|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
def create_logger(logging_dir): |
|
|
""" |
|
|
Create a logger that writes to a log file and stdout. |
|
|
""" |
|
|
if dist.get_rank() == 0: |
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="[\033[34m%(asctime)s\033[0m] %(message)s", |
|
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
|
handlers=[ |
|
|
logging.StreamHandler(), |
|
|
logging.FileHandler(f"{logging_dir}/log.txt"), |
|
|
], |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
else: |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.addHandler(logging.NullHandler()) |
|
|
return logger |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def update_ema(ema_model, model, decay=0.9999): |
|
|
""" |
|
|
Step the EMA model towards the current model. |
|
|
""" |
|
|
ema_ps = [] |
|
|
ps = [] |
|
|
|
|
|
for e, m in zip(ema_model.parameters(), model.parameters()): |
|
|
if m.requires_grad: |
|
|
ema_ps.append(e) |
|
|
ps.append(m) |
|
|
torch._foreach_lerp_(ema_ps, ps, 1.0 - decay) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sync_frozen_params_once(ema_model, model): |
|
|
for e, m in zip(ema_model.parameters(), model.parameters()): |
|
|
if not m.requires_grad: |
|
|
e.copy_(m) |
|
|
|
|
|
|
|
|
def requires_grad(model, flag=True): |
|
|
""" |
|
|
Set requires_grad flag for all parameters in a model. |
|
|
""" |
|
|
for p in model.parameters(): |
|
|
p.requires_grad = flag |
|
|
|
|
|
def patchify_raster(x, p): |
|
|
B, C, H, W = x.shape |
|
|
|
|
|
assert H % p == 0 and W % p == 0, f"Image dimensions ({H},{W}) must be divisible by patch size {p}" |
|
|
|
|
|
h_patches = H // p |
|
|
w_patches = W // p |
|
|
|
|
|
|
|
|
x = x.view(B, C, h_patches, p, w_patches, p) |
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 4, 3, 5, 1).contiguous() |
|
|
|
|
|
x = x.reshape(B, -1, C) |
|
|
|
|
|
return x |
|
|
|
|
|
def unpatchify_raster(x, p, target_shape): |
|
|
B, N, C = x.shape |
|
|
H, W = target_shape |
|
|
|
|
|
h_patches = H // p |
|
|
w_patches = W // p |
|
|
|
|
|
x = x.view(B, h_patches, w_patches, p, p, C) |
|
|
|
|
|
x = x.permute(0, 5, 1, 3, 2, 4).contiguous() |
|
|
|
|
|
x = x.reshape(B, C, H, W) |
|
|
|
|
|
return x |
|
|
|
|
|
def patchify_raster_2d(x: torch.Tensor, p: int, H: int, W: int) -> torch.Tensor: |
|
|
N, C1, C2 = x.shape |
|
|
|
|
|
assert N == H * W, f"N ({N}) must equal H*W ({H*W})" |
|
|
assert H % p == 0 and W % p == 0, f"H/W ({H}/{W}) must be divisible by patch size {p}" |
|
|
|
|
|
C_prime = C1 * C2 |
|
|
x_flat = x.view(N, C_prime) |
|
|
|
|
|
x_2d = x_flat.view(H, W, C_prime) |
|
|
|
|
|
h_patches = H // p |
|
|
w_patches = W // p |
|
|
|
|
|
x_split = x_2d.view(h_patches, p, w_patches, p, C_prime) |
|
|
|
|
|
x_permuted = x_split.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
x_reordered = x_permuted.reshape(N, C_prime) |
|
|
|
|
|
out = x_reordered.view(N, C1, C2) |
|
|
|
|
|
return out |