Toadoum's picture
Upload 5 files
ddbabb4 verified
"""
PlotWeaver Voice Agent — HuggingFace Space (Gradio 6 + Python 3.13)
====================================================================
Hausa-first conversational AI for African banks, telecoms, and delivery.
Pipeline (all real, running on CPU):
ASR (openai/whisper-small)
→ NLU (rule-based + Qwen2.5-1.5B-Instruct fallback, see nlu.py)
→ Dialogue FSM (see dialogue.py)
→ TTS (facebook/mms-tts-hau)
First turn: ~30-60s model downloads. Subsequent turns: ~5-10s on CPU.
"""
from __future__ import annotations
import time
import uuid
import html as html_lib
from typing import Optional
import gradio as gr
import numpy as np
import torch
from transformers import (
VitsModel, AutoTokenizer,
WhisperProcessor, WhisperForConditionalGeneration,
)
from dialogue import (
DialogueState, SCENARIOS,
get_prompt, get_expected_slot, transition,
)
from nlu import parse as nlu_parse
# ---------------------------------------------------------------------------
# Model loading (lazy, cached)
# ---------------------------------------------------------------------------
_asr_model = None
_asr_processor = None
_tts_model = None
_tts_tokenizer = None
def load_asr():
global _asr_model, _asr_processor
if _asr_model is None:
print("Loading Whisper-small…")
_asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
_asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
_asr_model.eval()
print("Whisper-small ready.")
return _asr_model, _asr_processor
def load_tts():
global _tts_model, _tts_tokenizer
if _tts_model is None:
print("Loading MMS-TTS Hausa…")
_tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau")
_tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau")
_tts_model.eval()
print("MMS-TTS Hausa ready.")
return _tts_model, _tts_tokenizer
def transcribe_hausa(audio_tuple) -> str:
if audio_tuple is None:
return ""
sample_rate, audio_array = audio_tuple
if audio_array is None or len(audio_array) == 0:
return ""
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Cap at 30s (Whisper training chunk size)
max_samples = sample_rate * 30
if len(audio_array) > max_samples:
audio_array = audio_array[:max_samples]
if sample_rate != 16000:
import scipy.signal
num_samples = int(len(audio_array) * 16000 / sample_rate)
audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32)
model, processor = load_asr()
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")
with torch.no_grad():
ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128)
text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
return text
def synthesize_hausa(text: str) -> Optional[tuple]:
if not text.strip():
return None
model, tokenizer = load_tts()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
out = model(**inputs).waveform
audio = out.squeeze().cpu().numpy().astype(np.float32)
return (model.config.sampling_rate, audio)
# ---------------------------------------------------------------------------
# WhatsApp-style HTML rendering
# ---------------------------------------------------------------------------
def _now() -> str:
return time.strftime("%H:%M")
def _user_bubble(text: str, is_voice: bool) -> str:
t = html_lib.escape(text)
if is_voice:
bars = "".join(
f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>'
for i in range(20)
)
return f'''<div class="pw-b user">
<div class="pw-voice-row">
<div class="pw-voice-icon">▶</div>
<div class="pw-voice-bars">{bars}</div>
</div>
<div style="font-size:12px;color:#667781;margin-top:3px;">"{t}"</div>
<div class="pw-b-meta">{_now()} ✓✓</div>
</div>'''
return f'<div class="pw-b user">{t}<div class="pw-b-meta">{_now()} ✓✓</div></div>'
def _bot_bubble(text_ha: str, text_en: str) -> str:
ha = html_lib.escape(text_ha)
en = html_lib.escape(text_en)
return f'''<div class="pw-b bot">
<div>{ha}</div>
<div class="pw-b-trans">{en}</div>
<div class="pw-b-meta">{_now()} ✓✓</div>
</div>'''
def render_whatsapp(session: dict) -> str:
vertical = session.get("vertical", "bank") if session else "bank"
name = SCENARIOS[vertical]["name"]
avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical]
escalated = session.get("escalate_to_human", False) if session else False
bubbles = []
for msg in session.get("history", []) if session else []:
if msg["role"] == "user":
bubbles.append(_user_bubble(msg["text"], msg.get("is_voice", False)))
else:
bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", "")))
banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>'
if escalated else "")
if not bubbles:
body = '<div style="text-align:center;color:#667781;font-size:12px;padding:40px 0;">Send a message to begin…</div>'
else:
body = "".join(bubbles)
return f"""
<div class="pw-phone">
<div class="pw-ph-header">
<div class="pw-ph-avatar">{avatar}</div>
<div>
<div class="pw-ph-name">{html_lib.escape(name)}</div>
<div class="pw-ph-status">online • voice agent</div>
</div>
</div>
<div class="pw-ph-messages">
{banner}
{body}
</div>
</div>
<style>
.pw-phone {{ max-width: 480px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 540px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }}
.pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }}
.pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }}
.pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }}
.pw-ph-status {{ font-size: 11px; color: #D4EDE8; }}
.pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 480px; overflow-y: auto; min-height: 420px; }}
.pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }}
.pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }}
.pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }}
.pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }}
.pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }}
.pw-voice-row {{ display: flex; align-items: center; gap: 8px; }}
.pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }}
.pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }}
.pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }}
.pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }}
</style>
"""
# ---------------------------------------------------------------------------
# Core turn handler
# ---------------------------------------------------------------------------
def run_turn(user_text: str, session: dict, is_voice: bool = False):
"""Returns (updated_session_dict, bot_audio)."""
state = DialogueState.from_dict(session) if session else None
if state is None:
state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank")
expected = get_expected_slot(state.vertical, state.current_state)
intent, entities, _ = nlu_parse(user_text, expected)
state = transition(state, intent, entities)
prompt = get_prompt(state.vertical, state.current_state)
state.history.append({"role": "user", "text": user_text, "is_voice": is_voice})
state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]})
try:
audio = synthesize_hausa(prompt["ha"])
except Exception as e:
print(f"TTS failed: {e}")
audio = None
return state.to_dict(), audio
# ---------------------------------------------------------------------------
# Gradio event handlers
# ---------------------------------------------------------------------------
def on_vertical_change(vertical: str):
state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical)
greet = get_prompt(vertical, "greeting")
state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]})
session = state.to_dict()
return session, render_whatsapp(session), None
def on_text_submit(text: str, session: dict):
if not text or not text.strip():
return session, render_whatsapp(session), None, ""
new_session, audio = run_turn(text, session, is_voice=False)
return new_session, render_whatsapp(new_session), audio, ""
def on_audio_submit(audio_data, session: dict):
if audio_data is None:
return session, render_whatsapp(session), None
try:
text = transcribe_hausa(audio_data)
except Exception as e:
print(f"ASR failed: {e}")
return session, render_whatsapp(session), None
if not text:
return session, render_whatsapp(session), None
new_session, audio = run_turn(text, session, is_voice=True)
return new_session, render_whatsapp(new_session), audio
def on_reset(session: dict):
vertical = session.get("vertical", "bank") if session else "bank"
return on_vertical_change(vertical)
# ---------------------------------------------------------------------------
# Gradio UI (chat-only, minimal components)
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
.gradio-container { max-width: 720px !important; margin: 0 auto !important; }
#whatsapp-container { padding: 20px 0; }
"""
with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo:
gr.HTML("""
<div style="text-align:center; padding: 0 0 12px;">
<h1 style="margin:0 0 4px; font-size: 22px; font-weight: 500;">PlotWeaver Voice Agent</h1>
<p style="margin:0; color: #5f5e5a; font-size: 14px;">Hausa-first conversational AI — pick a vertical, type or speak in Hausa</p>
</div>
""")
session_state = gr.State({})
vertical_radio = gr.Radio(
choices=[("PlotWeaver Bank", "bank"),
("PlotWeaver Telecom", "telecom"),
("PlotWeaver Delivery", "ecommerce")],
value="bank",
label="Vertical",
container=False,
)
whatsapp_html = gr.HTML(elem_id="whatsapp-container")
with gr.Row():
text_input = gr.Textbox(
placeholder="Type in Hausa… e.g. 'duba ma'auni'",
label="",
scale=4,
container=False,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
reset_btn = gr.Button("Reset", scale=1)
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Record or upload Hausa audio (click Stop when done recording)",
)
bot_audio = gr.Audio(
label="Bot response (Hausa TTS)",
autoplay=True,
interactive=False,
)
# Events
demo.load(
fn=lambda: on_vertical_change("bank"),
outputs=[session_state, whatsapp_html, bot_audio],
)
vertical_radio.change(
fn=on_vertical_change,
inputs=[vertical_radio],
outputs=[session_state, whatsapp_html, bot_audio],
)
send_btn.click(
fn=on_text_submit,
inputs=[text_input, session_state],
outputs=[session_state, whatsapp_html, bot_audio, text_input],
)
text_input.submit(
fn=on_text_submit,
inputs=[text_input, session_state],
outputs=[session_state, whatsapp_html, bot_audio, text_input],
)
audio_input.stop_recording(
fn=on_audio_submit,
inputs=[audio_input, session_state],
outputs=[session_state, whatsapp_html, bot_audio],
)
reset_btn.click(
fn=on_reset,
inputs=[session_state],
outputs=[session_state, whatsapp_html, bot_audio],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)