├── README.md ├── logger.py ├── main.py ├── main_aae.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Adversarial Autoencoders 2 | Replicated the results from [this blog post](https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/) using PyTorch. 3 | 4 | Using [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) to view the trainging from [this repo.](https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/04-utils/tensorboard) 5 | 6 | Autoencoders can be used to reduce dimensionality in the data. This example uses the Encoder to fit the data (unsupervised step) and then uses the encoder representation as "features" to train the labels. 7 | 8 | The result is not as good as using the raw features with a simple NN. This example is designed to demonstrate the workflow for AAE and using that as features for a supervised step. 9 | 10 |
11 | 12 | ## Usage 13 | 14 | #### 1. Install the dependencies 15 | ```bash 16 | $ pip install -r requirements.txt 17 | ``` 18 | 19 | #### 2. Train the AAE model & supervised step 20 | ```bash 21 | $ python main_aae.py && python main.py 22 | ``` 23 | 24 | #### 3. Open TensorBoard to view training steps 25 | To run the TensorBoard, open a new terminal and run the command below. Then, open http://localhost:6006/ in your web browser. 26 | ```bash 27 | $ tensorboard --logdir='./logs' --port=6006 28 | ``` 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() 72 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.datasets as dsets 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | #from logger import Logger 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # MNIST Dataset 12 | dataset = dsets.MNIST(root='./data', 13 | train=True, 14 | transform=transforms.ToTensor(), 15 | download=True) 16 | 17 | # Data Loader (Input Pipeline) 18 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 19 | batch_size=100, 20 | shuffle=True) 21 | 22 | def to_np(x): 23 | return x.data.cpu().numpy() 24 | 25 | def to_var(x): 26 | if torch.cuda.is_available(): 27 | x = x.cuda() 28 | return Variable(x) 29 | 30 | #Encoder 31 | class Q_net(nn.Module): 32 | def __init__(self,X_dim,N,z_dim): 33 | super(Q_net, self).__init__() 34 | self.lin1 = nn.Linear(X_dim, N) 35 | self.lin2 = nn.Linear(N, N) 36 | self.lin3gauss = nn.Linear(N, z_dim) 37 | def forward(self, x): 38 | x = F.dropout(self.lin1(x), p=0.25, training=self.training) 39 | x = F.relu(x) 40 | x = F.dropout(self.lin2(x), p=0.25, training=self.training) 41 | x = F.relu(x) 42 | xgauss = self.lin3gauss(x) 43 | return xgauss 44 | 45 | # Neural Network Model (1 hidden layer) 46 | class Net(nn.Module): 47 | def __init__(self, input_size=784, hidden_size=500, num_classes=10): 48 | super(Net, self).__init__() 49 | self.fc1 = nn.Linear(input_size, hidden_size) 50 | self.relu = nn.ReLU() 51 | self.fc2 = nn.Linear(hidden_size, num_classes) 52 | 53 | def forward(self, x): 54 | out = self.fc1(x) 55 | out = self.relu(out) 56 | out = F.log_softmax(self.fc2(out)) 57 | return out 58 | 59 | z_red_dims = 120 60 | Q = Q_net(784,1000,z_red_dims).cuda() 61 | Q.load_state_dict(torch.load('Q_encoder_weights.pt')) 62 | Q.eval() #turn off dropout 63 | net = Net(input_size = z_red_dims).cuda() 64 | 65 | # # Set the logger 66 | # logger = Logger('./logs/encoder_fit_120_sm_eval') 67 | 68 | # Loss and Optimizer 69 | criterion = nn.CrossEntropyLoss() 70 | optimizer = torch.optim.Adam(net.parameters())#, lr=0.00001) 71 | 72 | data_iter = iter(data_loader) 73 | iter_per_epoch = len(data_loader) 74 | total_step = 50000 75 | 76 | # Start training 77 | for step in range(total_step): 78 | 79 | # Reset the data_iter 80 | if (step+1) % iter_per_epoch == 0: 81 | data_iter = iter(data_loader) 82 | 83 | # Fetch the images and labels and convert them to variables 84 | images, labels = next(data_iter) 85 | images, labels = to_var(images.view(images.size(0), -1)), to_var(labels) 86 | 87 | # Forward, backward and optimize 88 | optimizer.zero_grad() # zero the gradient buffer 89 | outputs = net(Q(images)) 90 | # outputs = net(images) 91 | loss = criterion(outputs, labels) 92 | loss.backward() 93 | optimizer.step() 94 | 95 | # Compute accuracy 96 | _, argmax = torch.max(outputs, 1) 97 | accuracy = (labels == argmax.squeeze()).float().mean() 98 | 99 | if (step+1) % 100 == 0: 100 | print ('Step [%d/%d], Loss: %.4f, Acc: %.2f' 101 | %(step+1, total_step, loss.data[0], accuracy.data[0])) 102 | 103 | #============ TensorBoard logging ============# 104 | # (1) Log the scalar values 105 | info = { 106 | 'loss': loss.data[0], 107 | 'accuracy': accuracy.data[0] 108 | } 109 | 110 | # for tag, value in info.items(): 111 | # logger.scalar_summary(tag, value, step+1) 112 | 113 | # # (2) Log values and gradients of the parameters (histogram) 114 | # for tag, value in net.named_parameters(): 115 | # tag = tag.replace('.', '/') 116 | # logger.histo_summary(tag, to_np(value), step+1) 117 | # logger.histo_summary(tag+'/grad', to_np(value.grad), step+1) 118 | 119 | # # (3) Log the images 120 | # info = { 121 | # 'images': to_np(images.view(-1, 28, 28)[:10]) 122 | # } 123 | 124 | # for tag, images in info.items(): 125 | # logger.image_summary(tag, images, step+1) 126 | 127 | #test 128 | # MNIST Dataset 129 | dataset_test = dsets.MNIST(root='./data', 130 | train=False, 131 | transform=transforms.ToTensor(), 132 | download=True) 133 | 134 | # Data Loader (Input Pipeline) 135 | data_loader_test = torch.utils.data.DataLoader(dataset=dataset_test, 136 | batch_size=10000, 137 | shuffle=True) 138 | data_iter_test = iter(data_loader_test) 139 | # Fetch the images and labels and convert them to variables 140 | images, labels = next(data_iter_test) 141 | images, labels = to_var(images.view(images.size(0), -1)), to_var(labels) 142 | 143 | outputs = net(Q(images)) 144 | # outputs = net(images) 145 | 146 | # Compute accuracy 147 | _, argmax = torch.max(outputs, 1) 148 | accuracy = (labels == argmax.squeeze()).float().mean() 149 | 150 | print(accuracy.data[0]) 151 | 152 | -------------------------------------------------------------------------------- /main_aae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.datasets as dsets 5 | import torchvision.transforms as transforms 6 | from torch.autograd import Variable 7 | from logger import Logger 8 | 9 | 10 | # MNIST Dataset 11 | dataset = dsets.MNIST(root='./data', 12 | train=True, 13 | transform=transforms.ToTensor(), 14 | download=True) 15 | 16 | # Data Loader (Input Pipeline) 17 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 18 | batch_size=100, 19 | shuffle=True) 20 | 21 | def to_np(x): 22 | return x.data.cpu().numpy() 23 | 24 | def to_var(x): 25 | if torch.cuda.is_available(): 26 | x = x.cuda() 27 | return Variable(x) 28 | 29 | 30 | #Encoder 31 | class Q_net(nn.Module): 32 | def __init__(self,X_dim,N,z_dim): 33 | super(Q_net, self).__init__() 34 | self.lin1 = nn.Linear(X_dim, N) 35 | self.lin2 = nn.Linear(N, N) 36 | self.lin3gauss = nn.Linear(N, z_dim) 37 | def forward(self, x): 38 | x = F.dropout(self.lin1(x), p=0.25, training=self.training) 39 | x = F.relu(x) 40 | x = F.dropout(self.lin2(x), p=0.25, training=self.training) 41 | x = F.relu(x) 42 | xgauss = self.lin3gauss(x) 43 | return xgauss 44 | 45 | # Decoder 46 | class P_net(nn.Module): 47 | def __init__(self,X_dim,N,z_dim): 48 | super(P_net, self).__init__() 49 | self.lin1 = nn.Linear(z_dim, N) 50 | self.lin2 = nn.Linear(N, N) 51 | self.lin3 = nn.Linear(N, X_dim) 52 | def forward(self, x): 53 | x = F.dropout(self.lin1(x), p=0.25, training=self.training) 54 | x = F.relu(x) 55 | x = F.dropout(self.lin2(x), p=0.25, training=self.training) 56 | x = self.lin3(x) 57 | return F.sigmoid(x) 58 | 59 | # Discriminator 60 | class D_net_gauss(nn.Module): 61 | def __init__(self,N,z_dim): 62 | super(D_net_gauss, self).__init__() 63 | self.lin1 = nn.Linear(z_dim, N) 64 | self.lin2 = nn.Linear(N, N) 65 | self.lin3 = nn.Linear(N, 1) 66 | def forward(self, x): 67 | x = F.dropout(self.lin1(x), p=0.2, training=self.training) 68 | x = F.relu(x) 69 | x = F.dropout(self.lin2(x), p=0.2, training=self.training) 70 | x = F.relu(x) 71 | return F.sigmoid(self.lin3(x)) 72 | 73 | 74 | EPS = 1e-15 75 | z_red_dims = 120 76 | Q = Q_net(784,1000,z_red_dims).cuda() 77 | P = P_net(784,1000,z_red_dims).cuda() 78 | D_gauss = D_net_gauss(500,z_red_dims).cuda() 79 | 80 | # Set the logger 81 | logger = Logger('./logs/z_120_fixed_LR_2') 82 | 83 | # Set learning rates 84 | gen_lr = 0.0001 85 | reg_lr = 0.00005 86 | 87 | #encode/decode optimizers 88 | optim_P = torch.optim.Adam(P.parameters(), lr=gen_lr) 89 | optim_Q_enc = torch.optim.Adam(Q.parameters(), lr=gen_lr) 90 | #regularizing optimizers 91 | optim_Q_gen = torch.optim.Adam(Q.parameters(), lr=reg_lr) 92 | optim_D = torch.optim.Adam(D_gauss.parameters(), lr=reg_lr) 93 | 94 | data_iter = iter(data_loader) 95 | iter_per_epoch = len(data_loader) 96 | total_step = 50000 97 | 98 | # Start training 99 | for step in range(total_step): 100 | 101 | # Reset the data_iter 102 | if (step+1) % iter_per_epoch == 0: 103 | data_iter = iter(data_loader) 104 | 105 | # Fetch the images and labels and convert them to variables 106 | images, labels = next(data_iter) 107 | images, labels = to_var(images.view(images.size(0), -1)), to_var(labels) 108 | 109 | #reconstruction loss 110 | P.zero_grad() 111 | Q.zero_grad() 112 | D_gauss.zero_grad() 113 | 114 | z_sample = Q(images) #encode to z 115 | X_sample = P(z_sample) #decode to X reconstruction 116 | recon_loss = F.binary_cross_entropy(X_sample+EPS,images+EPS) 117 | 118 | recon_loss.backward() 119 | optim_P.step() 120 | optim_Q_enc.step() 121 | 122 | # Discriminator 123 | ## true prior is random normal (randn) 124 | ## this is constraining the Z-projection to be normal! 125 | Q.eval() 126 | z_real_gauss = Variable(torch.randn(images.size()[0], z_red_dims) * 5.).cuda() 127 | D_real_gauss = D_gauss(z_real_gauss) 128 | 129 | z_fake_gauss = Q(images) 130 | D_fake_gauss = D_gauss(z_fake_gauss) 131 | 132 | D_loss = -torch.mean(torch.log(D_real_gauss + EPS) + torch.log(1 - D_fake_gauss + EPS)) 133 | 134 | D_loss.backward() 135 | optim_D.step() 136 | 137 | # Generator 138 | Q.train() 139 | z_fake_gauss = Q(images) 140 | D_fake_gauss = D_gauss(z_fake_gauss) 141 | 142 | G_loss = -torch.mean(torch.log(D_fake_gauss + EPS)) 143 | 144 | G_loss.backward() 145 | optim_Q_gen.step() 146 | 147 | 148 | if (step+1) % 100 == 0: 149 | # print ('Step [%d/%d], Loss: %.4f, Acc: %.2f' 150 | # %(step+1, total_step, loss.data[0], accuracy.data[0])) 151 | 152 | #============ TensorBoard logging ============# 153 | # (1) Log the scalar values 154 | info = { 155 | 'recon_loss': recon_loss.data[0], 156 | 'discriminator_loss': D_loss.data[0], 157 | 'generator_loss': G_loss.data[0] 158 | } 159 | 160 | for tag, value in info.items(): 161 | logger.scalar_summary(tag, value, step+1) 162 | 163 | # (2) Log values and gradients of the parameters (histogram) 164 | for net,name in zip([P,Q,D_gauss],['P_','Q_','D_']): 165 | for tag, value in net.named_parameters(): 166 | tag = name+tag.replace('.', '/') 167 | logger.histo_summary(tag, to_np(value), step+1) 168 | logger.histo_summary(tag+'/grad', to_np(value.grad), step+1) 169 | 170 | # (3) Log the images 171 | info = { 172 | 'images': to_np(images.view(-1, 28, 28)[:10]) 173 | } 174 | 175 | for tag, images in info.items(): 176 | logger.image_summary(tag, images, step+1) 177 | 178 | #save the Encoder 179 | torch.save(Q.state_dict(),'Q_encoder_weights.pt') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | torch 3 | torchvision 4 | scipy 5 | numpy 6 | --------------------------------------------------------------------------------