| import os | |
| import torch | |
| import torchaudio | |
| import whisper | |
| import onnxruntime | |
| import numpy as np | |
| import torchaudio.compliance.kaldi as kaldi | |
| from typing import Callable, List, Union | |
| from functools import partial | |
| from loguru import logger | |
| from VietTTS.utils.frontend_utils import split_text, normalize_text, mel_spectrogram | |
| from VietTTS.tokenizer.tokenizer import get_tokenizer | |
| class TTSFrontEnd: | |
| def __init__( | |
| self, | |
| speech_embedding_model: str, | |
| speech_tokenizer_model: str, | |
| ): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.tokenizer = get_tokenizer() | |
| option = onnxruntime.SessionOptions() | |
| option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| option.intra_op_num_threads = 1 | |
| self.speech_embedding_session = onnxruntime.InferenceSession( | |
| speech_embedding_model, | |
| sess_options=option, | |
| providers=["CPUExecutionProvider"] | |
| ) | |
| self.speech_tokenizer_session = onnxruntime.InferenceSession( | |
| speech_tokenizer_model, | |
| sess_options=option, | |
| providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"] | |
| ) | |
| self.spk2info = {} | |
| def _extract_text_token(self, text: str): | |
| text_token = self.tokenizer.encode(text, allowed_special='all') | |
| text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) | |
| text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) | |
| return text_token, text_token_len | |
| def _extract_speech_token(self, speech: torch.Tensor): | |
| if speech.shape[1] / 16000 > 30: | |
| speech = speech[:, :int(16000 * 30)] | |
| feat = whisper.log_mel_spectrogram(speech, n_mels=128) | |
| speech_token = self.speech_tokenizer_session.run( | |
| None, | |
| {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(), | |
| self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)} | |
| )[0].flatten().tolist() | |
| speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) | |
| speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) | |
| return speech_token, speech_token_len | |
| def _extract_spk_embedding(self, speech: torch.Tensor): | |
| feat = kaldi.fbank( | |
| waveform=speech, | |
| num_mel_bins=80, | |
| dither=0, | |
| sample_frequency=16000 | |
| ) | |
| feat = feat - feat.mean(dim=0, keepdim=True) | |
| embedding = self.speech_embedding_session.run( | |
| None, | |
| {self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()} | |
| )[0].flatten().tolist() | |
| embedding = torch.tensor([embedding]).to(self.device) | |
| return embedding | |
| def _extract_speech_feat(self, speech: torch.Tensor): | |
| speech_feat = mel_spectrogram( | |
| y=speech, | |
| n_fft=1024, | |
| num_mels=80, | |
| sampling_rate=22050, | |
| hop_size=256, | |
| win_size=1024, | |
| fmin=0, | |
| fmax=8000, | |
| center=False | |
| ).squeeze(dim=0).transpose(0, 1).to(self.device) | |
| speech_feat = speech_feat.unsqueeze(dim=0) | |
| speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device) | |
| return speech_feat, speech_feat_len | |
| def preprocess_text(self, text, split=True) -> Union[str, List[str]]: | |
| text = normalize_text(text) | |
| if split: | |
| text = list(split_text( | |
| text=text, | |
| tokenize=partial(self.tokenizer.encode, allowed_special='all'), | |
| token_max_n=30, | |
| token_min_n=10, | |
| merge_len=5, | |
| comma_split=False | |
| )) | |
| return text | |
| def frontend_tts( | |
| self, | |
| text: str, | |
| prompt_speech_16k: Union[np.ndarray, torch.Tensor] | |
| ) -> dict: | |
| if isinstance(prompt_speech_16k, np.ndarray): | |
| prompt_speech_16k = torch.from_numpy(prompt_speech_16k) | |
| text_token, text_token_len = self._extract_text_token(text) | |
| speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k) | |
| prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) | |
| speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050) | |
| embedding = self._extract_spk_embedding(prompt_speech_16k) | |
| model_input = { | |
| 'text': text_token, | |
| 'text_len': text_token_len, | |
| 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, | |
| 'prompt_speech_feat': speech_feat, | |
| 'prompt_speech_feat_len': speech_feat_len, | |
| 'llm_embedding': embedding, | |
| 'flow_embedding': embedding | |
| } | |
| return model_input | |
| def frontend_vc( | |
| self, | |
| source_speech_16k: Union[np.ndarray, torch.Tensor], | |
| prompt_speech_16k: Union[np.ndarray, torch.Tensor] | |
| ) -> dict: | |
| if isinstance(source_speech_16k, np.ndarray): | |
| source_speech_16k = torch.from_numpy(source_speech_16k) | |
| if isinstance(prompt_speech_16k, np.ndarray): | |
| prompt_speech_16k = torch.from_numpy(prompt_speech_16k) | |
| prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k) | |
| prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k) | |
| prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050) | |
| embedding = self._extract_spk_embedding(prompt_speech_16k) | |
| source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k) | |
| model_input = { | |
| 'source_speech_token': source_speech_token, | |
| 'source_speech_token_len': source_speech_token_len, | |
| 'flow_prompt_speech_token': prompt_speech_token, | |
| 'flow_prompt_speech_token_len': prompt_speech_token_len, | |
| 'prompt_speech_feat': prompt_speech_feat, | |
| 'prompt_speech_feat_len': prompt_speech_feat_len, | |
| 'flow_embedding': embedding | |
| } | |
| return model_input | |