davanstrien's picture
download
raw
10.1 kB
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "jiwer>=3.0",
# ]
# ///
"""Metrics for the DiffusionGemma vs Gemma-4 post-OCR correction benchmark.
Computes CER/WER (jiwer), relative CER reduction, over-correction rate and
fix rate from the raw_outputs.jsonl produced by benchmark.py. No GPU, no
model downloads — metric bugs never require re-running generation.
Usage:
uv run metrics.py test
uv run metrics.py summarize --infile results/raw_outputs.jsonl --outdir results/
"""
import argparse
import difflib
import json
import re
import statistics
import unicodedata
from pathlib import Path
import jiwer
MODELS = ["diffusiongemma", "diffusiongemma_canvas", "gemma4", "gemma4_moe"]
SUFFIX = {
"diffusiongemma": "dg",
"diffusiongemma_canvas": "dgc",
"gemma4": "g4",
"gemma4_moe": "g4moe",
}
MODEL_LABELS = {
"diffusiongemma": "DiffusionGemma 26B-A4B-it",
"diffusiongemma_canvas": "DiffusionGemma (OCR-seeded canvas)",
"gemma4": "Gemma-4-E4B-it",
"gemma4_moe": "Gemma-4-26B-A4B-it (MoE)",
}
def normalize(text: str) -> str:
"""NFC + collapse whitespace runs. Deliberately lenient: layout/whitespace
reconstruction is not what this benchmark measures."""
text = unicodedata.normalize("NFC", text)
return re.sub(r"\s+", " ", text).strip()
def _equal_input_positions(a: str, b: str) -> set[int]:
"""Indices of `a` that survive unchanged in `b`, per SequenceMatcher.
autojunk must stay off: it silently degrades alignments on texts >~200 chars."""
sm = difflib.SequenceMatcher(None, a, b, autojunk=False)
positions: set[int] = set()
for op, i1, i2, _j1, _j2 in sm.get_opcodes():
if op == "equal":
positions.update(range(i1, i2))
return positions
def overcorrection_and_fix_rate(
ocr: str, gold: str, output: str
) -> tuple[float | None, float | None]:
"""Character-level over-correction rate and fix rate.
C = input positions already correct (input<->gold alignment)
P = input positions preserved in the model output (input<->output alignment)
over-correction = |C - P| / |C| (already-correct chars the model changed)
fix rate = |wrong - P| / |wrong| (wrong chars the model changed)
Both sets are input-side indices, so output length changes are handled
natively. Returns None for a rate whose denominator is 0.
"""
correct = _equal_input_positions(ocr, gold)
preserved = _equal_input_positions(ocr, output)
wrong = set(range(len(ocr))) - correct
overcorr = len(correct - preserved) / len(correct) if correct else None
fixrate = len(wrong - preserved) / len(wrong) if wrong else None
return overcorr, fixrate
def passage_metrics(record: dict) -> dict:
"""Per-passage metrics from one raw_outputs.jsonl record. No text fields
in the result — output is safe to publish."""
ocr = normalize(record["ocr_input"])
gold = normalize(record["gold"])
row: dict = {"id": record["id"], "n_chars_gold": len(gold)}
row["cer_input"] = jiwer.cer(gold, ocr) if ocr else 1.0
row["wer_input"] = jiwer.wer(gold, ocr) if ocr else 1.0
for model in MODELS:
out = record["output"].get(model)
if out is None:
continue
hyp = normalize(out["text"])
suffix = SUFFIX[model]
cer = jiwer.cer(gold, hyp) if hyp else 1.0
row[f"cer_{suffix}"] = cer
row[f"wer_{suffix}"] = jiwer.wer(gold, hyp) if hyp else 1.0
row[f"rel_cer_red_{suffix}"] = (
(row["cer_input"] - cer) / row["cer_input"] if row["cer_input"] > 0 else None
)
overcorr, fixrate = overcorrection_and_fix_rate(ocr, gold, hyp)
row[f"overcorr_{suffix}"] = overcorr
row[f"fixrate_{suffix}"] = fixrate
row[f"seconds_{suffix}"] = out.get("seconds")
tokens = out.get("tokens_generated")
row[f"tok_s_{suffix}"] = (
tokens / out["seconds"] if tokens and out.get("seconds") else None
)
if out.get("denoising_steps") is not None:
row[f"denoising_steps_{suffix}"] = out["denoising_steps"]
return row
def _mean(values: list) -> float | None:
vals = [v for v in values if v is not None]
return statistics.mean(vals) if vals else None
def _median(values: list) -> float | None:
vals = [v for v in values if v is not None]
return statistics.median(vals) if vals else None
def _fmt(value: float | None, pct: bool = False, digits: int = 3) -> str:
if value is None:
return "—"
if pct:
return f"{value * 100:.1f}%"
return f"{value:.{digits}f}"
def summarize(infile: Path, outdir: Path) -> str:
records = [json.loads(line) for line in infile.read_text().splitlines() if line.strip()]
rows = [passage_metrics(r) for r in records]
outdir.mkdir(parents=True, exist_ok=True)
with (outdir / "per_passage_metrics.jsonl").open("w") as f:
for row in rows:
f.write(json.dumps(row) + "\n")
present = [m for m in MODELS if m in records[0]["output"]]
# corpus-level micro CER (footnote row)
golds = [normalize(r["gold"]) for r in records]
micro = {"input": jiwer.cer(golds, [normalize(r["ocr_input"]) for r in records])}
for model in present:
hyps = [normalize(r["output"][model]["text"]) for r in records if model in r["output"]]
if len(hyps) == len(golds):
micro[SUFFIX[model]] = jiwer.cer(golds, hyps)
g = lambda key: [row.get(key) for row in rows] # noqa: E731
lines = [
"# Benchmark summary",
"",
f"Passages: {len(rows)} · macro means over passages (micro CER in footnote)",
"",
"| Model | CER ↓ | WER ↓ | Rel. CER reduction ↑ | Over-correction ↓ | Fix rate ↑ | Median s/passage | tok/s |",
"|---|---|---|---|---|---|---|---|",
f"| OCR input (uncorrected) | {_fmt(_mean(g('cer_input')))} | {_fmt(_mean(g('wer_input')))} | — | — | — | — | — |",
]
for model in present:
suffix = SUFFIX[model]
lines.append(
f"| {MODEL_LABELS[model]} "
f"| {_fmt(_mean(g(f'cer_{suffix}')))} "
f"| {_fmt(_mean(g(f'wer_{suffix}')))} "
f"| {_fmt(_mean(g(f'rel_cer_red_{suffix}')), pct=True)} "
f"| {_fmt(_mean(g(f'overcorr_{suffix}')), pct=True)} "
f"| {_fmt(_mean(g(f'fixrate_{suffix}')), pct=True)} "
f"| {_fmt(_median(g(f'seconds_{suffix}')), digits=2)} "
f"| {_fmt(_mean(g(f'tok_s_{suffix}')), digits=1)} |"
)
micro_parts = ", ".join(
f"{MODEL_LABELS.get(m, m)}: {_fmt(micro.get(SUFFIX[m]))}" for m in present
)
lines += ["", f"Micro (corpus-level) CER — input: {_fmt(micro.get('input'))}, {micro_parts}."]
for model in present:
steps = _mean(g(f"denoising_steps_{SUFFIX[model]}"))
if steps is not None:
lines.append(f"Mean denoising steps, {MODEL_LABELS[model]}: {steps:.1f} (max 48).")
if meta := next((r.get("meta") for r in records if r.get("meta")), None):
lines += ["", "## Config", "", "```json", json.dumps(meta, indent=2), "```"]
summary = "\n".join(lines) + "\n"
(outdir / "summary.md").write_text(summary)
return summary
# ---------------------------------------------------------------- tests
def run_tests() -> None:
n = normalize
assert n("a\n\n b\t c ") == "a b c"
# perfect copy of a partially-wrong input: changed nothing -> overcorr 0
ocr, gold = "the qvick brown fox", "the quick brown fox"
oc, fx = overcorrection_and_fix_rate(ocr, gold, ocr)
assert oc == 0.0, oc
assert fx == 0.0, fx
# perfect correction: fixed everything, touched nothing correct
oc, fx = overcorrection_and_fix_rate(ocr, gold, gold)
assert oc == 0.0, oc
assert fx == 1.0, fx
# full rewrite: every already-correct char lost
oc, fx = overcorrection_and_fix_rate(ocr, gold, "zzzzzzzz")
assert oc == 1.0, oc
assert fx == 1.0, fx
# empty output -> all correct chars lost
oc, _ = overcorrection_and_fix_rate(ocr, gold, "")
assert oc == 1.0, oc
# insertion-only edit: input chars all preserved -> overcorr 0
oc, _ = overcorrection_and_fix_rate("the cat sat", "the cat sat down", "the cat sat down")
assert oc == 0.0, oc
# deletion of a correct word counts as over-correction
oc, _ = overcorrection_and_fix_rate("the cat sat", "the cat sat", "the sat")
assert oc is not None and oc > 0.0, oc
# input identical to gold: no wrong chars -> fix rate None
_, fx = overcorrection_and_fix_rate("abc", "abc", "abc")
assert fx is None
# input shares nothing with gold -> overcorr None (excluded)
oc, _ = overcorrection_and_fix_rate("xxxx", "yyyy", "yyyy")
assert oc is None
# passage_metrics end-to-end shape, no text fields
row = passage_metrics(
{
"id": "t1",
"ocr_input": ocr,
"gold": gold,
"output": {
"diffusiongemma": {"text": gold, "seconds": 2.0, "tokens_generated": 10, "denoising_steps": 7},
"gemma4": {"text": ocr, "seconds": 1.0, "tokens_generated": 10},
},
}
)
assert row["cer_dg"] == 0.0
assert row["overcorr_dg"] == 0.0 and row["fixrate_dg"] == 1.0
assert row["overcorr_g4"] == 0.0 and row["fixrate_g4"] == 0.0
assert row["tok_s_dg"] == 5.0
assert row["denoising_steps_dg"] == 7
assert not any("text" in k or "input" in k and isinstance(row[k], str) for k in row)
print("all tests passed")
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
sub = parser.add_subparsers(dest="cmd", required=True)
sub.add_parser("test")
s = sub.add_parser("summarize")
s.add_argument("--infile", type=Path, default=Path("results/raw_outputs.jsonl"))
s.add_argument("--outdir", type=Path, default=Path("results"))
args = parser.parse_args()
if args.cmd == "test":
run_tests()
else:
print(summarize(args.infile, args.outdir))
if __name__ == "__main__":
main()

Xet Storage Details

Size:
10.1 kB
·
Xet hash:
3f3d64a35ddd347c11001a2505a38503c1f8186372f28f7bf09973a96aaa63a0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.