Spaces:
Running
Running
| """ | |
| STT confidence extractor. | |
| Wraps Whisper's generate() with return_dict_in_generate=True to compute | |
| avg_logprob — the mean log-probability over generated tokens. This mirrors | |
| the avg_logprob field returned by the OpenAI Whisper API. | |
| Threshold: avg_logprob < -1.0 signals a low-confidence transcription where | |
| the model was essentially guessing. The caller should treat this as "confused" | |
| and prompt the user to repeat and explain the word. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import numpy as np | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| # Anything below this is considered "confused" transcription | |
| LOW_CONFIDENCE_THRESHOLD: float = -1.0 | |
| # Message substituted for the transcript when confidence is low | |
| CONFUSION_PROMPT: str = ( | |
| "The user spoke, but I am confused. " | |
| "Ask the user in English to repeat the local word and explain its meaning." | |
| ) | |
| def transcribe_with_confidence( | |
| audio_np: np.ndarray, | |
| model, | |
| processor, | |
| forced_ids, | |
| max_new_tokens: int = 256, | |
| ) -> tuple[str, float]: | |
| """ | |
| Run Whisper and return (text, avg_logprob). | |
| avg_logprob is in (-inf, 0]. A value close to 0 means high confidence. | |
| Returns avg_logprob = 0.0 if computation fails (treated as confident). | |
| Args: | |
| audio_np: float32 audio at 16 kHz. | |
| model: WhisperForConditionalGeneration instance. | |
| processor: WhisperProcessor instance. | |
| forced_ids: Output of get_decoder_prompt_ids() or None. | |
| max_new_tokens: Maximum tokens to generate. | |
| """ | |
| inputs = processor.feature_extractor( | |
| audio_np, sampling_rate=16_000, return_tensors="pt" | |
| ) | |
| input_features = inputs.input_features | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_features, | |
| forced_decoder_ids=forced_ids, | |
| max_new_tokens=max_new_tokens, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| text = processor.batch_decode(output.sequences, skip_special_tokens=True)[0].strip() | |
| # Compute avg log-prob via model.compute_transition_scores | |
| avg_logprob = 0.0 | |
| try: | |
| transition_scores = model.compute_transition_scores( | |
| output.sequences, | |
| output.scores, | |
| normalize_logits=True, | |
| ) | |
| # Shape: (batch, generated_len). Take batch[0], skip zero-padded positions. | |
| scores = transition_scores[0] | |
| valid = scores[scores != 0] | |
| if valid.numel() > 0: | |
| avg_logprob = valid.mean().item() | |
| except Exception as exc: | |
| logger.debug("avg_logprob computation failed: %s", exc) | |
| logger.debug( | |
| "STT confidence: avg_logprob=%.3f text=%r", | |
| avg_logprob, | |
| text[:60], | |
| ) | |
| return text, avg_logprob | |