Spaces:
Running
Running
| """ | |
| Agricultural vocabulary for Bambara and Fula. | |
| Used to bias the Whisper decoder toward domain-specific terms via decoder prompt injection. | |
| """ | |
| from __future__ import annotations | |
| from typing import TYPE_CHECKING | |
| import torch | |
| if TYPE_CHECKING: | |
| from transformers import WhisperProcessor | |
| # Bambara (bam) agricultural vocabulary | |
| BAMBARA_VOCAB: dict[str, str] = { | |
| "sɛnɛ": "farming", | |
| "jiriw": "trees", | |
| "nɔgɔ": "soil", | |
| "sani": "fertilizer", | |
| "kogomali": "groundnut", | |
| "kaba": "corn/maize", | |
| "tiga": "peanut", | |
| "ji": "water", | |
| "sanji": "rain", | |
| "teliman": "weather", | |
| "suruku": "pest/predator", | |
| "bunding": "soil/earth", | |
| "sira": "path/way", | |
| "foro": "field", | |
| "dugu": "village/land", | |
| "dibi": "darkness/shade", | |
| "fanga": "strength/fertilizer", | |
| "kungoloni": "insects/pests", | |
| } | |
| # Fula (ful / Fulfulde) agricultural vocabulary | |
| FULA_VOCAB: dict[str, str] = { | |
| "ngesa": "field", | |
| "leydi": "land/soil", | |
| "kosam": "milk", | |
| "nagge": "cattle", | |
| "leeɗe": "crops", | |
| "ndiyam": "water", | |
| "yeeso": "wind/weather", | |
| "laabi": "road/way", | |
| "demoore": "farming", | |
| "hoore": "head/top", | |
| "biñ-biñ": "insects/pests", | |
| "fuɗorde": "sunrise/east field", | |
| "ngaari": "bull", | |
| "mbabba": "donkey", | |
| "ladde": "bush/forest", | |
| "wutte": "clothing/harvest", | |
| } | |
| LANGUAGE_VOCABS: dict[str, dict[str, str]] = { | |
| "bam": BAMBARA_VOCAB, | |
| "ful": FULA_VOCAB, | |
| } | |
| class AgriculturalDictionary: | |
| """Converts agricultural vocabulary into decoder prompt token IDs for Whisper.""" | |
| def get_vocab(self, language: str) -> dict[str, str]: | |
| if language not in LANGUAGE_VOCABS: | |
| raise ValueError(f"No vocabulary for language '{language}'. Available: {list(LANGUAGE_VOCABS)}") | |
| return LANGUAGE_VOCABS[language] | |
| def get_prompt_text(self, language: str) -> str: | |
| """Return a comma-joined string of all terms, used as decoder text prompt.""" | |
| vocab = self.get_vocab(language) | |
| return ", ".join(vocab.keys()) | |
| def build_prompt_ids(self, processor: "WhisperProcessor", language: str) -> torch.Tensor: | |
| """ | |
| Tokenize the vocabulary as a decoder prompt. | |
| Pass this as `decoder_input_ids` or `prompt_ids` to model.generate() | |
| to bias decoding toward known agricultural terms. | |
| """ | |
| prompt_text = self.get_prompt_text(language) | |
| token_ids = processor.tokenizer( | |
| prompt_text, | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| ).input_ids | |
| return token_ids # shape: (1, N) | |
| def get_token_ids(self, processor: "WhisperProcessor", language: str) -> list[int]: | |
| """Return flat list of token IDs for all vocabulary terms.""" | |
| ids = self.build_prompt_ids(processor, language) | |
| return ids[0].tolist() | |