File size: 6,832 Bytes
d066167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import re
import os.path as osp

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tf
from torch.utils.checkpoint import checkpoint

import numpy as np
import itertools
import importlib

from tqdm import tqdm
from inspect import isfunction
from functools import wraps
from safetensors import safe_open



def exists(x):
    return x is not None

def append_dims(x, target_dims) -> torch.Tensor:
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(
            f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
        )
    return x[(...,) + (None,) * dims_to_append]


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def expand_to_batch_size(x, bs):
    if isinstance(x, list):
        x = [xi.repeat(bs, *([1] * (len(xi.shape) - 1))) for xi in x]
    else:
        x = x.repeat(bs, *([1] * (len(x.shape) - 1)))
    return x


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

def instantiate_from_config(config):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


def scaled_resize(x: torch.Tensor, scale_factor, interpolation_mode="bicubic"):
    return F.interpolate(x, scale_factor=scale_factor, mode=interpolation_mode)

def get_crop_scale(h, w, bgh, bgw):
    gen_aspect = w / h
    bg_aspect = bgw / bgh
    if gen_aspect > bg_aspect:
        cw = 1.0
        ch = (h / w) * (bgw / bgh)
    else:
        ch = 1.0
        cw = (w / h) * (bgh / bgw)
    return ch, cw

def warp_resize(x: torch.Tensor, target_size, interpolation_mode="bicubic"):
    assert len(x.shape) == 4
    return F.interpolate(x, size=target_size, mode=interpolation_mode)

def resize_and_crop(x: torch.Tensor, ch, cw, th, tw):
    b, c, h, w = x.shape
    return tf.resized_crop(x, 0, 0, int(ch * h), int(cw * w), size=[th, tw])


def fitting_weights(model, sd):
    n_params = len([name for name, _ in
                    itertools.chain(model.named_parameters(),
                                    model.named_buffers())])
    for name, param in tqdm(
            itertools.chain(model.named_parameters(),
                            model.named_buffers()),
            desc="Fitting old weights to new weights",
            total=n_params
    ):
        if not name in sd:
            continue
        old_shape = sd[name].shape
        new_shape = param.shape
        assert len(old_shape) == len(new_shape)
        if len(new_shape) > 2:
            # we only modify first two axes
            assert new_shape[2:] == old_shape[2:]
        # assumes first axis corresponds to output dim
        if not new_shape == old_shape:
            new_param = param.clone()
            old_param = sd[name]
            device = old_param.device
            if len(new_shape) == 1:
                # Vectorized 1D case
                new_param = old_param[torch.arange(new_shape[0], device=device) % old_shape[0]]
            elif len(new_shape) >= 2:
                # Vectorized 2D case
                i_indices = torch.arange(new_shape[0], device=device)[:, None] % old_shape[0]
                j_indices = torch.arange(new_shape[1], device=device)[None, :] % old_shape[1]

                # Use advanced indexing to extract all values at once
                new_param = old_param[i_indices, j_indices]

                # Count how many times each old column is used
                n_used_old = torch.bincount(
                    torch.arange(new_shape[1], device=device) % old_shape[1],
                    minlength=old_shape[1]
                )

                # Map to new shape
                n_used_new = n_used_old[torch.arange(new_shape[1], device=device) % old_shape[1]]

                # Reshape for broadcasting
                n_used_new = n_used_new.reshape(1, new_shape[1])
                while len(n_used_new.shape) < len(new_shape):
                    n_used_new = n_used_new.unsqueeze(-1)

                # Normalize
                new_param = new_param / n_used_new

            sd[name] = new_param
    return sd


def count_params(model, verbose=False):
    total_params = sum(p.numel() for p in model.parameters())
    if verbose:
        print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
    return total_params


VALID_FORMATS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]

def load_weights(path, weights_only=True):
    ext = osp.splitext(path)[-1]
    assert ext in VALID_FORMATS, f"Invalid checkpoint format {ext}"
    if ext == ".safetensors":
        sd = {}
        safe_sd = safe_open(path, framework="pt", device="cpu")
        for key in safe_sd.keys():
            sd[key] = safe_sd.get_tensor(key)
    else:
        sd = torch.load(path, map_location="cpu", weights_only=weights_only)
        if "state_dict" in sd.keys():
            sd = sd["state_dict"]
    return sd


def delete_states(sd, delete_keys: list[str] = (), skip_keys: list[str] = ()):
    keys = list(sd.keys())
    for k in keys:
        for ik in delete_keys:
            if len(skip_keys) > 0:
                for sk in skip_keys:
                    if re.match(ik, k) is not None and re.match(sk, k) is None:
                        del sd[k]
            else:
                if re.match(ik, k) is not None:
                    del sd[k]
    return sd


def autocast(f, enabled=True):
    def do_autocast(*args, **kwargs):
        with torch.cuda.amp.autocast(
            enabled=enabled,
            dtype=torch.get_autocast_gpu_dtype(),
            cache_enabled=torch.is_autocast_cache_enabled(),
        ):
            return f(*args, **kwargs)

    return do_autocast


def checkpoint_wrapper(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        if not hasattr(self, 'checkpoint') or self.checkpoint:
            def bound_func(*args, **kwargs):
                return func(self, *args, **kwargs)
            return checkpoint(bound_func, *args, use_reentrant=False, **kwargs)
        else:
            return func(self, *args, **kwargs)
    return wrapper