File size: 2,791 Bytes
3657607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
STT confidence extractor.

Wraps Whisper's generate() with return_dict_in_generate=True to compute
avg_logprob — the mean log-probability over generated tokens. This mirrors
the avg_logprob field returned by the OpenAI Whisper API.

Threshold: avg_logprob < -1.0 signals a low-confidence transcription where
the model was essentially guessing.  The caller should treat this as "confused"
and prompt the user to repeat and explain the word.
"""
from __future__ import annotations

import logging

import numpy as np
import torch

logger = logging.getLogger(__name__)

# Anything below this is considered "confused" transcription
LOW_CONFIDENCE_THRESHOLD: float = -1.0

# Message substituted for the transcript when confidence is low
CONFUSION_PROMPT: str = (
    "The user spoke, but I am confused. "
    "Ask the user in English to repeat the local word and explain its meaning."
)


def transcribe_with_confidence(
    audio_np: np.ndarray,
    model,
    processor,
    forced_ids,
    max_new_tokens: int = 256,
) -> tuple[str, float]:
    """
    Run Whisper and return (text, avg_logprob).

    avg_logprob is in (-inf, 0].  A value close to 0 means high confidence.
    Returns avg_logprob = 0.0 if computation fails (treated as confident).

    Args:
        audio_np:       float32 audio at 16 kHz.
        model:          WhisperForConditionalGeneration instance.
        processor:      WhisperProcessor instance.
        forced_ids:     Output of get_decoder_prompt_ids() or None.
        max_new_tokens: Maximum tokens to generate.
    """
    inputs = processor.feature_extractor(
        audio_np, sampling_rate=16_000, return_tensors="pt"
    )
    input_features = inputs.input_features

    with torch.no_grad():
        output = model.generate(
            input_features,
            forced_decoder_ids=forced_ids,
            max_new_tokens=max_new_tokens,
            return_dict_in_generate=True,
            output_scores=True,
        )

    text = processor.batch_decode(output.sequences, skip_special_tokens=True)[0].strip()

    # Compute avg log-prob via model.compute_transition_scores
    avg_logprob = 0.0
    try:
        transition_scores = model.compute_transition_scores(
            output.sequences,
            output.scores,
            normalize_logits=True,
        )
        # Shape: (batch, generated_len).  Take batch[0], skip zero-padded positions.
        scores = transition_scores[0]
        valid  = scores[scores != 0]
        if valid.numel() > 0:
            avg_logprob = valid.mean().item()
    except Exception as exc:
        logger.debug("avg_logprob computation failed: %s", exc)

    logger.debug(
        "STT confidence: avg_logprob=%.3f  text=%r",
        avg_logprob,
        text[:60],
    )
    return text, avg_logprob