| | """ |
| | Router Model Architecture for Smart ASR Routing. |
| | |
| | Regression-based approach: predicts WER for each backend model. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from dataclasses import dataclass |
| | from typing import Optional, Dict |
| |
|
| | from transformers import PreTrainedModel, PretrainedConfig, WhisperModel, WhisperFeatureExtractor |
| | from transformers.modeling_outputs import ModelOutput |
| |
|
| |
|
| | class AttentionPooling(nn.Module): |
| | """Learnable attention pooling for variable-length sequences.""" |
| |
|
| | def __init__(self, input_dim: int): |
| | super().__init__() |
| | self.attention = nn.Sequential( |
| | nn.Linear(input_dim, 1), |
| | nn.Tanh() |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: [Batch, Time, Dim] |
| | mask: [Batch, Time] (1 for valid, 0 for pad) |
| | Returns: |
| | pooled: [Batch, Dim] |
| | """ |
| | scores = self.attention(x) |
| |
|
| | if mask is not None: |
| | scores = scores.masked_fill(mask.unsqueeze(-1) == 0, -1e9) |
| |
|
| | weights = F.softmax(scores, dim=1) |
| | return torch.sum(x * weights, dim=1) |
| |
|
| |
|
| | class ASRRouterConfig(PretrainedConfig): |
| | """Configuration for ASRRouter model.""" |
| | model_type = "asr_router" |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int = 384, |
| | hidden_dim: int = 128, |
| | intermediate_dim: int = 64, |
| | dropout: float = 0.1, |
| | num_models: int = 3, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.input_dim = input_dim |
| | self.hidden_dim = hidden_dim |
| | self.intermediate_dim = intermediate_dim |
| | self.dropout = dropout |
| | self.num_models = num_models |
| |
|
| |
|
| | @dataclass |
| | class RouterOutput(ModelOutput): |
| | loss: Optional[torch.FloatTensor] = None |
| | pred_wers: torch.FloatTensor = None |
| |
|
| |
|
| | class ASRRouterModel(PreTrainedModel): |
| | """ |
| | Regression Router. |
| | Input: 384-dimensional Whisper encoder embeddings |
| | Output: Estimated WER (0.0+, unbounded) for each backend model. |
| | Uses Softplus activation to ensure non-negative outputs while allowing WER > 1.0. |
| | """ |
| | config_class = ASRRouterConfig |
| |
|
| | MODEL_ID_MAP = {0: "kyutai", 1: "granite", 2: "tiny_audio"} |
| |
|
| | def __init__(self, config: ASRRouterConfig): |
| | super().__init__(config) |
| |
|
| | self.network = nn.Sequential( |
| | nn.Linear(config.input_dim, config.hidden_dim), |
| | nn.GELU(), |
| | nn.LayerNorm(config.hidden_dim), |
| | nn.Dropout(config.dropout), |
| |
|
| | nn.Linear(config.hidden_dim, config.intermediate_dim), |
| | nn.GELU(), |
| | nn.LayerNorm(config.intermediate_dim), |
| |
|
| | nn.Linear(config.intermediate_dim, config.num_models) |
| | ) |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | embeddings: torch.Tensor, |
| | labels: Optional[torch.Tensor] = None, |
| | ) -> RouterOutput: |
| |
|
| | |
| | pred_wers = F.softplus(self.network(embeddings)) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = F.mse_loss(pred_wers, labels) |
| |
|
| | return RouterOutput(loss=loss, pred_wers=pred_wers) |
| |
|
| | def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor: |
| | """Get predicted WERs for each model.""" |
| | with torch.no_grad(): |
| | return F.softplus(self.network(embeddings)) |
| |
|
| |
|
| | class RouterWithFeatureExtractor: |
| | """ |
| | Production-ready router with attention pooling and memory optimizations. |
| | """ |
| | def __init__(self, router: ASRRouterModel, device: str = "cpu"): |
| | self.device = device |
| | self.router = router.to(device) |
| | self.router.eval() |
| |
|
| | |
| | self.attention_pooling = AttentionPooling(input_dim=384).to(device) |
| | self.attention_pooling.eval() |
| |
|
| | |
| | print("Loading Whisper Encoder...") |
| | full_whisper = WhisperModel.from_pretrained("openai/whisper-tiny") |
| | self.whisper_encoder = full_whisper.encoder.to(device) |
| | self.whisper_encoder.eval() |
| |
|
| | del full_whisper.decoder |
| | del full_whisper |
| | torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
|
| | self.feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny") |
| |
|
| | def extract_features(self, waveform: torch.Tensor) -> torch.Tensor: |
| | """Extract embeddings using Attention Pooling for variable lengths.""" |
| | if waveform.dim() == 1: |
| | waveform = waveform.unsqueeze(0) |
| |
|
| | |
| | audio_np = [w.cpu().numpy() for w in waveform] |
| |
|
| | inputs = self.feature_extractor( |
| | audio_np, |
| | sampling_rate=16000, |
| | return_tensors="pt", |
| | return_attention_mask=True |
| | ) |
| |
|
| | input_features = inputs.input_features.to(self.device) |
| | attention_mask = inputs.attention_mask.to(self.device) |
| |
|
| | with torch.no_grad(): |
| | last_hidden_state = self.whisper_encoder(input_features).last_hidden_state |
| |
|
| | |
| | mask_resized = F.interpolate( |
| | attention_mask.unsqueeze(1).float(), |
| | size=last_hidden_state.shape[1], |
| | mode='nearest' |
| | ).squeeze(1) |
| |
|
| | |
| | return self.attention_pooling(last_hidden_state, mask_resized) |
| |
|
| | def predict(self, waveform: torch.Tensor) -> Dict: |
| | """Select the model with the lowest predicted WER.""" |
| | embeddings = self.extract_features(waveform) |
| |
|
| | with torch.no_grad(): |
| | output = self.router(embeddings) |
| | pred_wers = output.pred_wers[0].cpu().numpy() |
| |
|
| | scores = { |
| | "kyutai": float(pred_wers[0]), |
| | "granite": float(pred_wers[1]), |
| | "tiny_audio": float(pred_wers[2]) |
| | } |
| |
|
| | best_model = min(scores.items(), key=lambda x: x[1]) |
| |
|
| | return { |
| | "selected_model": best_model[0], |
| | "predicted_wers": scores, |
| | "confidence": max(0.0, 1.0 - best_model[1]) |
| | } |
| |
|
| |
|
| | |
| | ASRRouterConfig.register_for_auto_class() |
| | ASRRouterModel.register_for_auto_class("AutoModel") |
| |
|