|
|
from typing import Tuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
from einops import rearrange |
|
|
from torch.nn.functional import scaled_dot_product_attention |
|
|
|
|
|
def modulate(x, shift, scale): |
|
|
return x * (1 + scale) + shift |
|
|
|
|
|
class Embed(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chans: int = 3, |
|
|
embed_dim: int = 768, |
|
|
norm_layer = None, |
|
|
bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_chans = in_chans |
|
|
self.embed_dim = embed_dim |
|
|
self.proj = nn.Linear(in_chans, embed_dim, bias=bias) |
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
def forward(self, x): |
|
|
x = self.proj(x) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=8, |
|
|
embed_dim=1152, |
|
|
bias=True, |
|
|
patch_size=1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.patch_h, self.patch_w = patch_size |
|
|
|
|
|
self.patch_size = patch_size |
|
|
self.proj = nn.Linear(in_channels * self.patch_h * self.patch_w, embed_dim, bias=bias) |
|
|
self.in_channels = in_channels |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
def forward(self, latent): |
|
|
x = rearrange(latent, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=self.patch_h, p2=self.patch_w) |
|
|
x = self.proj(x) |
|
|
return x |
|
|
|
|
|
class FinalLayer(nn.Module): |
|
|
"""Final layer with configurable patch_size support""" |
|
|
|
|
|
def __init__(self, hidden_size, out_channels=8, patch_size=1): |
|
|
super().__init__() |
|
|
self.patch_h, self.patch_w = patch_size |
|
|
|
|
|
self.linear = nn.Linear(hidden_size, out_channels * self.patch_h * self.patch_w, bias=True) |
|
|
self.out_channels = out_channels |
|
|
self.patch_size = patch_size |
|
|
|
|
|
def forward(self, x, target_height, target_width): |
|
|
|
|
|
x = self.linear(x) |
|
|
|
|
|
x = rearrange(x, 'b (h w) (c p1 p2) -> b c (h p1) (w p2)', |
|
|
h=target_height, w=target_width, |
|
|
p1=self.patch_h, p2=self.patch_w, c=self.out_channels) |
|
|
return x |
|
|
|
|
|
class TimestepEmbedder(nn.Module): |
|
|
|
|
|
def __init__(self, hidden_size, frequency_embedding_size=256): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
|
|
nn.SiLU(), |
|
|
nn.Linear(hidden_size, hidden_size, bias=True), |
|
|
) |
|
|
self.frequency_embedding_size = frequency_embedding_size |
|
|
|
|
|
@staticmethod |
|
|
def timestep_embedding(t, dim, max_period=10): |
|
|
half = dim // 2 |
|
|
freqs = torch.exp( |
|
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half |
|
|
) |
|
|
args = t[..., None].float() * freqs[None, ...] |
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
if dim % 2: |
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
|
|
return embedding |
|
|
|
|
|
def forward(self, t): |
|
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
|
|
t_emb = self.mlp(t_freq) |
|
|
return t_emb |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
hidden_dim: int, |
|
|
): |
|
|
super().__init__() |
|
|
hidden_dim = int(2 * hidden_dim / 3) |
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
|
|
def forward(self, x): |
|
|
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) |
|
|
return x |
|
|
|
|
|
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale=1.0): |
|
|
|
|
|
if isinstance(scale, float): |
|
|
scale = (scale, scale) |
|
|
x_pos = torch.linspace(0, width * scale[0], width) |
|
|
y_pos = torch.linspace(0, height * scale[1], height) |
|
|
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") |
|
|
y_pos = y_pos.reshape(-1) |
|
|
x_pos = x_pos.reshape(-1) |
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) |
|
|
x_freqs = torch.outer(x_pos, freqs).float() |
|
|
y_freqs = torch.outer(y_pos, freqs).float() |
|
|
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) |
|
|
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) |
|
|
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) |
|
|
freqs_cis = freqs_cis.reshape(height * width, -1) |
|
|
return freqs_cis |
|
|
|
|
|
@torch.compiler.disable |
|
|
def apply_rotary_emb_2d( |
|
|
xq: torch.Tensor, |
|
|
xk: torch.Tensor, |
|
|
freqs_cis: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
freqs_cis = freqs_cis[None, None, :, :] |
|
|
|
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
|
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
class RAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int = 8, |
|
|
qkv_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
attn_drop: float = 0., |
|
|
proj_drop: float = 0., |
|
|
norm_layer: nn.Module = RMSNorm, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
|
|
|
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.scale = self.head_dim ** -0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: |
|
|
B, N, C = x.shape |
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
q = self.q_norm(q.contiguous()) |
|
|
k = self.k_norm(k.contiguous()) |
|
|
q, k = apply_rotary_emb_2d(q, k, freqs_cis=pos) |
|
|
|
|
|
q = q.view(B, self.num_heads, -1, C // self.num_heads) |
|
|
k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous() |
|
|
v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous() |
|
|
|
|
|
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_drop.p if self.training else 0.0) |
|
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
context_dim: int, |
|
|
num_heads: int, |
|
|
qkv_bias: bool = False, |
|
|
proj_drop: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.scale = self.head_dim**-0.5 |
|
|
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
self.kv_proj = nn.Linear(context_dim, dim * 2, bias=qkv_bias) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor = None) -> torch.Tensor: |
|
|
B, N, C = x.shape |
|
|
B_ctx, M, C_ctx = context.shape |
|
|
|
|
|
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
|
|
kv = self.kv_proj(context).reshape(B_ctx, M, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
|
|
k, v = kv[0], kv[1] |
|
|
|
|
|
attn_mask = None |
|
|
if context_mask is not None: |
|
|
attn_mask = torch.zeros(B, 1, 1, M, dtype=q.dtype, device=q.device) |
|
|
attn_mask.masked_fill_(~context_mask.unsqueeze(1).unsqueeze(2), float('-inf')) |
|
|
|
|
|
attn = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.proj_drop.p if self.training else 0.0) |
|
|
|
|
|
x = attn.permute(0, 2, 1, 3).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DDTBlock(nn.Module): |
|
|
def __init__(self, hidden_size, groups, mlp_ratio=4.0, context_dim=None, is_encoder_block=False): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.norm1 = RMSNorm(hidden_size, eps=1e-6) |
|
|
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) |
|
|
|
|
|
self.norm_cross = RMSNorm(hidden_size, eps=1e-6) if context_dim else nn.Identity() |
|
|
self.cross_attn = CrossAttention(hidden_size, context_dim, groups) if context_dim else None |
|
|
|
|
|
self.norm2 = RMSNorm(hidden_size, eps=1e-6) |
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
|
self.mlp = FeedForward(hidden_size, mlp_hidden_dim) |
|
|
|
|
|
self.is_encoder_block = is_encoder_block |
|
|
if not is_encoder_block: |
|
|
self.adaLN_modulation = nn.Sequential( |
|
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
|
) |
|
|
|
|
|
def forward(self, x, c, pos, mask=None, context=None, context_mask=None, shared_adaLN=None): |
|
|
if self.is_encoder_block: |
|
|
adaLN_output = shared_adaLN(c) |
|
|
else: |
|
|
adaLN_output = self.adaLN_modulation(c) |
|
|
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = adaLN_output.chunk(6, dim=-1) |
|
|
|
|
|
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) |
|
|
|
|
|
if self.cross_attn is not None and context is not None: |
|
|
x = x + self.cross_attn(self.norm_cross(x), context=context, context_mask=context_mask) |
|
|
|
|
|
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) |
|
|
return x |
|
|
|
|
|
class LocalSongModel(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=8, |
|
|
num_groups=16, |
|
|
hidden_size=1024, |
|
|
decoder_hidden_size=2048, |
|
|
num_blocks=36, |
|
|
patch_size=(16,1), |
|
|
num_classes=2304, |
|
|
max_tags=8, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_channels = in_channels |
|
|
self.out_channels = in_channels |
|
|
self.hidden_size = hidden_size |
|
|
self.decoder_hidden_size = decoder_hidden_size |
|
|
self.num_groups = num_groups |
|
|
self.num_groups = num_groups |
|
|
self.num_blocks = num_blocks |
|
|
self.patch_size = patch_size |
|
|
self.num_classes = num_classes |
|
|
self.max_tags = max_tags |
|
|
|
|
|
self.patch_h, self.patch_w = patch_size |
|
|
|
|
|
self.x_embedder = PatchEmbed( |
|
|
in_channels=in_channels, |
|
|
embed_dim=decoder_hidden_size, |
|
|
bias=True, |
|
|
patch_size=patch_size |
|
|
) |
|
|
|
|
|
self.s_embedder = PatchEmbed( |
|
|
in_channels=in_channels, |
|
|
embed_dim=decoder_hidden_size, |
|
|
bias=True, |
|
|
patch_size=patch_size |
|
|
) |
|
|
|
|
|
self.encoder_to_decoder = nn.Linear(hidden_size, decoder_hidden_size, bias=False) |
|
|
|
|
|
self.a_to_b_proj = nn.Linear(decoder_hidden_size, hidden_size, bias=False) |
|
|
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
|
|
|
|
self.y_embedder = nn.Embedding(num_classes + 1, hidden_size, padding_idx=0) |
|
|
|
|
|
self.final_layer = FinalLayer( |
|
|
decoder_hidden_size, |
|
|
out_channels=in_channels, |
|
|
patch_size=patch_size |
|
|
) |
|
|
|
|
|
self.shared_encoder_adaLN = nn.Sequential( |
|
|
nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
|
|
) |
|
|
|
|
|
self.shared_decoder_adaLN = nn.Sequential( |
|
|
nn.Linear(hidden_size, 6 * decoder_hidden_size, bias=True) |
|
|
) |
|
|
|
|
|
self.blocks = nn.ModuleList() |
|
|
for i in range(self.num_blocks): |
|
|
is_encoder = i < self.num_blocks |
|
|
|
|
|
if is_encoder: |
|
|
if i < 1: |
|
|
block_hidden_size = decoder_hidden_size |
|
|
num_heads = self.num_groups |
|
|
elif i >= self.num_blocks - 3: |
|
|
block_hidden_size = decoder_hidden_size |
|
|
num_heads = self.num_groups |
|
|
else: |
|
|
block_hidden_size = hidden_size |
|
|
num_heads = self.num_groups |
|
|
else: |
|
|
block_hidden_size = decoder_hidden_size |
|
|
num_heads = self.num_groups |
|
|
|
|
|
context_dim = hidden_size if i % 2 == 0 and is_encoder else None |
|
|
|
|
|
self.blocks.append( |
|
|
DDTBlock( |
|
|
block_hidden_size, |
|
|
num_heads, |
|
|
context_dim=context_dim, |
|
|
is_encoder_block=is_encoder |
|
|
) |
|
|
) |
|
|
|
|
|
self.bc_projection = nn.Linear(decoder_hidden_size + hidden_size, decoder_hidden_size, bias=False) |
|
|
|
|
|
self.initialize_weights() |
|
|
self.precompute_encoder_pos = dict() |
|
|
self.precompute_decoder_pos = dict() |
|
|
|
|
|
from functools import lru_cache |
|
|
|
|
|
@lru_cache |
|
|
def fetch_encoder_pos(self, height, width, device): |
|
|
key = (height, width) |
|
|
if key in self.precompute_encoder_pos: |
|
|
return self.precompute_encoder_pos[key].to(device) |
|
|
else: |
|
|
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) |
|
|
self.precompute_encoder_pos[key] = pos |
|
|
return pos |
|
|
|
|
|
@lru_cache |
|
|
def fetch_decoder_pos(self, height, width, device): |
|
|
key = (height, width) |
|
|
if key in self.precompute_decoder_pos: |
|
|
return self.precompute_decoder_pos[key].to(device) |
|
|
else: |
|
|
pos = precompute_freqs_cis_2d(self.decoder_hidden_size // self.num_groups, height, width).to(device) |
|
|
self.precompute_decoder_pos[key] = pos |
|
|
return pos |
|
|
|
|
|
def initialize_weights(self): |
|
|
for embedder in [self.x_embedder, self.s_embedder]: |
|
|
nn.init.xavier_uniform_(embedder.proj.weight) |
|
|
if embedder.proj.bias is not None: |
|
|
nn.init.constant_(embedder.proj.bias, 0) |
|
|
|
|
|
nn.init.xavier_uniform_(self.encoder_to_decoder.weight) |
|
|
nn.init.xavier_uniform_(self.a_to_b_proj.weight) |
|
|
|
|
|
nn.init.normal_(self.y_embedder.weight, std=0.02) |
|
|
|
|
|
with torch.no_grad(): |
|
|
self.y_embedder.weight[0].fill_(0) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
|
|
|
|
nn.init.constant_(self.shared_encoder_adaLN[-1].weight, 0) |
|
|
nn.init.constant_(self.shared_encoder_adaLN[-1].bias, 0) |
|
|
nn.init.constant_(self.shared_decoder_adaLN[-1].weight, 0) |
|
|
nn.init.constant_(self.shared_decoder_adaLN[-1].bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
|
|
nn.init.xavier_uniform_(self.bc_projection.weight) |
|
|
|
|
|
def embed_condition(self, cond): |
|
|
|
|
|
device = self.y_embedder.weight.device |
|
|
|
|
|
max_len = self.max_tags |
|
|
batch_size = len(cond) |
|
|
|
|
|
padded_tags = torch.zeros(batch_size, max_len, dtype=torch.long, device=device) |
|
|
|
|
|
for i, tags in enumerate(cond): |
|
|
truncated_tags = tags[:max_len] |
|
|
padded_tags[i, :len(truncated_tags)] = torch.tensor(truncated_tags, dtype=torch.long, device=device) |
|
|
|
|
|
padding_mask = (padded_tags != 0) |
|
|
|
|
|
embedded = self.y_embedder(padded_tags) |
|
|
|
|
|
return embedded, padding_mask |
|
|
|
|
|
def forward(self, x, t, y): |
|
|
y_emb, padding_mask = self.embed_condition(y) |
|
|
|
|
|
return self.forward_emb(x, t, y_emb, padding_mask) |
|
|
|
|
|
@torch.compile() |
|
|
def forward_emb(self, x, t, y_emb, padding_mask=None): |
|
|
B, _, H, W = x.shape |
|
|
|
|
|
h_patches = H // self.patch_h |
|
|
w_patches = W // self.patch_w |
|
|
encoder_pos = self.fetch_encoder_pos(h_patches, w_patches, x.device) |
|
|
decoder_pos = self.fetch_decoder_pos(h_patches, w_patches, x.device) |
|
|
|
|
|
t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size) |
|
|
|
|
|
t_cond = nn.functional.silu(t_emb) |
|
|
|
|
|
s = self.s_embedder(x) |
|
|
|
|
|
s_section_a = s |
|
|
for i in range(min(1, self.num_blocks)): |
|
|
block_context = y_emb if i % 2 == 0 else None |
|
|
s_section_a = self.blocks[i](s_section_a, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN) |
|
|
|
|
|
s_section_a_projected = self.a_to_b_proj(s_section_a) |
|
|
|
|
|
s_section_b = s_section_a_projected |
|
|
|
|
|
for i in range(1, self.num_blocks - 3): |
|
|
block_context = y_emb if i % 2 == 0 else None |
|
|
s_section_b = self.blocks[i](s_section_b, t_cond, encoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_encoder_adaLN) |
|
|
|
|
|
s_concat = torch.cat([s_section_a, s_section_b], dim=-1) |
|
|
|
|
|
s = self.bc_projection(s_concat) |
|
|
|
|
|
for i in range(max(1, self.num_blocks - 3), self.num_blocks): |
|
|
block_context = y_emb if i % 2 == 0 else None |
|
|
s = self.blocks[i](s, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN) |
|
|
|
|
|
s = self.final_layer(s, H // self.patch_h, W // self.patch_w) |
|
|
|
|
|
return s |
|
|
|