Spaces:
Running
Running
| """ | |
| Sahel-Agri Voice AI — HuggingFace Spaces (ZeroGPU) | |
| Two-way voice assistant: Bambara / Fula / French / English → voice response | |
| Environment variables (set in Space Settings → Secrets): | |
| HF_TOKEN — HF write-access token | |
| FEEDBACK_REPO_ID — e.g. ous-sow/sahel-agri-feedback (dataset, private) | |
| ADAPTER_REPO_ID — e.g. ous-sow/sahel-agri-adapters (model, private) | |
| WHISPER_MODEL_ID — default: openai/whisper-large-v3-turbo | |
| LLM_MODEL_ID — default: Qwen/Qwen2.5-72B-Instruct | |
| KAGGLE_USERNAME — Kaggle username (for auto-trigger training) | |
| KAGGLE_KEY — Kaggle API key (for auto-trigger training) | |
| KAGGLE_KERNEL_SLUG — default: ous-sow/sahel-voice-master-trainer | |
| AUTO_TRAIN_THRESHOLD — corrections count that triggers auto-training (default: 50) | |
| """ | |
| from __future__ import annotations | |
| import io | |
| import json | |
| import os | |
| import sys | |
| import threading | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| ROOT = Path(__file__).parent | |
| sys.path.insert(0, str(ROOT)) | |
| # ── env ─────────────────────────────────────────────────────────────────────── | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback") | |
| ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters") | |
| # whisper-large-v3-turbo: 128 mel bins, matches fine-tuned adapters trained on turbo. | |
| # whisper-small uses 80 mel bins — mismatches turbo adapters with a channel error. | |
| # Override via WHISPER_MODEL_ID env var in Space settings if needed. | |
| WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-large-v3-turbo") | |
| LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct") | |
| KAGGLE_USERNAME = os.environ.get("KAGGLE_USERNAME", "") | |
| KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "") | |
| KAGGLE_KERNEL_SLUG = os.environ.get("KAGGLE_KERNEL_SLUG", "ous-sow/sahel-voice-master-trainer") | |
| AUTO_TRAIN_THRESHOLD = int(os.environ.get("AUTO_TRAIN_THRESHOLD", "50")) | |
| # On local CPU (no HF_TOKEN / no spaces package) fall back gracefully | |
| _ON_SPACES = os.environ.get("SPACE_ID") is not None | |
| SUPPORTED_LANGUAGES = { | |
| "Bambara — Mali (bam)": "bam", | |
| "Fula / Pular — Guinea (ful)": "ful", | |
| "French / Français": "fr", | |
| "English": "en", | |
| } | |
| # Country and dialect context used in prompts and training metadata | |
| LANG_CONTEXT = { | |
| "bam": { | |
| "name": "Bambara", | |
| "country": "Mali", | |
| "region": "West Africa (Bamako, Ségou, Mopti dialects)", | |
| "script": "Latin with special characters (ɛ, ɔ, ŋ, ɲ)", | |
| "phonetic_note": ( | |
| "Use standard Malian orthography: 'u' not 'ou', 'j' not 'dj', " | |
| "'c' not 'ch', 'ɲ' not 'gn' or 'ny', 'ɔ' not 'oo', 'ɛ' not 'ee'. " | |
| "This is Bambara as spoken in Mali, NOT Dioula or other dialects." | |
| ), | |
| "do_not_mix": "Fula (Pulaar/Pular), Wolof, Dioula, or any other language", | |
| }, | |
| "ful": { | |
| "name": "Pular (Fula of Guinea)", | |
| "country": "Guinea", | |
| "region": "West Africa (Labé, Mamou, Kankan dialects)", | |
| "script": "Latin with special characters (ɓ, ɗ, ŋ, ɲ, ƴ)", | |
| "phonetic_note": ( | |
| "Use standard Guinean Pular orthography. " | |
| "This is the Fula variety spoken in Guinea (Pular/Pulaar), " | |
| "NOT Fulfulde from Niger/Nigeria nor Wolof." | |
| ), | |
| "do_not_mix": "Bambara, Soussou, Malinké, or any other language", | |
| }, | |
| "fr": { | |
| "name": "French", | |
| "country": "France / West Africa", | |
| "region": "", | |
| "script": "Latin", | |
| "phonetic_note": "Standard French.", | |
| "do_not_mix": "other languages unless the user switches", | |
| }, | |
| "en": { | |
| "name": "English", | |
| "country": "", | |
| "region": "", | |
| "script": "Latin", | |
| "phonetic_note": "Standard English.", | |
| "do_not_mix": "other languages unless the user switches", | |
| }, | |
| } | |
| # ── ZeroGPU decorator (no-op locally) ──────────────────────────────────────── | |
| try: | |
| import spaces # type: ignore | |
| _gpu = spaces.GPU(duration=55) | |
| except ImportError: | |
| def _gpu(fn): # local fallback: plain function | |
| return fn | |
| # ── Module-level model state (CPU-resident between requests) ───────────────── | |
| _whisper_model = None # WhisperForConditionalGeneration (base) | |
| _whisper_processor = None | |
| _fine_tuned_models = {} # lang_code -> WhisperForConditionalGeneration (full checkpoint) | |
| _fine_tuned_processors = {} # lang_code -> WhisperProcessor matching the fine-tuned checkpoint | |
| _model_lock = threading.Lock() | |
| _model_status = "not loaded" | |
| _load_started_at: float = 0.0 # monotonic time when loading began | |
| _LOAD_TIMEOUT = 180 # seconds before declaring a stuck load | |
| # ── Conversation-mode state ─────────────────────────────────────────────────── | |
| _voice_ref_path: str | None = None # path to 24 kHz WAV converted from user MP3 | |
| _voice_ref_text: str = "" # auto-transcribed text of reference audio | |
| _llm_client = None # GemmaClient, lazy init | |
| from src.tts.mms_tts import MMSTTSEngine | |
| from src.iot.intent_parser import IntentParser | |
| from src.iot.sensor_bridge import SensorBridge | |
| from src.iot.voice_responder import VoiceResponder | |
| from src.conversation.phrase_matcher import PhraseMatcher | |
| from src.llm.gemma_client import GemmaClient | |
| from src.data.bam_normalize import normalize as bam_normalize | |
| from src.data.adlam import normalize_pular | |
| _tts = MMSTTSEngine() | |
| _intent_parser = IntentParser() | |
| _sensor_bridge = SensorBridge() | |
| _phrase_matcher = PhraseMatcher() | |
| # HF API — only instantiate when token present | |
| _hf_api = None | |
| if HF_TOKEN: | |
| from huggingface_hub import HfApi | |
| _hf_api = HfApi(token=HF_TOKEN) | |
| # ── Model loading ───────────────────────────────────────────────────────────── | |
| def _do_load_whisper(): | |
| global _whisper_model, _whisper_processor, _model_status | |
| try: | |
| import torch | |
| try: | |
| from transformers.models.whisper import WhisperProcessor, WhisperForConditionalGeneration | |
| except ImportError: | |
| from transformers.models.whisper.processing_whisper import WhisperProcessor | |
| from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration | |
| _model_status = "loading…" | |
| _whisper_processor = WhisperProcessor.from_pretrained( | |
| WHISPER_MODEL_ID, token=HF_TOKEN | |
| ) | |
| try: | |
| _whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
| WHISPER_MODEL_ID, | |
| torch_dtype=torch.float32, | |
| token=HF_TOKEN, | |
| ) | |
| except TypeError: | |
| _whisper_model = WhisperForConditionalGeneration.from_pretrained( | |
| WHISPER_MODEL_ID, | |
| token=HF_TOKEN, | |
| ) | |
| _whisper_model.eval() | |
| _model_status = f"ready ({WHISPER_MODEL_ID})" | |
| except Exception as e: | |
| _model_status = f"error: {e}" | |
| def _ensure_whisper_loaded(): | |
| """Load Whisper to CPU in a background thread on first call. Non-blocking.""" | |
| import time | |
| global _model_status, _load_started_at | |
| with _model_lock: | |
| need_start = False | |
| if _whisper_model is None and "loading" not in _model_status: | |
| need_start = True | |
| elif (_whisper_model is None | |
| and "loading" in _model_status | |
| and _load_started_at | |
| and (time.monotonic() - _load_started_at) > _LOAD_TIMEOUT): | |
| _model_status = "error: load timed out after %ds — retrying" % _LOAD_TIMEOUT | |
| need_start = True | |
| if need_start: | |
| _model_status = "loading…" | |
| _load_started_at = time.monotonic() | |
| t = threading.Thread(target=_do_load_whisper, daemon=True) | |
| t.start() | |
| return _model_status | |
| def _wait_for_whisper(timeout: int = 120) -> bool: | |
| """ | |
| Block until Whisper is loaded or timeout (seconds) expires. | |
| Triggers loading if not already started. Returns True if model is ready. | |
| """ | |
| import time | |
| _ensure_whisper_loaded() | |
| deadline = time.monotonic() + timeout | |
| while time.monotonic() < deadline: | |
| if _whisper_model is not None: | |
| return True | |
| if "error" in _model_status: | |
| return False | |
| time.sleep(0.5) | |
| return False | |
| def get_model_status() -> str: | |
| s = _ensure_whisper_loaded() | |
| if "ready" in s: | |
| return f"🟢 {s}" | |
| if "loading" in s: | |
| return f"🟡 {s}" | |
| if "error" in s: | |
| return f"🔴 {s}" | |
| return f"⚪ {s}" | |
| # ── Core GPU pipeline ───────────────────────────────────────────────────────── | |
| def _run_pipeline(audio_path: str, language_code: str): | |
| """ | |
| Full STT → Intent → Sensor → TTS pipeline. | |
| Decorated with @spaces.GPU(duration=55) on HF Spaces; plain function locally. | |
| Returns: (transcript, response_text, (sample_rate, wav_np)) | |
| """ | |
| import asyncio | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ── 1. Whisper STT ──────────────────────────────────────────────────────── | |
| if _whisper_model is None: | |
| return "⏳ Model still loading…", "", None | |
| import librosa | |
| audio_np, _ = librosa.load(audio_path, sr=16000, mono=True) | |
| # Use fine-tuned checkpoint for this language if one has been loaded; | |
| # otherwise fall back to base Whisper. | |
| active_model = _fine_tuned_models.get(language_code, _whisper_model) | |
| # Use the matching processor so mel-bin count is consistent with the model | |
| # (large-v3-turbo = 128 bins; small = 80 bins). Mismatch causes channel error. | |
| active_processor = _fine_tuned_processors.get(language_code, _whisper_processor) | |
| active_model.to(device) | |
| with _model_lock: | |
| inputs = active_processor.feature_extractor( | |
| audio_np, sampling_rate=16000, return_tensors="pt" | |
| ) | |
| input_features = inputs.input_features.to(device) | |
| # Bambara and Fula have no Whisper language token — pass None so the model | |
| # auto-detects or falls back to multilingual decoding. | |
| if language_code in ("bam", "ful"): | |
| forced_ids = None | |
| else: | |
| forced_ids = active_processor.get_decoder_prompt_ids( | |
| language=language_code, task="transcribe" | |
| ) | |
| with torch.no_grad(): | |
| predicted_ids = active_model.generate( | |
| input_features, | |
| forced_decoder_ids=forced_ids if forced_ids else None, | |
| max_new_tokens=256, | |
| ) | |
| transcript = active_processor.batch_decode( | |
| predicted_ids, skip_special_tokens=True | |
| )[0].strip() | |
| # Free GPU VRAM before TTS | |
| active_model.to("cpu") | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # ── 2. Phrase library (general conversation — no sensors needed) ───────── | |
| phrase_match = _phrase_matcher.match(transcript, language_code) | |
| if phrase_match: | |
| response_text = phrase_match["response"] | |
| english_translation = phrase_match["english"] | |
| else: | |
| # ── 3. Intent + sensor data (agricultural queries) ──────────────────── | |
| intent = _intent_parser.parse(transcript, language=language_code) | |
| try: | |
| loop = asyncio.new_event_loop() | |
| sensor_data = loop.run_until_complete(_sensor_bridge.fetch(intent)) | |
| loop.close() | |
| except Exception: | |
| from src.iot.sensor_bridge import SensorData | |
| sensor_data = SensorData(sensor_type="soil", values={ | |
| "moisture_pct": 45.0, "ph": 6.5, "temperature_c": 28.0 | |
| }) | |
| responder = VoiceResponder(language=language_code) | |
| response_text, english_translation = responder.generate_response(intent, sensor_data) | |
| # Low-confidence fallback: say "I didn't understand" in the native language | |
| if intent.action == "unknown" and intent.confidence < 0.15: | |
| from src.iot.voice_responder import BAMBARA_TEMPLATES, FULA_TEMPLATES | |
| if language_code == "bam": | |
| response_text, english_translation = BAMBARA_TEMPLATES["not_understood"] | |
| elif language_code == "ful": | |
| response_text, english_translation = FULA_TEMPLATES["not_understood"] | |
| # ── 3. MMS-TTS (GPU) ────────────────────────────────────────────────────── | |
| wav_np, sample_rate = _tts.synthesize(response_text, language_code, device=device) | |
| return transcript, english_translation, response_text, (sample_rate, wav_np) | |
| # ── Conversation-mode helpers ───────────────────────────────────────────────── | |
| # Vocabulary context cache — loaded from Hub, refreshed after each LEARNED save | |
| _vocab_context_cache: str = "" | |
| _vocab_lock = threading.Lock() | |
| def _refresh_vocab_context() -> None: | |
| """Load vocabulary.jsonl from Hub and rebuild the LLM context string.""" | |
| global _vocab_context_cache | |
| if not HF_TOKEN or not FEEDBACK_REPO_ID: | |
| return | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename="vocabulary.jsonl", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| entries: list[dict] = [] | |
| with open(local, encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| try: | |
| entries.append(json.loads(line)) | |
| except Exception: | |
| pass | |
| # Most recent first, cap at 200 entries to stay within token budget | |
| entries = entries[-200:][::-1] | |
| lines = [] | |
| for e in entries: | |
| word = e.get("word", "").strip() | |
| tr = e.get("translation", "").strip() | |
| lang = e.get("language", "") | |
| if word: | |
| lines.append(f"{word} = {tr} [{lang}]" if tr else f"{word} [{lang}]") | |
| with _vocab_lock: | |
| _vocab_context_cache = "\n".join(lines) | |
| except Exception: | |
| pass # Non-critical — LLM continues without vocab context | |
| def _get_vocab_context() -> str: | |
| with _vocab_lock: | |
| return _vocab_context_cache | |
| def _save_learned_async(word: str, meaning: str, lang: str) -> None: | |
| """Persist a word/phrase learned mid-conversation to vocabulary.jsonl on Hub.""" | |
| def _run(): | |
| if not word.strip(): | |
| return | |
| entry = {"word": word.strip(), "translation": meaning.strip(), "language": lang, | |
| "source": "conversation", "timestamp": datetime.now(timezone.utc).isoformat()} | |
| _upload_jsonl_later("vocabulary.jsonl", [entry]) | |
| _refresh_vocab_context() # update cache so next turn knows this word | |
| threading.Thread(target=_run, daemon=True).start() | |
| def _upload_jsonl_later(repo_path: str, entries: list[dict]) -> None: | |
| """Append entries to a Hub JSONL file — called from background threads.""" | |
| if not HF_TOKEN or not FEEDBACK_REPO_ID or _hf_api is None: | |
| return | |
| from huggingface_hub import hf_hub_download | |
| for attempt in range(2): | |
| try: | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename=repo_path, | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| existing = f.read() | |
| except Exception: | |
| existing = "" | |
| updated = existing + "".join(json.dumps(e, ensure_ascii=False) + "\n" for e in entries) | |
| try: | |
| _hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(updated.encode("utf-8")), | |
| path_in_repo=repo_path, | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| return | |
| except Exception: | |
| if attempt == 1: | |
| pass | |
| import re as _re | |
| _LEARNED_RE = _re.compile( | |
| r'\[LEARNED:\s*word=["\'](.+?)["\']\s+meaning=["\'](.+?)["\']\s*\]', | |
| _re.IGNORECASE, | |
| ) | |
| def _parse_and_strip_learned(text: str, lang: str) -> tuple[str, list[tuple[str, str]]]: | |
| """ | |
| Extract [LEARNED: word="X" meaning="Y"] tags from LLM output. | |
| Returns (cleaned_text, list_of_(word, meaning) pairs). | |
| Saves each pair to Hub asynchronously. | |
| """ | |
| learned = [] | |
| for m in _LEARNED_RE.finditer(text): | |
| word, meaning = m.group(1).strip(), m.group(2).strip() | |
| learned.append((word, meaning)) | |
| _save_learned_async(word, meaning, lang) | |
| cleaned = _LEARNED_RE.sub("", text).strip() | |
| return cleaned, learned | |
| # System prompt — includes vocabulary context + conversation rules | |
| def _build_system_prompt(language_code: str, vocab: str) -> str: | |
| """ | |
| Build a language-specific system prompt that makes the LLM stay strictly | |
| in the correct dialect (Mali Bambara vs Guinea Pular) and never mix them. | |
| """ | |
| ctx = LANG_CONTEXT.get(language_code, LANG_CONTEXT["fr"]) | |
| lang_name = ctx["name"] | |
| country = ctx["country"] | |
| region = ctx["region"] | |
| phon_note = ctx["phonetic_note"] | |
| do_not_mix = ctx["do_not_mix"] | |
| region_line = f" ({region})" if region else "" | |
| vocab_section = ( | |
| f"WORDS AND PHRASES YOU HAVE LEARNED FOR {lang_name.upper()}:\n{vocab}" | |
| if vocab | |
| else f"(No {lang_name} vocabulary recorded yet — the user can teach you words.)" | |
| ) | |
| return f"""\ | |
| You are a voice assistant that speaks ONLY {lang_name} as used in {country}{region_line}. | |
| CRITICAL LANGUAGE RULE: | |
| - You MUST respond exclusively in {lang_name} ({country}). | |
| - NEVER mix in words from {do_not_mix}. | |
| - If the user writes in another language, gently ask them to switch to {lang_name}. | |
| - If you are unsure of a word in {lang_name}, say so honestly — do not substitute \ | |
| a word from another language. | |
| ORTHOGRAPHY ({lang_name}): | |
| {phon_note} | |
| {vocab_section} | |
| CONVERSATION RULES: | |
| 1. Keep every response to 1–3 short spoken sentences. This is voice, not text. | |
| 2. If you do not understand, ask ONE short follow-up question in {lang_name}. | |
| 3. If the user teaches you a word ("X means Y"), confirm warmly, then append \ | |
| exactly: [LEARNED: word="X" meaning="Y"] | |
| 4. Refer back to earlier messages naturally when relevant. | |
| 5. Never invent vocabulary. Honest uncertainty is always correct.""" | |
| def _get_vocab_context_for(language_code: str) -> str: | |
| """Return only vocabulary entries for the given language code.""" | |
| with _vocab_lock: | |
| raw = _vocab_context_cache | |
| if not raw: | |
| return "" | |
| lines = [ | |
| line for line in raw.splitlines() | |
| if f"[{language_code}]" in line | |
| ] | |
| return "\n".join(lines) | |
| def _build_messages(user_text: str, history: list, language_code: str) -> list[dict]: | |
| """Build the full message list: system (with lang-filtered vocab) + history + new turn.""" | |
| vocab = _get_vocab_context_for(language_code) | |
| system = _build_system_prompt(language_code, vocab) | |
| messages: list[dict] = [{"role": "system", "content": system}] | |
| for u, a in history[-20:]: | |
| messages.append({"role": "user", "content": u}) | |
| messages.append({"role": "assistant", "content": a}) | |
| messages.append({"role": "user", "content": user_text}) | |
| return messages | |
| def set_voice_reference(audio_file) -> str: | |
| """ | |
| Store an uploaded audio file as the TTS voice reference. | |
| Converts to 24 kHz WAV (F5-TTS requirement) and auto-transcribes. | |
| Returns a status string for the UI. | |
| """ | |
| global _voice_ref_path, _voice_ref_text | |
| if audio_file is None: | |
| _voice_ref_path = None | |
| _voice_ref_text = "" | |
| return "🗑️ Voice reference cleared — using default MMS-TTS voice." | |
| try: | |
| from src.tts.f5_tts import to_wav_24k | |
| wav_path = to_wav_24k(audio_file) | |
| _voice_ref_path = wav_path | |
| # Auto-transcribe using already-loaded Whisper if available | |
| if _whisper_model is not None and _whisper_processor is not None: | |
| import torch, librosa | |
| audio_np, _ = librosa.load(wav_path, sr=16000, mono=True) | |
| with _model_lock: | |
| inputs = _whisper_processor.feature_extractor( | |
| audio_np, sampling_rate=16000, return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| ids = _whisper_model.generate( | |
| inputs.input_features, | |
| max_new_tokens=128, | |
| ) | |
| _voice_ref_text = _whisper_processor.batch_decode( | |
| ids, skip_special_tokens=True | |
| )[0].strip() | |
| return ( | |
| f"✅ Voice reference set!\n" | |
| f"File : {Path(audio_file).name}\n" | |
| f"Transcript : {_voice_ref_text[:80] or '(empty — F5-TTS will use in-context inference)'}" | |
| ) | |
| else: | |
| _voice_ref_text = "" | |
| return ( | |
| f"✅ Voice reference set (model not loaded yet — transcript pending).\n" | |
| f"File: {Path(audio_file).name}" | |
| ) | |
| except Exception as exc: | |
| return f"❌ Could not process reference audio: {exc}" | |
| def _convo_pipeline(audio_path: str, language_code: str, history: list): | |
| """ | |
| Full S2S conversation pipeline with memory: | |
| 1. ASR — fine-tuned Whisper (or base) → transcript | |
| 2. Norm — bam_normalize() on Bambara text | |
| 3. Brain — LLM with full conversation history + vocabulary context | |
| 4. Learn — parse [LEARNED:] tags, persist to Hub async | |
| 5. Mouth — F5-TTS (voice ref) or MMS-TTS fallback → audio | |
| Returns: (transcript, eng, response_text, audio_out, new_history) | |
| """ | |
| import torch | |
| import logging | |
| log = logging.getLogger(__name__) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if _whisper_model is None: | |
| return "⏳ Model still loading…", "", "", None, history | |
| import librosa | |
| audio_np, _ = librosa.load(audio_path, sr=16000, mono=True) | |
| active_model = _fine_tuned_models.get(language_code, _whisper_model) | |
| active_processor = _fine_tuned_processors.get(language_code, _whisper_processor) | |
| active_model.to(device) | |
| with _model_lock: | |
| inputs = active_processor.feature_extractor( | |
| audio_np, sampling_rate=16000, return_tensors="pt" | |
| ) | |
| input_features = inputs.input_features.to(device) | |
| forced_ids = None | |
| if language_code not in ("bam", "ful"): | |
| forced_ids = active_processor.get_decoder_prompt_ids( | |
| language=language_code, task="transcribe" | |
| ) | |
| with torch.no_grad(): | |
| predicted_ids = active_model.generate( | |
| input_features, | |
| forced_decoder_ids=forced_ids if forced_ids else None, | |
| max_new_tokens=256, | |
| ) | |
| transcript = active_processor.batch_decode( | |
| predicted_ids, skip_special_tokens=True | |
| )[0].strip() | |
| active_model.to("cpu") | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # Phonetic normalisation (Bambara: French spellings → standard; Fula: Adlam → Latin) | |
| if language_code == "bam": | |
| normalised = bam_normalize(transcript) | |
| elif language_code == "ful": | |
| normalised = normalize_pular(transcript) | |
| else: | |
| normalised = transcript | |
| # ── LLM brain — full context: vocab + history + new turn ───────────────── | |
| response_text = "" | |
| try: | |
| from huggingface_hub import InferenceClient | |
| client = InferenceClient(token=HF_TOKEN) | |
| messages = _build_messages(normalised, history, language_code) | |
| completion = client.chat_completion( | |
| model=LLM_MODEL_ID, | |
| messages=messages, | |
| max_tokens=300, | |
| temperature=0.65, | |
| ) | |
| response_text = completion.choices[0].message.content.strip() | |
| except Exception as llm_err: | |
| log.warning("LLM failed: %s", llm_err) | |
| # Graceful degradation: tell user LLM is unavailable, ask them to try again | |
| _fallbacks = { | |
| "bam": "Hakɛ to, n bɛ sɔrɔ cogo dɔ la.", # Bambara (Mali) | |
| "ful": "Hakke, mi waawaa jogaade modèl oo jooni.", # Pular (Guinea) | |
| "fr": "Désolé, je n'ai pas pu joindre le modèle.", | |
| } | |
| response_text = _fallbacks.get(language_code, "Sorry, the language model is unavailable.") | |
| # ── Parse and strip [LEARNED:] tags — save async to Hub ────────────────── | |
| response_text, learned_pairs = _parse_and_strip_learned(response_text, language_code) | |
| if learned_pairs: | |
| log.info("Learned %d new item(s): %s", len(learned_pairs), learned_pairs) | |
| # ── Update conversation history ─────────────────────────────────────────── | |
| new_history = list(history) + [(normalised, response_text)] | |
| if len(new_history) > 20: | |
| new_history = new_history[-20:] | |
| # ── TTS mouth — F5-TTS (voice ref) or MMS-TTS fallback ─────────────────── | |
| audio_out = None | |
| if _voice_ref_path and Path(_voice_ref_path).exists(): | |
| try: | |
| from src.tts.f5_tts import synthesize as f5_synthesize | |
| result = f5_synthesize( | |
| response_text, | |
| ref_wav_path=_voice_ref_path, | |
| ref_text=_voice_ref_text, | |
| device=device, | |
| ) | |
| if result is not None: | |
| wav_np, sr = result | |
| audio_out = (sr, wav_np) | |
| except Exception as tts_err: | |
| log.warning("F5-TTS failed, using MMS-TTS: %s", tts_err) | |
| if audio_out is None: | |
| wav_np, sr = _tts.synthesize(response_text, language_code, device=device) | |
| audio_out = (sr, wav_np) | |
| return transcript, "", response_text, audio_out, new_history | |
| # ── HF Hub feedback persistence ─────────────────────────────────────────────── | |
| def _save_feedback_to_hub( | |
| audio_path: str | None, | |
| transcript: str, | |
| corrected_text: str, | |
| english_translation: str, | |
| corrected_english: str, | |
| response_text: str, | |
| corrected_response: str, | |
| rating: int, | |
| notes: str, | |
| language_label: str, | |
| ) -> str: | |
| language_code = SUPPORTED_LANGUAGES.get(language_label, "bam") | |
| if not corrected_text.strip(): | |
| return "⚠️ Corrected transcription is empty — please fill in what was actually said." | |
| timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f") | |
| record = { | |
| "id": timestamp, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| "language": language_code, | |
| "audio_file": f"audio/{language_code}_{timestamp}.wav", | |
| "whisper_output": transcript, | |
| "corrected_text": corrected_text.strip(), | |
| "english_translation": english_translation.strip(), | |
| "corrected_english": corrected_english.strip() or english_translation.strip(), | |
| "response_text": response_text, | |
| "corrected_response": corrected_response.strip() or response_text.strip(), | |
| "rating": rating, | |
| "notes": notes.strip(), | |
| "is_correction": transcript.strip() != corrected_text.strip(), | |
| "model": WHISPER_MODEL_ID, | |
| } | |
| if _hf_api is None: | |
| # Local: save to disk instead | |
| fb_dir = ROOT / "feedback" | |
| fb_dir.mkdir(exist_ok=True) | |
| (fb_dir / "audio").mkdir(exist_ok=True) | |
| corrections_path = fb_dir / "corrections.jsonl" | |
| if audio_path: | |
| import shutil | |
| shutil.copy2(audio_path, fb_dir / "audio" / f"{language_code}_{timestamp}.wav") | |
| with open(corrections_path, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| total = sum(1 for _ in open(corrections_path, encoding="utf-8")) | |
| return f"✅ Saved locally (#{total}) — HF_TOKEN not set, Hub upload skipped." | |
| try: | |
| # Upload audio | |
| if audio_path: | |
| _hf_api.upload_file( | |
| path_or_fileobj=audio_path, | |
| path_in_repo=f"audio/{language_code}_{timestamp}.wav", | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| # Download → append → re-upload corrections.jsonl (with retry on conflict) | |
| from huggingface_hub import hf_hub_download | |
| for attempt in range(2): | |
| try: | |
| local_jsonl = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, | |
| filename="corrections.jsonl", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| with open(local_jsonl, encoding="utf-8") as f: | |
| existing = f.read() | |
| except Exception: | |
| existing = "" | |
| updated = existing + json.dumps(record, ensure_ascii=False) + "\n" | |
| buf = io.BytesIO(updated.encode("utf-8")) | |
| try: | |
| _hf_api.upload_file( | |
| path_or_fileobj=buf, | |
| path_in_repo="corrections.jsonl", | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| break | |
| except Exception as e: | |
| if attempt == 1: | |
| return f"⚠️ Audio uploaded but corrections.jsonl update failed: {e}" | |
| total = updated.count("\n") | |
| _maybe_auto_trigger() | |
| return f"✅ Saved to Hub (#{total}) — {FEEDBACK_REPO_ID}" | |
| except Exception as e: | |
| return f"❌ Hub upload error: {e}" | |
| # ── Adapter reload ──────────────────────────────────────────────────────────── | |
| def _reload_adapters_from_hub() -> str: | |
| """Download full fine-tuned checkpoints from Hub and hot-swap them into memory.""" | |
| global _fine_tuned_models, _fine_tuned_processors | |
| if _hf_api is None: | |
| return "⚠️ HF_TOKEN not set — cannot download checkpoints." | |
| if _whisper_model is None: | |
| return "⏳ Base model not loaded yet — wait for model to finish loading and try again." | |
| try: | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| try: | |
| from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration | |
| from transformers.models.whisper.processing_whisper import WhisperProcessor | |
| except ImportError: | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| local_dir = snapshot_download( | |
| repo_id=ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN | |
| ) | |
| results = [] | |
| for lang, subdir in (("bam", "adapters/bambara"), ("ful", "adapters/fula")): | |
| ckpt_path = Path(local_dir) / subdir | |
| if not ckpt_path.exists(): | |
| results.append(f"⚠️ {lang}: `{subdir}` not found — run training notebook first") | |
| continue | |
| if not (ckpt_path / "config.json").exists(): | |
| results.append(f"⚠️ {lang}: `{subdir}/config.json` missing — incomplete checkpoint") | |
| continue | |
| try: | |
| m = WhisperForConditionalGeneration.from_pretrained( | |
| str(ckpt_path), torch_dtype=torch.float32 | |
| ) | |
| m.eval() | |
| _fine_tuned_models[lang] = m | |
| # Load the processor from the same checkpoint so the feature | |
| # extractor mel-bin count matches the model (e.g. 128 for | |
| # large-v3-turbo vs 80 for small). Mismatch causes: | |
| # "expected input to have 128 channels, but got 80 channels" | |
| _fine_tuned_processors[lang] = WhisperProcessor.from_pretrained( | |
| str(ckpt_path) | |
| ) | |
| results.append(f"✅ {lang}: fine-tuned checkpoint loaded from `{subdir}`") | |
| except Exception as e: | |
| results.append(f"❌ {lang}: load failed — {e}") | |
| summary = "\n".join(results) | |
| active = ", ".join(_fine_tuned_models) if _fine_tuned_models else "none" | |
| return f"{summary}\n\n**Active fine-tuned models:** {active}\n**Repo:** `{ADAPTER_REPO_ID}`" | |
| except Exception as e: | |
| return f"❌ Checkpoint reload failed: {e}" | |
| def _get_adapter_status() -> str: | |
| lines = [] | |
| if _fine_tuned_models: | |
| lines.append(f"**Fine-tuned models loaded:** {', '.join(sorted(_fine_tuned_models))}") | |
| else: | |
| lines.append("**Fine-tuned models:** none — using base Whisper for all languages") | |
| if _hf_api is None: | |
| lines.append("_HF_TOKEN not set — Hub check skipped._") | |
| return "\n".join(lines) | |
| try: | |
| from huggingface_hub import list_repo_files | |
| files = list(list_repo_files(ADAPTER_REPO_ID, repo_type="model", token=HF_TOKEN)) | |
| bam_ok = any("bambara" in f and "config.json" in f for f in files) | |
| ful_ok = any("fula" in f and "config.json" in f for f in files) | |
| lines += [ | |
| f"\n**Hub repo:** `{ADAPTER_REPO_ID}`", | |
| f"- Bambara (bam): {'✅ trained checkpoint present' if bam_ok else '⚠️ not yet trained — run Kaggle notebook'}", | |
| f"- Fula (ful): {'✅ trained checkpoint present' if ful_ok else '⚠️ not yet trained — run Kaggle notebook'}", | |
| ] | |
| if bam_ok or ful_ok: | |
| lines.append("\n_Click **Reload Models** to activate them._") | |
| except Exception as e: | |
| lines.append(f"_Could not read Hub repo: {e}_") | |
| return "\n".join(lines) | |
| # ── Knowledge Base handlers ─────────────────────────────────────────────────── | |
| def _import_phrase_pairs(lang_label: str, pairs_text: str) -> str: | |
| """Import pasted phrase pairs into the phrase library and append to vocabulary.jsonl.""" | |
| if not pairs_text.strip(): | |
| return "⚠️ Nothing entered. Use the format: native phrase | english translation" | |
| lang = SUPPORTED_LANGUAGES.get(lang_label, "bam") | |
| count = _phrase_matcher.import_pairs(lang, pairs_text) | |
| if count == 0: | |
| return "⚠️ No valid phrases found. Each line must contain a | separator.\nExample: I ni ce | Hello, good day" | |
| _upload_phrase_additions_to_hub(lang) | |
| # Also append to vocabulary.jsonl so the Kaggle training notebook picks them up | |
| _append_phrases_to_vocabulary_jsonl(lang, pairs_text) | |
| total = _phrase_matcher.phrase_count(lang) | |
| return f"✅ Added {count} phrase(s) for {lang_label}. Library now has {total} phrases. Available immediately." | |
| def _extract_text_from_document(file_path: str) -> str: | |
| """Extract plain text from a PDF, DOCX, or TXT file. Returns empty string on failure.""" | |
| ext = Path(file_path).suffix.lower() | |
| try: | |
| if ext == ".pdf": | |
| from pypdf import PdfReader | |
| reader = PdfReader(file_path) | |
| return "\n".join((p.extract_text() or "") for p in reader.pages) | |
| if ext in (".docx", ".doc"): | |
| from docx import Document | |
| doc = Document(file_path) | |
| return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) | |
| if ext in (".txt", ".md"): | |
| with open(file_path, encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| except Exception as exc: | |
| import logging | |
| logging.getLogger(__name__).warning("Document extract failed for %s: %s", file_path, exc) | |
| return "" | |
| def _sentences_from_text(text: str, min_words: int = 3, max_words: int = 25) -> list[str]: | |
| """Split extracted text into clean sentences suitable for vocabulary.jsonl.""" | |
| import re as _re | |
| # Normalise whitespace and split on sentence boundaries (., !, ?, or double newline) | |
| text = _re.sub(r"\s+", " ", text).strip() | |
| raw = _re.split(r"(?<=[.!?])\s+|\n\n+", text) | |
| out = [] | |
| seen = set() | |
| for s in raw: | |
| s = s.strip(" \t\"'`—–-") | |
| if not s: | |
| continue | |
| words = s.split() | |
| if not (min_words <= len(words) <= max_words): | |
| continue | |
| key = s.lower() | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| out.append(s) | |
| return out | |
| def _import_documents(lang_label: str, files: list, source_note: str) -> str: | |
| """Extract sentences from uploaded PDF/Word/TXT files and append to vocabulary.jsonl.""" | |
| if not files: | |
| return "⚠️ Please upload at least one document first." | |
| lang = SUPPORTED_LANGUAGES.get(lang_label, "bam") | |
| # Language normalisation — same rule as other ingestion paths | |
| total_sentences = 0 | |
| per_file_summary = [] | |
| all_entries: list[dict] = [] | |
| for f in files: | |
| # Gradio File component returns a tempfile path (or an object with .name) | |
| path = f if isinstance(f, str) else getattr(f, "name", None) | |
| if not path: | |
| continue | |
| text = _extract_text_from_document(path) | |
| if not text.strip(): | |
| per_file_summary.append(f" - {Path(path).name}: ⚠️ no text extracted") | |
| continue | |
| # Apply language-specific normalisation so Adlam → Latin etc. | |
| try: | |
| if lang == "ful": | |
| text = normalize_pular(text) | |
| elif lang == "bam": | |
| text = bam_normalize(text) | |
| except Exception: | |
| pass | |
| sentences = _sentences_from_text(text) | |
| for s in sentences: | |
| all_entries.append({ | |
| "word": s, | |
| "translation": "", | |
| "language": lang, | |
| "source": f"document: {source_note or Path(path).name}", | |
| }) | |
| per_file_summary.append(f" - {Path(path).name}: {len(sentences)} sentence(s)") | |
| total_sentences += len(sentences) | |
| if not all_entries: | |
| return "⚠️ No usable sentences found in the uploaded document(s).\n" + "\n".join(per_file_summary) | |
| # Append to vocabulary.jsonl on Hub (same pattern as _append_phrases_to_vocabulary_jsonl) | |
| if _hf_api is not None and FEEDBACK_REPO_ID: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename="vocabulary.jsonl", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| existing = f.read() | |
| except Exception: | |
| existing = "" | |
| new_lines = "".join(json.dumps(e, ensure_ascii=False) + "\n" for e in all_entries) | |
| _hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO((existing + new_lines).encode("utf-8")), | |
| path_in_repo="vocabulary.jsonl", | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| threading.Thread(target=_refresh_vocab_context, daemon=True).start() | |
| except Exception as exc: | |
| return f"⚠️ Extracted {total_sentences} sentence(s) but Hub upload failed: {exc}" | |
| return ( | |
| f"✅ Imported {total_sentences} sentence(s) for {lang_label} from {len(files)} document(s).\n" | |
| + "\n".join(per_file_summary) | |
| + "\n\nThese will be used by the Kaggle training notebook on the next run." | |
| ) | |
| def _append_phrases_to_vocabulary_jsonl(lang: str, pairs_text: str) -> None: | |
| """Append phrase pairs to vocabulary.jsonl in the feedback repo (training input).""" | |
| if _hf_api is None or not FEEDBACK_REPO_ID: | |
| return | |
| entries = [] | |
| for line in pairs_text.splitlines(): | |
| if "|" not in line: | |
| continue | |
| parts = line.split("|", 1) | |
| word = parts[0].strip() | |
| translation = parts[1].strip() if len(parts) > 1 else "" | |
| if word: | |
| entries.append({"word": word, "translation": translation, "language": lang}) | |
| if not entries: | |
| return | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| for attempt in range(2): | |
| try: | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename="vocabulary.jsonl", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| existing = f.read() | |
| except Exception: | |
| existing = "" | |
| new_lines = "".join(json.dumps(e, ensure_ascii=False) + "\n" for e in entries) | |
| updated = existing + new_lines | |
| try: | |
| _hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(updated.encode("utf-8")), | |
| path_in_repo="vocabulary.jsonl", | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| threading.Thread(target=_refresh_vocab_context, daemon=True).start() | |
| break | |
| except Exception: | |
| if attempt == 1: | |
| pass # Silent — phrase library still updated locally | |
| except Exception: | |
| pass # Non-critical — phrase library already saved via _upload_phrase_additions_to_hub | |
| def _upload_phrase_additions_to_hub(lang: str) -> None: | |
| """Persist user phrase additions to HF Hub so they survive Space restarts.""" | |
| if _hf_api is None or not FEEDBACK_REPO_ID: | |
| return | |
| try: | |
| import io | |
| data = _phrase_matcher.get_additions_json(lang) | |
| buf = io.BytesIO(data.encode("utf-8")) | |
| _hf_api.upload_file( | |
| path_or_fileobj=buf, | |
| path_in_repo=f"phrase_additions/{lang}.json", | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| except Exception as exc: | |
| import logging | |
| logging.getLogger(__name__).warning("Could not upload phrase additions: %s", exc) | |
| def _load_phrase_additions_from_hub() -> None: | |
| """Download and merge user phrase additions from HF Hub at startup.""" | |
| if _hf_api is None or not FEEDBACK_REPO_ID: | |
| return | |
| for lang in ("bam", "ful"): | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, | |
| filename=f"phrase_additions/{lang}.json", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| data = f.read() | |
| _phrase_matcher.reload_from_hub_data(lang, data) | |
| except Exception: | |
| pass # No additions saved yet — fine | |
| # Load phrase additions + vocabulary context in background at startup | |
| threading.Thread(target=_load_phrase_additions_from_hub, daemon=True).start() | |
| threading.Thread(target=_refresh_vocab_context, daemon=True).start() | |
| def _save_audio_for_training(lang_label: str, audio_path: str | None, transcript: str, source_note: str) -> str: | |
| """Save uploaded audio + transcription to corrections.jsonl so the Kaggle notebook picks it up.""" | |
| transcript = transcript.strip() | |
| if audio_path is None: | |
| return "⚠️ Please upload an audio file first." | |
| if not transcript: | |
| return "⚠️ Please type the transcription — what is said in this audio." | |
| lang = SUPPORTED_LANGUAGES.get(lang_label, "bam") | |
| timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f") | |
| # Store under audio/ — same path structure that corrections.jsonl expects | |
| audio_repo_path = f"audio/{lang}_{timestamp}.wav" | |
| record = { | |
| "id": timestamp, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| "language": lang, | |
| "audio_file": audio_repo_path, | |
| "transcription": transcript, # notebook reads this field | |
| "corrected_text": transcript, # also populate corrected_text for compatibility | |
| "source": source_note.strip() or "uploaded", | |
| "is_correction": False, | |
| "model": WHISPER_MODEL_ID, | |
| } | |
| if _hf_api is None or not FEEDBACK_REPO_ID: | |
| return "⚠️ HF_TOKEN not set — cannot upload to Hub." | |
| try: | |
| # Upload audio to audio/ (same bucket corrections use) | |
| _hf_api.upload_file( | |
| path_or_fileobj=audio_path, | |
| path_in_repo=audio_repo_path, | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| # Append to corrections.jsonl (same file the notebook reads) | |
| from huggingface_hub import hf_hub_download | |
| for attempt in range(2): | |
| try: | |
| local_jsonl = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename="corrections.jsonl", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local_jsonl, encoding="utf-8") as f: | |
| existing = f.read() | |
| except Exception: | |
| existing = "" | |
| updated = existing + json.dumps(record, ensure_ascii=False) + "\n" | |
| try: | |
| _hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(updated.encode("utf-8")), | |
| path_in_repo="corrections.jsonl", | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| break | |
| except Exception as e: | |
| if attempt == 1: | |
| return f"⚠️ Audio uploaded but corrections.jsonl update failed: {e}" | |
| total = updated.count("\n") | |
| return ( | |
| f"✅ Saved to training dataset (#{total} total corrections)!\n" | |
| f"Audio: {audio_repo_path}\n" | |
| f"Transcription: {transcript[:80]}{'…' if len(transcript) > 80 else ''}\n" | |
| f"Run the Kaggle notebook to include this in the next model update." | |
| ) | |
| except Exception as exc: | |
| return f"❌ Upload failed: {exc}" | |
| # ── Auto-training trigger ───────────────────────────────────────────────────── | |
| def _count_corrections() -> int: | |
| """Return number of entries in corrections.jsonl on the Hub.""" | |
| if _hf_api is None: | |
| return 0 | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename="corrections.jsonl", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| return sum(1 for l in f if l.strip()) | |
| except Exception: | |
| return 0 | |
| def _trigger_kaggle_training(lang: str = "bam") -> str: | |
| """ | |
| Push the master trainer notebook to Kaggle, creating a new kernel version | |
| (i.e. triggering a run). Requires KAGGLE_USERNAME + KAGGLE_KEY secrets. | |
| Tries the Python API first (no PATH issues), falls back to subprocess. | |
| """ | |
| if not KAGGLE_USERNAME or not KAGGLE_KEY: | |
| return "⚠️ KAGGLE_USERNAME / KAGGLE_KEY not set in Space secrets." | |
| notebooks_dir = ROOT / "notebooks" | |
| nb_file = notebooks_dir / "kaggle_master_trainer.ipynb" | |
| meta_file = notebooks_dir / "kernel-metadata.json" | |
| if not nb_file.exists(): | |
| return "❌ notebooks/kaggle_master_trainer.ipynb not found in Space." | |
| if not meta_file.exists(): | |
| return "❌ notebooks/kernel-metadata.json not found in Space." | |
| # Inject credentials into os.environ before any kaggle import — | |
| # kaggle.authenticate() reads KAGGLE_USERNAME + KAGGLE_KEY from env. | |
| os.environ["KAGGLE_USERNAME"] = KAGGLE_USERNAME | |
| os.environ["KAGGLE_KEY"] = KAGGLE_KEY | |
| # ── Method 1: Python API (no binary PATH issues) ───────────────────────── | |
| api_err = None | |
| try: | |
| from kaggle.api.kaggle_api_extended import KaggleApiExtended | |
| _kapi = KaggleApiExtended() | |
| _kapi.authenticate() | |
| _kapi.kernels_push_cli(str(notebooks_dir), quiet=True) | |
| return ( | |
| "✅ Kaggle training triggered!\n" | |
| f"Kernel: {KAGGLE_KERNEL_SLUG}\n" | |
| "Check https://www.kaggle.com for run progress." | |
| ) | |
| except Exception as e: | |
| api_err = str(e) | |
| # ── Method 2: subprocess fallback ──────────────────────────────────────── | |
| import subprocess, shutil | |
| kaggle_bin = shutil.which("kaggle") | |
| if kaggle_bin is None: | |
| for cand in [ | |
| Path(sys.executable).parent / "kaggle", | |
| Path("/usr/local/bin/kaggle"), | |
| Path("/usr/bin/kaggle"), | |
| ]: | |
| if cand.exists(): | |
| kaggle_bin = str(cand) | |
| break | |
| if kaggle_bin is None: | |
| return ( | |
| f"❌ Kaggle CLI not found (API error: {api_err}).\n" | |
| "Ensure kaggle>=1.6.0 is in requirements.txt and the Space rebuilt." | |
| ) | |
| env = { | |
| **os.environ, | |
| "KAGGLE_USERNAME": KAGGLE_USERNAME, | |
| "KAGGLE_KEY": KAGGLE_KEY, | |
| "PYTHONUTF8": "1", | |
| "PYTHONIOENCODING": "utf-8", | |
| } | |
| try: | |
| result = subprocess.run( | |
| [kaggle_bin, "kernels", "push", "-p", str(notebooks_dir)], | |
| capture_output=True, text=True, timeout=60, env=env, | |
| ) | |
| if result.returncode == 0: | |
| out = (result.stdout or "").strip() | |
| return f"✅ Kaggle training triggered!\n{out or 'Kernel version created.'}" | |
| err = (result.stderr or result.stdout or "unknown error").strip() | |
| return f"❌ Kaggle push failed:\n{err}" | |
| except Exception as e: | |
| return f"❌ Kaggle push failed (API: {api_err}) (CLI: {e})" | |
| def _maybe_auto_trigger() -> None: | |
| """Called after each correction save. Triggers Kaggle if threshold met.""" | |
| if not KAGGLE_USERNAME or not KAGGLE_KEY: | |
| return | |
| count = _count_corrections() | |
| if count > 0 and count % AUTO_TRAIN_THRESHOLD == 0: | |
| import logging | |
| def _run(): | |
| msg = _trigger_kaggle_training() | |
| logging.getLogger(__name__).info("Auto-trigger: %s", msg) | |
| threading.Thread(target=_run, daemon=True).start() | |
| # ── Bulk upload handler ──────────────────────────────────────────────────────── | |
| def _bulk_upload(lang_label: str, zip_file, csv_text: str) -> str: | |
| """ | |
| Accept a ZIP of audio files + a CSV (filename,transcription) and batch-insert | |
| all samples into corrections.jsonl. Audio stored under audio/ in the Hub repo. | |
| """ | |
| import zipfile, csv | |
| if _hf_api is None: | |
| return "⚠️ HF_TOKEN not set — cannot upload." | |
| if zip_file is None and not csv_text.strip(): | |
| return "⚠️ Upload a ZIP and/or paste a CSV." | |
| lang = SUPPORTED_LANGUAGES.get(lang_label, "bam") | |
| rows = [] # (audio_bytes_or_None, filename, transcription) | |
| # Parse CSV | |
| transcript_map: dict[str, str] = {} | |
| if csv_text.strip(): | |
| for row in csv.reader(csv_text.strip().splitlines()): | |
| if len(row) >= 2: | |
| transcript_map[row[0].strip()] = row[1].strip() | |
| # Extract ZIP | |
| if zip_file is not None: | |
| try: | |
| with zipfile.ZipFile(zip_file, "r") as zf: | |
| for name in zf.namelist(): | |
| if not name.lower().endswith((".wav", ".mp3", ".ogg", ".flac", ".m4a")): | |
| continue | |
| text = transcript_map.get(name) or transcript_map.get(Path(name).name) or "" | |
| if not text: | |
| continue | |
| rows.append((zf.read(name), Path(name).name, text)) | |
| except Exception as e: | |
| return f"❌ ZIP read error: {e}" | |
| elif transcript_map: | |
| # CSV only — audio-less vocab entries | |
| for fname, text in transcript_map.items(): | |
| rows.append((None, fname, text)) | |
| if not rows: | |
| return "⚠️ No matching (audio, transcription) pairs found. Check filenames match CSV." | |
| # Upload batch | |
| records = [] | |
| errors = 0 | |
| for audio_bytes, fname, text in rows: | |
| ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S_%f") | |
| audio_path = f"audio/{lang}_{ts}.wav" | |
| try: | |
| if audio_bytes: | |
| _hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(audio_bytes), | |
| path_in_repo=audio_path, | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| records.append({ | |
| "id": ts, "timestamp": datetime.now(timezone.utc).isoformat(), | |
| "language": lang, | |
| "audio_file": audio_path if audio_bytes else "", | |
| "transcription": text, "corrected_text": text, | |
| "source": f"bulk_upload:{fname}", "is_correction": False, | |
| "model": WHISPER_MODEL_ID, | |
| }) | |
| except Exception: | |
| errors += 1 | |
| # Append all to corrections.jsonl | |
| from huggingface_hub import hf_hub_download | |
| for attempt in range(2): | |
| try: | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename="corrections.jsonl", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| existing = f.read() | |
| except Exception: | |
| existing = "" | |
| new_lines = "".join(json.dumps(r, ensure_ascii=False) + "\n" for r in records) | |
| updated = existing + new_lines | |
| try: | |
| _hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(updated.encode("utf-8")), | |
| path_in_repo="corrections.jsonl", | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| break | |
| except Exception as e: | |
| if attempt == 1: | |
| return f"⚠️ Audio uploaded but corrections.jsonl failed: {e}" | |
| total = updated.count("\n") | |
| _maybe_auto_trigger() | |
| return ( | |
| f"✅ Bulk upload complete!\n" | |
| f" Uploaded : {len(records)} samples ({errors} errors)\n" | |
| f" Dataset : {total} total corrections\n" | |
| f" Auto-train threshold: {AUTO_TRAIN_THRESHOLD} entries" | |
| ) | |
| # ── Internet self-teaching handlers ─────────────────────────────────────────── | |
| def _upload_jsonl(repo_path: str, entries: list[dict]) -> tuple[int, str | None]: | |
| """Append entries to a jsonl file on the Hub. Returns (total_lines, error_or_None).""" | |
| from huggingface_hub import hf_hub_download | |
| for attempt in range(2): | |
| try: | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename=repo_path, | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| existing = f.read() | |
| except Exception: | |
| existing = "" | |
| updated = existing + "".join(json.dumps(e, ensure_ascii=False) + "\n" for e in entries) | |
| try: | |
| _hf_api.upload_file( | |
| path_or_fileobj=io.BytesIO(updated.encode("utf-8")), | |
| path_in_repo=repo_path, | |
| repo_id=FEEDBACK_REPO_ID, | |
| repo_type="dataset", | |
| ) | |
| return updated.count("\n"), None | |
| except Exception as e: | |
| if attempt == 1: | |
| return 0, str(e) | |
| return 0, "unknown error" | |
| def _harvest_wikipedia(lang_label: str, max_articles: int = 100) -> str: | |
| """Fetch Wikipedia text and append to vocabulary.jsonl.""" | |
| if _hf_api is None: | |
| return "⚠️ HF_TOKEN not set." | |
| import re as _re_tmp | |
| _m = _re_tmp.search(r'\((\w+)\)$', lang_label.strip()) | |
| lang = _m.group(1) if _m else SUPPORTED_LANGUAGES.get(lang_label, "bam") | |
| if lang not in ("bam", "ful"): | |
| return "⚠️ Supported for Bambara and Fula only." | |
| max_articles = int(max_articles) # Gradio slider returns float | |
| try: | |
| from src.data.web_harvester import harvest_wikipedia_text | |
| entries = harvest_wikipedia_text(lang, max_articles=max_articles) | |
| except Exception as e: | |
| return f"❌ Harvest error: {e}" | |
| if not entries: | |
| return "⚠️ No sentences extracted. Wikipedia may be temporarily unavailable." | |
| total, err = _upload_jsonl("vocabulary.jsonl", entries) | |
| if err: | |
| return f"❌ Upload failed: {err}" | |
| threading.Thread(target=_refresh_vocab_context, daemon=True).start() | |
| return ( | |
| f"✅ Wikipedia harvest complete!\n" | |
| f" Language : {lang_label}\n" | |
| f" Sentences added : {len(entries)}\n" | |
| f" Vocabulary total : {total} entries" | |
| ) | |
| def _harvest_hf_dataset(lang_label: str, max_samples: int = 500) -> str: | |
| """ | |
| Register an HF dataset as a training source by writing its config to | |
| dataset_sources.jsonl on the Hub. The Kaggle notebook reads this file | |
| at Cell 4 and loads the dataset directly — no audio re-upload needed. | |
| """ | |
| if _hf_api is None: | |
| return "⚠️ HF_TOKEN not set." | |
| # Dropdown sends "Bambara (bam)" / "Fula (ful)" — extract the code in parens | |
| import re as _re_tmp | |
| _m = _re_tmp.search(r'\((\w+)\)$', lang_label.strip()) | |
| lang = _m.group(1) if _m else SUPPORTED_LANGUAGES.get(lang_label, "bam") | |
| if lang not in ("bam", "ful"): | |
| return "⚠️ Supported for Bambara and Fula only." | |
| max_samples = int(max_samples) # Gradio slider returns float | |
| from src.data.web_harvester import get_hf_dataset_refs | |
| refs = get_hf_dataset_refs(lang) | |
| if not refs: | |
| return f"⚠️ No HF dataset configured for {lang}." | |
| # Read existing entries to avoid duplicates | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| local = hf_hub_download( | |
| repo_id=FEEDBACK_REPO_ID, filename="dataset_sources.jsonl", | |
| repo_type="dataset", token=HF_TOKEN, | |
| ) | |
| with open(local, encoding="utf-8") as f: | |
| existing_entries = [json.loads(l) for l in f if l.strip()] | |
| except Exception: | |
| existing_entries = [] | |
| existing_keys = { | |
| (e.get("repo", e.get("repo_id", "")), e.get("config", "")) | |
| for e in existing_entries | |
| } | |
| new_entries = [] | |
| already = [] | |
| for ref in refs: | |
| key = (ref.get("repo", ref.get("repo_id", "")), ref.get("config", "")) | |
| if key in existing_keys: | |
| already.append(ref["repo"]) | |
| continue | |
| entry = dict(ref) | |
| entry["max"] = max_samples | |
| entry["enabled"] = True | |
| entry["added_at"] = datetime.now(timezone.utc).isoformat() | |
| new_entries.append(entry) | |
| if not new_entries: | |
| repos = ", ".join(already) | |
| return ( | |
| f"✅ Already registered!\n" | |
| f" `{repos}` is already in your training config.\n" | |
| f" Click 'Trigger Training Now' to start a run with this data." | |
| ) | |
| total, err = _upload_jsonl("dataset_sources.jsonl", new_entries) | |
| if err: | |
| return f"❌ Upload failed: {err}" | |
| repos = ", ".join(r["repo"] for r in refs) | |
| return ( | |
| f"✅ Dataset registered for training!\n" | |
| f" Source(s) : {repos}\n" | |
| f" Max samples : {max_samples}\n" | |
| f" The Kaggle notebook will stream this dataset directly at training time.\n" | |
| f" Click 'Trigger Training Now' to start a training run.\n" | |
| f" Total registered sources: {total}" | |
| ) | |
| # ── Main ask handler ────────────────────────────────────────────────────────── | |
| def handle_ask(audio_path, language_label, convo_mode: bool = False, history: list | None = None): | |
| """ | |
| Main dispatcher. Always returns 5 values: | |
| (transcript, eng_translation, response_text, audio_out, new_history) | |
| new_history is the updated gr.State list of (user, asst) tuples. | |
| In normal (sensor) mode, history is passed through unchanged. | |
| """ | |
| history = history or [] | |
| if audio_path is None: | |
| return "⚠️ No audio — press Record or upload a file.", "", "", None, history | |
| language_code = SUPPORTED_LANGUAGES.get(language_label, "bam") | |
| if not _wait_for_whisper(timeout=120): | |
| return f"❌ Model failed to load: {_model_status}", "", "", None, history | |
| try: | |
| if convo_mode: | |
| transcript, eng, response_text, audio_out, new_history = _convo_pipeline( | |
| audio_path, language_code, history | |
| ) | |
| else: | |
| transcript, eng, response_text, audio_out = _run_pipeline(audio_path, language_code) | |
| new_history = history # sensor mode doesn't modify history | |
| return transcript, eng, response_text, audio_out, new_history | |
| except Exception as e: | |
| return f"❌ {e}", "", "", None, history | |
| # ── Two-stage pipeline (shows transcript fast, then response) ───────────────── | |
| def _do_asr(audio_path: str, language_label: str) -> str: | |
| """ | |
| Stage 1 — Whisper only. Returns the transcript string (or error/status). | |
| Blocks until the model is ready (up to 120 s) so the first request after a | |
| cold start works without the user needing to retry. | |
| """ | |
| if audio_path is None: | |
| return "⚠️ No audio — press Record or upload a file." | |
| lang = SUPPORTED_LANGUAGES.get(language_label, "bam") | |
| if not _wait_for_whisper(timeout=120): | |
| return f"❌ Model failed to load: {_model_status}" | |
| try: | |
| import torch, librosa | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| audio_np, _ = librosa.load(audio_path, sr=16000, mono=True) | |
| active_model = _fine_tuned_models.get(lang, _whisper_model) | |
| active_processor = _fine_tuned_processors.get(lang, _whisper_processor) | |
| active_model.to(device) | |
| with _model_lock: | |
| input_features = active_processor.feature_extractor( | |
| audio_np, sampling_rate=16000, return_tensors="pt" | |
| ).input_features.to(device) | |
| forced_ids = None | |
| if lang not in ("bam", "ful"): | |
| forced_ids = active_processor.get_decoder_prompt_ids( | |
| language=lang, task="transcribe" | |
| ) | |
| with torch.no_grad(): | |
| ids = active_model.generate( | |
| input_features, | |
| forced_decoder_ids=forced_ids or None, | |
| max_new_tokens=256, | |
| ) | |
| transcript = active_processor.batch_decode(ids, skip_special_tokens=True)[0].strip() | |
| active_model.to("cpu") | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # Phonetic normalisation (Bambara: French spellings → standard; Fula: Adlam → Latin) | |
| if lang == "bam": | |
| return bam_normalize(transcript) | |
| elif lang == "ful": | |
| return normalize_pular(transcript) | |
| return transcript | |
| except Exception as e: | |
| return f"❌ Transcription error: {e}" | |
| def _do_respond( | |
| transcript: str, | |
| language_label: str, | |
| convo_mode: bool, | |
| history: list, | |
| ) -> tuple: | |
| """ | |
| Stage 2 — LLM or sensor response, runs after transcript is already visible. | |
| Returns (eng_translation, response_text, audio_out, new_history, chat_msgs). | |
| """ | |
| history = history or [] | |
| # Bail early if stage 1 errored | |
| if not transcript or transcript[:1] in ("⚠️", "⏳", "❌") or transcript.startswith(("⚠", "⏳", "❌")): | |
| chat_msgs = [[u, v] for u, v in history] | |
| return "", "", None, history, chat_msgs | |
| lang = SUPPORTED_LANGUAGES.get(language_label, "bam") | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if convo_mode: | |
| # ── LLM brain ──────────────────────────────────────────────────────── | |
| response_text = "" | |
| try: | |
| from huggingface_hub import InferenceClient | |
| client = InferenceClient(token=HF_TOKEN) | |
| messages = _build_messages(transcript, history, lang) | |
| completion = client.chat_completion( | |
| model=LLM_MODEL_ID, | |
| messages=messages, | |
| max_tokens=150, # short spoken responses, much faster | |
| temperature=0.65, | |
| ) | |
| response_text = completion.choices[0].message.content.strip() | |
| except Exception as llm_err: | |
| import logging | |
| logging.getLogger(__name__).warning("LLM error: %s", llm_err) | |
| _fallbacks = { | |
| "bam": "Hakɛ to, tasuma tɛ kɛ sisan. I ka a lasɔrɔ tugu.", | |
| "ful": "Hakke, mi waawaa jogaade modèl oo jooni. Njaɓɓu.", | |
| "fr": "Désolé, le modèle est indisponible pour l'instant.", | |
| } | |
| response_text = _fallbacks.get(lang, "Sorry, the language model is unavailable.") | |
| # Strip [LEARNED:] tags, persist async | |
| response_text, _ = _parse_and_strip_learned(response_text, lang) | |
| # Update history | |
| new_history = list(history) + [(transcript, response_text)] | |
| if len(new_history) > 20: | |
| new_history = new_history[-20:] | |
| chat_msgs = [[u, v] for u, v in new_history] | |
| # ── TTS ─────────────────────────────────────────────────────────────── | |
| audio_out = None | |
| if _voice_ref_path and Path(_voice_ref_path).exists(): | |
| try: | |
| from src.tts.f5_tts import synthesize as f5s | |
| result = f5s(response_text, ref_wav_path=_voice_ref_path, | |
| ref_text=_voice_ref_text, device=device) | |
| if result is not None: | |
| audio_out = (result[1], result[0]) | |
| except Exception: | |
| pass | |
| if audio_out is None: | |
| try: | |
| wav_np, sr = _tts.synthesize(response_text, lang, device=device) | |
| audio_out = (sr, wav_np) | |
| except Exception as tts_err: | |
| import logging | |
| logging.getLogger(__name__).warning("TTS error: %s", tts_err) | |
| return "", response_text, audio_out, new_history, chat_msgs | |
| else: | |
| # ── Sensor / phrase pipeline ────────────────────────────────────────── | |
| import asyncio | |
| phrase_match = _phrase_matcher.match(transcript, lang) | |
| if phrase_match: | |
| response_text = phrase_match["response"] | |
| english_translation = phrase_match["english"] | |
| else: | |
| intent = _intent_parser.parse(transcript, language=lang) | |
| try: | |
| loop = asyncio.new_event_loop() | |
| sensor_data = loop.run_until_complete(_sensor_bridge.fetch(intent)) | |
| loop.close() | |
| except Exception: | |
| from src.iot.sensor_bridge import SensorData | |
| sensor_data = SensorData(sensor_type="soil", | |
| values={"moisture_pct": 45.0, "ph": 6.5, "temperature_c": 28.0}) | |
| responder = VoiceResponder(language=lang) | |
| response_text, english_translation = responder.generate_response(intent, sensor_data) | |
| if intent.action == "unknown" and intent.confidence < 0.15: | |
| from src.iot.voice_responder import BAMBARA_TEMPLATES, FULA_TEMPLATES | |
| if lang == "bam": | |
| response_text, english_translation = BAMBARA_TEMPLATES["not_understood"] | |
| elif lang == "ful": | |
| response_text, english_translation = FULA_TEMPLATES["not_understood"] | |
| audio_out = None | |
| try: | |
| wav_np, sr = _tts.synthesize(response_text, lang, device=device) | |
| audio_out = (sr, wav_np) | |
| except Exception as tts_err: | |
| import logging | |
| logging.getLogger(__name__).warning("TTS error: %s", tts_err) | |
| chat_msgs = [[u, v] for u, v in history] | |
| return english_translation, response_text, audio_out, history, chat_msgs | |
| # ── Gradio UI ───────────────────────────────────────────────────────────────── | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(title="Sahel-Agri Voice AI") as demo: | |
| gr.Markdown("# 🌾 Sahel-Agri Voice AI") | |
| gr.Markdown( | |
| "Speak in **Bambara** or **Fula** — get agricultural insights spoken back " | |
| "in your language. Also supports French and English." | |
| ) | |
| model_status_box = gr.Textbox( | |
| value=get_model_status(), | |
| label="Model status", | |
| interactive=False, | |
| ) | |
| # gr.Timer polls get_model_status every 3s and updates the box (Gradio 5) | |
| status_timer = gr.Timer(value=3) | |
| status_timer.tick(fn=get_model_status, outputs=model_status_box) | |
| with gr.Tabs() as tabs: | |
| # ── Tab 1: Voice Assistant ──────────────────────────────────────── | |
| with gr.TabItem("🎙️ Voice Assistant", id="tab_voice"): | |
| # ── Conversation Mode controls (top bar) ───────────────────── | |
| with gr.Row(): | |
| convo_mode_toggle = gr.Checkbox( | |
| value=False, | |
| label="🔄 Conversation Mode — AI responds with LLM + cloned voice", | |
| info="When ON: mic auto-submits on stop; AI replies via LLM + F5-TTS (requires voice reference below).", | |
| ) | |
| with gr.Accordion("🎤 Voice Reference — upload an MP3/WAV of the target speaker", open=False): | |
| gr.Markdown( | |
| "Upload **5–30 seconds** of clear speech in the target voice. " | |
| "The AI will speak all its responses using this voice. " | |
| "Requires `f5-tts` and a GPU — falls back to MMS-TTS otherwise." | |
| ) | |
| with gr.Row(): | |
| voice_ref_input = gr.Audio( | |
| sources=["upload"], | |
| type="filepath", | |
| label="Reference audio (MP3 or WAV)", | |
| ) | |
| voice_ref_status = gr.Textbox( | |
| label="Status", interactive=False, lines=3 | |
| ) | |
| voice_ref_btn = gr.Button("💾 Set as Voice Reference", variant="secondary") | |
| voice_ref_btn.click( | |
| fn=set_voice_reference, | |
| inputs=[voice_ref_input], | |
| outputs=[voice_ref_status], | |
| ) | |
| gr.Markdown("---") | |
| # Per-session conversation history (not shared between users) | |
| conv_history = gr.State(value=[]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| language_dd = gr.Dropdown( | |
| choices=list(SUPPORTED_LANGUAGES.keys()), | |
| value="Bambara (bam)", | |
| label="Language / Kan", | |
| ) | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Record or upload audio", | |
| ) | |
| with gr.Row(): | |
| ask_btn = gr.Button("▶ Ask / Ɲinɛ", variant="primary") | |
| clear_btn = gr.Button("🗑 Clear", variant="secondary", size="sm") | |
| with gr.Column(scale=1): | |
| transcript_box = gr.Textbox( | |
| label="Whisper heard", | |
| lines=2, | |
| placeholder="Your words will appear here…", | |
| interactive=False, | |
| ) | |
| translation_box = gr.Textbox( | |
| label="English translation", | |
| lines=2, | |
| placeholder="(shown in sensor mode only)", | |
| interactive=False, | |
| ) | |
| response_box = gr.Textbox( | |
| label="AI response", | |
| lines=2, | |
| placeholder="Response will appear here…", | |
| interactive=False, | |
| ) | |
| audio_output = gr.Audio( | |
| label="Voice response", | |
| autoplay=True, | |
| interactive=False, | |
| ) | |
| correct_btn = gr.Button( | |
| "✏️ Something wrong? Send to Correction tab", | |
| variant="secondary", | |
| size="sm", | |
| ) | |
| # Conversation history display (Conversation Mode only) | |
| chatbot = gr.Chatbot( | |
| label="Conversation history", | |
| height=300, | |
| visible=False, | |
| type="tuples", | |
| ) | |
| convo_mode_toggle.change( | |
| fn=lambda on: gr.update(visible=on), | |
| inputs=[convo_mode_toggle], | |
| outputs=[chatbot], | |
| ) | |
| # ── Stage 1 inputs/outputs (ASR only — fast) ───────────────── | |
| _s1_inputs = [audio_input, language_dd] | |
| _s1_outputs = [transcript_box] | |
| # ── Stage 2 inputs/outputs (LLM / sensor + TTS) ────────────── | |
| _s2_inputs = [transcript_box, language_dd, convo_mode_toggle, conv_history] | |
| _s2_outputs = [translation_box, response_box, audio_output, | |
| conv_history, chatbot] | |
| # Manual button: stage 1 then stage 2 | |
| ask_btn.click( | |
| fn=_do_asr, | |
| inputs=_s1_inputs, | |
| outputs=_s1_outputs, | |
| ).then( | |
| fn=_do_respond, | |
| inputs=_s2_inputs, | |
| outputs=_s2_outputs, | |
| ) | |
| # Auto-submit on mic stop: same chain, but stage 2 only runs when | |
| # convo_mode is ON (sensor mode has a manual button for deliberate use) | |
| audio_input.stop_recording( | |
| fn=_do_asr, | |
| inputs=_s1_inputs, | |
| outputs=_s1_outputs, | |
| ).then( | |
| fn=lambda t, ll, cm, h: _do_respond(t, ll, cm, h) if cm | |
| else ("", "", None, h, [[u, v] for u, v in (h or [])]), | |
| inputs=_s2_inputs, | |
| outputs=_s2_outputs, | |
| ) | |
| # Clear conversation | |
| clear_btn.click( | |
| fn=lambda: ([], []), | |
| outputs=[conv_history, chatbot], | |
| ) | |
| # ── Tab 2: Feedback & Correction ───────────────────────────────── | |
| with gr.TabItem("📝 Feedback & Correction", id="tab_feedback"): | |
| gr.Markdown( | |
| "Correct what Whisper heard, the English translation, and the response. " | |
| "All corrections are saved to the training dataset to improve future accuracy." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| fb_lang = gr.Dropdown( | |
| choices=list(SUPPORTED_LANGUAGES.keys()), | |
| value="Bambara (bam)", | |
| label="Language", | |
| ) | |
| fb_audio = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Audio", | |
| ) | |
| gr.Markdown("**Step 1 — Fix the transcription**") | |
| fb_transcript = gr.Textbox( | |
| label="What Whisper heard", | |
| lines=2, | |
| placeholder="Auto-filled from Tab 1…", | |
| ) | |
| fb_corrected = gr.Textbox( | |
| label="✏️ What was actually said (in Bambara/Fula)", | |
| lines=2, | |
| placeholder="Type the correct transcription here…", | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("**Step 2 — Fix the English translation**") | |
| fb_english = gr.Textbox( | |
| label="Auto-generated English translation", | |
| lines=2, | |
| placeholder="Auto-filled from Tab 1…", | |
| ) | |
| fb_corrected_english = gr.Textbox( | |
| label="✏️ Correct English translation", | |
| lines=2, | |
| placeholder="Type the correct English meaning here…", | |
| ) | |
| gr.Markdown("**Step 3 — Fix the response**") | |
| fb_response = gr.Textbox( | |
| label="Auto-generated response", | |
| lines=2, | |
| placeholder="Auto-filled from Tab 1…", | |
| ) | |
| fb_corrected_response = gr.Textbox( | |
| label="✏️ Better response (in farmer's language)", | |
| lines=2, | |
| placeholder="Type a better response here…", | |
| ) | |
| fb_rating = gr.Slider( | |
| minimum=1, maximum=5, step=1, value=3, | |
| label="Overall quality (1 = poor, 5 = excellent)", | |
| ) | |
| fb_notes = gr.Textbox( | |
| label="Notes (optional)", | |
| lines=2, | |
| placeholder="e.g. noisy background, strong accent…", | |
| ) | |
| save_btn = gr.Button("💾 Save to Dataset", variant="primary") | |
| save_status = gr.Textbox( | |
| label="Save status", interactive=False, lines=2 | |
| ) | |
| save_btn.click( | |
| fn=_save_feedback_to_hub, | |
| inputs=[ | |
| fb_audio, fb_transcript, fb_corrected, | |
| fb_english, fb_corrected_english, | |
| fb_response, fb_corrected_response, | |
| fb_rating, fb_notes, fb_lang, | |
| ], | |
| outputs=[save_status], | |
| ) | |
| # Wire "Send to Correction" button — populates Tab 2 fields from Tab 1 | |
| correct_btn.click( | |
| fn=lambda t, tr, r, lang: (t, t, tr, tr, r, r, lang), | |
| inputs=[transcript_box, translation_box, response_box, language_dd], | |
| outputs=[fb_transcript, fb_corrected, fb_english, fb_corrected_english, | |
| fb_response, fb_corrected_response, fb_lang], | |
| ) | |
| # ── Tab 3: Knowledge Base ───────────────────────────────────────── | |
| with gr.TabItem("📚 Knowledge Base"): | |
| gr.Markdown( | |
| "## Teach the assistant new phrases — no technical knowledge required\n\n" | |
| "Add phrases the assistant should recognise and respond to. " | |
| "Changes take effect **immediately** and are saved to the Hub so they survive restarts." | |
| ) | |
| with gr.Row(): | |
| # ── Left: phrase pair import ────────────────────────────── | |
| with gr.Column(): | |
| gr.Markdown( | |
| "### ➕ Add phrases manually\n" | |
| "One phrase per line in the format:\n" | |
| "```\nnative phrase | English translation\n```\n" | |
| "**Examples (Bambara):**\n" | |
| "```\nI ni ce | Hello, good day\n" | |
| "Sanji bɛ na | Rain is coming\n" | |
| "N bɛ i dɛmɛ | I will help you\n```\n" | |
| "**Examples (Fula):**\n" | |
| "```\nJam waali | Hello, peace be with you\n" | |
| "Ndiyam wadata | Rain is coming\n" | |
| "Mi woni ɗoo | I am here\n```" | |
| ) | |
| kb_lang = gr.Dropdown( | |
| choices=["Bambara (bam)", "Fula (ful)"], | |
| value="Bambara (bam)", | |
| label="Language", | |
| ) | |
| kb_pairs = gr.Textbox( | |
| lines=10, | |
| placeholder="I ni ce | Hello, good day\nI ni sogoma | Good morning\nSanji bɛ na | Rain is coming", | |
| label="Phrase pairs (native | english) — one per line", | |
| ) | |
| kb_import_btn = gr.Button("➕ Add to Knowledge Base", variant="primary") | |
| kb_status = gr.Textbox(label="Status", interactive=False, lines=3) | |
| # ── Right: audio upload for training ───────────────────── | |
| with gr.Column(): | |
| gr.Markdown( | |
| "### 🎬 Add audio from YouTube (or anywhere)\n" | |
| "HuggingFace Spaces cannot download YouTube directly, " | |
| "so convert the video to audio first on your computer:\n\n" | |
| "**Free online converters:**\n" | |
| "- [ytmp3.cc](https://ytmp3.cc) — paste YouTube URL → download MP3\n" | |
| "- [cobalt.tools](https://cobalt.tools) — paste any video URL → download audio\n" | |
| "- [y2mate.com](https://y2mate.com) — paste YouTube URL → download MP3\n\n" | |
| "**Good YouTube search terms:**\n" | |
| "- Bambara: *'Bamanankan conversation'*, *'Bambara leçon'*, *'donsomana'*\n" | |
| "- Fula: *'Fulfulde leçon'*, *'Pular conversation'*, *'Fula radio'*\n\n" | |
| "Then upload the MP3/WAV file below with its transcription." | |
| ) | |
| yt_lang = gr.Dropdown( | |
| choices=["Bambara (bam)", "Fula (ful)"], | |
| value="Bambara (bam)", | |
| label="Language spoken in the audio", | |
| ) | |
| yt_audio = gr.Audio( | |
| sources=["upload"], | |
| type="filepath", | |
| label="Upload audio file (MP3 or WAV)", | |
| ) | |
| yt_transcript = gr.Textbox( | |
| lines=5, | |
| placeholder="Type what is said in the audio (as much as you can).\n" | |
| "Example:\nJam waali. No mbadda. Mi woni ɗoo wallude ma.", | |
| label="Transcription — what is said in this audio", | |
| ) | |
| yt_source = gr.Textbox( | |
| placeholder="e.g. YouTube: Bambara lesson by Moussa Kouyaté", | |
| label="Source (optional — for your records)", | |
| ) | |
| yt_btn = gr.Button("💾 Save Audio for Training", variant="secondary") | |
| yt_status = gr.Textbox(label="Status", interactive=False, lines=4) | |
| kb_import_btn.click( | |
| fn=_import_phrase_pairs, | |
| inputs=[kb_lang, kb_pairs], | |
| outputs=[kb_status], | |
| ) | |
| yt_btn.click( | |
| fn=_save_audio_for_training, | |
| inputs=[yt_lang, yt_audio, yt_transcript, yt_source], | |
| outputs=[yt_status], | |
| ) | |
| # ── Document upload (PDF / Word / TXT) ─────────────────────── | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "### 📄 Upload documents (PDF, Word, TXT)\n" | |
| "Extract sentences from books, articles, or lesson PDFs. " | |
| "Each sentence is added to the training vocabulary in the language you select below. " | |
| "**Upload one batch per language** — do not mix Bambara and Fula files in one upload." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| doc_lang = gr.Dropdown( | |
| choices=["Bambara (bam)", "Fula (ful)"], | |
| value="Fula (ful)", | |
| label="Language of these documents", | |
| ) | |
| doc_files = gr.File( | |
| label="Upload .pdf, .docx, or .txt (multiple allowed)", | |
| file_count="multiple", | |
| file_types=[".pdf", ".docx", ".doc", ".txt", ".md"], | |
| ) | |
| doc_source = gr.Textbox( | |
| placeholder="e.g. SIL Pular grammar book, Labé lesson PDFs", | |
| label="Source note (optional — for your records)", | |
| ) | |
| doc_btn = gr.Button("📥 Extract & Add to Training Data", variant="primary") | |
| with gr.Column(): | |
| doc_status = gr.Textbox( | |
| label="Import status", | |
| interactive=False, | |
| lines=12, | |
| ) | |
| doc_btn.click( | |
| fn=_import_documents, | |
| inputs=[doc_lang, doc_files, doc_source], | |
| outputs=[doc_status], | |
| ) | |
| # ── Tab 4: Model Training ───────────────────────────────────────── | |
| with gr.TabItem("🔧 Model Training"): | |
| gr.Markdown( | |
| "After collecting audio corrections and YouTube samples, " | |
| "run the training notebook to fine-tune the speech model." | |
| ) | |
| adapter_status_md = gr.Markdown(value=_get_adapter_status()) | |
| reload_btn = gr.Button("🔄 Reload Fine-tuned Models from Hub") | |
| reload_out = gr.Markdown() | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "**Training notebook**: " | |
| "`notebooks/kaggle_master_trainer.ipynb` — import to Kaggle, run all cells.\n\n" | |
| "**What feeds training:**\n" | |
| "- Tab 2 corrections → `corrections.jsonl` in the feedback dataset\n" | |
| "- Tab 3 audio uploads → `corrections.jsonl` (same file)\n" | |
| "- Tab 3 phrase pairs → `vocabulary.jsonl` (used as synthetic fallback labels)\n\n" | |
| "**Feedback dataset**: " | |
| f"`{FEEDBACK_REPO_ID}` (auto-updated on each save)\n\n" | |
| "**Model checkpoint repo**: " | |
| f"`{ADAPTER_REPO_ID}` (updated after training, reload above to activate)" | |
| ) | |
| reload_btn.click(fn=_reload_adapters_from_hub, outputs=[reload_out]) | |
| reload_btn.click(fn=_get_adapter_status, outputs=[adapter_status_md]) | |
| # ── Tab 5: Bulk Upload ──────────────────────────────────────────── | |
| with gr.TabItem("📦 Bulk Upload"): | |
| gr.Markdown( | |
| "## Upload many audio samples at once\n\n" | |
| "**Step 1** — Prepare a ZIP file containing your audio files (WAV/MP3).\n\n" | |
| "**Step 2** — Prepare a CSV with two columns: `filename,transcription`\n" | |
| "```\nbam_001.wav,I ni ce a tɔ\nbam_002.wav,Sanji bɛ na sini\n```\n\n" | |
| "**Step 3** — Select language, upload ZIP, paste CSV, click Upload." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| bulk_lang = gr.Dropdown( | |
| choices=["Bambara (bam)", "Fula (ful)"], | |
| value="Bambara (bam)", label="Language" | |
| ) | |
| bulk_zip = gr.File( | |
| label="ZIP file (audio files)", file_types=[".zip"] | |
| ) | |
| bulk_csv = gr.Textbox( | |
| lines=10, | |
| label="CSV — filename,transcription (one per line)", | |
| placeholder="bam_001.wav,I ni ce a tɔ\nbam_002.wav,Sanji bɛ na sini", | |
| ) | |
| bulk_btn = gr.Button("📤 Upload Batch", variant="primary") | |
| bulk_status = gr.Textbox(label="Status", interactive=False, lines=5) | |
| bulk_btn.click( | |
| fn=_bulk_upload, | |
| inputs=[bulk_lang, bulk_zip, bulk_csv], | |
| outputs=[bulk_status], | |
| ) | |
| # ── Tab 6: Self-Teaching ────────────────────────────────────────── | |
| with gr.TabItem("🌐 Self-Teaching"): | |
| gr.Markdown( | |
| "## Teach the model from the internet\n\n" | |
| "These tools pull publicly available Bambara and Fula language data " | |
| "directly into your training dataset — no manual work required." | |
| ) | |
| with gr.Row(): | |
| # Wikipedia harvest | |
| with gr.Column(): | |
| gr.Markdown( | |
| "### 📖 Wikipedia Text Harvest\n" | |
| "Pulls sentence-length text from Bambara Wikipedia (868 articles) " | |
| "or Fula Wikipedia (17,000+ articles) into `vocabulary.jsonl`.\n\n" | |
| "Use this to expand vocabulary coverage before a training run." | |
| ) | |
| wiki_lang = gr.Dropdown( | |
| choices=["Bambara (bam)", "Fula (ful)"], | |
| value="Bambara (bam)", label="Language" | |
| ) | |
| wiki_articles = gr.Slider( | |
| minimum=10, maximum=500, value=100, step=10, | |
| label="Max articles to fetch" | |
| ) | |
| wiki_btn = gr.Button("📖 Harvest Wikipedia Text", variant="secondary") | |
| wiki_status = gr.Textbox(label="Status", interactive=False, lines=4) | |
| wiki_btn.click( | |
| fn=_harvest_wikipedia, | |
| inputs=[wiki_lang, wiki_articles], | |
| outputs=[wiki_status], | |
| ) | |
| # HF dataset harvest | |
| with gr.Column(): | |
| gr.Markdown( | |
| "### 🤗 HuggingFace Dataset Import\n" | |
| "Registers large public datasets as training sources:\n" | |
| "- **Bambara**: `RobotsMali/jeli-asr` (33,000 samples)\n" | |
| "- **Fula**: `google/WaxalNLP ful_asr` + `Pullo-Africa-Protagonist/Fula-pular` (9,761 samples) + `guizme/adlam_fulfulde` (51 Adlam samples)\n\n" | |
| "This writes a reference to `dataset_sources.jsonl`. " | |
| "The Kaggle training notebook streams the dataset directly " | |
| "at training time — no re-upload needed.\n\n" | |
| "**One click is enough** — duplicates are ignored automatically." | |
| ) | |
| hf_lang = gr.Dropdown( | |
| choices=["Bambara (bam)", "Fula (ful)"], | |
| value="Bambara (bam)", label="Language" | |
| ) | |
| hf_samples = gr.Slider( | |
| minimum=50, maximum=2000, value=500, step=50, | |
| label="Max samples to import" | |
| ) | |
| hf_btn = gr.Button("🤗 Import from HuggingFace", variant="primary") | |
| hf_status = gr.Textbox(label="Status", interactive=False, lines=5) | |
| hf_btn.click( | |
| fn=_harvest_hf_dataset, | |
| inputs=[hf_lang, hf_samples], | |
| outputs=[hf_status], | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "### ⚡ Auto-Training\n" | |
| f"When `corrections.jsonl` reaches a multiple of **{AUTO_TRAIN_THRESHOLD}** entries, " | |
| "the Kaggle training notebook is triggered automatically.\n\n" | |
| "To enable: add `KAGGLE_USERNAME` and `KAGGLE_KEY` in Space Settings → Secrets.\n\n" | |
| f"Kernel: `{KAGGLE_KERNEL_SLUG}`" | |
| ) | |
| with gr.Row(): | |
| trigger_lang = gr.Dropdown( | |
| choices=["Bambara (bam)", "Fula (ful)"], | |
| value="Bambara (bam)", label="Language to train" | |
| ) | |
| trigger_btn = gr.Button("⚡ Trigger Training Now", variant="secondary") | |
| trigger_out = gr.Textbox(label="Status", interactive=False, lines=2) | |
| trigger_btn.click( | |
| fn=lambda l: _trigger_kaggle_training(SUPPORTED_LANGUAGES.get(l, "bam")), | |
| inputs=[trigger_lang], | |
| outputs=[trigger_out], | |
| ) | |
| return demo | |
| # ── Entry point ─────────────────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Re-read env after dotenv | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| FEEDBACK_REPO_ID = os.environ.get("FEEDBACK_REPO_ID", "ous-sow/sahel-agri-feedback") | |
| ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapters") | |
| WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-large-v3-turbo") | |
| LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct") | |
| if HF_TOKEN: | |
| from huggingface_hub import HfApi | |
| _hf_api = HfApi(token=HF_TOKEN) | |
| # Load any previously saved phrase additions from HF Hub | |
| _load_phrase_additions_from_hub() | |
| # Kick off background model load immediately | |
| _ensure_whisper_loaded() | |
| print(f"Whisper model : {WHISPER_MODEL_ID}") | |
| print(f"LLM model : {LLM_MODEL_ID}") | |
| print(f"Feedback repo : {FEEDBACK_REPO_ID}") | |
| print(f"Adapter repo : {ADAPTER_REPO_ID}") | |
| print(f"HF_TOKEN set : {'yes' if HF_TOKEN else 'no (local-only mode)'}") | |
| print() | |
| demo = build_ui() | |
| demo.launch( | |
| server_port=7860, # HF Spaces standard port | |
| inbrowser=False, | |
| share=False, | |
| show_api=False, | |
| ssr_mode=False, # SSR starts a Node.js process that hangs in HF Spaces containers | |
| ) | |