File size: 2,970 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Log-mel spectrogram extraction, padding/truncation, and batch collation for Whisper.
"""
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
import torchaudio

if TYPE_CHECKING:
    from transformers import WhisperProcessor

logger = logging.getLogger(__name__)

TARGET_SR = 16_000
MEL_FRAMES = 3000  # 30 seconds at 100 frames/sec
N_MELS = 80


class AudioFeatureExtractor:
    """Wraps WhisperProcessor to extract and normalize audio features."""

    def __init__(self, processor: "WhisperProcessor", config: dict) -> None:
        self.processor = processor
        self.sample_rate = config.get("audio", {}).get("sample_rate", TARGET_SR)

    def extract(self, audio: np.ndarray, sr: int) -> torch.Tensor:
        """
        Resample audio to 16kHz, extract log-mel features.
        Returns tensor of shape (80, 3000).
        """
        if sr != TARGET_SR:
            tensor = torch.from_numpy(audio).unsqueeze(0)
            tensor = torchaudio.functional.resample(tensor, sr, TARGET_SR)
            audio = tensor.squeeze(0).numpy()

        inputs = self.processor.feature_extractor(
            audio,
            sampling_rate=TARGET_SR,
            return_tensors="pt",
        )
        features = inputs.input_features[0]  # (80, 3000)
        return features

    def pad_or_truncate(self, features: torch.Tensor) -> torch.Tensor:
        """Ensure features are exactly (80, 3000)."""
        _, t = features.shape
        if t < MEL_FRAMES:
            pad = torch.zeros(N_MELS, MEL_FRAMES - t, dtype=features.dtype)
            features = torch.cat([features, pad], dim=-1)
        elif t > MEL_FRAMES:
            features = features[:, :MEL_FRAMES]
        return features


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Pads input_features to uniform length and label sequences with -100
    (so they are ignored in the cross-entropy loss).
    Compatible with HuggingFace Seq2SeqTrainer.
    """
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:
        # Separate input_features and labels
        input_features = [{"input_features": f["input_features"]} for f in features]
        label_features = [{"input_ids": f["labels"]} for f in features]

        # Pad input features (processor handles this)
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad labels
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # Remove decoder start token if it was prepended
        if (labels[:, 0] == self.decoder_start_token_id).all().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch