├── .DS_Store ├── README.md ├── __pycache__ ├── utils.cpython-36.pyc └── vcca.cpython-36.pyc ├── main.py ├── results ├── .DS_Store ├── final.png ├── sample1_10.png ├── sample1_15.png ├── sample1_20.png ├── sample1_25.png ├── sample1_30.png ├── sample1_35.png ├── sample1_40.png ├── sample1_45.png ├── sample1_5.png ├── sample1_50.png ├── sample2_10.png ├── sample2_15.png ├── sample2_20.png ├── sample2_25.png ├── sample2_30.png ├── sample2_35.png ├── sample2_40.png ├── sample2_45.png ├── sample2_5.png └── sample2_50.png ├── test.py ├── utils.py └── vcca.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VCCA: Variational Canonical Correlation Analysis 2 | 3 | This is an Pytorch implementation of [Deep Variational Canonical Correlation Analysis (VCCA)](https://arxiv.org/abs/1610.03454) in Python. 4 | 5 | ## Variational CCA and Variational CCA Private [VCCA, VCCAP] 6 | 7 | 8 | 9 | VCCA-private: 10 | 11 | 12 | 13 | 14 | [Deep Variational Canonical Correlation Analysis](https://github.com/edchengg/VCCA-StudyNotes/blob/master/paper/DVCCA.pdf) 15 | 16 | 17 | ### Training 18 | 5 epochs: 19 | 20 | view1 | view2 21 | :-------------------------:|:-------------------------: 22 | ![](https://github.com/edchengg/VCCA_pytorch/blob/master/results/sample1_5.png) | ![](https://github.com/edchengg/VCCA_pytorch/blob/master/results/sample2_5.png) 23 | 24 | 50 epochs: 25 | 26 | view1 | view2 27 | :-------------------------:|:-------------------------: 28 | ![](https://github.com/edchengg/VCCA_pytorch/blob/master/results/sample1_50.png) | ![](https://github.com/edchengg/VCCA_pytorch/blob/master/results/sample2_50.png) 29 | ### Generation 30 | 31 | 32 | 33 | ### Dataset 34 | The model is evaluated on a noisy version of MNIST dataset. [Vahid Noroozi](https://github.com/VahidooX/DeepCCA) built the dataset exactly like the way it is introduced in the paper. The train/validation/test split is the original split of MNIST. 35 | 36 | The dataset was large and could not get uploaded on GitHub. So it is uploaded on another server. You can download the data from: 37 | 38 | [view1](https://www2.cs.uic.edu/~vnoroozi/noisy-mnist/noisymnist_view1.gz) 39 | 40 | [view2](https://www2.cs.uic.edu/~vnoroozi/noisy-mnist/noisymnist_view2.gz) 41 | 42 | save it in the same directory with python code. 43 | 44 | ### Differences with the original paper 45 | The following are the differences between my implementation and the original paper (they are small): 46 | 47 | * I used simple bianry cross entropy loss for two decoder networks. 48 | 49 | ### Other Implementations 50 | 51 | The following are the other implementations of VCCA in Tensorflow, 52 | 53 | * [Tensorflow implementation](http://ttic.uchicago.edu/~wwang5/papers/vcca_tf0.9_code.tgz) from Wang, Weiran's website (http://ttic.uchicago.edu/~wwang5/) 54 | 55 | -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/vcca.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/__pycache__/vcca.cpython-36.pyc -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data 4 | from torch import optim 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from torchvision import datasets, transforms 8 | from torchvision.utils import save_image 9 | from vcca import VCCA 10 | from utils import * 11 | 12 | CUDA = False 13 | SEED = 1 14 | BATCH_SIZE = 128 15 | LOG_INTERVAL = 10 16 | EPOCHS = 50 17 | ZDIMS = 20 18 | PDIMS = 30 19 | PRIVATE = True 20 | # I do this so that the MNIST dataset is downloaded where I want it 21 | os.chdir("/Users/edison/PycharmProjects/vcca_pytorch") 22 | 23 | torch.manual_seed(SEED) 24 | 25 | if CUDA: 26 | torch.cuda.manual_seed(SEED) 27 | 28 | # DataLoader instances will load tensors directly into GPU memory 29 | kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {} 30 | 31 | # Download or load downloaded MNIST dataset 32 | # shuffle data at every epoch 33 | 34 | data1 = load_data('noisymnist_view1.gz') 35 | data2 = load_data('noisymnist_view2.gz') 36 | 37 | train_set_x1, _ = data1[0] 38 | train_set_x2, _ = data2[0] 39 | 40 | train_loader = torch.utils.data.DataLoader( 41 | ConcatDataset( 42 | train_set_x1, 43 | train_set_x2 44 | ), 45 | batch_size=BATCH_SIZE, shuffle=True) 46 | 47 | # Same for test data 48 | #test_loader = torch.utils.data.DataLoader( 49 | # datasets.MNIST('data', train=False, transform=transforms.ToTensor()), 50 | # batch_size=BATCH_SIZE, shuffle=True, **kwargs) 51 | 52 | model = VCCA(PRIVATE) 53 | 54 | if CUDA: 55 | model.cuda() 56 | 57 | 58 | def loss_function(recon_x1, recon_x2, x1, x2, mu, logvar) -> Variable: 59 | # how well do input x and output recon_x agree? 60 | BCE1 = F.binary_cross_entropy(recon_x1, x1.view(-1, 784)) 61 | BCE2 = F.binary_cross_entropy(recon_x2, x2.view(-1, 784)) 62 | 63 | # KLD is Kullback–Leibler divergence -- how much does one learned 64 | # distribution deviate from another, in this specific case the 65 | # learned distribution from the unit Gaussian 66 | 67 | # see Appendix B from VAE paper: 68 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 69 | # https://arxiv.org/abs/1312.6114 70 | # - D_{KL} = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 71 | # note the negative D_{KL} in appendix B of the paper 72 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 73 | # Normalise by same number of elements as in reconstruction 74 | KLD /= BATCH_SIZE * 784 75 | 76 | # BCE tries to make our reconstruction as accurate as possible 77 | # KLD tries to push the distributions as close as possible to unit Gaussian 78 | return BCE1 + KLD + BCE2 79 | 80 | def loss_function_private(recon_x1, recon_x2, x1, x2, mu, logvar, mu1, logvar1, mu2, logvar2) -> Variable: 81 | # how well do input x and output recon_x agree? 82 | BCE1 = F.binary_cross_entropy(recon_x1, x1.view(-1, 784)) 83 | BCE2 = F.binary_cross_entropy(recon_x2, x2.view(-1, 784)) 84 | 85 | # KLD is Kullback–Leibler divergence -- how much does one learned 86 | # distribution deviate from another, in this specific case the 87 | # learned distribution from the unit Gaussian 88 | 89 | # see Appendix B from VAE paper: 90 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 91 | # https://arxiv.org/abs/1312.6114 92 | # - D_{KL} = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 93 | # note the negative D_{KL} in appendix B of the paper 94 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 95 | # Normalise by same number of elements as in reconstruction 96 | KLD /= BATCH_SIZE * 784 97 | 98 | KLD1 = -0.5 * torch.sum(1 + logvar1 - mu1.pow(2) - logvar1.exp()) 99 | # Normalise by same number of elements as in reconstruction 100 | KLD1 /= BATCH_SIZE * 784 101 | 102 | KLD2 = -0.5 * torch.sum(1 + logvar2 - mu2.pow(2) - logvar2.exp()) 103 | # Normalise by same number of elements as in reconstruction 104 | KLD2 /= BATCH_SIZE * 784 105 | 106 | # BCE tries to make our reconstruction as accurate as possible 107 | # KLD tries to push the distributions as close as possible to unit Gaussian 108 | return BCE1 + KLD + KLD1 + KLD2 + BCE2 109 | 110 | 111 | # Dr Diederik Kingma: as if VAEs weren't enough, he also gave us Adam! 112 | optimizer = optim.Adam(model.parameters(), lr=0.0001) 113 | 114 | 115 | def train(epoch): 116 | # toggle model to train mode 117 | model.train() 118 | train_loss = 0 119 | # in the case of MNIST, len(train_loader.dataset) is 60000 120 | # each `data` is of BATCH_SIZE samples and has shape [128, 1, 28, 28] 121 | for batch_idx, (data1, data2) in enumerate(train_loader): 122 | data1 = Variable(data1).float() 123 | data2 = Variable(data2).float() 124 | if CUDA: 125 | data1 = data1.cuda() 126 | data2 = data2.cuda() 127 | optimizer.zero_grad() 128 | 129 | if not model.private: 130 | # push whole batch of data through VAE.forward() to get recon_loss 131 | recon_batch1, recon_batch2, mu, log_var = model(data1, data2) 132 | # calculate scalar loss 133 | loss = loss_function(recon_batch1, recon_batch2, data1, data2, mu, log_var) 134 | else: 135 | recon_batch1, recon_batch2, mu, log_var, mu1, log_var1, mu2, log_var2 = model(data1, data2) 136 | loss = loss_function_private(recon_batch1, recon_batch2, data1, data2, mu, log_var, mu1, log_var1, mu2, log_var2) 137 | # calculate the gradient of the loss w.r.t. the graph leaves 138 | # i.e. input variables -- by the power of pytorch! 139 | loss.backward() 140 | train_loss += loss.data[0] 141 | optimizer.step() 142 | if batch_idx % LOG_INTERVAL == 0: 143 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 144 | epoch, batch_idx * len(data1), len(train_loader.dataset), 145 | 100. * batch_idx / len(train_loader), 146 | loss.data[0] / len(data1))) 147 | 148 | print('====> Epoch: {} Average loss: {:.4f}'.format( 149 | epoch, train_loss / len(train_loader.dataset))) 150 | 151 | 152 | 153 | for epoch in range(1, EPOCHS + 1): 154 | train(epoch) 155 | #est(epoch) 156 | model.eval() 157 | # 64 sets of random ZDIMS-float vectors, i.e. 64 locations / MNIST 158 | # digits in latent space 159 | if model.private: 160 | sample = Variable(torch.randn(64, PDIMS+ZDIMS)) 161 | else: 162 | sample = Variable(torch.randn(64, ZDIMS)) 163 | 164 | if CUDA: 165 | sample = sample.cuda() 166 | sample1 = model.decode_1(sample).cpu() 167 | sample2 = model.decode_2(sample).cpu() 168 | # save out as an 8x8 matrix of MNIST digits 169 | # this will give you a visual idea of how well latent space can generate things 170 | # that look like digits 171 | if epoch % 5 == 0: 172 | save_image(sample1.data.view(64, 1, 28, 28), 173 | 'results/sample1_' + str(epoch) + '.png') 174 | save_image(sample2.data.view(64, 1, 28, 28), 175 | 'results/sample2_' + str(epoch) + '.png') 176 | 177 | with open('model.pt','wb') as f: 178 | torch.save(model, f) -------------------------------------------------------------------------------- /results/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/.DS_Store -------------------------------------------------------------------------------- /results/final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/final.png -------------------------------------------------------------------------------- /results/sample1_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_10.png -------------------------------------------------------------------------------- /results/sample1_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_15.png -------------------------------------------------------------------------------- /results/sample1_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_20.png -------------------------------------------------------------------------------- /results/sample1_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_25.png -------------------------------------------------------------------------------- /results/sample1_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_30.png -------------------------------------------------------------------------------- /results/sample1_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_35.png -------------------------------------------------------------------------------- /results/sample1_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_40.png -------------------------------------------------------------------------------- /results/sample1_45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_45.png -------------------------------------------------------------------------------- /results/sample1_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_5.png -------------------------------------------------------------------------------- /results/sample1_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_50.png -------------------------------------------------------------------------------- /results/sample2_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_10.png -------------------------------------------------------------------------------- /results/sample2_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_15.png -------------------------------------------------------------------------------- /results/sample2_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_20.png -------------------------------------------------------------------------------- /results/sample2_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_25.png -------------------------------------------------------------------------------- /results/sample2_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_30.png -------------------------------------------------------------------------------- /results/sample2_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_35.png -------------------------------------------------------------------------------- /results/sample2_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_40.png -------------------------------------------------------------------------------- /results/sample2_45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_45.png -------------------------------------------------------------------------------- /results/sample2_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_5.png -------------------------------------------------------------------------------- /results/sample2_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_50.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from utils import * 3 | from torchvision import datasets, transforms 4 | from utils import * 5 | from torchvision.utils import save_image 6 | from torch.autograd import Variable 7 | import matplotlib.pyplot as plt 8 | 9 | with open('model.pt', 'rb') as f: 10 | model = torch.load(f) 11 | 12 | 13 | model.eval() 14 | 15 | data1 = load_data('noisymnist_view1.gz') 16 | data2 = load_data('noisymnist_view2.gz') 17 | train_set_x1, label1 = data1[0] 18 | train_set_x2, label2 = data2[0] 19 | 20 | train_loader = torch.utils.data.DataLoader( 21 | ConcatDataset( 22 | train_set_x1, 23 | train_set_x2 24 | ), 25 | batch_size=1, shuffle=False) 26 | 27 | 28 | print(label1[:10]) 29 | print(label2[:10]) 30 | 31 | for batch_idx, (data1, data2) in enumerate(train_loader): 32 | data1 = Variable(data1).float() 33 | data2 = Variable(data2).float() 34 | 35 | model.eval() 36 | 37 | mu_z, _ = model.encode(data1) 38 | sample1 = data1 39 | for batch_idx2, (data11, data22) in enumerate(train_loader): 40 | data11 = Variable(data11).float() 41 | 42 | p_mu, log_var = model.private_encoder1(data11) 43 | std = log_var.mul(0.5).exp_() # type: Variable 44 | 45 | eps = Variable(std.data.new(std.size()).normal_()) 46 | 47 | input = eps.mul(std).add_(p_mu) 48 | 49 | 50 | sample_tmp = torch.cat((mu_z, input),1) 51 | sample_tmp = model.decode_1(sample_tmp).cpu() 52 | 53 | sample1 = torch.cat((sample1, sample_tmp),1) 54 | if batch_idx2 == 6: 55 | break 56 | if batch_idx == 0: 57 | res = sample1 58 | else: 59 | res = torch.cat((res, sample1),0) 60 | 61 | if batch_idx == 7: 62 | break 63 | print(res.size()) 64 | save_image(res.data.view(64, 1, 28, 28), 65 | 'results/final.png') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from sklearn import svm 3 | from sklearn.metrics import accuracy_score 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | 8 | def load_data(data_file): 9 | """loads the data from the gzip pickled files, and converts to numpy arrays""" 10 | print('loading data ...') 11 | f = gzip.open(data_file, 'rb') 12 | train_set, valid_set, test_set = load_pickle(f) 13 | f.close() 14 | 15 | train_set_x, train_set_y = make_numpy_array(train_set) 16 | valid_set_x, valid_set_y = make_numpy_array(valid_set) 17 | test_set_x, test_set_y = make_numpy_array(test_set) 18 | 19 | return [(train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)] 20 | 21 | 22 | def make_numpy_array(data_xy): 23 | """converts the input to numpy arrays""" 24 | data_x, data_y = data_xy 25 | data_x = np.asarray(data_x) 26 | data_y = np.asarray(data_y, dtype='int32') 27 | return data_x, data_y 28 | 29 | 30 | def svm_classify(data, C): 31 | """ 32 | trains a linear SVM on the data 33 | input C specifies the penalty factor of SVM 34 | """ 35 | train_data, _, train_label = data[0] 36 | valid_data, _, valid_label = data[1] 37 | test_data, _, test_label = data[2] 38 | 39 | print('training SVM...') 40 | clf = svm.LinearSVC(C=C, dual=False) 41 | clf.fit(train_data, train_label.ravel()) 42 | 43 | p = clf.predict(test_data) 44 | test_acc = accuracy_score(test_label, p) 45 | p = clf.predict(valid_data) 46 | valid_acc = accuracy_score(valid_label, p) 47 | 48 | return [test_acc, valid_acc] 49 | 50 | 51 | def load_pickle(f): 52 | """ 53 | loads and returns the content of a pickled file 54 | it handles the inconsistencies between the pickle packages available in Python 2 and 3 55 | """ 56 | try: 57 | import cPickle as thepickle 58 | except ImportError: 59 | import _pickle as thepickle 60 | 61 | try: 62 | ret = thepickle.load(f, encoding='latin1') 63 | except TypeError: 64 | ret = thepickle.load(f) 65 | 66 | return ret 67 | 68 | class ConcatDataset(torch.utils.data.Dataset): 69 | def __init__(self, *datasets): 70 | self.datasets = datasets 71 | 72 | def __getitem__(self, i): 73 | return tuple(d[i] for d in self.datasets) 74 | 75 | def __len__(self): 76 | return min(len(d) for d in self.datasets) 77 | 78 | -------------------------------------------------------------------------------- /vcca.py: -------------------------------------------------------------------------------- 1 | from torch import nn, cat 2 | from torch.autograd import Variable 3 | ZDIMS=20 4 | PDIMS=30 5 | class VCCA(nn.Module): 6 | def __init__(self, private): 7 | super(VCCA, self).__init__() 8 | self.private = private 9 | # ENCODER 10 | # 28 x 28 pixels = 784 input pixels, 400 outputs 11 | self.en_z_1 = nn.Linear(784, 1024) 12 | self.en_z_2 = nn.Linear(1024, 1024) 13 | self.en_z_3 = nn.Linear(1024, 1024) 14 | 15 | # rectified linear unit layer from 400 to 400 16 | # max(0, x) 17 | self.relu = nn.ReLU() 18 | self.dropout = nn.Dropout(0.1) 19 | self.en_z_4_mu = nn.Linear(1024, ZDIMS) # mu layer 20 | self.en_z_4_sigma = nn.Linear(1024, ZDIMS) # logvariance layer 21 | # this last layer bottlenecks through ZDIMS connections 22 | if self.private: 23 | self.en_x_1 = nn.Linear(784, 1024) 24 | self.en_x_2 = nn.Linear(1024, 1024) 25 | self.en_x_3 = nn.Linear(1024, 1024) 26 | self.en_x_4_mu = nn.Linear(1024, PDIMS) 27 | self.en_x_4_sigma = nn.Linear(1024, PDIMS) 28 | 29 | self.en_y_1 = nn.Linear(784, 1024) 30 | self.en_y_2 = nn.Linear(1024, 1024) 31 | self.en_y_3 = nn.Linear(1024, 1024) 32 | self.en_y_4_mu = nn.Linear(1024, PDIMS) 33 | self.en_y_4_sigma = nn.Linear(1024, PDIMS) 34 | 35 | # DECODER 1 36 | # from bottleneck to hidden 400 37 | if self.private: 38 | self.de_x_1 = nn.Linear(PDIMS+ZDIMS, 1024) 39 | else: 40 | self.de_x_1 = nn.Linear(ZDIMS, 1024) 41 | self.de_x_2 = nn.Linear(1024, 1024) 42 | self.de_x_3 = nn.Linear(1024, 1024) 43 | self.de_x_4 = nn.Linear(1024, 784) 44 | 45 | # DECODER 2 46 | if self.private: 47 | self.de_y_1 = nn.Linear(PDIMS+ZDIMS, 1024) 48 | else: 49 | self.de_y_1 = nn.Linear(ZDIMS, 1024) 50 | self.de_y_2 = nn.Linear(1024, 1024) 51 | self.de_y_3 = nn.Linear(1024, 1024) 52 | self.de_y_4 = nn.Linear(1024, 784) 53 | 54 | self.sigmoid = nn.Sigmoid() 55 | 56 | def encode(self, x: Variable) -> (Variable, Variable): 57 | """Input vector x -> fully connected 1 -> ReLU -> (fully connected 58 | 59 | """ 60 | h1 = self.relu(self.en_z_1(self.dropout(x))) 61 | h1 = self.relu(self.en_z_2(self.dropout(h1))) 62 | h1 = self.relu(self.en_z_3(self.dropout(h1))) 63 | return self.en_z_4_mu(self.dropout(h1)), self.en_z_4_sigma(self.dropout(h1)) 64 | 65 | def private_encoder1(self, x:Variable): 66 | h1 = self.relu(self.en_x_1(self.dropout(x))) 67 | h1 = self.relu(self.en_x_2(self.dropout(h1))) 68 | h1 = self.relu(self.en_x_3(self.dropout(h1))) 69 | return self.en_x_4_mu(self.dropout(h1)), self.en_x_4_sigma(self.dropout(h1)) 70 | 71 | def private_encoder2(self, y:Variable): 72 | h1 = self.relu(self.en_y_1(self.dropout(y))) 73 | h1 = self.relu(self.en_y_2(self.dropout(h1))) 74 | h1 = self.relu(self.en_y_3(self.dropout(h1))) 75 | return self.en_y_4_mu(self.dropout(h1)), self.en_y_4_sigma(self.dropout(h1)) 76 | 77 | def reparameterize(self, mu: Variable, logvar: Variable) -> Variable: 78 | """THE REPARAMETERIZATION IDEA: 79 | 80 | """ 81 | 82 | if self.training: 83 | 84 | std = logvar.mul(0.5).exp_() # type: Variable 85 | 86 | eps = Variable(std.data.new(std.size()).normal_()) 87 | 88 | return eps.mul(std).add_(mu) 89 | 90 | else: 91 | # During inference, we simply spit out the mean of the 92 | # learned distribution for the current input. We could 93 | # use a random sample from the distribution, but mu of 94 | # course has the highest probability. 95 | return mu 96 | 97 | def decode_1(self, z: Variable) -> Variable: 98 | h3 = self.relu(self.de_x_1(self.dropout(z))) 99 | h3 = self.relu(self.de_x_2(self.dropout(h3))) 100 | h3 = self.relu(self.de_x_3(self.dropout(h3))) 101 | return self.sigmoid(self.de_x_4(self.dropout(h3))) 102 | 103 | def decode_2(self, z: Variable) -> Variable: 104 | h3 = self.relu(self.de_y_1(self.dropout(z))) 105 | h3 = self.relu(self.de_y_2(self.dropout(h3))) 106 | h3 = self.relu(self.de_y_3(self.dropout(h3))) 107 | return self.sigmoid(self.de_y_4(self.dropout(h3))) 108 | 109 | def forward(self, x: Variable, y: Variable) -> (Variable, Variable, Variable): 110 | mu, log_var = self.encode(x.view(-1, 784)) 111 | 112 | if self.private: 113 | mu1, log_var1 = self.private_encoder1(x.view(-1, 784)) 114 | mu2, log_var2 = self.private_encoder2(y.view(-1, 784)) 115 | mu1_tmp = cat((mu,mu1), 1) 116 | log_var1_tmp = cat((log_var,log_var1), 1) 117 | mu2_tmp = cat((mu, mu2), 1) 118 | log_var2_tmp = cat((log_var, log_var2), 1) 119 | z1 = self.reparameterize(mu1_tmp, log_var1_tmp) 120 | z2 = self.reparameterize(mu2_tmp, log_var2_tmp) 121 | return self.decode_1(z1), self.decode_2(z2), mu, log_var, mu1, log_var1, mu2, log_var2 122 | 123 | z1 = self.reparameterize(mu, log_var) 124 | z2 = self.reparameterize(mu, log_var) 125 | return self.decode_1(z1), self.decode_2(z2), mu, log_var 126 | 127 | --------------------------------------------------------------------------------