| # GPT-SoVITS 音色训练 HTTP API 服务架构设计 | |
| > **文档说明**: 本文档是 API 服务的完整架构设计文档,包含设计规范和实现参考代码。 | |
| ## 实现进度总览 | |
| | 模块 | 状态 | 说明 | | |
| |------|------|------| | |
| | **架构设计** | ✅ 完成 | 双模式 API 设计(Quick Mode + Advanced Mode) | | |
| | **Pydantic Schema** | ✅ 已实现 | `app/models/schemas/` - task.py, experiment.py, file.py, common.py | | |
| | **数据库 Schema** | ✅ 设计完成 | SQLite/PostgreSQL 表结构 | | |
| | **适配器基类** | ✅ 已实现 | `TaskQueueAdapter`, `ProgressAdapter`, `StorageAdapter`, `DatabaseAdapter` | | |
| | **AsyncTrainingManager** | ✅ 已实现 | 本地任务队列完整实现 | | |
| | **配置管理** | ✅ 已实现 | `app/core/config.py` | | |
| | **领域模型** | ✅ 已实现 | `Task`, `TaskStatus`, `ProgressInfo` | | |
| | **Pipeline 脚本** | ✅ 已实现 | `app/scripts/run_pipeline.py` | | |
| | **存储适配器** | ✅ 已实现 | `app/adapters/local/storage.py` - LocalStorageAdapter | | |
| | **数据库适配器** | ✅ 已实现 | `app/adapters/local/database.py` - SQLiteAdapter | | |
| | **进度适配器** | ✅ 已实现 | `app/adapters/local/progress.py` - LocalProgressAdapter | | |
| | **适配器工厂** | ✅ 已实现 | `app/core/adapters.py` - AdapterFactory | | |
| | **API 端点** | ✅ 已实现 | `app/api/v1/endpoints/` - tasks, experiments, files, stages | | |
| | **服务层** | ✅ 已实现 | `app/services/` - TaskService, ExperimentService, FileService | | |
| | **FastAPI 入口** | ✅ 已实现 | `app/main.py` - 应用入口和生命周期管理 | | |
| --- | |
| ## 一、架构总览 | |
| ### 1.1 两种部署场景对比 | |
| | 维度 | macOS本地训练 | Linux服务器端训练 | | |
| |------|--------------|------------------| | |
| | **用户场景** | 个人开发者、小规模训练 | 生产环境、多用户、大规模训练 | | |
| | **并发需求** | 单用户、串行任务 | 多用户、并发任务 | | |
| | **资源管理** | 简单(单机GPU) | 复杂(多GPU、分布式) | | |
| | **持久化需求** | 轻量级(SQLite/文件) | 重量级(PostgreSQL/分布式存储) | | |
| | **任务队列** | 简单队列(内存/SQLite) | 分布式队列(Celery+Redis) | | |
| | **API复杂度** | 简化版 | 完整版 | | |
| ### 1.1.1 macOS本地训练的运行模式 | |
| macOS本地训练可以有三种运行方式,需要根据最终交付形态选择合适的任务管理方案: | |
| | 运行模式 | 描述 | 启动方式 | 任务管理推荐 | | |
| |----------|------|----------|-------------| | |
| | **开发模式** | 直接运行Python脚本 | `python main.py` / `uvicorn` | asyncio.subprocess ⭐ | | |
| | **PyInstaller打包** | 打包为独立可执行文件 | `./app` 单个可执行文件 | asyncio.subprocess ⭐ | | |
| | **Electron集成** | 作为Electron子进程运行 | Electron spawn Python进程 | asyncio.subprocess ⭐ | | |
| #### ⚠️ PyInstaller + Electron 场景的特殊考量 | |
| 当需要将训练工程通过PyInstaller打包并集成到Electron应用时,**Huey不是合适的选择**,原因如下: | |
| 1. **多进程架构冲突**:Huey需要独立的`huey_consumer`进程 | |
| 2. **进程生命周期复杂**:Electron需要管理多个Python子进程 | |
| 3. **打包复杂度增加**:PyInstaller需要正确打包所有依赖 | |
| **推荐方案**:使用 **`asyncio.subprocess`** 方案(见第7.1节),训练任务本身已经是子进程,无需额外的任务队列。 | |
| ### 1.2 架构统一设计原则 | |
| **核心理念**: 使用适配器模式,统一API层和业务逻辑层,底层存储和任务执行通过适配器切换 | |
| ``` | |
| ┌─────────────────────────────────────────────────────┐ | |
| │ Unified API Layer (FastAPI) │ | |
| │ /api/v1/tasks, /api/v1/experiments, /files, etc. │ | |
| └────────────────────┬────────────────────────────────┘ | |
| │ | |
| ┌────────────────────▼────────────────────────────────┐ | |
| │ Service Layer (Unified) │ | |
| │ TaskService, ExperimentService, FileService, etc. │ | |
| └────────┬───────────────────────────────┬────────────┘ | |
| │ │ | |
| │ Adapter Pattern │ | |
| │ │ | |
| ┌────▼─────┐ ┌─────▼──────┐ | |
| │ Local │ │ Server │ | |
| │ Adapter │ │ Adapter │ | |
| └────┬─────┘ └─────┬──────┘ | |
| │ │ | |
| ┌────▼─────────────┐ ┌────────▼────────────┐ | |
| │ Local Backend │ │ Server Backend │ | |
| │ - SQLite │ │ - PostgreSQL │ | |
| │ - asyncio.subproc│ │ - Celery+Redis │ | |
| │ - Local FS │ │ - S3/MinIO │ | |
| └──────────────────┘ └─────────────────────┘ | |
| ``` | |
| --- | |
| ## 二、技术栈对比 | |
| ### 2.1 macOS本地训练方案 | |
| ```yaml | |
| Web框架: FastAPI | |
| 数据库: SQLite (aiosqlite) | |
| 任务管理: asyncio.subprocess (推荐) - 训练脚本本身是子进程 | |
| 文件存储: 本地文件系统 | |
| 进度推送: SSE (Server-Sent Events) | |
| 缓存: 内存 (lru_cache / cachetools) | |
| 日志: Loguru | |
| 配置: YAML / .env文件 | |
| ``` | |
| **优点**: | |
| - 无需额外服务(Redis、PostgreSQL) | |
| - 部署简单,一键启动 | |
| - 适合个人使用 | |
| **缺点**: | |
| - 不支持水平扩展 | |
| - 单点故障 | |
| - 任务并发能力有限 | |
| ### 2.2 Linux服务器端训练方案 | |
| ```yaml | |
| Web框架: FastAPI | |
| 数据库: PostgreSQL + Alembic (数据迁移) | |
| 任务队列: Celery + Redis | |
| 文件存储: MinIO / S3 | |
| 进度推送: SSE + Redis Pub/Sub | |
| 缓存: Redis | |
| 日志: Loguru + ELK Stack (可选) | |
| 监控: Prometheus + Grafana | |
| 配置: 环境变量 + Consul/etcd (可选) | |
| ``` | |
| **优点**: | |
| - 高并发、高可用 | |
| - 水平扩展 | |
| - 完整的监控告警 | |
| **缺点**: | |
| - 部署复杂 | |
| - 需要额外服务依赖 | |
| --- | |
| ## 三、统一架构设计 | |
| ### 3.1 项目结构 | |
| > **图例**: ✅ 已实现 | [待实现] 设计完成待开发 | [Phase 2] 服务器模式后续实现 | |
| ``` | |
| api_server/ | |
| ├── app/ | |
| │ ├── __init__.py # ✅ 已实现 | |
| │ │ | |
| │ ├── api/ # ✅ API 路由层 | |
| │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ ├── deps.py # ✅ 已实现 - 依赖注入 | |
| │ │ └── v1/ | |
| │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ ├── endpoints/ | |
| │ │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ │ ├── tasks.py # ✅ 已实现 - Quick Mode 任务管理 | |
| │ │ │ ├── experiments.py # ✅ 已实现 - Advanced Mode 实验管理 | |
| │ │ │ ├── stages.py # ✅ 已实现 - 阶段参数模板 | |
| │ │ │ ├── files.py # ✅ 已实现 - 文件管理 | |
| │ │ │ ├── models.py # [待实现] 模型管理 | |
| │ │ │ └── inference.py # [待实现] 推理接口 | |
| │ │ └── router.py # ✅ 已实现 - 路由注册 | |
| │ │ | |
| │ ├── core/ | |
| │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ ├── config.py # ✅ 已实现 - Settings, 路径常量, get_pythonpath() | |
| │ │ ├── adapters.py # ✅ 已实现 - 适配器工厂 | |
| │ │ └── enums.py # [待实现] 枚举定义 | |
| │ │ | |
| │ ├── services/ # ✅ 业务逻辑层 | |
| │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ ├── task_service.py # ✅ 已实现 - Quick Mode 任务服务 | |
| │ │ ├── experiment_service.py # ✅ 已实现 - Advanced Mode 实验服务 | |
| │ │ ├── file_service.py # ✅ 已实现 - 文件管理服务 | |
| │ │ ├── pipeline_service.py # [待实现] | |
| │ │ └── progress_service.py # [待实现] | |
| │ │ | |
| │ ├── adapters/ # 适配器层 | |
| │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ ├── base.py # ✅ 已实现 - TaskQueueAdapter, ProgressAdapter, StorageAdapter, DatabaseAdapter | |
| │ │ ├── local/ | |
| │ │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ │ ├── task_queue.py # ✅ 已实现 - AsyncTrainingManager (完整) | |
| │ │ │ ├── storage.py # ✅ 已实现 - LocalStorageAdapter | |
| │ │ │ ├── database.py # ✅ 已实现 - SQLiteAdapter | |
| │ │ │ └── progress.py # ✅ 已实现 - LocalProgressAdapter | |
| │ │ └── server/ # [Phase 2] | |
| │ │ ├── storage.py # S3/MinIO 适配器 | |
| │ │ ├── task_queue.py # Celery 适配器 | |
| │ │ └── database.py # PostgreSQL 适配器 | |
| │ │ | |
| │ ├── models/ | |
| │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ ├── domain.py # ✅ 已实现 - Task, TaskStatus, ProgressInfo | |
| │ │ └── schemas/ # ✅ 已实现 - Pydantic 模型 | |
| │ │ ├── __init__.py # ✅ 已实现 - Schema 模块导出 | |
| │ │ ├── common.py # ✅ 已实现 - 通用响应模型 | |
| │ │ ├── task.py # ✅ 已实现 - Quick Mode 任务模型 | |
| │ │ ├── experiment.py # ✅ 已实现 - Advanced Mode 实验/阶段模型 | |
| │ │ ├── file.py # ✅ 已实现 - 文件上传/下载模型 | |
| │ │ └── inference.py # [待实现] 推理相关模型 | |
| │ │ | |
| │ ├── scripts/ | |
| │ │ ├── __init__.py # ✅ 已实现 | |
| │ │ └── run_pipeline.py # ✅ 已实现 - Pipeline 子进程执行器 | |
| │ │ | |
| │ ├── workers/ # [待实现] 任务执行器 | |
| │ │ ├── local_worker.py # 本地执行器 | |
| │ │ └── celery_worker.py # [Phase 2] Celery 执行器 | |
| │ │ | |
| │ └── main.py # ✅ 已实现 - FastAPI 入口 | |
| │ | |
| ├── data/ # 数据目录 | |
| │ ├── configs/ # 任务配置文件 | |
| │ ├── tasks.db # SQLite 数据库 | |
| │ └── test_config.json # 测试配置 | |
| │ | |
| ├── config/ # [待实现] | |
| │ ├── local.yaml # 本地配置 | |
| │ └── server.yaml # 服务器配置 | |
| │ | |
| ├── requirements/ # [待实现] | |
| │ ├── base.txt # 共同依赖 | |
| │ ├── local.txt # 本地额外依赖 | |
| │ └── server.txt # 服务器额外依赖 | |
| │ | |
| ├── docker-compose.local.yml # [待实现] 本地开发 | |
| ├── docker-compose.server.yml # [Phase 2] 服务器部署 | |
| └── README.md # [待实现] | |
| ``` | |
| ### 3.2 核心适配器设计 | |
| #### 3.2.1 抽象基类 ✅ 已完成 | |
| > **实现状态**: 所有适配器抽象基类已在 `app/adapters/base.py` 中实现: | |
| > - `TaskQueueAdapter` - 任务队列接口 | |
| > - `ProgressAdapter` - 进度管理接口 | |
| > - `StorageAdapter` - 文件存储接口 | |
| > - `DatabaseAdapter` - 数据库操作接口 | |
| ```python | |
| # app/adapters/base.py - ✅ 已实现部分 | |
| from abc import ABC, abstractmethod | |
| from typing import Dict, List, Optional, AsyncGenerator | |
| class TaskQueueAdapter(ABC): | |
| """ | |
| 任务队列适配器抽象基类 ✅ 已实现 | |
| 定义任务队列的通用接口,支持本地(asyncio.subprocess)和 | |
| 服务器(Celery)两种实现方式。 | |
| """ | |
| @abstractmethod | |
| async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: | |
| """将任务加入队列,返回job_id""" | |
| pass | |
| @abstractmethod | |
| async def get_status(self, job_id: str) -> Dict: | |
| """获取任务状态""" | |
| pass | |
| @abstractmethod | |
| async def cancel(self, job_id: str) -> bool: | |
| """取消任务""" | |
| pass | |
| @abstractmethod | |
| async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: | |
| """订阅任务进度(SSE流)""" | |
| pass | |
| class ProgressAdapter(ABC): | |
| """ | |
| 进度管理适配器抽象基类 ✅ 已实现 | |
| 用于更新和订阅任务进度,支持本地(内存队列)和 | |
| 服务器(Redis Pub/Sub)两种实现。 | |
| """ | |
| @abstractmethod | |
| async def update_progress(self, task_id: str, progress: Dict) -> None: | |
| """更新进度""" | |
| pass | |
| @abstractmethod | |
| async def get_progress(self, task_id: str) -> Optional[Dict]: | |
| """获取当前进度""" | |
| pass | |
| @abstractmethod | |
| async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: | |
| """订阅进度更新""" | |
| pass | |
| ``` | |
| ```python | |
| # app/adapters/base.py - 待实现部分 | |
| class StorageAdapter(ABC): | |
| """存储适配器抽象基类 [待实现]""" | |
| @abstractmethod | |
| async def upload_file(self, file_data: bytes, filename: str, metadata: Dict) -> str: | |
| """上传文件,返回文件ID""" | |
| pass | |
| @abstractmethod | |
| async def download_file(self, file_id: str) -> bytes: | |
| """下载文件""" | |
| pass | |
| @abstractmethod | |
| async def delete_file(self, file_id: str) -> bool: | |
| """删除文件""" | |
| pass | |
| @abstractmethod | |
| async def get_file_metadata(self, file_id: str) -> Dict: | |
| """获取文件元数据""" | |
| pass | |
| class DatabaseAdapter(ABC): | |
| """数据库适配器抽象基类 [待实现]""" | |
| @abstractmethod | |
| async def create_task(self, task: Task) -> Task: | |
| """创建任务""" | |
| pass | |
| @abstractmethod | |
| async def get_task(self, task_id: str) -> Optional[Task]: | |
| """获取任务""" | |
| pass | |
| @abstractmethod | |
| async def update_task(self, task_id: str, updates: Dict) -> Task: | |
| """更新任务""" | |
| pass | |
| @abstractmethod | |
| async def list_tasks(self, filters: Dict, limit: int, offset: int) -> List[Task]: | |
| """查询任务列表""" | |
| pass | |
| @abstractmethod | |
| async def delete_task(self, task_id: str) -> bool: | |
| """删除任务""" | |
| pass | |
| ``` | |
| #### 3.2.2 本地适配器实现 | |
| ##### AsyncTrainingManager ✅ 已完整实现 | |
| > **实现文件**: `app/adapters/local/task_queue.py` | |
| > | |
| > 这是本地模式的核心组件,已完整实现以下功能: | |
| > - 任务入队与异步执行 | |
| > - 子进程管理 (`asyncio.create_subprocess_exec`) | |
| > - 进度解析与 SSE 流推送 | |
| > - 任务状态持久化(SQLite) | |
| > - 任务取消与恢复 | |
| ```python | |
| # app/adapters/local/task_queue.py - ✅ 已完整实现 | |
| class AsyncTrainingManager(TaskQueueAdapter): | |
| """ | |
| 基于 asyncio.subprocess 的异步任务管理器 | |
| 特点: | |
| 1. 使用 asyncio.create_subprocess_exec() 异步启动训练子进程 | |
| 2. 完全非阻塞,与 FastAPI 异步模型完美契合 | |
| 3. SQLite 持久化任务状态,支持应用重启后恢复 | |
| 4. 实时解析子进程输出获取进度 | |
| """ | |
| def __init__(self, db_path: str = None, max_concurrent: int = 1): | |
| self.db_path = db_path or str(settings.SQLITE_PATH) | |
| self.max_concurrent = max_concurrent | |
| self.running_processes: Dict[str, asyncio.subprocess.Process] = {} | |
| self.progress_channels: Dict[str, asyncio.Queue] = {} | |
| self._init_db_sync() | |
| async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: | |
| """将任务加入队列并异步启动""" | |
| # ... 完整实现见源文件 | |
| async def get_status(self, job_id: str) -> Dict: | |
| """获取任务状态""" | |
| # ... 完整实现见源文件 | |
| async def get_status_by_task_id(self, task_id: str) -> Dict: | |
| """通过 task_id 获取任务状态""" | |
| # ... 完整实现见源文件 | |
| async def cancel(self, job_id: str) -> bool: | |
| """取消任务(优雅终止 + 强制终止)""" | |
| # ... 完整实现见源文件 | |
| async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: | |
| """订阅任务进度(用于 SSE 流)""" | |
| # ... 完整实现见源文件 | |
| async def list_tasks(self, status: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: | |
| """列出任务""" | |
| # ... 完整实现见源文件 | |
| async def recover_pending_tasks(self) -> int: | |
| """应用重启后恢复未完成的任务""" | |
| # ... 完整实现见源文件 | |
| async def cleanup_old_tasks(self, days: int = 7) -> int: | |
| """清理旧任务记录""" | |
| # ... 完整实现见源文件 | |
| ``` | |
| ##### LocalStorageAdapter ✅ 已实现 | |
| > **实现文件**: `app/adapters/local/storage.py` | |
| > | |
| > 基于本地文件系统的存储适配器,使用 aiofiles 实现异步 I/O。 | |
| > 支持文件上传/下载、元数据管理、音频信息提取等功能。 | |
| ```python | |
| # app/adapters/local/storage.py - ✅ 已完整实现 | |
| class LocalStorageAdapter(StorageAdapter): | |
| """ | |
| 本地文件系统存储适配器 | |
| 特点: | |
| 1. 使用 aiofiles 进行异步文件读写 | |
| 2. 元数据存储在 .meta.json 文件中 | |
| 3. 支持音频文件信息提取(时长、采样率等) | |
| """ | |
| async def upload_file(self, file_data: bytes, filename: str, metadata: Dict) -> str: | |
| """上传文件,返回 file_id""" | |
| # ... 完整实现见源文件 | |
| async def download_file(self, file_id: str) -> bytes: | |
| """下载文件""" | |
| # ... 完整实现见源文件 | |
| async def delete_file(self, file_id: str) -> bool: | |
| """删除文件及其元数据""" | |
| # ... 完整实现见源文件 | |
| async def get_file_metadata(self, file_id: str) -> Optional[Dict]: | |
| """获取文件元数据""" | |
| # ... 完整实现见源文件 | |
| async def list_files(self, purpose: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: | |
| """列出文件""" | |
| # ... 完整实现见源文件 | |
| ``` | |
| ##### SQLiteAdapter ✅ 已实现 | |
| > **实现文件**: `app/adapters/local/database.py` | |
| > | |
| > 基于 SQLite + aiosqlite 的数据库适配器,支持 Task 和 Experiment 的完整 CRUD 操作。 | |
| ```python | |
| # app/adapters/local/database.py - ✅ 已完整实现 | |
| class SQLiteAdapter(DatabaseAdapter): | |
| """ | |
| SQLite 数据库适配器 | |
| 特点: | |
| 1. 使用 aiosqlite 实现异步数据库操作 | |
| 2. 支持 Task (Quick Mode) 和 Experiment (Advanced Mode) 管理 | |
| 3. 自动初始化数据库表结构 | |
| """ | |
| # Task CRUD | |
| async def create_task(self, task: Task) -> Task: ... | |
| async def get_task(self, task_id: str) -> Optional[Task]: ... | |
| async def update_task(self, task_id: str, updates: Dict) -> Optional[Task]: ... | |
| async def list_tasks(self, status: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Task]: ... | |
| async def delete_task(self, task_id: str) -> bool: ... | |
| async def count_tasks(self, status: Optional[str] = None) -> int: ... | |
| # Experiment CRUD | |
| async def create_experiment(self, experiment: Dict) -> Dict: ... | |
| async def get_experiment(self, exp_id: str) -> Optional[Dict]: ... | |
| async def update_experiment(self, exp_id: str, updates: Dict) -> Optional[Dict]: ... | |
| async def list_experiments(self, status: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: ... | |
| async def delete_experiment(self, exp_id: str) -> bool: ... | |
| # Stage 操作 | |
| async def update_stage(self, exp_id: str, stage_type: str, updates: Dict) -> Optional[Dict]: ... | |
| async def get_stage(self, exp_id: str, stage_type: str) -> Optional[Dict]: ... | |
| async def get_all_stages(self, exp_id: str) -> List[Dict]: ... | |
| # File 记录 | |
| async def create_file_record(self, file_data: Dict) -> Dict: ... | |
| async def get_file_record(self, file_id: str) -> Optional[Dict]: ... | |
| async def delete_file_record(self, file_id: str) -> bool: ... | |
| async def list_file_records(self, purpose: Optional[str] = None, limit: int = 50, offset: int = 0) -> List[Dict]: ... | |
| ``` | |
| ##### LocalProgressAdapter ✅ 已实现 | |
| > **实现文件**: `app/adapters/local/progress.py` | |
| > | |
| > 基于内存队列的进度管理适配器,支持多订阅者模式。 | |
| ```python | |
| # app/adapters/local/progress.py - ✅ 已完整实现 | |
| class LocalProgressAdapter(ProgressAdapter): | |
| """ | |
| 本地内存进度管理适配器 | |
| 特点: | |
| 1. 使用内存字典存储最新进度 | |
| 2. 使用 asyncio.Queue 实现订阅者模式 | |
| 3. 支持多订阅者同时订阅同一任务 | |
| 4. 与 AsyncTrainingManager 的进度推送机制兼容 | |
| """ | |
| async def update_progress(self, task_id: str, progress: Dict) -> None: | |
| """更新进度并通知所有订阅者""" | |
| # ... 完整实现见源文件 | |
| async def get_progress(self, task_id: str) -> Optional[Dict]: | |
| """获取当前进度""" | |
| # ... 完整实现见源文件 | |
| async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: | |
| """订阅进度更新(支持心跳、自动清理)""" | |
| # ... 完整实现见源文件 | |
| ``` | |
| #### 3.2.3 服务器适配器实现 | |
| ```python | |
| # app/adapters/server/storage.py | |
| from minio import Minio | |
| from app.adapters.base import StorageAdapter | |
| class S3StorageAdapter(StorageAdapter): | |
| """MinIO/S3对象存储适配器""" | |
| def __init__(self, endpoint: str, access_key: str, secret_key: str, bucket: str): | |
| self.client = Minio( | |
| endpoint, | |
| access_key=access_key, | |
| secret_key=secret_key, | |
| secure=False | |
| ) | |
| self.bucket = bucket | |
| # 确保bucket存在 | |
| if not self.client.bucket_exists(bucket): | |
| self.client.make_bucket(bucket) | |
| async def upload_file(self, file_data: bytes, filename: str, metadata: Dict) -> str: | |
| file_id = str(uuid.uuid4()) | |
| # 上传文件 | |
| self.client.put_object( | |
| self.bucket, | |
| file_id, | |
| io.BytesIO(file_data), | |
| len(file_data), | |
| metadata=metadata | |
| ) | |
| return file_id | |
| # ... 其他方法实现 | |
| ``` | |
| ```python | |
| # app/adapters/server/database.py | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| from app.adapters.base import DatabaseAdapter | |
| class PostgreSQLAdapter(DatabaseAdapter): | |
| """PostgreSQL数据库适配器""" | |
| def __init__(self, database_url: str): | |
| self.engine = create_async_engine(database_url) | |
| # 使用SQLAlchemy ORM | |
| async def create_task(self, task: Task) -> Task: | |
| async with AsyncSession(self.engine) as session: | |
| db_task = TaskModel(**task.dict()) | |
| session.add(db_task) | |
| await session.commit() | |
| await session.refresh(db_task) | |
| return Task.from_orm(db_task) | |
| # ... 其他方法实现 | |
| ``` | |
| ```python | |
| # app/adapters/server/task_queue.py | |
| from celery import Celery | |
| from app.adapters.base import TaskQueueAdapter | |
| class CeleryTaskQueueAdapter(TaskQueueAdapter): | |
| """Celery分布式任务队列""" | |
| def __init__(self, broker_url: str, backend_url: str): | |
| self.celery_app = Celery( | |
| 'gpt_sovits_training', | |
| broker=broker_url, | |
| backend=backend_url | |
| ) | |
| async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: | |
| from app.workers.celery_worker import execute_training_pipeline | |
| result = execute_training_pipeline.apply_async( | |
| args=[task_id, config], | |
| queue=f'queue_{priority}', | |
| priority=self._get_priority_value(priority) | |
| ) | |
| return result.id | |
| async def get_status(self, job_id: str) -> Dict: | |
| result = self.celery_app.AsyncResult(job_id) | |
| return { | |
| "status": result.state, | |
| "info": result.info | |
| } | |
| # ... 其他方法实现 | |
| ``` | |
| ```python | |
| # app/adapters/server/progress.py | |
| import redis.asyncio as redis | |
| from app.adapters.base import ProgressAdapter | |
| class RedisProgressAdapter(ProgressAdapter): | |
| """Redis进度管理""" | |
| def __init__(self, redis_url: str): | |
| self.redis = redis.from_url(redis_url) | |
| async def update_progress(self, task_id: str, progress: Dict): | |
| # 保存到Redis Hash | |
| await self.redis.hset( | |
| f"task:progress:{task_id}", | |
| mapping={ | |
| "data": json.dumps(progress), | |
| "updated_at": time.time() | |
| } | |
| ) | |
| # 发布到Redis Pub/Sub | |
| await self.redis.publish( | |
| f"task:progress:{task_id}", | |
| json.dumps(progress) | |
| ) | |
| async def get_progress(self, task_id: str) -> Optional[Dict]: | |
| data = await self.redis.hget(f"task:progress:{task_id}", "data") | |
| if data: | |
| return json.loads(data) | |
| return None | |
| async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: | |
| pubsub = self.redis.pubsub() | |
| await pubsub.subscribe(f"task:progress:{task_id}") | |
| try: | |
| async for message in pubsub.listen(): | |
| if message['type'] == 'message': | |
| progress = json.loads(message['data']) | |
| yield progress | |
| if progress.get('status') in ['completed', 'failed', 'cancelled']: | |
| break | |
| finally: | |
| await pubsub.unsubscribe(f"task:progress:{task_id}") | |
| ``` | |
| ### 3.3 适配器工厂 | |
| ```python | |
| # app/core/adapters.py | |
| from app.core.config import settings | |
| from app.adapters.base import StorageAdapter, DatabaseAdapter, TaskQueueAdapter, ProgressAdapter | |
| class AdapterFactory: | |
| """适配器工厂,根据配置创建对应的适配器""" | |
| @staticmethod | |
| def create_storage_adapter() -> StorageAdapter: | |
| if settings.DEPLOYMENT_MODE == "local": | |
| from app.adapters.local.storage import LocalStorageAdapter | |
| return LocalStorageAdapter(base_path=settings.LOCAL_STORAGE_PATH) | |
| else: | |
| from app.adapters.server.storage import S3StorageAdapter | |
| return S3StorageAdapter( | |
| endpoint=settings.S3_ENDPOINT, | |
| access_key=settings.S3_ACCESS_KEY, | |
| secret_key=settings.S3_SECRET_KEY, | |
| bucket=settings.S3_BUCKET | |
| ) | |
| @staticmethod | |
| def create_database_adapter() -> DatabaseAdapter: | |
| if settings.DEPLOYMENT_MODE == "local": | |
| from app.adapters.local.database import SQLiteAdapter | |
| return SQLiteAdapter(db_path=settings.SQLITE_PATH) | |
| else: | |
| from app.adapters.server.database import PostgreSQLAdapter | |
| return PostgreSQLAdapter(database_url=settings.DATABASE_URL) | |
| @staticmethod | |
| def create_task_queue_adapter() -> TaskQueueAdapter: | |
| if settings.DEPLOYMENT_MODE == "local": | |
| from app.adapters.local.task_queue import AsyncTrainingManager | |
| return AsyncTrainingManager(db_path=settings.SQLITE_PATH) | |
| else: | |
| from app.adapters.server.task_queue import CeleryTaskQueueAdapter | |
| return CeleryTaskQueueAdapter( | |
| broker_url=settings.CELERY_BROKER_URL, | |
| backend_url=settings.CELERY_RESULT_BACKEND | |
| ) | |
| @staticmethod | |
| def create_progress_adapter() -> ProgressAdapter: | |
| if settings.DEPLOYMENT_MODE == "local": | |
| from app.adapters.local.progress import LocalProgressAdapter | |
| return LocalProgressAdapter() | |
| else: | |
| from app.adapters.server.progress import RedisProgressAdapter | |
| return RedisProgressAdapter(redis_url=settings.REDIS_URL) | |
| # 全局单例 | |
| storage_adapter = AdapterFactory.create_storage_adapter() | |
| database_adapter = AdapterFactory.create_database_adapter() | |
| task_queue_adapter = AdapterFactory.create_task_queue_adapter() | |
| progress_adapter = AdapterFactory.create_progress_adapter() | |
| ``` | |
| ### 3.4 统一配置管理 | |
| ```python | |
| # app/core/config.py | |
| from pydantic_settings import BaseSettings | |
| from typing import Literal | |
| class Settings(BaseSettings): | |
| # 部署模式 | |
| DEPLOYMENT_MODE: Literal["local", "server"] = "local" | |
| # 通用配置 | |
| API_V1_PREFIX: str = "/api/v1" | |
| PROJECT_NAME: str = "GPT-SoVITS Training API" | |
| # 本地模式配置 | |
| LOCAL_STORAGE_PATH: str = "./data/files" | |
| SQLITE_PATH: str = "./data/app.db" | |
| LOCAL_MAX_WORKERS: int = 1 # 本地同时运行的训练任务数 | |
| # 服务器模式配置 | |
| DATABASE_URL: str = "postgresql+asyncpg://user:pass@localhost/gpt_sovits" | |
| REDIS_URL: str = "redis://localhost:6379/0" | |
| CELERY_BROKER_URL: str = "redis://localhost:6379/1" | |
| CELERY_RESULT_BACKEND: str = "redis://localhost:6379/2" | |
| S3_ENDPOINT: str = "localhost:9000" | |
| S3_ACCESS_KEY: str = "minioadmin" | |
| S3_SECRET_KEY: str = "minioadmin" | |
| S3_BUCKET: str = "gpt-sovits" | |
| class Config: | |
| env_file = ".env" | |
| case_sensitive = True | |
| settings = Settings() | |
| ``` | |
| --- | |
| ## 四、统一API接口(无差异) | |
| 无论是本地还是服务器模式,API接口完全一致。 | |
| ### 4.1 API 设计目标 | |
| 针对不同用户群体,提供两套独立的 API 体系: | |
| | 用户类型 | 需求 | API 模式 | 核心概念 | API 前缀 | | |
| |----------|------|----------|----------|----------| | |
| | **小白用户** | 上传音频即可训练,无需了解细节 | Quick Mode | Task(任务) | `/api/v1/tasks` | | |
| | **专家用户** | 精细控制每个阶段参数,分阶段执行 | Advanced Mode | Experiment(实验)+ Stage(阶段) | `/api/v1/experiments` | | |
| ### 4.2 完整 API 端点列表 | |
| #### Quick Mode API(小白用户) | |
| | 方法 | 路径 | 描述 | | |
| |------|------|------| | |
| | `POST` | `/api/v1/tasks` | 创建一键训练任务 | | |
| | `GET` | `/api/v1/tasks` | 获取任务列表 | | |
| | `GET` | `/api/v1/tasks/{task_id}` | 获取任务详情 | | |
| | `DELETE` | `/api/v1/tasks/{task_id}` | 取消任务 | | |
| | `GET` | `/api/v1/tasks/{task_id}/progress` | SSE 进度订阅 | | |
| #### Advanced Mode API(专家用户) | |
| | 方法 | 路径 | 描述 | | |
| |------|------|------| | |
| | `POST` | `/api/v1/experiments` | 创建实验(不立即执行) | | |
| | `GET` | `/api/v1/experiments` | 获取实验列表 | | |
| | `GET` | `/api/v1/experiments/{exp_id}` | 获取实验详情 | | |
| | `DELETE` | `/api/v1/experiments/{exp_id}` | 删除实验 | | |
| | `PATCH` | `/api/v1/experiments/{exp_id}` | 更新实验基础配置 | | |
| | `POST` | `/api/v1/experiments/{exp_id}/stages/{stage_type}` | 执行指定阶段 | | |
| | `GET` | `/api/v1/experiments/{exp_id}/stages` | 获取所有阶段状态 | | |
| | `GET` | `/api/v1/experiments/{exp_id}/stages/{stage_type}` | 获取指定阶段状态/结果 | | |
| | `GET` | `/api/v1/experiments/{exp_id}/stages/{stage_type}/progress` | SSE 阶段进度订阅 | | |
| | `DELETE` | `/api/v1/experiments/{exp_id}/stages/{stage_type}` | 取消正在执行的阶段 | | |
| #### 通用 API | |
| | 方法 | 路径 | 描述 | | |
| |------|------|------| | |
| | `POST` | `/api/v1/files` | 上传文件 | | |
| | `GET` | `/api/v1/files` | 获取文件列表 | | |
| | `GET` | `/api/v1/files/{file_id}` | 下载文件 | | |
| | `DELETE` | `/api/v1/files/{file_id}` | 删除文件 | | |
| | `GET` | `/api/v1/stages/presets` | 获取阶段预设列表 | | |
| | `GET` | `/api/v1/stages/{stage_type}/schema` | 获取阶段参数模板 | | |
| --- | |
| ## 4.3 Quick Mode API 详解(小白用户) | |
| ### 4.3.1 创建一键训练任务 | |
| ``` | |
| POST /api/v1/tasks | |
| ``` | |
| 只需上传音频文件,系统自动配置所有训练参数并执行完整流程: | |
| ```json | |
| { | |
| "exp_name": "my_voice", | |
| "audio_file_id": "550e8400-e29b-41d4-a716-446655440000", | |
| "options": { | |
| "version": "v2", | |
| "language": "zh", | |
| "quality": "standard" | |
| } | |
| } | |
| ``` | |
| **参数说明**: | |
| | 字段 | 类型 | 必填 | 说明 | | |
| |------|------|------|------| | |
| | `exp_name` | string | 是 | 实验名称 | | |
| | `audio_file_id` | string | 是 | 已上传音频文件的 ID | | |
| | `options.version` | string | 否 | 模型版本,默认 `"v2"` | | |
| | `options.language` | string | 否 | 语言,默认 `"zh"` | | |
| | `options.quality` | string | 否 | 训练质量:`"fast"` / `"standard"` / `"high"` | | |
| **质量预设**: | |
| | quality | SoVITS epochs | GPT epochs | 训练时长 | | |
| |---------|---------------|------------|----------| | |
| | `fast` | 4 | 8 | ~10分钟 | | |
| | `standard` | 8 | 15 | ~20分钟 | | |
| | `high` | 16 | 30 | ~40分钟 | | |
| **系统自动执行流程**: | |
| ``` | |
| audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train | |
| ``` | |
| **响应示例**: | |
| ```json | |
| { | |
| "id": "task-550e8400-e29b-41d4-a716-446655440000", | |
| "exp_name": "my_voice", | |
| "status": "queued", | |
| "current_stage": null, | |
| "progress": 0.0, | |
| "overall_progress": 0.0, | |
| "created_at": "2024-01-01T10:00:00Z" | |
| } | |
| ``` | |
| ### 4.3.2 获取任务状态 | |
| ``` | |
| GET /api/v1/tasks/{task_id} | |
| ``` | |
| **响应示例**: | |
| ```json | |
| { | |
| "id": "task-550e8400-e29b-41d4-a716-446655440000", | |
| "exp_name": "my_voice", | |
| "status": "running", | |
| "current_stage": "sovits_train", | |
| "progress": 0.45, | |
| "overall_progress": 0.72, | |
| "message": "SoVITS 训练中 Epoch 8/16", | |
| "created_at": "2024-01-01T10:00:00Z", | |
| "started_at": "2024-01-01T10:00:05Z" | |
| } | |
| ``` | |
| ### 4.3.3 SSE 进度订阅 | |
| ``` | |
| GET /api/v1/tasks/{task_id}/progress | |
| ``` | |
| 返回 SSE 流,实时推送进度更新: | |
| ``` | |
| event: progress | |
| data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"} | |
| event: progress | |
| data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"} | |
| event: completed | |
| data: {"status": "completed", "message": "训练完成"} | |
| ``` | |
| --- | |
| ## 4.4 Advanced Mode API 详解(专家用户) | |
| Advanced Mode 引入**实验(Experiment)**概念,允许前端分阶段调用不同 API 触发训练。 | |
| ### 4.4.1 专家模式交互流程 | |
| ```mermaid | |
| sequenceDiagram | |
| participant Frontend | |
| participant API | |
| participant Pipeline | |
| Frontend->>API: POST /experiments (创建实验) | |
| API-->>Frontend: {exp_id: "abc123"} | |
| Frontend->>API: POST /experiments/abc123/stages/audio_slice | |
| API->>Pipeline: 启动音频切片 | |
| Frontend->>API: GET .../audio_slice/progress (SSE) | |
| Pipeline-->>Frontend: 进度更新... | |
| Pipeline-->>Frontend: {status: "completed"} | |
| Note over Frontend: 用户查看切片结果,调整参数 | |
| Frontend->>API: POST /experiments/abc123/stages/asr | |
| API->>Pipeline: 启动 ASR | |
| Pipeline-->>Frontend: 进度更新... | |
| Note over Frontend: 继续后续阶段... | |
| ``` | |
| ### 4.4.2 创建实验 | |
| ``` | |
| POST /api/v1/experiments | |
| ``` | |
| 创建实验但不立即执行,用户可以逐阶段控制: | |
| ```json | |
| { | |
| "exp_name": "my_voice_custom", | |
| "version": "v2", | |
| "gpu_numbers": "0", | |
| "is_half": true, | |
| "audio_file_id": "550e8400-e29b-41d4-a716-446655440000" | |
| } | |
| ``` | |
| **参数说明**: | |
| | 字段 | 类型 | 必填 | 说明 | | |
| |------|------|------|------| | |
| | `exp_name` | string | 是 | 实验名称 | | |
| | `version` | string | 否 | 模型版本,默认 `"v2"` | | |
| | `gpu_numbers` | string | 否 | GPU 编号,默认 `"0"` | | |
| | `is_half` | bool | 否 | 是否使用半精度,默认 `true` | | |
| | `audio_file_id` | string | 是 | 已上传音频文件的 ID | | |
| **响应示例**: | |
| ```json | |
| { | |
| "id": "exp-abc123", | |
| "exp_name": "my_voice_custom", | |
| "version": "v2", | |
| "status": "created", | |
| "stages": { | |
| "audio_slice": { "status": "pending" }, | |
| "asr": { "status": "pending" }, | |
| "text_feature": { "status": "pending" }, | |
| "hubert_feature": { "status": "pending" }, | |
| "semantic_token": { "status": "pending" }, | |
| "sovits_train": { "status": "pending" }, | |
| "gpt_train": { "status": "pending" } | |
| }, | |
| "created_at": "2024-01-01T10:00:00Z" | |
| } | |
| ``` | |
| ### 4.4.3 执行阶段 | |
| ``` | |
| POST /api/v1/experiments/{exp_id}/stages/{stage_type} | |
| ``` | |
| 触发指定阶段执行,可传入阶段特定参数覆盖默认值: | |
| **可用的阶段类型(stage_type)**: | |
| | stage_type | 描述 | 依赖阶段 | | |
| |------------|------|----------| | |
| | `audio_slice` | 音频切片 | 无 | | |
| | `asr` | 语音识别 | audio_slice | | |
| | `text_feature` | 文本特征提取 | asr | | |
| | `hubert_feature` | HuBERT 特征提取 | audio_slice | | |
| | `semantic_token` | 语义 token 提取 | hubert_feature | | |
| | `sovits_train` | SoVITS 训练 | text_feature, semantic_token | | |
| | `gpt_train` | GPT 训练 | text_feature, semantic_token | | |
| **请求示例(执行音频切片)**: | |
| ``` | |
| POST /api/v1/experiments/exp-abc123/stages/audio_slice | |
| ``` | |
| ```json | |
| { | |
| "threshold": -34, | |
| "min_length": 4000, | |
| "min_interval": 300, | |
| "hop_size": 10, | |
| "max_sil_kept": 500 | |
| } | |
| ``` | |
| **请求示例(执行 SoVITS 训练)**: | |
| ``` | |
| POST /api/v1/experiments/exp-abc123/stages/sovits_train | |
| ``` | |
| ```json | |
| { | |
| "batch_size": 8, | |
| "total_epoch": 16, | |
| "save_every_epoch": 4, | |
| "pretrained_s2G": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", | |
| "pretrained_s2D": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth" | |
| } | |
| ``` | |
| **响应示例**: | |
| ```json | |
| { | |
| "exp_id": "exp-abc123", | |
| "stage_type": "sovits_train", | |
| "status": "running", | |
| "job_id": "job-xyz789", | |
| "config": { | |
| "batch_size": 8, | |
| "total_epoch": 16, | |
| "save_every_epoch": 4 | |
| }, | |
| "started_at": "2024-01-01T10:30:00Z" | |
| } | |
| ``` | |
| ### 4.4.4 获取阶段状态 | |
| ``` | |
| GET /api/v1/experiments/{exp_id}/stages/{stage_type} | |
| ``` | |
| **响应示例(已完成)**: | |
| ```json | |
| { | |
| "stage_type": "sovits_train", | |
| "status": "completed", | |
| "started_at": "2024-01-01T10:30:00Z", | |
| "completed_at": "2024-01-01T11:00:00Z", | |
| "config": { | |
| "batch_size": 8, | |
| "total_epoch": 16, | |
| "save_every_epoch": 4 | |
| }, | |
| "outputs": { | |
| "model_path": "logs/my_voice_custom/sovits_e16.pth", | |
| "metrics": { | |
| "final_loss": 0.023, | |
| "best_epoch": 14 | |
| } | |
| } | |
| } | |
| ``` | |
| **响应示例(运行中)**: | |
| ```json | |
| { | |
| "stage_type": "sovits_train", | |
| "status": "running", | |
| "started_at": "2024-01-01T10:30:00Z", | |
| "progress": 0.45, | |
| "message": "Epoch 8/16, Loss: 0.034" | |
| } | |
| ``` | |
| ### 4.4.5 获取所有阶段状态 | |
| ``` | |
| GET /api/v1/experiments/{exp_id}/stages | |
| ``` | |
| **响应示例**: | |
| ```json | |
| { | |
| "exp_id": "exp-abc123", | |
| "stages": [ | |
| { | |
| "stage_type": "audio_slice", | |
| "status": "completed", | |
| "completed_at": "2024-01-01T10:05:00Z" | |
| }, | |
| { | |
| "stage_type": "asr", | |
| "status": "completed", | |
| "completed_at": "2024-01-01T10:10:00Z" | |
| }, | |
| { | |
| "stage_type": "text_feature", | |
| "status": "completed", | |
| "completed_at": "2024-01-01T10:12:00Z" | |
| }, | |
| { | |
| "stage_type": "hubert_feature", | |
| "status": "completed", | |
| "completed_at": "2024-01-01T10:20:00Z" | |
| }, | |
| { | |
| "stage_type": "semantic_token", | |
| "status": "completed", | |
| "completed_at": "2024-01-01T10:25:00Z" | |
| }, | |
| { | |
| "stage_type": "sovits_train", | |
| "status": "running", | |
| "started_at": "2024-01-01T10:30:00Z", | |
| "progress": 0.45 | |
| }, | |
| { | |
| "stage_type": "gpt_train", | |
| "status": "pending" | |
| } | |
| ] | |
| } | |
| ``` | |
| ### 4.4.6 SSE 阶段进度订阅 | |
| ``` | |
| GET /api/v1/experiments/{exp_id}/stages/{stage_type}/progress | |
| ``` | |
| 返回 SSE 流,实时推送阶段进度: | |
| ``` | |
| event: progress | |
| data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034} | |
| event: progress | |
| data: {"epoch": 9, "total_epochs": 16, "progress": 0.56, "loss": 0.031} | |
| event: checkpoint | |
| data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"} | |
| event: completed | |
| data: {"status": "completed", "final_loss": 0.023} | |
| ``` | |
| ### 4.4.7 取消阶段执行 | |
| ``` | |
| DELETE /api/v1/experiments/{exp_id}/stages/{stage_type} | |
| ``` | |
| 取消正在执行的阶段: | |
| **响应示例**: | |
| ```json | |
| { | |
| "success": true, | |
| "message": "阶段 sovits_train 已取消", | |
| "stage_type": "sovits_train", | |
| "status": "cancelled" | |
| } | |
| ``` | |
| ### 4.4.8 重新执行阶段 | |
| 专家用户可以对任意已完成的阶段重新执行(使用新参数): | |
| ``` | |
| POST /api/v1/experiments/{exp_id}/stages/sovits_train | |
| ``` | |
| 如果阶段已完成,再次调用会重新执行。响应中会包含 `rerun: true` 标记: | |
| ```json | |
| { | |
| "exp_id": "exp-abc123", | |
| "stage_type": "sovits_train", | |
| "status": "running", | |
| "rerun": true, | |
| "previous_run": { | |
| "completed_at": "2024-01-01T11:00:00Z", | |
| "outputs": { "model_path": "logs/my_voice/sovits_e16.pth" } | |
| } | |
| } | |
| ``` | |
| --- | |
| ## 4.5 阶段参数模板 API | |
| ### 4.5.1 获取阶段预设列表 | |
| ``` | |
| GET /api/v1/stages/presets | |
| ``` | |
| **响应示例**: | |
| ```json | |
| { | |
| "presets": [ | |
| { | |
| "id": "full_training", | |
| "name": "完整训练流程", | |
| "description": "包含所有阶段的标准训练", | |
| "stages": ["audio_slice", "asr", "text_feature", "hubert_feature", "semantic_token", "sovits_train", "gpt_train"] | |
| }, | |
| { | |
| "id": "retrain_sovits", | |
| "name": "重训 SoVITS", | |
| "description": "跳过预处理,仅重新训练 SoVITS", | |
| "stages": ["sovits_train"] | |
| }, | |
| { | |
| "id": "feature_extraction", | |
| "name": "特征提取", | |
| "description": "仅执行音频切片和特征提取", | |
| "stages": ["audio_slice", "asr", "text_feature", "hubert_feature", "semantic_token"] | |
| } | |
| ] | |
| } | |
| ``` | |
| ### 4.5.2 获取阶段参数模板 | |
| ``` | |
| GET /api/v1/stages/{stage_type}/schema | |
| ``` | |
| **响应示例**(`/api/v1/stages/audio_slice/schema`): | |
| ```json | |
| { | |
| "type": "audio_slice", | |
| "name": "音频切片", | |
| "description": "将长音频切分为短片段", | |
| "parameters": { | |
| "threshold": { | |
| "type": "integer", | |
| "default": -34, | |
| "min": -60, | |
| "max": 0, | |
| "description": "静音检测阈值 (dB)" | |
| }, | |
| "min_length": { | |
| "type": "integer", | |
| "default": 4000, | |
| "min": 1000, | |
| "max": 10000, | |
| "description": "最小切片长度 (ms)" | |
| }, | |
| "min_interval": { | |
| "type": "integer", | |
| "default": 300, | |
| "min": 100, | |
| "max": 1000, | |
| "description": "最小静音间隔 (ms)" | |
| }, | |
| "hop_size": { | |
| "type": "integer", | |
| "default": 10, | |
| "min": 5, | |
| "max": 50, | |
| "description": "检测步长 (ms)" | |
| }, | |
| "max_sil_kept": { | |
| "type": "integer", | |
| "default": 500, | |
| "min": 100, | |
| "max": 2000, | |
| "description": "切片保留的最大静音长度 (ms)" | |
| } | |
| } | |
| } | |
| ``` | |
| **响应示例**(`/api/v1/stages/sovits_train/schema`): | |
| ```json | |
| { | |
| "type": "sovits_train", | |
| "name": "SoVITS 训练", | |
| "description": "训练 SoVITS 声码器模型", | |
| "parameters": { | |
| "batch_size": { | |
| "type": "integer", | |
| "default": 4, | |
| "min": 1, | |
| "max": 32, | |
| "description": "批次大小,显存不足时减小" | |
| }, | |
| "total_epoch": { | |
| "type": "integer", | |
| "default": 8, | |
| "min": 1, | |
| "max": 100, | |
| "description": "训练总轮数" | |
| }, | |
| "save_every_epoch": { | |
| "type": "integer", | |
| "default": 4, | |
| "min": 1, | |
| "description": "每 N 轮保存一次模型" | |
| }, | |
| "pretrained_s2G": { | |
| "type": "string", | |
| "default": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", | |
| "description": "预训练生成器模型路径" | |
| }, | |
| "pretrained_s2D": { | |
| "type": "string", | |
| "default": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth", | |
| "description": "预训练判别器模型路径" | |
| } | |
| } | |
| } | |
| ``` | |
| --- | |
| ## 4.6 Pydantic Schema 设计 | |
| ### 4.6.1 Quick Mode Schema | |
| ```python | |
| from typing import Literal, Optional | |
| from pydantic import BaseModel, Field | |
| class QuickModeOptions(BaseModel): | |
| version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = "v2" | |
| language: str = "zh" | |
| quality: Literal["fast", "standard", "high"] = "standard" | |
| class QuickModeRequest(BaseModel): | |
| """小白用户一键训练请求""" | |
| exp_name: str = Field(..., min_length=1, max_length=100) | |
| audio_file_id: str | |
| options: QuickModeOptions = QuickModeOptions() | |
| ``` | |
| ### 4.6.2 Advanced Mode Schema | |
| ```python | |
| from typing import Literal, Optional, Dict, Any, List | |
| from pydantic import BaseModel, Field | |
| from datetime import datetime | |
| # ============================================================ | |
| # 实验管理 | |
| # ============================================================ | |
| class ExperimentCreate(BaseModel): | |
| """创建实验请求""" | |
| exp_name: str = Field(..., min_length=1, max_length=100, description="实验名称") | |
| version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field(default="v2", description="模型版本") | |
| gpu_numbers: str = Field(default="0", description="GPU 编号") | |
| is_half: bool = Field(default=True, description="是否使用半精度") | |
| audio_file_id: str = Field(..., description="音频文件 ID") | |
| class ExperimentUpdate(BaseModel): | |
| """更新实验请求""" | |
| exp_name: Optional[str] = Field(None, min_length=1, max_length=100) | |
| gpu_numbers: Optional[str] = None | |
| is_half: Optional[bool] = None | |
| class StageStatus(BaseModel): | |
| """阶段状态""" | |
| stage_type: str | |
| status: Literal["pending", "running", "completed", "failed", "cancelled"] | |
| progress: Optional[float] = None | |
| message: Optional[str] = None | |
| started_at: Optional[datetime] = None | |
| completed_at: Optional[datetime] = None | |
| config: Optional[Dict[str, Any]] = None | |
| outputs: Optional[Dict[str, Any]] = None | |
| class ExperimentResponse(BaseModel): | |
| """实验响应""" | |
| id: str | |
| exp_name: str | |
| version: str | |
| status: str | |
| gpu_numbers: str | |
| is_half: bool | |
| audio_file_id: str | |
| stages: Dict[str, StageStatus] | |
| created_at: datetime | |
| updated_at: Optional[datetime] = None | |
| # ============================================================ | |
| # 阶段执行 | |
| # ============================================================ | |
| class StageExecuteRequest(BaseModel): | |
| """阶段执行请求基类""" | |
| class Config: | |
| extra = "allow" # 允许额外字段(阶段特定参数) | |
| class AudioSliceParams(StageExecuteRequest): | |
| """音频切片参数""" | |
| threshold: int = Field(default=-34, ge=-60, le=0, description="静音检测阈值 (dB)") | |
| min_length: int = Field(default=4000, ge=1000, le=10000, description="最小切片长度 (ms)") | |
| min_interval: int = Field(default=300, ge=100, le=1000, description="最小静音间隔 (ms)") | |
| hop_size: int = Field(default=10, ge=5, le=50, description="检测步长 (ms)") | |
| max_sil_kept: int = Field(default=500, ge=100, le=2000, description="保留最大静音长度 (ms)") | |
| class ASRParams(StageExecuteRequest): | |
| """ASR 参数""" | |
| model: str = Field(default="达摩 ASR (中文)", description="ASR 模型") | |
| language: str = Field(default="zh", description="语言") | |
| class SoVITSTrainParams(StageExecuteRequest): | |
| """SoVITS 训练参数""" | |
| batch_size: int = Field(default=4, ge=1, le=32, description="批次大小") | |
| total_epoch: int = Field(default=8, ge=1, le=100, description="训练总轮数") | |
| save_every_epoch: int = Field(default=4, ge=1, description="保存间隔") | |
| pretrained_s2G: Optional[str] = Field(None, description="预训练生成器路径") | |
| pretrained_s2D: Optional[str] = Field(None, description="预训练判别器路径") | |
| class GPTTrainParams(StageExecuteRequest): | |
| """GPT 训练参数""" | |
| batch_size: int = Field(default=4, ge=1, le=32, description="批次大小") | |
| total_epoch: int = Field(default=15, ge=1, le=100, description="训练总轮数") | |
| save_every_epoch: int = Field(default=5, ge=1, description="保存间隔") | |
| pretrained_s1: Optional[str] = Field(None, description="预训练模型路径") | |
| class StageExecuteResponse(BaseModel): | |
| """阶段执行响应""" | |
| exp_id: str | |
| stage_type: str | |
| status: Literal["running", "queued"] | |
| job_id: str | |
| config: Dict[str, Any] | |
| rerun: bool = False | |
| previous_run: Optional[Dict[str, Any]] = None | |
| started_at: datetime | |
| ``` | |
| ### 4.6.3 Task Schema(Quick Mode 响应) | |
| ```python | |
| class TaskResponse(BaseModel): | |
| """任务响应(Quick Mode)""" | |
| id: str = Field(..., description="任务唯一标识") | |
| exp_name: str = Field(..., description="实验名称") | |
| status: Literal["queued", "running", "completed", "failed", "cancelled"] | |
| current_stage: Optional[str] = None | |
| progress: float = Field(default=0.0, ge=0.0, le=1.0, description="当前阶段进度") | |
| overall_progress: float = Field(default=0.0, ge=0.0, le=1.0, description="总体进度") | |
| message: Optional[str] = None | |
| error_message: Optional[str] = None | |
| created_at: Optional[datetime] = None | |
| started_at: Optional[datetime] = None | |
| completed_at: Optional[datetime] = None | |
| class Config: | |
| from_attributes = True | |
| ``` | |
| --- | |
| ## 4.7 API 实现示例 | |
| ### 4.7.1 Quick Mode API 实现 | |
| ```python | |
| # app/api/v1/endpoints/tasks.py | |
| from fastapi import APIRouter, HTTPException, Depends | |
| from app.services.task_service import TaskService | |
| from app.models.schemas.task import QuickModeRequest, TaskResponse | |
| router = APIRouter() | |
| @router.post("/tasks", response_model=TaskResponse) | |
| async def create_task( | |
| request: QuickModeRequest, | |
| task_service: TaskService = Depends(get_task_service) | |
| ): | |
| """ | |
| 创建一键训练任务(小白用户) | |
| 上传音频文件后,系统自动配置参数并执行完整训练流程。 | |
| """ | |
| return await task_service.create_quick_task(request) | |
| @router.get("/tasks/{task_id}", response_model=TaskResponse) | |
| async def get_task( | |
| task_id: str, | |
| task_service: TaskService = Depends(get_task_service) | |
| ): | |
| """获取任务详情""" | |
| task = await task_service.get_task(task_id) | |
| if not task: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| return task | |
| @router.delete("/tasks/{task_id}") | |
| async def cancel_task( | |
| task_id: str, | |
| task_service: TaskService = Depends(get_task_service) | |
| ): | |
| """取消任务""" | |
| success = await task_service.cancel_task(task_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Task not found or cannot be cancelled") | |
| return {"success": True, "message": "任务已取消"} | |
| ``` | |
| ### 4.7.2 Advanced Mode API 实现 | |
| ```python | |
| # app/api/v1/endpoints/experiments.py | |
| from fastapi import APIRouter, HTTPException, Depends, Body | |
| from typing import Dict, Any | |
| from app.services.experiment_service import ExperimentService | |
| from app.models.schemas.experiment import ( | |
| ExperimentCreate, | |
| ExperimentResponse, | |
| StageExecuteResponse, | |
| StageStatus, | |
| ) | |
| router = APIRouter() | |
| @router.post("/experiments", response_model=ExperimentResponse) | |
| async def create_experiment( | |
| request: ExperimentCreate, | |
| experiment_service: ExperimentService = Depends(get_experiment_service) | |
| ): | |
| """ | |
| 创建实验(专家用户) | |
| 创建实验但不立即执行,用户可以逐阶段控制训练流程。 | |
| """ | |
| return await experiment_service.create_experiment(request) | |
| @router.get("/experiments/{exp_id}", response_model=ExperimentResponse) | |
| async def get_experiment( | |
| exp_id: str, | |
| experiment_service: ExperimentService = Depends(get_experiment_service) | |
| ): | |
| """获取实验详情""" | |
| experiment = await experiment_service.get_experiment(exp_id) | |
| if not experiment: | |
| raise HTTPException(status_code=404, detail="Experiment not found") | |
| return experiment | |
| @router.post("/experiments/{exp_id}/stages/{stage_type}", response_model=StageExecuteResponse) | |
| async def execute_stage( | |
| exp_id: str, | |
| stage_type: str, | |
| params: Dict[str, Any] = Body(default={}), | |
| experiment_service: ExperimentService = Depends(get_experiment_service) | |
| ): | |
| """ | |
| 执行指定阶段 | |
| 可传入阶段特定参数覆盖默认值。如果阶段已完成,会重新执行。 | |
| """ | |
| # 验证阶段类型 | |
| valid_stages = ["audio_slice", "asr", "text_feature", "hubert_feature", | |
| "semantic_token", "sovits_train", "gpt_train"] | |
| if stage_type not in valid_stages: | |
| raise HTTPException(status_code=400, detail=f"Invalid stage type: {stage_type}") | |
| # 检查依赖阶段是否完成 | |
| dependencies = await experiment_service.check_stage_dependencies(exp_id, stage_type) | |
| if not dependencies["satisfied"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"依赖阶段未完成: {', '.join(dependencies['missing'])}" | |
| ) | |
| return await experiment_service.execute_stage(exp_id, stage_type, params) | |
| @router.get("/experiments/{exp_id}/stages", response_model=Dict[str, StageStatus]) | |
| async def get_all_stages( | |
| exp_id: str, | |
| experiment_service: ExperimentService = Depends(get_experiment_service) | |
| ): | |
| """获取所有阶段状态""" | |
| return await experiment_service.get_all_stages(exp_id) | |
| @router.get("/experiments/{exp_id}/stages/{stage_type}", response_model=StageStatus) | |
| async def get_stage( | |
| exp_id: str, | |
| stage_type: str, | |
| experiment_service: ExperimentService = Depends(get_experiment_service) | |
| ): | |
| """获取指定阶段状态和结果""" | |
| stage = await experiment_service.get_stage(exp_id, stage_type) | |
| if not stage: | |
| raise HTTPException(status_code=404, detail="Stage not found") | |
| return stage | |
| @router.delete("/experiments/{exp_id}/stages/{stage_type}") | |
| async def cancel_stage( | |
| exp_id: str, | |
| stage_type: str, | |
| experiment_service: ExperimentService = Depends(get_experiment_service) | |
| ): | |
| """取消正在执行的阶段""" | |
| success = await experiment_service.cancel_stage(exp_id, stage_type) | |
| if not success: | |
| raise HTTPException(status_code=400, detail="Stage not running or cannot be cancelled") | |
| return {"success": True, "message": f"阶段 {stage_type} 已取消"} | |
| ``` | |
| ### 4.7.3 服务层实现 | |
| ```python | |
| # app/services/experiment_service.py | |
| from typing import Dict, Any, Optional | |
| from datetime import datetime | |
| import uuid | |
| from app.core.adapters import database_adapter, task_queue_adapter | |
| from app.models.schemas.experiment import ExperimentCreate, ExperimentResponse | |
| # 阶段依赖关系 | |
| STAGE_DEPENDENCIES = { | |
| "audio_slice": [], | |
| "asr": ["audio_slice"], | |
| "text_feature": ["asr"], | |
| "hubert_feature": ["audio_slice"], | |
| "semantic_token": ["hubert_feature"], | |
| "sovits_train": ["text_feature", "semantic_token"], | |
| "gpt_train": ["text_feature", "semantic_token"], | |
| } | |
| class ExperimentService: | |
| """实验服务(Advanced Mode)""" | |
| def __init__(self): | |
| self.db = database_adapter | |
| self.queue = task_queue_adapter | |
| async def create_experiment(self, request: ExperimentCreate) -> ExperimentResponse: | |
| """创建实验""" | |
| exp_id = f"exp-{uuid.uuid4().hex[:8]}" | |
| # 初始化所有阶段为 pending 状态 | |
| stages = { | |
| stage: {"status": "pending", "config": None, "outputs": None} | |
| for stage in STAGE_DEPENDENCIES.keys() | |
| } | |
| experiment = { | |
| "id": exp_id, | |
| "exp_name": request.exp_name, | |
| "version": request.version, | |
| "gpu_numbers": request.gpu_numbers, | |
| "is_half": request.is_half, | |
| "audio_file_id": request.audio_file_id, | |
| "status": "created", | |
| "stages": stages, | |
| "created_at": datetime.utcnow(), | |
| } | |
| await self.db.create_experiment(experiment) | |
| return ExperimentResponse(**experiment) | |
| async def check_stage_dependencies(self, exp_id: str, stage_type: str) -> Dict: | |
| """检查阶段依赖是否满足""" | |
| experiment = await self.db.get_experiment(exp_id) | |
| dependencies = STAGE_DEPENDENCIES.get(stage_type, []) | |
| missing = [] | |
| for dep in dependencies: | |
| if experiment["stages"][dep]["status"] != "completed": | |
| missing.append(dep) | |
| return { | |
| "satisfied": len(missing) == 0, | |
| "missing": missing | |
| } | |
| async def execute_stage( | |
| self, | |
| exp_id: str, | |
| stage_type: str, | |
| params: Dict[str, Any] | |
| ) -> StageExecuteResponse: | |
| """执行阶段""" | |
| experiment = await self.db.get_experiment(exp_id) | |
| # 检查是否是重新执行 | |
| current_stage = experiment["stages"][stage_type] | |
| is_rerun = current_stage["status"] == "completed" | |
| previous_run = current_stage if is_rerun else None | |
| # 构建阶段配置 | |
| stage_config = { | |
| "exp_id": exp_id, | |
| "exp_name": experiment["exp_name"], | |
| "version": experiment["version"], | |
| "gpu_numbers": experiment["gpu_numbers"], | |
| "is_half": experiment["is_half"], | |
| "stage_type": stage_type, | |
| "params": params, | |
| } | |
| # 加入执行队列 | |
| job_id = await self.queue.enqueue_stage( | |
| exp_id=exp_id, | |
| stage_type=stage_type, | |
| config=stage_config | |
| ) | |
| # 更新阶段状态 | |
| await self.db.update_stage(exp_id, stage_type, { | |
| "status": "running", | |
| "config": params, | |
| "started_at": datetime.utcnow(), | |
| "job_id": job_id, | |
| }) | |
| return StageExecuteResponse( | |
| exp_id=exp_id, | |
| stage_type=stage_type, | |
| status="running", | |
| job_id=job_id, | |
| config=params, | |
| rerun=is_rerun, | |
| previous_run=previous_run, | |
| started_at=datetime.utcnow(), | |
| ) | |
| ``` | |
| --- | |
| ## 五、部署配置 | |
| ### 5.1 本地模式 (macOS) | |
| **配置文件: config/local.yaml** | |
| ```yaml | |
| deployment_mode: local | |
| local_storage_path: ./data/files | |
| sqlite_path: ./data/app.db | |
| local_max_workers: 1 # macOS单GPU,串行执行 | |
| ``` | |
| **启动命令**: | |
| ```shell script | |
| # 安装依赖 | |
| pip install -r requirements/base.txt -r requirements/local.txt | |
| # 启动API服务 | |
| uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload | |
| # 无需额外服务! | |
| ``` | |
| **docker-compose.local.yml**: | |
| ```yaml | |
| version: '3.8' | |
| services: | |
| api: | |
| build: . | |
| ports: | |
| - "8000:8000" | |
| volumes: | |
| - ./data:/app/data | |
| - ./logs:/app/logs | |
| environment: | |
| - DEPLOYMENT_MODE=local | |
| ``` | |
| ### 5.2 服务器模式 (Linux) | |
| **配置文件: config/server.yaml** | |
| ```yaml | |
| deployment_mode: server | |
| database_url: postgresql+asyncpg://user:pass@postgres/gpt_sovits | |
| redis_url: redis://redis:6379/0 | |
| celery_broker_url: redis://redis:6379/1 | |
| s3_endpoint: minio:9000 | |
| ``` | |
| **启动命令**: | |
| ```shell script | |
| # 使用docker-compose启动所有服务 | |
| docker-compose -f docker-compose.server.yml up -d | |
| ``` | |
| **docker-compose.server.yml**: | |
| ```yaml | |
| version: '3.8' | |
| services: | |
| api: | |
| build: . | |
| ports: | |
| - "8000:8000" | |
| depends_on: | |
| - postgres | |
| - redis | |
| - minio | |
| environment: | |
| - DEPLOYMENT_MODE=server | |
| - DATABASE_URL=postgresql+asyncpg://user:pass@postgres/gpt_sovits | |
| - REDIS_URL=redis://redis:6379/0 | |
| celery-worker: | |
| build: . | |
| command: celery -A app.workers.celery_worker worker --loglevel=info --concurrency=2 | |
| depends_on: | |
| - redis | |
| - postgres | |
| environment: | |
| - DEPLOYMENT_MODE=server | |
| deploy: | |
| replicas: 2 # 多个Worker | |
| postgres: | |
| image: postgres:15 | |
| volumes: | |
| - postgres_data:/var/lib/postgresql/data | |
| environment: | |
| POSTGRES_PASSWORD: password | |
| redis: | |
| image: redis:7-alpine | |
| minio: | |
| image: minio/minio | |
| command: server /data --console-address ":9001" | |
| ports: | |
| - "9000:9000" | |
| - "9001:9001" | |
| volumes: | |
| - minio_data:/data | |
| volumes: | |
| postgres_data: | |
| minio_data: | |
| ``` | |
| --- | |
| ## 六、数据库方案对比 | |
| ### 6.1 本地模式 - SQLite | |
| **Schema**: | |
| ```sql | |
| -- tasks表(Quick Mode 一键训练任务) | |
| CREATE TABLE tasks ( | |
| id TEXT PRIMARY KEY, | |
| exp_name TEXT NOT NULL, | |
| version TEXT NOT NULL, | |
| status TEXT NOT NULL, | |
| current_stage TEXT, | |
| overall_progress REAL, | |
| config TEXT, -- JSON | |
| created_at TEXT, | |
| started_at TEXT, | |
| completed_at TEXT, | |
| error_message TEXT | |
| ); | |
| -- experiments表(Advanced Mode 实验) | |
| CREATE TABLE experiments ( | |
| id TEXT PRIMARY KEY, | |
| exp_name TEXT NOT NULL, | |
| version TEXT NOT NULL, | |
| exp_root TEXT DEFAULT 'logs', | |
| gpu_numbers TEXT DEFAULT '0', | |
| is_half INTEGER DEFAULT 1, | |
| audio_file_id TEXT NOT NULL, | |
| status TEXT NOT NULL, | |
| created_at TEXT, | |
| updated_at TEXT, | |
| FOREIGN KEY (audio_file_id) REFERENCES files(id) | |
| ); | |
| -- stages表(Advanced Mode 阶段状态) | |
| CREATE TABLE stages ( | |
| id TEXT PRIMARY KEY, | |
| experiment_id TEXT NOT NULL, | |
| stage_type TEXT NOT NULL, | |
| status TEXT DEFAULT 'pending', | |
| progress REAL DEFAULT 0, | |
| message TEXT, | |
| job_id TEXT, | |
| config TEXT, -- JSON | |
| outputs TEXT, -- JSON | |
| started_at TEXT, | |
| completed_at TEXT, | |
| error_message TEXT, | |
| FOREIGN KEY (experiment_id) REFERENCES experiments(id) | |
| ); | |
| -- files表 | |
| CREATE TABLE files ( | |
| id TEXT PRIMARY KEY, | |
| filename TEXT NOT NULL, | |
| storage_path TEXT NOT NULL, | |
| purpose TEXT, | |
| size_bytes INTEGER, | |
| uploaded_at TEXT | |
| ); | |
| -- models表 | |
| CREATE TABLE models ( | |
| id TEXT PRIMARY KEY, | |
| task_id TEXT, | |
| experiment_id TEXT, | |
| exp_name TEXT NOT NULL, | |
| model_type TEXT NOT NULL, | |
| storage_path TEXT NOT NULL, | |
| epoch INTEGER, | |
| created_at TEXT, | |
| FOREIGN KEY (task_id) REFERENCES tasks(id), | |
| FOREIGN KEY (experiment_id) REFERENCES experiments(id) | |
| ); | |
| -- 索引 | |
| CREATE INDEX idx_tasks_status ON tasks(status); | |
| CREATE INDEX idx_experiments_status ON experiments(status); | |
| CREATE INDEX idx_stages_experiment ON stages(experiment_id); | |
| CREATE INDEX idx_stages_status ON stages(status); | |
| ``` | |
| **迁移管理**: 使用简单的版本号文件 + SQL脚本 | |
| ### 6.2 服务器模式 - PostgreSQL | |
| **使用SQLAlchemy + Alembic**: | |
| ```python | |
| # app/models/db/models.py | |
| from sqlalchemy import Column, String, Float, JSON, DateTime, Boolean, ForeignKey | |
| from sqlalchemy.orm import relationship | |
| from sqlalchemy.ext.declarative import declarative_base | |
| Base = declarative_base() | |
| class TaskModel(Base): | |
| """Quick Mode 任务模型""" | |
| __tablename__ = "tasks" | |
| id = Column(String, primary_key=True) | |
| exp_name = Column(String, nullable=False, index=True) | |
| version = Column(String, nullable=False) | |
| status = Column(String, nullable=False, index=True) | |
| current_stage = Column(String) | |
| overall_progress = Column(Float) | |
| config = Column(JSON) | |
| created_at = Column(DateTime, index=True) | |
| started_at = Column(DateTime) | |
| completed_at = Column(DateTime) | |
| error_message = Column(String) | |
| class ExperimentModel(Base): | |
| """Advanced Mode 实验模型""" | |
| __tablename__ = "experiments" | |
| id = Column(String, primary_key=True) | |
| exp_name = Column(String, nullable=False, index=True) | |
| version = Column(String, nullable=False) | |
| exp_root = Column(String, default="logs") | |
| gpu_numbers = Column(String, default="0") | |
| is_half = Column(Boolean, default=True) | |
| audio_file_id = Column(String, ForeignKey("files.id"), nullable=False) | |
| status = Column(String, nullable=False, index=True) | |
| created_at = Column(DateTime, index=True) | |
| updated_at = Column(DateTime) | |
| # 关联 | |
| stages = relationship("StageModel", back_populates="experiment") | |
| class StageModel(Base): | |
| """Advanced Mode 阶段模型""" | |
| __tablename__ = "stages" | |
| id = Column(String, primary_key=True) | |
| experiment_id = Column(String, ForeignKey("experiments.id"), nullable=False) | |
| stage_type = Column(String, nullable=False) | |
| status = Column(String, default="pending", index=True) | |
| progress = Column(Float, default=0) | |
| message = Column(String) | |
| job_id = Column(String) | |
| config = Column(JSON) | |
| outputs = Column(JSON) | |
| started_at = Column(DateTime) | |
| completed_at = Column(DateTime) | |
| error_message = Column(String) | |
| # 关联 | |
| experiment = relationship("ExperimentModel", back_populates="stages") | |
| ``` | |
| **迁移**: `alembic upgrade head` | |
| --- | |
| ## 七、任务队列方案对比 | |
| ### 7.0 关键发现:训练Pipeline的执行模型 | |
| > [!IMPORTANT] | |
| > **训练任务实际上是通过子进程执行的!** | |
| > | |
| > 分析 `training_pipeline/stages/training.py` 发现,每个训练阶段都通过 `subprocess.Popen` 调用独立的Python脚本: | |
| > ```python | |
| > cmd = f'PYTHONPATH=.:GPT_SoVITS "{cfg.python_exec}" -s GPT_SoVITS/s2_train.py --config "{tmp_config_path}"' | |
| > self._process = self._run_command(cmd, wait=True) | |
| > ``` | |
| **这意味着**: | |
| 1. GPU密集型训练计算发生在**独立的子进程**中,不受Python GIL限制 | |
| 2. FastAPI主进程仅需要"管理"这些子进程:启动、监控、停止 | |
| 3. ThreadPoolExecutor在这里只是一个"监工",等待阻塞的subprocess调用完成 | |
| 4. 更优雅的方案是使用 `asyncio.subprocess`,完全非阻塞 | |
| **进程模型图**: | |
| ``` | |
| ┌─────────────────────────────────────────────────────────────────┐ | |
| │ FastAPI 主进程 │ | |
| │ ┌──────────────────┐ ┌──────────────────────────────────┐ │ | |
| │ │ AsyncIO Event │ │ AsyncTrainingManager │ │ | |
| │ │ Loop │◄───│ - 管理子进程生命周期 │ │ | |
| │ │ │ │ - 异步读取stdout/stderr │ │ | |
| │ │ │ │ - 推送进度到SSE │ │ | |
| │ └──────────────────┘ └───────────────┬──────────────────┘ │ | |
| └─────────────────────────────────────────────┼───────────────────┘ | |
| │ asyncio.create_subprocess_exec() | |
| ┌─────────────────────────┼─────────────────────────┐ | |
| ▼ ▼ ▼ | |
| ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ | |
| │ s2_train.py │ │ s1_train.py │ │ inference.py │ | |
| │ (GPU训练) │ │ (GPU训练) │ │ (推理) │ | |
| └──────────────┘ └──────────────┘ └──────────────┘ | |
| ``` | |
| ### 7.0.1 进度追踪能力分析 | |
| 分析 `GPT_SoVITS/s2_train.py` 发现,训练脚本的输出格式如下: | |
| | 输出类型 | 输出位置 | 示例 | 可追踪性 | | |
| |---------|---------|------|---------| | |
| | **Epoch进度** | logger → stdout | `"====> Epoch: 5"` | ✅ 可解析 | | |
| | **训练百分比** | logger → stdout | `"Train Epoch: 1 [50.0%]"` | ✅ 可解析 | | |
| | **Loss信息** | logger → stdout | `[0.23, 0.45, ...]` | ✅ 可解析 | | |
| | **Batch进度条** | tqdm → stderr | `45%|████▌ | 45/100` | ⚠️ 格式不规则 | | |
| | **模型保存** | logger → stdout | `"saving ckpt xxx_e5:..."` | ✅ 可解析 | | |
| **当前问题**: | |
| 1. ❌ 输出不是JSON格式,需要正则表达式解析 | |
| 2. ❌ tqdm进度条格式复杂,难以精确解析 | |
| 3. ❌ 没有统一的进度通信协议 | |
| **解决方案**:修改训练脚本,添加JSON格式的进度输出 | |
| ```python | |
| # 在训练脚本中添加进度报告函数 | |
| import json | |
| import sys | |
| def report_progress(stage: str, epoch: int, total_epochs: int, | |
| batch: int = None, total_batches: int = None, | |
| loss: dict = None, message: str = None): | |
| """输出JSON格式的进度信息到stdout,供管理器解析""" | |
| progress_info = { | |
| "type": "progress", | |
| "stage": stage, | |
| "epoch": epoch, | |
| "total_epochs": total_epochs, | |
| "progress": epoch / total_epochs * 100, | |
| } | |
| if batch is not None: | |
| progress_info["batch"] = batch | |
| progress_info["total_batches"] = total_batches | |
| progress_info["progress"] = (epoch - 1 + batch / total_batches) / total_epochs * 100 | |
| if loss: | |
| progress_info["loss"] = loss | |
| if message: | |
| progress_info["message"] = message | |
| # 使用特殊前缀标识,便于解析 | |
| print(f"##PROGRESS##{json.dumps(progress_info)}##", flush=True) | |
| # 在训练循环中调用 | |
| for epoch in range(epoch_str, hps.train.epochs + 1): | |
| report_progress("SoVITS训练", epoch, hps.train.epochs, message=f"开始Epoch {epoch}") | |
| for batch_idx, data in enumerate(train_loader): | |
| # ... 训练代码 ... | |
| if batch_idx % 10 == 0: # 每10个batch报告一次 | |
| report_progress("SoVITS训练", epoch, hps.train.epochs, | |
| batch_idx, len(train_loader), | |
| loss={"g_total": loss_gen_all.item()}) | |
| ``` | |
| **管理器端解析**: | |
| ```python | |
| async def _monitor_process_output(self, task_id: str, process): | |
| """解析子进程输出获取进度""" | |
| async for line in process.stdout: | |
| text = line.decode().strip() | |
| # 检测JSON进度标记 | |
| if text.startswith("##PROGRESS##") and text.endswith("##"): | |
| json_str = text[12:-2] # 提取JSON部分 | |
| progress_info = json.loads(json_str) | |
| await self._send_progress(task_id, progress_info) | |
| # 兼容旧格式:正则解析 | |
| elif "Train Epoch:" in text: | |
| match = re.search(r"Train Epoch: (\d+) \[(\d+\.?\d*)%\]", text) | |
| if match: | |
| epoch, percent = match.groups() | |
| await self._send_progress(task_id, { | |
| "stage": "SoVITS训练", | |
| "epoch": int(epoch), | |
| "progress": float(percent), | |
| "message": text | |
| }) | |
| ``` | |
| --- | |
| ### 7.0.2 任务控制能力分析 | |
| | 操作 | 实现方式 | macOS支持 | 备注 | | |
| |------|---------|-----------|------| | |
| | **终止(Kill)** | `process.terminate()` | ✅ 完全支持 | 立即终止,可能丢失当前epoch | | |
| | **强制终止** | `process.kill()` | ✅ 完全支持 | 发送SIGKILL,强制停止 | | |
| | **暂停(Pause)** | `os.kill(pid, signal.SIGSTOP)` | ⚠️ 支持但有风险 | GPU/CUDA状态可能异常 | | |
| | **恢复(Resume)** | `os.kill(pid, signal.SIGCONT)` | ⚠️ 需配合SIGSTOP | 同上 | | |
| | **优雅停止** | 需要训练脚本配合 | ❌ 当前不支持 | 需要修改训练脚本 | | |
| **优雅停止方案**: | |
| 需要修改训练脚本以支持信号处理: | |
| ```python | |
| # 在训练脚本开头添加 | |
| import signal | |
| import json | |
| should_stop = False | |
| should_pause = False | |
| def handle_stop_signal(signum, frame): | |
| """收到SIGUSR1时,完成当前epoch后停止""" | |
| global should_stop | |
| should_stop = True | |
| print(json.dumps({"type": "signal", "message": "收到停止信号,将在当前epoch结束后停止"})) | |
| def handle_pause_signal(signum, frame): | |
| """收到SIGUSR2时,暂停训练""" | |
| global should_pause | |
| should_pause = not should_pause | |
| status = "暂停" if should_pause else "继续" | |
| print(json.dumps({"type": "signal", "message": f"训练已{status}"})) | |
| signal.signal(signal.SIGUSR1, handle_stop_signal) | |
| signal.signal(signal.SIGUSR2, handle_pause_signal) | |
| # 在训练循环中检查 | |
| for epoch in range(epoch_str, hps.train.epochs + 1): | |
| # 检查暂停 | |
| while should_pause: | |
| time.sleep(1) | |
| # 检查停止 | |
| if should_stop: | |
| print(json.dumps({"type": "progress", "status": "stopped", | |
| "message": f"训练在Epoch {epoch}结束后停止"})) | |
| # 保存checkpoint | |
| save_checkpoint(...) | |
| break | |
| # ... 正常训练 ... | |
| ``` | |
| **管理器端控制**: | |
| ```python | |
| class AsyncTrainingManager: | |
| async def pause(self, task_id: str) -> bool: | |
| """暂停任务""" | |
| if task_id in self.running_processes: | |
| process = self.running_processes[task_id] | |
| os.kill(process.pid, signal.SIGUSR2) | |
| return True | |
| return False | |
| async def graceful_stop(self, task_id: str) -> bool: | |
| """优雅停止(完成当前epoch后停止)""" | |
| if task_id in self.running_processes: | |
| process = self.running_processes[task_id] | |
| os.kill(process.pid, signal.SIGUSR1) | |
| return True | |
| return False | |
| async def force_stop(self, task_id: str) -> bool: | |
| """强制停止""" | |
| if task_id in self.running_processes: | |
| process = self.running_processes[task_id] | |
| process.terminate() | |
| try: | |
| await asyncio.wait_for(process.wait(), timeout=5.0) | |
| except asyncio.TimeoutError: | |
| process.kill() | |
| return True | |
| return False | |
| ``` | |
| > [!WARNING] | |
| > **暂停训练的风险**: | |
| > - macOS上使用SIGSTOP/SIGCONT暂停进程可能导致GPU资源锁定 | |
| > - 长时间暂停后恢复,CUDA上下文可能失效 | |
| > - 推荐使用:保存checkpoint后终止,需要时从checkpoint恢复 | |
| --- | |
| ### 7.1 本地模式 - 任务管理方案 ✅ 已实现 | |
| > [!TIP] | |
| > 选择任务管理方案时,需要考虑: | |
| > - **执行模型**:训练已经是子进程,任务管理器只需监控 | |
| > - **交付形态**:PyInstaller打包需要单主进程 | |
| > - **简洁性**:asyncio.subprocess 比 ThreadPool 更简洁 | |
| #### Option 1: asyncio.subprocess ⭐⭐ 推荐(所有场景)✅ 已选用并实现 | |
| > **实现文件**: `app/adapters/local/task_queue.py` | |
| **核心设计思想**: | |
| - 利用 `asyncio.create_subprocess_exec()` 异步启动训练子进程 | |
| - 完全非阻塞,与 FastAPI 的异步模型完美契合 | |
| - 无需 ThreadPool,架构更简洁 | |
| - 异步读取子进程输出,实时解析进度 | |
| ```python | |
| # 优点: | |
| - 纯asyncio,与FastAPI完美集成 | |
| - 无需ThreadPool,无线程管理开销 | |
| - 异步监控多个子进程 | |
| - 更简洁的代码结构 | |
| - 完全兼容PyInstaller打包 | |
| # 缺点: | |
| - 需要修改Pipeline执行方式(从同步改为异步) | |
| - 进度解析需要从stdout/stderr提取 | |
| ``` | |
| **完整实现**: | |
| ```python | |
| # app/adapters/local/async_task_manager.py | |
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| import uuid | |
| from datetime import datetime | |
| from typing import Dict, Optional, AsyncGenerator, List | |
| from pathlib import Path | |
| import aiosqlite | |
| from app.adapters.base import TaskQueueAdapter | |
| class AsyncTrainingManager(TaskQueueAdapter): | |
| """ | |
| 基于asyncio.subprocess的异步任务管理器。 | |
| 特点: | |
| 1. 使用asyncio.create_subprocess_exec()异步启动训练子进程 | |
| 2. 完全非阻塞,与FastAPI异步模型完美契合 | |
| 3. SQLite持久化任务状态,支持应用重启后恢复 | |
| 4. 实时解析子进程输出获取进度 | |
| """ | |
| def __init__(self, db_path: str = "./data/tasks.db"): | |
| self.db_path = db_path | |
| # 运行时状态 | |
| self.running_processes: Dict[str, asyncio.subprocess.Process] = {} # task_id -> Process | |
| self.progress_channels: Dict[str, asyncio.Queue] = {} # task_id -> Queue | |
| # 初始化数据库 | |
| self._init_db_sync() | |
| def _init_db_sync(self): | |
| """同步初始化数据库(启动时调用)""" | |
| import sqlite3 | |
| Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) | |
| with sqlite3.connect(self.db_path) as conn: | |
| conn.execute(''' | |
| CREATE TABLE IF NOT EXISTS task_queue ( | |
| job_id TEXT PRIMARY KEY, | |
| task_id TEXT NOT NULL, | |
| config TEXT NOT NULL, | |
| status TEXT DEFAULT 'queued', | |
| current_stage TEXT, | |
| progress REAL DEFAULT 0, | |
| created_at TEXT, | |
| started_at TEXT, | |
| completed_at TEXT, | |
| error_message TEXT | |
| ) | |
| ''') | |
| conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_status ON task_queue(status)') | |
| conn.commit() | |
| async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: | |
| """将任务加入队列并异步启动""" | |
| job_id = str(uuid.uuid4()) | |
| # 持久化到SQLite | |
| async with aiosqlite.connect(self.db_path) as db: | |
| await db.execute( | |
| '''INSERT INTO task_queue (job_id, task_id, config, status, created_at) | |
| VALUES (?, ?, ?, 'queued', ?)''', | |
| (job_id, task_id, json.dumps(config), datetime.utcnow().isoformat()) | |
| ) | |
| await db.commit() | |
| # 创建进度队列 | |
| self.progress_channels[task_id] = asyncio.Queue() | |
| # 异步启动训练任务 | |
| asyncio.create_task(self._run_training_async(job_id, task_id, config)) | |
| return job_id | |
| async def _run_training_async(self, job_id: str, task_id: str, config: Dict): | |
| """异步执行训练Pipeline""" | |
| try: | |
| await self._update_status(job_id, 'running', started_at=datetime.utcnow().isoformat()) | |
| await self._send_progress(task_id, {"status": "running", "message": "训练启动中..."}) | |
| # 构建训练脚本命令 | |
| # 这里调用一个包装脚本,它会执行完整的Pipeline并输出JSON格式的进度 | |
| script_path = self._get_pipeline_script_path() | |
| config_path = await self._write_config_file(task_id, config) | |
| # 创建子进程 | |
| process = await asyncio.create_subprocess_exec( | |
| sys.executable, script_path, | |
| '--config', config_path, | |
| '--task-id', task_id, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| env={**os.environ, 'PYTHONPATH': '.:GPT_SoVITS'} | |
| ) | |
| self.running_processes[task_id] = process | |
| # 异步读取stdout并解析进度 | |
| await self._monitor_process_output(task_id, process) | |
| # 等待进程完成 | |
| returncode = await process.wait() | |
| if returncode == 0: | |
| await self._update_status(job_id, 'completed', completed_at=datetime.utcnow().isoformat()) | |
| await self._send_progress(task_id, {"status": "completed", "progress": 100, "message": "训练完成"}) | |
| else: | |
| stderr = await process.stderr.read() | |
| error_msg = stderr.decode() if stderr else f"Process exited with code {returncode}" | |
| await self._update_status(job_id, 'failed', error_message=error_msg) | |
| await self._send_progress(task_id, {"status": "failed", "error": error_msg}) | |
| except asyncio.CancelledError: | |
| await self._update_status(job_id, 'cancelled') | |
| await self._send_progress(task_id, {"status": "cancelled", "message": "任务已取消"}) | |
| except Exception as e: | |
| await self._update_status(job_id, 'failed', error_message=str(e)) | |
| await self._send_progress(task_id, {"status": "failed", "error": str(e)}) | |
| finally: | |
| self.running_processes.pop(task_id, None) | |
| # 清理临时配置文件 | |
| await self._cleanup_config_file(task_id) | |
| async def _monitor_process_output(self, task_id: str, process: asyncio.subprocess.Process): | |
| """异步监控子进程输出并解析进度""" | |
| async def read_stream(stream, is_stderr=False): | |
| while True: | |
| line = await stream.readline() | |
| if not line: | |
| break | |
| text = line.decode().strip() | |
| if not text: | |
| continue | |
| # 尝试解析JSON格式的进度信息 | |
| if text.startswith('{') and text.endswith('}'): | |
| try: | |
| progress_info = json.loads(text) | |
| await self._send_progress(task_id, progress_info) | |
| # 同时更新数据库中的进度 | |
| if 'progress' in progress_info or 'stage' in progress_info: | |
| await self._update_progress_in_db(task_id, progress_info) | |
| except json.JSONDecodeError: | |
| pass | |
| elif is_stderr: | |
| # stderr输出作为日志 | |
| await self._send_progress(task_id, {"type": "log", "level": "error", "message": text}) | |
| # 并发读取stdout和stderr | |
| await asyncio.gather( | |
| read_stream(process.stdout, is_stderr=False), | |
| read_stream(process.stderr, is_stderr=True) | |
| ) | |
| async def _send_progress(self, task_id: str, progress_info: Dict): | |
| """发送进度到订阅队列""" | |
| if task_id in self.progress_channels: | |
| await self.progress_channels[task_id].put(progress_info) | |
| async def _update_status(self, job_id: str, status: str, **kwargs): | |
| """更新任务状态""" | |
| async with aiosqlite.connect(self.db_path) as db: | |
| updates = ["status = ?"] | |
| values = [status] | |
| for key, value in kwargs.items(): | |
| updates.append(f"{key} = ?") | |
| values.append(value) | |
| values.append(job_id) | |
| await db.execute( | |
| f"UPDATE task_queue SET {', '.join(updates)} WHERE job_id = ?", | |
| values | |
| ) | |
| await db.commit() | |
| async def _update_progress_in_db(self, task_id: str, progress_info: Dict): | |
| """更新数据库中的进度""" | |
| async with aiosqlite.connect(self.db_path) as db: | |
| updates = [] | |
| values = [] | |
| if 'progress' in progress_info: | |
| updates.append("progress = ?") | |
| values.append(progress_info['progress']) | |
| if 'stage' in progress_info: | |
| updates.append("current_stage = ?") | |
| values.append(progress_info['stage']) | |
| if updates: | |
| values.append(task_id) | |
| await db.execute( | |
| f"UPDATE task_queue SET {', '.join(updates)} WHERE task_id = ?", | |
| values | |
| ) | |
| await db.commit() | |
| async def get_status(self, job_id: str) -> Dict: | |
| """获取任务状态""" | |
| async with aiosqlite.connect(self.db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| async with db.execute( | |
| "SELECT * FROM task_queue WHERE job_id = ?", (job_id,) | |
| ) as cursor: | |
| row = await cursor.fetchone() | |
| if row: | |
| return dict(row) | |
| return {"status": "not_found"} | |
| async def cancel(self, job_id: str) -> bool: | |
| """取消任务""" | |
| # 查找task_id | |
| async with aiosqlite.connect(self.db_path) as db: | |
| async with db.execute( | |
| "SELECT task_id FROM task_queue WHERE job_id = ?", (job_id,) | |
| ) as cursor: | |
| row = await cursor.fetchone() | |
| if not row: | |
| return False | |
| task_id = row[0] | |
| # 终止进程 | |
| if task_id in self.running_processes: | |
| process = self.running_processes[task_id] | |
| process.terminate() | |
| # 等待进程终止 | |
| try: | |
| await asyncio.wait_for(process.wait(), timeout=5.0) | |
| except asyncio.TimeoutError: | |
| process.kill() | |
| await self._update_status(job_id, 'cancelled') | |
| return True | |
| return False | |
| async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: | |
| """订阅任务进度(SSE流)""" | |
| if task_id not in self.progress_channels: | |
| self.progress_channels[task_id] = asyncio.Queue() | |
| queue = self.progress_channels[task_id] | |
| while True: | |
| try: | |
| progress = await asyncio.wait_for(queue.get(), timeout=30.0) | |
| yield progress | |
| if progress.get('status') in ['completed', 'failed', 'cancelled']: | |
| break | |
| except asyncio.TimeoutError: | |
| # 发送心跳保持连接 | |
| yield {"type": "heartbeat", "timestamp": datetime.utcnow().isoformat()} | |
| async def recover_pending_tasks(self) -> int: | |
| """ | |
| 应用重启后恢复未完成的任务。 | |
| 注意:由于子进程在应用重启后已经终止,这里只能: | |
| 1. 将running状态的任务标记为interrupted | |
| 2. 可选择重新启动queued状态的任务 | |
| """ | |
| async with aiosqlite.connect(self.db_path) as db: | |
| # 将running状态的任务标记为interrupted(需要用户决定是否重试) | |
| await db.execute( | |
| "UPDATE task_queue SET status = 'interrupted' WHERE status = 'running'" | |
| ) | |
| await db.commit() | |
| # 重新启动queued状态的任务 | |
| db.row_factory = aiosqlite.Row | |
| async with db.execute( | |
| "SELECT * FROM task_queue WHERE status = 'queued' ORDER BY created_at" | |
| ) as cursor: | |
| queued_tasks = await cursor.fetchall() | |
| for task in queued_tasks: | |
| task_id = task['task_id'] | |
| config = json.loads(task['config']) | |
| job_id = task['job_id'] | |
| self.progress_channels[task_id] = asyncio.Queue() | |
| asyncio.create_task(self._run_training_async(job_id, task_id, config)) | |
| return len(queued_tasks) | |
| def _get_pipeline_script_path(self) -> str: | |
| """获取Pipeline执行脚本路径""" | |
| # 这个脚本会封装TrainingPipeline,并输出JSON格式的进度 | |
| return os.path.join(os.path.dirname(__file__), '..', '..', 'scripts', 'run_pipeline.py') | |
| async def _write_config_file(self, task_id: str, config: Dict) -> str: | |
| """写入临时配置文件""" | |
| config_dir = Path(self.db_path).parent / 'configs' | |
| config_dir.mkdir(exist_ok=True) | |
| config_path = config_dir / f"{task_id}.json" | |
| async with aiosqlite.connect(self.db_path): # 确保目录可写 | |
| pass | |
| with open(config_path, 'w') as f: | |
| json.dump(config, f) | |
| return str(config_path) | |
| async def _cleanup_config_file(self, task_id: str): | |
| """清理临时配置文件""" | |
| config_path = Path(self.db_path).parent / 'configs' / f"{task_id}.json" | |
| if config_path.exists(): | |
| config_path.unlink() | |
| ``` | |
| #### Option 2: ThreadPoolExecutor + SQLite持久化(备选方案) | |
| 如果不想修改现有的Pipeline执行方式,可以继续使用ThreadPool包装同步调用: | |
| ```python | |
| # 优点: | |
| - 无需修改现有Pipeline代码 | |
| - 标准库,依赖极少 | |
| - 实现简单 | |
| # 缺点: | |
| - ThreadPool线程仅用于等待阻塞的subprocess | |
| - 资源利用不够优雅 | |
| - 不是真正的异步 | |
| ``` | |
| > [!NOTE] | |
| > 此方案使用 `concurrent.futures.ThreadPoolExecutor` 将同步的 subprocess 调用包装为异步操作。 | |
| > 虽然功能可行,但与 asyncio.subprocess 相比增加了不必要的线程开销。 | |
| ```python | |
| # 简易实现逻辑 | |
| from concurrent.futures import ThreadPoolExecutor | |
| class ThreadPoolAdapter(TaskQueueAdapter): | |
| def __init__(self): | |
| self.executor = ThreadPoolExecutor(max_workers=1) | |
| async def enqueue(self, task_id, config, priority="normal"): | |
| job_id = str(uuid.uuid4()) | |
| # 在线程中执行同步的 run_pipeline | |
| self.executor.submit(self._run_sync, task_id, config) | |
| return job_id | |
| def _run_sync(self, task_id, config): | |
| # 同步执行 Pipeline | |
| pipeline = TrainingPipeline(config) | |
| pipeline.run() | |
| ``` | |
| #### Option 3: Huey(仅适合开发模式,不推荐用于PyInstaller打包) | |
| > [!WARNING] | |
| > Huey需要独立的consumer进程,**不适合**PyInstaller打包和Electron集成场景。 | |
| > 仅在纯Python开发模式下使用。 | |
| ```python | |
| # 安装 | |
| pip install huey | |
| # 配置 | |
| from huey import SqliteHuey | |
| huey = SqliteHuey('gpt_sovits', filename='./data/tasks.db') | |
| @huey.task() | |
| def execute_training_pipeline(task_id, config): | |
| # 执行训练 | |
| pass | |
| # 优点: | |
| - 轻量级(~1000行代码) | |
| - 支持SQLite后端(持久化) | |
| - 支持任务重试、定时任务 | |
| - 支持优先级队列 | |
| - 无需额外服务 | |
| # 缺点: | |
| - 需要独立的huey_consumer进程 | |
| - 不兼容PyInstaller单文件打包 | |
| - 功能不如Celery丰富 | |
| - 社区较小 | |
| ``` | |
| --- | |
| ### 7.2 服务器模式 - Celery [Phase 2] | |
| > **注意**: 此部分为 Phase 2 服务器模式的设计,当前阶段优先实现本地模式。 | |
| ```python | |
| # app/workers/celery_worker.py | |
| from celery import Celery | |
| from app.core.config import settings | |
| celery_app = Celery( | |
| 'gpt_sovits', | |
| broker=settings.CELERY_BROKER_URL, | |
| backend=settings.CELERY_RESULT_BACKEND | |
| ) | |
| celery_app.conf.update( | |
| task_serializer='json', | |
| accept_content=['json'], | |
| result_serializer='json', | |
| timezone='UTC', | |
| task_routes={ | |
| 'app.workers.celery_worker.execute_training_pipeline': {'queue': 'training'}, | |
| 'app.workers.celery_worker.execute_inference': {'queue': 'inference'} | |
| } | |
| ) | |
| @celery_app.task(bind=True, max_retries=3) | |
| def execute_training_pipeline(self, task_id: str, config: dict): | |
| """执行训练Pipeline(与Huey版本类似)""" | |
| # 实现逻辑同上 | |
| pass | |
| ``` | |
| --- | |
| ## 八、完整对比表 | |
| | 维度 | 本地开发模式 (macOS) | PyInstaller/Electron模式 | 服务器模式 (Linux) | | |
| |------|---------------------|--------------------------|-------------------| | |
| | **数据库** | SQLite (单文件) | SQLite (单文件) | PostgreSQL (集群) | | |
| | **任务管理** | asyncio.subprocess ⭐ | asyncio.subprocess ⭐ | Celery + Redis | | |
| | **执行模型** | 子进程(s2_train.py等) | 子进程(s2_train.py等) | 分布式Worker | | |
| | **文件存储** | 本地文件系统 | 本地文件系统 | MinIO/S3 | | |
| | **进度管理** | stdout解析 + asyncio.Queue | stdout解析 + asyncio.Queue | Redis Pub/Sub | | |
| | **并发能力** | 1-2个任务 | 1个任务(串行) | 无限(水平扩展) | | |
| | **依赖服务** | 0 (全in-one) | 0 (全in-one) | 3+ (PostgreSQL, Redis, MinIO) | | |
| | **启动命令** | `uvicorn app.main:app` | Electron启动Python子进程 | `docker-compose up` | | |
| | **适用场景** | 开发调试 | 桌面应用分发 | 生产环境、多用户 | | |
| | **部署复杂度** | ⭐ | ⭐⭐ | ⭐⭐⭐⭐ | | |
| | **打包支持** | 不需要 | PyInstaller单文件 | Docker镜像 | | |
| | **维护成本** | 低 | 低 | 中等 | | |
| --- | |
| ## 九、推荐实现路径 | |
| ### Phase 1: 本地模式MVP | |
| #### 1.1 架构设计与 Schema 定义 ✅ 已完成 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | API 架构设计 | ✅ 完成 | 双模式设计(Quick Mode + Advanced Mode) | | |
| | Pydantic Schema 设计 | ✅ 完成 | development.md 中完整定义 | | |
| | 数据库 Schema 设计 | ✅ 完成 | tasks, experiments, stages 表结构 | | |
| | 阶段参数 Schema 设计 | ✅ 完成 | AudioSliceParams, SoVITSTrainParams 等 | | |
| #### 1.2 核心基础设施 ✅ 已完成 | |
| | 任务 | 状态 | 实现文件 | | |
| |------|------|----------| | |
| | 适配器抽象基类 | ✅ 完成 | `app/adapters/base.py` - TaskQueueAdapter, ProgressAdapter | | |
| | AsyncTrainingManager | ✅ 完成 | `app/adapters/local/task_queue.py` - 完整实现 | | |
| | 配置管理模块 | ✅ 完成 | `app/core/config.py` - Settings, 路径常量 | | |
| | 领域模型 | ✅ 完成 | `app/models/domain.py` - Task, TaskStatus, ProgressInfo | | |
| | Pipeline 包装脚本 | ✅ 完成 | `app/scripts/run_pipeline.py` - 子进程执行器 | | |
| **AsyncTrainingManager 已实现功能:** | |
| - ✅ 任务入队与异步执行 (`enqueue`) | |
| - ✅ 子进程管理 (`asyncio.create_subprocess_exec`) | |
| - ✅ 进度解析与推送 (`_monitor_process_output`) | |
| - ✅ 任务状态查询 (`get_status`, `get_status_by_task_id`) | |
| - ✅ 任务取消 (`cancel`) | |
| - ✅ 进度订阅 SSE 流 (`subscribe_progress`) | |
| - ✅ 任务列表查询 (`list_tasks`) | |
| - ✅ 任务恢复机制 (`recover_pending_tasks`) | |
| - ✅ 旧任务清理 (`cleanup_old_tasks`) | |
| #### 1.3 Pydantic Schema 文件 ✅ 已完成 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | `app/models/schemas/common.py` | ✅ 完成 | SuccessResponse, ErrorResponse, PaginatedResponse | | |
| | `app/models/schemas/task.py` | ✅ 完成 | QuickModeOptions, QuickModeRequest, TaskResponse, TaskListResponse | | |
| | `app/models/schemas/experiment.py` | ✅ 完成 | ExperimentCreate, StageStatus, 各阶段参数类等 | | |
| | `app/models/schemas/file.py` | ✅ 完成 | FileMetadata, FileUploadResponse, FileListResponse | | |
| #### 1.4 存储与数据库适配器 ✅ 已完成 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | StorageAdapter 抽象类 | ✅ 完成 | `app/adapters/base.py` - 文件存储接口 | | |
| | DatabaseAdapter 抽象类 | ✅ 完成 | `app/adapters/base.py` - 数据库操作接口 | | |
| | LocalStorageAdapter | ✅ 完成 | `app/adapters/local/storage.py` - 本地文件系统存储 | | |
| | SQLiteAdapter | ✅ 完成 | `app/adapters/local/database.py` - SQLite 数据库适配器 | | |
| | LocalProgressAdapter | ✅ 完成 | `app/adapters/local/progress.py` - 内存进度管理 | | |
| **LocalStorageAdapter 已实现功能:** | |
| - ✅ 文件上传/下载 (`upload_file`, `download_file`) | |
| - ✅ 文件删除 (`delete_file`) | |
| - ✅ 元数据管理 (`.meta.json` 文件) | |
| - ✅ 文件列表查询 (`list_files`) | |
| - ✅ 音频信息提取(时长、采样率) | |
| **SQLiteAdapter 已实现功能:** | |
| - ✅ Task CRUD (Quick Mode) | |
| - ✅ Experiment CRUD (Advanced Mode) | |
| - ✅ Stage 状态管理 | |
| - ✅ File 记录管理 | |
| - ✅ 自动表结构初始化 | |
| **LocalProgressAdapter 已实现功能:** | |
| - ✅ 进度更新与存储 (`update_progress`) | |
| - ✅ 订阅者模式 (`subscribe`) | |
| - ✅ 多订阅者支持 | |
| - ✅ 心跳机制 | |
| #### 1.5 API 端点 ✅ 已完成 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | Quick Mode API (`/tasks`) | ✅ 已实现 | `app/api/v1/endpoints/tasks.py` | | |
| | Advanced Mode API (`/experiments`) | ✅ 已实现 | `app/api/v1/endpoints/experiments.py` | | |
| | 文件管理 API (`/files`) | ✅ 已实现 | `app/api/v1/endpoints/files.py` | | |
| | 阶段模板 API (`/stages`) | ✅ 已实现 | `app/api/v1/endpoints/stages.py` | | |
| | 路由注册 | ✅ 已实现 | `app/api/v1/router.py` | | |
| | FastAPI 入口 | ✅ 已实现 | `app/main.py` | | |
| | 适配器工厂 | ✅ 已实现 | `app/core/adapters.py` | | |
| | 依赖注入 | ✅ 已实现 | `app/api/deps.py` | | |
| **API 端点已实现功能:** | |
| - ✅ Quick Mode: 创建任务、任务列表、任务详情、取消任务、SSE 进度订阅 | |
| - ✅ Advanced Mode: 创建实验、实验列表、实验详情、更新/删除实验、执行阶段、阶段状态、取消阶段、SSE 阶段进度 | |
| - ✅ 文件管理: 上传文件、文件列表、下载文件、删除文件 | |
| - ✅ 阶段模板: 预设列表、阶段参数模板 | |
| #### 1.6 服务层 ✅ 已完成 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | TaskService | ✅ 已实现 | `app/services/task_service.py` | | |
| | ExperimentService | ✅ 已实现 | `app/services/experiment_service.py` | | |
| | FileService | ✅ 已实现 | `app/services/file_service.py` | | |
| **服务层已实现功能:** | |
| - ✅ TaskService: 创建一键训练任务、质量预设配置、任务状态管理、进度订阅 | |
| - ✅ ExperimentService: 实验 CRUD、阶段依赖检查、阶段执行/取消、进度订阅 | |
| - ✅ FileService: 文件上传/下载、元数据管理、音频信息提取 | |
| #### 1.7 测试与验证 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | Quick Mode 端到端测试 | 🔲 待开始 | 上传音频 → 训练完成 | | |
| | Advanced Mode 分阶段测试 | 🔲 待开始 | 逐阶段执行 + 重新执行 | | |
| | 任务取消/恢复测试 | 🔲 待开始 | 验证任务生命周期管理 | | |
| --- | |
| ### Phase 2: Electron 集成准备 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | 任务持久化和恢复机制 | 🔲 待开始 | 应用重启后恢复任务状态 | | |
| | PyInstaller 打包配置 | 🔲 待开始 | .spec 文件配置 | | |
| | Electron 进程管理模块 | 🔲 待开始 | spawn/kill Python 进程 | | |
| | IPC 通信层 | 🔲 待开始 | HTTP API 或 WebSocket | | |
| | macOS 签名和公证 | 🔲 待开始 | 可选,用于分发 | | |
| --- | |
| ### Phase 3: 服务器模式 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | PostgreSQL 适配器 | 🔲 待开始 | SQLAlchemy + Alembic | | |
| | Celery 任务队列适配器 | 🔲 待开始 | 分布式任务执行 | | |
| | S3/MinIO 存储适配器 | 🔲 待开始 | 对象存储 | | |
| | Redis 进度管理适配器 | 🔲 待开始 | Pub/Sub 进度推送 | | |
| | 认证授权 | 🔲 待开始 | JWT / API Key | | |
| | 监控告警 | 🔲 待开始 | Prometheus + Grafana | | |
| | Docker 部署配置 | 🔲 待开始 | docker-compose.yml | | |
| --- | |
| ### Phase 4: 增强功能 | |
| | 任务 | 状态 | 说明 | | |
| |------|------|------| | |
| | 模型版本管理 | 🔲 待开始 | 多版本模型存储和切换 | | |
| | 批量推理 | 🔲 待开始 | 批量 TTS 生成 | | |
| | 定时任务 | 🔲 待开始 | 计划训练任务 | | |
| | Webhook 通知 | 🔲 待开始 | 训练完成回调 | | |
| | 训练数据集管理 | 🔲 待开始 | 数据集版本控制 | | |
| --- | |
| ## 十、关键代码示例 | |
| ### 10.1 启动文件(自动识别模式) | |
| ```python | |
| # app/main.py | |
| from fastapi import FastAPI | |
| from app.core.config import settings | |
| from app.api.v1.router import api_router | |
| app = FastAPI(title=settings.PROJECT_NAME) | |
| @app.on_event("startup") | |
| async def startup_event(): | |
| print(f"Starting in {settings.DEPLOYMENT_MODE.upper()} mode") | |
| if settings.DEPLOYMENT_MODE == "local": | |
| print("Using SQLite + Huey + Local FileSystem") | |
| # 启动Huey consumer(如果在同一进程) | |
| # 或者提示用户启动: huey_consumer app.workers.local_worker.huey | |
| else: | |
| print("Using PostgreSQL + Celery + MinIO") | |
| # 初始化数据库连接池 | |
| # 预热Redis连接 | |
| app.include_router(api_router, prefix=settings.API_V1_PREFIX) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |
| ``` | |
| ### 10.2 环境变量配置 | |
| **.env.local**: | |
| ``` | |
| DEPLOYMENT_MODE=local | |
| LOCAL_STORAGE_PATH=./data/files | |
| SQLITE_PATH=./data/app.db | |
| LOCAL_MAX_WORKERS=1 | |
| ``` | |
| **.env.server**: | |
| ``` | |
| DEPLOYMENT_MODE=server | |
| DATABASE_URL=postgresql+asyncpg://user:pass@localhost/gpt_sovits | |
| REDIS_URL=redis://localhost:6379/0 | |
| CELERY_BROKER_URL=redis://localhost:6379/1 | |
| S3_ENDPOINT=localhost:9000 | |
| S3_ACCESS_KEY=minioadmin | |
| S3_SECRET_KEY=minioadmin | |
| ``` | |
| --- | |
| ## 十一、Electron集成指南 | |
| ### 11.1 架构概览 | |
| ``` | |
| ┌─────────────────────────────────────────────────────────────┐ | |
| │ Electron Main Process │ | |
| │ ┌─────────────────┐ ┌──────────────────────────────┐ │ | |
| │ │ Process Manager │────▶│ Python (PyInstaller Bundle) │ │ | |
| │ └─────────────────┘ │ ┌──────────────────────────┐│ │ | |
| │ │ │ │ FastAPI HTTP Server ││ │ | |
| │ │ │ │ + ThreadPool Queue ││ │ | |
| │ │ │ │ + SQLite Database ││ │ | |
| │ │ │ └──────────────────────────┘│ │ | |
| │ │ └──────────────────────────────┘ │ | |
| │ │ │ │ | |
| │ ┌───────▼─────────────────────────────▼─────────────────┐ │ | |
| │ │ Renderer Process (Vue/React) │ │ | |
| │ │ HTTP API / SSE Progress Subscription │ │ | |
| │ └───────────────────────────────────────────────────────┘ │ | |
| └─────────────────────────────────────────────────────────────┘ | |
| ``` | |
| ### 11.2 Python进程管理(Electron侧) | |
| ```javascript | |
| // electron/python-manager.js | |
| const { spawn } = require('child_process'); | |
| const path = require('path'); | |
| const http = require('http'); | |
| class PythonProcessManager { | |
| constructor() { | |
| this.pythonProcess = null; | |
| this.apiPort = 8765; | |
| this.isReady = false; | |
| } | |
| /** | |
| * 启动Python后端进程 | |
| */ | |
| start() { | |
| return new Promise((resolve, reject) => { | |
| const pythonPath = this.getPythonPath(); | |
| this.pythonProcess = spawn(pythonPath, [], { | |
| env: { | |
| ...process.env, | |
| DEPLOYMENT_MODE: 'local', | |
| API_PORT: this.apiPort.toString(), | |
| // 使用Electron的userData目录存储数据 | |
| DATA_PATH: path.join(app.getPath('userData'), 'training-data') | |
| }, | |
| stdio: ['pipe', 'pipe', 'pipe'] | |
| }); | |
| this.pythonProcess.stdout.on('data', (data) => { | |
| console.log(`[Python] ${data}`); | |
| // 检测服务启动完成 | |
| if (data.toString().includes('Uvicorn running on')) { | |
| this.isReady = true; | |
| resolve(); | |
| } | |
| }); | |
| this.pythonProcess.stderr.on('data', (data) => { | |
| console.error(`[Python Error] ${data}`); | |
| }); | |
| this.pythonProcess.on('close', (code) => { | |
| console.log(`Python process exited with code ${code}`); | |
| this.isReady = false; | |
| }); | |
| // 超时处理 | |
| setTimeout(() => { | |
| if (!this.isReady) { | |
| reject(new Error('Python server startup timeout')); | |
| } | |
| }, 30000); | |
| }); | |
| } | |
| /** | |
| * 获取打包后的Python可执行文件路径 | |
| */ | |
| getPythonPath() { | |
| if (process.env.NODE_ENV === 'development') { | |
| return 'python'; // 开发模式使用系统Python | |
| } | |
| // 生产模式使用PyInstaller打包的可执行文件 | |
| const resourcesPath = process.resourcesPath; | |
| if (process.platform === 'darwin') { | |
| return path.join(resourcesPath, 'python', 'gpt-sovits-api'); | |
| } else if (process.platform === 'win32') { | |
| return path.join(resourcesPath, 'python', 'gpt-sovits-api.exe'); | |
| } | |
| return path.join(resourcesPath, 'python', 'gpt-sovits-api'); | |
| } | |
| /** | |
| * 等待API服务就绪 | |
| */ | |
| async waitForReady(maxRetries = 30) { | |
| for (let i = 0; i < maxRetries; i++) { | |
| try { | |
| await this.healthCheck(); | |
| return true; | |
| } catch { | |
| await new Promise(r => setTimeout(r, 1000)); | |
| } | |
| } | |
| return false; | |
| } | |
| /** | |
| * 健康检查 | |
| */ | |
| healthCheck() { | |
| return new Promise((resolve, reject) => { | |
| http.get(`http://localhost:${this.apiPort}/health`, (res) => { | |
| if (res.statusCode === 200) resolve(); | |
| else reject(); | |
| }).on('error', reject); | |
| }); | |
| } | |
| /** | |
| * 停止Python进程 | |
| */ | |
| stop() { | |
| if (this.pythonProcess) { | |
| this.pythonProcess.kill('SIGTERM'); | |
| this.pythonProcess = null; | |
| this.isReady = false; | |
| } | |
| } | |
| /** | |
| * 获取API基础URL | |
| */ | |
| getApiBaseUrl() { | |
| return `http://localhost:${this.apiPort}`; | |
| } | |
| } | |
| module.exports = PythonProcessManager; | |
| ``` | |
| ### 11.3 PyInstaller打包配置 | |
| ```python | |
| # gpt-sovits-api.spec | |
| # -*- mode: python ; coding: utf-8 -*- | |
| block_cipher = None | |
| a = Analysis( | |
| ['app/main.py'], | |
| pathex=[], | |
| binaries=[], | |
| datas=[ | |
| # 包含预训练模型 | |
| ('pretrained_models', 'pretrained_models'), | |
| # 包含配置文件 | |
| ('config', 'config'), | |
| ], | |
| hiddenimports=[ | |
| 'uvicorn.logging', | |
| 'uvicorn.loops', | |
| 'uvicorn.loops.auto', | |
| 'uvicorn.protocols', | |
| 'uvicorn.protocols.http', | |
| 'uvicorn.protocols.http.auto', | |
| 'uvicorn.protocols.websockets', | |
| 'uvicorn.protocols.websockets.auto', | |
| 'uvicorn.lifespan', | |
| 'uvicorn.lifespan.on', | |
| 'aiosqlite', | |
| 'torch', | |
| 'torchaudio', | |
| # 添加所有需要的隐式导入 | |
| ], | |
| hookspath=[], | |
| hooksconfig={}, | |
| runtime_hooks=[], | |
| excludes=[ | |
| 'tkinter', | |
| 'matplotlib', | |
| 'IPython', | |
| 'jupyter', | |
| ], | |
| win_no_prefer_redirects=False, | |
| win_private_assemblies=False, | |
| cipher=block_cipher, | |
| noarchive=False, | |
| ) | |
| pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) | |
| exe = EXE( | |
| pyz, | |
| a.scripts, | |
| a.binaries, | |
| a.zipfiles, | |
| a.datas, | |
| [], | |
| name='gpt-sovits-api', | |
| debug=False, | |
| bootloader_ignore_signals=False, | |
| strip=False, | |
| upx=True, | |
| upx_exclude=[], | |
| runtime_tmpdir=None, | |
| console=True, # 设为False隐藏控制台 | |
| disable_windowed_traceback=False, | |
| argv_emulation=False, | |
| target_arch=None, | |
| codesign_identity=None, | |
| entitlements_file=None, | |
| ) | |
| ``` | |
| ### 11.4 适配器工厂更新(支持Electron模式) | |
| ```python | |
| # app/core/adapters.py | |
| from app.core.config import settings | |
| import os | |
| class AdapterFactory: | |
| @staticmethod | |
| def create_task_queue_adapter(): | |
| # PyInstaller/Electron模式下强制使用ThreadPool | |
| if settings.DEPLOYMENT_MODE == "local": | |
| from app.adapters.local.task_queue import LocalTaskQueueAdapter | |
| # 根据环境确定数据路径 | |
| data_path = os.environ.get('DATA_PATH', './data') | |
| db_path = os.path.join(data_path, 'tasks.db') | |
| return LocalTaskQueueAdapter( | |
| max_workers=settings.LOCAL_MAX_WORKERS, | |
| db_path=db_path | |
| ) | |
| else: | |
| from app.adapters.server.task_queue import CeleryTaskQueueAdapter | |
| return CeleryTaskQueueAdapter( | |
| broker_url=settings.CELERY_BROKER_URL, | |
| backend_url=settings.CELERY_RESULT_BACKEND | |
| ) | |
| ``` | |
| ### 11.5 打包和分发检查清单 | |
| ```markdown | |
| ## macOS打包检查清单 | |
| - [ ] 签名Python可执行文件(如需分发到App Store外) | |
| - [ ] 处理Gatekeeper问题(首次运行需要右键打开) | |
| - [ ] 测试在干净的系统上启动 | |
| - [ ] 验证模型文件正确打包 | |
| - [ ] 测试任务恢复机制 | |
| - [ ] 验证进度SSE流正常工作 | |
| - [ ] 测试Electron退出时Python进程正确清理 | |
| ## 目录结构 | |
| YourApp.app/ | |
| ├── Contents/ | |
| │ ├── MacOS/ | |
| │ │ └── YourApp # Electron主程序 | |
| │ ├── Resources/ | |
| │ │ ├── python/ | |
| │ │ │ └── gpt-sovits-api # PyInstaller打包的Python | |
| │ │ ├── pretrained_models/ # 预训练模型 | |
| │ │ └── ... | |
| │ └── Info.plist | |
| ``` | |
| --- | |
| ## 总结 | |
| 此架构设计核心思想: | |
| 1. **统一接口**: API层和业务逻辑层完全统一 | |
| 2. **适配器模式**: 底层存储/队列/缓存通过适配器切换 | |
| 3. **配置驱动**: 通过环境变量控制部署模式 | |
| 4. **渐进式**: 先实现本地版本(快速验证),再扩展到服务器版本 | |
| 5. **零依赖本地部署**: 本地模式无需Docker、Redis、PostgreSQL | |
| 6. **子进程执行模型**: 训练任务通过subprocess执行,主进程仅管理 | |
| 7. **asyncio.subprocess推荐**: 完全非阻塞,与FastAPI完美契合 | |
| **推荐起步**: | |
| - **所有本地场景**: 使用 `asyncio.subprocess` + SQLite 方案(`AsyncTrainingManager`) | |
| - **Electron桌面应用**: 同上,完全兼容PyInstaller打包 | |
| - **服务器生产环境**: 使用Celery + Redis实现分布式任务队列 | |
| > [!TIP] | |
| > 关键洞察:既然训练Pipeline已经通过subprocess调用独立的Python脚本, | |
| > 那么使用 `asyncio.create_subprocess_exec()` 是最自然的选择, | |
| > 无需引入ThreadPool的额外复杂性。 |