├── .gitignore
├── LICENSE
├── README.md
├── assets
├── codebooks.png
├── reconstructions.png
└── variance_ratio.png
├── model.py
├── plot.ipynb
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # PyCharm project settings
98 | .idea
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 bshall
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vector Quantized VAE
2 | A PyTorch implementation of [Continuous Relaxation Training of Discrete Latent Variable Image Models](http://bayesiandeeplearning.org/2017/papers/54.pdf).
3 |
4 | Ensure you have Python 3.7 and PyTorch 1.2 or greater.
5 | To train the `VQVAE` model with 8 categorical dimensions and 128 codes per dimension
6 | run the following command:
7 | ```
8 | python train.py --model=VQVAE --latent-dim=8 --num-embeddings=128
9 | ```
10 | To train the `GS-Soft` model use `--model=GSSOFT`.
11 | Pretrained weights for the `VQVAE` and `GS-Soft` models can be found
12 | [here](https://github.com/bshall/VectorQuantizedVAE/releases/tag/v0.1).
13 |
14 |
15 |
16 |
17 |
18 | The `VQVAE` model gets ~4.82 bpd while the `GS-soft` model gets ~4.6 bpd.
19 |
20 | # Analysis of the Codebooks
21 |
22 | As demonstrated in the paper, the codebook matrices are low-dimensional, spanning only a few dimensions:
23 |
24 |
25 |
26 |
27 |
28 | Projecting the codes onto the first 3 principal components shows that the codes typically tile
29 | continuous 1- or 2-D manifolds:
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/assets/codebooks.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bshall/VectorQuantizedVAE/d7bc845c9d46d232f5c6aae1a825b8c6952084dc/assets/codebooks.png
--------------------------------------------------------------------------------
/assets/reconstructions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bshall/VectorQuantizedVAE/d7bc845c9d46d232f5c6aae1a825b8c6952084dc/assets/reconstructions.png
--------------------------------------------------------------------------------
/assets/variance_ratio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bshall/VectorQuantizedVAE/d7bc845c9d46d232f5c6aae1a825b8c6952084dc/assets/variance_ratio.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Categorical, RelaxedOneHotCategorical
5 | import math
6 |
7 |
8 | class VQEmbeddingEMA(nn.Module):
9 | def __init__(self, latent_dim, num_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5):
10 | super(VQEmbeddingEMA, self).__init__()
11 | self.commitment_cost = commitment_cost
12 | self.decay = decay
13 | self.epsilon = epsilon
14 |
15 | embedding = torch.zeros(latent_dim, num_embeddings, embedding_dim)
16 | embedding.uniform_(-1/num_embeddings, 1/num_embeddings)
17 | self.register_buffer("embedding", embedding)
18 | self.register_buffer("ema_count", torch.zeros(latent_dim, num_embeddings))
19 | self.register_buffer("ema_weight", self.embedding.clone())
20 |
21 | def forward(self, x):
22 | B, C, H, W = x.size()
23 | N, M, D = self.embedding.size()
24 | assert C == N * D
25 |
26 | x = x.view(B, N, D, H, W).permute(1, 0, 3, 4, 2)
27 | x_flat = x.detach().reshape(N, -1, D)
28 |
29 | distances = torch.baddbmm(torch.sum(self.embedding ** 2, dim=2).unsqueeze(1) +
30 | torch.sum(x_flat ** 2, dim=2, keepdim=True),
31 | x_flat, self.embedding.transpose(1, 2),
32 | alpha=-2.0, beta=1.0)
33 |
34 | indices = torch.argmin(distances, dim=-1)
35 | encodings = F.one_hot(indices, M).float()
36 | quantized = torch.gather(self.embedding, 1, indices.unsqueeze(-1).expand(-1, -1, D))
37 | quantized = quantized.view_as(x)
38 |
39 | if self.training:
40 | self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=1)
41 |
42 | n = torch.sum(self.ema_count, dim=-1, keepdim=True)
43 | self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n
44 |
45 | dw = torch.bmm(encodings.transpose(1, 2), x_flat)
46 | self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
47 |
48 | self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)
49 |
50 | e_latent_loss = F.mse_loss(x, quantized.detach())
51 | loss = self.commitment_cost * e_latent_loss
52 |
53 | quantized = x + (quantized - x).detach()
54 |
55 | avg_probs = torch.mean(encodings, dim=1)
56 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10), dim=-1))
57 |
58 | return quantized.permute(1, 0, 4, 2, 3).reshape(B, C, H, W), loss, perplexity.sum()
59 |
60 |
61 | class VQEmbeddingGSSoft(nn.Module):
62 | def __init__(self, latent_dim, num_embeddings, embedding_dim):
63 | super(VQEmbeddingGSSoft, self).__init__()
64 |
65 | self.embedding = nn.Parameter(torch.Tensor(latent_dim, num_embeddings, embedding_dim))
66 | nn.init.uniform_(self.embedding, -1/num_embeddings, 1/num_embeddings)
67 |
68 | def forward(self, x):
69 | B, C, H, W = x.size()
70 | N, M, D = self.embedding.size()
71 | assert C == N * D
72 |
73 | x = x.view(B, N, D, H, W).permute(1, 0, 3, 4, 2)
74 | x_flat = x.reshape(N, -1, D)
75 |
76 | distances = torch.baddbmm(torch.sum(self.embedding ** 2, dim=2).unsqueeze(1) +
77 | torch.sum(x_flat ** 2, dim=2, keepdim=True),
78 | x_flat, self.embedding.transpose(1, 2),
79 | alpha=-2.0, beta=1.0)
80 | distances = distances.view(N, B, H, W, M)
81 |
82 | dist = RelaxedOneHotCategorical(0.5, logits=-distances)
83 | if self.training:
84 | samples = dist.rsample().view(N, -1, M)
85 | else:
86 | samples = torch.argmax(dist.probs, dim=-1)
87 | samples = F.one_hot(samples, M).float()
88 | samples = samples.view(N, -1, M)
89 |
90 | quantized = torch.bmm(samples, self.embedding)
91 | quantized = quantized.view_as(x)
92 |
93 | KL = dist.probs * (dist.logits + math.log(M))
94 | KL[(dist.probs == 0).expand_as(KL)] = 0
95 | KL = KL.sum(dim=(0, 2, 3, 4)).mean()
96 |
97 | avg_probs = torch.mean(samples, dim=1)
98 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10), dim=-1))
99 |
100 | return quantized.permute(1, 0, 4, 2, 3).reshape(B, C, H, W), KL, perplexity.sum()
101 |
102 |
103 | class Residual(nn.Module):
104 | def __init__(self, channels):
105 | super(Residual, self).__init__()
106 | self.block = nn.Sequential(
107 | nn.ReLU(True),
108 | nn.Conv2d(channels, channels, 3, 1, 1, bias=False),
109 | nn.BatchNorm2d(channels),
110 | nn.ReLU(True),
111 | nn.Conv2d(channels, channels, 1, bias=False),
112 | nn.BatchNorm2d(channels)
113 | )
114 |
115 | def forward(self, x):
116 | return x + self.block(x)
117 |
118 |
119 | class Encoder(nn.Module):
120 | def __init__(self, channels, latent_dim, embedding_dim):
121 | super(Encoder, self).__init__()
122 | self.encoder = nn.Sequential(
123 | nn.Conv2d(3, channels, 4, 2, 1, bias=False),
124 | nn.BatchNorm2d(channels),
125 | nn.ReLU(True),
126 | nn.Conv2d(channels, channels, 4, 2, 1, bias=False),
127 | nn.BatchNorm2d(channels),
128 | Residual(channels),
129 | Residual(channels),
130 | nn.Conv2d(channels, latent_dim * embedding_dim, 1)
131 | )
132 |
133 | def forward(self, x):
134 | return self.encoder(x)
135 |
136 |
137 | class Decoder(nn.Module):
138 | def __init__(self, channels, latent_dim, embedding_dim):
139 | super(Decoder, self).__init__()
140 | self.decoder = nn.Sequential(
141 | nn.Conv2d(latent_dim * embedding_dim, channels, 1, bias=False),
142 | nn.BatchNorm2d(channels),
143 | Residual(channels),
144 | Residual(channels),
145 | nn.ConvTranspose2d(channels, channels, 4, 2, 1, bias=False),
146 | nn.BatchNorm2d(channels),
147 | nn.ReLU(True),
148 | nn.ConvTranspose2d(channels, channels, 4, 2, 1, bias=False),
149 | nn.BatchNorm2d(channels),
150 | nn.ReLU(True),
151 | nn.Conv2d(channels, 3 * 256, 1)
152 | )
153 |
154 | def forward(self, x):
155 | x = self.decoder(x)
156 | B, _, H, W = x.size()
157 | x = x.view(B, 3, 256, H, W).permute(0, 1, 3, 4, 2)
158 | dist = Categorical(logits=x)
159 | return dist
160 |
161 |
162 | class VQVAE(nn.Module):
163 | def __init__(self, channels, latent_dim, num_embeddings, embedding_dim):
164 | super(VQVAE, self).__init__()
165 | self.encoder = Encoder(channels, latent_dim, embedding_dim)
166 | self.codebook = VQEmbeddingEMA(latent_dim, num_embeddings, embedding_dim)
167 | self.decoder = Decoder(channels, latent_dim, embedding_dim)
168 |
169 | def forward(self, x):
170 | x = self.encoder(x)
171 | x, loss, perplexity = self.codebook(x)
172 | dist = self.decoder(x)
173 | return dist, loss, perplexity
174 |
175 |
176 | class GSSOFT(nn.Module):
177 | def __init__(self, channels, latent_dim, num_embeddings, embedding_dim):
178 | super(GSSOFT, self).__init__()
179 | self.encoder = Encoder(channels, latent_dim, embedding_dim)
180 | self.codebook = VQEmbeddingGSSoft(latent_dim, num_embeddings, embedding_dim)
181 | self.decoder = Decoder(channels, latent_dim, embedding_dim)
182 |
183 | def forward(self, x):
184 | x = self.encoder(x)
185 | x, KL, perplexity = self.codebook(x)
186 | dist = self.decoder(x)
187 | return dist, KL, perplexity
188 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import numpy as np
5 |
6 | from tqdm import tqdm
7 | import torch
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torchvision import datasets, transforms, utils
12 |
13 | from model import VQVAE, GSSOFT
14 |
15 |
16 | def save_checkpoint(model, optimizer, step, checkpoint_dir):
17 | checkpoint_state = {
18 | "model": model.state_dict(),
19 | "optimizer": optimizer.state_dict(),
20 | "step": step}
21 | checkpoint_path = checkpoint_dir / "model.ckpt-{}.pt".format(step)
22 | torch.save(checkpoint_state, checkpoint_path)
23 | print("Saved checkpoint: {}".format(checkpoint_path))
24 |
25 |
26 | def shift(x):
27 | return x - 0.5
28 |
29 |
30 | def train_gssoft(args):
31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 |
33 | model = GSSOFT(args.channels, args.latent_dim, args.num_embeddings, args.embedding_dim)
34 | model.to(device)
35 |
36 | model_name = "{}_C_{}_N_{}_M_{}_D_{}".format(args.model, args.channels, args.latent_dim,
37 | args.num_embeddings, args.embedding_dim)
38 |
39 | checkpoint_dir = Path(model_name)
40 | checkpoint_dir.mkdir(parents=True, exist_ok=True)
41 |
42 | writer = SummaryWriter(log_dir=Path("runs") / model_name)
43 |
44 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
45 |
46 | if args.resume is not None:
47 | print("Resume checkpoint from: {}:".format(args.resume))
48 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
49 | model.load_state_dict(checkpoint["model"])
50 | optimizer.load_state_dict(checkpoint["optimizer"])
51 | global_step = checkpoint["step"]
52 | else:
53 | global_step = 0
54 |
55 | transform = transforms.Compose([
56 | transforms.ToTensor(),
57 | transforms.Lambda(shift)
58 | ])
59 | training_dataset = datasets.CIFAR10("./CIFAR10", train=True, download=True,
60 | transform=transform)
61 |
62 | test_dataset = datasets.CIFAR10("./CIFAR10", train=False, download=True,
63 | transform=transform)
64 |
65 | training_dataloader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True,
66 | num_workers=args.num_workers, pin_memory=True)
67 |
68 | test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, drop_last=True,
69 | num_workers=args.num_workers, pin_memory=True)
70 |
71 | num_epochs = args.num_training_steps // len(training_dataloader) + 1
72 | start_epoch = global_step // len(training_dataloader) + 1
73 |
74 | N = 3 * 32 * 32
75 |
76 | for epoch in range(start_epoch, num_epochs + 1):
77 | model.train()
78 | average_logp = average_KL = average_elbo = average_bpd = average_perplexity = 0
79 | for i, (images, _) in enumerate(tqdm(training_dataloader), 1):
80 | images = images.to(device)
81 |
82 | dist, KL, perplexity = model(images)
83 | targets = (images + 0.5) * 255
84 | targets = targets.long()
85 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
86 | loss = (KL - logp) / N
87 | elbo = (KL - logp) / N
88 | bpd = elbo / np.log(2)
89 |
90 | optimizer.zero_grad()
91 | loss.backward()
92 | optimizer.step()
93 |
94 | global_step += 1
95 |
96 | if global_step % 25000 == 0:
97 | save_checkpoint(model, optimizer, global_step, checkpoint_dir)
98 |
99 | average_logp += (logp.item() - average_logp) / i
100 | average_KL += (KL.item() - average_KL) / i
101 | average_elbo += (elbo.item() - average_elbo) / i
102 | average_bpd += (bpd.item() - average_bpd) / i
103 | average_perplexity += (perplexity.item() - average_perplexity) / i
104 |
105 | writer.add_scalar("logp/train", average_logp, epoch)
106 | writer.add_scalar("kl/train", average_KL, epoch)
107 | writer.add_scalar("elbo/train", average_elbo, epoch)
108 | writer.add_scalar("bpd/train", average_bpd, epoch)
109 | writer.add_scalar("perplexity/train", average_perplexity, epoch)
110 |
111 | model.eval()
112 | average_logp = average_KL = average_elbo = average_bpd = average_perplexity = 0
113 | for i, (images, _) in enumerate(test_dataloader, 1):
114 | images = images.to(device)
115 |
116 | with torch.no_grad():
117 | dist, KL, perplexity = model(images)
118 |
119 | targets = (images + 0.5) * 255
120 | targets = targets.long()
121 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
122 | elbo = (KL - logp) / N
123 | bpd = elbo / np.log(2)
124 |
125 | average_logp += (logp.item() - average_logp) / i
126 | average_KL += (KL.item() - average_KL) / i
127 | average_elbo += (elbo.item() - average_elbo) / i
128 | average_bpd += (bpd.item() - average_bpd) / i
129 | average_perplexity += (perplexity.item() - average_perplexity) / i
130 |
131 | writer.add_scalar("logp/test", average_logp, epoch)
132 | writer.add_scalar("kl/test", average_KL, epoch)
133 | writer.add_scalar("elbo/test", average_elbo, epoch)
134 | writer.add_scalar("bpd/test", average_bpd, epoch)
135 | writer.add_scalar("perplexity/test", average_perplexity, epoch)
136 |
137 | samples = torch.argmax(dist.logits, dim=-1)
138 | grid = utils.make_grid(samples.float() / 255)
139 | writer.add_image("reconstructions", grid, epoch)
140 |
141 | print("epoch:{}, logp:{:.3E}, KL:{:.3E}, elbo:{:.3f}, bpd:{:.3f}, perplexity:{:.3f}"
142 | .format(epoch, average_logp, average_KL, average_elbo, average_bpd, average_perplexity))
143 |
144 |
145 | def train_vqvae(args):
146 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147 |
148 | model = VQVAE(args.channels, args.latent_dim, args.num_embeddings, args.embedding_dim)
149 | model.to(device)
150 |
151 | model_name = "{}_C_{}_N_{}_M_{}_D_{}".format(args.model, args.channels, args.latent_dim,
152 | args.num_embeddings, args.embedding_dim)
153 |
154 | checkpoint_dir = Path(model_name)
155 | checkpoint_dir.mkdir(parents=True, exist_ok=True)
156 |
157 | writer = SummaryWriter(log_dir=Path("runs") / model_name)
158 |
159 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
160 |
161 | if args.resume is not None:
162 | print("Resume checkpoint from: {}:".format(args.resume))
163 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
164 | model.load_state_dict(checkpoint["model"])
165 | optimizer.load_state_dict(checkpoint["optimizer"])
166 | global_step = checkpoint["step"]
167 | else:
168 | global_step = 0
169 |
170 | transform = transforms.Compose([
171 | transforms.ToTensor(),
172 | transforms.Lambda(shift)
173 | ])
174 | training_dataset = datasets.CIFAR10("./CIFAR10", train=True, download=True,
175 | transform=transform)
176 |
177 | test_dataset = datasets.CIFAR10("./CIFAR10", train=False, download=True,
178 | transform=transform)
179 |
180 | training_dataloader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True,
181 | num_workers=args.num_workers, pin_memory=True)
182 |
183 | test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, drop_last=True,
184 | num_workers=args.num_workers, pin_memory=True)
185 |
186 | num_epochs = args.num_training_steps // len(training_dataloader) + 1
187 | start_epoch = global_step // len(training_dataloader) + 1
188 |
189 | N = 3 * 32 * 32
190 | KL = args.latent_dim * 8 * 8 * np.log(args.num_embeddings)
191 |
192 | for epoch in range(start_epoch, num_epochs + 1):
193 | model.train()
194 | average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
195 | for i, (images, _) in enumerate(tqdm(training_dataloader), 1):
196 | images = images.to(device)
197 |
198 | dist, vq_loss, perplexity = model(images)
199 | targets = (images + 0.5) * 255
200 | targets = targets.long()
201 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
202 | loss = - logp / N + vq_loss
203 | elbo = (KL - logp) / N
204 | bpd = elbo / np.log(2)
205 |
206 | optimizer.zero_grad()
207 | loss.backward()
208 | optimizer.step()
209 |
210 | global_step += 1
211 |
212 | if global_step % 25000 == 0:
213 | save_checkpoint(model, optimizer, global_step, checkpoint_dir)
214 |
215 | average_logp += (logp.item() - average_logp) / i
216 | average_vq_loss += (vq_loss.item() - average_vq_loss) / i
217 | average_elbo += (elbo.item() - average_elbo) / i
218 | average_bpd += (bpd.item() - average_bpd) / i
219 | average_perplexity += (perplexity.item() - average_perplexity) / i
220 |
221 | writer.add_scalar("logp/train", average_logp, epoch)
222 | writer.add_scalar("kl/train", KL, epoch)
223 | writer.add_scalar("vqloss/train", average_vq_loss, epoch)
224 | writer.add_scalar("elbo/train", average_elbo, epoch)
225 | writer.add_scalar("bpd/train", average_bpd, epoch)
226 | writer.add_scalar("perplexity/train", average_perplexity, epoch)
227 |
228 | model.eval()
229 | average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
230 | for i, (images, _) in enumerate(test_dataloader, 1):
231 | images = images.to(device)
232 |
233 | with torch.no_grad():
234 | dist, vq_loss, perplexity = model(images)
235 |
236 | targets = (images + 0.5) * 255
237 | targets = targets.long()
238 | logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
239 | elbo = (KL - logp) / N
240 | bpd = elbo / np.log(2)
241 |
242 | average_logp += (logp.item() - average_logp) / i
243 | average_vq_loss += (vq_loss.item() - average_vq_loss) / i
244 | average_elbo += (elbo.item() - average_elbo) / i
245 | average_bpd += (bpd.item() - average_bpd) / i
246 | average_perplexity += (perplexity.item() - average_perplexity) / i
247 |
248 | writer.add_scalar("logp/test", average_logp, epoch)
249 | writer.add_scalar("kl/test", KL, epoch)
250 | writer.add_scalar("vqloss/test", average_vq_loss, epoch)
251 | writer.add_scalar("elbo/test", average_elbo, epoch)
252 | writer.add_scalar("bpd/test", average_bpd, epoch)
253 | writer.add_scalar("perplexity/test", average_perplexity, epoch)
254 |
255 | samples = torch.argmax(dist.logits, dim=-1)
256 | grid = utils.make_grid(samples.float() / 255)
257 | writer.add_image("reconstructions", grid, epoch)
258 |
259 | print("epoch:{}, logp:{:.3E}, vq loss:{:.3E}, elbo:{:.3f}, bpd:{:.3f}, perplexity:{:.3f}"
260 | .format(epoch, average_logp, average_vq_loss, average_elbo, average_bpd, average_perplexity))
261 |
262 |
263 | if __name__ == "__main__":
264 | parser = argparse.ArgumentParser()
265 | parser.add_argument("--num-workers", type=int, default=4, help="Number of dataloader workers.")
266 | parser.add_argument("--resume", type=str, default=None, help="Checkpoint path to resume.")
267 | parser.add_argument("--model", choices=["VQVAE", "GSSOFT"], help="Select model to train (either VQVAE or GSSOFT)")
268 | parser.add_argument("--channels", type=int, default=256, help="Number of channels in conv layers.")
269 | parser.add_argument("--latent-dim", type=int, default=8, help="Dimension of categorical latents.")
270 | parser.add_argument("--num-embeddings", type=int, default=128, help="Number of codebook embeddings size.")
271 | parser.add_argument("--embedding-dim", type=int, default=32, help="Dimension of codebook embeddings.")
272 | parser.add_argument("--learning-rate", type=float, default=5e-4, help="Learning rate.")
273 | parser.add_argument("--batch-size", type=int, default=128, help="Batch size.")
274 | parser.add_argument("--num-training-steps", type=int, default=250000, help="Number of training steps.")
275 | args = parser.parse_args()
276 | if args.model == "VQVAE":
277 | train_vqvae(args)
278 | if args.model == "GSSOFT":
279 | train_gssoft(args)
280 |
--------------------------------------------------------------------------------