| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from configuration_bigcodec import BigCodecConfig |
| |
|
| | |
| | from vq.codec_encoder import CodecEncoder_Transformer |
| | from vq.codec_decoder_vocos import CodecDecoderVocos |
| | from vq.module import SemanticEncoder |
| | from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
| |
|
| | class XCodec2Model(PreTrainedModel): |
| | config_class = BigCodecConfig |
| |
|
| | def __init__(self, config: BigCodecConfig): |
| | super().__init__(config) |
| |
|
| | |
| | self.semantic_model = Wav2Vec2BertModel.from_pretrained( |
| | "facebook/w2v-bert-2.0", |
| | output_hidden_states=True |
| | ) |
| | self.semantic_model.eval() |
| |
|
| | self.SemanticEncoder_module = SemanticEncoder( |
| | config.semantic_hidden_size, |
| | config.semantic_hidden_size, |
| | config.semantic_hidden_size |
| | ) |
| |
|
| | |
| | self.CodecEnc = CodecEncoder_Transformer() |
| |
|
| | |
| | self.generator = CodecDecoderVocos() |
| |
|
| | |
| | self.fc_prior = nn.Linear(2048, 2048) |
| | self.fc_post_a = nn.Linear(2048, 1024) |
| | feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") |
| | self.feature_extractor = feature_extractor |
| |
|
| | def forward(self, input_waveform, sample_rate=16000): |
| | """ |
| | 这里的 forward 不一定要叫 forward,也可以拆成别的方法; |
| | 但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。 |
| | |
| | 参数: |
| | input_waveform: [batch_size, waveform_length] |
| | sample_rate: 默认 16000 |
| | 返回: |
| | 重构后的语音音频 (Tensor) |
| | """ |
| | |
| | |
| | input_features = self.feature_extractor( |
| | input_waveform, |
| | sampling_rate=sample_rate, |
| | return_tensors="pt" |
| | ).input_features.to(self.device) |
| |
|
| | |
| | semantic_output = self.semantic_model(input_features) |
| | semantic_hidden_16 = semantic_output.hidden_states[16] |
| | semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) |
| | semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) |
| |
|
| | |
| | wav = input_waveform.unsqueeze(1).to(self.device) |
| | vq_emb = self.CodecEnc(wav) |
| | vq_emb = vq_emb.transpose(1, 2) |
| |
|
| | |
| | |
| | if vq_emb.shape[-1] != semantic_encoded.shape[-1]: |
| | |
| | min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) |
| | vq_emb = vq_emb[:, :, :min_len] |
| | semantic_encoded = semantic_encoded[:, :, :min_len] |
| |
|
| | |
| | concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) |
| |
|
| | |
| | concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) |
| |
|
| | |
| | _, vq_code, _ = self.generator(concat_emb, vq=True) |
| | vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
| | vq_post_emb = vq_post_emb.transpose(1, 2) |
| |
|
| | |
| | vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) |
| |
|
| | |
| | recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] |
| | |
| | return recon_audio |
| |
|
| | def encode_code(self, input_waveform, sample_rate=16000): |
| | """ |
| | 将输入的音频编码为代码表示。 |
| | |
| | 参数: |
| | input_waveform: [batch_size, waveform_length] |
| | sample_rate: 默认 16000 |
| | 返回: |
| | 编码后的代码 (Tensor) |
| | """ |
| | with torch.no_grad(): |
| | |
| | input_features = self.feature_extractor( |
| | input_waveform, |
| | sampling_rate=sample_rate, |
| | return_tensors="pt" |
| | ).input_features.to(self.device) |
| |
|
| | |
| | semantic_output = self.semantic_model(input_features) |
| | semantic_hidden_16 = semantic_output.hidden_states[16] |
| | semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) |
| | semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) |
| |
|
| | |
| | wav = input_waveform.unsqueeze(1).to(self.device) |
| | vq_emb = self.CodecEnc(wav) |
| | vq_emb = vq_emb.transpose(1, 2) |
| |
|
| | |
| | if vq_emb.shape[-1] != semantic_encoded.shape[-1]: |
| | min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) |
| | vq_emb = vq_emb[:, :, :min_len] |
| | semantic_encoded = semantic_encoded[:, :, :min_len] |
| |
|
| | |
| | concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) |
| |
|
| | |
| | concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) |
| |
|
| | |
| | _, vq_code, _ = self.generator(concat_emb, vq=True) |
| | |
| | return vq_code |
| |
|
| | def decode_code(self, vq_code): |
| | """ |
| | 将编码后的代码解码回音频。 |
| | |
| | 参数: |
| | vq_code: 编码后的代码 (Tensor) [batch, frames] |
| | 返回: |
| | 解码后的音频 (Tensor) [batch, waveform_length] |
| | """ |
| | with torch.no_grad(): |
| | |
| | vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
| | vq_post_emb = vq_post_emb.transpose(1, 2) |
| |
|
| | |
| | vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) |
| |
|
| | |
| | recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] |
| | return recon_audio |
| |
|