""" Agricultural vocabulary for Bambara and Fula. Used to bias the Whisper decoder toward domain-specific terms via decoder prompt injection. """ from __future__ import annotations from typing import TYPE_CHECKING import torch if TYPE_CHECKING: from transformers import WhisperProcessor # Bambara (bam) agricultural vocabulary BAMBARA_VOCAB: dict[str, str] = { "sɛnɛ": "farming", "jiriw": "trees", "nɔgɔ": "soil", "sani": "fertilizer", "kogomali": "groundnut", "kaba": "corn/maize", "tiga": "peanut", "ji": "water", "sanji": "rain", "teliman": "weather", "suruku": "pest/predator", "bunding": "soil/earth", "sira": "path/way", "foro": "field", "dugu": "village/land", "dibi": "darkness/shade", "fanga": "strength/fertilizer", "kungoloni": "insects/pests", } # Fula (ful / Fulfulde) agricultural vocabulary FULA_VOCAB: dict[str, str] = { "ngesa": "field", "leydi": "land/soil", "kosam": "milk", "nagge": "cattle", "leeɗe": "crops", "ndiyam": "water", "yeeso": "wind/weather", "laabi": "road/way", "demoore": "farming", "hoore": "head/top", "biñ-biñ": "insects/pests", "fuɗorde": "sunrise/east field", "ngaari": "bull", "mbabba": "donkey", "ladde": "bush/forest", "wutte": "clothing/harvest", } LANGUAGE_VOCABS: dict[str, dict[str, str]] = { "bam": BAMBARA_VOCAB, "ful": FULA_VOCAB, } class AgriculturalDictionary: """Converts agricultural vocabulary into decoder prompt token IDs for Whisper.""" def get_vocab(self, language: str) -> dict[str, str]: if language not in LANGUAGE_VOCABS: raise ValueError(f"No vocabulary for language '{language}'. Available: {list(LANGUAGE_VOCABS)}") return LANGUAGE_VOCABS[language] def get_prompt_text(self, language: str) -> str: """Return a comma-joined string of all terms, used as decoder text prompt.""" vocab = self.get_vocab(language) return ", ".join(vocab.keys()) def build_prompt_ids(self, processor: "WhisperProcessor", language: str) -> torch.Tensor: """ Tokenize the vocabulary as a decoder prompt. Pass this as `decoder_input_ids` or `prompt_ids` to model.generate() to bias decoding toward known agricultural terms. """ prompt_text = self.get_prompt_text(language) token_ids = processor.tokenizer( prompt_text, return_tensors="pt", add_special_tokens=False, ).input_ids return token_ids # shape: (1, N) def get_token_ids(self, processor: "WhisperProcessor", language: str) -> list[int]: """Return flat list of token IDs for all vocabulary terms.""" ids = self.build_prompt_ids(processor, language) return ids[0].tolist()