Daankular commited on
Commit
0d48b9b
·
verified ·
1 Parent(s): 0cac2db

Upload external/MV-Adapter/scripts/inference_tg2mv_sd.py with huggingface_hub

Browse files
external/MV-Adapter/scripts/inference_tg2mv_sd.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
5
+
6
+ from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0
7
+ from mvadapter.pipelines.pipeline_mvadapter_t2mv_sd import MVAdapterT2MVSDPipeline
8
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
9
+ from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
10
+ from mvadapter.utils.mesh_utils import NVDiffRastContextWrapper, load_mesh, render
11
+
12
+
13
+ def prepare_pipeline(
14
+ base_model,
15
+ vae_model,
16
+ unet_model,
17
+ lora_model,
18
+ adapter_path,
19
+ scheduler,
20
+ num_views,
21
+ device,
22
+ dtype,
23
+ ):
24
+ # Load vae and unet if provided
25
+ pipe_kwargs = {}
26
+ if vae_model is not None:
27
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
28
+ if unet_model is not None:
29
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
30
+
31
+ # Prepare pipeline
32
+ pipe: MVAdapterT2MVSDPipeline
33
+ pipe = MVAdapterT2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs)
34
+
35
+ # Load scheduler if provided
36
+ scheduler_class = None
37
+ if scheduler == "ddpm":
38
+ scheduler_class = DDPMScheduler
39
+ elif scheduler == "lcm":
40
+ scheduler_class = LCMScheduler
41
+
42
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
43
+ pipe.scheduler,
44
+ shift_mode="interpolated",
45
+ shift_scale=8.0,
46
+ scheduler_class=scheduler_class,
47
+ )
48
+ pipe.init_custom_adapter(
49
+ num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0
50
+ )
51
+ pipe.load_custom_adapter(
52
+ adapter_path, weight_name="mvadapter_tg2mv_sd21.safetensors"
53
+ )
54
+
55
+ pipe.to(device=device, dtype=dtype)
56
+ pipe.cond_encoder.to(device=device, dtype=dtype)
57
+
58
+ # load lora if provided
59
+ if lora_model is not None:
60
+ model_, name_ = lora_model.rsplit("/", 1)
61
+ pipe.load_lora_weights(model_, weight_name=name_)
62
+
63
+ # vae slicing for lower memory usage
64
+ pipe.enable_vae_slicing()
65
+
66
+ return pipe
67
+
68
+
69
+ def run_pipeline(
70
+ pipe,
71
+ mesh_path,
72
+ num_views,
73
+ text,
74
+ height,
75
+ width,
76
+ num_inference_steps,
77
+ guidance_scale,
78
+ seed,
79
+ negative_prompt,
80
+ lora_scale=1.0,
81
+ device="cuda",
82
+ ):
83
+ # Prepare cameras
84
+ cameras = get_orthogonal_camera(
85
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
86
+ distance=[1.8] * num_views,
87
+ left=-0.55,
88
+ right=0.55,
89
+ bottom=-0.55,
90
+ top=0.55,
91
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
92
+ device=device,
93
+ )
94
+ ctx = NVDiffRastContextWrapper(device=device)
95
+
96
+ mesh = load_mesh(mesh_path, rescale=True, device=device)
97
+ render_out = render(
98
+ ctx,
99
+ mesh,
100
+ cameras,
101
+ height=height,
102
+ width=width,
103
+ render_attr=False,
104
+ normal_background=0.0,
105
+ )
106
+ pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
107
+ normal_images = tensor_to_image(
108
+ (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
109
+ )
110
+ control_images = (
111
+ torch.cat(
112
+ [
113
+ (render_out.pos + 0.5).clamp(0, 1),
114
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
115
+ ],
116
+ dim=-1,
117
+ )
118
+ .permute(0, 3, 1, 2)
119
+ .to(device)
120
+ )
121
+
122
+ pipe_kwargs = {}
123
+ if seed != -1:
124
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
125
+
126
+ images = pipe(
127
+ text,
128
+ height=height,
129
+ width=width,
130
+ num_inference_steps=num_inference_steps,
131
+ guidance_scale=guidance_scale,
132
+ num_images_per_prompt=num_views,
133
+ control_image=control_images,
134
+ control_conditioning_scale=1.0,
135
+ negative_prompt=negative_prompt,
136
+ cross_attention_kwargs={"scale": lora_scale},
137
+ **pipe_kwargs,
138
+ ).images
139
+
140
+ return images, pos_images, normal_images
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ # Models
146
+ parser.add_argument(
147
+ "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base"
148
+ )
149
+ parser.add_argument("--vae_model", type=str, default=None)
150
+ parser.add_argument("--unet_model", type=str, default=None)
151
+ parser.add_argument("--scheduler", type=str, default=None)
152
+ parser.add_argument("--lora_model", type=str, default=None)
153
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
154
+ parser.add_argument("--num_views", type=int, default=6)
155
+ # Device
156
+ parser.add_argument("--device", type=str, default="cuda")
157
+ # Inference
158
+ parser.add_argument("--text", type=str, required=True)
159
+ parser.add_argument("--mesh", type=str, required=True)
160
+ parser.add_argument("--num_inference_steps", type=int, default=50)
161
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
162
+ parser.add_argument("--seed", type=int, default=-1)
163
+ parser.add_argument(
164
+ "--negative_prompt",
165
+ type=str,
166
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
167
+ )
168
+ parser.add_argument("--lora_scale", type=float, default=1.0)
169
+ parser.add_argument("--output", type=str, default="output.png")
170
+ args = parser.parse_args()
171
+
172
+ pipe = prepare_pipeline(
173
+ base_model=args.base_model,
174
+ vae_model=args.vae_model,
175
+ unet_model=args.unet_model,
176
+ lora_model=args.lora_model,
177
+ adapter_path=args.adapter_path,
178
+ scheduler=args.scheduler,
179
+ num_views=args.num_views,
180
+ device=args.device,
181
+ dtype=torch.float16,
182
+ )
183
+ images, pos_images, normal_images = run_pipeline(
184
+ pipe,
185
+ mesh_path=args.mesh,
186
+ num_views=args.num_views,
187
+ text=args.text,
188
+ height=512,
189
+ width=512,
190
+ num_inference_steps=args.num_inference_steps,
191
+ guidance_scale=args.guidance_scale,
192
+ seed=args.seed,
193
+ negative_prompt=args.negative_prompt,
194
+ lora_scale=args.lora_scale,
195
+ device=args.device,
196
+ )
197
+ make_image_grid(images, rows=1).save(args.output)
198
+ make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png")
199
+ make_image_grid(normal_images, rows=1).save(
200
+ args.output.rsplit(".", 1)[0] + "_nor.png"
201
+ )