├── data.py ├── main.py └── modules.py /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | 4 | 5 | def binarized_mnist(path=Path('./datasets/binarized_mnist')): 6 | def lines_to_np_array(lines): 7 | return np.array([[int(i) for i in line.split()] for line in lines]) 8 | 9 | data = {} 10 | for split in ['train', 'valid', 'test']: 11 | with open(path / 'binarized_mnist_{}.amat'.format(split)) as f: 12 | lines = f.readlines() 13 | data[split] = lines_to_np_array(lines).astype('float32') 14 | idxs = list(range(data[split].shape[0])) 15 | np.random.shuffle(idxs) 16 | data[split] = data[split][idxs] 17 | 18 | return data 19 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torchvision.utils import save_image 7 | from tensorboardX import SummaryWriter 8 | 9 | from modules import VAE, VAE_NF 10 | from data import binarized_mnist 11 | 12 | 13 | def iterator(data, batch_size=32): 14 | for i in range(0, data.shape[0], batch_size): 15 | yield torch.from_numpy(data[i: i + batch_size]), None 16 | 17 | 18 | BATCH_SIZE = 32 19 | N_EPOCHS = 100 20 | PRINT_INTERVAL = 500 21 | NUM_WORKERS = 4 22 | LR = 2e-4 23 | MODEL = 'VAE' # VAE-NF | VAE 24 | 25 | N_FLOWS = 30 26 | Z_DIM = 40 27 | 28 | 29 | n_steps = 0 30 | writer = SummaryWriter() 31 | dataset = binarized_mnist() 32 | 33 | if MODEL == 'VAE-NF': 34 | model = VAE_NF(N_FLOWS, Z_DIM).cuda() 35 | else: 36 | model = VAE(Z_DIM).cuda() 37 | 38 | print(model) 39 | opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) 40 | 41 | 42 | def train(): 43 | global n_steps 44 | train_loss = [] 45 | model.train() 46 | train_loader = iterator(dataset['train']) 47 | 48 | for batch_idx, (x, _) in enumerate(train_loader): 49 | start_time = time.time() 50 | x = x.cuda().view(-1, 784) 51 | 52 | x_tilde, kl_div = model(x) 53 | loss_recons = F.binary_cross_entropy(x_tilde, x, size_average=False) / x.size(0) 54 | loss = loss_recons + kl_div 55 | 56 | opt.zero_grad() 57 | loss.backward() 58 | opt.step() 59 | 60 | train_loss.append([loss_recons.item(), kl_div.item()]) 61 | writer.add_scalar('loss/train/ELBO', loss.item(), n_steps) 62 | writer.add_scalar('loss/train/reconstruction', loss_recons.item(), n_steps) 63 | writer.add_scalar('loss/train/KL', kl_div.item(), n_steps) 64 | 65 | if (batch_idx + 1) % PRINT_INTERVAL == 0: 66 | print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {:5.3f} ms/batch'.format( 67 | batch_idx * len(x), 50000, 68 | PRINT_INTERVAL * batch_idx / 50000, 69 | np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0), 70 | 1000 * (time.time() - start_time) 71 | )) 72 | 73 | n_steps += 1 74 | 75 | 76 | def evaluate(split='valid'): 77 | global n_steps 78 | start_time = time.time() 79 | val_loss = [] 80 | model.eval() 81 | eval_loader = iterator(dataset[split]) 82 | 83 | with torch.no_grad(): 84 | for batch_idx, (x, _) in enumerate(eval_loader): 85 | x = x.cuda().view(-1, 784) 86 | 87 | x_tilde, kl_div = model(x) 88 | loss_recons = F.binary_cross_entropy(x_tilde, x, size_average=False) / x.size(0) 89 | loss = loss_recons + kl_div 90 | 91 | val_loss.append(loss.item()) 92 | writer.add_scalar('loss/{}/ELBO'.format(split), loss.item(), n_steps) 93 | writer.add_scalar('loss/{}/reconstruction'.format(split), loss_recons.item(), n_steps) 94 | writer.add_scalar('loss/{}/KL'.format(split), kl_div.item(), n_steps) 95 | 96 | print('\nEvaluation Completed ({})!\tLoss: {:5.4f} Time: {:5.3f} s'.format( 97 | split, 98 | np.asarray(val_loss).mean(0), 99 | time.time() - start_time 100 | )) 101 | return np.asarray(val_loss).mean(0) 102 | 103 | 104 | def generate_reconstructions(): 105 | model.eval() 106 | x = torch.from_numpy(dataset['test'][:32]) 107 | 108 | x = x[:32].cuda().view(-1, 784) 109 | x_tilde, _ = model(x) 110 | 111 | x_cat = torch.cat([x, x_tilde], 0).view(-1, 1, 28, 28) 112 | images = x_cat.cpu().data 113 | 114 | save_image( 115 | images, 116 | 'samples/{}_reconstructions.png'.format(MODEL), 117 | nrow=8 118 | ) 119 | 120 | 121 | def generate_samples(): 122 | model.eval() 123 | z = torch.randn(64, Z_DIM).cuda() 124 | x_tilde = model.decoder(z).view(-1, 1, 28, 28) 125 | images = x_tilde.cpu().data 126 | save_image( 127 | images, 128 | 'samples/{}_samples.png'.format(MODEL), 129 | nrow=8 130 | ) 131 | 132 | 133 | BEST_LOSS = 99999 134 | LAST_SAVED = -1 135 | for epoch in range(1, N_EPOCHS): 136 | print("Epoch {}:".format(epoch)) 137 | train() 138 | cur_loss = evaluate() 139 | 140 | if cur_loss <= BEST_LOSS: 141 | BEST_LOSS = cur_loss 142 | LAST_SAVED = epoch 143 | print("Saving model!") 144 | torch.save(model.state_dict(), 'models/{}.pt'.format(MODEL)) 145 | else: 146 | print("Not saving model! Last saved: {}".format(LAST_SAVED)) 147 | 148 | generate_reconstructions() 149 | generate_samples() 150 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PlanarFlow(nn.Module): 7 | def __init__(self, D): 8 | super().__init__() 9 | self.D = D 10 | 11 | def forward(self, z, lamda): 12 | ''' 13 | z - latents from prev layer 14 | lambda - Flow parameters (b, w, u) 15 | b - scalar 16 | w - vector 17 | u - vector 18 | ''' 19 | b = lamda[:, :1] 20 | w, u = lamda[:, 1:].chunk(2, dim=1) 21 | 22 | # Forward 23 | # f(z) = z + u tanh(w^T z + b) 24 | transf = F.tanh( 25 | z.unsqueeze(1).bmm(w.unsqueeze(2))[:, 0] + b 26 | ) 27 | f_z = z + u * transf 28 | 29 | # Inverse 30 | # psi_z = tanh' (w^T z + b) w 31 | psi_z = (1 - transf ** 2) * w 32 | log_abs_det_jacobian = torch.log( 33 | (1 + psi_z.unsqueeze(1).bmm(u.unsqueeze(2))).abs() 34 | ) 35 | 36 | return f_z, log_abs_det_jacobian 37 | 38 | 39 | class NormalizingFlow(nn.Module): 40 | def __init__(self, K, D): 41 | super().__init__() 42 | self.flows = nn.ModuleList([PlanarFlow(D) for i in range(K)]) 43 | 44 | def forward(self, z_k, flow_params): 45 | # ladj -> log abs det jacobian 46 | sum_ladj = 0 47 | for i, flow in enumerate(self.flows): 48 | z_k, ladj_k = flow(z_k, flow_params[i]) 49 | sum_ladj += ladj_k 50 | 51 | return z_k, sum_ladj 52 | 53 | 54 | class VAE_NF(nn.Module): 55 | def __init__(self, K, D): 56 | super().__init__() 57 | self.dim = D 58 | self.K = K 59 | self.encoder = nn.Sequential( 60 | nn.Linear(784, 400), 61 | nn.ReLU(True), 62 | nn.Linear(400, D * 2 + K * (D * 2 + 1)) 63 | ) 64 | 65 | self.decoder = nn.Sequential( 66 | nn.Linear(D, 400), 67 | nn.ReLU(True), 68 | nn.Linear(400, 784), 69 | nn.Sigmoid() 70 | ) 71 | 72 | self.flows = NormalizingFlow(K, D) 73 | 74 | def forward(self, x): 75 | # Run Encoder and get NF params 76 | enc = self.encoder(x) 77 | mu = enc[:, :self.dim] 78 | log_var = enc[:, self.dim: self.dim * 2] 79 | flow_params = enc[:, 2 * self.dim:].chunk(self.K, dim=1) 80 | 81 | # Re-parametrize 82 | sigma = (log_var * .5).exp() 83 | z = mu + sigma * torch.randn_like(sigma) 84 | kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) 85 | 86 | # Construct more expressive posterior with NF 87 | z_k, sum_ladj = self.flows(z, flow_params) 88 | kl_div = kl_div / x.size(0) - sum_ladj.mean() # mean over batch 89 | 90 | # Run Decoder 91 | x_prime = self.decoder(z_k) 92 | return x_prime, kl_div 93 | 94 | 95 | class VAE(nn.Module): 96 | def __init__(self, D): 97 | super().__init__() 98 | self.dim = D 99 | self.encoder = nn.Sequential( 100 | nn.Linear(784, 400), 101 | nn.ReLU(True), 102 | nn.Linear(400, D * 2) 103 | ) 104 | 105 | self.decoder = nn.Sequential( 106 | nn.Linear(D, 400), 107 | nn.ReLU(True), 108 | nn.Linear(400, 784), 109 | nn.Sigmoid() 110 | ) 111 | 112 | def forward(self, x): 113 | # Run Encoder 114 | mu, log_var = self.encoder(x).chunk(2, dim=1) 115 | 116 | # Re-parametrize 117 | sigma = (log_var * .5).exp() 118 | z = mu + sigma * torch.randn_like(sigma) 119 | kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) 120 | kl_div = kl_div / x.size(0) # mean over batch 121 | 122 | # Run Decoder 123 | x_prime = self.decoder(z) 124 | return x_prime, kl_div 125 | --------------------------------------------------------------------------------