| """GLAP (Generalized Language Audio Pretraining) HuggingFace model. |
| |
| Audio encoder adapted from dasheng-denoiser (Apache 2.0). |
| Text encoder adapted from SONAR standalone (Apache 2.0). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from pathlib import Path |
| from typing import List, Optional, Sequence |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from einops import rearrange |
| from einops.layers.torch import Rearrange |
| from transformers import PreTrainedModel |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| except ImportError: |
| hf_hub_download = None |
|
|
| from .configuration_glap import GlapConfig |
|
|
|
|
| |
| |
| |
|
|
|
|
| class FrontEnd(nn.Sequential): |
| def __init__( |
| self, |
| f_min: int = 0, |
| sample_rate: int = 16000, |
| win_size: int = 512, |
| center: bool = True, |
| n_fft: int = 512, |
| f_max: Optional[int] = 8000, |
| hop_size: int = 160, |
| n_mels: int = 64, |
| ): |
| audio_transforms = __import__("importlib").import_module( |
| "torchaudio.transforms" |
| ) |
|
|
| self.f_min = f_min |
| self.sample_rate = sample_rate |
| self.win_size = win_size |
| self.center = center |
| self.n_fft = n_fft |
| self.f_max = f_max |
| self.hop_size = hop_size |
| self.n_mels = n_mels |
|
|
| with torch.device("cpu"): |
| super().__init__( |
| audio_transforms.MelSpectrogram( |
| f_min=self.f_min, |
| sample_rate=self.sample_rate, |
| win_length=self.win_size, |
| center=self.center, |
| n_fft=self.n_fft, |
| f_max=self.f_max, |
| hop_length=self.hop_size, |
| n_mels=self.n_mels, |
| ), |
| audio_transforms.AmplitudeToDB(top_db=120), |
| ) |
|
|
| @torch.autocast(enabled=False, device_type="cuda") |
| def forward(self, x, attention_mask=None): |
| features = super().forward(x) |
| if attention_mask is not None: |
| lengths = attention_mask.float().sum(-1) // self.hop_size |
| attention_mask = ( |
| torch.arange(features.shape[-1], device=features.device) |
| < lengths.unsqueeze(-1) |
| ).int() |
| return features, attention_mask |
|
|
|
|
| class Mlp(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: Optional[int] = None, |
| out_features: Optional[int] = None, |
| act_layer: type[nn.Module] = nn.GELU, |
| drop: float = 0.0, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class AudioAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| qkv_bias: bool = True, |
| attn_drop: float = 0.0, |
| proj_drop: float = 0.0, |
| ): |
| super().__init__() |
| self.num_heads = num_heads |
| self.scale = (dim // num_heads) ** -0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x, mask: Optional[torch.Tensor] = None): |
| B, N, C = x.shape |
| qkv = ( |
| self.qkv(x) |
| .reshape(B, N, 3, self.num_heads, C // self.num_heads) |
| .permute(2, 0, 3, 1, 4) |
| ) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| if mask is not None: |
| if mask.dtype != torch.bool: |
| padding_mask = mask == 0 |
| else: |
| padding_mask = mask |
| padding_mask = padding_mask.view(B, 1, 1, N) |
| attn = attn.masked_fill(padding_mask, float("-inf")) |
|
|
| attn = attn.softmax(dim=-1).nan_to_num() |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| return self.proj_drop(self.proj(x)) |
|
|
|
|
| class AudioBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| drop: float = 0.0, |
| attn_drop: float = 0.0, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
| self.attn = AudioAttention(dim, num_heads, qkv_bias, attn_drop, drop) |
| self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
| self.mlp = Mlp( |
| in_features=dim, |
| hidden_features=int(dim * mlp_ratio), |
| act_layer=nn.GELU, |
| drop=drop, |
| ) |
|
|
| def forward(self, x, mask=None): |
| x = x + self.attn(self.norm1(x), mask=mask) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class AudioPatchEmbed(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super().__init__() |
| self.stride = kwargs.get("stride", [None, 4])[-1] |
| self.proj = nn.Conv2d(*args, **kwargs) |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): |
| x = self.proj(x) |
| if attention_mask is not None: |
| lengths = attention_mask.float().sum(-1) // self.stride |
| attention_mask = ( |
| torch.arange(x.shape[-1], device=x.device) < lengths.unsqueeze(-1) |
| ).int() |
| return x, attention_mask |
|
|
|
|
| class DashengAudioEncoder(nn.Module): |
| """Dasheng audio encoder matching the original DashengWrapper. |
| |
| Produces a single (B, embed_dim) embedding per audio input. |
| Pads spectrogram to a multiple of target_length, splits into chunks, |
| processes each chunk independently through the Transformer, then |
| mean-pools across chunks. |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int = 768, |
| depth: int = 12, |
| num_heads: int = 12, |
| patch_size: list = None, |
| patch_stride: list = None, |
| target_length: int = 1008, |
| ): |
| super().__init__() |
| patch_size = patch_size or [64, 4] |
| patch_stride = patch_stride or [64, 4] |
| self.embed_dim = embed_dim |
| self.target_length = target_length |
| self.patch_stride = patch_stride |
| self.time_patches = patch_stride[-1] |
| self.max_t_tokens = target_length // self.time_patches |
|
|
| self.front_end = FrontEnd() |
| self.patch_embed = AudioPatchEmbed( |
| 1, embed_dim, kernel_size=patch_size, stride=patch_stride |
| ) |
| self.init_bn = nn.Sequential( |
| Rearrange("b c f t -> b f c t"), |
| nn.BatchNorm2d(self.front_end.n_mels, momentum=0.01), |
| Rearrange("b f c t -> b c f t"), |
| ) |
|
|
| self.time_pos_embed = nn.Parameter( |
| torch.randn(1, embed_dim, 1, target_length // self.time_patches) * 0.02 |
| ) |
| self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1) * 0.02) |
|
|
| self.blocks = nn.ModuleList( |
| [AudioBlock(embed_dim, num_heads) for _ in range(depth)] |
| ) |
| self.norm = nn.LayerNorm(embed_dim, eps=1e-6) |
|
|
| def _forward_chunk(self, x, attention_mask=None): |
| x, attention_mask = self.patch_embed(x, attention_mask) |
| t = x.shape[-1] |
| x = x + self.time_pos_embed[:, :, :, :t] + self.freq_pos_embed |
| x = rearrange(x, "b c f t -> b (f t) c") |
| for block in self.blocks: |
| x = block(x, mask=attention_mask) |
| x = self.norm(x) |
| return x |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| |
| x, attention_mask = self.front_end(x, attention_mask) |
| x = rearrange(x, "b f t -> b 1 f t") |
| x = self.init_bn(x) |
|
|
| |
| if x.shape[-1] > self.target_length: |
| remainder = x.shape[-1] % self.target_length |
| if remainder != 0: |
| pad_amount = self.target_length - remainder |
| x = F.pad(x, (0, pad_amount)) |
|
|
| |
| input_splits = x.split(self.target_length, dim=-1) |
| masks = [None for _ in range(len(input_splits))] |
|
|
| |
| outputs = [] |
| chunk_size_in_patches = self.target_length // self.patch_stride[-1] |
| for input_split_x in input_splits: |
| output = self._forward_chunk(input_split_x, attention_mask=None) |
| |
| chunks = output.split(chunk_size_in_patches, dim=1) |
| chunk_means = [c.mean(1) for c in chunks] |
| outputs.append(torch.stack(chunk_means).mean(0)) |
|
|
| |
| emb = torch.stack(outputs).mean(0) |
| return emb |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SinusoidalPositionEncoder(nn.Module): |
| def __init__(self, encoding_dim: int, max_seq_len: int, _legacy_pad_idx: int = 1): |
| super().__init__() |
| assert encoding_dim % 2 == 0 |
| self.encoding_dim = encoding_dim |
| self.max_seq_len = max_seq_len |
| self._legacy_pad_idx = _legacy_pad_idx |
| start_step = 1 + _legacy_pad_idx |
| steps = torch.arange(start_step, start_step + max_seq_len, dtype=torch.float32) |
| self.register_buffer( |
| "freqs", self._build_freqs(steps, encoding_dim), persistent=False |
| ) |
|
|
| @staticmethod |
| def _build_freqs(steps: Tensor, encoding_dim: int) -> Tensor: |
| num_sin = encoding_dim // 2 |
| indices = torch.arange(num_sin, dtype=torch.float32) |
| freq_vals = torch.exp(indices * -math.log(10000.0) / (num_sin - 1)) |
| l_half = torch.outer(steps, freq_vals) |
| r_half = l_half[:, : encoding_dim - num_sin].clone() |
| return torch.cat([l_half.sin(), r_half.cos()], dim=-1) |
|
|
| def forward(self, seqs: Tensor) -> Tensor: |
| seq_len = seqs.size(-2) |
| return (seqs.float() + self.freqs[:seq_len]).type_as(seqs) |
|
|
|
|
| class SonarMultiheadAttention(nn.Module): |
| def __init__(self, model_dim: int, num_heads: int, dropout_p: float = 0.0): |
| super().__init__() |
| self.model_dim = model_dim |
| self.num_heads = num_heads |
| self.head_dim = model_dim // num_heads |
| assert model_dim % num_heads == 0 |
|
|
| self.q_proj = nn.Linear(model_dim, model_dim, bias=True) |
| self.k_proj = nn.Linear(model_dim, model_dim, bias=True) |
| self.v_proj = nn.Linear(model_dim, model_dim, bias=True) |
| self.output_proj = nn.Linear(model_dim, model_dim, bias=True) |
| self.attn_dropout_p = dropout_p |
|
|
| def forward( |
| self, |
| queries: Tensor, |
| keys: Tensor, |
| values: Tensor, |
| padding_mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| bsz, seq_len, _ = queries.shape |
|
|
| q = ( |
| self.q_proj(queries) |
| .view(bsz, seq_len, self.num_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| k = ( |
| self.k_proj(keys) |
| .view(bsz, -1, self.num_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
| v = ( |
| self.v_proj(values) |
| .view(bsz, -1, self.num_heads, self.head_dim) |
| .transpose(1, 2) |
| ) |
|
|
| scale = self.head_dim**-0.5 |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale |
|
|
| if padding_mask is not None: |
| attn_weights = attn_weights.masked_fill( |
| padding_mask[:, None, None, :], float("-inf") |
| ) |
|
|
| attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) |
|
|
| if self.training and self.attn_dropout_p > 0.0: |
| attn_weights = F.dropout(attn_weights, p=self.attn_dropout_p) |
|
|
| attn = torch.matmul(attn_weights, v) |
| attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, self.model_dim) |
| return self.output_proj(attn) |
|
|
|
|
| class _FeedForwardNetwork(nn.Module): |
| def __init__(self, model_dim: int, inner_dim: int, dropout_p: float = 0.1): |
| super().__init__() |
| self.inner_proj = nn.Linear(model_dim, inner_dim, bias=True) |
| self.output_proj = nn.Linear(inner_dim, model_dim, bias=True) |
| self.dropout = nn.Dropout(dropout_p) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = self.inner_proj(x) |
| x = F.relu(x) |
| x = self.dropout(x) |
| x = self.output_proj(x) |
| return x |
|
|
|
|
| class SonarTransformerEncoderLayer(nn.Module): |
| def __init__( |
| self, model_dim: int, num_heads: int, ffn_inner_dim: int, dropout_p: float = 0.1 |
| ): |
| super().__init__() |
| self.self_attn_layer_norm = nn.LayerNorm(model_dim) |
| self.self_attn = SonarMultiheadAttention( |
| model_dim, num_heads, dropout_p=dropout_p |
| ) |
| self.ffn_layer_norm = nn.LayerNorm(model_dim) |
| self.ffn = _FeedForwardNetwork(model_dim, ffn_inner_dim, dropout_p) |
|
|
| def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor: |
| residual = seqs |
| seqs = self.self_attn_layer_norm(seqs) |
| seqs = self.self_attn(seqs, seqs, seqs, padding_mask) |
| seqs = seqs + residual |
|
|
| residual = seqs |
| seqs = self.ffn_layer_norm(seqs) |
| seqs = self.ffn(seqs) |
| seqs = seqs + residual |
| return seqs |
|
|
|
|
| class _SonarTransformerEncoder(nn.Module): |
| def __init__(self, layers: nn.ModuleList): |
| super().__init__() |
| self.layers = layers |
|
|
| def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor: |
| for layer in self.layers: |
| seqs = layer(seqs, padding_mask) |
| return seqs |
|
|
|
|
| class _SonarEmbeddingFrontend(nn.Module): |
| def __init__( |
| self, |
| embed: nn.Embedding, |
| pos_encoder: SinusoidalPositionEncoder, |
| dropout_p: float = 0.1, |
| ): |
| super().__init__() |
| self.embed = embed |
| self.pos_encoder = pos_encoder |
| self.dropout = nn.Dropout(dropout_p) |
|
|
| def forward(self, token_ids: Tensor) -> Tensor: |
| seqs = self.embed(token_ids) |
| seqs = seqs * math.sqrt(seqs.size(-1)) |
| seqs = self.pos_encoder(seqs) |
| seqs = self.dropout(seqs) |
| return seqs |
|
|
|
|
| class SonarTextEncoder(nn.Module): |
| """24-layer SONAR text encoder with sinusoidal PE and mean pooling.""" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 256206, |
| model_dim: int = 1024, |
| num_layers: int = 24, |
| num_heads: int = 16, |
| ffn_inner_dim: int = 8192, |
| max_seq_len: int = 514, |
| pad_idx: int = 0, |
| dropout_p: float = 0.1, |
| ): |
| super().__init__() |
| self.model_dim = model_dim |
| self.pad_idx = pad_idx |
|
|
| embed = nn.Embedding(vocab_size, model_dim, padding_idx=pad_idx) |
| pos_encoder = SinusoidalPositionEncoder( |
| model_dim, max_seq_len, _legacy_pad_idx=1 |
| ) |
| self.encoder_frontend = _SonarEmbeddingFrontend(embed, pos_encoder, dropout_p) |
|
|
| layers = nn.ModuleList( |
| [ |
| SonarTransformerEncoderLayer( |
| model_dim, num_heads, ffn_inner_dim, dropout_p |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.encoder = _SonarTransformerEncoder(layers) |
| self.layer_norm = nn.LayerNorm(model_dim) |
|
|
| def forward( |
| self, |
| token_ids: Tensor, |
| padding_mask: Optional[Tensor] = None, |
| ) -> Tensor: |
| seqs = self.encoder_frontend(token_ids) |
| seqs = self.encoder(seqs, padding_mask) |
| seqs = self.layer_norm(seqs) |
|
|
| if padding_mask is None: |
| sentence_embeddings = seqs.sum(dim=1) / (seqs.size(1) + 1e-7) |
| else: |
| mask = (~padding_mask).unsqueeze(-1).float() |
| seqs = seqs * mask |
| lengths = mask.sum(dim=1).clamp(min=1e-7) |
| sentence_embeddings = seqs.sum(dim=1) / lengths |
|
|
| return sentence_embeddings |
|
|
|
|
| |
| |
| |
|
|
|
|
| class NllbTokenizer: |
| """Standalone NLLB tokenizer using sentencepiece.""" |
|
|
| def __init__(self, model_path: str | Path, langs: Optional[List[str]] = None): |
| try: |
| import sentencepiece as spm |
| except ImportError: |
| raise ImportError("sentencepiece is required: pip install sentencepiece") |
|
|
| self.sp = spm.SentencePieceProcessor() |
| if not self.sp.load(str(model_path)): |
| raise RuntimeError(f"Failed to load SentencePiece model from {model_path}") |
|
|
| self.pad_idx = 0 |
| self.unk_idx = 1 |
| self.bos_idx = 2 |
| self.eos_idx = 3 |
|
|
| self._lang_token_to_idx = _NLLB_LANG_TOKEN_IDS |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self.sp.get_piece_size() + 206 |
|
|
| def create_encoder(self, lang: str = "eng_Latn"): |
| lang_idx = self._lang_token_to_idx.get(lang) |
| eos_idx = self.eos_idx |
|
|
| def encode(text: str) -> List[int]: |
| spm_ids = self.sp.encode(text, out_type=int) |
| content_ids = [tid + 1 for tid in spm_ids] |
| if lang_idx is not None: |
| token_ids = [lang_idx] + content_ids |
| else: |
| token_ids = content_ids |
| token_ids.append(eos_idx) |
| return token_ids |
|
|
| return encode |
|
|
|
|
| |
| _NLLB_LANG_TOKEN_IDS = { |
| "ace_Arab": 256001, |
| "ace_Latn": 256002, |
| "acm_Arab": 256003, |
| "acq_Arab": 256004, |
| "aeb_Arab": 256005, |
| "afr_Latn": 256006, |
| "ajp_Arab": 256007, |
| "aka_Latn": 256008, |
| "amh_Ethi": 256009, |
| "apc_Arab": 256010, |
| "arb_Arab": 256011, |
| "ars_Arab": 256012, |
| "ary_Arab": 256013, |
| "arz_Arab": 256014, |
| "asm_Beng": 256015, |
| "ast_Latn": 256016, |
| "awa_Deva": 256017, |
| "ayr_Latn": 256018, |
| "azb_Arab": 256019, |
| "azj_Latn": 256020, |
| "bak_Cyrl": 256021, |
| "bam_Latn": 256022, |
| "ban_Latn": 256023, |
| "bel_Cyrl": 256024, |
| "bem_Latn": 256025, |
| "ben_Beng": 256026, |
| "bho_Deva": 256027, |
| "bjn_Arab": 256028, |
| "bjn_Latn": 256029, |
| "bod_Tibt": 256030, |
| "bos_Latn": 256031, |
| "bug_Latn": 256032, |
| "bul_Cyrl": 256033, |
| "cat_Latn": 256034, |
| "ceb_Latn": 256035, |
| "ces_Latn": 256036, |
| "cjk_Latn": 256037, |
| "ckb_Arab": 256038, |
| "crh_Latn": 256039, |
| "cym_Latn": 256040, |
| "dan_Latn": 256041, |
| "deu_Latn": 256042, |
| "dik_Latn": 256043, |
| "dyu_Latn": 256044, |
| "dzo_Tibt": 256045, |
| "ell_Grek": 256046, |
| "eng_Latn": 256047, |
| "epo_Latn": 256048, |
| "est_Latn": 256049, |
| "eus_Latn": 256050, |
| "ewe_Latn": 256051, |
| "fao_Latn": 256052, |
| "pes_Arab": 256053, |
| "fij_Latn": 256054, |
| "fin_Latn": 256055, |
| "fon_Latn": 256056, |
| "fra_Latn": 256057, |
| "fur_Latn": 256058, |
| "fuv_Latn": 256059, |
| "gla_Latn": 256060, |
| "gle_Latn": 256061, |
| "glg_Latn": 256062, |
| "grn_Latn": 256063, |
| "guj_Gujr": 256064, |
| "hat_Latn": 256065, |
| "hau_Latn": 256066, |
| "heb_Hebr": 256067, |
| "hin_Deva": 256068, |
| "hne_Deva": 256069, |
| "hrv_Latn": 256070, |
| "hun_Latn": 256071, |
| "hye_Armn": 256072, |
| "ibo_Latn": 256073, |
| "ilo_Latn": 256074, |
| "ind_Latn": 256075, |
| "isl_Latn": 256076, |
| "ita_Latn": 256077, |
| "jav_Latn": 256078, |
| "jpn_Jpan": 256079, |
| "kab_Latn": 256080, |
| "kac_Latn": 256081, |
| "kam_Latn": 256082, |
| "kan_Knda": 256083, |
| "kas_Arab": 256084, |
| "kas_Deva": 256085, |
| "kat_Geor": 256086, |
| "knc_Arab": 256087, |
| "knc_Latn": 256088, |
| "kaz_Cyrl": 256089, |
| "kbp_Latn": 256090, |
| "kea_Latn": 256091, |
| "khm_Khmr": 256092, |
| "kik_Latn": 256093, |
| "kin_Latn": 256094, |
| "kir_Cyrl": 256095, |
| "kmb_Latn": 256096, |
| "kon_Latn": 256097, |
| "kor_Hang": 256098, |
| "kmr_Latn": 256099, |
| "lao_Laoo": 256100, |
| "lvs_Latn": 256101, |
| "lij_Latn": 256102, |
| "lim_Latn": 256103, |
| "lin_Latn": 256104, |
| "lit_Latn": 256105, |
| "lmo_Latn": 256106, |
| "ltg_Latn": 256107, |
| "ltz_Latn": 256108, |
| "lua_Latn": 256109, |
| "lug_Latn": 256110, |
| "luo_Latn": 256111, |
| "lus_Latn": 256112, |
| "mag_Deva": 256113, |
| "mai_Deva": 256114, |
| "mal_Mlym": 256115, |
| "mar_Deva": 256116, |
| "min_Latn": 256117, |
| "mkd_Cyrl": 256118, |
| "plt_Latn": 256119, |
| "mlt_Latn": 256120, |
| "mni_Beng": 256121, |
| "khk_Cyrl": 256122, |
| "mos_Latn": 256123, |
| "mri_Latn": 256124, |
| "zsm_Latn": 256125, |
| "mya_Mymr": 256126, |
| "nld_Latn": 256127, |
| "nno_Latn": 256128, |
| "nob_Latn": 256129, |
| "npi_Deva": 256130, |
| "nso_Latn": 256131, |
| "nus_Latn": 256132, |
| "nya_Latn": 256133, |
| "oci_Latn": 256134, |
| "gaz_Latn": 256135, |
| "ory_Orya": 256136, |
| "pag_Latn": 256137, |
| "pan_Guru": 256138, |
| "pap_Latn": 256139, |
| "pol_Latn": 256140, |
| "por_Latn": 256141, |
| "prs_Arab": 256142, |
| "pbt_Arab": 256143, |
| "quy_Latn": 256144, |
| "ron_Latn": 256145, |
| "run_Latn": 256146, |
| "rus_Cyrl": 256147, |
| "sag_Latn": 256148, |
| "san_Deva": 256149, |
| "sat_Beng": 256150, |
| "scn_Latn": 256151, |
| "shn_Mymr": 256152, |
| "sin_Sinh": 256153, |
| "slk_Latn": 256154, |
| "slv_Latn": 256155, |
| "smo_Latn": 256156, |
| "sna_Latn": 256157, |
| "snd_Arab": 256158, |
| "som_Latn": 256159, |
| "sot_Latn": 256160, |
| "spa_Latn": 256161, |
| "als_Latn": 256162, |
| "srd_Latn": 256163, |
| "srp_Cyrl": 256164, |
| "ssw_Latn": 256165, |
| "sun_Latn": 256166, |
| "swe_Latn": 256167, |
| "swh_Latn": 256168, |
| "szl_Latn": 256169, |
| "tam_Taml": 256170, |
| "tat_Cyrl": 256171, |
| "tel_Telu": 256172, |
| "tgk_Cyrl": 256173, |
| "tgl_Latn": 256174, |
| "tha_Thai": 256175, |
| "tir_Ethi": 256176, |
| "taq_Latn": 256177, |
| "taq_Tfng": 256178, |
| "tpi_Latn": 256179, |
| "tsn_Latn": 256180, |
| "tso_Latn": 256181, |
| "tuk_Latn": 256182, |
| "tum_Latn": 256183, |
| "tur_Latn": 256184, |
| "twi_Latn": 256185, |
| "tzm_Tfng": 256186, |
| "uig_Arab": 256187, |
| "ukr_Cyrl": 256188, |
| "umb_Latn": 256189, |
| "urd_Arab": 256190, |
| "uzn_Latn": 256191, |
| "vec_Latn": 256192, |
| "vie_Latn": 256193, |
| "war_Latn": 256194, |
| "wol_Latn": 256195, |
| "xho_Latn": 256196, |
| "ydd_Hebr": 256197, |
| "yor_Latn": 256198, |
| "yue_Hant": 256199, |
| "zho_Hans": 256200, |
| "zho_Hant": 256201, |
| "zul_Latn": 256202, |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| class GlapModel(PreTrainedModel): |
| config_class = GlapConfig |
|
|
| def __init__(self, config: GlapConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.audio_encoder = DashengAudioEncoder( |
| embed_dim=config.audio_embed_dim, |
| depth=config.audio_depth, |
| num_heads=config.audio_num_heads, |
| patch_size=config.patch_size, |
| patch_stride=config.patch_stride, |
| target_length=config.target_length, |
| ) |
|
|
| |
| self.text_encoder = SonarTextEncoder( |
| vocab_size=config.text_vocab_size, |
| model_dim=config.text_model_dim, |
| num_layers=config.text_num_layers, |
| num_heads=config.text_num_heads, |
| ffn_inner_dim=config.text_ffn_inner_dim, |
| max_seq_len=config.text_max_seq_len, |
| pad_idx=config.text_pad_idx, |
| dropout_p=config.text_dropout_p, |
| ) |
|
|
| |
| self.audio_proj = nn.Sequential( |
| nn.Linear(config.audio_embed_dim, config.embed_size), |
| nn.ReLU(), |
| nn.Linear(config.embed_size, config.embed_size), |
| ) |
| self.text_proj = nn.Sequential( |
| nn.Linear(config.text_model_dim, config.embed_size), |
| nn.ReLU(), |
| nn.Linear(config.embed_size, config.embed_size), |
| ) |
|
|
| self.tokenizer: Optional[NllbTokenizer] = None |
| self.post_init() |
|
|
| def _init_weights(self, module): |
| if isinstance(module, SinusoidalPositionEncoder): |
| with torch.no_grad(): |
| start_step = 1 + module._legacy_pad_idx |
| steps = torch.arange( |
| start_step, |
| start_step + module.max_seq_len, |
| dtype=torch.float32, |
| ) |
| module.freqs.copy_(module._build_freqs(steps, module.encoding_dim)) |
|
|
| def _get_tokenizer(self) -> NllbTokenizer: |
| if self.tokenizer is None: |
| tokenizer_filename = "sentencepiece.source.256000.model" |
| tokenizer_path: Optional[Path | str] = None |
|
|
| |
| model_dir = Path(self.config._name_or_path) |
| candidate = model_dir / tokenizer_filename |
| if candidate.exists(): |
| tokenizer_path = candidate |
|
|
| |
| if tokenizer_path is None: |
| candidate = Path(__file__).parent / tokenizer_filename |
| if candidate.exists(): |
| tokenizer_path = candidate |
|
|
| |
| if tokenizer_path is None and hf_hub_download is not None: |
| try: |
| tokenizer_path = hf_hub_download( |
| repo_id=self.config._name_or_path, |
| filename=tokenizer_filename, |
| ) |
| except Exception: |
| pass |
|
|
| if tokenizer_path is None: |
| raise FileNotFoundError( |
| f"Could not find {tokenizer_filename}. " |
| f"Searched {self.config._name_or_path} and " |
| f"{Path(__file__).parent}." |
| ) |
|
|
| self.tokenizer = NllbTokenizer(tokenizer_path) |
| return self.tokenizer |
|
|
| def encode_audio( |
| self, |
| audio: torch.Tensor, |
| audio_length: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| audio_embeds = self.audio_encoder(audio) |
| audio_embeds = F.normalize(self.audio_proj(audio_embeds), dim=-1) |
| return audio_embeds |
|
|
| def encode_text( |
| self, |
| text: Sequence[str], |
| source_lang: str = "eng_Latn", |
| ) -> torch.Tensor: |
| tokenizer = self._get_tokenizer() |
| encoder_fn = tokenizer.create_encoder(lang=source_lang) |
|
|
| all_token_ids: List[List[int]] = [] |
| max_seq_len = self.config.text_max_seq_len |
| for t in text: |
| token_ids = encoder_fn(t)[:max_seq_len] |
| all_token_ids.append(token_ids) |
|
|
| max_len = max(len(ids) for ids in all_token_ids) if all_token_ids else 0 |
| batch_size = len(all_token_ids) |
|
|
| device = self.audio_proj[0].weight.device |
|
|
| padded_ids = torch.full( |
| (batch_size, max_len), |
| tokenizer.pad_idx, |
| dtype=torch.long, |
| device="cpu", |
| ) |
| padding_mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device="cpu") |
|
|
| for i, ids in enumerate(all_token_ids): |
| length = len(ids) |
| padded_ids[i, :length] = torch.tensor(ids, dtype=torch.long) |
| padding_mask[i, length:] = True |
|
|
| self.text_encoder.eval() |
| with torch.no_grad(): |
| sentence_embeddings = self.text_encoder(padded_ids, padding_mask) |
|
|
| text_embeds = F.normalize( |
| self.text_proj(sentence_embeddings.to(device)), dim=-1 |
| ) |
| return text_embeds |
|
|
| def get_audio_features( |
| self, |
| audio: torch.Tensor, |
| audio_length: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| return self.encode_audio(audio, audio_length) |
|
|
| def get_text_features( |
| self, |
| text: Sequence[str], |
| source_lang: str = "eng_Latn", |
| **kwargs, |
| ) -> torch.Tensor: |
| return self.encode_text(text, source_lang=source_lang) |
|
|
| def forward( |
| self, |
| audio: Optional[torch.Tensor] = None, |
| text: Optional[Sequence[str]] = None, |
| audio_length: Optional[torch.Tensor] = None, |
| source_lang: str = "eng_Latn", |
| **kwargs, |
| ): |
| audio_embeds = None |
| text_embeds = None |
|
|
| if audio is not None: |
| audio_embeds = self.encode_audio(audio, audio_length) |
| if text is not None: |
| text_embeds = self.encode_text(text, source_lang=source_lang) |
|
|
| return audio_embeds, text_embeds |
|
|
| def score(self, audio_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor: |
| return 100 * (audio_emb @ text_emb.T) |
|
|
| def score_forward( |
| self, |
| audio: torch.Tensor, |
| text: Sequence[str], |
| audio_length: Optional[torch.Tensor] = None, |
| source_lang: str = "eng_Latn", |
| ) -> torch.Tensor: |
| audio_emb, text_emb = self.forward(audio, text, audio_length, source_lang) |
| return self.score(audio_emb, text_emb) |
|
|