import argparse import numpy as np import torch from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation from mvadapter.pipelines.pipeline_mvadapter_i2mv_sd import MVAdapterI2MVSDPipeline from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler from mvadapter.utils.mesh_utils import get_orthogonal_camera from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho from mvadapter.utils import make_image_grid def prepare_pipeline( base_model, vae_model, unet_model, lora_model, adapter_path, scheduler, num_views, device, dtype, ): # Load vae and unet if provided pipe_kwargs = {} if vae_model is not None: pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) if unet_model is not None: pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) # Prepare pipeline pipe: MVAdapterI2MVSDPipeline pipe = MVAdapterI2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs) # Load scheduler if provided scheduler_class = None if scheduler == "ddpm": scheduler_class = DDPMScheduler elif scheduler == "lcm": scheduler_class = LCMScheduler pipe.scheduler = ShiftSNRScheduler.from_scheduler( pipe.scheduler, shift_mode="interpolated", shift_scale=8.0, scheduler_class=scheduler_class, ) pipe.init_custom_adapter(num_views=num_views) pipe.load_custom_adapter( adapter_path, weight_name="mvadapter_i2mv_sd21.safetensors" ) pipe.to(device=device, dtype=dtype) pipe.cond_encoder.to(device=device, dtype=dtype) # load lora if provided if lora_model is not None: model_, name_ = lora_model.rsplit("/", 1) pipe.load_lora_weights(model_, weight_name=name_) # vae slicing for lower memory usage pipe.enable_vae_slicing() return pipe def remove_bg(image, net, transform, device): image_size = image.size input_images = transform(image).unsqueeze(0).to(device) with torch.no_grad(): preds = net(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image def preprocess_image(image: Image.Image, height, width): image = np.array(image) alpha = image[..., 3] > 0 H, W = alpha.shape # get the bounding box of alpha y, x = np.where(alpha) y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) image_center = image[y0:y1, x0:x1] # resize the longer side to H * 0.9 H, W, _ = image_center.shape if H > W: W = int(W * (height * 0.9) / H) H = int(height * 0.9) else: H = int(H * (width * 0.9) / W) W = int(width * 0.9) image_center = np.array(Image.fromarray(image_center).resize((W, H))) # pad to H, W start_h = (height - H) // 2 start_w = (width - W) // 2 image = np.zeros((height, width, 4), dtype=np.uint8) image[start_h : start_h + H, start_w : start_w + W] = image_center image = image.astype(np.float32) / 255.0 image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 image = (image * 255).clip(0, 255).astype(np.uint8) image = Image.fromarray(image) return image def run_pipeline( pipe, num_views, text, image, height, width, num_inference_steps, guidance_scale, seed, remove_bg_fn=None, reference_conditioning_scale=1.0, negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", lora_scale=1.0, device="cuda", ): # Prepare cameras cameras = get_orthogonal_camera( elevation_deg=[0, 0, 0, 0, 0, 0], distance=[1.8] * num_views, left=-0.55, right=0.55, bottom=-0.55, top=0.55, azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], device=device, ) plucker_embeds = get_plucker_embeds_from_cameras_ortho( cameras.c2w, [1.1] * num_views, width ) control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) # Prepare image reference_image = Image.open(image) if isinstance(image, str) else image if remove_bg_fn is not None: reference_image = remove_bg_fn(reference_image) reference_image = preprocess_image(reference_image, height, width) elif reference_image.mode == "RGBA": reference_image = preprocess_image(reference_image, height, width) pipe_kwargs = {} if seed != -1 and isinstance(seed, int): pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) images = pipe( text, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_views, control_image=control_images, control_conditioning_scale=1.0, reference_image=reference_image, reference_conditioning_scale=reference_conditioning_scale, negative_prompt=negative_prompt, cross_attention_kwargs={"scale": lora_scale}, **pipe_kwargs, ).images return images, reference_image if __name__ == "__main__": parser = argparse.ArgumentParser() # Models parser.add_argument( "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base" ) parser.add_argument("--vae_model", type=str, default=None) parser.add_argument("--unet_model", type=str, default=None) parser.add_argument("--scheduler", type=str, default=None) parser.add_argument("--lora_model", type=str, default=None) parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") parser.add_argument("--num_views", type=int, default=6) # Device parser.add_argument("--device", type=str, default="cuda") # Inference parser.add_argument("--image", type=str, required=True) parser.add_argument("--text", type=str, default="high quality") parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--guidance_scale", type=float, default=3.0) parser.add_argument("--seed", type=int, default=-1) parser.add_argument("--lora_scale", type=float, default=1.0) parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) parser.add_argument( "--negative_prompt", type=str, default="watermark, ugly, deformed, noisy, blurry, low contrast", ) parser.add_argument("--output", type=str, default="output.png") # Extra parser.add_argument("--remove_bg", action="store_true", help="Remove background") args = parser.parse_args() pipe = prepare_pipeline( base_model=args.base_model, vae_model=args.vae_model, unet_model=args.unet_model, lora_model=args.lora_model, adapter_path=args.adapter_path, scheduler=args.scheduler, num_views=args.num_views, device=args.device, dtype=torch.float16, ) if args.remove_bg: birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to(args.device) transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) else: remove_bg_fn = None images, reference_image = run_pipeline( pipe, num_views=args.num_views, text=args.text, image=args.image, height=512, width=512, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, seed=args.seed, lora_scale=args.lora_scale, reference_conditioning_scale=args.reference_conditioning_scale, negative_prompt=args.negative_prompt, device=args.device, remove_bg_fn=remove_bg_fn, ) make_image_grid(images, rows=1).save(args.output) reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")