File size: 7,120 Bytes
8952fff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
"""Patch kaggle_master_trainer.ipynb: add bam_normalize, replace WER with CER."""
import json, sys, re
sys.stdout.reconfigure(encoding="utf-8")

NB = "notebooks/kaggle_master_trainer.ipynb"

with open(NB, encoding="utf-8") as f:
    nb = json.load(f)
cells = nb["cells"]

changed = []

# ── Cell 10 (idx=11): inject _bam_norm definition ────────────────────────────
old = "".join(cells[11]["source"])
OLD_TOP = (
    "# -- Cell 10: Text cleaning utilities -----------------------------------------\n"
    "import re, unicodedata"
)
NEW_TOP = (
    "# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\n"
    "import re, unicodedata\n"
    "\n"
    "# Phonetic normaliser: unifies French-influenced spellings before training.\n"
    "# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n"
    "_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','\u0272'),('ny','\u0272'),"
    "('ch','c'),('oo','\u0254'),('ee','\u025b')]\n"
    "_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n"
    "_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n"
    "\n"
    "def _bam_norm(text):\n"
    "    import unicodedata as _ud\n"
    "    text = _ud.normalize('NFC', text.lower())\n"
    "    return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n"
)
if OLD_TOP in old:
    cells[11]["source"] = [old.replace(OLD_TOP, NEW_TOP)]
    changed.append("Cell 10: _bam_norm injected")
else:
    changed.append("Cell 10: OLD_TOP not found - skip")

# ── Cell 11 (idx=12): apply _bam_norm in prepare_dataset ─────────────────────
old = "".join(cells[12]["source"])
OLD_PREP = "    cleaned         = clean_text(str(raw_text), lang=lang)"
NEW_PREP = (
    "    _norm_text      = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n"
    "    cleaned         = clean_text(_norm_text, lang=lang)"
)
if OLD_PREP in old:
    cells[12]["source"] = [old.replace(OLD_PREP, NEW_PREP)]
    changed.append("Cell 11: normaliser applied in prepare_dataset")
else:
    changed.append(f"Cell 11: prepare pattern not found ({repr(old[old.find('cleaned'):old.find('cleaned')+60])})")

# ── Cell 14 (idx=17): WER -> CER in compute_metrics ──────────────────────────
old = "".join(cells[17]["source"])
# Replace header comment
new = old.replace(
    "# -- Cell 14: Data collator + WER metric",
    "# -- Cell 14: Data collator + CER metric"
)
# Add CER transform after existing transform definition
OLD_TRANSFORM_END = "    jiwer.ReduceToListOfListOfWords(),\n])"
NEW_TRANSFORM_END = (
    "    jiwer.ReduceToListOfListOfWords(),\n"
    "])\n"
    "\n"
    "# CER transform (no word-split step needed)\n"
    "_cer_transform = jiwer.Compose([\n"
    "    jiwer.ToLowerCase(),\n"
    "    jiwer.RemoveMultipleSpaces(),\n"
    "    jiwer.Strip(),\n"
    "    jiwer.RemovePunctuation(),\n"
    "])"
)
new = new.replace(OLD_TRANSFORM_END, NEW_TRANSFORM_END)
# Replace return value in compute_metrics
OLD_RETURN = (
    "    wer = jiwer.wer(label_str, pred_str,\n"
    "                    hypothesis_transform=transform,\n"
    "                    reference_transform=transform)\n"
    "    return {'wer': round(wer, 4)}"
)
NEW_RETURN = (
    "    cer = jiwer.cer(\n"
    "        label_str, pred_str,\n"
    "        reference_transform=_cer_transform,\n"
    "        hypothesis_transform=_cer_transform,\n"
    "    )\n"
    "    wer = jiwer.wer(label_str, pred_str,\n"
    "                    hypothesis_transform=transform,\n"
    "                    reference_transform=transform)\n"
    "    return {'cer': round(cer, 4), 'wer': round(wer, 4)}"
)
new = new.replace(OLD_RETURN, NEW_RETURN)
if new != old:
    cells[17]["source"] = [new]
    changed.append("Cell 14: WER->CER in compute_metrics")
else:
    changed.append("Cell 14: no changes applied")

# ── Cell 15 (idx=19): metric_for_best_model ───────────────────────────────────
old = "".join(cells[19]["source"])
new = old.replace(
    "    metric_for_best_model='wer',",
    "    metric_for_best_model='cer',"
)
if new != old:
    cells[19]["source"] = [new]
    changed.append("Cell 15: metric_for_best_model=cer")
else:
    changed.append("Cell 15: no change")

# ── Cell 17 (idx=22): CER display in evaluation ───────────────────────────────
old = "".join(cells[22]["source"])
OLD_WER_PRINT = (
    "wer_score = eval_results.get('eval_wer', float('nan'))\n"
    "print(f'\\n? Final WER : {wer_score:.1%}')\n"
    "print(f'   Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
)
NEW_WER_PRINT = (
    "cer_score = eval_results.get('eval_cer', float('nan'))\n"
    "wer_score = eval_results.get('eval_wer', float('nan'))\n"
    "print(f'\\n\u2705 Final CER : {cer_score:.1%}  (primary — lower is better)')\n"
    "print(f'   Final WER : {wer_score:.1%}  (secondary)')\n"
    "print(f'   Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
)
if OLD_WER_PRINT in old:
    cells[22]["source"] = [old.replace(OLD_WER_PRINT, NEW_WER_PRINT)]
    changed.append("Cell 17: CER display")
else:
    changed.append("Cell 17: print pattern not found")
    # Try to find what's there
    idx = old.find("wer_score")
    if idx >= 0:
        changed.append(f"  ...found: {repr(old[idx:idx+100])}")

# ── Cell 19 push (idx=25): cer_score in commit msg ───────────────────────────
old = "".join(cells[25]["source"])
new = (
    old
    .replace(
        "_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'",
        "_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'"
    )
    .replace(
        "f'{train_result.global_step} steps | WER {_wer_part} | '",
        "f'{train_result.global_step} steps | CER {_cer_part} | '"
    )
)
if new != old:
    cells[25]["source"] = [new]
    changed.append("Cell 19: CER in commit msg")
else:
    changed.append("Cell 19: no change")

# ── Cell 20 summary (idx=26) ─────────────────────────────────────────────────
old = "".join(cells[26]["source"])
new = old.replace(
    "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
    "print(f'  Eval WER          : {_wer_disp}')",
    "_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n"
    "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
    "print(f'  Eval CER (primary)  : {_cer_disp}')\n"
    "print(f'  Eval WER (secondary): {_wer_disp}')"
)
if new != old:
    cells[26]["source"] = [new]
    changed.append("Cell 20: CER in summary")
else:
    changed.append("Cell 20: no change")

with open(NB, "w", encoding="utf-8") as f:
    json.dump(nb, f, ensure_ascii=False, indent=1)

for msg in changed:
    print(msg)
print("Done.")