aniimage / app.py
8BitStudio's picture
Update app.py
d12fd5f verified
import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from PIL import Image, ImageEnhance, ImageFilter
import gradio as gr
# ── Config ────────────────────────────────────────────────────────────────────
HF_REPO_ID = "8BitStudio/Aniimage-1"
VAE_ID = "stabilityai/sd-vae-ft-mse"
CLIP_ID = "openai/clip-vit-large-patch14"
UNET_CONFIG = dict(
sample_size=32,
in_channels=4,
out_channels=4,
block_out_channels=(256, 512, 768, 1024),
layers_per_block=2,
cross_attention_dim=768,
attention_head_dim=8,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D",
"CrossAttnUpBlock2D", "UpBlock2D"),
)
DEFAULT_NEGATIVE = (
"low quality, ugly, blurry, distorted, deformed, bad anatomy, "
"bad proportions, extra limbs, missing limbs, watermark, text, "
"signature, washed out, flat colors, manga panel, disfigured, "
"poorly drawn, jpeg artifacts, cropped, out of frame"
)
SCHEDULER_LIST = ["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler a", "Euler", "DDIM"]
# ── Generator ─────────────────────────────────────────────────────────────────
class Generator:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.vae = None
self.text_encoder = None
self.tokenizer = None
self.unet = None
self.scheduler_name = "DPM++ 2M Karras"
self.latent_size = 32
self.output_size = 256
def load(self):
if self.unet is not None:
return
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import shutil
print("Loading VAE...")
self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device)
self.vae.eval()
print("Loading CLIP...")
self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device)
self.text_encoder.eval()
print("Loading UNet...")
weights_path = Path("unet_weights.safetensors")
if not weights_path.exists():
dl = hf_hub_download(repo_id=HF_REPO_ID,
filename="diffusion_pytorch_model.safetensors")
shutil.copy2(dl, weights_path)
self.unet = UNet2DConditionModel(**UNET_CONFIG).to(self.device)
state = load_file(str(weights_path), device=str(self.device))
self.unet.load_state_dict(state)
self.unet.eval()
print(f"Ready! Running on {self.device.upper()}")
def _make_scheduler(self, name):
from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler)
base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear",
prediction_type="epsilon")
if name == "DPM++ 2M Karras":
return DPMSolverMultistepScheduler(
**base, algorithm_type="dpmsolver++",
solver_order=2, use_karras_sigmas=True)
elif name == "DPM++ SDE Karras":
return DPMSolverMultistepScheduler(
**base, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
elif name == "Euler a":
return EulerAncestralDiscreteScheduler(**base)
elif name == "Euler":
return EulerDiscreteScheduler(**base)
else:
return DDIMScheduler(**base, clip_sample=False, set_alpha_to_one=False)
def _decode_latents(self, latents):
scaled = latents / self.vae.config.scaling_factor
with torch.no_grad():
image = self.vae.decode(scaled.float()).sample
image = (image.float() / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).round().astype("uint8")
img = Image.fromarray(image)
img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2))
img = ImageEnhance.Contrast(img).enhance(1.06)
img = ImageEnhance.Color(img).enhance(1.10)
return img
def _sharpen_latents(self, latents, amount=0.08):
blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1)
return latents + amount * (latents - blurred)
@torch.no_grad()
def generate(self, prompt, negative_prompt="", steps=25,
guidance_scale=7.5, seed=-1, scheduler_name="DPM++ 2M Karras"):
self.load()
if seed < 0:
seed = torch.randint(0, 2**32, (1,)).item()
gen = torch.Generator(device=self.device).manual_seed(seed)
tok = self.tokenizer(prompt, padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
text_emb = self.text_encoder(tok.input_ids.to(self.device))[0]
tok_neg = self.tokenizer(negative_prompt, padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True, return_tensors="pt")
neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0]
combined = torch.cat([neg_emb, text_emb])
scheduler = self._make_scheduler(scheduler_name)
scheduler.set_timesteps(steps, device=self.device)
latents = torch.randn(1, 4, self.latent_size, self.latent_size,
generator=gen, device=self.device)
latents = latents * scheduler.init_noise_sigma
for t in scheduler.timesteps:
inp = torch.cat([latents] * 2)
inp = scheduler.scale_model_input(inp, t)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
enabled=(self.device == "cuda")):
pred = self.unet(inp, t, encoder_hidden_states=combined).sample
pred_neg, pred_text = pred.chunk(2)
pred = pred_neg + guidance_scale * (pred_text - pred_neg)
latents = scheduler.step(pred, t, latents).prev_sample
latents = self._sharpen_latents(latents)
return self._decode_latents(latents), seed
# ── Load model once at startup ────────────────────────────────────────────────
gen = Generator()
# ── Gradio UI ─────────────────────────────────────────────────────────────────
def run(prompt, negative, steps, cfg, scheduler, seed):
if not prompt.strip():
return None, "Please enter a prompt!"
image, used_seed = gen.generate(
prompt=prompt,
negative_prompt=negative,
steps=int(steps),
guidance_scale=float(cfg),
seed=int(seed),
scheduler_name=scheduler,
)
return image, f"Seed: {used_seed}"
with gr.Blocks(title="Aniimage-1 by 8BitStudio") as demo:
gr.Markdown("# 🎨 Aniimage-1\nAnime image generator by **8BitStudio** Β· 256Γ—256 Β· Trained from scratch on 830k Danbooru images\n\nUse plain English: *\"A smiling anime girl with red hair and a school uniform\"*")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", lines=3,
placeholder="A smiling anime girl with red hair and a school uniform")
negative = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE, lines=2)
with gr.Row():
steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
cfg = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG Scale")
with gr.Row():
scheduler = gr.Dropdown(SCHEDULER_LIST, value="DPM++ 2M Karras", label="Scheduler")
seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
btn = gr.Button("✨ Generate", variant="primary")
with gr.Column(scale=1):
output = gr.Image(label="Generated Image", type="pil")
seed_out = gr.Textbox(label="Used Seed", interactive=False)
btn.click(run, inputs=[prompt, negative, steps, cfg, scheduler, seed],
outputs=[output, seed_out])
gr.Examples(
examples=[
["A smiling anime girl with red hair and a school uniform", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1],
["A mysterious anime girl with silver hair under a night sky with stars", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1],
["An anime girl in a maid dress holding a teacup, cherry blossoms in the background", DEFAULT_NEGATIVE, 30, 7.5, "DPM++ 2M Karras", -1],
],
inputs=[prompt, negative, steps, cfg, scheduler, seed],
)
demo.launch()