| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | from patch_utils import MindSpeedPatchesManager as aspm |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import logging |
| | import torchaudio.transforms as trans |
| | from s3prl.upstream.wavlm.expert import UpstreamExpert as s3prl_UpstreamExpert |
| | from models.ecapa_tdnn import Conv1dReluBn, SE_Res2Block, AttentiveStatsPool |
| | from models.ecapa_tdnn import ECAPA_TDNN_SMALL, ECAPA_TDNN |
| |
|
| | def init_model_patched(model_name, checkpoint=None): |
| | S3PRL_PATH = os.environ.get("S3PRL_PATH") |
| | if model_name == 'unispeech_sat': |
| | config_path = 'config/unispeech_sat.th' |
| | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path) |
| | elif model_name == 'wavlm_base_plus': |
| | config_path = None |
| | model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path) |
| | elif model_name == 'wavlm_large': |
| | config_path = S3PRL_PATH |
| | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path) |
| | elif model_name == 'hubert_large': |
| | config_path = None |
| | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path) |
| | elif model_name == 'wav2vec2_xlsr': |
| | config_path = None |
| | model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path) |
| | else: |
| | model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank') |
| |
|
| | if checkpoint is not None: |
| | state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) |
| | model.load_state_dict(state_dict['model'], strict=False) |
| | return model |
| |
|
| |
|
| | class patched_ECAPA_TDNN(ECAPA_TDNN): |
| | def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False, |
| | feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): |
| | super(ECAPA_TDNN, self).__init__() |
| |
|
| | self.feat_type = feat_type |
| | self.feature_selection = feature_selection |
| | self.update_extract = update_extract |
| | self.sr = sr |
| |
|
| | if feat_type == "fbank" or feat_type == "mfcc": |
| | self.update_extract = False |
| |
|
| | win_len = int(sr * 0.025) |
| | hop_len = int(sr * 0.01) |
| |
|
| | if feat_type == 'fbank': |
| | self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len, |
| | hop_length=hop_len, f_min=0.0, f_max=sr // 2, |
| | pad=0, n_mels=feat_dim) |
| | elif feat_type == 'mfcc': |
| | melkwargs = { |
| | 'n_fft': 512, |
| | 'win_length': win_len, |
| | 'hop_length': hop_len, |
| | 'f_min': 0.0, |
| | 'f_max': sr // 2, |
| | 'pad': 0 |
| | } |
| | self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False, |
| | melkwargs=melkwargs) |
| | else: |
| | if config_path is None: |
| | self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type) |
| | else: |
| | self.feature_extract = s3prl_UpstreamExpert(config_path) |
| | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): |
| | self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False |
| | if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): |
| | self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False |
| |
|
| | self.feat_num = self.get_feat_num() |
| | self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) |
| |
|
| | if feat_type != 'fbank' and feat_type != 'mfcc': |
| | freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer'] |
| | for name, param in self.feature_extract.named_parameters(): |
| | for freeze_val in freeze_list: |
| | if freeze_val in name: |
| | param.requires_grad = False |
| | break |
| |
|
| | if not self.update_extract: |
| | for param in self.feature_extract.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.instance_norm = nn.InstanceNorm1d(feat_dim) |
| | |
| | self.channels = [channels] * 4 + [1536] |
| |
|
| | self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) |
| | self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128) |
| | self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128) |
| | self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128) |
| |
|
| | |
| | cat_channels = channels * 3 |
| | self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) |
| | self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att) |
| | self.bn = nn.BatchNorm1d(self.channels[-1] * 2) |
| | self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) |
| |
|
| |
|
| | def patched_ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): |
| | return patched_ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim, |
| | feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path) |
| |
|
| | def patch_for_npu(): |
| | aspm.register_patch('models.ecapa_tdnn.ECAPA_TDNN_SMALL', patched_ECAPA_TDNN_SMALL) |
| | aspm.register_patch('verification.init_model', init_model_patched) |
| | aspm.apply_patches() |