| import functools |
| import importlib |
| import os |
| import fsspec |
| import numpy as np |
| import torch |
|
|
| from dataclasses import dataclass |
| from functools import partial |
| from inspect import isfunction |
| from PIL import Image, ImageDraw, ImageFont |
| from safetensors.torch import load_file |
| from tqdm import tqdm |
|
|
|
|
| def create_npz_from_sample_folder(sample_dir, num=50_000): |
| """ |
| Builds a single .npz file from a folder of .png samples. |
| """ |
| samples = [] |
| imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split(".")[0])) |
| print(len(imgs)) |
| assert len(imgs) >= num |
| for i in tqdm(range(num), desc="Building .npz file from samples"): |
| sample_pil = Image.open(f"{sample_dir}/{imgs[i]}") |
| sample_np = np.asarray(sample_pil).astype(np.uint8) |
| samples.append(sample_np) |
| samples = np.stack(samples) |
| assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) |
| npz_path = f"{sample_dir}.npz" |
| np.savez(npz_path, arr_0=samples) |
| print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") |
| return npz_path |
|
|
|
|
| def init_from_ckpt(model, checkpoint_dir, ignore_keys=None, verbose=False) -> None: |
| if checkpoint_dir.endswith(".safetensors"): |
| model_state_dict = load_file(checkpoint_dir, device="cpu") |
| else: |
| model_state_dict = torch.load(checkpoint_dir, map_location="cpu") |
| model_new_ckpt = dict() |
| for i in model_state_dict.keys(): |
| model_new_ckpt[i] = model_state_dict[i] |
| keys = list(model_new_ckpt.keys()) |
| for k in keys: |
| if ignore_keys: |
| for ik in ignore_keys: |
| if ik in k: |
| print("Deleting key {} from state_dict.".format(k)) |
| del model_new_ckpt[k] |
| missing, unexpected = model.load_state_dict(model_new_ckpt, strict=False) |
| if verbose: |
| print( |
| f"Restored with {len(missing)} missing and {len(unexpected)} unexpected keys" |
| ) |
| if len(missing) > 0: |
| print(f"Missing Keys: {missing}") |
| if len(unexpected) > 0: |
| print(f"Unexpected Keys: {unexpected}") |
| if verbose: |
| print("") |
|
|
|
|
| def get_dtype(str_dtype): |
| if str_dtype == "fp16": |
| return torch.float16 |
| elif str_dtype == "bf16": |
| return torch.bfloat16 |
| else: |
| return torch.float32 |
|
|
|
|
| def disabled_train(self, mode=True): |
| """Overwrite model.train with this function to make sure train/eval mode |
| does not change anymore.""" |
| return self |
|
|
|
|
| def get_string_from_tuple(s): |
| try: |
| |
| if s[0] == "(" and s[-1] == ")": |
| |
| t = eval(s) |
| |
| if type(t) == tuple: |
| return t[0] |
| else: |
| pass |
| except: |
| pass |
| return s |
|
|
|
|
| def is_power_of_two(n): |
| """ |
| chat.openai.com/chat |
| Return True if n is a power of 2, otherwise return False. |
| |
| The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. |
| The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. |
| If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. |
| Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. |
| |
| """ |
| if n <= 0: |
| return False |
| return (n & (n - 1)) == 0 |
|
|
|
|
| def autocast(f, enabled=True): |
| def do_autocast(*args, **kwargs): |
| with torch.cuda.amp.autocast( |
| enabled=enabled, |
| dtype=torch.get_autocast_gpu_dtype(), |
| cache_enabled=torch.is_autocast_cache_enabled(), |
| ): |
| return f(*args, **kwargs) |
|
|
| return do_autocast |
|
|
|
|
| def load_partial_from_config(config): |
| return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) |
|
|
|
|
| def log_txt_as_img(wh, xc, size=10): |
| |
| |
| b = len(xc) |
| txts = list() |
| for bi in range(b): |
| txt = Image.new("RGB", wh, color="white") |
| draw = ImageDraw.Draw(txt) |
| font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) |
| nc = int(40 * (wh[0] / 256)) |
| if isinstance(xc[bi], list): |
| text_seq = xc[bi][0] |
| else: |
| text_seq = xc[bi] |
| lines = "\n".join( |
| text_seq[start : start + nc] for start in range(0, len(text_seq), nc) |
| ) |
|
|
| try: |
| draw.text((0, 0), lines, fill="black", font=font) |
| except UnicodeEncodeError: |
| print("Cant encode string for logging. Skipping.") |
|
|
| txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 |
| txts.append(txt) |
| txts = np.stack(txts) |
| txts = torch.tensor(txts) |
| return txts |
|
|
|
|
| def partialclass(cls, *args, **kwargs): |
| class NewCls(cls): |
| __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) |
|
|
| return NewCls |
|
|
|
|
| def make_path_absolute(path): |
| fs, p = fsspec.core.url_to_fs(path) |
| if fs.protocol == "file": |
| return os.path.abspath(p) |
| return path |
|
|
|
|
| def ismap(x): |
| if not isinstance(x, torch.Tensor): |
| return False |
| return (len(x.shape) == 4) and (x.shape[1] > 3) |
|
|
|
|
| def isimage(x): |
| if not isinstance(x, torch.Tensor): |
| return False |
| return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) |
|
|
|
|
| def isheatmap(x): |
| if not isinstance(x, torch.Tensor): |
| return False |
|
|
| return x.ndim == 2 |
|
|
|
|
| def isneighbors(x): |
| if not isinstance(x, torch.Tensor): |
| return False |
| return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) |
|
|
|
|
| def exists(x): |
| return x is not None |
|
|
|
|
| def expand_dims_like(x, y): |
| while x.dim() != y.dim(): |
| x = x.unsqueeze(-1) |
| return x |
|
|
|
|
| def default(val, d): |
| if exists(val): |
| return val |
| return d() if isfunction(d) else d |
|
|
|
|
| def mean_flat(tensor): |
| """ |
| https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 |
| Take the mean over all non-batch dimensions. |
| """ |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) |
|
|
|
|
| def count_params(model, verbose=False): |
| total_params = sum(p.numel() for p in model.parameters()) |
| if verbose: |
| print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.") |
| return total_params |
|
|
|
|
| def instantiate_from_config(config): |
| if "target" not in config: |
| if config == "__is_first_stage__": |
| return None |
| elif config == "__is_unconditional__": |
| return None |
| raise KeyError("Expected key `target` to instantiate.") |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
| def get_obj_from_str(string, reload=False, invalidate_cache=True): |
| module, cls = string.rsplit(".", 1) |
| if invalidate_cache: |
| importlib.invalidate_caches() |
| if reload: |
| module_imp = importlib.import_module(module) |
| importlib.reload(module_imp) |
| return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
| def append_zero(x): |
| return torch.cat([x, x.new_zeros([1])]) |
|
|
|
|
| def append_dims(x, target_dims): |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
| dims_to_append = target_dims - x.ndim |
| if dims_to_append < 0: |
| raise ValueError( |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" |
| ) |
| return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
| def load_model_from_config(config, ckpt, verbose=True, freeze=True): |
| print(f"Loading model from {ckpt}") |
| if ckpt.endswith("ckpt"): |
| pl_sd = torch.load(ckpt, map_location="cpu") |
| if "global_step" in pl_sd: |
| print(f"Global Step: {pl_sd['global_step']}") |
| sd = pl_sd["state_dict"] |
| elif ckpt.endswith("safetensors"): |
| sd = load_safetensors(ckpt) |
| elif ckpt.endswith("bin"): |
| sd = torch.load(ckpt, map_location="cpu") |
| else: |
| raise NotImplementedError |
|
|
| model = instantiate_from_config(config.model) |
|
|
| m, u = model.load_state_dict(sd, strict=False) |
|
|
| if len(m) > 0 and verbose: |
| print("missing keys:") |
| print(m) |
| if len(u) > 0 and verbose: |
| print("unexpected keys:") |
| print(u) |
|
|
| |
| |
| |
|
|
| model.eval() |
| return model |
|
|
|
|
| def format_number(num): |
| num = float(num) |
| num /= 1000.0 |
| return "{:.0f}{}".format(num, "k") |
|
|
|
|
| def get_num_params(model: torch.nn.ModuleList) -> int: |
| num_params = sum(p.numel() for p in model.parameters()) |
| return num_params |
|
|
|
|
| def get_num_flop_per_token(num_params, model_config, seq_len) -> int: |
| l, h, q, t = ( |
| model_config.n_layers, |
| model_config.n_heads, |
| model_config.dim // model_config.n_heads, |
| seq_len, |
| ) |
| |
| |
| |
| |
| |
| |
| flop_per_token = 6 * num_params + 12 * l * h * q * t |
|
|
| return flop_per_token |
|
|
|
|
| def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int: |
| l, h, q = ( |
| model_config.n_layers, |
| model_config.n_heads, |
| model_config.dim // model_config.n_heads, |
| ) |
|
|
| |
| |
| |
| flop_per_sequence = 6 * num_params + 12 * l * h * q * seq_len * seq_len |
|
|
| return flop_per_sequence |
|
|
|
|
| |
| def get_peak_flops(device_name: str) -> int: |
| if "A100" in device_name: |
| |
| return 312e12 |
| elif "H100" in device_name: |
| |
| |
| if "NVL" in device_name: |
| return 1979e12 |
| elif "PCIe" in device_name: |
| return 756e12 |
| else: |
| return 989e12 |
| else: |
| return 312e12 |
|
|
|
|
| @dataclass(frozen=True) |
| class Color: |
| black = "\033[30m" |
| red = "\033[31m" |
| green = "\033[32m" |
| yellow = "\033[33m" |
| blue = "\033[34m" |
| magenta = "\033[35m" |
| cyan = "\033[36m" |
| white = "\033[37m" |
| reset = "\033[39m" |
|
|
|
|
| @dataclass(frozen=True) |
| class NoColor: |
| black = "" |
| red = "" |
| green = "" |
| yellow = "" |
| blue = "" |
| magenta = "" |
| cyan = "" |
| white = "" |
| reset = "" |
|
|