| | """ |
| | Speech Fluency Analysis - Hugging Face Gradio App |
| | WavLM stutter detection + Whisper transcription. |
| | """ |
| |
|
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torchaudio |
| | import numpy as np |
| | import gradio as gr |
| | from datetime import datetime |
| | from transformers import WavLMModel |
| |
|
| | STUTTER_LABELS = ["Prolongation", "Block", "SoundRep", "WordRep", "Interjection"] |
| |
|
| | STUTTER_INFO = { |
| | "Prolongation": "Sound stretched longer than normal (e.g. 'Ssssnake')", |
| | "Block": "Complete stoppage of airflow/sound with tension", |
| | "SoundRep": "Sound/syllable repetition (e.g. 'B-b-b-ball')", |
| | "WordRep": "Whole word repetition (e.g. 'I-I-I want')", |
| | "Interjection": "Filler words like 'um', 'uh', 'like'", |
| | } |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| |
|
| | class WaveLmStutterClassification(nn.Module): |
| | def __init__(self, num_labels=5): |
| | super().__init__() |
| | self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base") |
| | self.hidden_size = self.wavlm.config.hidden_size |
| | for p in self.wavlm.parameters(): |
| | p.requires_grad = False |
| | self.classifier = nn.Linear(self.hidden_size, num_labels) |
| |
|
| | def forward(self, x, attention_mask=None): |
| | h = self.wavlm(x, attention_mask=attention_mask).last_hidden_state |
| | return self.classifier(h.mean(dim=1)) |
| |
|
| |
|
| | wavlm_model = None |
| | whisper_model = None |
| | models_loaded = False |
| |
|
| |
|
| | def load_models(): |
| | """Load WavLM checkpoint and Whisper once.""" |
| | global wavlm_model, whisper_model, models_loaded |
| | if models_loaded: |
| | return True |
| |
|
| | print("Loading WavLM ...") |
| | wavlm_model = WaveLmStutterClassification(num_labels=5) |
| | ckpt = "wavlm_stutter_classification_best.pth" |
| | if os.path.exists(ckpt): |
| | state = torch.load(ckpt, map_location=DEVICE, weights_only=False) |
| | if isinstance(state, dict) and "model_state_dict" in state: |
| | wavlm_model.load_state_dict(state["model_state_dict"]) |
| | else: |
| | wavlm_model.load_state_dict(state) |
| | wavlm_model.to(DEVICE).eval() |
| |
|
| | print("Loading Whisper ...") |
| | import whisper |
| | whisper_model = whisper.load_model("base", device=DEVICE) |
| |
|
| | models_loaded = True |
| | print("Models ready.") |
| | return True |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def load_audio(path): |
| | """Load any audio file to 16 kHz mono tensor via torchaudio (uses FFmpeg).""" |
| | waveform, sr = torchaudio.load(path) |
| | if waveform.size(0) > 1: |
| | waveform = waveform.mean(dim=0, keepdim=True) |
| | if sr != 16000: |
| | waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) |
| | return waveform.squeeze(0), 16000 |
| |
|
| |
|
| | def analyze_chunk(chunk, threshold=0.5): |
| | """Run WavLM on a single chunk.""" |
| | with torch.no_grad(): |
| | logits = wavlm_model(chunk.unsqueeze(0).to(DEVICE)) |
| | probs = torch.sigmoid(logits).cpu().numpy()[0] |
| | detected = [STUTTER_LABELS[i] for i, p in enumerate(probs) if p > threshold] |
| | prob_dict = dict(zip(STUTTER_LABELS, [round(float(p), 3) for p in probs])) |
| | return detected, prob_dict |
| |
|
| |
|
| | def analyze_audio(audio_path, threshold, progress=gr.Progress()): |
| | """Main pipeline: chunk -> WavLM -> Whisper -> formatted results.""" |
| | if audio_path is None: |
| | return "Upload an audio file first.", "", "", "" |
| |
|
| | if isinstance(audio_path, tuple): |
| | import tempfile, soundfile as sf |
| | sr, data = audio_path |
| | tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| | sf.write(tmp.name, data, sr) |
| | audio_path = tmp.name |
| |
|
| | progress(0.05, desc="Loading models ...") |
| | if not models_loaded and not load_models(): |
| | return "Failed to load models.", "", "", "" |
| |
|
| | progress(0.15, desc="Loading audio ...") |
| | waveform, sr = load_audio(audio_path) |
| | duration = len(waveform) / sr |
| |
|
| | progress(0.25, desc="Detecting stutters ...") |
| | chunk_samples = 3 * sr |
| | counts = {l: 0 for l in STUTTER_LABELS} |
| | timeline_rows = [] |
| | total_chunks = max(1, (len(waveform) + chunk_samples - 1) // chunk_samples) |
| |
|
| | for i, start in enumerate(range(0, len(waveform), chunk_samples)): |
| | progress(0.25 + 0.45 * (i / total_chunks), desc=f"Chunk {i+1}/{total_chunks} ...") |
| | end = min(start + chunk_samples, len(waveform)) |
| | chunk = waveform[start:end] |
| | if len(chunk) < chunk_samples: |
| | chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) |
| |
|
| | detected, probs = analyze_chunk(chunk, threshold) |
| | for label in detected: |
| | counts[label] += 1 |
| |
|
| | time_str = f"{start/sr:.1f}-{end/sr:.1f}s" |
| | timeline_rows.append({"time": time_str, "detected": detected or ["Fluent"], "probs": probs}) |
| |
|
| | progress(0.75, desc="Transcribing ...") |
| | transcription = whisper_model.transcribe(audio_path).get("text", "").strip() |
| |
|
| | progress(0.90, desc="Building report ...") |
| | total_stutters = sum(counts.values()) |
| | chunks_with_stutter = sum(1 for r in timeline_rows if "Fluent" not in r["detected"]) |
| | stutter_pct = (chunks_with_stutter / total_chunks) * 100 if total_chunks else 0 |
| | word_count = len(transcription.split()) if transcription else 0 |
| | wpm = (word_count / duration) * 60 if duration > 0 else 0 |
| |
|
| | severity = ( |
| | "Very Mild" if stutter_pct < 5 else |
| | "Mild" if stutter_pct < 10 else |
| | "Moderate" if stutter_pct < 20 else |
| | "Severe" if stutter_pct < 30 else |
| | "Very Severe" |
| | ) |
| |
|
| | summary_lines = [ |
| | "## Analysis Results\n", |
| | "| Metric | Value |", |
| | "|--------|-------|", |
| | f"| Duration | {duration:.1f}s |", |
| | f"| Words | {word_count} |", |
| | f"| Speaking Rate | {wpm:.0f} wpm |", |
| | f"| Stutter Events | {total_stutters} |", |
| | f"| Affected Chunks | {chunks_with_stutter}/{total_chunks} ({stutter_pct:.1f}%) |", |
| | f"| Severity | **{severity}** |", |
| | "", |
| | "### Stutter Counts", |
| | "", |
| | ] |
| | for label in STUTTER_LABELS: |
| | c = counts[label] |
| | bar = "X" * min(c, 20) |
| | icon = "!" if c > 0 else "o" |
| | summary_lines.append(f"- {icon} **{label}**: {c} {bar}") |
| |
|
| | summary_md = "\n".join(summary_lines) |
| |
|
| | tl_lines = ["| Time | Detected |", "|------|----------|"] |
| | for row in timeline_rows: |
| | tl_lines.append(f"| {row['time']} | {', '.join(row['detected'])} |") |
| | timeline_md = "\n".join(tl_lines) |
| |
|
| | recs = ["## Recommendations\n"] |
| | if severity in ("Very Mild", "Mild"): |
| | recs.append("- Stuttering is within the mild range. Regular monitoring is recommended.") |
| | elif severity == "Moderate": |
| | recs.append("- Consider speech therapy consultation for fluency-enhancing techniques.") |
| | else: |
| | recs.append("- Professional speech-language pathology evaluation is strongly recommended.") |
| |
|
| | dominant = max(counts, key=counts.get) |
| | if counts[dominant] > 0: |
| | recs.append(f"- Most frequent type: **{dominant}** - {STUTTER_INFO[dominant]}") |
| |
|
| | if wpm > 180: |
| | recs.append(f"- Speaking rate is high ({wpm:.0f} wpm). Slower speech may reduce stuttering.") |
| |
|
| | recs.append("\n### Stutter Type Definitions\n") |
| | for label, desc in STUTTER_INFO.items(): |
| | recs.append(f"- **{label}**: {desc}") |
| |
|
| | recs_md = "\n".join(recs) |
| |
|
| | progress(1.0, desc="Done!") |
| | return summary_md, transcription, timeline_md, recs_md |
| |
|
| |
|
| | CUSTOM_CSS = """ |
| | .gradio-container { max-width: 960px !important; } |
| | .gr-button-primary { background: #0f766e !important; } |
| | """ |
| |
|
| | with gr.Blocks(title="Speech Fluency Analysis", css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: |
| |
|
| | gr.Markdown( |
| | """ |
| | # Speech Fluency Analysis |
| | Upload an audio file to detect stuttering patterns using **WavLM** (stutter detection) |
| | and **Whisper** (transcription). |
| | |
| | Supported formats: **WAV, MP3, M4A, FLAC, OGG** |
| | """ |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | audio_in = gr.Audio(label="Upload Audio", type="filepath") |
| | threshold = gr.Slider( |
| | 0.3, 0.7, value=0.5, step=0.05, |
| | label="Detection Threshold", |
| | info="Lower = more sensitive, Higher = more strict", |
| | ) |
| | btn = gr.Button("Analyze", variant="primary", size="lg") |
| |
|
| | with gr.Column(scale=2): |
| | summary_out = gr.Markdown(value="*Upload audio and click **Analyze** to start.*") |
| |
|
| | with gr.Tabs(): |
| | with gr.TabItem("Transcription"): |
| | trans_out = gr.Textbox(label="Whisper Transcription", lines=6, interactive=False) |
| | with gr.TabItem("Timeline"): |
| | timeline_out = gr.Markdown() |
| | with gr.TabItem("Recommendations"): |
| | recs_out = gr.Markdown() |
| |
|
| | gr.Markdown( |
| | "---\n*Disclaimer: AI-assisted analysis for clinical support only. " |
| | "Consult a qualified Speech-Language Pathologist for diagnosis.*" |
| | ) |
| |
|
| | btn.click( |
| | fn=analyze_audio, |
| | inputs=[audio_in, threshold], |
| | outputs=[summary_out, trans_out, timeline_out, recs_out], |
| | show_progress="full", |
| | ) |
| |
|
| | print("Loading models at startup ...") |
| | load_models() |
| |
|
| | print("Launching Gradio ...") |
| | demo.queue() |
| | demo.launch(ssr_mode=False) |
| |
|