Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from pathlib import Path | |
| from PIL import Image, ImageEnhance, ImageFilter | |
| import gradio as gr | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_REPO_ID = "8BitStudio/Aniimage-1" | |
| VAE_ID = "stabilityai/sd-vae-ft-mse" | |
| CLIP_ID = "openai/clip-vit-large-patch14" | |
| UNET_CONFIG = dict( | |
| sample_size=32, | |
| in_channels=4, | |
| out_channels=4, | |
| block_out_channels=(256, 512, 768, 1024), | |
| layers_per_block=2, | |
| cross_attention_dim=768, | |
| attention_head_dim=8, | |
| down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", | |
| "CrossAttnDownBlock2D", "DownBlock2D"), | |
| up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", | |
| "CrossAttnUpBlock2D", "UpBlock2D"), | |
| ) | |
| DEFAULT_NEGATIVE = ( | |
| "low quality, ugly, blurry, distorted, deformed, bad anatomy, " | |
| "bad proportions, extra limbs, missing limbs, watermark, text, " | |
| "signature, washed out, flat colors, manga panel, disfigured, " | |
| "poorly drawn, jpeg artifacts, cropped, out of frame" | |
| ) | |
| SCHEDULER_LIST = ["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler a", "Euler", "DDIM"] | |
| # ββ Generator βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Generator: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.vae = None | |
| self.text_encoder = None | |
| self.tokenizer = None | |
| self.unet = None | |
| self.scheduler_name = "DPM++ 2M Karras" | |
| self.latent_size = 32 | |
| self.output_size = 256 | |
| def load(self): | |
| if self.unet is not None: | |
| return | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| import shutil | |
| print("Loading VAE...") | |
| self.vae = AutoencoderKL.from_pretrained(VAE_ID).to(self.device) | |
| self.vae.eval() | |
| print("Loading CLIP...") | |
| self.tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID) | |
| self.text_encoder = CLIPTextModel.from_pretrained(CLIP_ID).to(self.device) | |
| self.text_encoder.eval() | |
| print("Loading UNet...") | |
| weights_path = Path("unet_weights.safetensors") | |
| if not weights_path.exists(): | |
| dl = hf_hub_download(repo_id=HF_REPO_ID, | |
| filename="diffusion_pytorch_model.safetensors") | |
| shutil.copy2(dl, weights_path) | |
| self.unet = UNet2DConditionModel(**UNET_CONFIG).to(self.device) | |
| state = load_file(str(weights_path), device=str(self.device)) | |
| self.unet.load_state_dict(state) | |
| self.unet.eval() | |
| print(f"Ready! Running on {self.device.upper()}") | |
| def _make_scheduler(self, name): | |
| from diffusers import (DDIMScheduler, DPMSolverMultistepScheduler, | |
| EulerAncestralDiscreteScheduler, | |
| EulerDiscreteScheduler) | |
| base = dict(num_train_timesteps=1000, beta_schedule="scaled_linear", | |
| prediction_type="epsilon") | |
| if name == "DPM++ 2M Karras": | |
| return DPMSolverMultistepScheduler( | |
| **base, algorithm_type="dpmsolver++", | |
| solver_order=2, use_karras_sigmas=True) | |
| elif name == "DPM++ SDE Karras": | |
| return DPMSolverMultistepScheduler( | |
| **base, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True) | |
| elif name == "Euler a": | |
| return EulerAncestralDiscreteScheduler(**base) | |
| elif name == "Euler": | |
| return EulerDiscreteScheduler(**base) | |
| else: | |
| return DDIMScheduler(**base, clip_sample=False, set_alpha_to_one=False) | |
| def _decode_latents(self, latents): | |
| scaled = latents / self.vae.config.scaling_factor | |
| with torch.no_grad(): | |
| image = self.vae.decode(scaled.float()).sample | |
| image = (image.float() / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
| image = (image * 255).round().astype("uint8") | |
| img = Image.fromarray(image) | |
| img = img.filter(ImageFilter.UnsharpMask(radius=1.5, percent=40, threshold=2)) | |
| img = ImageEnhance.Contrast(img).enhance(1.06) | |
| img = ImageEnhance.Color(img).enhance(1.10) | |
| return img | |
| def _sharpen_latents(self, latents, amount=0.08): | |
| blurred = F.avg_pool2d(latents, kernel_size=3, stride=1, padding=1) | |
| return latents + amount * (latents - blurred) | |
| def generate(self, prompt, negative_prompt="", steps=25, | |
| guidance_scale=7.5, seed=-1, scheduler_name="DPM++ 2M Karras"): | |
| self.load() | |
| if seed < 0: | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| gen = torch.Generator(device=self.device).manual_seed(seed) | |
| tok = self.tokenizer(prompt, padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, return_tensors="pt") | |
| text_emb = self.text_encoder(tok.input_ids.to(self.device))[0] | |
| tok_neg = self.tokenizer(negative_prompt, padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, return_tensors="pt") | |
| neg_emb = self.text_encoder(tok_neg.input_ids.to(self.device))[0] | |
| combined = torch.cat([neg_emb, text_emb]) | |
| scheduler = self._make_scheduler(scheduler_name) | |
| scheduler.set_timesteps(steps, device=self.device) | |
| latents = torch.randn(1, 4, self.latent_size, self.latent_size, | |
| generator=gen, device=self.device) | |
| latents = latents * scheduler.init_noise_sigma | |
| for t in scheduler.timesteps: | |
| inp = torch.cat([latents] * 2) | |
| inp = scheduler.scale_model_input(inp, t) | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, | |
| enabled=(self.device == "cuda")): | |
| pred = self.unet(inp, t, encoder_hidden_states=combined).sample | |
| pred_neg, pred_text = pred.chunk(2) | |
| pred = pred_neg + guidance_scale * (pred_text - pred_neg) | |
| latents = scheduler.step(pred, t, latents).prev_sample | |
| latents = self._sharpen_latents(latents) | |
| return self._decode_latents(latents), seed | |
| # ββ Load model once at startup ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gen = Generator() | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run(prompt, negative, steps, cfg, scheduler, seed): | |
| if not prompt.strip(): | |
| return None, "Please enter a prompt!" | |
| image, used_seed = gen.generate( | |
| prompt=prompt, | |
| negative_prompt=negative, | |
| steps=int(steps), | |
| guidance_scale=float(cfg), | |
| seed=int(seed), | |
| scheduler_name=scheduler, | |
| ) | |
| return image, f"Seed: {used_seed}" | |
| with gr.Blocks(title="Aniimage-1 by 8BitStudio") as demo: | |
| gr.Markdown("# π¨ Aniimage-1\nAnime image generator by **8BitStudio** Β· 256Γ256 Β· Trained from scratch on 830k Danbooru images\n\nUse plain English: *\"A smiling anime girl with red hair and a school uniform\"*") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", lines=3, | |
| placeholder="A smiling anime girl with red hair and a school uniform") | |
| negative = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE, lines=2) | |
| with gr.Row(): | |
| steps = gr.Slider(10, 50, value=25, step=1, label="Steps") | |
| cfg = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="CFG Scale") | |
| with gr.Row(): | |
| scheduler = gr.Dropdown(SCHEDULER_LIST, value="DPM++ 2M Karras", label="Scheduler") | |
| seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0) | |
| btn = gr.Button("β¨ Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| output = gr.Image(label="Generated Image", type="pil") | |
| seed_out = gr.Textbox(label="Used Seed", interactive=False) | |
| btn.click(run, inputs=[prompt, negative, steps, cfg, scheduler, seed], | |
| outputs=[output, seed_out]) | |
| gr.Examples( | |
| examples=[ | |
| ["A smiling anime girl with red hair and a school uniform", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1], | |
| ["A mysterious anime girl with silver hair under a night sky with stars", DEFAULT_NEGATIVE, 25, 7.5, "DPM++ 2M Karras", -1], | |
| ["An anime girl in a maid dress holding a teacup, cherry blossoms in the background", DEFAULT_NEGATIVE, 30, 7.5, "DPM++ 2M Karras", -1], | |
| ], | |
| inputs=[prompt, negative, steps, cfg, scheduler, seed], | |
| ) | |
| demo.launch() |