Spaces:
Running on Zero
Running on Zero
| """ | |
| Gemma Diffusion — live website builder (gradio.Server backend + custom frontend). | |
| ZeroGPU port. `gradio.Server` (a FastAPI subclass) gives us Gradio's queue + SSE | |
| streaming while we serve our own hand-written HTML/CSS/JS frontend. The single | |
| streaming endpoint `/generate` runs the block-diffusion model and yields JSON frames | |
| (one per denoising step) that the frontend renders side-by-side: the raw HTML canvas | |
| diffusing on the left, the live rendered page on the right. | |
| ZeroGPU specifics: | |
| - `import spaces` happens before `torch`. | |
| - The model is loaded once at module scope with `.to("cuda")` (ZeroGPU registers it). | |
| - The actual `model.generate` call lives inside the `@spaces.GPU` function `_gpu_stream`; | |
| the `gradio.Server` endpoint only marshals picklable CPU tensors in/out of it. | |
| Refs: | |
| - https://huggingface.co/blog/introducing-gradio-server | |
| - https://huggingface.co/docs/hub/spaces-zerogpu | |
| """ | |
| import glob | |
| import os | |
| import subprocess | |
| import sys | |
| # Set before torch is imported (transformers pulls torch in). | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import spaces # must precede torch so ZeroGPU can patch it | |
| def _ensure_transformers(): | |
| """Install the bundled custom DiffusionGemma `transformers` wheel at runtime. | |
| Spaces installs `requirements.txt` *before* copying the repo files into the image, | |
| so the wheel can't be referenced by local path there. By the time this app runs the | |
| file is present in the working directory, so we install it here (only if a stock / | |
| no transformers is importable) before importing torch/transformers below. | |
| """ | |
| try: | |
| import transformers # noqa: F401 | |
| if hasattr(transformers, "DiffusionGemmaForBlockDiffusion") or hasattr( | |
| getattr(transformers, "models", object), "diffusion_gemma" | |
| ): | |
| return | |
| except Exception: | |
| pass | |
| wheels = sorted(glob.glob(os.path.join(os.path.dirname(os.path.abspath(__file__)), "transformers-*.whl"))) | |
| if not wheels: | |
| return | |
| print(f"[gdiff] Installing bundled transformers wheel: {os.path.basename(wheels[0])}", flush=True) | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-cache-dir", wheels[0]]) | |
| import importlib | |
| importlib.invalidate_caches() | |
| _ensure_transformers() | |
| import json | |
| import queue as queue_lib | |
| import re | |
| import threading | |
| import time as _time | |
| import torch | |
| from fastapi.responses import HTMLResponse | |
| from gradio import Server | |
| from transformers import AutoTokenizer, DiffusionGemmaForBlockDiffusion | |
| from transformers.generation.streamers import BaseStreamer | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.environ.get("GDIFF_MODEL_PATH", "google/diffusiongemma-26B-A4B-it") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| MAX_ITERS_CAP = 120 # hard cap on denoising steps per block | |
| # ZeroGPU: the 26B checkpoint (~49 GB bf16) needs the full backing card. | |
| GPU_SIZE = os.environ.get("GDIFF_GPU_SIZE", "xlarge") | |
| SYSTEM_PROMPT = ( | |
| "You are an expert front-end web developer with great visual taste. When asked to " | |
| "build or change a web page, respond with a SINGLE, complete, self-contained HTML5 " | |
| "document. Put all CSS in a <style> tag and any JavaScript in a <script> tag inside " | |
| "the document. Do not load external assets. When asked to modify an existing page, " | |
| "return the FULL updated HTML document with the change applied. Do not include " | |
| "explanations or markdown code fences — output only raw HTML, starting with " | |
| "<!DOCTYPE html>." | |
| ) | |
| _MARKER_RE = re.compile( | |
| r"<\|?(?:channel|turn|think|image|audio|video|tool(?:_call|_response)?)\|?>" | |
| ) | |
| _FENCE_RE = re.compile(r"```(?:html)?\s*(.*?)\s*```", re.DOTALL) | |
| # --------------------------------------------------------------------------- # | |
| # Model (loaded once at module scope; ZeroGPU registers .to("cuda") tensors) | |
| # --------------------------------------------------------------------------- # | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[gdiff] Loading model from {MODEL_PATH} on {DEVICE} ...", flush=True) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_TOKEN) | |
| model = DiffusionGemmaForBlockDiffusion.from_pretrained( | |
| MODEL_PATH, | |
| dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| token=HF_TOKEN, | |
| ).to(DEVICE) | |
| model.eval() | |
| CANVAS_LEN = model.config.canvas_length | |
| PAD_ID = tokenizer.pad_token_id or 0 | |
| print(f"[gdiff] Model ready | canvas_length={CANVAS_LEN}", flush=True) | |
| # Cache of the last *cleaned* page so a follow-up tweak can warm-start in place. | |
| model._last_clean_html = None | |
| # --------------------------------------------------------------------------- # | |
| # Helpers (CPU-only; safe to run in the gradio.Server main process) | |
| # --------------------------------------------------------------------------- # | |
| def warm_canvas_from_cache(): | |
| """Starting canvas (first block) built from the previous *cleaned* page. | |
| Returns a CPU tensor (it is pickled across the ZeroGPU process boundary and moved | |
| to CUDA inside the GPU worker). We re-tokenize the cleaned HTML rather than reuse | |
| raw output tokens so a mangled header can't compound across tweaks. | |
| """ | |
| html = getattr(model, "_last_clean_html", None) | |
| if not html: | |
| return None | |
| ids = tokenizer(html, add_special_tokens=False).input_ids[:CANVAS_LEN] | |
| if not ids: | |
| return None | |
| if len(ids) < CANVAS_LEN: | |
| ids = ids + [PAD_ID] * (CANVAS_LEN - len(ids)) | |
| return torch.tensor(ids, dtype=torch.long).unsqueeze(0) | |
| def last_assistant_html(history_json: str): | |
| try: | |
| history = json.loads(history_json) if history_json else [] | |
| except json.JSONDecodeError: | |
| return None | |
| for turn in reversed(history): | |
| if turn.get("role") == "assistant" and turn.get("content"): | |
| return turn["content"] | |
| return None | |
| def clean_text(text: str) -> str: | |
| return _MARKER_RE.sub("", text).lstrip() | |
| def extract_html(text: str) -> str: | |
| """Pull a usable HTML document out of the (possibly mangled) model output. | |
| Anchor on the first intact structural tag and rebuild whatever the diffused tweak ate | |
| off the front, so the result is always a valid document (never quirks mode and never a | |
| broken ``DOCTYPE>`` / ``html lang=`` header). | |
| """ | |
| text = clean_text(text) | |
| fenced = _FENCE_RE.search(text) | |
| if fenced: | |
| text = fenced.group(1) | |
| lower = text.lower() | |
| dt = lower.find("<!doctype") | |
| if dt != -1: | |
| return text[dt:].strip() | |
| h = lower.find("<html") | |
| if h != -1: | |
| return "<!DOCTYPE html>\n" + text[h:].strip() | |
| hd = lower.find("<head") | |
| if hd != -1: | |
| return '<!DOCTYPE html>\n<html lang="en">\n' + text[hd:].strip() | |
| bd = lower.find("<body") | |
| if bd != -1: | |
| return ( | |
| '<!DOCTYPE html>\n<html lang="en">\n<head><meta charset="UTF-8">' | |
| '<meta name="viewport" content="width=device-width, initial-scale=1.0"></head>\n' | |
| + text[bd:].strip() | |
| ) | |
| return text.strip() | |
| class QueueDiffusionStreamer(BaseStreamer): | |
| def __init__(self, tok, q: "queue_lib.Queue"): | |
| self.tok = tok | |
| self.q = q | |
| self.confirmed_ids: list[int] = [] | |
| self.prompt_skipped = False | |
| self.block = 0 | |
| self.step = 0 | |
| def _decode(self, ids): | |
| return self.tok.decode(ids, skip_special_tokens=True) | |
| def put(self, value): | |
| ids = value[0].tolist() if value.dim() > 1 else value.tolist() | |
| if not self.prompt_skipped: | |
| self.prompt_skipped = True | |
| return | |
| self.confirmed_ids.extend(ids) | |
| self.block += 1 | |
| self.step = 0 | |
| self.q.put(("commit", self._decode(self.confirmed_ids), self.block, self.step)) | |
| def put_draft(self, value): | |
| self.step += 1 | |
| ids = value[0].tolist() if value.dim() > 1 else value.tolist() | |
| self.q.put(("draft", self._decode(self.confirmed_ids + ids), self.block + 1, self.step)) | |
| def end(self): | |
| self.q.put(("end", self._decode(self.confirmed_ids), self.block, self.step)) | |
| def build_messages(history_json: str, prompt: str): | |
| try: | |
| history = json.loads(history_json) if history_json else [] | |
| except json.JSONDecodeError: | |
| history = [] | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| for turn in history: | |
| role = turn.get("role") | |
| content = turn.get("content", "") | |
| if role in ("user", "assistant") and content: | |
| messages.append({"role": role, "content": content}) | |
| messages.append({"role": "user", "content": prompt}) | |
| return messages | |
| # --------------------------------------------------------------------------- # | |
| # GPU work — runs in a forked ZeroGPU worker process. | |
| # Inputs/outputs cross the boundary via pickle, so only CPU tensors / plain | |
| # Python objects go in and out (no CUDA tensors are returned). | |
| # --------------------------------------------------------------------------- # | |
| def _estimate_duration(input_ids, max_new_tokens=2048, max_iters=64, full_denoise=False, canvas_ids=None): | |
| blocks = max(1, int(max_new_tokens) // max(1, CANVAS_LEN)) | |
| secs = 30 + blocks * int(max_iters) * 0.3 | |
| return int(min(120, secs)) # xlarge internally doubles this for the quota check | |
| def _gpu_stream(input_ids, max_new_tokens, max_iters, full_denoise, canvas_ids): | |
| input_ids = input_ids.to(model.device) | |
| gen_kwargs = dict(max_new_tokens=int(max_new_tokens), max_denoising_steps=int(max_iters)) | |
| if full_denoise: | |
| gen_kwargs["confidence_threshold"] = 1e-9 | |
| gen_kwargs["stability_threshold"] = int(max_iters) | |
| if canvas_ids is not None: | |
| gen_kwargs["canvas_ids"] = canvas_ids.to(model.device) | |
| q: "queue_lib.Queue" = queue_lib.Queue() | |
| streamer = QueueDiffusionStreamer(tokenizer, q) | |
| err = {} | |
| def worker(): | |
| try: | |
| with torch.inference_mode(): | |
| model.generate(input_ids, streamer=streamer, **gen_kwargs) | |
| except Exception as exc: # surface to the endpoint | |
| err["msg"] = f"{type(exc).__name__}: {exc}" | |
| q.put(("error", str(exc), 0, 0)) | |
| finally: | |
| q.put(("end", "", 0, 0)) # always unblock the consumer | |
| thread = threading.Thread(target=worker) | |
| thread.start() | |
| try: | |
| while True: | |
| kind, text, block, step = q.get() | |
| if kind == "error": | |
| yield ("error", err.get("msg", text), 0, 0) | |
| return | |
| if kind == "end": | |
| return | |
| yield (kind, text, block, step) | |
| finally: | |
| thread.join() | |
| # --------------------------------------------------------------------------- # | |
| # Server | |
| # --------------------------------------------------------------------------- # | |
| app = Server(title="Gemma Diffusion Website Builder") | |
| def generate( | |
| prompt: str, | |
| history_json: str = "[]", | |
| max_new_tokens: int = 2048, | |
| max_iters: int = 64, | |
| full_denoise: bool = False, | |
| anim_delay: float = 0.0, | |
| warm_start: bool = True, | |
| ) -> str: | |
| """Stream the diffusion generation as JSON frames (one per denoising step). | |
| The model writes a self-contained HTML document; the frontend renders it live. | |
| """ | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| yield json.dumps({"kind": "error", "message": "Empty prompt."}) | |
| return | |
| messages = build_messages(history_json, prompt) | |
| max_iters = max(1, min(int(max_iters), MAX_ITERS_CAP)) | |
| # Tweak warm-start: seed the diffusion's first canvas with the previous page's own | |
| # tokens (native `canvas_ids` API) so the model edits the existing page in place. | |
| is_tweak = bool(last_assistant_html(history_json)) | |
| canvas_ids = warm_canvas_from_cache() if (warm_start and is_tweak) else None | |
| warming = canvas_ids is not None | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| )["input_ids"] | |
| last_text = "" | |
| for kind, text, block, step in _gpu_stream( | |
| input_ids, int(max_new_tokens), max_iters, bool(full_denoise), canvas_ids | |
| ): | |
| if kind == "error": | |
| yield json.dumps({"kind": "error", "message": text}) | |
| return | |
| last_text = text | |
| yield json.dumps( | |
| { | |
| "kind": "draft" if kind == "draft" else "commit", | |
| "source": clean_text(text), | |
| "block": block, | |
| "step": step, | |
| "canvas": CANVAS_LEN, | |
| "max_iters": max_iters, | |
| "warming": warming, | |
| } | |
| ) | |
| if anim_delay and kind == "draft": | |
| _time.sleep(float(anim_delay)) | |
| final_source = extract_html(last_text) | |
| # Cache the *cleaned* output so the next tweak warm-starts from a valid header. | |
| if final_source.strip(): | |
| model._last_clean_html = final_source | |
| yield json.dumps({"kind": "done", "source": final_source}) | |
| async def homepage(): | |
| with open(os.path.join(HERE, "index.html"), "r", encoding="utf-8") as f: | |
| return f.read() | |
| # HF Spaces' gradio runtime looks for a top-level `demo` (or `app`) to launch. | |
| demo = app | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name=os.environ.get("GDIFF_HOST", "0.0.0.0"), | |
| server_port=int(os.environ.get("GDIFF_PORT", "7860")), | |
| show_error=True, | |
| ) | |