├── LICENSE.md
├── README.md
├── images
├── CIFAR10_epoch100.png
└── MNIST_epoch100.png
├── iterators.py
├── models.py
├── requirements.txt
└── train.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Preferred Networks, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # chainer-LSGAN
2 | An implementation of [_Mao et al., "Least Squares Generative Adversarial Networks" 2017_](https://arxiv.org/abs/1611.04076) using the [Chainer framework](http://chainer.org/).
3 |
4 | Disclaimer: PFN provides no warranty or support for this implementation. Use it at your own risk. See [license](LICENSE.md) for details.
5 |
6 | Results
7 | -------
8 | CIFAR10 & MNIST for 100 epochs
9 |
10 |
11 |
12 |
13 | Usage
14 | -------
15 | Tested using `python 3.5.1`. Install the requirements first:
16 | ```
17 | pip install -r requirements.txt
18 | ```
19 |
20 | Trains on the CIFAR10 dataset by default, and will generate an image of a sample batch from the network after each epoch. Run the following:
21 | ```
22 | python train.py --device_id 0
23 | ```
24 | to train. By default, an output folder will be created in your current working directory. Setting `--device_id` to -1 will run in CPU mode, whereas 0 will run on GPU number 0 etc. To train on MNIST, use the flag `--mnist`.
25 |
26 | License
27 | -------
28 | MIT License. Please see the LICENSE file for details.
29 |
--------------------------------------------------------------------------------
/images/CIFAR10_epoch100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/chainer-LSGAN/8276602e4eeac14f869904069bfeccbe08bd012b/images/CIFAR10_epoch100.png
--------------------------------------------------------------------------------
/images/MNIST_epoch100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/chainer-LSGAN/8276602e4eeac14f869904069bfeccbe08bd012b/images/MNIST_epoch100.png
--------------------------------------------------------------------------------
/iterators.py:
--------------------------------------------------------------------------------
1 | #https://github.com/hvy/chainer-wasserstein-gan/blob/master/iterators.py
2 |
3 | import numpy
4 | from chainer.dataset import iterator
5 |
6 |
7 | def to_tuple(x):
8 | if hasattr(x, '__getitem__'):
9 | return x
10 | return x,
11 |
12 |
13 | class UniformNoiseGenerator(object):
14 | def __init__(self, low, high, size):
15 | self.low = low
16 | self.high = high
17 | self.size = to_tuple(size)
18 |
19 | def __call__(self, batch_size):
20 | return numpy.random.uniform(self.low, self.high, (batch_size,) +
21 | self.size).astype(numpy.float32)
22 |
23 |
24 | class GaussianNoiseGenerator(object):
25 | def __init__(self, loc, scale, size):
26 | self.loc = loc
27 | self.scale = scale
28 | self.size = to_tuple(size)
29 |
30 | def __call__(self, batch_size):
31 | return numpy.random.normal(self.loc, self.scale, (batch_size,) +
32 | self.size).astype(numpy.float32)
33 |
34 |
35 | class RandomNoiseIterator(iterator.Iterator):
36 | def __init__(self, noise_generator, batch_size):
37 | self.noise_generator = noise_generator
38 | self.batch_size = batch_size
39 |
40 | def __next__(self):
41 | batch = self.noise_generator(self.batch_size)
42 | return batch
43 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import chainer
2 | import chainer.functions as F
3 | import chainer.links as L
4 |
5 | class GeneratorCIFAR(chainer.Chain):
6 |
7 | def __init__(self, size=None):
8 |
9 | super().__init__(
10 | dc1=L.Deconvolution2D(None, 256, 4, stride=1, pad=0, nobias=True),
11 | dc2=L.Deconvolution2D(256, 128, 4, stride=2, pad=1, nobias=True),
12 | dc3=L.Deconvolution2D(128, 64, 4, stride=2, pad=1, nobias=True),
13 | dc4=L.Deconvolution2D(64, 3, 4, stride=2, pad=1, nobias=True),
14 | bn_dc1=L.BatchNormalization(256),
15 | bn_dc2=L.BatchNormalization(128),
16 | bn_dc3=L.BatchNormalization(64)
17 | )
18 |
19 | def __call__(self, z):
20 | h = F.reshape(z, (z.shape[0], -1, 1, 1))
21 | h = F.relu(self.bn_dc1(self.dc1(h)))
22 | h = F.relu(self.bn_dc2(self.dc2(h)))
23 | h = F.relu(self.bn_dc3(self.dc3(h)))
24 | h = F.tanh(self.dc4(h))
25 | return h
26 |
27 | class GeneratorMNIST(chainer.Chain):
28 |
29 | def __init__(self, size=None):
30 |
31 | super().__init__(
32 | dc1=L.Deconvolution2D(None, 256, 4, stride=1, pad=0, nobias=True),
33 | dc2=L.Deconvolution2D(256, 128, 4, stride=2, pad=1, nobias=True),
34 | dc3=L.Deconvolution2D(128, 64, 4, stride=2, pad=2, nobias=True),
35 | dc4=L.Deconvolution2D(64, 1, 4, stride=2, pad=1, nobias=True),
36 | bn_dc1=L.BatchNormalization(256),
37 | bn_dc2=L.BatchNormalization(128),
38 | bn_dc3=L.BatchNormalization(64)
39 | )
40 |
41 | def __call__(self, z):
42 | h = F.reshape(z, (z.shape[0], -1, 1, 1))
43 | h = F.relu(self.bn_dc1(self.dc1(h)))
44 | h = F.relu(self.bn_dc2(self.dc2(h)))
45 | h = F.relu(self.bn_dc3(self.dc3(h)))
46 | h = F.tanh(self.dc4(h))
47 | return h
48 |
49 | class Discriminator(chainer.Chain):
50 |
51 | def __init__(self):
52 | super().__init__(
53 | c0 = L.Convolution2D(None, 64, 4, stride=2, pad=1, nobias=True),
54 | c1 = L.Convolution2D(64, 128, 4, stride=2, pad=1, nobias=True),
55 | c2 = L.Convolution2D(128, 256, 4, stride=2, pad=1, nobias=True),
56 | c3 = L.Convolution2D(256, 512, 4, stride=2, pad=1, nobias=True),
57 | l4l = L.Linear(None, 1),
58 | bn0 = L.BatchNormalization(64),
59 | bn1 = L.BatchNormalization(128),
60 | bn2 = L.BatchNormalization(256),
61 | bn3 = L.BatchNormalization(512)
62 | )
63 |
64 | def __call__(self, x):
65 | h = F.leaky_relu(self.c0(x))
66 | h = F.leaky_relu(self.bn1(self.c1(h)))
67 | h = F.leaky_relu(self.bn2(self.c2(h)))
68 | h = F.leaky_relu(self.bn3(self.c3(h)))
69 | l = self.l4l(h)
70 | return l
71 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | chainer<3
2 | cupy<2
3 | matplotlib==1.5.1
4 | numpy==1.11.0
5 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 | import math
6 |
7 | import matplotlib
8 | matplotlib.use('Agg')
9 |
10 | from matplotlib import pyplot as plt
11 | import numpy as np
12 |
13 | import chainer
14 | from chainer import cuda, Variable
15 | import chainer.functions as F
16 |
17 | from models import Discriminator, GeneratorMNIST, GeneratorCIFAR
18 | from iterators import RandomNoiseIterator, GaussianNoiseGenerator, UniformNoiseGenerator
19 |
20 | def get_batch(iter, device_id):
21 | batch = chainer.dataset.concat_examples(next(iter), device=device_id)
22 | return Variable(batch)
23 |
24 | def update_model(opt, loss):
25 | opt.target.cleargrads()
26 | loss.backward()
27 | opt.update()
28 |
29 | def save_ims(filename, ims, dpi=100):
30 |
31 | ims += 1.0
32 | ims /= 2.0
33 |
34 | if cuda.get_array_module(ims) == cuda.cupy:
35 | ims = cuda.to_cpu(ims)
36 |
37 | n, c, w, h = ims.shape
38 | x_plots = math.ceil(math.sqrt(n))
39 | y_plots = x_plots if n % x_plots == 0 else x_plots - 1
40 | plt.figure(figsize=(w*x_plots/dpi, h*y_plots/dpi), dpi=dpi)
41 |
42 | for i, im in enumerate(ims):
43 | plt.subplot(y_plots, x_plots, i+1)
44 |
45 | if c == 1:
46 | plt.imshow(im[0], cmap=plt.cm.binary)
47 | else:
48 | plt.imshow(im.transpose((1, 2, 0)), interpolation="nearest")
49 |
50 | plt.axis('off')
51 | plt.gca().set_xticks([])
52 | plt.gca().set_yticks([])
53 | plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0,
54 | hspace=0)
55 |
56 | plt.savefig(filename, dpi=dpi*2, facecolor='black')
57 | plt.clf()
58 | plt.close()
59 |
60 | def print_sample(name, noise_samples, opt_generator):
61 | generated = opt_generator.target(noise_samples)
62 | save_ims(name, generated.data)
63 | print(" Saved image to {}".format(name))
64 |
65 | def training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator):
66 |
67 | noise_samples = get_batch(noise_iter, args.device_id)
68 |
69 | # generate an image
70 | generated = opt_generator.target(noise_samples)
71 |
72 | # get a batch of the dataset
73 | train_samples = get_batch(train_iter, args.device_id)
74 |
75 | # update the discriminator
76 | Dreal = opt_discriminator.target(train_samples)
77 | Dgen = opt_discriminator.target(generated)
78 |
79 | Dloss = 0.5 * (F.sum((Dreal - 1.0)**2) + F.sum(Dgen**2)) / args.batchsize
80 | update_model(opt_discriminator, Dloss)
81 |
82 | # update the generator
83 | noise_samples = get_batch(noise_iter, args.device_id)
84 | generated = opt_generator.target(noise_samples)
85 | Gloss = 0.5 * F.sum((opt_discriminator.target(generated) - 1.0)**2) / args.batchsize
86 | update_model(opt_generator, Gloss)
87 |
88 | if train_iter.is_new_epoch:
89 | print("[{}] Discriminator loss: {} Generator loss: {}".format(train_iter.epoch, Dloss.data, Gloss.data))
90 | print_sample(os.path.join(args.output, "epoch_{}.png".format(train_iter.epoch)), noise_samples, opt_generator)
91 |
92 | def parse_args():
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument('--device_id', '-g', type=int, default=-1)
95 | parser.add_argument('--num_epochs', '-n', type=int, default=100)
96 | parser.add_argument('--batchsize', '-b', type=int, default=64)
97 | parser.add_argument('--num_z', '-z', type=int, default=1024)
98 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.001)
99 | parser.add_argument('--output', '-o', type=str, default="output")
100 | parser.add_argument('--mnist', '-m', action="store_true")
101 | return parser.parse_args()
102 |
103 |
104 | def main(args):
105 |
106 | # if we enabled GPU mode, set the GPU to use
107 | if args.device_id >= 0:
108 | chainer.cuda.get_device(args.device_id).use()
109 |
110 | # Load dataset (we will only use the training set)
111 | if args.mnist:
112 | train, test = chainer.datasets.get_mnist(withlabel=False, scale=2, ndim=3)
113 | generator = GeneratorMNIST()
114 | else:
115 | train, test = chainer.datasets.get_cifar10(withlabel=False, scale=2, ndim=3)
116 | generator = GeneratorCIFAR()
117 |
118 | # subtracting 1, after scaling to 2 (done above) will make all pixels in the range [-1,1]
119 | train -= 1.0
120 |
121 | num_training_samples = train.shape[0]
122 |
123 | # make data iterators
124 | train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
125 |
126 | # build optimizers and models
127 | opt_generator = chainer.optimizers.RMSprop(lr=args.learning_rate)
128 | opt_discriminator = chainer.optimizers.RMSprop(lr=args.learning_rate)
129 |
130 | opt_generator.setup(generator)
131 | opt_discriminator.setup(Discriminator())
132 |
133 | # make a random noise iterator (uniform noise between -1 and 1)
134 | noise_iter = RandomNoiseIterator(UniformNoiseGenerator(-1, 1, args.num_z), args.batchsize)
135 |
136 | # send to GPU
137 | if args.device_id >= 0:
138 | opt_generator.target.to_gpu()
139 | opt_discriminator.target.to_gpu()
140 |
141 | # make the output folder
142 | if not os.path.exists(args.output):
143 | os.makedirs(args.output, exist_ok=True)
144 |
145 | print("Starting training loop...")
146 |
147 | while train_iter.epoch < args.num_epochs:
148 | training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator)
149 |
150 | print("Finished training.")
151 |
152 | if __name__=='__main__':
153 | args = parse_args()
154 | main(args)
155 |
--------------------------------------------------------------------------------