File size: 5,614 Bytes
e054d0c 8f68d0a e054d0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
适配器工厂模块
根据 DEPLOYMENT_MODE 配置自动选择本地或服务器适配器。
Example:
>>> from app.core.adapters import get_database_adapter, get_storage_adapter
>>> db = get_database_adapter()
>>> storage = get_storage_adapter()
"""
from functools import lru_cache
from typing import TYPE_CHECKING
from project_config import settings
if TYPE_CHECKING:
from ..adapters.base import (
DatabaseAdapter,
ProgressAdapter,
StorageAdapter,
TaskQueueAdapter,
)
class AdapterFactory:
"""
适配器工厂
根据 DEPLOYMENT_MODE 配置创建对应的适配器实例。
- local 模式: SQLite + 本地文件系统 + asyncio.subprocess
- server 模式: PostgreSQL + S3/MinIO + Celery (Phase 2)
"""
@staticmethod
def create_storage_adapter() -> "StorageAdapter":
"""
创建存储适配器
Returns:
本地模式返回 LocalStorageAdapter
服务器模式返回 S3StorageAdapter (Phase 2)
"""
if settings.DEPLOYMENT_MODE == "local":
from ..adapters.local.storage import LocalStorageAdapter
return LocalStorageAdapter(base_path=str(settings.DATA_DIR / "files"))
else:
# Phase 2: 服务器模式
raise NotImplementedError("Server mode storage adapter not implemented yet")
@staticmethod
def create_database_adapter() -> "DatabaseAdapter":
"""
创建数据库适配器
Returns:
本地模式返回 SQLiteAdapter
服务器模式返回 PostgreSQLAdapter (Phase 2)
"""
if settings.DEPLOYMENT_MODE == "local":
from ..adapters.local.database import SQLiteAdapter
return SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
else:
# Phase 2: 服务器模式
raise NotImplementedError("Server mode database adapter not implemented yet")
@staticmethod
def create_task_queue_adapter(database_adapter: "DatabaseAdapter" = None) -> "TaskQueueAdapter":
"""
创建任务队列适配器
Args:
database_adapter: 数据库适配器,用于同步任务状态到 tasks 表。
如果未提供,将自动创建一个实例。
Returns:
本地模式返回 AsyncTrainingManager
服务器模式返回 CeleryTaskQueueAdapter (Phase 2)
"""
if settings.DEPLOYMENT_MODE == "local":
from ..adapters.local.task_queue import AsyncTrainingManager
from ..adapters.local.database import SQLiteAdapter
# 如果未提供 database_adapter,创建一个新实例用于状态同步
if database_adapter is None:
database_adapter = SQLiteAdapter(db_path=str(settings.SQLITE_PATH))
return AsyncTrainingManager(
db_path=str(settings.SQLITE_PATH),
database_adapter=database_adapter
)
else:
# Phase 2: 服务器模式
raise NotImplementedError("Server mode task queue adapter not implemented yet")
@staticmethod
def create_progress_adapter() -> "ProgressAdapter":
"""
创建进度管理适配器
Returns:
本地模式返回 LocalProgressAdapter
服务器模式返回 RedisProgressAdapter (Phase 2)
"""
if settings.DEPLOYMENT_MODE == "local":
from ..adapters.local.progress import LocalProgressAdapter
return LocalProgressAdapter()
else:
# Phase 2: 服务器模式
raise NotImplementedError("Server mode progress adapter not implemented yet")
# ============================================================
# 全局单例获取函数(使用 lru_cache 缓存实例)
# ============================================================
@lru_cache()
def get_storage_adapter() -> "StorageAdapter":
"""
获取存储适配器单例
Returns:
StorageAdapter 实例
"""
return AdapterFactory.create_storage_adapter()
@lru_cache()
def get_database_adapter() -> "DatabaseAdapter":
"""
获取数据库适配器单例
Returns:
DatabaseAdapter 实例
"""
return AdapterFactory.create_database_adapter()
@lru_cache()
def get_task_queue_adapter() -> "TaskQueueAdapter":
"""
获取任务队列适配器单例
使用共享的数据库适配器实例来确保状态同步一致性。
Returns:
TaskQueueAdapter 实例
"""
# 使用共享的数据库适配器实例
db_adapter = get_database_adapter()
return AdapterFactory.create_task_queue_adapter(database_adapter=db_adapter)
@lru_cache()
def get_progress_adapter() -> "ProgressAdapter":
"""
获取进度管理适配器单例
Returns:
ProgressAdapter 实例
"""
return AdapterFactory.create_progress_adapter()
# ============================================================
# 便捷别名(向后兼容)
# ============================================================
# 延迟初始化的全局变量,在首次访问时创建
# 注意:这些是函数调用的结果,不是直接的实例引用
# 如果需要在模块级别使用,请调用对应的 get_*_adapter() 函数
__all__ = [
"AdapterFactory",
"get_storage_adapter",
"get_database_adapter",
"get_task_queue_adapter",
"get_progress_adapter",
]
|