| | import copy |
| | import io |
| | import json |
| | import logging |
| | import re |
| | from typing import List, Union |
| |
|
| | import numpy as np |
| | from box import Box |
| | from pydub import AudioSegment |
| | from scipy.io import wavfile |
| |
|
| | from modules import generate_audio |
| | from modules.api.utils import calc_spk_style |
| | from modules.normalization import text_normalize |
| | from modules.SentenceSplitter import SentenceSplitter |
| | from modules.speaker import Speaker |
| | from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment |
| | from modules.utils import rng |
| | from modules.utils.audio import apply_prosody_to_audio_segment |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def audio_data_to_segment_slow(audio_data, sr): |
| | byte_io = io.BytesIO() |
| | wavfile.write(byte_io, rate=sr, data=audio_data) |
| | byte_io.seek(0) |
| |
|
| | return AudioSegment.from_file(byte_io, format="wav") |
| |
|
| |
|
| | def clip_audio(audio_data: np.ndarray, threshold: float = 0.99): |
| | audio_data = np.clip(audio_data, -threshold, threshold) |
| | return audio_data |
| |
|
| |
|
| | def normalize_audio(audio_data: np.ndarray, norm_factor: float = 0.8): |
| | max_amplitude = np.max(np.abs(audio_data)) |
| | if max_amplitude > 0: |
| | audio_data = audio_data / max_amplitude * norm_factor |
| | return audio_data |
| |
|
| |
|
| | def audio_data_to_segment(audio_data: np.ndarray, sr: int): |
| | """ |
| | optimize: https://github.com/lenML/ChatTTS-Forge/issues/57 |
| | """ |
| |
|
| | audio_data = normalize_audio(audio_data) |
| | audio_data = clip_audio(audio_data) |
| |
|
| | audio_data = (audio_data * 32767).astype(np.int16) |
| | audio_segment = AudioSegment( |
| | audio_data.tobytes(), |
| | frame_rate=sr, |
| | sample_width=audio_data.dtype.itemsize, |
| | channels=1, |
| | ) |
| | return audio_segment |
| |
|
| |
|
| | def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment: |
| | combined_audio = AudioSegment.empty() |
| | for segment in audio_segments: |
| | combined_audio += segment |
| | return combined_audio |
| |
|
| |
|
| | def to_number(value, t, default=0): |
| | try: |
| | number = t(value) |
| | return number |
| | except (ValueError, TypeError) as e: |
| | return default |
| |
|
| |
|
| | class TTSAudioSegment(Box): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self._type = kwargs.get("_type", "voice") |
| | self.text = kwargs.get("text", "") |
| | self.temperature = kwargs.get("temperature", 0.3) |
| | self.top_P = kwargs.get("top_P", 0.5) |
| | self.top_K = kwargs.get("top_K", 20) |
| | self.spk = kwargs.get("spk", -1) |
| | self.infer_seed = kwargs.get("infer_seed", -1) |
| | self.prompt1 = kwargs.get("prompt1", "") |
| | self.prompt2 = kwargs.get("prompt2", "") |
| | self.prefix = kwargs.get("prefix", "") |
| |
|
| |
|
| | class SynthesizeSegments: |
| | def __init__(self, batch_size: int = 8, eos="", spliter_thr=100): |
| | self.batch_size = batch_size |
| | self.batch_default_spk_seed = rng.np_rng() |
| | self.batch_default_infer_seed = rng.np_rng() |
| | self.eos = eos |
| | self.spliter_thr = spliter_thr |
| |
|
| | def segment_to_generate_params( |
| | self, segment: Union[SSMLSegment, SSMLBreak] |
| | ) -> TTSAudioSegment: |
| | if isinstance(segment, SSMLBreak): |
| | return TTSAudioSegment(_type="break") |
| |
|
| | if segment.get("params", None) is not None: |
| | params = segment.get("params") |
| | text = segment.get("text", None) or segment.text or "" |
| | return TTSAudioSegment(**params, text=text) |
| |
|
| | text = segment.get("text", None) or segment.text or "" |
| | is_end = segment.get("is_end", False) |
| |
|
| | text = str(text).strip() |
| |
|
| | attrs = segment.attrs |
| | spk = attrs.spk |
| | style = attrs.style |
| |
|
| | ss_params = calc_spk_style(spk, style) |
| |
|
| | if "spk" in ss_params: |
| | spk = ss_params["spk"] |
| |
|
| | seed = to_number(attrs.seed, int, ss_params.get("seed") or -1) |
| | top_k = to_number(attrs.top_k, int, None) |
| | top_p = to_number(attrs.top_p, float, None) |
| | temp = to_number(attrs.temp, float, None) |
| |
|
| | prompt1 = attrs.prompt1 or ss_params.get("prompt1") |
| | prompt2 = attrs.prompt2 or ss_params.get("prompt2") |
| | prefix = attrs.prefix or ss_params.get("prefix") |
| | disable_normalize = attrs.get("normalize", "") == "False" |
| |
|
| | seg = TTSAudioSegment( |
| | _type="voice", |
| | text=text, |
| | temperature=temp if temp is not None else 0.3, |
| | top_P=top_p if top_p is not None else 0.5, |
| | top_K=top_k if top_k is not None else 20, |
| | spk=spk if spk else -1, |
| | infer_seed=seed if seed else -1, |
| | prompt1=prompt1 if prompt1 else "", |
| | prompt2=prompt2 if prompt2 else "", |
| | prefix=prefix if prefix else "", |
| | ) |
| |
|
| | if not disable_normalize: |
| | seg.text = text_normalize(text, is_end=is_end) |
| |
|
| | |
| | if seg.spk == -1: |
| | seg.spk = self.batch_default_spk_seed |
| | if seg.infer_seed == -1: |
| | seg.infer_seed = self.batch_default_infer_seed |
| |
|
| | return seg |
| |
|
| | def process_break_segments( |
| | self, |
| | src_segments: List[SSMLBreak], |
| | bucket_segments: List[SSMLBreak], |
| | audio_segments: List[AudioSegment], |
| | ): |
| | for segment in bucket_segments: |
| | index = src_segments.index(segment) |
| | audio_segments[index] = AudioSegment.silent( |
| | duration=int(segment.attrs.duration) |
| | ) |
| |
|
| | def process_voice_segments( |
| | self, |
| | src_segments: List[SSMLSegment], |
| | bucket: List[SSMLSegment], |
| | audio_segments: List[AudioSegment], |
| | ): |
| | for i in range(0, len(bucket), self.batch_size): |
| | batch = bucket[i : i + self.batch_size] |
| | param_arr = [self.segment_to_generate_params(segment) for segment in batch] |
| |
|
| | def append_eos(text: str): |
| | text = text.strip() |
| | eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"] |
| | has_eos = False |
| | for eos in eos_arr: |
| | if eos in text: |
| | has_eos = True |
| | break |
| | if not has_eos: |
| | text += self.eos |
| | return text |
| |
|
| | |
| | texts = [append_eos(params.text) for params in param_arr] |
| |
|
| | params = param_arr[0] |
| | audio_datas = generate_audio.generate_audio_batch( |
| | texts=texts, |
| | temperature=params.temperature, |
| | top_P=params.top_P, |
| | top_K=params.top_K, |
| | spk=params.spk, |
| | infer_seed=params.infer_seed, |
| | prompt1=params.prompt1, |
| | prompt2=params.prompt2, |
| | prefix=params.prefix, |
| | ) |
| | for idx, segment in enumerate(batch): |
| | sr, audio_data = audio_datas[idx] |
| | rate = float(segment.get("rate", "1.0")) |
| | volume = float(segment.get("volume", "0")) |
| | pitch = float(segment.get("pitch", "0")) |
| |
|
| | audio_segment = audio_data_to_segment(audio_data, sr) |
| | audio_segment = apply_prosody_to_audio_segment( |
| | audio_segment, rate=rate, volume=volume, pitch=pitch |
| | ) |
| | |
| | original_index = src_segments.index(segment) |
| | audio_segments[original_index] = audio_segment |
| |
|
| | def bucket_segments( |
| | self, segments: List[Union[SSMLSegment, SSMLBreak]] |
| | ) -> List[List[Union[SSMLSegment, SSMLBreak]]]: |
| | buckets = {"<break>": []} |
| | for segment in segments: |
| | if isinstance(segment, SSMLBreak): |
| | buckets["<break>"].append(segment) |
| | continue |
| |
|
| | params = self.segment_to_generate_params(segment) |
| |
|
| | if isinstance(params.spk, Speaker): |
| | params.spk = str(params.spk.id) |
| |
|
| | key = json.dumps( |
| | {k: v for k, v in params.items() if k != "text"}, sort_keys=True |
| | ) |
| | if key not in buckets: |
| | buckets[key] = [] |
| | buckets[key].append(segment) |
| |
|
| | return buckets |
| |
|
| | def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]): |
| | """ |
| | 将 segments 中的 text 经过 spliter 处理成多个 segments |
| | """ |
| | spliter = SentenceSplitter(threshold=self.spliter_thr) |
| | ret_segments: List[Union[SSMLSegment, SSMLBreak]] = [] |
| |
|
| | for segment in segments: |
| | if isinstance(segment, SSMLBreak): |
| | ret_segments.append(segment) |
| | continue |
| |
|
| | text = segment.text |
| | if not text: |
| | continue |
| |
|
| | sentences = spliter.parse(text) |
| | for sentence in sentences: |
| | seg = SSMLSegment( |
| | text=sentence, |
| | attrs=segment.attrs.copy(), |
| | params=copy.copy(segment.params), |
| | ) |
| | ret_segments.append(seg) |
| | setattr(seg, "_idx", len(ret_segments) - 1) |
| |
|
| | def is_none_speak_segment(segment: SSMLSegment): |
| | text = segment.text.strip() |
| | regexp = r"\[[^\]]+?\]" |
| | text = re.sub(regexp, "", text) |
| | text = text.strip() |
| | if not text: |
| | return True |
| | return False |
| |
|
| | |
| | for i in range(1, len(ret_segments)): |
| | if is_none_speak_segment(ret_segments[i]): |
| | ret_segments[i - 1].text += ret_segments[i].text |
| | ret_segments[i].text = "" |
| | |
| | ret_segments = [seg for seg in ret_segments if seg.text.strip()] |
| |
|
| | return ret_segments |
| |
|
| | def synthesize_segments( |
| | self, segments: List[Union[SSMLSegment, SSMLBreak]] |
| | ) -> List[AudioSegment]: |
| | segments = self.split_segments(segments) |
| | audio_segments = [None] * len(segments) |
| | buckets = self.bucket_segments(segments) |
| |
|
| | break_segments = buckets.pop("<break>") |
| | self.process_break_segments(segments, break_segments, audio_segments) |
| |
|
| | buckets = list(buckets.values()) |
| |
|
| | for bucket in buckets: |
| | self.process_voice_segments(segments, bucket, audio_segments) |
| |
|
| | return audio_segments |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | ctx1 = SSMLContext() |
| | ctx1.spk = 1 |
| | ctx1.seed = 42 |
| | ctx1.temp = 0.1 |
| | ctx2 = SSMLContext() |
| | ctx2.spk = 2 |
| | ctx2.seed = 42 |
| | ctx2.temp = 0.1 |
| | ssml_segments = [ |
| | SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), |
| | SSMLBreak(duration_ms=1000), |
| | SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()), |
| | SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()), |
| | ] |
| |
|
| | synthesizer = SynthesizeSegments(batch_size=2) |
| | audio_segments = synthesizer.synthesize_segments(ssml_segments) |
| | print(audio_segments) |
| | combined_audio = combine_audio_segments(audio_segments) |
| | combined_audio.export("output.wav", format="wav") |
| |
|