Spaces:
Running
Running
File size: 7,120 Bytes
8952fff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | """Patch kaggle_master_trainer.ipynb: add bam_normalize, replace WER with CER."""
import json, sys, re
sys.stdout.reconfigure(encoding="utf-8")
NB = "notebooks/kaggle_master_trainer.ipynb"
with open(NB, encoding="utf-8") as f:
nb = json.load(f)
cells = nb["cells"]
changed = []
# ── Cell 10 (idx=11): inject _bam_norm definition ────────────────────────────
old = "".join(cells[11]["source"])
OLD_TOP = (
"# -- Cell 10: Text cleaning utilities -----------------------------------------\n"
"import re, unicodedata"
)
NEW_TOP = (
"# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\n"
"import re, unicodedata\n"
"\n"
"# Phonetic normaliser: unifies French-influenced spellings before training.\n"
"# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n"
"_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','\u0272'),('ny','\u0272'),"
"('ch','c'),('oo','\u0254'),('ee','\u025b')]\n"
"_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n"
"_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n"
"\n"
"def _bam_norm(text):\n"
" import unicodedata as _ud\n"
" text = _ud.normalize('NFC', text.lower())\n"
" return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n"
)
if OLD_TOP in old:
cells[11]["source"] = [old.replace(OLD_TOP, NEW_TOP)]
changed.append("Cell 10: _bam_norm injected")
else:
changed.append("Cell 10: OLD_TOP not found - skip")
# ── Cell 11 (idx=12): apply _bam_norm in prepare_dataset ─────────────────────
old = "".join(cells[12]["source"])
OLD_PREP = " cleaned = clean_text(str(raw_text), lang=lang)"
NEW_PREP = (
" _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n"
" cleaned = clean_text(_norm_text, lang=lang)"
)
if OLD_PREP in old:
cells[12]["source"] = [old.replace(OLD_PREP, NEW_PREP)]
changed.append("Cell 11: normaliser applied in prepare_dataset")
else:
changed.append(f"Cell 11: prepare pattern not found ({repr(old[old.find('cleaned'):old.find('cleaned')+60])})")
# ── Cell 14 (idx=17): WER -> CER in compute_metrics ──────────────────────────
old = "".join(cells[17]["source"])
# Replace header comment
new = old.replace(
"# -- Cell 14: Data collator + WER metric",
"# -- Cell 14: Data collator + CER metric"
)
# Add CER transform after existing transform definition
OLD_TRANSFORM_END = " jiwer.ReduceToListOfListOfWords(),\n])"
NEW_TRANSFORM_END = (
" jiwer.ReduceToListOfListOfWords(),\n"
"])\n"
"\n"
"# CER transform (no word-split step needed)\n"
"_cer_transform = jiwer.Compose([\n"
" jiwer.ToLowerCase(),\n"
" jiwer.RemoveMultipleSpaces(),\n"
" jiwer.Strip(),\n"
" jiwer.RemovePunctuation(),\n"
"])"
)
new = new.replace(OLD_TRANSFORM_END, NEW_TRANSFORM_END)
# Replace return value in compute_metrics
OLD_RETURN = (
" wer = jiwer.wer(label_str, pred_str,\n"
" hypothesis_transform=transform,\n"
" reference_transform=transform)\n"
" return {'wer': round(wer, 4)}"
)
NEW_RETURN = (
" cer = jiwer.cer(\n"
" label_str, pred_str,\n"
" reference_transform=_cer_transform,\n"
" hypothesis_transform=_cer_transform,\n"
" )\n"
" wer = jiwer.wer(label_str, pred_str,\n"
" hypothesis_transform=transform,\n"
" reference_transform=transform)\n"
" return {'cer': round(cer, 4), 'wer': round(wer, 4)}"
)
new = new.replace(OLD_RETURN, NEW_RETURN)
if new != old:
cells[17]["source"] = [new]
changed.append("Cell 14: WER->CER in compute_metrics")
else:
changed.append("Cell 14: no changes applied")
# ── Cell 15 (idx=19): metric_for_best_model ───────────────────────────────────
old = "".join(cells[19]["source"])
new = old.replace(
" metric_for_best_model='wer',",
" metric_for_best_model='cer',"
)
if new != old:
cells[19]["source"] = [new]
changed.append("Cell 15: metric_for_best_model=cer")
else:
changed.append("Cell 15: no change")
# ── Cell 17 (idx=22): CER display in evaluation ───────────────────────────────
old = "".join(cells[22]["source"])
OLD_WER_PRINT = (
"wer_score = eval_results.get('eval_wer', float('nan'))\n"
"print(f'\\n? Final WER : {wer_score:.1%}')\n"
"print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
)
NEW_WER_PRINT = (
"cer_score = eval_results.get('eval_cer', float('nan'))\n"
"wer_score = eval_results.get('eval_wer', float('nan'))\n"
"print(f'\\n\u2705 Final CER : {cer_score:.1%} (primary — lower is better)')\n"
"print(f' Final WER : {wer_score:.1%} (secondary)')\n"
"print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
)
if OLD_WER_PRINT in old:
cells[22]["source"] = [old.replace(OLD_WER_PRINT, NEW_WER_PRINT)]
changed.append("Cell 17: CER display")
else:
changed.append("Cell 17: print pattern not found")
# Try to find what's there
idx = old.find("wer_score")
if idx >= 0:
changed.append(f" ...found: {repr(old[idx:idx+100])}")
# ── Cell 19 push (idx=25): cer_score in commit msg ───────────────────────────
old = "".join(cells[25]["source"])
new = (
old
.replace(
"_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'",
"_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'"
)
.replace(
"f'{train_result.global_step} steps | WER {_wer_part} | '",
"f'{train_result.global_step} steps | CER {_cer_part} | '"
)
)
if new != old:
cells[25]["source"] = [new]
changed.append("Cell 19: CER in commit msg")
else:
changed.append("Cell 19: no change")
# ── Cell 20 summary (idx=26) ─────────────────────────────────────────────────
old = "".join(cells[26]["source"])
new = old.replace(
"_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
"print(f' Eval WER : {_wer_disp}')",
"_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n"
"_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
"print(f' Eval CER (primary) : {_cer_disp}')\n"
"print(f' Eval WER (secondary): {_wer_disp}')"
)
if new != old:
cells[26]["source"] = [new]
changed.append("Cell 20: CER in summary")
else:
changed.append("Cell 20: no change")
with open(NB, "w", encoding="utf-8") as f:
json.dump(nb, f, ensure_ascii=False, indent=1)
for msg in changed:
print(msg)
print("Done.")
|