GLAP / modeling_glap.py
Heinrich Dinkel
updated modeling
79b1301
"""GLAP (Generalized Language Audio Pretraining) HuggingFace model.
Audio encoder adapted from dasheng-denoiser (Apache 2.0).
Text encoder adapted from SONAR standalone (Apache 2.0).
"""
from __future__ import annotations
import math
from pathlib import Path
from typing import List, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange
from einops.layers.torch import Rearrange
from transformers import PreTrainedModel
try:
from huggingface_hub import hf_hub_download
except ImportError:
hf_hub_download = None # type: ignore[assignment,misc]
from .configuration_glap import GlapConfig
# ============================================================================
# Audio Encoder (adapted from dasheng-denoiser/modeling_dasheng_encoder.py)
# ============================================================================
class FrontEnd(nn.Sequential):
def __init__(
self,
f_min: int = 0,
sample_rate: int = 16000,
win_size: int = 512,
center: bool = True,
n_fft: int = 512,
f_max: Optional[int] = 8000,
hop_size: int = 160,
n_mels: int = 64,
):
audio_transforms = __import__("importlib").import_module(
"torchaudio.transforms"
)
self.f_min = f_min
self.sample_rate = sample_rate
self.win_size = win_size
self.center = center
self.n_fft = n_fft
self.f_max = f_max
self.hop_size = hop_size
self.n_mels = n_mels
with torch.device("cpu"):
super().__init__(
audio_transforms.MelSpectrogram(
f_min=self.f_min,
sample_rate=self.sample_rate,
win_length=self.win_size,
center=self.center,
n_fft=self.n_fft,
f_max=self.f_max,
hop_length=self.hop_size,
n_mels=self.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)
@torch.autocast(enabled=False, device_type="cuda")
def forward(self, x, attention_mask=None):
features = super().forward(x)
if attention_mask is not None:
lengths = attention_mask.float().sum(-1) // self.hop_size
attention_mask = (
torch.arange(features.shape[-1], device=features.device)
< lengths.unsqueeze(-1)
).int()
return features, attention_mask
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class AudioAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
if mask.dtype != torch.bool:
padding_mask = mask == 0
else:
padding_mask = mask
padding_mask = padding_mask.view(B, 1, 1, N)
attn = attn.masked_fill(padding_mask, float("-inf"))
attn = attn.softmax(dim=-1).nan_to_num()
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.proj_drop(self.proj(x))
class AudioBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = AudioAttention(dim, num_heads, qkv_bias, attn_drop, drop)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask=mask)
x = x + self.mlp(self.norm2(x))
return x
class AudioPatchEmbed(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.stride = kwargs.get("stride", [None, 4])[-1]
self.proj = nn.Conv2d(*args, **kwargs)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
x = self.proj(x)
if attention_mask is not None:
lengths = attention_mask.float().sum(-1) // self.stride
attention_mask = (
torch.arange(x.shape[-1], device=x.device) < lengths.unsqueeze(-1)
).int()
return x, attention_mask
class DashengAudioEncoder(nn.Module):
"""Dasheng audio encoder matching the original DashengWrapper.
Produces a single (B, embed_dim) embedding per audio input.
Pads spectrogram to a multiple of target_length, splits into chunks,
processes each chunk independently through the Transformer, then
mean-pools across chunks.
"""
def __init__(
self,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
patch_size: list = None,
patch_stride: list = None,
target_length: int = 1008,
):
super().__init__()
patch_size = patch_size or [64, 4]
patch_stride = patch_stride or [64, 4]
self.embed_dim = embed_dim
self.target_length = target_length
self.patch_stride = patch_stride
self.time_patches = patch_stride[-1]
self.max_t_tokens = target_length // self.time_patches
self.front_end = FrontEnd()
self.patch_embed = AudioPatchEmbed(
1, embed_dim, kernel_size=patch_size, stride=patch_stride
)
self.init_bn = nn.Sequential(
Rearrange("b c f t -> b f c t"),
nn.BatchNorm2d(self.front_end.n_mels, momentum=0.01),
Rearrange("b f c t -> b c f t"),
)
self.time_pos_embed = nn.Parameter(
torch.randn(1, embed_dim, 1, target_length // self.time_patches) * 0.02
)
self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1) * 0.02)
self.blocks = nn.ModuleList(
[AudioBlock(embed_dim, num_heads) for _ in range(depth)]
)
self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
def _forward_chunk(self, x, attention_mask=None):
x, attention_mask = self.patch_embed(x, attention_mask)
t = x.shape[-1]
x = x + self.time_pos_embed[:, :, :, :t] + self.freq_pos_embed
x = rearrange(x, "b c f t -> b (f t) c")
for block in self.blocks:
x = block(x, mask=attention_mask)
x = self.norm(x)
return x
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Compute spectrogram
x, attention_mask = self.front_end(x, attention_mask)
x = rearrange(x, "b f t -> b 1 f t")
x = self.init_bn(x)
# Pad spectrogram time dim to next multiple of target_length
if x.shape[-1] > self.target_length:
remainder = x.shape[-1] % self.target_length
if remainder != 0:
pad_amount = self.target_length - remainder
x = F.pad(x, (0, pad_amount))
# Split into chunks along time dimension
input_splits = x.split(self.target_length, dim=-1)
masks = [None for _ in range(len(input_splits))]
# Process each chunk independently
outputs = []
chunk_size_in_patches = self.target_length // self.patch_stride[-1]
for input_split_x in input_splits:
output = self._forward_chunk(input_split_x, attention_mask=None)
# Mean pool each chunk: (B, num_patches, embed_dim) -> (B, embed_dim)
chunks = output.split(chunk_size_in_patches, dim=1)
chunk_means = [c.mean(1) for c in chunks]
outputs.append(torch.stack(chunk_means).mean(0))
# Mean across all split outputs
emb = torch.stack(outputs).mean(0)
return emb
# ============================================================================
# Text Encoder (adapted from dasheng-glap SONAR standalone)
# ============================================================================
class SinusoidalPositionEncoder(nn.Module):
def __init__(self, encoding_dim: int, max_seq_len: int, _legacy_pad_idx: int = 1):
super().__init__()
assert encoding_dim % 2 == 0
self.encoding_dim = encoding_dim
self.max_seq_len = max_seq_len
self._legacy_pad_idx = _legacy_pad_idx
start_step = 1 + _legacy_pad_idx
steps = torch.arange(start_step, start_step + max_seq_len, dtype=torch.float32)
self.register_buffer(
"freqs", self._build_freqs(steps, encoding_dim), persistent=False
)
@staticmethod
def _build_freqs(steps: Tensor, encoding_dim: int) -> Tensor:
num_sin = encoding_dim // 2
indices = torch.arange(num_sin, dtype=torch.float32)
freq_vals = torch.exp(indices * -math.log(10000.0) / (num_sin - 1))
l_half = torch.outer(steps, freq_vals)
r_half = l_half[:, : encoding_dim - num_sin].clone()
return torch.cat([l_half.sin(), r_half.cos()], dim=-1)
def forward(self, seqs: Tensor) -> Tensor:
seq_len = seqs.size(-2)
return (seqs.float() + self.freqs[:seq_len]).type_as(seqs)
class SonarMultiheadAttention(nn.Module):
def __init__(self, model_dim: int, num_heads: int, dropout_p: float = 0.0):
super().__init__()
self.model_dim = model_dim
self.num_heads = num_heads
self.head_dim = model_dim // num_heads
assert model_dim % num_heads == 0
self.q_proj = nn.Linear(model_dim, model_dim, bias=True)
self.k_proj = nn.Linear(model_dim, model_dim, bias=True)
self.v_proj = nn.Linear(model_dim, model_dim, bias=True)
self.output_proj = nn.Linear(model_dim, model_dim, bias=True)
self.attn_dropout_p = dropout_p
def forward(
self,
queries: Tensor,
keys: Tensor,
values: Tensor,
padding_mask: Optional[Tensor] = None,
) -> Tensor:
bsz, seq_len, _ = queries.shape
q = (
self.q_proj(queries)
.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
k = (
self.k_proj(keys)
.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
)
v = (
self.v_proj(values)
.view(bsz, -1, self.num_heads, self.head_dim)
.transpose(1, 2)
)
scale = self.head_dim**-0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
if padding_mask is not None:
attn_weights = attn_weights.masked_fill(
padding_mask[:, None, None, :], float("-inf")
)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
if self.training and self.attn_dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=self.attn_dropout_p)
attn = torch.matmul(attn_weights, v)
attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, self.model_dim)
return self.output_proj(attn)
class _FeedForwardNetwork(nn.Module):
def __init__(self, model_dim: int, inner_dim: int, dropout_p: float = 0.1):
super().__init__()
self.inner_proj = nn.Linear(model_dim, inner_dim, bias=True)
self.output_proj = nn.Linear(inner_dim, model_dim, bias=True)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x: Tensor) -> Tensor:
x = self.inner_proj(x)
x = F.relu(x)
x = self.dropout(x)
x = self.output_proj(x)
return x
class SonarTransformerEncoderLayer(nn.Module):
def __init__(
self, model_dim: int, num_heads: int, ffn_inner_dim: int, dropout_p: float = 0.1
):
super().__init__()
self.self_attn_layer_norm = nn.LayerNorm(model_dim)
self.self_attn = SonarMultiheadAttention(
model_dim, num_heads, dropout_p=dropout_p
)
self.ffn_layer_norm = nn.LayerNorm(model_dim)
self.ffn = _FeedForwardNetwork(model_dim, ffn_inner_dim, dropout_p)
def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
residual = seqs
seqs = self.self_attn_layer_norm(seqs)
seqs = self.self_attn(seqs, seqs, seqs, padding_mask)
seqs = seqs + residual
residual = seqs
seqs = self.ffn_layer_norm(seqs)
seqs = self.ffn(seqs)
seqs = seqs + residual
return seqs
class _SonarTransformerEncoder(nn.Module):
def __init__(self, layers: nn.ModuleList):
super().__init__()
self.layers = layers
def forward(self, seqs: Tensor, padding_mask: Optional[Tensor] = None) -> Tensor:
for layer in self.layers:
seqs = layer(seqs, padding_mask)
return seqs
class _SonarEmbeddingFrontend(nn.Module):
def __init__(
self,
embed: nn.Embedding,
pos_encoder: SinusoidalPositionEncoder,
dropout_p: float = 0.1,
):
super().__init__()
self.embed = embed
self.pos_encoder = pos_encoder
self.dropout = nn.Dropout(dropout_p)
def forward(self, token_ids: Tensor) -> Tensor:
seqs = self.embed(token_ids)
seqs = seqs * math.sqrt(seqs.size(-1))
seqs = self.pos_encoder(seqs)
seqs = self.dropout(seqs)
return seqs
class SonarTextEncoder(nn.Module):
"""24-layer SONAR text encoder with sinusoidal PE and mean pooling."""
def __init__(
self,
vocab_size: int = 256206,
model_dim: int = 1024,
num_layers: int = 24,
num_heads: int = 16,
ffn_inner_dim: int = 8192,
max_seq_len: int = 514,
pad_idx: int = 0,
dropout_p: float = 0.1,
):
super().__init__()
self.model_dim = model_dim
self.pad_idx = pad_idx
embed = nn.Embedding(vocab_size, model_dim, padding_idx=pad_idx)
pos_encoder = SinusoidalPositionEncoder(
model_dim, max_seq_len, _legacy_pad_idx=1
)
self.encoder_frontend = _SonarEmbeddingFrontend(embed, pos_encoder, dropout_p)
layers = nn.ModuleList(
[
SonarTransformerEncoderLayer(
model_dim, num_heads, ffn_inner_dim, dropout_p
)
for _ in range(num_layers)
]
)
self.encoder = _SonarTransformerEncoder(layers)
self.layer_norm = nn.LayerNorm(model_dim)
def forward(
self,
token_ids: Tensor,
padding_mask: Optional[Tensor] = None,
) -> Tensor:
seqs = self.encoder_frontend(token_ids)
seqs = self.encoder(seqs, padding_mask)
seqs = self.layer_norm(seqs)
if padding_mask is None:
sentence_embeddings = seqs.sum(dim=1) / (seqs.size(1) + 1e-7)
else:
mask = (~padding_mask).unsqueeze(-1).float()
seqs = seqs * mask
lengths = mask.sum(dim=1).clamp(min=1e-7)
sentence_embeddings = seqs.sum(dim=1) / lengths
return sentence_embeddings
# ============================================================================
# Tokenizer
# ============================================================================
class NllbTokenizer:
"""Standalone NLLB tokenizer using sentencepiece."""
def __init__(self, model_path: str | Path, langs: Optional[List[str]] = None):
try:
import sentencepiece as spm
except ImportError:
raise ImportError("sentencepiece is required: pip install sentencepiece")
self.sp = spm.SentencePieceProcessor()
if not self.sp.load(str(model_path)):
raise RuntimeError(f"Failed to load SentencePiece model from {model_path}")
self.pad_idx = 0
self.unk_idx = 1
self.bos_idx = 2
self.eos_idx = 3
self._lang_token_to_idx = _NLLB_LANG_TOKEN_IDS
@property
def vocab_size(self) -> int:
return self.sp.get_piece_size() + 206
def create_encoder(self, lang: str = "eng_Latn"):
lang_idx = self._lang_token_to_idx.get(lang)
eos_idx = self.eos_idx
def encode(text: str) -> List[int]:
spm_ids = self.sp.encode(text, out_type=int)
content_ids = [tid + 1 for tid in spm_ids]
if lang_idx is not None:
token_ids = [lang_idx] + content_ids
else:
token_ids = content_ids
token_ids.append(eos_idx)
return token_ids
return encode
# Pre-computed NLLB language -> token ID mapping
_NLLB_LANG_TOKEN_IDS = {
"ace_Arab": 256001,
"ace_Latn": 256002,
"acm_Arab": 256003,
"acq_Arab": 256004,
"aeb_Arab": 256005,
"afr_Latn": 256006,
"ajp_Arab": 256007,
"aka_Latn": 256008,
"amh_Ethi": 256009,
"apc_Arab": 256010,
"arb_Arab": 256011,
"ars_Arab": 256012,
"ary_Arab": 256013,
"arz_Arab": 256014,
"asm_Beng": 256015,
"ast_Latn": 256016,
"awa_Deva": 256017,
"ayr_Latn": 256018,
"azb_Arab": 256019,
"azj_Latn": 256020,
"bak_Cyrl": 256021,
"bam_Latn": 256022,
"ban_Latn": 256023,
"bel_Cyrl": 256024,
"bem_Latn": 256025,
"ben_Beng": 256026,
"bho_Deva": 256027,
"bjn_Arab": 256028,
"bjn_Latn": 256029,
"bod_Tibt": 256030,
"bos_Latn": 256031,
"bug_Latn": 256032,
"bul_Cyrl": 256033,
"cat_Latn": 256034,
"ceb_Latn": 256035,
"ces_Latn": 256036,
"cjk_Latn": 256037,
"ckb_Arab": 256038,
"crh_Latn": 256039,
"cym_Latn": 256040,
"dan_Latn": 256041,
"deu_Latn": 256042,
"dik_Latn": 256043,
"dyu_Latn": 256044,
"dzo_Tibt": 256045,
"ell_Grek": 256046,
"eng_Latn": 256047,
"epo_Latn": 256048,
"est_Latn": 256049,
"eus_Latn": 256050,
"ewe_Latn": 256051,
"fao_Latn": 256052,
"pes_Arab": 256053,
"fij_Latn": 256054,
"fin_Latn": 256055,
"fon_Latn": 256056,
"fra_Latn": 256057,
"fur_Latn": 256058,
"fuv_Latn": 256059,
"gla_Latn": 256060,
"gle_Latn": 256061,
"glg_Latn": 256062,
"grn_Latn": 256063,
"guj_Gujr": 256064,
"hat_Latn": 256065,
"hau_Latn": 256066,
"heb_Hebr": 256067,
"hin_Deva": 256068,
"hne_Deva": 256069,
"hrv_Latn": 256070,
"hun_Latn": 256071,
"hye_Armn": 256072,
"ibo_Latn": 256073,
"ilo_Latn": 256074,
"ind_Latn": 256075,
"isl_Latn": 256076,
"ita_Latn": 256077,
"jav_Latn": 256078,
"jpn_Jpan": 256079,
"kab_Latn": 256080,
"kac_Latn": 256081,
"kam_Latn": 256082,
"kan_Knda": 256083,
"kas_Arab": 256084,
"kas_Deva": 256085,
"kat_Geor": 256086,
"knc_Arab": 256087,
"knc_Latn": 256088,
"kaz_Cyrl": 256089,
"kbp_Latn": 256090,
"kea_Latn": 256091,
"khm_Khmr": 256092,
"kik_Latn": 256093,
"kin_Latn": 256094,
"kir_Cyrl": 256095,
"kmb_Latn": 256096,
"kon_Latn": 256097,
"kor_Hang": 256098,
"kmr_Latn": 256099,
"lao_Laoo": 256100,
"lvs_Latn": 256101,
"lij_Latn": 256102,
"lim_Latn": 256103,
"lin_Latn": 256104,
"lit_Latn": 256105,
"lmo_Latn": 256106,
"ltg_Latn": 256107,
"ltz_Latn": 256108,
"lua_Latn": 256109,
"lug_Latn": 256110,
"luo_Latn": 256111,
"lus_Latn": 256112,
"mag_Deva": 256113,
"mai_Deva": 256114,
"mal_Mlym": 256115,
"mar_Deva": 256116,
"min_Latn": 256117,
"mkd_Cyrl": 256118,
"plt_Latn": 256119,
"mlt_Latn": 256120,
"mni_Beng": 256121,
"khk_Cyrl": 256122,
"mos_Latn": 256123,
"mri_Latn": 256124,
"zsm_Latn": 256125,
"mya_Mymr": 256126,
"nld_Latn": 256127,
"nno_Latn": 256128,
"nob_Latn": 256129,
"npi_Deva": 256130,
"nso_Latn": 256131,
"nus_Latn": 256132,
"nya_Latn": 256133,
"oci_Latn": 256134,
"gaz_Latn": 256135,
"ory_Orya": 256136,
"pag_Latn": 256137,
"pan_Guru": 256138,
"pap_Latn": 256139,
"pol_Latn": 256140,
"por_Latn": 256141,
"prs_Arab": 256142,
"pbt_Arab": 256143,
"quy_Latn": 256144,
"ron_Latn": 256145,
"run_Latn": 256146,
"rus_Cyrl": 256147,
"sag_Latn": 256148,
"san_Deva": 256149,
"sat_Beng": 256150,
"scn_Latn": 256151,
"shn_Mymr": 256152,
"sin_Sinh": 256153,
"slk_Latn": 256154,
"slv_Latn": 256155,
"smo_Latn": 256156,
"sna_Latn": 256157,
"snd_Arab": 256158,
"som_Latn": 256159,
"sot_Latn": 256160,
"spa_Latn": 256161,
"als_Latn": 256162,
"srd_Latn": 256163,
"srp_Cyrl": 256164,
"ssw_Latn": 256165,
"sun_Latn": 256166,
"swe_Latn": 256167,
"swh_Latn": 256168,
"szl_Latn": 256169,
"tam_Taml": 256170,
"tat_Cyrl": 256171,
"tel_Telu": 256172,
"tgk_Cyrl": 256173,
"tgl_Latn": 256174,
"tha_Thai": 256175,
"tir_Ethi": 256176,
"taq_Latn": 256177,
"taq_Tfng": 256178,
"tpi_Latn": 256179,
"tsn_Latn": 256180,
"tso_Latn": 256181,
"tuk_Latn": 256182,
"tum_Latn": 256183,
"tur_Latn": 256184,
"twi_Latn": 256185,
"tzm_Tfng": 256186,
"uig_Arab": 256187,
"ukr_Cyrl": 256188,
"umb_Latn": 256189,
"urd_Arab": 256190,
"uzn_Latn": 256191,
"vec_Latn": 256192,
"vie_Latn": 256193,
"war_Latn": 256194,
"wol_Latn": 256195,
"xho_Latn": 256196,
"ydd_Hebr": 256197,
"yor_Latn": 256198,
"yue_Hant": 256199,
"zho_Hans": 256200,
"zho_Hant": 256201,
"zul_Latn": 256202,
}
# ============================================================================
# GLAP Model
# ============================================================================
class GlapModel(PreTrainedModel):
config_class = GlapConfig
def __init__(self, config: GlapConfig):
super().__init__(config)
self.config = config
# Audio encoder
self.audio_encoder = DashengAudioEncoder(
embed_dim=config.audio_embed_dim,
depth=config.audio_depth,
num_heads=config.audio_num_heads,
patch_size=config.patch_size,
patch_stride=config.patch_stride,
target_length=config.target_length,
)
# Text encoder
self.text_encoder = SonarTextEncoder(
vocab_size=config.text_vocab_size,
model_dim=config.text_model_dim,
num_layers=config.text_num_layers,
num_heads=config.text_num_heads,
ffn_inner_dim=config.text_ffn_inner_dim,
max_seq_len=config.text_max_seq_len,
pad_idx=config.text_pad_idx,
dropout_p=config.text_dropout_p,
)
# Projection layers
self.audio_proj = nn.Sequential(
nn.Linear(config.audio_embed_dim, config.embed_size),
nn.ReLU(),
nn.Linear(config.embed_size, config.embed_size),
)
self.text_proj = nn.Sequential(
nn.Linear(config.text_model_dim, config.embed_size),
nn.ReLU(),
nn.Linear(config.embed_size, config.embed_size),
)
self.tokenizer: Optional[NllbTokenizer] = None
self.post_init()
def _init_weights(self, module):
if isinstance(module, SinusoidalPositionEncoder):
with torch.no_grad():
start_step = 1 + module._legacy_pad_idx
steps = torch.arange(
start_step,
start_step + module.max_seq_len,
dtype=torch.float32,
)
module.freqs.copy_(module._build_freqs(steps, module.encoding_dim))
def _get_tokenizer(self) -> NllbTokenizer:
if self.tokenizer is None:
tokenizer_filename = "sentencepiece.source.256000.model"
tokenizer_path: Optional[Path | str] = None
# 1. Check config._name_or_path (local directory)
model_dir = Path(self.config._name_or_path)
candidate = model_dir / tokenizer_filename
if candidate.exists():
tokenizer_path = candidate
# 2. Check next to this file (modules cache / local install)
if tokenizer_path is None:
candidate = Path(__file__).parent / tokenizer_filename
if candidate.exists():
tokenizer_path = candidate
# 3. Download from HuggingFace Hub
if tokenizer_path is None and hf_hub_download is not None:
try:
tokenizer_path = hf_hub_download(
repo_id=self.config._name_or_path,
filename=tokenizer_filename,
)
except Exception:
pass
if tokenizer_path is None:
raise FileNotFoundError(
f"Could not find {tokenizer_filename}. "
f"Searched {self.config._name_or_path} and "
f"{Path(__file__).parent}."
)
self.tokenizer = NllbTokenizer(tokenizer_path)
return self.tokenizer
def encode_audio(
self,
audio: torch.Tensor,
audio_length: Optional[torch.Tensor] = None,
) -> torch.Tensor:
audio_embeds = self.audio_encoder(audio)
audio_embeds = F.normalize(self.audio_proj(audio_embeds), dim=-1)
return audio_embeds
def encode_text(
self,
text: Sequence[str],
source_lang: str = "eng_Latn",
) -> torch.Tensor:
tokenizer = self._get_tokenizer()
encoder_fn = tokenizer.create_encoder(lang=source_lang)
all_token_ids: List[List[int]] = []
max_seq_len = self.config.text_max_seq_len
for t in text:
token_ids = encoder_fn(t)[:max_seq_len]
all_token_ids.append(token_ids)
max_len = max(len(ids) for ids in all_token_ids) if all_token_ids else 0
batch_size = len(all_token_ids)
device = self.audio_proj[0].weight.device
padded_ids = torch.full(
(batch_size, max_len),
tokenizer.pad_idx,
dtype=torch.long,
device="cpu",
)
padding_mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device="cpu")
for i, ids in enumerate(all_token_ids):
length = len(ids)
padded_ids[i, :length] = torch.tensor(ids, dtype=torch.long)
padding_mask[i, length:] = True
self.text_encoder.eval()
with torch.no_grad():
sentence_embeddings = self.text_encoder(padded_ids, padding_mask)
text_embeds = F.normalize(
self.text_proj(sentence_embeddings.to(device)), dim=-1
)
return text_embeds
def get_audio_features(
self,
audio: torch.Tensor,
audio_length: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.encode_audio(audio, audio_length)
def get_text_features(
self,
text: Sequence[str],
source_lang: str = "eng_Latn",
**kwargs,
) -> torch.Tensor:
return self.encode_text(text, source_lang=source_lang)
def forward(
self,
audio: Optional[torch.Tensor] = None,
text: Optional[Sequence[str]] = None,
audio_length: Optional[torch.Tensor] = None,
source_lang: str = "eng_Latn",
**kwargs,
):
audio_embeds = None
text_embeds = None
if audio is not None:
audio_embeds = self.encode_audio(audio, audio_length)
if text is not None:
text_embeds = self.encode_text(text, source_lang=source_lang)
return audio_embeds, text_embeds
def score(self, audio_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
return 100 * (audio_emb @ text_emb.T)
def score_forward(
self,
audio: torch.Tensor,
text: Sequence[str],
audio_length: Optional[torch.Tensor] = None,
source_lang: str = "eng_Latn",
) -> torch.Tensor:
audio_emb, text_emb = self.forward(audio, text, audio_length, source_lang)
return self.score(audio_emb, text_emb)