ground-zero / src /engine /transcriber.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Public inference interface.
Accepts audio as a file path or numpy array and returns transcribed text.
Handles chunking for audio longer than 30 seconds.
"""
from __future__ import annotations
import logging
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
if TYPE_CHECKING:
from src.engine.adapter_manager import AdapterManager
from src.engine.whisper_base import WhisperBackbone
logger = logging.getLogger(__name__)
TARGET_SR = 16_000
@dataclass
class TranscriptionResult:
text: str
language: str
duration_s: float
processing_time_ms: int
confidence: float | None = None
class Transcriber:
"""
Composes WhisperBackbone + AdapterManager to provide a simple transcription API.
Thread-safety: Not thread-safe by design — use one worker process.
"""
def __init__(self, backbone: "WhisperBackbone", adapter_manager: "AdapterManager") -> None:
self._backbone = backbone
self._adapter_manager = adapter_manager
def transcribe(
self,
audio: np.ndarray,
sample_rate: int,
language: str,
use_agri_prompt: bool = True,
) -> TranscriptionResult:
"""
Transcribe a float32 audio array.
For audio > 30s, uses transformers pipeline with chunking.
"""
t0 = time.time()
# Activate the correct language adapter
self._adapter_manager.activate(language)
processor = self._backbone.processor
model = self._adapter_manager.get_model()
device = self._backbone.device
duration_s = len(audio) / sample_rate
if duration_s <= 30.0:
text = self._transcribe_chunk(audio, sample_rate, language, processor, model, device)
else:
text = self._transcribe_long(audio, sample_rate, language, processor, model, device)
elapsed_ms = int((time.time() - t0) * 1000)
return TranscriptionResult(
text=text.strip(),
language=language,
duration_s=duration_s,
processing_time_ms=elapsed_ms,
)
def transcribe_file(self, audio_path: str, language: str) -> TranscriptionResult:
"""Load audio from disk and transcribe."""
import librosa
audio, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
return self.transcribe(audio, sr, language)
def _transcribe_chunk(
self,
audio: np.ndarray,
sr: int,
language: str,
processor,
model,
device: str,
) -> str:
"""Transcribe a single ≤30s chunk."""
inputs = processor.feature_extractor(
audio, sampling_rate=sr, return_tensors="pt"
)
input_features = inputs.input_features.to(device)
if device == "cuda":
input_features = input_features.half()
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language, task="transcribe"
)
with torch.no_grad():
predicted_ids = model.generate(
input_features,
forced_decoder_ids=forced_decoder_ids,
max_new_tokens=128,
)
return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
def _transcribe_long(
self,
audio: np.ndarray,
sr: int,
language: str,
processor,
model,
device: str,
) -> str:
"""Chunk audio into 30s segments and concatenate transcriptions."""
chunk_size = TARGET_SR * 30
chunks = [audio[i : i + chunk_size] for i in range(0, len(audio), chunk_size)]
parts = []
for chunk in chunks:
text = self._transcribe_chunk(chunk, sr, language, processor, model, device)
parts.append(text)
return " ".join(parts)