├── AE.py ├── README.md └── conv_vae.py /AE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as dsets 3 | import torchvision.transforms as transforms 4 | import torchvision 5 | from torch.autograd import Variable 6 | 7 | from time import time 8 | 9 | from AE import * 10 | 11 | 12 | num_epochs = 50 13 | batch_size = 100 14 | hidden_size = 30 15 | 16 | 17 | # MNIST dataset 18 | dataset = dsets.MNIST(root='../data', 19 | train=True, 20 | transform=transforms.ToTensor(), 21 | download=True) 22 | 23 | # Data loader 24 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 25 | batch_size=batch_size, 26 | shuffle=True) 27 | 28 | def to_var(x): 29 | if torch.cuda.is_available(): 30 | x = x.cuda() 31 | return Variable(x) 32 | 33 | 34 | class Autoencoder(nn.Module): 35 | def __init__(self, in_dim=784, h_dim=400): 36 | super(Autoencoder, self).__init__() 37 | 38 | self.encoder = nn.Sequential( 39 | nn.Linear(in_dim, h_dim), 40 | nn.ReLU() 41 | ) 42 | 43 | self.decoder = nn.Sequential( 44 | nn.Linear(h_dim, in_dim), 45 | nn.Sigmoid() 46 | ) 47 | 48 | 49 | def forward(self, x): 50 | """ 51 | Note: image dimension conversion will be handled by external methods 52 | """ 53 | out = self.encoder(x) 54 | out = self.decoder(out) 55 | return out 56 | 57 | 58 | ae = Autoencoder(in_dim=784, h_dim=hidden_size) 59 | 60 | if torch.cuda.is_available(): 61 | ae.cuda() 62 | 63 | criterion = nn.BCELoss() 64 | optimizer = torch.optim.Adam(ae.parameters(), lr=0.001) 65 | iter_per_epoch = len(data_loader) 66 | data_iter = iter(data_loader) 67 | 68 | # save fixed inputs for debugging 69 | fixed_x, _ = next(data_iter) 70 | torchvision.utils.save_image(Variable(fixed_x).data.cpu(), './data/real_images.png') 71 | fixed_x = to_var(fixed_x.view(fixed_x.size(0), -1)) 72 | 73 | for epoch in range(num_epochs): 74 | t0 = time() 75 | for i, (images, _) in enumerate(data_loader): 76 | 77 | # flatten the image 78 | images = to_var(images.view(images.size(0), -1)) 79 | out = ae(images) 80 | loss = criterion(out, images) 81 | 82 | optimizer.zero_grad() 83 | loss.backward() 84 | optimizer.step() 85 | 86 | if (i+1) % 100 == 0: 87 | print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f Time: %.2fs' 88 | %(epoch+1, num_epochs, i+1, len(dataset)//batch_size, loss.data[0], time()-t0)) 89 | 90 | # save the reconstructed images 91 | reconst_images = ae(fixed_x) 92 | reconst_images = reconst_images.view(reconst_images.size(0), 1, 28, 28) 93 | torchvision.utils.save_image(reconst_images.data.cpu(), './data/reconst_images_%d.png' % (epoch+1)) 94 | 95 | 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autoencoders in PyTorch # 2 | 3 | ### Update - Feb 4, 2018 ### 4 | 5 | * One layer vanilla autoencoder on MNIST 6 | * Variational autoencoder with Convolutional hidden layers on CIFAR-10 -------------------------------------------------------------------------------- /conv_vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.utils.data 5 | from torch import nn, optim 6 | from torch.autograd import Variable 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | from torchvision import datasets, transforms 10 | from torchvision.utils import save_image 11 | 12 | 13 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 14 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 15 | help='input batch size for training (default: 128)') 16 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 17 | help='number of epochs to train (default: 10)') 18 | parser.add_argument('--no-cuda', action='store_true', default=False, 19 | help='enables CUDA training') 20 | parser.add_argument('--seed', type=int, default=1, metavar='S', 21 | help='random seed (default: 1)') 22 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 23 | help='how many batches to wait before logging training status') 24 | parser.add_argument('--hidden-size', type=int, default=20, metavar='N', 25 | help='how big is z') 26 | parser.add_argument('--intermediate-size', type=int, default=128, metavar='N', 27 | help='how big is linear around z') 28 | # parser.add_argument('--widen-factor', type=int, default=1, metavar='N', 29 | # help='how wide is the model') 30 | args = parser.parse_args() 31 | args.cuda = not args.no_cuda and torch.cuda.is_available() 32 | 33 | 34 | torch.manual_seed(args.seed) 35 | if args.cuda: 36 | torch.cuda.manual_seed(args.seed) 37 | 38 | 39 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 40 | train_loader = torch.utils.data.DataLoader( 41 | datasets.CIFAR10('../data', train=True, download=True, 42 | transform=transforms.ToTensor()), 43 | batch_size=args.batch_size, shuffle=True, **kwargs) 44 | test_loader = torch.utils.data.DataLoader( 45 | datasets.CIFAR10('../data', train=False, transform=transforms.ToTensor()), 46 | batch_size=args.batch_size, shuffle=False, **kwargs) 47 | 48 | 49 | class VAE(nn.Module): 50 | def __init__(self): 51 | super(VAE, self).__init__() 52 | 53 | # Encoder 54 | self.conv1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1) 55 | self.conv2 = nn.Conv2d(3, 32, kernel_size=2, stride=2, padding=0) 56 | self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 57 | self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 58 | self.fc1 = nn.Linear(16 * 16 * 32, args.intermediate_size) 59 | 60 | # Latent space 61 | self.fc21 = nn.Linear(args.intermediate_size, args.hidden_size) 62 | self.fc22 = nn.Linear(args.intermediate_size, args.hidden_size) 63 | 64 | # Decoder 65 | self.fc3 = nn.Linear(args.hidden_size, args.intermediate_size) 66 | self.fc4 = nn.Linear(args.intermediate_size, 8192) 67 | self.deconv1 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1) 68 | self.deconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1) 69 | self.deconv3 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2, padding=0) 70 | self.conv5 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1) 71 | 72 | self.relu = nn.ReLU() 73 | self.sigmoid = nn.Sigmoid() 74 | 75 | def encode(self, x): 76 | out = self.relu(self.conv1(x)) 77 | out = self.relu(self.conv2(out)) 78 | out = self.relu(self.conv3(out)) 79 | out = self.relu(self.conv4(out)) 80 | out = out.view(out.size(0), -1) 81 | h1 = self.relu(self.fc1(out)) 82 | return self.fc21(h1), self.fc22(h1) 83 | 84 | def reparameterize(self, mu, logvar): 85 | if self.training: 86 | std = logvar.mul(0.5).exp_() 87 | eps = Variable(std.data.new(std.size()).normal_()) 88 | return eps.mul(std).add_(mu) 89 | else: 90 | return mu 91 | 92 | def decode(self, z): 93 | h3 = self.relu(self.fc3(z)) 94 | out = self.relu(self.fc4(h3)) 95 | # import pdb; pdb.set_trace() 96 | out = out.view(out.size(0), 32, 16, 16) 97 | out = self.relu(self.deconv1(out)) 98 | out = self.relu(self.deconv2(out)) 99 | out = self.relu(self.deconv3(out)) 100 | out = self.sigmoid(self.conv5(out)) 101 | return out 102 | 103 | def forward(self, x): 104 | mu, logvar = self.encode(x) 105 | z = self.reparameterize(mu, logvar) 106 | return self.decode(z), mu, logvar 107 | 108 | 109 | model = VAE() 110 | if args.cuda: 111 | model.cuda() 112 | optimizer = optim.RMSprop(model.parameters(), lr=1e-3) 113 | 114 | 115 | # Reconstruction + KL divergence losses summed over all elements and batch 116 | def loss_function(recon_x, x, mu, logvar): 117 | BCE = F.binary_cross_entropy(recon_x.view(-1, 32 * 32 * 3), 118 | x.view(-1, 32 * 32 * 3), size_average=False) 119 | 120 | # see Appendix B from VAE paper: 121 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 122 | # https://arxiv.org/abs/1312.6114 123 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 124 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 125 | 126 | return BCE + KLD 127 | 128 | 129 | def train(epoch): 130 | model.train() 131 | train_loss = 0 132 | for batch_idx, (data, _) in enumerate(train_loader): 133 | data = Variable(data) 134 | if args.cuda: 135 | data = data.cuda() 136 | optimizer.zero_grad() 137 | recon_batch, mu, logvar = model(data) 138 | loss = loss_function(recon_batch, data, mu, logvar) 139 | loss.backward() 140 | train_loss += loss.data[0] 141 | optimizer.step() 142 | if batch_idx % args.log_interval == 0: 143 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 144 | epoch, batch_idx * len(data), len(train_loader.dataset), 145 | 100. * batch_idx / len(train_loader), 146 | loss.data[0] / len(data))) 147 | 148 | print('====> Epoch: {} Average loss: {:.4f}'.format( 149 | epoch, train_loss / len(train_loader.dataset))) 150 | 151 | 152 | def test(epoch): 153 | model.eval() 154 | test_loss = 0 155 | for i, (data, _) in enumerate(test_loader): 156 | if args.cuda: 157 | data = data.cuda() 158 | data = Variable(data, volatile=True) 159 | recon_batch, mu, logvar = model(data) 160 | test_loss += loss_function(recon_batch, data, mu, logvar).data[0] 161 | if epoch == args.epochs and i == 0: 162 | n = min(data.size(0), 8) 163 | comparison = torch.cat([data[:n], 164 | recon_batch[:n]]) 165 | save_image(comparison.data.cpu(), 166 | 'snapshots/conv_vae/reconstruction_' + str(epoch) + 167 | '.png', nrow=n) 168 | 169 | test_loss /= len(test_loader.dataset) 170 | print('====> Test set loss: {:.4f}'.format(test_loss)) 171 | 172 | 173 | for epoch in range(1, args.epochs + 1): 174 | train(epoch) 175 | test(epoch) 176 | if epoch == args.epochs: 177 | sample = Variable(torch.randn(64, args.hidden_size)) 178 | if args.cuda: 179 | sample = sample.cuda() 180 | sample = model.decode(sample).cpu() 181 | save_image(sample.data.view(64, 3, 32, 32), 182 | 'snapshots/conv_vae/sample_' + str(epoch) + '.png') 183 | --------------------------------------------------------------------------------