ground-zero / scripts /patch_notebook_cer.py
jefffffff9
Phase 3: Voice-to-Voice S2S pipeline — F5-TTS, LLM brain, CER metric
8952fff
"""Patch kaggle_master_trainer.ipynb: add bam_normalize, replace WER with CER."""
import json, sys, re
sys.stdout.reconfigure(encoding="utf-8")
NB = "notebooks/kaggle_master_trainer.ipynb"
with open(NB, encoding="utf-8") as f:
nb = json.load(f)
cells = nb["cells"]
changed = []
# ── Cell 10 (idx=11): inject _bam_norm definition ────────────────────────────
old = "".join(cells[11]["source"])
OLD_TOP = (
"# -- Cell 10: Text cleaning utilities -----------------------------------------\n"
"import re, unicodedata"
)
NEW_TOP = (
"# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\n"
"import re, unicodedata\n"
"\n"
"# Phonetic normaliser: unifies French-influenced spellings before training.\n"
"# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n"
"_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','\u0272'),('ny','\u0272'),"
"('ch','c'),('oo','\u0254'),('ee','\u025b')]\n"
"_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n"
"_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n"
"\n"
"def _bam_norm(text):\n"
" import unicodedata as _ud\n"
" text = _ud.normalize('NFC', text.lower())\n"
" return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n"
)
if OLD_TOP in old:
cells[11]["source"] = [old.replace(OLD_TOP, NEW_TOP)]
changed.append("Cell 10: _bam_norm injected")
else:
changed.append("Cell 10: OLD_TOP not found - skip")
# ── Cell 11 (idx=12): apply _bam_norm in prepare_dataset ─────────────────────
old = "".join(cells[12]["source"])
OLD_PREP = " cleaned = clean_text(str(raw_text), lang=lang)"
NEW_PREP = (
" _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n"
" cleaned = clean_text(_norm_text, lang=lang)"
)
if OLD_PREP in old:
cells[12]["source"] = [old.replace(OLD_PREP, NEW_PREP)]
changed.append("Cell 11: normaliser applied in prepare_dataset")
else:
changed.append(f"Cell 11: prepare pattern not found ({repr(old[old.find('cleaned'):old.find('cleaned')+60])})")
# ── Cell 14 (idx=17): WER -> CER in compute_metrics ──────────────────────────
old = "".join(cells[17]["source"])
# Replace header comment
new = old.replace(
"# -- Cell 14: Data collator + WER metric",
"# -- Cell 14: Data collator + CER metric"
)
# Add CER transform after existing transform definition
OLD_TRANSFORM_END = " jiwer.ReduceToListOfListOfWords(),\n])"
NEW_TRANSFORM_END = (
" jiwer.ReduceToListOfListOfWords(),\n"
"])\n"
"\n"
"# CER transform (no word-split step needed)\n"
"_cer_transform = jiwer.Compose([\n"
" jiwer.ToLowerCase(),\n"
" jiwer.RemoveMultipleSpaces(),\n"
" jiwer.Strip(),\n"
" jiwer.RemovePunctuation(),\n"
"])"
)
new = new.replace(OLD_TRANSFORM_END, NEW_TRANSFORM_END)
# Replace return value in compute_metrics
OLD_RETURN = (
" wer = jiwer.wer(label_str, pred_str,\n"
" hypothesis_transform=transform,\n"
" reference_transform=transform)\n"
" return {'wer': round(wer, 4)}"
)
NEW_RETURN = (
" cer = jiwer.cer(\n"
" label_str, pred_str,\n"
" reference_transform=_cer_transform,\n"
" hypothesis_transform=_cer_transform,\n"
" )\n"
" wer = jiwer.wer(label_str, pred_str,\n"
" hypothesis_transform=transform,\n"
" reference_transform=transform)\n"
" return {'cer': round(cer, 4), 'wer': round(wer, 4)}"
)
new = new.replace(OLD_RETURN, NEW_RETURN)
if new != old:
cells[17]["source"] = [new]
changed.append("Cell 14: WER->CER in compute_metrics")
else:
changed.append("Cell 14: no changes applied")
# ── Cell 15 (idx=19): metric_for_best_model ───────────────────────────────────
old = "".join(cells[19]["source"])
new = old.replace(
" metric_for_best_model='wer',",
" metric_for_best_model='cer',"
)
if new != old:
cells[19]["source"] = [new]
changed.append("Cell 15: metric_for_best_model=cer")
else:
changed.append("Cell 15: no change")
# ── Cell 17 (idx=22): CER display in evaluation ───────────────────────────────
old = "".join(cells[22]["source"])
OLD_WER_PRINT = (
"wer_score = eval_results.get('eval_wer', float('nan'))\n"
"print(f'\\n? Final WER : {wer_score:.1%}')\n"
"print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
)
NEW_WER_PRINT = (
"cer_score = eval_results.get('eval_cer', float('nan'))\n"
"wer_score = eval_results.get('eval_wer', float('nan'))\n"
"print(f'\\n\u2705 Final CER : {cer_score:.1%} (primary — lower is better)')\n"
"print(f' Final WER : {wer_score:.1%} (secondary)')\n"
"print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
)
if OLD_WER_PRINT in old:
cells[22]["source"] = [old.replace(OLD_WER_PRINT, NEW_WER_PRINT)]
changed.append("Cell 17: CER display")
else:
changed.append("Cell 17: print pattern not found")
# Try to find what's there
idx = old.find("wer_score")
if idx >= 0:
changed.append(f" ...found: {repr(old[idx:idx+100])}")
# ── Cell 19 push (idx=25): cer_score in commit msg ───────────────────────────
old = "".join(cells[25]["source"])
new = (
old
.replace(
"_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'",
"_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'"
)
.replace(
"f'{train_result.global_step} steps | WER {_wer_part} | '",
"f'{train_result.global_step} steps | CER {_cer_part} | '"
)
)
if new != old:
cells[25]["source"] = [new]
changed.append("Cell 19: CER in commit msg")
else:
changed.append("Cell 19: no change")
# ── Cell 20 summary (idx=26) ─────────────────────────────────────────────────
old = "".join(cells[26]["source"])
new = old.replace(
"_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
"print(f' Eval WER : {_wer_disp}')",
"_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n"
"_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
"print(f' Eval CER (primary) : {_cer_disp}')\n"
"print(f' Eval WER (secondary): {_wer_disp}')"
)
if new != old:
cells[26]["source"] = [new]
changed.append("Cell 20: CER in summary")
else:
changed.append("Cell 20: no change")
with open(NB, "w", encoding="utf-8") as f:
json.dump(nb, f, ensure_ascii=False, indent=1)
for msg in changed:
print(msg)
print("Done.")