| import argparse
|
|
|
| import torch
|
| from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
|
|
|
| from mvadapter.pipelines.pipeline_mvadapter_t2mv_sd import MVAdapterT2MVSDPipeline
|
| from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
|
| from mvadapter.utils.mesh_utils import get_orthogonal_camera
|
| from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho
|
| from mvadapter.utils import make_image_grid
|
|
|
|
|
| def prepare_pipeline(
|
| base_model,
|
| vae_model,
|
| unet_model,
|
| lora_model,
|
| adapter_path,
|
| scheduler,
|
| num_views,
|
| device,
|
| dtype,
|
| ):
|
|
|
| pipe_kwargs = {}
|
| if vae_model is not None:
|
| pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
|
| if unet_model is not None:
|
| pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
|
|
|
|
|
| pipe: MVAdapterT2MVSDPipeline
|
| pipe = MVAdapterT2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs)
|
|
|
|
|
| scheduler_class = None
|
| if scheduler == "ddpm":
|
| scheduler_class = DDPMScheduler
|
| elif scheduler == "lcm":
|
| scheduler_class = LCMScheduler
|
|
|
| pipe.scheduler = ShiftSNRScheduler.from_scheduler(
|
| pipe.scheduler,
|
| shift_mode="interpolated",
|
| shift_scale=8.0,
|
| scheduler_class=scheduler_class,
|
| )
|
| pipe.init_custom_adapter(num_views=num_views)
|
| pipe.load_custom_adapter(
|
| adapter_path, weight_name="mvadapter_t2mv_sd21.safetensors"
|
| )
|
|
|
| pipe.to(device=device, dtype=dtype)
|
| pipe.cond_encoder.to(device=device, dtype=dtype)
|
|
|
|
|
| if lora_model is not None:
|
| model_, name_ = lora_model.rsplit("/", 1)
|
| pipe.load_lora_weights(model_, weight_name=name_)
|
|
|
|
|
| pipe.enable_vae_slicing()
|
|
|
| return pipe
|
|
|
|
|
| def run_pipeline(
|
| pipe,
|
| num_views,
|
| text,
|
| height,
|
| width,
|
| num_inference_steps,
|
| guidance_scale,
|
| seed,
|
| negative_prompt,
|
| lora_scale=1.0,
|
| device="cuda",
|
| ):
|
|
|
| cameras = get_orthogonal_camera(
|
| elevation_deg=[0, 0, 0, 0, 0, 0],
|
| distance=[1.8] * num_views,
|
| left=-0.55,
|
| right=0.55,
|
| bottom=-0.55,
|
| top=0.55,
|
| azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
|
| device=device,
|
| )
|
|
|
| plucker_embeds = get_plucker_embeds_from_cameras_ortho(
|
| cameras.c2w, [1.1] * num_views, width
|
| )
|
| control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
|
|
|
| pipe_kwargs = {}
|
| if seed != -1:
|
| pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
|
|
|
| images = pipe(
|
| text,
|
| height=height,
|
| width=width,
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| num_images_per_prompt=num_views,
|
| control_image=control_images,
|
| control_conditioning_scale=1.0,
|
| negative_prompt=negative_prompt,
|
| cross_attention_kwargs={"scale": lora_scale},
|
| **pipe_kwargs,
|
| ).images
|
|
|
| return images
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser()
|
|
|
| parser.add_argument(
|
| "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base"
|
| )
|
| parser.add_argument("--vae_model", type=str, default=None)
|
| parser.add_argument("--unet_model", type=str, default=None)
|
| parser.add_argument("--scheduler", type=str, default=None)
|
| parser.add_argument("--lora_model", type=str, default=None)
|
| parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
|
| parser.add_argument("--num_views", type=int, default=6)
|
|
|
| parser.add_argument("--device", type=str, default="cuda")
|
|
|
| parser.add_argument("--text", type=str, required=True)
|
| parser.add_argument("--num_inference_steps", type=int, default=50)
|
| parser.add_argument("--guidance_scale", type=float, default=7.0)
|
| parser.add_argument("--seed", type=int, default=-1)
|
| parser.add_argument(
|
| "--negative_prompt",
|
| type=str,
|
| default="watermark, ugly, deformed, noisy, blurry, low contrast",
|
| )
|
| parser.add_argument("--lora_scale", type=float, default=1.0)
|
| parser.add_argument("--output", type=str, default="output.png")
|
| args = parser.parse_args()
|
|
|
| pipe = prepare_pipeline(
|
| base_model=args.base_model,
|
| vae_model=args.vae_model,
|
| unet_model=args.unet_model,
|
| lora_model=args.lora_model,
|
| adapter_path=args.adapter_path,
|
| scheduler=args.scheduler,
|
| num_views=args.num_views,
|
| device=args.device,
|
| dtype=torch.float16,
|
| )
|
| images = run_pipeline(
|
| pipe,
|
| num_views=args.num_views,
|
| text=args.text,
|
| height=512,
|
| width=512,
|
| num_inference_steps=args.num_inference_steps,
|
| guidance_scale=args.guidance_scale,
|
| seed=args.seed,
|
| negative_prompt=args.negative_prompt,
|
| lora_scale=args.lora_scale,
|
| device=args.device,
|
| )
|
| make_image_grid(images, rows=1).save(args.output)
|
|
|