| import argparse
|
|
|
| import numpy as np
|
| import torch
|
| from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
|
| from PIL import Image
|
| from torchvision import transforms
|
| from tqdm import tqdm
|
| from transformers import AutoModelForImageSegmentation
|
|
|
| from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0
|
| from mvadapter.pipelines.pipeline_mvadapter_i2mv_sd import MVAdapterI2MVSDPipeline
|
| from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
|
| from mvadapter.utils import make_image_grid, tensor_to_image
|
| from mvadapter.utils.mesh_utils import (
|
| NVDiffRastContextWrapper,
|
| get_orthogonal_camera,
|
| load_mesh,
|
| render,
|
| )
|
|
|
|
|
| def prepare_pipeline(
|
| base_model,
|
| vae_model,
|
| unet_model,
|
| lora_model,
|
| adapter_path,
|
| scheduler,
|
| num_views,
|
| device,
|
| dtype,
|
| ):
|
|
|
| pipe_kwargs = {}
|
| if vae_model is not None:
|
| pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
|
| if unet_model is not None:
|
| pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
|
|
|
|
|
| pipe: MVAdapterI2MVSDPipeline
|
| pipe = MVAdapterI2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs)
|
|
|
|
|
| scheduler_class = None
|
| if scheduler == "ddpm":
|
| scheduler_class = DDPMScheduler
|
| elif scheduler == "lcm":
|
| scheduler_class = LCMScheduler
|
|
|
| pipe.scheduler = ShiftSNRScheduler.from_scheduler(
|
| pipe.scheduler,
|
| shift_mode="interpolated",
|
| shift_scale=8.0,
|
| scheduler_class=scheduler_class,
|
| )
|
| pipe.init_custom_adapter(
|
| num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0
|
| )
|
| pipe.load_custom_adapter(
|
| adapter_path, weight_name="mvadapter_ig2mv_sd21.safetensors"
|
| )
|
|
|
| pipe.to(device=device, dtype=dtype)
|
| pipe.cond_encoder.to(device=device, dtype=dtype)
|
|
|
|
|
| if lora_model is not None:
|
| model_, name_ = lora_model.rsplit("/", 1)
|
| pipe.load_lora_weights(model_, weight_name=name_)
|
|
|
|
|
| pipe.enable_vae_slicing()
|
|
|
| return pipe
|
|
|
|
|
| def remove_bg(image, net, transform, device):
|
| image_size = image.size
|
| input_images = transform(image).unsqueeze(0).to(device)
|
| with torch.no_grad():
|
| preds = net(input_images)[-1].sigmoid().cpu()
|
| pred = preds[0].squeeze()
|
| pred_pil = transforms.ToPILImage()(pred)
|
| mask = pred_pil.resize(image_size)
|
| image.putalpha(mask)
|
| return image
|
|
|
|
|
| def preprocess_image(image: Image.Image, height, width):
|
| image = np.array(image)
|
| alpha = image[..., 3] > 0
|
| H, W = alpha.shape
|
|
|
| y, x = np.where(alpha)
|
| y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
| x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
| image_center = image[y0:y1, x0:x1]
|
|
|
| H, W, _ = image_center.shape
|
| if H > W:
|
| W = int(W * (height * 0.9) / H)
|
| H = int(height * 0.9)
|
| else:
|
| H = int(H * (width * 0.9) / W)
|
| W = int(width * 0.9)
|
| image_center = np.array(Image.fromarray(image_center).resize((W, H)))
|
|
|
| start_h = (height - H) // 2
|
| start_w = (width - W) // 2
|
| image = np.zeros((height, width, 4), dtype=np.uint8)
|
| image[start_h : start_h + H, start_w : start_w + W] = image_center
|
| image = image.astype(np.float32) / 255.0
|
| image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
| image = (image * 255).clip(0, 255).astype(np.uint8)
|
| image = Image.fromarray(image)
|
|
|
| return image
|
|
|
|
|
| def run_pipeline(
|
| pipe,
|
| mesh_path,
|
| num_views,
|
| text,
|
| image,
|
| height,
|
| width,
|
| num_inference_steps,
|
| guidance_scale,
|
| seed,
|
| remove_bg_fn=None,
|
| reference_conditioning_scale=1.0,
|
| negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
|
| lora_scale=1.0,
|
| device="cuda",
|
| ):
|
|
|
| cameras = get_orthogonal_camera(
|
| elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
|
| distance=[1.8] * num_views,
|
| left=-0.55,
|
| right=0.55,
|
| bottom=-0.55,
|
| top=0.55,
|
| azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
|
| device=device,
|
| )
|
| ctx = NVDiffRastContextWrapper(device=device)
|
|
|
| mesh = load_mesh(mesh_path, rescale=True, device=device)
|
| render_out = render(
|
| ctx,
|
| mesh,
|
| cameras,
|
| height=height,
|
| width=width,
|
| render_attr=False,
|
| normal_background=0.0,
|
| )
|
| pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
|
| normal_images = tensor_to_image(
|
| (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
|
| )
|
| control_images = (
|
| torch.cat(
|
| [
|
| (render_out.pos + 0.5).clamp(0, 1),
|
| (render_out.normal / 2 + 0.5).clamp(0, 1),
|
| ],
|
| dim=-1,
|
| )
|
| .permute(0, 3, 1, 2)
|
| .to(device)
|
| )
|
|
|
|
|
| reference_image = Image.open(image) if isinstance(image, str) else image
|
| if remove_bg_fn is not None:
|
| reference_image = remove_bg_fn(reference_image)
|
| reference_image = preprocess_image(reference_image, height, width)
|
| elif reference_image.mode == "RGBA":
|
| reference_image = preprocess_image(reference_image, height, width)
|
|
|
| pipe_kwargs = {}
|
| if seed != -1 and isinstance(seed, int):
|
| pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
|
|
|
| images = pipe(
|
| text,
|
| height=height,
|
| width=width,
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| num_images_per_prompt=num_views,
|
| control_image=control_images,
|
| control_conditioning_scale=1.0,
|
| reference_image=reference_image,
|
| reference_conditioning_scale=reference_conditioning_scale,
|
| negative_prompt=negative_prompt,
|
| cross_attention_kwargs={"scale": lora_scale},
|
| **pipe_kwargs,
|
| ).images
|
|
|
| return images, pos_images, normal_images, reference_image
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
|
|
| parser.add_argument(
|
| "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base"
|
| )
|
| parser.add_argument("--vae_model", type=str, default=None)
|
| parser.add_argument("--unet_model", type=str, default=None)
|
| parser.add_argument("--scheduler", type=str, default=None)
|
| parser.add_argument("--lora_model", type=str, default=None)
|
| parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
|
| parser.add_argument("--num_views", type=int, default=6)
|
|
|
| parser.add_argument("--device", type=str, default="cuda")
|
|
|
| parser.add_argument("--mesh", type=str, required=True)
|
| parser.add_argument("--image", type=str, required=True)
|
| parser.add_argument("--text", type=str, required=False, default="high quality")
|
| parser.add_argument("--num_inference_steps", type=int, default=50)
|
| parser.add_argument("--guidance_scale", type=float, default=3.0)
|
| parser.add_argument("--seed", type=int, default=-1)
|
| parser.add_argument("--lora_scale", type=float, default=1.0)
|
| parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
|
| parser.add_argument(
|
| "--negative_prompt",
|
| type=str,
|
| default="watermark, ugly, deformed, noisy, blurry, low contrast",
|
| )
|
| parser.add_argument("--output", type=str, default="output.png")
|
|
|
| parser.add_argument("--remove_bg", action="store_true", help="Remove background")
|
| args = parser.parse_args()
|
|
|
| pipe = prepare_pipeline(
|
| base_model=args.base_model,
|
| vae_model=args.vae_model,
|
| unet_model=args.unet_model,
|
| lora_model=args.lora_model,
|
| adapter_path=args.adapter_path,
|
| scheduler=args.scheduler,
|
| num_views=args.num_views,
|
| device=args.device,
|
| dtype=torch.float16,
|
| )
|
|
|
| if args.remove_bg:
|
| birefnet = AutoModelForImageSegmentation.from_pretrained(
|
| "ZhengPeng7/BiRefNet", trust_remote_code=True
|
| )
|
| birefnet.to(args.device)
|
| transform_image = transforms.Compose(
|
| [
|
| transforms.Resize((1024, 1024)),
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| ]
|
| )
|
| remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
|
| else:
|
| remove_bg_fn = None
|
|
|
| images, pos_images, normal_images, reference_image = run_pipeline(
|
| pipe,
|
| mesh_path=args.mesh,
|
| num_views=args.num_views,
|
| text=args.text,
|
| image=args.image,
|
| height=512,
|
| width=512,
|
| num_inference_steps=args.num_inference_steps,
|
| guidance_scale=args.guidance_scale,
|
| seed=args.seed,
|
| lora_scale=args.lora_scale,
|
| reference_conditioning_scale=args.reference_conditioning_scale,
|
| negative_prompt=args.negative_prompt,
|
| device=args.device,
|
| remove_bg_fn=remove_bg_fn,
|
| )
|
| make_image_grid(images, rows=1).save(args.output)
|
| make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png")
|
| make_image_grid(normal_images, rows=1).save(
|
| args.output.rsplit(".", 1)[0] + "_nor.png"
|
| )
|
| reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")
|
|
|