Daankular commited on
Commit
f6fb2f5
·
verified ·
1 Parent(s): 0cf1195

Upload external/MV-Adapter/scripts/inference_ig2mv_sd.py with huggingface_hub

Browse files
external/MV-Adapter/scripts/inference_ig2mv_sd.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForImageSegmentation
10
+
11
+ from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0
12
+ from mvadapter.pipelines.pipeline_mvadapter_i2mv_sd import MVAdapterI2MVSDPipeline
13
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
14
+ from mvadapter.utils import make_image_grid, tensor_to_image
15
+ from mvadapter.utils.mesh_utils import (
16
+ NVDiffRastContextWrapper,
17
+ get_orthogonal_camera,
18
+ load_mesh,
19
+ render,
20
+ )
21
+
22
+
23
+ def prepare_pipeline(
24
+ base_model,
25
+ vae_model,
26
+ unet_model,
27
+ lora_model,
28
+ adapter_path,
29
+ scheduler,
30
+ num_views,
31
+ device,
32
+ dtype,
33
+ ):
34
+ # Load vae and unet if provided
35
+ pipe_kwargs = {}
36
+ if vae_model is not None:
37
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
38
+ if unet_model is not None:
39
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
40
+
41
+ # Prepare pipeline
42
+ pipe: MVAdapterI2MVSDPipeline
43
+ pipe = MVAdapterI2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs)
44
+
45
+ # Load scheduler if provided
46
+ scheduler_class = None
47
+ if scheduler == "ddpm":
48
+ scheduler_class = DDPMScheduler
49
+ elif scheduler == "lcm":
50
+ scheduler_class = LCMScheduler
51
+
52
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
53
+ pipe.scheduler,
54
+ shift_mode="interpolated",
55
+ shift_scale=8.0,
56
+ scheduler_class=scheduler_class,
57
+ )
58
+ pipe.init_custom_adapter(
59
+ num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0
60
+ )
61
+ pipe.load_custom_adapter(
62
+ adapter_path, weight_name="mvadapter_ig2mv_sd21.safetensors"
63
+ )
64
+
65
+ pipe.to(device=device, dtype=dtype)
66
+ pipe.cond_encoder.to(device=device, dtype=dtype)
67
+
68
+ # load lora if provided
69
+ if lora_model is not None:
70
+ model_, name_ = lora_model.rsplit("/", 1)
71
+ pipe.load_lora_weights(model_, weight_name=name_)
72
+
73
+ # vae slicing for lower memory usage
74
+ pipe.enable_vae_slicing()
75
+
76
+ return pipe
77
+
78
+
79
+ def remove_bg(image, net, transform, device):
80
+ image_size = image.size
81
+ input_images = transform(image).unsqueeze(0).to(device)
82
+ with torch.no_grad():
83
+ preds = net(input_images)[-1].sigmoid().cpu()
84
+ pred = preds[0].squeeze()
85
+ pred_pil = transforms.ToPILImage()(pred)
86
+ mask = pred_pil.resize(image_size)
87
+ image.putalpha(mask)
88
+ return image
89
+
90
+
91
+ def preprocess_image(image: Image.Image, height, width):
92
+ image = np.array(image)
93
+ alpha = image[..., 3] > 0
94
+ H, W = alpha.shape
95
+ # get the bounding box of alpha
96
+ y, x = np.where(alpha)
97
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
98
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
99
+ image_center = image[y0:y1, x0:x1]
100
+ # resize the longer side to H * 0.9
101
+ H, W, _ = image_center.shape
102
+ if H > W:
103
+ W = int(W * (height * 0.9) / H)
104
+ H = int(height * 0.9)
105
+ else:
106
+ H = int(H * (width * 0.9) / W)
107
+ W = int(width * 0.9)
108
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
109
+ # pad to H, W
110
+ start_h = (height - H) // 2
111
+ start_w = (width - W) // 2
112
+ image = np.zeros((height, width, 4), dtype=np.uint8)
113
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
114
+ image = image.astype(np.float32) / 255.0
115
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
116
+ image = (image * 255).clip(0, 255).astype(np.uint8)
117
+ image = Image.fromarray(image)
118
+
119
+ return image
120
+
121
+
122
+ def run_pipeline(
123
+ pipe,
124
+ mesh_path,
125
+ num_views,
126
+ text,
127
+ image,
128
+ height,
129
+ width,
130
+ num_inference_steps,
131
+ guidance_scale,
132
+ seed,
133
+ remove_bg_fn=None,
134
+ reference_conditioning_scale=1.0,
135
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
136
+ lora_scale=1.0,
137
+ device="cuda",
138
+ ):
139
+ # Prepare cameras
140
+ cameras = get_orthogonal_camera(
141
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
142
+ distance=[1.8] * num_views,
143
+ left=-0.55,
144
+ right=0.55,
145
+ bottom=-0.55,
146
+ top=0.55,
147
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
148
+ device=device,
149
+ )
150
+ ctx = NVDiffRastContextWrapper(device=device)
151
+
152
+ mesh = load_mesh(mesh_path, rescale=True, device=device)
153
+ render_out = render(
154
+ ctx,
155
+ mesh,
156
+ cameras,
157
+ height=height,
158
+ width=width,
159
+ render_attr=False,
160
+ normal_background=0.0,
161
+ )
162
+ pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
163
+ normal_images = tensor_to_image(
164
+ (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
165
+ )
166
+ control_images = (
167
+ torch.cat(
168
+ [
169
+ (render_out.pos + 0.5).clamp(0, 1),
170
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
171
+ ],
172
+ dim=-1,
173
+ )
174
+ .permute(0, 3, 1, 2)
175
+ .to(device)
176
+ )
177
+
178
+ # Prepare image
179
+ reference_image = Image.open(image) if isinstance(image, str) else image
180
+ if remove_bg_fn is not None:
181
+ reference_image = remove_bg_fn(reference_image)
182
+ reference_image = preprocess_image(reference_image, height, width)
183
+ elif reference_image.mode == "RGBA":
184
+ reference_image = preprocess_image(reference_image, height, width)
185
+
186
+ pipe_kwargs = {}
187
+ if seed != -1 and isinstance(seed, int):
188
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
189
+
190
+ images = pipe(
191
+ text,
192
+ height=height,
193
+ width=width,
194
+ num_inference_steps=num_inference_steps,
195
+ guidance_scale=guidance_scale,
196
+ num_images_per_prompt=num_views,
197
+ control_image=control_images,
198
+ control_conditioning_scale=1.0,
199
+ reference_image=reference_image,
200
+ reference_conditioning_scale=reference_conditioning_scale,
201
+ negative_prompt=negative_prompt,
202
+ cross_attention_kwargs={"scale": lora_scale},
203
+ **pipe_kwargs,
204
+ ).images
205
+
206
+ return images, pos_images, normal_images, reference_image
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = argparse.ArgumentParser()
211
+ # Models
212
+ parser.add_argument(
213
+ "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base"
214
+ )
215
+ parser.add_argument("--vae_model", type=str, default=None)
216
+ parser.add_argument("--unet_model", type=str, default=None)
217
+ parser.add_argument("--scheduler", type=str, default=None)
218
+ parser.add_argument("--lora_model", type=str, default=None)
219
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
220
+ parser.add_argument("--num_views", type=int, default=6)
221
+ # Device
222
+ parser.add_argument("--device", type=str, default="cuda")
223
+ # Inference
224
+ parser.add_argument("--mesh", type=str, required=True)
225
+ parser.add_argument("--image", type=str, required=True)
226
+ parser.add_argument("--text", type=str, required=False, default="high quality")
227
+ parser.add_argument("--num_inference_steps", type=int, default=50)
228
+ parser.add_argument("--guidance_scale", type=float, default=3.0)
229
+ parser.add_argument("--seed", type=int, default=-1)
230
+ parser.add_argument("--lora_scale", type=float, default=1.0)
231
+ parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
232
+ parser.add_argument(
233
+ "--negative_prompt",
234
+ type=str,
235
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
236
+ )
237
+ parser.add_argument("--output", type=str, default="output.png")
238
+ # Extra
239
+ parser.add_argument("--remove_bg", action="store_true", help="Remove background")
240
+ args = parser.parse_args()
241
+
242
+ pipe = prepare_pipeline(
243
+ base_model=args.base_model,
244
+ vae_model=args.vae_model,
245
+ unet_model=args.unet_model,
246
+ lora_model=args.lora_model,
247
+ adapter_path=args.adapter_path,
248
+ scheduler=args.scheduler,
249
+ num_views=args.num_views,
250
+ device=args.device,
251
+ dtype=torch.float16,
252
+ )
253
+
254
+ if args.remove_bg:
255
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
256
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
257
+ )
258
+ birefnet.to(args.device)
259
+ transform_image = transforms.Compose(
260
+ [
261
+ transforms.Resize((1024, 1024)),
262
+ transforms.ToTensor(),
263
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
264
+ ]
265
+ )
266
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
267
+ else:
268
+ remove_bg_fn = None
269
+
270
+ images, pos_images, normal_images, reference_image = run_pipeline(
271
+ pipe,
272
+ mesh_path=args.mesh,
273
+ num_views=args.num_views,
274
+ text=args.text,
275
+ image=args.image,
276
+ height=512,
277
+ width=512,
278
+ num_inference_steps=args.num_inference_steps,
279
+ guidance_scale=args.guidance_scale,
280
+ seed=args.seed,
281
+ lora_scale=args.lora_scale,
282
+ reference_conditioning_scale=args.reference_conditioning_scale,
283
+ negative_prompt=args.negative_prompt,
284
+ device=args.device,
285
+ remove_bg_fn=remove_bg_fn,
286
+ )
287
+ make_image_grid(images, rows=1).save(args.output)
288
+ make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png")
289
+ make_image_grid(normal_images, rows=1).save(
290
+ args.output.rsplit(".", 1)[0] + "_nor.png"
291
+ )
292
+ reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")