Daankular commited on
Commit
708c3df
·
1 Parent(s): 0183927

Patch autoencoder_kl_triposg: fix embedder fp32 upcast in _decode

Browse files

FrequencyPositionalEmbedding uses float32 frequency buffers, which
upcast float16 query coords to float32 during embedding. The result
then hits proj_query (float16 weight) causing:
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half

Fix: after self.embedder(queries), cast back to model_dtype
(= self.decoder.proj_query.weight.dtype) so the decoder always
receives embeddings matching its parameter dtype.

patches/triposg/triposg/models/autoencoders/autoencoder_kl_triposg.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
8
+ from diffusers.models.autoencoders.vae import DecoderOutput
9
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.models.normalization import FP32LayerNorm, LayerNorm
12
+ from diffusers.utils import logging
13
+ from diffusers.utils.accelerate_utils import apply_forward_hook
14
+ from einops import repeat
15
+ # from torch_cluster import fps
16
+ from tqdm import tqdm
17
+
18
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0, FlashTripoSGAttnProcessor2_0
19
+ from ..embeddings import FrequencyPositionalEmbedding
20
+ from ..transformers.triposg_transformer import DiTBlock
21
+ from .vae import DiagonalGaussianDistribution
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ class TripoSGEncoder(nn.Module):
27
+ def __init__(
28
+ self,
29
+ in_channels: int = 3,
30
+ dim: int = 512,
31
+ num_attention_heads: int = 8,
32
+ num_layers: int = 8,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.proj_in = nn.Linear(in_channels, dim, bias=True)
37
+
38
+ self.blocks = nn.ModuleList(
39
+ [
40
+ DiTBlock(
41
+ dim=dim,
42
+ num_attention_heads=num_attention_heads,
43
+ use_self_attention=False,
44
+ use_cross_attention=True,
45
+ cross_attention_dim=dim,
46
+ cross_attention_norm_type="layer_norm",
47
+ activation_fn="gelu",
48
+ norm_type="fp32_layer_norm",
49
+ norm_eps=1e-5,
50
+ qk_norm=False,
51
+ qkv_bias=False,
52
+ ) # cross attention
53
+ ]
54
+ + [
55
+ DiTBlock(
56
+ dim=dim,
57
+ num_attention_heads=num_attention_heads,
58
+ use_self_attention=True,
59
+ self_attention_norm_type="fp32_layer_norm",
60
+ use_cross_attention=False,
61
+ activation_fn="gelu",
62
+ norm_type="fp32_layer_norm",
63
+ norm_eps=1e-5,
64
+ qk_norm=False,
65
+ qkv_bias=False,
66
+ )
67
+ for _ in range(num_layers) # self attention
68
+ ]
69
+ )
70
+
71
+ self.norm_out = LayerNorm(dim)
72
+
73
+ def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
74
+ hidden_states = self.proj_in(sample_1)
75
+ encoder_hidden_states = self.proj_in(sample_2)
76
+
77
+ for layer, block in enumerate(self.blocks):
78
+ if layer == 0:
79
+ hidden_states = block(
80
+ hidden_states, encoder_hidden_states=encoder_hidden_states
81
+ )
82
+ else:
83
+ hidden_states = block(hidden_states)
84
+
85
+ hidden_states = self.norm_out(hidden_states)
86
+
87
+ return hidden_states
88
+
89
+
90
+ class TripoSGDecoder(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int = 3,
94
+ out_channels: int = 1,
95
+ dim: int = 512,
96
+ num_attention_heads: int = 8,
97
+ num_layers: int = 16,
98
+ grad_type: str = "analytical",
99
+ grad_interval: float = 0.001,
100
+ ):
101
+ super().__init__()
102
+
103
+ if grad_type not in ["numerical", "analytical"]:
104
+ raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
105
+ self.grad_type = grad_type
106
+ self.grad_interval = grad_interval
107
+
108
+ self.blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim=dim,
112
+ num_attention_heads=num_attention_heads,
113
+ use_self_attention=True,
114
+ self_attention_norm_type="fp32_layer_norm",
115
+ use_cross_attention=False,
116
+ activation_fn="gelu",
117
+ norm_type="fp32_layer_norm",
118
+ norm_eps=1e-5,
119
+ qk_norm=False,
120
+ qkv_bias=False,
121
+ )
122
+ for _ in range(num_layers) # self attention
123
+ ]
124
+ + [
125
+ DiTBlock(
126
+ dim=dim,
127
+ num_attention_heads=num_attention_heads,
128
+ use_self_attention=False,
129
+ use_cross_attention=True,
130
+ cross_attention_dim=dim,
131
+ cross_attention_norm_type="layer_norm",
132
+ activation_fn="gelu",
133
+ norm_type="fp32_layer_norm",
134
+ norm_eps=1e-5,
135
+ qk_norm=False,
136
+ qkv_bias=False,
137
+ ) # cross attention
138
+ ]
139
+ )
140
+
141
+ self.proj_query = nn.Linear(in_channels, dim, bias=True)
142
+
143
+ self.norm_out = LayerNorm(dim)
144
+ self.proj_out = nn.Linear(dim, out_channels, bias=True)
145
+
146
+ def set_topk(self, topk):
147
+ self.blocks[-1].set_topk(topk)
148
+
149
+ def set_flash_processor(self, processor):
150
+ self.blocks[-1].set_flash_processor(processor)
151
+
152
+ def query_geometry(
153
+ self,
154
+ model_fn: callable,
155
+ queries: torch.Tensor,
156
+ sample: torch.Tensor,
157
+ grad: bool = False,
158
+ ):
159
+ logits = model_fn(queries, sample)
160
+ if grad:
161
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
162
+ if self.grad_type == "numerical":
163
+ interval = self.grad_interval
164
+ grad_value = []
165
+ for offset in [
166
+ (interval, 0, 0),
167
+ (0, interval, 0),
168
+ (0, 0, interval),
169
+ ]:
170
+ offset_tensor = torch.tensor(offset, device=queries.device)[
171
+ None, :
172
+ ]
173
+ res_p = model_fn(queries + offset_tensor, sample)[..., 0]
174
+ res_n = model_fn(queries - offset_tensor, sample)[..., 0]
175
+ grad_value.append((res_p - res_n) / (2 * interval))
176
+ grad_value = torch.stack(grad_value, dim=-1)
177
+ else:
178
+ queries_d = torch.clone(queries)
179
+ queries_d.requires_grad = True
180
+ with torch.enable_grad():
181
+ res_d = model_fn(queries_d, sample)
182
+ grad_value = torch.autograd.grad(
183
+ res_d,
184
+ [queries_d],
185
+ grad_outputs=torch.ones_like(res_d),
186
+ create_graph=self.training,
187
+ )[0]
188
+ else:
189
+ grad_value = None
190
+
191
+ return logits, grad_value
192
+
193
+ def forward(
194
+ self,
195
+ sample: torch.Tensor,
196
+ queries: torch.Tensor,
197
+ kv_cache: Optional[torch.Tensor] = None,
198
+ ):
199
+ if kv_cache is None:
200
+ hidden_states = sample
201
+ for _, block in enumerate(self.blocks[:-1]):
202
+ hidden_states = block(hidden_states)
203
+ kv_cache = hidden_states
204
+
205
+ # query grid logits by cross attention
206
+ def query_fn(q, kv):
207
+ q = self.proj_query(q)
208
+ l = self.blocks[-1](q, encoder_hidden_states=kv)
209
+ return self.proj_out(self.norm_out(l))
210
+
211
+ logits, grad = self.query_geometry(
212
+ query_fn, queries, kv_cache, grad=self.training
213
+ )
214
+ logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
215
+
216
+ return logits, kv_cache
217
+
218
+
219
+ class TripoSGVAEModel(ModelMixin, ConfigMixin):
220
+ @register_to_config
221
+ def __init__(
222
+ self,
223
+ in_channels: int = 3, # NOTE xyz instead of feature dim
224
+ latent_channels: int = 64,
225
+ num_attention_heads: int = 8,
226
+ width_encoder: int = 512,
227
+ width_decoder: int = 1024,
228
+ num_layers_encoder: int = 8,
229
+ num_layers_decoder: int = 16,
230
+ embedding_type: str = "frequency",
231
+ embed_frequency: int = 8,
232
+ embed_include_pi: bool = False,
233
+ ):
234
+ super().__init__()
235
+
236
+ self.out_channels = 1
237
+
238
+ if embedding_type == "frequency":
239
+ self.embedder = FrequencyPositionalEmbedding(
240
+ num_freqs=embed_frequency,
241
+ logspace=True,
242
+ input_dim=in_channels,
243
+ include_pi=embed_include_pi,
244
+ )
245
+ else:
246
+ raise NotImplementedError(
247
+ f"Embedding type {embedding_type} is not supported."
248
+ )
249
+
250
+ self.encoder = TripoSGEncoder(
251
+ in_channels=in_channels + self.embedder.out_dim,
252
+ dim=width_encoder,
253
+ num_attention_heads=num_attention_heads,
254
+ num_layers=num_layers_encoder,
255
+ )
256
+ self.decoder = TripoSGDecoder(
257
+ in_channels=self.embedder.out_dim,
258
+ out_channels=self.out_channels,
259
+ dim=width_decoder,
260
+ num_attention_heads=num_attention_heads,
261
+ num_layers=num_layers_decoder,
262
+ )
263
+
264
+ self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
265
+ self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
266
+
267
+ self.use_slicing = False
268
+ self.slicing_length = 1
269
+
270
+ def set_flash_decoder(self):
271
+ self.decoder.set_flash_processor(FlashTripoSGAttnProcessor2_0())
272
+
273
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
274
+ def fuse_qkv_projections(self):
275
+ """
276
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
277
+ are fused. For cross-attention modules, key and value projection matrices are fused.
278
+
279
+ <Tip warning={true}>
280
+
281
+ This API is 🧪 experimental.
282
+
283
+ </Tip>
284
+ """
285
+ self.original_attn_processors = None
286
+
287
+ for _, attn_processor in self.attn_processors.items():
288
+ if "Added" in str(attn_processor.__class__.__name__):
289
+ raise ValueError(
290
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
291
+ )
292
+
293
+ self.original_attn_processors = self.attn_processors
294
+
295
+ for module in self.modules():
296
+ if isinstance(module, Attention):
297
+ module.fuse_projections(fuse=True)
298
+
299
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
300
+
301
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
302
+ def unfuse_qkv_projections(self):
303
+ """Disables the fused QKV projection if enabled.
304
+
305
+ <Tip warning={true}>
306
+
307
+ This API is 🧪 experimental.
308
+
309
+ </Tip>
310
+
311
+ """
312
+ if self.original_attn_processors is not None:
313
+ self.set_attn_processor(self.original_attn_processors)
314
+
315
+ @property
316
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
317
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
318
+ r"""
319
+ Returns:
320
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
321
+ indexed by its weight name.
322
+ """
323
+ # set recursively
324
+ processors = {}
325
+
326
+ def fn_recursive_add_processors(
327
+ name: str,
328
+ module: torch.nn.Module,
329
+ processors: Dict[str, AttentionProcessor],
330
+ ):
331
+ if hasattr(module, "get_processor"):
332
+ processors[f"{name}.processor"] = module.get_processor()
333
+
334
+ for sub_name, child in module.named_children():
335
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
336
+
337
+ return processors
338
+
339
+ for name, module in self.named_children():
340
+ fn_recursive_add_processors(name, module, processors)
341
+
342
+ return processors
343
+
344
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
345
+ def set_attn_processor(
346
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
347
+ ):
348
+ r"""
349
+ Sets the attention processor to use to compute attention.
350
+
351
+ Parameters:
352
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
353
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
354
+ for **all** `Attention` layers.
355
+
356
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
357
+ processor. This is strongly recommended when setting trainable attention processors.
358
+
359
+ """
360
+ count = len(self.attn_processors.keys())
361
+
362
+ if isinstance(processor, dict) and len(processor) != count:
363
+ raise ValueError(
364
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
365
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
366
+ )
367
+
368
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
369
+ if hasattr(module, "set_processor"):
370
+ if not isinstance(processor, dict):
371
+ module.set_processor(processor)
372
+ else:
373
+ module.set_processor(processor.pop(f"{name}.processor"))
374
+
375
+ for sub_name, child in module.named_children():
376
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
377
+
378
+ for name, module in self.named_children():
379
+ fn_recursive_attn_processor(name, module, processor)
380
+
381
+ def set_default_attn_processor(self):
382
+ """
383
+ Disables custom attention processors and sets the default attention implementation.
384
+ """
385
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
386
+
387
+ def enable_slicing(self, slicing_length: int = 1) -> None:
388
+ r"""
389
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
390
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
391
+ """
392
+ self.use_slicing = True
393
+ self.slicing_length = slicing_length
394
+
395
+ def disable_slicing(self) -> None:
396
+ r"""
397
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
398
+ decoding in one step.
399
+ """
400
+ self.use_slicing = False
401
+
402
+ def _sample_features(
403
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
404
+ ):
405
+ """
406
+ Sample points from features of the input point cloud.
407
+
408
+ Args:
409
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
410
+ num_tokens (int, optional): The number of points to sample. Defaults to 2048.
411
+ seed (Optional[int], optional): The random seed. Defaults to None.
412
+ """
413
+ rng = np.random.default_rng(seed)
414
+ indices = rng.choice(
415
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
416
+ )
417
+ selected_points = x[:, indices]
418
+
419
+ batch_size, num_points, num_channels = selected_points.shape
420
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
421
+ batch_indices = (
422
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
423
+ )
424
+
425
+ # fps sampling
426
+ sampling_ratio = 1.0 / 4
427
+ sampled_indices = fps(
428
+ flattened_points[:, :3],
429
+ batch_indices,
430
+ ratio=sampling_ratio,
431
+ random_start=self.training,
432
+ )
433
+ sampled_points = flattened_points[sampled_indices].view(
434
+ batch_size, -1, num_channels
435
+ )
436
+
437
+ return sampled_points
438
+
439
+ def _encode(
440
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
441
+ ):
442
+ position_channels = self.config.in_channels
443
+ positions, features = x[..., :position_channels], x[..., position_channels:]
444
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
445
+
446
+ sampled_x = self._sample_features(x, num_tokens, seed)
447
+ positions, features = (
448
+ sampled_x[..., :position_channels],
449
+ sampled_x[..., position_channels:],
450
+ )
451
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
452
+
453
+ x = self.encoder(x_q, x_kv)
454
+
455
+ x = self.quant(x)
456
+
457
+ return x
458
+
459
+ @apply_forward_hook
460
+ def encode(
461
+ self, x: torch.Tensor, return_dict: bool = True, **kwargs
462
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
463
+ """
464
+ Encode a batch of point features into latents.
465
+ """
466
+ if self.use_slicing and x.shape[0] > 1:
467
+ encoded_slices = [
468
+ self._encode(x_slice, **kwargs)
469
+ for x_slice in x.split(self.slicing_length)
470
+ ]
471
+ h = torch.cat(encoded_slices)
472
+ else:
473
+ h = self._encode(x, **kwargs)
474
+
475
+ posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
476
+
477
+ if not return_dict:
478
+ return (posterior,)
479
+ return AutoencoderKLOutput(latent_dist=posterior)
480
+
481
+ def _decode(
482
+ self,
483
+ z: torch.Tensor,
484
+ sampled_points: torch.Tensor,
485
+ num_chunks: int = 50000,
486
+ to_cpu: bool = False,
487
+ return_dict: bool = True,
488
+ ) -> Union[DecoderOutput, torch.Tensor]:
489
+ xyz_samples = sampled_points
490
+
491
+ z = self.post_quant(z)
492
+ # Determine the model's operating dtype from proj_query weights.
493
+ # FrequencyPositionalEmbedding buffers may be float32, causing the embedder
494
+ # to upcast float16 xyz coords to float32 — which then mismatches the float16
495
+ # proj_query weight in the decoder. Force embeddings back to model dtype.
496
+ model_dtype = self.decoder.proj_query.weight.dtype
497
+
498
+ num_points = xyz_samples.shape[1]
499
+ kv_cache = None
500
+ dec = []
501
+
502
+ for i in range(0, num_points, num_chunks):
503
+ queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
504
+ queries = self.embedder(queries).to(dtype=model_dtype)
505
+
506
+ z_, kv_cache = self.decoder(z, queries, kv_cache)
507
+ dec.append(z_ if not to_cpu else z_.cpu())
508
+
509
+ z = torch.cat(dec, dim=1)
510
+
511
+ if not return_dict:
512
+ return (z,)
513
+
514
+ return DecoderOutput(sample=z)
515
+
516
+ @apply_forward_hook
517
+ def decode(
518
+ self,
519
+ z: torch.Tensor,
520
+ sampled_points: torch.Tensor,
521
+ return_dict: bool = True,
522
+ **kwargs,
523
+ ) -> Union[DecoderOutput, torch.Tensor]:
524
+ if self.use_slicing and z.shape[0] > 1:
525
+ decoded_slices = [
526
+ self._decode(z_slice, p_slice, **kwargs).sample
527
+ for z_slice, p_slice in zip(
528
+ z.split(self.slicing_length),
529
+ sampled_points.split(self.slicing_length),
530
+ )
531
+ ]
532
+ decoded = torch.cat(decoded_slices)
533
+ else:
534
+ decoded = self._decode(z, sampled_points, **kwargs).sample
535
+
536
+ if not return_dict:
537
+ return (decoded,)
538
+ return DecoderOutput(sample=decoded)
539
+
540
+ def forward(self, x: torch.Tensor):
541
+ pass