File size: 2,842 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Agricultural vocabulary for Bambara and Fula.
Used to bias the Whisper decoder toward domain-specific terms via decoder prompt injection.
"""
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
    from transformers import WhisperProcessor

# Bambara (bam) agricultural vocabulary
BAMBARA_VOCAB: dict[str, str] = {
    "sɛnɛ": "farming",
    "jiriw": "trees",
    "nɔgɔ": "soil",
    "sani": "fertilizer",
    "kogomali": "groundnut",
    "kaba": "corn/maize",
    "tiga": "peanut",
    "ji": "water",
    "sanji": "rain",
    "teliman": "weather",
    "suruku": "pest/predator",
    "bunding": "soil/earth",
    "sira": "path/way",
    "foro": "field",
    "dugu": "village/land",
    "dibi": "darkness/shade",
    "fanga": "strength/fertilizer",
    "kungoloni": "insects/pests",
}

# Fula (ful / Fulfulde) agricultural vocabulary
FULA_VOCAB: dict[str, str] = {
    "ngesa": "field",
    "leydi": "land/soil",
    "kosam": "milk",
    "nagge": "cattle",
    "leeɗe": "crops",
    "ndiyam": "water",
    "yeeso": "wind/weather",
    "laabi": "road/way",
    "demoore": "farming",
    "hoore": "head/top",
    "biñ-biñ": "insects/pests",
    "fuɗorde": "sunrise/east field",
    "ngaari": "bull",
    "mbabba": "donkey",
    "ladde": "bush/forest",
    "wutte": "clothing/harvest",
}

LANGUAGE_VOCABS: dict[str, dict[str, str]] = {
    "bam": BAMBARA_VOCAB,
    "ful": FULA_VOCAB,
}


class AgriculturalDictionary:
    """Converts agricultural vocabulary into decoder prompt token IDs for Whisper."""

    def get_vocab(self, language: str) -> dict[str, str]:
        if language not in LANGUAGE_VOCABS:
            raise ValueError(f"No vocabulary for language '{language}'. Available: {list(LANGUAGE_VOCABS)}")
        return LANGUAGE_VOCABS[language]

    def get_prompt_text(self, language: str) -> str:
        """Return a comma-joined string of all terms, used as decoder text prompt."""
        vocab = self.get_vocab(language)
        return ", ".join(vocab.keys())

    def build_prompt_ids(self, processor: "WhisperProcessor", language: str) -> torch.Tensor:
        """
        Tokenize the vocabulary as a decoder prompt.
        Pass this as `decoder_input_ids` or `prompt_ids` to model.generate()
        to bias decoding toward known agricultural terms.
        """
        prompt_text = self.get_prompt_text(language)
        token_ids = processor.tokenizer(
            prompt_text,
            return_tensors="pt",
            add_special_tokens=False,
        ).input_ids
        return token_ids  # shape: (1, N)

    def get_token_ids(self, processor: "WhisperProcessor", language: str) -> list[int]:
        """Return flat list of token IDs for all vocabulary terms."""
        ids = self.build_prompt_ids(processor, language)
        return ids[0].tolist()