|
|
|
|
|
""" |
|
|
Pipeline 包装脚本 |
|
|
|
|
|
此脚本作为独立子进程运行,执行 TrainingPipeline 并将进度以 JSON 格式输出到 stdout。 |
|
|
主进程(AsyncTrainingManager)通过解析 stdout 来获取实时进度。 |
|
|
|
|
|
进度消息格式: |
|
|
##PROGRESS##{"type": "progress", "stage": "...", ...}## |
|
|
|
|
|
Usage: |
|
|
python run_pipeline.py --config /path/to/config.json --task-id task-123 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import sys |
|
|
import os |
|
|
import traceback |
|
|
from datetime import datetime |
|
|
from typing import Dict, Any |
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
_SCRIPT_DIR = Path(__file__).parent.resolve() |
|
|
_API_SERVER_ROOT = _SCRIPT_DIR.parent.parent |
|
|
_PROJECT_ROOT = _API_SERVER_ROOT.parent |
|
|
sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
|
|
|
|
|
|
from project_config import settings, PROJECT_ROOT, get_pythonpath |
|
|
|
|
|
|
|
|
|
|
|
PROGRESS_PREFIX = "##PROGRESS##" |
|
|
PROGRESS_SUFFIX = "##" |
|
|
|
|
|
|
|
|
def emit_progress(progress_info: Dict[str, Any]) -> None: |
|
|
""" |
|
|
输出进度消息到 stdout |
|
|
|
|
|
Args: |
|
|
progress_info: 进度信息字典 |
|
|
""" |
|
|
|
|
|
if "timestamp" not in progress_info: |
|
|
progress_info["timestamp"] = datetime.utcnow().isoformat() |
|
|
|
|
|
json_str = json.dumps(progress_info, ensure_ascii=False) |
|
|
print(f"{PROGRESS_PREFIX}{json_str}{PROGRESS_SUFFIX}", flush=True) |
|
|
|
|
|
|
|
|
def emit_log(level: str, message: str, **extra) -> None: |
|
|
""" |
|
|
输出日志消息 |
|
|
|
|
|
Args: |
|
|
level: 日志级别 (info, warning, error) |
|
|
message: 日志消息 |
|
|
**extra: 额外数据 |
|
|
""" |
|
|
emit_progress({ |
|
|
"type": "log", |
|
|
"level": level, |
|
|
"message": message, |
|
|
**extra |
|
|
}) |
|
|
|
|
|
|
|
|
def load_config(config_path: str) -> Dict[str, Any]: |
|
|
""" |
|
|
加载配置文件 |
|
|
|
|
|
Args: |
|
|
config_path: 配置文件路径 |
|
|
|
|
|
Returns: |
|
|
配置字典 |
|
|
""" |
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def build_pipeline(config: Dict[str, Any]): |
|
|
""" |
|
|
根据配置构建 TrainingPipeline |
|
|
|
|
|
Args: |
|
|
config: 配置字典,包含: |
|
|
- exp_name: 实验名称 |
|
|
- version: 模型版本 |
|
|
- stages: 要执行的阶段列表 |
|
|
- 各阶段的具体配置 |
|
|
|
|
|
Returns: |
|
|
TrainingPipeline 实例 |
|
|
""" |
|
|
from training_pipeline import ( |
|
|
TrainingPipeline, |
|
|
ModelVersion, |
|
|
|
|
|
AudioSliceConfig, |
|
|
ASRConfig, |
|
|
DenoiseConfig, |
|
|
FeatureExtractionConfig, |
|
|
SoVITSTrainConfig, |
|
|
GPTTrainConfig, |
|
|
InferenceConfig, |
|
|
|
|
|
AudioSliceStage, |
|
|
ASRStage, |
|
|
DenoiseStage, |
|
|
TextFeatureStage, |
|
|
HuBERTFeatureStage, |
|
|
SemanticTokenStage, |
|
|
SoVITSTrainStage, |
|
|
GPTTrainStage, |
|
|
InferenceStage, |
|
|
) |
|
|
|
|
|
pipeline = TrainingPipeline() |
|
|
|
|
|
exp_name = config["exp_name"] |
|
|
version_str = config.get("version", "v2") |
|
|
version = ModelVersion(version_str) if isinstance(version_str, str) else version_str |
|
|
|
|
|
|
|
|
base_params = { |
|
|
"exp_name": exp_name, |
|
|
"exp_root": config.get("exp_root", "logs"), |
|
|
"gpu_numbers": config.get("gpu_numbers", "0"), |
|
|
"is_half": config.get("is_half", True), |
|
|
} |
|
|
|
|
|
|
|
|
stage_builders = { |
|
|
"audio_slice": lambda cfg: AudioSliceStage(AudioSliceConfig( |
|
|
**base_params, |
|
|
input_path=cfg.get("input_path", ""), |
|
|
threshold=cfg.get("threshold", -34), |
|
|
min_length=cfg.get("min_length", 4000), |
|
|
min_interval=cfg.get("min_interval", 300), |
|
|
hop_size=cfg.get("hop_size", 10), |
|
|
max_sil_kept=cfg.get("max_sil_kept", 500), |
|
|
max_amp=cfg.get("max_amp", 0.9), |
|
|
alpha=cfg.get("alpha", 0.25), |
|
|
n_parts=cfg.get("n_parts", 4), |
|
|
)), |
|
|
|
|
|
"asr": lambda cfg: ASRStage(ASRConfig( |
|
|
**base_params, |
|
|
model=cfg.get("model", "达摩 ASR (中文)"), |
|
|
model_size=cfg.get("model_size", "large"), |
|
|
language=cfg.get("language", "zh"), |
|
|
precision=cfg.get("precision", "float32"), |
|
|
)), |
|
|
|
|
|
"denoise": lambda cfg: DenoiseStage(DenoiseConfig( |
|
|
**base_params, |
|
|
input_dir=cfg.get("input_dir", ""), |
|
|
output_dir=cfg.get("output_dir", "output/denoise_opt"), |
|
|
)), |
|
|
|
|
|
"text_feature": lambda cfg: TextFeatureStage(FeatureExtractionConfig( |
|
|
**base_params, |
|
|
version=version, |
|
|
bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
|
|
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
|
|
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
|
|
"GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
|
|
pretrained_s2G=cfg.get("pretrained_s2G", |
|
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
|
|
)), |
|
|
|
|
|
"hubert_feature": lambda cfg: HuBERTFeatureStage(FeatureExtractionConfig( |
|
|
**base_params, |
|
|
version=version, |
|
|
bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
|
|
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
|
|
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
|
|
"GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
|
|
pretrained_s2G=cfg.get("pretrained_s2G", |
|
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
|
|
)), |
|
|
|
|
|
"semantic_token": lambda cfg: SemanticTokenStage(FeatureExtractionConfig( |
|
|
**base_params, |
|
|
version=version, |
|
|
bert_pretrained_dir=cfg.get("bert_pretrained_dir", |
|
|
"GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"), |
|
|
ssl_pretrained_dir=cfg.get("ssl_pretrained_dir", |
|
|
"GPT_SoVITS/pretrained_models/chinese-hubert-base"), |
|
|
pretrained_s2G=cfg.get("pretrained_s2G", |
|
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
|
|
)), |
|
|
|
|
|
"sovits_train": lambda cfg: SoVITSTrainStage(SoVITSTrainConfig( |
|
|
**base_params, |
|
|
version=version, |
|
|
batch_size=cfg.get("batch_size", 4), |
|
|
total_epoch=cfg.get("total_epoch", 8), |
|
|
text_low_lr_rate=cfg.get("text_low_lr_rate", 0.4), |
|
|
save_every_epoch=cfg.get("save_every_epoch", 4), |
|
|
if_save_latest=cfg.get("if_save_latest", True), |
|
|
if_save_every_weights=cfg.get("if_save_every_weights", True), |
|
|
pretrained_s2G=cfg.get("pretrained_s2G", |
|
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"), |
|
|
pretrained_s2D=cfg.get("pretrained_s2D", |
|
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2D2333k.pth"), |
|
|
if_grad_ckpt=cfg.get("if_grad_ckpt", False), |
|
|
lora_rank=cfg.get("lora_rank", 32), |
|
|
)), |
|
|
|
|
|
"gpt_train": lambda cfg: GPTTrainStage(GPTTrainConfig( |
|
|
**base_params, |
|
|
version=version, |
|
|
batch_size=cfg.get("batch_size", 4), |
|
|
total_epoch=cfg.get("total_epoch", 15), |
|
|
save_every_epoch=cfg.get("save_every_epoch", 5), |
|
|
if_save_latest=cfg.get("if_save_latest", True), |
|
|
if_save_every_weights=cfg.get("if_save_every_weights", True), |
|
|
if_dpo=cfg.get("if_dpo", False), |
|
|
pretrained_s1=cfg.get("pretrained_s1", |
|
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"), |
|
|
)), |
|
|
|
|
|
"inference": lambda cfg: InferenceStage(InferenceConfig( |
|
|
**base_params, |
|
|
version=version, |
|
|
gpt_path=cfg.get("gpt_path", ""), |
|
|
sovits_path=cfg.get("sovits_path", ""), |
|
|
ref_text=cfg.get("ref_text", ""), |
|
|
ref_audio_path=cfg.get("ref_audio_path", ""), |
|
|
target_text=cfg.get("target_text", ""), |
|
|
text_split_method=cfg.get("text_split_method", "cut1"), |
|
|
)), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stages = config.get("stages", []) |
|
|
for stage_item in stages: |
|
|
|
|
|
if isinstance(stage_item, str): |
|
|
stage_type = stage_item |
|
|
stage_config = config |
|
|
elif isinstance(stage_item, dict): |
|
|
stage_type = stage_item.get("type") |
|
|
|
|
|
stage_config = {**config, **stage_item} |
|
|
else: |
|
|
emit_log("warning", f"无效的阶段配置类型: {type(stage_item)}") |
|
|
continue |
|
|
|
|
|
if stage_type in stage_builders: |
|
|
stage = stage_builders[stage_type](stage_config) |
|
|
pipeline.add_stage(stage) |
|
|
emit_log("info", f"已添加阶段: {stage.name}") |
|
|
else: |
|
|
emit_log("warning", f"未知阶段类型: {stage_type}") |
|
|
|
|
|
return pipeline |
|
|
|
|
|
|
|
|
def run_pipeline(config: Dict[str, Any], task_id: str) -> bool: |
|
|
""" |
|
|
执行 Pipeline |
|
|
|
|
|
Args: |
|
|
config: 配置字典 |
|
|
task_id: 任务ID |
|
|
|
|
|
Returns: |
|
|
是否成功完成 |
|
|
""" |
|
|
emit_progress({ |
|
|
"type": "progress", |
|
|
"status": "running", |
|
|
"message": "正在初始化训练流水线...", |
|
|
"task_id": task_id, |
|
|
"progress": 0.0, |
|
|
"overall_progress": 0.0, |
|
|
}) |
|
|
|
|
|
try: |
|
|
pipeline = build_pipeline(config) |
|
|
|
|
|
stages = pipeline.get_stages() |
|
|
if not stages: |
|
|
emit_progress({ |
|
|
"type": "progress", |
|
|
"status": "failed", |
|
|
"message": "没有配置任何训练阶段", |
|
|
"task_id": task_id, |
|
|
}) |
|
|
return False |
|
|
|
|
|
emit_log("info", f"训练流水线已初始化,共 {len(stages)} 个阶段") |
|
|
|
|
|
|
|
|
for progress in pipeline.run(): |
|
|
|
|
|
emit_progress({ |
|
|
"type": "progress", |
|
|
"status": "running", |
|
|
"stage": progress.get("stage"), |
|
|
"stage_index": progress.get("stage_index"), |
|
|
"total_stages": progress.get("total_stages"), |
|
|
"progress": progress.get("progress", 0.0), |
|
|
"overall_progress": progress.get("overall_progress", 0.0), |
|
|
"message": progress.get("message"), |
|
|
"task_id": task_id, |
|
|
"data": progress.get("data", {}), |
|
|
}) |
|
|
|
|
|
|
|
|
if progress.get("status") == "failed": |
|
|
emit_progress({ |
|
|
"type": "progress", |
|
|
"status": "failed", |
|
|
"stage": progress.get("stage"), |
|
|
"message": progress.get("message", "阶段执行失败"), |
|
|
"task_id": task_id, |
|
|
}) |
|
|
return False |
|
|
|
|
|
|
|
|
emit_progress({ |
|
|
"type": "progress", |
|
|
"status": "completed", |
|
|
"message": "训练流水线执行完成", |
|
|
"task_id": task_id, |
|
|
"progress": 1.0, |
|
|
"overall_progress": 1.0, |
|
|
}) |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
error_trace = traceback.format_exc() |
|
|
emit_progress({ |
|
|
"type": "progress", |
|
|
"status": "failed", |
|
|
"message": f"执行出错: {error_msg}", |
|
|
"error": error_msg, |
|
|
"traceback": error_trace, |
|
|
"task_id": task_id, |
|
|
}) |
|
|
return False |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
parser = argparse.ArgumentParser(description="执行 GPT-SoVITS 训练流水线") |
|
|
parser.add_argument("--config", required=True, help="配置文件路径 (JSON)") |
|
|
parser.add_argument("--task-id", required=True, help="任务ID") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
emit_log("info", f"启动训练任务: {args.task_id}") |
|
|
emit_log("info", f"配置文件: {args.config}") |
|
|
|
|
|
try: |
|
|
config = load_config(args.config) |
|
|
except Exception as e: |
|
|
emit_progress({ |
|
|
"type": "progress", |
|
|
"status": "failed", |
|
|
"message": f"加载配置文件失败: {e}", |
|
|
"task_id": args.task_id, |
|
|
}) |
|
|
sys.exit(1) |
|
|
|
|
|
success = run_pipeline(config, args.task_id) |
|
|
sys.exit(0 if success else 1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|