Daankular commited on
Commit
62cb6e6
·
verified ·
1 Parent(s): 9a190f2

Upload external/MV-Adapter/mvadapter/utils/mesh_utils/warp.py with huggingface_hub

Browse files
external/MV-Adapter/mvadapter/utils/mesh_utils/warp.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import nvdiffrast.torch as dr
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import trimesh
8
+ from einops import rearrange
9
+ from PIL import Image, ImageDraw
10
+ from tqdm import tqdm
11
+
12
+
13
+ def draw(img, lines):
14
+ """
15
+ lines: n x 4, x1,y2,x2,y2
16
+ """
17
+ if isinstance(img, torch.Tensor):
18
+ img = img.detach().cpu().numpy()
19
+
20
+ if isinstance(img, np.ndarray):
21
+ if img.dtype == np.float32 or img.dtype == np.float64:
22
+ img = (img * 255.0).astype(np.uint8)
23
+ img = Image.fromarray(img)
24
+
25
+ img = img.resize((512, 512), Image.Resampling.BICUBIC)
26
+ w, h = img.size
27
+ assert w == h
28
+ if isinstance(lines, torch.Tensor):
29
+ lines = lines.detach().cpu().numpy()
30
+ lines = (lines + 1) / 2 * w
31
+
32
+ img1 = ImageDraw.Draw(img)
33
+ for line in lines:
34
+ img1.line(line.tolist(), fill="red", width=1)
35
+ return img
36
+
37
+
38
+ def construct_grid_mesh(n_grid):
39
+ vertices = []
40
+ vertices_movable_ids = []
41
+ idx = 0
42
+ for j in range(n_grid + 1):
43
+ for i in range(n_grid + 1):
44
+ # 3D fit in trimesh format
45
+ if 0 < i < n_grid and 0 < j < n_grid:
46
+ vertices_movable_ids.append(idx)
47
+ vertices.append([i / n_grid, j / n_grid, 1.0 / 2])
48
+ idx += 1
49
+
50
+ vertices = np.array(vertices)
51
+ vertices = 2 * vertices - 1
52
+
53
+ faces = []
54
+ for j in range(n_grid):
55
+ for i in range(n_grid):
56
+ # clockwise
57
+ faces.append(
58
+ [
59
+ i + j * (n_grid + 1),
60
+ i + 1 + j * (n_grid + 1),
61
+ i + (j + 1) * (n_grid + 1),
62
+ ]
63
+ )
64
+ faces.append(
65
+ [
66
+ i + 1 + j * (n_grid + 1),
67
+ i + 1 + (j + 1) * (n_grid + 1),
68
+ i + (j + 1) * (n_grid + 1),
69
+ ]
70
+ )
71
+ faces = np.array(faces)
72
+ vertices_movable_ids = np.array(vertices_movable_ids)
73
+
74
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
75
+ return mesh, vertices_movable_ids, vertices
76
+
77
+
78
+ def compute_warp_field(
79
+ ctx,
80
+ src_images_tensor,
81
+ tgt_images_tensor,
82
+ n_grid,
83
+ optim_res,
84
+ optim_step_per_res,
85
+ lambda_reg,
86
+ temp_dir,
87
+ verbose,
88
+ device,
89
+ ):
90
+ """
91
+ src_images_tensor: 4 H W 4
92
+ tgt_images_tensor: 4 H W 4
93
+ return: 4 H W 4
94
+ """
95
+ lam_reg = lambda_reg
96
+ lam_mask = 2
97
+ mesh, _, vertices_np = construct_grid_mesh(n_grid)
98
+
99
+ # vertices = mesh.vertices
100
+ faces = mesh.faces
101
+ edges = mesh.edges_unique
102
+
103
+ n_degree = torch.tensor(mesh.vertex_degree)
104
+ vertices = torch.tensor(vertices_np, device=device, dtype=torch.float32)
105
+ vertices_unopt = vertices.clone()
106
+ faces = torch.tensor(faces, device=device)
107
+ edges = torch.tensor(edges, device=device)
108
+
109
+ warped_images_tensor = []
110
+ bs = src_images_tensor.shape[0]
111
+
112
+ for img_idx in range(bs):
113
+ src_image_tensor = src_images_tensor[img_idx]
114
+ tgt_image_tensor = tgt_images_tensor[img_idx]
115
+
116
+ if verbose and temp_dir is not None:
117
+ vis_dir = os.path.join(temp_dir, f"{img_idx}")
118
+ os.makedirs(vis_dir, exist_ok=True)
119
+ # prepare for nvdiff rendering
120
+ v_origin_homo = torch.cat(
121
+ [vertices_unopt, torch.ones_like(vertices_unopt[..., :1])], dim=-1
122
+ )
123
+
124
+ # prepare parameters
125
+ v_move_indices = torch.where(n_degree == 6)[0]
126
+ vertices_movable = vertices[v_move_indices].detach().clone()
127
+ vertices_movable.requires_grad = True
128
+ opt = torch.optim.Adam([vertices_movable], lr=0.02)
129
+
130
+ for resl in optim_res:
131
+ rast, _ = dr.rasterize(
132
+ ctx,
133
+ v_origin_homo[None, ...].float(),
134
+ faces.int(),
135
+ (resl, resl),
136
+ grad_db=True,
137
+ )
138
+ face_ids = rast[..., 3].long() - 1
139
+ assert torch.all(face_ids >= 0)
140
+ u = rast[..., 0]
141
+ v = rast[..., 1]
142
+ pixel_vertice_ids = faces[face_ids]
143
+ pixel_bary_coords = torch.stack([u, v, 1 - u - v], dim=-1)
144
+
145
+ src_img = F.interpolate(
146
+ rearrange(src_image_tensor[..., :3], "H W C -> () C H W"),
147
+ size=(resl, resl),
148
+ mode="bilinear",
149
+ align_corners=False,
150
+ antialias=True,
151
+ )
152
+
153
+ tgt_img = F.interpolate(
154
+ rearrange(tgt_image_tensor[..., :3], "H W C -> () C H W"),
155
+ size=(resl, resl),
156
+ mode="bilinear",
157
+ align_corners=False,
158
+ antialias=True,
159
+ )
160
+
161
+ # src_img, tgt_img: 1 C H W
162
+ if verbose:
163
+ Image.fromarray(
164
+ (tgt_img[0].permute(1, 2, 0).detach().cpu().numpy() * 255.0).astype(
165
+ np.uint8
166
+ )
167
+ ).save(os.path.join(vis_dir, f"target_{resl:04d}.png"))
168
+ with tqdm(range(optim_step_per_res), disable=not verbose) as pbar:
169
+ for i in pbar:
170
+ opt.zero_grad()
171
+ vertices_all = vertices_unopt.detach().clone()
172
+ vertices_all[v_move_indices] = vertices_movable
173
+
174
+ pixel_vertices = vertices_all[
175
+ pixel_vertice_ids
176
+ ] # 1 x 512 x 512 x 3 x 3
177
+ pixel_coords_inter = torch.sum(
178
+ pixel_vertices * pixel_bary_coords[..., None], dim=-2
179
+ )
180
+ src_img_warped = F.grid_sample(
181
+ src_img,
182
+ pixel_coords_inter[..., :2],
183
+ mode="bilinear",
184
+ align_corners=False,
185
+ )
186
+
187
+ edge_vertices_all = vertices_all[edges]
188
+ edge_vertices_unopt = vertices_unopt[edges]
189
+
190
+ edge_len_all = torch.linalg.norm(
191
+ edge_vertices_all[:, 0, :2] - edge_vertices_all[:, 1, :2],
192
+ dim=-1,
193
+ )
194
+ edge_len_unopt = torch.linalg.norm(
195
+ edge_vertices_unopt[:, 0, :2] - edge_vertices_all[:, 1, :2],
196
+ dim=-1,
197
+ )
198
+
199
+ reg_loss = ((edge_len_all - edge_len_unopt) ** 2).mean()
200
+
201
+ # mask_loss = ((src_img_warped[:, 3:4, :, :] - tgt_img[:, 3:4, :, :])**2).mean()
202
+ # rgb_loss = (((src_img_warped[:, :3, :, :] - tgt_img[:, :3, :, :]) * tgt_img[:, 3:4, :, :])**2).mean()
203
+ # loss = rgb_loss + lam_reg * reg_loss + lam_mask * mask_loss
204
+ img_loss = ((src_img_warped - tgt_img) ** 2).mean()
205
+ loss = img_loss + lam_reg * reg_loss
206
+
207
+ loss.backward()
208
+ opt.step()
209
+ if verbose:
210
+ print(img_loss.item(), reg_loss.item())
211
+
212
+ if verbose:
213
+ draw(
214
+ src_img[0].permute(1, 2, 0),
215
+ edge_vertices_all[:, :, :2].reshape(-1, 4),
216
+ ).save(os.path.join(vis_dir, f"src_{resl:04d}_{i:03d}.png"))
217
+ Image.fromarray(
218
+ (
219
+ torch.cat(
220
+ [
221
+ tgt_img,
222
+ src_img_warped,
223
+ (tgt_img - src_img_warped).abs(),
224
+ ],
225
+ dim=-1,
226
+ )[0]
227
+ .permute(1, 2, 0)
228
+ .detach()
229
+ .cpu()
230
+ .numpy()
231
+ * 255.0
232
+ ).astype(np.uint8)
233
+ ).save(os.path.join(vis_dir, f"opt_{resl:04d}_{i:03d}.png"))
234
+ warped_this_resl = Image.fromarray(
235
+ (
236
+ src_img_warped[0]
237
+ .permute(1, 2, 0)
238
+ .detach()
239
+ .cpu()
240
+ .numpy()
241
+ * 255.0
242
+ ).astype(np.uint8)
243
+ )
244
+ warped_this_resl.save(
245
+ os.path.join(vis_dir, f"warped_{resl:04d}_{i:03d}.png")
246
+ )
247
+
248
+ # warp on 512 resolution
249
+ resl = src_image_tensor.shape[1]
250
+ with torch.no_grad():
251
+ rast, _ = dr.rasterize(
252
+ ctx,
253
+ v_origin_homo[None, ...].float(),
254
+ faces.int(),
255
+ (resl, resl),
256
+ grad_db=True,
257
+ )
258
+ face_ids = rast[..., 3].long() - 1
259
+ assert torch.all(face_ids >= 0)
260
+ u = rast[..., 0]
261
+ v = rast[..., 1]
262
+ pixel_vertice_ids = faces[face_ids]
263
+ pixel_bary_coords = torch.stack([u, v, 1 - u - v], dim=-1)
264
+
265
+ vertices_all = vertices_unopt.detach().clone()
266
+ vertices_all[v_move_indices] = vertices_movable
267
+
268
+ pixel_vertices = vertices_all[pixel_vertice_ids] # 1 x 512 x 512 x 3 x 3
269
+ pixel_coords_inter = torch.sum(
270
+ pixel_vertices * pixel_bary_coords[..., None], dim=-2
271
+ )
272
+ src_img_warped = rearrange(
273
+ F.grid_sample(
274
+ rearrange(src_image_tensor, "H W C -> () C H W"),
275
+ pixel_coords_inter[..., :2],
276
+ mode="bicubic",
277
+ align_corners=False,
278
+ ),
279
+ "() C H W -> H W C",
280
+ ).clamp(0, 1)
281
+ warped_images_tensor.append(src_img_warped)
282
+
283
+ warped_images_tensor = torch.stack(warped_images_tensor, dim=0) # 4 H W 3
284
+
285
+ return warped_images_tensor