liumaolin
feat(api): implement local training MVP with adapter pattern
e054d0c
"""
Advanced Mode 实验 API
专家用户分阶段训练 API 端点
API 列表:
- POST /experiments 创建实验
- GET /experiments 获取实验列表
- GET /experiments/{exp_id} 获取实验详情
- PATCH /experiments/{exp_id} 更新实验配置
- DELETE /experiments/{exp_id} 删除实验
- POST /experiments/{exp_id}/stages/{stage_type} 执行阶段
- GET /experiments/{exp_id}/stages 获取所有阶段状态
- GET /experiments/{exp_id}/stages/{stage_type} 获取阶段详情
- DELETE /experiments/{exp_id}/stages/{stage_type} 取消阶段
- GET /experiments/{exp_id}/stages/{stage_type}/progress SSE 阶段进度
"""
import json
from typing import Any, Dict, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from ....models.schemas.experiment import (
ExperimentCreate,
ExperimentUpdate,
ExperimentResponse,
ExperimentListResponse,
StageStatus,
StageExecuteResponse,
StagesListResponse,
STAGE_DEPENDENCIES,
get_stage_params_class,
)
from ....models.schemas.common import SuccessResponse, ErrorResponse
from ....services.experiment_service import ExperimentService
from ...deps import get_experiment_service
router = APIRouter()
# 有效的阶段类型
VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys())
@router.post(
"",
response_model=ExperimentResponse,
summary="创建实验",
description="""
创建实验(专家用户)。
创建实验但不立即执行,用户可以逐阶段控制训练流程。
实验创建后,所有阶段状态为 `pending`,需要手动触发执行。
**训练阶段**:
- `audio_slice`: 音频切片
- `asr`: 语音识别
- `text_feature`: 文本特征提取
- `hubert_feature`: HuBERT 特征提取
- `semantic_token`: 语义 Token 提取
- `sovits_train`: SoVITS 训练
- `gpt_train`: GPT 训练
""",
)
async def create_experiment(
request: ExperimentCreate,
service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentResponse:
"""
创建实验
"""
return await service.create_experiment(request)
@router.get(
"",
response_model=ExperimentListResponse,
summary="获取实验列表",
description="获取所有实验列表,支持按状态筛选和分页。",
)
async def list_experiments(
status: Optional[str] = Query(None, description="按状态筛选"),
limit: int = Query(50, ge=1, le=100, description="每页数量"),
offset: int = Query(0, ge=0, description="偏移量"),
service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentListResponse:
"""
获取实验列表
"""
return await service.list_experiments(status=status, limit=limit, offset=offset)
@router.get(
"/{exp_id}",
response_model=ExperimentResponse,
summary="获取实验详情",
description="获取指定实验的详细信息,包括所有阶段状态。",
responses={
404: {"model": ErrorResponse, "description": "实验不存在"},
},
)
async def get_experiment(
exp_id: str,
service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentResponse:
"""
获取实验详情
"""
experiment = await service.get_experiment(exp_id)
if not experiment:
raise HTTPException(status_code=404, detail="实验不存在")
return experiment
@router.patch(
"/{exp_id}",
response_model=ExperimentResponse,
summary="更新实验配置",
description="更新实验的基础配置(非阶段参数)。",
responses={
404: {"model": ErrorResponse, "description": "实验不存在"},
},
)
async def update_experiment(
exp_id: str,
request: ExperimentUpdate,
service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentResponse:
"""
更新实验配置
"""
experiment = await service.update_experiment(exp_id, request)
if not experiment:
raise HTTPException(status_code=404, detail="实验不存在")
return experiment
@router.delete(
"/{exp_id}",
response_model=SuccessResponse,
summary="删除实验",
description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。",
responses={
404: {"model": ErrorResponse, "description": "实验不存在"},
},
)
async def delete_experiment(
exp_id: str,
service: ExperimentService = Depends(get_experiment_service),
) -> SuccessResponse:
"""
删除实验
"""
success = await service.delete_experiment(exp_id)
if not success:
raise HTTPException(status_code=404, detail="实验不存在")
return SuccessResponse(message="实验已删除")
@router.post(
"/{exp_id}/stages/{stage_type}",
response_model=StageExecuteResponse,
summary="执行阶段",
description="""
执行指定阶段。
**阶段依赖关系**:
- `audio_slice`: 无依赖
- `asr`: 依赖 audio_slice
- `text_feature`: 依赖 asr
- `hubert_feature`: 依赖 audio_slice
- `semantic_token`: 依赖 hubert_feature
- `sovits_train`: 依赖 text_feature, semantic_token
- `gpt_train`: 依赖 text_feature, semantic_token
如果依赖阶段未完成,会返回 400 错误。
如果阶段已完成,会重新执行(返回 `rerun: true`)。
""",
responses={
400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"},
404: {"model": ErrorResponse, "description": "实验不存在"},
},
)
async def execute_stage(
exp_id: str,
stage_type: str,
params: Dict[str, Any] = Body(default={}),
service: ExperimentService = Depends(get_experiment_service),
) -> StageExecuteResponse:
"""
执行阶段
"""
# 验证阶段类型
if stage_type not in VALID_STAGE_TYPES:
raise HTTPException(
status_code=400,
detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}"
)
# 检查实验是否存在
experiment = await service.get_experiment(exp_id)
if not experiment:
raise HTTPException(status_code=404, detail="实验不存在")
# 检查依赖
deps = await service.check_stage_dependencies(exp_id, stage_type)
if not deps["satisfied"]:
raise HTTPException(
status_code=400,
detail=f"依赖阶段未完成: {', '.join(deps['missing'])}"
)
# 验证并解析参数
try:
params_class = get_stage_params_class(stage_type)
validated_params = params_class(**params)
params = validated_params.model_dump(exclude_unset=True)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# 执行阶段
result = await service.execute_stage(exp_id, stage_type, params)
if not result:
raise HTTPException(status_code=404, detail="实验不存在")
return result
@router.get(
"/{exp_id}/stages",
response_model=StagesListResponse,
summary="获取所有阶段状态",
description="获取实验的所有阶段状态列表。",
responses={
404: {"model": ErrorResponse, "description": "实验不存在"},
},
)
async def get_all_stages(
exp_id: str,
service: ExperimentService = Depends(get_experiment_service),
) -> StagesListResponse:
"""
获取所有阶段状态
"""
result = await service.get_all_stages(exp_id)
if not result:
raise HTTPException(status_code=404, detail="实验不存在")
return result
@router.get(
"/{exp_id}/stages/{stage_type}",
response_model=StageStatus,
summary="获取阶段详情",
description="获取指定阶段的详细状态和结果。",
responses={
400: {"model": ErrorResponse, "description": "阶段类型无效"},
404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
},
)
async def get_stage(
exp_id: str,
stage_type: str,
service: ExperimentService = Depends(get_experiment_service),
) -> StageStatus:
"""
获取阶段详情
"""
# 验证阶段类型
if stage_type not in VALID_STAGE_TYPES:
raise HTTPException(
status_code=400,
detail=f"无效的阶段类型: {stage_type}"
)
stage = await service.get_stage(exp_id, stage_type)
if not stage:
raise HTTPException(status_code=404, detail="实验或阶段不存在")
return stage
@router.delete(
"/{exp_id}/stages/{stage_type}",
response_model=SuccessResponse,
summary="取消阶段",
description="取消正在执行的阶段。只有运行中的阶段可以取消。",
responses={
400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"},
404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
},
)
async def cancel_stage(
exp_id: str,
stage_type: str,
service: ExperimentService = Depends(get_experiment_service),
) -> SuccessResponse:
"""
取消阶段
"""
# 验证阶段类型
if stage_type not in VALID_STAGE_TYPES:
raise HTTPException(
status_code=400,
detail=f"无效的阶段类型: {stage_type}"
)
success = await service.cancel_stage(exp_id, stage_type)
if not success:
raise HTTPException(
status_code=400,
detail="阶段未运行或无法取消"
)
return SuccessResponse(message=f"阶段 {stage_type} 已取消")
@router.get(
"/{exp_id}/stages/{stage_type}/progress",
summary="SSE 阶段进度订阅",
description="""
订阅阶段进度更新(Server-Sent Events)。
返回的事件流格式:
```
event: progress
data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034}
event: checkpoint
data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"}
event: completed
data: {"status": "completed", "final_loss": 0.023}
```
""",
responses={
400: {"model": ErrorResponse, "description": "阶段类型无效"},
404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
},
)
async def subscribe_stage_progress(
exp_id: str,
stage_type: str,
service: ExperimentService = Depends(get_experiment_service),
) -> StreamingResponse:
"""
SSE 阶段进度订阅
"""
# 验证阶段类型
if stage_type not in VALID_STAGE_TYPES:
raise HTTPException(
status_code=400,
detail=f"无效的阶段类型: {stage_type}"
)
# 检查实验是否存在
experiment = await service.get_experiment(exp_id)
if not experiment:
raise HTTPException(status_code=404, detail="实验不存在")
async def event_generator():
"""生成 SSE 事件流"""
async for progress in service.subscribe_stage_progress(exp_id, stage_type):
# 确定事件类型
event_type = progress.get("type", "progress")
status = progress.get("status")
if status == "completed":
event_type = "completed"
elif status == "failed":
event_type = "failed"
elif status == "cancelled":
event_type = "cancelled"
elif progress.get("model_path"):
event_type = "checkpoint"
# 构建 SSE 格式
data = json.dumps(progress, ensure_ascii=False)
yield f"event: {event_type}\ndata: {data}\n\n"
# 如果是终态,结束流
if status in ("completed", "failed", "cancelled"):
break
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)