├── ACGAN.py
├── BEGAN.py
├── CGAN.py
├── DRAGAN.py
├── EBGAN.py
├── GAN.py
├── LSGAN.py
├── README.md
├── WGAN.py
├── WGAN_GP.py
├── assets
├── celebA_results
│ ├── BEGAN_epoch001.png
│ ├── BEGAN_epoch010.png
│ ├── BEGAN_epoch025.png
│ ├── BEGAN_generate_animation.gif
│ ├── DRAGAN_epoch001.png
│ ├── DRAGAN_epoch010.png
│ ├── DRAGAN_epoch025.png
│ ├── DRAGAN_generate_animation.gif
│ ├── EBGAN_epoch001.png
│ ├── EBGAN_epoch010.png
│ ├── EBGAN_epoch025.png
│ ├── EBGAN_generate_animation.gif
│ ├── GAN_epoch001.png
│ ├── GAN_epoch010.png
│ ├── GAN_epoch025.png
│ ├── GAN_generate_animation.gif
│ ├── LSGAN_epoch001.png
│ ├── LSGAN_epoch010.png
│ ├── LSGAN_epoch025.png
│ ├── LSGAN_generate_animation.gif
│ ├── WGAN_GP_epoch001.png
│ ├── WGAN_GP_epoch010.png
│ ├── WGAN_GP_epoch025.png
│ ├── WGAN_GP_generate_animation.gif
│ ├── WGAN_epoch001.png
│ ├── WGAN_epoch010.png
│ ├── WGAN_epoch025.png
│ └── WGAN_generate_animation.gif
├── equations
│ ├── ACGAN.png
│ ├── BEGAN.png
│ ├── CGAN.png
│ ├── DRAGAN.png
│ ├── EBGAN.png
│ ├── GAN.png
│ ├── LSGAN.png
│ ├── WGAN.png
│ ├── WGAN_GP.png
│ └── infoGAN.png
├── etc
│ └── GAN_structure.png
├── fashion_mnist_results
│ ├── ACGAN_epoch001.png
│ ├── ACGAN_epoch025.png
│ ├── ACGAN_epoch050.png
│ ├── ACGAN_generate_animation.gif
│ ├── ACGAN_loss.png
│ ├── BEGAN_epoch001.png
│ ├── BEGAN_epoch025.png
│ ├── BEGAN_epoch050.png
│ ├── BEGAN_generate_animation.gif
│ ├── BEGAN_loss.png
│ ├── CGAN_epoch001.png
│ ├── CGAN_epoch025.png
│ ├── CGAN_epoch050.png
│ ├── CGAN_generate_animation.gif
│ ├── CGAN_loss.png
│ ├── DRAGAN_epoch001.png
│ ├── DRAGAN_epoch025.png
│ ├── DRAGAN_epoch050.png
│ ├── DRAGAN_generate_animation.gif
│ ├── DRAGAN_loss.png
│ ├── EBGAN_epoch001.png
│ ├── EBGAN_epoch025.png
│ ├── EBGAN_epoch050.png
│ ├── EBGAN_generate_animation.gif
│ ├── EBGAN_loss.png
│ ├── GAN_epoch001.png
│ ├── GAN_epoch025.png
│ ├── GAN_epoch050.png
│ ├── GAN_generate_animation.gif
│ ├── GAN_loss.png
│ ├── LSGAN_epoch001.png
│ ├── LSGAN_epoch025.png
│ ├── LSGAN_epoch050.png
│ ├── LSGAN_generate_animation.gif
│ ├── LSGAN_loss.png
│ ├── WGAN_GP_epoch001.png
│ ├── WGAN_GP_epoch025.png
│ ├── WGAN_GP_epoch050.png
│ ├── WGAN_GP_generate_animation.gif
│ ├── WGAN_GP_loss.png
│ ├── WGAN_epoch001.png
│ ├── WGAN_epoch025.png
│ ├── WGAN_epoch050.png
│ ├── WGAN_generate_animation.gif
│ ├── WGAN_loss.png
│ ├── infoGAN_cont_epoch001.png
│ ├── infoGAN_cont_epoch025.png
│ ├── infoGAN_cont_epoch050.png
│ ├── infoGAN_cont_generate_animation.gif
│ ├── infoGAN_epoch001.png
│ ├── infoGAN_epoch025.png
│ ├── infoGAN_epoch050.png
│ ├── infoGAN_generate_animation.gif
│ └── infoGAN_loss.png
└── mnist_results
│ ├── ACGAN_epoch001.png
│ ├── ACGAN_epoch025.png
│ ├── ACGAN_epoch050.png
│ ├── ACGAN_generate_animation.gif
│ ├── ACGAN_loss.png
│ ├── BEGAN_epoch001.png
│ ├── BEGAN_epoch025.png
│ ├── BEGAN_epoch050.png
│ ├── BEGAN_generate_animation.gif
│ ├── BEGAN_loss.png
│ ├── CGAN_epoch001.png
│ ├── CGAN_epoch025.png
│ ├── CGAN_epoch050.png
│ ├── CGAN_generate_animation.gif
│ ├── CGAN_loss.png
│ ├── DRAGAN_epoch001.png
│ ├── DRAGAN_epoch025.png
│ ├── DRAGAN_epoch050.png
│ ├── DRAGAN_generate_animation.gif
│ ├── DRAGAN_loss.png
│ ├── EBGAN_epoch001.png
│ ├── EBGAN_epoch025.png
│ ├── EBGAN_epoch050.png
│ ├── EBGAN_generate_animation.gif
│ ├── EBGAN_loss.png
│ ├── GAN_epoch001.png
│ ├── GAN_epoch025.png
│ ├── GAN_epoch050.png
│ ├── GAN_generate_animation.gif
│ ├── GAN_loss.png
│ ├── LSGAN_epoch001.png
│ ├── LSGAN_epoch025.png
│ ├── LSGAN_epoch050.png
│ ├── LSGAN_generate_animation.gif
│ ├── LSGAN_loss.png
│ ├── WGAN_GP_epoch001.png
│ ├── WGAN_GP_epoch025.png
│ ├── WGAN_GP_epoch050.png
│ ├── WGAN_GP_generate_animation.gif
│ ├── WGAN_GP_loss.png
│ ├── WGAN_epoch001.png
│ ├── WGAN_epoch025.png
│ ├── WGAN_epoch050.png
│ ├── WGAN_generate_animation.gif
│ ├── WGAN_loss.png
│ ├── infoGAN_cont_epoch001.png
│ ├── infoGAN_cont_epoch025.png
│ ├── infoGAN_cont_epoch050.png
│ ├── infoGAN_cont_generate_animation.gif
│ ├── infoGAN_epoch001.png
│ ├── infoGAN_epoch025.png
│ ├── infoGAN_epoch050.png
│ ├── infoGAN_generate_animation.gif
│ └── infoGAN_loss.png
├── dataloader.py
├── infoGAN.py
├── main.py
└── utils.py
/ACGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from dataloader import dataloader
6 |
7 | class generator(nn.Module):
8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
10 | def __init__(self, input_dim=100, output_dim=1, input_size=32, class_num=10):
11 | super(generator, self).__init__()
12 | self.input_dim = input_dim
13 | self.output_dim = output_dim
14 | self.input_size = input_size
15 | self.class_num = class_num
16 |
17 | self.fc = nn.Sequential(
18 | nn.Linear(self.input_dim + self.class_num, 1024),
19 | nn.BatchNorm1d(1024),
20 | nn.ReLU(),
21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
23 | nn.ReLU(),
24 | )
25 | self.deconv = nn.Sequential(
26 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
27 | nn.BatchNorm2d(64),
28 | nn.ReLU(),
29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
30 | nn.Tanh(),
31 | )
32 | utils.initialize_weights(self)
33 |
34 | def forward(self, input, label):
35 | x = torch.cat([input, label], 1)
36 | x = self.fc(x)
37 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
38 | x = self.deconv(x)
39 |
40 | return x
41 |
42 | class discriminator(nn.Module):
43 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
44 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
45 | def __init__(self, input_dim=1, output_dim=1, input_size=32, class_num=10):
46 | super(discriminator, self).__init__()
47 | self.input_dim = input_dim
48 | self.output_dim = output_dim
49 | self.input_size = input_size
50 | self.class_num = class_num
51 |
52 | self.conv = nn.Sequential(
53 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
54 | nn.LeakyReLU(0.2),
55 | nn.Conv2d(64, 128, 4, 2, 1),
56 | nn.BatchNorm2d(128),
57 | nn.LeakyReLU(0.2),
58 | )
59 | self.fc1 = nn.Sequential(
60 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
61 | nn.BatchNorm1d(1024),
62 | nn.LeakyReLU(0.2),
63 | )
64 | self.dc = nn.Sequential(
65 | nn.Linear(1024, self.output_dim),
66 | nn.Sigmoid(),
67 | )
68 | self.cl = nn.Sequential(
69 | nn.Linear(1024, self.class_num),
70 | )
71 | utils.initialize_weights(self)
72 |
73 | def forward(self, input):
74 | x = self.conv(input)
75 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
76 | x = self.fc1(x)
77 | d = self.dc(x)
78 | c = self.cl(x)
79 |
80 | return d, c
81 |
82 | class ACGAN(object):
83 | def __init__(self, args):
84 | # parameters
85 | self.epoch = args.epoch
86 | self.sample_num = 100
87 | self.batch_size = args.batch_size
88 | self.save_dir = args.save_dir
89 | self.result_dir = args.result_dir
90 | self.dataset = args.dataset
91 | self.log_dir = args.log_dir
92 | self.gpu_mode = args.gpu_mode
93 | self.model_name = args.gan_type
94 | self.input_size = args.input_size
95 | self.z_dim = 62
96 | self.class_num = 10
97 | self.sample_num = self.class_num ** 2
98 |
99 | # load dataset
100 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
101 | data = self.data_loader.__iter__().__next__()[0]
102 |
103 | # networks init
104 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
105 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
106 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
107 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
108 |
109 | if self.gpu_mode:
110 | self.G.cuda()
111 | self.D.cuda()
112 | self.BCE_loss = nn.BCELoss().cuda()
113 | self.CE_loss = nn.CrossEntropyLoss().cuda()
114 | else:
115 | self.BCE_loss = nn.BCELoss()
116 | self.CE_loss = nn.CrossEntropyLoss()
117 |
118 | print('---------- Networks architecture -------------')
119 | utils.print_network(self.G)
120 | utils.print_network(self.D)
121 | print('-----------------------------------------------')
122 |
123 | # fixed noise & condition
124 | self.sample_z_ = torch.zeros((self.sample_num, self.z_dim))
125 | for i in range(self.class_num):
126 | self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim)
127 | for j in range(1, self.class_num):
128 | self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num]
129 |
130 | temp = torch.zeros((self.class_num, 1))
131 | for i in range(self.class_num):
132 | temp[i, 0] = i
133 |
134 | temp_y = torch.zeros((self.sample_num, 1))
135 | for i in range(self.class_num):
136 | temp_y[i*self.class_num: (i+1)*self.class_num] = temp
137 |
138 | self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
139 | if self.gpu_mode:
140 | self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda()
141 |
142 | def train(self):
143 | self.train_hist = {}
144 | self.train_hist['D_loss'] = []
145 | self.train_hist['G_loss'] = []
146 | self.train_hist['per_epoch_time'] = []
147 | self.train_hist['total_time'] = []
148 |
149 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
150 | if self.gpu_mode:
151 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
152 |
153 | self.D.train()
154 | print('training start!!')
155 | start_time = time.time()
156 | for epoch in range(self.epoch):
157 | self.G.train()
158 | epoch_start_time = time.time()
159 | for iter, (x_, y_) in enumerate(self.data_loader):
160 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
161 | break
162 | z_ = torch.rand((self.batch_size, self.z_dim))
163 | y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
164 | if self.gpu_mode:
165 | x_, z_, y_vec_ = x_.cuda(), z_.cuda(), y_vec_.cuda()
166 |
167 | # update D network
168 | self.D_optimizer.zero_grad()
169 |
170 | D_real, C_real = self.D(x_)
171 | D_real_loss = self.BCE_loss(D_real, self.y_real_)
172 | C_real_loss = self.CE_loss(C_real, torch.max(y_vec_, 1)[1])
173 |
174 | G_ = self.G(z_, y_vec_)
175 | D_fake, C_fake = self.D(G_)
176 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
177 | C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])
178 |
179 | D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss
180 | self.train_hist['D_loss'].append(D_loss.item())
181 |
182 | D_loss.backward()
183 | self.D_optimizer.step()
184 |
185 | # update G network
186 | self.G_optimizer.zero_grad()
187 |
188 | G_ = self.G(z_, y_vec_)
189 | D_fake, C_fake = self.D(G_)
190 |
191 | G_loss = self.BCE_loss(D_fake, self.y_real_)
192 | C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])
193 |
194 | G_loss += C_fake_loss
195 | self.train_hist['G_loss'].append(G_loss.item())
196 |
197 | G_loss.backward()
198 | self.G_optimizer.step()
199 |
200 | if ((iter + 1) % 100) == 0:
201 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
202 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
203 |
204 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
205 | with torch.no_grad():
206 | self.visualize_results((epoch+1))
207 |
208 | self.train_hist['total_time'].append(time.time() - start_time)
209 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
210 | self.epoch, self.train_hist['total_time'][0]))
211 | print("Training finish!... save training results")
212 |
213 | self.save()
214 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
215 | self.epoch)
216 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
217 |
218 | def visualize_results(self, epoch, fix=True):
219 | self.G.eval()
220 |
221 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
222 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
223 |
224 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num)))
225 |
226 | if fix:
227 | """ fixed noise """
228 | samples = self.G(self.sample_z_, self.sample_y_)
229 | else:
230 | """ random noise """
231 | sample_y_ = torch.zeros(self.batch_size, self.class_num).scatter_(1, torch.randint(0, self.class_num - 1, (self.batch_size, 1)).type(torch.LongTensor), 1)
232 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
233 | if self.gpu_mode:
234 | sample_z_, sample_y_ = sample_z_.cuda(), sample_y_.cuda()
235 |
236 | samples = self.G(sample_z_, sample_y_)
237 |
238 | if self.gpu_mode:
239 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
240 | else:
241 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
242 |
243 | samples = (samples + 1) / 2
244 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
245 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
246 |
247 | def save(self):
248 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
249 |
250 | if not os.path.exists(save_dir):
251 | os.makedirs(save_dir)
252 |
253 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
254 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
255 |
256 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
257 | pickle.dump(self.train_hist, f)
258 |
259 | def load(self):
260 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
261 |
262 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
263 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
264 |
--------------------------------------------------------------------------------
/BEGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from dataloader import dataloader
6 |
7 | class generator(nn.Module):
8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
10 | def __init__(self, input_dim=100, output_dim=1, input_size=32):
11 | super(generator, self).__init__()
12 | self.input_dim = input_dim
13 | self.output_dim = output_dim
14 | self.input_size = input_size
15 |
16 | self.fc = nn.Sequential(
17 | nn.Linear(self.input_dim, 1024),
18 | nn.BatchNorm1d(1024),
19 | nn.ReLU(),
20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.ReLU(),
23 | )
24 | self.deconv = nn.Sequential(
25 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
26 | nn.BatchNorm2d(64),
27 | nn.ReLU(),
28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
29 | nn.Tanh(),
30 | )
31 | utils.initialize_weights(self)
32 |
33 | def forward(self, input):
34 | x = self.fc(input)
35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
36 | x = self.deconv(x)
37 |
38 | return x
39 |
40 | class discriminator(nn.Module):
41 | # It must be Auto-Encoder style architecture
42 | # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S
43 | def __init__(self, input_dim=1, output_dim=1, input_size=32):
44 | super(discriminator, self).__init__()
45 | self.input_dim = input_dim
46 | self.output_dim = output_dim
47 | self.input_size = input_size
48 |
49 | self.conv = nn.Sequential(
50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
51 | nn.ReLU(),
52 | )
53 | self.fc = nn.Sequential(
54 | nn.Linear(64 * (self.input_size // 2) * (self.input_size // 2), 32),
55 | nn.Linear(32, 64 * (self.input_size // 2) * (self.input_size // 2)),
56 | )
57 | self.deconv = nn.Sequential(
58 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
59 | #nn.Sigmoid(),
60 | )
61 |
62 | utils.initialize_weights(self)
63 |
64 | def forward(self, input):
65 | x = self.conv(input)
66 | x = x.view(x.size()[0], -1)
67 | x = self.fc(x)
68 | x = x.view(-1, 64, (self.input_size // 2), (self.input_size // 2))
69 | x = self.deconv(x)
70 |
71 | return x
72 |
73 | class BEGAN(object):
74 | def __init__(self, args):
75 | # parameters
76 | self.epoch = args.epoch
77 | self.sample_num = 100
78 | self.batch_size = args.batch_size
79 | self.save_dir = args.save_dir
80 | self.result_dir = args.result_dir
81 | self.dataset = args.dataset
82 | self.log_dir = args.log_dir
83 | self.gpu_mode = args.gpu_mode
84 | self.model_name = args.gan_type
85 | self.input_size = args.input_size
86 | self.z_dim = 62
87 | self.gamma = 1
88 | self.lambda_ = 0.001
89 | self.k = 0.0
90 | self.lr_lower_boundary = 0.00002
91 |
92 | # load dataset
93 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
94 | data = self.data_loader.__iter__().__next__()[0]
95 |
96 | # networks init
97 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
98 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
99 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(args.beta1, args.beta2))
100 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(args.beta1, args.beta2))
101 |
102 | if self.gpu_mode:
103 | self.G.cuda()
104 | self.D.cuda()
105 |
106 | print('---------- Networks architecture -------------')
107 | utils.print_network(self.G)
108 | utils.print_network(self.D)
109 | print('-----------------------------------------------')
110 |
111 | # fixed noise
112 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
113 | if self.gpu_mode:
114 | self.sample_z_ = self.sample_z_.cuda()
115 |
116 | def train(self):
117 | self.train_hist = {}
118 | self.train_hist['D_loss'] = []
119 | self.train_hist['G_loss'] = []
120 | self.train_hist['per_epoch_time'] = []
121 | self.train_hist['total_time'] = []
122 | self.M = {}
123 | self.M['pre'] = []
124 | self.M['pre'].append(1)
125 | self.M['cur'] = []
126 |
127 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
128 | if self.gpu_mode:
129 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
130 |
131 | self.D.train()
132 | print('training start!!')
133 | start_time = time.time()
134 | for epoch in range(self.epoch):
135 | self.G.train()
136 | epoch_start_time = time.time()
137 | for iter, (x_, _) in enumerate(self.data_loader):
138 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
139 | break
140 |
141 | z_ = torch.rand((self.batch_size, self.z_dim))
142 |
143 | if self.gpu_mode:
144 | x_, z_ = x_.cuda(), z_.cuda()
145 |
146 | # update D network
147 | self.D_optimizer.zero_grad()
148 |
149 | D_real = self.D(x_)
150 | D_real_loss = torch.mean(torch.abs(D_real - x_))
151 |
152 | G_ = self.G(z_)
153 | D_fake = self.D(G_)
154 | D_fake_loss = torch.mean(torch.abs(D_fake - G_))
155 |
156 | D_loss = D_real_loss - self.k * D_fake_loss
157 | self.train_hist['D_loss'].append(D_loss.item())
158 |
159 | D_loss.backward()
160 | self.D_optimizer.step()
161 |
162 | # update G network
163 | self.G_optimizer.zero_grad()
164 |
165 | G_ = self.G(z_)
166 | D_fake = self.D(G_)
167 | D_fake_loss = torch.mean(torch.abs(D_fake - G_))
168 |
169 | G_loss = D_fake_loss
170 | self.train_hist['G_loss'].append(G_loss.item())
171 |
172 | G_loss.backward()
173 | self.G_optimizer.step()
174 |
175 | # convergence metric
176 | temp_M = D_real_loss + torch.abs(self.gamma * D_real_loss - G_loss)
177 |
178 | # operation for updating k
179 | temp_k = self.k + self.lambda_ * (self.gamma * D_real_loss - G_loss)
180 | temp_k = temp_k.item()
181 |
182 | self.k = min(max(temp_k, 0), 1)
183 | self.M['cur'] = temp_M.item()
184 |
185 | if ((iter + 1) % 100) == 0:
186 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, M: %.8f, k: %.8f" %
187 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item(), self.M['cur'], self.k))
188 |
189 |
190 | # if epoch == 0:
191 | # self.M['pre'] = self.M['cur']
192 | # self.M['cur'] = []
193 | # else:
194 | if np.mean(self.M['pre']) < np.mean(self.M['cur']):
195 | pre_lr = self.G_optimizer.param_groups[0]['lr']
196 | self.G_optimizer.param_groups[0]['lr'] = max(self.G_optimizer.param_groups[0]['lr'] / 2.0,
197 | self.lr_lower_boundary)
198 | self.D_optimizer.param_groups[0]['lr'] = max(self.D_optimizer.param_groups[0]['lr'] / 2.0,
199 | self.lr_lower_boundary)
200 | print('M_pre: ' + str(np.mean(self.M['pre'])) + ', M_cur: ' + str(
201 | np.mean(self.M['cur'])) + ', lr: ' + str(pre_lr) + ' --> ' + str(
202 | self.G_optimizer.param_groups[0]['lr']))
203 | else:
204 | print('M_pre: ' + str(np.mean(self.M['pre'])) + ', M_cur: ' + str(np.mean(self.M['cur'])))
205 | self.M['pre'] = self.M['cur']
206 |
207 | self.M['cur'] = []
208 |
209 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
210 | with torch.no_grad():
211 | self.visualize_results((epoch+1))
212 |
213 | self.train_hist['total_time'].append(time.time() - start_time)
214 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
215 | self.epoch, self.train_hist['total_time'][0]))
216 | print("Training finish!... save training results")
217 |
218 | self.save()
219 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
220 | self.epoch)
221 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
222 |
223 | def visualize_results(self, epoch, fix=True):
224 | self.G.eval()
225 |
226 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
227 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
228 |
229 | tot_num_samples = min(self.sample_num, self.batch_size)
230 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
231 |
232 | if fix:
233 | """ fixed noise """
234 | samples = self.G(self.sample_z_)
235 | else:
236 | """ random noise """
237 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
238 | if self.gpu_mode:
239 | sample_z_ = sample_z_.cuda()
240 |
241 | samples = self.G(sample_z_)
242 |
243 | if self.gpu_mode:
244 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
245 | else:
246 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
247 |
248 | samples = (samples + 1) / 2
249 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
250 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
251 |
252 | def save(self):
253 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
254 |
255 | if not os.path.exists(save_dir):
256 | os.makedirs(save_dir)
257 |
258 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
259 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
260 |
261 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
262 | pickle.dump(self.train_hist, f)
263 |
264 | def load(self):
265 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
266 |
267 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
268 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/CGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from dataloader import dataloader
6 |
7 | class generator(nn.Module):
8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
10 | def __init__(self, input_dim=100, output_dim=1, input_size=32, class_num=10):
11 | super(generator, self).__init__()
12 | self.input_dim = input_dim
13 | self.output_dim = output_dim
14 | self.input_size = input_size
15 | self.class_num = class_num
16 |
17 | self.fc = nn.Sequential(
18 | nn.Linear(self.input_dim + self.class_num, 1024),
19 | nn.BatchNorm1d(1024),
20 | nn.ReLU(),
21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
23 | nn.ReLU(),
24 | )
25 | self.deconv = nn.Sequential(
26 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
27 | nn.BatchNorm2d(64),
28 | nn.ReLU(),
29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
30 | nn.Tanh(),
31 | )
32 | utils.initialize_weights(self)
33 |
34 | def forward(self, input, label):
35 | x = torch.cat([input, label], 1)
36 | x = self.fc(x)
37 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
38 | x = self.deconv(x)
39 |
40 | return x
41 |
42 | class discriminator(nn.Module):
43 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
44 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
45 | def __init__(self, input_dim=1, output_dim=1, input_size=32, class_num=10):
46 | super(discriminator, self).__init__()
47 | self.input_dim = input_dim
48 | self.output_dim = output_dim
49 | self.input_size = input_size
50 | self.class_num = class_num
51 |
52 | self.conv = nn.Sequential(
53 | nn.Conv2d(self.input_dim + self.class_num, 64, 4, 2, 1),
54 | nn.LeakyReLU(0.2),
55 | nn.Conv2d(64, 128, 4, 2, 1),
56 | nn.BatchNorm2d(128),
57 | nn.LeakyReLU(0.2),
58 | )
59 | self.fc = nn.Sequential(
60 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
61 | nn.BatchNorm1d(1024),
62 | nn.LeakyReLU(0.2),
63 | nn.Linear(1024, self.output_dim),
64 | nn.Sigmoid(),
65 | )
66 | utils.initialize_weights(self)
67 |
68 | def forward(self, input, label):
69 | x = torch.cat([input, label], 1)
70 | x = self.conv(x)
71 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
72 | x = self.fc(x)
73 |
74 | return x
75 |
76 | class CGAN(object):
77 | def __init__(self, args):
78 | # parameters
79 | self.epoch = args.epoch
80 | self.batch_size = args.batch_size
81 | self.save_dir = args.save_dir
82 | self.result_dir = args.result_dir
83 | self.dataset = args.dataset
84 | self.log_dir = args.log_dir
85 | self.gpu_mode = args.gpu_mode
86 | self.model_name = args.gan_type
87 | self.input_size = args.input_size
88 | self.z_dim = 62
89 | self.class_num = 10
90 | self.sample_num = self.class_num ** 2
91 |
92 | # load dataset
93 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
94 | data = self.data_loader.__iter__().__next__()[0]
95 |
96 | # networks init
97 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, class_num=self.class_num)
98 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, class_num=self.class_num)
99 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
100 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
101 |
102 | if self.gpu_mode:
103 | self.G.cuda()
104 | self.D.cuda()
105 | self.BCE_loss = nn.BCELoss().cuda()
106 | else:
107 | self.BCE_loss = nn.BCELoss()
108 |
109 | print('---------- Networks architecture -------------')
110 | utils.print_network(self.G)
111 | utils.print_network(self.D)
112 | print('-----------------------------------------------')
113 |
114 | # fixed noise & condition
115 | self.sample_z_ = torch.zeros((self.sample_num, self.z_dim))
116 | for i in range(self.class_num):
117 | self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim)
118 | for j in range(1, self.class_num):
119 | self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num]
120 |
121 | temp = torch.zeros((self.class_num, 1))
122 | for i in range(self.class_num):
123 | temp[i, 0] = i
124 |
125 | temp_y = torch.zeros((self.sample_num, 1))
126 | for i in range(self.class_num):
127 | temp_y[i*self.class_num: (i+1)*self.class_num] = temp
128 |
129 | self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
130 | if self.gpu_mode:
131 | self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda()
132 |
133 | def train(self):
134 | self.train_hist = {}
135 | self.train_hist['D_loss'] = []
136 | self.train_hist['G_loss'] = []
137 | self.train_hist['per_epoch_time'] = []
138 | self.train_hist['total_time'] = []
139 |
140 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
141 | if self.gpu_mode:
142 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
143 |
144 | self.D.train()
145 | print('training start!!')
146 | start_time = time.time()
147 | for epoch in range(self.epoch):
148 | self.G.train()
149 | epoch_start_time = time.time()
150 | for iter, (x_, y_) in enumerate(self.data_loader):
151 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
152 | break
153 |
154 | z_ = torch.rand((self.batch_size, self.z_dim))
155 | y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
156 | y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
157 | if self.gpu_mode:
158 | x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda()
159 |
160 | # update D network
161 | self.D_optimizer.zero_grad()
162 |
163 | D_real = self.D(x_, y_fill_)
164 | D_real_loss = self.BCE_loss(D_real, self.y_real_)
165 |
166 | G_ = self.G(z_, y_vec_)
167 | D_fake = self.D(G_, y_fill_)
168 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
169 |
170 | D_loss = D_real_loss + D_fake_loss
171 | self.train_hist['D_loss'].append(D_loss.item())
172 |
173 | D_loss.backward()
174 | self.D_optimizer.step()
175 |
176 | # update G network
177 | self.G_optimizer.zero_grad()
178 |
179 | G_ = self.G(z_, y_vec_)
180 | D_fake = self.D(G_, y_fill_)
181 | G_loss = self.BCE_loss(D_fake, self.y_real_)
182 | self.train_hist['G_loss'].append(G_loss.item())
183 |
184 | G_loss.backward()
185 | self.G_optimizer.step()
186 |
187 | if ((iter + 1) % 100) == 0:
188 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
189 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
190 |
191 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
192 | with torch.no_grad():
193 | self.visualize_results((epoch+1))
194 |
195 | self.train_hist['total_time'].append(time.time() - start_time)
196 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
197 | self.epoch, self.train_hist['total_time'][0]))
198 | print("Training finish!... save training results")
199 |
200 | self.save()
201 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
202 | self.epoch)
203 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
204 |
205 | def visualize_results(self, epoch, fix=True):
206 | self.G.eval()
207 |
208 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
209 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
210 |
211 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num)))
212 |
213 | if fix:
214 | """ fixed noise """
215 | samples = self.G(self.sample_z_, self.sample_y_)
216 | else:
217 | """ random noise """
218 | sample_y_ = torch.zeros(self.batch_size, self.class_num).scatter_(1, torch.randint(0, self.class_num - 1, (self.batch_size, 1)).type(torch.LongTensor), 1)
219 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
220 | if self.gpu_mode:
221 | sample_z_, sample_y_ = sample_z_.cuda(), sample_y_.cuda()
222 |
223 | samples = self.G(sample_z_, sample_y_)
224 |
225 | if self.gpu_mode:
226 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
227 | else:
228 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
229 |
230 | samples = (samples + 1) / 2
231 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
232 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
233 |
234 | def save(self):
235 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
236 |
237 | if not os.path.exists(save_dir):
238 | os.makedirs(save_dir)
239 |
240 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
241 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
242 |
243 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
244 | pickle.dump(self.train_hist, f)
245 |
246 | def load(self):
247 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
248 |
249 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
250 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/DRAGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.autograd import grad
6 | from dataloader import dataloader
7 |
8 | class generator(nn.Module):
9 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
10 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
11 | def __init__(self, input_dim=100, output_dim=1, input_size=32):
12 | super(generator, self).__init__()
13 | self.input_dim = input_dim
14 | self.output_dim = output_dim
15 | self.input_size = input_size
16 |
17 | self.fc = nn.Sequential(
18 | nn.Linear(self.input_dim, 1024),
19 | nn.BatchNorm1d(1024),
20 | nn.ReLU(),
21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
23 | nn.ReLU(),
24 | )
25 | self.deconv = nn.Sequential(
26 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
27 | nn.BatchNorm2d(64),
28 | nn.ReLU(),
29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
30 | nn.Tanh(),
31 | )
32 | utils.initialize_weights(self)
33 |
34 | def forward(self, input):
35 | x = self.fc(input)
36 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
37 | x = self.deconv(x)
38 |
39 | return x
40 |
41 | class discriminator(nn.Module):
42 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
43 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
44 | def __init__(self, input_dim=1, output_dim=1, input_size=32):
45 | super(discriminator, self).__init__()
46 | self.input_dim = input_dim
47 | self.output_dim = output_dim
48 | self.input_size = input_size
49 |
50 | self.conv = nn.Sequential(
51 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
52 | nn.LeakyReLU(0.2),
53 | nn.Conv2d(64, 128, 4, 2, 1),
54 | nn.BatchNorm2d(128),
55 | nn.LeakyReLU(0.2),
56 | )
57 | self.fc = nn.Sequential(
58 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
59 | nn.BatchNorm1d(1024),
60 | nn.LeakyReLU(0.2),
61 | nn.Linear(1024, self.output_dim),
62 | nn.Sigmoid(),
63 | )
64 | utils.initialize_weights(self)
65 |
66 | def forward(self, input):
67 | x = self.conv(input)
68 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
69 | x = self.fc(x)
70 |
71 | return x
72 |
73 | class DRAGAN(object):
74 | def __init__(self, args):
75 | # parameters
76 | self.epoch = args.epoch
77 | self.sample_num = 100
78 | self.batch_size = args.batch_size
79 | self.save_dir = args.save_dir
80 | self.result_dir = args.result_dir
81 | self.dataset = args.dataset
82 | self.log_dir = args.log_dir
83 | self.gpu_mode = args.gpu_mode
84 | self.model_name = args.gan_type
85 | self.input_size = args.input_size
86 | self.z_dim = 62
87 | self.lambda_ = 0.25
88 |
89 | # load dataset
90 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
91 | data = self.data_loader.__iter__().__next__()[0]
92 |
93 | # networks init
94 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
95 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
96 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
97 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
98 |
99 | if self.gpu_mode:
100 | self.G.cuda()
101 | self.D.cuda()
102 | self.BCE_loss = nn.BCELoss().cuda()
103 | else:
104 | self.BCE_loss = nn.BCELoss()
105 |
106 | print('---------- Networks architecture -------------')
107 | utils.print_network(self.G)
108 | utils.print_network(self.D)
109 | print('-----------------------------------------------')
110 |
111 | # fixed noise
112 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
113 | if self.gpu_mode:
114 | self.sample_z_ = self.sample_z_.cuda()
115 |
116 | def train(self):
117 | self.train_hist = {}
118 | self.train_hist['D_loss'] = []
119 | self.train_hist['G_loss'] = []
120 | self.train_hist['per_epoch_time'] = []
121 | self.train_hist['total_time'] = []
122 |
123 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
124 | if self.gpu_mode:
125 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
126 |
127 | self.D.train()
128 | print('training start!!')
129 | start_time = time.time()
130 | for epoch in range(self.epoch):
131 | epoch_start_time = time.time()
132 | self.G.train()
133 | for iter, (x_, _) in enumerate(self.data_loader):
134 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
135 | break
136 |
137 | z_ = torch.rand((self.batch_size, self.z_dim))
138 | if self.gpu_mode:
139 | x_, z_ = x_.cuda(), z_.cuda()
140 |
141 | # update D network
142 | self.D_optimizer.zero_grad()
143 |
144 | D_real = self.D(x_)
145 | D_real_loss = self.BCE_loss(D_real, self.y_real_)
146 |
147 | G_ = self.G(z_)
148 | D_fake = self.D(G_)
149 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
150 |
151 | """ DRAGAN Loss (Gradient penalty) """
152 | # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb
153 | alpha = torch.rand(self.batch_size, 1, 1, 1).cuda()
154 | if self.gpu_mode:
155 | alpha = alpha.cuda()
156 | x_p = x_ + 0.5 * x_.std() * torch.rand(x_.size()).cuda()
157 | else:
158 | x_p = x_ + 0.5 * x_.std() * torch.rand(x_.size())
159 | differences = x_p - x_
160 | interpolates = x_ + (alpha * differences)
161 | interpolates.requires_grad = True
162 | pred_hat = self.D(interpolates)
163 | if self.gpu_mode:
164 | gradients = grad(outputs=pred_hat, inputs=interpolates, grad_outputs=torch.ones(pred_hat.size()).cuda(),
165 | create_graph=True, retain_graph=True, only_inputs=True)[0]
166 | else:
167 | gradients = grad(outputs=pred_hat, inputs=interpolates, grad_outputs=torch.ones(pred_hat.size()),
168 | create_graph=True, retain_graph=True, only_inputs=True)[0]
169 |
170 | gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()
171 |
172 | D_loss = D_real_loss + D_fake_loss + gradient_penalty
173 | self.train_hist['D_loss'].append(D_loss.item())
174 | D_loss.backward()
175 | self.D_optimizer.step()
176 |
177 | # update G network
178 | self.G_optimizer.zero_grad()
179 |
180 | G_ = self.G(z_)
181 | D_fake = self.D(G_)
182 |
183 | G_loss = self.BCE_loss(D_fake, self.y_real_)
184 | self.train_hist['G_loss'].append(G_loss.item())
185 |
186 | G_loss.backward()
187 | self.G_optimizer.step()
188 |
189 | if ((iter + 1) % 100) == 0:
190 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
191 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
192 |
193 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
194 | with torch.no_grad():
195 | self.visualize_results((epoch+1))
196 |
197 | self.train_hist['total_time'].append(time.time() - start_time)
198 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
199 | self.epoch, self.train_hist['total_time'][0]))
200 | print("Training finish!... save training results")
201 |
202 | self.save()
203 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, self.epoch)
204 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
205 |
206 | def visualize_results(self, epoch, fix=True):
207 | self.G.eval()
208 |
209 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
210 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
211 |
212 | tot_num_samples = min(self.sample_num, self.batch_size)
213 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
214 |
215 | if fix:
216 | """ fixed noise """
217 | samples = self.G(self.sample_z_)
218 | else:
219 | """ random noise """
220 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
221 | if self.gpu_mode:
222 | sample_z_ = sample_z_.cuda()
223 |
224 | samples = self.G(sample_z_)
225 |
226 | if self.gpu_mode:
227 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
228 | else:
229 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
230 |
231 | samples = (samples + 1) / 2
232 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
233 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
234 |
235 | def save(self):
236 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
237 |
238 | if not os.path.exists(save_dir):
239 | os.makedirs(save_dir)
240 |
241 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
242 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
243 |
244 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
245 | pickle.dump(self.train_hist, f)
246 |
247 | def load(self):
248 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
249 |
250 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
251 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/EBGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.autograd import Variable
6 | from torch.utils.data import DataLoader
7 | from torchvision import datasets, transforms
8 | from dataloader import dataloader
9 |
10 | class generator(nn.Module):
11 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
12 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
13 | def __init__(self, input_dim=100, output_dim=1, input_size=32):
14 | super(generator, self).__init__()
15 | self.input_dim = input_dim
16 | self.output_dim = output_dim
17 | self.input_size = input_size
18 |
19 | self.fc = nn.Sequential(
20 | nn.Linear(self.input_dim, 1024),
21 | nn.BatchNorm1d(1024),
22 | nn.ReLU(),
23 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
24 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
25 | nn.ReLU(),
26 | )
27 | self.deconv = nn.Sequential(
28 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
29 | nn.BatchNorm2d(64),
30 | nn.ReLU(),
31 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
32 | nn.Tanh(),
33 | )
34 | utils.initialize_weights(self)
35 |
36 | def forward(self, input):
37 | x = self.fc(input)
38 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
39 | x = self.deconv(x)
40 |
41 | return x
42 |
43 | class discriminator(nn.Module):
44 | # It must be Auto-Encoder style architecture
45 | # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S
46 | def __init__(self, input_dim=1, output_dim=1, input_size=32):
47 | super(discriminator, self).__init__()
48 | self.input_dim = input_dim
49 | self.output_dim = output_dim
50 | self.input_size = input_size
51 |
52 | self.conv = nn.Sequential(
53 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
54 | nn.ReLU(),
55 | )
56 | self.code = nn.Sequential(
57 | nn.Linear(64 * (self.input_size // 2) * (self.input_size // 2), 32), # bn and relu are excluded since code is used in pullaway_loss
58 | )
59 | self.fc = nn.Sequential(
60 | nn.Linear(32, 64 * (self.input_size // 2) * (self.input_size // 2)),
61 | nn.BatchNorm1d(64 * (self.input_size // 2) * (self.input_size // 2)),
62 | nn.ReLU(),
63 | )
64 | self.deconv = nn.Sequential(
65 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
66 | # nn.Sigmoid(),
67 | )
68 | utils.initialize_weights(self)
69 |
70 | def forward(self, input):
71 | x = self.conv(input)
72 | x = x.view(x.size()[0], -1)
73 | code = self.code(x)
74 | x = self.fc(code)
75 | x = x.view(-1, 64, (self.input_size // 2), (self.input_size // 2))
76 | x = self.deconv(x)
77 |
78 | return x, code
79 |
80 | class EBGAN(object):
81 | def __init__(self, args):
82 | # parameters
83 | self.epoch = args.epoch
84 | self.sample_num = 100
85 | self.batch_size = args.batch_size
86 | self.save_dir = args.save_dir
87 | self.result_dir = args.result_dir
88 | self.dataset = args.dataset
89 | self.log_dir = args.log_dir
90 | self.gpu_mode = args.gpu_mode
91 | self.model_name = args.gan_type
92 | self.input_size = args.input_size
93 | self.z_dim = 62
94 | self.pt_loss_weight = 0.1
95 | self.margin = 1
96 |
97 | # load dataset
98 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
99 | data = self.data_loader.__iter__().__next__()[0]
100 |
101 | # networks init
102 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
103 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
104 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
105 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
106 |
107 | if self.gpu_mode:
108 | self.G.cuda()
109 | self.D.cuda()
110 | self.MSE_loss = nn.MSELoss().cuda()
111 | else:
112 | self.MSE_loss = nn.MSELoss()
113 |
114 | print('---------- Networks architecture -------------')
115 | utils.print_network(self.G)
116 | utils.print_network(self.D)
117 | print('-----------------------------------------------')
118 |
119 | # fixed noise
120 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
121 | if self.gpu_mode:
122 | self.sample_z_ = self.sample_z_.cuda()
123 |
124 | def train(self):
125 | self.train_hist = {}
126 | self.train_hist['D_loss'] = []
127 | self.train_hist['G_loss'] = []
128 | self.train_hist['per_epoch_time'] = []
129 | self.train_hist['total_time'] = []
130 |
131 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
132 | if self.gpu_mode:
133 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
134 |
135 | self.D.train()
136 | print('training start!!')
137 | start_time = time.time()
138 | for epoch in range(self.epoch):
139 | self.G.train()
140 | epoch_start_time = time.time()
141 | for iter, (x_, _) in enumerate(self.data_loader):
142 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
143 | break
144 |
145 | z_ = torch.rand((self.batch_size, self.z_dim))
146 | if self.gpu_mode:
147 | x_, z_ = x_.cuda(), z_.cuda()
148 |
149 | # update D network
150 | self.D_optimizer.zero_grad()
151 |
152 | D_real, _ = self.D(x_)
153 | D_real_loss = self.MSE_loss(D_real, x_)
154 |
155 | G_ = self.G(z_)
156 | D_fake, _ = self.D(G_)
157 | D_fake_loss = self.MSE_loss(D_fake, G_.detach())
158 |
159 | D_loss = D_real_loss + torch.clamp(self.margin - D_fake_loss, min=0)
160 | self.train_hist['D_loss'].append(D_loss.item())
161 |
162 | D_loss.backward()
163 | self.D_optimizer.step()
164 |
165 | # update G network
166 | self.G_optimizer.zero_grad()
167 |
168 | G_ = self.G(z_)
169 | D_fake, D_fake_code = self.D(G_)
170 | D_fake_loss = self.MSE_loss(D_fake, G_.detach())
171 | G_loss = D_fake_loss + self.pt_loss_weight * self.pullaway_loss(D_fake_code.view(self.batch_size, -1))
172 | self.train_hist['G_loss'].append(G_loss.item())
173 |
174 | G_loss.backward()
175 | self.G_optimizer.step()
176 |
177 | if ((iter + 1) % 100) == 0:
178 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
179 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
180 |
181 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
182 | with torch.no_grad():
183 | self.visualize_results((epoch+1))
184 |
185 | self.train_hist['total_time'].append(time.time() - start_time)
186 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
187 | self.epoch, self.train_hist['total_time'][0]))
188 | print("Training finish!... save training results")
189 |
190 | self.save()
191 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
192 | self.epoch)
193 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
194 |
195 | def pullaway_loss(self, embeddings):
196 | """ pullaway_loss tensorflow version code
197 |
198 | norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
199 | normalized_embeddings = embeddings / norm
200 | similarity = tf.matmul(
201 | normalized_embeddings, normalized_embeddings, transpose_b=True)
202 | batch_size = tf.cast(tf.shape(embeddings)[0], tf.float32)
203 | pt_loss = (tf.reduce_sum(similarity) - batch_size) / (batch_size * (batch_size - 1))
204 | return pt_loss
205 |
206 | """
207 | # norm = torch.sqrt(torch.sum(embeddings ** 2, 1, keepdim=True))
208 | # normalized_embeddings = embeddings / norm
209 | # similarity = torch.matmul(normalized_embeddings, normalized_embeddings.transpose(1, 0))
210 | # batch_size = embeddings.size()[0]
211 | # pt_loss = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1))
212 |
213 | norm = torch.norm(embeddings, 1)
214 | normalized_embeddings = embeddings / norm
215 | similarity = torch.matmul(normalized_embeddings, normalized_embeddings.transpose(1, 0)) ** 2
216 | batch_size = embeddings.size()[0]
217 | pt_loss = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1))
218 |
219 | return pt_loss
220 |
221 |
222 | def visualize_results(self, epoch, fix=True):
223 | self.G.eval()
224 |
225 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
226 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
227 |
228 | tot_num_samples = min(self.sample_num, self.batch_size)
229 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
230 |
231 | if fix:
232 | """ fixed noise """
233 | samples = self.G(self.sample_z_)
234 | else:
235 | """ random noise """
236 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
237 | if self.gpu_mode:
238 | sample_z_ = sample_z_.cuda()
239 |
240 | samples = self.G(sample_z_)
241 |
242 | if self.gpu_mode:
243 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
244 | else:
245 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
246 |
247 | samples = (samples + 1) / 2
248 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
249 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
250 |
251 | def save(self):
252 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
253 |
254 | if not os.path.exists(save_dir):
255 | os.makedirs(save_dir)
256 |
257 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
258 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
259 |
260 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
261 | pickle.dump(self.train_hist, f)
262 |
263 | def load(self):
264 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
265 |
266 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
267 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/GAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from dataloader import dataloader
6 |
7 | class generator(nn.Module):
8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
10 | def __init__(self, input_dim=100, output_dim=1, input_size=32):
11 | super(generator, self).__init__()
12 | self.input_dim = input_dim
13 | self.output_dim = output_dim
14 | self.input_size = input_size
15 |
16 | self.fc = nn.Sequential(
17 | nn.Linear(self.input_dim, 1024),
18 | nn.BatchNorm1d(1024),
19 | nn.ReLU(),
20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.ReLU(),
23 | )
24 | self.deconv = nn.Sequential(
25 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
26 | nn.BatchNorm2d(64),
27 | nn.ReLU(),
28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
29 | nn.Tanh(),
30 | )
31 | utils.initialize_weights(self)
32 |
33 | def forward(self, input):
34 | x = self.fc(input)
35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
36 | x = self.deconv(x)
37 |
38 | return x
39 |
40 | class discriminator(nn.Module):
41 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
42 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
43 | def __init__(self, input_dim=1, output_dim=1, input_size=32):
44 | super(discriminator, self).__init__()
45 | self.input_dim = input_dim
46 | self.output_dim = output_dim
47 | self.input_size = input_size
48 |
49 | self.conv = nn.Sequential(
50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
51 | nn.LeakyReLU(0.2),
52 | nn.Conv2d(64, 128, 4, 2, 1),
53 | nn.BatchNorm2d(128),
54 | nn.LeakyReLU(0.2),
55 | )
56 | self.fc = nn.Sequential(
57 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
58 | nn.BatchNorm1d(1024),
59 | nn.LeakyReLU(0.2),
60 | nn.Linear(1024, self.output_dim),
61 | nn.Sigmoid(),
62 | )
63 | utils.initialize_weights(self)
64 |
65 | def forward(self, input):
66 | x = self.conv(input)
67 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
68 | x = self.fc(x)
69 |
70 | return x
71 |
72 | class GAN(object):
73 | def __init__(self, args):
74 | # parameters
75 | self.epoch = args.epoch
76 | self.sample_num = 100
77 | self.batch_size = args.batch_size
78 | self.save_dir = args.save_dir
79 | self.result_dir = args.result_dir
80 | self.dataset = args.dataset
81 | self.log_dir = args.log_dir
82 | self.gpu_mode = args.gpu_mode
83 | self.model_name = args.gan_type
84 | self.input_size = args.input_size
85 | self.z_dim = 62
86 |
87 | # load dataset
88 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
89 | data = self.data_loader.__iter__().__next__()[0]
90 |
91 | # networks init
92 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
93 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
94 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
95 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
96 |
97 | if self.gpu_mode:
98 | self.G.cuda()
99 | self.D.cuda()
100 | self.BCE_loss = nn.BCELoss().cuda()
101 | else:
102 | self.BCE_loss = nn.BCELoss()
103 |
104 | print('---------- Networks architecture -------------')
105 | utils.print_network(self.G)
106 | utils.print_network(self.D)
107 | print('-----------------------------------------------')
108 |
109 |
110 | # fixed noise
111 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
112 | if self.gpu_mode:
113 | self.sample_z_ = self.sample_z_.cuda()
114 |
115 |
116 | def train(self):
117 | self.train_hist = {}
118 | self.train_hist['D_loss'] = []
119 | self.train_hist['G_loss'] = []
120 | self.train_hist['per_epoch_time'] = []
121 | self.train_hist['total_time'] = []
122 |
123 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
124 | if self.gpu_mode:
125 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
126 |
127 | self.D.train()
128 | print('training start!!')
129 | start_time = time.time()
130 | for epoch in range(self.epoch):
131 | self.G.train()
132 | epoch_start_time = time.time()
133 | for iter, (x_, _) in enumerate(self.data_loader):
134 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
135 | break
136 |
137 | z_ = torch.rand((self.batch_size, self.z_dim))
138 | if self.gpu_mode:
139 | x_, z_ = x_.cuda(), z_.cuda()
140 |
141 | # update D network
142 | self.D_optimizer.zero_grad()
143 |
144 | D_real = self.D(x_)
145 | D_real_loss = self.BCE_loss(D_real, self.y_real_)
146 |
147 | G_ = self.G(z_)
148 | D_fake = self.D(G_)
149 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
150 |
151 | D_loss = D_real_loss + D_fake_loss
152 | self.train_hist['D_loss'].append(D_loss.item())
153 |
154 | D_loss.backward()
155 | self.D_optimizer.step()
156 |
157 | # update G network
158 | self.G_optimizer.zero_grad()
159 |
160 | G_ = self.G(z_)
161 | D_fake = self.D(G_)
162 | G_loss = self.BCE_loss(D_fake, self.y_real_)
163 | self.train_hist['G_loss'].append(G_loss.item())
164 |
165 | G_loss.backward()
166 | self.G_optimizer.step()
167 |
168 | if ((iter + 1) % 100) == 0:
169 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
170 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
171 |
172 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
173 | with torch.no_grad():
174 | self.visualize_results((epoch+1))
175 |
176 | self.train_hist['total_time'].append(time.time() - start_time)
177 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
178 | self.epoch, self.train_hist['total_time'][0]))
179 | print("Training finish!... save training results")
180 |
181 | self.save()
182 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
183 | self.epoch)
184 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
185 |
186 | def visualize_results(self, epoch, fix=True):
187 | self.G.eval()
188 |
189 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
190 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
191 |
192 | tot_num_samples = min(self.sample_num, self.batch_size)
193 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
194 |
195 | if fix:
196 | """ fixed noise """
197 | samples = self.G(self.sample_z_)
198 | else:
199 | """ random noise """
200 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
201 | if self.gpu_mode:
202 | sample_z_ = sample_z_.cuda()
203 |
204 | samples = self.G(sample_z_)
205 |
206 | if self.gpu_mode:
207 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
208 | else:
209 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
210 |
211 | samples = (samples + 1) / 2
212 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
213 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
214 |
215 | def save(self):
216 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
217 |
218 | if not os.path.exists(save_dir):
219 | os.makedirs(save_dir)
220 |
221 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
222 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
223 |
224 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
225 | pickle.dump(self.train_hist, f)
226 |
227 | def load(self):
228 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
229 |
230 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
231 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/LSGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from dataloader import dataloader
6 |
7 | class generator(nn.Module):
8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
10 | def __init__(self, input_dim=100, output_dim=1, input_size=32):
11 | super(generator, self).__init__()
12 | self.input_dim = input_dim
13 | self.output_dim = output_dim
14 | self.input_size = input_size
15 |
16 | self.fc = nn.Sequential(
17 | nn.Linear(self.input_dim, 1024),
18 | nn.BatchNorm1d(1024),
19 | nn.ReLU(),
20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.ReLU(),
23 | )
24 | self.deconv = nn.Sequential(
25 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
26 | nn.BatchNorm2d(64),
27 | nn.ReLU(),
28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
29 | nn.Tanh(),
30 | )
31 | utils.initialize_weights(self)
32 |
33 | def forward(self, input):
34 | x = self.fc(input)
35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
36 | x = self.deconv(x)
37 |
38 | return x
39 |
40 | class discriminator(nn.Module):
41 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
42 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
43 | def __init__(self, input_dim=1, output_dim=1, input_size=32):
44 | super(discriminator, self).__init__()
45 | self.input_dim = input_dim
46 | self.output_dim = output_dim
47 | self.input_size = input_size
48 |
49 | self.conv = nn.Sequential(
50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
51 | nn.LeakyReLU(0.2),
52 | nn.Conv2d(64, 128, 4, 2, 1),
53 | nn.BatchNorm2d(128),
54 | nn.LeakyReLU(0.2),
55 | )
56 | self.fc = nn.Sequential(
57 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
58 | nn.BatchNorm1d(1024),
59 | nn.LeakyReLU(0.2),
60 | nn.Linear(1024, self.output_dim),
61 | # nn.Sigmoid(),
62 | )
63 | utils.initialize_weights(self)
64 |
65 | def forward(self, input):
66 | x = self.conv(input)
67 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
68 | x = self.fc(x)
69 |
70 | return x
71 |
72 | class LSGAN(object):
73 | def __init__(self, args):
74 | # parameters
75 | self.epoch = args.epoch
76 | self.sample_num = 100
77 | self.batch_size = args.batch_size
78 | self.save_dir = args.save_dir
79 | self.result_dir = args.result_dir
80 | self.dataset = args.dataset
81 | self.log_dir = args.log_dir
82 | self.gpu_mode = args.gpu_mode
83 | self.model_name = args.gan_type
84 | self.input_size = args.input_size
85 | self.z_dim = 62
86 |
87 | # load dataset
88 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
89 | data = self.data_loader.__iter__().__next__()[0]
90 |
91 | # networks init
92 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
93 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
94 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
95 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
96 |
97 | if self.gpu_mode:
98 | self.G.cuda()
99 | self.D.cuda()
100 | self.MSE_loss = nn.MSELoss().cuda()
101 | else:
102 | self.MSE_loss = nn.MSELoss()
103 |
104 | print('---------- Networks architecture -------------')
105 | utils.print_network(self.G)
106 | utils.print_network(self.D)
107 | print('-----------------------------------------------')
108 |
109 | # fixed noise
110 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
111 | if self.gpu_mode:
112 | self.sample_z_ = self.sample_z_.cuda()
113 |
114 | def train(self):
115 | self.train_hist = {}
116 | self.train_hist['D_loss'] = []
117 | self.train_hist['G_loss'] = []
118 | self.train_hist['per_epoch_time'] = []
119 | self.train_hist['total_time'] = []
120 |
121 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
122 | if self.gpu_mode:
123 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
124 |
125 | self.D.train()
126 | print('training start!!')
127 | start_time = time.time()
128 | for epoch in range(self.epoch):
129 | self.G.train()
130 | epoch_start_time = time.time()
131 | for iter, (x_, _) in enumerate(self.data_loader):
132 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
133 | break
134 |
135 | z_ = torch.rand((self.batch_size, self.z_dim))
136 | if self.gpu_mode:
137 | x_, z_ = x_.cuda(), z_.cuda()
138 |
139 | # update D network
140 | self.D_optimizer.zero_grad()
141 |
142 | D_real = self.D(x_)
143 | D_real_loss = self.MSE_loss(D_real, self.y_real_)
144 |
145 | G_ = self.G(z_)
146 | D_fake = self.D(G_)
147 | D_fake_loss = self.MSE_loss(D_fake, self.y_fake_)
148 |
149 | D_loss = D_real_loss + D_fake_loss
150 | self.train_hist['D_loss'].append(D_loss.item())
151 |
152 | D_loss.backward()
153 | self.D_optimizer.step()
154 |
155 | # update G network
156 | self.G_optimizer.zero_grad()
157 |
158 | G_ = self.G(z_)
159 | D_fake = self.D(G_)
160 | G_loss = self.MSE_loss(D_fake, self.y_real_)
161 | self.train_hist['G_loss'].append(G_loss.item())
162 |
163 | G_loss.backward()
164 | self.G_optimizer.step()
165 |
166 | if ((iter + 1) % 100) == 0:
167 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
168 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
169 |
170 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
171 | with torch.no_grad():
172 | self.visualize_results((epoch+1))
173 |
174 | self.train_hist['total_time'].append(time.time() - start_time)
175 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
176 | self.epoch, self.train_hist['total_time'][0]))
177 | print("Training finish!... save training results")
178 |
179 | self.save()
180 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
181 | self.epoch)
182 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
183 |
184 | def visualize_results(self, epoch, fix=True):
185 | self.G.eval()
186 |
187 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
188 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
189 |
190 | tot_num_samples = min(self.sample_num, self.batch_size)
191 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
192 |
193 | if fix:
194 | """ fixed noise """
195 | samples = self.G(self.sample_z_)
196 | else:
197 | """ random noise """
198 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
199 | if self.gpu_mode:
200 | sample_z_ = sample_z_.cuda()
201 |
202 | samples = self.G(sample_z_)
203 |
204 | if self.gpu_mode:
205 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
206 | else:
207 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
208 |
209 | samples = (samples + 1) / 2
210 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
211 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
212 |
213 | def save(self):
214 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
215 |
216 | if not os.path.exists(save_dir):
217 | os.makedirs(save_dir)
218 |
219 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
220 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
221 |
222 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
223 | pickle.dump(self.train_hist, f)
224 |
225 | def load(self):
226 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
227 |
228 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
229 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-generative-model-collections
2 | Original : [[Tensorflow version]](https://github.com/hwalsuklee/tensorflow-generative-model-collections)
3 |
4 | Pytorch implementation of various GANs.
5 |
6 | This repository was re-implemented with reference to [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections) by [Hwalsuk Lee](https://github.com/hwalsuklee)
7 |
8 | I tried to implement this repository as much as possible with [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections), But some models are a little different.
9 |
10 | This repository is included code for CPU mode Pytorch, but i did not test. I tested only in GPU mode Pytorch.
11 |
12 | ## Dataset
13 |
14 | - MNIST
15 | - Fashion-MNIST
16 | - CIFAR10
17 | - SVHN
18 | - STL10
19 | - LSUN-bed
20 | #### I only tested the code on MNIST and Fashion-MNIST.
21 |
22 | ## Generative Adversarial Networks (GANs)
23 | ### Lists (Table is borrowed from [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections))
24 |
25 | *Name* | *Paper Link* | *Value Function*
26 | :---: | :---: | :--- |
27 | **GAN** | [Arxiv](https://arxiv.org/abs/1406.2661) |
28 | **LSGAN**| [Arxiv](https://arxiv.org/abs/1611.04076) |
29 | **WGAN**| [Arxiv](https://arxiv.org/abs/1701.07875) |
30 | **WGAN_GP**| [Arxiv](https://arxiv.org/abs/1704.00028) |
31 | **DRAGAN**| [Arxiv](https://arxiv.org/abs/1705.07215) |
32 | **CGAN**| [Arxiv](https://arxiv.org/abs/1411.1784) |
33 | **infoGAN**| [Arxiv](https://arxiv.org/abs/1606.03657) |
34 | **ACGAN**| [Arxiv](https://arxiv.org/abs/1610.09585) |
35 | **EBGAN**| [Arxiv](https://arxiv.org/abs/1609.03126) |
36 | **BEGAN**| [Arxiv](https://arxiv.org/abs/1703.10717) |
37 |
38 | #### Variants of GAN structure (Figures are borrowed from [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections))
39 |
40 |
41 | ### Results for mnist
42 | Network architecture of generator and discriminator is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.03657).
43 | For fair comparison of core ideas in all gan variants, all implementations for network architecture are kept same except EBGAN and BEGAN. Small modification is made for EBGAN/BEGAN, since those adopt auto-encoder strucutre for discriminator. But I tried to keep the capacity of discirminator.
44 |
45 | The following results can be reproduced with command:
46 | ```
47 | python main.py --dataset mnist --gan_type --epoch 50 --batch_size 64
48 | ```
49 |
50 | #### Fixed generation
51 | All results are generated from the fixed noise vector.
52 |
53 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF*
54 | :---: | :---: | :---: | :---: | :---: |
55 | GAN |
|
|
|
56 | LSGAN |
|
|
|
57 | WGAN |
|
|
|
58 | WGAN_GP |
|
|
|
59 | DRAGAN |
|
|
|
60 | EBGAN |
|
|
|
61 | BEGAN |
|
|
|
62 |
63 | #### Conditional generation
64 | Each row has the same noise vector and each column has the same label condition.
65 |
66 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF*
67 | :---: | :---: | :---: | :---: | :---: |
68 | CGAN |
|
|
|
69 | ACGAN |
|
|
|
70 | infoGAN |
|
|
|
71 |
72 | #### InfoGAN : Manipulating two continous codes
73 | All results have the same noise vector and label condition, but have different continous vector.
74 |
75 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF*
76 | :---: | :---: | :---: | :---: | :---: |
77 | infoGAN |
|
|
|
78 |
79 | #### Loss plot
80 |
81 | *Name* | *Loss*
82 | :---: | :---: |
83 | GAN |
84 | LSGAN |
85 | WGAN |
86 | WGAN_GP |
87 | DRAGAN |
88 | EBGAN |
89 | BEGAN |
90 | CGAN |
91 | ACGAN |
92 | infoGAN |
93 |
94 | ### Results for fashion-mnist
95 | Comments on network architecture in mnist are also applied to here.
96 | [Fashion-mnist](https://github.com/zalandoresearch/fashion-mnist) is a recently proposed dataset consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. (T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot)
97 |
98 | The following results can be reproduced with command:
99 | ```
100 | python main.py --dataset fashion-mnist --gan_type --epoch 50 --batch_size 64
101 | ```
102 |
103 | #### Fixed generation
104 | All results are generated from the fixed noise vector.
105 |
106 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF*
107 | :---: | :---: | :---: | :---: | :---: |
108 | GAN |
|
|
|
109 | LSGAN |
|
|
|
110 | WGAN |
|
|
|
111 | WGAN_GP |
|
|
|
112 | DRAGAN |
|
|
|
113 | EBGAN |
|
|
|
114 | BEGAN |
|
|
|
115 |
116 | #### Conditional generation
117 | Each row has the same noise vector and each column has the same label condition.
118 |
119 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF*
120 | :---: | :---: | :---: | :---: | :---: |
121 | CGAN |
|
|
|
122 | ACGAN |
|
|
|
123 | infoGAN |
|
|
|
124 |
125 | - ACGAN tends to fall into mode-collapse in [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections), but Pytorch ACGAN does not fall into mode-collapse.
126 |
127 | #### InfoGAN : Manipulating two continous codes
128 | All results have the same noise vector and label condition, but have different continous vector.
129 |
130 | *Name* | *Epoch 1* | *Epoch 25* | *Epoch 50* | *GIF*
131 | :---: | :---: | :---: | :---: | :---: |
132 | infoGAN |
|
|
|
133 |
134 | #### Loss plot
135 |
136 | *Name* | *Loss*
137 | :---: | :---: |
138 | GAN |
139 | LSGAN |
140 | WGAN |
141 | WGAN_GP |
142 | DRAGAN |
143 | EBGAN |
144 | BEGAN |
145 | CGAN |
146 | ACGAN |
147 | infoGAN |
148 |
149 | ## Folder structure
150 | The following shows basic folder structure.
151 | ```
152 | ├── main.py # gateway
153 | ├── data
154 | │ ├── mnist # mnist data (not included in this repo)
155 | │ ├── ...
156 | │ ├── ...
157 | │ └── fashion-mnist # fashion-mnist data (not included in this repo)
158 | │
159 | ├── GAN.py # vainilla GAN
160 | ├── utils.py # utils
161 | ├── dataloader.py # dataloader
162 | ├── models # model files to be saved here
163 | └── results # generation results to be saved here
164 | ```
165 |
166 | ## Development Environment
167 | * Ubuntu 16.04 LTS
168 | * NVIDIA GTX 1080 ti
169 | * cuda 9.0
170 | * Python 3.5.2
171 | * pytorch 0.4.0
172 | * torchvision 0.2.1
173 | * numpy 1.14.3
174 | * matplotlib 2.2.2
175 | * imageio 2.3.0
176 | * scipy 1.1.0
177 |
178 | ## Acknowledgements
179 | This implementation has been based on [tensorflow-generative-model-collections](https://github.com/hwalsuklee/tensorflow-generative-model-collections) and tested with Pytorch 0.4.0 on Ubuntu 16.04 using GPU.
180 |
181 |
--------------------------------------------------------------------------------
/WGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from dataloader import dataloader
6 |
7 | class generator(nn.Module):
8 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
9 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
10 | def __init__(self, input_dim=100, output_dim=1, input_size=32):
11 | super(generator, self).__init__()
12 | self.input_dim = input_dim
13 | self.output_dim = output_dim
14 | self.input_size = input_size
15 |
16 | self.fc = nn.Sequential(
17 | nn.Linear(self.input_dim, 1024),
18 | nn.BatchNorm1d(1024),
19 | nn.ReLU(),
20 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
21 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.ReLU(),
23 | )
24 | self.deconv = nn.Sequential(
25 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
26 | nn.BatchNorm2d(64),
27 | nn.ReLU(),
28 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
29 | nn.Tanh(),
30 | )
31 | utils.initialize_weights(self)
32 |
33 | def forward(self, input):
34 | x = self.fc(input)
35 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
36 | x = self.deconv(x)
37 |
38 | return x
39 |
40 | class discriminator(nn.Module):
41 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
42 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
43 | def __init__(self, input_dim=1, output_dim=1, input_size=32):
44 | super(discriminator, self).__init__()
45 | self.input_dim = input_dim
46 | self.output_dim = output_dim
47 | self.input_size = input_size
48 |
49 | self.conv = nn.Sequential(
50 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
51 | nn.LeakyReLU(0.2),
52 | nn.Conv2d(64, 128, 4, 2, 1),
53 | nn.BatchNorm2d(128),
54 | nn.LeakyReLU(0.2),
55 | )
56 | self.fc = nn.Sequential(
57 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
58 | nn.BatchNorm1d(1024),
59 | nn.LeakyReLU(0.2),
60 | nn.Linear(1024, self.output_dim),
61 | # nn.Sigmoid(),
62 | )
63 | utils.initialize_weights(self)
64 |
65 | def forward(self, input):
66 | x = self.conv(input)
67 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
68 | x = self.fc(x)
69 |
70 | return x
71 |
72 | class WGAN(object):
73 | def __init__(self, args):
74 | # parameters
75 | self.epoch = args.epoch
76 | self.sample_num = 100
77 | self.batch_size = args.batch_size
78 | self.save_dir = args.save_dir
79 | self.result_dir = args.result_dir
80 | self.dataset = args.dataset
81 | self.log_dir = args.log_dir
82 | self.gpu_mode = args.gpu_mode
83 | self.model_name = args.gan_type
84 | self.input_size = args.input_size
85 | self.z_dim = 62
86 | self.c = 0.01 # clipping value
87 | self.n_critic = 5 # the number of iterations of the critic per generator iteration
88 |
89 | # load dataset
90 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
91 | data = self.data_loader.__iter__().__next__()[0]
92 |
93 | # networks init
94 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
95 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
96 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
97 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
98 |
99 | if self.gpu_mode:
100 | self.G.cuda()
101 | self.D.cuda()
102 |
103 | print('---------- Networks architecture -------------')
104 | utils.print_network(self.G)
105 | utils.print_network(self.D)
106 | print('-----------------------------------------------')
107 |
108 | # fixed noise
109 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
110 | if self.gpu_mode:
111 | self.sample_z_ = self.sample_z_.cuda()
112 |
113 | def train(self):
114 | self.train_hist = {}
115 | self.train_hist['D_loss'] = []
116 | self.train_hist['G_loss'] = []
117 | self.train_hist['per_epoch_time'] = []
118 | self.train_hist['total_time'] = []
119 |
120 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
121 | if self.gpu_mode:
122 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
123 |
124 | self.D.train()
125 | print('training start!!')
126 | start_time = time.time()
127 | for epoch in range(self.epoch):
128 | self.G.train()
129 | epoch_start_time = time.time()
130 | for iter, (x_, _) in enumerate(self.data_loader):
131 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
132 | break
133 |
134 | z_ = torch.rand((self.batch_size, self.z_dim))
135 | if self.gpu_mode:
136 | x_, z_ = x_.cuda(), z_.cuda()
137 |
138 | # update D network
139 | self.D_optimizer.zero_grad()
140 |
141 | D_real = self.D(x_)
142 | D_real_loss = -torch.mean(D_real)
143 |
144 | G_ = self.G(z_)
145 | D_fake = self.D(G_)
146 | D_fake_loss = torch.mean(D_fake)
147 |
148 | D_loss = D_real_loss + D_fake_loss
149 |
150 | D_loss.backward()
151 | self.D_optimizer.step()
152 |
153 | # clipping D
154 | for p in self.D.parameters():
155 | p.data.clamp_(-self.c, self.c)
156 |
157 | if ((iter+1) % self.n_critic) == 0:
158 | # update G network
159 | self.G_optimizer.zero_grad()
160 |
161 | G_ = self.G(z_)
162 | D_fake = self.D(G_)
163 | G_loss = -torch.mean(D_fake)
164 | self.train_hist['G_loss'].append(G_loss.item())
165 |
166 | G_loss.backward()
167 | self.G_optimizer.step()
168 |
169 | self.train_hist['D_loss'].append(D_loss.item())
170 |
171 | if ((iter + 1) % 100) == 0:
172 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
173 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
174 |
175 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
176 | with torch.no_grad():
177 | self.visualize_results((epoch+1))
178 |
179 | self.train_hist['total_time'].append(time.time() - start_time)
180 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
181 | self.epoch, self.train_hist['total_time'][0]))
182 | print("Training finish!... save training results")
183 |
184 | self.save()
185 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
186 | self.epoch)
187 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
188 |
189 | def visualize_results(self, epoch, fix=True):
190 | self.G.eval()
191 |
192 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
193 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
194 |
195 | tot_num_samples = min(self.sample_num, self.batch_size)
196 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
197 |
198 | if fix:
199 | """ fixed noise """
200 | samples = self.G(self.sample_z_)
201 | else:
202 | """ random noise """
203 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
204 | if self.gpu_mode:
205 | sample_z_ = sample_z_.cuda()
206 |
207 | samples = self.G(sample_z_)
208 |
209 | if self.gpu_mode:
210 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
211 | else:
212 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
213 |
214 | samples = (samples + 1) / 2
215 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
216 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
217 |
218 | def save(self):
219 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
220 |
221 | if not os.path.exists(save_dir):
222 | os.makedirs(save_dir)
223 |
224 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
225 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
226 |
227 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
228 | pickle.dump(self.train_hist, f)
229 |
230 | def load(self):
231 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
232 |
233 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
234 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/WGAN_GP.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.autograd import grad
6 | from dataloader import dataloader
7 |
8 | class generator(nn.Module):
9 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
10 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
11 | def __init__(self, input_dim=100, output_dim=1, input_size=32):
12 | super(generator, self).__init__()
13 | self.input_dim = input_dim
14 | self.output_dim = output_dim
15 | self.input_size = input_size
16 |
17 | self.fc = nn.Sequential(
18 | nn.Linear(self.input_dim, 1024),
19 | nn.BatchNorm1d(1024),
20 | nn.ReLU(),
21 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
22 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
23 | nn.ReLU(),
24 | )
25 | self.deconv = nn.Sequential(
26 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
27 | nn.BatchNorm2d(64),
28 | nn.ReLU(),
29 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
30 | nn.Tanh(),
31 | )
32 | utils.initialize_weights(self)
33 |
34 | def forward(self, input):
35 | x = self.fc(input)
36 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
37 | x = self.deconv(x)
38 |
39 | return x
40 |
41 | class discriminator(nn.Module):
42 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
43 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
44 | def __init__(self, input_dim=1, output_dim=1, input_size=32):
45 | super(discriminator, self).__init__()
46 | self.input_dim = input_dim
47 | self.output_dim = output_dim
48 | self.input_size = input_size
49 |
50 | self.conv = nn.Sequential(
51 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
52 | nn.LeakyReLU(0.2),
53 | nn.Conv2d(64, 128, 4, 2, 1),
54 | nn.BatchNorm2d(128),
55 | nn.LeakyReLU(0.2),
56 | )
57 | self.fc = nn.Sequential(
58 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
59 | nn.BatchNorm1d(1024),
60 | nn.LeakyReLU(0.2),
61 | nn.Linear(1024, self.output_dim),
62 | # nn.Sigmoid(),
63 | )
64 | utils.initialize_weights(self)
65 |
66 | def forward(self, input):
67 | x = self.conv(input)
68 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
69 | x = self.fc(x)
70 |
71 | return x
72 |
73 | class WGAN_GP(object):
74 | def __init__(self, args):
75 | # parameters
76 | self.epoch = args.epoch
77 | self.sample_num = 100
78 | self.batch_size = args.batch_size
79 | self.save_dir = args.save_dir
80 | self.result_dir = args.result_dir
81 | self.dataset = args.dataset
82 | self.log_dir = args.log_dir
83 | self.gpu_mode = args.gpu_mode
84 | self.model_name = args.gan_type
85 | self.input_size = args.input_size
86 | self.z_dim = 62
87 | self.lambda_ = 10
88 | self.n_critic = 5 # the number of iterations of the critic per generator iteration
89 |
90 | # load dataset
91 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
92 | data = self.data_loader.__iter__().__next__()[0]
93 |
94 | # networks init
95 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
96 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
97 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
98 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
99 |
100 | if self.gpu_mode:
101 | self.G.cuda()
102 | self.D.cuda()
103 |
104 | print('---------- Networks architecture -------------')
105 | utils.print_network(self.G)
106 | utils.print_network(self.D)
107 | print('-----------------------------------------------')
108 |
109 | # fixed noise
110 | self.sample_z_ = torch.rand((self.batch_size, self.z_dim))
111 | if self.gpu_mode:
112 | self.sample_z_ = self.sample_z_.cuda()
113 |
114 | def train(self):
115 | self.train_hist = {}
116 | self.train_hist['D_loss'] = []
117 | self.train_hist['G_loss'] = []
118 | self.train_hist['per_epoch_time'] = []
119 | self.train_hist['total_time'] = []
120 |
121 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
122 | if self.gpu_mode:
123 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
124 |
125 | self.D.train()
126 | print('training start!!')
127 | start_time = time.time()
128 | for epoch in range(self.epoch):
129 | self.G.train()
130 | epoch_start_time = time.time()
131 | for iter, (x_, _) in enumerate(self.data_loader):
132 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
133 | break
134 |
135 | z_ = torch.rand((self.batch_size, self.z_dim))
136 | if self.gpu_mode:
137 | x_, z_ = x_.cuda(), z_.cuda()
138 |
139 | # update D network
140 | self.D_optimizer.zero_grad()
141 |
142 | D_real = self.D(x_)
143 | D_real_loss = -torch.mean(D_real)
144 |
145 | G_ = self.G(z_)
146 | D_fake = self.D(G_)
147 | D_fake_loss = torch.mean(D_fake)
148 |
149 | # gradient penalty
150 | alpha = torch.rand((self.batch_size, 1, 1, 1))
151 | if self.gpu_mode:
152 | alpha = alpha.cuda()
153 |
154 | x_hat = alpha * x_.data + (1 - alpha) * G_.data
155 | x_hat.requires_grad = True
156 |
157 | pred_hat = self.D(x_hat)
158 | if self.gpu_mode:
159 | gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(),
160 | create_graph=True, retain_graph=True, only_inputs=True)[0]
161 | else:
162 | gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()),
163 | create_graph=True, retain_graph=True, only_inputs=True)[0]
164 |
165 | gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()
166 |
167 | D_loss = D_real_loss + D_fake_loss + gradient_penalty
168 |
169 | D_loss.backward()
170 | self.D_optimizer.step()
171 |
172 | if ((iter+1) % self.n_critic) == 0:
173 | # update G network
174 | self.G_optimizer.zero_grad()
175 |
176 | G_ = self.G(z_)
177 | D_fake = self.D(G_)
178 | G_loss = -torch.mean(D_fake)
179 | self.train_hist['G_loss'].append(G_loss.item())
180 |
181 | G_loss.backward()
182 | self.G_optimizer.step()
183 |
184 | self.train_hist['D_loss'].append(D_loss.item())
185 |
186 | if ((iter + 1) % 100) == 0:
187 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
188 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
189 |
190 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
191 | with torch.no_grad():
192 | self.visualize_results((epoch+1))
193 |
194 | self.train_hist['total_time'].append(time.time() - start_time)
195 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
196 | self.epoch, self.train_hist['total_time'][0]))
197 | print("Training finish!... save training results")
198 |
199 | self.save()
200 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
201 | self.epoch)
202 | utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
203 |
204 | def visualize_results(self, epoch, fix=True):
205 | self.G.eval()
206 |
207 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
208 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
209 |
210 | tot_num_samples = min(self.sample_num, self.batch_size)
211 | image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
212 |
213 | if fix:
214 | """ fixed noise """
215 | samples = self.G(self.sample_z_)
216 | else:
217 | """ random noise """
218 | sample_z_ = torch.rand((self.batch_size, self.z_dim))
219 | if self.gpu_mode:
220 | sample_z_ = sample_z_.cuda()
221 |
222 | samples = self.G(sample_z_)
223 |
224 | if self.gpu_mode:
225 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
226 | else:
227 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
228 |
229 | samples = (samples + 1) / 2
230 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
231 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
232 |
233 | def save(self):
234 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
235 |
236 | if not os.path.exists(save_dir):
237 | os.makedirs(save_dir)
238 |
239 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
240 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
241 |
242 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
243 | pickle.dump(self.train_hist, f)
244 |
245 | def load(self):
246 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
247 |
248 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
249 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
--------------------------------------------------------------------------------
/assets/celebA_results/BEGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/celebA_results/BEGAN_epoch010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_epoch010.png
--------------------------------------------------------------------------------
/assets/celebA_results/BEGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/celebA_results/BEGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/BEGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/celebA_results/DRAGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/celebA_results/DRAGAN_epoch010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_epoch010.png
--------------------------------------------------------------------------------
/assets/celebA_results/DRAGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/celebA_results/DRAGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/DRAGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/celebA_results/EBGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/celebA_results/EBGAN_epoch010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_epoch010.png
--------------------------------------------------------------------------------
/assets/celebA_results/EBGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/celebA_results/EBGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/EBGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/celebA_results/GAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_epoch001.png
--------------------------------------------------------------------------------
/assets/celebA_results/GAN_epoch010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_epoch010.png
--------------------------------------------------------------------------------
/assets/celebA_results/GAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_epoch025.png
--------------------------------------------------------------------------------
/assets/celebA_results/GAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/GAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/celebA_results/LSGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/celebA_results/LSGAN_epoch010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_epoch010.png
--------------------------------------------------------------------------------
/assets/celebA_results/LSGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/celebA_results/LSGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/LSGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_GP_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_epoch001.png
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_GP_epoch010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_epoch010.png
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_GP_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_epoch025.png
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_GP_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_GP_generate_animation.gif
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_epoch010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_epoch010.png
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/celebA_results/WGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/celebA_results/WGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/equations/ACGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/ACGAN.png
--------------------------------------------------------------------------------
/assets/equations/BEGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/BEGAN.png
--------------------------------------------------------------------------------
/assets/equations/CGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/CGAN.png
--------------------------------------------------------------------------------
/assets/equations/DRAGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/DRAGAN.png
--------------------------------------------------------------------------------
/assets/equations/EBGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/EBGAN.png
--------------------------------------------------------------------------------
/assets/equations/GAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/GAN.png
--------------------------------------------------------------------------------
/assets/equations/LSGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/LSGAN.png
--------------------------------------------------------------------------------
/assets/equations/WGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/WGAN.png
--------------------------------------------------------------------------------
/assets/equations/WGAN_GP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/WGAN_GP.png
--------------------------------------------------------------------------------
/assets/equations/infoGAN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/equations/infoGAN.png
--------------------------------------------------------------------------------
/assets/etc/GAN_structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/etc/GAN_structure.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/ACGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/ACGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/ACGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/ACGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/ACGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/ACGAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/BEGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/BEGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/BEGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/BEGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/BEGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/BEGAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/CGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/CGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/CGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/CGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/CGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/CGAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/DRAGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/DRAGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/DRAGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/DRAGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/DRAGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/DRAGAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/EBGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/EBGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/EBGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/EBGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/EBGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/EBGAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/GAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/GAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/GAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/GAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/GAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/GAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/LSGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/LSGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/LSGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/LSGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/LSGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/LSGAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_GP_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_GP_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_GP_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_GP_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_GP_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_GP_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/WGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/WGAN_loss.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_cont_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_cont_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_cont_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_cont_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_cont_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/fashion_mnist_results/infoGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/fashion_mnist_results/infoGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/ACGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/ACGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/ACGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/ACGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/ACGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/ACGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/BEGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/BEGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/BEGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/BEGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/BEGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/BEGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/CGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/CGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/CGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/CGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/CGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/CGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/DRAGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/DRAGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/DRAGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/DRAGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/DRAGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/DRAGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/EBGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/EBGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/EBGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/EBGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/EBGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/EBGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/GAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/GAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/GAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/GAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/GAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/GAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/LSGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/LSGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/LSGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/LSGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/LSGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/LSGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_GP_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_GP_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_GP_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_GP_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_GP_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_GP_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/WGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/WGAN_loss.png
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_cont_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_cont_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_cont_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_cont_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_cont_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_epoch001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_epoch001.png
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_epoch025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_epoch025.png
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_epoch050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_epoch050.png
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_generate_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_generate_animation.gif
--------------------------------------------------------------------------------
/assets/mnist_results/infoGAN_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/znxlwm/pytorch-generative-model-collections/0d183bb5ea2fbe069e1c6806c4a9a1fd8e81656f/assets/mnist_results/infoGAN_loss.png
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from torchvision import datasets, transforms
3 |
4 | def dataloader(dataset, input_size, batch_size, split='train'):
5 | transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
6 | if dataset == 'mnist':
7 | data_loader = DataLoader(
8 | datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
9 | batch_size=batch_size, shuffle=True)
10 | elif dataset == 'fashion-mnist':
11 | data_loader = DataLoader(
12 | datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
13 | batch_size=batch_size, shuffle=True)
14 | elif dataset == 'cifar10':
15 | data_loader = DataLoader(
16 | datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
17 | batch_size=batch_size, shuffle=True)
18 | elif dataset == 'svhn':
19 | data_loader = DataLoader(
20 | datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
21 | batch_size=batch_size, shuffle=True)
22 | elif dataset == 'stl10':
23 | data_loader = DataLoader(
24 | datasets.STL10('data/stl10', split=split, download=True, transform=transform),
25 | batch_size=batch_size, shuffle=True)
26 | elif dataset == 'lsun-bed':
27 | data_loader = DataLoader(
28 | datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
29 | batch_size=batch_size, shuffle=True)
30 |
31 | return data_loader
--------------------------------------------------------------------------------
/infoGAN.py:
--------------------------------------------------------------------------------
1 | import utils, torch, time, os, pickle, itertools
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import matplotlib.pyplot as plt
7 | from dataloader import dataloader
8 |
9 | class generator(nn.Module):
10 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
11 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
12 | def __init__(self, input_dim=100, output_dim=1, input_size=32, len_discrete_code=10, len_continuous_code=2):
13 | super(generator, self).__init__()
14 | self.input_dim = input_dim
15 | self.output_dim = output_dim
16 | self.input_size = input_size
17 | self.len_discrete_code = len_discrete_code # categorical distribution (i.e. label)
18 | self.len_continuous_code = len_continuous_code # gaussian distribution (e.g. rotation, thickness)
19 |
20 | self.fc = nn.Sequential(
21 | nn.Linear(self.input_dim + self.len_discrete_code + self.len_continuous_code, 1024),
22 | nn.BatchNorm1d(1024),
23 | nn.ReLU(),
24 | nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
25 | nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
26 | nn.ReLU(),
27 | )
28 | self.deconv = nn.Sequential(
29 | nn.ConvTranspose2d(128, 64, 4, 2, 1),
30 | nn.BatchNorm2d(64),
31 | nn.ReLU(),
32 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
33 | nn.Tanh(),
34 | )
35 | utils.initialize_weights(self)
36 |
37 | def forward(self, input, cont_code, dist_code):
38 | x = torch.cat([input, cont_code, dist_code], 1)
39 | x = self.fc(x)
40 | x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
41 | x = self.deconv(x)
42 |
43 | return x
44 |
45 | class discriminator(nn.Module):
46 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
47 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
48 | def __init__(self, input_dim=1, output_dim=1, input_size=32, len_discrete_code=10, len_continuous_code=2):
49 | super(discriminator, self).__init__()
50 | self.input_dim = input_dim
51 | self.output_dim = output_dim
52 | self.input_size = input_size
53 | self.len_discrete_code = len_discrete_code # categorical distribution (i.e. label)
54 | self.len_continuous_code = len_continuous_code # gaussian distribution (e.g. rotation, thickness)
55 |
56 | self.conv = nn.Sequential(
57 | nn.Conv2d(self.input_dim, 64, 4, 2, 1),
58 | nn.LeakyReLU(0.2),
59 | nn.Conv2d(64, 128, 4, 2, 1),
60 | nn.BatchNorm2d(128),
61 | nn.LeakyReLU(0.2),
62 | )
63 | self.fc = nn.Sequential(
64 | nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
65 | nn.BatchNorm1d(1024),
66 | nn.LeakyReLU(0.2),
67 | nn.Linear(1024, self.output_dim + self.len_continuous_code + self.len_discrete_code),
68 | # nn.Sigmoid(),
69 | )
70 | utils.initialize_weights(self)
71 |
72 | def forward(self, input):
73 | x = self.conv(input)
74 | x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
75 | x = self.fc(x)
76 | a = F.sigmoid(x[:, self.output_dim])
77 | b = x[:, self.output_dim:self.output_dim + self.len_continuous_code]
78 | c = x[:, self.output_dim + self.len_continuous_code:]
79 |
80 | return a, b, c
81 |
82 | class infoGAN(object):
83 | def __init__(self, args, SUPERVISED=True):
84 | # parameters
85 | self.epoch = args.epoch
86 | self.batch_size = args.batch_size
87 | self.save_dir = args.save_dir
88 | self.result_dir = args.result_dir
89 | self.dataset = args.dataset
90 | self.log_dir = args.log_dir
91 | self.gpu_mode = args.gpu_mode
92 | self.model_name = args.gan_type
93 | self.input_size = args.input_size
94 | self.z_dim = 62
95 | self.SUPERVISED = SUPERVISED # if it is true, label info is directly used for code
96 | self.len_discrete_code = 10 # categorical distribution (i.e. label)
97 | self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness)
98 | self.sample_num = self.len_discrete_code ** 2
99 |
100 | # load dataset
101 | self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
102 | data = self.data_loader.__iter__().__next__()[0]
103 |
104 | # networks init
105 | self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code)
106 | self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code)
107 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
108 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
109 | self.info_optimizer = optim.Adam(itertools.chain(self.G.parameters(), self.D.parameters()), lr=args.lrD, betas=(args.beta1, args.beta2))
110 |
111 | if self.gpu_mode:
112 | self.G.cuda()
113 | self.D.cuda()
114 | self.BCE_loss = nn.BCELoss().cuda()
115 | self.CE_loss = nn.CrossEntropyLoss().cuda()
116 | self.MSE_loss = nn.MSELoss().cuda()
117 | else:
118 | self.BCE_loss = nn.BCELoss()
119 | self.CE_loss = nn.CrossEntropyLoss()
120 | self.MSE_loss = nn.MSELoss()
121 |
122 | print('---------- Networks architecture -------------')
123 | utils.print_network(self.G)
124 | utils.print_network(self.D)
125 | print('-----------------------------------------------')
126 |
127 | # fixed noise & condition
128 | self.sample_z_ = torch.zeros((self.sample_num, self.z_dim))
129 | for i in range(self.len_discrete_code):
130 | self.sample_z_[i * self.len_discrete_code] = torch.rand(1, self.z_dim)
131 | for j in range(1, self.len_discrete_code):
132 | self.sample_z_[i * self.len_discrete_code + j] = self.sample_z_[i * self.len_discrete_code]
133 |
134 | temp = torch.zeros((self.len_discrete_code, 1))
135 | for i in range(self.len_discrete_code):
136 | temp[i, 0] = i
137 |
138 | temp_y = torch.zeros((self.sample_num, 1))
139 | for i in range(self.len_discrete_code):
140 | temp_y[i * self.len_discrete_code: (i + 1) * self.len_discrete_code] = temp
141 |
142 | self.sample_y_ = torch.zeros((self.sample_num, self.len_discrete_code)).scatter_(1, temp_y.type(torch.LongTensor), 1)
143 | self.sample_c_ = torch.zeros((self.sample_num, self.len_continuous_code))
144 |
145 | # manipulating two continuous code
146 | self.sample_z2_ = torch.rand((1, self.z_dim)).expand(self.sample_num, self.z_dim)
147 | self.sample_y2_ = torch.zeros(self.sample_num, self.len_discrete_code)
148 | self.sample_y2_[:, 0] = 1
149 |
150 | temp_c = torch.linspace(-1, 1, 10)
151 | self.sample_c2_ = torch.zeros((self.sample_num, 2))
152 | for i in range(self.len_discrete_code):
153 | for j in range(self.len_discrete_code):
154 | self.sample_c2_[i*self.len_discrete_code+j, 0] = temp_c[i]
155 | self.sample_c2_[i*self.len_discrete_code+j, 1] = temp_c[j]
156 |
157 | if self.gpu_mode:
158 | self.sample_z_, self.sample_y_, self.sample_c_, self.sample_z2_, self.sample_y2_, self.sample_c2_ = \
159 | self.sample_z_.cuda(), self.sample_y_.cuda(), self.sample_c_.cuda(), self.sample_z2_.cuda(), \
160 | self.sample_y2_.cuda(), self.sample_c2_.cuda()
161 |
162 | def train(self):
163 | self.train_hist = {}
164 | self.train_hist['D_loss'] = []
165 | self.train_hist['G_loss'] = []
166 | self.train_hist['info_loss'] = []
167 | self.train_hist['per_epoch_time'] = []
168 | self.train_hist['total_time'] = []
169 |
170 | self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
171 | if self.gpu_mode:
172 | self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
173 |
174 | self.D.train()
175 | print('training start!!')
176 | start_time = time.time()
177 | for epoch in range(self.epoch):
178 | self.G.train()
179 | epoch_start_time = time.time()
180 | for iter, (x_, y_) in enumerate(self.data_loader):
181 | if iter == self.data_loader.dataset.__len__() // self.batch_size:
182 | break
183 | z_ = torch.rand((self.batch_size, self.z_dim))
184 | if self.SUPERVISED == True:
185 | y_disc_ = torch.zeros((self.batch_size, self.len_discrete_code)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
186 | else:
187 | y_disc_ = torch.from_numpy(
188 | np.random.multinomial(1, self.len_discrete_code * [float(1.0 / self.len_discrete_code)],
189 | size=[self.batch_size])).type(torch.FloatTensor)
190 |
191 | y_cont_ = torch.from_numpy(np.random.uniform(-1, 1, size=(self.batch_size, 2))).type(torch.FloatTensor)
192 |
193 | if self.gpu_mode:
194 | x_, z_, y_disc_, y_cont_ = x_.cuda(), z_.cuda(), y_disc_.cuda(), y_cont_.cuda()
195 |
196 | # update D network
197 | self.D_optimizer.zero_grad()
198 |
199 | D_real, _, _ = self.D(x_)
200 | D_real_loss = self.BCE_loss(D_real, self.y_real_)
201 |
202 | G_ = self.G(z_, y_cont_, y_disc_)
203 | D_fake, _, _ = self.D(G_)
204 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
205 |
206 | D_loss = D_real_loss + D_fake_loss
207 | self.train_hist['D_loss'].append(D_loss.item())
208 |
209 | D_loss.backward(retain_graph=True)
210 | self.D_optimizer.step()
211 |
212 | # update G network
213 | self.G_optimizer.zero_grad()
214 |
215 | G_ = self.G(z_, y_cont_, y_disc_)
216 | D_fake, D_cont, D_disc = self.D(G_)
217 |
218 | G_loss = self.BCE_loss(D_fake, self.y_real_)
219 | self.train_hist['G_loss'].append(G_loss.item())
220 |
221 | G_loss.backward(retain_graph=True)
222 | self.G_optimizer.step()
223 |
224 | # information loss
225 | disc_loss = self.CE_loss(D_disc, torch.max(y_disc_, 1)[1])
226 | cont_loss = self.MSE_loss(D_cont, y_cont_)
227 | info_loss = disc_loss + cont_loss
228 | self.train_hist['info_loss'].append(info_loss.item())
229 |
230 | info_loss.backward()
231 | self.info_optimizer.step()
232 |
233 |
234 | if ((iter + 1) % 100) == 0:
235 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, info_loss: %.8f" %
236 | ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item(), info_loss.item()))
237 |
238 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
239 | with torch.no_grad():
240 | self.visualize_results((epoch+1))
241 |
242 | self.train_hist['total_time'].append(time.time() - start_time)
243 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
244 | self.epoch, self.train_hist['total_time'][0]))
245 | print("Training finish!... save training results")
246 |
247 | self.save()
248 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
249 | self.epoch)
250 | utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_cont',
251 | self.epoch)
252 | self.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
253 |
254 | def visualize_results(self, epoch):
255 | self.G.eval()
256 |
257 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
258 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
259 |
260 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num)))
261 |
262 | """ style by class """
263 | samples = self.G(self.sample_z_, self.sample_c_, self.sample_y_)
264 | if self.gpu_mode:
265 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
266 | else:
267 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
268 |
269 | samples = (samples + 1) / 2
270 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
271 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')
272 |
273 | """ manipulating two continous codes """
274 | samples = self.G(self.sample_z2_, self.sample_c2_, self.sample_y2_)
275 | if self.gpu_mode:
276 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
277 | else:
278 | samples = samples.data.numpy().transpose(0, 2, 3, 1)
279 |
280 | samples = (samples + 1) / 2
281 | utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
282 | self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_cont_epoch%03d' % epoch + '.png')
283 |
284 | def save(self):
285 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
286 |
287 | if not os.path.exists(save_dir):
288 | os.makedirs(save_dir)
289 |
290 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl'))
291 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl'))
292 |
293 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f:
294 | pickle.dump(self.train_hist, f)
295 |
296 | def load(self):
297 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
298 |
299 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
300 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
301 |
302 | def loss_plot(self, hist, path='Train_hist.png', model_name=''):
303 | x = range(len(hist['D_loss']))
304 |
305 | y1 = hist['D_loss']
306 | y2 = hist['G_loss']
307 | y3 = hist['info_loss']
308 |
309 | plt.plot(x, y1, label='D_loss')
310 | plt.plot(x, y2, label='G_loss')
311 | plt.plot(x, y3, label='info_loss')
312 |
313 | plt.xlabel('Iter')
314 | plt.ylabel('Loss')
315 |
316 | plt.legend(loc=4)
317 | plt.grid(True)
318 | plt.tight_layout()
319 |
320 | path = os.path.join(path, model_name + '_loss.png')
321 |
322 | plt.savefig(path)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse, os, torch
2 | from GAN import GAN
3 | from CGAN import CGAN
4 | from LSGAN import LSGAN
5 | from DRAGAN import DRAGAN
6 | from ACGAN import ACGAN
7 | from WGAN import WGAN
8 | from WGAN_GP import WGAN_GP
9 | from infoGAN import infoGAN
10 | from EBGAN import EBGAN
11 | from BEGAN import BEGAN
12 |
13 | """parsing and configuration"""
14 | def parse_args():
15 | desc = "Pytorch implementation of GAN collections"
16 | parser = argparse.ArgumentParser(description=desc)
17 |
18 | parser.add_argument('--gan_type', type=str, default='GAN',
19 | choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN'],
20 | help='The type of GAN')
21 | parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed'],
22 | help='The name of dataset')
23 | parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10')
24 | parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
25 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
26 | parser.add_argument('--input_size', type=int, default=28, help='The size of input image')
27 | parser.add_argument('--save_dir', type=str, default='models',
28 | help='Directory name to save the model')
29 | parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images')
30 | parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
31 | parser.add_argument('--lrG', type=float, default=0.0002)
32 | parser.add_argument('--lrD', type=float, default=0.0002)
33 | parser.add_argument('--beta1', type=float, default=0.5)
34 | parser.add_argument('--beta2', type=float, default=0.999)
35 | parser.add_argument('--gpu_mode', type=bool, default=True)
36 | parser.add_argument('--benchmark_mode', type=bool, default=True)
37 |
38 | return check_args(parser.parse_args())
39 |
40 | """checking arguments"""
41 | def check_args(args):
42 | # --save_dir
43 | if not os.path.exists(args.save_dir):
44 | os.makedirs(args.save_dir)
45 |
46 | # --result_dir
47 | if not os.path.exists(args.result_dir):
48 | os.makedirs(args.result_dir)
49 |
50 | # --result_dir
51 | if not os.path.exists(args.log_dir):
52 | os.makedirs(args.log_dir)
53 |
54 | # --epoch
55 | try:
56 | assert args.epoch >= 1
57 | except:
58 | print('number of epochs must be larger than or equal to one')
59 |
60 | # --batch_size
61 | try:
62 | assert args.batch_size >= 1
63 | except:
64 | print('batch size must be larger than or equal to one')
65 |
66 | return args
67 |
68 | """main"""
69 | def main():
70 | # parse arguments
71 | args = parse_args()
72 | if args is None:
73 | exit()
74 |
75 | if args.benchmark_mode:
76 | torch.backends.cudnn.benchmark = True
77 |
78 | # declare instance for GAN
79 | if args.gan_type == 'GAN':
80 | gan = GAN(args)
81 | elif args.gan_type == 'CGAN':
82 | gan = CGAN(args)
83 | elif args.gan_type == 'ACGAN':
84 | gan = ACGAN(args)
85 | elif args.gan_type == 'infoGAN':
86 | gan = infoGAN(args, SUPERVISED=False)
87 | elif args.gan_type == 'EBGAN':
88 | gan = EBGAN(args)
89 | elif args.gan_type == 'WGAN':
90 | gan = WGAN(args)
91 | elif args.gan_type == 'WGAN_GP':
92 | gan = WGAN_GP(args)
93 | elif args.gan_type == 'DRAGAN':
94 | gan = DRAGAN(args)
95 | elif args.gan_type == 'LSGAN':
96 | gan = LSGAN(args)
97 | elif args.gan_type == 'BEGAN':
98 | gan = BEGAN(args)
99 | else:
100 | raise Exception("[!] There is no option for " + args.gan_type)
101 |
102 | # launch the graph in a session
103 | gan.train()
104 | print(" [*] Training finished!")
105 |
106 | # visualize learned generator
107 | gan.visualize_results(args.epoch)
108 | print(" [*] Testing finished!")
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os, gzip, torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import scipy.misc
5 | import imageio
6 | import matplotlib.pyplot as plt
7 | from torchvision import datasets, transforms
8 |
9 | def load_mnist(dataset):
10 | data_dir = os.path.join("./data", dataset)
11 |
12 | def extract_data(filename, num_data, head_size, data_size):
13 | with gzip.open(filename) as bytestream:
14 | bytestream.read(head_size)
15 | buf = bytestream.read(data_size * num_data)
16 | data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
17 | return data
18 |
19 | data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
20 | trX = data.reshape((60000, 28, 28, 1))
21 |
22 | data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
23 | trY = data.reshape((60000))
24 |
25 | data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
26 | teX = data.reshape((10000, 28, 28, 1))
27 |
28 | data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
29 | teY = data.reshape((10000))
30 |
31 | trY = np.asarray(trY).astype(np.int)
32 | teY = np.asarray(teY)
33 |
34 | X = np.concatenate((trX, teX), axis=0)
35 | y = np.concatenate((trY, teY), axis=0).astype(np.int)
36 |
37 | seed = 547
38 | np.random.seed(seed)
39 | np.random.shuffle(X)
40 | np.random.seed(seed)
41 | np.random.shuffle(y)
42 |
43 | y_vec = np.zeros((len(y), 10), dtype=np.float)
44 | for i, label in enumerate(y):
45 | y_vec[i, y[i]] = 1
46 |
47 | X = X.transpose(0, 3, 1, 2) / 255.
48 | # y_vec = y_vec.transpose(0, 3, 1, 2)
49 |
50 | X = torch.from_numpy(X).type(torch.FloatTensor)
51 | y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
52 | return X, y_vec
53 |
54 | def load_celebA(dir, transform, batch_size, shuffle):
55 | # transform = transforms.Compose([
56 | # transforms.CenterCrop(160),
57 | # transform.Scale(64)
58 | # transforms.ToTensor(),
59 | # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
60 | # ])
61 |
62 | # data_dir = 'data/celebA' # this path depends on your computer
63 | dset = datasets.ImageFolder(dir, transform)
64 | data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle)
65 |
66 | return data_loader
67 |
68 |
69 | def print_network(net):
70 | num_params = 0
71 | for param in net.parameters():
72 | num_params += param.numel()
73 | print(net)
74 | print('Total number of parameters: %d' % num_params)
75 |
76 | def save_images(images, size, image_path):
77 | return imsave(images, size, image_path)
78 |
79 | def imsave(images, size, path):
80 | image = np.squeeze(merge(images, size))
81 | return scipy.misc.imsave(path, image)
82 |
83 | def merge(images, size):
84 | h, w = images.shape[1], images.shape[2]
85 | if (images.shape[3] in (3,4)):
86 | c = images.shape[3]
87 | img = np.zeros((h * size[0], w * size[1], c))
88 | for idx, image in enumerate(images):
89 | i = idx % size[1]
90 | j = idx // size[1]
91 | img[j * h:j * h + h, i * w:i * w + w, :] = image
92 | return img
93 | elif images.shape[3]==1:
94 | img = np.zeros((h * size[0], w * size[1]))
95 | for idx, image in enumerate(images):
96 | i = idx % size[1]
97 | j = idx // size[1]
98 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
99 | return img
100 | else:
101 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
102 |
103 | def generate_animation(path, num):
104 | images = []
105 | for e in range(num):
106 | img_name = path + '_epoch%03d' % (e+1) + '.png'
107 | images.append(imageio.imread(img_name))
108 | imageio.mimsave(path + '_generate_animation.gif', images, fps=5)
109 |
110 | def loss_plot(hist, path = 'Train_hist.png', model_name = ''):
111 | x = range(len(hist['D_loss']))
112 |
113 | y1 = hist['D_loss']
114 | y2 = hist['G_loss']
115 |
116 | plt.plot(x, y1, label='D_loss')
117 | plt.plot(x, y2, label='G_loss')
118 |
119 | plt.xlabel('Iter')
120 | plt.ylabel('Loss')
121 |
122 | plt.legend(loc=4)
123 | plt.grid(True)
124 | plt.tight_layout()
125 |
126 | path = os.path.join(path, model_name + '_loss.png')
127 |
128 | plt.savefig(path)
129 |
130 | plt.close()
131 |
132 | def initialize_weights(net):
133 | for m in net.modules():
134 | if isinstance(m, nn.Conv2d):
135 | m.weight.data.normal_(0, 0.02)
136 | m.bias.data.zero_()
137 | elif isinstance(m, nn.ConvTranspose2d):
138 | m.weight.data.normal_(0, 0.02)
139 | m.bias.data.zero_()
140 | elif isinstance(m, nn.Linear):
141 | m.weight.data.normal_(0, 0.02)
142 | m.bias.data.zero_()
--------------------------------------------------------------------------------