├── README.md ├── models_img_skip.py ├── models_text_skip.py ├── train_img_skip.py └── train_text_skip.py /README.md: -------------------------------------------------------------------------------- 1 | # Skip-VAE 2 | Code for the paper: 3 | [Avoiding Latent Variable Collapse With Generative Skip Models](https://arxiv.org/pdf/1807.04863.pdf) 4 | Adji B. Dieng, Yoon Kim, Alexander M. Rush, David M. Blei. 5 | 6 | Our code/data is based on the [Semi-Amortized VAE repo](https://github.com/harvardnlp/sa-vae). 7 | Please refer to the above repo for dependencies, data processing, etc. 8 | 9 | ## Model 10 | After downloading the `sa-vae` repo, copy these files to the `sa-vae` folder: 11 | - `train_text_skip.py` 12 | - `models_text_skip.py` 13 | - `train_img_skip.py` 14 | - `models_img_skip.py` 15 | 16 | To run the text model: 17 | ``` 18 | python train_text_skip.py --train_file data/yahoo/yahoo-train.hdf5 --val_file data/yahoo/yahoo-val.hdf5 --gpu 1 --checkpoint_path model-path --skip 1 --model savae --svi_steps 20 --train_n2n 1 19 | ``` 20 | where `model-path` is the path to save the best model and the `*.hdf5` files are obtained from running `preprocess_text.py`. You can specify which GPU to use by changing the input to the `--gpu` command. 21 | 22 | To run the image model: 23 | ``` 24 | python train_img_skip.py --data_file data/omniglot/omniglot.pt --gpu 1 --checkpoint_path model-path --skip 1 --model savae --svi_steps 20 --train_n2n 1 25 | ``` 26 | ## License 27 | MIT -------------------------------------------------------------------------------- /models_img_skip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import numpy as np 9 | 10 | def he_init(m): 11 | s = np.sqrt(2./ m.in_features) 12 | m.weight.data.normal_(0, s) 13 | 14 | class GatedMaskedConv2d(nn.Module): 15 | def __init__(self, in_dim, out_dim=None, kernel_size = 3, mask = 'B'): 16 | super(GatedMaskedConv2d, self).__init__() 17 | if out_dim is None: 18 | out_dim = in_dim 19 | self.dim = out_dim 20 | self.size = kernel_size 21 | self.mask = mask 22 | pad = self.size // 2 23 | 24 | #vertical stack 25 | self.v_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(pad+1, self.size)) 26 | self.v_pad1 = nn.ConstantPad2d((pad, pad, pad, 0), 0) 27 | self.v_pad2 = nn.ConstantPad2d((0, 0, 1, 0), 0) 28 | self.vh_conv = nn.Conv2d(2*self.dim, 2*self.dim, kernel_size = 1) 29 | 30 | #horizontal stack 31 | self.h_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(1, pad+1)) 32 | self.h_pad1 = nn.ConstantPad2d((self.size // 2, 0, 0, 0), 0) 33 | self.h_pad2 = nn.ConstantPad2d((1, 0, 0, 0), 0) 34 | self.h_conv_res = nn.Conv2d(self.dim, self.dim, 1) 35 | self.h_res = nn.Conv2d(in_dim, out_dim, 1) 36 | 37 | def forward(self, v_map, h_map): 38 | v_out = self.v_pad2(self.v_conv(self.v_pad1(v_map)))[:, :, :-1, :] 39 | v_map_out = F.tanh(v_out[:, :self.dim])*F.sigmoid(v_out[:, self.dim:]) 40 | vh = self.vh_conv(v_out) 41 | 42 | h_out = self.h_conv(self.h_pad1(h_map)) 43 | if self.mask == 'A': 44 | h_out = self.h_pad2(h_out)[:, :, :, :-1] 45 | h_out = h_out + vh 46 | h_out = F.tanh(h_out[:, :self.dim])*F.sigmoid(h_out[:, self.dim:]) 47 | h_map_out = self.h_conv_res(h_out) 48 | if self.mask == 'B': 49 | h_map_out = h_map_out + self.h_res(h_map) 50 | return v_map_out, h_map_out 51 | 52 | class StackedGatedMaskedConv2d(nn.Module): 53 | def __init__(self, 54 | img_size = [1, 28, 28], layers = [64,64,64], 55 | kernel_size = [7,7,7], latent_dim=64, latent_feature_map = 1, skip = 0): 56 | super(StackedGatedMaskedConv2d, self).__init__() 57 | self.skip = skip 58 | input_dim = img_size[0] 59 | self.conv_layers = [] 60 | self.z_linears = nn.ModuleList() 61 | if latent_feature_map > 0: 62 | self.latent_feature_map = latent_feature_map 63 | if self.skip == 0: 64 | add_dim = 0 65 | else: 66 | add_dim = latent_feature_map 67 | for i in range(len(kernel_size)): 68 | self.z_linears.append(nn.Linear(latent_dim, latent_feature_map*28*28) ) 69 | if i == 0: 70 | self.conv_layers.append(GatedMaskedConv2d(input_dim + latent_feature_map, 71 | layers[i], kernel_size[i], 'A')) 72 | else: 73 | self.conv_layers.append(GatedMaskedConv2d(layers[i-1] + add_dim, 74 | layers[i], kernel_size[i])) 75 | 76 | self.modules = nn.ModuleList(self.conv_layers) 77 | 78 | def forward(self, img, q_z=None): 79 | # if q_z is not None: 80 | # z_img = self.z_linear(q_z) 81 | # z_img = z_img.view(img.size(0), self.latent_feature_map, img.size(2), img.size(3)) 82 | v_map = img 83 | h_map = img 84 | for i in range(len(self.conv_layers)): 85 | z_img_i = self.z_linears[i](q_z).view(img.size(0), self.latent_feature_map, 28, 28) 86 | if i == 0 or self.skip == 1: 87 | v_map = torch.cat([v_map, z_img_i], 1) 88 | h_map = torch.cat([h_map, z_img_i], 1) 89 | # if i == 0: 90 | # if q_z is not None: 91 | # v_map = torch.cat([img, z_img], 1) 92 | # else: 93 | # v_map = img 94 | # h_map = v_map 95 | # print(i, v_map.size(), h_map.size()) 96 | v_map, h_map = self.conv_layers[i](v_map, h_map) 97 | return h_map 98 | 99 | class ResidualBlock(nn.Module): 100 | def __init__(self, in_dim, out_dim=None, with_residual=True, with_batchnorm=True, mask=None, 101 | kernel_size = 3, padding = 1): 102 | if out_dim is None: 103 | out_dim = in_dim 104 | super(ResidualBlock, self).__init__() 105 | if mask is None: 106 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=padding) 107 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=kernel_size, padding=padding) 108 | else: 109 | self.conv1 = MaskedConv2d(mask, in_dim, out_dim, kernel_size=kernel_size, padding=padding) 110 | self.conv2 = MaskedConv2d(mask, out_dim, out_dim, kernel_size=kernel_size, padding=padding) 111 | self.with_batchnorm = with_batchnorm 112 | if with_batchnorm: 113 | self.bn1 = nn.BatchNorm2d(out_dim) 114 | self.bn2 = nn.BatchNorm2d(out_dim) 115 | self.with_residual = with_residual 116 | if in_dim == out_dim or not with_residual: 117 | self.proj = None 118 | else: 119 | self.proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 120 | 121 | def forward(self, x): 122 | if self.with_batchnorm: 123 | out = F.relu(self.bn1(self.conv1(x))) 124 | out = self.bn2(self.conv2(out)) 125 | else: 126 | out = self.conv2(F.relu(self.conv1(x))) 127 | res = x if self.proj is None else self.proj(x) 128 | if self.with_residual: 129 | out = F.relu(res + out) 130 | else: 131 | out = F.relu(out) 132 | return out 133 | 134 | class MaskedConv2d(nn.Conv2d): 135 | def __init__(self, include_center=False, *args, **kwargs): 136 | super(MaskedConv2d, self).__init__(*args, **kwargs) 137 | self.register_buffer('mask', self.weight.data.clone()) 138 | _, _, kH, kW = self.weight.size() 139 | self.mask.fill_(1) 140 | self.mask[:, :, kH // 2, kW // 2 + (include_center == True):] = 0 141 | self.mask[:, :, kH // 2 + 1:] = 0 142 | 143 | def forward(self, x): 144 | self.weight.data *= self.mask.cuda() 145 | return super(MaskedConv2d, self).forward(x) 146 | 147 | class CNNVAE(nn.Module): 148 | def __init__(self, 149 | img_size = [1,28,28], 150 | latent_dim = 32, 151 | enc_layers = [64,64,64], 152 | dec_kernel_size = [7,7,7], 153 | dec_layers= [64,64,64], 154 | latent_feature_map = 4, 155 | skip = 0): 156 | super(CNNVAE, self).__init__() 157 | self.skip = skip 158 | enc_modules = [] 159 | img_h = img_size[1] 160 | img_w = img_size[2] 161 | for i in range(len(enc_layers)): 162 | if i == 0: 163 | input_dim = img_size[0] 164 | else: 165 | input_dim = enc_layers[i-1] 166 | enc_modules.append(ResidualBlock(input_dim, enc_layers[i])) 167 | enc_modules.append(nn.Conv2d(enc_layers[i], enc_layers[i], kernel_size=2, stride=2)) 168 | 169 | img_h //= 2 170 | img_w //= 2 171 | latent_in_dim = img_h*img_w*enc_layers[-1] 172 | self.z_linear = nn.Linear(latent_dim, 28*28) 173 | 174 | self.enc_cnn = nn.Sequential(*enc_modules) 175 | self.latent_linear_mean = nn.Linear(latent_in_dim, latent_dim) 176 | self.latent_linear_logvar = nn.Linear(latent_in_dim, latent_dim) 177 | self.enc = nn.ModuleList([self.enc_cnn, self.latent_linear_mean, self.latent_linear_logvar]) 178 | self.dec_cnn = StackedGatedMaskedConv2d(img_size=img_size, layers = dec_layers, 179 | latent_dim= latent_dim, kernel_size = dec_kernel_size, 180 | latent_feature_map = latent_feature_map, 181 | skip = self.skip) 182 | if self.skip == 0: 183 | self.dec_linear = nn.Conv2d(dec_layers[-1], img_size[0], kernel_size = 1) 184 | else: 185 | self.dec_linear = nn.Conv2d(dec_layers[-1]+ latent_feature_map, img_size[0], kernel_size = 1) 186 | self.dec = nn.ModuleList([self.dec_cnn, self.dec_linear]) 187 | for m in self.modules(): 188 | if isinstance(m, nn.Linear): 189 | he_init(m) 190 | 191 | def _enc_forward(self, img): 192 | img_code = self.enc_cnn(img) 193 | img_code = img_code.view(img.size(0), -1) 194 | self.img_code = img_code 195 | mean = self.latent_linear_mean(img_code) 196 | logvar = self.latent_linear_logvar(img_code) 197 | return mean, logvar 198 | 199 | def _reparameterize(self, mean, logvar, z = None): 200 | self.std = logvar.mul(0.5).exp() 201 | if z is None: 202 | self.z = Variable(torch.FloatTensor(self.std.size()).normal_(0, 1).type_as(mean.data)) 203 | else: 204 | self.z = z 205 | self.q_z = self.z*self.std + mean 206 | return self.q_z 207 | 208 | def _dec_forward(self, img, q_z): 209 | dec_cnn_output = self.dec_cnn(img, q_z) 210 | if self.skip == 1: 211 | z_linear_last = self.z_linear(q_z).view(img.size(0), 1, 28, 28) 212 | dec_cnn_output = torch.cat([dec_cnn_output, z_linear_last], 1) 213 | pred = F.sigmoid(self.dec_linear(dec_cnn_output)) 214 | return pred 215 | 216 | class MLPVAE(nn.Module): 217 | def __init__(self, 218 | img_size = [1,28,28], 219 | latent_dim = 32, 220 | enc_layers = [64,64,64], 221 | dec_kernel_size = [7,7,7], 222 | dec_layers= [64,64,64], 223 | latent_feature_map = 4, 224 | skip = 0): 225 | super(MLPVAE, self).__init__() 226 | self.skip = skip 227 | h = 1024 228 | self.enc_mlp = nn.Sequential(nn.Linear(28*28, h), nn.ReLU(), 229 | nn.Linear(h, h), nn.ReLU()) 230 | self.latent_linear_mean = nn.Linear(h, latent_dim) 231 | self.latent_linear_logvar = nn.Linear(h, latent_dim) 232 | self.enc = nn.ModuleList([self.enc_mlp, self.latent_linear_mean, self.latent_linear_logvar]) 233 | self.num_layers = 2 234 | self.dec_linears = nn.ModuleList([nn.Linear(h,h) for _ in range(self.num_layers)]) 235 | self.z_linears = nn.ModuleList([nn.Linear(latent_dim, h) for _ in range(self.num_layers)]) 236 | self.dec_init = nn.Linear(latent_dim, h) 237 | self.dec_last = nn.Linear(h, 28*28) 238 | 239 | def _enc_forward(self, img): 240 | img_code = self.enc_mlp(img.view(img.size(0), -1)) 241 | mean = self.latent_linear_mean(img_code) 242 | logvar = self.latent_linear_logvar(img_code) 243 | return mean, logvar 244 | 245 | def _reparameterize(self, mean, logvar, z = None): 246 | self.std = logvar.mul(0.5).exp() 247 | if z is None: 248 | self.z = Variable(torch.FloatTensor(self.std.size()).normal_(0, 1).type_as(mean.data)) 249 | else: 250 | self.z = z 251 | self.q_z = self.z*self.std + mean 252 | return self.q_z 253 | 254 | def _dec_forward(self, img, q_z): 255 | x = self.dec_init(q_z) 256 | for i in range(self.num_layers): 257 | x = self.dec_linears[i](x) 258 | if self.skip == 1: 259 | z = self.z_linears[i](q_z) 260 | x = x + z 261 | x = F.relu(x) 262 | pred = self.dec_last(x) 263 | pred = F.sigmoid(pred) 264 | return pred.view(img.size(0), 1, 28, 28) 265 | -------------------------------------------------------------------------------- /models_text_skip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import numpy as np 9 | 10 | class RNNVAE(nn.Module): 11 | def __init__(self, vocab_size=10000, 12 | enc_word_dim = 512, 13 | enc_h_dim = 1024, 14 | enc_num_layers = 1, 15 | dec_word_dim = 512, 16 | dec_h_dim = 1024, 17 | dec_num_layers = 1, 18 | dec_dropout = 0.5, 19 | latent_dim=32, 20 | mode='savae', 21 | skip = 0): 22 | super(RNNVAE, self).__init__() 23 | self.skip = skip 24 | self.enc_h_dim = enc_h_dim 25 | self.enc_num_layers = enc_num_layers 26 | self.dec_h_dim =dec_h_dim 27 | self.dec_num_layers = dec_num_layers 28 | 29 | if mode == 'savae' or mode == 'vae': 30 | self.enc_word_vecs = nn.Embedding(vocab_size, enc_word_dim) 31 | self.latent_linear_mean = nn.Linear(enc_h_dim, latent_dim) 32 | self.latent_linear_logvar = nn.Linear(enc_h_dim, latent_dim) 33 | self.enc_rnn = nn.LSTM(enc_word_dim, enc_h_dim, num_layers = enc_num_layers, 34 | batch_first = True) 35 | self.enc = nn.ModuleList([self.enc_word_vecs, self.enc_rnn, 36 | self.latent_linear_mean, self.latent_linear_logvar]) 37 | elif mode == 'autoreg': 38 | latent_dim = 0 39 | 40 | self.dec_word_vecs = nn.Embedding(vocab_size, dec_word_dim) 41 | dec_input_size = dec_word_dim 42 | dec_input_size += latent_dim 43 | if self.skip == 0: 44 | self.dec_linear = nn.Linear(dec_h_dim, vocab_size) 45 | else: 46 | self.dec_linear = nn.Linear(dec_h_dim + latent_dim, vocab_size) 47 | 48 | if self.skip == 0 or self.dec_num_layers == 1: 49 | self.dec_rnn = nn.LSTM(dec_input_size, dec_h_dim, num_layers = dec_num_layers, 50 | batch_first = True) 51 | self.dec = nn.ModuleList([self.dec_word_vecs, self.dec_rnn, self.dec_linear]) 52 | else: 53 | self.dec_rnn = nn.ModuleList([nn.LSTM(dec_input_size, dec_h_dim, batch_first = True) 54 | if k == 0 else 55 | nn.LSTM(dec_h_dim + latent_dim, dec_h_dim, batch_first=True) 56 | for k in range(self.dec_num_layers)]) 57 | self.dec = nn.ModuleList([self.dec_word_vecs, self.dec_linear, *self.dec_rnn]) 58 | 59 | self.dropout = nn.Dropout(dec_dropout) 60 | 61 | if latent_dim > 0: 62 | self.latent_hidden_linear = nn.Linear(latent_dim, dec_h_dim) 63 | self.dec.append(self.latent_hidden_linear) 64 | 65 | def _enc_forward(self, sent): 66 | word_vecs = self.enc_word_vecs(sent) 67 | 68 | h0 = Variable(torch.zeros(self.enc_num_layers, word_vecs.size(0), 69 | self.enc_h_dim).type_as(word_vecs.data)) 70 | c0 = Variable(torch.zeros(self.enc_num_layers, word_vecs.size(0), 71 | self.enc_h_dim).type_as(word_vecs.data)) 72 | enc_h_states, _ = self.enc_rnn(word_vecs, (h0, c0)) 73 | enc_h_states_last = enc_h_states[:, -1] 74 | mean = self.latent_linear_mean(enc_h_states_last) 75 | logvar = self.latent_linear_logvar(enc_h_states_last) 76 | return mean, logvar 77 | 78 | def _reparameterize(self, mean, logvar, z = None): 79 | std = logvar.mul(0.5).exp() 80 | if z is None: 81 | z = Variable(torch.cuda.FloatTensor(std.size()).normal_(0, 1)) 82 | return z.mul(std) + mean 83 | 84 | def _dec_forward(self, sent, q_z, init_h = True): 85 | self.word_vecs = self.dropout(self.dec_word_vecs(sent[:, :-1])) 86 | if init_h: 87 | self.h0 = Variable(torch.zeros(self.dec_num_layers, self.word_vecs.size(0), 88 | self.dec_h_dim).type_as(self.word_vecs.data), requires_grad = False) 89 | self.c0 = Variable(torch.zeros(self.dec_num_layers, self.word_vecs.size(0), 90 | self.dec_h_dim).type_as(self.word_vecs.data), requires_grad = False) 91 | else: 92 | self.h0.data.zero_() 93 | self.c0.data.zero_() 94 | 95 | if q_z is not None: 96 | q_z_expand = q_z.unsqueeze(1).expand(self.word_vecs.size(0), 97 | self.word_vecs.size(1), q_z.size(1)) 98 | dec_input = torch.cat([self.word_vecs, q_z_expand], 2) 99 | else: 100 | dec_input = self.word_vecs 101 | if q_z is not None: 102 | self.h0[-1] = self.latent_hidden_linear(q_z) 103 | if self.skip == 0 or self.dec_num_layers == 1: 104 | memory, _ = self.dec_rnn(dec_input, (self.h0, self.c0)) 105 | else: 106 | for k in range(self.dec_num_layers): 107 | memory, _ = self.dec_rnn[k](dec_input, (self.h0[k].unsqueeze(0), self.c0[k].unsqueeze(0))) 108 | dec_input = torch.cat([memory, q_z_expand], 2) 109 | 110 | memory = self.dropout(memory) 111 | if self.skip == 1: 112 | dec_linear_input = torch.cat([memory.contiguous(), q_z_expand], 2) 113 | else: 114 | dec_linear_input = memory.contiguous() 115 | preds = self.dec_linear(dec_linear_input) 116 | preds = F.log_softmax(preds, 2) 117 | return preds 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /train_img_skip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import os 4 | 5 | import argparse 6 | import json 7 | import random 8 | import shutil 9 | import copy 10 | 11 | import torch 12 | from torch import cuda 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | from torch.nn.parameter import Parameter 16 | 17 | import torch.nn.functional as F 18 | import numpy as np 19 | import h5py 20 | import time 21 | import logging 22 | from models_img_skip import CNNVAE, MLPVAE 23 | from optim_n2n import OptimN2N 24 | import utils 25 | import torch.utils.data 26 | 27 | parser = argparse.ArgumentParser() 28 | 29 | # Input data 30 | parser.add_argument('--data_file', default='data/mnist/static_data.pt') 31 | parser.add_argument('--train_from', default='') 32 | parser.add_argument('--checkpoint_path', default='baseline.pt') 33 | 34 | # Model options 35 | parser.add_argument('--img_size', default=[1,28,28]) 36 | parser.add_argument('--latent_dim', default=20, type=int) 37 | parser.add_argument('--enc_layers', default=[64,64,64]) 38 | parser.add_argument('--dec_kernel_size', default=[3,3,3,3,3,3,3,3], type=int) 39 | parser.add_argument('--dec_layers', default=[32,32,32,32,32,32,32,32]) 40 | parser.add_argument('--latent_feature_map', default=1, type=int) 41 | parser.add_argument('--model', default='vae', type=str, choices = ['vae', 'autoreg', 'savae', 'svi']) 42 | parser.add_argument('--train_kl', default=1, type=int) 43 | parser.add_argument('--train_n2n', default=1, type=int) 44 | 45 | # Optimization options 46 | parser.add_argument('--skip', default=0, type=int) 47 | parser.add_argument('--num_epochs', default=100, type=int) 48 | parser.add_argument('--svi_steps', default=20, type=int) 49 | parser.add_argument('--svi_lr1', default=1, type=float) 50 | parser.add_argument('--svi_lr2', default=1, type=float) 51 | parser.add_argument('--eps', default=1e-5, type=float) 52 | parser.add_argument('--momentum', default=0.5, type=float) 53 | parser.add_argument('--warmup', default=0, type=int) 54 | parser.add_argument('--lr', default=1e-3, type=float) 55 | parser.add_argument('--max_grad_norm', default=5, type=float) 56 | parser.add_argument('--svi_max_grad_norm', default=5, type=float) 57 | parser.add_argument('--gpu', default=2, type=int) 58 | parser.add_argument('--slurm', default=0, type=int) 59 | parser.add_argument('--batch_size', default=50, type=int) 60 | parser.add_argument('--seed', default=354, type=int) 61 | parser.add_argument('--print_every', type=int, default=500) 62 | parser.add_argument('--test', type=int, default=0) 63 | 64 | def main(args): 65 | np.random.seed(args.seed) 66 | torch.manual_seed(args.seed) 67 | all_data = torch.load(args.data_file) 68 | x_train, x_val, x_test = all_data 69 | y_size = 1 70 | y_train = torch.zeros(x_train.size(0), y_size) 71 | y_val = torch.zeros(x_val.size(0), y_size) 72 | y_test = torch.zeros(x_test.size(0), y_size) 73 | train = torch.utils.data.TensorDataset(x_train, y_train) 74 | val = torch.utils.data.TensorDataset(x_val, y_val) 75 | test = torch.utils.data.TensorDataset(x_test, y_test) 76 | 77 | train_loader = torch.utils.data.DataLoader(train, batch_size=args.batch_size, shuffle=True) 78 | val_loader = torch.utils.data.DataLoader(val, batch_size=args.batch_size, shuffle=True) 79 | test_loader = torch.utils.data.DataLoader(test, batch_size=args.batch_size, shuffle=True) 80 | print('Train data: %d batches' % len(train_loader)) 81 | print('Val data: %d batches' % len(val_loader)) 82 | print('Test data: %d batches' % len(test_loader)) 83 | if args.slurm == 0: 84 | cuda.set_device(args.gpu) 85 | if args.model == 'autoreg': 86 | args.latent_feature_map = 0 87 | if args.train_from == '': 88 | model = CNNVAE(img_size = args.img_size, 89 | latent_dim = args.latent_dim, 90 | enc_layers = args.enc_layers, 91 | dec_kernel_size = args.dec_kernel_size, 92 | dec_layers = args.dec_layers, 93 | latent_feature_map = args.latent_feature_map, 94 | skip = args.skip) 95 | else: 96 | print('loading model from ' + args.train_from) 97 | checkpoint = torch.load(args.train_from) 98 | model = checkpoint['model'] 99 | print("model architecture") 100 | print(model) 101 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) 102 | 103 | model.cuda() 104 | model.train() 105 | 106 | def variational_loss(input, img, model, z = None): 107 | mean, logvar = input 108 | z_samples = model._reparameterize(mean, logvar, z) 109 | preds = model._dec_forward(img, z_samples) 110 | nll = utils.log_bernoulli_loss(preds, img) 111 | kl = utils.kl_loss_diag(mean, logvar) 112 | return nll + args.beta*kl 113 | 114 | update_params = list(model.dec.parameters()) 115 | meta_optimizer = OptimN2N(variational_loss, model, update_params, eps = args.eps, 116 | lr = [args.svi_lr1, args.svi_lr2], 117 | iters = args.svi_steps, momentum = args.momentum, 118 | acc_param_grads= args.train_n2n == 1, 119 | max_grad_norm = args.svi_max_grad_norm) 120 | epoch = 0 121 | t = 0 122 | best_val_nll = 1e5 123 | best_epoch = 0 124 | loss_stats = [] 125 | if args.warmup == 0: 126 | args.beta = 1. 127 | else: 128 | args.beta = 0.1 129 | 130 | if args.test == 1: 131 | args.beta = 1 132 | agg_kl = get_agg_kl(test_loader, test_loader, model) 133 | eval(test_loader, model, meta_optimizer, agg_kl) 134 | exit() 135 | 136 | while epoch < args.num_epochs: 137 | start_time = time.time() 138 | epoch += 1 139 | print('Starting epoch %d' % epoch) 140 | train_nll_vae = 0. 141 | train_nll_autoreg = 0. 142 | train_kl_vae = 0. 143 | train_nll_svi = 0. 144 | train_kl_svi = 0. 145 | num_examples = 0 146 | for b, datum in enumerate(train_loader): 147 | if args.warmup > 0: 148 | args.beta = min(1, args.beta + 1./(args.warmup*len(train_loader))) 149 | img, _ = datum 150 | img = torch.bernoulli(img) 151 | batch_size = img.size(0) 152 | img = Variable(img.cuda()) 153 | t += 1 154 | optimizer.zero_grad() 155 | if args.model == 'autoreg': 156 | preds = model._dec_forward(img, None) 157 | nll_autoreg = utils.log_bernoulli_loss(preds, img) 158 | train_nll_autoreg += nll_autoreg.data[0]*batch_size 159 | nll_autoreg.backward() 160 | elif args.model == 'svi': 161 | mean_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) 162 | logvar_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) 163 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], img, 164 | t % args.print_every == 0) 165 | mean_svi_final, logvar_svi_final = var_params_svi 166 | z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) 167 | preds = model._dec_forward(img, z_samples) 168 | nll_svi = utils.log_bernoulli_loss(preds, img) 169 | train_nll_svi += nll_svi.data[0]*batch_size 170 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 171 | train_kl_svi += kl_svi.data[0]*batch_size 172 | var_loss = nll_svi + args.beta*kl_svi 173 | var_loss.backward() 174 | else: 175 | mean, logvar = model._enc_forward(img) 176 | z_samples = model._reparameterize(mean, logvar) 177 | preds = model._dec_forward(img, z_samples) 178 | nll_vae = utils.log_bernoulli_loss(preds, img) 179 | train_nll_vae += nll_vae.data[0]*batch_size 180 | kl_vae = utils.kl_loss_diag(mean, logvar) 181 | train_kl_vae += kl_vae.data[0]*batch_size 182 | if args.model == 'vae': 183 | vae_loss = nll_vae + args.beta*kl_vae 184 | vae_loss.backward(retain_graph = True) 185 | 186 | if args.model == 'savae': 187 | var_params = torch.cat([mean, logvar], 1) 188 | mean_svi = Variable(mean.data, requires_grad = True) 189 | logvar_svi = Variable(logvar.data, requires_grad = True) 190 | 191 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], img, 192 | t % args.print_every == 0) 193 | mean_svi_final, logvar_svi_final = var_params_svi 194 | z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) 195 | preds = model._dec_forward(img, z_samples) 196 | nll_svi = utils.log_bernoulli_loss(preds, img) 197 | train_nll_svi += nll_svi.data[0]*batch_size 198 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 199 | train_kl_svi += kl_svi.data[0]*batch_size 200 | var_loss = nll_svi + args.beta*kl_svi 201 | var_loss.backward(retain_graph = True) 202 | if args.train_n2n == 0: 203 | if args.train_kl == 1: 204 | mean_final = mean_svi_final.detach() 205 | logvar_final = logvar_svi_final.detach() 206 | kl_init_final = utils.kl_loss(mean, logvar, mean_final, logvar_final) 207 | kl_init_final.backward(retain_graph = True) 208 | else: 209 | vae_loss = nll_vae + args.beta*kl_vae 210 | var_param_grads = torch.autograd.grad(vae_loss, [mean, logvar], retain_graph=True) 211 | var_param_grads = torch.cat(var_param_grads, 1) 212 | var_params.backward(var_param_grads, retain_graph=True) 213 | else: 214 | var_param_grads = meta_optimizer.backward([mean_svi_final.grad, logvar_svi_final.grad], 215 | t % args.print_every == 0) 216 | var_param_grads = torch.cat(var_param_grads, 1) 217 | var_params.backward(var_param_grads) 218 | if args.max_grad_norm > 0: 219 | torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) 220 | optimizer.step() 221 | num_examples += batch_size 222 | if t % args.print_every == 0: 223 | param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 224 | print('Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.2f, TrainVAE_NLL: %.2f, TrainVAE_KL: %.4f, TrainVAE_NLLBnd: %.2f, TrainSVI_NLL: %.2f, TrainSVI_KL: %.4f, TrainSVI_NLLBnd: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.3f, Throughput: %.2f examples/sec' % 225 | (t, epoch, b+1, len(train_loader), args.lr, train_nll_autoreg / num_examples, 226 | train_nll_vae/num_examples, train_kl_vae / num_examples, 227 | (train_nll_vae + train_kl_vae)/num_examples, 228 | train_nll_svi/num_examples, train_kl_svi/ num_examples, 229 | (train_nll_svi + train_kl_svi)/num_examples, 230 | param_norm, best_val_nll, best_epoch, args.beta, 231 | num_examples / (time.time() - start_time))) 232 | print('--------------------------------') 233 | print('Checking validation perf...') 234 | val_nll = eval(val_loader, model, meta_optimizer) 235 | loss_stats.append(val_nll) 236 | if val_nll < best_val_nll: 237 | best_val_nll = val_nll 238 | best_epoch = epoch 239 | checkpoint = { 240 | 'args': args.__dict__, 241 | 'model': model, 242 | 'optimizer': optimizer, 243 | 'loss_stats': loss_stats 244 | } 245 | print('Savaeng checkpoint to %s' % args.checkpoint_path) 246 | torch.save(checkpoint, args.checkpoint_path) 247 | 248 | def get_agg_kl(q_data, test_data, model): 249 | model.eval() 250 | means = [] 251 | logvars = [] 252 | all_z = [] 253 | for datum in q_data: 254 | img, _ = datum 255 | batch_size = img.size(0) 256 | img = Variable(img.cuda()) 257 | mean, logvar = model._enc_forward(img) 258 | z_samples = model._reparameterize(mean, logvar) 259 | means.append(mean.data) 260 | logvars.append(logvar.data) 261 | all_z.append(z_samples.data) 262 | means = torch.cat(means, 0) 263 | logvars = torch.cat(logvars, 0) 264 | N = float(means.size(0)) 265 | mean_prior = torch.zeros(1, means.size(1)).cuda() 266 | logvar_prior = torch.zeros(1, means.size(1)).cuda() 267 | agg_kl = 0. 268 | count = 0. 269 | for datum in test_data: 270 | img, _ = datum 271 | batch_size = img.size(0) 272 | img = Variable(img.cuda()) 273 | mean, logvar = model._enc_forward(img) 274 | z_samples = model._reparameterize(mean, logvar).data 275 | for i in range(z_samples.size(0)): 276 | z_i = z_samples[i].unsqueeze(0).expand_as(means) 277 | log_agg_density = utils.log_gaussian(z_i, means, logvars) # log q(z|x) for all x 278 | log_q = utils.logsumexp(log_agg_density, 0) 279 | log_q = -np.log(N) + log_q 280 | log_p = utils.log_gaussian(z_samples[i].unsqueeze(0), mean_prior, logvar_prior) 281 | agg_kl += log_q.sum()- log_p.sum() 282 | count += 1 283 | mean_var = means.var(0) 284 | print('active units', (mean_var > 0.02).float().sum()) 285 | print(mean_var) 286 | return agg_kl / count 287 | 288 | def eval(data, model, meta_optimizer, agg_kl = 0): 289 | model.eval() 290 | num_examples = 0 291 | total_nll_autoreg = 0. 292 | total_nll_vae = 0. 293 | total_kl_vae = 0. 294 | total_nll_svi = 0. 295 | total_kl_svi = 0. 296 | total_kl_dim = 0 297 | for datum in data: 298 | img, _ = datum 299 | batch_size = img.size(0) 300 | img = Variable(img.cuda()) 301 | if args.model == 'autoreg': 302 | preds = model._dec_forward(img, None) 303 | nll_autoreg = utils.log_bernoulli_loss(preds, img) 304 | total_nll_autoreg += nll_autoreg.data[0]*batch_size 305 | elif args.model == 'svi': 306 | mean_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) 307 | logvar_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) 308 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], img) 309 | mean_svi_final, logvar_svi_final = var_params_svi 310 | z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) 311 | preds = model._dec_forward(img, z_samples) 312 | nll_svi = utils.log_bernoulli_loss(preds, img) 313 | total_nll_svi += nll_svi.data[0]*batch_size 314 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 315 | total_kl_svi += kl_svi.data[0]*batch_size 316 | else: 317 | mean, logvar = model._enc_forward(img) 318 | z_samples = model._reparameterize(mean, logvar) 319 | preds = model._dec_forward(img, z_samples) 320 | nll_vae = utils.log_bernoulli_loss(preds, img) 321 | total_nll_vae += nll_vae.data[0]*batch_size 322 | kl_vae = utils.kl_loss_diag(mean, logvar) 323 | total_kl_vae += kl_vae.data[0]*batch_size 324 | kl_dim = utils.kl_loss_dim(mean, logvar) 325 | total_kl_dim += kl_dim.sum(0).data 326 | if args.model == 'savae': 327 | mean_svi = Variable(mean.data, requires_grad = True) 328 | logvar_svi = Variable(logvar.data, requires_grad = True) 329 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], img) 330 | mean_svi_final, logvar_svi_final = var_params_svi 331 | z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) 332 | preds = model._dec_forward(img, z_samples.detach()) 333 | nll_svi = utils.log_bernoulli_loss(preds, img) 334 | total_nll_svi += nll_svi.data[0]*batch_size 335 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 336 | total_kl_svi += kl_svi.data[0]*batch_size 337 | mean, logvar = mean_svi_final, logvar_svi_final 338 | num_examples += batch_size 339 | 340 | nll_autoreg = total_nll_autoreg / num_examples 341 | nll_vae = total_nll_vae/ num_examples 342 | kl_vae = total_kl_vae / num_examples 343 | nll_bound_vae = (total_nll_vae + total_kl_vae)/num_examples 344 | nll_svi = total_nll_svi/num_examples 345 | kl_svi = total_kl_svi/num_examples 346 | nll_bound_svi = (total_nll_svi + total_kl_svi)/num_examples 347 | kl_dim = total_kl_dim / num_examples 348 | print([ '%.4f' % e for e in list(kl_dim)]) 349 | print('') 350 | print('NEG ELBO: %.4f, KL: %.4f, AGG KL: %.4f, MI: %.4f ' % 351 | (nll_bound_vae, kl_vae, agg_kl, kl_vae - agg_kl)) 352 | print('') 353 | print('AR NLL: %.4f, VAE NLL: %.4f, VAE KL: %.4f, VAE NLL BOUND: %.4f, SVI PPL: %.4f, SVI KL: %.4f, SVI NLL BOUND: %.4f' % 354 | (nll_autoreg, nll_vae, kl_vae, nll_bound_vae, nll_svi, kl_svi, nll_bound_svi)) 355 | model.train() 356 | if args.model == 'autoreg': 357 | return nll_autoreg 358 | elif args.model == 'vae': 359 | return nll_bound_vae 360 | elif args.model == 'savae' or args.model == 'svi': 361 | return nll_bound_svi 362 | 363 | 364 | if __name__ == '__main__': 365 | args = parser.parse_args() 366 | main(args) 367 | -------------------------------------------------------------------------------- /train_text_skip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import os 5 | 6 | import argparse 7 | import json 8 | import random 9 | import shutil 10 | import copy 11 | 12 | import torch 13 | from torch import cuda 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch.nn.parameter import Parameter 17 | 18 | import torch.nn.functional as F 19 | import numpy as np 20 | import h5py 21 | import time 22 | from optim_n2n import OptimN2N 23 | from data import Dataset 24 | from models_text_skip import RNNVAE 25 | import utils 26 | 27 | parser = argparse.ArgumentParser() 28 | 29 | # Input data 30 | parser.add_argument('--train_file', default='../savi/data/yahoo/yahoo-train.hdf5') 31 | parser.add_argument('--val_file', default='../savi/data/yahoo/yahoo-val.hdf5') 32 | parser.add_argument('--test_file', default='../savi/data/yahoo/yahoo-test.hdf5') 33 | parser.add_argument('--train_from', default='') 34 | 35 | # Model options 36 | parser.add_argument('--latent_dim', default=32, type=int) 37 | parser.add_argument('--enc_word_dim', default=512, type=int) 38 | parser.add_argument('--enc_h_dim', default=1024, type=int) 39 | parser.add_argument('--enc_num_layers', default=1, type=int) 40 | parser.add_argument('--dec_word_dim', default=512, type=int) 41 | parser.add_argument('--dec_h_dim', default=1024, type=int) 42 | parser.add_argument('--skip', default=1, type=int) 43 | parser.add_argument('--dec_num_layers', default=1, type=int) 44 | parser.add_argument('--dec_dropout', default=0.5, type=float) 45 | parser.add_argument('--model', default='vae', type=str, choices = ['vae', 'autoreg', 'savae', 'svi']) 46 | parser.add_argument('--train_n2n', default=1, type=int) 47 | parser.add_argument('--train_kl', default=1, type=int) 48 | 49 | # Optimization options 50 | parser.add_argument('--checkpoint_path', default='baseline.pt') 51 | parser.add_argument('--slurm', default=0, type=int) 52 | parser.add_argument('--warmup', default=10, type=int) 53 | parser.add_argument('--num_epochs', default=30, type=int) 54 | parser.add_argument('--min_epochs', default=15, type=int) 55 | parser.add_argument('--start_epoch', default=0, type=int) 56 | parser.add_argument('--svi_steps', default=20, type=int) 57 | parser.add_argument('--svi_lr1', default=1, type=float) 58 | parser.add_argument('--svi_lr2', default=1, type=float) 59 | parser.add_argument('--eps', default=1e-5, type=float) 60 | parser.add_argument('--decay', default=0, type=int) 61 | parser.add_argument('--momentum', default=0.5, type=float) 62 | parser.add_argument('--lr', default=1, type=float) 63 | parser.add_argument('--max_grad_norm', default=5, type=float) 64 | parser.add_argument('--svi_max_grad_norm', default=5, type=float) 65 | parser.add_argument('--gpu', default=2, type=int) 66 | parser.add_argument('--seed', default=3435, type=int) 67 | parser.add_argument('--print_every', type=int, default=100) 68 | parser.add_argument('--test', type=int, default=0) 69 | 70 | def main(args): 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | train_data = Dataset(args.train_file) 74 | val_data = Dataset(args.val_file) 75 | train_sents = train_data.batch_size.sum() 76 | vocab_size = int(train_data.vocab_size) 77 | print('Train data: %d batches' % len(train_data)) 78 | print('Val data: %d batches' % len(val_data)) 79 | print('Word vocab size: %d' % vocab_size) 80 | if args.slurm == 0: 81 | cuda.set_device(args.gpu) 82 | if args.train_from == '': 83 | model = RNNVAE(vocab_size = vocab_size, 84 | enc_word_dim = args.enc_word_dim, 85 | enc_h_dim = args.enc_h_dim, 86 | enc_num_layers = args.enc_num_layers, 87 | dec_word_dim = args.dec_word_dim, 88 | dec_h_dim = args.dec_h_dim, 89 | dec_num_layers = args.dec_num_layers, 90 | dec_dropout = args.dec_dropout, 91 | latent_dim = args.latent_dim, 92 | mode = args.model, 93 | skip = args.skip) 94 | for param in model.parameters(): 95 | param.data.uniform_(-0.1, 0.1) 96 | else: 97 | print('loading model from ' + args.train_from) 98 | checkpoint = torch.load(args.train_from) 99 | model = checkpoint['model'] 100 | 101 | print("model architecture") 102 | print(model) 103 | 104 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 105 | 106 | if args.warmup == 0: 107 | args.beta = 1. 108 | else: 109 | args.beta = 0.1 110 | 111 | criterion = nn.NLLLoss() 112 | model.cuda() 113 | criterion.cuda() 114 | model.train() 115 | 116 | def variational_loss(input, sents, model, z = None): 117 | mean, logvar = input 118 | z_samples = model._reparameterize(mean, logvar, z) 119 | preds = model._dec_forward(sents, z_samples) 120 | nll = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(preds.size(1))]) 121 | kl = utils.kl_loss_diag(mean, logvar) 122 | return nll + args.beta*kl 123 | 124 | update_params = list(model.dec.parameters()) 125 | meta_optimizer = OptimN2N(variational_loss, model, update_params, eps = args.eps, 126 | lr = [args.svi_lr1, args.svi_lr2], 127 | iters = args.svi_steps, momentum = args.momentum, 128 | acc_param_grads= args.train_n2n == 1, 129 | max_grad_norm = args.svi_max_grad_norm) 130 | if args.test == 1: 131 | args.beta = 1 132 | test_data = Dataset(args.test_file) 133 | agg_kl = get_agg_kl(test_data, model, meta_optimizer) 134 | print('agg KL', agg_kl) 135 | eval(test_data, model, meta_optimizer, agg_kl) 136 | exit() 137 | 138 | t = 0 139 | best_val_nll = 1e5 140 | best_epoch = 0 141 | val_stats = [] 142 | epoch = 0 143 | while epoch < args.num_epochs: 144 | start_time = time.time() 145 | epoch += 1 146 | print('Starting epoch %d' % epoch) 147 | train_nll_vae = 0. 148 | train_nll_autoreg = 0. 149 | train_kl_vae = 0. 150 | train_nll_svi = 0. 151 | train_kl_svi = 0. 152 | train_kl_init_final = 0. 153 | num_sents = 0 154 | num_words = 0 155 | b = 0 156 | 157 | for i in np.random.permutation(len(train_data)): 158 | if args.warmup > 0: 159 | args.beta = min(1, args.beta + 1./(args.warmup*len(train_data))) 160 | 161 | sents, length, batch_size = train_data[i] 162 | if args.gpu >= 0: 163 | sents = sents.cuda() 164 | b += 1 165 | 166 | optimizer.zero_grad() 167 | if args.model == 'autoreg': 168 | preds = model._dec_forward(sents, None, True) 169 | nll_autoreg = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 170 | train_nll_autoreg += nll_autoreg.data[0]*batch_size 171 | nll_autoreg.backward() 172 | elif args.model == 'svi': 173 | mean_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) 174 | logvar_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) 175 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], sents, 176 | b % args.print_every == 0) 177 | mean_svi_final, logvar_svi_final = var_params_svi 178 | z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) 179 | preds = model._dec_forward(sents, z_samples) 180 | nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 181 | train_nll_svi += nll_svi.data[0]*batch_size 182 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 183 | train_kl_svi += kl_svi.data[0]*batch_size 184 | var_loss = nll_svi + args.beta*kl_svi 185 | var_loss.backward(retain_graph = True) 186 | else: 187 | mean, logvar = model._enc_forward(sents) 188 | z_samples = model._reparameterize(mean, logvar) 189 | preds = model._dec_forward(sents, z_samples) 190 | nll_vae = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 191 | train_nll_vae += nll_vae.data[0]*batch_size 192 | kl_vae = utils.kl_loss_diag(mean, logvar) 193 | train_kl_vae += kl_vae.data[0]*batch_size 194 | if args.model == 'vae': 195 | vae_loss = nll_vae + args.beta*kl_vae 196 | vae_loss.backward(retain_graph = True) 197 | if args.model == 'savae': 198 | var_params = torch.cat([mean, logvar], 1) 199 | mean_svi = Variable(mean.data, requires_grad = True) 200 | logvar_svi = Variable(logvar.data, requires_grad = True) 201 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], sents, 202 | b % args.print_every == 0) 203 | mean_svi_final, logvar_svi_final = var_params_svi 204 | z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) 205 | preds = model._dec_forward(sents, z_samples) 206 | nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 207 | train_nll_svi += nll_svi.data[0]*batch_size 208 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 209 | train_kl_svi += kl_svi.data[0]*batch_size 210 | var_loss = nll_svi + args.beta*kl_svi 211 | var_loss.backward(retain_graph = True) 212 | if args.train_n2n == 0: 213 | if args.train_kl == 1: 214 | mean_final = mean_svi_final.detach() 215 | logvar_final = logvar_svi_final.detach() 216 | kl_init_final = utils.kl_loss(mean, logvar, mean_final, logvar_final) 217 | train_kl_init_final += kl_init_final.data[0]*batch_size 218 | kl_init_final.backward(retain_graph = True) 219 | else: 220 | vae_loss = nll_vae + args.beta*kl_vae 221 | var_param_grads = torch.autograd.grad(vae_loss, [mean, logvar], retain_graph=True) 222 | var_param_grads = torch.cat(var_param_grads, 1) 223 | var_params.backward(var_param_grads, retain_graph=True) 224 | else: 225 | var_param_grads = meta_optimizer.backward([mean_svi_final.grad, logvar_svi_final.grad], 226 | b % args.print_every == 0) 227 | var_param_grads = torch.cat(var_param_grads, 1) 228 | var_params.backward(var_param_grads) 229 | if args.max_grad_norm > 0: 230 | torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) 231 | optimizer.step() 232 | num_sents += batch_size 233 | num_words += batch_size * length 234 | 235 | if b % args.print_every == 0: 236 | param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 237 | print('Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARPPL: %.2f, TrainVAE_PPL: %.2f, TrainVAE_KL: %.4f, TrainVAE_PPLBnd: %.2f, TrainSVI_PPL: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPLBnd: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec' % 238 | (t, epoch, b+1, len(train_data), args.lr, np.exp(train_nll_autoreg / num_words), 239 | np.exp(train_nll_vae/num_words), train_kl_vae / num_sents, 240 | np.exp((train_nll_vae + train_kl_vae)/num_words), 241 | np.exp(train_nll_svi/num_words), train_kl_svi/ num_sents, 242 | np.exp((train_nll_svi + train_kl_svi)/num_words), train_kl_init_final / num_sents, 243 | param_norm, best_val_nll, best_epoch, args.beta, 244 | num_sents / (time.time() - start_time))) 245 | 246 | print('--------------------------------') 247 | print('Checking validation perf...') 248 | val_nll = eval(val_data, model, meta_optimizer) 249 | val_stats.append(val_nll) 250 | if val_nll < best_val_nll: 251 | best_val_nll = val_nll 252 | best_epoch = epoch 253 | model.cpu() 254 | checkpoint = { 255 | 'args': args.__dict__, 256 | 'model': model, 257 | 'val_stats': val_stats 258 | } 259 | print('Savaeng checkpoint to %s' % args.checkpoint_path) 260 | torch.save(checkpoint, args.checkpoint_path) 261 | model.cuda() 262 | else: 263 | if epoch >= args.min_epochs: 264 | args.decay = 1 265 | if args.decay == 1: 266 | args.lr = args.lr*0.5 267 | for param_group in optimizer.param_groups: 268 | param_group['lr'] = args.lr 269 | if args.lr < 0.03: 270 | break 271 | 272 | def get_agg_kl(data, model, meta_optimizer): 273 | model.eval() 274 | criterion = nn.NLLLoss().cuda() 275 | means = [] 276 | logvars = [] 277 | all_z = [] 278 | for i in range(len(data)): 279 | sents, length, batch_size = data[i] 280 | if args.gpu >= 0: 281 | sents = sents.cuda() 282 | mean, logvar = model._enc_forward(sents) 283 | z_samples = model._reparameterize(mean, logvar) 284 | if args.model == 'savae': 285 | mean_svi = Variable(mean.data, requires_grad = True) 286 | logvar_svi = Variable(logvar.data, requires_grad = True) 287 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], sents) 288 | mean_svi_final, logvar_svi_final = var_params_svi 289 | z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) 290 | preds = model._dec_forward(sents, z_samples) 291 | nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 292 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 293 | mean, logvar = mean_svi_final, logvar_svi_final 294 | means.append(mean.data) 295 | logvars.append(logvar.data) 296 | all_z.append(z_samples.data) 297 | means = torch.cat(means, 0) 298 | logvars = torch.cat(logvars, 0) 299 | all_z = torch.cat(all_z, 0) 300 | N = float(means.size(0)) 301 | mean_prior = torch.zeros(1, means.size(1)).cuda() 302 | logvar_prior = torch.zeros(1, means.size(1)).cuda() 303 | agg_kl = 0. 304 | count = 0. 305 | for i in range(all_z.size(0)): 306 | z_i = all_z[i].unsqueeze(0).expand_as(means) 307 | log_agg_density = utils.log_gaussian(z_i, means, logvars) # log q(z|x) for all x 308 | log_q = utils.logsumexp(log_agg_density, 0) 309 | log_q = -np.log(N) + log_q 310 | log_p = utils.log_gaussian(all_z[i].unsqueeze(0), mean_prior, logvar_prior) 311 | agg_kl += log_q.sum()- log_p.sum() 312 | count += 1 313 | mean_var = mean.var(0) 314 | print('active units', (mean_var > 0.02).float().sum()) 315 | print(mean_var) 316 | 317 | return agg_kl / count 318 | 319 | def eval(data, model, meta_optimizer, agg_kl = 0): 320 | model.eval() 321 | criterion = nn.NLLLoss().cuda() 322 | num_sents = 0 323 | num_words = 0 324 | total_nll_autoreg = 0. 325 | total_nll_vae = 0. 326 | total_kl_vae = 0. 327 | total_nll_svi = 0. 328 | total_kl_svi = 0. 329 | best_svi_loss = 0. 330 | total_kl_dim = 0 331 | for i in range(len(data)): 332 | sents, length, batch_size = data[i] 333 | num_words += batch_size*length 334 | num_sents += batch_size 335 | if args.gpu >= 0: 336 | sents = sents.cuda() 337 | if args.model == 'autoreg': 338 | preds = model._dec_forward(sents, None, True) 339 | nll_autoreg = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 340 | total_nll_autoreg += nll_autoreg.data[0]*batch_size 341 | elif args.model == 'svi': 342 | mean_svi = Variable(0.1*torch.randn(batch_size, args.latent_dim).cuda(), requires_grad = True) 343 | logvar_svi = Variable(0.1*torch.randn(batch_size, args.latent_dim).cuda(), requires_grad = True) 344 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], sents) 345 | mean_svi_final, logvar_svi_final = var_params_svi 346 | z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) 347 | preds = model._dec_forward(sents, z_samples) 348 | nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 349 | total_nll_svi += nll_svi.data[0]*batch_size 350 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 351 | total_kl_svi += kl_svi.data[0]*batch_size 352 | mean, logvar = mean_svi_final, logvar_svi_final 353 | else: 354 | mean, logvar = model._enc_forward(sents) 355 | z_samples = model._reparameterize(mean, logvar) 356 | preds = model._dec_forward(sents, z_samples) 357 | nll_vae = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 358 | total_nll_vae += nll_vae.data[0]*batch_size 359 | kl_vae = utils.kl_loss_diag(mean, logvar) 360 | kl_dim = utils.kl_loss_dim(mean, logvar) 361 | total_kl_dim += kl_dim.sum(0).data 362 | total_kl_vae += kl_vae.data[0]*batch_size 363 | if args.model == 'savae': 364 | mean_svi = Variable(mean.data, requires_grad = True) 365 | logvar_svi = Variable(logvar.data, requires_grad = True) 366 | var_params_svi = meta_optimizer.forward([mean_svi, logvar_svi], sents) 367 | mean_svi_final, logvar_svi_final = var_params_svi 368 | z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) 369 | preds = model._dec_forward(sents, z_samples) 370 | nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) 371 | total_nll_svi += nll_svi.data[0]*batch_size 372 | kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) 373 | total_kl_svi += kl_svi.data[0]*batch_size 374 | mean, logvar = mean_svi_final, logvar_svi_final 375 | ppl_autoreg = np.exp(total_nll_autoreg / num_words) 376 | ppl_vae = np.exp(total_nll_vae/ num_words) 377 | kl_vae = total_kl_vae / num_sents 378 | ppl_bound_vae = np.exp((total_nll_vae + total_kl_vae)/num_words) 379 | ppl_svi = np.exp(total_nll_svi/num_words) 380 | kl_svi = total_kl_svi/num_sents 381 | kl_dim = total_kl_dim / num_sents 382 | print([ '%.4f' % e for e in list(kl_dim)]) 383 | ppl_bound_svi = np.exp((total_nll_svi + total_kl_svi)/num_words) 384 | print('elbo vae', (total_nll_vae + total_kl_vae)/num_sents) 385 | print('elbo savi', (total_nll_svi + total_kl_svi)/num_sents) 386 | 387 | print('AR PPL: %.4f, VAE PPL: %.4f, VAE KL: %.4f, VAE PPL BOUND: %.4f, SVI PPL: %.4f, SVI KL: %.4f, SVI PPL BOUND: %.4f' % 388 | (ppl_autoreg, ppl_vae, kl_vae, ppl_bound_vae, ppl_svi, kl_svi, ppl_bound_svi)) 389 | model.train() 390 | if args.model == 'autoreg': 391 | return ppl_autoreg 392 | elif args.model == 'vae': 393 | return ppl_bound_vae 394 | elif args.model == 'savae' or args.model == 'svi': 395 | return ppl_bound_svi 396 | 397 | 398 | if __name__ == '__main__': 399 | args = parser.parse_args() 400 | main(args) 401 | --------------------------------------------------------------------------------