Spaces:
Runtime error
Runtime error
| # Trainer for MaskGIT | |
| import os | |
| import random | |
| import math | |
| import numpy as np | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from Models.models.transformer import MaskTransformer | |
| from Models.models.vqgan import VQModel | |
| class MaskGIT(nn.Module): | |
| def __init__(self, args): | |
| """ Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc.""" | |
| super().__init__() | |
| self.args = args # Main argument see main.py | |
| self.patch_size = self.args.img_size // 16 # Number of vizual token (+1 for the class) | |
| self.scaler = torch.cuda.amp.GradScaler() # Init Scaler for multi GPUs | |
| self.vit = self.get_network("vit") # Load Masked Bidirectional Transformer | |
| self.ae = self.get_network("autoencoder") # Load VQGAN | |
| def get_network(self, archi): | |
| """ return the network, load checkpoint if self.args.resume == True | |
| :param | |
| archi -> str: vit|autoencoder, the architecture to load | |
| :return | |
| model -> nn.Module: the network | |
| """ | |
| if archi == "vit": | |
| if self.args.vit_size == "base": | |
| model = MaskTransformer( | |
| img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 # Small | |
| ) | |
| elif self.args.vit_size == "big": | |
| model = MaskTransformer( | |
| img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 # Big | |
| ) | |
| elif self.args.vit_size == "huge": | |
| model = MaskTransformer( | |
| img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 # Huge | |
| ) | |
| if self.args.resume: | |
| ckpt = self.args.vit_folder | |
| ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else "" | |
| if self.args.is_master: | |
| print("load ckpt from:", ckpt) | |
| # Read checkpoint file | |
| checkpoint = torch.load(ckpt, map_location='cpu') | |
| # Load network | |
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) | |
| model = model.to(self.args.device) | |
| if self.args.is_multi_gpus: # put model on multi GPUs if available | |
| model = DDP(model, device_ids=[self.args.device]) | |
| elif archi == "autoencoder": | |
| # Load config | |
| config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml")) | |
| model = VQModel(**config.model.params) | |
| checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"] | |
| # Load network | |
| model.load_state_dict(checkpoint, strict=False) | |
| model = model.eval() | |
| model = model.to(self.args.device) | |
| if self.args.is_multi_gpus: # put model on multi GPUs if available | |
| model = DDP(model, device_ids=[self.args.device]) | |
| model = model.module | |
| else: | |
| model = None | |
| if self.args.is_master: | |
| print(f"Size of model {archi}: " | |
| f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M") | |
| return model | |
| def adap_sche(self, step, mode="arccos", leave=False): | |
| """ Create a sampling scheduler | |
| :param | |
| step -> int: number of prediction during inference | |
| mode -> str: the rate of value to unmask | |
| leave -> bool: tqdm arg on either to keep the bar or not | |
| :return | |
| scheduler -> torch.LongTensor(): the list of token to predict at each step | |
| """ | |
| r = torch.linspace(1, 0, step) | |
| if mode == "root": # root scheduler | |
| val_to_mask = 1 - (r ** .5) | |
| elif mode == "linear": # linear scheduler | |
| val_to_mask = 1 - r | |
| elif mode == "square": # square scheduler | |
| val_to_mask = 1 - (r ** 2) | |
| elif mode == "cosine": # cosine scheduler | |
| val_to_mask = torch.cos(r * math.pi * 0.5) | |
| elif mode == "arccos": # arc cosine scheduler | |
| val_to_mask = torch.arccos(r) / (math.pi * 0.5) | |
| else: | |
| return | |
| # fill the scheduler by the ratio of tokens to predict at each step | |
| sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size) | |
| sche = sche.round() | |
| sche[sche == 0] = 1 # add 1 to predict a least 1 token / step | |
| sche[-1] += (self.patch_size * self.patch_size) - sche.sum() # need to sum up nb of code | |
| return tqdm(sche.int(), leave=leave) | |
| def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3, | |
| randomize="linear", r_temp=4.5, sched_mode="arccos", step=12): | |
| """ Generate sample with the MaskGIT model | |
| :param | |
| init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code | |
| nb_sample -> int: the number of image to generated | |
| labels -> torch.LongTensor: the list of classes to generate | |
| sm_temp -> float: the temperature before softmax | |
| w -> float: scale for the classifier free guidance | |
| randomize -> str: linear|warm_up|random|no, either or not to add randomness | |
| r_temp -> float: temperature for the randomness | |
| sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler | |
| step: -> int: number of step for the decoding | |
| :return | |
| x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images | |
| code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images | |
| """ | |
| self.vit.eval() | |
| l_codes = [] # Save the intermediate codes predicted | |
| l_mask = [] # Save the intermediate masks | |
| with torch.no_grad(): | |
| if labels is None: # Default classes generated | |
| # goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random | |
| labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10) | |
| labels = torch.LongTensor(labels).to(self.args.device) | |
| drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device) | |
| if init_code is not None: # Start with a pre-define code | |
| code = init_code | |
| mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size) | |
| else: # Initialize a code | |
| if self.args.mask_value < 0: # Code initialize with random tokens | |
| code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device) | |
| else: # Code initialize with masked tokens | |
| code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device) | |
| mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device) | |
| # Instantiate scheduler | |
| if isinstance(sched_mode, str): # Standard ones | |
| scheduler = self.adap_sche(step, mode=sched_mode) | |
| else: # Custom one | |
| scheduler = sched_mode | |
| # Beginning of sampling, t = number of token to predict a step "indice" | |
| for indice, t in enumerate(scheduler): | |
| if mask.sum() < t: # Cannot predict more token than 16*16 or 32*32 | |
| t = int(mask.sum().item()) | |
| if mask.sum() == 0: # Break if code is fully predicted | |
| break | |
| with torch.cuda.amp.autocast(): # half precision | |
| if w != 0: | |
| # Model Prediction | |
| logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0), | |
| torch.cat([labels, labels], dim=0), | |
| torch.cat([~drop, drop], dim=0)) | |
| logit_c, logit_u = torch.chunk(logit, 2, dim=0) | |
| _w = w * (indice / (len(scheduler)-1)) | |
| # Classifier Free Guidance | |
| logit = (1 + _w) * logit_c - _w * logit_u | |
| else: | |
| logit = self.vit(code.clone(), labels, drop_label=~drop) | |
| prob = torch.softmax(logit * sm_temp, -1) | |
| # Sample the code from the softmax prediction | |
| distri = torch.distributions.Categorical(probs=prob) | |
| pred_code = distri.sample() | |
| conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1)) | |
| if randomize == "linear": # add gumbel noise decreasing over the sampling process | |
| ratio = (indice / len(scheduler)) | |
| rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio) | |
| conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device) | |
| elif randomize == "warm_up": # chose random sample for the 2 first steps | |
| conf = torch.rand_like(conf) if indice < 2 else conf | |
| elif randomize == "random": # chose random prediction at each step | |
| conf = torch.rand_like(conf) | |
| # do not predict on already predicted tokens | |
| conf[~mask.bool()] = -math.inf | |
| # chose the predicted token with the highest confidence | |
| tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1) | |
| tresh_conf = tresh_conf[:, -1] | |
| # replace the chosen tokens | |
| conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size) | |
| f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool() | |
| code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask] | |
| # update the mask | |
| for i_mask, ind_mask in enumerate(indice_mask): | |
| mask[i_mask, ind_mask] = 0 | |
| l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone()) | |
| l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone()) | |
| # decode the final prediction | |
| _code = torch.clamp(code, 0, 1023) # VQGAN has only 1024 codebook | |
| x = self.ae.decode_code(_code) | |
| x = (torch.clamp(x, -1, 1) + 1) / 2 | |
| self.vit.train() | |
| return x, l_codes, l_mask | |