├── LICENSE.md ├── README.md ├── images ├── CIFAR10_epoch100.png └── MNIST_epoch100.png ├── iterators.py ├── models.py ├── requirements.txt └── train.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Preferred Networks, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chainer-LSGAN 2 | An implementation of [_Mao et al., "Least Squares Generative Adversarial Networks" 2017_](https://arxiv.org/abs/1611.04076) using the [Chainer framework](http://chainer.org/). 3 | 4 | Disclaimer: PFN provides no warranty or support for this implementation. Use it at your own risk. See [license](LICENSE.md) for details. 5 | 6 | Results 7 | ------- 8 | CIFAR10 & MNIST for 100 epochs 9 |

10 | CIFAR10 MNIST 11 |

12 | 13 | Usage 14 | ------- 15 | Tested using `python 3.5.1`. Install the requirements first: 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | Trains on the CIFAR10 dataset by default, and will generate an image of a sample batch from the network after each epoch. Run the following: 21 | ``` 22 | python train.py --device_id 0 23 | ``` 24 | to train. By default, an output folder will be created in your current working directory. Setting `--device_id` to -1 will run in CPU mode, whereas 0 will run on GPU number 0 etc. To train on MNIST, use the flag `--mnist`. 25 | 26 | License 27 | ------- 28 | MIT License. Please see the LICENSE file for details. 29 | -------------------------------------------------------------------------------- /images/CIFAR10_epoch100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/chainer-LSGAN/8276602e4eeac14f869904069bfeccbe08bd012b/images/CIFAR10_epoch100.png -------------------------------------------------------------------------------- /images/MNIST_epoch100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/chainer-LSGAN/8276602e4eeac14f869904069bfeccbe08bd012b/images/MNIST_epoch100.png -------------------------------------------------------------------------------- /iterators.py: -------------------------------------------------------------------------------- 1 | #https://github.com/hvy/chainer-wasserstein-gan/blob/master/iterators.py 2 | 3 | import numpy 4 | from chainer.dataset import iterator 5 | 6 | 7 | def to_tuple(x): 8 | if hasattr(x, '__getitem__'): 9 | return x 10 | return x, 11 | 12 | 13 | class UniformNoiseGenerator(object): 14 | def __init__(self, low, high, size): 15 | self.low = low 16 | self.high = high 17 | self.size = to_tuple(size) 18 | 19 | def __call__(self, batch_size): 20 | return numpy.random.uniform(self.low, self.high, (batch_size,) + 21 | self.size).astype(numpy.float32) 22 | 23 | 24 | class GaussianNoiseGenerator(object): 25 | def __init__(self, loc, scale, size): 26 | self.loc = loc 27 | self.scale = scale 28 | self.size = to_tuple(size) 29 | 30 | def __call__(self, batch_size): 31 | return numpy.random.normal(self.loc, self.scale, (batch_size,) + 32 | self.size).astype(numpy.float32) 33 | 34 | 35 | class RandomNoiseIterator(iterator.Iterator): 36 | def __init__(self, noise_generator, batch_size): 37 | self.noise_generator = noise_generator 38 | self.batch_size = batch_size 39 | 40 | def __next__(self): 41 | batch = self.noise_generator(self.batch_size) 42 | return batch 43 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | import chainer.links as L 4 | 5 | class GeneratorCIFAR(chainer.Chain): 6 | 7 | def __init__(self, size=None): 8 | 9 | super().__init__( 10 | dc1=L.Deconvolution2D(None, 256, 4, stride=1, pad=0, nobias=True), 11 | dc2=L.Deconvolution2D(256, 128, 4, stride=2, pad=1, nobias=True), 12 | dc3=L.Deconvolution2D(128, 64, 4, stride=2, pad=1, nobias=True), 13 | dc4=L.Deconvolution2D(64, 3, 4, stride=2, pad=1, nobias=True), 14 | bn_dc1=L.BatchNormalization(256), 15 | bn_dc2=L.BatchNormalization(128), 16 | bn_dc3=L.BatchNormalization(64) 17 | ) 18 | 19 | def __call__(self, z): 20 | h = F.reshape(z, (z.shape[0], -1, 1, 1)) 21 | h = F.relu(self.bn_dc1(self.dc1(h))) 22 | h = F.relu(self.bn_dc2(self.dc2(h))) 23 | h = F.relu(self.bn_dc3(self.dc3(h))) 24 | h = F.tanh(self.dc4(h)) 25 | return h 26 | 27 | class GeneratorMNIST(chainer.Chain): 28 | 29 | def __init__(self, size=None): 30 | 31 | super().__init__( 32 | dc1=L.Deconvolution2D(None, 256, 4, stride=1, pad=0, nobias=True), 33 | dc2=L.Deconvolution2D(256, 128, 4, stride=2, pad=1, nobias=True), 34 | dc3=L.Deconvolution2D(128, 64, 4, stride=2, pad=2, nobias=True), 35 | dc4=L.Deconvolution2D(64, 1, 4, stride=2, pad=1, nobias=True), 36 | bn_dc1=L.BatchNormalization(256), 37 | bn_dc2=L.BatchNormalization(128), 38 | bn_dc3=L.BatchNormalization(64) 39 | ) 40 | 41 | def __call__(self, z): 42 | h = F.reshape(z, (z.shape[0], -1, 1, 1)) 43 | h = F.relu(self.bn_dc1(self.dc1(h))) 44 | h = F.relu(self.bn_dc2(self.dc2(h))) 45 | h = F.relu(self.bn_dc3(self.dc3(h))) 46 | h = F.tanh(self.dc4(h)) 47 | return h 48 | 49 | class Discriminator(chainer.Chain): 50 | 51 | def __init__(self): 52 | super().__init__( 53 | c0 = L.Convolution2D(None, 64, 4, stride=2, pad=1, nobias=True), 54 | c1 = L.Convolution2D(64, 128, 4, stride=2, pad=1, nobias=True), 55 | c2 = L.Convolution2D(128, 256, 4, stride=2, pad=1, nobias=True), 56 | c3 = L.Convolution2D(256, 512, 4, stride=2, pad=1, nobias=True), 57 | l4l = L.Linear(None, 1), 58 | bn0 = L.BatchNormalization(64), 59 | bn1 = L.BatchNormalization(128), 60 | bn2 = L.BatchNormalization(256), 61 | bn3 = L.BatchNormalization(512) 62 | ) 63 | 64 | def __call__(self, x): 65 | h = F.leaky_relu(self.c0(x)) 66 | h = F.leaky_relu(self.bn1(self.c1(h))) 67 | h = F.leaky_relu(self.bn2(self.c2(h))) 68 | h = F.leaky_relu(self.bn3(self.c3(h))) 69 | l = self.l4l(h) 70 | return l 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chainer<3 2 | cupy<2 3 | matplotlib==1.5.1 4 | numpy==1.11.0 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import math 6 | 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | 10 | from matplotlib import pyplot as plt 11 | import numpy as np 12 | 13 | import chainer 14 | from chainer import cuda, Variable 15 | import chainer.functions as F 16 | 17 | from models import Discriminator, GeneratorMNIST, GeneratorCIFAR 18 | from iterators import RandomNoiseIterator, GaussianNoiseGenerator, UniformNoiseGenerator 19 | 20 | def get_batch(iter, device_id): 21 | batch = chainer.dataset.concat_examples(next(iter), device=device_id) 22 | return Variable(batch) 23 | 24 | def update_model(opt, loss): 25 | opt.target.cleargrads() 26 | loss.backward() 27 | opt.update() 28 | 29 | def save_ims(filename, ims, dpi=100): 30 | 31 | ims += 1.0 32 | ims /= 2.0 33 | 34 | if cuda.get_array_module(ims) == cuda.cupy: 35 | ims = cuda.to_cpu(ims) 36 | 37 | n, c, w, h = ims.shape 38 | x_plots = math.ceil(math.sqrt(n)) 39 | y_plots = x_plots if n % x_plots == 0 else x_plots - 1 40 | plt.figure(figsize=(w*x_plots/dpi, h*y_plots/dpi), dpi=dpi) 41 | 42 | for i, im in enumerate(ims): 43 | plt.subplot(y_plots, x_plots, i+1) 44 | 45 | if c == 1: 46 | plt.imshow(im[0], cmap=plt.cm.binary) 47 | else: 48 | plt.imshow(im.transpose((1, 2, 0)), interpolation="nearest") 49 | 50 | plt.axis('off') 51 | plt.gca().set_xticks([]) 52 | plt.gca().set_yticks([]) 53 | plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, 54 | hspace=0) 55 | 56 | plt.savefig(filename, dpi=dpi*2, facecolor='black') 57 | plt.clf() 58 | plt.close() 59 | 60 | def print_sample(name, noise_samples, opt_generator): 61 | generated = opt_generator.target(noise_samples) 62 | save_ims(name, generated.data) 63 | print(" Saved image to {}".format(name)) 64 | 65 | def training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator): 66 | 67 | noise_samples = get_batch(noise_iter, args.device_id) 68 | 69 | # generate an image 70 | generated = opt_generator.target(noise_samples) 71 | 72 | # get a batch of the dataset 73 | train_samples = get_batch(train_iter, args.device_id) 74 | 75 | # update the discriminator 76 | Dreal = opt_discriminator.target(train_samples) 77 | Dgen = opt_discriminator.target(generated) 78 | 79 | Dloss = 0.5 * (F.sum((Dreal - 1.0)**2) + F.sum(Dgen**2)) / args.batchsize 80 | update_model(opt_discriminator, Dloss) 81 | 82 | # update the generator 83 | noise_samples = get_batch(noise_iter, args.device_id) 84 | generated = opt_generator.target(noise_samples) 85 | Gloss = 0.5 * F.sum((opt_discriminator.target(generated) - 1.0)**2) / args.batchsize 86 | update_model(opt_generator, Gloss) 87 | 88 | if train_iter.is_new_epoch: 89 | print("[{}] Discriminator loss: {} Generator loss: {}".format(train_iter.epoch, Dloss.data, Gloss.data)) 90 | print_sample(os.path.join(args.output, "epoch_{}.png".format(train_iter.epoch)), noise_samples, opt_generator) 91 | 92 | def parse_args(): 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--device_id', '-g', type=int, default=-1) 95 | parser.add_argument('--num_epochs', '-n', type=int, default=100) 96 | parser.add_argument('--batchsize', '-b', type=int, default=64) 97 | parser.add_argument('--num_z', '-z', type=int, default=1024) 98 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.001) 99 | parser.add_argument('--output', '-o', type=str, default="output") 100 | parser.add_argument('--mnist', '-m', action="store_true") 101 | return parser.parse_args() 102 | 103 | 104 | def main(args): 105 | 106 | # if we enabled GPU mode, set the GPU to use 107 | if args.device_id >= 0: 108 | chainer.cuda.get_device(args.device_id).use() 109 | 110 | # Load dataset (we will only use the training set) 111 | if args.mnist: 112 | train, test = chainer.datasets.get_mnist(withlabel=False, scale=2, ndim=3) 113 | generator = GeneratorMNIST() 114 | else: 115 | train, test = chainer.datasets.get_cifar10(withlabel=False, scale=2, ndim=3) 116 | generator = GeneratorCIFAR() 117 | 118 | # subtracting 1, after scaling to 2 (done above) will make all pixels in the range [-1,1] 119 | train -= 1.0 120 | 121 | num_training_samples = train.shape[0] 122 | 123 | # make data iterators 124 | train_iter = chainer.iterators.SerialIterator(train, args.batchsize) 125 | 126 | # build optimizers and models 127 | opt_generator = chainer.optimizers.RMSprop(lr=args.learning_rate) 128 | opt_discriminator = chainer.optimizers.RMSprop(lr=args.learning_rate) 129 | 130 | opt_generator.setup(generator) 131 | opt_discriminator.setup(Discriminator()) 132 | 133 | # make a random noise iterator (uniform noise between -1 and 1) 134 | noise_iter = RandomNoiseIterator(UniformNoiseGenerator(-1, 1, args.num_z), args.batchsize) 135 | 136 | # send to GPU 137 | if args.device_id >= 0: 138 | opt_generator.target.to_gpu() 139 | opt_discriminator.target.to_gpu() 140 | 141 | # make the output folder 142 | if not os.path.exists(args.output): 143 | os.makedirs(args.output, exist_ok=True) 144 | 145 | print("Starting training loop...") 146 | 147 | while train_iter.epoch < args.num_epochs: 148 | training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator) 149 | 150 | print("Finished training.") 151 | 152 | if __name__=='__main__': 153 | args = parse_args() 154 | main(args) 155 | --------------------------------------------------------------------------------