liumaolin commited on
Commit ·
8178cb9
1
Parent(s): c8b2614
Refactor config classes and inference pipeline to improve path handling, weight management, and modularity
Browse files
training_pipeline/configs.py
CHANGED
|
@@ -3,9 +3,11 @@
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
|
|
| 6 |
import sys
|
| 7 |
from dataclasses import dataclass
|
| 8 |
|
|
|
|
| 9 |
from .enums import ModelVersion
|
| 10 |
|
| 11 |
|
|
@@ -38,7 +40,7 @@ class AudioSliceConfig(BaseConfig):
|
|
| 38 |
|
| 39 |
@property
|
| 40 |
def output_dir(self):
|
| 41 |
-
return os.path.join(self.
|
| 42 |
|
| 43 |
|
| 44 |
@dataclass
|
|
@@ -51,11 +53,11 @@ class ASRConfig(BaseConfig):
|
|
| 51 |
|
| 52 |
@property
|
| 53 |
def input_dir(self):
|
| 54 |
-
return os.path.join(self.
|
| 55 |
|
| 56 |
@property
|
| 57 |
def output_dir(self):
|
| 58 |
-
return os.path.join(self.
|
| 59 |
|
| 60 |
|
| 61 |
@dataclass
|
|
@@ -76,12 +78,13 @@ class FeatureExtractionConfig(BaseConfig):
|
|
| 76 |
@property
|
| 77 |
def inp_text(self):
|
| 78 |
"""标注文件路径"""
|
| 79 |
-
return os.path.join(self.
|
| 80 |
|
| 81 |
@property
|
| 82 |
def inp_wav_dir(self):
|
| 83 |
"""音频目录"""
|
| 84 |
-
return os.path.join(self.
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
@dataclass
|
|
@@ -99,7 +102,13 @@ class SoVITSTrainConfig(BaseConfig):
|
|
| 99 |
if_grad_ckpt: bool = False
|
| 100 |
lora_rank: int = 32
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
|
|
|
| 103 |
@dataclass
|
| 104 |
class GPTTrainConfig(BaseConfig):
|
| 105 |
"""GPT训练配置"""
|
|
@@ -112,12 +121,43 @@ class GPTTrainConfig(BaseConfig):
|
|
| 112 |
if_dpo: bool = False
|
| 113 |
pretrained_s1: str = 'GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt'
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
@dataclass
|
| 117 |
class InferenceConfig(BaseConfig):
|
| 118 |
"""推理配置"""
|
|
|
|
|
|
|
| 119 |
gpt_path: str = ""
|
| 120 |
sovits_path: str = ""
|
| 121 |
bert_path: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
| 122 |
cnhubert_base_path: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
| 123 |
batched_infer_enabled: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
| 6 |
+
import pathlib
|
| 7 |
import sys
|
| 8 |
from dataclasses import dataclass
|
| 9 |
|
| 10 |
+
from config import SoVITS_weight_version2root, GPT_weight_version2root
|
| 11 |
from .enums import ModelVersion
|
| 12 |
|
| 13 |
|
|
|
|
| 40 |
|
| 41 |
@property
|
| 42 |
def output_dir(self):
|
| 43 |
+
return os.path.join(self.exp_dir, "slicer_opt")
|
| 44 |
|
| 45 |
|
| 46 |
@dataclass
|
|
|
|
| 53 |
|
| 54 |
@property
|
| 55 |
def input_dir(self):
|
| 56 |
+
return os.path.join(self.exp_dir, "slicer_opt")
|
| 57 |
|
| 58 |
@property
|
| 59 |
def output_dir(self):
|
| 60 |
+
return os.path.join(self.exp_dir, "asr_opt")
|
| 61 |
|
| 62 |
|
| 63 |
@dataclass
|
|
|
|
| 78 |
@property
|
| 79 |
def inp_text(self):
|
| 80 |
"""标注文件路径"""
|
| 81 |
+
return os.path.join(self.exp_dir, 'asr_opt', "slicer_opt.list")
|
| 82 |
|
| 83 |
@property
|
| 84 |
def inp_wav_dir(self):
|
| 85 |
"""音频目录"""
|
| 86 |
+
return os.path.join(self.exp_dir, "slicer_opt")
|
| 87 |
+
|
| 88 |
|
| 89 |
|
| 90 |
@dataclass
|
|
|
|
| 102 |
if_grad_ckpt: bool = False
|
| 103 |
lora_rank: int = 32
|
| 104 |
|
| 105 |
+
@property
|
| 106 |
+
def output_dir(self):
|
| 107 |
+
_output_dir = os.path.join(self.exp_dir, SoVITS_weight_version2root[self.version.value])
|
| 108 |
+
os.makedirs(_output_dir, exist_ok=True)
|
| 109 |
+
return _output_dir
|
| 110 |
|
| 111 |
+
|
| 112 |
@dataclass
|
| 113 |
class GPTTrainConfig(BaseConfig):
|
| 114 |
"""GPT训练配置"""
|
|
|
|
| 121 |
if_dpo: bool = False
|
| 122 |
pretrained_s1: str = 'GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt'
|
| 123 |
|
| 124 |
+
@property
|
| 125 |
+
def output_dir(self):
|
| 126 |
+
_output_dir = os.path.join(self.exp_dir, GPT_weight_version2root[self.version.value])
|
| 127 |
+
os.makedirs(_output_dir, exist_ok=True)
|
| 128 |
+
return _output_dir
|
| 129 |
+
|
| 130 |
|
| 131 |
@dataclass
|
| 132 |
class InferenceConfig(BaseConfig):
|
| 133 |
"""推理配置"""
|
| 134 |
+
version: ModelVersion = ModelVersion.V2_PRO
|
| 135 |
+
|
| 136 |
gpt_path: str = ""
|
| 137 |
sovits_path: str = ""
|
| 138 |
bert_path: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
| 139 |
cnhubert_base_path: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
| 140 |
batched_infer_enabled: bool = False
|
| 141 |
+
|
| 142 |
+
ref_text: str = ""
|
| 143 |
+
ref_audio_path: str = ""
|
| 144 |
+
target_text: str = ""
|
| 145 |
+
text_split_method: str = "cut1"
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def output_dir(self):
|
| 149 |
+
return os.path.join(self.exp_dir, 'inference')
|
| 150 |
+
|
| 151 |
+
def gpt_paths(self) -> list:
|
| 152 |
+
"""获取所有 GPT 权重路径"""
|
| 153 |
+
base = pathlib.Path(self.exp_dir) / GPT_weight_version2root[self.version.value]
|
| 154 |
+
if not base.exists():
|
| 155 |
+
return []
|
| 156 |
+
return sorted([item.as_posix() for item in base.iterdir() if item.is_file()])
|
| 157 |
+
|
| 158 |
+
def sovits_paths(self) -> list:
|
| 159 |
+
"""获取所有 SoVITS 权重路径"""
|
| 160 |
+
base = pathlib.Path(self.exp_dir) / SoVITS_weight_version2root[self.version.value]
|
| 161 |
+
if not base.exists():
|
| 162 |
+
return []
|
| 163 |
+
return sorted([item.as_posix() for item in base.iterdir() if item.is_file()])
|
training_pipeline/stages/inference.py
CHANGED
|
@@ -4,50 +4,233 @@
|
|
| 4 |
包含:
|
| 5 |
- InferenceStage: TTS推理
|
| 6 |
"""
|
| 7 |
-
|
| 8 |
import os
|
|
|
|
|
|
|
|
|
|
| 9 |
from typing import Dict, Any, Generator
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from ..base import BaseStage
|
| 12 |
-
from ..enums import StageStatus
|
| 13 |
from ..configs import InferenceConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class InferenceStage(BaseStage):
|
| 17 |
"""1C - TTS推理"""
|
| 18 |
-
|
| 19 |
def __init__(self, config: InferenceConfig):
|
| 20 |
super().__init__(config)
|
| 21 |
self.config: InferenceConfig = config
|
| 22 |
-
|
| 23 |
@property
|
| 24 |
def name(self) -> str:
|
| 25 |
return "TTS推理"
|
| 26 |
-
|
| 27 |
def validate(self) -> bool:
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def run(self) -> Generator[Dict[str, Any], None, None]:
|
| 31 |
self._status = StageStatus.RUNNING
|
| 32 |
cfg = self.config
|
| 33 |
-
|
| 34 |
-
#
|
| 35 |
-
os.
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
if cfg.batched_infer_enabled:
|
| 43 |
-
cmd = f'"{cfg.python_exec}" -s GPT_SoVITS/inference_webui_fast.py'
|
| 44 |
else:
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
self.
|
| 52 |
-
yield self._make_progress("推理WebUI已启动", 1.0)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
包含:
|
| 5 |
- InferenceStage: TTS推理
|
| 6 |
"""
|
|
|
|
| 7 |
import os
|
| 8 |
+
import sys
|
| 9 |
+
from itertools import product
|
| 10 |
+
from pathlib import Path
|
| 11 |
from typing import Dict, Any, Generator
|
| 12 |
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
|
| 15 |
+
from GPT_SoVITS.TTS_infer_pack.TTS import TTS_Config, TTS
|
| 16 |
+
from GPT_SoVITS.utils import HParams
|
| 17 |
from ..base import BaseStage
|
|
|
|
| 18 |
from ..configs import InferenceConfig
|
| 19 |
+
from ..enums import StageStatus
|
| 20 |
+
|
| 21 |
+
if "utils" not in sys.modules:
|
| 22 |
+
class GPTSoVITSFixedUtilsModule:
|
| 23 |
+
HParams = HParams
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
sys.modules['utils'] = GPTSoVITSFixedUtilsModule
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_tts_module(cfg: InferenceConfig):
|
| 30 |
+
"""创建 TTS 模块
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
cfg: 推理配置
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
TTS 模块实例
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
tts_config = TTS_Config({
|
| 40 |
+
"v2": {
|
| 41 |
+
"device": "cpu",
|
| 42 |
+
"is_half": False,
|
| 43 |
+
"version": "v2",
|
| 44 |
+
"t2s_weights_path": cfg.gpt_path,
|
| 45 |
+
"vits_weights_path": cfg.sovits_path,
|
| 46 |
+
"cnhuhbert_base_path": cfg.cnhubert_base_path,
|
| 47 |
+
"bert_base_path": cfg.bert_path,
|
| 48 |
+
},
|
| 49 |
+
})
|
| 50 |
+
return TTS(tts_config)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def create_inference_config(
|
| 54 |
+
text: str,
|
| 55 |
+
ref_audio_path: str,
|
| 56 |
+
prompt_text: str = "",
|
| 57 |
+
text_lang: str = "zh",
|
| 58 |
+
prompt_lang: str = "zh",
|
| 59 |
+
aux_ref_audio_paths: list = None,
|
| 60 |
+
top_k: int = 15,
|
| 61 |
+
top_p: float = 1.0,
|
| 62 |
+
temperature: float = 1.0,
|
| 63 |
+
text_split_method: str = "cut1",
|
| 64 |
+
batch_size: int = 1,
|
| 65 |
+
batch_threshold: float = 0.75,
|
| 66 |
+
split_bucket: bool = True,
|
| 67 |
+
speed_factor: float = 1.0,
|
| 68 |
+
fragment_interval: float = 0.3,
|
| 69 |
+
seed: int = -1,
|
| 70 |
+
parallel_infer: bool = False,
|
| 71 |
+
repetition_penalty: float = 1.35,
|
| 72 |
+
sample_steps: int = 32,
|
| 73 |
+
super_sampling: bool = False,
|
| 74 |
+
return_fragment: bool = False,
|
| 75 |
+
streaming_mode: bool = False,
|
| 76 |
+
overlap_length: int = 2,
|
| 77 |
+
min_chunk_length: int = 16,
|
| 78 |
+
fixed_length_chunk: bool = False,
|
| 79 |
+
) -> Dict[str, Any]:
|
| 80 |
+
"""创建推理配置
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
text: 要合成的文本
|
| 84 |
+
ref_audio_path: 参考音频路径
|
| 85 |
+
prompt_text: 参考音频的提示文本
|
| 86 |
+
text_lang: 文本语言
|
| 87 |
+
prompt_lang: 提示文本语言
|
| 88 |
+
aux_ref_audio_paths: 辅助参考音频路径列表,用于多说话人音色融合
|
| 89 |
+
top_k: top k 采样
|
| 90 |
+
top_p: top p 采样
|
| 91 |
+
temperature: 采样温度
|
| 92 |
+
text_split_method: 文本分割方法,详见 text_segmentation_method.py
|
| 93 |
+
batch_size: 推理批次大小
|
| 94 |
+
batch_threshold: 批次分割阈值
|
| 95 |
+
split_bucket: 是否将批次分割成多个桶
|
| 96 |
+
speed_factor: 控制合成音频的速度
|
| 97 |
+
fragment_interval: 控制音频片段的间隔
|
| 98 |
+
seed: 随机种子,用于可复现性
|
| 99 |
+
parallel_infer: 是否使用并行推理
|
| 100 |
+
repetition_penalty: T2S 模型的重复惩罚
|
| 101 |
+
sample_steps: VITS V3 模型的采样步数
|
| 102 |
+
super_sampling: VITS V3 模型是否使用超采样
|
| 103 |
+
return_fragment: 是否逐步返回音频片段(最佳质量,最慢响应)
|
| 104 |
+
streaming_mode: 是否按块返回音频(中等质量,较慢响应)
|
| 105 |
+
overlap_length: 流式模式下语义 token 的重叠长度
|
| 106 |
+
min_chunk_length: 流式模式下语义 token 的最小块长度
|
| 107 |
+
fixed_length_chunk: 是否使用固定长度块(较低质量,更快响应)
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
推理配置字典
|
| 111 |
+
"""
|
| 112 |
+
if aux_ref_audio_paths is None:
|
| 113 |
+
aux_ref_audio_paths = []
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"text": text, # str.(required) text to be synthesized
|
| 117 |
+
"text_lang": text_lang, # str.(required) language of the text to be synthesized
|
| 118 |
+
"ref_audio_path": ref_audio_path, # str.(required) reference audio path
|
| 119 |
+
"aux_ref_audio_paths": aux_ref_audio_paths,
|
| 120 |
+
# list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
|
| 121 |
+
"prompt_text": prompt_text, # str.(optional) prompt text for the reference audio
|
| 122 |
+
"prompt_lang": prompt_lang, # str.(required) language of the prompt text for the reference audio
|
| 123 |
+
"top_k": top_k, # int. top k sampling
|
| 124 |
+
"top_p": top_p, # float. top p sampling
|
| 125 |
+
"temperature": temperature, # float. temperature for sampling
|
| 126 |
+
"text_split_method": text_split_method, # str. text split method, see text_segmentation_method.py for details.
|
| 127 |
+
"batch_size": batch_size, # int. batch size for inference
|
| 128 |
+
"batch_threshold": batch_threshold, # float. threshold for batch splitting.
|
| 129 |
+
"split_bucket": split_bucket, # bool. whether to split the batch into multiple buckets.
|
| 130 |
+
"speed_factor": speed_factor, # float. control the speed of the synthesized audio.
|
| 131 |
+
"fragment_interval": fragment_interval, # float. to control the interval of the audio fragment.
|
| 132 |
+
"seed": seed, # int. random seed for reproducibility.
|
| 133 |
+
"parallel_infer": parallel_infer, # bool. whether to use parallel inference.
|
| 134 |
+
"repetition_penalty": repetition_penalty, # float. repetition penalty for T2S model.
|
| 135 |
+
"sample_steps": sample_steps, # int. number of sampling steps for VITS model V3.
|
| 136 |
+
"super_sampling": super_sampling, # bool. whether to use super-sampling for audio when using VITS model V3.
|
| 137 |
+
"return_fragment": return_fragment,
|
| 138 |
+
# bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
|
| 139 |
+
"streaming_mode": streaming_mode, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
|
| 140 |
+
"overlap_length": overlap_length, # int. overlap length of semantic tokens for streaming mode.
|
| 141 |
+
"min_chunk_length": min_chunk_length,
|
| 142 |
+
# int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
|
| 143 |
+
"fixed_length_chunk": fixed_length_chunk,
|
| 144 |
+
# bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
|
| 145 |
+
}
|
| 146 |
|
| 147 |
|
| 148 |
class InferenceStage(BaseStage):
|
| 149 |
"""1C - TTS推理"""
|
| 150 |
+
|
| 151 |
def __init__(self, config: InferenceConfig):
|
| 152 |
super().__init__(config)
|
| 153 |
self.config: InferenceConfig = config
|
| 154 |
+
|
| 155 |
@property
|
| 156 |
def name(self) -> str:
|
| 157 |
return "TTS推理"
|
| 158 |
+
|
| 159 |
def validate(self) -> bool:
|
| 160 |
+
# 如果指定了具体路径,使用指定的路径
|
| 161 |
+
if self.config.gpt_path and self.config.sovits_path:
|
| 162 |
+
return True
|
| 163 |
+
# 否则检查是否能从实验目录获取路径
|
| 164 |
+
gpt_paths = self.config.gpt_paths()
|
| 165 |
+
sovits_paths = self.config.sovits_paths()
|
| 166 |
+
return len(gpt_paths) > 0 and len(sovits_paths) > 0
|
| 167 |
+
|
| 168 |
def run(self) -> Generator[Dict[str, Any], None, None]:
|
| 169 |
self._status = StageStatus.RUNNING
|
| 170 |
cfg = self.config
|
| 171 |
+
|
| 172 |
+
# 确保输出目录存在
|
| 173 |
+
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 174 |
+
|
| 175 |
+
# 获取所有权重路径
|
| 176 |
+
if cfg.gpt_path and cfg.sovits_path:
|
| 177 |
+
# 使用指定的单一路径
|
| 178 |
+
combinations = [(cfg.gpt_path, cfg.sovits_path)]
|
|
|
|
|
|
|
|
|
|
| 179 |
else:
|
| 180 |
+
# 获取所有路径进行排列组合
|
| 181 |
+
gpt_paths = cfg.gpt_paths()
|
| 182 |
+
sovits_paths = cfg.sovits_paths()
|
| 183 |
+
combinations = list(product(gpt_paths, sovits_paths))
|
| 184 |
+
|
| 185 |
+
total_combinations = len(combinations)
|
| 186 |
+
yield self._make_progress(f"共 {total_combinations} 个组合待推理", 0.0)
|
|
|
|
| 187 |
|
| 188 |
+
for idx, (gpt_path, sovits_path) in enumerate(combinations):
|
| 189 |
+
# 提取权重文件名(不含扩展名)
|
| 190 |
+
gpt_name = Path(gpt_path).stem
|
| 191 |
+
sovits_name = Path(sovits_path).stem
|
| 192 |
+
|
| 193 |
+
# 生成独立的输出文件名
|
| 194 |
+
output_filename = f"{cfg.exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav"
|
| 195 |
+
output_path = os.path.join(cfg.output_dir, output_filename)
|
| 196 |
+
|
| 197 |
+
progress = (idx / total_combinations)
|
| 198 |
+
yield self._make_progress(
|
| 199 |
+
f"[{idx + 1}/{total_combinations}] GPT: {gpt_name}, SoVITS: {sovits_name}",
|
| 200 |
+
progress
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# 创建临时配置用于当前组合
|
| 204 |
+
temp_cfg = InferenceConfig(
|
| 205 |
+
exp_name=cfg.exp_name,
|
| 206 |
+
exp_root=cfg.exp_root,
|
| 207 |
+
gpt_path=gpt_path,
|
| 208 |
+
sovits_path=sovits_path,
|
| 209 |
+
bert_path=cfg.bert_path,
|
| 210 |
+
cnhubert_base_path=cfg.cnhubert_base_path,
|
| 211 |
+
ref_text=cfg.ref_text,
|
| 212 |
+
ref_audio_path=cfg.ref_audio_path,
|
| 213 |
+
target_text=cfg.target_text,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# 创建 TTS 模块并推理
|
| 217 |
+
module = create_tts_module(temp_cfg)
|
| 218 |
+
inference_config = create_inference_config(
|
| 219 |
+
text=cfg.target_text,
|
| 220 |
+
ref_audio_path=cfg.ref_audio_path,
|
| 221 |
+
prompt_text=cfg.ref_text,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
for item in module.run(inference_config):
|
| 225 |
+
sample_rate, audio_data = item[0], item[1]
|
| 226 |
+
# 保存到独立的输出文件
|
| 227 |
+
sf.write(output_path, audio_data, sample_rate, subtype='PCM_16')
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
yield self._make_progress(
|
| 231 |
+
f"[{idx + 1}/{total_combinations}] 已保存: {output_filename}",
|
| 232 |
+
(idx + 1) / total_combinations
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
self._status = StageStatus.COMPLETED
|
| 236 |
+
yield self._make_progress(f"推理完成,共生成 {total_combinations} 个音频文件", 1.0)
|
training_pipeline/stages/training.py
CHANGED
|
@@ -6,16 +6,16 @@
|
|
| 6 |
- GPTTrainStage: GPT模型训练
|
| 7 |
"""
|
| 8 |
|
| 9 |
-
import os
|
| 10 |
import json
|
|
|
|
| 11 |
from typing import Dict, Any, Generator
|
| 12 |
|
| 13 |
import yaml
|
| 14 |
|
| 15 |
from ..base import BaseStage
|
| 16 |
-
from ..enums import StageStatus, ModelVersion
|
| 17 |
from ..configs import SoVITSTrainConfig, GPTTrainConfig
|
| 18 |
-
from
|
|
|
|
| 19 |
|
| 20 |
class SoVITSTrainStage(BaseStage):
|
| 21 |
"""1Ba - SoVITS模型训练"""
|
|
@@ -69,8 +69,7 @@ class SoVITSTrainStage(BaseStage):
|
|
| 69 |
data["train"]["lora_rank"] = cfg.lora_rank
|
| 70 |
data["model"]["version"] = version_str
|
| 71 |
data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir
|
| 72 |
-
data["save_weight_dir"] =
|
| 73 |
-
os.makedirs(SoVITS_weight_version2root[version_str], exist_ok=True)
|
| 74 |
data["name"] = cfg.exp_name
|
| 75 |
data["version"] = version_str
|
| 76 |
|
|
@@ -136,8 +135,7 @@ class GPTTrainStage(BaseStage):
|
|
| 136 |
data["train"]["if_save_every_weights"] = cfg.if_save_every_weights
|
| 137 |
data["train"]["if_save_latest"] = cfg.if_save_latest
|
| 138 |
data["train"]["if_dpo"] = cfg.if_dpo
|
| 139 |
-
data["train"]["half_weights_save_dir"] =
|
| 140 |
-
os.makedirs(GPT_weight_version2root[cfg.version.value], exist_ok=True)
|
| 141 |
data["train"]["exp_name"] = cfg.exp_name
|
| 142 |
data["train_semantic_path"] = f"{s1_dir}/6-name2semantic.tsv"
|
| 143 |
data["train_phoneme_path"] = f"{s1_dir}/2-name2text.txt"
|
|
|
|
| 6 |
- GPTTrainStage: GPT模型训练
|
| 7 |
"""
|
| 8 |
|
|
|
|
| 9 |
import json
|
| 10 |
+
import os
|
| 11 |
from typing import Dict, Any, Generator
|
| 12 |
|
| 13 |
import yaml
|
| 14 |
|
| 15 |
from ..base import BaseStage
|
|
|
|
| 16 |
from ..configs import SoVITSTrainConfig, GPTTrainConfig
|
| 17 |
+
from ..enums import StageStatus, ModelVersion
|
| 18 |
+
|
| 19 |
|
| 20 |
class SoVITSTrainStage(BaseStage):
|
| 21 |
"""1Ba - SoVITS模型训练"""
|
|
|
|
| 69 |
data["train"]["lora_rank"] = cfg.lora_rank
|
| 70 |
data["model"]["version"] = version_str
|
| 71 |
data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir
|
| 72 |
+
data["save_weight_dir"] = cfg.output_dir
|
|
|
|
| 73 |
data["name"] = cfg.exp_name
|
| 74 |
data["version"] = version_str
|
| 75 |
|
|
|
|
| 135 |
data["train"]["if_save_every_weights"] = cfg.if_save_every_weights
|
| 136 |
data["train"]["if_save_latest"] = cfg.if_save_latest
|
| 137 |
data["train"]["if_dpo"] = cfg.if_dpo
|
| 138 |
+
data["train"]["half_weights_save_dir"] = cfg.output_dir
|
|
|
|
| 139 |
data["train"]["exp_name"] = cfg.exp_name
|
| 140 |
data["train_semantic_path"] = f"{s1_dir}/6-name2semantic.tsv"
|
| 141 |
data["train_phoneme_path"] = f"{s1_dir}/2-name2text.txt"
|