| import folder_paths |
| import comfy.utils |
| import comfy.model_detection |
| import comfy.model_management |
| import comfy.lora |
| from comfy.model_patcher import ModelPatcher |
|
|
| from .utils import TimestepKeyframeGroup |
| from .control import ControlNetAdvanced, load_controlnet |
|
|
|
|
|
|
|
|
| def convert_cn_lora_from_diffusers(cn_model: ModelPatcher, lora_path: str): |
| lora_data = comfy.utils.load_torch_file(lora_path, safe_load=True) |
| unet_dtype = comfy.model_management.unet_dtype() |
| for key, value in lora_data.items(): |
| lora_data[key] = value.to(unet_dtype) |
| diffusers_keys = comfy.utils.unet_to_diffusers(cn_model.model.state_dict()) |
|
|
| |
|
|
|
|
|
|
| |
| lora_data = comfy.lora.load_lora(lora_data, to_load=diffusers_keys) |
|
|
| |
| |
| |
| |
| return lora_data |
|
|
|
|
| class ControlNetLoaderWithLoraAdvanced: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "control_net_name": (folder_paths.get_filename_list("controlnet"), ), |
| "cn_lora_name": (folder_paths.get_filename_list("controlnet"), ), |
| "cn_lora_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), |
| }, |
| "optional": { |
| "timestep_keyframe": ("TIMESTEP_KEYFRAME", ), |
| } |
| } |
|
|
| RETURN_TYPES = ("CONTROL_NET", ) |
| FUNCTION = "load_controlnet" |
|
|
| CATEGORY = "Adv-ControlNet ππ
π
π
/LOOSEControl" |
|
|
| def load_controlnet(self, control_net_name, cn_lora_name, cn_lora_strength: float, |
| timestep_keyframe: TimestepKeyframeGroup=None |
| ): |
| controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) |
| controlnet: ControlNetAdvanced = load_controlnet(controlnet_path, timestep_keyframe) |
| if not isinstance(controlnet, ControlNetAdvanced): |
| raise ValueError("Type {} is not compatible with CN LoRA features at this time.") |
| |
| lora_path = folder_paths.get_full_path("controlnet", cn_lora_name) |
| lora_data = convert_cn_lora_from_diffusers(cn_model=controlnet.control_model_wrapped, lora_path=lora_path) |
| |
| controlnet.control_model_wrapped.add_patches(lora_data, strength_patch=cn_lora_strength) |
| |
| return (controlnet,) |
|
|