Daankular commited on
Commit
7ee007c
·
verified ·
1 Parent(s): e8c88d7

Upload external/MV-Adapter/scripts/inference_t2mv_sd.py with huggingface_hub

Browse files
external/MV-Adapter/scripts/inference_t2mv_sd.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
5
+
6
+ from mvadapter.pipelines.pipeline_mvadapter_t2mv_sd import MVAdapterT2MVSDPipeline
7
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
8
+ from mvadapter.utils.mesh_utils import get_orthogonal_camera
9
+ from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho
10
+ from mvadapter.utils import make_image_grid
11
+
12
+
13
+ def prepare_pipeline(
14
+ base_model,
15
+ vae_model,
16
+ unet_model,
17
+ lora_model,
18
+ adapter_path,
19
+ scheduler,
20
+ num_views,
21
+ device,
22
+ dtype,
23
+ ):
24
+ # Load vae and unet if provided
25
+ pipe_kwargs = {}
26
+ if vae_model is not None:
27
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
28
+ if unet_model is not None:
29
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
30
+
31
+ # Prepare pipeline
32
+ pipe: MVAdapterT2MVSDPipeline
33
+ pipe = MVAdapterT2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs)
34
+
35
+ # Load scheduler if provided
36
+ scheduler_class = None
37
+ if scheduler == "ddpm":
38
+ scheduler_class = DDPMScheduler
39
+ elif scheduler == "lcm":
40
+ scheduler_class = LCMScheduler
41
+
42
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
43
+ pipe.scheduler,
44
+ shift_mode="interpolated",
45
+ shift_scale=8.0,
46
+ scheduler_class=scheduler_class,
47
+ )
48
+ pipe.init_custom_adapter(num_views=num_views)
49
+ pipe.load_custom_adapter(
50
+ adapter_path, weight_name="mvadapter_t2mv_sd21.safetensors"
51
+ )
52
+
53
+ pipe.to(device=device, dtype=dtype)
54
+ pipe.cond_encoder.to(device=device, dtype=dtype)
55
+
56
+ # load lora if provided
57
+ if lora_model is not None:
58
+ model_, name_ = lora_model.rsplit("/", 1)
59
+ pipe.load_lora_weights(model_, weight_name=name_)
60
+
61
+ # vae slicing for lower memory usage
62
+ pipe.enable_vae_slicing()
63
+
64
+ return pipe
65
+
66
+
67
+ def run_pipeline(
68
+ pipe,
69
+ num_views,
70
+ text,
71
+ height,
72
+ width,
73
+ num_inference_steps,
74
+ guidance_scale,
75
+ seed,
76
+ negative_prompt,
77
+ lora_scale=1.0,
78
+ device="cuda",
79
+ ):
80
+ # Prepare cameras
81
+ cameras = get_orthogonal_camera(
82
+ elevation_deg=[0, 0, 0, 0, 0, 0],
83
+ distance=[1.8] * num_views,
84
+ left=-0.55,
85
+ right=0.55,
86
+ bottom=-0.55,
87
+ top=0.55,
88
+ azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
89
+ device=device,
90
+ )
91
+
92
+ plucker_embeds = get_plucker_embeds_from_cameras_ortho(
93
+ cameras.c2w, [1.1] * num_views, width
94
+ )
95
+ control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
96
+
97
+ pipe_kwargs = {}
98
+ if seed != -1:
99
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
100
+
101
+ images = pipe(
102
+ text,
103
+ height=height,
104
+ width=width,
105
+ num_inference_steps=num_inference_steps,
106
+ guidance_scale=guidance_scale,
107
+ num_images_per_prompt=num_views,
108
+ control_image=control_images,
109
+ control_conditioning_scale=1.0,
110
+ negative_prompt=negative_prompt,
111
+ cross_attention_kwargs={"scale": lora_scale},
112
+ **pipe_kwargs,
113
+ ).images
114
+
115
+ return images
116
+
117
+
118
+ if __name__ == "__main__":
119
+ parser = argparse.ArgumentParser()
120
+ # Models
121
+ parser.add_argument(
122
+ "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base"
123
+ )
124
+ parser.add_argument("--vae_model", type=str, default=None)
125
+ parser.add_argument("--unet_model", type=str, default=None)
126
+ parser.add_argument("--scheduler", type=str, default=None)
127
+ parser.add_argument("--lora_model", type=str, default=None)
128
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
129
+ parser.add_argument("--num_views", type=int, default=6)
130
+ # Device
131
+ parser.add_argument("--device", type=str, default="cuda")
132
+ # Inference
133
+ parser.add_argument("--text", type=str, required=True)
134
+ parser.add_argument("--num_inference_steps", type=int, default=50)
135
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
136
+ parser.add_argument("--seed", type=int, default=-1)
137
+ parser.add_argument(
138
+ "--negative_prompt",
139
+ type=str,
140
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
141
+ )
142
+ parser.add_argument("--lora_scale", type=float, default=1.0)
143
+ parser.add_argument("--output", type=str, default="output.png")
144
+ args = parser.parse_args()
145
+
146
+ pipe = prepare_pipeline(
147
+ base_model=args.base_model,
148
+ vae_model=args.vae_model,
149
+ unet_model=args.unet_model,
150
+ lora_model=args.lora_model,
151
+ adapter_path=args.adapter_path,
152
+ scheduler=args.scheduler,
153
+ num_views=args.num_views,
154
+ device=args.device,
155
+ dtype=torch.float16,
156
+ )
157
+ images = run_pipeline(
158
+ pipe,
159
+ num_views=args.num_views,
160
+ text=args.text,
161
+ height=512,
162
+ width=512,
163
+ num_inference_steps=args.num_inference_steps,
164
+ guidance_scale=args.guidance_scale,
165
+ seed=args.seed,
166
+ negative_prompt=args.negative_prompt,
167
+ lora_scale=args.lora_scale,
168
+ device=args.device,
169
+ )
170
+ make_image_grid(images, rows=1).save(args.output)