|
|
""" |
|
|
适配器抽象基类模块 |
|
|
|
|
|
定义任务队列、存储、数据库等适配器的抽象接口 |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, AsyncGenerator, Any |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ..models.domain import Task |
|
|
|
|
|
|
|
|
class TaskQueueAdapter(ABC): |
|
|
""" |
|
|
任务队列适配器抽象基类 |
|
|
|
|
|
定义任务队列的通用接口,支持本地(asyncio.subprocess)和 |
|
|
服务器(Celery)两种实现方式。 |
|
|
|
|
|
Example: |
|
|
>>> adapter = AsyncTrainingManager(db_path="./data/tasks.db") |
|
|
>>> job_id = await adapter.enqueue("task-123", {"exp_name": "test"}) |
|
|
>>> status = await adapter.get_status(job_id) |
|
|
>>> async for progress in adapter.subscribe_progress("task-123"): |
|
|
... print(progress) |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: |
|
|
""" |
|
|
将任务加入队列 |
|
|
|
|
|
Args: |
|
|
task_id: 任务唯一标识 |
|
|
config: 任务配置字典,包含训练所需的所有参数 |
|
|
priority: 任务优先级 ("low", "normal", "high") |
|
|
|
|
|
Returns: |
|
|
job_id: 队列中的作业ID |
|
|
|
|
|
Raises: |
|
|
ValueError: 配置无效时抛出 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_status(self, job_id: str) -> Dict: |
|
|
""" |
|
|
获取任务状态 |
|
|
|
|
|
Args: |
|
|
job_id: 作业ID |
|
|
|
|
|
Returns: |
|
|
状态字典,包含: |
|
|
- status: 任务状态 (queued, running, completed, failed, cancelled) |
|
|
- progress: 进度 (0.0-1.0) |
|
|
- current_stage: 当前阶段名称 |
|
|
- message: 状态消息 |
|
|
- error_message: 错误信息(如果失败) |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def cancel(self, job_id: str) -> bool: |
|
|
""" |
|
|
取消任务 |
|
|
|
|
|
Args: |
|
|
job_id: 作业ID |
|
|
|
|
|
Returns: |
|
|
是否成功取消 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: |
|
|
""" |
|
|
订阅任务进度(用于 SSE 流) |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Yields: |
|
|
进度信息字典,包含: |
|
|
- type: 消息类型 ("progress", "log", "heartbeat") |
|
|
- stage: 当前阶段 |
|
|
- progress: 进度值 |
|
|
- message: 进度消息 |
|
|
- status: 状态 (running, completed, failed, cancelled) |
|
|
|
|
|
Note: |
|
|
当 status 为终态时,生成器会自动结束 |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class ProgressAdapter(ABC): |
|
|
""" |
|
|
进度管理适配器抽象基类 |
|
|
|
|
|
用于更新和订阅任务进度,支持本地(内存队列)和 |
|
|
服务器(Redis Pub/Sub)两种实现。 |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
async def update_progress(self, task_id: str, progress: Dict) -> None: |
|
|
""" |
|
|
更新进度 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
progress: 进度信息字典 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_progress(self, task_id: str) -> Optional[Dict]: |
|
|
""" |
|
|
获取当前进度 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Returns: |
|
|
最新进度信息,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: |
|
|
""" |
|
|
订阅进度更新 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Yields: |
|
|
进度信息字典 |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class StorageAdapter(ABC): |
|
|
""" |
|
|
存储适配器抽象基类 |
|
|
|
|
|
定义文件存储的通用接口,支持本地文件系统和 |
|
|
对象存储(S3/MinIO)两种实现方式。 |
|
|
|
|
|
Example: |
|
|
>>> adapter = LocalStorageAdapter(base_path="./data/files") |
|
|
>>> file_id = await adapter.upload_file(data, "audio.wav", {"purpose": "training"}) |
|
|
>>> content = await adapter.download_file(file_id) |
|
|
>>> await adapter.delete_file(file_id) |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
async def upload_file( |
|
|
self, |
|
|
file_data: bytes, |
|
|
filename: str, |
|
|
metadata: Dict[str, Any] |
|
|
) -> str: |
|
|
""" |
|
|
上传文件 |
|
|
|
|
|
Args: |
|
|
file_data: 文件二进制数据 |
|
|
filename: 原始文件名 |
|
|
metadata: 文件元数据,可包含: |
|
|
- content_type: MIME类型 |
|
|
- purpose: 文件用途 (training, reference, output) |
|
|
- 其他自定义字段 |
|
|
|
|
|
Returns: |
|
|
file_id: 文件唯一标识 |
|
|
|
|
|
Raises: |
|
|
IOError: 存储失败时抛出 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def download_file(self, file_id: str) -> bytes: |
|
|
""" |
|
|
下载文件 |
|
|
|
|
|
Args: |
|
|
file_id: 文件唯一标识 |
|
|
|
|
|
Returns: |
|
|
文件二进制数据 |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: 文件不存在时抛出 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def delete_file(self, file_id: str) -> bool: |
|
|
""" |
|
|
删除文件 |
|
|
|
|
|
Args: |
|
|
file_id: 文件唯一标识 |
|
|
|
|
|
Returns: |
|
|
是否成功删除 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
获取文件元数据 |
|
|
|
|
|
Args: |
|
|
file_id: 文件唯一标识 |
|
|
|
|
|
Returns: |
|
|
文件元数据字典,包含: |
|
|
- id: 文件ID |
|
|
- filename: 原始文件名 |
|
|
- content_type: MIME类型 |
|
|
- size_bytes: 文件大小 |
|
|
- purpose: 文件用途 |
|
|
- uploaded_at: 上传时间 |
|
|
- 音频文件额外包含: duration_seconds, sample_rate |
|
|
|
|
|
文件不存在时返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def list_files( |
|
|
self, |
|
|
purpose: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
列出文件 |
|
|
|
|
|
Args: |
|
|
purpose: 按用途筛选 (training, reference, output) |
|
|
limit: 返回数量限制 |
|
|
offset: 偏移量 |
|
|
|
|
|
Returns: |
|
|
文件元数据列表 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def file_exists(self, file_id: str) -> bool: |
|
|
""" |
|
|
检查文件是否存在 |
|
|
|
|
|
Args: |
|
|
file_id: 文件唯一标识 |
|
|
|
|
|
Returns: |
|
|
文件是否存在 |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class DatabaseAdapter(ABC): |
|
|
""" |
|
|
数据库适配器抽象基类 |
|
|
|
|
|
定义数据持久化的通用接口,支持 SQLite 和 |
|
|
PostgreSQL 两种实现方式。 |
|
|
|
|
|
管理以下实体: |
|
|
- Task: Quick Mode 一键训练任务 |
|
|
- Experiment: Advanced Mode 实验 |
|
|
- Stage: 实验中的各个阶段 |
|
|
- File: 上传的文件记录(可选,与StorageAdapter配合) |
|
|
|
|
|
Example: |
|
|
>>> adapter = SQLiteAdapter(db_path="./data/app.db") |
|
|
>>> task = await adapter.create_task(task_data) |
|
|
>>> task = await adapter.get_task(task_id) |
|
|
>>> await adapter.update_task(task_id, {"status": "completed"}) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
async def create_task(self, task: "Task") -> "Task": |
|
|
""" |
|
|
创建任务 |
|
|
|
|
|
Args: |
|
|
task: Task 领域模型实例 |
|
|
|
|
|
Returns: |
|
|
创建后的 Task 实例(包含生成的字段如 created_at) |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_task(self, task_id: str) -> Optional["Task"]: |
|
|
""" |
|
|
获取任务 |
|
|
|
|
|
Args: |
|
|
task_id: 任务唯一标识 |
|
|
|
|
|
Returns: |
|
|
Task 实例,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional["Task"]: |
|
|
""" |
|
|
更新任务 |
|
|
|
|
|
Args: |
|
|
task_id: 任务唯一标识 |
|
|
updates: 要更新的字段字典 |
|
|
|
|
|
Returns: |
|
|
更新后的 Task 实例,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def list_tasks( |
|
|
self, |
|
|
status: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List["Task"]: |
|
|
""" |
|
|
查询任务列表 |
|
|
|
|
|
Args: |
|
|
status: 按状态筛选 |
|
|
limit: 返回数量限制 |
|
|
offset: 偏移量 |
|
|
|
|
|
Returns: |
|
|
Task 实例列表 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def delete_task(self, task_id: str) -> bool: |
|
|
""" |
|
|
删除任务 |
|
|
|
|
|
Args: |
|
|
task_id: 任务唯一标识 |
|
|
|
|
|
Returns: |
|
|
是否成功删除 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def count_tasks(self, status: Optional[str] = None) -> int: |
|
|
""" |
|
|
统计任务数量 |
|
|
|
|
|
Args: |
|
|
status: 按状态筛选 |
|
|
|
|
|
Returns: |
|
|
任务数量 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_task_by_exp_name(self, exp_name: str) -> Optional["Task"]: |
|
|
""" |
|
|
根据实验名称获取任务 |
|
|
|
|
|
用于检查 exp_name 是否已存在。 |
|
|
|
|
|
Args: |
|
|
exp_name: 实验名称 |
|
|
|
|
|
Returns: |
|
|
Task 实例,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
创建实验 |
|
|
|
|
|
Args: |
|
|
experiment: 实验数据字典 |
|
|
|
|
|
Returns: |
|
|
创建后的实验数据 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
获取实验 |
|
|
|
|
|
Args: |
|
|
exp_id: 实验唯一标识 |
|
|
|
|
|
Returns: |
|
|
实验数据字典,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def update_experiment( |
|
|
self, |
|
|
exp_id: str, |
|
|
updates: Dict[str, Any] |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
更新实验 |
|
|
|
|
|
Args: |
|
|
exp_id: 实验唯一标识 |
|
|
updates: 要更新的字段字典 |
|
|
|
|
|
Returns: |
|
|
更新后的实验数据,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def list_experiments( |
|
|
self, |
|
|
status: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
查询实验列表 |
|
|
|
|
|
Args: |
|
|
status: 按状态筛选 |
|
|
limit: 返回数量限制 |
|
|
offset: 偏移量 |
|
|
|
|
|
Returns: |
|
|
实验数据列表 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def delete_experiment(self, exp_id: str) -> bool: |
|
|
""" |
|
|
删除实验 |
|
|
|
|
|
Args: |
|
|
exp_id: 实验唯一标识 |
|
|
|
|
|
Returns: |
|
|
是否成功删除 |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
async def update_stage( |
|
|
self, |
|
|
exp_id: str, |
|
|
stage_type: str, |
|
|
updates: Dict[str, Any] |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
更新阶段状态 |
|
|
|
|
|
Args: |
|
|
exp_id: 实验唯一标识 |
|
|
stage_type: 阶段类型 |
|
|
updates: 要更新的字段字典 |
|
|
|
|
|
Returns: |
|
|
更新后的阶段数据,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_stage( |
|
|
self, |
|
|
exp_id: str, |
|
|
stage_type: str |
|
|
) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
获取阶段状态 |
|
|
|
|
|
Args: |
|
|
exp_id: 实验唯一标识 |
|
|
stage_type: 阶段类型 |
|
|
|
|
|
Returns: |
|
|
阶段数据字典,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
获取实验的所有阶段状态 |
|
|
|
|
|
Args: |
|
|
exp_id: 实验唯一标识 |
|
|
|
|
|
Returns: |
|
|
阶段数据列表 |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
|
async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
创建文件记录 |
|
|
|
|
|
Args: |
|
|
file_data: 文件元数据 |
|
|
|
|
|
Returns: |
|
|
创建后的文件记录 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
获取文件记录 |
|
|
|
|
|
Args: |
|
|
file_id: 文件唯一标识 |
|
|
|
|
|
Returns: |
|
|
文件记录,不存在则返回 None |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def delete_file_record(self, file_id: str) -> bool: |
|
|
""" |
|
|
删除文件记录 |
|
|
|
|
|
Args: |
|
|
file_id: 文件唯一标识 |
|
|
|
|
|
Returns: |
|
|
是否成功删除 |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
async def list_file_records( |
|
|
self, |
|
|
purpose: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
查询文件记录列表 |
|
|
|
|
|
Args: |
|
|
purpose: 按用途筛选 |
|
|
limit: 返回数量限制 |
|
|
offset: 偏移量 |
|
|
|
|
|
Returns: |
|
|
文件记录列表 |
|
|
""" |
|
|
pass |
|
|
|