| import importlib
|
| import os
|
| from typing import Optional, Set
|
|
|
| import diffusers.loaders.single_file_model as single_file_model
|
| import diffusers.pipelines.pipeline_loading_utils as pipe_loading_utils
|
| import torch
|
| from diffusers.loaders.single_file_utils import (
|
| convert_animatediff_checkpoint_to_diffusers,
|
| convert_auraflow_transformer_checkpoint_to_diffusers,
|
| convert_autoencoder_dc_checkpoint_to_diffusers,
|
| convert_chroma_transformer_checkpoint_to_diffusers,
|
| convert_controlnet_checkpoint,
|
| convert_cosmos_transformer_checkpoint_to_diffusers,
|
| convert_flux2_transformer_checkpoint_to_diffusers,
|
| convert_flux_transformer_checkpoint_to_diffusers,
|
| convert_hidream_transformer_to_diffusers,
|
| convert_hunyuan_video_transformer_to_diffusers,
|
| convert_ldm_unet_checkpoint,
|
| convert_ldm_vae_checkpoint,
|
| convert_ltx_transformer_checkpoint_to_diffusers,
|
| convert_ltx_vae_checkpoint_to_diffusers,
|
| convert_lumina2_to_diffusers,
|
| convert_mochi_transformer_checkpoint_to_diffusers,
|
| convert_sana_transformer_to_diffusers,
|
| convert_sd3_transformer_checkpoint_to_diffusers,
|
| convert_stable_cascade_unet_single_file_to_diffusers,
|
| convert_wan_transformer_to_diffusers,
|
| convert_wan_vae_to_diffusers,
|
| convert_z_image_transformer_checkpoint_to_diffusers,
|
| create_controlnet_diffusers_config_from_ldm,
|
| create_unet_diffusers_config_from_ldm,
|
| create_vae_diffusers_config_from_ldm,
|
| )
|
| from diffusers.pipelines.pipeline_loading_utils import _unwrap_model
|
| from diffusers.utils import (
|
| _maybe_remap_transformers_class,
|
| get_class_from_dynamic_module,
|
| )
|
|
|
|
|
| try:
|
| from diffusers.hooks.group_offloading import (
|
| _GROUP_ID_LAZY_LEAF,
|
| GroupOffloadingConfig,
|
| ModelHook,
|
| ModuleGroup,
|
| _apply_group_offloading_hook,
|
| _apply_lazy_group_offloading_hook,
|
| _find_parent_module_in_module_dict,
|
| _gather_buffers_with_no_group_offloading_parent,
|
| _gather_parameters_with_no_group_offloading_parent,
|
| send_to_device,
|
| )
|
|
|
| except ImportError:
|
| ModelHook = object
|
| ModuleGroup = object
|
| GroupOffloadingConfig = object
|
|
|
| def _apply_group_offloading_hook(*args, **kwargs):
|
| pass
|
|
|
|
|
| _MY_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
|
| torch.nn.Conv1d,
|
| torch.nn.Conv2d,
|
| torch.nn.Conv3d,
|
| torch.nn.ConvTranspose1d,
|
| torch.nn.ConvTranspose2d,
|
| torch.nn.ConvTranspose3d,
|
| torch.nn.Linear,
|
| torch.nn.Sequential,
|
| )
|
|
|
|
|
| class GroupOffloadingHook(ModelHook):
|
| r"""
|
| A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
|
| computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
|
| module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
|
| group is responsible for onloading the current module group.
|
| """
|
|
|
| _is_stateful = False
|
|
|
| def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
| self.group = group
|
| self.next_group: Optional[ModuleGroup] = None
|
| self.config = config
|
|
|
| def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
| if self.group.offload_leader == module:
|
| self.group.offload_()
|
| return module
|
|
|
| def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
|
|
|
|
| if self.group.onload_leader is None:
|
| self.group.onload_leader = module
|
|
|
| if self.group.onload_leader == module:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if not self.group.onload_self and self.group.stream is not None:
|
| self.group.onload_()
|
| self.group.stream.synchronize()
|
|
|
|
|
|
|
| if self.group.onload_self:
|
| self.group.onload_()
|
|
|
|
|
| if self.group.stream is not None:
|
| self.group.stream.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
| if should_onload_next_group:
|
| self.next_group.onload_()
|
|
|
|
|
|
|
| args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
|
|
| exclude_kwargs = self.config.exclude_kwargs or []
|
| if exclude_kwargs:
|
| moved_kwargs = send_to_device(
|
| {k: v for k, v in kwargs.items() if k not in exclude_kwargs},
|
| self.group.onload_device,
|
| non_blocking=self.group.non_blocking,
|
| )
|
| kwargs.update(moved_kwargs)
|
| else:
|
| kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
|
|
| return args, kwargs
|
|
|
| def post_forward(self, module: torch.nn.Module, output):
|
| if self.group.offload_leader == module:
|
| self.group.offload_()
|
| return output
|
|
|
|
|
| def _apply_group_offloading_leaf_level_patched(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| """
|
| Versão corrigida de _apply_group_offloading_leaf_level que suporta nn.Sequential.
|
| """
|
| modules_with_group_offloading: Set[str] = set()
|
| for name, submodule in module.named_modules():
|
| if not isinstance(submodule, _MY_GO_LC_SUPPORTED_PYTORCH_LAYERS):
|
| continue
|
|
|
| group = ModuleGroup(
|
| modules=[submodule],
|
| offload_device=config.offload_device,
|
| onload_device=config.onload_device,
|
| offload_to_disk_path=config.offload_to_disk_path,
|
| offload_leader=submodule,
|
| onload_leader=submodule,
|
| non_blocking=config.non_blocking,
|
| stream=config.stream,
|
| record_stream=config.record_stream,
|
| low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| onload_self=True,
|
| group_id=name,
|
| )
|
| _apply_group_offloading_hook(submodule, group, config=config)
|
| modules_with_group_offloading.add(name)
|
|
|
|
|
|
|
| module_dict = dict(module.named_modules())
|
| parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
|
|
|
|
| parent_to_parameters = {}
|
| for name, param in parameters:
|
| parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| if parent_name in parent_to_parameters:
|
| parent_to_parameters[parent_name].append(param)
|
| else:
|
| parent_to_parameters[parent_name] = [param]
|
|
|
| parent_to_buffers = {}
|
| for name, buffer in buffers:
|
| parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| if parent_name in parent_to_buffers:
|
| parent_to_buffers[parent_name].append(buffer)
|
| else:
|
| parent_to_buffers[parent_name] = [buffer]
|
|
|
| parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
|
| for name in parent_names:
|
| parameters = parent_to_parameters.get(name, [])
|
| buffers = parent_to_buffers.get(name, [])
|
| parent_module = module_dict[name]
|
| group = ModuleGroup(
|
| modules=[],
|
| offload_device=config.offload_device,
|
| onload_device=config.onload_device,
|
| offload_leader=parent_module,
|
| onload_leader=parent_module,
|
| offload_to_disk_path=config.offload_to_disk_path,
|
| parameters=parameters,
|
| buffers=buffers,
|
| non_blocking=config.non_blocking,
|
| stream=config.stream,
|
| record_stream=config.record_stream,
|
| low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| onload_self=True,
|
| group_id=name,
|
| )
|
| _apply_group_offloading_hook(parent_module, group, config=config)
|
|
|
| if config.stream is not None:
|
|
|
|
|
|
|
| unmatched_group = ModuleGroup(
|
| modules=[],
|
| offload_device=config.offload_device,
|
| onload_device=config.onload_device,
|
| offload_to_disk_path=config.offload_to_disk_path,
|
| offload_leader=module,
|
| onload_leader=module,
|
| parameters=None,
|
| buffers=None,
|
| non_blocking=False,
|
| stream=None,
|
| record_stream=False,
|
| low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| onload_self=True,
|
| group_id=_GROUP_ID_LAZY_LEAF,
|
| )
|
| _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
|
|
|
|
| try:
|
| import diffusers.hooks.group_offloading as group_offloading_module
|
|
|
| setattr(group_offloading_module, "_apply_group_offloading_leaf_level", _apply_group_offloading_leaf_level_patched)
|
| setattr(group_offloading_module, "GroupOffloadingHook", GroupOffloadingHook)
|
| except ImportError as e:
|
| print(f"-> ERRO: Não foi possível importar o módulo `diffusers.hooks.group_offloading` para aplicar o patch: {e}")
|
|
|
|
|
| def convert_z_image_control_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
| Z_IMAGE_KEYS_RENAME_DICT = {
|
| "final_layer.": "all_final_layer.2-1.",
|
| "x_embedder.": "all_x_embedder.2-1.",
|
| ".attention.out.bias": ".attention.to_out.0.bias",
|
| ".attention.k_norm.weight": ".attention.norm_k.weight",
|
| ".attention.q_norm.weight": ".attention.norm_q.weight",
|
| ".attention.out.weight": ".attention.to_out.0.weight",
|
| "control_x_embedder.": "control_all_x_embedder.2-1.",
|
| }
|
|
|
| def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
|
| if ".attention.qkv.weight" not in key:
|
| return
|
|
|
| fused_qkv_weight = state_dict.pop(key)
|
| to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
| new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
|
| new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
|
| new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
|
|
|
| state_dict[new_q_name] = to_q_weight
|
| state_dict[new_k_name] = to_k_weight
|
| state_dict[new_v_name] = to_v_weight
|
| return
|
|
|
| TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
| ".attention.qkv.weight": convert_z_image_fused_attention,
|
| }
|
|
|
| def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
|
| state_dict[new_key] = state_dict.pop(old_key)
|
|
|
| converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
|
|
|
|
| for key in list(converted_state_dict.keys()):
|
| new_key = key[:]
|
| for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
|
| new_key = new_key.replace(replace_key, rename_key)
|
|
|
| update_state_dict(converted_state_dict, key, new_key)
|
|
|
|
|
|
|
| for key in list(converted_state_dict.keys()):
|
| for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
| if special_key not in key:
|
| continue
|
| handler_fn_inplace(key, converted_state_dict)
|
|
|
| return converted_state_dict
|
|
|
|
|
| SINGLE_FILE_LOADABLE_CLASSES = {
|
| "StableCascadeUNet": {
|
| "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
| },
|
| "UNet2DConditionModel": {
|
| "checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
| "config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
| "default_subfolder": "unet",
|
| "legacy_kwargs": {
|
| "num_in_channels": "in_channels",
|
| },
|
| },
|
| "AutoencoderKL": {
|
| "checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
| "config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
| "default_subfolder": "vae",
|
| },
|
| "ControlNetModel": {
|
| "checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
| "config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
| },
|
| "SD3Transformer2DModel": {
|
| "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "MotionAdapter": {
|
| "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
| },
|
| "SparseControlNetModel": {
|
| "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
| },
|
| "FluxTransformer2DModel": {
|
| "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "ChromaTransformer2DModel": {
|
| "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "LTXVideoTransformer3DModel": {
|
| "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "AutoencoderKLLTXVideo": {
|
| "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
| "default_subfolder": "vae",
|
| },
|
| "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
| "MochiTransformer3DModel": {
|
| "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "HunyuanVideoTransformer3DModel": {
|
| "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "AuraFlowTransformer2DModel": {
|
| "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "Lumina2Transformer2DModel": {
|
| "checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "SanaTransformer2DModel": {
|
| "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "WanTransformer3DModel": {
|
| "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "WanVACETransformer3DModel": {
|
| "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "AutoencoderKLWan": {
|
| "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
| "default_subfolder": "vae",
|
| },
|
| "HiDreamImageTransformer2DModel": {
|
| "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "CosmosTransformer3DModel": {
|
| "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "QwenImageTransformer2DModel": {
|
| "checkpoint_mapping_fn": lambda x: x,
|
| "default_subfolder": "transformer",
|
| },
|
| "Flux2Transformer2DModel": {
|
| "checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "ZImageTransformer2DModel": {
|
| "checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| "ZImageControlTransformer2DModel": {
|
| "checkpoint_mapping_fn": convert_z_image_control_transformer_checkpoint_to_diffusers,
|
| "default_subfolder": "transformer",
|
| },
|
| }
|
|
|
|
|
| def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None):
|
| """Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
| component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
|
|
| if is_pipeline_module:
|
| pipeline_module = getattr(pipelines, library_name)
|
|
|
| class_obj = getattr(pipeline_module, class_name)
|
| class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
| elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
|
|
| class_obj = get_class_from_dynamic_module(component_folder, module_file=library_name + ".py", class_name=class_name)
|
| class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
| else:
|
|
|
| library = importlib.import_module(library_name)
|
|
|
|
|
| if library_name == "transformers":
|
| class_name = _maybe_remap_transformers_class(class_name) or class_name
|
|
|
| try:
|
| class_obj = getattr(library, class_name)
|
| except Exception:
|
| module = importlib.import_module("diffusers_local")
|
| class_obj = getattr(module, class_name)
|
| class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
|
|
| return class_obj, class_candidates
|
|
|
|
|
| def _get_single_file_loadable_mapping_class(cls):
|
| diffusers_module = importlib.import_module("diffusers")
|
| class_name_str = cls.__name__
|
| for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
| try:
|
| loadable_class = getattr(diffusers_module, loadable_class_str)
|
| except Exception:
|
| module = importlib.import_module("diffusers_local")
|
| loadable_class = getattr(module, loadable_class_str)
|
| if issubclass(cls, loadable_class):
|
| return loadable_class_str
|
|
|
| return class_name_str
|
|
|
|
|
| def maybe_raise_or_warn(library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module):
|
| """Simple helper method to raise or warn in case incorrect module has been passed"""
|
| if not is_pipeline_module:
|
| library = importlib.import_module(library_name)
|
|
|
|
|
| if library_name == "transformers":
|
| class_name = _maybe_remap_transformers_class(class_name) or class_name
|
|
|
| try:
|
| class_obj = getattr(library, class_name)
|
| except Exception:
|
| module = importlib.import_module("diffusers_local")
|
| class_obj = getattr(module, class_name)
|
|
|
| class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
|
|
| expected_class_obj = None
|
| for class_name, class_candidate in class_candidates.items():
|
| if class_candidate is not None and issubclass(class_obj, class_candidate):
|
| expected_class_obj = class_candidate
|
|
|
|
|
|
|
| sub_model = passed_class_obj[name]
|
| unwrapped_sub_model = _unwrap_model(sub_model)
|
| model_cls = unwrapped_sub_model.__class__
|
|
|
| if not issubclass(model_cls, expected_class_obj):
|
| raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
|
| else:
|
| print(f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it has the correct type")
|
|
|
|
|
| pipe_loading_utils.get_class_obj_and_candidates = get_class_obj_and_candidates
|
| pipe_loading_utils.maybe_raise_or_warn = maybe_raise_or_warn
|
| single_file_model.SINGLE_FILE_LOADABLE_CLASSES = SINGLE_FILE_LOADABLE_CLASSES
|
| single_file_model._get_single_file_loadable_mapping_class = _get_single_file_loadable_mapping_class
|
|
|