Z-Image-SAM-ControlNet / diffusers_local /z_image_control_transformer_2d.py
neuralvfx's picture
Initial commit with large files tracked by LFS
7f0b483
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
# Refactored and optimized by DEVAIEXP Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention_dispatch import dispatch_attention_fn
from diffusers.models.attention_processor import Attention, AttentionProcessor
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import RMSNorm
from diffusers.utils import (
is_torch_version,
)
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch.nn.utils.rnn import pad_sequence
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
def zero_module(module):
"""
Initializes the parameters of a given module with zeros.
Args:
module (nn.Module): The module to be zero-initialized.
Returns:
nn.Module: The same module with its parameters initialized to zero.
"""
for p in module.parameters():
nn.init.zeros_(p)
return module
class TimestepEmbedder(nn.Module):
"""
A module to embed timesteps into a higher-dimensional space using sinusoidal embeddings
followed by a multilayer perceptron (MLP).
"""
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
"""
Initializes the TimestepEmbedder module.
Args:
out_size (int): The output dimension of the embedding.
mid_size (int, optional): The intermediate dimension of the MLP. Defaults to `out_size`.
frequency_embedding_size (int, optional): The dimension of the sinusoidal frequency embedding. Defaults to 256.
"""
super().__init__()
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size,
mid_size,
bias=True,
),
nn.SiLU(),
nn.Linear(
mid_size,
out_size,
bias=True,
),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Creates sinusoidal timestep embeddings.
Args:
t (torch.Tensor): A 1-D Tensor of N timesteps.
dim (int): The dimension of the embedding.
max_period (int, optional): The maximum period for the sinusoidal frequencies. Defaults to 10000.
Returns:
torch.Tensor: The timestep embeddings with shape (N, dim).
"""
with torch.amp.autocast("cuda", enabled=False):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
"""
Processes the input timesteps to generate embeddings.
Args:
t (torch.Tensor): The input timesteps.
Returns:
torch.Tensor: The final timestep embeddings after passing through the MLP.
"""
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForward(nn.Module):
"""
A Feed-Forward Network module using SwiGLU activation.
"""
def __init__(self, dim: int, hidden_dim: int):
"""
Initializes the FeedForward module.
Args:
dim (int): Input and output dimension.
hidden_dim (int): The hidden dimension of the network.
"""
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def _forward_silu_gating(self, x1, x3):
"""
Applies the SiLU gating mechanism.
Args:
x1 (torch.Tensor): The first intermediate tensor.
x3 (torch.Tensor): The second intermediate tensor (gate).
Returns:
torch.Tensor: The result of the gating operation.
"""
return F.silu(x1) * x3
def forward(self, x):
"""
Defines the forward pass of the FeedForward network.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
class FinalLayer(nn.Module):
"""
The final layer of the transformer, which applies AdaLN modulation and a linear projection.
"""
def __init__(self, hidden_size, out_channels):
"""
Initializes the FinalLayer module.
Args:
hidden_size (int): The input hidden size.
out_channels (int): The output dimension (number of channels).
"""
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
)
def forward(self, x, c):
"""
Defines the forward pass for the final layer.
Args:
x (torch.Tensor): The main input tensor from the transformer blocks.
c (torch.Tensor): The conditioning tensor (usually from timestep embedding) for AdaLN modulation.
Returns:
torch.Tensor: The final output tensor projected to the patch dimension.
"""
scale = 1.0 + self.adaLN_modulation(c)
x = self.norm_final(x) * scale.unsqueeze(1)
x = self.linear(x)
return x
class RopeEmbedder:
"""
Computes Rotary Positional Embeddings (RoPE) for 3D coordinates.
"""
def __init__(self, theta: float = 256.0, axes_dims: List[int] = (32, 48, 48), axes_lens: List[int] = (1024, 512, 512)):
"""
Initializes the RopeEmbedder.
Args:
theta (float, optional): The base for the rotary frequencies. Defaults to 256.0.
axes_dims (List[int], optional): The dimensions for each axis (F, H, W). Defaults to (32, 48, 48).
axes_lens (List[int], optional): The maximum length for each axis. Defaults to (1024, 512, 512).
"""
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.freqs_cis_cache = {}
def _precompute_freqs_cis(self, device):
"""
Precomputes and caches the rotary frequency tensors (cos and sin values).
Args:
device (torch.device): The device to store the cached tensors on.
Returns:
List[torch.Tensor]: A list of precomputed frequency tensors for each axis.
"""
if device in self.freqs_cis_cache:
return self.freqs_cis_cache[device]
freqs_cis_list = []
for dim, max_len in zip(self.axes_dims, self.axes_lens):
half = dim // 2
freqs = 1.0 / (self.theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
t = torch.arange(max_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
emb = torch.stack([freqs.cos(), freqs.sin()], dim=-1)
freqs_cis_list.append(emb)
self.freqs_cis_cache[device] = freqs_cis_list
return freqs_cis_list
def __call__(self, ids: torch.Tensor):
"""
Generates RoPE embeddings for a batch of 3D coordinates.
Args:
ids (torch.Tensor): A tensor of coordinates with shape (N, 3).
Returns:
torch.Tensor: The concatenated RoPE embeddings for the input coordinates.
"""
assert ids.ndim == 2 and ids.shape[1] == len(self.axes_dims)
device = ids.device
freqs_cis_list = self._precompute_freqs_cis(device)
result = []
for i in range(len(self.axes_dims)):
result.append(freqs_cis_list[i][ids[:, i]])
return torch.cat(result, dim=-2)
class ZSingleStreamAttnProcessor:
"""
An attention processor that applies Rotary Positional Embeddings (RoPE) to query and key tensors
before computing scaled dot-product attention.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
"""
Initializes the ZSingleStreamAttnProcessor.
"""
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
The forward call for the attention processor.
Args:
attn (Attention): The attention layer that this processor is attached to.
hidden_states (torch.Tensor): The input hidden states.
encoder_hidden_states (Optional[torch.Tensor], optional): Not used in self-attention. Defaults to None.
attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None.
freqs_cis (Optional[torch.Tensor], optional): The precomputed RoPE frequencies. Defaults to None.
Returns:
torch.Tensor: The output of the attention mechanism.
"""
def apply_rotary_emb(q_or_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
Applies RoPE to a query or key tensor.
"""
x = q_or_k.transpose(1, 2)
x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)
x0 = x_reshaped[..., 0]
x1 = x_reshaped[..., 1]
freqs_cos = freqs_cis[..., 0].unsqueeze(1)
freqs_sin = freqs_cis[..., 1].unsqueeze(1)
x_rotated_0 = x0 * freqs_cos - x1 * freqs_sin
x_rotated_1 = x0 * freqs_sin + x1 * freqs_cos
x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)
x_out = x_rotated.flatten(-2).transpose(1, 2)
return x_out.to(q_or_k.dtype)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
output = attn.to_out[0](hidden_states.to(hidden_states.dtype))
if len(attn.to_out) > 1:
output = attn.to_out[1](output)
return output
@maybe_allow_in_graph
class ZImageTransformerBlock(nn.Module):
"""
A standard transformer block consisting of a self-attention layer and a feed-forward network.
Includes support for AdaLN modulation.
"""
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
):
"""
Initializes the ZImageTransformerBlock.
Args:
layer_id (int): The index of the layer.
dim (int): The dimension of the input and output features.
n_heads (int): The number of attention heads.
n_kv_heads (int): The number of key/value heads (not directly used in this simplified attention).
norm_eps (float): Epsilon for RMSNorm.
qk_norm (bool): Whether to apply normalization to query and key tensors.
modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
"""
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // n_heads,
heads=n_heads,
qk_norm="rms_norm" if qk_norm else None,
eps=1e-5,
bias=False,
out_bias=False,
processor=ZSingleStreamAttnProcessor(),
)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
"""
Returns a dictionary of all attention processors used in the module.
"""
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
"""
Sets the attention processor for the attention layer in this block.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(self, x, attn_mask, freqs_cis, adaln_input=None):
"""
Defines the forward pass for the transformer block.
Args:
x (torch.Tensor): The input tensor.
attn_mask (torch.Tensor): The attention mask.
freqs_cis (torch.Tensor): The RoPE frequencies.
adaln_input (torch.Tensor, optional): The conditioning tensor for AdaLN. Defaults to None.
Returns:
torch.Tensor: The output tensor of the block.
"""
if self.modulation:
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
scale_msa = scale_msa + 1.0
gate_msa = gate_msa.tanh()
scale_mlp = scale_mlp + 1.0
gate_mlp = gate_mlp.tanh()
normed = self.attention_norm1(x)
normed = normed * scale_msa
attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis)
attn_out = self.attention_norm2(attn_out) * gate_msa
x = x + attn_out
normed = self.ffn_norm1(x)
normed = normed * scale_mlp
ffn_out = self.feed_forward(normed)
ffn_out = self.ffn_norm2(ffn_out) * gate_mlp
x = x + ffn_out
else:
normed = self.attention_norm1(x)
attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis)
x = x + self.attention_norm2(attn_out)
normed = self.ffn_norm1(x)
ffn_out = self.feed_forward(normed)
x = x + self.ffn_norm2(ffn_out)
return x
class ZImageControlTransformerBlock(ZImageTransformerBlock):
"""
A specialized transformer block for the control pathway. It inherits from ZImageTransformerBlock
and adds projection layers to generate and combine control signals.
"""
def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0):
"""
Initializes the ZImageControlTransformerBlock.
Args:
layer_id (int): The index of the layer.
dim (int): The dimension of the features.
n_heads (int): The number of attention heads.
n_kv_heads (int): The number of key/value heads.
norm_eps (float): Epsilon for RMSNorm.
qk_norm (bool): Whether to apply normalization to query and key.
modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
block_id (int, optional): The index of this control block. Defaults to 0.
"""
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
if block_id == 0:
self.before_proj = zero_module(nn.Linear(self.dim, self.dim))
self.after_proj = zero_module(nn.Linear(self.dim, self.dim))
def forward(self, c, x, **kwargs):
"""
Defines the forward pass for the control block.
Args:
c (torch.Tensor): The control signal tensor.
x (torch.Tensor): The reference tensor from the main pathway.
**kwargs: Additional arguments for the parent's forward method.
Returns:
torch.Tensor: A stacked tensor containing the skip connection and the final output.
"""
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class BaseZImageTransformerBlock(ZImageTransformerBlock):
"""
The main transformer block used in the primary pathway. It inherits from ZImageTransformerBlock
and adds the logic to inject control "hints" from the control pathway.
"""
def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0):
"""
Initializes the BaseZImageTransformerBlock.
Args:
layer_id (int): The index of the layer.
dim (int): The dimension of the features.
n_heads (int): The number of attention heads.
n_kv_heads (int): The number of key/value heads.
norm_eps (float): Epsilon for RMSNorm.
qk_norm (bool): Whether to apply normalization to query and key.
modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
block_id (int, optional): The index used to retrieve the corresponding control hint. Defaults to 0.
"""
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs):
"""
Defines the forward pass, including the injection of control hints.
Args:
hidden_states (torch.Tensor): The input tensor.
hints (List[torch.Tensor], optional): A list of control hints from the control pathway. Defaults to None.
context_scale (float, optional): A scale factor for the control hints. Defaults to 1.0.
**kwargs: Additional arguments for the parent's forward method.
Returns:
torch.Tensor: The output tensor of the block.
"""
hidden_states = super().forward(hidden_states, **kwargs)
if self.block_id is not None and hints is not None:
hidden_states = hidden_states + hints[self.block_id] * context_scale
return hidden_states
class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [
r"control_layers\..*",
r"control_noise_refiner\..*",
r"control_all_x_embedder\..*",
]
_no_split_modules = ["ZImageTransformerBlock", "BaseZImageTransformerBlock", "ZImageControlTransformerBlock"]
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"]
_group_offload_block_modules = ["t_embedder", "cap_embedder"]
@register_to_config
def __init__(
self,
control_layers_places=None,
control_refiner_layers_places=None,
control_in_dim=None,
add_control_noise_refiner=False,
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=3840,
n_layers=30,
n_refiner_layers=2,
n_heads=30,
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=2560,
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512],
use_controlnet=True,
checkpoint_ratio=0.5,
):
"""
Initializes the ZImageControlTransformer2DModel.
Args:
control_layers_places (List[int], optional): Indices of main layers where control hints are injected.
control_refiner_layers_places (List[int], optional): Indices of noise refiner layers for two-stage control.
control_in_dim (int, optional): Input channel dimension for the control context.
add_control_noise_refiner (bool, optional): Whether to add a dedicated refiner for the control signal.
all_patch_size (Tuple[int], optional): Tuple of patch sizes for spatial dimensions.
all_f_patch_size (Tuple[int], optional): Tuple of patch sizes for the frame dimension.
in_channels (int, optional): Number of input channels for the latent image.
dim (int, optional): The main dimension of the transformer model.
n_layers (int, optional): The number of main transformer layers.
n_refiner_layers (int, optional): The number of layers in the refiner blocks.
n_heads (int, optional): The number of attention heads.
n_kv_heads (int, optional): The number of key/value heads.
norm_eps (float, optional): Epsilon for RMSNorm.
qk_norm (bool, optional): Whether to apply normalization to query and key.
cap_feat_dim (int, optional): The dimension of the input caption features.
rope_theta (float, optional): The base for RoPE.
t_scale (float, optional): A scaling factor for the timestep.
axes_dims (List[int], optional): Dimensions for each axis in RoPE.
axes_lens (List[int], optional): Maximum lengths for each axis in RoPE.
use_controlnet (bool, optional): If False, control-related layers will not be created to save memory.
checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to.
"""
super().__init__()
self.use_controlnet = use_controlnet
self.in_channels = in_channels
self.out_channels = in_channels
self.all_patch_size = all_patch_size
self.all_f_patch_size = all_f_patch_size
self.dim = dim
self.control_in_dim = self.dim if control_in_dim is None else control_in_dim
self.is_two_stage_control = self.control_in_dim > 16
self.n_heads = n_heads
self.rope_theta = rope_theta
self.t_scale = t_scale
self.gradient_checkpointing = False
self.checkpoint_ratio = checkpoint_ratio
assert len(all_patch_size) == len(all_f_patch_size)
self.control_layers_places = list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places
self.control_refiner_layers_places = list(range(0, n_refiner_layers)) if control_refiner_layers_places is None else control_refiner_layers_places
self.add_control_noise_refiner = add_control_noise_refiner
assert 0 in self.control_layers_places
self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)}
self.control_refiner_layers_mapping = {i: n for n, i in enumerate(self.control_refiner_layers_places)}
self.all_x_embedder = nn.ModuleDict(
{
f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
}
)
self.all_final_layer = nn.ModuleDict(
{
f"{patch_size}-{f_patch_size}": FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
}
)
self.context_refiner = nn.ModuleList(
[ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False) for i in range(n_refiner_layers)]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
head_dim = dim // n_heads
assert head_dim == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
self.layers = nn.ModuleList(
[BaseZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=self.control_layers_mapping.get(i)) for i in range(n_layers)]
)
self.noise_refiner = nn.ModuleList(
[
BaseZImageTransformerBlock(
1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=self.control_refiner_layers_mapping.get(i)
)
for i in range(n_refiner_layers)
]
)
if self.use_controlnet:
self.control_layers = nn.ModuleList(
[ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places]
)
self.control_all_x_embedder = nn.ModuleDict(
{
f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
}
)
if self.is_two_stage_control:
if self.add_control_noise_refiner:
self.control_noise_refiner = nn.ModuleList(
[
ZImageControlTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=layer_id)
for layer_id in range(n_refiner_layers)
]
)
else:
self.control_noise_refiner = None
else: # V1
self.control_noise_refiner = nn.ModuleList(
[ZImageTransformerBlock(1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True) for i in range(n_refiner_layers)]
)
else:
self.control_layers = None
self.control_all_x_embedder = None
self.control_noise_refiner = None
def _unpatchify(self, x_image_tokens: torch.Tensor, all_sizes: List[Tuple], patch_size: int, f_patch_size: int) -> torch.Tensor:
"""
Converts a sequence of image tokens back into a batched image tensor. This version is robust
to batches containing images of different original sizes.
Args:
x_image_tokens (torch.Tensor): A tensor of image tokens with shape [B, SeqLen, Dim].
all_sizes (List[Tuple]): A list of tuples with the original (F, H, W) size for each image in the batch.
patch_size (int): The spatial patch size (height and width).
f_patch_size (int): The frame/temporal patch size.
Returns:
torch.Tensor: The reconstructed latent tensor with shape [B, C, F, H, W].
"""
pH = pW = patch_size
pF = f_patch_size
batch_size = x_image_tokens.shape[0]
unpatched_images = []
for i in range(batch_size):
F, H, W = all_sizes[i]
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
original_seq_len = F_tokens * H_tokens * W_tokens
current_image_tokens = x_image_tokens[i, :original_seq_len, :]
unpatched_image = current_image_tokens.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, self.out_channels)
unpatched_image = unpatched_image.permute(6, 0, 3, 1, 4, 2, 5).reshape(self.out_channels, F, H, W)
unpatched_images.append(unpatched_image)
try:
final_tensor = torch.stack(unpatched_images, dim=0)
except RuntimeError:
raise ValueError(
"Could not stack unpatched images into a single batch tensor. "
"This typically occurs if you are trying to generate images of different sizes in the same batch."
)
return final_tensor
def _patchify(
self,
all_image: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
cap_padding_len: int,
):
"""
Converts a list of image tensors into patch sequences and computes their positional IDs.
Args:
all_image (List[torch.Tensor]): A list of image tensors to process.
patch_size (int): The spatial patch size.
f_patch_size (int): The frame/temporal patch size.
cap_padding_len (int): The length of the padded caption sequence, used as an offset for image position IDs.
Returns:
Tuple: A tuple containing lists of processed patches, sizes, position IDs, and padding masks.
"""
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
for i, image in enumerate(all_image):
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self._create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padding_pos_ids = (
self._create_coordinate_grid(
size=(1, 1, 1),
start=(0, 0, 0),
device=device,
)
.flatten(0, 2)
.repeat(image_padding_len, 1)
)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
all_image_pos_ids.append(image_padded_pos_ids)
all_image_pad_mask.append(
torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return (
all_image_out,
all_image_size,
all_image_pos_ids,
all_image_pad_mask,
)
def _patchify_and_embed(
self,
all_image: List[torch.Tensor],
all_cap_feats: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
):
"""
Processes a batch of images and caption features by converting them into padded patch sequences
and generating their corresponding positional IDs and padding masks. This is the general-purpose,
robust version that iterates through the batch.
Args:
all_image (List[torch.Tensor]): A list of image tensors.
all_cap_feats (List[torch.Tensor]): A list of caption feature tensors.
patch_size (int): The spatial patch size.
f_patch_size (int): The frame/temporal patch size.
Returns:
Tuple: A tuple containing all processed data structures (image patches, caption features, sizes,
position IDs, and padding masks) as lists.
"""
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], []
all_cap_pos_ids, all_cap_pad_mask, all_cap_feats_out = [], [], []
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
cap_ori_len = len(cap_feat)
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
cap_total_len = cap_ori_len + cap_padding_len
cap_padded_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2)
all_cap_pos_ids.append(cap_padded_pos_ids)
cap_mask = torch.ones(cap_total_len, dtype=torch.bool, device=device)
cap_mask[:cap_ori_len] = False
all_cap_pad_mask.append(cap_mask)
if cap_padding_len > 0:
padding_tensor = cap_feat[-1:].repeat(cap_padding_len, 1)
cap_padded_feat = torch.cat([cap_feat, padding_tensor], dim=0)
else:
cap_padded_feat = cap_feat
all_cap_feats_out.append(cap_padded_feat)
C, Fr, H, W = image.size()
all_image_size.append((Fr, H, W))
F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW
image_reshaped = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW).permute(1, 3, 5, 2, 4, 6, 0).reshape(-1, pF * pH * pW * C)
image_ori_len = image_reshaped.shape[0]
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_total_len = image_ori_len + image_padding_len
image_ori_pos_ids = self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device).flatten(0, 2)
if image_padding_len > 0:
image_padding_pos_ids = torch.zeros((image_padding_len, 3), dtype=torch.int32, device=device)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
else:
image_padded_pos_ids = image_ori_pos_ids
all_image_pos_ids.append(image_padded_pos_ids)
image_mask = torch.ones(image_total_len, dtype=torch.bool, device=device)
image_mask[:image_ori_len] = False
all_image_pad_mask.append(image_mask)
if image_padding_len > 0:
padding_tensor = image_reshaped[-1:].repeat(image_padding_len, 1)
image_padded_feat = torch.cat([image_reshaped, padding_tensor], dim=0)
else:
image_padded_feat = image_reshaped
all_image_out.append(image_padded_feat)
return (
all_image_out,
all_cap_feats_out,
all_image_size,
all_image_pos_ids,
all_cap_pos_ids,
all_image_pad_mask,
all_cap_pad_mask,
)
def _process_cap_feats_with_cfg_cache(self, cap_feats_list, cap_pos_ids, cap_inner_pad_mask):
"""
Processes caption features with intelligent duplicate detection to avoid redundant computation,
especially for Classifier-Free Guidance (CFG) where prompts are repeated.
Args:
cap_feats_list (List[torch.Tensor]): List of padded caption feature tensors.
cap_pos_ids (List[torch.Tensor]): List of corresponding position ID tensors.
cap_inner_pad_mask (List[torch.Tensor]): List of corresponding padding masks.
Returns:
Tuple: A tuple of batched tensors for padded features, RoPE frequencies, attention mask, and sequence lengths.
"""
device = cap_feats_list[0].device
bsz = len(cap_feats_list)
shapes_equal = all(c.shape == cap_feats_list[0].shape for c in cap_feats_list)
if shapes_equal and bsz >= 2:
unique_indices = [0]
unique_tensors = [cap_feats_list[0]]
tensor_mapping = [0]
for i in range(1, bsz):
found_match = False
for j, unique_tensor in enumerate(unique_tensors):
if torch.equal(cap_feats_list[i], unique_tensor):
tensor_mapping.append(j)
found_match = True
break
if not found_match:
unique_indices.append(i)
unique_tensors.append(cap_feats_list[i])
tensor_mapping.append(len(unique_tensors) - 1)
if len(unique_tensors) < bsz:
unique_cap_feats_list = [cap_feats_list[i] for i in unique_indices]
unique_cap_pos_ids = [cap_pos_ids[i] for i in unique_indices]
unique_cap_inner_pad_mask = [cap_inner_pad_mask[i] for i in unique_indices]
cap_item_seqlens_unique = [len(i) for i in unique_cap_feats_list]
cap_max_item_seqlen = max(cap_item_seqlens_unique)
cap_feats_cat = torch.cat(unique_cap_feats_list, dim=0)
cap_feats_embedded = self.cap_embedder(cap_feats_cat)
cap_feats_embedded[torch.cat(unique_cap_inner_pad_mask)] = self.cap_pad_token
cap_feats_padded_unique = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0)
cap_freqs_cis_cat = self.rope_embedder(torch.cat(unique_cap_pos_ids, dim=0))
cap_freqs_cis_unique = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0)
cap_feats_padded = cap_feats_padded_unique[tensor_mapping]
cap_freqs_cis = cap_freqs_cis_unique[tensor_mapping]
seq_lens_tensor = torch.tensor([cap_max_item_seqlen] * bsz, device=device, dtype=torch.int32)
arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32)
cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
cap_item_seqlens = [cap_max_item_seqlen] * bsz
return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens
cap_item_seqlens = [len(i) for i in cap_feats_list]
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats_cat = torch.cat(cap_feats_list, dim=0)
cap_feats_embedded = self.cap_embedder(cap_feats_cat)
cap_feats_embedded[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats_padded = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
cap_freqs_cis_cat = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
cap_freqs_cis = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
seq_lens_tensor = torch.tensor(cap_item_seqlens, device=device, dtype=torch.int32)
arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32)
cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens
@staticmethod
def _create_coordinate_grid(size, start=None, device=None):
"""
Creates a 3D coordinate grid.
Args:
size (Tuple[int]): The dimensions of the grid (F, H, W).
start (Tuple[int], optional): The starting coordinates for each axis. Defaults to (0, 0, 0).
device (torch.device, optional): The device to create the tensor on. Defaults to None.
Returns:
torch.Tensor: The coordinate grid tensor.
"""
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
def _apply_transformer_blocks(self, hidden_states, layers, checkpoint_ratio=0.5, **kwargs):
"""
Applies a list of transformer layers to the hidden states, with optional selective gradient checkpointing.
Args:
hidden_states (torch.Tensor): The input tensor.
layers (nn.ModuleList): The list of transformer layers to apply.
checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. Defaults to 0.5.
**kwargs: Additional keyword arguments to pass to each layer's forward method.
Returns:
torch.Tensor: The output tensor after applying all layers.
"""
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, **static_kwargs):
def custom_forward(*inputs):
return module(*inputs, **static_kwargs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
checkpoint_every_n = max(1, int(1.0 / checkpoint_ratio)) if checkpoint_ratio > 0 else len(layers) + 1
for i, layer in enumerate(layers):
if i % checkpoint_every_n == 0:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer, **kwargs),
hidden_states,
**ckpt_kwargs,
)
else:
hidden_states = layer(hidden_states, **kwargs)
else:
for layer in layers:
hidden_states = layer(hidden_states, **kwargs)
return hidden_states
def _prepare_control_inputs(self, control_context, cap_feats_ref, t, patch_size, f_patch_size, device):
"""
Prepares the control context for the transformer, including patchifying, embedding, and generating
positional information. Includes a fast path for batches with uniform shapes.
Args:
control_context (torch.Tensor or List[torch.Tensor]): The control context input.
cap_feats_ref (List[torch.Tensor]): A reference to caption features for padding calculation.
t (torch.Tensor): The timestep tensor.
patch_size (int): The spatial patch size.
f_patch_size (int): The frame/temporal patch size.
device (torch.device): The target device.
Returns:
Dict: A dictionary containing the processed control tensors ('c', 'c_item_seqlens', 'attn_mask', etc.).
"""
bsz = control_context.shape[0]
if isinstance(control_context, torch.Tensor) and control_context.ndim == 5:
control_list = list(torch.unbind(control_context, dim=0))
else:
control_list = control_context
pH = pW = patch_size
pF = f_patch_size
cap_padding_len = cap_feats_ref[0].size(0) if isinstance(cap_feats_ref, list) else cap_feats_ref.shape[1]
shapes = [c.shape for c in control_list]
same_shape = all(s == shapes[0] for s in shapes)
if same_shape and bsz >= 2:
control_batch = torch.stack(control_list, dim=0)
B, C, F, H, W = control_batch.shape
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
control_batch = control_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
control_batch = control_batch.permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
ori_len = control_batch.shape[1]
padding_len = (-ori_len) % SEQ_MULTI_OF
if padding_len > 0:
pad_tensor = control_batch[:, -1:, :].repeat(1, padding_len, 1)
control_batch = torch.cat([control_batch, pad_tensor], dim=1)
c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_batch)
final_seq_len = control_batch.shape[1]
pos_ids_ori = self._create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2) # [ori_len, 3]
pos_ids_pad = torch.zeros((padding_len, 3), dtype=torch.int32, device=device)
pos_ids_padded = torch.cat([pos_ids_ori, pos_ids_pad], dim=0)
c_freqs_cis_single = self.rope_embedder(pos_ids_padded)
c_freqs_cis = c_freqs_cis_single.unsqueeze(0).repeat(B, 1, 1, 1)
c_attn_mask = torch.ones((B, final_seq_len), dtype=torch.bool, device=device)
return {"c": c, "c_item_seqlens": [final_seq_len] * B, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis, "adaln_input": t.type_as(c)}
(c_patches, _, c_pos_ids, c_inner_pad_mask) = self._patchify(control_list, patch_size, f_patch_size, cap_padding_len)
c_item_seqlens = [len(p) for p in c_patches]
c_max_item_seqlen = max(c_item_seqlens)
c = torch.cat(c_patches, dim=0)
c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](c)
c[torch.cat(c_inner_pad_mask)] = self.x_pad_token
c = list(c.split(c_item_seqlens, dim=0))
c_freqs_cis_list = []
for pos_ids in c_pos_ids:
c_freqs_cis_list.append(self.rope_embedder(pos_ids))
c_padded = pad_sequence(c, batch_first=True, padding_value=0.0)
c_freqs_cis_padded = pad_sequence(c_freqs_cis_list, batch_first=True, padding_value=0.0)
seq_lens_tensor = torch.tensor(c_item_seqlens, device=device, dtype=torch.int32)
arange = torch.arange(c_max_item_seqlen, device=device, dtype=torch.int32)
c_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
return {"c": c_padded, "c_item_seqlens": c_item_seqlens, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis_padded, "adaln_input": t.type_as(c_padded)}
def _patchify_and_embed_batch_optimized(self, all_image, all_cap_feats, patch_size, f_patch_size):
"""
An optimized version of _patchify_and_embed for batches where all images and captions have
uniform shapes. It processes the entire batch using vectorized operations instead of a loop.
Args:
all_image (List[torch.Tensor]): List of image tensors, all of the same shape.
all_cap_feats (List[torch.Tensor]): List of caption features, all of the same shape.
patch_size (int): The spatial patch size.
f_patch_size (int): The frame/temporal patch size.
Returns:
Tuple: A tuple containing all processed data structures, matching the output of the standard method.
"""
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
image_shapes = [img.shape for img in all_image]
cap_shapes = [cap.shape for cap in all_cap_feats]
same_image_shape = all(s == image_shapes[0] for s in image_shapes)
same_cap_shape = all(s == cap_shapes[0] for s in cap_shapes)
if not (same_image_shape and same_cap_shape):
return self._patchify_and_embed(all_image, all_cap_feats, patch_size, f_patch_size)
images_batch = torch.stack(all_image, dim=0)
caps_batch = torch.stack(all_cap_feats, dim=0)
B, C, Fr, H, W = images_batch.shape
cap_ori_len = caps_batch.shape[1]
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
cap_total_len = cap_ori_len + cap_padding_len
if cap_padding_len > 0:
cap_pad = caps_batch[:, -1:, :].repeat(1, cap_padding_len, 1)
caps_batch = torch.cat([caps_batch, cap_pad], dim=1)
cap_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2).unsqueeze(0).repeat(B, 1, 1)
cap_mask = torch.zeros((B, cap_total_len), dtype=torch.bool, device=device)
if cap_padding_len > 0:
cap_mask[:, cap_ori_len:] = True
F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW
images_reshaped = (
images_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
.permute(0, 2, 4, 6, 3, 5, 7, 1)
.reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
)
image_ori_len = images_reshaped.shape[1]
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_total_len = image_ori_len + image_padding_len
if image_padding_len > 0:
img_pad = images_reshaped[:, -1:, :].repeat(1, image_padding_len, 1)
images_reshaped = torch.cat([images_reshaped, img_pad], dim=1)
image_pos_ids = (
self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device)
.flatten(0, 2)
.unsqueeze(0)
.repeat(B, 1, 1)
)
if image_padding_len > 0:
img_pos_pad = torch.zeros((B, image_padding_len, 3), dtype=torch.int32, device=device)
image_pos_ids = torch.cat([image_pos_ids, img_pos_pad], dim=1)
image_mask = torch.zeros((B, image_total_len), dtype=torch.bool, device=device)
if image_padding_len > 0:
image_mask[:, image_ori_len:] = True
all_image_size = [(Fr, H, W)] * B
return (
list(torch.unbind(images_reshaped, dim=0)),
list(torch.unbind(caps_batch, dim=0)),
all_image_size,
list(torch.unbind(image_pos_ids, dim=0)),
list(torch.unbind(cap_pos_ids, dim=0)),
list(torch.unbind(image_mask, dim=0)),
list(torch.unbind(cap_mask, dim=0)),
)
def forward(
self,
x: List[torch.Tensor],
t,
cap_feats: List[torch.Tensor],
patch_size=2,
f_patch_size=1,
control_context=None,
conditioning_scale=1.0,
refiner_conditioning_scale=1.0,
):
"""
The main forward pass of the transformer model.
Args:
x (List[torch.Tensor]):
A list of latent image tensors.
t (torch.Tensor):
A batch of timesteps.
cap_feats (List[torch.Tensor]):
A list of caption feature tensors.
patch_size (int, optional):
The spatial patch size to use. Defaults to 2.
f_patch_size (int, optional):
The frame/temporal patch size to use. Defaults to 1.
control_context (torch.Tensor, optional):
The control context tensor. Defaults to None.
conditioning_scale (float, optional):
The scale for applying control hints. Defaults to 1.0.
refiner_conditioning_scale (float, optional):
The scale for applying refiner control hints. Defaults to 1.0.
Returns:
Transformer2DModelOutput: An object containing the final denoised sample.
"""
is_control_mode = self.use_controlnet and control_context is not None and conditioning_scale > 0
if refiner_conditioning_scale is None:
refiner_conditioning_scale = conditioning_scale or 1.0
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
can_optimize_patchify = (
bsz == len(cap_feats) and bsz >= 2 and all(img.shape == x[0].shape for img in x) and all(cap.shape == cap_feats[0].shape for cap in cap_feats)
)
if can_optimize_patchify:
(x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed_batch_optimized(
x, cap_feats, patch_size, f_patch_size
)
else:
(x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed(
x, cap_feats, patch_size, f_patch_size
)
x_item_seqlens = [len(i) for i in x_list]
x_max_item_seqlen = max(x_item_seqlens) if x_item_seqlens else 0
x_cat = torch.cat(x_list, dim=0) if x_list else torch.empty(0, x_list[0].shape[1] if x_list else 0, device=device)
x_embedded = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
if x_inner_pad_mask and torch.cat(x_inner_pad_mask).any():
x_embedded[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = pad_sequence(list(x_embedded.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
adaln_input = t.to(device).type_as(x)
cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens = self._process_cap_feats_with_cfg_cache(
cap_feats_list, cap_pos_ids, cap_inner_pad_mask
)
x_freqs_cis_cat = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) if x_pos_ids else torch.empty(0, device=device)
x_freqs_cis = pad_sequence(list(x_freqs_cis_cat.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
seq_lens_tensor = torch.tensor(x_item_seqlens, device=device, dtype=torch.int32)
arange = torch.arange(x_max_item_seqlen, device=device, dtype=torch.int32)
x_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
refiner_hints = None
if is_control_mode and self.is_two_stage_control:
prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device)
c = prepared_control["c"]
"""
kwargs_for_control_refiner = {
"x": x,
"attn_mask": prepared_control["attn_mask"],
"freqs_cis": prepared_control["freqs_cis"],
"adaln_input": prepared_control["adaln_input"],
}
c_processed = self._apply_transformer_blocks(
c,
self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers,
checkpoint_ratio=self.checkpoint_ratio,
**kwargs_for_control_refiner,
)
refiner_hints = torch.unbind(c_processed)[:-1]
control_context_processed = torch.unbind(c_processed)[-1]
control_context_item_seqlens = prepared_control["c_item_seqlens"]
"""
kwargs_for_control_refiner = {
"x": x,
"attn_mask": x_attn_mask, # was prepared_control["attn_mask"]
"freqs_cis": x_freqs_cis, # was prepared_control["freqs_cis"]
"adaln_input": adaln_input,
}
c_processed = self._apply_transformer_blocks(
c,
self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, # KEEP ORIGINAL
checkpoint_ratio=self.checkpoint_ratio,
**kwargs_for_control_refiner,
)
refiner_hints = torch.unbind(c_processed)[:-1]
control_context_processed = torch.unbind(c_processed)[-1]
control_context_item_seqlens = prepared_control["c_item_seqlens"]
kwargs_for_refiner = {
"attn_mask": x_attn_mask,
"freqs_cis": x_freqs_cis,
"adaln_input": adaln_input,
"context_scale": refiner_conditioning_scale,
}
if refiner_hints is not None:
kwargs_for_refiner["hints"] = refiner_hints
x = self._apply_transformer_blocks(x, self.noise_refiner, checkpoint_ratio=1.0, **kwargs_for_refiner)
kwargs_for_context = {"attn_mask": cap_attn_mask, "freqs_cis": cap_freqs_cis}
cap_feats = self._apply_transformer_blocks(cap_feats_padded, self.context_refiner, checkpoint_ratio=1.0, **kwargs_for_context)
unified_item_seqlens = [a + b for a, b in zip(x_item_seqlens, cap_item_seqlens)]
unified_max_item_seqlen = max(unified_item_seqlens) if unified_item_seqlens else 0
unified = torch.zeros((bsz, unified_max_item_seqlen, x.shape[-1]), dtype=x.dtype, device=device)
unified_freqs_cis = torch.zeros((bsz, unified_max_item_seqlen, x_freqs_cis.shape[-2], x_freqs_cis.shape[-1]), dtype=x_freqs_cis.dtype, device=device)
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified[i, :x_len] = x[i, :x_len]
unified[i, x_len : x_len + cap_len] = cap_feats[i, :cap_len]
unified_freqs_cis[i, :x_len] = x_freqs_cis[i, :x_len]
unified_freqs_cis[i, x_len : x_len + cap_len] = cap_freqs_cis[i, :cap_len]
seq_lens_tensor = torch.tensor(unified_item_seqlens, device=device, dtype=torch.int32)
arange = torch.arange(unified_max_item_seqlen, device=device, dtype=torch.int32)
unified_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
hints = None
if is_control_mode:
kwargs_for_hints = {
"attn_mask": unified_attn_mask,
"freqs_cis": unified_freqs_cis,
"adaln_input": adaln_input,
}
if self.is_two_stage_control:
control_context_unified_list = [
torch.cat([control_context_processed[i][: control_context_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)
]
c = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0)
new_kwargs = dict(x=unified, **kwargs_for_hints)
c_processed = self._apply_transformer_blocks(c, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs)
hints = torch.unbind(c_processed)[:-1]
else:
prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device)
c = prepared_control["c"]
kwargs_for_v1_refiner = {
"attn_mask": prepared_control["attn_mask"],
"freqs_cis": prepared_control["freqs_cis"],
"adaln_input": prepared_control["adaln_input"],
}
c = self._apply_transformer_blocks(c, self.control_noise_refiner, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_v1_refiner)
c_item_seqlens = prepared_control["c_item_seqlens"]
control_context_unified_list = [torch.cat([c[i, : c_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)]
c_unified = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0)
new_kwargs = dict(x=unified, **kwargs_for_hints)
c_processed = self._apply_transformer_blocks(c_unified, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs)
hints = torch.unbind(c_processed)[:-1]
kwargs_for_layers = {"attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input}
if hints is not None:
kwargs_for_layers["hints"] = hints
kwargs_for_layers["context_scale"] = conditioning_scale
unified = self._apply_transformer_blocks(unified, self.layers, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_layers)
unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
x_image_tokens = unified_out[:, :x_max_item_seqlen]
x_final_tensor = self._unpatchify(x_image_tokens, x_size, patch_size, f_patch_size)
return Transformer2DModelOutput(sample=x_final_tensor)