File size: 3,938 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Public inference interface.
Accepts audio as a file path or numpy array and returns transcribed text.
Handles chunking for audio longer than 30 seconds.
"""
from __future__ import annotations

import logging
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import torch

if TYPE_CHECKING:
    from src.engine.adapter_manager import AdapterManager
    from src.engine.whisper_base import WhisperBackbone

logger = logging.getLogger(__name__)

TARGET_SR = 16_000


@dataclass
class TranscriptionResult:
    text: str
    language: str
    duration_s: float
    processing_time_ms: int
    confidence: float | None = None


class Transcriber:
    """
    Composes WhisperBackbone + AdapterManager to provide a simple transcription API.
    Thread-safety: Not thread-safe by design — use one worker process.
    """

    def __init__(self, backbone: "WhisperBackbone", adapter_manager: "AdapterManager") -> None:
        self._backbone = backbone
        self._adapter_manager = adapter_manager

    def transcribe(
        self,
        audio: np.ndarray,
        sample_rate: int,
        language: str,
        use_agri_prompt: bool = True,
    ) -> TranscriptionResult:
        """
        Transcribe a float32 audio array.
        For audio > 30s, uses transformers pipeline with chunking.
        """
        t0 = time.time()

        # Activate the correct language adapter
        self._adapter_manager.activate(language)

        processor = self._backbone.processor
        model = self._adapter_manager.get_model()
        device = self._backbone.device
        duration_s = len(audio) / sample_rate

        if duration_s <= 30.0:
            text = self._transcribe_chunk(audio, sample_rate, language, processor, model, device)
        else:
            text = self._transcribe_long(audio, sample_rate, language, processor, model, device)

        elapsed_ms = int((time.time() - t0) * 1000)
        return TranscriptionResult(
            text=text.strip(),
            language=language,
            duration_s=duration_s,
            processing_time_ms=elapsed_ms,
        )

    def transcribe_file(self, audio_path: str, language: str) -> TranscriptionResult:
        """Load audio from disk and transcribe."""
        import librosa
        audio, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
        return self.transcribe(audio, sr, language)

    def _transcribe_chunk(
        self,
        audio: np.ndarray,
        sr: int,
        language: str,
        processor,
        model,
        device: str,
    ) -> str:
        """Transcribe a single ≤30s chunk."""
        inputs = processor.feature_extractor(
            audio, sampling_rate=sr, return_tensors="pt"
        )
        input_features = inputs.input_features.to(device)
        if device == "cuda":
            input_features = input_features.half()

        forced_decoder_ids = processor.get_decoder_prompt_ids(
            language=language, task="transcribe"
        )

        with torch.no_grad():
            predicted_ids = model.generate(
                input_features,
                forced_decoder_ids=forced_decoder_ids,
                max_new_tokens=128,
            )

        return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    def _transcribe_long(
        self,
        audio: np.ndarray,
        sr: int,
        language: str,
        processor,
        model,
        device: str,
    ) -> str:
        """Chunk audio into 30s segments and concatenate transcriptions."""
        chunk_size = TARGET_SR * 30
        chunks = [audio[i : i + chunk_size] for i in range(0, len(audio), chunk_size)]
        parts = []
        for chunk in chunks:
            text = self._transcribe_chunk(chunk, sr, language, processor, model, device)
            parts.append(text)
        return " ".join(parts)