import sys import types import datetime import re from pathlib import Path import huggingface_hub # ------------------------------------------------------------------- # Compatibility shim: older diffusers may still expect cached_download # ------------------------------------------------------------------- if not hasattr(huggingface_hub, "cached_download"): def cached_download(*args, **kwargs): return huggingface_hub.hf_hub_download(*args, **kwargs) huggingface_hub.cached_download = cached_download import torch import numpy as np import einops import spaces import gradio as gr from PIL import Image from torchvision import transforms import torch.nn.functional as F from torchvision.models import resnet50, ResNet50_Weights from pytorch_lightning import seed_everything from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor from diffusers import ( AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler, ) # ------------------------------------------------------------------- # GPU spoof for Spaces env compatibility # ------------------------------------------------------------------- torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6) torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace( name="NVIDIA A10G", major=8, minor=6, total_memory=23836033024, multi_processor_count=80, ) # ------------------------------------------------------------------- # Download required assets # ------------------------------------------------------------------- huggingface_hub.snapshot_download( repo_id="camenduru/PASD", allow_patterns=[ "pasd/**", "pasd_light/**", "pasd_light_rrdb/**", "pasd_rrdb/**", ], local_dir="PASD/runs", ) huggingface_hub.hf_hub_download( repo_id="camenduru/PASD", filename="majicmixRealistic_v6.safetensors", local_dir="PASD/checkpoints/personalized_models", ) huggingface_hub.hf_hub_download( repo_id="akhaliq/RetinaFace-R50", filename="RetinaFace-R50.pth", local_dir="PASD/annotator/ckpts", ) # ------------------------------------------------------------------- # PASD local path # ------------------------------------------------------------------- sys.path.append("./PASD") # ------------------------------------------------------------------- # Runtime patching helpers # ------------------------------------------------------------------- def patch_file(path_str: str, replacements: list[tuple[str, str]]) -> None: path = Path(path_str) if not path.exists(): print(f"[patch] file not found: {path}") return try: text = path.read_text(encoding="utf-8") except Exception as e: print(f"[patch] failed reading {path}: {e}") return original = text for old, new in replacements: text = text.replace(old, new) if text != original: try: path.write_text(text, encoding="utf-8") print(f"[patch] updated: {path}") except Exception as e: print(f"[patch] failed writing {path}: {e}") else: print(f"[patch] no changes: {path}") def patch_controlnet_loader_import(path_str: str) -> None: path = Path(path_str) if not path.exists(): print(f"[patch] file not found: {path}") return try: text = path.read_text(encoding="utf-8") except Exception as e: print(f"[patch] failed reading {path}: {e}") return safe_block = """try: from diffusers.loaders import FromOriginalControlNetMixin as FromOriginalControlnetMixin except Exception: try: from diffusers.loaders import FromOriginalControlnetMixin except Exception: class FromOriginalControlnetMixin: pass """ original = text # Enlève d'anciens imports simples text = re.sub( r"(?m)^from diffusers\.loaders[^\n]*FromOriginalControl\w*Mixin[^\n]*\n", "", text, ) text = re.sub( r"(?m)^from diffusers\.loaders\.single_file_model[^\n]*FromOriginal\w+[^\n]*\n", "", text, ) # Enlève d'anciens blocs try/except cassés liés à ce mixin text = re.sub( r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))", lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0), text, ) # Normalise la référence de mixin dans le reste du fichier text = text.replace("FromOriginalControlNetMixin", "FromOriginalControlnetMixin") marker = "class ControlNetConditioningEmbedding" if safe_block not in text: idx = text.find(marker) if idx != -1: text = text[:idx] + safe_block + text[idx:] else: text = safe_block + text if text != original: try: path.write_text(text, encoding="utf-8") print(f"[patch] normalized: {path}") except Exception as e: print(f"[patch] failed writing {path}: {e}") else: print(f"[patch] no changes: {path}") def patch_pasd_for_diffusers() -> None: # pipeline_utils path moved patch_file( "./PASD/pipelines/pipeline_pasd.py", [ ( "from diffusers.pipeline_utils import DiffusionPipeline", "from diffusers import DiffusionPipeline", ), ], ) # PositionNet -> GLIGENTextBoundingboxProjection alias patch_file( "./PASD/models/pasd/unet_2d_condition.py", [ (" PositionNet,\n", ""), ( " GLIGENTextBoundingboxProjection,\n", " GLIGENTextBoundingboxProjection as PositionNet,\n", ), ], ) # internal module paths moved in newer diffusers patch_file( "./PASD/models/pasd/unet_2d_blocks.py", [ ( "from diffusers.models.attention import AdaGroupNorm", "from diffusers.models.normalization import AdaGroupNorm", ), ( "from diffusers.models.dual_transformer_2d import DualTransformer2DModel", "from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel", ), ( "from diffusers.models.transformer_2d import Transformer2DModel", "from diffusers.models.transformers.transformer_2d import Transformer2DModel", ), ], ) # robust controlnet patch patch_controlnet_loader_import("./PASD/models/pasd/controlnet.py") patch_pasd_for_diffusers() # ------------------------------------------------------------------- # Import PASD modules only after patching # ------------------------------------------------------------------- from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline from myutils.misc import load_dreambooth_lora from myutils.wavelet_color_fix import wavelet_color_fix from annotator.retinaface import RetinaFaceDetection use_pasd_light = False face_detector = RetinaFaceDetection() if use_pasd_light: from models.pasd_light.unet_2d_condition import UNet2DConditionModel from models.pasd_light.controlnet import ControlNetModel else: from models.pasd.unet_2d_condition import UNet2DConditionModel from models.pasd.controlnet import ControlNetModel # ------------------------------------------------------------------- # Model setup # ------------------------------------------------------------------- pretrained_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5" ckpt_path = "PASD/runs/pasd/checkpoint-100000" dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors" weight_dtype = torch.float16 device = "cuda" scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_path, subfolder="scheduler" ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_path, subfolder="text_encoder" ) tokenizer = CLIPTokenizer.from_pretrained( pretrained_model_path, subfolder="tokenizer" ) vae = AutoencoderKL.from_pretrained( pretrained_model_path, subfolder="vae" ) feature_extractor = CLIPImageProcessor.from_pretrained( pretrained_model_path, subfolder="feature_extractor" ) unet = UNet2DConditionModel.from_pretrained( ckpt_path, subfolder="unet" ) controlnet = ControlNetModel.from_pretrained( ckpt_path, subfolder="controlnet" ) vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) controlnet.requires_grad_(False) unet, vae, text_encoder = load_dreambooth_lora( unet, vae, text_encoder, dreambooth_lora_path ) text_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) controlnet.to(device, dtype=weight_dtype) validation_pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, ) validation_pipeline._init_tiled_vae(decoder_tile_size=224) # ------------------------------------------------------------------- # ResNet helper # ------------------------------------------------------------------- weights = ResNet50_Weights.DEFAULT preprocess = weights.transforms() resnet = resnet50(weights=weights) resnet.eval() def resize_image(image_path: str, target_height: int) -> Image.Image: with Image.open(image_path) as img: ratio = target_height / float(img.size[1]) new_width = int(float(img.size[0]) * ratio) return img.resize((new_width, target_height), Image.LANCZOS) @spaces.GPU(enable_queue=True) def inference( input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed, progress=gr.Progress(track_tqdm=True) ): if seed == -1: seed = 0 input_image = resize_image(input_image, 512) timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") with torch.no_grad(): seed_everything(seed) generator = torch.Generator(device=device) generator.manual_seed(seed) input_image = input_image.convert("RGB") batch = preprocess(input_image).unsqueeze(0) prediction = resnet(batch).squeeze(0).softmax(0) class_id = prediction.argmax().item() score = prediction[class_id].item() category_name = weights.meta["categories"][class_id] if score >= 0.1: prompt += f"{category_name}" if prompt == "" else f", {category_name}" prompt = a_prompt if prompt == "" else f"{prompt}, {a_prompt}" ori_width, ori_height = input_image.size rscale = upscale input_image = input_image.resize( (input_image.size[0] * rscale, input_image.size[1] * rscale) ) input_image = input_image.resize( (input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8) ) width, height = input_image.size try: image = validation_pipeline( None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0, ).images[0] image = wavelet_color_fix(image, input_image) image = image.resize((ori_width * rscale, ori_height * rscale)) except Exception as e: print(f"[inference] error: {e}") image = Image.new(mode="RGB", size=(512, 512)) result_path = f"result_{timestamp}.jpg" input_path = f"input_{timestamp}.jpg" image.save(result_path, "JPEG") input_image.save(input_path, "JPEG") return input_path, result_path, result_path css = """ #col-container{ margin: 0 auto; max-width: 720px; } #project-links{ margin: 0 0 12px !important; column-gap: 8px; display: flex; justify-content: center; flex-wrap: nowrap; flex-direction: row; align-items: center; } """ with gr.Blocks() as demo: with gr.Column(elem_id="col-container"): gr.HTML("""
Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
""") with gr.Row(): with gr.Column(): input_image = gr.Image( type="filepath", sources=["upload"], value="PASD/samples/frog.png", label="Input image", ) prompt_in = gr.Textbox(label="Prompt", value="Frog") with gr.Accordion(label="Advanced settings", open=False): added_prompt = gr.Textbox( label="Added Prompt", value="clean, high-resolution, 8k, best quality, masterpiece", ) neg_prompt = gr.Textbox( label="Negative Prompt", value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", ) denoise_steps = gr.Slider( label="Denoise Steps", minimum=10, maximum=50, value=20, step=1, ) upsample_scale = gr.Slider( label="Upsample Scale", minimum=1, maximum=4, value=2, step=1, ) condition_scale = gr.Slider( label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1, ) classifier_free_guidance = gr.Slider( label="Classifier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1, ) seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True, ) submit_btn = gr.Button("Submit") with gr.Column(): before_img = gr.Image(label="Input") after_img = gr.Image(label="Result") file_output = gr.File(label="Downloadable image result") submit_btn.click( fn=inference, inputs=[ input_image, prompt_in, added_prompt, neg_prompt, denoise_steps, upsample_scale, condition_scale, classifier_free_guidance, seed, ], outputs=[ before_img, after_img, file_output, ], api_visibility="private", ) demo.queue(max_size=10).launch( ssr_mode=False, mcp_server=False, css=css, )