├── README.md ├── train_wgan_gp.py ├── preprocess.py ├── net.py └── wgan_gp.py /README.md: -------------------------------------------------------------------------------- 1 | # Improved Training of Wasserstein GANs 2 | Code for reproducing the [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028) 3 | 4 | # Requirement 5 | python 3.5.2 6 | keras 2.0.2 7 | 8 | # Usage 9 | ``` 10 | python train_wgan_gp.py --load_dir {} 11 | ``` 12 | -------------------------------------------------------------------------------- /train_wgan_gp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import keras.backend as K 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from net import Net 8 | from preprocess import DataFeeder 9 | from wgan_gp import WganGP 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--load_dir', type=str, default='data') 13 | parser.add_argument('--batch_size', type=int, default=64) 14 | parser.add_argument('--image_size', type=int, default=64) 15 | args = parser.parse_args() 16 | load_dir = args.load_dir 17 | batch_size = args.batch_size 18 | size = (args.image_size, args.image_size) 19 | 20 | net = Net() 21 | data_feeder = DataFeeder(load_dir=load_dir, batch_size=batch_size, size=size) 22 | sess = tf.Session() 23 | wgan_gp = WganGP(net, data_feeder, sess, batch_size, size[0]) 24 | 25 | wgan_gp.train(100) 26 | 27 | wgan_gp.generate_image(np.random.normal(size=[10, 128]), 'tekitou', concat=True) 28 | 29 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from keras.preprocessing.image import ImageDataGenerator, array_to_img 4 | from PIL import Image 5 | 6 | class DataFeeder(object): 7 | def __init__(self, load_dir, batch_size=64, size=(64, 64)): 8 | self.load_dir = load_dir 9 | self.batch_size = batch_size 10 | self.size = size 11 | self.generator = ImageDataGenerator(data_format='channels_first').flow_from_directory(self.load_dir, target_size=size, batch_size = batch_size) 12 | 13 | def fetch_data(self): 14 | data, _ = next(self.generator) 15 | if data.shape[0] == self.batch_size: 16 | return data/255. 17 | else: 18 | return self.fetch_data() 19 | 20 | def save_images(self, arrays, names, concat, save_dir='save'): 21 | if not isinstance(names, list): 22 | names = [names] 23 | if not os.path.exists(save_dir): 24 | os.mkdir(save_dir) 25 | if not concat: 26 | for array, name in zip(arrays, names): 27 | image = array_to_img(array, data_format='channels_first').resize((60, 60)) 28 | image.save("{}.png".format(name), quality=100) 29 | else: 30 | canvas = Image.new('RGB', (60*len(arrays), 60), (255, 255, 255)) 31 | for i, array in enumerate(arrays): 32 | image = array_to_img(array, data_format='channels_first').resize((60, 60)) 33 | canvas.paste(image, (i*60, 0)) 34 | canvas.save(os.path.join(save_dir, "{}.png".format(names[0])), quality=100) 35 | 36 | if __name__ == '__main__': 37 | data_feeder = DataFeeder('./dir_test', batch_size=2) 38 | data = data_feeder.fetch_data() 39 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import os 2 | from keras.layers import Dense, Conv2DTranspose, Reshape, UpSampling2D, Conv2D, LeakyReLU, Flatten, Activation, BatchNormalization 3 | from keras.models import Sequential 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | class Net(object): 8 | def __init__(self, z_dim=128, dim=64, gen_model=None, dis_model=None): 9 | if gen_model is None: 10 | gen_model = Sequential() 11 | gen_model.add(Dense(12*12*512, activation='relu', input_dim=z_dim)) 12 | gen_model.add(Activation('relu')) 13 | gen_model.add(BatchNormalization(axis=1)) 14 | gen_model.add(Reshape([512, 12, 12])) 15 | gen_model.add(Conv2DTranspose(256, 3, data_format='channels_first')) 16 | gen_model.add(BatchNormalization(axis=1)) 17 | gen_model.add(Activation('relu')) 18 | gen_model.add(Conv2DTranspose(128, 3, data_format='channels_first')) 19 | gen_model.add(BatchNormalization(axis=1)) 20 | gen_model.add(Activation('relu')) 21 | gen_model.add(Conv2DTranspose(64, 4, strides=2, padding='same', data_format='channels_first')) 22 | gen_model.add(BatchNormalization(axis=1)) 23 | gen_model.add(Activation('relu')) 24 | gen_model.add(Conv2DTranspose(3, 4, strides=2, padding='same', data_format='channels_first')) 25 | gen_model.add(Activation('tanh')) 26 | self.generator = gen_model 27 | 28 | if dis_model is None: 29 | dis_model = Sequential() 30 | dis_model.add(Conv2D(dim, 5, strides=2, padding='same', data_format='channels_first', input_shape=[3, 64, 64])) 31 | dis_model.add(LeakyReLU(0.2)) 32 | dis_model.add(Conv2D(dim*2, 5, strides=2)) 33 | dis_model.add(LeakyReLU(0.2)) 34 | dis_model.add(Flatten()) 35 | dis_model.add(Dense(256)) 36 | dis_model.add(LeakyReLU(0.2)) 37 | dis_model.add(Dense(1)) 38 | self.discriminator = dis_model 39 | 40 | def save_models(self, name, save_dir='save'): 41 | self.generator.save(os.path.join(save_dir, "generator_{}.h5".format(name))) 42 | self.discriminator.save(os.path.join(save_dir, "discriminator_{}.h5".format(name))) 43 | 44 | if __name__ == '__main__': 45 | net = Net() 46 | 47 | -------------------------------------------------------------------------------- /wgan_gp.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | class WganGP(object): 6 | def __init__(self, net, data_feeder, sess, batch_size=64, dim=64, z_dim=128, lambda_gp=10, dis_lr=1e-4, gen_lr=1e-4, n_critic=5): 7 | self.net = net 8 | self.data_feeder = data_feeder 9 | self.sess = sess 10 | self.batch_size = batch_size 11 | self.dim = dim 12 | self.z_dim = z_dim 13 | self.lambda_gp = lambda_gp 14 | self.dis_lr = dis_lr 15 | self.gen_lr = gen_lr 16 | self.n_critic = n_critic 17 | self.built = False 18 | 19 | def build(self): 20 | self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim]) 21 | self.real_image = tf.placeholder(tf.float32, shape=[None, 3, 64, 64]) 22 | self.fake_image = self.net.generator(self.z) 23 | self.dis_fake = self.net.discriminator(self.fake_image) 24 | self.dis_real = self.net.discriminator(self.real_image) 25 | self.gen_loss = -K.mean(self.dis_fake) 26 | self.dis_loss = K.mean(self.dis_fake) - K.mean(self.dis_real) 27 | alpha = K.random_uniform(shape=[K.shape(self.z)[0], 1, 1, 1]) 28 | diff = self.fake_image - self.real_image 29 | interp = self.real_image + alpha * diff 30 | gradients = K.gradients(self.net.discriminator(interp), [interp])[0] 31 | gp = K.mean(K.square(K.sqrt(K.sum(K.square(gradients), axis=1))-1)) 32 | self.dis_loss += self.lambda_gp * gp 33 | 34 | self.dis_updater = tf.train.AdamOptimizer(learning_rate=self.dis_lr).minimize(self.dis_loss, var_list=self.net.discriminator.trainable_weights) 35 | self.gen_updater = tf.train.AdamOptimizer(learning_rate=self.gen_lr).minimize(self.gen_loss, var_list=self.net.generator.trainable_weights) 36 | self.sess.run(tf.global_variables_initializer()) 37 | self.built = True 38 | 39 | def train(self, epoch): 40 | if not self.built: 41 | self.build() 42 | for i in range(epoch): 43 | for _ in range(self.n_critic): 44 | images = self.data_feeder.fetch_data() 45 | feed_in = {self.z: np.random.normal(size=[self.batch_size, self.z_dim]), 46 | self.real_image: images} 47 | #self.sess.run(self.dis_updater, feed_in) 48 | self.sess.run(self.gen_updater, {self.z: np.random.normal(size=[self.batch_size, self.z_dim])}) 49 | print("epoch: {}, gen_loss: {}, dis_loss{}".format(i, *self.sess.run([self.gen_loss, self.dis_loss], feed_in))) 50 | 51 | def generate_image(self, z, names, concat, save_dir='save'): 52 | if not self.built: 53 | self.build() 54 | self.data_feeder.save_images(self.sess.run(self.fake_image, {self.z: z}), names, concat, save_dir) 55 | 56 | def save_models(self, name, save_dir='save'): 57 | self.net.save_models(name, save_dir) 58 | 59 | --------------------------------------------------------------------------------