| """ |
| Router Model Architecture for Smart ASR Routing. |
| |
| Regression-based approach: predicts WER for each backend model. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional, Dict |
|
|
| from transformers import PreTrainedModel, PretrainedConfig, WhisperModel, WhisperFeatureExtractor |
| from transformers.modeling_outputs import ModelOutput |
|
|
|
|
| class AttentionPooling(nn.Module): |
| """Learnable attention pooling for variable-length sequences.""" |
|
|
| def __init__(self, input_dim: int): |
| super().__init__() |
| self.attention = nn.Sequential( |
| nn.Linear(input_dim, 1), |
| nn.Tanh() |
| ) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
| """ |
| Args: |
| x: [Batch, Time, Dim] |
| mask: [Batch, Time] (1 for valid, 0 for pad) |
| Returns: |
| pooled: [Batch, Dim] |
| """ |
| scores = self.attention(x) |
|
|
| if mask is not None: |
| scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -1e9) |
|
|
| weights = F.softmax(scores, dim=1) |
| return torch.sum(x * weights, dim=1) |
|
|
|
|
| class ASRRouterConfig(PretrainedConfig): |
| """Configuration for ASRRouter model.""" |
| model_type = "asr_router" |
|
|
| def __init__( |
| self, |
| input_dim: int = 384, |
| hidden_dim: int = 128, |
| intermediate_dim: int = 64, |
| dropout: float = 0.1, |
| num_models: int = 3, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.input_dim = input_dim |
| self.hidden_dim = hidden_dim |
| self.intermediate_dim = intermediate_dim |
| self.dropout = dropout |
| self.num_models = num_models |
|
|
|
|
| @dataclass |
| class RouterOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| pred_wers: torch.FloatTensor = None |
|
|
|
|
| class ASRRouterModel(PreTrainedModel): |
| """ |
| Regression Router. |
| Input: 384-dimensional Whisper encoder embeddings |
| Output: Estimated WER (0.0+, unbounded) for each backend model. |
| Uses Softplus activation to ensure non-negative outputs while allowing WER > 1.0. |
| """ |
| config_class = ASRRouterConfig |
|
|
| MODEL_ID_MAP = {0: "kyutai", 1: "granite", 2: "tiny_audio"} |
|
|
| def __init__(self, config: ASRRouterConfig): |
| super().__init__(config) |
|
|
| self.network = nn.Sequential( |
| nn.Linear(config.input_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.LayerNorm(config.hidden_dim), |
| nn.Dropout(config.dropout), |
|
|
| nn.Linear(config.hidden_dim, config.intermediate_dim), |
| nn.GELU(), |
| nn.LayerNorm(config.intermediate_dim), |
|
|
| nn.Linear(config.intermediate_dim, config.num_models) |
| ) |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| embeddings: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| ) -> RouterOutput: |
|
|
| |
| pred_wers = F.softplus(self.network(embeddings)) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.mse_loss(pred_wers, labels) |
|
|
| return RouterOutput(loss=loss, pred_wers=pred_wers) |
|
|
| def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor: |
| """Get predicted WERs for each model.""" |
| with torch.no_grad(): |
| return F.softplus(self.network(embeddings)) |
|
|
|
|
| class RouterWithFeatureExtractor: |
| """ |
| Production-ready router with attention pooling and memory optimizations. |
| """ |
| def __init__(self, router: ASRRouterModel, device: str = "cpu"): |
| self.device = device |
| self.router = router.to(device) |
| self.router.eval() |
|
|
| |
| self.attention_pooling = AttentionPooling(input_dim=384).to(device) |
| self.attention_pooling.eval() |
|
|
| |
| print("Loading Whisper Encoder...") |
| full_whisper = WhisperModel.from_pretrained("openai/whisper-tiny") |
| self.whisper_encoder = full_whisper.encoder.to(device) |
| self.whisper_encoder.eval() |
|
|
| del full_whisper.decoder |
| del full_whisper |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
| self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny") |
|
|
| def extract_features(self, waveform: torch.Tensor) -> torch.Tensor: |
| """Extract embeddings using Attention Pooling for variable lengths.""" |
| if waveform.dim() == 1: |
| waveform = waveform.unsqueeze(0) |
|
|
| |
| audio_np = [w.cpu().numpy() for w in waveform] |
|
|
| inputs = self.feature_extractor( |
| audio_np, |
| sampling_rate=16000, |
| return_tensors="pt", |
| return_attention_mask=True |
| ) |
|
|
| input_features = inputs.input_features.to(self.device) |
| attention_mask = inputs.attention_mask.to(self.device) |
|
|
| with torch.no_grad(): |
| last_hidden_state = self.whisper_encoder(input_features).last_hidden_state |
|
|
| |
| mask_resized = F.interpolate( |
| attention_mask.unsqueeze(1).float(), |
| size=last_hidden_state.shape[1], |
| mode='nearest' |
| ).squeeze(1) |
|
|
| |
| return self.attention_pooling(last_hidden_state, mask_resized) |
|
|
| def predict(self, waveform: torch.Tensor) -> Dict: |
| """Select the model with the lowest predicted WER.""" |
| embeddings = self.extract_features(waveform) |
|
|
| with torch.no_grad(): |
| output = self.router(embeddings) |
| pred_wers = output.pred_wers[0].cpu().numpy() |
|
|
| scores = { |
| "kyutai": float(pred_wers[0]), |
| "granite": float(pred_wers[1]), |
| "tiny_audio": float(pred_wers[2]) |
| } |
|
|
| best_model = min(scores.items(), key=lambda x: x[1]) |
|
|
| return { |
| "selected_model": best_model[0], |
| "predicted_wers": scores, |
| "confidence": max(0.0, 1.0 - best_model[1]) |
| } |
|
|
|
|
| |
| ASRRouterConfig.register_for_auto_class() |
| ASRRouterModel.register_for_auto_class("AutoModel") |
|
|