""" Quick Mode 任务 API 小白用户一键训练 API 端点 API 列表: - POST /tasks 创建一键训练任务 - GET /tasks 获取任务列表 - GET /tasks/{task_id} 获取任务详情 - DELETE /tasks/{task_id} 取消任务 - GET /tasks/{task_id}/progress SSE 进度订阅 - GET /tasks/{task_id}/outputs 获取推理输出列表 - GET /tasks/{task_id}/outputs/{file_type}/{filename} 下载任务相关文件 """ import json from typing import Literal, Optional from fastapi import APIRouter, Depends, HTTPException, Path, Query from fastapi.responses import StreamingResponse, Response from ....models.schemas.task import ( QuickModeRequest, TaskResponse, TaskListResponse, InferenceOutputsResponse, ) from ....models.schemas.common import SuccessResponse, ErrorResponse from ....services.task_service import TaskService from ...deps import get_task_service router = APIRouter() @router.post( "", response_model=TaskResponse, summary="创建一键训练任务", description=""" 创建一键训练任务(小白用户)。 上传音频文件后,系统自动配置所有参数并执行完整训练流程: `audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train -> inference` 训练完成后会自动进行推理测试,生成测试音频文件。可通过 `options.inference` 配置推理参数或禁用推理阶段。 **质量预设**: - `fast`: SoVITS 4 epochs, GPT 8 epochs, 约10分钟 - `standard`: SoVITS 8 epochs, GPT 15 epochs, 约20分钟 - `high`: SoVITS 16 epochs, GPT 30 epochs, 约40分钟 """, responses={ 200: {"model": TaskResponse, "description": "任务创建成功"}, 400: {"model": ErrorResponse, "description": "请求参数错误"}, 404: {"model": ErrorResponse, "description": "音频文件不存在"}, 409: {"model": ErrorResponse, "description": "实验名称已存在"}, }, ) async def create_task( request: QuickModeRequest, service: TaskService = Depends(get_task_service), ) -> TaskResponse: """ 创建一键训练任务 """ # 验证 exp_name 是否已存在 if await service.check_exp_name_exists(request.exp_name): raise HTTPException( status_code=409, detail=f"实验名称 '{request.exp_name}' 已存在,请使用不同的名称" ) # 验证音频文件是否存在 file_exists, audio_path = await service.validate_audio_file(request.audio_file_id) if not file_exists: raise HTTPException( status_code=404, detail=f"音频文件不存在: {request.audio_file_id}" ) return await service.create_quick_task(request) @router.get( "", response_model=TaskListResponse, summary="获取任务列表", description="获取所有训练任务列表,支持按状态筛选和分页。", ) async def list_tasks( status: Optional[str] = Query( None, description="按状态筛选: queued, running, completed, failed, cancelled, interrupted" ), limit: int = Query(50, ge=1, le=100, description="每页数量"), offset: int = Query(0, ge=0, description="偏移量"), service: TaskService = Depends(get_task_service), ) -> TaskListResponse: """ 获取任务列表 """ return await service.list_tasks(status=status, limit=limit, offset=offset) @router.get( "/{task_id}", response_model=TaskResponse, summary="获取任务详情", description="获取指定任务的详细状态信息。", responses={ 200: {"model": TaskResponse, "description": "任务详情"}, 404: {"model": ErrorResponse, "description": "任务不存在"}, }, ) async def get_task( task_id: str, service: TaskService = Depends(get_task_service), ) -> TaskResponse: """ 获取任务详情 """ task = await service.get_task(task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") return task @router.delete( "/{task_id}", response_model=SuccessResponse, summary="取消任务", description="取消排队中或运行中的任务。已完成、失败或已取消的任务无法取消。", responses={ 200: {"model": SuccessResponse, "description": "任务取消成功"}, 400: {"model": ErrorResponse, "description": "任务无法取消"}, 404: {"model": ErrorResponse, "description": "任务不存在"}, }, ) async def cancel_task( task_id: str, service: TaskService = Depends(get_task_service), ) -> SuccessResponse: """ 取消任务 """ # 先检查任务是否存在 task = await service.get_task(task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") success = await service.cancel_task(task_id) if not success: raise HTTPException(status_code=400, detail="任务无法取消(可能已完成或已取消)") return SuccessResponse(message="任务已取消") @router.get( "/{task_id}/progress", summary="SSE 进度订阅", description=""" 订阅任务进度更新(Server-Sent Events)。 返回的事件流格式: ``` event: progress data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"} event: progress data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"} event: completed data: {"status": "completed", "message": "训练完成"} ``` 可能的事件类型: - `progress`: 进度更新 - `log`: 日志消息 - `heartbeat`: 心跳(保持连接) - `completed`: 任务完成 - `failed`: 任务失败 - `cancelled`: 任务取消 """, responses={ 200: {"description": "SSE 事件流"}, 404: {"model": ErrorResponse, "description": "任务不存在"}, }, ) async def subscribe_progress( task_id: str, service: TaskService = Depends(get_task_service), ) -> StreamingResponse: """ SSE 进度订阅 """ # 先检查任务是否存在 task = await service.get_task(task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") async def event_generator(): """生成 SSE 事件流""" async for progress in service.subscribe_progress(task_id): # 确定事件类型 event_type = progress.get("type", "progress") status = progress.get("status") if status == "completed": event_type = "completed" elif status == "failed": event_type = "failed" elif status == "cancelled": event_type = "cancelled" elif event_type == "heartbeat": event_type = "heartbeat" # 构建 SSE 格式 data = json.dumps(progress, ensure_ascii=False) yield f"event: {event_type}\ndata: {data}\n\n" # 如果是终态,结束流 if status in ("completed", "failed", "cancelled"): break return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # Nginx 禁用缓冲 }, ) @router.get( "/{task_id}/outputs", response_model=InferenceOutputsResponse, summary="获取推理输出列表", description=""" 获取任务的推理输出文件列表及推理配置信息。 训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息, 包括文件名、使用的模型路径、文件大小等,以及推理使用的参考音频和文本信息。 **推理配置**: - `ref_text`: 参考音频的文本内容 - `ref_audio_path`: 参考音频文件路径 - `target_text`: 合成的目标文本 **输出文件信息**: - `filename`: 文件名 - `gpt_model`: 使用的 GPT 模型名称 - `sovits_model`: 使用的 SoVITS 模型名称 - `gpt_path`: GPT 模型完整路径 - `sovits_path`: SoVITS 模型完整路径 - `file_path`: 输出文件相对路径 - `size_bytes`: 文件大小(字节) - `created_at`: 创建时间 **下载文件**: 使用 `/tasks/{task_id}/outputs/{file_type}/{filename}` 端点下载相关文件。 """, responses={ 200: {"model": InferenceOutputsResponse, "description": "推理输出列表"}, 404: {"model": ErrorResponse, "description": "任务不存在"}, }, ) async def get_task_outputs( task_id: str, service: TaskService = Depends(get_task_service), ) -> InferenceOutputsResponse: """ 获取任务的推理输出列表 """ result = await service.get_inference_outputs(task_id) if result is None: raise HTTPException(status_code=404, detail="任务不存在") return result # 文件类型定义 FileType = Literal["output", "ref_audio", "gpt_model", "sovits_model"] @router.get( "/{task_id}/outputs/{file_type}/{filename:path}", summary="下载任务相关文件", description=""" 下载任务相关的各类文件。 **文件类型 (file_type)**: - `output` - 推理输出音频文件 (.wav) - `ref_audio` - 参考音频文件 (.wav) - `gpt_model` - GPT 模型文件 (.ckpt) - `sovits_model` - SoVITS 模型文件 (.pth) **文件名来源**: - `output`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].filename` 获取 - `ref_audio`: 从 `/tasks/{task_id}/outputs` 端点的 `ref_audio_path` 获取 - `gpt_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].gpt_path` 获取文件名部分 - `sovits_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].sovits_path` 获取文件名部分 **返回**: - 音频文件: Content-Type: audio/wav - 模型文件: Content-Type: application/octet-stream """, responses={ 200: {"description": "文件内容"}, 404: {"model": ErrorResponse, "description": "任务或文件不存在"}, }, ) async def download_task_file( task_id: str, file_type: FileType = Path(..., description="文件类型: output/ref_audio/gpt_model/sovits_model"), filename: str = Path(..., description="文件名或路径"), service: TaskService = Depends(get_task_service), ) -> Response: """ 下载任务相关文件(推理输出、参考音频、模型文件) """ result = await service.download_file(task_id, file_type, filename) if result is None: raise HTTPException(status_code=404, detail="任务或文件不存在") file_data, file_name, content_type = result return Response( content=file_data, media_type=content_type, headers={ "Content-Disposition": f'attachment; filename="{file_name}"', "Content-Length": str(len(file_data)), }, )