File size: 2,807 Bytes
a1663bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import logging
import torch
import torch.distributed as dist
from torch.nn import functional as F
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/log.txt"),
],
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_ps = []
ps = []
for e, m in zip(ema_model.parameters(), model.parameters()):
if m.requires_grad:
ema_ps.append(e)
ps.append(m)
torch._foreach_lerp_(ema_ps, ps, 1.0 - decay)
@torch.no_grad()
def sync_frozen_params_once(ema_model, model):
for e, m in zip(ema_model.parameters(), model.parameters()):
if not m.requires_grad:
e.copy_(m)
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def patchify_raster(x, p):
B, C, H, W = x.shape
assert H % p == 0 and W % p == 0, f"Image dimensions ({H},{W}) must be divisible by patch size {p}"
h_patches = H // p
w_patches = W // p
x = x.view(B, C, h_patches, p, w_patches, p)
x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
x = x.reshape(B, -1, C)
return x
def unpatchify_raster(x, p, target_shape):
B, N, C = x.shape
H, W = target_shape
h_patches = H // p
w_patches = W // p
x = x.view(B, h_patches, w_patches, p, p, C)
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
x = x.reshape(B, C, H, W)
return x
def patchify_raster_2d(x: torch.Tensor, p: int, H: int, W: int) -> torch.Tensor:
N, C1, C2 = x.shape
assert N == H * W, f"N ({N}) must equal H*W ({H*W})"
assert H % p == 0 and W % p == 0, f"H/W ({H}/{W}) must be divisible by patch size {p}"
C_prime = C1 * C2
x_flat = x.view(N, C_prime)
x_2d = x_flat.view(H, W, C_prime)
h_patches = H // p
w_patches = W // p
x_split = x_2d.view(h_patches, p, w_patches, p, C_prime)
x_permuted = x_split.permute(0, 2, 1, 3, 4)
x_reordered = x_permuted.reshape(N, C_prime)
out = x_reordered.view(N, C1, C2)
return out |