Spaces:
Runtime error
Runtime error
llvictorll
commited on
Commit
·
9c4679e
1
Parent(s):
8d9cd7d
init
Browse files- Models/__init__.py +0 -0
- Models/__pycache__/__init__.cpython-38.pyc +0 -0
- Models/models/__init__.py +0 -0
- Models/models/__pycache__/__init__.cpython-38.pyc +0 -0
- Models/models/__pycache__/transformer.cpython-38.pyc +0 -0
- Models/models/__pycache__/vqgan.cpython-38.pyc +0 -0
- Models/models/transformer.py +117 -0
- Models/models/vqgan.py +294 -0
- Models/modules/diffusionmodules/__pycache__/model.cpython-38.pyc +0 -0
- Models/modules/diffusionmodules/model.py +436 -0
- Models/modules/util.py +130 -0
- Models/modules/vqvae/__pycache__/quantize.cpython-38.pyc +0 -0
- Models/modules/vqvae/quantize.py +335 -0
- Models/util.py +157 -0
- __init__.py +0 -0
- __pycache__/runner.cpython-38.pyc +0 -0
- flagged/log.csv +2 -0
- gradio_app.py +77 -0
- runner.py +221 -0
Models/__init__.py
ADDED
|
File without changes
|
Models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
Models/models/__init__.py
ADDED
|
File without changes
|
Models/models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
Models/models/__pycache__/transformer.cpython-38.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
Models/models/__pycache__/vqgan.cpython-38.pyc
ADDED
|
Binary file (8.6 kB). View file
|
|
|
Models/models/transformer.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BERT architecture for the Masked Bidirectional Encoder Transformer
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PreNorm(nn.Module):
|
| 7 |
+
def __init__(self, dim, fn):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.norm = nn.LayerNorm(dim)
|
| 10 |
+
self.fn = fn
|
| 11 |
+
|
| 12 |
+
def forward(self, x, **kwargs):
|
| 13 |
+
return self.fn(self.norm(x), **kwargs)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FeedForward(nn.Module):
|
| 17 |
+
def __init__(self, dim, hidden_dim, dropout=0.):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.net = nn.Sequential(
|
| 20 |
+
nn.Linear(dim, hidden_dim, bias=True),
|
| 21 |
+
nn.GELU(),
|
| 22 |
+
nn.Dropout(dropout),
|
| 23 |
+
nn.Linear(hidden_dim, dim, bias=True),
|
| 24 |
+
nn.Dropout(dropout)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
return self.net(x)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Attention(nn.Module):
|
| 32 |
+
def __init__(self, embed_dim, num_heads, dropout=0.):
|
| 33 |
+
super(Attention, self).__init__()
|
| 34 |
+
self.dim = embed_dim
|
| 35 |
+
self.mha = nn.MultiheadAttention(embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True, bias=True)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
attention_value, attention_weight = self.mha(x, x, x)
|
| 39 |
+
return attention_value, attention_weight
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TransformerEncoder(nn.Module):
|
| 43 |
+
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.layers = nn.ModuleList([])
|
| 46 |
+
for _ in range(depth):
|
| 47 |
+
self.layers.append(nn.ModuleList([
|
| 48 |
+
PreNorm(dim, Attention(dim, heads, dropout=dropout)),
|
| 49 |
+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
|
| 50 |
+
]))
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
l_attn = []
|
| 54 |
+
for attn, ff in self.layers:
|
| 55 |
+
attention_value, attention_weight = attn(x)
|
| 56 |
+
x = attention_value + x
|
| 57 |
+
x = ff(x) + x
|
| 58 |
+
l_attn.append(attention_weight)
|
| 59 |
+
return x, l_attn
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class MaskTransformer(nn.Module):
|
| 63 |
+
def __init__(self, img_size=256, hidden_dim=768, codebook_size=1024, depth=24, heads=8, mlp_dim=3072, dropout=0.1, nclass=1000):
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
self.nclass = nclass
|
| 67 |
+
self.patch_size = img_size // 16
|
| 68 |
+
self.codebook_size = codebook_size
|
| 69 |
+
self.tok_emb = nn.Embedding(codebook_size+1+nclass+1, hidden_dim) # +1 for the mask of the viz token, +1 for mask of the class
|
| 70 |
+
# self.msk_emb = nn.Embedding(2, hidden_dim)
|
| 71 |
+
self.pos_emb = nn.init.trunc_normal_(nn.Parameter(torch.zeros(1, (self.patch_size*self.patch_size)+1, hidden_dim)), 0., 0.02)
|
| 72 |
+
self.first_layer = nn.Sequential(
|
| 73 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
| 74 |
+
nn.Dropout(p=dropout),
|
| 75 |
+
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
|
| 76 |
+
nn.GELU(),
|
| 77 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
| 78 |
+
nn.Dropout(p=dropout),
|
| 79 |
+
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.transformer = TransformerEncoder(dim=hidden_dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dropout=dropout)
|
| 83 |
+
|
| 84 |
+
self.last_layer = nn.Sequential(
|
| 85 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
| 86 |
+
nn.Dropout(p=dropout),
|
| 87 |
+
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
|
| 88 |
+
nn.GELU(),
|
| 89 |
+
nn.LayerNorm(hidden_dim, eps=1e-12),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.bias = nn.Parameter(torch.zeros((self.patch_size*self.patch_size)+1, codebook_size+1+nclass+1))
|
| 93 |
+
|
| 94 |
+
def forward(self, img_token, y=None, drop_label=None, return_attn=False): # , masking_flag=None):
|
| 95 |
+
b, w, h = img_token.size()
|
| 96 |
+
|
| 97 |
+
cls_token = y.view(b, -1) + self.codebook_size + 1
|
| 98 |
+
cls_token[drop_label] = self.codebook_size + 1 + self.nclass
|
| 99 |
+
input = torch.cat([img_token.view(b, -1), cls_token.view(b, -1)], -1)
|
| 100 |
+
tok_embeddings = self.tok_emb(input)
|
| 101 |
+
pos_embeddings = self.pos_emb
|
| 102 |
+
x = tok_embeddings + pos_embeddings
|
| 103 |
+
|
| 104 |
+
# if masking_flag is not None:
|
| 105 |
+
# flag = torch.cat([masking_flag.view(b, -1), torch.zeros_like(cls_token.view(b, -1))], -1)
|
| 106 |
+
# x += self.msk_emb(flag)
|
| 107 |
+
|
| 108 |
+
x = self.first_layer(x)
|
| 109 |
+
x, attn = self.transformer(x)
|
| 110 |
+
x = self.last_layer(x)
|
| 111 |
+
|
| 112 |
+
logit = torch.matmul(x, self.tok_emb.weight.T) + self.bias
|
| 113 |
+
|
| 114 |
+
if return_attn:
|
| 115 |
+
return logit[:, :self.patch_size * self.patch_size, :self.codebook_size + 1], attn
|
| 116 |
+
|
| 117 |
+
return logit[:, :self.patch_size*self.patch_size, :self.codebook_size+1]
|
Models/models/vqgan.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from Models.modules.diffusionmodules.model import Encoder, Decoder
|
| 8 |
+
from Models.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
| 9 |
+
from Models.modules.vqvae.quantize import GumbelQuantize
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_obj_from_str(string, reload=False):
|
| 13 |
+
module, cls = string.rsplit(".", 1)
|
| 14 |
+
if reload:
|
| 15 |
+
module_imp = importlib.import_module(module)
|
| 16 |
+
importlib.reload(module_imp)
|
| 17 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def instantiate_from_config(config):
|
| 21 |
+
if not "target" in config:
|
| 22 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 23 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class VQModel(pl.LightningModule):
|
| 27 |
+
def __init__(self,
|
| 28 |
+
ddconfig,
|
| 29 |
+
lossconfig,
|
| 30 |
+
n_embed,
|
| 31 |
+
embed_dim,
|
| 32 |
+
ckpt_path=None,
|
| 33 |
+
ignore_keys=[],
|
| 34 |
+
image_key="image",
|
| 35 |
+
colorize_nlabels=None,
|
| 36 |
+
monitor=None,
|
| 37 |
+
remap=None,
|
| 38 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.image_key = image_key
|
| 42 |
+
self.encoder = Encoder(**ddconfig)
|
| 43 |
+
self.decoder = Decoder(**ddconfig)
|
| 44 |
+
# self.loss = instantiate_from_config(lossconfig)
|
| 45 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
| 46 |
+
remap=remap, sane_index_shape=sane_index_shape)
|
| 47 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
| 48 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 49 |
+
if ckpt_path is not None:
|
| 50 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 51 |
+
self.image_key = image_key
|
| 52 |
+
if colorize_nlabels is not None:
|
| 53 |
+
assert type(colorize_nlabels) == int
|
| 54 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 55 |
+
if monitor is not None:
|
| 56 |
+
self.monitor = monitor
|
| 57 |
+
|
| 58 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 59 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 60 |
+
keys = list(sd.keys())
|
| 61 |
+
for k in keys:
|
| 62 |
+
for ik in ignore_keys:
|
| 63 |
+
if k.startswith(ik):
|
| 64 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 65 |
+
del sd[k]
|
| 66 |
+
self.load_state_dict(sd, strict=False)
|
| 67 |
+
print(f"Restored from {path}")
|
| 68 |
+
|
| 69 |
+
def encode(self, x):
|
| 70 |
+
h = self.encoder(x)
|
| 71 |
+
h = self.quant_conv(h)
|
| 72 |
+
quant, emb_loss, info = self.quantize(h)
|
| 73 |
+
return quant, emb_loss, info
|
| 74 |
+
|
| 75 |
+
def decode(self, quant):
|
| 76 |
+
quant = self.post_quant_conv(quant)
|
| 77 |
+
dec = self.decoder(quant)
|
| 78 |
+
return dec
|
| 79 |
+
|
| 80 |
+
def decode_code(self, code_b):
|
| 81 |
+
quant_b = self.quantize.get_codebook_entry(code_b.view(-1), (code_b.size(0), code_b.size(1), code_b.size(2), 256))
|
| 82 |
+
dec = self.decode(quant_b)
|
| 83 |
+
return dec
|
| 84 |
+
|
| 85 |
+
def forward(self, input):
|
| 86 |
+
quant, diff, _ = self.encode(input)
|
| 87 |
+
dec = self.decode(quant)
|
| 88 |
+
return dec, diff
|
| 89 |
+
|
| 90 |
+
def get_input(self, batch, k):
|
| 91 |
+
x = batch[k]
|
| 92 |
+
if len(x.shape) == 3:
|
| 93 |
+
x = x[..., None]
|
| 94 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
| 95 |
+
return x.float()
|
| 96 |
+
|
| 97 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 98 |
+
x = self.get_input(batch, self.image_key)
|
| 99 |
+
xrec, qloss = self(x)
|
| 100 |
+
|
| 101 |
+
if optimizer_idx == 0:
|
| 102 |
+
# autoencode
|
| 103 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 104 |
+
last_layer=self.get_last_layer(), split="train")
|
| 105 |
+
|
| 106 |
+
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 107 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 108 |
+
return aeloss
|
| 109 |
+
|
| 110 |
+
if optimizer_idx == 1:
|
| 111 |
+
# discriminator
|
| 112 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 113 |
+
last_layer=self.get_last_layer(), split="train")
|
| 114 |
+
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 115 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 116 |
+
return discloss
|
| 117 |
+
|
| 118 |
+
def validation_step(self, batch, batch_idx):
|
| 119 |
+
x = self.get_input(batch, self.image_key)
|
| 120 |
+
xrec, qloss = self(x)
|
| 121 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
| 122 |
+
last_layer=self.get_last_layer(), split="val")
|
| 123 |
+
|
| 124 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
| 125 |
+
last_layer=self.get_last_layer(), split="val")
|
| 126 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
| 127 |
+
self.log("val/rec_loss", rec_loss,
|
| 128 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
| 129 |
+
self.log("val/aeloss", aeloss,
|
| 130 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
| 131 |
+
self.log_dict(log_dict_ae)
|
| 132 |
+
self.log_dict(log_dict_disc)
|
| 133 |
+
return self.log_dict
|
| 134 |
+
|
| 135 |
+
def configure_optimizers(self):
|
| 136 |
+
lr = self.learning_rate
|
| 137 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 138 |
+
list(self.decoder.parameters())+
|
| 139 |
+
list(self.quantize.parameters())+
|
| 140 |
+
list(self.quant_conv.parameters())+
|
| 141 |
+
list(self.post_quant_conv.parameters()),
|
| 142 |
+
lr=lr, betas=(0.5, 0.9))
|
| 143 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 144 |
+
lr=lr, betas=(0.5, 0.9))
|
| 145 |
+
return [opt_ae, opt_disc], []
|
| 146 |
+
|
| 147 |
+
def get_last_layer(self):
|
| 148 |
+
return self.decoder.conv_out.weight
|
| 149 |
+
|
| 150 |
+
def log_images(self, batch, **kwargs):
|
| 151 |
+
log = dict()
|
| 152 |
+
x = self.get_input(batch, self.image_key)
|
| 153 |
+
x = x.to(self.device)
|
| 154 |
+
xrec, _ = self(x)
|
| 155 |
+
if x.shape[1] > 3:
|
| 156 |
+
# colorize with random projection
|
| 157 |
+
assert xrec.shape[1] > 3
|
| 158 |
+
x = self.to_rgb(x)
|
| 159 |
+
xrec = self.to_rgb(xrec)
|
| 160 |
+
log["inputs"] = x
|
| 161 |
+
log["reconstructions"] = xrec
|
| 162 |
+
return log
|
| 163 |
+
|
| 164 |
+
def to_rgb(self, x):
|
| 165 |
+
assert self.image_key == "segmentation"
|
| 166 |
+
if not hasattr(self, "colorize"):
|
| 167 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 168 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 169 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class GumbelVQ(VQModel):
|
| 174 |
+
def __init__(self,
|
| 175 |
+
ddconfig,
|
| 176 |
+
lossconfig,
|
| 177 |
+
n_embed,
|
| 178 |
+
embed_dim,
|
| 179 |
+
temperature_scheduler_config,
|
| 180 |
+
ckpt_path=None,
|
| 181 |
+
ignore_keys=[],
|
| 182 |
+
image_key="image",
|
| 183 |
+
colorize_nlabels=None,
|
| 184 |
+
monitor=None,
|
| 185 |
+
kl_weight=1e-8,
|
| 186 |
+
remap=None,
|
| 187 |
+
):
|
| 188 |
+
|
| 189 |
+
z_channels = ddconfig["z_channels"]
|
| 190 |
+
super().__init__(ddconfig,
|
| 191 |
+
lossconfig,
|
| 192 |
+
n_embed,
|
| 193 |
+
embed_dim,
|
| 194 |
+
ckpt_path=None,
|
| 195 |
+
ignore_keys=ignore_keys,
|
| 196 |
+
image_key=image_key,
|
| 197 |
+
colorize_nlabels=colorize_nlabels,
|
| 198 |
+
monitor=monitor,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# self.loss.n_classes = n_embed
|
| 202 |
+
self.vocab_size = n_embed
|
| 203 |
+
|
| 204 |
+
self.quantize = GumbelQuantize(z_channels, embed_dim,
|
| 205 |
+
n_embed=n_embed,
|
| 206 |
+
kl_weight=kl_weight, temp_init=1.0,
|
| 207 |
+
remap=remap)
|
| 208 |
+
|
| 209 |
+
# self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
|
| 210 |
+
|
| 211 |
+
if ckpt_path is not None:
|
| 212 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 213 |
+
|
| 214 |
+
def temperature_scheduling(self):
|
| 215 |
+
self.quantize.temperature = self.temperature_scheduler(self.global_step)
|
| 216 |
+
|
| 217 |
+
def encode_to_prequant(self, x):
|
| 218 |
+
h = self.encoder(x)
|
| 219 |
+
h = self.quant_conv(h)
|
| 220 |
+
return h
|
| 221 |
+
|
| 222 |
+
def decode_code(self, code_b):
|
| 223 |
+
quant_b = self.quantize.get_codebook_entry(code_b.view(-1), (code_b.size(0), 32, 32, 8192))
|
| 224 |
+
dec = self.decode(quant_b)
|
| 225 |
+
return dec
|
| 226 |
+
|
| 227 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 228 |
+
self.temperature_scheduling()
|
| 229 |
+
x = self.get_input(batch, self.image_key)
|
| 230 |
+
xrec, qloss = self(x)
|
| 231 |
+
|
| 232 |
+
if optimizer_idx == 0:
|
| 233 |
+
# autoencoder
|
| 234 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 235 |
+
last_layer=self.get_last_layer(), split="train")
|
| 236 |
+
|
| 237 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 238 |
+
self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 239 |
+
return aeloss
|
| 240 |
+
|
| 241 |
+
if optimizer_idx == 1:
|
| 242 |
+
# discriminator
|
| 243 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 244 |
+
last_layer=self.get_last_layer(), split="train")
|
| 245 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 246 |
+
return discloss
|
| 247 |
+
|
| 248 |
+
def validation_step(self, batch, batch_idx):
|
| 249 |
+
x = self.get_input(batch, self.image_key)
|
| 250 |
+
xrec, qloss = self(x, return_pred_indices=True)
|
| 251 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
| 252 |
+
last_layer=self.get_last_layer(), split="val")
|
| 253 |
+
|
| 254 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
| 255 |
+
last_layer=self.get_last_layer(), split="val")
|
| 256 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
| 257 |
+
self.log("val/rec_loss", rec_loss,
|
| 258 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 259 |
+
self.log("val/aeloss", aeloss,
|
| 260 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 261 |
+
self.log_dict(log_dict_ae)
|
| 262 |
+
self.log_dict(log_dict_disc)
|
| 263 |
+
return self.log_dict
|
| 264 |
+
|
| 265 |
+
def log_images(self, batch, **kwargs):
|
| 266 |
+
log = dict()
|
| 267 |
+
x = self.get_input(batch, self.image_key)
|
| 268 |
+
x = x.to(self.device)
|
| 269 |
+
# encode
|
| 270 |
+
h = self.encoder(x)
|
| 271 |
+
h = self.quant_conv(h)
|
| 272 |
+
quant, _, _ = self.quantize(h)
|
| 273 |
+
# decode
|
| 274 |
+
x_rec = self.decode(quant)
|
| 275 |
+
log["inputs"] = x
|
| 276 |
+
log["reconstructions"] = x_rec
|
| 277 |
+
return log
|
| 278 |
+
|
| 279 |
+
def reco(self, x): # , batch, **kwargs):
|
| 280 |
+
# log = dict()
|
| 281 |
+
# x = self.get_input(batch, self.image_key)
|
| 282 |
+
# x = x.to(self.device)
|
| 283 |
+
# encode
|
| 284 |
+
h = self.encoder(x)
|
| 285 |
+
# print(h, h.size())
|
| 286 |
+
h = self.quant_conv(h)
|
| 287 |
+
quant, _, _ = self.quantize(h)
|
| 288 |
+
print(quant, quant.size())
|
| 289 |
+
exit()
|
| 290 |
+
# decode
|
| 291 |
+
x_rec = self.decode(quant)
|
| 292 |
+
# log["inputs"] = x
|
| 293 |
+
# log["reconstructions"] = x_rec
|
| 294 |
+
return x_rec
|
Models/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (9.58 kB). View file
|
|
|
Models/modules/diffusionmodules/model.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytorch_diffusion + derived encoder decoder
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
| 8 |
+
"""
|
| 9 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
| 10 |
+
From Fairseq.
|
| 11 |
+
Build sinusoidal embeddings.
|
| 12 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
| 13 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
| 14 |
+
"""
|
| 15 |
+
assert len(timesteps.shape) == 1
|
| 16 |
+
|
| 17 |
+
half_dim = embedding_dim // 2
|
| 18 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 19 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
| 20 |
+
emb = emb.to(device=timesteps.device)
|
| 21 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
| 22 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 23 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 24 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
| 25 |
+
return emb
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def nonlinearity(x):
|
| 29 |
+
# swish
|
| 30 |
+
return x*torch.sigmoid(x)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def Normalize(in_channels):
|
| 34 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Upsample(nn.Module):
|
| 38 |
+
def __init__(self, in_channels, with_conv):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.with_conv = with_conv
|
| 41 |
+
if self.with_conv:
|
| 42 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 43 |
+
in_channels,
|
| 44 |
+
kernel_size=3,
|
| 45 |
+
stride=1,
|
| 46 |
+
padding=1)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 50 |
+
if self.with_conv:
|
| 51 |
+
x = self.conv(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Downsample(nn.Module):
|
| 56 |
+
def __init__(self, in_channels, with_conv):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.with_conv = with_conv
|
| 59 |
+
if self.with_conv:
|
| 60 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 61 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 62 |
+
in_channels,
|
| 63 |
+
kernel_size=3,
|
| 64 |
+
stride=2,
|
| 65 |
+
padding=0)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
if self.with_conv:
|
| 69 |
+
pad = (0,1,0,1)
|
| 70 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 71 |
+
x = self.conv(x)
|
| 72 |
+
else:
|
| 73 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class ResnetBlock(nn.Module):
|
| 78 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
| 79 |
+
dropout, temb_channels=512):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.in_channels = in_channels
|
| 82 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 83 |
+
self.out_channels = out_channels
|
| 84 |
+
self.use_conv_shortcut = conv_shortcut
|
| 85 |
+
|
| 86 |
+
self.norm1 = Normalize(in_channels)
|
| 87 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 88 |
+
out_channels,
|
| 89 |
+
kernel_size=3,
|
| 90 |
+
stride=1,
|
| 91 |
+
padding=1)
|
| 92 |
+
if temb_channels > 0:
|
| 93 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 94 |
+
out_channels)
|
| 95 |
+
self.norm2 = Normalize(out_channels)
|
| 96 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 97 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 98 |
+
out_channels,
|
| 99 |
+
kernel_size=3,
|
| 100 |
+
stride=1,
|
| 101 |
+
padding=1)
|
| 102 |
+
if self.in_channels != self.out_channels:
|
| 103 |
+
if self.use_conv_shortcut:
|
| 104 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 105 |
+
out_channels,
|
| 106 |
+
kernel_size=3,
|
| 107 |
+
stride=1,
|
| 108 |
+
padding=1)
|
| 109 |
+
else:
|
| 110 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 111 |
+
out_channels,
|
| 112 |
+
kernel_size=1,
|
| 113 |
+
stride=1,
|
| 114 |
+
padding=0)
|
| 115 |
+
|
| 116 |
+
def forward(self, x, temb):
|
| 117 |
+
h = x
|
| 118 |
+
h = self.norm1(h)
|
| 119 |
+
h = nonlinearity(h)
|
| 120 |
+
h = self.conv1(h)
|
| 121 |
+
|
| 122 |
+
if temb is not None:
|
| 123 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
| 124 |
+
|
| 125 |
+
h = self.norm2(h)
|
| 126 |
+
h = nonlinearity(h)
|
| 127 |
+
h = self.dropout(h)
|
| 128 |
+
h = self.conv2(h)
|
| 129 |
+
|
| 130 |
+
if self.in_channels != self.out_channels:
|
| 131 |
+
if self.use_conv_shortcut:
|
| 132 |
+
x = self.conv_shortcut(x)
|
| 133 |
+
else:
|
| 134 |
+
x = self.nin_shortcut(x)
|
| 135 |
+
|
| 136 |
+
return x+h
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class AttnBlock(nn.Module):
|
| 140 |
+
def __init__(self, in_channels):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.in_channels = in_channels
|
| 143 |
+
|
| 144 |
+
self.norm = Normalize(in_channels)
|
| 145 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 146 |
+
in_channels,
|
| 147 |
+
kernel_size=1,
|
| 148 |
+
stride=1,
|
| 149 |
+
padding=0)
|
| 150 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 151 |
+
in_channels,
|
| 152 |
+
kernel_size=1,
|
| 153 |
+
stride=1,
|
| 154 |
+
padding=0)
|
| 155 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 156 |
+
in_channels,
|
| 157 |
+
kernel_size=1,
|
| 158 |
+
stride=1,
|
| 159 |
+
padding=0)
|
| 160 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 161 |
+
in_channels,
|
| 162 |
+
kernel_size=1,
|
| 163 |
+
stride=1,
|
| 164 |
+
padding=0)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def forward(self, x):
|
| 168 |
+
h_ = x
|
| 169 |
+
h_ = self.norm(h_)
|
| 170 |
+
q = self.q(h_)
|
| 171 |
+
k = self.k(h_)
|
| 172 |
+
v = self.v(h_)
|
| 173 |
+
|
| 174 |
+
# compute attention
|
| 175 |
+
b,c,h,w = q.shape
|
| 176 |
+
q = q.reshape(b,c,h*w)
|
| 177 |
+
q = q.permute(0,2,1) # b,hw,c
|
| 178 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
| 179 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 180 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 181 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 182 |
+
|
| 183 |
+
# attend to values
|
| 184 |
+
v = v.reshape(b,c,h*w)
|
| 185 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
| 186 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 187 |
+
h_ = h_.reshape(b,c,h,w)
|
| 188 |
+
|
| 189 |
+
h_ = self.proj_out(h_)
|
| 190 |
+
|
| 191 |
+
return x+h_
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class Encoder(nn.Module):
|
| 195 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 196 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 197 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.ch = ch
|
| 200 |
+
self.temb_ch = 0
|
| 201 |
+
self.num_resolutions = len(ch_mult)
|
| 202 |
+
self.num_res_blocks = num_res_blocks
|
| 203 |
+
self.resolution = resolution
|
| 204 |
+
self.in_channels = in_channels
|
| 205 |
+
|
| 206 |
+
# downsampling
|
| 207 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 208 |
+
self.ch,
|
| 209 |
+
kernel_size=3,
|
| 210 |
+
stride=1,
|
| 211 |
+
padding=1)
|
| 212 |
+
|
| 213 |
+
curr_res = resolution
|
| 214 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 215 |
+
self.down = nn.ModuleList()
|
| 216 |
+
for i_level in range(self.num_resolutions):
|
| 217 |
+
block = nn.ModuleList()
|
| 218 |
+
attn = nn.ModuleList()
|
| 219 |
+
block_in = ch*in_ch_mult[i_level]
|
| 220 |
+
block_out = ch*ch_mult[i_level]
|
| 221 |
+
for i_block in range(self.num_res_blocks):
|
| 222 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 223 |
+
out_channels=block_out,
|
| 224 |
+
temb_channels=self.temb_ch,
|
| 225 |
+
dropout=dropout))
|
| 226 |
+
block_in = block_out
|
| 227 |
+
if curr_res in attn_resolutions:
|
| 228 |
+
attn.append(AttnBlock(block_in))
|
| 229 |
+
down = nn.Module()
|
| 230 |
+
down.block = block
|
| 231 |
+
down.attn = attn
|
| 232 |
+
if i_level != self.num_resolutions-1:
|
| 233 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 234 |
+
curr_res = curr_res // 2
|
| 235 |
+
self.down.append(down)
|
| 236 |
+
|
| 237 |
+
# middle
|
| 238 |
+
self.mid = nn.Module()
|
| 239 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 240 |
+
out_channels=block_in,
|
| 241 |
+
temb_channels=self.temb_ch,
|
| 242 |
+
dropout=dropout)
|
| 243 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 244 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 245 |
+
out_channels=block_in,
|
| 246 |
+
temb_channels=self.temb_ch,
|
| 247 |
+
dropout=dropout)
|
| 248 |
+
|
| 249 |
+
# end
|
| 250 |
+
self.norm_out = Normalize(block_in)
|
| 251 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 252 |
+
2*z_channels if double_z else z_channels,
|
| 253 |
+
kernel_size=3,
|
| 254 |
+
stride=1,
|
| 255 |
+
padding=1)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
| 260 |
+
|
| 261 |
+
# timestep embedding
|
| 262 |
+
temb = None
|
| 263 |
+
|
| 264 |
+
# downsampling
|
| 265 |
+
hs = [self.conv_in(x)]
|
| 266 |
+
for i_level in range(self.num_resolutions):
|
| 267 |
+
for i_block in range(self.num_res_blocks):
|
| 268 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 269 |
+
if len(self.down[i_level].attn) > 0:
|
| 270 |
+
h = self.down[i_level].attn[i_block](h)
|
| 271 |
+
hs.append(h)
|
| 272 |
+
if i_level != self.num_resolutions-1:
|
| 273 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 274 |
+
|
| 275 |
+
# middle
|
| 276 |
+
h = hs[-1]
|
| 277 |
+
h = self.mid.block_1(h, temb)
|
| 278 |
+
h = self.mid.attn_1(h)
|
| 279 |
+
h = self.mid.block_2(h, temb)
|
| 280 |
+
|
| 281 |
+
# end
|
| 282 |
+
h = self.norm_out(h)
|
| 283 |
+
h = nonlinearity(h)
|
| 284 |
+
h = self.conv_out(h)
|
| 285 |
+
return h
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class Decoder(nn.Module):
|
| 289 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 290 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 291 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.ch = ch
|
| 294 |
+
self.temb_ch = 0
|
| 295 |
+
self.num_resolutions = len(ch_mult)
|
| 296 |
+
self.num_res_blocks = num_res_blocks
|
| 297 |
+
self.resolution = resolution
|
| 298 |
+
self.in_channels = in_channels
|
| 299 |
+
self.give_pre_end = give_pre_end
|
| 300 |
+
|
| 301 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 302 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 303 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
| 304 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
| 305 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
| 306 |
+
# print("Working with z of shape {} = {} dimensions.".format(
|
| 307 |
+
# self.z_shape, np.prod(self.z_shape)))
|
| 308 |
+
|
| 309 |
+
# z to block_in
|
| 310 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
| 311 |
+
block_in,
|
| 312 |
+
kernel_size=3,
|
| 313 |
+
stride=1,
|
| 314 |
+
padding=1)
|
| 315 |
+
|
| 316 |
+
# middle
|
| 317 |
+
self.mid = nn.Module()
|
| 318 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 319 |
+
out_channels=block_in,
|
| 320 |
+
temb_channels=self.temb_ch,
|
| 321 |
+
dropout=dropout)
|
| 322 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 323 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 324 |
+
out_channels=block_in,
|
| 325 |
+
temb_channels=self.temb_ch,
|
| 326 |
+
dropout=dropout)
|
| 327 |
+
|
| 328 |
+
# upsampling
|
| 329 |
+
self.up = nn.ModuleList()
|
| 330 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 331 |
+
block = nn.ModuleList()
|
| 332 |
+
attn = nn.ModuleList()
|
| 333 |
+
block_out = ch*ch_mult[i_level]
|
| 334 |
+
for i_block in range(self.num_res_blocks+1):
|
| 335 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 336 |
+
out_channels=block_out,
|
| 337 |
+
temb_channels=self.temb_ch,
|
| 338 |
+
dropout=dropout))
|
| 339 |
+
block_in = block_out
|
| 340 |
+
if curr_res in attn_resolutions:
|
| 341 |
+
attn.append(AttnBlock(block_in))
|
| 342 |
+
up = nn.Module()
|
| 343 |
+
up.block = block
|
| 344 |
+
up.attn = attn
|
| 345 |
+
if i_level != 0:
|
| 346 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 347 |
+
curr_res = curr_res * 2
|
| 348 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 349 |
+
|
| 350 |
+
# end
|
| 351 |
+
self.norm_out = Normalize(block_in)
|
| 352 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 353 |
+
out_ch,
|
| 354 |
+
kernel_size=3,
|
| 355 |
+
stride=1,
|
| 356 |
+
padding=1)
|
| 357 |
+
|
| 358 |
+
def forward(self, z):
|
| 359 |
+
self.last_z_shape = z.shape
|
| 360 |
+
|
| 361 |
+
# timestep embedding
|
| 362 |
+
temb = None
|
| 363 |
+
|
| 364 |
+
# z to block_in
|
| 365 |
+
h = self.conv_in(z)
|
| 366 |
+
|
| 367 |
+
# middle
|
| 368 |
+
h = self.mid.block_1(h, temb)
|
| 369 |
+
h = self.mid.attn_1(h)
|
| 370 |
+
h = self.mid.block_2(h, temb)
|
| 371 |
+
|
| 372 |
+
# upsampling
|
| 373 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 374 |
+
for i_block in range(self.num_res_blocks+1):
|
| 375 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 376 |
+
if len(self.up[i_level].attn) > 0:
|
| 377 |
+
h = self.up[i_level].attn[i_block](h)
|
| 378 |
+
if i_level != 0:
|
| 379 |
+
h = self.up[i_level].upsample(h)
|
| 380 |
+
|
| 381 |
+
# end
|
| 382 |
+
if self.give_pre_end:
|
| 383 |
+
return h
|
| 384 |
+
|
| 385 |
+
h = self.norm_out(h)
|
| 386 |
+
h = nonlinearity(h)
|
| 387 |
+
h = self.conv_out(h)
|
| 388 |
+
return h
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class UpsampleDecoder(nn.Module):
|
| 392 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
|
| 393 |
+
super().__init__()
|
| 394 |
+
# upsampling
|
| 395 |
+
self.temb_ch = 0
|
| 396 |
+
self.num_resolutions = len(ch_mult)
|
| 397 |
+
self.num_res_blocks = num_res_blocks
|
| 398 |
+
block_in = in_channels
|
| 399 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 400 |
+
self.res_blocks = nn.ModuleList()
|
| 401 |
+
self.upsample_blocks = nn.ModuleList()
|
| 402 |
+
for i_level in range(self.num_resolutions):
|
| 403 |
+
res_block = []
|
| 404 |
+
block_out = ch * ch_mult[i_level]
|
| 405 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 406 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
| 407 |
+
out_channels=block_out,
|
| 408 |
+
temb_channels=self.temb_ch,
|
| 409 |
+
dropout=dropout))
|
| 410 |
+
block_in = block_out
|
| 411 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
| 412 |
+
if i_level != self.num_resolutions - 1:
|
| 413 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
| 414 |
+
curr_res = curr_res * 2
|
| 415 |
+
|
| 416 |
+
# end
|
| 417 |
+
self.norm_out = Normalize(block_in)
|
| 418 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 419 |
+
out_channels,
|
| 420 |
+
kernel_size=3,
|
| 421 |
+
stride=1,
|
| 422 |
+
padding=1)
|
| 423 |
+
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
# upsampling
|
| 426 |
+
h = x
|
| 427 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
| 428 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 429 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
| 430 |
+
if i_level != self.num_resolutions - 1:
|
| 431 |
+
h = self.upsample_blocks[k](h)
|
| 432 |
+
h = self.norm_out(h)
|
| 433 |
+
h = nonlinearity(h)
|
| 434 |
+
h = self.conv_out(h)
|
| 435 |
+
return h
|
| 436 |
+
|
Models/modules/util.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def count_params(model):
|
| 6 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 7 |
+
return total_params
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ActNorm(nn.Module):
|
| 11 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
| 12 |
+
allow_reverse_init=False):
|
| 13 |
+
assert affine
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.logdet = logdet
|
| 16 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
| 17 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
| 18 |
+
self.allow_reverse_init = allow_reverse_init
|
| 19 |
+
|
| 20 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
| 21 |
+
|
| 22 |
+
def initialize(self, input):
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
| 25 |
+
mean = (
|
| 26 |
+
flatten.mean(1)
|
| 27 |
+
.unsqueeze(1)
|
| 28 |
+
.unsqueeze(2)
|
| 29 |
+
.unsqueeze(3)
|
| 30 |
+
.permute(1, 0, 2, 3)
|
| 31 |
+
)
|
| 32 |
+
std = (
|
| 33 |
+
flatten.std(1)
|
| 34 |
+
.unsqueeze(1)
|
| 35 |
+
.unsqueeze(2)
|
| 36 |
+
.unsqueeze(3)
|
| 37 |
+
.permute(1, 0, 2, 3)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.loc.data.copy_(-mean)
|
| 41 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
| 42 |
+
|
| 43 |
+
def forward(self, input, reverse=False):
|
| 44 |
+
if reverse:
|
| 45 |
+
return self.reverse(input)
|
| 46 |
+
if len(input.shape) == 2:
|
| 47 |
+
input = input[:,:,None,None]
|
| 48 |
+
squeeze = True
|
| 49 |
+
else:
|
| 50 |
+
squeeze = False
|
| 51 |
+
|
| 52 |
+
_, _, height, width = input.shape
|
| 53 |
+
|
| 54 |
+
if self.training and self.initialized.item() == 0:
|
| 55 |
+
self.initialize(input)
|
| 56 |
+
self.initialized.fill_(1)
|
| 57 |
+
|
| 58 |
+
h = self.scale * (input + self.loc)
|
| 59 |
+
|
| 60 |
+
if squeeze:
|
| 61 |
+
h = h.squeeze(-1).squeeze(-1)
|
| 62 |
+
|
| 63 |
+
if self.logdet:
|
| 64 |
+
log_abs = torch.log(torch.abs(self.scale))
|
| 65 |
+
logdet = height*width*torch.sum(log_abs)
|
| 66 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
| 67 |
+
return h, logdet
|
| 68 |
+
|
| 69 |
+
return h
|
| 70 |
+
|
| 71 |
+
def reverse(self, output):
|
| 72 |
+
if self.training and self.initialized.item() == 0:
|
| 73 |
+
if not self.allow_reverse_init:
|
| 74 |
+
raise RuntimeError(
|
| 75 |
+
"Initializing ActNorm in reverse direction is "
|
| 76 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
self.initialize(output)
|
| 80 |
+
self.initialized.fill_(1)
|
| 81 |
+
|
| 82 |
+
if len(output.shape) == 2:
|
| 83 |
+
output = output[:,:,None,None]
|
| 84 |
+
squeeze = True
|
| 85 |
+
else:
|
| 86 |
+
squeeze = False
|
| 87 |
+
|
| 88 |
+
h = output / self.scale - self.loc
|
| 89 |
+
|
| 90 |
+
if squeeze:
|
| 91 |
+
h = h.squeeze(-1).squeeze(-1)
|
| 92 |
+
return h
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class AbstractEncoder(nn.Module):
|
| 96 |
+
def __init__(self):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
def encode(self, *args, **kwargs):
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Labelator(AbstractEncoder):
|
| 104 |
+
"""Net2Net Interface for Class-Conditional Model"""
|
| 105 |
+
def __init__(self, n_classes, quantize_interface=True):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.n_classes = n_classes
|
| 108 |
+
self.quantize_interface = quantize_interface
|
| 109 |
+
|
| 110 |
+
def encode(self, c):
|
| 111 |
+
c = c[:,None]
|
| 112 |
+
if self.quantize_interface:
|
| 113 |
+
return c, None, [None, None, c.long()]
|
| 114 |
+
return c
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class SOSProvider(AbstractEncoder):
|
| 118 |
+
# for unconditional training
|
| 119 |
+
def __init__(self, sos_token, quantize_interface=True):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.sos_token = sos_token
|
| 122 |
+
self.quantize_interface = quantize_interface
|
| 123 |
+
|
| 124 |
+
def encode(self, x):
|
| 125 |
+
# get batch size from data and replicate sos_token
|
| 126 |
+
c = torch.ones(x.shape[0], 1)*self.sos_token
|
| 127 |
+
c = c.long().to(x.device)
|
| 128 |
+
if self.quantize_interface:
|
| 129 |
+
return c, None, [None, None, c]
|
| 130 |
+
return c
|
Models/modules/vqvae/__pycache__/quantize.cpython-38.pyc
ADDED
|
Binary file (8.96 kB). View file
|
|
|
Models/modules/vqvae/quantize.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch import einsum
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VectorQuantizer(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
| 12 |
+
____________________________________________
|
| 13 |
+
Discretization bottleneck part of the VQ-VAE.
|
| 14 |
+
Inputs:
|
| 15 |
+
- n_e : number of embeddings
|
| 16 |
+
- e_dim : dimension of embedding
|
| 17 |
+
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
| 18 |
+
_____________________________________________
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
|
| 22 |
+
# a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
|
| 23 |
+
# used wherever VectorQuantizer has been used before and is additionally
|
| 24 |
+
# more efficient.
|
| 25 |
+
def __init__(self, n_e, e_dim, beta):
|
| 26 |
+
super(VectorQuantizer, self).__init__()
|
| 27 |
+
self.n_e = n_e
|
| 28 |
+
self.e_dim = e_dim
|
| 29 |
+
self.beta = beta
|
| 30 |
+
|
| 31 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 32 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 33 |
+
|
| 34 |
+
def forward(self, z):
|
| 35 |
+
"""
|
| 36 |
+
Inputs the output of the encoder network z and maps it to a discrete
|
| 37 |
+
one-hot vector that is the index of the closest embedding vector e_j
|
| 38 |
+
z (continuous) -> z_q (discrete)
|
| 39 |
+
z.shape = (batch, channel, height, width)
|
| 40 |
+
quantization pipeline:
|
| 41 |
+
1. get encoder input (B,C,H,W)
|
| 42 |
+
2. flatten input to (B*H*W,C)
|
| 43 |
+
"""
|
| 44 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 45 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
| 46 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 47 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 48 |
+
|
| 49 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| 50 |
+
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
| 51 |
+
torch.matmul(z_flattened, self.embedding.weight.t())
|
| 52 |
+
|
| 53 |
+
## could possible replace this here
|
| 54 |
+
# #\start...
|
| 55 |
+
# find closest encodings
|
| 56 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
| 57 |
+
|
| 58 |
+
min_encodings = torch.zeros(
|
| 59 |
+
min_encoding_indices.shape[0], self.n_e).to(z)
|
| 60 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
| 61 |
+
|
| 62 |
+
# dtype min encodings: torch.float32
|
| 63 |
+
# min_encodings shape: torch.Size([2048, 512])
|
| 64 |
+
# min_encoding_indices.shape: torch.Size([2048, 1])
|
| 65 |
+
|
| 66 |
+
# get quantized latent vectors
|
| 67 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
| 68 |
+
# .........\end
|
| 69 |
+
|
| 70 |
+
# with:
|
| 71 |
+
# .........\start
|
| 72 |
+
# min_encoding_indices = torch.argmin(d, dim=1)
|
| 73 |
+
# z_q = self.embedding(min_encoding_indices)
|
| 74 |
+
# ......\end......... (TODO)
|
| 75 |
+
|
| 76 |
+
# compute loss for embedding
|
| 77 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
| 78 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 79 |
+
|
| 80 |
+
# preserve gradients
|
| 81 |
+
z_q = z + (z_q - z).detach()
|
| 82 |
+
|
| 83 |
+
# perplexity
|
| 84 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
| 85 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
| 86 |
+
|
| 87 |
+
# reshape back to match original input shape
|
| 88 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 89 |
+
|
| 90 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
| 91 |
+
|
| 92 |
+
def get_codebook_entry(self, indices, shape):
|
| 93 |
+
# shape specifying (batch, height, width, channel)
|
| 94 |
+
# TODO: check for more easy handling with nn.Embedding
|
| 95 |
+
min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
| 96 |
+
min_encodings.scatter_(1, indices[:, None], 1)
|
| 97 |
+
|
| 98 |
+
# get quantized latent vectors
|
| 99 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
| 100 |
+
|
| 101 |
+
if shape is not None:
|
| 102 |
+
z_q = z_q.view(shape)
|
| 103 |
+
|
| 104 |
+
# reshape back to match original input shape
|
| 105 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 106 |
+
|
| 107 |
+
return z_q
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class GumbelQuantize(nn.Module):
|
| 111 |
+
"""
|
| 112 |
+
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
| 113 |
+
Gumbel Softmax trick quantizer
|
| 114 |
+
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
| 115 |
+
https://arxiv.org/abs/1611.01144
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
|
| 119 |
+
kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
|
| 120 |
+
remap=None, unknown_index="random"):
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
self.embedding_dim = embedding_dim
|
| 124 |
+
self.n_embed = n_embed
|
| 125 |
+
print(n_embed)
|
| 126 |
+
self.straight_through = straight_through
|
| 127 |
+
self.temperature = temp_init
|
| 128 |
+
self.kl_weight = kl_weight
|
| 129 |
+
|
| 130 |
+
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
| 131 |
+
self.embed = nn.Embedding(n_embed, embedding_dim)
|
| 132 |
+
|
| 133 |
+
self.use_vqinterface = use_vqinterface
|
| 134 |
+
|
| 135 |
+
self.remap = remap
|
| 136 |
+
|
| 137 |
+
if self.remap is not None:
|
| 138 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
| 139 |
+
self.re_embed = self.used.shape[0]
|
| 140 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
| 141 |
+
if self.unknown_index == "extra":
|
| 142 |
+
self.unknown_index = self.re_embed
|
| 143 |
+
self.re_embed = self.re_embed + 1
|
| 144 |
+
print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
| 145 |
+
f"Using {self.unknown_index} for unknown indices.")
|
| 146 |
+
else:
|
| 147 |
+
self.re_embed = n_embed
|
| 148 |
+
|
| 149 |
+
def remap_to_used(self, inds):
|
| 150 |
+
ishape = inds.shape
|
| 151 |
+
assert len(ishape) > 1
|
| 152 |
+
inds = inds.reshape(ishape[0], -1)
|
| 153 |
+
used = self.used.to(inds)
|
| 154 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
| 155 |
+
new = match.argmax(-1)
|
| 156 |
+
unknown = match.sum(2) < 1
|
| 157 |
+
if self.unknown_index == "random":
|
| 158 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
| 159 |
+
else:
|
| 160 |
+
new[unknown] = self.unknown_index
|
| 161 |
+
return new.reshape(ishape)
|
| 162 |
+
|
| 163 |
+
def unmap_to_all(self, inds):
|
| 164 |
+
ishape = inds.shape
|
| 165 |
+
assert len(ishape) > 1
|
| 166 |
+
inds = inds.reshape(ishape[0], -1)
|
| 167 |
+
used = self.used.to(inds)
|
| 168 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
| 169 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
| 170 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
| 171 |
+
return back.reshape(ishape)
|
| 172 |
+
|
| 173 |
+
def forward(self, z, temp=None, return_logits=False):
|
| 174 |
+
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
|
| 175 |
+
hard = self.straight_through if self.training else True
|
| 176 |
+
temp = self.temperature if temp is None else temp
|
| 177 |
+
|
| 178 |
+
logits = self.proj(z)
|
| 179 |
+
if self.remap is not None:
|
| 180 |
+
# continue only with used logits
|
| 181 |
+
full_zeros = torch.zeros_like(logits)
|
| 182 |
+
logits = logits[:, self.used, ...]
|
| 183 |
+
|
| 184 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
| 185 |
+
if self.remap is not None:
|
| 186 |
+
# go back to all entries but unused set to zero
|
| 187 |
+
full_zeros[:, self.used, ...] = soft_one_hot
|
| 188 |
+
soft_one_hot = full_zeros
|
| 189 |
+
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
|
| 190 |
+
|
| 191 |
+
# + kl divergence to the prior loss
|
| 192 |
+
qy = F.softmax(logits, dim=1)
|
| 193 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
| 194 |
+
|
| 195 |
+
ind = soft_one_hot.argmax(dim=1)
|
| 196 |
+
if self.remap is not None:
|
| 197 |
+
ind = self.remap_to_used(ind)
|
| 198 |
+
if self.use_vqinterface:
|
| 199 |
+
if return_logits:
|
| 200 |
+
return z_q, diff, (None, None, ind), logits
|
| 201 |
+
return z_q, diff, (None, None, ind)
|
| 202 |
+
return z_q, diff, ind
|
| 203 |
+
|
| 204 |
+
def get_codebook_entry(self, indices, shape):
|
| 205 |
+
b, h, w, c = shape
|
| 206 |
+
assert b * h * w == indices.shape[0]
|
| 207 |
+
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
|
| 208 |
+
if self.remap is not None:
|
| 209 |
+
indices = self.unmap_to_all(indices)
|
| 210 |
+
one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
| 211 |
+
# print(one_hot.size())
|
| 212 |
+
# exit()
|
| 213 |
+
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
|
| 214 |
+
|
| 215 |
+
return z_q
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class VectorQuantizer2(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
| 221 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
| 225 |
+
# backwards compatibility we use the buggy version by default, but you can
|
| 226 |
+
# specify legacy=False to fix it.
|
| 227 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
| 228 |
+
sane_index_shape=False, legacy=True):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.n_e = n_e
|
| 231 |
+
self.e_dim = e_dim
|
| 232 |
+
self.beta = beta
|
| 233 |
+
self.legacy = legacy
|
| 234 |
+
|
| 235 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 236 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 237 |
+
|
| 238 |
+
self.remap = remap
|
| 239 |
+
if self.remap is not None:
|
| 240 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
| 241 |
+
self.re_embed = self.used.shape[0]
|
| 242 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
| 243 |
+
if self.unknown_index == "extra":
|
| 244 |
+
self.unknown_index = self.re_embed
|
| 245 |
+
self.re_embed = self.re_embed + 1
|
| 246 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
| 247 |
+
f"Using {self.unknown_index} for unknown indices.")
|
| 248 |
+
else:
|
| 249 |
+
self.re_embed = n_e
|
| 250 |
+
|
| 251 |
+
self.sane_index_shape = sane_index_shape
|
| 252 |
+
|
| 253 |
+
def remap_to_used(self, inds):
|
| 254 |
+
ishape = inds.shape
|
| 255 |
+
assert len(ishape) > 1
|
| 256 |
+
inds = inds.reshape(ishape[0], -1)
|
| 257 |
+
used = self.used.to(inds)
|
| 258 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
| 259 |
+
new = match.argmax(-1)
|
| 260 |
+
unknown = match.sum(2) < 1
|
| 261 |
+
if self.unknown_index == "random":
|
| 262 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
| 263 |
+
else:
|
| 264 |
+
new[unknown] = self.unknown_index
|
| 265 |
+
return new.reshape(ishape)
|
| 266 |
+
|
| 267 |
+
def unmap_to_all(self, inds):
|
| 268 |
+
ishape = inds.shape
|
| 269 |
+
assert len(ishape) > 1
|
| 270 |
+
inds = inds.reshape(ishape[0], -1)
|
| 271 |
+
used = self.used.to(inds)
|
| 272 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
| 273 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
| 274 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
| 275 |
+
return back.reshape(ishape)
|
| 276 |
+
|
| 277 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
| 278 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
| 279 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
| 280 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
| 281 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 282 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
| 283 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 284 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 285 |
+
|
| 286 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| 287 |
+
torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
|
| 288 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
| 289 |
+
|
| 290 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
| 291 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
| 292 |
+
perplexity = None
|
| 293 |
+
min_encodings = None
|
| 294 |
+
|
| 295 |
+
# compute loss for embedding
|
| 296 |
+
if not self.legacy:
|
| 297 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
|
| 298 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 299 |
+
else:
|
| 300 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
|
| 301 |
+
torch.mean((z_q - z.detach()) ** 2)
|
| 302 |
+
|
| 303 |
+
# preserve gradients
|
| 304 |
+
z_q = z + (z_q - z).detach()
|
| 305 |
+
|
| 306 |
+
# reshape back to match original input shape
|
| 307 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
| 308 |
+
|
| 309 |
+
if self.remap is not None:
|
| 310 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
| 311 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
| 312 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
| 313 |
+
|
| 314 |
+
if self.sane_index_shape:
|
| 315 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
| 316 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
| 317 |
+
|
| 318 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
| 319 |
+
|
| 320 |
+
def get_codebook_entry(self, indices, shape):
|
| 321 |
+
# shape specifying (batch, height, width, channel)
|
| 322 |
+
if self.remap is not None:
|
| 323 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
| 324 |
+
indices = self.unmap_to_all(indices)
|
| 325 |
+
indices = indices.reshape(-1) # flatten again
|
| 326 |
+
|
| 327 |
+
# get quantized latent vectors
|
| 328 |
+
z_q = self.embedding(indices)
|
| 329 |
+
|
| 330 |
+
if shape is not None:
|
| 331 |
+
z_q = z_q.view(shape)
|
| 332 |
+
# reshape back to match original input shape
|
| 333 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 334 |
+
|
| 335 |
+
return z_q
|
Models/util.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, hashlib
|
| 2 |
+
import requests
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
URL_MAP = {
|
| 6 |
+
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
CKPT_MAP = {
|
| 10 |
+
"vgg_lpips": "vgg.pth"
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
MD5_MAP = {
|
| 14 |
+
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def download(url, local_path, chunk_size=1024):
|
| 19 |
+
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
| 20 |
+
with requests.get(url, stream=True) as r:
|
| 21 |
+
total_size = int(r.headers.get("content-length", 0))
|
| 22 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
| 23 |
+
with open(local_path, "wb") as f:
|
| 24 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
| 25 |
+
if data:
|
| 26 |
+
f.write(data)
|
| 27 |
+
pbar.update(chunk_size)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def md5_hash(path):
|
| 31 |
+
with open(path, "rb") as f:
|
| 32 |
+
content = f.read()
|
| 33 |
+
return hashlib.md5(content).hexdigest()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_ckpt_path(name, root, check=False):
|
| 37 |
+
assert name in URL_MAP
|
| 38 |
+
path = os.path.join(root, CKPT_MAP[name])
|
| 39 |
+
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
| 40 |
+
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
| 41 |
+
download(URL_MAP[name], path)
|
| 42 |
+
md5 = md5_hash(path)
|
| 43 |
+
assert md5 == MD5_MAP[name], md5
|
| 44 |
+
return path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class KeyNotFoundError(Exception):
|
| 48 |
+
def __init__(self, cause, keys=None, visited=None):
|
| 49 |
+
self.cause = cause
|
| 50 |
+
self.keys = keys
|
| 51 |
+
self.visited = visited
|
| 52 |
+
messages = list()
|
| 53 |
+
if keys is not None:
|
| 54 |
+
messages.append("Key not found: {}".format(keys))
|
| 55 |
+
if visited is not None:
|
| 56 |
+
messages.append("Visited: {}".format(visited))
|
| 57 |
+
messages.append("Cause:\n{}".format(cause))
|
| 58 |
+
message = "\n".join(messages)
|
| 59 |
+
super().__init__(message)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def retrieve(
|
| 63 |
+
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
| 64 |
+
):
|
| 65 |
+
"""Given a nested list or dict return the desired value at key expanding
|
| 66 |
+
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
| 67 |
+
is done in-place.
|
| 68 |
+
|
| 69 |
+
Parameters
|
| 70 |
+
----------
|
| 71 |
+
list_or_dict : list or dict
|
| 72 |
+
Possibly nested list or dictionary.
|
| 73 |
+
key : str
|
| 74 |
+
key/to/value, path like string describing all keys necessary to
|
| 75 |
+
consider to get to the desired value. List indices can also be
|
| 76 |
+
passed here.
|
| 77 |
+
splitval : str
|
| 78 |
+
String that defines the delimiter between keys of the
|
| 79 |
+
different depth levels in `key`.
|
| 80 |
+
default : obj
|
| 81 |
+
Value returned if :attr:`key` is not found.
|
| 82 |
+
expand : bool
|
| 83 |
+
Whether to expand callable nodes on the path or not.
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
The desired value or if :attr:`default` is not ``None`` and the
|
| 88 |
+
:attr:`key` is not found returns ``default``.
|
| 89 |
+
|
| 90 |
+
Raises
|
| 91 |
+
------
|
| 92 |
+
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
| 93 |
+
``None``.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
keys = key.split(splitval)
|
| 97 |
+
|
| 98 |
+
success = True
|
| 99 |
+
try:
|
| 100 |
+
visited = []
|
| 101 |
+
parent = None
|
| 102 |
+
last_key = None
|
| 103 |
+
for key in keys:
|
| 104 |
+
if callable(list_or_dict):
|
| 105 |
+
if not expand:
|
| 106 |
+
raise KeyNotFoundError(
|
| 107 |
+
ValueError(
|
| 108 |
+
"Trying to get past callable node with expand=False."
|
| 109 |
+
),
|
| 110 |
+
keys=keys,
|
| 111 |
+
visited=visited,
|
| 112 |
+
)
|
| 113 |
+
list_or_dict = list_or_dict()
|
| 114 |
+
parent[last_key] = list_or_dict
|
| 115 |
+
|
| 116 |
+
last_key = key
|
| 117 |
+
parent = list_or_dict
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
if isinstance(list_or_dict, dict):
|
| 121 |
+
list_or_dict = list_or_dict[key]
|
| 122 |
+
else:
|
| 123 |
+
list_or_dict = list_or_dict[int(key)]
|
| 124 |
+
except (KeyError, IndexError, ValueError) as e:
|
| 125 |
+
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
| 126 |
+
|
| 127 |
+
visited += [key]
|
| 128 |
+
# final expansion of retrieved value
|
| 129 |
+
if expand and callable(list_or_dict):
|
| 130 |
+
list_or_dict = list_or_dict()
|
| 131 |
+
parent[last_key] = list_or_dict
|
| 132 |
+
except KeyNotFoundError as e:
|
| 133 |
+
if default is None:
|
| 134 |
+
raise e
|
| 135 |
+
else:
|
| 136 |
+
list_or_dict = default
|
| 137 |
+
success = False
|
| 138 |
+
|
| 139 |
+
if not pass_success:
|
| 140 |
+
return list_or_dict
|
| 141 |
+
else:
|
| 142 |
+
return list_or_dict, success
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
config = {"keya": "a",
|
| 147 |
+
"keyb": "b",
|
| 148 |
+
"keyc":
|
| 149 |
+
{"cc1": 1,
|
| 150 |
+
"cc2": 2,
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
from omegaconf import OmegaConf
|
| 154 |
+
config = OmegaConf.create(config)
|
| 155 |
+
print(config)
|
| 156 |
+
retrieve(config, "keya")
|
| 157 |
+
|
__init__.py
ADDED
|
File without changes
|
__pycache__/runner.cpython-38.pyc
ADDED
|
Binary file (7.04 kB). View file
|
|
|
flagged/log.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cls,sm_temp,w,r_temp,step,seed,nb_img,output,flag,username,timestamp
|
| 2 |
+
31,1.3,25,4.5,16,1,1,,,,2024-01-22 21:12:04.408431
|
gradio_app.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from runner import MaskGIT
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import torchvision.utils as vutils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Args(argparse.Namespace):
|
| 12 |
+
data_folder = ""
|
| 13 |
+
vqgan_folder = r"C:\Users\vbesnier\Experiment\VQGAN"
|
| 14 |
+
writer_log = ""
|
| 15 |
+
data = ""
|
| 16 |
+
mask_value = 1024
|
| 17 |
+
seed = 1
|
| 18 |
+
channel = 3
|
| 19 |
+
num_workers = 0
|
| 20 |
+
iter = 0
|
| 21 |
+
global_epoch = 0
|
| 22 |
+
lr = 1e-4
|
| 23 |
+
drop_label = 0.1
|
| 24 |
+
resume = True
|
| 25 |
+
device = "cpu"
|
| 26 |
+
print(device)
|
| 27 |
+
debug = True
|
| 28 |
+
test_only = False
|
| 29 |
+
is_master = True
|
| 30 |
+
is_multi_gpus = False
|
| 31 |
+
vit_size = "base"
|
| 32 |
+
vit_folder = r"C:\Users\vbesnier\Experiment\MaskGIT\current.pth"
|
| 33 |
+
img_size = 256
|
| 34 |
+
patch_size = 256 // 16
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def set_seed(seed):
|
| 38 |
+
if seed > 0:
|
| 39 |
+
torch.manual_seed(seed)
|
| 40 |
+
torch.cuda.manual_seed(seed)
|
| 41 |
+
np.random.seed(seed)
|
| 42 |
+
random.seed(seed)
|
| 43 |
+
torch.backends.cudnn.enable = False
|
| 44 |
+
torch.backends.cudnn.deterministic = True
|
| 45 |
+
|
| 46 |
+
args = Args()
|
| 47 |
+
maskgit = MaskGIT(args)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Function to perform image synthesis
|
| 51 |
+
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
|
| 52 |
+
# Perform image synthesis using your model
|
| 53 |
+
set_seed(seed)
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
labels = [cls] * nb_img
|
| 56 |
+
labels = torch.LongTensor(labels).to(args.device)
|
| 57 |
+
gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
|
| 58 |
+
randomize="linear", r_temp=r_temp, sched_mode="arccos",
|
| 59 |
+
step=step)[0]
|
| 60 |
+
|
| 61 |
+
# Post-process the output image (adjust based on your needs)
|
| 62 |
+
output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
|
| 63 |
+
|
| 64 |
+
return output_image
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Gradio Interface
|
| 68 |
+
app = gr.Interface(
|
| 69 |
+
fn=synthesize_image,
|
| 70 |
+
inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
|
| 71 |
+
gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
|
| 72 |
+
outputs=gr.Image(),
|
| 73 |
+
title="Image Synthesis using MaskGIT",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Launch the Gradio app
|
| 77 |
+
app.launch(share=True)
|
runner.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Trainer for MaskGIT
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 13 |
+
|
| 14 |
+
from Models.models.transformer import MaskTransformer
|
| 15 |
+
from Models.models.vqgan import VQModel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MaskGIT(nn.Module):
|
| 19 |
+
|
| 20 |
+
def __init__(self, args):
|
| 21 |
+
""" Initialization of the model (VQGAN and Masked Transformer), optimizer, criterion, etc."""
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
self.args = args # Main argument see main.py
|
| 25 |
+
self.patch_size = self.args.img_size // 16 # Number of vizual token (+1 for the class)
|
| 26 |
+
self.scaler = torch.cuda.amp.GradScaler() # Init Scaler for multi GPUs
|
| 27 |
+
self.vit = self.get_network("vit") # Load Masked Bidirectional Transformer
|
| 28 |
+
self.ae = self.get_network("autoencoder") # Load VQGAN
|
| 29 |
+
|
| 30 |
+
def get_network(self, archi):
|
| 31 |
+
""" return the network, load checkpoint if self.args.resume == True
|
| 32 |
+
:param
|
| 33 |
+
archi -> str: vit|autoencoder, the architecture to load
|
| 34 |
+
:return
|
| 35 |
+
model -> nn.Module: the network
|
| 36 |
+
"""
|
| 37 |
+
if archi == "vit":
|
| 38 |
+
if self.args.vit_size == "base":
|
| 39 |
+
model = MaskTransformer(
|
| 40 |
+
img_size=self.args.img_size, hidden_dim=768, codebook_size=1024, depth=24, heads=16, mlp_dim=3072, dropout=0.1 # Small
|
| 41 |
+
)
|
| 42 |
+
elif self.args.vit_size == "big":
|
| 43 |
+
model = MaskTransformer(
|
| 44 |
+
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=32, heads=16, mlp_dim=3072, dropout=0.1 # Big
|
| 45 |
+
)
|
| 46 |
+
elif self.args.vit_size == "huge":
|
| 47 |
+
model = MaskTransformer(
|
| 48 |
+
img_size=self.args.img_size, hidden_dim=1024, codebook_size=1024, depth=48, heads=16, mlp_dim=3072, dropout=0.1 # Huge
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if self.args.resume:
|
| 52 |
+
ckpt = self.args.vit_folder
|
| 53 |
+
ckpt += "current.pth" if os.path.isdir(self.args.vit_folder) else ""
|
| 54 |
+
if self.args.is_master:
|
| 55 |
+
print("load ckpt from:", ckpt)
|
| 56 |
+
# Read checkpoint file
|
| 57 |
+
checkpoint = torch.load(ckpt, map_location='cpu')
|
| 58 |
+
# Load network
|
| 59 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 60 |
+
|
| 61 |
+
model = model.to(self.args.device)
|
| 62 |
+
|
| 63 |
+
if self.args.is_multi_gpus: # put model on multi GPUs if available
|
| 64 |
+
model = DDP(model, device_ids=[self.args.device])
|
| 65 |
+
|
| 66 |
+
elif archi == "autoencoder":
|
| 67 |
+
# Load config
|
| 68 |
+
config = OmegaConf.load(os.path.join(self.args.vqgan_folder, "model.yaml"))
|
| 69 |
+
model = VQModel(**config.model.params)
|
| 70 |
+
checkpoint = torch.load(os.path.join(self.args.vqgan_folder, "last.ckpt"), map_location="cpu")["state_dict"]
|
| 71 |
+
# Load network
|
| 72 |
+
model.load_state_dict(checkpoint, strict=False)
|
| 73 |
+
model = model.eval()
|
| 74 |
+
model = model.to(self.args.device)
|
| 75 |
+
|
| 76 |
+
if self.args.is_multi_gpus: # put model on multi GPUs if available
|
| 77 |
+
model = DDP(model, device_ids=[self.args.device])
|
| 78 |
+
model = model.module
|
| 79 |
+
else:
|
| 80 |
+
model = None
|
| 81 |
+
|
| 82 |
+
if self.args.is_master:
|
| 83 |
+
print(f"Size of model {archi}: "
|
| 84 |
+
f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 10 ** 6:.3f}M")
|
| 85 |
+
|
| 86 |
+
return model
|
| 87 |
+
|
| 88 |
+
def adap_sche(self, step, mode="arccos", leave=False):
|
| 89 |
+
""" Create a sampling scheduler
|
| 90 |
+
:param
|
| 91 |
+
step -> int: number of prediction during inference
|
| 92 |
+
mode -> str: the rate of value to unmask
|
| 93 |
+
leave -> bool: tqdm arg on either to keep the bar or not
|
| 94 |
+
:return
|
| 95 |
+
scheduler -> torch.LongTensor(): the list of token to predict at each step
|
| 96 |
+
"""
|
| 97 |
+
r = torch.linspace(1, 0, step)
|
| 98 |
+
if mode == "root": # root scheduler
|
| 99 |
+
val_to_mask = 1 - (r ** .5)
|
| 100 |
+
elif mode == "linear": # linear scheduler
|
| 101 |
+
val_to_mask = 1 - r
|
| 102 |
+
elif mode == "square": # square scheduler
|
| 103 |
+
val_to_mask = 1 - (r ** 2)
|
| 104 |
+
elif mode == "cosine": # cosine scheduler
|
| 105 |
+
val_to_mask = torch.cos(r * math.pi * 0.5)
|
| 106 |
+
elif mode == "arccos": # arc cosine scheduler
|
| 107 |
+
val_to_mask = torch.arccos(r) / (math.pi * 0.5)
|
| 108 |
+
else:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
# fill the scheduler by the ratio of tokens to predict at each step
|
| 112 |
+
sche = (val_to_mask / val_to_mask.sum()) * (self.patch_size * self.patch_size)
|
| 113 |
+
sche = sche.round()
|
| 114 |
+
sche[sche == 0] = 1 # add 1 to predict a least 1 token / step
|
| 115 |
+
sche[-1] += (self.patch_size * self.patch_size) - sche.sum() # need to sum up nb of code
|
| 116 |
+
return tqdm(sche.int(), leave=leave)
|
| 117 |
+
|
| 118 |
+
def sample(self, init_code=None, nb_sample=50, labels=None, sm_temp=1, w=3,
|
| 119 |
+
randomize="linear", r_temp=4.5, sched_mode="arccos", step=12):
|
| 120 |
+
""" Generate sample with the MaskGIT model
|
| 121 |
+
:param
|
| 122 |
+
init_code -> torch.LongTensor: nb_sample x 16 x 16, the starting initialization code
|
| 123 |
+
nb_sample -> int: the number of image to generated
|
| 124 |
+
labels -> torch.LongTensor: the list of classes to generate
|
| 125 |
+
sm_temp -> float: the temperature before softmax
|
| 126 |
+
w -> float: scale for the classifier free guidance
|
| 127 |
+
randomize -> str: linear|warm_up|random|no, either or not to add randomness
|
| 128 |
+
r_temp -> float: temperature for the randomness
|
| 129 |
+
sched_mode -> str: root|linear|square|cosine|arccos, the shape of the scheduler
|
| 130 |
+
step: -> int: number of step for the decoding
|
| 131 |
+
:return
|
| 132 |
+
x -> torch.FloatTensor: nb_sample x 3 x 256 x 256, the generated images
|
| 133 |
+
code -> torch.LongTensor: nb_sample x step x 16 x 16, the code corresponding to the generated images
|
| 134 |
+
"""
|
| 135 |
+
self.vit.eval()
|
| 136 |
+
l_codes = [] # Save the intermediate codes predicted
|
| 137 |
+
l_mask = [] # Save the intermediate masks
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
if labels is None: # Default classes generated
|
| 140 |
+
# goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random
|
| 141 |
+
labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, random.randint(0, 999)] * (nb_sample // 10)
|
| 142 |
+
labels = torch.LongTensor(labels).to(self.args.device)
|
| 143 |
+
|
| 144 |
+
drop = torch.ones(nb_sample, dtype=torch.bool).to(self.args.device)
|
| 145 |
+
if init_code is not None: # Start with a pre-define code
|
| 146 |
+
code = init_code
|
| 147 |
+
mask = (init_code == 1024).float().view(nb_sample, self.patch_size*self.patch_size)
|
| 148 |
+
else: # Initialize a code
|
| 149 |
+
if self.args.mask_value < 0: # Code initialize with random tokens
|
| 150 |
+
code = torch.randint(0, 1024, (nb_sample, self.patch_size, self.patch_size)).to(self.args.device)
|
| 151 |
+
else: # Code initialize with masked tokens
|
| 152 |
+
code = torch.full((nb_sample, self.patch_size, self.patch_size), self.args.mask_value).to(self.args.device)
|
| 153 |
+
mask = torch.ones(nb_sample, self.patch_size*self.patch_size).to(self.args.device)
|
| 154 |
+
|
| 155 |
+
# Instantiate scheduler
|
| 156 |
+
if isinstance(sched_mode, str): # Standard ones
|
| 157 |
+
scheduler = self.adap_sche(step, mode=sched_mode)
|
| 158 |
+
else: # Custom one
|
| 159 |
+
scheduler = sched_mode
|
| 160 |
+
|
| 161 |
+
# Beginning of sampling, t = number of token to predict a step "indice"
|
| 162 |
+
for indice, t in enumerate(scheduler):
|
| 163 |
+
if mask.sum() < t: # Cannot predict more token than 16*16 or 32*32
|
| 164 |
+
t = int(mask.sum().item())
|
| 165 |
+
|
| 166 |
+
if mask.sum() == 0: # Break if code is fully predicted
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
with torch.cuda.amp.autocast(): # half precision
|
| 170 |
+
if w != 0:
|
| 171 |
+
# Model Prediction
|
| 172 |
+
logit = self.vit(torch.cat([code.clone(), code.clone()], dim=0),
|
| 173 |
+
torch.cat([labels, labels], dim=0),
|
| 174 |
+
torch.cat([~drop, drop], dim=0))
|
| 175 |
+
logit_c, logit_u = torch.chunk(logit, 2, dim=0)
|
| 176 |
+
_w = w * (indice / (len(scheduler)-1))
|
| 177 |
+
# Classifier Free Guidance
|
| 178 |
+
logit = (1 + _w) * logit_c - _w * logit_u
|
| 179 |
+
else:
|
| 180 |
+
logit = self.vit(code.clone(), labels, drop_label=~drop)
|
| 181 |
+
|
| 182 |
+
prob = torch.softmax(logit * sm_temp, -1)
|
| 183 |
+
# Sample the code from the softmax prediction
|
| 184 |
+
distri = torch.distributions.Categorical(probs=prob)
|
| 185 |
+
pred_code = distri.sample()
|
| 186 |
+
|
| 187 |
+
conf = torch.gather(prob, 2, pred_code.view(nb_sample, self.patch_size*self.patch_size, 1))
|
| 188 |
+
|
| 189 |
+
if randomize == "linear": # add gumbel noise decreasing over the sampling process
|
| 190 |
+
ratio = (indice / len(scheduler))
|
| 191 |
+
rand = r_temp * np.random.gumbel(size=(nb_sample, self.patch_size*self.patch_size)) * (1 - ratio)
|
| 192 |
+
conf = torch.log(conf.squeeze()) + torch.from_numpy(rand).to(self.args.device)
|
| 193 |
+
elif randomize == "warm_up": # chose random sample for the 2 first steps
|
| 194 |
+
conf = torch.rand_like(conf) if indice < 2 else conf
|
| 195 |
+
elif randomize == "random": # chose random prediction at each step
|
| 196 |
+
conf = torch.rand_like(conf)
|
| 197 |
+
|
| 198 |
+
# do not predict on already predicted tokens
|
| 199 |
+
conf[~mask.bool()] = -math.inf
|
| 200 |
+
|
| 201 |
+
# chose the predicted token with the highest confidence
|
| 202 |
+
tresh_conf, indice_mask = torch.topk(conf.view(nb_sample, -1), k=t, dim=-1)
|
| 203 |
+
tresh_conf = tresh_conf[:, -1]
|
| 204 |
+
|
| 205 |
+
# replace the chosen tokens
|
| 206 |
+
conf = (conf >= tresh_conf.unsqueeze(-1)).view(nb_sample, self.patch_size, self.patch_size)
|
| 207 |
+
f_mask = (mask.view(nb_sample, self.patch_size, self.patch_size).float() * conf.view(nb_sample, self.patch_size, self.patch_size).float()).bool()
|
| 208 |
+
code[f_mask] = pred_code.view(nb_sample, self.patch_size, self.patch_size)[f_mask]
|
| 209 |
+
|
| 210 |
+
# update the mask
|
| 211 |
+
for i_mask, ind_mask in enumerate(indice_mask):
|
| 212 |
+
mask[i_mask, ind_mask] = 0
|
| 213 |
+
l_codes.append(pred_code.view(nb_sample, self.patch_size, self.patch_size).clone())
|
| 214 |
+
l_mask.append(mask.view(nb_sample, self.patch_size, self.patch_size).clone())
|
| 215 |
+
|
| 216 |
+
# decode the final prediction
|
| 217 |
+
_code = torch.clamp(code, 0, 1023) # VQGAN has only 1024 codebook
|
| 218 |
+
x = self.ae.decode_code(_code)
|
| 219 |
+
x = (torch.clamp(x, -1, 1) + 1) / 2
|
| 220 |
+
self.vit.train()
|
| 221 |
+
return x, l_codes, l_mask
|