MeshForge / external /MV-Adapter /scripts /inference_scribble2mv_sdxl.py
Daankular's picture
Upload external/MV-Adapter/scripts/inference_scribble2mv_sdxl.py with huggingface_hub
e8c88d7 verified
import argparse
import random
import cv2
import numpy as np
import torch
from controlnet_aux import HEDdetector, PidiNetDetector
from diffusers import (
AutoencoderKL,
ControlNetModel,
DDPMScheduler,
LCMScheduler,
UNet2DConditionModel,
)
from PIL import Image
from mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline
from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
from mvadapter.utils.mesh_utils import get_orthogonal_camera
from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho
from mvadapter.utils import make_image_grid
def prepare_pipeline(
base_model,
vae_model,
unet_model,
lora_model,
adapter_path,
scheduler,
num_views,
device,
dtype,
):
# Load vae and unet if provided
pipe_kwargs = {}
if vae_model is not None:
pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
if unet_model is not None:
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
# Prepare pipeline
pipe: MVAdapterT2MVSDXLPipeline
pipe = MVAdapterT2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
# Load scheduler if provided
scheduler_class = None
if scheduler == "ddpm":
scheduler_class = DDPMScheduler
elif scheduler == "lcm":
scheduler_class = LCMScheduler
pipe.scheduler = ShiftSNRScheduler.from_scheduler(
pipe.scheduler,
shift_mode="interpolated",
shift_scale=8.0,
scheduler_class=scheduler_class,
)
pipe.init_custom_adapter(num_views=num_views)
pipe.load_custom_adapter(
adapter_path, weight_name="mvadapter_t2mv_sdxl.safetensors"
)
# ControlNet
pipe.controlnet = ControlNetModel.from_pretrained(
"xinsir/controlnet-scribble-sdxl-1.0"
)
pipe.to(device=device, dtype=dtype)
pipe.cond_encoder.to(device=device, dtype=dtype)
pipe.controlnet.to(device=device, dtype=dtype)
# load lora if provided
if lora_model is not None:
model_, name_ = lora_model.rsplit("/", 1)
pipe.load_lora_weights(model_, weight_name=name_)
# vae slicing for lower memory usage
pipe.enable_vae_slicing()
return pipe
def nms(x, t, s):
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
z = np.zeros_like(y, dtype=np.uint8)
z[y > t] = 255
return z
def preprocess_controlnet_image(image_path, height, width):
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
image = cv2.resize(image, (width, height))
if image.shape[2] == 4:
alpha_channel = image[:, :, 3] / 255.0
rgb_channels = image[:, :, :3] / 255.0
gray_background = np.ones_like(rgb_channels) * 0.5
image = (
alpha_channel[..., None] * rgb_channels
+ (1 - alpha_channel[..., None]) * gray_background
)
image = (image * 255).astype(np.uint8)
processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
image = processor(image, scribble=False)
# following is some processing to simulate human sketch draw, different threshold can generate different width of lines
image = np.array(image)
image = nms(image, 127, 3)
image = cv2.GaussianBlur(image, (0, 0), 3)
# higher threshold, thiner line
random_val = int(round(random.uniform(0.01, 0.10), 2) * 255)
image[image > random_val] = 255
image[image < 255] = 0
return Image.fromarray(image)
def run_pipeline(
pipe,
num_views,
text,
height,
width,
num_inference_steps,
guidance_scale,
seed,
controlnet_images,
controlnet_conditioning_scale,
lora_scale=1.0,
device="cuda",
):
# Prepare cameras
cameras = get_orthogonal_camera(
elevation_deg=[0, 0, 0, 0, 0, 0],
distance=[1.8] * num_views,
left=-0.55,
right=0.55,
bottom=-0.55,
top=0.55,
azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
device=device,
)
plucker_embeds = get_plucker_embeds_from_cameras_ortho(
cameras.c2w, [1.1] * num_views, width
)
control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
pipe_kwargs = {}
if seed != -1:
pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
# Prepare controlnet images
controlnet_image = [
preprocess_controlnet_image(path, height, width) for path in controlnet_images
]
pipe_kwargs.update(
{
"controlnet_image": controlnet_image,
"controlnet_conditioning_scale": controlnet_conditioning_scale,
}
)
images = pipe(
text,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_views,
control_image=control_images,
control_conditioning_scale=1.0,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
cross_attention_kwargs={"scale": lora_scale},
**pipe_kwargs,
).images
return images, controlnet_image
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Models
parser.add_argument(
"--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
)
parser.add_argument(
"--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
)
parser.add_argument("--unet_model", type=str, default=None)
parser.add_argument("--scheduler", type=str, default=None)
parser.add_argument("--lora_model", type=str, default=None)
parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
parser.add_argument("--num_views", type=int, default=6)
# Device
parser.add_argument("--device", type=str, default="cuda")
# Inference
parser.add_argument("--text", type=str, required=True)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=7.0)
parser.add_argument("--seed", type=int, default=-1)
parser.add_argument("--lora_scale", type=float, default=1.0)
parser.add_argument("--output", type=str, default="output.png")
parser.add_argument("--controlnet_images", type=str, nargs="+", required=True)
parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0)
args = parser.parse_args()
pipe = prepare_pipeline(
base_model=args.base_model,
vae_model=args.vae_model,
unet_model=args.unet_model,
lora_model=args.lora_model,
adapter_path=args.adapter_path,
scheduler=args.scheduler,
num_views=args.num_views,
device=args.device,
dtype=torch.float16,
)
images, controlnet_images = run_pipeline(
pipe,
num_views=args.num_views,
text=args.text,
height=768,
width=768,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
seed=args.seed,
controlnet_images=args.controlnet_images,
controlnet_conditioning_scale=args.controlnet_conditioning_scale,
lora_scale=args.lora_scale,
device=args.device,
)
make_image_grid(images, rows=1).save(args.output)
make_image_grid(controlnet_images, rows=1).save(
args.output.rsplit(".", 1)[0] + "_controlnet.png"
)