├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── data ├── __init__.py ├── celeba.py ├── cifar10.py └── mnist.py ├── datagen.py ├── discriminators.py ├── encoders.py ├── generators.py ├── images └── mnist │ └── .gitignore ├── ops.py ├── plot.py ├── plots ├── celeba │ ├── samples_47299.png │ └── samples_51399.png ├── cifar10 │ └── samples_199999.jpg └── mnist │ └── samples_33099.jpg ├── requirements.txt ├── results ├── cifar10 │ ├── aae_samples_199999.jpg │ ├── ae_cost.jpg │ ├── ae_samples_80699.jpg │ ├── dev_disc_cost.jpg │ ├── disc_cost.jpg │ ├── gen_cost.jpg │ └── w1_distance.jpg └── mnist │ ├── ae_cost.jpg │ ├── ae_samples_10799.jpg │ ├── dev_disc_cost.jpg │ ├── disc_cost.jpg │ ├── gen_cost.jpg │ └── samples_37499.jpg ├── spectral_normalization.py ├── start.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pkl 3 | *.gz 4 | *.bz2 5 | *.tar 6 | *.pt 7 | *.jpg 8 | models/ 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Neale Ratzlaff 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial-Autoencoder 2 | A convolutional adversarial autoencoder implementation in pytorch using the WGAN with gradient penalty framework. 3 | 4 | There's a lot to tweak here as far as balancing the adversarial vs reconstruction loss, but this works and I'll update as I go along. 5 | 6 | The MNIST GAN seems to converge at around 30K steps, while CIFAR10 arguably doesn't output anything realistic ever (compared to ACGAN). Nonetheless it starts to looks ok at around 50K steps 7 | 8 | The autoencoder components are able to output good reconstructions much faster than the GAN. ~10k steps on MNIST. The auto encoder is currently bad with CIFAR10 (under investigation) 9 | 10 | 11 | # Note 12 | There is a lot here that I want to add and fix with regard to image generation and large scale training. 13 | 14 | But I can't do anything until pytorch fixes these issues with gradient penalty [here](https://github.com/pytorch/pytorch/issues/19024) 15 | 16 | 17 | ## MNIST Gaussian Samples (GAN) - 33k steps 18 | 19 | ![output image](plots/mnist/samples_33099.jpg) 20 | 21 | ## MNIST Reconstructions (AE) - 10k steps 22 | 23 | ![output_image](results/mnist/ae_samples_10799.jpg) 24 | 25 | ## CelebA 64x64 Gaussian Samples (GAN) - 50k steps 26 | 27 | ![output_image](plots/celeba/samples_51399.png) 28 | 29 | ## CIFAR10 Gaussian Samples (GAN) - 200k steps 30 | 31 | ![output image](plots/cifar10/samples_199999.jpg) 32 | 33 | ## CIFAR10 Reconstructions (AE) - 80k steps 34 | 35 | ![output image](results/cifar10/ae_samples_80699.jpg) 36 | 37 | clearly need to fix this 38 | 39 | ### Requirements 40 | 41 | * pytorch 0.2.0 42 | * python 3 - but 2.7 just requires some simple modifications 43 | * matplotlib / numpy / scipy 44 | 45 | ### Usage 46 | 47 | To start training right away just run 48 | 49 | `start.sh` 50 | 51 | 52 | To train on MNIST 53 | 54 | `python3 train.py --dataset mnist --batch_size 50 --dim 32 -o 784` 55 | 56 | To train on CIFAR10 57 | 58 | `python3 train.py --dataset cifar10 --batch_size 64 --dim 32 -o 3072` 59 | 60 | ### Acknowledgements 61 | 62 | For the wgan-gp components I mostly used [caogang's](https://github.com/caogang/wgan-gp) nice implementation 63 | 64 | 65 | ### TODO 66 | 67 | Provide pretrained model files from on a google drive folder (push loader). 68 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import locale 3 | 4 | locale.setlocale(locale.LC_ALL, '') 5 | 6 | _params = {} 7 | _param_aliases = {} 8 | 9 | def params_with_name(name): 10 | return [p for n,p in _params.items() if name in n] 11 | 12 | def delete_all_params(): 13 | _params.clear() 14 | 15 | def alias_params(replace_dict): 16 | for old,new in replace_dict.items(): 17 | _param_aliases[old] = new 18 | 19 | def delete_param_aliases(): 20 | _param_aliases.clear() 21 | 22 | def print_model_settings(locals_): 23 | all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T' and k!='SETTINGS' and k!='ALL_SETTINGS')] 24 | all_vars = sorted(all_vars, key=lambda x: x[0]) 25 | for var_name, var_value in all_vars: 26 | print "\t{}: {}".format(var_name, var_value) 27 | 28 | 29 | def print_model_settings_dict(settings): 30 | all_vars = [(k,v) for (k,v) in settings.items()] 31 | all_vars = sorted(all_vars, key=lambda x: x[0]) 32 | for var_name, var_value in all_vars: 33 | print "\t{}: {}".format(var_name, var_value) 34 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/data/__init__.py -------------------------------------------------------------------------------- /data/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import urllib 4 | import numpy as np 5 | import _pickle as pickle 6 | import glob 7 | from scipy.misc import imread 8 | 9 | def celeba_generator(batch_size, data_dir): 10 | all_data = [] 11 | 12 | paths = glob.glob(data_dir+'*.jpg') 13 | for fn in paths: 14 | all_data.append(imread(fn)) 15 | images = np.concatenate(all_data, axis=0) 16 | 17 | def get_epoch(): 18 | rng_state = np.random.get_state() 19 | np.random.shuffle(images) 20 | np.random.set_state(rng_state) 21 | 22 | for i in range(int(len(images) / batch_size)): 23 | yield np.copy(images[i*batch_size:(i+1)*batch_size]) 24 | 25 | return get_epoch 26 | 27 | 28 | def load(batch_size, data_dir): 29 | return celeba_generator(batch_size, data_dir) 30 | -------------------------------------------------------------------------------- /data/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import urllib 4 | import numpy as np 5 | import _pickle as pickle 6 | 7 | def unpickle(file): 8 | fo = open(file, 'rb') 9 | dict = pickle.load(fo, encoding='iso-8859-1') 10 | fo.close() 11 | return dict['data'], dict['labels'] 12 | 13 | def cifar_generator(filenames, batch_size, data_dir): 14 | all_data = [] 15 | all_labels = [] 16 | for filename in filenames: 17 | data, labels = unpickle(data_dir + '/' + filename) 18 | all_data.append(data) 19 | all_labels.append(labels) 20 | 21 | images = np.concatenate(all_data, axis=0) 22 | labels = np.concatenate(all_labels, axis=0) 23 | 24 | def get_epoch(): 25 | rng_state = np.random.get_state() 26 | np.random.shuffle(images) 27 | np.random.set_state(rng_state) 28 | np.random.shuffle(labels) 29 | 30 | for i in range(int(len(images) / batch_size)): 31 | yield (np.copy(images[i*batch_size:(i+1)*batch_size]), 32 | labels[i*batch_size:(i+1)*batch_size]) 33 | 34 | return get_epoch 35 | 36 | 37 | def load(batch_size, data_dir): 38 | return ( 39 | cifar_generator(['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5'], batch_size, data_dir), 40 | cifar_generator(['test_batch'], batch_size, data_dir) 41 | ) 42 | -------------------------------------------------------------------------------- /data/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import urllib.request 4 | import numpy as np 5 | import _pickle as pickle 6 | 7 | filepath = 'images/mnist/mnist.pkl.gz' 8 | 9 | def mnist_generator(data, batch_size, n_labelled, limit=None): 10 | if batch_size % 25 != 0: 11 | batch_size = 50 12 | images, targets = data 13 | rng_state = np.random.get_state() 14 | np.random.shuffle(images) 15 | np.random.set_state(rng_state) 16 | np.random.shuffle(targets) 17 | if limit is not None: 18 | print ("WARNING ONLY FIRST {} MNIST DIGITS".format(limit)) 19 | images = images.astype('float32')[:limit] 20 | targets = targets.astype('int32')[:limit] 21 | if n_labelled is not None: 22 | labelled = np.zeros(len(images), dtype='int32') 23 | labelled[:n_labelled] = 1 24 | 25 | def get_epoch(): 26 | rng_state = np.random.get_state() 27 | np.random.shuffle(images) 28 | np.random.set_state(rng_state) 29 | np.random.shuffle(targets) 30 | 31 | if n_labelled is not None: 32 | np.random.set_state(rng_state) 33 | np.random.shuffle(labelled) 34 | 35 | image_batches = images.reshape(-1, batch_size, 784) 36 | target_batches = targets.reshape(-1, batch_size) 37 | 38 | if n_labelled is not None: 39 | labelled_batches = labelled.reshape(-1, batch_size) 40 | 41 | for i in range(len(image_batches)): 42 | yield (np.copy(image_batches[i]), 43 | np.copy(target_batches[i]), 44 | np.copy(labelled)) 45 | 46 | else: 47 | for i in range(len(image_batches)): 48 | yield (np.copy(image_batches[i])) 49 | return get_epoch 50 | 51 | 52 | def load(batch_size, test_batch_size, n_labelled=None): 53 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 54 | 55 | if not os.path.isfile(filepath): 56 | print ("Couldn't find MNIST dataset in "+filepath+", downloading...") 57 | urllib.request.urlretrieve(url, filepath) 58 | 59 | with gzip.open(filepath, 'rb') as f: 60 | train_data, dev_data, test_data = pickle.load(f, encoding='iso-8859-1') 61 | 62 | return ( 63 | mnist_generator(train_data, batch_size, n_labelled), 64 | mnist_generator(dev_data, test_batch_size, n_labelled), 65 | mnist_generator(test_data, test_batch_size, n_labelled) 66 | ) 67 | -------------------------------------------------------------------------------- /datagen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import utils 4 | import torch 5 | import torchvision 6 | import pickle 7 | from torchvision import datasets, transforms 8 | from scipy.misc import imread,imresize 9 | from sklearn.model_selection import train_test_split 10 | from glob import glob 11 | 12 | 13 | def load_mnist(args): 14 | torch.cuda.manual_seed(1) 15 | kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': False} 16 | path = 'data/mnist' 17 | train_loader = torch.utils.data.DataLoader( 18 | datasets.MNIST(path, train=True, download=True, 19 | transform=transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.1307,), (0.3081,)) 22 | ])), 23 | batch_size=args.batch_size, shuffle=False, **kwargs) 24 | test_loader = torch.utils.data.DataLoader( 25 | datasets.MNIST(path, train=False, transform=transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.1307,), (0.3081,)) 28 | ])), 29 | batch_size=100, shuffle=False, **kwargs) 30 | return train_loader, test_loader 31 | 32 | 33 | def load_fashion_mnist(args): 34 | path = 'data/fashion_mnist' 35 | torch.cuda.manual_seed(1) 36 | kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True} 37 | train_loader = torch.utils.data.DataLoader( 38 | datasets.FashionMNIST(path, train=True, download=True, 39 | transform=transforms.Compose([ 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.1307,), (0.3081,)) 42 | ])), 43 | batch_size=args.batch_size, shuffle=True, **kwargs) 44 | test_loader = torch.utils.data.DataLoader( 45 | datasets.FashionMNIST(path, train=False, download=True, 46 | transform=transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.1307,), (0.3081,)) 49 | ])), 50 | batch_size=100, shuffle=False, **kwargs) 51 | return train_loader, test_loader 52 | 53 | 54 | def load_cifar(args): 55 | path = 'data/cifar' 56 | kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True} 57 | transform_train = transforms.Compose([ 58 | transforms.RandomCrop(32, padding=4), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 62 | ]) 63 | transform_test = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 66 | ]) 67 | trainset = torchvision.datasets.CIFAR10(root=path, train=True, 68 | download=True, transform=transform_train) 69 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 70 | shuffle=True, **kwargs) 71 | testset = torchvision.datasets.CIFAR10(root=path, train=False, 72 | download=True, transform=transform_test) 73 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, 74 | shuffle=True, **kwargs) 75 | return trainloader, testloader 76 | 77 | 78 | def load_cifar_hidden(args, c_idx): 79 | path = 'data/cifar' 80 | kwargs = {'num_workers': 2, 'pin_memory': True, 'drop_last': True} 81 | transform_train = transforms.Compose([ 82 | transforms.RandomCrop(32, padding=4), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 86 | ]) 87 | transform_test = transforms.Compose([ 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 90 | ]) 91 | def get_classes(target, labels): 92 | label_indices = [] 93 | for i in range(len(target)): 94 | if target[i][1] in labels: 95 | label_indices.append(i) 96 | return label_indices 97 | 98 | trainset = torchvision.datasets.CIFAR10(root=path, train=True, 99 | download=True, transform=transform_train) 100 | train_hidden = torch.utils.data.Subset(trainset, get_classes(trainset, c_idx)) 101 | trainloader = torch.utils.data.DataLoader(train_hidden, batch_size=args.batch_size, 102 | shuffle=True, **kwargs) 103 | 104 | testset = torchvision.datasets.CIFAR10(root=path, train=False, 105 | download=True, transform=transform_test) 106 | test_hidden = torch.utils.data.Subset(testset, get_classes(testset, c_idx)) 107 | testloader = torch.utils.data.DataLoader(test_hidden, batch_size=100, 108 | shuffle=True, **kwargs) 109 | return trainloader, testloader 110 | -------------------------------------------------------------------------------- /discriminators.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from spectral_normalization import SpectralNorm as SN 4 | 5 | 6 | class CELEBAdiscriminator(nn.Module): 7 | def __init__(self, args): 8 | super(CELEBAdiscriminator, self).__init__() 9 | self._name = 'cifarD' 10 | self.shape = (64, 64, 3) 11 | self.dim = args.dim 12 | 13 | self.conv1 = SN(nn.Conv2d(3, self.dim, 3, 1, padding=1)) 14 | self.conv2 = SN(nn.Conv2d(self.dim, self.dim, 3, 2, padding=1)) 15 | self.conv3 = SN(nn.Conv2d(self.dim, 2 * self.dim, 3, 1, padding=1)) 16 | self.conv4 = SN(nn.Conv2d(2 * self.dim, 2 * self.dim, 3, 2, padding=1)) 17 | self.conv5 = SN(nn.Conv2d(2 * self.dim, 4 * self.dim, 3, 1, padding=1)) 18 | self.conv6 = SN(nn.Conv2d(4 * self.dim, 4 * self.dim, 3, 2, padding=1)) 19 | self.linear = SN(nn.Linear(4*4*4*self.dim, 1)) 20 | 21 | def forward(self, input): 22 | input = input.view(-1, 3, 64, 64) 23 | x = F.leaky_relu(self.conv1(input)) 24 | x = F.leaky_relu(self.conv2(x)) 25 | x = F.leaky_relu(self.conv3(x)) 26 | x = F.leaky_relu(self.conv4(x)) 27 | x = F.leaky_relu(self.conv5(x)) 28 | x = F.leaky_relu(self.conv6(x)) 29 | output = x.view(-1, 4*4*4*self.dim) 30 | output = self.linear(output) 31 | return output 32 | 33 | 34 | class CIFARdiscriminator(nn.Module): 35 | def __init__(self, args): 36 | super(CIFARdiscriminator, self).__init__() 37 | self._name = 'cifarD' 38 | self.shape = (32, 32, 3) 39 | self.dim = args.dim 40 | convblock = nn.Sequential( 41 | nn.Conv2d(3, self.dim, 3, 2, padding=1), 42 | nn.LeakyReLU(), 43 | nn.Conv2d(self.dim, 2 * self.dim, 3, 2, padding=1), 44 | nn.LeakyReLU(), 45 | nn.Conv2d(2 * self.dim, 4 * self.dim, 3, 2, padding=1), 46 | nn.LeakyReLU(), 47 | ) 48 | self.main = convblock 49 | self.linear = nn.Linear(4*4*4*self.dim, 1) 50 | 51 | def forward(self, input): 52 | input = input.view(-1, 3, 32, 32) 53 | output = self.main(input) 54 | output = output.view(-1, 4*4*4*self.dim) 55 | output = self.linear(output) 56 | return output 57 | 58 | 59 | class MNISTdiscriminator(nn.Module): 60 | def __init__(self, args): 61 | super(MNISTdiscriminator, self).__init__() 62 | self._name = 'mnistD' 63 | self.shape = (1, 28, 28) 64 | self.dim = args.dim 65 | convblock = nn.Sequential( 66 | nn.Conv2d(1, self.dim, 5, stride=2, padding=2), 67 | nn.Dropout(p=0.3), 68 | nn.ReLU(True), 69 | nn.Conv2d(self.dim, 2*self.dim, 5, stride=2, padding=2), 70 | nn.Dropout(p=0.3), 71 | nn.ReLU(True), 72 | nn.Conv2d(2*self.dim, 4*self.dim, 5, stride=2, padding=2), 73 | nn.Dropout(p=0.3), 74 | nn.ReLU(True), 75 | ) 76 | self.main = convblock 77 | self.output = nn.Linear(4*4*4*self.dim, 1) 78 | 79 | def forward(self, input): 80 | input = input.view(-1, 1, 28, 28) 81 | out = self.main(input) 82 | out = out.view(-1, 4*4*4*self.dim) 83 | out = self.output(out) 84 | return out.view(-1) 85 | -------------------------------------------------------------------------------- /encoders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class CELEBAencoder(nn.Module): 7 | def __init__(self, args): 8 | super(CELEBAencoder, self).__init__() 9 | self._name = 'celebaE' 10 | self.shape = (64, 64, 3) 11 | self.dim = args.dim 12 | convblock = nn.Sequential( 13 | nn.Conv2d(3, self.dim, 3, 2, padding=1), 14 | nn.Dropout(p=0.3), 15 | nn.LeakyReLU(), 16 | nn.Conv2d(self.dim, 2 * self.dim, 3, 2, padding=1), 17 | nn.Dropout(p=0.3), 18 | nn.LeakyReLU(), 19 | nn.Conv2d(2 * self.dim, 4 * self.dim, 3, 2, padding=1), 20 | nn.Dropout(p=0.3), 21 | nn.LeakyReLU(), 22 | nn.Conv2d(4 * self.dim, 8 * self.dim, 3, 2, padding=1), 23 | nn.Dropout(p=0.3), 24 | nn.LeakyReLU(), 25 | nn.Conv2d(8 * self.dim, 16 * self.dim, 3, 2, padding=1), 26 | nn.Dropout(p=0.3), 27 | nn.LeakyReLU(), 28 | ) 29 | self.main = convblock 30 | self.linear = nn.Linear(4*4*4*self.dim, self.dim) 31 | 32 | def forward(self, input): 33 | input = input.view(-1, 3, 64, 64) 34 | output = self.main(input) 35 | output = output.view(-1, 4*4*4*self.dim) 36 | output = self.linear(output) 37 | return output.view(-1, self.dim) 38 | 39 | 40 | class CIFARencoder(nn.Module): 41 | def __init__(self, args): 42 | super(CIFARencoder, self).__init__() 43 | self._name = 'cifarE' 44 | self.shape = (32, 32, 3) 45 | self.dim = args.dim 46 | convblock = nn.Sequential( 47 | nn.Conv2d(3, self.dim, 3, 2, padding=1), 48 | nn.Dropout(p=0.3), 49 | nn.LeakyReLU(), 50 | nn.Conv2d(self.dim, 2 * self.dim, 3, 2, padding=1), 51 | nn.Dropout(p=0.3), 52 | nn.LeakyReLU(), 53 | nn.Conv2d(2 * self.dim, 4 * self.dim, 3, 2, padding=1), 54 | nn.Dropout(p=0.3), 55 | nn.LeakyReLU(), 56 | ) 57 | self.main = convblock 58 | self.linear = nn.Linear(4*4*4*self.dim, self.dim) 59 | 60 | def forward(self, input): 61 | input = input.view(-1, 3, 32, 32) 62 | output = self.main(input) 63 | output = output.view(-1, 4*4*4*self.dim) 64 | output = self.linear(output) 65 | return output.view(-1, self.dim) 66 | 67 | 68 | class MNISTencoder(nn.Module): 69 | def __init__(self, args): 70 | super(MNISTencoder, self).__init__() 71 | self._name = 'mnistE' 72 | self.shape = (1, 28, 28) 73 | self.dim = args.dim 74 | convblock = nn.Sequential( 75 | nn.Conv2d(1, self.dim, 5, stride=2, padding=2), 76 | nn.Dropout(p=0.3), 77 | nn.ReLU(True), 78 | nn.Conv2d(self.dim, 2*self.dim, 5, stride=2, padding=2), 79 | nn.Dropout(p=0.3), 80 | nn.ReLU(True), 81 | nn.Conv2d(2*self.dim, 4*self.dim, 5, stride=2, padding=2), 82 | nn.Dropout(p=0.3), 83 | nn.ReLU(True), 84 | ) 85 | self.main = convblock 86 | self.output = nn.Linear(4*4*4*self.dim, self.dim) 87 | 88 | def forward(self, input): 89 | input = input.view(-1, 1, 28, 28) 90 | out = self.main(input) 91 | out = out.view(-1, 4*4*4*self.dim) 92 | out = self.output(out) 93 | return out.view(-1, self.dim) 94 | -------------------------------------------------------------------------------- /generators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class CELEBAgenerator(nn.Module): 7 | def __init__(self, args): 8 | super(CELEBAgenerator, self).__init__() 9 | self._name = 'celebaG' 10 | self.shape = (64, 64, 3) 11 | self.dim = args.dim 12 | preprocess = nn.Sequential( 13 | nn.Linear(self.dim, 2* 4 * 4 * 4 * self.dim), 14 | nn.BatchNorm1d(2 * 4 * 4 * 4 * self.dim), 15 | nn.ReLU(True), 16 | ) 17 | block1 = nn.Sequential( 18 | nn.ConvTranspose2d(8 * self.dim, 4 * self.dim, 2, stride=2), 19 | nn.BatchNorm2d(4 * self.dim), 20 | nn.ReLU(True), 21 | ) 22 | block2 = nn.Sequential( 23 | nn.ConvTranspose2d(4 * self.dim, 2 * self.dim, 2, stride=2), 24 | nn.BatchNorm2d(2 * self.dim), 25 | nn.ReLU(True), 26 | ) 27 | block3 = nn.Sequential( 28 | nn.ConvTranspose2d(2 * self.dim, self.dim, 2, stride=2), 29 | nn.BatchNorm2d(self.dim), 30 | nn.ReLU(True), 31 | ) 32 | deconv_out = nn.ConvTranspose2d(self.dim, 3, 2, stride=2) 33 | 34 | self.preprocess = preprocess 35 | self.block1 = block1 36 | self.block2 = block2 37 | self.block3 = block3 38 | self.deconv_out = deconv_out 39 | self.tanh = nn.Tanh() 40 | 41 | def forward(self, input): 42 | output = self.preprocess(input) 43 | output = output.view(-1, 4 * 2 * self.dim, 4, 4) 44 | output = self.block1(output) 45 | output = self.block2(output) 46 | output = self.block3(output) 47 | output = self.deconv_out(output) 48 | output = self.tanh(output) 49 | output = output.view(-1, 3, 64, 64) 50 | return output 51 | 52 | 53 | class CIFARgenerator(nn.Module): 54 | def __init__(self, args): 55 | super(CIFARgenerator, self).__init__() 56 | self._name = 'cifarG' 57 | self.shape = (32, 32, 3) 58 | self.dim = args.dim 59 | preprocess = nn.Sequential( 60 | nn.Linear(self.dim, 4 * 4 * 4 * self.dim), 61 | nn.BatchNorm1d(4 * 4 * 4 * self.dim), 62 | nn.ReLU(True), 63 | ) 64 | block1 = nn.Sequential( 65 | nn.ConvTranspose2d(4 * self.dim, 2 * self.dim, 2, stride=2), 66 | nn.BatchNorm2d(2 * self.dim), 67 | nn.ReLU(True), 68 | ) 69 | block2 = nn.Sequential( 70 | nn.ConvTranspose2d(2 * self.dim, self.dim, 2, stride=2), 71 | nn.BatchNorm2d(self.dim), 72 | nn.ReLU(True), 73 | ) 74 | deconv_out = nn.ConvTranspose2d(self.dim, 3, 2, stride=2) 75 | 76 | self.preprocess = preprocess 77 | self.block1 = block1 78 | self.block2 = block2 79 | self.deconv_out = deconv_out 80 | self.tanh = nn.Tanh() 81 | 82 | def forward(self, input): 83 | output = self.preprocess(input) 84 | output = output.view(-1, 4 * self.dim, 4, 4) 85 | output = self.block1(output) 86 | output = self.block2(output) 87 | output = self.deconv_out(output) 88 | output = self.tanh(output) 89 | return output.view(-1, 3*32*32) 90 | 91 | 92 | class MNISTgenerator(nn.Module): 93 | def __init__(self, args): 94 | super(MNISTgenerator, self).__init__() 95 | self._name = 'mnistG' 96 | self.dim = args.dim 97 | self.in_shape = int(np.sqrt(args.dim)) 98 | self.shape = (self.in_shape, self.in_shape, 1) 99 | preprocess = nn.Sequential( 100 | nn.Linear(self.dim, 4*4*4*self.dim), 101 | nn.ReLU(True), 102 | ) 103 | block1 = nn.Sequential( 104 | nn.ConvTranspose2d(4*self.dim, 2*self.dim, 5), 105 | nn.ReLU(True), 106 | ) 107 | block2 = nn.Sequential( 108 | nn.ConvTranspose2d(2*self.dim, self.dim, 5), 109 | nn.ReLU(True), 110 | ) 111 | deconv_out = nn.ConvTranspose2d(self.dim, 1, 8, stride=2) 112 | self.block1 = block1 113 | self.block2 = block2 114 | self.deconv_out = deconv_out 115 | self.preprocess = preprocess 116 | self.sigmoid = nn.Sigmoid() 117 | 118 | def forward(self, input): 119 | output = self.preprocess(input) 120 | output = output.view(-1, 4*self.dim, 4, 4) 121 | output = self.block1(output) 122 | output = output[:, :, :7, :7] 123 | output = self.block2(output) 124 | output = self.deconv_out(output) 125 | output = self.sigmoid(output) 126 | return output.view(-1, 784) 127 | -------------------------------------------------------------------------------- /images/mnist/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/images/mnist/.gitignore -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import scipy.misc 4 | import torch.autograd as autograd 5 | from scipy.misc import imsave 6 | 7 | 8 | 9 | def calc_gradient_penalty(args, model, real_data, gen_data): 10 | datashape = model.shape 11 | alpha = torch.rand(args.batch_size, 1) 12 | real_data = real_data.view(args.batch_size, -1) 13 | if args.dataset == 'mnist': 14 | alpha = alpha.expand(real_data.size()).cuda() 15 | else: 16 | alpha = alpha.expand(args.batch_size, real_data.nelement()//args.batch_size) 17 | alpha = alpha.contiguous().view(args.batch_size, -1).cuda() 18 | interpolates = alpha * real_data + ((1 - alpha) * gen_data) 19 | interpolates = interpolates.cuda() 20 | interpolates = autograd.Variable(interpolates, requires_grad=True) 21 | disc_interpolates = model(interpolates) 22 | gradients = autograd.grad(outputs=disc_interpolates, 23 | inputs=interpolates, 24 | grad_outputs=torch.ones(disc_interpolates.size()).cuda(), 25 | create_graph=True, 26 | retain_graph=True, 27 | only_inputs=True)[0] 28 | 29 | if args.dataset != 'mnist': 30 | gradients = gradients.view(gradients.size(0), -1) 31 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * args.gp 32 | return gradient_penalty 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import collections 4 | import time 5 | import _pickle as pickle 6 | 7 | _since_beginning = collections.defaultdict(lambda: {}) 8 | _since_last_flush = collections.defaultdict(lambda: {}) 9 | _iter = [0] 10 | 11 | 12 | def tick(): 13 | _iter[0] += 1 14 | 15 | def plot(d, name, value): 16 | _since_last_flush[d+name][_iter[0]] = value 17 | 18 | def flush(): 19 | prints = [] 20 | 21 | for name, vals in _since_last_flush.items(): 22 | prints.append("{}\t{}".format(name, 23 | np.mean(list(vals.values())))) 24 | _since_beginning[name].update(vals) 25 | 26 | x_vals = np.sort(list(_since_beginning[name].keys())) 27 | y_vals = [_since_beginning[name][x] for x in x_vals] 28 | 29 | plt.clf() 30 | plt.plot(x_vals, y_vals) 31 | plt.xlabel('iteration') 32 | plt.ylabel(name) 33 | # plt.savefig(name.replace(' ', '_')+'.jpg') 34 | 35 | print ("iter {}\t{}".format(_iter[0], "\t".join(prints))) 36 | _since_last_flush.clear() 37 | 38 | with open('log.pkl', 'wb') as f: 39 | pickle.dump(dict(_since_beginning), f, 3) 40 | -------------------------------------------------------------------------------- /plots/celeba/samples_47299.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/plots/celeba/samples_47299.png -------------------------------------------------------------------------------- /plots/celeba/samples_51399.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/plots/celeba/samples_51399.png -------------------------------------------------------------------------------- /plots/cifar10/samples_199999.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/plots/cifar10/samples_199999.jpg -------------------------------------------------------------------------------- /plots/mnist/samples_33099.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/plots/mnist/samples_33099.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | matplotlib 4 | numpy 5 | scipy 6 | -------------------------------------------------------------------------------- /results/cifar10/aae_samples_199999.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/cifar10/aae_samples_199999.jpg -------------------------------------------------------------------------------- /results/cifar10/ae_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/cifar10/ae_cost.jpg -------------------------------------------------------------------------------- /results/cifar10/ae_samples_80699.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/cifar10/ae_samples_80699.jpg -------------------------------------------------------------------------------- /results/cifar10/dev_disc_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/cifar10/dev_disc_cost.jpg -------------------------------------------------------------------------------- /results/cifar10/disc_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/cifar10/disc_cost.jpg -------------------------------------------------------------------------------- /results/cifar10/gen_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/cifar10/gen_cost.jpg -------------------------------------------------------------------------------- /results/cifar10/w1_distance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/cifar10/w1_distance.jpg -------------------------------------------------------------------------------- /results/mnist/ae_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/mnist/ae_cost.jpg -------------------------------------------------------------------------------- /results/mnist/ae_samples_10799.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/mnist/ae_samples_10799.jpg -------------------------------------------------------------------------------- /results/mnist/dev_disc_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/mnist/dev_disc_cost.jpg -------------------------------------------------------------------------------- /results/mnist/disc_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/mnist/disc_cost.jpg -------------------------------------------------------------------------------- /results/mnist/gen_cost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/mnist/gen_cost.jpg -------------------------------------------------------------------------------- /results/mnist/samples_37499.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neale/Adversarial-Autoencoder/e1d5dbcc293862aabfcaecb26e8553f0710c04a0/results/mnist/samples_37499.jpg -------------------------------------------------------------------------------- /spectral_normalization.py: -------------------------------------------------------------------------------- 1 | ########################## 2 | ## Mostly copied from github.com/christiancosgrove/pytorch-spectral-normalization-gan 3 | ########################## 4 | 5 | import torch 6 | from torch.optim.optimizer import Optimizer, required 7 | 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | from torch import nn 11 | from torch import Tensor 12 | from torch.nn import Parameter 13 | 14 | def l2normalize(v, eps=1e-12): 15 | return v / (v.norm() + eps) 16 | 17 | 18 | class SpectralNorm(nn.Module): 19 | def __init__(self, module, name='weight', power_iterations=1): 20 | super(SpectralNorm, self).__init__() 21 | self.module = module 22 | self.name = name 23 | self.power_iterations = power_iterations 24 | if not self._made_params(): 25 | self._make_params() 26 | 27 | def _update_u_v(self): 28 | u = getattr(self.module, self.name + "_u") 29 | v = getattr(self.module, self.name + "_v") 30 | w = getattr(self.module, self.name + "_bar") 31 | 32 | height = w.data.shape[0] 33 | for _ in range(self.power_iterations): 34 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 35 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 36 | 37 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 38 | sigma = u.dot(w.view(height, -1).mv(v)) 39 | setattr(self.module, self.name, w / sigma.expand_as(w)) 40 | 41 | def _made_params(self): 42 | try: 43 | u = getattr(self.module, self.name + "_u") 44 | v = getattr(self.module, self.name + "_v") 45 | w = getattr(self.module, self.name + "_bar") 46 | return True 47 | except AttributeError: 48 | return False 49 | 50 | 51 | def _make_params(self): 52 | w = getattr(self.module, self.name) 53 | 54 | height = w.data.shape[0] 55 | width = w.view(height, -1).data.shape[1] 56 | 57 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 58 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 59 | u.data = l2normalize(u.data) 60 | v.data = l2normalize(v.data) 61 | w_bar = Parameter(w.data) 62 | 63 | del self.module._parameters[self.name] 64 | 65 | self.module.register_parameter(self.name + "_u", u) 66 | self.module.register_parameter(self.name + "_v", v) 67 | self.module.register_parameter(self.name + "_bar", w_bar) 68 | 69 | 70 | def forward(self, *args): 71 | self._update_u_v() 72 | return self.module.forward(*args) 73 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install -r requirements.txt 3 | python3 train.py --dataset mnist --batch_size 50 --dim 32 --o 784 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import numpy as np 6 | from scipy.misc import imshow 7 | 8 | import torch 9 | import torchvision 10 | from torch import nn 11 | from torch import autograd 12 | from torch import optim 13 | from torch.nn import functional as F 14 | 15 | import ops 16 | import plot 17 | import utils 18 | import datagen 19 | import encoders 20 | import generators 21 | import discriminators 22 | from data import mnist 23 | from data import cifar10 24 | 25 | 26 | def load_args(): 27 | 28 | parser = argparse.ArgumentParser(description='aae-wgan') 29 | parser.add_argument('-d', '--dim', default=100, type=int, help='latent space size') 30 | parser.add_argument('-l', '--gp', default=10, type=int, help='gradient penalty') 31 | parser.add_argument('-b', '--batch_size', default=64, type=int) 32 | parser.add_argument('-e', '--epochs', default=200, type=int) 33 | parser.add_argument('-o', '--output_dim', default=4096, type=int) 34 | parser.add_argument('--dataset', default='celeba') 35 | parser.add_argument('--use_spectral_norm', default=True) 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | def load_models(args): 41 | if args.dataset in ['mnist', 'fmnist']: 42 | netG = generators.MNISTgenerator(args).cuda() 43 | netD = discriminators.MNISTdiscriminator(args).cuda() 44 | netE = encoders.MNISTencoder(args).cuda() 45 | 46 | if args.dataset in ['cifar', 'cifar_hidden']: 47 | netG = generators.CIFARgenerator(args).cuda() 48 | netD = discriminators.CIFARdiscriminator(args).cuda() 49 | netE = encoders.CIFARencoder(args).cuda() 50 | 51 | if args.dataset == 'celeba': 52 | netG = generators.CELEBAgenerator(args).cuda() 53 | netD = discriminators.CELEBAdiscriminator(args).cuda() 54 | netE = encoders.CELEBAencoder(args).cuda() 55 | 56 | print (netG, netD, netE) 57 | return (netG, netD, netE) 58 | 59 | 60 | def load_data(args): 61 | if args.dataset == 'mnist': 62 | return datagen.load_mnist(args) 63 | if args.dataset == 'cifar': 64 | return datagen.load_cifar(args) 65 | if args.dataset == 'fmnist': 66 | return datagen.load_fashion_mnist(args) 67 | if args.dataset == 'cifar_hidden': 68 | class_list = [0] ## just load class 0 69 | return datagen.load_cifar_hidden(args, class_list) 70 | else: 71 | print ('Dataset not specified correctly') 72 | print ('choose --dataset ') 73 | 74 | 75 | def train(): 76 | args = load_args() 77 | train_gen, test_gen = load_data(args) 78 | torch.manual_seed(1) 79 | netG, netD, netE = load_models(args) 80 | 81 | if args.use_spectral_norm: 82 | optimizerD = optim.Adam(filter(lambda p: p.requires_grad, 83 | netD.parameters()), lr=2e-4, betas=(0.0,0.9)) 84 | else: 85 | optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.9)) 86 | optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.9)) 87 | optimizerE = optim.Adam(netE.parameters(), lr=2e-4, betas=(0.5, 0.9)) 88 | 89 | schedulerD = optim.lr_scheduler.ExponentialLR(optimizerD, gamma=0.99) 90 | schedulerG = optim.lr_scheduler.ExponentialLR(optimizerG, gamma=0.99) 91 | schedulerE = optim.lr_scheduler.ExponentialLR(optimizerE, gamma=0.99) 92 | 93 | ae_criterion = nn.MSELoss() 94 | one = torch.FloatTensor([1]).cuda() 95 | mone = (one * -1).cuda() 96 | iteration = 0 97 | for epoch in range(args.epochs): 98 | for i, (data, targets) in enumerate(train_gen): 99 | start_time = time.time() 100 | """ Update AutoEncoder """ 101 | for p in netD.parameters(): 102 | p.requires_grad = False 103 | netG.zero_grad() 104 | netE.zero_grad() 105 | real_data_v = autograd.Variable(data).cuda() 106 | real_data_v = real_data_v.view(args.batch_size, -1) 107 | encoding = netE(real_data_v) 108 | fake = netG(encoding) 109 | ae_loss = ae_criterion(fake, real_data_v) 110 | ae_loss.backward(one) 111 | optimizerE.step() 112 | optimizerG.step() 113 | 114 | """ Update D network """ 115 | for p in netD.parameters(): 116 | p.requires_grad = True 117 | for i in range(5): 118 | real_data_v = autograd.Variable(data).cuda() 119 | # train with real data 120 | netD.zero_grad() 121 | D_real = netD(real_data_v) 122 | D_real = D_real.mean() 123 | D_real.backward(mone) 124 | # train with fake data 125 | noise = torch.randn(args.batch_size, args.dim).cuda() 126 | noisev = autograd.Variable(noise, volatile=True) 127 | fake = autograd.Variable(netG(noisev).data) 128 | inputv = fake 129 | D_fake = netD(inputv) 130 | D_fake = D_fake.mean() 131 | D_fake.backward(one) 132 | 133 | # train with gradient penalty 134 | gradient_penalty = ops.calc_gradient_penalty(args, 135 | netD, real_data_v.data, fake.data) 136 | gradient_penalty.backward() 137 | 138 | D_cost = D_fake - D_real + gradient_penalty 139 | Wasserstein_D = D_real - D_fake 140 | optimizerD.step() 141 | 142 | # Update generator network (GAN) 143 | noise = torch.randn(args.batch_size, args.dim).cuda() 144 | noisev = autograd.Variable(noise) 145 | fake = netG(noisev) 146 | G = netD(fake) 147 | G = G.mean() 148 | G.backward(mone) 149 | G_cost = -G 150 | optimizerG.step() 151 | 152 | schedulerD.step() 153 | schedulerG.step() 154 | schedulerE.step() 155 | # Write logs and save samples 156 | save_dir = './plots/'+args.dataset 157 | plot.plot(save_dir, '/disc cost', D_cost.cpu().data.numpy()) 158 | plot.plot(save_dir, '/gen cost', G_cost.cpu().data.numpy()) 159 | plot.plot(save_dir, '/w1 distance', Wasserstein_D.cpu().data.numpy()) 160 | plot.plot(save_dir, '/ae cost', ae_loss.data.cpu().numpy()) 161 | 162 | # Calculate dev loss and generate samples every 100 iters 163 | if iteration % 100 == 99: 164 | dev_disc_costs = [] 165 | for i, (images, targets) in enumerate(test_gen): 166 | imgs_v = autograd.Variable(images, volatile=True).cuda() 167 | D = netD(imgs_v) 168 | _dev_disc_cost = -D.mean().cpu().data.numpy() 169 | dev_disc_costs.append(_dev_disc_cost) 170 | plot.plot(save_dir ,'/dev disc cost', np.mean(dev_disc_costs)) 171 | utils.generate_image(iteration, netG, save_dir, args) 172 | # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v) 173 | 174 | # Save logs every 100 iters 175 | if (iteration < 5) or (iteration % 100 == 99): 176 | plot.flush() 177 | plot.tick() 178 | if iteration % 100 == 0: 179 | utils.save_model(netG, optimizerG, iteration, 180 | 'models/{}/G_{}'.format(args.dataset, iteration)) 181 | utils.save_model(netD, optimizerD, iteration, 182 | 'models/{}/D_{}'.format(args.dataset, iteration)) 183 | iteration += 1 184 | 185 | 186 | 187 | if __name__ == '__main__': 188 | train() 189 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import numpy as np 4 | from data import mnist 5 | from data import cifar10 6 | from data import celeba 7 | from scipy.misc import imsave 8 | import matplotlib.pyplot as plt 9 | import torchvision 10 | import torchvision.datasets as datasets 11 | import torchvision.transforms as transforms 12 | import PIL.Image as Image 13 | 14 | 15 | def save_model(net, optim, epoch, path): 16 | state_dict = net.state_dict() 17 | torch.save({ 18 | 'epoch': epoch + 1, 19 | 'state_dict': state_dict, 20 | 'optimizer': optim.state_dict(), 21 | }, path) 22 | 23 | 24 | def generate_ae_image(iter, netE, netG, save_path, args, real_data): 25 | batch_size = args.batch_size 26 | datashape = netE.shape 27 | encoding = netE(real_data) 28 | samples = netG(encoding) 29 | if netG._name == 'mnistG': 30 | samples = samples.view(batch_size, 28, 28) 31 | else: 32 | samples = samples.view(-1, *(datashape[::-1])) 33 | samples = samples.mul(0.5).add(0.5) 34 | samples = samples.cpu().data.numpy() 35 | save_images(samples, save_path+'/ae_samples_{}.jpg'.format(iter)) 36 | 37 | 38 | def generate_image(iter, model, save_path, args): 39 | batch_size = args.batch_size 40 | datashape = model.shape 41 | if model._name == 'mnistG': 42 | fixed_noise_128 = torch.randn(batch_size, args.dim).cuda() 43 | else: 44 | fixed_noise_128 = torch.randn(128, args.dim).cuda() 45 | noisev = autograd.Variable(fixed_noise_128, volatile=True) 46 | samples = model(noisev) 47 | if model._name == 'mnistG': 48 | samples = samples.view(batch_size, 28, 28) 49 | else: 50 | samples = samples.view(-1, *(datashape[::-1])) 51 | samples = samples.mul(0.5).add(0.5) 52 | samples = samples.cpu().data.numpy() 53 | save_images(samples, save_path+'/samples_{}.jpg'.format(iter)) 54 | 55 | 56 | def save_images(X, save_path, use_np=False): 57 | # [0, 1] -> [0,255] 58 | plt.ion() 59 | if not use_np: 60 | if isinstance(X.flatten()[0], np.floating): 61 | X = (255.99*X).astype('uint8') 62 | n_samples = X.shape[0] 63 | rows = int(np.sqrt(n_samples)) 64 | while n_samples % rows != 0: 65 | rows -= 1 66 | nh, nw = rows, int(n_samples/rows) 67 | if X.ndim == 2: 68 | s = int(np.sqrt(X.shape[1])) 69 | X = np.reshape(X, (X.shape[0], s, s)) 70 | if X.ndim == 4: 71 | X = X.transpose(0,2,3,1) 72 | h, w = X[0].shape[:2] 73 | img = np.zeros((h*nh, w*nw, 3)) 74 | elif X.ndim == 3: 75 | h, w = X[0].shape[:2] 76 | img = np.zeros((h*nh, w*nw)) 77 | for n, x in enumerate(X): 78 | j = int(n/nw) 79 | i = int(n%nw) 80 | img[j*h:j*h+h, i*w:i*w+w] = x 81 | 82 | plt.imshow(img, cmap='gray') 83 | plt.draw() 84 | plt.pause(0.001) 85 | 86 | if use_np: 87 | np.save(save_path, img) 88 | else: 89 | imsave(save_path, img) 90 | 91 | 92 | --------------------------------------------------------------------------------