""" Advanced Mode 实验 API 专家用户分阶段训练 API 端点 API 列表: - POST /experiments 创建实验 - GET /experiments 获取实验列表 - GET /experiments/{exp_id} 获取实验详情 - PATCH /experiments/{exp_id} 更新实验配置 - DELETE /experiments/{exp_id} 删除实验 - POST /experiments/{exp_id}/stages/{stage_type} 执行阶段 - GET /experiments/{exp_id}/stages 获取所有阶段状态 - GET /experiments/{exp_id}/stages/{stage_type} 获取阶段详情 - DELETE /experiments/{exp_id}/stages/{stage_type} 取消阶段 - GET /experiments/{exp_id}/stages/{stage_type}/progress SSE 阶段进度 """ import json from typing import Any, Dict, Optional from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from ....models.schemas.experiment import ( ExperimentCreate, ExperimentUpdate, ExperimentResponse, ExperimentListResponse, StageStatus, StageExecuteResponse, StagesListResponse, STAGE_DEPENDENCIES, get_stage_params_class, ) from ....models.schemas.common import SuccessResponse, ErrorResponse from ....services.experiment_service import ExperimentService from ...deps import get_experiment_service router = APIRouter() # 有效的阶段类型 VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys()) @router.post( "", response_model=ExperimentResponse, summary="创建实验", description=""" 创建实验(专家用户)。 创建实验但不立即执行,用户可以逐阶段控制训练流程。 实验创建后,所有阶段状态为 `pending`,需要手动触发执行。 **训练阶段**: - `audio_slice`: 音频切片 - `asr`: 语音识别 - `text_feature`: 文本特征提取 - `hubert_feature`: HuBERT 特征提取 - `semantic_token`: 语义 Token 提取 - `sovits_train`: SoVITS 训练 - `gpt_train`: GPT 训练 """, ) async def create_experiment( request: ExperimentCreate, service: ExperimentService = Depends(get_experiment_service), ) -> ExperimentResponse: """ 创建实验 """ return await service.create_experiment(request) @router.get( "", response_model=ExperimentListResponse, summary="获取实验列表", description="获取所有实验列表,支持按状态筛选和分页。", ) async def list_experiments( status: Optional[str] = Query(None, description="按状态筛选"), limit: int = Query(50, ge=1, le=100, description="每页数量"), offset: int = Query(0, ge=0, description="偏移量"), service: ExperimentService = Depends(get_experiment_service), ) -> ExperimentListResponse: """ 获取实验列表 """ return await service.list_experiments(status=status, limit=limit, offset=offset) @router.get( "/{exp_id}", response_model=ExperimentResponse, summary="获取实验详情", description="获取指定实验的详细信息,包括所有阶段状态。", responses={ 404: {"model": ErrorResponse, "description": "实验不存在"}, }, ) async def get_experiment( exp_id: str, service: ExperimentService = Depends(get_experiment_service), ) -> ExperimentResponse: """ 获取实验详情 """ experiment = await service.get_experiment(exp_id) if not experiment: raise HTTPException(status_code=404, detail="实验不存在") return experiment @router.patch( "/{exp_id}", response_model=ExperimentResponse, summary="更新实验配置", description="更新实验的基础配置(非阶段参数)。", responses={ 404: {"model": ErrorResponse, "description": "实验不存在"}, }, ) async def update_experiment( exp_id: str, request: ExperimentUpdate, service: ExperimentService = Depends(get_experiment_service), ) -> ExperimentResponse: """ 更新实验配置 """ experiment = await service.update_experiment(exp_id, request) if not experiment: raise HTTPException(status_code=404, detail="实验不存在") return experiment @router.delete( "/{exp_id}", response_model=SuccessResponse, summary="删除实验", description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。", responses={ 404: {"model": ErrorResponse, "description": "实验不存在"}, }, ) async def delete_experiment( exp_id: str, service: ExperimentService = Depends(get_experiment_service), ) -> SuccessResponse: """ 删除实验 """ success = await service.delete_experiment(exp_id) if not success: raise HTTPException(status_code=404, detail="实验不存在") return SuccessResponse(message="实验已删除") @router.post( "/{exp_id}/stages/{stage_type}", response_model=StageExecuteResponse, summary="执行阶段", description=""" 执行指定阶段。 **阶段依赖关系**: - `audio_slice`: 无依赖 - `asr`: 依赖 audio_slice - `text_feature`: 依赖 asr - `hubert_feature`: 依赖 audio_slice - `semantic_token`: 依赖 hubert_feature - `sovits_train`: 依赖 text_feature, semantic_token - `gpt_train`: 依赖 text_feature, semantic_token 如果依赖阶段未完成,会返回 400 错误。 如果阶段已完成,会重新执行(返回 `rerun: true`)。 """, responses={ 400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"}, 404: {"model": ErrorResponse, "description": "实验不存在"}, }, ) async def execute_stage( exp_id: str, stage_type: str, params: Dict[str, Any] = Body(default={}), service: ExperimentService = Depends(get_experiment_service), ) -> StageExecuteResponse: """ 执行阶段 """ # 验证阶段类型 if stage_type not in VALID_STAGE_TYPES: raise HTTPException( status_code=400, detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}" ) # 检查实验是否存在 experiment = await service.get_experiment(exp_id) if not experiment: raise HTTPException(status_code=404, detail="实验不存在") # 检查依赖 deps = await service.check_stage_dependencies(exp_id, stage_type) if not deps["satisfied"]: raise HTTPException( status_code=400, detail=f"依赖阶段未完成: {', '.join(deps['missing'])}" ) # 验证并解析参数 try: params_class = get_stage_params_class(stage_type) validated_params = params_class(**params) params = validated_params.model_dump(exclude_unset=True) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) # 执行阶段 result = await service.execute_stage(exp_id, stage_type, params) if not result: raise HTTPException(status_code=404, detail="实验不存在") return result @router.get( "/{exp_id}/stages", response_model=StagesListResponse, summary="获取所有阶段状态", description="获取实验的所有阶段状态列表。", responses={ 404: {"model": ErrorResponse, "description": "实验不存在"}, }, ) async def get_all_stages( exp_id: str, service: ExperimentService = Depends(get_experiment_service), ) -> StagesListResponse: """ 获取所有阶段状态 """ result = await service.get_all_stages(exp_id) if not result: raise HTTPException(status_code=404, detail="实验不存在") return result @router.get( "/{exp_id}/stages/{stage_type}", response_model=StageStatus, summary="获取阶段详情", description="获取指定阶段的详细状态和结果。", responses={ 400: {"model": ErrorResponse, "description": "阶段类型无效"}, 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, }, ) async def get_stage( exp_id: str, stage_type: str, service: ExperimentService = Depends(get_experiment_service), ) -> StageStatus: """ 获取阶段详情 """ # 验证阶段类型 if stage_type not in VALID_STAGE_TYPES: raise HTTPException( status_code=400, detail=f"无效的阶段类型: {stage_type}" ) stage = await service.get_stage(exp_id, stage_type) if not stage: raise HTTPException(status_code=404, detail="实验或阶段不存在") return stage @router.delete( "/{exp_id}/stages/{stage_type}", response_model=SuccessResponse, summary="取消阶段", description="取消正在执行的阶段。只有运行中的阶段可以取消。", responses={ 400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"}, 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, }, ) async def cancel_stage( exp_id: str, stage_type: str, service: ExperimentService = Depends(get_experiment_service), ) -> SuccessResponse: """ 取消阶段 """ # 验证阶段类型 if stage_type not in VALID_STAGE_TYPES: raise HTTPException( status_code=400, detail=f"无效的阶段类型: {stage_type}" ) success = await service.cancel_stage(exp_id, stage_type) if not success: raise HTTPException( status_code=400, detail="阶段未运行或无法取消" ) return SuccessResponse(message=f"阶段 {stage_type} 已取消") @router.get( "/{exp_id}/stages/{stage_type}/progress", summary="SSE 阶段进度订阅", description=""" 订阅阶段进度更新(Server-Sent Events)。 返回的事件流格式: ``` event: progress data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034} event: checkpoint data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"} event: completed data: {"status": "completed", "final_loss": 0.023} ``` """, responses={ 400: {"model": ErrorResponse, "description": "阶段类型无效"}, 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, }, ) async def subscribe_stage_progress( exp_id: str, stage_type: str, service: ExperimentService = Depends(get_experiment_service), ) -> StreamingResponse: """ SSE 阶段进度订阅 """ # 验证阶段类型 if stage_type not in VALID_STAGE_TYPES: raise HTTPException( status_code=400, detail=f"无效的阶段类型: {stage_type}" ) # 检查实验是否存在 experiment = await service.get_experiment(exp_id) if not experiment: raise HTTPException(status_code=404, detail="实验不存在") async def event_generator(): """生成 SSE 事件流""" async for progress in service.subscribe_stage_progress(exp_id, stage_type): # 确定事件类型 event_type = progress.get("type", "progress") status = progress.get("status") if status == "completed": event_type = "completed" elif status == "failed": event_type = "failed" elif status == "cancelled": event_type = "cancelled" elif progress.get("model_path"): event_type = "checkpoint" # 构建 SSE 格式 data = json.dumps(progress, ensure_ascii=False) yield f"event: {event_type}\ndata: {data}\n\n" # 如果是终态,结束流 if status in ("completed", "failed", "cancelled"): break return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, )