├── .DS_Store
├── .idea
├── WAE.iml
├── deployment.xml
├── dictionaries
│ └── chensnathan.xml
├── misc.xml
├── modules.xml
├── vcs.xml
├── webServers.xml
└── workspace.xml
├── README.md
├── images
├── .DS_Store
├── real_image_100.png
├── recon_image_100.png
└── sample_image_100.png
└── wae_for_mnist.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsnedy/WAE_Pytorch/c599aef1cbc6a73210b2fe045998b55bced22f01/.DS_Store
--------------------------------------------------------------------------------
/.idea/WAE.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/dictionaries/chensnathan.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | normal
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 | true
72 | DEFINITION_ORDER
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 | CSS
87 |
88 |
89 | Probable bugsCSS
90 |
91 |
92 | RELAX NG
93 |
94 |
95 |
96 |
97 | AngularJS
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 | 1517468803468
165 |
166 |
167 | 1517468803468
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch implementation of Wasserstein Auto-Encoders
2 | The reference of this code is the original implementation in TensorFlow:
3 |
4 |
5 |
6 | ##### The results of epoch 100 are as follow:
7 | ##### real images
8 | 
9 | ##### recon images
10 | 
11 | ##### sample images
12 | 
13 |
--------------------------------------------------------------------------------
/images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsnedy/WAE_Pytorch/c599aef1cbc6a73210b2fe045998b55bced22f01/images/.DS_Store
--------------------------------------------------------------------------------
/images/real_image_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsnedy/WAE_Pytorch/c599aef1cbc6a73210b2fe045998b55bced22f01/images/real_image_100.png
--------------------------------------------------------------------------------
/images/recon_image_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsnedy/WAE_Pytorch/c599aef1cbc6a73210b2fe045998b55bced22f01/images/recon_image_100.png
--------------------------------------------------------------------------------
/images/sample_image_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsnedy/WAE_Pytorch/c599aef1cbc6a73210b2fe045998b55bced22f01/images/sample_image_100.png
--------------------------------------------------------------------------------
/wae_for_mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import torch.optim as optim
6 | from torchvision.datasets import MNIST
7 | from torch.utils.data import DataLoader
8 | import torchvision
9 | import itertools
10 | from torchvision.utils import save_image
11 |
12 |
13 | class Encoder(nn.Module):
14 | def __init__(self, in_channels, num_filters, num_layers, z_size, is_training=False):
15 | super(Encoder, self).__init__()
16 | self.in_channels = in_channels
17 | self.is_training = is_training
18 | conv_list = []
19 | for i in xrange(num_layers):
20 | scale = 2 ** (num_layers - i - 1)
21 | conv_i = nn.Sequential(
22 | nn.Conv2d(
23 | in_channels=self.in_channels, out_channels=num_filters / scale,
24 | kernel_size=4, stride=2, padding=2
25 | ),
26 | nn.BatchNorm2d(num_features=num_filters / scale),
27 | nn.ReLU(inplace=True)
28 | )
29 | conv_list.append(conv_i)
30 | self.in_channels = num_filters / scale
31 | self.conv = nn.Sequential(*conv_list)
32 | self.linear = nn.Linear(in_features=3 * 3 * num_filters, out_features=z_size)
33 | # initialize weights
34 | for m in self.modules():
35 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
36 | m.weight.data.normal_(0.0, 0.0099999)
37 | m.bias.data.zero_()
38 | elif isinstance(m, nn.BatchNorm2d):
39 | m.weight.data.normal_(0.0, 0.01)
40 | m.bias.data.zero_()
41 |
42 | def forward(self, x):
43 | if self.is_training:
44 | noise = torch.normal(torch.zeros_like(x.data), std=0.01)
45 | x.data += noise
46 | conv_out = self.conv(x).view(-1, 3 * 3 * 1024)
47 | z = self.linear(conv_out)
48 | return z
49 |
50 |
51 | class Decoder(nn.Module):
52 | def __init__(self, num_filters, num_layers, z_size, output_shape):
53 | super(Decoder, self).__init__()
54 | height = output_shape / 2 ** (num_layers - 1) + 1
55 | width = output_shape / 2 ** (num_layers - 1) + 1
56 | self.linear1 = nn.Sequential(
57 | nn.Linear(in_features=z_size, out_features=num_filters * height * width),
58 | nn.ReLU(inplace=True)
59 | )
60 | self.in_channels = num_filters
61 | deconv_list = []
62 | for i in xrange(num_layers - 1):
63 | scale = 2 ** (i + 1)
64 | deconv_i = nn.Sequential(
65 | nn.ConvTranspose2d(
66 | in_channels=self.in_channels, out_channels=num_filters / scale,
67 | kernel_size=4, stride=2, padding=2, output_padding=1
68 | ),
69 | nn.BatchNorm2d(num_features=num_filters / scale),
70 | nn.ReLU(inplace=True)
71 | )
72 | deconv_list.append(deconv_i)
73 | self.in_channels = num_filters / scale
74 | self.deconv = nn.Sequential(*deconv_list)
75 | self.deconv_last = nn.ConvTranspose2d(
76 | in_channels=self.in_channels, out_channels=1,
77 | kernel_size=4, stride=1, padding=2
78 | )
79 | # initialize weights
80 | for m in self.modules():
81 | if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
82 | m.weight.data.normal_(0.0, 0.0099999)
83 | m.bias.data.zero_()
84 | elif isinstance(m, nn.BatchNorm2d):
85 | m.weight.data.normal_(0.0, 0.01)
86 | m.bias.data.zero_()
87 |
88 | def forward(self, z):
89 | linear_out = self.linear1(z)
90 | deconv_input = linear_out.view(-1, 1024, 8, 8)
91 | deconv_out = self.deconv(deconv_input)
92 | recon_x = self.deconv_last(deconv_out)
93 | return F.sigmoid(recon_x), recon_x
94 |
95 |
96 | class Adversary_z(nn.Module):
97 | def __init__(self, num_filters, num_layers, z_size):
98 | super(Adversary_z, self).__init__()
99 | self.in_features = z_size
100 | linears = []
101 | for i in xrange(num_layers):
102 | linear_i = nn.Sequential(
103 | nn.Linear(in_features=self.in_features, out_features=num_filters),
104 | nn.ReLU(inplace=True)
105 | )
106 | linears.append(linear_i)
107 | self.in_features = num_filters
108 | self.linear = nn.Sequential(*linears)
109 | self.final_linear = nn.Linear(in_features=self.in_features, out_features=1)
110 | # initialize weights
111 | for m in self.modules():
112 | if isinstance(m, nn.Linear):
113 | m.weight.data.normal_(0.0, 0.0099999)
114 | m.bias.data.zero_()
115 |
116 | def forward(self, z):
117 | linear_out = self.linear(z)
118 | out = self.final_linear(linear_out)
119 | return out
120 |
121 |
122 | def run():
123 | # define the hyper-parameters
124 | batch_size = 100
125 | e_pretrain_batch_size = 1000
126 | pretrain_epochs = 200
127 | epochs = 100
128 | z_size = 8
129 | lam = 10
130 |
131 | e_num_filters = 1024
132 | e_num_layers = 4
133 |
134 | g_num_filters = 1024
135 | g_num_layers = 3
136 |
137 | d_num_filters = 512
138 | d_num_layers = 4
139 |
140 | # download and load data
141 | train_data = MNIST(
142 | root='data/',
143 | train=True,
144 | transform=torchvision.transforms.ToTensor()
145 | )
146 | test_data = MNIST(
147 | root='data/',
148 | train=False,
149 | transform=torchvision.transforms.ToTensor()
150 | )
151 | # get the data batch
152 | pretrain_data_loader = DataLoader(dataset=train_data, batch_size=e_pretrain_batch_size, shuffle=True)
153 | pretest_data_loader = DataLoader(dataset=test_data, batch_size=e_pretrain_batch_size, shuffle=False)
154 | train_data_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
155 | test_data_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)
156 |
157 | # models
158 | encoder = Encoder(
159 | in_channels=1, num_filters=e_num_filters, num_layers=e_num_layers, z_size=z_size, is_training=True
160 | )
161 | decoder = Decoder(
162 | num_filters=g_num_filters, num_layers=g_num_layers, z_size=z_size, output_shape=28
163 | )
164 | discriminator = Adversary_z(num_filters=d_num_filters, num_layers=d_num_layers, z_size=z_size)
165 |
166 | # load model parameters
167 | # encoder.load_state_dict(torch.load('Pretrain_Encoder_epoch200.pth'))
168 | # encoder.load_state_dict(torch.load('encoder_100.pth'))
169 | # decoder.load_state_dict(torch.load('decoder_100.pth'))
170 | # discriminator.load_state_dict(torch.load('discriminator_100.pth'))
171 |
172 | if torch.cuda.is_available():
173 | encoder, decoder, discriminator = encoder.cuda(), decoder.cuda(), discriminator.cuda()
174 |
175 | # define the optimizer
176 | e_params = encoder.parameters()
177 | ae_params = itertools.chain(encoder.parameters(), decoder.parameters())
178 | d_params = discriminator.parameters()
179 | optimizer_e = optim.Adam(e_params, lr=1e-03, betas=(0.5, 0.999))
180 | optimizer_ae = optim.Adam(ae_params, lr=1e-03, betas=(0.5, 0.999))
181 | optimizer_d = optim.Adam(d_params, lr=5e-04, betas=(0.5, 0.999))
182 |
183 | # pretrain model
184 | def sample_pz(batch_size=100, z_size=8):
185 | return Variable(torch.normal(torch.zeros(batch_size, z_size), std=1).cuda())
186 |
187 | def pretrain_loss(encoded, sample_noise):
188 | # for mean
189 | mean_qz = torch.mean(encoded, dim=0, keepdim=True)
190 | mean_pz = torch.mean(sample_noise, dim=0, keepdim=True)
191 | mean_loss = F.mse_loss(mean_qz, mean_pz)
192 |
193 | # for covariance
194 | cov_qz = torch.matmul((encoded - mean_qz).transpose(0, 1), encoded - mean_qz)
195 | cov_qz /= e_pretrain_batch_size - 1.
196 | cov_pz = torch.matmul((sample_noise - mean_pz).transpose(0, 1), sample_noise - mean_pz)
197 | cov_pz /= e_pretrain_batch_size - 1.
198 | cov_loss = F.mse_loss(cov_qz, cov_pz)
199 |
200 | return mean_loss + cov_loss
201 |
202 | def encoder_pretrain(epoch):
203 | encoder.train()
204 | for batch_idx, (data, _) in enumerate(pretrain_data_loader):
205 | if torch.cuda.is_available():
206 | data = data.cuda()
207 | data = Variable(data)
208 | sample_noise = sample_pz(e_pretrain_batch_size, z_size)
209 | encoded = encoder(data)
210 |
211 | optimizer_e.zero_grad()
212 | loss_pretrain = pretrain_loss(encoded, sample_noise)
213 | loss_pretrain.backward()
214 | optimizer_e.step()
215 | if batch_idx % 10 == 0:
216 | print("Train Epoch: {} [{}/{} ({:.0f}%)] \tLoss: {:.6f}".format(
217 | epoch, batch_idx * len(data), len(pretrain_data_loader.dataset),
218 | 100. * batch_idx / len(pretrain_data_loader), loss_pretrain.data[0]
219 | ))
220 | # save the pretrain model at the last epoch
221 | if epoch == 200:
222 | torch.save(encoder.state_dict(), 'Pretrain_Encoder_epoch{:02d}.pth'.format(epoch))
223 |
224 | # train models
225 | d_loss_function = nn.BCEWithLogitsLoss()
226 |
227 | def gan_loss(sample_qz, sample_pz):
228 | logits_qz = discriminator(sample_qz)
229 | logits_pz = discriminator(sample_pz)
230 |
231 | # losses
232 | loss_qz = d_loss_function(logits_qz, torch.zeros_like(logits_qz))
233 | loss_pz = d_loss_function(logits_pz, torch.ones_like(logits_pz))
234 | loss_qz_trick = d_loss_function(logits_qz, torch.ones_like(logits_qz))
235 | loss_adversary = lam * (loss_qz + loss_pz)
236 | loss_penalty = loss_qz_trick
237 | return (loss_adversary, logits_qz, logits_pz), loss_penalty
238 |
239 | def train(epoch):
240 | for param_group in optimizer_ae.param_groups:
241 | print(param_group['lr'], "learning rate for Auto-Encoder.")
242 | for param_group in optimizer_d.param_groups:
243 | print(param_group['lr'], "learning rate for Discriminator.")
244 | encoder.train(), decoder.train(), discriminator.train()
245 | for batch_idx, (data, _) in enumerate(train_data_loader):
246 | if torch.cuda.is_available():
247 | data = data.cuda()
248 | data = Variable(data)
249 | sample_noise = sample_pz(batch_size, z_size)
250 |
251 | encoded = encoder(data)
252 | # for reconstructed
253 | recon_x, recon_logits = decoder(encoded)
254 | # for sample
255 | decoded, decoded_logits = decoder(sample_noise)
256 |
257 | # losses
258 | recon_loss = F.mse_loss(recon_x, data)
259 | loss_gan, loss_penalty = gan_loss(encoded, sample_noise)
260 | loss_wae = recon_loss + lam * loss_penalty
261 | loss_adv = loss_gan[0]
262 |
263 | # optimize wae
264 | encoder.zero_grad()
265 | decoder.zero_grad()
266 | loss_wae.backward(retain_graph=True)
267 | optimizer_ae.step()
268 |
269 | # optimize adv
270 | discriminator.zero_grad()
271 | loss_adv.backward()
272 | optimizer_d.step()
273 |
274 | if batch_idx % 10 == 0:
275 | print("Train Epoch: {} [{}/{} ({:.0f}%)] \tWAE_Loss: {:.6f}\tD_Loss: {:.6f}".format(
276 | epoch, batch_idx * len(data), len(train_data_loader.dataset),
277 | 100. * batch_idx / len(train_data_loader), loss_wae.data[0], loss_adv.data[0]
278 | ))
279 | # save images and save models
280 | save_image(data.cpu().data, 'real_image_{:02d}.png'.format(epoch), nrow=10)
281 | save_image(recon_x.cpu().data, 'recon_image_{:02d}.png'.format(epoch), nrow=10)
282 | save_image(decoded.cpu().data, 'sample_image_{:02d}.png'.format(epoch), nrow=10)
283 | if epoch % 50 == 0:
284 | torch.save(encoder.state_dict(), 'encoder_{:02d}.pth'.format(epoch))
285 | torch.save(decoder.state_dict(), 'decoder_{:02d}.pth'.format(epoch))
286 | torch.save(discriminator.state_dict(), 'discriminator_{:02d}.pth'.format(epoch))
287 |
288 | # add the learning rate adjust function
289 | def adjust_learning_rate_manual(optimizer, epoch):
290 | for param_group in optimizer.param_groups:
291 | if epoch == 30:
292 | param_group['lr'] /= 2.
293 | elif epoch == 50:
294 | param_group['lr'] /= 2.5
295 | elif epoch == 100:
296 | param_group['lr'] /= 2.
297 |
298 | print("=========> Pretrain encoder")
299 | for epoch in range(1, pretrain_epochs+1):
300 | encoder_pretrain(epoch)
301 | print("=========> Train models")
302 | for epoch in range(1, epochs+1):
303 | # do not need to use adjust_learning_rate in mnist
304 | # adjust_learning_rate_manual(optimizer_ae, epoch)
305 | # adjust_learning_rate_manual(optimizer_d, epoch)
306 | train(epoch)
307 |
308 |
309 | if __name__ == "__main__":
310 | run()
311 |
--------------------------------------------------------------------------------