| |
| import os |
| import random |
| import math |
|
|
| import numpy as np |
| from tqdm import tqdm |
| from omegaconf import OmegaConf |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| from Models.models.transformer import MaskTransformer |
| from Models.models.vqgan import VQModel |
|
|
|
|
| class MaskGIT(nn.Module): |
|
|
| def __init__(self, args): |
| """ Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc.""" |
| super().__init__() |
|
|
| self.args = args |
| self.patch_size = self.args.img_size // 16 |
| self.scaler = torch.cuda.amp.GradScaler() |
| self.vit = self.get_network("vit") |
| self.ae = self.get_network("autoencoder") |
|
|
| def get_network(self, archi): |
| """ return the network, load checkpoint if self.args.resume == True |
| :param |
| archi -> str: vit|autoencoder, the architecture to load |
| :return |
| model -> nn.Module: the network |
| """ |
| if archi == "vit": |
| if self.args.vit_size == "base": |
| model = MaskTransformer( |
| img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 |
| ) |
| elif self.args.vit_size == "big": |
| model = MaskTransformer( |
| img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 |
| ) |
| elif self.args.vit_size == "huge": |
| model = MaskTransformer( |
| img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 |
| ) |
|
|
| if self.args.resume: |
| ckpt = self.args.vit_folder |
| ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else "" |
| if self.args.is_master: |
| print("load ckpt from:", ckpt) |
| |
| checkpoint = torch.load(ckpt, map_location='cpu') |
| |
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) |
|
|
| model = model.to(self.args.device) |
|
|
| if self.args.is_multi_gpus: |
| model = DDP(model, device_ids=[self.args.device]) |
|
|
| elif archi == "autoencoder": |
| |
| config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml")) |
| model = VQModel(**config.model.params) |
| checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"] |
| |
| model.load_state_dict(checkpoint, strict=False) |
| model = model.eval() |
| model = model.to(self.args.device) |
|
|
| if self.args.is_multi_gpus: |
| model = DDP(model, device_ids=[self.args.device]) |
| model = model.module |
| else: |
| model = None |
|
|
| if self.args.is_master: |
| print(f"Size of model {archi}: " |
| f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M") |
|
|
| return model |
|
|
| def adap_sche(self, step, mode="arccos", leave=False): |
| """ Create a sampling scheduler |
| :param |
| step -> int: number of prediction during inference |
| mode -> str: the rate of value to unmask |
| leave -> bool: tqdm arg on either to keep the bar or not |
| :return |
| scheduler -> torch.LongTensor(): the list of token to predict at each step |
| """ |
| r = torch.linspace(1, 0, step) |
| if mode == "root": |
| val_to_mask = 1 - (r ** .5) |
| elif mode == "linear": |
| val_to_mask = 1 - r |
| elif mode == "square": |
| val_to_mask = 1 - (r ** 2) |
| elif mode == "cosine": |
| val_to_mask = torch.cos(r * math.pi * 0.5) |
| elif mode == "arccos": |
| val_to_mask = torch.arccos(r) / (math.pi * 0.5) |
| else: |
| return |
|
|
| |
| sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size) |
| sche = sche.round() |
| sche[sche == 0] = 1 |
| sche[-1] += (self.patch_size * self.patch_size) - sche.sum() |
| return tqdm(sche.int(), leave=leave) |
|
|
| def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3, |
| randomize="linear", r_temp=4.5, sched_mode="arccos", step=12): |
| """ Generate sample with the MaskGIT model |
| :param |
| init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code |
| nb_sample -> int: the number of image to generated |
| labels -> torch.LongTensor: the list of classes to generate |
| sm_temp -> float: the temperature before softmax |
| w -> float: scale for the classifier free guidance |
| randomize -> str: linear|warm_up|random|no, either or not to add randomness |
| r_temp -> float: temperature for the randomness |
| sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler |
| step: -> int: number of step for the decoding |
| :return |
| x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images |
| code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images |
| """ |
| self.vit.eval() |
| l_codes = [] |
| l_mask = [] |
| with torch.no_grad(): |
| if labels is None: |
| |
| labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10) |
| labels = torch.LongTensor(labels).to(self.args.device) |
|
|
| drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device) |
| if init_code is not None: |
| code = init_code |
| mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size) |
| else: |
| if self.args.mask_value < 0: |
| code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device) |
| else: |
| code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device) |
| mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device) |
|
|
| |
| if isinstance(sched_mode, str): |
| scheduler = self.adap_sche(step, mode=sched_mode) |
| else: |
| scheduler = sched_mode |
|
|
| |
| for indice, t in enumerate(scheduler): |
| if mask.sum() < t: |
| t = int(mask.sum().item()) |
|
|
| if mask.sum() == 0: |
| break |
|
|
| with torch.cuda.amp.autocast(): |
| if w != 0: |
| |
| logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0), |
| torch.cat([labels, labels], dim=0), |
| torch.cat([~drop, drop], dim=0)) |
| logit_c, logit_u = torch.chunk(logit, 2, dim=0) |
| _w = w * (indice / (len(scheduler)-1)) |
| |
| logit = (1 + _w) * logit_c - _w * logit_u |
| else: |
| logit = self.vit(code.clone(), labels, drop_label=~drop) |
|
|
| prob = torch.softmax(logit * sm_temp, -1) |
| |
| distri = torch.distributions.Categorical(probs=prob) |
| pred_code = distri.sample() |
|
|
| conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1)) |
|
|
| if randomize == "linear": |
| ratio = (indice / len(scheduler)) |
| rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio) |
| conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device) |
| elif randomize == "warm_up": |
| conf = torch.rand_like(conf) if indice < 2 else conf |
| elif randomize == "random": |
| conf = torch.rand_like(conf) |
|
|
| |
| conf[~mask.bool()] = -math.inf |
|
|
| |
| tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1) |
| tresh_conf = tresh_conf[:, -1] |
|
|
| |
| conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size) |
| f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool() |
| code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask] |
|
|
| |
| for i_mask, ind_mask in enumerate(indice_mask): |
| mask[i_mask, ind_mask] = 0 |
| l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone()) |
| l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone()) |
|
|
| |
| _code = torch.clamp(code, 0, 1023) |
| x = self.ae.decode_code(_code) |
| x = (torch.clamp(x, -1, 1) + 1) / 2 |
| self.vit.train() |
| return x, l_codes, l_mask |
|
|