davanstrien's picture
download
raw
20.8 kB
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "transformers>=5.11,<6",
# "torch>=2.8,<2.10", # a100-large job image driver is CUDA 12.9; torch 2.10+ ships cu13 wheels
# "torchvision", # Gemma4Processor imports it even for text-only use
# "accelerate",
# "huggingface-hub",
# "pillow",
# "pyzipper",
# "requests",
# ]
# ///
"""Generation runs for the DiffusionGemma vs Gemma-4 post-OCR correction benchmark.
GENERATION ONLY — all metrics are computed offline by metrics.py from the
JSONL this produces, so metric changes never require re-running GPU jobs.
Designed to run on HF Jobs:
hf jobs uv run --flavor a100-large --timeout 45m \
-e HF_XET_HIGH_PERFORMANCE=1 -s HF_TOKEN \
benchmark.py -- --mode smoke
"""
import argparse
import difflib
import gc
import json
import random
import re
import time
from pathlib import Path
import requests
DG_MODEL = "google/diffusiongemma-26B-A4B-it"
G4_MODEL = "google/gemma-4-E4B-it"
# Parameter-matched AR baseline (26B MoE / 4B active, same as DiffusionGemma) —
# the model DeepMind benched DiffusionGemma against (per João Gante).
G4_MOE_MODEL = "google/gemma-4-26B-A4B-it"
# BLN600 (CC-BY-NC-4.0): resolved via the figshare API for DOI 10.15131/shef.data.25439023.
# NC license — the text is downloaded at run time and must never be committed
# to a public repo; raw outputs go to a PRIVATE dataset repo only.
BLN600_API = "https://api.figshare.com/v2/articles/25439023"
BLN600_ZIP_PASSWORD = b"BLN600"
# ICDAR2019 post-OCR (CC-BY-4.0) — fallback eval set + source of the Space's
# committable examples.
ICDAR_URL = "https://zenodo.org/records/3515403/files/ICDAR2019-POCR-ground-truth.zip"
MAX_PASSAGE_TOKENS = 220 # margin under DiffusionGemma's fixed 256-token canvas
N_EXAMPLES = 6 # Space dropdown examples (ICDAR, CC-BY)
PROMPT_TEMPLATE = """\
Correct the OCR errors in the following text from a 19th-century English newspaper.
Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \
spelling, do not rephrase, and do not add or remove content. Preserve the original \
punctuation unless it is clearly an OCR error.
Output only the corrected text, with no commentary or preamble.
OCR text:
{ocr}"""
STOP_MARKERS = ("<turn|>", "<eos>", "<end_of_turn>", "<pad>")
def extract_answer(raw: str) -> tuple[str, str]:
"""Split a raw decode into (answer, thought).
DiffusionGemma's generated block looks like
`<|channel>thought\\n<channel|>ANSWER<turn|><eos>...` even with thinking
off (empty thought) — the answer is the text after the LAST `<channel|>`.
Gemma-4 emits plain text; we just cut at the first stop marker.
"""
stops = [i for m in STOP_MARKERS if (i := raw.find(m)) != -1]
if stops:
raw = raw[: min(stops)]
thought = ""
if "<channel|>" in raw:
head, _, raw = raw.rpartition("<channel|>")
m = re.search(r"<\|channel>thought(.*)$", head, flags=re.DOTALL)
if m:
thought = m.group(1).strip()
return raw.strip(), thought
# ---------------------------------------------------------------- data
def _download(url: str, dest: Path, **kwargs) -> Path:
dest.parent.mkdir(parents=True, exist_ok=True)
if dest.exists():
return dest
print(f"downloading {url} -> {dest}")
with requests.get(url, stream=True, timeout=600, **kwargs) as r:
r.raise_for_status()
with dest.open("wb") as f:
for chunk in r.iter_content(chunk_size=1 << 20):
f.write(chunk)
return dest
def download_bln600(workdir: Path) -> list[dict]:
"""Download + parse BLN600 into [{id, ocr_input, gold}]. Handles both a
CSV layout ('OCR Text'/'Ground Truth' columns) and a folder layout
(OCR Text/*.txt paired with Ground Truth/*.txt by stem)."""
import pyzipper
meta = requests.get(BLN600_API, timeout=60).json()
zips = [f for f in meta["files"] if f["name"].lower().endswith(".zip")]
if not zips:
raise RuntimeError(f"no zip in figshare article files: {[f['name'] for f in meta['files']]}")
zip_path = _download(zips[0]["download_url"], workdir / zips[0]["name"])
extract_dir = workdir / "bln600"
if not extract_dir.exists():
with pyzipper.AESZipFile(zip_path) as zf:
zf.setpassword(BLN600_ZIP_PASSWORD)
zf.extractall(extract_dir)
# folder layout: pair OCR Text/ and Ground Truth/ files by stem
ocr_files = {p.stem: p for p in extract_dir.rglob("*.txt") if "ocr" in str(p.parent).lower()}
gold_files = {
p.stem: p
for p in extract_dir.rglob("*.txt")
if "ground" in str(p.parent).lower() or "gold" in str(p.parent).lower()
}
common = sorted(set(ocr_files) & set(gold_files))
if common:
print(f"BLN600 folder layout: {len(common)} aligned pairs")
return [
{
"id": f"bln600/{stem}",
"ocr_input": ocr_files[stem].read_text(errors="replace"),
"gold": gold_files[stem].read_text(errors="replace"),
}
for stem in common
]
# CSV layout fallback
import csv
for csv_path in extract_dir.rglob("*.csv"):
with csv_path.open(newline="", errors="replace") as f:
rows = list(csv.DictReader(f))
if rows and "OCR Text" in rows[0] and "Ground Truth" in rows[0]:
print(f"BLN600 CSV layout: {len(rows)} rows from {csv_path.name}")
return [
{"id": f"bln600/{i}", "ocr_input": r["OCR Text"], "gold": r["Ground Truth"]}
for i, r in enumerate(rows)
]
listing = [str(p.relative_to(extract_dir)) for p in list(extract_dir.rglob("*"))[:40]]
raise RuntimeError(f"could not parse BLN600; archive contents: {listing}")
def download_icdar_english(workdir: Path) -> list[dict]:
"""ICDAR2019 post-OCR English subset. Format: per-passage .txt files with
[OCR_toInput]/[OCR_aligned]/[ GS_aligned] lines; '@' are alignment pads."""
import zipfile
zip_path = _download(ICDAR_URL, workdir / "icdar2019.zip")
extract_dir = workdir / "icdar2019"
if not extract_dir.exists():
with zipfile.ZipFile(zip_path) as zf:
zf.extractall(extract_dir)
passages = []
for p in sorted(extract_dir.rglob("*.txt")):
if not re.search(r"(^|/)EN", str(p.relative_to(extract_dir))):
continue
ocr = gold = None
for line in p.read_text(errors="replace").splitlines():
if line.startswith("[OCR_toInput]"):
ocr = line.removeprefix("[OCR_toInput]").strip()
elif line.startswith("[ GS_aligned]") or line.startswith("[GS_aligned]"):
gold = re.sub("@", "", line.split("]", 1)[1]).strip()
if ocr and gold:
passages.append(
{"id": f"icdar2019/{p.relative_to(extract_dir)}", "ocr_input": ocr, "gold": gold}
)
print(f"ICDAR2019 English: {len(passages)} passages")
return passages
def trim_pair(ocr: str, gold: str, n_tokens, max_tokens: int) -> tuple[str, str] | None:
"""Trim an aligned (ocr, gold) pair so both sides fit in max_tokens.
Cuts at a whitespace position inside a character-aligned "equal" region so
the pair stays parallel after trimming (independent token-count truncation
would misalign the endings and corrupt tail CER). Returns None if no valid
cut point exists.
"""
if n_tokens(ocr) <= max_tokens and n_tokens(gold) <= max_tokens:
return ocr, gold
sm = difflib.SequenceMatcher(None, ocr, gold, autojunk=False)
# candidate (i_cut, j_cut) pairs, ascending: whitespace inside equal blocks
candidates = [
(i1 + m.start(), j1 + m.start())
for op, i1, i2, j1, _j2 in sm.get_opcodes()
if op == "equal"
for m in re.finditer(r"\s", ocr[i1:i2])
]
if not candidates:
return None
def fits(idx: int) -> bool:
i_cut, j_cut = candidates[idx]
return n_tokens(ocr[:i_cut]) <= max_tokens and n_tokens(gold[:j_cut]) <= max_tokens
if not fits(0):
return None
# token counts grow with cut position -> binary search the largest fit
lo, hi = 0, len(candidates) - 1
while lo < hi:
mid = (lo + hi + 1) // 2
if fits(mid):
lo = mid
else:
hi = mid - 1
i_cut, j_cut = candidates[lo]
return ocr[:i_cut].rstrip(), gold[:j_cut].rstrip()
def sample_passages(passages: list[dict], n: int, seed: int) -> list[dict]:
"""Deterministic sample; pairs longer than the canvas are align-trimmed."""
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(G4_MODEL)
def n_tokens(text: str) -> int:
return len(tok(text)["input_ids"])
chosen = random.Random(seed).sample(passages, len(passages)) # seeded order
out: list[dict] = []
n_trimmed = n_skipped = 0
for p in chosen:
if len(out) >= n:
break
trimmed = trim_pair(p["ocr_input"], p["gold"], n_tokens, MAX_PASSAGE_TOKENS)
if trimmed is None or len(trimmed[1]) < 200: # drop degenerate/too-short cuts
n_skipped += 1
continue
if trimmed != (p["ocr_input"], p["gold"]):
n_trimmed += 1
out.append({"id": p["id"], "ocr_input": trimmed[0], "gold": trimmed[1]})
print(
f"sampled {len(out)} passages ({n_trimmed} trimmed to <= {MAX_PASSAGE_TOKENS} "
f"tokens, {n_skipped} skipped as untrimmable/too short)"
)
return out
# ---------------------------------------------------------------- generation
def clean_output(text: str) -> str:
cleaned = re.sub(r"^\s*corrected text:?\s*", "", text.strip(), flags=re.IGNORECASE)
if cleaned != text.strip():
print(" [clean_output stripped a prefix]")
return cleaned
def count_generated_tokens(generated_ids, tokenizer) -> int:
"""Non-pad tokens up to (excluding) the first EOS."""
ids = generated_ids.tolist()
stop_ids = {tokenizer.eos_token_id, tokenizer.pad_token_id}
count = 0
for tid in ids:
if tid in stop_ids:
break
count += 1
return count
def run_model(model_key: str, passages: list[dict], smoke: bool) -> dict[str, dict]:
"""Load one model, run all passages, free the model. Returns id -> output dict.
model_key "diffusiongemma_canvas" = same model, but the denoising canvas is
initialised with the OCR text (via the undocumented `decoder_input_ids`
hook in DiffusionGemmaForBlockDiffusion.generate) instead of random tokens
— testing whether correction-as-denoising stays closer to the input.
"""
import torch
from transformers import AutoProcessor
canvas_init = model_key == "diffusiongemma_canvas"
is_dg = model_key.startswith("diffusiongemma")
if is_dg:
from transformers import DiffusionGemmaForBlockDiffusion, TextDiffusionStreamer
class StepCountingStreamer(TextDiffusionStreamer):
"""Counts denoising steps; suppresses the default console printing
(the parent prints every draft with ANSI rewrites — unusable in job logs)."""
def __init__(self, tokenizer):
super().__init__(tokenizer=tokenizer)
self.n_steps = 0
def put_draft(self, value, **kwargs):
self.n_steps += 1
def put(self, value):
pass
def end(self):
pass
model_id = DG_MODEL
print(f"loading {model_id} ...")
processor = AutoProcessor.from_pretrained(model_id)
model = DiffusionGemmaForBlockDiffusion.from_pretrained(
model_id, dtype="auto", device_map="auto"
)
else:
from transformers import AutoModelForMultimodalLM
model_id = G4_MOE_MODEL if model_key == "gemma4_moe" else G4_MODEL
print(f"loading {model_id} ...")
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForMultimodalLM.from_pretrained(model_id, dtype="auto", device_map="auto")
tokenizer = processor.tokenizer
canvas_rng = torch.Generator().manual_seed(0) # deterministic canvas tail padding
def generate(ocr_text: str) -> dict:
message = [{"role": "user", "content": PROMPT_TEMPLATE.format(ocr=ocr_text)}]
inputs = processor.apply_chat_template(
message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
gen_kwargs: dict = {"max_new_tokens": 256}
streamer = None
if is_dg:
# generation_config defaults for the entropy sampler (no greedy equivalent)
streamer = StepCountingStreamer(tokenizer)
gen_kwargs["streamer"] = streamer
if canvas_init:
# Seed the first denoising canvas with the OCR text instead of
# random tokens; pad the tail with random tokens as the sampler
# would. Canvas must be exactly canvas_length wide.
canvas_length = getattr(model.generation_config, "canvas_length", None) or 256
ids = tokenizer(ocr_text, add_special_tokens=False)["input_ids"][:canvas_length]
vocab = model.config.text_config.vocab_size
pad = torch.randint(vocab, (canvas_length - len(ids),), generator=canvas_rng)
canvas = torch.cat([torch.tensor(ids, dtype=torch.long), pad])
gen_kwargs["decoder_input_ids"] = canvas.unsqueeze(0).to(model.device)
else:
gen_kwargs["do_sample"] = False # greedy
torch.cuda.synchronize()
t0 = time.perf_counter()
output = model.generate(**inputs, **gen_kwargs)
torch.cuda.synchronize()
seconds = time.perf_counter() - t0
# DiffusionGemma returns a DiffusionGemmaGenerationOutput (sequences
# includes the prompt, like AR generate); plain tensor for Gemma-4.
seq = output.sequences if hasattr(output, "sequences") else output
generated = seq[0][input_len:] if seq.shape[-1] > input_len else seq[0]
tpf = getattr(output, "tokens_per_forward", None)
if torch.is_tensor(tpf):
tpf = int(tpf.flatten()[0])
raw = tokenizer.decode(generated, skip_special_tokens=False)
answer, thought = extract_answer(raw)
if thought:
print(f" [WARNING: thought content present ({len(thought)} chars)]")
return {
"text": clean_output(answer),
"_raw": raw,
"seconds": round(seconds, 3),
"tokens_generated": count_generated_tokens(generated, tokenizer),
"denoising_steps": streamer.n_steps if streamer else None,
"tokens_per_forward": tpf,
"thought_chars": len(thought),
}
print("warmup generation (uncounted) ...")
generate(passages[0]["ocr_input"])
results: dict[str, dict] = {}
for i, p in enumerate(passages):
out = generate(p["ocr_input"])
raw = out.pop("_raw")
results[p["id"]] = out
print(
f"[{model_key} {i + 1}/{len(passages)}] {out['seconds']}s, "
f"{out['tokens_generated']} tok"
+ (f", {out['denoising_steps']} steps" if out["denoising_steps"] else "")
)
if smoke:
print(f" OCR: {p['ocr_input'][:200]}")
print(f" GOLD: {p['gold'][:200]}")
print(f" RAW: {raw[:300]}")
print(f" OUT: {out['text'][:200]}")
model = None # noqa: F841 — drop the closure-captured ref so the GPU frees
gc.collect()
torch.cuda.empty_cache()
return results
# ---------------------------------------------------------------- main
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--mode", choices=["smoke", "full"], default="smoke")
parser.add_argument("--n", type=int, default=75)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--models", choices=["both", "dg", "g4", "g4moe", "all"], default="both")
parser.add_argument(
"--canvas-init",
action="store_true",
help="also run DiffusionGemma with the canvas initialised from the OCR text",
)
parser.add_argument("--dataset", choices=["bln600", "icdar"], default="bln600")
parser.add_argument("--cache-examples", action="store_true")
parser.add_argument("--out-repo", default=None, help="private dataset repo for raw outputs")
parser.add_argument("--workdir", type=Path, default=Path("data"))
args = parser.parse_args()
import torch
import transformers
print(f"transformers {transformers.__version__}, torch {torch.__version__}, "
f"cuda: {torch.cuda.get_device_name(0)}")
loader = download_bln600 if args.dataset == "bln600" else download_icdar_english
passages = sample_passages(loader(args.workdir), args.n if args.mode == "full" else 3, args.seed)
print(f"running {len(passages)} passages, models={args.models}")
examples = []
if args.cache_examples:
icdar = download_icdar_english(args.workdir)
examples = sample_passages(icdar, N_EXAMPLES, args.seed)
for e in examples:
e["id"] = "example/" + e["id"]
model_keys = {
"both": ["diffusiongemma", "gemma4"],
"dg": ["diffusiongemma"],
"g4": ["gemma4"],
"g4moe": ["gemma4_moe"],
"all": ["diffusiongemma", "gemma4", "gemma4_moe"],
}[args.models]
if args.canvas_init:
model_keys.insert(1, "diffusiongemma_canvas")
all_passages = passages + examples
outputs: dict[str, dict[str, dict]] = {}
for key in model_keys:
try:
outputs[key] = run_model(key, all_passages, smoke=args.mode == "smoke")
except Exception as e: # noqa: BLE001 — a failed condition shouldn't sink the others
if key == "diffusiongemma_canvas":
print(f"[{key} FAILED, continuing without it: {type(e).__name__}: {e}]")
else:
raise
meta = {
"date": time.strftime("%Y-%m-%d"),
"dataset": args.dataset,
"n": len(passages),
"seed": args.seed,
"max_passage_tokens": MAX_PASSAGE_TOKENS,
"prompt": PROMPT_TEMPLATE,
"transformers": transformers.__version__,
"torch": torch.__version__,
"gpu": torch.cuda.get_device_name(0),
"generation": {
"diffusiongemma": "generation_config defaults (entropy sampler), max_new_tokens=256",
"diffusiongemma_canvas": "as diffusiongemma, but first canvas seeded with the OCR"
" text via decoder_input_ids (random tail padding, seed 0)",
"gemma4": "do_sample=False (greedy), max_new_tokens=256",
"gemma4_moe": "do_sample=False (greedy), max_new_tokens=256",
},
}
out_path = Path("raw_outputs.jsonl")
with out_path.open("w") as f:
for i, p in enumerate(passages):
record = {
"id": p["id"],
"ocr_input": p["ocr_input"],
"gold": p["gold"],
"output": {k: outputs[k][p["id"]] for k in model_keys if k in outputs},
}
if i == 0:
record["meta"] = meta
f.write(json.dumps(record) + "\n")
print(f"wrote {out_path} ({len(passages)} records)")
cache_path = None
if examples:
cache_path = Path("examples_cached.json")
cache_path.write_text(
json.dumps(
[
{
"id": e["id"],
"ocr_input": e["ocr_input"],
"gold": e["gold"],
"output": {k: outputs[k][e["id"]] for k in model_keys if k in outputs},
}
for e in examples
],
indent=2,
)
)
print(f"wrote {cache_path}")
if args.out_repo:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(args.out_repo, repo_type="dataset", private=True, exist_ok=True)
api.upload_file(
path_or_fileobj=out_path, path_in_repo=out_path.name,
repo_id=args.out_repo, repo_type="dataset",
)
if cache_path:
api.upload_file(
path_or_fileobj=cache_path, path_in_repo=cache_path.name,
repo_id=args.out_repo, repo_type="dataset",
)
print(f"uploaded to https://huggingface.co/datasets/{args.out_repo} (private)")
if __name__ == "__main__":
main()

Xet Storage Details

Size:
20.8 kB
·
Xet hash:
d371aeac2af63d6a057caeb4d105cae9e1e38ce4160f7686b79dad9457834c8a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.