| | |
| | import os |
| | import random |
| | import math |
| |
|
| | import numpy as np |
| | from tqdm import tqdm |
| | from omegaconf import OmegaConf |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| |
|
| | from Models.models.transformer import MaskTransformer |
| | from Models.models.vqgan import VQModel |
| |
|
| |
|
| | class MaskGIT(nn.Module): |
| |
|
| | def __init__(self, args): |
| | """ Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc.""" |
| | super().__init__() |
| |
|
| | self.args = args |
| | self.patch_size = self.args.img_size // 16 |
| | self.scaler = torch.cuda.amp.GradScaler() |
| | self.vit = self.get_network("vit") |
| | self.ae = self.get_network("autoencoder") |
| |
|
| | def get_network(self, archi): |
| | """ return the network, load checkpoint if self.args.resume == True |
| | :param |
| | archi -> str: vit|autoencoder, the architecture to load |
| | :return |
| | model -> nn.Module: the network |
| | """ |
| | if archi == "vit": |
| | if self.args.vit_size == "base": |
| | model = MaskTransformer( |
| | img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 |
| | ) |
| | elif self.args.vit_size == "big": |
| | model = MaskTransformer( |
| | img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 |
| | ) |
| | elif self.args.vit_size == "huge": |
| | model = MaskTransformer( |
| | img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 |
| | ) |
| |
|
| | if self.args.resume: |
| | ckpt = self.args.vit_folder |
| | ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else "" |
| | if self.args.is_master: |
| | print("load ckpt from:", ckpt) |
| | |
| | checkpoint = torch.load(ckpt, map_location='cpu') |
| | |
| | model.load_state_dict(checkpoint['model_state_dict'], strict=False) |
| |
|
| | model = model.to(self.args.device) |
| |
|
| | if self.args.is_multi_gpus: |
| | model = DDP(model, device_ids=[self.args.device]) |
| |
|
| | elif archi == "autoencoder": |
| | |
| | config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml")) |
| | model = VQModel(**config.model.params) |
| | checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"] |
| | |
| | model.load_state_dict(checkpoint, strict=False) |
| | model = model.eval() |
| | model = model.to(self.args.device) |
| |
|
| | if self.args.is_multi_gpus: |
| | model = DDP(model, device_ids=[self.args.device]) |
| | model = model.module |
| | else: |
| | model = None |
| |
|
| | if self.args.is_master: |
| | print(f"Size of model {archi}: " |
| | f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M") |
| |
|
| | return model |
| |
|
| | def adap_sche(self, step, mode="arccos", leave=False): |
| | """ Create a sampling scheduler |
| | :param |
| | step -> int: number of prediction during inference |
| | mode -> str: the rate of value to unmask |
| | leave -> bool: tqdm arg on either to keep the bar or not |
| | :return |
| | scheduler -> torch.LongTensor(): the list of token to predict at each step |
| | """ |
| | r = torch.linspace(1, 0, step) |
| | if mode == "root": |
| | val_to_mask = 1 - (r ** .5) |
| | elif mode == "linear": |
| | val_to_mask = 1 - r |
| | elif mode == "square": |
| | val_to_mask = 1 - (r ** 2) |
| | elif mode == "cosine": |
| | val_to_mask = torch.cos(r * math.pi * 0.5) |
| | elif mode == "arccos": |
| | val_to_mask = torch.arccos(r) / (math.pi * 0.5) |
| | else: |
| | return |
| |
|
| | |
| | sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size) |
| | sche = sche.round() |
| | sche[sche == 0] = 1 |
| | sche[-1] += (self.patch_size * self.patch_size) - sche.sum() |
| | return tqdm(sche.int(), leave=leave) |
| |
|
| | def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3, |
| | randomize="linear", r_temp=4.5, sched_mode="arccos", step=12): |
| | """ Generate sample with the MaskGIT model |
| | :param |
| | init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code |
| | nb_sample -> int: the number of image to generated |
| | labels -> torch.LongTensor: the list of classes to generate |
| | sm_temp -> float: the temperature before softmax |
| | w -> float: scale for the classifier free guidance |
| | randomize -> str: linear|warm_up|random|no, either or not to add randomness |
| | r_temp -> float: temperature for the randomness |
| | sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler |
| | step: -> int: number of step for the decoding |
| | :return |
| | x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images |
| | code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images |
| | """ |
| | self.vit.eval() |
| | l_codes = [] |
| | l_mask = [] |
| | with torch.no_grad(): |
| | if labels is None: |
| | |
| | labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10) |
| | labels = torch.LongTensor(labels).to(self.args.device) |
| |
|
| | drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device) |
| | if init_code is not None: |
| | code = init_code |
| | mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size) |
| | else: |
| | if self.args.mask_value < 0: |
| | code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device) |
| | else: |
| | code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device) |
| | mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device) |
| |
|
| | |
| | if isinstance(sched_mode, str): |
| | scheduler = self.adap_sche(step, mode=sched_mode) |
| | else: |
| | scheduler = sched_mode |
| |
|
| | |
| | for indice, t in enumerate(scheduler): |
| | if mask.sum() < t: |
| | t = int(mask.sum().item()) |
| |
|
| | if mask.sum() == 0: |
| | break |
| |
|
| | with torch.cuda.amp.autocast(): |
| | if w != 0: |
| | |
| | logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0), |
| | torch.cat([labels, labels], dim=0), |
| | torch.cat([~drop, drop], dim=0)) |
| | logit_c, logit_u = torch.chunk(logit, 2, dim=0) |
| | _w = w * (indice / (len(scheduler)-1)) |
| | |
| | logit = (1 + _w) * logit_c - _w * logit_u |
| | else: |
| | logit = self.vit(code.clone(), labels, drop_label=~drop) |
| |
|
| | prob = torch.softmax(logit * sm_temp, -1) |
| | |
| | distri = torch.distributions.Categorical(probs=prob) |
| | pred_code = distri.sample() |
| |
|
| | conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1)) |
| |
|
| | if randomize == "linear": |
| | ratio = (indice / len(scheduler)) |
| | rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio) |
| | conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device) |
| | elif randomize == "warm_up": |
| | conf = torch.rand_like(conf) if indice < 2 else conf |
| | elif randomize == "random": |
| | conf = torch.rand_like(conf) |
| |
|
| | |
| | conf[~mask.bool()] = -math.inf |
| |
|
| | |
| | tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1) |
| | tresh_conf = tresh_conf[:, -1] |
| |
|
| | |
| | conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size) |
| | f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool() |
| | code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask] |
| |
|
| | |
| | for i_mask, ind_mask in enumerate(indice_mask): |
| | mask[i_mask, ind_mask] = 0 |
| | l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone()) |
| | l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone()) |
| |
|
| | |
| | _code = torch.clamp(code, 0, 1023) |
| | x = self.ae.decode_code(_code) |
| | x = (torch.clamp(x, -1, 1) + 1) / 2 |
| | self.vit.train() |
| | return x, l_codes, l_mask |
| |
|