File size: 5,614 Bytes
e054d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
8f68d0a
e054d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
适配器工厂模块

根据 DEPLOYMENT_MODE 配置自动选择本地或服务器适配器。

Example:
    >>> from app.core.adapters import get_database_adapter, get_storage_adapter
    >>> db = get_database_adapter()
    >>> storage = get_storage_adapter()
"""

from functools import lru_cache
from typing import TYPE_CHECKING

from project_config import settings

if TYPE_CHECKING:
    from ..adapters.base import (
        DatabaseAdapter,
        ProgressAdapter,
        StorageAdapter,
        TaskQueueAdapter,
    )


class AdapterFactory:
    """
    适配器工厂
    
    根据 DEPLOYMENT_MODE 配置创建对应的适配器实例。
    
    - local 模式: SQLite + 本地文件系统 + asyncio.subprocess
    - server 模式: PostgreSQL + S3/MinIO + Celery (Phase 2)
    """
    
    @staticmethod
    def create_storage_adapter() -> "StorageAdapter":
        """
        创建存储适配器
        
        Returns:
            本地模式返回 LocalStorageAdapter
            服务器模式返回 S3StorageAdapter (Phase 2)
        """
        if settings.DEPLOYMENT_MODE == "local":
            from ..adapters.local.storage import LocalStorageAdapter
            return LocalStorageAdapter(base_path=str(settings.DATA_DIR / "files"))
        else:
            # Phase 2: 服务器模式
            raise NotImplementedError("Server mode storage adapter not implemented yet")
    
    @staticmethod
    def create_database_adapter() -> "DatabaseAdapter":
        """
        创建数据库适配器
        
        Returns:
            本地模式返回 SQLiteAdapter
            服务器模式返回 PostgreSQLAdapter (Phase 2)
        """
        if settings.DEPLOYMENT_MODE == "local":
            from ..adapters.local.database import SQLiteAdapter
            return SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
        else:
            # Phase 2: 服务器模式
            raise NotImplementedError("Server mode database adapter not implemented yet")
    
    @staticmethod
    def create_task_queue_adapter(database_adapter: "DatabaseAdapter" = None) -> "TaskQueueAdapter":
        """
        创建任务队列适配器
        
        Args:
            database_adapter: 数据库适配器,用于同步任务状态到 tasks 表。
                              如果未提供,将自动创建一个实例。
        
        Returns:
            本地模式返回 AsyncTrainingManager
            服务器模式返回 CeleryTaskQueueAdapter (Phase 2)
        """
        if settings.DEPLOYMENT_MODE == "local":
            from ..adapters.local.task_queue import AsyncTrainingManager
            from ..adapters.local.database import SQLiteAdapter
            
            # 如果未提供 database_adapter,创建一个新实例用于状态同步
            if database_adapter is None:
                database_adapter = SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
            
            return AsyncTrainingManager(
                db_path=str(settings.SQLITE_PATH),
                database_adapter=database_adapter
            )
        else:
            # Phase 2: 服务器模式
            raise NotImplementedError("Server mode task queue adapter not implemented yet")
    
    @staticmethod
    def create_progress_adapter() -> "ProgressAdapter":
        """
        创建进度管理适配器
        
        Returns:
            本地模式返回 LocalProgressAdapter
            服务器模式返回 RedisProgressAdapter (Phase 2)
        """
        if settings.DEPLOYMENT_MODE == "local":
            from ..adapters.local.progress import LocalProgressAdapter
            return LocalProgressAdapter()
        else:
            # Phase 2: 服务器模式
            raise NotImplementedError("Server mode progress adapter not implemented yet")


# ============================================================
# 全局单例获取函数(使用 lru_cache 缓存实例)
# ============================================================

@lru_cache()
def get_storage_adapter() -> "StorageAdapter":
    """
    获取存储适配器单例
    
    Returns:
        StorageAdapter 实例
    """
    return AdapterFactory.create_storage_adapter()


@lru_cache()
def get_database_adapter() -> "DatabaseAdapter":
    """
    获取数据库适配器单例
    
    Returns:
        DatabaseAdapter 实例
    """
    return AdapterFactory.create_database_adapter()


@lru_cache()
def get_task_queue_adapter() -> "TaskQueueAdapter":
    """
    获取任务队列适配器单例
    
    使用共享的数据库适配器实例来确保状态同步一致性。
    
    Returns:
        TaskQueueAdapter 实例
    """
    # 使用共享的数据库适配器实例
    db_adapter = get_database_adapter()
    return AdapterFactory.create_task_queue_adapter(database_adapter=db_adapter)


@lru_cache()
def get_progress_adapter() -> "ProgressAdapter":
    """
    获取进度管理适配器单例
    
    Returns:
        ProgressAdapter 实例
    """
    return AdapterFactory.create_progress_adapter()


# ============================================================
# 便捷别名(向后兼容)
# ============================================================

# 延迟初始化的全局变量,在首次访问时创建
# 注意:这些是函数调用的结果,不是直接的实例引用
# 如果需要在模块级别使用,请调用对应的 get_*_adapter() 函数

__all__ = [
    "AdapterFactory",
    "get_storage_adapter",
    "get_database_adapter",
    "get_task_queue_adapter",
    "get_progress_adapter",
]