Spaces:
Running
Running
| """Patch kaggle_master_trainer.ipynb: add bam_normalize, replace WER with CER.""" | |
| import json, sys, re | |
| sys.stdout.reconfigure(encoding="utf-8") | |
| NB = "notebooks/kaggle_master_trainer.ipynb" | |
| with open(NB, encoding="utf-8") as f: | |
| nb = json.load(f) | |
| cells = nb["cells"] | |
| changed = [] | |
| # ── Cell 10 (idx=11): inject _bam_norm definition ──────────────────────────── | |
| old = "".join(cells[11]["source"]) | |
| OLD_TOP = ( | |
| "# -- Cell 10: Text cleaning utilities -----------------------------------------\n" | |
| "import re, unicodedata" | |
| ) | |
| NEW_TOP = ( | |
| "# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\n" | |
| "import re, unicodedata\n" | |
| "\n" | |
| "# Phonetic normaliser: unifies French-influenced spellings before training.\n" | |
| "# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n" | |
| "_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','\u0272'),('ny','\u0272')," | |
| "('ch','c'),('oo','\u0254'),('ee','\u025b')]\n" | |
| "_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n" | |
| "_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n" | |
| "\n" | |
| "def _bam_norm(text):\n" | |
| " import unicodedata as _ud\n" | |
| " text = _ud.normalize('NFC', text.lower())\n" | |
| " return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n" | |
| ) | |
| if OLD_TOP in old: | |
| cells[11]["source"] = [old.replace(OLD_TOP, NEW_TOP)] | |
| changed.append("Cell 10: _bam_norm injected") | |
| else: | |
| changed.append("Cell 10: OLD_TOP not found - skip") | |
| # ── Cell 11 (idx=12): apply _bam_norm in prepare_dataset ───────────────────── | |
| old = "".join(cells[12]["source"]) | |
| OLD_PREP = " cleaned = clean_text(str(raw_text), lang=lang)" | |
| NEW_PREP = ( | |
| " _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n" | |
| " cleaned = clean_text(_norm_text, lang=lang)" | |
| ) | |
| if OLD_PREP in old: | |
| cells[12]["source"] = [old.replace(OLD_PREP, NEW_PREP)] | |
| changed.append("Cell 11: normaliser applied in prepare_dataset") | |
| else: | |
| changed.append(f"Cell 11: prepare pattern not found ({repr(old[old.find('cleaned'):old.find('cleaned')+60])})") | |
| # ── Cell 14 (idx=17): WER -> CER in compute_metrics ────────────────────────── | |
| old = "".join(cells[17]["source"]) | |
| # Replace header comment | |
| new = old.replace( | |
| "# -- Cell 14: Data collator + WER metric", | |
| "# -- Cell 14: Data collator + CER metric" | |
| ) | |
| # Add CER transform after existing transform definition | |
| OLD_TRANSFORM_END = " jiwer.ReduceToListOfListOfWords(),\n])" | |
| NEW_TRANSFORM_END = ( | |
| " jiwer.ReduceToListOfListOfWords(),\n" | |
| "])\n" | |
| "\n" | |
| "# CER transform (no word-split step needed)\n" | |
| "_cer_transform = jiwer.Compose([\n" | |
| " jiwer.ToLowerCase(),\n" | |
| " jiwer.RemoveMultipleSpaces(),\n" | |
| " jiwer.Strip(),\n" | |
| " jiwer.RemovePunctuation(),\n" | |
| "])" | |
| ) | |
| new = new.replace(OLD_TRANSFORM_END, NEW_TRANSFORM_END) | |
| # Replace return value in compute_metrics | |
| OLD_RETURN = ( | |
| " wer = jiwer.wer(label_str, pred_str,\n" | |
| " hypothesis_transform=transform,\n" | |
| " reference_transform=transform)\n" | |
| " return {'wer': round(wer, 4)}" | |
| ) | |
| NEW_RETURN = ( | |
| " cer = jiwer.cer(\n" | |
| " label_str, pred_str,\n" | |
| " reference_transform=_cer_transform,\n" | |
| " hypothesis_transform=_cer_transform,\n" | |
| " )\n" | |
| " wer = jiwer.wer(label_str, pred_str,\n" | |
| " hypothesis_transform=transform,\n" | |
| " reference_transform=transform)\n" | |
| " return {'cer': round(cer, 4), 'wer': round(wer, 4)}" | |
| ) | |
| new = new.replace(OLD_RETURN, NEW_RETURN) | |
| if new != old: | |
| cells[17]["source"] = [new] | |
| changed.append("Cell 14: WER->CER in compute_metrics") | |
| else: | |
| changed.append("Cell 14: no changes applied") | |
| # ── Cell 15 (idx=19): metric_for_best_model ─────────────────────────────────── | |
| old = "".join(cells[19]["source"]) | |
| new = old.replace( | |
| " metric_for_best_model='wer',", | |
| " metric_for_best_model='cer'," | |
| ) | |
| if new != old: | |
| cells[19]["source"] = [new] | |
| changed.append("Cell 15: metric_for_best_model=cer") | |
| else: | |
| changed.append("Cell 15: no change") | |
| # ── Cell 17 (idx=22): CER display in evaluation ─────────────────────────────── | |
| old = "".join(cells[22]["source"]) | |
| OLD_WER_PRINT = ( | |
| "wer_score = eval_results.get('eval_wer', float('nan'))\n" | |
| "print(f'\\n? Final WER : {wer_score:.1%}')\n" | |
| "print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')" | |
| ) | |
| NEW_WER_PRINT = ( | |
| "cer_score = eval_results.get('eval_cer', float('nan'))\n" | |
| "wer_score = eval_results.get('eval_wer', float('nan'))\n" | |
| "print(f'\\n\u2705 Final CER : {cer_score:.1%} (primary — lower is better)')\n" | |
| "print(f' Final WER : {wer_score:.1%} (secondary)')\n" | |
| "print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')" | |
| ) | |
| if OLD_WER_PRINT in old: | |
| cells[22]["source"] = [old.replace(OLD_WER_PRINT, NEW_WER_PRINT)] | |
| changed.append("Cell 17: CER display") | |
| else: | |
| changed.append("Cell 17: print pattern not found") | |
| # Try to find what's there | |
| idx = old.find("wer_score") | |
| if idx >= 0: | |
| changed.append(f" ...found: {repr(old[idx:idx+100])}") | |
| # ── Cell 19 push (idx=25): cer_score in commit msg ─────────────────────────── | |
| old = "".join(cells[25]["source"]) | |
| new = ( | |
| old | |
| .replace( | |
| "_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'", | |
| "_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'" | |
| ) | |
| .replace( | |
| "f'{train_result.global_step} steps | WER {_wer_part} | '", | |
| "f'{train_result.global_step} steps | CER {_cer_part} | '" | |
| ) | |
| ) | |
| if new != old: | |
| cells[25]["source"] = [new] | |
| changed.append("Cell 19: CER in commit msg") | |
| else: | |
| changed.append("Cell 19: no change") | |
| # ── Cell 20 summary (idx=26) ───────────────────────────────────────────────── | |
| old = "".join(cells[26]["source"]) | |
| new = old.replace( | |
| "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n" | |
| "print(f' Eval WER : {_wer_disp}')", | |
| "_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n" | |
| "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n" | |
| "print(f' Eval CER (primary) : {_cer_disp}')\n" | |
| "print(f' Eval WER (secondary): {_wer_disp}')" | |
| ) | |
| if new != old: | |
| cells[26]["source"] = [new] | |
| changed.append("Cell 20: CER in summary") | |
| else: | |
| changed.append("Cell 20: no change") | |
| with open(NB, "w", encoding="utf-8") as f: | |
| json.dump(nb, f, ensure_ascii=False, indent=1) | |
| for msg in changed: | |
| print(msg) | |
| print("Done.") | |