| | import torch
|
| | import torch.nn as nn
|
| | from torch.utils.checkpoint import checkpoint
|
| | from .utils.attention import Attention, JointAttention
|
| | from .utils.modules import unpatchify, FeedForward
|
| | from .utils.modules import film_modulate
|
| |
|
| |
|
| | class AdaLN(nn.Module):
|
| | def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
|
| | super().__init__()
|
| | self.ada_mode = ada_mode
|
| | self.scale_shift_table = None
|
| | if ada_mode == 'ada':
|
| |
|
| | self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
|
| | elif ada_mode == 'ada_single':
|
| |
|
| | self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
|
| | elif ada_mode in ['ada_lora', 'ada_lora_bias']:
|
| | self.lora_a = nn.Linear(dim, r * 6, bias=False)
|
| | self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
|
| | self.scaling = alpha / r
|
| | if ada_mode == 'ada_lora_bias':
|
| |
|
| | self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | def forward(self, time_token=None, time_ada=None):
|
| | if self.ada_mode == 'ada':
|
| | assert time_ada is None
|
| | B = time_token.shape[0]
|
| | time_ada = self.time_ada(time_token).reshape(B, 6, -1)
|
| | elif self.ada_mode == 'ada_single':
|
| | B = time_ada.shape[0]
|
| | time_ada = time_ada.reshape(B, 6, -1)
|
| | time_ada = self.scale_shift_table[None] + time_ada
|
| | elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
|
| | B = time_ada.shape[0]
|
| | time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
|
| | time_ada = time_ada + time_ada_lora
|
| | time_ada = time_ada.reshape(B, 6, -1)
|
| | if self.scale_shift_table is not None:
|
| | time_ada = self.scale_shift_table[None] + time_ada
|
| | else:
|
| | raise NotImplementedError
|
| | return time_ada
|
| |
|
| |
|
| | class DiTBlock(nn.Module):
|
| | """
|
| | A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
| | """
|
| |
|
| | def __init__(self, dim, context_dim=None,
|
| | num_heads=8, mlp_ratio=4.,
|
| | qkv_bias=False, qk_scale=None, qk_norm=None,
|
| | act_layer='gelu', norm_layer=nn.LayerNorm,
|
| | time_fusion='none',
|
| | ada_lora_rank=None, ada_lora_alpha=None,
|
| | skip=False, skip_norm=False,
|
| | rope_mode='none',
|
| | context_norm=False,
|
| | use_checkpoint=False):
|
| |
|
| | super().__init__()
|
| | self.norm1 = norm_layer(dim)
|
| | self.attn = Attention(dim=dim,
|
| | num_heads=num_heads,
|
| | qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| | qk_norm=qk_norm,
|
| | rope_mode=rope_mode)
|
| |
|
| | if context_dim is not None:
|
| | self.use_context = True
|
| | self.cross_attn = Attention(dim=dim,
|
| | num_heads=num_heads,
|
| | context_dim=context_dim,
|
| | qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| | qk_norm=qk_norm,
|
| | rope_mode='none')
|
| | self.norm2 = norm_layer(dim)
|
| | if context_norm:
|
| | self.norm_context = norm_layer(context_dim)
|
| | else:
|
| | self.norm_context = nn.Identity()
|
| | else:
|
| | self.use_context = False
|
| |
|
| | self.norm3 = norm_layer(dim)
|
| | self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
|
| | activation_fn=act_layer, dropout=0)
|
| |
|
| | self.use_adanorm = True if time_fusion != 'token' else False
|
| | if self.use_adanorm:
|
| | self.adaln = AdaLN(dim, ada_mode=time_fusion,
|
| | r=ada_lora_rank, alpha=ada_lora_alpha)
|
| | if skip:
|
| | self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
|
| | self.skip_linear = nn.Linear(2 * dim, dim)
|
| | else:
|
| | self.skip_linear = None
|
| |
|
| | self.use_checkpoint = use_checkpoint
|
| |
|
| | def forward(self, x, time_token=None, time_ada=None,
|
| | skip=None, context=None,
|
| | x_mask=None, context_mask=None, extras=None):
|
| | if self.use_checkpoint:
|
| | return checkpoint(self._forward, x,
|
| | time_token, time_ada, skip, context,
|
| | x_mask, context_mask, extras,
|
| | use_reentrant=False)
|
| | else:
|
| | return self._forward(x,
|
| | time_token, time_ada, skip, context,
|
| | x_mask, context_mask, extras)
|
| |
|
| | def _forward(self, x, time_token=None, time_ada=None,
|
| | skip=None, context=None,
|
| | x_mask=None, context_mask=None, extras=None):
|
| | B, T, C = x.shape
|
| | if self.skip_linear is not None:
|
| | assert skip is not None
|
| | cat = torch.cat([x, skip], dim=-1)
|
| | cat = self.skip_norm(cat)
|
| | x = self.skip_linear(cat)
|
| |
|
| | if self.use_adanorm:
|
| | time_ada = self.adaln(time_token, time_ada)
|
| | (shift_msa, scale_msa, gate_msa,
|
| | shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
|
| |
|
| |
|
| | if self.use_adanorm:
|
| | x_norm = film_modulate(self.norm1(x), shift=shift_msa,
|
| | scale=scale_msa)
|
| | x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
|
| | context_mask=x_mask,
|
| | extras=extras)
|
| | else:
|
| | x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
|
| | extras=extras)
|
| |
|
| |
|
| | if self.use_context:
|
| | assert context is not None
|
| | x = x + self.cross_attn(x=self.norm2(x),
|
| | context=self.norm_context(context),
|
| | context_mask=context_mask, extras=extras)
|
| |
|
| |
|
| | if self.use_adanorm:
|
| | x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
|
| | x = x + (1 - gate_mlp) * self.mlp(x_norm)
|
| | else:
|
| | x = x + self.mlp(self.norm3(x))
|
| |
|
| | return x
|
| |
|
| |
|
| | class JointDiTBlock(nn.Module):
|
| | """
|
| | A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
| | """
|
| |
|
| | def __init__(self, dim, context_dim=None,
|
| | num_heads=8, mlp_ratio=4.,
|
| | qkv_bias=False, qk_scale=None, qk_norm=None,
|
| | act_layer='gelu', norm_layer=nn.LayerNorm,
|
| | time_fusion='none',
|
| | ada_lora_rank=None, ada_lora_alpha=None,
|
| | skip=(False, False),
|
| | rope_mode=False,
|
| | context_norm=False,
|
| | use_checkpoint=False,):
|
| |
|
| | super().__init__()
|
| |
|
| | assert context_dim is None
|
| | self.attn_norm_x = norm_layer(dim)
|
| | self.attn_norm_c = norm_layer(dim)
|
| | self.attn = JointAttention(dim=dim,
|
| | num_heads=num_heads,
|
| | qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| | qk_norm=qk_norm,
|
| | rope_mode=rope_mode)
|
| | self.ffn_norm_x = norm_layer(dim)
|
| | self.ffn_norm_c = norm_layer(dim)
|
| | self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
|
| | activation_fn=act_layer, dropout=0)
|
| | self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
|
| | activation_fn=act_layer, dropout=0)
|
| |
|
| |
|
| | self.use_adanorm = True if time_fusion != 'token' else False
|
| | if self.use_adanorm:
|
| | self.adaln = AdaLN(dim, ada_mode=time_fusion,
|
| | r=ada_lora_rank, alpha=ada_lora_alpha)
|
| |
|
| | if skip is False:
|
| | skip_x, skip_c = False, False
|
| | else:
|
| | skip_x, skip_c = skip
|
| |
|
| | self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
|
| | self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
|
| |
|
| | self.use_checkpoint = use_checkpoint
|
| |
|
| | def forward(self, x, time_token=None, time_ada=None,
|
| | skip=None, context=None,
|
| | x_mask=None, context_mask=None, extras=None):
|
| | if self.use_checkpoint:
|
| | return checkpoint(self._forward, x,
|
| | time_token, time_ada, skip,
|
| | context, x_mask, context_mask, extras,
|
| | use_reentrant=False)
|
| | else:
|
| | return self._forward(x,
|
| | time_token, time_ada, skip,
|
| | context, x_mask, context_mask, extras)
|
| |
|
| | def _forward(self, x, time_token=None, time_ada=None,
|
| | skip=None, context=None,
|
| | x_mask=None, context_mask=None, extras=None):
|
| |
|
| | assert context is None and context_mask is None
|
| |
|
| | context, x = x[:, :extras, :], x[:, extras:, :]
|
| | context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
|
| |
|
| | if skip is not None:
|
| | skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
|
| |
|
| | B, T, C = x.shape
|
| | if self.skip_linear_x is not None:
|
| | x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
|
| |
|
| | if self.skip_linear_c is not None:
|
| | context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
|
| |
|
| | if self.use_adanorm:
|
| | time_ada = self.adaln(time_token, time_ada)
|
| | (shift_msa, scale_msa, gate_msa,
|
| | shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
|
| |
|
| |
|
| | x_norm = self.attn_norm_x(x)
|
| | c_norm = self.attn_norm_c(context)
|
| | if self.use_adanorm:
|
| | x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
|
| | x_out, c_out = self.attn(x_norm, context=c_norm,
|
| | x_mask=x_mask, context_mask=context_mask,
|
| | extras=extras)
|
| | if self.use_adanorm:
|
| | x = x + (1 - gate_msa) * x_out
|
| | else:
|
| | x = x + x_out
|
| | context = context + c_out
|
| |
|
| |
|
| | if self.use_adanorm:
|
| | x_norm = film_modulate(self.ffn_norm_x(x),
|
| | shift=shift_mlp, scale=scale_mlp)
|
| | x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
|
| | else:
|
| | x = x + self.mlp_x(self.ffn_norm_x(x))
|
| |
|
| | c_norm = self.ffn_norm_c(context)
|
| | context = context + self.mlp_c(c_norm)
|
| |
|
| | return torch.cat((context, x), dim=1)
|
| |
|
| |
|
| | class FinalBlock(nn.Module):
|
| | def __init__(self, embed_dim, patch_size, in_chans,
|
| | img_size,
|
| | input_type='2d',
|
| | norm_layer=nn.LayerNorm,
|
| | use_conv=True,
|
| | use_adanorm=True):
|
| | super().__init__()
|
| | self.in_chans = in_chans
|
| | self.img_size = img_size
|
| | self.input_type = input_type
|
| |
|
| | self.norm = norm_layer(embed_dim)
|
| | if use_adanorm:
|
| | self.use_adanorm = True
|
| | else:
|
| | self.use_adanorm = False
|
| |
|
| | if input_type == '2d':
|
| | self.patch_dim = patch_size ** 2 * in_chans
|
| | self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
| | if use_conv:
|
| | self.final_layer = nn.Conv2d(self.in_chans, self.in_chans,
|
| | 3, padding=1)
|
| | else:
|
| | self.final_layer = nn.Identity()
|
| |
|
| | elif input_type == '1d':
|
| | self.patch_dim = patch_size * in_chans
|
| | self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
| | if use_conv:
|
| | self.final_layer = nn.Conv1d(self.in_chans, self.in_chans,
|
| | 3, padding=1)
|
| | else:
|
| | self.final_layer = nn.Identity()
|
| |
|
| | def forward(self, x, time_ada=None, extras=0):
|
| | B, T, C = x.shape
|
| | x = x[:, extras:, :]
|
| |
|
| | if self.use_adanorm:
|
| | shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
|
| | x = film_modulate(self.norm(x), shift, scale)
|
| | else:
|
| | x = self.norm(x)
|
| | x = self.linear(x)
|
| | x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
|
| | x = self.final_layer(x)
|
| | return x |