├── .gitignore ├── README.md ├── gan_train.py ├── install_env.sh └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction to Generative Adversarial Networks 2 | ------------------------------------------------- 3 | 4 | This repo is based on [John Glover's blog post](http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/) about approximting a 1-dimensional Gaussian distribution using Generative Adversarial networks (GAN). 5 | 6 | He also released [TensorFlow code](https://github.com/AYLIEN/gan-intro) for his experiments. 7 | 8 | I tried to reimplement his experiments using Pytorch. 9 | For this, I used some code from the [Pytorch DCGAN example](https://github.com/pytorch/examples/blob/master/dcgan/main.py). 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /gan_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jul 10 14:24:54 2017 5 | 6 | @author: ptrblck 7 | 8 | This code is based on the blog post by John Glover: 9 | http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/ 10 | 11 | The pytorch code was partly taken from: 12 | https://github.com/pytorch/examples/blob/master/dcgan/main.py 13 | """ 14 | 15 | import numpy as np 16 | 17 | import torch 18 | from torch.autograd import Variable 19 | import torch.nn.functional as F 20 | 21 | import os 22 | import matplotlib.pyplot as plt 23 | from matplotlib import animation 24 | import datetime 25 | #%matplotlib 26 | 27 | 28 | # custom weights initialization called on netG and netD 29 | def weights_init(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('Linear') != -1: 32 | m.weight.data.normal_(0.0, 1.0) 33 | m.bias.data.fill_(0.0) 34 | 35 | 36 | def update_learning_rate(optimizer, epoch, init_lr, decay_rate, lr_decay_epochs): 37 | lr = init_lr * (decay_rate**(epoch // lr_decay_epochs)) 38 | 39 | if epoch % lr_decay_epochs == 0: 40 | print('LR set to {}'.format(lr)) 41 | 42 | for param_group in optimizer.param_groups: 43 | param_group['lr'] = lr 44 | 45 | return optimizer 46 | 47 | 48 | def samples( 49 | models, 50 | data, 51 | sample_range, 52 | batch_size, 53 | num_points=10000, 54 | num_bins=100 55 | ): 56 | ''' 57 | Return a tuple (db, pd, pg), where db is the current decision 58 | boundary, pd is a histogram of samples from the data distribution, 59 | and pg is a histogram of generated samples. 60 | 61 | Taken from https://github.com/AYLIEN/gan-intro/blob/master/gan.py 62 | ''' 63 | xs = np.linspace(-sample_range, sample_range, num_points) 64 | bins = np.linspace(-sample_range, sample_range, num_bins) 65 | 66 | # decision boundary 67 | db = np.zeros((num_points, 1)) 68 | for i in range(num_points // batch_size): 69 | sample = torch.FloatTensor(np.reshape(xs[batch_size * i : batch_size * (i + 1)], (batch_size, 1))) 70 | sample = Variable(sample).cuda() 71 | db[batch_size * i:batch_size * (i + 1)] = models[0](sample).cpu().data.numpy() 72 | 73 | # data distribution 74 | d = data.sample(num_points) 75 | pd, _ = np.histogram(d, bins=bins, density=True) 76 | 77 | # generated samples 78 | zs = np.linspace(-sample_range, sample_range, num_points) 79 | g = np.zeros((num_points, 1)) 80 | for i in range(num_points // batch_size): 81 | sample = torch.FloatTensor(np.reshape(zs[batch_size * i:batch_size * (i + 1)], (batch_size, 1))) 82 | sample = Variable(sample).cuda() 83 | g[batch_size * i:batch_size * (i + 1)] = models[1](sample).cpu().data.numpy() 84 | 85 | pg, _ = np.histogram(g, bins=bins, density=True) 86 | 87 | return db, pd, pg 88 | 89 | 90 | def plot_distributions(samps, sample_range, ax, save_img_name): 91 | ax.clear() 92 | db, pd, pg = samps 93 | db_x = np.linspace(-sample_range, sample_range, len(db)) 94 | p_x = np.linspace(-sample_range, sample_range, len(pd)) 95 | #f, ax = plt.subplots(1) 96 | ax.plot(db_x, db, label='decision boundary') 97 | ax.set_ylim(0, 1) 98 | plt.plot(p_x, pd, label='real data') 99 | plt.plot(p_x, pg, label='generated data') 100 | plt.title('1D Generative Adversarial Network') 101 | plt.xlabel('Data values') 102 | plt.ylabel('Probability density') 103 | plt.legend() 104 | plt.savefig(save_img_name) 105 | #plt.show() 106 | #plt.pause(0.05) 107 | 108 | 109 | class DataDistribution(object): 110 | def __init__(self): 111 | self.mu = 4 112 | self.sigma = 0.5 113 | 114 | def sample(self, N): 115 | samples = np.random.normal(self.mu, self.sigma, N) 116 | samples.sort() 117 | return np.reshape(samples, (-1, 1)) 118 | 119 | 120 | class GeneratorDistribution(object): 121 | def __init__(self, range): 122 | self.range = range 123 | 124 | def sample(self, N): 125 | samples = np.linspace(-self.range, self.range, N) + \ 126 | np.random.random(N) * 0.01 127 | return samples 128 | 129 | 130 | class Generator(torch.nn.Module): 131 | def __init__(self, D_in, H, D_out): 132 | super(Generator, self).__init__() 133 | self.linear1 = torch.nn.Linear(D_in, H) 134 | self.linear2 = torch.nn.Linear(H, D_out) 135 | 136 | def forward(self, x): 137 | h_relu = self.linear1(x).clamp(min=0) 138 | y_pred = self.linear2(h_relu) 139 | return y_pred 140 | 141 | 142 | class Discriminator(torch.nn.Module): 143 | def __init__(self, D_in, H, D_out): 144 | super(Discriminator, self).__init__() 145 | self.linear1 = torch.nn.Linear(D_in, H * 2) 146 | self.linear2 = torch.nn.Linear(H * 2, H * 2) 147 | self.linear3 = torch.nn.Linear(H * 2, H * 2) 148 | self.linear4 = torch.nn.Linear(H * 2, D_out) 149 | 150 | def forward(self, x): 151 | h0 = F.tanh(self.linear1(x)) 152 | h1 = F.tanh(self.linear2(h0)) 153 | h2 = F.tanh(self.linear3(h1)) 154 | out = F.sigmoid(self.linear4(h2)) 155 | return out 156 | 157 | def save_animation(anim_frames, anim_path, sample_range): 158 | f, ax = plt.subplots(figsize=(6, 4)) 159 | f.suptitle('1D Generative Adversarial Network', fontsize=15) 160 | plt.xlabel('Data values') 161 | plt.ylabel('Probability density') 162 | ax.set_xlim(-6, 6) 163 | ax.set_ylim(0, 1.4) 164 | line_db, = ax.plot([], [], label='decision boundary') 165 | line_pd, = ax.plot([], [], label='real data') 166 | line_pg, = ax.plot([], [], label='generated data') 167 | frame_number = ax.text( 168 | 0.02, 169 | 0.95, 170 | '', 171 | horizontalalignment='left', 172 | verticalalignment='top', 173 | transform=ax.transAxes 174 | ) 175 | ax.legend() 176 | 177 | db, pd, _ = anim_frames[0] 178 | db_x = np.linspace(-sample_range, sample_range, len(db)) 179 | p_x = np.linspace(-sample_range, sample_range, len(pd)) 180 | 181 | def init(): 182 | line_db.set_data([], []) 183 | line_pd.set_data([], []) 184 | line_pg.set_data([], []) 185 | frame_number.set_text('') 186 | return (line_db, line_pd, line_pg, frame_number) 187 | 188 | def animate(i): 189 | frame_number.set_text( 190 | 'Frame: {}/{}'.format(i, len(anim_frames)) 191 | ) 192 | db, pd, pg = anim_frames[i] 193 | line_db.set_data(db_x, db) 194 | line_pd.set_data(p_x, pd) 195 | line_pg.set_data(p_x, pg) 196 | return (line_db, line_pd, line_pg, frame_number) 197 | 198 | anim = animation.FuncAnimation( 199 | f, 200 | animate, 201 | init_func=init, 202 | frames=len(anim_frames), 203 | blit=True 204 | ) 205 | anim.save(anim_path, fps=30, extra_args=['-vcodec', 'libx264']) 206 | 207 | 208 | 209 | 210 | N = 8 # batch size 211 | D_in = 1 # input size of D 212 | H = 4 # numbr of hidden neurons 213 | D_out = 1 214 | learning_rate = 0.005 215 | epochs = 10000 216 | plot_every_epochs = 1000 217 | current_time = datetime.datetime.now().strftime("20%y%m%d_%H%M_%S") 218 | output_path = '/tmp/{}'.format(current_time) 219 | if not os.path.exists(output_path): 220 | os.makedirs(output_path) 221 | anim_path = output_path 222 | 223 | 224 | anim_frames = [] 225 | 226 | use_cuda = torch.cuda.is_available() 227 | 228 | data_dist = DataDistribution() 229 | gen_dist = GeneratorDistribution(range=8) 230 | 231 | x = torch.FloatTensor(N, 1) 232 | z = torch.FloatTensor(N, 1) 233 | label = torch.FloatTensor(N) 234 | real_label = 1 235 | fake_label = 0 236 | 237 | netD = Discriminator(D_in=D_in, H=H, D_out=D_out) 238 | netG = Generator(D_in=D_in, H=H, D_out=D_out) 239 | netD.apply(weights_init) 240 | netG.apply(weights_init) 241 | 242 | criterion = torch.nn.BCELoss() 243 | 244 | if use_cuda: 245 | x, z, label = x.cuda(), z.cuda(), label.cuda() 246 | netD.cuda() 247 | netG.cuda() 248 | criterion.cuda() 249 | 250 | optimizerD = torch.optim.SGD(netD.parameters(), lr=learning_rate) 251 | optimizerG = torch.optim.SGD(netG.parameters(), lr=learning_rate) 252 | 253 | # Create figure 254 | f, ax = plt.subplots(1) 255 | 256 | for epoch in range(epochs): 257 | ############################ 258 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 259 | ########################### 260 | # train with real: maximize log(D(x)) 261 | netD.zero_grad() 262 | real_cpu = torch.FloatTensor(data_dist.sample(N)) 263 | if use_cuda: 264 | real_cpu = real_cpu.cuda() 265 | x.copy_(real_cpu) 266 | label.fill_(real_label) 267 | xv = Variable(x) 268 | labelv = Variable(label) 269 | 270 | output = netD(xv) # D(x) 271 | errD_real = criterion(output, labelv) 272 | errD_real.backward() 273 | D_x = output.data.mean() 274 | 275 | # train with fake: maximize log(1 - D(G(z))) 276 | z = torch.FloatTensor(gen_dist.sample(N))[...,None] # (N_sample, N_channel) 277 | if use_cuda: 278 | z = z.cuda() 279 | zv = Variable(z) 280 | fake = netG(zv) # G(z) 281 | labelv = Variable(label.fill_(fake_label)) 282 | output = netD(fake.detach()) # D(G(z)) 283 | errD_fake = criterion(output, labelv) 284 | errD_fake.backward() 285 | D_G_z1 = output.data.mean() 286 | errD = errD_real + errD_fake 287 | optimizerD.step() 288 | optimizerD = update_learning_rate(optimizer=optimizerD, 289 | epoch=epoch, 290 | init_lr=learning_rate, 291 | decay_rate=0.95, 292 | lr_decay_epochs=150) 293 | 294 | ############################ 295 | # (2) Update G network: maximize log(D(G(z))): guide D make wrong prediction: G(z) --> real_label(1) 296 | ########################### 297 | netG.zero_grad() 298 | labelv = Variable(label.fill_(real_label)) 299 | output = netD(fake) # D(G(z)) 300 | errG = criterion(output, labelv) 301 | errG.backward() 302 | D_G_z2 = output.data.mean() 303 | optimizerG.step() 304 | optimizerG = update_learning_rate(optimizer=optimizerG, 305 | epoch=epoch, 306 | init_lr=learning_rate, 307 | decay_rate=0.95, 308 | lr_decay_epochs=150) 309 | 310 | print('[%d/%d] Loss_D: %.4f Loss_G %.4f D(x): %.4f D(G(z)): %.4f / %.4f' \ 311 | % (epoch, epochs, errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2)) 312 | 313 | if epoch % plot_every_epochs == 0: 314 | # Plot distribution 315 | samps = samples([netD, netG], data_dist, gen_dist.range, N) 316 | anim_frames.append(samps) 317 | plot_distributions(samps, gen_dist.range, ax, save_img_name = output_path+'/{:06}'.format(epoch)) 318 | 319 | # save_animation(anim_frames, anim_path, gen_dist.range) 320 | -------------------------------------------------------------------------------- /install_env.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | pip install datetime 3 | pip3 install torch torchvision # or use `conda install pytorch torchvision -c pytorch` 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==1.5.3 2 | numpy==1.11.3 3 | scipy==0.17.0 4 | seaborn==0.7.1 5 | --------------------------------------------------------------------------------