| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import math |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
| from diffusers.models.attention_dispatch import dispatch_attention_fn |
| from diffusers.models.attention_processor import Attention, AttentionProcessor |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput |
| from diffusers.models.modeling_utils import ModelMixin |
| from diffusers.models.normalization import RMSNorm |
| from diffusers.utils import ( |
| is_torch_version, |
| ) |
| from diffusers.utils.torch_utils import maybe_allow_in_graph |
| from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
| ADALN_EMBED_DIM = 256 |
| SEQ_MULTI_OF = 32 |
|
|
|
|
| def zero_module(module): |
| """ |
| Initializes the parameters of a given module with zeros. |
| |
| Args: |
| module (nn.Module): The module to be zero-initialized. |
| |
| Returns: |
| nn.Module: The same module with its parameters initialized to zero. |
| """ |
| for p in module.parameters(): |
| nn.init.zeros_(p) |
| return module |
|
|
|
|
| class TimestepEmbedder(nn.Module): |
| """ |
| A module to embed timesteps into a higher-dimensional space using sinusoidal embeddings |
| followed by a multilayer perceptron (MLP). |
| """ |
|
|
| def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): |
| """ |
| Initializes the TimestepEmbedder module. |
| |
| Args: |
| out_size (int): The output dimension of the embedding. |
| mid_size (int, optional): The intermediate dimension of the MLP. Defaults to `out_size`. |
| frequency_embedding_size (int, optional): The dimension of the sinusoidal frequency embedding. Defaults to 256. |
| """ |
| super().__init__() |
| if mid_size is None: |
| mid_size = out_size |
| self.mlp = nn.Sequential( |
| nn.Linear( |
| frequency_embedding_size, |
| mid_size, |
| bias=True, |
| ), |
| nn.SiLU(), |
| nn.Linear( |
| mid_size, |
| out_size, |
| bias=True, |
| ), |
| ) |
| self.frequency_embedding_size = frequency_embedding_size |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| """ |
| Creates sinusoidal timestep embeddings. |
| |
| Args: |
| t (torch.Tensor): A 1-D Tensor of N timesteps. |
| dim (int): The dimension of the embedding. |
| max_period (int, optional): The maximum period for the sinusoidal frequencies. Defaults to 10000. |
| |
| Returns: |
| torch.Tensor: The timestep embeddings with shape (N, dim). |
| """ |
| with torch.amp.autocast("cuda", enabled=False): |
| half = dim // 2 |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) |
| args = t[:, None] * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
| def forward(self, t): |
| """ |
| Processes the input timesteps to generate embeddings. |
| |
| Args: |
| t (torch.Tensor): The input timesteps. |
| |
| Returns: |
| torch.Tensor: The final timestep embeddings after passing through the MLP. |
| """ |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| weight_dtype = self.mlp[0].weight.dtype |
| if weight_dtype.is_floating_point: |
| t_freq = t_freq.to(weight_dtype) |
| t_emb = self.mlp(t_freq) |
| return t_emb |
|
|
|
|
| class FeedForward(nn.Module): |
| """ |
| A Feed-Forward Network module using SwiGLU activation. |
| """ |
|
|
| def __init__(self, dim: int, hidden_dim: int): |
| """ |
| Initializes the FeedForward module. |
| |
| Args: |
| dim (int): Input and output dimension. |
| hidden_dim (int): The hidden dimension of the network. |
| """ |
| super().__init__() |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def _forward_silu_gating(self, x1, x3): |
| """ |
| Applies the SiLU gating mechanism. |
| |
| Args: |
| x1 (torch.Tensor): The first intermediate tensor. |
| x3 (torch.Tensor): The second intermediate tensor (gate). |
| |
| Returns: |
| torch.Tensor: The result of the gating operation. |
| """ |
| return F.silu(x1) * x3 |
|
|
| def forward(self, x): |
| """ |
| Defines the forward pass of the FeedForward network. |
| |
| Args: |
| x (torch.Tensor): The input tensor. |
| |
| Returns: |
| torch.Tensor: The output tensor. |
| """ |
| return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) |
|
|
|
|
| class FinalLayer(nn.Module): |
| """ |
| The final layer of the transformer, which applies AdaLN modulation and a linear projection. |
| """ |
|
|
| def __init__(self, hidden_size, out_channels): |
| """ |
| Initializes the FinalLayer module. |
| |
| Args: |
| hidden_size (int): The input hidden size. |
| out_channels (int): The output dimension (number of channels). |
| """ |
| super().__init__() |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.linear = nn.Linear(hidden_size, out_channels, bias=True) |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), |
| ) |
|
|
| def forward(self, x, c): |
| """ |
| Defines the forward pass for the final layer. |
| |
| Args: |
| x (torch.Tensor): The main input tensor from the transformer blocks. |
| c (torch.Tensor): The conditioning tensor (usually from timestep embedding) for AdaLN modulation. |
| |
| Returns: |
| torch.Tensor: The final output tensor projected to the patch dimension. |
| """ |
| scale = 1.0 + self.adaLN_modulation(c) |
| x = self.norm_final(x) * scale.unsqueeze(1) |
| x = self.linear(x) |
| return x |
|
|
|
|
| class RopeEmbedder: |
| """ |
| Computes Rotary Positional Embeddings (RoPE) for 3D coordinates. |
| """ |
|
|
| def __init__(self, theta: float = 256.0, axes_dims: List[int] = (32, 48, 48), axes_lens: List[int] = (1024, 512, 512)): |
| """ |
| Initializes the RopeEmbedder. |
| |
| Args: |
| theta (float, optional): The base for the rotary frequencies. Defaults to 256.0. |
| axes_dims (List[int], optional): The dimensions for each axis (F, H, W). Defaults to (32, 48, 48). |
| axes_lens (List[int], optional): The maximum length for each axis. Defaults to (1024, 512, 512). |
| """ |
| self.theta = theta |
| self.axes_dims = axes_dims |
| self.axes_lens = axes_lens |
| self.freqs_cis_cache = {} |
|
|
| def _precompute_freqs_cis(self, device): |
| """ |
| Precomputes and caches the rotary frequency tensors (cos and sin values). |
| |
| Args: |
| device (torch.device): The device to store the cached tensors on. |
| |
| Returns: |
| List[torch.Tensor]: A list of precomputed frequency tensors for each axis. |
| """ |
| if device in self.freqs_cis_cache: |
| return self.freqs_cis_cache[device] |
| freqs_cis_list = [] |
| for dim, max_len in zip(self.axes_dims, self.axes_lens): |
| half = dim // 2 |
| freqs = 1.0 / (self.theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half)) |
| t = torch.arange(max_len, device=device, dtype=torch.float32) |
| freqs = torch.outer(t, freqs) |
| emb = torch.stack([freqs.cos(), freqs.sin()], dim=-1) |
| freqs_cis_list.append(emb) |
| self.freqs_cis_cache[device] = freqs_cis_list |
| return freqs_cis_list |
|
|
| def __call__(self, ids: torch.Tensor): |
| """ |
| Generates RoPE embeddings for a batch of 3D coordinates. |
| |
| Args: |
| ids (torch.Tensor): A tensor of coordinates with shape (N, 3). |
| |
| Returns: |
| torch.Tensor: The concatenated RoPE embeddings for the input coordinates. |
| """ |
| assert ids.ndim == 2 and ids.shape[1] == len(self.axes_dims) |
| device = ids.device |
| freqs_cis_list = self._precompute_freqs_cis(device) |
| result = [] |
| for i in range(len(self.axes_dims)): |
| result.append(freqs_cis_list[i][ids[:, i]]) |
| return torch.cat(result, dim=-2) |
|
|
|
|
| class ZSingleStreamAttnProcessor: |
| """ |
| An attention processor that applies Rotary Positional Embeddings (RoPE) to query and key tensors |
| before computing scaled dot-product attention. |
| """ |
|
|
| _attention_backend = None |
| _parallel_config = None |
|
|
| def __init__(self): |
| """ |
| Initializes the ZSingleStreamAttnProcessor. |
| """ |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError("ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher.") |
|
|
| def __call__( |
| self, |
| attn: Attention, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| freqs_cis: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| The forward call for the attention processor. |
| |
| Args: |
| attn (Attention): The attention layer that this processor is attached to. |
| hidden_states (torch.Tensor): The input hidden states. |
| encoder_hidden_states (Optional[torch.Tensor], optional): Not used in self-attention. Defaults to None. |
| attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None. |
| freqs_cis (Optional[torch.Tensor], optional): The precomputed RoPE frequencies. Defaults to None. |
| |
| Returns: |
| torch.Tensor: The output of the attention mechanism. |
| """ |
|
|
| def apply_rotary_emb(q_or_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| """ |
| Applies RoPE to a query or key tensor. |
| """ |
| x = q_or_k.transpose(1, 2) |
| x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
| x0 = x_reshaped[..., 0] |
| x1 = x_reshaped[..., 1] |
| freqs_cos = freqs_cis[..., 0].unsqueeze(1) |
| freqs_sin = freqs_cis[..., 1].unsqueeze(1) |
| x_rotated_0 = x0 * freqs_cos - x1 * freqs_sin |
| x_rotated_1 = x0 * freqs_sin + x1 * freqs_cos |
| x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1) |
| x_out = x_rotated.flatten(-2).transpose(1, 2) |
| return x_out.to(q_or_k.dtype) |
|
|
| query = attn.to_q(hidden_states) |
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
|
|
| query = query.unflatten(-1, (attn.heads, -1)) |
| key = key.unflatten(-1, (attn.heads, -1)) |
| value = value.unflatten(-1, (attn.heads, -1)) |
|
|
| if attn.norm_q is not None: |
| query = attn.norm_q(query) |
| if attn.norm_k is not None: |
| key = attn.norm_k(key) |
|
|
| if freqs_cis is not None: |
| query = apply_rotary_emb(query, freqs_cis) |
| key = apply_rotary_emb(key, freqs_cis) |
|
|
| if attention_mask is not None and attention_mask.ndim == 2: |
| attention_mask = attention_mask[:, None, None, :] |
|
|
| hidden_states = dispatch_attention_fn( |
| query, |
| key, |
| value, |
| attn_mask=attention_mask, |
| dropout_p=0.0, |
| is_causal=False, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
|
|
| hidden_states = hidden_states.flatten(2, 3) |
|
|
| output = attn.to_out[0](hidden_states.to(hidden_states.dtype)) |
| if len(attn.to_out) > 1: |
| output = attn.to_out[1](output) |
|
|
| return output |
|
|
|
|
| @maybe_allow_in_graph |
| class ZImageTransformerBlock(nn.Module): |
| """ |
| A standard transformer block consisting of a self-attention layer and a feed-forward network. |
| Includes support for AdaLN modulation. |
| """ |
|
|
| def __init__( |
| self, |
| layer_id: int, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| norm_eps: float, |
| qk_norm: bool, |
| modulation=True, |
| ): |
| """ |
| Initializes the ZImageTransformerBlock. |
| |
| Args: |
| layer_id (int): The index of the layer. |
| dim (int): The dimension of the input and output features. |
| n_heads (int): The number of attention heads. |
| n_kv_heads (int): The number of key/value heads (not directly used in this simplified attention). |
| norm_eps (float): Epsilon for RMSNorm. |
| qk_norm (bool): Whether to apply normalization to query and key tensors. |
| modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. |
| """ |
| super().__init__() |
| self.dim = dim |
| self.head_dim = dim // n_heads |
| self.attention = Attention( |
| query_dim=dim, |
| cross_attention_dim=None, |
| dim_head=dim // n_heads, |
| heads=n_heads, |
| qk_norm="rms_norm" if qk_norm else None, |
| eps=1e-5, |
| bias=False, |
| out_bias=False, |
| processor=ZSingleStreamAttnProcessor(), |
| ) |
|
|
| self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) |
| self.layer_id = layer_id |
|
|
| self.attention_norm1 = RMSNorm(dim, eps=norm_eps) |
| self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) |
|
|
| self.attention_norm2 = RMSNorm(dim, eps=norm_eps) |
| self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) |
|
|
| self.modulation = modulation |
| if modulation: |
| self.adaLN_modulation = nn.Sequential( |
| nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), |
| ) |
|
|
| @property |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| """ |
| Returns a dictionary of all attention processors used in the module. |
| """ |
| processors = {} |
|
|
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| if hasattr(module, "get_processor"): |
| processors[f"{name}.processor"] = module.get_processor() |
| for sub_name, child in module.named_children(): |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| return processors |
|
|
| for name, module in self.named_children(): |
| fn_recursive_add_processors(name, module, processors) |
|
|
| return processors |
|
|
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| """ |
| Sets the attention processor for the attention layer in this block. |
| """ |
| count = len(self.attn_processors.keys()) |
|
|
| if isinstance(processor, dict) and len(processor) != count: |
| raise ValueError( |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| ) |
|
|
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| if hasattr(module, "set_processor"): |
| if not isinstance(processor, dict): |
| module.set_processor(processor) |
| else: |
| module.set_processor(processor.pop(f"{name}.processor")) |
| for sub_name, child in module.named_children(): |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
|
|
| for name, module in self.named_children(): |
| fn_recursive_attn_processor(name, module, processor) |
|
|
| def forward(self, x, attn_mask, freqs_cis, adaln_input=None): |
| """ |
| Defines the forward pass for the transformer block. |
| |
| Args: |
| x (torch.Tensor): The input tensor. |
| attn_mask (torch.Tensor): The attention mask. |
| freqs_cis (torch.Tensor): The RoPE frequencies. |
| adaln_input (torch.Tensor, optional): The conditioning tensor for AdaLN. Defaults to None. |
| |
| Returns: |
| torch.Tensor: The output tensor of the block. |
| """ |
| if self.modulation: |
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) |
| scale_msa = scale_msa + 1.0 |
| gate_msa = gate_msa.tanh() |
| scale_mlp = scale_mlp + 1.0 |
| gate_mlp = gate_mlp.tanh() |
|
|
| normed = self.attention_norm1(x) |
| normed = normed * scale_msa |
| attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis) |
| attn_out = self.attention_norm2(attn_out) * gate_msa |
| x = x + attn_out |
|
|
| normed = self.ffn_norm1(x) |
| normed = normed * scale_mlp |
| ffn_out = self.feed_forward(normed) |
| ffn_out = self.ffn_norm2(ffn_out) * gate_mlp |
| x = x + ffn_out |
| else: |
| normed = self.attention_norm1(x) |
| attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis) |
| x = x + self.attention_norm2(attn_out) |
| normed = self.ffn_norm1(x) |
| ffn_out = self.feed_forward(normed) |
| x = x + self.ffn_norm2(ffn_out) |
| return x |
|
|
|
|
| class ZImageControlTransformerBlock(ZImageTransformerBlock): |
| """ |
| A specialized transformer block for the control pathway. It inherits from ZImageTransformerBlock |
| and adds projection layers to generate and combine control signals. |
| """ |
|
|
| def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0): |
| """ |
| Initializes the ZImageControlTransformerBlock. |
| |
| Args: |
| layer_id (int): The index of the layer. |
| dim (int): The dimension of the features. |
| n_heads (int): The number of attention heads. |
| n_kv_heads (int): The number of key/value heads. |
| norm_eps (float): Epsilon for RMSNorm. |
| qk_norm (bool): Whether to apply normalization to query and key. |
| modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. |
| block_id (int, optional): The index of this control block. Defaults to 0. |
| """ |
| super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) |
| self.block_id = block_id |
| if block_id == 0: |
| self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) |
| self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) |
|
|
| def forward(self, c, x, **kwargs): |
| """ |
| Defines the forward pass for the control block. |
| |
| Args: |
| c (torch.Tensor): The control signal tensor. |
| x (torch.Tensor): The reference tensor from the main pathway. |
| **kwargs: Additional arguments for the parent's forward method. |
| |
| Returns: |
| torch.Tensor: A stacked tensor containing the skip connection and the final output. |
| """ |
| if self.block_id == 0: |
| c = self.before_proj(c) + x |
| all_c = [] |
| else: |
| all_c = list(torch.unbind(c)) |
| c = all_c.pop(-1) |
|
|
| c = super().forward(c, **kwargs) |
| c_skip = self.after_proj(c) |
| all_c += [c_skip, c] |
| c = torch.stack(all_c) |
| return c |
|
|
|
|
| class BaseZImageTransformerBlock(ZImageTransformerBlock): |
| """ |
| The main transformer block used in the primary pathway. It inherits from ZImageTransformerBlock |
| and adds the logic to inject control "hints" from the control pathway. |
| """ |
|
|
| def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0): |
| """ |
| Initializes the BaseZImageTransformerBlock. |
| |
| Args: |
| layer_id (int): The index of the layer. |
| dim (int): The dimension of the features. |
| n_heads (int): The number of attention heads. |
| n_kv_heads (int): The number of key/value heads. |
| norm_eps (float): Epsilon for RMSNorm. |
| qk_norm (bool): Whether to apply normalization to query and key. |
| modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. |
| block_id (int, optional): The index used to retrieve the corresponding control hint. Defaults to 0. |
| """ |
| super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) |
| self.block_id = block_id |
|
|
| def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs): |
| """ |
| Defines the forward pass, including the injection of control hints. |
| |
| Args: |
| hidden_states (torch.Tensor): The input tensor. |
| hints (List[torch.Tensor], optional): A list of control hints from the control pathway. Defaults to None. |
| context_scale (float, optional): A scale factor for the control hints. Defaults to 1.0. |
| **kwargs: Additional arguments for the parent's forward method. |
| |
| Returns: |
| torch.Tensor: The output tensor of the block. |
| """ |
| hidden_states = super().forward(hidden_states, **kwargs) |
| if self.block_id is not None and hints is not None: |
| hidden_states = hidden_states + hints[self.block_id] * context_scale |
| return hidden_states |
|
|
|
|
| class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): |
| _supports_gradient_checkpointing = True |
| _keys_to_ignore_on_load_unexpected = [ |
| r"control_layers\..*", |
| r"control_noise_refiner\..*", |
| r"control_all_x_embedder\..*", |
| ] |
| _no_split_modules = ["ZImageTransformerBlock", "BaseZImageTransformerBlock", "ZImageControlTransformerBlock"] |
| _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] |
| _group_offload_block_modules = ["t_embedder", "cap_embedder"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| control_layers_places=None, |
| control_refiner_layers_places=None, |
| control_in_dim=None, |
| add_control_noise_refiner=False, |
| all_patch_size=(2,), |
| all_f_patch_size=(1,), |
| in_channels=16, |
| dim=3840, |
| n_layers=30, |
| n_refiner_layers=2, |
| n_heads=30, |
| n_kv_heads=30, |
| norm_eps=1e-5, |
| qk_norm=True, |
| cap_feat_dim=2560, |
| rope_theta=256.0, |
| t_scale=1000.0, |
| axes_dims=[32, 48, 48], |
| axes_lens=[1024, 512, 512], |
| use_controlnet=True, |
| checkpoint_ratio=0.5, |
| ): |
| """ |
| Initializes the ZImageControlTransformer2DModel. |
| |
| Args: |
| control_layers_places (List[int], optional): Indices of main layers where control hints are injected. |
| control_refiner_layers_places (List[int], optional): Indices of noise refiner layers for two-stage control. |
| control_in_dim (int, optional): Input channel dimension for the control context. |
| add_control_noise_refiner (bool, optional): Whether to add a dedicated refiner for the control signal. |
| all_patch_size (Tuple[int], optional): Tuple of patch sizes for spatial dimensions. |
| all_f_patch_size (Tuple[int], optional): Tuple of patch sizes for the frame dimension. |
| in_channels (int, optional): Number of input channels for the latent image. |
| dim (int, optional): The main dimension of the transformer model. |
| n_layers (int, optional): The number of main transformer layers. |
| n_refiner_layers (int, optional): The number of layers in the refiner blocks. |
| n_heads (int, optional): The number of attention heads. |
| n_kv_heads (int, optional): The number of key/value heads. |
| norm_eps (float, optional): Epsilon for RMSNorm. |
| qk_norm (bool, optional): Whether to apply normalization to query and key. |
| cap_feat_dim (int, optional): The dimension of the input caption features. |
| rope_theta (float, optional): The base for RoPE. |
| t_scale (float, optional): A scaling factor for the timestep. |
| axes_dims (List[int], optional): Dimensions for each axis in RoPE. |
| axes_lens (List[int], optional): Maximum lengths for each axis in RoPE. |
| use_controlnet (bool, optional): If False, control-related layers will not be created to save memory. |
| checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. |
| """ |
| super().__init__() |
| self.use_controlnet = use_controlnet |
| self.in_channels = in_channels |
| self.out_channels = in_channels |
| self.all_patch_size = all_patch_size |
| self.all_f_patch_size = all_f_patch_size |
| self.dim = dim |
| self.control_in_dim = self.dim if control_in_dim is None else control_in_dim |
| self.is_two_stage_control = self.control_in_dim > 16 |
| self.n_heads = n_heads |
| self.rope_theta = rope_theta |
| self.t_scale = t_scale |
| self.gradient_checkpointing = False |
| self.checkpoint_ratio = checkpoint_ratio |
| assert len(all_patch_size) == len(all_f_patch_size) |
|
|
| self.control_layers_places = list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places |
| self.control_refiner_layers_places = list(range(0, n_refiner_layers)) if control_refiner_layers_places is None else control_refiner_layers_places |
| self.add_control_noise_refiner = add_control_noise_refiner |
| assert 0 in self.control_layers_places |
| self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)} |
| self.control_refiner_layers_mapping = {i: n for n, i in enumerate(self.control_refiner_layers_places)} |
|
|
| self.all_x_embedder = nn.ModuleDict( |
| { |
| f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) |
| for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) |
| } |
| ) |
|
|
| self.all_final_layer = nn.ModuleDict( |
| { |
| f"{patch_size}-{f_patch_size}": FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) |
| for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) |
| } |
| ) |
|
|
| self.context_refiner = nn.ModuleList( |
| [ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False) for i in range(n_refiner_layers)] |
| ) |
| self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) |
| self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) |
| self.x_pad_token = nn.Parameter(torch.empty((1, dim))) |
| self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) |
|
|
| head_dim = dim // n_heads |
| assert head_dim == sum(axes_dims) |
| self.axes_dims = axes_dims |
| self.axes_lens = axes_lens |
| self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) |
|
|
| self.layers = nn.ModuleList( |
| [BaseZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=self.control_layers_mapping.get(i)) for i in range(n_layers)] |
| ) |
|
|
| self.noise_refiner = nn.ModuleList( |
| [ |
| BaseZImageTransformerBlock( |
| 1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=self.control_refiner_layers_mapping.get(i) |
| ) |
| for i in range(n_refiner_layers) |
| ] |
| ) |
|
|
| if self.use_controlnet: |
| self.control_layers = nn.ModuleList( |
| [ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places] |
| ) |
| self.control_all_x_embedder = nn.ModuleDict( |
| { |
| f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) |
| for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) |
| } |
| ) |
|
|
| if self.is_two_stage_control: |
| if self.add_control_noise_refiner: |
| self.control_noise_refiner = nn.ModuleList( |
| [ |
| ZImageControlTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=layer_id) |
| for layer_id in range(n_refiner_layers) |
| ] |
| ) |
| else: |
| self.control_noise_refiner = None |
| else: |
| self.control_noise_refiner = nn.ModuleList( |
| [ZImageTransformerBlock(1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True) for i in range(n_refiner_layers)] |
| ) |
| else: |
| self.control_layers = None |
| self.control_all_x_embedder = None |
| self.control_noise_refiner = None |
|
|
| def _unpatchify(self, x_image_tokens: torch.Tensor, all_sizes: List[Tuple], patch_size: int, f_patch_size: int) -> torch.Tensor: |
| """ |
| Converts a sequence of image tokens back into a batched image tensor. This version is robust |
| to batches containing images of different original sizes. |
| |
| Args: |
| x_image_tokens (torch.Tensor): A tensor of image tokens with shape [B, SeqLen, Dim]. |
| all_sizes (List[Tuple]): A list of tuples with the original (F, H, W) size for each image in the batch. |
| patch_size (int): The spatial patch size (height and width). |
| f_patch_size (int): The frame/temporal patch size. |
| |
| Returns: |
| torch.Tensor: The reconstructed latent tensor with shape [B, C, F, H, W]. |
| """ |
| pH = pW = patch_size |
| pF = f_patch_size |
| batch_size = x_image_tokens.shape[0] |
| unpatched_images = [] |
|
|
| for i in range(batch_size): |
| F, H, W = all_sizes[i] |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
| original_seq_len = F_tokens * H_tokens * W_tokens |
| current_image_tokens = x_image_tokens[i, :original_seq_len, :] |
| unpatched_image = current_image_tokens.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, self.out_channels) |
| unpatched_image = unpatched_image.permute(6, 0, 3, 1, 4, 2, 5).reshape(self.out_channels, F, H, W) |
| unpatched_images.append(unpatched_image) |
|
|
| try: |
| final_tensor = torch.stack(unpatched_images, dim=0) |
| except RuntimeError: |
| raise ValueError( |
| "Could not stack unpatched images into a single batch tensor. " |
| "This typically occurs if you are trying to generate images of different sizes in the same batch." |
| ) |
|
|
| return final_tensor |
|
|
| def _patchify( |
| self, |
| all_image: List[torch.Tensor], |
| patch_size: int, |
| f_patch_size: int, |
| cap_padding_len: int, |
| ): |
| """ |
| Converts a list of image tensors into patch sequences and computes their positional IDs. |
| |
| Args: |
| all_image (List[torch.Tensor]): A list of image tensors to process. |
| patch_size (int): The spatial patch size. |
| f_patch_size (int): The frame/temporal patch size. |
| cap_padding_len (int): The length of the padded caption sequence, used as an offset for image position IDs. |
| |
| Returns: |
| Tuple: A tuple containing lists of processed patches, sizes, position IDs, and padding masks. |
| """ |
| pH = pW = patch_size |
| pF = f_patch_size |
| device = all_image[0].device |
|
|
| all_image_out = [] |
| all_image_size = [] |
| all_image_pos_ids = [] |
| all_image_pad_mask = [] |
|
|
| for i, image in enumerate(all_image): |
| C, F, H, W = image.size() |
| all_image_size.append((F, H, W)) |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
|
|
| image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
| image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
|
|
| image_ori_len = len(image) |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF |
|
|
| image_ori_pos_ids = self._create_coordinate_grid( |
| size=(F_tokens, H_tokens, W_tokens), |
| start=(cap_padding_len + 1, 0, 0), |
| device=device, |
| ).flatten(0, 2) |
| image_padding_pos_ids = ( |
| self._create_coordinate_grid( |
| size=(1, 1, 1), |
| start=(0, 0, 0), |
| device=device, |
| ) |
| .flatten(0, 2) |
| .repeat(image_padding_len, 1) |
| ) |
| image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) |
| all_image_pos_ids.append(image_padded_pos_ids) |
| all_image_pad_mask.append( |
| torch.cat( |
| [ |
| torch.zeros((image_ori_len,), dtype=torch.bool, device=device), |
| torch.ones((image_padding_len,), dtype=torch.bool, device=device), |
| ], |
| dim=0, |
| ) |
| ) |
| image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) |
| all_image_out.append(image_padded_feat) |
|
|
| return ( |
| all_image_out, |
| all_image_size, |
| all_image_pos_ids, |
| all_image_pad_mask, |
| ) |
|
|
| def _patchify_and_embed( |
| self, |
| all_image: List[torch.Tensor], |
| all_cap_feats: List[torch.Tensor], |
| patch_size: int, |
| f_patch_size: int, |
| ): |
| """ |
| Processes a batch of images and caption features by converting them into padded patch sequences |
| and generating their corresponding positional IDs and padding masks. This is the general-purpose, |
| robust version that iterates through the batch. |
| |
| Args: |
| all_image (List[torch.Tensor]): A list of image tensors. |
| all_cap_feats (List[torch.Tensor]): A list of caption feature tensors. |
| patch_size (int): The spatial patch size. |
| f_patch_size (int): The frame/temporal patch size. |
| |
| Returns: |
| Tuple: A tuple containing all processed data structures (image patches, caption features, sizes, |
| position IDs, and padding masks) as lists. |
| """ |
| pH = pW = patch_size |
| pF = f_patch_size |
| device = all_image[0].device |
|
|
| all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], [] |
| all_cap_pos_ids, all_cap_pad_mask, all_cap_feats_out = [], [], [] |
|
|
| for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): |
| cap_ori_len = len(cap_feat) |
| cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF |
| cap_total_len = cap_ori_len + cap_padding_len |
|
|
| cap_padded_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2) |
| all_cap_pos_ids.append(cap_padded_pos_ids) |
|
|
| cap_mask = torch.ones(cap_total_len, dtype=torch.bool, device=device) |
| cap_mask[:cap_ori_len] = False |
| all_cap_pad_mask.append(cap_mask) |
|
|
| if cap_padding_len > 0: |
| padding_tensor = cap_feat[-1:].repeat(cap_padding_len, 1) |
| cap_padded_feat = torch.cat([cap_feat, padding_tensor], dim=0) |
| else: |
| cap_padded_feat = cap_feat |
| all_cap_feats_out.append(cap_padded_feat) |
|
|
| C, Fr, H, W = image.size() |
| all_image_size.append((Fr, H, W)) |
| F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW |
|
|
| image_reshaped = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW).permute(1, 3, 5, 2, 4, 6, 0).reshape(-1, pF * pH * pW * C) |
|
|
| image_ori_len = image_reshaped.shape[0] |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF |
| image_total_len = image_ori_len + image_padding_len |
|
|
| image_ori_pos_ids = self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device).flatten(0, 2) |
| if image_padding_len > 0: |
| image_padding_pos_ids = torch.zeros((image_padding_len, 3), dtype=torch.int32, device=device) |
| image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) |
| else: |
| image_padded_pos_ids = image_ori_pos_ids |
| all_image_pos_ids.append(image_padded_pos_ids) |
|
|
| image_mask = torch.ones(image_total_len, dtype=torch.bool, device=device) |
| image_mask[:image_ori_len] = False |
| all_image_pad_mask.append(image_mask) |
|
|
| if image_padding_len > 0: |
| padding_tensor = image_reshaped[-1:].repeat(image_padding_len, 1) |
| image_padded_feat = torch.cat([image_reshaped, padding_tensor], dim=0) |
| else: |
| image_padded_feat = image_reshaped |
| all_image_out.append(image_padded_feat) |
|
|
| return ( |
| all_image_out, |
| all_cap_feats_out, |
| all_image_size, |
| all_image_pos_ids, |
| all_cap_pos_ids, |
| all_image_pad_mask, |
| all_cap_pad_mask, |
| ) |
|
|
| def _process_cap_feats_with_cfg_cache(self, cap_feats_list, cap_pos_ids, cap_inner_pad_mask): |
| """ |
| Processes caption features with intelligent duplicate detection to avoid redundant computation, |
| especially for Classifier-Free Guidance (CFG) where prompts are repeated. |
| |
| Args: |
| cap_feats_list (List[torch.Tensor]): List of padded caption feature tensors. |
| cap_pos_ids (List[torch.Tensor]): List of corresponding position ID tensors. |
| cap_inner_pad_mask (List[torch.Tensor]): List of corresponding padding masks. |
| |
| Returns: |
| Tuple: A tuple of batched tensors for padded features, RoPE frequencies, attention mask, and sequence lengths. |
| """ |
| device = cap_feats_list[0].device |
| bsz = len(cap_feats_list) |
|
|
| shapes_equal = all(c.shape == cap_feats_list[0].shape for c in cap_feats_list) |
|
|
| if shapes_equal and bsz >= 2: |
| unique_indices = [0] |
| unique_tensors = [cap_feats_list[0]] |
| tensor_mapping = [0] |
|
|
| for i in range(1, bsz): |
| found_match = False |
| for j, unique_tensor in enumerate(unique_tensors): |
| if torch.equal(cap_feats_list[i], unique_tensor): |
| tensor_mapping.append(j) |
| found_match = True |
| break |
|
|
| if not found_match: |
| unique_indices.append(i) |
| unique_tensors.append(cap_feats_list[i]) |
| tensor_mapping.append(len(unique_tensors) - 1) |
|
|
| if len(unique_tensors) < bsz: |
| unique_cap_feats_list = [cap_feats_list[i] for i in unique_indices] |
| unique_cap_pos_ids = [cap_pos_ids[i] for i in unique_indices] |
| unique_cap_inner_pad_mask = [cap_inner_pad_mask[i] for i in unique_indices] |
|
|
| cap_item_seqlens_unique = [len(i) for i in unique_cap_feats_list] |
| cap_max_item_seqlen = max(cap_item_seqlens_unique) |
|
|
| cap_feats_cat = torch.cat(unique_cap_feats_list, dim=0) |
| cap_feats_embedded = self.cap_embedder(cap_feats_cat) |
| cap_feats_embedded[torch.cat(unique_cap_inner_pad_mask)] = self.cap_pad_token |
| cap_feats_padded_unique = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0) |
|
|
| cap_freqs_cis_cat = self.rope_embedder(torch.cat(unique_cap_pos_ids, dim=0)) |
| cap_freqs_cis_unique = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0) |
|
|
| cap_feats_padded = cap_feats_padded_unique[tensor_mapping] |
| cap_freqs_cis = cap_freqs_cis_unique[tensor_mapping] |
|
|
| seq_lens_tensor = torch.tensor([cap_max_item_seqlen] * bsz, device=device, dtype=torch.int32) |
| arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32) |
| cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None] |
|
|
| cap_item_seqlens = [cap_max_item_seqlen] * bsz |
|
|
| return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens |
|
|
| cap_item_seqlens = [len(i) for i in cap_feats_list] |
| cap_max_item_seqlen = max(cap_item_seqlens) |
| cap_feats_cat = torch.cat(cap_feats_list, dim=0) |
| cap_feats_embedded = self.cap_embedder(cap_feats_cat) |
| cap_feats_embedded[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token |
| cap_feats_padded = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) |
|
|
| cap_freqs_cis_cat = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)) |
| cap_freqs_cis = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) |
|
|
| seq_lens_tensor = torch.tensor(cap_item_seqlens, device=device, dtype=torch.int32) |
| arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32) |
| cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None] |
|
|
| return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens |
|
|
| @staticmethod |
| def _create_coordinate_grid(size, start=None, device=None): |
| """ |
| Creates a 3D coordinate grid. |
| |
| Args: |
| size (Tuple[int]): The dimensions of the grid (F, H, W). |
| start (Tuple[int], optional): The starting coordinates for each axis. Defaults to (0, 0, 0). |
| device (torch.device, optional): The device to create the tensor on. Defaults to None. |
| |
| Returns: |
| torch.Tensor: The coordinate grid tensor. |
| """ |
| if start is None: |
| start = (0 for _ in size) |
| axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] |
| grids = torch.meshgrid(axes, indexing="ij") |
| return torch.stack(grids, dim=-1) |
|
|
| def _apply_transformer_blocks(self, hidden_states, layers, checkpoint_ratio=0.5, **kwargs): |
| """ |
| Applies a list of transformer layers to the hidden states, with optional selective gradient checkpointing. |
| |
| Args: |
| hidden_states (torch.Tensor): The input tensor. |
| layers (nn.ModuleList): The list of transformer layers to apply. |
| checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. Defaults to 0.5. |
| **kwargs: Additional keyword arguments to pass to each layer's forward method. |
| |
| Returns: |
| torch.Tensor: The output tensor after applying all layers. |
| """ |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module, **static_kwargs): |
| def custom_forward(*inputs): |
| return module(*inputs, **static_kwargs) |
|
|
| return custom_forward |
|
|
| ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
| checkpoint_every_n = max(1, int(1.0 / checkpoint_ratio)) if checkpoint_ratio > 0 else len(layers) + 1 |
|
|
| for i, layer in enumerate(layers): |
| if i % checkpoint_every_n == 0: |
| hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer, **kwargs), |
| hidden_states, |
| **ckpt_kwargs, |
| ) |
| else: |
| hidden_states = layer(hidden_states, **kwargs) |
| else: |
| for layer in layers: |
| hidden_states = layer(hidden_states, **kwargs) |
|
|
| return hidden_states |
|
|
| def _prepare_control_inputs(self, control_context, cap_feats_ref, t, patch_size, f_patch_size, device): |
| """ |
| Prepares the control context for the transformer, including patchifying, embedding, and generating |
| positional information. Includes a fast path for batches with uniform shapes. |
| |
| Args: |
| control_context (torch.Tensor or List[torch.Tensor]): The control context input. |
| cap_feats_ref (List[torch.Tensor]): A reference to caption features for padding calculation. |
| t (torch.Tensor): The timestep tensor. |
| patch_size (int): The spatial patch size. |
| f_patch_size (int): The frame/temporal patch size. |
| device (torch.device): The target device. |
| |
| Returns: |
| Dict: A dictionary containing the processed control tensors ('c', 'c_item_seqlens', 'attn_mask', etc.). |
| """ |
| bsz = control_context.shape[0] |
|
|
| if isinstance(control_context, torch.Tensor) and control_context.ndim == 5: |
| control_list = list(torch.unbind(control_context, dim=0)) |
| else: |
| control_list = control_context |
|
|
| pH = pW = patch_size |
| pF = f_patch_size |
| cap_padding_len = cap_feats_ref[0].size(0) if isinstance(cap_feats_ref, list) else cap_feats_ref.shape[1] |
|
|
| shapes = [c.shape for c in control_list] |
| same_shape = all(s == shapes[0] for s in shapes) |
|
|
| if same_shape and bsz >= 2: |
| control_batch = torch.stack(control_list, dim=0) |
| B, C, F, H, W = control_batch.shape |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
|
|
| control_batch = control_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
| control_batch = control_batch.permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
|
|
| ori_len = control_batch.shape[1] |
| padding_len = (-ori_len) % SEQ_MULTI_OF |
|
|
| if padding_len > 0: |
| pad_tensor = control_batch[:, -1:, :].repeat(1, padding_len, 1) |
| control_batch = torch.cat([control_batch, pad_tensor], dim=1) |
|
|
| c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_batch) |
|
|
| final_seq_len = control_batch.shape[1] |
| pos_ids_ori = self._create_coordinate_grid( |
| size=(F_tokens, H_tokens, W_tokens), |
| start=(cap_padding_len + 1, 0, 0), |
| device=device, |
| ).flatten(0, 2) |
|
|
| pos_ids_pad = torch.zeros((padding_len, 3), dtype=torch.int32, device=device) |
| pos_ids_padded = torch.cat([pos_ids_ori, pos_ids_pad], dim=0) |
|
|
| c_freqs_cis_single = self.rope_embedder(pos_ids_padded) |
| c_freqs_cis = c_freqs_cis_single.unsqueeze(0).repeat(B, 1, 1, 1) |
| c_attn_mask = torch.ones((B, final_seq_len), dtype=torch.bool, device=device) |
|
|
| return {"c": c, "c_item_seqlens": [final_seq_len] * B, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis, "adaln_input": t.type_as(c)} |
|
|
| (c_patches, _, c_pos_ids, c_inner_pad_mask) = self._patchify(control_list, patch_size, f_patch_size, cap_padding_len) |
|
|
| c_item_seqlens = [len(p) for p in c_patches] |
| c_max_item_seqlen = max(c_item_seqlens) |
|
|
| c = torch.cat(c_patches, dim=0) |
| c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](c) |
| c[torch.cat(c_inner_pad_mask)] = self.x_pad_token |
| c = list(c.split(c_item_seqlens, dim=0)) |
|
|
| c_freqs_cis_list = [] |
| for pos_ids in c_pos_ids: |
| c_freqs_cis_list.append(self.rope_embedder(pos_ids)) |
|
|
| c_padded = pad_sequence(c, batch_first=True, padding_value=0.0) |
| c_freqs_cis_padded = pad_sequence(c_freqs_cis_list, batch_first=True, padding_value=0.0) |
|
|
| seq_lens_tensor = torch.tensor(c_item_seqlens, device=device, dtype=torch.int32) |
| arange = torch.arange(c_max_item_seqlen, device=device, dtype=torch.int32) |
| c_attn_mask = arange[None, :] < seq_lens_tensor[:, None] |
|
|
| return {"c": c_padded, "c_item_seqlens": c_item_seqlens, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis_padded, "adaln_input": t.type_as(c_padded)} |
|
|
| def _patchify_and_embed_batch_optimized(self, all_image, all_cap_feats, patch_size, f_patch_size): |
| """ |
| An optimized version of _patchify_and_embed for batches where all images and captions have |
| uniform shapes. It processes the entire batch using vectorized operations instead of a loop. |
| |
| Args: |
| all_image (List[torch.Tensor]): List of image tensors, all of the same shape. |
| all_cap_feats (List[torch.Tensor]): List of caption features, all of the same shape. |
| patch_size (int): The spatial patch size. |
| f_patch_size (int): The frame/temporal patch size. |
| |
| Returns: |
| Tuple: A tuple containing all processed data structures, matching the output of the standard method. |
| """ |
| pH = pW = patch_size |
| pF = f_patch_size |
| device = all_image[0].device |
|
|
| image_shapes = [img.shape for img in all_image] |
| cap_shapes = [cap.shape for cap in all_cap_feats] |
|
|
| same_image_shape = all(s == image_shapes[0] for s in image_shapes) |
| same_cap_shape = all(s == cap_shapes[0] for s in cap_shapes) |
|
|
| if not (same_image_shape and same_cap_shape): |
| return self._patchify_and_embed(all_image, all_cap_feats, patch_size, f_patch_size) |
|
|
| images_batch = torch.stack(all_image, dim=0) |
| caps_batch = torch.stack(all_cap_feats, dim=0) |
|
|
| B, C, Fr, H, W = images_batch.shape |
| cap_ori_len = caps_batch.shape[1] |
|
|
| cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF |
| cap_total_len = cap_ori_len + cap_padding_len |
|
|
| if cap_padding_len > 0: |
| cap_pad = caps_batch[:, -1:, :].repeat(1, cap_padding_len, 1) |
| caps_batch = torch.cat([caps_batch, cap_pad], dim=1) |
|
|
| cap_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2).unsqueeze(0).repeat(B, 1, 1) |
|
|
| cap_mask = torch.zeros((B, cap_total_len), dtype=torch.bool, device=device) |
| if cap_padding_len > 0: |
| cap_mask[:, cap_ori_len:] = True |
|
|
| F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW |
| images_reshaped = ( |
| images_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
| .permute(0, 2, 4, 6, 3, 5, 7, 1) |
| .reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
| ) |
|
|
| image_ori_len = images_reshaped.shape[1] |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF |
| image_total_len = image_ori_len + image_padding_len |
|
|
| if image_padding_len > 0: |
| img_pad = images_reshaped[:, -1:, :].repeat(1, image_padding_len, 1) |
| images_reshaped = torch.cat([images_reshaped, img_pad], dim=1) |
|
|
| image_pos_ids = ( |
| self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device) |
| .flatten(0, 2) |
| .unsqueeze(0) |
| .repeat(B, 1, 1) |
| ) |
|
|
| if image_padding_len > 0: |
| img_pos_pad = torch.zeros((B, image_padding_len, 3), dtype=torch.int32, device=device) |
| image_pos_ids = torch.cat([image_pos_ids, img_pos_pad], dim=1) |
|
|
| image_mask = torch.zeros((B, image_total_len), dtype=torch.bool, device=device) |
| if image_padding_len > 0: |
| image_mask[:, image_ori_len:] = True |
|
|
| all_image_size = [(Fr, H, W)] * B |
|
|
| return ( |
| list(torch.unbind(images_reshaped, dim=0)), |
| list(torch.unbind(caps_batch, dim=0)), |
| all_image_size, |
| list(torch.unbind(image_pos_ids, dim=0)), |
| list(torch.unbind(cap_pos_ids, dim=0)), |
| list(torch.unbind(image_mask, dim=0)), |
| list(torch.unbind(cap_mask, dim=0)), |
| ) |
|
|
| def forward( |
| self, |
| x: List[torch.Tensor], |
| t, |
| cap_feats: List[torch.Tensor], |
| patch_size=2, |
| f_patch_size=1, |
| control_context=None, |
| conditioning_scale=1.0, |
| refiner_conditioning_scale=1.0, |
| ): |
| """ |
| The main forward pass of the transformer model. |
| |
| Args: |
| x (List[torch.Tensor]): |
| A list of latent image tensors. |
| t (torch.Tensor): |
| A batch of timesteps. |
| cap_feats (List[torch.Tensor]): |
| A list of caption feature tensors. |
| patch_size (int, optional): |
| The spatial patch size to use. Defaults to 2. |
| f_patch_size (int, optional): |
| The frame/temporal patch size to use. Defaults to 1. |
| control_context (torch.Tensor, optional): |
| The control context tensor. Defaults to None. |
| conditioning_scale (float, optional): |
| The scale for applying control hints. Defaults to 1.0. |
| refiner_conditioning_scale (float, optional): |
| The scale for applying refiner control hints. Defaults to 1.0. |
| |
| Returns: |
| Transformer2DModelOutput: An object containing the final denoised sample. |
| """ |
|
|
| is_control_mode = self.use_controlnet and control_context is not None and conditioning_scale > 0 |
| if refiner_conditioning_scale is None: |
| refiner_conditioning_scale = conditioning_scale or 1.0 |
|
|
| assert patch_size in self.all_patch_size |
| assert f_patch_size in self.all_f_patch_size |
|
|
| bsz = len(x) |
| device = x[0].device |
|
|
| t = t * self.t_scale |
| t = self.t_embedder(t) |
|
|
| can_optimize_patchify = ( |
| bsz == len(cap_feats) and bsz >= 2 and all(img.shape == x[0].shape for img in x) and all(cap.shape == cap_feats[0].shape for cap in cap_feats) |
| ) |
|
|
| if can_optimize_patchify: |
| (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed_batch_optimized( |
| x, cap_feats, patch_size, f_patch_size |
| ) |
| else: |
| (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed( |
| x, cap_feats, patch_size, f_patch_size |
| ) |
|
|
| x_item_seqlens = [len(i) for i in x_list] |
| x_max_item_seqlen = max(x_item_seqlens) if x_item_seqlens else 0 |
| x_cat = torch.cat(x_list, dim=0) if x_list else torch.empty(0, x_list[0].shape[1] if x_list else 0, device=device) |
| x_embedded = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat) |
| if x_inner_pad_mask and torch.cat(x_inner_pad_mask).any(): |
| x_embedded[torch.cat(x_inner_pad_mask)] = self.x_pad_token |
| x = pad_sequence(list(x_embedded.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) |
| adaln_input = t.to(device).type_as(x) |
|
|
| cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens = self._process_cap_feats_with_cfg_cache( |
| cap_feats_list, cap_pos_ids, cap_inner_pad_mask |
| ) |
|
|
| x_freqs_cis_cat = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) if x_pos_ids else torch.empty(0, device=device) |
| x_freqs_cis = pad_sequence(list(x_freqs_cis_cat.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) |
|
|
| seq_lens_tensor = torch.tensor(x_item_seqlens, device=device, dtype=torch.int32) |
| arange = torch.arange(x_max_item_seqlen, device=device, dtype=torch.int32) |
| x_attn_mask = arange[None, :] < seq_lens_tensor[:, None] |
|
|
|
|
| refiner_hints = None |
| if is_control_mode and self.is_two_stage_control: |
| prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device) |
| c = prepared_control["c"] |
| """ |
| kwargs_for_control_refiner = { |
| "x": x, |
| "attn_mask": prepared_control["attn_mask"], |
| "freqs_cis": prepared_control["freqs_cis"], |
| "adaln_input": prepared_control["adaln_input"], |
| } |
| c_processed = self._apply_transformer_blocks( |
| c, |
| self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, |
| checkpoint_ratio=self.checkpoint_ratio, |
| **kwargs_for_control_refiner, |
| ) |
| refiner_hints = torch.unbind(c_processed)[:-1] |
| control_context_processed = torch.unbind(c_processed)[-1] |
| control_context_item_seqlens = prepared_control["c_item_seqlens"] |
| """ |
| kwargs_for_control_refiner = { |
| "x": x, |
| "attn_mask": x_attn_mask, |
| "freqs_cis": x_freqs_cis, |
| "adaln_input": adaln_input, |
| } |
| c_processed = self._apply_transformer_blocks( |
| c, |
| self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, |
| checkpoint_ratio=self.checkpoint_ratio, |
| **kwargs_for_control_refiner, |
| ) |
| refiner_hints = torch.unbind(c_processed)[:-1] |
| control_context_processed = torch.unbind(c_processed)[-1] |
| control_context_item_seqlens = prepared_control["c_item_seqlens"] |
| kwargs_for_refiner = { |
| "attn_mask": x_attn_mask, |
| "freqs_cis": x_freqs_cis, |
| "adaln_input": adaln_input, |
| "context_scale": refiner_conditioning_scale, |
| } |
| if refiner_hints is not None: |
| kwargs_for_refiner["hints"] = refiner_hints |
| x = self._apply_transformer_blocks(x, self.noise_refiner, checkpoint_ratio=1.0, **kwargs_for_refiner) |
|
|
| kwargs_for_context = {"attn_mask": cap_attn_mask, "freqs_cis": cap_freqs_cis} |
| cap_feats = self._apply_transformer_blocks(cap_feats_padded, self.context_refiner, checkpoint_ratio=1.0, **kwargs_for_context) |
|
|
| unified_item_seqlens = [a + b for a, b in zip(x_item_seqlens, cap_item_seqlens)] |
| unified_max_item_seqlen = max(unified_item_seqlens) if unified_item_seqlens else 0 |
| unified = torch.zeros((bsz, unified_max_item_seqlen, x.shape[-1]), dtype=x.dtype, device=device) |
| unified_freqs_cis = torch.zeros((bsz, unified_max_item_seqlen, x_freqs_cis.shape[-2], x_freqs_cis.shape[-1]), dtype=x_freqs_cis.dtype, device=device) |
|
|
| for i in range(bsz): |
| x_len = x_item_seqlens[i] |
| cap_len = cap_item_seqlens[i] |
| unified[i, :x_len] = x[i, :x_len] |
| unified[i, x_len : x_len + cap_len] = cap_feats[i, :cap_len] |
| unified_freqs_cis[i, :x_len] = x_freqs_cis[i, :x_len] |
| unified_freqs_cis[i, x_len : x_len + cap_len] = cap_freqs_cis[i, :cap_len] |
|
|
| seq_lens_tensor = torch.tensor(unified_item_seqlens, device=device, dtype=torch.int32) |
| arange = torch.arange(unified_max_item_seqlen, device=device, dtype=torch.int32) |
| unified_attn_mask = arange[None, :] < seq_lens_tensor[:, None] |
|
|
| hints = None |
| if is_control_mode: |
| kwargs_for_hints = { |
| "attn_mask": unified_attn_mask, |
| "freqs_cis": unified_freqs_cis, |
| "adaln_input": adaln_input, |
| } |
| if self.is_two_stage_control: |
| control_context_unified_list = [ |
| torch.cat([control_context_processed[i][: control_context_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz) |
| ] |
| c = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0) |
| new_kwargs = dict(x=unified, **kwargs_for_hints) |
| c_processed = self._apply_transformer_blocks(c, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs) |
| hints = torch.unbind(c_processed)[:-1] |
| else: |
| prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device) |
| c = prepared_control["c"] |
| kwargs_for_v1_refiner = { |
| "attn_mask": prepared_control["attn_mask"], |
| "freqs_cis": prepared_control["freqs_cis"], |
| "adaln_input": prepared_control["adaln_input"], |
| } |
| c = self._apply_transformer_blocks(c, self.control_noise_refiner, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_v1_refiner) |
| c_item_seqlens = prepared_control["c_item_seqlens"] |
| control_context_unified_list = [torch.cat([c[i, : c_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)] |
| c_unified = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0) |
| new_kwargs = dict(x=unified, **kwargs_for_hints) |
| c_processed = self._apply_transformer_blocks(c_unified, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs) |
| hints = torch.unbind(c_processed)[:-1] |
|
|
| kwargs_for_layers = {"attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} |
| if hints is not None: |
| kwargs_for_layers["hints"] = hints |
| kwargs_for_layers["context_scale"] = conditioning_scale |
| unified = self._apply_transformer_blocks(unified, self.layers, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_layers) |
|
|
| unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) |
| x_image_tokens = unified_out[:, :x_max_item_seqlen] |
| x_final_tensor = self._unpatchify(x_image_tokens, x_size, patch_size, f_patch_size) |
|
|
| return Transformer2DModelOutput(sample=x_final_tensor) |
|
|