| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| from patch_utils import MindSpeedPatchesManager as aspm |
| import os |
| import torch |
| import torch.nn as nn |
| import logging |
| import torchaudio.transforms as trans |
| from s3prl.upstream.wavlm.expert import UpstreamExpert as s3prl_UpstreamExpert |
| from models.ecapa_tdnn import Conv1dReluBn, SE_Res2Block, AttentiveStatsPool |
| from models.ecapa_tdnn import ECAPA_TDNN_SMALL, ECAPA_TDNN |
|
|
| def init_model_patched(model_name, checkpoint=None): |
| S3PRL_PATH = os.environ.get("S3PRL_PATH") |
| if model_name == 'unispeech_sat': |
| config_path = 'config/unispeech_sat.th' |
| model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path) |
| elif model_name == 'wavlm_base_plus': |
| config_path = None |
| model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path) |
| elif model_name == 'wavlm_large': |
| config_path = S3PRL_PATH |
| model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path) |
| elif model_name == 'hubert_large': |
| config_path = None |
| model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path) |
| elif model_name == 'wav2vec2_xlsr': |
| config_path = None |
| model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path) |
| else: |
| model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank') |
|
|
| if checkpoint is not None: |
| state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) |
| model.load_state_dict(state_dict['model'], strict=False) |
| return model |
|
|
|
|
| class patched_ECAPA_TDNN(ECAPA_TDNN): |
| def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False, |
| feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): |
| super(ECAPA_TDNN, self).__init__() |
|
|
| self.feat_type = feat_type |
| self.feature_selection = feature_selection |
| self.update_extract = update_extract |
| self.sr = sr |
|
|
| if feat_type == "fbank" or feat_type == "mfcc": |
| self.update_extract = False |
|
|
| win_len = int(sr * 0.025) |
| hop_len = int(sr * 0.01) |
|
|
| if feat_type == 'fbank': |
| self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len, |
| hop_length=hop_len, f_min=0.0, f_max=sr // 2, |
| pad=0, n_mels=feat_dim) |
| elif feat_type == 'mfcc': |
| melkwargs = { |
| 'n_fft': 512, |
| 'win_length': win_len, |
| 'hop_length': hop_len, |
| 'f_min': 0.0, |
| 'f_max': sr // 2, |
| 'pad': 0 |
| } |
| self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False, |
| melkwargs=melkwargs) |
| else: |
| if config_path is None: |
| self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type) |
| else: |
| self.feature_extract = s3prl_UpstreamExpert(config_path) |
| if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): |
| self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False |
| if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): |
| self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False |
|
|
| self.feat_num = self.get_feat_num() |
| self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) |
|
|
| if feat_type != 'fbank' and feat_type != 'mfcc': |
| freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer'] |
| for name, param in self.feature_extract.named_parameters(): |
| for freeze_val in freeze_list: |
| if freeze_val in name: |
| param.requires_grad = False |
| break |
|
|
| if not self.update_extract: |
| for param in self.feature_extract.parameters(): |
| param.requires_grad = False |
|
|
| self.instance_norm = nn.InstanceNorm1d(feat_dim) |
| |
| self.channels = [channels] * 4 + [1536] |
|
|
| self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) |
| self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128) |
| self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128) |
| self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128) |
|
|
| |
| cat_channels = channels * 3 |
| self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) |
| self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att) |
| self.bn = nn.BatchNorm1d(self.channels[-1] * 2) |
| self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) |
|
|
|
|
| def patched_ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None): |
| return patched_ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim, |
| feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path) |
|
|
| def patch_for_npu(): |
| aspm.register_patch('models.ecapa_tdnn.ECAPA_TDNN_SMALL', patched_ECAPA_TDNN_SMALL) |
| aspm.register_patch('verification.init_model', init_model_patched) |
| aspm.apply_patches() |