#!/usr/bin/env python3 """ Pipeline 包装脚本 此脚本作为独立子进程运行,执行 TrainingPipeline 并将进度以 JSON 格式输出到 stdout。 主进程(AsyncTrainingManager)通过解析 stdout 来获取实时进度。 进度消息格式: ##PROGRESS##{"type": "progress", "stage": "...", ...}## Usage: python run_pipeline.py --config /path/to/config.json --task-id task-123 """ import argparse import json import sys import os import traceback from datetime import datetime from typing import Dict, Any # 确保可以导入项目模块(在导入其他模块之前) from pathlib import Path _SCRIPT_DIR = Path(__file__).parent.resolve() _API_SERVER_ROOT = _SCRIPT_DIR.parent.parent _PROJECT_ROOT = _API_SERVER_ROOT.parent sys.path.insert(0, str(_PROJECT_ROOT)) # 导入配置模块 from project_config import settings, PROJECT_ROOT, get_pythonpath # 进度消息前缀和后缀,用于主进程解析 PROGRESS_PREFIX = "##PROGRESS##" PROGRESS_SUFFIX = "##" def emit_progress(progress_info: Dict[str, Any]) -> None: """ 输出进度消息到 stdout Args: progress_info: 进度信息字典 """ # 确保有时间戳 if "timestamp" not in progress_info: progress_info["timestamp"] = datetime.utcnow().isoformat() json_str = json.dumps(progress_info, ensure_ascii=False) print(f"{PROGRESS_PREFIX}{json_str}{PROGRESS_SUFFIX}", flush=True) def emit_log(level: str, message: str, **extra) -> None: """ 输出日志消息 Args: level: 日志级别 (info, warning, error) message: 日志消息 **extra: 额外数据 """ emit_progress({ "type": "log", "level": level, "message": message, **extra }) def load_config(config_path: str) -> Dict[str, Any]: """ 加载配置文件 Args: config_path: 配置文件路径 Returns: 配置字典 """ with open(config_path, 'r', encoding='utf-8') as f: return json.load(f) def build_pipeline(config: Dict[str, Any]): """ 根据配置构建 TrainingPipeline Args: config: 配置字典,包含: - exp_name: 实验名称 - version: 模型版本 - stages: 要执行的阶段列表 - 各阶段的具体配置 Returns: TrainingPipeline 实例 """ from training_pipeline import ( TrainingPipeline, ModelVersion, # 配置类 AudioSliceConfig, ASRConfig, DenoiseConfig, FeatureExtractionConfig, SoVITSTrainConfig, GPTTrainConfig, InferenceConfig, # 阶段类 AudioSliceStage, ASRStage, DenoiseStage, TextFeatureStage, HuBERTFeatureStage, SemanticTokenStage, SoVITSTrainStage, GPTTrainStage, InferenceStage, ) pipeline = TrainingPipeline() exp_name = config["exp_name"] version_str = config.get("version", "v2") version = ModelVersion(version_str) if isinstance(version_str, str) else version_str # 通用配置参数 base_params = { "exp_name": exp_name, "exp_root": config.get("exp_root", "logs"), "gpu_numbers": config.get("gpu_numbers", "0"), "is_half": config.get("is_half", True), } # 阶段配置映射 stage_builders = { "audio_slice": lambda cfg: AudioSliceStage(AudioSliceConfig( **base_params, input_path=cfg.get("input_path", ""), threshold=cfg.get("threshold", -34), min_length=cfg.get("min_length", 4000), min_interval=cfg.get("min_interval", 300), hop_size=cfg.get("hop_size", 10), max_sil_kept=cfg.get("max_sil_kept", 500), max_amp=cfg.get("max_amp", 0.9), alpha=cfg.get("alpha", 0.25), n_parts=cfg.get("n_parts", 4), )), "asr": lambda cfg: ASRStage(ASRConfig( **base_params, model=cfg.get("model", "达摩 ASR (中文)"), model_size=cfg.get("model_size", "large"), language=cfg.get("language", "zh"), precision=cfg.get("precision", "float32"), )), "denoise": lambda cfg: DenoiseStage(DenoiseConfig( **base_params, input_dir=cfg.get("input_dir", ""), output_dir=cfg.get("output_dir", "output/denoise_opt"), )), "text_feature": lambda cfg: TextFeatureStage(FeatureExtractionConfig( **base_params, version=version, bert_pretrained_dir=cfg.get("bert_pretrained_dir", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", "GPT_SoVITS/pretrained_models/chinese-hubert-base"), pretrained_s2G=cfg.get("pretrained_s2G", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), )), "hubert_feature": lambda cfg: HuBERTFeatureStage(FeatureExtractionConfig( **base_params, version=version, bert_pretrained_dir=cfg.get("bert_pretrained_dir", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", "GPT_SoVITS/pretrained_models/chinese-hubert-base"), pretrained_s2G=cfg.get("pretrained_s2G", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), )), "semantic_token": lambda cfg: SemanticTokenStage(FeatureExtractionConfig( **base_params, version=version, bert_pretrained_dir=cfg.get("bert_pretrained_dir", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", "GPT_SoVITS/pretrained_models/chinese-hubert-base"), pretrained_s2G=cfg.get("pretrained_s2G", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), )), "sovits_train": lambda cfg: SoVITSTrainStage(SoVITSTrainConfig( **base_params, version=version, batch_size=cfg.get("batch_size", 4), total_epoch=cfg.get("total_epoch", 8), text_low_lr_rate=cfg.get("text_low_lr_rate", 0.4), save_every_epoch=cfg.get("save_every_epoch", 4), if_save_latest=cfg.get("if_save_latest", True), if_save_every_weights=cfg.get("if_save_every_weights", True), pretrained_s2G=cfg.get("pretrained_s2G", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), pretrained_s2D=cfg.get("pretrained_s2D", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth"), if_grad_ckpt=cfg.get("if_grad_ckpt", False), lora_rank=cfg.get("lora_rank", 32), )), "gpt_train": lambda cfg: GPTTrainStage(GPTTrainConfig( **base_params, version=version, batch_size=cfg.get("batch_size", 4), total_epoch=cfg.get("total_epoch", 15), save_every_epoch=cfg.get("save_every_epoch", 5), if_save_latest=cfg.get("if_save_latest", True), if_save_every_weights=cfg.get("if_save_every_weights", True), if_dpo=cfg.get("if_dpo", False), pretrained_s1=cfg.get("pretrained_s1", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"), )), "inference": lambda cfg: InferenceStage(InferenceConfig( **base_params, version=version, gpt_path=cfg.get("gpt_path", ""), sovits_path=cfg.get("sovits_path", ""), ref_text=cfg.get("ref_text", ""), ref_audio_path=cfg.get("ref_audio_path", ""), target_text=cfg.get("target_text", ""), text_split_method=cfg.get("text_split_method", "cut1"), )), } # 按顺序添加阶段 # stages 可以是: # 1. 字符串列表: ["audio_slice", "asr", ...] # 2. 字典列表: [{"type": "audio_slice", "threshold": -30}, ...] stages = config.get("stages", []) for stage_item in stages: # 判断是字符串还是字典 if isinstance(stage_item, str): stage_type = stage_item stage_config = config # 使用全局配置作为阶段配置 elif isinstance(stage_item, dict): stage_type = stage_item.get("type") # 合并全局配置和阶段特定配置 stage_config = {**config, **stage_item} else: emit_log("warning", f"无效的阶段配置类型: {type(stage_item)}") continue if stage_type in stage_builders: stage = stage_builders[stage_type](stage_config) pipeline.add_stage(stage) emit_log("info", f"已添加阶段: {stage.name}") else: emit_log("warning", f"未知阶段类型: {stage_type}") return pipeline def run_pipeline(config: Dict[str, Any], task_id: str) -> bool: """ 执行 Pipeline Args: config: 配置字典 task_id: 任务ID Returns: 是否成功完成 """ emit_progress({ "type": "progress", "status": "running", "message": "正在初始化训练流水线...", "task_id": task_id, "progress": 0.0, "overall_progress": 0.0, }) try: pipeline = build_pipeline(config) stages = pipeline.get_stages() if not stages: emit_progress({ "type": "progress", "status": "failed", "message": "没有配置任何训练阶段", "task_id": task_id, }) return False emit_log("info", f"训练流水线已初始化,共 {len(stages)} 个阶段") # 执行 Pipeline for progress in pipeline.run(): # 转换进度格式 emit_progress({ "type": "progress", "status": "running", "stage": progress.get("stage"), "stage_index": progress.get("stage_index"), "total_stages": progress.get("total_stages"), "progress": progress.get("progress", 0.0), "overall_progress": progress.get("overall_progress", 0.0), "message": progress.get("message"), "task_id": task_id, "data": progress.get("data", {}), }) # 检查是否失败 if progress.get("status") == "failed": emit_progress({ "type": "progress", "status": "failed", "stage": progress.get("stage"), "message": progress.get("message", "阶段执行失败"), "task_id": task_id, }) return False # 完成 emit_progress({ "type": "progress", "status": "completed", "message": "训练流水线执行完成", "task_id": task_id, "progress": 1.0, "overall_progress": 1.0, }) return True except Exception as e: error_msg = str(e) error_trace = traceback.format_exc() emit_progress({ "type": "progress", "status": "failed", "message": f"执行出错: {error_msg}", "error": error_msg, "traceback": error_trace, "task_id": task_id, }) return False def main(): """主函数""" parser = argparse.ArgumentParser(description="执行 GPT-SoVITS 训练流水线") parser.add_argument("--config", required=True, help="配置文件路径 (JSON)") parser.add_argument("--task-id", required=True, help="任务ID") args = parser.parse_args() emit_log("info", f"启动训练任务: {args.task_id}") emit_log("info", f"配置文件: {args.config}") try: config = load_config(args.config) except Exception as e: emit_progress({ "type": "progress", "status": "failed", "message": f"加载配置文件失败: {e}", "task_id": args.task_id, }) sys.exit(1) success = run_pipeline(config, args.task_id) sys.exit(0 if success else 1) if __name__ == "__main__": main()