ground-zero / src /engine /stt_processor.py
jefffffff9
Add confidence loop, curiosity engine, and lightweight TTS
3657607
"""
STT confidence extractor.
Wraps Whisper's generate() with return_dict_in_generate=True to compute
avg_logprob — the mean log-probability over generated tokens. This mirrors
the avg_logprob field returned by the OpenAI Whisper API.
Threshold: avg_logprob < -1.0 signals a low-confidence transcription where
the model was essentially guessing. The caller should treat this as "confused"
and prompt the user to repeat and explain the word.
"""
from __future__ import annotations
import logging
import numpy as np
import torch
logger = logging.getLogger(__name__)
# Anything below this is considered "confused" transcription
LOW_CONFIDENCE_THRESHOLD: float = -1.0
# Message substituted for the transcript when confidence is low
CONFUSION_PROMPT: str = (
"The user spoke, but I am confused. "
"Ask the user in English to repeat the local word and explain its meaning."
)
def transcribe_with_confidence(
audio_np: np.ndarray,
model,
processor,
forced_ids,
max_new_tokens: int = 256,
) -> tuple[str, float]:
"""
Run Whisper and return (text, avg_logprob).
avg_logprob is in (-inf, 0]. A value close to 0 means high confidence.
Returns avg_logprob = 0.0 if computation fails (treated as confident).
Args:
audio_np: float32 audio at 16 kHz.
model: WhisperForConditionalGeneration instance.
processor: WhisperProcessor instance.
forced_ids: Output of get_decoder_prompt_ids() or None.
max_new_tokens: Maximum tokens to generate.
"""
inputs = processor.feature_extractor(
audio_np, sampling_rate=16_000, return_tensors="pt"
)
input_features = inputs.input_features
with torch.no_grad():
output = model.generate(
input_features,
forced_decoder_ids=forced_ids,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
)
text = processor.batch_decode(output.sequences, skip_special_tokens=True)[0].strip()
# Compute avg log-prob via model.compute_transition_scores
avg_logprob = 0.0
try:
transition_scores = model.compute_transition_scores(
output.sequences,
output.scores,
normalize_logits=True,
)
# Shape: (batch, generated_len). Take batch[0], skip zero-padded positions.
scores = transition_scores[0]
valid = scores[scores != 0]
if valid.numel() > 0:
avg_logprob = valid.mean().item()
except Exception as exc:
logger.debug("avg_logprob computation failed: %s", exc)
logger.debug(
"STT confidence: avg_logprob=%.3f text=%r",
avg_logprob,
text[:60],
)
return text, avg_logprob