""" Public inference interface. Accepts audio as a file path or numpy array and returns transcribed text. Handles chunking for audio longer than 30 seconds. """ from __future__ import annotations import logging import os import time from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING import numpy as np import torch if TYPE_CHECKING: from src.engine.adapter_manager import AdapterManager from src.engine.whisper_base import WhisperBackbone logger = logging.getLogger(__name__) TARGET_SR = 16_000 @dataclass class TranscriptionResult: text: str language: str duration_s: float processing_time_ms: int confidence: float | None = None class Transcriber: """ Composes WhisperBackbone + AdapterManager to provide a simple transcription API. Thread-safety: Not thread-safe by design — use one worker process. """ def __init__(self, backbone: "WhisperBackbone", adapter_manager: "AdapterManager") -> None: self._backbone = backbone self._adapter_manager = adapter_manager def transcribe( self, audio: np.ndarray, sample_rate: int, language: str, use_agri_prompt: bool = True, ) -> TranscriptionResult: """ Transcribe a float32 audio array. For audio > 30s, uses transformers pipeline with chunking. """ t0 = time.time() # Activate the correct language adapter self._adapter_manager.activate(language) processor = self._backbone.processor model = self._adapter_manager.get_model() device = self._backbone.device duration_s = len(audio) / sample_rate if duration_s <= 30.0: text = self._transcribe_chunk(audio, sample_rate, language, processor, model, device) else: text = self._transcribe_long(audio, sample_rate, language, processor, model, device) elapsed_ms = int((time.time() - t0) * 1000) return TranscriptionResult( text=text.strip(), language=language, duration_s=duration_s, processing_time_ms=elapsed_ms, ) def transcribe_file(self, audio_path: str, language: str) -> TranscriptionResult: """Load audio from disk and transcribe.""" import librosa audio, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True) return self.transcribe(audio, sr, language) def _transcribe_chunk( self, audio: np.ndarray, sr: int, language: str, processor, model, device: str, ) -> str: """Transcribe a single ≤30s chunk.""" inputs = processor.feature_extractor( audio, sampling_rate=sr, return_tensors="pt" ) input_features = inputs.input_features.to(device) if device == "cuda": input_features = input_features.half() forced_decoder_ids = processor.get_decoder_prompt_ids( language=language, task="transcribe" ) with torch.no_grad(): predicted_ids = model.generate( input_features, forced_decoder_ids=forced_decoder_ids, max_new_tokens=128, ) return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] def _transcribe_long( self, audio: np.ndarray, sr: int, language: str, processor, model, device: str, ) -> str: """Chunk audio into 30s segments and concatenate transcriptions.""" chunk_size = TARGET_SR * 30 chunks = [audio[i : i + chunk_size] for i in range(0, len(audio), chunk_size)] parts = [] for chunk in chunks: text = self._transcribe_chunk(chunk, sr, language, processor, model, device) parts.append(text) return " ".join(parts)