Spaces:
Running on Zero
Running on Zero
| import os.path | |
| import torch | |
| import torch.nn as nn | |
| from typing import Dict | |
| import logging | |
| from librosa.filters import mel as librosa_mel_fn | |
| import torch.nn.functional as F | |
| from funcineforge.models.utils.nets_utils import make_pad_mask | |
| from funcineforge.utils.device_funcs import to_device | |
| import numpy as np | |
| from funcineforge.utils.load_utils import extract_campp_xvec | |
| import time | |
| from funcineforge.models.utils import dtype_map | |
| from funcineforge.utils.hinter import hint_once | |
| from funcineforge.models.utils.masks import add_optional_chunk_mask | |
| from .modules.dit_flow_matching.dit_model import DiT | |
| class Audio2Mel(nn.Module): | |
| def __init__( | |
| self, | |
| n_fft=1024, | |
| hop_length=256, | |
| win_length=1024, | |
| sampling_rate=22050, | |
| n_mel_channels=80, | |
| mel_fmin=0.0, | |
| mel_fmax=None, | |
| center=False, | |
| device='cuda', | |
| feat_type="power_log", | |
| ): | |
| super().__init__() | |
| ############################################## | |
| # FFT Parameters | |
| ############################################## | |
| window = torch.hann_window(win_length, device=device).float() | |
| mel_basis = librosa_mel_fn( | |
| sr=sampling_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax | |
| ) | |
| mel_basis = torch.from_numpy(mel_basis).float().to(device) | |
| self.register_buffer("mel_basis", mel_basis) | |
| self.register_buffer("window", window) | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.sampling_rate = sampling_rate | |
| self.n_mel_channels = n_mel_channels | |
| self.mel_fmax = mel_fmax | |
| self.center = center | |
| self.feat_type = feat_type | |
| def forward(self, audioin): | |
| p = (self.n_fft - self.hop_length) // 2 | |
| audio = F.pad(audioin, (p, p), "reflect").squeeze(1) | |
| fft = torch.stft( | |
| audio, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| win_length=self.win_length, | |
| window=self.window, | |
| center=self.center, | |
| pad_mode="reflect", | |
| normalized=False, | |
| onesided=True, | |
| return_complex=True, | |
| ) | |
| if self.feat_type == "mag_log10": | |
| power_spec = torch.sqrt(torch.pow(fft.imag, 2) + torch.pow(fft.real, 2)) | |
| mel_output = torch.matmul(self.mel_basis, power_spec) | |
| return torch.log10(torch.clamp(mel_output, min=1e-5)) | |
| power_spec = torch.pow(fft.imag, 2) + torch.pow(fft.real, 2) | |
| mel_spec = torch.matmul(self.mel_basis, torch.sqrt(power_spec + 1e-9)) | |
| return self.spectral_normalize(mel_spec) | |
| def spectral_normalize(cls, spec, C=1, clip_val=1e-5): | |
| output = cls.dynamic_range_compression(spec, C, clip_val) | |
| return output | |
| def spectral_de_normalize_torch(cls, spec, C=1, clip_val=1e-5): | |
| output = cls.dynamic_range_decompression(spec, C, clip_val) | |
| return output | |
| def dynamic_range_compression(x, C=1, clip_val=1e-5): | |
| return torch.log(torch.clamp(x, min=clip_val) * C) | |
| def dynamic_range_decompression(x, C=1): | |
| return torch.exp(x) / C | |
| class LookaheadBlock(nn.Module): | |
| def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1): | |
| super().__init__() | |
| self.channels = channels | |
| self.pre_lookahead_len = pre_lookahead_len | |
| self.conv1 = nn.Conv1d( | |
| in_channels, channels, | |
| kernel_size=pre_lookahead_len+1, | |
| stride=1, padding=0, | |
| ) | |
| self.conv2 = nn.Conv1d( | |
| channels, in_channels, | |
| kernel_size=3, stride=1, padding=0, | |
| ) | |
| def forward(self, inputs, ilens, context: torch.Tensor = torch.zeros(0, 0, 0)): | |
| """ | |
| inputs: (batch_size, seq_len, channels) | |
| """ | |
| outputs = inputs.transpose(1, 2).contiguous() | |
| context = context.transpose(1, 2).contiguous() | |
| # look ahead | |
| if context.size(2) == 0: | |
| outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0) | |
| else: | |
| assert context.size(2) == self.pre_lookahead_len | |
| outputs = torch.concat([outputs, context], dim=2) | |
| outputs = F.leaky_relu(self.conv1(outputs)) | |
| # outputs | |
| outputs = F.pad(outputs, (2, 0), mode='constant', value=0) | |
| outputs = self.conv2(outputs) | |
| outputs = outputs.transpose(1, 2).contiguous() | |
| mask = (~make_pad_mask(ilens).unsqueeze(-1).to(inputs.device)) | |
| # residual connection | |
| outputs = (outputs + inputs) * mask | |
| return outputs, ilens | |
| class CosyVoiceFlowMatching(nn.Module): | |
| def __init__( | |
| self, | |
| codebook_size: int, | |
| model_size: int, | |
| xvec_size: int = 198, | |
| dit_conf: Dict = {}, | |
| mel_feat_conf: Dict = {}, | |
| prompt_conf: Dict = None, | |
| **kwargs): | |
| super().__init__() | |
| # feat related | |
| self.feat_token_ratio = kwargs.get("feat_token_ratio", None) | |
| try: | |
| self.mel_extractor = Audio2Mel(**mel_feat_conf) | |
| self.sample_rate = self.mel_extractor.sampling_rate | |
| except: | |
| self.mel_extractor = None | |
| self.sample_rate = 24000 | |
| self.mel_norm_type = kwargs.get("mel_norm_type", None) | |
| self.num_mels = num_mels = mel_feat_conf["n_mel_channels"] | |
| self.token_rate = kwargs.get("token_rate", 25) | |
| self.model_dtype = kwargs.get("model_dtype", "fp32") | |
| self.codebook_size = codebook_size | |
| # condition related | |
| self.prompt_conf = prompt_conf | |
| if self.prompt_conf is not None: | |
| self.prompt_masker = self.build_prompt_masker() | |
| # codec related | |
| self.codec_embedder = nn.Embedding(codebook_size, num_mels) | |
| lookahead_length = kwargs.get("lookahead_length", 4) | |
| self.lookahead_conv1d = LookaheadBlock(num_mels, model_size, lookahead_length) | |
| # spk embed related | |
| if xvec_size is not None: | |
| self.xvec_proj = torch.nn.Linear(xvec_size, num_mels) | |
| # dit model related | |
| self.dit_conf = dit_conf | |
| self.dit_model = DiT(**dit_conf) | |
| self.training_cfg_rate = kwargs.get("training_cfg_rate", 0) | |
| self.only_mask_loss = kwargs.get("only_mask_loss", True) | |
| # NOTE fm需要右看的下文 | |
| self.context_size = self.lookahead_conv1d.pre_lookahead_len | |
| def build_prompt_masker(self): | |
| prompt_type = self.prompt_conf.get("prompt_type", "free") | |
| if prompt_type == "prefix": | |
| from funcineforge.models.utils.mask_along_axis import MaskTailVariableMaxWidth | |
| masker = MaskTailVariableMaxWidth( | |
| mask_width_ratio_range=self.prompt_conf["prompt_width_ratio_range"], | |
| ) | |
| else: | |
| raise NotImplementedError | |
| return masker | |
| def norm_spk_emb(xvec): | |
| xvec_mask = (~xvec.norm(dim=-1).isnan()) * (~xvec.norm(dim=-1).isinf()) | |
| xvec = xvec * xvec_mask.unsqueeze(-1) | |
| xvec = xvec.mean(dim=1) | |
| xvec = F.normalize(xvec, dim=1) | |
| return xvec | |
| def select_target_prompt(self, y: torch.Tensor, y_lengths: torch.Tensor): | |
| # cond_mask: 1, 1, 1, ..., 0, 0, 0 | |
| cond_mask = self.prompt_masker(y, y_lengths, return_mask=True) | |
| return cond_mask | |
| def normalize_mel_feat(self, feat, feat_lengths): | |
| # feat in B,T,D | |
| if self.mel_norm_type == "mean_std": | |
| max_length = feat.shape[1] | |
| mask = (~make_pad_mask(feat_lengths, maxlen=max_length)) | |
| mask = mask.unsqueeze(-1).to(feat) | |
| mean = ((feat * mask).sum(dim=(1, 2), keepdim=True) / | |
| (mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1])) | |
| var = (((feat - mean)**2 * mask).sum(dim=(1, 2), keepdim=True) / | |
| (mask.sum(dim=(1, 2), keepdim=True) * feat.shape[-1] - 1)) # -1 for unbiased estimation | |
| std = torch.sqrt(var) | |
| feat = (feat - mean) / std | |
| feat = feat * mask | |
| return feat | |
| if self.mel_norm_type == "min_max": | |
| bb, tt, dd = feat.shape | |
| mask = (~make_pad_mask(feat_lengths, maxlen=tt)) | |
| mask = mask.unsqueeze(-1).to(feat) | |
| feat_min = (feat * mask).reshape([bb, tt * dd]).min(dim=1, keepdim=True).values.unsqueeze(-1) | |
| feat_max = (feat * mask).reshape([bb, tt * dd]).max(dim=1, keepdim=True).values.unsqueeze(-1) | |
| feat = (feat - feat_min) / (feat_max - feat_min) | |
| # noise ~ N(0, I), P(x >= 3sigma) = 0.001, 3 is enough. | |
| feat = (feat * 3) * mask # feat in [-3, 3] | |
| return feat | |
| else: | |
| raise NotImplementedError | |
| def extract_feat(self, y: torch.Tensor, y_lengths: torch.Tensor): | |
| mel_extractor = self.mel_extractor.float() | |
| feat = mel_extractor(y) | |
| feat = feat.transpose(1, 2) | |
| feat_lengths = (y_lengths / self.mel_extractor.hop_length).to(y_lengths) | |
| if self.mel_norm_type is not None: | |
| feat = self.normalize_mel_feat(feat, feat_lengths) | |
| return feat, feat_lengths | |
| def load_data(self, contents: dict, **kwargs): | |
| fm_use_prompt = kwargs.get("fm_use_prompt", True) | |
| # codec | |
| codec = contents["codec"] | |
| if isinstance(codec, np.ndarray): | |
| codec = torch.from_numpy(codec) | |
| # codec = torch.from_numpy(codec)[None, :] | |
| codec_lengths = torch.tensor([codec.shape[1]], dtype=torch.int64) | |
| # prompt codec (optional) | |
| prompt_codec = kwargs.get("prompt_codec", None) | |
| prompt_codec_lengths = None | |
| if prompt_codec is not None and fm_use_prompt: | |
| if isinstance(prompt_codec, str) and os.path.exists(prompt_codec): | |
| prompt_codec = np.load(prompt_codec) | |
| if isinstance(prompt_codec, np.ndarray): | |
| prompt_codec = torch.from_numpy(prompt_codec)[None, :] | |
| prompt_codec_lengths = torch.tensor([prompt_codec.shape[1]], dtype=torch.int64) | |
| else: | |
| prompt_codec = None | |
| spk_emb = kwargs.get("spk_emb", None) | |
| spk_emb_lengths = None | |
| if spk_emb is not None: | |
| if isinstance(spk_emb, str) and os.path.exists(spk_emb): | |
| spk_emb = np.load(spk_emb) | |
| if isinstance(spk_emb, np.ndarray): | |
| spk_emb = torch.from_numpy(spk_emb)[None, :] | |
| spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64) | |
| # prompt wav as condition | |
| prompt_wav = contents["vocal"] | |
| prompt_wav_lengths = None | |
| if prompt_wav is not None and fm_use_prompt and os.path.exists(prompt_wav): | |
| if prompt_wav.endswith(".npy"): | |
| spk_emb = np.load(prompt_wav) | |
| spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64) | |
| else: | |
| spk_emb = extract_campp_xvec(prompt_wav, **kwargs) | |
| spk_emb = torch.from_numpy(spk_emb) | |
| spk_emb_lengths = torch.tensor([spk_emb.shape[1]], dtype=torch.int64) | |
| # prompt_wav = load_audio_text_image_video(prompt_wav, fs=self.sample_rate) | |
| # prompt_wav = prompt_wav[None, :] | |
| # prompt_wav_lengths = torch.tensor([prompt_wav.shape[1]], dtype=torch.int64) | |
| else: | |
| logging.info("[error] prompt_wav is None or not path or path not exists! Please provide the correct speaker embedding.") | |
| output = { | |
| "codec": codec, | |
| "codec_lengths": codec_lengths, | |
| "prompt_codec": prompt_codec, | |
| "prompt_codec_lengths": prompt_codec_lengths, | |
| "prompt_wav": None, | |
| "prompt_wav_lengths": None, | |
| "xvec": spk_emb, | |
| "xvec_lengths": spk_emb_lengths, | |
| } | |
| return output | |
| def inference( | |
| self, | |
| data_in, | |
| data_lengths=None, | |
| key: list = None, | |
| chunk_size: int = -1, | |
| finalize: bool = True, | |
| **kwargs, | |
| ): | |
| uttid = key[0] | |
| if kwargs.get("batch_size", 1) > 1: | |
| raise NotImplementedError("batch decoding is not implemented") | |
| batch = self.load_data(data_in[0], **kwargs) | |
| batch = to_device(batch, kwargs["device"]) | |
| batch.update({'finalize': finalize, 'chunk_size': chunk_size}) | |
| feat = self._inference(**batch, **kwargs) | |
| feat = feat.float() | |
| logging.info(f"{uttid}: feat lengths {feat.shape[1]}") | |
| return feat | |
| def _inference( | |
| self, | |
| codec, codec_lengths, | |
| prompt_codec=None, prompt_codec_lengths=None, | |
| prompt_wav=None, prompt_wav_lengths=None, | |
| xvec=None, xvec_lengths=None, chunk_size=-1, finalize=False, | |
| **kwargs | |
| ): | |
| fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")] | |
| rand_xvec = None | |
| if xvec is not None: | |
| if xvec.dim() == 2: | |
| xvec = xvec.unsqueeze(1) | |
| xvec_lens = torch.ones_like(xvec_lengths) | |
| rand_xvec = self.norm_spk_emb(xvec) | |
| self.xvec_proj.to(fm_dtype) | |
| rand_xvec = self.xvec_proj(rand_xvec.to(fm_dtype)) | |
| rand_xvec = rand_xvec.unsqueeze(1) | |
| if (codec >= self.codebook_size).any(): | |
| new_codec = codec[codec < self.codebook_size].unsqueeze(0) | |
| logging.info(f"remove out-of-range token for FM: from {codec.shape[1]} to {new_codec.shape[1]}.") | |
| codec_lengths = codec_lengths - (codec.shape[1] - new_codec.shape[1]) | |
| codec = new_codec | |
| if prompt_codec is not None: | |
| codec, codec_lengths = self.concat_prompt(prompt_codec, prompt_codec_lengths, codec, codec_lengths) | |
| mask = (codec != -1).float().unsqueeze(-1) | |
| codec_emb = self.codec_embedder(torch.clamp(codec, min=0)) * mask | |
| self.lookahead_conv1d.to(fm_dtype) | |
| if finalize is True: | |
| context = torch.zeros(1, 0, self.codec_embedder.embedding_dim).to(fm_dtype) | |
| else: | |
| codec_emb, context = codec_emb[:, :-self.context_size].to(fm_dtype), codec_emb[:, -self.context_size:].to(fm_dtype) | |
| codec_lengths = codec_lengths - self.context_size | |
| mu, _ = self.lookahead_conv1d(codec_emb, codec_lengths, context) | |
| mu = mu.repeat_interleave(self.feat_token_ratio, dim=1) | |
| # print(mu.size()) | |
| conditions = torch.zeros([mu.size(0), mu.shape[1], self.num_mels]).to(mu) | |
| # get conditions | |
| if prompt_wav is not None: | |
| if prompt_wav.ndim == 2: | |
| prompt_wav, prompt_wav_lengths = self.extract_feat(prompt_wav, prompt_wav_lengths) | |
| # NOTE 在fmax12k fm中,尝试mel interploate成token 2倍shape,而不是强制截断 | |
| prompt_wav = prompt_wav.to(fm_dtype) | |
| for i, _len in enumerate(prompt_wav_lengths): | |
| conditions[i, :_len] = prompt_wav[i] | |
| feat_lengths = codec_lengths * self.feat_token_ratio | |
| # NOTE add_optional_chunk_mask支持生成-1/1/15/30不同chunk_size的mask | |
| mask = add_optional_chunk_mask(mu, torch.ones([1, 1, mu.shape[1]]).to(mu).bool(), False, False, 0, chunk_size, -1) | |
| feat = self.solve_ode(mu, rand_xvec, conditions.to(fm_dtype), mask, **kwargs) | |
| if prompt_codec is not None and prompt_wav is not None: | |
| feat, feat_lens = self.remove_prompt(None, prompt_wav_lengths, feat, feat_lengths) | |
| return feat | |
| def concat_prompt(prompt, prompt_lengths, text, text_lengths): | |
| xs_list, x_len_list = [], [] | |
| for idx, (_prompt_len, _text_len) in enumerate(zip(prompt_lengths, text_lengths)): | |
| xs_list.append(torch.concat([prompt[idx, :_prompt_len], text[idx, :_text_len]], dim=0)) | |
| x_len_list.append(_prompt_len + _text_len) | |
| xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0) | |
| x_lens = torch.tensor(x_len_list, dtype=torch.int64).to(xs.device) | |
| return xs, x_lens | |
| def remove_prompt(prompt, prompt_lengths, padded, padded_lengths): | |
| xs_list = [] | |
| for idx, (_prompt_len, _x_len) in enumerate(zip(prompt_lengths, padded_lengths)): | |
| xs_list.append(padded[idx, _prompt_len: _x_len]) | |
| xs = torch.nn.utils.rnn.pad_sequence(xs_list, batch_first=True, padding_value=0.0) | |
| return xs, padded_lengths - prompt_lengths | |
| def get_rand_noise(self, mu: torch.Tensor, **kwargs): | |
| use_fixed_noise_infer = kwargs.get("use_fixed_noise_infer", True) | |
| max_len = kwargs.get("max_len", 50*300) | |
| if use_fixed_noise_infer: | |
| if not hasattr(self, "rand_noise") or self.rand_noise is None or self.rand_noise.shape[2] < mu.shape[2]: | |
| self.rand_noise = torch.randn([1, max_len, mu.shape[2]]).to(mu) | |
| logging.info("init random noise for Flow") | |
| # return self.rand_noise[:, :mu.shape[1], :] | |
| return torch.concat([self.rand_noise[:, :mu.shape[1], :] for _ in range(mu.size(0))], dim = 0) | |
| else: | |
| return torch.randn_like(mu) | |
| def solve_ode(self, mu, rand_xvec, conditions, mask, **kwargs): | |
| fm_dtype = dtype_map[kwargs.get("fm_dtype", "fp32")] | |
| temperature = kwargs.get("temperature", 1.0) | |
| n_timesteps = kwargs.get("n_timesteps", 10) | |
| infer_t_scheduler = kwargs.get("infer_t_scheduler", "cosine") | |
| z = self.get_rand_noise(mu) * temperature | |
| # print("z", z.size(), "mu", mu.size()) | |
| t_span = torch.linspace(0, 1, n_timesteps + 1).to(mu) | |
| # print("t_span", t_span) | |
| if infer_t_scheduler == 'cosine': | |
| t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) | |
| fm_time = time.time() | |
| self.dit_model.to(fm_dtype) | |
| feat = self.solve_euler( | |
| z.to(fm_dtype), t_span=t_span.to(fm_dtype), mu=mu.to(fm_dtype), mask=mask, | |
| spks=rand_xvec.to(fm_dtype), cond=conditions.to(fm_dtype), **kwargs | |
| ) | |
| escape_time = (time.time() - fm_time) * 1000.0 | |
| logging.info(f"fm dec {n_timesteps} step time: {escape_time:.2f}, avg {escape_time/n_timesteps:.2f} ms") | |
| return feat | |
| def solve_euler(self, x, t_span, mu, mask, spks=None, cond=None, **kwargs): | |
| """ | |
| Fixed euler solver for ODEs. | |
| Args: | |
| x (torch.Tensor): random noise | |
| t_span (torch.Tensor): n_timesteps interpolated | |
| shape: (n_timesteps + 1,) | |
| mu (torch.Tensor): output of encoder | |
| shape: (batch_size, n_feats, mel_timesteps) | |
| mask (torch.Tensor): output_mask | |
| shape: (batch_size, 1, mel_timesteps) | |
| spks (torch.Tensor, optional): speaker ids. Defaults to None. | |
| shape: (batch_size, spk_emb_dim) | |
| cond: Not used but kept for future purposes | |
| """ | |
| inference_cfg_rate = kwargs.get("inference_cfg_rate", 0.7) | |
| t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] | |
| # print("solve_euler cond", cond.size()) | |
| steps = 1 | |
| z, bz = x, x.shape[0] | |
| while steps <= len(t_span) - 1: | |
| if inference_cfg_rate > 0: | |
| x_in = torch.concat([x, x], dim=0) | |
| spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0) | |
| mask_in = torch.concat([mask, mask], dim=0) | |
| mu_in = torch.concat([mu, torch.zeros_like(mu)], dim=0) | |
| t_in = torch.concat([t.unsqueeze(0) for _ in range(mu_in.size(0))], dim=0) | |
| if isinstance(cond, torch.Tensor): | |
| cond_in = torch.concat([cond, torch.zeros_like(cond)], dim=0) | |
| else: | |
| cond_in = dict( | |
| prompt=[ | |
| torch.concat([cond["prompt"][0], torch.zeros_like(cond["prompt"][0])], dim=0), | |
| torch.concat([cond["prompt"][1], cond["prompt"][1]], dim=0), | |
| ] | |
| ) | |
| else: | |
| x_in, mask_in, mu_in, spks_in, t_in, cond_in = x, mask, mu, spks, t, cond | |
| # if spks is not None: | |
| # cond_in = cond_in + spks | |
| infer_causal_mask_type = kwargs.get("infer_causal_mask_type", 0) | |
| chunk_mask_value = self.dit_model.causal_mask_type[infer_causal_mask_type]["prob_min"] | |
| hint_once( | |
| f"flow mask type: {infer_causal_mask_type}, mask_rank value: {chunk_mask_value}.", | |
| "chunk_mask_value" | |
| ) | |
| # print("dit_model cond", x_in.size(), cond_in.size(), mu_in.size(), spks_in.size(), t_in.size()) | |
| # print(t_in) | |
| dphi_dt = self.dit_model( | |
| x_in, cond_in, mu_in, spks_in, t_in, | |
| mask=mask_in, | |
| mask_rand=torch.ones_like(t_in).reshape(-1, 1, 1) * chunk_mask_value | |
| ) | |
| if inference_cfg_rate > 0: | |
| dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [bz, bz], dim=0) | |
| dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt - | |
| inference_cfg_rate * cfg_dphi_dt) | |
| x = x + dt * dphi_dt | |
| t = t + dt | |
| # sol.append(x) | |
| if steps < len(t_span) - 1: | |
| dt = t_span[steps + 1] - t | |
| steps += 1 | |
| return x |