davanstrien's picture
download
raw
7.76 kB
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "transformers>=5.11,<6",
# "torch>=2.8,<2.10", # a100-large job image driver is CUDA 12.9; torch 2.10+ ships cu13 wheels
# "torchvision",
# "accelerate",
# "huggingface-hub",
# "pillow",
# ]
# ///
"""Image-arm vibe check: does giving DiffusionGemma the source page image help?
Runs the 6 image-paired Space examples (4 Commoner 1901 regions sharing one
full newspaper page, 2 Britannica 1771 leafs with long s) through three
conditions, DiffusionGemma only:
text OCR text -> corrected text (v1 benchmark baseline)
image_text page image + OCR text -> corrected text (grounded correction)
image_only page image -> transcription, anchored by the passage's opening
words (512 new tokens = first multi-canvas generation test)
No gold for these examples — output is qualitative, eyeballed via job logs and
a JSON artifact uploaded to a private dataset repo.
Run on HF Jobs:
hf jobs uv run --flavor a100-large --timeout 30m \
-e HF_XET_HIGH_PERFORMANCE=1 -s HF_TOKEN \
vibe_image.py
"""
import json
import time
from pathlib import Path
SPACE_REPO = "davanstrien/diffusiongemma-ocr-correction"
DG_MODEL = "google/diffusiongemma-26B-A4B-it"
OUT_REPO = "davanstrien/diffusiongemma-image-vibe"
CORRECT_PROMPT = """\
Correct the OCR errors in the following text from a historical printed page.
Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \
spelling, do not rephrase, and do not add or remove content. Preserve the original \
punctuation unless it is clearly an OCR error.
Output only the corrected text, with no commentary or preamble.
OCR text:
{ocr}"""
GROUNDED_PROMPT = """\
The image shows the printed source page for the OCR text below.
Correct the OCR errors in the text, using the image as the authoritative reference. \
Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \
spelling, do not rephrase, and do not add or remove content. Preserve the original \
punctuation unless it is clearly an OCR error.
Output only the corrected text, with no commentary or preamble.
OCR text:
{ocr}"""
TRANSCRIBE_PROMPT = """\
Transcribe the passage in the image that begins with the words: "{anchor}".
Reproduce the printed text exactly as it appears, preserving original spelling \
and punctuation. Output only the transcription, with no commentary or preamble."""
STOP_MARKERS = ("<turn|>", "<eos>", "<end_of_turn>", "<pad>")
def extract_answer(raw: str) -> str:
"""Answer is the text after the LAST <channel|> (thought channel is empty
with thinking off); cut at the first stop marker. Same as benchmark.py."""
stops = [i for m in STOP_MARKERS if (i := raw.find(m)) != -1]
if stops:
raw = raw[: min(stops)]
if "<channel|>" in raw:
_, _, raw = raw.rpartition("<channel|>")
return raw.strip()
def main() -> None:
import torch
from huggingface_hub import HfApi, hf_hub_download
from PIL import Image
from transformers import (
AutoProcessor,
DiffusionGemmaForBlockDiffusion,
TextDiffusionStreamer,
)
class StepCountingStreamer(TextDiffusionStreamer):
def __init__(self, tokenizer):
super().__init__(tokenizer=tokenizer)
self.n_steps = 0
def put_draft(self, value, **kwargs):
self.n_steps += 1
def put(self, value):
pass
def end(self):
pass
examples_path = hf_hub_download(SPACE_REPO, "examples.json", repo_type="space")
examples = [e for e in json.loads(Path(examples_path).read_text()) if e.get("image")]
print(f"{len(examples)} image-paired examples")
images: dict[str, Image.Image] = {}
for e in examples:
fname = "images/" + e["image"].removeprefix("static/")
if fname not in images:
p = hf_hub_download(SPACE_REPO, fname, repo_type="space")
images[fname] = Image.open(p).convert("RGB")
print(f"{fname}: {images[fname].size}")
e["_image"] = images[fname]
print(f"loading {DG_MODEL} ...")
processor = AutoProcessor.from_pretrained(DG_MODEL)
model = DiffusionGemmaForBlockDiffusion.from_pretrained(
DG_MODEL, dtype="auto", device_map="auto"
)
tokenizer = processor.tokenizer
def generate(content: list[dict], max_new_tokens: int) -> dict:
message = [{"role": "user", "content": content}]
inputs = processor.apply_chat_template(
message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
streamer = StepCountingStreamer(tokenizer)
torch.cuda.synchronize()
t0 = time.perf_counter()
output = model.generate(**inputs, max_new_tokens=max_new_tokens, streamer=streamer)
torch.cuda.synchronize()
seconds = time.perf_counter() - t0
seq = output.sequences if hasattr(output, "sequences") else output
generated = seq[0][input_len:] if seq.shape[-1] > input_len else seq[0]
raw = tokenizer.decode(generated, skip_special_tokens=False)
return {
"text": extract_answer(raw),
"raw": raw,
"seconds": round(seconds, 3),
"denoising_steps": streamer.n_steps,
"prompt_tokens": input_len,
}
def conditions(e: dict) -> dict[str, tuple[list[dict], int]]:
ocr = e["ocr_input"]
anchor = " ".join(ocr.split()[:8])
img = e["_image"]
# model card: image before text for multimodal prompts
return {
"text": ([{"type": "text", "text": CORRECT_PROMPT.format(ocr=ocr)}], 256),
"image_text": (
[{"type": "image", "image": img},
{"type": "text", "text": GROUNDED_PROMPT.format(ocr=ocr)}],
256,
),
"image_only": (
[{"type": "image", "image": img},
{"type": "text", "text": TRANSCRIBE_PROMPT.format(anchor=anchor)}],
512,
),
}
print("warmup generation (uncounted) ...")
generate([{"type": "text", "text": CORRECT_PROMPT.format(ocr=examples[0]["ocr_input"])}], 256)
results = []
for e in examples:
rec = {"id": e["id"], "label": e["label"], "ocr_input": e["ocr_input"], "output": {}}
for cond, (content, max_new) in conditions(e).items():
try:
out = generate(content, max_new)
except Exception as exc: # surface per-condition failures, keep going
print(f"[{e['id']} | {cond}] FAILED: {type(exc).__name__}: {exc}")
rec["output"][cond] = {"error": f"{type(exc).__name__}: {exc}"}
continue
rec["output"][cond] = out
print(
f"\n[{e['id']} | {cond}] {out['seconds']}s, {out['denoising_steps']} steps, "
f"{out['prompt_tokens']} prompt tok"
)
print(f" OUT: {out['text'][:400]}")
ocr_preview = e["ocr_input"][:400]
print(f" OCR: {ocr_preview}")
results.append(rec)
out_path = Path("image_vibe_outputs.json")
out_path.write_text(json.dumps(results, indent=2, ensure_ascii=False))
print(f"\nwrote {out_path}")
api = HfApi()
api.create_repo(OUT_REPO, repo_type="dataset", private=True, exist_ok=True)
api.upload_file(
path_or_fileobj=out_path, path_in_repo=out_path.name,
repo_id=OUT_REPO, repo_type="dataset",
)
print(f"uploaded to https://huggingface.co/datasets/{OUT_REPO} (private)")
if __name__ == "__main__":
main()

Xet Storage Details

Size:
7.76 kB
·
Xet hash:
a88e66f42707eb1545d83a1c496811c6300f9c860d7ce45f85a10cea2c15cc43

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.