liumaolin
feat(api): implement local training MVP with adapter pattern
e054d0c
"""
适配器抽象基类模块
定义任务队列、存储、数据库等适配器的抽象接口
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, AsyncGenerator, Any
if TYPE_CHECKING:
from ..models.domain import Task
class TaskQueueAdapter(ABC):
"""
任务队列适配器抽象基类
定义任务队列的通用接口,支持本地(asyncio.subprocess)和
服务器(Celery)两种实现方式。
Example:
>>> adapter = AsyncTrainingManager(db_path="./data/tasks.db")
>>> job_id = await adapter.enqueue("task-123", {"exp_name": "test"})
>>> status = await adapter.get_status(job_id)
>>> async for progress in adapter.subscribe_progress("task-123"):
... print(progress)
"""
@abstractmethod
async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str:
"""
将任务加入队列
Args:
task_id: 任务唯一标识
config: 任务配置字典,包含训练所需的所有参数
priority: 任务优先级 ("low", "normal", "high")
Returns:
job_id: 队列中的作业ID
Raises:
ValueError: 配置无效时抛出
"""
pass
@abstractmethod
async def get_status(self, job_id: str) -> Dict:
"""
获取任务状态
Args:
job_id: 作业ID
Returns:
状态字典,包含:
- status: 任务状态 (queued, running, completed, failed, cancelled)
- progress: 进度 (0.0-1.0)
- current_stage: 当前阶段名称
- message: 状态消息
- error_message: 错误信息(如果失败)
"""
pass
@abstractmethod
async def cancel(self, job_id: str) -> bool:
"""
取消任务
Args:
job_id: 作业ID
Returns:
是否成功取消
"""
pass
@abstractmethod
async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]:
"""
订阅任务进度(用于 SSE 流)
Args:
task_id: 任务ID
Yields:
进度信息字典,包含:
- type: 消息类型 ("progress", "log", "heartbeat")
- stage: 当前阶段
- progress: 进度值
- message: 进度消息
- status: 状态 (running, completed, failed, cancelled)
Note:
当 status 为终态时,生成器会自动结束
"""
pass
class ProgressAdapter(ABC):
"""
进度管理适配器抽象基类
用于更新和订阅任务进度,支持本地(内存队列)和
服务器(Redis Pub/Sub)两种实现。
"""
@abstractmethod
async def update_progress(self, task_id: str, progress: Dict) -> None:
"""
更新进度
Args:
task_id: 任务ID
progress: 进度信息字典
"""
pass
@abstractmethod
async def get_progress(self, task_id: str) -> Optional[Dict]:
"""
获取当前进度
Args:
task_id: 任务ID
Returns:
最新进度信息,不存在则返回 None
"""
pass
@abstractmethod
async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]:
"""
订阅进度更新
Args:
task_id: 任务ID
Yields:
进度信息字典
"""
pass
class StorageAdapter(ABC):
"""
存储适配器抽象基类
定义文件存储的通用接口,支持本地文件系统和
对象存储(S3/MinIO)两种实现方式。
Example:
>>> adapter = LocalStorageAdapter(base_path="./data/files")
>>> file_id = await adapter.upload_file(data, "audio.wav", {"purpose": "training"})
>>> content = await adapter.download_file(file_id)
>>> await adapter.delete_file(file_id)
"""
@abstractmethod
async def upload_file(
self,
file_data: bytes,
filename: str,
metadata: Dict[str, Any]
) -> str:
"""
上传文件
Args:
file_data: 文件二进制数据
filename: 原始文件名
metadata: 文件元数据,可包含:
- content_type: MIME类型
- purpose: 文件用途 (training, reference, output)
- 其他自定义字段
Returns:
file_id: 文件唯一标识
Raises:
IOError: 存储失败时抛出
"""
pass
@abstractmethod
async def download_file(self, file_id: str) -> bytes:
"""
下载文件
Args:
file_id: 文件唯一标识
Returns:
文件二进制数据
Raises:
FileNotFoundError: 文件不存在时抛出
"""
pass
@abstractmethod
async def delete_file(self, file_id: str) -> bool:
"""
删除文件
Args:
file_id: 文件唯一标识
Returns:
是否成功删除
"""
pass
@abstractmethod
async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]:
"""
获取文件元数据
Args:
file_id: 文件唯一标识
Returns:
文件元数据字典,包含:
- id: 文件ID
- filename: 原始文件名
- content_type: MIME类型
- size_bytes: 文件大小
- purpose: 文件用途
- uploaded_at: 上传时间
- 音频文件额外包含: duration_seconds, sample_rate
文件不存在时返回 None
"""
pass
@abstractmethod
async def list_files(
self,
purpose: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""
列出文件
Args:
purpose: 按用途筛选 (training, reference, output)
limit: 返回数量限制
offset: 偏移量
Returns:
文件元数据列表
"""
pass
@abstractmethod
async def file_exists(self, file_id: str) -> bool:
"""
检查文件是否存在
Args:
file_id: 文件唯一标识
Returns:
文件是否存在
"""
pass
class DatabaseAdapter(ABC):
"""
数据库适配器抽象基类
定义数据持久化的通用接口,支持 SQLite 和
PostgreSQL 两种实现方式。
管理以下实体:
- Task: Quick Mode 一键训练任务
- Experiment: Advanced Mode 实验
- Stage: 实验中的各个阶段
- File: 上传的文件记录(可选,与StorageAdapter配合)
Example:
>>> adapter = SQLiteAdapter(db_path="./data/app.db")
>>> task = await adapter.create_task(task_data)
>>> task = await adapter.get_task(task_id)
>>> await adapter.update_task(task_id, {"status": "completed"})
"""
# ============================================================
# Task CRUD (Quick Mode)
# ============================================================
@abstractmethod
async def create_task(self, task: "Task") -> "Task":
"""
创建任务
Args:
task: Task 领域模型实例
Returns:
创建后的 Task 实例(包含生成的字段如 created_at)
"""
pass
@abstractmethod
async def get_task(self, task_id: str) -> Optional["Task"]:
"""
获取任务
Args:
task_id: 任务唯一标识
Returns:
Task 实例,不存在则返回 None
"""
pass
@abstractmethod
async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional["Task"]:
"""
更新任务
Args:
task_id: 任务唯一标识
updates: 要更新的字段字典
Returns:
更新后的 Task 实例,不存在则返回 None
"""
pass
@abstractmethod
async def list_tasks(
self,
status: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List["Task"]:
"""
查询任务列表
Args:
status: 按状态筛选
limit: 返回数量限制
offset: 偏移量
Returns:
Task 实例列表
"""
pass
@abstractmethod
async def delete_task(self, task_id: str) -> bool:
"""
删除任务
Args:
task_id: 任务唯一标识
Returns:
是否成功删除
"""
pass
@abstractmethod
async def count_tasks(self, status: Optional[str] = None) -> int:
"""
统计任务数量
Args:
status: 按状态筛选
Returns:
任务数量
"""
pass
@abstractmethod
async def get_task_by_exp_name(self, exp_name: str) -> Optional["Task"]:
"""
根据实验名称获取任务
用于检查 exp_name 是否已存在。
Args:
exp_name: 实验名称
Returns:
Task 实例,不存在则返回 None
"""
pass
# ============================================================
# Experiment CRUD (Advanced Mode)
# ============================================================
@abstractmethod
async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]:
"""
创建实验
Args:
experiment: 实验数据字典
Returns:
创建后的实验数据
"""
pass
@abstractmethod
async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]:
"""
获取实验
Args:
exp_id: 实验唯一标识
Returns:
实验数据字典,不存在则返回 None
"""
pass
@abstractmethod
async def update_experiment(
self,
exp_id: str,
updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
更新实验
Args:
exp_id: 实验唯一标识
updates: 要更新的字段字典
Returns:
更新后的实验数据,不存在则返回 None
"""
pass
@abstractmethod
async def list_experiments(
self,
status: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""
查询实验列表
Args:
status: 按状态筛选
limit: 返回数量限制
offset: 偏移量
Returns:
实验数据列表
"""
pass
@abstractmethod
async def delete_experiment(self, exp_id: str) -> bool:
"""
删除实验
Args:
exp_id: 实验唯一标识
Returns:
是否成功删除
"""
pass
# ============================================================
# Stage 操作 (Advanced Mode)
# ============================================================
@abstractmethod
async def update_stage(
self,
exp_id: str,
stage_type: str,
updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
更新阶段状态
Args:
exp_id: 实验唯一标识
stage_type: 阶段类型
updates: 要更新的字段字典
Returns:
更新后的阶段数据,不存在则返回 None
"""
pass
@abstractmethod
async def get_stage(
self,
exp_id: str,
stage_type: str
) -> Optional[Dict[str, Any]]:
"""
获取阶段状态
Args:
exp_id: 实验唯一标识
stage_type: 阶段类型
Returns:
阶段数据字典,不存在则返回 None
"""
pass
@abstractmethod
async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]:
"""
获取实验的所有阶段状态
Args:
exp_id: 实验唯一标识
Returns:
阶段数据列表
"""
pass
# ============================================================
# File 记录 (可选,与 StorageAdapter 配合)
# ============================================================
@abstractmethod
async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]:
"""
创建文件记录
Args:
file_data: 文件元数据
Returns:
创建后的文件记录
"""
pass
@abstractmethod
async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]:
"""
获取文件记录
Args:
file_id: 文件唯一标识
Returns:
文件记录,不存在则返回 None
"""
pass
@abstractmethod
async def delete_file_record(self, file_id: str) -> bool:
"""
删除文件记录
Args:
file_id: 文件唯一标识
Returns:
是否成功删除
"""
pass
@abstractmethod
async def list_file_records(
self,
purpose: Optional[str] = None,
limit: int = 50,
offset: int = 0
) -> List[Dict[str, Any]]:
"""
查询文件记录列表
Args:
purpose: 按用途筛选
limit: 返回数量限制
offset: 偏移量
Returns:
文件记录列表
"""
pass