├── figs ├── gan │ ├── epoch_0.png │ ├── epoch_25.png │ ├── epoch_50.png │ ├── epoch_75.png │ └── epoch_99.png ├── lsgan │ ├── epoch_0.png │ ├── epoch_25.png │ ├── epoch_50.png │ ├── epoch_75.png │ └── epoch_99.png └── wgan │ ├── epoch_0.png │ ├── epoch_25.png │ ├── epoch_50.png │ ├── epoch_75.png │ └── epoch_99.png ├── main.py ├── readme.md └── utils.py /figs/gan/epoch_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_0.png -------------------------------------------------------------------------------- /figs/gan/epoch_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_25.png -------------------------------------------------------------------------------- /figs/gan/epoch_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_50.png -------------------------------------------------------------------------------- /figs/gan/epoch_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_75.png -------------------------------------------------------------------------------- /figs/gan/epoch_99.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/gan/epoch_99.png -------------------------------------------------------------------------------- /figs/lsgan/epoch_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_0.png -------------------------------------------------------------------------------- /figs/lsgan/epoch_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_25.png -------------------------------------------------------------------------------- /figs/lsgan/epoch_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_50.png -------------------------------------------------------------------------------- /figs/lsgan/epoch_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_75.png -------------------------------------------------------------------------------- /figs/lsgan/epoch_99.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/lsgan/epoch_99.png -------------------------------------------------------------------------------- /figs/wgan/epoch_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_0.png -------------------------------------------------------------------------------- /figs/wgan/epoch_25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_25.png -------------------------------------------------------------------------------- /figs/wgan/epoch_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_50.png -------------------------------------------------------------------------------- /figs/wgan/epoch_75.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_75.png -------------------------------------------------------------------------------- /figs/wgan/epoch_99.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/generate_normal/8341bdf14806fbe9a8d702cb8226be421426ccf2/figs/wgan/epoch_99.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | import matplotlib, os 4 | matplotlib.use("Agg") 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | from matplotlib import pyplot as plt 10 | from utils import * 11 | 12 | 13 | class GAN(BasicTrainFramework): 14 | def __init__(self, 15 | gan_type='gan', 16 | optim_type='adam', 17 | data=datamanager_gaussian(0,1), 18 | batch_size=64, 19 | noise_dim=50, 20 | learning_rate=2e-4, 21 | optim_num=0.5, 22 | clip_num=0.03, 23 | critic_iter=5 24 | ): 25 | 26 | self.noise_dim = noise_dim 27 | self.clip_num = None if clip_num==0 else clip_num 28 | self.lr = learning_rate 29 | self.optim_num = optim_num 30 | self.critic_iter = critic_iter 31 | super(GAN, self).__init__(batch_size, gan_type) 32 | 33 | self.gan_type = gan_type 34 | self.optim_type = optim_type 35 | 36 | self.data = data 37 | 38 | np.random.seed(233) 39 | # self.sample_data = np.random.uniform(-1.0, 1.0, (self.batch_size, self.noise_dim)) 40 | self.sample_data = np.random.normal(size=(self.batch_size, self.noise_dim)) 41 | 42 | self.generator = Generator_MLP(name='generator') 43 | self.discriminator = Discriminator_MLP(name='discriminator') 44 | 45 | self.build_placeholder() 46 | self.build_gan() 47 | self.build_optimizer(optim_type) 48 | self.build_summary() 49 | 50 | self.build_sess() 51 | self.build_dirs() 52 | 53 | def build_placeholder(self): 54 | self.noise = tf.placeholder(shape=(self.batch_size, self.noise_dim), dtype=tf.float32) 55 | self.source = tf.placeholder(shape=(self.batch_size, self.noise_dim), dtype=tf.float32) 56 | 57 | def build_gan(self): 58 | self.G = self.generator(self.noise, is_training=True, reuse=False) 59 | self.G_test = self.generator(self.noise, is_training=False, reuse=True) 60 | self.logit_real, self.net_real = self.discriminator(self.source, is_training=True, reuse=False) 61 | self.logit_fake, self.net_fake = self.discriminator(self.G, is_training=True, reuse=True) 62 | 63 | self.mean_real, self.std_real = tf.nn.moments(self.source, axes=[0,1]) 64 | self.mean_fake, self.std_fake = tf.nn.moments(self.G, axes=[0,1]) 65 | self.std_real = tf.sqrt(self.std_real) 66 | self.std_fake = tf.sqrt(self.std_fake) 67 | 68 | def build_optimizer(self, optim_type='adam'): 69 | if self.gan_type == 'gan': 70 | self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logit_real, labels=tf.ones_like(self.logit_real))) 71 | self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logit_fake, labels=tf.zeros_like(self.logit_fake))) 72 | self.D_loss = self.D_loss_real + self.D_loss_fake 73 | self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logit_fake, labels=tf.ones_like(self.logit_fake))) 74 | elif self.gan_type == 'wgan': 75 | self.D_loss_real = - tf.reduce_mean(self.logit_real) 76 | self.D_loss_fake = tf.reduce_mean(self.logit_fake) 77 | self.D_loss = self.D_loss_real + self.D_loss_fake 78 | self.G_loss = - self.D_loss_fake 79 | if self.clip_num: 80 | print "GC" 81 | self.D_clip = [v.assign(tf.clip_by_value(v, -self.clip_num, self.clip_num)) for v in self.discriminator.weights] 82 | elif self.gan_type == 'lsgan': 83 | def mse_loss(pred, data): 84 | return tf.sqrt(2 * tf.nn.l2_loss(pred - data)) / self.batch_size 85 | self.D_loss_real = tf.reduce_mean(mse_loss(self.logit_real, tf.ones_like(self.logit_real))) 86 | self.D_loss_fake = tf.reduce_mean(mse_loss(self.logit_fake, tf.zeros_like(self.logit_fake))) 87 | self.D_loss = 0.5 * (self.D_loss_real + self.D_loss_fake) 88 | self.G_loss = tf.reduce_mean(mse_loss(self.logit_fake, tf.ones_like(self.logit_fake))) 89 | if self.clip_num: 90 | print "GC" 91 | self.D_clip = [v.assign(tf.clip_by_value(v, -self.clip_num, self.clip_num)) for v in self.discriminator.weights] 92 | 93 | if optim_type == 'adam': 94 | self.D_solver = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=self.optim_num).minimize(self.D_loss, var_list=self.discriminator.vars) 95 | self.G_solver = tf.train.AdamOptimizer(learning_rate=self.lr, beta1=self.optim_num).minimize(self.G_loss, var_list=self.generator.vars) 96 | elif optim_type == 'rmsprop': 97 | self.D_solver = tf.train.RMSPropOptimizer(learning_rate=self.lr).minimize(self.D_loss, var_list=self.discriminator.vars) 98 | self.G_solver = tf.train.RMSPropOptimizer(learning_rate=self.lr).minimize(self.G_loss, var_list=self.generator.vars) 99 | 100 | def build_summary(self): 101 | D_sum = tf.summary.scalar("D_loss", self.D_loss) 102 | D_sum_real = tf.summary.scalar("D_loss_real", self.D_loss_real) 103 | D_sum_fake = tf.summary.scalar("D_loss_fake", self.D_loss_fake) 104 | G_sum = tf.summary.scalar("G_loss", self.G_loss) 105 | mean_sum_real = tf.summary.scalar("mean_real", self.mean_real) 106 | mean_sum_fake = tf.summary.scalar("mean_fake", self.mean_fake) 107 | std_sum_real = tf.summary.scalar("std_real", self.std_real) 108 | std_sum_fake = tf.summary.scalar("std_fake", self.std_fake) 109 | self.summary = tf.summary.merge([D_sum, D_sum_real, D_sum_fake, G_sum, 110 | mean_sum_real, mean_sum_fake, std_sum_real, std_sum_fake]) 111 | 112 | def test(self): 113 | out = self.sess.run(self.G_test, feed_dict={self.noise:self.sample_data}) 114 | print "mean=%.2f std=%.2f" % (np.mean(out), np.std(out)) 115 | 116 | def sample(self, epoch): 117 | real = self.data([500000, self.noise_dim]) 118 | pr, _ = np.histogram(real, bins=np.linspace(-6, 10, 200), density=True) 119 | plt.plot(np.linspace(-6, 10, len(pr)), pr, label='real', color='g', linewidth=2) 120 | fake = [] 121 | for i in range(500): 122 | out = self.sess.run(self.G_test, feed_dict={self.noise: np.random.normal(size=(self.batch_size, self.noise_dim))}) 123 | fake.append(out) 124 | pf, _ = np.histogram(np.concatenate(fake), bins=np.linspace(-6, 10, 200), density=True) 125 | plt.plot(np.linspace(-6, 10, len(pf)), pf, label='fake', color='r', linewidth=1.5) 126 | plt.title("epoch_{}".format(epoch)) 127 | plt.legend() 128 | plt.savefig(os.path.join(self.fig_dir, "epoch_{}.png".format(epoch))) 129 | plt.clf() 130 | 131 | def train(self, epoches=1): 132 | self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) 133 | batches_per_epoch = 100 134 | 135 | for epoch in range(epoches): 136 | 137 | for idx in range(batches_per_epoch): 138 | cnt = epoch * batches_per_epoch + idx 139 | 140 | data = self.data([self.batch_size, self.noise_dim]) 141 | 142 | feed_dict = { 143 | self.source: data, 144 | self.noise: np.random.normal(size=(self.batch_size, self.noise_dim)) 145 | } 146 | 147 | # train D 148 | self.sess.run(self.D_solver, feed_dict=feed_dict) 149 | if self.clip_num: 150 | self.sess.run(self.D_clip) 151 | 152 | # train G 153 | if (cnt-1) % self.critic_iter == 0: 154 | self.sess.run(self.G_solver, feed_dict=feed_dict) 155 | 156 | if cnt % 20 == 0: 157 | d_loss, d_loss_r, d_loss_f, g_loss, sum_str = self.sess.run([self.D_loss, self.D_loss_real, self.D_loss_fake, self.G_loss, self.summary], feed_dict=feed_dict) 158 | print self.version + " epoch [%3d/%3d] iter [%3d/%3d] D=%.3f Dr=%.3f Df=%.3f G=%.3f" % \ 159 | (epoch, epoches, idx, batches_per_epoch, d_loss, d_loss_r, d_loss_f, g_loss) 160 | self.writer.add_summary(sum_str, cnt) 161 | self.test() 162 | if epoch % 25 == 0: 163 | self.sample(epoch) 164 | self.sample(epoch) 165 | self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=cnt) 166 | 167 | if __name__ == "__main__": 168 | 169 | data = datamanager_gaussian(mean=3.5, std=0.7) 170 | 171 | gan = GAN(gan_type='gan', data=data, batch_size=64, noise_dim=10, clip_num=0, critic_iter=5) 172 | gan.train(100) 173 | 174 | gan = GAN(gan_type='wgan', data=data, batch_size=64, noise_dim=10, clip_num=0.1, critic_iter=5) 175 | gan.train(100) 176 | 177 | gan = GAN(gan_type='lsgan', data=data, batch_size=64, noise_dim=10, clip_num=0.1, critic_iter=5) 178 | gan.train(100) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # GAN to generate Normal Distribution 2 | 3 | Use GAN to generate Gaussian distribution N(3.5, 0.7). 4 | 5 | # Usage 6 | 7 | ```python 8 | python main.py 9 | ``` 10 | 11 | # Random Generation 12 | 13 | *Name* | *Epoch 0* | *Epoch 25* | *Epoch 99* 14 | :---: | :---: | :---: | :---: 15 | GAN | | | 16 | WGAN | | | 17 | LSGAN | | | 18 | 19 | # GAN Papers 20 | 21 | Name | Paper Link 22 | :---: | :---: | 23 | GAN | [Arxiv](https://arxiv.org/abs/1406.2661) 24 | WGAN | [Arxiv](https://arxiv.org/abs/1701.07875) 25 | LSGAN| [Arxiv](https://arxiv.org/abs/1611.04076) 26 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import os 6 | 7 | class BasicBlock(object): 8 | def __init__(self, hidden_units, name): 9 | self.name = name 10 | self.hidden_units = hidden_units 11 | @property 12 | def vars(self): 13 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name) 14 | 15 | class BasicTrainFramework(object): 16 | def __init__(self, batch_size, version): 17 | self.batch_size = batch_size 18 | self.version = version 19 | 20 | def build_dirs(self): 21 | self.log_dir = os.path.join('logs', self.version) 22 | self.model_dir = os.path.join('checkpoints', self.version) 23 | self.fig_dir = os.path.join('figs', self.version) 24 | for d in [self.log_dir, self.model_dir, self.fig_dir]: 25 | if (d is not None) and (not os.path.exists(d)): 26 | print "mkdir " + d 27 | os.makedirs(d) 28 | 29 | def build_sess(self): 30 | gpu_options = tf.GPUOptions(allow_growth=True) 31 | self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 32 | self.sess.run(tf.global_variables_initializer()) 33 | self.saver = tf.train.Saver() 34 | 35 | def build_network(self): 36 | self.D_logit_real = None 37 | self.D_logit_fake = None 38 | 39 | def load_model(self, checkpoint_dir=None, ckpt_name=None): 40 | import re 41 | print "load checkpoints ..." 42 | checkpoint_dir = checkpoint_dir or self.model_dir 43 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 44 | if ckpt and ckpt.model_checkpoint_path: 45 | ckpt_name = ckpt_name or os.path.basename(ckpt.model_checkpoint_path) 46 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 47 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 48 | print "Success to read {}".format(ckpt_name) 49 | return True, counter 50 | else: 51 | print "Failed to find a checkpoint" 52 | return False, 0 53 | 54 | def bn(x, is_training, name): 55 | return tf.contrib.layers.batch_norm(x, 56 | decay=0.999, 57 | updates_collections=None, 58 | epsilon=0.001, 59 | scale=True, 60 | fused=False, 61 | is_training=is_training, 62 | scope=name) 63 | 64 | def spectral_norm(w, iteration=10, name="sn"): 65 | ''' 66 | Ref: https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/65218e8cc6916d24b49504c337981548685e1be1/spectral_norm.py 67 | ''' 68 | w_shape = w.shape.as_list() # [KH, KW, Cin, Cout] or [H, W] 69 | w = tf.reshape(w, [-1, w_shape[-1]]) # [KH*KW*Cin, Cout] or [H, W] 70 | 71 | u = tf.get_variable(name+"_u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 72 | s = tf.get_variable(name+"_sigma", [1, ], initializer=tf.random_normal_initializer(), trainable=False) 73 | 74 | u_hat = u # [1, Cout] or [1, W] 75 | v_hat = None 76 | 77 | for _ in range(iteration): 78 | v_hat = tf.nn.l2_normalize(tf.matmul(u_hat, tf.transpose(w))) # [1, KH*KW*Cin] or [1, H] 79 | u_hat = tf.nn.l2_normalize(tf.matmul(v_hat, w)) # [1, Cout] or [1, W] 80 | 81 | u_hat = tf.stop_gradient(u_hat) 82 | v_hat = tf.stop_gradient(v_hat) 83 | 84 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) # [1,1] 85 | sigma = tf.reshape(sigma, (1,)) 86 | 87 | with tf.control_dependencies([u.assign(u_hat), s.assign(sigma)]): 88 | # ops here run after u.assign(u_hat) 89 | w_norm = w / sigma 90 | w_norm = tf.reshape(w_norm, w_shape) 91 | 92 | return w_norm 93 | 94 | def dense(x, output_size, stddev=0.02, bias_start=0.0, activation=None, sn=False, reuse=False, name='dense'): 95 | shape = x.get_shape().as_list() 96 | with tf.variable_scope(name, reuse=reuse): 97 | W = tf.get_variable( 98 | 'weights', [shape[1], output_size], 99 | tf.float32, 100 | tf.random_normal_initializer(stddev=stddev)) 101 | bias = tf.get_variable( 102 | 'biases', [output_size], 103 | initializer=tf.constant_initializer(bias_start)) 104 | if sn: 105 | W = spectral_norm(W, name="sn") 106 | out = tf.matmul(x, W) + bias 107 | if activation is not None: 108 | out = activation(out) 109 | 110 | return out 111 | 112 | def conv_cond_concat(x, y): 113 | # x: [N, H, W, C] 114 | # y: [N, 1, 1, d] 115 | x_shapes = x.get_shape() 116 | y_shapes = y.get_shape() 117 | return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 118 | 119 | class Generator_MLP(BasicBlock): 120 | def __init__(self, name=None): 121 | super(Generator_MLP, self).__init__(None, name or "Generator_MLP") 122 | 123 | def __call__(self, z, y=None, is_training=True, reuse=False): 124 | with tf.variable_scope(self.name, reuse=reuse): 125 | if y is not None: 126 | z = tf.concat([z,y], 1) 127 | net = tf.nn.softplus(dense(z, 64, name='g_fc1')) 128 | out = dense(net, 10, name='g_fc2') 129 | return out 130 | 131 | class Discriminator_MLP(BasicBlock): 132 | def __init__(self, class_num=None, name=None): 133 | super(Discriminator_MLP, self).__init__(None, name or "Discriminator_MLP") 134 | self.class_num = class_num 135 | 136 | def __call__(self, x, y=None, sn=False, is_training=True, reuse=False): 137 | with tf.variable_scope(self.name, reuse=reuse): 138 | batch_size = x.get_shape().as_list()[0] 139 | if y is not None: 140 | ydim = y.get_shape().as_list()[-1] 141 | y = tf.reshape(y, [batch_size, 1, 1, ydim]) 142 | x = conv_cond_concat(x, y) # [bz, 28, 28, 11] 143 | 144 | net = tf.nn.tanh(dense(x, 64, sn=sn, name='d_fc1')) 145 | net = tf.nn.tanh(bn(dense(net, 64, sn=sn, name='d_fc2'), is_training, name='d_bn2')) 146 | 147 | yd = dense(net, 1, sn=sn, name="D_dense") 148 | 149 | if self.class_num is not None: 150 | print self.class_num 151 | yc = dense(net, self.class_num, sn=sn, name='C_dense') 152 | return yd, net, yc 153 | else: 154 | return yd, net 155 | 156 | @property 157 | def weights(self): 158 | res = [] 159 | for v in self.vars: 160 | if "weights" in v.name: 161 | res.append(v) 162 | return res 163 | 164 | class datamanager_gaussian(object): 165 | def __init__(self, mean=0.0, std=1.0): 166 | self.mean = mean 167 | self.std = std 168 | 169 | def __call__(self, size): 170 | return np.random.normal(self.mean, self.std, size=size) --------------------------------------------------------------------------------