| from typing import Union |
| from torch import Tensor |
| import torch |
|
|
| import comfy.utils |
| import comfy.controlnet as comfy_cn |
| from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, broadcast_image_to |
|
|
|
|
| def get_properly_arranged_t2i_weights(initial_weights: list[float]): |
| new_weights = [] |
| new_weights.extend([initial_weights[0]]*3) |
| new_weights.extend([initial_weights[1]]*3) |
| new_weights.extend([initial_weights[2]]*3) |
| new_weights.extend([initial_weights[3]]*3) |
| return new_weights |
|
|
|
|
| class ControlWeightType: |
| DEFAULT = "default" |
| UNIVERSAL = "universal" |
| T2IADAPTER = "t2iadapter" |
| CONTROLNET = "controlnet" |
| CONTROLLORA = "controllora" |
| CONTROLLLLITE = "controllllite" |
|
|
|
|
| class ControlWeights: |
| def __init__(self, weight_type: str, base_multiplier: float=1.0, flip_weights: bool=False, weights: list[float]=None, weight_mask: Tensor=None): |
| self.weight_type = weight_type |
| self.base_multiplier = base_multiplier |
| self.flip_weights = flip_weights |
| self.weights = weights |
| if self.weights is not None and self.flip_weights: |
| self.weights.reverse() |
| self.weight_mask = weight_mask |
|
|
| def get(self, idx: int) -> Union[float, Tensor]: |
| |
| if self.weights is not None: |
| return self.weights[idx] |
| return 1.0 |
|
|
| @classmethod |
| def default(cls): |
| return cls(ControlWeightType.DEFAULT) |
|
|
| @classmethod |
| def universal(cls, base_multiplier: float, flip_weights: bool=False): |
| return cls(ControlWeightType.UNIVERSAL, base_multiplier=base_multiplier, flip_weights=flip_weights) |
| |
| @classmethod |
| def universal_mask(cls, weight_mask: Tensor): |
| return cls(ControlWeightType.UNIVERSAL, weight_mask=weight_mask) |
|
|
| @classmethod |
| def t2iadapter(cls, weights: list[float]=None, flip_weights: bool=False): |
| if weights is None: |
| weights = [1.0]*12 |
| return cls(ControlWeightType.T2IADAPTER, weights=weights,flip_weights=flip_weights) |
|
|
| @classmethod |
| def controlnet(cls, weights: list[float]=None, flip_weights: bool=False): |
| if weights is None: |
| weights = [1.0]*13 |
| return cls(ControlWeightType.CONTROLNET, weights=weights, flip_weights=flip_weights) |
| |
| @classmethod |
| def controllora(cls, weights: list[float]=None, flip_weights: bool=False): |
| if weights is None: |
| weights = [1.0]*10 |
| return cls(ControlWeightType.CONTROLLORA, weights=weights, flip_weights=flip_weights) |
| |
| @classmethod |
| def controllllite(cls, weights: list[float]=None, flip_weights: bool=False): |
| if weights is None: |
| |
| weights = [1.0]*200 |
| return cls(ControlWeightType.CONTROLLLLITE, weights=weights, flip_weights=flip_weights) |
|
|
|
|
| class StrengthInterpolation: |
| LINEAR = "linear" |
| EASE_IN = "ease-in" |
| EASE_OUT = "ease-out" |
| EASE_IN_OUT = "ease-in-out" |
| NONE = "none" |
|
|
|
|
| class LatentKeyframe: |
| def __init__(self, batch_index: int, strength: float) -> None: |
| self.batch_index = batch_index |
| self.strength = strength |
|
|
|
|
| |
| class LatentKeyframeGroup: |
| def __init__(self) -> None: |
| self.keyframes: list[LatentKeyframe] = [] |
|
|
| def add(self, keyframe: LatentKeyframe) -> None: |
| added = False |
| |
| for i in range(len(self.keyframes)): |
| if self.keyframes[i].batch_index == keyframe.batch_index: |
| self.keyframes[i] = keyframe |
| added = True |
| break |
| if not added: |
| self.keyframes.append(keyframe) |
| self.keyframes.sort(key=lambda k: k.batch_index) |
| |
| def get_index(self, index: int) -> Union[LatentKeyframe, None]: |
| try: |
| return self.keyframes[index] |
| except IndexError: |
| return None |
| |
| def __getitem__(self, index) -> LatentKeyframe: |
| return self.keyframes[index] |
| |
| def is_empty(self) -> bool: |
| return len(self.keyframes) == 0 |
|
|
| def clone(self) -> 'LatentKeyframeGroup': |
| cloned = LatentKeyframeGroup() |
| for tk in self.keyframes: |
| cloned.add(tk) |
| return cloned |
|
|
|
|
| class TimestepKeyframe: |
| def __init__(self, |
| start_percent: float = 0.0, |
| strength: float = 1.0, |
| interpolation: str = StrengthInterpolation.NONE, |
| control_weights: ControlWeights = None, |
| latent_keyframes: LatentKeyframeGroup = None, |
| null_latent_kf_strength: float = 0.0, |
| inherit_missing: bool = True, |
| guarantee_usage: bool = True, |
| mask_hint_orig: Tensor = None) -> None: |
| self.start_percent = start_percent |
| self.start_t = 999999999.9 |
| self.strength = strength |
| self.interpolation = interpolation |
| self.control_weights = control_weights |
| self.latent_keyframes = latent_keyframes |
| self.null_latent_kf_strength = null_latent_kf_strength |
| self.inherit_missing = inherit_missing |
| self.guarantee_usage = guarantee_usage |
| self.mask_hint_orig = mask_hint_orig |
|
|
| def has_control_weights(self): |
| return self.control_weights is not None |
| |
| def has_latent_keyframes(self): |
| return self.latent_keyframes is not None |
| |
| def has_mask_hint(self): |
| return self.mask_hint_orig is not None |
| |
| |
| @classmethod |
| def default(cls) -> 'TimestepKeyframe': |
| return cls(0.0) |
|
|
|
|
| |
| class TimestepKeyframeGroup: |
| def __init__(self) -> None: |
| self.keyframes: list[TimestepKeyframe] = [] |
| self.keyframes.append(TimestepKeyframe.default()) |
|
|
| def add(self, keyframe: TimestepKeyframe) -> None: |
| added = False |
| |
| for i in range(len(self.keyframes)): |
| if self.keyframes[i].start_percent == keyframe.start_percent: |
| self.keyframes[i] = keyframe |
| added = True |
| break |
| if not added: |
| self.keyframes.append(keyframe) |
| self.keyframes.sort(key=lambda k: k.start_percent) |
|
|
| def get_index(self, index: int) -> Union[TimestepKeyframe, None]: |
| try: |
| return self.keyframes[index] |
| except IndexError: |
| return None |
| |
| def has_index(self, index: int) -> int: |
| return index >=0 and index < len(self.keyframes) |
|
|
| def __getitem__(self, index) -> TimestepKeyframe: |
| return self.keyframes[index] |
| |
| def __len__(self) -> int: |
| return len(self.keyframes) |
|
|
| def is_empty(self) -> bool: |
| return len(self.keyframes) == 0 |
| |
| def clone(self) -> 'TimestepKeyframeGroup': |
| cloned = TimestepKeyframeGroup() |
| for tk in self.keyframes: |
| cloned.add(tk) |
| return cloned |
| |
| @classmethod |
| def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup': |
| group = cls() |
| group.keyframes[0] = keyframe |
| return group |
|
|
|
|
| |
|
|
|
|
| class AdvancedControlBase: |
| def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroup, weights_default: ControlWeights): |
| self.base = base |
| self.compatible_weights = [ControlWeightType.UNIVERSAL] |
| self.add_compatible_weight(weights_default.weight_type) |
| |
| self.mask_cond_hint_original = None |
| self.mask_cond_hint = None |
| self.tk_mask_cond_hint_original = None |
| self.tk_mask_cond_hint = None |
| self.weight_mask_cond_hint = None |
| |
| self.sub_idxs = None |
| self.full_latent_length = 0 |
| self.context_length = 0 |
| |
| self.t: Tensor = None |
| self.batched_number: int = None |
| |
| self.weights: ControlWeights = None |
| self.weights_default: ControlWeights = weights_default |
| self.weights_override: ControlWeights = None |
| |
| self.latent_keyframes: LatentKeyframeGroup = None |
| self.latent_keyframe_override: LatentKeyframeGroup = None |
| |
| self.set_timestep_keyframes(timestep_keyframes) |
| |
| self.get_control = self.get_control_inject |
| self.control_merge = self.control_merge_inject |
| self.pre_run = self.pre_run_inject |
| self.cleanup = self.cleanup_inject |
|
|
| def add_compatible_weight(self, control_weight_type: str): |
| self.compatible_weights.append(control_weight_type) |
|
|
| def verify_all_weights(self, throw_error=True): |
| |
| if self.weights_override is not None: |
| if self.weights_override.weight_type not in self.compatible_weights: |
| msg = f"Weight override is type {self.weights_override.weight_type}, but loaded {type(self).__name__}" + \ |
| f"only supports {self.compatible_weights} weights." |
| raise WeightTypeException(msg) |
| |
| else: |
| for tk in self.timestep_keyframes.keyframes: |
| if tk.has_control_weights() and tk.control_weights.weight_type not in self.compatible_weights: |
| msg = f"Weight on Timestep Keyframe with start_percent={tk.start_percent} is type" + \ |
| f"{tk.control_weights.weight_type}, but loaded {type(self).__name__} only supports {self.compatible_weights} weights." |
| raise WeightTypeException(msg) |
|
|
| def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup): |
| self.timestep_keyframes = timestep_keyframes if timestep_keyframes else TimestepKeyframeGroup() |
| |
| self.current_timestep_keyframe = None |
| self.current_timestep_index = -1 |
| self.next_timestep_keyframe = None |
| self.weights = None |
| self.latent_keyframes = None |
|
|
| def prepare_current_timestep(self, t: Tensor, batched_number: int): |
| self.t = t |
| self.batched_number = batched_number |
| |
| curr_t: float = t[0] |
| prev_index = self.current_timestep_index |
| |
| if self.timestep_keyframes.has_index(self.current_timestep_index+1): |
| for i in range(self.current_timestep_index+1, len(self.timestep_keyframes)): |
| eval_tk = self.timestep_keyframes[i] |
| |
| if eval_tk.start_t >= curr_t: |
| self.current_timestep_index = i |
| self.current_timestep_keyframe = eval_tk |
| |
| |
| if self.current_timestep_keyframe.has_control_weights(): |
| self.weights = self.current_timestep_keyframe.control_weights |
| elif not self.current_timestep_keyframe.inherit_missing: |
| self.weights = self.weights_default |
| if self.current_timestep_keyframe.has_latent_keyframes(): |
| self.latent_keyframes = self.current_timestep_keyframe.latent_keyframes |
| elif not self.current_timestep_keyframe.inherit_missing: |
| self.latent_keyframes = None |
| if self.current_timestep_keyframe.has_mask_hint(): |
| self.tk_mask_cond_hint_original = self.current_timestep_keyframe.mask_hint_orig |
| elif not self.current_timestep_keyframe.inherit_missing: |
| del self.tk_mask_cond_hint_original |
| self.tk_mask_cond_hint_original = None |
| |
| if self.current_timestep_keyframe.guarantee_usage: |
| break |
| |
| else: |
| break |
| |
| |
| if prev_index != self.current_timestep_index: |
| if self.weights_override is not None: |
| self.weights = self.weights_override |
| if self.latent_keyframe_override is not None: |
| self.latent_keyframes = self.latent_keyframe_override |
|
|
| |
| |
| self.prepare_weights() |
| |
| def prepare_weights(self): |
| if self.weights is None or self.weights.weight_type == ControlWeightType.DEFAULT: |
| self.weights = self.weights_default |
| elif self.weights.weight_type == ControlWeightType.UNIVERSAL: |
| |
| if self.weights.weight_mask is not None: |
| return |
| self.weights = self.get_universal_weights() |
| |
| def get_universal_weights(self) -> ControlWeights: |
| return self.weights |
|
|
| def set_cond_hint_mask(self, mask_hint): |
| self.mask_cond_hint_original = mask_hint |
| return self |
|
|
| def pre_run_inject(self, model, percent_to_timestep_function): |
| self.base.pre_run(model, percent_to_timestep_function) |
| self.pre_run_advanced(model, percent_to_timestep_function) |
| |
| def pre_run_advanced(self, model, percent_to_timestep_function): |
| |
| for tk in self.timestep_keyframes.keyframes: |
| tk.start_t = percent_to_timestep_function(tk.start_percent) |
| |
| self.cleanup_advanced() |
|
|
| def get_control_inject(self, x_noisy, t, cond, batched_number): |
| |
| self.prepare_current_timestep(t=t, batched_number=batched_number) |
| |
| if self.strength == 0.0 or self.current_timestep_keyframe.strength == 0.0: |
| control_prev = None |
| if self.previous_controlnet is not None: |
| control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) |
| if control_prev is not None: |
| return control_prev |
| else: |
| return None |
| |
| return self.get_control_advanced(x_noisy, t, cond, batched_number) |
|
|
| def get_control_advanced(self, x_noisy, t, cond, batched_number): |
| pass |
|
|
| def calc_weight(self, idx: int, x: Tensor, layers: int) -> Union[float, Tensor]: |
| if self.weights.weight_mask is not None: |
| |
| self.prepare_weight_mask_cond_hint(x, self.batched_number) |
| |
| return torch.pow(self.weight_mask_cond_hint, self.get_calc_pow(idx=idx, layers=layers)) |
| return self.weights.get(idx=idx) |
| |
| def get_calc_pow(self, idx: int, layers: int) -> int: |
| return (layers-1)-idx |
|
|
| def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int): |
| |
| |
| if self.latent_keyframes is not None: |
| latent_count = x.size(0)//batched_number |
| indeces_to_null = set(range(latent_count)) |
| mapped_indeces = None |
| |
| if self.sub_idxs: |
| mapped_indeces = {} |
| for i, actual in enumerate(self.sub_idxs): |
| mapped_indeces[actual] = i |
| for keyframe in self.latent_keyframes: |
| real_index = keyframe.batch_index |
| |
| if real_index < 0: |
| real_index += latent_count if self.sub_idxs is None else self.full_latent_length |
|
|
| |
| if mapped_indeces is None: |
| if real_index in indeces_to_null: |
| indeces_to_null.remove(real_index) |
| |
| else: |
| real_index = mapped_indeces.get(real_index, None) |
| if real_index is None: |
| continue |
| indeces_to_null.remove(real_index) |
|
|
| |
| if real_index >= latent_count or real_index < 0: |
| continue |
|
|
| |
| for b in range(batched_number): |
| x[(latent_count*b)+real_index] = x[(latent_count*b)+real_index] * keyframe.strength |
|
|
| |
| for batch_index in indeces_to_null: |
| |
| for b in range(batched_number): |
| x[(latent_count*b)+batch_index] = x[(latent_count*b)+batch_index] * self.current_timestep_keyframe.null_latent_kf_strength |
| |
| if self.mask_cond_hint is not None: |
| masks = prepare_mask_batch(self.mask_cond_hint, x.shape) |
| x[:] = x[:] * masks |
| if self.tk_mask_cond_hint is not None: |
| masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape) |
| x[:] = x[:] * masks |
| |
| if self.current_timestep_keyframe.strength != 1.0: |
| x[:] *= self.current_timestep_keyframe.strength |
| |
| def control_merge_inject(self: 'AdvancedControlBase', control_input, control_output, control_prev, output_dtype): |
| out = {'input':[], 'middle':[], 'output': []} |
|
|
| if control_input is not None: |
| for i in range(len(control_input)): |
| key = 'input' |
| x = control_input[i] |
| if x is not None: |
| self.apply_advanced_strengths_and_masks(x, self.batched_number) |
|
|
| x *= self.strength * self.calc_weight(i, x, len(control_input)) |
| if x.dtype != output_dtype: |
| x = x.to(output_dtype) |
| out[key].insert(0, x) |
|
|
| if control_output is not None: |
| for i in range(len(control_output)): |
| if i == (len(control_output) - 1): |
| key = 'middle' |
| index = 0 |
| else: |
| key = 'output' |
| index = i |
| x = control_output[i] |
| if x is not None: |
| self.apply_advanced_strengths_and_masks(x, self.batched_number) |
|
|
| if self.global_average_pooling: |
| x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) |
|
|
| x *= self.strength * self.calc_weight(i, x, len(control_output)) |
| if x.dtype != output_dtype: |
| x = x.to(output_dtype) |
|
|
| out[key].append(x) |
| if control_prev is not None: |
| for x in ['input', 'middle', 'output']: |
| o = out[x] |
| for i in range(len(control_prev[x])): |
| prev_val = control_prev[x][i] |
| if i >= len(o): |
| o.append(prev_val) |
| elif prev_val is not None: |
| if o[i] is None: |
| o[i] = prev_val |
| else: |
| o[i] += prev_val |
| return out |
|
|
| def prepare_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None): |
| self._prepare_mask("mask_cond_hint", self.mask_cond_hint_original, x_noisy, t, cond, batched_number, dtype) |
| self.prepare_tk_mask_cond_hint(x_noisy, t, cond, batched_number, dtype) |
|
|
| def prepare_tk_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None): |
| return self._prepare_mask("tk_mask_cond_hint", self.current_timestep_keyframe.mask_hint_orig, x_noisy, t, cond, batched_number, dtype) |
|
|
| def prepare_weight_mask_cond_hint(self, x_noisy: Tensor, batched_number, dtype=None): |
| return self._prepare_mask("weight_mask_cond_hint", self.weights.weight_mask, x_noisy, t=None, cond=None, batched_number=batched_number, dtype=dtype, direct_attn=True) |
|
|
| def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): |
| |
| if orig_mask is not None: |
| out_mask = getattr(self, attr_name) |
| if self.sub_idxs is not None or out_mask is None or x_noisy.shape[2] * 8 != out_mask.shape[1] or x_noisy.shape[3] * 8 != out_mask.shape[2]: |
| self._reset_attr(attr_name) |
| del out_mask |
| |
| |
| multiplier = 1 if direct_attn else 8 |
| out_mask = prepare_mask_batch(orig_mask, x_noisy.shape, multiplier=multiplier) |
| actual_latent_length = x_noisy.shape[0] // batched_number |
| out_mask = comfy.utils.repeat_to_batch_size(out_mask, actual_latent_length if self.sub_idxs is None else self.full_latent_length) |
| if self.sub_idxs is not None: |
| out_mask = out_mask[self.sub_idxs] |
| |
| if x_noisy.shape[0] != out_mask.shape[0]: |
| out_mask = broadcast_image_to(out_mask, x_noisy.shape[0], batched_number) |
| |
| if dtype is None: |
| dtype = x_noisy.dtype |
| setattr(self, attr_name, out_mask.to(dtype=dtype).to(self.device)) |
| del out_mask |
|
|
| def _reset_attr(self, attr_name, new_value=None): |
| if hasattr(self, attr_name): |
| delattr(self, attr_name) |
| setattr(self, attr_name, new_value) |
|
|
| def cleanup_inject(self): |
| self.base.cleanup() |
| self.cleanup_advanced() |
|
|
| def cleanup_advanced(self): |
| self.sub_idxs = None |
| self.full_latent_length = 0 |
| self.context_length = 0 |
| self.t = None |
| self.batched_number = None |
| self.weights = None |
| self.latent_keyframes = None |
| |
| self.current_timestep_keyframe = None |
| self.next_timestep_keyframe = None |
| self.current_timestep_index = -1 |
| |
| if self.mask_cond_hint is not None: |
| del self.mask_cond_hint |
| self.mask_cond_hint = None |
| if self.tk_mask_cond_hint_original is not None: |
| del self.tk_mask_cond_hint_original |
| self.tk_mask_cond_hint_original = None |
| if self.tk_mask_cond_hint is not None: |
| del self.tk_mask_cond_hint |
| self.tk_mask_cond_hint = None |
| if self.weight_mask_cond_hint is not None: |
| del self.weight_mask_cond_hint |
| self.weight_mask_cond_hint = None |
| |
| def copy_to_advanced(self, copied: 'AdvancedControlBase'): |
| copied.mask_cond_hint_original = self.mask_cond_hint_original |
| copied.weights_override = self.weights_override |
| copied.latent_keyframe_override = self.latent_keyframe_override |
|
|
|
|
| class ControlNetAdvanced(ControlNet, AdvancedControlBase): |
| def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None): |
| super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, device=device) |
| AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet()) |
|
|
| def get_universal_weights(self) -> ControlWeights: |
| raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)] |
| return ControlWeights.controlnet(raw_weights, self.weights.flip_weights) |
|
|
| def get_control_advanced(self, x_noisy, t, cond, batched_number): |
| |
| return self.sliding_get_control(x_noisy, t, cond, batched_number) |
|
|
| def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number): |
| control_prev = None |
| if self.previous_controlnet is not None: |
| control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) |
|
|
| if self.timestep_range is not None: |
| if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: |
| if control_prev is not None: |
| return control_prev |
| else: |
| return None |
|
|
| output_dtype = x_noisy.dtype |
|
|
| |
| |
| if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: |
| if self.cond_hint is not None: |
| del self.cond_hint |
| self.cond_hint = None |
| |
| if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length: |
| self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) |
| else: |
| self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) |
| if x_noisy.shape[0] != self.cond_hint.shape[0]: |
| self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) |
|
|
| |
| self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=self.control_model.dtype) |
|
|
| context = cond['c_crossattn'] |
| |
| y = cond.get('y', None) |
| if y is None: |
| y = cond.get('c_adm', None) |
| if y is not None: |
| y = y.to(self.control_model.dtype) |
| timestep = self.model_sampling_current.timestep(t) |
| x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) |
|
|
| control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) |
| return self.control_merge(None, control, control_prev, output_dtype) |
|
|
| def copy(self): |
| c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling) |
| self.copy_to(c) |
| self.copy_to_advanced(c) |
| return c |
| |
| @staticmethod |
| def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced': |
| return ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe, |
| global_average_pooling=v.global_average_pooling, device=v.device) |
|
|
|
|
| class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase): |
| def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channels_in, device=None): |
| super().__init__(t2i_model=t2i_model, channels_in=channels_in, device=device) |
| AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.t2iadapter()) |
|
|
| def get_universal_weights(self) -> ControlWeights: |
| raw_weights = [(self.weights.base_multiplier ** float(7 - i)) for i in range(8)] |
| raw_weights = [raw_weights[-8], raw_weights[-3], raw_weights[-2], raw_weights[-1]] |
| raw_weights = get_properly_arranged_t2i_weights(raw_weights) |
| return ControlWeights.t2iadapter(raw_weights, self.weights.flip_weights) |
|
|
| def get_calc_pow(self, idx: int, layers: int) -> int: |
| |
| indeces = [7 - i for i in range(8)] |
| indeces = [indeces[-8], indeces[-3], indeces[-2], indeces[-1]] |
| indeces = get_properly_arranged_t2i_weights(indeces) |
| return indeces[idx] |
|
|
| def get_control_advanced(self, x_noisy, t, cond, batched_number): |
| |
| self.prepare_current_timestep(t=t, batched_number=batched_number) |
| try: |
| |
| if self.sub_idxs is not None: |
| |
| full_cond_hint_original = self.cond_hint_original |
| del self.cond_hint |
| self.cond_hint = None |
| self.cond_hint_original = full_cond_hint_original[self.sub_idxs] |
| |
| self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number) |
| return super().get_control(x_noisy, t, cond, batched_number) |
| finally: |
| if self.sub_idxs is not None: |
| |
| self.cond_hint_original = full_cond_hint_original |
| del full_cond_hint_original |
|
|
| def copy(self): |
| c = T2IAdapterAdvanced(self.t2i_model, self.timestep_keyframes, self.channels_in) |
| self.copy_to(c) |
| self.copy_to_advanced(c) |
| return c |
| |
| def cleanup(self): |
| super().cleanup() |
| self.cleanup_advanced() |
|
|
| @staticmethod |
| def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced': |
| return T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in, device=v.device) |
|
|
|
|
| class ControlLoraAdvanced(ControlLora, AdvancedControlBase): |
| def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None): |
| super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling, device=device) |
| AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllora()) |
| |
| self.get_control_advanced = ControlNetAdvanced.get_control_advanced.__get__(self, type(self)) |
| self.sliding_get_control = ControlNetAdvanced.sliding_get_control.__get__(self, type(self)) |
| |
| def get_universal_weights(self) -> ControlWeights: |
| raw_weights = [(self.weights.base_multiplier ** float(9 - i)) for i in range(10)] |
| return ControlWeights.controllora(raw_weights, self.weights.flip_weights) |
|
|
| def copy(self): |
| c = ControlLoraAdvanced(self.control_weights, self.timestep_keyframes, global_average_pooling=self.global_average_pooling) |
| self.copy_to(c) |
| self.copy_to_advanced(c) |
| return c |
| |
| def cleanup(self): |
| super().cleanup() |
| self.cleanup_advanced() |
|
|
| @staticmethod |
| def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced': |
| return ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe, |
| global_average_pooling=v.global_average_pooling, device=v.device) |
|
|
|
|
| class ControlLLLiteAdvanced(ControlNet, AdvancedControlBase): |
| def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, device=None): |
| AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite()) |
|
|
|
|
| def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None): |
| control = comfy_cn.load_controlnet(ckpt_path, model=model) |
| |
| |
| |
| |
| |
| |
| |
| return convert_to_advanced(control, timestep_keyframe=timestep_keyframe) |
|
|
|
|
| def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None): |
| |
| if is_advanced_controlnet(control): |
| return control |
| |
| if type(control) == ControlNet: |
| return ControlNetAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) |
| |
| elif type(control) == ControlLora: |
| return ControlLoraAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) |
| |
| elif isinstance(control, T2IAdapter): |
| return T2IAdapterAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) |
| |
| return control |
|
|
|
|
| def is_advanced_controlnet(input_object): |
| return hasattr(input_object, "sub_idxs") |
|
|
|
|
| |
| def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False): |
| mask = mask.clone() |
| mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear") |
| if match_dim1: |
| mask = torch.cat([mask] * shape[1], dim=1) |
| return mask |
|
|
|
|
| |
| |
| def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0): |
| x_min, x_max = x.min(), x.max() |
| return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min |
|
|
| def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): |
| return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min |
|
|
|
|
| class WeightTypeException(TypeError): |
| "Raised when weight not compatible with AdvancedControlBase object" |
| pass |
|
|