Spaces:
Running
Running
File size: 3,938 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
Public inference interface.
Accepts audio as a file path or numpy array and returns transcribed text.
Handles chunking for audio longer than 30 seconds.
"""
from __future__ import annotations
import logging
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
if TYPE_CHECKING:
from src.engine.adapter_manager import AdapterManager
from src.engine.whisper_base import WhisperBackbone
logger = logging.getLogger(__name__)
TARGET_SR = 16_000
@dataclass
class TranscriptionResult:
text: str
language: str
duration_s: float
processing_time_ms: int
confidence: float | None = None
class Transcriber:
"""
Composes WhisperBackbone + AdapterManager to provide a simple transcription API.
Thread-safety: Not thread-safe by design — use one worker process.
"""
def __init__(self, backbone: "WhisperBackbone", adapter_manager: "AdapterManager") -> None:
self._backbone = backbone
self._adapter_manager = adapter_manager
def transcribe(
self,
audio: np.ndarray,
sample_rate: int,
language: str,
use_agri_prompt: bool = True,
) -> TranscriptionResult:
"""
Transcribe a float32 audio array.
For audio > 30s, uses transformers pipeline with chunking.
"""
t0 = time.time()
# Activate the correct language adapter
self._adapter_manager.activate(language)
processor = self._backbone.processor
model = self._adapter_manager.get_model()
device = self._backbone.device
duration_s = len(audio) / sample_rate
if duration_s <= 30.0:
text = self._transcribe_chunk(audio, sample_rate, language, processor, model, device)
else:
text = self._transcribe_long(audio, sample_rate, language, processor, model, device)
elapsed_ms = int((time.time() - t0) * 1000)
return TranscriptionResult(
text=text.strip(),
language=language,
duration_s=duration_s,
processing_time_ms=elapsed_ms,
)
def transcribe_file(self, audio_path: str, language: str) -> TranscriptionResult:
"""Load audio from disk and transcribe."""
import librosa
audio, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
return self.transcribe(audio, sr, language)
def _transcribe_chunk(
self,
audio: np.ndarray,
sr: int,
language: str,
processor,
model,
device: str,
) -> str:
"""Transcribe a single ≤30s chunk."""
inputs = processor.feature_extractor(
audio, sampling_rate=sr, return_tensors="pt"
)
input_features = inputs.input_features.to(device)
if device == "cuda":
input_features = input_features.half()
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language, task="transcribe"
)
with torch.no_grad():
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_new_tokens=128,
)
return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
def _transcribe_long(
self,
audio: np.ndarray,
sr: int,
language: str,
processor,
model,
device: str,
) -> str:
"""Chunk audio into 30s segments and concatenate transcriptions."""
chunk_size = TARGET_SR * 30
chunks = [audio[i : i + chunk_size] for i in range(0, len(audio), chunk_size)]
parts = []
for chunk in chunks:
text = self._transcribe_chunk(chunk, sr, language, processor, model, device)
parts.append(text)
return " ".join(parts)
|