Spaces:
Running on Zero
Running on Zero
File size: 6,832 Bytes
d066167 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | import re
import os.path as osp
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tf
from torch.utils.checkpoint import checkpoint
import numpy as np
import itertools
import importlib
from tqdm import tqdm
from inspect import isfunction
from functools import wraps
from safetensors import safe_open
def exists(x):
return x is not None
def append_dims(x, target_dims) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def expand_to_batch_size(x, bs):
if isinstance(x, list):
x = [xi.repeat(bs, *([1] * (len(xi.shape) - 1))) for xi in x]
else:
x = x.repeat(bs, *([1] * (len(x.shape) - 1)))
return x
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def scaled_resize(x: torch.Tensor, scale_factor, interpolation_mode="bicubic"):
return F.interpolate(x, scale_factor=scale_factor, mode=interpolation_mode)
def get_crop_scale(h, w, bgh, bgw):
gen_aspect = w / h
bg_aspect = bgw / bgh
if gen_aspect > bg_aspect:
cw = 1.0
ch = (h / w) * (bgw / bgh)
else:
ch = 1.0
cw = (w / h) * (bgh / bgw)
return ch, cw
def warp_resize(x: torch.Tensor, target_size, interpolation_mode="bicubic"):
assert len(x.shape) == 4
return F.interpolate(x, size=target_size, mode=interpolation_mode)
def resize_and_crop(x: torch.Tensor, ch, cw, th, tw):
b, c, h, w = x.shape
return tf.resized_crop(x, 0, 0, int(ch * h), int(cw * w), size=[th, tw])
def fitting_weights(model, sd):
n_params = len([name for name, _ in
itertools.chain(model.named_parameters(),
model.named_buffers())])
for name, param in tqdm(
itertools.chain(model.named_parameters(),
model.named_buffers()),
desc="Fitting old weights to new weights",
total=n_params
):
if not name in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
assert len(old_shape) == len(new_shape)
if len(new_shape) > 2:
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
new_param = param.clone()
old_param = sd[name]
device = old_param.device
if len(new_shape) == 1:
# Vectorized 1D case
new_param = old_param[torch.arange(new_shape[0], device=device) % old_shape[0]]
elif len(new_shape) >= 2:
# Vectorized 2D case
i_indices = torch.arange(new_shape[0], device=device)[:, None] % old_shape[0]
j_indices = torch.arange(new_shape[1], device=device)[None, :] % old_shape[1]
# Use advanced indexing to extract all values at once
new_param = old_param[i_indices, j_indices]
# Count how many times each old column is used
n_used_old = torch.bincount(
torch.arange(new_shape[1], device=device) % old_shape[1],
minlength=old_shape[1]
)
# Map to new shape
n_used_new = n_used_old[torch.arange(new_shape[1], device=device) % old_shape[1]]
# Reshape for broadcasting
n_used_new = n_used_new.reshape(1, new_shape[1])
while len(n_used_new.shape) < len(new_shape):
n_used_new = n_used_new.unsqueeze(-1)
# Normalize
new_param = new_param / n_used_new
sd[name] = new_param
return sd
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
VALID_FORMATS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
def load_weights(path, weights_only=True):
ext = osp.splitext(path)[-1]
assert ext in VALID_FORMATS, f"Invalid checkpoint format {ext}"
if ext == ".safetensors":
sd = {}
safe_sd = safe_open(path, framework="pt", device="cpu")
for key in safe_sd.keys():
sd[key] = safe_sd.get_tensor(key)
else:
sd = torch.load(path, map_location="cpu", weights_only=weights_only)
if "state_dict" in sd.keys():
sd = sd["state_dict"]
return sd
def delete_states(sd, delete_keys: list[str] = (), skip_keys: list[str] = ()):
keys = list(sd.keys())
for k in keys:
for ik in delete_keys:
if len(skip_keys) > 0:
for sk in skip_keys:
if re.match(ik, k) is not None and re.match(sk, k) is None:
del sd[k]
else:
if re.match(ik, k) is not None:
del sd[k]
return sd
def autocast(f, enabled=True):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=enabled,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def checkpoint_wrapper(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not hasattr(self, 'checkpoint') or self.checkpoint:
def bound_func(*args, **kwargs):
return func(self, *args, **kwargs)
return checkpoint(bound_func, *args, use_reentrant=False, **kwargs)
else:
return func(self, *args, **kwargs)
return wrapper
|