BiliSakura commited on
Commit
6a56bf0
·
verified ·
1 Parent(s): f55b7a3

Update all files for BitDance-ImageNet-diffusers

Browse files
BitDance_B_16x/transformer/model_parallel.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from .diff_head_parallel import DiffHead
10
+ from .layers_parallel import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d
11
+ from .qae import VQModel
12
+ from .utils import patchify_raster, unpatchify_raster, patchify_raster_2d
13
+
14
+
15
+
16
+ def get_model_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--model", type=str, choices=list(BitDance_models.keys()), default="BitDance-L"
20
+ )
21
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
22
+ parser.add_argument("--down-size", type=int, default=16, choices=[16])
23
+ parser.add_argument("--patch-size", type=int, default=1, choices=[1, 2, 4])
24
+ parser.add_argument("--num-classes", type=int, default=1000)
25
+ parser.add_argument("--cls-token-num", type=int, default=64)
26
+ parser.add_argument("--latent-dim", type=int, default=16)
27
+ parser.add_argument("--diff-batch-mul", type=int, default=4)
28
+ parser.add_argument("--grad-checkpointing", action="store_true")
29
+ parser.add_argument("--trained-vae", type=str, default="")
30
+ parser.add_argument("--drop-rate", type=float, default=0.0)
31
+ parser.add_argument("--perturb-schedule", type=str, default="constant")
32
+ parser.add_argument("--perturb-rate", type=float, default=0.0)
33
+ parser.add_argument("--perturb-rate-max", type=float, default=0.3)
34
+ parser.add_argument("--time-schedule", type=str, default='logit_normal')
35
+ parser.add_argument("--time-shift", type=float, default=1.)
36
+ parser.add_argument("--parallel-num", type=int, default=4)
37
+ parser.add_argument("--P-std", type=float, default=0.8)
38
+ parser.add_argument("--P-mean", type=float, default=-0.8)
39
+ parser.add_argument("--parallel-mode", type=str, default='patch', choices=['standard', 'patch'])
40
+ return parser
41
+
42
+
43
+ def create_model(args, device):
44
+ model = BitDance_models[args.model](
45
+ resolution=args.image_size,
46
+ down_size=args.down_size,
47
+ patch_size=args.patch_size,
48
+ latent_dim=args.latent_dim,
49
+ diff_batch_mul=args.diff_batch_mul,
50
+ cls_token_num=args.cls_token_num,
51
+ num_classes=args.num_classes,
52
+ grad_checkpointing=args.grad_checkpointing,
53
+ trained_vae=args.trained_vae,
54
+ drop_rate=args.drop_rate,
55
+ perturb_schedule=args.perturb_schedule,
56
+ perturb_rate=args.perturb_rate,
57
+ perturb_rate_max=args.perturb_rate_max,
58
+ time_schedule=args.time_schedule,
59
+ time_shift=args.time_shift,
60
+ parallel_num=args.parallel_num,
61
+ P_std=args.P_std,
62
+ P_mean=args.P_mean,
63
+ parallel_mode=args.parallel_mode,
64
+ ).to(device, memory_format=torch.channels_last)
65
+ return model
66
+
67
+ class MLPConnector(nn.Module):
68
+ def __init__(self, in_dim, dim, dropout_p=0.0):
69
+ super().__init__()
70
+ hidden_dim = int(dim * 1.5)
71
+ self.w1 = nn.Linear(in_dim, hidden_dim * 2, bias=True)
72
+ self.w2 = nn.Linear(hidden_dim, dim, bias=True)
73
+ self.ffn_dropout = nn.Dropout(dropout_p)
74
+
75
+ def forward(self, x):
76
+ h1, h2 = self.w1(x).chunk(2, dim=-1)
77
+ return self.ffn_dropout(self.w2(F.silu(h1) * h2))
78
+
79
+ def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor:
80
+ if not 0.0 <= p_max <= 1.0:
81
+ raise ValueError(f"p_max must be in [0.0, 1.0] range, but got: {p_max}")
82
+ r1 = torch.rand_like(tensor)
83
+ r2 = torch.rand_like(tensor)
84
+ flip_mask = r1 < p_max * r2
85
+ multiplier = torch.where(flip_mask, -1.0, 1.0)
86
+ multiplier = multiplier.to(tensor.dtype)
87
+ flipped_tensor = tensor * multiplier
88
+ return flipped_tensor
89
+
90
+ def get_block_causal_mask(num_tokens_total, num_tokens_causal, block_size):
91
+ assert (num_tokens_total - num_tokens_causal) % block_size == 0
92
+ attention_mask = torch.zeros(num_tokens_total, num_tokens_total)
93
+ causal_mask = torch.triu(torch.ones(num_tokens_total, num_tokens_total), diagonal=1)
94
+ attention_mask.masked_fill_(causal_mask.bool(), float('-inf'))
95
+
96
+ for i in range(num_tokens_causal, num_tokens_total, block_size):
97
+ start_idx = i
98
+ end_idx = i + block_size
99
+ attention_mask[start_idx:end_idx, start_idx:end_idx] = 0
100
+
101
+ return attention_mask
102
+
103
+ class BitDance(nn.Module):
104
+
105
+ def __init__(
106
+ self,
107
+ dim,
108
+ n_layer,
109
+ n_head,
110
+ diff_layers,
111
+ diff_dim,
112
+ diff_adanln_layers,
113
+ latent_dim,
114
+ down_size,
115
+ patch_size,
116
+ resolution,
117
+ diff_batch_mul,
118
+ grad_checkpointing=False,
119
+ cls_token_num=16,
120
+ num_classes: int = 1000,
121
+ class_dropout_prob: float = 0.1,
122
+ trained_vae: str = "",
123
+ drop_rate: float = 0.0,
124
+ perturb_schedule: str = "constant",
125
+ perturb_rate: float = 0.0,
126
+ perturb_rate_max: float = 0.3,
127
+ time_schedule: str = 'logit_normal',
128
+ time_shift: float = 1.,
129
+ parallel_num: int = 4,
130
+ P_std: float = 1.,
131
+ P_mean: float = 0.,
132
+ parallel_mode: str = 'standard',
133
+ ):
134
+ super().__init__()
135
+
136
+ self.n_layer = n_layer
137
+ self.resolution = resolution
138
+ self.down_size = down_size
139
+ self.patch_size = patch_size
140
+ self.num_classes = num_classes
141
+ self.cls_token_num = cls_token_num
142
+ self.class_dropout_prob = class_dropout_prob
143
+ self.latent_dim = latent_dim
144
+ self.trained_vae = trained_vae
145
+ self.perturb_schedule = perturb_schedule
146
+ self.perturb_rate = perturb_rate
147
+ self.perturb_rate_max = perturb_rate_max
148
+ self.parallel_num = parallel_num
149
+ self.parallel_mode = parallel_mode
150
+
151
+ # define the vae and mar model
152
+ ddconfig = {
153
+ "double_z": False,
154
+ "z_channels": latent_dim,
155
+ "in_channels": 3,
156
+ "out_ch": 3,
157
+ "ch": 256,
158
+ "ch_mult": [1,1,2,2,4],
159
+ "num_res_blocks": 4
160
+ }
161
+ num_codebooks = 4
162
+ # print(f"loading vae unexpected_keys: {unexpected_keys}")
163
+ self.vae = VQModel(ddconfig, num_codebooks)
164
+ self.grad_checkpointing = grad_checkpointing
165
+
166
+ self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num)
167
+ self.query_token = nn.Parameter(torch.randn(1, self.parallel_num - 1, dim) * 0.02)
168
+ self.proj_in = MLPConnector(latent_dim * self.patch_size * self.patch_size, dim, drop_rate)
169
+ self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True)
170
+ self.h, self.w = resolution // (down_size * patch_size), resolution // (down_size * patch_size)
171
+ self.total_tokens = self.h * self.w + self.cls_token_num
172
+
173
+ self.layers = torch.nn.ModuleList()
174
+ for layer_id in range(n_layer):
175
+ self.layers.append(
176
+ TransformerBlock(
177
+ dim,
178
+ n_head,
179
+ resid_dropout_p=drop_rate,
180
+ )
181
+ )
182
+
183
+ self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True)
184
+ self.pos_for_diff = nn.Embedding(self.h * self.w, dim)
185
+ self.head = DiffHead(
186
+ ch_target=latent_dim * self.patch_size * self.patch_size,
187
+ ch_cond=dim,
188
+ ch_latent=diff_dim,
189
+ depth_latent=diff_layers,
190
+ depth_adanln=diff_adanln_layers,
191
+ grad_checkpointing=grad_checkpointing,
192
+ time_shift=time_shift,
193
+ time_schedule=time_schedule,
194
+ parallel_num=parallel_num,
195
+ P_std=P_std,
196
+ P_mean=P_mean,
197
+ )
198
+ self.diff_batch_mul = diff_batch_mul
199
+
200
+ patch_2d_pos = get_2d_pos(resolution, int(down_size * patch_size))
201
+
202
+ freqs_cis = precompute_freqs_cis_2d(
203
+ patch_2d_pos,
204
+ dim // n_head,
205
+ 10000,
206
+ cls_token_num=self.cls_token_num + self.parallel_num - 1,
207
+ )
208
+
209
+ if self.parallel_mode == 'patch':
210
+ freqs_cis[-self.h * self.w:] = patchify_raster_2d(freqs_cis[-self.h * self.w:], int(self.parallel_num ** 0.5), self.h, self.w)
211
+
212
+ self.register_buffer("freqs_cis", freqs_cis[:-self.parallel_num], persistent=False)
213
+
214
+ attn_mask = get_block_causal_mask(self.h * self.w + self.cls_token_num -1, self.cls_token_num -1, self.parallel_num)
215
+ self.register_buffer("attn_mask", attn_mask.unsqueeze(0).unsqueeze(0), persistent=False)
216
+ self.freeze_vae()
217
+
218
+ self.initialize_weights()
219
+
220
+ def load_vae_weight(self):
221
+ state = torch.load(
222
+ self.trained_vae,
223
+ map_location="cpu",
224
+ )
225
+ missing_keys, unexpected_keys = self.vae.load_state_dict(state["state_dict"], strict=False)
226
+ print(f"loading vae, missing_keys: {missing_keys}")
227
+ del state
228
+
229
+ def non_decay_keys(self):
230
+ return ["proj_in", "cls_embedding", "query_token"]
231
+
232
+ def freeze_module(self, module: nn.Module):
233
+ for param in module.parameters():
234
+ param.requires_grad = False
235
+
236
+ def freeze_vae(self):
237
+ self.freeze_module(self.vae)
238
+ self.vae.eval()
239
+
240
+ def initialize_weights(self):
241
+ # Initialize nn.Linear and nn.Embedding
242
+ self.apply(self.__init_weights)
243
+ self.head.initialize_weights()
244
+ # self.vae.initialize_weights()
245
+
246
+ def __init_weights(self, module):
247
+ std = 0.02
248
+ if isinstance(module, nn.Linear):
249
+ module.weight.data.normal_(mean=0.0, std=std)
250
+ if module.bias is not None:
251
+ module.bias.data.zero_()
252
+ elif isinstance(module, nn.Embedding):
253
+ module.weight.data.normal_(mean=0.0, std=std)
254
+
255
+ def drop_label(self, class_id):
256
+ if self.class_dropout_prob > 0.0 and self.training:
257
+ is_drop = (
258
+ torch.rand(class_id.shape, device=class_id.device)
259
+ < self.class_dropout_prob
260
+ )
261
+ class_id = torch.where(is_drop, self.num_classes, class_id)
262
+ return class_id
263
+
264
+ def patchify(self, x):
265
+ bsz, c, h, w = x.shape
266
+ p = self.patch_size
267
+ h_, w_ = h // p, w // p
268
+
269
+ x = x.reshape(bsz, c, h_, p, w_, p)
270
+ x = torch.einsum('nchpwq->nhwcpq', x)
271
+ x = x.reshape(bsz, h_ * w_, c * p ** 2)
272
+ return x # [n, l, d]
273
+
274
+ def unpatchify(self, x):
275
+ bsz = x.shape[0]
276
+ p = self.patch_size
277
+ c = self.latent_dim
278
+ h_, w_ = self.h, self.w
279
+
280
+ x = x.reshape(bsz, h_, w_, c, p, p)
281
+ x = torch.einsum('nhwcpq->nchpwq', x)
282
+ x = x.reshape(bsz, c, h_ * p, w_ * p)
283
+ return x # [n, c, h, w]
284
+
285
+ def forward(
286
+ self,
287
+ images,
288
+ class_id,
289
+ cached=False
290
+ ):
291
+ if cached:
292
+ vae_latent = images
293
+ else:
294
+ vae_latent, _, _, _ = self.vae.encode(images) # b c h w
295
+
296
+ if self.parallel_mode == 'standard':
297
+ vae_latent = self.patchify(vae_latent)
298
+ elif self.parallel_mode == 'patch':
299
+ vae_latent = patchify_raster(vae_latent, int(self.parallel_num ** 0.5))
300
+ else:
301
+ raise NotImplementedError(f"unknown parallel_mode {self.parallel_mode}")
302
+ x = vae_latent.clone().detach()
303
+ if self.training:
304
+ if self.perturb_schedule =="constant":
305
+ x = flip_tensor_elements_uniform_prob(x, self.perturb_rate)
306
+ else:
307
+ raise NotImplementedError(f"unknown perturb_schedule {self.perturb_schedule}")
308
+ x = self.proj_in(x[:, :-self.parallel_num, :])
309
+ class_id = self.drop_label(class_id)
310
+ bsz = x.shape[0]
311
+ c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1)
312
+ query_token = self.query_token.repeat(bsz, 1, 1)
313
+ x = torch.cat([c, query_token, x], dim=1)
314
+ x = self.emb_norm(x)
315
+
316
+ if self.grad_checkpointing and self.training:
317
+ for layer in self.layers:
318
+ block = partial(layer.forward, mask=self.attn_mask, freqs_cis=self.freqs_cis)
319
+ x = checkpoint(block, x, use_reentrant=False)
320
+ else:
321
+ for layer in self.layers:
322
+ x = layer(x, self.attn_mask, self.freqs_cis)
323
+
324
+ x = x[:, -self.h * self.w :, :]
325
+ x = self.norm(x)
326
+ x = x + self.pos_for_diff.weight
327
+
328
+ target = vae_latent.clone().detach()
329
+ x = x.view(-1, self.parallel_num, x.shape[-1])
330
+ target = target.view(-1, self.parallel_num, target.shape[-1])
331
+
332
+ x = x.repeat(self.diff_batch_mul, 1, 1)
333
+ target = target.repeat(self.diff_batch_mul, 1, 1)
334
+ loss = self.head(target, x)
335
+
336
+ return loss
337
+
338
+ def enable_kv_cache(self, bsz):
339
+ for layer in self.layers:
340
+ layer.attention.enable_kv_cache(bsz, self.total_tokens)
341
+
342
+ @torch.compile()
343
+ def forward_model(self, x, mask, start_pos, end_pos):
344
+ x = self.emb_norm(x)
345
+ for layer in self.layers:
346
+ x = layer.forward_onestep(
347
+ x, mask, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos
348
+ )
349
+ x = self.norm(x)
350
+ return x
351
+
352
+ def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"):
353
+ x = x + self.pos_for_diff.weight[diff_pos*self.parallel_num : (diff_pos+1)*self.parallel_num, :]
354
+ # x = x.view(-1, x.shape[-1])
355
+ seq_len = self.h * self.w // self.parallel_num
356
+ if cfg_scale > 1.0:
357
+ if cfg_schedule == "constant":
358
+ cfg_iter = cfg_scale
359
+ elif cfg_schedule == "linear":
360
+ start = 1.0
361
+ cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len
362
+ else:
363
+ raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}")
364
+ else:
365
+ cfg_iter = 1.0
366
+ pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter)
367
+ # Important: LFQ here, sign the prediction
368
+ pred = torch.sign(pred)
369
+ return pred
370
+
371
+ @torch.no_grad()
372
+ def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear", chunk_size=0):
373
+ self.eval()
374
+ if cfg_scale > 1.0:
375
+ cond_null = torch.ones_like(cond) * self.num_classes
376
+ cond_combined = torch.cat([cond, cond_null])
377
+ else:
378
+ cond_combined = cond
379
+ bsz = cond_combined.shape[0]
380
+ act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz
381
+ self.enable_kv_cache(bsz)
382
+
383
+ c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1)
384
+ last_pred = None
385
+ all_preds = []
386
+ for i in range(self.h * self.w // self.parallel_num):
387
+ if i == 0:
388
+ x = self.forward_model(torch.cat([c, self.query_token.repeat(bsz, 1, 1)], dim=1), self.attn_mask[:, :, :self.cls_token_num + self.parallel_num - 1, :self.cls_token_num + self.parallel_num - 1], 0, self.cls_token_num + self.parallel_num - 1)
389
+ else:
390
+ x = self.proj_in(last_pred)
391
+ start_pos = self.parallel_num * (i-1) + self.cls_token_num + self.parallel_num - 1
392
+ x = self.forward_model(
393
+ x,
394
+ self.attn_mask[:, :, start_pos : start_pos + self.parallel_num, : start_pos + self.parallel_num],
395
+ start_pos,
396
+ start_pos + self.parallel_num
397
+ )
398
+
399
+ last_pred = self.head_sample(
400
+ x[:, -self.parallel_num:, :],
401
+ i,
402
+ sample_steps,
403
+ cfg_scale,
404
+ cfg_schedule,
405
+ )
406
+ all_preds.append(last_pred)
407
+
408
+ x = torch.cat(all_preds, dim=-2)[:act_bsz]
409
+ if x.dim() == 3: #b n c -> b c h w
410
+ if self.parallel_mode == 'standard':
411
+ x = self.unpatchify(x)
412
+ elif self.parallel_mode == 'patch':
413
+ x = unpatchify_raster(x, int(self.parallel_num ** 0.5), (self.h, self.w))
414
+ # recon = self.vae.decode(x)
415
+ if chunk_size > 0:
416
+ recon = self.decode_in_chunks(x, chunk_size)
417
+ else:
418
+ recon = self.vae.decode(x)
419
+ return recon
420
+
421
+ def decode_in_chunks(self, latent_tensor, chunk_size=64):
422
+ total_bsz = latent_tensor.shape[0]
423
+ recon_chunks_on_cpu = []
424
+ with torch.no_grad():
425
+ for i in range(0, total_bsz, chunk_size):
426
+ end_idx = min(i + chunk_size, total_bsz)
427
+ latent_chunk = latent_tensor[i:end_idx]
428
+ recon_chunk = self.vae.decode(latent_chunk)
429
+ recon_chunks_on_cpu.append(recon_chunk.cpu())
430
+ return torch.cat(recon_chunks_on_cpu, dim=0)
431
+
432
+ def get_fsdp_wrap_module_list(self):
433
+ return list(self.layers)
434
+
435
+ def BitDance_H(**kwargs):
436
+ return BitDance(
437
+ n_layer=40,
438
+ n_head=20,
439
+ dim=1280,
440
+ diff_layers=12,
441
+ diff_dim=1280,
442
+ diff_adanln_layers=3,
443
+ **kwargs,
444
+ )
445
+
446
+
447
+ def BitDance_L(**kwargs):
448
+ return BitDance(
449
+ n_layer=32,
450
+ n_head=16,
451
+ dim=1024,
452
+ diff_layers=8,
453
+ diff_dim=1024,
454
+ diff_adanln_layers=2,
455
+ **kwargs,
456
+ )
457
+
458
+
459
+ def BitDance_B(**kwargs):
460
+ return BitDance(
461
+ n_layer=24,
462
+ n_head=12,
463
+ dim=768,
464
+ diff_layers=6,
465
+ diff_dim=768,
466
+ diff_adanln_layers=2,
467
+ **kwargs,
468
+ )
469
+
470
+
471
+ BitDance_models = {
472
+ "BitDance-B": BitDance_B,
473
+ "BitDance-L": BitDance_L,
474
+ "BitDance-H": BitDance_H,
475
+ }