File size: 12,137 Bytes
e054d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
Advanced Mode 实验 API

专家用户分阶段训练 API 端点

API 列表:
- POST   /experiments                                   创建实验
- GET    /experiments                                   获取实验列表
- GET    /experiments/{exp_id}                          获取实验详情
- PATCH  /experiments/{exp_id}                          更新实验配置
- DELETE /experiments/{exp_id}                          删除实验
- POST   /experiments/{exp_id}/stages/{stage_type}      执行阶段
- GET    /experiments/{exp_id}/stages                   获取所有阶段状态
- GET    /experiments/{exp_id}/stages/{stage_type}      获取阶段详情
- DELETE /experiments/{exp_id}/stages/{stage_type}      取消阶段
- GET    /experiments/{exp_id}/stages/{stage_type}/progress  SSE 阶段进度
"""

import json
from typing import Any, Dict, Optional

from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse

from ....models.schemas.experiment import (
    ExperimentCreate,
    ExperimentUpdate,
    ExperimentResponse,
    ExperimentListResponse,
    StageStatus,
    StageExecuteResponse,
    StagesListResponse,
    STAGE_DEPENDENCIES,
    get_stage_params_class,
)
from ....models.schemas.common import SuccessResponse, ErrorResponse
from ....services.experiment_service import ExperimentService
from ...deps import get_experiment_service

router = APIRouter()

# 有效的阶段类型
VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys())


@router.post(
    "",
    response_model=ExperimentResponse,
    summary="创建实验",
    description="""
创建实验(专家用户)。

创建实验但不立即执行,用户可以逐阶段控制训练流程。
实验创建后,所有阶段状态为 `pending`,需要手动触发执行。

**训练阶段**:
- `audio_slice`: 音频切片
- `asr`: 语音识别
- `text_feature`: 文本特征提取
- `hubert_feature`: HuBERT 特征提取
- `semantic_token`: 语义 Token 提取
- `sovits_train`: SoVITS 训练
- `gpt_train`: GPT 训练
""",
)
async def create_experiment(
    request: ExperimentCreate,
    service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentResponse:
    """
    创建实验
    """
    return await service.create_experiment(request)


@router.get(
    "",
    response_model=ExperimentListResponse,
    summary="获取实验列表",
    description="获取所有实验列表,支持按状态筛选和分页。",
)
async def list_experiments(
    status: Optional[str] = Query(None, description="按状态筛选"),
    limit: int = Query(50, ge=1, le=100, description="每页数量"),
    offset: int = Query(0, ge=0, description="偏移量"),
    service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentListResponse:
    """
    获取实验列表
    """
    return await service.list_experiments(status=status, limit=limit, offset=offset)


@router.get(
    "/{exp_id}",
    response_model=ExperimentResponse,
    summary="获取实验详情",
    description="获取指定实验的详细信息,包括所有阶段状态。",
    responses={
        404: {"model": ErrorResponse, "description": "实验不存在"},
    },
)
async def get_experiment(
    exp_id: str,
    service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentResponse:
    """
    获取实验详情
    """
    experiment = await service.get_experiment(exp_id)
    if not experiment:
        raise HTTPException(status_code=404, detail="实验不存在")
    return experiment


@router.patch(
    "/{exp_id}",
    response_model=ExperimentResponse,
    summary="更新实验配置",
    description="更新实验的基础配置(非阶段参数)。",
    responses={
        404: {"model": ErrorResponse, "description": "实验不存在"},
    },
)
async def update_experiment(
    exp_id: str,
    request: ExperimentUpdate,
    service: ExperimentService = Depends(get_experiment_service),
) -> ExperimentResponse:
    """
    更新实验配置
    """
    experiment = await service.update_experiment(exp_id, request)
    if not experiment:
        raise HTTPException(status_code=404, detail="实验不存在")
    return experiment


@router.delete(
    "/{exp_id}",
    response_model=SuccessResponse,
    summary="删除实验",
    description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。",
    responses={
        404: {"model": ErrorResponse, "description": "实验不存在"},
    },
)
async def delete_experiment(
    exp_id: str,
    service: ExperimentService = Depends(get_experiment_service),
) -> SuccessResponse:
    """
    删除实验
    """
    success = await service.delete_experiment(exp_id)
    if not success:
        raise HTTPException(status_code=404, detail="实验不存在")
    return SuccessResponse(message="实验已删除")


@router.post(
    "/{exp_id}/stages/{stage_type}",
    response_model=StageExecuteResponse,
    summary="执行阶段",
    description="""
执行指定阶段。

**阶段依赖关系**:
- `audio_slice`: 无依赖
- `asr`: 依赖 audio_slice
- `text_feature`: 依赖 asr
- `hubert_feature`: 依赖 audio_slice
- `semantic_token`: 依赖 hubert_feature
- `sovits_train`: 依赖 text_feature, semantic_token
- `gpt_train`: 依赖 text_feature, semantic_token

如果依赖阶段未完成,会返回 400 错误。
如果阶段已完成,会重新执行(返回 `rerun: true`)。
""",
    responses={
        400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"},
        404: {"model": ErrorResponse, "description": "实验不存在"},
    },
)
async def execute_stage(
    exp_id: str,
    stage_type: str,
    params: Dict[str, Any] = Body(default={}),
    service: ExperimentService = Depends(get_experiment_service),
) -> StageExecuteResponse:
    """
    执行阶段
    """
    # 验证阶段类型
    if stage_type not in VALID_STAGE_TYPES:
        raise HTTPException(
            status_code=400,
            detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}"
        )
    
    # 检查实验是否存在
    experiment = await service.get_experiment(exp_id)
    if not experiment:
        raise HTTPException(status_code=404, detail="实验不存在")
    
    # 检查依赖
    deps = await service.check_stage_dependencies(exp_id, stage_type)
    if not deps["satisfied"]:
        raise HTTPException(
            status_code=400,
            detail=f"依赖阶段未完成: {', '.join(deps['missing'])}"
        )
    
    # 验证并解析参数
    try:
        params_class = get_stage_params_class(stage_type)
        validated_params = params_class(**params)
        params = validated_params.model_dump(exclude_unset=True)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    
    # 执行阶段
    result = await service.execute_stage(exp_id, stage_type, params)
    if not result:
        raise HTTPException(status_code=404, detail="实验不存在")
    
    return result


@router.get(
    "/{exp_id}/stages",
    response_model=StagesListResponse,
    summary="获取所有阶段状态",
    description="获取实验的所有阶段状态列表。",
    responses={
        404: {"model": ErrorResponse, "description": "实验不存在"},
    },
)
async def get_all_stages(
    exp_id: str,
    service: ExperimentService = Depends(get_experiment_service),
) -> StagesListResponse:
    """
    获取所有阶段状态
    """
    result = await service.get_all_stages(exp_id)
    if not result:
        raise HTTPException(status_code=404, detail="实验不存在")
    return result


@router.get(
    "/{exp_id}/stages/{stage_type}",
    response_model=StageStatus,
    summary="获取阶段详情",
    description="获取指定阶段的详细状态和结果。",
    responses={
        400: {"model": ErrorResponse, "description": "阶段类型无效"},
        404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
    },
)
async def get_stage(
    exp_id: str,
    stage_type: str,
    service: ExperimentService = Depends(get_experiment_service),
) -> StageStatus:
    """
    获取阶段详情
    """
    # 验证阶段类型
    if stage_type not in VALID_STAGE_TYPES:
        raise HTTPException(
            status_code=400,
            detail=f"无效的阶段类型: {stage_type}"
        )
    
    stage = await service.get_stage(exp_id, stage_type)
    if not stage:
        raise HTTPException(status_code=404, detail="实验或阶段不存在")
    return stage


@router.delete(
    "/{exp_id}/stages/{stage_type}",
    response_model=SuccessResponse,
    summary="取消阶段",
    description="取消正在执行的阶段。只有运行中的阶段可以取消。",
    responses={
        400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"},
        404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
    },
)
async def cancel_stage(
    exp_id: str,
    stage_type: str,
    service: ExperimentService = Depends(get_experiment_service),
) -> SuccessResponse:
    """
    取消阶段
    """
    # 验证阶段类型
    if stage_type not in VALID_STAGE_TYPES:
        raise HTTPException(
            status_code=400,
            detail=f"无效的阶段类型: {stage_type}"
        )
    
    success = await service.cancel_stage(exp_id, stage_type)
    if not success:
        raise HTTPException(
            status_code=400,
            detail="阶段未运行或无法取消"
        )
    
    return SuccessResponse(message=f"阶段 {stage_type} 已取消")


@router.get(
    "/{exp_id}/stages/{stage_type}/progress",
    summary="SSE 阶段进度订阅",
    description="""
订阅阶段进度更新(Server-Sent Events)。

返回的事件流格式:
```
event: progress
data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034}

event: checkpoint
data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"}

event: completed
data: {"status": "completed", "final_loss": 0.023}
```
""",
    responses={
        400: {"model": ErrorResponse, "description": "阶段类型无效"},
        404: {"model": ErrorResponse, "description": "实验或阶段不存在"},
    },
)
async def subscribe_stage_progress(
    exp_id: str,
    stage_type: str,
    service: ExperimentService = Depends(get_experiment_service),
) -> StreamingResponse:
    """
    SSE 阶段进度订阅
    """
    # 验证阶段类型
    if stage_type not in VALID_STAGE_TYPES:
        raise HTTPException(
            status_code=400,
            detail=f"无效的阶段类型: {stage_type}"
        )
    
    # 检查实验是否存在
    experiment = await service.get_experiment(exp_id)
    if not experiment:
        raise HTTPException(status_code=404, detail="实验不存在")
    
    async def event_generator():
        """生成 SSE 事件流"""
        async for progress in service.subscribe_stage_progress(exp_id, stage_type):
            # 确定事件类型
            event_type = progress.get("type", "progress")
            status = progress.get("status")
            
            if status == "completed":
                event_type = "completed"
            elif status == "failed":
                event_type = "failed"
            elif status == "cancelled":
                event_type = "cancelled"
            elif progress.get("model_path"):
                event_type = "checkpoint"
            
            # 构建 SSE 格式
            data = json.dumps(progress, ensure_ascii=False)
            yield f"event: {event_type}\ndata: {data}\n\n"
            
            # 如果是终态,结束流
            if status in ("completed", "failed", "cancelled"):
                break
    
    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )