├── .gitignore
├── assets
├── ACAI-figure.png
├── interpolations_AE.png
├── interpolations_ACAI.png
└── interpolations_VAE_2.png
├── requirements.txt
├── utils.py
├── datasets.py
├── README.md
├── models
├── AE.py
└── VAE.py
├── train.py
└── architectures.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | data/
3 | results/
4 |
--------------------------------------------------------------------------------
/assets/ACAI-figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/ACAI-figure.png
--------------------------------------------------------------------------------
/assets/interpolations_AE.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/interpolations_AE.png
--------------------------------------------------------------------------------
/assets/interpolations_ACAI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/interpolations_ACAI.png
--------------------------------------------------------------------------------
/assets/interpolations_VAE_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/interpolations_VAE_2.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.15.3
2 | Pillow==5.3.0
3 | pkg-resources==0.0.0
4 | six==1.11.0
5 | torch==0.4.1
6 | torchvision==0.2.1
7 | imageio
8 | scipy
9 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def get_interpolations(args, model, device, images, images_per_row=20):
5 | model.eval()
6 | with torch.no_grad():
7 | def interpolate(t1, t2, num_interps):
8 | alpha = np.linspace(0, 1, num_interps+2)
9 | interps = []
10 | for a in alpha:
11 | interps.append(a*t2.view(1, -1) + (1 - a)*t1.view(1, -1))
12 | return torch.cat(interps, 0)
13 |
14 | if args.model == 'VAE':
15 | mu, logvar = model.encode(images.view(-1, 784))
16 | embeddings = model.reparameterize(mu, logvar).cpu()
17 | elif args.model == 'AE':
18 | embeddings = model.encode(images.view(-1, 784))
19 |
20 | interps = []
21 | for i in range(0, images_per_row+1, 1):
22 | interp = interpolate(embeddings[i], embeddings[i+1], images_per_row-4)
23 | interp = interp.to(device)
24 | interp_dec = model.decode(interp)
25 | line = torch.cat((images[i].view(-1, 784), interp_dec, images[i+1].view(-1, 784)))
26 | interps.append(line)
27 | # Complete the loop and append the first image again
28 | interp = interpolate(embeddings[i+1], embeddings[0], images_per_row-4)
29 | interp = interp.to(device)
30 | interp_dec = model.decode(interp)
31 | line = torch.cat((images[i+1].view(-1, 784), interp_dec, images[0].view(-1, 784)))
32 | interps.append(line)
33 |
34 | interps = torch.cat(interps, 0).to(device)
35 | return interps
36 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 |
4 | class MNIST(object):
5 | def __init__(self, args):
6 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
7 | self.train_loader = torch.utils.data.DataLoader(
8 | datasets.MNIST('data/mnist', train=True, download=True,
9 | transform=transforms.ToTensor()),
10 | batch_size=args.batch_size, shuffle=True, **kwargs)
11 | self.test_loader = torch.utils.data.DataLoader(
12 | datasets.MNIST('data/mnist', train=False, transform=transforms.ToTensor()),
13 | batch_size=args.batch_size, shuffle=True, **kwargs)
14 |
15 | class EMNIST(object):
16 | def __init__(self, args):
17 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
18 | self.train_loader = torch.utils.data.DataLoader(
19 | datasets.EMNIST('data/emnist', train=True, download=True, split='byclass',
20 | transform=transforms.ToTensor()),
21 | batch_size=args.batch_size, shuffle=True, **kwargs)
22 | self.test_loader = torch.utils.data.DataLoader(
23 | datasets.EMNIST('data/emnist', train=False, split='byclass',
24 | transform=transforms.ToTensor()),
25 | batch_size=args.batch_size, shuffle=True, **kwargs)
26 |
27 | class FashionMNIST(object):
28 | def __init__(self, args):
29 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
30 | self.train_loader = torch.utils.data.DataLoader(
31 | datasets.FashionMNIST('data/fmnist', train=True, download=True,
32 | transform=transforms.ToTensor()),
33 | batch_size=args.batch_size, shuffle=True, **kwargs)
34 | self.test_loader = torch.utils.data.DataLoader(
35 | datasets.FashionMNIST('data/fmnist', train=False, transform=transforms.ToTensor()),
36 | batch_size=args.batch_size, shuffle=True, **kwargs)
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AutoEncoders in PyTorch
2 |
3 | [](https://pytorch.org/)
4 | 
5 |
6 | -------------------------
7 |
8 | ## Description
9 |
10 | This repo contains an implementation of the following AutoEncoders:
11 |
12 | * [Vanilla AutoEncoders - **AE**](http://ufldl.stanford.edu/tutorial/unsupervised/Autoencoders/):
13 | The most basic autoencoder structure is one which simply maps input data-points through a __bottleneck layer__ whose dimensionality is smaller than the input.
14 |
15 | * [Variational AutoEncoders - **VAE**](https://arxiv.org/pdf/1606.05908):
16 | The Variational Autoencoder introduces the constraint
17 | that the latent code `z` is a random variable distributed according to a prior distribution `p(z)`.
18 |
19 | * [Adversarially Constrained Autoencoder Interpolations - **ACAI**](https://arxiv.org/pdf/1807.07543):
20 | A critic network tries to predict the interpolation coefficient α corresponding to an interpolated datapoint. The autoencoder is
21 | trained to fool the critic into outputting α = 0.
22 | 
23 |
24 | -------------------------
25 |
26 | ## Setup
27 |
28 | ### Create a Python Virtual Environment
29 | ```
30 | mkvirtualenv --python=/usr/bin/python3 pytorch-AE
31 | ```
32 |
33 | ### Install dependencies
34 | ```
35 | pip install torch torchvision
36 | ```
37 |
38 | -------------------------
39 |
40 | ## Training
41 | ```
42 | python train.py --help
43 | ```
44 |
45 | ### Training Options and some examples:
46 |
47 | * **Vanilla Autoencoder:**
48 | ```
49 | python train.py --model AE
50 | ```
51 |
52 | * **Variational Autoencoder:**
53 | ```
54 | python train.py --model VAE --batch-size 512 --dataset EMNIST --seed 42 --log-interval 500 --epochs 5 --embedding-size 128
55 | ```
56 |
57 | -------------------------
58 |
59 | ## Results
60 |
61 | | Vanilla AutoEncoders | Variational AutoEncoders | ACAI |
62 | |------------------------- |------------------------- | --------------------|
63 | |
|
|
|
64 |
65 |
66 |
67 | ### Contributing:
68 | If you have suggestions or any type of contribution idea, file an issue, make a PR
69 | and **don't forget to star the repository**
70 |
71 | ### More projects:
72 | Feel free to check out my other repos with more work in Machine Learning:
73 |
74 | * [World Models in TensorFlow](https://github.com/dariocazzani/World-Models-TensorFlow)
75 | * [TensorBlob](https://github.com/dariocazzani/TensorBlob)
76 | * [banaNavigation](https://github.com/dariocazzani/banaNavigation)
77 |
--------------------------------------------------------------------------------
/models/AE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from torch import nn, optim
4 | from torch.nn import functional as F
5 | from torchvision import datasets, transforms
6 |
7 | import sys
8 | sys.path.append('../')
9 | from architectures import FC_Encoder, FC_Decoder, CNN_Encoder, CNN_Decoder
10 | from datasets import MNIST, EMNIST, FashionMNIST
11 |
12 | class Network(nn.Module):
13 | def __init__(self, args):
14 | super(Network, self).__init__()
15 | output_size = args.embedding_size
16 | self.encoder = CNN_Encoder(output_size)
17 |
18 | self.decoder = CNN_Decoder(args.embedding_size)
19 |
20 | def encode(self, x):
21 | return self.encoder(x)
22 |
23 | def decode(self, z):
24 | return self.decoder(z)
25 |
26 | def forward(self, x):
27 | z = self.encode(x.view(-1, 784))
28 | return self.decode(z)
29 |
30 | class AE(object):
31 | def __init__(self, args):
32 | self.args = args
33 | self.device = torch.device("cuda" if args.cuda else "cpu")
34 | self._init_dataset()
35 | self.train_loader = self.data.train_loader
36 | self.test_loader = self.data.test_loader
37 |
38 | self.model = Network(args)
39 | self.model.to(self.device)
40 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
41 |
42 | def _init_dataset(self):
43 | if self.args.dataset == 'MNIST':
44 | self.data = MNIST(self.args)
45 | elif self.args.dataset == 'EMNIST':
46 | self.data = EMNIST(self.args)
47 | elif self.args.dataset == 'FashionMNIST':
48 | self.data = FashionMNIST(self.args)
49 | else:
50 | print("Dataset not supported")
51 | sys.exit()
52 |
53 | def loss_function(self, recon_x, x):
54 | BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
55 | return BCE
56 |
57 | def train(self, epoch):
58 | self.model.train()
59 | train_loss = 0
60 | for batch_idx, (data, _) in enumerate(self.train_loader):
61 | data = data.to(self.device)
62 | self.optimizer.zero_grad()
63 | recon_batch = self.model(data)
64 | loss = self.loss_function(recon_batch, data)
65 | loss.backward()
66 | train_loss += loss.item()
67 | self.optimizer.step()
68 | if batch_idx % self.args.log_interval == 0:
69 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
70 | epoch, batch_idx * len(data), len(self.train_loader.dataset),
71 | 100. * batch_idx / len(self.train_loader),
72 | loss.item() / len(data)))
73 |
74 | print('====> Epoch: {} Average loss: {:.4f}'.format(
75 | epoch, train_loss / len(self.train_loader.dataset)))
76 |
77 | def test(self, epoch):
78 | self.model.eval()
79 | test_loss = 0
80 | with torch.no_grad():
81 | for i, (data, _) in enumerate(self.test_loader):
82 | data = data.to(self.device)
83 | recon_batch = self.model(data)
84 | test_loss += self.loss_function(recon_batch, data).item()
85 |
86 | test_loss /= len(self.test_loader.dataset)
87 | print('====> Test set loss: {:.4f}'.format(test_loss))
88 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys
2 | import numpy as np
3 | import imageio
4 | from scipy import ndimage
5 |
6 | import torch
7 | from torchvision.utils import save_image
8 |
9 | from models.VAE import VAE
10 | from models.AE import AE
11 |
12 | from utils import get_interpolations
13 |
14 | parser = argparse.ArgumentParser(
15 | description='Main function to call training for different AutoEncoders')
16 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
17 | help='input batch size for training (default: 128)')
18 | parser.add_argument('--epochs', type=int, default=10, metavar='N',
19 | help='number of epochs to train (default: 10)')
20 | parser.add_argument('--no-cuda', action='store_true', default=False,
21 | help='enables CUDA training')
22 | parser.add_argument('--seed', type=int, default=42, metavar='S',
23 | help='random seed (default: 1)')
24 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
25 | help='how many batches to wait before logging training status')
26 | parser.add_argument('--embedding-size', type=int, default=32, metavar='N',
27 | help='how many batches to wait before logging training status')
28 | parser.add_argument('--results_path', type=str, default='results/', metavar='N',
29 | help='Where to store images')
30 | parser.add_argument('--model', type=str, default='AE', metavar='N',
31 | help='Which architecture to use')
32 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N',
33 | help='Which dataset to use')
34 |
35 | args = parser.parse_args()
36 | args.cuda = not args.no_cuda and torch.cuda.is_available()
37 | torch.manual_seed(args.seed)
38 |
39 | vae = VAE(args)
40 | ae = AE(args)
41 | architectures = {'AE': ae,
42 | 'VAE': vae}
43 |
44 | print(args.model)
45 | if __name__ == "__main__":
46 | try:
47 | os.stat(args.results_path)
48 | except :
49 | os.mkdir(args.results_path)
50 |
51 | try:
52 | autoenc = architectures[args.model]
53 | except KeyError:
54 | print('---------------------------------------------------------')
55 | print('Model architecture not supported. ', end='')
56 | print('Maybe you can implement it?')
57 | print('---------------------------------------------------------')
58 | sys.exit()
59 |
60 | try:
61 | for epoch in range(1, args.epochs + 1):
62 | autoenc.train(epoch)
63 | autoenc.test(epoch)
64 | except (KeyboardInterrupt, SystemExit):
65 | print("Manual Interruption")
66 |
67 | with torch.no_grad():
68 | images, _ = next(iter(autoenc.test_loader))
69 | images = images.to(autoenc.device)
70 | images_per_row = 16
71 | interpolations = get_interpolations(args, autoenc.model, autoenc.device, images, images_per_row)
72 |
73 | sample = torch.randn(64, args.embedding_size).to(autoenc.device)
74 | sample = autoenc.model.decode(sample).cpu()
75 | save_image(sample.view(64, 1, 28, 28),
76 | '{}/sample_{}_{}.png'.format(args.results_path, args.model, args.dataset))
77 | save_image(interpolations.view(-1, 1, 28, 28),
78 | '{}/interpolations_{}_{}.png'.format(args.results_path, args.model, args.dataset), nrow=images_per_row)
79 | interpolations = interpolations.cpu()
80 | interpolations = np.reshape(interpolations.data.numpy(), (-1, 28, 28))
81 | interpolations = ndimage.zoom(interpolations, 5, order=1)
82 | interpolations *= 256
83 | imageio.mimsave('{}/animation_{}_{}.gif'.format(args.results_path, args.model, args.dataset), interpolations.astype(np.uint8))
84 |
--------------------------------------------------------------------------------
/models/VAE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from torch import nn, optim
4 | from torch.nn import functional as F
5 |
6 | import sys
7 | sys.path.append('../')
8 | from architectures import FC_Encoder, FC_Decoder, CNN_Encoder, CNN_Decoder
9 | from datasets import MNIST, EMNIST, FashionMNIST
10 |
11 | class Network(nn.Module):
12 | def __init__(self, args):
13 | super(Network, self).__init__()
14 | output_size = 512
15 | self.encoder = CNN_Encoder(output_size)
16 | self.var = nn.Linear(output_size, args.embedding_size)
17 | self.mu = nn.Linear(output_size, args.embedding_size)
18 |
19 | self.decoder = CNN_Decoder(args.embedding_size)
20 |
21 | def encode(self, x):
22 | x = self.encoder(x)
23 | return self.mu(x), self.var(x)
24 |
25 | def reparameterize(self, mu, logvar):
26 | std = torch.exp(0.5*logvar)
27 | eps = torch.randn_like(std)
28 | return eps.mul(std).add_(mu)
29 |
30 | def decode(self, z):
31 | return self.decoder(z)
32 |
33 | def forward(self, x):
34 | mu, logvar = self.encode(x.view(-1, 784))
35 | z = self.reparameterize(mu, logvar)
36 | return self.decode(z), mu, logvar
37 |
38 | class VAE(object):
39 | def __init__(self, args):
40 | self.args = args
41 | self.device = torch.device("cuda" if args.cuda else "cpu")
42 | self._init_dataset()
43 | self.train_loader = self.data.train_loader
44 | self.test_loader = self.data.test_loader
45 |
46 | self.model = Network(args)
47 | self.model.to(self.device)
48 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
49 |
50 | def _init_dataset(self):
51 | if self.args.dataset == 'MNIST':
52 | self.data = MNIST(self.args)
53 | elif self.args.dataset == 'EMNIST':
54 | self.data = EMNIST(self.args)
55 | elif self.args.dataset == 'FashionMNIST':
56 | self.data = FashionMNIST(self.args)
57 | else:
58 | print("Dataset not supported")
59 | sys.exit()
60 |
61 | # Reconstruction + KL divergence losses summed over all elements and batch
62 | def loss_function(self, recon_x, x, mu, logvar):
63 | BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
64 | # see Appendix B from VAE paper:
65 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
66 | # https://arxiv.org/abs/1312.6114
67 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
68 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
69 | return BCE + KLD
70 |
71 | def train(self, epoch):
72 | self.model.train()
73 | train_loss = 0
74 | for batch_idx, (data, _) in enumerate(self.train_loader):
75 | data = data.to(self.device)
76 | self.optimizer.zero_grad()
77 | recon_batch, mu, logvar = self.model(data)
78 | loss = self.loss_function(recon_batch, data, mu, logvar)
79 | loss.backward()
80 | train_loss += loss.item()
81 | self.optimizer.step()
82 | if batch_idx % self.args.log_interval == 0:
83 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
84 | epoch, batch_idx * len(data), len(self.train_loader.dataset),
85 | 100. * batch_idx / len(self.train_loader),
86 | loss.item() / len(data)))
87 |
88 | print('====> Epoch: {} Average loss: {:.4f}'.format(
89 | epoch, train_loss / len(self.train_loader.dataset)))
90 |
91 | def test(self, epoch):
92 | self.model.eval()
93 | test_loss = 0
94 | with torch.no_grad():
95 | for i, (data, _) in enumerate(self.test_loader):
96 | data = data.to(self.device)
97 | recon_batch, mu, logvar = self.model(data)
98 | test_loss += self.loss_function(recon_batch, data, mu, logvar).item()
99 |
100 | test_loss /= len(self.test_loader.dataset)
101 | print('====> Test set loss: {:.4f}'.format(test_loss))
102 |
--------------------------------------------------------------------------------
/architectures.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Variable
5 |
6 | import numpy as np
7 |
8 | class FC_Encoder(nn.Module):
9 | def __init__(self, output_size):
10 | super(FC_Encoder, self).__init__()
11 | self.fc1 = nn.Linear(784, output_size)
12 |
13 | def forward(self, x):
14 | h1 = F.relu(self.fc1(x))
15 | return h1
16 |
17 | class FC_Decoder(nn.Module):
18 | def __init__(self, embedding_size):
19 | super(FC_Decoder, self).__init__()
20 | self.fc3 = nn.Linear(embedding_size, 1024)
21 | self.fc4 = nn.Linear(1024, 784)
22 |
23 | def forward(self, z):
24 | h3 = F.relu(self.fc3(z))
25 | return torch.sigmoid(self.fc4(h3))
26 |
27 | class CNN_Encoder(nn.Module):
28 | def __init__(self, output_size, input_size=(1, 28, 28)):
29 | super(CNN_Encoder, self).__init__()
30 |
31 | self.input_size = input_size
32 | self.channel_mult = 16
33 |
34 | #convolutions
35 | self.conv = nn.Sequential(
36 | nn.Conv2d(in_channels=1,
37 | out_channels=self.channel_mult*1,
38 | kernel_size=4,
39 | stride=1,
40 | padding=1),
41 | nn.LeakyReLU(0.2, inplace=True),
42 | nn.Conv2d(self.channel_mult*1, self.channel_mult*2, 4, 2, 1),
43 | nn.BatchNorm2d(self.channel_mult*2),
44 | nn.LeakyReLU(0.2, inplace=True),
45 | nn.Conv2d(self.channel_mult*2, self.channel_mult*4, 4, 2, 1),
46 | nn.BatchNorm2d(self.channel_mult*4),
47 | nn.LeakyReLU(0.2, inplace=True),
48 | nn.Conv2d(self.channel_mult*4, self.channel_mult*8, 4, 2, 1),
49 | nn.BatchNorm2d(self.channel_mult*8),
50 | nn.LeakyReLU(0.2, inplace=True),
51 | nn.Conv2d(self.channel_mult*8, self.channel_mult*16, 3, 2, 1),
52 | nn.BatchNorm2d(self.channel_mult*16),
53 | nn.LeakyReLU(0.2, inplace=True)
54 | )
55 |
56 | self.flat_fts = self.get_flat_fts(self.conv)
57 |
58 | self.linear = nn.Sequential(
59 | nn.Linear(self.flat_fts, output_size),
60 | nn.BatchNorm1d(output_size),
61 | nn.LeakyReLU(0.2),
62 | )
63 |
64 | def get_flat_fts(self, fts):
65 | f = fts(Variable(torch.ones(1, *self.input_size)))
66 | return int(np.prod(f.size()[1:]))
67 |
68 | def forward(self, x):
69 | x = self.conv(x.view(-1, *self.input_size))
70 | x = x.view(-1, self.flat_fts)
71 | return self.linear(x)
72 |
73 | class CNN_Decoder(nn.Module):
74 | def __init__(self, embedding_size, input_size=(1, 28, 28)):
75 | super(CNN_Decoder, self).__init__()
76 | self.input_height = 28
77 | self.input_width = 28
78 | self.input_dim = embedding_size
79 | self.channel_mult = 16
80 | self.output_channels = 1
81 | self.fc_output_dim = 512
82 |
83 | self.fc = nn.Sequential(
84 | nn.Linear(self.input_dim, self.fc_output_dim),
85 | nn.BatchNorm1d(self.fc_output_dim),
86 | nn.ReLU(True)
87 | )
88 |
89 | self.deconv = nn.Sequential(
90 | # input is Z, going into a convolution
91 | nn.ConvTranspose2d(self.fc_output_dim, self.channel_mult*4,
92 | 4, 1, 0, bias=False),
93 | nn.BatchNorm2d(self.channel_mult*4),
94 | nn.ReLU(True),
95 | # state size. self.channel_mult*32 x 4 x 4
96 | nn.ConvTranspose2d(self.channel_mult*4, self.channel_mult*2,
97 | 3, 2, 1, bias=False),
98 | nn.BatchNorm2d(self.channel_mult*2),
99 | nn.ReLU(True),
100 | # state size. self.channel_mult*16 x 7 x 7
101 | nn.ConvTranspose2d(self.channel_mult*2, self.channel_mult*1,
102 | 4, 2, 1, bias=False),
103 | nn.BatchNorm2d(self.channel_mult*1),
104 | nn.ReLU(True),
105 | # state size. self.channel_mult*8 x 14 x 14
106 | nn.ConvTranspose2d(self.channel_mult*1, self.output_channels, 4, 2, 1, bias=False),
107 | nn.Sigmoid()
108 | # state size. self.output_channels x 28 x 28
109 | )
110 |
111 | def forward(self, x):
112 | x = self.fc(x)
113 | x = x.view(-1, self.fc_output_dim, 1, 1)
114 | x = self.deconv(x)
115 | return x.view(-1, self.input_width*self.input_height)
116 |
--------------------------------------------------------------------------------