Spaces:
Running
Running
| """ | |
| Log-mel spectrogram extraction, padding/truncation, and batch collation for Whisper. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Any | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| if TYPE_CHECKING: | |
| from transformers import WhisperProcessor | |
| logger = logging.getLogger(__name__) | |
| TARGET_SR = 16_000 | |
| MEL_FRAMES = 3000 # 30 seconds at 100 frames/sec | |
| N_MELS = 80 | |
| class AudioFeatureExtractor: | |
| """Wraps WhisperProcessor to extract and normalize audio features.""" | |
| def __init__(self, processor: "WhisperProcessor", config: dict) -> None: | |
| self.processor = processor | |
| self.sample_rate = config.get("audio", {}).get("sample_rate", TARGET_SR) | |
| def extract(self, audio: np.ndarray, sr: int) -> torch.Tensor: | |
| """ | |
| Resample audio to 16kHz, extract log-mel features. | |
| Returns tensor of shape (80, 3000). | |
| """ | |
| if sr != TARGET_SR: | |
| tensor = torch.from_numpy(audio).unsqueeze(0) | |
| tensor = torchaudio.functional.resample(tensor, sr, TARGET_SR) | |
| audio = tensor.squeeze(0).numpy() | |
| inputs = self.processor.feature_extractor( | |
| audio, | |
| sampling_rate=TARGET_SR, | |
| return_tensors="pt", | |
| ) | |
| features = inputs.input_features[0] # (80, 3000) | |
| return features | |
| def pad_or_truncate(self, features: torch.Tensor) -> torch.Tensor: | |
| """Ensure features are exactly (80, 3000).""" | |
| _, t = features.shape | |
| if t < MEL_FRAMES: | |
| pad = torch.zeros(N_MELS, MEL_FRAMES - t, dtype=features.dtype) | |
| features = torch.cat([features, pad], dim=-1) | |
| elif t > MEL_FRAMES: | |
| features = features[:, :MEL_FRAMES] | |
| return features | |
| class DataCollatorSpeechSeq2SeqWithPadding: | |
| """ | |
| Pads input_features to uniform length and label sequences with -100 | |
| (so they are ignored in the cross-entropy loss). | |
| Compatible with HuggingFace Seq2SeqTrainer. | |
| """ | |
| processor: Any | |
| decoder_start_token_id: int | |
| def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]: | |
| # Separate input_features and labels | |
| input_features = [{"input_features": f["input_features"]} for f in features] | |
| label_features = [{"input_ids": f["labels"]} for f in features] | |
| # Pad input features (processor handles this) | |
| batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") | |
| # Pad labels | |
| labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") | |
| labels = labels_batch["input_ids"].masked_fill( | |
| labels_batch.attention_mask.ne(1), -100 | |
| ) | |
| # Remove decoder start token if it was prepended | |
| if (labels[:, 0] == self.decoder_start_token_id).all().item(): | |
| labels = labels[:, 1:] | |
| batch["labels"] = labels | |
| return batch | |