Spaces:
Running
Running
File size: 2,791 Bytes
3657607 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | """
STT confidence extractor.
Wraps Whisper's generate() with return_dict_in_generate=True to compute
avg_logprob — the mean log-probability over generated tokens. This mirrors
the avg_logprob field returned by the OpenAI Whisper API.
Threshold: avg_logprob < -1.0 signals a low-confidence transcription where
the model was essentially guessing. The caller should treat this as "confused"
and prompt the user to repeat and explain the word.
"""
from __future__ import annotations
import logging
import numpy as np
import torch
logger = logging.getLogger(__name__)
# Anything below this is considered "confused" transcription
LOW_CONFIDENCE_THRESHOLD: float = -1.0
# Message substituted for the transcript when confidence is low
CONFUSION_PROMPT: str = (
"The user spoke, but I am confused. "
"Ask the user in English to repeat the local word and explain its meaning."
)
def transcribe_with_confidence(
audio_np: np.ndarray,
model,
processor,
forced_ids,
max_new_tokens: int = 256,
) -> tuple[str, float]:
"""
Run Whisper and return (text, avg_logprob).
avg_logprob is in (-inf, 0]. A value close to 0 means high confidence.
Returns avg_logprob = 0.0 if computation fails (treated as confident).
Args:
audio_np: float32 audio at 16 kHz.
model: WhisperForConditionalGeneration instance.
processor: WhisperProcessor instance.
forced_ids: Output of get_decoder_prompt_ids() or None.
max_new_tokens: Maximum tokens to generate.
"""
inputs = processor.feature_extractor(
audio_np, sampling_rate=16_000, return_tensors="pt"
)
input_features = inputs.input_features
with torch.no_grad():
output = model.generate(
input_features,
forced_decoder_ids=forced_ids,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
)
text = processor.batch_decode(output.sequences, skip_special_tokens=True)[0].strip()
# Compute avg log-prob via model.compute_transition_scores
avg_logprob = 0.0
try:
transition_scores = model.compute_transition_scores(
output.sequences,
output.scores,
normalize_logits=True,
)
# Shape: (batch, generated_len). Take batch[0], skip zero-padded positions.
scores = transition_scores[0]
valid = scores[scores != 0]
if valid.numel() > 0:
avg_logprob = valid.mean().item()
except Exception as exc:
logger.debug("avg_logprob computation failed: %s", exc)
logger.debug(
"STT confidence: avg_logprob=%.3f text=%r",
avg_logprob,
text[:60],
)
return text, avg_logprob
|