| """ |
| PlotWeaver Voice Agent — HuggingFace Space (Gradio 6 + Python 3.13) |
| ==================================================================== |
| Hausa-first conversational AI for African banks, telecoms, and delivery. |
| |
| Pipeline (all real, running on CPU): |
| ASR (openai/whisper-small) |
| → NLU (rule-based + Qwen2.5-1.5B-Instruct fallback, see nlu.py) |
| → Dialogue FSM (see dialogue.py) |
| → TTS (facebook/mms-tts-hau) |
| |
| First turn: ~30-60s model downloads. Subsequent turns: ~5-10s on CPU. |
| """ |
| from __future__ import annotations |
| import time |
| import uuid |
| import html as html_lib |
| from typing import Optional |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| from transformers import ( |
| VitsModel, AutoTokenizer, |
| WhisperProcessor, WhisperForConditionalGeneration, |
| ) |
|
|
| from dialogue import ( |
| DialogueState, SCENARIOS, |
| get_prompt, get_expected_slot, transition, |
| ) |
| from nlu import parse as nlu_parse |
|
|
|
|
| |
| |
| |
| _asr_model = None |
| _asr_processor = None |
| _tts_model = None |
| _tts_tokenizer = None |
|
|
|
|
| def load_asr(): |
| global _asr_model, _asr_processor |
| if _asr_model is None: |
| print("Loading Whisper-small…") |
| _asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small") |
| _asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") |
| _asr_model.eval() |
| print("Whisper-small ready.") |
| return _asr_model, _asr_processor |
|
|
|
|
| def load_tts(): |
| global _tts_model, _tts_tokenizer |
| if _tts_model is None: |
| print("Loading MMS-TTS Hausa…") |
| _tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau") |
| _tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau") |
| _tts_model.eval() |
| print("MMS-TTS Hausa ready.") |
| return _tts_model, _tts_tokenizer |
|
|
|
|
| def transcribe_hausa(audio_tuple) -> str: |
| if audio_tuple is None: |
| return "" |
| sample_rate, audio_array = audio_tuple |
| if audio_array is None or len(audio_array) == 0: |
| return "" |
| if audio_array.dtype != np.float32: |
| audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max |
| if audio_array.ndim > 1: |
| audio_array = audio_array.mean(axis=1) |
| |
| max_samples = sample_rate * 30 |
| if len(audio_array) > max_samples: |
| audio_array = audio_array[:max_samples] |
| if sample_rate != 16000: |
| import scipy.signal |
| num_samples = int(len(audio_array) * 16000 / sample_rate) |
| audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32) |
|
|
| model, processor = load_asr() |
| inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt") |
| forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe") |
| with torch.no_grad(): |
| ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128) |
| text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip() |
| return text |
|
|
|
|
| def synthesize_hausa(text: str) -> Optional[tuple]: |
| if not text.strip(): |
| return None |
| model, tokenizer = load_tts() |
| inputs = tokenizer(text, return_tensors="pt") |
| with torch.no_grad(): |
| out = model(**inputs).waveform |
| audio = out.squeeze().cpu().numpy().astype(np.float32) |
| return (model.config.sampling_rate, audio) |
|
|
|
|
| |
| |
| |
| def _now() -> str: |
| return time.strftime("%H:%M") |
|
|
|
|
| def _user_bubble(text: str, is_voice: bool) -> str: |
| t = html_lib.escape(text) |
| if is_voice: |
| bars = "".join( |
| f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>' |
| for i in range(20) |
| ) |
| return f'''<div class="pw-b user"> |
| <div class="pw-voice-row"> |
| <div class="pw-voice-icon">▶</div> |
| <div class="pw-voice-bars">{bars}</div> |
| </div> |
| <div style="font-size:12px;color:#667781;margin-top:3px;">"{t}"</div> |
| <div class="pw-b-meta">{_now()} ✓✓</div> |
| </div>''' |
| return f'<div class="pw-b user">{t}<div class="pw-b-meta">{_now()} ✓✓</div></div>' |
|
|
|
|
| def _bot_bubble(text_ha: str, text_en: str) -> str: |
| ha = html_lib.escape(text_ha) |
| en = html_lib.escape(text_en) |
| return f'''<div class="pw-b bot"> |
| <div>{ha}</div> |
| <div class="pw-b-trans">{en}</div> |
| <div class="pw-b-meta">{_now()} ✓✓</div> |
| </div>''' |
|
|
|
|
| def render_whatsapp(session: dict) -> str: |
| vertical = session.get("vertical", "bank") if session else "bank" |
| name = SCENARIOS[vertical]["name"] |
| avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical] |
| escalated = session.get("escalate_to_human", False) if session else False |
|
|
| bubbles = [] |
| for msg in session.get("history", []) if session else []: |
| if msg["role"] == "user": |
| bubbles.append(_user_bubble(msg["text"], msg.get("is_voice", False))) |
| else: |
| bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", ""))) |
|
|
| banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>' |
| if escalated else "") |
|
|
| if not bubbles: |
| body = '<div style="text-align:center;color:#667781;font-size:12px;padding:40px 0;">Send a message to begin…</div>' |
| else: |
| body = "".join(bubbles) |
|
|
| return f""" |
| <div class="pw-phone"> |
| <div class="pw-ph-header"> |
| <div class="pw-ph-avatar">{avatar}</div> |
| <div> |
| <div class="pw-ph-name">{html_lib.escape(name)}</div> |
| <div class="pw-ph-status">online • voice agent</div> |
| </div> |
| </div> |
| <div class="pw-ph-messages"> |
| {banner} |
| {body} |
| </div> |
| </div> |
| <style> |
| .pw-phone {{ max-width: 480px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 540px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }} |
| .pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }} |
| .pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }} |
| .pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }} |
| .pw-ph-status {{ font-size: 11px; color: #D4EDE8; }} |
| .pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 480px; overflow-y: auto; min-height: 420px; }} |
| .pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }} |
| .pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }} |
| .pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }} |
| .pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }} |
| .pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }} |
| .pw-voice-row {{ display: flex; align-items: center; gap: 8px; }} |
| .pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }} |
| .pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }} |
| .pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }} |
| .pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }} |
| </style> |
| """ |
|
|
|
|
| |
| |
| |
| def run_turn(user_text: str, session: dict, is_voice: bool = False): |
| """Returns (updated_session_dict, bot_audio).""" |
| state = DialogueState.from_dict(session) if session else None |
| if state is None: |
| state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank") |
|
|
| expected = get_expected_slot(state.vertical, state.current_state) |
| intent, entities, _ = nlu_parse(user_text, expected) |
| state = transition(state, intent, entities) |
|
|
| prompt = get_prompt(state.vertical, state.current_state) |
|
|
| state.history.append({"role": "user", "text": user_text, "is_voice": is_voice}) |
| state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]}) |
|
|
| try: |
| audio = synthesize_hausa(prompt["ha"]) |
| except Exception as e: |
| print(f"TTS failed: {e}") |
| audio = None |
|
|
| return state.to_dict(), audio |
|
|
|
|
| |
| |
| |
| def on_vertical_change(vertical: str): |
| state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical) |
| greet = get_prompt(vertical, "greeting") |
| state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]}) |
| session = state.to_dict() |
| return session, render_whatsapp(session), None |
|
|
|
|
| def on_text_submit(text: str, session: dict): |
| if not text or not text.strip(): |
| return session, render_whatsapp(session), None, "" |
| new_session, audio = run_turn(text, session, is_voice=False) |
| return new_session, render_whatsapp(new_session), audio, "" |
|
|
|
|
| def on_audio_submit(audio_data, session: dict): |
| if audio_data is None: |
| return session, render_whatsapp(session), None |
| try: |
| text = transcribe_hausa(audio_data) |
| except Exception as e: |
| print(f"ASR failed: {e}") |
| return session, render_whatsapp(session), None |
| if not text: |
| return session, render_whatsapp(session), None |
| new_session, audio = run_turn(text, session, is_voice=True) |
| return new_session, render_whatsapp(new_session), audio |
|
|
|
|
| def on_reset(session: dict): |
| vertical = session.get("vertical", "bank") if session else "bank" |
| return on_vertical_change(vertical) |
|
|
|
|
| |
| |
| |
| CUSTOM_CSS = """ |
| .gradio-container { max-width: 720px !important; margin: 0 auto !important; } |
| #whatsapp-container { padding: 20px 0; } |
| """ |
|
|
| with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo: |
| gr.HTML(""" |
| <div style="text-align:center; padding: 0 0 12px;"> |
| <h1 style="margin:0 0 4px; font-size: 22px; font-weight: 500;">PlotWeaver Voice Agent</h1> |
| <p style="margin:0; color: #5f5e5a; font-size: 14px;">Hausa-first conversational AI — pick a vertical, type or speak in Hausa</p> |
| </div> |
| """) |
|
|
| session_state = gr.State({}) |
|
|
| vertical_radio = gr.Radio( |
| choices=[("PlotWeaver Bank", "bank"), |
| ("PlotWeaver Telecom", "telecom"), |
| ("PlotWeaver Delivery", "ecommerce")], |
| value="bank", |
| label="Vertical", |
| container=False, |
| ) |
|
|
| whatsapp_html = gr.HTML(elem_id="whatsapp-container") |
|
|
| with gr.Row(): |
| text_input = gr.Textbox( |
| placeholder="Type in Hausa… e.g. 'duba ma'auni'", |
| label="", |
| scale=4, |
| container=False, |
| ) |
| send_btn = gr.Button("Send", scale=1, variant="primary") |
| reset_btn = gr.Button("Reset", scale=1) |
|
|
| audio_input = gr.Audio( |
| sources=["microphone", "upload"], |
| type="numpy", |
| label="Record or upload Hausa audio (click Stop when done recording)", |
| ) |
|
|
| bot_audio = gr.Audio( |
| label="Bot response (Hausa TTS)", |
| autoplay=True, |
| interactive=False, |
| ) |
|
|
| |
| demo.load( |
| fn=lambda: on_vertical_change("bank"), |
| outputs=[session_state, whatsapp_html, bot_audio], |
| ) |
| vertical_radio.change( |
| fn=on_vertical_change, |
| inputs=[vertical_radio], |
| outputs=[session_state, whatsapp_html, bot_audio], |
| ) |
| send_btn.click( |
| fn=on_text_submit, |
| inputs=[text_input, session_state], |
| outputs=[session_state, whatsapp_html, bot_audio, text_input], |
| ) |
| text_input.submit( |
| fn=on_text_submit, |
| inputs=[text_input, session_state], |
| outputs=[session_state, whatsapp_html, bot_audio, text_input], |
| ) |
| audio_input.stop_recording( |
| fn=on_audio_submit, |
| inputs=[audio_input, session_state], |
| outputs=[session_state, whatsapp_html, bot_audio], |
| ) |
| reset_btn.click( |
| fn=on_reset, |
| inputs=[session_state], |
| outputs=[session_state, whatsapp_html, bot_audio], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|