| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Dict, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.utils.checkpoint import checkpoint |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import PeftAdapterMixin |
| from ..attention import BasicTransformerBlock, SkipFFTransformerBlock |
| from ..attention_processor import ( |
| ADDED_KV_ATTENTION_PROCESSORS, |
| CROSS_ATTENTION_PROCESSORS, |
| AttentionProcessor, |
| AttnAddedKVProcessor, |
| AttnProcessor, |
| ) |
| from ..embeddings import TimestepEmbedding, get_timestep_embedding |
| from ..modeling_utils import ModelMixin |
| from ..normalization import GlobalResponseNorm, RMSNorm |
| from ..resnet import Downsample2D, Upsample2D |
|
|
|
|
| class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): |
| _supports_gradient_checkpointing = True |
|
|
| @register_to_config |
| def __init__( |
| self, |
| |
| hidden_size: int = 1024, |
| use_bias: bool = False, |
| hidden_dropout: float = 0.0, |
| |
| cond_embed_dim: int = 768, |
| micro_cond_encode_dim: int = 256, |
| micro_cond_embed_dim: int = 1280, |
| encoder_hidden_size: int = 768, |
| |
| vocab_size: int = 8256, |
| codebook_size: int = 8192, |
| |
| in_channels: int = 768, |
| block_out_channels: int = 768, |
| num_res_blocks: int = 3, |
| downsample: bool = False, |
| upsample: bool = False, |
| block_num_heads: int = 12, |
| |
| num_hidden_layers: int = 22, |
| num_attention_heads: int = 16, |
| |
| attention_dropout: float = 0.0, |
| |
| intermediate_size: int = 2816, |
| |
| layer_norm_eps: float = 1e-6, |
| ln_elementwise_affine: bool = True, |
| sample_size: int = 64, |
| ): |
| super().__init__() |
|
|
| self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) |
| self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) |
|
|
| self.embed = UVit2DConvEmbed( |
| in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias |
| ) |
|
|
| self.cond_embed = TimestepEmbedding( |
| micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias |
| ) |
|
|
| self.down_block = UVitBlock( |
| block_out_channels, |
| num_res_blocks, |
| hidden_size, |
| hidden_dropout, |
| ln_elementwise_affine, |
| layer_norm_eps, |
| use_bias, |
| block_num_heads, |
| attention_dropout, |
| downsample, |
| False, |
| ) |
|
|
| self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) |
| self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) |
|
|
| self.transformer_layers = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=hidden_size, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=hidden_size // num_attention_heads, |
| dropout=hidden_dropout, |
| cross_attention_dim=hidden_size, |
| attention_bias=use_bias, |
| norm_type="ada_norm_continuous", |
| ada_norm_continous_conditioning_embedding_dim=hidden_size, |
| norm_elementwise_affine=ln_elementwise_affine, |
| norm_eps=layer_norm_eps, |
| ada_norm_bias=use_bias, |
| ff_inner_dim=intermediate_size, |
| ff_bias=use_bias, |
| attention_out_bias=use_bias, |
| ) |
| for _ in range(num_hidden_layers) |
| ] |
| ) |
|
|
| self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) |
| self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) |
|
|
| self.up_block = UVitBlock( |
| block_out_channels, |
| num_res_blocks, |
| hidden_size, |
| hidden_dropout, |
| ln_elementwise_affine, |
| layer_norm_eps, |
| use_bias, |
| block_num_heads, |
| attention_dropout, |
| downsample=False, |
| upsample=upsample, |
| ) |
|
|
| self.mlm_layer = ConvMlmLayer( |
| block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size |
| ) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value: bool = False) -> None: |
| pass |
|
|
| def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): |
| encoder_hidden_states = self.encoder_proj(encoder_hidden_states) |
| encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) |
|
|
| micro_cond_embeds = get_timestep_embedding( |
| micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 |
| ) |
|
|
| micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) |
|
|
| pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) |
| pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) |
| pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) |
|
|
| hidden_states = self.embed(input_ids) |
|
|
| hidden_states = self.down_block( |
| hidden_states, |
| pooled_text_emb=pooled_text_emb, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| ) |
|
|
| batch_size, channels, height, width = hidden_states.shape |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) |
|
|
| hidden_states = self.project_to_hidden_norm(hidden_states) |
| hidden_states = self.project_to_hidden(hidden_states) |
|
|
| for layer in self.transformer_layers: |
| if self.training and self.gradient_checkpointing: |
|
|
| def layer_(*args): |
| return checkpoint(layer, *args) |
|
|
| else: |
| layer_ = layer |
|
|
| hidden_states = layer_( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, |
| ) |
|
|
| hidden_states = self.project_from_hidden_norm(hidden_states) |
| hidden_states = self.project_from_hidden(hidden_states) |
|
|
| hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) |
|
|
| hidden_states = self.up_block( |
| hidden_states, |
| pooled_text_emb=pooled_text_emb, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| ) |
|
|
| logits = self.mlm_layer(hidden_states) |
|
|
| return logits |
|
|
| @property |
| |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| r""" |
| Returns: |
| `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| indexed by its weight name. |
| """ |
| |
| processors = {} |
|
|
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| if hasattr(module, "get_processor"): |
| processors[f"{name}.processor"] = module.get_processor() |
|
|
| for sub_name, child in module.named_children(): |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
|
|
| return processors |
|
|
| for name, module in self.named_children(): |
| fn_recursive_add_processors(name, module, processors) |
|
|
| return processors |
|
|
| |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| r""" |
| Sets the attention processor to use to compute attention. |
| |
| Parameters: |
| processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| for **all** `Attention` layers. |
| |
| If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| processor. This is strongly recommended when setting trainable attention processors. |
| |
| """ |
| count = len(self.attn_processors.keys()) |
|
|
| if isinstance(processor, dict) and len(processor) != count: |
| raise ValueError( |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| ) |
|
|
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| if hasattr(module, "set_processor"): |
| if not isinstance(processor, dict): |
| module.set_processor(processor) |
| else: |
| module.set_processor(processor.pop(f"{name}.processor")) |
|
|
| for sub_name, child in module.named_children(): |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
|
|
| for name, module in self.named_children(): |
| fn_recursive_attn_processor(name, module, processor) |
|
|
| |
| def set_default_attn_processor(self): |
| """ |
| Disables custom attention processors and sets the default attention implementation. |
| """ |
| if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
| processor = AttnAddedKVProcessor() |
| elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
| processor = AttnProcessor() |
| else: |
| raise ValueError( |
| f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" |
| ) |
|
|
| self.set_attn_processor(processor) |
|
|
|
|
| class UVit2DConvEmbed(nn.Module): |
| def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): |
| super().__init__() |
| self.embeddings = nn.Embedding(vocab_size, in_channels) |
| self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) |
| self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) |
|
|
| def forward(self, input_ids): |
| embeddings = self.embeddings(input_ids) |
| embeddings = self.layer_norm(embeddings) |
| embeddings = embeddings.permute(0, 3, 1, 2) |
| embeddings = self.conv(embeddings) |
| return embeddings |
|
|
|
|
| class UVitBlock(nn.Module): |
| def __init__( |
| self, |
| channels, |
| num_res_blocks: int, |
| hidden_size, |
| hidden_dropout, |
| ln_elementwise_affine, |
| layer_norm_eps, |
| use_bias, |
| block_num_heads, |
| attention_dropout, |
| downsample: bool, |
| upsample: bool, |
| ): |
| super().__init__() |
|
|
| if downsample: |
| self.downsample = Downsample2D( |
| channels, |
| use_conv=True, |
| padding=0, |
| name="Conv2d_0", |
| kernel_size=2, |
| norm_type="rms_norm", |
| eps=layer_norm_eps, |
| elementwise_affine=ln_elementwise_affine, |
| bias=use_bias, |
| ) |
| else: |
| self.downsample = None |
|
|
| self.res_blocks = nn.ModuleList( |
| [ |
| ConvNextBlock( |
| channels, |
| layer_norm_eps, |
| ln_elementwise_affine, |
| use_bias, |
| hidden_dropout, |
| hidden_size, |
| ) |
| for i in range(num_res_blocks) |
| ] |
| ) |
|
|
| self.attention_blocks = nn.ModuleList( |
| [ |
| SkipFFTransformerBlock( |
| channels, |
| block_num_heads, |
| channels // block_num_heads, |
| hidden_size, |
| use_bias, |
| attention_dropout, |
| channels, |
| attention_bias=use_bias, |
| attention_out_bias=use_bias, |
| ) |
| for _ in range(num_res_blocks) |
| ] |
| ) |
|
|
| if upsample: |
| self.upsample = Upsample2D( |
| channels, |
| use_conv_transpose=True, |
| kernel_size=2, |
| padding=0, |
| name="conv", |
| norm_type="rms_norm", |
| eps=layer_norm_eps, |
| elementwise_affine=ln_elementwise_affine, |
| bias=use_bias, |
| interpolate=False, |
| ) |
| else: |
| self.upsample = None |
|
|
| def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): |
| if self.downsample is not None: |
| x = self.downsample(x) |
|
|
| for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): |
| x = res_block(x, pooled_text_emb) |
|
|
| batch_size, channels, height, width = x.shape |
| x = x.view(batch_size, channels, height * width).permute(0, 2, 1) |
| x = attention_block( |
| x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs |
| ) |
| x = x.permute(0, 2, 1).view(batch_size, channels, height, width) |
|
|
| if self.upsample is not None: |
| x = self.upsample(x) |
|
|
| return x |
|
|
|
|
| class ConvNextBlock(nn.Module): |
| def __init__( |
| self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 |
| ): |
| super().__init__() |
| self.depthwise = nn.Conv2d( |
| channels, |
| channels, |
| kernel_size=3, |
| padding=1, |
| groups=channels, |
| bias=use_bias, |
| ) |
| self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) |
| self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) |
| self.channelwise_act = nn.GELU() |
| self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) |
| self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) |
| self.channelwise_dropout = nn.Dropout(hidden_dropout) |
| self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) |
|
|
| def forward(self, x, cond_embeds): |
| x_res = x |
|
|
| x = self.depthwise(x) |
|
|
| x = x.permute(0, 2, 3, 1) |
| x = self.norm(x) |
|
|
| x = self.channelwise_linear_1(x) |
| x = self.channelwise_act(x) |
| x = self.channelwise_norm(x) |
| x = self.channelwise_linear_2(x) |
| x = self.channelwise_dropout(x) |
|
|
| x = x.permute(0, 3, 1, 2) |
|
|
| x = x + x_res |
|
|
| scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) |
| x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] |
|
|
| return x |
|
|
|
|
| class ConvMlmLayer(nn.Module): |
| def __init__( |
| self, |
| block_out_channels: int, |
| in_channels: int, |
| use_bias: bool, |
| ln_elementwise_affine: bool, |
| layer_norm_eps: float, |
| codebook_size: int, |
| ): |
| super().__init__() |
| self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) |
| self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) |
| self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.conv1(hidden_states) |
| hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
| logits = self.conv2(hidden_states) |
| return logits |
|
|