Spaces:
Running
Running
File size: 2,970 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """
Log-mel spectrogram extraction, padding/truncation, and batch collation for Whisper.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torchaudio
if TYPE_CHECKING:
from transformers import WhisperProcessor
logger = logging.getLogger(__name__)
TARGET_SR = 16_000
MEL_FRAMES = 3000 # 30 seconds at 100 frames/sec
N_MELS = 80
class AudioFeatureExtractor:
"""Wraps WhisperProcessor to extract and normalize audio features."""
def __init__(self, processor: "WhisperProcessor", config: dict) -> None:
self.processor = processor
self.sample_rate = config.get("audio", {}).get("sample_rate", TARGET_SR)
def extract(self, audio: np.ndarray, sr: int) -> torch.Tensor:
"""
Resample audio to 16kHz, extract log-mel features.
Returns tensor of shape (80, 3000).
"""
if sr != TARGET_SR:
tensor = torch.from_numpy(audio).unsqueeze(0)
tensor = torchaudio.functional.resample(tensor, sr, TARGET_SR)
audio = tensor.squeeze(0).numpy()
inputs = self.processor.feature_extractor(
audio,
sampling_rate=TARGET_SR,
return_tensors="pt",
)
features = inputs.input_features[0] # (80, 3000)
return features
def pad_or_truncate(self, features: torch.Tensor) -> torch.Tensor:
"""Ensure features are exactly (80, 3000)."""
_, t = features.shape
if t < MEL_FRAMES:
pad = torch.zeros(N_MELS, MEL_FRAMES - t, dtype=features.dtype)
features = torch.cat([features, pad], dim=-1)
elif t > MEL_FRAMES:
features = features[:, :MEL_FRAMES]
return features
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
"""
Pads input_features to uniform length and label sequences with -100
(so they are ignored in the cross-entropy loss).
Compatible with HuggingFace Seq2SeqTrainer.
"""
processor: Any
decoder_start_token_id: int
def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:
# Separate input_features and labels
input_features = [{"input_features": f["input_features"]} for f in features]
label_features = [{"input_ids": f["labels"]} for f in features]
# Pad input features (processor handles this)
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# Pad labels
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)
# Remove decoder start token if it was prepended
if (labels[:, 0] == self.decoder_start_token_id).all().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
|