| import argparse |
| from functools import partial |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from torch.utils.checkpoint import checkpoint |
|
|
| from .diff_head import DiffHead |
| from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d |
| from .qae import VQModel |
|
|
| def get_model_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--model", type=str, choices=list(BitDance_models.keys()), default="BitDance-L" |
| ) |
| parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) |
| parser.add_argument("--down-size", type=int, default=16, choices=[16]) |
| parser.add_argument("--patch-size", type=int, default=1, choices=[1, 2, 4]) |
| parser.add_argument("--num-classes", type=int, default=1000) |
| parser.add_argument("--cls-token-num", type=int, default=64) |
| parser.add_argument("--latent-dim", type=int, default=16) |
| parser.add_argument("--diff-batch-mul", type=int, default=4) |
| parser.add_argument("--grad-checkpointing", action="store_true") |
| parser.add_argument("--trained-vae", type=str, default="") |
| parser.add_argument("--drop-rate", type=float, default=0.0) |
| parser.add_argument("--perturb-schedule", type=str, default="constant") |
| parser.add_argument("--perturb-rate", type=float, default=0.0) |
| parser.add_argument("--perturb-rate-max", type=float, default=0.3) |
| parser.add_argument("--time-schedule", type=str, default='logit_normal') |
| parser.add_argument("--time-shift", type=float, default=1.) |
| parser.add_argument("--P-std", type=float, default=1.) |
| parser.add_argument("--P-mean", type=float, default=0.) |
| return parser |
|
|
|
|
| def create_model(args, device): |
| model = BitDance_models[args.model]( |
| resolution=args.image_size, |
| down_size=args.down_size, |
| patch_size=args.patch_size, |
| latent_dim=args.latent_dim, |
| diff_batch_mul=args.diff_batch_mul, |
| cls_token_num=args.cls_token_num, |
| num_classes=args.num_classes, |
| grad_checkpointing=args.grad_checkpointing, |
| trained_vae=args.trained_vae, |
| drop_rate=args.drop_rate, |
| perturb_schedule=args.perturb_schedule, |
| perturb_rate=args.perturb_rate, |
| perturb_rate_max=args.perturb_rate_max, |
| time_schedule=args.time_schedule, |
| time_shift=args.time_shift, |
| P_std=args.P_std, |
| P_mean=args.P_mean, |
| ).to(device, memory_format=torch.channels_last) |
| return model |
|
|
| class MLPConnector(nn.Module): |
| def __init__(self, in_dim, dim, dropout_p=0.0): |
| super().__init__() |
| hidden_dim = int(dim * 1.5) |
| self.w1 = nn.Linear(in_dim, hidden_dim * 2, bias=True) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=True) |
| self.ffn_dropout = nn.Dropout(dropout_p) |
|
|
| def forward(self, x): |
| h1, h2 = self.w1(x).chunk(2, dim=-1) |
| return self.ffn_dropout(self.w2(F.silu(h1) * h2)) |
|
|
| def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor: |
| if not 0.0 <= p_max <= 1.0: |
| raise ValueError(f"p_max must be in [0.0, 1.0] range, but got: {p_max}") |
| r1 = torch.rand_like(tensor) |
| r2 = torch.rand_like(tensor) |
| flip_mask = r1 < p_max * r2 |
| multiplier = torch.where(flip_mask, -1.0, 1.0) |
| multiplier = multiplier.to(tensor.dtype) |
| flipped_tensor = tensor * multiplier |
| return flipped_tensor |
|
|
| class BitDance(nn.Module): |
|
|
| def __init__( |
| self, |
| dim, |
| n_layer, |
| n_head, |
| diff_layers, |
| diff_dim, |
| diff_adanln_layers, |
| latent_dim, |
| down_size, |
| patch_size, |
| resolution, |
| diff_batch_mul, |
| grad_checkpointing=False, |
| cls_token_num=16, |
| num_classes: int = 1000, |
| class_dropout_prob: float = 0.1, |
| trained_vae: str = "", |
| drop_rate: float = 0.0, |
| perturb_schedule: str = "constant", |
| perturb_rate: float = 0.0, |
| perturb_rate_max: float = 0.3, |
| time_schedule: str = 'logit_normal', |
| time_shift: float = 1., |
| P_std: float = 1., |
| P_mean: float = 0., |
| ): |
| super().__init__() |
|
|
| self.n_layer = n_layer |
| self.resolution = resolution |
| self.down_size = down_size |
| self.patch_size = patch_size |
| self.num_classes = num_classes |
| self.cls_token_num = cls_token_num |
| self.class_dropout_prob = class_dropout_prob |
| self.latent_dim = latent_dim |
| self.trained_vae = trained_vae |
| self.perturb_schedule = perturb_schedule |
| self.perturb_rate = perturb_rate |
| self.perturb_rate_max = perturb_rate_max |
|
|
| |
| ddconfig = { |
| "double_z": False, |
| "z_channels": latent_dim, |
| "in_channels": 3, |
| "out_ch": 3, |
| "ch": 256, |
| "ch_mult": [1,1,2,2,4], |
| "num_res_blocks": 4 |
| } |
| num_codebooks = 4 |
| |
| self.vae = VQModel(ddconfig, num_codebooks) |
| self.grad_checkpointing = grad_checkpointing |
|
|
| self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num) |
| self.proj_in = MLPConnector(latent_dim * self.patch_size * self.patch_size, dim, drop_rate) |
| self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) |
| self.h, self.w = resolution // (down_size * patch_size), resolution // (down_size * patch_size) |
| self.total_tokens = self.h * self.w + self.cls_token_num |
|
|
| self.layers = torch.nn.ModuleList() |
| for layer_id in range(n_layer): |
| self.layers.append( |
| TransformerBlock( |
| dim, |
| n_head, |
| resid_dropout_p=drop_rate, |
| causal=True, |
| ) |
| ) |
|
|
| self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) |
| self.pos_for_diff = nn.Embedding(self.h * self.w, dim) |
| self.head = DiffHead( |
| ch_target=latent_dim * self.patch_size * self.patch_size, |
| ch_cond=dim, |
| ch_latent=diff_dim, |
| depth_latent=diff_layers, |
| depth_adanln=diff_adanln_layers, |
| grad_checkpointing=grad_checkpointing, |
| time_shift=time_shift, |
| time_schedule=time_schedule, |
| P_std=P_std, |
| P_mean=P_mean, |
| ) |
| self.diff_batch_mul = diff_batch_mul |
|
|
| patch_2d_pos = get_2d_pos(resolution, int(down_size * patch_size)) |
|
|
| self.register_buffer( |
| "freqs_cis", |
| precompute_freqs_cis_2d( |
| patch_2d_pos, |
| dim // n_head, |
| 10000, |
| cls_token_num=self.cls_token_num, |
| )[:-1], |
| persistent=False, |
| ) |
| self.freeze_vae() |
|
|
| self.initialize_weights() |
|
|
| def load_vae_weight(self): |
| state = torch.load( |
| self.trained_vae, |
| map_location="cpu", |
| ) |
| missing_keys, unexpected_keys = self.vae.load_state_dict(state["state_dict"], strict=False) |
| print(f"loading vae, missing_keys: {missing_keys}") |
| del state |
|
|
| def non_decay_keys(self): |
| return ["proj_in", "cls_embedding"] |
|
|
| def freeze_module(self, module: nn.Module): |
| for param in module.parameters(): |
| param.requires_grad = False |
|
|
| def freeze_vae(self): |
| self.freeze_module(self.vae) |
| self.vae.eval() |
|
|
| def initialize_weights(self): |
| |
| self.apply(self.__init_weights) |
| self.head.initialize_weights() |
| |
|
|
| def __init_weights(self, module): |
| std = 0.02 |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
|
|
| def drop_label(self, class_id): |
| if self.class_dropout_prob > 0.0 and self.training: |
| is_drop = ( |
| torch.rand(class_id.shape, device=class_id.device) |
| < self.class_dropout_prob |
| ) |
| class_id = torch.where(is_drop, self.num_classes, class_id) |
| return class_id |
|
|
| def patchify(self, x): |
| bsz, c, h, w = x.shape |
| p = self.patch_size |
| h_, w_ = h // p, w // p |
|
|
| x = x.reshape(bsz, c, h_, p, w_, p) |
| x = torch.einsum('nchpwq->nhwcpq', x) |
| x = x.reshape(bsz, h_ * w_, c * p ** 2) |
| return x |
|
|
| def unpatchify(self, x): |
| bsz = x.shape[0] |
| p = self.patch_size |
| c = self.latent_dim |
| h_, w_ = self.h, self.w |
|
|
| x = x.reshape(bsz, h_, w_, c, p, p) |
| x = torch.einsum('nhwcpq->nchpwq', x) |
| x = x.reshape(bsz, c, h_ * p, w_ * p) |
| return x |
|
|
| def forward( |
| self, |
| images, |
| class_id, |
| cached=False |
| ): |
| if cached: |
| vae_latent = images |
| else: |
| vae_latent, _, _, _ = self.vae.encode(images) |
|
|
| vae_latent = self.patchify(vae_latent) |
| x = vae_latent.clone().detach() |
| if self.training: |
| if self.perturb_schedule =="constant": |
| x = flip_tensor_elements_uniform_prob(x, self.perturb_rate) |
| else: |
| raise NotImplementedError(f"unknown perturb_schedule {self.perturb_schedule}") |
| x = self.proj_in(x[:, :-1, :]) |
| class_id = self.drop_label(class_id) |
| bsz = x.shape[0] |
| c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1) |
| x = torch.cat([c, x], dim=1) |
| x = self.emb_norm(x) |
|
|
| if self.grad_checkpointing and self.training: |
| for layer in self.layers: |
| block = partial(layer.forward, freqs_cis=self.freqs_cis) |
| x = checkpoint(block, x, use_reentrant=False) |
| else: |
| for layer in self.layers: |
| x = layer(x, self.freqs_cis) |
|
|
| x = x[:, -self.h * self.w :, :] |
| x = self.norm(x) |
| x = x + self.pos_for_diff.weight |
|
|
| target = vae_latent.clone().detach() |
| x = x.view(-1, x.shape[-1]) |
| target = target.view(-1, target.shape[-1]) |
|
|
| x = x.repeat(self.diff_batch_mul, 1) |
| target = target.repeat(self.diff_batch_mul, 1) |
| loss = self.head(target, x) |
|
|
| return loss |
|
|
| def enable_kv_cache(self, bsz): |
| for layer in self.layers: |
| layer.attention.enable_kv_cache(bsz, self.total_tokens) |
|
|
| @torch.compile() |
| def forward_model(self, x, start_pos, end_pos): |
| x = self.emb_norm(x) |
| for layer in self.layers: |
| x = layer.forward_onestep( |
| x, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos |
| ) |
| x = self.norm(x) |
| return x |
| |
| def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"): |
| x = x + self.pos_for_diff.weight[diff_pos : diff_pos + 1, :] |
| x = x.view(-1, x.shape[-1]) |
| seq_len = self.h * self.w |
| if cfg_scale > 1.0: |
| if cfg_schedule == "constant": |
| cfg_iter = cfg_scale |
| elif cfg_schedule == "linear": |
| start = 1.0 |
| cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len |
| else: |
| raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}") |
| else: |
| cfg_iter = 1.0 |
| pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter) |
| pred = pred.view(-1, 1, pred.shape[-1]) |
| |
| pred = torch.sign(pred) |
| return pred |
|
|
| @torch.no_grad() |
| def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear", chunk_size=0): |
| self.eval() |
| if cfg_scale > 1.0: |
| cond_null = torch.ones_like(cond) * self.num_classes |
| cond_combined = torch.cat([cond, cond_null]) |
| else: |
| cond_combined = cond |
| bsz = cond_combined.shape[0] |
| act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz |
| self.enable_kv_cache(bsz) |
|
|
| c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1) |
| last_pred = None |
| all_preds = [] |
| for i in range(self.h * self.w): |
| if i == 0: |
| x = self.forward_model(c, 0, self.cls_token_num) |
| else: |
| x = self.proj_in(last_pred) |
| x = self.forward_model( |
| x, i + self.cls_token_num - 1, i + self.cls_token_num |
| ) |
| last_pred = self.head_sample( |
| x[:, -1:, :], |
| i, |
| sample_steps, |
| cfg_scale, |
| cfg_schedule, |
| ) |
| all_preds.append(last_pred) |
|
|
| x = torch.cat(all_preds, dim=-2)[:act_bsz] |
| if x.dim() == 3: |
| x = self.unpatchify(x) |
| if chunk_size > 0: |
| recon = self.decode_in_chunks(x, chunk_size) |
| else: |
| recon = self.vae.decode(x) |
| return recon |
|
|
| def decode_in_chunks(self, latent_tensor, chunk_size=64): |
| total_bsz = latent_tensor.shape[0] |
| recon_chunks_on_cpu = [] |
| with torch.no_grad(): |
| for i in range(0, total_bsz, chunk_size): |
| end_idx = min(i + chunk_size, total_bsz) |
| latent_chunk = latent_tensor[i:end_idx] |
| recon_chunk = self.vae.decode(latent_chunk) |
| recon_chunks_on_cpu.append(recon_chunk.cpu()) |
| return torch.cat(recon_chunks_on_cpu, dim=0) |
|
|
| def get_fsdp_wrap_module_list(self): |
| return list(self.layers) |
|
|
| def BitDance_H(**kwargs): |
| return BitDance( |
| n_layer=40, |
| n_head=20, |
| dim=1280, |
| diff_layers=12, |
| diff_dim=1280, |
| diff_adanln_layers=3, |
| **kwargs, |
| ) |
|
|
|
|
| def BitDance_L(**kwargs): |
| return BitDance( |
| n_layer=32, |
| n_head=16, |
| dim=1024, |
| diff_layers=8, |
| diff_dim=1024, |
| diff_adanln_layers=2, |
| **kwargs, |
| ) |
|
|
|
|
| def BitDance_B(**kwargs): |
| return BitDance( |
| n_layer=24, |
| n_head=12, |
| dim=768, |
| diff_layers=6, |
| diff_dim=768, |
| diff_adanln_layers=2, |
| **kwargs, |
| ) |
|
|
|
|
| BitDance_models = { |
| "BitDance-B": BitDance_B, |
| "BitDance-L": BitDance_L, |
| "BitDance-H": BitDance_H, |
| } |
|
|