"""GLAP (Generalized Language Audio Pretraining) HuggingFace model. Audio encoder adapted from dasheng-denoiser (Apache 2.0). Text encoder adapted from SONAR standalone (Apache 2.0). """ from __future__ import annotations import math from pathlib import Path from typing import List, Optional, Sequence import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from einops import rearrange from einops.layers.torch import Rearrange from transformers import PreTrainedModel try: from huggingface_hub import hf_hub_download except ImportError: hf_hub_download = None # type: ignore[assignment,misc] from .configuration_glap import GlapConfig # ============================================================================ # Audio Encoder (adapted from dasheng-denoiser/modeling_dasheng_encoder.py) # ============================================================================ class FrontEnd(nn.Sequential): def __init__( self, f_min: int = 0, sample_rate: int = 16000, win_size: int = 512, center: bool = True, n_fft: int = 512, f_max: Optional[int] = 8000, hop_size: int = 160, n_mels: int = 64, ): audio_transforms = __import__("importlib").import_module( "torchaudio.transforms" ) self.f_min = f_min self.sample_rate = sample_rate self.win_size = win_size self.center = center self.n_fft = n_fft self.f_max = f_max self.hop_size = hop_size self.n_mels = n_mels with torch.device("cpu"): super().__init__( audio_transforms.MelSpectrogram( f_min=self.f_min, sample_rate=self.sample_rate, win_length=self.win_size, center=self.center, n_fft=self.n_fft, f_max=self.f_max, hop_length=self.hop_size, n_mels=self.n_mels, ), audio_transforms.AmplitudeToDB(top_db=120), ) @torch.autocast(enabled=False, device_type="cuda") def forward(self, x, attention_mask=None): features = super().forward(x) if attention_mask is not None: lengths = attention_mask.float().sum(-1) // self.hop_size attention_mask = ( torch.arange(features.shape[-1], device=features.device) < lengths.unsqueeze(-1) ).int() return features, attention_mask class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: type[nn.Module] = nn.GELU, drop: float = 0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class AudioAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() self.num_heads = num_heads self.scale = (dim // num_heads) ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, mask: Optional[torch.Tensor] = None): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale if mask is not None: if mask.dtype != torch.bool: padding_mask = mask == 0 else: padding_mask = mask padding_mask = padding_mask.view(B, 1, 1, N) attn = attn.masked_fill(padding_mask, float("-inf")) attn = attn.softmax(dim=-1).nan_to_num() attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) return self.proj_drop(self.proj(x)) class AudioBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, ): super().__init__() self.norm1 = nn.LayerNorm(dim, eps=1e-6) self.attn = AudioAttention(dim, num_heads, qkv_bias, attn_drop, drop) self.norm2 = nn.LayerNorm(dim, eps=1e-6) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=nn.GELU, drop=drop, ) def forward(self, x, mask=None): x = x + self.attn(self.norm1(x), mask=mask) x = x + self.mlp(self.norm2(x)) return x class AudioPatchEmbed(nn.Module): def __init__(self, *args, **kwargs): super().__init__() self.stride = kwargs.get("stride", [None, 4])[-1] self.proj = nn.Conv2d(*args, **kwargs) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): x = self.proj(x) if attention_mask is not None: lengths = attention_mask.float().sum(-1) // self.stride attention_mask = ( torch.arange(x.shape[-1], device=x.device) < lengths.unsqueeze(-1) ).int() return x, attention_mask class DashengAudioEncoder(nn.Module): """Dasheng audio encoder matching the original DashengWrapper. Produces a single (B, embed_dim) embedding per audio input. Pads spectrogram to a multiple of target_length, splits into chunks, processes each chunk independently through the Transformer, then mean-pools across chunks. """ def __init__( self, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, patch_size: list = None, patch_stride: list = None, target_length: int = 1008, ): super().__init__() patch_size = patch_size or [64, 4] patch_stride = patch_stride or [64, 4] self.embed_dim = embed_dim self.target_length = target_length self.patch_stride = patch_stride self.time_patches = patch_stride[-1] self.max_t_tokens = target_length // self.time_patches self.front_end = FrontEnd() self.patch_embed = AudioPatchEmbed( 1, embed_dim, kernel_size=patch_size, stride=patch_stride ) self.init_bn = nn.Sequential( Rearrange("b c f t -> b f c t"), nn.BatchNorm2d(self.front_end.n_mels, momentum=0.01), Rearrange("b f c t -> b c f t"), ) self.time_pos_embed = nn.Parameter( torch.randn(1, embed_dim, 1, target_length // self.time_patches) * 0.02 ) self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1) * 0.02) self.blocks = nn.ModuleList( [AudioBlock(embed_dim, num_heads) for _ in range(depth)] ) self.norm = nn.LayerNorm(embed_dim, eps=1e-6) def _forward_chunk(self, x, attention_mask=None): x, attention_mask = self.patch_embed(x, attention_mask) t = x.shape[-1] x = x + self.time_pos_embed[:, :, :, :t] + self.freq_pos_embed x = rearrange(x, "b c f t -> b (f t) c") for block in self.blocks: x = block(x, mask=attention_mask) x = self.norm(x) return x def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Compute spectrogram x, attention_mask = self.front_end(x, attention_mask) x = rearrange(x, "b f t -> b 1 f t") x = self.init_bn(x) # Pad spectrogram time dim to next multiple of target_length if x.shape[-1] > self.target_length: remainder = x.shape[-1] % self.target_length if remainder != 0: pad_amount = self.target_length - remainder x = F.pad(x, (0, pad_amount)) # Split into chunks along time dimension input_splits = x.split(self.target_length, dim=-1) masks = [None for _ in range(len(input_splits))] # Process each chunk independently outputs = [] chunk_size_in_patches = self.target_length // self.patch_stride[-1] for input_split_x in input_splits: output = self._forward_chunk(input_split_x, attention_mask=None) # Mean pool each chunk: (B, num_patches, embed_dim) -> (B, embed_dim) chunks = output.split(chunk_size_in_patches, dim=1) chunk_means = [c.mean(1) for c in chunks] outputs.append(torch.stack(chunk_means).mean(0)) # Mean across all split outputs emb = torch.stack(outputs).mean(0) return emb # ============================================================================ # Text Encoder (adapted from dasheng-glap SONAR standalone) # ============================================================================ class SinusoidalPositionEncoder(nn.Module): def __init__(self, encoding_dim: int, max_seq_len: int, _legacy_pad_idx: int = 1): super().__init__() assert encoding_dim % 2 == 0 self.encoding_dim = encoding_dim self.max_seq_len = max_seq_len self._legacy_pad_idx = _legacy_pad_idx start_step = 1 + _legacy_pad_idx steps = torch.arange(start_step, start_step + max_seq_len, dtype=torch.float32) self.register_buffer( "freqs", self._build_freqs(steps, encoding_dim), persistent=False ) @staticmethod def _build_freqs(steps: Tensor, encoding_dim: int) -> Tensor: num_sin = encoding_dim // 2 indices = torch.arange(num_sin, dtype=torch.float32) freq_vals = torch.exp(indices * -math.log(10000.0) / (num_sin - 1)) l_half = torch.outer(steps, freq_vals) r_half = l_half[:, : encoding_dim - num_sin].clone() return torch.cat([l_half.sin(), r_half.cos()], dim=-1) def forward(self, seqs: Tensor) -> Tensor: seq_len = seqs.size(-2) return (seqs.float() + self.freqs[:seq_len]).type_as(seqs) class SonarMultiheadAttention(nn.Module): def __init__(self, model_dim: int, num_heads: int, dropout_p: float = 0.0): super().__init__() self.model_dim = model_dim self.num_heads = num_heads self.head_dim = model_dim // num_heads assert model_dim % num_heads == 0 self.q_proj = nn.Linear(model_dim, model_dim, bias=True) self.k_proj = nn.Linear(model_dim, model_dim, bias=True) self.v_proj = nn.Linear(model_dim, model_dim, bias=True) self.output_proj = nn.Linear(model_dim, model_dim, bias=True) self.attn_dropout_p = dropout_p def forward( self, queries: Tensor, keys: Tensor, values: Tensor, padding_mask: Optional[Tensor] = None, ) -> Tensor: bsz, seq_len, _ = queries.shape q = ( self.q_proj(queries) .view(bsz, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) k = ( self.k_proj(keys) .view(bsz, -1, self.num_heads, self.head_dim) .transpose(1, 2) ) v = ( self.v_proj(values) .view(bsz, -1, self.num_heads, self.head_dim) .transpose(1, 2) ) scale = self.head_dim**-0.5 attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale if padding_mask is not None: attn_weights = attn_weights.masked_fill( padding_mask[:, None, None, :], float("-inf") ) attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) if self.training and self.attn_dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=self.attn_dropout_p) attn = torch.matmul(attn_weights, v) attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, self.model_dim) return self.output_proj(attn) class _FeedForwardNetwork(nn.Module): def __init__(self, model_dim: int, inner_dim: int, dropout_p: float = 0.1): super().__init__() self.inner_proj = nn.Linear(model_dim, inner_dim, bias=True) self.output_proj = nn.Linear(inner_dim, model_dim, bias=True) self.dropout = nn.Dropout(dropout_p) def forward(self, x: Tensor) -> Tensor: x = self.inner_proj(x) x = F.relu(x) x = self.dropout(x) x = self.output_proj(x) return x class SonarTransformerEncoderLayer(nn.Module): def __init__( self, model_dim: int, num_heads: int, ffn_inner_dim: int, dropout_p: float = 0.1 ): super().__init__() self.self_attn_layer_norm = nn.LayerNorm(model_dim) self.self_attn = SonarMultiheadAttention( model_dim, num_heads, dropout_p=dropout_p ) self.ffn_layer_norm = nn.LayerNorm(model_dim) self.ffn = _FeedForwardNetwork(model_dim, ffn_inner_dim, dropout_p) def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor: residual = seqs seqs = self.self_attn_layer_norm(seqs) seqs = self.self_attn(seqs, seqs, seqs, padding_mask) seqs = seqs + residual residual = seqs seqs = self.ffn_layer_norm(seqs) seqs = self.ffn(seqs) seqs = seqs + residual return seqs class _SonarTransformerEncoder(nn.Module): def __init__(self, layers: nn.ModuleList): super().__init__() self.layers = layers def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor: for layer in self.layers: seqs = layer(seqs, padding_mask) return seqs class _SonarEmbeddingFrontend(nn.Module): def __init__( self, embed: nn.Embedding, pos_encoder: SinusoidalPositionEncoder, dropout_p: float = 0.1, ): super().__init__() self.embed = embed self.pos_encoder = pos_encoder self.dropout = nn.Dropout(dropout_p) def forward(self, token_ids: Tensor) -> Tensor: seqs = self.embed(token_ids) seqs = seqs * math.sqrt(seqs.size(-1)) seqs = self.pos_encoder(seqs) seqs = self.dropout(seqs) return seqs class SonarTextEncoder(nn.Module): """24-layer SONAR text encoder with sinusoidal PE and mean pooling.""" def __init__( self, vocab_size: int = 256206, model_dim: int = 1024, num_layers: int = 24, num_heads: int = 16, ffn_inner_dim: int = 8192, max_seq_len: int = 514, pad_idx: int = 0, dropout_p: float = 0.1, ): super().__init__() self.model_dim = model_dim self.pad_idx = pad_idx embed = nn.Embedding(vocab_size, model_dim, padding_idx=pad_idx) pos_encoder = SinusoidalPositionEncoder( model_dim, max_seq_len, _legacy_pad_idx=1 ) self.encoder_frontend = _SonarEmbeddingFrontend(embed, pos_encoder, dropout_p) layers = nn.ModuleList( [ SonarTransformerEncoderLayer( model_dim, num_heads, ffn_inner_dim, dropout_p ) for _ in range(num_layers) ] ) self.encoder = _SonarTransformerEncoder(layers) self.layer_norm = nn.LayerNorm(model_dim) def forward( self, token_ids: Tensor, padding_mask: Optional[Tensor] = None, ) -> Tensor: seqs = self.encoder_frontend(token_ids) seqs = self.encoder(seqs, padding_mask) seqs = self.layer_norm(seqs) if padding_mask is None: sentence_embeddings = seqs.sum(dim=1) / (seqs.size(1) + 1e-7) else: mask = (~padding_mask).unsqueeze(-1).float() seqs = seqs * mask lengths = mask.sum(dim=1).clamp(min=1e-7) sentence_embeddings = seqs.sum(dim=1) / lengths return sentence_embeddings # ============================================================================ # Tokenizer # ============================================================================ class NllbTokenizer: """Standalone NLLB tokenizer using sentencepiece.""" def __init__(self, model_path: str | Path, langs: Optional[List[str]] = None): try: import sentencepiece as spm except ImportError: raise ImportError("sentencepiece is required: pip install sentencepiece") self.sp = spm.SentencePieceProcessor() if not self.sp.load(str(model_path)): raise RuntimeError(f"Failed to load SentencePiece model from {model_path}") self.pad_idx = 0 self.unk_idx = 1 self.bos_idx = 2 self.eos_idx = 3 self._lang_token_to_idx = _NLLB_LANG_TOKEN_IDS @property def vocab_size(self) -> int: return self.sp.get_piece_size() + 206 def create_encoder(self, lang: str = "eng_Latn"): lang_idx = self._lang_token_to_idx.get(lang) eos_idx = self.eos_idx def encode(text: str) -> List[int]: spm_ids = self.sp.encode(text, out_type=int) content_ids = [tid + 1 for tid in spm_ids] if lang_idx is not None: token_ids = [lang_idx] + content_ids else: token_ids = content_ids token_ids.append(eos_idx) return token_ids return encode # Pre-computed NLLB language -> token ID mapping _NLLB_LANG_TOKEN_IDS = { "ace_Arab": 256001, "ace_Latn": 256002, "acm_Arab": 256003, "acq_Arab": 256004, "aeb_Arab": 256005, "afr_Latn": 256006, "ajp_Arab": 256007, "aka_Latn": 256008, "amh_Ethi": 256009, "apc_Arab": 256010, "arb_Arab": 256011, "ars_Arab": 256012, "ary_Arab": 256013, "arz_Arab": 256014, "asm_Beng": 256015, "ast_Latn": 256016, "awa_Deva": 256017, "ayr_Latn": 256018, "azb_Arab": 256019, "azj_Latn": 256020, "bak_Cyrl": 256021, "bam_Latn": 256022, "ban_Latn": 256023, "bel_Cyrl": 256024, "bem_Latn": 256025, "ben_Beng": 256026, "bho_Deva": 256027, "bjn_Arab": 256028, "bjn_Latn": 256029, "bod_Tibt": 256030, "bos_Latn": 256031, "bug_Latn": 256032, "bul_Cyrl": 256033, "cat_Latn": 256034, "ceb_Latn": 256035, "ces_Latn": 256036, "cjk_Latn": 256037, "ckb_Arab": 256038, "crh_Latn": 256039, "cym_Latn": 256040, "dan_Latn": 256041, "deu_Latn": 256042, "dik_Latn": 256043, "dyu_Latn": 256044, "dzo_Tibt": 256045, "ell_Grek": 256046, "eng_Latn": 256047, "epo_Latn": 256048, "est_Latn": 256049, "eus_Latn": 256050, "ewe_Latn": 256051, "fao_Latn": 256052, "pes_Arab": 256053, "fij_Latn": 256054, "fin_Latn": 256055, "fon_Latn": 256056, "fra_Latn": 256057, "fur_Latn": 256058, "fuv_Latn": 256059, "gla_Latn": 256060, "gle_Latn": 256061, "glg_Latn": 256062, "grn_Latn": 256063, "guj_Gujr": 256064, "hat_Latn": 256065, "hau_Latn": 256066, "heb_Hebr": 256067, "hin_Deva": 256068, "hne_Deva": 256069, "hrv_Latn": 256070, "hun_Latn": 256071, "hye_Armn": 256072, "ibo_Latn": 256073, "ilo_Latn": 256074, "ind_Latn": 256075, "isl_Latn": 256076, "ita_Latn": 256077, "jav_Latn": 256078, "jpn_Jpan": 256079, "kab_Latn": 256080, "kac_Latn": 256081, "kam_Latn": 256082, "kan_Knda": 256083, "kas_Arab": 256084, "kas_Deva": 256085, "kat_Geor": 256086, "knc_Arab": 256087, "knc_Latn": 256088, "kaz_Cyrl": 256089, "kbp_Latn": 256090, "kea_Latn": 256091, "khm_Khmr": 256092, "kik_Latn": 256093, "kin_Latn": 256094, "kir_Cyrl": 256095, "kmb_Latn": 256096, "kon_Latn": 256097, "kor_Hang": 256098, "kmr_Latn": 256099, "lao_Laoo": 256100, "lvs_Latn": 256101, "lij_Latn": 256102, "lim_Latn": 256103, "lin_Latn": 256104, "lit_Latn": 256105, "lmo_Latn": 256106, "ltg_Latn": 256107, "ltz_Latn": 256108, "lua_Latn": 256109, "lug_Latn": 256110, "luo_Latn": 256111, "lus_Latn": 256112, "mag_Deva": 256113, "mai_Deva": 256114, "mal_Mlym": 256115, "mar_Deva": 256116, "min_Latn": 256117, "mkd_Cyrl": 256118, "plt_Latn": 256119, "mlt_Latn": 256120, "mni_Beng": 256121, "khk_Cyrl": 256122, "mos_Latn": 256123, "mri_Latn": 256124, "zsm_Latn": 256125, "mya_Mymr": 256126, "nld_Latn": 256127, "nno_Latn": 256128, "nob_Latn": 256129, "npi_Deva": 256130, "nso_Latn": 256131, "nus_Latn": 256132, "nya_Latn": 256133, "oci_Latn": 256134, "gaz_Latn": 256135, "ory_Orya": 256136, "pag_Latn": 256137, "pan_Guru": 256138, "pap_Latn": 256139, "pol_Latn": 256140, "por_Latn": 256141, "prs_Arab": 256142, "pbt_Arab": 256143, "quy_Latn": 256144, "ron_Latn": 256145, "run_Latn": 256146, "rus_Cyrl": 256147, "sag_Latn": 256148, "san_Deva": 256149, "sat_Beng": 256150, "scn_Latn": 256151, "shn_Mymr": 256152, "sin_Sinh": 256153, "slk_Latn": 256154, "slv_Latn": 256155, "smo_Latn": 256156, "sna_Latn": 256157, "snd_Arab": 256158, "som_Latn": 256159, "sot_Latn": 256160, "spa_Latn": 256161, "als_Latn": 256162, "srd_Latn": 256163, "srp_Cyrl": 256164, "ssw_Latn": 256165, "sun_Latn": 256166, "swe_Latn": 256167, "swh_Latn": 256168, "szl_Latn": 256169, "tam_Taml": 256170, "tat_Cyrl": 256171, "tel_Telu": 256172, "tgk_Cyrl": 256173, "tgl_Latn": 256174, "tha_Thai": 256175, "tir_Ethi": 256176, "taq_Latn": 256177, "taq_Tfng": 256178, "tpi_Latn": 256179, "tsn_Latn": 256180, "tso_Latn": 256181, "tuk_Latn": 256182, "tum_Latn": 256183, "tur_Latn": 256184, "twi_Latn": 256185, "tzm_Tfng": 256186, "uig_Arab": 256187, "ukr_Cyrl": 256188, "umb_Latn": 256189, "urd_Arab": 256190, "uzn_Latn": 256191, "vec_Latn": 256192, "vie_Latn": 256193, "war_Latn": 256194, "wol_Latn": 256195, "xho_Latn": 256196, "ydd_Hebr": 256197, "yor_Latn": 256198, "yue_Hant": 256199, "zho_Hans": 256200, "zho_Hant": 256201, "zul_Latn": 256202, } # ============================================================================ # GLAP Model # ============================================================================ class GlapModel(PreTrainedModel): config_class = GlapConfig def __init__(self, config: GlapConfig): super().__init__(config) self.config = config # Audio encoder self.audio_encoder = DashengAudioEncoder( embed_dim=config.audio_embed_dim, depth=config.audio_depth, num_heads=config.audio_num_heads, patch_size=config.patch_size, patch_stride=config.patch_stride, target_length=config.target_length, ) # Text encoder self.text_encoder = SonarTextEncoder( vocab_size=config.text_vocab_size, model_dim=config.text_model_dim, num_layers=config.text_num_layers, num_heads=config.text_num_heads, ffn_inner_dim=config.text_ffn_inner_dim, max_seq_len=config.text_max_seq_len, pad_idx=config.text_pad_idx, dropout_p=config.text_dropout_p, ) # Projection layers self.audio_proj = nn.Sequential( nn.Linear(config.audio_embed_dim, config.embed_size), nn.ReLU(), nn.Linear(config.embed_size, config.embed_size), ) self.text_proj = nn.Sequential( nn.Linear(config.text_model_dim, config.embed_size), nn.ReLU(), nn.Linear(config.embed_size, config.embed_size), ) self.tokenizer: Optional[NllbTokenizer] = None self.post_init() def _init_weights(self, module): if isinstance(module, SinusoidalPositionEncoder): with torch.no_grad(): start_step = 1 + module._legacy_pad_idx steps = torch.arange( start_step, start_step + module.max_seq_len, dtype=torch.float32, ) module.freqs.copy_(module._build_freqs(steps, module.encoding_dim)) def _get_tokenizer(self) -> NllbTokenizer: if self.tokenizer is None: tokenizer_filename = "sentencepiece.source.256000.model" tokenizer_path: Optional[Path | str] = None # 1. Check config._name_or_path (local directory) model_dir = Path(self.config._name_or_path) candidate = model_dir / tokenizer_filename if candidate.exists(): tokenizer_path = candidate # 2. Check next to this file (modules cache / local install) if tokenizer_path is None: candidate = Path(__file__).parent / tokenizer_filename if candidate.exists(): tokenizer_path = candidate # 3. Download from HuggingFace Hub if tokenizer_path is None and hf_hub_download is not None: try: tokenizer_path = hf_hub_download( repo_id=self.config._name_or_path, filename=tokenizer_filename, ) except Exception: pass if tokenizer_path is None: raise FileNotFoundError( f"Could not find {tokenizer_filename}. " f"Searched {self.config._name_or_path} and " f"{Path(__file__).parent}." ) self.tokenizer = NllbTokenizer(tokenizer_path) return self.tokenizer def encode_audio( self, audio: torch.Tensor, audio_length: Optional[torch.Tensor] = None, ) -> torch.Tensor: audio_embeds = self.audio_encoder(audio) audio_embeds = F.normalize(self.audio_proj(audio_embeds), dim=-1) return audio_embeds def encode_text( self, text: Sequence[str], source_lang: str = "eng_Latn", ) -> torch.Tensor: tokenizer = self._get_tokenizer() encoder_fn = tokenizer.create_encoder(lang=source_lang) all_token_ids: List[List[int]] = [] max_seq_len = self.config.text_max_seq_len for t in text: token_ids = encoder_fn(t)[:max_seq_len] all_token_ids.append(token_ids) max_len = max(len(ids) for ids in all_token_ids) if all_token_ids else 0 batch_size = len(all_token_ids) device = self.audio_proj[0].weight.device padded_ids = torch.full( (batch_size, max_len), tokenizer.pad_idx, dtype=torch.long, device="cpu", ) padding_mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device="cpu") for i, ids in enumerate(all_token_ids): length = len(ids) padded_ids[i, :length] = torch.tensor(ids, dtype=torch.long) padding_mask[i, length:] = True self.text_encoder.eval() with torch.no_grad(): sentence_embeddings = self.text_encoder(padded_ids, padding_mask) text_embeds = F.normalize( self.text_proj(sentence_embeddings.to(device)), dim=-1 ) return text_embeds def get_audio_features( self, audio: torch.Tensor, audio_length: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: return self.encode_audio(audio, audio_length) def get_text_features( self, text: Sequence[str], source_lang: str = "eng_Latn", **kwargs, ) -> torch.Tensor: return self.encode_text(text, source_lang=source_lang) def forward( self, audio: Optional[torch.Tensor] = None, text: Optional[Sequence[str]] = None, audio_length: Optional[torch.Tensor] = None, source_lang: str = "eng_Latn", **kwargs, ): audio_embeds = None text_embeds = None if audio is not None: audio_embeds = self.encode_audio(audio, audio_length) if text is not None: text_embeds = self.encode_text(text, source_lang=source_lang) return audio_embeds, text_embeds def score(self, audio_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor: return 100 * (audio_emb @ text_emb.T) def score_forward( self, audio: torch.Tensor, text: Sequence[str], audio_length: Optional[torch.Tensor] = None, source_lang: str = "eng_Latn", ) -> torch.Tensor: audio_emb, text_emb = self.forward(audio, text, audio_length, source_lang) return self.score(audio_emb, text_emb)