""" 适配器抽象基类模块 定义任务队列、存储、数据库等适配器的抽象接口 """ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, List, Optional, AsyncGenerator, Any if TYPE_CHECKING: from ..models.domain import Task class TaskQueueAdapter(ABC): """ 任务队列适配器抽象基类 定义任务队列的通用接口,支持本地(asyncio.subprocess)和 服务器(Celery)两种实现方式。 Example: >>> adapter = AsyncTrainingManager(db_path="./data/tasks.db") >>> job_id = await adapter.enqueue("task-123", {"exp_name": "test"}) >>> status = await adapter.get_status(job_id) >>> async for progress in adapter.subscribe_progress("task-123"): ... print(progress) """ @abstractmethod async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: """ 将任务加入队列 Args: task_id: 任务唯一标识 config: 任务配置字典,包含训练所需的所有参数 priority: 任务优先级 ("low", "normal", "high") Returns: job_id: 队列中的作业ID Raises: ValueError: 配置无效时抛出 """ pass @abstractmethod async def get_status(self, job_id: str) -> Dict: """ 获取任务状态 Args: job_id: 作业ID Returns: 状态字典,包含: - status: 任务状态 (queued, running, completed, failed, cancelled) - progress: 进度 (0.0-1.0) - current_stage: 当前阶段名称 - message: 状态消息 - error_message: 错误信息(如果失败) """ pass @abstractmethod async def cancel(self, job_id: str) -> bool: """ 取消任务 Args: job_id: 作业ID Returns: 是否成功取消 """ pass @abstractmethod async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: """ 订阅任务进度(用于 SSE 流) Args: task_id: 任务ID Yields: 进度信息字典,包含: - type: 消息类型 ("progress", "log", "heartbeat") - stage: 当前阶段 - progress: 进度值 - message: 进度消息 - status: 状态 (running, completed, failed, cancelled) Note: 当 status 为终态时,生成器会自动结束 """ pass class ProgressAdapter(ABC): """ 进度管理适配器抽象基类 用于更新和订阅任务进度,支持本地(内存队列)和 服务器(Redis Pub/Sub)两种实现。 """ @abstractmethod async def update_progress(self, task_id: str, progress: Dict) -> None: """ 更新进度 Args: task_id: 任务ID progress: 进度信息字典 """ pass @abstractmethod async def get_progress(self, task_id: str) -> Optional[Dict]: """ 获取当前进度 Args: task_id: 任务ID Returns: 最新进度信息,不存在则返回 None """ pass @abstractmethod async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: """ 订阅进度更新 Args: task_id: 任务ID Yields: 进度信息字典 """ pass class StorageAdapter(ABC): """ 存储适配器抽象基类 定义文件存储的通用接口,支持本地文件系统和 对象存储(S3/MinIO)两种实现方式。 Example: >>> adapter = LocalStorageAdapter(base_path="./data/files") >>> file_id = await adapter.upload_file(data, "audio.wav", {"purpose": "training"}) >>> content = await adapter.download_file(file_id) >>> await adapter.delete_file(file_id) """ @abstractmethod async def upload_file( self, file_data: bytes, filename: str, metadata: Dict[str, Any] ) -> str: """ 上传文件 Args: file_data: 文件二进制数据 filename: 原始文件名 metadata: 文件元数据,可包含: - content_type: MIME类型 - purpose: 文件用途 (training, reference, output) - 其他自定义字段 Returns: file_id: 文件唯一标识 Raises: IOError: 存储失败时抛出 """ pass @abstractmethod async def download_file(self, file_id: str) -> bytes: """ 下载文件 Args: file_id: 文件唯一标识 Returns: 文件二进制数据 Raises: FileNotFoundError: 文件不存在时抛出 """ pass @abstractmethod async def delete_file(self, file_id: str) -> bool: """ 删除文件 Args: file_id: 文件唯一标识 Returns: 是否成功删除 """ pass @abstractmethod async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]: """ 获取文件元数据 Args: file_id: 文件唯一标识 Returns: 文件元数据字典,包含: - id: 文件ID - filename: 原始文件名 - content_type: MIME类型 - size_bytes: 文件大小 - purpose: 文件用途 - uploaded_at: 上传时间 - 音频文件额外包含: duration_seconds, sample_rate 文件不存在时返回 None """ pass @abstractmethod async def list_files( self, purpose: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List[Dict[str, Any]]: """ 列出文件 Args: purpose: 按用途筛选 (training, reference, output) limit: 返回数量限制 offset: 偏移量 Returns: 文件元数据列表 """ pass @abstractmethod async def file_exists(self, file_id: str) -> bool: """ 检查文件是否存在 Args: file_id: 文件唯一标识 Returns: 文件是否存在 """ pass class DatabaseAdapter(ABC): """ 数据库适配器抽象基类 定义数据持久化的通用接口,支持 SQLite 和 PostgreSQL 两种实现方式。 管理以下实体: - Task: Quick Mode 一键训练任务 - Experiment: Advanced Mode 实验 - Stage: 实验中的各个阶段 - File: 上传的文件记录(可选,与StorageAdapter配合) Example: >>> adapter = SQLiteAdapter(db_path="./data/app.db") >>> task = await adapter.create_task(task_data) >>> task = await adapter.get_task(task_id) >>> await adapter.update_task(task_id, {"status": "completed"}) """ # ============================================================ # Task CRUD (Quick Mode) # ============================================================ @abstractmethod async def create_task(self, task: "Task") -> "Task": """ 创建任务 Args: task: Task 领域模型实例 Returns: 创建后的 Task 实例(包含生成的字段如 created_at) """ pass @abstractmethod async def get_task(self, task_id: str) -> Optional["Task"]: """ 获取任务 Args: task_id: 任务唯一标识 Returns: Task 实例,不存在则返回 None """ pass @abstractmethod async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional["Task"]: """ 更新任务 Args: task_id: 任务唯一标识 updates: 要更新的字段字典 Returns: 更新后的 Task 实例,不存在则返回 None """ pass @abstractmethod async def list_tasks( self, status: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List["Task"]: """ 查询任务列表 Args: status: 按状态筛选 limit: 返回数量限制 offset: 偏移量 Returns: Task 实例列表 """ pass @abstractmethod async def delete_task(self, task_id: str) -> bool: """ 删除任务 Args: task_id: 任务唯一标识 Returns: 是否成功删除 """ pass @abstractmethod async def count_tasks(self, status: Optional[str] = None) -> int: """ 统计任务数量 Args: status: 按状态筛选 Returns: 任务数量 """ pass @abstractmethod async def get_task_by_exp_name(self, exp_name: str) -> Optional["Task"]: """ 根据实验名称获取任务 用于检查 exp_name 是否已存在。 Args: exp_name: 实验名称 Returns: Task 实例,不存在则返回 None """ pass # ============================================================ # Experiment CRUD (Advanced Mode) # ============================================================ @abstractmethod async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]: """ 创建实验 Args: experiment: 实验数据字典 Returns: 创建后的实验数据 """ pass @abstractmethod async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]: """ 获取实验 Args: exp_id: 实验唯一标识 Returns: 实验数据字典,不存在则返回 None """ pass @abstractmethod async def update_experiment( self, exp_id: str, updates: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """ 更新实验 Args: exp_id: 实验唯一标识 updates: 要更新的字段字典 Returns: 更新后的实验数据,不存在则返回 None """ pass @abstractmethod async def list_experiments( self, status: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List[Dict[str, Any]]: """ 查询实验列表 Args: status: 按状态筛选 limit: 返回数量限制 offset: 偏移量 Returns: 实验数据列表 """ pass @abstractmethod async def delete_experiment(self, exp_id: str) -> bool: """ 删除实验 Args: exp_id: 实验唯一标识 Returns: 是否成功删除 """ pass # ============================================================ # Stage 操作 (Advanced Mode) # ============================================================ @abstractmethod async def update_stage( self, exp_id: str, stage_type: str, updates: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """ 更新阶段状态 Args: exp_id: 实验唯一标识 stage_type: 阶段类型 updates: 要更新的字段字典 Returns: 更新后的阶段数据,不存在则返回 None """ pass @abstractmethod async def get_stage( self, exp_id: str, stage_type: str ) -> Optional[Dict[str, Any]]: """ 获取阶段状态 Args: exp_id: 实验唯一标识 stage_type: 阶段类型 Returns: 阶段数据字典,不存在则返回 None """ pass @abstractmethod async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]: """ 获取实验的所有阶段状态 Args: exp_id: 实验唯一标识 Returns: 阶段数据列表 """ pass # ============================================================ # File 记录 (可选,与 StorageAdapter 配合) # ============================================================ @abstractmethod async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]: """ 创建文件记录 Args: file_data: 文件元数据 Returns: 创建后的文件记录 """ pass @abstractmethod async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]: """ 获取文件记录 Args: file_id: 文件唯一标识 Returns: 文件记录,不存在则返回 None """ pass @abstractmethod async def delete_file_record(self, file_id: str) -> bool: """ 删除文件记录 Args: file_id: 文件唯一标识 Returns: 是否成功删除 """ pass @abstractmethod async def list_file_records( self, purpose: Optional[str] = None, limit: int = 50, offset: int = 0 ) -> List[Dict[str, Any]]: """ 查询文件记录列表 Args: purpose: 按用途筛选 limit: 返回数量限制 offset: 偏移量 Returns: 文件记录列表 """ pass