Toadoum commited on
Commit
ddbabb4
·
verified ·
1 Parent(s): facd497

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +34 -7
  2. app.py +344 -0
  3. dialogue.py +237 -0
  4. nlu.py +310 -0
  5. requirements.txt +7 -0
README.md CHANGED
@@ -1,12 +1,39 @@
1
  ---
2
- title: Voice AI Agent Clean
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.13.0
8
  app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: PlotWeaver Voice Agent
3
+ emoji: 🗣️
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 6.12.0
8
  app_file: app.py
9
+ pinned: true
10
+ short_description: Hausa voice AI for African banks, telecoms, and delivery
11
+ license: apache-2.0
12
  ---
13
 
14
+ # PlotWeaver Voice Agent
15
+
16
+ Hausa-first conversational AI demo. Product 7 of the PlotWeaver suite: voice bots for WhatsApp, phone, and customer support across African banks, telecoms, and delivery services.
17
+
18
+ ## What it does
19
+
20
+ - **ASR**: Whisper-small transcribes your Hausa audio
21
+ - **NLU**: Hybrid — rule-based keyword matcher (fast path) + Qwen2.5-1.5B-Instruct (zero-shot fallback for paraphrases)
22
+ - **Dialogue manager**: deterministic FSM across 3 verticals (Bank, Telecom, Delivery)
23
+ - **TTS**: `facebook/mms-tts-hau` synthesizes the bot's Hausa response
24
+
25
+ ## How to use
26
+
27
+ 1. Pick a vertical (Bank / Telecom / Delivery)
28
+ 2. Type a Hausa phrase, record with microphone, or upload audio
29
+ 3. The bot's audio response autoplays
30
+
31
+ **Sample bank flow**: Type `duba ma'auni` → then `1234` → bot returns balance.
32
+
33
+ **Escalation**: Say "mutum" or "wakili" at any time to flag human handoff.
34
+
35
+ ## Notes
36
+
37
+ First turn cold-starts ASR + TTS (~640 MB download). Qwen2.5-1.5B (~3 GB) only loads when rule-based NLU misses. Subsequent turns: 5-10s on CPU.
38
+
39
+ This is a POC. Production plan: fine-tuned Hausa Whisper, fine-tuned AfroXLMR NLU (replacing the LLM tier), live WhatsApp Business Cloud integration, Twilio Voice.
app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlotWeaver Voice Agent — HuggingFace Space (Gradio 6 + Python 3.13)
3
+ ====================================================================
4
+ Hausa-first conversational AI for African banks, telecoms, and delivery.
5
+
6
+ Pipeline (all real, running on CPU):
7
+ ASR (openai/whisper-small)
8
+ → NLU (rule-based + Qwen2.5-1.5B-Instruct fallback, see nlu.py)
9
+ → Dialogue FSM (see dialogue.py)
10
+ → TTS (facebook/mms-tts-hau)
11
+
12
+ First turn: ~30-60s model downloads. Subsequent turns: ~5-10s on CPU.
13
+ """
14
+ from __future__ import annotations
15
+ import time
16
+ import uuid
17
+ import html as html_lib
18
+ from typing import Optional
19
+
20
+ import gradio as gr
21
+ import numpy as np
22
+ import torch
23
+ from transformers import (
24
+ VitsModel, AutoTokenizer,
25
+ WhisperProcessor, WhisperForConditionalGeneration,
26
+ )
27
+
28
+ from dialogue import (
29
+ DialogueState, SCENARIOS,
30
+ get_prompt, get_expected_slot, transition,
31
+ )
32
+ from nlu import parse as nlu_parse
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Model loading (lazy, cached)
37
+ # ---------------------------------------------------------------------------
38
+ _asr_model = None
39
+ _asr_processor = None
40
+ _tts_model = None
41
+ _tts_tokenizer = None
42
+
43
+
44
+ def load_asr():
45
+ global _asr_model, _asr_processor
46
+ if _asr_model is None:
47
+ print("Loading Whisper-small…")
48
+ _asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
49
+ _asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
50
+ _asr_model.eval()
51
+ print("Whisper-small ready.")
52
+ return _asr_model, _asr_processor
53
+
54
+
55
+ def load_tts():
56
+ global _tts_model, _tts_tokenizer
57
+ if _tts_model is None:
58
+ print("Loading MMS-TTS Hausa…")
59
+ _tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau")
60
+ _tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau")
61
+ _tts_model.eval()
62
+ print("MMS-TTS Hausa ready.")
63
+ return _tts_model, _tts_tokenizer
64
+
65
+
66
+ def transcribe_hausa(audio_tuple) -> str:
67
+ if audio_tuple is None:
68
+ return ""
69
+ sample_rate, audio_array = audio_tuple
70
+ if audio_array is None or len(audio_array) == 0:
71
+ return ""
72
+ if audio_array.dtype != np.float32:
73
+ audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max
74
+ if audio_array.ndim > 1:
75
+ audio_array = audio_array.mean(axis=1)
76
+ # Cap at 30s (Whisper training chunk size)
77
+ max_samples = sample_rate * 30
78
+ if len(audio_array) > max_samples:
79
+ audio_array = audio_array[:max_samples]
80
+ if sample_rate != 16000:
81
+ import scipy.signal
82
+ num_samples = int(len(audio_array) * 16000 / sample_rate)
83
+ audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32)
84
+
85
+ model, processor = load_asr()
86
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
87
+ forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")
88
+ with torch.no_grad():
89
+ ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128)
90
+ text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
91
+ return text
92
+
93
+
94
+ def synthesize_hausa(text: str) -> Optional[tuple]:
95
+ if not text.strip():
96
+ return None
97
+ model, tokenizer = load_tts()
98
+ inputs = tokenizer(text, return_tensors="pt")
99
+ with torch.no_grad():
100
+ out = model(**inputs).waveform
101
+ audio = out.squeeze().cpu().numpy().astype(np.float32)
102
+ return (model.config.sampling_rate, audio)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # WhatsApp-style HTML rendering
107
+ # ---------------------------------------------------------------------------
108
+ def _now() -> str:
109
+ return time.strftime("%H:%M")
110
+
111
+
112
+ def _user_bubble(text: str, is_voice: bool) -> str:
113
+ t = html_lib.escape(text)
114
+ if is_voice:
115
+ bars = "".join(
116
+ f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>'
117
+ for i in range(20)
118
+ )
119
+ return f'''<div class="pw-b user">
120
+ <div class="pw-voice-row">
121
+ <div class="pw-voice-icon">▶</div>
122
+ <div class="pw-voice-bars">{bars}</div>
123
+ </div>
124
+ <div style="font-size:12px;color:#667781;margin-top:3px;">"{t}"</div>
125
+ <div class="pw-b-meta">{_now()} ✓✓</div>
126
+ </div>'''
127
+ return f'<div class="pw-b user">{t}<div class="pw-b-meta">{_now()} ✓✓</div></div>'
128
+
129
+
130
+ def _bot_bubble(text_ha: str, text_en: str) -> str:
131
+ ha = html_lib.escape(text_ha)
132
+ en = html_lib.escape(text_en)
133
+ return f'''<div class="pw-b bot">
134
+ <div>{ha}</div>
135
+ <div class="pw-b-trans">{en}</div>
136
+ <div class="pw-b-meta">{_now()} ✓✓</div>
137
+ </div>'''
138
+
139
+
140
+ def render_whatsapp(session: dict) -> str:
141
+ vertical = session.get("vertical", "bank") if session else "bank"
142
+ name = SCENARIOS[vertical]["name"]
143
+ avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical]
144
+ escalated = session.get("escalate_to_human", False) if session else False
145
+
146
+ bubbles = []
147
+ for msg in session.get("history", []) if session else []:
148
+ if msg["role"] == "user":
149
+ bubbles.append(_user_bubble(msg["text"], msg.get("is_voice", False)))
150
+ else:
151
+ bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", "")))
152
+
153
+ banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>'
154
+ if escalated else "")
155
+
156
+ if not bubbles:
157
+ body = '<div style="text-align:center;color:#667781;font-size:12px;padding:40px 0;">Send a message to begin…</div>'
158
+ else:
159
+ body = "".join(bubbles)
160
+
161
+ return f"""
162
+ <div class="pw-phone">
163
+ <div class="pw-ph-header">
164
+ <div class="pw-ph-avatar">{avatar}</div>
165
+ <div>
166
+ <div class="pw-ph-name">{html_lib.escape(name)}</div>
167
+ <div class="pw-ph-status">online • voice agent</div>
168
+ </div>
169
+ </div>
170
+ <div class="pw-ph-messages">
171
+ {banner}
172
+ {body}
173
+ </div>
174
+ </div>
175
+ <style>
176
+ .pw-phone {{ max-width: 480px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 540px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }}
177
+ .pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }}
178
+ .pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }}
179
+ .pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }}
180
+ .pw-ph-status {{ font-size: 11px; color: #D4EDE8; }}
181
+ .pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 480px; overflow-y: auto; min-height: 420px; }}
182
+ .pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }}
183
+ .pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }}
184
+ .pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }}
185
+ .pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }}
186
+ .pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }}
187
+ .pw-voice-row {{ display: flex; align-items: center; gap: 8px; }}
188
+ .pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }}
189
+ .pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }}
190
+ .pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }}
191
+ .pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }}
192
+ </style>
193
+ """
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # Core turn handler
198
+ # ---------------------------------------------------------------------------
199
+ def run_turn(user_text: str, session: dict, is_voice: bool = False):
200
+ """Returns (updated_session_dict, bot_audio)."""
201
+ state = DialogueState.from_dict(session) if session else None
202
+ if state is None:
203
+ state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank")
204
+
205
+ expected = get_expected_slot(state.vertical, state.current_state)
206
+ intent, entities, _ = nlu_parse(user_text, expected)
207
+ state = transition(state, intent, entities)
208
+
209
+ prompt = get_prompt(state.vertical, state.current_state)
210
+
211
+ state.history.append({"role": "user", "text": user_text, "is_voice": is_voice})
212
+ state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]})
213
+
214
+ try:
215
+ audio = synthesize_hausa(prompt["ha"])
216
+ except Exception as e:
217
+ print(f"TTS failed: {e}")
218
+ audio = None
219
+
220
+ return state.to_dict(), audio
221
+
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # Gradio event handlers
225
+ # ---------------------------------------------------------------------------
226
+ def on_vertical_change(vertical: str):
227
+ state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical)
228
+ greet = get_prompt(vertical, "greeting")
229
+ state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]})
230
+ session = state.to_dict()
231
+ return session, render_whatsapp(session), None
232
+
233
+
234
+ def on_text_submit(text: str, session: dict):
235
+ if not text or not text.strip():
236
+ return session, render_whatsapp(session), None, ""
237
+ new_session, audio = run_turn(text, session, is_voice=False)
238
+ return new_session, render_whatsapp(new_session), audio, ""
239
+
240
+
241
+ def on_audio_submit(audio_data, session: dict):
242
+ if audio_data is None:
243
+ return session, render_whatsapp(session), None
244
+ try:
245
+ text = transcribe_hausa(audio_data)
246
+ except Exception as e:
247
+ print(f"ASR failed: {e}")
248
+ return session, render_whatsapp(session), None
249
+ if not text:
250
+ return session, render_whatsapp(session), None
251
+ new_session, audio = run_turn(text, session, is_voice=True)
252
+ return new_session, render_whatsapp(new_session), audio
253
+
254
+
255
+ def on_reset(session: dict):
256
+ vertical = session.get("vertical", "bank") if session else "bank"
257
+ return on_vertical_change(vertical)
258
+
259
+
260
+ # ---------------------------------------------------------------------------
261
+ # Gradio UI (chat-only, minimal components)
262
+ # ---------------------------------------------------------------------------
263
+ CUSTOM_CSS = """
264
+ .gradio-container { max-width: 720px !important; margin: 0 auto !important; }
265
+ #whatsapp-container { padding: 20px 0; }
266
+ """
267
+
268
+ with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo:
269
+ gr.HTML("""
270
+ <div style="text-align:center; padding: 0 0 12px;">
271
+ <h1 style="margin:0 0 4px; font-size: 22px; font-weight: 500;">PlotWeaver Voice Agent</h1>
272
+ <p style="margin:0; color: #5f5e5a; font-size: 14px;">Hausa-first conversational AI — pick a vertical, type or speak in Hausa</p>
273
+ </div>
274
+ """)
275
+
276
+ session_state = gr.State({})
277
+
278
+ vertical_radio = gr.Radio(
279
+ choices=[("PlotWeaver Bank", "bank"),
280
+ ("PlotWeaver Telecom", "telecom"),
281
+ ("PlotWeaver Delivery", "ecommerce")],
282
+ value="bank",
283
+ label="Vertical",
284
+ container=False,
285
+ )
286
+
287
+ whatsapp_html = gr.HTML(elem_id="whatsapp-container")
288
+
289
+ with gr.Row():
290
+ text_input = gr.Textbox(
291
+ placeholder="Type in Hausa… e.g. 'duba ma'auni'",
292
+ label="",
293
+ scale=4,
294
+ container=False,
295
+ )
296
+ send_btn = gr.Button("Send", scale=1, variant="primary")
297
+ reset_btn = gr.Button("Reset", scale=1)
298
+
299
+ audio_input = gr.Audio(
300
+ sources=["microphone", "upload"],
301
+ type="numpy",
302
+ label="Record or upload Hausa audio (click Stop when done recording)",
303
+ )
304
+
305
+ bot_audio = gr.Audio(
306
+ label="Bot response (Hausa TTS)",
307
+ autoplay=True,
308
+ interactive=False,
309
+ )
310
+
311
+ # Events
312
+ demo.load(
313
+ fn=lambda: on_vertical_change("bank"),
314
+ outputs=[session_state, whatsapp_html, bot_audio],
315
+ )
316
+ vertical_radio.change(
317
+ fn=on_vertical_change,
318
+ inputs=[vertical_radio],
319
+ outputs=[session_state, whatsapp_html, bot_audio],
320
+ )
321
+ send_btn.click(
322
+ fn=on_text_submit,
323
+ inputs=[text_input, session_state],
324
+ outputs=[session_state, whatsapp_html, bot_audio, text_input],
325
+ )
326
+ text_input.submit(
327
+ fn=on_text_submit,
328
+ inputs=[text_input, session_state],
329
+ outputs=[session_state, whatsapp_html, bot_audio, text_input],
330
+ )
331
+ audio_input.stop_recording(
332
+ fn=on_audio_submit,
333
+ inputs=[audio_input, session_state],
334
+ outputs=[session_state, whatsapp_html, bot_audio],
335
+ )
336
+ reset_btn.click(
337
+ fn=on_reset,
338
+ inputs=[session_state],
339
+ outputs=[session_state, whatsapp_html, bot_audio],
340
+ )
341
+
342
+
343
+ if __name__ == "__main__":
344
+ demo.launch(server_name="0.0.0.0", server_port=7860)
dialogue.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PlotWeaver Voice Agent — Dialogue Manager
3
+ ==========================================
4
+ FSM for multi-turn Hausa conversations across 3 verticals.
5
+ State lives in Gradio session state (dict) — no Redis needed in the Space.
6
+ """
7
+ from __future__ import annotations
8
+ from dataclasses import dataclass, field, asdict
9
+ from enum import Enum
10
+ from typing import Optional
11
+
12
+
13
+ class Vertical(str, Enum):
14
+ BANK = "bank"
15
+ TELECOM = "telecom"
16
+ ECOMMERCE = "ecommerce"
17
+
18
+
19
+ @dataclass
20
+ class DialogueState:
21
+ session_id: str
22
+ vertical: str
23
+ current_state: str = "greeting"
24
+ slots: dict = field(default_factory=dict)
25
+ turn_count: int = 0
26
+ escalate_to_human: bool = False
27
+ history: list = field(default_factory=list)
28
+
29
+ def to_dict(self):
30
+ return asdict(self)
31
+
32
+ @classmethod
33
+ def from_dict(cls, d):
34
+ return cls(**d) if d else None
35
+
36
+
37
+ SCENARIOS = {
38
+ "bank": {
39
+ "name": "PlotWeaver Bank",
40
+ "states": {
41
+ "greeting": {
42
+ "ha": "Sannu! Wannan shine mataimakin banki na PlotWeaver. Yaya zan taimake ka yau? Za ka iya ce 'duba ma'auni', 'toshe kati', ko 'canjin kuɗi'.",
43
+ "en": "Hello! This is the PlotWeaver banking assistant. How can I help you today? You can say 'check balance', 'block card', or 'transfer money'.",
44
+ "expects": "intent",
45
+ "transitions": {"check_balance": "ask_account_number", "block_card": "confirm_block_card", "transfer_money": "ask_recipient"},
46
+ },
47
+ "ask_account_number": {
48
+ "ha": "Don Allah ka faɗi lambobin ƙarshe huɗu na asusunka.",
49
+ "en": "Please say the last four digits of your account number.",
50
+ "expects": "digits",
51
+ "transitions": {"provide_digits": "return_balance"},
52
+ },
53
+ "return_balance": {
54
+ "ha": "Ma'aunin asusunka shine Naira dubu ɗari biyu da arba'in da biyar. Akwai wani abu?",
55
+ "en": "Your account balance is two hundred forty-five thousand Naira. Anything else?",
56
+ "expects": "yesno",
57
+ "transitions": {"yes": "greeting", "no": "exit"},
58
+ },
59
+ "confirm_block_card": {
60
+ "ha": "Don tabbatar, kana son toshe katinka? Ka ce 'i' ko 'a'a'.",
61
+ "en": "To confirm, you want to block your card? Say 'yes' or 'no'.",
62
+ "expects": "yesno",
63
+ "transitions": {"yes": "card_blocked", "no": "greeting"},
64
+ },
65
+ "card_blocked": {
66
+ "ha": "An toshe katinka. Sabon kati zai iso a cikin kwanaki uku zuwa biyar. Ana juya ka ga wakili don tabbatar.",
67
+ "en": "Your card is blocked. A new card will arrive in 3-5 days. Transferring you to an agent for confirmation.",
68
+ "expects": None, "terminal": True, "escalate": True,
69
+ },
70
+ "ask_recipient": {
71
+ "ha": "Zuwa wa kake son turawa? Ka faɗi sunan mai karɓa.",
72
+ "en": "Who do you want to transfer to? Say the recipient's name.",
73
+ "expects": "name",
74
+ "transitions": {"provide_name": "ask_amount"},
75
+ },
76
+ "ask_amount": {
77
+ "ha": "Nawa kake son turawa, a Naira?",
78
+ "en": "How much do you want to transfer, in Naira?",
79
+ "expects": "amount",
80
+ "transitions": {"provide_amount": "confirm_transfer"},
81
+ },
82
+ "confirm_transfer": {
83
+ "ha": "Zan tura kuɗin yanzu. Ka ce 'i' don ci gaba.",
84
+ "en": "I'll send the money now. Say 'yes' to continue.",
85
+ "expects": "yesno",
86
+ "transitions": {"yes": "transfer_done", "no": "greeting"},
87
+ },
88
+ "transfer_done": {
89
+ "ha": "An tura kuɗin. Godiya da zabar PlotWeaver Bank.",
90
+ "en": "Money sent. Thank you for choosing PlotWeaver Bank.",
91
+ "expects": None, "terminal": True,
92
+ },
93
+ },
94
+ },
95
+ "telecom": {
96
+ "name": "PlotWeaver Telecom",
97
+ "states": {
98
+ "greeting": {
99
+ "ha": "Sannu! Wannan shine PlotWeaver Telecom. Kana son 'saya airtime', 'saya bundle', ko 'yin korafi'?",
100
+ "en": "Hello! This is PlotWeaver Telecom. Would you like to 'buy airtime', 'buy bundle', or 'file a complaint'?",
101
+ "expects": "intent",
102
+ "transitions": {"buy_airtime": "ask_airtime_amount", "buy_bundle": "ask_bundle_type", "complaint": "ask_complaint"},
103
+ },
104
+ "ask_airtime_amount": {
105
+ "ha": "Nawa na airtime kake son saya? Misali, Naira ɗari ko dubu.",
106
+ "en": "How much airtime? For example 100 or 1000 Naira.",
107
+ "expects": "amount",
108
+ "transitions": {"provide_amount": "airtime_done"},
109
+ },
110
+ "airtime_done": {
111
+ "ha": "An kara airtime. Ma'aunin ka sabo shine Naira dubu ɗaya da ɗari biyar.",
112
+ "en": "Airtime loaded. Your new balance is 1500 Naira.",
113
+ "expects": None, "terminal": True,
114
+ },
115
+ "ask_bundle_type": {
116
+ "ha": "Wane irin bundle? Muna da 'rana', 'mako', ko 'wata'.",
117
+ "en": "Which bundle type? 'day', 'week', or 'month'.",
118
+ "expects": "bundle",
119
+ "transitions": {"provide_bundle": "bundle_done"},
120
+ },
121
+ "bundle_done": {
122
+ "ha": "An kunna bundle ɗinka. Za ka iya yin amfani da shi yanzu.",
123
+ "en": "Your bundle is active. You can use it now.",
124
+ "expects": None, "terminal": True,
125
+ },
126
+ "ask_complaint": {
127
+ "ha": "Me ya faru? Ka bayyana matsalar da kake fuskanta.",
128
+ "en": "What happened? Please describe the issue.",
129
+ "expects": "text",
130
+ "transitions": {"provide_text": "escalate"},
131
+ },
132
+ "escalate": {
133
+ "ha": "Nagode. Zan juya ka ga wakili na mutum yanzu.",
134
+ "en": "Thank you. I'll transfer you to a human agent now.",
135
+ "expects": None, "terminal": True, "escalate": True,
136
+ },
137
+ },
138
+ },
139
+ "ecommerce": {
140
+ "name": "PlotWeaver Delivery",
141
+ "states": {
142
+ "greeting": {
143
+ "ha": "Sannu! Wannan shine PlotWeaver Delivery. Kana son 'bincika oda', 'sake tsara lokaci', ko 'mayar da kaya'?",
144
+ "en": "Hello! This is PlotWeaver Delivery. Would you like to 'check order', 'reschedule', or 'return'?",
145
+ "expects": "intent",
146
+ "transitions": {"check_order": "ask_order_id", "reschedule": "ask_order_id_reschedule", "return_item": "ask_order_id_return"},
147
+ },
148
+ "ask_order_id": {
149
+ "ha": "Ka faɗi lambar oda naka.",
150
+ "en": "Say your order number.",
151
+ "expects": "digits",
152
+ "transitions": {"provide_digits": "order_status"},
153
+ },
154
+ "order_status": {
155
+ "ha": "Oda ɗinka yana kan hanya. Za a isar gobe da yamma.",
156
+ "en": "Your order is on the way. It will be delivered tomorrow evening.",
157
+ "expects": None, "terminal": True,
158
+ },
159
+ "ask_order_id_reschedule": {
160
+ "ha": "Ka faɗi lambar oda da kake son sake tsarawa.",
161
+ "en": "Say the order number you want to reschedule.",
162
+ "expects": "digits",
163
+ "transitions": {"provide_digits": "ask_new_date"},
164
+ },
165
+ "ask_new_date": {
166
+ "ha": "Wace rana kake so? Misali 'jumma'a' ko 'asabar'.",
167
+ "en": "Which day? For example 'Friday' or 'Saturday'.",
168
+ "expects": "date",
169
+ "transitions": {"provide_date": "reschedule_done"},
170
+ },
171
+ "reschedule_done": {
172
+ "ha": "An sake tsara isar. Za ka sami SMS na tabbatarwa.",
173
+ "en": "Delivery rescheduled. You'll receive a confirmation SMS.",
174
+ "expects": None, "terminal": True,
175
+ },
176
+ "ask_order_id_return": {
177
+ "ha": "Ka faɗi lambar oda da kake son mayarwa.",
178
+ "en": "Say the order number you want to return.",
179
+ "expects": "digits",
180
+ "transitions": {"provide_digits": "return_reason"},
181
+ },
182
+ "return_reason": {
183
+ "ha": "Me ya sa kake son mayarwa?",
184
+ "en": "Why do you want to return it?",
185
+ "expects": "text",
186
+ "transitions": {"provide_reason": "return_done"},
187
+ },
188
+ "return_done": {
189
+ "ha": "An karɓi buƙatarka. Wakili zai tattara kaya a gobe.",
190
+ "en": "Your request is received. An agent will collect the item tomorrow.",
191
+ "expects": None, "terminal": True,
192
+ },
193
+ },
194
+ },
195
+ }
196
+
197
+
198
+ def get_prompt(vertical: str, state_name: str) -> dict:
199
+ if state_name == "escalate_virtual":
200
+ return {"ha": "Zan juya ka ga wakili na mutum yanzu. Ka jira ɗan lokaci.",
201
+ "en": "I'll transfer you to a human agent now. Please hold."}
202
+ if state_name == "exit":
203
+ return {"ha": "Nagode. Sai watan.", "en": "Thank you. Goodbye."}
204
+ s = SCENARIOS[vertical]["states"].get(state_name)
205
+ if not s:
206
+ return {"ha": "Ban fahimci abin da ka ce ba.", "en": "I didn't understand."}
207
+ return {"ha": s["ha"], "en": s["en"]}
208
+
209
+
210
+ def get_expected_slot(vertical: str, state_name: str) -> Optional[str]:
211
+ s = SCENARIOS[vertical]["states"].get(state_name)
212
+ return s.get("expects") if s else None
213
+
214
+
215
+ def transition(state: DialogueState, intent: str, entities: dict) -> DialogueState:
216
+ state.turn_count += 1
217
+ for k, v in entities.items():
218
+ state.slots[k] = v
219
+
220
+ if intent == "human_agent" or state.turn_count > 12:
221
+ state.current_state = "escalate_virtual"
222
+ state.escalate_to_human = True
223
+ return state
224
+
225
+ current = SCENARIOS[state.vertical]["states"].get(state.current_state)
226
+ if not current:
227
+ state.current_state = "greeting"
228
+ return state
229
+
230
+ next_state = current.get("transitions", {}).get(intent)
231
+ if next_state:
232
+ state.current_state = next_state
233
+ target = SCENARIOS[state.vertical]["states"].get(next_state, {})
234
+ if target.get("escalate"):
235
+ state.escalate_to_human = True
236
+
237
+ return state
nlu.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NLU — Hybrid Hausa intent + entity extraction.
3
+
4
+ Three-tier architecture:
5
+ 1. Rule-based keyword matcher (fast path, ~80% of demo utterances)
6
+ 2. Qwen2.5-1.5B-Instruct zero-shot JSON extractor (paraphrases, novel phrasings)
7
+ 3. Rule-based fallback (if LLM fails or returns unparseable output)
8
+
9
+ The LLM is lazy-loaded on first non-matched utterance so the Space boots fast.
10
+ In production this would be replaced with a fine-tuned classifier on
11
+ PlotWeaver's Hausa intent corpus.
12
+ """
13
+ from __future__ import annotations
14
+ import re
15
+ import json
16
+ import logging
17
+ from typing import Optional
18
+
19
+ logger = logging.getLogger("plotweaver.nlu")
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Layer 1: rule-based fast path (covers common demo phrases)
23
+ # ---------------------------------------------------------------------------
24
+ INTENT_KEYWORDS = {
25
+ "check_balance": ["duba", "ma'auni", "balance", "kudi", "asusu"],
26
+ "block_card": ["toshe", "kati", "block"],
27
+ "transfer_money": ["tura", "canji", "canjin", "aika", "transfer"],
28
+ "buy_airtime": ["airtime", "caji"],
29
+ "buy_bundle": ["bundle", "data", "intanet"],
30
+ "complaint": ["korafi", "matsala", "complain"],
31
+ "check_order": ["bincika", "order", "oda"],
32
+ "reschedule": ["sake tsara", "reschedule", "canja lokaci"],
33
+ "return_item": ["mayar", "mayarwa", "return"],
34
+ "human_agent": ["mutum", "wakili", "agent", "human"],
35
+ "yes": ["i ", " i", "eh", "haka ne", "yes", "ok", "okay"],
36
+ "no": ["a'a", "a'aa", "ba haka", " no", "no "],
37
+ }
38
+
39
+ WORD_DIGITS = {
40
+ "sifili": "0", "daya": "1", "ɗaya": "1", "biyu": "2", "uku": "3",
41
+ "hudu": "4", "huɗu": "4", "biyar": "5", "shida": "6", "bakwai": "7",
42
+ "takwas": "8", "tara": "9",
43
+ }
44
+
45
+ WORD_AMOUNTS = {
46
+ "dubu goma": 10000, "dubu biyar": 5000, "dubu biyu": 2000,
47
+ "dubu": 1000, "ɗari biyar": 500, "dari biyar": 500,
48
+ "ɗari": 100, "dari": 100,
49
+ }
50
+
51
+
52
+ def _norm(t: str) -> str:
53
+ return " " + t.lower().strip() + " "
54
+
55
+
56
+ def _match_intent_kw(text: str) -> Optional[str]:
57
+ t = _norm(text)
58
+ for intent, kws in INTENT_KEYWORDS.items():
59
+ for kw in kws:
60
+ if kw in t:
61
+ return intent
62
+ return None
63
+
64
+
65
+ def _extract_digits(text: str) -> Optional[str]:
66
+ m = re.findall(r"\d+", text)
67
+ if m:
68
+ return "".join(m)
69
+ tokens = text.lower().split()
70
+ d = [WORD_DIGITS[tok] for tok in tokens if tok in WORD_DIGITS]
71
+ return "".join(d) if d else None
72
+
73
+
74
+ def _extract_amount(text: str) -> Optional[int]:
75
+ m = re.search(r"\d+", text)
76
+ if m:
77
+ return int(m.group())
78
+ t = text.lower()
79
+ for phrase in sorted(WORD_AMOUNTS.keys(), key=len, reverse=True):
80
+ if phrase in t:
81
+ return WORD_AMOUNTS[phrase]
82
+ return None
83
+
84
+
85
+ def _rule_based_parse(text: str, expected: Optional[str]) -> tuple[str, dict]:
86
+ """Layer 1 + 3: deterministic keyword + slot matcher."""
87
+ entities: dict = {}
88
+ if not text or not text.strip():
89
+ return "unknown", entities
90
+
91
+ # Universal escape
92
+ if _match_intent_kw(text) == "human_agent":
93
+ return "human_agent", entities
94
+
95
+ if expected == "digits":
96
+ d = _extract_digits(text)
97
+ if d:
98
+ entities["digits"] = d
99
+ return "provide_digits", entities
100
+
101
+ if expected == "amount":
102
+ a = _extract_amount(text)
103
+ if a is not None:
104
+ entities["amount"] = a
105
+ return "provide_amount", entities
106
+
107
+ if expected == "name":
108
+ name = text.strip().split()[-1] if text.strip() else ""
109
+ if name:
110
+ entities["name"] = name
111
+ return "provide_name", entities
112
+
113
+ if expected == "date":
114
+ entities["date"] = text.strip()
115
+ return "provide_date", entities
116
+
117
+ if expected == "bundle":
118
+ t = text.lower()
119
+ for b in ("rana", "mako", "wata"):
120
+ if b in t:
121
+ entities["bundle"] = b
122
+ return "provide_bundle", entities
123
+
124
+ if expected == "text":
125
+ entities["text"] = text.strip()
126
+ return "provide_text", entities
127
+
128
+ if expected == "yesno":
129
+ i = _match_intent_kw(text)
130
+ if i in ("yes", "no"):
131
+ return i, entities
132
+
133
+ i = _match_intent_kw(text)
134
+ if i:
135
+ return i, entities
136
+
137
+ return "unknown", entities
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Layer 2: Qwen2.5-1.5B-Instruct zero-shot NLU
142
+ # ---------------------------------------------------------------------------
143
+ _llm_model = None
144
+ _llm_tokenizer = None
145
+ _llm_failed = False # set to True after any load failure, to prevent retries
146
+
147
+
148
+ def _load_llm():
149
+ """Lazy-load Qwen2.5-1.5B-Instruct. Called only when rule-based misses."""
150
+ global _llm_model, _llm_tokenizer, _llm_failed
151
+ if _llm_failed:
152
+ return None, None
153
+ if _llm_model is not None:
154
+ return _llm_model, _llm_tokenizer
155
+ try:
156
+ import torch
157
+ from transformers import AutoModelForCausalLM, AutoTokenizer
158
+ logger.info("Loading Qwen2.5-1.5B-Instruct for NLU…")
159
+ model_id = "Qwen/Qwen2.5-1.5B-Instruct"
160
+ _llm_tokenizer = AutoTokenizer.from_pretrained(model_id)
161
+ _llm_model = AutoModelForCausalLM.from_pretrained(
162
+ model_id,
163
+ torch_dtype=torch.float32, # CPU — bfloat16 not broadly supported
164
+ low_cpu_mem_usage=True,
165
+ )
166
+ _llm_model.eval()
167
+ logger.info("Qwen2.5-1.5B-Instruct ready.")
168
+ return _llm_model, _llm_tokenizer
169
+ except Exception as e:
170
+ logger.warning(f"LLM load failed: {e}")
171
+ _llm_failed = True
172
+ return None, None
173
+
174
+
175
+ # Candidate intents per expected-slot context. Keeps the LLM prompt small
176
+ # and constrains output to valid options only.
177
+ CANDIDATE_INTENTS = {
178
+ None: ["check_balance", "block_card", "transfer_money",
179
+ "buy_airtime", "buy_bundle", "complaint",
180
+ "check_order", "reschedule", "return_item",
181
+ "human_agent", "unknown"],
182
+ "intent": ["check_balance", "block_card", "transfer_money",
183
+ "buy_airtime", "buy_bundle", "complaint",
184
+ "check_order", "reschedule", "return_item",
185
+ "human_agent", "unknown"],
186
+ "yesno": ["yes", "no", "human_agent", "unknown"],
187
+ "digits": ["provide_digits", "human_agent", "unknown"],
188
+ "amount": ["provide_amount", "human_agent", "unknown"],
189
+ "name": ["provide_name", "human_agent", "unknown"],
190
+ "date": ["provide_date", "human_agent", "unknown"],
191
+ "bundle": ["provide_bundle", "human_agent", "unknown"],
192
+ "text": ["provide_text", "human_agent", "unknown"],
193
+ }
194
+
195
+
196
+ SYSTEM_PROMPT = """You are an intent classifier for a Hausa-language customer service voice agent.
197
+
198
+ Analyze the user's Hausa utterance and return a JSON object with:
199
+ - "intent": one of the candidate intents provided
200
+ - "entities": a dict of extracted values (may be empty)
201
+
202
+ Intent meanings:
203
+ - check_balance: user wants to check their account balance
204
+ - block_card: user wants to block or freeze their bank card
205
+ - transfer_money: user wants to transfer or send money
206
+ - buy_airtime: user wants to buy phone airtime
207
+ - buy_bundle: user wants to buy a data bundle
208
+ - complaint: user wants to file a complaint
209
+ - check_order: user wants to check an order status
210
+ - reschedule: user wants to reschedule a delivery
211
+ - return_item: user wants to return an item
212
+ - human_agent: user wants to speak to a human
213
+ - yes / no: affirmative or negative response
214
+ - provide_digits / provide_amount / provide_name / provide_date / provide_bundle / provide_text: user is providing specific information
215
+ - unknown: cannot determine the intent
216
+
217
+ Return ONLY a valid JSON object, no explanation. Example: {"intent": "check_balance", "entities": {}}"""
218
+
219
+
220
+ def _llm_parse(text: str, expected: Optional[str]) -> Optional[tuple[str, dict]]:
221
+ """Layer 2: zero-shot LLM classification. Returns None on any failure."""
222
+ model, tokenizer = _load_llm()
223
+ if model is None:
224
+ return None
225
+
226
+ candidates = CANDIDATE_INTENTS.get(expected, CANDIDATE_INTENTS[None])
227
+ user_prompt = (
228
+ f'Hausa utterance: "{text}"\n'
229
+ f'Expected slot type: {expected or "any"}\n'
230
+ f'Candidate intents: {", ".join(candidates)}\n\n'
231
+ 'Respond with JSON only.'
232
+ )
233
+ messages = [
234
+ {"role": "system", "content": SYSTEM_PROMPT},
235
+ {"role": "user", "content": user_prompt},
236
+ ]
237
+ try:
238
+ import torch
239
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
240
+ inputs = tokenizer(prompt, return_tensors="pt")
241
+ with torch.no_grad():
242
+ out = model.generate(
243
+ **inputs,
244
+ max_new_tokens=80,
245
+ do_sample=False,
246
+ pad_token_id=tokenizer.eos_token_id,
247
+ )
248
+ generated = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
249
+ logger.info(f"LLM raw output: {generated}")
250
+
251
+ # Extract JSON (model sometimes wraps it in markdown fences or prose)
252
+ m = re.search(r"\{.*?\}", generated, re.DOTALL)
253
+ if not m:
254
+ return None
255
+ parsed = json.loads(m.group())
256
+ intent = parsed.get("intent", "unknown")
257
+ entities = parsed.get("entities", {}) or {}
258
+ if not isinstance(entities, dict):
259
+ entities = {}
260
+ # Validate intent is in candidate list
261
+ if intent not in candidates:
262
+ logger.info(f"LLM returned out-of-candidate intent: {intent}")
263
+ return None
264
+ return intent, entities
265
+ except Exception as e:
266
+ logger.warning(f"LLM inference failed: {e}")
267
+ return None
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Public API
272
+ # ---------------------------------------------------------------------------
273
+ def parse(text: str, expected: Optional[str] = None,
274
+ use_llm: bool = True) -> tuple[str, dict, str]:
275
+ """
276
+ Hybrid NLU. Returns (intent, entities, source) where source is one of
277
+ 'rule', 'llm', or 'rule_fallback'.
278
+
279
+ Flow:
280
+ 1. Try rule-based keyword/slot matcher (fast, deterministic)
281
+ 2. If result is 'unknown' AND use_llm=True: try Qwen2.5 zero-shot
282
+ 3. If LLM fails or returns invalid output: return rule-based 'unknown'
283
+ """
284
+ intent, entities = _rule_based_parse(text, expected)
285
+
286
+ if intent != "unknown":
287
+ return intent, entities, "rule"
288
+
289
+ if not use_llm:
290
+ return intent, entities, "rule"
291
+
292
+ # Rule-based missed — try LLM
293
+ llm_result = _llm_parse(text, expected)
294
+ if llm_result is None:
295
+ return intent, entities, "rule_fallback"
296
+
297
+ llm_intent, llm_entities = llm_result
298
+
299
+ # Sanity-check entities for slot-typed expected (LLM might hallucinate
300
+ # digits; re-run our deterministic extractors for strict-format slots)
301
+ if expected == "digits":
302
+ d = _extract_digits(text)
303
+ if d:
304
+ llm_entities["digits"] = d
305
+ elif expected == "amount":
306
+ a = _extract_amount(text)
307
+ if a is not None:
308
+ llm_entities["amount"] = a
309
+
310
+ return llm_intent, llm_entities, "llm"
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.48.0
2
+ torch==2.5.1
3
+ accelerate==1.2.1
4
+ numpy==2.1.3
5
+ scipy==1.15.0
6
+ sentencepiece==0.2.0
7
+ audioop-lts==0.2.2