Instructions to use diffusers/matrix-game-2-modular with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use diffusers/matrix-game-2-modular with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("diffusers/matrix-game-2-modular", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any, List, Tuple | |
| import torch | |
| from diffusers.configuration_utils import FrozenDict | |
| from diffusers.guiders import ClassifierFreeGuidance | |
| from diffusers.models import AutoModel, WanTransformer3DModel | |
| from diffusers.schedulers import UniPCMultistepScheduler | |
| from diffusers.utils import logging | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.modular_pipelines import ( | |
| BlockState, | |
| LoopSequentialPipelineBlocks, | |
| ModularPipelineBlocks, | |
| PipelineState, | |
| ModularPipeline | |
| ) | |
| from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class MatrixGameWanLoopDenoiser(ModularPipelineBlocks): | |
| model_name = "MatrixGameWan" | |
| frame_seq_length = 880 | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec( | |
| "guider", | |
| ClassifierFreeGuidance, | |
| config=FrozenDict({"guidance_scale": 5.0}), | |
| default_creation_method="from_config", | |
| ), | |
| ComponentSpec("transformer", AutoModel), | |
| ] | |
| def description(self) -> str: | |
| return ( | |
| "Step within the denoising loop that denoise the latents with guidance. " | |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " | |
| "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)" | |
| ) | |
| def inputs(self) -> List[Tuple[str, Any]]: | |
| return [ | |
| InputParam("attention_kwargs"), | |
| InputParam( | |
| "latents", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", | |
| ), | |
| InputParam( | |
| "image_mask_latents", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| ), | |
| InputParam( | |
| "image_embeds", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| ), | |
| InputParam( | |
| "keyboard_conditions", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| ), | |
| InputParam( | |
| "mouse_conditions", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| ), | |
| InputParam( | |
| "num_inference_steps", | |
| required=True, | |
| type_hint=int, | |
| default=4, | |
| description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", | |
| ), | |
| InputParam( | |
| kwargs_type="guider_input_fields", | |
| description=( | |
| "All conditional model inputs that need to be prepared with guider. " | |
| "It should contain prompt_embeds/negative_prompt_embeds. " | |
| "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" | |
| ), | |
| ), | |
| ] | |
| def __call__( | |
| self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor | |
| ) -> PipelineState: | |
| cond_concat = block_state.image_mask_latents | |
| keyboard_conditions = block_state.keyboard_conditions | |
| mouse_conditions = block_state.mouse_conditions | |
| visual_context = block_state.image_embeds | |
| transformer_dtype = components.transformer.dtype | |
| components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) | |
| # Prepare mini‐batches according to guidance method and `guider_input_fields` | |
| # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. | |
| # e.g. for CFG, we prepare two batches: one for uncond, one for cond | |
| # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds | |
| # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds | |
| guider_state = components.guider.prepare_inputs(block_state, {}) | |
| # run the denoiser for each guidance batch | |
| for guider_state_batch in guider_state: | |
| components.guider.prepare_models(components.transformer) | |
| cond_kwargs = guider_state_batch.as_dict() | |
| # Predict the noise residual | |
| # store the noise_pred in guider_state_batch so that we can apply guidance across all batches | |
| guider_state_batch.noise_pred = components.transformer( | |
| x=block_state.latents.to(transformer_dtype), | |
| t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block), | |
| visual_context=visual_context.to(transformer_dtype), | |
| cond_concat=cond_concat.to(transformer_dtype), | |
| keyboard_cond=keyboard_conditions, | |
| mouse_cond=mouse_conditions, | |
| kv_cache=block_state.kv_cache, | |
| kv_cache_mouse=block_state.kv_cache_mouse, | |
| kv_cache_keyboard=block_state.kv_cache_keyboard, | |
| crossattn_cache=block_state.kv_cache_cross_attn, | |
| current_start=block_state.current_frame_idx * self.frame_seq_length, | |
| num_frames_per_block=block_state.num_frames_per_block, | |
| )[0] | |
| components.guider.cleanup_models(components.transformer) | |
| # Perform guidance | |
| block_state.noise_pred = components.guider(guider_state)[0] | |
| return components, block_state | |
| class MatrixGameWanLoopAfterDenoiser(ModularPipelineBlocks): | |
| model_name = "MatrixGameWan" | |
| def expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", UniPCMultistepScheduler), | |
| ] | |
| def description(self) -> str: | |
| return ( | |
| "step within the denoising loop that update the latents. " | |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " | |
| "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)" | |
| ) | |
| def inputs(self) -> List[Tuple[str, Any]]: | |
| return [] | |
| def intermediate_inputs(self) -> List[str]: | |
| return [ | |
| InputParam("generator"), | |
| ] | |
| def intermediate_outputs(self) -> List[OutputParam]: | |
| return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] | |
| def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): | |
| # Perform scheduler step using the predicted output | |
| latents_dtype = block_state.latents.dtype | |
| step_index = components.scheduler.index_for_timestep(t) | |
| sigma_t = components.scheduler.sigmas[step_index] | |
| latents = block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double() | |
| block_state.latents = latents | |
| if block_state.latents.dtype != latents_dtype: | |
| block_state.latents = block_state.latents.to(latents_dtype) | |
| return components, block_state | |
| class MatrixGameWanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): | |
| model_name = "MatrixGameWan" | |
| frame_seq_length = 880 | |
| local_attn_size = 6 | |
| num_transformer_blocks = 30 | |
| def _initialize_kv_cache(self, batch_size, dtype, device): | |
| """ | |
| Initialize a Per-GPU KV cache for the Wan model. | |
| """ | |
| cache = [] | |
| if self.local_attn_size != -1: | |
| # Use the local attention size to compute the KV cache size | |
| kv_cache_size = self.local_attn_size * self.frame_seq_length | |
| else: | |
| # Use the default KV cache size | |
| kv_cache_size = 15 * 1 * self.frame_seq_length # 32760 | |
| for _ in range(self.num_transformer_blocks): | |
| cache.append({ | |
| "k": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device), | |
| "v": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device), | |
| "global_end_index": torch.tensor([0], dtype=torch.long, device=device), | |
| "local_end_index": torch.tensor([0], dtype=torch.long, device=device) | |
| }) | |
| return cache # always store the clean cache | |
| def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device): | |
| """ | |
| Initialize a Per-GPU KV cache for the Wan model. | |
| """ | |
| kv_cache_mouse = [] | |
| kv_cache_keyboard = [] | |
| if self.local_attn_size != -1: | |
| kv_cache_size = self.local_attn_size | |
| else: | |
| kv_cache_size = 15 * 1 | |
| for _ in range(self.num_transformer_blocks): | |
| kv_cache_keyboard.append({ | |
| "k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device), | |
| "v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device), | |
| "global_end_index": torch.tensor([0], dtype=torch.long, device=device), | |
| "local_end_index": torch.tensor([0], dtype=torch.long, device=device) | |
| }) | |
| kv_cache_mouse.append({ | |
| "k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), | |
| "v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device), | |
| "global_end_index": torch.tensor([0], dtype=torch.long, device=device), | |
| "local_end_index": torch.tensor([0], dtype=torch.long, device=device) | |
| }) | |
| return kv_cache_mouse, kv_cache_keyboard # always store the clean cache | |
| def _initialize_crossattn_cache(self, batch_size, dtype, device): | |
| """ | |
| Initialize a Per-GPU cross-attention cache for the Wan model. | |
| """ | |
| crossattn_cache = [] | |
| for _ in range(self.num_transformer_blocks): | |
| crossattn_cache.append({ | |
| "k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device), | |
| "v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device), | |
| "is_init": False | |
| }) | |
| return crossattn_cache | |
| def description(self) -> str: | |
| return ( | |
| "Pipeline block that iteratively denoise the latents over `timesteps`. " | |
| "The specific steps with each iteration can be customized with `sub_blocks` attributes" | |
| ) | |
| def loop_expected_components(self) -> List[ComponentSpec]: | |
| return [ | |
| ComponentSpec( | |
| "guider", | |
| ClassifierFreeGuidance, | |
| config=FrozenDict({"guidance_scale": 5.0}), | |
| default_creation_method="from_config", | |
| ), | |
| ComponentSpec("scheduler", UniPCMultistepScheduler), | |
| ComponentSpec("transformer", AutoModel), | |
| ] | |
| def loop_inputs(self) -> List[InputParam]: | |
| return [ | |
| InputParam( | |
| "timesteps", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", | |
| ), | |
| InputParam( | |
| "num_inference_steps", | |
| required=True, | |
| type_hint=int, | |
| description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", | |
| ), | |
| InputParam( | |
| "num_frames_per_block", | |
| required=True, | |
| type_hint=int, | |
| default=3, | |
| ), | |
| ] | |
| def __call__( | |
| self, components: ModularPipeline, state: PipelineState | |
| ) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| transformer_dtype = components.transformer.dtype | |
| num_frames_per_block = block_state.num_frames_per_block | |
| latents = block_state.latents.to(transformer_dtype) | |
| image_mask_latents = block_state.image_mask_latents.to(transformer_dtype) | |
| mouse_conditions = block_state.mouse_conditions.unsqueeze(0).to(transformer_dtype) | |
| keyboard_conditions = block_state.keyboard_conditions.unsqueeze(0).to(transformer_dtype) | |
| visual_context = block_state.image_embeds | |
| batch_size, num_channels, num_frames, height, width = latents.shape | |
| output = torch.zeros( | |
| (batch_size, num_channels, num_frames, height, width), | |
| device=latents.device, | |
| dtype=latents.dtype, | |
| ) | |
| current_frame_idx = 0 | |
| num_blocks = num_frames // num_frames_per_block | |
| kv_cache = self._initialize_kv_cache(batch_size, latents.dtype, latents.device) | |
| kv_cache_mouse, kv_cache_keyboard = self._initialize_kv_cache_mouse_and_keyboard(batch_size, latents.dtype, latents.device) | |
| kv_cache_cross_attn = self._initialize_crossattn_cache(batch_size, latents.dtype, latents.device) | |
| block_state.kv_cache = kv_cache | |
| block_state.kv_cache_mouse = kv_cache_mouse | |
| block_state.kv_cache_keyboard = kv_cache_keyboard | |
| block_state.kv_cache_cross_attn = kv_cache_cross_attn | |
| for _ in range(num_blocks): | |
| block_state.current_frame_idx = current_frame_idx | |
| block_state.image_mask_latents = image_mask_latents[ | |
| :, :, current_frame_idx : current_frame_idx + num_frames_per_block | |
| ] | |
| cond_idx = 1 + 4 * (current_frame_idx + num_frames_per_block - 1) | |
| block_state.mouse_conditions = mouse_conditions[:, :cond_idx] | |
| block_state.keyboard_conditions = keyboard_conditions[:, :cond_idx] | |
| block_state.latents = latents[ | |
| :, :, current_frame_idx : current_frame_idx + num_frames_per_block | |
| ] | |
| for i, t in enumerate(block_state.timesteps): | |
| components, block_state = self.loop_step( | |
| components, block_state, i=i, t=t | |
| ) | |
| if i < (block_state.num_inference_steps - 1): | |
| t1 = components.scheduler.timesteps[i+1] | |
| block_state.latents = components.scheduler.add_noise( | |
| block_state.latents, | |
| randn_tensor( | |
| block_state.latents.shape, | |
| device=block_state.latents.device, | |
| dtype=block_state.latents.dtype | |
| ), | |
| t1.expand(block_state.latents.shape[0]) | |
| ) | |
| output[ | |
| :, :, current_frame_idx : current_frame_idx + num_frames_per_block | |
| ] = block_state.latents | |
| components.transformer( | |
| x=block_state.latents, | |
| t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block) * 0.0, | |
| visual_context=visual_context, | |
| cond_concat=block_state.image_mask_latents, | |
| keyboard_cond=block_state.keyboard_conditions, | |
| mouse_cond=block_state.mouse_conditions, | |
| kv_cache=block_state.kv_cache, | |
| kv_cache_mouse=block_state.kv_cache_mouse, | |
| kv_cache_keyboard=block_state.kv_cache_keyboard, | |
| crossattn_cache=block_state.kv_cache_cross_attn, | |
| current_start=block_state.current_frame_idx * self.frame_seq_length, | |
| num_frames_per_block=block_state.num_frames_per_block, | |
| )[0] | |
| current_frame_idx += num_frames_per_block | |
| block_state.latents = output | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class MatrixGameWanDenoiseStep(MatrixGameWanDenoiseLoopWrapper): | |
| block_classes = [ | |
| MatrixGameWanLoopDenoiser, | |
| MatrixGameWanLoopAfterDenoiser, | |
| ] | |
| block_names = ["denoiser", "after_denoiser"] | |
| def description(self) -> str: | |
| return ( | |
| "Denoise step that iteratively denoise the latents. \n" | |
| "Its loop logic is defined in `MatrixGameWanDenoiseLoopWrapper.__call__` method \n" | |
| "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" | |
| " - `MatrixGameWanLoopDenoiser`\n" | |
| " - `MatrixGameWanLoopAfterDenoiser`\n" | |
| "This block supports both text2vid tasks." | |
| ) | |