├── .gitignore
├── air.py
├── basic_vae.py
├── datasets
└── chexpert.py
├── dcgan.py
├── draw.py
├── draw_test_attn.py
├── environment.yml
├── images
├── air
│ ├── air_count.png
│ ├── air_elbo.png
│ └── image_recons_270.png
├── basic_vae
│ ├── reconstruction_at_epoch_24.png
│ ├── sample_at_epoch_24.png
│ └── tsne_embedding.png
├── dcgan
│ └── latent_var_grid_sample_c1.png
├── draw
│ ├── draw_fig_3.png
│ ├── draw_fig_4.png
│ ├── elephant.png
│ └── generated_32_time_steps.gif
├── infogan
│ ├── latent_var_grid_sample_c1.png
│ └── latent_var_grid_sample_c2.png
├── ssvae
│ ├── analogies_sample.png
│ ├── latent_var_grid_sample_c1_y2.png
│ └── latent_var_grid_sample_c2_y4.png
└── vqvae2
│ ├── 128x128_bits3_eval_reconstruction_step_87300_bottom.png
│ ├── 128x128_bits3_eval_reconstruction_step_87300_original.png
│ ├── 128x128_bits3_eval_reconstruction_step_87300_top.png
│ └── generation_sample_step_52440_top_b128_c128_outstack10_bottom_b16_c128_nres20_condstack10.png
├── infogan.py
├── optim.py
├── readme.md
├── ssvae.py
├── utils.py
├── vqvae.py
└── vqvae_prior.py
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 | __pycache__/
3 | data
4 | *.pyc
5 |
--------------------------------------------------------------------------------
/air.py:
--------------------------------------------------------------------------------
1 | """
2 | Attend, Infer, Repeat:
3 | Fast Scene Understanding with Generative Models
4 | https://arxiv.org/pdf/1603.08575v2.pdf
5 | """
6 |
7 | import os
8 | import argparse
9 | import pprint
10 | import time
11 | from tqdm import tqdm
12 |
13 | import numpy as np
14 | from observations import multi_mnist
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.distributions as D
20 | from torch.utils.data import DataLoader, Dataset
21 | from torchvision.utils import save_image, make_grid
22 | from tensorboardX import SummaryWriter
23 |
24 |
25 | parser = argparse.ArgumentParser()
26 |
27 | # actions
28 | parser.add_argument('--train', action='store_true', help='Train a model.')
29 | parser.add_argument('--evaluate', action='store_true', help='Evaluate a model.')
30 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.')
31 | parser.add_argument('--restore_file', type=str, help='Path to model to restore.')
32 | parser.add_argument('--data_dir', default='./data/', help='Location of datasets.')
33 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0]))
34 | parser.add_argument('--seed', type=int, default=2182019, help='Random seed to use.')
35 | parser.add_argument('--cuda', type=int, default=None, help='Which cuda device to use')
36 | parser.add_argument('--verbose', '-v', action='count', help='Verbose mode; send gradient stats to tensorboard.')
37 | # model params
38 | parser.add_argument('--image_dims', type=tuple, default=(1,50,50), help='Dimensions of a single datapoint (e.g. (1,50,50) for multi MNIST).')
39 | parser.add_argument('--z_what_size', type=int, default=50, help='Size of the z_what latent representation.')
40 | parser.add_argument('--z_where_size', type=int, default=3, help='Size of the z_where latent representation e.g. dim=3 for (s, tx, ty) affine parametrization.')
41 | parser.add_argument('--z_pres_size', type=int, default=1, help='Size of the z_pres latent representation, e.g. dim=1 for the probability of occurence of an object.')
42 | parser.add_argument('--enc_dec_size', type=int, default=200, help='Size of the encoder and decoder hidden layers.')
43 | parser.add_argument('--lstm_size', type=int, default=256, help='Size of the LSTM hidden layer for AIR.')
44 | parser.add_argument('--baseline_lstm_size', type=int, default=256, help='Size of the LSTM hidden layer for the gradient baseline estimator.')
45 | parser.add_argument('--attn_window_size', type=int, default=28, help='Size of the attention window of the decoder.')
46 | parser.add_argument('--max_steps', type=int, default=3, help='Maximum number of objects per image to sample a binomial from.')
47 | parser.add_argument('--likelihood_sigma', type=float, default=0.3, help='Sigma parameter for the likelihood function (a Normal distribution).')
48 | parser.add_argument('--z_pres_prior_success_prob', type=float, default=0.75, help='Prior probability of success for the num objects per image prior.')
49 | parser.add_argument('--z_pres_anneal_start_step', type=int, default=1000, help='Start step to begin annealing the num objects per image prior.')
50 | parser.add_argument('--z_pres_anneal_end_step', type=int, default=100000, help='End step to stop annealing the num objects per image prior.')
51 | parser.add_argument('--z_pres_anneal_start_value', type=float, default=0.99, help='Initial probability of success for the num objects per image prior.')
52 | parser.add_argument('--z_pres_anneal_end_value', type=float, default=1e-5, help='Final probility of successs value for the num objects per image prior.')
53 | parser.add_argument('--z_pres_init_encoder_bias', type=float, default=2., help='Add bias to the initialization of the z_pres encoder.')
54 | parser.add_argument('--decoder_bias', type=float, default=-2., help='Add preactivation bias to decoder.')
55 | # training params
56 | parser.add_argument('--batch_size', type=int, default=64)
57 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.')
58 | parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.')
59 | parser.add_argument('--model_lr', type=float, default=1e-4, help='Learning rate for AIR.')
60 | parser.add_argument('--baseline_lr', type=float, default=1e-3, help='Learning rate for the gradient baseline estimator.')
61 | parser.add_argument('--log_interval', type=int, default=100, help='Write loss and parameter stats to tensorboard.')
62 | parser.add_argument('--eval_interval', type=int, default=10, help='Number of epochs to eval model and save checkpoint.')
63 | parser.add_argument('--mini_data_size', type=int, default=None, help='Train only on this number of datapoints.')
64 |
65 | # --------------------
66 | # Data
67 | # --------------------
68 |
69 | class MultiMNIST(Dataset):
70 | def __init__(self, root, training=True, download=True, max_digits=2, canvas_size=50, seed=42, mini_data_size=None):
71 | self.root = os.path.expanduser(root)
72 |
73 | # check if multi mnist already compiled
74 | self.multi_mnist_filename = 'multi_mnist_{}_{}_{}'.format(max_digits, canvas_size, seed)
75 |
76 | if not self._check_processed_exists():
77 | if self._check_raw_exists():
78 | # process into pt file
79 | data = np.load(os.path.join(self.root, 'raw', self.multi_mnist_filename + '.npz'))
80 | train_data, train_labels, test_data, test_labels = [data[f] for f in data.files]
81 | self._process_and_save(train_data, train_labels, test_data, test_labels)
82 | else:
83 | if not download:
84 | raise RuntimeError('Dataset not found. Use download=True to download it.')
85 | else:
86 | (train_data, train_labels), (test_data, test_labels) = multi_mnist(root, max_digits, canvas_size, seed)
87 | self._process_and_save(train_data, train_labels, test_data, test_labels)
88 | else:
89 | data = torch.load(os.path.join(self.root, 'processed', self.multi_mnist_filename + '.pt'))
90 | self.train_data, self.train_labels, self.test_data, self.test_labels = \
91 | data['train_data'], data['train_labels'], data['test_data'], data['test_labels']
92 |
93 | if training:
94 | self.x, self.y = self.train_data, self.train_labels
95 | else:
96 | self.x, self.y = self.test_data, self.test_labels
97 |
98 | if mini_data_size != None:
99 | self.x = self.x[:mini_data_size]
100 | self.y = self.y[:mini_data_size]
101 |
102 | def __getitem__(self, idx):
103 | return self.x[idx].unsqueeze(0), self.y[idx]
104 |
105 | def __len__(self):
106 | return len(self.x)
107 |
108 | def _check_processed_exists(self):
109 | return os.path.exists(os.path.join(self.root, 'processed', self.multi_mnist_filename + '.pt'))
110 |
111 | def _check_raw_exists(self):
112 | return os.path.exists(os.path.join(self.root, 'raw', self.multi_mnist_filename + '.npz'))
113 |
114 | def _make_label_tensor(self, label_arr):
115 | out = torch.zeros(10)
116 | for l in label_arr:
117 | out[l] += 1
118 | return out
119 |
120 | def _process_and_save(self, train_data, train_labels, test_data, test_labels):
121 | self.train_data = torch.from_numpy(train_data).float() / 255
122 | self.train_labels = torch.stack([self._make_label_tensor(label) for label in train_labels])
123 | self.test_data = torch.from_numpy(test_data).float() / 255
124 | self.test_labels = torch.stack([self._make_label_tensor(label) for label in test_labels])
125 | # check folder exists
126 | if not os.path.exists(os.path.join(self.root, 'processed')):
127 | os.makedirs(os.path.join(self.root, 'processed'))
128 | with open(os.path.join(self.root, 'processed', self.multi_mnist_filename + '.pt'), 'wb') as f:
129 | torch.save({'train_data': self.train_data,
130 | 'train_labels': self.train_labels,
131 | 'test_data': self.test_data,
132 | 'test_labels': self.test_labels},
133 | f)
134 |
135 | def fetch_dataloaders(args):
136 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {}
137 | dataset = MultiMNIST(root=args.data_dir, training=True, mini_data_size=args.mini_data_size)
138 | train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)
139 | dataset = MultiMNIST(root=args.data_dir, training=False if args.mini_data_size is None else True, mini_data_size=args.mini_data_size)
140 | test_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, **kwargs)
141 | return train_dataloader, test_dataloader
142 |
143 |
144 | # --------------------
145 | # Model helper functions -- spatial tranformer
146 | # --------------------
147 |
148 | def stn(image, z_where, out_dims, inverse=False, box_attn_window_color=None):
149 | """ spatial transformer network used to scale and shift input according to z_where in:
150 | 1/ x -> x_att -- shapes (H, W) -> (attn_window, attn_window) -- thus inverse = False
151 | 2/ y_att -> y -- (attn_window, attn_window) -> (H, W) -- thus inverse = True
152 |
153 | inverting the affine transform as follows: A_inv ( A * image ) = image
154 | A = [R | T] where R is rotation component of angle alpha, T is [tx, ty] translation component
155 | A_inv rotates by -alpha and translates by [-tx, -ty]
156 |
157 | if x' = R * x + T --> x = R_inv * (x' - T) = R_inv * x - R_inv * T
158 |
159 | here, z_where is 3-dim [scale, tx, ty] so inverse transform is [1/scale, -tx/scale, -ty/scale]
160 | R = [[s, 0], -> R_inv = [[1/s, 0],
161 | [0, s]] [0, 1/s]]
162 | """
163 |
164 | if box_attn_window_color is not None:
165 | # draw a box around the attention window by overwriting the boundary pixels in the given color channel
166 | with torch.no_grad():
167 | box = torch.zeros_like(image.expand(-1,3,-1,-1))
168 | c = box_attn_window_color % 3 # write the color bbox in channel c, as model time steps
169 | box[:,c,:,0] = 1
170 | box[:,c,:,-1] = 1
171 | box[:,c,0,:] = 1
172 | box[:,c,-1,:] = 1
173 | # add box to image and clap at 1 if overlap
174 | image = torch.clamp(image + box, 0, 1)
175 |
176 | # 1. construct 2x3 affine matrix for each datapoint in the minibatch
177 | theta = torch.zeros(2,3).repeat(image.shape[0], 1, 1).to(image.device)
178 | # set scaling
179 | theta[:, 0, 0] = theta[:, 1, 1] = z_where[:,0] if not inverse else 1 / (z_where[:,0] + 1e-9)
180 | # set translation
181 | theta[:, :, -1] = z_where[:, 1:] if not inverse else - z_where[:,1:] / (z_where[:,0].view(-1,1) + 1e-9)
182 | # 2. construct sampling grid
183 | grid = F.affine_grid(theta, torch.Size(out_dims))
184 | # 3. sample image from grid
185 | return F.grid_sample(image, grid)
186 |
187 |
188 | # --------------------
189 | # Model helper functions -- distribution manupulations
190 | # --------------------
191 |
192 | def compute_geometric_from_bernoulli(obj_probs):
193 | """ compute a normalized truncated geometric distribution from a table of bernoulli probs
194 | args
195 | obj_probs -- tensor of shape (N, max_steps) of Bernoulli success probabilities.
196 | """
197 | cum_succ_probs = obj_probs.cumprod(1)
198 | fail_probs = 1 - obj_probs
199 | geom = torch.cat([fail_probs[:,:1], fail_probs[:,1:] * cum_succ_probs[:,:-1], cum_succ_probs[:,-1:]], dim=1)
200 | return geom / geom.sum(1, True)
201 |
202 | def compute_z_pres_kl(q_z_pres_geom, p_z_pres, writer=None):
203 | """ compute kl divergence between truncated geom prior and tabular geom posterior
204 | args
205 | p_z_pres -- torch.distributions.Geometric object
206 | q_z_pres_geom -- torch tensor of shape (N, max_steps + 1) of a normalized geometric pdf
207 | """
208 | # compute normalized truncated geometric
209 | p_z_pres_log_probs = p_z_pres.log_prob(torch.arange(q_z_pres_geom.shape[1], dtype=torch.float, device=q_z_pres_geom.device))
210 | p_z_pres_normed_log_probs = p_z_pres_log_probs - p_z_pres_log_probs.logsumexp(dim=0)
211 |
212 | kl = q_z_pres_geom * (torch.log(q_z_pres_geom + 1e-8) - p_z_pres_normed_log_probs.expand_as(q_z_pres_geom))
213 | return kl
214 |
215 | def anneal_z_pres_prob(prob, step, args):
216 | if args.z_pres_anneal_start_step < step < args.z_pres_anneal_end_step:
217 | slope = (args.z_pres_anneal_end_value - args.z_pres_anneal_start_value) / (args.z_pres_anneal_end_step - args.z_pres_anneal_start_step)
218 | prob = torch.tensor(args.z_pres_anneal_start_value + slope * (step - args.z_pres_anneal_start_step), device=prob.device)
219 | return prob
220 |
221 |
222 | # --------------------
223 | # Model
224 | # --------------------
225 |
226 | class AIR(nn.Module):
227 | def __init__(self, args):
228 | super().__init__()
229 | self.debug = False
230 | # record dims
231 | self.C, self.H, self.W = args.image_dims
232 | self.A = args.attn_window_size
233 | x_size = self.C * self.H * self.W
234 | self.lstm_size = args.lstm_size
235 | self.baseline_lstm_size = args.baseline_lstm_size
236 | self.z_what_size = args.z_what_size
237 | self.z_where_size = args.z_where_size
238 | self.max_steps = args.max_steps
239 |
240 | # --------------------
241 | # p model -- cf AIR paper section 2
242 | # --------------------
243 |
244 | # latent variable priors
245 | # z_pres ~ Ber(p) Geom(rho) discrete representation for the presence of a scene object
246 | # z_where ~ N(mu, scale); continuous 3-dim variable for pose (position and scale)
247 | # z_what ~ N(0,1); continuous representation for shape
248 | self.register_buffer('z_pres_prior', torch.tensor(args.z_pres_prior_success_prob)) # prior used for generation
249 | self.register_buffer('z_pres_prob', torch.tensor(args.z_pres_anneal_start_value)) # `current value` used for training and annealing
250 | self.register_buffer('z_what_mean', torch.zeros(args.z_what_size))
251 | self.register_buffer('z_what_scale', torch.ones(args.z_what_size))
252 | self.register_buffer('z_where_mean', torch.tensor([0.3, 0., 0.]))
253 | self.register_buffer('z_where_scale', torch.tensor([0.1, 1., 1.]))
254 |
255 | # likelihood = N(mu, sigma)
256 | self.register_buffer('likelihood_sigma', torch.tensor(args.likelihood_sigma))
257 |
258 | # likelihood p(x|n,z) of the data given the latents
259 | self.decoder = nn.Sequential(nn.Linear(args.z_what_size, args.enc_dec_size),
260 | nn.ReLU(True),
261 | nn.Linear(args.enc_dec_size, self.C * self.A ** 2))
262 | self.decoder_bias = args.decoder_bias # otherwise initial samples are heavily penalized by likelihood (cf Pyro implementation)
263 |
264 | # --------------------
265 | # q model for approximating the posterior -- cf AIR paper section 2.1
266 | # --------------------
267 |
268 | # encoder
269 | # rnn encodes the latents z_1:t over the number of steps where z_pres indicates presence of an object
270 | # q_z_pres encodes whether there is an object present in the image; q_z_pres = Bernoulli
271 | # q_z_what encodes the attention window; q_z_what = Normal(mu, sigma)
272 | # q_z_where encodes the affine transform of of the image > attn_window; q_z_where = Normal(0, cov) of dim = 3 for [scale, tx, ty]
273 | self.encoder = nn.ModuleDict({
274 | 'rnn': nn.LSTMCell(x_size + args.z_where_size + args.z_what_size + args.z_pres_size, args.lstm_size),
275 | 'z_pres': nn.Linear(args.lstm_size, 1),
276 | 'z_what': nn.Sequential(nn.Linear(self.A ** 2 , args.enc_dec_size),
277 | nn.ReLU(True),
278 | nn.Linear(args.enc_dec_size, 2 * args.z_what_size)),
279 | 'z_where': nn.Linear(args.lstm_size, 2 * args.z_where_size)})
280 |
281 | nn.init.constant_(self.encoder.z_pres.bias, args.z_pres_init_encoder_bias) # push initial num time steps probs higher
282 |
283 | # initialize STN to identity
284 | self.encoder.z_where.weight.data.zero_()
285 | self.encoder.z_where.bias.data = torch.cat([torch.zeros(args.z_where_size), -1.*torch.ones(args.z_where_size)],dim=0)
286 |
287 | # --------------------
288 | # Baseline model for NVIL per Mnih & Gregor
289 | # --------------------
290 |
291 | self.baseline = nn.ModuleDict({
292 | 'rnn': nn.LSTMCell(x_size + args.z_where_size + args.z_what_size + args.z_pres_size, args.baseline_lstm_size),
293 | 'linear': nn.Linear(args.baseline_lstm_size, 1)})
294 |
295 | @property
296 | def p_z_pres(self):
297 | return D.Geometric(probs=1-self.z_pres_prob)
298 |
299 | @property
300 | def p_z_what(self):
301 | return D.Normal(self.z_what_mean, self.z_what_scale)
302 |
303 | @property
304 | def p_z_where(self):
305 | return D.Normal(self.z_where_mean, self.z_where_scale)
306 |
307 | def forward(self, x, writer=None, box_attn_window_color=None):
308 | """ cf AIR paper Figure 3 (right) for model flow.
309 | Computes (1) inference for z latents;
310 | (2) data reconstruction given the latents;
311 | (3) baseline for decreasing gradient variance;
312 | (4) losses
313 | Returns
314 | recon_x -- tensor of shape (B, C, H, W); reconstruction of data
315 | pred_counts -- teonsor of shape (B,); predicted number of object for each data point
316 | elbo -- tensor of shape (B,); variational lower bound
317 | loss -- tensor of shape (0) of the scalar objective loss
318 | baseline loss -- tensor of shape (0) of the scalar baseline loss (cf Mnih & Gregor NVIL)
319 | """
320 | batch_size = x.shape[0]
321 | device = x.device
322 |
323 | # store for elbo computation
324 | pred_counts = torch.zeros(batch_size, self.max_steps, device=device) # store for object count accuracy
325 | obj_probs = torch.ones(batch_size, self.max_steps, device=device) # store for computing the geometric posterior
326 | baseline = torch.zeros(batch_size, device=device)
327 | kl_z_pres = torch.zeros(batch_size, device=device)
328 | kl_z_what = torch.zeros(batch_size, device=device)
329 | kl_z_where = torch.zeros(batch_size, device=device)
330 |
331 | # initialize canvas, encoder rnn, states of the latent variables, mask for z_pres, baseline rnn
332 | recon_x = torch.zeros(batch_size, 3 if box_attn_window_color is not None else self.C, self.H, self.W, device=device)
333 | h_enc = torch.zeros(batch_size, self.lstm_size, device=device)
334 | c_enc = torch.zeros_like(h_enc)
335 | z_pres = torch.ones(batch_size, 1, device=device)
336 | z_what = torch.zeros(batch_size, self.z_what_size, device=device)
337 | z_where = torch.rand(batch_size, self.z_where_size, device=device)
338 | h_baseline = torch.zeros(batch_size, self.baseline_lstm_size, device=device)
339 | c_baseline = torch.zeros_like(h_baseline)
340 |
341 | # run model forward up to a max number of reconstruction steps
342 | for i in range(self.max_steps):
343 |
344 | # --------------------
345 | # Inference step -- AIR paper fig3 middle.
346 | # 1. compute 1-dimensional Bernoulli variable indicating the entity’s presence
347 | # 2. compute 3-dimensional vector specifying the affine parameters of its position and scale (ziwhere).
348 | # 3. compute C-dimensional distributed vector describing its class or appearance (ziwhat)
349 | # --------------------
350 |
351 | # rnn encoder
352 | h_enc, c_enc = self.encoder.rnn(torch.cat([x, z_pres, z_what, z_where], dim=-1), (h_enc, c_enc))
353 |
354 | # 1. compute 1-dimensional Bernoulli variable indicating the entity’s presence; note: if z_pres == 0, subsequent mask are zeroed
355 | q_z_pres = D.Bernoulli(probs = torch.clamp(z_pres * torch.sigmoid(self.encoder.z_pres(h_enc)), 1e-5, 1 - 1e-5)) # avoid probs that are exactly 0 or 1
356 | z_pres = q_z_pres.sample()
357 |
358 | # 2. compute 3-dimensional vector specifying the affine parameters of its position and scale (ziwhere).
359 | q_z_where_mean, q_z_where_scale = self.encoder.z_where(h_enc).chunk(2, -1)
360 | q_z_where = D.Normal(q_z_where_mean + self.z_where_mean, F.softplus(q_z_where_scale) * self.z_where_scale)
361 | z_where = q_z_where.rsample()
362 |
363 | # attend to a part of the image (using a spatial transformer) to produce x_i_att
364 | x_att = stn(x.view(batch_size, self.C, self.H, self.W), z_where, (batch_size, self.C, self.A, self.A), inverse=False)
365 |
366 | # 3. compute C-dimensional distributed vector describing its class or appearance (ziwhat)
367 | q_z_what_mean, q_z_what_scale = self.encoder.z_what(x_att.flatten(start_dim=1)).chunk(2, -1)
368 | q_z_what = D.Normal(q_z_what_mean, F.softplus(q_z_what_scale))
369 | z_what = q_z_what.rsample()
370 |
371 | # --------------------
372 | # Reconstruction step
373 | # 1. computes y_i_att reconstruction of the attention window x_att
374 | # 2. add to canvas over all timesteps
375 | # --------------------
376 |
377 | # 1. compute reconstruction of the attention window
378 | y_att = torch.sigmoid(self.decoder(z_what).view(-1, self.C, self.A, self.A) + self.decoder_bias)
379 |
380 | # scale and shift y according to z_where
381 | y = stn(y_att, z_where, (batch_size, self.C, self.H, self.W), inverse=True, box_attn_window_color=i if box_attn_window_color is not None else None)
382 |
383 | # 2. add reconstruction to canvas
384 | recon_x += y * z_pres.view(-1,1,1,1)
385 |
386 | # --------------------
387 | # Baseline step -- AIR paper cf's Mnih & Gregor NVIL; specifically sec 2.3 variance reduction
388 | # --------------------
389 |
390 | # compute baseline; independent of the z latents (cf Mnih & Gregor NVIL) so detach from graph
391 | baseline_input = torch.cat([x, z_pres.detach(), z_what.detach(), z_where.detach()], dim=-1)
392 | h_baseline, c_baseline = self.baseline.rnn(baseline_input, (h_baseline, c_baseline))
393 | baseline += self.baseline.linear(h_baseline).squeeze() # note: masking by z_pres give poorer results
394 |
395 | # --------------------
396 | # Variational lower bound / loss components
397 | # --------------------
398 |
399 | # compute kl(q||p) divergences -- sum over latent dim
400 | kl_z_what += D.kl.kl_divergence(q_z_what, self.p_z_what).sum(1) * z_pres.squeeze()
401 | kl_z_where += D.kl.kl_divergence(q_z_where, self.p_z_where).sum(1) * z_pres.squeeze()
402 |
403 | pred_counts[:,i] = z_pres.flatten()
404 | obj_probs[:,i] = q_z_pres.probs.flatten()
405 |
406 | q_z_pres = compute_geometric_from_bernoulli(obj_probs)
407 | score_fn = q_z_pres[torch.arange(batch_size), pred_counts.sum(1).long()].log() # log prob of num objects under the geometric
408 | kl_z_pres = compute_z_pres_kl(q_z_pres, self.p_z_pres, writer).sum(1) # note: mask by pred_counts makes no difference
409 |
410 | p_x_z = D.Normal(recon_x.flatten(1), self.likelihood_sigma)
411 | log_like = p_x_z.log_prob(x.view(-1, self.C, self.H, self.W).expand_as(recon_x).flatten(1)).sum(-1) # sum image dims (C, H, W)
412 |
413 | # --------------------
414 | # Compute variational bound and loss function
415 | # --------------------
416 |
417 | elbo = log_like - kl_z_pres - kl_z_what - kl_z_where # objective for loss function, but high variance
418 | loss = - torch.sum(elbo + (elbo - baseline).detach() * score_fn) # var reduction surrogate objective objective (cf Mnih & Gregor NVIL)
419 | baseline_loss = F.mse_loss(elbo.detach(), baseline)
420 |
421 | if writer:
422 | writer.add_scalar('log_like', log_like.mean(0).item(), writer.step)
423 | writer.add_scalar('kl_z_pres', kl_z_pres.mean(0).item(), writer.step)
424 | writer.add_scalar('kl_z_what', kl_z_what.mean(0).item(), writer.step)
425 | writer.add_scalar('kl_z_where', kl_z_where.mean(0).item(), writer.step)
426 | writer.add_scalar('elbo', elbo.mean(0).item(), writer.step)
427 | writer.add_scalar('baseline', baseline.mean(0).item(), writer.step)
428 | writer.add_scalar('score_function', score_fn.mean(0).item(), writer.step)
429 | writer.add_scalar('z_pres_prob', self.z_pres_prob.item(), writer.step)
430 |
431 | return recon_x, pred_counts, elbo, loss, baseline_loss
432 |
433 | @torch.no_grad()
434 | def generate(self, n_samples):
435 | """ AIR paper figure 3 left:
436 |
437 | The generative model draws n ∼ Geom(ρ) digits {y_i_att} of size 28 × 28 (two shown), scales andshifts them
438 | according to z_i_where ∼ N (0, Σ) using spatial transformers, and sums the results {y_i} to form a 50 × 50 image.
439 | Each digit is obtained by first sampling a latent code z_i_what from the prior z_i_what ∼ N (0, 1) and
440 | propagating it through the decoder network of a variational autoencoder.
441 | The learnable parameters θ of the generative model are the parameters of this decoder network.
442 | """
443 | # sample z_pres ~ Geom(rho) -- this is the number of digits present in an image
444 | z_pres = D.Geometric(1 - self.z_pres_prior).sample((n_samples,)).clamp_(0, self.max_steps)
445 |
446 | # compute a mask on z_pres as e.g.:
447 | # z_pres = [1,4,2,0]
448 | # mask = [[1,0,0,0,0],
449 | # [1,1,1,1,0],
450 | # [1,1,0,0,0],
451 | # [0,0,0,0,0]]
452 | # thus network outputs more objects (sample z_what, z_where and decode) where z_pres is 1
453 | # and outputs nothing when z_pres is 0
454 | z_pres_mask = torch.arange(self.max_steps).float().to(z_pres.device).expand(n_samples, self.max_steps) < z_pres.view(-1,1)
455 | z_pres_mask = z_pres_mask.float().to(z_pres.device)
456 |
457 | # initialize image canvas
458 | x = torch.zeros(n_samples, self.C, self.H, self.W).to(z_pres.device)
459 |
460 | # generate digits
461 | for i in range(int(z_pres.max().item())): # up until the number of objects sampled via z_pres
462 | # sample priors
463 | z_what = self.p_z_what.sample((n_samples,))
464 | z_where = self.p_z_where.sample((n_samples,))
465 |
466 | # propagate through the decoder, scale and shift y_att according to z_where using spatial transformers
467 | y_att = torch.sigmoid(self.decoder(z_what).view(n_samples, self.C, self.A, self.A) + self.decoder_bias)
468 | y = stn(y_att, z_where, (n_samples, self.C, self.H, self.W), inverse=True, box_attn_window_color=i)
469 |
470 | # apply mask and sum results towards final image
471 | x = x + y * z_pres_mask[:,i].view(-1,1,1,1)
472 | return x
473 |
474 |
475 | # --------------------
476 | # Train and evaluate
477 | # --------------------
478 |
479 | def train_epoch(model, dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, epoch, writer, args):
480 | model.train()
481 |
482 | with tqdm(total=len(dataloader), desc='epoch {} / {}'.format(epoch+1, args.start_epoch + args.n_epochs)) as pbar:
483 |
484 | for i, (x, y) in enumerate(dataloader):
485 | writer.step += 1 # update global step
486 |
487 | x = x.view(x.shape[0], -1).to(args.device)
488 |
489 | # run through model and compute loss
490 | recon_x, pred_counts, elbo, loss, baseline_loss = model(x, writer if i % args.log_interval == 0 else None) # pass writer at logging intervals
491 |
492 | # anneal z_pres prior
493 | model.z_pres_prob = anneal_z_pres_prob(model.z_pres_prob, writer.step, args)
494 |
495 | model_optimizer.zero_grad()
496 | loss.backward()
497 | model_optimizer.step()
498 |
499 | baseline_optimizer.zero_grad()
500 | baseline_loss.backward()
501 | baseline_optimizer.step()
502 |
503 | # update tracking
504 | count_accuracy = torch.eq(pred_counts.sum(1).cpu(), y.sum(1)).float().mean()
505 | pbar.set_postfix(elbo='{:.3f}'.format(elbo.mean(0).item()), \
506 | loss='{:.3f}'.format(loss.item()), \
507 | count_acc='{:.2f}'.format(count_accuracy.item()))
508 | pbar.update()
509 |
510 | if i % args.log_interval == 0:
511 | writer.add_scalar('loss', loss.item(), writer.step)
512 | writer.add_scalar('baseline_loss', baseline_loss.item(), writer.step)
513 | writer.add_scalar('count_accuracy_train', count_accuracy.item(), writer.step)
514 |
515 | if args.verbose == 1:
516 | print('z_pres prior:', model.p_z_pres.log_prob(torch.arange(args.max_steps + 1.).to(args.device)).exp(), \
517 | 'post:', compute_geometric_from_bernoulli(pred_counts.mean(0).unsqueeze(0)).squeeze(), \
518 | 'ber success:', pred_counts.mean(0))
519 |
520 | @torch.no_grad()
521 | def evaluate(model, dataloader, args, n_samples=10):
522 | model.eval()
523 |
524 | # initialize trackers
525 | elbo = 0
526 | pred_counts = []
527 | true_counts = []
528 |
529 | # evaluate elbo
530 | for x, y in tqdm(dataloader):
531 | x = x.view(x.shape[0], -1).to(args.device)
532 | _, pred_counts_i, elbo_i, _, _ = model(x)
533 | elbo += elbo_i.sum(0).item()
534 | pred_counts += [pred_counts_i.cpu()]
535 | true_counts += [y]
536 | elbo /= (len(dataloader) * args.batch_size)
537 |
538 | # evaluate count accuracy; test dataset not shuffled to preds and true aligned sequentially
539 | pred_counts = torch.cat(pred_counts, dim=0)
540 | true_counts = torch.cat(true_counts, dim=0)
541 | count_accuracy = torch.eq(pred_counts.sum(1), true_counts.sum(1)).float().mean()
542 |
543 | # visualize reconstruction
544 | x = x[-n_samples:] # take last n_sample data points
545 | recon_x, _, _, _, _ = model(x, box_attn_window_color=True)
546 | image_recons = torch.cat([x.view(-1, *args.image_dims).expand_as(recon_x), recon_x], dim=0)
547 | image_recons = make_grid(image_recons.cpu(), nrow=n_samples, pad_value=1)
548 |
549 | return elbo, count_accuracy, image_recons
550 |
551 | def train_and_evaluate(model, train_dataloader, test_dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, writer, args):
552 |
553 | for epoch in range(args.start_epoch, args.start_epoch + args.n_epochs):
554 | # train
555 | train_epoch(model, train_dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, epoch, writer, args)
556 |
557 | # evaluate
558 | if epoch % args.eval_interval == 0:
559 | test_elbo, count_accuracy, image_recons = evaluate(model, test_dataloader, args)
560 | print('Evaluation at epoch {}: test elbo {:.3f}; count accuracy {:.3f}'.format(epoch, test_elbo, count_accuracy))
561 | writer.add_scalar('test_elbo', test_elbo, epoch)
562 | writer.add_scalar('count_accuracy_test', count_accuracy, epoch)
563 | writer.add_image('image_reconstruction', image_recons, epoch)
564 | save_image(image_recons, os.path.join(args.output_dir, 'image_recons_{}.png'.format(epoch)))
565 |
566 | # generate samples
567 | samples = model.generate(n_samples=10)
568 | images = make_grid(samples, nrow=samples.shape[0], pad_value=1)
569 | save_image(images, os.path.join(args.output_dir, 'generated_sample_{}.png'.format(epoch)))
570 | writer.add_image('training_sample', images, epoch)
571 |
572 | # save training checkpoint
573 | torch.save({'epoch': epoch,
574 | 'global_step': writer.step,
575 | 'state_dict': model.state_dict()},
576 | os.path.join(args.output_dir, 'checkpoint.pt'))
577 |
578 |
579 | # --------------------
580 | # Main
581 | # --------------------
582 |
583 | if __name__ == '__main__':
584 | args = parser.parse_args()
585 |
586 | # setup writer and output folders
587 | writer = SummaryWriter(log_dir = os.path.join(args.output_dir, time.strftime('%Y-%m-%d_%H-%M-%S', time.gmtime())) \
588 | if not args.restore_file else os.path.dirname(args.restore_file))
589 | writer.step = 0
590 | args.output_dir = writer.file_writer.get_logdir() # update output_dir with the writer unique directory
591 |
592 | # setup device
593 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu')
594 | torch.manual_seed(args.seed)
595 | if args.device.type == 'cuda': torch.cuda.manual_seed(args.seed)
596 |
597 | # load data
598 | train_dataloader, test_dataloader = fetch_dataloaders(args)
599 |
600 | # load model
601 | model = AIR(args).to(args.device)
602 |
603 | # load optimizers
604 | model_optimizer = torch.optim.RMSprop(model.parameters(), lr=args.model_lr, momentum=0.9)
605 | baseline_optimizer = torch.optim.RMSprop(model.parameters(), lr=args.baseline_lr, momentum=0.9)
606 |
607 | if args.restore_file:
608 | checkpoint = torch.load(args.restore_file, map_location=args.device)
609 | model.load_state_dict(checkpoint['state_dict'])
610 | writer.step = checkpoint['global_step']
611 | args.start_epoch = checkpoint['epoch'] + 1
612 | # set up paths
613 | args.output_dir = os.path.dirname(args.restore_file)
614 |
615 | # save settings
616 | with open(os.path.join(args.output_dir, 'config.txt'), 'a') as f:
617 | print('Parsed args:\n', pprint.pformat(args.__dict__), file=f)
618 | print('\nModel:\n', model, file=f)
619 |
620 | if args.train:
621 | train_and_evaluate(model, train_dataloader, test_dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, writer, args)
622 |
623 | if args.evaluate:
624 | test_elbo, count_accuracy, image_recons = evaluate(model, test_dataloader, args)
625 | print('Evaluation: test elbo {:.3f}; {:.3f}'.format(test_elbo, count_accuracy))
626 | save_image(image_recons, os.path.join(args.output_dir, 'image_recons.png'))
627 |
628 | if args.generate:
629 | samples = model.generate(n_samples=7)
630 | images = make_grid(samples, pad_value=1)
631 | save_image(images, os.path.join(args.output_dir, 'generated_sample.png'))
632 | writer.add_image('generated_sample', images)
633 |
634 | writer.close()
635 |
--------------------------------------------------------------------------------
/basic_vae.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of Auto-encoding Variational Bayes
3 | https://arxiv.org/pdf/1312.6114.pdf
4 |
5 | Reference implementatoin in pytorch examples https://github.com/pytorch/examples/blob/master/vae/
6 | Toy example per Adversarial Variational Bayes https://arxiv.org/abs/1701.04722
7 |
8 | """
9 |
10 | import os
11 | import argparse
12 | from tqdm import tqdm
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.distributions as D
18 | import torchvision.transforms as T
19 | from torch.utils.data import DataLoader, Dataset
20 | from torchvision.datasets import MNIST
21 | from torchvision.utils import make_grid, save_image
22 |
23 | import matplotlib
24 | matplotlib.use('TkAgg')
25 | import matplotlib.pyplot as plt
26 | from sklearn.manifold import TSNE
27 |
28 |
29 |
30 | parser = argparse.ArgumentParser()
31 | subparsers = parser.add_subparsers(help='Dataset specific configs for input and latent dimensions.', dest='dataset')
32 |
33 | # training params
34 | parser.add_argument('--batch_size', type=int, default=128)
35 | parser.add_argument('--n_epochs', type=int, default=10)
36 | parser.add_argument('--seed', type=int, default=11272018)
37 | parser.add_argument('--save_model', action='store_true')
38 | parser.add_argument('--quiet', action='store_true')
39 |
40 | parser.add_argument('--data_dir', default='./data')
41 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0]))
42 |
43 | # model parameters
44 | toy_subparser = subparsers.add_parser('toy')
45 | toy_subparser.add_argument('--x_dim', type=int, default=4, help='Dimension of the input data.')
46 | toy_subparser.add_argument('--z_dim', type=int, default=2, help='Size of the latent space.')
47 | toy_subparser.add_argument('--hidden_dim', type=int, default=400, help='Size of the hidden layer.')
48 |
49 | mnist_subparser = subparsers.add_parser('mnist')
50 | mnist_subparser.add_argument('--x_dim', type=int, default=28*28, help='Dimension of the input data.')
51 | mnist_subparser.add_argument('--z_dim', type=int, default=100, help='Size of the latent space.')
52 | mnist_subparser.add_argument('--hidden_dim', type=int, default=400, help='Size of the hidden layer.')
53 |
54 |
55 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56 |
57 |
58 | # --------------------
59 | # Data
60 | # --------------------
61 |
62 | def fetch_dataloader(args, train=True, download=False):
63 |
64 | transforms = T.Compose([T.ToTensor()])
65 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms)
66 |
67 | kwargs = {'num_workers': 1, 'pin_memory': True} if device.type is 'cuda' else {}
68 |
69 | return DataLoader(dataset, batch_size=args.batch_size, shuffle=train, drop_last=True, **kwargs)
70 |
71 | class ToyDataset(Dataset):
72 | def __init__(self, args):
73 | super().__init__()
74 | self.x_dim = args.x_dim
75 | self.batch_size = args.batch_size
76 |
77 | def __len__(self):
78 | return self.batch_size * 1000
79 |
80 | def __getitem__(self, i):
81 | one_hot = torch.zeros(self.x_dim)
82 | label = torch.randint(0, self.x_dim, (1, )).long()
83 | one_hot[label] = 1.
84 | return one_hot, label
85 |
86 | def fetch_toy_dataloader(args):
87 | return DataLoader(ToyDataset(args), batch_size=args.batch_size, shuffle=True)
88 |
89 |
90 | # --------------------
91 | # Plotting helpers
92 | # --------------------
93 |
94 | def plot_tsne(model, test_loader, args):
95 | data = test_loader.dataset.test_data.float() / 255.
96 | data = data.view(data.shape[0], -1)
97 | labels = test_loader.dataset.test_labels
98 | classes = torch.unique(labels, sorted=True).numpy()
99 |
100 | p_x_z, q_z_x = model(data)
101 |
102 | tsne = TSNE(n_components=2, random_state=0)
103 | z_embed = tsne.fit_transform(q_z_x.loc.cpu().numpy()) # map the posterior mean
104 |
105 | fig = plt.figure()
106 | for i in classes:
107 | mask = labels.cpu().numpy() == i
108 | plt.scatter(z_embed[mask, 0], z_embed[mask, 1], s=10, label=str(i))
109 |
110 | plt.title('Latent variable T-SNE embedding per class')
111 | plt.legend()
112 | plt.gca().axis('off')
113 | fig.savefig(os.path.join(args.output_dir, 'tsne_embedding.png'))
114 |
115 |
116 | def plot_scatter(model, args):
117 | data = torch.eye(args.x_dim).repeat(args.batch_size, 1)
118 | labels = data @ torch.arange(args.x_dim).float()
119 |
120 | _, q_z_x = model(data)
121 | z = q_z_x.sample().numpy()
122 | plt.scatter(z[:,0], z[:,1], c=labels.data.numpy(), alpha=0.5)
123 |
124 | plt.title('Latent space embedding per class\n(n_iter = {})'.format(len(ToyDataset(args))*args.n_epochs))
125 | plt.savefig(os.path.join(args.output_dir, 'latent_distribution_toy_example.png'))
126 | plt.close()
127 |
128 | # --------------------
129 | # Model
130 | # --------------------
131 |
132 | class VAE(nn.Module):
133 | def __init__(self, args):#in_dim=784, hidden_dim=400, z_dim=20):
134 | super().__init__()
135 | self.fc1 = nn.Linear(args.x_dim, args.hidden_dim)
136 | self.fc21 = nn.Linear(args.hidden_dim, args.z_dim)
137 | self.fc22 = nn.Linear(args.hidden_dim, args.z_dim)
138 | self.fc3 = nn.Linear(args.z_dim, args.hidden_dim)
139 | self.fc4 = nn.Linear(args.hidden_dim, args.x_dim)
140 |
141 | # q(z|x) parametrizes the approximate posterior as a Normal(mu, scale)
142 | def encode(self, x):
143 | h1 = F.relu(self.fc1(x))
144 | mu = self.fc21(h1)
145 | scale = self.fc22(h1).exp()
146 | return D.Normal(mu, scale)
147 |
148 | # p(x|z) returns the likelihood of data given the latents
149 | def decode(self, z):
150 | h3 = F.relu(self.fc3(z))
151 | logits = self.fc4(h3)
152 | return D.Bernoulli(logits=logits)
153 |
154 | def forward(self, x):
155 | q_z_x = self.encode(x.view(x.shape[0], -1)) # returns Normal
156 | p_x_z = self.decode(q_z_x.rsample()) # returns Bernoulli; note reparametrization when sampling the approximate
157 | return p_x_z, q_z_x
158 |
159 |
160 | # ELBO loss
161 | def loss_fn(p_x_z, q_z_x, x):
162 | # Equation 3 from Kingma & Welling -- Auto-Encoding Variational Bayes
163 | # ELBO = - KL( q(z|x), p(z) ) + Expectation_under_q(z|x)_[log p(x|z)]
164 | # this simplifies to eq 7 from Kingma nad Welling where the expectation is avg of z samples
165 | # signs are revered from paper as paper maximizes ELBO and here we min - ELBO
166 | # both KLD and BCE are summed over dim 1 (image H*W) and mean over dim 0 (batch)
167 | p_z = D.Normal(torch.FloatTensor([0], device=x.device), torch.FloatTensor([1], device=x.device))
168 | KLD = D.kl.kl_divergence(q_z_x, p_z).sum(1).mean(0) # divergene of the approximate posterior from the prior
169 | BCE = - p_x_z.log_prob(x.view(x.shape[0], -1)).sum(1).mean(0) # expected negative reconstruction error;
170 | # prob density of data x under the generative model given by z
171 | return BCE + KLD
172 |
173 |
174 | # --------------------
175 | # Train and eval
176 | # --------------------
177 |
178 | def train_epoch(model, dataloader, loss_fn, optimizer, epoch, args):
179 | model.train()
180 |
181 | ELBO_loss = 0
182 |
183 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar:
184 | for i, (data, _) in enumerate(dataloader):
185 | data = data.to(device)
186 |
187 | p_x_z, q_z_x = model(data)
188 | loss = loss_fn(p_x_z, q_z_x, data)
189 |
190 | optimizer.zero_grad()
191 | loss.backward()
192 | optimizer.step()
193 |
194 | # update tracking
195 | pbar.set_postfix(loss='{:.3f}'.format(loss.item()))
196 | pbar.update()
197 |
198 | ELBO_loss += loss.item()
199 |
200 | print('Epoch: {} Average ELBO loss: {:.4f}'.format(epoch+1, ELBO_loss / (len(dataloader))))
201 |
202 |
203 | @torch.no_grad()
204 | def evaluate(model, dataloader, loss_fn, epoch, args):
205 | model.eval()
206 |
207 | ELBO_loss = 0
208 |
209 | with tqdm(total=len(dataloader)) as pbar:
210 | for i, (data, _) in enumerate(dataloader):
211 | data = data.to(device)
212 | p_x_z, q_z_x = model(data)
213 |
214 | ELBO_loss += loss_fn(p_x_z, q_z_x, data).item()
215 |
216 | pbar.update()
217 |
218 | if i == 0 and args.dataset == 'mnist':
219 | nrow = 10
220 | n = min(data.size(0), nrow**2)
221 | real_data = make_grid(data[:n].cpu(), nrow)
222 | spacer = torch.ones(real_data.shape[0], real_data.shape[1], 5)
223 | generated_data = make_grid(p_x_z.probs.view(args.batch_size, 1, 28, 28)[:n].cpu(), nrow)
224 | image = torch.cat([real_data, spacer, generated_data], dim=-1)
225 | save_image(image, os.path.join(args.output_dir, 'reconstruction_at_epoch_' + str(epoch) + '.png'), nrow)
226 |
227 | print('Test set average ELBO loss: {:.4f}'.format(ELBO_loss / len(dataloader)))
228 |
229 |
230 | def train_and_evaluate(model, train_loader, test_loader, loss_fn, optimizer, args):
231 | for epoch in range(args.n_epochs):
232 | train_epoch(model, train_loader, loss_fn, optimizer, epoch, args)
233 | evaluate(model, test_loader, loss_fn, epoch, args)
234 |
235 | # save weights
236 | if args.save_model:
237 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'vae_model_xdim{}_hdim{}_zdim{}.pt'.format(
238 | args.x_dim, args.hidden_dim, args.z_dim)))
239 |
240 | # show samples
241 | if args.dataset == 'mnist':
242 | with torch.no_grad():
243 | # sample p(z) = Normal(0, 1)
244 | prior_sample = torch.randn(64, args.z_dim).to(device)
245 | # compute likelihood p(x|z) decoder; returns torch.distribution.Bernoulli
246 | likelihood = model.decode(prior_sample).probs
247 | save_image(likelihood.cpu().view(64, 1, 28, 28), os.path.join(args.output_dir, 'sample_at_epoch_' + str(epoch) + '.png'))
248 |
249 |
250 |
251 | if __name__ == '__main__':
252 | args = parser.parse_args()
253 | if not os.path.isdir(os.path.join(args.output_dir, args.dataset)):
254 | os.makedirs(os.path.join(args.output_dir, args.dataset))
255 | args.output_dir = os.path.join(args.output_dir, args.dataset)
256 |
257 | torch.manual_seed(args.seed)
258 |
259 | # data
260 | if args.dataset == 'toy':
261 | train_loader = fetch_toy_dataloader(args)
262 | test_loader = train_loader
263 | else:
264 | train_loader = fetch_dataloader(args, train=True)
265 | test_loader = fetch_dataloader(args, train=False)
266 |
267 | # model
268 | model = VAE(args).to(device)
269 |
270 | # optimizer
271 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
272 |
273 | # train and eval
274 | train_and_evaluate(model, train_loader, test_loader, loss_fn, optimizer, args)
275 |
276 | # visualize z space
277 | with torch.no_grad():
278 | if args.dataset == 'toy':
279 | plot_scatter(model, args)
280 | else:
281 | pass
282 | plot_tsne(model, test_loader, args)
283 |
284 |
--------------------------------------------------------------------------------
/datasets/chexpert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import glob
5 | from multiprocessing import Pool
6 | from functools import partial
7 |
8 | import pandas as pd
9 | from tqdm import tqdm
10 | from PIL import Image
11 |
12 | import torch
13 | from torch.utils.data import Dataset
14 | import torchvision.transforms as T
15 |
16 |
17 | class ChexpertDataset(Dataset):
18 | attr_all_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
19 | 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
20 | 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other',
21 | 'Fracture', 'Support Devices']
22 | # subset of labels to use
23 | attr_names = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
24 |
25 | def __init__(self, root, train=True, transform=None):
26 | self.root = os.path.expanduser(root)
27 | self.transform = transform
28 | self.input_dims = (1, *(int(i) for i in self.root.strip('/').rpartition('/')[2].split('_')[-1].split('x')))
29 |
30 | if train:
31 | self.data = self._load_and_preprocess_training_data(os.path.join(self.root, 'train.csv'))
32 | else:
33 | self.data = pd.read_csv(os.path.join(self.root, 'valid.csv'), keep_default_na=True)
34 |
35 | # store index of the selected attributes in the columns of the data for faster indexing
36 | self.attr_idxs = [self.data.columns.tolist().index(a) for a in self.attr_names]
37 |
38 | def __getitem__(self, idx):
39 | # 1. select and load image
40 | img_path = self.data.iloc[idx, 1] # `Path` is the first column after index
41 | img = Image.open(os.path.join(self.root, img_path.partition('/')[2]))
42 | if self.transform is not None:
43 | img = self.transform(img)
44 |
45 | # 2. select attributes as targets
46 | attr = self.data.iloc[idx, self.attr_idxs].values.astype(float)
47 | attr = torch.from_numpy(attr).float()
48 |
49 | return img, attr
50 |
51 | def __len__(self):
52 | return len(self.data)
53 |
54 | def _load_and_preprocess_training_data(self, csv_path):
55 | # Dataset labels are: blank for unmentioned, 0 for negative, -1 for uncertain, and 1 for positive.
56 | # Process by:
57 | # 1. fill NAs (blanks for unmentioned) as 0 (negatives)
58 | # 2. fill -1 as 1 (U-Ones method described in paper)
59 |
60 | # load
61 | train_df = pd.read_csv(csv_path, keep_default_na=True)
62 |
63 | # 1. fill NAs (blanks for unmentioned) as 0 (negatives)
64 | # attr columns ['No Finding', ..., 'Support Devices']; note AP/PA remains with NAs for Lateral pictures
65 | train_df[self.attr_names] = train_df[self.attr_names].fillna(0)
66 |
67 | # 2. fill -1 as 1 (U-Ones method described in paper)
68 | train_df[self.attr_names] = train_df[self.attr_names].replace(-1,1)
69 |
70 | return train_df
71 |
72 |
73 | def compute_dataset_mean_and_std(dataset):
74 | m = 0
75 | s = 0
76 | k = 1
77 | for img, _, _ in tqdm(dataset):
78 | x = img.mean().item()
79 | new_m = m + (x - m)/k
80 | s += (x - m)*(x - new_m)
81 | m = new_m
82 | k += 1
83 | print('Number of datapoints: ', k)
84 | return m, math.sqrt(s/(k-1))
85 |
86 | # --------------------
87 | # Resize dataset
88 | # --------------------
89 |
90 | def _process_entry(img_path, root, source_dir, target_dir, transforms):
91 | # img_path is e.g. `CheXpert-v1.0-small/valid/patient64541/study1/view1_frontal.jpg`
92 | subpath = img_path.partition('/')[-1]
93 |
94 | # make new img folders/subfolders
95 | os.makedirs(os.path.dirname(os.path.join(root, target_dir, subpath)), exist_ok=True)
96 |
97 | # save resized image
98 | img = Image.open(os.path.join(root, source_dir, subpath))
99 | img = transforms(img)
100 | img.save(os.path.join(root, target_dir, subpath), quality=97)
101 | img.close()
102 |
103 | def make_resized_dataset(root, source_dir, size, n_workers):
104 | root = os.path.expanduser(root)
105 | target_dir = 'CheXpert_{}x{}'.format(size, size)
106 |
107 | assert (not os.path.exists(os.path.join(root, target_dir, 'train.csv')) and \
108 | not os.path.exists(os.path.join(root, target_dir, 'valid.csv'))), 'Data exists at target dir.'
109 | print('Resizing dataset at root {}:\n\tsource {}\n\ttarget {}\n\tNew size: {}x{}'.format(root, source_dir, target_dir, size, size))
110 |
111 | transforms = T.Compose([T.Resize(size, Image.BICUBIC), T.CenterCrop(size)])
112 |
113 |
114 | for split in ['train', 'valid']:
115 | print('Processing ', split, ' split...')
116 | csv_path = os.path.join(root, source_dir, split + '.csv')
117 |
118 | # load data and preprocess NAs
119 | df = pd.read_csv(csv_path, keep_default_na=True)
120 |
121 | if split == 'train':
122 | # 1. fill NAs (blanks for unmentioned) as 0 (negatives)
123 | # attr columns ['No Finding', ..., 'Support Devices']; note AP/PA remains with NAs for Lateral pictures
124 | df[ChexpertDataset.attr_names] = df[ChexpertDataset.attr_names].fillna(0)
125 |
126 | # 2. fill -1 as 1 (U-Ones method described in paper)
127 | df[ChexpertDataset.attr_names] = df[ChexpertDataset.attr_names].replace(-1,1)
128 |
129 | # make new folders, resize image and store
130 | f = partial(_process_entry, root=root, source_dir=source_dir, target_dir=target_dir, transforms=transforms)
131 | with Pool(n_workers) as p:
132 | p.map(f, df['Path'].tolist())
133 |
134 | # replace `CheXpert-v1.0-small` root with new root defined above
135 | df['Path'] = df['Path'].str.replace(source_dir, target_dir)
136 |
137 | # save new df
138 | df.to_csv(os.path.join(root, target_dir, split + '.csv'))
139 |
140 |
141 | # resize entire dataset
142 | if False:
143 | root = '/mnt/disks/chexpert-ssd'
144 | source_dir = 'CheXpert-v1.0-small'
145 | new_size = 64
146 | n_workers = 16
147 |
148 | make_resized_dataset(root, source_dir, new_size, n_workers)
149 |
150 | # compute dataset mean and std
151 | if False:
152 | ds = ChexpertSmall(root=args.data_dir, train=True, transform=T.Compose([T.CenterCrop(320), T.ToTensor()]))
153 | m, s = compute_mean_and_std(ds)
154 | print('Dataset mean: {}; dataset std {}'.format(m, s))
155 | # Dataset mean: 0.533048452958796; dataset std 0.03490651403764978
156 |
157 | # output a few images from the validation set and display labels
158 | if False:
159 | import torchvision.transforms as T
160 | from torchvision.utils import save_image
161 | ds = ChexpertSmall(root=args.data_dir, train=False,
162 | transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5330], std=[0.0349])]))
163 | print('Valid dataset loaded. Length: ', len(ds))
164 | for i in range(10):
165 | img, attr, patient_id = ds[i]
166 | save_image(img, 'test_valid_dataset_image_{}.png'.format(i), normalize=True, scale_each=True)
167 | print('Patient id: {}; labels: {}'.format(patient_id, attr))
168 |
169 |
--------------------------------------------------------------------------------
/dcgan.py:
--------------------------------------------------------------------------------
1 | """
2 | Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
3 | https://arxiv.org/pdf/1511.06434
4 |
5 | Notes:
6 | Model architecture differs from paper:
7 | generator ends with Sigmoid
8 | inputs normalized to [0,1]
9 | learning rates differ
10 |
11 | """
12 |
13 |
14 | import os
15 | import argparse
16 | from tqdm import tqdm
17 | import time
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch.distributions as dist
23 | from torch.utils.data import DataLoader
24 | from torchvision.datasets import MNIST
25 | import torchvision.transforms as T
26 | from torchvision.utils import save_image, make_grid
27 |
28 | import utils
29 |
30 | parser = argparse.ArgumentParser()
31 |
32 | # training params
33 | parser.add_argument('--batch_size', type=int, default=128)
34 | parser.add_argument('--n_epochs', type=int, default=1)
35 | parser.add_argument('--noise_dim', type=int, default=96, help='Size of the latent representation.')
36 | parser.add_argument('--g_lr', type=float, default=1e-3, help='Generator learning rate')
37 | parser.add_argument('--d_lr', type=float, default=1e-4, help='Discriminator learning rate')
38 | parser.add_argument('--log_interval', default=100)
39 | parser.add_argument('--cuda', type=int, help='Which cuda device to use')
40 | parser.add_argument('--mini_data', action='store_true')
41 | # eval params
42 | parser.add_argument('--evaluate_on_grid', action='store_true')
43 | # data paths
44 | parser.add_argument('--save_model', action='store_true')
45 | parser.add_argument('--data_dir', default='./data')
46 | parser.add_argument('--output_dir', default='./results/dcgan')
47 | parser.add_argument('--restore_file', help='Path to .pt checkpoint file for Discriminator and Generator')
48 |
49 |
50 |
51 |
52 | # --------------------
53 | # Data
54 | # --------------------
55 |
56 | def fetch_dataloader(args, train=True, download=True, mini_size=128):
57 | # load dataset and init in the dataloader
58 |
59 | transforms = T.Compose([T.ToTensor()])
60 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms)
61 |
62 | # load dataset and init in the dataloader
63 | if args.mini_data:
64 | if train:
65 | dataset.train_data = dataset.train_data[:mini_size]
66 | dataset.train_labels = dataset.train_labels[:mini_size]
67 | else:
68 | dataset.test_data = dataset.test_data[:mini_size]
69 | dataset.test_labels = dataset.test_labels[:mini_size]
70 |
71 |
72 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {}
73 |
74 | dl = DataLoader(dataset, batch_size=args.batch_size, shuffle=train, drop_last=True, **kwargs)
75 |
76 | return dl
77 |
78 |
79 | # --------------------
80 | # Model
81 | # --------------------
82 |
83 | class Flatten(nn.Module):
84 | def forward(self, x):
85 | return x.view(x.shape[0], -1)
86 |
87 | class Unflatten(nn.Module):
88 | def __init__(self, B, C, H, W):
89 | super().__init__()
90 | self.B = B
91 | self.C = C
92 | self.H = H
93 | self.W = W
94 |
95 | def forward(self, x):
96 | return x.reshape(self.B, self.C, self.H, self.W)
97 |
98 | class Discriminator(nn.Module):
99 | def __init__(self):
100 | super().__init__()
101 | self.net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # out (B, 64, 14, 14)
102 | nn.LeakyReLU(0.2, True),
103 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # out (B, 128, 7, 7)
104 | nn.BatchNorm2d(128),
105 | nn.LeakyReLU(0.2, True),
106 | nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=0, bias=False), # out (B, 128, 4, 4)
107 | nn.BatchNorm2d(256),
108 | nn.LeakyReLU(0.2, True),
109 | nn.Conv2d(256, 512, kernel_size=4, bias=False), # out (B, 256, 1, 1)
110 | nn.BatchNorm2d(512),
111 | nn.LeakyReLU(0.2, True),
112 | nn.Conv2d(512, 1, kernel_size=1, bias=False))
113 |
114 | def forward(self, x):
115 | return dist.Bernoulli(logits=self.net(x).squeeze())
116 |
117 |
118 | class Generator(nn.Module):
119 | def __init__(self, noise_dim):
120 | super().__init__()
121 | self.net = nn.Sequential(nn.ConvTranspose2d(noise_dim, 512, kernel_size=1, stride=1, padding=0, bias=False),
122 | nn.BatchNorm2d(512),
123 | nn.ReLU(True),
124 | nn.ConvTranspose2d(512, 256, kernel_size=4, bias=False),
125 | nn.BatchNorm2d(256),
126 | nn.ReLU(True),
127 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=1, padding=0, bias=False),
128 | nn.BatchNorm2d(128),
129 | nn.ReLU(True),
130 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
131 | nn.BatchNorm2d(64),
132 | nn.ReLU(True),
133 | nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
134 | nn.Sigmoid())
135 |
136 | def forward(self, x):
137 | return self.net(x)
138 |
139 |
140 | def initialize_weights(m, std=0.02):
141 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
142 | m.weight.data.normal_(mean=1., std=std)
143 | m.bias.data.fill_(0.)
144 | else:
145 | try:
146 | m.weight.data.normal_(std=std)
147 | except AttributeError: # skip activation layers
148 | pass
149 |
150 |
151 |
152 | # --------------------
153 | # Train
154 | # --------------------
155 |
156 | def sample_z(args):
157 | # generate samples from the prior
158 | return dist.Uniform(-1,1).sample((args.batch_size, args.noise_dim, 1, 1)).to(args.device)
159 |
160 |
161 | def train_epoch(D, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args):
162 |
163 | fixed_z = sample_z(args)
164 |
165 | real_labels = torch.ones(args.batch_size, 1, device=args.device).requires_grad_(False)
166 | fake_labels = torch.zeros(args.batch_size, 1, device=args.device).requires_grad_(False)
167 |
168 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar:
169 | time.sleep(0.1)
170 |
171 | for i, (x, _) in enumerate(dataloader):
172 | D.train()
173 | G.train()
174 |
175 | x = x.to(args.device)
176 |
177 | # train generator
178 |
179 | # sample prior
180 | z = sample_z(args)
181 |
182 | # run through model
183 | generated = G(z)
184 | d_fake = D(generated)
185 |
186 | # calculate losses
187 | g_loss = - d_fake.log_prob(real_labels).mean()
188 |
189 | g_optimizer.zero_grad()
190 | g_loss.backward()
191 | g_optimizer.step()
192 |
193 |
194 | # train discriminator
195 | d_real = D(x)
196 | d_fake = D(generated.detach())
197 |
198 | # calculate losses
199 | d_loss = - d_real.log_prob(real_labels).mean() - d_fake.log_prob(fake_labels).mean()
200 |
201 | d_optimizer.zero_grad()
202 | d_loss.backward()
203 | d_optimizer.step()
204 |
205 |
206 | # update tracking
207 | pbar.set_postfix(d_loss='{:.3f}'.format(d_loss.item()),
208 | g_loss='{:.3f}'.format(g_loss.item()))
209 | pbar.update()
210 |
211 | if i % args.log_interval == 0:
212 | step = epoch
213 | writer.add_scalar('d_loss', d_loss.item(), step)
214 | writer.add_scalar('g_loss', g_loss.item(), step)
215 | # sample images
216 | with torch.no_grad():
217 | G.eval()
218 | fake_images = G(fixed_z)
219 | writer.add_image('generated', make_grid(fake_images[:10].cpu(), nrow=10, padding=1), step)
220 | save_image(fake_images[:10].cpu(),
221 | os.path.join(args.output_dir, 'generated_sample_epoch_{}.png'.format(epoch)),
222 | nrow=10)
223 |
224 |
225 | def train(D, G, dataloader, d_optimizer, g_optimizer, writer, args):
226 |
227 | print('Starting training with args:\n', args)
228 |
229 | start_epoch = 0
230 |
231 | if args.restore_file:
232 | print('Restoring parameters from {}'.format(args.restore_file))
233 | start_epoch = utils.load_checkpoint(args.restore_file, [D, G], [d_optimizer, g_optimizer])
234 | args.n_epochs += start_epoch - 1
235 | print('Resuming training from epoch {}'.format(start_epoch))
236 |
237 | for epoch in range(start_epoch, args.n_epochs):
238 | train_epoch(D, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args)
239 |
240 | # snapshot at end of epoch
241 | if args.save_model:
242 | utils.save_checkpoint({'epoch': epoch + 1,
243 | 'model_state_dicts': [D.state_dict(), G.state_dict()],
244 | 'optimizer_state_dicts': [d_optimizer.state_dict(), g_optimizer.state_dict()]},
245 | checkpoint=args.output_dir,
246 | quiet=True)
247 |
248 | @torch.no_grad()
249 | def evaluate_on_grid(G, writer, args):
250 | # sample noise randomly
251 | z = torch.empty(100, args.noise_dim, 1, 1).uniform_(-1,1).to(args.device)
252 |
253 | fake_images = G(z)
254 | writer.add_image('generated grid', make_grid(fake_images.cpu(), nrow=10, normalize=True, padding=1))
255 | save_image(fake_images.cpu(),
256 | os.path.join(args.output_dir, 'latent_var_grid_sample_c1.png'),
257 | nrow=10)
258 |
259 |
260 |
261 | if __name__ == '__main__':
262 | args = parser.parse_args()
263 |
264 | if not os.path.isdir(args.output_dir):
265 | os.makedirs(args.output_dir)
266 |
267 | writer = utils.set_writer(args.output_dir, '_train')
268 |
269 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu')
270 |
271 | # set seed
272 | torch.manual_seed(11122018)
273 | if args.device is torch.device('cuda'): torch.cuda.manual_seed(11122018)
274 |
275 | # input
276 | dataloader = fetch_dataloader(args)
277 |
278 | # models
279 | D = Discriminator().to(args.device)
280 | G = Generator(args.noise_dim).to(args.device)
281 | D.apply(initialize_weights)
282 | G.apply(initialize_weights)
283 |
284 | # optimizers
285 | d_optimizer = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999))
286 | g_optimizer = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
287 |
288 | # train
289 | # eval
290 | if args.evaluate_on_grid:
291 | print('Restoring parameters from {}'.format(args.restore_file))
292 | _ = utils.load_checkpoint(args.restore_file, [D, G], [d_optimizer, g_optimizer])
293 | evaluate_on_grid(G, writer, args)
294 | # train
295 | else:
296 | dataloader = fetch_dataloader(args)
297 | train(D, G, dataloader, d_optimizer, g_optimizer, writer, args)
298 | evaluate_on_grid(G, writer, args)
299 | writer.close()
300 |
301 |
--------------------------------------------------------------------------------
/draw.py:
--------------------------------------------------------------------------------
1 | """
2 | DRAW: A Recurrent Neural Network For Image Generation
3 | https://arxiv.org/pdf/1502.04623.pdf
4 | """
5 |
6 | import os
7 | import argparse
8 | import time
9 | from tqdm import tqdm
10 | import pprint
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.distributions as D
16 | import torchvision.transforms as T
17 | from torch.utils.data import DataLoader
18 | from torchvision.datasets import MNIST
19 | from torchvision.utils import save_image, make_grid
20 |
21 | import utils
22 |
23 | parser = argparse.ArgumentParser()
24 |
25 | # model parameters
26 | parser.add_argument('--image_dims', type=tuple, default=(1,28,28), help='Dimensions of a single datapoint (e.g. (1,28,28) for MNIST).')
27 | parser.add_argument('--time_steps', type=int, default=32, help='Number of time-steps T consumed by the network before performing reconstruction.')
28 | parser.add_argument('--z_size', type=int, default=100, help='Size of the latent representation.')
29 | parser.add_argument('--lstm_size', type=int, default=256, help='Size of the hidden layer in the encoder/decoder models.')
30 | parser.add_argument('--read_size', type=int, default=2, help='Size of the read operation visual field.')
31 | parser.add_argument('--write_size', type=int, default=5, help='Size of the write operation visual field.')
32 | parser.add_argument('--use_read_attn', action='store_true', help='Whether to use visual attention or not. If not, read/write field size is the full image.')
33 | parser.add_argument('--use_write_attn', action='store_true', help='Whether to use visual attention or not. If not, read/write field size is the full image.')
34 |
35 | # training params
36 | parser.add_argument('--train', action='store_true')
37 | parser.add_argument('--train_batch_size', type=int, default=128)
38 | parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs to train.')
39 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
40 | parser.add_argument('--log_interval', default=100, help='How often to write summary outputs.')
41 | parser.add_argument('--cuda', type=int, help='Which cuda device to use')
42 | parser.add_argument('--mini_data', action='store_true')
43 | parser.add_argument('--verbose', '-v', action='store_true', help='Extra monitoring of training + record forward/backward hooks on attn params')
44 |
45 | # eval params
46 | parser.add_argument('--evaluate', action='store_true')
47 | parser.add_argument('--test_batch_size', type=int, default=10, help='Batch size for evaluation')
48 |
49 | # generate params
50 | parser.add_argument('--generate', action='store_true')
51 |
52 | # data paths
53 | parser.add_argument('--save_model', action='store_true')
54 | parser.add_argument('--data_dir', default='./data')
55 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0]))
56 | parser.add_argument('--restore_file', help='Path to .pt checkpoint file.')
57 |
58 |
59 |
60 |
61 | # --------------------
62 | # Data
63 | # --------------------
64 |
65 | def fetch_dataloader(args, batch_size, train=True, download=False, mini_size=128):
66 |
67 | transforms = T.Compose([T.ToTensor()])
68 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms)
69 |
70 | # load dataset and init in the dataloader
71 | if args.mini_data:
72 | if train:
73 | dataset.train_data = dataset.train_data[:mini_size]
74 | dataset.train_labels = dataset.train_labels[:mini_size]
75 | else:
76 | dataset.test_data = dataset.test_data[:mini_size]
77 | dataset.test_labels = dataset.test_labels[:mini_size]
78 |
79 |
80 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {}
81 |
82 | return DataLoader(dataset, batch_size=batch_size, shuffle=train, drop_last=True, **kwargs)
83 |
84 |
85 | # --------------------
86 | # Model
87 | # --------------------
88 |
89 | class DRAW(nn.Module):
90 | def __init__(self, args):
91 | super().__init__()
92 |
93 | # save dimensions
94 | self.C, self.H, self.W = args.image_dims
95 | self.time_steps = args.time_steps
96 | self.lstm_size = args.lstm_size
97 | self.z_size = args.z_size
98 | self.use_read_attn = args.use_read_attn
99 | self.use_write_attn = args.use_write_attn
100 | if self.use_read_attn:
101 | self.read_size = args.read_size
102 | else:
103 | self.read_size = self.H
104 | if self.use_write_attn:
105 | self.write_size = args.write_size
106 | else:
107 | self.write_size = self.H
108 |
109 | # encoder - decoder layers
110 | self.encoder = nn.LSTMCell(2 * self.read_size * self.read_size + self.lstm_size, self.lstm_size)
111 | self.decoder = nn.LSTMCell(self.z_size, self.lstm_size)
112 |
113 | # latent space layer
114 | # outputs the parameters of the q distribution (mu and var; here q it is a Normal)
115 | self.z_linear = nn.Linear(self.lstm_size, 2 * self.z_size)
116 |
117 | # write layers
118 | self.write_linear = nn.Linear(self.lstm_size, self.write_size * self.write_size)
119 |
120 | # filter bank
121 | # outputs center location (g_x, g_y), stride delta, logvar of the gaussian filters, scalar intensity gamma (5 params in total)
122 | self.read_attention_linear = nn.Linear(self.lstm_size, 5)
123 | self.write_attention_linear = nn.Linear(self.lstm_size, 5)
124 |
125 | def compute_q_distribution(self, h_enc):
126 | mus, logsigmas = torch.split(self.z_linear(h_enc), [self.z_size, self.z_size], dim=1)
127 | sigmas = logsigmas.exp()
128 | z_dist = D.Normal(loc=mus, scale=sigmas)
129 | return z_dist.rsample(), mus, sigmas
130 |
131 | def read(self, x, x_hat, h_dec):
132 | if self.use_read_attn:
133 | # read filter bank -- eq 21
134 | g_x, g_y, logvar, logdelta, loggamma = self.read_attention_linear(h_dec).split(split_size=1, dim=1) # split returns column vecs
135 | # compute filter bank matrices -- eq 22 - 26
136 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, self.H, self.W, self.read_size)
137 | # output reading wth attention -- eq 27
138 | new_x = F_y @ x.view(-1, self.H, self.W) @ F_x.transpose(-2, -1) # out (B, N, N)
139 | new_x_hat = F_y @ x_hat.view(-1, self.H, self.W) @ F_x.transpose(-2, -1) # out (B, N, N)
140 | return loggamma.exp() * torch.cat([new_x.view(x.shape[0], -1), new_x_hat.view(x.shape[0], -1)], dim=1)
141 | else:
142 | # output reading without attention -- eq 17
143 | return torch.cat([x, x_hat], dim=1)
144 |
145 | def write(self, h_dec):
146 | w = self.write_linear(h_dec)
147 |
148 | if self.use_write_attn:
149 | # read filter bank -- eq 21
150 | g_x, g_y, logvar, logdelta, loggamma = self.write_attention_linear(h_dec).split(split_size=1, dim=1) # split returns column vecs
151 | # compute filter bank matrices -- eq 22 - 26
152 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, self.H, self.W, self.write_size)
153 | # output write with attention -- eq 29
154 | w = F_y.transpose(-2, -1) @ w.view(-1, self.write_size, self.write_size) @ F_x
155 | return 1. / loggamma.exp() * w.view(w.shape[0], -1)
156 | else:
157 | # output write without attention -- eq 18
158 | return w
159 |
160 | def forward(self, x):
161 | batch_size = x.shape[0]
162 | device = x.device
163 |
164 | # record metrics for loss calculation
165 | mus, sigmas = [0]*self.time_steps, [0]*self.time_steps
166 |
167 | # initialize the canvas matrix c on same device and the hidden state and cell state for the encoder and decoder
168 | c = torch.zeros(*x.shape).to(device)
169 | h_enc = torch.zeros(batch_size, self.lstm_size).to(device)
170 | c_enc = torch.zeros_like(h_enc)
171 | h_dec = torch.zeros(batch_size, self.lstm_size).to(device)
172 | c_dec = torch.zeros_like(h_dec)
173 |
174 | # run model forward (cf DRAW eq 3 - 8)
175 | for t in range(self.time_steps):
176 | x_hat = x.to(c.device) - torch.sigmoid(c)
177 | r = self.read(x, x_hat, h_dec)
178 | h_enc, c_enc = self.encoder(torch.cat([r, h_dec], dim=1), (h_enc, c_enc))
179 | z_sample, mus[t], sigmas[t] = self.compute_q_distribution(h_enc)
180 | h_dec, c_dec = self.decoder(z_sample, (h_dec, c_dec))
181 | c = c + self.write(h_dec)
182 |
183 | # return
184 | # data likelihood; used to compute L_x loss -- shape (B, H*W)
185 | # sequence of latent distributions Q; used to compute L_z loss (here Normal of shape (B, z_size, time_steps)
186 | return D.Bernoulli(logits=c), D.Normal(torch.stack(mus, dim=-1), torch.stack(sigmas, dim=-1))
187 |
188 | @torch.no_grad()
189 | def generate(self, n_samples, args):
190 | samples_time_seq = []
191 |
192 | # initialize model
193 | c = torch.zeros(n_samples, self.C * self.H * self.W).to(args.device)
194 | h_dec = torch.zeros(n_samples, self.lstm_size).to(args.device)
195 | c_dec = torch.zeros_like(h_dec).to(args.device)
196 |
197 | # run for the number of time steps
198 | for t in range(self.time_steps):
199 | z_sample = D.Normal(0,1).sample((n_samples, self.z_size)).to(args.device)
200 | h_dec, c_dec = self.decoder(z_sample, (h_dec, c_dec))
201 | c = c + self.write(h_dec)
202 | x = D.Bernoulli(logits=c.view(n_samples, self.C, self.H, self.W)).probs
203 |
204 | samples_time_seq.append(x)
205 |
206 | return samples_time_seq
207 |
208 |
209 | def compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size):
210 | """ DRAW section 3.2 -- computes the parameters for an NxN grid of Gaussian filters over the input image.
211 | Args
212 | g_x, g_y -- tensors of shape (B, 1); unnormalized center coords for the attention window
213 | logvar -- tensor of shape (B, 1); log variance for the Gaussian filters (filterbank matrices) on the attention window
214 | logdelta -- tensor of shape (B, 1); unnormalized stride for the spacing of the filters in the attention window
215 | H, W -- scalars; original image dimensions
216 | attn_window_size -- scalar; size of the attention window (specified by the read_size / write_size input args
217 |
218 | Returns
219 | g_x, g_y -- tensors of shape (B, 1); normalized center coords of the attention window;
220 | delta -- tensor of shape (B, 1); stride for the spacing of the filters in the attention window
221 | mu_x, mu_y -- tensors of shape (B, attn_window_size); means location of the filters at row and column
222 | F_x, F_y -- tensors of shape (B, N, W) and (B, N, H) where N=attention_window_size; filterbank matrices
223 | """
224 |
225 | batch_size = g_x.shape[0]
226 | device = g_x.device
227 |
228 | # rescale attention window center coords and stride to ensure the initial patch covers the whole input image
229 | # eq 22 - 24
230 | g_x = 0.5 * (W + 1) * (g_x + 1) # (B, 1)
231 | g_y = 0.5 * (H + 1) * (g_y + 1) # (B, 1)
232 | delta = (max(H, W) - 1) / (attn_window_size - 1) * logdelta.exp() # (B, 1)
233 |
234 | # compute the means of the filter
235 | # eq 19 - 20
236 | mu_x = g_x + (torch.arange(1., 1. + attn_window_size).to(device) - 0.5*(attn_window_size + 1)) * delta # (B, N)
237 | mu_y = g_y + (torch.arange(1., 1. + attn_window_size).to(device) - 0.5*(attn_window_size + 1)) * delta # (B, N)
238 |
239 | # compute the filterbank matrices
240 | # B = batch dim; N = attn window size; H = original heigh; W = original width
241 | # eq 25 -- combines logvar=(B, 1, 1) * ( range=(B, 1, W) - mu=(B, N, 1) ) = out (B, N, W); then normalizes over W dimension;
242 | F_x = torch.exp(- 0.5 / logvar.exp().view(-1,1,1) * (torch.arange(1., 1. + W).repeat(batch_size, 1, 1).to(device) - mu_x.unsqueeze(-1))**2)
243 | F_x = F_x / torch.sum(F_x + 1e-8, dim=2, keepdim=True) # normalize over the coordinates of the input image
244 | # eq 26
245 | F_y = torch.exp(- 0.5 / logvar.exp().view(-1,1,1) * (torch.arange(1., 1. + H).repeat(batch_size, 1, 1).to(device) - mu_y.unsqueeze(-1))**2)
246 | F_y = F_y / torch.sum(F_y + 1e-8, dim=2, keepdim=True) # normalize over the coordinates of the input image
247 |
248 | # returns DRAW paper eq 22, 23, 24, 19, 20, 25, 26
249 | return g_x, g_y, delta, mu_x, mu_y, F_x, F_y
250 |
251 |
252 | def loss_fn(d, q, x, writer=None, step=None):
253 | """
254 | Args
255 | d -- data likelihood distribution output by the model (Bernoulli)
256 | q -- approximation distribution to the latent variable z (Normal)
257 | """
258 | # cf DRAW paper section 2
259 |
260 | # reconstruction loss L_x eq 9 -- negative log probability of x under the model d
261 | # (sum log probs over the pixels of each datapoint and mean over batch dim)
262 | # latent loss L_z eq 10 -- sum KL over temporal dimension (number of time steps) and mean over batch and z dims
263 | batch_size = x.shape[0]
264 | p_prior = D.Normal(torch.tensor(0., device=x.device), torch.tensor(1., device=x.device))
265 | loss_log_likelihood = - d.log_prob(x).sum(-1).mean(0) # sum over pixels (-1), mean over datapoints (0)
266 | loss_kl = D.kl.kl_divergence(q, p_prior).sum(dim=[-2,-1]).mean(0) # sum over time_steps (-1) and z (-2), mean over datapoints (0)
267 |
268 | if writer:
269 | writer.add_scalar('loss_log_likelihood', loss_log_likelihood, step)
270 | writer.add_scalar('loss_kl', loss_kl, step)
271 |
272 | return loss_log_likelihood + loss_kl
273 |
274 |
275 | # --------------------
276 | # Train and eval
277 | # --------------------
278 |
279 | def train_epoch(model, dataloader, loss_fn, optimizer, epoch, writer, args):
280 | model.train()
281 |
282 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar:
283 | time.sleep(0.1)
284 |
285 | for i, (x, _) in enumerate(dataloader):
286 | global_step = epoch * len(dataloader) + i + 1
287 |
288 | x = x.view(x.shape[0], -1).to(args.device)
289 |
290 | (d, q) = model(x)
291 |
292 | loss = loss_fn(d, q, x, writer, global_step)
293 |
294 | optimizer.zero_grad()
295 | loss.backward()
296 |
297 | # record grad norm and clip
298 | if args.verbose:
299 | grad_norm = 0
300 | for name, p in model.named_parameters():
301 | grad_norm += p.grad.norm().item() if p.grad is not None else 0
302 | writer.add_scalar('grad_norm', grad_norm, global_step)
303 | nn.utils.clip_grad_norm_(model.parameters(), 10)
304 |
305 | optimizer.step()
306 |
307 | # update tracking
308 | pbar.set_postfix(loss='{:.3f}'.format(loss.item()))
309 | pbar.update()
310 |
311 | if i % args.log_interval == 0:
312 | writer.add_scalar('loss', loss.item(), global_step)
313 |
314 |
315 | def train_and_evaluate(model, train_dataloader, test_dataloader, loss_fn, optimizer, writer, args):
316 | start_epoch = 0
317 |
318 | if args.restore_file:
319 | print('Restoring parameters from {}'.format(args.restore_file))
320 | start_epoch = utils.load_checkpoint(args.restore_file, [model], [optimizer], map_location=args.device.type)
321 | args.n_epochs += start_epoch - 1
322 | print('Resuming training from epoch {}'.format(start_epoch))
323 |
324 | for epoch in range(start_epoch, args.n_epochs):
325 | train_epoch(model, train_dataloader, loss_fn, optimizer, epoch, writer, args)
326 | # evaluate(model, test_dataloader, loss_fn, writer, args, epoch)
327 |
328 | # snapshot at end of epoch
329 | if args.save_model:
330 | utils.save_checkpoint({'epoch': epoch + 1,
331 | 'model_state_dicts': [model.state_dict()],
332 | 'optimizer_state_dicts': [optimizer.state_dict()]},
333 | checkpoint=args.output_dir,
334 | quiet=True)
335 |
336 |
337 | @torch.no_grad()
338 | def evaluate(model, dataloader, loss_fn, writer, args, epoch=None):
339 | model.eval()
340 |
341 | # sample the generation model
342 | samples_time_seq = model.generate(args.test_batch_size**2, args)
343 | samples = samples_time_seq[-1] # grab the final sample
344 | # pull targets to search closest neighbor to
345 | # right-most column of a nxn image grid where n is args.test_batch_size -- start at index n-1 and skip by n (e.g 9, 19, ...)
346 | targets = samples[args.test_batch_size - 1 :: args.test_batch_size].view(args.test_batch_size, -1)
347 | # initialize a large max L2 pixel distance and tensor for l2-closest neighbors
348 | max_distances = 100 * samples[0].numel() * torch.ones(args.test_batch_size).to(args.device)
349 | closest_neighbors = torch.zeros_like(targets)
350 |
351 | # compute ELBO on dataset
352 | cum_loss = 0
353 | for x, _ in tqdm(dataloader):
354 | x = x.view(x.shape[0], -1).float().to(args.device)
355 |
356 | # run through model and aggregate loss
357 | d, q = model(x)
358 | # aggregate loss
359 | loss = loss_fn(d, q, x)
360 | cum_loss += loss.item()
361 |
362 | # find closest neighbors to the targets sampled above - l2 distance between targets and images in minibatch
363 | distances = F.pairwise_distance(x, targets)
364 | mask = distances < max_distances
365 | max_distances[mask] = distances[mask]
366 | closest_neighbors[mask] = x[mask]
367 |
368 | cum_loss /= len(dataloader)
369 | # output loss
370 | print('Evaluation ELBO: {:.2f}'.format(cum_loss))
371 | writer.add_scalar('Evaluation ELBO', cum_loss, epoch)
372 |
373 | # visualize generated samples and closest neighbors (cf DRAW paper fig 6)
374 | generated = make_grid(samples.cpu(), nrow=samples.shape[0]//args.test_batch_size)
375 | spacer = torch.ones_like(generated)[:,:,:2]
376 | neighbors = make_grid(closest_neighbors.view(-1, *args.image_dims).cpu(), nrow=1)
377 | images = torch.cat([generated, spacer, neighbors], dim=-1)
378 | save_image(images, os.path.join(args.output_dir,
379 | 'evaluation_sample' + (epoch!=None)*'_epoch_{}'.format(epoch) + '.png'))
380 | writer.add_image('generated images', images, epoch)
381 |
382 |
383 | @torch.no_grad()
384 | def generate(model, writer, args, n_samples=64):
385 | import math
386 |
387 | # generate samples
388 | samples_time_seq = model.generate(n_samples, args)
389 |
390 | # visualize generation sequence (cf DRAW paper fig 7)
391 | images = torch.stack(samples_time_seq, dim=1).view(-1, *args.image_dims) # reshape to (10*time_steps, 1, 28, 28)
392 | images = make_grid(images, nrow=len(samples_time_seq), padding=1, pad_value=1)
393 | save_name = 'generated_sequences_r{}_w{}_steps{}.png'.format(args.read_size, args.write_size, args.time_steps)
394 | save_image(images, os.path.join(args.output_dir, save_name))
395 | writer.add_image(save_name, images)
396 |
397 | # make gif
398 | for i in range(len(samples_time_seq)):
399 | # convert sequence of image tensors to 8x8 grid
400 | image = make_grid(samples_time_seq[i].cpu(), nrow=int(math.sqrt(n_samples)), padding=1, normalize=True, pad_value=1)
401 | # make into gif
402 | samples_time_seq[i] = image.data.numpy().transpose(1,2,0)
403 |
404 | import imageio
405 | imageio.mimsave(os.path.join(args.output_dir, 'generated_{}_time_steps.gif'.format(args.time_steps)), samples_time_seq)
406 |
407 |
408 | # --------------------
409 | # Monitor training
410 | # --------------------
411 |
412 | def record_attn_params(self, in_tensor, out_tensor, bank_name):
413 | g_x, g_y, logvar, logdelta, loggamma = out_tensor.cpu().split(split_size=1, dim=1)
414 | writer.add_scalar(bank_name + ' g_x', g_x.mean())
415 | writer.add_scalar(bank_name + ' g_y', g_y.mean())
416 | writer.add_scalar(bank_name + ' var', logvar.exp().mean())
417 | writer.add_scalar(bank_name + ' exp_logdelta', logdelta.exp().mean())
418 | writer.add_scalar(bank_name + ' gamma', loggamma.exp().mean())
419 |
420 | def record_attn_grads(self, in_tensor, out_tensor, bank_name):
421 | g_x, g_y, logvar, logdelta, loggamma = out_tensor[0].cpu().split(split_size=1, dim=1)
422 | writer.add_scalar(bank_name + ' grad_var', logvar.exp().mean())
423 | writer.add_scalar(bank_name + ' grad_logdelta', logdelta.exp().mean())
424 |
425 | def record_forward_backward_attn_hooks(model):
426 | from functools import partial
427 |
428 | model.read_attention_linear.register_forward_hook(partial(record_attn_params, bank_name='read'))
429 | model.write_attention_linear.register_forward_hook(partial(record_attn_params, bank_name='write'))
430 | model.write_attention_linear.register_backward_hook(partial(record_attn_grads, bank_name='write'))
431 |
432 |
433 | # --------------------
434 | # Main
435 | # --------------------
436 |
437 | if __name__ == '__main__':
438 | args = parser.parse_args()
439 |
440 | if not os.path.isdir(args.output_dir):
441 | os.makedirs(args.output_dir)
442 |
443 | writer = utils.set_writer(args.output_dir if args.restore_file is None else args.restore_file,
444 | (args.restore_file==None)*'', # suffix only when not restoring
445 | args.restore_file is not None)
446 | # update output_dir with the writer unique directory
447 | args.output_dir = writer.file_writer.get_logdir()
448 |
449 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu')
450 |
451 | # set seed
452 | torch.manual_seed(5)
453 | if args.device.type is 'cuda': torch.cuda.manual_seed(11192018)
454 |
455 | # set up model
456 | model = DRAW(args).to(args.device)
457 |
458 | if args.verbose:
459 | record_forward_backward_attn_hooks(model)
460 |
461 | # train
462 | if args.train:
463 | # optimizer
464 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
465 | # dataloaders
466 | train_dataloader = fetch_dataloader(args, args.train_batch_size, train=True)
467 | test_dataloader = fetch_dataloader(args, args.test_batch_size, train=False)
468 | # run training
469 | print('Starting training with args:\n', args)
470 | writer.add_text('Params', pprint.pformat(args.__dict__))
471 | with open(os.path.join(args.output_dir, 'params.txt'), 'w') as f:
472 | pprint.pprint(args.__dict__, f)
473 | train_and_evaluate(model, train_dataloader, test_dataloader, loss_fn, optimizer, writer, args)
474 |
475 | # eval
476 | if args.evaluate:
477 | print('Restoring parameters from {}'.format(args.restore_file))
478 | _ = utils.load_checkpoint(args.restore_file, [model])
479 | print('Evaluating model with args:\n', args)
480 | # get test dataloader
481 | dataloader = fetch_dataloader(args, args.test_batch_size, train=False)
482 | # evaluate
483 | evaluate(model, dataloader, loss_fn, writer, args)
484 |
485 | # generate
486 | if args.generate:
487 | print('Restoring parameters from {}'.format(args.restore_file))
488 | _ = utils.load_checkpoint(args.restore_file, [model])
489 | print('Generating images from model with args:\n', args)
490 | generate(model, writer, args)
491 |
492 |
493 | writer.close()
494 |
495 |
--------------------------------------------------------------------------------
/draw_test_attn.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torchvision.datasets import MNIST
5 | import torchvision.transforms as T
6 |
7 | import matplotlib.pyplot as plt
8 | from matplotlib.patches import Rectangle
9 | import matplotlib.gridspec as gridspec
10 | from PIL import Image
11 | import numpy as np
12 |
13 | from draw import compute_filterbank_matrices
14 |
15 |
16 | # --------------------
17 | # DRAW paper figure 3
18 | # --------------------
19 |
20 | def plot_attn_window(img_tensor, mu_x, mu_y, delta, sigma, attn_window_size, ax):
21 | ax.imshow(img_tensor.squeeze().data.numpy(), cmap='gray')
22 | mu_x = mu_x.flatten()
23 | mu_y = mu_y.flatten()
24 | x = mu_x[0]
25 | y = mu_y[0]
26 | w = mu_y[-1] - mu_y[0]
27 | h = mu_x[-1] - mu_x[0]
28 | ax.add_patch(Rectangle((x, y), w, h, facecolor='none', edgecolor='lime', linewidth=5*sigma, alpha=0.7))
29 |
30 | def plot_filtered_attn_window(img_tensor, F_x, F_y, g_x, g_y, H, W, attn_window_size, ax):
31 | ax.set_xlim(0, W)
32 | ax.set_ylim(0, H)
33 | ax.imshow((F_y @ img_tensor @ F_x.permute(0,2,1)).squeeze(), cmap='gray', extent=(g_x - attn_window_size/2,
34 | g_x + attn_window_size/2,
35 | g_y - attn_window_size/2,
36 | g_y + attn_window_size/2))
37 | def test_attn_window_params():
38 | dataset = MNIST(root='./data', transform=T.ToTensor(), train=True, download=False)
39 | img = dataset[0][0]
40 | batch_size, H, W = img.shape
41 |
42 | fig = plt.figure()#figsize=(8,6))
43 | gs = gridspec.GridSpec(nrows=3, ncols=3, width_ratios=[4,2,1])
44 |
45 | # Figure 3.
46 | # Left: A 3 × 3 grid of filters superimposed on an image. The stride (δ) and centre location (gX , gY ) are indicated.
47 | attn_window_size = 3
48 | g_x = torch.tensor([[-0.2]])
49 | g_y = torch.tensor([[0.]])
50 | logvar = torch.tensor([[1.]])
51 | logdelta = torch.tensor([[-1.]])
52 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size)
53 |
54 | ax = fig.add_subplot(gs[:,0])
55 | ax.imshow(img.squeeze().data.numpy(), cmap='gray')
56 | ax.scatter(g_x.numpy(), g_y.numpy(), s=150, color='orange', alpha=0.8)
57 | ax.scatter(mu_x.view(1, -1).repeat(attn_window_size, 1).numpy(),
58 | mu_y.view(-1,1).repeat(1, attn_window_size).numpy(), s=100, color='lime', alpha=0.8)
59 |
60 | # Right: Three N × N patches extracted from the image (N = 12).
61 | # The green rectangles on the left indicate the boundary and precision (σ) of the patches, while the patches themselves are shown to the right.
62 | # The top patch has a small δ and high σ, giving a zoomed-in but blurry view of the centre of the digit;
63 | # the middle patch has large δ and low σ, effectively downsampling the whole image;
64 | # and the bottom patch has high δ and σ.
65 | attn_window_size = 12
66 | logdeltas = [-1., -0.5, 0.]
67 | sigmas = [1., 0.5, 3.]
68 |
69 | for i, (logdelta, sigma) in enumerate(zip(logdeltas, sigmas)):
70 | g_x = torch.tensor([[-0.2]])
71 | g_y = torch.tensor([[0.]])
72 | logvar = torch.tensor(sigma**2).float().view(1,-1).log()
73 | logdelta = torch.tensor(logdelta).float().view(1,-1)
74 |
75 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size)
76 |
77 | # plot attention window
78 | ax = fig.add_subplot(gs[i,1])
79 | plot_attn_window(img, mu_x, mu_y, delta, sigma, attn_window_size, ax)
80 |
81 | # plot attention window zoom in
82 | ax = fig.add_subplot(gs[i,2])
83 | plot_filtered_attn_window(img, F_x, F_y, g_x, g_y, H, W, attn_window_size, ax)
84 |
85 | for ax in fig.axes:
86 | ax.axis('off')
87 |
88 | plt.tight_layout()
89 | plt.savefig('images/draw_fig_3.png')
90 | plt.close()
91 |
92 |
93 | # --------------------
94 | # DRAW paper figure 4 -- Test read write attention on
95 | # --------------------
96 |
97 | def test_read_write_attn():
98 | #im = cv2.imread('elephant_r.png')
99 | im = np.asarray(Image.open('images/elephant.png'))
100 | img = torch.from_numpy(im).float()
101 | img /= 255. # normalize to 0-1
102 | img = img.permute(2,0,1) # to torch standard (C, H, W)
103 |
104 | print('image dims -- ', img.shape)
105 |
106 | # store dims
107 | C, H, W = img.shape
108 | attn_window_size = 12
109 |
110 |
111 | # filter params
112 | g_x = torch.tensor([[0.5]])
113 | g_y = torch.tensor([[0.5]])
114 | #logvar = torch.tensor([[1.]]).float().log()
115 | #logdelta = torch.tensor([[3.]]).float().log()
116 | logvar = torch.tensor([[1.]])
117 | logdelta = torch.tensor([[-1.]])
118 |
119 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size)
120 |
121 | print('delta -- ', delta)
122 | print('mu_x -- ', mu_x)
123 | print('mu_y -- ', mu_y)
124 | print('F_x shape -- ', F_x.shape)
125 |
126 | mu_x = mu_x.flatten()
127 | mu_y = mu_y.flatten()
128 |
129 |
130 | # read image
131 | read = F_y @ img @ F_x.transpose(-2,-1)
132 | print('read image shape -- ', read.shape)
133 | # reconstruct image
134 | #read.fill_(1)
135 | recon_img = 10 * F_y.transpose(-2,-1) @ read @ F_x
136 |
137 | # plot
138 | fig = plt.figure(figsize=(9, 3))
139 | gs = gridspec.GridSpec(nrows=1, ncols=3, width_ratios=[W,attn_window_size,W])
140 |
141 |
142 | # show original image with attention bbox
143 | ax = fig.add_subplot(gs[0,0])
144 | ax.imshow(im)
145 | x = mu_x[0]
146 | y = mu_y[0]
147 | w = mu_y[-1] - mu_y[0]
148 | h = mu_x[-1] - mu_x[0]
149 | ax.add_patch(Rectangle((x, y), w, h, facecolor='none', edgecolor='lime', linewidth=5, alpha=0.7))
150 |
151 | # show attention patch
152 | ax = fig.add_subplot(gs[0, 1])
153 | ax.set_xlim(0, attn_window_size)
154 | ax.set_ylim(0, H)
155 | ax.imshow(read.squeeze().data.numpy().transpose(1,2,0), extent = (0, attn_window_size, H/2 - attn_window_size/2, H/2 + attn_window_size/2))
156 |
157 | # show reconstruction
158 | ax = fig.add_subplot(gs[0, 2])
159 | ax.imshow(recon_img.squeeze().data.numpy().transpose(1,2,0))
160 |
161 | plt.tight_layout()
162 | for ax in plt.gcf().axes:
163 | ax.axis('off')
164 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0)
165 | plt.savefig('images/draw_fig_4.png')
166 | plt.close()
167 |
168 |
169 |
170 | if __name__ == '__main__':
171 | test_attn_window_params()
172 | test_read_write_attn()
173 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: generative_models
2 | channels:
3 | - pytorch
4 | - defaults
5 | - conda-forge
6 | dependencies:
7 | - blas=1.0
8 | - ca-certificates=2019.5.15
9 | - certifi=2019.6.16
10 | - cffi=1.11.5
11 | - cycler=0.10.0
12 | - freetype=2.9.1
13 | - intel-openmp=2019.1
14 | - jpeg=9b
15 | - kiwisolver=1.0.1
16 | - libedit=3.1.20170329
17 | - libffi=3.2.1
18 | - libpng=1.6.35
19 | - libtiff=4.0.9
20 | - matplotlib=3.0.1
21 | - mkl=2018.0.3
22 | - mkl_fft=1.0.6
23 | - mkl_random=1.0.1
24 | - ncurses=6.1
25 | - ninja=1.8.2
26 | - numpy=1.15.4
27 | - numpy-base=1.15.4
28 | - olefile=0.46
29 | - openssl=1.1.1c
30 | - pandas=0.24.2
31 | - pillow=5.3.0
32 | - pip=18.1
33 | - pycparser=2.19
34 | - pyparsing=2.3.0
35 | - python=3.7.1
36 | - python-dateutil=2.7.5
37 | - pytorch=1.1.0
38 | - pytz=2018.7
39 | - readline=7.0
40 | - scipy=1.1.0
41 | - setuptools=40.6.2
42 | - six=1.11.0
43 | - sqlite=3.25.3
44 | - tk=8.6.8
45 | - torchvision=0.3.0
46 | - tornado=5.1.1
47 | - tqdm=4.28.1
48 | - wheel=0.32.3
49 | - xz=5.2.4
50 | - zlib=1.2.11
51 | - pip:
52 | - chardet==3.0.4
53 | - idna==2.7
54 | - lmdb==0.96
55 | - observations==0.1.4
56 | - protobuf==3.6.1
57 | - requests==2.20.1
58 | - tensorboardx==1.4
59 | - urllib3==1.24.1
60 | prefix: /Users/Kamen/anaconda3/envs/generative_models
61 |
62 |
--------------------------------------------------------------------------------
/images/air/air_count.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/air/air_count.png
--------------------------------------------------------------------------------
/images/air/air_elbo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/air/air_elbo.png
--------------------------------------------------------------------------------
/images/air/image_recons_270.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/air/image_recons_270.png
--------------------------------------------------------------------------------
/images/basic_vae/reconstruction_at_epoch_24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/basic_vae/reconstruction_at_epoch_24.png
--------------------------------------------------------------------------------
/images/basic_vae/sample_at_epoch_24.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/basic_vae/sample_at_epoch_24.png
--------------------------------------------------------------------------------
/images/basic_vae/tsne_embedding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/basic_vae/tsne_embedding.png
--------------------------------------------------------------------------------
/images/dcgan/latent_var_grid_sample_c1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/dcgan/latent_var_grid_sample_c1.png
--------------------------------------------------------------------------------
/images/draw/draw_fig_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/draw_fig_3.png
--------------------------------------------------------------------------------
/images/draw/draw_fig_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/draw_fig_4.png
--------------------------------------------------------------------------------
/images/draw/elephant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/elephant.png
--------------------------------------------------------------------------------
/images/draw/generated_32_time_steps.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/generated_32_time_steps.gif
--------------------------------------------------------------------------------
/images/infogan/latent_var_grid_sample_c1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/infogan/latent_var_grid_sample_c1.png
--------------------------------------------------------------------------------
/images/infogan/latent_var_grid_sample_c2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/infogan/latent_var_grid_sample_c2.png
--------------------------------------------------------------------------------
/images/ssvae/analogies_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/ssvae/analogies_sample.png
--------------------------------------------------------------------------------
/images/ssvae/latent_var_grid_sample_c1_y2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/ssvae/latent_var_grid_sample_c1_y2.png
--------------------------------------------------------------------------------
/images/ssvae/latent_var_grid_sample_c2_y4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/ssvae/latent_var_grid_sample_c2_y4.png
--------------------------------------------------------------------------------
/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_bottom.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_bottom.png
--------------------------------------------------------------------------------
/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_original.png
--------------------------------------------------------------------------------
/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_top.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_top.png
--------------------------------------------------------------------------------
/images/vqvae2/generation_sample_step_52440_top_b128_c128_outstack10_bottom_b16_c128_nres20_condstack10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/generation_sample_step_52440_top_b128_c128_outstack10_bottom_b16_c128_nres20_condstack10.png
--------------------------------------------------------------------------------
/infogan.py:
--------------------------------------------------------------------------------
1 | """
2 | InfoGAN -- https://arxiv.org/abs/1606.03657
3 |
4 | Follows the Tensorflow implementation at http://www.depthfirstlearning.com/2018/InfoGAN
5 |
6 | """
7 |
8 |
9 | import os
10 | import argparse
11 | from tqdm import tqdm
12 | import time
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.distributions as dist
18 | from torch.distributions.one_hot_categorical import OneHotCategorical
19 | from torch.utils.data import DataLoader
20 | from torchvision.datasets import MNIST
21 | import torchvision.transforms as T
22 | from torchvision.utils import save_image, make_grid
23 |
24 | import utils
25 |
26 | parser = argparse.ArgumentParser()
27 |
28 | # training params
29 | parser.add_argument('--batch_size', type=int, default=128)
30 | parser.add_argument('--n_epochs', type=int, default=1)
31 | parser.add_argument('--noise_dim', type=int, default=62, help='Size of the categorical latent representation')
32 | parser.add_argument('--cat_dim', type=int, default=10, help='Size of the categorical latent representation')
33 | parser.add_argument('--cont_dim', type=int, default=2, help='Size of the continuous latent representation')
34 | parser.add_argument('--info_reg_coeff', default=1., help='The weight of the MI regularization hyperparameter')
35 | parser.add_argument('--g_lr', default=1e-3, help='Generator learning rate')
36 | parser.add_argument('--d_lr', default=2e-4, help='Discriminator learning rate')
37 | parser.add_argument('--log_interval', default=100)
38 | parser.add_argument('--cuda', type=int, help='Which cuda device to use')
39 | parser.add_argument('--mini_data', action='store_true')
40 | # eval params
41 | parser.add_argument('--evaluate_on_grid', action='store_true')
42 | # data paths
43 | parser.add_argument('--save_model', action='store_true')
44 | parser.add_argument('--data_dir', default='./data')
45 | parser.add_argument('--output_dir', default='./results/infogan')
46 | parser.add_argument('--restore_file', help='Path to .pt checkpoint file for Discriminator and Generator')
47 |
48 |
49 |
50 |
51 | # --------------------
52 | # Data
53 | # --------------------
54 |
55 | def fetch_dataloader(args, train=True, download=True, mini_size=128):
56 | # load dataset and init in the dataloader
57 |
58 | transforms = T.Compose([T.ToTensor()])
59 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms)
60 |
61 | # load dataset and init in the dataloader
62 | if args.mini_data:
63 | if train:
64 | dataset.train_data = dataset.train_data[:mini_size]
65 | dataset.train_labels = dataset.train_labels[:mini_size]
66 | else:
67 | dataset.test_data = dataset.test_data[:mini_size]
68 | dataset.test_labels = dataset.test_labels[:mini_size]
69 |
70 |
71 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {}
72 |
73 | dl = DataLoader(dataset, batch_size=args.batch_size, shuffle=train, drop_last=True, **kwargs)
74 |
75 | return dl
76 |
77 | # --------------------
78 | # Model
79 | # --------------------
80 |
81 | class Flatten(nn.Module):
82 | def forward(self, x):
83 | return x.view(x.shape[0], -1)
84 |
85 | class Unflatten(nn.Module):
86 | def __init__(self, B, C, H, W):
87 | super().__init__()
88 | self.B = B
89 | self.C = C
90 | self.H = H
91 | self.W = W
92 |
93 | def forward(self, x):
94 | return x.reshape(self.B, self.C, self.H, self.W)
95 |
96 |
97 | class Discriminator(nn.Module):
98 | """ base for the Discriminator (D) and latent recognition network (Q) """
99 | def __init__(self):
100 | super().__init__()
101 | # base network shared between discriminator D and recognition network Q
102 | self.base_net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # out (B, 64, 14, 14)
103 | nn.LeakyReLU(0.1, True),
104 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # out (B, 128, 7, 7)
105 | nn.BatchNorm2d(128),
106 | nn.LeakyReLU(0.1, True),
107 | Flatten(),
108 | nn.Linear(128*7*7, 1024, bias=False),
109 | nn.BatchNorm1d(1024),
110 | nn.LeakyReLU(0.1, True))
111 |
112 | # discriminator -- real vs fake binary output
113 | self.d = nn.Linear(1024, 1)
114 |
115 | def forward(self, x):
116 | x = self.base_net(x).squeeze()
117 | logits_real = self.d(x)
118 | # return feature representation and real vs fake prob
119 | return x, dist.Bernoulli(logits=logits_real)
120 |
121 |
122 | class Q(nn.Module):
123 | """ Latent space recognition network; shares base network of the discriminator """
124 | def __init__(self, cat_dim, cont_dim, fix_cont_std=True):
125 | super().__init__()
126 | self.cat_dim = cat_dim
127 | self.cont_dim = cont_dim
128 | self.fix_cont_std = fix_cont_std
129 |
130 | # recognition network for latent vars ie encoder, shared between the factors of q
131 | self.encoder = nn.Sequential(nn.Linear(1024, 128, bias=False),
132 | nn.BatchNorm1d(128),
133 | nn.LeakyReLU(0.1, True))
134 |
135 | # the factors of q -- 1 categorical and 2 continuous variables
136 | self.q = nn.Linear(128, cat_dim + 2 * cont_dim)
137 |
138 | def forward(self, x):
139 | # latent space encoding
140 | z = self.encoder(x)
141 |
142 | logits_cat, cont_mu, cont_var = torch.split(self.q(z), [self.cat_dim, self.cont_dim, self.cont_dim], dim=-1)
143 |
144 | if self.fix_cont_std:
145 | cont_sigma = torch.ones_like(cont_mu)
146 | else:
147 | cont_sigma = F.softplus(cont_var)
148 |
149 | q_cat = dist.Categorical(logits=logits_cat)
150 | q_cont = dist.Normal(loc=cont_mu, scale=cont_sigma)
151 |
152 | return q_cat, q_cont
153 |
154 |
155 | class Generator(nn.Module):
156 | def __init__(self):
157 | super().__init__()
158 | self.net = nn.Sequential(nn.Linear(74, 1024, bias=False),
159 | nn.BatchNorm1d(1024),
160 | nn.ReLU(True),
161 | nn.Linear(1024, 7*7*128),
162 | nn.BatchNorm1d(7*7*128),
163 | nn.ReLU(True),
164 | Unflatten(-1, 128, 7, 7),
165 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
166 | nn.BatchNorm2d(64),
167 | nn.ReLU(True),
168 | nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
169 | nn.Sigmoid())
170 |
171 | def forward(self, x):
172 | return self.net(x)
173 |
174 |
175 | def initialize_weights(m):
176 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
177 | m.weight.data.normal_(mean=1., std=0.02)
178 | m.bias.data.fill_(0.)
179 | else:
180 | try:
181 | m.weight.data.normal_(std=0.02)
182 | except AttributeError: # skip activation layers
183 | pass
184 |
185 |
186 | # --------------------
187 | # Train
188 | # --------------------
189 |
190 | def sample_z(args):
191 | # generate samples from the prior
192 | z_cat = OneHotCategorical(logits=torch.zeros(args.batch_size, args.cat_dim)).sample()
193 | z_noise = dist.Uniform(-1, 1).sample(torch.Size((args.batch_size, args.noise_dim)))
194 | z_cont = dist.Uniform(-1, 1).sample(torch.Size((args.batch_size, args.cont_dim)))
195 |
196 | # concatenate the incompressible noise, discrete latest, and continuous latents
197 | z = torch.cat([z_noise, z_cat, z_cont], dim=1)
198 |
199 | return z.to(args.device), z_cat.to(args.device), z_noise.to(args.device), z_cont.to(args.device)
200 |
201 |
202 | def info_loss_fn(cat_fake, cont_fake, z_cat, z_cont, args):
203 | log_prob_cat = cat_fake.log_prob(z_cat.nonzero()[:,1]).mean() # equivalent to pytorch cross_entropy loss fn
204 | log_prob_cont = cont_fake.log_prob(z_cont).sum(1).mean()
205 |
206 | info_loss = - args.info_reg_coeff * (log_prob_cat + log_prob_cont)
207 | return log_prob_cat, log_prob_cont, info_loss
208 |
209 |
210 |
211 | def train_epoch(D, Q, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args):
212 |
213 | fixed_z, _, _, _ = sample_z(args)
214 |
215 | real_labels = torch.ones(args.batch_size, 1, device=args.device).requires_grad_(False)
216 | fake_labels = torch.zeros(args.batch_size, 1, device=args.device).requires_grad_(False)
217 |
218 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar:
219 | time.sleep(0.1)
220 |
221 | for i, (x, _) in enumerate(dataloader):
222 | D.train()
223 | G.train()
224 |
225 | x = x.to(args.device)
226 | # x = 2*x - 0.5
227 |
228 |
229 | # train Generator
230 | z, z_cat, z_noise, z_cont = sample_z(args)
231 |
232 | generated = G(z)
233 | x_pre_q, d_fake = D(generated)
234 | q_cat, q_cont = Q(x_pre_q)
235 |
236 | gan_g_loss = - d_fake.log_prob(real_labels).mean() # equivalent to pytorch binary_cross_entropy_with_logits loss fn
237 | log_prob_cat, log_prob_cont, info_loss = info_loss_fn(q_cat, q_cont, z_cat, z_cont, args)
238 |
239 | g_loss = gan_g_loss + info_loss
240 |
241 | g_optimizer.zero_grad()
242 | g_loss.backward()
243 | g_optimizer.step()
244 |
245 |
246 | # train Discriminator
247 | _, d_real = D(x)
248 | x_pre_q, d_fake = D(generated.detach())
249 | q_cat, q_cont = Q(x_pre_q)
250 |
251 | gan_d_loss = - d_real.log_prob(real_labels).mean() - d_fake.log_prob(fake_labels).mean()
252 | log_prob_cat, log_prob_cont, info_loss = info_loss_fn(q_cat, q_cont, z_cat, z_cont, args)
253 |
254 | d_loss = gan_d_loss + info_loss
255 |
256 | d_optimizer.zero_grad()
257 | d_loss.backward()
258 | d_optimizer.step()
259 |
260 |
261 | # update tracking
262 | pbar.set_postfix(log_prob_cat='{:.3f}'.format(log_prob_cat.item()),
263 | log_prob_cont='{:.3f}'.format(log_prob_cont.item()),
264 | d_loss='{:.3f}'.format(gan_d_loss.item()),
265 | g_loss='{:.3f}'.format(gan_g_loss.item()),
266 | i_loss='{:.3f}'.format(info_loss.item()))
267 | pbar.update()
268 |
269 | if i % args.log_interval == 0:
270 | step = epoch
271 | writer.add_scalar('gan_d_loss', gan_d_loss.item(), step)
272 | writer.add_scalar('gan_g_loss', gan_g_loss.item(), step)
273 | writer.add_scalar('info_loss', info_loss.item(), step)
274 | writer.add_scalar('log_prob_cat', log_prob_cat.item(), step)
275 | writer.add_scalar('log_prob_cont', log_prob_cont.item(), step)
276 | # sample images
277 | with torch.no_grad():
278 | G.eval()
279 | fake_images = G(fixed_z)
280 | writer.add_image('generated', make_grid(fake_images[:10].cpu(), nrow=10, normalize=True, padding=1), step)
281 | save_image(fake_images[:10].cpu(),
282 | os.path.join(args.output_dir, 'generated_sample_epoch_{}.png'.format(epoch)),
283 | nrow=10)
284 |
285 |
286 | def train(D, Q, G, dataloader, d_optimizer, g_optimizer, writer, args):
287 |
288 | print('Starting training with args:\n', args)
289 |
290 | start_epoch = 0
291 |
292 | if args.restore_file:
293 | print('Restoring parameters from {}'.format(args.restore_file))
294 | start_epoch = utils.load_checkpoint(args.restore_file, [D, Q, G], [d_optimizer, g_optimizer], map_location=args.device.type)
295 | args.n_epochs += start_epoch - 1
296 | print('Resuming training from epoch {}'.format(start_epoch))
297 |
298 | for epoch in range(start_epoch, args.n_epochs):
299 | train_epoch(D, Q, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args)
300 |
301 | # snapshot at end of epoch
302 | if args.save_model:
303 | utils.save_checkpoint({'epoch': epoch + 1,
304 | 'model_state_dicts': [D.state_dict(), Q.state_dict(), G.state_dict()],
305 | 'optimizer_state_dicts': [d_optimizer.state_dict(), g_optimizer.state_dict()]},
306 | checkpoint=args.output_dir,
307 | quiet=True)
308 |
309 | # --------------------
310 | # Evaluate
311 | # --------------------
312 |
313 | @torch.no_grad()
314 | def evaluate_on_grid(G, writer, args):
315 | # sample noise randomly
316 | z_noise = torch.empty(100, args.noise_dim).uniform_(-1,1)
317 | # order the categorical latent
318 | z_cat = torch.eye(10).repeat(10,1)
319 | # order the first continuous latent
320 | c = torch.linspace(-2, 2, 10).view(-1,1).repeat(1,10).reshape(-1,1)
321 | z_cont = torch.cat([c, torch.zeros_like(c)], dim=1).reshape(100, 2)
322 |
323 | # combine into z and pass through generator
324 | z = torch.cat([z_noise, z_cat, z_cont], dim=1).to(args.device)
325 | fake_images = G(z)
326 | writer.add_image('c1 cont generated', make_grid(fake_images.cpu(), nrow=10, normalize=True, padding=1))
327 | save_image(fake_images.cpu(),
328 | os.path.join(args.output_dir, 'latent_var_grid_sample_c1.png'),
329 | nrow=10)
330 |
331 | # order second continuous latent; combine into z and pass through generator
332 | z_cont = z_cont.flip(1)
333 | z = torch.cat([z_noise, z_cat, z_cont], dim=1).to(args.device)
334 | fake_images = G(z)
335 | writer.add_image('c2 cont generated', make_grid(fake_images.cpu(), nrow=10, normalize=True, padding=1))
336 | save_image(fake_images.cpu(),
337 | os.path.join(args.output_dir, 'latent_var_grid_sample_c2.png'),
338 | nrow=10)
339 |
340 |
341 | # --------------------
342 | # Run
343 | # --------------------
344 |
345 | if __name__ == '__main__':
346 | args = parser.parse_args()
347 |
348 | if not os.path.isdir(args.output_dir):
349 | os.makedirs(args.output_dir)
350 |
351 | writer = utils.set_writer(args.output_dir, '_train')
352 |
353 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu')
354 |
355 | # set seed
356 | torch.manual_seed(11122018)
357 | if args.device.type is 'cuda': torch.cuda.manual_seed(11122018)
358 |
359 | # models
360 | D = Discriminator().to(args.device)
361 | Q = Q(args.cat_dim, args.cont_dim).to(args.device)
362 | G = Generator().to(args.device)
363 | D.apply(initialize_weights)
364 | Q.apply(initialize_weights)
365 | G.apply(initialize_weights)
366 |
367 | # optimizers
368 | g_optimizer = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
369 | d_optimizer = torch.optim.Adam([{'params': D.parameters()},
370 | {'params': Q.parameters()}], lr=args.d_lr, betas=(0.5, 0.999))
371 |
372 | # eval
373 | if args.evaluate_on_grid:
374 | print('Restoring parameters from {}'.format(args.restore_file))
375 | _ = utils.load_checkpoint(args.restore_file, [D, Q, G], [d_optimizer, g_optimizer])
376 | evaluate_on_grid(G, writer, args)
377 | # train
378 | else:
379 | dataloader = fetch_dataloader(args)
380 | train(D, Q, G, dataloader, d_optimizer, g_optimizer, writer, args)
381 | evaluate_on_grid(G, writer, args)
382 |
383 | writer.close()
384 |
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | """ Wrapper of optimizers in torch.optim for computation of exponential moving average of parameters"""
2 |
3 | import torch
4 |
5 | def build_ema_optimizer(optimizer_cls):
6 | class Optimizer(optimizer_cls):
7 | def __init__(self, *args, polyak=0.0, **kwargs):
8 | if not 0.0 <= polyak <= 1.0:
9 | raise ValueError("Invalid polyak decay rate: {}".format(polyak))
10 | super().__init__(*args, **kwargs)
11 | self.defaults['polyak'] = polyak
12 | self.ema = False
13 |
14 | def step(self, closure=None):
15 | super().step(closure)
16 |
17 | # update exponential moving average after gradient update to parameters
18 | for group in self.param_groups:
19 | for p in group['params']:
20 | state = self.state[p]
21 |
22 | # state initialization
23 | if 'ema' not in state:
24 | state['ema'] = torch.zeros_like(p.data)
25 |
26 | # ema update
27 | state['ema'] -= (1 - self.defaults['polyak']) * (state['ema'] - p.data)
28 |
29 | def use_ema(self, use_ema_for_params=True):
30 | """ substitute exponential moving average values into parameter values """
31 | if self.ema ^ use_ema_for_params: # logical XOR; swap only when different;
32 | try:
33 | print('Swapping EMA and parameters values. Now using: ' + ('EMA' if use_ema_for_params else 'param values'))
34 | for group in self.param_groups:
35 | for p in group['params']:
36 | data = p.data
37 | state = self.state[p]
38 | p.data = state['ema']
39 | state['ema'] = data
40 | self.ema = use_ema_for_params
41 | except KeyError:
42 | print('Optimizer not initialized. No EMA values to swap to. Keeping parameter values.')
43 |
44 |
45 | def __repr__(self):
46 | s = super().__repr__()
47 | return self.__class__.__mro__[1].__name__ + ' (\npolyak: {}\n'.format(self.defaults['polyak']) + s.partition('\n')[2]
48 |
49 | return Optimizer
50 |
51 | Adam = build_ema_optimizer(torch.optim.Adam)
52 | RMSprop = build_ema_optimizer(torch.optim.RMSprop)
53 |
54 |
55 | if __name__ == '__main__':
56 | import copy
57 | torch.manual_seed(0)
58 | x = torch.randn(2,2)
59 | y = torch.rand(2,2)
60 | polyak = 0.9
61 | _m = torch.nn.Linear(2,2)
62 | for optim in [Adam, RMSprop]:
63 | m = copy.deepcopy(_m)
64 | o = optim(m.parameters(), lr=0.1, polyak=polyak)
65 | print('Testing: ', optim.__name__)
66 | print(o)
67 | print('init loss {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
68 | p = torch.zeros_like(m.weight)
69 | for i in range(5):
70 | loss = torch.mean((m(x) - y)**2)
71 | print('step {}: loss {:.3f}'.format(i, loss.item()))
72 | o.zero_grad()
73 | loss.backward()
74 | o.step()
75 | # manual compute ema
76 | p -= (1 - polyak) * (p - m.weight.data)
77 | print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
78 | print('swapping ema values for params.')
79 | o.use_ema(True)
80 | assert torch.allclose(p, m.weight)
81 | print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
82 | print('swapping params for ema values.')
83 | o.use_ema(False)
84 | print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
85 | print()
86 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Collection of generative methods in pytorch.
6 |
7 | # Implemented models
8 | * [Generating Diverse High-Fidelity Images with VQ-VAE-2](https://arxiv.org/abs/1906.00446) / [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)
9 | * [Attend, Infer, Repeat: Fast Scene Understanding with Generative Models](https://arxiv.org/abs/1603.08575v3)
10 | * [DRAW: A Recurrent Neural Network For Image Generation](https://arxiv.org/abs/1502.04623.pdf)
11 | * [Semi-Supervised Learning with Deep Generative Models](https://arxiv.org/abs/1406.5298)
12 | * [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657)
13 | * [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) (DCGAN)
14 | * [Auto-encoding Variational Bayes](https://arxiv.org/abs/1312.6114)
15 |
16 | The models are implemented for MNIST data; other datasets are a todo.
17 |
18 | ## Dependencies
19 | * python 3.6
20 | * pytorch 0.4.1+
21 | * numpy
22 | * matplotlib
23 | * tensorboardx
24 | * tqdm
25 |
26 | ###### Some of the models further require
27 | * observations
28 | * imageio
29 |
30 |
31 | ## VQ-VAE2
32 |
33 | Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 (https://arxiv.org/abs/1906.00446) based on Vector Quantised VAE per Neural Discrete Representation Learning (https://arxiv.org/abs/1711.00937) with PixelCNN prior on the level 1 discrete latent variables per Conditional Image Generation with PixelCNN Decoders (https://arxiv.org/abs/1606.05328) and PixelSNAIL prior on the level 2 discrete latent variables per PixelSNAIL: An Improved Autoregressive Generative Model (https://arxiv.org/abs/1712.09763).
34 |
35 | #### Results
36 |
37 | Model reconstructions on [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) Chest X-Ray Dataset -- CheXpert is a large public dataset for chest radiograph interpretation, consisting of 224,316 chest radiographs of 65,240 patients at 320x320 pixels for the small version of the dataset. Reconstructions and samples below are for 128x128 images using a codebook of size 8 (3 bits) which for single channel 8-bit gray scale images from CheXpert results to similar compression ratios listed in paper for 8-bit RGB images.
38 |
39 | | Original | Bottom reconstruction
(using top encoding) | Top reconstruction
(using zeroed out bottom encoding) |
40 | | --- | --- | --- |
41 | |  |  |  |
42 |
43 |
44 | ##### Model samples from priors
45 | Both top and bottom prior models are pretty heavy to train; the samples below were trained only for 84k and 140k steps for top and bottom priors, respectively, using smaller model sizes than what was reported in the paper. The samples are class conditional along the rows for classes (atelectasis, cardiomegaly, consolidation, edema, pleural effusion, no finding) -- much more to be desired / improved with larger models and higher computational budget.
46 |
47 | Model parameters:
48 | * bottom prior: n_channels 128, n_res_layers 20, n_cond_stack_layers 10, drop_rate 0.1, batch size 16, lr 5e-5
49 | * top prior: n_channels 128, n_res_layers 5, n_out_stack_layers 10, drop_rate 0.1, batch size 128, lr 5e-5; (attention params: layers 4, heads 8, dq 16, dk 16, dv 128)
50 |
51 | 
52 |
53 | #### Usage
54 |
55 | To train and evaluate/reconstruct from the VAE model with hyperparameters of the paper:
56 | ```
57 | python vqvae.py --train
58 | --n_embeddings [size of the latent space]
59 | --n_epochs [number of epochs to train]
60 | --ema [flag to use exponential moving average training for the embeddings]
61 | --cuda [cuda device to run on]
62 |
63 | python vqvae.py --evaluate
64 | --restore_dir [path to model directory with config.json and saved checkpoint]
65 | --n_samples [number of examples from the validation set to reconstruct]
66 | --cuda [cuda device to run on]
67 | ```
68 |
69 | To train the top and bottom priors on the latent codes using 4 GPUs and Pytorch DistributedDataParallels:
70 | * the latent codes are extracted for the full dataset and saved as a pytorch dataset object, which is then loaded into memory for training
71 | * hyperparameters not shown as options below are at the defaults given by the paper (e.g. kernel size, attention parameters)
72 |
73 | ```
74 | python -m torch.distributed.launch --nproc_per_node 4 --use_env \
75 | vqvae_prior.py --vqvae_dir [path to vae model used for encoding the dataset and decoding samples]
76 | --train
77 | --distributed [flag to use DistributedDataParallels]
78 | --n_epochs 20
79 | --batch_size 128
80 | --lr 0.00005
81 | --which_prior top [which prior to train]
82 | --n_cond_classes 5 [number of classes to condition on]
83 | --n_channels 128 [convolutional channels throughout the architecture]
84 | --n_res_layers 5 [number of residual layers]
85 | --n_out_stack_layers 10 [output convolutional stack (used only by top prior)]
86 | --n_cond_stack_layers 0 [input conditional stack (used only by bottom prior)]
87 | --drop_rate 0.1 [dropout rate used in the residual layers]
88 |
89 | python -m torch.distributed.launch --nproc_per_node 4 --use_env \
90 | vqvae_prior.py --vqvae_dir [path_to_vae_directory]
91 | --train
92 | --distributed
93 | --n_epochs 20
94 | --batch_size 16
95 | --lr 0.00005
96 | --which_prior bottom
97 | --n_cond_classes 5
98 | --n_channels 128
99 | --n_res_layers 20
100 | --n_out_stack_layers 0
101 | --n_cond_stack_layers 10
102 | --drop_rate 0.1
103 | ```
104 |
105 | To generate data from a trained model and priors:
106 | ```
107 |
108 | python -m torch.distributed.launch --nproc_per_node 4 --use_env \
109 | vqvae_prior.py --vqvae_dir [path_to_vae_directory]
110 | --restore_dir [path_to_bottom_prior_directory, path_to_top_prior_directory]
111 | --generate
112 | --distributed
113 | --n_samples [number of samples to generate (per gpu)]
114 | ```
115 |
116 | Useful resources
117 | * Official tensorflow implementation of the VQ layer and VAE model in Sonnet (https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py and https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb)
118 |
119 | ## AIR
120 |
121 | Reimplementation of the Attend, Infer, Repeat (AIR) architecture.
122 | https://arxiv.org/abs/1603.08575v3
123 |
124 | #### Results
125 | Model reconstructed data (top row is sample of original images, bottom row is AIR reconstruction; red attention window corresponds to first time step, green to second):
126 |
127 | 
128 |
129 | EBLO and object count accuracy after 300 epochs of training using RMSprop with the default hyperparameters discussed in the paper and linear annealing of the z_pres probability. Variance coming from the discrete z_pres is alleviated using NVIL ([Mnih & Gregor](https://arxiv.org/abs/1402.0030)) but can still be seen in the count accuracy in the first 50k training iterations.
130 |
131 |
132 | | Variational bound | Count accuracy |
133 | | --- | --- |
134 | |  | 
135 |
136 | #### Usage
137 | To train a model with hyperparameters of the paper:
138 | ```
139 | python air.py -- train \
140 | -- cuda=[# of cuda device to run on]
141 | ```
142 |
143 | To evaluate model ELBO:
144 | ```
145 | python air.py -- evaluate \
146 | -- restore_file=[path to .pt checkpoint]
147 | -- cuda=[# of cuda device to run on]
148 | ```
149 |
150 | To generate data from a trained model:
151 | ```
152 | python air.py -- generate \
153 | -- restore_file=[path to .pt checkpoint]
154 | ```
155 |
156 | Useful resources
157 | * tensorflow implementation https://github.com/akosiorek/attend_infer_repeat and by the same author Sequential AIR (a state-space model on top of AIR) (https://github.com/akosiorek/sqair/)
158 | * pyro implmentation and walk through http://pyro.ai/examples/air.html
159 |
160 | ## DRAW
161 | Reimplementation of the Deep Recurrent Attentive Writer (DRAW) network architecture. https://arxiv.org/abs/1502.04623
162 |
163 | #### Results
164 | Model generated data:
165 |
166 | 
167 |
168 | Results were achieved training at the parameters presented in the paper (except at 32 time steps) for 50 epochs.
169 |
170 | Visualizing the specific filterbank functions for read and write attention (cf Figure 3 & 4 in paper):
171 |
172 | | Extracted patches with grid filters | Applying transposed filters to reconstruct extracted image patch |
173 | | --- | --- |
174 | |  | 
175 |
176 | #### Usage
177 | To train a model with read and write attention (window sizes 2 and 5):
178 | ```
179 | python draw.py -- train \
180 | -- use_read_attn \
181 | -- read_size=2 \
182 | -- use_write_attn \
183 | -- write_size=5 \
184 | -- [add'l options: e.g. n_epoch, z_size, lstm_size] \
185 | -- cuda=[# of cuda device to run on]
186 | ```
187 |
188 | To evaluate model ELBO:
189 | ```
190 | python draw.py -- evaluate \
191 | -- restore_file=[path to .pt checkpoint]
192 | -- [model parameters: read_size, write_size, lstm_size, z_size]
193 | ```
194 |
195 | To generate data from a trained model:
196 | ```
197 | python draw.py -- generate \
198 | -- restore_file=[path to .pt checkpoint]
199 | -- [model parameters: read_size, write_size, lstm_size, z_size]
200 | ```
201 |
202 | #### Useful resources
203 | * https://github.com/jbornschein/draw
204 | * https://github.com/ericjang/draw
205 |
206 |
207 | ## Semi-supervised Learning with Deep Generative Models
208 | https://arxiv.org/abs/1406.5298
209 |
210 | Reimplementation of M2 model on MNIST.
211 |
212 | #### Results
213 | Visualization of handwriting styles learned by the model (cf Figure 1 in paper). Column 1 shows an image column from the test data followed by model generated data. Columns 2 and 3 show model generated styles for a fixed label and a linear variation of each component of a 2-d latent variable.
214 |
215 | | MNIST analogies | Varying 2-d latent z (z1) on number 2 | Varying 2-d latent z (z2) on number 4 |
216 | | --- | --- | --- |
217 | |  |  | 
218 |
219 | #### Usage
220 | To train a model:
221 | ```
222 | python ssvae.py -- train \
223 | -- n_labeled=[100 | 300 | 1000 | 3000] \
224 | -- [add'l options: e.g. n_epochs, z_dim, hidden_size] \
225 | -- cuda=[# of cuda device to run on]
226 | ```
227 |
228 | To evaluate model accuracy:
229 | ```
230 | python ssvae.py -- evaluate \
231 | -- restore_file=[path to .pt checkpoint]
232 | ```
233 |
234 | To generate data from a trained model:
235 | ```
236 | python ssvae.py -- generate \
237 | -- restore_file=[path to .pt checkpoint]
238 | ```
239 |
240 | #### Useful resource
241 | * https://github.com/dpkingma/nips14-ssl
242 |
243 |
244 | ## InfoGAN
245 |
246 | Reimplementation of InfoGan. https://arxiv.org/abs/1606.03657
247 | This follows closely the Tensorflow implementation by [Depth First Learning](http://www.depthfirstlearning.com/2018/InfoGAN) using tf.distribution, which make the model quite intuitive.
248 |
249 | #### Results
250 |
251 | Visualizing model-generated data varying each component of a 2-d continuous latent variable:
252 |
253 | | Varying 2-d latent z (z1)| Varying 2-d latent z (z2) |
254 | | --- | --- |
255 | |  | 
256 |
257 | #### Usage
258 | To train a model with read and write attention (window sizes 2 and 5):
259 | ```
260 | python infogan.py -- n_epochs=[# epochs] \
261 | -- cuda=[# of cuda device to run on]
262 | -- [add'l options: e.g. noise_dim, cat_dim, cont_dim] \
263 | ```
264 |
265 | To evaluate model and visualize latents:
266 | ```
267 | python infogan.py -- evaluate_on_grid \
268 | -- restore_file=[path to .pt checkpoint]
269 | ```
270 |
271 | #### Useful resources
272 | * http://www.depthfirstlearning.com/2018/InfoGAN
273 |
274 |
275 | ## DCGAN
276 |
277 | Reimplementation of DCGAN. https://arxiv.org/abs/1511.06434
278 |
279 | #### Results
280 | Model generated data:
281 |
282 | 
283 |
284 | #### Usage
285 | To train a model with read and write attention (window sizes 2 and 5):
286 | ```
287 | python infogan.py -- n_epochs=[# epochs] \
288 | -- cuda=[# of cuda device to run on]
289 | -- [add'l options: e.g. noise_dim, cat_dim, cont_dim] \
290 | ```
291 |
292 | To evaluate model and visualize latents:
293 | ```
294 | python infogan.py -- evaluate_on_grid \
295 | -- restore_file=[path to .pt checkpoint]
296 | ```
297 |
298 | #### Useful resources
299 | * pytorch code examples https://github.com/pytorch/examples/
300 |
301 |
302 | ## Auto-encoding Variational Bayes
303 | Reimplementation of https://arxiv.org/abs/1312.6114
304 |
305 | #### Results
306 |
307 | Visualizing reconstruction (after training for 25 epochs):
308 |
309 | | Real samples (left) and model reconstruction (right) |
310 | | --- |
311 | |  |
312 |
313 | Visualizing model-generated data and TSNE embedding in latent space:
314 |
315 | | Model-generated data using Normal(0,1) prior | TSNE embedding in latent space |
316 | | --- | --- |
317 | |  |  |
318 |
319 |
320 | #### Usage
321 | To train and evaluate a model on MNIST:
322 | ```
323 | python basic_vae.py -- n_epochs=[# epochs] mnist
324 | ```
325 |
326 | #### Useful resources
327 | * Implementation in Pyro and quick tutorial http://pyro.ai/examples/vae.html
328 |
--------------------------------------------------------------------------------
/ssvae.py:
--------------------------------------------------------------------------------
1 | """
2 | Semi-supervised Learning with Deep Generative Models
3 | https://arxiv.org/pdf/1406.5298.pdf
4 | """
5 |
6 | import os
7 | import argparse
8 | from tqdm import tqdm
9 | import pprint
10 | import copy
11 |
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.distributions as D
18 | import torchvision.transforms as T
19 | from torch.utils.data import DataLoader
20 | from torchvision.datasets import MNIST
21 | from torchvision.utils import save_image, make_grid
22 |
23 | parser = argparse.ArgumentParser()
24 |
25 | # actions
26 | parser.add_argument('--train', action='store_true', help='Train a new or restored model.')
27 | parser.add_argument('--evaluate', action='store_true', help='Evaluate a model.')
28 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.')
29 | parser.add_argument('--vis_styles', action='store_true', help='Visualize styles manifold.')
30 | parser.add_argument('--cuda', type=int, help='Which cuda device to use')
31 | parser.add_argument('--seed', type=int, default=1, help='Random seed.')
32 |
33 | # file paths
34 | parser.add_argument('--restore_file', type=str, help='Path to model to restore.')
35 | parser.add_argument('--data_dir', default='./data/', help='Location of dataset.')
36 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0]))
37 | parser.add_argument('--results_file', default='results.txt', help='Filename where to store settings and test results.')
38 |
39 | # model parameters
40 | parser.add_argument('--image_dims', type=tuple, default=(1,28,28), help='Dimensions of a single datapoint (e.g. (1,28,28) for MNIST).')
41 | parser.add_argument('--z_dim', type=int, default=50, help='Size of the latent representation.')
42 | parser.add_argument('--y_dim', type=int, default=10, help='Size of the labels / output.')
43 | parser.add_argument('--hidden_dim', type=int, default=500, help='Size of the hidden layer.')
44 |
45 | # training params
46 | parser.add_argument('--n_labeled', type=int, default=3000, help='Number of labeled training examples in the dataset')
47 | parser.add_argument('--batch_size', type=int, default=100)
48 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.')
49 | parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
50 | parser.add_argument('--alpha', type=float, default=0.1, help='Classifier loss multiplier controlling generative vs. discriminative learning.')
51 |
52 |
53 | # --------------------
54 | # Data
55 | # --------------------
56 |
57 | # create semi-supervised datasets of labeled and unlabeled data with equal number of labels from each class
58 | def create_semisupervised_datasets(dataset, n_labeled):
59 | # note this is only relevant for training the model
60 | assert dataset.train == True, 'Dataset must be the training set; assure dataset.train = True.'
61 |
62 | # compile new x and y and replace the dataset.train_data and train_labels with the
63 | x = dataset.train_data
64 | y = dataset.train_labels
65 | n_x = x.shape[0]
66 | n_classes = len(torch.unique(y))
67 |
68 | assert n_labeled % n_classes == 0, 'n_labeld not divisible by n_classes; cannot assure class balance.'
69 | n_labeled_per_class = n_labeled // n_classes
70 |
71 | x_labeled = [0] * n_classes
72 | x_unlabeled = [0] * n_classes
73 | y_labeled = [0] * n_classes
74 | y_unlabeled = [0] * n_classes
75 |
76 | for i in range(n_classes):
77 | idxs = (y == i).nonzero().data.numpy()
78 | np.random.shuffle(idxs)
79 |
80 | x_labeled[i] = x[idxs][:n_labeled_per_class]
81 | y_labeled[i] = y[idxs][:n_labeled_per_class]
82 | x_unlabeled[i] = x[idxs][n_labeled_per_class:]
83 | y_unlabeled[i] = y[idxs][n_labeled_per_class:]
84 |
85 | # construct new labeled and unlabeled datasets
86 | labeled_dataset = copy.deepcopy(dataset)
87 | labeled_dataset.train_data = torch.cat(x_labeled, dim=0).squeeze()
88 | labeled_dataset.train_labels = torch.cat(y_labeled, dim=0)
89 |
90 | unlabeled_dataset = copy.deepcopy(dataset)
91 | unlabeled_dataset.train_data = torch.cat(x_unlabeled, dim=0).squeeze()
92 | unlabeled_dataset.train_labels = torch.cat(y_unlabeled, dim=0)
93 |
94 | del dataset
95 |
96 | return labeled_dataset, unlabeled_dataset
97 |
98 |
99 | def fetch_dataloaders(args):
100 | assert args.n_labeled != None, 'Must provide n_labeled number to split dataset.'
101 |
102 | transforms = T.Compose([T.ToTensor()])
103 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {}
104 |
105 | def get_dataset(train):
106 | return MNIST(root=args.data_dir, train=train, transform=transforms)
107 |
108 | def get_dl(dataset):
109 | return DataLoader(dataset, batch_size=args.batch_size, shuffle=dataset.train, drop_last=True, **kwargs)
110 |
111 | test_dataset = get_dataset(train=False)
112 | train_dataset = get_dataset(train=True)
113 | labeled_dataset, unlabeled_dataset = create_semisupervised_datasets(train_dataset, args.n_labeled)
114 |
115 | return get_dl(labeled_dataset), get_dl(unlabeled_dataset), get_dl(test_dataset)
116 |
117 |
118 | def one_hot(x, label_size):
119 | out = torch.zeros(len(x), label_size).to(x.device)
120 | out[torch.arange(len(x)), x.squeeze()] = 1
121 | return out
122 |
123 | # --------------------
124 | # Model
125 | # --------------------
126 |
127 | class SSVAE(nn.Module):
128 | """
129 | Data model (SSL paper eq 2):
130 | p(y) = Cat(y|pi)
131 | p(z) = Normal(z|0,1)
132 | p(x|y,z) = f(x; z,y,theta)
133 |
134 | Recognition model / approximate posterior q_phi (SSL paper eq 4):
135 | q(y|x) = Cat(y|pi_phi(x))
136 | q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x)))
137 |
138 |
139 | """
140 | def __init__(self, args):
141 | super().__init__()
142 | C, H, W = args.image_dims
143 | x_dim = C * H * W
144 |
145 | # --------------------
146 | # p model -- SSL paper generative semi supervised model M2
147 | # --------------------
148 |
149 | self.p_y = D.OneHotCategorical(probs=1 / args.y_dim * torch.ones(1,args.y_dim, device=args.device))
150 | self.p_z = D.Normal(torch.tensor(0., device=args.device), torch.tensor(1., device=args.device))
151 |
152 | # parametrized data likelihood p(x|y,z)
153 | self.decoder = nn.Sequential(nn.Linear(args.z_dim + args.y_dim, args.hidden_dim),
154 | nn.Softplus(),
155 | nn.Linear(args.hidden_dim, args.hidden_dim),
156 | nn.Softplus(),
157 | nn.Linear(args.hidden_dim, x_dim))
158 |
159 | # --------------------
160 | # q model -- SSL paper eq 4
161 | # --------------------
162 |
163 | # parametrized q(y|x) = Cat(y|pi_phi(x)) -- outputs parametrization of categorical distribution
164 | self.encoder_y = nn.Sequential(nn.Linear(x_dim, args.hidden_dim),
165 | nn.Softplus(),
166 | nn.Linear(args.hidden_dim, args.hidden_dim),
167 | nn.Softplus(),
168 | nn.Linear(args.hidden_dim, args.y_dim))
169 |
170 | # parametrized q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- output parametrizations for mean and diagonal variance of a Normal distribution
171 | self.encoder_z = nn.Sequential(nn.Linear(x_dim + args.y_dim, args.hidden_dim),
172 | nn.Softplus(),
173 | nn.Linear(args.hidden_dim, args.hidden_dim),
174 | nn.Softplus(),
175 | nn.Linear(args.hidden_dim, 2 * args.z_dim))
176 |
177 |
178 | # initialize weights to N(0, 0.001) and biases to 0 (cf SSL section 4.4)
179 | for p in self.parameters():
180 | p.data.normal_(0, 0.001)
181 | if p.ndimension() == 1: p.data.fill_(0.)
182 |
183 | # q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- SSL paper eq 4
184 | def encode_z(self, x, y):
185 | xy = torch.cat([x, y], dim=1)
186 | mu, logsigma = self.encoder_z(xy).chunk(2, dim=-1)
187 | return D.Normal(mu, logsigma.exp())
188 |
189 | # q(y|x) = Categorical(y|pi_phi(x)) -- SSL paper eq 4
190 | def encode_y(self, x):
191 | return D.OneHotCategorical(logits=self.encoder_y(x))
192 |
193 | # p(x|y,z) = Bernoulli
194 | def decode(self, y, z):
195 | yz = torch.cat([y,z], dim=1)
196 | return D.Bernoulli(logits=self.decoder(yz))
197 |
198 | # classification model q(y|x) using the trained q distribution
199 | def forward(self, x):
200 | y_probs = self.encode_y(x).probs
201 | return y_probs.max(dim=1)[1] # return pred labels = argmax
202 |
203 |
204 | def loss_components_fn(x, y, z, p_y, p_z, p_x_yz, q_z_xy):
205 | # SSL paper eq 6 for an given y (observed or enumerated from q_y)
206 | return - p_x_yz.log_prob(x).sum(1) \
207 | - p_y.log_prob(y) \
208 | - p_z.log_prob(z).sum(1) \
209 | + q_z_xy.log_prob(z).sum(1)
210 |
211 |
212 | # --------------------
213 | # Train and eval
214 | # --------------------
215 |
216 | def train_epoch(model, labeled_dataloader, unlabeled_dataloader, loss_components_fn, optimizer, epoch, args):
217 | model.train()
218 |
219 | n_batches = len(labeled_dataloader) + len(unlabeled_dataloader)
220 | n_unlabeled_per_labeled = len(unlabeled_dataloader) // len(labeled_dataloader) + 1
221 |
222 | labeled_dataloader = iter(labeled_dataloader)
223 | unlabeled_dataloader = iter(unlabeled_dataloader)
224 |
225 | with tqdm(total=n_batches, desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar:
226 | for i in range(n_batches):
227 | is_supervised = i % n_unlabeled_per_labeled == 0
228 |
229 | # get batch from respective dataloader
230 | if is_supervised:
231 | x, y = next(labeled_dataloader)
232 | y = one_hot(y, args.y_dim).to(args.device)
233 | else:
234 | x, y = next(unlabeled_dataloader)
235 | y = None
236 | x = x.to(args.device).view(x.shape[0], -1)
237 |
238 | # compute loss -- SSL paper eq 6, 7, 9
239 | q_y = model.encode_y(x)
240 | # labeled data loss -- SSL paper eq 6 and eq 9
241 | if y is not None:
242 | q_z_xy = model.encode_z(x, y)
243 | z = q_z_xy.rsample()
244 | p_x_yz = model.decode(y, z)
245 | loss = loss_components_fn(x, y, z, model.p_y, model.p_z, p_x_yz, q_z_xy)
246 | loss -= args.alpha * args.n_labeled * q_y.log_prob(y) # SSL eq 9
247 | # unlabeled data loss -- SSL paper eq 7
248 | else:
249 | # marginalize y according to q_y
250 | loss = - q_y.entropy()
251 | for y in q_y.enumerate_support():
252 | q_z_xy = model.encode_z(x, y)
253 | z = q_z_xy.rsample()
254 | p_x_yz = model.decode(y, z)
255 | L_xy = loss_components_fn(x, y, z, model.p_y, model.p_z, p_x_yz, q_z_xy)
256 | loss += q_y.log_prob(y).exp() * L_xy
257 | loss = loss.mean(0)
258 |
259 | optimizer.zero_grad()
260 | loss.backward()
261 | optimizer.step()
262 |
263 | # update trackers
264 | pbar.set_postfix(loss='{:.3f}'.format(loss.item()))
265 | pbar.update()
266 |
267 |
268 | @torch.no_grad()
269 | def evaluate(model, dataloader, epoch, args):
270 | model.eval()
271 |
272 | accurate_preds = 0
273 |
274 | with tqdm(total=len(dataloader), desc='eval') as pbar:
275 | for i, (x, y) in enumerate(dataloader):
276 | x = x.to(args.device).view(x.shape[0], -1)
277 | y = y.to(args.device)
278 | preds = model(x)
279 |
280 | accurate_preds += (preds == y).sum().item()
281 |
282 | pbar.set_postfix(accuracy='{:.3f}'.format(accurate_preds / ((i+1) * args.batch_size)))
283 | pbar.update()
284 |
285 | output = (epoch != None)*'Epoch {} -- '.format(epoch) + 'Test set accuracy: {:.3f}'.format(accurate_preds / (args.batch_size * len(dataloader)))
286 | print(output)
287 | print(output, file=open(args.results_file, 'a'))
288 |
289 |
290 | def train_and_evaluate(model, labeled_dataloader, unlabeled_dataloader, test_dataloader, loss_components_fn, optimizer, args):
291 | for epoch in range(args.n_epochs):
292 | train_epoch(model, labeled_dataloader, unlabeled_dataloader, loss_components_fn, optimizer, epoch, args)
293 | evaluate(model, test_dataloader, epoch, args)
294 |
295 | # save weights
296 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'ssvae_model_state_hdim{}_zdim{}.pt'.format(
297 | args.hidden_dim, args.z_dim)))
298 |
299 | # show samples -- SSL paper Figure 1-b
300 | generate(model, test_dataloader.dataset, args, epoch)
301 |
302 |
303 | # --------------------
304 | # Visualize
305 | # --------------------
306 |
307 | @torch.no_grad()
308 | def generate(model, dataset, args, epoch=None, n_samples=10):
309 | n_samples_per_label = 10
310 |
311 | # some interesting samples per paper implementation
312 | idxs = [7910, 8150, 3623, 2645, 4066, 9660, 5083, 948, 2595, 2]
313 |
314 | x = torch.stack([dataset[i][0] for i in idxs], dim=0).to(args.device)
315 | y = torch.stack([dataset[i][1] for i in idxs], dim=0).to(args.device)
316 | y = one_hot(y, args.y_dim)
317 |
318 | q_z_xy = model.encode_z(x.view(n_samples_per_label, -1), y)
319 | z = q_z_xy.loc
320 | z = z.repeat(args.y_dim, 1, 1).transpose(0, 1).contiguous().view(-1, args.z_dim)
321 |
322 | # hold z constant and vary y:
323 | y = torch.eye(args.y_dim).repeat(n_samples_per_label, 1).to(args.device)
324 | generated_x = model.decode(y, z).probs.view(n_samples_per_label, args.y_dim, *args.image_dims)
325 | generated_x = generated_x.contiguous().view(-1, *args.image_dims) # out (n_samples * n_label, C, H, W)
326 |
327 | x = make_grid(x.cpu(), nrow=1)
328 | spacer = torch.ones(x.shape[0], x.shape[1], 5)
329 | generated_x = make_grid(generated_x.cpu(), nrow=args.y_dim)
330 | image = torch.cat([x, spacer, generated_x], dim=-1)
331 | save_image(image,
332 | os.path.join(args.output_dir, 'analogies_sample' + (epoch != None)*'_at_epoch_{}'.format(epoch) + '.png'),
333 | nrow=args.y_dim)
334 |
335 |
336 | @torch.no_grad()
337 | def vis_styles(model, args):
338 | assert args.z_dim == 2, 'Style viualization requires z_dim=2'
339 |
340 | for y in range(2,5):
341 | y = one_hot(torch.tensor(y).unsqueeze(-1), args.y_dim).expand(100, args.y_dim).to(args.device)
342 |
343 | # order the first dim of the z latent
344 | c = torch.linspace(-5, 5, 10).view(-1,1).repeat(1,10).reshape(-1,1)
345 | z = torch.cat([c, torch.zeros_like(c)], dim=1).reshape(100, 2).to(args.device)
346 |
347 | # combine into z and pass through decoder
348 | x = model.decode(y, z).probs.view(y.shape[0], *args.image_dims)
349 | save_image(x.cpu(),
350 | os.path.join(args.output_dir, 'latent_var_grid_sample_c1_y{}.png'.format(y[0].nonzero().item())),
351 | nrow=10)
352 |
353 | # order second dim of latent and pass through decoder
354 | z = z.flip(1)
355 | x = model.decode(y, z).probs.view(y.shape[0], *args.image_dims)
356 | save_image(x.cpu(),
357 | os.path.join(args.output_dir, 'latent_var_grid_sample_c2_y{}.png'.format(y[0].nonzero().item())),
358 | nrow=10)
359 |
360 |
361 | # --------------------
362 | # Main
363 | # --------------------
364 |
365 | if __name__ == '__main__':
366 | args = parser.parse_args()
367 |
368 | if not os.path.isdir(args.output_dir):
369 | os.makedirs(args.output_dir)
370 |
371 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda != None else 'cpu')
372 | torch.manual_seed(args.seed)
373 | if args.device.type == 'cuda': torch.cuda.manual_seed(args.seed)
374 |
375 | # dataloaders
376 | labeled_dataloader, unlabeled_dataloader, test_dataloader = fetch_dataloaders(args)
377 |
378 | # model
379 | model = SSVAE(args).to(args.device)
380 |
381 | # optimizer
382 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
383 |
384 | if args.restore_file:
385 | # load model and optimizer states
386 | state = torch.load(args.restore_file, map_location=args.device)
387 | model.load_state_dict(state)
388 | # set up paths
389 | args.output_dir = os.path.dirname(args.restore_file)
390 | args.results_file = os.path.join(args.output_dir, args.results_file)
391 |
392 | print('Loaded settings and model:')
393 | print(pprint.pformat(args.__dict__))
394 | print(model)
395 | print(pprint.pformat(args.__dict__), file=open(args.results_file, 'a'))
396 | print(model, file=open(args.results_file, 'a'))
397 |
398 | if args.train:
399 | train_and_evaluate(model, labeled_dataloader, unlabeled_dataloader, test_dataloader, loss_components_fn, optimizer, args)
400 |
401 | if args.evaluate:
402 | evaluate(model, test_dataloader, None, args)
403 |
404 | if args.generate:
405 | generate(model, test_dataloader.dataset, args)
406 |
407 | if args.vis_styles:
408 | vis_styles(model, args)
409 |
410 |
411 |
412 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import json
5 | from datetime import datetime
6 | import torch
7 |
8 | from tensorboardX import SummaryWriter
9 |
10 | from subprocess import check_call
11 |
12 |
13 | def set_writer(log_path, comment='', restore=False):
14 | """ setup a tensorboardx summarywriter """
15 | current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
16 | if restore:
17 | log_path = os.path.dirname(log_path)
18 | else:
19 | log_path = os.path.join(log_path, current_time + comment)
20 | writer = SummaryWriter(log_dir=log_path)
21 | return writer
22 |
23 |
24 | def save_checkpoint(state, checkpoint, is_best=None, quiet=False):
25 | """ saves model and training params at checkpoint + 'last.pt'; if is_best also saves checkpoint + 'best.pt'
26 |
27 | args
28 | state -- dict; with keys model_state_dict, optimizer_state_dict, epoch, scheduler_state_dict, etc
29 | is_best -- bool; true if best model seen so far
30 | checkpoint -- str; folder where params are to be saved
31 | """
32 |
33 | filepath = os.path.join(checkpoint, 'state_checkpoint.pt')
34 | if not os.path.exists(checkpoint):
35 | if not quiet:
36 | print('Checkpoint directory does not exist Making directory {}'.format(checkpoint))
37 | os.mkdir(checkpoint)
38 |
39 | torch.save(state, filepath)
40 |
41 | # if is_best:
42 | # shutil.copyfile(filepath, os.path.join(checkpoint, 'best_state_checkpoint.pt'))
43 |
44 | if not quiet:
45 | print('Checkpoint saved.')
46 |
47 |
48 | def load_checkpoint(checkpoint, models, optimizers=None, scheduler=None, best_metric=None, map_location='cpu'):
49 | """ loads model state_dict from filepath; if optimizer and lr_scheduler provided also loads them
50 |
51 | args
52 | checkpoint -- string of filename
53 | model -- torch nn.Module model
54 | optimizer -- torch.optim instance to resume from checkpoint
55 | lr_scheduler -- torch.optim.lr_scheduler instance to resume from checkpoint
56 | """
57 |
58 | if not os.path.exists(checkpoint):
59 | raise('File does not exist {}'.format(checkpoint))
60 |
61 | checkpoint = torch.load(checkpoint, map_location=map_location)
62 | models = [m.load_state_dict(checkpoint['model_state_dicts'][i]) for i, m in enumerate(models)]
63 |
64 | if optimizers:
65 | try:
66 | optimizers = [o.load_state_dict(checkpoint['optimizer_state_dicts'][i]) for i, o in enumerate(optimizers)]
67 | except KeyError:
68 | print('No optimizer state dict in checkpoint file')
69 |
70 | if best_metric:
71 | try:
72 | best_metric = checkpoint['best_val_acc']
73 | except KeyError:
74 | print('No best validation accuracy recorded in checkpoint file.')
75 |
76 | if scheduler:
77 | try:
78 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
79 | except KeyError:
80 | print('No lr scheduler state dict in checkpoint file')
81 |
82 | return checkpoint['epoch']
83 |
84 |
--------------------------------------------------------------------------------
/vqvae.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of VQ-VAE-2:
3 | -- van den Oord, 'Generating Diverse High-Fidelity Images with VQ-VAE-2' -- https://arxiv.org/abs/1906.00446
4 | -- van den Oord, 'Neural Discrete Representation Learning' -- https://arxiv.org/abs/1711.00937
5 | -- Roy, Theory and Experiments on Vector Quantized Autoencoders' -- https://arxiv.org/pdf/1805.11063.pdf
6 |
7 | Reference implementation of the vector quantized VAE:
8 | https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.utils.data import DataLoader
15 | import torchvision.transforms as T
16 | from torchvision.datasets import CIFAR10
17 | from torchvision.utils import save_image, make_grid
18 |
19 | from tensorboardX import SummaryWriter
20 | from tqdm import tqdm
21 |
22 | import os
23 | import argparse
24 | import time
25 | import json
26 | import pprint
27 | from functools import partial
28 |
29 | from datasets.chexpert import ChexpertDataset
30 |
31 | parser = argparse.ArgumentParser()
32 | # action
33 | parser.add_argument('--train', action='store_true', help='Train model.')
34 | parser.add_argument('--evaluate', action='store_true', help='Evaluate model.')
35 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.')
36 | parser.add_argument('--seed', type=int, default=0, help='Random seed to use.')
37 | parser.add_argument('--cuda', type=int, help='Which cuda device to use.')
38 | parser.add_argument('--mini_data', action='store_true', help='Truncate dataset to a single minibatch.')
39 | # model
40 | parser.add_argument('--n_embeddings', default=256, type=int, help='Size of discrete latent space (K-way categorical).')
41 | parser.add_argument('--embedding_dim', default=64, type=int, help='Dimensionality of each latent embedding vector.')
42 | parser.add_argument('--n_channels', default=128, type=int, help='Number of channels in the encoder and decoder.')
43 | parser.add_argument('--n_res_channels', default=64, type=int, help='Number of channels in the residual layers.')
44 | parser.add_argument('--n_res_layers', default=2, type=int, help='Number of residual layers inside the residual block.')
45 | parser.add_argument('--n_cond_classes', type=int, help='(NOT USED here; used in training prior but requires flag for dataloader) Number of classes if conditional model.')
46 | # data params
47 | parser.add_argument('--dataset', choices=['cifar10', 'chexpert'], default='chexpert')
48 | parser.add_argument('--data_dir', default='~/data/', help='Location of datasets.')
49 | parser.add_argument('--output_dir', type=str, help='Location where weights, logs, and sample should be saved.')
50 | parser.add_argument('--restore_dir', type=str, help='Path to model config and checkpoint to restore.')
51 | # training param
52 | parser.add_argument('--batch_size', type=int, default=128, help='Training batch size.')
53 | parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate.')
54 | parser.add_argument('--lr_decay', type=float, default=0.9999965, help='Learning rate decay (assume end lr = 1e-6 @ 2m iters for init lr 0.001).')
55 | parser.add_argument('--commitment_cost', type=float, default=0.25, help='Commitment cost term in loss function.')
56 | parser.add_argument('--ema', action='store_true', help='Use exponential moving average training for the codebook.')
57 | parser.add_argument('--ema_decay', type=float, default=0.99, help='EMA decay rate.')
58 | parser.add_argument('--ema_eps', type=float, default=1e-5, help='EMA epsilon.')
59 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.')
60 | parser.add_argument('--step', type=int, default=0, help='Current step of training (number of minibatches processed).')
61 | parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.')
62 | parser.add_argument('--log_interval', type=int, default=50, help='How often to show loss statistics and save samples.')
63 | parser.add_argument('--eval_interval', type=int, default=10, help='How often to evaluate and save samples.')
64 | # distributed training params
65 | parser.add_argument('--distributed', action='store_true', default=False, help='Whether to use DistributedDataParallels on multiple machines and GPUs.')
66 | # generation param
67 | parser.add_argument('--n_samples', type=int, default=64, help='Number of samples to generate.')
68 |
69 |
70 | # --------------------
71 | # Data and model loading
72 | # --------------------
73 |
74 | def fetch_vqvae_dataloader(args, train=True):
75 | if args.dataset == 'cifar10':
76 | # setup dataset and dataloader -- preprocess data to [-1, 1]
77 | dataset = CIFAR10(args.data_dir,
78 | train=train,
79 | transform=T.Compose([T.ToTensor(), lambda x: x.mul(2).sub(1)]),
80 | target_transform=(lambda y: torch.eye(args.n_cond_classes)[y]) if args.n_cond_classes else None)
81 | if not 'input_dims' in args: args.input_dims = (3,32,32)
82 | elif args.dataset == 'chexpert':
83 | dataset = ChexpertDataset(args.data_dir, train,
84 | transform=T.Compose([T.ToTensor(), lambda x: x.mul(2).sub(1)]))
85 | if not 'input_dims' in args: args.input_dims = dataset.input_dims
86 | args.n_cond_classes = len(dataset.attr_idxs)
87 |
88 | if args.mini_data:
89 | dataset.data = dataset.data[:args.batch_size]
90 | return DataLoader(dataset, args.batch_size, shuffle=train, num_workers=4, pin_memory=('cuda' in args.device))
91 |
92 | def load_model(model_cls, config, model_dir, args, restore=False, eval_mode=False, optimizer_cls=None, scheduler_cls=None, verbose=True):
93 | # load model config
94 | if config is None: config = load_json(os.path.join(model_dir, 'config_{}.json'.format(args.cuda)))
95 | # init model and distribute
96 | model = model_cls(**config).to(args.device)
97 | if args.distributed:
98 | # NOTE: DistributedDataParallel will divide and allocate batch_size to all available GPUs if device_ids are not set
99 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.cuda], output_device=args.cuda,
100 | find_unused_parameters=True)
101 | # init optimizer and scheduler
102 | optimizer = optimizer_cls(model.parameters()) if optimizer_cls else None
103 | scheduler = scheduler_cls(optimizer) if scheduler_cls else None
104 | if restore:
105 | checkpoint = torch.load(os.path.join(model_dir, 'checkpoint.pt'), map_location=args.device)
106 | if args.distributed:
107 | model.module.load_state_dict(checkpoint['state_dict'])
108 | else:
109 | model.load_state_dict(checkpoint['state_dict'])
110 | args.start_epoch = checkpoint['epoch'] + 1
111 | args.step = checkpoint['global_step']
112 | if optimizer: optimizer.load_state_dict(torch.load(model_dir + '/optim_checkpoint.pt', map_location=args.device))
113 | if scheduler: scheduler.load_state_dict(torch.load(model_dir + '/sched_checkpoint.pt', map_location=args.device))
114 | if eval_mode:
115 | model.eval()
116 | # if optimizer and restore: optimizer.use_ema(True)
117 | for p in model.parameters(): p.requires_grad_(False)
118 | if verbose:
119 | print('Loaded {}\n\tconfig and state dict loaded from {}'.format(model_cls.__name__, model_dir))
120 | print('\tmodel parameters: {:,}'.format(sum(p.numel() for p in model.parameters())))
121 | return model, optimizer, scheduler
122 |
123 | def save_json(data, filename, args):
124 | with open(os.path.join(args.output_dir, filename + '.json'), 'w') as f:
125 | json.dump(data, f, indent=4)
126 |
127 | def load_json(file_path):
128 | with open(file_path, 'r') as f:
129 | data = json.load(f)
130 | return data
131 |
132 | # --------------------
133 | # VQVAE components
134 | # --------------------
135 |
136 | class VQ(nn.Module):
137 | def __init__(self, n_embeddings, embedding_dim, ema=False, ema_decay=0.99, ema_eps=1e-5):
138 | super().__init__()
139 | self.n_embeddings = n_embeddings
140 | self.embedding_dim = embedding_dim
141 | self.ema = ema
142 | self.ema_decay = ema_decay
143 | self.ema_eps = ema_eps
144 |
145 | self.embedding = nn.Embedding(n_embeddings, embedding_dim)
146 | nn.init.kaiming_uniform_(self.embedding.weight, 1)
147 |
148 | if ema:
149 | self.embedding.weight.requires_grad_(False)
150 | # set up moving averages
151 | self.register_buffer('ema_cluster_size', torch.zeros(n_embeddings))
152 | self.register_buffer('ema_weight', self.embedding.weight.clone().detach())
153 |
154 | def embed(self, encoding_indices):
155 | return self.embedding(encoding_indices).permute(0,4,1,2,3).squeeze(2) # in (B,1,H,W); out (B,E,H,W)
156 |
157 | def forward(self, z):
158 | # input (B,E,H,W); permute and reshape to (B*H*W,E) to compute distances in E-space
159 | flat_z = z.permute(0,2,3,1).reshape(-1, self.embedding_dim) # (B*H*W,E)
160 | # compute distances to nearest embedding
161 | distances = flat_z.pow(2).sum(1, True) + self.embedding.weight.pow(2).sum(1) - 2 * flat_z.matmul(self.embedding.weight.t())
162 | # quantize z to nearest embedding
163 | encoding_indices = distances.argmin(1).reshape(z.shape[0], 1, *z.shape[2:]) # (B,1,H,W)
164 | z_q = self.embed(encoding_indices)
165 |
166 | # perform ema updates
167 | if self.ema and self.training:
168 | with torch.no_grad():
169 | # update cluster size
170 | encodings = F.one_hot(encoding_indices.flatten(), self.n_embeddings).float().to(z.device)
171 | self.ema_cluster_size -= (1 - self.ema_decay) * (self.ema_cluster_size - encodings.sum(0))
172 | # update weight
173 | dw = z.permute(1,0,2,3).flatten(1) @ encodings # (E,B*H*W) dot (B*H*W,n_embeddings)
174 | self.ema_weight -= (1 - self.ema_decay) * (self.ema_weight - dw.t())
175 | # update embedding weight with normalized ema_weight
176 | n = self.ema_cluster_size.sum()
177 | updated_cluster_size = (self.ema_cluster_size + self.ema_eps) / (n + self.n_embeddings * self.ema_eps) * n
178 | self.embedding.weight.data = self.ema_weight / updated_cluster_size.unsqueeze(1)
179 |
180 | return encoding_indices, z_q # out (B,1,H,W) codes and (B,E,H,W) embedded codes
181 |
182 |
183 | class ResidualLayer(nn.Sequential):
184 | def __init__(self, n_channels, n_res_channels):
185 | super().__init__(nn.Conv2d(n_channels, n_res_channels, kernel_size=3, padding=1),
186 | nn.ReLU(True),
187 | nn.Conv2d(n_res_channels, n_channels, kernel_size=1))
188 |
189 | def forward(self, x):
190 | return F.relu(x + super().forward(x), True)
191 |
192 | # --------------------
193 | # VQVAE2
194 | # --------------------
195 |
196 | class VQVAE2(nn.Module):
197 | def __init__(self, input_dims, n_embeddings, embedding_dim, n_channels, n_res_channels, n_res_layers,
198 | ema=False, ema_decay=0.99, ema_eps=1e-5, **kwargs): # keep kwargs so can load from config with arbitrary other args
199 | super().__init__()
200 | self.ema = ema
201 |
202 | self.enc1 = nn.Sequential(nn.Conv2d(input_dims[0], n_channels//2, kernel_size=4, stride=2, padding=1),
203 | nn.ReLU(True),
204 | nn.Conv2d(n_channels//2, n_channels, kernel_size=4, stride=2, padding=1),
205 | nn.ReLU(True),
206 | nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=1),
207 | nn.ReLU(True),
208 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]),
209 | nn.Conv2d(n_channels, embedding_dim, kernel_size=1))
210 |
211 | self.enc2 = nn.Sequential(nn.Conv2d(embedding_dim, n_channels//2, kernel_size=4, stride=2, padding=1),
212 | nn.ReLU(True),
213 | nn.Conv2d(n_channels//2, n_channels, kernel_size=3, padding=1),
214 | nn.ReLU(True),
215 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]),
216 | nn.Conv2d(n_channels, embedding_dim, kernel_size=1))
217 |
218 | self.dec2 = nn.Sequential(nn.Conv2d(embedding_dim, n_channels, kernel_size=3, padding=1),
219 | nn.ReLU(True),
220 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]),
221 | nn.ConvTranspose2d(n_channels, embedding_dim, kernel_size=4, stride=2, padding=1))
222 |
223 | self.dec1 = nn.Sequential(nn.Conv2d(2*embedding_dim, n_channels, kernel_size=3, padding=1),
224 | nn.ReLU(True),
225 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]),
226 | nn.ConvTranspose2d(n_channels, n_channels//2, kernel_size=4, stride=2, padding=1),
227 | nn.ReLU(True),
228 | nn.ConvTranspose2d(n_channels//2, input_dims[0], kernel_size=4, stride=2, padding=1))
229 |
230 | self.proj_to_vq1 = nn.Conv2d(2*embedding_dim, embedding_dim, kernel_size=1)
231 | self.upsample_to_dec1 = nn.ConvTranspose2d(embedding_dim, embedding_dim, kernel_size=4, stride=2, padding=1)
232 |
233 | self.vq1 = VQ(n_embeddings, embedding_dim, ema, ema_decay, ema_eps)
234 | self.vq2 = VQ(n_embeddings, embedding_dim, ema, ema_decay, ema_eps)
235 |
236 | def encode(self, x):
237 | z1 = self.enc1(x)
238 | z2 = self.enc2(z1)
239 | return (z1, z2) # each is (B,E,H,W)
240 |
241 | def embed(self, encoding_indices):
242 | encoding_indices1, encoding_indices2 = encoding_indices
243 | return (self.vq1.embed(encoding_indices1), self.vq2.embed(encoding_indices2))
244 |
245 | def quantize(self, z_e):
246 | # unpack inputs
247 | z1, z2 = z_e
248 |
249 | # quantize top level
250 | encoding_indices2, zq2 = self.vq2(z2)
251 |
252 | # quantize bottom level conditioned on top level decoder and bottom level encoder
253 | # decode top level
254 | quantized2 = z2 + (zq2 - z2).detach() # stop decoder optimization from accessing the embedding
255 | dec2_out = self.dec2(quantized2)
256 | # condition on bottom encoder and top decoder
257 | vq1_input = torch.cat([z1, dec2_out], 1)
258 | vq1_input = self.proj_to_vq1(vq1_input)
259 | encoding_indices1, zq1 = self.vq1(vq1_input)
260 | return (encoding_indices1, encoding_indices2), (zq1, zq2)
261 |
262 | def decode(self, z_e, z_q):
263 | # unpack inputs
264 | zq1, zq2 = z_q
265 | if z_e is not None:
266 | z1, z2 = z_e
267 | # stop decoder optimization from accessing the embedding
268 | zq1 = z1 + (zq1 - z1).detach()
269 | zq2 = z2 + (zq2 - z2).detach()
270 |
271 | # upsample quantized2 to match spacial dim of quantized1
272 | zq2_upsampled = self.upsample_to_dec1(zq2)
273 | # decode
274 | combined_latents = torch.cat([zq1, zq2_upsampled], 1)
275 | return self.dec1(combined_latents)
276 |
277 | def forward(self, x, commitment_cost, writer=None):
278 | # Figure 2a in paper
279 | z_e = self.encode(x)
280 | encoding_indices, z_q = self.quantize(z_e)
281 | recon_x = self.decode(z_e, z_q)
282 |
283 | # compute loss over the hierarchy -- cf eq 2 in paper
284 | recon_loss = F.mse_loss(recon_x, x)
285 | q_latent_loss = sum(F.mse_loss(z_i.detach(), zq_i) for z_i, zq_i in zip(z_e, z_q)) if not self.ema else torch.zeros(1, device=x.device)
286 | e_latent_loss = sum(F.mse_loss(z_i, zq_i.detach()) for z_i, zq_i in zip(z_e, z_q))
287 | loss = recon_loss + q_latent_loss + commitment_cost * e_latent_loss
288 |
289 | if writer:
290 | # compute perplexity
291 | n_embeddings = self.vq1.embedding.num_embeddings
292 | avg_probs = lambda e: torch.histc(e.float(), bins=n_embeddings, max=n_embeddings).float().div(e.numel())
293 | perplexity = lambda avg_probs: torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
294 | # record training stats
295 | writer.add_scalar('loss', loss.item(), args.step)
296 | writer.add_scalar('loss_recon_train', recon_loss.item(), args.step)
297 | writer.add_scalar('loss_q_latent', q_latent_loss.item(), args.step)
298 | writer.add_scalar('loss_e_latent', e_latent_loss.item(), args.step)
299 | for i, e_i in enumerate(encoding_indices):
300 | writer.add_scalar('perplexity_{}'.format(i), perplexity(avg_probs(e_i)).item(), args.step)
301 |
302 | return loss
303 |
304 |
305 | # --------------------
306 | # Train, evaluate, reconstruct
307 | # --------------------
308 |
309 | def train_epoch(model, dataloader, optimizer, scheduler, epoch, writer, args):
310 | model.train()
311 |
312 | with tqdm(total=len(dataloader), desc='epoch {}/{}'.format(epoch, args.start_epoch + args.n_epochs)) as pbar:
313 | for x, _ in dataloader:
314 | args.step += 1
315 |
316 | loss = model(x.to(args.device), args.commitment_cost, writer if args.step % args.log_interval == 0 else None)
317 |
318 | optimizer.zero_grad()
319 | loss.backward()
320 | optimizer.step()
321 | if scheduler: scheduler.step()
322 |
323 | pbar.set_postfix(loss='{:.4f}'.format(loss.item()))
324 | pbar.update()
325 |
326 | def show_recons_from_hierarchy(model, n_samples, x, z_q, recon_x=None):
327 | # full reconstruction
328 | if recon_x is None:
329 | recon_x = model.decode(None, z_q)
330 | # top level only reconstruction -- no contribution from bottom-level (level1) latents
331 | recon_top = model.decode(None, (z_q[0].fill_(0), z_q[1]))
332 |
333 | # construct image grid
334 | x = make_grid(x[:n_samples].cpu(), normalize=True)
335 | recon_x = make_grid(recon_x[:n_samples].cpu(), normalize=True)
336 | recon_top = make_grid(recon_top[:n_samples].cpu(), normalize=True)
337 | separator = torch.zeros(x.shape[0], 4, x.shape[2])
338 | return torch.cat([x, separator, recon_x, separator, recon_top], dim=1)
339 |
340 | @torch.no_grad()
341 | def evaluate(model, dataloader, args):
342 | model.eval()
343 |
344 | recon_loss = 0
345 | for x, _ in tqdm(dataloader):
346 | x = x.to(args.device)
347 | z_e = model.encode(x)
348 | encoding_indices, z_q = model.quantize(z_e)
349 | recon_x = model.decode(z_e, z_q)
350 | recon_loss += F.mse_loss(recon_x, x).item()
351 | recon_loss /= len(dataloader)
352 |
353 | # reconstruct
354 | recon_image = show_recons_from_hierarchy(model, args.n_samples, x, z_q, recon_x)
355 | return recon_image, recon_loss
356 |
357 | def train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args):
358 | for epoch in range(args.start_epoch, args.start_epoch + args.n_epochs):
359 | train_epoch(model, train_dataloader, optimizer, scheduler, epoch, writer, args)
360 |
361 | # save model
362 | torch.save({'epoch': epoch,
363 | 'global_step': args.step,
364 | 'state_dict': model.state_dict()},
365 | os.path.join(args.output_dir, 'checkpoint.pt'))
366 | torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'optim_checkpoint.pt'))
367 | if scheduler: torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'sched_checkpoint.pt'))
368 |
369 | if (epoch+1) % args.eval_interval == 0:
370 | # evaluate
371 | recon_image, recon_loss = evaluate(model, valid_dataloader, args)
372 | print('Evaluate -- recon loss: {:.4f}'.format(recon_loss))
373 | writer.add_scalar('loss_recon_eval', recon_loss, args.step)
374 | writer.add_image('eval_reconstructions', recon_image, args.step)
375 | save_image(recon_image, os.path.join(args.output_dir, 'eval_reconstruction_step_{}'.format(args.step) + '.png'))
376 |
377 |
378 | # --------------------
379 | # Main
380 | # --------------------
381 |
382 | if __name__ == '__main__':
383 | args = parser.parse_args()
384 | if args.restore_dir:
385 | args.output_dir = args.restore_dir
386 | if not args.output_dir: # if not given use results/file_name/time_stamp
387 | args.output_dir = './results/{}/{}'.format(os.path.splitext(__file__)[0], time.strftime('%Y-%m-%d_%H-%M-%S', time.gmtime()))
388 | writer = SummaryWriter(log_dir = args.output_dir)
389 |
390 | args.device = 'cuda:{}'.format(args.cuda) if args.cuda is not None and torch.cuda.is_available() else 'cpu'
391 |
392 | torch.manual_seed(args.seed)
393 |
394 | # setup dataset and dataloader -- preprocess data to [-1, 1]
395 | train_dataloader = fetch_vqvae_dataloader(args, train=True)
396 | valid_dataloader = fetch_vqvae_dataloader(args, train=False)
397 |
398 | # save config
399 | if not os.path.exists(os.path.join(args.output_dir, 'config_{}.json'.format(args.cuda))):
400 | save_json(args.__dict__, 'config_{}'.format(args.cuda), args)
401 |
402 | # setup model
403 | model, optimizer, scheduler = load_model(VQVAE2, args.output_dir, args,
404 | restore=(args.restore_dir is not None),
405 | eval_mode=False,
406 | optimizer_cls=partial(torch.optim.Adam, lr=args.lr),
407 | scheduler_cls=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=args.lr_decay))
408 |
409 | # print and write config with update step and epoch from load_model
410 | writer.add_text('config', str(args.__dict__), args.step)
411 | pprint.pprint(args.__dict__)
412 |
413 | if args.train:
414 | train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args)
415 |
416 | if args.evaluate:
417 | recon_image, recon_loss = evaluate(model, valid_dataloader, args)
418 | print('Evaluate @ step {} -- recon loss: {:.4f}'.format(args.step, recon_loss))
419 | writer.add_scalar('loss_recon_eval', recon_loss, args.step)
420 | writer.add_image('eval_reconstructions', recon_image, args.step)
421 | save_image(recon_image, os.path.join(args.output_dir, 'eval_reconstruction_step_{}'.format(args.step) + '.png'))
422 |
423 |
--------------------------------------------------------------------------------
/vqvae_prior.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of VQ-VAE-2 priors:
3 | -- van den Oord, 'Generating Diverse High-Fidelity Images with VQ-VAE-2' -- https://arxiv.org/abs/1906.00446
4 | -- van den Oord, 'Conditional Image Generation with PixelCNN Decoders' -- https://arxiv.org/abs/1606.05328
5 | -- Xi Chen, 'PixelSNAIL: An Improved Autoregressive Generative Model' -- https://arxiv.org/abs/1712.09763
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.multiprocessing as mp
12 | from torch.utils.data import DataLoader, TensorDataset
13 | from torchvision.utils import save_image, make_grid
14 |
15 | import numpy as np
16 | from tensorboardX import SummaryWriter
17 | from tqdm import tqdm
18 |
19 | import os
20 | import argparse
21 | import time
22 | import pprint
23 | from functools import partial
24 |
25 | from vqvae import VQVAE2, fetch_vqvae_dataloader, load_model, save_json, load_json
26 | from optim import Adam, RMSprop
27 |
28 |
29 | parser = argparse.ArgumentParser()
30 |
31 | # action
32 | parser.add_argument('--train', action='store_true', help='Train model.')
33 | parser.add_argument('--evaluate', action='store_true', help='Evaluate model.')
34 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.')
35 | parser.add_argument('--seed', type=int, default=0, help='Random seed to use.')
36 | parser.add_argument('--cuda', type=int, help='Which cuda device to use.')
37 | parser.add_argument('--mini_data', action='store_true', help='Truncate dataset to a single minibatch.')
38 | # model
39 | parser.add_argument('--which_prior', choices=['bottom', 'top'], help='Which prior model to train.')
40 | parser.add_argument('--vqvae_dir', type=str, required=True, help='Path to VQVAE folder with config.json and checkpoint.pt files.')
41 | parser.add_argument('--n_channels', default=128, type=int, help='Number of channels for gated residual convolutional blocks.')
42 | parser.add_argument('--n_out_conv_channels', default=1024, type=int, help='Number of channels for outer 1x1 convolutional layers.')
43 | parser.add_argument('--n_res_layers', default=20, type=int, help='Number of Gated Residual Blocks.')
44 | parser.add_argument('--n_cond_classes', default=5, type=int, help='Number of classes if conditional model.')
45 | parser.add_argument('--n_cond_stack_layers', default=10, type=int, help='Number of conditioning stack residual blocks.')
46 | parser.add_argument('--n_out_stack_layers', default=10, type=int, help='Number of output stack layers.')
47 | parser.add_argument('--kernel_size', default=5, type=int, help='Kernel size for the gated residual convolutional blocks.')
48 | parser.add_argument('--drop_rate', default=0, type=float, help='Dropout for the Gated Residual Blocks.')
49 | # data params
50 | parser.add_argument('--output_dir', type=str, help='Location where weights, logs, and sample should be saved.')
51 | parser.add_argument('--restore_dir', nargs='+', help='Location where configs and weights are to be restored from.')
52 | parser.add_argument('--n_bits', type=int, help='Number of bits of input data.')
53 | # training param
54 | parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate.')
55 | parser.add_argument('--lr_decay', type=float, default=0.99999, help='Learning rate decay (assume 5e-5 @ 300k iters for lr 0.001).')
56 | parser.add_argument('--polyak', type=float, default=0.9995, help='Polyak decay for exponential moving averaging.')
57 | parser.add_argument('--batch_size', type=int, default=16, help='Training batch size.')
58 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.')
59 | parser.add_argument('--step', type=int, default=0, help='Current step of training (number of minibatches processed).')
60 | parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.')
61 | parser.add_argument('--log_interval', type=int, default=50, help='How often to show loss statistics and save samples.')
62 | parser.add_argument('--eval_interval', type=int, default=10, help='How often to evaluate and save samples.')
63 | parser.add_argument('--save_interval', type=int, default=300, help='How often to evaluate and save samples.')
64 | # distributed training params
65 | parser.add_argument('--distributed', action='store_true', default=False, help='Whether to use DistributedDataParallels on multiple machines and GPUs.')
66 | parser.add_argument('--world_size', type=int, default=1)
67 | parser.add_argument('--rank', type=int, default=0)
68 | # generation param
69 | parser.add_argument('--n_samples', type=int, default=8, help='Number of samples to generate.')
70 |
71 |
72 |
73 | # --------------------
74 | # Data and model loading
75 | # --------------------
76 |
77 | @torch.no_grad()
78 | def extract_codes_from_dataloader(vqvae, dataloader, dataset_path):
79 | """ encode image inputs with vqvae and extract field of discrete latents (the embedding indices in the codebook with closest l2 distance) """
80 | device = next(vqvae.parameters()).device
81 | e1s, e2s, ys = [], [], []
82 | for x, y in tqdm(dataloader):
83 | z_e = vqvae.encode(x.to(device))
84 | encoding_indices, _ = vqvae.quantize(z_e) # tuple of (bottom, top encoding indices) where each is (B,1,H,W)
85 |
86 | e1, e2 = encoding_indices
87 | e1s.append(e1)
88 | e2s.append(e2)
89 | ys.append(y)
90 | return TensorDataset(torch.cat(e1s).cpu(), torch.cat(e2s).cpu(), torch.cat(ys))
91 |
92 | def maybe_extract_codes(vqvae, args, train):
93 | """ construct datasets of vqvae encodings and class conditional labels -- each dataset entry is [encodings level 1 (bottom), encodings level 2 (top), class label vector] """
94 | # paths to load/save as `chexpert_train_codes_mini_data.pt`
95 | dataset_path = os.path.join(args.vqvae_dir, '{}_{}_codes'.format(args.dataset, 'train' if train else 'valid') + args.mini_data*'_mini_data_{}'.format(args.batch_size) + '.pt')
96 | if not os.path.exists(dataset_path):
97 | print('Extracting codes for {} data ...'.format('train' if train else 'valid'))
98 | dataloader = fetch_vqvae_dataloader(args, train)
99 | dataset = extract_codes_from_dataloader(vqvae, dataloader, dataset_path)
100 | torch.save(dataset, dataset_path)
101 | else:
102 | dataset = torch.load(dataset_path)
103 | if args.on_main_process: print('Loaded {} codes dataset of size {}'.format('train' if train else 'valid', len(dataset)))
104 | return dataset
105 |
106 | def fetch_prior_dataloader(vqvae, args, train=True):
107 | dataset = maybe_extract_codes(vqvae, args, train)
108 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if args.distributed and train else None
109 | return DataLoader(dataset, args.batch_size, shuffle=(train and sampler is None), sampler=sampler, num_workers=4, pin_memory=('cuda' in args.device))
110 |
111 | def preprocess(x, n_bits):
112 | """ preprosses discrete latents space [0, 2**n_bits) to model space [-1,1]; if size of the codebook ie n_embeddings = 512 = 2**9 -> n_bit=9 """
113 | # 1. convert data to float
114 | # 2. normalize to [0,1] given quantization
115 | # 3. shift to [-1,1]
116 | return x.float().div(2**n_bits - 1).mul(2).add(-1)
117 |
118 | def deprocess(x, n_bits):
119 | """ deprocess x from model space [-1,1] to discrete latents space [0, 2**n_bits) where 2**n_bits is size of the codebook """
120 | # 1. shift to [0,1]
121 | # 2. quantize to n_bits
122 | # 3. convert data to long
123 | return x.add(1).div(2).mul(2**n_bits - 1).long()
124 |
125 | # --------------------
126 | # PixelSNAIL -- top level prior conditioned on class labels
127 | # --------------------
128 |
129 | def down_shift(x):
130 | return F.pad(x, (0,0,1,0))[:,:,:-1,:]
131 |
132 | def right_shift(x):
133 | return F.pad(x, (1,0))[:,:,:,:-1]
134 |
135 | def concat_elu(x):
136 | return F.elu(torch.cat([x, -x], dim=1))
137 |
138 | class Conv2d(nn.Conv2d):
139 | def __init__(self, *args, **kwargs):
140 | super().__init__(*args, **kwargs)
141 | nn.utils.weight_norm(self)
142 |
143 | class DownShiftedConv2d(Conv2d):
144 | def forward(self, x):
145 | # pad H above and W on each side
146 | Hk, Wk = self.kernel_size
147 | x = F.pad(x, ((Wk-1)//2, (Wk-1)//2, Hk-1, 0))
148 | return super().forward(x)
149 |
150 | class DownRightShiftedConv2d(Conv2d):
151 | def forward(self, x):
152 | # pad above and on left (ie shift input down and right)
153 | Hk, Wk = self.kernel_size
154 | x = F.pad(x, (Wk-1, 0, Hk-1, 0))
155 | return super().forward(x)
156 |
157 | class GatedResidualLayer(nn.Module):
158 | def __init__(self, conv, n_channels, kernel_size, drop_rate=0, shortcut_channels=None, n_cond_classes=None, relu_fn=concat_elu):
159 | super().__init__()
160 | self.relu_fn = relu_fn
161 |
162 | self.c1 = conv(2*n_channels, n_channels, kernel_size)
163 | if shortcut_channels:
164 | self.c1c = Conv2d(2*shortcut_channels, n_channels, kernel_size=1)
165 | if drop_rate > 0:
166 | self.dropout = nn.Dropout(drop_rate)
167 | self.c2 = conv(2*n_channels, 2*n_channels, kernel_size)
168 | if n_cond_classes:
169 | self.proj_y = nn.Linear(n_cond_classes, 2*n_channels)
170 |
171 | def forward(self, x, a=None, y=None):
172 | c1 = self.c1(self.relu_fn(x))
173 | if a is not None: # shortcut connection if auxiliary input 'a' is given
174 | c1 = c1 + self.c1c(self.relu_fn(a))
175 | c1 = self.relu_fn(c1)
176 | if hasattr(self, 'dropout'):
177 | c1 = self.dropout(c1)
178 | c2 = self.c2(c1)
179 | if y is not None:
180 | c2 += self.proj_y(y)[:,:,None,None]
181 | a, b = c2.chunk(2,1)
182 | out = x + a * torch.sigmoid(b)
183 | return out
184 |
185 | def causal_attention(k, q, v, mask, nh, drop_rate, training):
186 | B, dq, H, W = q.shape
187 | _, dv, _, _ = v.shape
188 |
189 | # split channels into multiple heads, flatten H,W dims and scale q; out (B, nh, dkh or dvh, HW)
190 | flat_q = q.reshape(B, nh, dq//nh, H, W).flatten(3) * (dq//nh)**-0.5
191 | flat_k = k.reshape(B, nh, dq//nh, H, W).flatten(3)
192 | flat_v = v.reshape(B, nh, dv//nh, H, W).flatten(3)
193 |
194 | logits = torch.matmul(flat_q.transpose(2,3), flat_k) # (B,nh,HW,dq) dot (B,nh,dq,HW) = (B,nh,HW,HW)
195 | logits = F.dropout(logits, p=drop_rate, training=training, inplace=True)
196 | logits = logits.masked_fill(mask==0, -1e10)
197 | weights = F.softmax(logits, -1)
198 |
199 | attn_out = torch.matmul(weights, flat_v.transpose(2,3)) # (B,nh,HW,HW) dot (B,nh,HW,dvh) = (B,nh,HW,dvh)
200 | attn_out = attn_out.transpose(2,3) # (B,nh,dvh,HW)
201 | return attn_out.reshape(B, -1, H, W) # (B,dv,H,W)
202 |
203 | class AttentionGatedResidualLayer(nn.Module):
204 | def __init__(self, n_channels, n_background_ch, n_res_layers, n_cond_classes, drop_rate, nh, dq, dv, attn_drop_rate):
205 | super().__init__()
206 | # attn params
207 | self.nh = nh
208 | self.dq = dq
209 | self.dv = dv
210 | self.attn_drop_rate = attn_drop_rate
211 |
212 | self.input_gated_resnet = nn.ModuleList([
213 | *[GatedResidualLayer(DownRightShiftedConv2d, n_channels, (2,2), drop_rate, None, n_cond_classes) for _ in range(n_res_layers)]])
214 | self.in_proj_kv = nn.Sequential(GatedResidualLayer(Conv2d, 2*n_channels + n_background_ch, 1, drop_rate, None, n_cond_classes),
215 | Conv2d(2*n_channels + n_background_ch, dq+dv, 1))
216 | self.in_proj_q = nn.Sequential(GatedResidualLayer(Conv2d, n_channels + n_background_ch, 1, drop_rate, None, n_cond_classes),
217 | Conv2d(n_channels + n_background_ch, dq, 1))
218 | self.out_proj = GatedResidualLayer(Conv2d, n_channels, 1, drop_rate, dv, n_cond_classes)
219 |
220 | def forward(self, x, background, attn_mask, y=None):
221 | ul = x
222 | for m in self.input_gated_resnet:
223 | ul = m(ul, y=y)
224 |
225 | kv = self.in_proj_kv(torch.cat([x, ul, background], 1))
226 | k, v = kv.split([self.dq, self.dv], 1)
227 | q = self.in_proj_q(torch.cat([ul, background], 1))
228 | attn_out = causal_attention(k, q, v, attn_mask, self.nh, self.attn_drop_rate, self.training)
229 | return self.out_proj(ul, attn_out)
230 |
231 | class PixelSNAIL(nn.Module):
232 | def __init__(self, input_dims, n_channels, n_res_layers, n_out_stack_layers, n_cond_classes, n_bits,
233 | attn_n_layers=4, attn_nh=8, attn_dq=16, attn_dv=128, attn_drop_rate=0, drop_rate=0.5, **kwargs):
234 | super().__init__()
235 | H,W = input_dims[2]
236 | # init background
237 | background_v = ((torch.arange(H, dtype=torch.float) - H / 2) / 2).view(1,1,-1,1).expand(1,1,H,W)
238 | background_h = ((torch.arange(W, dtype=torch.float) - W / 2) / 2).view(1,1,1,-1).expand(1,1,H,W)
239 | self.register_buffer('background', torch.cat([background_v, background_h], 1))
240 | # init attention mask over current and future pixels
241 | attn_mask = torch.tril(torch.ones(1,1,H*W,H*W), diagonal=-1).byte() # 1s below diagonal -- attend to context only
242 | self.register_buffer('attn_mask', attn_mask)
243 |
244 | # input layers for `up` and `up and to the left` pixels
245 | self.ul_input_d = DownShiftedConv2d(2, n_channels, kernel_size=(1,3))
246 | self.ul_input_dr = DownRightShiftedConv2d(2, n_channels, kernel_size=(2,1))
247 | self.ul_modules = nn.ModuleList([
248 | *[AttentionGatedResidualLayer(n_channels, self.background.shape[1], n_res_layers, n_cond_classes, drop_rate,
249 | attn_nh, attn_dq, attn_dv, attn_drop_rate) for _ in range(attn_n_layers)]])
250 | self.output_stack = nn.Sequential(
251 | *[GatedResidualLayer(DownRightShiftedConv2d, n_channels, (2,2), drop_rate, None, n_cond_classes) \
252 | for _ in range(n_out_stack_layers)])
253 | self.output_conv = Conv2d(n_channels, 2**n_bits, kernel_size=1)
254 |
255 |
256 | def forward(self, x, y=None):
257 | # add channel of ones to distinguish image from padding later on
258 | x = F.pad(x, (0,0,0,0,0,1), value=1)
259 |
260 | ul = down_shift(self.ul_input_d(x)) + right_shift(self.ul_input_dr(x))
261 | for m in self.ul_modules:
262 | ul = m(ul, self.background.expand(x.shape[0],-1,-1,-1), self.attn_mask, y)
263 | ul = self.output_stack(ul)
264 | return self.output_conv(F.elu(ul)).unsqueeze(2) # out (B, 2**n_bits, 1, H, W)
265 |
266 | # --------------------
267 | # PixelCNN -- bottom level prior conditioned on class labels and top level codes
268 | # --------------------
269 |
270 | def pixelcnn_gate(x):
271 | a, b = x.chunk(2,1)
272 | return torch.tanh(a) * torch.sigmoid(b)
273 |
274 | class MaskedConv2d(nn.Conv2d):
275 | def __init__(self, mask_type, *args, **kwargs):
276 | self.mask_type = mask_type
277 | super().__init__(*args, **kwargs)
278 |
279 | def apply_mask(self):
280 | H, W = self.kernel_size
281 | self.weight.data[:,:,H//2+1:,:].zero_() # mask out rows below the middle
282 | self.weight.data[:,:,H//2,W//2+1:].zero_() # mask out center row pixels right of middle
283 | if self.mask_type=='a':
284 | self.weight.data[:,:,H//2,W//2] = 0 # mask out center pixel
285 |
286 | def forward(self, x):
287 | self.apply_mask()
288 | return super().forward(x)
289 |
290 | class GatedResidualBlock(nn.Module):
291 | """ Figure 2 in Conditional image generation with PixelCNN Decoders """
292 | def __init__(self, in_channels, out_channels, kernel_size, n_cond_channels, drop_rate):
293 | super().__init__()
294 | self.residual = (in_channels==out_channels)
295 | self.drop_rate = drop_rate
296 |
297 | self.v = nn.Conv2d(in_channels, 2*out_channels, kernel_size, padding=kernel_size//2) # vertical stack
298 | self.h = nn.Conv2d(in_channels, 2*out_channels, (1, kernel_size), padding=(0, kernel_size//2)) # horizontal stack
299 | self.v2h = nn.Conv2d(2*out_channels, 2*out_channels, kernel_size=1) # vertical to horizontal connection
300 | self.h2h = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False) # horizontal to horizontal
301 |
302 | if n_cond_channels:
303 | self.in_proj_y = nn.Conv2d(n_cond_channels, 2*out_channels, kernel_size=1)
304 |
305 | if self.drop_rate > 0:
306 | self.dropout_h = nn.Dropout(drop_rate)
307 |
308 | def apply_mask(self):
309 | self.v.weight.data[:,:,self.v.kernel_size[0]//2:,:].zero_() # mask out middle row and below
310 | self.h.weight.data[:,:,:,self.h.kernel_size[1]//2+1:].zero_() # mask out to the right of the central column
311 |
312 | def forward(self, x_v, x_h, y):
313 | self.apply_mask()
314 |
315 | # projection of y if included for conditional generation (cf paper section 2.3 -- added before the pixelcnn_gate)
316 | proj_y = self.in_proj_y(y)
317 |
318 | # vertical stack
319 | x_v_out = self.v(x_v)
320 | x_v2h = self.v2h(x_v_out) + proj_y
321 | x_v_out = pixelcnn_gate(x_v_out)
322 |
323 | # horizontal stack
324 | x_h_out = self.h(x_h) + x_v2h + proj_y
325 | x_h_out = pixelcnn_gate(x_h_out)
326 | if self.drop_rate:
327 | x_h_out = self.dropout_h(x_h_out)
328 | x_h_out = self.h2h(x_h_out)
329 |
330 | # residual connection
331 | if self.residual:
332 | x_h_out = x_h_out + x_h
333 |
334 | return x_v_out, x_h_out
335 |
336 | def extra_repr(self):
337 | return 'residual={}, drop_rate={}'.format(self.residual, self.drop_rate)
338 |
339 | class PixelCNN(nn.Module):
340 | def __init__(self, n_channels, n_out_conv_channels, kernel_size, n_res_layers, n_cond_stack_layers, n_cond_classes, n_bits,
341 | drop_rate=0, **kwargs):
342 | super().__init__()
343 | # conditioning layers (bottom prior conditioned on class labels and top-level code)
344 | self.in_proj_y = nn.Linear(n_cond_classes, 2*n_channels)
345 | self.in_proj_h = nn.ConvTranspose2d(1, n_channels, kernel_size=4, stride=2, padding=1) # upsample top codes to bottom-level spacial dim
346 | self.cond_layers = nn.ModuleList([
347 | GatedResidualLayer(partial(Conv2d, padding=kernel_size//2), n_channels, kernel_size, drop_rate, None, n_cond_classes) \
348 | for _ in range(n_cond_stack_layers)])
349 | self.out_proj_h = nn.Conv2d(n_channels, 2*n_channels, kernel_size=1) # double channels top apply pixelcnn_gate
350 |
351 | # pixelcnn layers
352 | self.input_conv = MaskedConv2d('a', 1, 2*n_channels, kernel_size=7, padding=3)
353 | self.res_layers = nn.ModuleList([
354 | GatedResidualBlock(n_channels, n_channels, kernel_size, 2*n_channels, drop_rate) for _ in range(n_res_layers)])
355 | self.conv_out1 = nn.Conv2d(n_channels, 2*n_out_conv_channels, kernel_size=1)
356 | self.conv_out2 = nn.Conv2d(n_out_conv_channels, 2*n_out_conv_channels, kernel_size=1)
357 | self.output = nn.Conv2d(n_out_conv_channels, 2**n_bits, kernel_size=1)
358 |
359 | def forward(self, x, h=None, y=None):
360 | # conditioning inputs -- h is top-level codes; y is class labels
361 | h = self.in_proj_h(h)
362 | for l in self.cond_layers:
363 | h = l(h, y=y)
364 | h = self.out_proj_h(h)
365 | y = self.in_proj_y(y)[:,:,None,None]
366 |
367 | # pixelcnn model
368 | x = pixelcnn_gate(self.input_conv(x) + h + y)
369 | x_v, x_h = x, x
370 | for l in self.res_layers:
371 | x_v, x_h = l(x_v, x_h, y)
372 | out = pixelcnn_gate(self.conv_out1(x_h))
373 | out = pixelcnn_gate(self.conv_out2(out))
374 | return self.output(out).unsqueeze(2) # (B, 2**n_bits, 1, H, W)
375 |
376 | # --------------------
377 | # Train and evaluate
378 | # --------------------
379 |
380 | def train_epoch(model, dataloader, optimizer, scheduler, epoch, writer, args):
381 | model.train()
382 |
383 | tic = time.time()
384 | if args.on_main_process: pbar = tqdm(total=len(dataloader), desc='epoch {}/{}'.format(epoch, args.start_epoch + args.n_epochs))
385 | for e1, e2, y in dataloader:
386 | args.step += args.world_size
387 |
388 | e1, e2, y = e1.to(args.device), e2.to(args.device), y.to(args.device)
389 |
390 | if args.which_prior == 'bottom':
391 | x = e1
392 | logits = model(preprocess(x, args.n_bits), preprocess(e2, args.n_bits), y)
393 | elif args.which_prior == 'top':
394 | x = e2
395 | logits = model(preprocess(x, args.n_bits), y)
396 | loss = F.cross_entropy(logits, x).mean(0)
397 |
398 | optimizer.zero_grad()
399 | loss.backward()
400 | nn.utils.clip_grad_value_(model.parameters(), 1)
401 | optimizer.step()
402 | if scheduler: scheduler.step()
403 |
404 | # record
405 | if args.on_main_process:
406 | pbar.set_postfix(loss='{:.4f}'.format(loss.item() / np.log(2)))
407 | pbar.update()
408 |
409 | if args.step % args.log_interval == 0 and args.on_main_process:
410 | writer.add_scalar('train_bits_per_dim', loss.item() / np.log(2), args.step)
411 |
412 | # save
413 | if args.step % args.save_interval == 0 and args.on_main_process:
414 | # save model
415 | torch.save({'epoch': epoch,
416 | 'global_step': args.step,
417 | 'state_dict': model.module.state_dict() if args.distributed else model.state_dict()},
418 | os.path.join(args.output_dir, 'checkpoint.pt'))
419 | torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'optim_checkpoint.pt'))
420 | if scheduler: torch.save(scheduler.state_dict(), os.path.join(args.output_dir, 'sched_checkpoint.pt'))
421 |
422 | if args.on_main_process: pbar.close()
423 |
424 | @torch.no_grad()
425 | def evaluate(model, dataloader, args):
426 | model.eval()
427 |
428 | losses = 0
429 | for e1, e2, y in dataloader:
430 | e1, e2, y = e1.to(args.device), e2.to(args.device), y.to(args.device)
431 | if args.which_prior == 'bottom':
432 | x = e1
433 | logits = model(preprocess(x, args.n_bits), preprocess(e2, args.n_bits), y)
434 | elif args.which_prior == 'top':
435 | x = e2
436 | logits = model(preprocess(x, args.n_bits), y)
437 | losses += F.cross_entropy(logits, x).mean(0).item()
438 | return losses / (len(dataloader) * np.log(2)) # to bits per dim
439 |
440 | def train_and_evaluate(model, vqvae, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args):
441 | for epoch in range(args.start_epoch, args.start_epoch + args.n_epochs):
442 | train_epoch(model, train_dataloader, optimizer, scheduler, epoch, writer, args)
443 |
444 | if (epoch+1) % args.eval_interval == 0:
445 | # optimizer.use_ema(True)
446 |
447 | # evaluate
448 | eval_bpd = evaluate(model, valid_dataloader, args)
449 | if args.on_main_process:
450 | print('Evaluate bits per dim: {:.4f}'.format(eval_bpd))
451 | writer.add_scalar('eval_bits_per_dim', eval_bpd, args.step)
452 |
453 | # generate
454 | samples = generate_samples_in_training(model, vqvae, train_dataloader, args)
455 | samples = make_grid(samples, normalize=True, nrow=args.n_samples)
456 | if args.distributed:
457 | # collect samples tensor from all processes onto main process cpu
458 | tensors = [torch.empty(samples.shape, dtype=samples.dtype).cuda() for i in range(args.world_size)]
459 | torch.distributed.all_gather(tensors, samples)
460 | samples = torch.cat(tensors, 2)
461 | if args.on_main_process:
462 | samples = samples.cpu()
463 | writer.add_image('samples_' + args.which_prior, samples, args.step)
464 | save_image(samples, os.path.join(args.output_dir, 'samples_{}_step_{}.png'.format(args.which_prior, args.step)))
465 |
466 | # optimizer.use_ema(False)
467 |
468 | if args.on_main_process:
469 | # save model
470 | torch.save({'epoch': epoch,
471 | 'global_step': args.step,
472 | 'state_dict': model.module.state_dict() if args.distributed else model.state_dict()},
473 | os.path.join(args.output_dir, 'checkpoint.pt'))
474 | torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'optim_checkpoint.pt'))
475 | if scheduler: torch.save(scheduler.state_dict(), os.path.join(args.output_dir, 'sched_checkpoint.pt'))
476 |
477 |
478 | # --------------------
479 | # Sample and generate
480 | # --------------------
481 |
482 | def sample_prior(model, h, y, n_samples, input_dims, n_bits):
483 | model.eval()
484 |
485 | H,W = input_dims
486 | out = torch.zeros(n_samples, 1, H, W, device=next(model.parameters()).device)
487 | if args.on_main_process: pbar = tqdm(total=H*W, desc='Generating {} images'.format(n_samples))
488 | for hi in range(H):
489 | for wi in range(W):
490 | logits = model(out, y) if h is None else model(out, h, y)
491 | probs = F.softmax(logits, dim=1)
492 | sample = torch.multinomial(probs[:,:,:,hi,wi].squeeze(2), 1)
493 | out[:,:,hi,wi] = preprocess(sample, n_bits) # multinomial samples long tensor in [0, 2**n_bits), convert back to model space [-1,1]
494 | if args.on_main_process: pbar.update()
495 | del logits, probs, sample
496 | if args.on_main_process: pbar.close()
497 | return deprocess(out, n_bits) # out (B,1,H,W) field of latents in latent space [0, 2**n_bits)
498 |
499 | @torch.no_grad()
500 | def generate(vqvae, bottom_model, top_model, args, ys=None):
501 | samples = []
502 | for y in ys.unsqueeze(1): # condition on class one-hot labels; (n_samples, 1, n_cond_classes) when sliced on dim 0 returns (1,n_cond_classes)
503 | # sample top prior conditioned on class labels y
504 | top_samples = sample_prior(top_model, None, y, args.n_samples, args.input_dims[2], args.n_bits)
505 | # sample bottom prior conditioned on top_sample codes and class labels y
506 | bottom_samples = sample_prior(bottom_model, preprocess(top_samples, args.n_bits), y, args.n_samples, args.input_dims[1], args.n_bits)
507 | # decode
508 | samples += [vqvae.decode(None, vqvae.embed((bottom_samples, top_samples)))]
509 | samples = torch.cat(samples)
510 | return make_grid(samples, normalize=True, scale_each=True)
511 |
512 | def generate_samples_in_training(model, vqvae, dataloader, args):
513 | if args.which_prior == 'top':
514 | # zero out bottom samples so no contribution
515 | bottom_samples = torch.zeros(args.n_samples*(args.n_cond_classes+1),1,*args.input_dims[1], dtype=torch.long)
516 | # sample top prior
517 | top_samples = []
518 | for y in torch.eye(args.n_cond_classes + 1, args.n_cond_classes).unsqueeze(1).to(args.device): # note eg: torch.eye(3,2) = [[1,0],[0,1],[0,0]]
519 | top_samples += [sample_prior(model, None, y, args.n_samples, args.input_dims[2], args.n_bits).cpu()]
520 | top_samples = torch.cat(top_samples)
521 | # decode
522 | samples = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_samples.to(args.device), top_samples.to(args.device))))
523 |
524 | elif args.which_prior == 'bottom': # level 1
525 | # use the dataset ground truth top codes and only sample the bottom
526 | bottom_gt, top_gt, y = next(iter(dataloader)) # take e2 and y from dataloader output (e1,e2,y)
527 | bottom_gt, top_gt, y = bottom_gt[:args.n_samples].to(args.device), top_gt[:args.n_samples].to(args.device), y[:args.n_samples].to(args.device)
528 | # sample bottom prior
529 | bottom_samples = sample_prior(model, preprocess(top_gt, args.n_bits), y, args.n_samples, args.input_dims[1], args.n_bits)
530 | # decode
531 | # stack (1) recon using bottom+top actual latents,
532 | # (2) recon using top latents only,
533 | # (3) recon using top latent and bottom prior samples
534 | recon_actuals = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_gt, top_gt)))
535 | recon_top = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_gt.fill_(0), top_gt)))
536 | recon_samples = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_samples, top_gt)))
537 | samples = torch.cat([recon_actuals, recon_top, recon_samples])
538 |
539 | return samples
540 |
541 |
542 | # --------------------
543 | # Main
544 | # --------------------
545 |
546 | if __name__ == '__main__':
547 | args = parser.parse_args()
548 | if args.restore_dir and args.which_prior:
549 | args.output_dir = args.restore_dir[0]
550 | if not args.output_dir: # if not given or not set by restore_dir use results/file_name/time_stamp
551 | # name experiment 'vqvae_[vqvae_dir]_prior_[prior_args]_[timestamp]'
552 | exp_name = 'vqvae_' + args.vqvae_dir.strip('/').rpartition('/')[2] + \
553 | '_prior_{which_prior}' + args.mini_data*'_mini{}'.format(args.batch_size) + \
554 | '_b{batch_size}_c{n_channels}_outc{n_out_conv_channels}_nres{n_res_layers}_condstack{n_cond_stack_layers}' + \
555 | '_outstack{n_out_stack_layers}_drop{drop_rate}' + \
556 | '_{}'.format(time.strftime('%Y-%m-%d_%H-%M', time.gmtime()))
557 | args.output_dir = './results/{}/{}'.format(os.path.splitext(__file__)[0], exp_name.format(**args.__dict__))
558 | os.makedirs(args.output_dir, exist_ok=True)
559 |
560 | # setup device and distributed training
561 | if args.distributed:
562 | args.cuda = int(os.environ['LOCAL_RANK'])
563 | args.world_size = int(os.environ['WORLD_SIZE'])
564 | torch.cuda.set_device(args.cuda)
565 | args.device = 'cuda:{}'.format(args.cuda)
566 |
567 | # initialize
568 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
569 | else:
570 | args.device = 'cuda:{}'.format(args.cuda) if args.cuda is not None and torch.cuda.is_available() else 'cpu'
571 |
572 | # write ops only when on_main_process
573 | args.on_main_process = (args.distributed and args.cuda == 0) or not args.distributed
574 |
575 | # setup seed
576 | torch.manual_seed(args.seed)
577 | np.random.seed(args.seed)
578 | torch.backends.cudnn.deterministic = True
579 | torch.backends.cudnn.benchmark = False
580 |
581 | # load vqvae
582 | # load config; extract bits and input sizes throughout the hierarchy from the vqvae config
583 | vqvae_config = load_json(os.path.join(args.vqvae_dir, 'config.json'))
584 | img_dims = vqvae_config['input_dims'][1:]
585 | args.input_dims = [img_dims, [img_dims[0]//4, img_dims[1]//4], [img_dims[0]//8, img_dims[1]//8]]
586 | args.n_bits = int(np.log2(vqvae_config['n_embeddings']))
587 | args.dataset = vqvae_config['dataset']
588 | args.data_dir = vqvae_config['data_dir']
589 | # load model
590 | vqvae, _, _ = load_model(VQVAE2, vqvae_config, args.vqvae_dir, args, restore=True, eval_mode=True, verbose=args.on_main_process)
591 | # reset start_epoch and step after model loading
592 | args.start_epoch, args.step = 0, 0
593 | # expose functions
594 | if args.distributed:
595 | vqvae.encode = vqvae.module.encode
596 | vqvae.decode = vqvae.module.decode
597 | vqvae.embed = vqvae.module.embed
598 |
599 |
600 | # load prior model
601 | # save prior config to feed to load_model
602 | if not os.path.exists(os.path.join(args.output_dir, 'config_{}.json'.format(args.cuda))):
603 | save_json(args.__dict__, 'config_{}'.format(args.cuda), args)
604 | # load model + optimizers, scheduler if training
605 | if args.which_prior:
606 | model, optimizer, scheduler = load_model(PixelCNN if args.which_prior=='bottom' else PixelSNAIL,
607 | config=args.__dict__,
608 | model_dir=args.output_dir,
609 | args=args,
610 | restore=(args.restore_dir is not None),
611 | eval_mode=False,
612 | optimizer_cls=partial(RMSprop,
613 | lr=args.lr,
614 | polyak=args.polyak),
615 | scheduler_cls=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=args.lr_decay),
616 | verbose=args.on_main_process)
617 | else:
618 | assert args.restore_dir and len(args.restore_dir)==2, '`restore_dir` should specify restore dir to bottom prior and top prior'
619 | # load both top and bottom to generate
620 | restore_bottom, restore_top = args.restore_dir
621 | bottom_model, _, _ = load_model(PixelCNN, config=None, model_dir=restore_bottom, args=args, restore=True, eval_mode=True,
622 | optimizer_cls=partial(RMSprop, lr=args.lr, polyak=args.polyak))
623 | top_model, _, _ = load_model(PixelSNAIL, config=None, model_dir=restore_top, args=args, restore=True, eval_mode=True,
624 | optimizer_cls=partial(RMSprop, lr=args.lr, polyak=args.polyak))
625 |
626 | # save and print config and setup writer on main process
627 | writer = None
628 | if args.on_main_process:
629 | pprint.pprint(args.__dict__)
630 | writer = SummaryWriter(log_dir = args.output_dir)
631 | writer.add_text('config', str(args.__dict__))
632 |
633 | if args.train:
634 | assert args.which_prior is not None, 'Must specify `which_prior` to train.'
635 | train_dataloader = fetch_prior_dataloader(vqvae, args, True)
636 | valid_dataloader = fetch_prior_dataloader(vqvae, args, False)
637 | train_and_evaluate(model, vqvae, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args)
638 |
639 | if args.evaluate:
640 | assert args.which_prior is not None, 'Must specify `which_prior` to evaluate.'
641 | valid_dataloader = fetch_prior_dataloader(vqvae, args, False)
642 | # optimizer.use_ema(True)
643 | eval_bpd = evaluate(model, valid_dataloader, args)
644 | if args.on_main_process:
645 | print('Evaluate bits per dim: {:.4f}'.format(eval_bpd))
646 |
647 | if args.generate:
648 | assert args.which_prior is None, 'Remove `which_prior` to load both priors and generate'
649 | # optimizer.use_ema(True)
650 | samples = generate(vqvae, bottom_model, top_model, args, ys=torch.eye(args.n_cond_classes + 1, args.n_cond_classes).to(args.device))
651 | if args.distributed:
652 | torch.manual_seed(args.rank)
653 | # collect samples tensor from all processes onto main process cpu
654 | tensors = [torch.empty(samples.shape, dtype=samples.dtype).cuda() for i in range(args.world_size)]
655 | torch.distributed.all_gather(tensors, samples) # collect samples tensor from all processes onto main process cpu
656 | samples = torch.cat(tensors, 2)
657 | if args.on_main_process:
658 | samples = samples.cpu()
659 | writer.add_image('samples', samples, args.step)
660 | save_image(samples.cpu(), os.path.join(args.output_dir, 'generation_sample_step_{}.png'.format(args.step)))
661 |
662 |
--------------------------------------------------------------------------------