neuralvfx commited on
Commit
7f0b483
·
0 Parent(s):

Initial commit with large files tracked by LFS

Browse files
.gitattributes ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
3
+ *.pth filter=lfs diff=lfs merge=lfs -text
4
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
5
+ *.pt filter=lfs diff=lfs merge=lfs -text
6
+ *.bin filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.zip filter=lfs diff=lfs merge=lfs -text
9
+ *.rar filter=lfs diff=lfs merge=lfs -text
10
+ *.gz filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ __pycache__/
4
+ *.pyc
README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - opendiffusionai/laion2b-squareish-1536px
5
+ base_model:
6
+ - Tongyi-MAI/Z-Image
7
+ tags:
8
+ - z-image
9
+ - controlnet
10
+ thumbnail: https://huggingface.co/neuralvfx/Z-Image-SAM-ControlNet/resolve/main/assets/stacked_vertical.png
11
+ ---
12
+
13
+ # Z-Image-SAM-ControlNet
14
+ ![side by side](assets/side_by_side_d.png)
15
+ ## Fun Facts
16
+ - This ControlNet is trained exclusively on images generated by [Segment Anything (SAM)](https://aidemos.meta.com/segment-anything/)
17
+ - Base model used was [Tongyi-MAI/Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image)
18
+ - Trained at 1024x1024 resolution
19
+ - Trained on 220K segmented images from [laion2b-squareish-1536px](https://huggingface.co/datasets/opendiffusionai/laion2b-squareish-1536px)
20
+ - Trained using this repo: [https://github.com/aigc-apps/VideoX-Fun](https://github.com/aigc-apps/VideoX-Fun)
21
+
22
+ # Showcases
23
+ <table style="width:100%; table-layout:fixed;">
24
+ <tr>
25
+ <td><img src="./assets/resized_kitten_seg.png" ></td>
26
+ <td><img src="./assets/resized_kitten.png" ></td>
27
+ </tr>
28
+ <tr>
29
+ <td><img src="./assets/resized_dread_girl_seg.png" ></td>
30
+ <td><img src="./assets/resized_dread_girl.png" ></td>
31
+ </tr>
32
+ <tr>
33
+ <td><img src="./assets/resized_house_seg.png" ></td>
34
+ <td><img src="./assets/resized_house.png" ></td>
35
+ </tr>
36
+ </table>
37
+
38
+ # ComfyUI Usage
39
+ 1) Copy the weights from `comfy-ui-patch/z-image-sam-controlnet.safetensors` to `ComfyUI/models/model_patches`
40
+ 2) Use `ModelPatchLoader` to load the patch
41
+ 3) Plug `MODEL_PATCH` into `model_patch` on `ZImageFunControlnet`
42
+ 4) Plug the model, VAE and image into `ZImageFunControlnet`
43
+ 5) Plug the `ZImageFunControlnet` into KSampler
44
+ ![videoXFun Nodes](assets/comfyui.png)
45
+
46
+ # Hugging Face Usage
47
+
48
+ ## Compatibility
49
+ ```py
50
+ pip install -U diffusers==0.37.0
51
+ ```
52
+
53
+ ## Download
54
+ ```bash
55
+ sudo apt-get install git-lfs
56
+ git lfs install
57
+
58
+ git clone https://huggingface.co/neuralvfx/Z-Image-SAM-ControlNet
59
+
60
+ cd Z-Image-SAM-ControlNet
61
+ ```
62
+
63
+ ## Inference
64
+ ```python
65
+ import torch
66
+ from diffusers.utils import load_image
67
+ from diffusers_local.pipeline_z_image_control_unified import ZImageControlUnifiedPipeline
68
+ from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
69
+
70
+ transformer = ZImageControlTransformer2DModel.from_pretrained(
71
+ ".",
72
+ torch_dtype=torch.bfloat16,
73
+ use_safetensors=True,
74
+ add_control_noise_refiner=True,
75
+ )
76
+
77
+ pipe = ZImageControlUnifiedPipeline.from_pretrained(
78
+ "Tongyi-MAI/Z-Image",
79
+ torch_dtype=torch.bfloat16,
80
+ transformer=transformer,
81
+ )
82
+
83
+ pipe.enable_model_cpu_offload()
84
+
85
+ image = pipe(
86
+ prompt="some beach wood washed up on the sunny sand, spelling the words z-image, with footprints and waves crashing",
87
+ negative_prompt="低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。",
88
+ control_image=load_image("assets/z-image.png"),
89
+ height=1024,
90
+ width=1024,
91
+ num_inference_steps=50,
92
+ guidance_scale=4.0,
93
+ controlnet_conditioning_scale=1.0,
94
+ generator= torch.Generator("cuda").manual_seed(22),
95
+ ).images[0]
96
+
97
+ image.save("output.png")
98
+ image
99
+ ```
assets/comfyui.png ADDED

Git LFS Details

  • SHA256: 5cdf98017e5c0cd6e880adefce747f2832f91fe7d0d1af1f358c312cf63bda09
  • Pointer size: 130 Bytes
  • Size of remote file: 69.7 kB
assets/gallery.png ADDED

Git LFS Details

  • SHA256: 249d295ade52970079957ac6d7bf43d5acaf79475b2fdf0d60a2e2e1e562f61b
  • Pointer size: 131 Bytes
  • Size of remote file: 658 kB
assets/girl_icon.png ADDED

Git LFS Details

  • SHA256: 1068f895476ee2375316db163ae9d12d42fb1996d38fb19bf7e2ae8d2e6e98dc
  • Pointer size: 131 Bytes
  • Size of remote file: 460 kB
assets/resized_dread_girl.png ADDED

Git LFS Details

  • SHA256: 4c0fa84baad4f6c2e264053d41aaed3cf4696cd7250a24a7def82dd1ba9af13b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.11 MB
assets/resized_dread_girl_seg.png ADDED

Git LFS Details

  • SHA256: af9c472d3bb2050f0ba2c4e5d849d67a1a4e148c9746e3dad0e72752ca11a02f
  • Pointer size: 131 Bytes
  • Size of remote file: 275 kB
assets/resized_house.png ADDED

Git LFS Details

  • SHA256: 95a999066e9c8264a1b4e2a8fcde497f8f1d1f594abec026d36765505212e0f7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
assets/resized_house_seg.png ADDED

Git LFS Details

  • SHA256: f841f654fe557240cc8159a7540c2d7caaf4aad8d73d21bdc9b3ad740e055c42
  • Pointer size: 131 Bytes
  • Size of remote file: 217 kB
assets/resized_kitten.png ADDED

Git LFS Details

  • SHA256: 0587c5229256f6be0282bce9a3839c2bfbfd55bbf4ee3b6faf62d0759a390bc6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
assets/resized_kitten_seg.png ADDED

Git LFS Details

  • SHA256: 2cbcc3608f07af9cd2f76d4ed783b42820a28f289f24c888ebd451d93cb20de3
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
assets/side_by_side_d.png ADDED

Git LFS Details

  • SHA256: d3b1a1f0c3d95bfa107a92a32e6cbcd1c33b6ead024b6c2703c985c558bdbf06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
assets/stacked_vertical.png ADDED

Git LFS Details

  • SHA256: 834b5452dc8418327c3be3ee32ef3dd93592c9d2d3f0a39f6edf199694da3753
  • Pointer size: 131 Bytes
  • Size of remote file: 863 kB
assets/z-image.png ADDED

Git LFS Details

  • SHA256: 27557ac106417f123b6707a13820aa09c3f0fe664c6fac03d4f354c314d9368d
  • Pointer size: 130 Bytes
  • Size of remote file: 33.7 kB
comfy-ui-patch/z-image-sam-controlnet.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d64755decb4b48ee265e271d9a65b2e5fba0d06bca79a7d382dfc7d7829ee15a
3
+ size 6712485600
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ZImageControlTransformer2DModel",
3
+ "_diffusers_version": "0.37.0",
4
+ "add_control_noise_refiner": true,
5
+ "add_control_noise_refiner_correctly": true,
6
+ "all_f_patch_size": [
7
+ 1
8
+ ],
9
+ "all_patch_size": [
10
+ 2
11
+ ],
12
+ "axes_dims": [
13
+ 32,
14
+ 48,
15
+ 48
16
+ ],
17
+ "axes_lens": [
18
+ 1536,
19
+ 512,
20
+ 512
21
+ ],
22
+ "cap_feat_dim": 2560,
23
+ "control_in_dim": 33,
24
+ "control_layers_places": [
25
+ 0,
26
+ 2,
27
+ 4,
28
+ 6,
29
+ 8,
30
+ 10,
31
+ 12,
32
+ 14,
33
+ 16,
34
+ 18,
35
+ 20,
36
+ 22,
37
+ 24,
38
+ 26,
39
+ 28
40
+ ],
41
+ "control_refiner_layers_places": [
42
+ 0,
43
+ 1
44
+ ],
45
+ "dim": 3840,
46
+ "in_channels": 16,
47
+ "n_heads": 30,
48
+ "n_kv_heads": 30,
49
+ "n_layers": 30,
50
+ "n_refiner_layers": 2,
51
+ "norm_eps": 1e-05,
52
+ "qk_norm": true,
53
+ "rope_theta": 256.0,
54
+ "siglip_feat_dim": null,
55
+ "t_scale": 1000.0
56
+ }
diffusers_local/patch.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ from typing import Optional, Set
4
+
5
+ import diffusers.loaders.single_file_model as single_file_model
6
+ import diffusers.pipelines.pipeline_loading_utils as pipe_loading_utils
7
+ import torch
8
+ from diffusers.loaders.single_file_utils import (
9
+ convert_animatediff_checkpoint_to_diffusers,
10
+ convert_auraflow_transformer_checkpoint_to_diffusers,
11
+ convert_autoencoder_dc_checkpoint_to_diffusers,
12
+ convert_chroma_transformer_checkpoint_to_diffusers,
13
+ convert_controlnet_checkpoint,
14
+ convert_cosmos_transformer_checkpoint_to_diffusers,
15
+ convert_flux2_transformer_checkpoint_to_diffusers,
16
+ convert_flux_transformer_checkpoint_to_diffusers,
17
+ convert_hidream_transformer_to_diffusers,
18
+ convert_hunyuan_video_transformer_to_diffusers,
19
+ convert_ldm_unet_checkpoint,
20
+ convert_ldm_vae_checkpoint,
21
+ convert_ltx_transformer_checkpoint_to_diffusers,
22
+ convert_ltx_vae_checkpoint_to_diffusers,
23
+ convert_lumina2_to_diffusers,
24
+ convert_mochi_transformer_checkpoint_to_diffusers,
25
+ convert_sana_transformer_to_diffusers,
26
+ convert_sd3_transformer_checkpoint_to_diffusers,
27
+ convert_stable_cascade_unet_single_file_to_diffusers,
28
+ convert_wan_transformer_to_diffusers,
29
+ convert_wan_vae_to_diffusers,
30
+ convert_z_image_transformer_checkpoint_to_diffusers,
31
+ create_controlnet_diffusers_config_from_ldm,
32
+ create_unet_diffusers_config_from_ldm,
33
+ create_vae_diffusers_config_from_ldm,
34
+ )
35
+ from diffusers.pipelines.pipeline_loading_utils import _unwrap_model
36
+ from diffusers.utils import (
37
+ _maybe_remap_transformers_class,
38
+ get_class_from_dynamic_module,
39
+ )
40
+
41
+
42
+ try:
43
+ from diffusers.hooks.group_offloading import (
44
+ _GROUP_ID_LAZY_LEAF,
45
+ GroupOffloadingConfig,
46
+ ModelHook,
47
+ ModuleGroup,
48
+ _apply_group_offloading_hook,
49
+ _apply_lazy_group_offloading_hook,
50
+ _find_parent_module_in_module_dict,
51
+ _gather_buffers_with_no_group_offloading_parent,
52
+ _gather_parameters_with_no_group_offloading_parent,
53
+ send_to_device,
54
+ )
55
+
56
+ except ImportError:
57
+ ModelHook = object
58
+ ModuleGroup = object
59
+ GroupOffloadingConfig = object
60
+
61
+ def _apply_group_offloading_hook(*args, **kwargs):
62
+ pass
63
+
64
+
65
+ _MY_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
66
+ torch.nn.Conv1d,
67
+ torch.nn.Conv2d,
68
+ torch.nn.Conv3d,
69
+ torch.nn.ConvTranspose1d,
70
+ torch.nn.ConvTranspose2d,
71
+ torch.nn.ConvTranspose3d,
72
+ torch.nn.Linear,
73
+ torch.nn.Sequential, # A camada que queremos adicionar
74
+ )
75
+
76
+
77
+ class GroupOffloadingHook(ModelHook):
78
+ r"""
79
+ A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
80
+ computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
81
+ module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
82
+ group is responsible for onloading the current module group.
83
+ """
84
+
85
+ _is_stateful = False
86
+
87
+ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
88
+ self.group = group
89
+ self.next_group: Optional[ModuleGroup] = None
90
+ self.config = config
91
+
92
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
93
+ if self.group.offload_leader == module:
94
+ self.group.offload_()
95
+ return module
96
+
97
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
98
+ # If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
99
+ # method is the onload_leader of the group.
100
+ if self.group.onload_leader is None:
101
+ self.group.onload_leader = module
102
+
103
+ if self.group.onload_leader == module:
104
+ # STEP 1: GUARANTEE THE CURRENT GROUP'S STATE
105
+ # This section ensures that the parameters for the *current* module are on the correct device
106
+ # before its forward pass is executed.
107
+
108
+ # This block handles modules that are part of the prefetching chain (`onload_self` is False).
109
+ # The original design relied on the previous module to initiate the onload, which proved fragile.
110
+ # Our robust fix makes each module responsible for itself:
111
+ # 1. `self.group.onload_()`: Guarantees the data transfer is initiated, acting as a backup if the
112
+ # previous module in the chain failed to do so.
113
+ # 2. `self.group.stream.synchronize()`: This is the critical synchronization barrier. It forces the
114
+ # CPU to wait until the asynchronous transfer to the GPU is complete, preventing device mismatch errors.
115
+ if not self.group.onload_self and self.group.stream is not None:
116
+ self.group.onload_()
117
+ self.group.stream.synchronize()
118
+
119
+ # This block handles the first module in an execution chain (`onload_self` is True).
120
+ # It is responsible for loading itself onto the device.
121
+ if self.group.onload_self:
122
+ self.group.onload_()
123
+ # If streams are used, the onload() call above is asynchronous. We MUST synchronize here
124
+ # to ensure the module is ready before its computation starts.
125
+ if self.group.stream is not None:
126
+ self.group.stream.synchronize()
127
+
128
+ # At this point, we are 100% certain that the current group's parameters are on the onload_device.
129
+
130
+ # STEP 2: INITIATE PREFETCHING FOR THE NEXT GROUP
131
+ # With the current group secured, we can now look ahead and start the asynchronous data transfer
132
+ # for the next module in the execution chain. This allows the data transfer to overlap with the
133
+ # computation of the current module's forward pass, which is the core benefit of prefetching.
134
+ should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
135
+ if should_onload_next_group:
136
+ self.next_group.onload_()
137
+
138
+ # The rest of the function handles moving positional (*args) and keyword (**kwargs)
139
+ # arguments to the correct device.
140
+ args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
141
+
142
+ exclude_kwargs = self.config.exclude_kwargs or []
143
+ if exclude_kwargs:
144
+ moved_kwargs = send_to_device(
145
+ {k: v for k, v in kwargs.items() if k not in exclude_kwargs},
146
+ self.group.onload_device,
147
+ non_blocking=self.group.non_blocking,
148
+ )
149
+ kwargs.update(moved_kwargs)
150
+ else:
151
+ kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
152
+
153
+ return args, kwargs
154
+
155
+ def post_forward(self, module: torch.nn.Module, output):
156
+ if self.group.offload_leader == module:
157
+ self.group.offload_()
158
+ return output
159
+
160
+
161
+ def _apply_group_offloading_leaf_level_patched(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
162
+ """
163
+ Versão corrigida de _apply_group_offloading_leaf_level que suporta nn.Sequential.
164
+ """
165
+ modules_with_group_offloading: Set[str] = set()
166
+ for name, submodule in module.named_modules():
167
+ if not isinstance(submodule, _MY_GO_LC_SUPPORTED_PYTORCH_LAYERS):
168
+ continue
169
+
170
+ group = ModuleGroup(
171
+ modules=[submodule],
172
+ offload_device=config.offload_device,
173
+ onload_device=config.onload_device,
174
+ offload_to_disk_path=config.offload_to_disk_path,
175
+ offload_leader=submodule,
176
+ onload_leader=submodule,
177
+ non_blocking=config.non_blocking,
178
+ stream=config.stream,
179
+ record_stream=config.record_stream,
180
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
181
+ onload_self=True,
182
+ group_id=name,
183
+ )
184
+ _apply_group_offloading_hook(submodule, group, config=config)
185
+ modules_with_group_offloading.add(name)
186
+
187
+ # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
188
+ # of the module is called
189
+ module_dict = dict(module.named_modules())
190
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
191
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
192
+
193
+ # Find closest module parent for each parameter and buffer, and attach group hooks
194
+ parent_to_parameters = {}
195
+ for name, param in parameters:
196
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
197
+ if parent_name in parent_to_parameters:
198
+ parent_to_parameters[parent_name].append(param)
199
+ else:
200
+ parent_to_parameters[parent_name] = [param]
201
+
202
+ parent_to_buffers = {}
203
+ for name, buffer in buffers:
204
+ parent_name = _find_parent_module_in_module_dict(name, module_dict)
205
+ if parent_name in parent_to_buffers:
206
+ parent_to_buffers[parent_name].append(buffer)
207
+ else:
208
+ parent_to_buffers[parent_name] = [buffer]
209
+
210
+ parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
211
+ for name in parent_names:
212
+ parameters = parent_to_parameters.get(name, [])
213
+ buffers = parent_to_buffers.get(name, [])
214
+ parent_module = module_dict[name]
215
+ group = ModuleGroup(
216
+ modules=[],
217
+ offload_device=config.offload_device,
218
+ onload_device=config.onload_device,
219
+ offload_leader=parent_module,
220
+ onload_leader=parent_module,
221
+ offload_to_disk_path=config.offload_to_disk_path,
222
+ parameters=parameters,
223
+ buffers=buffers,
224
+ non_blocking=config.non_blocking,
225
+ stream=config.stream,
226
+ record_stream=config.record_stream,
227
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
228
+ onload_self=True,
229
+ group_id=name,
230
+ )
231
+ _apply_group_offloading_hook(parent_module, group, config=config)
232
+
233
+ if config.stream is not None:
234
+ # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
235
+ # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
236
+ # execution order and apply prefetching in the correct order.
237
+ unmatched_group = ModuleGroup(
238
+ modules=[],
239
+ offload_device=config.offload_device,
240
+ onload_device=config.onload_device,
241
+ offload_to_disk_path=config.offload_to_disk_path,
242
+ offload_leader=module,
243
+ onload_leader=module,
244
+ parameters=None,
245
+ buffers=None,
246
+ non_blocking=False,
247
+ stream=None,
248
+ record_stream=False,
249
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
250
+ onload_self=True,
251
+ group_id=_GROUP_ID_LAZY_LEAF,
252
+ )
253
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
254
+
255
+
256
+ try:
257
+ import diffusers.hooks.group_offloading as group_offloading_module
258
+
259
+ setattr(group_offloading_module, "_apply_group_offloading_leaf_level", _apply_group_offloading_leaf_level_patched)
260
+ setattr(group_offloading_module, "GroupOffloadingHook", GroupOffloadingHook)
261
+ except ImportError as e:
262
+ print(f"-> ERRO: Não foi possível importar o módulo `diffusers.hooks.group_offloading` para aplicar o patch: {e}")
263
+
264
+
265
+ def convert_z_image_control_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
266
+ Z_IMAGE_KEYS_RENAME_DICT = {
267
+ "final_layer.": "all_final_layer.2-1.",
268
+ "x_embedder.": "all_x_embedder.2-1.",
269
+ ".attention.out.bias": ".attention.to_out.0.bias",
270
+ ".attention.k_norm.weight": ".attention.norm_k.weight",
271
+ ".attention.q_norm.weight": ".attention.norm_q.weight",
272
+ ".attention.out.weight": ".attention.to_out.0.weight",
273
+ "control_x_embedder.": "control_all_x_embedder.2-1.",
274
+ }
275
+
276
+ def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
277
+ if ".attention.qkv.weight" not in key:
278
+ return
279
+
280
+ fused_qkv_weight = state_dict.pop(key)
281
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
282
+ new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
283
+ new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
284
+ new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
285
+
286
+ state_dict[new_q_name] = to_q_weight
287
+ state_dict[new_k_name] = to_k_weight
288
+ state_dict[new_v_name] = to_v_weight
289
+ return
290
+
291
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
292
+ ".attention.qkv.weight": convert_z_image_fused_attention,
293
+ }
294
+
295
+ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
296
+ state_dict[new_key] = state_dict.pop(old_key)
297
+
298
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
299
+
300
+ # Handle single file --> diffusers key remapping via the remap dict
301
+ for key in list(converted_state_dict.keys()):
302
+ new_key = key[:]
303
+ for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
304
+ new_key = new_key.replace(replace_key, rename_key)
305
+
306
+ update_state_dict(converted_state_dict, key, new_key)
307
+
308
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
309
+ # special_keys_remap
310
+ for key in list(converted_state_dict.keys()):
311
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
312
+ if special_key not in key:
313
+ continue
314
+ handler_fn_inplace(key, converted_state_dict)
315
+
316
+ return converted_state_dict
317
+
318
+
319
+ SINGLE_FILE_LOADABLE_CLASSES = {
320
+ "StableCascadeUNet": {
321
+ "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
322
+ },
323
+ "UNet2DConditionModel": {
324
+ "checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
325
+ "config_mapping_fn": create_unet_diffusers_config_from_ldm,
326
+ "default_subfolder": "unet",
327
+ "legacy_kwargs": {
328
+ "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
329
+ },
330
+ },
331
+ "AutoencoderKL": {
332
+ "checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
333
+ "config_mapping_fn": create_vae_diffusers_config_from_ldm,
334
+ "default_subfolder": "vae",
335
+ },
336
+ "ControlNetModel": {
337
+ "checkpoint_mapping_fn": convert_controlnet_checkpoint,
338
+ "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
339
+ },
340
+ "SD3Transformer2DModel": {
341
+ "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
342
+ "default_subfolder": "transformer",
343
+ },
344
+ "MotionAdapter": {
345
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
346
+ },
347
+ "SparseControlNetModel": {
348
+ "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
349
+ },
350
+ "FluxTransformer2DModel": {
351
+ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
352
+ "default_subfolder": "transformer",
353
+ },
354
+ "ChromaTransformer2DModel": {
355
+ "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
356
+ "default_subfolder": "transformer",
357
+ },
358
+ "LTXVideoTransformer3DModel": {
359
+ "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
360
+ "default_subfolder": "transformer",
361
+ },
362
+ "AutoencoderKLLTXVideo": {
363
+ "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
364
+ "default_subfolder": "vae",
365
+ },
366
+ "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
367
+ "MochiTransformer3DModel": {
368
+ "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
369
+ "default_subfolder": "transformer",
370
+ },
371
+ "HunyuanVideoTransformer3DModel": {
372
+ "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
373
+ "default_subfolder": "transformer",
374
+ },
375
+ "AuraFlowTransformer2DModel": {
376
+ "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
377
+ "default_subfolder": "transformer",
378
+ },
379
+ "Lumina2Transformer2DModel": {
380
+ "checkpoint_mapping_fn": convert_lumina2_to_diffusers,
381
+ "default_subfolder": "transformer",
382
+ },
383
+ "SanaTransformer2DModel": {
384
+ "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
385
+ "default_subfolder": "transformer",
386
+ },
387
+ "WanTransformer3DModel": {
388
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
389
+ "default_subfolder": "transformer",
390
+ },
391
+ "WanVACETransformer3DModel": {
392
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
393
+ "default_subfolder": "transformer",
394
+ },
395
+ "AutoencoderKLWan": {
396
+ "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
397
+ "default_subfolder": "vae",
398
+ },
399
+ "HiDreamImageTransformer2DModel": {
400
+ "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
401
+ "default_subfolder": "transformer",
402
+ },
403
+ "CosmosTransformer3DModel": {
404
+ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
405
+ "default_subfolder": "transformer",
406
+ },
407
+ "QwenImageTransformer2DModel": {
408
+ "checkpoint_mapping_fn": lambda x: x,
409
+ "default_subfolder": "transformer",
410
+ },
411
+ "Flux2Transformer2DModel": {
412
+ "checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
413
+ "default_subfolder": "transformer",
414
+ },
415
+ "ZImageTransformer2DModel": {
416
+ "checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
417
+ "default_subfolder": "transformer",
418
+ },
419
+ "ZImageControlTransformer2DModel": {
420
+ "checkpoint_mapping_fn": convert_z_image_control_transformer_checkpoint_to_diffusers,
421
+ "default_subfolder": "transformer",
422
+ },
423
+ }
424
+
425
+
426
+ def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None):
427
+ """Simple helper method to retrieve class object of module as well as potential parent class objects"""
428
+ component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
429
+
430
+ if is_pipeline_module:
431
+ pipeline_module = getattr(pipelines, library_name)
432
+
433
+ class_obj = getattr(pipeline_module, class_name)
434
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
435
+ elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
436
+ # load custom component
437
+ class_obj = get_class_from_dynamic_module(component_folder, module_file=library_name + ".py", class_name=class_name)
438
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
439
+ else:
440
+ # else we just import it from the library.
441
+ library = importlib.import_module(library_name)
442
+
443
+ # Handle deprecated Transformers classes
444
+ if library_name == "transformers":
445
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
446
+
447
+ try:
448
+ class_obj = getattr(library, class_name)
449
+ except Exception:
450
+ module = importlib.import_module("diffusers_local")
451
+ class_obj = getattr(module, class_name)
452
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
453
+
454
+ return class_obj, class_candidates
455
+
456
+
457
+ def _get_single_file_loadable_mapping_class(cls):
458
+ diffusers_module = importlib.import_module("diffusers")
459
+ class_name_str = cls.__name__
460
+ for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
461
+ try:
462
+ loadable_class = getattr(diffusers_module, loadable_class_str)
463
+ except Exception:
464
+ module = importlib.import_module("diffusers_local")
465
+ loadable_class = getattr(module, loadable_class_str)
466
+ if issubclass(cls, loadable_class):
467
+ return loadable_class_str
468
+
469
+ return class_name_str
470
+
471
+
472
+ def maybe_raise_or_warn(library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module):
473
+ """Simple helper method to raise or warn in case incorrect module has been passed"""
474
+ if not is_pipeline_module:
475
+ library = importlib.import_module(library_name)
476
+
477
+ # Handle deprecated Transformers classes
478
+ if library_name == "transformers":
479
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
480
+
481
+ try:
482
+ class_obj = getattr(library, class_name)
483
+ except Exception:
484
+ module = importlib.import_module("diffusers_local")
485
+ class_obj = getattr(module, class_name)
486
+
487
+ class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
488
+
489
+ expected_class_obj = None
490
+ for class_name, class_candidate in class_candidates.items():
491
+ if class_candidate is not None and issubclass(class_obj, class_candidate):
492
+ expected_class_obj = class_candidate
493
+
494
+ # Dynamo wraps the original model in a private class.
495
+ # I didn't find a public API to get the original class.
496
+ sub_model = passed_class_obj[name]
497
+ unwrapped_sub_model = _unwrap_model(sub_model)
498
+ model_cls = unwrapped_sub_model.__class__
499
+
500
+ if not issubclass(model_cls, expected_class_obj):
501
+ raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
502
+ else:
503
+ print(f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it has the correct type")
504
+
505
+
506
+ pipe_loading_utils.get_class_obj_and_candidates = get_class_obj_and_candidates
507
+ pipe_loading_utils.maybe_raise_or_warn = maybe_raise_or_warn
508
+ single_file_model.SINGLE_FILE_LOADABLE_CLASSES = SINGLE_FILE_LOADABLE_CLASSES
509
+ single_file_model._get_single_file_loadable_mapping_class = _get_single_file_loadable_mapping_class
diffusers_local/pipeline_z_image_control_unified.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
2
+ # Refactored and optimized by DEVAIEXP Team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import inspect
18
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchvision.transforms as T
24
+ from diffusers import AutoencoderKL, DiffusionPipeline, FlowMatchEulerDiscreteScheduler
25
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
27
+ from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
28
+ from diffusers.utils import logging
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from PIL import Image, ImageFilter
31
+ from transformers import AutoTokenizer, PreTrainedModel
32
+
33
+ from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ def calculate_shift(
40
+ image_seq_len,
41
+ base_seq_len: int = 256,
42
+ max_seq_len: int = 4096,
43
+ base_shift: float = 0.5,
44
+ max_shift: float = 1.15,
45
+ ):
46
+ """
47
+ Calculates the shift value `mu` for the scheduler based on the image sequence length.
48
+
49
+ This function implements a linear interpolation to determine the shift value based on the input
50
+ image's sequence length, scaling between a base and a maximum shift value.
51
+
52
+ Args:
53
+ image_seq_len (`int`):
54
+ The sequence length of the image latents (height * width).
55
+ base_seq_len (`int`, *optional*, defaults to 256):
56
+ The base sequence length for the shift calculation.
57
+ max_seq_len (`int`, *optional*, defaults to 4096):
58
+ The maximum sequence length for the shift calculation.
59
+ base_shift (`float`, *optional*, defaults to 0.5):
60
+ The shift value corresponding to `base_seq_len`.
61
+ max_shift (`float`, *optional*, defaults to 1.15):
62
+ The shift value corresponding to `max_seq_len`.
63
+
64
+ Returns:
65
+ `float`: The calculated shift value `mu`.
66
+ """
67
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
68
+ b = base_shift - m * base_seq_len
69
+ mu = image_seq_len * m + b
70
+ return mu
71
+
72
+
73
+ def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"):
74
+ """
75
+ Retrieves latents from a VAE encoder output.
76
+
77
+ Args:
78
+ encoder_output (`torch.Tensor`):
79
+ The output of a VAE encoder.
80
+ generator (`torch.Generator`, *optional*):
81
+ A random number generator for sampling from the latent distribution.
82
+ sample_mode (`str`, *optional*, defaults to "sample"):
83
+ The method to retrieve latents. Can be "sample" to sample from the distribution or
84
+ "argmax" to take the mode.
85
+
86
+ Returns:
87
+ `torch.Tensor`: The retrieved latents.
88
+ """
89
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
90
+ return encoder_output.latent_dist.sample(generator)
91
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
92
+ return encoder_output.latent_dist.mode()
93
+ elif hasattr(encoder_output, "latents"):
94
+ return encoder_output.latents
95
+ else:
96
+ raise AttributeError("Could not access latents of provided encoder_output")
97
+
98
+
99
+ def retrieve_timesteps(
100
+ scheduler,
101
+ num_inference_steps: Optional[int] = None,
102
+ device: Optional[Union[str, torch.device]] = None,
103
+ timesteps: Optional[List[int]] = None,
104
+ sigmas: Optional[List[float]] = None,
105
+ **kwargs,
106
+ ):
107
+ """
108
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
109
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
110
+
111
+ Args:
112
+ scheduler (`SchedulerMixin`):
113
+ The scheduler to get timesteps from.
114
+ num_inference_steps (`int`):
115
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
116
+ must be `None`.
117
+ device (`str` or `torch.device`, *optional*):
118
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
119
+ timesteps (`List[int]`, *optional*):
120
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
121
+ `num_inference_steps` and `sigmas` must be `None`.
122
+ sigmas (`List[float]`, *optional*):
123
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
124
+ `num_inference_steps` and `timesteps` must be `None`.
125
+
126
+ Returns:
127
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
128
+ second element is the number of inference steps.
129
+ """
130
+ if timesteps is not None and sigmas is not None:
131
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
132
+ if timesteps is not None:
133
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
134
+ if not accepts_timesteps:
135
+ raise ValueError(
136
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
137
+ f" timestep schedules. Please check whether you are using the correct scheduler."
138
+ )
139
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
140
+ timesteps = scheduler.timesteps
141
+ num_inference_steps = len(timesteps)
142
+ elif sigmas is not None:
143
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
144
+ if not accept_sigmas:
145
+ raise ValueError(
146
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
147
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
148
+ )
149
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
150
+ timesteps = scheduler.timesteps
151
+ num_inference_steps = len(timesteps)
152
+ else:
153
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
154
+ timesteps = scheduler.timesteps
155
+ return timesteps, num_inference_steps
156
+
157
+
158
+ class ZImageControlUnifiedPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
159
+ model_cpu_offload_seq = "text_encoder->vae->transformer"
160
+ _optional_components = []
161
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
162
+
163
+ def __init__(
164
+ self,
165
+ scheduler: FlowMatchEulerDiscreteScheduler,
166
+ vae: AutoencoderKL,
167
+ text_encoder: PreTrainedModel,
168
+ tokenizer: AutoTokenizer,
169
+ transformer: ZImageControlTransformer2DModel,
170
+ ):
171
+ """
172
+ Initializes the ZImageControlUnifiedPipeline.
173
+
174
+ Args:
175
+ scheduler (`FlowMatchEulerDiscreteScheduler`):
176
+ A scheduler to be used in combination with `transformer` to denoise the latents.
177
+ vae (`AutoencoderKL`):
178
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
179
+ text_encoder (`PreTrainedModel`):
180
+ A pretrained text encoder model.
181
+ tokenizer (`AutoTokenizer`):
182
+ A tokenizer to prepare text prompts for the `text_encoder`.
183
+ transformer (`ZImageControlTransformer2DModel`):
184
+ The main transformer model for the diffusion process.
185
+ """
186
+ super().__init__()
187
+ self.register_modules(
188
+ vae=vae,
189
+ text_encoder=text_encoder,
190
+ tokenizer=tokenizer,
191
+ scheduler=scheduler,
192
+ transformer=transformer,
193
+ )
194
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
195
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
196
+ self.mask_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
197
+
198
+ def encode_prompt(
199
+ self,
200
+ prompt: Union[str, List[str]],
201
+ device: Optional[torch.device] = None,
202
+ num_images_per_prompt: int = 1,
203
+ do_classifier_free_guidance: bool = True,
204
+ negative_prompt: Optional[Union[str, List[str]]] = None,
205
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
206
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
207
+ max_sequence_length: int = 512,
208
+ ):
209
+ """
210
+ Encodes the prompt into text embeddings.
211
+
212
+ Args:
213
+ prompt (`Union[str, List[str]]`):
214
+ The prompt or prompts to guide the image generation.
215
+ device (`Optional[torch.device]`):
216
+ The device to move the embeddings to.
217
+ num_images_per_prompt (`int`):
218
+ The number of images to generate per prompt.
219
+ do_classifier_free_guidance (`bool`):
220
+ Whether to generate embeddings for classifier-free guidance.
221
+ negative_prompt (`Optional[Union[str, List[str]]]`):
222
+ The negative prompt or prompts.
223
+ prompt_embeds (`Optional[List[torch.FloatTensor]]`):
224
+ Pre-generated positive prompt embeddings.
225
+ negative_prompt_embeds (`Optional[torch.FloatTensor]`):
226
+ Pre-generated negative prompt embeddings.
227
+ max_sequence_length (`int`):
228
+ The maximum sequence length for tokenization.
229
+
230
+ Returns:
231
+ `Tuple[List[torch.Tensor], List[torch.Tensor]]`: A tuple containing the positive and negative prompt embeddings.
232
+ """
233
+ device = device or self._execution_device
234
+ prompt = [prompt] if isinstance(prompt, str) else prompt
235
+
236
+ if prompt_embeds is not None:
237
+ pass
238
+ else:
239
+ prompt_embeds = self._encode_prompt(
240
+ prompt=prompt,
241
+ device=device,
242
+ max_sequence_length=max_sequence_length,
243
+ )
244
+ if num_images_per_prompt > 1:
245
+ prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
246
+
247
+ if do_classifier_free_guidance:
248
+ if negative_prompt_embeds is not None:
249
+ pass
250
+ else:
251
+ if negative_prompt is None:
252
+ negative_prompt = [""] * len(prompt)
253
+ else:
254
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
255
+ assert len(prompt) == len(negative_prompt)
256
+ negative_prompt_embeds = self._encode_prompt(
257
+ prompt=negative_prompt,
258
+ device=device,
259
+ max_sequence_length=max_sequence_length,
260
+ )
261
+
262
+ if num_images_per_prompt > 1:
263
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
264
+
265
+ return prompt_embeds, negative_prompt_embeds
266
+
267
+ def _encode_prompt(self, prompt: Union[str, List[str]], device: torch.device, max_sequence_length: int) -> List[torch.Tensor]:
268
+ """
269
+ Internal helper to encode a list of prompts into embeddings, applying chat templates if available.
270
+
271
+ Args:
272
+ prompt (`Union[str, List[str]]`):
273
+ A list of strings to be encoded.
274
+ device (`torch.device`):
275
+ The target device for the embeddings.
276
+ max_sequence_length (`int`):
277
+ The maximum length for tokenization.
278
+
279
+ Returns:
280
+ `List[torch.Tensor]`: A list of embedding tensors, one for each prompt.
281
+ """
282
+ formatted_prompts = []
283
+ for p in prompt:
284
+ messages = [{"role": "user", "content": p}]
285
+ if hasattr(self.tokenizer, "apply_chat_template"):
286
+ formatted_prompts.append(self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True))
287
+ else:
288
+ formatted_prompts.append(p)
289
+
290
+ text_inputs = self.tokenizer(
291
+ formatted_prompts,
292
+ padding="max_length",
293
+ max_length=max_sequence_length,
294
+ truncation=True,
295
+ return_tensors="pt",
296
+ ).to(device)
297
+
298
+ prompt_masks = text_inputs.attention_mask.bool()
299
+
300
+ with torch.no_grad():
301
+ prompt_embeds_batch = self.text_encoder(input_ids=text_inputs.input_ids, attention_mask=prompt_masks, output_hidden_states=True).hidden_states[-2]
302
+
303
+ embeddings_list = []
304
+ for i in range(prompt_embeds_batch.shape[0]):
305
+ embeddings_list.append(prompt_embeds_batch[i][prompt_masks[i]])
306
+
307
+ return embeddings_list
308
+
309
+ def get_timesteps(self, num_inference_steps, strength, device):
310
+ """
311
+ Calculates the timesteps for the scheduler based on the number of inference steps and strength.
312
+ This is primarily used for image-to-image pipelines.
313
+
314
+ Args:
315
+ num_inference_steps (`int`): The total number of diffusion steps.
316
+ strength (`float`): The strength of the denoising process. A value of 1.0 means full denoising.
317
+ device (`torch.device`): The device to place the timesteps on.
318
+
319
+ Returns:
320
+ `Tuple[torch.Tensor, int]`: A tuple containing the timesteps and the number of steps to run.
321
+ """
322
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
323
+
324
+ t_start = int(max(num_inference_steps - init_timestep, 0))
325
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
326
+ if hasattr(self.scheduler, "set_begin_index"):
327
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
328
+
329
+ return timesteps, num_inference_steps - t_start
330
+
331
+ def prepare_latents(
332
+ self,
333
+ batch_size: int,
334
+ num_channels_latents: int,
335
+ height: int,
336
+ width: int,
337
+ dtype: torch.dtype,
338
+ device: torch.device,
339
+ generator: torch.Generator,
340
+ image: Optional[PipelineImageInput] = None,
341
+ timestep: Optional[torch.Tensor] = None,
342
+ latents: Optional[torch.Tensor] = None,
343
+ ):
344
+ """
345
+ Prepares the initial latents for the diffusion process.
346
+
347
+ This function handles three cases:
348
+ 1. `latents` are provided: They are returned directly.
349
+ 2. `image` is None (Text-to-Image): Random noise is generated.
350
+ 3. `image` is provided (Image-to-Image): The image is encoded, and noise is added according to the timestep.
351
+
352
+ Args:
353
+ batch_size (`int`): The number of latents to generate.
354
+ num_channels_latents (`int`): The number of channels in the latents.
355
+ height (`int`): The height of the output image in pixels.
356
+ width (`int`): The width of the output image in pixels.
357
+ dtype (`torch.dtype`): The data type for the latents.
358
+ device (`torch.device`): The device to create the latents on.
359
+ generator (`torch.Generator`): A random generator for creating the initial noise.
360
+ image (`Optional[PipelineImageInput]`): An initial image for img2img mode.
361
+ timestep (`Optional[torch.Tensor]`): The starting timestep for adding noise in img2img mode.
362
+ latents (`Optional[torch.Tensor]`): Pre-generated latents.
363
+
364
+ Returns:
365
+ `torch.Tensor`: The prepared latents.
366
+ """
367
+ latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
368
+ latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
369
+ shape = (batch_size, num_channels_latents, latent_height, latent_width)
370
+
371
+ if latents is not None:
372
+ return latents.to(device=device, dtype=dtype)
373
+
374
+ if image is None:
375
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
376
+ return latents
377
+
378
+ image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
379
+ with torch.no_grad():
380
+ if image_tensor.shape[1] != num_channels_latents:
381
+ if isinstance(generator, list):
382
+ image_latents = [retrieve_latents(self.vae.encode(image_tensor[i : i + 1]), generator=generator[i]) for i in range(image_tensor.shape[0])]
383
+ image_latents = torch.cat(image_latents, dim=0)
384
+ else:
385
+ image_latents = retrieve_latents(self.vae.encode(image_tensor), generator=generator)
386
+
387
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
388
+ image_latents = image_latents.to(dtype)
389
+
390
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
391
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
392
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
393
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
394
+ raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.")
395
+
396
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
397
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
398
+
399
+ return latents
400
+
401
+ def _prepare_image_latents(
402
+ self,
403
+ image: PipelineImageInput,
404
+ mask_image: PipelineImageInput,
405
+ width: int,
406
+ height: int,
407
+ batch_size: int,
408
+ num_images_per_prompt: int,
409
+ device: torch.device,
410
+ dtype: torch.dtype,
411
+ do_preprocess: bool = True,
412
+ ) -> torch.Tensor:
413
+ """
414
+ Generic function to encode an image into 5D latents for inpainting context.
415
+
416
+ If `do_preprocess` is True, it processes the image (PIL/np).
417
+ If `do_preprocess` is False, it assumes 'image' is already a ready-to-use tensor.
418
+
419
+ Args:
420
+ image (`PipelineImageInput`): The input image. Can be None to return zeros.
421
+ width (`int`): The target width.
422
+ height (`int`): The target height.
423
+ batch_size (`int`): The prompt batch size.
424
+ num_images_per_prompt (`int`): The number of images per prompt.
425
+ device (`torch.device`): The target device.
426
+ dtype (`torch.dtype`): The target data type.
427
+ do_preprocess (`bool`): Whether to preprocess the image.
428
+
429
+ Returns:
430
+ `torch.Tensor`: A 5D tensor of the encoded image latents.
431
+ """
432
+ if image is None:
433
+ latent_h = height // self.vae_scale_factor
434
+ latent_w = width // self.vae_scale_factor
435
+ shape = (batch_size * num_images_per_prompt, self.transformer.in_channels, 1, latent_h, latent_w)
436
+ return torch.zeros(shape, device=device, dtype=dtype)
437
+
438
+ if do_preprocess:
439
+ image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
440
+ else:
441
+ image_tensor = image.to(device=device, dtype=self.vae.dtype)
442
+
443
+ if mask_image is not None:
444
+ mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
445
+ # Tile para 3 canais (RGB)
446
+ mask_condition = torch.tile(mask_condition, [1, 3, 1, 1])
447
+ # Aplica máscara: mantém apenas áreas escuras (< 0.5)
448
+ image_tensor = image_tensor * (mask_condition < 0.5)
449
+
450
+ with torch.no_grad():
451
+ latents = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax")
452
+ latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
453
+
454
+ effective_batch_size = batch_size * num_images_per_prompt
455
+ if latents.shape[0] != effective_batch_size:
456
+ repeat_by = effective_batch_size // latents.shape[0]
457
+ latents = latents.repeat_interleave(repeat_by, dim=0)
458
+
459
+ return latents.to(dtype=dtype).unsqueeze(2)
460
+
461
+ def _prepare_mask_latents(
462
+ self,
463
+ mask_image: PipelineImageInput,
464
+ width: int,
465
+ height: int,
466
+ batch_size: int,
467
+ num_images_per_prompt: int,
468
+ reference_latents_shape: Tuple,
469
+ device: torch.device,
470
+ dtype: torch.dtype,
471
+ invert_mask: bool = False,
472
+ do_unsqueeze: bool = True,
473
+ ) -> torch.Tensor:
474
+ """
475
+ Processes a MASK using the mask_processor, inverts it, resizes it, and formats it for the control_context.
476
+
477
+ Args:
478
+ mask_image (`PipelineImageInput`): The mask image. Can be None to return zeros.
479
+ width (`int`): The target width.
480
+ height (`int`): The target height.
481
+ batch_size (`int`): The prompt batch size.
482
+ num_images_per_prompt (`int`): The number of images per prompt.
483
+ reference_latents_shape (`Tuple`): The shape of the inpainting latents for resizing.
484
+ device (`torch.device`): The target device.
485
+ dtype (`torch.dtype`): The target data type.
486
+
487
+ Returns:
488
+ `torch.Tensor`: A 5D tensor of the processed mask latents.
489
+ """
490
+ if mask_image is None:
491
+ placeholder_shape = (
492
+ batch_size * num_images_per_prompt,
493
+ 1,
494
+ 1,
495
+ reference_latents_shape[-2],
496
+ reference_latents_shape[-1],
497
+ )
498
+ return torch.zeros(placeholder_shape, device=device, dtype=dtype)
499
+
500
+ mask_tensor = self.mask_processor.preprocess(mask_image, height=height, width=width)
501
+ mask_tensor = mask_tensor.to(device=device, dtype=dtype)
502
+
503
+ if invert_mask:
504
+ mask_tensor = 1.0 - mask_tensor
505
+
506
+ mask_latents = F.interpolate(mask_tensor, size=reference_latents_shape[-2:], mode="nearest")
507
+
508
+ if do_unsqueeze:
509
+ mask_latents = mask_latents.unsqueeze(2)
510
+
511
+ return mask_latents
512
+
513
+ def prepare_control_latents(
514
+ self, image: PipelineImageInput, width: int, height: int, batch_size: int, num_images_per_prompt: int, device: torch.device, dtype: torch.dtype
515
+ ) -> torch.Tensor:
516
+ """
517
+ Preprocesses a control image, ENCODES it with the VAE to latent space,
518
+ and returns a 5D tensor ready for the transformer model.
519
+
520
+ Args:
521
+ image (`PipelineImageInput`): The control image. Can be None to return zeros.
522
+ width (`int`): The target width.
523
+ height (`int`): The target height.
524
+ batch_size (`int`): The prompt batch size.
525
+ num_images_per_prompt (`int`): The number of images per prompt.
526
+ device (`torch.device`): The target device.
527
+ dtype (`torch.dtype`): The target data type.
528
+
529
+ Returns:
530
+ `torch.Tensor`: A 5D tensor of the control image latents.
531
+ """
532
+ if image is None:
533
+ latent_h = 2 * (int(height) // (self.vae_scale_factor * 2))
534
+ latent_w = 2 * (int(width) // (self.vae_scale_factor * 2))
535
+ return torch.zeros(
536
+ (batch_size * num_images_per_prompt, self.transformer.in_channels, 1, latent_h, latent_w),
537
+ device=device,
538
+ dtype=dtype,
539
+ )
540
+
541
+ image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
542
+ with torch.no_grad():
543
+ latents = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax")
544
+ latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
545
+
546
+ effective_batch_size = batch_size * num_images_per_prompt
547
+ if latents.shape[0] < effective_batch_size:
548
+ latents = latents.repeat_interleave(effective_batch_size // latents.shape[0], dim=0)
549
+
550
+ return latents.to(dtype=dtype).unsqueeze(2)
551
+
552
+ def _expand_and_feather_mask(self, mask_image, expand_pixels=10, feather_radius=8, is_inpaint_mode=True):
553
+ """
554
+ Expands the white area of a mask using PyTorch for performance and then smooths its edges with Pillow.
555
+
556
+ Args:
557
+ mask_image (PIL.Image.Image | np.ndarray | torch.Tensor): The input mask.
558
+ expand_pixels (int): How many pixels to expand the white area.
559
+ feather_radius (int): The radius of the Gaussian blur for the gradient.
560
+ is_inpaint_mode (bool): Flag to enable/disable the operation.
561
+
562
+ Returns:
563
+ PIL.Image.Image | np.ndarray | torch.Tensor: The processed mask, in the same format as the input.
564
+ """
565
+ if not is_inpaint_mode or (expand_pixels <= 0 and feather_radius <= 0):
566
+ return mask_image
567
+
568
+ # --- 1. CONVERSÃO PARA TENSOR PYTORCH ---
569
+ input_type = type(mask_image)
570
+
571
+ if isinstance(mask_image, Image.Image):
572
+ # Converte PIL Image para Tensor
573
+ mask_tensor = T.ToTensor()(mask_image.convert("L"))
574
+ elif isinstance(mask_image, np.ndarray):
575
+ # Converte NumPy array para Tensor
576
+ mask_tensor = torch.from_numpy(mask_image).permute(2, 0, 1) if mask_image.ndim == 3 else torch.from_numpy(mask_image).unsqueeze(0)
577
+ elif isinstance(mask_image, torch.Tensor):
578
+ mask_tensor = mask_image
579
+ else:
580
+ raise TypeError(f"Unsupported mask type: {input_type}")
581
+
582
+ # Garante que o tensor está no device e formato corretos (Batch, Canais, H, W)
583
+ mask_tensor = mask_tensor.to(device=self.device, dtype=torch.float32)
584
+ if mask_tensor.ndim == 3:
585
+ mask_tensor = mask_tensor.unsqueeze(0) # Adiciona a dimensão do batch se necessário
586
+
587
+ # --- 2. EXPANSÃO (DILATION) NA GPU COM PYTORCH ---
588
+ if expand_pixels > 0:
589
+ kernel_size = expand_pixels * 2 + 1
590
+ padding = expand_pixels
591
+
592
+ # Max pooling com stride=1 é a implementação de dilatação para tensores
593
+ mask_tensor = F.max_pool2d(
594
+ mask_tensor,
595
+ kernel_size=kernel_size,
596
+ stride=1,
597
+ padding=padding
598
+ )
599
+
600
+ # --- 3. CONVERSÃO DE VOLTA PARA PIL IMAGE ---
601
+ # `ToPILImage` espera um tensor [C, H, W], então removemos a dimensão do batch
602
+ to_pil = T.ToPILImage()
603
+ mask_pil = to_pil(mask_tensor.squeeze(0).cpu())
604
+
605
+ # --- 4. DEGRADÊ (FEATHERING / BLUR) COM PILLOW ---
606
+ if feather_radius > 0:
607
+ mask_pil = mask_pil.filter(ImageFilter.GaussianBlur(radius=feather_radius))
608
+
609
+ # --- 5. CONVERSÃO FINAL PARA O TIPO ORIGINAL ---
610
+ if input_type is torch.Tensor:
611
+ # Reconverte para Tensor se o input era um Tensor
612
+ return T.ToTensor()(mask_pil).to(device=self.device, dtype=mask_image.dtype)
613
+ elif input_type is np.ndarray:
614
+ # Reconverte para NumPy array se o input era um array
615
+ return np.array(mask_pil)
616
+ else: # input_type is Image.Image
617
+ return mask_pil
618
+
619
+ def _apply_mask_blur(self, mask_image, mask_blur_radius, is_inpaint_mode):
620
+ """
621
+ Apply Gaussian blur to a mask image for inpainting operations.
622
+ Args:
623
+ mask_image (Image.Image | np.ndarray | torch.Tensor): The mask image to be blurred.
624
+ Can be provided as a PIL Image, NumPy array, or PyTorch tensor.
625
+ mask_blur_radius (float): The radius of the Gaussian blur filter in pixels.
626
+ Only applied if is_inpaint_mode is True and mask_blur_radius > 0.
627
+ is_inpaint_mode (bool): Flag indicating whether the pipeline is in inpainting mode.
628
+ Blur is only applied when this is True.
629
+ Returns:
630
+ Image.Image | np.ndarray | torch.Tensor: The mask image with Gaussian blur applied
631
+ if is_inpaint_mode is True and mask_blur_radius > 0. Otherwise, returns the
632
+ original mask_image unchanged. The return type matches the input type.
633
+ """
634
+ mask_to_use = mask_image
635
+ if is_inpaint_mode and mask_blur_radius > 0:
636
+ if isinstance(mask_image, Image.Image):
637
+ mask_pil = mask_image
638
+ elif isinstance(mask_image, np.ndarray):
639
+ mask_pil = Image.fromarray(mask_image)
640
+ elif isinstance(mask_image, torch.Tensor):
641
+ mask_pil = Image.fromarray(mask_image.cpu().numpy().astype(np.uint8))
642
+ else:
643
+ mask_pil = mask_image
644
+
645
+ mask_to_use = mask_pil.filter(ImageFilter.GaussianBlur(radius=mask_blur_radius))
646
+ return mask_to_use
647
+
648
+ @property
649
+ def guidance_scale(self):
650
+ return self._guidance_scale
651
+
652
+ @property
653
+ def do_classifier_free_guidance(self):
654
+ return self._guidance_scale > 1
655
+
656
+ @property
657
+ def joint_attention_kwargs(self):
658
+ return self._joint_attention_kwargs
659
+
660
+ @property
661
+ def num_timesteps(self):
662
+ return self._num_timesteps
663
+
664
+ @property
665
+ def interrupt(self):
666
+ return self._interrupt
667
+
668
+ def __call__(
669
+ self,
670
+ prompt: Union[str, List[str]],
671
+ image: Optional[PipelineImageInput] = None,
672
+ mask_image: Optional[PipelineImageInput] = None,
673
+ inpaint_mode: Literal["default", "diff", "diff+inpaint"] = "default",
674
+ mask_blur_radius: float=8.0,
675
+ control_image: Optional[PipelineImageInput] = None,
676
+ height: Optional[int] = None,
677
+ width: Optional[int] = None,
678
+ num_inference_steps: int = 20,
679
+ sigmas: Optional[List[float]] = None,
680
+ strength: float = 1.0,
681
+ guidance_scale: float = 4.0,
682
+ cfg_normalization: bool = False,
683
+ cfg_truncation: float = 1.0,
684
+ negative_prompt: Optional[Union[str, List[str]]] = None,
685
+ num_images_per_prompt: int = 1,
686
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
687
+ latents: Optional[torch.Tensor] = None,
688
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
689
+ negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
690
+ controlnet_conditioning_scale: float = 1.0,
691
+ controlnet_refiner_conditioning_scale: float = 1.0,
692
+ output_type: str = "pil",
693
+ return_dict: bool = True,
694
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
695
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
696
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
697
+ max_sequence_length: int = 512,
698
+ ):
699
+ r"""
700
+ The main entry point for the Z-Image unified pipeline for generation.
701
+
702
+ Args:
703
+ prompt (`str` or `List[str]`, *optional*):
704
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
705
+ image (`PipelineImageInput`, *optional*):
706
+ The initial image for image-to-image or inpainting modes.
707
+ mask_image (`PipelineImageInput`, *optional*):
708
+ The mask image for inpainting. White areas are preserved, black areas are inpainted.
709
+ inpaint_mode (`str`, *optional*, defaults to `"default"`):
710
+ The inpainting mode. Can be "default", "diff", or "diff+inpaint". Determines how the inpainting
711
+ process is handled.
712
+ mask_blur_radius (`float`, *optional*, defaults to 8.0):
713
+ The radius for blurring the edges of the inpainting mask to create a smoother transition.
714
+ control_image (`PipelineImageInput`, *optional*):
715
+ The conditioning image for control modes (e.g., Canny, depth).
716
+ height (`int`, *optional*, defaults to 1024):
717
+ The height in pixels of the generated image.
718
+ width (`int`, *optional*, defaults to 1024):
719
+ The width in pixels of the generated image.
720
+ num_inference_steps (`int`, *optional*, defaults to 20):
721
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
722
+ expense of slower inference.
723
+ sigmas (`List[float]`, *optional*):
724
+ Custom sigmas to use for the denoising process. If not defined, the scheduler's default behavior
725
+ will be used.
726
+ strength (`float`, *optional*, defaults to 1.0):
727
+ Denoising strength for image-to-image. A value of 1.0 means the initial image is fully replaced,
728
+ while a lower value preserves more of the original image structure. Only used in img2img mode.
729
+ guidance_scale (`float`, *optional*, defaults to 4.0):
730
+ The scale for classifier-free guidance. A value > 1 enables it. Higher values encourage images
731
+ closer to the prompt, potentially at the cost of quality.
732
+ cfg_normalization (`bool`, *optional*, defaults to False):
733
+ Whether to apply normalization to the guidance, which can prevent oversaturation.
734
+ cfg_truncation (`float`, *optional*, defaults to 1.0):
735
+ A value between 0.0 and 1.0 that disables CFG for the final portion of the denoising steps,
736
+ specified as a fraction of total steps. For example, 0.8 disables CFG for the last 20% of steps.
737
+ negative_prompt (`str` or `List[str]`, *optional*):
738
+ The prompt or prompts not to guide the image generation.
739
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
740
+ The number of images to generate per prompt.
741
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
742
+ A torch generator to make generation deterministic.
743
+ latents (`torch.FloatTensor`, *optional*):
744
+ Pre-generated noisy latents to be used as inputs for image generation.
745
+ prompt_embeds (`List[torch.FloatTensor]`, *optional*):
746
+ Pre-generated positive text embeddings.
747
+ negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
748
+ Pre-generated negative text embeddings.
749
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
750
+ The scale of the control conditioning influence.
751
+ controlnet_refiner_conditioning_scale (`float`, *optional*, defaults to 1.0):
752
+ The scale of the control refiner conditioning influence.
753
+ output_type (`str`, *optional*, defaults to `"pil"`):
754
+ The output format of the generated image. Choose between "pil" (`PIL.Image.Image`), "np.array", or "latent".
755
+ return_dict (`bool`, *optional*, defaults to `True`):
756
+ Whether to return a `ZImagePipelineOutput` instead of a plain tuple.
757
+ joint_attention_kwargs (`dict`, *optional*):
758
+ A kwargs dictionary for the `AttentionProcessor`.
759
+ callback_on_step_end (`Callable`, *optional*):
760
+ A function that is called at the end of each denoising step.
761
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
762
+ The list of tensor inputs for the `callback_on_step_end` function.
763
+ max_sequence_length (`int`, *optional*, defaults to 512):
764
+ Maximum sequence length to use with the `prompt`.
765
+
766
+ Examples:
767
+
768
+ Returns:
769
+ [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`:
770
+ If `return_dict` is True, a `ZImagePipelineOutput` is returned, otherwise a `tuple` with the generated images.
771
+ """
772
+ self._guidance_scale = guidance_scale
773
+ self._joint_attention_kwargs = joint_attention_kwargs
774
+ self._interrupt = False
775
+ self._cfg_normalization = cfg_normalization
776
+ self._cfg_truncation = cfg_truncation
777
+ is_two_stage_control_model = self.transformer.control_in_dim > self.transformer.in_channels if hasattr(self.transformer, "control_in_dim") else False
778
+ device = self._execution_device
779
+ dtype = self.transformer.dtype
780
+ vae_scale = self.vae_scale_factor * 2
781
+ has_inpaint_inputs = image is not None and mask_image is not None
782
+ is_inpaint_control_mode = has_inpaint_inputs and inpaint_mode in ["default", "diff+inpaint"]
783
+ is_diff_mode = has_inpaint_inputs and inpaint_mode in ["diff", "diff+inpaint"]
784
+ is_img2img_mode = image is not None and not has_inpaint_inputs
785
+
786
+ ref_image = control_image or image
787
+ image_height = None
788
+ image_width = None
789
+ if ref_image is not None:
790
+ if isinstance(ref_image, Image.Image):
791
+ image_height, image_width = ref_image.height, ref_image.width
792
+ else:
793
+ image_height, image_width = ref_image.shape[-2], ref_image.shape[-1]
794
+
795
+ height = height or image_height or 1024
796
+ width = width or image_width or 1024
797
+
798
+ if height % vae_scale != 0 or width % vae_scale != 0:
799
+ raise ValueError(f"Height/width must be divisible by {vae_scale}.")
800
+
801
+ batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt else len(prompt_embeds)
802
+ effective_batch_size = batch_size * num_images_per_prompt
803
+
804
+ if prompt_embeds is not None and prompt is None:
805
+ if self.do_classifier_free_guidance and negative_prompt_embeds is None:
806
+ raise ValueError(
807
+ "When `prompt_embeds` is provided without `prompt`, `negative_prompt_embeds` must also be provided for classifier-free guidance."
808
+ )
809
+ else:
810
+ (
811
+ prompt_embeds,
812
+ negative_prompt_embeds,
813
+ ) = self.encode_prompt(
814
+ prompt=prompt,
815
+ num_images_per_prompt=num_images_per_prompt,
816
+ negative_prompt=negative_prompt,
817
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
818
+ prompt_embeds=prompt_embeds,
819
+ negative_prompt_embeds=negative_prompt_embeds,
820
+ device=device,
821
+ max_sequence_length=max_sequence_length,
822
+ )
823
+
824
+ if self.do_classifier_free_guidance:
825
+ prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
826
+ else:
827
+ prompt_embeds_model_input = prompt_embeds
828
+
829
+ if control_image is not None or is_inpaint_control_mode:
830
+ control_latents = self.prepare_control_latents(control_image, width, height, batch_size, num_images_per_prompt, device, dtype)
831
+
832
+ if is_two_stage_control_model:
833
+ image_for_inpaint = None if is_diff_mode and not is_inpaint_control_mode else image
834
+ mask_for_inpaint = None if is_diff_mode and not is_inpaint_control_mode else mask_image
835
+
836
+ if is_inpaint_control_mode:
837
+ mask_for_inpaint = self._apply_mask_blur(mask_for_inpaint, mask_blur_radius, True)
838
+
839
+ inpaint_latents = self._prepare_image_latents(
840
+ image_for_inpaint, mask_for_inpaint, width, height, batch_size, num_images_per_prompt, device, dtype
841
+ )
842
+
843
+ mask_latents = self._prepare_mask_latents(
844
+ mask_for_inpaint,
845
+ width,
846
+ height,
847
+ batch_size,
848
+ num_images_per_prompt,
849
+ inpaint_latents.shape,
850
+ device,
851
+ dtype,
852
+ invert_mask=is_inpaint_control_mode,
853
+ do_unsqueeze=True,
854
+ )
855
+ control_context = torch.cat([control_latents, mask_latents, inpaint_latents], dim=1)
856
+ else:
857
+ control_context = control_latents
858
+ else:
859
+ control_context = None
860
+
861
+ if self.do_classifier_free_guidance:
862
+ control_context_model_input = control_context.repeat(2, 1, 1, 1, 1)
863
+ else:
864
+ control_context_model_input = control_context
865
+
866
+ image_seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
867
+ mu = calculate_shift(image_seq_len)
868
+ self.scheduler.sigma_min = 0.0
869
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas, mu=mu)
870
+ self._num_timesteps = len(timesteps)
871
+
872
+ if is_img2img_mode:
873
+ strength = min(strength, 1.0)
874
+ else:
875
+ strength = 1.0
876
+
877
+ if strength < 1.0:
878
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
879
+ t_start = max(num_inference_steps - init_timestep, 0)
880
+ timesteps = timesteps[t_start * self.scheduler.order :]
881
+ num_steps_to_run = len(timesteps) // self.scheduler.order
882
+ else:
883
+ num_steps_to_run = num_inference_steps
884
+
885
+ latent_timestep = timesteps[:1].repeat(effective_batch_size) if strength < 1.0 else None
886
+
887
+ use_image_for_latents = is_img2img_mode
888
+
889
+ latents = self.prepare_latents(
890
+ effective_batch_size,
891
+ self.transformer.in_channels,
892
+ height,
893
+ width,
894
+ torch.float32,
895
+ device,
896
+ generator,
897
+ image=image if use_image_for_latents else None,
898
+ timestep=latent_timestep if use_image_for_latents else None,
899
+ latents=latents,
900
+ )
901
+
902
+ if is_diff_mode:
903
+ original_image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
904
+ with torch.no_grad():
905
+ original_clean_latents = retrieve_latents(self.vae.encode(original_image_tensor), sample_mode="argmax")
906
+ original_clean_latents = (original_clean_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
907
+ original_clean_latents = original_clean_latents.to(dtype)
908
+
909
+ noise = randn_tensor(original_clean_latents.shape, generator=generator, device=device, dtype=dtype)
910
+ latents_list = []
911
+ step_indices = [(self.scheduler.timesteps == t).nonzero().item() for t in timesteps]
912
+ for i in step_indices:
913
+ sigma = self.scheduler.sigmas[i]
914
+ noisy_latent = (1.0 - sigma) * original_clean_latents + sigma * noise
915
+ latents_list.append(noisy_latent)
916
+
917
+ original_latents_trajectory = torch.cat(latents_list, dim=0)
918
+ blurred_mask_image = self._apply_mask_blur(mask_image, mask_blur_radius, True)
919
+ map_processed = self._prepare_mask_latents(
920
+ blurred_mask_image,
921
+ width,
922
+ height,
923
+ batch_size,
924
+ num_images_per_prompt,
925
+ latents.shape,
926
+ device,
927
+ dtype,
928
+ invert_mask=True,
929
+ do_unsqueeze=False,
930
+ )
931
+
932
+ thresholds = torch.arange(len(timesteps), device=device, dtype=dtype) / len(timesteps)
933
+ thresholds = thresholds.view(-1, 1, 1, 1)
934
+ time_masks = map_processed > thresholds
935
+
936
+ num_warmup_steps = len(timesteps) - num_steps_to_run * self.scheduler.order
937
+ with torch.inference_mode():
938
+ with self.progress_bar(total=num_steps_to_run) as progress_bar:
939
+ for i, t in enumerate(timesteps):
940
+ if self.interrupt:
941
+ continue
942
+
943
+ if is_diff_mode:
944
+ if i == 0:
945
+ latents = original_latents_trajectory[:1]
946
+ else:
947
+ current_mask = time_masks[i].to(latents.dtype)
948
+ current_original_latent = original_latents_trajectory[i:i+1]
949
+
950
+ if current_mask.ndim == 3:
951
+ current_mask = current_mask.unsqueeze(1)
952
+
953
+ latents = current_original_latent * current_mask + latents * (1 - current_mask)
954
+
955
+ timestep = t.expand(latents.shape[0])
956
+ timestep = (1000 - timestep) / 1000
957
+
958
+ t_norm = timestep[0].item()
959
+ current_guidance_scale = self.guidance_scale
960
+ if self.do_classifier_free_guidance and self._cfg_truncation is not None and float(self._cfg_truncation) <= 1:
961
+ if t_norm > self._cfg_truncation:
962
+ current_guidance_scale = 0.0
963
+ apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
964
+
965
+ if apply_cfg:
966
+ latent_model_input = latents.repeat(2, 1, 1, 1)
967
+ timestep_model_input = timestep.repeat(2)
968
+ else:
969
+ latent_model_input = latents
970
+ timestep_model_input = timestep
971
+
972
+ latent_model_input = latent_model_input.to(self.transformer.dtype)
973
+ latent_model_input = latent_model_input.unsqueeze(2)
974
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
975
+
976
+ model_out_list = self.transformer(
977
+ x=latent_model_input_list,
978
+ t=timestep_model_input,
979
+ cap_feats=prompt_embeds_model_input,
980
+ control_context=control_context_model_input,
981
+ conditioning_scale=controlnet_conditioning_scale,
982
+ refiner_conditioning_scale=controlnet_refiner_conditioning_scale,
983
+ )[0]
984
+
985
+ if apply_cfg:
986
+ pos_out = model_out_list[:effective_batch_size]
987
+ neg_out = model_out_list[effective_batch_size:]
988
+
989
+ noise_pred = []
990
+ for j in range(effective_batch_size):
991
+ pos = pos_out[j].float()
992
+ neg = neg_out[j].float()
993
+
994
+ pred = pos + current_guidance_scale * (pos - neg)
995
+
996
+ if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
997
+ ori_pos_norm = torch.linalg.vector_norm(pos)
998
+ new_pos_norm = torch.linalg.vector_norm(pred)
999
+ max_new_norm = ori_pos_norm * float(self._cfg_normalization)
1000
+ if new_pos_norm > max_new_norm:
1001
+ pred = pred * (max_new_norm / new_pos_norm)
1002
+
1003
+ noise_pred.append(pred)
1004
+
1005
+ noise_pred = torch.stack(noise_pred, dim=0)
1006
+ else:
1007
+ noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
1008
+
1009
+ noise_pred = noise_pred.squeeze(2)
1010
+ noise_pred = -noise_pred
1011
+
1012
+ latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents).prev_sample
1013
+
1014
+ if callback_on_step_end is not None:
1015
+ callback_kwargs = {}
1016
+ for k in callback_on_step_end_tensor_inputs:
1017
+ callback_kwargs[k] = locals()[k]
1018
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1019
+
1020
+ if isinstance(callback_outputs, dict):
1021
+ latents = callback_outputs.pop("latents", latents)
1022
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1023
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1024
+
1025
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1026
+ progress_bar.update()
1027
+
1028
+ if output_type != "latent":
1029
+ latents = latents.to(self.vae.dtype)
1030
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1031
+ with torch.no_grad():
1032
+ image = self.vae.decode(latents, return_dict=False)[0]
1033
+ image = self.image_processor.postprocess(image, output_type=output_type)
1034
+ else:
1035
+ image = latents
1036
+
1037
+ self.maybe_free_model_hooks()
1038
+
1039
+ if not return_dict:
1040
+ return (image,)
1041
+
1042
+ return ZImagePipelineOutput(images=image)
diffusers_local/z_image_control_transformer_2d.py ADDED
@@ -0,0 +1,1460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
2
+ # Refactored and optimized by DEVAIEXP Team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import math
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
26
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
27
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.normalization import RMSNorm
30
+ from diffusers.utils import (
31
+ is_torch_version,
32
+ )
33
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
34
+ from torch.nn.utils.rnn import pad_sequence
35
+
36
+
37
+ ADALN_EMBED_DIM = 256
38
+ SEQ_MULTI_OF = 32
39
+
40
+
41
+ def zero_module(module):
42
+ """
43
+ Initializes the parameters of a given module with zeros.
44
+
45
+ Args:
46
+ module (nn.Module): The module to be zero-initialized.
47
+
48
+ Returns:
49
+ nn.Module: The same module with its parameters initialized to zero.
50
+ """
51
+ for p in module.parameters():
52
+ nn.init.zeros_(p)
53
+ return module
54
+
55
+
56
+ class TimestepEmbedder(nn.Module):
57
+ """
58
+ A module to embed timesteps into a higher-dimensional space using sinusoidal embeddings
59
+ followed by a multilayer perceptron (MLP).
60
+ """
61
+
62
+ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
63
+ """
64
+ Initializes the TimestepEmbedder module.
65
+
66
+ Args:
67
+ out_size (int): The output dimension of the embedding.
68
+ mid_size (int, optional): The intermediate dimension of the MLP. Defaults to `out_size`.
69
+ frequency_embedding_size (int, optional): The dimension of the sinusoidal frequency embedding. Defaults to 256.
70
+ """
71
+ super().__init__()
72
+ if mid_size is None:
73
+ mid_size = out_size
74
+ self.mlp = nn.Sequential(
75
+ nn.Linear(
76
+ frequency_embedding_size,
77
+ mid_size,
78
+ bias=True,
79
+ ),
80
+ nn.SiLU(),
81
+ nn.Linear(
82
+ mid_size,
83
+ out_size,
84
+ bias=True,
85
+ ),
86
+ )
87
+ self.frequency_embedding_size = frequency_embedding_size
88
+
89
+ @staticmethod
90
+ def timestep_embedding(t, dim, max_period=10000):
91
+ """
92
+ Creates sinusoidal timestep embeddings.
93
+
94
+ Args:
95
+ t (torch.Tensor): A 1-D Tensor of N timesteps.
96
+ dim (int): The dimension of the embedding.
97
+ max_period (int, optional): The maximum period for the sinusoidal frequencies. Defaults to 10000.
98
+
99
+ Returns:
100
+ torch.Tensor: The timestep embeddings with shape (N, dim).
101
+ """
102
+ with torch.amp.autocast("cuda", enabled=False):
103
+ half = dim // 2
104
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
105
+ args = t[:, None] * freqs[None]
106
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
107
+ if dim % 2:
108
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
109
+ return embedding
110
+
111
+ def forward(self, t):
112
+ """
113
+ Processes the input timesteps to generate embeddings.
114
+
115
+ Args:
116
+ t (torch.Tensor): The input timesteps.
117
+
118
+ Returns:
119
+ torch.Tensor: The final timestep embeddings after passing through the MLP.
120
+ """
121
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
122
+ weight_dtype = self.mlp[0].weight.dtype
123
+ if weight_dtype.is_floating_point:
124
+ t_freq = t_freq.to(weight_dtype)
125
+ t_emb = self.mlp(t_freq)
126
+ return t_emb
127
+
128
+
129
+ class FeedForward(nn.Module):
130
+ """
131
+ A Feed-Forward Network module using SwiGLU activation.
132
+ """
133
+
134
+ def __init__(self, dim: int, hidden_dim: int):
135
+ """
136
+ Initializes the FeedForward module.
137
+
138
+ Args:
139
+ dim (int): Input and output dimension.
140
+ hidden_dim (int): The hidden dimension of the network.
141
+ """
142
+ super().__init__()
143
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
144
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
145
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
146
+
147
+ def _forward_silu_gating(self, x1, x3):
148
+ """
149
+ Applies the SiLU gating mechanism.
150
+
151
+ Args:
152
+ x1 (torch.Tensor): The first intermediate tensor.
153
+ x3 (torch.Tensor): The second intermediate tensor (gate).
154
+
155
+ Returns:
156
+ torch.Tensor: The result of the gating operation.
157
+ """
158
+ return F.silu(x1) * x3
159
+
160
+ def forward(self, x):
161
+ """
162
+ Defines the forward pass of the FeedForward network.
163
+
164
+ Args:
165
+ x (torch.Tensor): The input tensor.
166
+
167
+ Returns:
168
+ torch.Tensor: The output tensor.
169
+ """
170
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
171
+
172
+
173
+ class FinalLayer(nn.Module):
174
+ """
175
+ The final layer of the transformer, which applies AdaLN modulation and a linear projection.
176
+ """
177
+
178
+ def __init__(self, hidden_size, out_channels):
179
+ """
180
+ Initializes the FinalLayer module.
181
+
182
+ Args:
183
+ hidden_size (int): The input hidden size.
184
+ out_channels (int): The output dimension (number of channels).
185
+ """
186
+ super().__init__()
187
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
188
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
189
+ self.adaLN_modulation = nn.Sequential(
190
+ nn.SiLU(),
191
+ nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
192
+ )
193
+
194
+ def forward(self, x, c):
195
+ """
196
+ Defines the forward pass for the final layer.
197
+
198
+ Args:
199
+ x (torch.Tensor): The main input tensor from the transformer blocks.
200
+ c (torch.Tensor): The conditioning tensor (usually from timestep embedding) for AdaLN modulation.
201
+
202
+ Returns:
203
+ torch.Tensor: The final output tensor projected to the patch dimension.
204
+ """
205
+ scale = 1.0 + self.adaLN_modulation(c)
206
+ x = self.norm_final(x) * scale.unsqueeze(1)
207
+ x = self.linear(x)
208
+ return x
209
+
210
+
211
+ class RopeEmbedder:
212
+ """
213
+ Computes Rotary Positional Embeddings (RoPE) for 3D coordinates.
214
+ """
215
+
216
+ def __init__(self, theta: float = 256.0, axes_dims: List[int] = (32, 48, 48), axes_lens: List[int] = (1024, 512, 512)):
217
+ """
218
+ Initializes the RopeEmbedder.
219
+
220
+ Args:
221
+ theta (float, optional): The base for the rotary frequencies. Defaults to 256.0.
222
+ axes_dims (List[int], optional): The dimensions for each axis (F, H, W). Defaults to (32, 48, 48).
223
+ axes_lens (List[int], optional): The maximum length for each axis. Defaults to (1024, 512, 512).
224
+ """
225
+ self.theta = theta
226
+ self.axes_dims = axes_dims
227
+ self.axes_lens = axes_lens
228
+ self.freqs_cis_cache = {}
229
+
230
+ def _precompute_freqs_cis(self, device):
231
+ """
232
+ Precomputes and caches the rotary frequency tensors (cos and sin values).
233
+
234
+ Args:
235
+ device (torch.device): The device to store the cached tensors on.
236
+
237
+ Returns:
238
+ List[torch.Tensor]: A list of precomputed frequency tensors for each axis.
239
+ """
240
+ if device in self.freqs_cis_cache:
241
+ return self.freqs_cis_cache[device]
242
+ freqs_cis_list = []
243
+ for dim, max_len in zip(self.axes_dims, self.axes_lens):
244
+ half = dim // 2
245
+ freqs = 1.0 / (self.theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
246
+ t = torch.arange(max_len, device=device, dtype=torch.float32)
247
+ freqs = torch.outer(t, freqs)
248
+ emb = torch.stack([freqs.cos(), freqs.sin()], dim=-1)
249
+ freqs_cis_list.append(emb)
250
+ self.freqs_cis_cache[device] = freqs_cis_list
251
+ return freqs_cis_list
252
+
253
+ def __call__(self, ids: torch.Tensor):
254
+ """
255
+ Generates RoPE embeddings for a batch of 3D coordinates.
256
+
257
+ Args:
258
+ ids (torch.Tensor): A tensor of coordinates with shape (N, 3).
259
+
260
+ Returns:
261
+ torch.Tensor: The concatenated RoPE embeddings for the input coordinates.
262
+ """
263
+ assert ids.ndim == 2 and ids.shape[1] == len(self.axes_dims)
264
+ device = ids.device
265
+ freqs_cis_list = self._precompute_freqs_cis(device)
266
+ result = []
267
+ for i in range(len(self.axes_dims)):
268
+ result.append(freqs_cis_list[i][ids[:, i]])
269
+ return torch.cat(result, dim=-2)
270
+
271
+
272
+ class ZSingleStreamAttnProcessor:
273
+ """
274
+ An attention processor that applies Rotary Positional Embeddings (RoPE) to query and key tensors
275
+ before computing scaled dot-product attention.
276
+ """
277
+
278
+ _attention_backend = None
279
+ _parallel_config = None
280
+
281
+ def __init__(self):
282
+ """
283
+ Initializes the ZSingleStreamAttnProcessor.
284
+ """
285
+ if not hasattr(F, "scaled_dot_product_attention"):
286
+ raise ImportError("ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher.")
287
+
288
+ def __call__(
289
+ self,
290
+ attn: Attention,
291
+ hidden_states: torch.Tensor,
292
+ encoder_hidden_states: Optional[torch.Tensor] = None,
293
+ attention_mask: Optional[torch.Tensor] = None,
294
+ freqs_cis: Optional[torch.Tensor] = None,
295
+ ) -> torch.Tensor:
296
+ """
297
+ The forward call for the attention processor.
298
+
299
+ Args:
300
+ attn (Attention): The attention layer that this processor is attached to.
301
+ hidden_states (torch.Tensor): The input hidden states.
302
+ encoder_hidden_states (Optional[torch.Tensor], optional): Not used in self-attention. Defaults to None.
303
+ attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None.
304
+ freqs_cis (Optional[torch.Tensor], optional): The precomputed RoPE frequencies. Defaults to None.
305
+
306
+ Returns:
307
+ torch.Tensor: The output of the attention mechanism.
308
+ """
309
+
310
+ def apply_rotary_emb(q_or_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
311
+ """
312
+ Applies RoPE to a query or key tensor.
313
+ """
314
+ x = q_or_k.transpose(1, 2)
315
+ x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)
316
+ x0 = x_reshaped[..., 0]
317
+ x1 = x_reshaped[..., 1]
318
+ freqs_cos = freqs_cis[..., 0].unsqueeze(1)
319
+ freqs_sin = freqs_cis[..., 1].unsqueeze(1)
320
+ x_rotated_0 = x0 * freqs_cos - x1 * freqs_sin
321
+ x_rotated_1 = x0 * freqs_sin + x1 * freqs_cos
322
+ x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)
323
+ x_out = x_rotated.flatten(-2).transpose(1, 2)
324
+ return x_out.to(q_or_k.dtype)
325
+
326
+ query = attn.to_q(hidden_states)
327
+ key = attn.to_k(hidden_states)
328
+ value = attn.to_v(hidden_states)
329
+
330
+ query = query.unflatten(-1, (attn.heads, -1))
331
+ key = key.unflatten(-1, (attn.heads, -1))
332
+ value = value.unflatten(-1, (attn.heads, -1))
333
+
334
+ if attn.norm_q is not None:
335
+ query = attn.norm_q(query)
336
+ if attn.norm_k is not None:
337
+ key = attn.norm_k(key)
338
+
339
+ if freqs_cis is not None:
340
+ query = apply_rotary_emb(query, freqs_cis)
341
+ key = apply_rotary_emb(key, freqs_cis)
342
+
343
+ if attention_mask is not None and attention_mask.ndim == 2:
344
+ attention_mask = attention_mask[:, None, None, :]
345
+
346
+ hidden_states = dispatch_attention_fn(
347
+ query,
348
+ key,
349
+ value,
350
+ attn_mask=attention_mask,
351
+ dropout_p=0.0,
352
+ is_causal=False,
353
+ backend=self._attention_backend,
354
+ parallel_config=self._parallel_config,
355
+ )
356
+
357
+ hidden_states = hidden_states.flatten(2, 3)
358
+
359
+ output = attn.to_out[0](hidden_states.to(hidden_states.dtype))
360
+ if len(attn.to_out) > 1:
361
+ output = attn.to_out[1](output)
362
+
363
+ return output
364
+
365
+
366
+ @maybe_allow_in_graph
367
+ class ZImageTransformerBlock(nn.Module):
368
+ """
369
+ A standard transformer block consisting of a self-attention layer and a feed-forward network.
370
+ Includes support for AdaLN modulation.
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ layer_id: int,
376
+ dim: int,
377
+ n_heads: int,
378
+ n_kv_heads: int,
379
+ norm_eps: float,
380
+ qk_norm: bool,
381
+ modulation=True,
382
+ ):
383
+ """
384
+ Initializes the ZImageTransformerBlock.
385
+
386
+ Args:
387
+ layer_id (int): The index of the layer.
388
+ dim (int): The dimension of the input and output features.
389
+ n_heads (int): The number of attention heads.
390
+ n_kv_heads (int): The number of key/value heads (not directly used in this simplified attention).
391
+ norm_eps (float): Epsilon for RMSNorm.
392
+ qk_norm (bool): Whether to apply normalization to query and key tensors.
393
+ modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
394
+ """
395
+ super().__init__()
396
+ self.dim = dim
397
+ self.head_dim = dim // n_heads
398
+ self.attention = Attention(
399
+ query_dim=dim,
400
+ cross_attention_dim=None,
401
+ dim_head=dim // n_heads,
402
+ heads=n_heads,
403
+ qk_norm="rms_norm" if qk_norm else None,
404
+ eps=1e-5,
405
+ bias=False,
406
+ out_bias=False,
407
+ processor=ZSingleStreamAttnProcessor(),
408
+ )
409
+
410
+ self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
411
+ self.layer_id = layer_id
412
+
413
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
414
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
415
+
416
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
417
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
418
+
419
+ self.modulation = modulation
420
+ if modulation:
421
+ self.adaLN_modulation = nn.Sequential(
422
+ nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
423
+ )
424
+
425
+ @property
426
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
427
+ """
428
+ Returns a dictionary of all attention processors used in the module.
429
+ """
430
+ processors = {}
431
+
432
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
433
+ if hasattr(module, "get_processor"):
434
+ processors[f"{name}.processor"] = module.get_processor()
435
+ for sub_name, child in module.named_children():
436
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
437
+ return processors
438
+
439
+ for name, module in self.named_children():
440
+ fn_recursive_add_processors(name, module, processors)
441
+
442
+ return processors
443
+
444
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
445
+ """
446
+ Sets the attention processor for the attention layer in this block.
447
+ """
448
+ count = len(self.attn_processors.keys())
449
+
450
+ if isinstance(processor, dict) and len(processor) != count:
451
+ raise ValueError(
452
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
453
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
454
+ )
455
+
456
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
457
+ if hasattr(module, "set_processor"):
458
+ if not isinstance(processor, dict):
459
+ module.set_processor(processor)
460
+ else:
461
+ module.set_processor(processor.pop(f"{name}.processor"))
462
+ for sub_name, child in module.named_children():
463
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
464
+
465
+ for name, module in self.named_children():
466
+ fn_recursive_attn_processor(name, module, processor)
467
+
468
+ def forward(self, x, attn_mask, freqs_cis, adaln_input=None):
469
+ """
470
+ Defines the forward pass for the transformer block.
471
+
472
+ Args:
473
+ x (torch.Tensor): The input tensor.
474
+ attn_mask (torch.Tensor): The attention mask.
475
+ freqs_cis (torch.Tensor): The RoPE frequencies.
476
+ adaln_input (torch.Tensor, optional): The conditioning tensor for AdaLN. Defaults to None.
477
+
478
+ Returns:
479
+ torch.Tensor: The output tensor of the block.
480
+ """
481
+ if self.modulation:
482
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
483
+ scale_msa = scale_msa + 1.0
484
+ gate_msa = gate_msa.tanh()
485
+ scale_mlp = scale_mlp + 1.0
486
+ gate_mlp = gate_mlp.tanh()
487
+
488
+ normed = self.attention_norm1(x)
489
+ normed = normed * scale_msa
490
+ attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis)
491
+ attn_out = self.attention_norm2(attn_out) * gate_msa
492
+ x = x + attn_out
493
+
494
+ normed = self.ffn_norm1(x)
495
+ normed = normed * scale_mlp
496
+ ffn_out = self.feed_forward(normed)
497
+ ffn_out = self.ffn_norm2(ffn_out) * gate_mlp
498
+ x = x + ffn_out
499
+ else:
500
+ normed = self.attention_norm1(x)
501
+ attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis)
502
+ x = x + self.attention_norm2(attn_out)
503
+ normed = self.ffn_norm1(x)
504
+ ffn_out = self.feed_forward(normed)
505
+ x = x + self.ffn_norm2(ffn_out)
506
+ return x
507
+
508
+
509
+ class ZImageControlTransformerBlock(ZImageTransformerBlock):
510
+ """
511
+ A specialized transformer block for the control pathway. It inherits from ZImageTransformerBlock
512
+ and adds projection layers to generate and combine control signals.
513
+ """
514
+
515
+ def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0):
516
+ """
517
+ Initializes the ZImageControlTransformerBlock.
518
+
519
+ Args:
520
+ layer_id (int): The index of the layer.
521
+ dim (int): The dimension of the features.
522
+ n_heads (int): The number of attention heads.
523
+ n_kv_heads (int): The number of key/value heads.
524
+ norm_eps (float): Epsilon for RMSNorm.
525
+ qk_norm (bool): Whether to apply normalization to query and key.
526
+ modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
527
+ block_id (int, optional): The index of this control block. Defaults to 0.
528
+ """
529
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
530
+ self.block_id = block_id
531
+ if block_id == 0:
532
+ self.before_proj = zero_module(nn.Linear(self.dim, self.dim))
533
+ self.after_proj = zero_module(nn.Linear(self.dim, self.dim))
534
+
535
+ def forward(self, c, x, **kwargs):
536
+ """
537
+ Defines the forward pass for the control block.
538
+
539
+ Args:
540
+ c (torch.Tensor): The control signal tensor.
541
+ x (torch.Tensor): The reference tensor from the main pathway.
542
+ **kwargs: Additional arguments for the parent's forward method.
543
+
544
+ Returns:
545
+ torch.Tensor: A stacked tensor containing the skip connection and the final output.
546
+ """
547
+ if self.block_id == 0:
548
+ c = self.before_proj(c) + x
549
+ all_c = []
550
+ else:
551
+ all_c = list(torch.unbind(c))
552
+ c = all_c.pop(-1)
553
+
554
+ c = super().forward(c, **kwargs)
555
+ c_skip = self.after_proj(c)
556
+ all_c += [c_skip, c]
557
+ c = torch.stack(all_c)
558
+ return c
559
+
560
+
561
+ class BaseZImageTransformerBlock(ZImageTransformerBlock):
562
+ """
563
+ The main transformer block used in the primary pathway. It inherits from ZImageTransformerBlock
564
+ and adds the logic to inject control "hints" from the control pathway.
565
+ """
566
+
567
+ def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0):
568
+ """
569
+ Initializes the BaseZImageTransformerBlock.
570
+
571
+ Args:
572
+ layer_id (int): The index of the layer.
573
+ dim (int): The dimension of the features.
574
+ n_heads (int): The number of attention heads.
575
+ n_kv_heads (int): The number of key/value heads.
576
+ norm_eps (float): Epsilon for RMSNorm.
577
+ qk_norm (bool): Whether to apply normalization to query and key.
578
+ modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
579
+ block_id (int, optional): The index used to retrieve the corresponding control hint. Defaults to 0.
580
+ """
581
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
582
+ self.block_id = block_id
583
+
584
+ def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs):
585
+ """
586
+ Defines the forward pass, including the injection of control hints.
587
+
588
+ Args:
589
+ hidden_states (torch.Tensor): The input tensor.
590
+ hints (List[torch.Tensor], optional): A list of control hints from the control pathway. Defaults to None.
591
+ context_scale (float, optional): A scale factor for the control hints. Defaults to 1.0.
592
+ **kwargs: Additional arguments for the parent's forward method.
593
+
594
+ Returns:
595
+ torch.Tensor: The output tensor of the block.
596
+ """
597
+ hidden_states = super().forward(hidden_states, **kwargs)
598
+ if self.block_id is not None and hints is not None:
599
+ hidden_states = hidden_states + hints[self.block_id] * context_scale
600
+ return hidden_states
601
+
602
+
603
+ class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
604
+ _supports_gradient_checkpointing = True
605
+ _keys_to_ignore_on_load_unexpected = [
606
+ r"control_layers\..*",
607
+ r"control_noise_refiner\..*",
608
+ r"control_all_x_embedder\..*",
609
+ ]
610
+ _no_split_modules = ["ZImageTransformerBlock", "BaseZImageTransformerBlock", "ZImageControlTransformerBlock"]
611
+ _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"]
612
+ _group_offload_block_modules = ["t_embedder", "cap_embedder"]
613
+
614
+ @register_to_config
615
+ def __init__(
616
+ self,
617
+ control_layers_places=None,
618
+ control_refiner_layers_places=None,
619
+ control_in_dim=None,
620
+ add_control_noise_refiner=False,
621
+ all_patch_size=(2,),
622
+ all_f_patch_size=(1,),
623
+ in_channels=16,
624
+ dim=3840,
625
+ n_layers=30,
626
+ n_refiner_layers=2,
627
+ n_heads=30,
628
+ n_kv_heads=30,
629
+ norm_eps=1e-5,
630
+ qk_norm=True,
631
+ cap_feat_dim=2560,
632
+ rope_theta=256.0,
633
+ t_scale=1000.0,
634
+ axes_dims=[32, 48, 48],
635
+ axes_lens=[1024, 512, 512],
636
+ use_controlnet=True,
637
+ checkpoint_ratio=0.5,
638
+ ):
639
+ """
640
+ Initializes the ZImageControlTransformer2DModel.
641
+
642
+ Args:
643
+ control_layers_places (List[int], optional): Indices of main layers where control hints are injected.
644
+ control_refiner_layers_places (List[int], optional): Indices of noise refiner layers for two-stage control.
645
+ control_in_dim (int, optional): Input channel dimension for the control context.
646
+ add_control_noise_refiner (bool, optional): Whether to add a dedicated refiner for the control signal.
647
+ all_patch_size (Tuple[int], optional): Tuple of patch sizes for spatial dimensions.
648
+ all_f_patch_size (Tuple[int], optional): Tuple of patch sizes for the frame dimension.
649
+ in_channels (int, optional): Number of input channels for the latent image.
650
+ dim (int, optional): The main dimension of the transformer model.
651
+ n_layers (int, optional): The number of main transformer layers.
652
+ n_refiner_layers (int, optional): The number of layers in the refiner blocks.
653
+ n_heads (int, optional): The number of attention heads.
654
+ n_kv_heads (int, optional): The number of key/value heads.
655
+ norm_eps (float, optional): Epsilon for RMSNorm.
656
+ qk_norm (bool, optional): Whether to apply normalization to query and key.
657
+ cap_feat_dim (int, optional): The dimension of the input caption features.
658
+ rope_theta (float, optional): The base for RoPE.
659
+ t_scale (float, optional): A scaling factor for the timestep.
660
+ axes_dims (List[int], optional): Dimensions for each axis in RoPE.
661
+ axes_lens (List[int], optional): Maximum lengths for each axis in RoPE.
662
+ use_controlnet (bool, optional): If False, control-related layers will not be created to save memory.
663
+ checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to.
664
+ """
665
+ super().__init__()
666
+ self.use_controlnet = use_controlnet
667
+ self.in_channels = in_channels
668
+ self.out_channels = in_channels
669
+ self.all_patch_size = all_patch_size
670
+ self.all_f_patch_size = all_f_patch_size
671
+ self.dim = dim
672
+ self.control_in_dim = self.dim if control_in_dim is None else control_in_dim
673
+ self.is_two_stage_control = self.control_in_dim > 16
674
+ self.n_heads = n_heads
675
+ self.rope_theta = rope_theta
676
+ self.t_scale = t_scale
677
+ self.gradient_checkpointing = False
678
+ self.checkpoint_ratio = checkpoint_ratio
679
+ assert len(all_patch_size) == len(all_f_patch_size)
680
+
681
+ self.control_layers_places = list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places
682
+ self.control_refiner_layers_places = list(range(0, n_refiner_layers)) if control_refiner_layers_places is None else control_refiner_layers_places
683
+ self.add_control_noise_refiner = add_control_noise_refiner
684
+ assert 0 in self.control_layers_places
685
+ self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)}
686
+ self.control_refiner_layers_mapping = {i: n for n, i in enumerate(self.control_refiner_layers_places)}
687
+
688
+ self.all_x_embedder = nn.ModuleDict(
689
+ {
690
+ f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
691
+ for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
692
+ }
693
+ )
694
+
695
+ self.all_final_layer = nn.ModuleDict(
696
+ {
697
+ f"{patch_size}-{f_patch_size}": FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
698
+ for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
699
+ }
700
+ )
701
+
702
+ self.context_refiner = nn.ModuleList(
703
+ [ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False) for i in range(n_refiner_layers)]
704
+ )
705
+ self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
706
+ self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
707
+ self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
708
+ self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
709
+
710
+ head_dim = dim // n_heads
711
+ assert head_dim == sum(axes_dims)
712
+ self.axes_dims = axes_dims
713
+ self.axes_lens = axes_lens
714
+ self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
715
+
716
+ self.layers = nn.ModuleList(
717
+ [BaseZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=self.control_layers_mapping.get(i)) for i in range(n_layers)]
718
+ )
719
+
720
+ self.noise_refiner = nn.ModuleList(
721
+ [
722
+ BaseZImageTransformerBlock(
723
+ 1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=self.control_refiner_layers_mapping.get(i)
724
+ )
725
+ for i in range(n_refiner_layers)
726
+ ]
727
+ )
728
+
729
+ if self.use_controlnet:
730
+ self.control_layers = nn.ModuleList(
731
+ [ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places]
732
+ )
733
+ self.control_all_x_embedder = nn.ModuleDict(
734
+ {
735
+ f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
736
+ for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
737
+ }
738
+ )
739
+
740
+ if self.is_two_stage_control:
741
+ if self.add_control_noise_refiner:
742
+ self.control_noise_refiner = nn.ModuleList(
743
+ [
744
+ ZImageControlTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=layer_id)
745
+ for layer_id in range(n_refiner_layers)
746
+ ]
747
+ )
748
+ else:
749
+ self.control_noise_refiner = None
750
+ else: # V1
751
+ self.control_noise_refiner = nn.ModuleList(
752
+ [ZImageTransformerBlock(1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True) for i in range(n_refiner_layers)]
753
+ )
754
+ else:
755
+ self.control_layers = None
756
+ self.control_all_x_embedder = None
757
+ self.control_noise_refiner = None
758
+
759
+ def _unpatchify(self, x_image_tokens: torch.Tensor, all_sizes: List[Tuple], patch_size: int, f_patch_size: int) -> torch.Tensor:
760
+ """
761
+ Converts a sequence of image tokens back into a batched image tensor. This version is robust
762
+ to batches containing images of different original sizes.
763
+
764
+ Args:
765
+ x_image_tokens (torch.Tensor): A tensor of image tokens with shape [B, SeqLen, Dim].
766
+ all_sizes (List[Tuple]): A list of tuples with the original (F, H, W) size for each image in the batch.
767
+ patch_size (int): The spatial patch size (height and width).
768
+ f_patch_size (int): The frame/temporal patch size.
769
+
770
+ Returns:
771
+ torch.Tensor: The reconstructed latent tensor with shape [B, C, F, H, W].
772
+ """
773
+ pH = pW = patch_size
774
+ pF = f_patch_size
775
+ batch_size = x_image_tokens.shape[0]
776
+ unpatched_images = []
777
+
778
+ for i in range(batch_size):
779
+ F, H, W = all_sizes[i]
780
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
781
+ original_seq_len = F_tokens * H_tokens * W_tokens
782
+ current_image_tokens = x_image_tokens[i, :original_seq_len, :]
783
+ unpatched_image = current_image_tokens.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, self.out_channels)
784
+ unpatched_image = unpatched_image.permute(6, 0, 3, 1, 4, 2, 5).reshape(self.out_channels, F, H, W)
785
+ unpatched_images.append(unpatched_image)
786
+
787
+ try:
788
+ final_tensor = torch.stack(unpatched_images, dim=0)
789
+ except RuntimeError:
790
+ raise ValueError(
791
+ "Could not stack unpatched images into a single batch tensor. "
792
+ "This typically occurs if you are trying to generate images of different sizes in the same batch."
793
+ )
794
+
795
+ return final_tensor
796
+
797
+ def _patchify(
798
+ self,
799
+ all_image: List[torch.Tensor],
800
+ patch_size: int,
801
+ f_patch_size: int,
802
+ cap_padding_len: int,
803
+ ):
804
+ """
805
+ Converts a list of image tensors into patch sequences and computes their positional IDs.
806
+
807
+ Args:
808
+ all_image (List[torch.Tensor]): A list of image tensors to process.
809
+ patch_size (int): The spatial patch size.
810
+ f_patch_size (int): The frame/temporal patch size.
811
+ cap_padding_len (int): The length of the padded caption sequence, used as an offset for image position IDs.
812
+
813
+ Returns:
814
+ Tuple: A tuple containing lists of processed patches, sizes, position IDs, and padding masks.
815
+ """
816
+ pH = pW = patch_size
817
+ pF = f_patch_size
818
+ device = all_image[0].device
819
+
820
+ all_image_out = []
821
+ all_image_size = []
822
+ all_image_pos_ids = []
823
+ all_image_pad_mask = []
824
+
825
+ for i, image in enumerate(all_image):
826
+ C, F, H, W = image.size()
827
+ all_image_size.append((F, H, W))
828
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
829
+
830
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
831
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
832
+
833
+ image_ori_len = len(image)
834
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
835
+
836
+ image_ori_pos_ids = self._create_coordinate_grid(
837
+ size=(F_tokens, H_tokens, W_tokens),
838
+ start=(cap_padding_len + 1, 0, 0),
839
+ device=device,
840
+ ).flatten(0, 2)
841
+ image_padding_pos_ids = (
842
+ self._create_coordinate_grid(
843
+ size=(1, 1, 1),
844
+ start=(0, 0, 0),
845
+ device=device,
846
+ )
847
+ .flatten(0, 2)
848
+ .repeat(image_padding_len, 1)
849
+ )
850
+ image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
851
+ all_image_pos_ids.append(image_padded_pos_ids)
852
+ all_image_pad_mask.append(
853
+ torch.cat(
854
+ [
855
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
856
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
857
+ ],
858
+ dim=0,
859
+ )
860
+ )
861
+ image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
862
+ all_image_out.append(image_padded_feat)
863
+
864
+ return (
865
+ all_image_out,
866
+ all_image_size,
867
+ all_image_pos_ids,
868
+ all_image_pad_mask,
869
+ )
870
+
871
+ def _patchify_and_embed(
872
+ self,
873
+ all_image: List[torch.Tensor],
874
+ all_cap_feats: List[torch.Tensor],
875
+ patch_size: int,
876
+ f_patch_size: int,
877
+ ):
878
+ """
879
+ Processes a batch of images and caption features by converting them into padded patch sequences
880
+ and generating their corresponding positional IDs and padding masks. This is the general-purpose,
881
+ robust version that iterates through the batch.
882
+
883
+ Args:
884
+ all_image (List[torch.Tensor]): A list of image tensors.
885
+ all_cap_feats (List[torch.Tensor]): A list of caption feature tensors.
886
+ patch_size (int): The spatial patch size.
887
+ f_patch_size (int): The frame/temporal patch size.
888
+
889
+ Returns:
890
+ Tuple: A tuple containing all processed data structures (image patches, caption features, sizes,
891
+ position IDs, and padding masks) as lists.
892
+ """
893
+ pH = pW = patch_size
894
+ pF = f_patch_size
895
+ device = all_image[0].device
896
+
897
+ all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], []
898
+ all_cap_pos_ids, all_cap_pad_mask, all_cap_feats_out = [], [], []
899
+
900
+ for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
901
+ cap_ori_len = len(cap_feat)
902
+ cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
903
+ cap_total_len = cap_ori_len + cap_padding_len
904
+
905
+ cap_padded_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2)
906
+ all_cap_pos_ids.append(cap_padded_pos_ids)
907
+
908
+ cap_mask = torch.ones(cap_total_len, dtype=torch.bool, device=device)
909
+ cap_mask[:cap_ori_len] = False
910
+ all_cap_pad_mask.append(cap_mask)
911
+
912
+ if cap_padding_len > 0:
913
+ padding_tensor = cap_feat[-1:].repeat(cap_padding_len, 1)
914
+ cap_padded_feat = torch.cat([cap_feat, padding_tensor], dim=0)
915
+ else:
916
+ cap_padded_feat = cap_feat
917
+ all_cap_feats_out.append(cap_padded_feat)
918
+
919
+ C, Fr, H, W = image.size()
920
+ all_image_size.append((Fr, H, W))
921
+ F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW
922
+
923
+ image_reshaped = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW).permute(1, 3, 5, 2, 4, 6, 0).reshape(-1, pF * pH * pW * C)
924
+
925
+ image_ori_len = image_reshaped.shape[0]
926
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
927
+ image_total_len = image_ori_len + image_padding_len
928
+
929
+ image_ori_pos_ids = self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device).flatten(0, 2)
930
+ if image_padding_len > 0:
931
+ image_padding_pos_ids = torch.zeros((image_padding_len, 3), dtype=torch.int32, device=device)
932
+ image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
933
+ else:
934
+ image_padded_pos_ids = image_ori_pos_ids
935
+ all_image_pos_ids.append(image_padded_pos_ids)
936
+
937
+ image_mask = torch.ones(image_total_len, dtype=torch.bool, device=device)
938
+ image_mask[:image_ori_len] = False
939
+ all_image_pad_mask.append(image_mask)
940
+
941
+ if image_padding_len > 0:
942
+ padding_tensor = image_reshaped[-1:].repeat(image_padding_len, 1)
943
+ image_padded_feat = torch.cat([image_reshaped, padding_tensor], dim=0)
944
+ else:
945
+ image_padded_feat = image_reshaped
946
+ all_image_out.append(image_padded_feat)
947
+
948
+ return (
949
+ all_image_out,
950
+ all_cap_feats_out,
951
+ all_image_size,
952
+ all_image_pos_ids,
953
+ all_cap_pos_ids,
954
+ all_image_pad_mask,
955
+ all_cap_pad_mask,
956
+ )
957
+
958
+ def _process_cap_feats_with_cfg_cache(self, cap_feats_list, cap_pos_ids, cap_inner_pad_mask):
959
+ """
960
+ Processes caption features with intelligent duplicate detection to avoid redundant computation,
961
+ especially for Classifier-Free Guidance (CFG) where prompts are repeated.
962
+
963
+ Args:
964
+ cap_feats_list (List[torch.Tensor]): List of padded caption feature tensors.
965
+ cap_pos_ids (List[torch.Tensor]): List of corresponding position ID tensors.
966
+ cap_inner_pad_mask (List[torch.Tensor]): List of corresponding padding masks.
967
+
968
+ Returns:
969
+ Tuple: A tuple of batched tensors for padded features, RoPE frequencies, attention mask, and sequence lengths.
970
+ """
971
+ device = cap_feats_list[0].device
972
+ bsz = len(cap_feats_list)
973
+
974
+ shapes_equal = all(c.shape == cap_feats_list[0].shape for c in cap_feats_list)
975
+
976
+ if shapes_equal and bsz >= 2:
977
+ unique_indices = [0]
978
+ unique_tensors = [cap_feats_list[0]]
979
+ tensor_mapping = [0]
980
+
981
+ for i in range(1, bsz):
982
+ found_match = False
983
+ for j, unique_tensor in enumerate(unique_tensors):
984
+ if torch.equal(cap_feats_list[i], unique_tensor):
985
+ tensor_mapping.append(j)
986
+ found_match = True
987
+ break
988
+
989
+ if not found_match:
990
+ unique_indices.append(i)
991
+ unique_tensors.append(cap_feats_list[i])
992
+ tensor_mapping.append(len(unique_tensors) - 1)
993
+
994
+ if len(unique_tensors) < bsz:
995
+ unique_cap_feats_list = [cap_feats_list[i] for i in unique_indices]
996
+ unique_cap_pos_ids = [cap_pos_ids[i] for i in unique_indices]
997
+ unique_cap_inner_pad_mask = [cap_inner_pad_mask[i] for i in unique_indices]
998
+
999
+ cap_item_seqlens_unique = [len(i) for i in unique_cap_feats_list]
1000
+ cap_max_item_seqlen = max(cap_item_seqlens_unique)
1001
+
1002
+ cap_feats_cat = torch.cat(unique_cap_feats_list, dim=0)
1003
+ cap_feats_embedded = self.cap_embedder(cap_feats_cat)
1004
+ cap_feats_embedded[torch.cat(unique_cap_inner_pad_mask)] = self.cap_pad_token
1005
+ cap_feats_padded_unique = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0)
1006
+
1007
+ cap_freqs_cis_cat = self.rope_embedder(torch.cat(unique_cap_pos_ids, dim=0))
1008
+ cap_freqs_cis_unique = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0)
1009
+
1010
+ cap_feats_padded = cap_feats_padded_unique[tensor_mapping]
1011
+ cap_freqs_cis = cap_freqs_cis_unique[tensor_mapping]
1012
+
1013
+ seq_lens_tensor = torch.tensor([cap_max_item_seqlen] * bsz, device=device, dtype=torch.int32)
1014
+ arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32)
1015
+ cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
1016
+
1017
+ cap_item_seqlens = [cap_max_item_seqlen] * bsz
1018
+
1019
+ return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens
1020
+
1021
+ cap_item_seqlens = [len(i) for i in cap_feats_list]
1022
+ cap_max_item_seqlen = max(cap_item_seqlens)
1023
+ cap_feats_cat = torch.cat(cap_feats_list, dim=0)
1024
+ cap_feats_embedded = self.cap_embedder(cap_feats_cat)
1025
+ cap_feats_embedded[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
1026
+ cap_feats_padded = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
1027
+
1028
+ cap_freqs_cis_cat = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
1029
+ cap_freqs_cis = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
1030
+
1031
+ seq_lens_tensor = torch.tensor(cap_item_seqlens, device=device, dtype=torch.int32)
1032
+ arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32)
1033
+ cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
1034
+
1035
+ return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens
1036
+
1037
+ @staticmethod
1038
+ def _create_coordinate_grid(size, start=None, device=None):
1039
+ """
1040
+ Creates a 3D coordinate grid.
1041
+
1042
+ Args:
1043
+ size (Tuple[int]): The dimensions of the grid (F, H, W).
1044
+ start (Tuple[int], optional): The starting coordinates for each axis. Defaults to (0, 0, 0).
1045
+ device (torch.device, optional): The device to create the tensor on. Defaults to None.
1046
+
1047
+ Returns:
1048
+ torch.Tensor: The coordinate grid tensor.
1049
+ """
1050
+ if start is None:
1051
+ start = (0 for _ in size)
1052
+ axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
1053
+ grids = torch.meshgrid(axes, indexing="ij")
1054
+ return torch.stack(grids, dim=-1)
1055
+
1056
+ def _apply_transformer_blocks(self, hidden_states, layers, checkpoint_ratio=0.5, **kwargs):
1057
+ """
1058
+ Applies a list of transformer layers to the hidden states, with optional selective gradient checkpointing.
1059
+
1060
+ Args:
1061
+ hidden_states (torch.Tensor): The input tensor.
1062
+ layers (nn.ModuleList): The list of transformer layers to apply.
1063
+ checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. Defaults to 0.5.
1064
+ **kwargs: Additional keyword arguments to pass to each layer's forward method.
1065
+
1066
+ Returns:
1067
+ torch.Tensor: The output tensor after applying all layers.
1068
+ """
1069
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1070
+
1071
+ def create_custom_forward(module, **static_kwargs):
1072
+ def custom_forward(*inputs):
1073
+ return module(*inputs, **static_kwargs)
1074
+
1075
+ return custom_forward
1076
+
1077
+ ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1078
+
1079
+ checkpoint_every_n = max(1, int(1.0 / checkpoint_ratio)) if checkpoint_ratio > 0 else len(layers) + 1
1080
+
1081
+ for i, layer in enumerate(layers):
1082
+ if i % checkpoint_every_n == 0:
1083
+ hidden_states = torch.utils.checkpoint.checkpoint(
1084
+ create_custom_forward(layer, **kwargs),
1085
+ hidden_states,
1086
+ **ckpt_kwargs,
1087
+ )
1088
+ else:
1089
+ hidden_states = layer(hidden_states, **kwargs)
1090
+ else:
1091
+ for layer in layers:
1092
+ hidden_states = layer(hidden_states, **kwargs)
1093
+
1094
+ return hidden_states
1095
+
1096
+ def _prepare_control_inputs(self, control_context, cap_feats_ref, t, patch_size, f_patch_size, device):
1097
+ """
1098
+ Prepares the control context for the transformer, including patchifying, embedding, and generating
1099
+ positional information. Includes a fast path for batches with uniform shapes.
1100
+
1101
+ Args:
1102
+ control_context (torch.Tensor or List[torch.Tensor]): The control context input.
1103
+ cap_feats_ref (List[torch.Tensor]): A reference to caption features for padding calculation.
1104
+ t (torch.Tensor): The timestep tensor.
1105
+ patch_size (int): The spatial patch size.
1106
+ f_patch_size (int): The frame/temporal patch size.
1107
+ device (torch.device): The target device.
1108
+
1109
+ Returns:
1110
+ Dict: A dictionary containing the processed control tensors ('c', 'c_item_seqlens', 'attn_mask', etc.).
1111
+ """
1112
+ bsz = control_context.shape[0]
1113
+
1114
+ if isinstance(control_context, torch.Tensor) and control_context.ndim == 5:
1115
+ control_list = list(torch.unbind(control_context, dim=0))
1116
+ else:
1117
+ control_list = control_context
1118
+
1119
+ pH = pW = patch_size
1120
+ pF = f_patch_size
1121
+ cap_padding_len = cap_feats_ref[0].size(0) if isinstance(cap_feats_ref, list) else cap_feats_ref.shape[1]
1122
+
1123
+ shapes = [c.shape for c in control_list]
1124
+ same_shape = all(s == shapes[0] for s in shapes)
1125
+
1126
+ if same_shape and bsz >= 2:
1127
+ control_batch = torch.stack(control_list, dim=0)
1128
+ B, C, F, H, W = control_batch.shape
1129
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
1130
+
1131
+ control_batch = control_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
1132
+ control_batch = control_batch.permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
1133
+
1134
+ ori_len = control_batch.shape[1]
1135
+ padding_len = (-ori_len) % SEQ_MULTI_OF
1136
+
1137
+ if padding_len > 0:
1138
+ pad_tensor = control_batch[:, -1:, :].repeat(1, padding_len, 1)
1139
+ control_batch = torch.cat([control_batch, pad_tensor], dim=1)
1140
+
1141
+ c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_batch)
1142
+
1143
+ final_seq_len = control_batch.shape[1]
1144
+ pos_ids_ori = self._create_coordinate_grid(
1145
+ size=(F_tokens, H_tokens, W_tokens),
1146
+ start=(cap_padding_len + 1, 0, 0),
1147
+ device=device,
1148
+ ).flatten(0, 2) # [ori_len, 3]
1149
+
1150
+ pos_ids_pad = torch.zeros((padding_len, 3), dtype=torch.int32, device=device)
1151
+ pos_ids_padded = torch.cat([pos_ids_ori, pos_ids_pad], dim=0)
1152
+
1153
+ c_freqs_cis_single = self.rope_embedder(pos_ids_padded)
1154
+ c_freqs_cis = c_freqs_cis_single.unsqueeze(0).repeat(B, 1, 1, 1)
1155
+ c_attn_mask = torch.ones((B, final_seq_len), dtype=torch.bool, device=device)
1156
+
1157
+ return {"c": c, "c_item_seqlens": [final_seq_len] * B, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis, "adaln_input": t.type_as(c)}
1158
+
1159
+ (c_patches, _, c_pos_ids, c_inner_pad_mask) = self._patchify(control_list, patch_size, f_patch_size, cap_padding_len)
1160
+
1161
+ c_item_seqlens = [len(p) for p in c_patches]
1162
+ c_max_item_seqlen = max(c_item_seqlens)
1163
+
1164
+ c = torch.cat(c_patches, dim=0)
1165
+ c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](c)
1166
+ c[torch.cat(c_inner_pad_mask)] = self.x_pad_token
1167
+ c = list(c.split(c_item_seqlens, dim=0))
1168
+
1169
+ c_freqs_cis_list = []
1170
+ for pos_ids in c_pos_ids:
1171
+ c_freqs_cis_list.append(self.rope_embedder(pos_ids))
1172
+
1173
+ c_padded = pad_sequence(c, batch_first=True, padding_value=0.0)
1174
+ c_freqs_cis_padded = pad_sequence(c_freqs_cis_list, batch_first=True, padding_value=0.0)
1175
+
1176
+ seq_lens_tensor = torch.tensor(c_item_seqlens, device=device, dtype=torch.int32)
1177
+ arange = torch.arange(c_max_item_seqlen, device=device, dtype=torch.int32)
1178
+ c_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
1179
+
1180
+ return {"c": c_padded, "c_item_seqlens": c_item_seqlens, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis_padded, "adaln_input": t.type_as(c_padded)}
1181
+
1182
+ def _patchify_and_embed_batch_optimized(self, all_image, all_cap_feats, patch_size, f_patch_size):
1183
+ """
1184
+ An optimized version of _patchify_and_embed for batches where all images and captions have
1185
+ uniform shapes. It processes the entire batch using vectorized operations instead of a loop.
1186
+
1187
+ Args:
1188
+ all_image (List[torch.Tensor]): List of image tensors, all of the same shape.
1189
+ all_cap_feats (List[torch.Tensor]): List of caption features, all of the same shape.
1190
+ patch_size (int): The spatial patch size.
1191
+ f_patch_size (int): The frame/temporal patch size.
1192
+
1193
+ Returns:
1194
+ Tuple: A tuple containing all processed data structures, matching the output of the standard method.
1195
+ """
1196
+ pH = pW = patch_size
1197
+ pF = f_patch_size
1198
+ device = all_image[0].device
1199
+
1200
+ image_shapes = [img.shape for img in all_image]
1201
+ cap_shapes = [cap.shape for cap in all_cap_feats]
1202
+
1203
+ same_image_shape = all(s == image_shapes[0] for s in image_shapes)
1204
+ same_cap_shape = all(s == cap_shapes[0] for s in cap_shapes)
1205
+
1206
+ if not (same_image_shape and same_cap_shape):
1207
+ return self._patchify_and_embed(all_image, all_cap_feats, patch_size, f_patch_size)
1208
+
1209
+ images_batch = torch.stack(all_image, dim=0)
1210
+ caps_batch = torch.stack(all_cap_feats, dim=0)
1211
+
1212
+ B, C, Fr, H, W = images_batch.shape
1213
+ cap_ori_len = caps_batch.shape[1]
1214
+
1215
+ cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
1216
+ cap_total_len = cap_ori_len + cap_padding_len
1217
+
1218
+ if cap_padding_len > 0:
1219
+ cap_pad = caps_batch[:, -1:, :].repeat(1, cap_padding_len, 1)
1220
+ caps_batch = torch.cat([caps_batch, cap_pad], dim=1)
1221
+
1222
+ cap_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2).unsqueeze(0).repeat(B, 1, 1)
1223
+
1224
+ cap_mask = torch.zeros((B, cap_total_len), dtype=torch.bool, device=device)
1225
+ if cap_padding_len > 0:
1226
+ cap_mask[:, cap_ori_len:] = True
1227
+
1228
+ F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW
1229
+ images_reshaped = (
1230
+ images_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
1231
+ .permute(0, 2, 4, 6, 3, 5, 7, 1)
1232
+ .reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
1233
+ )
1234
+
1235
+ image_ori_len = images_reshaped.shape[1]
1236
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
1237
+ image_total_len = image_ori_len + image_padding_len
1238
+
1239
+ if image_padding_len > 0:
1240
+ img_pad = images_reshaped[:, -1:, :].repeat(1, image_padding_len, 1)
1241
+ images_reshaped = torch.cat([images_reshaped, img_pad], dim=1)
1242
+
1243
+ image_pos_ids = (
1244
+ self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device)
1245
+ .flatten(0, 2)
1246
+ .unsqueeze(0)
1247
+ .repeat(B, 1, 1)
1248
+ )
1249
+
1250
+ if image_padding_len > 0:
1251
+ img_pos_pad = torch.zeros((B, image_padding_len, 3), dtype=torch.int32, device=device)
1252
+ image_pos_ids = torch.cat([image_pos_ids, img_pos_pad], dim=1)
1253
+
1254
+ image_mask = torch.zeros((B, image_total_len), dtype=torch.bool, device=device)
1255
+ if image_padding_len > 0:
1256
+ image_mask[:, image_ori_len:] = True
1257
+
1258
+ all_image_size = [(Fr, H, W)] * B
1259
+
1260
+ return (
1261
+ list(torch.unbind(images_reshaped, dim=0)),
1262
+ list(torch.unbind(caps_batch, dim=0)),
1263
+ all_image_size,
1264
+ list(torch.unbind(image_pos_ids, dim=0)),
1265
+ list(torch.unbind(cap_pos_ids, dim=0)),
1266
+ list(torch.unbind(image_mask, dim=0)),
1267
+ list(torch.unbind(cap_mask, dim=0)),
1268
+ )
1269
+
1270
+ def forward(
1271
+ self,
1272
+ x: List[torch.Tensor],
1273
+ t,
1274
+ cap_feats: List[torch.Tensor],
1275
+ patch_size=2,
1276
+ f_patch_size=1,
1277
+ control_context=None,
1278
+ conditioning_scale=1.0,
1279
+ refiner_conditioning_scale=1.0,
1280
+ ):
1281
+ """
1282
+ The main forward pass of the transformer model.
1283
+
1284
+ Args:
1285
+ x (List[torch.Tensor]):
1286
+ A list of latent image tensors.
1287
+ t (torch.Tensor):
1288
+ A batch of timesteps.
1289
+ cap_feats (List[torch.Tensor]):
1290
+ A list of caption feature tensors.
1291
+ patch_size (int, optional):
1292
+ The spatial patch size to use. Defaults to 2.
1293
+ f_patch_size (int, optional):
1294
+ The frame/temporal patch size to use. Defaults to 1.
1295
+ control_context (torch.Tensor, optional):
1296
+ The control context tensor. Defaults to None.
1297
+ conditioning_scale (float, optional):
1298
+ The scale for applying control hints. Defaults to 1.0.
1299
+ refiner_conditioning_scale (float, optional):
1300
+ The scale for applying refiner control hints. Defaults to 1.0.
1301
+
1302
+ Returns:
1303
+ Transformer2DModelOutput: An object containing the final denoised sample.
1304
+ """
1305
+
1306
+ is_control_mode = self.use_controlnet and control_context is not None and conditioning_scale > 0
1307
+ if refiner_conditioning_scale is None:
1308
+ refiner_conditioning_scale = conditioning_scale or 1.0
1309
+
1310
+ assert patch_size in self.all_patch_size
1311
+ assert f_patch_size in self.all_f_patch_size
1312
+
1313
+ bsz = len(x)
1314
+ device = x[0].device
1315
+
1316
+ t = t * self.t_scale
1317
+ t = self.t_embedder(t)
1318
+
1319
+ can_optimize_patchify = (
1320
+ bsz == len(cap_feats) and bsz >= 2 and all(img.shape == x[0].shape for img in x) and all(cap.shape == cap_feats[0].shape for cap in cap_feats)
1321
+ )
1322
+
1323
+ if can_optimize_patchify:
1324
+ (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed_batch_optimized(
1325
+ x, cap_feats, patch_size, f_patch_size
1326
+ )
1327
+ else:
1328
+ (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed(
1329
+ x, cap_feats, patch_size, f_patch_size
1330
+ )
1331
+
1332
+ x_item_seqlens = [len(i) for i in x_list]
1333
+ x_max_item_seqlen = max(x_item_seqlens) if x_item_seqlens else 0
1334
+ x_cat = torch.cat(x_list, dim=0) if x_list else torch.empty(0, x_list[0].shape[1] if x_list else 0, device=device)
1335
+ x_embedded = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
1336
+ if x_inner_pad_mask and torch.cat(x_inner_pad_mask).any():
1337
+ x_embedded[torch.cat(x_inner_pad_mask)] = self.x_pad_token
1338
+ x = pad_sequence(list(x_embedded.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
1339
+ adaln_input = t.to(device).type_as(x)
1340
+
1341
+ cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens = self._process_cap_feats_with_cfg_cache(
1342
+ cap_feats_list, cap_pos_ids, cap_inner_pad_mask
1343
+ )
1344
+
1345
+ x_freqs_cis_cat = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) if x_pos_ids else torch.empty(0, device=device)
1346
+ x_freqs_cis = pad_sequence(list(x_freqs_cis_cat.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
1347
+
1348
+ seq_lens_tensor = torch.tensor(x_item_seqlens, device=device, dtype=torch.int32)
1349
+ arange = torch.arange(x_max_item_seqlen, device=device, dtype=torch.int32)
1350
+ x_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
1351
+
1352
+
1353
+ refiner_hints = None
1354
+ if is_control_mode and self.is_two_stage_control:
1355
+ prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device)
1356
+ c = prepared_control["c"]
1357
+ """
1358
+ kwargs_for_control_refiner = {
1359
+ "x": x,
1360
+ "attn_mask": prepared_control["attn_mask"],
1361
+ "freqs_cis": prepared_control["freqs_cis"],
1362
+ "adaln_input": prepared_control["adaln_input"],
1363
+ }
1364
+ c_processed = self._apply_transformer_blocks(
1365
+ c,
1366
+ self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers,
1367
+ checkpoint_ratio=self.checkpoint_ratio,
1368
+ **kwargs_for_control_refiner,
1369
+ )
1370
+ refiner_hints = torch.unbind(c_processed)[:-1]
1371
+ control_context_processed = torch.unbind(c_processed)[-1]
1372
+ control_context_item_seqlens = prepared_control["c_item_seqlens"]
1373
+ """
1374
+ kwargs_for_control_refiner = {
1375
+ "x": x,
1376
+ "attn_mask": x_attn_mask, # was prepared_control["attn_mask"]
1377
+ "freqs_cis": x_freqs_cis, # was prepared_control["freqs_cis"]
1378
+ "adaln_input": adaln_input,
1379
+ }
1380
+ c_processed = self._apply_transformer_blocks(
1381
+ c,
1382
+ self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, # KEEP ORIGINAL
1383
+ checkpoint_ratio=self.checkpoint_ratio,
1384
+ **kwargs_for_control_refiner,
1385
+ )
1386
+ refiner_hints = torch.unbind(c_processed)[:-1]
1387
+ control_context_processed = torch.unbind(c_processed)[-1]
1388
+ control_context_item_seqlens = prepared_control["c_item_seqlens"]
1389
+ kwargs_for_refiner = {
1390
+ "attn_mask": x_attn_mask,
1391
+ "freqs_cis": x_freqs_cis,
1392
+ "adaln_input": adaln_input,
1393
+ "context_scale": refiner_conditioning_scale,
1394
+ }
1395
+ if refiner_hints is not None:
1396
+ kwargs_for_refiner["hints"] = refiner_hints
1397
+ x = self._apply_transformer_blocks(x, self.noise_refiner, checkpoint_ratio=1.0, **kwargs_for_refiner)
1398
+
1399
+ kwargs_for_context = {"attn_mask": cap_attn_mask, "freqs_cis": cap_freqs_cis}
1400
+ cap_feats = self._apply_transformer_blocks(cap_feats_padded, self.context_refiner, checkpoint_ratio=1.0, **kwargs_for_context)
1401
+
1402
+ unified_item_seqlens = [a + b for a, b in zip(x_item_seqlens, cap_item_seqlens)]
1403
+ unified_max_item_seqlen = max(unified_item_seqlens) if unified_item_seqlens else 0
1404
+ unified = torch.zeros((bsz, unified_max_item_seqlen, x.shape[-1]), dtype=x.dtype, device=device)
1405
+ unified_freqs_cis = torch.zeros((bsz, unified_max_item_seqlen, x_freqs_cis.shape[-2], x_freqs_cis.shape[-1]), dtype=x_freqs_cis.dtype, device=device)
1406
+
1407
+ for i in range(bsz):
1408
+ x_len = x_item_seqlens[i]
1409
+ cap_len = cap_item_seqlens[i]
1410
+ unified[i, :x_len] = x[i, :x_len]
1411
+ unified[i, x_len : x_len + cap_len] = cap_feats[i, :cap_len]
1412
+ unified_freqs_cis[i, :x_len] = x_freqs_cis[i, :x_len]
1413
+ unified_freqs_cis[i, x_len : x_len + cap_len] = cap_freqs_cis[i, :cap_len]
1414
+
1415
+ seq_lens_tensor = torch.tensor(unified_item_seqlens, device=device, dtype=torch.int32)
1416
+ arange = torch.arange(unified_max_item_seqlen, device=device, dtype=torch.int32)
1417
+ unified_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
1418
+
1419
+ hints = None
1420
+ if is_control_mode:
1421
+ kwargs_for_hints = {
1422
+ "attn_mask": unified_attn_mask,
1423
+ "freqs_cis": unified_freqs_cis,
1424
+ "adaln_input": adaln_input,
1425
+ }
1426
+ if self.is_two_stage_control:
1427
+ control_context_unified_list = [
1428
+ torch.cat([control_context_processed[i][: control_context_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)
1429
+ ]
1430
+ c = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0)
1431
+ new_kwargs = dict(x=unified, **kwargs_for_hints)
1432
+ c_processed = self._apply_transformer_blocks(c, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs)
1433
+ hints = torch.unbind(c_processed)[:-1]
1434
+ else:
1435
+ prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device)
1436
+ c = prepared_control["c"]
1437
+ kwargs_for_v1_refiner = {
1438
+ "attn_mask": prepared_control["attn_mask"],
1439
+ "freqs_cis": prepared_control["freqs_cis"],
1440
+ "adaln_input": prepared_control["adaln_input"],
1441
+ }
1442
+ c = self._apply_transformer_blocks(c, self.control_noise_refiner, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_v1_refiner)
1443
+ c_item_seqlens = prepared_control["c_item_seqlens"]
1444
+ control_context_unified_list = [torch.cat([c[i, : c_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)]
1445
+ c_unified = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0)
1446
+ new_kwargs = dict(x=unified, **kwargs_for_hints)
1447
+ c_processed = self._apply_transformer_blocks(c_unified, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs)
1448
+ hints = torch.unbind(c_processed)[:-1]
1449
+
1450
+ kwargs_for_layers = {"attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input}
1451
+ if hints is not None:
1452
+ kwargs_for_layers["hints"] = hints
1453
+ kwargs_for_layers["context_scale"] = conditioning_scale
1454
+ unified = self._apply_transformer_blocks(unified, self.layers, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_layers)
1455
+
1456
+ unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
1457
+ x_image_tokens = unified_out[:, :x_max_item_seqlen]
1458
+ x_final_tensor = self._unpatchify(x_image_tokens, x_size, patch_size, f_patch_size)
1459
+
1460
+ return Transformer2DModelOutput(sample=x_final_tensor)