|
|
""" |
|
|
Quick Mode 任务 API |
|
|
|
|
|
小白用户一键训练 API 端点 |
|
|
|
|
|
API 列表: |
|
|
- POST /tasks 创建一键训练任务 |
|
|
- GET /tasks 获取任务列表 |
|
|
- GET /tasks/{task_id} 获取任务详情 |
|
|
- DELETE /tasks/{task_id} 取消任务 |
|
|
- GET /tasks/{task_id}/progress SSE 进度订阅 |
|
|
- GET /tasks/{task_id}/outputs 获取推理输出列表 |
|
|
- GET /tasks/{task_id}/outputs/{file_type}/{filename} 下载任务相关文件 |
|
|
""" |
|
|
|
|
|
import json |
|
|
from typing import Literal, Optional |
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Path, Query |
|
|
from fastapi.responses import StreamingResponse, Response |
|
|
|
|
|
from ....models.schemas.task import ( |
|
|
QuickModeRequest, |
|
|
TaskResponse, |
|
|
TaskListResponse, |
|
|
InferenceOutputsResponse, |
|
|
) |
|
|
from ....models.schemas.common import SuccessResponse, ErrorResponse |
|
|
from ....services.task_service import TaskService |
|
|
from ...deps import get_task_service |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"", |
|
|
response_model=TaskResponse, |
|
|
summary="创建一键训练任务", |
|
|
description=""" |
|
|
创建一键训练任务(小白用户)。 |
|
|
|
|
|
上传音频文件后,系统自动配置所有参数并执行完整训练流程: |
|
|
`audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train -> inference` |
|
|
|
|
|
训练完成后会自动进行推理测试,生成测试音频文件。可通过 `options.inference` 配置推理参数或禁用推理阶段。 |
|
|
|
|
|
**质量预设**: |
|
|
- `fast`: SoVITS 4 epochs, GPT 8 epochs, 约10分钟 |
|
|
- `standard`: SoVITS 8 epochs, GPT 15 epochs, 约20分钟 |
|
|
- `high`: SoVITS 16 epochs, GPT 30 epochs, 约40分钟 |
|
|
""", |
|
|
responses={ |
|
|
200: {"model": TaskResponse, "description": "任务创建成功"}, |
|
|
400: {"model": ErrorResponse, "description": "请求参数错误"}, |
|
|
404: {"model": ErrorResponse, "description": "音频文件不存在"}, |
|
|
409: {"model": ErrorResponse, "description": "实验名称已存在"}, |
|
|
}, |
|
|
) |
|
|
async def create_task( |
|
|
request: QuickModeRequest, |
|
|
service: TaskService = Depends(get_task_service), |
|
|
) -> TaskResponse: |
|
|
""" |
|
|
创建一键训练任务 |
|
|
""" |
|
|
|
|
|
if await service.check_exp_name_exists(request.exp_name): |
|
|
raise HTTPException( |
|
|
status_code=409, |
|
|
detail=f"实验名称 '{request.exp_name}' 已存在,请使用不同的名称" |
|
|
) |
|
|
|
|
|
|
|
|
file_exists, audio_path = await service.validate_audio_file(request.audio_file_id) |
|
|
if not file_exists: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"音频文件不存在: {request.audio_file_id}" |
|
|
) |
|
|
|
|
|
return await service.create_quick_task(request) |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"", |
|
|
response_model=TaskListResponse, |
|
|
summary="获取任务列表", |
|
|
description="获取所有训练任务列表,支持按状态筛选和分页。", |
|
|
) |
|
|
async def list_tasks( |
|
|
status: Optional[str] = Query( |
|
|
None, |
|
|
description="按状态筛选: queued, running, completed, failed, cancelled, interrupted" |
|
|
), |
|
|
limit: int = Query(50, ge=1, le=100, description="每页数量"), |
|
|
offset: int = Query(0, ge=0, description="偏移量"), |
|
|
service: TaskService = Depends(get_task_service), |
|
|
) -> TaskListResponse: |
|
|
""" |
|
|
获取任务列表 |
|
|
""" |
|
|
return await service.list_tasks(status=status, limit=limit, offset=offset) |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{task_id}", |
|
|
response_model=TaskResponse, |
|
|
summary="获取任务详情", |
|
|
description="获取指定任务的详细状态信息。", |
|
|
responses={ |
|
|
200: {"model": TaskResponse, "description": "任务详情"}, |
|
|
404: {"model": ErrorResponse, "description": "任务不存在"}, |
|
|
}, |
|
|
) |
|
|
async def get_task( |
|
|
task_id: str, |
|
|
service: TaskService = Depends(get_task_service), |
|
|
) -> TaskResponse: |
|
|
""" |
|
|
获取任务详情 |
|
|
""" |
|
|
task = await service.get_task(task_id) |
|
|
if not task: |
|
|
raise HTTPException(status_code=404, detail="任务不存在") |
|
|
return task |
|
|
|
|
|
|
|
|
@router.delete( |
|
|
"/{task_id}", |
|
|
response_model=SuccessResponse, |
|
|
summary="取消任务", |
|
|
description="取消排队中或运行中的任务。已完成、失败或已取消的任务无法取消。", |
|
|
responses={ |
|
|
200: {"model": SuccessResponse, "description": "任务取消成功"}, |
|
|
400: {"model": ErrorResponse, "description": "任务无法取消"}, |
|
|
404: {"model": ErrorResponse, "description": "任务不存在"}, |
|
|
}, |
|
|
) |
|
|
async def cancel_task( |
|
|
task_id: str, |
|
|
service: TaskService = Depends(get_task_service), |
|
|
) -> SuccessResponse: |
|
|
""" |
|
|
取消任务 |
|
|
""" |
|
|
|
|
|
task = await service.get_task(task_id) |
|
|
if not task: |
|
|
raise HTTPException(status_code=404, detail="任务不存在") |
|
|
|
|
|
success = await service.cancel_task(task_id) |
|
|
if not success: |
|
|
raise HTTPException(status_code=400, detail="任务无法取消(可能已完成或已取消)") |
|
|
|
|
|
return SuccessResponse(message="任务已取消") |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{task_id}/progress", |
|
|
summary="SSE 进度订阅", |
|
|
description=""" |
|
|
订阅任务进度更新(Server-Sent Events)。 |
|
|
|
|
|
返回的事件流格式: |
|
|
``` |
|
|
event: progress |
|
|
data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"} |
|
|
|
|
|
event: progress |
|
|
data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"} |
|
|
|
|
|
event: completed |
|
|
data: {"status": "completed", "message": "训练完成"} |
|
|
``` |
|
|
|
|
|
可能的事件类型: |
|
|
- `progress`: 进度更新 |
|
|
- `log`: 日志消息 |
|
|
- `heartbeat`: 心跳(保持连接) |
|
|
- `completed`: 任务完成 |
|
|
- `failed`: 任务失败 |
|
|
- `cancelled`: 任务取消 |
|
|
""", |
|
|
responses={ |
|
|
200: {"description": "SSE 事件流"}, |
|
|
404: {"model": ErrorResponse, "description": "任务不存在"}, |
|
|
}, |
|
|
) |
|
|
async def subscribe_progress( |
|
|
task_id: str, |
|
|
service: TaskService = Depends(get_task_service), |
|
|
) -> StreamingResponse: |
|
|
""" |
|
|
SSE 进度订阅 |
|
|
""" |
|
|
|
|
|
task = await service.get_task(task_id) |
|
|
if not task: |
|
|
raise HTTPException(status_code=404, detail="任务不存在") |
|
|
|
|
|
async def event_generator(): |
|
|
"""生成 SSE 事件流""" |
|
|
async for progress in service.subscribe_progress(task_id): |
|
|
|
|
|
event_type = progress.get("type", "progress") |
|
|
status = progress.get("status") |
|
|
|
|
|
if status == "completed": |
|
|
event_type = "completed" |
|
|
elif status == "failed": |
|
|
event_type = "failed" |
|
|
elif status == "cancelled": |
|
|
event_type = "cancelled" |
|
|
elif event_type == "heartbeat": |
|
|
event_type = "heartbeat" |
|
|
|
|
|
|
|
|
data = json.dumps(progress, ensure_ascii=False) |
|
|
yield f"event: {event_type}\ndata: {data}\n\n" |
|
|
|
|
|
|
|
|
if status in ("completed", "failed", "cancelled"): |
|
|
break |
|
|
|
|
|
return StreamingResponse( |
|
|
event_generator(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"X-Accel-Buffering": "no", |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{task_id}/outputs", |
|
|
response_model=InferenceOutputsResponse, |
|
|
summary="获取推理输出列表", |
|
|
description=""" |
|
|
获取任务的推理输出文件列表及推理配置信息。 |
|
|
|
|
|
训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息, |
|
|
包括文件名、使用的模型路径、文件大小等,以及推理使用的参考音频和文本信息。 |
|
|
|
|
|
**推理配置**: |
|
|
- `ref_text`: 参考音频的文本内容 |
|
|
- `ref_audio_path`: 参考音频文件路径 |
|
|
- `target_text`: 合成的目标文本 |
|
|
|
|
|
**输出文件信息**: |
|
|
- `filename`: 文件名 |
|
|
- `gpt_model`: 使用的 GPT 模型名称 |
|
|
- `sovits_model`: 使用的 SoVITS 模型名称 |
|
|
- `gpt_path`: GPT 模型完整路径 |
|
|
- `sovits_path`: SoVITS 模型完整路径 |
|
|
- `file_path`: 输出文件相对路径 |
|
|
- `size_bytes`: 文件大小(字节) |
|
|
- `created_at`: 创建时间 |
|
|
|
|
|
**下载文件**: |
|
|
使用 `/tasks/{task_id}/outputs/{file_type}/{filename}` 端点下载相关文件。 |
|
|
""", |
|
|
responses={ |
|
|
200: {"model": InferenceOutputsResponse, "description": "推理输出列表"}, |
|
|
404: {"model": ErrorResponse, "description": "任务不存在"}, |
|
|
}, |
|
|
) |
|
|
async def get_task_outputs( |
|
|
task_id: str, |
|
|
service: TaskService = Depends(get_task_service), |
|
|
) -> InferenceOutputsResponse: |
|
|
""" |
|
|
获取任务的推理输出列表 |
|
|
""" |
|
|
result = await service.get_inference_outputs(task_id) |
|
|
if result is None: |
|
|
raise HTTPException(status_code=404, detail="任务不存在") |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
FileType = Literal["output", "ref_audio", "gpt_model", "sovits_model"] |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{task_id}/outputs/{file_type}/{filename:path}", |
|
|
summary="下载任务相关文件", |
|
|
description=""" |
|
|
下载任务相关的各类文件。 |
|
|
|
|
|
**文件类型 (file_type)**: |
|
|
- `output` - 推理输出音频文件 (.wav) |
|
|
- `ref_audio` - 参考音频文件 (.wav) |
|
|
- `gpt_model` - GPT 模型文件 (.ckpt) |
|
|
- `sovits_model` - SoVITS 模型文件 (.pth) |
|
|
|
|
|
**文件名来源**: |
|
|
- `output`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].filename` 获取 |
|
|
- `ref_audio`: 从 `/tasks/{task_id}/outputs` 端点的 `ref_audio_path` 获取 |
|
|
- `gpt_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].gpt_path` 获取文件名部分 |
|
|
- `sovits_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].sovits_path` 获取文件名部分 |
|
|
|
|
|
**返回**: |
|
|
- 音频文件: Content-Type: audio/wav |
|
|
- 模型文件: Content-Type: application/octet-stream |
|
|
""", |
|
|
responses={ |
|
|
200: {"description": "文件内容"}, |
|
|
404: {"model": ErrorResponse, "description": "任务或文件不存在"}, |
|
|
}, |
|
|
) |
|
|
async def download_task_file( |
|
|
task_id: str, |
|
|
file_type: FileType = Path(..., description="文件类型: output/ref_audio/gpt_model/sovits_model"), |
|
|
filename: str = Path(..., description="文件名或路径"), |
|
|
service: TaskService = Depends(get_task_service), |
|
|
) -> Response: |
|
|
""" |
|
|
下载任务相关文件(推理输出、参考音频、模型文件) |
|
|
""" |
|
|
result = await service.download_file(task_id, file_type, filename) |
|
|
if result is None: |
|
|
raise HTTPException(status_code=404, detail="任务或文件不存在") |
|
|
|
|
|
file_data, file_name, content_type = result |
|
|
|
|
|
return Response( |
|
|
content=file_data, |
|
|
media_type=content_type, |
|
|
headers={ |
|
|
"Content-Disposition": f'attachment; filename="{file_name}"', |
|
|
"Content-Length": str(len(file_data)), |
|
|
}, |
|
|
) |
|
|
|