liumaolin commited on
Commit
8178cb9
·
1 Parent(s): c8b2614

Refactor config classes and inference pipeline to improve path handling, weight management, and modularity

Browse files
training_pipeline/configs.py CHANGED
@@ -3,9 +3,11 @@
3
  """
4
 
5
  import os
 
6
  import sys
7
  from dataclasses import dataclass
8
 
 
9
  from .enums import ModelVersion
10
 
11
 
@@ -38,7 +40,7 @@ class AudioSliceConfig(BaseConfig):
38
 
39
  @property
40
  def output_dir(self):
41
- return os.path.join(self.exp_root, self.exp_name, "slicer_opt")
42
 
43
 
44
  @dataclass
@@ -51,11 +53,11 @@ class ASRConfig(BaseConfig):
51
 
52
  @property
53
  def input_dir(self):
54
- return os.path.join(self.exp_root, self.exp_name, "slicer_opt")
55
 
56
  @property
57
  def output_dir(self):
58
- return os.path.join(self.exp_root, self.exp_name, "asr_opt")
59
 
60
 
61
  @dataclass
@@ -76,12 +78,13 @@ class FeatureExtractionConfig(BaseConfig):
76
  @property
77
  def inp_text(self):
78
  """标注文件路径"""
79
- return os.path.join(self.exp_root, self.exp_name, "slicer_opt.list")
80
 
81
  @property
82
  def inp_wav_dir(self):
83
  """音频目录"""
84
- return os.path.join(self.exp_root, self.exp_name, "slicer_opt")
 
85
 
86
 
87
  @dataclass
@@ -99,7 +102,13 @@ class SoVITSTrainConfig(BaseConfig):
99
  if_grad_ckpt: bool = False
100
  lora_rank: int = 32
101
 
 
 
 
 
 
102
 
 
103
  @dataclass
104
  class GPTTrainConfig(BaseConfig):
105
  """GPT训练配置"""
@@ -112,12 +121,43 @@ class GPTTrainConfig(BaseConfig):
112
  if_dpo: bool = False
113
  pretrained_s1: str = 'GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt'
114
 
 
 
 
 
 
 
115
 
116
  @dataclass
117
  class InferenceConfig(BaseConfig):
118
  """推理配置"""
 
 
119
  gpt_path: str = ""
120
  sovits_path: str = ""
121
  bert_path: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
122
  cnhubert_base_path: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
123
  batched_infer_enabled: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  import os
6
+ import pathlib
7
  import sys
8
  from dataclasses import dataclass
9
 
10
+ from config import SoVITS_weight_version2root, GPT_weight_version2root
11
  from .enums import ModelVersion
12
 
13
 
 
40
 
41
  @property
42
  def output_dir(self):
43
+ return os.path.join(self.exp_dir, "slicer_opt")
44
 
45
 
46
  @dataclass
 
53
 
54
  @property
55
  def input_dir(self):
56
+ return os.path.join(self.exp_dir, "slicer_opt")
57
 
58
  @property
59
  def output_dir(self):
60
+ return os.path.join(self.exp_dir, "asr_opt")
61
 
62
 
63
  @dataclass
 
78
  @property
79
  def inp_text(self):
80
  """标注文件路径"""
81
+ return os.path.join(self.exp_dir, 'asr_opt', "slicer_opt.list")
82
 
83
  @property
84
  def inp_wav_dir(self):
85
  """音频目录"""
86
+ return os.path.join(self.exp_dir, "slicer_opt")
87
+
88
 
89
 
90
  @dataclass
 
102
  if_grad_ckpt: bool = False
103
  lora_rank: int = 32
104
 
105
+ @property
106
+ def output_dir(self):
107
+ _output_dir = os.path.join(self.exp_dir, SoVITS_weight_version2root[self.version.value])
108
+ os.makedirs(_output_dir, exist_ok=True)
109
+ return _output_dir
110
 
111
+
112
  @dataclass
113
  class GPTTrainConfig(BaseConfig):
114
  """GPT训练配置"""
 
121
  if_dpo: bool = False
122
  pretrained_s1: str = 'GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt'
123
 
124
+ @property
125
+ def output_dir(self):
126
+ _output_dir = os.path.join(self.exp_dir, GPT_weight_version2root[self.version.value])
127
+ os.makedirs(_output_dir, exist_ok=True)
128
+ return _output_dir
129
+
130
 
131
  @dataclass
132
  class InferenceConfig(BaseConfig):
133
  """推理配置"""
134
+ version: ModelVersion = ModelVersion.V2_PRO
135
+
136
  gpt_path: str = ""
137
  sovits_path: str = ""
138
  bert_path: str = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
139
  cnhubert_base_path: str = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
140
  batched_infer_enabled: bool = False
141
+
142
+ ref_text: str = ""
143
+ ref_audio_path: str = ""
144
+ target_text: str = ""
145
+ text_split_method: str = "cut1"
146
+
147
+ @property
148
+ def output_dir(self):
149
+ return os.path.join(self.exp_dir, 'inference')
150
+
151
+ def gpt_paths(self) -> list:
152
+ """获取所有 GPT 权重路径"""
153
+ base = pathlib.Path(self.exp_dir) / GPT_weight_version2root[self.version.value]
154
+ if not base.exists():
155
+ return []
156
+ return sorted([item.as_posix() for item in base.iterdir() if item.is_file()])
157
+
158
+ def sovits_paths(self) -> list:
159
+ """获取所有 SoVITS 权重路径"""
160
+ base = pathlib.Path(self.exp_dir) / SoVITS_weight_version2root[self.version.value]
161
+ if not base.exists():
162
+ return []
163
+ return sorted([item.as_posix() for item in base.iterdir() if item.is_file()])
training_pipeline/stages/inference.py CHANGED
@@ -4,50 +4,233 @@
4
  包含:
5
  - InferenceStage: TTS推理
6
  """
7
-
8
  import os
 
 
 
9
  from typing import Dict, Any, Generator
10
 
 
 
 
 
11
  from ..base import BaseStage
12
- from ..enums import StageStatus
13
  from ..configs import InferenceConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  class InferenceStage(BaseStage):
17
  """1C - TTS推理"""
18
-
19
  def __init__(self, config: InferenceConfig):
20
  super().__init__(config)
21
  self.config: InferenceConfig = config
22
-
23
  @property
24
  def name(self) -> str:
25
  return "TTS推理"
26
-
27
  def validate(self) -> bool:
28
- return self.config.gpt_path != "" and self.config.sovits_path != ""
29
-
 
 
 
 
 
 
30
  def run(self) -> Generator[Dict[str, Any], None, None]:
31
  self._status = StageStatus.RUNNING
32
  cfg = self.config
33
-
34
- # 设置环境变量
35
- os.environ["gpt_path"] = cfg.gpt_path
36
- os.environ["sovits_path"] = cfg.sovits_path
37
- os.environ["cnhubert_base_path"] = cfg.cnhubert_base_path
38
- os.environ["bert_path"] = cfg.bert_path
39
- os.environ["_CUDA_VISIBLE_DEVICES"] = cfg.gpu_numbers
40
- os.environ["is_half"] = str(cfg.is_half)
41
-
42
- if cfg.batched_infer_enabled:
43
- cmd = f'"{cfg.python_exec}" -s GPT_SoVITS/inference_webui_fast.py'
44
  else:
45
- cmd = f'"{cfg.python_exec}" -s GPT_SoVITS/inference_webui.py'
46
-
47
- yield self._make_progress("推理WebUI启动中...", 0.5)
48
-
49
- self._process = self._run_command(cmd, wait=False)
50
-
51
- self._status = StageStatus.COMPLETED
52
- yield self._make_progress("推理WebUI已启动", 1.0)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  包含:
5
  - InferenceStage: TTS推理
6
  """
 
7
  import os
8
+ import sys
9
+ from itertools import product
10
+ from pathlib import Path
11
  from typing import Dict, Any, Generator
12
 
13
+ import soundfile as sf
14
+
15
+ from GPT_SoVITS.TTS_infer_pack.TTS import TTS_Config, TTS
16
+ from GPT_SoVITS.utils import HParams
17
  from ..base import BaseStage
 
18
  from ..configs import InferenceConfig
19
+ from ..enums import StageStatus
20
+
21
+ if "utils" not in sys.modules:
22
+ class GPTSoVITSFixedUtilsModule:
23
+ HParams = HParams
24
+
25
+
26
+ sys.modules['utils'] = GPTSoVITSFixedUtilsModule
27
+
28
+
29
+ def create_tts_module(cfg: InferenceConfig):
30
+ """创建 TTS 模块
31
+
32
+ Args:
33
+ cfg: 推理配置
34
+
35
+ Returns:
36
+ TTS 模块实例
37
+ """
38
+
39
+ tts_config = TTS_Config({
40
+ "v2": {
41
+ "device": "cpu",
42
+ "is_half": False,
43
+ "version": "v2",
44
+ "t2s_weights_path": cfg.gpt_path,
45
+ "vits_weights_path": cfg.sovits_path,
46
+ "cnhuhbert_base_path": cfg.cnhubert_base_path,
47
+ "bert_base_path": cfg.bert_path,
48
+ },
49
+ })
50
+ return TTS(tts_config)
51
+
52
+
53
+ def create_inference_config(
54
+ text: str,
55
+ ref_audio_path: str,
56
+ prompt_text: str = "",
57
+ text_lang: str = "zh",
58
+ prompt_lang: str = "zh",
59
+ aux_ref_audio_paths: list = None,
60
+ top_k: int = 15,
61
+ top_p: float = 1.0,
62
+ temperature: float = 1.0,
63
+ text_split_method: str = "cut1",
64
+ batch_size: int = 1,
65
+ batch_threshold: float = 0.75,
66
+ split_bucket: bool = True,
67
+ speed_factor: float = 1.0,
68
+ fragment_interval: float = 0.3,
69
+ seed: int = -1,
70
+ parallel_infer: bool = False,
71
+ repetition_penalty: float = 1.35,
72
+ sample_steps: int = 32,
73
+ super_sampling: bool = False,
74
+ return_fragment: bool = False,
75
+ streaming_mode: bool = False,
76
+ overlap_length: int = 2,
77
+ min_chunk_length: int = 16,
78
+ fixed_length_chunk: bool = False,
79
+ ) -> Dict[str, Any]:
80
+ """创建推理配置
81
+
82
+ Args:
83
+ text: 要合成的文本
84
+ ref_audio_path: 参考音频路径
85
+ prompt_text: 参考音频的提示文本
86
+ text_lang: 文本语言
87
+ prompt_lang: 提示文本语言
88
+ aux_ref_audio_paths: 辅助参考音频路径列表,用于多说话人音色融合
89
+ top_k: top k 采样
90
+ top_p: top p 采样
91
+ temperature: 采样温度
92
+ text_split_method: 文本分割方法,详见 text_segmentation_method.py
93
+ batch_size: 推理批次大小
94
+ batch_threshold: 批次分割阈值
95
+ split_bucket: 是否将批次分割成多个桶
96
+ speed_factor: 控制合成音频的速度
97
+ fragment_interval: 控制音频片段的间隔
98
+ seed: 随机种子,用于可复现性
99
+ parallel_infer: 是否使用并行推理
100
+ repetition_penalty: T2S 模型的重复惩罚
101
+ sample_steps: VITS V3 模型的采样步数
102
+ super_sampling: VITS V3 模型是否使用超采样
103
+ return_fragment: 是否逐步返回音频片段(最佳质量,最慢响应)
104
+ streaming_mode: 是否按块返回音频(中等质量,较慢响应)
105
+ overlap_length: 流式模式下语义 token 的重叠长度
106
+ min_chunk_length: 流式模式下语义 token 的最小块长度
107
+ fixed_length_chunk: 是否使用固定长度块(较低质量,更快响应)
108
+
109
+ Returns:
110
+ 推理配置字典
111
+ """
112
+ if aux_ref_audio_paths is None:
113
+ aux_ref_audio_paths = []
114
+
115
+ return {
116
+ "text": text, # str.(required) text to be synthesized
117
+ "text_lang": text_lang, # str.(required) language of the text to be synthesized
118
+ "ref_audio_path": ref_audio_path, # str.(required) reference audio path
119
+ "aux_ref_audio_paths": aux_ref_audio_paths,
120
+ # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion
121
+ "prompt_text": prompt_text, # str.(optional) prompt text for the reference audio
122
+ "prompt_lang": prompt_lang, # str.(required) language of the prompt text for the reference audio
123
+ "top_k": top_k, # int. top k sampling
124
+ "top_p": top_p, # float. top p sampling
125
+ "temperature": temperature, # float. temperature for sampling
126
+ "text_split_method": text_split_method, # str. text split method, see text_segmentation_method.py for details.
127
+ "batch_size": batch_size, # int. batch size for inference
128
+ "batch_threshold": batch_threshold, # float. threshold for batch splitting.
129
+ "split_bucket": split_bucket, # bool. whether to split the batch into multiple buckets.
130
+ "speed_factor": speed_factor, # float. control the speed of the synthesized audio.
131
+ "fragment_interval": fragment_interval, # float. to control the interval of the audio fragment.
132
+ "seed": seed, # int. random seed for reproducibility.
133
+ "parallel_infer": parallel_infer, # bool. whether to use parallel inference.
134
+ "repetition_penalty": repetition_penalty, # float. repetition penalty for T2S model.
135
+ "sample_steps": sample_steps, # int. number of sampling steps for VITS model V3.
136
+ "super_sampling": super_sampling, # bool. whether to use super-sampling for audio when using VITS model V3.
137
+ "return_fragment": return_fragment,
138
+ # bool. step by step return the audio fragment. (Best Quality, Slowest response speed. old version of streaming mode)
139
+ "streaming_mode": streaming_mode, # bool. return audio chunk by chunk. (Medium quality, Slow response speed)
140
+ "overlap_length": overlap_length, # int. overlap length of semantic tokens for streaming mode.
141
+ "min_chunk_length": min_chunk_length,
142
+ # int. The minimum chunk length of semantic tokens for streaming mode. (affects audio chunk size)
143
+ "fixed_length_chunk": fixed_length_chunk,
144
+ # bool. When turned on, it can achieve faster streaming response, but with lower quality. (lower quality, faster response speed)
145
+ }
146
 
147
 
148
  class InferenceStage(BaseStage):
149
  """1C - TTS推理"""
150
+
151
  def __init__(self, config: InferenceConfig):
152
  super().__init__(config)
153
  self.config: InferenceConfig = config
154
+
155
  @property
156
  def name(self) -> str:
157
  return "TTS推理"
158
+
159
  def validate(self) -> bool:
160
+ # 如果指定了具体路径,使用指定的路径
161
+ if self.config.gpt_path and self.config.sovits_path:
162
+ return True
163
+ # 否则检查是否能从实验目录获取路径
164
+ gpt_paths = self.config.gpt_paths()
165
+ sovits_paths = self.config.sovits_paths()
166
+ return len(gpt_paths) > 0 and len(sovits_paths) > 0
167
+
168
  def run(self) -> Generator[Dict[str, Any], None, None]:
169
  self._status = StageStatus.RUNNING
170
  cfg = self.config
171
+
172
+ # 确保输出目录存在
173
+ os.makedirs(cfg.output_dir, exist_ok=True)
174
+
175
+ # 获取所有权重路径
176
+ if cfg.gpt_path and cfg.sovits_path:
177
+ # 使用指定的单一路径
178
+ combinations = [(cfg.gpt_path, cfg.sovits_path)]
 
 
 
179
  else:
180
+ # 获取所有路径进行排列组合
181
+ gpt_paths = cfg.gpt_paths()
182
+ sovits_paths = cfg.sovits_paths()
183
+ combinations = list(product(gpt_paths, sovits_paths))
184
+
185
+ total_combinations = len(combinations)
186
+ yield self._make_progress(f"共 {total_combinations} 个组合待推理", 0.0)
 
187
 
188
+ for idx, (gpt_path, sovits_path) in enumerate(combinations):
189
+ # 提取权重文件名(不含扩展名)
190
+ gpt_name = Path(gpt_path).stem
191
+ sovits_name = Path(sovits_path).stem
192
+
193
+ # 生成独立的输出文件名
194
+ output_filename = f"{cfg.exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav"
195
+ output_path = os.path.join(cfg.output_dir, output_filename)
196
+
197
+ progress = (idx / total_combinations)
198
+ yield self._make_progress(
199
+ f"[{idx + 1}/{total_combinations}] GPT: {gpt_name}, SoVITS: {sovits_name}",
200
+ progress
201
+ )
202
+
203
+ # 创建临时配置用于当前组合
204
+ temp_cfg = InferenceConfig(
205
+ exp_name=cfg.exp_name,
206
+ exp_root=cfg.exp_root,
207
+ gpt_path=gpt_path,
208
+ sovits_path=sovits_path,
209
+ bert_path=cfg.bert_path,
210
+ cnhubert_base_path=cfg.cnhubert_base_path,
211
+ ref_text=cfg.ref_text,
212
+ ref_audio_path=cfg.ref_audio_path,
213
+ target_text=cfg.target_text,
214
+ )
215
+
216
+ # 创建 TTS 模块并推理
217
+ module = create_tts_module(temp_cfg)
218
+ inference_config = create_inference_config(
219
+ text=cfg.target_text,
220
+ ref_audio_path=cfg.ref_audio_path,
221
+ prompt_text=cfg.ref_text,
222
+ )
223
+
224
+ for item in module.run(inference_config):
225
+ sample_rate, audio_data = item[0], item[1]
226
+ # 保存到独立的输出文件
227
+ sf.write(output_path, audio_data, sample_rate, subtype='PCM_16')
228
+ break
229
+
230
+ yield self._make_progress(
231
+ f"[{idx + 1}/{total_combinations}] 已保存: {output_filename}",
232
+ (idx + 1) / total_combinations
233
+ )
234
+
235
+ self._status = StageStatus.COMPLETED
236
+ yield self._make_progress(f"推理完成,共生成 {total_combinations} 个音频文件", 1.0)
training_pipeline/stages/training.py CHANGED
@@ -6,16 +6,16 @@
6
  - GPTTrainStage: GPT模型训练
7
  """
8
 
9
- import os
10
  import json
 
11
  from typing import Dict, Any, Generator
12
 
13
  import yaml
14
 
15
  from ..base import BaseStage
16
- from ..enums import StageStatus, ModelVersion
17
  from ..configs import SoVITSTrainConfig, GPTTrainConfig
18
- from config import SoVITS_weight_version2root, GPT_weight_version2root
 
19
 
20
  class SoVITSTrainStage(BaseStage):
21
  """1Ba - SoVITS模型训练"""
@@ -69,8 +69,7 @@ class SoVITSTrainStage(BaseStage):
69
  data["train"]["lora_rank"] = cfg.lora_rank
70
  data["model"]["version"] = version_str
71
  data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir
72
- data["save_weight_dir"] = SoVITS_weight_version2root[version_str]
73
- os.makedirs(SoVITS_weight_version2root[version_str], exist_ok=True)
74
  data["name"] = cfg.exp_name
75
  data["version"] = version_str
76
 
@@ -136,8 +135,7 @@ class GPTTrainStage(BaseStage):
136
  data["train"]["if_save_every_weights"] = cfg.if_save_every_weights
137
  data["train"]["if_save_latest"] = cfg.if_save_latest
138
  data["train"]["if_dpo"] = cfg.if_dpo
139
- data["train"]["half_weights_save_dir"] = GPT_weight_version2root[cfg.version.value]
140
- os.makedirs(GPT_weight_version2root[cfg.version.value], exist_ok=True)
141
  data["train"]["exp_name"] = cfg.exp_name
142
  data["train_semantic_path"] = f"{s1_dir}/6-name2semantic.tsv"
143
  data["train_phoneme_path"] = f"{s1_dir}/2-name2text.txt"
 
6
  - GPTTrainStage: GPT模型训练
7
  """
8
 
 
9
  import json
10
+ import os
11
  from typing import Dict, Any, Generator
12
 
13
  import yaml
14
 
15
  from ..base import BaseStage
 
16
  from ..configs import SoVITSTrainConfig, GPTTrainConfig
17
+ from ..enums import StageStatus, ModelVersion
18
+
19
 
20
  class SoVITSTrainStage(BaseStage):
21
  """1Ba - SoVITS模型训练"""
 
69
  data["train"]["lora_rank"] = cfg.lora_rank
70
  data["model"]["version"] = version_str
71
  data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir
72
+ data["save_weight_dir"] = cfg.output_dir
 
73
  data["name"] = cfg.exp_name
74
  data["version"] = version_str
75
 
 
135
  data["train"]["if_save_every_weights"] = cfg.if_save_every_weights
136
  data["train"]["if_save_latest"] = cfg.if_save_latest
137
  data["train"]["if_dpo"] = cfg.if_dpo
138
+ data["train"]["half_weights_save_dir"] = cfg.output_dir
 
139
  data["train"]["exp_name"] = cfg.exp_name
140
  data["train_semantic_path"] = f"{s1_dir}/6-name2semantic.tsv"
141
  data["train_phoneme_path"] = f"{s1_dir}/2-name2text.txt"