├── README.md ├── catgan_cifar10.py ├── results └── cifar10 │ ├── cifar10.gif │ ├── errD.jpg │ ├── errD_fake.jpg │ ├── errD_real.jpg │ ├── errG.jpg │ ├── samples_0.jpg │ ├── samples_10.jpg │ ├── samples_12.jpg │ ├── samples_14.jpg │ ├── samples_16.jpg │ ├── samples_18.jpg │ ├── samples_2.jpg │ ├── samples_20.jpg │ ├── samples_22.jpg │ ├── samples_24.jpg │ ├── samples_26.jpg │ ├── samples_28.jpg │ ├── samples_30.jpg │ ├── samples_32.jpg │ ├── samples_34.jpg │ ├── samples_36.jpg │ ├── samples_38.jpg │ ├── samples_4.jpg │ ├── samples_40.jpg │ ├── samples_42.jpg │ ├── samples_44.jpg │ ├── samples_46.jpg │ ├── samples_48.jpg │ ├── samples_50.jpg │ ├── samples_52.jpg │ ├── samples_54.jpg │ ├── samples_56.jpg │ ├── samples_58.jpg │ ├── samples_6.jpg │ ├── samples_60.jpg │ ├── samples_62.jpg │ ├── samples_64.jpg │ ├── samples_66.jpg │ ├── samples_68.jpg │ ├── samples_70.jpg │ ├── samples_72.jpg │ ├── samples_74.jpg │ ├── samples_76.jpg │ ├── samples_78.jpg │ ├── samples_8.jpg │ ├── samples_80.jpg │ ├── samples_82.jpg │ ├── samples_84.jpg │ ├── samples_86.jpg │ ├── samples_88.jpg │ ├── samples_90.jpg │ ├── samples_92.jpg │ ├── samples_94.jpg │ ├── samples_96.jpg │ └── samples_98.jpg └── utils ├── plot.py └── utility.py /README.md: -------------------------------------------------------------------------------- 1 | # catGAN 2 | 3 | PyTorch implementation of [Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks](https://arxiv.org/abs/1511.06390) that was originally proposed by Jost Tobias Springenberg. 4 | 5 | 6 | 7 | 8 | ### Results on CIFAR10 9 | Note that in this repo, only the unsupervised version was implemented for now. I reaplced the orginal architecture with DCGAN and the results are more colorful than the original one. 10 | 11 | From 0 to 100 epochs: 12 | 13 | ![cifar10](results/cifar10/cifar10.gif) 14 | 15 | 16 | 17 | ## Prerequisites 18 | - Python 2.7 19 | - PyTorch v0.2.0 20 | - Numpy 21 | - SciPy 22 | - Matplotlib 23 | 24 | 25 | ## Getting Started 26 | ### Installation 27 | - Install [PyTorh](https://github.com/pytorch/pytorch) and the other dependencies 28 | - Clone this repo: 29 | ```bash 30 | git clone https://github.com/xinario/catgan_pytorch.git 31 | cd catgan_pytorch 32 | ``` 33 | 34 | ### Train 35 | - Download the cifar10 dataset (.png format from [kaggle](https://www.kaggle.com/c/cifar-10/data)) 36 | - Create a dataset folder to hold the images 37 | ```bash 38 | mkdir -p ./datasets/cifar10/images 39 | ``` 40 | - Move the extracted images into the newly created folder 41 | 42 | - Train a model: 43 | ```bash 44 | python catgan_cifar10.py --data_dir ./datasets/cifar10 --name cifar10 45 | ``` 46 | All the generated plot and samples can be found in side ./results/cifar10 47 | 48 | 49 | 50 | 51 | ### Training options 52 | ```bash 53 | optional arguments: 54 | 55 | --continue_train to continue training from the latest checkpoints if --netG and --netD are not specified 56 | --netG NETG path to netG (to continue training) 57 | --netD NETD path to netD (to continue training) 58 | --workers WORKERS number of data loading workers 59 | --num_epochs EPOCHS number of epochs to train for 60 | ``` 61 | More options can be found in side the training script. 62 | 63 | 64 | 65 | ## Acknowledgments 66 | Some of code are inspired and borrowed from [wgan-gp](https://github.com/caogang/wgan-gp), [DCGAN](https://github.com/pytorch/examples/tree/master/dcgan), [catGAN chainer repo](https://github.com/smayru/catgan) 67 | -------------------------------------------------------------------------------- /catgan_cifar10.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.getcwd()) 3 | 4 | import time 5 | 6 | from utils.utility import mkdir_p, generate_image 7 | from utils.plot import plot, flush 8 | 9 | import numpy as np 10 | 11 | 12 | import torch 13 | import torchvision 14 | import torchvision.datasets as dset 15 | import torchvision.transforms as transforms 16 | from torch import nn 17 | from torch import autograd 18 | from torch import optim 19 | import argparse 20 | import csv 21 | 22 | 23 | 24 | 25 | parser = argparse.ArgumentParser(description='parse the input options') 26 | 27 | parser.add_argument('--name', type=str, default='cifar10', help='name of the experiment. It decides where to store the results and checkpoints') 28 | parser.add_argument('--results_dir', type=str, default='./results', help='folder to store the results') 29 | parser.add_argument('--image_size', type=int, default=32, help='input image size, for cifar10 is 32x32') 30 | parser.add_argument('--batch_size', type=int, default=20, help='batch size') 31 | parser.add_argument('--workers', type=int, default=2, help='# of workers to load the dataset') 32 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='folder to store the model checkpoint') 33 | parser.add_argument('--noise_dim', type=int, default=100, help='input dim of noise') 34 | parser.add_argument('--dim', type=int, default=64, help='# of filters in first conv layer of both discrim and gen') 35 | parser.add_argument('--data_dir', required=True, help='folder of the dataset') 36 | parser.add_argument('--netG', type=str, default='', help='checkpoints of netG you wish to use in continuing the training') 37 | parser.add_argument('--netD', type=str, default='', help='checkpoints of netD you wish to use in continuing the training') 38 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 39 | parser.add_argument('--num_epochs', type=int, default=200, help='# of epochs to train') 40 | 41 | 42 | 43 | 44 | opt = parser.parse_args() 45 | 46 | 47 | 48 | dtype = torch.FloatTensor 49 | 50 | mkdir_p(os.path.join(opt.results_dir,opt.name)) 51 | mkdir_p(os.path.join(opt.checkpoints_dir,opt.name)) 52 | 53 | 54 | class Generator(nn.Module): 55 | def __init__(self): 56 | super(Generator, self).__init__() 57 | self.main = nn.Sequential( 58 | # input is Z, going into a convolution 59 | nn.ConvTranspose2d( opt.noise_dim, opt.dim * 4, 4, 1, 0, bias=False), 60 | nn.BatchNorm2d(opt.dim * 4), 61 | nn.LeakyReLU(0.2, inplace=True), 62 | # state size. (opt.dim*8) x 4 x 4 63 | nn.ConvTranspose2d(opt.dim * 4, opt.dim * 2, 4, 2, 1, bias=False), 64 | nn.BatchNorm2d(opt.dim * 2), 65 | nn.LeakyReLU(0.2, inplace=True), 66 | # state size. (opt.dim*4) x 8 x 8 67 | nn.ConvTranspose2d(opt.dim * 2, opt.dim, 4, 2, 1, bias=False), 68 | nn.BatchNorm2d(opt.dim), 69 | nn.LeakyReLU(0.2, inplace=True), 70 | # state size. (opt.dim*2) x 16 x 16 71 | nn.ConvTranspose2d(opt.dim, 3, 4, 2, 1, bias=False), 72 | nn.Tanh() 73 | # state size. (nc) x 32 x 32 74 | ) 75 | 76 | def forward(self, input): 77 | output = self.main(input) 78 | return output 79 | 80 | 81 | class Discriminator(nn.Module): 82 | def __init__(self): 83 | super(Discriminator, self).__init__() 84 | main = nn.Sequential( 85 | nn.Conv2d(3, opt.dim, 4, 2, 1, bias=False), 86 | nn.BatchNorm2d(opt.dim), 87 | nn.LeakyReLU(0.2, inplace=True), 88 | nn.Dropout(0.5),#64x16x16 89 | nn.Conv2d(opt.dim, 2 * opt.dim, 4, 2, 1, bias=False), 90 | nn.BatchNorm2d(2*opt.dim), 91 | nn.LeakyReLU(0.2, inplace=True), 92 | nn.Dropout(0.5),#128x8x8 93 | nn.Conv2d(2 * opt.dim, 4 * opt.dim, 4, 2, 1, bias=False), 94 | nn.BatchNorm2d(4*opt.dim), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | nn.Dropout(0.5),#256x4x4 97 | nn.Conv2d(4*opt.dim, 4*opt.dim, 4), 98 | nn.BatchNorm2d(4*opt.dim), 99 | nn.LeakyReLU(0.2, inplace=True), 100 | nn.Dropout(0.5),#256x1x1 101 | nn.Conv2d(4*opt.dim, 10, 1) 102 | ) 103 | 104 | self.main = main 105 | self.softmax = nn.Softmax() 106 | 107 | def forward(self, input): 108 | output = self.main(input) 109 | output = output.view(-1, 10) 110 | output = self.softmax(output) 111 | return output 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | #marginalized entropy 120 | def entropy1(y): 121 | y1 = autograd.Variable(torch.randn(y.size(1)).type(dtype), requires_grad=True) 122 | y2 = autograd.Variable(torch.randn(1).type(dtype), requires_grad=True) 123 | y1 = y.mean(0) 124 | y2 = -torch.sum(y1*torch.log(y1+1e-6)) 125 | 126 | return y2 127 | 128 | 129 | 130 | # entropy 131 | def entropy2(y): 132 | y1 = autograd.Variable(torch.randn(y.size()).type(dtype), requires_grad=True) 133 | y2 = autograd.Variable(torch.randn(1).type(dtype), requires_grad=True) 134 | y1 = -y*torch.log(y+1e-6) 135 | 136 | y2 = 1.0/opt.batch_size*y1.sum() 137 | return y2 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | netG = Generator() 147 | netD = Discriminator() 148 | 149 | #continue traning by loading the latest model or the model specified in --netG and --netD 150 | if opt.continue_train: 151 | if opt.netG != '': 152 | netG.load_state_dict(torch.load(opt.netG)) 153 | else: 154 | netG.load_state_dict(torch.load('%s/netG_latest.pth' % (os.path.join(opt.checkpoints_dir,opt.name)))) 155 | 156 | 157 | if opt.netD != '': 158 | netD.load_state_dict(torch.load(opt.netD)) 159 | else: 160 | netD.load_state_dict(torch.load('%s/netD_latest.pth' % (os.path.join(opt.checkpoints_dir,opt.name)))) 161 | 162 | 163 | 164 | print netG 165 | print netD 166 | 167 | use_cuda = torch.cuda.is_available() 168 | 169 | if use_cuda: 170 | netD = netD.cuda() 171 | netG = netG.cuda() 172 | 173 | one = torch.FloatTensor([1]) 174 | mone = one * -1 175 | 176 | if use_cuda: 177 | one = one.cuda() 178 | mone = mone.cuda() 179 | 180 | optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.9)) 181 | optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.9)) 182 | 183 | 184 | 185 | 186 | # Dataset iterator 187 | dataset = dset.ImageFolder(root=opt.data_dir, 188 | transform=transforms.Compose([ 189 | transforms.Scale(opt.image_size), 190 | transforms.ToTensor(), 191 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 192 | ])) 193 | 194 | 195 | assert dataset 196 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, 197 | shuffle=True, num_workers=int(opt.workers)) 198 | 199 | print("Start training on %s dataset which contains %d images..." %(opt.name, len(dataset))) 200 | 201 | 202 | 203 | 204 | 205 | 206 | iter_idx = 0 207 | with open(os.path.join(opt.results_dir, opt.name, 'log.csv'), 'wb') as log: 208 | log_writer = csv.writer(log, delimiter=',') 209 | 210 | for epoch in xrange(opt.num_epochs): 211 | start_time = time.time() 212 | 213 | for batch_idx, (real, labels) in enumerate(dataloader): 214 | ########################### 215 | # (1) Update D network 216 | ########################### 217 | 218 | #freeze G and update D 219 | for p in netD.parameters(): 220 | p.requires_grad = True 221 | for p in netG.parameters(): 222 | p.requires_grad = False 223 | netD.zero_grad() 224 | 225 | ################# 226 | # train with real 227 | ################# 228 | if use_cuda: 229 | real = autograd.Variable(real.cuda()) 230 | 231 | D_real = netD(real) 232 | # minimize entropy to make certain prediction of real sample 233 | entorpy2_real = entropy2(D_real) 234 | entorpy2_real.backward(one, retain_graph=True) 235 | 236 | # maximize marginalized entropy over real samples to ensure equal usage 237 | entropy1_real = entropy1(D_real) 238 | entropy1_real.backward(mone) 239 | 240 | ################# 241 | # train with fake 242 | ################# 243 | noise = torch.randn(opt.batch_size, opt.noise_dim, 1, 1) 244 | if use_cuda: 245 | noise = autograd.Variable(noise.cuda()) # totally freeze netG 246 | 247 | fake = netG(noise) 248 | D_fake = netD(fake) 249 | 250 | #minimize entropy to make uncertain prediction of fake sample 251 | entorpy2_fake = entropy2(D_fake) 252 | entorpy2_fake.backward(mone) 253 | 254 | 255 | D_cost = entropy1_real + entorpy2_real + entorpy2_fake 256 | optimizerD.step() 257 | ############################ 258 | # (2) Update G network 259 | ########################### 260 | 261 | #freeze D and update G 262 | for p in netD.parameters(): 263 | p.requires_grad = False 264 | for p in netG.parameters(): 265 | p.requires_grad = True 266 | netG.zero_grad() 267 | 268 | 269 | noise = torch.randn(opt.batch_size, opt.noise_dim, 1, 1) 270 | noise = autograd.Variable(noise.cuda()) 271 | fake = netG(noise) 272 | D_fake = netD(fake) 273 | 274 | #fool D to make it believe the generated samples are real 275 | entropy2_fake = entropy2(D_fake) 276 | entropy2_fake.backward(one, retain_graph=True) 277 | 278 | #ensure equal usage of fake samples 279 | entropy1_fake = entropy1(D_fake) 280 | entropy1_fake.backward(mone) 281 | 282 | G_cost = entropy2_fake + entropy1_fake 283 | optimizerG.step() 284 | 285 | 286 | D_cost = D_cost.cpu().data.numpy() 287 | G_cost = G_cost.cpu().data.numpy() 288 | entorpy2_real = entorpy2_real.cpu().data.numpy() 289 | entorpy2_fake = entorpy2_fake.cpu().data.numpy() 290 | 291 | #monitoring the loss 292 | plot('errD', D_cost, iter_idx) 293 | # plot('time', time.time() - start_time, iter_idx) 294 | plot('errG', G_cost, iter_idx) 295 | plot('errD_real', entorpy2_real, iter_idx) 296 | plot('errD_fake', entorpy2_fake, iter_idx) 297 | 298 | 299 | # Save plot every iter 300 | flush(os.path.join(opt.results_dir, opt.name)) 301 | 302 | # Write losses to logs 303 | log_writer.writerow([D_cost[0],G_cost[0],entorpy2_real[0],entorpy2_fake[0]]) 304 | 305 | print "iter%d[epoch %d]\t %s %.4f \t %s %.4f \t %s %.4f \t %s %.4f" % (iter_idx, epoch, 306 | 'errD', D_cost, 307 | 'errG', G_cost, 308 | 'errD_real', entorpy2_real, 309 | 'errD_fake', entorpy2_fake ) 310 | 311 | #checkpointing the latest model every 500 iteration 312 | if iter_idx % 500 == 0: 313 | torch.save(netG.state_dict(), '%s/netG_latest.pth' % (os.path.join(opt.checkpoints_dir,opt.name))) 314 | torch.save(netD.state_dict(), '%s/netD_latest.pth' % (os.path.join(opt.checkpoints_dir,opt.name))) 315 | 316 | iter_idx += 1 317 | 318 | 319 | # generate samples every 2 epochs for surveillance 320 | if epoch % 2 == 0: 321 | generate_image(epoch, netG, opt) 322 | 323 | 324 | # do checkpointing every 20 epochs 325 | if epoch % 20 == 0: 326 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (os.path.join(opt.checkpoints_dir, opt.name), epoch)) 327 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (os.path.join(opt.checkpoints_dir, opt.name), epoch)) 328 | -------------------------------------------------------------------------------- /results/cifar10/cifar10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/cifar10.gif -------------------------------------------------------------------------------- /results/cifar10/errD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/errD.jpg -------------------------------------------------------------------------------- /results/cifar10/errD_fake.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/errD_fake.jpg -------------------------------------------------------------------------------- /results/cifar10/errD_real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/errD_real.jpg -------------------------------------------------------------------------------- /results/cifar10/errG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/errG.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_0.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_10.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_12.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_14.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_16.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_18.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_2.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_20.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_22.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_24.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_26.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_28.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_30.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_32.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_34.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_34.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_36.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_38.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_4.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_40.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_42.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_42.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_44.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_46.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_46.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_48.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_48.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_50.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_52.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_52.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_54.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_54.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_56.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_56.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_58.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_58.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_6.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_60.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_60.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_62.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_62.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_64.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_64.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_66.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_66.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_68.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_68.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_70.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_70.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_72.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_72.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_74.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_74.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_76.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_76.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_78.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_78.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_8.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_80.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_80.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_82.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_82.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_84.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_84.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_86.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_86.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_88.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_88.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_90.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_90.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_92.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_92.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_94.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_94.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_96.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_96.jpg -------------------------------------------------------------------------------- /results/cifar10/samples_98.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinario/catgan_pytorch/a02c67b3d9f1ee272f17e819cd34e473ab1d11f8/results/cifar10/samples_98.jpg -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | 7 | import collections 8 | import time 9 | import os 10 | 11 | _since_beginning = collections.defaultdict(lambda: {}) 12 | _since_last_flush = collections.defaultdict(lambda: {}) 13 | 14 | 15 | 16 | def plot(name, value, iter): 17 | _since_last_flush[name][iter] = value 18 | 19 | def flush(save_dir): 20 | prints = [] 21 | 22 | for name, vals in _since_last_flush.items(): 23 | _since_beginning[name].update(vals) 24 | 25 | x_vals = np.sort(_since_beginning[name].keys()) 26 | y_vals = [_since_beginning[name][x] for x in x_vals] 27 | 28 | plt.clf() 29 | plt.plot(x_vals, y_vals) 30 | plt.xlabel('iteration') 31 | plt.ylabel(name) 32 | plt.savefig(os.path.join(save_dir, name.replace(' ', '_')+'.jpg')) 33 | 34 | _since_last_flush.clear() 35 | 36 | -------------------------------------------------------------------------------- /utils/utility.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image grid saver, based on color_grid_vis from github.com/Newmu 3 | """ 4 | 5 | import numpy as np 6 | import scipy.misc 7 | from scipy.misc import imsave 8 | import os 9 | import errno 10 | import torch 11 | from torch import autograd 12 | 13 | 14 | def save_images(X, save_path): 15 | # [0, 1] -> [0,255] 16 | if isinstance(X.flatten()[0], np.floating): 17 | X = (255.99*X).astype('uint8') 18 | 19 | n_samples = X.shape[0] 20 | rows = int(np.sqrt(n_samples)) 21 | while n_samples % rows != 0: 22 | rows -= 1 23 | 24 | nh, nw = rows, n_samples/rows 25 | 26 | if X.ndim == 2: 27 | X = np.reshape(X, (X.shape[0], int(np.sqrt(X.shape[1])), int(np.sqrt(X.shape[1])))) 28 | 29 | if X.ndim == 4: 30 | # BCHW -> BHWC 31 | X = X.transpose(0,2,3,1) 32 | h, w = X[0].shape[:2] 33 | img = np.zeros((h*nh, w*nw, 3)) 34 | elif X.ndim == 3: 35 | h, w = X[0].shape[:2] 36 | img = np.zeros((h*nh, w*nw)) 37 | 38 | for n, x in enumerate(X): 39 | j = n/nw 40 | i = n%nw 41 | img[j*h:j*h+h, i*w:i*w+w] = x 42 | 43 | imsave(save_path, img) 44 | 45 | 46 | 47 | # For generating samples 48 | def generate_image(epoch, netG, opt): 49 | fixed_noise_128 = torch.randn(128, opt.noise_dim, 1, 1) 50 | if torch.cuda.is_available(): 51 | fixed_noise_128 = fixed_noise_128.cuda() 52 | noisev = autograd.Variable(fixed_noise_128, volatile=True) 53 | samples = netG(noisev) 54 | samples = samples.view(-1, 3, opt.image_size, opt.image_size) 55 | samples = samples.mul(0.5).add(0.5) 56 | samples = samples.cpu().data.numpy() 57 | 58 | save_images(samples, os.path.join(opt.results_dir, opt.name, 'samples_{}.jpg'.format(epoch))) 59 | 60 | 61 | # code borrowed from @tzot on stackoverflow.com 62 | def mkdir_p(path): 63 | try: 64 | os.makedirs(path) 65 | except OSError as exc: # Python >2.5 66 | if exc.errno == errno.EEXIST and os.path.isdir(path): 67 | pass 68 | else: 69 | raise --------------------------------------------------------------------------------