import os from dataclasses import dataclass from PIL import Image import cv2 import numpy as np import gradio as gr import torch import spaces # type: ignore from huggingface_hub import hf_hub_download from safetensors.torch import load_file from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.controlnets.controlnet import ControlNetModel from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from transformers import CLIPTextModel, CLIPTokenizer BIG_CSS = """ /* Global bump */ .gradio-container { font-size: 18px !important; } /* Force most UI text bigger */ .gradio-container * { font-size: 18px !important; } /* Keep markdown headings bigger */ .gradio-container h1 { font-size: 28px !important; } .gradio-container h2 { font-size: 24px !important; } .gradio-container h3 { font-size: 20px !important; } /* Slightly smaller helper/info text if you want */ .gradio-container .info, .gradio-container .prose p, .gradio-container .prose li { font-size: 16px !important; line-height: 1.35 !important; } """ # ----------------------------- # Pipeline builder # ----------------------------- def build_controlnet_pipe( base_model_name: str, controlnet: ControlNetModel, vae: AutoencoderKL, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, device: torch.device, weight_dtype: torch.dtype, use_unipc: bool = True, ) -> StableDiffusionControlNetPipeline: pipe = StableDiffusionControlNetPipeline.from_pretrained( base_model_name, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, torch_dtype=weight_dtype, ) if use_unipc: pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=True) return pipe @dataclass class CannyCFG: use_clahe: bool = True clahe_clip: float = 2.0 clahe_grid: int = 8 gaussian_ksize: int = 5 gaussian_sigma: float = 1.2 high_pct: float = 90.0 # higher -> fewer edges (stricter) low_ratio: float = 0.4 # low = low_ratio * high aperture_size: int = 3 l2_gradient: bool = True def canny_percentile(pil_img: Image.Image, cfg: CannyCFG) -> Image.Image: gray = np.array(pil_img.convert("L"), dtype=np.uint8) if cfg.use_clahe: clahe = cv2.createCLAHE( clipLimit=float(cfg.clahe_clip), tileGridSize=(int(cfg.clahe_grid), int(cfg.clahe_grid)), ) gray = clahe.apply(gray) k = int(cfg.gaussian_ksize) | 1 # ensure odd blur = cv2.GaussianBlur(gray, (k, k), float(cfg.gaussian_sigma)) gx = cv2.Sobel(blur, cv2.CV_32F, 1, 0, ksize=3) gy = cv2.Sobel(blur, cv2.CV_32F, 0, 1, ksize=3) mag = cv2.magnitude(gx, gy) high = float(np.percentile(mag, float(cfg.high_pct))) low = float(cfg.low_ratio) * high if high <= low: high = low + 1.0 ap = int(cfg.aperture_size) if ap not in (3, 5, 7): ap = 3 edges = cv2.Canny( blur, threshold1=low, threshold2=high, apertureSize=ap, L2gradient=bool(cfg.l2_gradient), ) return Image.fromarray(edges, mode="L") # ----------------------------- # Config # ----------------------------- BASE_MODEL = "sd-legacy/stable-diffusion-v1-5" WEIGHTS_REPO = "mvp-lab/ControlNet_Weight" WEIGHTS_FILENAME = "diffusion_pytorch_model_1.safetensors" LOCAL_WEIGHTS = os.getenv( "CONTROLNET_WEIGHTS", "/home/nik/ImperialWork/GenerativeAi/sd15-controlnet-trainer/controlnet_laion/final/diffusion_pytorch_model.safetensors", ) if os.path.isfile(LOCAL_WEIGHTS): CONTROLNET_PATH = LOCAL_WEIGHTS else: CONTROLNET_PATH = hf_hub_download(repo_id=WEIGHTS_REPO, filename=WEIGHTS_FILENAME, repo_type="model") DTYPE = torch.float32 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------------- # Model load (once) # ----------------------------- vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE) unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet", torch_dtype=DTYPE) tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder", torch_dtype=DTYPE) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.requires_grad_(False) controlnet = ControlNetModel.from_unet(unet, conditioning_channels=3) state = load_file(CONTROLNET_PATH) missing, unexpected = controlnet.load_state_dict(state, strict=False) pipe = build_controlnet_pipe( base_model_name=BASE_MODEL, controlnet=controlnet, vae=vae, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, device=DEVICE, weight_dtype=DTYPE, use_unipc=True, ) # ----------------------------- # Helpers: fixed resize policy (longest side = 512, keep aspect, divisible by 8) # ----------------------------- def round_down_to_multiple(x: int, m: int = 8) -> int: return max(m, (x // m) * m) def resize_longest_side_div8(img: Image.Image, longest: int = 512) -> tuple[Image.Image, int, int]: w, h = img.size if w <= 0 or h <= 0: raise ValueError("Invalid image size") scale = float(longest) / float(max(w, h)) tw = int(round(w * scale)) th = int(round(h * scale)) tw = round_down_to_multiple(tw, 8) th = round_down_to_multiple(th, 8) tw = max(8, tw) th = max(8, th) resized = img.resize((tw, th), resample=Image.BICUBIC) # type: ignore return resized, tw, th def compute_canny_rgb(img_rgb_resized: Image.Image, use_clahe: bool, edge_amount: float, smoothing: float) -> Image.Image: high_pct = 95.0 - 20.0 * float(edge_amount) # 0 => 95 (few), 1 => 75 (many) high_pct = float(np.clip(high_pct, 70.0, 99.0)) gaussian_sigma = 0.6 + 2.2 * float(smoothing) # 0 => 0.6, 1 => 2.8 cfg = CannyCFG( use_clahe=bool(use_clahe), clahe_clip=2.0, clahe_grid=8, gaussian_ksize=5, gaussian_sigma=float(gaussian_sigma), high_pct=float(high_pct), low_ratio=0.4, aperture_size=3, l2_gradient=True, ) edges_l = canny_percentile(img_rgb_resized, cfg) return edges_l.convert("RGB") def update_canny_preview(input_image, use_clahe, edge_amount, smoothing): if input_image is None: return None, None, 512, 512 if not isinstance(input_image, Image.Image): input_image = Image.fromarray(input_image) img_rgb0 = input_image.convert("RGB") img_rgb, width, height = resize_longest_side_div8(img_rgb0, longest=512) canny = compute_canny_rgb( img_rgb, use_clahe=use_clahe, edge_amount=float(edge_amount), smoothing=float(smoothing), ) return canny, canny, width, height @spaces.GPU @torch.inference_mode() def generate_from_canny( canny: Image.Image, width: int, height: int, prompt: str, negative_prompt: str, guidance_scale: float, num_inference_steps: int, num_images: int, controlnet_conditioning_scale: float, ): if canny is None: raise gr.Error("Canny conditioning image missing. Upload an image first.") if int(num_images) < 1: raise gr.Error("num_images must be >= 1") gens = [torch.Generator(device=DEVICE).manual_seed(i) for i in range(int(num_images))] imgs = pipe( prompt=[prompt] * int(num_images), negative_prompt=[negative_prompt] * int(num_images), image=[canny] * int(num_images), num_inference_steps=int(num_inference_steps), guidance_scale=float(guidance_scale), height=int(height), width=int(width), generator=gens, controlnet_conditioning_scale=float(controlnet_conditioning_scale), ).images # type: ignore first = imgs[0] if imgs else None return first, imgs def next_image(images, idx): if not images: return None, 0, "0 / 0" idx = (int(idx) + 1) % len(images) return images[idx], idx, f"{idx + 1} / {len(images)}" def prev_image(images, idx): if not images: return None, 0, "0 / 0" idx = (int(idx) - 1) % len(images) return images[idx], idx, f"{idx + 1} / {len(images)}" # ----------------------------- # UI # ----------------------------- IMG_H = 360 # uniform-ish size for both preview boxes with gr.Blocks(css=BIG_CSS) as demo: gr.Markdown("# Canny-Edge ControlNet Demo") gr.Markdown("**Note:** Trained on aesthetic/artistic images — best results come from similar, stylised inputs.") # state canny_state = gr.State(None) width_state = gr.State(512) height_state = gr.State(512) gen_images_state = gr.State([]) # list[PIL] gen_index_state = gr.State(0) with gr.Row(): # ---- Left: Canny + Canny controls ---- with gr.Column(scale=1): input_image = gr.Image( label="Input Image", type="pil", image_mode="RGB", height=IMG_H, ) canny_preview = gr.Image( label="Canny edges", type="pil", height=IMG_H, ) gr.Markdown("### Edge controls") use_clahe = gr.Checkbox( label="Stabilise contrast (CLAHE)", value=True, info="Helps edges stay consistent under different lighting/contrast.", ) edge_amount = gr.Slider( label="Edge Amount", minimum=0.0, maximum=1.0, value=0.6, step=0.01, info="More = detect more edges (more detail). Less = cleaner outline.", ) smoothing = gr.Slider( label="Smoothing", minimum=0.0, maximum=1.0, value=0.4, step=0.01, info="More = reduce tiny texture/noise edges, cleaner structure.", ) # ---- Right: Generated output + generation controls ---- with gr.Column(scale=1): generated = gr.Image( label="Generated image", type="pil", height=IMG_H, ) with gr.Row(): prev_btn = gr.Button("◀ Prev") page_label = gr.Markdown("0 / 0") next_btn = gr.Button("Next ▶") gr.Markdown("### Generation controls") positive_prompt = gr.Textbox( label="Positive Prompt", value="", lines=2, info="Describe what you want. The edges guide the structure.", ) negative_prompt = gr.Textbox( label="Negative Prompt", value="", lines=2, info="Things to avoid (e.g. blurry, deformed, low quality).", ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1, info="Higher = follow text prompt more strongly (can drift from edges).", ) controlnet_conditioning_scale = gr.Slider( label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.05, info="Higher = follow edges more strongly. Too high can reduce creativity.", ) with gr.Row(): num_inference_steps = gr.Slider( label="Steps", minimum=10, maximum=80, value=50, step=1, info="More steps can improve quality but is slower.", ) num_images = gr.Slider( label="Samples", minimum=1, maximum=8, value=4, step=1, info="How many images to generate.", ) run_btn = gr.Button("Generate", variant="primary") # Auto-update Canny preview on changes (CPU) auto_inputs = [input_image, use_clahe, edge_amount, smoothing] for c in auto_inputs: c.change( fn=update_canny_preview, inputs=auto_inputs, outputs=[canny_preview, canny_state, width_state, height_state], ) # Generate (GPU) -> store list -> show first -> update paging label run_btn.click( fn=generate_from_canny, inputs=[ canny_state, width_state, height_state, positive_prompt, negative_prompt, guidance_scale, num_inference_steps, num_images, controlnet_conditioning_scale, ], outputs=[generated, gen_images_state], # visible output first => proper "Generating..." UX ).then( fn=lambda imgs: (0, f"1 / {len(imgs)}") if imgs else (0, "0 / 0"), inputs=[gen_images_state], outputs=[gen_index_state, page_label], ) # Paging buttons (CPU) next_btn.click( fn=next_image, inputs=[gen_images_state, gen_index_state], outputs=[generated, gen_index_state, page_label], ) prev_btn.click( fn=prev_image, inputs=[gen_images_state, gen_index_state], outputs=[generated, gen_index_state, page_label], ) if __name__ == "__main__": demo.launch()