| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import math |
| | from collections import OrderedDict |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.jit import Final |
| | from timm.models.vision_transformer import Attention, Mlp, RmsNorm, use_fused_attn |
| |
|
| |
|
| | |
| | |
| | |
| | class TimestepEmbedder(nn.Module): |
| | """ |
| | Embeds scalar timesteps into vector representations. |
| | """ |
| | def __init__(self, hidden_size, frequency_embedding_size=256, dtype=torch.bfloat16): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True), |
| | ) |
| | self.frequency_embedding_size = frequency_embedding_size |
| | self.dtype = dtype |
| |
|
| | def timestep_embedding(self, t, dim, max_period=10000): |
| | """ |
| | Create sinusoidal timestep embeddings. |
| | :param t: a 1-D Tensor of N indices, one per batch element. |
| | These may be fractional. |
| | :param dim: the dimension of the output. |
| | :param max_period: controls the minimum frequency of the embeddings. |
| | :return: an (N, D) Tensor of positional embeddings. |
| | """ |
| | |
| | half = dim // 2 |
| | freqs = torch.exp( |
| | -math.log(max_period) * torch.arange( |
| | start=0, end=half, dtype=torch.float32, device=t.device) / half |
| | ) |
| | args = t[:, None].float() * freqs[None] |
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| | if dim % 2: |
| | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| | return embedding.to(self.dtype) |
| |
|
| | def forward(self, t): |
| | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| | t_emb = self.mlp(t_freq) |
| | return t_emb |
| |
|
| |
|
| | |
| | |
| | |
| | class CrossAttention(nn.Module): |
| | """ |
| | A cross-attention layer with flash attention. |
| | """ |
| | fused_attn: Final[bool] |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int = 8, |
| | qkv_bias: bool = False, |
| | qk_norm: bool = False, |
| | attn_drop: float = 0, |
| | proj_drop: float = 0, |
| | norm_layer: nn.Module = nn.LayerNorm, |
| | ) -> None: |
| | super().__init__() |
| | assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| | self.num_heads = num_heads |
| | self.head_dim = dim // num_heads |
| | self.scale = self.head_dim ** -0.5 |
| | self.fused_attn = use_fused_attn() |
| |
|
| | self.q = nn.Linear(dim, dim, bias=qkv_bias) |
| | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) |
| | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| | |
| | def forward(self, x: torch.Tensor, c: torch.Tensor, |
| | mask: torch.Tensor | None = None) -> torch.Tensor: |
| | B, N, C = x.shape |
| | _, L, _ = c.shape |
| | q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
| | kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| | k, v = kv.unbind(0) |
| | q, k = self.q_norm(q), self.k_norm(k) |
| |
|
| | |
| | if mask is not None: |
| | mask = mask.reshape(B, 1, 1, L) |
| | mask = mask.expand(-1, -1, N, -1) |
| | |
| | if self.fused_attn: |
| | x = F.scaled_dot_product_attention( |
| | query=q, |
| | key=k, |
| | value=v, |
| | dropout_p=self.attn_drop.p if self.training else 0., |
| | attn_mask=mask |
| | ) |
| | else: |
| | q = q * self.scale |
| | attn = q @ k.transpose(-2, -1) |
| | if mask is not None: |
| | attn = attn.masked_fill_(mask.logical_not(), float('-inf')) |
| | attn = attn.softmax(dim=-1) |
| | if self.attn_drop.p > 0: |
| | attn = self.attn_drop(attn) |
| | x = attn @ v |
| | |
| | x = x.permute(0, 2, 1, 3).reshape(B, N, C) |
| | x = self.proj(x) |
| | if self.proj_drop.p > 0: |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| | class RDTBlock(nn.Module): |
| | """ |
| | A RDT block with cross-attention conditioning. |
| | """ |
| | def __init__(self, hidden_size, num_heads, **block_kwargs): |
| | super().__init__() |
| | self.norm1 = RmsNorm(hidden_size, eps=1e-6) |
| | self.attn = Attention( |
| | dim=hidden_size, num_heads=num_heads, |
| | qkv_bias=True, qk_norm=True, |
| | norm_layer=RmsNorm,**block_kwargs) |
| | self.cross_attn = CrossAttention( |
| | hidden_size, num_heads=num_heads, |
| | qkv_bias=True, qk_norm=True, |
| | norm_layer=RmsNorm,**block_kwargs) |
| | |
| | self.norm2 = RmsNorm(hidden_size, eps=1e-6) |
| | approx_gelu = lambda: nn.GELU(approximate="tanh") |
| | self.ffn = Mlp(in_features=hidden_size, |
| | hidden_features=hidden_size, |
| | act_layer=approx_gelu, drop=0) |
| | self.norm3 = RmsNorm(hidden_size, eps=1e-6) |
| |
|
| | def forward(self, x, c, mask=None): |
| | origin_x = x |
| | x = self.norm1(x) |
| | x = self.attn(x) |
| | x = x + origin_x |
| | |
| | origin_x = x |
| | x = self.norm2(x) |
| | x = self.cross_attn(x, c, mask) |
| | x = x + origin_x |
| | |
| | origin_x = x |
| | x = self.norm3(x) |
| | x = self.ffn(x) |
| | x = x + origin_x |
| | |
| | return x |
| |
|
| |
|
| | class FinalLayer(nn.Module): |
| | """ |
| | The final layer of RDT. |
| | """ |
| | def __init__(self, hidden_size, out_channels): |
| | super().__init__() |
| | self.norm_final = RmsNorm(hidden_size, eps=1e-6) |
| | approx_gelu = lambda: nn.GELU(approximate="tanh") |
| | self.ffn_final = Mlp(in_features=hidden_size, |
| | hidden_features=hidden_size, |
| | out_features=out_channels, |
| | act_layer=approx_gelu, drop=0) |
| |
|
| | def forward(self, x): |
| | x = self.norm_final(x) |
| | x = self.ffn_final(x) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| | """ |
| | embed_dim: output dimension for each position |
| | pos: a list of positions to be encoded: size (M,) |
| | out: (M, D) |
| | """ |
| | assert embed_dim % 2 == 0 |
| | omega = np.arange(embed_dim // 2, dtype=np.float64) |
| | omega /= embed_dim / 2. |
| | omega = 1. / 10000**omega |
| |
|
| | if not isinstance(pos, np.ndarray): |
| | pos = np.array(pos, dtype=np.float64) |
| | pos = pos.reshape(-1) |
| | out = np.einsum('m,d->md', pos, omega) |
| |
|
| | emb_sin = np.sin(out) |
| | emb_cos = np.cos(out) |
| |
|
| | emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| | return emb |
| |
|
| |
|
| | def get_nd_sincos_pos_embed_from_grid(embed_dim, grid_sizes): |
| | """ |
| | embed_dim: output dimension for each position |
| | grid_sizes: the grids sizes in each dimension (K,). |
| | out: (grid_sizes[0], ..., grid_sizes[K-1], D) |
| | """ |
| | num_sizes = len(grid_sizes) |
| | |
| | num_valid_sizes = len([x for x in grid_sizes if x > 1]) |
| | emb = np.zeros(grid_sizes + (embed_dim,)) |
| | |
| | dim_for_each_grid = embed_dim // num_valid_sizes |
| | |
| | if dim_for_each_grid % 2 != 0: |
| | dim_for_each_grid -= 1 |
| | valid_size_idx = 0 |
| | for size_idx in range(num_sizes): |
| | grid_size = grid_sizes[size_idx] |
| | if grid_size <= 1: |
| | continue |
| | pos = np.arange(grid_size) |
| | posemb_shape = [1] * len(grid_sizes) + [dim_for_each_grid] |
| | posemb_shape[size_idx] = -1 |
| | emb[..., valid_size_idx * dim_for_each_grid:(valid_size_idx + 1) * dim_for_each_grid] += \ |
| | get_1d_sincos_pos_embed_from_grid(dim_for_each_grid, pos).reshape(posemb_shape) |
| | valid_size_idx += 1 |
| | return emb |
| |
|
| |
|
| | def get_multimodal_cond_pos_embed(embed_dim, mm_cond_lens: OrderedDict, |
| | embed_modality=True): |
| | """ |
| | Generate position embeddings for multimodal conditions. |
| | |
| | mm_cond_lens: an OrderedDict containing |
| | (modality name, modality token length) pairs. |
| | For `"image"` modality, the value can be a multi-dimensional tuple. |
| | If the length < 0, it means there is no position embedding for the modality or grid. |
| | embed_modality: whether to embed the modality information. Default is True. |
| | """ |
| | num_modalities = len(mm_cond_lens) |
| | modality_pos_embed = np.zeros((num_modalities, embed_dim)) |
| | if embed_modality: |
| | |
| | |
| | modality_sincos_embed = get_1d_sincos_pos_embed_from_grid( |
| | embed_dim // 2, torch.arange(num_modalities)) |
| | modality_pos_embed[:, :embed_dim // 2] = modality_sincos_embed |
| | |
| | pos_embed_dim = embed_dim // 2 |
| | else: |
| | |
| | pos_embed_dim = embed_dim |
| | |
| | |
| | c_pos_emb = np.zeros((0, embed_dim)) |
| | for idx, (modality, cond_len) in enumerate(mm_cond_lens.items()): |
| | if modality == "image" and \ |
| | (isinstance(cond_len, tuple) or isinstance(cond_len, list)): |
| | all_grid_sizes = tuple([abs(x) for x in cond_len]) |
| | embed_grid_sizes = tuple([x if x > 0 else 1 for x in cond_len]) |
| | cond_sincos_embed = get_nd_sincos_pos_embed_from_grid( |
| | pos_embed_dim, embed_grid_sizes) |
| | cond_pos_embed = np.zeros(all_grid_sizes + (embed_dim,)) |
| | cond_pos_embed[..., -pos_embed_dim:] += cond_sincos_embed |
| | cond_pos_embed = cond_pos_embed.reshape((-1, embed_dim)) |
| | else: |
| | cond_sincos_embed = get_1d_sincos_pos_embed_from_grid( |
| | pos_embed_dim, torch.arange(cond_len if cond_len > 0 else 1)) |
| | cond_pos_embed = np.zeros((abs(cond_len), embed_dim)) |
| | cond_pos_embed[:, -pos_embed_dim:] += cond_sincos_embed |
| | cond_pos_embed += modality_pos_embed[idx] |
| | c_pos_emb = np.concatenate([c_pos_emb, cond_pos_embed], axis=0) |
| | |
| | return c_pos_emb |
| |
|