"""Patch kaggle_master_trainer.ipynb: add bam_normalize, replace WER with CER.""" import json, sys, re sys.stdout.reconfigure(encoding="utf-8") NB = "notebooks/kaggle_master_trainer.ipynb" with open(NB, encoding="utf-8") as f: nb = json.load(f) cells = nb["cells"] changed = [] # ── Cell 10 (idx=11): inject _bam_norm definition ──────────────────────────── old = "".join(cells[11]["source"]) OLD_TOP = ( "# -- Cell 10: Text cleaning utilities -----------------------------------------\n" "import re, unicodedata" ) NEW_TOP = ( "# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\n" "import re, unicodedata\n" "\n" "# Phonetic normaliser: unifies French-influenced spellings before training.\n" "# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n" "_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','\u0272'),('ny','\u0272')," "('ch','c'),('oo','\u0254'),('ee','\u025b')]\n" "_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n" "_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n" "\n" "def _bam_norm(text):\n" " import unicodedata as _ud\n" " text = _ud.normalize('NFC', text.lower())\n" " return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n" ) if OLD_TOP in old: cells[11]["source"] = [old.replace(OLD_TOP, NEW_TOP)] changed.append("Cell 10: _bam_norm injected") else: changed.append("Cell 10: OLD_TOP not found - skip") # ── Cell 11 (idx=12): apply _bam_norm in prepare_dataset ───────────────────── old = "".join(cells[12]["source"]) OLD_PREP = " cleaned = clean_text(str(raw_text), lang=lang)" NEW_PREP = ( " _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n" " cleaned = clean_text(_norm_text, lang=lang)" ) if OLD_PREP in old: cells[12]["source"] = [old.replace(OLD_PREP, NEW_PREP)] changed.append("Cell 11: normaliser applied in prepare_dataset") else: changed.append(f"Cell 11: prepare pattern not found ({repr(old[old.find('cleaned'):old.find('cleaned')+60])})") # ── Cell 14 (idx=17): WER -> CER in compute_metrics ────────────────────────── old = "".join(cells[17]["source"]) # Replace header comment new = old.replace( "# -- Cell 14: Data collator + WER metric", "# -- Cell 14: Data collator + CER metric" ) # Add CER transform after existing transform definition OLD_TRANSFORM_END = " jiwer.ReduceToListOfListOfWords(),\n])" NEW_TRANSFORM_END = ( " jiwer.ReduceToListOfListOfWords(),\n" "])\n" "\n" "# CER transform (no word-split step needed)\n" "_cer_transform = jiwer.Compose([\n" " jiwer.ToLowerCase(),\n" " jiwer.RemoveMultipleSpaces(),\n" " jiwer.Strip(),\n" " jiwer.RemovePunctuation(),\n" "])" ) new = new.replace(OLD_TRANSFORM_END, NEW_TRANSFORM_END) # Replace return value in compute_metrics OLD_RETURN = ( " wer = jiwer.wer(label_str, pred_str,\n" " hypothesis_transform=transform,\n" " reference_transform=transform)\n" " return {'wer': round(wer, 4)}" ) NEW_RETURN = ( " cer = jiwer.cer(\n" " label_str, pred_str,\n" " reference_transform=_cer_transform,\n" " hypothesis_transform=_cer_transform,\n" " )\n" " wer = jiwer.wer(label_str, pred_str,\n" " hypothesis_transform=transform,\n" " reference_transform=transform)\n" " return {'cer': round(cer, 4), 'wer': round(wer, 4)}" ) new = new.replace(OLD_RETURN, NEW_RETURN) if new != old: cells[17]["source"] = [new] changed.append("Cell 14: WER->CER in compute_metrics") else: changed.append("Cell 14: no changes applied") # ── Cell 15 (idx=19): metric_for_best_model ─────────────────────────────────── old = "".join(cells[19]["source"]) new = old.replace( " metric_for_best_model='wer',", " metric_for_best_model='cer'," ) if new != old: cells[19]["source"] = [new] changed.append("Cell 15: metric_for_best_model=cer") else: changed.append("Cell 15: no change") # ── Cell 17 (idx=22): CER display in evaluation ─────────────────────────────── old = "".join(cells[22]["source"]) OLD_WER_PRINT = ( "wer_score = eval_results.get('eval_wer', float('nan'))\n" "print(f'\\n? Final WER : {wer_score:.1%}')\n" "print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')" ) NEW_WER_PRINT = ( "cer_score = eval_results.get('eval_cer', float('nan'))\n" "wer_score = eval_results.get('eval_wer', float('nan'))\n" "print(f'\\n\u2705 Final CER : {cer_score:.1%} (primary — lower is better)')\n" "print(f' Final WER : {wer_score:.1%} (secondary)')\n" "print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')" ) if OLD_WER_PRINT in old: cells[22]["source"] = [old.replace(OLD_WER_PRINT, NEW_WER_PRINT)] changed.append("Cell 17: CER display") else: changed.append("Cell 17: print pattern not found") # Try to find what's there idx = old.find("wer_score") if idx >= 0: changed.append(f" ...found: {repr(old[idx:idx+100])}") # ── Cell 19 push (idx=25): cer_score in commit msg ─────────────────────────── old = "".join(cells[25]["source"]) new = ( old .replace( "_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'", "_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'" ) .replace( "f'{train_result.global_step} steps | WER {_wer_part} | '", "f'{train_result.global_step} steps | CER {_cer_part} | '" ) ) if new != old: cells[25]["source"] = [new] changed.append("Cell 19: CER in commit msg") else: changed.append("Cell 19: no change") # ── Cell 20 summary (idx=26) ───────────────────────────────────────────────── old = "".join(cells[26]["source"]) new = old.replace( "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n" "print(f' Eval WER : {_wer_disp}')", "_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n" "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n" "print(f' Eval CER (primary) : {_cer_disp}')\n" "print(f' Eval WER (secondary): {_wer_disp}')" ) if new != old: cells[26]["source"] = [new] changed.append("Cell 20: CER in summary") else: changed.append("Cell 20: no change") with open(NB, "w", encoding="utf-8") as f: json.dump(nb, f, ensure_ascii=False, indent=1) for msg in changed: print(msg) print("Done.")