""" Log-mel spectrogram extraction, padding/truncation, and batch collation for Whisper. """ from __future__ import annotations import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Any import numpy as np import torch import torchaudio if TYPE_CHECKING: from transformers import WhisperProcessor logger = logging.getLogger(__name__) TARGET_SR = 16_000 MEL_FRAMES = 3000 # 30 seconds at 100 frames/sec N_MELS = 80 class AudioFeatureExtractor: """Wraps WhisperProcessor to extract and normalize audio features.""" def __init__(self, processor: "WhisperProcessor", config: dict) -> None: self.processor = processor self.sample_rate = config.get("audio", {}).get("sample_rate", TARGET_SR) def extract(self, audio: np.ndarray, sr: int) -> torch.Tensor: """ Resample audio to 16kHz, extract log-mel features. Returns tensor of shape (80, 3000). """ if sr != TARGET_SR: tensor = torch.from_numpy(audio).unsqueeze(0) tensor = torchaudio.functional.resample(tensor, sr, TARGET_SR) audio = tensor.squeeze(0).numpy() inputs = self.processor.feature_extractor( audio, sampling_rate=TARGET_SR, return_tensors="pt", ) features = inputs.input_features[0] # (80, 3000) return features def pad_or_truncate(self, features: torch.Tensor) -> torch.Tensor: """Ensure features are exactly (80, 3000).""" _, t = features.shape if t < MEL_FRAMES: pad = torch.zeros(N_MELS, MEL_FRAMES - t, dtype=features.dtype) features = torch.cat([features, pad], dim=-1) elif t > MEL_FRAMES: features = features[:, :MEL_FRAMES] return features @dataclass class DataCollatorSpeechSeq2SeqWithPadding: """ Pads input_features to uniform length and label sequences with -100 (so they are ignored in the cross-entropy loss). Compatible with HuggingFace Seq2SeqTrainer. """ processor: Any decoder_start_token_id: int def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]: # Separate input_features and labels input_features = [{"input_features": f["input_features"]} for f in features] label_features = [{"input_ids": f["labels"]} for f in features] # Pad input features (processor handles this) batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") # Pad labels labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") labels = labels_batch["input_ids"].masked_fill( labels_batch.attention_mask.ne(1), -100 ) # Remove decoder start token if it was prepended if (labels[:, 0] == self.decoder_start_token_id).all().item(): labels = labels[:, 1:] batch["labels"] = labels return batch