├── .gitignore ├── air.py ├── basic_vae.py ├── datasets └── chexpert.py ├── dcgan.py ├── draw.py ├── draw_test_attn.py ├── environment.yml ├── images ├── air │ ├── air_count.png │ ├── air_elbo.png │ └── image_recons_270.png ├── basic_vae │ ├── reconstruction_at_epoch_24.png │ ├── sample_at_epoch_24.png │ └── tsne_embedding.png ├── dcgan │ └── latent_var_grid_sample_c1.png ├── draw │ ├── draw_fig_3.png │ ├── draw_fig_4.png │ ├── elephant.png │ └── generated_32_time_steps.gif ├── infogan │ ├── latent_var_grid_sample_c1.png │ └── latent_var_grid_sample_c2.png ├── ssvae │ ├── analogies_sample.png │ ├── latent_var_grid_sample_c1_y2.png │ └── latent_var_grid_sample_c2_y4.png └── vqvae2 │ ├── 128x128_bits3_eval_reconstruction_step_87300_bottom.png │ ├── 128x128_bits3_eval_reconstruction_step_87300_original.png │ ├── 128x128_bits3_eval_reconstruction_step_87300_top.png │ └── generation_sample_step_52440_top_b128_c128_outstack10_bottom_b16_c128_nres20_condstack10.png ├── infogan.py ├── optim.py ├── readme.md ├── ssvae.py ├── utils.py ├── vqvae.py └── vqvae_prior.py /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | __pycache__/ 3 | data 4 | *.pyc 5 | -------------------------------------------------------------------------------- /air.py: -------------------------------------------------------------------------------- 1 | """ 2 | Attend, Infer, Repeat: 3 | Fast Scene Understanding with Generative Models 4 | https://arxiv.org/pdf/1603.08575v2.pdf 5 | """ 6 | 7 | import os 8 | import argparse 9 | import pprint 10 | import time 11 | from tqdm import tqdm 12 | 13 | import numpy as np 14 | from observations import multi_mnist 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.distributions as D 20 | from torch.utils.data import DataLoader, Dataset 21 | from torchvision.utils import save_image, make_grid 22 | from tensorboardX import SummaryWriter 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | # actions 28 | parser.add_argument('--train', action='store_true', help='Train a model.') 29 | parser.add_argument('--evaluate', action='store_true', help='Evaluate a model.') 30 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.') 31 | parser.add_argument('--restore_file', type=str, help='Path to model to restore.') 32 | parser.add_argument('--data_dir', default='./data/', help='Location of datasets.') 33 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0])) 34 | parser.add_argument('--seed', type=int, default=2182019, help='Random seed to use.') 35 | parser.add_argument('--cuda', type=int, default=None, help='Which cuda device to use') 36 | parser.add_argument('--verbose', '-v', action='count', help='Verbose mode; send gradient stats to tensorboard.') 37 | # model params 38 | parser.add_argument('--image_dims', type=tuple, default=(1,50,50), help='Dimensions of a single datapoint (e.g. (1,50,50) for multi MNIST).') 39 | parser.add_argument('--z_what_size', type=int, default=50, help='Size of the z_what latent representation.') 40 | parser.add_argument('--z_where_size', type=int, default=3, help='Size of the z_where latent representation e.g. dim=3 for (s, tx, ty) affine parametrization.') 41 | parser.add_argument('--z_pres_size', type=int, default=1, help='Size of the z_pres latent representation, e.g. dim=1 for the probability of occurence of an object.') 42 | parser.add_argument('--enc_dec_size', type=int, default=200, help='Size of the encoder and decoder hidden layers.') 43 | parser.add_argument('--lstm_size', type=int, default=256, help='Size of the LSTM hidden layer for AIR.') 44 | parser.add_argument('--baseline_lstm_size', type=int, default=256, help='Size of the LSTM hidden layer for the gradient baseline estimator.') 45 | parser.add_argument('--attn_window_size', type=int, default=28, help='Size of the attention window of the decoder.') 46 | parser.add_argument('--max_steps', type=int, default=3, help='Maximum number of objects per image to sample a binomial from.') 47 | parser.add_argument('--likelihood_sigma', type=float, default=0.3, help='Sigma parameter for the likelihood function (a Normal distribution).') 48 | parser.add_argument('--z_pres_prior_success_prob', type=float, default=0.75, help='Prior probability of success for the num objects per image prior.') 49 | parser.add_argument('--z_pres_anneal_start_step', type=int, default=1000, help='Start step to begin annealing the num objects per image prior.') 50 | parser.add_argument('--z_pres_anneal_end_step', type=int, default=100000, help='End step to stop annealing the num objects per image prior.') 51 | parser.add_argument('--z_pres_anneal_start_value', type=float, default=0.99, help='Initial probability of success for the num objects per image prior.') 52 | parser.add_argument('--z_pres_anneal_end_value', type=float, default=1e-5, help='Final probility of successs value for the num objects per image prior.') 53 | parser.add_argument('--z_pres_init_encoder_bias', type=float, default=2., help='Add bias to the initialization of the z_pres encoder.') 54 | parser.add_argument('--decoder_bias', type=float, default=-2., help='Add preactivation bias to decoder.') 55 | # training params 56 | parser.add_argument('--batch_size', type=int, default=64) 57 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.') 58 | parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.') 59 | parser.add_argument('--model_lr', type=float, default=1e-4, help='Learning rate for AIR.') 60 | parser.add_argument('--baseline_lr', type=float, default=1e-3, help='Learning rate for the gradient baseline estimator.') 61 | parser.add_argument('--log_interval', type=int, default=100, help='Write loss and parameter stats to tensorboard.') 62 | parser.add_argument('--eval_interval', type=int, default=10, help='Number of epochs to eval model and save checkpoint.') 63 | parser.add_argument('--mini_data_size', type=int, default=None, help='Train only on this number of datapoints.') 64 | 65 | # -------------------- 66 | # Data 67 | # -------------------- 68 | 69 | class MultiMNIST(Dataset): 70 | def __init__(self, root, training=True, download=True, max_digits=2, canvas_size=50, seed=42, mini_data_size=None): 71 | self.root = os.path.expanduser(root) 72 | 73 | # check if multi mnist already compiled 74 | self.multi_mnist_filename = 'multi_mnist_{}_{}_{}'.format(max_digits, canvas_size, seed) 75 | 76 | if not self._check_processed_exists(): 77 | if self._check_raw_exists(): 78 | # process into pt file 79 | data = np.load(os.path.join(self.root, 'raw', self.multi_mnist_filename + '.npz')) 80 | train_data, train_labels, test_data, test_labels = [data[f] for f in data.files] 81 | self._process_and_save(train_data, train_labels, test_data, test_labels) 82 | else: 83 | if not download: 84 | raise RuntimeError('Dataset not found. Use download=True to download it.') 85 | else: 86 | (train_data, train_labels), (test_data, test_labels) = multi_mnist(root, max_digits, canvas_size, seed) 87 | self._process_and_save(train_data, train_labels, test_data, test_labels) 88 | else: 89 | data = torch.load(os.path.join(self.root, 'processed', self.multi_mnist_filename + '.pt')) 90 | self.train_data, self.train_labels, self.test_data, self.test_labels = \ 91 | data['train_data'], data['train_labels'], data['test_data'], data['test_labels'] 92 | 93 | if training: 94 | self.x, self.y = self.train_data, self.train_labels 95 | else: 96 | self.x, self.y = self.test_data, self.test_labels 97 | 98 | if mini_data_size != None: 99 | self.x = self.x[:mini_data_size] 100 | self.y = self.y[:mini_data_size] 101 | 102 | def __getitem__(self, idx): 103 | return self.x[idx].unsqueeze(0), self.y[idx] 104 | 105 | def __len__(self): 106 | return len(self.x) 107 | 108 | def _check_processed_exists(self): 109 | return os.path.exists(os.path.join(self.root, 'processed', self.multi_mnist_filename + '.pt')) 110 | 111 | def _check_raw_exists(self): 112 | return os.path.exists(os.path.join(self.root, 'raw', self.multi_mnist_filename + '.npz')) 113 | 114 | def _make_label_tensor(self, label_arr): 115 | out = torch.zeros(10) 116 | for l in label_arr: 117 | out[l] += 1 118 | return out 119 | 120 | def _process_and_save(self, train_data, train_labels, test_data, test_labels): 121 | self.train_data = torch.from_numpy(train_data).float() / 255 122 | self.train_labels = torch.stack([self._make_label_tensor(label) for label in train_labels]) 123 | self.test_data = torch.from_numpy(test_data).float() / 255 124 | self.test_labels = torch.stack([self._make_label_tensor(label) for label in test_labels]) 125 | # check folder exists 126 | if not os.path.exists(os.path.join(self.root, 'processed')): 127 | os.makedirs(os.path.join(self.root, 'processed')) 128 | with open(os.path.join(self.root, 'processed', self.multi_mnist_filename + '.pt'), 'wb') as f: 129 | torch.save({'train_data': self.train_data, 130 | 'train_labels': self.train_labels, 131 | 'test_data': self.test_data, 132 | 'test_labels': self.test_labels}, 133 | f) 134 | 135 | def fetch_dataloaders(args): 136 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {} 137 | dataset = MultiMNIST(root=args.data_dir, training=True, mini_data_size=args.mini_data_size) 138 | train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) 139 | dataset = MultiMNIST(root=args.data_dir, training=False if args.mini_data_size is None else True, mini_data_size=args.mini_data_size) 140 | test_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, **kwargs) 141 | return train_dataloader, test_dataloader 142 | 143 | 144 | # -------------------- 145 | # Model helper functions -- spatial tranformer 146 | # -------------------- 147 | 148 | def stn(image, z_where, out_dims, inverse=False, box_attn_window_color=None): 149 | """ spatial transformer network used to scale and shift input according to z_where in: 150 | 1/ x -> x_att -- shapes (H, W) -> (attn_window, attn_window) -- thus inverse = False 151 | 2/ y_att -> y -- (attn_window, attn_window) -> (H, W) -- thus inverse = True 152 | 153 | inverting the affine transform as follows: A_inv ( A * image ) = image 154 | A = [R | T] where R is rotation component of angle alpha, T is [tx, ty] translation component 155 | A_inv rotates by -alpha and translates by [-tx, -ty] 156 | 157 | if x' = R * x + T --> x = R_inv * (x' - T) = R_inv * x - R_inv * T 158 | 159 | here, z_where is 3-dim [scale, tx, ty] so inverse transform is [1/scale, -tx/scale, -ty/scale] 160 | R = [[s, 0], -> R_inv = [[1/s, 0], 161 | [0, s]] [0, 1/s]] 162 | """ 163 | 164 | if box_attn_window_color is not None: 165 | # draw a box around the attention window by overwriting the boundary pixels in the given color channel 166 | with torch.no_grad(): 167 | box = torch.zeros_like(image.expand(-1,3,-1,-1)) 168 | c = box_attn_window_color % 3 # write the color bbox in channel c, as model time steps 169 | box[:,c,:,0] = 1 170 | box[:,c,:,-1] = 1 171 | box[:,c,0,:] = 1 172 | box[:,c,-1,:] = 1 173 | # add box to image and clap at 1 if overlap 174 | image = torch.clamp(image + box, 0, 1) 175 | 176 | # 1. construct 2x3 affine matrix for each datapoint in the minibatch 177 | theta = torch.zeros(2,3).repeat(image.shape[0], 1, 1).to(image.device) 178 | # set scaling 179 | theta[:, 0, 0] = theta[:, 1, 1] = z_where[:,0] if not inverse else 1 / (z_where[:,0] + 1e-9) 180 | # set translation 181 | theta[:, :, -1] = z_where[:, 1:] if not inverse else - z_where[:,1:] / (z_where[:,0].view(-1,1) + 1e-9) 182 | # 2. construct sampling grid 183 | grid = F.affine_grid(theta, torch.Size(out_dims)) 184 | # 3. sample image from grid 185 | return F.grid_sample(image, grid) 186 | 187 | 188 | # -------------------- 189 | # Model helper functions -- distribution manupulations 190 | # -------------------- 191 | 192 | def compute_geometric_from_bernoulli(obj_probs): 193 | """ compute a normalized truncated geometric distribution from a table of bernoulli probs 194 | args 195 | obj_probs -- tensor of shape (N, max_steps) of Bernoulli success probabilities. 196 | """ 197 | cum_succ_probs = obj_probs.cumprod(1) 198 | fail_probs = 1 - obj_probs 199 | geom = torch.cat([fail_probs[:,:1], fail_probs[:,1:] * cum_succ_probs[:,:-1], cum_succ_probs[:,-1:]], dim=1) 200 | return geom / geom.sum(1, True) 201 | 202 | def compute_z_pres_kl(q_z_pres_geom, p_z_pres, writer=None): 203 | """ compute kl divergence between truncated geom prior and tabular geom posterior 204 | args 205 | p_z_pres -- torch.distributions.Geometric object 206 | q_z_pres_geom -- torch tensor of shape (N, max_steps + 1) of a normalized geometric pdf 207 | """ 208 | # compute normalized truncated geometric 209 | p_z_pres_log_probs = p_z_pres.log_prob(torch.arange(q_z_pres_geom.shape[1], dtype=torch.float, device=q_z_pres_geom.device)) 210 | p_z_pres_normed_log_probs = p_z_pres_log_probs - p_z_pres_log_probs.logsumexp(dim=0) 211 | 212 | kl = q_z_pres_geom * (torch.log(q_z_pres_geom + 1e-8) - p_z_pres_normed_log_probs.expand_as(q_z_pres_geom)) 213 | return kl 214 | 215 | def anneal_z_pres_prob(prob, step, args): 216 | if args.z_pres_anneal_start_step < step < args.z_pres_anneal_end_step: 217 | slope = (args.z_pres_anneal_end_value - args.z_pres_anneal_start_value) / (args.z_pres_anneal_end_step - args.z_pres_anneal_start_step) 218 | prob = torch.tensor(args.z_pres_anneal_start_value + slope * (step - args.z_pres_anneal_start_step), device=prob.device) 219 | return prob 220 | 221 | 222 | # -------------------- 223 | # Model 224 | # -------------------- 225 | 226 | class AIR(nn.Module): 227 | def __init__(self, args): 228 | super().__init__() 229 | self.debug = False 230 | # record dims 231 | self.C, self.H, self.W = args.image_dims 232 | self.A = args.attn_window_size 233 | x_size = self.C * self.H * self.W 234 | self.lstm_size = args.lstm_size 235 | self.baseline_lstm_size = args.baseline_lstm_size 236 | self.z_what_size = args.z_what_size 237 | self.z_where_size = args.z_where_size 238 | self.max_steps = args.max_steps 239 | 240 | # -------------------- 241 | # p model -- cf AIR paper section 2 242 | # -------------------- 243 | 244 | # latent variable priors 245 | # z_pres ~ Ber(p) Geom(rho) discrete representation for the presence of a scene object 246 | # z_where ~ N(mu, scale); continuous 3-dim variable for pose (position and scale) 247 | # z_what ~ N(0,1); continuous representation for shape 248 | self.register_buffer('z_pres_prior', torch.tensor(args.z_pres_prior_success_prob)) # prior used for generation 249 | self.register_buffer('z_pres_prob', torch.tensor(args.z_pres_anneal_start_value)) # `current value` used for training and annealing 250 | self.register_buffer('z_what_mean', torch.zeros(args.z_what_size)) 251 | self.register_buffer('z_what_scale', torch.ones(args.z_what_size)) 252 | self.register_buffer('z_where_mean', torch.tensor([0.3, 0., 0.])) 253 | self.register_buffer('z_where_scale', torch.tensor([0.1, 1., 1.])) 254 | 255 | # likelihood = N(mu, sigma) 256 | self.register_buffer('likelihood_sigma', torch.tensor(args.likelihood_sigma)) 257 | 258 | # likelihood p(x|n,z) of the data given the latents 259 | self.decoder = nn.Sequential(nn.Linear(args.z_what_size, args.enc_dec_size), 260 | nn.ReLU(True), 261 | nn.Linear(args.enc_dec_size, self.C * self.A ** 2)) 262 | self.decoder_bias = args.decoder_bias # otherwise initial samples are heavily penalized by likelihood (cf Pyro implementation) 263 | 264 | # -------------------- 265 | # q model for approximating the posterior -- cf AIR paper section 2.1 266 | # -------------------- 267 | 268 | # encoder 269 | # rnn encodes the latents z_1:t over the number of steps where z_pres indicates presence of an object 270 | # q_z_pres encodes whether there is an object present in the image; q_z_pres = Bernoulli 271 | # q_z_what encodes the attention window; q_z_what = Normal(mu, sigma) 272 | # q_z_where encodes the affine transform of of the image > attn_window; q_z_where = Normal(0, cov) of dim = 3 for [scale, tx, ty] 273 | self.encoder = nn.ModuleDict({ 274 | 'rnn': nn.LSTMCell(x_size + args.z_where_size + args.z_what_size + args.z_pres_size, args.lstm_size), 275 | 'z_pres': nn.Linear(args.lstm_size, 1), 276 | 'z_what': nn.Sequential(nn.Linear(self.A ** 2 , args.enc_dec_size), 277 | nn.ReLU(True), 278 | nn.Linear(args.enc_dec_size, 2 * args.z_what_size)), 279 | 'z_where': nn.Linear(args.lstm_size, 2 * args.z_where_size)}) 280 | 281 | nn.init.constant_(self.encoder.z_pres.bias, args.z_pres_init_encoder_bias) # push initial num time steps probs higher 282 | 283 | # initialize STN to identity 284 | self.encoder.z_where.weight.data.zero_() 285 | self.encoder.z_where.bias.data = torch.cat([torch.zeros(args.z_where_size), -1.*torch.ones(args.z_where_size)],dim=0) 286 | 287 | # -------------------- 288 | # Baseline model for NVIL per Mnih & Gregor 289 | # -------------------- 290 | 291 | self.baseline = nn.ModuleDict({ 292 | 'rnn': nn.LSTMCell(x_size + args.z_where_size + args.z_what_size + args.z_pres_size, args.baseline_lstm_size), 293 | 'linear': nn.Linear(args.baseline_lstm_size, 1)}) 294 | 295 | @property 296 | def p_z_pres(self): 297 | return D.Geometric(probs=1-self.z_pres_prob) 298 | 299 | @property 300 | def p_z_what(self): 301 | return D.Normal(self.z_what_mean, self.z_what_scale) 302 | 303 | @property 304 | def p_z_where(self): 305 | return D.Normal(self.z_where_mean, self.z_where_scale) 306 | 307 | def forward(self, x, writer=None, box_attn_window_color=None): 308 | """ cf AIR paper Figure 3 (right) for model flow. 309 | Computes (1) inference for z latents; 310 | (2) data reconstruction given the latents; 311 | (3) baseline for decreasing gradient variance; 312 | (4) losses 313 | Returns 314 | recon_x -- tensor of shape (B, C, H, W); reconstruction of data 315 | pred_counts -- teonsor of shape (B,); predicted number of object for each data point 316 | elbo -- tensor of shape (B,); variational lower bound 317 | loss -- tensor of shape (0) of the scalar objective loss 318 | baseline loss -- tensor of shape (0) of the scalar baseline loss (cf Mnih & Gregor NVIL) 319 | """ 320 | batch_size = x.shape[0] 321 | device = x.device 322 | 323 | # store for elbo computation 324 | pred_counts = torch.zeros(batch_size, self.max_steps, device=device) # store for object count accuracy 325 | obj_probs = torch.ones(batch_size, self.max_steps, device=device) # store for computing the geometric posterior 326 | baseline = torch.zeros(batch_size, device=device) 327 | kl_z_pres = torch.zeros(batch_size, device=device) 328 | kl_z_what = torch.zeros(batch_size, device=device) 329 | kl_z_where = torch.zeros(batch_size, device=device) 330 | 331 | # initialize canvas, encoder rnn, states of the latent variables, mask for z_pres, baseline rnn 332 | recon_x = torch.zeros(batch_size, 3 if box_attn_window_color is not None else self.C, self.H, self.W, device=device) 333 | h_enc = torch.zeros(batch_size, self.lstm_size, device=device) 334 | c_enc = torch.zeros_like(h_enc) 335 | z_pres = torch.ones(batch_size, 1, device=device) 336 | z_what = torch.zeros(batch_size, self.z_what_size, device=device) 337 | z_where = torch.rand(batch_size, self.z_where_size, device=device) 338 | h_baseline = torch.zeros(batch_size, self.baseline_lstm_size, device=device) 339 | c_baseline = torch.zeros_like(h_baseline) 340 | 341 | # run model forward up to a max number of reconstruction steps 342 | for i in range(self.max_steps): 343 | 344 | # -------------------- 345 | # Inference step -- AIR paper fig3 middle. 346 | # 1. compute 1-dimensional Bernoulli variable indicating the entity’s presence 347 | # 2. compute 3-dimensional vector specifying the affine parameters of its position and scale (ziwhere). 348 | # 3. compute C-dimensional distributed vector describing its class or appearance (ziwhat) 349 | # -------------------- 350 | 351 | # rnn encoder 352 | h_enc, c_enc = self.encoder.rnn(torch.cat([x, z_pres, z_what, z_where], dim=-1), (h_enc, c_enc)) 353 | 354 | # 1. compute 1-dimensional Bernoulli variable indicating the entity’s presence; note: if z_pres == 0, subsequent mask are zeroed 355 | q_z_pres = D.Bernoulli(probs = torch.clamp(z_pres * torch.sigmoid(self.encoder.z_pres(h_enc)), 1e-5, 1 - 1e-5)) # avoid probs that are exactly 0 or 1 356 | z_pres = q_z_pres.sample() 357 | 358 | # 2. compute 3-dimensional vector specifying the affine parameters of its position and scale (ziwhere). 359 | q_z_where_mean, q_z_where_scale = self.encoder.z_where(h_enc).chunk(2, -1) 360 | q_z_where = D.Normal(q_z_where_mean + self.z_where_mean, F.softplus(q_z_where_scale) * self.z_where_scale) 361 | z_where = q_z_where.rsample() 362 | 363 | # attend to a part of the image (using a spatial transformer) to produce x_i_att 364 | x_att = stn(x.view(batch_size, self.C, self.H, self.W), z_where, (batch_size, self.C, self.A, self.A), inverse=False) 365 | 366 | # 3. compute C-dimensional distributed vector describing its class or appearance (ziwhat) 367 | q_z_what_mean, q_z_what_scale = self.encoder.z_what(x_att.flatten(start_dim=1)).chunk(2, -1) 368 | q_z_what = D.Normal(q_z_what_mean, F.softplus(q_z_what_scale)) 369 | z_what = q_z_what.rsample() 370 | 371 | # -------------------- 372 | # Reconstruction step 373 | # 1. computes y_i_att reconstruction of the attention window x_att 374 | # 2. add to canvas over all timesteps 375 | # -------------------- 376 | 377 | # 1. compute reconstruction of the attention window 378 | y_att = torch.sigmoid(self.decoder(z_what).view(-1, self.C, self.A, self.A) + self.decoder_bias) 379 | 380 | # scale and shift y according to z_where 381 | y = stn(y_att, z_where, (batch_size, self.C, self.H, self.W), inverse=True, box_attn_window_color=i if box_attn_window_color is not None else None) 382 | 383 | # 2. add reconstruction to canvas 384 | recon_x += y * z_pres.view(-1,1,1,1) 385 | 386 | # -------------------- 387 | # Baseline step -- AIR paper cf's Mnih & Gregor NVIL; specifically sec 2.3 variance reduction 388 | # -------------------- 389 | 390 | # compute baseline; independent of the z latents (cf Mnih & Gregor NVIL) so detach from graph 391 | baseline_input = torch.cat([x, z_pres.detach(), z_what.detach(), z_where.detach()], dim=-1) 392 | h_baseline, c_baseline = self.baseline.rnn(baseline_input, (h_baseline, c_baseline)) 393 | baseline += self.baseline.linear(h_baseline).squeeze() # note: masking by z_pres give poorer results 394 | 395 | # -------------------- 396 | # Variational lower bound / loss components 397 | # -------------------- 398 | 399 | # compute kl(q||p) divergences -- sum over latent dim 400 | kl_z_what += D.kl.kl_divergence(q_z_what, self.p_z_what).sum(1) * z_pres.squeeze() 401 | kl_z_where += D.kl.kl_divergence(q_z_where, self.p_z_where).sum(1) * z_pres.squeeze() 402 | 403 | pred_counts[:,i] = z_pres.flatten() 404 | obj_probs[:,i] = q_z_pres.probs.flatten() 405 | 406 | q_z_pres = compute_geometric_from_bernoulli(obj_probs) 407 | score_fn = q_z_pres[torch.arange(batch_size), pred_counts.sum(1).long()].log() # log prob of num objects under the geometric 408 | kl_z_pres = compute_z_pres_kl(q_z_pres, self.p_z_pres, writer).sum(1) # note: mask by pred_counts makes no difference 409 | 410 | p_x_z = D.Normal(recon_x.flatten(1), self.likelihood_sigma) 411 | log_like = p_x_z.log_prob(x.view(-1, self.C, self.H, self.W).expand_as(recon_x).flatten(1)).sum(-1) # sum image dims (C, H, W) 412 | 413 | # -------------------- 414 | # Compute variational bound and loss function 415 | # -------------------- 416 | 417 | elbo = log_like - kl_z_pres - kl_z_what - kl_z_where # objective for loss function, but high variance 418 | loss = - torch.sum(elbo + (elbo - baseline).detach() * score_fn) # var reduction surrogate objective objective (cf Mnih & Gregor NVIL) 419 | baseline_loss = F.mse_loss(elbo.detach(), baseline) 420 | 421 | if writer: 422 | writer.add_scalar('log_like', log_like.mean(0).item(), writer.step) 423 | writer.add_scalar('kl_z_pres', kl_z_pres.mean(0).item(), writer.step) 424 | writer.add_scalar('kl_z_what', kl_z_what.mean(0).item(), writer.step) 425 | writer.add_scalar('kl_z_where', kl_z_where.mean(0).item(), writer.step) 426 | writer.add_scalar('elbo', elbo.mean(0).item(), writer.step) 427 | writer.add_scalar('baseline', baseline.mean(0).item(), writer.step) 428 | writer.add_scalar('score_function', score_fn.mean(0).item(), writer.step) 429 | writer.add_scalar('z_pres_prob', self.z_pres_prob.item(), writer.step) 430 | 431 | return recon_x, pred_counts, elbo, loss, baseline_loss 432 | 433 | @torch.no_grad() 434 | def generate(self, n_samples): 435 | """ AIR paper figure 3 left: 436 | 437 | The generative model draws n ∼ Geom(ρ) digits {y_i_att} of size 28 × 28 (two shown), scales andshifts them 438 | according to z_i_where ∼ N (0, Σ) using spatial transformers, and sums the results {y_i} to form a 50 × 50 image. 439 | Each digit is obtained by first sampling a latent code z_i_what from the prior z_i_what ∼ N (0, 1) and 440 | propagating it through the decoder network of a variational autoencoder. 441 | The learnable parameters θ of the generative model are the parameters of this decoder network. 442 | """ 443 | # sample z_pres ~ Geom(rho) -- this is the number of digits present in an image 444 | z_pres = D.Geometric(1 - self.z_pres_prior).sample((n_samples,)).clamp_(0, self.max_steps) 445 | 446 | # compute a mask on z_pres as e.g.: 447 | # z_pres = [1,4,2,0] 448 | # mask = [[1,0,0,0,0], 449 | # [1,1,1,1,0], 450 | # [1,1,0,0,0], 451 | # [0,0,0,0,0]] 452 | # thus network outputs more objects (sample z_what, z_where and decode) where z_pres is 1 453 | # and outputs nothing when z_pres is 0 454 | z_pres_mask = torch.arange(self.max_steps).float().to(z_pres.device).expand(n_samples, self.max_steps) < z_pres.view(-1,1) 455 | z_pres_mask = z_pres_mask.float().to(z_pres.device) 456 | 457 | # initialize image canvas 458 | x = torch.zeros(n_samples, self.C, self.H, self.W).to(z_pres.device) 459 | 460 | # generate digits 461 | for i in range(int(z_pres.max().item())): # up until the number of objects sampled via z_pres 462 | # sample priors 463 | z_what = self.p_z_what.sample((n_samples,)) 464 | z_where = self.p_z_where.sample((n_samples,)) 465 | 466 | # propagate through the decoder, scale and shift y_att according to z_where using spatial transformers 467 | y_att = torch.sigmoid(self.decoder(z_what).view(n_samples, self.C, self.A, self.A) + self.decoder_bias) 468 | y = stn(y_att, z_where, (n_samples, self.C, self.H, self.W), inverse=True, box_attn_window_color=i) 469 | 470 | # apply mask and sum results towards final image 471 | x = x + y * z_pres_mask[:,i].view(-1,1,1,1) 472 | return x 473 | 474 | 475 | # -------------------- 476 | # Train and evaluate 477 | # -------------------- 478 | 479 | def train_epoch(model, dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, epoch, writer, args): 480 | model.train() 481 | 482 | with tqdm(total=len(dataloader), desc='epoch {} / {}'.format(epoch+1, args.start_epoch + args.n_epochs)) as pbar: 483 | 484 | for i, (x, y) in enumerate(dataloader): 485 | writer.step += 1 # update global step 486 | 487 | x = x.view(x.shape[0], -1).to(args.device) 488 | 489 | # run through model and compute loss 490 | recon_x, pred_counts, elbo, loss, baseline_loss = model(x, writer if i % args.log_interval == 0 else None) # pass writer at logging intervals 491 | 492 | # anneal z_pres prior 493 | model.z_pres_prob = anneal_z_pres_prob(model.z_pres_prob, writer.step, args) 494 | 495 | model_optimizer.zero_grad() 496 | loss.backward() 497 | model_optimizer.step() 498 | 499 | baseline_optimizer.zero_grad() 500 | baseline_loss.backward() 501 | baseline_optimizer.step() 502 | 503 | # update tracking 504 | count_accuracy = torch.eq(pred_counts.sum(1).cpu(), y.sum(1)).float().mean() 505 | pbar.set_postfix(elbo='{:.3f}'.format(elbo.mean(0).item()), \ 506 | loss='{:.3f}'.format(loss.item()), \ 507 | count_acc='{:.2f}'.format(count_accuracy.item())) 508 | pbar.update() 509 | 510 | if i % args.log_interval == 0: 511 | writer.add_scalar('loss', loss.item(), writer.step) 512 | writer.add_scalar('baseline_loss', baseline_loss.item(), writer.step) 513 | writer.add_scalar('count_accuracy_train', count_accuracy.item(), writer.step) 514 | 515 | if args.verbose == 1: 516 | print('z_pres prior:', model.p_z_pres.log_prob(torch.arange(args.max_steps + 1.).to(args.device)).exp(), \ 517 | 'post:', compute_geometric_from_bernoulli(pred_counts.mean(0).unsqueeze(0)).squeeze(), \ 518 | 'ber success:', pred_counts.mean(0)) 519 | 520 | @torch.no_grad() 521 | def evaluate(model, dataloader, args, n_samples=10): 522 | model.eval() 523 | 524 | # initialize trackers 525 | elbo = 0 526 | pred_counts = [] 527 | true_counts = [] 528 | 529 | # evaluate elbo 530 | for x, y in tqdm(dataloader): 531 | x = x.view(x.shape[0], -1).to(args.device) 532 | _, pred_counts_i, elbo_i, _, _ = model(x) 533 | elbo += elbo_i.sum(0).item() 534 | pred_counts += [pred_counts_i.cpu()] 535 | true_counts += [y] 536 | elbo /= (len(dataloader) * args.batch_size) 537 | 538 | # evaluate count accuracy; test dataset not shuffled to preds and true aligned sequentially 539 | pred_counts = torch.cat(pred_counts, dim=0) 540 | true_counts = torch.cat(true_counts, dim=0) 541 | count_accuracy = torch.eq(pred_counts.sum(1), true_counts.sum(1)).float().mean() 542 | 543 | # visualize reconstruction 544 | x = x[-n_samples:] # take last n_sample data points 545 | recon_x, _, _, _, _ = model(x, box_attn_window_color=True) 546 | image_recons = torch.cat([x.view(-1, *args.image_dims).expand_as(recon_x), recon_x], dim=0) 547 | image_recons = make_grid(image_recons.cpu(), nrow=n_samples, pad_value=1) 548 | 549 | return elbo, count_accuracy, image_recons 550 | 551 | def train_and_evaluate(model, train_dataloader, test_dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, writer, args): 552 | 553 | for epoch in range(args.start_epoch, args.start_epoch + args.n_epochs): 554 | # train 555 | train_epoch(model, train_dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, epoch, writer, args) 556 | 557 | # evaluate 558 | if epoch % args.eval_interval == 0: 559 | test_elbo, count_accuracy, image_recons = evaluate(model, test_dataloader, args) 560 | print('Evaluation at epoch {}: test elbo {:.3f}; count accuracy {:.3f}'.format(epoch, test_elbo, count_accuracy)) 561 | writer.add_scalar('test_elbo', test_elbo, epoch) 562 | writer.add_scalar('count_accuracy_test', count_accuracy, epoch) 563 | writer.add_image('image_reconstruction', image_recons, epoch) 564 | save_image(image_recons, os.path.join(args.output_dir, 'image_recons_{}.png'.format(epoch))) 565 | 566 | # generate samples 567 | samples = model.generate(n_samples=10) 568 | images = make_grid(samples, nrow=samples.shape[0], pad_value=1) 569 | save_image(images, os.path.join(args.output_dir, 'generated_sample_{}.png'.format(epoch))) 570 | writer.add_image('training_sample', images, epoch) 571 | 572 | # save training checkpoint 573 | torch.save({'epoch': epoch, 574 | 'global_step': writer.step, 575 | 'state_dict': model.state_dict()}, 576 | os.path.join(args.output_dir, 'checkpoint.pt')) 577 | 578 | 579 | # -------------------- 580 | # Main 581 | # -------------------- 582 | 583 | if __name__ == '__main__': 584 | args = parser.parse_args() 585 | 586 | # setup writer and output folders 587 | writer = SummaryWriter(log_dir = os.path.join(args.output_dir, time.strftime('%Y-%m-%d_%H-%M-%S', time.gmtime())) \ 588 | if not args.restore_file else os.path.dirname(args.restore_file)) 589 | writer.step = 0 590 | args.output_dir = writer.file_writer.get_logdir() # update output_dir with the writer unique directory 591 | 592 | # setup device 593 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu') 594 | torch.manual_seed(args.seed) 595 | if args.device.type == 'cuda': torch.cuda.manual_seed(args.seed) 596 | 597 | # load data 598 | train_dataloader, test_dataloader = fetch_dataloaders(args) 599 | 600 | # load model 601 | model = AIR(args).to(args.device) 602 | 603 | # load optimizers 604 | model_optimizer = torch.optim.RMSprop(model.parameters(), lr=args.model_lr, momentum=0.9) 605 | baseline_optimizer = torch.optim.RMSprop(model.parameters(), lr=args.baseline_lr, momentum=0.9) 606 | 607 | if args.restore_file: 608 | checkpoint = torch.load(args.restore_file, map_location=args.device) 609 | model.load_state_dict(checkpoint['state_dict']) 610 | writer.step = checkpoint['global_step'] 611 | args.start_epoch = checkpoint['epoch'] + 1 612 | # set up paths 613 | args.output_dir = os.path.dirname(args.restore_file) 614 | 615 | # save settings 616 | with open(os.path.join(args.output_dir, 'config.txt'), 'a') as f: 617 | print('Parsed args:\n', pprint.pformat(args.__dict__), file=f) 618 | print('\nModel:\n', model, file=f) 619 | 620 | if args.train: 621 | train_and_evaluate(model, train_dataloader, test_dataloader, model_optimizer, baseline_optimizer, anneal_z_pres_prob, writer, args) 622 | 623 | if args.evaluate: 624 | test_elbo, count_accuracy, image_recons = evaluate(model, test_dataloader, args) 625 | print('Evaluation: test elbo {:.3f}; {:.3f}'.format(test_elbo, count_accuracy)) 626 | save_image(image_recons, os.path.join(args.output_dir, 'image_recons.png')) 627 | 628 | if args.generate: 629 | samples = model.generate(n_samples=7) 630 | images = make_grid(samples, pad_value=1) 631 | save_image(images, os.path.join(args.output_dir, 'generated_sample.png')) 632 | writer.add_image('generated_sample', images) 633 | 634 | writer.close() 635 | -------------------------------------------------------------------------------- /basic_vae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Auto-encoding Variational Bayes 3 | https://arxiv.org/pdf/1312.6114.pdf 4 | 5 | Reference implementatoin in pytorch examples https://github.com/pytorch/examples/blob/master/vae/ 6 | Toy example per Adversarial Variational Bayes https://arxiv.org/abs/1701.04722 7 | 8 | """ 9 | 10 | import os 11 | import argparse 12 | from tqdm import tqdm 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.distributions as D 18 | import torchvision.transforms as T 19 | from torch.utils.data import DataLoader, Dataset 20 | from torchvision.datasets import MNIST 21 | from torchvision.utils import make_grid, save_image 22 | 23 | import matplotlib 24 | matplotlib.use('TkAgg') 25 | import matplotlib.pyplot as plt 26 | from sklearn.manifold import TSNE 27 | 28 | 29 | 30 | parser = argparse.ArgumentParser() 31 | subparsers = parser.add_subparsers(help='Dataset specific configs for input and latent dimensions.', dest='dataset') 32 | 33 | # training params 34 | parser.add_argument('--batch_size', type=int, default=128) 35 | parser.add_argument('--n_epochs', type=int, default=10) 36 | parser.add_argument('--seed', type=int, default=11272018) 37 | parser.add_argument('--save_model', action='store_true') 38 | parser.add_argument('--quiet', action='store_true') 39 | 40 | parser.add_argument('--data_dir', default='./data') 41 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0])) 42 | 43 | # model parameters 44 | toy_subparser = subparsers.add_parser('toy') 45 | toy_subparser.add_argument('--x_dim', type=int, default=4, help='Dimension of the input data.') 46 | toy_subparser.add_argument('--z_dim', type=int, default=2, help='Size of the latent space.') 47 | toy_subparser.add_argument('--hidden_dim', type=int, default=400, help='Size of the hidden layer.') 48 | 49 | mnist_subparser = subparsers.add_parser('mnist') 50 | mnist_subparser.add_argument('--x_dim', type=int, default=28*28, help='Dimension of the input data.') 51 | mnist_subparser.add_argument('--z_dim', type=int, default=100, help='Size of the latent space.') 52 | mnist_subparser.add_argument('--hidden_dim', type=int, default=400, help='Size of the hidden layer.') 53 | 54 | 55 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 56 | 57 | 58 | # -------------------- 59 | # Data 60 | # -------------------- 61 | 62 | def fetch_dataloader(args, train=True, download=False): 63 | 64 | transforms = T.Compose([T.ToTensor()]) 65 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms) 66 | 67 | kwargs = {'num_workers': 1, 'pin_memory': True} if device.type is 'cuda' else {} 68 | 69 | return DataLoader(dataset, batch_size=args.batch_size, shuffle=train, drop_last=True, **kwargs) 70 | 71 | class ToyDataset(Dataset): 72 | def __init__(self, args): 73 | super().__init__() 74 | self.x_dim = args.x_dim 75 | self.batch_size = args.batch_size 76 | 77 | def __len__(self): 78 | return self.batch_size * 1000 79 | 80 | def __getitem__(self, i): 81 | one_hot = torch.zeros(self.x_dim) 82 | label = torch.randint(0, self.x_dim, (1, )).long() 83 | one_hot[label] = 1. 84 | return one_hot, label 85 | 86 | def fetch_toy_dataloader(args): 87 | return DataLoader(ToyDataset(args), batch_size=args.batch_size, shuffle=True) 88 | 89 | 90 | # -------------------- 91 | # Plotting helpers 92 | # -------------------- 93 | 94 | def plot_tsne(model, test_loader, args): 95 | data = test_loader.dataset.test_data.float() / 255. 96 | data = data.view(data.shape[0], -1) 97 | labels = test_loader.dataset.test_labels 98 | classes = torch.unique(labels, sorted=True).numpy() 99 | 100 | p_x_z, q_z_x = model(data) 101 | 102 | tsne = TSNE(n_components=2, random_state=0) 103 | z_embed = tsne.fit_transform(q_z_x.loc.cpu().numpy()) # map the posterior mean 104 | 105 | fig = plt.figure() 106 | for i in classes: 107 | mask = labels.cpu().numpy() == i 108 | plt.scatter(z_embed[mask, 0], z_embed[mask, 1], s=10, label=str(i)) 109 | 110 | plt.title('Latent variable T-SNE embedding per class') 111 | plt.legend() 112 | plt.gca().axis('off') 113 | fig.savefig(os.path.join(args.output_dir, 'tsne_embedding.png')) 114 | 115 | 116 | def plot_scatter(model, args): 117 | data = torch.eye(args.x_dim).repeat(args.batch_size, 1) 118 | labels = data @ torch.arange(args.x_dim).float() 119 | 120 | _, q_z_x = model(data) 121 | z = q_z_x.sample().numpy() 122 | plt.scatter(z[:,0], z[:,1], c=labels.data.numpy(), alpha=0.5) 123 | 124 | plt.title('Latent space embedding per class\n(n_iter = {})'.format(len(ToyDataset(args))*args.n_epochs)) 125 | plt.savefig(os.path.join(args.output_dir, 'latent_distribution_toy_example.png')) 126 | plt.close() 127 | 128 | # -------------------- 129 | # Model 130 | # -------------------- 131 | 132 | class VAE(nn.Module): 133 | def __init__(self, args):#in_dim=784, hidden_dim=400, z_dim=20): 134 | super().__init__() 135 | self.fc1 = nn.Linear(args.x_dim, args.hidden_dim) 136 | self.fc21 = nn.Linear(args.hidden_dim, args.z_dim) 137 | self.fc22 = nn.Linear(args.hidden_dim, args.z_dim) 138 | self.fc3 = nn.Linear(args.z_dim, args.hidden_dim) 139 | self.fc4 = nn.Linear(args.hidden_dim, args.x_dim) 140 | 141 | # q(z|x) parametrizes the approximate posterior as a Normal(mu, scale) 142 | def encode(self, x): 143 | h1 = F.relu(self.fc1(x)) 144 | mu = self.fc21(h1) 145 | scale = self.fc22(h1).exp() 146 | return D.Normal(mu, scale) 147 | 148 | # p(x|z) returns the likelihood of data given the latents 149 | def decode(self, z): 150 | h3 = F.relu(self.fc3(z)) 151 | logits = self.fc4(h3) 152 | return D.Bernoulli(logits=logits) 153 | 154 | def forward(self, x): 155 | q_z_x = self.encode(x.view(x.shape[0], -1)) # returns Normal 156 | p_x_z = self.decode(q_z_x.rsample()) # returns Bernoulli; note reparametrization when sampling the approximate 157 | return p_x_z, q_z_x 158 | 159 | 160 | # ELBO loss 161 | def loss_fn(p_x_z, q_z_x, x): 162 | # Equation 3 from Kingma & Welling -- Auto-Encoding Variational Bayes 163 | # ELBO = - KL( q(z|x), p(z) ) + Expectation_under_q(z|x)_[log p(x|z)] 164 | # this simplifies to eq 7 from Kingma nad Welling where the expectation is avg of z samples 165 | # signs are revered from paper as paper maximizes ELBO and here we min - ELBO 166 | # both KLD and BCE are summed over dim 1 (image H*W) and mean over dim 0 (batch) 167 | p_z = D.Normal(torch.FloatTensor([0], device=x.device), torch.FloatTensor([1], device=x.device)) 168 | KLD = D.kl.kl_divergence(q_z_x, p_z).sum(1).mean(0) # divergene of the approximate posterior from the prior 169 | BCE = - p_x_z.log_prob(x.view(x.shape[0], -1)).sum(1).mean(0) # expected negative reconstruction error; 170 | # prob density of data x under the generative model given by z 171 | return BCE + KLD 172 | 173 | 174 | # -------------------- 175 | # Train and eval 176 | # -------------------- 177 | 178 | def train_epoch(model, dataloader, loss_fn, optimizer, epoch, args): 179 | model.train() 180 | 181 | ELBO_loss = 0 182 | 183 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar: 184 | for i, (data, _) in enumerate(dataloader): 185 | data = data.to(device) 186 | 187 | p_x_z, q_z_x = model(data) 188 | loss = loss_fn(p_x_z, q_z_x, data) 189 | 190 | optimizer.zero_grad() 191 | loss.backward() 192 | optimizer.step() 193 | 194 | # update tracking 195 | pbar.set_postfix(loss='{:.3f}'.format(loss.item())) 196 | pbar.update() 197 | 198 | ELBO_loss += loss.item() 199 | 200 | print('Epoch: {} Average ELBO loss: {:.4f}'.format(epoch+1, ELBO_loss / (len(dataloader)))) 201 | 202 | 203 | @torch.no_grad() 204 | def evaluate(model, dataloader, loss_fn, epoch, args): 205 | model.eval() 206 | 207 | ELBO_loss = 0 208 | 209 | with tqdm(total=len(dataloader)) as pbar: 210 | for i, (data, _) in enumerate(dataloader): 211 | data = data.to(device) 212 | p_x_z, q_z_x = model(data) 213 | 214 | ELBO_loss += loss_fn(p_x_z, q_z_x, data).item() 215 | 216 | pbar.update() 217 | 218 | if i == 0 and args.dataset == 'mnist': 219 | nrow = 10 220 | n = min(data.size(0), nrow**2) 221 | real_data = make_grid(data[:n].cpu(), nrow) 222 | spacer = torch.ones(real_data.shape[0], real_data.shape[1], 5) 223 | generated_data = make_grid(p_x_z.probs.view(args.batch_size, 1, 28, 28)[:n].cpu(), nrow) 224 | image = torch.cat([real_data, spacer, generated_data], dim=-1) 225 | save_image(image, os.path.join(args.output_dir, 'reconstruction_at_epoch_' + str(epoch) + '.png'), nrow) 226 | 227 | print('Test set average ELBO loss: {:.4f}'.format(ELBO_loss / len(dataloader))) 228 | 229 | 230 | def train_and_evaluate(model, train_loader, test_loader, loss_fn, optimizer, args): 231 | for epoch in range(args.n_epochs): 232 | train_epoch(model, train_loader, loss_fn, optimizer, epoch, args) 233 | evaluate(model, test_loader, loss_fn, epoch, args) 234 | 235 | # save weights 236 | if args.save_model: 237 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'vae_model_xdim{}_hdim{}_zdim{}.pt'.format( 238 | args.x_dim, args.hidden_dim, args.z_dim))) 239 | 240 | # show samples 241 | if args.dataset == 'mnist': 242 | with torch.no_grad(): 243 | # sample p(z) = Normal(0, 1) 244 | prior_sample = torch.randn(64, args.z_dim).to(device) 245 | # compute likelihood p(x|z) decoder; returns torch.distribution.Bernoulli 246 | likelihood = model.decode(prior_sample).probs 247 | save_image(likelihood.cpu().view(64, 1, 28, 28), os.path.join(args.output_dir, 'sample_at_epoch_' + str(epoch) + '.png')) 248 | 249 | 250 | 251 | if __name__ == '__main__': 252 | args = parser.parse_args() 253 | if not os.path.isdir(os.path.join(args.output_dir, args.dataset)): 254 | os.makedirs(os.path.join(args.output_dir, args.dataset)) 255 | args.output_dir = os.path.join(args.output_dir, args.dataset) 256 | 257 | torch.manual_seed(args.seed) 258 | 259 | # data 260 | if args.dataset == 'toy': 261 | train_loader = fetch_toy_dataloader(args) 262 | test_loader = train_loader 263 | else: 264 | train_loader = fetch_dataloader(args, train=True) 265 | test_loader = fetch_dataloader(args, train=False) 266 | 267 | # model 268 | model = VAE(args).to(device) 269 | 270 | # optimizer 271 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 272 | 273 | # train and eval 274 | train_and_evaluate(model, train_loader, test_loader, loss_fn, optimizer, args) 275 | 276 | # visualize z space 277 | with torch.no_grad(): 278 | if args.dataset == 'toy': 279 | plot_scatter(model, args) 280 | else: 281 | pass 282 | plot_tsne(model, test_loader, args) 283 | 284 | -------------------------------------------------------------------------------- /datasets/chexpert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import glob 5 | from multiprocessing import Pool 6 | from functools import partial 7 | 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from PIL import Image 11 | 12 | import torch 13 | from torch.utils.data import Dataset 14 | import torchvision.transforms as T 15 | 16 | 17 | class ChexpertDataset(Dataset): 18 | attr_all_names = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 19 | 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 20 | 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 21 | 'Fracture', 'Support Devices'] 22 | # subset of labels to use 23 | attr_names = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion'] 24 | 25 | def __init__(self, root, train=True, transform=None): 26 | self.root = os.path.expanduser(root) 27 | self.transform = transform 28 | self.input_dims = (1, *(int(i) for i in self.root.strip('/').rpartition('/')[2].split('_')[-1].split('x'))) 29 | 30 | if train: 31 | self.data = self._load_and_preprocess_training_data(os.path.join(self.root, 'train.csv')) 32 | else: 33 | self.data = pd.read_csv(os.path.join(self.root, 'valid.csv'), keep_default_na=True) 34 | 35 | # store index of the selected attributes in the columns of the data for faster indexing 36 | self.attr_idxs = [self.data.columns.tolist().index(a) for a in self.attr_names] 37 | 38 | def __getitem__(self, idx): 39 | # 1. select and load image 40 | img_path = self.data.iloc[idx, 1] # `Path` is the first column after index 41 | img = Image.open(os.path.join(self.root, img_path.partition('/')[2])) 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | 45 | # 2. select attributes as targets 46 | attr = self.data.iloc[idx, self.attr_idxs].values.astype(float) 47 | attr = torch.from_numpy(attr).float() 48 | 49 | return img, attr 50 | 51 | def __len__(self): 52 | return len(self.data) 53 | 54 | def _load_and_preprocess_training_data(self, csv_path): 55 | # Dataset labels are: blank for unmentioned, 0 for negative, -1 for uncertain, and 1 for positive. 56 | # Process by: 57 | # 1. fill NAs (blanks for unmentioned) as 0 (negatives) 58 | # 2. fill -1 as 1 (U-Ones method described in paper) 59 | 60 | # load 61 | train_df = pd.read_csv(csv_path, keep_default_na=True) 62 | 63 | # 1. fill NAs (blanks for unmentioned) as 0 (negatives) 64 | # attr columns ['No Finding', ..., 'Support Devices']; note AP/PA remains with NAs for Lateral pictures 65 | train_df[self.attr_names] = train_df[self.attr_names].fillna(0) 66 | 67 | # 2. fill -1 as 1 (U-Ones method described in paper) 68 | train_df[self.attr_names] = train_df[self.attr_names].replace(-1,1) 69 | 70 | return train_df 71 | 72 | 73 | def compute_dataset_mean_and_std(dataset): 74 | m = 0 75 | s = 0 76 | k = 1 77 | for img, _, _ in tqdm(dataset): 78 | x = img.mean().item() 79 | new_m = m + (x - m)/k 80 | s += (x - m)*(x - new_m) 81 | m = new_m 82 | k += 1 83 | print('Number of datapoints: ', k) 84 | return m, math.sqrt(s/(k-1)) 85 | 86 | # -------------------- 87 | # Resize dataset 88 | # -------------------- 89 | 90 | def _process_entry(img_path, root, source_dir, target_dir, transforms): 91 | # img_path is e.g. `CheXpert-v1.0-small/valid/patient64541/study1/view1_frontal.jpg` 92 | subpath = img_path.partition('/')[-1] 93 | 94 | # make new img folders/subfolders 95 | os.makedirs(os.path.dirname(os.path.join(root, target_dir, subpath)), exist_ok=True) 96 | 97 | # save resized image 98 | img = Image.open(os.path.join(root, source_dir, subpath)) 99 | img = transforms(img) 100 | img.save(os.path.join(root, target_dir, subpath), quality=97) 101 | img.close() 102 | 103 | def make_resized_dataset(root, source_dir, size, n_workers): 104 | root = os.path.expanduser(root) 105 | target_dir = 'CheXpert_{}x{}'.format(size, size) 106 | 107 | assert (not os.path.exists(os.path.join(root, target_dir, 'train.csv')) and \ 108 | not os.path.exists(os.path.join(root, target_dir, 'valid.csv'))), 'Data exists at target dir.' 109 | print('Resizing dataset at root {}:\n\tsource {}\n\ttarget {}\n\tNew size: {}x{}'.format(root, source_dir, target_dir, size, size)) 110 | 111 | transforms = T.Compose([T.Resize(size, Image.BICUBIC), T.CenterCrop(size)]) 112 | 113 | 114 | for split in ['train', 'valid']: 115 | print('Processing ', split, ' split...') 116 | csv_path = os.path.join(root, source_dir, split + '.csv') 117 | 118 | # load data and preprocess NAs 119 | df = pd.read_csv(csv_path, keep_default_na=True) 120 | 121 | if split == 'train': 122 | # 1. fill NAs (blanks for unmentioned) as 0 (negatives) 123 | # attr columns ['No Finding', ..., 'Support Devices']; note AP/PA remains with NAs for Lateral pictures 124 | df[ChexpertDataset.attr_names] = df[ChexpertDataset.attr_names].fillna(0) 125 | 126 | # 2. fill -1 as 1 (U-Ones method described in paper) 127 | df[ChexpertDataset.attr_names] = df[ChexpertDataset.attr_names].replace(-1,1) 128 | 129 | # make new folders, resize image and store 130 | f = partial(_process_entry, root=root, source_dir=source_dir, target_dir=target_dir, transforms=transforms) 131 | with Pool(n_workers) as p: 132 | p.map(f, df['Path'].tolist()) 133 | 134 | # replace `CheXpert-v1.0-small` root with new root defined above 135 | df['Path'] = df['Path'].str.replace(source_dir, target_dir) 136 | 137 | # save new df 138 | df.to_csv(os.path.join(root, target_dir, split + '.csv')) 139 | 140 | 141 | # resize entire dataset 142 | if False: 143 | root = '/mnt/disks/chexpert-ssd' 144 | source_dir = 'CheXpert-v1.0-small' 145 | new_size = 64 146 | n_workers = 16 147 | 148 | make_resized_dataset(root, source_dir, new_size, n_workers) 149 | 150 | # compute dataset mean and std 151 | if False: 152 | ds = ChexpertSmall(root=args.data_dir, train=True, transform=T.Compose([T.CenterCrop(320), T.ToTensor()])) 153 | m, s = compute_mean_and_std(ds) 154 | print('Dataset mean: {}; dataset std {}'.format(m, s)) 155 | # Dataset mean: 0.533048452958796; dataset std 0.03490651403764978 156 | 157 | # output a few images from the validation set and display labels 158 | if False: 159 | import torchvision.transforms as T 160 | from torchvision.utils import save_image 161 | ds = ChexpertSmall(root=args.data_dir, train=False, 162 | transform=T.Compose([T.ToTensor(), T.Normalize(mean=[0.5330], std=[0.0349])])) 163 | print('Valid dataset loaded. Length: ', len(ds)) 164 | for i in range(10): 165 | img, attr, patient_id = ds[i] 166 | save_image(img, 'test_valid_dataset_image_{}.png'.format(i), normalize=True, scale_each=True) 167 | print('Patient id: {}; labels: {}'.format(patient_id, attr)) 168 | 169 | -------------------------------------------------------------------------------- /dcgan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks 3 | https://arxiv.org/pdf/1511.06434 4 | 5 | Notes: 6 | Model architecture differs from paper: 7 | generator ends with Sigmoid 8 | inputs normalized to [0,1] 9 | learning rates differ 10 | 11 | """ 12 | 13 | 14 | import os 15 | import argparse 16 | from tqdm import tqdm 17 | import time 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.distributions as dist 23 | from torch.utils.data import DataLoader 24 | from torchvision.datasets import MNIST 25 | import torchvision.transforms as T 26 | from torchvision.utils import save_image, make_grid 27 | 28 | import utils 29 | 30 | parser = argparse.ArgumentParser() 31 | 32 | # training params 33 | parser.add_argument('--batch_size', type=int, default=128) 34 | parser.add_argument('--n_epochs', type=int, default=1) 35 | parser.add_argument('--noise_dim', type=int, default=96, help='Size of the latent representation.') 36 | parser.add_argument('--g_lr', type=float, default=1e-3, help='Generator learning rate') 37 | parser.add_argument('--d_lr', type=float, default=1e-4, help='Discriminator learning rate') 38 | parser.add_argument('--log_interval', default=100) 39 | parser.add_argument('--cuda', type=int, help='Which cuda device to use') 40 | parser.add_argument('--mini_data', action='store_true') 41 | # eval params 42 | parser.add_argument('--evaluate_on_grid', action='store_true') 43 | # data paths 44 | parser.add_argument('--save_model', action='store_true') 45 | parser.add_argument('--data_dir', default='./data') 46 | parser.add_argument('--output_dir', default='./results/dcgan') 47 | parser.add_argument('--restore_file', help='Path to .pt checkpoint file for Discriminator and Generator') 48 | 49 | 50 | 51 | 52 | # -------------------- 53 | # Data 54 | # -------------------- 55 | 56 | def fetch_dataloader(args, train=True, download=True, mini_size=128): 57 | # load dataset and init in the dataloader 58 | 59 | transforms = T.Compose([T.ToTensor()]) 60 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms) 61 | 62 | # load dataset and init in the dataloader 63 | if args.mini_data: 64 | if train: 65 | dataset.train_data = dataset.train_data[:mini_size] 66 | dataset.train_labels = dataset.train_labels[:mini_size] 67 | else: 68 | dataset.test_data = dataset.test_data[:mini_size] 69 | dataset.test_labels = dataset.test_labels[:mini_size] 70 | 71 | 72 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {} 73 | 74 | dl = DataLoader(dataset, batch_size=args.batch_size, shuffle=train, drop_last=True, **kwargs) 75 | 76 | return dl 77 | 78 | 79 | # -------------------- 80 | # Model 81 | # -------------------- 82 | 83 | class Flatten(nn.Module): 84 | def forward(self, x): 85 | return x.view(x.shape[0], -1) 86 | 87 | class Unflatten(nn.Module): 88 | def __init__(self, B, C, H, W): 89 | super().__init__() 90 | self.B = B 91 | self.C = C 92 | self.H = H 93 | self.W = W 94 | 95 | def forward(self, x): 96 | return x.reshape(self.B, self.C, self.H, self.W) 97 | 98 | class Discriminator(nn.Module): 99 | def __init__(self): 100 | super().__init__() 101 | self.net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # out (B, 64, 14, 14) 102 | nn.LeakyReLU(0.2, True), 103 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # out (B, 128, 7, 7) 104 | nn.BatchNorm2d(128), 105 | nn.LeakyReLU(0.2, True), 106 | nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=0, bias=False), # out (B, 128, 4, 4) 107 | nn.BatchNorm2d(256), 108 | nn.LeakyReLU(0.2, True), 109 | nn.Conv2d(256, 512, kernel_size=4, bias=False), # out (B, 256, 1, 1) 110 | nn.BatchNorm2d(512), 111 | nn.LeakyReLU(0.2, True), 112 | nn.Conv2d(512, 1, kernel_size=1, bias=False)) 113 | 114 | def forward(self, x): 115 | return dist.Bernoulli(logits=self.net(x).squeeze()) 116 | 117 | 118 | class Generator(nn.Module): 119 | def __init__(self, noise_dim): 120 | super().__init__() 121 | self.net = nn.Sequential(nn.ConvTranspose2d(noise_dim, 512, kernel_size=1, stride=1, padding=0, bias=False), 122 | nn.BatchNorm2d(512), 123 | nn.ReLU(True), 124 | nn.ConvTranspose2d(512, 256, kernel_size=4, bias=False), 125 | nn.BatchNorm2d(256), 126 | nn.ReLU(True), 127 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=1, padding=0, bias=False), 128 | nn.BatchNorm2d(128), 129 | nn.ReLU(True), 130 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False), 131 | nn.BatchNorm2d(64), 132 | nn.ReLU(True), 133 | nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False), 134 | nn.Sigmoid()) 135 | 136 | def forward(self, x): 137 | return self.net(x) 138 | 139 | 140 | def initialize_weights(m, std=0.02): 141 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 142 | m.weight.data.normal_(mean=1., std=std) 143 | m.bias.data.fill_(0.) 144 | else: 145 | try: 146 | m.weight.data.normal_(std=std) 147 | except AttributeError: # skip activation layers 148 | pass 149 | 150 | 151 | 152 | # -------------------- 153 | # Train 154 | # -------------------- 155 | 156 | def sample_z(args): 157 | # generate samples from the prior 158 | return dist.Uniform(-1,1).sample((args.batch_size, args.noise_dim, 1, 1)).to(args.device) 159 | 160 | 161 | def train_epoch(D, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args): 162 | 163 | fixed_z = sample_z(args) 164 | 165 | real_labels = torch.ones(args.batch_size, 1, device=args.device).requires_grad_(False) 166 | fake_labels = torch.zeros(args.batch_size, 1, device=args.device).requires_grad_(False) 167 | 168 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar: 169 | time.sleep(0.1) 170 | 171 | for i, (x, _) in enumerate(dataloader): 172 | D.train() 173 | G.train() 174 | 175 | x = x.to(args.device) 176 | 177 | # train generator 178 | 179 | # sample prior 180 | z = sample_z(args) 181 | 182 | # run through model 183 | generated = G(z) 184 | d_fake = D(generated) 185 | 186 | # calculate losses 187 | g_loss = - d_fake.log_prob(real_labels).mean() 188 | 189 | g_optimizer.zero_grad() 190 | g_loss.backward() 191 | g_optimizer.step() 192 | 193 | 194 | # train discriminator 195 | d_real = D(x) 196 | d_fake = D(generated.detach()) 197 | 198 | # calculate losses 199 | d_loss = - d_real.log_prob(real_labels).mean() - d_fake.log_prob(fake_labels).mean() 200 | 201 | d_optimizer.zero_grad() 202 | d_loss.backward() 203 | d_optimizer.step() 204 | 205 | 206 | # update tracking 207 | pbar.set_postfix(d_loss='{:.3f}'.format(d_loss.item()), 208 | g_loss='{:.3f}'.format(g_loss.item())) 209 | pbar.update() 210 | 211 | if i % args.log_interval == 0: 212 | step = epoch 213 | writer.add_scalar('d_loss', d_loss.item(), step) 214 | writer.add_scalar('g_loss', g_loss.item(), step) 215 | # sample images 216 | with torch.no_grad(): 217 | G.eval() 218 | fake_images = G(fixed_z) 219 | writer.add_image('generated', make_grid(fake_images[:10].cpu(), nrow=10, padding=1), step) 220 | save_image(fake_images[:10].cpu(), 221 | os.path.join(args.output_dir, 'generated_sample_epoch_{}.png'.format(epoch)), 222 | nrow=10) 223 | 224 | 225 | def train(D, G, dataloader, d_optimizer, g_optimizer, writer, args): 226 | 227 | print('Starting training with args:\n', args) 228 | 229 | start_epoch = 0 230 | 231 | if args.restore_file: 232 | print('Restoring parameters from {}'.format(args.restore_file)) 233 | start_epoch = utils.load_checkpoint(args.restore_file, [D, G], [d_optimizer, g_optimizer]) 234 | args.n_epochs += start_epoch - 1 235 | print('Resuming training from epoch {}'.format(start_epoch)) 236 | 237 | for epoch in range(start_epoch, args.n_epochs): 238 | train_epoch(D, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args) 239 | 240 | # snapshot at end of epoch 241 | if args.save_model: 242 | utils.save_checkpoint({'epoch': epoch + 1, 243 | 'model_state_dicts': [D.state_dict(), G.state_dict()], 244 | 'optimizer_state_dicts': [d_optimizer.state_dict(), g_optimizer.state_dict()]}, 245 | checkpoint=args.output_dir, 246 | quiet=True) 247 | 248 | @torch.no_grad() 249 | def evaluate_on_grid(G, writer, args): 250 | # sample noise randomly 251 | z = torch.empty(100, args.noise_dim, 1, 1).uniform_(-1,1).to(args.device) 252 | 253 | fake_images = G(z) 254 | writer.add_image('generated grid', make_grid(fake_images.cpu(), nrow=10, normalize=True, padding=1)) 255 | save_image(fake_images.cpu(), 256 | os.path.join(args.output_dir, 'latent_var_grid_sample_c1.png'), 257 | nrow=10) 258 | 259 | 260 | 261 | if __name__ == '__main__': 262 | args = parser.parse_args() 263 | 264 | if not os.path.isdir(args.output_dir): 265 | os.makedirs(args.output_dir) 266 | 267 | writer = utils.set_writer(args.output_dir, '_train') 268 | 269 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu') 270 | 271 | # set seed 272 | torch.manual_seed(11122018) 273 | if args.device is torch.device('cuda'): torch.cuda.manual_seed(11122018) 274 | 275 | # input 276 | dataloader = fetch_dataloader(args) 277 | 278 | # models 279 | D = Discriminator().to(args.device) 280 | G = Generator(args.noise_dim).to(args.device) 281 | D.apply(initialize_weights) 282 | G.apply(initialize_weights) 283 | 284 | # optimizers 285 | d_optimizer = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999)) 286 | g_optimizer = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999)) 287 | 288 | # train 289 | # eval 290 | if args.evaluate_on_grid: 291 | print('Restoring parameters from {}'.format(args.restore_file)) 292 | _ = utils.load_checkpoint(args.restore_file, [D, G], [d_optimizer, g_optimizer]) 293 | evaluate_on_grid(G, writer, args) 294 | # train 295 | else: 296 | dataloader = fetch_dataloader(args) 297 | train(D, G, dataloader, d_optimizer, g_optimizer, writer, args) 298 | evaluate_on_grid(G, writer, args) 299 | writer.close() 300 | 301 | -------------------------------------------------------------------------------- /draw.py: -------------------------------------------------------------------------------- 1 | """ 2 | DRAW: A Recurrent Neural Network For Image Generation 3 | https://arxiv.org/pdf/1502.04623.pdf 4 | """ 5 | 6 | import os 7 | import argparse 8 | import time 9 | from tqdm import tqdm 10 | import pprint 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.distributions as D 16 | import torchvision.transforms as T 17 | from torch.utils.data import DataLoader 18 | from torchvision.datasets import MNIST 19 | from torchvision.utils import save_image, make_grid 20 | 21 | import utils 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | # model parameters 26 | parser.add_argument('--image_dims', type=tuple, default=(1,28,28), help='Dimensions of a single datapoint (e.g. (1,28,28) for MNIST).') 27 | parser.add_argument('--time_steps', type=int, default=32, help='Number of time-steps T consumed by the network before performing reconstruction.') 28 | parser.add_argument('--z_size', type=int, default=100, help='Size of the latent representation.') 29 | parser.add_argument('--lstm_size', type=int, default=256, help='Size of the hidden layer in the encoder/decoder models.') 30 | parser.add_argument('--read_size', type=int, default=2, help='Size of the read operation visual field.') 31 | parser.add_argument('--write_size', type=int, default=5, help='Size of the write operation visual field.') 32 | parser.add_argument('--use_read_attn', action='store_true', help='Whether to use visual attention or not. If not, read/write field size is the full image.') 33 | parser.add_argument('--use_write_attn', action='store_true', help='Whether to use visual attention or not. If not, read/write field size is the full image.') 34 | 35 | # training params 36 | parser.add_argument('--train', action='store_true') 37 | parser.add_argument('--train_batch_size', type=int, default=128) 38 | parser.add_argument('--n_epochs', type=int, default=50, help='Number of epochs to train.') 39 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') 40 | parser.add_argument('--log_interval', default=100, help='How often to write summary outputs.') 41 | parser.add_argument('--cuda', type=int, help='Which cuda device to use') 42 | parser.add_argument('--mini_data', action='store_true') 43 | parser.add_argument('--verbose', '-v', action='store_true', help='Extra monitoring of training + record forward/backward hooks on attn params') 44 | 45 | # eval params 46 | parser.add_argument('--evaluate', action='store_true') 47 | parser.add_argument('--test_batch_size', type=int, default=10, help='Batch size for evaluation') 48 | 49 | # generate params 50 | parser.add_argument('--generate', action='store_true') 51 | 52 | # data paths 53 | parser.add_argument('--save_model', action='store_true') 54 | parser.add_argument('--data_dir', default='./data') 55 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0])) 56 | parser.add_argument('--restore_file', help='Path to .pt checkpoint file.') 57 | 58 | 59 | 60 | 61 | # -------------------- 62 | # Data 63 | # -------------------- 64 | 65 | def fetch_dataloader(args, batch_size, train=True, download=False, mini_size=128): 66 | 67 | transforms = T.Compose([T.ToTensor()]) 68 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms) 69 | 70 | # load dataset and init in the dataloader 71 | if args.mini_data: 72 | if train: 73 | dataset.train_data = dataset.train_data[:mini_size] 74 | dataset.train_labels = dataset.train_labels[:mini_size] 75 | else: 76 | dataset.test_data = dataset.test_data[:mini_size] 77 | dataset.test_labels = dataset.test_labels[:mini_size] 78 | 79 | 80 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {} 81 | 82 | return DataLoader(dataset, batch_size=batch_size, shuffle=train, drop_last=True, **kwargs) 83 | 84 | 85 | # -------------------- 86 | # Model 87 | # -------------------- 88 | 89 | class DRAW(nn.Module): 90 | def __init__(self, args): 91 | super().__init__() 92 | 93 | # save dimensions 94 | self.C, self.H, self.W = args.image_dims 95 | self.time_steps = args.time_steps 96 | self.lstm_size = args.lstm_size 97 | self.z_size = args.z_size 98 | self.use_read_attn = args.use_read_attn 99 | self.use_write_attn = args.use_write_attn 100 | if self.use_read_attn: 101 | self.read_size = args.read_size 102 | else: 103 | self.read_size = self.H 104 | if self.use_write_attn: 105 | self.write_size = args.write_size 106 | else: 107 | self.write_size = self.H 108 | 109 | # encoder - decoder layers 110 | self.encoder = nn.LSTMCell(2 * self.read_size * self.read_size + self.lstm_size, self.lstm_size) 111 | self.decoder = nn.LSTMCell(self.z_size, self.lstm_size) 112 | 113 | # latent space layer 114 | # outputs the parameters of the q distribution (mu and var; here q it is a Normal) 115 | self.z_linear = nn.Linear(self.lstm_size, 2 * self.z_size) 116 | 117 | # write layers 118 | self.write_linear = nn.Linear(self.lstm_size, self.write_size * self.write_size) 119 | 120 | # filter bank 121 | # outputs center location (g_x, g_y), stride delta, logvar of the gaussian filters, scalar intensity gamma (5 params in total) 122 | self.read_attention_linear = nn.Linear(self.lstm_size, 5) 123 | self.write_attention_linear = nn.Linear(self.lstm_size, 5) 124 | 125 | def compute_q_distribution(self, h_enc): 126 | mus, logsigmas = torch.split(self.z_linear(h_enc), [self.z_size, self.z_size], dim=1) 127 | sigmas = logsigmas.exp() 128 | z_dist = D.Normal(loc=mus, scale=sigmas) 129 | return z_dist.rsample(), mus, sigmas 130 | 131 | def read(self, x, x_hat, h_dec): 132 | if self.use_read_attn: 133 | # read filter bank -- eq 21 134 | g_x, g_y, logvar, logdelta, loggamma = self.read_attention_linear(h_dec).split(split_size=1, dim=1) # split returns column vecs 135 | # compute filter bank matrices -- eq 22 - 26 136 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, self.H, self.W, self.read_size) 137 | # output reading wth attention -- eq 27 138 | new_x = F_y @ x.view(-1, self.H, self.W) @ F_x.transpose(-2, -1) # out (B, N, N) 139 | new_x_hat = F_y @ x_hat.view(-1, self.H, self.W) @ F_x.transpose(-2, -1) # out (B, N, N) 140 | return loggamma.exp() * torch.cat([new_x.view(x.shape[0], -1), new_x_hat.view(x.shape[0], -1)], dim=1) 141 | else: 142 | # output reading without attention -- eq 17 143 | return torch.cat([x, x_hat], dim=1) 144 | 145 | def write(self, h_dec): 146 | w = self.write_linear(h_dec) 147 | 148 | if self.use_write_attn: 149 | # read filter bank -- eq 21 150 | g_x, g_y, logvar, logdelta, loggamma = self.write_attention_linear(h_dec).split(split_size=1, dim=1) # split returns column vecs 151 | # compute filter bank matrices -- eq 22 - 26 152 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, self.H, self.W, self.write_size) 153 | # output write with attention -- eq 29 154 | w = F_y.transpose(-2, -1) @ w.view(-1, self.write_size, self.write_size) @ F_x 155 | return 1. / loggamma.exp() * w.view(w.shape[0], -1) 156 | else: 157 | # output write without attention -- eq 18 158 | return w 159 | 160 | def forward(self, x): 161 | batch_size = x.shape[0] 162 | device = x.device 163 | 164 | # record metrics for loss calculation 165 | mus, sigmas = [0]*self.time_steps, [0]*self.time_steps 166 | 167 | # initialize the canvas matrix c on same device and the hidden state and cell state for the encoder and decoder 168 | c = torch.zeros(*x.shape).to(device) 169 | h_enc = torch.zeros(batch_size, self.lstm_size).to(device) 170 | c_enc = torch.zeros_like(h_enc) 171 | h_dec = torch.zeros(batch_size, self.lstm_size).to(device) 172 | c_dec = torch.zeros_like(h_dec) 173 | 174 | # run model forward (cf DRAW eq 3 - 8) 175 | for t in range(self.time_steps): 176 | x_hat = x.to(c.device) - torch.sigmoid(c) 177 | r = self.read(x, x_hat, h_dec) 178 | h_enc, c_enc = self.encoder(torch.cat([r, h_dec], dim=1), (h_enc, c_enc)) 179 | z_sample, mus[t], sigmas[t] = self.compute_q_distribution(h_enc) 180 | h_dec, c_dec = self.decoder(z_sample, (h_dec, c_dec)) 181 | c = c + self.write(h_dec) 182 | 183 | # return 184 | # data likelihood; used to compute L_x loss -- shape (B, H*W) 185 | # sequence of latent distributions Q; used to compute L_z loss (here Normal of shape (B, z_size, time_steps) 186 | return D.Bernoulli(logits=c), D.Normal(torch.stack(mus, dim=-1), torch.stack(sigmas, dim=-1)) 187 | 188 | @torch.no_grad() 189 | def generate(self, n_samples, args): 190 | samples_time_seq = [] 191 | 192 | # initialize model 193 | c = torch.zeros(n_samples, self.C * self.H * self.W).to(args.device) 194 | h_dec = torch.zeros(n_samples, self.lstm_size).to(args.device) 195 | c_dec = torch.zeros_like(h_dec).to(args.device) 196 | 197 | # run for the number of time steps 198 | for t in range(self.time_steps): 199 | z_sample = D.Normal(0,1).sample((n_samples, self.z_size)).to(args.device) 200 | h_dec, c_dec = self.decoder(z_sample, (h_dec, c_dec)) 201 | c = c + self.write(h_dec) 202 | x = D.Bernoulli(logits=c.view(n_samples, self.C, self.H, self.W)).probs 203 | 204 | samples_time_seq.append(x) 205 | 206 | return samples_time_seq 207 | 208 | 209 | def compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size): 210 | """ DRAW section 3.2 -- computes the parameters for an NxN grid of Gaussian filters over the input image. 211 | Args 212 | g_x, g_y -- tensors of shape (B, 1); unnormalized center coords for the attention window 213 | logvar -- tensor of shape (B, 1); log variance for the Gaussian filters (filterbank matrices) on the attention window 214 | logdelta -- tensor of shape (B, 1); unnormalized stride for the spacing of the filters in the attention window 215 | H, W -- scalars; original image dimensions 216 | attn_window_size -- scalar; size of the attention window (specified by the read_size / write_size input args 217 | 218 | Returns 219 | g_x, g_y -- tensors of shape (B, 1); normalized center coords of the attention window; 220 | delta -- tensor of shape (B, 1); stride for the spacing of the filters in the attention window 221 | mu_x, mu_y -- tensors of shape (B, attn_window_size); means location of the filters at row and column 222 | F_x, F_y -- tensors of shape (B, N, W) and (B, N, H) where N=attention_window_size; filterbank matrices 223 | """ 224 | 225 | batch_size = g_x.shape[0] 226 | device = g_x.device 227 | 228 | # rescale attention window center coords and stride to ensure the initial patch covers the whole input image 229 | # eq 22 - 24 230 | g_x = 0.5 * (W + 1) * (g_x + 1) # (B, 1) 231 | g_y = 0.5 * (H + 1) * (g_y + 1) # (B, 1) 232 | delta = (max(H, W) - 1) / (attn_window_size - 1) * logdelta.exp() # (B, 1) 233 | 234 | # compute the means of the filter 235 | # eq 19 - 20 236 | mu_x = g_x + (torch.arange(1., 1. + attn_window_size).to(device) - 0.5*(attn_window_size + 1)) * delta # (B, N) 237 | mu_y = g_y + (torch.arange(1., 1. + attn_window_size).to(device) - 0.5*(attn_window_size + 1)) * delta # (B, N) 238 | 239 | # compute the filterbank matrices 240 | # B = batch dim; N = attn window size; H = original heigh; W = original width 241 | # eq 25 -- combines logvar=(B, 1, 1) * ( range=(B, 1, W) - mu=(B, N, 1) ) = out (B, N, W); then normalizes over W dimension; 242 | F_x = torch.exp(- 0.5 / logvar.exp().view(-1,1,1) * (torch.arange(1., 1. + W).repeat(batch_size, 1, 1).to(device) - mu_x.unsqueeze(-1))**2) 243 | F_x = F_x / torch.sum(F_x + 1e-8, dim=2, keepdim=True) # normalize over the coordinates of the input image 244 | # eq 26 245 | F_y = torch.exp(- 0.5 / logvar.exp().view(-1,1,1) * (torch.arange(1., 1. + H).repeat(batch_size, 1, 1).to(device) - mu_y.unsqueeze(-1))**2) 246 | F_y = F_y / torch.sum(F_y + 1e-8, dim=2, keepdim=True) # normalize over the coordinates of the input image 247 | 248 | # returns DRAW paper eq 22, 23, 24, 19, 20, 25, 26 249 | return g_x, g_y, delta, mu_x, mu_y, F_x, F_y 250 | 251 | 252 | def loss_fn(d, q, x, writer=None, step=None): 253 | """ 254 | Args 255 | d -- data likelihood distribution output by the model (Bernoulli) 256 | q -- approximation distribution to the latent variable z (Normal) 257 | """ 258 | # cf DRAW paper section 2 259 | 260 | # reconstruction loss L_x eq 9 -- negative log probability of x under the model d 261 | # (sum log probs over the pixels of each datapoint and mean over batch dim) 262 | # latent loss L_z eq 10 -- sum KL over temporal dimension (number of time steps) and mean over batch and z dims 263 | batch_size = x.shape[0] 264 | p_prior = D.Normal(torch.tensor(0., device=x.device), torch.tensor(1., device=x.device)) 265 | loss_log_likelihood = - d.log_prob(x).sum(-1).mean(0) # sum over pixels (-1), mean over datapoints (0) 266 | loss_kl = D.kl.kl_divergence(q, p_prior).sum(dim=[-2,-1]).mean(0) # sum over time_steps (-1) and z (-2), mean over datapoints (0) 267 | 268 | if writer: 269 | writer.add_scalar('loss_log_likelihood', loss_log_likelihood, step) 270 | writer.add_scalar('loss_kl', loss_kl, step) 271 | 272 | return loss_log_likelihood + loss_kl 273 | 274 | 275 | # -------------------- 276 | # Train and eval 277 | # -------------------- 278 | 279 | def train_epoch(model, dataloader, loss_fn, optimizer, epoch, writer, args): 280 | model.train() 281 | 282 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar: 283 | time.sleep(0.1) 284 | 285 | for i, (x, _) in enumerate(dataloader): 286 | global_step = epoch * len(dataloader) + i + 1 287 | 288 | x = x.view(x.shape[0], -1).to(args.device) 289 | 290 | (d, q) = model(x) 291 | 292 | loss = loss_fn(d, q, x, writer, global_step) 293 | 294 | optimizer.zero_grad() 295 | loss.backward() 296 | 297 | # record grad norm and clip 298 | if args.verbose: 299 | grad_norm = 0 300 | for name, p in model.named_parameters(): 301 | grad_norm += p.grad.norm().item() if p.grad is not None else 0 302 | writer.add_scalar('grad_norm', grad_norm, global_step) 303 | nn.utils.clip_grad_norm_(model.parameters(), 10) 304 | 305 | optimizer.step() 306 | 307 | # update tracking 308 | pbar.set_postfix(loss='{:.3f}'.format(loss.item())) 309 | pbar.update() 310 | 311 | if i % args.log_interval == 0: 312 | writer.add_scalar('loss', loss.item(), global_step) 313 | 314 | 315 | def train_and_evaluate(model, train_dataloader, test_dataloader, loss_fn, optimizer, writer, args): 316 | start_epoch = 0 317 | 318 | if args.restore_file: 319 | print('Restoring parameters from {}'.format(args.restore_file)) 320 | start_epoch = utils.load_checkpoint(args.restore_file, [model], [optimizer], map_location=args.device.type) 321 | args.n_epochs += start_epoch - 1 322 | print('Resuming training from epoch {}'.format(start_epoch)) 323 | 324 | for epoch in range(start_epoch, args.n_epochs): 325 | train_epoch(model, train_dataloader, loss_fn, optimizer, epoch, writer, args) 326 | # evaluate(model, test_dataloader, loss_fn, writer, args, epoch) 327 | 328 | # snapshot at end of epoch 329 | if args.save_model: 330 | utils.save_checkpoint({'epoch': epoch + 1, 331 | 'model_state_dicts': [model.state_dict()], 332 | 'optimizer_state_dicts': [optimizer.state_dict()]}, 333 | checkpoint=args.output_dir, 334 | quiet=True) 335 | 336 | 337 | @torch.no_grad() 338 | def evaluate(model, dataloader, loss_fn, writer, args, epoch=None): 339 | model.eval() 340 | 341 | # sample the generation model 342 | samples_time_seq = model.generate(args.test_batch_size**2, args) 343 | samples = samples_time_seq[-1] # grab the final sample 344 | # pull targets to search closest neighbor to 345 | # right-most column of a nxn image grid where n is args.test_batch_size -- start at index n-1 and skip by n (e.g 9, 19, ...) 346 | targets = samples[args.test_batch_size - 1 :: args.test_batch_size].view(args.test_batch_size, -1) 347 | # initialize a large max L2 pixel distance and tensor for l2-closest neighbors 348 | max_distances = 100 * samples[0].numel() * torch.ones(args.test_batch_size).to(args.device) 349 | closest_neighbors = torch.zeros_like(targets) 350 | 351 | # compute ELBO on dataset 352 | cum_loss = 0 353 | for x, _ in tqdm(dataloader): 354 | x = x.view(x.shape[0], -1).float().to(args.device) 355 | 356 | # run through model and aggregate loss 357 | d, q = model(x) 358 | # aggregate loss 359 | loss = loss_fn(d, q, x) 360 | cum_loss += loss.item() 361 | 362 | # find closest neighbors to the targets sampled above - l2 distance between targets and images in minibatch 363 | distances = F.pairwise_distance(x, targets) 364 | mask = distances < max_distances 365 | max_distances[mask] = distances[mask] 366 | closest_neighbors[mask] = x[mask] 367 | 368 | cum_loss /= len(dataloader) 369 | # output loss 370 | print('Evaluation ELBO: {:.2f}'.format(cum_loss)) 371 | writer.add_scalar('Evaluation ELBO', cum_loss, epoch) 372 | 373 | # visualize generated samples and closest neighbors (cf DRAW paper fig 6) 374 | generated = make_grid(samples.cpu(), nrow=samples.shape[0]//args.test_batch_size) 375 | spacer = torch.ones_like(generated)[:,:,:2] 376 | neighbors = make_grid(closest_neighbors.view(-1, *args.image_dims).cpu(), nrow=1) 377 | images = torch.cat([generated, spacer, neighbors], dim=-1) 378 | save_image(images, os.path.join(args.output_dir, 379 | 'evaluation_sample' + (epoch!=None)*'_epoch_{}'.format(epoch) + '.png')) 380 | writer.add_image('generated images', images, epoch) 381 | 382 | 383 | @torch.no_grad() 384 | def generate(model, writer, args, n_samples=64): 385 | import math 386 | 387 | # generate samples 388 | samples_time_seq = model.generate(n_samples, args) 389 | 390 | # visualize generation sequence (cf DRAW paper fig 7) 391 | images = torch.stack(samples_time_seq, dim=1).view(-1, *args.image_dims) # reshape to (10*time_steps, 1, 28, 28) 392 | images = make_grid(images, nrow=len(samples_time_seq), padding=1, pad_value=1) 393 | save_name = 'generated_sequences_r{}_w{}_steps{}.png'.format(args.read_size, args.write_size, args.time_steps) 394 | save_image(images, os.path.join(args.output_dir, save_name)) 395 | writer.add_image(save_name, images) 396 | 397 | # make gif 398 | for i in range(len(samples_time_seq)): 399 | # convert sequence of image tensors to 8x8 grid 400 | image = make_grid(samples_time_seq[i].cpu(), nrow=int(math.sqrt(n_samples)), padding=1, normalize=True, pad_value=1) 401 | # make into gif 402 | samples_time_seq[i] = image.data.numpy().transpose(1,2,0) 403 | 404 | import imageio 405 | imageio.mimsave(os.path.join(args.output_dir, 'generated_{}_time_steps.gif'.format(args.time_steps)), samples_time_seq) 406 | 407 | 408 | # -------------------- 409 | # Monitor training 410 | # -------------------- 411 | 412 | def record_attn_params(self, in_tensor, out_tensor, bank_name): 413 | g_x, g_y, logvar, logdelta, loggamma = out_tensor.cpu().split(split_size=1, dim=1) 414 | writer.add_scalar(bank_name + ' g_x', g_x.mean()) 415 | writer.add_scalar(bank_name + ' g_y', g_y.mean()) 416 | writer.add_scalar(bank_name + ' var', logvar.exp().mean()) 417 | writer.add_scalar(bank_name + ' exp_logdelta', logdelta.exp().mean()) 418 | writer.add_scalar(bank_name + ' gamma', loggamma.exp().mean()) 419 | 420 | def record_attn_grads(self, in_tensor, out_tensor, bank_name): 421 | g_x, g_y, logvar, logdelta, loggamma = out_tensor[0].cpu().split(split_size=1, dim=1) 422 | writer.add_scalar(bank_name + ' grad_var', logvar.exp().mean()) 423 | writer.add_scalar(bank_name + ' grad_logdelta', logdelta.exp().mean()) 424 | 425 | def record_forward_backward_attn_hooks(model): 426 | from functools import partial 427 | 428 | model.read_attention_linear.register_forward_hook(partial(record_attn_params, bank_name='read')) 429 | model.write_attention_linear.register_forward_hook(partial(record_attn_params, bank_name='write')) 430 | model.write_attention_linear.register_backward_hook(partial(record_attn_grads, bank_name='write')) 431 | 432 | 433 | # -------------------- 434 | # Main 435 | # -------------------- 436 | 437 | if __name__ == '__main__': 438 | args = parser.parse_args() 439 | 440 | if not os.path.isdir(args.output_dir): 441 | os.makedirs(args.output_dir) 442 | 443 | writer = utils.set_writer(args.output_dir if args.restore_file is None else args.restore_file, 444 | (args.restore_file==None)*'', # suffix only when not restoring 445 | args.restore_file is not None) 446 | # update output_dir with the writer unique directory 447 | args.output_dir = writer.file_writer.get_logdir() 448 | 449 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu') 450 | 451 | # set seed 452 | torch.manual_seed(5) 453 | if args.device.type is 'cuda': torch.cuda.manual_seed(11192018) 454 | 455 | # set up model 456 | model = DRAW(args).to(args.device) 457 | 458 | if args.verbose: 459 | record_forward_backward_attn_hooks(model) 460 | 461 | # train 462 | if args.train: 463 | # optimizer 464 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999)) 465 | # dataloaders 466 | train_dataloader = fetch_dataloader(args, args.train_batch_size, train=True) 467 | test_dataloader = fetch_dataloader(args, args.test_batch_size, train=False) 468 | # run training 469 | print('Starting training with args:\n', args) 470 | writer.add_text('Params', pprint.pformat(args.__dict__)) 471 | with open(os.path.join(args.output_dir, 'params.txt'), 'w') as f: 472 | pprint.pprint(args.__dict__, f) 473 | train_and_evaluate(model, train_dataloader, test_dataloader, loss_fn, optimizer, writer, args) 474 | 475 | # eval 476 | if args.evaluate: 477 | print('Restoring parameters from {}'.format(args.restore_file)) 478 | _ = utils.load_checkpoint(args.restore_file, [model]) 479 | print('Evaluating model with args:\n', args) 480 | # get test dataloader 481 | dataloader = fetch_dataloader(args, args.test_batch_size, train=False) 482 | # evaluate 483 | evaluate(model, dataloader, loss_fn, writer, args) 484 | 485 | # generate 486 | if args.generate: 487 | print('Restoring parameters from {}'.format(args.restore_file)) 488 | _ = utils.load_checkpoint(args.restore_file, [model]) 489 | print('Generating images from model with args:\n', args) 490 | generate(model, writer, args) 491 | 492 | 493 | writer.close() 494 | 495 | -------------------------------------------------------------------------------- /draw_test_attn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchvision.datasets import MNIST 5 | import torchvision.transforms as T 6 | 7 | import matplotlib.pyplot as plt 8 | from matplotlib.patches import Rectangle 9 | import matplotlib.gridspec as gridspec 10 | from PIL import Image 11 | import numpy as np 12 | 13 | from draw import compute_filterbank_matrices 14 | 15 | 16 | # -------------------- 17 | # DRAW paper figure 3 18 | # -------------------- 19 | 20 | def plot_attn_window(img_tensor, mu_x, mu_y, delta, sigma, attn_window_size, ax): 21 | ax.imshow(img_tensor.squeeze().data.numpy(), cmap='gray') 22 | mu_x = mu_x.flatten() 23 | mu_y = mu_y.flatten() 24 | x = mu_x[0] 25 | y = mu_y[0] 26 | w = mu_y[-1] - mu_y[0] 27 | h = mu_x[-1] - mu_x[0] 28 | ax.add_patch(Rectangle((x, y), w, h, facecolor='none', edgecolor='lime', linewidth=5*sigma, alpha=0.7)) 29 | 30 | def plot_filtered_attn_window(img_tensor, F_x, F_y, g_x, g_y, H, W, attn_window_size, ax): 31 | ax.set_xlim(0, W) 32 | ax.set_ylim(0, H) 33 | ax.imshow((F_y @ img_tensor @ F_x.permute(0,2,1)).squeeze(), cmap='gray', extent=(g_x - attn_window_size/2, 34 | g_x + attn_window_size/2, 35 | g_y - attn_window_size/2, 36 | g_y + attn_window_size/2)) 37 | def test_attn_window_params(): 38 | dataset = MNIST(root='./data', transform=T.ToTensor(), train=True, download=False) 39 | img = dataset[0][0] 40 | batch_size, H, W = img.shape 41 | 42 | fig = plt.figure()#figsize=(8,6)) 43 | gs = gridspec.GridSpec(nrows=3, ncols=3, width_ratios=[4,2,1]) 44 | 45 | # Figure 3. 46 | # Left: A 3 × 3 grid of filters superimposed on an image. The stride (δ) and centre location (gX , gY ) are indicated. 47 | attn_window_size = 3 48 | g_x = torch.tensor([[-0.2]]) 49 | g_y = torch.tensor([[0.]]) 50 | logvar = torch.tensor([[1.]]) 51 | logdelta = torch.tensor([[-1.]]) 52 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size) 53 | 54 | ax = fig.add_subplot(gs[:,0]) 55 | ax.imshow(img.squeeze().data.numpy(), cmap='gray') 56 | ax.scatter(g_x.numpy(), g_y.numpy(), s=150, color='orange', alpha=0.8) 57 | ax.scatter(mu_x.view(1, -1).repeat(attn_window_size, 1).numpy(), 58 | mu_y.view(-1,1).repeat(1, attn_window_size).numpy(), s=100, color='lime', alpha=0.8) 59 | 60 | # Right: Three N × N patches extracted from the image (N = 12). 61 | # The green rectangles on the left indicate the boundary and precision (σ) of the patches, while the patches themselves are shown to the right. 62 | # The top patch has a small δ and high σ, giving a zoomed-in but blurry view of the centre of the digit; 63 | # the middle patch has large δ and low σ, effectively downsampling the whole image; 64 | # and the bottom patch has high δ and σ. 65 | attn_window_size = 12 66 | logdeltas = [-1., -0.5, 0.] 67 | sigmas = [1., 0.5, 3.] 68 | 69 | for i, (logdelta, sigma) in enumerate(zip(logdeltas, sigmas)): 70 | g_x = torch.tensor([[-0.2]]) 71 | g_y = torch.tensor([[0.]]) 72 | logvar = torch.tensor(sigma**2).float().view(1,-1).log() 73 | logdelta = torch.tensor(logdelta).float().view(1,-1) 74 | 75 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size) 76 | 77 | # plot attention window 78 | ax = fig.add_subplot(gs[i,1]) 79 | plot_attn_window(img, mu_x, mu_y, delta, sigma, attn_window_size, ax) 80 | 81 | # plot attention window zoom in 82 | ax = fig.add_subplot(gs[i,2]) 83 | plot_filtered_attn_window(img, F_x, F_y, g_x, g_y, H, W, attn_window_size, ax) 84 | 85 | for ax in fig.axes: 86 | ax.axis('off') 87 | 88 | plt.tight_layout() 89 | plt.savefig('images/draw_fig_3.png') 90 | plt.close() 91 | 92 | 93 | # -------------------- 94 | # DRAW paper figure 4 -- Test read write attention on 95 | # -------------------- 96 | 97 | def test_read_write_attn(): 98 | #im = cv2.imread('elephant_r.png') 99 | im = np.asarray(Image.open('images/elephant.png')) 100 | img = torch.from_numpy(im).float() 101 | img /= 255. # normalize to 0-1 102 | img = img.permute(2,0,1) # to torch standard (C, H, W) 103 | 104 | print('image dims -- ', img.shape) 105 | 106 | # store dims 107 | C, H, W = img.shape 108 | attn_window_size = 12 109 | 110 | 111 | # filter params 112 | g_x = torch.tensor([[0.5]]) 113 | g_y = torch.tensor([[0.5]]) 114 | #logvar = torch.tensor([[1.]]).float().log() 115 | #logdelta = torch.tensor([[3.]]).float().log() 116 | logvar = torch.tensor([[1.]]) 117 | logdelta = torch.tensor([[-1.]]) 118 | 119 | g_x, g_y, delta, mu_x, mu_y, F_x, F_y = compute_filterbank_matrices(g_x, g_y, logvar, logdelta, H, W, attn_window_size) 120 | 121 | print('delta -- ', delta) 122 | print('mu_x -- ', mu_x) 123 | print('mu_y -- ', mu_y) 124 | print('F_x shape -- ', F_x.shape) 125 | 126 | mu_x = mu_x.flatten() 127 | mu_y = mu_y.flatten() 128 | 129 | 130 | # read image 131 | read = F_y @ img @ F_x.transpose(-2,-1) 132 | print('read image shape -- ', read.shape) 133 | # reconstruct image 134 | #read.fill_(1) 135 | recon_img = 10 * F_y.transpose(-2,-1) @ read @ F_x 136 | 137 | # plot 138 | fig = plt.figure(figsize=(9, 3)) 139 | gs = gridspec.GridSpec(nrows=1, ncols=3, width_ratios=[W,attn_window_size,W]) 140 | 141 | 142 | # show original image with attention bbox 143 | ax = fig.add_subplot(gs[0,0]) 144 | ax.imshow(im) 145 | x = mu_x[0] 146 | y = mu_y[0] 147 | w = mu_y[-1] - mu_y[0] 148 | h = mu_x[-1] - mu_x[0] 149 | ax.add_patch(Rectangle((x, y), w, h, facecolor='none', edgecolor='lime', linewidth=5, alpha=0.7)) 150 | 151 | # show attention patch 152 | ax = fig.add_subplot(gs[0, 1]) 153 | ax.set_xlim(0, attn_window_size) 154 | ax.set_ylim(0, H) 155 | ax.imshow(read.squeeze().data.numpy().transpose(1,2,0), extent = (0, attn_window_size, H/2 - attn_window_size/2, H/2 + attn_window_size/2)) 156 | 157 | # show reconstruction 158 | ax = fig.add_subplot(gs[0, 2]) 159 | ax.imshow(recon_img.squeeze().data.numpy().transpose(1,2,0)) 160 | 161 | plt.tight_layout() 162 | for ax in plt.gcf().axes: 163 | ax.axis('off') 164 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0) 165 | plt.savefig('images/draw_fig_4.png') 166 | plt.close() 167 | 168 | 169 | 170 | if __name__ == '__main__': 171 | test_attn_window_params() 172 | test_read_write_attn() 173 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: generative_models 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - blas=1.0 8 | - ca-certificates=2019.5.15 9 | - certifi=2019.6.16 10 | - cffi=1.11.5 11 | - cycler=0.10.0 12 | - freetype=2.9.1 13 | - intel-openmp=2019.1 14 | - jpeg=9b 15 | - kiwisolver=1.0.1 16 | - libedit=3.1.20170329 17 | - libffi=3.2.1 18 | - libpng=1.6.35 19 | - libtiff=4.0.9 20 | - matplotlib=3.0.1 21 | - mkl=2018.0.3 22 | - mkl_fft=1.0.6 23 | - mkl_random=1.0.1 24 | - ncurses=6.1 25 | - ninja=1.8.2 26 | - numpy=1.15.4 27 | - numpy-base=1.15.4 28 | - olefile=0.46 29 | - openssl=1.1.1c 30 | - pandas=0.24.2 31 | - pillow=5.3.0 32 | - pip=18.1 33 | - pycparser=2.19 34 | - pyparsing=2.3.0 35 | - python=3.7.1 36 | - python-dateutil=2.7.5 37 | - pytorch=1.1.0 38 | - pytz=2018.7 39 | - readline=7.0 40 | - scipy=1.1.0 41 | - setuptools=40.6.2 42 | - six=1.11.0 43 | - sqlite=3.25.3 44 | - tk=8.6.8 45 | - torchvision=0.3.0 46 | - tornado=5.1.1 47 | - tqdm=4.28.1 48 | - wheel=0.32.3 49 | - xz=5.2.4 50 | - zlib=1.2.11 51 | - pip: 52 | - chardet==3.0.4 53 | - idna==2.7 54 | - lmdb==0.96 55 | - observations==0.1.4 56 | - protobuf==3.6.1 57 | - requests==2.20.1 58 | - tensorboardx==1.4 59 | - urllib3==1.24.1 60 | prefix: /Users/Kamen/anaconda3/envs/generative_models 61 | 62 | -------------------------------------------------------------------------------- /images/air/air_count.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/air/air_count.png -------------------------------------------------------------------------------- /images/air/air_elbo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/air/air_elbo.png -------------------------------------------------------------------------------- /images/air/image_recons_270.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/air/image_recons_270.png -------------------------------------------------------------------------------- /images/basic_vae/reconstruction_at_epoch_24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/basic_vae/reconstruction_at_epoch_24.png -------------------------------------------------------------------------------- /images/basic_vae/sample_at_epoch_24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/basic_vae/sample_at_epoch_24.png -------------------------------------------------------------------------------- /images/basic_vae/tsne_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/basic_vae/tsne_embedding.png -------------------------------------------------------------------------------- /images/dcgan/latent_var_grid_sample_c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/dcgan/latent_var_grid_sample_c1.png -------------------------------------------------------------------------------- /images/draw/draw_fig_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/draw_fig_3.png -------------------------------------------------------------------------------- /images/draw/draw_fig_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/draw_fig_4.png -------------------------------------------------------------------------------- /images/draw/elephant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/elephant.png -------------------------------------------------------------------------------- /images/draw/generated_32_time_steps.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/draw/generated_32_time_steps.gif -------------------------------------------------------------------------------- /images/infogan/latent_var_grid_sample_c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/infogan/latent_var_grid_sample_c1.png -------------------------------------------------------------------------------- /images/infogan/latent_var_grid_sample_c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/infogan/latent_var_grid_sample_c2.png -------------------------------------------------------------------------------- /images/ssvae/analogies_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/ssvae/analogies_sample.png -------------------------------------------------------------------------------- /images/ssvae/latent_var_grid_sample_c1_y2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/ssvae/latent_var_grid_sample_c1_y2.png -------------------------------------------------------------------------------- /images/ssvae/latent_var_grid_sample_c2_y4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/ssvae/latent_var_grid_sample_c2_y4.png -------------------------------------------------------------------------------- /images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_bottom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_bottom.png -------------------------------------------------------------------------------- /images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_original.png -------------------------------------------------------------------------------- /images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_top.png -------------------------------------------------------------------------------- /images/vqvae2/generation_sample_step_52440_top_b128_c128_outstack10_bottom_b16_c128_nres20_condstack10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kamenbliznashki/generative_models/428671c4966ebd7fa0c2a5ed58f69ac4e30c5e62/images/vqvae2/generation_sample_step_52440_top_b128_c128_outstack10_bottom_b16_c128_nres20_condstack10.png -------------------------------------------------------------------------------- /infogan.py: -------------------------------------------------------------------------------- 1 | """ 2 | InfoGAN -- https://arxiv.org/abs/1606.03657 3 | 4 | Follows the Tensorflow implementation at http://www.depthfirstlearning.com/2018/InfoGAN 5 | 6 | """ 7 | 8 | 9 | import os 10 | import argparse 11 | from tqdm import tqdm 12 | import time 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.distributions as dist 18 | from torch.distributions.one_hot_categorical import OneHotCategorical 19 | from torch.utils.data import DataLoader 20 | from torchvision.datasets import MNIST 21 | import torchvision.transforms as T 22 | from torchvision.utils import save_image, make_grid 23 | 24 | import utils 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | # training params 29 | parser.add_argument('--batch_size', type=int, default=128) 30 | parser.add_argument('--n_epochs', type=int, default=1) 31 | parser.add_argument('--noise_dim', type=int, default=62, help='Size of the categorical latent representation') 32 | parser.add_argument('--cat_dim', type=int, default=10, help='Size of the categorical latent representation') 33 | parser.add_argument('--cont_dim', type=int, default=2, help='Size of the continuous latent representation') 34 | parser.add_argument('--info_reg_coeff', default=1., help='The weight of the MI regularization hyperparameter') 35 | parser.add_argument('--g_lr', default=1e-3, help='Generator learning rate') 36 | parser.add_argument('--d_lr', default=2e-4, help='Discriminator learning rate') 37 | parser.add_argument('--log_interval', default=100) 38 | parser.add_argument('--cuda', type=int, help='Which cuda device to use') 39 | parser.add_argument('--mini_data', action='store_true') 40 | # eval params 41 | parser.add_argument('--evaluate_on_grid', action='store_true') 42 | # data paths 43 | parser.add_argument('--save_model', action='store_true') 44 | parser.add_argument('--data_dir', default='./data') 45 | parser.add_argument('--output_dir', default='./results/infogan') 46 | parser.add_argument('--restore_file', help='Path to .pt checkpoint file for Discriminator and Generator') 47 | 48 | 49 | 50 | 51 | # -------------------- 52 | # Data 53 | # -------------------- 54 | 55 | def fetch_dataloader(args, train=True, download=True, mini_size=128): 56 | # load dataset and init in the dataloader 57 | 58 | transforms = T.Compose([T.ToTensor()]) 59 | dataset = MNIST(root=args.data_dir, train=train, download=download, transform=transforms) 60 | 61 | # load dataset and init in the dataloader 62 | if args.mini_data: 63 | if train: 64 | dataset.train_data = dataset.train_data[:mini_size] 65 | dataset.train_labels = dataset.train_labels[:mini_size] 66 | else: 67 | dataset.test_data = dataset.test_data[:mini_size] 68 | dataset.test_labels = dataset.test_labels[:mini_size] 69 | 70 | 71 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {} 72 | 73 | dl = DataLoader(dataset, batch_size=args.batch_size, shuffle=train, drop_last=True, **kwargs) 74 | 75 | return dl 76 | 77 | # -------------------- 78 | # Model 79 | # -------------------- 80 | 81 | class Flatten(nn.Module): 82 | def forward(self, x): 83 | return x.view(x.shape[0], -1) 84 | 85 | class Unflatten(nn.Module): 86 | def __init__(self, B, C, H, W): 87 | super().__init__() 88 | self.B = B 89 | self.C = C 90 | self.H = H 91 | self.W = W 92 | 93 | def forward(self, x): 94 | return x.reshape(self.B, self.C, self.H, self.W) 95 | 96 | 97 | class Discriminator(nn.Module): 98 | """ base for the Discriminator (D) and latent recognition network (Q) """ 99 | def __init__(self): 100 | super().__init__() 101 | # base network shared between discriminator D and recognition network Q 102 | self.base_net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # out (B, 64, 14, 14) 103 | nn.LeakyReLU(0.1, True), 104 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # out (B, 128, 7, 7) 105 | nn.BatchNorm2d(128), 106 | nn.LeakyReLU(0.1, True), 107 | Flatten(), 108 | nn.Linear(128*7*7, 1024, bias=False), 109 | nn.BatchNorm1d(1024), 110 | nn.LeakyReLU(0.1, True)) 111 | 112 | # discriminator -- real vs fake binary output 113 | self.d = nn.Linear(1024, 1) 114 | 115 | def forward(self, x): 116 | x = self.base_net(x).squeeze() 117 | logits_real = self.d(x) 118 | # return feature representation and real vs fake prob 119 | return x, dist.Bernoulli(logits=logits_real) 120 | 121 | 122 | class Q(nn.Module): 123 | """ Latent space recognition network; shares base network of the discriminator """ 124 | def __init__(self, cat_dim, cont_dim, fix_cont_std=True): 125 | super().__init__() 126 | self.cat_dim = cat_dim 127 | self.cont_dim = cont_dim 128 | self.fix_cont_std = fix_cont_std 129 | 130 | # recognition network for latent vars ie encoder, shared between the factors of q 131 | self.encoder = nn.Sequential(nn.Linear(1024, 128, bias=False), 132 | nn.BatchNorm1d(128), 133 | nn.LeakyReLU(0.1, True)) 134 | 135 | # the factors of q -- 1 categorical and 2 continuous variables 136 | self.q = nn.Linear(128, cat_dim + 2 * cont_dim) 137 | 138 | def forward(self, x): 139 | # latent space encoding 140 | z = self.encoder(x) 141 | 142 | logits_cat, cont_mu, cont_var = torch.split(self.q(z), [self.cat_dim, self.cont_dim, self.cont_dim], dim=-1) 143 | 144 | if self.fix_cont_std: 145 | cont_sigma = torch.ones_like(cont_mu) 146 | else: 147 | cont_sigma = F.softplus(cont_var) 148 | 149 | q_cat = dist.Categorical(logits=logits_cat) 150 | q_cont = dist.Normal(loc=cont_mu, scale=cont_sigma) 151 | 152 | return q_cat, q_cont 153 | 154 | 155 | class Generator(nn.Module): 156 | def __init__(self): 157 | super().__init__() 158 | self.net = nn.Sequential(nn.Linear(74, 1024, bias=False), 159 | nn.BatchNorm1d(1024), 160 | nn.ReLU(True), 161 | nn.Linear(1024, 7*7*128), 162 | nn.BatchNorm1d(7*7*128), 163 | nn.ReLU(True), 164 | Unflatten(-1, 128, 7, 7), 165 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False), 166 | nn.BatchNorm2d(64), 167 | nn.ReLU(True), 168 | nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False), 169 | nn.Sigmoid()) 170 | 171 | def forward(self, x): 172 | return self.net(x) 173 | 174 | 175 | def initialize_weights(m): 176 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 177 | m.weight.data.normal_(mean=1., std=0.02) 178 | m.bias.data.fill_(0.) 179 | else: 180 | try: 181 | m.weight.data.normal_(std=0.02) 182 | except AttributeError: # skip activation layers 183 | pass 184 | 185 | 186 | # -------------------- 187 | # Train 188 | # -------------------- 189 | 190 | def sample_z(args): 191 | # generate samples from the prior 192 | z_cat = OneHotCategorical(logits=torch.zeros(args.batch_size, args.cat_dim)).sample() 193 | z_noise = dist.Uniform(-1, 1).sample(torch.Size((args.batch_size, args.noise_dim))) 194 | z_cont = dist.Uniform(-1, 1).sample(torch.Size((args.batch_size, args.cont_dim))) 195 | 196 | # concatenate the incompressible noise, discrete latest, and continuous latents 197 | z = torch.cat([z_noise, z_cat, z_cont], dim=1) 198 | 199 | return z.to(args.device), z_cat.to(args.device), z_noise.to(args.device), z_cont.to(args.device) 200 | 201 | 202 | def info_loss_fn(cat_fake, cont_fake, z_cat, z_cont, args): 203 | log_prob_cat = cat_fake.log_prob(z_cat.nonzero()[:,1]).mean() # equivalent to pytorch cross_entropy loss fn 204 | log_prob_cont = cont_fake.log_prob(z_cont).sum(1).mean() 205 | 206 | info_loss = - args.info_reg_coeff * (log_prob_cat + log_prob_cont) 207 | return log_prob_cat, log_prob_cont, info_loss 208 | 209 | 210 | 211 | def train_epoch(D, Q, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args): 212 | 213 | fixed_z, _, _, _ = sample_z(args) 214 | 215 | real_labels = torch.ones(args.batch_size, 1, device=args.device).requires_grad_(False) 216 | fake_labels = torch.zeros(args.batch_size, 1, device=args.device).requires_grad_(False) 217 | 218 | with tqdm(total=len(dataloader), desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar: 219 | time.sleep(0.1) 220 | 221 | for i, (x, _) in enumerate(dataloader): 222 | D.train() 223 | G.train() 224 | 225 | x = x.to(args.device) 226 | # x = 2*x - 0.5 227 | 228 | 229 | # train Generator 230 | z, z_cat, z_noise, z_cont = sample_z(args) 231 | 232 | generated = G(z) 233 | x_pre_q, d_fake = D(generated) 234 | q_cat, q_cont = Q(x_pre_q) 235 | 236 | gan_g_loss = - d_fake.log_prob(real_labels).mean() # equivalent to pytorch binary_cross_entropy_with_logits loss fn 237 | log_prob_cat, log_prob_cont, info_loss = info_loss_fn(q_cat, q_cont, z_cat, z_cont, args) 238 | 239 | g_loss = gan_g_loss + info_loss 240 | 241 | g_optimizer.zero_grad() 242 | g_loss.backward() 243 | g_optimizer.step() 244 | 245 | 246 | # train Discriminator 247 | _, d_real = D(x) 248 | x_pre_q, d_fake = D(generated.detach()) 249 | q_cat, q_cont = Q(x_pre_q) 250 | 251 | gan_d_loss = - d_real.log_prob(real_labels).mean() - d_fake.log_prob(fake_labels).mean() 252 | log_prob_cat, log_prob_cont, info_loss = info_loss_fn(q_cat, q_cont, z_cat, z_cont, args) 253 | 254 | d_loss = gan_d_loss + info_loss 255 | 256 | d_optimizer.zero_grad() 257 | d_loss.backward() 258 | d_optimizer.step() 259 | 260 | 261 | # update tracking 262 | pbar.set_postfix(log_prob_cat='{:.3f}'.format(log_prob_cat.item()), 263 | log_prob_cont='{:.3f}'.format(log_prob_cont.item()), 264 | d_loss='{:.3f}'.format(gan_d_loss.item()), 265 | g_loss='{:.3f}'.format(gan_g_loss.item()), 266 | i_loss='{:.3f}'.format(info_loss.item())) 267 | pbar.update() 268 | 269 | if i % args.log_interval == 0: 270 | step = epoch 271 | writer.add_scalar('gan_d_loss', gan_d_loss.item(), step) 272 | writer.add_scalar('gan_g_loss', gan_g_loss.item(), step) 273 | writer.add_scalar('info_loss', info_loss.item(), step) 274 | writer.add_scalar('log_prob_cat', log_prob_cat.item(), step) 275 | writer.add_scalar('log_prob_cont', log_prob_cont.item(), step) 276 | # sample images 277 | with torch.no_grad(): 278 | G.eval() 279 | fake_images = G(fixed_z) 280 | writer.add_image('generated', make_grid(fake_images[:10].cpu(), nrow=10, normalize=True, padding=1), step) 281 | save_image(fake_images[:10].cpu(), 282 | os.path.join(args.output_dir, 'generated_sample_epoch_{}.png'.format(epoch)), 283 | nrow=10) 284 | 285 | 286 | def train(D, Q, G, dataloader, d_optimizer, g_optimizer, writer, args): 287 | 288 | print('Starting training with args:\n', args) 289 | 290 | start_epoch = 0 291 | 292 | if args.restore_file: 293 | print('Restoring parameters from {}'.format(args.restore_file)) 294 | start_epoch = utils.load_checkpoint(args.restore_file, [D, Q, G], [d_optimizer, g_optimizer], map_location=args.device.type) 295 | args.n_epochs += start_epoch - 1 296 | print('Resuming training from epoch {}'.format(start_epoch)) 297 | 298 | for epoch in range(start_epoch, args.n_epochs): 299 | train_epoch(D, Q, G, dataloader, d_optimizer, g_optimizer, epoch, writer, args) 300 | 301 | # snapshot at end of epoch 302 | if args.save_model: 303 | utils.save_checkpoint({'epoch': epoch + 1, 304 | 'model_state_dicts': [D.state_dict(), Q.state_dict(), G.state_dict()], 305 | 'optimizer_state_dicts': [d_optimizer.state_dict(), g_optimizer.state_dict()]}, 306 | checkpoint=args.output_dir, 307 | quiet=True) 308 | 309 | # -------------------- 310 | # Evaluate 311 | # -------------------- 312 | 313 | @torch.no_grad() 314 | def evaluate_on_grid(G, writer, args): 315 | # sample noise randomly 316 | z_noise = torch.empty(100, args.noise_dim).uniform_(-1,1) 317 | # order the categorical latent 318 | z_cat = torch.eye(10).repeat(10,1) 319 | # order the first continuous latent 320 | c = torch.linspace(-2, 2, 10).view(-1,1).repeat(1,10).reshape(-1,1) 321 | z_cont = torch.cat([c, torch.zeros_like(c)], dim=1).reshape(100, 2) 322 | 323 | # combine into z and pass through generator 324 | z = torch.cat([z_noise, z_cat, z_cont], dim=1).to(args.device) 325 | fake_images = G(z) 326 | writer.add_image('c1 cont generated', make_grid(fake_images.cpu(), nrow=10, normalize=True, padding=1)) 327 | save_image(fake_images.cpu(), 328 | os.path.join(args.output_dir, 'latent_var_grid_sample_c1.png'), 329 | nrow=10) 330 | 331 | # order second continuous latent; combine into z and pass through generator 332 | z_cont = z_cont.flip(1) 333 | z = torch.cat([z_noise, z_cat, z_cont], dim=1).to(args.device) 334 | fake_images = G(z) 335 | writer.add_image('c2 cont generated', make_grid(fake_images.cpu(), nrow=10, normalize=True, padding=1)) 336 | save_image(fake_images.cpu(), 337 | os.path.join(args.output_dir, 'latent_var_grid_sample_c2.png'), 338 | nrow=10) 339 | 340 | 341 | # -------------------- 342 | # Run 343 | # -------------------- 344 | 345 | if __name__ == '__main__': 346 | args = parser.parse_args() 347 | 348 | if not os.path.isdir(args.output_dir): 349 | os.makedirs(args.output_dir) 350 | 351 | writer = utils.set_writer(args.output_dir, '_train') 352 | 353 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda is not None else 'cpu') 354 | 355 | # set seed 356 | torch.manual_seed(11122018) 357 | if args.device.type is 'cuda': torch.cuda.manual_seed(11122018) 358 | 359 | # models 360 | D = Discriminator().to(args.device) 361 | Q = Q(args.cat_dim, args.cont_dim).to(args.device) 362 | G = Generator().to(args.device) 363 | D.apply(initialize_weights) 364 | Q.apply(initialize_weights) 365 | G.apply(initialize_weights) 366 | 367 | # optimizers 368 | g_optimizer = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999)) 369 | d_optimizer = torch.optim.Adam([{'params': D.parameters()}, 370 | {'params': Q.parameters()}], lr=args.d_lr, betas=(0.5, 0.999)) 371 | 372 | # eval 373 | if args.evaluate_on_grid: 374 | print('Restoring parameters from {}'.format(args.restore_file)) 375 | _ = utils.load_checkpoint(args.restore_file, [D, Q, G], [d_optimizer, g_optimizer]) 376 | evaluate_on_grid(G, writer, args) 377 | # train 378 | else: 379 | dataloader = fetch_dataloader(args) 380 | train(D, Q, G, dataloader, d_optimizer, g_optimizer, writer, args) 381 | evaluate_on_grid(G, writer, args) 382 | 383 | writer.close() 384 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | """ Wrapper of optimizers in torch.optim for computation of exponential moving average of parameters""" 2 | 3 | import torch 4 | 5 | def build_ema_optimizer(optimizer_cls): 6 | class Optimizer(optimizer_cls): 7 | def __init__(self, *args, polyak=0.0, **kwargs): 8 | if not 0.0 <= polyak <= 1.0: 9 | raise ValueError("Invalid polyak decay rate: {}".format(polyak)) 10 | super().__init__(*args, **kwargs) 11 | self.defaults['polyak'] = polyak 12 | self.ema = False 13 | 14 | def step(self, closure=None): 15 | super().step(closure) 16 | 17 | # update exponential moving average after gradient update to parameters 18 | for group in self.param_groups: 19 | for p in group['params']: 20 | state = self.state[p] 21 | 22 | # state initialization 23 | if 'ema' not in state: 24 | state['ema'] = torch.zeros_like(p.data) 25 | 26 | # ema update 27 | state['ema'] -= (1 - self.defaults['polyak']) * (state['ema'] - p.data) 28 | 29 | def use_ema(self, use_ema_for_params=True): 30 | """ substitute exponential moving average values into parameter values """ 31 | if self.ema ^ use_ema_for_params: # logical XOR; swap only when different; 32 | try: 33 | print('Swapping EMA and parameters values. Now using: ' + ('EMA' if use_ema_for_params else 'param values')) 34 | for group in self.param_groups: 35 | for p in group['params']: 36 | data = p.data 37 | state = self.state[p] 38 | p.data = state['ema'] 39 | state['ema'] = data 40 | self.ema = use_ema_for_params 41 | except KeyError: 42 | print('Optimizer not initialized. No EMA values to swap to. Keeping parameter values.') 43 | 44 | 45 | def __repr__(self): 46 | s = super().__repr__() 47 | return self.__class__.__mro__[1].__name__ + ' (\npolyak: {}\n'.format(self.defaults['polyak']) + s.partition('\n')[2] 48 | 49 | return Optimizer 50 | 51 | Adam = build_ema_optimizer(torch.optim.Adam) 52 | RMSprop = build_ema_optimizer(torch.optim.RMSprop) 53 | 54 | 55 | if __name__ == '__main__': 56 | import copy 57 | torch.manual_seed(0) 58 | x = torch.randn(2,2) 59 | y = torch.rand(2,2) 60 | polyak = 0.9 61 | _m = torch.nn.Linear(2,2) 62 | for optim in [Adam, RMSprop]: 63 | m = copy.deepcopy(_m) 64 | o = optim(m.parameters(), lr=0.1, polyak=polyak) 65 | print('Testing: ', optim.__name__) 66 | print(o) 67 | print('init loss {:.3f}'.format(torch.mean((m(x) - y)**2).item())) 68 | p = torch.zeros_like(m.weight) 69 | for i in range(5): 70 | loss = torch.mean((m(x) - y)**2) 71 | print('step {}: loss {:.3f}'.format(i, loss.item())) 72 | o.zero_grad() 73 | loss.backward() 74 | o.step() 75 | # manual compute ema 76 | p -= (1 - polyak) * (p - m.weight.data) 77 | print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item())) 78 | print('swapping ema values for params.') 79 | o.use_ema(True) 80 | assert torch.allclose(p, m.weight) 81 | print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item())) 82 | print('swapping params for ema values.') 83 | o.use_ema(False) 84 | print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item())) 85 | print() 86 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Collection of generative methods in pytorch. 6 | 7 | # Implemented models 8 | * [Generating Diverse High-Fidelity Images with VQ-VAE-2](https://arxiv.org/abs/1906.00446) / [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) 9 | * [Attend, Infer, Repeat: Fast Scene Understanding with Generative Models](https://arxiv.org/abs/1603.08575v3) 10 | * [DRAW: A Recurrent Neural Network For Image Generation](https://arxiv.org/abs/1502.04623.pdf) 11 | * [Semi-Supervised Learning with Deep Generative Models](https://arxiv.org/abs/1406.5298) 12 | * [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657) 13 | * [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) (DCGAN) 14 | * [Auto-encoding Variational Bayes](https://arxiv.org/abs/1312.6114) 15 | 16 | The models are implemented for MNIST data; other datasets are a todo. 17 | 18 | ## Dependencies 19 | * python 3.6 20 | * pytorch 0.4.1+ 21 | * numpy 22 | * matplotlib 23 | * tensorboardx 24 | * tqdm 25 | 26 | ###### Some of the models further require 27 | * observations 28 | * imageio 29 | 30 | 31 | ## VQ-VAE2 32 | 33 | Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 (https://arxiv.org/abs/1906.00446) based on Vector Quantised VAE per Neural Discrete Representation Learning (https://arxiv.org/abs/1711.00937) with PixelCNN prior on the level 1 discrete latent variables per Conditional Image Generation with PixelCNN Decoders (https://arxiv.org/abs/1606.05328) and PixelSNAIL prior on the level 2 discrete latent variables per PixelSNAIL: An Improved Autoregressive Generative Model (https://arxiv.org/abs/1712.09763). 34 | 35 | #### Results 36 | 37 | Model reconstructions on [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) Chest X-Ray Dataset -- CheXpert is a large public dataset for chest radiograph interpretation, consisting of 224,316 chest radiographs of 65,240 patients at 320x320 pixels for the small version of the dataset. Reconstructions and samples below are for 128x128 images using a codebook of size 8 (3 bits) which for single channel 8-bit gray scale images from CheXpert results to similar compression ratios listed in paper for 8-bit RGB images. 38 | 39 | | Original | Bottom reconstruction
(using top encoding) | Top reconstruction
(using zeroed out bottom encoding) | 40 | | --- | --- | --- | 41 | | ![chexpertoriginal](images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_original.png) | ![vqvae2bottomrecon](images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_bottom.png) | ![vqvae2toprecon](images/vqvae2/128x128_bits3_eval_reconstruction_step_87300_top.png) | 42 | 43 | 44 | ##### Model samples from priors 45 | Both top and bottom prior models are pretty heavy to train; the samples below were trained only for 84k and 140k steps for top and bottom priors, respectively, using smaller model sizes than what was reported in the paper. The samples are class conditional along the rows for classes (atelectasis, cardiomegaly, consolidation, edema, pleural effusion, no finding) -- much more to be desired / improved with larger models and higher computational budget. 46 | 47 | Model parameters: 48 | * bottom prior: n_channels 128, n_res_layers 20, n_cond_stack_layers 10, drop_rate 0.1, batch size 16, lr 5e-5 49 | * top prior: n_channels 128, n_res_layers 5, n_out_stack_layers 10, drop_rate 0.1, batch size 128, lr 5e-5; (attention params: layers 4, heads 8, dq 16, dk 16, dv 128) 50 | 51 | ![vqvae2_sampes](images/vqvae2/generation_sample_step_52440_top_b128_c128_outstack10_bottom_b16_c128_nres20_condstack10.png) 52 | 53 | #### Usage 54 | 55 | To train and evaluate/reconstruct from the VAE model with hyperparameters of the paper: 56 | ``` 57 | python vqvae.py --train 58 | --n_embeddings [size of the latent space] 59 | --n_epochs [number of epochs to train] 60 | --ema [flag to use exponential moving average training for the embeddings] 61 | --cuda [cuda device to run on] 62 | 63 | python vqvae.py --evaluate 64 | --restore_dir [path to model directory with config.json and saved checkpoint] 65 | --n_samples [number of examples from the validation set to reconstruct] 66 | --cuda [cuda device to run on] 67 | ``` 68 | 69 | To train the top and bottom priors on the latent codes using 4 GPUs and Pytorch DistributedDataParallels: 70 | * the latent codes are extracted for the full dataset and saved as a pytorch dataset object, which is then loaded into memory for training 71 | * hyperparameters not shown as options below are at the defaults given by the paper (e.g. kernel size, attention parameters) 72 | 73 | ``` 74 | python -m torch.distributed.launch --nproc_per_node 4 --use_env \ 75 | vqvae_prior.py --vqvae_dir [path to vae model used for encoding the dataset and decoding samples] 76 | --train 77 | --distributed [flag to use DistributedDataParallels] 78 | --n_epochs 20 79 | --batch_size 128 80 | --lr 0.00005 81 | --which_prior top [which prior to train] 82 | --n_cond_classes 5 [number of classes to condition on] 83 | --n_channels 128 [convolutional channels throughout the architecture] 84 | --n_res_layers 5 [number of residual layers] 85 | --n_out_stack_layers 10 [output convolutional stack (used only by top prior)] 86 | --n_cond_stack_layers 0 [input conditional stack (used only by bottom prior)] 87 | --drop_rate 0.1 [dropout rate used in the residual layers] 88 | 89 | python -m torch.distributed.launch --nproc_per_node 4 --use_env \ 90 | vqvae_prior.py --vqvae_dir [path_to_vae_directory] 91 | --train 92 | --distributed 93 | --n_epochs 20 94 | --batch_size 16 95 | --lr 0.00005 96 | --which_prior bottom 97 | --n_cond_classes 5 98 | --n_channels 128 99 | --n_res_layers 20 100 | --n_out_stack_layers 0 101 | --n_cond_stack_layers 10 102 | --drop_rate 0.1 103 | ``` 104 | 105 | To generate data from a trained model and priors: 106 | ``` 107 | 108 | python -m torch.distributed.launch --nproc_per_node 4 --use_env \ 109 | vqvae_prior.py --vqvae_dir [path_to_vae_directory] 110 | --restore_dir [path_to_bottom_prior_directory, path_to_top_prior_directory] 111 | --generate 112 | --distributed 113 | --n_samples [number of samples to generate (per gpu)] 114 | ``` 115 | 116 | Useful resources 117 | * Official tensorflow implementation of the VQ layer and VAE model in Sonnet (https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py and https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb) 118 | 119 | ## AIR 120 | 121 | Reimplementation of the Attend, Infer, Repeat (AIR) architecture. 122 | https://arxiv.org/abs/1603.08575v3 123 | 124 | #### Results 125 | Model reconstructed data (top row is sample of original images, bottom row is AIR reconstruction; red attention window corresponds to first time step, green to second): 126 | 127 | ![air_recon](images/air/image_recons_270.png) 128 | 129 | EBLO and object count accuracy after 300 epochs of training using RMSprop with the default hyperparameters discussed in the paper and linear annealing of the z_pres probability. Variance coming from the discrete z_pres is alleviated using NVIL ([Mnih & Gregor](https://arxiv.org/abs/1402.0030)) but can still be seen in the count accuracy in the first 50k training iterations. 130 | 131 | 132 | | Variational bound | Count accuracy | 133 | | --- | --- | 134 | | ![air_elbo](images/air/air_elbo.png) | ![air_count](images/air/air_count.png) 135 | 136 | #### Usage 137 | To train a model with hyperparameters of the paper: 138 | ``` 139 | python air.py -- train \ 140 | -- cuda=[# of cuda device to run on] 141 | ``` 142 | 143 | To evaluate model ELBO: 144 | ``` 145 | python air.py -- evaluate \ 146 | -- restore_file=[path to .pt checkpoint] 147 | -- cuda=[# of cuda device to run on] 148 | ``` 149 | 150 | To generate data from a trained model: 151 | ``` 152 | python air.py -- generate \ 153 | -- restore_file=[path to .pt checkpoint] 154 | ``` 155 | 156 | Useful resources 157 | * tensorflow implementation https://github.com/akosiorek/attend_infer_repeat and by the same author Sequential AIR (a state-space model on top of AIR) (https://github.com/akosiorek/sqair/) 158 | * pyro implmentation and walk through http://pyro.ai/examples/air.html 159 | 160 | ## DRAW 161 | Reimplementation of the Deep Recurrent Attentive Writer (DRAW) network architecture. https://arxiv.org/abs/1502.04623 162 | 163 | #### Results 164 | Model generated data: 165 | 166 | ![draw](images/draw/generated_32_time_steps.gif) 167 | 168 | Results were achieved training at the parameters presented in the paper (except at 32 time steps) for 50 epochs. 169 | 170 | Visualizing the specific filterbank functions for read and write attention (cf Figure 3 & 4 in paper): 171 | 172 | | Extracted patches with grid filters | Applying transposed filters to reconstruct extracted image patch | 173 | | --- | --- | 174 | | ![drawread](images/draw/draw_fig_3.png) | ![drawwrite](images/draw/draw_fig_4.png) 175 | 176 | #### Usage 177 | To train a model with read and write attention (window sizes 2 and 5): 178 | ``` 179 | python draw.py -- train \ 180 | -- use_read_attn \ 181 | -- read_size=2 \ 182 | -- use_write_attn \ 183 | -- write_size=5 \ 184 | -- [add'l options: e.g. n_epoch, z_size, lstm_size] \ 185 | -- cuda=[# of cuda device to run on] 186 | ``` 187 | 188 | To evaluate model ELBO: 189 | ``` 190 | python draw.py -- evaluate \ 191 | -- restore_file=[path to .pt checkpoint] 192 | -- [model parameters: read_size, write_size, lstm_size, z_size] 193 | ``` 194 | 195 | To generate data from a trained model: 196 | ``` 197 | python draw.py -- generate \ 198 | -- restore_file=[path to .pt checkpoint] 199 | -- [model parameters: read_size, write_size, lstm_size, z_size] 200 | ``` 201 | 202 | #### Useful resources 203 | * https://github.com/jbornschein/draw 204 | * https://github.com/ericjang/draw 205 | 206 | 207 | ## Semi-supervised Learning with Deep Generative Models 208 | https://arxiv.org/abs/1406.5298 209 | 210 | Reimplementation of M2 model on MNIST. 211 | 212 | #### Results 213 | Visualization of handwriting styles learned by the model (cf Figure 1 in paper). Column 1 shows an image column from the test data followed by model generated data. Columns 2 and 3 show model generated styles for a fixed label and a linear variation of each component of a 2-d latent variable. 214 | 215 | | MNIST analogies | Varying 2-d latent z (z1) on number 2 | Varying 2-d latent z (z2) on number 4 | 216 | | --- | --- | --- | 217 | | ![analogies](images/ssvae/analogies_sample.png) | ![c1](images/ssvae/latent_var_grid_sample_c1_y2.png) | ![c2](images/ssvae/latent_var_grid_sample_c2_y4.png) 218 | 219 | #### Usage 220 | To train a model: 221 | ``` 222 | python ssvae.py -- train \ 223 | -- n_labeled=[100 | 300 | 1000 | 3000] \ 224 | -- [add'l options: e.g. n_epochs, z_dim, hidden_size] \ 225 | -- cuda=[# of cuda device to run on] 226 | ``` 227 | 228 | To evaluate model accuracy: 229 | ``` 230 | python ssvae.py -- evaluate \ 231 | -- restore_file=[path to .pt checkpoint] 232 | ``` 233 | 234 | To generate data from a trained model: 235 | ``` 236 | python ssvae.py -- generate \ 237 | -- restore_file=[path to .pt checkpoint] 238 | ``` 239 | 240 | #### Useful resource 241 | * https://github.com/dpkingma/nips14-ssl 242 | 243 | 244 | ## InfoGAN 245 | 246 | Reimplementation of InfoGan. https://arxiv.org/abs/1606.03657 247 | This follows closely the Tensorflow implementation by [Depth First Learning](http://www.depthfirstlearning.com/2018/InfoGAN) using tf.distribution, which make the model quite intuitive. 248 | 249 | #### Results 250 | 251 | Visualizing model-generated data varying each component of a 2-d continuous latent variable: 252 | 253 | | Varying 2-d latent z (z1)| Varying 2-d latent z (z2) | 254 | | --- | --- | 255 | | ![c1](images/infogan/latent_var_grid_sample_c1.png) | ![c2](images/infogan/latent_var_grid_sample_c2.png) 256 | 257 | #### Usage 258 | To train a model with read and write attention (window sizes 2 and 5): 259 | ``` 260 | python infogan.py -- n_epochs=[# epochs] \ 261 | -- cuda=[# of cuda device to run on] 262 | -- [add'l options: e.g. noise_dim, cat_dim, cont_dim] \ 263 | ``` 264 | 265 | To evaluate model and visualize latents: 266 | ``` 267 | python infogan.py -- evaluate_on_grid \ 268 | -- restore_file=[path to .pt checkpoint] 269 | ``` 270 | 271 | #### Useful resources 272 | * http://www.depthfirstlearning.com/2018/InfoGAN 273 | 274 | 275 | ## DCGAN 276 | 277 | Reimplementation of DCGAN. https://arxiv.org/abs/1511.06434 278 | 279 | #### Results 280 | Model generated data: 281 | 282 | ![dcgan](images/dcgan/latent_var_grid_sample_c1.png) 283 | 284 | #### Usage 285 | To train a model with read and write attention (window sizes 2 and 5): 286 | ``` 287 | python infogan.py -- n_epochs=[# epochs] \ 288 | -- cuda=[# of cuda device to run on] 289 | -- [add'l options: e.g. noise_dim, cat_dim, cont_dim] \ 290 | ``` 291 | 292 | To evaluate model and visualize latents: 293 | ``` 294 | python infogan.py -- evaluate_on_grid \ 295 | -- restore_file=[path to .pt checkpoint] 296 | ``` 297 | 298 | #### Useful resources 299 | * pytorch code examples https://github.com/pytorch/examples/ 300 | 301 | 302 | ## Auto-encoding Variational Bayes 303 | Reimplementation of https://arxiv.org/abs/1312.6114 304 | 305 | #### Results 306 | 307 | Visualizing reconstruction (after training for 25 epochs): 308 | 309 | | Real samples (left) and model reconstruction (right) | 310 | | --- | 311 | | ![vae_recon](images/basic_vae/reconstruction_at_epoch_24.png) | 312 | 313 | Visualizing model-generated data and TSNE embedding in latent space: 314 | 315 | | Model-generated data using Normal(0,1) prior | TSNE embedding in latent space | 316 | | --- | --- | 317 | | ![vae_sample](images/basic_vae/sample_at_epoch_24.png) | ![vae_tsne](images/basic_vae/tsne_embedding.png) | 318 | 319 | 320 | #### Usage 321 | To train and evaluate a model on MNIST: 322 | ``` 323 | python basic_vae.py -- n_epochs=[# epochs] mnist 324 | ``` 325 | 326 | #### Useful resources 327 | * Implementation in Pyro and quick tutorial http://pyro.ai/examples/vae.html 328 | -------------------------------------------------------------------------------- /ssvae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Semi-supervised Learning with Deep Generative Models 3 | https://arxiv.org/pdf/1406.5298.pdf 4 | """ 5 | 6 | import os 7 | import argparse 8 | from tqdm import tqdm 9 | import pprint 10 | import copy 11 | 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.distributions as D 18 | import torchvision.transforms as T 19 | from torch.utils.data import DataLoader 20 | from torchvision.datasets import MNIST 21 | from torchvision.utils import save_image, make_grid 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | # actions 26 | parser.add_argument('--train', action='store_true', help='Train a new or restored model.') 27 | parser.add_argument('--evaluate', action='store_true', help='Evaluate a model.') 28 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.') 29 | parser.add_argument('--vis_styles', action='store_true', help='Visualize styles manifold.') 30 | parser.add_argument('--cuda', type=int, help='Which cuda device to use') 31 | parser.add_argument('--seed', type=int, default=1, help='Random seed.') 32 | 33 | # file paths 34 | parser.add_argument('--restore_file', type=str, help='Path to model to restore.') 35 | parser.add_argument('--data_dir', default='./data/', help='Location of dataset.') 36 | parser.add_argument('--output_dir', default='./results/{}'.format(os.path.splitext(__file__)[0])) 37 | parser.add_argument('--results_file', default='results.txt', help='Filename where to store settings and test results.') 38 | 39 | # model parameters 40 | parser.add_argument('--image_dims', type=tuple, default=(1,28,28), help='Dimensions of a single datapoint (e.g. (1,28,28) for MNIST).') 41 | parser.add_argument('--z_dim', type=int, default=50, help='Size of the latent representation.') 42 | parser.add_argument('--y_dim', type=int, default=10, help='Size of the labels / output.') 43 | parser.add_argument('--hidden_dim', type=int, default=500, help='Size of the hidden layer.') 44 | 45 | # training params 46 | parser.add_argument('--n_labeled', type=int, default=3000, help='Number of labeled training examples in the dataset') 47 | parser.add_argument('--batch_size', type=int, default=100) 48 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.') 49 | parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') 50 | parser.add_argument('--alpha', type=float, default=0.1, help='Classifier loss multiplier controlling generative vs. discriminative learning.') 51 | 52 | 53 | # -------------------- 54 | # Data 55 | # -------------------- 56 | 57 | # create semi-supervised datasets of labeled and unlabeled data with equal number of labels from each class 58 | def create_semisupervised_datasets(dataset, n_labeled): 59 | # note this is only relevant for training the model 60 | assert dataset.train == True, 'Dataset must be the training set; assure dataset.train = True.' 61 | 62 | # compile new x and y and replace the dataset.train_data and train_labels with the 63 | x = dataset.train_data 64 | y = dataset.train_labels 65 | n_x = x.shape[0] 66 | n_classes = len(torch.unique(y)) 67 | 68 | assert n_labeled % n_classes == 0, 'n_labeld not divisible by n_classes; cannot assure class balance.' 69 | n_labeled_per_class = n_labeled // n_classes 70 | 71 | x_labeled = [0] * n_classes 72 | x_unlabeled = [0] * n_classes 73 | y_labeled = [0] * n_classes 74 | y_unlabeled = [0] * n_classes 75 | 76 | for i in range(n_classes): 77 | idxs = (y == i).nonzero().data.numpy() 78 | np.random.shuffle(idxs) 79 | 80 | x_labeled[i] = x[idxs][:n_labeled_per_class] 81 | y_labeled[i] = y[idxs][:n_labeled_per_class] 82 | x_unlabeled[i] = x[idxs][n_labeled_per_class:] 83 | y_unlabeled[i] = y[idxs][n_labeled_per_class:] 84 | 85 | # construct new labeled and unlabeled datasets 86 | labeled_dataset = copy.deepcopy(dataset) 87 | labeled_dataset.train_data = torch.cat(x_labeled, dim=0).squeeze() 88 | labeled_dataset.train_labels = torch.cat(y_labeled, dim=0) 89 | 90 | unlabeled_dataset = copy.deepcopy(dataset) 91 | unlabeled_dataset.train_data = torch.cat(x_unlabeled, dim=0).squeeze() 92 | unlabeled_dataset.train_labels = torch.cat(y_unlabeled, dim=0) 93 | 94 | del dataset 95 | 96 | return labeled_dataset, unlabeled_dataset 97 | 98 | 99 | def fetch_dataloaders(args): 100 | assert args.n_labeled != None, 'Must provide n_labeled number to split dataset.' 101 | 102 | transforms = T.Compose([T.ToTensor()]) 103 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.device.type is 'cuda' else {} 104 | 105 | def get_dataset(train): 106 | return MNIST(root=args.data_dir, train=train, transform=transforms) 107 | 108 | def get_dl(dataset): 109 | return DataLoader(dataset, batch_size=args.batch_size, shuffle=dataset.train, drop_last=True, **kwargs) 110 | 111 | test_dataset = get_dataset(train=False) 112 | train_dataset = get_dataset(train=True) 113 | labeled_dataset, unlabeled_dataset = create_semisupervised_datasets(train_dataset, args.n_labeled) 114 | 115 | return get_dl(labeled_dataset), get_dl(unlabeled_dataset), get_dl(test_dataset) 116 | 117 | 118 | def one_hot(x, label_size): 119 | out = torch.zeros(len(x), label_size).to(x.device) 120 | out[torch.arange(len(x)), x.squeeze()] = 1 121 | return out 122 | 123 | # -------------------- 124 | # Model 125 | # -------------------- 126 | 127 | class SSVAE(nn.Module): 128 | """ 129 | Data model (SSL paper eq 2): 130 | p(y) = Cat(y|pi) 131 | p(z) = Normal(z|0,1) 132 | p(x|y,z) = f(x; z,y,theta) 133 | 134 | Recognition model / approximate posterior q_phi (SSL paper eq 4): 135 | q(y|x) = Cat(y|pi_phi(x)) 136 | q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) 137 | 138 | 139 | """ 140 | def __init__(self, args): 141 | super().__init__() 142 | C, H, W = args.image_dims 143 | x_dim = C * H * W 144 | 145 | # -------------------- 146 | # p model -- SSL paper generative semi supervised model M2 147 | # -------------------- 148 | 149 | self.p_y = D.OneHotCategorical(probs=1 / args.y_dim * torch.ones(1,args.y_dim, device=args.device)) 150 | self.p_z = D.Normal(torch.tensor(0., device=args.device), torch.tensor(1., device=args.device)) 151 | 152 | # parametrized data likelihood p(x|y,z) 153 | self.decoder = nn.Sequential(nn.Linear(args.z_dim + args.y_dim, args.hidden_dim), 154 | nn.Softplus(), 155 | nn.Linear(args.hidden_dim, args.hidden_dim), 156 | nn.Softplus(), 157 | nn.Linear(args.hidden_dim, x_dim)) 158 | 159 | # -------------------- 160 | # q model -- SSL paper eq 4 161 | # -------------------- 162 | 163 | # parametrized q(y|x) = Cat(y|pi_phi(x)) -- outputs parametrization of categorical distribution 164 | self.encoder_y = nn.Sequential(nn.Linear(x_dim, args.hidden_dim), 165 | nn.Softplus(), 166 | nn.Linear(args.hidden_dim, args.hidden_dim), 167 | nn.Softplus(), 168 | nn.Linear(args.hidden_dim, args.y_dim)) 169 | 170 | # parametrized q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- output parametrizations for mean and diagonal variance of a Normal distribution 171 | self.encoder_z = nn.Sequential(nn.Linear(x_dim + args.y_dim, args.hidden_dim), 172 | nn.Softplus(), 173 | nn.Linear(args.hidden_dim, args.hidden_dim), 174 | nn.Softplus(), 175 | nn.Linear(args.hidden_dim, 2 * args.z_dim)) 176 | 177 | 178 | # initialize weights to N(0, 0.001) and biases to 0 (cf SSL section 4.4) 179 | for p in self.parameters(): 180 | p.data.normal_(0, 0.001) 181 | if p.ndimension() == 1: p.data.fill_(0.) 182 | 183 | # q(z|x,y) = Normal(z|mu_phi(x,y), diag(sigma2_phi(x))) -- SSL paper eq 4 184 | def encode_z(self, x, y): 185 | xy = torch.cat([x, y], dim=1) 186 | mu, logsigma = self.encoder_z(xy).chunk(2, dim=-1) 187 | return D.Normal(mu, logsigma.exp()) 188 | 189 | # q(y|x) = Categorical(y|pi_phi(x)) -- SSL paper eq 4 190 | def encode_y(self, x): 191 | return D.OneHotCategorical(logits=self.encoder_y(x)) 192 | 193 | # p(x|y,z) = Bernoulli 194 | def decode(self, y, z): 195 | yz = torch.cat([y,z], dim=1) 196 | return D.Bernoulli(logits=self.decoder(yz)) 197 | 198 | # classification model q(y|x) using the trained q distribution 199 | def forward(self, x): 200 | y_probs = self.encode_y(x).probs 201 | return y_probs.max(dim=1)[1] # return pred labels = argmax 202 | 203 | 204 | def loss_components_fn(x, y, z, p_y, p_z, p_x_yz, q_z_xy): 205 | # SSL paper eq 6 for an given y (observed or enumerated from q_y) 206 | return - p_x_yz.log_prob(x).sum(1) \ 207 | - p_y.log_prob(y) \ 208 | - p_z.log_prob(z).sum(1) \ 209 | + q_z_xy.log_prob(z).sum(1) 210 | 211 | 212 | # -------------------- 213 | # Train and eval 214 | # -------------------- 215 | 216 | def train_epoch(model, labeled_dataloader, unlabeled_dataloader, loss_components_fn, optimizer, epoch, args): 217 | model.train() 218 | 219 | n_batches = len(labeled_dataloader) + len(unlabeled_dataloader) 220 | n_unlabeled_per_labeled = len(unlabeled_dataloader) // len(labeled_dataloader) + 1 221 | 222 | labeled_dataloader = iter(labeled_dataloader) 223 | unlabeled_dataloader = iter(unlabeled_dataloader) 224 | 225 | with tqdm(total=n_batches, desc='epoch {} of {}'.format(epoch+1, args.n_epochs)) as pbar: 226 | for i in range(n_batches): 227 | is_supervised = i % n_unlabeled_per_labeled == 0 228 | 229 | # get batch from respective dataloader 230 | if is_supervised: 231 | x, y = next(labeled_dataloader) 232 | y = one_hot(y, args.y_dim).to(args.device) 233 | else: 234 | x, y = next(unlabeled_dataloader) 235 | y = None 236 | x = x.to(args.device).view(x.shape[0], -1) 237 | 238 | # compute loss -- SSL paper eq 6, 7, 9 239 | q_y = model.encode_y(x) 240 | # labeled data loss -- SSL paper eq 6 and eq 9 241 | if y is not None: 242 | q_z_xy = model.encode_z(x, y) 243 | z = q_z_xy.rsample() 244 | p_x_yz = model.decode(y, z) 245 | loss = loss_components_fn(x, y, z, model.p_y, model.p_z, p_x_yz, q_z_xy) 246 | loss -= args.alpha * args.n_labeled * q_y.log_prob(y) # SSL eq 9 247 | # unlabeled data loss -- SSL paper eq 7 248 | else: 249 | # marginalize y according to q_y 250 | loss = - q_y.entropy() 251 | for y in q_y.enumerate_support(): 252 | q_z_xy = model.encode_z(x, y) 253 | z = q_z_xy.rsample() 254 | p_x_yz = model.decode(y, z) 255 | L_xy = loss_components_fn(x, y, z, model.p_y, model.p_z, p_x_yz, q_z_xy) 256 | loss += q_y.log_prob(y).exp() * L_xy 257 | loss = loss.mean(0) 258 | 259 | optimizer.zero_grad() 260 | loss.backward() 261 | optimizer.step() 262 | 263 | # update trackers 264 | pbar.set_postfix(loss='{:.3f}'.format(loss.item())) 265 | pbar.update() 266 | 267 | 268 | @torch.no_grad() 269 | def evaluate(model, dataloader, epoch, args): 270 | model.eval() 271 | 272 | accurate_preds = 0 273 | 274 | with tqdm(total=len(dataloader), desc='eval') as pbar: 275 | for i, (x, y) in enumerate(dataloader): 276 | x = x.to(args.device).view(x.shape[0], -1) 277 | y = y.to(args.device) 278 | preds = model(x) 279 | 280 | accurate_preds += (preds == y).sum().item() 281 | 282 | pbar.set_postfix(accuracy='{:.3f}'.format(accurate_preds / ((i+1) * args.batch_size))) 283 | pbar.update() 284 | 285 | output = (epoch != None)*'Epoch {} -- '.format(epoch) + 'Test set accuracy: {:.3f}'.format(accurate_preds / (args.batch_size * len(dataloader))) 286 | print(output) 287 | print(output, file=open(args.results_file, 'a')) 288 | 289 | 290 | def train_and_evaluate(model, labeled_dataloader, unlabeled_dataloader, test_dataloader, loss_components_fn, optimizer, args): 291 | for epoch in range(args.n_epochs): 292 | train_epoch(model, labeled_dataloader, unlabeled_dataloader, loss_components_fn, optimizer, epoch, args) 293 | evaluate(model, test_dataloader, epoch, args) 294 | 295 | # save weights 296 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'ssvae_model_state_hdim{}_zdim{}.pt'.format( 297 | args.hidden_dim, args.z_dim))) 298 | 299 | # show samples -- SSL paper Figure 1-b 300 | generate(model, test_dataloader.dataset, args, epoch) 301 | 302 | 303 | # -------------------- 304 | # Visualize 305 | # -------------------- 306 | 307 | @torch.no_grad() 308 | def generate(model, dataset, args, epoch=None, n_samples=10): 309 | n_samples_per_label = 10 310 | 311 | # some interesting samples per paper implementation 312 | idxs = [7910, 8150, 3623, 2645, 4066, 9660, 5083, 948, 2595, 2] 313 | 314 | x = torch.stack([dataset[i][0] for i in idxs], dim=0).to(args.device) 315 | y = torch.stack([dataset[i][1] for i in idxs], dim=0).to(args.device) 316 | y = one_hot(y, args.y_dim) 317 | 318 | q_z_xy = model.encode_z(x.view(n_samples_per_label, -1), y) 319 | z = q_z_xy.loc 320 | z = z.repeat(args.y_dim, 1, 1).transpose(0, 1).contiguous().view(-1, args.z_dim) 321 | 322 | # hold z constant and vary y: 323 | y = torch.eye(args.y_dim).repeat(n_samples_per_label, 1).to(args.device) 324 | generated_x = model.decode(y, z).probs.view(n_samples_per_label, args.y_dim, *args.image_dims) 325 | generated_x = generated_x.contiguous().view(-1, *args.image_dims) # out (n_samples * n_label, C, H, W) 326 | 327 | x = make_grid(x.cpu(), nrow=1) 328 | spacer = torch.ones(x.shape[0], x.shape[1], 5) 329 | generated_x = make_grid(generated_x.cpu(), nrow=args.y_dim) 330 | image = torch.cat([x, spacer, generated_x], dim=-1) 331 | save_image(image, 332 | os.path.join(args.output_dir, 'analogies_sample' + (epoch != None)*'_at_epoch_{}'.format(epoch) + '.png'), 333 | nrow=args.y_dim) 334 | 335 | 336 | @torch.no_grad() 337 | def vis_styles(model, args): 338 | assert args.z_dim == 2, 'Style viualization requires z_dim=2' 339 | 340 | for y in range(2,5): 341 | y = one_hot(torch.tensor(y).unsqueeze(-1), args.y_dim).expand(100, args.y_dim).to(args.device) 342 | 343 | # order the first dim of the z latent 344 | c = torch.linspace(-5, 5, 10).view(-1,1).repeat(1,10).reshape(-1,1) 345 | z = torch.cat([c, torch.zeros_like(c)], dim=1).reshape(100, 2).to(args.device) 346 | 347 | # combine into z and pass through decoder 348 | x = model.decode(y, z).probs.view(y.shape[0], *args.image_dims) 349 | save_image(x.cpu(), 350 | os.path.join(args.output_dir, 'latent_var_grid_sample_c1_y{}.png'.format(y[0].nonzero().item())), 351 | nrow=10) 352 | 353 | # order second dim of latent and pass through decoder 354 | z = z.flip(1) 355 | x = model.decode(y, z).probs.view(y.shape[0], *args.image_dims) 356 | save_image(x.cpu(), 357 | os.path.join(args.output_dir, 'latent_var_grid_sample_c2_y{}.png'.format(y[0].nonzero().item())), 358 | nrow=10) 359 | 360 | 361 | # -------------------- 362 | # Main 363 | # -------------------- 364 | 365 | if __name__ == '__main__': 366 | args = parser.parse_args() 367 | 368 | if not os.path.isdir(args.output_dir): 369 | os.makedirs(args.output_dir) 370 | 371 | args.device = torch.device('cuda:{}'.format(args.cuda) if torch.cuda.is_available() and args.cuda != None else 'cpu') 372 | torch.manual_seed(args.seed) 373 | if args.device.type == 'cuda': torch.cuda.manual_seed(args.seed) 374 | 375 | # dataloaders 376 | labeled_dataloader, unlabeled_dataloader, test_dataloader = fetch_dataloaders(args) 377 | 378 | # model 379 | model = SSVAE(args).to(args.device) 380 | 381 | # optimizer 382 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 383 | 384 | if args.restore_file: 385 | # load model and optimizer states 386 | state = torch.load(args.restore_file, map_location=args.device) 387 | model.load_state_dict(state) 388 | # set up paths 389 | args.output_dir = os.path.dirname(args.restore_file) 390 | args.results_file = os.path.join(args.output_dir, args.results_file) 391 | 392 | print('Loaded settings and model:') 393 | print(pprint.pformat(args.__dict__)) 394 | print(model) 395 | print(pprint.pformat(args.__dict__), file=open(args.results_file, 'a')) 396 | print(model, file=open(args.results_file, 'a')) 397 | 398 | if args.train: 399 | train_and_evaluate(model, labeled_dataloader, unlabeled_dataloader, test_dataloader, loss_components_fn, optimizer, args) 400 | 401 | if args.evaluate: 402 | evaluate(model, test_dataloader, None, args) 403 | 404 | if args.generate: 405 | generate(model, test_dataloader.dataset, args) 406 | 407 | if args.vis_styles: 408 | vis_styles(model, args) 409 | 410 | 411 | 412 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import json 5 | from datetime import datetime 6 | import torch 7 | 8 | from tensorboardX import SummaryWriter 9 | 10 | from subprocess import check_call 11 | 12 | 13 | def set_writer(log_path, comment='', restore=False): 14 | """ setup a tensorboardx summarywriter """ 15 | current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 16 | if restore: 17 | log_path = os.path.dirname(log_path) 18 | else: 19 | log_path = os.path.join(log_path, current_time + comment) 20 | writer = SummaryWriter(log_dir=log_path) 21 | return writer 22 | 23 | 24 | def save_checkpoint(state, checkpoint, is_best=None, quiet=False): 25 | """ saves model and training params at checkpoint + 'last.pt'; if is_best also saves checkpoint + 'best.pt' 26 | 27 | args 28 | state -- dict; with keys model_state_dict, optimizer_state_dict, epoch, scheduler_state_dict, etc 29 | is_best -- bool; true if best model seen so far 30 | checkpoint -- str; folder where params are to be saved 31 | """ 32 | 33 | filepath = os.path.join(checkpoint, 'state_checkpoint.pt') 34 | if not os.path.exists(checkpoint): 35 | if not quiet: 36 | print('Checkpoint directory does not exist Making directory {}'.format(checkpoint)) 37 | os.mkdir(checkpoint) 38 | 39 | torch.save(state, filepath) 40 | 41 | # if is_best: 42 | # shutil.copyfile(filepath, os.path.join(checkpoint, 'best_state_checkpoint.pt')) 43 | 44 | if not quiet: 45 | print('Checkpoint saved.') 46 | 47 | 48 | def load_checkpoint(checkpoint, models, optimizers=None, scheduler=None, best_metric=None, map_location='cpu'): 49 | """ loads model state_dict from filepath; if optimizer and lr_scheduler provided also loads them 50 | 51 | args 52 | checkpoint -- string of filename 53 | model -- torch nn.Module model 54 | optimizer -- torch.optim instance to resume from checkpoint 55 | lr_scheduler -- torch.optim.lr_scheduler instance to resume from checkpoint 56 | """ 57 | 58 | if not os.path.exists(checkpoint): 59 | raise('File does not exist {}'.format(checkpoint)) 60 | 61 | checkpoint = torch.load(checkpoint, map_location=map_location) 62 | models = [m.load_state_dict(checkpoint['model_state_dicts'][i]) for i, m in enumerate(models)] 63 | 64 | if optimizers: 65 | try: 66 | optimizers = [o.load_state_dict(checkpoint['optimizer_state_dicts'][i]) for i, o in enumerate(optimizers)] 67 | except KeyError: 68 | print('No optimizer state dict in checkpoint file') 69 | 70 | if best_metric: 71 | try: 72 | best_metric = checkpoint['best_val_acc'] 73 | except KeyError: 74 | print('No best validation accuracy recorded in checkpoint file.') 75 | 76 | if scheduler: 77 | try: 78 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 79 | except KeyError: 80 | print('No lr scheduler state dict in checkpoint file') 81 | 82 | return checkpoint['epoch'] 83 | 84 | -------------------------------------------------------------------------------- /vqvae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of VQ-VAE-2: 3 | -- van den Oord, 'Generating Diverse High-Fidelity Images with VQ-VAE-2' -- https://arxiv.org/abs/1906.00446 4 | -- van den Oord, 'Neural Discrete Representation Learning' -- https://arxiv.org/abs/1711.00937 5 | -- Roy, Theory and Experiments on Vector Quantized Autoencoders' -- https://arxiv.org/pdf/1805.11063.pdf 6 | 7 | Reference implementation of the vector quantized VAE: 8 | https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.data import DataLoader 15 | import torchvision.transforms as T 16 | from torchvision.datasets import CIFAR10 17 | from torchvision.utils import save_image, make_grid 18 | 19 | from tensorboardX import SummaryWriter 20 | from tqdm import tqdm 21 | 22 | import os 23 | import argparse 24 | import time 25 | import json 26 | import pprint 27 | from functools import partial 28 | 29 | from datasets.chexpert import ChexpertDataset 30 | 31 | parser = argparse.ArgumentParser() 32 | # action 33 | parser.add_argument('--train', action='store_true', help='Train model.') 34 | parser.add_argument('--evaluate', action='store_true', help='Evaluate model.') 35 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.') 36 | parser.add_argument('--seed', type=int, default=0, help='Random seed to use.') 37 | parser.add_argument('--cuda', type=int, help='Which cuda device to use.') 38 | parser.add_argument('--mini_data', action='store_true', help='Truncate dataset to a single minibatch.') 39 | # model 40 | parser.add_argument('--n_embeddings', default=256, type=int, help='Size of discrete latent space (K-way categorical).') 41 | parser.add_argument('--embedding_dim', default=64, type=int, help='Dimensionality of each latent embedding vector.') 42 | parser.add_argument('--n_channels', default=128, type=int, help='Number of channels in the encoder and decoder.') 43 | parser.add_argument('--n_res_channels', default=64, type=int, help='Number of channels in the residual layers.') 44 | parser.add_argument('--n_res_layers', default=2, type=int, help='Number of residual layers inside the residual block.') 45 | parser.add_argument('--n_cond_classes', type=int, help='(NOT USED here; used in training prior but requires flag for dataloader) Number of classes if conditional model.') 46 | # data params 47 | parser.add_argument('--dataset', choices=['cifar10', 'chexpert'], default='chexpert') 48 | parser.add_argument('--data_dir', default='~/data/', help='Location of datasets.') 49 | parser.add_argument('--output_dir', type=str, help='Location where weights, logs, and sample should be saved.') 50 | parser.add_argument('--restore_dir', type=str, help='Path to model config and checkpoint to restore.') 51 | # training param 52 | parser.add_argument('--batch_size', type=int, default=128, help='Training batch size.') 53 | parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate.') 54 | parser.add_argument('--lr_decay', type=float, default=0.9999965, help='Learning rate decay (assume end lr = 1e-6 @ 2m iters for init lr 0.001).') 55 | parser.add_argument('--commitment_cost', type=float, default=0.25, help='Commitment cost term in loss function.') 56 | parser.add_argument('--ema', action='store_true', help='Use exponential moving average training for the codebook.') 57 | parser.add_argument('--ema_decay', type=float, default=0.99, help='EMA decay rate.') 58 | parser.add_argument('--ema_eps', type=float, default=1e-5, help='EMA epsilon.') 59 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.') 60 | parser.add_argument('--step', type=int, default=0, help='Current step of training (number of minibatches processed).') 61 | parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.') 62 | parser.add_argument('--log_interval', type=int, default=50, help='How often to show loss statistics and save samples.') 63 | parser.add_argument('--eval_interval', type=int, default=10, help='How often to evaluate and save samples.') 64 | # distributed training params 65 | parser.add_argument('--distributed', action='store_true', default=False, help='Whether to use DistributedDataParallels on multiple machines and GPUs.') 66 | # generation param 67 | parser.add_argument('--n_samples', type=int, default=64, help='Number of samples to generate.') 68 | 69 | 70 | # -------------------- 71 | # Data and model loading 72 | # -------------------- 73 | 74 | def fetch_vqvae_dataloader(args, train=True): 75 | if args.dataset == 'cifar10': 76 | # setup dataset and dataloader -- preprocess data to [-1, 1] 77 | dataset = CIFAR10(args.data_dir, 78 | train=train, 79 | transform=T.Compose([T.ToTensor(), lambda x: x.mul(2).sub(1)]), 80 | target_transform=(lambda y: torch.eye(args.n_cond_classes)[y]) if args.n_cond_classes else None) 81 | if not 'input_dims' in args: args.input_dims = (3,32,32) 82 | elif args.dataset == 'chexpert': 83 | dataset = ChexpertDataset(args.data_dir, train, 84 | transform=T.Compose([T.ToTensor(), lambda x: x.mul(2).sub(1)])) 85 | if not 'input_dims' in args: args.input_dims = dataset.input_dims 86 | args.n_cond_classes = len(dataset.attr_idxs) 87 | 88 | if args.mini_data: 89 | dataset.data = dataset.data[:args.batch_size] 90 | return DataLoader(dataset, args.batch_size, shuffle=train, num_workers=4, pin_memory=('cuda' in args.device)) 91 | 92 | def load_model(model_cls, config, model_dir, args, restore=False, eval_mode=False, optimizer_cls=None, scheduler_cls=None, verbose=True): 93 | # load model config 94 | if config is None: config = load_json(os.path.join(model_dir, 'config_{}.json'.format(args.cuda))) 95 | # init model and distribute 96 | model = model_cls(**config).to(args.device) 97 | if args.distributed: 98 | # NOTE: DistributedDataParallel will divide and allocate batch_size to all available GPUs if device_ids are not set 99 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.cuda], output_device=args.cuda, 100 | find_unused_parameters=True) 101 | # init optimizer and scheduler 102 | optimizer = optimizer_cls(model.parameters()) if optimizer_cls else None 103 | scheduler = scheduler_cls(optimizer) if scheduler_cls else None 104 | if restore: 105 | checkpoint = torch.load(os.path.join(model_dir, 'checkpoint.pt'), map_location=args.device) 106 | if args.distributed: 107 | model.module.load_state_dict(checkpoint['state_dict']) 108 | else: 109 | model.load_state_dict(checkpoint['state_dict']) 110 | args.start_epoch = checkpoint['epoch'] + 1 111 | args.step = checkpoint['global_step'] 112 | if optimizer: optimizer.load_state_dict(torch.load(model_dir + '/optim_checkpoint.pt', map_location=args.device)) 113 | if scheduler: scheduler.load_state_dict(torch.load(model_dir + '/sched_checkpoint.pt', map_location=args.device)) 114 | if eval_mode: 115 | model.eval() 116 | # if optimizer and restore: optimizer.use_ema(True) 117 | for p in model.parameters(): p.requires_grad_(False) 118 | if verbose: 119 | print('Loaded {}\n\tconfig and state dict loaded from {}'.format(model_cls.__name__, model_dir)) 120 | print('\tmodel parameters: {:,}'.format(sum(p.numel() for p in model.parameters()))) 121 | return model, optimizer, scheduler 122 | 123 | def save_json(data, filename, args): 124 | with open(os.path.join(args.output_dir, filename + '.json'), 'w') as f: 125 | json.dump(data, f, indent=4) 126 | 127 | def load_json(file_path): 128 | with open(file_path, 'r') as f: 129 | data = json.load(f) 130 | return data 131 | 132 | # -------------------- 133 | # VQVAE components 134 | # -------------------- 135 | 136 | class VQ(nn.Module): 137 | def __init__(self, n_embeddings, embedding_dim, ema=False, ema_decay=0.99, ema_eps=1e-5): 138 | super().__init__() 139 | self.n_embeddings = n_embeddings 140 | self.embedding_dim = embedding_dim 141 | self.ema = ema 142 | self.ema_decay = ema_decay 143 | self.ema_eps = ema_eps 144 | 145 | self.embedding = nn.Embedding(n_embeddings, embedding_dim) 146 | nn.init.kaiming_uniform_(self.embedding.weight, 1) 147 | 148 | if ema: 149 | self.embedding.weight.requires_grad_(False) 150 | # set up moving averages 151 | self.register_buffer('ema_cluster_size', torch.zeros(n_embeddings)) 152 | self.register_buffer('ema_weight', self.embedding.weight.clone().detach()) 153 | 154 | def embed(self, encoding_indices): 155 | return self.embedding(encoding_indices).permute(0,4,1,2,3).squeeze(2) # in (B,1,H,W); out (B,E,H,W) 156 | 157 | def forward(self, z): 158 | # input (B,E,H,W); permute and reshape to (B*H*W,E) to compute distances in E-space 159 | flat_z = z.permute(0,2,3,1).reshape(-1, self.embedding_dim) # (B*H*W,E) 160 | # compute distances to nearest embedding 161 | distances = flat_z.pow(2).sum(1, True) + self.embedding.weight.pow(2).sum(1) - 2 * flat_z.matmul(self.embedding.weight.t()) 162 | # quantize z to nearest embedding 163 | encoding_indices = distances.argmin(1).reshape(z.shape[0], 1, *z.shape[2:]) # (B,1,H,W) 164 | z_q = self.embed(encoding_indices) 165 | 166 | # perform ema updates 167 | if self.ema and self.training: 168 | with torch.no_grad(): 169 | # update cluster size 170 | encodings = F.one_hot(encoding_indices.flatten(), self.n_embeddings).float().to(z.device) 171 | self.ema_cluster_size -= (1 - self.ema_decay) * (self.ema_cluster_size - encodings.sum(0)) 172 | # update weight 173 | dw = z.permute(1,0,2,3).flatten(1) @ encodings # (E,B*H*W) dot (B*H*W,n_embeddings) 174 | self.ema_weight -= (1 - self.ema_decay) * (self.ema_weight - dw.t()) 175 | # update embedding weight with normalized ema_weight 176 | n = self.ema_cluster_size.sum() 177 | updated_cluster_size = (self.ema_cluster_size + self.ema_eps) / (n + self.n_embeddings * self.ema_eps) * n 178 | self.embedding.weight.data = self.ema_weight / updated_cluster_size.unsqueeze(1) 179 | 180 | return encoding_indices, z_q # out (B,1,H,W) codes and (B,E,H,W) embedded codes 181 | 182 | 183 | class ResidualLayer(nn.Sequential): 184 | def __init__(self, n_channels, n_res_channels): 185 | super().__init__(nn.Conv2d(n_channels, n_res_channels, kernel_size=3, padding=1), 186 | nn.ReLU(True), 187 | nn.Conv2d(n_res_channels, n_channels, kernel_size=1)) 188 | 189 | def forward(self, x): 190 | return F.relu(x + super().forward(x), True) 191 | 192 | # -------------------- 193 | # VQVAE2 194 | # -------------------- 195 | 196 | class VQVAE2(nn.Module): 197 | def __init__(self, input_dims, n_embeddings, embedding_dim, n_channels, n_res_channels, n_res_layers, 198 | ema=False, ema_decay=0.99, ema_eps=1e-5, **kwargs): # keep kwargs so can load from config with arbitrary other args 199 | super().__init__() 200 | self.ema = ema 201 | 202 | self.enc1 = nn.Sequential(nn.Conv2d(input_dims[0], n_channels//2, kernel_size=4, stride=2, padding=1), 203 | nn.ReLU(True), 204 | nn.Conv2d(n_channels//2, n_channels, kernel_size=4, stride=2, padding=1), 205 | nn.ReLU(True), 206 | nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=1), 207 | nn.ReLU(True), 208 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]), 209 | nn.Conv2d(n_channels, embedding_dim, kernel_size=1)) 210 | 211 | self.enc2 = nn.Sequential(nn.Conv2d(embedding_dim, n_channels//2, kernel_size=4, stride=2, padding=1), 212 | nn.ReLU(True), 213 | nn.Conv2d(n_channels//2, n_channels, kernel_size=3, padding=1), 214 | nn.ReLU(True), 215 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]), 216 | nn.Conv2d(n_channels, embedding_dim, kernel_size=1)) 217 | 218 | self.dec2 = nn.Sequential(nn.Conv2d(embedding_dim, n_channels, kernel_size=3, padding=1), 219 | nn.ReLU(True), 220 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]), 221 | nn.ConvTranspose2d(n_channels, embedding_dim, kernel_size=4, stride=2, padding=1)) 222 | 223 | self.dec1 = nn.Sequential(nn.Conv2d(2*embedding_dim, n_channels, kernel_size=3, padding=1), 224 | nn.ReLU(True), 225 | nn.Sequential(*[ResidualLayer(n_channels, n_res_channels) for _ in range(n_res_layers)]), 226 | nn.ConvTranspose2d(n_channels, n_channels//2, kernel_size=4, stride=2, padding=1), 227 | nn.ReLU(True), 228 | nn.ConvTranspose2d(n_channels//2, input_dims[0], kernel_size=4, stride=2, padding=1)) 229 | 230 | self.proj_to_vq1 = nn.Conv2d(2*embedding_dim, embedding_dim, kernel_size=1) 231 | self.upsample_to_dec1 = nn.ConvTranspose2d(embedding_dim, embedding_dim, kernel_size=4, stride=2, padding=1) 232 | 233 | self.vq1 = VQ(n_embeddings, embedding_dim, ema, ema_decay, ema_eps) 234 | self.vq2 = VQ(n_embeddings, embedding_dim, ema, ema_decay, ema_eps) 235 | 236 | def encode(self, x): 237 | z1 = self.enc1(x) 238 | z2 = self.enc2(z1) 239 | return (z1, z2) # each is (B,E,H,W) 240 | 241 | def embed(self, encoding_indices): 242 | encoding_indices1, encoding_indices2 = encoding_indices 243 | return (self.vq1.embed(encoding_indices1), self.vq2.embed(encoding_indices2)) 244 | 245 | def quantize(self, z_e): 246 | # unpack inputs 247 | z1, z2 = z_e 248 | 249 | # quantize top level 250 | encoding_indices2, zq2 = self.vq2(z2) 251 | 252 | # quantize bottom level conditioned on top level decoder and bottom level encoder 253 | # decode top level 254 | quantized2 = z2 + (zq2 - z2).detach() # stop decoder optimization from accessing the embedding 255 | dec2_out = self.dec2(quantized2) 256 | # condition on bottom encoder and top decoder 257 | vq1_input = torch.cat([z1, dec2_out], 1) 258 | vq1_input = self.proj_to_vq1(vq1_input) 259 | encoding_indices1, zq1 = self.vq1(vq1_input) 260 | return (encoding_indices1, encoding_indices2), (zq1, zq2) 261 | 262 | def decode(self, z_e, z_q): 263 | # unpack inputs 264 | zq1, zq2 = z_q 265 | if z_e is not None: 266 | z1, z2 = z_e 267 | # stop decoder optimization from accessing the embedding 268 | zq1 = z1 + (zq1 - z1).detach() 269 | zq2 = z2 + (zq2 - z2).detach() 270 | 271 | # upsample quantized2 to match spacial dim of quantized1 272 | zq2_upsampled = self.upsample_to_dec1(zq2) 273 | # decode 274 | combined_latents = torch.cat([zq1, zq2_upsampled], 1) 275 | return self.dec1(combined_latents) 276 | 277 | def forward(self, x, commitment_cost, writer=None): 278 | # Figure 2a in paper 279 | z_e = self.encode(x) 280 | encoding_indices, z_q = self.quantize(z_e) 281 | recon_x = self.decode(z_e, z_q) 282 | 283 | # compute loss over the hierarchy -- cf eq 2 in paper 284 | recon_loss = F.mse_loss(recon_x, x) 285 | q_latent_loss = sum(F.mse_loss(z_i.detach(), zq_i) for z_i, zq_i in zip(z_e, z_q)) if not self.ema else torch.zeros(1, device=x.device) 286 | e_latent_loss = sum(F.mse_loss(z_i, zq_i.detach()) for z_i, zq_i in zip(z_e, z_q)) 287 | loss = recon_loss + q_latent_loss + commitment_cost * e_latent_loss 288 | 289 | if writer: 290 | # compute perplexity 291 | n_embeddings = self.vq1.embedding.num_embeddings 292 | avg_probs = lambda e: torch.histc(e.float(), bins=n_embeddings, max=n_embeddings).float().div(e.numel()) 293 | perplexity = lambda avg_probs: torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 294 | # record training stats 295 | writer.add_scalar('loss', loss.item(), args.step) 296 | writer.add_scalar('loss_recon_train', recon_loss.item(), args.step) 297 | writer.add_scalar('loss_q_latent', q_latent_loss.item(), args.step) 298 | writer.add_scalar('loss_e_latent', e_latent_loss.item(), args.step) 299 | for i, e_i in enumerate(encoding_indices): 300 | writer.add_scalar('perplexity_{}'.format(i), perplexity(avg_probs(e_i)).item(), args.step) 301 | 302 | return loss 303 | 304 | 305 | # -------------------- 306 | # Train, evaluate, reconstruct 307 | # -------------------- 308 | 309 | def train_epoch(model, dataloader, optimizer, scheduler, epoch, writer, args): 310 | model.train() 311 | 312 | with tqdm(total=len(dataloader), desc='epoch {}/{}'.format(epoch, args.start_epoch + args.n_epochs)) as pbar: 313 | for x, _ in dataloader: 314 | args.step += 1 315 | 316 | loss = model(x.to(args.device), args.commitment_cost, writer if args.step % args.log_interval == 0 else None) 317 | 318 | optimizer.zero_grad() 319 | loss.backward() 320 | optimizer.step() 321 | if scheduler: scheduler.step() 322 | 323 | pbar.set_postfix(loss='{:.4f}'.format(loss.item())) 324 | pbar.update() 325 | 326 | def show_recons_from_hierarchy(model, n_samples, x, z_q, recon_x=None): 327 | # full reconstruction 328 | if recon_x is None: 329 | recon_x = model.decode(None, z_q) 330 | # top level only reconstruction -- no contribution from bottom-level (level1) latents 331 | recon_top = model.decode(None, (z_q[0].fill_(0), z_q[1])) 332 | 333 | # construct image grid 334 | x = make_grid(x[:n_samples].cpu(), normalize=True) 335 | recon_x = make_grid(recon_x[:n_samples].cpu(), normalize=True) 336 | recon_top = make_grid(recon_top[:n_samples].cpu(), normalize=True) 337 | separator = torch.zeros(x.shape[0], 4, x.shape[2]) 338 | return torch.cat([x, separator, recon_x, separator, recon_top], dim=1) 339 | 340 | @torch.no_grad() 341 | def evaluate(model, dataloader, args): 342 | model.eval() 343 | 344 | recon_loss = 0 345 | for x, _ in tqdm(dataloader): 346 | x = x.to(args.device) 347 | z_e = model.encode(x) 348 | encoding_indices, z_q = model.quantize(z_e) 349 | recon_x = model.decode(z_e, z_q) 350 | recon_loss += F.mse_loss(recon_x, x).item() 351 | recon_loss /= len(dataloader) 352 | 353 | # reconstruct 354 | recon_image = show_recons_from_hierarchy(model, args.n_samples, x, z_q, recon_x) 355 | return recon_image, recon_loss 356 | 357 | def train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args): 358 | for epoch in range(args.start_epoch, args.start_epoch + args.n_epochs): 359 | train_epoch(model, train_dataloader, optimizer, scheduler, epoch, writer, args) 360 | 361 | # save model 362 | torch.save({'epoch': epoch, 363 | 'global_step': args.step, 364 | 'state_dict': model.state_dict()}, 365 | os.path.join(args.output_dir, 'checkpoint.pt')) 366 | torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'optim_checkpoint.pt')) 367 | if scheduler: torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'sched_checkpoint.pt')) 368 | 369 | if (epoch+1) % args.eval_interval == 0: 370 | # evaluate 371 | recon_image, recon_loss = evaluate(model, valid_dataloader, args) 372 | print('Evaluate -- recon loss: {:.4f}'.format(recon_loss)) 373 | writer.add_scalar('loss_recon_eval', recon_loss, args.step) 374 | writer.add_image('eval_reconstructions', recon_image, args.step) 375 | save_image(recon_image, os.path.join(args.output_dir, 'eval_reconstruction_step_{}'.format(args.step) + '.png')) 376 | 377 | 378 | # -------------------- 379 | # Main 380 | # -------------------- 381 | 382 | if __name__ == '__main__': 383 | args = parser.parse_args() 384 | if args.restore_dir: 385 | args.output_dir = args.restore_dir 386 | if not args.output_dir: # if not given use results/file_name/time_stamp 387 | args.output_dir = './results/{}/{}'.format(os.path.splitext(__file__)[0], time.strftime('%Y-%m-%d_%H-%M-%S', time.gmtime())) 388 | writer = SummaryWriter(log_dir = args.output_dir) 389 | 390 | args.device = 'cuda:{}'.format(args.cuda) if args.cuda is not None and torch.cuda.is_available() else 'cpu' 391 | 392 | torch.manual_seed(args.seed) 393 | 394 | # setup dataset and dataloader -- preprocess data to [-1, 1] 395 | train_dataloader = fetch_vqvae_dataloader(args, train=True) 396 | valid_dataloader = fetch_vqvae_dataloader(args, train=False) 397 | 398 | # save config 399 | if not os.path.exists(os.path.join(args.output_dir, 'config_{}.json'.format(args.cuda))): 400 | save_json(args.__dict__, 'config_{}'.format(args.cuda), args) 401 | 402 | # setup model 403 | model, optimizer, scheduler = load_model(VQVAE2, args.output_dir, args, 404 | restore=(args.restore_dir is not None), 405 | eval_mode=False, 406 | optimizer_cls=partial(torch.optim.Adam, lr=args.lr), 407 | scheduler_cls=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=args.lr_decay)) 408 | 409 | # print and write config with update step and epoch from load_model 410 | writer.add_text('config', str(args.__dict__), args.step) 411 | pprint.pprint(args.__dict__) 412 | 413 | if args.train: 414 | train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args) 415 | 416 | if args.evaluate: 417 | recon_image, recon_loss = evaluate(model, valid_dataloader, args) 418 | print('Evaluate @ step {} -- recon loss: {:.4f}'.format(args.step, recon_loss)) 419 | writer.add_scalar('loss_recon_eval', recon_loss, args.step) 420 | writer.add_image('eval_reconstructions', recon_image, args.step) 421 | save_image(recon_image, os.path.join(args.output_dir, 'eval_reconstruction_step_{}'.format(args.step) + '.png')) 422 | 423 | -------------------------------------------------------------------------------- /vqvae_prior.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of VQ-VAE-2 priors: 3 | -- van den Oord, 'Generating Diverse High-Fidelity Images with VQ-VAE-2' -- https://arxiv.org/abs/1906.00446 4 | -- van den Oord, 'Conditional Image Generation with PixelCNN Decoders' -- https://arxiv.org/abs/1606.05328 5 | -- Xi Chen, 'PixelSNAIL: An Improved Autoregressive Generative Model' -- https://arxiv.org/abs/1712.09763 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.multiprocessing as mp 12 | from torch.utils.data import DataLoader, TensorDataset 13 | from torchvision.utils import save_image, make_grid 14 | 15 | import numpy as np 16 | from tensorboardX import SummaryWriter 17 | from tqdm import tqdm 18 | 19 | import os 20 | import argparse 21 | import time 22 | import pprint 23 | from functools import partial 24 | 25 | from vqvae import VQVAE2, fetch_vqvae_dataloader, load_model, save_json, load_json 26 | from optim import Adam, RMSprop 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | # action 32 | parser.add_argument('--train', action='store_true', help='Train model.') 33 | parser.add_argument('--evaluate', action='store_true', help='Evaluate model.') 34 | parser.add_argument('--generate', action='store_true', help='Generate samples from a model.') 35 | parser.add_argument('--seed', type=int, default=0, help='Random seed to use.') 36 | parser.add_argument('--cuda', type=int, help='Which cuda device to use.') 37 | parser.add_argument('--mini_data', action='store_true', help='Truncate dataset to a single minibatch.') 38 | # model 39 | parser.add_argument('--which_prior', choices=['bottom', 'top'], help='Which prior model to train.') 40 | parser.add_argument('--vqvae_dir', type=str, required=True, help='Path to VQVAE folder with config.json and checkpoint.pt files.') 41 | parser.add_argument('--n_channels', default=128, type=int, help='Number of channels for gated residual convolutional blocks.') 42 | parser.add_argument('--n_out_conv_channels', default=1024, type=int, help='Number of channels for outer 1x1 convolutional layers.') 43 | parser.add_argument('--n_res_layers', default=20, type=int, help='Number of Gated Residual Blocks.') 44 | parser.add_argument('--n_cond_classes', default=5, type=int, help='Number of classes if conditional model.') 45 | parser.add_argument('--n_cond_stack_layers', default=10, type=int, help='Number of conditioning stack residual blocks.') 46 | parser.add_argument('--n_out_stack_layers', default=10, type=int, help='Number of output stack layers.') 47 | parser.add_argument('--kernel_size', default=5, type=int, help='Kernel size for the gated residual convolutional blocks.') 48 | parser.add_argument('--drop_rate', default=0, type=float, help='Dropout for the Gated Residual Blocks.') 49 | # data params 50 | parser.add_argument('--output_dir', type=str, help='Location where weights, logs, and sample should be saved.') 51 | parser.add_argument('--restore_dir', nargs='+', help='Location where configs and weights are to be restored from.') 52 | parser.add_argument('--n_bits', type=int, help='Number of bits of input data.') 53 | # training param 54 | parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate.') 55 | parser.add_argument('--lr_decay', type=float, default=0.99999, help='Learning rate decay (assume 5e-5 @ 300k iters for lr 0.001).') 56 | parser.add_argument('--polyak', type=float, default=0.9995, help='Polyak decay for exponential moving averaging.') 57 | parser.add_argument('--batch_size', type=int, default=16, help='Training batch size.') 58 | parser.add_argument('--n_epochs', type=int, default=1, help='Number of epochs to train.') 59 | parser.add_argument('--step', type=int, default=0, help='Current step of training (number of minibatches processed).') 60 | parser.add_argument('--start_epoch', default=0, help='Starting epoch (for logging; to be overwritten when restoring file.') 61 | parser.add_argument('--log_interval', type=int, default=50, help='How often to show loss statistics and save samples.') 62 | parser.add_argument('--eval_interval', type=int, default=10, help='How often to evaluate and save samples.') 63 | parser.add_argument('--save_interval', type=int, default=300, help='How often to evaluate and save samples.') 64 | # distributed training params 65 | parser.add_argument('--distributed', action='store_true', default=False, help='Whether to use DistributedDataParallels on multiple machines and GPUs.') 66 | parser.add_argument('--world_size', type=int, default=1) 67 | parser.add_argument('--rank', type=int, default=0) 68 | # generation param 69 | parser.add_argument('--n_samples', type=int, default=8, help='Number of samples to generate.') 70 | 71 | 72 | 73 | # -------------------- 74 | # Data and model loading 75 | # -------------------- 76 | 77 | @torch.no_grad() 78 | def extract_codes_from_dataloader(vqvae, dataloader, dataset_path): 79 | """ encode image inputs with vqvae and extract field of discrete latents (the embedding indices in the codebook with closest l2 distance) """ 80 | device = next(vqvae.parameters()).device 81 | e1s, e2s, ys = [], [], [] 82 | for x, y in tqdm(dataloader): 83 | z_e = vqvae.encode(x.to(device)) 84 | encoding_indices, _ = vqvae.quantize(z_e) # tuple of (bottom, top encoding indices) where each is (B,1,H,W) 85 | 86 | e1, e2 = encoding_indices 87 | e1s.append(e1) 88 | e2s.append(e2) 89 | ys.append(y) 90 | return TensorDataset(torch.cat(e1s).cpu(), torch.cat(e2s).cpu(), torch.cat(ys)) 91 | 92 | def maybe_extract_codes(vqvae, args, train): 93 | """ construct datasets of vqvae encodings and class conditional labels -- each dataset entry is [encodings level 1 (bottom), encodings level 2 (top), class label vector] """ 94 | # paths to load/save as `chexpert_train_codes_mini_data.pt` 95 | dataset_path = os.path.join(args.vqvae_dir, '{}_{}_codes'.format(args.dataset, 'train' if train else 'valid') + args.mini_data*'_mini_data_{}'.format(args.batch_size) + '.pt') 96 | if not os.path.exists(dataset_path): 97 | print('Extracting codes for {} data ...'.format('train' if train else 'valid')) 98 | dataloader = fetch_vqvae_dataloader(args, train) 99 | dataset = extract_codes_from_dataloader(vqvae, dataloader, dataset_path) 100 | torch.save(dataset, dataset_path) 101 | else: 102 | dataset = torch.load(dataset_path) 103 | if args.on_main_process: print('Loaded {} codes dataset of size {}'.format('train' if train else 'valid', len(dataset))) 104 | return dataset 105 | 106 | def fetch_prior_dataloader(vqvae, args, train=True): 107 | dataset = maybe_extract_codes(vqvae, args, train) 108 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if args.distributed and train else None 109 | return DataLoader(dataset, args.batch_size, shuffle=(train and sampler is None), sampler=sampler, num_workers=4, pin_memory=('cuda' in args.device)) 110 | 111 | def preprocess(x, n_bits): 112 | """ preprosses discrete latents space [0, 2**n_bits) to model space [-1,1]; if size of the codebook ie n_embeddings = 512 = 2**9 -> n_bit=9 """ 113 | # 1. convert data to float 114 | # 2. normalize to [0,1] given quantization 115 | # 3. shift to [-1,1] 116 | return x.float().div(2**n_bits - 1).mul(2).add(-1) 117 | 118 | def deprocess(x, n_bits): 119 | """ deprocess x from model space [-1,1] to discrete latents space [0, 2**n_bits) where 2**n_bits is size of the codebook """ 120 | # 1. shift to [0,1] 121 | # 2. quantize to n_bits 122 | # 3. convert data to long 123 | return x.add(1).div(2).mul(2**n_bits - 1).long() 124 | 125 | # -------------------- 126 | # PixelSNAIL -- top level prior conditioned on class labels 127 | # -------------------- 128 | 129 | def down_shift(x): 130 | return F.pad(x, (0,0,1,0))[:,:,:-1,:] 131 | 132 | def right_shift(x): 133 | return F.pad(x, (1,0))[:,:,:,:-1] 134 | 135 | def concat_elu(x): 136 | return F.elu(torch.cat([x, -x], dim=1)) 137 | 138 | class Conv2d(nn.Conv2d): 139 | def __init__(self, *args, **kwargs): 140 | super().__init__(*args, **kwargs) 141 | nn.utils.weight_norm(self) 142 | 143 | class DownShiftedConv2d(Conv2d): 144 | def forward(self, x): 145 | # pad H above and W on each side 146 | Hk, Wk = self.kernel_size 147 | x = F.pad(x, ((Wk-1)//2, (Wk-1)//2, Hk-1, 0)) 148 | return super().forward(x) 149 | 150 | class DownRightShiftedConv2d(Conv2d): 151 | def forward(self, x): 152 | # pad above and on left (ie shift input down and right) 153 | Hk, Wk = self.kernel_size 154 | x = F.pad(x, (Wk-1, 0, Hk-1, 0)) 155 | return super().forward(x) 156 | 157 | class GatedResidualLayer(nn.Module): 158 | def __init__(self, conv, n_channels, kernel_size, drop_rate=0, shortcut_channels=None, n_cond_classes=None, relu_fn=concat_elu): 159 | super().__init__() 160 | self.relu_fn = relu_fn 161 | 162 | self.c1 = conv(2*n_channels, n_channels, kernel_size) 163 | if shortcut_channels: 164 | self.c1c = Conv2d(2*shortcut_channels, n_channels, kernel_size=1) 165 | if drop_rate > 0: 166 | self.dropout = nn.Dropout(drop_rate) 167 | self.c2 = conv(2*n_channels, 2*n_channels, kernel_size) 168 | if n_cond_classes: 169 | self.proj_y = nn.Linear(n_cond_classes, 2*n_channels) 170 | 171 | def forward(self, x, a=None, y=None): 172 | c1 = self.c1(self.relu_fn(x)) 173 | if a is not None: # shortcut connection if auxiliary input 'a' is given 174 | c1 = c1 + self.c1c(self.relu_fn(a)) 175 | c1 = self.relu_fn(c1) 176 | if hasattr(self, 'dropout'): 177 | c1 = self.dropout(c1) 178 | c2 = self.c2(c1) 179 | if y is not None: 180 | c2 += self.proj_y(y)[:,:,None,None] 181 | a, b = c2.chunk(2,1) 182 | out = x + a * torch.sigmoid(b) 183 | return out 184 | 185 | def causal_attention(k, q, v, mask, nh, drop_rate, training): 186 | B, dq, H, W = q.shape 187 | _, dv, _, _ = v.shape 188 | 189 | # split channels into multiple heads, flatten H,W dims and scale q; out (B, nh, dkh or dvh, HW) 190 | flat_q = q.reshape(B, nh, dq//nh, H, W).flatten(3) * (dq//nh)**-0.5 191 | flat_k = k.reshape(B, nh, dq//nh, H, W).flatten(3) 192 | flat_v = v.reshape(B, nh, dv//nh, H, W).flatten(3) 193 | 194 | logits = torch.matmul(flat_q.transpose(2,3), flat_k) # (B,nh,HW,dq) dot (B,nh,dq,HW) = (B,nh,HW,HW) 195 | logits = F.dropout(logits, p=drop_rate, training=training, inplace=True) 196 | logits = logits.masked_fill(mask==0, -1e10) 197 | weights = F.softmax(logits, -1) 198 | 199 | attn_out = torch.matmul(weights, flat_v.transpose(2,3)) # (B,nh,HW,HW) dot (B,nh,HW,dvh) = (B,nh,HW,dvh) 200 | attn_out = attn_out.transpose(2,3) # (B,nh,dvh,HW) 201 | return attn_out.reshape(B, -1, H, W) # (B,dv,H,W) 202 | 203 | class AttentionGatedResidualLayer(nn.Module): 204 | def __init__(self, n_channels, n_background_ch, n_res_layers, n_cond_classes, drop_rate, nh, dq, dv, attn_drop_rate): 205 | super().__init__() 206 | # attn params 207 | self.nh = nh 208 | self.dq = dq 209 | self.dv = dv 210 | self.attn_drop_rate = attn_drop_rate 211 | 212 | self.input_gated_resnet = nn.ModuleList([ 213 | *[GatedResidualLayer(DownRightShiftedConv2d, n_channels, (2,2), drop_rate, None, n_cond_classes) for _ in range(n_res_layers)]]) 214 | self.in_proj_kv = nn.Sequential(GatedResidualLayer(Conv2d, 2*n_channels + n_background_ch, 1, drop_rate, None, n_cond_classes), 215 | Conv2d(2*n_channels + n_background_ch, dq+dv, 1)) 216 | self.in_proj_q = nn.Sequential(GatedResidualLayer(Conv2d, n_channels + n_background_ch, 1, drop_rate, None, n_cond_classes), 217 | Conv2d(n_channels + n_background_ch, dq, 1)) 218 | self.out_proj = GatedResidualLayer(Conv2d, n_channels, 1, drop_rate, dv, n_cond_classes) 219 | 220 | def forward(self, x, background, attn_mask, y=None): 221 | ul = x 222 | for m in self.input_gated_resnet: 223 | ul = m(ul, y=y) 224 | 225 | kv = self.in_proj_kv(torch.cat([x, ul, background], 1)) 226 | k, v = kv.split([self.dq, self.dv], 1) 227 | q = self.in_proj_q(torch.cat([ul, background], 1)) 228 | attn_out = causal_attention(k, q, v, attn_mask, self.nh, self.attn_drop_rate, self.training) 229 | return self.out_proj(ul, attn_out) 230 | 231 | class PixelSNAIL(nn.Module): 232 | def __init__(self, input_dims, n_channels, n_res_layers, n_out_stack_layers, n_cond_classes, n_bits, 233 | attn_n_layers=4, attn_nh=8, attn_dq=16, attn_dv=128, attn_drop_rate=0, drop_rate=0.5, **kwargs): 234 | super().__init__() 235 | H,W = input_dims[2] 236 | # init background 237 | background_v = ((torch.arange(H, dtype=torch.float) - H / 2) / 2).view(1,1,-1,1).expand(1,1,H,W) 238 | background_h = ((torch.arange(W, dtype=torch.float) - W / 2) / 2).view(1,1,1,-1).expand(1,1,H,W) 239 | self.register_buffer('background', torch.cat([background_v, background_h], 1)) 240 | # init attention mask over current and future pixels 241 | attn_mask = torch.tril(torch.ones(1,1,H*W,H*W), diagonal=-1).byte() # 1s below diagonal -- attend to context only 242 | self.register_buffer('attn_mask', attn_mask) 243 | 244 | # input layers for `up` and `up and to the left` pixels 245 | self.ul_input_d = DownShiftedConv2d(2, n_channels, kernel_size=(1,3)) 246 | self.ul_input_dr = DownRightShiftedConv2d(2, n_channels, kernel_size=(2,1)) 247 | self.ul_modules = nn.ModuleList([ 248 | *[AttentionGatedResidualLayer(n_channels, self.background.shape[1], n_res_layers, n_cond_classes, drop_rate, 249 | attn_nh, attn_dq, attn_dv, attn_drop_rate) for _ in range(attn_n_layers)]]) 250 | self.output_stack = nn.Sequential( 251 | *[GatedResidualLayer(DownRightShiftedConv2d, n_channels, (2,2), drop_rate, None, n_cond_classes) \ 252 | for _ in range(n_out_stack_layers)]) 253 | self.output_conv = Conv2d(n_channels, 2**n_bits, kernel_size=1) 254 | 255 | 256 | def forward(self, x, y=None): 257 | # add channel of ones to distinguish image from padding later on 258 | x = F.pad(x, (0,0,0,0,0,1), value=1) 259 | 260 | ul = down_shift(self.ul_input_d(x)) + right_shift(self.ul_input_dr(x)) 261 | for m in self.ul_modules: 262 | ul = m(ul, self.background.expand(x.shape[0],-1,-1,-1), self.attn_mask, y) 263 | ul = self.output_stack(ul) 264 | return self.output_conv(F.elu(ul)).unsqueeze(2) # out (B, 2**n_bits, 1, H, W) 265 | 266 | # -------------------- 267 | # PixelCNN -- bottom level prior conditioned on class labels and top level codes 268 | # -------------------- 269 | 270 | def pixelcnn_gate(x): 271 | a, b = x.chunk(2,1) 272 | return torch.tanh(a) * torch.sigmoid(b) 273 | 274 | class MaskedConv2d(nn.Conv2d): 275 | def __init__(self, mask_type, *args, **kwargs): 276 | self.mask_type = mask_type 277 | super().__init__(*args, **kwargs) 278 | 279 | def apply_mask(self): 280 | H, W = self.kernel_size 281 | self.weight.data[:,:,H//2+1:,:].zero_() # mask out rows below the middle 282 | self.weight.data[:,:,H//2,W//2+1:].zero_() # mask out center row pixels right of middle 283 | if self.mask_type=='a': 284 | self.weight.data[:,:,H//2,W//2] = 0 # mask out center pixel 285 | 286 | def forward(self, x): 287 | self.apply_mask() 288 | return super().forward(x) 289 | 290 | class GatedResidualBlock(nn.Module): 291 | """ Figure 2 in Conditional image generation with PixelCNN Decoders """ 292 | def __init__(self, in_channels, out_channels, kernel_size, n_cond_channels, drop_rate): 293 | super().__init__() 294 | self.residual = (in_channels==out_channels) 295 | self.drop_rate = drop_rate 296 | 297 | self.v = nn.Conv2d(in_channels, 2*out_channels, kernel_size, padding=kernel_size//2) # vertical stack 298 | self.h = nn.Conv2d(in_channels, 2*out_channels, (1, kernel_size), padding=(0, kernel_size//2)) # horizontal stack 299 | self.v2h = nn.Conv2d(2*out_channels, 2*out_channels, kernel_size=1) # vertical to horizontal connection 300 | self.h2h = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False) # horizontal to horizontal 301 | 302 | if n_cond_channels: 303 | self.in_proj_y = nn.Conv2d(n_cond_channels, 2*out_channels, kernel_size=1) 304 | 305 | if self.drop_rate > 0: 306 | self.dropout_h = nn.Dropout(drop_rate) 307 | 308 | def apply_mask(self): 309 | self.v.weight.data[:,:,self.v.kernel_size[0]//2:,:].zero_() # mask out middle row and below 310 | self.h.weight.data[:,:,:,self.h.kernel_size[1]//2+1:].zero_() # mask out to the right of the central column 311 | 312 | def forward(self, x_v, x_h, y): 313 | self.apply_mask() 314 | 315 | # projection of y if included for conditional generation (cf paper section 2.3 -- added before the pixelcnn_gate) 316 | proj_y = self.in_proj_y(y) 317 | 318 | # vertical stack 319 | x_v_out = self.v(x_v) 320 | x_v2h = self.v2h(x_v_out) + proj_y 321 | x_v_out = pixelcnn_gate(x_v_out) 322 | 323 | # horizontal stack 324 | x_h_out = self.h(x_h) + x_v2h + proj_y 325 | x_h_out = pixelcnn_gate(x_h_out) 326 | if self.drop_rate: 327 | x_h_out = self.dropout_h(x_h_out) 328 | x_h_out = self.h2h(x_h_out) 329 | 330 | # residual connection 331 | if self.residual: 332 | x_h_out = x_h_out + x_h 333 | 334 | return x_v_out, x_h_out 335 | 336 | def extra_repr(self): 337 | return 'residual={}, drop_rate={}'.format(self.residual, self.drop_rate) 338 | 339 | class PixelCNN(nn.Module): 340 | def __init__(self, n_channels, n_out_conv_channels, kernel_size, n_res_layers, n_cond_stack_layers, n_cond_classes, n_bits, 341 | drop_rate=0, **kwargs): 342 | super().__init__() 343 | # conditioning layers (bottom prior conditioned on class labels and top-level code) 344 | self.in_proj_y = nn.Linear(n_cond_classes, 2*n_channels) 345 | self.in_proj_h = nn.ConvTranspose2d(1, n_channels, kernel_size=4, stride=2, padding=1) # upsample top codes to bottom-level spacial dim 346 | self.cond_layers = nn.ModuleList([ 347 | GatedResidualLayer(partial(Conv2d, padding=kernel_size//2), n_channels, kernel_size, drop_rate, None, n_cond_classes) \ 348 | for _ in range(n_cond_stack_layers)]) 349 | self.out_proj_h = nn.Conv2d(n_channels, 2*n_channels, kernel_size=1) # double channels top apply pixelcnn_gate 350 | 351 | # pixelcnn layers 352 | self.input_conv = MaskedConv2d('a', 1, 2*n_channels, kernel_size=7, padding=3) 353 | self.res_layers = nn.ModuleList([ 354 | GatedResidualBlock(n_channels, n_channels, kernel_size, 2*n_channels, drop_rate) for _ in range(n_res_layers)]) 355 | self.conv_out1 = nn.Conv2d(n_channels, 2*n_out_conv_channels, kernel_size=1) 356 | self.conv_out2 = nn.Conv2d(n_out_conv_channels, 2*n_out_conv_channels, kernel_size=1) 357 | self.output = nn.Conv2d(n_out_conv_channels, 2**n_bits, kernel_size=1) 358 | 359 | def forward(self, x, h=None, y=None): 360 | # conditioning inputs -- h is top-level codes; y is class labels 361 | h = self.in_proj_h(h) 362 | for l in self.cond_layers: 363 | h = l(h, y=y) 364 | h = self.out_proj_h(h) 365 | y = self.in_proj_y(y)[:,:,None,None] 366 | 367 | # pixelcnn model 368 | x = pixelcnn_gate(self.input_conv(x) + h + y) 369 | x_v, x_h = x, x 370 | for l in self.res_layers: 371 | x_v, x_h = l(x_v, x_h, y) 372 | out = pixelcnn_gate(self.conv_out1(x_h)) 373 | out = pixelcnn_gate(self.conv_out2(out)) 374 | return self.output(out).unsqueeze(2) # (B, 2**n_bits, 1, H, W) 375 | 376 | # -------------------- 377 | # Train and evaluate 378 | # -------------------- 379 | 380 | def train_epoch(model, dataloader, optimizer, scheduler, epoch, writer, args): 381 | model.train() 382 | 383 | tic = time.time() 384 | if args.on_main_process: pbar = tqdm(total=len(dataloader), desc='epoch {}/{}'.format(epoch, args.start_epoch + args.n_epochs)) 385 | for e1, e2, y in dataloader: 386 | args.step += args.world_size 387 | 388 | e1, e2, y = e1.to(args.device), e2.to(args.device), y.to(args.device) 389 | 390 | if args.which_prior == 'bottom': 391 | x = e1 392 | logits = model(preprocess(x, args.n_bits), preprocess(e2, args.n_bits), y) 393 | elif args.which_prior == 'top': 394 | x = e2 395 | logits = model(preprocess(x, args.n_bits), y) 396 | loss = F.cross_entropy(logits, x).mean(0) 397 | 398 | optimizer.zero_grad() 399 | loss.backward() 400 | nn.utils.clip_grad_value_(model.parameters(), 1) 401 | optimizer.step() 402 | if scheduler: scheduler.step() 403 | 404 | # record 405 | if args.on_main_process: 406 | pbar.set_postfix(loss='{:.4f}'.format(loss.item() / np.log(2))) 407 | pbar.update() 408 | 409 | if args.step % args.log_interval == 0 and args.on_main_process: 410 | writer.add_scalar('train_bits_per_dim', loss.item() / np.log(2), args.step) 411 | 412 | # save 413 | if args.step % args.save_interval == 0 and args.on_main_process: 414 | # save model 415 | torch.save({'epoch': epoch, 416 | 'global_step': args.step, 417 | 'state_dict': model.module.state_dict() if args.distributed else model.state_dict()}, 418 | os.path.join(args.output_dir, 'checkpoint.pt')) 419 | torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'optim_checkpoint.pt')) 420 | if scheduler: torch.save(scheduler.state_dict(), os.path.join(args.output_dir, 'sched_checkpoint.pt')) 421 | 422 | if args.on_main_process: pbar.close() 423 | 424 | @torch.no_grad() 425 | def evaluate(model, dataloader, args): 426 | model.eval() 427 | 428 | losses = 0 429 | for e1, e2, y in dataloader: 430 | e1, e2, y = e1.to(args.device), e2.to(args.device), y.to(args.device) 431 | if args.which_prior == 'bottom': 432 | x = e1 433 | logits = model(preprocess(x, args.n_bits), preprocess(e2, args.n_bits), y) 434 | elif args.which_prior == 'top': 435 | x = e2 436 | logits = model(preprocess(x, args.n_bits), y) 437 | losses += F.cross_entropy(logits, x).mean(0).item() 438 | return losses / (len(dataloader) * np.log(2)) # to bits per dim 439 | 440 | def train_and_evaluate(model, vqvae, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args): 441 | for epoch in range(args.start_epoch, args.start_epoch + args.n_epochs): 442 | train_epoch(model, train_dataloader, optimizer, scheduler, epoch, writer, args) 443 | 444 | if (epoch+1) % args.eval_interval == 0: 445 | # optimizer.use_ema(True) 446 | 447 | # evaluate 448 | eval_bpd = evaluate(model, valid_dataloader, args) 449 | if args.on_main_process: 450 | print('Evaluate bits per dim: {:.4f}'.format(eval_bpd)) 451 | writer.add_scalar('eval_bits_per_dim', eval_bpd, args.step) 452 | 453 | # generate 454 | samples = generate_samples_in_training(model, vqvae, train_dataloader, args) 455 | samples = make_grid(samples, normalize=True, nrow=args.n_samples) 456 | if args.distributed: 457 | # collect samples tensor from all processes onto main process cpu 458 | tensors = [torch.empty(samples.shape, dtype=samples.dtype).cuda() for i in range(args.world_size)] 459 | torch.distributed.all_gather(tensors, samples) 460 | samples = torch.cat(tensors, 2) 461 | if args.on_main_process: 462 | samples = samples.cpu() 463 | writer.add_image('samples_' + args.which_prior, samples, args.step) 464 | save_image(samples, os.path.join(args.output_dir, 'samples_{}_step_{}.png'.format(args.which_prior, args.step))) 465 | 466 | # optimizer.use_ema(False) 467 | 468 | if args.on_main_process: 469 | # save model 470 | torch.save({'epoch': epoch, 471 | 'global_step': args.step, 472 | 'state_dict': model.module.state_dict() if args.distributed else model.state_dict()}, 473 | os.path.join(args.output_dir, 'checkpoint.pt')) 474 | torch.save(optimizer.state_dict(), os.path.join(args.output_dir, 'optim_checkpoint.pt')) 475 | if scheduler: torch.save(scheduler.state_dict(), os.path.join(args.output_dir, 'sched_checkpoint.pt')) 476 | 477 | 478 | # -------------------- 479 | # Sample and generate 480 | # -------------------- 481 | 482 | def sample_prior(model, h, y, n_samples, input_dims, n_bits): 483 | model.eval() 484 | 485 | H,W = input_dims 486 | out = torch.zeros(n_samples, 1, H, W, device=next(model.parameters()).device) 487 | if args.on_main_process: pbar = tqdm(total=H*W, desc='Generating {} images'.format(n_samples)) 488 | for hi in range(H): 489 | for wi in range(W): 490 | logits = model(out, y) if h is None else model(out, h, y) 491 | probs = F.softmax(logits, dim=1) 492 | sample = torch.multinomial(probs[:,:,:,hi,wi].squeeze(2), 1) 493 | out[:,:,hi,wi] = preprocess(sample, n_bits) # multinomial samples long tensor in [0, 2**n_bits), convert back to model space [-1,1] 494 | if args.on_main_process: pbar.update() 495 | del logits, probs, sample 496 | if args.on_main_process: pbar.close() 497 | return deprocess(out, n_bits) # out (B,1,H,W) field of latents in latent space [0, 2**n_bits) 498 | 499 | @torch.no_grad() 500 | def generate(vqvae, bottom_model, top_model, args, ys=None): 501 | samples = [] 502 | for y in ys.unsqueeze(1): # condition on class one-hot labels; (n_samples, 1, n_cond_classes) when sliced on dim 0 returns (1,n_cond_classes) 503 | # sample top prior conditioned on class labels y 504 | top_samples = sample_prior(top_model, None, y, args.n_samples, args.input_dims[2], args.n_bits) 505 | # sample bottom prior conditioned on top_sample codes and class labels y 506 | bottom_samples = sample_prior(bottom_model, preprocess(top_samples, args.n_bits), y, args.n_samples, args.input_dims[1], args.n_bits) 507 | # decode 508 | samples += [vqvae.decode(None, vqvae.embed((bottom_samples, top_samples)))] 509 | samples = torch.cat(samples) 510 | return make_grid(samples, normalize=True, scale_each=True) 511 | 512 | def generate_samples_in_training(model, vqvae, dataloader, args): 513 | if args.which_prior == 'top': 514 | # zero out bottom samples so no contribution 515 | bottom_samples = torch.zeros(args.n_samples*(args.n_cond_classes+1),1,*args.input_dims[1], dtype=torch.long) 516 | # sample top prior 517 | top_samples = [] 518 | for y in torch.eye(args.n_cond_classes + 1, args.n_cond_classes).unsqueeze(1).to(args.device): # note eg: torch.eye(3,2) = [[1,0],[0,1],[0,0]] 519 | top_samples += [sample_prior(model, None, y, args.n_samples, args.input_dims[2], args.n_bits).cpu()] 520 | top_samples = torch.cat(top_samples) 521 | # decode 522 | samples = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_samples.to(args.device), top_samples.to(args.device)))) 523 | 524 | elif args.which_prior == 'bottom': # level 1 525 | # use the dataset ground truth top codes and only sample the bottom 526 | bottom_gt, top_gt, y = next(iter(dataloader)) # take e2 and y from dataloader output (e1,e2,y) 527 | bottom_gt, top_gt, y = bottom_gt[:args.n_samples].to(args.device), top_gt[:args.n_samples].to(args.device), y[:args.n_samples].to(args.device) 528 | # sample bottom prior 529 | bottom_samples = sample_prior(model, preprocess(top_gt, args.n_bits), y, args.n_samples, args.input_dims[1], args.n_bits) 530 | # decode 531 | # stack (1) recon using bottom+top actual latents, 532 | # (2) recon using top latents only, 533 | # (3) recon using top latent and bottom prior samples 534 | recon_actuals = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_gt, top_gt))) 535 | recon_top = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_gt.fill_(0), top_gt))) 536 | recon_samples = vqvae.decode(z_e=None, z_q=vqvae.embed((bottom_samples, top_gt))) 537 | samples = torch.cat([recon_actuals, recon_top, recon_samples]) 538 | 539 | return samples 540 | 541 | 542 | # -------------------- 543 | # Main 544 | # -------------------- 545 | 546 | if __name__ == '__main__': 547 | args = parser.parse_args() 548 | if args.restore_dir and args.which_prior: 549 | args.output_dir = args.restore_dir[0] 550 | if not args.output_dir: # if not given or not set by restore_dir use results/file_name/time_stamp 551 | # name experiment 'vqvae_[vqvae_dir]_prior_[prior_args]_[timestamp]' 552 | exp_name = 'vqvae_' + args.vqvae_dir.strip('/').rpartition('/')[2] + \ 553 | '_prior_{which_prior}' + args.mini_data*'_mini{}'.format(args.batch_size) + \ 554 | '_b{batch_size}_c{n_channels}_outc{n_out_conv_channels}_nres{n_res_layers}_condstack{n_cond_stack_layers}' + \ 555 | '_outstack{n_out_stack_layers}_drop{drop_rate}' + \ 556 | '_{}'.format(time.strftime('%Y-%m-%d_%H-%M', time.gmtime())) 557 | args.output_dir = './results/{}/{}'.format(os.path.splitext(__file__)[0], exp_name.format(**args.__dict__)) 558 | os.makedirs(args.output_dir, exist_ok=True) 559 | 560 | # setup device and distributed training 561 | if args.distributed: 562 | args.cuda = int(os.environ['LOCAL_RANK']) 563 | args.world_size = int(os.environ['WORLD_SIZE']) 564 | torch.cuda.set_device(args.cuda) 565 | args.device = 'cuda:{}'.format(args.cuda) 566 | 567 | # initialize 568 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 569 | else: 570 | args.device = 'cuda:{}'.format(args.cuda) if args.cuda is not None and torch.cuda.is_available() else 'cpu' 571 | 572 | # write ops only when on_main_process 573 | args.on_main_process = (args.distributed and args.cuda == 0) or not args.distributed 574 | 575 | # setup seed 576 | torch.manual_seed(args.seed) 577 | np.random.seed(args.seed) 578 | torch.backends.cudnn.deterministic = True 579 | torch.backends.cudnn.benchmark = False 580 | 581 | # load vqvae 582 | # load config; extract bits and input sizes throughout the hierarchy from the vqvae config 583 | vqvae_config = load_json(os.path.join(args.vqvae_dir, 'config.json')) 584 | img_dims = vqvae_config['input_dims'][1:] 585 | args.input_dims = [img_dims, [img_dims[0]//4, img_dims[1]//4], [img_dims[0]//8, img_dims[1]//8]] 586 | args.n_bits = int(np.log2(vqvae_config['n_embeddings'])) 587 | args.dataset = vqvae_config['dataset'] 588 | args.data_dir = vqvae_config['data_dir'] 589 | # load model 590 | vqvae, _, _ = load_model(VQVAE2, vqvae_config, args.vqvae_dir, args, restore=True, eval_mode=True, verbose=args.on_main_process) 591 | # reset start_epoch and step after model loading 592 | args.start_epoch, args.step = 0, 0 593 | # expose functions 594 | if args.distributed: 595 | vqvae.encode = vqvae.module.encode 596 | vqvae.decode = vqvae.module.decode 597 | vqvae.embed = vqvae.module.embed 598 | 599 | 600 | # load prior model 601 | # save prior config to feed to load_model 602 | if not os.path.exists(os.path.join(args.output_dir, 'config_{}.json'.format(args.cuda))): 603 | save_json(args.__dict__, 'config_{}'.format(args.cuda), args) 604 | # load model + optimizers, scheduler if training 605 | if args.which_prior: 606 | model, optimizer, scheduler = load_model(PixelCNN if args.which_prior=='bottom' else PixelSNAIL, 607 | config=args.__dict__, 608 | model_dir=args.output_dir, 609 | args=args, 610 | restore=(args.restore_dir is not None), 611 | eval_mode=False, 612 | optimizer_cls=partial(RMSprop, 613 | lr=args.lr, 614 | polyak=args.polyak), 615 | scheduler_cls=partial(torch.optim.lr_scheduler.ExponentialLR, gamma=args.lr_decay), 616 | verbose=args.on_main_process) 617 | else: 618 | assert args.restore_dir and len(args.restore_dir)==2, '`restore_dir` should specify restore dir to bottom prior and top prior' 619 | # load both top and bottom to generate 620 | restore_bottom, restore_top = args.restore_dir 621 | bottom_model, _, _ = load_model(PixelCNN, config=None, model_dir=restore_bottom, args=args, restore=True, eval_mode=True, 622 | optimizer_cls=partial(RMSprop, lr=args.lr, polyak=args.polyak)) 623 | top_model, _, _ = load_model(PixelSNAIL, config=None, model_dir=restore_top, args=args, restore=True, eval_mode=True, 624 | optimizer_cls=partial(RMSprop, lr=args.lr, polyak=args.polyak)) 625 | 626 | # save and print config and setup writer on main process 627 | writer = None 628 | if args.on_main_process: 629 | pprint.pprint(args.__dict__) 630 | writer = SummaryWriter(log_dir = args.output_dir) 631 | writer.add_text('config', str(args.__dict__)) 632 | 633 | if args.train: 634 | assert args.which_prior is not None, 'Must specify `which_prior` to train.' 635 | train_dataloader = fetch_prior_dataloader(vqvae, args, True) 636 | valid_dataloader = fetch_prior_dataloader(vqvae, args, False) 637 | train_and_evaluate(model, vqvae, train_dataloader, valid_dataloader, optimizer, scheduler, writer, args) 638 | 639 | if args.evaluate: 640 | assert args.which_prior is not None, 'Must specify `which_prior` to evaluate.' 641 | valid_dataloader = fetch_prior_dataloader(vqvae, args, False) 642 | # optimizer.use_ema(True) 643 | eval_bpd = evaluate(model, valid_dataloader, args) 644 | if args.on_main_process: 645 | print('Evaluate bits per dim: {:.4f}'.format(eval_bpd)) 646 | 647 | if args.generate: 648 | assert args.which_prior is None, 'Remove `which_prior` to load both priors and generate' 649 | # optimizer.use_ema(True) 650 | samples = generate(vqvae, bottom_model, top_model, args, ys=torch.eye(args.n_cond_classes + 1, args.n_cond_classes).to(args.device)) 651 | if args.distributed: 652 | torch.manual_seed(args.rank) 653 | # collect samples tensor from all processes onto main process cpu 654 | tensors = [torch.empty(samples.shape, dtype=samples.dtype).cuda() for i in range(args.world_size)] 655 | torch.distributed.all_gather(tensors, samples) # collect samples tensor from all processes onto main process cpu 656 | samples = torch.cat(tensors, 2) 657 | if args.on_main_process: 658 | samples = samples.cpu() 659 | writer.add_image('samples', samples, args.step) 660 | save_image(samples.cpu(), os.path.join(args.output_dir, 'generation_sample_step_{}.png'.format(args.step))) 661 | 662 | --------------------------------------------------------------------------------