Daankular commited on
Commit
d18dbc0
·
verified ·
1 Parent(s): 0d48b9b

Upload external/MV-Adapter/scripts/inference_tg2mv_sdxl.py with huggingface_hub

Browse files
external/MV-Adapter/scripts/inference_tg2mv_sdxl.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
5
+
6
+ from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0
7
+ from mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline
8
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
9
+ from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
10
+ from mvadapter.utils.mesh_utils import NVDiffRastContextWrapper, load_mesh, render
11
+
12
+
13
+ def prepare_pipeline(
14
+ base_model,
15
+ vae_model,
16
+ unet_model,
17
+ lora_model,
18
+ adapter_path,
19
+ scheduler,
20
+ num_views,
21
+ device,
22
+ dtype,
23
+ ):
24
+ # Load vae and unet if provided
25
+ pipe_kwargs = {}
26
+ if vae_model is not None:
27
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
28
+ if unet_model is not None:
29
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
30
+
31
+ # Prepare pipeline
32
+ pipe: MVAdapterT2MVSDXLPipeline
33
+ pipe = MVAdapterT2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
34
+
35
+ # Load scheduler if provided
36
+ scheduler_class = None
37
+ if scheduler == "ddpm":
38
+ scheduler_class = DDPMScheduler
39
+ elif scheduler == "lcm":
40
+ scheduler_class = LCMScheduler
41
+
42
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
43
+ pipe.scheduler,
44
+ shift_mode="interpolated",
45
+ shift_scale=8.0,
46
+ scheduler_class=scheduler_class,
47
+ )
48
+ pipe.init_custom_adapter(
49
+ num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0
50
+ )
51
+ pipe.load_custom_adapter(
52
+ adapter_path, weight_name="mvadapter_tg2mv_sdxl.safetensors"
53
+ )
54
+
55
+ pipe.to(device=device, dtype=dtype)
56
+ pipe.cond_encoder.to(device=device, dtype=dtype)
57
+
58
+ # load lora if provided
59
+ if lora_model is not None:
60
+ model_, name_ = lora_model.rsplit("/", 1)
61
+ pipe.load_lora_weights(model_, weight_name=name_)
62
+
63
+ # vae slicing for lower memory usage
64
+ pipe.enable_vae_slicing()
65
+
66
+ return pipe
67
+
68
+
69
+ def run_pipeline(
70
+ pipe,
71
+ mesh_path,
72
+ num_views,
73
+ text,
74
+ height,
75
+ width,
76
+ num_inference_steps,
77
+ guidance_scale,
78
+ seed,
79
+ negative_prompt,
80
+ lora_scale=1.0,
81
+ device="cuda",
82
+ ):
83
+ # Prepare cameras
84
+ cameras = get_orthogonal_camera(
85
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
86
+ distance=[1.8] * num_views,
87
+ left=-0.55,
88
+ right=0.55,
89
+ bottom=-0.55,
90
+ top=0.55,
91
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
92
+ device=device,
93
+ )
94
+ ctx = NVDiffRastContextWrapper(device=device)
95
+
96
+ mesh = load_mesh(mesh_path, rescale=True, device=device)
97
+ render_out = render(
98
+ ctx,
99
+ mesh,
100
+ cameras,
101
+ height=height,
102
+ width=width,
103
+ render_attr=False,
104
+ normal_background=0.0,
105
+ )
106
+ pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
107
+ normal_images = tensor_to_image(
108
+ (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
109
+ )
110
+ control_images = (
111
+ torch.cat(
112
+ [
113
+ (render_out.pos + 0.5).clamp(0, 1),
114
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
115
+ ],
116
+ dim=-1,
117
+ )
118
+ .permute(0, 3, 1, 2)
119
+ .to(device)
120
+ )
121
+
122
+ pipe_kwargs = {}
123
+ if seed != -1:
124
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
125
+
126
+ images = pipe(
127
+ text,
128
+ height=height,
129
+ width=width,
130
+ num_inference_steps=num_inference_steps,
131
+ guidance_scale=guidance_scale,
132
+ num_images_per_prompt=num_views,
133
+ control_image=control_images,
134
+ control_conditioning_scale=1.0,
135
+ negative_prompt=negative_prompt,
136
+ cross_attention_kwargs={"scale": lora_scale},
137
+ **pipe_kwargs,
138
+ ).images
139
+
140
+ return images, pos_images, normal_images
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ # Models
146
+ parser.add_argument(
147
+ "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
148
+ )
149
+ parser.add_argument(
150
+ "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
151
+ )
152
+ parser.add_argument("--unet_model", type=str, default=None)
153
+ parser.add_argument("--scheduler", type=str, default=None)
154
+ parser.add_argument("--lora_model", type=str, default=None)
155
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
156
+ parser.add_argument("--num_views", type=int, default=6)
157
+ # Device
158
+ parser.add_argument("--device", type=str, default="cuda")
159
+ # Inference
160
+ parser.add_argument("--mesh", type=str, required=True)
161
+ parser.add_argument("--text", type=str, required=True)
162
+ parser.add_argument("--num_inference_steps", type=int, default=50)
163
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
164
+ parser.add_argument("--seed", type=int, default=-1)
165
+ parser.add_argument(
166
+ "--negative_prompt",
167
+ type=str,
168
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
169
+ )
170
+ parser.add_argument("--lora_scale", type=float, default=1.0)
171
+ parser.add_argument("--output", type=str, default="output.png")
172
+ args = parser.parse_args()
173
+
174
+ pipe = prepare_pipeline(
175
+ base_model=args.base_model,
176
+ vae_model=args.vae_model,
177
+ unet_model=args.unet_model,
178
+ lora_model=args.lora_model,
179
+ adapter_path=args.adapter_path,
180
+ scheduler=args.scheduler,
181
+ num_views=args.num_views,
182
+ device=args.device,
183
+ dtype=torch.float16,
184
+ )
185
+ images, pos_images, normal_images = run_pipeline(
186
+ pipe,
187
+ mesh_path=args.mesh,
188
+ num_views=args.num_views,
189
+ text=args.text,
190
+ height=768,
191
+ width=768,
192
+ num_inference_steps=args.num_inference_steps,
193
+ guidance_scale=args.guidance_scale,
194
+ seed=args.seed,
195
+ negative_prompt=args.negative_prompt,
196
+ lora_scale=args.lora_scale,
197
+ device=args.device,
198
+ )
199
+ make_image_grid(images, rows=1).save(args.output)
200
+ make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png")
201
+ make_image_grid(normal_images, rows=1).save(
202
+ args.output.rsplit(".", 1)[0] + "_nor.png"
203
+ )