|
|
""" |
|
|
适配器工厂模块 |
|
|
|
|
|
根据 DEPLOYMENT_MODE 配置自动选择本地或服务器适配器。 |
|
|
|
|
|
Example: |
|
|
>>> from app.core.adapters import get_database_adapter, get_storage_adapter |
|
|
>>> db = get_database_adapter() |
|
|
>>> storage = get_storage_adapter() |
|
|
""" |
|
|
|
|
|
from functools import lru_cache |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from project_config import settings |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ..adapters.base import ( |
|
|
DatabaseAdapter, |
|
|
ProgressAdapter, |
|
|
StorageAdapter, |
|
|
TaskQueueAdapter, |
|
|
) |
|
|
|
|
|
|
|
|
class AdapterFactory: |
|
|
""" |
|
|
适配器工厂 |
|
|
|
|
|
根据 DEPLOYMENT_MODE 配置创建对应的适配器实例。 |
|
|
|
|
|
- local 模式: SQLite + 本地文件系统 + asyncio.subprocess |
|
|
- server 模式: PostgreSQL + S3/MinIO + Celery (Phase 2) |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def create_storage_adapter() -> "StorageAdapter": |
|
|
""" |
|
|
创建存储适配器 |
|
|
|
|
|
Returns: |
|
|
本地模式返回 LocalStorageAdapter |
|
|
服务器模式返回 S3StorageAdapter (Phase 2) |
|
|
""" |
|
|
if settings.DEPLOYMENT_MODE == "local": |
|
|
from ..adapters.local.storage import LocalStorageAdapter |
|
|
return LocalStorageAdapter(base_path=str(settings.DATA_DIR / "files")) |
|
|
else: |
|
|
|
|
|
raise NotImplementedError("Server mode storage adapter not implemented yet") |
|
|
|
|
|
@staticmethod |
|
|
def create_database_adapter() -> "DatabaseAdapter": |
|
|
""" |
|
|
创建数据库适配器 |
|
|
|
|
|
Returns: |
|
|
本地模式返回 SQLiteAdapter |
|
|
服务器模式返回 PostgreSQLAdapter (Phase 2) |
|
|
""" |
|
|
if settings.DEPLOYMENT_MODE == "local": |
|
|
from ..adapters.local.database import SQLiteAdapter |
|
|
return SQLiteAdapter(db_path=str(settings.SQLITE_PATH)) |
|
|
else: |
|
|
|
|
|
raise NotImplementedError("Server mode database adapter not implemented yet") |
|
|
|
|
|
@staticmethod |
|
|
def create_task_queue_adapter(database_adapter: "DatabaseAdapter" = None) -> "TaskQueueAdapter": |
|
|
""" |
|
|
创建任务队列适配器 |
|
|
|
|
|
Args: |
|
|
database_adapter: 数据库适配器,用于同步任务状态到 tasks 表。 |
|
|
如果未提供,将自动创建一个实例。 |
|
|
|
|
|
Returns: |
|
|
本地模式返回 AsyncTrainingManager |
|
|
服务器模式返回 CeleryTaskQueueAdapter (Phase 2) |
|
|
""" |
|
|
if settings.DEPLOYMENT_MODE == "local": |
|
|
from ..adapters.local.task_queue import AsyncTrainingManager |
|
|
from ..adapters.local.database import SQLiteAdapter |
|
|
|
|
|
|
|
|
if database_adapter is None: |
|
|
database_adapter = SQLiteAdapter(db_path=str(settings.SQLITE_PATH)) |
|
|
|
|
|
return AsyncTrainingManager( |
|
|
db_path=str(settings.SQLITE_PATH), |
|
|
database_adapter=database_adapter |
|
|
) |
|
|
else: |
|
|
|
|
|
raise NotImplementedError("Server mode task queue adapter not implemented yet") |
|
|
|
|
|
@staticmethod |
|
|
def create_progress_adapter() -> "ProgressAdapter": |
|
|
""" |
|
|
创建进度管理适配器 |
|
|
|
|
|
Returns: |
|
|
本地模式返回 LocalProgressAdapter |
|
|
服务器模式返回 RedisProgressAdapter (Phase 2) |
|
|
""" |
|
|
if settings.DEPLOYMENT_MODE == "local": |
|
|
from ..adapters.local.progress import LocalProgressAdapter |
|
|
return LocalProgressAdapter() |
|
|
else: |
|
|
|
|
|
raise NotImplementedError("Server mode progress adapter not implemented yet") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
def get_storage_adapter() -> "StorageAdapter": |
|
|
""" |
|
|
获取存储适配器单例 |
|
|
|
|
|
Returns: |
|
|
StorageAdapter 实例 |
|
|
""" |
|
|
return AdapterFactory.create_storage_adapter() |
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
def get_database_adapter() -> "DatabaseAdapter": |
|
|
""" |
|
|
获取数据库适配器单例 |
|
|
|
|
|
Returns: |
|
|
DatabaseAdapter 实例 |
|
|
""" |
|
|
return AdapterFactory.create_database_adapter() |
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
def get_task_queue_adapter() -> "TaskQueueAdapter": |
|
|
""" |
|
|
获取任务队列适配器单例 |
|
|
|
|
|
使用共享的数据库适配器实例来确保状态同步一致性。 |
|
|
|
|
|
Returns: |
|
|
TaskQueueAdapter 实例 |
|
|
""" |
|
|
|
|
|
db_adapter = get_database_adapter() |
|
|
return AdapterFactory.create_task_queue_adapter(database_adapter=db_adapter) |
|
|
|
|
|
|
|
|
@lru_cache() |
|
|
def get_progress_adapter() -> "ProgressAdapter": |
|
|
""" |
|
|
获取进度管理适配器单例 |
|
|
|
|
|
Returns: |
|
|
ProgressAdapter 实例 |
|
|
""" |
|
|
return AdapterFactory.create_progress_adapter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"AdapterFactory", |
|
|
"get_storage_adapter", |
|
|
"get_database_adapter", |
|
|
"get_task_queue_adapter", |
|
|
"get_progress_adapter", |
|
|
] |
|
|
|