BiliSakura commited on
Commit
e13c83d
·
verified ·
1 Parent(s): 5344135

Update all files for BitDance-ImageNet-diffusers

Browse files
Files changed (1) hide show
  1. BitDance_B_1x/transformer/model.py +432 -0
BitDance_B_1x/transformer/model.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from torch.utils.checkpoint import checkpoint
8
+
9
+ from .diff_head import DiffHead
10
+ from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d
11
+ from .qae import VQModel
12
+
13
+ def get_model_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--model", type=str, choices=list(BitDance_models.keys()), default="BitDance-L"
17
+ )
18
+ parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
19
+ parser.add_argument("--down-size", type=int, default=16, choices=[16])
20
+ parser.add_argument("--patch-size", type=int, default=1, choices=[1, 2, 4])
21
+ parser.add_argument("--num-classes", type=int, default=1000)
22
+ parser.add_argument("--cls-token-num", type=int, default=64)
23
+ parser.add_argument("--latent-dim", type=int, default=16)
24
+ parser.add_argument("--diff-batch-mul", type=int, default=4)
25
+ parser.add_argument("--grad-checkpointing", action="store_true")
26
+ parser.add_argument("--trained-vae", type=str, default="")
27
+ parser.add_argument("--drop-rate", type=float, default=0.0)
28
+ parser.add_argument("--perturb-schedule", type=str, default="constant")
29
+ parser.add_argument("--perturb-rate", type=float, default=0.0)
30
+ parser.add_argument("--perturb-rate-max", type=float, default=0.3)
31
+ parser.add_argument("--time-schedule", type=str, default='logit_normal')
32
+ parser.add_argument("--time-shift", type=float, default=1.)
33
+ parser.add_argument("--P-std", type=float, default=1.)
34
+ parser.add_argument("--P-mean", type=float, default=0.)
35
+ return parser
36
+
37
+
38
+ def create_model(args, device):
39
+ model = BitDance_models[args.model](
40
+ resolution=args.image_size,
41
+ down_size=args.down_size,
42
+ patch_size=args.patch_size,
43
+ latent_dim=args.latent_dim,
44
+ diff_batch_mul=args.diff_batch_mul,
45
+ cls_token_num=args.cls_token_num,
46
+ num_classes=args.num_classes,
47
+ grad_checkpointing=args.grad_checkpointing,
48
+ trained_vae=args.trained_vae,
49
+ drop_rate=args.drop_rate,
50
+ perturb_schedule=args.perturb_schedule,
51
+ perturb_rate=args.perturb_rate,
52
+ perturb_rate_max=args.perturb_rate_max,
53
+ time_schedule=args.time_schedule,
54
+ time_shift=args.time_shift,
55
+ P_std=args.P_std,
56
+ P_mean=args.P_mean,
57
+ ).to(device, memory_format=torch.channels_last)
58
+ return model
59
+
60
+ class MLPConnector(nn.Module):
61
+ def __init__(self, in_dim, dim, dropout_p=0.0):
62
+ super().__init__()
63
+ hidden_dim = int(dim * 1.5)
64
+ self.w1 = nn.Linear(in_dim, hidden_dim * 2, bias=True)
65
+ self.w2 = nn.Linear(hidden_dim, dim, bias=True)
66
+ self.ffn_dropout = nn.Dropout(dropout_p)
67
+
68
+ def forward(self, x):
69
+ h1, h2 = self.w1(x).chunk(2, dim=-1)
70
+ return self.ffn_dropout(self.w2(F.silu(h1) * h2))
71
+
72
+ def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor:
73
+ if not 0.0 <= p_max <= 1.0:
74
+ raise ValueError(f"p_max must be in [0.0, 1.0] range, but got: {p_max}")
75
+ r1 = torch.rand_like(tensor)
76
+ r2 = torch.rand_like(tensor)
77
+ flip_mask = r1 < p_max * r2
78
+ multiplier = torch.where(flip_mask, -1.0, 1.0)
79
+ multiplier = multiplier.to(tensor.dtype)
80
+ flipped_tensor = tensor * multiplier
81
+ return flipped_tensor
82
+
83
+ class BitDance(nn.Module):
84
+
85
+ def __init__(
86
+ self,
87
+ dim,
88
+ n_layer,
89
+ n_head,
90
+ diff_layers,
91
+ diff_dim,
92
+ diff_adanln_layers,
93
+ latent_dim,
94
+ down_size,
95
+ patch_size,
96
+ resolution,
97
+ diff_batch_mul,
98
+ grad_checkpointing=False,
99
+ cls_token_num=16,
100
+ num_classes: int = 1000,
101
+ class_dropout_prob: float = 0.1,
102
+ trained_vae: str = "",
103
+ drop_rate: float = 0.0,
104
+ perturb_schedule: str = "constant",
105
+ perturb_rate: float = 0.0,
106
+ perturb_rate_max: float = 0.3,
107
+ time_schedule: str = 'logit_normal',
108
+ time_shift: float = 1.,
109
+ P_std: float = 1.,
110
+ P_mean: float = 0.,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.n_layer = n_layer
115
+ self.resolution = resolution
116
+ self.down_size = down_size
117
+ self.patch_size = patch_size
118
+ self.num_classes = num_classes
119
+ self.cls_token_num = cls_token_num
120
+ self.class_dropout_prob = class_dropout_prob
121
+ self.latent_dim = latent_dim
122
+ self.trained_vae = trained_vae
123
+ self.perturb_schedule = perturb_schedule
124
+ self.perturb_rate = perturb_rate
125
+ self.perturb_rate_max = perturb_rate_max
126
+
127
+ # define the vae and mar model
128
+ ddconfig = {
129
+ "double_z": False,
130
+ "z_channels": latent_dim,
131
+ "in_channels": 3,
132
+ "out_ch": 3,
133
+ "ch": 256,
134
+ "ch_mult": [1,1,2,2,4],
135
+ "num_res_blocks": 4
136
+ }
137
+ num_codebooks = 4
138
+ # print(f"loading vae unexpected_keys: {unexpected_keys}")
139
+ self.vae = VQModel(ddconfig, num_codebooks)
140
+ self.grad_checkpointing = grad_checkpointing
141
+
142
+ self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num)
143
+ self.proj_in = MLPConnector(latent_dim * self.patch_size * self.patch_size, dim, drop_rate)
144
+ self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True)
145
+ self.h, self.w = resolution // (down_size * patch_size), resolution // (down_size * patch_size)
146
+ self.total_tokens = self.h * self.w + self.cls_token_num
147
+
148
+ self.layers = torch.nn.ModuleList()
149
+ for layer_id in range(n_layer):
150
+ self.layers.append(
151
+ TransformerBlock(
152
+ dim,
153
+ n_head,
154
+ resid_dropout_p=drop_rate,
155
+ causal=True,
156
+ )
157
+ )
158
+
159
+ self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True)
160
+ self.pos_for_diff = nn.Embedding(self.h * self.w, dim)
161
+ self.head = DiffHead(
162
+ ch_target=latent_dim * self.patch_size * self.patch_size,
163
+ ch_cond=dim,
164
+ ch_latent=diff_dim,
165
+ depth_latent=diff_layers,
166
+ depth_adanln=diff_adanln_layers,
167
+ grad_checkpointing=grad_checkpointing,
168
+ time_shift=time_shift,
169
+ time_schedule=time_schedule,
170
+ P_std=P_std,
171
+ P_mean=P_mean,
172
+ )
173
+ self.diff_batch_mul = diff_batch_mul
174
+
175
+ patch_2d_pos = get_2d_pos(resolution, int(down_size * patch_size))
176
+
177
+ self.register_buffer(
178
+ "freqs_cis",
179
+ precompute_freqs_cis_2d(
180
+ patch_2d_pos,
181
+ dim // n_head,
182
+ 10000,
183
+ cls_token_num=self.cls_token_num,
184
+ )[:-1],
185
+ persistent=False,
186
+ )
187
+ self.freeze_vae()
188
+
189
+ self.initialize_weights()
190
+
191
+ def load_vae_weight(self):
192
+ state = torch.load(
193
+ self.trained_vae,
194
+ map_location="cpu",
195
+ )
196
+ missing_keys, unexpected_keys = self.vae.load_state_dict(state["state_dict"], strict=False)
197
+ print(f"loading vae, missing_keys: {missing_keys}")
198
+ del state
199
+
200
+ def non_decay_keys(self):
201
+ return ["proj_in", "cls_embedding"]
202
+
203
+ def freeze_module(self, module: nn.Module):
204
+ for param in module.parameters():
205
+ param.requires_grad = False
206
+
207
+ def freeze_vae(self):
208
+ self.freeze_module(self.vae)
209
+ self.vae.eval()
210
+
211
+ def initialize_weights(self):
212
+ # Initialize nn.Linear and nn.Embedding
213
+ self.apply(self.__init_weights)
214
+ self.head.initialize_weights()
215
+ # self.vae.initialize_weights()
216
+
217
+ def __init_weights(self, module):
218
+ std = 0.02
219
+ if isinstance(module, nn.Linear):
220
+ module.weight.data.normal_(mean=0.0, std=std)
221
+ if module.bias is not None:
222
+ module.bias.data.zero_()
223
+ elif isinstance(module, nn.Embedding):
224
+ module.weight.data.normal_(mean=0.0, std=std)
225
+
226
+ def drop_label(self, class_id):
227
+ if self.class_dropout_prob > 0.0 and self.training:
228
+ is_drop = (
229
+ torch.rand(class_id.shape, device=class_id.device)
230
+ < self.class_dropout_prob
231
+ )
232
+ class_id = torch.where(is_drop, self.num_classes, class_id)
233
+ return class_id
234
+
235
+ def patchify(self, x):
236
+ bsz, c, h, w = x.shape
237
+ p = self.patch_size
238
+ h_, w_ = h // p, w // p
239
+
240
+ x = x.reshape(bsz, c, h_, p, w_, p)
241
+ x = torch.einsum('nchpwq->nhwcpq', x)
242
+ x = x.reshape(bsz, h_ * w_, c * p ** 2)
243
+ return x # [n, l, d]
244
+
245
+ def unpatchify(self, x):
246
+ bsz = x.shape[0]
247
+ p = self.patch_size
248
+ c = self.latent_dim
249
+ h_, w_ = self.h, self.w
250
+
251
+ x = x.reshape(bsz, h_, w_, c, p, p)
252
+ x = torch.einsum('nhwcpq->nchpwq', x)
253
+ x = x.reshape(bsz, c, h_ * p, w_ * p)
254
+ return x # [n, c, h, w]
255
+
256
+ def forward(
257
+ self,
258
+ images,
259
+ class_id,
260
+ cached=False
261
+ ):
262
+ if cached:
263
+ vae_latent = images
264
+ else:
265
+ vae_latent, _, _, _ = self.vae.encode(images) # b c h w
266
+
267
+ vae_latent = self.patchify(vae_latent)
268
+ x = vae_latent.clone().detach()
269
+ if self.training:
270
+ if self.perturb_schedule =="constant":
271
+ x = flip_tensor_elements_uniform_prob(x, self.perturb_rate)
272
+ else:
273
+ raise NotImplementedError(f"unknown perturb_schedule {self.perturb_schedule}")
274
+ x = self.proj_in(x[:, :-1, :])
275
+ class_id = self.drop_label(class_id)
276
+ bsz = x.shape[0]
277
+ c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1)
278
+ x = torch.cat([c, x], dim=1)
279
+ x = self.emb_norm(x)
280
+
281
+ if self.grad_checkpointing and self.training:
282
+ for layer in self.layers:
283
+ block = partial(layer.forward, freqs_cis=self.freqs_cis)
284
+ x = checkpoint(block, x, use_reentrant=False)
285
+ else:
286
+ for layer in self.layers:
287
+ x = layer(x, self.freqs_cis)
288
+
289
+ x = x[:, -self.h * self.w :, :]
290
+ x = self.norm(x)
291
+ x = x + self.pos_for_diff.weight
292
+
293
+ target = vae_latent.clone().detach()
294
+ x = x.view(-1, x.shape[-1])
295
+ target = target.view(-1, target.shape[-1])
296
+
297
+ x = x.repeat(self.diff_batch_mul, 1)
298
+ target = target.repeat(self.diff_batch_mul, 1)
299
+ loss = self.head(target, x)
300
+
301
+ return loss
302
+
303
+ def enable_kv_cache(self, bsz):
304
+ for layer in self.layers:
305
+ layer.attention.enable_kv_cache(bsz, self.total_tokens)
306
+
307
+ @torch.compile()
308
+ def forward_model(self, x, start_pos, end_pos):
309
+ x = self.emb_norm(x)
310
+ for layer in self.layers:
311
+ x = layer.forward_onestep(
312
+ x, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos
313
+ )
314
+ x = self.norm(x)
315
+ return x
316
+
317
+ def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"):
318
+ x = x + self.pos_for_diff.weight[diff_pos : diff_pos + 1, :]
319
+ x = x.view(-1, x.shape[-1])
320
+ seq_len = self.h * self.w
321
+ if cfg_scale > 1.0:
322
+ if cfg_schedule == "constant":
323
+ cfg_iter = cfg_scale
324
+ elif cfg_schedule == "linear":
325
+ start = 1.0
326
+ cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len
327
+ else:
328
+ raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}")
329
+ else:
330
+ cfg_iter = 1.0
331
+ pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter)
332
+ pred = pred.view(-1, 1, pred.shape[-1])
333
+ # Important: LFQ here, sign the prediction
334
+ pred = torch.sign(pred)
335
+ return pred
336
+
337
+ @torch.no_grad()
338
+ def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear", chunk_size=0):
339
+ self.eval()
340
+ if cfg_scale > 1.0:
341
+ cond_null = torch.ones_like(cond) * self.num_classes
342
+ cond_combined = torch.cat([cond, cond_null])
343
+ else:
344
+ cond_combined = cond
345
+ bsz = cond_combined.shape[0]
346
+ act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz
347
+ self.enable_kv_cache(bsz)
348
+
349
+ c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1)
350
+ last_pred = None
351
+ all_preds = []
352
+ for i in range(self.h * self.w):
353
+ if i == 0:
354
+ x = self.forward_model(c, 0, self.cls_token_num)
355
+ else:
356
+ x = self.proj_in(last_pred)
357
+ x = self.forward_model(
358
+ x, i + self.cls_token_num - 1, i + self.cls_token_num
359
+ )
360
+ last_pred = self.head_sample(
361
+ x[:, -1:, :],
362
+ i,
363
+ sample_steps,
364
+ cfg_scale,
365
+ cfg_schedule,
366
+ )
367
+ all_preds.append(last_pred)
368
+
369
+ x = torch.cat(all_preds, dim=-2)[:act_bsz]
370
+ if x.dim() == 3: #b n c -> b c h w
371
+ x = self.unpatchify(x)
372
+ if chunk_size > 0:
373
+ recon = self.decode_in_chunks(x, chunk_size)
374
+ else:
375
+ recon = self.vae.decode(x)
376
+ return recon
377
+
378
+ def decode_in_chunks(self, latent_tensor, chunk_size=64):
379
+ total_bsz = latent_tensor.shape[0]
380
+ recon_chunks_on_cpu = []
381
+ with torch.no_grad():
382
+ for i in range(0, total_bsz, chunk_size):
383
+ end_idx = min(i + chunk_size, total_bsz)
384
+ latent_chunk = latent_tensor[i:end_idx]
385
+ recon_chunk = self.vae.decode(latent_chunk)
386
+ recon_chunks_on_cpu.append(recon_chunk.cpu())
387
+ return torch.cat(recon_chunks_on_cpu, dim=0)
388
+
389
+ def get_fsdp_wrap_module_list(self):
390
+ return list(self.layers)
391
+
392
+ def BitDance_H(**kwargs):
393
+ return BitDance(
394
+ n_layer=40,
395
+ n_head=20,
396
+ dim=1280,
397
+ diff_layers=12,
398
+ diff_dim=1280,
399
+ diff_adanln_layers=3,
400
+ **kwargs,
401
+ )
402
+
403
+
404
+ def BitDance_L(**kwargs):
405
+ return BitDance(
406
+ n_layer=32,
407
+ n_head=16,
408
+ dim=1024,
409
+ diff_layers=8,
410
+ diff_dim=1024,
411
+ diff_adanln_layers=2,
412
+ **kwargs,
413
+ )
414
+
415
+
416
+ def BitDance_B(**kwargs):
417
+ return BitDance(
418
+ n_layer=24,
419
+ n_head=12,
420
+ dim=768,
421
+ diff_layers=6,
422
+ diff_dim=768,
423
+ diff_adanln_layers=2,
424
+ **kwargs,
425
+ )
426
+
427
+
428
+ BitDance_models = {
429
+ "BitDance-B": BitDance_B,
430
+ "BitDance-L": BitDance_L,
431
+ "BitDance-H": BitDance_H,
432
+ }