import torch import torch.nn.functional as F import numpy as np from pathlib import Path from PIL import Image, ImageEnhance, ImageFilter import gradio as gr # ── Config ──────────────────────────────────────────────────────────────────── HF_REPO_ID = "8BitStudio/Aniimage-1" VAE_ID = "stabilityai/sd-vae-ft-mse" CLIP_ID = "openai/clip-vit-large-patch14" UNET_CONFIG = dict( sample_size=32, in_channels=4, out_channels=4, block_out_channels=(256, 512, 768, 1024), layers_per_block=2, cross_attention_dim=768, attention_head_dim=8, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), ) DEFAULT_NEGATIVE = ( "low quality, ugly, blurry, distorted, deformed, bad anatomy, " "bad proportions, extra limbs, missing limbs, watermark, text, " "signature, washed out, flat colors, manga panel, disfigured, " "poorly drawn, jpeg artifacts, cropped, out of frame" ) SCHEDULER_LIST = ["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler a", "Euler", "DDIM"] # ── Generator ───────────────────────────────────────────────────────────────── class Generator: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.vae = None self.text_encoder = None self.tokenizer = None self.unet = None self.scheduler_name = "DPM++ 2M Karras" self.latent_size = 32 self.output_size = 256 def load(self): if self.unet is not None: return from diffusers import AutoencoderKL, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer from huggingface_hub import hf_hub_download from safetensors.torch import load_file import shutil print("Loading VAE...") self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device) self.vae.eval() print("Loading CLIP...") self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID) self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device) self.text_encoder.eval() print("Loading UNet...") weights_path = Path("unet_weights.safetensors") if not weights_path.exists(): dl = hf_hub_download(repo_id=HF_REPO_ID, filename="diffusion_pytorch_model.safetensors") shutil.copy2(dl, weights_path) self.unet = UNet2DConditionModel(**UNET_CONFIG).to(self.device) state = load_file(str(weights_path), device=str(self.device)) self.unet.load_state_dict(state) self.unet.eval() print(f"Ready! Running on {self.device.upper()}") def _make_scheduler(self, name): from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler) base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear", prediction_type="epsilon") if name == "DPM++ 2M Karras": return DPMSolverMultistepScheduler( **base, algorithm_type="dpmsolver++", solver_order=2, use_karras_sigmas=True) elif name == "DPM++ SDE Karras": return DPMSolverMultistepScheduler( **base, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True) elif name == "Euler a": return EulerAncestralDiscreteScheduler(**base) elif name == "Euler": return EulerDiscreteScheduler(**base) else: return DDIMScheduler(**base, clip_sample=False, set_alpha_to_one=False) def _decode_latents(self, latents): scaled = latents / self.vae.config.scaling_factor with torch.no_grad(): image = self.vae.decode(scaled.float()).sample image = (image.float() / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).round().astype("uint8") img = Image.fromarray(image) img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2)) img = ImageEnhance.Contrast(img).enhance(1.06) img = ImageEnhance.Color(img).enhance(1.10) return img def _sharpen_latents(self, latents, amount=0.08): blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1) return latents + amount * (latents - blurred) @torch.no_grad() def generate(self, prompt, negative_prompt="", steps=25, guidance_scale=7.5, seed=-1, scheduler_name="DPM++ 2M Karras"): self.load() if seed < 0: seed = torch.randint(0, 2**32, (1,)).item() gen = torch.Generator(device=self.device).manual_seed(seed) tok = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") text_emb = self.text_encoder(tok.input_ids.to(self.device))[0] tok_neg = self.tokenizer(negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt") neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0] combined = torch.cat([neg_emb, text_emb]) scheduler = self._make_scheduler(scheduler_name) scheduler.set_timesteps(steps, device=self.device) latents = torch.randn(1, 4, self.latent_size, self.latent_size, generator=gen, device=self.device) latents = latents * scheduler.init_noise_sigma for t in scheduler.timesteps: inp = torch.cat([latents] * 2) inp = scheduler.scale_model_input(inp, t) with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(self.device == "cuda")): pred = self.unet(inp, t, encoder_hidden_states=combined).sample pred_neg, pred_text = pred.chunk(2) pred = pred_neg + guidance_scale * (pred_text - pred_neg) latents = scheduler.step(pred, t, latents).prev_sample latents = self._sharpen_latents(latents) return self._decode_latents(latents), seed # ── Load model once at startup ──────────────────────────────────────────────── gen = Generator() # ── Gradio UI ───────────────────────────────────────────────────────────────── def run(prompt, negative, steps, cfg, scheduler, seed): if not prompt.strip(): return None, "Please enter a prompt!" image, used_seed = gen.generate( prompt=prompt, negative_prompt=negative, steps=int(steps), guidance_scale=float(cfg), seed=int(seed), scheduler_name=scheduler, ) return image, f"Seed: {used_seed}" with gr.Blocks(title="Aniimage-1 by 8BitStudio") as demo: gr.Markdown("# 🎨 Aniimage-1\nAnime image generator by **8BitStudio** · 256×256 · Trained from scratch on 830k Danbooru images\n\nUse plain English: *\"A smiling anime girl with red hair and a school uniform\"*") with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Prompt", lines=3, placeholder="A smiling anime girl with red hair and a school uniform") negative = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE, lines=2) with gr.Row(): steps = gr.Slider(10, 50, value=25, step=1, label="Steps") cfg = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG Scale") with gr.Row(): scheduler = gr.Dropdown(SCHEDULER_LIST, value="DPM++ 2M Karras", label="Scheduler") seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0) btn = gr.Button("✨ Generate", variant="primary") with gr.Column(scale=1): output = gr.Image(label="Generated Image", type="pil") seed_out = gr.Textbox(label="Used Seed", interactive=False) btn.click(run, inputs=[prompt, negative, steps, cfg, scheduler, seed], outputs=[output, seed_out]) gr.Examples( examples=[ ["A smiling anime girl with red hair and a school uniform", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1], ["A mysterious anime girl with silver hair under a night sky with stars", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1], ["An anime girl in a maid dress holding a teacup, cherry blossoms in the background", DEFAULT_NEGATIVE, 30, 7.5, "DPM++ 2M Karras", -1], ], inputs=[prompt, negative, steps, cfg, scheduler, seed], ) demo.launch()