Buckets:
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "jiwer>=3.0", | |
| # ] | |
| # /// | |
| """Metrics for the DiffusionGemma vs Gemma-4 post-OCR correction benchmark. | |
| Computes CER/WER (jiwer), relative CER reduction, over-correction rate and | |
| fix rate from the raw_outputs.jsonl produced by benchmark.py. No GPU, no | |
| model downloads — metric bugs never require re-running generation. | |
| Usage: | |
| uv run metrics.py test | |
| uv run metrics.py summarize --infile results/raw_outputs.jsonl --outdir results/ | |
| """ | |
| import argparse | |
| import difflib | |
| import json | |
| import re | |
| import statistics | |
| import unicodedata | |
| from pathlib import Path | |
| import jiwer | |
| MODELS = ["diffusiongemma", "diffusiongemma_canvas", "gemma4", "gemma4_moe"] | |
| SUFFIX = { | |
| "diffusiongemma": "dg", | |
| "diffusiongemma_canvas": "dgc", | |
| "gemma4": "g4", | |
| "gemma4_moe": "g4moe", | |
| } | |
| MODEL_LABELS = { | |
| "diffusiongemma": "DiffusionGemma 26B-A4B-it", | |
| "diffusiongemma_canvas": "DiffusionGemma (OCR-seeded canvas)", | |
| "gemma4": "Gemma-4-E4B-it", | |
| "gemma4_moe": "Gemma-4-26B-A4B-it (MoE)", | |
| } | |
| def normalize(text: str) -> str: | |
| """NFC + collapse whitespace runs. Deliberately lenient: layout/whitespace | |
| reconstruction is not what this benchmark measures.""" | |
| text = unicodedata.normalize("NFC", text) | |
| return re.sub(r"\s+", " ", text).strip() | |
| def _equal_input_positions(a: str, b: str) -> set[int]: | |
| """Indices of `a` that survive unchanged in `b`, per SequenceMatcher. | |
| autojunk must stay off: it silently degrades alignments on texts >~200 chars.""" | |
| sm = difflib.SequenceMatcher(None, a, b, autojunk=False) | |
| positions: set[int] = set() | |
| for op, i1, i2, _j1, _j2 in sm.get_opcodes(): | |
| if op == "equal": | |
| positions.update(range(i1, i2)) | |
| return positions | |
| def overcorrection_and_fix_rate( | |
| ocr: str, gold: str, output: str | |
| ) -> tuple[float | None, float | None]: | |
| """Character-level over-correction rate and fix rate. | |
| C = input positions already correct (input<->gold alignment) | |
| P = input positions preserved in the model output (input<->output alignment) | |
| over-correction = |C - P| / |C| (already-correct chars the model changed) | |
| fix rate = |wrong - P| / |wrong| (wrong chars the model changed) | |
| Both sets are input-side indices, so output length changes are handled | |
| natively. Returns None for a rate whose denominator is 0. | |
| """ | |
| correct = _equal_input_positions(ocr, gold) | |
| preserved = _equal_input_positions(ocr, output) | |
| wrong = set(range(len(ocr))) - correct | |
| overcorr = len(correct - preserved) / len(correct) if correct else None | |
| fixrate = len(wrong - preserved) / len(wrong) if wrong else None | |
| return overcorr, fixrate | |
| def passage_metrics(record: dict) -> dict: | |
| """Per-passage metrics from one raw_outputs.jsonl record. No text fields | |
| in the result — output is safe to publish.""" | |
| ocr = normalize(record["ocr_input"]) | |
| gold = normalize(record["gold"]) | |
| row: dict = {"id": record["id"], "n_chars_gold": len(gold)} | |
| row["cer_input"] = jiwer.cer(gold, ocr) if ocr else 1.0 | |
| row["wer_input"] = jiwer.wer(gold, ocr) if ocr else 1.0 | |
| for model in MODELS: | |
| out = record["output"].get(model) | |
| if out is None: | |
| continue | |
| hyp = normalize(out["text"]) | |
| suffix = SUFFIX[model] | |
| cer = jiwer.cer(gold, hyp) if hyp else 1.0 | |
| row[f"cer_{suffix}"] = cer | |
| row[f"wer_{suffix}"] = jiwer.wer(gold, hyp) if hyp else 1.0 | |
| row[f"rel_cer_red_{suffix}"] = ( | |
| (row["cer_input"] - cer) / row["cer_input"] if row["cer_input"] > 0 else None | |
| ) | |
| overcorr, fixrate = overcorrection_and_fix_rate(ocr, gold, hyp) | |
| row[f"overcorr_{suffix}"] = overcorr | |
| row[f"fixrate_{suffix}"] = fixrate | |
| row[f"seconds_{suffix}"] = out.get("seconds") | |
| tokens = out.get("tokens_generated") | |
| row[f"tok_s_{suffix}"] = ( | |
| tokens / out["seconds"] if tokens and out.get("seconds") else None | |
| ) | |
| if out.get("denoising_steps") is not None: | |
| row[f"denoising_steps_{suffix}"] = out["denoising_steps"] | |
| return row | |
| def _mean(values: list) -> float | None: | |
| vals = [v for v in values if v is not None] | |
| return statistics.mean(vals) if vals else None | |
| def _median(values: list) -> float | None: | |
| vals = [v for v in values if v is not None] | |
| return statistics.median(vals) if vals else None | |
| def _fmt(value: float | None, pct: bool = False, digits: int = 3) -> str: | |
| if value is None: | |
| return "—" | |
| if pct: | |
| return f"{value * 100:.1f}%" | |
| return f"{value:.{digits}f}" | |
| def summarize(infile: Path, outdir: Path) -> str: | |
| records = [json.loads(line) for line in infile.read_text().splitlines() if line.strip()] | |
| rows = [passage_metrics(r) for r in records] | |
| outdir.mkdir(parents=True, exist_ok=True) | |
| with (outdir / "per_passage_metrics.jsonl").open("w") as f: | |
| for row in rows: | |
| f.write(json.dumps(row) + "\n") | |
| present = [m for m in MODELS if m in records[0]["output"]] | |
| # corpus-level micro CER (footnote row) | |
| golds = [normalize(r["gold"]) for r in records] | |
| micro = {"input": jiwer.cer(golds, [normalize(r["ocr_input"]) for r in records])} | |
| for model in present: | |
| hyps = [normalize(r["output"][model]["text"]) for r in records if model in r["output"]] | |
| if len(hyps) == len(golds): | |
| micro[SUFFIX[model]] = jiwer.cer(golds, hyps) | |
| g = lambda key: [row.get(key) for row in rows] # noqa: E731 | |
| lines = [ | |
| "# Benchmark summary", | |
| "", | |
| f"Passages: {len(rows)} · macro means over passages (micro CER in footnote)", | |
| "", | |
| "| Model | CER ↓ | WER ↓ | Rel. CER reduction ↑ | Over-correction ↓ | Fix rate ↑ | Median s/passage | tok/s |", | |
| "|---|---|---|---|---|---|---|---|", | |
| f"| OCR input (uncorrected) | {_fmt(_mean(g('cer_input')))} | {_fmt(_mean(g('wer_input')))} | — | — | — | — | — |", | |
| ] | |
| for model in present: | |
| suffix = SUFFIX[model] | |
| lines.append( | |
| f"| {MODEL_LABELS[model]} " | |
| f"| {_fmt(_mean(g(f'cer_{suffix}')))} " | |
| f"| {_fmt(_mean(g(f'wer_{suffix}')))} " | |
| f"| {_fmt(_mean(g(f'rel_cer_red_{suffix}')), pct=True)} " | |
| f"| {_fmt(_mean(g(f'overcorr_{suffix}')), pct=True)} " | |
| f"| {_fmt(_mean(g(f'fixrate_{suffix}')), pct=True)} " | |
| f"| {_fmt(_median(g(f'seconds_{suffix}')), digits=2)} " | |
| f"| {_fmt(_mean(g(f'tok_s_{suffix}')), digits=1)} |" | |
| ) | |
| micro_parts = ", ".join( | |
| f"{MODEL_LABELS.get(m, m)}: {_fmt(micro.get(SUFFIX[m]))}" for m in present | |
| ) | |
| lines += ["", f"Micro (corpus-level) CER — input: {_fmt(micro.get('input'))}, {micro_parts}."] | |
| for model in present: | |
| steps = _mean(g(f"denoising_steps_{SUFFIX[model]}")) | |
| if steps is not None: | |
| lines.append(f"Mean denoising steps, {MODEL_LABELS[model]}: {steps:.1f} (max 48).") | |
| if meta := next((r.get("meta") for r in records if r.get("meta")), None): | |
| lines += ["", "## Config", "", "```json", json.dumps(meta, indent=2), "```"] | |
| summary = "\n".join(lines) + "\n" | |
| (outdir / "summary.md").write_text(summary) | |
| return summary | |
| # ---------------------------------------------------------------- tests | |
| def run_tests() -> None: | |
| n = normalize | |
| assert n("a\n\n b\t c ") == "a b c" | |
| # perfect copy of a partially-wrong input: changed nothing -> overcorr 0 | |
| ocr, gold = "the qvick brown fox", "the quick brown fox" | |
| oc, fx = overcorrection_and_fix_rate(ocr, gold, ocr) | |
| assert oc == 0.0, oc | |
| assert fx == 0.0, fx | |
| # perfect correction: fixed everything, touched nothing correct | |
| oc, fx = overcorrection_and_fix_rate(ocr, gold, gold) | |
| assert oc == 0.0, oc | |
| assert fx == 1.0, fx | |
| # full rewrite: every already-correct char lost | |
| oc, fx = overcorrection_and_fix_rate(ocr, gold, "zzzzzzzz") | |
| assert oc == 1.0, oc | |
| assert fx == 1.0, fx | |
| # empty output -> all correct chars lost | |
| oc, _ = overcorrection_and_fix_rate(ocr, gold, "") | |
| assert oc == 1.0, oc | |
| # insertion-only edit: input chars all preserved -> overcorr 0 | |
| oc, _ = overcorrection_and_fix_rate("the cat sat", "the cat sat down", "the cat sat down") | |
| assert oc == 0.0, oc | |
| # deletion of a correct word counts as over-correction | |
| oc, _ = overcorrection_and_fix_rate("the cat sat", "the cat sat", "the sat") | |
| assert oc is not None and oc > 0.0, oc | |
| # input identical to gold: no wrong chars -> fix rate None | |
| _, fx = overcorrection_and_fix_rate("abc", "abc", "abc") | |
| assert fx is None | |
| # input shares nothing with gold -> overcorr None (excluded) | |
| oc, _ = overcorrection_and_fix_rate("xxxx", "yyyy", "yyyy") | |
| assert oc is None | |
| # passage_metrics end-to-end shape, no text fields | |
| row = passage_metrics( | |
| { | |
| "id": "t1", | |
| "ocr_input": ocr, | |
| "gold": gold, | |
| "output": { | |
| "diffusiongemma": {"text": gold, "seconds": 2.0, "tokens_generated": 10, "denoising_steps": 7}, | |
| "gemma4": {"text": ocr, "seconds": 1.0, "tokens_generated": 10}, | |
| }, | |
| } | |
| ) | |
| assert row["cer_dg"] == 0.0 | |
| assert row["overcorr_dg"] == 0.0 and row["fixrate_dg"] == 1.0 | |
| assert row["overcorr_g4"] == 0.0 and row["fixrate_g4"] == 0.0 | |
| assert row["tok_s_dg"] == 5.0 | |
| assert row["denoising_steps_dg"] == 7 | |
| assert not any("text" in k or "input" in k and isinstance(row[k], str) for k in row) | |
| print("all tests passed") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| sub = parser.add_subparsers(dest="cmd", required=True) | |
| sub.add_parser("test") | |
| s = sub.add_parser("summarize") | |
| s.add_argument("--infile", type=Path, default=Path("results/raw_outputs.jsonl")) | |
| s.add_argument("--outdir", type=Path, default=Path("results")) | |
| args = parser.parse_args() | |
| if args.cmd == "test": | |
| run_tests() | |
| else: | |
| print(summarize(args.infile, args.outdir)) | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 10.1 kB
- Xet hash:
- 3f3d64a35ddd347c11001a2505a38503c1f8186372f28f7bf09973a96aaa63a0
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.