| | import torch |
| | import torch.nn as nn |
| | import torch.utils.checkpoint |
| | import math |
| | from .utils.modules import PatchEmbed, TimestepEmbedder |
| | from .utils.modules import PE_wrapper, RMSNorm |
| | from .blocks import DiTBlock, JointDiTBlock, FinalBlock |
| |
|
| |
|
| | class UDiT(nn.Module): |
| | def __init__(self, |
| | img_size=224, patch_size=16, in_chans=3, |
| | input_type='2d', out_chans=None, |
| | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., |
| | qkv_bias=False, qk_scale=None, qk_norm=None, |
| | act_layer='gelu', norm_layer='layernorm', |
| | context_norm=False, |
| | use_checkpoint=False, |
| | |
| | time_fusion='token', |
| | ada_lora_rank=None, ada_lora_alpha=None, |
| | cls_dim=None, |
| | |
| | context_dim=768, context_fusion='concat', |
| | context_max_length=128, context_pe_method='sinu', |
| | pe_method='abs', rope_mode='none', |
| | use_conv=True, |
| | skip=True, skip_norm=True): |
| | super().__init__() |
| | self.num_features = self.embed_dim = embed_dim |
| |
|
| | |
| | self.in_chans = in_chans |
| | self.input_type = input_type |
| | if self.input_type == '2d': |
| | num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) |
| | elif self.input_type == '1d': |
| | num_patches = img_size // patch_size |
| | self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, |
| | embed_dim=embed_dim, input_type=input_type) |
| | out_chans = in_chans if out_chans is None else out_chans |
| | self.out_chans = out_chans |
| |
|
| | |
| | self.rope = rope_mode |
| | self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method, |
| | length=num_patches) |
| |
|
| | print(f'x position embedding: {pe_method}') |
| | print(f'rope mode: {self.rope}') |
| |
|
| | |
| | self.time_embed = TimestepEmbedder(embed_dim) |
| | self.time_fusion = time_fusion |
| | self.use_adanorm = False |
| |
|
| | |
| | if cls_dim is not None: |
| | self.cls_embed = nn.Sequential( |
| | nn.Linear(cls_dim, embed_dim, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(embed_dim, embed_dim, bias=True),) |
| | else: |
| | self.cls_embed = None |
| |
|
| | |
| | if time_fusion == 'token': |
| | |
| | self.extras = 2 if self.cls_embed else 1 |
| | self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras) |
| | elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']: |
| | self.use_adanorm = True |
| | |
| | self.time_act = nn.SiLU() |
| | self.extras = 0 |
| | self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True) |
| | if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']: |
| | |
| | self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) |
| | else: |
| | self.time_ada = None |
| | else: |
| | raise NotImplementedError |
| | print(f'time fusion mode: {self.time_fusion}') |
| |
|
| | |
| | |
| | self.use_context = False |
| | self.context_cross = False |
| | self.context_max_length = context_max_length |
| | self.context_fusion = 'none' |
| | if context_dim is not None: |
| | self.use_context = True |
| | self.context_embed = nn.Sequential( |
| | nn.Linear(context_dim, embed_dim, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(embed_dim, embed_dim, bias=True),) |
| | self.context_fusion = context_fusion |
| | if context_fusion == 'concat' or context_fusion == 'joint': |
| | self.extras += context_max_length |
| | self.context_pe = PE_wrapper(dim=embed_dim, |
| | method=context_pe_method, |
| | length=context_max_length) |
| | |
| | context_dim = None |
| | elif context_fusion == 'cross': |
| | self.context_pe = PE_wrapper(dim=embed_dim, |
| | method=context_pe_method, |
| | length=context_max_length) |
| | self.context_cross = True |
| | context_dim = embed_dim |
| | else: |
| | raise NotImplementedError |
| | print(f'context fusion mode: {context_fusion}') |
| | print(f'context position embedding: {context_pe_method}') |
| |
|
| | if self.context_fusion == 'joint': |
| | Block = JointDiTBlock |
| | self.use_skip = skip[0] |
| | else: |
| | Block = DiTBlock |
| | self.use_skip = skip |
| |
|
| | |
| | if norm_layer == 'layernorm': |
| | norm_layer = nn.LayerNorm |
| | elif norm_layer == 'rmsnorm': |
| | norm_layer = RMSNorm |
| | else: |
| | raise NotImplementedError |
| |
|
| | print(f'use long skip connection: {skip}') |
| | self.in_blocks = nn.ModuleList([ |
| | Block( |
| | dim=embed_dim, context_dim=context_dim, num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, |
| | act_layer=act_layer, norm_layer=norm_layer, |
| | time_fusion=time_fusion, |
| | ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, |
| | skip=False, skip_norm=False, |
| | rope_mode=self.rope, |
| | context_norm=context_norm, |
| | use_checkpoint=use_checkpoint) |
| | for _ in range(depth // 2)]) |
| |
|
| | self.mid_block = Block( |
| | dim=embed_dim, context_dim=context_dim, num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, |
| | act_layer=act_layer, norm_layer=norm_layer, |
| | time_fusion=time_fusion, |
| | ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, |
| | skip=False, skip_norm=False, |
| | rope_mode=self.rope, |
| | context_norm=context_norm, |
| | use_checkpoint=use_checkpoint) |
| |
|
| | self.out_blocks = nn.ModuleList([ |
| | Block( |
| | dim=embed_dim, context_dim=context_dim, num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, |
| | act_layer=act_layer, norm_layer=norm_layer, |
| | time_fusion=time_fusion, |
| | ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, |
| | skip=skip, skip_norm=skip_norm, |
| | rope_mode=self.rope, |
| | context_norm=context_norm, |
| | use_checkpoint=use_checkpoint) |
| | for _ in range(depth // 2)]) |
| |
|
| | |
| | self.use_conv = use_conv |
| | self.final_block = FinalBlock(embed_dim=embed_dim, |
| | patch_size=patch_size, |
| | img_size=img_size, |
| | in_chans=out_chans, |
| | input_type=input_type, |
| | norm_layer=norm_layer, |
| | use_conv=use_conv, |
| | use_adanorm=self.use_adanorm) |
| | self.initialize_weights() |
| |
|
| | def _init_ada(self): |
| | if self.time_fusion == 'ada': |
| | nn.init.constant_(self.time_ada_final.weight, 0) |
| | nn.init.constant_(self.time_ada_final.bias, 0) |
| | for block in self.in_blocks: |
| | nn.init.constant_(block.adaln.time_ada.weight, 0) |
| | nn.init.constant_(block.adaln.time_ada.bias, 0) |
| | nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) |
| | nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) |
| | for block in self.out_blocks: |
| | nn.init.constant_(block.adaln.time_ada.weight, 0) |
| | nn.init.constant_(block.adaln.time_ada.bias, 0) |
| | elif self.time_fusion == 'ada_single': |
| | nn.init.constant_(self.time_ada.weight, 0) |
| | nn.init.constant_(self.time_ada.bias, 0) |
| | nn.init.constant_(self.time_ada_final.weight, 0) |
| | nn.init.constant_(self.time_ada_final.bias, 0) |
| | elif self.time_fusion in ['ada_lora', 'ada_lora_bias']: |
| | nn.init.constant_(self.time_ada.weight, 0) |
| | nn.init.constant_(self.time_ada.bias, 0) |
| | nn.init.constant_(self.time_ada_final.weight, 0) |
| | nn.init.constant_(self.time_ada_final.bias, 0) |
| | for block in self.in_blocks: |
| | nn.init.kaiming_uniform_(block.adaln.lora_a.weight, |
| | a=math.sqrt(5)) |
| | nn.init.constant_(block.adaln.lora_b.weight, 0) |
| | nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight, |
| | a=math.sqrt(5)) |
| | nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) |
| | for block in self.out_blocks: |
| | nn.init.kaiming_uniform_(block.adaln.lora_a.weight, |
| | a=math.sqrt(5)) |
| | nn.init.constant_(block.adaln.lora_b.weight, 0) |
| |
|
| | def initialize_weights(self): |
| | |
| | def _basic_init(module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.xavier_uniform_(module.weight) |
| | if module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | self.apply(_basic_init) |
| |
|
| | |
| | w = self.patch_embed.proj.weight.data |
| | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| | nn.init.constant_(self.patch_embed.proj.bias, 0) |
| |
|
| | |
| | if self.use_adanorm: |
| | self._init_ada() |
| |
|
| | |
| | if self.context_cross: |
| | for block in self.in_blocks: |
| | nn.init.constant_(block.cross_attn.proj.weight, 0) |
| | nn.init.constant_(block.cross_attn.proj.bias, 0) |
| | nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) |
| | nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) |
| | for block in self.out_blocks: |
| | nn.init.constant_(block.cross_attn.proj.weight, 0) |
| | nn.init.constant_(block.cross_attn.proj.bias, 0) |
| |
|
| | |
| | if self.cls_embed: |
| | if self.use_adanorm: |
| | nn.init.constant_(self.cls_embed[-1].weight, 0) |
| | nn.init.constant_(self.cls_embed[-1].bias, 0) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if self.use_conv: |
| | nn.init.xavier_uniform_(self.final_block.final_layer.weight) |
| | nn.init.constant_(self.final_block.final_layer.bias, 0) |
| |
|
| | def _concat_x_context(self, x, context, x_mask=None, context_mask=None): |
| | assert context.shape[-2] == self.context_max_length |
| | |
| | B = x.shape[0] |
| | |
| | if x_mask is None: |
| | x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() |
| | if context_mask is None: |
| | context_mask = torch.ones(B, context.shape[-2], |
| | device=context.device).bool() |
| | |
| | x_mask = torch.cat([context_mask, x_mask], dim=1) |
| | |
| | x = torch.cat((context, x), dim=1) |
| | return x, x_mask |
| |
|
| | def forward(self, x, timesteps, context, |
| | x_mask=None, context_mask=None, |
| | cls_token=None, controlnet_skips=None, |
| | ): |
| | |
| | if timesteps.dim() == 0: |
| | timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) |
| |
|
| | x = self.patch_embed(x) |
| | x = self.x_pe(x) |
| |
|
| | B, L, D = x.shape |
| |
|
| | if self.use_context: |
| | context_token = self.context_embed(context) |
| | context_token = self.context_pe(context_token) |
| | if self.context_fusion == 'concat' or self.context_fusion == 'joint': |
| | x, x_mask = self._concat_x_context(x=x, context=context_token, |
| | x_mask=x_mask, |
| | context_mask=context_mask) |
| | context_token, context_mask = None, None |
| | else: |
| | context_token, context_mask = None, None |
| |
|
| | time_token = self.time_embed(timesteps) |
| | if self.cls_embed: |
| | cls_token = self.cls_embed(cls_token) |
| | time_ada = None |
| | time_ada_final = None |
| | if self.use_adanorm: |
| | if self.cls_embed: |
| | time_token = time_token + cls_token |
| | time_token = self.time_act(time_token) |
| | time_ada_final = self.time_ada_final(time_token) |
| | if self.time_ada is not None: |
| | time_ada = self.time_ada(time_token) |
| | else: |
| | time_token = time_token.unsqueeze(dim=1) |
| | if self.cls_embed: |
| | cls_token = cls_token.unsqueeze(dim=1) |
| | time_token = torch.cat([time_token, cls_token], dim=1) |
| | time_token = self.time_pe(time_token) |
| | x = torch.cat((time_token, x), dim=1) |
| | if x_mask is not None: |
| | x_mask = torch.cat( |
| | [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), |
| | x_mask], dim=1) |
| | time_token = None |
| |
|
| | skips = [] |
| | for blk in self.in_blocks: |
| | x = blk(x=x, time_token=time_token, time_ada=time_ada, |
| | skip=None, context=context_token, |
| | x_mask=x_mask, context_mask=context_mask, |
| | extras=self.extras) |
| | if self.use_skip: |
| | skips.append(x) |
| |
|
| | x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada, |
| | skip=None, context=context_token, |
| | x_mask=x_mask, context_mask=context_mask, |
| | extras=self.extras) |
| | for blk in self.out_blocks: |
| | if self.use_skip: |
| | skip = skips.pop() |
| | if controlnet_skips: |
| | |
| | skip = skip + controlnet_skips.pop() |
| | else: |
| | skip = None |
| | if controlnet_skips: |
| | |
| | x = x + controlnet_skips.pop() |
| |
|
| | x = blk(x=x, time_token=time_token, time_ada=time_ada, |
| | skip=skip, context=context_token, |
| | x_mask=x_mask, context_mask=context_mask, |
| | extras=self.extras) |
| |
|
| | x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) |
| |
|
| | return x |