MoYoYo.tts / api_server /app /scripts /run_pipeline.py
liumaolin
refactor(config): centralize configuration management in `project_config`
8f68d0a
#!/usr/bin/env python3
"""
Pipeline 包装脚本
此脚本作为独立子进程运行,执行 TrainingPipeline 并将进度以 JSON 格式输出到 stdout。
主进程(AsyncTrainingManager)通过解析 stdout 来获取实时进度。
进度消息格式:
##PROGRESS##{"type": "progress", "stage": "...", ...}##
Usage:
python run_pipeline.py --config /path/to/config.json --task-id task-123
"""
import argparse
import json
import sys
import os
import traceback
from datetime import datetime
from typing import Dict, Any
# 确保可以导入项目模块(在导入其他模块之前)
from pathlib import Path
_SCRIPT_DIR = Path(__file__).parent.resolve()
_API_SERVER_ROOT = _SCRIPT_DIR.parent.parent
_PROJECT_ROOT = _API_SERVER_ROOT.parent
sys.path.insert(0, str(_PROJECT_ROOT))
# 导入配置模块
from project_config import settings, PROJECT_ROOT, get_pythonpath
# 进度消息前缀和后缀,用于主进程解析
PROGRESS_PREFIX = "##PROGRESS##"
PROGRESS_SUFFIX = "##"
def emit_progress(progress_info: Dict[str, Any]) -> None:
"""
输出进度消息到 stdout
Args:
progress_info: 进度信息字典
"""
# 确保有时间戳
if "timestamp" not in progress_info:
progress_info["timestamp"] = datetime.utcnow().isoformat()
json_str = json.dumps(progress_info, ensure_ascii=False)
print(f"{PROGRESS_PREFIX}{json_str}{PROGRESS_SUFFIX}", flush=True)
def emit_log(level: str, message: str, **extra) -> None:
"""
输出日志消息
Args:
level: 日志级别 (info, warning, error)
message: 日志消息
**extra: 额外数据
"""
emit_progress({
"type": "log",
"level": level,
"message": message,
**extra
})
def load_config(config_path: str) -> Dict[str, Any]:
"""
加载配置文件
Args:
config_path: 配置文件路径
Returns:
配置字典
"""
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def build_pipeline(config: Dict[str, Any]):
"""
根据配置构建 TrainingPipeline
Args:
config: 配置字典,包含:
- exp_name: 实验名称
- version: 模型版本
- stages: 要执行的阶段列表
- 各阶段的具体配置
Returns:
TrainingPipeline 实例
"""
from training_pipeline import (
TrainingPipeline,
ModelVersion,
# 配置类
AudioSliceConfig,
ASRConfig,
DenoiseConfig,
FeatureExtractionConfig,
SoVITSTrainConfig,
GPTTrainConfig,
InferenceConfig,
# 阶段类
AudioSliceStage,
ASRStage,
DenoiseStage,
TextFeatureStage,
HuBERTFeatureStage,
SemanticTokenStage,
SoVITSTrainStage,
GPTTrainStage,
InferenceStage,
)
pipeline = TrainingPipeline()
exp_name = config["exp_name"]
version_str = config.get("version", "v2")
version = ModelVersion(version_str) if isinstance(version_str, str) else version_str
# 通用配置参数
base_params = {
"exp_name": exp_name,
"exp_root": config.get("exp_root", "logs"),
"gpu_numbers": config.get("gpu_numbers", "0"),
"is_half": config.get("is_half", True),
}
# 阶段配置映射
stage_builders = {
"audio_slice": lambda cfg: AudioSliceStage(AudioSliceConfig(
**base_params,
input_path=cfg.get("input_path", ""),
threshold=cfg.get("threshold", -34),
min_length=cfg.get("min_length", 4000),
min_interval=cfg.get("min_interval", 300),
hop_size=cfg.get("hop_size", 10),
max_sil_kept=cfg.get("max_sil_kept", 500),
max_amp=cfg.get("max_amp", 0.9),
alpha=cfg.get("alpha", 0.25),
n_parts=cfg.get("n_parts", 4),
)),
"asr": lambda cfg: ASRStage(ASRConfig(
**base_params,
model=cfg.get("model", "达摩 ASR (中文)"),
model_size=cfg.get("model_size", "large"),
language=cfg.get("language", "zh"),
precision=cfg.get("precision", "float32"),
)),
"denoise": lambda cfg: DenoiseStage(DenoiseConfig(
**base_params,
input_dir=cfg.get("input_dir", ""),
output_dir=cfg.get("output_dir", "output/denoise_opt"),
)),
"text_feature": lambda cfg: TextFeatureStage(FeatureExtractionConfig(
**base_params,
version=version,
bert_pretrained_dir=cfg.get("bert_pretrained_dir",
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
"GPT_SoVITS/pretrained_models/chinese-hubert-base"),
pretrained_s2G=cfg.get("pretrained_s2G",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
)),
"hubert_feature": lambda cfg: HuBERTFeatureStage(FeatureExtractionConfig(
**base_params,
version=version,
bert_pretrained_dir=cfg.get("bert_pretrained_dir",
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
"GPT_SoVITS/pretrained_models/chinese-hubert-base"),
pretrained_s2G=cfg.get("pretrained_s2G",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
)),
"semantic_token": lambda cfg: SemanticTokenStage(FeatureExtractionConfig(
**base_params,
version=version,
bert_pretrained_dir=cfg.get("bert_pretrained_dir",
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"),
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir",
"GPT_SoVITS/pretrained_models/chinese-hubert-base"),
pretrained_s2G=cfg.get("pretrained_s2G",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
)),
"sovits_train": lambda cfg: SoVITSTrainStage(SoVITSTrainConfig(
**base_params,
version=version,
batch_size=cfg.get("batch_size", 4),
total_epoch=cfg.get("total_epoch", 8),
text_low_lr_rate=cfg.get("text_low_lr_rate", 0.4),
save_every_epoch=cfg.get("save_every_epoch", 4),
if_save_latest=cfg.get("if_save_latest", True),
if_save_every_weights=cfg.get("if_save_every_weights", True),
pretrained_s2G=cfg.get("pretrained_s2G",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"),
pretrained_s2D=cfg.get("pretrained_s2D",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth"),
if_grad_ckpt=cfg.get("if_grad_ckpt", False),
lora_rank=cfg.get("lora_rank", 32),
)),
"gpt_train": lambda cfg: GPTTrainStage(GPTTrainConfig(
**base_params,
version=version,
batch_size=cfg.get("batch_size", 4),
total_epoch=cfg.get("total_epoch", 15),
save_every_epoch=cfg.get("save_every_epoch", 5),
if_save_latest=cfg.get("if_save_latest", True),
if_save_every_weights=cfg.get("if_save_every_weights", True),
if_dpo=cfg.get("if_dpo", False),
pretrained_s1=cfg.get("pretrained_s1",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"),
)),
"inference": lambda cfg: InferenceStage(InferenceConfig(
**base_params,
version=version,
gpt_path=cfg.get("gpt_path", ""),
sovits_path=cfg.get("sovits_path", ""),
ref_text=cfg.get("ref_text", ""),
ref_audio_path=cfg.get("ref_audio_path", ""),
target_text=cfg.get("target_text", ""),
text_split_method=cfg.get("text_split_method", "cut1"),
)),
}
# 按顺序添加阶段
# stages 可以是:
# 1. 字符串列表: ["audio_slice", "asr", ...]
# 2. 字典列表: [{"type": "audio_slice", "threshold": -30}, ...]
stages = config.get("stages", [])
for stage_item in stages:
# 判断是字符串还是字典
if isinstance(stage_item, str):
stage_type = stage_item
stage_config = config # 使用全局配置作为阶段配置
elif isinstance(stage_item, dict):
stage_type = stage_item.get("type")
# 合并全局配置和阶段特定配置
stage_config = {**config, **stage_item}
else:
emit_log("warning", f"无效的阶段配置类型: {type(stage_item)}")
continue
if stage_type in stage_builders:
stage = stage_builders[stage_type](stage_config)
pipeline.add_stage(stage)
emit_log("info", f"已添加阶段: {stage.name}")
else:
emit_log("warning", f"未知阶段类型: {stage_type}")
return pipeline
def run_pipeline(config: Dict[str, Any], task_id: str) -> bool:
"""
执行 Pipeline
Args:
config: 配置字典
task_id: 任务ID
Returns:
是否成功完成
"""
emit_progress({
"type": "progress",
"status": "running",
"message": "正在初始化训练流水线...",
"task_id": task_id,
"progress": 0.0,
"overall_progress": 0.0,
})
try:
pipeline = build_pipeline(config)
stages = pipeline.get_stages()
if not stages:
emit_progress({
"type": "progress",
"status": "failed",
"message": "没有配置任何训练阶段",
"task_id": task_id,
})
return False
emit_log("info", f"训练流水线已初始化,共 {len(stages)} 个阶段")
# 执行 Pipeline
for progress in pipeline.run():
# 转换进度格式
emit_progress({
"type": "progress",
"status": "running",
"stage": progress.get("stage"),
"stage_index": progress.get("stage_index"),
"total_stages": progress.get("total_stages"),
"progress": progress.get("progress", 0.0),
"overall_progress": progress.get("overall_progress", 0.0),
"message": progress.get("message"),
"task_id": task_id,
"data": progress.get("data", {}),
})
# 检查是否失败
if progress.get("status") == "failed":
emit_progress({
"type": "progress",
"status": "failed",
"stage": progress.get("stage"),
"message": progress.get("message", "阶段执行失败"),
"task_id": task_id,
})
return False
# 完成
emit_progress({
"type": "progress",
"status": "completed",
"message": "训练流水线执行完成",
"task_id": task_id,
"progress": 1.0,
"overall_progress": 1.0,
})
return True
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
emit_progress({
"type": "progress",
"status": "failed",
"message": f"执行出错: {error_msg}",
"error": error_msg,
"traceback": error_trace,
"task_id": task_id,
})
return False
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="执行 GPT-SoVITS 训练流水线")
parser.add_argument("--config", required=True, help="配置文件路径 (JSON)")
parser.add_argument("--task-id", required=True, help="任务ID")
args = parser.parse_args()
emit_log("info", f"启动训练任务: {args.task_id}")
emit_log("info", f"配置文件: {args.config}")
try:
config = load_config(args.config)
except Exception as e:
emit_progress({
"type": "progress",
"status": "failed",
"message": f"加载配置文件失败: {e}",
"task_id": args.task_id,
})
sys.exit(1)
success = run_pipeline(config, args.task_id)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()