BiliSakura's picture
Update all files for BitDance-ImageNet-diffusers
3f7ea9b verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .sampling_parallel import euler_maruyama
def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
half = dim // 2
t = time_factor * t.float()
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
def time_shift_sana(t: torch.Tensor, flow_shift: float = 1., sigma: float = 1.):
return (1 / flow_shift) / ( (1 / flow_shift) + (1 / t - 1) ** sigma)
class DiffHead(nn.Module):
"""Diffusion Loss"""
def __init__(
self,
ch_target,
ch_cond,
ch_latent,
depth_latent,
depth_adanln,
grad_checkpointing=False,
time_shift=1.,
time_schedule='logit_normal',
parallel_num=4,
P_std: float = 1.,
P_mean: float = 0.,
):
super(DiffHead, self).__init__()
self.ch_target = ch_target
self.time_shift = time_shift
self.time_schedule = time_schedule
self.P_std = P_std
self.P_mean = P_mean
self.net = TransEncoder(
in_channels=ch_target,
model_channels=ch_latent,
z_channels=ch_cond,
num_res_blocks=depth_latent,
num_ada_ln_blocks=depth_adanln,
grad_checkpointing=grad_checkpointing,
parallel_num=parallel_num,
)
def forward(self, x, cond):
with torch.autocast(device_type="cuda", enabled=False):
with torch.no_grad():
if self.time_schedule == 'logit_normal':
t = (torch.randn((x.shape[0]), device=x.device) * self.P_std + self.P_mean).sigmoid()
if self.time_shift != 1.:
t = time_shift_sana(t, self.time_shift)
elif self.time_schedule == 'uniform':
t = torch.rand((x.shape[0]), device=x.device)
if self.time_shift != 1.:
t = time_shift_sana(t, self.time_shift)
else:
raise NotImplementedError(f"unknown time_schedule {self.time_schedule}")
e = torch.randn_like(x)
ti = t.view(-1, 1, 1)
z = (1.0 - ti) * e + ti * x
v = (x - z) / (1 - ti).clamp_min(0.05)
x_pred = self.net(z, t, cond)
v_pred = (x_pred - z) / (1 - ti).clamp_min(0.05)
with torch.autocast(device_type="cuda", enabled=False):
v_pred = v_pred.float()
loss = torch.mean((v - v_pred) ** 2)
return loss
def sample(
self,
z,
cfg,
num_sampling_steps,
):
return euler_maruyama(
self.ch_target,
self.net.forward,
z,
cfg,
num_sampling_steps=num_sampling_steps,
time_shift = self.time_shift,
)
def initialize_weights(self):
self.net.initialize_weights()
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.norm = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
hidden_dim = int(channels * 1.5)
self.w1 = nn.Linear(channels, hidden_dim * 2, bias=True)
self.w2 = nn.Linear(hidden_dim, channels, bias=True)
def forward(self, x, scale, shift, gate):
h = self.norm(x) * (1 + scale) + shift
h1, h2 = self.w1(h).chunk(2, dim=-1)
h = self.w2(F.silu(h1) * h2)
return x + h * gate
class FinalLayer(nn.Module):
def __init__(self, channels, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=False)
self.ada_ln_modulation = nn.Linear(channels, channels * 2, bias=True)
self.linear = nn.Linear(channels, out_channels, bias=True)
def forward(self, x, y):
scale, shift = self.ada_ln_modulation(y).chunk(2, dim=-1)
x = self.norm_final(x) * (1.0 + scale) + shift
x = self.linear(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim,
n_head,
):
super().__init__()
assert dim % n_head == 0
self.dim = dim
self.head_dim = dim // n_head
self.scale = self.head_dim**-0.5
self.n_head = n_head
total_kv_dim = (self.n_head * 3) * self.head_dim
self.wqkv = nn.Linear(dim, total_kv_dim, bias=True)
self.wo = nn.Linear(dim, dim, bias=True)
def forward(
self,
x: torch.Tensor,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wqkv(x).chunk(3, dim=-1)
xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_head, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_head, self.head_dim)
xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
xq = xq * self.scale
attn = xq @ xk.transpose(-1, -2)
attn = F.softmax(attn, dim=-1)
output = (attn @ xv).transpose(1, 2).contiguous()
# output = flash_attn_func(
# xq,
# xk,
# xv,
# causal=False,
# )
output = output.view(bsz, seqlen, self.dim)
output = self.wo(output)
return output
class TransBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.norm1 = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
self.attn = Attention(channels, n_head=channels//64)
self.norm2 = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
hidden_dim = int(channels * 1.5)
self.w1 = nn.Linear(channels, hidden_dim * 2, bias=True)
self.w2 = nn.Linear(hidden_dim, channels, bias=True)
def forward(self, x, scale1, shift1, gate1, scale2, shift2, gate2):
h = self.norm1(x) * (1 + scale1) + shift1
h = self.attn(h)
x = x + h * gate1
h = self.norm2(x) * (1 + scale2) + shift2
h1, h2 = self.w1(h).chunk(2, dim=-1)
h = self.w2(F.silu(h1) * h2)
return x + h * gate2
class TransEncoder(nn.Module):
def __init__(
self,
in_channels,
model_channels,
z_channels,
num_res_blocks,
num_ada_ln_blocks=2,
grad_checkpointing=False,
parallel_num=4,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = in_channels
self.num_res_blocks = num_res_blocks
self.grad_checkpointing = grad_checkpointing
self.parallel_num = parallel_num
self.time_embed = TimestepEmbedder(model_channels)
self.cond_embed = nn.Linear(z_channels, model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
self.res_blocks = nn.ModuleList()
for i in range(num_res_blocks):
self.res_blocks.append(
TransBlock(
model_channels,
)
)
# share adaLN for consecutive blocks, to save computation and parameters
self.ada_ln_blocks = nn.ModuleList()
for i in range(num_ada_ln_blocks):
self.ada_ln_blocks.append(
nn.Linear(model_channels, model_channels * 6, bias=True)
)
self.ada_ln_switch_freq = max(1, num_res_blocks // num_ada_ln_blocks)
assert (
num_res_blocks % self.ada_ln_switch_freq
) == 0, "num_res_blocks must be divisible by num_ada_ln_blocks"
self.final_layer = FinalLayer(model_channels, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
for block in self.ada_ln_blocks:
nn.init.constant_(block.weight, 0)
nn.init.constant_(block.bias, 0)
# Zero-out output layers
nn.init.constant_(self.final_layer.ada_ln_modulation.weight, 0)
nn.init.constant_(self.final_layer.ada_ln_modulation.bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
@torch.compile()
def forward(self, x, t, c):
x = self.input_proj(x)
t = self.time_embed(t).unsqueeze(1)
c = self.cond_embed(c)
y = F.silu(t+c)
scale1, shift1, gate1, scale2, shift2, gate2 = self.ada_ln_blocks[0](y).chunk(6, dim=-1)
for i, block in enumerate(self.res_blocks):
if i > 0 and i % self.ada_ln_switch_freq == 0:
ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq]
scale1, shift1, gate1, scale2, shift2, gate2 = ada_ln_block(y).chunk(6, dim=-1)
x = block(x, scale1, shift1, gate1, scale2, shift2, gate2)
return self.final_layer(x, y)