├── figs
├── gan
│ ├── epoch_0.png
│ ├── epoch_25.png
│ ├── epoch_50.png
│ ├── epoch_75.png
│ └── epoch_99.png
├── lsgan
│ ├── epoch_0.png
│ ├── epoch_25.png
│ ├── epoch_50.png
│ ├── epoch_75.png
│ └── epoch_99.png
└── wgan
│ ├── epoch_0.png
│ ├── epoch_25.png
│ ├── epoch_50.png
│ ├── epoch_75.png
│ └── epoch_99.png
├── main.py
├── readme.md
└── utils.py
/figs/gan/epoch_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_0.png
--------------------------------------------------------------------------------
/figs/gan/epoch_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_25.png
--------------------------------------------------------------------------------
/figs/gan/epoch_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_50.png
--------------------------------------------------------------------------------
/figs/gan/epoch_75.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_75.png
--------------------------------------------------------------------------------
/figs/gan/epoch_99.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_99.png
--------------------------------------------------------------------------------
/figs/lsgan/epoch_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_0.png
--------------------------------------------------------------------------------
/figs/lsgan/epoch_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_25.png
--------------------------------------------------------------------------------
/figs/lsgan/epoch_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_50.png
--------------------------------------------------------------------------------
/figs/lsgan/epoch_75.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_75.png
--------------------------------------------------------------------------------
/figs/lsgan/epoch_99.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_99.png
--------------------------------------------------------------------------------
/figs/wgan/epoch_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_0.png
--------------------------------------------------------------------------------
/figs/wgan/epoch_25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_25.png
--------------------------------------------------------------------------------
/figs/wgan/epoch_50.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_50.png
--------------------------------------------------------------------------------
/figs/wgan/epoch_75.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_75.png
--------------------------------------------------------------------------------
/figs/wgan/epoch_99.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_99.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 |
3 | import matplotlib, os
4 | matplotlib.use("Agg")
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | from matplotlib import pyplot as plt
10 | from utils import *
11 |
12 |
13 | class GAN(BasicTrainFramework):
14 | def __init__(self,
15 | gan_type='gan',
16 | optim_type='adam',
17 | data=datamanager_gaussian(0,1),
18 | batch_size=64,
19 | noise_dim=50,
20 | learning_rate=2e-4,
21 | optim_num=0.5,
22 | clip_num=0.03,
23 | critic_iter=5
24 | ):
25 |
26 | self.noise_dim = noise_dim
27 | self.clip_num = None if clip_num==0 else clip_num
28 | self.lr = learning_rate
29 | self.optim_num = optim_num
30 | self.critic_iter = critic_iter
31 | super(GAN, self).__init__(batch_size, gan_type)
32 |
33 | self.gan_type = gan_type
34 | self.optim_type = optim_type
35 |
36 | self.data = data
37 |
38 | np.random.seed(233)
39 | # self.sample_data = np.random.uniform(-1.0, 1.0, (self.batch_size, self.noise_dim))
40 | self.sample_data = np.random.normal(size=(self.batch_size, self.noise_dim))
41 |
42 | self.generator = Generator_MLP(name='generator')
43 | self.discriminator = Discriminator_MLP(name='discriminator')
44 |
45 | self.build_placeholder()
46 | self.build_gan()
47 | self.build_optimizer(optim_type)
48 | self.build_summary()
49 |
50 | self.build_sess()
51 | self.build_dirs()
52 |
53 | def build_placeholder(self):
54 | self.noise = tf.placeholder(shape=(self.batch_size, self.noise_dim), dtype=tf.float32)
55 | self.source = tf.placeholder(shape=(self.batch_size, self.noise_dim), dtype=tf.float32)
56 |
57 | def build_gan(self):
58 | self.G = self.generator(self.noise, is_training=True, reuse=False)
59 | self.G_test = self.generator(self.noise, is_training=False, reuse=True)
60 | self.logit_real, self.net_real = self.discriminator(self.source, is_training=True, reuse=False)
61 | self.logit_fake, self.net_fake = self.discriminator(self.G, is_training=True, reuse=True)
62 |
63 | self.mean_real, self.std_real = tf.nn.moments(self.source, axes=[0,1])
64 | self.mean_fake, self.std_fake = tf.nn.moments(self.G, axes=[0,1])
65 | self.std_real = tf.sqrt(self.std_real)
66 | self.std_fake = tf.sqrt(self.std_fake)
67 |
68 | def build_optimizer(self, optim_type='adam'):
69 | if self.gan_type == 'gan':
70 | self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logit_real, labels=tf.ones_like(self.logit_real)))
71 | self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logit_fake, labels=tf.zeros_like(self.logit_fake)))
72 | self.D_loss = self.D_loss_real + self.D_loss_fake
73 | self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logit_fake, labels=tf.ones_like(self.logit_fake)))
74 | elif self.gan_type == 'wgan':
75 | self.D_loss_real = - tf.reduce_mean(self.logit_real)
76 | self.D_loss_fake = tf.reduce_mean(self.logit_fake)
77 | self.D_loss = self.D_loss_real + self.D_loss_fake
78 | self.G_loss = - self.D_loss_fake
79 | if self.clip_num:
80 | print "GC"
81 | self.D_clip = [v.assign(tf.clip_by_value(v, -self.clip_num, self.clip_num)) for v in self.discriminator.weights]
82 | elif self.gan_type == 'lsgan':
83 | def mse_loss(pred, data):
84 | return tf.sqrt(2 * tf.nn.l2_loss(pred - data)) / self.batch_size
85 | self.D_loss_real = tf.reduce_mean(mse_loss(self.logit_real, tf.ones_like(self.logit_real)))
86 | self.D_loss_fake = tf.reduce_mean(mse_loss(self.logit_fake, tf.zeros_like(self.logit_fake)))
87 | self.D_loss = 0.5 * (self.D_loss_real + self.D_loss_fake)
88 | self.G_loss = tf.reduce_mean(mse_loss(self.logit_fake, tf.ones_like(self.logit_fake)))
89 | if self.clip_num:
90 | print "GC"
91 | self.D_clip = [v.assign(tf.clip_by_value(v, -self.clip_num, self.clip_num)) for v in self.discriminator.weights]
92 |
93 | if optim_type == 'adam':
94 | self.D_solver = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=self.optim_num).minimize(self.D_loss, var_list=self.discriminator.vars)
95 | self.G_solver = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=self.optim_num).minimize(self.G_loss, var_list=self.generator.vars)
96 | elif optim_type == 'rmsprop':
97 | self.D_solver = tf.train.RMSPropOptimizer(learning_rate=self.lr).minimize(self.D_loss, var_list=self.discriminator.vars)
98 | self.G_solver = tf.train.RMSPropOptimizer(learning_rate=self.lr).minimize(self.G_loss, var_list=self.generator.vars)
99 |
100 | def build_summary(self):
101 | D_sum = tf.summary.scalar("D_loss", self.D_loss)
102 | D_sum_real = tf.summary.scalar("D_loss_real", self.D_loss_real)
103 | D_sum_fake = tf.summary.scalar("D_loss_fake", self.D_loss_fake)
104 | G_sum = tf.summary.scalar("G_loss", self.G_loss)
105 | mean_sum_real = tf.summary.scalar("mean_real", self.mean_real)
106 | mean_sum_fake = tf.summary.scalar("mean_fake", self.mean_fake)
107 | std_sum_real = tf.summary.scalar("std_real", self.std_real)
108 | std_sum_fake = tf.summary.scalar("std_fake", self.std_fake)
109 | self.summary = tf.summary.merge([D_sum, D_sum_real, D_sum_fake, G_sum,
110 | mean_sum_real, mean_sum_fake, std_sum_real, std_sum_fake])
111 |
112 | def test(self):
113 | out = self.sess.run(self.G_test, feed_dict={self.noise:self.sample_data})
114 | print "mean=%.2f std=%.2f" % (np.mean(out), np.std(out))
115 |
116 | def sample(self, epoch):
117 | real = self.data([500000, self.noise_dim])
118 | pr, _ = np.histogram(real, bins=np.linspace(-6, 10, 200), density=True)
119 | plt.plot(np.linspace(-6, 10, len(pr)), pr, label='real', color='g', linewidth=2)
120 | fake = []
121 | for i in range(500):
122 | out = self.sess.run(self.G_test, feed_dict={self.noise: np.random.normal(size=(self.batch_size, self.noise_dim))})
123 | fake.append(out)
124 | pf, _ = np.histogram(np.concatenate(fake), bins=np.linspace(-6, 10, 200), density=True)
125 | plt.plot(np.linspace(-6, 10, len(pf)), pf, label='fake', color='r', linewidth=1.5)
126 | plt.title("epoch_{}".format(epoch))
127 | plt.legend()
128 | plt.savefig(os.path.join(self.fig_dir, "epoch_{}.png".format(epoch)))
129 | plt.clf()
130 |
131 | def train(self, epoches=1):
132 | self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
133 | batches_per_epoch = 100
134 |
135 | for epoch in range(epoches):
136 |
137 | for idx in range(batches_per_epoch):
138 | cnt = epoch * batches_per_epoch + idx
139 |
140 | data = self.data([self.batch_size, self.noise_dim])
141 |
142 | feed_dict = {
143 | self.source: data,
144 | self.noise: np.random.normal(size=(self.batch_size, self.noise_dim))
145 | }
146 |
147 | # train D
148 | self.sess.run(self.D_solver, feed_dict=feed_dict)
149 | if self.clip_num:
150 | self.sess.run(self.D_clip)
151 |
152 | # train G
153 | if (cnt-1) % self.critic_iter == 0:
154 | self.sess.run(self.G_solver, feed_dict=feed_dict)
155 |
156 | if cnt % 20 == 0:
157 | d_loss, d_loss_r, d_loss_f, g_loss, sum_str = self.sess.run([self.D_loss, self.D_loss_real, self.D_loss_fake, self.G_loss, self.summary], feed_dict=feed_dict)
158 | print self.version + " epoch [%3d/%3d] iter [%3d/%3d] D=%.3f Dr=%.3f Df=%.3f G=%.3f" % \
159 | (epoch, epoches, idx, batches_per_epoch, d_loss, d_loss_r, d_loss_f, g_loss)
160 | self.writer.add_summary(sum_str, cnt)
161 | self.test()
162 | if epoch % 25 == 0:
163 | self.sample(epoch)
164 | self.sample(epoch)
165 | self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=cnt)
166 |
167 | if __name__ == "__main__":
168 |
169 | data = datamanager_gaussian(mean=3.5, std=0.7)
170 |
171 | gan = GAN(gan_type='gan', data=data, batch_size=64, noise_dim=10, clip_num=0, critic_iter=5)
172 | gan.train(100)
173 |
174 | gan = GAN(gan_type='wgan', data=data, batch_size=64, noise_dim=10, clip_num=0.1, critic_iter=5)
175 | gan.train(100)
176 |
177 | gan = GAN(gan_type='lsgan', data=data, batch_size=64, noise_dim=10, clip_num=0.1, critic_iter=5)
178 | gan.train(100)
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # GAN to generate Normal Distribution
2 |
3 | Use GAN to generate Gaussian distribution N(3.5, 0.7).
4 |
5 | # Usage
6 |
7 | ```python
8 | python main.py
9 | ```
10 |
11 | # Random Generation
12 |
13 | *Name* | *Epoch 0* | *Epoch 25* | *Epoch 99*
14 | :---: | :---: | :---: | :---:
15 | GAN |
|
|
16 | WGAN |
|
|
17 | LSGAN |
|
|
18 |
19 | # GAN Papers
20 |
21 | Name | Paper Link
22 | :---: | :---: |
23 | GAN | [Arxiv](https://arxiv.org/abs/1406.2661)
24 | WGAN | [Arxiv](https://arxiv.org/abs/1701.07875)
25 | LSGAN| [Arxiv](https://arxiv.org/abs/1611.04076)
26 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 | import os
6 |
7 | class BasicBlock(object):
8 | def __init__(self, hidden_units, name):
9 | self.name = name
10 | self.hidden_units = hidden_units
11 | @property
12 | def vars(self):
13 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
14 |
15 | class BasicTrainFramework(object):
16 | def __init__(self, batch_size, version):
17 | self.batch_size = batch_size
18 | self.version = version
19 |
20 | def build_dirs(self):
21 | self.log_dir = os.path.join('logs', self.version)
22 | self.model_dir = os.path.join('checkpoints', self.version)
23 | self.fig_dir = os.path.join('figs', self.version)
24 | for d in [self.log_dir, self.model_dir, self.fig_dir]:
25 | if (d is not None) and (not os.path.exists(d)):
26 | print "mkdir " + d
27 | os.makedirs(d)
28 |
29 | def build_sess(self):
30 | gpu_options = tf.GPUOptions(allow_growth=True)
31 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
32 | self.sess.run(tf.global_variables_initializer())
33 | self.saver = tf.train.Saver()
34 |
35 | def build_network(self):
36 | self.D_logit_real = None
37 | self.D_logit_fake = None
38 |
39 | def load_model(self, checkpoint_dir=None, ckpt_name=None):
40 | import re
41 | print "load checkpoints ..."
42 | checkpoint_dir = checkpoint_dir or self.model_dir
43 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
44 | if ckpt and ckpt.model_checkpoint_path:
45 | ckpt_name = ckpt_name or os.path.basename(ckpt.model_checkpoint_path)
46 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
47 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
48 | print "Success to read {}".format(ckpt_name)
49 | return True, counter
50 | else:
51 | print "Failed to find a checkpoint"
52 | return False, 0
53 |
54 | def bn(x, is_training, name):
55 | return tf.contrib.layers.batch_norm(x,
56 | decay=0.999,
57 | updates_collections=None,
58 | epsilon=0.001,
59 | scale=True,
60 | fused=False,
61 | is_training=is_training,
62 | scope=name)
63 |
64 | def spectral_norm(w, iteration=10, name="sn"):
65 | '''
66 | Ref: https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/65218e8cc6916d24b49504c337981548685e1be1/spectral_norm.py
67 | '''
68 | w_shape = w.shape.as_list() # [KH, KW, Cin, Cout] or [H, W]
69 | w = tf.reshape(w, [-1, w_shape[-1]]) # [KH*KW*Cin, Cout] or [H, W]
70 |
71 | u = tf.get_variable(name+"_u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)
72 | s = tf.get_variable(name+"_sigma", [1, ], initializer=tf.random_normal_initializer(), trainable=False)
73 |
74 | u_hat = u # [1, Cout] or [1, W]
75 | v_hat = None
76 |
77 | for _ in range(iteration):
78 | v_hat = tf.nn.l2_normalize(tf.matmul(u_hat, tf.transpose(w))) # [1, KH*KW*Cin] or [1, H]
79 | u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w)) # [1, Cout] or [1, W]
80 |
81 | u_hat = tf.stop_gradient(u_hat)
82 | v_hat = tf.stop_gradient(v_hat)
83 |
84 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) # [1,1]
85 | sigma = tf.reshape(sigma, (1,))
86 |
87 | with tf.control_dependencies([u.assign(u_hat), s.assign(sigma)]):
88 | # ops here run after u.assign(u_hat)
89 | w_norm = w / sigma
90 | w_norm = tf.reshape(w_norm, w_shape)
91 |
92 | return w_norm
93 |
94 | def dense(x, output_size, stddev=0.02, bias_start=0.0, activation=None, sn=False, reuse=False, name='dense'):
95 | shape = x.get_shape().as_list()
96 | with tf.variable_scope(name, reuse=reuse):
97 | W = tf.get_variable(
98 | 'weights', [shape[1], output_size],
99 | tf.float32,
100 | tf.random_normal_initializer(stddev=stddev))
101 | bias = tf.get_variable(
102 | 'biases', [output_size],
103 | initializer=tf.constant_initializer(bias_start))
104 | if sn:
105 | W = spectral_norm(W, name="sn")
106 | out = tf.matmul(x, W) + bias
107 | if activation is not None:
108 | out = activation(out)
109 |
110 | return out
111 |
112 | def conv_cond_concat(x, y):
113 | # x: [N, H, W, C]
114 | # y: [N, 1, 1, d]
115 | x_shapes = x.get_shape()
116 | y_shapes = y.get_shape()
117 | return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
118 |
119 | class Generator_MLP(BasicBlock):
120 | def __init__(self, name=None):
121 | super(Generator_MLP, self).__init__(None, name or "Generator_MLP")
122 |
123 | def __call__(self, z, y=None, is_training=True, reuse=False):
124 | with tf.variable_scope(self.name, reuse=reuse):
125 | if y is not None:
126 | z = tf.concat([z,y], 1)
127 | net = tf.nn.softplus(dense(z, 64, name='g_fc1'))
128 | out = dense(net, 10, name='g_fc2')
129 | return out
130 |
131 | class Discriminator_MLP(BasicBlock):
132 | def __init__(self, class_num=None, name=None):
133 | super(Discriminator_MLP, self).__init__(None, name or "Discriminator_MLP")
134 | self.class_num = class_num
135 |
136 | def __call__(self, x, y=None, sn=False, is_training=True, reuse=False):
137 | with tf.variable_scope(self.name, reuse=reuse):
138 | batch_size = x.get_shape().as_list()[0]
139 | if y is not None:
140 | ydim = y.get_shape().as_list()[-1]
141 | y = tf.reshape(y, [batch_size, 1, 1, ydim])
142 | x = conv_cond_concat(x, y) # [bz, 28, 28, 11]
143 |
144 | net = tf.nn.tanh(dense(x, 64, sn=sn, name='d_fc1'))
145 | net = tf.nn.tanh(bn(dense(net, 64, sn=sn, name='d_fc2'), is_training, name='d_bn2'))
146 |
147 | yd = dense(net, 1, sn=sn, name="D_dense")
148 |
149 | if self.class_num is not None:
150 | print self.class_num
151 | yc = dense(net, self.class_num, sn=sn, name='C_dense')
152 | return yd, net, yc
153 | else:
154 | return yd, net
155 |
156 | @property
157 | def weights(self):
158 | res = []
159 | for v in self.vars:
160 | if "weights" in v.name:
161 | res.append(v)
162 | return res
163 |
164 | class datamanager_gaussian(object):
165 | def __init__(self, mean=0.0, std=1.0):
166 | self.mean = mean
167 | self.std = std
168 |
169 | def __call__(self, size):
170 | return np.random.normal(self.mean, self.std, size=size)
--------------------------------------------------------------------------------