| """Audio projector modules for bridging encoder and decoder embeddings. |
| |
| This module contains all projector architectures: |
| - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling |
| - MOSAProjector: MOSA-style dense mixture of experts |
| - SharedMoEAudioProjector: Shared expert + sparse routed experts |
| - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style) |
| """ |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, Blip2QFormerConfig |
| from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
| |
| |
| |
|
|
|
|
| class MLPAudioProjector(nn.Module): |
| """2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR).""" |
|
|
| def __init__(self, config): |
| """Initialize MLP projector. |
| |
| Args: |
| config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride |
| """ |
| super().__init__() |
|
|
| encoder_dim = getattr(config, "encoder_dim", 768) |
| llm_dim = getattr(config, "llm_dim", 2048) |
| self.k = getattr(config, "projector_pool_stride", 4) |
|
|
| |
| |
| in_dim = encoder_dim * self.k |
| hidden_dim = llm_dim * 2 |
| self.linear_1 = nn.Linear(in_dim, hidden_dim) |
| self.act = nn.GELU() |
| self.linear_2 = nn.Linear(hidden_dim, llm_dim) |
|
|
| def get_output_length(self, input_length: int) -> int: |
| """Calculate output sequence length given input length (matches GLM-ASR).""" |
| |
| return (input_length - self.k) // self.k + 1 |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Project audio features to LLM embedding space. |
| |
| Args: |
| x: Audio encoder output of shape [batch, seq_len, encoder_dim] |
| |
| Returns: |
| Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim] |
| """ |
| batch, seq, dim = x.shape |
| |
| |
| out_len = (seq - self.k) // self.k + 1 |
| x = x[:, : out_len * self.k, :] |
| x = x.reshape(batch, out_len, dim * self.k) |
|
|
| x = self.linear_1(x) |
| x = self.act(x) |
| return self.linear_2(x) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SimpleAdapter(nn.Module): |
| """Simple 2-layer GELU adapter (from MOSA paper).""" |
|
|
| def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): |
| super().__init__() |
| self.fc1 = nn.Linear(input_dim, hidden_dim) |
| self.act = nn.GELU() |
| self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.fc2(self.act(self.fc1(x))) |
|
|
|
|
| class SwiGLU(nn.Module): |
| """SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.).""" |
|
|
| def __init__(self, dim: int, hidden_dim: int, bias: bool = False): |
| super().__init__() |
| self.w1 = nn.Linear(dim, hidden_dim, bias=bias) |
| self.w2 = nn.Linear(dim, hidden_dim, bias=bias) |
| self.w3 = nn.Linear(hidden_dim, dim, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
| class AsymmetricSwiGLU(nn.Module): |
| """SwiGLU that handles different input and output dimensions.""" |
|
|
| def __init__( |
| self, in_features: int, hidden_features: int, out_features: int, bias: bool = False |
| ): |
| super().__init__() |
| self.w1 = nn.Linear(in_features, hidden_features, bias=bias) |
| self.w2 = nn.Linear(in_features, hidden_features, bias=bias) |
| self.w3 = nn.Linear(hidden_features, out_features, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
| class MOSAProjector(nn.Module): |
| """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters. |
| |
| Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998). |
| Uses softmax gating over all experts (dense MoE) with only cross-entropy loss. |
| Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total). |
| """ |
|
|
| def __init__(self, config): |
| """Initialize MOSA projector. |
| |
| Args: |
| config: ASRConfig with encoder_dim, llm_dim, num_experts |
| """ |
| super().__init__() |
| self.encoder_dim = getattr(config, "encoder_dim", None) or 1280 |
| self.llm_dim = getattr(config, "llm_dim", None) or 2048 |
| self.num_experts = getattr(config, "num_experts", None) or 4 |
| adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096 |
| router_hidden = getattr(config, "router_hidden_dim", None) or 512 |
|
|
| |
| |
| self.downsampler = nn.Sequential( |
| nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1), |
| nn.GELU(), |
| nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1), |
| nn.GELU(), |
| ) |
|
|
| |
| |
| self.router = nn.Sequential( |
| nn.Linear(self.llm_dim, router_hidden), |
| nn.ReLU(), |
| nn.Linear(router_hidden, self.num_experts), |
| ) |
|
|
| |
| |
| self.experts = nn.ModuleList( |
| [ |
| SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim) |
| for _ in range(self.num_experts) |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Project audio features using mixture of experts. |
| |
| Args: |
| x: Audio encoder output of shape [batch, seq_len, encoder_dim] |
| |
| Returns: |
| Projected features of shape [batch, out_len, llm_dim] |
| """ |
| |
| |
| x = x.transpose(1, 2) |
| x = self.downsampler(x) |
| |
| x = x.transpose(1, 2) |
|
|
| |
| routing_weights = F.softmax(self.router(x), dim=-1) |
|
|
| |
| expert_outputs = torch.stack([expert(x) for expert in self.experts]) |
| return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights) |
|
|
| def get_output_length(self, input_length: int) -> int: |
| """Calculate output sequence length after Conv1d downsampling (4x reduction).""" |
| |
| |
| after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1 |
| return (after_conv1 + 2 * 1 - 3) // 2 + 1 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class MoEAudioProjector(nn.Module): |
| """MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation. |
| |
| Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens. |
| No external dependencies (megablocks removed). |
| |
| Architecture matches main branch: norm → experts(in_dim → hidden → out_dim) |
| """ |
|
|
| def __init__(self, config): |
| """Initialize MoE projector. |
| |
| Args: |
| config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok |
| """ |
| super().__init__() |
|
|
| self.k = getattr(config, "projector_pool_stride", 4) |
| self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01) |
|
|
| |
| self.router_z_loss_coef = getattr( |
| config, "router_z_loss_coef", 1e-4 |
| ) |
| self.router_jitter_noise = getattr( |
| config, "router_jitter_noise", 0.01 |
| ) |
|
|
| in_dim = config.encoder_dim * self.k |
| out_dim = config.llm_dim |
|
|
| |
| hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim |
|
|
| |
| self.num_experts = getattr(config, "num_experts", 4) |
| self.top_k = getattr(config, "num_experts_per_tok", 2) |
|
|
| |
| self.norm = LlamaRMSNorm(in_dim, eps=1e-6) |
|
|
| |
| self.router = nn.Linear(in_dim, self.num_experts, bias=False) |
|
|
| |
| self.experts = nn.ModuleList( |
| [SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)] |
| ) |
|
|
| |
| self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim) |
|
|
| |
| self._init_weights() |
|
|
| self.last_aux_loss = torch.tensor(0.0) |
|
|
| def _init_weights(self): |
| """Initialize weights for stable training start.""" |
| with torch.no_grad(): |
| |
| nn.init.normal_(self.router.weight, mean=0.0, std=0.02) |
|
|
| |
| for expert in [self.shared_expert, *self.experts]: |
| nn.init.xavier_uniform_(expert.fc1.weight) |
| nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) |
|
|
| def get_output_length(self, input_length: int) -> int: |
| """Calculate output sequence length given input length (matches MLP projector).""" |
| return (input_length - self.k) // self.k + 1 |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Project audio features using shared + sparse MoE. |
| |
| Args: |
| x: Audio encoder output of shape [batch, seq_len, encoder_dim] |
| |
| Returns: |
| Projected features of shape [batch, out_len, llm_dim] |
| """ |
| |
| batch, seq, dim = x.shape |
| out_len = (seq - self.k) // self.k + 1 |
| x = x[:, : out_len * self.k, :] |
| x = x.reshape(batch, out_len, dim * self.k) |
|
|
| |
| x = self.norm(x) |
| flat_x = x.view(-1, x.size(-1)) |
|
|
| |
| output = self.shared_expert(flat_x) |
|
|
| |
| self.last_aux_loss = self._forward_sparse(flat_x, output) |
|
|
| return output.view(batch, out_len, -1) |
|
|
| def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor: |
| """Stability-hardened sparse expert dispatch (in-place add to output). |
| |
| Args: |
| x: Flattened input of shape [tokens, dim] |
| output: Output tensor to add sparse expert results into (in-place) |
| |
| Returns: |
| Auxiliary loss tensor |
| """ |
| |
| logits = self.router(x) |
|
|
| if self.training and self.router_jitter_noise > 0: |
| |
| |
| noise = torch.empty_like(logits).uniform_( |
| 1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise |
| ) |
| logits = logits * noise |
|
|
| |
| probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x) |
|
|
| |
| top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1) |
|
|
| |
| top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6) |
|
|
| |
| aux_loss = torch.tensor(0.0, device=x.device) |
|
|
| if self.training: |
| |
| prob_per_expert = probs.mean(0) |
| target = 1.0 / self.num_experts |
| balance_loss = ( |
| self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts |
| ) |
|
|
| |
| z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean() |
|
|
| aux_loss = balance_loss + z_loss |
|
|
| |
| for i, expert in enumerate(self.experts): |
| |
| mask = top_k_indices == i |
|
|
| if mask.any(): |
| |
| token_idx, k_idx = torch.where(mask) |
|
|
| |
| expert_input = x[token_idx] |
| expert_output = expert(expert_input) |
|
|
| |
| weight = top_k_weights[token_idx, k_idx].unsqueeze(-1) |
| weighted_output = (expert_output * weight).type_as(output) |
|
|
| |
| output.index_add_(0, token_idx, weighted_output) |
|
|
| return aux_loss |
|
|
| def get_aux_loss(self) -> torch.Tensor: |
| """Return auxiliary load balancing loss.""" |
| return self.last_aux_loss |
|
|
|
|
| |
| |
| |
|
|
|
|
| class QFormerAudioProjector(nn.Module): |
| """ |
| BLIP-2 QFormer projector with learnable queries. |
| |
| Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable |
| query embeddings to compress and project audio encoder outputs. The audio |
| sequence is processed in windows and downsampled via cross-attention. |
| """ |
|
|
| def __init__(self, config): |
| """Initialize QFormer projector. |
| |
| Args: |
| config: ASRConfig with encoder_dim, llm_dim, qformer_* settings |
| """ |
| super().__init__() |
|
|
| encoder_dim = config.encoder_dim |
| llm_dim = config.llm_dim |
|
|
| |
| self.window_size = getattr(config, "qformer_window_size", 15) |
| self.downsample_rate = getattr(config, "downsample_rate", 5) |
| self.num_queries = self.window_size // self.downsample_rate |
|
|
| |
| qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim |
| qformer_num_layers = getattr(config, "qformer_num_layers", 2) |
| qformer_num_heads = getattr(config, "qformer_num_heads", 16) |
| qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or ( |
| qformer_hidden * 4 |
| ) |
|
|
| |
| self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden)) |
| self.query.data.normal_(mean=0.0, std=1.0) |
|
|
| |
| if encoder_dim != qformer_hidden: |
| self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False) |
| else: |
| self.encoder_proj = None |
|
|
| |
| qformer_config = Blip2QFormerConfig( |
| hidden_size=qformer_hidden, |
| num_hidden_layers=qformer_num_layers, |
| num_attention_heads=qformer_num_heads, |
| intermediate_size=qformer_intermediate, |
| encoder_hidden_size=qformer_hidden, |
| cross_attention_frequency=1, |
| |
| hidden_act="gelu", |
| attention_probs_dropout_prob=0.1, |
| hidden_dropout_prob=0.1, |
| layer_norm_eps=1e-12, |
| initializer_range=0.02, |
| ) |
| self.qformer = AutoModel.from_config(qformer_config) |
|
|
| |
| self.linear = nn.Linear(qformer_hidden, llm_dim) |
|
|
| def get_output_length(self, input_length: int) -> int: |
| """Calculate output sequence length given input length.""" |
| |
| nblocks = math.ceil(input_length / self.window_size) |
| return nblocks * self.num_queries |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states: [batch_size, seq_len, encoder_dim] |
| |
| Returns: |
| projected: [batch_size, num_output_tokens, llm_dim] |
| """ |
| batch_size, seq_len, dim = hidden_states.size() |
|
|
| |
| target_dtype = self.query.dtype |
| if hidden_states.dtype != target_dtype: |
| hidden_states = hidden_states.to(target_dtype) |
|
|
| |
| if self.encoder_proj is not None: |
| hidden_states = self.encoder_proj(hidden_states) |
|
|
| |
| nblocks = math.ceil(seq_len / self.window_size) |
| pad = nblocks * self.window_size - seq_len |
| if pad > 0: |
| hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0) |
|
|
| |
| effective_batch = batch_size * nblocks |
| hidden_states = hidden_states.view(effective_batch, self.window_size, -1) |
|
|
| |
| query_embeds = self.query.expand(effective_batch, -1, -1) |
|
|
| |
| query_output = self.qformer( |
| query_embeds=query_embeds, |
| encoder_hidden_states=hidden_states, |
| return_dict=True, |
| ) |
|
|
| |
| output_tokens = nblocks * self.num_queries |
| query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1) |
|
|
| |
| return self.linear(query_proj) |
|
|
|
|
| |
| |
| |
|
|
| PROJECTOR_CLASSES = { |
| "mlp": MLPAudioProjector, |
| "mosa": MOSAProjector, |
| "moe": MoEAudioProjector, |
| "qformer": QFormerAudioProjector, |
| } |
|
|