| |
| |
|
|
| LOG_PREFIX = '[ControlNet-Travel]' |
|
|
| |
|
|
| CTRLNET_REPO_NAME = 'sdcontrol' |
| if 'externel repo sanity check': |
| from pathlib import Path |
| from modules.scripts import basedir |
| from traceback import print_exc |
|
|
| ME_PATH = Path(basedir()) |
| CTRLNET_PATH = ME_PATH.parent / 'sdcontrol' |
|
|
| controlnet_found = False |
| try: |
| import sys ; sys.path.append(str(CTRLNET_PATH)) |
| |
| from scripts.external_code import ControlNetUnit |
| from scripts.hook import UNetModel, UnetHook, ControlParams |
| from scripts.hook import * |
|
|
| controlnet_found = True |
| print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} found, ControlNet-Travel loaded :)') |
| except ImportError: |
| print(f'{LOG_PREFIX} extension {CTRLNET_REPO_NAME} not found, ControlNet-Travel ignored :(') |
| exit(0) |
| except: |
| print_exc() |
| exit(0) |
|
|
| |
|
|
|
|
| import sys |
| from PIL import Image |
|
|
| from ldm.models.diffusion.ddpm import LatentDiffusion |
| from modules import shared, devices, lowvram |
| from modules.processing import StableDiffusionProcessing as Processing |
|
|
| from scripts.prompt_travel import * |
| from manager import run_cmd |
|
|
| class InterpMethod(Enum): |
| LINEAR = 'linear (weight sum)' |
| RIFE = 'rife (optical flow)' |
|
|
| if 'consts': |
| __ = lambda key, value=None: opts.data.get(f'customscript/controlnet_travel.py/txt2img/{key}/value', value) |
|
|
|
|
| LABEL_CTRLNET_REF_DIR = 'Reference image folder (one ref image per stage :)' |
| LABEL_INTERP_METH = 'Interpolate method' |
| LABEL_SKIP_FUSE = 'Ext. skip latent fusion' |
| LABEL_DEBUG_RIFE = 'Save RIFE intermediates' |
|
|
| DEFAULT_STEPS = 10 |
| DEFAULT_CTRLNET_REF_DIR = str(ME_PATH / 'img' / 'ref_ctrlnet') |
| DEFAULT_INTERP_METH = __(LABEL_INTERP_METH, InterpMethod.LINEAR.value) |
| DEFAULT_SKIP_FUSE = __(LABEL_SKIP_FUSE, False) |
| DEFAULT_DEBUG_RIFE = __(LABEL_DEBUG_RIFE, False) |
|
|
| CHOICES_INTERP_METH = [x.value for x in InterpMethod] |
|
|
| if 'vars': |
| skip_fuse_plan: List[bool] = [] |
|
|
| interp_alpha: float = 0.0 |
| interp_ip: int = 0 |
| from_hint_cond: List[Tensor] = [] |
| to_hint_cond: List[Tensor] = [] |
| mid_hint_cond: List[Tensor] = [] |
| from_control_tensors: List[List[Tensor]] = [] |
| to_control_tensors: List[List[Tensor]] = [] |
|
|
| caches: List[list] = [from_hint_cond, to_hint_cond, mid_hint_cond, from_control_tensors, to_control_tensors] |
|
|
|
|
| |
|
|
| def hook_hijack(self:UnetHook, model:UNetModel, sd_ldm:LatentDiffusion, control_params:List[ControlParams], process:Processing): |
| self.model = model |
| self.sd_ldm = sd_ldm |
| self.control_params = control_params |
|
|
| outer = self |
|
|
| def process_sample(*args, **kwargs): |
| |
| |
| |
| |
| |
| |
| |
| mark_prompt_context(kwargs.get('conditioning', []), positive=True) |
| mark_prompt_context(kwargs.get('unconditional_conditioning', []), positive=False) |
| mark_prompt_context(getattr(process, 'hr_c', []), positive=True) |
| mark_prompt_context(getattr(process, 'hr_uc', []), positive=False) |
| return process.sample_before_CN_hack(*args, **kwargs) |
|
|
| |
| def forward(self:UNetModel, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs): |
| total_controlnet_embedding = [0.0] * 13 |
| total_t2i_adapter_embedding = [0.0] * 4 |
| require_inpaint_hijack = False |
| is_in_high_res_fix = False |
| batch_size = int(x.shape[0]) |
|
|
| |
| global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, mid_hint_cond, interp_alpha, interp_ip |
| x: Tensor |
| timesteps: Tensor |
| context: Tensor |
| kwargs: dict |
|
|
| |
| cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context) |
| |
|
|
| |
| for param in outer.control_params: |
| |
| if param.used_hint_cond is None: |
| param.used_hint_cond = param.hint_cond |
| param.used_hint_cond_latent = None |
| param.used_hint_inpaint_hijack = None |
|
|
| |
| if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4: |
| _, _, h_lr, w_lr = param.hint_cond.shape |
| _, _, h_hr, w_hr = param.hr_hint_cond.shape |
| _, _, h, w = x.shape |
| h, w = h * 8, w * 8 |
| if abs(h - h_lr) < abs(h - h_hr): |
| is_in_high_res_fix = False |
| if param.used_hint_cond is not param.hint_cond: |
| param.used_hint_cond = param.hint_cond |
| param.used_hint_cond_latent = None |
| param.used_hint_inpaint_hijack = None |
| else: |
| is_in_high_res_fix = True |
| if param.used_hint_cond is not param.hr_hint_cond: |
| param.used_hint_cond = param.hr_hint_cond |
| param.used_hint_cond_latent = None |
| param.used_hint_inpaint_hijack = None |
|
|
| |
| for i, param in enumerate(outer.control_params): |
| if interp_alpha == 0.0: |
| if len(to_hint_cond) < len(outer.control_params): |
| to_hint_cond.append(param.used_hint_cond.clone().detach().cpu()) |
| else: |
| param.used_hint_cond = mid_hint_cond[i].to(x.device) |
|
|
| |
| for param in outer.control_params: |
| if param.used_hint_cond_latent is not None: |
| continue |
| if param.control_model_type not in [ControlModelType.AttentionInjection] \ |
| and 'colorfix' not in param.preprocessor['name'] \ |
| and 'inpaint_only' not in param.preprocessor['name']: |
| continue |
| param.used_hint_cond_latent = outer.call_vae_using_process(process, param.used_hint_cond, batch_size=batch_size) |
| |
| |
| for param in outer.control_params: |
| if param.guidance_stopped: |
| continue |
|
|
| if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]: |
| continue |
|
|
| param.control_model.to(devices.get_device_for("controlnet")) |
| control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context) |
| control = torch.cat([control.clone() for _ in range(batch_size)], dim=0) |
| control *= param.weight |
| control *= cond_mark[:, :, :, 0] |
| context = torch.cat([context, control.clone()], dim=1) |
|
|
| |
| for param in outer.control_params: |
| if param.guidance_stopped: |
| continue |
|
|
| if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]: |
| continue |
|
|
| param.control_model.to(devices.get_device_for("controlnet")) |
| |
| x_in = x |
| control_model = param.control_model.control_model |
|
|
| if param.control_model_type == ControlModelType.ControlNet: |
| if x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9: |
| |
| x_in = x[:, :4, ...] |
| require_inpaint_hijack = True |
|
|
| assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given" |
|
|
| hint = param.used_hint_cond |
|
|
| |
| if hint.shape[1] == 4: |
| c = hint[:, 0:3, :, :] |
| m = hint[:, 3:4, :, :] |
| m = (m > 0.5).float() |
| hint = c * (1 - m) - m |
|
|
| |
| control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context) |
| control_scales = ([param.weight] * 13) |
|
|
| if outer.lowvram: |
| param.control_model.to("cpu") |
|
|
| if param.cfg_injection or param.global_average_pooling: |
| if param.control_model_type == ControlModelType.T2I_Adapter: |
| control = [torch.cat([c.clone() for _ in range(batch_size)], dim=0) for c in control] |
| control = [c * cond_mark for c in control] |
|
|
| high_res_fix_forced_soft_injection = False |
|
|
| if is_in_high_res_fix: |
| if 'canny' in param.preprocessor['name']: |
| high_res_fix_forced_soft_injection = True |
| if 'mlsd' in param.preprocessor['name']: |
| high_res_fix_forced_soft_injection = True |
|
|
| |
| |
|
|
| if param.soft_injection or high_res_fix_forced_soft_injection: |
| |
| if param.control_model_type == ControlModelType.T2I_Adapter: |
| control_scales = [param.weight * x for x in (0.25, 0.62, 0.825, 1.0)] |
| elif param.control_model_type == ControlModelType.ControlNet: |
| control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)] |
|
|
| if param.advanced_weighting is not None: |
| control_scales = param.advanced_weighting |
|
|
| control = [c * scale for c, scale in zip(control, control_scales)] |
| if param.global_average_pooling: |
| control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control] |
|
|
| for idx, item in enumerate(control): |
| target = None |
| if param.control_model_type == ControlModelType.ControlNet: |
| target = total_controlnet_embedding |
| if param.control_model_type == ControlModelType.T2I_Adapter: |
| target = total_t2i_adapter_embedding |
| if target is not None: |
| target[idx] = item + target[idx] |
|
|
| |
| for param in outer.control_params: |
| if param.used_hint_cond.shape[1] != 4: |
| continue |
| if x.shape[1] != 9: |
| continue |
| if param.used_hint_inpaint_hijack is None: |
| mask_pixel = param.used_hint_cond[:, 3:4, :, :] |
| image_pixel = param.used_hint_cond[:, 0:3, :, :] |
| mask_pixel = (mask_pixel > 0.5).to(mask_pixel.dtype) |
| masked_latent = outer.call_vae_using_process(process, image_pixel, batch_size, mask=mask_pixel) |
| mask_latent = torch.nn.functional.max_pool2d(mask_pixel, (8, 8)) |
| if mask_latent.shape[0] != batch_size: |
| mask_latent = torch.cat([mask_latent.clone() for _ in range(batch_size)], dim=0) |
| param.used_hint_inpaint_hijack = torch.cat([mask_latent, masked_latent], dim=1) |
| param.used_hint_inpaint_hijack.to(x.dtype).to(x.device) |
| x = torch.cat([x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1) |
|
|
| |
| if shared.cmd_opts.medvram: |
| try: |
| |
| outer.sd_ldm.model() |
| except: |
| pass |
|
|
| |
| for module in outer.attn_module_list: |
| module.bank = [] |
| module.style_cfgs = [] |
| for module in outer.gn_module_list: |
| module.mean_bank = [] |
| module.var_bank = [] |
| module.style_cfgs = [] |
|
|
| |
| for param in outer.control_params: |
| if param.guidance_stopped: |
| continue |
|
|
| if param.used_hint_cond_latent is None: |
| continue |
|
|
| if param.control_model_type not in [ControlModelType.AttentionInjection]: |
| continue |
|
|
| ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long()) |
|
|
| |
| if x.shape[1] == 9: |
| ref_xt = torch.cat([ |
| ref_xt, |
| torch.zeros_like(ref_xt)[:, 0:1, :, :], |
| param.used_hint_cond_latent |
| ], dim=1) |
|
|
| outer.current_style_fidelity = float(param.preprocessor['threshold_a']) |
| outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity)) |
|
|
| if param.cfg_injection: |
| outer.current_style_fidelity = 1.0 |
| elif param.soft_injection or is_in_high_res_fix: |
| outer.current_style_fidelity = 0.0 |
|
|
| control_name = param.preprocessor['name'] |
|
|
| if control_name in ['reference_only', 'reference_adain+attn']: |
| outer.attention_auto_machine = AutoMachine.Write |
| outer.attention_auto_machine_weight = param.weight |
|
|
| if control_name in ['reference_adain', 'reference_adain+attn']: |
| outer.gn_auto_machine = AutoMachine.Write |
| outer.gn_auto_machine_weight = param.weight |
|
|
| outer.original_forward( |
| x=ref_xt.to(devices.dtype_unet), |
| timesteps=timesteps.to(devices.dtype_unet), |
| context=context.to(devices.dtype_unet) |
| ) |
|
|
| outer.attention_auto_machine = AutoMachine.Read |
| outer.gn_auto_machine = AutoMachine.Read |
|
|
| |
| total_control = total_controlnet_embedding |
| if interp_alpha == 0.0: |
| tensors: List[Tensor] = [] |
| for i, t in enumerate(total_control): |
| if len(skip_fuse_plan) and skip_fuse_plan[i]: |
| tensors.append(None) |
| else: |
| tensors.append(t.clone().detach().cpu()) |
| to_control_tensors.append(tensors) |
| else: |
| device = total_control[0].device |
| for i, (ctrlA, ctrlB) in enumerate(zip(from_control_tensors[interp_ip], to_control_tensors[interp_ip])): |
| if ctrlA is not None and ctrlB is not None: |
| ctrlC = weighted_sum(ctrlA.to(device), ctrlB.to(device), interp_alpha) |
| |
| total_control[i].data = ctrlC |
| interp_ip += 1 |
| |
| |
| if total_t2i_adapter_embedding[0] != 0: |
| print(f'{LOG_PREFIX} warn: currently t2i_adapter is not supported. if you wanna this, put a feature request on Kahsolt/stable-diffusion-webui-prompt-travel') |
|
|
| |
| hs = [] |
| with th.no_grad(): |
| t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False)) |
| emb = self.time_embed(t_emb) |
| h = x.type(self.dtype) |
| for i, module in enumerate(self.input_blocks): |
| h = module(h, emb, context) |
|
|
| if (i + 1) % 3 == 0: |
| h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack) |
|
|
| hs.append(h) |
| h = self.middle_block(h, emb, context) |
|
|
| |
| h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack) |
|
|
| |
| for i, module in enumerate(self.output_blocks): |
| h = th.cat([h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1) |
| h = module(h, emb, context) |
|
|
| |
| h = h.type(x.dtype) |
| h = self.out(h) |
|
|
| |
| for param in outer.control_params: |
| if param.used_hint_cond_latent is None: |
| continue |
| if 'colorfix' not in param.preprocessor['name']: |
| continue |
|
|
| k = int(param.preprocessor['threshold_a']) |
| if is_in_high_res_fix: |
| k *= 2 |
|
|
| |
| xt = x[:, :4, :, :] |
|
|
| x0_origin = param.used_hint_cond_latent |
| t = torch.round(timesteps.float()).long() |
| x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h) |
| x0 = x0_prd - blur(x0_prd, k) + blur(x0_origin, k) |
|
|
| if '+sharp' in param.preprocessor['name']: |
| detail_weight = float(param.preprocessor['threshold_b']) * 0.01 |
| neg = detail_weight * blur(x0, k) + (1 - detail_weight) * x0 |
| x0 = cond_mark * x0 + (1 - cond_mark) * neg |
|
|
| eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0) |
|
|
| w = max(0.0, min(1.0, float(param.weight))) |
| h = eps_prd * w + h * (1 - w) |
|
|
| |
| for param in outer.control_params: |
| if param.used_hint_cond_latent is None: |
| continue |
| if 'inpaint_only' not in param.preprocessor['name']: |
| continue |
| if param.used_hint_cond.shape[1] != 4: |
| continue |
|
|
| |
| xt = x[:, :4, :, :] |
|
|
| mask = param.used_hint_cond[:, 3:4, :, :] |
| mask = torch.nn.functional.max_pool2d(mask, (10, 10), stride=(8, 8), padding=1) |
|
|
| x0_origin = param.used_hint_cond_latent |
| t = torch.round(timesteps.float()).long() |
| x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h) |
| x0 = x0_prd * mask + x0_origin * (1 - mask) |
| eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0) |
|
|
| w = max(0.0, min(1.0, float(param.weight))) |
| h = eps_prd * w + h * (1 - w) |
|
|
| return h |
|
|
| def forward_webui(*args, **kwargs): |
| |
| try: |
| if shared.cmd_opts.lowvram: |
| lowvram.send_everything_to_cpu() |
|
|
| return forward(*args, **kwargs) |
| finally: |
| if self.lowvram: |
| for param in self.control_params: |
| if isinstance(param.control_model, torch.nn.Module): |
| param.control_model.to("cpu") |
|
|
| def hacked_basic_transformer_inner_forward(self, x, context=None): |
| x_norm1 = self.norm1(x) |
| self_attn1 = None |
| if self.disable_self_attn: |
| |
| self_attn1 = self.attn1(x_norm1, context=context) |
| else: |
| |
| self_attention_context = x_norm1 |
| if outer.attention_auto_machine == AutoMachine.Write: |
| if outer.attention_auto_machine_weight > self.attn_weight: |
| self.bank.append(self_attention_context.detach().clone()) |
| self.style_cfgs.append(outer.current_style_fidelity) |
| if outer.attention_auto_machine == AutoMachine.Read: |
| if len(self.bank) > 0: |
| style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) |
| self_attn1_uc = self.attn1(x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1)) |
| self_attn1_c = self_attn1_uc.clone() |
| if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: |
| self_attn1_c[outer.current_uc_indices] = self.attn1( |
| x_norm1[outer.current_uc_indices], |
| context=self_attention_context[outer.current_uc_indices]) |
| self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc |
| self.bank = [] |
| self.style_cfgs = [] |
| if self_attn1 is None: |
| self_attn1 = self.attn1(x_norm1, context=self_attention_context) |
|
|
| x = self_attn1.to(x.dtype) + x |
| x = self.attn2(self.norm2(x), context=context) + x |
| x = self.ff(self.norm3(x)) + x |
| return x |
|
|
| def hacked_group_norm_forward(self, *args, **kwargs): |
| eps = 1e-6 |
| x = self.original_forward(*args, **kwargs) |
| y = None |
| if outer.gn_auto_machine == AutoMachine.Write: |
| if outer.gn_auto_machine_weight > self.gn_weight: |
| var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) |
| self.mean_bank.append(mean) |
| self.var_bank.append(var) |
| self.style_cfgs.append(outer.current_style_fidelity) |
| if outer.gn_auto_machine == AutoMachine.Read: |
| if len(self.mean_bank) > 0 and len(self.var_bank) > 0: |
| style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) |
| var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) |
| std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 |
| mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) |
| var_acc = sum(self.var_bank) / float(len(self.var_bank)) |
| std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 |
| y_uc = (((x - mean) / std) * std_acc) + mean_acc |
| y_c = y_uc.clone() |
| if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: |
| y_c[outer.current_uc_indices] = x.to(y_c.dtype)[outer.current_uc_indices] |
| y = style_cfg * y_c + (1.0 - style_cfg) * y_uc |
| self.mean_bank = [] |
| self.var_bank = [] |
| self.style_cfgs = [] |
| if y is None: |
| y = x |
| return y.to(x.dtype) |
|
|
| if getattr(process, 'sample_before_CN_hack', None) is None: |
| process.sample_before_CN_hack = process.sample |
| process.sample = process_sample |
|
|
| model._original_forward = model.forward |
| outer.original_forward = model.forward |
| model.forward = forward_webui.__get__(model, UNetModel) |
|
|
| all_modules = torch_dfs(model) |
|
|
| attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)] |
| attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0]) |
|
|
| for i, module in enumerate(attn_modules): |
| if getattr(module, '_original_inner_forward', None) is None: |
| module._original_inner_forward = module._forward |
| module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) |
| module.bank = [] |
| module.style_cfgs = [] |
| module.attn_weight = float(i) / float(len(attn_modules)) |
|
|
| gn_modules = [model.middle_block] |
| model.middle_block.gn_weight = 0 |
|
|
| input_block_indices = [4, 5, 7, 8, 10, 11] |
| for w, i in enumerate(input_block_indices): |
| module = model.input_blocks[i] |
| module.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) |
| gn_modules.append(module) |
|
|
| output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] |
| for w, i in enumerate(output_block_indices): |
| module = model.output_blocks[i] |
| module.gn_weight = float(w) / float(len(output_block_indices)) |
| gn_modules.append(module) |
|
|
| for i, module in enumerate(gn_modules): |
| if getattr(module, 'original_forward', None) is None: |
| module.original_forward = module.forward |
| module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module) |
| module.mean_bank = [] |
| module.var_bank = [] |
| module.style_cfgs = [] |
| module.gn_weight *= 2 |
|
|
| outer.attn_module_list = attn_modules |
| outer.gn_module_list = gn_modules |
|
|
| scripts.script_callbacks.on_cfg_denoiser(self.guidance_schedule_handler) |
|
|
| |
|
|
| def reset_cuda(): |
| devices.torch_gc() |
| import gc; gc.collect() |
|
|
| try: |
| import os |
| import psutil |
| mem = psutil.Process(os.getpid()).memory_info() |
| print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') |
| from modules.shared import mem_mon as vram_mon |
| free, total = vram_mon.cuda_mem_get_info() |
| print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') |
| except: |
| pass |
|
|
|
|
| class Script(scripts.Script): |
|
|
| def title(self): |
| return 'ControlNet Travel' |
|
|
| def describe(self): |
| return 'Travel from one controlnet hint condition to another in the tensor space.' |
|
|
| def show(self, is_img2img): |
| return controlnet_found |
|
|
| def ui(self, is_img2img): |
| with gr.Row(variant='compact'): |
| interp_meth = gr.Dropdown(label=LABEL_INTERP_METH, value=lambda: DEFAULT_INTERP_METH, choices=CHOICES_INTERP_METH) |
| steps = gr.Text (label=LABEL_STEPS, value=lambda: DEFAULT_STEPS, max_lines=1) |
| |
| reset = gr.Button(value='Reset Cuda', variant='tool') |
| reset.click(fn=reset_cuda, show_progress=False) |
|
|
| with gr.Row(variant='compact'): |
| ctrlnet_ref_dir = gr.Text(label=LABEL_CTRLNET_REF_DIR, value=lambda: DEFAULT_CTRLNET_REF_DIR, max_lines=1) |
|
|
| with gr.Group(visible=DEFAULT_SKIP_FUSE) as tab_ext_skip_fuse: |
| with gr.Row(variant='compact'): |
| skip_in_0 = gr.Checkbox(label='in_0') |
| skip_in_3 = gr.Checkbox(label='in_3') |
| skip_out_0 = gr.Checkbox(label='out_0') |
| skip_out_3 = gr.Checkbox(label='out_3') |
| with gr.Row(variant='compact'): |
| skip_in_1 = gr.Checkbox(label='in_1') |
| skip_in_4 = gr.Checkbox(label='in_4') |
| skip_out_1 = gr.Checkbox(label='out_1') |
| skip_out_4 = gr.Checkbox(label='out_4') |
| with gr.Row(variant='compact'): |
| skip_in_2 = gr.Checkbox(label='in_2') |
| skip_in_5 = gr.Checkbox(label='in_5') |
| skip_out_2 = gr.Checkbox(label='out_2') |
| skip_out_5 = gr.Checkbox(label='out_5') |
| with gr.Row(variant='compact'): |
| skip_mid = gr.Checkbox(label='mid') |
|
|
| with gr.Row(variant='compact', visible=DEFAULT_UPSCALE) as tab_ext_upscale: |
| upscale_meth = gr.Dropdown(label=LABEL_UPSCALE_METH, value=lambda: DEFAULT_UPSCALE_METH, choices=CHOICES_UPSCALER) |
| upscale_ratio = gr.Slider (label=LABEL_UPSCALE_RATIO, value=lambda: DEFAULT_UPSCALE_RATIO, minimum=1.0, maximum=16.0, step=0.1) |
| upscale_width = gr.Slider (label=LABEL_UPSCALE_WIDTH, value=lambda: DEFAULT_UPSCALE_WIDTH, minimum=0, maximum=2048, step=8) |
| upscale_height = gr.Slider (label=LABEL_UPSCALE_HEIGHT, value=lambda: DEFAULT_UPSCALE_HEIGHT, minimum=0, maximum=2048, step=8) |
|
|
| with gr.Row(variant='compact', visible=DEFAULT_VIDEO) as tab_ext_video: |
| video_fmt = gr.Dropdown(label=LABEL_VIDEO_FMT, value=lambda: DEFAULT_VIDEO_FMT, choices=CHOICES_VIDEO_FMT) |
| video_fps = gr.Number (label=LABEL_VIDEO_FPS, value=lambda: DEFAULT_VIDEO_FPS) |
| video_pad = gr.Number (label=LABEL_VIDEO_PAD, value=lambda: DEFAULT_VIDEO_PAD, precision=0) |
| video_pick = gr.Text (label=LABEL_VIDEO_PICK, value=lambda: DEFAULT_VIDEO_PICK, max_lines=1) |
|
|
| with gr.Row(variant='compact') as tab_ext: |
| ext_video = gr.Checkbox(label=LABEL_VIDEO, value=lambda: DEFAULT_VIDEO) |
| ext_upscale = gr.Checkbox(label=LABEL_UPSCALE, value=lambda: DEFAULT_UPSCALE) |
| ext_skip_fuse = gr.Checkbox(label=LABEL_SKIP_FUSE, value=lambda: DEFAULT_SKIP_FUSE) |
| dbg_rife = gr.Checkbox(label=LABEL_DEBUG_RIFE, value=lambda: DEFAULT_DEBUG_RIFE) |
| |
| ext_video .change(gr_show, inputs=ext_video, outputs=tab_ext_video, show_progress=False) |
| ext_upscale .change(gr_show, inputs=ext_upscale, outputs=tab_ext_upscale, show_progress=False) |
| ext_skip_fuse.change(gr_show, inputs=ext_skip_fuse, outputs=tab_ext_skip_fuse, show_progress=False) |
|
|
| skip_fuses = [ |
| skip_in_0, |
| skip_in_1, |
| skip_in_2, |
| skip_in_3, |
| skip_in_4, |
| skip_in_5, |
| skip_mid, |
| skip_out_0, |
| skip_out_1, |
| skip_out_2, |
| skip_out_3, |
| skip_out_4, |
| skip_out_5, |
| ] |
| return [ |
| interp_meth, steps, ctrlnet_ref_dir, |
| upscale_meth, upscale_ratio, upscale_width, upscale_height, |
| video_fmt, video_fps, video_pad, video_pick, |
| ext_video, ext_upscale, ext_skip_fuse, dbg_rife, |
| *skip_fuses, |
| ] |
|
|
| def run(self, p:Processing, |
| interp_meth:str, steps:str, ctrlnet_ref_dir:str, |
| upscale_meth:str, upscale_ratio:float, upscale_width:int, upscale_height:int, |
| video_fmt:str, video_fps:float, video_pad:int, video_pick:str, |
| ext_video:bool, ext_upscale:bool, ext_skip_fuse:bool, dbg_rife:bool, |
| *skip_fuses:bool, |
| ): |
|
|
| |
| |
| self.controlnet_script = None |
| try: |
| for script in p.scripts.alwayson_scripts: |
| if hasattr(script, "latest_network") and script.title().lower() == "controlnet": |
| script_args: Tuple[ControlNetUnit] = p.script_args[script.args_from:script.args_to] |
| if not any([u.enabled for u in script_args]): return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not enabled') |
| self.controlnet_script = script |
| break |
| except ImportError: |
| return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not installed') |
| except: |
| print_exc() |
| if not self.controlnet_script: return Processed(p, [], p.seed, f'{CTRLNET_REPO_NAME} not loaded') |
|
|
| |
| interp_meth: InterpMethod = InterpMethod(interp_meth) |
| video_fmt: VideoFormat = VideoFormat (video_fmt) |
|
|
| |
| if ext_video: |
| if video_pad < 0: return Processed(p, [], p.seed, f'video_pad must >= 0, but got {video_pad}') |
| if video_fps <= 0: return Processed(p, [], p.seed, f'video_fps must > 0, but got {video_fps}') |
| try: video_slice = parse_slice(video_pick) |
| except: return Processed(p, [], p.seed, 'syntax error in video_slice') |
| if ext_skip_fuse: |
| global skip_fuse_plan |
| skip_fuse_plan = skip_fuses |
|
|
| |
| if not ctrlnet_ref_dir: return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}') |
| ctrlnet_ref_dir: Path = Path(ctrlnet_ref_dir) |
| if not ctrlnet_ref_dir.is_dir(): return Processed(p, [], p.seed, f'invalid image folder path: {ctrlnet_ref_dir}(') |
| self.ctrlnet_ref_fps = [fp for fp in list(ctrlnet_ref_dir.iterdir()) if fp.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.webp']] |
| n_stages = len(self.ctrlnet_ref_fps) |
| if n_stages == 0: return Processed(p, [], p.seed, f'no images file (*.jpg/*.png/*.bmp/*.webp) found in folder path: {ctrlnet_ref_dir}') |
| if n_stages == 1: return Processed(p, [], p.seed, 'requires at least two images to travel between, but found only 1 :(') |
|
|
| |
| try: steps: List[int] = [int(s.strip()) for s in steps.strip().split(',')] |
| except: return Processed(p, [], p.seed, f'cannot parse steps options: {steps}') |
| if len(steps) == 1: steps = [steps[0]] * (n_stages - 1) |
| elif len(steps) != n_stages - 1: return Processed(p, [], p.seed, f'stage count mismatch: len_steps({len(steps)}) != n_stages({n_stages} - 1))') |
| n_frames = sum(steps) + n_stages |
| if 'show_debug': |
| print('n_stages:', n_stages) |
| print('n_frames:', n_frames) |
| print('steps:', steps) |
| steps.insert(0, -1) |
|
|
| |
| travel_path = os.path.join(p.outpath_samples, 'prompt_travel') |
| os.makedirs(travel_path, exist_ok=True) |
| travel_number = get_next_sequence_number(travel_path) |
| self.log_dp = os.path.join(travel_path, f'{travel_number:05}') |
| p.outpath_samples = self.log_dp |
| os.makedirs(self.log_dp, exist_ok=True) |
| self.tmp_dp = Path(self.log_dp) / 'ctrl_cond' |
| self.tmp_fp = self.tmp_dp / 'tmp.png' |
|
|
| |
| p.n_iter = 1 |
| p.batch_size = 1 |
|
|
| |
| p.seed = get_fixed_seed(p.seed) |
| self.subseed = p.subseed |
| if 'show_debug': |
| print('seed:', p.seed) |
| print('subseed:', p.subseed) |
| print('subseed_strength:', p.subseed_strength) |
| |
| |
| state.job_count = n_frames |
|
|
| |
| self.n_stages = n_stages |
| self.steps = steps |
| self.interp_meth = interp_meth |
| self.dbg_rife = dbg_rife |
|
|
| def upscale_image_callback(params:ImageSaveParams): |
| params.image = upscale_image(params.image, p.width, p.height, upscale_meth, upscale_ratio, upscale_width, upscale_height) |
|
|
| images: List[PILImage] = [] |
| info: str = None |
| try: |
| if ext_upscale: on_before_image_saved(upscale_image_callback) |
|
|
| self.UnetHook_hook_original = UnetHook.hook |
| UnetHook.hook = hook_hijack |
|
|
| [c.clear() for c in caches] |
| images, info = self.run_linear(p) |
| except: |
| info = format_exc() |
| print(info) |
| finally: |
| if self.tmp_fp.exists(): os.unlink(self.tmp_fp) |
| [c.clear() for c in caches] |
|
|
| UnetHook.hook = self.UnetHook_hook_original |
|
|
| self.controlnet_script.input_image = None |
| if self.controlnet_script.latest_network: |
| self.controlnet_script.latest_network: UnetHook |
| self.controlnet_script.latest_network.restore(p.sd_model.model.diffusion_model) |
| self.controlnet_script.latest_network = None |
|
|
| if ext_upscale: remove_callbacks_for_function(upscale_image_callback) |
|
|
| reset_cuda() |
|
|
| |
| if ext_video: save_video(images, video_slice, video_pad, video_fps, video_fmt, os.path.join(self.log_dp, f'travel-{travel_number:05}')) |
|
|
| return Processed(p, images, p.seed, info) |
|
|
| def run_linear(self, p:Processing) -> RunResults: |
| global from_hint_cond, to_hint_cond, from_control_tensors, to_control_tensors, interp_alpha, interp_ip |
|
|
| images: List[PILImage] = [] |
| info: str = None |
| def process_p(append:bool=True) -> Optional[List[PILImage]]: |
| nonlocal p, images, info |
| proc = process_images(p) |
| if not info: info = proc.info |
| if append: images.extend(proc.images) |
| else: return proc.images |
|
|
| ''' βββ rife interp utils βββ ''' |
| def save_ctrl_cond(idx:int): |
| self.tmp_dp.mkdir(exist_ok=True) |
| for i, x in enumerate(to_hint_cond): |
| x = x[0] |
| if len(x.shape) == 3: |
| if x.shape[0] == 1: x = x.squeeze_(0) |
| elif x.shape[0] == 3: x = x.permute([1, 2, 0]) |
| else: raise ValueError(f'unknown cond shape: {x.shape}') |
| else: |
| raise ValueError(f'unknown cond shape: {x.shape}') |
| im = (x.detach().clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8) |
| Image.fromarray(im).save(self.tmp_dp / f'{idx}-{i}.png') |
| def rife_interp(i:int, j:int, k:int, alpha:float) -> Tensor: |
| ''' interp between i-th and j-th cond of the k-th ctrlnet set ''' |
| fp0 = self.tmp_dp / f'{i}-{k}.png' |
| fp1 = self.tmp_dp / f'{j}-{k}.png' |
| fpo = self.tmp_dp / f'{i}-{j}-{alpha:.3f}.png' if self.dbg_rife else self.tmp_fp |
| assert run_cmd(f'rife-ncnn-vulkan -m rife-v4 -s {alpha:.3f} -0 "{fp0}" -1 "{fp1}" -o "{fpo}"') |
| x = torch.from_numpy(np.asarray(Image.open(fpo)) / 255.0) |
| if len(x.shape) == 2: x = x.unsqueeze_(0) |
| elif len(x.shape) == 3: x = x.permute([2, 0, 1]) |
| else: raise ValueError(f'unknown cond shape: {x.shape}') |
| x = x.unsqueeze(dim=0) |
| return x |
| ''' βββ rife interp utils βββ ''' |
|
|
| ''' βββ filename reorder utils βββ ''' |
| iframe = 0 |
| def rename_image_filename(idx:int, param: ImageSaveParams): |
| fn = param.filename |
| stem, suffix = os.path.splitext(os.path.basename(fn)) |
| param.filename = os.path.join(os.path.dirname(fn), f'{idx:05d}' + suffix) |
| class on_before_image_saved_wrapper: |
| def __init__(self, callback_fn): |
| self.callback_fn = callback_fn |
| def __enter__(self): |
| on_before_image_saved(self.callback_fn) |
| def __exit__(self, exc_type, exc_value, exc_traceback): |
| remove_callbacks_for_function(self.callback_fn) |
| ''' βββ filename reorder utils βββ ''' |
|
|
| |
| setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[0])]) |
| interp_alpha = 0.0 |
| with on_before_image_saved_wrapper(partial(rename_image_filename, 0)): |
| process_p() |
| iframe += 1 |
| save_ctrl_cond(0) |
|
|
| |
| for i in range(1, self.n_stages): |
| if state.interrupted: break |
|
|
| |
| from_hint_cond = [t for t in to_hint_cond] ; to_hint_cond .clear() |
| from_control_tensors = [t for t in to_control_tensors] ; to_control_tensors.clear() |
| setattr(p, 'init_images', [Image.open(self.ctrlnet_ref_fps[i])]) |
| interp_alpha = 0.0 |
|
|
| with on_before_image_saved_wrapper(partial(rename_image_filename, iframe + self.steps[i])): |
| cached_images = process_p(append=False) |
| save_ctrl_cond(i) |
|
|
| |
| is_interrupted = False |
| n_inter = self.steps[i] + 1 |
| for t in range(1, n_inter): |
| if state.interrupted: is_interrupted = True ; break |
|
|
| interp_alpha = t / n_inter |
|
|
| mid_hint_cond.clear() |
| device = devices.get_device_for("controlnet") |
| if self.interp_meth == InterpMethod.LINEAR: |
| for hintA, hintB in zip(from_hint_cond, to_hint_cond): |
| hintC = weighted_sum(hintA.to(device), hintB.to(device), interp_alpha) |
| mid_hint_cond.append(hintC) |
| elif self.interp_meth == InterpMethod.RIFE: |
| dtype = to_hint_cond[0].dtype |
| for k in range(len(to_hint_cond)): |
| hintC = rife_interp(i-1, i, k, interp_alpha).to(device, dtype) |
| mid_hint_cond.append(hintC) |
| else: raise ValueError(f'unknown interp_meth: {self.interp_meth}') |
|
|
| interp_ip = 0 |
| with on_before_image_saved_wrapper(partial(rename_image_filename, iframe)): |
| process_p() |
| iframe += 1 |
|
|
| |
| images.extend(cached_images) |
| iframe += 1 |
|
|
| if is_interrupted: break |
|
|
| return images, info |
|
|