Spaces:
Running on Zero
Running on Zero
| import sys | |
| import types | |
| import datetime | |
| import re | |
| from pathlib import Path | |
| import huggingface_hub | |
| # ------------------------------------------------------------------- | |
| # Compatibility shim: older diffusers may still expect cached_download | |
| # ------------------------------------------------------------------- | |
| if not hasattr(huggingface_hub, "cached_download"): | |
| def cached_download(*args, **kwargs): | |
| return huggingface_hub.hf_hub_download(*args, **kwargs) | |
| huggingface_hub.cached_download = cached_download | |
| import torch | |
| import numpy as np | |
| import einops | |
| import spaces | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| from pytorch_lightning import seed_everything | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor | |
| from diffusers import ( | |
| AutoencoderKL, | |
| DDIMScheduler, | |
| PNDMScheduler, | |
| DPMSolverMultistepScheduler, | |
| UniPCMultistepScheduler, | |
| ) | |
| # ------------------------------------------------------------------- | |
| # GPU spoof for Spaces env compatibility | |
| # ------------------------------------------------------------------- | |
| torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6) | |
| torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace( | |
| name="NVIDIA A10G", | |
| major=8, | |
| minor=6, | |
| total_memory=23836033024, | |
| multi_processor_count=80, | |
| ) | |
| # ------------------------------------------------------------------- | |
| # Download required assets | |
| # ------------------------------------------------------------------- | |
| huggingface_hub.snapshot_download( | |
| repo_id="camenduru/PASD", | |
| allow_patterns=[ | |
| "pasd/**", | |
| "pasd_light/**", | |
| "pasd_light_rrdb/**", | |
| "pasd_rrdb/**", | |
| ], | |
| local_dir="PASD/runs", | |
| ) | |
| huggingface_hub.hf_hub_download( | |
| repo_id="camenduru/PASD", | |
| filename="majicmixRealistic_v6.safetensors", | |
| local_dir="PASD/checkpoints/personalized_models", | |
| ) | |
| huggingface_hub.hf_hub_download( | |
| repo_id="akhaliq/RetinaFace-R50", | |
| filename="RetinaFace-R50.pth", | |
| local_dir="PASD/annotator/ckpts", | |
| ) | |
| # ------------------------------------------------------------------- | |
| # PASD local path | |
| # ------------------------------------------------------------------- | |
| sys.path.append("./PASD") | |
| # ------------------------------------------------------------------- | |
| # Runtime patching helpers | |
| # ------------------------------------------------------------------- | |
| def patch_file(path_str: str, replacements: list[tuple[str, str]]) -> None: | |
| path = Path(path_str) | |
| if not path.exists(): | |
| print(f"[patch] file not found: {path}") | |
| return | |
| try: | |
| text = path.read_text(encoding="utf-8") | |
| except Exception as e: | |
| print(f"[patch] failed reading {path}: {e}") | |
| return | |
| original = text | |
| for old, new in replacements: | |
| text = text.replace(old, new) | |
| if text != original: | |
| try: | |
| path.write_text(text, encoding="utf-8") | |
| print(f"[patch] updated: {path}") | |
| except Exception as e: | |
| print(f"[patch] failed writing {path}: {e}") | |
| else: | |
| print(f"[patch] no changes: {path}") | |
| def patch_controlnet_loader_import(path_str: str) -> None: | |
| path = Path(path_str) | |
| if not path.exists(): | |
| print(f"[patch] file not found: {path}") | |
| return | |
| try: | |
| text = path.read_text(encoding="utf-8") | |
| except Exception as e: | |
| print(f"[patch] failed reading {path}: {e}") | |
| return | |
| safe_block = """try: | |
| from diffusers.loaders import FromOriginalControlNetMixin as FromOriginalControlnetMixin | |
| except Exception: | |
| try: | |
| from diffusers.loaders import FromOriginalControlnetMixin | |
| except Exception: | |
| class FromOriginalControlnetMixin: | |
| pass | |
| """ | |
| original = text | |
| # Enlève d'anciens imports simples | |
| text = re.sub( | |
| r"(?m)^from diffusers\.loaders[^\n]*FromOriginalControl\w*Mixin[^\n]*\n", | |
| "", | |
| text, | |
| ) | |
| text = re.sub( | |
| r"(?m)^from diffusers\.loaders\.single_file_model[^\n]*FromOriginal\w+[^\n]*\n", | |
| "", | |
| text, | |
| ) | |
| # Enlève d'anciens blocs try/except cassés liés à ce mixin | |
| text = re.sub( | |
| r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))", | |
| lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0), | |
| text, | |
| ) | |
| # Normalise la référence de mixin dans le reste du fichier | |
| text = text.replace("FromOriginalControlNetMixin", "FromOriginalControlnetMixin") | |
| marker = "class ControlNetConditioningEmbedding" | |
| if safe_block not in text: | |
| idx = text.find(marker) | |
| if idx != -1: | |
| text = text[:idx] + safe_block + text[idx:] | |
| else: | |
| text = safe_block + text | |
| if text != original: | |
| try: | |
| path.write_text(text, encoding="utf-8") | |
| print(f"[patch] normalized: {path}") | |
| except Exception as e: | |
| print(f"[patch] failed writing {path}: {e}") | |
| else: | |
| print(f"[patch] no changes: {path}") | |
| def patch_pasd_for_diffusers() -> None: | |
| # pipeline_utils path moved | |
| patch_file( | |
| "./PASD/pipelines/pipeline_pasd.py", | |
| [ | |
| ( | |
| "from diffusers.pipeline_utils import DiffusionPipeline", | |
| "from diffusers import DiffusionPipeline", | |
| ), | |
| ], | |
| ) | |
| # PositionNet -> GLIGENTextBoundingboxProjection alias | |
| patch_file( | |
| "./PASD/models/pasd/unet_2d_condition.py", | |
| [ | |
| (" PositionNet,\n", ""), | |
| ( | |
| " GLIGENTextBoundingboxProjection,\n", | |
| " GLIGENTextBoundingboxProjection as PositionNet,\n", | |
| ), | |
| ], | |
| ) | |
| # internal module paths moved in newer diffusers | |
| patch_file( | |
| "./PASD/models/pasd/unet_2d_blocks.py", | |
| [ | |
| ( | |
| "from diffusers.models.attention import AdaGroupNorm", | |
| "from diffusers.models.normalization import AdaGroupNorm", | |
| ), | |
| ( | |
| "from diffusers.models.dual_transformer_2d import DualTransformer2DModel", | |
| "from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel", | |
| ), | |
| ( | |
| "from diffusers.models.transformer_2d import Transformer2DModel", | |
| "from diffusers.models.transformers.transformer_2d import Transformer2DModel", | |
| ), | |
| ], | |
| ) | |
| # robust controlnet patch | |
| patch_controlnet_loader_import("./PASD/models/pasd/controlnet.py") | |
| patch_pasd_for_diffusers() | |
| # ------------------------------------------------------------------- | |
| # Import PASD modules only after patching | |
| # ------------------------------------------------------------------- | |
| from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline | |
| from myutils.misc import load_dreambooth_lora | |
| from myutils.wavelet_color_fix import wavelet_color_fix | |
| from annotator.retinaface import RetinaFaceDetection | |
| use_pasd_light = False | |
| face_detector = RetinaFaceDetection() | |
| if use_pasd_light: | |
| from models.pasd_light.unet_2d_condition import UNet2DConditionModel | |
| from models.pasd_light.controlnet import ControlNetModel | |
| else: | |
| from models.pasd.unet_2d_condition import UNet2DConditionModel | |
| from models.pasd.controlnet import ControlNetModel | |
| # ------------------------------------------------------------------- | |
| # Model setup | |
| # ------------------------------------------------------------------- | |
| pretrained_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5" | |
| ckpt_path = "PASD/runs/pasd/checkpoint-100000" | |
| dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors" | |
| weight_dtype = torch.float16 | |
| device = "cuda" | |
| scheduler = UniPCMultistepScheduler.from_pretrained( | |
| pretrained_model_path, subfolder="scheduler" | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| pretrained_model_path, subfolder="text_encoder" | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| pretrained_model_path, subfolder="tokenizer" | |
| ) | |
| vae = AutoencoderKL.from_pretrained( | |
| pretrained_model_path, subfolder="vae" | |
| ) | |
| feature_extractor = CLIPImageProcessor.from_pretrained( | |
| pretrained_model_path, subfolder="feature_extractor" | |
| ) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| ckpt_path, subfolder="unet" | |
| ) | |
| controlnet = ControlNetModel.from_pretrained( | |
| ckpt_path, subfolder="controlnet" | |
| ) | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| unet.requires_grad_(False) | |
| controlnet.requires_grad_(False) | |
| unet, vae, text_encoder = load_dreambooth_lora( | |
| unet, vae, text_encoder, dreambooth_lora_path | |
| ) | |
| text_encoder.to(device, dtype=weight_dtype) | |
| vae.to(device, dtype=weight_dtype) | |
| unet.to(device, dtype=weight_dtype) | |
| controlnet.to(device, dtype=weight_dtype) | |
| validation_pipeline = StableDiffusionControlNetPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| feature_extractor=feature_extractor, | |
| unet=unet, | |
| controlnet=controlnet, | |
| scheduler=scheduler, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ) | |
| validation_pipeline._init_tiled_vae(decoder_tile_size=224) | |
| # ------------------------------------------------------------------- | |
| # ResNet helper | |
| # ------------------------------------------------------------------- | |
| weights = ResNet50_Weights.DEFAULT | |
| preprocess = weights.transforms() | |
| resnet = resnet50(weights=weights) | |
| resnet.eval() | |
| def resize_image(image_path: str, target_height: int) -> Image.Image: | |
| with Image.open(image_path) as img: | |
| ratio = target_height / float(img.size[1]) | |
| new_width = int(float(img.size[0]) * ratio) | |
| return img.resize((new_width, target_height), Image.LANCZOS) | |
| def inference( | |
| input_image, | |
| prompt, | |
| a_prompt, | |
| n_prompt, | |
| denoise_steps, | |
| upscale, | |
| alpha, | |
| cfg, | |
| seed, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| if seed == -1: | |
| seed = 0 | |
| input_image = resize_image(input_image, 512) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") | |
| with torch.no_grad(): | |
| seed_everything(seed) | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(seed) | |
| input_image = input_image.convert("RGB") | |
| batch = preprocess(input_image).unsqueeze(0) | |
| prediction = resnet(batch).squeeze(0).softmax(0) | |
| class_id = prediction.argmax().item() | |
| score = prediction[class_id].item() | |
| category_name = weights.meta["categories"][class_id] | |
| if score >= 0.1: | |
| prompt += f"{category_name}" if prompt == "" else f", {category_name}" | |
| prompt = a_prompt if prompt == "" else f"{prompt}, {a_prompt}" | |
| ori_width, ori_height = input_image.size | |
| rscale = upscale | |
| input_image = input_image.resize( | |
| (input_image.size[0] * rscale, input_image.size[1] * rscale) | |
| ) | |
| input_image = input_image.resize( | |
| (input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8) | |
| ) | |
| width, height = input_image.size | |
| try: | |
| image = validation_pipeline( | |
| None, | |
| prompt, | |
| input_image, | |
| num_inference_steps=denoise_steps, | |
| generator=generator, | |
| height=height, | |
| width=width, | |
| guidance_scale=cfg, | |
| negative_prompt=n_prompt, | |
| conditioning_scale=alpha, | |
| eta=0.0, | |
| ).images[0] | |
| image = wavelet_color_fix(image, input_image) | |
| image = image.resize((ori_width * rscale, ori_height * rscale)) | |
| except Exception as e: | |
| print(f"[inference] error: {e}") | |
| image = Image.new(mode="RGB", size=(512, 512)) | |
| result_path = f"result_{timestamp}.jpg" | |
| input_path = f"input_{timestamp}.jpg" | |
| image.save(result_path, "JPEG") | |
| input_image.save(input_path, "JPEG") | |
| return input_path, result_path, result_path | |
| css = """ | |
| #col-container{ | |
| margin: 0 auto; | |
| max-width: 720px; | |
| } | |
| #project-links{ | |
| margin: 0 0 12px !important; | |
| column-gap: 8px; | |
| display: flex; | |
| justify-content: center; | |
| flex-wrap: nowrap; | |
| flex-direction: row; | |
| align-items: center; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(""" | |
| <h2 style="text-align: center;"> | |
| PASD Magnify | |
| </h2> | |
| <p style="text-align: center;"> | |
| Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization | |
| </p> | |
| <p id="project-links" align="center"> | |
| <a href="https://github.com/yangxy/PASD"><img src="https://img.shields.io/badge/Project-Page-Green"></a> | |
| <a href="https://huggingface.co/papers/2308.14469"><img src="https://img.shields.io/badge/Paper-Arxiv-red"></a> | |
| </p> | |
| <p style="margin:12px auto;display: flex;justify-content: center;"> | |
| <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"> | |
| </a> | |
| </p> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="filepath", | |
| sources=["upload"], | |
| value="PASD/samples/frog.png", | |
| label="Input image", | |
| ) | |
| prompt_in = gr.Textbox(label="Prompt", value="Frog") | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| added_prompt = gr.Textbox( | |
| label="Added Prompt", | |
| value="clean, high-resolution, 8k, best quality, masterpiece", | |
| ) | |
| neg_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", | |
| ) | |
| denoise_steps = gr.Slider( | |
| label="Denoise Steps", | |
| minimum=10, | |
| maximum=50, | |
| value=20, | |
| step=1, | |
| ) | |
| upsample_scale = gr.Slider( | |
| label="Upsample Scale", | |
| minimum=1, | |
| maximum=4, | |
| value=2, | |
| step=1, | |
| ) | |
| condition_scale = gr.Slider( | |
| label="Conditioning Scale", | |
| minimum=0.5, | |
| maximum=1.5, | |
| value=1.1, | |
| step=0.1, | |
| ) | |
| classifier_free_guidance = gr.Slider( | |
| label="Classifier-free Guidance", | |
| minimum=0.1, | |
| maximum=10.0, | |
| value=7.5, | |
| step=0.1, | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=-1, | |
| maximum=2147483647, | |
| step=1, | |
| randomize=True, | |
| ) | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(): | |
| before_img = gr.Image(label="Input") | |
| after_img = gr.Image(label="Result") | |
| file_output = gr.File(label="Downloadable image result") | |
| submit_btn.click( | |
| fn=inference, | |
| inputs=[ | |
| input_image, | |
| prompt_in, | |
| added_prompt, | |
| neg_prompt, | |
| denoise_steps, | |
| upsample_scale, | |
| condition_scale, | |
| classifier_free_guidance, | |
| seed, | |
| ], | |
| outputs=[ | |
| before_img, | |
| after_img, | |
| file_output, | |
| ], | |
| api_visibility="private", | |
| ) | |
| demo.queue(max_size=10).launch( | |
| ssr_mode=False, | |
| mcp_server=False, | |
| css=css, | |
| ) |