File size: 2,807 Bytes
5d61943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import logging

import torch
import torch.distributed as dist
from torch.nn import functional as F


def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    if dist.get_rank() == 0:  # real logger
        logging.basicConfig(
            level=logging.INFO,
            format="[\033[34m%(asctime)s\033[0m] %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
            handlers=[
                logging.StreamHandler(),
                logging.FileHandler(f"{logging_dir}/log.txt"),
            ],
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())
    return logger


@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_ps = []
    ps = []

    for e, m in zip(ema_model.parameters(), model.parameters()):
        if m.requires_grad:
            ema_ps.append(e)
            ps.append(m)
    torch._foreach_lerp_(ema_ps, ps, 1.0 - decay)


@torch.no_grad()
def sync_frozen_params_once(ema_model, model):
    for e, m in zip(ema_model.parameters(), model.parameters()):
        if not m.requires_grad:
            e.copy_(m)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag

def patchify_raster(x, p):
    B, C, H, W = x.shape
    
    assert H % p == 0 and W % p == 0, f"Image dimensions ({H},{W}) must be divisible by patch size {p}"
    
    h_patches = H // p
    w_patches = W // p
    

    x = x.view(B, C, h_patches, p, w_patches, p)
    

    x = x.permute(0, 2, 4, 3, 5, 1).contiguous()

    x = x.reshape(B, -1, C)
    
    return x

def unpatchify_raster(x, p, target_shape):
    B, N, C = x.shape
    H, W = target_shape
    
    h_patches = H // p
    w_patches = W // p
    
    x = x.view(B, h_patches, w_patches, p, p, C)
    
    x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
    
    x = x.reshape(B, C, H, W)
    
    return x

def patchify_raster_2d(x: torch.Tensor, p: int, H: int, W: int) -> torch.Tensor:
    N, C1, C2 = x.shape
    
    assert N == H * W, f"N ({N}) must equal H*W ({H*W})"
    assert H % p == 0 and W % p == 0, f"H/W ({H}/{W}) must be divisible by patch size {p}"
    
    C_prime = C1 * C2
    x_flat = x.view(N, C_prime)
    
    x_2d = x_flat.view(H, W, C_prime)
    
    h_patches = H // p
    w_patches = W // p
    
    x_split = x_2d.view(h_patches, p, w_patches, p, C_prime)
    
    x_permuted = x_split.permute(0, 2, 1, 3, 4)
    
    x_reordered = x_permuted.reshape(N, C_prime)
    
    out = x_reordered.view(N, C1, C2)
    
    return out