Buckets:
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "transformers>=5.11,<6", | |
| # "torch>=2.8,<2.10", # a100-large job image driver is CUDA 12.9; torch 2.10+ ships cu13 wheels | |
| # "torchvision", # Gemma4Processor imports it even for text-only use | |
| # "accelerate", | |
| # "huggingface-hub", | |
| # "pillow", | |
| # "pyzipper", | |
| # "requests", | |
| # ] | |
| # /// | |
| """Generation runs for the DiffusionGemma vs Gemma-4 post-OCR correction benchmark. | |
| GENERATION ONLY — all metrics are computed offline by metrics.py from the | |
| JSONL this produces, so metric changes never require re-running GPU jobs. | |
| Designed to run on HF Jobs: | |
| hf jobs uv run --flavor a100-large --timeout 45m \ | |
| -e HF_XET_HIGH_PERFORMANCE=1 -s HF_TOKEN \ | |
| benchmark.py -- --mode smoke | |
| """ | |
| import argparse | |
| import difflib | |
| import gc | |
| import json | |
| import random | |
| import re | |
| import time | |
| from pathlib import Path | |
| import requests | |
| DG_MODEL = "google/diffusiongemma-26B-A4B-it" | |
| G4_MODEL = "google/gemma-4-E4B-it" | |
| # Parameter-matched AR baseline (26B MoE / 4B active, same as DiffusionGemma) — | |
| # the model DeepMind benched DiffusionGemma against (per João Gante). | |
| G4_MOE_MODEL = "google/gemma-4-26B-A4B-it" | |
| # BLN600 (CC-BY-NC-4.0): resolved via the figshare API for DOI 10.15131/shef.data.25439023. | |
| # NC license — the text is downloaded at run time and must never be committed | |
| # to a public repo; raw outputs go to a PRIVATE dataset repo only. | |
| BLN600_API = "https://api.figshare.com/v2/articles/25439023" | |
| BLN600_ZIP_PASSWORD = b"BLN600" | |
| # ICDAR2019 post-OCR (CC-BY-4.0) — fallback eval set + source of the Space's | |
| # committable examples. | |
| ICDAR_URL = "https://zenodo.org/records/3515403/files/ICDAR2019-POCR-ground-truth.zip" | |
| MAX_PASSAGE_TOKENS = 220 # margin under DiffusionGemma's fixed 256-token canvas | |
| N_EXAMPLES = 6 # Space dropdown examples (ICDAR, CC-BY) | |
| PROMPT_TEMPLATE = """\ | |
| Correct the OCR errors in the following text from a 19th-century English newspaper. | |
| Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \ | |
| spelling, do not rephrase, and do not add or remove content. Preserve the original \ | |
| punctuation unless it is clearly an OCR error. | |
| Output only the corrected text, with no commentary or preamble. | |
| OCR text: | |
| {ocr}""" | |
| STOP_MARKERS = ("<turn|>", "<eos>", "<end_of_turn>", "<pad>") | |
| def extract_answer(raw: str) -> tuple[str, str]: | |
| """Split a raw decode into (answer, thought). | |
| DiffusionGemma's generated block looks like | |
| `<|channel>thought\\n<channel|>ANSWER<turn|><eos>...` even with thinking | |
| off (empty thought) — the answer is the text after the LAST `<channel|>`. | |
| Gemma-4 emits plain text; we just cut at the first stop marker. | |
| """ | |
| stops = [i for m in STOP_MARKERS if (i := raw.find(m)) != -1] | |
| if stops: | |
| raw = raw[: min(stops)] | |
| thought = "" | |
| if "<channel|>" in raw: | |
| head, _, raw = raw.rpartition("<channel|>") | |
| m = re.search(r"<\|channel>thought(.*)$", head, flags=re.DOTALL) | |
| if m: | |
| thought = m.group(1).strip() | |
| return raw.strip(), thought | |
| # ---------------------------------------------------------------- data | |
| def _download(url: str, dest: Path, **kwargs) -> Path: | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| if dest.exists(): | |
| return dest | |
| print(f"downloading {url} -> {dest}") | |
| with requests.get(url, stream=True, timeout=600, **kwargs) as r: | |
| r.raise_for_status() | |
| with dest.open("wb") as f: | |
| for chunk in r.iter_content(chunk_size=1 << 20): | |
| f.write(chunk) | |
| return dest | |
| def download_bln600(workdir: Path) -> list[dict]: | |
| """Download + parse BLN600 into [{id, ocr_input, gold}]. Handles both a | |
| CSV layout ('OCR Text'/'Ground Truth' columns) and a folder layout | |
| (OCR Text/*.txt paired with Ground Truth/*.txt by stem).""" | |
| import pyzipper | |
| meta = requests.get(BLN600_API, timeout=60).json() | |
| zips = [f for f in meta["files"] if f["name"].lower().endswith(".zip")] | |
| if not zips: | |
| raise RuntimeError(f"no zip in figshare article files: {[f['name'] for f in meta['files']]}") | |
| zip_path = _download(zips[0]["download_url"], workdir / zips[0]["name"]) | |
| extract_dir = workdir / "bln600" | |
| if not extract_dir.exists(): | |
| with pyzipper.AESZipFile(zip_path) as zf: | |
| zf.setpassword(BLN600_ZIP_PASSWORD) | |
| zf.extractall(extract_dir) | |
| # folder layout: pair OCR Text/ and Ground Truth/ files by stem | |
| ocr_files = {p.stem: p for p in extract_dir.rglob("*.txt") if "ocr" in str(p.parent).lower()} | |
| gold_files = { | |
| p.stem: p | |
| for p in extract_dir.rglob("*.txt") | |
| if "ground" in str(p.parent).lower() or "gold" in str(p.parent).lower() | |
| } | |
| common = sorted(set(ocr_files) & set(gold_files)) | |
| if common: | |
| print(f"BLN600 folder layout: {len(common)} aligned pairs") | |
| return [ | |
| { | |
| "id": f"bln600/{stem}", | |
| "ocr_input": ocr_files[stem].read_text(errors="replace"), | |
| "gold": gold_files[stem].read_text(errors="replace"), | |
| } | |
| for stem in common | |
| ] | |
| # CSV layout fallback | |
| import csv | |
| for csv_path in extract_dir.rglob("*.csv"): | |
| with csv_path.open(newline="", errors="replace") as f: | |
| rows = list(csv.DictReader(f)) | |
| if rows and "OCR Text" in rows[0] and "Ground Truth" in rows[0]: | |
| print(f"BLN600 CSV layout: {len(rows)} rows from {csv_path.name}") | |
| return [ | |
| {"id": f"bln600/{i}", "ocr_input": r["OCR Text"], "gold": r["Ground Truth"]} | |
| for i, r in enumerate(rows) | |
| ] | |
| listing = [str(p.relative_to(extract_dir)) for p in list(extract_dir.rglob("*"))[:40]] | |
| raise RuntimeError(f"could not parse BLN600; archive contents: {listing}") | |
| def download_icdar_english(workdir: Path) -> list[dict]: | |
| """ICDAR2019 post-OCR English subset. Format: per-passage .txt files with | |
| [OCR_toInput]/[OCR_aligned]/[ GS_aligned] lines; '@' are alignment pads.""" | |
| import zipfile | |
| zip_path = _download(ICDAR_URL, workdir / "icdar2019.zip") | |
| extract_dir = workdir / "icdar2019" | |
| if not extract_dir.exists(): | |
| with zipfile.ZipFile(zip_path) as zf: | |
| zf.extractall(extract_dir) | |
| passages = [] | |
| for p in sorted(extract_dir.rglob("*.txt")): | |
| if not re.search(r"(^|/)EN", str(p.relative_to(extract_dir))): | |
| continue | |
| ocr = gold = None | |
| for line in p.read_text(errors="replace").splitlines(): | |
| if line.startswith("[OCR_toInput]"): | |
| ocr = line.removeprefix("[OCR_toInput]").strip() | |
| elif line.startswith("[ GS_aligned]") or line.startswith("[GS_aligned]"): | |
| gold = re.sub("@", "", line.split("]", 1)[1]).strip() | |
| if ocr and gold: | |
| passages.append( | |
| {"id": f"icdar2019/{p.relative_to(extract_dir)}", "ocr_input": ocr, "gold": gold} | |
| ) | |
| print(f"ICDAR2019 English: {len(passages)} passages") | |
| return passages | |
| def trim_pair(ocr: str, gold: str, n_tokens, max_tokens: int) -> tuple[str, str] | None: | |
| """Trim an aligned (ocr, gold) pair so both sides fit in max_tokens. | |
| Cuts at a whitespace position inside a character-aligned "equal" region so | |
| the pair stays parallel after trimming (independent token-count truncation | |
| would misalign the endings and corrupt tail CER). Returns None if no valid | |
| cut point exists. | |
| """ | |
| if n_tokens(ocr) <= max_tokens and n_tokens(gold) <= max_tokens: | |
| return ocr, gold | |
| sm = difflib.SequenceMatcher(None, ocr, gold, autojunk=False) | |
| # candidate (i_cut, j_cut) pairs, ascending: whitespace inside equal blocks | |
| candidates = [ | |
| (i1 + m.start(), j1 + m.start()) | |
| for op, i1, i2, j1, _j2 in sm.get_opcodes() | |
| if op == "equal" | |
| for m in re.finditer(r"\s", ocr[i1:i2]) | |
| ] | |
| if not candidates: | |
| return None | |
| def fits(idx: int) -> bool: | |
| i_cut, j_cut = candidates[idx] | |
| return n_tokens(ocr[:i_cut]) <= max_tokens and n_tokens(gold[:j_cut]) <= max_tokens | |
| if not fits(0): | |
| return None | |
| # token counts grow with cut position -> binary search the largest fit | |
| lo, hi = 0, len(candidates) - 1 | |
| while lo < hi: | |
| mid = (lo + hi + 1) // 2 | |
| if fits(mid): | |
| lo = mid | |
| else: | |
| hi = mid - 1 | |
| i_cut, j_cut = candidates[lo] | |
| return ocr[:i_cut].rstrip(), gold[:j_cut].rstrip() | |
| def sample_passages(passages: list[dict], n: int, seed: int) -> list[dict]: | |
| """Deterministic sample; pairs longer than the canvas are align-trimmed.""" | |
| from transformers import AutoTokenizer | |
| tok = AutoTokenizer.from_pretrained(G4_MODEL) | |
| def n_tokens(text: str) -> int: | |
| return len(tok(text)["input_ids"]) | |
| chosen = random.Random(seed).sample(passages, len(passages)) # seeded order | |
| out: list[dict] = [] | |
| n_trimmed = n_skipped = 0 | |
| for p in chosen: | |
| if len(out) >= n: | |
| break | |
| trimmed = trim_pair(p["ocr_input"], p["gold"], n_tokens, MAX_PASSAGE_TOKENS) | |
| if trimmed is None or len(trimmed[1]) < 200: # drop degenerate/too-short cuts | |
| n_skipped += 1 | |
| continue | |
| if trimmed != (p["ocr_input"], p["gold"]): | |
| n_trimmed += 1 | |
| out.append({"id": p["id"], "ocr_input": trimmed[0], "gold": trimmed[1]}) | |
| print( | |
| f"sampled {len(out)} passages ({n_trimmed} trimmed to <= {MAX_PASSAGE_TOKENS} " | |
| f"tokens, {n_skipped} skipped as untrimmable/too short)" | |
| ) | |
| return out | |
| # ---------------------------------------------------------------- generation | |
| def clean_output(text: str) -> str: | |
| cleaned = re.sub(r"^\s*corrected text:?\s*", "", text.strip(), flags=re.IGNORECASE) | |
| if cleaned != text.strip(): | |
| print(" [clean_output stripped a prefix]") | |
| return cleaned | |
| def count_generated_tokens(generated_ids, tokenizer) -> int: | |
| """Non-pad tokens up to (excluding) the first EOS.""" | |
| ids = generated_ids.tolist() | |
| stop_ids = {tokenizer.eos_token_id, tokenizer.pad_token_id} | |
| count = 0 | |
| for tid in ids: | |
| if tid in stop_ids: | |
| break | |
| count += 1 | |
| return count | |
| def run_model(model_key: str, passages: list[dict], smoke: bool) -> dict[str, dict]: | |
| """Load one model, run all passages, free the model. Returns id -> output dict. | |
| model_key "diffusiongemma_canvas" = same model, but the denoising canvas is | |
| initialised with the OCR text (via the undocumented `decoder_input_ids` | |
| hook in DiffusionGemmaForBlockDiffusion.generate) instead of random tokens | |
| — testing whether correction-as-denoising stays closer to the input. | |
| """ | |
| import torch | |
| from transformers import AutoProcessor | |
| canvas_init = model_key == "diffusiongemma_canvas" | |
| is_dg = model_key.startswith("diffusiongemma") | |
| if is_dg: | |
| from transformers import DiffusionGemmaForBlockDiffusion, TextDiffusionStreamer | |
| class StepCountingStreamer(TextDiffusionStreamer): | |
| """Counts denoising steps; suppresses the default console printing | |
| (the parent prints every draft with ANSI rewrites — unusable in job logs).""" | |
| def __init__(self, tokenizer): | |
| super().__init__(tokenizer=tokenizer) | |
| self.n_steps = 0 | |
| def put_draft(self, value, **kwargs): | |
| self.n_steps += 1 | |
| def put(self, value): | |
| pass | |
| def end(self): | |
| pass | |
| model_id = DG_MODEL | |
| print(f"loading {model_id} ...") | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = DiffusionGemmaForBlockDiffusion.from_pretrained( | |
| model_id, dtype="auto", device_map="auto" | |
| ) | |
| else: | |
| from transformers import AutoModelForMultimodalLM | |
| model_id = G4_MOE_MODEL if model_key == "gemma4_moe" else G4_MODEL | |
| print(f"loading {model_id} ...") | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = AutoModelForMultimodalLM.from_pretrained(model_id, dtype="auto", device_map="auto") | |
| tokenizer = processor.tokenizer | |
| canvas_rng = torch.Generator().manual_seed(0) # deterministic canvas tail padding | |
| def generate(ocr_text: str) -> dict: | |
| message = [{"role": "user", "content": PROMPT_TEMPLATE.format(ocr=ocr_text)}] | |
| inputs = processor.apply_chat_template( | |
| message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" | |
| ).to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| gen_kwargs: dict = {"max_new_tokens": 256} | |
| streamer = None | |
| if is_dg: | |
| # generation_config defaults for the entropy sampler (no greedy equivalent) | |
| streamer = StepCountingStreamer(tokenizer) | |
| gen_kwargs["streamer"] = streamer | |
| if canvas_init: | |
| # Seed the first denoising canvas with the OCR text instead of | |
| # random tokens; pad the tail with random tokens as the sampler | |
| # would. Canvas must be exactly canvas_length wide. | |
| canvas_length = getattr(model.generation_config, "canvas_length", None) or 256 | |
| ids = tokenizer(ocr_text, add_special_tokens=False)["input_ids"][:canvas_length] | |
| vocab = model.config.text_config.vocab_size | |
| pad = torch.randint(vocab, (canvas_length - len(ids),), generator=canvas_rng) | |
| canvas = torch.cat([torch.tensor(ids, dtype=torch.long), pad]) | |
| gen_kwargs["decoder_input_ids"] = canvas.unsqueeze(0).to(model.device) | |
| else: | |
| gen_kwargs["do_sample"] = False # greedy | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| output = model.generate(**inputs, **gen_kwargs) | |
| torch.cuda.synchronize() | |
| seconds = time.perf_counter() - t0 | |
| # DiffusionGemma returns a DiffusionGemmaGenerationOutput (sequences | |
| # includes the prompt, like AR generate); plain tensor for Gemma-4. | |
| seq = output.sequences if hasattr(output, "sequences") else output | |
| generated = seq[0][input_len:] if seq.shape[-1] > input_len else seq[0] | |
| tpf = getattr(output, "tokens_per_forward", None) | |
| if torch.is_tensor(tpf): | |
| tpf = int(tpf.flatten()[0]) | |
| raw = tokenizer.decode(generated, skip_special_tokens=False) | |
| answer, thought = extract_answer(raw) | |
| if thought: | |
| print(f" [WARNING: thought content present ({len(thought)} chars)]") | |
| return { | |
| "text": clean_output(answer), | |
| "_raw": raw, | |
| "seconds": round(seconds, 3), | |
| "tokens_generated": count_generated_tokens(generated, tokenizer), | |
| "denoising_steps": streamer.n_steps if streamer else None, | |
| "tokens_per_forward": tpf, | |
| "thought_chars": len(thought), | |
| } | |
| print("warmup generation (uncounted) ...") | |
| generate(passages[0]["ocr_input"]) | |
| results: dict[str, dict] = {} | |
| for i, p in enumerate(passages): | |
| out = generate(p["ocr_input"]) | |
| raw = out.pop("_raw") | |
| results[p["id"]] = out | |
| print( | |
| f"[{model_key} {i + 1}/{len(passages)}] {out['seconds']}s, " | |
| f"{out['tokens_generated']} tok" | |
| + (f", {out['denoising_steps']} steps" if out["denoising_steps"] else "") | |
| ) | |
| if smoke: | |
| print(f" OCR: {p['ocr_input'][:200]}") | |
| print(f" GOLD: {p['gold'][:200]}") | |
| print(f" RAW: {raw[:300]}") | |
| print(f" OUT: {out['text'][:200]}") | |
| model = None # noqa: F841 — drop the closure-captured ref so the GPU frees | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return results | |
| # ---------------------------------------------------------------- main | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--mode", choices=["smoke", "full"], default="smoke") | |
| parser.add_argument("--n", type=int, default=75) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--models", choices=["both", "dg", "g4", "g4moe", "all"], default="both") | |
| parser.add_argument( | |
| "--canvas-init", | |
| action="store_true", | |
| help="also run DiffusionGemma with the canvas initialised from the OCR text", | |
| ) | |
| parser.add_argument("--dataset", choices=["bln600", "icdar"], default="bln600") | |
| parser.add_argument("--cache-examples", action="store_true") | |
| parser.add_argument("--out-repo", default=None, help="private dataset repo for raw outputs") | |
| parser.add_argument("--workdir", type=Path, default=Path("data")) | |
| args = parser.parse_args() | |
| import torch | |
| import transformers | |
| print(f"transformers {transformers.__version__}, torch {torch.__version__}, " | |
| f"cuda: {torch.cuda.get_device_name(0)}") | |
| loader = download_bln600 if args.dataset == "bln600" else download_icdar_english | |
| passages = sample_passages(loader(args.workdir), args.n if args.mode == "full" else 3, args.seed) | |
| print(f"running {len(passages)} passages, models={args.models}") | |
| examples = [] | |
| if args.cache_examples: | |
| icdar = download_icdar_english(args.workdir) | |
| examples = sample_passages(icdar, N_EXAMPLES, args.seed) | |
| for e in examples: | |
| e["id"] = "example/" + e["id"] | |
| model_keys = { | |
| "both": ["diffusiongemma", "gemma4"], | |
| "dg": ["diffusiongemma"], | |
| "g4": ["gemma4"], | |
| "g4moe": ["gemma4_moe"], | |
| "all": ["diffusiongemma", "gemma4", "gemma4_moe"], | |
| }[args.models] | |
| if args.canvas_init: | |
| model_keys.insert(1, "diffusiongemma_canvas") | |
| all_passages = passages + examples | |
| outputs: dict[str, dict[str, dict]] = {} | |
| for key in model_keys: | |
| try: | |
| outputs[key] = run_model(key, all_passages, smoke=args.mode == "smoke") | |
| except Exception as e: # noqa: BLE001 — a failed condition shouldn't sink the others | |
| if key == "diffusiongemma_canvas": | |
| print(f"[{key} FAILED, continuing without it: {type(e).__name__}: {e}]") | |
| else: | |
| raise | |
| meta = { | |
| "date": time.strftime("%Y-%m-%d"), | |
| "dataset": args.dataset, | |
| "n": len(passages), | |
| "seed": args.seed, | |
| "max_passage_tokens": MAX_PASSAGE_TOKENS, | |
| "prompt": PROMPT_TEMPLATE, | |
| "transformers": transformers.__version__, | |
| "torch": torch.__version__, | |
| "gpu": torch.cuda.get_device_name(0), | |
| "generation": { | |
| "diffusiongemma": "generation_config defaults (entropy sampler), max_new_tokens=256", | |
| "diffusiongemma_canvas": "as diffusiongemma, but first canvas seeded with the OCR" | |
| " text via decoder_input_ids (random tail padding, seed 0)", | |
| "gemma4": "do_sample=False (greedy), max_new_tokens=256", | |
| "gemma4_moe": "do_sample=False (greedy), max_new_tokens=256", | |
| }, | |
| } | |
| out_path = Path("raw_outputs.jsonl") | |
| with out_path.open("w") as f: | |
| for i, p in enumerate(passages): | |
| record = { | |
| "id": p["id"], | |
| "ocr_input": p["ocr_input"], | |
| "gold": p["gold"], | |
| "output": {k: outputs[k][p["id"]] for k in model_keys if k in outputs}, | |
| } | |
| if i == 0: | |
| record["meta"] = meta | |
| f.write(json.dumps(record) + "\n") | |
| print(f"wrote {out_path} ({len(passages)} records)") | |
| cache_path = None | |
| if examples: | |
| cache_path = Path("examples_cached.json") | |
| cache_path.write_text( | |
| json.dumps( | |
| [ | |
| { | |
| "id": e["id"], | |
| "ocr_input": e["ocr_input"], | |
| "gold": e["gold"], | |
| "output": {k: outputs[k][e["id"]] for k in model_keys if k in outputs}, | |
| } | |
| for e in examples | |
| ], | |
| indent=2, | |
| ) | |
| ) | |
| print(f"wrote {cache_path}") | |
| if args.out_repo: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| api.create_repo(args.out_repo, repo_type="dataset", private=True, exist_ok=True) | |
| api.upload_file( | |
| path_or_fileobj=out_path, path_in_repo=out_path.name, | |
| repo_id=args.out_repo, repo_type="dataset", | |
| ) | |
| if cache_path: | |
| api.upload_file( | |
| path_or_fileobj=cache_path, path_in_repo=cache_path.name, | |
| repo_id=args.out_repo, repo_type="dataset", | |
| ) | |
| print(f"uploaded to https://huggingface.co/datasets/{args.out_repo} (private)") | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 20.8 kB
- Xet hash:
- d371aeac2af63d6a057caeb4d105cae9e1e38ce4160f7686b79dad9457834c8a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.