├── .DS_Store
├── README.md
├── __pycache__
├── utils.cpython-36.pyc
└── vcca.cpython-36.pyc
├── main.py
├── results
├── .DS_Store
├── final.png
├── sample1_10.png
├── sample1_15.png
├── sample1_20.png
├── sample1_25.png
├── sample1_30.png
├── sample1_35.png
├── sample1_40.png
├── sample1_45.png
├── sample1_5.png
├── sample1_50.png
├── sample2_10.png
├── sample2_15.png
├── sample2_20.png
├── sample2_25.png
├── sample2_30.png
├── sample2_35.png
├── sample2_40.png
├── sample2_45.png
├── sample2_5.png
└── sample2_50.png
├── test.py
├── utils.py
└── vcca.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VCCA: Variational Canonical Correlation Analysis
2 |
3 | This is an Pytorch implementation of [Deep Variational Canonical Correlation Analysis (VCCA)](https://arxiv.org/abs/1610.03454) in Python.
4 |
5 | ## Variational CCA and Variational CCA Private [VCCA, VCCAP]
6 |
7 |
8 |
9 | VCCA-private:
10 |
11 |
12 |
13 |
14 | [Deep Variational Canonical Correlation Analysis](https://github.com/edchengg/VCCA-StudyNotes/blob/master/paper/DVCCA.pdf)
15 |
16 |
17 | ### Training
18 | 5 epochs:
19 |
20 | view1 | view2
21 | :-------------------------:|:-------------------------:
22 |  | 
23 |
24 | 50 epochs:
25 |
26 | view1 | view2
27 | :-------------------------:|:-------------------------:
28 |  | 
29 | ### Generation
30 |
31 |
32 |
33 | ### Dataset
34 | The model is evaluated on a noisy version of MNIST dataset. [Vahid Noroozi](https://github.com/VahidooX/DeepCCA) built the dataset exactly like the way it is introduced in the paper. The train/validation/test split is the original split of MNIST.
35 |
36 | The dataset was large and could not get uploaded on GitHub. So it is uploaded on another server. You can download the data from:
37 |
38 | [view1](https://www2.cs.uic.edu/~vnoroozi/noisy-mnist/noisymnist_view1.gz)
39 |
40 | [view2](https://www2.cs.uic.edu/~vnoroozi/noisy-mnist/noisymnist_view2.gz)
41 |
42 | save it in the same directory with python code.
43 |
44 | ### Differences with the original paper
45 | The following are the differences between my implementation and the original paper (they are small):
46 |
47 | * I used simple bianry cross entropy loss for two decoder networks.
48 |
49 | ### Other Implementations
50 |
51 | The following are the other implementations of VCCA in Tensorflow,
52 |
53 | * [Tensorflow implementation](http://ttic.uchicago.edu/~wwang5/papers/vcca_tf0.9_code.tgz) from Wang, Weiran's website (http://ttic.uchicago.edu/~wwang5/)
54 |
55 |
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/vcca.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/__pycache__/vcca.cpython-36.pyc
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data
4 | from torch import optim
5 | from torch.autograd import Variable
6 | from torch.nn import functional as F
7 | from torchvision import datasets, transforms
8 | from torchvision.utils import save_image
9 | from vcca import VCCA
10 | from utils import *
11 |
12 | CUDA = False
13 | SEED = 1
14 | BATCH_SIZE = 128
15 | LOG_INTERVAL = 10
16 | EPOCHS = 50
17 | ZDIMS = 20
18 | PDIMS = 30
19 | PRIVATE = True
20 | # I do this so that the MNIST dataset is downloaded where I want it
21 | os.chdir("/Users/edison/PycharmProjects/vcca_pytorch")
22 |
23 | torch.manual_seed(SEED)
24 |
25 | if CUDA:
26 | torch.cuda.manual_seed(SEED)
27 |
28 | # DataLoader instances will load tensors directly into GPU memory
29 | kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}
30 |
31 | # Download or load downloaded MNIST dataset
32 | # shuffle data at every epoch
33 |
34 | data1 = load_data('noisymnist_view1.gz')
35 | data2 = load_data('noisymnist_view2.gz')
36 |
37 | train_set_x1, _ = data1[0]
38 | train_set_x2, _ = data2[0]
39 |
40 | train_loader = torch.utils.data.DataLoader(
41 | ConcatDataset(
42 | train_set_x1,
43 | train_set_x2
44 | ),
45 | batch_size=BATCH_SIZE, shuffle=True)
46 |
47 | # Same for test data
48 | #test_loader = torch.utils.data.DataLoader(
49 | # datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
50 | # batch_size=BATCH_SIZE, shuffle=True, **kwargs)
51 |
52 | model = VCCA(PRIVATE)
53 |
54 | if CUDA:
55 | model.cuda()
56 |
57 |
58 | def loss_function(recon_x1, recon_x2, x1, x2, mu, logvar) -> Variable:
59 | # how well do input x and output recon_x agree?
60 | BCE1 = F.binary_cross_entropy(recon_x1, x1.view(-1, 784))
61 | BCE2 = F.binary_cross_entropy(recon_x2, x2.view(-1, 784))
62 |
63 | # KLD is Kullback–Leibler divergence -- how much does one learned
64 | # distribution deviate from another, in this specific case the
65 | # learned distribution from the unit Gaussian
66 |
67 | # see Appendix B from VAE paper:
68 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
69 | # https://arxiv.org/abs/1312.6114
70 | # - D_{KL} = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
71 | # note the negative D_{KL} in appendix B of the paper
72 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
73 | # Normalise by same number of elements as in reconstruction
74 | KLD /= BATCH_SIZE * 784
75 |
76 | # BCE tries to make our reconstruction as accurate as possible
77 | # KLD tries to push the distributions as close as possible to unit Gaussian
78 | return BCE1 + KLD + BCE2
79 |
80 | def loss_function_private(recon_x1, recon_x2, x1, x2, mu, logvar, mu1, logvar1, mu2, logvar2) -> Variable:
81 | # how well do input x and output recon_x agree?
82 | BCE1 = F.binary_cross_entropy(recon_x1, x1.view(-1, 784))
83 | BCE2 = F.binary_cross_entropy(recon_x2, x2.view(-1, 784))
84 |
85 | # KLD is Kullback–Leibler divergence -- how much does one learned
86 | # distribution deviate from another, in this specific case the
87 | # learned distribution from the unit Gaussian
88 |
89 | # see Appendix B from VAE paper:
90 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
91 | # https://arxiv.org/abs/1312.6114
92 | # - D_{KL} = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
93 | # note the negative D_{KL} in appendix B of the paper
94 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
95 | # Normalise by same number of elements as in reconstruction
96 | KLD /= BATCH_SIZE * 784
97 |
98 | KLD1 = -0.5 * torch.sum(1 + logvar1 - mu1.pow(2) - logvar1.exp())
99 | # Normalise by same number of elements as in reconstruction
100 | KLD1 /= BATCH_SIZE * 784
101 |
102 | KLD2 = -0.5 * torch.sum(1 + logvar2 - mu2.pow(2) - logvar2.exp())
103 | # Normalise by same number of elements as in reconstruction
104 | KLD2 /= BATCH_SIZE * 784
105 |
106 | # BCE tries to make our reconstruction as accurate as possible
107 | # KLD tries to push the distributions as close as possible to unit Gaussian
108 | return BCE1 + KLD + KLD1 + KLD2 + BCE2
109 |
110 |
111 | # Dr Diederik Kingma: as if VAEs weren't enough, he also gave us Adam!
112 | optimizer = optim.Adam(model.parameters(), lr=0.0001)
113 |
114 |
115 | def train(epoch):
116 | # toggle model to train mode
117 | model.train()
118 | train_loss = 0
119 | # in the case of MNIST, len(train_loader.dataset) is 60000
120 | # each `data` is of BATCH_SIZE samples and has shape [128, 1, 28, 28]
121 | for batch_idx, (data1, data2) in enumerate(train_loader):
122 | data1 = Variable(data1).float()
123 | data2 = Variable(data2).float()
124 | if CUDA:
125 | data1 = data1.cuda()
126 | data2 = data2.cuda()
127 | optimizer.zero_grad()
128 |
129 | if not model.private:
130 | # push whole batch of data through VAE.forward() to get recon_loss
131 | recon_batch1, recon_batch2, mu, log_var = model(data1, data2)
132 | # calculate scalar loss
133 | loss = loss_function(recon_batch1, recon_batch2, data1, data2, mu, log_var)
134 | else:
135 | recon_batch1, recon_batch2, mu, log_var, mu1, log_var1, mu2, log_var2 = model(data1, data2)
136 | loss = loss_function_private(recon_batch1, recon_batch2, data1, data2, mu, log_var, mu1, log_var1, mu2, log_var2)
137 | # calculate the gradient of the loss w.r.t. the graph leaves
138 | # i.e. input variables -- by the power of pytorch!
139 | loss.backward()
140 | train_loss += loss.data[0]
141 | optimizer.step()
142 | if batch_idx % LOG_INTERVAL == 0:
143 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
144 | epoch, batch_idx * len(data1), len(train_loader.dataset),
145 | 100. * batch_idx / len(train_loader),
146 | loss.data[0] / len(data1)))
147 |
148 | print('====> Epoch: {} Average loss: {:.4f}'.format(
149 | epoch, train_loss / len(train_loader.dataset)))
150 |
151 |
152 |
153 | for epoch in range(1, EPOCHS + 1):
154 | train(epoch)
155 | #est(epoch)
156 | model.eval()
157 | # 64 sets of random ZDIMS-float vectors, i.e. 64 locations / MNIST
158 | # digits in latent space
159 | if model.private:
160 | sample = Variable(torch.randn(64, PDIMS+ZDIMS))
161 | else:
162 | sample = Variable(torch.randn(64, ZDIMS))
163 |
164 | if CUDA:
165 | sample = sample.cuda()
166 | sample1 = model.decode_1(sample).cpu()
167 | sample2 = model.decode_2(sample).cpu()
168 | # save out as an 8x8 matrix of MNIST digits
169 | # this will give you a visual idea of how well latent space can generate things
170 | # that look like digits
171 | if epoch % 5 == 0:
172 | save_image(sample1.data.view(64, 1, 28, 28),
173 | 'results/sample1_' + str(epoch) + '.png')
174 | save_image(sample2.data.view(64, 1, 28, 28),
175 | 'results/sample2_' + str(epoch) + '.png')
176 |
177 | with open('model.pt','wb') as f:
178 | torch.save(model, f)
--------------------------------------------------------------------------------
/results/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/.DS_Store
--------------------------------------------------------------------------------
/results/final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/final.png
--------------------------------------------------------------------------------
/results/sample1_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_10.png
--------------------------------------------------------------------------------
/results/sample1_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_15.png
--------------------------------------------------------------------------------
/results/sample1_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_20.png
--------------------------------------------------------------------------------
/results/sample1_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_25.png
--------------------------------------------------------------------------------
/results/sample1_30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_30.png
--------------------------------------------------------------------------------
/results/sample1_35.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_35.png
--------------------------------------------------------------------------------
/results/sample1_40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_40.png
--------------------------------------------------------------------------------
/results/sample1_45.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_45.png
--------------------------------------------------------------------------------
/results/sample1_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_5.png
--------------------------------------------------------------------------------
/results/sample1_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample1_50.png
--------------------------------------------------------------------------------
/results/sample2_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_10.png
--------------------------------------------------------------------------------
/results/sample2_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_15.png
--------------------------------------------------------------------------------
/results/sample2_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_20.png
--------------------------------------------------------------------------------
/results/sample2_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_25.png
--------------------------------------------------------------------------------
/results/sample2_30.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_30.png
--------------------------------------------------------------------------------
/results/sample2_35.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_35.png
--------------------------------------------------------------------------------
/results/sample2_40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_40.png
--------------------------------------------------------------------------------
/results/sample2_45.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_45.png
--------------------------------------------------------------------------------
/results/sample2_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_5.png
--------------------------------------------------------------------------------
/results/sample2_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/edchengg/VCCA_pytorch/85aa1f6e9ef247a36600f24d40513c44557c4373/results/sample2_50.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from utils import *
3 | from torchvision import datasets, transforms
4 | from utils import *
5 | from torchvision.utils import save_image
6 | from torch.autograd import Variable
7 | import matplotlib.pyplot as plt
8 |
9 | with open('model.pt', 'rb') as f:
10 | model = torch.load(f)
11 |
12 |
13 | model.eval()
14 |
15 | data1 = load_data('noisymnist_view1.gz')
16 | data2 = load_data('noisymnist_view2.gz')
17 | train_set_x1, label1 = data1[0]
18 | train_set_x2, label2 = data2[0]
19 |
20 | train_loader = torch.utils.data.DataLoader(
21 | ConcatDataset(
22 | train_set_x1,
23 | train_set_x2
24 | ),
25 | batch_size=1, shuffle=False)
26 |
27 |
28 | print(label1[:10])
29 | print(label2[:10])
30 |
31 | for batch_idx, (data1, data2) in enumerate(train_loader):
32 | data1 = Variable(data1).float()
33 | data2 = Variable(data2).float()
34 |
35 | model.eval()
36 |
37 | mu_z, _ = model.encode(data1)
38 | sample1 = data1
39 | for batch_idx2, (data11, data22) in enumerate(train_loader):
40 | data11 = Variable(data11).float()
41 |
42 | p_mu, log_var = model.private_encoder1(data11)
43 | std = log_var.mul(0.5).exp_() # type: Variable
44 |
45 | eps = Variable(std.data.new(std.size()).normal_())
46 |
47 | input = eps.mul(std).add_(p_mu)
48 |
49 |
50 | sample_tmp = torch.cat((mu_z, input),1)
51 | sample_tmp = model.decode_1(sample_tmp).cpu()
52 |
53 | sample1 = torch.cat((sample1, sample_tmp),1)
54 | if batch_idx2 == 6:
55 | break
56 | if batch_idx == 0:
57 | res = sample1
58 | else:
59 | res = torch.cat((res, sample1),0)
60 |
61 | if batch_idx == 7:
62 | break
63 | print(res.size())
64 | save_image(res.data.view(64, 1, 28, 28),
65 | 'results/final.png')
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | from sklearn import svm
3 | from sklearn.metrics import accuracy_score
4 | import numpy as np
5 | import torch
6 | import torch.utils.data
7 |
8 | def load_data(data_file):
9 | """loads the data from the gzip pickled files, and converts to numpy arrays"""
10 | print('loading data ...')
11 | f = gzip.open(data_file, 'rb')
12 | train_set, valid_set, test_set = load_pickle(f)
13 | f.close()
14 |
15 | train_set_x, train_set_y = make_numpy_array(train_set)
16 | valid_set_x, valid_set_y = make_numpy_array(valid_set)
17 | test_set_x, test_set_y = make_numpy_array(test_set)
18 |
19 | return [(train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)]
20 |
21 |
22 | def make_numpy_array(data_xy):
23 | """converts the input to numpy arrays"""
24 | data_x, data_y = data_xy
25 | data_x = np.asarray(data_x)
26 | data_y = np.asarray(data_y, dtype='int32')
27 | return data_x, data_y
28 |
29 |
30 | def svm_classify(data, C):
31 | """
32 | trains a linear SVM on the data
33 | input C specifies the penalty factor of SVM
34 | """
35 | train_data, _, train_label = data[0]
36 | valid_data, _, valid_label = data[1]
37 | test_data, _, test_label = data[2]
38 |
39 | print('training SVM...')
40 | clf = svm.LinearSVC(C=C, dual=False)
41 | clf.fit(train_data, train_label.ravel())
42 |
43 | p = clf.predict(test_data)
44 | test_acc = accuracy_score(test_label, p)
45 | p = clf.predict(valid_data)
46 | valid_acc = accuracy_score(valid_label, p)
47 |
48 | return [test_acc, valid_acc]
49 |
50 |
51 | def load_pickle(f):
52 | """
53 | loads and returns the content of a pickled file
54 | it handles the inconsistencies between the pickle packages available in Python 2 and 3
55 | """
56 | try:
57 | import cPickle as thepickle
58 | except ImportError:
59 | import _pickle as thepickle
60 |
61 | try:
62 | ret = thepickle.load(f, encoding='latin1')
63 | except TypeError:
64 | ret = thepickle.load(f)
65 |
66 | return ret
67 |
68 | class ConcatDataset(torch.utils.data.Dataset):
69 | def __init__(self, *datasets):
70 | self.datasets = datasets
71 |
72 | def __getitem__(self, i):
73 | return tuple(d[i] for d in self.datasets)
74 |
75 | def __len__(self):
76 | return min(len(d) for d in self.datasets)
77 |
78 |
--------------------------------------------------------------------------------
/vcca.py:
--------------------------------------------------------------------------------
1 | from torch import nn, cat
2 | from torch.autograd import Variable
3 | ZDIMS=20
4 | PDIMS=30
5 | class VCCA(nn.Module):
6 | def __init__(self, private):
7 | super(VCCA, self).__init__()
8 | self.private = private
9 | # ENCODER
10 | # 28 x 28 pixels = 784 input pixels, 400 outputs
11 | self.en_z_1 = nn.Linear(784, 1024)
12 | self.en_z_2 = nn.Linear(1024, 1024)
13 | self.en_z_3 = nn.Linear(1024, 1024)
14 |
15 | # rectified linear unit layer from 400 to 400
16 | # max(0, x)
17 | self.relu = nn.ReLU()
18 | self.dropout = nn.Dropout(0.1)
19 | self.en_z_4_mu = nn.Linear(1024, ZDIMS) # mu layer
20 | self.en_z_4_sigma = nn.Linear(1024, ZDIMS) # logvariance layer
21 | # this last layer bottlenecks through ZDIMS connections
22 | if self.private:
23 | self.en_x_1 = nn.Linear(784, 1024)
24 | self.en_x_2 = nn.Linear(1024, 1024)
25 | self.en_x_3 = nn.Linear(1024, 1024)
26 | self.en_x_4_mu = nn.Linear(1024, PDIMS)
27 | self.en_x_4_sigma = nn.Linear(1024, PDIMS)
28 |
29 | self.en_y_1 = nn.Linear(784, 1024)
30 | self.en_y_2 = nn.Linear(1024, 1024)
31 | self.en_y_3 = nn.Linear(1024, 1024)
32 | self.en_y_4_mu = nn.Linear(1024, PDIMS)
33 | self.en_y_4_sigma = nn.Linear(1024, PDIMS)
34 |
35 | # DECODER 1
36 | # from bottleneck to hidden 400
37 | if self.private:
38 | self.de_x_1 = nn.Linear(PDIMS+ZDIMS, 1024)
39 | else:
40 | self.de_x_1 = nn.Linear(ZDIMS, 1024)
41 | self.de_x_2 = nn.Linear(1024, 1024)
42 | self.de_x_3 = nn.Linear(1024, 1024)
43 | self.de_x_4 = nn.Linear(1024, 784)
44 |
45 | # DECODER 2
46 | if self.private:
47 | self.de_y_1 = nn.Linear(PDIMS+ZDIMS, 1024)
48 | else:
49 | self.de_y_1 = nn.Linear(ZDIMS, 1024)
50 | self.de_y_2 = nn.Linear(1024, 1024)
51 | self.de_y_3 = nn.Linear(1024, 1024)
52 | self.de_y_4 = nn.Linear(1024, 784)
53 |
54 | self.sigmoid = nn.Sigmoid()
55 |
56 | def encode(self, x: Variable) -> (Variable, Variable):
57 | """Input vector x -> fully connected 1 -> ReLU -> (fully connected
58 |
59 | """
60 | h1 = self.relu(self.en_z_1(self.dropout(x)))
61 | h1 = self.relu(self.en_z_2(self.dropout(h1)))
62 | h1 = self.relu(self.en_z_3(self.dropout(h1)))
63 | return self.en_z_4_mu(self.dropout(h1)), self.en_z_4_sigma(self.dropout(h1))
64 |
65 | def private_encoder1(self, x:Variable):
66 | h1 = self.relu(self.en_x_1(self.dropout(x)))
67 | h1 = self.relu(self.en_x_2(self.dropout(h1)))
68 | h1 = self.relu(self.en_x_3(self.dropout(h1)))
69 | return self.en_x_4_mu(self.dropout(h1)), self.en_x_4_sigma(self.dropout(h1))
70 |
71 | def private_encoder2(self, y:Variable):
72 | h1 = self.relu(self.en_y_1(self.dropout(y)))
73 | h1 = self.relu(self.en_y_2(self.dropout(h1)))
74 | h1 = self.relu(self.en_y_3(self.dropout(h1)))
75 | return self.en_y_4_mu(self.dropout(h1)), self.en_y_4_sigma(self.dropout(h1))
76 |
77 | def reparameterize(self, mu: Variable, logvar: Variable) -> Variable:
78 | """THE REPARAMETERIZATION IDEA:
79 |
80 | """
81 |
82 | if self.training:
83 |
84 | std = logvar.mul(0.5).exp_() # type: Variable
85 |
86 | eps = Variable(std.data.new(std.size()).normal_())
87 |
88 | return eps.mul(std).add_(mu)
89 |
90 | else:
91 | # During inference, we simply spit out the mean of the
92 | # learned distribution for the current input. We could
93 | # use a random sample from the distribution, but mu of
94 | # course has the highest probability.
95 | return mu
96 |
97 | def decode_1(self, z: Variable) -> Variable:
98 | h3 = self.relu(self.de_x_1(self.dropout(z)))
99 | h3 = self.relu(self.de_x_2(self.dropout(h3)))
100 | h3 = self.relu(self.de_x_3(self.dropout(h3)))
101 | return self.sigmoid(self.de_x_4(self.dropout(h3)))
102 |
103 | def decode_2(self, z: Variable) -> Variable:
104 | h3 = self.relu(self.de_y_1(self.dropout(z)))
105 | h3 = self.relu(self.de_y_2(self.dropout(h3)))
106 | h3 = self.relu(self.de_y_3(self.dropout(h3)))
107 | return self.sigmoid(self.de_y_4(self.dropout(h3)))
108 |
109 | def forward(self, x: Variable, y: Variable) -> (Variable, Variable, Variable):
110 | mu, log_var = self.encode(x.view(-1, 784))
111 |
112 | if self.private:
113 | mu1, log_var1 = self.private_encoder1(x.view(-1, 784))
114 | mu2, log_var2 = self.private_encoder2(y.view(-1, 784))
115 | mu1_tmp = cat((mu,mu1), 1)
116 | log_var1_tmp = cat((log_var,log_var1), 1)
117 | mu2_tmp = cat((mu, mu2), 1)
118 | log_var2_tmp = cat((log_var, log_var2), 1)
119 | z1 = self.reparameterize(mu1_tmp, log_var1_tmp)
120 | z2 = self.reparameterize(mu2_tmp, log_var2_tmp)
121 | return self.decode_1(z1), self.decode_2(z2), mu, log_var, mu1, log_var1, mu2, log_var2
122 |
123 | z1 = self.reparameterize(mu, log_var)
124 | z2 = self.reparameterize(mu, log_var)
125 | return self.decode_1(z1), self.decode_2(z2), mu, log_var
126 |
127 |
--------------------------------------------------------------------------------