PASD / app_zero.py
fffiloni's picture
track tqdm
01fd7bd verified
import sys
import types
import datetime
import re
from pathlib import Path
import huggingface_hub
# -------------------------------------------------------------------
# Compatibility shim: older diffusers may still expect cached_download
# -------------------------------------------------------------------
if not hasattr(huggingface_hub, "cached_download"):
def cached_download(*args, **kwargs):
return huggingface_hub.hf_hub_download(*args, **kwargs)
huggingface_hub.cached_download = cached_download
import torch
import numpy as np
import einops
import spaces
import gradio as gr
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from pytorch_lightning import seed_everything
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
from diffusers import (
AutoencoderKL,
DDIMScheduler,
PNDMScheduler,
DPMSolverMultistepScheduler,
UniPCMultistepScheduler,
)
# -------------------------------------------------------------------
# GPU spoof for Spaces env compatibility
# -------------------------------------------------------------------
torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6)
torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace(
name="NVIDIA A10G",
major=8,
minor=6,
total_memory=23836033024,
multi_processor_count=80,
)
# -------------------------------------------------------------------
# Download required assets
# -------------------------------------------------------------------
huggingface_hub.snapshot_download(
repo_id="camenduru/PASD",
allow_patterns=[
"pasd/**",
"pasd_light/**",
"pasd_light_rrdb/**",
"pasd_rrdb/**",
],
local_dir="PASD/runs",
)
huggingface_hub.hf_hub_download(
repo_id="camenduru/PASD",
filename="majicmixRealistic_v6.safetensors",
local_dir="PASD/checkpoints/personalized_models",
)
huggingface_hub.hf_hub_download(
repo_id="akhaliq/RetinaFace-R50",
filename="RetinaFace-R50.pth",
local_dir="PASD/annotator/ckpts",
)
# -------------------------------------------------------------------
# PASD local path
# -------------------------------------------------------------------
sys.path.append("./PASD")
# -------------------------------------------------------------------
# Runtime patching helpers
# -------------------------------------------------------------------
def patch_file(path_str: str, replacements: list[tuple[str, str]]) -> None:
path = Path(path_str)
if not path.exists():
print(f"[patch] file not found: {path}")
return
try:
text = path.read_text(encoding="utf-8")
except Exception as e:
print(f"[patch] failed reading {path}: {e}")
return
original = text
for old, new in replacements:
text = text.replace(old, new)
if text != original:
try:
path.write_text(text, encoding="utf-8")
print(f"[patch] updated: {path}")
except Exception as e:
print(f"[patch] failed writing {path}: {e}")
else:
print(f"[patch] no changes: {path}")
def patch_controlnet_loader_import(path_str: str) -> None:
path = Path(path_str)
if not path.exists():
print(f"[patch] file not found: {path}")
return
try:
text = path.read_text(encoding="utf-8")
except Exception as e:
print(f"[patch] failed reading {path}: {e}")
return
safe_block = """try:
from diffusers.loaders import FromOriginalControlNetMixin as FromOriginalControlnetMixin
except Exception:
try:
from diffusers.loaders import FromOriginalControlnetMixin
except Exception:
class FromOriginalControlnetMixin:
pass
"""
original = text
# Enlève d'anciens imports simples
text = re.sub(
r"(?m)^from diffusers\.loaders[^\n]*FromOriginalControl\w*Mixin[^\n]*\n",
"",
text,
)
text = re.sub(
r"(?m)^from diffusers\.loaders\.single_file_model[^\n]*FromOriginal\w+[^\n]*\n",
"",
text,
)
# Enlève d'anciens blocs try/except cassés liés à ce mixin
text = re.sub(
r"(?ms)^try:\n(?:(?: |\t).*\n)+?except Exception:\n(?:(?: |\t).*\n)+?(?=^(?:class|def|@|from |import |\Z))",
lambda m: "" if "FromOriginalControl" in m.group(0) else m.group(0),
text,
)
# Normalise la référence de mixin dans le reste du fichier
text = text.replace("FromOriginalControlNetMixin", "FromOriginalControlnetMixin")
marker = "class ControlNetConditioningEmbedding"
if safe_block not in text:
idx = text.find(marker)
if idx != -1:
text = text[:idx] + safe_block + text[idx:]
else:
text = safe_block + text
if text != original:
try:
path.write_text(text, encoding="utf-8")
print(f"[patch] normalized: {path}")
except Exception as e:
print(f"[patch] failed writing {path}: {e}")
else:
print(f"[patch] no changes: {path}")
def patch_pasd_for_diffusers() -> None:
# pipeline_utils path moved
patch_file(
"./PASD/pipelines/pipeline_pasd.py",
[
(
"from diffusers.pipeline_utils import DiffusionPipeline",
"from diffusers import DiffusionPipeline",
),
],
)
# PositionNet -> GLIGENTextBoundingboxProjection alias
patch_file(
"./PASD/models/pasd/unet_2d_condition.py",
[
(" PositionNet,\n", ""),
(
" GLIGENTextBoundingboxProjection,\n",
" GLIGENTextBoundingboxProjection as PositionNet,\n",
),
],
)
# internal module paths moved in newer diffusers
patch_file(
"./PASD/models/pasd/unet_2d_blocks.py",
[
(
"from diffusers.models.attention import AdaGroupNorm",
"from diffusers.models.normalization import AdaGroupNorm",
),
(
"from diffusers.models.dual_transformer_2d import DualTransformer2DModel",
"from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel",
),
(
"from diffusers.models.transformer_2d import Transformer2DModel",
"from diffusers.models.transformers.transformer_2d import Transformer2DModel",
),
],
)
# robust controlnet patch
patch_controlnet_loader_import("./PASD/models/pasd/controlnet.py")
patch_pasd_for_diffusers()
# -------------------------------------------------------------------
# Import PASD modules only after patching
# -------------------------------------------------------------------
from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline
from myutils.misc import load_dreambooth_lora
from myutils.wavelet_color_fix import wavelet_color_fix
from annotator.retinaface import RetinaFaceDetection
use_pasd_light = False
face_detector = RetinaFaceDetection()
if use_pasd_light:
from models.pasd_light.unet_2d_condition import UNet2DConditionModel
from models.pasd_light.controlnet import ControlNetModel
else:
from models.pasd.unet_2d_condition import UNet2DConditionModel
from models.pasd.controlnet import ControlNetModel
# -------------------------------------------------------------------
# Model setup
# -------------------------------------------------------------------
pretrained_model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
ckpt_path = "PASD/runs/pasd/checkpoint-100000"
dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
weight_dtype = torch.float16
device = "cuda"
scheduler = UniPCMultistepScheduler.from_pretrained(
pretrained_model_path, subfolder="scheduler"
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_path, subfolder="text_encoder"
)
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_path, subfolder="tokenizer"
)
vae = AutoencoderKL.from_pretrained(
pretrained_model_path, subfolder="vae"
)
feature_extractor = CLIPImageProcessor.from_pretrained(
pretrained_model_path, subfolder="feature_extractor"
)
unet = UNet2DConditionModel.from_pretrained(
ckpt_path, subfolder="unet"
)
controlnet = ControlNetModel.from_pretrained(
ckpt_path, subfolder="controlnet"
)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
controlnet.requires_grad_(False)
unet, vae, text_encoder = load_dreambooth_lora(
unet, vae, text_encoder, dreambooth_lora_path
)
text_encoder.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
unet.to(device, dtype=weight_dtype)
controlnet.to(device, dtype=weight_dtype)
validation_pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
)
validation_pipeline._init_tiled_vae(decoder_tile_size=224)
# -------------------------------------------------------------------
# ResNet helper
# -------------------------------------------------------------------
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()
resnet = resnet50(weights=weights)
resnet.eval()
def resize_image(image_path: str, target_height: int) -> Image.Image:
with Image.open(image_path) as img:
ratio = target_height / float(img.size[1])
new_width = int(float(img.size[0]) * ratio)
return img.resize((new_width, target_height), Image.LANCZOS)
@spaces.GPU(enable_queue=True)
def inference(
input_image,
prompt,
a_prompt,
n_prompt,
denoise_steps,
upscale,
alpha,
cfg,
seed,
progress=gr.Progress(track_tqdm=True)
):
if seed == -1:
seed = 0
input_image = resize_image(input_image, 512)
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
with torch.no_grad():
seed_everything(seed)
generator = torch.Generator(device=device)
generator.manual_seed(seed)
input_image = input_image.convert("RGB")
batch = preprocess(input_image).unsqueeze(0)
prediction = resnet(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
if score >= 0.1:
prompt += f"{category_name}" if prompt == "" else f", {category_name}"
prompt = a_prompt if prompt == "" else f"{prompt}, {a_prompt}"
ori_width, ori_height = input_image.size
rscale = upscale
input_image = input_image.resize(
(input_image.size[0] * rscale, input_image.size[1] * rscale)
)
input_image = input_image.resize(
(input_image.size[0] // 8 * 8, input_image.size[1] // 8 * 8)
)
width, height = input_image.size
try:
image = validation_pipeline(
None,
prompt,
input_image,
num_inference_steps=denoise_steps,
generator=generator,
height=height,
width=width,
guidance_scale=cfg,
negative_prompt=n_prompt,
conditioning_scale=alpha,
eta=0.0,
).images[0]
image = wavelet_color_fix(image, input_image)
image = image.resize((ori_width * rscale, ori_height * rscale))
except Exception as e:
print(f"[inference] error: {e}")
image = Image.new(mode="RGB", size=(512, 512))
result_path = f"result_{timestamp}.jpg"
input_path = f"input_{timestamp}.jpg"
image.save(result_path, "JPEG")
input_image.save(input_path, "JPEG")
return input_path, result_path, result_path
css = """
#col-container{
margin: 0 auto;
max-width: 720px;
}
#project-links{
margin: 0 0 12px !important;
column-gap: 8px;
display: flex;
justify-content: center;
flex-wrap: nowrap;
flex-direction: row;
align-items: center;
}
"""
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h2 style="text-align: center;">
PASD Magnify
</h2>
<p style="text-align: center;">
Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization
</p>
<p id="project-links" align="center">
<a href="https://github.com/yangxy/PASD"><img src="https://img.shields.io/badge/Project-Page-Green"></a>
<a href="https://huggingface.co/papers/2308.14469"><img src="https://img.shields.io/badge/Paper-Arxiv-red"></a>
</p>
<p style="margin:12px auto;display: flex;justify-content: center;">
<a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
</a>
</p>
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="filepath",
sources=["upload"],
value="PASD/samples/frog.png",
label="Input image",
)
prompt_in = gr.Textbox(label="Prompt", value="Frog")
with gr.Accordion(label="Advanced settings", open=False):
added_prompt = gr.Textbox(
label="Added Prompt",
value="clean, high-resolution, 8k, best quality, masterpiece",
)
neg_prompt = gr.Textbox(
label="Negative Prompt",
value="dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
)
denoise_steps = gr.Slider(
label="Denoise Steps",
minimum=10,
maximum=50,
value=20,
step=1,
)
upsample_scale = gr.Slider(
label="Upsample Scale",
minimum=1,
maximum=4,
value=2,
step=1,
)
condition_scale = gr.Slider(
label="Conditioning Scale",
minimum=0.5,
maximum=1.5,
value=1.1,
step=0.1,
)
classifier_free_guidance = gr.Slider(
label="Classifier-free Guidance",
minimum=0.1,
maximum=10.0,
value=7.5,
step=0.1,
)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
randomize=True,
)
submit_btn = gr.Button("Submit")
with gr.Column():
before_img = gr.Image(label="Input")
after_img = gr.Image(label="Result")
file_output = gr.File(label="Downloadable image result")
submit_btn.click(
fn=inference,
inputs=[
input_image,
prompt_in,
added_prompt,
neg_prompt,
denoise_steps,
upsample_scale,
condition_scale,
classifier_free_guidance,
seed,
],
outputs=[
before_img,
after_img,
file_output,
],
api_visibility="private",
)
demo.queue(max_size=10).launch(
ssr_mode=False,
mcp_server=False,
css=css,
)