├── .gitignore ├── LICENSE ├── README.md ├── assets ├── codebooks.png ├── reconstructions.png └── variance_ratio.png ├── model.py ├── plot.ipynb └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # PyCharm project settings 98 | .idea 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 bshall 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vector Quantized VAE 2 | A PyTorch implementation of [Continuous Relaxation Training of Discrete Latent Variable Image Models](http://bayesiandeeplearning.org/2017/papers/54.pdf). 3 | 4 | Ensure you have Python 3.7 and PyTorch 1.2 or greater. 5 | To train the `VQVAE` model with 8 categorical dimensions and 128 codes per dimension 6 | run the following command: 7 | ``` 8 | python train.py --model=VQVAE --latent-dim=8 --num-embeddings=128 9 | ``` 10 | To train the `GS-Soft` model use `--model=GSSOFT`. 11 | Pretrained weights for the `VQVAE` and `GS-Soft` models can be found 12 | [here](https://github.com/bshall/VectorQuantizedVAE/releases/tag/v0.1). 13 | 14 |

15 | VQVAE Reconstructions 16 |

17 | 18 | The `VQVAE` model gets ~4.82 bpd while the `GS-soft` model gets ~4.6 bpd. 19 | 20 | # Analysis of the Codebooks 21 | 22 | As demonstrated in the paper, the codebook matrices are low-dimensional, spanning only a few dimensions: 23 | 24 |

25 | Explained Variance Ratio 26 |

27 | 28 | Projecting the codes onto the first 3 principal components shows that the codes typically tile 29 | continuous 1- or 2-D manifolds: 30 | 31 |

32 | Codebook principal components 33 |

34 | -------------------------------------------------------------------------------- /assets/codebooks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshall/VectorQuantizedVAE/d7bc845c9d46d232f5c6aae1a825b8c6952084dc/assets/codebooks.png -------------------------------------------------------------------------------- /assets/reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshall/VectorQuantizedVAE/d7bc845c9d46d232f5c6aae1a825b8c6952084dc/assets/reconstructions.png -------------------------------------------------------------------------------- /assets/variance_ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshall/VectorQuantizedVAE/d7bc845c9d46d232f5c6aae1a825b8c6952084dc/assets/variance_ratio.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical, RelaxedOneHotCategorical 5 | import math 6 | 7 | 8 | class VQEmbeddingEMA(nn.Module): 9 | def __init__(self, latent_dim, num_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5): 10 | super(VQEmbeddingEMA, self).__init__() 11 | self.commitment_cost = commitment_cost 12 | self.decay = decay 13 | self.epsilon = epsilon 14 | 15 | embedding = torch.zeros(latent_dim, num_embeddings, embedding_dim) 16 | embedding.uniform_(-1/num_embeddings, 1/num_embeddings) 17 | self.register_buffer("embedding", embedding) 18 | self.register_buffer("ema_count", torch.zeros(latent_dim, num_embeddings)) 19 | self.register_buffer("ema_weight", self.embedding.clone()) 20 | 21 | def forward(self, x): 22 | B, C, H, W = x.size() 23 | N, M, D = self.embedding.size() 24 | assert C == N * D 25 | 26 | x = x.view(B, N, D, H, W).permute(1, 0, 3, 4, 2) 27 | x_flat = x.detach().reshape(N, -1, D) 28 | 29 | distances = torch.baddbmm(torch.sum(self.embedding ** 2, dim=2).unsqueeze(1) + 30 | torch.sum(x_flat ** 2, dim=2, keepdim=True), 31 | x_flat, self.embedding.transpose(1, 2), 32 | alpha=-2.0, beta=1.0) 33 | 34 | indices = torch.argmin(distances, dim=-1) 35 | encodings = F.one_hot(indices, M).float() 36 | quantized = torch.gather(self.embedding, 1, indices.unsqueeze(-1).expand(-1, -1, D)) 37 | quantized = quantized.view_as(x) 38 | 39 | if self.training: 40 | self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=1) 41 | 42 | n = torch.sum(self.ema_count, dim=-1, keepdim=True) 43 | self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n 44 | 45 | dw = torch.bmm(encodings.transpose(1, 2), x_flat) 46 | self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw 47 | 48 | self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) 49 | 50 | e_latent_loss = F.mse_loss(x, quantized.detach()) 51 | loss = self.commitment_cost * e_latent_loss 52 | 53 | quantized = x + (quantized - x).detach() 54 | 55 | avg_probs = torch.mean(encodings, dim=1) 56 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10), dim=-1)) 57 | 58 | return quantized.permute(1, 0, 4, 2, 3).reshape(B, C, H, W), loss, perplexity.sum() 59 | 60 | 61 | class VQEmbeddingGSSoft(nn.Module): 62 | def __init__(self, latent_dim, num_embeddings, embedding_dim): 63 | super(VQEmbeddingGSSoft, self).__init__() 64 | 65 | self.embedding = nn.Parameter(torch.Tensor(latent_dim, num_embeddings, embedding_dim)) 66 | nn.init.uniform_(self.embedding, -1/num_embeddings, 1/num_embeddings) 67 | 68 | def forward(self, x): 69 | B, C, H, W = x.size() 70 | N, M, D = self.embedding.size() 71 | assert C == N * D 72 | 73 | x = x.view(B, N, D, H, W).permute(1, 0, 3, 4, 2) 74 | x_flat = x.reshape(N, -1, D) 75 | 76 | distances = torch.baddbmm(torch.sum(self.embedding ** 2, dim=2).unsqueeze(1) + 77 | torch.sum(x_flat ** 2, dim=2, keepdim=True), 78 | x_flat, self.embedding.transpose(1, 2), 79 | alpha=-2.0, beta=1.0) 80 | distances = distances.view(N, B, H, W, M) 81 | 82 | dist = RelaxedOneHotCategorical(0.5, logits=-distances) 83 | if self.training: 84 | samples = dist.rsample().view(N, -1, M) 85 | else: 86 | samples = torch.argmax(dist.probs, dim=-1) 87 | samples = F.one_hot(samples, M).float() 88 | samples = samples.view(N, -1, M) 89 | 90 | quantized = torch.bmm(samples, self.embedding) 91 | quantized = quantized.view_as(x) 92 | 93 | KL = dist.probs * (dist.logits + math.log(M)) 94 | KL[(dist.probs == 0).expand_as(KL)] = 0 95 | KL = KL.sum(dim=(0, 2, 3, 4)).mean() 96 | 97 | avg_probs = torch.mean(samples, dim=1) 98 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10), dim=-1)) 99 | 100 | return quantized.permute(1, 0, 4, 2, 3).reshape(B, C, H, W), KL, perplexity.sum() 101 | 102 | 103 | class Residual(nn.Module): 104 | def __init__(self, channels): 105 | super(Residual, self).__init__() 106 | self.block = nn.Sequential( 107 | nn.ReLU(True), 108 | nn.Conv2d(channels, channels, 3, 1, 1, bias=False), 109 | nn.BatchNorm2d(channels), 110 | nn.ReLU(True), 111 | nn.Conv2d(channels, channels, 1, bias=False), 112 | nn.BatchNorm2d(channels) 113 | ) 114 | 115 | def forward(self, x): 116 | return x + self.block(x) 117 | 118 | 119 | class Encoder(nn.Module): 120 | def __init__(self, channels, latent_dim, embedding_dim): 121 | super(Encoder, self).__init__() 122 | self.encoder = nn.Sequential( 123 | nn.Conv2d(3, channels, 4, 2, 1, bias=False), 124 | nn.BatchNorm2d(channels), 125 | nn.ReLU(True), 126 | nn.Conv2d(channels, channels, 4, 2, 1, bias=False), 127 | nn.BatchNorm2d(channels), 128 | Residual(channels), 129 | Residual(channels), 130 | nn.Conv2d(channels, latent_dim * embedding_dim, 1) 131 | ) 132 | 133 | def forward(self, x): 134 | return self.encoder(x) 135 | 136 | 137 | class Decoder(nn.Module): 138 | def __init__(self, channels, latent_dim, embedding_dim): 139 | super(Decoder, self).__init__() 140 | self.decoder = nn.Sequential( 141 | nn.Conv2d(latent_dim * embedding_dim, channels, 1, bias=False), 142 | nn.BatchNorm2d(channels), 143 | Residual(channels), 144 | Residual(channels), 145 | nn.ConvTranspose2d(channels, channels, 4, 2, 1, bias=False), 146 | nn.BatchNorm2d(channels), 147 | nn.ReLU(True), 148 | nn.ConvTranspose2d(channels, channels, 4, 2, 1, bias=False), 149 | nn.BatchNorm2d(channels), 150 | nn.ReLU(True), 151 | nn.Conv2d(channels, 3 * 256, 1) 152 | ) 153 | 154 | def forward(self, x): 155 | x = self.decoder(x) 156 | B, _, H, W = x.size() 157 | x = x.view(B, 3, 256, H, W).permute(0, 1, 3, 4, 2) 158 | dist = Categorical(logits=x) 159 | return dist 160 | 161 | 162 | class VQVAE(nn.Module): 163 | def __init__(self, channels, latent_dim, num_embeddings, embedding_dim): 164 | super(VQVAE, self).__init__() 165 | self.encoder = Encoder(channels, latent_dim, embedding_dim) 166 | self.codebook = VQEmbeddingEMA(latent_dim, num_embeddings, embedding_dim) 167 | self.decoder = Decoder(channels, latent_dim, embedding_dim) 168 | 169 | def forward(self, x): 170 | x = self.encoder(x) 171 | x, loss, perplexity = self.codebook(x) 172 | dist = self.decoder(x) 173 | return dist, loss, perplexity 174 | 175 | 176 | class GSSOFT(nn.Module): 177 | def __init__(self, channels, latent_dim, num_embeddings, embedding_dim): 178 | super(GSSOFT, self).__init__() 179 | self.encoder = Encoder(channels, latent_dim, embedding_dim) 180 | self.codebook = VQEmbeddingGSSoft(latent_dim, num_embeddings, embedding_dim) 181 | self.decoder = Decoder(channels, latent_dim, embedding_dim) 182 | 183 | def forward(self, x): 184 | x = self.encoder(x) 185 | x, KL, perplexity = self.codebook(x) 186 | dist = self.decoder(x) 187 | return dist, KL, perplexity 188 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | import torch 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torchvision import datasets, transforms, utils 12 | 13 | from model import VQVAE, GSSOFT 14 | 15 | 16 | def save_checkpoint(model, optimizer, step, checkpoint_dir): 17 | checkpoint_state = { 18 | "model": model.state_dict(), 19 | "optimizer": optimizer.state_dict(), 20 | "step": step} 21 | checkpoint_path = checkpoint_dir / "model.ckpt-{}.pt".format(step) 22 | torch.save(checkpoint_state, checkpoint_path) 23 | print("Saved checkpoint: {}".format(checkpoint_path)) 24 | 25 | 26 | def shift(x): 27 | return x - 0.5 28 | 29 | 30 | def train_gssoft(args): 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | 33 | model = GSSOFT(args.channels, args.latent_dim, args.num_embeddings, args.embedding_dim) 34 | model.to(device) 35 | 36 | model_name = "{}_C_{}_N_{}_M_{}_D_{}".format(args.model, args.channels, args.latent_dim, 37 | args.num_embeddings, args.embedding_dim) 38 | 39 | checkpoint_dir = Path(model_name) 40 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 41 | 42 | writer = SummaryWriter(log_dir=Path("runs") / model_name) 43 | 44 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 45 | 46 | if args.resume is not None: 47 | print("Resume checkpoint from: {}:".format(args.resume)) 48 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 49 | model.load_state_dict(checkpoint["model"]) 50 | optimizer.load_state_dict(checkpoint["optimizer"]) 51 | global_step = checkpoint["step"] 52 | else: 53 | global_step = 0 54 | 55 | transform = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Lambda(shift) 58 | ]) 59 | training_dataset = datasets.CIFAR10("./CIFAR10", train=True, download=True, 60 | transform=transform) 61 | 62 | test_dataset = datasets.CIFAR10("./CIFAR10", train=False, download=True, 63 | transform=transform) 64 | 65 | training_dataloader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True, 66 | num_workers=args.num_workers, pin_memory=True) 67 | 68 | test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, drop_last=True, 69 | num_workers=args.num_workers, pin_memory=True) 70 | 71 | num_epochs = args.num_training_steps // len(training_dataloader) + 1 72 | start_epoch = global_step // len(training_dataloader) + 1 73 | 74 | N = 3 * 32 * 32 75 | 76 | for epoch in range(start_epoch, num_epochs + 1): 77 | model.train() 78 | average_logp = average_KL = average_elbo = average_bpd = average_perplexity = 0 79 | for i, (images, _) in enumerate(tqdm(training_dataloader), 1): 80 | images = images.to(device) 81 | 82 | dist, KL, perplexity = model(images) 83 | targets = (images + 0.5) * 255 84 | targets = targets.long() 85 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean() 86 | loss = (KL - logp) / N 87 | elbo = (KL - logp) / N 88 | bpd = elbo / np.log(2) 89 | 90 | optimizer.zero_grad() 91 | loss.backward() 92 | optimizer.step() 93 | 94 | global_step += 1 95 | 96 | if global_step % 25000 == 0: 97 | save_checkpoint(model, optimizer, global_step, checkpoint_dir) 98 | 99 | average_logp += (logp.item() - average_logp) / i 100 | average_KL += (KL.item() - average_KL) / i 101 | average_elbo += (elbo.item() - average_elbo) / i 102 | average_bpd += (bpd.item() - average_bpd) / i 103 | average_perplexity += (perplexity.item() - average_perplexity) / i 104 | 105 | writer.add_scalar("logp/train", average_logp, epoch) 106 | writer.add_scalar("kl/train", average_KL, epoch) 107 | writer.add_scalar("elbo/train", average_elbo, epoch) 108 | writer.add_scalar("bpd/train", average_bpd, epoch) 109 | writer.add_scalar("perplexity/train", average_perplexity, epoch) 110 | 111 | model.eval() 112 | average_logp = average_KL = average_elbo = average_bpd = average_perplexity = 0 113 | for i, (images, _) in enumerate(test_dataloader, 1): 114 | images = images.to(device) 115 | 116 | with torch.no_grad(): 117 | dist, KL, perplexity = model(images) 118 | 119 | targets = (images + 0.5) * 255 120 | targets = targets.long() 121 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean() 122 | elbo = (KL - logp) / N 123 | bpd = elbo / np.log(2) 124 | 125 | average_logp += (logp.item() - average_logp) / i 126 | average_KL += (KL.item() - average_KL) / i 127 | average_elbo += (elbo.item() - average_elbo) / i 128 | average_bpd += (bpd.item() - average_bpd) / i 129 | average_perplexity += (perplexity.item() - average_perplexity) / i 130 | 131 | writer.add_scalar("logp/test", average_logp, epoch) 132 | writer.add_scalar("kl/test", average_KL, epoch) 133 | writer.add_scalar("elbo/test", average_elbo, epoch) 134 | writer.add_scalar("bpd/test", average_bpd, epoch) 135 | writer.add_scalar("perplexity/test", average_perplexity, epoch) 136 | 137 | samples = torch.argmax(dist.logits, dim=-1) 138 | grid = utils.make_grid(samples.float() / 255) 139 | writer.add_image("reconstructions", grid, epoch) 140 | 141 | print("epoch:{}, logp:{:.3E}, KL:{:.3E}, elbo:{:.3f}, bpd:{:.3f}, perplexity:{:.3f}" 142 | .format(epoch, average_logp, average_KL, average_elbo, average_bpd, average_perplexity)) 143 | 144 | 145 | def train_vqvae(args): 146 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 147 | 148 | model = VQVAE(args.channels, args.latent_dim, args.num_embeddings, args.embedding_dim) 149 | model.to(device) 150 | 151 | model_name = "{}_C_{}_N_{}_M_{}_D_{}".format(args.model, args.channels, args.latent_dim, 152 | args.num_embeddings, args.embedding_dim) 153 | 154 | checkpoint_dir = Path(model_name) 155 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 156 | 157 | writer = SummaryWriter(log_dir=Path("runs") / model_name) 158 | 159 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 160 | 161 | if args.resume is not None: 162 | print("Resume checkpoint from: {}:".format(args.resume)) 163 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 164 | model.load_state_dict(checkpoint["model"]) 165 | optimizer.load_state_dict(checkpoint["optimizer"]) 166 | global_step = checkpoint["step"] 167 | else: 168 | global_step = 0 169 | 170 | transform = transforms.Compose([ 171 | transforms.ToTensor(), 172 | transforms.Lambda(shift) 173 | ]) 174 | training_dataset = datasets.CIFAR10("./CIFAR10", train=True, download=True, 175 | transform=transform) 176 | 177 | test_dataset = datasets.CIFAR10("./CIFAR10", train=False, download=True, 178 | transform=transform) 179 | 180 | training_dataloader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True, 181 | num_workers=args.num_workers, pin_memory=True) 182 | 183 | test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, drop_last=True, 184 | num_workers=args.num_workers, pin_memory=True) 185 | 186 | num_epochs = args.num_training_steps // len(training_dataloader) + 1 187 | start_epoch = global_step // len(training_dataloader) + 1 188 | 189 | N = 3 * 32 * 32 190 | KL = args.latent_dim * 8 * 8 * np.log(args.num_embeddings) 191 | 192 | for epoch in range(start_epoch, num_epochs + 1): 193 | model.train() 194 | average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0 195 | for i, (images, _) in enumerate(tqdm(training_dataloader), 1): 196 | images = images.to(device) 197 | 198 | dist, vq_loss, perplexity = model(images) 199 | targets = (images + 0.5) * 255 200 | targets = targets.long() 201 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean() 202 | loss = - logp / N + vq_loss 203 | elbo = (KL - logp) / N 204 | bpd = elbo / np.log(2) 205 | 206 | optimizer.zero_grad() 207 | loss.backward() 208 | optimizer.step() 209 | 210 | global_step += 1 211 | 212 | if global_step % 25000 == 0: 213 | save_checkpoint(model, optimizer, global_step, checkpoint_dir) 214 | 215 | average_logp += (logp.item() - average_logp) / i 216 | average_vq_loss += (vq_loss.item() - average_vq_loss) / i 217 | average_elbo += (elbo.item() - average_elbo) / i 218 | average_bpd += (bpd.item() - average_bpd) / i 219 | average_perplexity += (perplexity.item() - average_perplexity) / i 220 | 221 | writer.add_scalar("logp/train", average_logp, epoch) 222 | writer.add_scalar("kl/train", KL, epoch) 223 | writer.add_scalar("vqloss/train", average_vq_loss, epoch) 224 | writer.add_scalar("elbo/train", average_elbo, epoch) 225 | writer.add_scalar("bpd/train", average_bpd, epoch) 226 | writer.add_scalar("perplexity/train", average_perplexity, epoch) 227 | 228 | model.eval() 229 | average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0 230 | for i, (images, _) in enumerate(test_dataloader, 1): 231 | images = images.to(device) 232 | 233 | with torch.no_grad(): 234 | dist, vq_loss, perplexity = model(images) 235 | 236 | targets = (images + 0.5) * 255 237 | targets = targets.long() 238 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean() 239 | elbo = (KL - logp) / N 240 | bpd = elbo / np.log(2) 241 | 242 | average_logp += (logp.item() - average_logp) / i 243 | average_vq_loss += (vq_loss.item() - average_vq_loss) / i 244 | average_elbo += (elbo.item() - average_elbo) / i 245 | average_bpd += (bpd.item() - average_bpd) / i 246 | average_perplexity += (perplexity.item() - average_perplexity) / i 247 | 248 | writer.add_scalar("logp/test", average_logp, epoch) 249 | writer.add_scalar("kl/test", KL, epoch) 250 | writer.add_scalar("vqloss/test", average_vq_loss, epoch) 251 | writer.add_scalar("elbo/test", average_elbo, epoch) 252 | writer.add_scalar("bpd/test", average_bpd, epoch) 253 | writer.add_scalar("perplexity/test", average_perplexity, epoch) 254 | 255 | samples = torch.argmax(dist.logits, dim=-1) 256 | grid = utils.make_grid(samples.float() / 255) 257 | writer.add_image("reconstructions", grid, epoch) 258 | 259 | print("epoch:{}, logp:{:.3E}, vq loss:{:.3E}, elbo:{:.3f}, bpd:{:.3f}, perplexity:{:.3f}" 260 | .format(epoch, average_logp, average_vq_loss, average_elbo, average_bpd, average_perplexity)) 261 | 262 | 263 | if __name__ == "__main__": 264 | parser = argparse.ArgumentParser() 265 | parser.add_argument("--num-workers", type=int, default=4, help="Number of dataloader workers.") 266 | parser.add_argument("--resume", type=str, default=None, help="Checkpoint path to resume.") 267 | parser.add_argument("--model", choices=["VQVAE", "GSSOFT"], help="Select model to train (either VQVAE or GSSOFT)") 268 | parser.add_argument("--channels", type=int, default=256, help="Number of channels in conv layers.") 269 | parser.add_argument("--latent-dim", type=int, default=8, help="Dimension of categorical latents.") 270 | parser.add_argument("--num-embeddings", type=int, default=128, help="Number of codebook embeddings size.") 271 | parser.add_argument("--embedding-dim", type=int, default=32, help="Dimension of codebook embeddings.") 272 | parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate.") 273 | parser.add_argument("--batch-size", type=int, default=128, help="Batch size.") 274 | parser.add_argument("--num-training-steps", type=int, default=250000, help="Number of training steps.") 275 | args = parser.parse_args() 276 | if args.model == "VQVAE": 277 | train_vqvae(args) 278 | if args.model == "GSSOFT": 279 | train_gssoft(args) 280 | --------------------------------------------------------------------------------