├── .gitignore ├── utils.py ├── README.md ├── datasets.py ├── autoencoder.py ├── cnn.py ├── vae.py └── LearnDataAugCIFAR.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *pyc 3 | *~ 4 | *._ 5 | .ipynb_checkpoints 6 | __pycache__ 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | 8 | 9 | class CenterCrop(nn.Module): 10 | def __init__(self, height, width): 11 | super().__init__() 12 | self.height = height 13 | self.width = width 14 | 15 | def forward(self, img): 16 | bs, c, h, w = img.size() 17 | xy1 = (w - self.width) // 2 18 | xy2 = (h - self.height) // 2 19 | img = img[:, :, xy2:(xy2 + self.height), xy1:(xy1 + self.width)] 20 | return img 21 | 22 | def plot_tensor(img, fs=(10,10), title=""): 23 | if len(img.size()) == 4: 24 | img = img.squeeze(dim=0) 25 | npimg = img.numpy() 26 | plt.figure(figsize=fs) 27 | plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray') 28 | plt.title(title) 29 | plt.show() 30 | 31 | def plot_batch(samples, title="", fs=(10,10)): 32 | plot_tensor(torchvision.utils.make_grid(samples), fs=fs, title=title) 33 | 34 | def plot_metric(trn, tst, title): 35 | plt.plot(np.stack([trn, tst], 1)); 36 | plt.title(title) 37 | plt.show() 38 | 39 | def get_accuracy(preds, targets): 40 | correct = np.sum(preds==targets) 41 | return correct / len(targets) 42 | 43 | def get_argmax(output): 44 | val,idx = torch.max(output, dim=1) 45 | return idx.data.cpu().view(-1).numpy() 46 | 47 | def predict_batch(net, inputs): 48 | v = Variable(inputs.cuda(), volatile=True) 49 | return net(v).data.cpu().numpy() 50 | 51 | def get_probabilities(model, loader): 52 | model.eval() 53 | return np.vstack(predict_batch(model, data[0]) for data in loader) 54 | 55 | def get_predictions(probs, thresholds): 56 | preds = np.copy(probs) 57 | preds[preds >= thresholds] = 1 58 | preds[preds < thresholds] = 0 59 | return preds.astype('uint8') 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Data Augmentation 2 | OpenAI Request for Research - https://blog.openai.com/requests-for-research-2/ 3 | 4 | ## Requirements 5 | 6 | * Python 3 7 | * Anaconda (Numpy, Matplotlib, PIL, etc) 8 | * PyTorch 9 | * GPU, CUDA 10 | 11 | Tested on Ubuntu 16.04 with GeForce GTX 1080 Ti 12 | 13 | ## Papers 14 | 15 | * [RenderGAN: Generating Realistic Labeled Data](https://arxiv.org/abs/1611.01331) 16 | * [Dataset Augmentation in Feature Space](https://arxiv.org/abs/1702.05538) 17 | * [Learning to Compose Domain-Specific Transformations](https://arxiv.org/abs/1709.01643) 18 | * [Data Augmentation Generative Adversarial Networks](https://arxiv.org/abs/1711.04340) 19 | * [Data Augmentation in Emotion Classification Using GANs](https://arxiv.org/abs/1711.00648) 20 | * [Neural Augmentation and Cycle GAN](http://cs231n.stanford.edu/reports/2017/pdfs/300.pdf) 21 | * [Bayesian Data Augmentation Approach](https://arxiv.org/abs/1710.10564) 22 | * [Learning To Model the Tail](https://papers.nips.cc/paper/7278-learning-to-model-the-tail) 23 | * [Tutorial on Variational Autoencoders](https://arxiv.org/abs/1606.05908) 24 | 25 | 26 | ## Approaches 27 | 1. Autoencoder: Encode --> Peturb/Interpolate --> Decode 28 | 2. Synthetic Dataset Generation with CGAN 29 | 3. Domain Adaption / Image-to-Image translation (CycleGAN) 30 | 4. Augmentation / Preprocessing Net + backprop classification loss 31 | 32 | 33 | ## Posts 34 | 35 | ### Autoencoder 36 | 37 | * https://pgaleone.eu/neural-networks/2016/11/24/convolutional-autoencoders/ 38 | 39 | ### VAE 40 | 41 | * http://kvfrans.com/variational-autoencoders-explained/ 42 | * https://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/ 43 | * https://wiseodd.github.io/techblog/2016/12/17/conditional-vae/ 44 | 45 | ### GAN 46 | 47 | * http://blog.otoro.net/2016/04/01/generating-large-images-from-latent-vectors/ 48 | 49 | 50 | ## Code 51 | 52 | ### Autoencoder 53 | 54 | * https://blog.keras.io/building-autoencoders-in-keras.html 55 | 56 | ### VAE 57 | 58 | * https://github.com/kvfrans/variational-autoencoder 59 | * https://github.com/wiseodd/generative-models/blob/master/VAE/vanilla_vae/vae_pytorch.py 60 | * https://github.com/wiseodd/generative-models/blob/master/VAE/conditional_vae/cvae_pytorch.py 61 | 62 | ## Links 63 | * https://blog.openai.com/requests-for-research-2/ 64 | * http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html#4d4e495354 65 | * https://paper.dropbox.com/doc/VAE-QXVhUneHWqtvYELdjpR9d (my notes) 66 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import torchvision 5 | from torchvision import transforms 6 | 7 | 8 | CIFAR_CLASSES = ( 9 | 'plane', 'car', 'bird', 'cat', 'deer', 10 | 'dog', 'frog', 'horse', 'ship', 'truck' 11 | ) 12 | 13 | TRN_TRANSFORM = transforms.Compose([ 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ToTensor() 16 | ]) 17 | TST_TRANSFORM = transforms.Compose([ 18 | transforms.ToTensor() 19 | ]) 20 | 21 | class LearnedTransform(): 22 | def __init__(self, model): 23 | self.model = model 24 | 25 | def __call__(self, x): 26 | x = Variable(x) 27 | return self.model.transform(x) 28 | 29 | 30 | def get_cifar_dataset(trn_size=50000, tst_size=10000, 31 | trn_transform=TRN_TRANSFORM, tst_transform=TST_TRANSFORM): 32 | trainset = torchvision.datasets.CIFAR10( 33 | root='data/', train=True, download=True, transform=trn_transform) 34 | trainset.train_data = trainset.train_data[:trn_size] 35 | trainset.train_labels = trainset.train_labels[:trn_size] 36 | 37 | testset = torchvision.datasets.CIFAR10( 38 | root='data/', train=False, download=True, transform=tst_transform) 39 | testset.test_data = testset.test_data[:tst_size] 40 | testset.test_labels = testset.test_labels[:tst_size] 41 | return trainset, testset 42 | 43 | def get_cifar_loader(trainset, testset, batch_size=64): 44 | trainloader = torch.utils.data.DataLoader( 45 | trainset, batch_size=batch_size, shuffle=True, num_workers=2) 46 | testloader = torch.utils.data.DataLoader( 47 | testset, batch_size=batch_size, shuffle=False, num_workers=2) 48 | return trainloader, testloader 49 | 50 | 51 | def get_mnist_dataset(trn_size=60000, tst_size=10000): 52 | MNIST_MEAN = np.array([0.1307,]) 53 | MNIST_STD = np.array([0.3081,]) 54 | normTransform = transforms.Normalize(MNIST_MEAN, MNIST_STD) 55 | trainTransform = transforms.Compose([ 56 | transforms.ToTensor(), 57 | #normTransform 58 | ]) 59 | testTransform = transforms.Compose([ 60 | transforms.ToTensor(), 61 | #normTransform 62 | ]) 63 | 64 | trainset = torchvision.datasets.MNIST(root='../data', train=True, 65 | download=True, transform=trainTransform) 66 | trainset.train_data = trainset.train_data[:trn_size] 67 | trainset.train_labels = trainset.train_labels[:trn_size] 68 | testset = torchvision.datasets.MNIST(root='../data', train=False, 69 | download=True, transform=testTransform) 70 | testset.test_data = testset.test_data[:tst_size] 71 | testset.test_labels = testset.test_labels[:tst_size] 72 | return trainset, testset 73 | 74 | def get_mnist_loader(trainset, testset, batch_size=128): 75 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 76 | shuffle=True, num_workers=2) 77 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 78 | shuffle=False, num_workers=2) 79 | 80 | return trainloader, testloader 81 | -------------------------------------------------------------------------------- /autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | import utils 8 | 9 | 10 | class AE(nn.Module): 11 | def __init__(self, in_shape): 12 | super().__init__() 13 | c,h,w = in_shape 14 | self.encoder = nn.Sequential( 15 | nn.Linear(c*h*w, 128), 16 | nn.ReLU(), 17 | nn.Linear(128, 64), 18 | nn.ReLU(), 19 | nn.Linear(64, 12), 20 | nn.ReLU(), 21 | nn.Linear(12, 3)) 22 | self.decoder = nn.Sequential( 23 | nn.Linear(3, 12), 24 | nn.ReLU(), 25 | nn.Linear(12, 64), 26 | nn.ReLU(), 27 | nn.Linear(64, 128), 28 | nn.ReLU(), 29 | nn.Linear(128, c*h*w), 30 | nn.Sigmoid()) 31 | 32 | def forward(self, x): 33 | bs,c,h,w = x.size() 34 | x = x.view(bs, -1) 35 | x = self.encoder(x) 36 | x = self.decoder(x) 37 | x = x.view(bs, c, h, w) 38 | return x 39 | 40 | 41 | class ConvAE(nn.Module): 42 | def __init__(self, in_shape): 43 | super().__init__() 44 | c,h,w = in_shape 45 | self.encoder = nn.Sequential( 46 | nn.Conv2d(c, 16, kernel_size=3, stride=1, padding=1), # b, 16, 32, 32 47 | nn.ReLU(), 48 | nn.MaxPool2d(kernel_size=2, stride=2), # b, 16, 16, 16 49 | nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1), # b, 8, 16, 16 50 | nn.ReLU(), 51 | nn.MaxPool2d(kernel_size=2, stride=2) # b, 8, 8, 8 52 | ) 53 | self.decoder = nn.Sequential( 54 | nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2, padding=0), # 16, 17, 17 55 | nn.ReLU(), 56 | nn.ConvTranspose2d(16, c, kernel_size=3, stride=2, padding=1), # 3, 33, 33 57 | utils.CenterCrop(h, w), 58 | nn.Sigmoid() 59 | ) 60 | 61 | def forward(self, x): 62 | x = self.encoder(x) 63 | x = self.decoder(x) 64 | return x 65 | 66 | 67 | def train(model, loader, criterion, optim): 68 | model.train() 69 | total_loss = 0 70 | for img, _ in loader: 71 | inputs = Variable(img.cuda()) 72 | 73 | output = model(inputs) 74 | loss = criterion(output, inputs) 75 | 76 | optim.zero_grad() 77 | loss.backward() 78 | optim.step() 79 | 80 | total_loss += loss.data[0] 81 | 82 | mean_loss = total_loss / len(loader) 83 | return mean_loss 84 | 85 | def predict(model, img): 86 | model.eval() 87 | if len(img.size()) == 3: 88 | c,h,w = img.size() 89 | img = img.view(1,c,h,w) 90 | img = Variable(img.cuda()) 91 | out = model(img).data.cpu() 92 | return out 93 | 94 | def predict_batch(model, loader): 95 | inputs, _ = next(iter(loader)) 96 | out = predict(model, inputs) 97 | return out 98 | 99 | def run(model, trn_loader, crit, optim, epochs, plot_interval=1000): 100 | losses = [] 101 | for epoch in range(epochs): 102 | loss = train(model, trn_loader, crit, optim) 103 | print('Epoch {:d} Loss: {:.4f}'.format(epoch+1, loss)) 104 | if epoch % plot_interval == 0: 105 | samples = predict_batch(model, trn_loader) 106 | utils.plot_batch(samples) 107 | losses.append(loss) 108 | samples = predict_batch(model, trn_loader) 109 | utils.plot_batch(samples) 110 | return losses 111 | -------------------------------------------------------------------------------- /cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | import utils 8 | 9 | 10 | class CNN(nn.Module): 11 | def __init__(self, in_shape, n_classes): 12 | super().__init__() 13 | c, w, h = in_shape 14 | pool_layers = 3 15 | fc_h = int(h / 2**pool_layers) 16 | fc_w = int(w / 2**pool_layers) 17 | self.features = nn.Sequential( 18 | *conv_bn_relu(c, 16, kernel_size=1, stride=1, padding=0), 19 | *conv_bn_relu(16, 32, kernel_size=3, stride=1, padding=1), 20 | nn.MaxPool2d(kernel_size=2, stride=2), #size/2 21 | *conv_bn_relu(32, 64, kernel_size=3, stride=1, padding=1), 22 | nn.MaxPool2d(kernel_size=2, stride=2), #size/2 23 | *conv_bn_relu(64, 128, kernel_size=3, stride=1, padding=1), 24 | nn.MaxPool2d(kernel_size=2, stride=2), #size/2 25 | ) 26 | self.classifier = nn.Sequential( 27 | *linear_bn_relu_drop(128 * fc_h * fc_w, 128, dropout=0.5), 28 | nn.Linear(128, n_classes), 29 | nn.Softmax(dim=1) 30 | ) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | 39 | class Trainer(): 40 | def __init__(self, optimizer, lr_adjuster=None, augmentor=None): 41 | self.metrics = { 42 | 'loss': { 43 | 'trn':[], 44 | 'tst':[] 45 | }, 46 | 'accuracy': { 47 | 'trn':[], 48 | 'tst':[] 49 | }, 50 | } 51 | self.optimizer = optimizer 52 | self.lr_adjuster = lr_adjuster 53 | self.augmentor = augmentor 54 | 55 | def run(self, model, trn_loader, tst_loader, criterion, epochs): 56 | for epoch in range(1, epochs+1): 57 | trn_loss, trn_acc = train(model, trn_loader, criterion, 58 | self.optimizer, self.lr_adjuster, 59 | self.augmentor) 60 | tst_loss, tst_acc = test(model, tst_loader, criterion) 61 | print('Epoch %d, TrnLoss: %.3f, TrnAcc: %.3f, TstLoss: %.3f, TstAcc: %.3f' % ( 62 | epoch, trn_loss, trn_acc, tst_loss, tst_acc)) 63 | self.metrics['loss']['trn'].append(trn_loss) 64 | self.metrics['loss']['tst'].append(tst_loss) 65 | self.metrics['accuracy']['trn'].append(trn_acc) 66 | self.metrics['accuracy']['tst'].append(tst_acc) 67 | 68 | 69 | def train(net, loader, crit, optim, lr_adjuster=None, augmentor=None): 70 | net.train() 71 | n_batches = len(loader) 72 | total_loss = 0 73 | total_acc = 0 74 | for inputs, targets in loader: 75 | inputs = Variable(inputs.cuda()) 76 | targets = Variable(targets.cuda()) 77 | 78 | if augmentor is not None: 79 | inputs = augmentor.transform(inputs) 80 | 81 | output = net(inputs) 82 | loss = crit(output, targets) 83 | 84 | optim.zero_grad() 85 | loss.backward() 86 | optim.step() 87 | 88 | preds = utils.get_argmax(output) 89 | accuracy = utils.get_accuracy(preds, targets.data.cpu().numpy()) 90 | 91 | total_loss += loss.data[0] 92 | total_acc += accuracy 93 | 94 | if lr_adjuster is not None: 95 | lr_adjuster.step() 96 | mean_loss = total_loss / n_batches 97 | mean_acc = total_acc / n_batches 98 | return mean_loss, mean_acc 99 | 100 | def test(net, tst_loader, criterion): 101 | net.eval() 102 | test_loss = 0 103 | test_acc = 0 104 | for data in tst_loader: 105 | inputs = Variable(data[0].cuda(), volatile=True) 106 | target = Variable(data[1].cuda()) 107 | output = net(inputs) 108 | test_loss += criterion(output, target).data[0] 109 | pred = utils.get_argmax(output) 110 | test_acc += utils.get_accuracy(pred, target.data.cpu().numpy()) 111 | test_loss /= len(tst_loader) 112 | test_acc /= len(tst_loader) 113 | return test_loss, test_acc 114 | 115 | def conv_bn_relu(in_chans, out_chans, kernel_size=3, stride=1, 116 | padding=1, bias=False): 117 | return [ 118 | nn.Conv2d(in_chans, out_chans, kernel_size=kernel_size, 119 | stride=stride, padding=padding, bias=bias), 120 | nn.BatchNorm2d(out_chans), 121 | nn.ReLU(inplace=True), 122 | ] 123 | 124 | def linear_bn_relu_drop(in_chans, out_chans, dropout=0.5, bias=False): 125 | layers = [ 126 | nn.Linear(in_chans, out_chans, bias=bias), 127 | nn.BatchNorm1d(out_chans), 128 | nn.ReLU(inplace=True) 129 | ] 130 | if dropout > 0: 131 | layers.append(nn.Dropout(dropout)) 132 | return layers 133 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | import utils 8 | 9 | 10 | class VAE(nn.Module): 11 | def __init__(self, in_shape, n_latent): 12 | super().__init__() 13 | self.in_shape = in_shape 14 | self.n_latent = n_latent 15 | c,h,w = in_shape 16 | self.z_dim = h//2**2 # receptive field downsampled 3 times 17 | self.encoder = nn.Sequential( 18 | nn.BatchNorm2d(c), 19 | nn.Conv2d(c, 32, kernel_size=4, stride=2, padding=1), # 32, 16, 16 20 | nn.BatchNorm2d(32), 21 | nn.LeakyReLU(), 22 | nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 32, 8, 8 23 | nn.BatchNorm2d(64), 24 | nn.LeakyReLU(), 25 | ) 26 | self.z_mean = nn.Linear(64 * self.z_dim**2, n_latent) 27 | self.z_var = nn.Linear(64 * self.z_dim**2, n_latent) 28 | self.z_develop = nn.Linear(n_latent, 64 * self.z_dim**2) 29 | self.decoder = nn.Sequential( 30 | nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0), 31 | nn.BatchNorm2d(32), 32 | nn.ReLU(), 33 | nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1), 34 | utils.CenterCrop(h,w), 35 | nn.Sigmoid() 36 | ) 37 | 38 | def sample_z(self, mean, var): 39 | std = torch.exp(0.5 * var) 40 | eps = Variable(torch.randn(std.size())).cuda() 41 | return (eps * std) + mean 42 | 43 | def encode(self, x): 44 | x = self.encoder(x) 45 | x = x.view(x.size(0), -1) 46 | mean = self.z_mean(x) 47 | var = self.z_var(x) 48 | return mean, var 49 | 50 | def decode(self, z): 51 | out = self.z_develop(z) 52 | out = out.view(z.size(0), 64, self.z_dim, self.z_dim) 53 | out = self.decoder(out) 54 | return out 55 | 56 | def forward(self, x): 57 | mean, var = self.encode(x) 58 | z = self.sample_z(mean, var) 59 | out = self.decode(z) 60 | return out, mean, var 61 | 62 | def transform(self, x, eps=0.1): 63 | mean, var = self.encode(x) 64 | noise = Variable(1. + torch.randn(x.size(0), self.n_latent).cuda() * eps) 65 | z = self.sample_z(mean*noise, var*noise) 66 | out = self.decode(z) 67 | return out 68 | 69 | 70 | def vae_loss(output, input, mean, var, criterion): 71 | recon_loss = criterion(output, input) 72 | kl_loss = torch.mean(0.5 * torch.sum(torch.exp(var) + mean**2 - 1. - var, 1)) 73 | return recon_loss + kl_loss 74 | 75 | def generate(model, mean, var): 76 | model.eval() 77 | mean = Variable(mean.cuda()) 78 | var = Variable(var.cuda()) 79 | z = model.sample_z(mean, var) 80 | out = model.decode(z) 81 | return out.data.cpu() 82 | 83 | def predict(model, img): 84 | model.eval() 85 | if len(img.size()) == 3: 86 | c,h,w = img.size() 87 | img = img.view(1,c,h,w) 88 | img = Variable(img.cuda()) 89 | out, mean, var = model(img) 90 | return out.data.cpu(), mean.data.cpu(), var.data.cpu() 91 | 92 | def predict_batch(model, loader): 93 | inputs, _ = next(iter(loader)) 94 | out, mu, logvar = predict(model, inputs) 95 | return out, mu, logvar 96 | 97 | def train(model, dataloader, crit, optim): 98 | model.train() 99 | total_loss = 0 100 | for img, _ in dataloader: 101 | inputs = Variable(img.cuda()) 102 | 103 | output, mean, var = model(inputs) 104 | loss = vae_loss(output, inputs, mean, var, crit) 105 | 106 | optim.zero_grad() 107 | loss.backward() 108 | optim.step() 109 | 110 | total_loss += loss.data[0] 111 | 112 | return total_loss / len(dataloader) 113 | 114 | def test(model, tst_loader, crit): 115 | model.eval() 116 | test_loss = 0 117 | for inputs, targets in tst_loader: 118 | inputs = Variable(inputs.cuda(), volatile=True) 119 | target = Variable(targets.cuda()) 120 | 121 | output, mean, var = model(inputs) 122 | 123 | loss = vae_loss(output, inputs, mean, var, crit) 124 | test_loss += loss.data[0] 125 | 126 | test_loss /= len(tst_loader) 127 | return test_loss 128 | 129 | def run(model, trn_loader, tst_loader, crit, optim, epochs, plot_interval=1000): 130 | losses = {'trn': [], 'tst':[]} 131 | for epoch in range(epochs): 132 | trn_loss = train(model, trn_loader, crit, optim) 133 | tst_loss = test(model, tst_loader, crit) 134 | print('Epoch %d, TrnLoss: %.4f, TstLoss: %.4f' % ( 135 | epoch+1, trn_loss, tst_loss)) 136 | if epoch % plot_interval == 0: 137 | samples, mu, var = predict_batch(model, tst_loader) 138 | utils.plot_batch(samples) 139 | losses['trn'].append(trn_loss) 140 | losses['tst'].append(tst_loss) 141 | samples, mean, var = predict_batch(model, trn_loader) 142 | utils.plot_batch(samples) 143 | return losses 144 | -------------------------------------------------------------------------------- /LearnDataAugCIFAR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Learning Data Augmentation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%matplotlib inline\n", 17 | "%load_ext autoreload\n", 18 | "%autoreload 2\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "import torch\n", 22 | "from torch.autograd import Variable\n", 23 | "import torch.nn.functional as F\n", 24 | "import torch.nn as nn\n", 25 | "import torchvision\n", 26 | "\n", 27 | "import datasets\n", 28 | "import autoencoder\n", 29 | "import cnn\n", 30 | "import vae\n", 31 | "import utils" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "## Dataset" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "trn_dset, tst_dset = datasets.get_cifar_dataset(trn_size=5000, tst_size=5000)\n", 48 | "trn_loader, tst_loader = datasets.get_cifar_loader(trn_dset, tst_dset, batch_size=64)\n", 49 | "inputs,targets = next(iter(trn_loader))\n", 50 | "utils.plot_batch(inputs)\n", 51 | "print(\"Train:\", len(trn_loader.dataset), \"Test:\", len(tst_loader.dataset), \n", 52 | " \"Input:\", inputs.size(), \"Target:\", targets.size())" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## Classifier" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "model = cnn.CNN(in_shape=(3,32,32), n_classes=10).cuda()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "epochs = 20\n", 78 | "iters = epochs * len(trn_loader)\n", 79 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)\n", 80 | "criterion = nn.CrossEntropyLoss()\n", 81 | "lr_adjuster = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n", 82 | "trainer = cnn.Trainer(optimizer, lr_adjuster)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "trainer.run(model, trn_loader, tst_loader, criterion, epochs)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "utils.plot_metric(trainer.metrics['loss']['trn'], trainer.metrics['loss']['tst'], 'Loss')\n", 101 | "utils.plot_metric(trainer.metrics['accuracy']['trn'], trainer.metrics['accuracy']['tst'], 'Accuracy')" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Autoencoder" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "model = autoencoder.ConvAE(in_shape=(3,32,32)).cuda()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "criterion = nn.MSELoss()\n", 127 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": { 134 | "scrolled": false 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "losses = autoencoder.run(model, trn_loader, criterion, optimizer, epochs=50)\n", 139 | "utils.plot_metric(losses, losses, 'Loss')" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## VAE" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "model = vae.VAE(in_shape=(3,32,32), n_latent=100).cuda()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "criterion = nn.MSELoss(size_average=False)\n", 165 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "scrolled": false 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "losses = vae.run(model, trn_loader, tst_loader, criterion, \n", 177 | " optimizer, epochs=50, plot_interval=25)\n", 178 | "utils.plot_metric(losses['trn'], losses['tst'], 'Loss')" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "# Single Image\n", 188 | "img_idx = 1\n", 189 | "noise = 1. + torch.randn(1) * 1e-1 \n", 190 | "recon, mean, var = vae.predict(model, inputs[img_idx])\n", 191 | "out = vae.generate(model, mean*noise, var*noise)\n", 192 | "utils.plot_tensor(inputs[img_idx], title=\"Input\", fs=(4,4))\n", 193 | "utils.plot_tensor(out, title=\"Generated\", fs=(4,4))" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# Batch\n", 203 | "recon, mean, var = vae.predict(model, inputs)\n", 204 | "out = vae.generate(model, mean, var)\n", 205 | "utils.plot_batch(inputs)\n", 206 | "utils.plot_batch(out)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "## Classifier w VAE Augmentation" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "augmentor = model\n", 223 | "classifier = cnn.CNN(in_shape=(3,32,32), n_classes=10).cuda()" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "epochs = 50\n", 233 | "optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)\n", 234 | "criterion = nn.CrossEntropyLoss()\n", 235 | "lr_adjuster = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)\n", 236 | "trainer = cnn.Trainer(optimizer, lr_adjuster, augmentor)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "trainer.run(classifier, trn_loader, tst_loader, criterion, epochs)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "utils.plot_metric(trainer.metrics['loss']['trn'], trainer.metrics['loss']['tst'], 'Loss')\n", 255 | "utils.plot_metric(trainer.metrics['accuracy']['trn'], trainer.metrics['accuracy']['tst'], 'Accuracy')" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "## GAN" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "# https://arxiv.org/abs/1711.04340" 272 | ] 273 | } 274 | ], 275 | "metadata": { 276 | "kernelspec": { 277 | "display_name": "Python 3", 278 | "language": "python", 279 | "name": "python3" 280 | }, 281 | "language_info": { 282 | "codemirror_mode": { 283 | "name": "ipython", 284 | "version": 3 285 | }, 286 | "file_extension": ".py", 287 | "mimetype": "text/x-python", 288 | "name": "python", 289 | "nbconvert_exporter": "python", 290 | "pygments_lexer": "ipython3", 291 | "version": "3.6.4" 292 | }, 293 | "latex_envs": { 294 | "LaTeX_envs_menu_present": true, 295 | "autocomplete": true, 296 | "bibliofile": "biblio.bib", 297 | "cite_by": "apalike", 298 | "current_citInitial": 1, 299 | "eqLabelWithNumbers": true, 300 | "eqNumInitial": 1, 301 | "hotkeys": { 302 | "equation": "Ctrl-E", 303 | "itemize": "Ctrl-I" 304 | }, 305 | "labels_anchors": false, 306 | "latex_user_defs": false, 307 | "report_style_numbering": false, 308 | "user_envs_cfg": false 309 | }, 310 | "toc": { 311 | "colors": { 312 | "hover_highlight": "#DAA520", 313 | "navigate_num": "#000000", 314 | "navigate_text": "#333333", 315 | "running_highlight": "#FF0000", 316 | "selected_highlight": "#FFD700", 317 | "sidebar_border": "#EEEEEE", 318 | "wrapper_background": "#FFFFFF" 319 | }, 320 | "moveMenuLeft": true, 321 | "nav_menu": { 322 | "height": "48px", 323 | "width": "254px" 324 | }, 325 | "navigate_menu": true, 326 | "number_sections": true, 327 | "sideBar": true, 328 | "threshold": 4, 329 | "toc_cell": false, 330 | "toc_section_display": "block", 331 | "toc_window_display": false, 332 | "widenNotebook": false 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 2 337 | } 338 | --------------------------------------------------------------------------------