"""
PlotWeaver Voice Agent — HuggingFace Space (Gradio 6 + Python 3.13)
====================================================================
Hausa-first conversational AI for African banks, telecoms, and delivery.
Pipeline (all real, running on CPU):
ASR (openai/whisper-small)
→ NLU (rule-based + Qwen2.5-1.5B-Instruct fallback, see nlu.py)
→ Dialogue FSM (see dialogue.py)
→ TTS (facebook/mms-tts-hau)
First turn: ~30-60s model downloads. Subsequent turns: ~5-10s on CPU.
"""
from __future__ import annotations
import time
import uuid
import html as html_lib
from typing import Optional
import gradio as gr
import numpy as np
import torch
from transformers import (
VitsModel, AutoTokenizer,
WhisperProcessor, WhisperForConditionalGeneration,
)
from dialogue import (
DialogueState, SCENARIOS,
get_prompt, get_expected_slot, transition,
)
from nlu import parse as nlu_parse
# ---------------------------------------------------------------------------
# Model loading (lazy, cached)
# ---------------------------------------------------------------------------
_asr_model = None
_asr_processor = None
_tts_model = None
_tts_tokenizer = None
def load_asr():
global _asr_model, _asr_processor
if _asr_model is None:
print("Loading Whisper-small…")
_asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
_asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
_asr_model.eval()
print("Whisper-small ready.")
return _asr_model, _asr_processor
def load_tts():
global _tts_model, _tts_tokenizer
if _tts_model is None:
print("Loading MMS-TTS Hausa…")
_tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau")
_tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau")
_tts_model.eval()
print("MMS-TTS Hausa ready.")
return _tts_model, _tts_tokenizer
def transcribe_hausa(audio_tuple) -> str:
if audio_tuple is None:
return ""
sample_rate, audio_array = audio_tuple
if audio_array is None or len(audio_array) == 0:
return ""
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Cap at 30s (Whisper training chunk size)
max_samples = sample_rate * 30
if len(audio_array) > max_samples:
audio_array = audio_array[:max_samples]
if sample_rate != 16000:
import scipy.signal
num_samples = int(len(audio_array) * 16000 / sample_rate)
audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32)
model, processor = load_asr()
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")
with torch.no_grad():
ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128)
text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
return text
def synthesize_hausa(text: str) -> Optional[tuple]:
if not text.strip():
return None
model, tokenizer = load_tts()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
out = model(**inputs).waveform
audio = out.squeeze().cpu().numpy().astype(np.float32)
return (model.config.sampling_rate, audio)
# ---------------------------------------------------------------------------
# WhatsApp-style HTML rendering
# ---------------------------------------------------------------------------
def _now() -> str:
return time.strftime("%H:%M")
def _user_bubble(text: str, is_voice: bool) -> str:
t = html_lib.escape(text)
if is_voice:
bars = "".join(
f''
for i in range(20)
)
return f'''
'''
return f''
def _bot_bubble(text_ha: str, text_en: str) -> str:
ha = html_lib.escape(text_ha)
en = html_lib.escape(text_en)
return f''''''
def render_whatsapp(session: dict) -> str:
vertical = session.get("vertical", "bank") if session else "bank"
name = SCENARIOS[vertical]["name"]
avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical]
escalated = session.get("escalate_to_human", False) if session else False
bubbles = []
for msg in session.get("history", []) if session else []:
if msg["role"] == "user":
bubbles.append(_user_bubble(msg["text"], msg.get("is_voice", False)))
else:
bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", "")))
banner = ('Session escalated to human agent
'
if escalated else "")
if not bubbles:
body = 'Send a message to begin…
'
else:
body = "".join(bubbles)
return f"""
"""
# ---------------------------------------------------------------------------
# Core turn handler
# ---------------------------------------------------------------------------
def run_turn(user_text: str, session: dict, is_voice: bool = False):
"""Returns (updated_session_dict, bot_audio)."""
state = DialogueState.from_dict(session) if session else None
if state is None:
state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank")
expected = get_expected_slot(state.vertical, state.current_state)
intent, entities, _ = nlu_parse(user_text, expected)
state = transition(state, intent, entities)
prompt = get_prompt(state.vertical, state.current_state)
state.history.append({"role": "user", "text": user_text, "is_voice": is_voice})
state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]})
try:
audio = synthesize_hausa(prompt["ha"])
except Exception as e:
print(f"TTS failed: {e}")
audio = None
return state.to_dict(), audio
# ---------------------------------------------------------------------------
# Gradio event handlers
# ---------------------------------------------------------------------------
def on_vertical_change(vertical: str):
state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical)
greet = get_prompt(vertical, "greeting")
state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]})
session = state.to_dict()
return session, render_whatsapp(session), None
def on_text_submit(text: str, session: dict):
if not text or not text.strip():
return session, render_whatsapp(session), None, ""
new_session, audio = run_turn(text, session, is_voice=False)
return new_session, render_whatsapp(new_session), audio, ""
def on_audio_submit(audio_data, session: dict):
if audio_data is None:
return session, render_whatsapp(session), None
try:
text = transcribe_hausa(audio_data)
except Exception as e:
print(f"ASR failed: {e}")
return session, render_whatsapp(session), None
if not text:
return session, render_whatsapp(session), None
new_session, audio = run_turn(text, session, is_voice=True)
return new_session, render_whatsapp(new_session), audio
def on_reset(session: dict):
vertical = session.get("vertical", "bank") if session else "bank"
return on_vertical_change(vertical)
# ---------------------------------------------------------------------------
# Gradio UI (chat-only, minimal components)
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
.gradio-container { max-width: 720px !important; margin: 0 auto !important; }
#whatsapp-container { padding: 20px 0; }
"""
with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo:
gr.HTML("""
PlotWeaver Voice Agent
Hausa-first conversational AI — pick a vertical, type or speak in Hausa
""")
session_state = gr.State({})
vertical_radio = gr.Radio(
choices=[("PlotWeaver Bank", "bank"),
("PlotWeaver Telecom", "telecom"),
("PlotWeaver Delivery", "ecommerce")],
value="bank",
label="Vertical",
container=False,
)
whatsapp_html = gr.HTML(elem_id="whatsapp-container")
with gr.Row():
text_input = gr.Textbox(
placeholder="Type in Hausa… e.g. 'duba ma'auni'",
label="",
scale=4,
container=False,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
reset_btn = gr.Button("Reset", scale=1)
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Record or upload Hausa audio (click Stop when done recording)",
)
bot_audio = gr.Audio(
label="Bot response (Hausa TTS)",
autoplay=True,
interactive=False,
)
# Events
demo.load(
fn=lambda: on_vertical_change("bank"),
outputs=[session_state, whatsapp_html, bot_audio],
)
vertical_radio.change(
fn=on_vertical_change,
inputs=[vertical_radio],
outputs=[session_state, whatsapp_html, bot_audio],
)
send_btn.click(
fn=on_text_submit,
inputs=[text_input, session_state],
outputs=[session_state, whatsapp_html, bot_audio, text_input],
)
text_input.submit(
fn=on_text_submit,
inputs=[text_input, session_state],
outputs=[session_state, whatsapp_html, bot_audio, text_input],
)
audio_input.stop_recording(
fn=on_audio_submit,
inputs=[audio_input, session_state],
outputs=[session_state, whatsapp_html, bot_audio],
)
reset_btn.click(
fn=on_reset,
inputs=[session_state],
outputs=[session_state, whatsapp_html, bot_audio],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)