├── README.md ├── utils.py ├── ops.py ├── LSGAN.py ├── GAN.py ├── VAE.py ├── WGAN.py ├── CGAN.py ├── CVAE.py ├── WGAN_GP.py ├── BEGAN.py ├── EBGAN.py ├── ACGAN.py └── infoGAN.py /README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | > This repository implements all kinds of GAN-models based on tensorflow2.0 keras API and in eager mode. The code was only validated on windows platform and just cpu mode. 3 | 4 | 5 | # How to use 6 | > [Install tensorflow2.0-aplha version.](https://tensorflow.google.cn/install/pip) 7 | > Anaconda virtual enviroment is recommended. 8 | > To support GPU mode,please refer to [this](https://tensorflow.google.cn/guide/using_gpu)(I haven't validated on GPU beacuse have no GPU to use). 9 | 10 | # TODO 11 | - [x] To implement ACGAN, EBGAN, BEGAN, etc. 12 | - [x] Support GPU run. 13 | 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import scipy.misc 4 | import os 5 | 6 | def load_mnist_data(batch_size=64,datasets='mnist',model_name=None): 7 | assert datasets in ['mnist','fashion_mnist','cifar10','cifar100'], "you should provided a datasets name in 'mnist','fashion_mnist' " 8 | if datasets=='mnist': 9 | (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() 10 | elif datasets=='fashion_mnist': 11 | (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data() 12 | # elif datasets=='cifar10': 13 | # (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data() 14 | # elif datasets=='cifar100': 15 | # (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data() 16 | train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') 17 | BUFFER_SIZE=train_images.shape[0] 18 | if model_name=='WGAN' or model_name == 'WGAN_GP': 19 | train_images = (train_images-127.5)/127.5 20 | else: 21 | train_images=(train_images)/255.0 22 | train_labels=tf.one_hot(train_labels,depth=10) 23 | train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(batch_size,drop_remainder=True) 24 | return train_dataset 25 | 26 | def inverse_transform(images): 27 | return (images+1.0)/2.0 28 | 29 | 30 | def check_folder(log_dir): 31 | if not os.path.exists(log_dir): 32 | os.makedirs(log_dir) 33 | return log_dir 34 | 35 | 36 | def save_images(images, size, image_path): 37 | return imsave(inverse_transform(images), size, image_path) 38 | 39 | 40 | def imsave(images, size, path): 41 | image = np.squeeze(merge(images, size)) 42 | return scipy.misc.imsave(path, image) 43 | 44 | 45 | def merge(images, size): 46 | h, w = images.shape[1], images.shape[2] 47 | if (images.shape[3] in (3,4)): 48 | c = images.shape[3] 49 | img = np.zeros((h * size[0], w * size[1], c)) 50 | for idx, image in enumerate(images): 51 | i = idx % size[1] 52 | j = idx // size[1] 53 | img[j * h:j * h + h, i * w:i * w + w, :] = image 54 | return img 55 | elif images.shape[3]==1: 56 | img = np.zeros((h * size[0], w * size[1])) 57 | for idx, image in enumerate(images): 58 | i = idx % size[1] 59 | j = idx // size[1] 60 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 61 | return img 62 | else: 63 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 64 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | from tensorflow.keras import layers 4 | 5 | 6 | class Sigmoid(layers.Layer): 7 | 8 | def __init__(self): 9 | super(Sigmoid, self).__init__() 10 | 11 | def call(self, inputs): 12 | return keras.activations.sigmoid(inputs) 13 | 14 | 15 | class Tanh(layers.Layer): 16 | def __init__(self): 17 | super(Tanh, self).__init__() 18 | 19 | def call(self, inputs): 20 | return keras.activations.tanh(inputs) 21 | 22 | 23 | class Conv2D(layers.Layer): 24 | def __init__(self, filters, kernel_size, strides=2): 25 | super(Conv2D, self).__init__() 26 | self.conv_op = layers.Conv2D(filters=filters, 27 | kernel_size=kernel_size, 28 | strides=strides, 29 | padding='same', 30 | kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02), 31 | use_bias=True, 32 | bias_initializer=keras.initializers.Constant(value=0.0)) 33 | 34 | def call(self, inputs): 35 | return self.conv_op(inputs) 36 | 37 | 38 | class BatchNorm(layers.Layer): 39 | def __init__(self, is_training=False): 40 | super(BatchNorm, self).__init__() 41 | self.bn = tf.keras.layers.BatchNormalization(epsilon=1e-5, 42 | momentum=0.9, 43 | scale=True, 44 | trainable=is_training) 45 | 46 | def call(self, inputs, training): 47 | x = self.bn(inputs, training=training) 48 | return x 49 | 50 | 51 | class DenseLayer(layers.Layer): 52 | def __init__(self, hidden_n, is_input=False): 53 | super(DenseLayer, self).__init__() 54 | 55 | self.fc_op = layers.Dense(hidden_n, 56 | kernel_initializer=keras.initializers.RandomNormal(stddev=0.02), 57 | bias_initializer=keras.initializers.Constant(value=0.0)) 58 | 59 | def call(self, inputs): 60 | x = self.fc_op(inputs) 61 | 62 | return x 63 | 64 | 65 | class UpConv2D(layers.Layer): 66 | def __init__(self, filters, kernel_size, strides): 67 | super(UpConv2D, self).__init__() 68 | self.up_conv_op = layers.Conv2DTranspose(filters, 69 | kernel_size=kernel_size, 70 | strides=strides, 71 | padding='same', 72 | kernel_initializer=keras.initializers.RandomNormal(stddev=0.02), 73 | use_bias=True, 74 | bias_initializer=keras.initializers.Constant(value=0.0)) 75 | 76 | def call(self, inputs): 77 | x = self.up_conv_op(inputs) 78 | return x 79 | 80 | 81 | def conv_cond_concat(x, y): 82 | """Concatenate conditioning vector on feature map axis.""" 83 | x_shapes = tf.shape(x) 84 | y_shapes = tf.shape(y) 85 | y = tf.reshape(y, [-1, 1, 1, y_shapes[1]]) 86 | y_shapes = tf.shape(y) 87 | return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 88 | -------------------------------------------------------------------------------- /LSGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import tensorflow as tf 7 | import scipy.misc 8 | from tensorflow import keras 9 | from tensorflow.keras import layers,optimizers,metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | 15 | class LSGAN(): 16 | def __init__(self,args): 17 | super(LSGAN, self).__init__() 18 | self.model_name = args.gan_type 19 | self.batch_size = args.batch_size 20 | self.z_dim = args.z_dim 21 | self.checkpoint_dir = check_folder(os.path.join(args.checkpoint_dir,self.model_name)) 22 | self.result_dir = args.result_dir 23 | self.datasets_name = args.datasets 24 | self.log_dir=args.log_dir 25 | self.learnning_rate=args.lr 26 | self.epoches=args.epoch 27 | self.datasets = load_mnist_data(datasets=self.datasets_name,batch_size=args.batch_size) 28 | self.g = self.make_generator_model(self.z_dim,is_training=True) 29 | self.d = self.make_discriminator_model(is_training=True) 30 | self.g_optimizer = optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 31 | self.d_optimizer = optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 32 | self.g_loss_metric = metrics.Mean('g_loss', dtype=tf.float32) 33 | self.d_loss_metric = metrics.Mean('d_loss', dtype=tf.float32) 34 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 35 | generator_optimizer=self.g_optimizer, 36 | discriminator_optimizer=self.d_optimizer, 37 | generator=self.g, 38 | discriminator=self.d) 39 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 40 | 41 | 42 | 43 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 44 | def make_discriminator_model(self,is_training): 45 | model = tf.keras.Sequential() 46 | model.add(Conv2D(64,4,2)) 47 | model.add(layers.LeakyReLU(alpha=0.2)) 48 | model.add(Conv2D(128,4,2)) 49 | model.add(BatchNorm(is_training=is_training)) 50 | model.add(layers.LeakyReLU(alpha=0.2)) 51 | model.add(layers.Flatten()) 52 | model.add(DenseLayer(1024)) 53 | model.add(BatchNorm(is_training=is_training)) 54 | model.add(layers.LeakyReLU(alpha=0.3)) 55 | model.add(DenseLayer(1)) 56 | return model 57 | 58 | 59 | def make_generator_model(self,z_dim,is_training): 60 | model = tf.keras.Sequential() 61 | model.add(DenseLayer(1024,z_dim)) 62 | model.add(BatchNorm(is_training=is_training)) 63 | model.add(keras.layers.ReLU()) 64 | model.add(DenseLayer(128*7*7)) 65 | model.add(BatchNorm(is_training=is_training)) 66 | model.add(keras.layers.ReLU()) 67 | model.add(layers.Reshape((7,7,128))) 68 | model.add(UpConv2D(64,4,2)) 69 | model.add(BatchNorm(is_training=is_training)) 70 | model.add(keras.layers.ReLU()) 71 | model.add(UpConv2D(1,4,2)) 72 | model.add(Sigmoid()) 73 | return model 74 | 75 | 76 | @property 77 | def model_dir(self): 78 | return "{}_{}_{}_{}".format( 79 | self.model_name, self.datasets_name, 80 | self.batch_size, self.z_dim) 81 | 82 | 83 | 84 | def mse_loss(self, pred, data): 85 | loss_val = tf.sqrt(2 * tf.nn.l2_loss(pred - data)) / self.batch_size 86 | return loss_val 87 | 88 | # training for one batch 89 | @tf.function 90 | def train_one_step(self,batch_images): 91 | batch_z = np.random.uniform(-1, 1,[self.batch_size, self.z_dim]).astype(np.float32) 92 | real_images = batch_images 93 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: 94 | fake_imgs = self.g(batch_z, training=True) 95 | d_fake_logits = self.d(fake_imgs, training=True) 96 | d_real_logits = self.d(real_images, training=True) 97 | d_real_loss = tf.reduce_mean(self.mse_loss(d_real_logits, tf.ones_like(d_real_logits))) 98 | d_fake_loss = tf.reduce_mean(self.mse_loss(d_fake_logits, tf.zeros_like(d_fake_logits))) 99 | d_loss=0.5*(d_real_loss+d_fake_loss) 100 | g_loss = tf.reduce_mean(self.mse_loss(d_fake_logits, tf.ones_like(d_fake_logits))) 101 | gradients_of_d = d_tape.gradient(d_loss, self.d.trainable_variables) 102 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 103 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 104 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 105 | self.g_loss_metric(g_loss) 106 | self.d_loss_metric(d_loss) 107 | 108 | 109 | 110 | def train(self, load=False): 111 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 112 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 113 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 114 | 115 | # if want to load a checkpoints,set load flags to be true 116 | if load: 117 | self.could_load = self.load_ckpt() 118 | ckpt_step = int(self.checkpoint.step) 119 | start_epoch=int((ckpt_step*self.batch_size)//60000) 120 | else: 121 | start_epoch=0 122 | 123 | 124 | for epoch in range(start_epoch,self.epoches): 125 | for batch_images, _ in self.datasets: 126 | self.train_one_step(batch_images) 127 | self.checkpoint.step.assign_add(1) 128 | step = int(self.checkpoint.step) 129 | 130 | 131 | # save generated images for every 50 batches training 132 | if step % 50 == 0: 133 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 134 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 135 | sample_z = tf.random.uniform(minval=-1,maxval= 1, shape=(self.batch_size, self.z_dim), 136 | dtype=tf.dtypes.float32) 137 | print ('step: {}, d_loss: {:.4f}, g_oss: {:.4F}'.format(step,self.d_loss_metric.result(), self.g_loss_metric.result())) 138 | result_to_display = self.g(sample_z, training=False) 139 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 140 | [manifold_h, manifold_w], 141 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 142 | 143 | with self.train_summary_writer.as_default(): 144 | 145 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 146 | tf.summary.scalar('d_loss', self.d_loss_metric.result(), step=step) 147 | 148 | 149 | #save checkpoints for every 400 batches training 150 | if step % 400 ==0: 151 | save_path = self.manager.save() 152 | print("\n----------Saved checkpoint for step {}: {}-------------\n".format(step, save_path)) 153 | self.g_loss_metric.reset_states() 154 | self.d_loss_metric.reset_states() 155 | 156 | 157 | def load_ckpt(self): 158 | self.checkpoint.restore(self.manager.latest_checkpoint) 159 | if self.manager.latest_checkpoint: 160 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 161 | return True 162 | 163 | else: 164 | print("Initializing from scratch.") 165 | return False 166 | 167 | 168 | 169 | 170 | def parse_args(): 171 | desc = "Tensorflow implementation of GAN collections" 172 | parser = argparse.ArgumentParser(description=desc) 173 | parser.add_argument('--gan_type', type=str, default='LSGAN') 174 | parser.add_argument('--datasets', type=str, default='fashion_mnist') 175 | parser.add_argument('--lr', type=float, default=2e-4) 176 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 177 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 178 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 179 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 180 | help='Directory name to save the checkpoints') 181 | parser.add_argument('--result_dir', type=str, default='results', 182 | help='Directory name to save the generated images') 183 | parser.add_argument('--log_dir', type=str, default='logs', 184 | help='Directory name to save training logs') 185 | 186 | return check_args(parser.parse_args()) 187 | 188 | """checking arguments""" 189 | def check_args(args): 190 | # --checkpoint_dir 191 | check_folder(args.checkpoint_dir) 192 | 193 | # --result_dir 194 | check_folder(args.result_dir) 195 | 196 | # --result_dir 197 | check_folder(args.log_dir) 198 | 199 | # --epoch 200 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 201 | 202 | # --batch_size 203 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 204 | 205 | # --z_dim 206 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 207 | 208 | return args 209 | 210 | 211 | def main(): 212 | args = parse_args() 213 | if args is None: 214 | exit() 215 | model = LSGAN(args) 216 | model.train(load=True) 217 | 218 | 219 | if __name__ == '__main__': 220 | main() 221 | 222 | -------------------------------------------------------------------------------- /GAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import tensorflow as tf 7 | import scipy.misc 8 | from tensorflow import keras 9 | from tensorflow.keras import layers,optimizers,metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | 15 | tf.debugging.set_log_device_placement(True) 16 | 17 | class GAN(): 18 | def __init__(self,args): 19 | super(GAN, self).__init__() 20 | self.model_name = args.gan_type 21 | self.batch_size = args.batch_size 22 | self.z_dim = args.z_dim 23 | self.checkpoint_dir =check_folder(os.path.join(args.checkpoint_dir,self.model_name)) 24 | self.result_dir = args.result_dir 25 | self.datasets_name = args.datasets 26 | self.log_dir=args.log_dir 27 | self.learnning_rate=args.lr 28 | self.epoches=args.epoch 29 | self.datasets = load_mnist_data(datasets=self.datasets_name,batch_size=args.batch_size) 30 | self.g = self.make_generator_model(is_training=True) 31 | self.d = self.make_discriminator_model(is_training=True) 32 | self.g_optimizer = optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 33 | self.d_optimizer = optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 34 | self.g_loss_metric = metrics.Mean('g_loss', dtype=tf.float32) 35 | self.d_loss_metric = metrics.Mean('d_loss', dtype=tf.float32) 36 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 37 | generator_optimizer=self.g_optimizer, 38 | discriminator_optimizer=self.d_optimizer, 39 | generator=self.g, 40 | discriminator=self.d) 41 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 42 | 43 | 44 | 45 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 46 | def make_discriminator_model(self,is_training): 47 | model = tf.keras.Sequential() 48 | model.add(Conv2D(64,4,2)) 49 | model.add(layers.LeakyReLU(alpha=0.2)) 50 | model.add(Conv2D(128,4,2)) 51 | model.add(BatchNorm(is_training=is_training)) 52 | model.add(layers.LeakyReLU(alpha=0.2)) 53 | model.add(layers.Flatten()) 54 | model.add(DenseLayer(1024)) 55 | model.add(BatchNorm(is_training=is_training)) 56 | model.add(layers.LeakyReLU(alpha=0.2)) 57 | model.add(DenseLayer(1)) 58 | return model 59 | 60 | 61 | def make_generator_model(self,is_training): 62 | model = tf.keras.Sequential() 63 | model.add(DenseLayer(1024)) 64 | model.add(BatchNorm(is_training=is_training)) 65 | model.add(keras.layers.ReLU()) 66 | model.add(DenseLayer(128*7*7)) 67 | model.add(BatchNorm(is_training=is_training)) 68 | model.add(keras.layers.ReLU()) 69 | model.add(layers.Reshape((7,7,128))) 70 | model.add(UpConv2D(64,4,2)) 71 | model.add(BatchNorm(is_training=is_training)) 72 | model.add(keras.layers.ReLU()) 73 | model.add(UpConv2D(1,4,2)) 74 | model.add(Sigmoid()) 75 | return model 76 | 77 | 78 | @property 79 | def model_dir(self): 80 | return "{}_{}_{}_{}".format( 81 | self.model_name, self.datasets_name, 82 | self.batch_size, self.z_dim) 83 | 84 | def d_loss_fun(self, d_fake_logits, d_real_logits): 85 | d_loss_real = tf.reduce_mean( 86 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real_logits), logits=d_real_logits)) 87 | d_loss_fake = tf.reduce_mean( 88 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake_logits), logits=d_fake_logits)) 89 | total_loss = d_loss_fake+d_loss_real 90 | return total_loss 91 | 92 | def g_loss_fun(self, logits): 93 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 94 | labels=tf.ones_like(logits), logits=logits)) 95 | return g_loss 96 | 97 | 98 | # train for one batch 99 | @tf.function 100 | def train_one_step(self,batch_images): 101 | batch_z = np.random.uniform(-1, 1,[self.batch_size, self.z_dim]).astype(np.float32) 102 | real_images = batch_images 103 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: 104 | fake_imgs = self.g(batch_z, training=True) 105 | d_fake_logits = self.d(fake_imgs, training=True) 106 | d_real_logits = self.d(real_images, training=True) 107 | d_loss = self.d_loss_fun(d_fake_logits, d_real_logits) 108 | g_loss = self.g_loss_fun(d_fake_logits) 109 | gradients_of_d = d_tape.gradient(d_loss, self.d.trainable_variables) 110 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 111 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 112 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 113 | self.g_loss_metric(g_loss) 114 | self.d_loss_metric(d_loss) 115 | 116 | 117 | 118 | def train(self, load=False): 119 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 120 | train_log_dir = self.log_dir+'/'+self.model_name+'/'+ current_time 121 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 122 | 123 | # if want to load a checkpoints,set load flags to be true 124 | if load: 125 | self.could_load = self.load_ckpt() 126 | ckpt_step = int(self.checkpoint.step) 127 | start_epoch=int((ckpt_step*self.batch_size)//60000) 128 | else: 129 | start_epoch=0 130 | 131 | 132 | for epoch in range(start_epoch,self.epoches): 133 | for batch_images, _ in self.datasets: 134 | self.train_one_step(batch_images) 135 | self.checkpoint.step.assign_add(1) 136 | step = int(self.checkpoint.step) 137 | 138 | # save generated images for every 50 batches training 139 | if step % 50 == 0: 140 | print ('step: {}, d_loss: {:.4f}, g_oss: {:.4F}'.format(step,self.d_loss_metric.result(), self.g_loss_metric.result())) 141 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 142 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 143 | sample_z = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim)) 144 | result_to_display = self.g(sample_z, training=False) 145 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 146 | [manifold_h, manifold_w], 147 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 148 | 149 | with self.train_summary_writer.as_default(): 150 | # print("-----------write to summary-----------") 151 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 152 | tf.summary.scalar('d_loss', self.d_loss_metric.result(), step=step) 153 | 154 | 155 | #save checkpoints for every 400 batches training 156 | if step % 400 == 0: 157 | save_path = self.manager.save() 158 | 159 | print("\n----------Saved checkpoint for step {}: {}-------------\n".format(step, save_path)) 160 | 161 | self.g_loss_metric.reset_states() 162 | self.d_loss_metric.reset_states() 163 | 164 | def load_ckpt(self): 165 | self.checkpoint.restore(self.manager.latest_checkpoint) 166 | if self.manager.latest_checkpoint: 167 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 168 | return True 169 | 170 | else: 171 | print("Initializing from scratch.") 172 | return False 173 | 174 | 175 | 176 | def parse_args(): 177 | desc = "Tensorflow implementation of GAN collections" 178 | parser = argparse.ArgumentParser(description=desc) 179 | parser.add_argument('--gan_type', type=str, default='GAN') 180 | parser.add_argument('--datasets', type=str, default='fashion_mnist') 181 | parser.add_argument('--lr', type=float, default=2e-4) 182 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 183 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 184 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 185 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 186 | help='Directory name to save the checkpoints') 187 | parser.add_argument('--result_dir', type=str, default='results', 188 | help='Directory name to save the generated images') 189 | parser.add_argument('--log_dir', type=str, default='logs', 190 | help='Directory name to save training logs') 191 | 192 | return check_args(parser.parse_args()) 193 | 194 | """checking arguments""" 195 | def check_args(args): 196 | # --checkpoint_dir 197 | check_folder(args.checkpoint_dir) 198 | 199 | # --result_dir 200 | check_folder(args.result_dir) 201 | 202 | # --result_dir 203 | check_folder(args.log_dir) 204 | 205 | # --epoch 206 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 207 | 208 | # --batch_size 209 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 210 | 211 | # --z_dim 212 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 213 | 214 | return args 215 | 216 | 217 | def main(): 218 | args = parse_args() 219 | if args is None: 220 | exit() 221 | model = GAN(args) 222 | model.train(load=True) 223 | 224 | 225 | if __name__ == '__main__': 226 | main() 227 | 228 | -------------------------------------------------------------------------------- /VAE.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import tensorflow as tf 7 | import scipy.misc 8 | from tensorflow import keras 9 | from tensorflow.keras import layers,optimizers,metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | 15 | class VAE(): 16 | def __init__(self,args): 17 | super(VAE, self).__init__() 18 | self.model_name = args.gan_type 19 | self.batch_size = args.batch_size 20 | self.z_dim = args.z_dim 21 | self.checkpoint_dir = os.path.join(args.checkpoint_dir, self.model_name) 22 | self.result_dir = args.result_dir 23 | self.datasets_name = args.datasets 24 | self.log_dir=args.log_dir 25 | self.learnning_rate=args.lr 26 | self.epoches=args.epoch 27 | self.datasets = load_mnist_data(datasets=self.datasets_name,batch_size=args.batch_size) 28 | self.decoder = self.make_decoder_model(is_training=True) 29 | self.encoder = self.make_encoder_model(is_training=True) 30 | self.optimizer = keras.optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 31 | self.nll_loss_metric = tf.keras.metrics.Mean('nll_loss', dtype=tf.float32) 32 | self.kl_loss_metric = tf.keras.metrics.Mean('kl_loss', dtype=tf.float32) 33 | self.total_loss_metric = tf.keras.metrics.Mean('totol_loss', dtype=tf.float32) 34 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 35 | optimizer=self.optimizer, 36 | encoder=self.encoder, 37 | decoder=self.decoder) 38 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 39 | 40 | 41 | 42 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 43 | def make_encoder_model(self,is_training): 44 | model = tf.keras.Sequential() 45 | model.add(Conv2D(64,4,2)) 46 | model.add(layers.LeakyReLU(alpha=0.2)) 47 | model.add(Conv2D(128,4,2)) 48 | model.add(BatchNorm(is_training=is_training)) 49 | model.add(layers.LeakyReLU(alpha=0.2)) 50 | model.add(layers.Flatten()) 51 | model.add(DenseLayer(1024)) 52 | model.add(BatchNorm(is_training=is_training)) 53 | model.add(layers.LeakyReLU(alpha=0.2)) 54 | model.add(DenseLayer(2*self.z_dim)) 55 | return model 56 | 57 | 58 | 59 | def make_decoder_model(self,is_training): 60 | model = tf.keras.Sequential() 61 | model.add(DenseLayer(1024)) 62 | model.add(BatchNorm(is_training=is_training)) 63 | model.add(keras.layers.ReLU()) 64 | model.add(DenseLayer(128*7*7)) 65 | model.add(BatchNorm(is_training=is_training)) 66 | model.add(keras.layers.ReLU()) 67 | model.add(layers.Reshape((7,7,128))) 68 | model.add(UpConv2D(64,4,2)) 69 | model.add(BatchNorm(is_training=is_training)) 70 | model.add(keras.layers.ReLU()) 71 | model.add(UpConv2D(1,4,2)) 72 | model.add(Sigmoid()) 73 | return model 74 | 75 | 76 | @property 77 | def model_dir(self): 78 | return "{}_{}_{}_{}".format( 79 | self.model_name, self.datasets_name, 80 | self.batch_size, self.z_dim) 81 | 82 | 83 | # training for one batch 84 | @tf.function 85 | def train_one_step(self,batch_images): 86 | batch_z = np.random.uniform(-1, 1,[self.batch_size, self.z_dim]).astype(np.float32) 87 | real_images = batch_images 88 | with tf.GradientTape() as gradient_tape: 89 | gaussian_params=self.encoder(real_images,training=True) 90 | mu = gaussian_params[:, :self.z_dim] 91 | sigma = 1e-6 + tf.keras.activations.softplus(gaussian_params[:, self.z_dim:]) 92 | z = mu + sigma * tf.random.normal(tf.shape(mu), 0, 1, dtype=tf.float32) 93 | out = self.decoder(z, training=True) 94 | out = tf.clip_by_value(out, 1e-8, 1 - 1e-8) 95 | marginal_likelihood = tf.reduce_sum(real_images * tf.math.log(out) + (1. - real_images) * tf.math.log(1. - out),[1, 2]) 96 | KL_divergence = 0.5 * tf.reduce_sum(tf.math.square(mu) + tf.math.square(sigma) - tf.math.log(1e-8 + tf.math.square(sigma)) - 1, [1]) 97 | self.neg_loglikelihood = -tf.reduce_mean(marginal_likelihood) 98 | self.KL_divergence = tf.reduce_mean(KL_divergence) 99 | ELBO = -self.neg_loglikelihood - self.KL_divergence 100 | loss = -ELBO 101 | 102 | self.trainable_variables=self.decoder.trainable_variables+self.encoder.trainable_variables 103 | gradients = gradient_tape.gradient(loss,self.trainable_variables) 104 | self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) 105 | self.nll_loss_metric(self.neg_loglikelihood) 106 | self.kl_loss_metric(KL_divergence) 107 | self.total_loss_metric(loss) 108 | 109 | 110 | 111 | def train(self, load=False): 112 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 113 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 114 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 115 | 116 | # if want to load a checkpoints,set load flag to be true 117 | if load: 118 | self.could_load = self.load_ckpt() 119 | ckpt_step = int(self.checkpoint.step) 120 | start_epoch=int((ckpt_step*self.batch_size)//60000) 121 | else: 122 | start_epoch=0 123 | 124 | 125 | for epoch in range(start_epoch,self.epoches): 126 | for batch_images, _ in self.datasets: 127 | self.sample_z = np.random.normal(0., 1., (self.batch_size, self.z_dim)).astype(np.float32) 128 | self.train_one_step(batch_images) 129 | self.checkpoint.step.assign_add(1) 130 | step = int(self.checkpoint.step) 131 | 132 | # save generated images for every 50 batches training 133 | if step % 50 == 0: 134 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 135 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 136 | print ('step: {}, nll_loss: {:.4f}, kl_loss: {:.4F} ,total_loss: {:.4F}'.format(step,self.nll_loss_metric.result(), self.kl_loss_metric.result(),self.total_loss_metric.result())) 137 | result_to_display = self.decoder(self.sample_z, training=False) 138 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 139 | [manifold_h, manifold_w], 140 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 141 | 142 | 143 | with self.train_summary_writer.as_default(): 144 | tf.summary.scalar('g_loss', self.nll_loss_metric.result(), step=step) 145 | tf.summary.scalar('d_loss', self.kl_loss_metric.result(), step=step) 146 | tf.summary.scalar('d_loss', self.total_loss_metric.result(), step=step) 147 | 148 | #save checkpoints for every 400 batches training 149 | if step % 400 ==0: 150 | save_path = self.manager.save() 151 | 152 | print("\n----------Saved checkpoint for step {}: {}-------------\n".format(step, save_path)) 153 | 154 | self.nll_loss_metric.reset_states() 155 | self.kl_loss_metric.reset_states() 156 | self.total_loss_metric.reset_states() 157 | 158 | 159 | 160 | def load_ckpt(self): 161 | self.checkpoint.restore(self.manager.latest_checkpoint) 162 | if self.manager.latest_checkpoint: 163 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 164 | return True 165 | 166 | else: 167 | print("Initializing from scratch.") 168 | return False 169 | 170 | 171 | 172 | 173 | 174 | def parse_args(): 175 | desc = "Tensorflow implementation of GAN collections" 176 | parser = argparse.ArgumentParser(description=desc) 177 | parser.add_argument('--gan_type', type=str, default='VAE') 178 | parser.add_argument('--datasets', type=str, default='fashion_mnist') 179 | parser.add_argument('--lr', type=float, default=2e-4) 180 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 181 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 182 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 183 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 184 | help='Directory name to save the checkpoints') 185 | parser.add_argument('--result_dir', type=str, default='results', 186 | help='Directory name to save the generated images') 187 | parser.add_argument('--log_dir', type=str, default='logs', 188 | help='Directory name to save training logs') 189 | 190 | return check_args(parser.parse_args()) 191 | 192 | """checking arguments""" 193 | def check_args(args): 194 | # --checkpoint_dir 195 | check_folder(args.checkpoint_dir) 196 | 197 | # --result_dir 198 | check_folder(args.result_dir) 199 | 200 | # --result_dir 201 | check_folder(args.log_dir) 202 | 203 | # --epoch 204 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 205 | 206 | # --batch_size 207 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 208 | 209 | # --z_dim 210 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 211 | 212 | return args 213 | 214 | 215 | def main(): 216 | args = parse_args() 217 | if args is None: 218 | exit() 219 | model = VAE(args) 220 | model.train(load=True) 221 | 222 | 223 | if __name__ == '__main__': 224 | main() 225 | 226 | -------------------------------------------------------------------------------- /WGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import argparse 4 | import datetime 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import tensorflow as tf 8 | import scipy.misc 9 | from tensorflow import keras 10 | from tensorflow.keras import layers,optimizers,metrics 11 | 12 | from ops import * 13 | from utils import * 14 | 15 | 16 | 17 | 18 | class WGAN(): 19 | def __init__(self,args): 20 | super(WGAN, self).__init__() 21 | self.model_name = args.gan_type 22 | self.batch_size = args.batch_size 23 | self.z_dim = args.z_dim 24 | self.checkpoint_dir = check_folder(os.path.join(args.checkpoint_dir,self.model_name)) 25 | self.result_dir = args.result_dir 26 | self.datasets_name = args.datasets 27 | self.log_dir=args.log_dir 28 | self.learnning_rate=args.lr 29 | self.epoches=args.epoch 30 | self.datasets = load_mnist_data(datasets=self.datasets_name,batch_size=args.batch_size,model_name='WGAN') 31 | self.g = self.make_generator_model(self.z_dim,is_training=True) 32 | self.d = self.make_discriminator_model(is_training=True) 33 | self.g_optimizer = keras.optimizers.RMSprop(lr=5*self.learnning_rate) 34 | self.d_optimizer = keras.optimizers.RMSprop(lr=self.learnning_rate) 35 | self.g_loss_metric = metrics.Mean('g_loss', dtype=tf.float32) 36 | self.critic_loss_metric = metrics.Mean('critic_loss', dtype=tf.float32) 37 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 38 | generator_optimizer=self.g_optimizer, 39 | discriminator_optimizer=self.d_optimizer, 40 | generator=self.g, 41 | discriminator=self.d) 42 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 43 | 44 | 45 | 46 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 47 | def make_discriminator_model(self,is_training): 48 | model = tf.keras.Sequential() 49 | model.add(Conv2D(64,4,2)) 50 | model.add(layers.LeakyReLU(alpha=0.2)) 51 | model.add(Conv2D(128,4,2)) 52 | model.add(BatchNorm(is_training=is_training)) 53 | model.add(layers.LeakyReLU(alpha=0.2)) 54 | model.add(layers.Flatten()) 55 | model.add(DenseLayer(1024)) 56 | model.add(BatchNorm(is_training=is_training)) 57 | model.add(layers.LeakyReLU(alpha=0.2)) 58 | model.add(DenseLayer(1)) 59 | return model 60 | 61 | 62 | def make_generator_model(self,z_dim,is_training): 63 | model = tf.keras.Sequential() 64 | model.add(DenseLayer(1024)) 65 | model.add(BatchNorm(is_training=is_training)) 66 | model.add(keras.layers.ReLU()) 67 | model.add(DenseLayer(128*7*7)) 68 | model.add(BatchNorm(is_training=is_training)) 69 | model.add(keras.layers.ReLU()) 70 | model.add(layers.Reshape((7,7,128))) 71 | model.add(UpConv2D(64,4,2)) 72 | model.add(BatchNorm(is_training=is_training)) 73 | model.add(keras.layers.ReLU()) 74 | model.add(UpConv2D(1,4,2)) 75 | model.add(Tanh()) 76 | return model 77 | 78 | 79 | @property 80 | def model_dir(self): 81 | return "{}_{}_{}_{}".format( 82 | self.model_name, self.datasets_name, 83 | self.batch_size, self.z_dim) 84 | 85 | def d_loss_fun(self, d_fake_logits, d_real_logits): 86 | d_loss_real = tf.reduce_mean( 87 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real_logits), logits=d_real_logits)) 88 | d_loss_fake = tf.reduce_mean( 89 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake_logits), logits=d_fake_logits)) 90 | total_loss = d_loss_fake+d_loss_real 91 | return total_loss 92 | 93 | def g_loss_fun(self, logits): 94 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 95 | labels=tf.ones_like(logits), logits=logits)) 96 | return g_loss 97 | 98 | 99 | # training for one batch 100 | @tf.function 101 | def train_one_step(self,batch_images): 102 | batch_z = np.random.uniform(-1, 1,[self.batch_size, self.z_dim]).astype(np.float32) 103 | real_images = batch_images 104 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: 105 | fake_imgs = self.g(batch_z, training=True) 106 | d_fake_logits= self.d(fake_imgs, training=True) 107 | d_real_logits= self.d(real_images, training=True) 108 | critic_loss=tf.reduce_mean(d_fake_logits-d_real_logits) 109 | g_loss = self.g_loss_fun(d_fake_logits) 110 | 111 | gradients_of_d = d_tape.gradient(critic_loss, self.d.trainable_variables) 112 | 113 | # for WGAN model all the gradients should clip to (-0.01,0.01) 114 | for idx,grad in enumerate(gradients_of_d) : 115 | gradients_of_d[idx]=tf.clip_by_value(grad,-0.01,0.01) 116 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 117 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 118 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 119 | self.g_loss_metric(g_loss) 120 | self.critic_loss_metric(critic_loss) 121 | 122 | 123 | def train(self, load=False): 124 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 125 | train_log_dir = os.path.join(self.log_dir,self.model_name,current_time) 126 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 127 | 128 | # if want to load a checkpoint,set load flags to be true 129 | if load: 130 | self.could_load = self.load_ckpt() 131 | ckpt_step = int(self.checkpoint.step) 132 | start_epoch=int((ckpt_step*self.batch_size)//60000) 133 | else: 134 | start_epoch=0 135 | 136 | 137 | for epoch in range(start_epoch,self.epoches): 138 | for batch_images, _ in self.datasets: 139 | self.train_one_step(batch_images) 140 | self.checkpoint.step.assign_add(1) 141 | step = int(self.checkpoint.step) 142 | 143 | 144 | # save generated images for every 50 batches training 145 | if step % 50 == 0: 146 | print ('step:{}, d_loss: {:.4f}, g_loss: {:.4F}'.format(step,self.critic_loss_metric.result(), self.g_loss_metric.result())) 147 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 148 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 149 | sample_z = tf.random.uniform(minval=-1,maxval= 1, shape=(self.batch_size, self.z_dim), 150 | dtype=tf.dtypes.float32) 151 | result_to_display = self.g(sample_z, training=False) 152 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 153 | [manifold_h, manifold_w], 154 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 155 | 156 | with self.train_summary_writer.as_default(): 157 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 158 | tf.summary.scalar('d_loss', self.critic_loss_metric.result(), step=step) 159 | 160 | #save checkpoints for every 400 batches training 161 | if step % 400 ==0: 162 | save_path = self.manager.save() 163 | 164 | print("\n----------Saved checkpoint for step {}: {}-------------\n".format(step, save_path)) 165 | 166 | self.g_loss_metric.reset_states() 167 | self.critic_loss_metric.reset_states() 168 | 169 | def load_ckpt(self): 170 | self.checkpoint.restore(self.manager.latest_checkpoint) 171 | if self.manager.latest_checkpoint: 172 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 173 | return True 174 | 175 | else: 176 | print("Initializing from scratch.") 177 | return False 178 | 179 | 180 | 181 | 182 | 183 | def parse_args(): 184 | desc = "Tensorflow implementation of GAN collections" 185 | parser = argparse.ArgumentParser(description=desc) 186 | parser.add_argument('--gan_type', type=str, default='WGAN') 187 | parser.add_argument('--datasets', type=str, default='fashion_mnist') 188 | parser.add_argument('--lr', type=float, default=2e-4) 189 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 190 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 191 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 192 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 193 | help='Directory name to save the checkpoints') 194 | parser.add_argument('--result_dir', type=str, default='results', 195 | help='Directory name to save the generated images') 196 | parser.add_argument('--log_dir', type=str, default='logs', 197 | help='Directory name to save training logs') 198 | 199 | return check_args(parser.parse_args()) 200 | 201 | """checking arguments""" 202 | def check_args(args): 203 | # --checkpoint_dir 204 | check_folder(args.checkpoint_dir) 205 | 206 | # --result_dir 207 | check_folder(args.result_dir) 208 | 209 | # --result_dir 210 | check_folder(args.log_dir) 211 | 212 | # --epoch 213 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 214 | 215 | # --batch_size 216 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 217 | 218 | # --z_dim 219 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 220 | 221 | return args 222 | 223 | 224 | def main(): 225 | args = parse_args() 226 | if args is None: 227 | exit() 228 | model = WGAN(args) 229 | model.train(load=True) 230 | 231 | 232 | if __name__ == '__main__': 233 | main() 234 | 235 | -------------------------------------------------------------------------------- /CGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import tensorflow as tf 7 | import scipy.misc 8 | from tensorflow import keras 9 | from tensorflow.keras import layers,optimizers,metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | 15 | class CGAN(): 16 | def __init__(self,args): 17 | super(CGAN, self).__init__() 18 | self.model_name = args.gan_type 19 | self.batch_size = args.batch_size 20 | self.z_dim = args.z_dim 21 | self.y_dim=10 22 | self.checkpoint_dir = check_folder(os.path.join(args.checkpoint_dir,self.model_name)) 23 | self.result_dir = args.result_dir 24 | self.datasets_name = args.datasets 25 | self.log_dir=args.log_dir 26 | self.learnning_rate=args.lr 27 | self.epoches=args.epoch 28 | self.datasets = load_mnist_data(datasets=self.datasets_name,batch_size=args.batch_size) 29 | self.g = self.make_generator_model(is_training=True) 30 | self.d = self.make_discriminator_model(is_training=True) 31 | self.g_optimizer = optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 32 | self.d_optimizer = optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 33 | self.g_loss_metric = metrics.Mean('g_loss', dtype=tf.float32) 34 | self.d_loss_metric = metrics.Mean('d_loss', dtype=tf.float32) 35 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 36 | generator_optimizer=self.g_optimizer, 37 | discriminator_optimizer=self.d_optimizer, 38 | generator=self.g, 39 | discriminator=self.d) 40 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 41 | 42 | 43 | 44 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 45 | def make_discriminator_model(self,is_training): 46 | model = tf.keras.Sequential() 47 | model.add(Conv2D(64,4,2)) 48 | model.add(layers.LeakyReLU(alpha=0.2)) 49 | model.add(Conv2D(128,4,2)) 50 | model.add(BatchNorm(is_training=is_training)) 51 | model.add(layers.LeakyReLU(alpha=0.2)) 52 | model.add(layers.Flatten()) 53 | model.add(DenseLayer(1024)) 54 | model.add(BatchNorm(is_training=is_training)) 55 | model.add(layers.LeakyReLU(alpha=0.3)) 56 | model.add(DenseLayer(1)) 57 | return model 58 | 59 | 60 | def make_generator_model(self,is_training): 61 | model = tf.keras.Sequential() 62 | model.add(DenseLayer(1024)) 63 | model.add(BatchNorm(is_training=is_training)) 64 | model.add(keras.layers.ReLU()) 65 | model.add(DenseLayer(128*7*7)) 66 | model.add(BatchNorm(is_training=is_training)) 67 | model.add(keras.layers.ReLU()) 68 | model.add(layers.Reshape((7,7,128))) 69 | model.add(UpConv2D(64,4,2)) 70 | model.add(BatchNorm(is_training=is_training)) 71 | model.add(keras.layers.ReLU()) 72 | model.add(UpConv2D(1,4,2)) 73 | model.add(Sigmoid()) 74 | return model 75 | 76 | 77 | @property 78 | def model_dir(self): 79 | return "{}_{}_{}_{}".format( 80 | self.model_name, self.datasets_name, 81 | self.batch_size, self.z_dim) 82 | 83 | 84 | def d_loss_fun(self, d_fake_logits, d_real_logits): 85 | d_loss_real = tf.reduce_mean( 86 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real_logits), logits=d_real_logits)) 87 | d_loss_fake = tf.reduce_mean( 88 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake_logits), logits=d_fake_logits)) 89 | total_loss = d_loss_fake+d_loss_real 90 | return total_loss 91 | 92 | 93 | def g_loss_fun(self, logits): 94 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 95 | labels=tf.ones_like(logits), logits=logits)) 96 | return g_loss 97 | 98 | 99 | # train for one batch 100 | @tf.function 101 | def train_one_step(self,batch_images,batch_labels): 102 | z = np.random.uniform(-1, 1,[self.batch_size, self.z_dim]).astype(np.float32) 103 | z_y = tf.concat([z, batch_labels], 1) 104 | real_images = conv_cond_concat(batch_images, batch_labels) 105 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: 106 | fake_imgs = self.g(z_y, training=True) 107 | fake_imgs = conv_cond_concat(fake_imgs, batch_labels) 108 | d_fake_logits = self.d(fake_imgs, training=True) 109 | d_real_logits = self.d(real_images, training=True) 110 | d_loss = self.d_loss_fun(d_fake_logits, d_real_logits) 111 | g_loss = self.g_loss_fun(d_fake_logits) 112 | gradients_of_d = d_tape.gradient(d_loss, self.d.trainable_variables) 113 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 114 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 115 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 116 | self.g_loss_metric(g_loss) 117 | self.d_loss_metric(d_loss) 118 | 119 | 120 | 121 | def train(self, load=False): 122 | self.sample_label=np.random.randint(0,self.y_dim-1,size=(self.batch_size)) 123 | self.sample_label=tf.one_hot(self.sample_label,depth=10) 124 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 125 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 126 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 127 | 128 | # if want to load a checkpoints,set load flags to be true 129 | if load: 130 | self.could_load = self.load_ckpt() 131 | ckpt_step = int(self.checkpoint.step) 132 | start_epoch=int((ckpt_step*self.batch_size)//60000) 133 | else: 134 | start_epoch=0 135 | 136 | 137 | for epoch in range(start_epoch,self.epoches): 138 | for batch_images, batch_labels in self.datasets: 139 | self.train_one_step(batch_images,batch_labels) 140 | self.checkpoint.step.assign_add(1) 141 | step = int(self.checkpoint.step) 142 | 143 | # save generated images for every 50 batches training 144 | if step % 50 == 0: 145 | print ('step: {}, d_loss: {:.4f}, g_oss: {:.4F}'.format(step,self.d_loss_metric.result(), self.g_loss_metric.result())) 146 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 147 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 148 | sample_z = np.random.uniform(-1., 1., size=(self.batch_size, self.z_dim)).astype(np.float32) 149 | sample_z_y=tf.concat([sample_z,self.sample_label],1) 150 | result_to_display = self.g(sample_z_y, training=False) 151 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 152 | [manifold_h, manifold_w], 153 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 154 | 155 | with self.train_summary_writer.as_default(): 156 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 157 | tf.summary.scalar('d_loss', self.d_loss_metric.result(), step=step) 158 | 159 | 160 | #save checkpoints for every 400 batches training 161 | if step % 400 == 0: 162 | save_path = self.manager.save() 163 | 164 | print("\n----------Saved checkpoint for step {}: {}-------------\n".format(step, save_path)) 165 | 166 | self.g_loss_metric.reset_states() 167 | self.d_loss_metric.reset_states() 168 | 169 | def load_ckpt(self): 170 | self.checkpoint.restore(self.manager.latest_checkpoint) 171 | if self.manager.latest_checkpoint: 172 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 173 | return True 174 | 175 | else: 176 | print("Initializing from scratch.") 177 | return False 178 | 179 | 180 | def parse_args(): 181 | desc = "Tensorflow implementation of GAN collections" 182 | parser = argparse.ArgumentParser(description=desc) 183 | parser.add_argument('--gan_type', type=str, default='CGAN') 184 | parser.add_argument('--datasets', type=str, default='mnist') 185 | parser.add_argument('--lr', type=float, default=2e-4) 186 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 187 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 188 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 189 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 190 | help='Directory name to save the checkpoints') 191 | parser.add_argument('--result_dir', type=str, default='results', 192 | help='Directory name to save the generated images') 193 | parser.add_argument('--log_dir', type=str, default='logs', 194 | help='Directory name to save training logs') 195 | 196 | return check_args(parser.parse_args()) 197 | 198 | """checking arguments""" 199 | def check_args(args): 200 | # --checkpoint_dir 201 | check_folder(args.checkpoint_dir) 202 | 203 | # --result_dir 204 | check_folder(args.result_dir) 205 | 206 | # --result_dir 207 | check_folder(args.log_dir) 208 | 209 | # --epoch 210 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 211 | 212 | # --batch_size 213 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 214 | 215 | # --z_dim 216 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 217 | 218 | return args 219 | 220 | 221 | def main(): 222 | args = parse_args() 223 | if args is None: 224 | exit() 225 | model = CGAN(args) 226 | model.train(load=True) 227 | 228 | 229 | if __name__ == '__main__': 230 | main() 231 | 232 | -------------------------------------------------------------------------------- /CVAE.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import tensorflow as tf 7 | import scipy.misc 8 | from tensorflow import keras 9 | from tensorflow.keras import layers,optimizers,metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | 15 | class CVAE(): 16 | def __init__(self,args): 17 | super(CVAE, self).__init__() 18 | self.model_name = args.gan_type 19 | self.batch_size = args.batch_size 20 | self.z_dim = args.z_dim 21 | self.checkpoint_dir = os.path.join(args.checkpoint_dir, self.model_name) 22 | self.result_dir = args.result_dir 23 | self.datasets_name = args.datasets 24 | self.log_dir=args.log_dir 25 | self.learnning_rate=args.lr 26 | self.epoches=args.epoch 27 | self.y_dim=10 28 | # self.sample_y=tf.one_hot(np.random.randint(0,9,size=(64)),depth=10) 29 | self.datasets = load_mnist_data(datasets=self.datasets_name,batch_size=self.batch_size) 30 | self.decoder = self.make_decoder_model(is_training=True) 31 | self.encoder = self.make_encoder_model(is_training=True) 32 | self.optimizer = keras.optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 33 | self.nll_loss_metric = tf.keras.metrics.Mean('nll_loss', dtype=tf.float32) 34 | self.kl_loss_metric = tf.keras.metrics.Mean('kl_loss', dtype=tf.float32) 35 | self.total_loss_metric = tf.keras.metrics.Mean('total_loss', dtype=tf.float32) 36 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 37 | optimizer=self.optimizer, 38 | encoder=self.encoder, 39 | decoder=self.decoder) 40 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 41 | 42 | 43 | 44 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 45 | def make_encoder_model(self,is_training): 46 | model = tf.keras.Sequential() 47 | model.add(Conv2D(64,4,2)) 48 | model.add(layers.LeakyReLU(alpha=0.2)) 49 | model.add(Conv2D(128,4,2)) 50 | model.add(BatchNorm(is_training=is_training)) 51 | model.add(layers.LeakyReLU(alpha=0.2)) 52 | model.add(layers.Flatten()) 53 | model.add(DenseLayer(1024)) 54 | model.add(BatchNorm(is_training=is_training)) 55 | model.add(layers.LeakyReLU(alpha=0.2)) 56 | model.add(DenseLayer(2*self.z_dim)) 57 | return model 58 | 59 | 60 | 61 | def make_decoder_model(self,is_training): 62 | model = tf.keras.Sequential() 63 | model.add(DenseLayer(1024)) 64 | model.add(BatchNorm(is_training=is_training)) 65 | model.add(keras.layers.ReLU()) 66 | model.add(DenseLayer(128*7*7)) 67 | model.add(BatchNorm(is_training=is_training)) 68 | model.add(keras.layers.ReLU()) 69 | model.add(layers.Reshape((7,7,128))) 70 | model.add(UpConv2D(64,4,2)) 71 | model.add(BatchNorm(is_training=is_training)) 72 | model.add(keras.layers.ReLU()) 73 | model.add(UpConv2D(1,4,2)) 74 | model.add(Sigmoid()) 75 | return model 76 | 77 | 78 | @property 79 | def model_dir(self): 80 | return "{}_{}_{}_{}".format( 81 | self.model_name, self.datasets_name, 82 | self.batch_size, self.z_dim) 83 | 84 | 85 | # training for one batch 86 | @tf.function 87 | def train_one_step(self,batch_images,batch_labels): 88 | with tf.GradientTape() as gradient_tape: 89 | batch_images_y=conv_cond_concat(batch_images,batch_labels) 90 | gaussian_params=self.encoder(batch_images_y,training=True) 91 | mu = gaussian_params[:, :self.z_dim] 92 | sigma = 1e-6 + tf.keras.activations.softplus(gaussian_params[:, self.z_dim:]) 93 | z = mu + sigma * tf.random.normal(tf.shape(mu), 0, 1, dtype=tf.float32) 94 | z_y = tf.concat([z, batch_labels], 1) 95 | 96 | out = self.decoder(z_y, training=True) 97 | out = tf.clip_by_value(out, 1e-8, 1 - 1e-8) 98 | marginal_likelihood = tf.reduce_sum(batch_images * tf.math.log(out) + (1. - batch_images) * tf.math.log(1. - out),[1, 2]) 99 | KL_divergence = 0.5 * tf.reduce_sum(tf.math.square(mu) + tf.math.square(sigma) - tf.math.log(1e-8 + tf.math.square(sigma)) - 1, [1]) 100 | self.neg_loglikelihood = -tf.reduce_mean(marginal_likelihood) 101 | self.KL_divergence = tf.reduce_mean(KL_divergence) 102 | ELBO = -self.neg_loglikelihood - self.KL_divergence 103 | loss = -ELBO 104 | 105 | self.trainable_variables=self.decoder.trainable_variables+self.encoder.trainable_variables 106 | gradients = gradient_tape.gradient(loss,self.trainable_variables) 107 | self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) 108 | self.nll_loss_metric(self.neg_loglikelihood) 109 | self.kl_loss_metric(KL_divergence) 110 | self.total_loss_metric(loss) 111 | 112 | 113 | 114 | def train(self, load=False): 115 | self.sample_label=np.random.randint(0,self.y_dim-1,size=(self.batch_size)) 116 | self.sample_label=tf.one_hot(self.sample_label,depth=10) 117 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 118 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 119 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 120 | 121 | # if want to load a checkpoints,set load flag to be true 122 | if load: 123 | self.could_load = self.load_ckpt() 124 | ckpt_step = int(self.checkpoint.step) 125 | start_epoch=int((ckpt_step*self.batch_size)//60000) 126 | else: 127 | start_epoch=0 128 | 129 | 130 | for epoch in range(start_epoch,self.epoches): 131 | for batch_images, batch_labels in self.datasets: 132 | self.train_one_step(batch_images,batch_labels) 133 | self.checkpoint.step.assign_add(1) 134 | step = int(self.checkpoint.step) 135 | 136 | # save generated images for every 50 batches training 137 | if step % 50 == 0: 138 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 139 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 140 | print ('step: {}, nll_loss: {:.4f}, kl_loss: {:.4F} ,total_loss: {:.4F}'.format(step,self.nll_loss_metric.result(), self.kl_loss_metric.result(),self.total_loss_metric.result())) 141 | sample_z = np.random.uniform(-1., 1., size=(self.batch_size, self.z_dim)).astype(np.float32) 142 | self.samples=tf.concat([sample_z,self.sample_label],1) 143 | result_to_display = self.decoder(self.samples, training=False) 144 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 145 | [manifold_h, manifold_w], 146 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 147 | 148 | with self.train_summary_writer.as_default(): 149 | tf.summary.scalar('g_loss', self.nll_loss_metric.result(), step=step) 150 | tf.summary.scalar('d_loss', self.kl_loss_metric.result(), step=step) 151 | tf.summary.scalar('d_loss', self.total_loss_metric.result(), step=step) 152 | 153 | #save checkpoints for every 400 batches training 154 | if step % 400 ==0: 155 | save_path = self.manager.save() 156 | 157 | print("\n----------Saved checkpoint for step {}: {}-------------\n".format(step, save_path)) 158 | 159 | self.nll_loss_metric.reset_states() 160 | self.kl_loss_metric.reset_states() 161 | self.total_loss_metric.reset_states() 162 | 163 | 164 | 165 | def load_ckpt(self): 166 | self.checkpoint.restore(self.manager.latest_checkpoint) 167 | if self.manager.latest_checkpoint: 168 | print("load success! restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 169 | return True 170 | 171 | else: 172 | print("load failed! Initializing from scratch.") 173 | return False 174 | 175 | 176 | 177 | 178 | 179 | def parse_args(): 180 | desc = "Tensorflow implementation of GAN collections" 181 | parser = argparse.ArgumentParser(description=desc) 182 | parser.add_argument('--gan_type', type=str, default='CVAE') 183 | parser.add_argument('--datasets', type=str, default='fashion_mnist') 184 | parser.add_argument('--lr', type=float, default=2e-4) 185 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 186 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 187 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 188 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 189 | help='Directory name to save the checkpoints') 190 | parser.add_argument('--result_dir', type=str, default='results', 191 | help='Directory name to save the generated images') 192 | parser.add_argument('--log_dir', type=str, default='logs', 193 | help='Directory name to save training logs') 194 | 195 | return check_args(parser.parse_args()) 196 | 197 | """checking arguments""" 198 | def check_args(args): 199 | # --checkpoint_dir 200 | check_folder(args.checkpoint_dir) 201 | 202 | # --result_dir 203 | check_folder(args.result_dir) 204 | 205 | # --result_dir 206 | check_folder(args.log_dir) 207 | 208 | # --epoch 209 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 210 | 211 | # --batch_size 212 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 213 | 214 | # --z_dim 215 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 216 | 217 | return args 218 | 219 | 220 | def main(): 221 | args = parse_args() 222 | if args is None: 223 | exit() 224 | model = CVAE(args) 225 | model.train(load=True) 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | 231 | -------------------------------------------------------------------------------- /WGAN_GP.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import argparse 4 | import datetime 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import tensorflow as tf 8 | import scipy.misc 9 | from tensorflow import keras 10 | from tensorflow.keras import layers,optimizers,metrics 11 | 12 | from ops import * 13 | from utils import * 14 | 15 | 16 | # tf.debugging.set_log_device_placement(True) 17 | 18 | 19 | 20 | 21 | class WGAN_GP(): 22 | def __init__(self,args): 23 | super(WGAN_GP, self).__init__() 24 | self.model_name = args.gan_type 25 | self.batch_size = args.batch_size 26 | self.z_dim = args.z_dim 27 | self.lam=10. 28 | self.checkpoint_dir = check_folder(os.path.join(args.checkpoint_dir,self.model_name)) 29 | self.result_dir = args.result_dir 30 | self.datasets_name = args.datasets 31 | self.log_dir=args.log_dir 32 | self.learnning_rate=args.lr 33 | self.epoches=args.epoch 34 | self.datasets = load_mnist_data(datasets=self.datasets_name, model_name='WGAN_GP') 35 | self.g = self.make_generator_model(self.z_dim,is_training=True) 36 | self.d = self.make_discriminator_model(is_training=True) 37 | self.g_optimizer = keras.optimizers.RMSprop(lr=5*self.learnning_rate) 38 | self.d_optimizer = keras.optimizers.RMSprop(lr=self.learnning_rate) 39 | self.g_loss_metric = metrics.Mean('g_loss', dtype=tf.float32) 40 | self.critic_loss_metric = metrics.Mean('critic_loss', dtype=tf.float32) 41 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 42 | generator_optimizer=self.g_optimizer, 43 | discriminator_optimizer=self.d_optimizer, 44 | generator=self.g, 45 | discriminator=self.d) 46 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 47 | 48 | 49 | 50 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 51 | def make_discriminator_model(self,is_training): 52 | model = tf.keras.Sequential() 53 | model.add(Conv2D(64,4,2)) 54 | model.add(layers.LeakyReLU(alpha=0.2)) 55 | model.add(Conv2D(128,4,2)) 56 | model.add(BatchNorm(is_training=is_training)) 57 | model.add(layers.LeakyReLU(alpha=0.2)) 58 | model.add(layers.Flatten()) 59 | model.add(DenseLayer(1024)) 60 | model.add(BatchNorm(is_training=is_training)) 61 | model.add(layers.LeakyReLU(alpha=0.2)) 62 | model.add(DenseLayer(1)) 63 | return model 64 | 65 | 66 | def make_generator_model(self,z_dim,is_training): 67 | model = tf.keras.Sequential() 68 | model.add(DenseLayer(1024)) 69 | model.add(BatchNorm(is_training=is_training)) 70 | model.add(keras.layers.ReLU()) 71 | model.add(DenseLayer(128*7*7)) 72 | model.add(BatchNorm(is_training=is_training)) 73 | model.add(keras.layers.ReLU()) 74 | model.add(layers.Reshape((7,7,128))) 75 | model.add(UpConv2D(64,4,2)) 76 | model.add(BatchNorm(is_training=is_training)) 77 | model.add(keras.layers.ReLU()) 78 | model.add(UpConv2D(1,4,2)) 79 | model.add(Tanh()) 80 | return model 81 | 82 | 83 | @property 84 | def model_dir(self): 85 | return "{}_{}_{}_{}".format( 86 | self.model_name, self.datasets_name, 87 | self.batch_size, self.z_dim) 88 | 89 | 90 | # training for one batch 91 | @tf.function 92 | def train_one_step(self,batch_images): 93 | batch_z = np.random.uniform(-1, 1,[self.batch_size, self.z_dim]).astype(np.float32) 94 | real_images = batch_images 95 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: 96 | fake_imgs = self.g(batch_z, training=True) 97 | d_fake_logits= self.d(fake_imgs, training=True) 98 | d_real_logits= self.d(real_images, training=True) 99 | critic_loss=tf.reduce_mean(d_fake_logits-d_real_logits) 100 | g_loss = -d_fake_logits 101 | 102 | 103 | # see https://tensorflow.google.cn/tutorials/eager/automatic_differentiation for higher-order diffirention method in tensorflow 104 | #calculate gradient of delta D(x) for x be interploted images 105 | with tf.GradientTape() as penalty_tape: 106 | alpha=tf.random.uniform([self.batch_size],0.,1.,dtype=tf.float32) 107 | alpha=tf.reshape(alpha,(-1,1,1,1)) 108 | interpolated=real_images+alpha*(fake_imgs-real_images) 109 | penalty_tape.watch(interpolated) 110 | inter_logits=self.d(interpolated,training=False) 111 | gradient=penalty_tape.gradient(inter_logits,interpolated) 112 | grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(gradient), axis=[1,2,3])) 113 | gradient_penalty = tf.reduce_mean((grad_l2-1)**2) 114 | critic_loss +=self.lam*gradient_penalty 115 | 116 | 117 | # calculate gradient respect to loss for generator and discriminator 118 | gradients_of_d = d_tape.gradient(critic_loss, self.d.trainable_variables) 119 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 120 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 121 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 122 | self.g_loss_metric(g_loss) 123 | self.critic_loss_metric(critic_loss) 124 | 125 | 126 | def train(self, load=False): 127 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 128 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 129 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 130 | 131 | # if want to load a checkpoints,set load flags to be true 132 | if load: 133 | self.could_load = self.load_ckpt() 134 | ckpt_step = int(self.checkpoint.step) 135 | start_epoch=int((ckpt_step*self.batch_size)//60000) 136 | else: 137 | start_epoch=0 138 | 139 | 140 | for epoch in range(start_epoch,self.epoches): 141 | for batch_images, _ in self.datasets: 142 | self.train_one_step(batch_images) 143 | self.checkpoint.step.assign_add(1) 144 | step = int(self.checkpoint.step) 145 | 146 | # save generated images for every 50 batches training 147 | if step % 50 == 0: 148 | sample_z = tf.random.uniform(minval=-1,maxval= 1, shape=(self.batch_size, self.z_dim), 149 | dtype=tf.dtypes.float32) 150 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 151 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 152 | print ('step:{}, d_loss: {:.4f}, g_oss: {:.4F}'.format(step,self.critic_loss_metric.result(), self.g_loss_metric.result())) 153 | result_to_display = self.g(sample_z, training=False) 154 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 155 | [manifold_h, manifold_w], 156 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 157 | 158 | with self.train_summary_writer.as_default(): 159 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 160 | tf.summary.scalar('d_loss', self.critic_loss_metric.result(), step=step) 161 | 162 | #save checkpoints for every 400 batches training 163 | if step % 400 ==0: 164 | save_path = self.manager.save() 165 | 166 | print("\n----------Saved checkpoint for step {}: {}-------------\n".format(step, save_path)) 167 | 168 | self.g_loss_metric.reset_states() 169 | self.critic_loss_metric.reset_states() 170 | 171 | def load_ckpt(self): 172 | self.checkpoint.restore(self.manager.latest_checkpoint) 173 | if self.manager.latest_checkpoint: 174 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 175 | return True 176 | 177 | else: 178 | print("Initializing from scratch.") 179 | return False 180 | 181 | 182 | 183 | 184 | 185 | def parse_args(): 186 | desc = "Tensorflow implementation of GAN collections" 187 | parser = argparse.ArgumentParser(description=desc) 188 | parser.add_argument('--gan_type', type=str, default='WGAN_GP') 189 | parser.add_argument('--datasets', type=str, default='fashion_mnist') 190 | parser.add_argument('--lr', type=float, default=2e-4) 191 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 192 | parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') 193 | parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') 194 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 195 | help='Directory name to save the checkpoints') 196 | parser.add_argument('--result_dir', type=str, default='results', 197 | help='Directory name to save the generated images') 198 | parser.add_argument('--log_dir', type=str, default='logs', 199 | help='Directory name to save training logs') 200 | 201 | return check_args(parser.parse_args()) 202 | 203 | """checking arguments""" 204 | def check_args(args): 205 | # --checkpoint_dir 206 | check_folder(args.checkpoint_dir) 207 | 208 | # --result_dir 209 | check_folder(args.result_dir) 210 | 211 | # --result_dir 212 | check_folder(args.log_dir) 213 | 214 | # --epoch 215 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 216 | 217 | # --batch_size 218 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 219 | 220 | # --z_dim 221 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 222 | 223 | return args 224 | 225 | 226 | def main(): 227 | args = parse_args() 228 | if args is None: 229 | exit() 230 | model = WGAN_GP(args) 231 | model.train(load=True) 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | 237 | -------------------------------------------------------------------------------- /BEGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import scipy.misc 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | from tensorflow.keras import layers, optimizers, metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 15 | class Discriminator(tf.keras.Model): 16 | def __init__(self, batch_size=64, is_training=True): 17 | super(Discriminator, self).__init__(name='discriminator') 18 | self.batch_size = batch_size 19 | self.is_training = is_training 20 | self.bn_1 = BatchNorm(is_training=self.is_training) 21 | self.bn_2 = BatchNorm(is_training=self.is_training) 22 | self.fc_1 = DenseLayer(32) 23 | self.fc_2 = DenseLayer(64*14*14) 24 | self.conv_1 = Conv2D(64, 4, 2) 25 | self.up_conv_1 = UpConv2D(1, 4, 2) 26 | 27 | def call(self, inputs, training): 28 | x = self.conv_1(inputs) 29 | x = layers.ReLU()(x) 30 | x = layers.Flatten()(x) 31 | x = self.fc_1(x) 32 | x = self.bn_1(x, training) 33 | code = layers.ReLU()(x) 34 | x = self.fc_2(code) 35 | x = self.bn_2(x, training) 36 | x = layers.ReLU()(x) 37 | x = layers.Reshape((14, 14, 64))(x) 38 | x = self.up_conv_1(x) 39 | out = Sigmoid()(x) 40 | recon_error = tf.math.sqrt(2 * tf.nn.l2_loss(out - inputs)) / self.batch_size 41 | return out, recon_error, code 42 | 43 | 44 | class Generator(tf.keras.Model): 45 | def __init__(self, is_training=True): 46 | super(Generator, self).__init__(name='generator') 47 | self.is_training = is_training 48 | self.fc_1 = DenseLayer(1024) 49 | self.fc_2 = DenseLayer(128*7*7) 50 | self.bn_1 = BatchNorm(is_training=self.is_training) 51 | self.bn_2 = BatchNorm(is_training=self.is_training) 52 | self.bn_3 = BatchNorm(is_training=self.is_training) 53 | self.up_conv_1 = UpConv2D(64, 4, 2) 54 | self.up_conv_2 = UpConv2D(1, 4, 2) 55 | 56 | def call(self, inputs, training): 57 | x = self.fc_1(inputs) 58 | x = self.bn_1(x, training) 59 | x = layers.ReLU()(x) 60 | x = self.fc_2(x) 61 | x = self.bn_2(x, training) 62 | x = layers.ReLU()(x) 63 | x = layers.Reshape((7, 7, 128))(x) 64 | x = self.up_conv_1(x) 65 | x = self.bn_3(x, training) 66 | x = layers.ReLU()(x) 67 | x = self.up_conv_2(x) 68 | x = Sigmoid()(x) 69 | return x 70 | 71 | 72 | class BEGAN(): 73 | def __init__(self, args): 74 | super(BEGAN, self).__init__() 75 | self.model_name = args.gan_type 76 | self.batch_size = args.batch_size 77 | self.z_dim = args.z_dim 78 | self.sample_z = tf.random.uniform(minval=-1., maxval=1., shape=( 79 | self.batch_size, self.z_dim), dtype=tf.float32) 80 | self.y_dim = 10 81 | # BEGAN Parameter 82 | self.gamma = 0.75 83 | self.lamda = 0.001 84 | self.k = tf.Variable(0.0,trainable=False) 85 | self.checkpoint_dir = check_folder(os.path.join(args.checkpoint_dir, self.model_name)) 86 | self.result_dir = args.result_dir 87 | self.datasets_name = args.datasets 88 | self.log_dir = args.log_dir 89 | self.learnning_rate = args.lr 90 | self.epoches = args.epoch 91 | self.datasets = load_mnist_data(datasets=self.datasets_name, batch_size=args.batch_size) 92 | self.g = Generator(is_training=True) 93 | self.d = Discriminator(is_training=True) 94 | self.g_optimizer = keras.optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 95 | self.d_optimizer = keras.optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 96 | self.g_loss_metric = tf.keras.metrics.Mean('g_loss', dtype=tf.float32) 97 | self.d_loss_metric = tf.keras.metrics.Mean('d_loss', dtype=tf.float32) 98 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 99 | generator_optimizer=self.g_optimizer, 100 | discriminator_optimizer=self.d_optimizer, 101 | generator=self.g, 102 | discriminator=self.d) 103 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 104 | 105 | @property 106 | def model_dir(self): 107 | return "{}_{}_{}_{}".format( 108 | self.model_name, self.datasets_name, 109 | self.batch_size, self.z_dim) 110 | 111 | 112 | 113 | # train for one batch 114 | # @tf.function 115 | def train_one_step(self, batch_images,step): 116 | batch_z = tf.random.uniform(minval=-1.,maxval= 1.,shape=(self.batch_size, self.z_dim),dtype=tf.float32) 117 | real_images = batch_images 118 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: 119 | D_real_img, D_real_err, D_real_code = self.d(batch_images, training=True) 120 | fake_imgs = self.g(batch_z, training=True) 121 | D_fake_img, D_fake_err, D_fake_code = self.d(fake_imgs, training=True) 122 | 123 | # get loss for discriminator 124 | self.d_loss = D_real_err - self.k*D_fake_err 125 | 126 | # get loss for generator 127 | self.g_loss = D_fake_err 128 | 129 | # convergence metric 130 | self.M = D_real_err + tf.math.abs(self.gamma*D_real_err - D_fake_err) 131 | 132 | 133 | gradients_of_d = d_tape.gradient(self.d_loss, self.d.trainable_variables) 134 | gradients_of_g = g_tape.gradient(self.g_loss, self.g.trainable_variables) 135 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 136 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 137 | self.k = tf.clip_by_value(self.k + self.lamda*(self.gamma*D_real_err - D_fake_err), 0, 1) 138 | 139 | # operation for updating k 140 | 141 | M_sum = tf.summary.scalar("M", self.M,step=step) 142 | k_sum = tf.summary.scalar("k", self.k,step=step) 143 | self.d_loss_metric(self.d_loss) 144 | self.g_loss_metric(self.g_loss) 145 | 146 | 147 | 148 | 149 | def train(self, load=False): 150 | current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 151 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 152 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 153 | self.could_load = self.load_ckpt() 154 | ckpt_step = int(self.checkpoint.step) 155 | start_epoch = int((ckpt_step*self.batch_size)//60000) 156 | 157 | for epoch in range(start_epoch, self.epoches): 158 | for batch_images, _ in self.datasets: 159 | 160 | self.train_one_step(batch_images,int(self.checkpoint.step)) 161 | self.checkpoint.step.assign_add(1) 162 | step = int(self.checkpoint.step) 163 | 164 | # save generated images for every 50 batches training 165 | if step % 50 == 0: 166 | print('step:{}, k:{:.4f}, d_loss: {:.4f}, g_loss: {:.4F}'.format( 167 | step, self.k, self.d_loss_metric.result(), self.g_loss_metric.result())) 168 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 169 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 170 | result_to_display = self.g(self.sample_z, training=False) 171 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 172 | [manifold_h, manifold_w], 173 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 174 | 175 | with self.train_summary_writer.as_default(): 176 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 177 | tf.summary.scalar('d_loss', self.d_loss_metric.result(), step=step) 178 | 179 | #save checkpoints for every 400 batches training 180 | if step % 400 == 0: 181 | save_path = self.manager.save() 182 | print("\n----------Saved checkpoint for step {}: {}-----------\n".format(step, save_path)) 183 | self.g_loss_metric.reset_states() 184 | self.d_loss_metric.reset_states() 185 | 186 | def load_ckpt(self): 187 | self.checkpoint.restore(self.manager.latest_checkpoint) 188 | if self.manager.latest_checkpoint: 189 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 190 | return True 191 | 192 | else: 193 | print("Initializing from scratch.") 194 | return False 195 | 196 | 197 | def parse_args(): 198 | desc = "Tensorflow implementation of GAN collections" 199 | parser = argparse.ArgumentParser(description=desc) 200 | parser.add_argument('--gan_type', type=str, default='BEGAN') 201 | parser.add_argument('--datasets', type=str, default='fashion_mnist') 202 | parser.add_argument('--lr', type=float, default=2e-4) 203 | parser.add_argument('--epoch', type=int, default=20, 204 | help='The number of epochs to run') 205 | parser.add_argument('--batch_size', type=int, 206 | default=64, help='The size of batch') 207 | parser.add_argument('--z_dim', type=int, default=62, 208 | help='Dimension of noise vector') 209 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 210 | help='Directory name to save the checkpoints') 211 | parser.add_argument('--result_dir', type=str, default='results', 212 | help='Directory name to save the generated images') 213 | parser.add_argument('--log_dir', type=str, default='logs', 214 | help='Directory name to save training logs') 215 | 216 | return check_args(parser.parse_args()) 217 | 218 | 219 | """checking arguments""" 220 | 221 | 222 | def check_args(args): 223 | # --checkpoint_dir 224 | check_folder(args.checkpoint_dir) 225 | 226 | # --result_dir 227 | check_folder(args.result_dir) 228 | 229 | # --result_dir 230 | check_folder(args.log_dir) 231 | 232 | # --epoch 233 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 234 | 235 | # --batch_size 236 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 237 | 238 | # --z_dim 239 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 240 | 241 | return args 242 | 243 | 244 | def main(): 245 | args = parse_args() 246 | if args is None: 247 | exit() 248 | model = BEGAN(args) 249 | model.train(load=True) 250 | 251 | 252 | if __name__ == '__main__': 253 | main() 254 | -------------------------------------------------------------------------------- /EBGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import scipy.misc 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | from tensorflow.keras import layers, optimizers, metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 15 | class Discriminator(tf.keras.Model): 16 | def __init__(self, batch_size=64, is_training=True): 17 | super(Discriminator, self).__init__(name='discriminator') 18 | self.batch_size = batch_size 19 | self.is_training = is_training 20 | self.bn_1 = BatchNorm(is_training=self.is_training) 21 | self.bn_2 = BatchNorm(is_training=self.is_training) 22 | self.fc_1 = DenseLayer(32) 23 | self.fc_2 = DenseLayer(64*14*14) 24 | self.conv_1 = Conv2D(64, 4, 2) 25 | self.up_conv_1 = UpConv2D(1, 4, 2) 26 | 27 | def call(self, inputs, training): 28 | x = self.conv_1(inputs) 29 | x = layers.ReLU()(x) 30 | x = layers.Flatten()(x) 31 | code = self.fc_1(x) 32 | x = self.fc_2(code) 33 | x = self.bn_1(x, training) 34 | x = layers.ReLU()(x) 35 | x = layers.Reshape((14, 14, 64))(x) 36 | x = self.up_conv_1(x) 37 | out = Sigmoid()(x) 38 | recon_error = tf.math.sqrt(2 * tf.nn.l2_loss(out - inputs)) / self.batch_size 39 | return out, recon_error, code 40 | 41 | 42 | class Generator(tf.keras.Model): 43 | def __init__(self, is_training=True): 44 | super(Generator, self).__init__(name='generator') 45 | self.is_training = is_training 46 | self.fc_1 = DenseLayer(1024) 47 | self.fc_2 = DenseLayer(128*7*7) 48 | self.bn_1 = BatchNorm(is_training=self.is_training) 49 | self.bn_2 = BatchNorm(is_training=self.is_training) 50 | self.bn_3 = BatchNorm(is_training=self.is_training) 51 | self.up_conv_1 = UpConv2D(64, 4, 2) 52 | self.up_conv_2 = UpConv2D(1, 4, 2) 53 | 54 | def call(self, inputs, training): 55 | x = self.fc_1(inputs) 56 | x = self.bn_1(x, training) 57 | x = layers.ReLU()(x) 58 | x = self.fc_2(x) 59 | x = self.bn_2(x, training) 60 | x = layers.ReLU()(x) 61 | x = layers.Reshape((7, 7, 128))(x) 62 | x = self.up_conv_1(x) 63 | x = self.bn_3(x, training) 64 | x = layers.ReLU()(x) 65 | x = self.up_conv_2(x) 66 | x = Sigmoid()(x) 67 | return x 68 | 69 | 70 | class EBGAN(): 71 | def __init__(self, args): 72 | super(EBGAN, self).__init__() 73 | self.model_name = args.gan_type 74 | self.batch_size = args.batch_size 75 | self.z_dim = args.z_dim 76 | self.sample_z = tf.random.uniform(minval=-1., maxval=1., shape=( 77 | self.batch_size, self.z_dim), dtype=tf.dtypes.float32) 78 | self.y_dim = 10 79 | # margin for loss function 80 | self.margin = max(1, self.batch_size/64.) 81 | self.pt_loss_weight = 0.1 82 | # self.batch_size = 36 83 | self.checkpoint_dir = check_folder(os.path.join(args.checkpoint_dir, self.model_name)) 84 | self.result_dir = args.result_dir 85 | self.datasets_name = args.datasets 86 | self.log_dir = args.log_dir 87 | self.learnning_rate = args.lr 88 | self.epoches = args.epoch 89 | self.datasets = load_mnist_data(datasets=self.datasets_name, batch_size=args.batch_size) 90 | self.g = Generator(is_training=True) 91 | self.d = Discriminator(is_training=True) 92 | self.g_optimizer = keras.optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 93 | self.d_optimizer = keras.optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 94 | self.g_loss_metric = tf.keras.metrics.Mean('g_loss', dtype=tf.float32) 95 | self.d_loss_metric = tf.keras.metrics.Mean('d_loss', dtype=tf.float32) 96 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 97 | generator_optimizer=self.g_optimizer, 98 | discriminator_optimizer=self.d_optimizer, 99 | generator=self.g, 100 | discriminator=self.d) 101 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 102 | 103 | @property 104 | def model_dir(self): 105 | return "{}_{}_{}_{}".format( 106 | self.model_name, self.datasets_name, 107 | self.batch_size, self.z_dim) 108 | 109 | def pullaway_loss(self, embeddings): 110 | """ 111 | Pull Away loss calculation 112 | :param embeddings: The embeddings to be orthogonalized for varied faces. Shape [batch_size, embeddings_dim] 113 | :return: pull away term loss 114 | """ 115 | norm = tf.sqrt(tf.math.reduce_sum(tf.square(embeddings), 1, keepdims=True)) 116 | normalized_embeddings = embeddings / norm 117 | similarity = tf.matmul(normalized_embeddings,normalized_embeddings, transpose_b=True) 118 | batch_size = tf.cast(tf.shape(embeddings)[0], tf.float32) 119 | pt_loss = (tf.reduce_sum(similarity) - batch_size) / (batch_size * (batch_size - 1)) 120 | return pt_loss 121 | 122 | # train for one batch 123 | @tf.function 124 | def train_one_step(self, batch_images): 125 | batch_z = tf.random.uniform(minval=-1,maxval= 1,shape=(self.batch_size, self.z_dim),dtype=tf.dtypes.float32) 126 | real_images = batch_images 127 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: 128 | D_real_img, D_real_err, D_real_code = self.d(batch_images, training=True) 129 | fake_imgs = self.g(batch_z, training=True) 130 | D_fake_img, D_fake_err, D_fake_code = self.d(fake_imgs, training=True) 131 | 132 | # get loss for discriminator 133 | d_loss = D_real_err + tf.maximum(self.margin - D_fake_err, 0) 134 | 135 | # get loss for generator 136 | g_loss = D_fake_err + self.pt_loss_weight * self.pullaway_loss(D_fake_code) 137 | 138 | gradients_of_d = d_tape.gradient(d_loss, self.d.trainable_variables) 139 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 140 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 141 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 142 | 143 | self.d_loss_metric(d_loss) 144 | self.g_loss_metric(g_loss) 145 | 146 | 147 | 148 | def train(self, load=False): 149 | current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 150 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 151 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 152 | self.could_load = self.load_ckpt() 153 | ckpt_step = int(self.checkpoint.step) 154 | start_epoch = int((ckpt_step*self.batch_size)//60000) 155 | 156 | for epoch in range(start_epoch, self.epoches): 157 | for batch_images, _ in self.datasets: 158 | self.train_one_step(batch_images) 159 | self.checkpoint.step.assign_add(1) 160 | step = int(self.checkpoint.step) 161 | 162 | # save generated images for every 50 batches training 163 | if step % 50 == 0: 164 | print('step: {}, d_loss: {:.4f}, g_loss: {:.4F}'.format( 165 | step, self.d_loss_metric.result(), self.g_loss_metric.result())) 166 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 167 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 168 | result_to_display = self.g(self.sample_z, training=False) 169 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 170 | [manifold_h, manifold_w], 171 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 172 | 173 | with self.train_summary_writer.as_default(): 174 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 175 | tf.summary.scalar('d_loss', self.d_loss_metric.result(), step=step) 176 | 177 | #save checkpoints for every 400 batches training 178 | if step % 400 == 0: 179 | save_path = self.manager.save() 180 | print("\n----------Saved checkpoint for step {}: {}-----------\n".format(step, save_path)) 181 | self.g_loss_metric.reset_states() 182 | self.d_loss_metric.reset_states() 183 | 184 | def load_ckpt(self): 185 | self.checkpoint.restore(self.manager.latest_checkpoint) 186 | if self.manager.latest_checkpoint: 187 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 188 | return True 189 | 190 | else: 191 | print("Initializing from scratch.") 192 | return False 193 | 194 | 195 | def parse_args(): 196 | desc = "Tensorflow implementation of GAN collections" 197 | parser = argparse.ArgumentParser(description=desc) 198 | parser.add_argument('--gan_type', type=str, default='EBGAN') 199 | parser.add_argument('--datasets', type=str, default='mnist') 200 | parser.add_argument('--lr', type=float, default=2e-4) 201 | parser.add_argument('--epoch', type=int, default=20, 202 | help='The number of epochs to run') 203 | parser.add_argument('--batch_size', type=int, 204 | default=64, help='The size of batch') 205 | parser.add_argument('--z_dim', type=int, default=62, 206 | help='Dimension of noise vector') 207 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 208 | help='Directory name to save the checkpoints') 209 | parser.add_argument('--result_dir', type=str, default='results', 210 | help='Directory name to save the generated images') 211 | parser.add_argument('--log_dir', type=str, default='logs', 212 | help='Directory name to save training logs') 213 | 214 | return check_args(parser.parse_args()) 215 | 216 | 217 | """checking arguments""" 218 | 219 | 220 | def check_args(args): 221 | # --checkpoint_dir 222 | check_folder(args.checkpoint_dir) 223 | 224 | # --result_dir 225 | check_folder(args.result_dir) 226 | 227 | # --result_dir 228 | check_folder(args.log_dir) 229 | 230 | # --epoch 231 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 232 | 233 | # --batch_size 234 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 235 | 236 | # --z_dim 237 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 238 | 239 | return args 240 | 241 | 242 | def main(): 243 | args = parse_args() 244 | if args is None: 245 | exit() 246 | model = EBGAN(args) 247 | model.train(load=True) 248 | 249 | 250 | if __name__ == '__main__': 251 | main() 252 | -------------------------------------------------------------------------------- /ACGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import scipy.misc 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | from tensorflow.keras import layers, optimizers, metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | 15 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 16 | class Discriminator(tf.keras.Model): 17 | def __init__(self, is_training=True): 18 | super(Discriminator, self).__init__(name='discriminator') 19 | self.is_training = is_training 20 | self.conv_1 = Conv2D(64, 4, 2) 21 | self.conv_2 = Conv2D(128, 4, 2) 22 | self.bn_1=BatchNorm(is_training=self.is_training) 23 | self.bn_2=BatchNorm(is_training=self.is_training) 24 | self.fc_1 = DenseLayer(1024) 25 | self.fc_2 = DenseLayer(1) 26 | 27 | def call(self, inputs, training): 28 | x = self.conv_1(inputs) 29 | x = layers.LeakyReLU(alpha=0.2)(x) 30 | x = self.conv_2(x) 31 | x = self.bn_1(x, training) 32 | x = layers.LeakyReLU(alpha=0.2)(x) 33 | x = layers.Flatten()(x) 34 | x = self.fc_1(x) 35 | x = self.bn_2(x, training) 36 | x = layers.LeakyReLU(alpha=0.2)(x) 37 | out_logits = self.fc_2(x) 38 | out = keras.activations.sigmoid(out_logits) 39 | return out, out_logits, x 40 | 41 | 42 | class Generator(tf.keras.Model): 43 | def __init__(self, is_training=True): 44 | super(Generator, self).__init__(name='generator') 45 | self.is_training = is_training 46 | self.fc_1 = DenseLayer(1024) 47 | self.fc_2 = DenseLayer(128*7*7) 48 | self.bn_1=BatchNorm(is_training=self.is_training) 49 | self.bn_2=BatchNorm(is_training=self.is_training) 50 | self.bn_3=BatchNorm(is_training=self.is_training) 51 | self.up_conv_1 = UpConv2D(64, 4, 2) 52 | self.up_conv_2 = UpConv2D(1, 4, 2) 53 | 54 | def call(self, inputs, training): 55 | x = self.fc_1(inputs) 56 | x = self.bn_1(x, training) 57 | x = layers.ReLU()(x) 58 | x = self.fc_2(x) 59 | x = self.bn_2(x, training) 60 | x = layers.ReLU()(x) 61 | x = layers.Reshape((7, 7, 128))(x) 62 | x = self.up_conv_1(x) 63 | x = self.bn_3(x, training) 64 | x = layers.ReLU()(x) 65 | x = self.up_conv_2(x) 66 | x = keras.activations.sigmoid(x) 67 | return x 68 | 69 | 70 | class Classifier(tf.keras.Model): 71 | def __init__(self, y_dim, is_training=True): 72 | super(Classifier, self).__init__(name='classifier') 73 | self.is_training = is_training 74 | self.y_dim = y_dim 75 | self.fc_1 = DenseLayer(64) 76 | self.fc_2 = DenseLayer(self.y_dim) 77 | self.bn_1=BatchNorm(is_training=self.is_training) 78 | 79 | def call(self, inputs, training): 80 | x = self.fc_1(inputs) 81 | x = self.bn_1(x, training) 82 | x = layers.LeakyReLU(alpha=0.2)(x) 83 | out_logits = self.fc_2(x) 84 | out=keras.layers.Softmax()(out_logits) 85 | return out, out_logits 86 | 87 | 88 | class ACGAN(): 89 | def __init__(self, args): 90 | super(ACGAN, self).__init__() 91 | self.model_name = args.gan_type 92 | self.batch_size = args.batch_size 93 | self.z_dim = args.z_dim 94 | self.sample_z = np.random.uniform(-1, 1,size=(self.batch_size, self.z_dim)) 95 | self.y_dim = 10 96 | self.checkpoint_dir = check_folder(os.path.join(args.checkpoint_dir, self.model_name)) 97 | self.result_dir = args.result_dir 98 | self.datasets_name = args.datasets 99 | self.log_dir = args.log_dir 100 | self.learnning_rate = args.lr 101 | self.epoches = args.epoch 102 | self.datasets = load_mnist_data(datasets=self.datasets_name,batch_size=args.batch_size) 103 | self.g = Generator(is_training=True) 104 | self.d = Discriminator(is_training=True) 105 | self.c = Classifier(self.y_dim, is_training=True) 106 | self.g_optimizer = keras.optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 107 | self.d_optimizer = keras.optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 108 | self.q_optimizer = keras.optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 109 | self.g_loss_metric = tf.keras.metrics.Mean('g_loss', dtype=tf.float32) 110 | self.d_loss_metric = tf.keras.metrics.Mean('d_loss', dtype=tf.float32) 111 | self.q_loss_metric = tf.keras.metrics.Mean('q_loss', dtype=tf.float32) 112 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 113 | generator_optimizer=self.g_optimizer, 114 | discriminator_optimizer=self.d_optimizer, 115 | classifier_optimizer=self.q_optimizer, 116 | generator=self.g, 117 | discriminator=self.d, 118 | classifier=self.c) 119 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 120 | 121 | @property 122 | def model_dir(self): 123 | return "{}_{}_{}_{}".format( 124 | self.model_name, self.datasets_name, 125 | self.batch_size, self.z_dim) 126 | 127 | def d_loss_fun(self, d_fake_logits, d_real_logits): 128 | d_loss_real = tf.reduce_mean( 129 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real_logits), logits=d_real_logits)) 130 | d_loss_fake = tf.reduce_mean( 131 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake_logits), logits=d_fake_logits)) 132 | total_loss = d_loss_fake+d_loss_real 133 | return total_loss 134 | 135 | def g_loss_fun(self, logits): 136 | g_loss = tf.reduce_mean( 137 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits)) 138 | return g_loss 139 | 140 | def q_loss_fun(self, code_logit_real, code_logit_fake,batch_labels): 141 | q_real_loss = tf.reduce_mean( 142 | tf.nn.sigmoid_cross_entropy_with_logits(labels=batch_labels, logits=code_logit_real)) 143 | q_fake_loss = tf.reduce_mean( 144 | tf.nn.sigmoid_cross_entropy_with_logits(labels=batch_labels, logits=code_logit_fake)) 145 | q_loss = q_real_loss+q_fake_loss 146 | return q_loss 147 | 148 | 149 | 150 | # train for one batch 151 | @tf.function 152 | def train_one_step(self, batch_labels, batch_images): 153 | noises = np.random.uniform(-1, 1,[self.batch_size, self.z_dim]).astype(np.float32) 154 | batch_z = tf.concat([noises, batch_labels], 1) 155 | real_images = batch_images 156 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape, tf.GradientTape() as q_tape: 157 | fake_imgs = self.g(batch_z, training=True) 158 | _, d_fake_logits, input4classifier_fake = self.d(fake_imgs, training=True) 159 | _, d_real_logits, input4classifier_real = self.d(real_images, training=True) 160 | d_loss = self.d_loss_fun(d_fake_logits, d_real_logits) 161 | g_loss = self.g_loss_fun(d_fake_logits) 162 | code_fake, code_logit_fake = self.c(input4classifier_fake, training=True) 163 | code_real, code_logit_real = self.c(input4classifier_real, training=True) 164 | q_loss=self.q_loss_fun(code_logit_real, code_logit_fake,batch_labels) 165 | 166 | 167 | 168 | gradients_of_d = d_tape.gradient(d_loss, self.d.trainable_variables) 169 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 170 | #q loss backprop to all the trainable-variables 171 | 172 | trainable_variables_q=self.c.trainable_variables+self.d.trainable_variables+self.g.trainable_variables 173 | gradients_q = q_tape.gradient(q_loss,trainable_variables_q) 174 | 175 | 176 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 177 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 178 | self.q_optimizer.apply_gradients(zip(gradients_q, trainable_variables_q)) 179 | 180 | self.g_loss_metric(g_loss) 181 | self.d_loss_metric(d_loss) 182 | self.q_loss_metric(q_loss) 183 | 184 | 185 | def train(self, load=False): 186 | current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 187 | train_log_dir = os.path.join(self.log_dir, self.model_name, current_time) 188 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 189 | self.could_load = self.load_ckpt() 190 | ckpt_step = int(self.checkpoint.step) 191 | start_epoch = int((ckpt_step*self.batch_size)//60000) 192 | 193 | for epoch in range(start_epoch, self.epoches): 194 | for batch_images, batch_labels in self.datasets: 195 | if int(self.checkpoint.step) == 0 or self.could_load: 196 | self.could_load = False 197 | self.test_labels = batch_labels[0:self.batch_size] 198 | self.train_one_step(batch_labels, batch_images) 199 | self.checkpoint.step.assign_add(1) 200 | step = int(self.checkpoint.step) 201 | 202 | # save generated images for every 50 batches training 203 | if step % 50 == 0: 204 | print('step: {}, d_loss: {:.4f}, g_loss: {:.4F}, q_loss: {:.4F}'.format( 205 | step, self.d_loss_metric.result(), self.g_loss_metric.result(),self.q_loss_metric.result())) 206 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 207 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 208 | self.batch_z_to_disply = tf.concat([self.sample_z, self.test_labels[:self.batch_size, :]], 1) 209 | result_to_display = self.g(self.batch_z_to_disply, training=False) 210 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 211 | [manifold_h, manifold_w], 212 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 213 | 214 | with self.train_summary_writer.as_default(): 215 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 216 | tf.summary.scalar('d_loss', self.d_loss_metric.result(), step=step) 217 | tf.summary.scalar('q_loss', self.q_loss_metric.result(), step=step) 218 | 219 | 220 | #save checkpoints for every 400 batches training 221 | if step % 400 == 0: 222 | save_path = self.manager.save() 223 | print("\n----------Saved checkpoint for step {}: {}-----------\n".format(step, save_path)) 224 | self.g_loss_metric.reset_states() 225 | self.d_loss_metric.reset_states() 226 | self.q_loss_metric.reset_states() 227 | 228 | def load_ckpt(self): 229 | self.checkpoint.restore(self.manager.latest_checkpoint) 230 | if self.manager.latest_checkpoint: 231 | print("restore model from checkpoint: {}".format( 232 | self.manager.latest_checkpoint)) 233 | return True 234 | 235 | else: 236 | print("Initializing from scratch.") 237 | return False 238 | 239 | 240 | def parse_args(): 241 | desc = "Tensorflow implementation of GAN collections" 242 | parser = argparse.ArgumentParser(description=desc) 243 | parser.add_argument('--gan_type', type=str, default='ACGAN') 244 | parser.add_argument('--datasets', type=str, default='mnist') 245 | parser.add_argument('--lr', type=float, default=2e-4) 246 | parser.add_argument('--epoch', type=int, default=20, 247 | help='The number of epochs to run') 248 | parser.add_argument('--batch_size', type=int, 249 | default=64, help='The size of batch') 250 | parser.add_argument('--z_dim', type=int, default=62, 251 | help='Dimension of noise vector') 252 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 253 | help='Directory name to save the checkpoints') 254 | parser.add_argument('--result_dir', type=str, default='results', 255 | help='Directory name to save the generated images') 256 | parser.add_argument('--log_dir', type=str, default='logs', 257 | help='Directory name to save training logs') 258 | 259 | return check_args(parser.parse_args()) 260 | 261 | 262 | """checking arguments""" 263 | 264 | 265 | def check_args(args): 266 | # --checkpoint_dir 267 | check_folder(args.checkpoint_dir) 268 | 269 | # --result_dir 270 | check_folder(args.result_dir) 271 | 272 | # --result_dir 273 | check_folder(args.log_dir) 274 | 275 | # --epoch 276 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 277 | 278 | # --batch_size 279 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 280 | 281 | # --z_dim 282 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 283 | 284 | return args 285 | 286 | 287 | def main(): 288 | args = parse_args() 289 | if args is None: 290 | exit() 291 | model = ACGAN(args) 292 | model.train(load=True) 293 | 294 | 295 | if __name__ == '__main__': 296 | main() 297 | -------------------------------------------------------------------------------- /infoGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import datetime 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import scipy.misc 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | from tensorflow.keras import layers, optimizers, metrics 10 | 11 | from ops import * 12 | from utils import * 13 | 14 | # the network is based on https://github.com/hwalsuklee/tensorflow-generative-model-collections 15 | class Discriminator(tf.keras.Model): 16 | def __init__(self, is_training=True): 17 | super(Discriminator, self).__init__(name='discriminator') 18 | self.is_training = is_training 19 | self.conv_1 = Conv2D(64, 4, 2) 20 | self.conv_2 = Conv2D(128, 4, 2) 21 | self.bn_1 = BatchNorm(is_training=self.is_training) 22 | self.bn_2 = BatchNorm(is_training=self.is_training) 23 | self.fc_1 = DenseLayer(1024) 24 | self.fc_2 = DenseLayer(1) 25 | 26 | def call(self, inputs, training): 27 | x = self.conv_1(inputs) 28 | x = layers.LeakyReLU(alpha=0.2)(x) 29 | x = self.conv_2(x) 30 | x = self.bn_1(x, training) 31 | x = layers.LeakyReLU(alpha=0.2)(x) 32 | x = layers.Flatten()(x) 33 | x = self.fc_1(x) 34 | x = self.bn_2(x, training) 35 | x = layers.LeakyReLU(alpha=0.2)(x) 36 | out_logits = self.fc_2(x) 37 | out = keras.activations.sigmoid(out_logits) 38 | 39 | return out, out_logits, x 40 | 41 | 42 | class Generator(tf.keras.Model): 43 | def __init__(self, is_training=True): 44 | super(Generator, self).__init__(name='generator') 45 | self.is_training = is_training 46 | self.bn_1 = BatchNorm(is_training=self.is_training) 47 | self.bn_2 = BatchNorm(is_training=self.is_training) 48 | self.bn_3 = BatchNorm(is_training=self.is_training) 49 | self.fc_1 = DenseLayer(1024) 50 | self.fc_2 = DenseLayer(128*7*7) 51 | self.up_conv_1 = UpConv2D(64, 4, 2) 52 | self.up_conv_2 = UpConv2D(1, 4, 2) 53 | 54 | def call(self, inputs, training): 55 | x = self.fc_1(inputs) 56 | x = self.bn_1(x, training) 57 | x = layers.ReLU()(x) 58 | x = self.fc_2(x) 59 | x = self.bn_2(x, training) 60 | x = layers.ReLU()(x) 61 | x = layers.Reshape((7, 7, 128))(x) 62 | x = self.up_conv_1(x) 63 | x = self.bn_3(x, training) 64 | x = layers.ReLU()(x) 65 | x = self.up_conv_2(x) 66 | x = keras.activations.sigmoid(x) 67 | return x 68 | 69 | 70 | class Classifier(tf.keras.Model): 71 | def __init__(self, y_dim, is_training=True): 72 | super(Classifier, self).__init__(name='classifier') 73 | self.is_training = is_training 74 | self.y_dim = y_dim 75 | self.bn_1 = BatchNorm(is_training=self.is_training) 76 | self.fc_1 = DenseLayer(64) 77 | self.fc_2 = DenseLayer(self.y_dim) 78 | 79 | def call(self, inputs, training): 80 | x = self.fc_1(inputs) 81 | x = self.bn_1(x, training) 82 | x = layers.LeakyReLU(alpha=0.2)(x) 83 | out_logits = self.fc_2(x) 84 | out = keras.layers.Softmax()(out_logits) 85 | return out, out_logits 86 | 87 | 88 | class infoGAN(): 89 | def __init__(self, args): 90 | super(infoGAN, self).__init__() 91 | self.model_name = args.gan_type 92 | self.batch_size = args.batch_size 93 | self.SUPERVISED = True # if it is true, label info is directly used for code 94 | self.z_dim = args.z_dim 95 | self.y_dim = 12 96 | self.len_discrete_code = 10 # categorical distribution (i.e. label) 97 | self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) 98 | self.checkpoint_dir = os.path.join(args.checkpoint_dir, self.model_name) 99 | self.result_dir = args.result_dir 100 | self.datasets_name = args.datasets 101 | self.log_dir = args.log_dir 102 | self.learnning_rate = args.lr 103 | self.epoches = args.epoch 104 | self.sample_z = tf.random.uniform(minval=-1, maxval=1, shape=(self.batch_size, self.z_dim), 105 | dtype=tf.dtypes.float32) 106 | self.datasets = load_mnist_data(batch_size=self.batch_size, datasets=self.datasets_name) 107 | self.g = Generator(is_training=True) 108 | self.d = Discriminator(is_training=True) 109 | self.c = Classifier(y_dim=self.y_dim, is_training=True) 110 | self.g_optimizer = keras.optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 111 | self.d_optimizer = keras.optimizers.Adam(lr=self.learnning_rate, beta_1=0.5) 112 | self.q_optimizer = keras.optimizers.Adam(lr=5*self.learnning_rate, beta_1=0.5) 113 | self.g_loss_metric = tf.keras.metrics.Mean('g_loss', dtype=tf.float32) 114 | self.d_loss_metric = tf.keras.metrics.Mean('d_loss', dtype=tf.float32) 115 | self.q_loss_metric = tf.keras.metrics.Mean('q_loss', dtype=tf.float32) 116 | self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0), 117 | generator_optimizer=self.g_optimizer, 118 | discriminator_optimizer=self.d_optimizer, 119 | classifier_optimizer=self.q_optimizer, 120 | generator=self.g, 121 | discriminator=self.d, 122 | classifier=self.c) 123 | self.manager = tf.train.CheckpointManager(self.checkpoint, self.checkpoint_dir, max_to_keep=3) 124 | 125 | @property 126 | def model_dir(self): 127 | return "{}_{}_{}_{}".format( 128 | self.model_name, self.datasets_name, 129 | self.batch_size, self.z_dim) 130 | 131 | def d_loss_fun(self, d_fake_logits, d_real_logits): 132 | d_loss_real = tf.reduce_mean( 133 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(d_real_logits), logits=d_real_logits)) 134 | d_loss_fake = tf.reduce_mean( 135 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(d_fake_logits), logits=d_fake_logits)) 136 | total_loss = d_loss_fake+d_loss_real 137 | return total_loss 138 | 139 | def g_loss_fun(self, logits): 140 | g_loss = tf.reduce_mean( 141 | tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits)) 142 | return g_loss 143 | 144 | def q_loss_fun(self, disc_code_est, disc_code_tg, cont_code_est, cont_code_tg): 145 | q_disc_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 146 | labels=disc_code_tg, logits=disc_code_est)) 147 | q_cont_loss = tf.reduce_mean(tf.reduce_sum( 148 | tf.square(cont_code_tg - cont_code_est), axis=1)) 149 | q_loss = q_disc_loss+q_cont_loss 150 | return q_loss 151 | 152 | 153 | # training for one batch 154 | @tf.function 155 | def train_one_step(self, batch_labels, batch_images): 156 | noises = tf.random.uniform(shape=(self.batch_size, self.z_dim), minval=-1, maxval=1, dtype=tf.dtypes.float32) 157 | code = tf.random.uniform(minval=-1, maxval=1, shape=(self.batch_size, self.len_continuous_code), dtype=tf.dtypes.float32) 158 | batch_codes = tf.concat((batch_labels, code), axis=1) 159 | batch_z = tf.concat([noises, batch_codes], 1) 160 | real_images = conv_cond_concat(batch_images, batch_codes) 161 | with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape, tf.GradientTape() as q_tape: 162 | fake_imgs = self.g(batch_z, training=True) 163 | fake_imgs = conv_cond_concat(fake_imgs, batch_codes) 164 | d_fake, d_fake_logits, input4classifier_fake = self.d(fake_imgs, training=True) 165 | d_real, d_real_logits, _ = self.d(real_images, training=True) 166 | d_loss = self.d_loss_fun(d_fake_logits, d_real_logits) 167 | g_loss = self.g_loss_fun(d_fake_logits) 168 | code_fake, code_logit_fake = self.c(input4classifier_fake, training=True) 169 | disc_code_est = code_logit_fake[:, :self.len_discrete_code] 170 | disc_code_tg = batch_codes[:, :self.len_discrete_code] 171 | cont_code_est = code_logit_fake[:, self.len_discrete_code:] 172 | cont_code_tg = batch_codes[:, self.len_discrete_code:] 173 | q_loss = self.q_loss_fun(disc_code_est, disc_code_tg, cont_code_est, cont_code_tg) 174 | 175 | gradients_of_d = d_tape.gradient(d_loss, self.d.trainable_variables) 176 | gradients_of_g = g_tape.gradient(g_loss, self.g.trainable_variables) 177 | gradients_of_q = q_tape.gradient(q_loss, self.c.trainable_variables) 178 | 179 | self.d_optimizer.apply_gradients(zip(gradients_of_d, self.d.trainable_variables)) 180 | self.g_optimizer.apply_gradients(zip(gradients_of_g, self.g.trainable_variables)) 181 | self.q_optimizer.apply_gradients(zip(gradients_of_q, self.c.trainable_variables)) 182 | 183 | self.g_loss_metric(g_loss) 184 | self.d_loss_metric(d_loss) 185 | self.q_loss_metric(q_loss) 186 | 187 | def train(self, load=False): 188 | 189 | self.sample_label=tf.cast(2*(tf.ones(shape=(self.batch_size,),dtype=tf.int32)),dtype=tf.int32) 190 | self.sample_label=tf.one_hot(self.sample_label,depth=10) 191 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 192 | train_log_dir = self.log_dir+'/'+self.model_name+'/' + current_time 193 | self.train_summary_writer = tf.summary.create_file_writer(train_log_dir) 194 | self.could_load = self.load_ckpt() 195 | ckpt_step = int(self.checkpoint.step) 196 | start_epoch = int((ckpt_step*self.batch_size)//60000) 197 | 198 | for epoch in range(start_epoch, self.epoches): 199 | for batch_images, batch_labels in self.datasets: 200 | if self.SUPERVISED == True: 201 | batch_labels = batch_labels 202 | else: 203 | batch_labels = np.random.multinomial(1, 204 | self.len_discrete_code * 205 | [float(1.0 / self.len_discrete_code)], 206 | size=[self.batch_size]) 207 | 208 | self.continuous_code=tf.random.uniform(shape=[self.batch_size, self.len_continuous_code], 209 | minval=-1.0,maxval=1.0,dtype=tf.dtypes.float32) 210 | self.test_codes = tf.concat([self.sample_label,self.continuous_code],1) 211 | self.train_one_step(batch_labels, batch_images) 212 | self.checkpoint.step.assign_add(1) 213 | step = int(self.checkpoint.step) 214 | 215 | 216 | #save checkpoints for every 400 batches training 217 | if step % 50 == 0: 218 | print('step: {}, d_loss: {:.4f}, g_loss: {:.4F}, q_loss: {:.4F}' 219 | .format(step, self.d_loss_metric.result(), self.g_loss_metric.result(),self.q_loss_metric.result())) 220 | manifold_h = int(np.floor(np.sqrt(self.batch_size))) 221 | manifold_w = int(np.floor(np.sqrt(self.batch_size))) 222 | self.batch_z_to_disply = tf.concat([self.sample_z, self.test_codes[:self.batch_size, :]], 1) 223 | result_to_display = self.g(self.batch_z_to_disply, training=False) 224 | save_images(result_to_display[:manifold_h * manifold_w, :, :, :], 225 | [manifold_h, manifold_w], 226 | './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(epoch, int(step))) 227 | 228 | with self.train_summary_writer.as_default(): 229 | tf.summary.scalar('g_loss', self.g_loss_metric.result(), step=step) 230 | tf.summary.scalar('d_loss', self.d_loss_metric.result(), step=step) 231 | tf.summary.scalar('q_loss', self.q_loss_metric.result(), step=step) 232 | 233 | #save checkpoints for every 400 batches training 234 | if step % 400 == 0: 235 | save_path = self.manager.save() 236 | print("\n---------------Saved checkpoint for step {}: {}------------------\n".format(step, save_path)) 237 | self.g_loss_metric.reset_states() 238 | self.d_loss_metric.reset_states() 239 | self.q_loss_metric.reset_states() 240 | 241 | def load_ckpt(self): 242 | self.checkpoint.restore(self.manager.latest_checkpoint) 243 | if self.manager.latest_checkpoint: 244 | print("restore model from checkpoint: {}".format(self.manager.latest_checkpoint)) 245 | return True 246 | 247 | else: 248 | print("Initializing from scratch.") 249 | return False 250 | 251 | 252 | def parse_args(): 253 | desc = "Tensorflow implementation of GAN collections" 254 | parser = argparse.ArgumentParser(description=desc) 255 | parser.add_argument('--gan_type', type=str, default='infoGAN') 256 | parser.add_argument('--datasets', type=str, default='mnist') 257 | parser.add_argument('--lr', type=float, default=2e-4) 258 | parser.add_argument('--epoch', type=int, default=20, 259 | help='The number of epochs to run') 260 | parser.add_argument('--batch_size', type=int, 261 | default=64, help='The size of batch') 262 | parser.add_argument('--z_dim', type=int, default=62, 263 | help='Dimension of noise vector') 264 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 265 | help='Directory name to save the checkpoints') 266 | parser.add_argument('--result_dir', type=str, default='results', 267 | help='Directory name to save the generated images') 268 | parser.add_argument('--log_dir', type=str, default='logs', 269 | help='Directory name to save training logs') 270 | 271 | return check_args(parser.parse_args()) 272 | 273 | 274 | """checking arguments""" 275 | 276 | 277 | def check_args(args): 278 | # --checkpoint_dir 279 | check_folder(args.checkpoint_dir) 280 | 281 | # --result_dir 282 | check_folder(args.result_dir) 283 | 284 | # --result_dir 285 | check_folder(args.log_dir) 286 | 287 | # --epoch 288 | assert args.epoch >= 1, 'number of epochs must be larger than or equal to one' 289 | 290 | # --batch_size 291 | assert args.batch_size >= 1, 'batch size must be larger than or equal to one' 292 | 293 | # --z_dim 294 | assert args.z_dim >= 1, 'dimension of noise vector must be larger than or equal to one' 295 | 296 | return args 297 | 298 | 299 | def main(): 300 | args = parse_args() 301 | if args is None: 302 | exit() 303 | model = infoGAN(args) 304 | model.train(load=True) 305 | 306 | 307 | if __name__ == '__main__': 308 | main() 309 | --------------------------------------------------------------------------------