comicx / comic /engine.py
ASTRALK's picture
Upload comic/engine.py with huggingface_hub
044e70f verified
Raw
History Blame Contribute Delete
6.19 kB
"""generate_comic β€” the full idea -> finished comic pipeline.
It is a GENERATOR that yields progress events as it goes, so the Gradio app can show
live status and stream panels onto the page as they render (rather than freezing for
the whole ~minute of generation). Stages:
1. WRITER bible call: safety gate + story bible. If refused -> a 'refused' event, done.
2. WRITER panel calls: 5 batches of 2 pages each, each fed a recap of prior panels
for continuity. Yields a 'panels' event per batch.
3. ARTIST renders all 20 panels, one at a time, yielding an 'image' event each.
4. 'done' event with the finished Comic.
The pipeline is backend-agnostic (mock or modal) via make_backends(). Errors in a
single panel render are caught so one bad image never sinks the whole comic.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, Optional
from .backends import make_backends, WriterBackend, ArtistBackend
from .schema import Comic, ComicBible, Panel, TOTAL_PANELS
from . import writer as W
from .imaging import build_image_prompt, panel_seed
# Panels rendered per GPU pass. 4 fits klein-9B at 832x576 on an H100 with headroom;
# the serve endpoint falls back to serial if a batch ever OOMs.
RENDER_BATCH = 4
@dataclass
class GenerateEvent:
"""One step of progress. `kind` drives the UI; payload fields are kind-specific."""
kind: str # status|bible|refused|panels|image|done|error
message: str = ""
comic: Optional[Comic] = None # carried from 'bible' onward (mutated in place)
panel: Optional[Panel] = None # for 'image' events
progress: float = 0.0 # 0..1 coarse progress for a progress bar
def _retry_json(fn, parse, attempts: int = 2):
"""Call fn() -> reply, parse it; retry once on a parse/JSON failure."""
last = None
for _ in range(attempts):
reply = fn()
try:
return parse(reply)
except Exception as e: # noqa: BLE001 - surface after retries
last = e
raise last if last else RuntimeError("generation failed")
def generate_comic(
idea: str,
writer: WriterBackend | None = None,
artist: ArtistBackend | None = None,
backend: str | None = None,
) -> Iterator[GenerateEvent]:
"""Yield GenerateEvents from raw idea to a fully rendered Comic."""
if writer is None or artist is None:
w, a = make_backends(backend)
writer = writer or w
artist = artist or a
idea = (idea or "").strip()
if not idea:
yield GenerateEvent("error", "Please describe the comic you want.")
return
# ── 1. bible + safety gate ────────────────────────────────────────────────
yield GenerateEvent("status", "Reading your idea and planning the story…", progress=0.02)
try:
bible: ComicBible = _retry_json(
lambda: writer.chat(W.build_bible_messages(idea)),
W.parse_bible,
)
except Exception as e: # noqa: BLE001
yield GenerateEvent("error", f"Couldn't plan the comic ({type(e).__name__}). Try again.")
return
if not bible.approved:
reason = bible.refusal_reason or "That request can't be turned into a comic."
yield GenerateEvent("refused", reason)
return
comic = Comic(bible=bible)
yield GenerateEvent("bible", f"β€œ{bible.title}” β€” {bible.logline}", comic=comic, progress=0.1)
# ── 2. panel script, batched, with running recap for continuity ───────────
written: list[Panel] = []
page_batches = W.batches()
n_batches = len(page_batches)
for bi, pages in enumerate(page_batches):
yield GenerateEvent(
"status",
f"Writing pages {pages[0]}–{pages[-1]} of {len(comic.bible.pages)}…",
comic=comic,
progress=0.1 + 0.3 * (bi / n_batches),
)
recap = W.recap_from_panels(written)
try:
panels = _retry_json(
lambda: writer.chat(W.build_panel_messages(bible, pages, recap)),
lambda r: W.parse_panels(r, pages),
)
except Exception as e: # noqa: BLE001
yield GenerateEvent("error", f"Story writing failed on pages {pages} ({type(e).__name__}).")
return
# Assemble each panel's image prompt now (deterministic, no model call).
for p in panels:
p.image_prompt = build_image_prompt(p, bible)
written.extend(panels)
comic.panels = sorted(written, key=lambda x: x.index)
yield GenerateEvent("panels", f"Pages {pages[0]}–{pages[-1]} scripted.",
comic=comic, progress=0.1 + 0.3 * ((bi + 1) / n_batches))
# ── 3. render every panel (batched through the GPU for throughput) ─────────
ordered = sorted(comic.panels, key=lambda x: x.index)
total = len(ordered) or TOTAL_PANELS
done = 0
for start in range(0, len(ordered), RENDER_BATCH):
chunk = ordered[start:start + RENDER_BATCH]
yield GenerateEvent(
"status",
f"Illustrating panels {start + 1}–{start + len(chunk)} of {total}…",
comic=comic, progress=0.4 + 0.6 * (start / total),
)
prompts = [p.image_prompt for p in chunk]
seeds = [panel_seed(bible, p) for p in chunk]
try:
images = artist.render_batch(prompts, seeds)
except Exception as e: # noqa: BLE001 - a batch failure must not sink the comic
images = [None] * len(chunk)
yield GenerateEvent("status", f"A render batch hiccupped ({type(e).__name__}); continuing…",
comic=comic)
for panel, img in zip(chunk, images):
panel.image = img
done += 1
yield GenerateEvent("image", f"Panel {done} ready.", comic=comic, panel=panel,
progress=0.4 + 0.6 * (done / total))
yield GenerateEvent("done", f"β€œ{bible.title}” is ready.", comic=comic, progress=1.0)