├── .gitignore ├── LICENSE.md ├── README.md ├── datasets ├── __init__.py └── toy.py ├── losses ├── BCELoss.py ├── HingeLoss.py ├── LeakyHingeLoss.py ├── SumLoss.py └── __init__.py ├── main.py ├── models ├── __init__.py ├── dcgan.py ├── mlp.py ├── toy.py └── toy4.py ├── plot_log.py ├── requirements.txt ├── scripts ├── exp1a.toy.all.sh ├── exp1b.toy.diffC.sh ├── exp2.mnist.sh ├── exp3.celeba.sh ├── exp4.lsun.sh └── plot.example.sh └── utils ├── __init__.py └── plot.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sw* 3 | *.bak 4 | logs 5 | logs/* 6 | samples 7 | samples/* 8 | bak 9 | bak/* 10 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Original work Copyright (c) 2017, Martin Arjovsky (NYU), Soumith Chintala (Facebook), Leon Bottou (Facebook) 4 | Modified work Copyright (c) 2017, Jae Hyun Lim (ETRI), Jong Chul Ye (KAIST) 5 | 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | * Neither the name of the copyright holder nor the names of its 19 | contributors may be used to endorse or promote products derived from 20 | this software without specific prior written permission. 21 | 22 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 23 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 26 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 28 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 30 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Geometric GAN 2 | =============== 3 | 4 | Code accompanying the paper ["Geometric GAN"](https://arxiv.org/abs/1705.02894). \ 5 | (Ths code is modified from https://github.com/martinarjovsky/WassersteinGAN) 6 | 7 | 8 | [Prerequisites](#prerequisites) \ 9 | [Datasets](#datasets) \ 10 | [Reproducing Experiments](#reproducing-experiments) \ 11 | [Generated Samples](#generated-samples) \ 12 | [Plot Losses](#plot-losses) 13 | 14 | 15 | ## Prerequisites 16 | 17 | - Computer with Linux or OSX 18 | - [PyTorch](http://pytorch.org) 19 | - For training, an NVIDIA GPU is strongly recommended for speed. CPU is supported but training is very slow. 20 | 21 | ## Datasets 22 | ### MNIST 23 | 24 | Make empty folder at `//`. 25 | 26 | Set symbolic link as follows; 27 | ``` 28 | mkdir data 29 | ln -s // data/mnist 30 | ``` 31 | 32 | Note: you can leave the folder empty since `torchvision` will automatically download mnist dataset. 33 | 34 | ### CelebA 35 | 36 | Download Align&Cropped Images of CelebA dataset, i.e. `img_align_celeba.zip`, from https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg at `//`. 37 | 38 | ``` 39 | unzip img_align_celeba.zip 40 | ``` 41 | 42 | Then you have, 43 | 44 | ``` 45 | // 46 | ├── img_align_celeba.zip 47 | └── img_align_celeba 48 | ``` 49 | 50 | Set symbolic link as follows; 51 | ``` 52 | mkdir data 53 | ln -s // data/celeba 54 | ``` 55 | 56 | ### LSUN 57 | 58 | Download LSUN bedroom dataset using https://github.com/fyu/lsun at `//`. 59 | 60 | ``` 61 | unzip bedroom_train_lmdb.zip 62 | ``` 63 | 64 | Then you have, 65 | 66 | ``` 67 | // 68 | ├── bedroom_train_lmdb.zip 69 | ├── bedroom_train_lmdb 70 | ... 71 | ``` 72 | 73 | Set symbolic link as follows; 74 | ``` 75 | mkdir data 76 | ln -s // data/lsun 77 | ``` 78 | 79 | ## Reproducing Experiments 80 | ### Exp1: Mixture of Gaussian 81 | ``` 82 | python main.py standard geogan --cuda --dataset toy4 --dataroot '' --lrD 0.001 --lrG 0.001 --nc 2 --nz 4 --ngf 128 --ndf 128 --model_G toy4 --model_D toy4 --batchSize 500 --experiment samples/toy4_geogan_toy4_rmsprop_lr001_c1 --niter 500 --ndisplay 100 --nsave 50 83 | ``` 84 | 85 | or execute following scripts in the directory of this repo. 86 | ``` 87 | ./scripts/exp1a.toy.all.sh 88 | ./scripts/exp1b.toy.diffC.sh 89 | ``` 90 | 91 | ### Exp2: MNIST 92 | ``` 93 | python main.py standard geogan --cuda --dataset mnist --dataroot data/mnist --imageSize 64 --nc 1 --lrD 0.0002 --lrG 0.0002 --model_G dcgan --model_D dcgan --ndf 128 --ngf 128 --Giters 10 --niter 25 --ndisplay 100 --nsave 5 --experiment samples/mnist_geogan_dcgan128_rmsprop_lr0002_kg10_c1 94 | ``` 95 | 96 | or execute following scripts in the directory of this repo. 97 | ``` 98 | ./scripts/exp2.mnist.sh 99 | ``` 100 | 101 | ### Exp3: CelebA 102 | ``` 103 | python main.py standard geogan --cuda --dataset folder --dataroot data/celeba --loadSize 96 --imageSize 64 --lrD 0.0002 --lrG 0.0002 --model_G dcgan --model_D dcgan --ndf 128 --ngf 128 --Giters 10 --niter 50 --ndisplay 500 --nsave 5 --experiment samples/celeba_geogan_dcgan128_rmsprop_lr0002_kg10_c1 104 | ``` 105 | 106 | or execute following scripts in the directory of this repo. 107 | ``` 108 | ./scripts/exp3.celeba.sh 109 | ``` 110 | 111 | ### Exp4: LSUN 112 | ``` 113 | python main.py standard geogan --cuda --dataset lsun --dataroot data/lsun --imageSize 64 --lrD 0.0002 --lrG 0.0002 --model_G dcgan --model_D dcgan --ndf 128 --ngf 128 --Giters 10 --niter 5 --nsave 1 --ndisplay 500 --experiment samples/lsun_geogan_dcgan128_rmsprop_lr0002_kg10_c1 114 | ``` 115 | 116 | or execute following scripts in the directory of this repo. 117 | ``` 118 | ./scripts/exp4.lsun.sh 119 | ``` 120 | 121 | 122 | ## Generated Samples 123 | Generated samples will be in the `samples` folder. 124 | 125 | 126 | ## Plot Losses 127 | Logs will be in the `logs` folder (if you use the aforementioned scripts). 128 | 129 | Use `plot_log.py`, and the example usages of it are in `scripts/plot.example.sh` 130 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim0606/pytorch-geometric-gan/eb84feb5cae1d6963c075aa6fb4c0c3a18eeec41/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/toy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This function is borrowed and modified from https://github.com/torch/demos/blob/master/train-a-digit-classifier/dataset-mnist.lua 3 | and from https://github.com/gcr/torch-residual-networks/blob/master/data/mnist-dataset.lua 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | from scipy.stats import multivariate_normal 10 | import numpy as np 11 | from torch.autograd import Variable 12 | 13 | import warnings 14 | warnings.filterwarnings("ignore", category=FutureWarning) 15 | #import numpy as np 16 | import matplotlib 17 | matplotlib.use('Agg') 18 | import matplotlib.cm as cm 19 | import matplotlib.mlab as mlab 20 | import matplotlib.pyplot as plt 21 | 22 | # data generating function 23 | # exp1: mixture of 4 gaussians 24 | def exp1(num_data=1000): 25 | if num_data % 4 != 0: 26 | raise ValueError('num_data should be multiple of 4. num_data = {}'.format(num_data)) 27 | 28 | center = 8 29 | sigma = 1 #math.sqrt(3) 30 | 31 | # init data 32 | d1x = torch.FloatTensor(num_data/4, 1) 33 | d1y = torch.FloatTensor(num_data/4, 1) 34 | d1x.normal_(center, sigma * 3) 35 | d1y.normal_(center, sigma * 1) 36 | 37 | d2x = torch.FloatTensor(num_data/4, 1) 38 | d2y = torch.FloatTensor(num_data/4, 1) 39 | d2x.normal_(-center, sigma * 1) 40 | d2y.normal_(center, sigma * 3) 41 | 42 | d3x = torch.FloatTensor(num_data/4, 1) 43 | d3y = torch.FloatTensor(num_data/4, 1) 44 | d3x.normal_(center, sigma * 3) 45 | d3y.normal_(-center, sigma * 2) 46 | 47 | d4x = torch.FloatTensor(num_data/4, 1) 48 | d4y = torch.FloatTensor(num_data/4, 1) 49 | d4x.normal_(-center, sigma * 2) 50 | d4y.normal_(-center, sigma * 2) 51 | 52 | d1 = torch.cat((d1x, d1y), 1) 53 | d2 = torch.cat((d2x, d2y), 1) 54 | d3 = torch.cat((d3x, d3y), 1) 55 | d4 = torch.cat((d4x, d4y), 1) 56 | 57 | d = torch.cat((d1, d2, d3, d4), 0) 58 | 59 | # label 60 | label = torch.IntTensor(num_data).zero_() 61 | for i in range(4): 62 | label[i*(num_data/4):(i+1)*(num_data/4)] = i 63 | 64 | # shuffle 65 | #shuffle = torch.randperm(d.size()[0]) 66 | #d = torch.index_select(d, 0, shuffle) 67 | #label = torch.index_select(label, 0, shuffle) 68 | 69 | # pdf 70 | rv1 = multivariate_normal([ center, center], [[math.pow(sigma * 3, 2), 0.0], [0.0, math.pow(sigma * 1, 2)]]) 71 | rv2 = multivariate_normal([-center, center], [[math.pow(sigma * 1, 2), 0.0], [0.0, math.pow(sigma * 3, 2)]]) 72 | rv3 = multivariate_normal([ center, -center], [[math.pow(sigma * 3, 2), 0.0], [0.0, math.pow(sigma * 2, 2)]]) 73 | rv4 = multivariate_normal([-center, -center], [[math.pow(sigma * 2, 2), 0.0], [0.0, math.pow(sigma * 2, 2)]]) 74 | 75 | def pdf(x): 76 | prob = 0.25 * rv1.pdf(x) + 0.25 * rv2.pdf(x) + 0.25 * rv3.pdf(x) + 0.25 * rv4.pdf(x) 77 | return prob 78 | 79 | def sumloglikelihood(x): 80 | return np.sum(np.log((pdf(x) + 1e-10))) 81 | 82 | return d, label, sumloglikelihood 83 | 84 | 85 | # exp2: two spirals 86 | def exp2(num_data=1000): 87 | ''' 88 | This function is borrowed from http://stackoverflow.com/questions/16146599/create-artificial-data-in-matlab 89 | ''' 90 | 91 | degrees = 450 #570 92 | start = 90 93 | #noise = 0 #0.2 94 | deg2rad = (2*math.pi)/360 95 | radius = 1.8 96 | start = start * deg2rad; 97 | 98 | N_mixtures = 100 99 | N = 2 * N_mixtures 100 | N1 = N_mixtures #math.floor(N/2) 101 | N2 = N_mixtures #N-N1 102 | if num_data % N_mixtures != 0: 103 | raise ValueError('num_data should be multiple of {} (num_data = {})'.format(2*N_mixtures, num_data)) 104 | 105 | n = (start + 106 | torch.sqrt(torch.linspace(0.075,1,N2).view(N2,1)).mul_(degrees) 107 | ).mul_(deg2rad) 108 | mu1 = torch.cat((torch.mul(-torch.cos(n), n).mul_(radius), 109 | torch.mul(torch.sin(n), n).mul_(radius)), 1) 110 | 111 | n = (start + 112 | torch.sqrt(torch.linspace(0.075,1,N1).view(N1,1)).mul_(degrees) 113 | ).mul_(deg2rad) 114 | mu2 = torch.cat((torch.mul(torch.cos(n), n).mul_(radius), 115 | torch.mul(-torch.sin(n), n).mul_(radius)), 1) 116 | 117 | mu = torch.cat((mu1, mu2), 0) 118 | num_data_per_mixture = num_data / (2*N_mixtures) 119 | sigma = math.sqrt(0.6) 120 | x = torch.zeros(num_data, 2) 121 | for i in range(2*N_mixtures): 122 | xx = x[i*num_data_per_mixture:(i+1)*num_data_per_mixture, :] 123 | xx.copy_(torch.cat( 124 | (torch.FloatTensor(num_data_per_mixture).normal_(mu[i,0], sigma).view(num_data_per_mixture, 1), 125 | torch.FloatTensor(num_data_per_mixture).normal_(mu[i,1], sigma).view(num_data_per_mixture, 1)), 1)) 126 | 127 | # label 128 | label = torch.IntTensor(num_data).zero_() 129 | label[0:num_data/2] = 0 130 | label[num_data/2:] = 1 131 | 132 | # shuffle 133 | #shuffle = torch.randperm(x.size()[0]) 134 | #x = torch.index_select(x, 0, shuffle) 135 | #label = torch.index_select(label, 0, shuffle) 136 | 137 | # pdf 138 | rv_list = [] 139 | for i in range(2 * N_mixtures): 140 | rv = multivariate_normal([mu[i,0], mu[i,1]], [[math.pow(sigma, 2), 0.0], [0.0, math.pow(sigma, 2)]]) 141 | rv_list.append(rv) 142 | 143 | def pdf(x): 144 | prob = 1 / (2*N_mixtures) * rv_list[0].pdf(x) 145 | for i in range(1, 2 * N_mixtures): 146 | prob += (1.0 / float(2*N_mixtures)) * rv_list[i].pdf(x) 147 | return prob 148 | 149 | def sumloglikelihood(x): 150 | return np.sum(np.log((pdf(x) + 1e-10))) 151 | 152 | return x, label, sumloglikelihood 153 | 154 | # exp3: mixture of 2 gaussians with high bias 155 | def exp3(num_data=1000): 156 | if num_data < 2: 157 | raise ValueError('num_data should be larger than 2. (num_data = {})'.format(num_data)) 158 | 159 | center = 6.2 160 | sigma = 1 #math.sqrt(3) 161 | 162 | n1 = int(round(num_data * 0.9)) 163 | n2 = num_data - n1 164 | 165 | # init data 166 | d1x = torch.FloatTensor(n1, 1) 167 | d1y = torch.FloatTensor(n1, 1) 168 | d1x.normal_(center, sigma * 5) 169 | d1y.normal_(center, sigma * 5) 170 | 171 | d2x = torch.FloatTensor(n2, 1) 172 | d2y = torch.FloatTensor(n2, 1) 173 | d2x.normal_(-center, sigma * 1) 174 | d2y.normal_(-center, sigma * 1) 175 | 176 | d1 = torch.cat((d1x, d1y), 1) 177 | d2 = torch.cat((d2x, d2y), 1) 178 | 179 | d = torch.cat((d1, d2), 0) 180 | 181 | # label 182 | label = torch.IntTensor(num_data).zero_() 183 | label[0:n1] = 0 184 | label[n1:] = 1 185 | 186 | # shuffle 187 | #shuffle = torch.randperm(d.size()[0]) 188 | #d = torch.index_select(d, 0, shuffle) 189 | #label = torch.index_select(label, 0, shuffle) 190 | 191 | # pdf 192 | rv1 = multivariate_normal([ center, center], [[math.pow(sigma * 5, 2), 0.0], [0.0, math.pow(sigma * 5, 2)]]) 193 | rv2 = multivariate_normal([-center, -center], [[math.pow(sigma * 1, 2), 0.0], [0.0, math.pow(sigma * 1, 2)]]) 194 | 195 | def pdf(x): 196 | prob = (float(n1) / float(num_data)) * rv1.pdf(x) + (float(n2) / float(num_data)) * rv2.pdf(x) 197 | return prob 198 | 199 | def sumloglikelihood(x): 200 | return np.sum(np.log((pdf(x) + 1e-10))) 201 | 202 | return d, label, sumloglikelihood 203 | 204 | # exp4: grid shapes 205 | def exp4(num_data=1000): 206 | 207 | var = 0.1 208 | max_x = 21 209 | max_y = 21 210 | min_x = -max_x 211 | min_y = -max_y 212 | n = 5 213 | 214 | # init 215 | nx, ny = (n, n) 216 | x = np.linspace(min_x, max_x, nx) 217 | y = np.linspace(min_y, max_y, ny) 218 | xv, yv = np.meshgrid(x, y) 219 | N = xv.size 220 | if num_data % N != 0: 221 | raise ValueError('num_data should be multiple of {} (num_data = {})'.format(N, num_data)) 222 | 223 | # data and label 224 | mu = np.concatenate((xv.reshape(N,1), yv.reshape(N,1)), axis=1) 225 | mu = torch.FloatTensor(mu) 226 | num_data_per_mixture = num_data / N 227 | sigma = math.sqrt(var) 228 | x = torch.zeros(num_data, 2) 229 | label = torch.IntTensor(num_data).zero_() 230 | for i in range(N): 231 | xx = x[i*num_data_per_mixture:(i+1)*num_data_per_mixture, :] 232 | xx.copy_(torch.cat( 233 | (torch.FloatTensor(num_data_per_mixture).normal_(mu[i,0], sigma).view(num_data_per_mixture, 1), 234 | torch.FloatTensor(num_data_per_mixture).normal_(mu[i,1], sigma).view(num_data_per_mixture, 1)), 1)) 235 | label[i*num_data_per_mixture:(i+1)*num_data_per_mixture] = i 236 | 237 | # shuffle 238 | #shuffle = torch.randperm(x.size()[0]) 239 | #x = torch.index_select(x, 0, shuffle) 240 | #label = torch.index_select(label, 0, shuffle) 241 | 242 | # pdf 243 | rv_list = [] 244 | for i in range(N): 245 | rv = multivariate_normal([mu[i,0], mu[i,1]], [[math.pow(sigma, 2), 0.0], [0.0, math.pow(sigma, 2)]]) 246 | rv_list.append(rv) 247 | 248 | def pdf(x): 249 | prob = 1 / (N) * rv_list[0].pdf(x) 250 | for i in range(1, N): 251 | prob += (1.0 / float(N)) * rv_list[i].pdf(x) 252 | return prob 253 | 254 | def sumloglikelihood(x): 255 | return np.sum(np.log((pdf(x) + 1e-10))) 256 | 257 | return x, label, sumloglikelihood 258 | 259 | # exp5: mixture of 2 gaussians with high bias 260 | def exp5(num_data=1000): 261 | if num_data < 2: 262 | raise ValueError('num_data should be larger than 2. (num_data = {})'.format(num_data)) 263 | 264 | center = -5 265 | sigma_x = 0.5 266 | sigma_y = 7 267 | 268 | n1 = num_data 269 | 270 | # init data 271 | d1x = torch.FloatTensor(n1, 1) 272 | d1y = torch.FloatTensor(n1, 1) 273 | d1x.normal_(center, sigma_x) 274 | d1y.normal_(center, sigma_y) 275 | 276 | d1 = torch.cat((d1x, d1y), 1) 277 | 278 | d = d1 279 | 280 | # label 281 | label = torch.IntTensor(num_data).zero_() 282 | label[:] = 0 283 | 284 | # shuffle 285 | #shuffle = torch.randperm(d.size()[0]) 286 | #d = torch.index_select(d, 0, shuffle) 287 | #label = torch.index_select(label, 0, shuffle) 288 | 289 | # pdf 290 | rv1 = multivariate_normal([ center, center], [[math.pow(sigma_x, 2), 0.0], [0.0, math.pow(sigma_y, 2)]]) 291 | 292 | def pdf(x): 293 | prob = (float(n1) / float(num_data)) * rv1.pdf(x) 294 | return prob 295 | 296 | def sumloglikelihood(x): 297 | return np.sum(np.log((pdf(x) + 1e-10))) 298 | 299 | return d, label, sumloglikelihood 300 | 301 | # exp6: mixture of 2 gaussians with high bias 302 | def exp6(num_data=1000): 303 | if num_data < 2: 304 | raise ValueError('num_data should be larger than 2. (num_data = {})'.format(num_data)) 305 | 306 | center = -5 307 | sigma_x = 7 308 | sigma_y = 7 309 | 310 | n1 = num_data 311 | 312 | # init data 313 | d1x = torch.FloatTensor(n1, 1) 314 | d1y = torch.FloatTensor(n1, 1) 315 | d1x.normal_(center, sigma_x) 316 | d1y.normal_(center, sigma_y) 317 | 318 | d1 = torch.cat((d1x, d1y), 1) 319 | 320 | d = d1 321 | 322 | # label 323 | label = torch.IntTensor(num_data).zero_() 324 | label[:] = 0 325 | 326 | # shuffle 327 | #shuffle = torch.randperm(d.size()[0]) 328 | #d = torch.index_select(d, 0, shuffle) 329 | #label = torch.index_select(label, 0, shuffle) 330 | 331 | # pdf 332 | rv1 = multivariate_normal([ center, center], [[math.pow(sigma_x, 2), 0.0], [0.0, math.pow(sigma_y, 2)]]) 333 | 334 | def pdf(x): 335 | prob = (float(n1) / float(num_data)) * rv1.pdf(x) 336 | return prob 337 | 338 | def sumloglikelihood(x): 339 | return np.sum(np.log((pdf(x) + 1e-10))) 340 | 341 | return d, label, sumloglikelihood 342 | 343 | 344 | def exp(exp_num='toy1', num_data=1000): 345 | if exp_num == 'toy1': 346 | return exp1(num_data) 347 | elif exp_num == 'toy2': 348 | return exp2(num_data) 349 | elif exp_num == 'toy3': 350 | return exp3(num_data) 351 | elif exp_num == 'toy4': 352 | return exp4(num_data) 353 | elif exp_num == 'toy5': 354 | return exp5(num_data) 355 | elif exp_num == 'toy6': 356 | return exp6(num_data) 357 | else: 358 | raise ValueError('unknown experiment {}'.format(exp_num)) 359 | 360 | def save_image_fake(fake_data, filename): 361 | #import warnings 362 | #warnings.filterwarnings("ignore", category=FutureWarning) 363 | #import numpy as np 364 | #import matplotlib 365 | #matplotlib.use('Agg') 366 | #import matplotlib.pyplot as plt 367 | 368 | fig, ax = plt.subplots() 369 | #plt.scatter(real_data[:,0], real_data[:,1], color='blue', label='real') 370 | plt.scatter(fake_data[:,0], fake_data[:,1], color='red', label='fake') 371 | plt.axis('equal') 372 | #plt.legend(loc='upper right', fancybox=True, shadow=True, fontsize=11) 373 | plt.grid(True) 374 | plt.xlim(-25, 25) 375 | plt.ylim(-25, 25) 376 | plt.minorticks_on() 377 | plt.xlabel('x', fontsize=14, color='black') 378 | plt.ylabel('y', fontsize=14, color='black') 379 | #plt.title('Toy dataset') 380 | plt.savefig(filename) 381 | plt.close() 382 | 383 | def save_image_real(real_data, filename): 384 | #import warnings 385 | #warnings.filterwarnings("ignore", category=FutureWarning) 386 | #import numpy as np 387 | #import matplotlib 388 | #matplotlib.use('Agg') 389 | #import matplotlib.pyplot as plt 390 | 391 | fig, ax = plt.subplots() 392 | plt.scatter(real_data[:,0], real_data[:,1], color='blue', label='real') 393 | #plt.scatter(fake_data[:,0], fake_data[:,1], color='red', label='fake') 394 | plt.axis('equal') 395 | #plt.legend(loc='upper right', fancybox=True, shadow=True, fontsize=11) 396 | plt.grid(True) 397 | plt.xlim(-25, 25) 398 | plt.ylim(-25, 25) 399 | plt.minorticks_on() 400 | plt.xlabel('x', fontsize=14, color='black') 401 | plt.ylabel('y', fontsize=14, color='black') 402 | #plt.title('Toy dataset') 403 | plt.savefig(filename) 404 | plt.close() 405 | 406 | def save_image(real_data, fake_data, filename): 407 | #import warnings 408 | #warnings.filterwarnings("ignore", category=FutureWarning) 409 | #import numpy as np 410 | #import matplotlib 411 | #matplotlib.use('Agg') 412 | #import matplotlib.pyplot as plt 413 | 414 | fig, ax = plt.subplots() 415 | plt.scatter(real_data[:,0], real_data[:,1], color='blue', label='real') 416 | plt.scatter(fake_data[:,0], fake_data[:,1], color='red', label='fake') 417 | #plt.axis('equal') 418 | plt.legend(loc='upper right', fancybox=True, shadow=True, fontsize=11) 419 | plt.grid(True) 420 | plt.xlim(-25, 25) 421 | plt.ylim(-25, 25) 422 | plt.minorticks_on() 423 | plt.xlabel('x', fontsize=14, color='black') 424 | plt.ylabel('y', fontsize=14, color='black') 425 | plt.title('Toy dataset') 426 | plt.savefig(filename) 427 | plt.close() 428 | 429 | def save_contour(netD, filename, cuda=False): 430 | #import warnings 431 | #warnings.filterwarnings("ignore", category=FutureWarning) 432 | #import numpy as np 433 | #import matplotlib 434 | #matplotlib.use('Agg') 435 | #import matplotlib.cm as cm 436 | #import matplotlib.mlab as mlab 437 | #import matplotlib.pyplot as plt 438 | 439 | matplotlib.rcParams['xtick.direction'] = 'out' 440 | matplotlib.rcParams['ytick.direction'] = 'out' 441 | matplotlib.rcParams['contour.negative_linestyle'] = 'solid' 442 | 443 | # gen grid 444 | delta = 0.1 445 | x = np.arange(-25.0, 25.0, delta) 446 | y = np.arange(-25.0, 25.0, delta) 447 | X, Y = np.meshgrid(x, y) 448 | 449 | # convert numpy array to to torch variable 450 | (h, w) = X.shape 451 | XY = np.concatenate((X.reshape((h*w, 1, 1, 1)), Y.reshape((h*w, 1, 1, 1))), axis=1) 452 | input = torch.Tensor(XY) 453 | input = Variable(input) 454 | if cuda: 455 | input = input.cuda() 456 | 457 | # forward 458 | output = netD(input) 459 | 460 | # convert torch variable to numpy array 461 | Z = output.data.cpu().view(-1).numpy().reshape(h, w) 462 | 463 | # plot and save 464 | plt.figure() 465 | CS1 = plt.contourf(X, Y, Z) 466 | CS2 = plt.contour(X, Y, Z, alpha=.7, colors='k') 467 | plt.clabel(CS2, inline=1, fontsize=10, colors='k') 468 | plt.title('Simplest default with labels') 469 | plt.savefig(filename) 470 | plt.close() 471 | 472 | 473 | ''' 474 | ### test 475 | import numpy as np 476 | import matplotlib 477 | import matplotlib.pyplot as plt 478 | 479 | num_data = 10000 480 | exp_name = 'exp6' 481 | 482 | if exp_name == 'exp1': 483 | data, label, sumloglikelihood = exp1(num_data) 484 | elif exp_name == 'exp2': 485 | data, label, sumloglikelihood = exp2(num_data) 486 | elif exp_name == 'exp3': 487 | data, label, sumloglikelihood = exp3(num_data) 488 | elif exp_name == 'exp4': 489 | data, label, sumloglikelihood = exp4(num_data) 490 | elif exp_name == 'exp5': 491 | data, label, sumloglikelihood = exp5(num_data) 492 | elif exp_name == 'exp6': 493 | data, label, sumloglikelihood = exp6(num_data) 494 | else: 495 | raise ValueError('known exp: {}'.format(exp_name)) 496 | data = data.numpy() 497 | label = label.numpy() 498 | colors = ['red','purple','green','blue'] 499 | #print(data) 500 | #print(data.shape) 501 | #print(label) 502 | #print(label.shape) 503 | 504 | fig, ax = plt.subplots() 505 | #plt.scatter(data[:,0], data[:,1], c=label, alpha=0.01, label=exp_name, cmap=matplotlib.colors.ListedColormap(colors)) 506 | plt.scatter(data[:,0], data[:,1], c=label, alpha=0.1, label=exp_name, cmap=matplotlib.colors.ListedColormap(colors)) 507 | plt.axis('equal') 508 | plt.minorticks_on() 509 | plt.grid(True) 510 | plt.xlabel('x', fontsize=14, color='black') 511 | plt.ylabel('y', fontsize=14, color='black') 512 | plt.title('Toy dataset') 513 | plt.savefig('toy.png') 514 | ''' 515 | -------------------------------------------------------------------------------- /losses/BCELoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | 8 | class BCELoss(nn.Module): 9 | def __init__(self, sign=1): 10 | super(BCELoss, self).__init__() 11 | self.sign = sign 12 | self.main = nn.BCELoss() 13 | 14 | def forward(self, input, target): 15 | output = self.main(input, target) 16 | output = torch.mul(output, self.sign) 17 | return output 18 | 19 | def cuda(self, device_id=None): 20 | super(BCELoss, self).cuda(device_id) 21 | self.main.cuda() 22 | -------------------------------------------------------------------------------- /losses/HingeLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | 10 | class HingeLoss(nn.Module): 11 | def __init__(self, margin=1.0, size_average=True, sign=1.0): 12 | super(HingeLoss, self).__init__() 13 | self.sign = sign 14 | self.margin = margin 15 | self.size_average = size_average 16 | 17 | def forward(self, input, target): 18 | # 19 | input = input.view(-1) 20 | 21 | # 22 | assert input.dim() == target.dim() 23 | for i in range(input.dim()): 24 | assert input.size(i) == target.size(i) 25 | 26 | # 27 | output = self.margin - torch.mul(target, input) 28 | 29 | # 30 | if 'cuda' in input.data.type(): 31 | mask = torch.cuda.FloatTensor(input.size()).zero_() 32 | else: 33 | mask = torch.FloatTensor(input.size()).zero_() 34 | mask = Variable(mask) 35 | mask[torch.gt(output, 0.0)] = 1.0 36 | 37 | # 38 | output = torch.mul(output, mask) 39 | 40 | # size average 41 | if self.size_average: 42 | output = torch.mul(output, 1.0 / input.nelement()) 43 | 44 | # sum 45 | output = output.sum() 46 | 47 | # apply sign 48 | output = torch.mul(output, self.sign) 49 | return output 50 | -------------------------------------------------------------------------------- /losses/LeakyHingeLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | 10 | class LeakyHingeLoss(nn.Module): 11 | def __init__(self, margin=1.0, slope=0.1, size_average=True, sign=1.0): 12 | super(LeakyHingeLoss, self).__init__() 13 | self.sign = sign 14 | self.margin = margin 15 | self.slope = slope 16 | self.size_average = size_average 17 | self.leakyrelu = nn.LeakyReLU(self.slope) 18 | 19 | def forward(self, input, target): 20 | # 21 | input = input.view(-1) 22 | 23 | # 24 | assert input.dim() == target.dim() 25 | for i in range(input.dim()): 26 | assert input.size(i) == target.size(i) 27 | 28 | # 29 | output = self.margin - torch.mul(target, input) 30 | 31 | # 32 | output = self.leakyrelu(output) 33 | 34 | # size average 35 | if self.size_average: 36 | output = torch.mul(output, 1.0 / input.nelement()) 37 | 38 | # sum 39 | output = output.sum() 40 | 41 | # apply sign 42 | output = torch.mul(output, self.sign) 43 | 44 | return output 45 | 46 | def cuda(self, device_id=None): 47 | super(LeakyHingeLoss, self).cuda(device_id) 48 | self.leakyrelu.cuda() 49 | -------------------------------------------------------------------------------- /losses/SumLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | 8 | class SumLoss(nn.Module): 9 | def __init__(self, sign=1.0): 10 | super(SumLoss, self).__init__() 11 | self.sign = sign 12 | 13 | def forward(self, input, target=0): 14 | output = torch.mul(input, self.sign) 15 | if input.dim() == 4: 16 | output = output.view(input.size(0), 17 | input.size(1) * input.size(2) * input.size(3)) 18 | elif input.dim() == 3: 19 | output = output.view(input.size(0), 20 | input.size(1) * input.size(2)) 21 | output = output.mean(0) 22 | return output.view(1) 23 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim0606/pytorch-geometric-gan/eb84feb5cae1d6963c075aa6fb4c0c3a18eeec41/losses/__init__.py -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | import torchvision.datasets as dset 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | from torch.autograd import Variable 14 | import os 15 | import time 16 | from scipy.stats import multivariate_normal 17 | import numpy as np 18 | 19 | import models.dcgan as dcgan 20 | import models.mlp as mlp 21 | import models.toy as toy 22 | import models.toy4 as toy4 23 | import losses.SumLoss as sumloss 24 | import losses.HingeLoss as hingeloss 25 | import losses.LeakyHingeLoss as leakyhingeloss 26 | import losses.BCELoss as bceloss 27 | import utils.plot as plt 28 | 29 | 30 | parent_parser = argparse.ArgumentParser(add_help=False) 31 | parent_parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | toy1~toy4') 32 | parent_parser.add_argument('--dataroot', required=True, help='path to dataset') 33 | parent_parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 34 | parent_parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 35 | parent_parser.add_argument('--loadSize', type=int, default=64, help='the height / width of the input image (it will be croppred)') 36 | parent_parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 37 | parent_parser.add_argument('--nc', type=int, default=3, help='number of channels in input (image)') 38 | parent_parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 39 | parent_parser.add_argument('--ngf', type=int, default=64) 40 | parent_parser.add_argument('--ndf', type=int, default=64) 41 | parent_parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') 42 | parent_parser.add_argument('--nsave', type=int, default=1, help='number of epochs to save models') 43 | parent_parser.add_argument('--lrD', type=float, default=0.00005, help='learning rate for Critic, default=0.00005') 44 | parent_parser.add_argument('--lrG', type=float, default=0.00005, help='learning rate for Generator, default=0.00005') 45 | parent_parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 46 | parent_parser.add_argument('--weight_decay_D', type=float, default=0, help='weight_decay for discriminator. default=0') 47 | parent_parser.add_argument('--weight_decay_G', type=float, default=0, help='weight_decay for generator. default=0') 48 | parent_parser.add_argument('--cuda' , action='store_true', help='enables cuda') 49 | parent_parser.add_argument('--ngpu' , type=int, default=1, help='number of GPUs to use') 50 | parent_parser.add_argument('--netG', default='', help="path to netG (to continue training)") 51 | parent_parser.add_argument('--netD', default='', help="path to netD (to continue training)") 52 | parent_parser.add_argument('--Diters', type=int, default=1, help='number of D iters per loop') 53 | parent_parser.add_argument('--Giters', type=int, default=1, help='number of G iters per loop') 54 | parent_parser.add_argument('--noBN', action='store_true', help='use batchnorm or not (only for DCGAN)') 55 | parent_parser.add_argument('--model_G', default='dcgan', help='model for G: dcgan | mlp | toy') 56 | parent_parser.add_argument('--model_D', default='dcgan', help='model for D: dcgan | mlp | toy') 57 | parent_parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc') 58 | parent_parser.add_argument('--experiment', default=None, help='Where to store samples and models') 59 | parent_parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)') 60 | 61 | # arguments for weight clipping 62 | parent_parser.add_argument('--wclip_lower', type=float, default=-0.01) 63 | parent_parser.add_argument('--wclip_upper', type=float, default=0.01) 64 | wclip_parser = parent_parser.add_mutually_exclusive_group(required=False) 65 | wclip_parser.add_argument('--wclip', dest='wclip', action='store_true', help='flag for wclip. for wgan, it is required.') 66 | wclip_parser.add_argument('--no-wclip', dest='wclip', action='store_false', help='flag for wclip. for wgan, it is required.') 67 | parent_parser.set_defaults(wclip=False) 68 | 69 | # arguments for weight projection 70 | parent_parser.add_argument('--wproj_upper', type=float, default=1.0) 71 | wproj_parser = parent_parser.add_mutually_exclusive_group(required=False) 72 | wproj_parser.add_argument('--wproj', dest='wproj', action='store_true', help='flag for wproj. for wgan, it is required.') 73 | wproj_parser.add_argument('--no-wproj', dest='wproj', action='store_false', help='flag for wproj. for wgan, it is required.') 74 | parent_parser.set_defaults(wproj=False) 75 | 76 | # display setting 77 | display_parser = parent_parser.add_mutually_exclusive_group(required=False) 78 | display_parser.add_argument('--display', dest='display', action='store_true', help='flag for display. for toy1~toy4, it should be off.') 79 | display_parser.add_argument('--no-display', dest='display', action='store_false', help='flag for display. for toy1~toy4, it should be off.') 80 | parent_parser.set_defaults(display=True) 81 | parent_parser.add_argument('--ndisplay', type=int, default=500, help='number of epochs to display samples') 82 | 83 | # arguments for training criterion 84 | def add_criterion(mode_parser, parent_parser): 85 | criterion_subparser = mode_parser.add_subparsers(title='criterion method: gan | wgan | geogan', 86 | dest='criterion') 87 | 88 | # wgan 89 | wgan_parser = criterion_subparser.add_parser('wgan', help='train using WGAN', 90 | parents=[parent_parser]) 91 | 92 | # meangan 93 | meangan_parser = criterion_subparser.add_parser('meangan', help='train using mean matching GAN', 94 | parents=[parent_parser]) 95 | 96 | # geogan 97 | geogan_parser = criterion_subparser.add_parser('geogan', help='train using geoGAN', 98 | parents=[parent_parser]) 99 | geogan_parser.add_argument('--C', type=float, default=1, help='tuning parapmeter C in 0.5 * ||w||^2 + C * hinge_loss(x)') 100 | geogan_parser.add_argument('--margin', type=float, default=1, help='margin size in max(0, m - c * x), hinge loss, for generator loss') 101 | gtrain_parser = geogan_parser.add_mutually_exclusive_group() 102 | gtrain_parser.add_argument('--theory', action='store_const', dest='gtrain', const='theory', 103 | help='For D, real_label = 1, fake_label = -1, and minimize svm primal loss. For G, fake_label = -1, and move perpendicular to hyperplane') 104 | gtrain_parser.add_argument('--leaky', action='store_const', dest='gtrain', const='leaky', 105 | help='For D, real_label = 1, fake_label = -1, and minimize svm primal loss. For G, fake_label = 1, and minize leaky svm primal loss with flipped labels.') 106 | geogan_parser.set_defaults(gtrain='theory') 107 | 108 | # ebgan 109 | ebgan_parser = criterion_subparser.add_parser('ebgan', help='train using EBGAN', 110 | parents=[parent_parser]) 111 | ebgan_parser.add_argument('--margin', type=float, default=1, help='slack margin constant in discriminator loss for fake data.') 112 | 113 | # gan 114 | gan_parser = criterion_subparser.add_parser('gan', help='train using GAN', 115 | parents=[parent_parser]) 116 | gtrain_parser = gan_parser.add_mutually_exclusive_group() 117 | gtrain_parser.add_argument('--theory', action='store_const', dest='gtrain', const='theory', 118 | help='real_label = 1, fake_label = 0; thus, for D, min_D E_data[-log(D(x)] + E_gen[-log(1-D(G(z)))]. for G, min_G E_gen[log(1-D(G(z)))]') 119 | gtrain_parser.add_argument('--practice', action='store_const', dest='gtrain', const='practice', 120 | help='for D, min_D E_data[-log(D(x)] + E_gen[-log(1-D(G(z)))]. for G, min_G E_gen[-log(D(G(z)))]') 121 | gtrain_parser.add_argument('--flip', action='store_const', dest='gtrain', const='flip', 122 | help='real_label = 0, fake_label = 1.') 123 | gan_parser.set_defaults(gtrain='practice') 124 | 125 | # main parser and training mode 126 | main_parser = argparse.ArgumentParser() 127 | mode_subparsers = main_parser.add_subparsers(title='training mode: standard | bigan | ali', 128 | dest='mode') 129 | mode_standard_parser = mode_subparsers.add_parser('standard', help='train as standard implicit modeling') 130 | add_criterion(mode_standard_parser, parent_parser) 131 | #mode_bigan_parser = mode_subparsers.add_parser('bigan', help='train as BiGAN') 132 | #add_criterion(mode_bigan_parser, parent_parser) 133 | #mode_ali_parser = mode_subparsers.add_parser('ali', help='train as ALI') 134 | #add_criterion(mode_ali_parser, parent_parser) 135 | 136 | # parse arguments 137 | opt = main_parser.parse_args() 138 | print(opt) 139 | 140 | # generate cache folder 141 | os.system('mkdir samples') 142 | if opt.experiment is None: 143 | opt.experiment = 'samples/experiment' 144 | os.system('mkdir -p {0}'.format(opt.experiment)) 145 | 146 | # set random seed 147 | opt.manualSeed = random.randint(1, 10000) # fix seed 148 | print("Random Seed: ", opt.manualSeed) 149 | random.seed(opt.manualSeed) 150 | torch.manual_seed(opt.manualSeed) 151 | 152 | # apply cudnn option 153 | cudnn.benchmark = True 154 | 155 | # diagnose cuda option 156 | if torch.cuda.is_available() and not opt.cuda: 157 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 158 | 159 | 160 | # load dataset 161 | if opt.dataset in ['imagenet', 'folder', 'lfw']: 162 | # folder dataset 163 | dataset = dset.ImageFolder(root=opt.dataroot, 164 | transform=transforms.Compose([ 165 | transforms.Scale(opt.loadSize), 166 | transforms.CenterCrop(opt.imageSize), 167 | transforms.ToTensor(), 168 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 169 | ])) 170 | elif opt.dataset == 'lsun': 171 | dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'], 172 | transform=transforms.Compose([ 173 | transforms.Scale(opt.loadSize), 174 | transforms.CenterCrop(opt.imageSize), 175 | transforms.ToTensor(), 176 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 177 | ])) 178 | elif opt.dataset == 'cifar10': 179 | dataset = dset.CIFAR10(root=opt.dataroot, download=True, 180 | transform=transforms.Compose([ 181 | transforms.Scale(opt.imageSize), 182 | transforms.ToTensor(), 183 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 184 | ])) 185 | elif opt.dataset == 'mnist': 186 | dataset = dset.MNIST(root=opt.dataroot, download=True, 187 | transform=transforms.Compose([ 188 | transforms.Scale(opt.imageSize), 189 | transforms.ToTensor(), 190 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 191 | ])) 192 | elif 'toy' in opt.dataset: #opt.dataset in ['toy1', 'toy2', 'toy3', 'toy4', 'toy5', 'toy6']: 193 | if opt.nc != 2: 194 | raise ValueError('nc should be 2 for simulated dataset. (opt.nc = {})'.format(opt.nc)) 195 | import datasets.toy as tdset 196 | num_data = 100000 197 | data_tensor, target_tensor, x_sumloglikelihood = tdset.exp(opt.dataset, num_data) 198 | data_tensor = data_tensor.view(num_data, 2, 1, 1).contiguous() 199 | dataset = torch.utils.data.TensorDataset(data_tensor, target_tensor) 200 | assert dataset 201 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 202 | shuffle=True, num_workers=int(opt.workers)) 203 | 204 | # init model parameters 205 | ngpu = int(opt.ngpu) 206 | nz = int(opt.nz) 207 | ngf = int(opt.ngf) 208 | ndf = int(opt.ndf) 209 | nc = opt.nc 210 | n_extra_layers = int(opt.n_extra_layers) 211 | 212 | # custum function for weight project in l2-norm unit ball 213 | def weight_proj_l2norm(param): 214 | norm = torch.norm(param.data, p=2) + 1e-8 215 | coeff = min(opt.wproj_upper, 1.0/norm) 216 | param.data.mul_(coeff) 217 | 218 | # custom weights initialization called on netG and netD 219 | def weights_init_dcgan(m): 220 | classname = m.__class__.__name__ 221 | if classname.find('Conv') != -1: 222 | m.weight.data.normal_(0.0, 0.02) 223 | elif classname.find('BatchNorm') != -1: 224 | m.weight.data.normal_(1.0, 0.02) 225 | m.bias.data.fill_(0) 226 | 227 | def weights_init_mlp(m): 228 | classname = m.__class__.__name__ 229 | if classname.find('Linear') != -1: 230 | m.weight.data.normal_(0.0, 0.01) 231 | m.bias.data.fill_(0) 232 | 233 | def weights_init_toy(m): 234 | classname = m.__class__.__name__ 235 | if classname.find('Linear') != -1: 236 | m.weight.data.normal_(0.0, 0.01) 237 | if m.bias: 238 | m.bias.data.fill_(0) 239 | elif classname.find('BatchNorm') != -1: 240 | m.weight.data.normal_(1.0, 0.01) 241 | m.bias.data.fill_(0) 242 | 243 | # model initializaton: genterator 244 | if opt.model_G == 'dcgan': 245 | if opt.noBN: 246 | netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) 247 | else: 248 | netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers) 249 | netG.apply(weights_init_dcgan) 250 | elif opt.model_G == 'mlp': 251 | netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu) 252 | netG.apply(weights_init_mlp) 253 | elif opt.model_G == 'toy': 254 | netG = toy.MLP_G(1, nz, 2, ngf, ngpu) 255 | netG.apply(weights_init_toy) 256 | elif opt.model_G == 'toy4': 257 | netG = toy4.MLP_G(1, nz, 2, ngf, ngpu) 258 | netG.apply(weights_init_toy) 259 | else: 260 | raise ValueError('unkown model: {}'.format(opt.model_G)) 261 | 262 | if opt.netG != '': # load checkpoint if needed 263 | netG.load_state_dict(torch.load(opt.netG)) 264 | print(netG) 265 | 266 | # model initializaton: discriminator 267 | if opt.model_D == 'dcgan': 268 | netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers) 269 | netD.apply(weights_init_dcgan) 270 | elif opt.model_D == 'mlp': 271 | netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu) 272 | netD.apply(weights_init_mlp) 273 | elif opt.model_D == 'toy': 274 | netD = toy.MLP_D(1, nz, 2, ndf, ngpu) 275 | netD.apply(weights_init_toy) 276 | elif opt.model_D == 'toy4': 277 | netD = toy4.MLP_D(1, nz, 2, ndf, ngpu) 278 | netD.apply(weights_init_toy) 279 | else: 280 | raise ValueError('unkown model: {}'.format(opt.model_D)) 281 | 282 | if opt.criterion == 'gan': 283 | # add sigmoid activation function for gan 284 | netD.main.add_module('sigmoid', 285 | nn.Sigmoid()) 286 | 287 | if opt.netD != '': 288 | netD.load_state_dict(torch.load(opt.netD)) 289 | print(netD) 290 | 291 | # set type of adversarial training 292 | if opt.criterion == 'gan': 293 | criterion_R = nn.BCELoss() 294 | criterion_F = nn.BCELoss() 295 | if opt.gtrain == 'theory' or opt.gtrain == 'flip': 296 | criterion_G = bceloss.BCELoss(-1) 297 | else: #opt.gtrain == 'practice': 298 | criterion_G = nn.BCELoss() 299 | elif opt.criterion == 'wgan' or opt.criterion == 'meangan': 300 | criterion_R = sumloss.SumLoss() 301 | criterion_F = sumloss.SumLoss(-1) 302 | criterion_G = sumloss.SumLoss() 303 | elif opt.criterion == 'geogan': 304 | criterion_R = hingeloss.HingeLoss() 305 | criterion_F = hingeloss.HingeLoss() 306 | if opt.gtrain == 'theory': 307 | criterion_G = sumloss.SumLoss(sign=-1.0) 308 | elif opt.gtrain == 'leaky': 309 | criterion_G = leakyhingeloss.LeakyHingeLoss(margin=opt.margin) 310 | else: 311 | raise NotImplementedError('unknown opt.gtrain: {}'.format(opt.gtrain)) 312 | elif opt.criterion == 'ebgan': 313 | criterion_R = sumloss.SumLoss(sign=1.0) 314 | criterion_F = hingeloss.HingeLoss(margin=opt.margin) 315 | criterion_G = sumloss.SumLoss(sign=1.0) 316 | else: 317 | raise ValueError('unknown criterion: {}'.format(opt.criterion)) 318 | 319 | 320 | # init variables 321 | input = torch.FloatTensor(opt.batchSize, nc, opt.imageSize, opt.imageSize) 322 | noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) 323 | fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) 324 | label = torch.FloatTensor(opt.batchSize) 325 | if opt.criterion == 'gan' and opt.gtrain == 'theory': 326 | real_label = 1 327 | fake_label = 0 328 | gen_label = fake_label 329 | elif opt.criterion == 'gan' and opt.gtrain == 'flip': 330 | real_label = 0 331 | fake_label = 1 332 | gen_label = fake_label 333 | elif opt.criterion == 'geogan' and opt.gtrain == 'theory': 334 | real_label = 1 335 | fake_label = -1 336 | gen_label = fake_label 337 | elif opt.criterion == 'geogan' and opt.gtrain == 'leaky': 338 | real_label = 1 339 | fake_label = -1 340 | gen_label = real_label 341 | elif opt.criterion == 'ebgan': 342 | real_label = -1 343 | fake_label = 1 344 | gen_label = fake_label 345 | else: # opt.gtrain == 'practice' 346 | real_label = 1 347 | fake_label = 0 348 | gen_label = real_label 349 | 350 | 351 | # init cuda 352 | if opt.cuda: 353 | netD.cuda() 354 | netG.cuda() 355 | criterion_R.cuda() 356 | criterion_F.cuda() 357 | criterion_G.cuda() 358 | input, label = input.cuda(), label.cuda() 359 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 360 | 361 | 362 | # convert to autograd variable 363 | input = Variable(input) 364 | label = Variable(label) 365 | noise = Variable(noise) 366 | fixed_noise = Variable(fixed_noise) 367 | 368 | 369 | # setup optimizer 370 | if opt.criterion == 'geogan': 371 | paramsD = [ 372 | {'params': filter(lambda p: p.cls_weight, netD.parameters()), 'weight_decay': 1.0 / (float(opt.batchSize) * float(opt.C)) }, # assign weight decay for geogan to cls layer only 373 | {'params': filter(lambda p: p.cls_bias, netD.parameters()) }, # no weight decay to the bias of cls layer 374 | {'params': filter(lambda p: not p.cls, netD.parameters()), 'weight_decay': opt.weight_decay_D } 375 | ] 376 | else: 377 | paramsD = [ 378 | {'params': filter(lambda p: p.cls, netD.parameters()) }, # no weight decay to the bias of cls layer 379 | {'params': filter(lambda p: not p.cls, netD.parameters()), 'weight_decay': opt.weight_decay_D } 380 | ] 381 | #paramsD = [ 382 | # {'params': netD.parameters(), 'weight_decay': opt.weight_decay_D }, 383 | #] 384 | if opt.adam: 385 | optimizerD = optim.Adam(paramsD, lr=opt.lrD, betas=(opt.beta1, 0.999))#, weight_decay=opt.weight_decay_D) 386 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_G) 387 | else: 388 | optimizerD = optim.RMSprop(paramsD, lr=opt.lrD)#, weight_decay=opt.weight_decay_D) 389 | optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG, weight_decay=opt.weight_decay_G) 390 | 391 | 392 | # training 393 | gen_iterations = 0 394 | disc_iterations = 0 395 | errM_print = -float('inf') 396 | errM_real_print = -float('inf') 397 | errM_fake_print = -float('inf') 398 | for epoch in range(opt.niter): 399 | data_iter = iter(dataloader) 400 | i = 0 401 | 402 | while i < len(dataloader): 403 | tm_start = time.time() 404 | 405 | ############################ 406 | # (1) Update D network 407 | ############################ 408 | for p in netD.parameters(): # reset requires_grad 409 | p.requires_grad = True # they are set to False below in netG update 410 | for p in netG.parameters(): 411 | p.requires_grad = False # to avoid computation 412 | 413 | # train the discriminator Diters times 414 | if opt.wclip and (gen_iterations < 25 or gen_iterations % 500 == 0): 415 | Diters = 100 416 | else: 417 | Diters = opt.Diters 418 | j = 0 419 | while j < Diters and i < len(dataloader): 420 | j += 1 421 | disc_iterations += 1 422 | 423 | ##### weight clipping 424 | # wclip parameters to a cube 425 | if opt.wclip: 426 | for p in netD.parameters(): 427 | if not p.cls:# or opt.criterion != 'geogan': 428 | p.data.clamp_(opt.wclip_lower, opt.wclip_upper) 429 | 430 | # wclip parameters to a cube for the last linear layer of disc if opt.criterion == 'wgan' 431 | if opt.criterion == 'wgan': 432 | for p in netD.parameters(): 433 | if p.cls: 434 | p.data.clamp_(opt.wclip_lower, opt.wclip_upper) 435 | 436 | ##### weight projection 437 | # weight projection to a cube for parameters 438 | if opt.wproj: 439 | for p in netD.parameters(): 440 | if not p.cls:# or opt.criterion != 'geogan': 441 | weight_proj_l2norm(p) 442 | 443 | # wproj parameters to a cube for the last linear layer of disc if opt.criterion == 'meangan' 444 | if opt.criterion == 'meangan': 445 | for p in netD.parameters(): 446 | if p.cls: 447 | weight_proj_l2norm(p) 448 | 449 | data_tm_start = time.time() 450 | data = data_iter.next() 451 | data_tm_end = time.time() 452 | i += 1 453 | 454 | # train with real 455 | real_cpu, _ = data 456 | netD.zero_grad() 457 | batch_size = real_cpu.size(0) 458 | input.data.resize_(real_cpu.size()).copy_(real_cpu) 459 | label.data.resize_(batch_size).fill_(real_label) 460 | outD_real = netD(input) 461 | errD_real = criterion_R(outD_real, label) 462 | errD_real.backward() 463 | 464 | # train with fake 465 | noise.data.resize_(batch_size, nz, 1, 1) 466 | noise.data.normal_(0, 1) 467 | fake = netG(noise) 468 | label.data.fill_(fake_label) 469 | input.data.copy_(fake.data) 470 | outD_fake = netD(input) 471 | errD_fake = criterion_F(outD_fake, label) 472 | errD_fake.backward() 473 | errD = errD_real + errD_fake 474 | optimizerD.step() 475 | 476 | 477 | ############################ 478 | # (2) Update G network 479 | ############################ 480 | for p in netD.parameters(): 481 | p.requires_grad = False # to avoid computation 482 | for p in netG.parameters(): 483 | p.requires_grad = True # reset requires_grad 484 | 485 | j = 0 486 | while j < opt.Giters: 487 | j += 1 488 | gen_iterations += 1 489 | 490 | netG.zero_grad() 491 | 492 | # in case our last batch was the tail batch of the dataloader, 493 | # make sure we feed a full batch of noise 494 | label.data.resize_(opt.batchSize).fill_(gen_label) 495 | noise.data.resize_(opt.batchSize, nz, 1, 1) 496 | noise.data.normal_(0, 1) 497 | 498 | # forward G 499 | fake = netG(noise) 500 | 501 | # forward D (backward from D) 502 | outG = netD(fake) 503 | errG = criterion_G(outG, label) 504 | errG.backward() 505 | 506 | # update G 507 | optimizerG.step() 508 | 509 | 510 | ############################ 511 | # Display results 512 | ############################ 513 | if opt.display and (gen_iterations % opt.ndisplay == 0): 514 | if 'toy' in opt.dataset: 515 | fake = netG(fixed_noise) 516 | tdset.save_image(real_cpu.view(-1,2).numpy(), 517 | fake.data.cpu().view(-1,2).numpy(), 518 | '{0}/real_fake_samples_{1}.png'.format(opt.experiment, gen_iterations)) 519 | #tdset.save_contour(netD, 520 | # '{0}/disc_contour_{1}.png'.format(opt.experiment, gen_iterations), 521 | # cuda=opt.cuda) 522 | else: 523 | vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment), normalize=True) 524 | fake = netG(fixed_noise) 525 | vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations), normalize=True) 526 | 527 | tm_end = time.time() 528 | 529 | if 'toy' in opt.dataset: 530 | print('Epoch: [%d][%d/%d][%d]\t Time: %.3f DataTime: %.3f Loss_G: %f Loss_D: %f Loss_D_real: %f Loss_D_fake: %f x_real_sll: %f x_fake_sll: %f' 531 | % (epoch, i, len(dataloader), gen_iterations, 532 | tm_end-tm_start, data_tm_end-data_tm_start, 533 | errG.data[0], errD.data[0], errD_real.data[0], errD_fake.data[0], 534 | x_sumloglikelihood(real_cpu.view(-1,2).numpy()), x_sumloglikelihood(fake.data.cpu().view(-1,2).numpy()))) 535 | else: 536 | print('Epoch: [%d][%d/%d][%d]\t Time: %.3f DataTime: %.3f Loss_G: %f Loss_D: %f Loss_D_real: %f Loss_D_fake: %f' 537 | % (epoch, i, len(dataloader), gen_iterations, 538 | tm_end-tm_start, data_tm_end-data_tm_start, 539 | errG.data[0], errD.data[0], errD_real.data[0], errD_fake.data[0])) 540 | 541 | 542 | ############################ 543 | # Detect errors 544 | ############################ 545 | if np.isnan(errG.data[0]) or np.isnan(errD.data[0]) or np.isnan(errD_real.data[0]) or np.isnan(errD_fake.data[0]): 546 | raise ValueError('nan detected.') 547 | if np.isinf(errG.data[0]) or np.isinf(errD.data[0]) or np.isinf(errD_real.data[0]) or np.isinf(errD_fake.data[0]): 548 | raise ValueError('inf detected.') 549 | 550 | 551 | # do checkpointing 552 | if (epoch+1) % opt.nsave == 0: 553 | torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch)) 554 | torch.save(optimizerG.state_dict(), '{0}/optG_epoch_{1}.pth'.format(opt.experiment, epoch)) 555 | torch.save(netD.state_dict(), '{0}/netD_epoch_{1}.pth'.format(opt.experiment, epoch)) 556 | torch.save(optimizerD.state_dict(), '{0}/optD_epoch_{1}.pth'.format(opt.experiment, epoch)) 557 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lim0606/pytorch-geometric-gan/eb84feb5cae1d6963c075aa6fb4c0c3a18eeec41/models/__init__.py -------------------------------------------------------------------------------- /models/dcgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | 5 | class DCGAN_D(nn.Module): 6 | def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): 7 | super(DCGAN_D, self).__init__() 8 | self.ngpu = ngpu 9 | assert isize % 16 == 0, "isize has to be a multiple of 16" 10 | 11 | main = nn.Sequential() 12 | # input is nc x isize x isize 13 | main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), 14 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) 15 | main.add_module('initial.relu.{0}'.format(ndf), 16 | nn.LeakyReLU(0.2, inplace=True)) 17 | csize, cndf = isize / 2, ndf 18 | 19 | # Extra layers 20 | for t in range(n_extra_layers): 21 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cndf), 22 | nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False)) 23 | main.add_module('extra-layers-{0}.{1}.batchnorm'.format(t, cndf), 24 | nn.BatchNorm2d(cndf)) 25 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cndf), 26 | nn.LeakyReLU(0.2, inplace=True)) 27 | 28 | while csize > 4: 29 | in_feat = cndf 30 | out_feat = cndf * 2 31 | main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), 32 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) 33 | main.add_module('pyramid.{0}.batchnorm'.format(out_feat), 34 | nn.BatchNorm2d(out_feat)) 35 | main.add_module('pyramid.{0}.relu'.format(out_feat), 36 | nn.LeakyReLU(0.2, inplace=True)) 37 | cndf = cndf * 2 38 | csize = csize / 2 39 | 40 | # state size. K x 4 x 4 41 | cls = nn.Conv2d(cndf, 1, 4, 1, 0, bias=False) 42 | main.add_module('final.{0}-{1}.conv'.format(cndf, 1), 43 | cls) 44 | 45 | self.main = main 46 | 47 | # assign cls flag 48 | for p in self.main.parameters(): 49 | p.cls = False 50 | p.cls_weight = False 51 | p.cls_bias = False 52 | for p in cls.parameters(): 53 | p.cls = True 54 | if p.data.ndimension() == cls.weight.data.ndimension(): 55 | p.cls_weight = True 56 | else: 57 | p.cls_bias = True 58 | 59 | def forward(self, input): 60 | #gpu_ids = None 61 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 62 | # gpu_ids = range(self.ngpu) 63 | #return nn.parallel.data_parallel(self.main, input, gpu_ids) 64 | return self.main(input) 65 | 66 | class DCGAN_G(nn.Module): 67 | def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): 68 | super(DCGAN_G, self).__init__() 69 | self.ngpu = ngpu 70 | assert isize % 16 == 0, "isize has to be a multiple of 16" 71 | 72 | cngf, tisize = ngf//2, 4 73 | while tisize != isize: 74 | cngf = cngf * 2 75 | tisize = tisize * 2 76 | 77 | main = nn.Sequential() 78 | # input is Z, going into a convolution 79 | main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), 80 | nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) 81 | main.add_module('initial.{0}.batchnorm'.format(cngf), 82 | nn.BatchNorm2d(cngf)) 83 | main.add_module('initial.{0}.relu'.format(cngf), 84 | nn.ReLU(True)) 85 | 86 | csize, cndf = 4, cngf 87 | while csize < isize//2: 88 | main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), 89 | nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) 90 | main.add_module('pyramid.{0}.batchnorm'.format(cngf//2), 91 | nn.BatchNorm2d(cngf//2)) 92 | main.add_module('pyramid.{0}.relu'.format(cngf//2), 93 | nn.ReLU(True)) 94 | cngf = cngf // 2 95 | csize = csize * 2 96 | 97 | # Extra layers 98 | for t in range(n_extra_layers): 99 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cngf), 100 | nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False)) 101 | main.add_module('extra-layers-{0}.{1}.batchnorm'.format(t, cngf), 102 | nn.BatchNorm2d(cngf)) 103 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cngf), 104 | nn.ReLU(True)) 105 | 106 | main.add_module('final.{0}-{1}.convt'.format(cngf, nc), 107 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) 108 | main.add_module('final.{0}.tanh'.format(nc), 109 | nn.Tanh()) 110 | self.main = main 111 | 112 | def forward(self, input): 113 | #gpu_ids = None 114 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 115 | # gpu_ids = range(self.ngpu) 116 | #return nn.parallel.data_parallel(self.main, input, gpu_ids) 117 | return self.main(input) 118 | 119 | 120 | ############################################################################### 121 | class DCGAN_D_nobn(nn.Module): 122 | def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0): 123 | super(DCGAN_D_nobn, self).__init__() 124 | self.ngpu = ngpu 125 | assert isize % 16 == 0, "isize has to be a multiple of 16" 126 | 127 | main = nn.Sequential() 128 | # input is nc x isize x isize 129 | # input is nc x isize x isize 130 | main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), 131 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) 132 | main.add_module('initial.relu.{0}'.format(ndf), 133 | nn.LeakyReLU(0.2, inplace=True)) 134 | csize, cndf = isize / 2, ndf 135 | 136 | # Extra layers 137 | for t in range(n_extra_layers): 138 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cndf), 139 | nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False)) 140 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cndf), 141 | nn.LeakyReLU(0.2, inplace=True)) 142 | 143 | while csize > 4: 144 | in_feat = cndf 145 | out_feat = cndf * 2 146 | main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), 147 | nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) 148 | main.add_module('pyramid.{0}.relu'.format(out_feat), 149 | nn.LeakyReLU(0.2, inplace=True)) 150 | cndf = cndf * 2 151 | csize = csize / 2 152 | 153 | # state size. K x 4 x 4 154 | cls = nn.Conv2d(cndf, 1, 4, 1, 0, bias=False) 155 | main.add_module('final.{0}-{1}.conv'.format(cndf, 1), 156 | cls) 157 | 158 | self.main = main 159 | 160 | # assign cls flag 161 | for p in self.main.parameters(): 162 | p.cls = False 163 | p.cls_weight = False 164 | p.cls_bias = False 165 | for p in cls.parameters(): 166 | p.cls = True 167 | if p.data.ndimension() == cls.weight.data.ndimension(): 168 | p.cls_weight = True 169 | else: 170 | p.cls_bias = True 171 | 172 | def forward(self, input): 173 | #gpu_ids = None 174 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 175 | # gpu_ids = range(self.ngpu) 176 | ##output = nn.parallel.data_parallel(self.main, input, gpu_ids) 177 | ##output = output.mean(0) 178 | ##return output.view(1) 179 | #return nn.parallel.data_parallel(self.main, input, gpu_ids) 180 | return self.main(input) 181 | 182 | class DCGAN_G_nobn(nn.Module): 183 | def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): 184 | super(DCGAN_G_nobn, self).__init__() 185 | self.ngpu = ngpu 186 | assert isize % 16 == 0, "isize has to be a multiple of 16" 187 | 188 | cngf, tisize = ngf//2, 4 189 | while tisize != isize: 190 | cngf = cngf * 2 191 | tisize = tisize * 2 192 | 193 | main = nn.Sequential() 194 | main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), 195 | nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) 196 | main.add_module('initial.{0}.relu'.format(cngf), 197 | nn.ReLU(True)) 198 | 199 | csize, cndf = 4, cngf 200 | while csize < isize//2: 201 | main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), 202 | nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) 203 | main.add_module('pyramid.{0}.relu'.format(cngf//2), 204 | nn.ReLU(True)) 205 | cngf = cngf // 2 206 | csize = csize * 2 207 | 208 | # Extra layers 209 | for t in range(n_extra_layers): 210 | main.add_module('extra-layers-{0}.{1}.conv'.format(t, cngf), 211 | nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False)) 212 | main.add_module('extra-layers-{0}.{1}.relu'.format(t, cngf), 213 | nn.ReLU(True)) 214 | 215 | main.add_module('final.{0}-{1}.convt'.format(cngf, nc), 216 | nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) 217 | main.add_module('final.{0}.tanh'.format(nc), 218 | nn.Tanh()) 219 | self.main = main 220 | 221 | def forward(self, input): 222 | #gpu_ids = None 223 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 224 | # gpu_ids = range(self.ngpu) 225 | #return nn.parallel.data_parallel(self.main, input, gpu_ids) 226 | return self.main(input) 227 | 228 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | 8 | class MLP_G(nn.Module): 9 | def __init__(self, isize, nz, nc, ngf, ngpu): 10 | super(MLP_G, self).__init__() 11 | self.ngpu = ngpu 12 | 13 | main = nn.Sequential( 14 | # Z goes into a linear of size: ngf 15 | nn.Linear(nz, ngf), 16 | nn.ReLU(True), 17 | nn.Linear(ngf, ngf), 18 | nn.ReLU(True), 19 | nn.Linear(ngf, ngf), 20 | nn.ReLU(True), 21 | nn.Linear(ngf, nc * isize * isize), 22 | ) 23 | self.main = main 24 | self.nc = nc 25 | self.isize = isize 26 | self.nz = nz 27 | 28 | def forward(self, input): 29 | input = input.view(input.size(0), input.size(1)) 30 | gpu_ids = None 31 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 32 | gpu_ids = range(self.ngpu) 33 | out = nn.parallel.data_parallel(self.main, input, gpu_ids) 34 | return out.view(out.size(0), self.nc, self.isize, self.isize) 35 | 36 | 37 | class MLP_D(nn.Module): 38 | def __init__(self, isize, nz, nc, ndf, ngpu): 39 | super(MLP_D, self).__init__() 40 | self.ngpu = ngpu 41 | 42 | cls = nn.Linear(ndf, 1) 43 | main = nn.Sequential( 44 | # Z goes into a linear of size: ndf 45 | nn.Linear(nc * isize * isize, ndf), 46 | nn.ReLU(True), 47 | nn.Linear(ndf, ndf), 48 | nn.ReLU(True), 49 | nn.Linear(ndf, ndf), 50 | nn.ReLU(True), 51 | ) 52 | main.add_module('cls', cls) 53 | 54 | self.main = main 55 | self.nc = nc 56 | self.isize = isize 57 | self.nz = nz 58 | 59 | # assign cls flag 60 | for p in self.main.parameters(): 61 | p.cls = False 62 | p.cls_weight = False 63 | p.cls_bias = False 64 | for p in cls.parameters(): 65 | p.cls = True 66 | if p.data.ndimension() == cls.weight.data.ndimension(): 67 | p.cls_weight = True 68 | else: 69 | p.cls_bias = True 70 | 71 | def forward(self, input): 72 | input = input.view(input.size(0), 73 | input.size(1) * input.size(2) * input.size(3)) 74 | gpu_ids = None 75 | if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 76 | gpu_ids = range(self.ngpu) 77 | return nn.parallel.data_parallel(self.main, input, gpu_ids) 78 | 79 | 80 | -------------------------------------------------------------------------------- /models/toy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | 8 | class MLP_G(nn.Module): 9 | #def __init__(self, isize=1, nz=4, nc=2, ngf=128, ngpu): 10 | def __init__(self, isize, nz, nc, ngf, ngpu): 11 | super(MLP_G, self).__init__() 12 | self.ngpu = ngpu 13 | 14 | main = nn.Sequential( 15 | # Z goes into a linear of size: ngf 16 | nn.Linear(nz, ngf, bias=False), 17 | nn.BatchNorm1d(ngf), 18 | nn.ReLU(True), 19 | nn.Linear(ngf, ngf, bias=False), 20 | nn.BatchNorm1d(ngf), 21 | nn.ReLU(True), 22 | nn.Linear(ngf, nc * isize * isize), 23 | ) 24 | self.main = main 25 | self.nc = nc 26 | self.isize = isize 27 | self.nz = nz 28 | 29 | def forward(self, input): 30 | input = input.view(input.size(0), input.size(1)) 31 | #gpu_ids = None 32 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 33 | # gpu_ids = range(self.ngpu) 34 | #out = nn.parallel.data_parallel(self.main, input, gpu_ids) 35 | out = self.main(input) 36 | return out.view(out.size(0), self.nc, self.isize, self.isize) 37 | 38 | 39 | class MLP_D(nn.Module): 40 | #def __init__(self, isize=1, nz=4, nc=1, ndf=128, ngpu): 41 | def __init__(self, isize, nz, nc, ndf, ngpu): 42 | super(MLP_D, self).__init__() 43 | self.ngpu = ngpu 44 | 45 | cls = nn.Linear(ndf, 1) 46 | main = nn.Sequential( 47 | # Z goes into a linear of size: ndf 48 | nn.Linear(nc * isize * isize, ndf), 49 | nn.ReLU(True), 50 | nn.Linear(ndf, ndf), 51 | nn.ReLU(True), 52 | ) 53 | main.add_module('cls', cls) 54 | 55 | self.main = main 56 | self.nc = nc 57 | self.isize = isize 58 | self.nz = nz 59 | 60 | # assign cls flag 61 | for p in self.main.parameters(): 62 | p.cls = False 63 | p.cls_weight = False 64 | p.cls_bias = False 65 | for p in cls.parameters(): 66 | p.cls = True 67 | if p.data.ndimension() == cls.weight.data.ndimension(): 68 | p.cls_weight = True 69 | else: 70 | p.cls_bias = True 71 | 72 | def forward(self, input): 73 | input = input.view(input.size(0), 74 | input.size(1) * input.size(2) * input.size(3)) 75 | #gpu_ids = None 76 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 77 | # gpu_ids = range(self.ngpu) 78 | #return nn.parallel.data_parallel(self.main, input, gpu_ids) 79 | return self.main(input) 80 | -------------------------------------------------------------------------------- /models/toy4.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import torch 6 | import torch.nn as nn 7 | 8 | class MLP_G(nn.Module): 9 | #def __init__(self, isize=1, nz=4, nc=2, ngf=128, ngpu): 10 | def __init__(self, isize, nz, nc, ngf, ngpu): 11 | super(MLP_G, self).__init__() 12 | self.ngpu = ngpu 13 | 14 | main = nn.Sequential( 15 | # Z goes into a linear of size: ngf 16 | nn.Linear(nz, ngf, bias=False), 17 | nn.BatchNorm1d(ngf), 18 | nn.ReLU(True), 19 | nn.Linear(ngf, ngf, bias=False), 20 | nn.BatchNorm1d(ngf), 21 | nn.ReLU(True), 22 | nn.Linear(ngf, ngf, bias=False), 23 | nn.BatchNorm1d(ngf), 24 | nn.ReLU(True), 25 | nn.Linear(ngf, nc * isize * isize), 26 | ) 27 | self.main = main 28 | self.nc = nc 29 | self.isize = isize 30 | self.nz = nz 31 | 32 | def forward(self, input): 33 | input = input.view(input.size(0), input.size(1)) 34 | #gpu_ids = None 35 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 36 | # gpu_ids = range(self.ngpu) 37 | #out = nn.parallel.data_parallel(self.main, input, gpu_ids) 38 | out = self.main(input) 39 | return out.view(out.size(0), self.nc, self.isize, self.isize) 40 | 41 | 42 | class MLP_D(nn.Module): 43 | #def __init__(self, isize=1, nz=4, nc=1, ndf=128, ngpu): 44 | def __init__(self, isize, nz, nc, ndf, ngpu): 45 | super(MLP_D, self).__init__() 46 | self.ngpu = ngpu 47 | 48 | cls = nn.Linear(ndf, 1) 49 | main = nn.Sequential( 50 | # Z goes into a linear of size: ndf 51 | nn.Linear(nc * isize * isize, ndf), 52 | nn.ReLU(True), 53 | nn.Linear(ndf, ndf), 54 | nn.ReLU(True), 55 | nn.Linear(ndf, ndf), 56 | nn.ReLU(True), 57 | ) 58 | main.add_module('cls', cls) 59 | 60 | self.main = main 61 | self.nc = nc 62 | self.isize = isize 63 | self.nz = nz 64 | 65 | # assign cls flag 66 | for p in self.main.parameters(): 67 | p.cls = False 68 | p.cls_weight = False 69 | p.cls_bias = False 70 | for p in cls.parameters(): 71 | p.cls = True 72 | if p.data.ndimension() == cls.weight.data.ndimension(): 73 | p.cls_weight = True 74 | else: 75 | p.cls_bias = True 76 | 77 | def forward(self, input): 78 | input = input.view(input.size(0), 79 | input.size(1) * input.size(2) * input.size(3)) 80 | #gpu_ids = None 81 | #if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: 82 | # gpu_ids = range(self.ngpu) 83 | #return nn.parallel.data_parallel(self.main, input, gpu_ids) 84 | return self.main(input) 85 | -------------------------------------------------------------------------------- /plot_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | from parse import * 7 | import progressbar 8 | import math 9 | from matplotlib.ticker import MultipleLocator, FormatStrFormatter 10 | import pickle 11 | import os.path 12 | import scipy 13 | import scipy.signal 14 | 15 | import argparse 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("output_prefix", help="output prefix. output images will be _disc_loss.png, _real_loss.png, _fake_loss.png, _gen_loss.png") 18 | parser.add_argument("-d", "--data", nargs=2, action='append', 19 | help="