""" STT confidence extractor. Wraps Whisper's generate() with return_dict_in_generate=True to compute avg_logprob — the mean log-probability over generated tokens. This mirrors the avg_logprob field returned by the OpenAI Whisper API. Threshold: avg_logprob < -1.0 signals a low-confidence transcription where the model was essentially guessing. The caller should treat this as "confused" and prompt the user to repeat and explain the word. """ from __future__ import annotations import logging import numpy as np import torch logger = logging.getLogger(__name__) # Anything below this is considered "confused" transcription LOW_CONFIDENCE_THRESHOLD: float = -1.0 # Message substituted for the transcript when confidence is low CONFUSION_PROMPT: str = ( "The user spoke, but I am confused. " "Ask the user in English to repeat the local word and explain its meaning." ) def transcribe_with_confidence( audio_np: np.ndarray, model, processor, forced_ids, max_new_tokens: int = 256, ) -> tuple[str, float]: """ Run Whisper and return (text, avg_logprob). avg_logprob is in (-inf, 0]. A value close to 0 means high confidence. Returns avg_logprob = 0.0 if computation fails (treated as confident). Args: audio_np: float32 audio at 16 kHz. model: WhisperForConditionalGeneration instance. processor: WhisperProcessor instance. forced_ids: Output of get_decoder_prompt_ids() or None. max_new_tokens: Maximum tokens to generate. """ inputs = processor.feature_extractor( audio_np, sampling_rate=16_000, return_tensors="pt" ) input_features = inputs.input_features with torch.no_grad(): output = model.generate( input_features, forced_decoder_ids=forced_ids, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, ) text = processor.batch_decode(output.sequences, skip_special_tokens=True)[0].strip() # Compute avg log-prob via model.compute_transition_scores avg_logprob = 0.0 try: transition_scores = model.compute_transition_scores( output.sequences, output.scores, normalize_logits=True, ) # Shape: (batch, generated_len). Take batch[0], skip zero-padded positions. scores = transition_scores[0] valid = scores[scores != 0] if valid.numel() > 0: avg_logprob = valid.mean().item() except Exception as exc: logger.debug("avg_logprob computation failed: %s", exc) logger.debug( "STT confidence: avg_logprob=%.3f text=%r", avg_logprob, text[:60], ) return text, avg_logprob