| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import LayerNorm |
| import torchaudio.compliance.kaldi as ta_kaldi |
|
|
| from beats.backbone import ( |
| TransformerEncoder, |
| ) |
|
|
| import logging |
| from typing import Optional |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BEATsConfig: |
| def __init__(self, cfg=None): |
| self.input_patch_size: int = -1 |
| self.embed_dim: int = 512 |
| self.conv_bias: bool = False |
|
|
| self.encoder_layers: int = 12 |
| self.encoder_embed_dim: int = 768 |
| self.encoder_ffn_embed_dim: int = 3072 |
| self.encoder_attention_heads: int = 12 |
| self.activation_fn: str = "gelu" |
|
|
| self.layer_wise_gradient_decay_ratio: float = 1.0 |
| self.layer_norm_first: bool = False |
| self.deep_norm: bool = False |
|
|
| |
| self.dropout: float = 0.1 |
| self.attention_dropout: float = 0.1 |
| self.activation_dropout: float = 0.0 |
| self.encoder_layerdrop: float = 0.0 |
| self.dropout_input: float = 0.0 |
|
|
| |
| self.conv_pos: int = 128 |
| self.conv_pos_groups: int = 16 |
|
|
| |
| self.relative_position_embedding: bool = False |
| self.num_buckets: int = 320 |
| self.max_distance: int = 1280 |
| self.gru_rel_pos: bool = False |
|
|
| |
| self.finetuned_model: bool = False |
| self.predictor_dropout: float = 0.1 |
| self.predictor_class: int = 527 |
|
|
| if cfg is not None: |
| self.update(cfg) |
|
|
| def update(self, cfg: dict): |
| self.__dict__.update(cfg) |
|
|
|
|
| class BEATs(nn.Module): |
| def __init__( |
| self, |
| cfg: BEATsConfig, |
| ) -> None: |
| super().__init__() |
| logger.info(f"BEATs Config: {cfg.__dict__}") |
|
|
| self.cfg = cfg |
|
|
| self.embed = cfg.embed_dim |
| self.post_extract_proj = ( |
| nn.Linear(self.embed, cfg.encoder_embed_dim) |
| if self.embed != cfg.encoder_embed_dim |
| else None |
| ) |
|
|
| self.input_patch_size = cfg.input_patch_size |
| self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, |
| bias=cfg.conv_bias) |
|
|
| self.dropout_input = nn.Dropout(cfg.dropout_input) |
|
|
| assert not cfg.deep_norm or not cfg.layer_norm_first |
| self.encoder = TransformerEncoder(cfg) |
| self.layer_norm = LayerNorm(self.embed) |
|
|
| if cfg.finetuned_model: |
| self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) |
| self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) |
| else: |
| self.predictor = None |
|
|
| def forward_padding_mask( |
| self, |
| features: torch.Tensor, |
| padding_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| extra = padding_mask.size(1) % features.size(1) |
| if extra > 0: |
| padding_mask = padding_mask[:, :-extra] |
| padding_mask = padding_mask.view( |
| padding_mask.size(0), features.size(1), -1 |
| ) |
| padding_mask = padding_mask.all(-1) |
| return padding_mask |
|
|
| def preprocess( |
| self, |
| source: torch.Tensor, |
| fbank_mean: float = 15.41663, |
| fbank_std: float = 6.55582, |
| ) -> torch.Tensor: |
| fbanks = [] |
| for waveform in source: |
| waveform = waveform.unsqueeze(0) * 2 ** 15 |
| fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) |
| fbanks.append(fbank) |
| fbank = torch.stack(fbanks, dim=0) |
| fbank = (fbank - fbank_mean) / (2 * fbank_std) |
| return fbank |
|
|
| def extract_features( |
| self, |
| source: torch.Tensor, |
| padding_mask: Optional[torch.Tensor] = None, |
| fbank_mean: float = 15.41663, |
| fbank_std: float = 6.55582, |
| feature_only=False, |
| ): |
| fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) |
|
|
| if padding_mask is not None: |
| padding_mask = self.forward_padding_mask(fbank, padding_mask) |
|
|
| fbank = fbank.unsqueeze(1) |
| features = self.patch_embedding(fbank) |
| features = features.reshape(features.shape[0], features.shape[1], -1) |
| features = features.transpose(1, 2) |
| features = self.layer_norm(features) |
|
|
| if padding_mask is not None: |
| padding_mask = self.forward_padding_mask(features, padding_mask) |
|
|
| if self.post_extract_proj is not None: |
| features = self.post_extract_proj(features) |
|
|
| x = self.dropout_input(features) |
|
|
| x, layer_results = self.encoder( |
| x, |
| padding_mask=padding_mask, |
| ) |
|
|
| if not feature_only and self.predictor is not None: |
| x = self.predictor_dropout(x) |
| logits = self.predictor(x) |
|
|
| if padding_mask is not None and padding_mask.any(): |
| logits[padding_mask] = 0 |
| logits = logits.sum(dim=1) |
| logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) |
| else: |
| logits = logits.mean(dim=1) |
|
|
| lprobs = torch.sigmoid(logits) |
|
|
| return lprobs, padding_mask |
| else: |
| return x, padding_mask |