|
|
""" |
|
|
Advanced Mode 实验 API |
|
|
|
|
|
专家用户分阶段训练 API 端点 |
|
|
|
|
|
API 列表: |
|
|
- POST /experiments 创建实验 |
|
|
- GET /experiments 获取实验列表 |
|
|
- GET /experiments/{exp_id} 获取实验详情 |
|
|
- PATCH /experiments/{exp_id} 更新实验配置 |
|
|
- DELETE /experiments/{exp_id} 删除实验 |
|
|
- POST /experiments/{exp_id}/stages/{stage_type} 执行阶段 |
|
|
- GET /experiments/{exp_id}/stages 获取所有阶段状态 |
|
|
- GET /experiments/{exp_id}/stages/{stage_type} 获取阶段详情 |
|
|
- DELETE /experiments/{exp_id}/stages/{stage_type} 取消阶段 |
|
|
- GET /experiments/{exp_id}/stages/{stage_type}/progress SSE 阶段进度 |
|
|
""" |
|
|
|
|
|
import json |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, Query |
|
|
from fastapi.responses import StreamingResponse |
|
|
|
|
|
from ....models.schemas.experiment import ( |
|
|
ExperimentCreate, |
|
|
ExperimentUpdate, |
|
|
ExperimentResponse, |
|
|
ExperimentListResponse, |
|
|
StageStatus, |
|
|
StageExecuteResponse, |
|
|
StagesListResponse, |
|
|
STAGE_DEPENDENCIES, |
|
|
get_stage_params_class, |
|
|
) |
|
|
from ....models.schemas.common import SuccessResponse, ErrorResponse |
|
|
from ....services.experiment_service import ExperimentService |
|
|
from ...deps import get_experiment_service |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys()) |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"", |
|
|
response_model=ExperimentResponse, |
|
|
summary="创建实验", |
|
|
description=""" |
|
|
创建实验(专家用户)。 |
|
|
|
|
|
创建实验但不立即执行,用户可以逐阶段控制训练流程。 |
|
|
实验创建后,所有阶段状态为 `pending`,需要手动触发执行。 |
|
|
|
|
|
**训练阶段**: |
|
|
- `audio_slice`: 音频切片 |
|
|
- `asr`: 语音识别 |
|
|
- `text_feature`: 文本特征提取 |
|
|
- `hubert_feature`: HuBERT 特征提取 |
|
|
- `semantic_token`: 语义 Token 提取 |
|
|
- `sovits_train`: SoVITS 训练 |
|
|
- `gpt_train`: GPT 训练 |
|
|
""", |
|
|
) |
|
|
async def create_experiment( |
|
|
request: ExperimentCreate, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> ExperimentResponse: |
|
|
""" |
|
|
创建实验 |
|
|
""" |
|
|
return await service.create_experiment(request) |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"", |
|
|
response_model=ExperimentListResponse, |
|
|
summary="获取实验列表", |
|
|
description="获取所有实验列表,支持按状态筛选和分页。", |
|
|
) |
|
|
async def list_experiments( |
|
|
status: Optional[str] = Query(None, description="按状态筛选"), |
|
|
limit: int = Query(50, ge=1, le=100, description="每页数量"), |
|
|
offset: int = Query(0, ge=0, description="偏移量"), |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> ExperimentListResponse: |
|
|
""" |
|
|
获取实验列表 |
|
|
""" |
|
|
return await service.list_experiments(status=status, limit=limit, offset=offset) |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{exp_id}", |
|
|
response_model=ExperimentResponse, |
|
|
summary="获取实验详情", |
|
|
description="获取指定实验的详细信息,包括所有阶段状态。", |
|
|
responses={ |
|
|
404: {"model": ErrorResponse, "description": "实验不存在"}, |
|
|
}, |
|
|
) |
|
|
async def get_experiment( |
|
|
exp_id: str, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> ExperimentResponse: |
|
|
""" |
|
|
获取实验详情 |
|
|
""" |
|
|
experiment = await service.get_experiment(exp_id) |
|
|
if not experiment: |
|
|
raise HTTPException(status_code=404, detail="实验不存在") |
|
|
return experiment |
|
|
|
|
|
|
|
|
@router.patch( |
|
|
"/{exp_id}", |
|
|
response_model=ExperimentResponse, |
|
|
summary="更新实验配置", |
|
|
description="更新实验的基础配置(非阶段参数)。", |
|
|
responses={ |
|
|
404: {"model": ErrorResponse, "description": "实验不存在"}, |
|
|
}, |
|
|
) |
|
|
async def update_experiment( |
|
|
exp_id: str, |
|
|
request: ExperimentUpdate, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> ExperimentResponse: |
|
|
""" |
|
|
更新实验配置 |
|
|
""" |
|
|
experiment = await service.update_experiment(exp_id, request) |
|
|
if not experiment: |
|
|
raise HTTPException(status_code=404, detail="实验不存在") |
|
|
return experiment |
|
|
|
|
|
|
|
|
@router.delete( |
|
|
"/{exp_id}", |
|
|
response_model=SuccessResponse, |
|
|
summary="删除实验", |
|
|
description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。", |
|
|
responses={ |
|
|
404: {"model": ErrorResponse, "description": "实验不存在"}, |
|
|
}, |
|
|
) |
|
|
async def delete_experiment( |
|
|
exp_id: str, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> SuccessResponse: |
|
|
""" |
|
|
删除实验 |
|
|
""" |
|
|
success = await service.delete_experiment(exp_id) |
|
|
if not success: |
|
|
raise HTTPException(status_code=404, detail="实验不存在") |
|
|
return SuccessResponse(message="实验已删除") |
|
|
|
|
|
|
|
|
@router.post( |
|
|
"/{exp_id}/stages/{stage_type}", |
|
|
response_model=StageExecuteResponse, |
|
|
summary="执行阶段", |
|
|
description=""" |
|
|
执行指定阶段。 |
|
|
|
|
|
**阶段依赖关系**: |
|
|
- `audio_slice`: 无依赖 |
|
|
- `asr`: 依赖 audio_slice |
|
|
- `text_feature`: 依赖 asr |
|
|
- `hubert_feature`: 依赖 audio_slice |
|
|
- `semantic_token`: 依赖 hubert_feature |
|
|
- `sovits_train`: 依赖 text_feature, semantic_token |
|
|
- `gpt_train`: 依赖 text_feature, semantic_token |
|
|
|
|
|
如果依赖阶段未完成,会返回 400 错误。 |
|
|
如果阶段已完成,会重新执行(返回 `rerun: true`)。 |
|
|
""", |
|
|
responses={ |
|
|
400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"}, |
|
|
404: {"model": ErrorResponse, "description": "实验不存在"}, |
|
|
}, |
|
|
) |
|
|
async def execute_stage( |
|
|
exp_id: str, |
|
|
stage_type: str, |
|
|
params: Dict[str, Any] = Body(default={}), |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> StageExecuteResponse: |
|
|
""" |
|
|
执行阶段 |
|
|
""" |
|
|
|
|
|
if stage_type not in VALID_STAGE_TYPES: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}" |
|
|
) |
|
|
|
|
|
|
|
|
experiment = await service.get_experiment(exp_id) |
|
|
if not experiment: |
|
|
raise HTTPException(status_code=404, detail="实验不存在") |
|
|
|
|
|
|
|
|
deps = await service.check_stage_dependencies(exp_id, stage_type) |
|
|
if not deps["satisfied"]: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"依赖阶段未完成: {', '.join(deps['missing'])}" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
params_class = get_stage_params_class(stage_type) |
|
|
validated_params = params_class(**params) |
|
|
params = validated_params.model_dump(exclude_unset=True) |
|
|
except ValueError as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
|
|
|
result = await service.execute_stage(exp_id, stage_type, params) |
|
|
if not result: |
|
|
raise HTTPException(status_code=404, detail="实验不存在") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{exp_id}/stages", |
|
|
response_model=StagesListResponse, |
|
|
summary="获取所有阶段状态", |
|
|
description="获取实验的所有阶段状态列表。", |
|
|
responses={ |
|
|
404: {"model": ErrorResponse, "description": "实验不存在"}, |
|
|
}, |
|
|
) |
|
|
async def get_all_stages( |
|
|
exp_id: str, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> StagesListResponse: |
|
|
""" |
|
|
获取所有阶段状态 |
|
|
""" |
|
|
result = await service.get_all_stages(exp_id) |
|
|
if not result: |
|
|
raise HTTPException(status_code=404, detail="实验不存在") |
|
|
return result |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{exp_id}/stages/{stage_type}", |
|
|
response_model=StageStatus, |
|
|
summary="获取阶段详情", |
|
|
description="获取指定阶段的详细状态和结果。", |
|
|
responses={ |
|
|
400: {"model": ErrorResponse, "description": "阶段类型无效"}, |
|
|
404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
|
|
}, |
|
|
) |
|
|
async def get_stage( |
|
|
exp_id: str, |
|
|
stage_type: str, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> StageStatus: |
|
|
""" |
|
|
获取阶段详情 |
|
|
""" |
|
|
|
|
|
if stage_type not in VALID_STAGE_TYPES: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"无效的阶段类型: {stage_type}" |
|
|
) |
|
|
|
|
|
stage = await service.get_stage(exp_id, stage_type) |
|
|
if not stage: |
|
|
raise HTTPException(status_code=404, detail="实验或阶段不存在") |
|
|
return stage |
|
|
|
|
|
|
|
|
@router.delete( |
|
|
"/{exp_id}/stages/{stage_type}", |
|
|
response_model=SuccessResponse, |
|
|
summary="取消阶段", |
|
|
description="取消正在执行的阶段。只有运行中的阶段可以取消。", |
|
|
responses={ |
|
|
400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"}, |
|
|
404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
|
|
}, |
|
|
) |
|
|
async def cancel_stage( |
|
|
exp_id: str, |
|
|
stage_type: str, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> SuccessResponse: |
|
|
""" |
|
|
取消阶段 |
|
|
""" |
|
|
|
|
|
if stage_type not in VALID_STAGE_TYPES: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"无效的阶段类型: {stage_type}" |
|
|
) |
|
|
|
|
|
success = await service.cancel_stage(exp_id, stage_type) |
|
|
if not success: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="阶段未运行或无法取消" |
|
|
) |
|
|
|
|
|
return SuccessResponse(message=f"阶段 {stage_type} 已取消") |
|
|
|
|
|
|
|
|
@router.get( |
|
|
"/{exp_id}/stages/{stage_type}/progress", |
|
|
summary="SSE 阶段进度订阅", |
|
|
description=""" |
|
|
订阅阶段进度更新(Server-Sent Events)。 |
|
|
|
|
|
返回的事件流格式: |
|
|
``` |
|
|
event: progress |
|
|
data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034} |
|
|
|
|
|
event: checkpoint |
|
|
data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"} |
|
|
|
|
|
event: completed |
|
|
data: {"status": "completed", "final_loss": 0.023} |
|
|
``` |
|
|
""", |
|
|
responses={ |
|
|
400: {"model": ErrorResponse, "description": "阶段类型无效"}, |
|
|
404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
|
|
}, |
|
|
) |
|
|
async def subscribe_stage_progress( |
|
|
exp_id: str, |
|
|
stage_type: str, |
|
|
service: ExperimentService = Depends(get_experiment_service), |
|
|
) -> StreamingResponse: |
|
|
""" |
|
|
SSE 阶段进度订阅 |
|
|
""" |
|
|
|
|
|
if stage_type not in VALID_STAGE_TYPES: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"无效的阶段类型: {stage_type}" |
|
|
) |
|
|
|
|
|
|
|
|
experiment = await service.get_experiment(exp_id) |
|
|
if not experiment: |
|
|
raise HTTPException(status_code=404, detail="实验不存在") |
|
|
|
|
|
async def event_generator(): |
|
|
"""生成 SSE 事件流""" |
|
|
async for progress in service.subscribe_stage_progress(exp_id, stage_type): |
|
|
|
|
|
event_type = progress.get("type", "progress") |
|
|
status = progress.get("status") |
|
|
|
|
|
if status == "completed": |
|
|
event_type = "completed" |
|
|
elif status == "failed": |
|
|
event_type = "failed" |
|
|
elif status == "cancelled": |
|
|
event_type = "cancelled" |
|
|
elif progress.get("model_path"): |
|
|
event_type = "checkpoint" |
|
|
|
|
|
|
|
|
data = json.dumps(progress, ensure_ascii=False) |
|
|
yield f"event: {event_type}\ndata: {data}\n\n" |
|
|
|
|
|
|
|
|
if status in ("completed", "failed", "cancelled"): |
|
|
break |
|
|
|
|
|
return StreamingResponse( |
|
|
event_generator(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"X-Accel-Buffering": "no", |
|
|
}, |
|
|
) |
|
|
|