import re import os.path as osp import torch import torch.nn.functional as F import torchvision.transforms.functional as tf from torch.utils.checkpoint import checkpoint import numpy as np import itertools import importlib from tqdm import tqdm from inspect import isfunction from functools import wraps from safetensors import safe_open def exists(x): return x is not None def append_dims(x, target_dims) -> torch.Tensor: """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def default(val, d): if exists(val): return val return d() if isfunction(d) else d def expand_to_batch_size(x, bs): if isinstance(x, list): x = [xi.repeat(bs, *([1] * (len(xi.shape) - 1))) for xi in x] else: x = x.repeat(bs, *([1] * (len(x.shape) - 1))) return x def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def scaled_resize(x: torch.Tensor, scale_factor, interpolation_mode="bicubic"): return F.interpolate(x, scale_factor=scale_factor, mode=interpolation_mode) def get_crop_scale(h, w, bgh, bgw): gen_aspect = w / h bg_aspect = bgw / bgh if gen_aspect > bg_aspect: cw = 1.0 ch = (h / w) * (bgw / bgh) else: ch = 1.0 cw = (w / h) * (bgh / bgw) return ch, cw def warp_resize(x: torch.Tensor, target_size, interpolation_mode="bicubic"): assert len(x.shape) == 4 return F.interpolate(x, size=target_size, mode=interpolation_mode) def resize_and_crop(x: torch.Tensor, ch, cw, th, tw): b, c, h, w = x.shape return tf.resized_crop(x, 0, 0, int(ch * h), int(cw * w), size=[th, tw]) def fitting_weights(model, sd): n_params = len([name for name, _ in itertools.chain(model.named_parameters(), model.named_buffers())]) for name, param in tqdm( itertools.chain(model.named_parameters(), model.named_buffers()), desc="Fitting old weights to new weights", total=n_params ): if not name in sd: continue old_shape = sd[name].shape new_shape = param.shape assert len(old_shape) == len(new_shape) if len(new_shape) > 2: # we only modify first two axes assert new_shape[2:] == old_shape[2:] # assumes first axis corresponds to output dim if not new_shape == old_shape: new_param = param.clone() old_param = sd[name] device = old_param.device if len(new_shape) == 1: # Vectorized 1D case new_param = old_param[torch.arange(new_shape[0], device=device) % old_shape[0]] elif len(new_shape) >= 2: # Vectorized 2D case i_indices = torch.arange(new_shape[0], device=device)[:, None] % old_shape[0] j_indices = torch.arange(new_shape[1], device=device)[None, :] % old_shape[1] # Use advanced indexing to extract all values at once new_param = old_param[i_indices, j_indices] # Count how many times each old column is used n_used_old = torch.bincount( torch.arange(new_shape[1], device=device) % old_shape[1], minlength=old_shape[1] ) # Map to new shape n_used_new = n_used_old[torch.arange(new_shape[1], device=device) % old_shape[1]] # Reshape for broadcasting n_used_new = n_used_new.reshape(1, new_shape[1]) while len(n_used_new.shape) < len(new_shape): n_used_new = n_used_new.unsqueeze(-1) # Normalize new_param = new_param / n_used_new sd[name] = new_param return sd def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") return total_params VALID_FORMATS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"] def load_weights(path, weights_only=True): ext = osp.splitext(path)[-1] assert ext in VALID_FORMATS, f"Invalid checkpoint format {ext}" if ext == ".safetensors": sd = {} safe_sd = safe_open(path, framework="pt", device="cpu") for key in safe_sd.keys(): sd[key] = safe_sd.get_tensor(key) else: sd = torch.load(path, map_location="cpu", weights_only=weights_only) if "state_dict" in sd.keys(): sd = sd["state_dict"] return sd def delete_states(sd, delete_keys: list[str] = (), skip_keys: list[str] = ()): keys = list(sd.keys()) for k in keys: for ik in delete_keys: if len(skip_keys) > 0: for sk in skip_keys: if re.match(ik, k) is not None and re.match(sk, k) is None: del sd[k] else: if re.match(ik, k) is not None: del sd[k] return sd def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.cuda.amp.autocast( enabled=enabled, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled(), ): return f(*args, **kwargs) return do_autocast def checkpoint_wrapper(func): @wraps(func) def wrapper(self, *args, **kwargs): if not hasattr(self, 'checkpoint') or self.checkpoint: def bound_func(*args, **kwargs): return func(self, *args, **kwargs) return checkpoint(bound_func, *args, use_reentrant=False, **kwargs) else: return func(self, *args, **kwargs) return wrapper