liumaolin
feat(api): enhance task file management and download endpoints
8064340
"""
Quick Mode 任务 API
小白用户一键训练 API 端点
API 列表:
- POST /tasks 创建一键训练任务
- GET /tasks 获取任务列表
- GET /tasks/{task_id} 获取任务详情
- DELETE /tasks/{task_id} 取消任务
- GET /tasks/{task_id}/progress SSE 进度订阅
- GET /tasks/{task_id}/outputs 获取推理输出列表
- GET /tasks/{task_id}/outputs/{file_type}/{filename} 下载任务相关文件
"""
import json
from typing import Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from fastapi.responses import StreamingResponse, Response
from ....models.schemas.task import (
QuickModeRequest,
TaskResponse,
TaskListResponse,
InferenceOutputsResponse,
)
from ....models.schemas.common import SuccessResponse, ErrorResponse
from ....services.task_service import TaskService
from ...deps import get_task_service
router = APIRouter()
@router.post(
"",
response_model=TaskResponse,
summary="创建一键训练任务",
description="""
创建一键训练任务(小白用户)。
上传音频文件后,系统自动配置所有参数并执行完整训练流程:
`audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train -> inference`
训练完成后会自动进行推理测试,生成测试音频文件。可通过 `options.inference` 配置推理参数或禁用推理阶段。
**质量预设**:
- `fast`: SoVITS 4 epochs, GPT 8 epochs, 约10分钟
- `standard`: SoVITS 8 epochs, GPT 15 epochs, 约20分钟
- `high`: SoVITS 16 epochs, GPT 30 epochs, 约40分钟
""",
responses={
200: {"model": TaskResponse, "description": "任务创建成功"},
400: {"model": ErrorResponse, "description": "请求参数错误"},
404: {"model": ErrorResponse, "description": "音频文件不存在"},
409: {"model": ErrorResponse, "description": "实验名称已存在"},
},
)
async def create_task(
request: QuickModeRequest,
service: TaskService = Depends(get_task_service),
) -> TaskResponse:
"""
创建一键训练任务
"""
# 验证 exp_name 是否已存在
if await service.check_exp_name_exists(request.exp_name):
raise HTTPException(
status_code=409,
detail=f"实验名称 '{request.exp_name}' 已存在,请使用不同的名称"
)
# 验证音频文件是否存在
file_exists, audio_path = await service.validate_audio_file(request.audio_file_id)
if not file_exists:
raise HTTPException(
status_code=404,
detail=f"音频文件不存在: {request.audio_file_id}"
)
return await service.create_quick_task(request)
@router.get(
"",
response_model=TaskListResponse,
summary="获取任务列表",
description="获取所有训练任务列表,支持按状态筛选和分页。",
)
async def list_tasks(
status: Optional[str] = Query(
None,
description="按状态筛选: queued, running, completed, failed, cancelled, interrupted"
),
limit: int = Query(50, ge=1, le=100, description="每页数量"),
offset: int = Query(0, ge=0, description="偏移量"),
service: TaskService = Depends(get_task_service),
) -> TaskListResponse:
"""
获取任务列表
"""
return await service.list_tasks(status=status, limit=limit, offset=offset)
@router.get(
"/{task_id}",
response_model=TaskResponse,
summary="获取任务详情",
description="获取指定任务的详细状态信息。",
responses={
200: {"model": TaskResponse, "description": "任务详情"},
404: {"model": ErrorResponse, "description": "任务不存在"},
},
)
async def get_task(
task_id: str,
service: TaskService = Depends(get_task_service),
) -> TaskResponse:
"""
获取任务详情
"""
task = await service.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
return task
@router.delete(
"/{task_id}",
response_model=SuccessResponse,
summary="取消任务",
description="取消排队中或运行中的任务。已完成、失败或已取消的任务无法取消。",
responses={
200: {"model": SuccessResponse, "description": "任务取消成功"},
400: {"model": ErrorResponse, "description": "任务无法取消"},
404: {"model": ErrorResponse, "description": "任务不存在"},
},
)
async def cancel_task(
task_id: str,
service: TaskService = Depends(get_task_service),
) -> SuccessResponse:
"""
取消任务
"""
# 先检查任务是否存在
task = await service.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
success = await service.cancel_task(task_id)
if not success:
raise HTTPException(status_code=400, detail="任务无法取消(可能已完成或已取消)")
return SuccessResponse(message="任务已取消")
@router.get(
"/{task_id}/progress",
summary="SSE 进度订阅",
description="""
订阅任务进度更新(Server-Sent Events)。
返回的事件流格式:
```
event: progress
data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"}
event: progress
data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"}
event: completed
data: {"status": "completed", "message": "训练完成"}
```
可能的事件类型:
- `progress`: 进度更新
- `log`: 日志消息
- `heartbeat`: 心跳(保持连接)
- `completed`: 任务完成
- `failed`: 任务失败
- `cancelled`: 任务取消
""",
responses={
200: {"description": "SSE 事件流"},
404: {"model": ErrorResponse, "description": "任务不存在"},
},
)
async def subscribe_progress(
task_id: str,
service: TaskService = Depends(get_task_service),
) -> StreamingResponse:
"""
SSE 进度订阅
"""
# 先检查任务是否存在
task = await service.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
async def event_generator():
"""生成 SSE 事件流"""
async for progress in service.subscribe_progress(task_id):
# 确定事件类型
event_type = progress.get("type", "progress")
status = progress.get("status")
if status == "completed":
event_type = "completed"
elif status == "failed":
event_type = "failed"
elif status == "cancelled":
event_type = "cancelled"
elif event_type == "heartbeat":
event_type = "heartbeat"
# 构建 SSE 格式
data = json.dumps(progress, ensure_ascii=False)
yield f"event: {event_type}\ndata: {data}\n\n"
# 如果是终态,结束流
if status in ("completed", "failed", "cancelled"):
break
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Nginx 禁用缓冲
},
)
@router.get(
"/{task_id}/outputs",
response_model=InferenceOutputsResponse,
summary="获取推理输出列表",
description="""
获取任务的推理输出文件列表及推理配置信息。
训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息,
包括文件名、使用的模型路径、文件大小等,以及推理使用的参考音频和文本信息。
**推理配置**:
- `ref_text`: 参考音频的文本内容
- `ref_audio_path`: 参考音频文件路径
- `target_text`: 合成的目标文本
**输出文件信息**:
- `filename`: 文件名
- `gpt_model`: 使用的 GPT 模型名称
- `sovits_model`: 使用的 SoVITS 模型名称
- `gpt_path`: GPT 模型完整路径
- `sovits_path`: SoVITS 模型完整路径
- `file_path`: 输出文件相对路径
- `size_bytes`: 文件大小(字节)
- `created_at`: 创建时间
**下载文件**:
使用 `/tasks/{task_id}/outputs/{file_type}/{filename}` 端点下载相关文件。
""",
responses={
200: {"model": InferenceOutputsResponse, "description": "推理输出列表"},
404: {"model": ErrorResponse, "description": "任务不存在"},
},
)
async def get_task_outputs(
task_id: str,
service: TaskService = Depends(get_task_service),
) -> InferenceOutputsResponse:
"""
获取任务的推理输出列表
"""
result = await service.get_inference_outputs(task_id)
if result is None:
raise HTTPException(status_code=404, detail="任务不存在")
return result
# 文件类型定义
FileType = Literal["output", "ref_audio", "gpt_model", "sovits_model"]
@router.get(
"/{task_id}/outputs/{file_type}/{filename:path}",
summary="下载任务相关文件",
description="""
下载任务相关的各类文件。
**文件类型 (file_type)**:
- `output` - 推理输出音频文件 (.wav)
- `ref_audio` - 参考音频文件 (.wav)
- `gpt_model` - GPT 模型文件 (.ckpt)
- `sovits_model` - SoVITS 模型文件 (.pth)
**文件名来源**:
- `output`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].filename` 获取
- `ref_audio`: 从 `/tasks/{task_id}/outputs` 端点的 `ref_audio_path` 获取
- `gpt_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].gpt_path` 获取文件名部分
- `sovits_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].sovits_path` 获取文件名部分
**返回**:
- 音频文件: Content-Type: audio/wav
- 模型文件: Content-Type: application/octet-stream
""",
responses={
200: {"description": "文件内容"},
404: {"model": ErrorResponse, "description": "任务或文件不存在"},
},
)
async def download_task_file(
task_id: str,
file_type: FileType = Path(..., description="文件类型: output/ref_audio/gpt_model/sovits_model"),
filename: str = Path(..., description="文件名或路径"),
service: TaskService = Depends(get_task_service),
) -> Response:
"""
下载任务相关文件(推理输出、参考音频、模型文件)
"""
result = await service.download_file(task_id, file_type, filename)
if result is None:
raise HTTPException(status_code=404, detail="任务或文件不存在")
file_data, file_name, content_type = result
return Response(
content=file_data,
media_type=content_type,
headers={
"Content-Disposition": f'attachment; filename="{file_name}"',
"Content-Length": str(len(file_data)),
},
)