# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. # Refactored and optimized by DEVAIEXP Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention_dispatch import dispatch_attention_fn from diffusers.models.attention_processor import Attention, AttentionProcessor from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import RMSNorm from diffusers.utils import ( is_torch_version, ) from diffusers.utils.torch_utils import maybe_allow_in_graph from torch.nn.utils.rnn import pad_sequence ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 def zero_module(module): """ Initializes the parameters of a given module with zeros. Args: module (nn.Module): The module to be zero-initialized. Returns: nn.Module: The same module with its parameters initialized to zero. """ for p in module.parameters(): nn.init.zeros_(p) return module class TimestepEmbedder(nn.Module): """ A module to embed timesteps into a higher-dimensional space using sinusoidal embeddings followed by a multilayer perceptron (MLP). """ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): """ Initializes the TimestepEmbedder module. Args: out_size (int): The output dimension of the embedding. mid_size (int, optional): The intermediate dimension of the MLP. Defaults to `out_size`. frequency_embedding_size (int, optional): The dimension of the sinusoidal frequency embedding. Defaults to 256. """ super().__init__() if mid_size is None: mid_size = out_size self.mlp = nn.Sequential( nn.Linear( frequency_embedding_size, mid_size, bias=True, ), nn.SiLU(), nn.Linear( mid_size, out_size, bias=True, ), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Creates sinusoidal timestep embeddings. Args: t (torch.Tensor): A 1-D Tensor of N timesteps. dim (int): The dimension of the embedding. max_period (int, optional): The maximum period for the sinusoidal frequencies. Defaults to 10000. Returns: torch.Tensor: The timestep embeddings with shape (N, dim). """ with torch.amp.autocast("cuda", enabled=False): half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) args = t[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): """ Processes the input timesteps to generate embeddings. Args: t (torch.Tensor): The input timesteps. Returns: torch.Tensor: The final timestep embeddings after passing through the MLP. """ t_freq = self.timestep_embedding(t, self.frequency_embedding_size) weight_dtype = self.mlp[0].weight.dtype if weight_dtype.is_floating_point: t_freq = t_freq.to(weight_dtype) t_emb = self.mlp(t_freq) return t_emb class FeedForward(nn.Module): """ A Feed-Forward Network module using SwiGLU activation. """ def __init__(self, dim: int, hidden_dim: int): """ Initializes the FeedForward module. Args: dim (int): Input and output dimension. hidden_dim (int): The hidden dimension of the network. """ super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) def _forward_silu_gating(self, x1, x3): """ Applies the SiLU gating mechanism. Args: x1 (torch.Tensor): The first intermediate tensor. x3 (torch.Tensor): The second intermediate tensor (gate). Returns: torch.Tensor: The result of the gating operation. """ return F.silu(x1) * x3 def forward(self, x): """ Defines the forward pass of the FeedForward network. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor. """ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) class FinalLayer(nn.Module): """ The final layer of the transformer, which applies AdaLN modulation and a linear projection. """ def __init__(self, hidden_size, out_channels): """ Initializes the FinalLayer module. Args: hidden_size (int): The input hidden size. out_channels (int): The output dimension (number of channels). """ super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), ) def forward(self, x, c): """ Defines the forward pass for the final layer. Args: x (torch.Tensor): The main input tensor from the transformer blocks. c (torch.Tensor): The conditioning tensor (usually from timestep embedding) for AdaLN modulation. Returns: torch.Tensor: The final output tensor projected to the patch dimension. """ scale = 1.0 + self.adaLN_modulation(c) x = self.norm_final(x) * scale.unsqueeze(1) x = self.linear(x) return x class RopeEmbedder: """ Computes Rotary Positional Embeddings (RoPE) for 3D coordinates. """ def __init__(self, theta: float = 256.0, axes_dims: List[int] = (32, 48, 48), axes_lens: List[int] = (1024, 512, 512)): """ Initializes the RopeEmbedder. Args: theta (float, optional): The base for the rotary frequencies. Defaults to 256.0. axes_dims (List[int], optional): The dimensions for each axis (F, H, W). Defaults to (32, 48, 48). axes_lens (List[int], optional): The maximum length for each axis. Defaults to (1024, 512, 512). """ self.theta = theta self.axes_dims = axes_dims self.axes_lens = axes_lens self.freqs_cis_cache = {} def _precompute_freqs_cis(self, device): """ Precomputes and caches the rotary frequency tensors (cos and sin values). Args: device (torch.device): The device to store the cached tensors on. Returns: List[torch.Tensor]: A list of precomputed frequency tensors for each axis. """ if device in self.freqs_cis_cache: return self.freqs_cis_cache[device] freqs_cis_list = [] for dim, max_len in zip(self.axes_dims, self.axes_lens): half = dim // 2 freqs = 1.0 / (self.theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half)) t = torch.arange(max_len, device=device, dtype=torch.float32) freqs = torch.outer(t, freqs) emb = torch.stack([freqs.cos(), freqs.sin()], dim=-1) freqs_cis_list.append(emb) self.freqs_cis_cache[device] = freqs_cis_list return freqs_cis_list def __call__(self, ids: torch.Tensor): """ Generates RoPE embeddings for a batch of 3D coordinates. Args: ids (torch.Tensor): A tensor of coordinates with shape (N, 3). Returns: torch.Tensor: The concatenated RoPE embeddings for the input coordinates. """ assert ids.ndim == 2 and ids.shape[1] == len(self.axes_dims) device = ids.device freqs_cis_list = self._precompute_freqs_cis(device) result = [] for i in range(len(self.axes_dims)): result.append(freqs_cis_list[i][ids[:, i]]) return torch.cat(result, dim=-2) class ZSingleStreamAttnProcessor: """ An attention processor that applies Rotary Positional Embeddings (RoPE) to query and key tensors before computing scaled dot-product attention. """ _attention_backend = None _parallel_config = None def __init__(self): """ Initializes the ZSingleStreamAttnProcessor. """ if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ The forward call for the attention processor. Args: attn (Attention): The attention layer that this processor is attached to. hidden_states (torch.Tensor): The input hidden states. encoder_hidden_states (Optional[torch.Tensor], optional): Not used in self-attention. Defaults to None. attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None. freqs_cis (Optional[torch.Tensor], optional): The precomputed RoPE frequencies. Defaults to None. Returns: torch.Tensor: The output of the attention mechanism. """ def apply_rotary_emb(q_or_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """ Applies RoPE to a query or key tensor. """ x = q_or_k.transpose(1, 2) x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) x0 = x_reshaped[..., 0] x1 = x_reshaped[..., 1] freqs_cos = freqs_cis[..., 0].unsqueeze(1) freqs_sin = freqs_cis[..., 1].unsqueeze(1) x_rotated_0 = x0 * freqs_cos - x1 * freqs_sin x_rotated_1 = x0 * freqs_sin + x1 * freqs_cos x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1) x_out = x_rotated.flatten(-2).transpose(1, 2) return x_out.to(q_or_k.dtype) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) query = query.unflatten(-1, (attn.heads, -1)) key = key.unflatten(-1, (attn.heads, -1)) value = value.unflatten(-1, (attn.heads, -1)) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) if freqs_cis is not None: query = apply_rotary_emb(query, freqs_cis) key = apply_rotary_emb(key, freqs_cis) if attention_mask is not None and attention_mask.ndim == 2: attention_mask = attention_mask[:, None, None, :] hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, backend=self._attention_backend, parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) output = attn.to_out[0](hidden_states.to(hidden_states.dtype)) if len(attn.to_out) > 1: output = attn.to_out[1](output) return output @maybe_allow_in_graph class ZImageTransformerBlock(nn.Module): """ A standard transformer block consisting of a self-attention layer and a feed-forward network. Includes support for AdaLN modulation. """ def __init__( self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, ): """ Initializes the ZImageTransformerBlock. Args: layer_id (int): The index of the layer. dim (int): The dimension of the input and output features. n_heads (int): The number of attention heads. n_kv_heads (int): The number of key/value heads (not directly used in this simplified attention). norm_eps (float): Epsilon for RMSNorm. qk_norm (bool): Whether to apply normalization to query and key tensors. modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. """ super().__init__() self.dim = dim self.head_dim = dim // n_heads self.attention = Attention( query_dim=dim, cross_attention_dim=None, dim_head=dim // n_heads, heads=n_heads, qk_norm="rms_norm" if qk_norm else None, eps=1e-5, bias=False, out_bias=False, processor=ZSingleStreamAttnProcessor(), ) self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) self.layer_id = layer_id self.attention_norm1 = RMSNorm(dim, eps=norm_eps) self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) self.attention_norm2 = RMSNorm(dim, eps=norm_eps) self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) self.modulation = modulation if modulation: self.adaLN_modulation = nn.Sequential( nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: """ Returns a dictionary of all attention processors used in the module. """ processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): """ Sets the attention processor for the attention layer in this block. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def forward(self, x, attn_mask, freqs_cis, adaln_input=None): """ Defines the forward pass for the transformer block. Args: x (torch.Tensor): The input tensor. attn_mask (torch.Tensor): The attention mask. freqs_cis (torch.Tensor): The RoPE frequencies. adaln_input (torch.Tensor, optional): The conditioning tensor for AdaLN. Defaults to None. Returns: torch.Tensor: The output tensor of the block. """ if self.modulation: scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) scale_msa = scale_msa + 1.0 gate_msa = gate_msa.tanh() scale_mlp = scale_mlp + 1.0 gate_mlp = gate_mlp.tanh() normed = self.attention_norm1(x) normed = normed * scale_msa attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis) attn_out = self.attention_norm2(attn_out) * gate_msa x = x + attn_out normed = self.ffn_norm1(x) normed = normed * scale_mlp ffn_out = self.feed_forward(normed) ffn_out = self.ffn_norm2(ffn_out) * gate_mlp x = x + ffn_out else: normed = self.attention_norm1(x) attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis) x = x + self.attention_norm2(attn_out) normed = self.ffn_norm1(x) ffn_out = self.feed_forward(normed) x = x + self.ffn_norm2(ffn_out) return x class ZImageControlTransformerBlock(ZImageTransformerBlock): """ A specialized transformer block for the control pathway. It inherits from ZImageTransformerBlock and adds projection layers to generate and combine control signals. """ def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0): """ Initializes the ZImageControlTransformerBlock. Args: layer_id (int): The index of the layer. dim (int): The dimension of the features. n_heads (int): The number of attention heads. n_kv_heads (int): The number of key/value heads. norm_eps (float): Epsilon for RMSNorm. qk_norm (bool): Whether to apply normalization to query and key. modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. block_id (int, optional): The index of this control block. Defaults to 0. """ super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) self.block_id = block_id if block_id == 0: self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) def forward(self, c, x, **kwargs): """ Defines the forward pass for the control block. Args: c (torch.Tensor): The control signal tensor. x (torch.Tensor): The reference tensor from the main pathway. **kwargs: Additional arguments for the parent's forward method. Returns: torch.Tensor: A stacked tensor containing the skip connection and the final output. """ if self.block_id == 0: c = self.before_proj(c) + x all_c = [] else: all_c = list(torch.unbind(c)) c = all_c.pop(-1) c = super().forward(c, **kwargs) c_skip = self.after_proj(c) all_c += [c_skip, c] c = torch.stack(all_c) return c class BaseZImageTransformerBlock(ZImageTransformerBlock): """ The main transformer block used in the primary pathway. It inherits from ZImageTransformerBlock and adds the logic to inject control "hints" from the control pathway. """ def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0): """ Initializes the BaseZImageTransformerBlock. Args: layer_id (int): The index of the layer. dim (int): The dimension of the features. n_heads (int): The number of attention heads. n_kv_heads (int): The number of key/value heads. norm_eps (float): Epsilon for RMSNorm. qk_norm (bool): Whether to apply normalization to query and key. modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. block_id (int, optional): The index used to retrieve the corresponding control hint. Defaults to 0. """ super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) self.block_id = block_id def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs): """ Defines the forward pass, including the injection of control hints. Args: hidden_states (torch.Tensor): The input tensor. hints (List[torch.Tensor], optional): A list of control hints from the control pathway. Defaults to None. context_scale (float, optional): A scale factor for the control hints. Defaults to 1.0. **kwargs: Additional arguments for the parent's forward method. Returns: torch.Tensor: The output tensor of the block. """ hidden_states = super().forward(hidden_states, **kwargs) if self.block_id is not None and hints is not None: hidden_states = hidden_states + hints[self.block_id] * context_scale return hidden_states class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [ r"control_layers\..*", r"control_noise_refiner\..*", r"control_all_x_embedder\..*", ] _no_split_modules = ["ZImageTransformerBlock", "BaseZImageTransformerBlock", "ZImageControlTransformerBlock"] _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] _group_offload_block_modules = ["t_embedder", "cap_embedder"] @register_to_config def __init__( self, control_layers_places=None, control_refiner_layers_places=None, control_in_dim=None, add_control_noise_refiner=False, all_patch_size=(2,), all_f_patch_size=(1,), in_channels=16, dim=3840, n_layers=30, n_refiner_layers=2, n_heads=30, n_kv_heads=30, norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560, rope_theta=256.0, t_scale=1000.0, axes_dims=[32, 48, 48], axes_lens=[1024, 512, 512], use_controlnet=True, checkpoint_ratio=0.5, ): """ Initializes the ZImageControlTransformer2DModel. Args: control_layers_places (List[int], optional): Indices of main layers where control hints are injected. control_refiner_layers_places (List[int], optional): Indices of noise refiner layers for two-stage control. control_in_dim (int, optional): Input channel dimension for the control context. add_control_noise_refiner (bool, optional): Whether to add a dedicated refiner for the control signal. all_patch_size (Tuple[int], optional): Tuple of patch sizes for spatial dimensions. all_f_patch_size (Tuple[int], optional): Tuple of patch sizes for the frame dimension. in_channels (int, optional): Number of input channels for the latent image. dim (int, optional): The main dimension of the transformer model. n_layers (int, optional): The number of main transformer layers. n_refiner_layers (int, optional): The number of layers in the refiner blocks. n_heads (int, optional): The number of attention heads. n_kv_heads (int, optional): The number of key/value heads. norm_eps (float, optional): Epsilon for RMSNorm. qk_norm (bool, optional): Whether to apply normalization to query and key. cap_feat_dim (int, optional): The dimension of the input caption features. rope_theta (float, optional): The base for RoPE. t_scale (float, optional): A scaling factor for the timestep. axes_dims (List[int], optional): Dimensions for each axis in RoPE. axes_lens (List[int], optional): Maximum lengths for each axis in RoPE. use_controlnet (bool, optional): If False, control-related layers will not be created to save memory. checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. """ super().__init__() self.use_controlnet = use_controlnet self.in_channels = in_channels self.out_channels = in_channels self.all_patch_size = all_patch_size self.all_f_patch_size = all_f_patch_size self.dim = dim self.control_in_dim = self.dim if control_in_dim is None else control_in_dim self.is_two_stage_control = self.control_in_dim > 16 self.n_heads = n_heads self.rope_theta = rope_theta self.t_scale = t_scale self.gradient_checkpointing = False self.checkpoint_ratio = checkpoint_ratio assert len(all_patch_size) == len(all_f_patch_size) self.control_layers_places = list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places self.control_refiner_layers_places = list(range(0, n_refiner_layers)) if control_refiner_layers_places is None else control_refiner_layers_places self.add_control_noise_refiner = add_control_noise_refiner assert 0 in self.control_layers_places self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)} self.control_refiner_layers_mapping = {i: n for n, i in enumerate(self.control_refiner_layers_places)} self.all_x_embedder = nn.ModuleDict( { f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) } ) self.all_final_layer = nn.ModuleDict( { f"{patch_size}-{f_patch_size}": FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) } ) self.context_refiner = nn.ModuleList( [ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False) for i in range(n_refiner_layers)] ) self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) head_dim = dim // n_heads assert head_dim == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) self.layers = nn.ModuleList( [BaseZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=self.control_layers_mapping.get(i)) for i in range(n_layers)] ) self.noise_refiner = nn.ModuleList( [ BaseZImageTransformerBlock( 1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=self.control_refiner_layers_mapping.get(i) ) for i in range(n_refiner_layers) ] ) if self.use_controlnet: self.control_layers = nn.ModuleList( [ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places] ) self.control_all_x_embedder = nn.ModuleDict( { f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) } ) if self.is_two_stage_control: if self.add_control_noise_refiner: self.control_noise_refiner = nn.ModuleList( [ ZImageControlTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=layer_id) for layer_id in range(n_refiner_layers) ] ) else: self.control_noise_refiner = None else: # V1 self.control_noise_refiner = nn.ModuleList( [ZImageTransformerBlock(1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True) for i in range(n_refiner_layers)] ) else: self.control_layers = None self.control_all_x_embedder = None self.control_noise_refiner = None def _unpatchify(self, x_image_tokens: torch.Tensor, all_sizes: List[Tuple], patch_size: int, f_patch_size: int) -> torch.Tensor: """ Converts a sequence of image tokens back into a batched image tensor. This version is robust to batches containing images of different original sizes. Args: x_image_tokens (torch.Tensor): A tensor of image tokens with shape [B, SeqLen, Dim]. all_sizes (List[Tuple]): A list of tuples with the original (F, H, W) size for each image in the batch. patch_size (int): The spatial patch size (height and width). f_patch_size (int): The frame/temporal patch size. Returns: torch.Tensor: The reconstructed latent tensor with shape [B, C, F, H, W]. """ pH = pW = patch_size pF = f_patch_size batch_size = x_image_tokens.shape[0] unpatched_images = [] for i in range(batch_size): F, H, W = all_sizes[i] F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW original_seq_len = F_tokens * H_tokens * W_tokens current_image_tokens = x_image_tokens[i, :original_seq_len, :] unpatched_image = current_image_tokens.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, self.out_channels) unpatched_image = unpatched_image.permute(6, 0, 3, 1, 4, 2, 5).reshape(self.out_channels, F, H, W) unpatched_images.append(unpatched_image) try: final_tensor = torch.stack(unpatched_images, dim=0) except RuntimeError: raise ValueError( "Could not stack unpatched images into a single batch tensor. " "This typically occurs if you are trying to generate images of different sizes in the same batch." ) return final_tensor def _patchify( self, all_image: List[torch.Tensor], patch_size: int, f_patch_size: int, cap_padding_len: int, ): """ Converts a list of image tensors into patch sequences and computes their positional IDs. Args: all_image (List[torch.Tensor]): A list of image tensors to process. patch_size (int): The spatial patch size. f_patch_size (int): The frame/temporal patch size. cap_padding_len (int): The length of the padded caption sequence, used as an offset for image position IDs. Returns: Tuple: A tuple containing lists of processed patches, sizes, position IDs, and padding masks. """ pH = pW = patch_size pF = f_patch_size device = all_image[0].device all_image_out = [] all_image_size = [] all_image_pos_ids = [] all_image_pad_mask = [] for i, image in enumerate(all_image): C, F, H, W = image.size() all_image_size.append((F, H, W)) F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) image_ori_len = len(image) image_padding_len = (-image_ori_len) % SEQ_MULTI_OF image_ori_pos_ids = self._create_coordinate_grid( size=(F_tokens, H_tokens, W_tokens), start=(cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) image_padding_pos_ids = ( self._create_coordinate_grid( size=(1, 1, 1), start=(0, 0, 0), device=device, ) .flatten(0, 2) .repeat(image_padding_len, 1) ) image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) all_image_pos_ids.append(image_padded_pos_ids) all_image_pad_mask.append( torch.cat( [ torch.zeros((image_ori_len,), dtype=torch.bool, device=device), torch.ones((image_padding_len,), dtype=torch.bool, device=device), ], dim=0, ) ) image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) all_image_out.append(image_padded_feat) return ( all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask, ) def _patchify_and_embed( self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int, ): """ Processes a batch of images and caption features by converting them into padded patch sequences and generating their corresponding positional IDs and padding masks. This is the general-purpose, robust version that iterates through the batch. Args: all_image (List[torch.Tensor]): A list of image tensors. all_cap_feats (List[torch.Tensor]): A list of caption feature tensors. patch_size (int): The spatial patch size. f_patch_size (int): The frame/temporal patch size. Returns: Tuple: A tuple containing all processed data structures (image patches, caption features, sizes, position IDs, and padding masks) as lists. """ pH = pW = patch_size pF = f_patch_size device = all_image[0].device all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], [] all_cap_pos_ids, all_cap_pad_mask, all_cap_feats_out = [], [], [] for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): cap_ori_len = len(cap_feat) cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF cap_total_len = cap_ori_len + cap_padding_len cap_padded_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2) all_cap_pos_ids.append(cap_padded_pos_ids) cap_mask = torch.ones(cap_total_len, dtype=torch.bool, device=device) cap_mask[:cap_ori_len] = False all_cap_pad_mask.append(cap_mask) if cap_padding_len > 0: padding_tensor = cap_feat[-1:].repeat(cap_padding_len, 1) cap_padded_feat = torch.cat([cap_feat, padding_tensor], dim=0) else: cap_padded_feat = cap_feat all_cap_feats_out.append(cap_padded_feat) C, Fr, H, W = image.size() all_image_size.append((Fr, H, W)) F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW image_reshaped = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW).permute(1, 3, 5, 2, 4, 6, 0).reshape(-1, pF * pH * pW * C) image_ori_len = image_reshaped.shape[0] image_padding_len = (-image_ori_len) % SEQ_MULTI_OF image_total_len = image_ori_len + image_padding_len image_ori_pos_ids = self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device).flatten(0, 2) if image_padding_len > 0: image_padding_pos_ids = torch.zeros((image_padding_len, 3), dtype=torch.int32, device=device) image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) else: image_padded_pos_ids = image_ori_pos_ids all_image_pos_ids.append(image_padded_pos_ids) image_mask = torch.ones(image_total_len, dtype=torch.bool, device=device) image_mask[:image_ori_len] = False all_image_pad_mask.append(image_mask) if image_padding_len > 0: padding_tensor = image_reshaped[-1:].repeat(image_padding_len, 1) image_padded_feat = torch.cat([image_reshaped, padding_tensor], dim=0) else: image_padded_feat = image_reshaped all_image_out.append(image_padded_feat) return ( all_image_out, all_cap_feats_out, all_image_size, all_image_pos_ids, all_cap_pos_ids, all_image_pad_mask, all_cap_pad_mask, ) def _process_cap_feats_with_cfg_cache(self, cap_feats_list, cap_pos_ids, cap_inner_pad_mask): """ Processes caption features with intelligent duplicate detection to avoid redundant computation, especially for Classifier-Free Guidance (CFG) where prompts are repeated. Args: cap_feats_list (List[torch.Tensor]): List of padded caption feature tensors. cap_pos_ids (List[torch.Tensor]): List of corresponding position ID tensors. cap_inner_pad_mask (List[torch.Tensor]): List of corresponding padding masks. Returns: Tuple: A tuple of batched tensors for padded features, RoPE frequencies, attention mask, and sequence lengths. """ device = cap_feats_list[0].device bsz = len(cap_feats_list) shapes_equal = all(c.shape == cap_feats_list[0].shape for c in cap_feats_list) if shapes_equal and bsz >= 2: unique_indices = [0] unique_tensors = [cap_feats_list[0]] tensor_mapping = [0] for i in range(1, bsz): found_match = False for j, unique_tensor in enumerate(unique_tensors): if torch.equal(cap_feats_list[i], unique_tensor): tensor_mapping.append(j) found_match = True break if not found_match: unique_indices.append(i) unique_tensors.append(cap_feats_list[i]) tensor_mapping.append(len(unique_tensors) - 1) if len(unique_tensors) < bsz: unique_cap_feats_list = [cap_feats_list[i] for i in unique_indices] unique_cap_pos_ids = [cap_pos_ids[i] for i in unique_indices] unique_cap_inner_pad_mask = [cap_inner_pad_mask[i] for i in unique_indices] cap_item_seqlens_unique = [len(i) for i in unique_cap_feats_list] cap_max_item_seqlen = max(cap_item_seqlens_unique) cap_feats_cat = torch.cat(unique_cap_feats_list, dim=0) cap_feats_embedded = self.cap_embedder(cap_feats_cat) cap_feats_embedded[torch.cat(unique_cap_inner_pad_mask)] = self.cap_pad_token cap_feats_padded_unique = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0) cap_freqs_cis_cat = self.rope_embedder(torch.cat(unique_cap_pos_ids, dim=0)) cap_freqs_cis_unique = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0) cap_feats_padded = cap_feats_padded_unique[tensor_mapping] cap_freqs_cis = cap_freqs_cis_unique[tensor_mapping] seq_lens_tensor = torch.tensor([cap_max_item_seqlen] * bsz, device=device, dtype=torch.int32) arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32) cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None] cap_item_seqlens = [cap_max_item_seqlen] * bsz return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens cap_item_seqlens = [len(i) for i in cap_feats_list] cap_max_item_seqlen = max(cap_item_seqlens) cap_feats_cat = torch.cat(cap_feats_list, dim=0) cap_feats_embedded = self.cap_embedder(cap_feats_cat) cap_feats_embedded[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token cap_feats_padded = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) cap_freqs_cis_cat = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)) cap_freqs_cis = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) seq_lens_tensor = torch.tensor(cap_item_seqlens, device=device, dtype=torch.int32) arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32) cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None] return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens @staticmethod def _create_coordinate_grid(size, start=None, device=None): """ Creates a 3D coordinate grid. Args: size (Tuple[int]): The dimensions of the grid (F, H, W). start (Tuple[int], optional): The starting coordinates for each axis. Defaults to (0, 0, 0). device (torch.device, optional): The device to create the tensor on. Defaults to None. Returns: torch.Tensor: The coordinate grid tensor. """ if start is None: start = (0 for _ in size) axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] grids = torch.meshgrid(axes, indexing="ij") return torch.stack(grids, dim=-1) def _apply_transformer_blocks(self, hidden_states, layers, checkpoint_ratio=0.5, **kwargs): """ Applies a list of transformer layers to the hidden states, with optional selective gradient checkpointing. Args: hidden_states (torch.Tensor): The input tensor. layers (nn.ModuleList): The list of transformer layers to apply. checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. Defaults to 0.5. **kwargs: Additional keyword arguments to pass to each layer's forward method. Returns: torch.Tensor: The output tensor after applying all layers. """ if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, **static_kwargs): def custom_forward(*inputs): return module(*inputs, **static_kwargs) return custom_forward ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} checkpoint_every_n = max(1, int(1.0 / checkpoint_ratio)) if checkpoint_ratio > 0 else len(layers) + 1 for i, layer in enumerate(layers): if i % checkpoint_every_n == 0: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(layer, **kwargs), hidden_states, **ckpt_kwargs, ) else: hidden_states = layer(hidden_states, **kwargs) else: for layer in layers: hidden_states = layer(hidden_states, **kwargs) return hidden_states def _prepare_control_inputs(self, control_context, cap_feats_ref, t, patch_size, f_patch_size, device): """ Prepares the control context for the transformer, including patchifying, embedding, and generating positional information. Includes a fast path for batches with uniform shapes. Args: control_context (torch.Tensor or List[torch.Tensor]): The control context input. cap_feats_ref (List[torch.Tensor]): A reference to caption features for padding calculation. t (torch.Tensor): The timestep tensor. patch_size (int): The spatial patch size. f_patch_size (int): The frame/temporal patch size. device (torch.device): The target device. Returns: Dict: A dictionary containing the processed control tensors ('c', 'c_item_seqlens', 'attn_mask', etc.). """ bsz = control_context.shape[0] if isinstance(control_context, torch.Tensor) and control_context.ndim == 5: control_list = list(torch.unbind(control_context, dim=0)) else: control_list = control_context pH = pW = patch_size pF = f_patch_size cap_padding_len = cap_feats_ref[0].size(0) if isinstance(cap_feats_ref, list) else cap_feats_ref.shape[1] shapes = [c.shape for c in control_list] same_shape = all(s == shapes[0] for s in shapes) if same_shape and bsz >= 2: control_batch = torch.stack(control_list, dim=0) B, C, F, H, W = control_batch.shape F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW control_batch = control_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW) control_batch = control_batch.permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C) ori_len = control_batch.shape[1] padding_len = (-ori_len) % SEQ_MULTI_OF if padding_len > 0: pad_tensor = control_batch[:, -1:, :].repeat(1, padding_len, 1) control_batch = torch.cat([control_batch, pad_tensor], dim=1) c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_batch) final_seq_len = control_batch.shape[1] pos_ids_ori = self._create_coordinate_grid( size=(F_tokens, H_tokens, W_tokens), start=(cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) # [ori_len, 3] pos_ids_pad = torch.zeros((padding_len, 3), dtype=torch.int32, device=device) pos_ids_padded = torch.cat([pos_ids_ori, pos_ids_pad], dim=0) c_freqs_cis_single = self.rope_embedder(pos_ids_padded) c_freqs_cis = c_freqs_cis_single.unsqueeze(0).repeat(B, 1, 1, 1) c_attn_mask = torch.ones((B, final_seq_len), dtype=torch.bool, device=device) return {"c": c, "c_item_seqlens": [final_seq_len] * B, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis, "adaln_input": t.type_as(c)} (c_patches, _, c_pos_ids, c_inner_pad_mask) = self._patchify(control_list, patch_size, f_patch_size, cap_padding_len) c_item_seqlens = [len(p) for p in c_patches] c_max_item_seqlen = max(c_item_seqlens) c = torch.cat(c_patches, dim=0) c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](c) c[torch.cat(c_inner_pad_mask)] = self.x_pad_token c = list(c.split(c_item_seqlens, dim=0)) c_freqs_cis_list = [] for pos_ids in c_pos_ids: c_freqs_cis_list.append(self.rope_embedder(pos_ids)) c_padded = pad_sequence(c, batch_first=True, padding_value=0.0) c_freqs_cis_padded = pad_sequence(c_freqs_cis_list, batch_first=True, padding_value=0.0) seq_lens_tensor = torch.tensor(c_item_seqlens, device=device, dtype=torch.int32) arange = torch.arange(c_max_item_seqlen, device=device, dtype=torch.int32) c_attn_mask = arange[None, :] < seq_lens_tensor[:, None] return {"c": c_padded, "c_item_seqlens": c_item_seqlens, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis_padded, "adaln_input": t.type_as(c_padded)} def _patchify_and_embed_batch_optimized(self, all_image, all_cap_feats, patch_size, f_patch_size): """ An optimized version of _patchify_and_embed for batches where all images and captions have uniform shapes. It processes the entire batch using vectorized operations instead of a loop. Args: all_image (List[torch.Tensor]): List of image tensors, all of the same shape. all_cap_feats (List[torch.Tensor]): List of caption features, all of the same shape. patch_size (int): The spatial patch size. f_patch_size (int): The frame/temporal patch size. Returns: Tuple: A tuple containing all processed data structures, matching the output of the standard method. """ pH = pW = patch_size pF = f_patch_size device = all_image[0].device image_shapes = [img.shape for img in all_image] cap_shapes = [cap.shape for cap in all_cap_feats] same_image_shape = all(s == image_shapes[0] for s in image_shapes) same_cap_shape = all(s == cap_shapes[0] for s in cap_shapes) if not (same_image_shape and same_cap_shape): return self._patchify_and_embed(all_image, all_cap_feats, patch_size, f_patch_size) images_batch = torch.stack(all_image, dim=0) caps_batch = torch.stack(all_cap_feats, dim=0) B, C, Fr, H, W = images_batch.shape cap_ori_len = caps_batch.shape[1] cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF cap_total_len = cap_ori_len + cap_padding_len if cap_padding_len > 0: cap_pad = caps_batch[:, -1:, :].repeat(1, cap_padding_len, 1) caps_batch = torch.cat([caps_batch, cap_pad], dim=1) cap_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2).unsqueeze(0).repeat(B, 1, 1) cap_mask = torch.zeros((B, cap_total_len), dtype=torch.bool, device=device) if cap_padding_len > 0: cap_mask[:, cap_ori_len:] = True F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW images_reshaped = ( images_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW) .permute(0, 2, 4, 6, 3, 5, 7, 1) .reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C) ) image_ori_len = images_reshaped.shape[1] image_padding_len = (-image_ori_len) % SEQ_MULTI_OF image_total_len = image_ori_len + image_padding_len if image_padding_len > 0: img_pad = images_reshaped[:, -1:, :].repeat(1, image_padding_len, 1) images_reshaped = torch.cat([images_reshaped, img_pad], dim=1) image_pos_ids = ( self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device) .flatten(0, 2) .unsqueeze(0) .repeat(B, 1, 1) ) if image_padding_len > 0: img_pos_pad = torch.zeros((B, image_padding_len, 3), dtype=torch.int32, device=device) image_pos_ids = torch.cat([image_pos_ids, img_pos_pad], dim=1) image_mask = torch.zeros((B, image_total_len), dtype=torch.bool, device=device) if image_padding_len > 0: image_mask[:, image_ori_len:] = True all_image_size = [(Fr, H, W)] * B return ( list(torch.unbind(images_reshaped, dim=0)), list(torch.unbind(caps_batch, dim=0)), all_image_size, list(torch.unbind(image_pos_ids, dim=0)), list(torch.unbind(cap_pos_ids, dim=0)), list(torch.unbind(image_mask, dim=0)), list(torch.unbind(cap_mask, dim=0)), ) def forward( self, x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, control_context=None, conditioning_scale=1.0, refiner_conditioning_scale=1.0, ): """ The main forward pass of the transformer model. Args: x (List[torch.Tensor]): A list of latent image tensors. t (torch.Tensor): A batch of timesteps. cap_feats (List[torch.Tensor]): A list of caption feature tensors. patch_size (int, optional): The spatial patch size to use. Defaults to 2. f_patch_size (int, optional): The frame/temporal patch size to use. Defaults to 1. control_context (torch.Tensor, optional): The control context tensor. Defaults to None. conditioning_scale (float, optional): The scale for applying control hints. Defaults to 1.0. refiner_conditioning_scale (float, optional): The scale for applying refiner control hints. Defaults to 1.0. Returns: Transformer2DModelOutput: An object containing the final denoised sample. """ is_control_mode = self.use_controlnet and control_context is not None and conditioning_scale > 0 if refiner_conditioning_scale is None: refiner_conditioning_scale = conditioning_scale or 1.0 assert patch_size in self.all_patch_size assert f_patch_size in self.all_f_patch_size bsz = len(x) device = x[0].device t = t * self.t_scale t = self.t_embedder(t) can_optimize_patchify = ( bsz == len(cap_feats) and bsz >= 2 and all(img.shape == x[0].shape for img in x) and all(cap.shape == cap_feats[0].shape for cap in cap_feats) ) if can_optimize_patchify: (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed_batch_optimized( x, cap_feats, patch_size, f_patch_size ) else: (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed( x, cap_feats, patch_size, f_patch_size ) x_item_seqlens = [len(i) for i in x_list] x_max_item_seqlen = max(x_item_seqlens) if x_item_seqlens else 0 x_cat = torch.cat(x_list, dim=0) if x_list else torch.empty(0, x_list[0].shape[1] if x_list else 0, device=device) x_embedded = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat) if x_inner_pad_mask and torch.cat(x_inner_pad_mask).any(): x_embedded[torch.cat(x_inner_pad_mask)] = self.x_pad_token x = pad_sequence(list(x_embedded.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) adaln_input = t.to(device).type_as(x) cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens = self._process_cap_feats_with_cfg_cache( cap_feats_list, cap_pos_ids, cap_inner_pad_mask ) x_freqs_cis_cat = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) if x_pos_ids else torch.empty(0, device=device) x_freqs_cis = pad_sequence(list(x_freqs_cis_cat.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) seq_lens_tensor = torch.tensor(x_item_seqlens, device=device, dtype=torch.int32) arange = torch.arange(x_max_item_seqlen, device=device, dtype=torch.int32) x_attn_mask = arange[None, :] < seq_lens_tensor[:, None] refiner_hints = None if is_control_mode and self.is_two_stage_control: prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device) c = prepared_control["c"] """ kwargs_for_control_refiner = { "x": x, "attn_mask": prepared_control["attn_mask"], "freqs_cis": prepared_control["freqs_cis"], "adaln_input": prepared_control["adaln_input"], } c_processed = self._apply_transformer_blocks( c, self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_control_refiner, ) refiner_hints = torch.unbind(c_processed)[:-1] control_context_processed = torch.unbind(c_processed)[-1] control_context_item_seqlens = prepared_control["c_item_seqlens"] """ kwargs_for_control_refiner = { "x": x, "attn_mask": x_attn_mask, # was prepared_control["attn_mask"] "freqs_cis": x_freqs_cis, # was prepared_control["freqs_cis"] "adaln_input": adaln_input, } c_processed = self._apply_transformer_blocks( c, self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, # KEEP ORIGINAL checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_control_refiner, ) refiner_hints = torch.unbind(c_processed)[:-1] control_context_processed = torch.unbind(c_processed)[-1] control_context_item_seqlens = prepared_control["c_item_seqlens"] kwargs_for_refiner = { "attn_mask": x_attn_mask, "freqs_cis": x_freqs_cis, "adaln_input": adaln_input, "context_scale": refiner_conditioning_scale, } if refiner_hints is not None: kwargs_for_refiner["hints"] = refiner_hints x = self._apply_transformer_blocks(x, self.noise_refiner, checkpoint_ratio=1.0, **kwargs_for_refiner) kwargs_for_context = {"attn_mask": cap_attn_mask, "freqs_cis": cap_freqs_cis} cap_feats = self._apply_transformer_blocks(cap_feats_padded, self.context_refiner, checkpoint_ratio=1.0, **kwargs_for_context) unified_item_seqlens = [a + b for a, b in zip(x_item_seqlens, cap_item_seqlens)] unified_max_item_seqlen = max(unified_item_seqlens) if unified_item_seqlens else 0 unified = torch.zeros((bsz, unified_max_item_seqlen, x.shape[-1]), dtype=x.dtype, device=device) unified_freqs_cis = torch.zeros((bsz, unified_max_item_seqlen, x_freqs_cis.shape[-2], x_freqs_cis.shape[-1]), dtype=x_freqs_cis.dtype, device=device) for i in range(bsz): x_len = x_item_seqlens[i] cap_len = cap_item_seqlens[i] unified[i, :x_len] = x[i, :x_len] unified[i, x_len : x_len + cap_len] = cap_feats[i, :cap_len] unified_freqs_cis[i, :x_len] = x_freqs_cis[i, :x_len] unified_freqs_cis[i, x_len : x_len + cap_len] = cap_freqs_cis[i, :cap_len] seq_lens_tensor = torch.tensor(unified_item_seqlens, device=device, dtype=torch.int32) arange = torch.arange(unified_max_item_seqlen, device=device, dtype=torch.int32) unified_attn_mask = arange[None, :] < seq_lens_tensor[:, None] hints = None if is_control_mode: kwargs_for_hints = { "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input, } if self.is_two_stage_control: control_context_unified_list = [ torch.cat([control_context_processed[i][: control_context_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz) ] c = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0) new_kwargs = dict(x=unified, **kwargs_for_hints) c_processed = self._apply_transformer_blocks(c, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs) hints = torch.unbind(c_processed)[:-1] else: prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device) c = prepared_control["c"] kwargs_for_v1_refiner = { "attn_mask": prepared_control["attn_mask"], "freqs_cis": prepared_control["freqs_cis"], "adaln_input": prepared_control["adaln_input"], } c = self._apply_transformer_blocks(c, self.control_noise_refiner, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_v1_refiner) c_item_seqlens = prepared_control["c_item_seqlens"] control_context_unified_list = [torch.cat([c[i, : c_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)] c_unified = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0) new_kwargs = dict(x=unified, **kwargs_for_hints) c_processed = self._apply_transformer_blocks(c_unified, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs) hints = torch.unbind(c_processed)[:-1] kwargs_for_layers = {"attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} if hints is not None: kwargs_for_layers["hints"] = hints kwargs_for_layers["context_scale"] = conditioning_scale unified = self._apply_transformer_blocks(unified, self.layers, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_layers) unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) x_image_tokens = unified_out[:, :x_max_item_seqlen] x_final_tensor = self._unpatchify(x_image_tokens, x_size, patch_size, f_patch_size) return Transformer2DModelOutput(sample=x_final_tensor)