Spaces:
Running on Zero
Running on Zero
| import torch | |
| from refnet.util import exists, fitting_weights, instantiate_from_config, load_weights, delete_states | |
| from refnet.ldm import LatentDiffusion | |
| from typing import Union | |
| from refnet.sampling import ( | |
| UnetHook, | |
| KDiffusionSampler, | |
| DiffuserDenoiser, | |
| ) | |
| class GuidanceFlag: | |
| none = 0 | |
| reference = 1 | |
| sketch = 10 | |
| both = 11 | |
| def reconstruct_cond(cond, uncond): | |
| if not isinstance(uncond, list): | |
| uncond = [uncond] | |
| for k in cond.keys(): | |
| if k == "inpaint_bg": | |
| continue | |
| for uc in uncond: | |
| if isinstance(cond[k], list): | |
| cond[k] = [torch.cat([cond[k][i], uc[k][i]]) for i in range(len(cond[k]))] | |
| elif isinstance(cond[k], torch.Tensor): | |
| cond[k] = torch.cat([cond[k], uc[k]]) | |
| return cond | |
| class CustomizedLDM(LatentDiffusion): | |
| def __init__( | |
| self, | |
| dtype = torch.float32, | |
| sigma_max = None, | |
| sigma_min = None, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.dtype = dtype | |
| self.sigma_max = sigma_max | |
| self.sigma_min = sigma_min | |
| self.model_list = { | |
| "first": self.first_stage_model, | |
| "cond": self.cond_stage_model, | |
| "unet": self.model, | |
| } | |
| self.switch_cond_modules = ["cond"] | |
| self.switch_main_modules = ["unet"] | |
| self.retrieve_attn_modules() | |
| self.retrieve_attn_layers() | |
| def init_from_ckpt( | |
| self, | |
| path, | |
| only_model = False, | |
| logging = False, | |
| make_it_fit = False, | |
| ignore_keys: list[str] = (), | |
| ): | |
| sd = delete_states(load_weights(path), ignore_keys) | |
| if make_it_fit: | |
| sd = fitting_weights(self, sd) | |
| missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model \ | |
| else self.model.load_state_dict(sd, strict=False) | |
| filtered_missing = [] | |
| filtered_unexpect = [] | |
| for k in missing: | |
| if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1 and not k.find("fg") > -1: | |
| filtered_missing.append(k) | |
| for k in unexpected: | |
| if not k.find("cond_stage_model") > -1 and not k.find("img_embedder") > -1: | |
| filtered_unexpect.append(k) | |
| print( | |
| f"Restored from {path} with {len(filtered_missing)} filtered missing and " | |
| f"{len(filtered_unexpect)} filtered unexpected keys") | |
| if logging: | |
| if len(missing) > 0: | |
| print(f"Filtered missing Keys: {filtered_missing}") | |
| if len(unexpected) > 0: | |
| print(f"Filtered unexpected Keys: {filtered_unexpect}") | |
| def sample( | |
| self, | |
| cond: dict, | |
| uncond: Union[dict, list[dict]] = None, | |
| cfg_scale: Union[float, list[float]] = 1., | |
| bs: int = 1, | |
| shape: Union[tuple, list] = None, | |
| step: int = 20, | |
| sampler = "DPM++ 3M SDE", | |
| scheduler = "Automatic", | |
| device = "cuda", | |
| x_T = None, | |
| seed = None, | |
| deterministic = False, | |
| **kwargs | |
| ): | |
| shape = shape or (self.channels, self.image_size, self.image_size) | |
| x = x_T or torch.randn(bs, *shape, device=device) | |
| if exists(uncond): | |
| cond = reconstruct_cond(cond, uncond) | |
| if sampler.startswith("diffuser"): | |
| # Using huggingface diffuser noise sampler and scheduler | |
| sampler = DiffuserDenoiser( | |
| sampler, | |
| prediction_type = "v_prediction" if self.parameterization == "v" else "epsilon", | |
| use_karras = scheduler == "Karras" | |
| ) | |
| samples = sampler( | |
| x, | |
| cond, | |
| cond_scale=cfg_scale, | |
| unet=self, | |
| timesteps=step, | |
| generator=torch.manual_seed(seed) if exists(seed) else None, | |
| device=device | |
| ) | |
| else: | |
| # Using k-diffusion sampler and noise scheduler | |
| seed = seed or torch.seed() | |
| sampler = KDiffusionSampler(sampler, scheduler, self, device) | |
| sigmas = sampler.get_sigmas(step) | |
| extra_args = { | |
| "cond": cond, | |
| "cond_scale": cfg_scale, | |
| } | |
| seed = [seed for _ in range(bs)] if deterministic else seed | |
| samples = sampler(x, sigmas, extra_args, seed, deterministic, step) | |
| return samples | |
| def switch_to_fp16(self): | |
| unet = self.model.diffusion_model | |
| unet.input_blocks = unet.input_blocks.to(self.half_precision_dtype) | |
| unet.middle_block = unet.middle_block.to(self.half_precision_dtype) | |
| unet.output_blocks = unet.output_blocks.to(self.half_precision_dtype) | |
| self.dtype = self.half_precision_dtype | |
| unet.dtype = self.half_precision_dtype | |
| def switch_to_fp32(self): | |
| unet = self.model.diffusion_model | |
| unet.input_blocks = unet.input_blocks.float() | |
| unet.middle_block = unet.middle_block.float() | |
| unet.output_blocks = unet.output_blocks.float() | |
| self.dtype = torch.float32 | |
| unet.dtype = torch.float32 | |
| def switch_vae_to_fp16(self): | |
| self.first_stage_model = self.first_stage_model.to(self.half_precision_dtype) | |
| def switch_vae_to_fp32(self): | |
| self.first_stage_model = self.first_stage_model.float() | |
| def low_vram_shift(self, cuda_list: Union[str, list[str]]): | |
| if not isinstance(cuda_list, list): | |
| cuda_list = [cuda_list] | |
| cpu_list = self.model_list.keys() - cuda_list | |
| for model in cpu_list: | |
| self.model_list[model] = self.model_list[model].cpu() | |
| torch.cuda.empty_cache() | |
| for model in cuda_list: | |
| self.model_list[model] = self.model_list[model].cuda() | |
| def retrieve_attn_modules(self): | |
| from refnet.modules.transformer import BasicTransformerBlock | |
| from refnet.sampling import torch_dfs | |
| scale_factor_levels = {"high": 0.5, "low": 0.25, "bottom": 0.25} | |
| attn_modules = [] | |
| for module in torch_dfs(self.model.diffusion_model): | |
| if isinstance(module, BasicTransformerBlock): | |
| attn_modules.append(module) | |
| self.attn_modules = { | |
| "high": [0, 1, 2, 3] + [64, 65, 66, 67, 68, 69], | |
| "low": [i for i in range(4, 24)] + [i for i in range(34, 64)], | |
| "bottom": [i for i in range(24, 34)], | |
| "encoder": [i for i in range(24)], | |
| "decoder": [i for i in range(34, len(attn_modules))] | |
| } | |
| self.attn_modules["modules"] = attn_modules | |
| for k in ["high", "low", "bottom"]: | |
| scale_factor = scale_factor_levels[k] | |
| for attn in self.attn_modules[k]: | |
| attn_modules[attn].scale_factor = scale_factor | |
| def retrieve_attn_layers(self): | |
| self.attn_layers = [] | |
| for module in (self.attn_modules["modules"]): | |
| if hasattr(module, "attn2") and exists(getattr(module, "attn2")): | |
| self.attn_layers.append(module.attn2) | |
| class CustomizedColorizer(CustomizedLDM): | |
| def __init__( | |
| self, | |
| control_encoder_config, | |
| proj_config, | |
| token_type = "full", | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.control_encoder = instantiate_from_config(control_encoder_config) | |
| self.proj = instantiate_from_config(proj_config) | |
| self.token_type = token_type | |
| self.model_list.update({"control_encoder": self.control_encoder, "proj": self.proj}) | |
| self.switch_cond_modules += ["control_encoder", "proj"] | |
| def switch_to_fp16(self): | |
| self.control_encoder = self.control_encoder.to(self.half_precision_dtype) | |
| super().switch_to_fp16() | |
| def switch_to_fp32(self): | |
| self.control_encoder = self.control_encoder.float() | |
| super().switch_to_fp32() | |
| from refnet.modules.unet import hack_inference_forward | |
| class CustomizedWrapper: | |
| def __init__(self): | |
| self.scaling_sample = False | |
| self.guidance_steps = (0, 1) | |
| self.no_guidance_steps = (-0.05, 0.05) | |
| hack_inference_forward(self.model.diffusion_model) | |
| def adjust_reference_scale(self, scale_kwargs): | |
| if isinstance(scale_kwargs, dict): | |
| if scale_kwargs["level_control"]: | |
| for key in scale_kwargs["scales"]: | |
| if key == "middle": | |
| continue | |
| for idx in self.attn_modules[key]: | |
| self.attn_modules["modules"][idx].reference_scale = scale_kwargs["scales"][key] | |
| else: | |
| for idx, s in enumerate(scale_kwargs["scales"]): | |
| self.attn_modules["modules"][idx].reference_scale = s | |
| else: | |
| for module in self.attn_modules["modules"]: | |
| module.reference_scale = scale_kwargs | |
| def adjust_fgbg_scale(self, fg_scale, bg_scale, merge_scale, mask_threshold): | |
| for layer in self.attn_layers: | |
| layer.fg_scale = fg_scale | |
| layer.bg_scale = bg_scale | |
| layer.merge_scale = merge_scale | |
| layer.mask_threshold = mask_threshold | |
| # for layer in self.attn_modules["modules"]: | |
| # layer.fg_scale = fg_scale | |
| # layer.bg_scale = bg_scale | |
| # layer.merge_scale = merge_scale | |
| # layer.mask_threshold = mask_threshold | |
| def apply_model(self, x_noisy, t, cond): | |
| tr = 1 - t[0] / (self.num_timesteps - 1) | |
| crossattn = cond["context"][0] | |
| if ((tr < self.guidance_steps[0] or tr > self.guidance_steps[1]) or | |
| (tr >= self.no_guidance_steps[0] and tr <= self.no_guidance_steps[1])): | |
| crossattn = torch.zeros_like(crossattn)[:, :1] | |
| cond["context"] = [crossattn] | |
| model_cond = {k: v for k, v in cond.items() if k != "inpaint_bg"} | |
| return self.model(x_noisy, t, **model_cond) | |
| def prepare_conditions(self, *args, **kwargs): | |
| raise NotImplementedError("Inputs preprocessing function is not implemented.") | |
| def check_manipulate(self, scales): | |
| if exists(scales) and len(scales) > 0: | |
| for scale in scales: | |
| if scale > 0: | |
| return True | |
| return False | |
| def generate( | |
| self, | |
| # Conditional inputs | |
| cond: dict, | |
| ctl_scale: Union[float|list[float]], | |
| merge_scale: float, | |
| mask_scale: float, | |
| mask_thresh: float, | |
| mask_thresh_sketch: float, | |
| # Sampling settings | |
| sampler, | |
| scheduler, | |
| step: int, | |
| bs: int, | |
| gs: list[float], | |
| strength: Union[float, list[float]], | |
| fg_strength: float, | |
| bg_strength: float, | |
| seed: int, | |
| start_step: float = 0.0, | |
| end_step: float = 1.0, | |
| no_start_step: float = -0.05, | |
| no_end_step: float = -0.05, | |
| deterministic: bool = False, | |
| style_enhance: bool = False, | |
| bg_enhance: bool = False, | |
| fg_enhance: bool = False, | |
| latent_inpaint: bool = False, | |
| height: int = 512, | |
| width: int = 512, | |
| # Injection settings | |
| injection: bool = False, | |
| injection_cfg: float = 0.5, | |
| injection_control: float = 0, | |
| injection_start_step: float = 0, | |
| hook_xr: torch.Tensor = None, | |
| hook_xs: torch.Tensor = None, | |
| # Additional settings | |
| low_vram: bool = True, | |
| return_intermediate = False, | |
| manipulation_params = None, | |
| **kwargs, | |
| ): | |
| """ | |
| User interface function. | |
| """ | |
| hook_unet = UnetHook() | |
| self.guidance_steps = (start_step, end_step) | |
| self.no_guidance_steps = (no_start_step, no_end_step) | |
| self.adjust_reference_scale(strength) | |
| self.adjust_fgbg_scale(fg_strength, bg_strength, merge_scale, mask_thresh_sketch) | |
| if low_vram: | |
| self.low_vram_shift(self.switch_cond_modules) | |
| else: | |
| self.low_vram_shift(list(self.model_list.keys())) | |
| c, uc = self.prepare_conditions( | |
| bs = bs, | |
| control_scale = ctl_scale, | |
| merge_scale = merge_scale, | |
| mask_scale = mask_scale, | |
| mask_threshold_ref = mask_thresh, | |
| mask_threshold_sketch = mask_thresh_sketch, | |
| style_enhance = style_enhance, | |
| bg_enhance = bg_enhance, | |
| fg_enhance = fg_enhance, | |
| latent_inpaint = latent_inpaint, | |
| height = height, | |
| width = width, | |
| bg_strength = bg_strength, | |
| low_vram = low_vram, | |
| **cond, | |
| **manipulation_params, | |
| **kwargs | |
| ) | |
| cfg = int(gs[0] > 1) * GuidanceFlag.reference + int(gs[1] > 1) * GuidanceFlag.sketch | |
| gr_indice = [] if (cfg == GuidanceFlag.none or cfg == GuidanceFlag.sketch) else [i for i in range(bs, bs*2)] | |
| repeat = 1 | |
| if cfg == GuidanceFlag.none: | |
| gs = 1 | |
| uc = None | |
| if cfg == GuidanceFlag.reference: | |
| gs = gs[0] | |
| uc = uc[0] | |
| repeat = 2 | |
| if cfg == GuidanceFlag.sketch: | |
| gs = gs[1] | |
| uc = uc[1] | |
| repeat = 2 | |
| if cfg == GuidanceFlag.both: | |
| repeat = 3 | |
| if low_vram: | |
| self.low_vram_shift("first") | |
| if injection: | |
| rx = self.get_first_stage_encoding(hook_xr.to(self.first_stage_model.dtype)) | |
| hook_unet.enhance_reference( | |
| model = self.model, | |
| ldm = self, | |
| bs = bs * repeat, | |
| s = -hook_xr.to(self.dtype), | |
| r = rx, | |
| style_cfg = injection_cfg, | |
| control_cfg = injection_control, | |
| gr_indice = gr_indice, | |
| start_step = injection_start_step, | |
| ) | |
| if low_vram: | |
| self.low_vram_shift(self.switch_main_modules) | |
| z = self.sample( | |
| cond = c, | |
| uncond = uc, | |
| bs = bs, | |
| shape = (self.channels, height // 8, width // 8), | |
| cfg_scale = gs, | |
| step = step, | |
| sampler = sampler, | |
| scheduler = scheduler, | |
| seed = seed, | |
| deterministic = deterministic, | |
| return_intermediate = return_intermediate, | |
| ) | |
| if injection: | |
| hook_unet.restore(self.model) | |
| if low_vram: | |
| self.low_vram_shift("first") | |
| return self.decode_first_stage(z.to(self.first_stage_model.dtype)) |