Buckets:
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "transformers>=5.11,<6", | |
| # "torch>=2.8,<2.10", # a100-large job image driver is CUDA 12.9; torch 2.10+ ships cu13 wheels | |
| # "torchvision", | |
| # "accelerate", | |
| # "huggingface-hub", | |
| # "pillow", | |
| # ] | |
| # /// | |
| """Image-arm vibe check: does giving DiffusionGemma the source page image help? | |
| Runs the 6 image-paired Space examples (4 Commoner 1901 regions sharing one | |
| full newspaper page, 2 Britannica 1771 leafs with long s) through three | |
| conditions, DiffusionGemma only: | |
| text OCR text -> corrected text (v1 benchmark baseline) | |
| image_text page image + OCR text -> corrected text (grounded correction) | |
| image_only page image -> transcription, anchored by the passage's opening | |
| words (512 new tokens = first multi-canvas generation test) | |
| No gold for these examples — output is qualitative, eyeballed via job logs and | |
| a JSON artifact uploaded to a private dataset repo. | |
| Run on HF Jobs: | |
| hf jobs uv run --flavor a100-large --timeout 30m \ | |
| -e HF_XET_HIGH_PERFORMANCE=1 -s HF_TOKEN \ | |
| vibe_image.py | |
| """ | |
| import json | |
| import time | |
| from pathlib import Path | |
| SPACE_REPO = "davanstrien/diffusiongemma-ocr-correction" | |
| DG_MODEL = "google/diffusiongemma-26B-A4B-it" | |
| OUT_REPO = "davanstrien/diffusiongemma-image-vibe" | |
| CORRECT_PROMPT = """\ | |
| Correct the OCR errors in the following text from a historical printed page. | |
| Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \ | |
| spelling, do not rephrase, and do not add or remove content. Preserve the original \ | |
| punctuation unless it is clearly an OCR error. | |
| Output only the corrected text, with no commentary or preamble. | |
| OCR text: | |
| {ocr}""" | |
| GROUNDED_PROMPT = """\ | |
| The image shows the printed source page for the OCR text below. | |
| Correct the OCR errors in the text, using the image as the authoritative reference. \ | |
| Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \ | |
| spelling, do not rephrase, and do not add or remove content. Preserve the original \ | |
| punctuation unless it is clearly an OCR error. | |
| Output only the corrected text, with no commentary or preamble. | |
| OCR text: | |
| {ocr}""" | |
| TRANSCRIBE_PROMPT = """\ | |
| Transcribe the passage in the image that begins with the words: "{anchor}". | |
| Reproduce the printed text exactly as it appears, preserving original spelling \ | |
| and punctuation. Output only the transcription, with no commentary or preamble.""" | |
| STOP_MARKERS = ("<turn|>", "<eos>", "<end_of_turn>", "<pad>") | |
| def extract_answer(raw: str) -> str: | |
| """Answer is the text after the LAST <channel|> (thought channel is empty | |
| with thinking off); cut at the first stop marker. Same as benchmark.py.""" | |
| stops = [i for m in STOP_MARKERS if (i := raw.find(m)) != -1] | |
| if stops: | |
| raw = raw[: min(stops)] | |
| if "<channel|>" in raw: | |
| _, _, raw = raw.rpartition("<channel|>") | |
| return raw.strip() | |
| def main() -> None: | |
| import torch | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from PIL import Image | |
| from transformers import ( | |
| AutoProcessor, | |
| DiffusionGemmaForBlockDiffusion, | |
| TextDiffusionStreamer, | |
| ) | |
| class StepCountingStreamer(TextDiffusionStreamer): | |
| def __init__(self, tokenizer): | |
| super().__init__(tokenizer=tokenizer) | |
| self.n_steps = 0 | |
| def put_draft(self, value, **kwargs): | |
| self.n_steps += 1 | |
| def put(self, value): | |
| pass | |
| def end(self): | |
| pass | |
| examples_path = hf_hub_download(SPACE_REPO, "examples.json", repo_type="space") | |
| examples = [e for e in json.loads(Path(examples_path).read_text()) if e.get("image")] | |
| print(f"{len(examples)} image-paired examples") | |
| images: dict[str, Image.Image] = {} | |
| for e in examples: | |
| fname = "images/" + e["image"].removeprefix("static/") | |
| if fname not in images: | |
| p = hf_hub_download(SPACE_REPO, fname, repo_type="space") | |
| images[fname] = Image.open(p).convert("RGB") | |
| print(f"{fname}: {images[fname].size}") | |
| e["_image"] = images[fname] | |
| print(f"loading {DG_MODEL} ...") | |
| processor = AutoProcessor.from_pretrained(DG_MODEL) | |
| model = DiffusionGemmaForBlockDiffusion.from_pretrained( | |
| DG_MODEL, dtype="auto", device_map="auto" | |
| ) | |
| tokenizer = processor.tokenizer | |
| def generate(content: list[dict], max_new_tokens: int) -> dict: | |
| message = [{"role": "user", "content": content}] | |
| inputs = processor.apply_chat_template( | |
| message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" | |
| ).to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| streamer = StepCountingStreamer(tokenizer) | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| output = model.generate(**inputs, max_new_tokens=max_new_tokens, streamer=streamer) | |
| torch.cuda.synchronize() | |
| seconds = time.perf_counter() - t0 | |
| seq = output.sequences if hasattr(output, "sequences") else output | |
| generated = seq[0][input_len:] if seq.shape[-1] > input_len else seq[0] | |
| raw = tokenizer.decode(generated, skip_special_tokens=False) | |
| return { | |
| "text": extract_answer(raw), | |
| "raw": raw, | |
| "seconds": round(seconds, 3), | |
| "denoising_steps": streamer.n_steps, | |
| "prompt_tokens": input_len, | |
| } | |
| def conditions(e: dict) -> dict[str, tuple[list[dict], int]]: | |
| ocr = e["ocr_input"] | |
| anchor = " ".join(ocr.split()[:8]) | |
| img = e["_image"] | |
| # model card: image before text for multimodal prompts | |
| return { | |
| "text": ([{"type": "text", "text": CORRECT_PROMPT.format(ocr=ocr)}], 256), | |
| "image_text": ( | |
| [{"type": "image", "image": img}, | |
| {"type": "text", "text": GROUNDED_PROMPT.format(ocr=ocr)}], | |
| 256, | |
| ), | |
| "image_only": ( | |
| [{"type": "image", "image": img}, | |
| {"type": "text", "text": TRANSCRIBE_PROMPT.format(anchor=anchor)}], | |
| 512, | |
| ), | |
| } | |
| print("warmup generation (uncounted) ...") | |
| generate([{"type": "text", "text": CORRECT_PROMPT.format(ocr=examples[0]["ocr_input"])}], 256) | |
| results = [] | |
| for e in examples: | |
| rec = {"id": e["id"], "label": e["label"], "ocr_input": e["ocr_input"], "output": {}} | |
| for cond, (content, max_new) in conditions(e).items(): | |
| try: | |
| out = generate(content, max_new) | |
| except Exception as exc: # surface per-condition failures, keep going | |
| print(f"[{e['id']} | {cond}] FAILED: {type(exc).__name__}: {exc}") | |
| rec["output"][cond] = {"error": f"{type(exc).__name__}: {exc}"} | |
| continue | |
| rec["output"][cond] = out | |
| print( | |
| f"\n[{e['id']} | {cond}] {out['seconds']}s, {out['denoising_steps']} steps, " | |
| f"{out['prompt_tokens']} prompt tok" | |
| ) | |
| print(f" OUT: {out['text'][:400]}") | |
| ocr_preview = e["ocr_input"][:400] | |
| print(f" OCR: {ocr_preview}") | |
| results.append(rec) | |
| out_path = Path("image_vibe_outputs.json") | |
| out_path.write_text(json.dumps(results, indent=2, ensure_ascii=False)) | |
| print(f"\nwrote {out_path}") | |
| api = HfApi() | |
| api.create_repo(OUT_REPO, repo_type="dataset", private=True, exist_ok=True) | |
| api.upload_file( | |
| path_or_fileobj=out_path, path_in_repo=out_path.name, | |
| repo_id=OUT_REPO, repo_type="dataset", | |
| ) | |
| print(f"uploaded to https://huggingface.co/datasets/{OUT_REPO} (private)") | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 7.76 kB
- Xet hash:
- a88e66f42707eb1545d83a1c496811c6300f9c860d7ce45f85a10cea2c15cc43
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.