Daankular commited on
Commit
2974e1c
·
verified ·
1 Parent(s): d4a2728

Upload external/MV-Adapter/scripts/inference_i2mv_sdxl.py with huggingface_hub

Browse files
external/MV-Adapter/scripts/inference_i2mv_sdxl.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForImageSegmentation
10
+
11
+ from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
12
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
13
+ from mvadapter.utils.mesh_utils import get_orthogonal_camera
14
+ from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho
15
+ from mvadapter.utils import make_image_grid
16
+
17
+
18
+ def prepare_pipeline(
19
+ base_model,
20
+ vae_model,
21
+ unet_model,
22
+ lora_model,
23
+ adapter_path,
24
+ scheduler,
25
+ num_views,
26
+ device,
27
+ dtype,
28
+ ):
29
+ # Load vae and unet if provided
30
+ pipe_kwargs = {}
31
+ if vae_model is not None:
32
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
33
+ if unet_model is not None:
34
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
35
+
36
+ # Prepare pipeline
37
+ pipe: MVAdapterI2MVSDXLPipeline
38
+ pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
39
+
40
+ # Load scheduler if provided
41
+ scheduler_class = None
42
+ if scheduler == "ddpm":
43
+ scheduler_class = DDPMScheduler
44
+ elif scheduler == "lcm":
45
+ scheduler_class = LCMScheduler
46
+
47
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
48
+ pipe.scheduler,
49
+ shift_mode="interpolated",
50
+ shift_scale=8.0,
51
+ scheduler_class=scheduler_class,
52
+ )
53
+ pipe.init_custom_adapter(num_views=num_views)
54
+ pipe.load_custom_adapter(
55
+ adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors"
56
+ )
57
+
58
+ pipe.to(device=device, dtype=dtype)
59
+ pipe.cond_encoder.to(device=device, dtype=dtype)
60
+
61
+ # load lora if provided
62
+ if lora_model is not None:
63
+ model_, name_ = lora_model.rsplit("/", 1)
64
+ pipe.load_lora_weights(model_, weight_name=name_)
65
+
66
+ # vae slicing for lower memory usage
67
+ pipe.enable_vae_slicing()
68
+
69
+ return pipe
70
+
71
+
72
+ def remove_bg(image, net, transform, device):
73
+ image_size = image.size
74
+ input_images = transform(image).unsqueeze(0).to(device)
75
+ with torch.no_grad():
76
+ preds = net(input_images)[-1].sigmoid().cpu()
77
+ pred = preds[0].squeeze()
78
+ pred_pil = transforms.ToPILImage()(pred)
79
+ mask = pred_pil.resize(image_size)
80
+ image.putalpha(mask)
81
+ return image
82
+
83
+
84
+ def preprocess_image(image: Image.Image, height, width):
85
+ image = np.array(image)
86
+ alpha = image[..., 3] > 0
87
+ H, W = alpha.shape
88
+ # get the bounding box of alpha
89
+ y, x = np.where(alpha)
90
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
91
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
92
+ image_center = image[y0:y1, x0:x1]
93
+ # resize the longer side to H * 0.9
94
+ H, W, _ = image_center.shape
95
+ if H > W:
96
+ W = int(W * (height * 0.9) / H)
97
+ H = int(height * 0.9)
98
+ else:
99
+ H = int(H * (width * 0.9) / W)
100
+ W = int(width * 0.9)
101
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
102
+ # pad to H, W
103
+ start_h = (height - H) // 2
104
+ start_w = (width - W) // 2
105
+ image = np.zeros((height, width, 4), dtype=np.uint8)
106
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
107
+ image = image.astype(np.float32) / 255.0
108
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
109
+ image = (image * 255).clip(0, 255).astype(np.uint8)
110
+ image = Image.fromarray(image)
111
+
112
+ return image
113
+
114
+
115
+ def run_pipeline(
116
+ pipe,
117
+ num_views,
118
+ text,
119
+ image,
120
+ height,
121
+ width,
122
+ num_inference_steps,
123
+ guidance_scale,
124
+ seed,
125
+ remove_bg_fn=None,
126
+ reference_conditioning_scale=1.0,
127
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
128
+ lora_scale=1.0,
129
+ device="cuda",
130
+ azimuth_deg=None,
131
+ ):
132
+ # Prepare cameras
133
+ if azimuth_deg is None:
134
+ azimuth_deg = [0, 45, 90, 180, 270, 315]
135
+ cameras = get_orthogonal_camera(
136
+ elevation_deg=[0] * num_views,
137
+ distance=[1.8] * num_views,
138
+ left=-0.55,
139
+ right=0.55,
140
+ bottom=-0.55,
141
+ top=0.55,
142
+ azimuth_deg=[x - 90 for x in azimuth_deg],
143
+ device=device,
144
+ )
145
+
146
+ plucker_embeds = get_plucker_embeds_from_cameras_ortho(
147
+ cameras.c2w, [1.1] * num_views, width
148
+ )
149
+ control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
150
+
151
+ # Prepare image
152
+ reference_image = Image.open(image) if isinstance(image, str) else image
153
+ if remove_bg_fn is not None:
154
+ reference_image = remove_bg_fn(reference_image)
155
+ reference_image = preprocess_image(reference_image, height, width)
156
+ elif reference_image.mode == "RGBA":
157
+ reference_image = preprocess_image(reference_image, height, width)
158
+
159
+ pipe_kwargs = {}
160
+ if seed != -1 and isinstance(seed, int):
161
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
162
+
163
+ images = pipe(
164
+ text,
165
+ height=height,
166
+ width=width,
167
+ num_inference_steps=num_inference_steps,
168
+ guidance_scale=guidance_scale,
169
+ num_images_per_prompt=num_views,
170
+ control_image=control_images,
171
+ control_conditioning_scale=1.0,
172
+ reference_image=reference_image,
173
+ reference_conditioning_scale=reference_conditioning_scale,
174
+ negative_prompt=negative_prompt,
175
+ cross_attention_kwargs={"scale": lora_scale},
176
+ **pipe_kwargs,
177
+ ).images
178
+
179
+ return images, reference_image
180
+
181
+
182
+ if __name__ == "__main__":
183
+ parser = argparse.ArgumentParser()
184
+ # Models
185
+ parser.add_argument(
186
+ "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
187
+ )
188
+ parser.add_argument(
189
+ "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
190
+ )
191
+ parser.add_argument("--unet_model", type=str, default=None)
192
+ parser.add_argument("--scheduler", type=str, default=None)
193
+ parser.add_argument("--lora_model", type=str, default=None)
194
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
195
+ # Device
196
+ parser.add_argument("--device", type=str, default="cuda")
197
+ # Inference
198
+ parser.add_argument("--num_views", type=int, default=6) # not used
199
+ parser.add_argument(
200
+ "--azimuth_deg", type=int, nargs="+", default=[0, 45, 90, 180, 270, 315]
201
+ )
202
+ parser.add_argument("--image", type=str, required=True)
203
+ parser.add_argument("--text", type=str, default="high quality")
204
+ parser.add_argument("--num_inference_steps", type=int, default=50)
205
+ parser.add_argument("--guidance_scale", type=float, default=3.0)
206
+ parser.add_argument("--seed", type=int, default=-1)
207
+ parser.add_argument("--lora_scale", type=float, default=1.0)
208
+ parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
209
+ parser.add_argument(
210
+ "--negative_prompt",
211
+ type=str,
212
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
213
+ )
214
+ parser.add_argument("--output", type=str, default="output.png")
215
+ # Extra
216
+ parser.add_argument("--remove_bg", action="store_true", help="Remove background")
217
+ args = parser.parse_args()
218
+
219
+ num_views = len(args.azimuth_deg)
220
+
221
+ pipe = prepare_pipeline(
222
+ base_model=args.base_model,
223
+ vae_model=args.vae_model,
224
+ unet_model=args.unet_model,
225
+ lora_model=args.lora_model,
226
+ adapter_path=args.adapter_path,
227
+ scheduler=args.scheduler,
228
+ num_views=num_views,
229
+ device=args.device,
230
+ dtype=torch.float16,
231
+ )
232
+
233
+ if args.remove_bg:
234
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
235
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
236
+ )
237
+ birefnet.to(args.device)
238
+ transform_image = transforms.Compose(
239
+ [
240
+ transforms.Resize((1024, 1024)),
241
+ transforms.ToTensor(),
242
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
243
+ ]
244
+ )
245
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
246
+ else:
247
+ remove_bg_fn = None
248
+
249
+ images, reference_image = run_pipeline(
250
+ pipe,
251
+ num_views=num_views,
252
+ text=args.text,
253
+ image=args.image,
254
+ height=768,
255
+ width=768,
256
+ num_inference_steps=args.num_inference_steps,
257
+ guidance_scale=args.guidance_scale,
258
+ seed=args.seed,
259
+ lora_scale=args.lora_scale,
260
+ reference_conditioning_scale=args.reference_conditioning_scale,
261
+ negative_prompt=args.negative_prompt,
262
+ device=args.device,
263
+ remove_bg_fn=remove_bg_fn,
264
+ azimuth_deg=args.azimuth_deg,
265
+ )
266
+ make_image_grid(images, rows=1).save(args.output)
267
+ reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")