|
|
import argparse |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
from .diff_head import DiffHead |
|
|
from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d |
|
|
from .qae import VQModel |
|
|
|
|
|
def get_model_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--model", type=str, choices=list(BitDance_models.keys()), default="BitDance-L" |
|
|
) |
|
|
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) |
|
|
parser.add_argument("--down-size", type=int, default=16, choices=[16]) |
|
|
parser.add_argument("--patch-size", type=int, default=1, choices=[1, 2, 4]) |
|
|
parser.add_argument("--num-classes", type=int, default=1000) |
|
|
parser.add_argument("--cls-token-num", type=int, default=64) |
|
|
parser.add_argument("--latent-dim", type=int, default=16) |
|
|
parser.add_argument("--diff-batch-mul", type=int, default=4) |
|
|
parser.add_argument("--grad-checkpointing", action="store_true") |
|
|
parser.add_argument("--trained-vae", type=str, default="") |
|
|
parser.add_argument("--drop-rate", type=float, default=0.0) |
|
|
parser.add_argument("--perturb-schedule", type=str, default="constant") |
|
|
parser.add_argument("--perturb-rate", type=float, default=0.0) |
|
|
parser.add_argument("--perturb-rate-max", type=float, default=0.3) |
|
|
parser.add_argument("--time-schedule", type=str, default='logit_normal') |
|
|
parser.add_argument("--time-shift", type=float, default=1.) |
|
|
parser.add_argument("--P-std", type=float, default=1.) |
|
|
parser.add_argument("--P-mean", type=float, default=0.) |
|
|
return parser |
|
|
|
|
|
|
|
|
def create_model(args, device): |
|
|
model = BitDance_models[args.model]( |
|
|
resolution=args.image_size, |
|
|
down_size=args.down_size, |
|
|
patch_size=args.patch_size, |
|
|
latent_dim=args.latent_dim, |
|
|
diff_batch_mul=args.diff_batch_mul, |
|
|
cls_token_num=args.cls_token_num, |
|
|
num_classes=args.num_classes, |
|
|
grad_checkpointing=args.grad_checkpointing, |
|
|
trained_vae=args.trained_vae, |
|
|
drop_rate=args.drop_rate, |
|
|
perturb_schedule=args.perturb_schedule, |
|
|
perturb_rate=args.perturb_rate, |
|
|
perturb_rate_max=args.perturb_rate_max, |
|
|
time_schedule=args.time_schedule, |
|
|
time_shift=args.time_shift, |
|
|
P_std=args.P_std, |
|
|
P_mean=args.P_mean, |
|
|
).to(device, memory_format=torch.channels_last) |
|
|
return model |
|
|
|
|
|
class MLPConnector(nn.Module): |
|
|
def __init__(self, in_dim, dim, dropout_p=0.0): |
|
|
super().__init__() |
|
|
hidden_dim = int(dim * 1.5) |
|
|
self.w1 = nn.Linear(in_dim, hidden_dim * 2, bias=True) |
|
|
self.w2 = nn.Linear(hidden_dim, dim, bias=True) |
|
|
self.ffn_dropout = nn.Dropout(dropout_p) |
|
|
|
|
|
def forward(self, x): |
|
|
h1, h2 = self.w1(x).chunk(2, dim=-1) |
|
|
return self.ffn_dropout(self.w2(F.silu(h1) * h2)) |
|
|
|
|
|
def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor: |
|
|
if not 0.0 <= p_max <= 1.0: |
|
|
raise ValueError(f"p_max must be in [0.0, 1.0] range, but got: {p_max}") |
|
|
r1 = torch.rand_like(tensor) |
|
|
r2 = torch.rand_like(tensor) |
|
|
flip_mask = r1 < p_max * r2 |
|
|
multiplier = torch.where(flip_mask, -1.0, 1.0) |
|
|
multiplier = multiplier.to(tensor.dtype) |
|
|
flipped_tensor = tensor * multiplier |
|
|
return flipped_tensor |
|
|
|
|
|
class BitDance(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
n_layer, |
|
|
n_head, |
|
|
diff_layers, |
|
|
diff_dim, |
|
|
diff_adanln_layers, |
|
|
latent_dim, |
|
|
down_size, |
|
|
patch_size, |
|
|
resolution, |
|
|
diff_batch_mul, |
|
|
grad_checkpointing=False, |
|
|
cls_token_num=16, |
|
|
num_classes: int = 1000, |
|
|
class_dropout_prob: float = 0.1, |
|
|
trained_vae: str = "", |
|
|
drop_rate: float = 0.0, |
|
|
perturb_schedule: str = "constant", |
|
|
perturb_rate: float = 0.0, |
|
|
perturb_rate_max: float = 0.3, |
|
|
time_schedule: str = 'logit_normal', |
|
|
time_shift: float = 1., |
|
|
P_std: float = 1., |
|
|
P_mean: float = 0., |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.n_layer = n_layer |
|
|
self.resolution = resolution |
|
|
self.down_size = down_size |
|
|
self.patch_size = patch_size |
|
|
self.num_classes = num_classes |
|
|
self.cls_token_num = cls_token_num |
|
|
self.class_dropout_prob = class_dropout_prob |
|
|
self.latent_dim = latent_dim |
|
|
self.trained_vae = trained_vae |
|
|
self.perturb_schedule = perturb_schedule |
|
|
self.perturb_rate = perturb_rate |
|
|
self.perturb_rate_max = perturb_rate_max |
|
|
|
|
|
|
|
|
ddconfig = { |
|
|
"double_z": False, |
|
|
"z_channels": latent_dim, |
|
|
"in_channels": 3, |
|
|
"out_ch": 3, |
|
|
"ch": 256, |
|
|
"ch_mult": [1,1,2,2,4], |
|
|
"num_res_blocks": 4 |
|
|
} |
|
|
num_codebooks = 4 |
|
|
|
|
|
self.vae = VQModel(ddconfig, num_codebooks) |
|
|
self.grad_checkpointing = grad_checkpointing |
|
|
|
|
|
self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num) |
|
|
self.proj_in = MLPConnector(latent_dim * self.patch_size * self.patch_size, dim, drop_rate) |
|
|
self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) |
|
|
self.h, self.w = resolution // (down_size * patch_size), resolution // (down_size * patch_size) |
|
|
self.total_tokens = self.h * self.w + self.cls_token_num |
|
|
|
|
|
self.layers = torch.nn.ModuleList() |
|
|
for layer_id in range(n_layer): |
|
|
self.layers.append( |
|
|
TransformerBlock( |
|
|
dim, |
|
|
n_head, |
|
|
resid_dropout_p=drop_rate, |
|
|
causal=True, |
|
|
) |
|
|
) |
|
|
|
|
|
self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) |
|
|
self.pos_for_diff = nn.Embedding(self.h * self.w, dim) |
|
|
self.head = DiffHead( |
|
|
ch_target=latent_dim * self.patch_size * self.patch_size, |
|
|
ch_cond=dim, |
|
|
ch_latent=diff_dim, |
|
|
depth_latent=diff_layers, |
|
|
depth_adanln=diff_adanln_layers, |
|
|
grad_checkpointing=grad_checkpointing, |
|
|
time_shift=time_shift, |
|
|
time_schedule=time_schedule, |
|
|
P_std=P_std, |
|
|
P_mean=P_mean, |
|
|
) |
|
|
self.diff_batch_mul = diff_batch_mul |
|
|
|
|
|
patch_2d_pos = get_2d_pos(resolution, int(down_size * patch_size)) |
|
|
|
|
|
self.register_buffer( |
|
|
"freqs_cis", |
|
|
precompute_freqs_cis_2d( |
|
|
patch_2d_pos, |
|
|
dim // n_head, |
|
|
10000, |
|
|
cls_token_num=self.cls_token_num, |
|
|
)[:-1], |
|
|
persistent=False, |
|
|
) |
|
|
self.freeze_vae() |
|
|
|
|
|
self.initialize_weights() |
|
|
|
|
|
def load_vae_weight(self): |
|
|
state = torch.load( |
|
|
self.trained_vae, |
|
|
map_location="cpu", |
|
|
) |
|
|
missing_keys, unexpected_keys = self.vae.load_state_dict(state["state_dict"], strict=False) |
|
|
print(f"loading vae, missing_keys: {missing_keys}") |
|
|
del state |
|
|
|
|
|
def non_decay_keys(self): |
|
|
return ["proj_in", "cls_embedding"] |
|
|
|
|
|
def freeze_module(self, module: nn.Module): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def freeze_vae(self): |
|
|
self.freeze_module(self.vae) |
|
|
self.vae.eval() |
|
|
|
|
|
def initialize_weights(self): |
|
|
|
|
|
self.apply(self.__init_weights) |
|
|
self.head.initialize_weights() |
|
|
|
|
|
|
|
|
def __init_weights(self, module): |
|
|
std = 0.02 |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
|
|
|
def drop_label(self, class_id): |
|
|
if self.class_dropout_prob > 0.0 and self.training: |
|
|
is_drop = ( |
|
|
torch.rand(class_id.shape, device=class_id.device) |
|
|
< self.class_dropout_prob |
|
|
) |
|
|
class_id = torch.where(is_drop, self.num_classes, class_id) |
|
|
return class_id |
|
|
|
|
|
def patchify(self, x): |
|
|
bsz, c, h, w = x.shape |
|
|
p = self.patch_size |
|
|
h_, w_ = h // p, w // p |
|
|
|
|
|
x = x.reshape(bsz, c, h_, p, w_, p) |
|
|
x = torch.einsum('nchpwq->nhwcpq', x) |
|
|
x = x.reshape(bsz, h_ * w_, c * p ** 2) |
|
|
return x |
|
|
|
|
|
def unpatchify(self, x): |
|
|
bsz = x.shape[0] |
|
|
p = self.patch_size |
|
|
c = self.latent_dim |
|
|
h_, w_ = self.h, self.w |
|
|
|
|
|
x = x.reshape(bsz, h_, w_, c, p, p) |
|
|
x = torch.einsum('nhwcpq->nchpwq', x) |
|
|
x = x.reshape(bsz, c, h_ * p, w_ * p) |
|
|
return x |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
images, |
|
|
class_id, |
|
|
cached=False |
|
|
): |
|
|
if cached: |
|
|
vae_latent = images |
|
|
else: |
|
|
vae_latent, _, _, _ = self.vae.encode(images) |
|
|
|
|
|
vae_latent = self.patchify(vae_latent) |
|
|
x = vae_latent.clone().detach() |
|
|
if self.training: |
|
|
if self.perturb_schedule =="constant": |
|
|
x = flip_tensor_elements_uniform_prob(x, self.perturb_rate) |
|
|
else: |
|
|
raise NotImplementedError(f"unknown perturb_schedule {self.perturb_schedule}") |
|
|
x = self.proj_in(x[:, :-1, :]) |
|
|
class_id = self.drop_label(class_id) |
|
|
bsz = x.shape[0] |
|
|
c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1) |
|
|
x = torch.cat([c, x], dim=1) |
|
|
x = self.emb_norm(x) |
|
|
|
|
|
if self.grad_checkpointing and self.training: |
|
|
for layer in self.layers: |
|
|
block = partial(layer.forward, freqs_cis=self.freqs_cis) |
|
|
x = checkpoint(block, x, use_reentrant=False) |
|
|
else: |
|
|
for layer in self.layers: |
|
|
x = layer(x, self.freqs_cis) |
|
|
|
|
|
x = x[:, -self.h * self.w :, :] |
|
|
x = self.norm(x) |
|
|
x = x + self.pos_for_diff.weight |
|
|
|
|
|
target = vae_latent.clone().detach() |
|
|
x = x.view(-1, x.shape[-1]) |
|
|
target = target.view(-1, target.shape[-1]) |
|
|
|
|
|
x = x.repeat(self.diff_batch_mul, 1) |
|
|
target = target.repeat(self.diff_batch_mul, 1) |
|
|
loss = self.head(target, x) |
|
|
|
|
|
return loss |
|
|
|
|
|
def enable_kv_cache(self, bsz): |
|
|
for layer in self.layers: |
|
|
layer.attention.enable_kv_cache(bsz, self.total_tokens) |
|
|
|
|
|
@torch.compile() |
|
|
def forward_model(self, x, start_pos, end_pos): |
|
|
x = self.emb_norm(x) |
|
|
for layer in self.layers: |
|
|
x = layer.forward_onestep( |
|
|
x, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos |
|
|
) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"): |
|
|
x = x + self.pos_for_diff.weight[diff_pos : diff_pos + 1, :] |
|
|
x = x.view(-1, x.shape[-1]) |
|
|
seq_len = self.h * self.w |
|
|
if cfg_scale > 1.0: |
|
|
if cfg_schedule == "constant": |
|
|
cfg_iter = cfg_scale |
|
|
elif cfg_schedule == "linear": |
|
|
start = 1.0 |
|
|
cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len |
|
|
else: |
|
|
raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}") |
|
|
else: |
|
|
cfg_iter = 1.0 |
|
|
pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter) |
|
|
pred = pred.view(-1, 1, pred.shape[-1]) |
|
|
|
|
|
pred = torch.sign(pred) |
|
|
return pred |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear", chunk_size=0): |
|
|
self.eval() |
|
|
if cfg_scale > 1.0: |
|
|
cond_null = torch.ones_like(cond) * self.num_classes |
|
|
cond_combined = torch.cat([cond, cond_null]) |
|
|
else: |
|
|
cond_combined = cond |
|
|
bsz = cond_combined.shape[0] |
|
|
act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz |
|
|
self.enable_kv_cache(bsz) |
|
|
|
|
|
c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1) |
|
|
last_pred = None |
|
|
all_preds = [] |
|
|
for i in range(self.h * self.w): |
|
|
if i == 0: |
|
|
x = self.forward_model(c, 0, self.cls_token_num) |
|
|
else: |
|
|
x = self.proj_in(last_pred) |
|
|
x = self.forward_model( |
|
|
x, i + self.cls_token_num - 1, i + self.cls_token_num |
|
|
) |
|
|
last_pred = self.head_sample( |
|
|
x[:, -1:, :], |
|
|
i, |
|
|
sample_steps, |
|
|
cfg_scale, |
|
|
cfg_schedule, |
|
|
) |
|
|
all_preds.append(last_pred) |
|
|
|
|
|
x = torch.cat(all_preds, dim=-2)[:act_bsz] |
|
|
if x.dim() == 3: |
|
|
x = self.unpatchify(x) |
|
|
if chunk_size > 0: |
|
|
recon = self.decode_in_chunks(x, chunk_size) |
|
|
else: |
|
|
recon = self.vae.decode(x) |
|
|
return recon |
|
|
|
|
|
def decode_in_chunks(self, latent_tensor, chunk_size=64): |
|
|
total_bsz = latent_tensor.shape[0] |
|
|
recon_chunks_on_cpu = [] |
|
|
with torch.no_grad(): |
|
|
for i in range(0, total_bsz, chunk_size): |
|
|
end_idx = min(i + chunk_size, total_bsz) |
|
|
latent_chunk = latent_tensor[i:end_idx] |
|
|
recon_chunk = self.vae.decode(latent_chunk) |
|
|
recon_chunks_on_cpu.append(recon_chunk.cpu()) |
|
|
return torch.cat(recon_chunks_on_cpu, dim=0) |
|
|
|
|
|
def get_fsdp_wrap_module_list(self): |
|
|
return list(self.layers) |
|
|
|
|
|
def BitDance_H(**kwargs): |
|
|
return BitDance( |
|
|
n_layer=40, |
|
|
n_head=20, |
|
|
dim=1280, |
|
|
diff_layers=12, |
|
|
diff_dim=1280, |
|
|
diff_adanln_layers=3, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def BitDance_L(**kwargs): |
|
|
return BitDance( |
|
|
n_layer=32, |
|
|
n_head=16, |
|
|
dim=1024, |
|
|
diff_layers=8, |
|
|
diff_dim=1024, |
|
|
diff_adanln_layers=2, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def BitDance_B(**kwargs): |
|
|
return BitDance( |
|
|
n_layer=24, |
|
|
n_head=12, |
|
|
dim=768, |
|
|
diff_layers=6, |
|
|
diff_dim=768, |
|
|
diff_adanln_layers=2, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
BitDance_models = { |
|
|
"BitDance-B": BitDance_B, |
|
|
"BitDance-L": BitDance_L, |
|
|
"BitDance-H": BitDance_H, |
|
|
} |
|
|
|