├── CycleGAN.py ├── MSGGAN.py ├── NST.py ├── PGGAN.py ├── README.md ├── SAGAN.py └── Summary-of-Different-GAN-Loss-Functions.png /CycleGAN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | CycleGAN on Horse2Zebra dataset 3 | Horse2Zebra dataset: https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip 4 | ''' 5 | import tensorflow as tf 6 | from tensorflow.keras import layers, models, initializers, constraints, optimizers 7 | from tensorflow_addons.layers import InstanceNormalization 8 | import tensorflow.keras.backend as K 9 | from tensorflow.keras.preprocessing.image import load_img, img_to_array 10 | import numpy as np 11 | from random import random 12 | 13 | import os 14 | from os.path import isfile, isdir, join 15 | 16 | from matplotlib import pyplot 17 | 18 | # define the discriminator model 19 | def define_discriminator(image_shape): 20 | init = initializers.RandomNormal(stddev=0.02) 21 | in_image = layers.Input(shape=image_shape) 22 | # C64 23 | d = layers.Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image) 24 | d = layers.LeakyReLU(alpha=0.2)(d) 25 | # C128 26 | d = layers.Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d) 27 | d = InstanceNormalization(axis=-1)(d) 28 | d = layers.LeakyReLU(alpha=0.2)(d) 29 | # C256 30 | d = layers.Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d) 31 | d = InstanceNormalization(axis=-1)(d) 32 | d = layers.LeakyReLU(alpha=0.2)(d) 33 | # C512 34 | d = layers.Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d) 35 | d = InstanceNormalization(axis=-1)(d) 36 | d = layers.LeakyReLU(alpha=0.2)(d) 37 | # second last output layer 38 | d = layers.Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d) 39 | d = InstanceNormalization(axis=-1)(d) 40 | d = layers.LeakyReLU(alpha=0.2)(d) 41 | # patch output 42 | patch_out = layers.Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d) 43 | 44 | model = models.Model(in_image, patch_out) 45 | model.compile(loss='mse', optimizer=optimizers.Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5]) 46 | return model 47 | 48 | # generator a resnet block 49 | def resnet_block(n_filters, input_layer): 50 | init = initializers.RandomNormal(stddev=0.02) 51 | # first layer convolutional layer 52 | g = layers.Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer) 53 | g = InstanceNormalization(axis=-1)(g) 54 | g = layers.Activation('relu')(g) 55 | # second convolutional layer 56 | g = layers.Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g) 57 | g = InstanceNormalization(axis=-1)(g) 58 | # concatenate merge channel-wise with input layer 59 | g = layers.Concatenate()([g, input_layer]) 60 | return g 61 | 62 | # define the standalone generator model 63 | def define_generator(image_shape, n_resnet=9): 64 | init = initializers.RandomNormal(stddev=0.02) 65 | in_image = layers.Input(shape=image_shape) 66 | # c7s1-64 67 | g = layers.Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image) 68 | g = InstanceNormalization(axis=-1)(g) 69 | g = layers.Activation('relu')(g) 70 | # d128 71 | g = layers.Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g) 72 | g = InstanceNormalization(axis=-1)(g) 73 | g = layers.Activation('relu')(g) 74 | # d256 75 | g = layers.Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g) 76 | g = InstanceNormalization(axis=-1)(g) 77 | g = layers.Activation('relu')(g) 78 | # R256 79 | for _ in range(n_resnet): 80 | g = resnet_block(256, g) 81 | # u128 82 | g = layers.Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g) 83 | g = InstanceNormalization(axis=-1)(g) 84 | g = layers.Activation('relu')(g) 85 | # u64 86 | g = layers.Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g) 87 | g = InstanceNormalization(axis=-1)(g) 88 | g = layers.Activation('relu')(g) 89 | # c7s1-3 90 | g = layers.Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g) 91 | g = InstanceNormalization(axis=-1)(g) 92 | out_image = layers.Activation('tanh')(g) 93 | 94 | model = models.Model(in_image, out_image) 95 | return model 96 | 97 | # define a composite model for updating generators by adversarial and cycle loss 98 | def define_composite_model(g_model_1, d_model, g_model_2, image_shape): 99 | # ensure the model we're updating is trainable 100 | g_model_1.trainable = True 101 | # mark discriminator as not trainable 102 | d_model.trainable = False 103 | # mark other generator model as not trainable 104 | g_model_2.trainable = False 105 | # discriminator element 106 | input_gen = layers.Input(shape=image_shape) 107 | gen1_out = g_model_1(input_gen) 108 | output_d = d_model(gen1_out) 109 | # identity element 110 | input_id = layers.Input(shape=image_shape) 111 | output_id = g_model_1(input_id) 112 | # forward cycle 113 | output_f = g_model_2(gen1_out) 114 | # backward cycle 115 | gen2_out = g_model_2(input_id) 116 | output_b = g_model_1(gen2_out) 117 | 118 | model = models.Model([input_gen, input_id], [output_d, output_id, output_f, output_b]) 119 | opt = optimizers.Adam(lr=0.0002, beta_1=0.5) 120 | # compile model with weighting of least squares loss and L1 loss 121 | model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt) 122 | return model 123 | 124 | # load images from folder 125 | def load_real_samples(imgs_path, size=(256,256)): 126 | imgs = [] 127 | for img_name in os.listdir(imgs_path): 128 | img_path = join(imgs_path, img_name) 129 | if isfile(img_path): 130 | img = load_img(img_path, target_size=size) 131 | img = img_to_array(img) 132 | imgs.append(img) 133 | imgs = np.asarray(imgs) 134 | # scale from [0,255] to [-1,1] 135 | imgs = imgs / 127.5 - 1. 136 | return imgs 137 | 138 | # select a batch of random samples, returns images and target 139 | def generate_real_samples(dataset, n_samples, patch_shape): 140 | # select random instances 141 | X = dataset[np.random.randint(0, dataset.shape[0], n_samples)] 142 | # generate 'real' class labels (1) 143 | y = np.ones((n_samples, patch_shape, patch_shape, 1)) 144 | return X, y 145 | 146 | # generate a batch of images, returns images and targets 147 | def generate_fake_samples(g_model, dataset, patch_shape): 148 | # generate fake instance 149 | X = g_model.predict(dataset) 150 | # create 'fake' class labels (0) 151 | y = np.zeros((len(X), patch_shape, patch_shape, 1)) 152 | return X, y 153 | 154 | # save the generator models to file 155 | def save_models(step, g_model_AtoB, g_model_BtoA): 156 | # save the first generator model 157 | filename1 = 'g_model_AtoB_%06d.h5' % (step+1) 158 | g_model_AtoB.save(filename1) 159 | # save the second generator model 160 | filename2 = 'g_model_BtoA_%06d.h5' % (step+1) 161 | g_model_BtoA.save(filename2) 162 | print('>Saved: %s and %s' % (filename1, filename2)) 163 | 164 | # generate samples and save as a plot and save the model 165 | def summarize_performance(step, g_model, trainX, name, n_samples=5): 166 | # select a sample of input images 167 | X_in, _ = generate_real_samples(trainX, n_samples, 0) 168 | # generate translated images 169 | X_out, _ = generate_fake_samples(g_model, X_in, 0) 170 | # scale all pixels from [-1,1] to [0,1] 171 | X_in = (X_in + 1) / 2.0 172 | X_out = (X_out + 1) / 2.0 173 | # plot real images 174 | for i in range(n_samples): 175 | pyplot.subplot(2, n_samples, 1 + i) 176 | pyplot.axis('off') 177 | pyplot.imshow(X_in[i]) 178 | # plot translated image 179 | for i in range(n_samples): 180 | pyplot.subplot(2, n_samples, 1 + n_samples + i) 181 | pyplot.axis('off') 182 | pyplot.imshow(X_out[i]) 183 | # save plot to file 184 | filename1 = '%s_generated_plot_%06d.png' % (name, (step+1)) 185 | pyplot.savefig(filename1) 186 | pyplot.close() 187 | 188 | # update image pool for fake images 189 | def update_image_pool(pool, images, max_size=50): 190 | selected = list() 191 | for image in images: 192 | if len(pool) < max_size: 193 | # stock the pool 194 | pool.append(image) 195 | selected.append(image) 196 | elif random() < 0.5: 197 | # use image, but don't add it to the pool 198 | selected.append(image) 199 | else: 200 | # replace an existing image and use replaced image 201 | ix = np.random.randint(0, len(pool)) 202 | selected.append(pool[ix]) 203 | pool[ix] = image 204 | return np.asarray(selected) 205 | 206 | # train cyclegan models 207 | def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset): 208 | # define properties of the training run 209 | n_epochs, n_batch, = 100, 1 210 | # determine the output square shape of the discriminator 211 | n_patch = d_model_A.output_shape[1] 212 | 213 | trainA, trainB = dataset 214 | # prepare image pool for fakes 215 | poolA, poolB = list(), list() 216 | # calculate the number of batches per training epoch 217 | bat_per_epo = int(len(trainA) / n_batch) 218 | # calculate the number of training iterations 219 | n_steps = bat_per_epo * n_epochs 220 | # manually enumerate epochs 221 | for i in range(n_steps): 222 | # select a batch of real samples 223 | X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch) 224 | X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch) 225 | # generate a batch of fake samples 226 | X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch) 227 | X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch) 228 | # update fakes from pool 229 | X_fakeA = update_image_pool(poolA, X_fakeA) 230 | X_fakeB = update_image_pool(poolB, X_fakeB) 231 | # update generator B->A via adversarial and cycle loss 232 | g_loss2, _, _, _, _ = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA]) 233 | # update discriminator for A -> [real/fake] 234 | dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA) 235 | dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA) 236 | # update generator A->B via adversarial and cycle loss 237 | g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB]) 238 | # update discriminator for B -> [real/fake] 239 | dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB) 240 | dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB) 241 | # summarize performance 242 | print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2)) 243 | # evaluate the model performance every so often 244 | if (i+1) % (bat_per_epo * 1) == 0: 245 | # plot A->B translation 246 | summarize_performance(i, g_model_AtoB, trainA, 'AtoB') 247 | # plot B->A translation 248 | summarize_performance(i, g_model_BtoA, trainB, 'BtoA') 249 | if (i+1) % (bat_per_epo * 5) == 0: 250 | # save the models 251 | save_models(i, g_model_AtoB, g_model_BtoA) 252 | 253 | # call this function when using GPU with small memory 254 | def using_gpu_memory_growth(): 255 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 256 | assert len(physical_devices) > 0, "Not enough GPU hardware devices available" 257 | config = tf.config.experimental.set_memory_growth(physical_devices[0], True) 258 | return config 259 | 260 | if __name__ == "__main__": 261 | using_gpu_memory_growth() 262 | # load two different style images 263 | trainA = load_real_samples("./Projects/GAN_Exp/horse2zebra/trainA/") 264 | trainB = load_real_samples("./Projects/GAN_Exp/horse2zebra/trainB/") 265 | dataset = [trainA, trainB] 266 | print('Loaded', dataset[0].shape, dataset[1].shape) 267 | image_shape = dataset[0].shape[1:] 268 | # generator: A -> B 269 | g_model_AtoB = define_generator(image_shape) 270 | # generator: B -> A 271 | g_model_BtoA = define_generator(image_shape) 272 | # discriminator: A -> [real/fake] 273 | d_model_A = define_discriminator(image_shape) 274 | # discriminator: B -> [real/fake] 275 | d_model_B = define_discriminator(image_shape) 276 | # composite: A -> B -> [real/fake, A] 277 | c_model_AtoB = define_composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape) 278 | # composite: B -> A -> [real/fake, B] 279 | c_model_BtoA = define_composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape) 280 | 281 | train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset) -------------------------------------------------------------------------------- /MSGGAN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Multi-Scale Gradients GAN on celebrity faces dataset 3 | CelebA dataset: https://www.kaggle.com/jessicali9530/celeba-dataset 4 | ''' 5 | import tensorflow as tf 6 | from tensorflow_addons.layers import SpectralNormalization 7 | from tensorflow import keras 8 | from tensorflow.keras import layers 9 | import tensorflow.keras.backend as K 10 | import numpy as np 11 | import random 12 | 13 | import os 14 | from os.path import isfile, isdir, join 15 | import time 16 | import datetime 17 | 18 | import cv2 as cv 19 | import imutils 20 | 21 | save_every_number_episodes = 50 22 | max_checkpoints_number = 3 23 | log_dir = "./Projects/GAN_Exp/logs/" 24 | saved_model_path = "./Projects/GAN_Exp/saved_models/" 25 | images_path = "./Projects/GAN_Exp/Face_64X64/" 26 | load_model_id = "" 27 | 28 | EPOCHS = 100 29 | BATCH_SIZE = 12 30 | LATENT_DIM = 256 31 | 32 | IS_MSG = True 33 | 34 | lr_g_max = 3e-4 35 | lr_g_min = 1e-4 36 | beta_1_g = 0. 37 | beta_2_g = 0.9 38 | lr_d_max = 3e-4 39 | lr_d_min = 1e-4 40 | beta_1_d = 0. 41 | beta_2_d = 0.9 42 | 43 | # mini-batch standard deviation layer 44 | class MinibatchStdev(layers.Layer): 45 | def __init__(self, **kwargs): 46 | super(MinibatchStdev, self).__init__(**kwargs) 47 | 48 | # calculate the mean standard deviation across each pixel coord 49 | def call(self, inputs): 50 | mean = K.mean(inputs, axis=0, keepdims=True) 51 | mean_sq_diff = K.mean(K.square(inputs - mean), axis=0, keepdims=True) + 1e-8 52 | mean_pix = K.mean(K.sqrt(mean_sq_diff), keepdims=True) 53 | shape = K.shape(inputs) 54 | output = K.tile(mean_pix, [shape[0], shape[1], shape[2], 1]) 55 | return K.concatenate([inputs, output], axis=-1) 56 | 57 | # define the output shape of the layer 58 | def compute_output_shape(self, input_shape): 59 | input_shape = list(input_shape) 60 | input_shape[-1] += 1 61 | return tuple(input_shape) 62 | 63 | # pixel-wise feature vector normalization layer 64 | class PixelNormalization(layers.Layer): 65 | # initialize the layer 66 | def __init__(self, **kwargs): 67 | super(PixelNormalization, self).__init__(**kwargs) 68 | 69 | # L2 norm for all the activations in the same image and at the same location across all channels 70 | def call(self, inputs): 71 | return inputs / K.sqrt(K.mean(inputs**2, axis=-1, keepdims=True) + 1e-8) 72 | 73 | # define the output shape of the layer 74 | def compute_output_shape(self, input_shape): 75 | return input_shape 76 | 77 | class MSGGAN: 78 | def __init__(self, latent_dim, beta_1_g, beta_2_g, beta_1_d, beta_2_d, is_msg): 79 | self.latent_dim = latent_dim 80 | self.is_msg = is_msg 81 | 82 | self.generator = self.make_generator_model() 83 | self.discriminator = self.make_discriminator_model() 84 | 85 | self.generator_optimizer = keras.optimizers.Adam(beta_1=beta_1_g, beta_2=beta_2_g) 86 | self.discriminator_optimizer = keras.optimizers.Adam(beta_1_d, beta_2_d) 87 | #self.generator_optimizer = keras.optimizers.RMSprop() 88 | #self.discriminator_optimizer = keras.optimizers.RMSprop() 89 | 90 | def stack_layer(self, filters, size, strides=1, padding='same', k_init_stddev=1., max_norm=0., power_iterations=-1, 91 | norm_func=None, dropout_rate=0., conv_func=layers.Conv2D, activation=tf.nn.leaky_relu): 92 | result = keras.Sequential() 93 | 94 | kernel_initializer = keras.initializers.RandomNormal(stddev=k_init_stddev) if k_init_stddev > 0 else 'glorot_uniform' 95 | kernel_constraint = keras.constraints.MaxNorm(max_norm) if max_norm > 0 else None 96 | conv = conv_func(filters, size, strides=strides, 97 | padding=padding, 98 | kernel_initializer=kernel_initializer, 99 | kernel_constraint=kernel_constraint) 100 | if power_iterations > 0: 101 | conv = SpectralNormalization(conv, power_iterations=power_iterations) 102 | result.add(conv) 103 | 104 | if norm_func is not None: 105 | result.add(norm_func()) 106 | 107 | if dropout_rate > 0: 108 | result.add(layers.Dropout(dropout_rate)) 109 | 110 | if activation is not None: 111 | result.add(layers.Activation(activation)) 112 | 113 | return result 114 | 115 | def dense_to_conv(self, x, conv_shape, k_init_stddev=1., max_norm=0.): 116 | kernel_initializer = keras.initializers.RandomNormal(stddev=k_init_stddev) if k_init_stddev > 0 else 'glorot_uniform' 117 | kernel_constraint = keras.constraints.MaxNorm(max_norm) if max_norm > 0 else None 118 | x = layers.Dense(tf.reduce_prod(conv_shape).numpy(), 119 | kernel_initializer=kernel_initializer, 120 | kernel_constraint=kernel_constraint)(x) 121 | x = layers.Reshape(conv_shape)(x) 122 | return x 123 | 124 | def make_generator_model(self): 125 | outputs = [] 126 | latent = layers.Input(shape=(self.latent_dim, )) 127 | x = self.dense_to_conv(latent, (4, 4, 512)) # (bs, 4, 4, 512) 128 | ''' 129 | x = layers.Reshape((1, 1, self.latent_dim))(latent) # (bs, 1, 1, latent_dim) 130 | x = self.stack_layer(512, 4, strides=4, conv_func=layers.Conv2DTranspose)(x) # (bs, 4, 4, 512) 131 | ''' 132 | x = self.stack_layer(512, 4, norm_func=PixelNormalization)(x) 133 | if self.is_msg: 134 | o = self.stack_layer(3, 1, activation=None)(x) # (bs, 4, 4, 3) 135 | outputs.append(o) 136 | 137 | x = layers.UpSampling2D(2)(x) # (bs, 8, 8, 512) 138 | x = self.stack_layer(512, 4, norm_func=PixelNormalization)(x) 139 | x = self.stack_layer(512, 4, norm_func=PixelNormalization)(x) 140 | if self.is_msg: 141 | o = self.stack_layer(3, 1, activation=None)(x) # (bs, 8, 8, 3) 142 | outputs.append(o) 143 | 144 | x = layers.UpSampling2D(2)(x) # (bs, 16, 16, 512) 145 | x = self.stack_layer(512, 4, norm_func=PixelNormalization)(x) 146 | x = self.stack_layer(512, 4, norm_func=PixelNormalization)(x) 147 | if self.is_msg: 148 | o = self.stack_layer(3, 1, activation=None)(x) # (bs, 16, 16, 3) 149 | outputs.append(o) 150 | 151 | x = layers.UpSampling2D(2)(x) # (bs, 32, 32, 512) 152 | x = self.stack_layer(512, 5, norm_func=PixelNormalization)(x) 153 | x = self.stack_layer(512, 5, norm_func=PixelNormalization)(x) 154 | if self.is_msg: 155 | o = self.stack_layer(3, 1, activation=None)(x) # (bs, 32, 32, 3) 156 | outputs.append(o) 157 | 158 | x = layers.UpSampling2D(2)(x) # (bs, 64, 64, 512) 159 | x = self.stack_layer(256, 5, norm_func=PixelNormalization)(x) # (bs, 64, 64, 256) 160 | x = self.stack_layer(256, 5, norm_func=PixelNormalization)(x) 161 | o = self.stack_layer(3, 1, activation=None)(x) # (bs, 64, 64, 3) 162 | outputs.append(o) 163 | ''' 164 | x = layers.UpSampling2D(2)(x) # (bs, 128, 128, 256) 165 | x = self.stack_layer(128, 3, norm_func=PixelNormalization)(x) # (bs, 128, 128, 128) 166 | x = self.stack_layer(128, 3, norm_func=PixelNormalization)(x) 167 | o = self.stack_layer(3, 1)(x) # (bs, 128, 128, 3) 168 | outputs.append(o) 169 | ''' 170 | 171 | return keras.Model(inputs=latent, outputs=outputs) 172 | 173 | def make_discriminator_model(self): 174 | inputs = [] 175 | ''' 176 | i = layers.Input(shape=(128, 128, 3)) 177 | inputs.append(i) 178 | x = MinibatchStdev()(i) 179 | x = self.stack_layer(128, 3)(x) # (bs, 128, 128, 128) 180 | x = self.stack_layer(256, 3)(x) # (bs, 128, 128, 256) 181 | x = layers.AvgPool2D(2)(x) # (bs, 64, 64, 256) 182 | ''' 183 | i = layers.Input(shape=(64, 64, 3)) 184 | inputs.append(i) 185 | #x = layers.Concatenate()([i, x]) 186 | x = MinibatchStdev()(i) 187 | x = self.stack_layer(256, 5, power_iterations=5)(x) # (bs, 64, 64, 256) 188 | x = self.stack_layer(512, 5, power_iterations=5)(x) # (bs, 64, 64, 512) 189 | x = layers.AvgPool2D(2)(x) # (bs, 32, 32, 512) 190 | 191 | if self.is_msg: 192 | i = layers.Input(shape=(32, 32, 3)) 193 | inputs.append(i) 194 | x = layers.Concatenate()([i, x]) 195 | x = MinibatchStdev()(x) 196 | x = self.stack_layer(512, 5, power_iterations=5)(x) # (bs, 32, 32, 512) 197 | x = self.stack_layer(512, 5, power_iterations=5)(x) # (bs, 32, 32, 512) 198 | x = layers.AvgPool2D(2)(x) # (bs, 16, 16, 512) 199 | 200 | if self.is_msg: 201 | i = layers.Input(shape=(16, 16, 3)) 202 | inputs.append(i) 203 | x = layers.Concatenate()([i, x]) 204 | x = MinibatchStdev()(x) 205 | x = self.stack_layer(512, 4, power_iterations=5)(x) # (bs, 16, 16, 512) 206 | x = self.stack_layer(512, 4, power_iterations=5)(x) # (bs, 16, 16, 512) 207 | x = layers.AvgPool2D(2)(x) # (bs, 8, 8, 512) 208 | 209 | if self.is_msg: 210 | i = layers.Input(shape=(8, 8, 3)) 211 | inputs.append(i) 212 | x = layers.Concatenate()([i, x]) 213 | x = MinibatchStdev()(x) 214 | x = self.stack_layer(512, 3, power_iterations=5)(x) # (bs, 8, 8, 512) 215 | x = self.stack_layer(512, 3, power_iterations=5)(x) # (bs, 8, 8, 512) 216 | x = layers.AvgPool2D(2)(x) # (bs, 4, 4, 512) 217 | 218 | if self.is_msg: 219 | i = layers.Input(shape=(4, 4, 3)) 220 | inputs.append(i) 221 | x = layers.Concatenate()([i, x]) 222 | x = MinibatchStdev()(x) 223 | x = self.stack_layer(512, 3, power_iterations=5)(x) # (bs, 4, 4, 512) 224 | x = self.stack_layer(512, 4, power_iterations=5, padding='valid')(x) # (bs, 1, 1, 512) 225 | x = layers.Flatten()(x) 226 | x = layers.Dense(1)(x) # (bs, 1, 1, 1) 227 | 228 | return keras.Model(inputs=inputs, outputs=x) 229 | 230 | def sigmoid_cross_entropy_loss(self, target, output): 231 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(tf.constant(target, dtype=tf.float32, shape=output.shape), output)) 232 | 233 | def wasserstein_loss(self, target, output): 234 | return -1. * tf.reduce_mean(target * output) 235 | 236 | def hinge_loss(self, target, output): 237 | return -1. * tf.reduce_mean(tf.minimum(0., -1. + target * output)) 238 | 239 | @tf.function 240 | def train_step(self, real_images, batch_size, lr_g, lr_d): 241 | latent = self.generate_latent(batch_size) 242 | with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: 243 | gen_outputs = self.generator(latent, training=True) 244 | if self.is_msg: 245 | gen_outputs.reverse() 246 | 247 | disc_real_output = self.discriminator(real_images, training=True) 248 | disc_generated_output = self.discriminator(gen_outputs, training=True) 249 | 250 | #disc_loss_real = self.sigmoid_cross_entropy_loss(.9, disc_real_output) 251 | #disc_loss_gen = self.sigmoid_cross_entropy_loss(0., disc_generated_output) 252 | #disc_loss_real = self.wasserstein_loss(1., disc_real_output) 253 | #disc_loss_gen = self.wasserstein_loss(-1., disc_generated_output) 254 | disc_loss_real = self.hinge_loss(1., disc_real_output) 255 | disc_loss_gen = self.hinge_loss(-1., disc_generated_output) 256 | disc_total_loss = disc_loss_real + disc_loss_gen 257 | 258 | #gen_total_loss = self.sigmoid_cross_entropy_loss(.9, disc_generated_output) 259 | gen_total_loss = self.wasserstein_loss(1., disc_generated_output) 260 | 261 | self.discriminator_optimizer.learning_rate = lr_d 262 | self.generator_optimizer.learning_rate = lr_g 263 | discriminator_gradients = disc_tape.gradient(disc_total_loss, self.discriminator.trainable_variables) 264 | generator_gradients = gen_tape.gradient(gen_total_loss, self.generator.trainable_variables) 265 | self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables)) 266 | self.generator_optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables)) 267 | 268 | train_dict = {"disc_total_loss": disc_total_loss, "gen_total_loss": gen_total_loss, 269 | "max disc_grads[-1]": tf.reduce_max(discriminator_gradients[-1]), "max disc_grads[0]": tf.reduce_max(discriminator_gradients[0]), 270 | "max gen_grads[-1]": tf.reduce_max(generator_gradients[-1]), "max gen_grads[0]": tf.reduce_max(generator_gradients[0])} 271 | return train_dict 272 | 273 | def generate_latent(self, batch_size=1): 274 | latent = tf.random.normal([batch_size, self.latent_dim]) 275 | #latent = tf.math.l2_normalize(latent, axis=1) 276 | return latent 277 | 278 | def generate_samples(self, batch_size=1): 279 | return self.generator(self.generate_latent(batch_size), training=True) 280 | 281 | def load_model(path, model, max_checkpoints_number): 282 | if path is not None: 283 | tf.print("Loading model...from: {}".format(path)) 284 | ckpt = tf.train.Checkpoint(model=model) 285 | manager = tf.train.CheckpointManager(ckpt, path, max_to_keep=max_checkpoints_number) 286 | ckpt.restore(manager.latest_checkpoint) 287 | 288 | def save_model(path, model, max_checkpoints_number): 289 | if path is not None: 290 | tf.print("Saveing model...from: {}".format(path)) 291 | ckpt = tf.train.Checkpoint(model=model) 292 | manager = tf.train.CheckpointManager(ckpt, path, max_to_keep=max_checkpoints_number) 293 | manager.save() 294 | 295 | def load_images(imgs_path): 296 | imgs = [] 297 | for img_name in os.listdir(imgs_path): 298 | img_path = join(imgs_path, img_name) 299 | if isfile(img_path): 300 | img = cv.imread(img_path) 301 | img = np.float32(img) 302 | imgs.append(img) 303 | return imgs 304 | 305 | def image_normalization(img, img_min=0, img_max=255): 306 | """This is a typical image normalization function 307 | where the minimum and maximum of the image is needed 308 | source: https://en.wikipedia.org/wiki/Normalization_(image_processing) 309 | :param img: an image could be gray scale or color 310 | :param img_min: for default is 0 311 | :param img_max: for default is 255 312 | :return: a normalized image, if max is 255 the dtype is uint8 313 | """ 314 | img = np.float32(img) 315 | epsilon=1e-12 # whenever an inconsistent image 316 | img = (img-np.min(img))*(img_max-img_min)/((np.max(img)-np.min(img))+epsilon)+img_min 317 | return img 318 | 319 | def generate_and_save_samples_l(gan, W=64, H=64, save_path="./", name="sample"): 320 | outputs = gan.generate_samples() 321 | concat_images = [] 322 | if gan.is_msg: 323 | for img in outputs: 324 | img = image_normalization(img)[0, :] 325 | img = cv.resize(img, (W, H), interpolation=cv.INTER_NEAREST) 326 | concat_images.append(img) 327 | concat_output = cv.hconcat(concat_images) 328 | else: 329 | concat_output = image_normalization(outputs)[0, :] 330 | cv.imwrite(save_path + name + ".png", concat_output) 331 | 332 | def resize_batch_images(imgs, img_min=-1., img_max=1., size_list=[(64, 64), (32, 32), (16, 16), (8, 8), (4, 4)]): 333 | batch_resized_images =[] 334 | for W, H in size_list: 335 | resized_list = [] 336 | for img in imgs: 337 | img = image_normalization(img, img_min=img_min, img_max=img_max) 338 | resized_img = cv.resize(img, (W, H), interpolation=cv.INTER_AREA) 339 | resized_list.append(resized_img) 340 | batch_resized_images.append(np.array(resized_list)) 341 | return batch_resized_images 342 | 343 | def data_augmentation(img, min_rot_angle=-180, max_rot_angle=180, crop_ratio=0.2, smooth_size=3, sharp_val=3, max_noise_scale=10): 344 | (H, W) = img.shape[:2] 345 | img_a = img 346 | 347 | all_func = ['flip', 'rotate', 'crop', 'smooth', 'sharp', 'noise'] 348 | do_func = np.random.choice(all_func, size=np.random.randint(1, len(all_func)), replace=False) 349 | #do_func = ['crop'] 350 | # Filp image, 0: vertically, 1: horizontally 351 | if 'flip' in do_func: 352 | img_a = cv.flip(img_a, np.random.choice([0, 1])) 353 | # Rotate image 354 | if 'rotate' in do_func: 355 | rot_ang = np.random.uniform(min_rot_angle, max_rot_angle) 356 | img_a = imutils.rotate_bound(img_a, rot_ang) 357 | # Crop image 358 | if 'crop' in do_func: 359 | (H_A, W_A) = img_a.shape[:2] 360 | start_x = np.random.randint(0, int(H_A * crop_ratio)) 361 | start_y = np.random.randint(0, int(W_A * crop_ratio)) 362 | end_x = np.random.randint(int(H_A * (1-crop_ratio)), H_A) 363 | end_y = np.random.randint(int(W_A * (1-crop_ratio)), W_A) 364 | 365 | img_a = img_a[start_x:end_x, start_y:end_y] 366 | # Smoothing 367 | if 'smooth' in do_func: 368 | img_a = cv.GaussianBlur(img_a, (smooth_size, smooth_size), 0) 369 | # Sharpening 370 | if 'sharp' in do_func: 371 | de_sharp_val = -(sharp_val - 1) / 8 372 | kernel = np.array([[de_sharp_val]*3, [de_sharp_val, sharp_val, de_sharp_val], [de_sharp_val]*3]) 373 | img_a = cv.filter2D(img_a, -1, kernel) 374 | # Add the Gaussian noise to the image 375 | if 'noise' in do_func: 376 | noise_scale = np.random.uniform(0, max_noise_scale) 377 | gauss = np.random.normal(0, noise_scale, img_a.size) 378 | gauss = np.float32(gauss.reshape(img_a.shape[0],img_a.shape[1],img_a.shape[2])) 379 | img_a = cv.add(img_a,gauss) 380 | # Keep shape 381 | img_a = cv.resize(img_a, (W, H)) 382 | return np.float32(img_a) 383 | 384 | def train(): 385 | ds = load_images(images_path) 386 | 387 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 388 | summary_writer = tf.summary.create_file_writer(log_dir + "fit/" + current_time) 389 | 390 | msgan = MSGGAN(LATENT_DIM, beta_1_g, beta_2_d, beta_1_d, beta_2_d, IS_MSG) 391 | 392 | size_list = [(64, 64), (32, 32), (16, 16), (8, 8), (4, 4)] if IS_MSG else [(64, 64)] 393 | 394 | ''' 395 | msgan.load_model(saved_model_path + "MSGGAN_Generator_" + load_model_id, msgan.generator, max_checkpoints_number) 396 | msgan.load_model(saved_model_path + "MSGGAN_Discriminator_" + load_model_id, msgan.discriminator, max_checkpoints_number) 397 | ''' 398 | ds_len = len(ds) 399 | for epoch in range(EPOCHS): 400 | start = time.time() 401 | 402 | random.shuffle(ds) 403 | tf.print("Epoch: ", epoch) 404 | 405 | l = epoch / EPOCHS 406 | lr_g = lr_g_max - l * (lr_g_max - lr_g_min) 407 | lr_d = lr_d_max - l * (lr_d_max - lr_d_min) 408 | # Train 409 | for n in range(0, ds_len, BATCH_SIZE): 410 | target_image_b = [] 411 | for i_n in range(n, n + BATCH_SIZE): 412 | tf.print('.', end='') 413 | i_n %= ds_len 414 | if (i_n + 1) % 100 == 0: 415 | name_id = "_e" + str(epoch) + "_s" + str(i_n + 1) 416 | generate_and_save_samples_l(msgan, save_path="./Projects/GAN_Exp/generated_samples/", name=name_id) 417 | tf.print("\n") 418 | 419 | target_image = ds[i_n] 420 | #target_image = data_augmentation(target_image) 421 | target_image_b.append(target_image) 422 | 423 | target_image_b = resize_batch_images(target_image_b, size_list=size_list) 424 | 425 | train_dict = msgan.train_step(target_image_b, BATCH_SIZE, lr_g, lr_d) 426 | 427 | with summary_writer.as_default(): 428 | for scaler_name in train_dict: 429 | tf.summary.scalar(scaler_name, train_dict[scaler_name], step=epoch) 430 | tf.print(scaler_name + ": {}".format(train_dict[scaler_name]), end=', ') 431 | tf.print("\n") 432 | 433 | #generate_and_save_samples_l(msgan, save_path="./Projects/GAN_Exp/generated_samples/") 434 | 435 | # Save the model every certain number epochs 436 | if (epoch + 1) % save_every_number_episodes == 0: 437 | save_model(saved_model_path + "MSGGAN_Generator_" + current_time, msgan.generator, max_checkpoints_number) 438 | save_model(saved_model_path + "MSGGAN_Discriminator_" + current_time, msgan.discriminator, max_checkpoints_number) 439 | 440 | tf.print ('\nTime taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start)) 441 | 442 | # call this function when using GPU with small memory 443 | def using_gpu_memory_growth(): 444 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 445 | assert len(physical_devices) > 0, "Not enough GPU hardware devices available" 446 | config = tf.config.experimental.set_memory_growth(physical_devices[0], True) 447 | return config 448 | 449 | if __name__ == "__main__": 450 | using_gpu_memory_growth() 451 | train() 452 | 453 | -------------------------------------------------------------------------------- /NST.py: -------------------------------------------------------------------------------- 1 | # Neural style transfer 2 | import tensorflow as tf 3 | from tensorflow import keras 4 | 5 | import os 6 | import time 7 | import numpy as np 8 | from PIL import Image 9 | import functools 10 | import matplotlib.pyplot as plt 11 | import matplotlib as mpl 12 | import IPython.display as display 13 | 14 | # Load compressed models from tensorflow_hub 15 | os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' 16 | 17 | mpl.rcParams['figure.figsize'] = (12,12) 18 | mpl.rcParams['axes.grid'] = False 19 | 20 | content_layers = ['block5_conv2'] 21 | style_layers = ['block1_conv1', 22 | 'block2_conv1', 23 | 'block3_conv1', 24 | 'block4_conv1', 25 | 'block5_conv1'] 26 | 27 | epochs = 10 28 | steps_per_epoch = 100 29 | 30 | style_weight = 1e-2 31 | content_weight = 1e4 32 | total_variation_weight = 30 33 | learning_rate = 0.02 34 | beta_1 = 0.99 35 | epsilon = 1e-1 36 | 37 | def load_img(path_to_img): 38 | max_dim = 512 39 | img = tf.io.read_file(path_to_img) 40 | img = tf.image.decode_image(img, channels=3) 41 | img = tf.image.convert_image_dtype(img, tf.float32) 42 | 43 | shape = tf.cast(tf.shape(img)[:-1], tf.float32) 44 | long_dim = max(shape) 45 | scale = max_dim / long_dim 46 | 47 | new_shape = tf.cast(shape * scale, tf.int32) 48 | 49 | img = tf.image.resize(img, new_shape) 50 | img = img[tf.newaxis, :] 51 | return img 52 | 53 | def imshow(image, title=None): 54 | if len(image.shape) > 3: 55 | image = tf.squeeze(image, axis=0) 56 | 57 | plt.imshow(image) 58 | if title: 59 | plt.title(title) 60 | 61 | class StyleContentModel(keras.models.Model): 62 | def __init__(self, style_layers, content_layers): 63 | super(StyleContentModel, self).__init__() 64 | self.vgg = self.vgg_layers(style_layers + content_layers) 65 | self.style_layers = style_layers 66 | self.content_layers = content_layers 67 | self.num_style_layers = len(style_layers) 68 | self.num_content_layers = len(content_layers) 69 | self.vgg.trainable = False 70 | self.opt = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=beta_1, epsilon=epsilon) 71 | 72 | def call(self, inputs): 73 | "Expects float input in [0,1]" 74 | inputs = inputs*255.0 75 | preprocessed_input = keras.applications.vgg19.preprocess_input(inputs) 76 | outputs = self.vgg(preprocessed_input) 77 | style_outputs, content_outputs = (outputs[:self.num_style_layers], 78 | outputs[self.num_style_layers:]) 79 | 80 | style_outputs = [self.gram_matrix(style_output) 81 | for style_output in style_outputs] 82 | 83 | content_dict = {content_name:value 84 | for content_name, value 85 | in zip(self.content_layers, content_outputs)} 86 | 87 | style_dict = {style_name:value 88 | for style_name, value 89 | in zip(self.style_layers, style_outputs)} 90 | 91 | return {'content':content_dict, 'style':style_dict} 92 | 93 | def gram_matrix(self, input_tensor): 94 | result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor) 95 | input_shape = tf.shape(input_tensor) 96 | num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32) 97 | return result/(num_locations) 98 | 99 | def vgg_layers(self, layer_names): 100 | """ Creates a vgg model that returns a list of intermediate output values.""" 101 | # Load our model. Load pretrained VGG, trained on imagenet data 102 | 103 | vgg = keras.applications.VGG19(include_top=False, weights='imagenet') 104 | vgg.trainable = False 105 | 106 | outputs = [vgg.get_layer(name).output for name in layer_names] 107 | 108 | model = keras.Model([vgg.input], outputs) 109 | return model 110 | 111 | def style_content_loss(outputs, extractor, style_targets, content_targets): 112 | style_outputs = outputs['style'] 113 | content_outputs = outputs['content'] 114 | style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-style_targets[name])**2) 115 | for name in style_outputs.keys()]) 116 | style_loss *= style_weight / extractor.num_style_layers 117 | 118 | content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-content_targets[name])**2) 119 | for name in content_outputs.keys()]) 120 | content_loss *= content_weight / extractor.num_content_layers 121 | loss = style_loss + content_loss 122 | return loss 123 | 124 | @tf.function() 125 | def train_step(image, extractor, style_targets, content_targets): 126 | with tf.GradientTape() as tape: 127 | outputs = extractor(image) 128 | loss = style_content_loss(outputs, extractor, style_targets, content_targets) 129 | loss += total_variation_weight*tf.image.total_variation(image) 130 | 131 | grad = tape.gradient(loss, image) 132 | extractor.opt.apply_gradients([(grad, image)]) 133 | image.assign(tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)) 134 | 135 | def tensor_to_image(tensor): 136 | tensor = tensor*255 137 | tensor = np.array(tensor, dtype=np.uint8) 138 | if np.ndim(tensor)>3: 139 | assert tensor.shape[0] == 1 140 | tensor = tensor[0] 141 | return Image.fromarray(tensor) 142 | 143 | def train(): 144 | # Load and show both content and style image 145 | content_path = keras.utils.get_file('YellowLabradorLooking_new.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg') 146 | style_path = keras.utils.get_file('kandinsky5.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg') 147 | content_image = load_img(content_path) 148 | style_image = load_img(style_path) 149 | 150 | plt.subplot(1, 2, 1) 151 | imshow(content_image, 'Content Image') 152 | plt.subplot(1, 2, 2) 153 | imshow(style_image, 'Style Image') 154 | plt.show() 155 | 156 | start = time.time() 157 | generated_image = tf.Variable(content_image) 158 | 159 | extractor = StyleContentModel(style_layers, content_layers) 160 | style_targets = extractor(style_image)['style'] 161 | content_targets = extractor(content_image)['content'] 162 | 163 | step = 0 164 | for n in range(epochs): 165 | for m in range(steps_per_epoch): 166 | step += 1 167 | train_step(generated_image, extractor, style_targets, content_targets) 168 | print(".", end='') 169 | display.clear_output(wait=True) 170 | display.display(tensor_to_image(generated_image)) 171 | print("Train step: {}".format(step)) 172 | 173 | end = time.time() 174 | print("Total time: {:.1f}".format(end-start)) 175 | 176 | if __name__ == "__main__": 177 | train() 178 | -------------------------------------------------------------------------------- /PGGAN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Progressive Growing GAN on celebrity faces dataset 3 | CelebA dataset: https://www.kaggle.com/jessicali9530/celeba-dataset 4 | ''' 5 | import tensorflow as tf 6 | from tensorflow.keras import layers, models, initializers, constraints, optimizers 7 | import tensorflow.keras.backend as K 8 | import numpy as np 9 | 10 | import os 11 | from os.path import isfile, isdir, join 12 | 13 | from skimage import io 14 | from skimage.transform import resize 15 | from matplotlib import pyplot 16 | 17 | # pixel-wise feature vector normalization layer 18 | class PixelNormalization(layers.Layer): 19 | # initialize the layer 20 | def __init__(self, **kwargs): 21 | super(PixelNormalization, self).__init__(**kwargs) 22 | 23 | # L2 norm for all the activations in the same image and at the same location across all channels 24 | def call(self, inputs): 25 | return inputs / K.sqrt(K.mean(inputs**2, axis=-1, keepdims=True) + 1e-8) 26 | 27 | # define the output shape of the layer 28 | def compute_output_shape(self, input_shape): 29 | return input_shape 30 | 31 | # mini-batch standard deviation layer 32 | class MinibatchStdev(layers.Layer): 33 | def __init__(self, **kwargs): 34 | super(MinibatchStdev, self).__init__(**kwargs) 35 | 36 | # calculate the mean standard deviation across each pixel coord 37 | def call(self, inputs): 38 | mean = K.mean(inputs, axis=0, keepdims=True) 39 | mean_sq_diff = K.mean(K.square(inputs - mean), axis=0, keepdims=True) + 1e-8 40 | mean_pix = K.mean(K.sqrt(mean_sq_diff), keepdims=True) 41 | shape = K.shape(inputs) 42 | output = K.tile(mean_pix, [shape[0], shape[1], shape[2], 1]) 43 | return K.concatenate([inputs, output], axis=-1) 44 | 45 | # define the output shape of the layer 46 | def compute_output_shape(self, input_shape): 47 | input_shape = list(input_shape) 48 | input_shape[-1] += 1 49 | return tuple(input_shape) 50 | 51 | # weighted sum output 52 | class WeightedSum(layers.Add): 53 | def __init__(self, **kwargs): 54 | super(WeightedSum, self).__init__(**kwargs) 55 | self.alpha = K.variable(0., name='ws_alpha') 56 | 57 | # output a weighted sum of inputs, only supports a weighted sum of two inputs 58 | def _merge_function(self, inputs): 59 | assert (len(inputs) == 2) 60 | output = ((1.0 - self.alpha) * inputs[0]) + (self.alpha * inputs[1]) 61 | return output 62 | 63 | def wasserstein_loss(y_true, y_pred): 64 | return K.mean(y_true * y_pred) 65 | 66 | def add_discriminator_block(old_model, n_input_layers=3): 67 | init = initializers.RandomNormal(stddev=0.02) 68 | const = constraints.max_norm(1.0) 69 | # get shape of existing model 70 | in_shape = list(old_model.input.shape) 71 | # define new input shape as double the size 72 | input_shape = (in_shape[-3]*2, in_shape[-2]*2, in_shape[-1]) 73 | in_image = layers.Input(shape=input_shape) 74 | # define new input processing layer 75 | d = layers.Conv2D(128, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(in_image) 76 | d = layers.LeakyReLU(alpha=0.2)(d) 77 | # define new block 78 | d = layers.Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d) 79 | d = layers.LeakyReLU(alpha=0.2)(d) 80 | d = layers.Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d) 81 | d = layers.LeakyReLU(alpha=0.2)(d) 82 | d = layers.AveragePooling2D()(d) 83 | block_new = d 84 | # skip the input, 1x1 and activation for the old model 85 | for i in range(n_input_layers, len(old_model.layers)): 86 | d = old_model.layers[i](d) 87 | # define straight-through model 88 | model1 = models.Model(in_image, d) 89 | model1.compile(loss=wasserstein_loss, optimizer=optimizers.Adam(lr=0.001, beta_1=0, beta_2=0.99, epsilon=10e-8)) 90 | # downsample the new larger image 91 | downsample = layers.AveragePooling2D()(in_image) 92 | # connect old input processing to downsampled new input 93 | block_old = old_model.layers[1](downsample) 94 | block_old = old_model.layers[2](block_old) 95 | # fade in output of old model input layer with new input 96 | d = WeightedSum()([block_old, block_new]) 97 | # skip the input, 1x1 and activation for the old model 98 | for i in range(n_input_layers, len(old_model.layers)): 99 | d = old_model.layers[i](d) 100 | # define straight-through model 101 | model2 = models.Model(in_image, d) 102 | model2.compile(loss=wasserstein_loss, optimizer=optimizers.Adam(lr=0.001, beta_1=0, beta_2=0.99, epsilon=10e-8)) 103 | return [model1, model2] 104 | 105 | # define the discriminator models for each image resolution 106 | def define_discriminator(n_blocks, input_shape=(4,4,3)): 107 | init = initializers.RandomNormal(stddev=0.02) 108 | const = constraints.max_norm(1.0) 109 | model_list = list() 110 | in_image = layers.Input(shape=input_shape) 111 | # conv 1x1 112 | d = layers.Conv2D(128, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(in_image) 113 | d = layers.LeakyReLU(alpha=0.2)(d) 114 | # conv 3x3 (output block) 115 | d = MinibatchStdev()(d) 116 | d = layers.Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(d) 117 | d = layers.LeakyReLU(alpha=0.2)(d) 118 | # conv 4x4 119 | d = layers.Conv2D(128, (4,4), padding='same', kernel_initializer=init, kernel_constraint=const)(d) 120 | d = layers.LeakyReLU(alpha=0.2)(d) 121 | # dense output layer 122 | d = layers.Flatten()(d) 123 | out_class = layers.Dense(1)(d) 124 | 125 | model = models.Model(in_image, out_class) 126 | model.compile(loss=wasserstein_loss, optimizer=optimizers.Adam(lr=0.001, beta_1=0, beta_2=0.99, epsilon=10e-8)) 127 | model_list.append([model, model]) 128 | # create submodels 129 | for i in range(1, n_blocks): 130 | # get prior model without the fade-on 131 | old_model = model_list[i-1][0] 132 | # create new model for next resolution 133 | new_models = add_discriminator_block(old_model) 134 | model_list.append(new_models) 135 | return model_list 136 | 137 | def add_generator_block(old_model): 138 | init = initializers.RandomNormal(stddev=0.02) 139 | const = constraints.max_norm(1.0) 140 | # get the end of the last block 141 | block_end = old_model.layers[-2].output 142 | # upsample, and define new block 143 | upsampling = layers.UpSampling2D()(block_end) 144 | g = layers.Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(upsampling) 145 | g = PixelNormalization()(g) 146 | g = layers.LeakyReLU(alpha=0.2)(g) 147 | g = layers.Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(g) 148 | g = PixelNormalization()(g) 149 | g = layers.LeakyReLU(alpha=0.2)(g) 150 | # add new output layer 151 | out_image = layers.Conv2D(3, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(g) 152 | 153 | model1 = models.Model(old_model.input, out_image) 154 | # get the output layer from old model 155 | out_old = old_model.layers[-1] 156 | # connect the upsampling to the old output layer 157 | out_image2 = out_old(upsampling) 158 | # define new output image as the weighted sum of the old and new models 159 | merged = WeightedSum()([out_image2, out_image]) 160 | 161 | model2 = models.Model(old_model.input, merged) 162 | return [model1, model2] 163 | 164 | def define_generator(latent_dim, n_blocks, in_dim=4): 165 | init = initializers.RandomNormal(stddev=0.02) 166 | const = constraints.max_norm(1.0) 167 | model_list = list() 168 | in_latent = layers.Input(shape=(latent_dim,)) 169 | # linear scale up to activation maps 170 | g = layers.Dense(128 * in_dim * in_dim, kernel_initializer=init, kernel_constraint=const)(in_latent) 171 | g = layers.Reshape((in_dim, in_dim, 128))(g) 172 | # conv 4x4, input block 173 | g = layers.Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(g) 174 | g = PixelNormalization()(g) 175 | g = layers.LeakyReLU(alpha=0.2)(g) 176 | # conv 3x3 177 | g = layers.Conv2D(128, (3,3), padding='same', kernel_initializer=init, kernel_constraint=const)(g) 178 | g = PixelNormalization()(g) 179 | g = layers.LeakyReLU(alpha=0.2)(g) 180 | # conv 1x1, output block 181 | out_image = layers.Conv2D(3, (1,1), padding='same', kernel_initializer=init, kernel_constraint=const)(g) 182 | 183 | model = models.Model(in_latent, out_image) 184 | model_list.append([model, model]) 185 | # create submodels 186 | for i in range(1, n_blocks): 187 | # get prior model without the fade-on 188 | old_model = model_list[i-1][0] 189 | # create new model for next resolution 190 | new_models = add_generator_block(old_model) 191 | model_list.append(new_models) 192 | return model_list 193 | 194 | # define composite models for training generators via discriminators 195 | def define_composite(discriminators, generators): 196 | model_list = list() 197 | for i in range(len(discriminators)): 198 | g_models, d_models = generators[i], discriminators[i] 199 | # straight-through model 200 | d_models[0].trainable = False 201 | model1 = models.Sequential() 202 | model1.add(g_models[0]) 203 | model1.add(d_models[0]) 204 | model1.compile(loss=wasserstein_loss, optimizer=optimizers.Adam(lr=0.001, beta_1=0, beta_2=0.99, epsilon=10e-8)) 205 | # fade-in model 206 | d_models[1].trainable = False 207 | model2 = models.Sequential() 208 | model2.add(g_models[1]) 209 | model2.add(d_models[1]) 210 | model2.compile(loss=wasserstein_loss, optimizer=optimizers.Adam(lr=0.001, beta_1=0, beta_2=0.99, epsilon=10e-8)) 211 | model_list.append([model1, model2]) 212 | return model_list 213 | 214 | # load images from folder 215 | def load_real_samples(imgs_path): 216 | imgs = [] 217 | for img_name in os.listdir(imgs_path): 218 | img_path = join(imgs_path, img_name) 219 | if isfile(img_path): 220 | img = io.imread(img_path) 221 | img = np.float32(img) 222 | imgs.append(img) 223 | imgs = np.array(imgs, dtype=np.float32) 224 | # scale from [0,255] to [-1,1] 225 | imgs = imgs / 127.5 - 1. 226 | return imgs 227 | 228 | # select real samples 229 | def generate_real_samples(dataset, n_samples): 230 | # select random instances 231 | X = dataset[np.random.randint(0, dataset.shape[0], n_samples)] 232 | # generate class labels 233 | y = np.ones((n_samples, 1)) 234 | return X, y 235 | 236 | # generate points in latent space as input for the generator 237 | def generate_latent_points(latent_dim, n_samples): 238 | # generate points in the latent space 239 | x_input = np.random.randn(latent_dim * n_samples) 240 | # reshape into a batch of inputs for the network 241 | x_input = x_input.reshape(n_samples, latent_dim) 242 | return x_input 243 | 244 | # use the generator to generate n fake examples, with class labels 245 | def generate_fake_samples(generator, latent_dim, n_samples): 246 | # generate points in latent space 247 | x_input = generate_latent_points(latent_dim, n_samples) 248 | # predict outputs 249 | X = generator.predict(x_input) 250 | # create class labels 251 | y = -np.ones((n_samples, 1)) 252 | return X, y 253 | 254 | # update the alpha value on each instance of WeightedSum 255 | def update_fadein(models, step, n_steps): 256 | # calculate current alpha (linear from 0 to 1) 257 | alpha = step / float(n_steps - 1) 258 | # update the alpha for each model 259 | for model in models: 260 | for layer in model.layers: 261 | if isinstance(layer, WeightedSum): 262 | K.set_value(layer.alpha, alpha) 263 | 264 | # train a generator and discriminator 265 | def train_epochs(g_model, d_model, gan_model, dataset, n_epochs, n_batch, fadein=False): 266 | # calculate the number of batches per training epoch 267 | bat_per_epo = int(dataset.shape[0] / n_batch) 268 | # calculate the number of training iterations 269 | n_steps = bat_per_epo * n_epochs 270 | # calculate the size of half a batch of samples 271 | half_batch = int(n_batch / 2) 272 | # manually enumerate epochs 273 | for i in range(n_steps): 274 | # update alpha for all WeightedSum layers when fading in new blocks 275 | if fadein: 276 | update_fadein([g_model, d_model, gan_model], i, n_steps) 277 | # prepare real and fake samples 278 | X_real, y_real = generate_real_samples(dataset, half_batch) 279 | X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch) 280 | # update discriminator model 281 | d_loss1 = d_model.train_on_batch(X_real, y_real) 282 | d_loss2 = d_model.train_on_batch(X_fake, y_fake) 283 | # update the generator via the discriminator's error 284 | z_input = generate_latent_points(latent_dim, n_batch) 285 | y_real2 = np.ones((n_batch, 1)) 286 | g_loss = gan_model.train_on_batch(z_input, y_real2) 287 | # summarize loss on this batch 288 | print('>%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, d_loss1, d_loss2, g_loss)) 289 | 290 | # scale images to preferred size with nearest neighbor interpolation 291 | def scale_dataset(images, new_shape): 292 | images_list = list() 293 | for image in images: 294 | new_image = resize(image, new_shape, 0) 295 | images_list.append(new_image) 296 | return np.asarray(images_list) 297 | 298 | # generate samples and save as a plot and save the model 299 | def summarize_performance(status, g_model, latent_dim, n_samples=25): 300 | gen_shape = g_model.output_shape 301 | name = '%03dx%03d-%s' % (gen_shape[1], gen_shape[2], status) 302 | 303 | X, _ = generate_fake_samples(g_model, latent_dim, n_samples) 304 | # normalize pixel values to the range [0,1] 305 | X = (X - X.min()) / (X.max() - X.min()) 306 | # plot real images 307 | square = int(np.sqrt(n_samples)) 308 | for i in range(n_samples): 309 | pyplot.subplot(square, square, 1 + i) 310 | pyplot.axis('off') 311 | pyplot.imshow(X[i]) 312 | # save plot to file 313 | filename1 = 'plot_%s.png' % (name) 314 | pyplot.savefig(filename1) 315 | pyplot.close() 316 | # save the generator model 317 | filename2 = 'model_%s.h5' % (name) 318 | g_model.save(filename2) 319 | print('>Saved: %s and %s' % (filename1, filename2)) 320 | 321 | # train the generator and discriminator 322 | def train(g_models, d_models, gan_models, dataset, latent_dim, e_norm, e_fadein, n_batch): 323 | # fit the baseline model 324 | g_normal, d_normal, gan_normal = g_models[0][0], d_models[0][0], gan_models[0][0] 325 | # scale dataset to appropriate size 326 | gen_shape = g_normal.output_shape 327 | scaled_data = scale_dataset(dataset, gen_shape[1:]) 328 | print('Scaled Data', scaled_data.shape) 329 | # train normal or straight-through models 330 | train_epochs(g_normal, d_normal, gan_normal, scaled_data, e_norm[0], n_batch[0]) 331 | summarize_performance('tuned', g_normal, latent_dim) 332 | # process each level of growth 333 | for i in range(1, len(g_models)): 334 | # retrieve models for this level of growth 335 | [g_normal, g_fadein] = g_models[i] 336 | [d_normal, d_fadein] = d_models[i] 337 | [gan_normal, gan_fadein] = gan_models[i] 338 | # scale dataset to appropriate size 339 | gen_shape = g_normal.output_shape 340 | scaled_data = scale_dataset(dataset, gen_shape[1:]) 341 | print('Scaled Data', scaled_data.shape) 342 | # train fade-in models for next level of growth 343 | train_epochs(g_fadein, d_fadein, gan_fadein, scaled_data, e_fadein[i], n_batch[i], True) 344 | summarize_performance('faded', g_fadein, latent_dim) 345 | # train normal or straight-through models 346 | train_epochs(g_normal, d_normal, gan_normal, scaled_data, e_norm[i], n_batch[i]) 347 | summarize_performance('tuned', g_normal, latent_dim) 348 | 349 | # call this function when using GPU with small memory 350 | def using_gpu_memory_growth(): 351 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 352 | assert len(physical_devices) > 0, "Not enough GPU hardware devices available" 353 | config = tf.config.experimental.set_memory_growth(physical_devices[0], True) 354 | return config 355 | 356 | if __name__ == "__main__": 357 | using_gpu_memory_growth() 358 | # number of growth phases, e.g. 6 == [4, 8, 16, 32, 64, 128] 359 | n_blocks = 6 360 | # size of the latent space 361 | latent_dim = 100 362 | 363 | n_batch = [16, 16, 16, 8, 4, 4] 364 | # 1 epochs == 3K images per training phase 365 | n_epochs = [5, 8, 8, 10, 10, 10] 366 | 367 | dataset = load_real_samples('./Projects/GAN_Exp/Face_64X64/') 368 | print('Loaded', dataset.shape) 369 | 370 | d_models = define_discriminator(n_blocks) 371 | g_models = define_generator(latent_dim, n_blocks) 372 | gan_models = define_composite(d_models, g_models) 373 | 374 | train(g_models, d_models, gan_models, dataset, latent_dim, n_epochs, n_epochs, n_batch) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Models Collection 2 | Generative Model related code and info 3 | 4 | | ML Models | Reference | 5 | | ------------- | ------------- | 6 | | [Self-Attention Generative Adversarial Networks](SAGAN.py) | [Arxiv](https://arxiv.org/abs/1805.08318) | 7 | | [Multi-Scale Gradients Generative Adversarial Networks](MSGGAN.py) | [Arxiv](https://arxiv.org/abs/1903.06048) | 8 | | [Cycle-Consistent Generative Adversarial Networks](CycleGAN.py) | [Machine Learning Mastery](https://machinelearningmastery.com/cyclegan-tutorial-with-keras/); [Arxiv](https://arxiv.org/abs/1703.10593) | 9 | | [Progressive Growing Generative Adversarial Networks](PGGAN.py) | [Machine Learning Mastery](https://machinelearningmastery.com/how-to-train-a-progressive-growing-gan-in-keras-for-synthesizing-faces/); [Arxiv](https://arxiv.org/abs/1710.10196) | 10 | | [Neural style transfer](NST.py) | [Tensorflow](https://www.tensorflow.org/tutorials/generative/style_transfer); [Arxiv](https://arxiv.org/abs/1508.06576) | 11 | 12 |
13 | 14 | ## Training & Stabilizing 15 | * Using small dataset like [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html), [CelebA](https://www.kaggle.com/jessicali9530/celeba-dataset), [Horses2Zebra](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip) to test algorithm before scale whole thing up: save time, whatever you can! 16 | 17 | * Normalize the inputs between -1 and 1 for both generator and discriminator 18 | 19 | * Sample from gaussian distribution rather than uniform distribution 20 | 21 | * All normalizations methods are potentially helpful, **SpectralNorm**, **BatchNorm**(when have label using **ConditionalBatchNorm** or don't have label using **SelfModulationBatchNorm**), consider other methods like **InstanceNorm** or **PixelNorm** when BatchNorm is not an option 22 | 23 | * Avoid Sparse Gradients(e.g **ReLU**, **MaxPool**), using **LeakyReLU** in both generator(could also use **ReLU**) and discriminator, For Downsampling, use **Average Pooling**, **Conv2d + stride**, For Upsampling, use **NearestUpsampling**, **PixelShuffle**, **ConvTranspose2d + stride** 24 | 25 | * Model architecture with usual form of **ResnetBlock** is usually better than without 26 | 27 | * Overfitting of generator usually happen after long epochs of training, generator starting generate good samples make discriminator hard to discriminate, so discriminator starting produce misleading gradients to generator, to avoid it use data augmentation technique call the [**Adaptive Discriminator Augmentation(ADA)**](https://github.com/NVlabs/stylegan2-ada) 28 | 29 | * Use stability tricks from RL e.g **Experience Replay** 30 | 31 | * Try different loss function 32 | 33 | ``` 34 | # usually hinge loss works well 35 | L_D = −E[min(0,−1 +D(x))] − E[min(0,−1−D(G(z)))] 36 | L_G = −E[G(z)] 37 | ``` 38 | ![](Summary-of-Different-GAN-Loss-Functions.png) 39 | 40 | * Optimizer **ADAM** and **RMSprop** usually work well 41 | 42 | * Track failures early using plot(e.g tensorboard), some common failure mode are: 43 | * discriminator loss goes to 0 44 | * gradients of generator or discriminator extremely large or extremely small 45 | * If loss of the generator steadily decreases, it is likely fooling the discriminator with garbage images 46 | * when outputs of generator have large percentage of repetitions then it's mode collapse 47 | 48 | * Other methods for improve convergence of GAN: 49 | * Feature matching: Develop a GAN using semi-supervised learning 50 | * Minibatch discrimination: Develop features across multiple samples in a minibatch 51 | * Virtual batch normalization: Calculation of batch norm statistics using a reference batch of real images 52 | * One side soft label for some loss function(e.g sigmoid cross entropy loss): real=0.9, fake=0 53 | * Two-timescale update rule(TTUR): using separate learning rates for the generator and the discriminator(usually lr_d > lr_g), making it possibleto use fewer discriminator steps per generator step 54 | * All regularization and constraint methods are potentially helpful, e.g use dropout of 50 percent during train and generation 55 | -------------------------------------------------------------------------------- /SAGAN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Self-Attention GAN on cifar10 dataset 3 | CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html 4 | ''' 5 | import os 6 | import time 7 | import datetime 8 | 9 | import tensorflow as tf 10 | from tensorflow.keras import layers, models, initializers, constraints, optimizers 11 | from tensorflow_addons.layers import InstanceNormalization 12 | import tensorflow.keras.backend as K 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | 16 | def l2_normalize(x, eps=1e-12): 17 | ''' 18 | Scale input by the inverse of it's euclidean norm 19 | ''' 20 | return x / tf.linalg.norm(x + eps) 21 | 22 | class Spectral_Norm(constraints.Constraint): 23 | ''' 24 | Uses power iteration method to calculate a fast approximation 25 | of the spectral norm (Golub & Van der Vorst) 26 | The weights are then scaled by the inverse of the spectral norm 27 | ''' 28 | def __init__(self, power_iters=5): 29 | self.n_iters = power_iters 30 | 31 | def __call__(self, w): 32 | flattened_w = tf.reshape(w, [w.shape[0], -1]) 33 | u = tf.random.normal([flattened_w.shape[0]]) 34 | v = tf.random.normal([flattened_w.shape[1]]) 35 | for i in range(self.n_iters): 36 | v = tf.linalg.matvec(tf.transpose(flattened_w), u) 37 | v = l2_normalize(v) 38 | u = tf.linalg.matvec(flattened_w, v) 39 | u = l2_normalize(u) 40 | sigma = tf.tensordot(u, tf.linalg.matvec(flattened_w, v), axes=1) 41 | return w / sigma 42 | 43 | def get_config(self): 44 | return {'n_iters': self.n_iters} 45 | 46 | def make_discriminator_model(): 47 | model = tf.keras.Sequential() 48 | model.add(layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same', 49 | kernel_initializer='glorot_uniform', 50 | kernel_constraint=Spectral_Norm())) 51 | model.add(layers.LeakyReLU(0.1)) 52 | model.add(layers.Conv2D(64, kernel_size=(4,4), strides=(2,2), padding='same', 53 | kernel_initializer='glorot_uniform', 54 | kernel_constraint=Spectral_Norm())) 55 | model.add(layers.LeakyReLU(0.1)) 56 | model.add(layers.Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same', 57 | kernel_initializer='glorot_uniform', 58 | kernel_constraint=Spectral_Norm())) 59 | model.add(layers.LeakyReLU(0.1)) 60 | model.add(layers.Conv2D(128, kernel_size=(4,4), strides=(2,2), padding='same', 61 | kernel_initializer='glorot_uniform', 62 | kernel_constraint=Spectral_Norm())) 63 | model.add(layers.LeakyReLU(0.1)) 64 | model.add(layers.Conv2D(256, kernel_size=(3,3), strides=(1,1), padding='same', 65 | kernel_initializer='glorot_uniform', 66 | kernel_constraint=Spectral_Norm())) 67 | model.add(layers.LeakyReLU(0.1)) 68 | model.add(layers.Conv2D(256, kernel_size=(4,4), strides=(2,2), padding='same', 69 | kernel_initializer='glorot_uniform', 70 | kernel_constraint=Spectral_Norm())) 71 | model.add(layers.LeakyReLU(0.1)) 72 | model.add(layers.Conv2D(512, kernel_size=(3,3), strides=(1,1), padding='same', 73 | kernel_initializer='glorot_uniform', 74 | kernel_constraint=Spectral_Norm())) 75 | model.add(layers.LeakyReLU(0.1)) 76 | model.add(layers.Flatten()) 77 | model.add(layers.Dense(1, kernel_constraint=Spectral_Norm())) 78 | return model 79 | 80 | class ResnetBlockGen(tf.keras.Model): 81 | def __init__(self, kernel_size, filters, pad='same'): 82 | super(ResnetBlockGen, self).__init__(name='') 83 | 84 | self.bn1 = tf.keras.layers.BatchNormalization() 85 | self.deconv2a = tf.keras.layers.Conv2DTranspose(filters, kernel_size, 86 | padding=pad) 87 | 88 | 89 | self.bn2 = tf.keras.layers.BatchNormalization() 90 | self.deconv2b = tf.keras.layers.Conv2DTranspose(filters, kernel_size, 91 | padding=pad) 92 | 93 | self.up_sample = tf.keras.layers.UpSampling2D(size=(2,2)) 94 | self.shortcut_conv = tf.keras.layers.Conv2DTranspose(filters, 95 | kernel_size=1, 96 | padding=pad) 97 | 98 | def call(self, input_tensor, training=False): 99 | x = self.bn1(input_tensor) 100 | x = tf.nn.relu(x) 101 | x = self.up_sample(x) 102 | x = self.deconv2a(x) 103 | 104 | x = self.bn2(x) 105 | x = tf.nn.relu(x) 106 | x = self.deconv2b(x) 107 | 108 | sc_x = self.up_sample(self.shortcut_conv(input_tensor)) 109 | return x + sc_x 110 | 111 | class ResnetBlockDisc(tf.keras.Model): 112 | def __init__(self, filters, kernel_size=3, downsample=False, pad='same'): 113 | super(ResnetBlockDisc, self).__init__(name='') 114 | self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding=pad, 115 | kernel_initializer='glorot_uniform', 116 | kernel_constraint=Spectral_Norm()) 117 | self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size, padding=pad, 118 | kernel_initializer='glorot_uniform', 119 | kernel_constraint=Spectral_Norm()) 120 | self.shortcut_conv = tf.keras.layers.Conv2D(filters, kernel_size=(1,1), 121 | kernel_initializer='glorot_uniform', 122 | padding=pad, 123 | kernel_constraint=Spectral_Norm()) 124 | self.downsample_layer = tf.keras.layers.AvgPool2D((2,2)) 125 | self.downsample = downsample 126 | 127 | def residual(self, x): 128 | h = x 129 | h = tf.nn.relu(h) 130 | h = self.conv1(h) 131 | h = tf.nn.relu(h) 132 | h = self.conv2(h) 133 | if self.downsample: 134 | h = self.downsample_layer(h) 135 | return h 136 | 137 | def shortcut(self, x): 138 | h2 = x 139 | if self.downsample: 140 | h2 = self.downsample_layer(x) 141 | return self.shortcut_conv(h2) 142 | 143 | def call(self, x): 144 | return self.residual(x) + self.shortcut(x) 145 | 146 | class OptimizedBlock(tf.keras.Model): 147 | def __init__(self, filters, kernel_size=3, pad='same'): 148 | super(OptimizedBlock, self).__init__(name='') 149 | self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size, padding=pad, 150 | kernel_initializer='glorot_uniform', 151 | kernel_constraint=Spectral_Norm()) 152 | self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size, padding=pad, 153 | kernel_initializer='glorot_uniform', 154 | kernel_constraint=Spectral_Norm()) 155 | self.shortcut_conv = tf.keras.layers.Conv2D(filters, kernel_size=(1,1), 156 | kernel_initializer='glorot_uniform', 157 | padding=pad, 158 | kernel_constraint=Spectral_Norm()) 159 | self.downsample_layer = tf.keras.layers.AvgPool2D((2,2)) 160 | 161 | def residual(self, x): 162 | h = x 163 | h = self.conv1(h) 164 | h = tf.nn.relu(h) 165 | h = self.conv2(h) 166 | h = self.downsample_layer(h) 167 | return h 168 | 169 | def shortcut(self, x): 170 | return self.shortcut_conv(self.downsample_layer(x)) 171 | 172 | def call(self, x): 173 | return self.residual(x) + self.shortcut(x) 174 | 175 | class SelfAttentionBlock(tf.keras.Model): 176 | def __init__(self): 177 | super(SelfAttentionBlock, self).__init__() 178 | self.sigma = K.variable(0.0, name='sigma') 179 | self.phi_pool = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2) 180 | self.g_pool = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2) 181 | 182 | def build(self, inp): 183 | batch_size, h, w, n_channels = inp 184 | self.batch_size = batch_size 185 | self.n_channels = n_channels 186 | self.h = h 187 | self.w = w 188 | 189 | self.location_num = h*w 190 | self.downsampled_num = self.location_num // 4 191 | 192 | self.theta = tf.keras.layers.Conv2D(filters = n_channels // 8, 193 | kernel_size=[1, 1], 194 | padding= 'same', 195 | strides=(1,1)) 196 | self.phi = tf.keras.layers.Conv2D(filters = n_channels // 8, 197 | kernel_size=[1, 1], 198 | padding='same', 199 | strides=(1,1)) 200 | self.attn_conv = tf.keras.layers.Conv2D(self.n_channels, kernel_size=[1, 1]) 201 | 202 | self.g = tf.keras.layers.Conv2D(filters = n_channels // 2, 203 | kernel_size=[1, 1]) 204 | 205 | def call(self, x): 206 | theta = self.theta(x) 207 | theta = tf.reshape(theta, [self.batch_size, self.location_num, 208 | self.n_channels // 8]) 209 | 210 | phi = self.phi(x) 211 | phi = self.phi_pool(phi) 212 | phi = tf.reshape(phi, [self.batch_size, self.downsampled_num, 213 | self.n_channels // 8]) 214 | 215 | 216 | attn = tf.matmul(theta, phi, transpose_b=True) 217 | attn = tf.nn.softmax(attn) 218 | 219 | # g path 220 | g = self.g(x) 221 | g = self.g_pool(g) 222 | g = tf.reshape( 223 | g, [self.batch_size, self.downsampled_num, self.n_channels // 2]) 224 | 225 | attn_g = tf.matmul(attn, g) 226 | attn_g = tf.reshape(attn_g, [self.batch_size, self.h, self.w, 227 | self.n_channels // 2]) 228 | attn_g = self.attn_conv(attn_g) 229 | return x + (attn_g * self.sigma) 230 | 231 | def make_resnet_generator_model(): 232 | model = tf.keras.Sequential() 233 | model.add(tf.keras.layers.Dense(16*256, kernel_initializer='glorot_uniform')) 234 | model.add(tf.keras.layers.Reshape((4, 4, 256))) 235 | model.add(ResnetBlockGen(3, 256)) 236 | model.add(ResnetBlockGen(3, 256, pad='same')) 237 | model.add(SelfAttentionBlock()) 238 | model.add(ResnetBlockGen(3, 256, pad='same')) 239 | model.add(tf.keras.layers.BatchNormalization()) 240 | model.add(tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=(1,1), 241 | padding='same', activation='tanh')) 242 | return model 243 | 244 | def make_resnet_discriminator_model(): 245 | model = tf.keras.Sequential() 246 | model.add(OptimizedBlock(128)) 247 | model.add(ResnetBlockDisc(128, downsample=True)) 248 | model.add(ResnetBlockDisc(128, downsample=False)) 249 | model.add(ResnetBlockDisc(128, downsample=False)) 250 | model.add(tf.keras.layers.ReLU()) 251 | model.add(tf.keras.layers.GlobalAvgPool2D(data_format='channels_last' )) 252 | model.add(tf.keras.layers.Dense(1, kernel_initializer='glorot_uniform', 253 | kernel_constraint=Spectral_Norm())) 254 | return model 255 | 256 | class ConditionalBatchNorm(layers.Layer): 257 | def __init__(self, num_categories, decay_rate=0.999, 258 | center=True, scale=True): 259 | super(ConditionalBatchNorm, self).__init__() 260 | self.num_categories = num_categories 261 | self.center = center 262 | self.scale = scale 263 | self.decay_rate = decay_rate 264 | 265 | def build(self, input_size): 266 | self.inputs_shape = tf.TensorShape(input_size) 267 | params_shape = self.inputs_shape[-1:] 268 | axis = [0, 1, 2] 269 | shape = tf.TensorShape([self.num_categories]).concatenate(params_shape) 270 | moving_shape = tf.TensorShape([1,1,1]).concatenate(params_shape) 271 | 272 | self.gamma = self.add_variable(name='gamma', shape=shape, 273 | initializer='ones') 274 | self.beta = self.add_variable(name='beta', shape=shape, 275 | initializer='zeros') 276 | 277 | self.moving_mean = self.add_variable(name='mean', 278 | shape=moving_shape, 279 | initializer='zeros', 280 | trainable=False) 281 | self.moving_var = self.add_variable(name='var', 282 | shape=moving_shape, 283 | initializer='ones', 284 | trainable=False) 285 | 286 | 287 | def call(self, inputs, labels, is_training=True): 288 | beta = tf.gather(self.beta, labels) 289 | beta = tf.expand_dims(tf.expand_dims(beta, 1), 1) 290 | gamma = tf.gather(self.gamma, labels) 291 | gamma = tf.expand_dims(tf.expand_dims(gamma, 1), 1) 292 | decay = self.decay_rate 293 | variance_epsilon = 1e-5 294 | axis = [0, 1, 2] 295 | if is_training: 296 | mean, variance = tf.nn.moments(inputs, axis, keepdims=True) 297 | self.moving_mean.assign(self.moving_mean * decay + mean * (1 - decay)) 298 | self.moving_var.assign(self.moving_var * decay + variance * (1 - decay)) 299 | outputs = tf.nn.batch_normalization( 300 | inputs, mean, variance, beta, gamma, variance_epsilon) 301 | else: 302 | outputs = tf.nn.batch_normalization( 303 | inputs, self.moving_mean, self.moving_var, 304 | beta, gamma, variance_epsilon) 305 | outputs.set_shape(self.inputs_shape) 306 | return outputs 307 | 308 | class Generator_CBN(tf.keras.Model): 309 | def __init__(self, n_classes, training=True): 310 | super(Generator_CBN, self).__init__() 311 | self.linear1 = layers.Dense(4*4*512, use_bias=False, 312 | kernel_initializer='glorot_uniform') 313 | self.reshape = layers.Reshape([4, 4, 512]) 314 | 315 | self.cbn1 = ConditionalBatchNorm(n_classes) 316 | self.deconv1 = layers.Conv2DTranspose(filters=256, kernel_size=(4,4), 317 | strides=2, padding='same') 318 | 319 | self.cbn2 = ConditionalBatchNorm(n_classes) 320 | self.deconv2 = layers.Conv2DTranspose(filters=128, kernel_size=(4,4), 321 | strides=2, padding='same') 322 | 323 | self.cbn3 = ConditionalBatchNorm(n_classes) 324 | self.deconv3 = layers.Conv2DTranspose(filters=64, kernel_size=(4,4), 325 | strides=2, padding='same') 326 | 327 | self.cbn4 = ConditionalBatchNorm(n_classes) 328 | self.deconv4 = layers.Conv2DTranspose(filters=3, kernel_size=(3,3), 329 | strides=1, use_bias=True, 330 | padding='same') 331 | 332 | 333 | def call(self, inp, labels): 334 | x = self.linear1(inp) 335 | x = self.reshape(x) 336 | 337 | x = self.cbn1(x, labels) 338 | x = tf.nn.relu(x) 339 | x = self.deconv1(x) 340 | 341 | x = self.cbn2(x, labels) 342 | x = tf.nn.relu(x) 343 | x = self.deconv2(x) 344 | 345 | x = self.cbn3(x, labels) 346 | x = tf.nn.relu(x) 347 | x = self.deconv3(x) 348 | 349 | x = self.cbn4(x, labels) 350 | x = tf.nn.relu(x) 351 | x = self.deconv4(x) 352 | 353 | x = tf.nn.tanh(x) 354 | return x 355 | 356 | def make_cbn_generator_model(): 357 | model = Generator_CBN(n_classes=10) 358 | return model 359 | 360 | class SelfModulationBN(layers.Layer): 361 | def __init__(self, z_size=128, decay_rate=0.999, 362 | center=True, scale=True): 363 | super(SelfModulationBN, self).__init__() 364 | self.z_size = z_size 365 | self.center = center 366 | self.scale = scale 367 | self.decay_rate = decay_rate 368 | 369 | 370 | def build(self, input_size): 371 | self.inputs_shape = tf.TensorShape(input_size) 372 | params_shape = self.inputs_shape[-1:] 373 | batch_shape = self.inputs_shape[0] 374 | z_shape = tf.TensorShape(self.z_size) 375 | mlp_shape = z_shape.concatenate(params_shape) 376 | moving_shape = tf.TensorShape([batch_shape,1,1]).concatenate(params_shape) 377 | self.moving_mean = self.add_variable(name='mean', 378 | shape=moving_shape, 379 | initializer='zeros', 380 | trainable=False) 381 | self.moving_var = self.add_variable(name='var', 382 | shape=moving_shape, 383 | initializer='ones', 384 | trainable=False) 385 | self.ln_gamma = self.add_variable(name='gamma', shape=mlp_shape, 386 | initializer='glorot_normal') 387 | self.ln_beta = self.add_variable(name='beta', shape=mlp_shape, 388 | initializer='glorot_normal') 389 | 390 | def call(self, inputs, z, is_training=True): 391 | z = tf.squeeze(z, axis=[1,2]) 392 | gamma = tf.matmul(z, self.ln_gamma) 393 | gamma = tf.nn.relu(gamma) 394 | beta = tf.matmul(z, self.ln_beta) 395 | beta = tf.nn.relu(beta) 396 | gamma = tf.expand_dims(tf.expand_dims(gamma, axis=1), axis=1) 397 | beta = tf.expand_dims(tf.expand_dims(beta, axis=1), axis=1) 398 | decay = self.decay_rate 399 | variance_epsilon = 1e-5 400 | axis = [0, 1, 2] 401 | mean, variance = tf.nn.moments(inputs, axis, keepdims=True) 402 | self.moving_mean.assign(self.moving_mean * decay + mean * (1 - decay)) 403 | self.moving_var.assign(self.moving_var * decay + variance * (1 - decay)) 404 | outputs = tf.nn.batch_normalization( 405 | inputs, mean, variance, beta, gamma, variance_epsilon) 406 | outputs.set_shape(self.inputs_shape) 407 | return outputs 408 | 409 | class Generator_SBN(tf.keras.Model): 410 | def __init__(self, z_shape, training=True): 411 | super(Generator_SBN, self).__init__() 412 | self.linear1 = layers.Dense(4*4*512, use_bias=False, 413 | kernel_initializer='glorot_uniform') 414 | self.reshape = layers.Reshape([4, 4, 512]) 415 | 416 | self.sbn1 = SelfModulationBN(z_shape) 417 | self.deconv1 = layers.Conv2DTranspose(filters=256, kernel_size=(4,4), 418 | strides=2, padding='same') 419 | 420 | self.sbn2 = SelfModulationBN(z_shape) 421 | self.deconv2 = layers.Conv2DTranspose(filters=128, kernel_size=(4,4), 422 | strides=2, padding='same') 423 | 424 | self.sbn3 = SelfModulationBN(z_shape) 425 | self.deconv3 = layers.Conv2DTranspose(filters=64, kernel_size=(4,4), 426 | strides=2, padding='same') 427 | 428 | self.sbn4 = SelfModulationBN(z_shape) 429 | self.deconv4 = layers.Conv2DTranspose(filters=3, kernel_size=(3,3), 430 | strides=1, use_bias=True, 431 | padding='same') 432 | 433 | def call(self, inp): 434 | x = self.linear1(inp) 435 | x = self.reshape(x) 436 | 437 | x = self.sbn1(x, inp) 438 | x = tf.nn.relu(x) 439 | x = self.deconv1(x) 440 | 441 | x = self.sbn2(x, inp) 442 | x = tf.nn.relu(x) 443 | x = self.deconv2(x) 444 | 445 | x = self.sbn3(x, inp) 446 | x = tf.nn.relu(x) 447 | x = self.deconv3(x) 448 | 449 | x = self.sbn4(x, inp) 450 | x = tf.nn.relu(x) 451 | x = self.deconv4(x) 452 | 453 | x = tf.nn.tanh(x) 454 | return x 455 | 456 | def make_sbn_generator_model(): 457 | model = Generator_SBN(z_shape=128) 458 | return model 459 | 460 | def discriminator_loss(real_output, fake_output): 461 | L1 = K.mean(K.softplus(-real_output)) 462 | L2 = K.mean(K.softplus(fake_output)) 463 | loss = L1 + L2 464 | return loss 465 | 466 | def discriminator_hinge_loss(real_output, fake_output): 467 | loss = K.mean(K.relu(1. - real_output)) 468 | loss += K.mean(K.relu(1. + fake_output)) 469 | return loss 470 | 471 | def generator_loss(fake_output): 472 | return K.mean(K.softplus(-fake_output)) 473 | 474 | def generator_hinge_loss(fake_output): 475 | return -1 * K.mean(fake_output) 476 | 477 | EPOCHS = 50 478 | BATCH_SIZE = 32 479 | noise_dim = 128 480 | num_examples_to_generate = 8 481 | seed = tf.random.normal([BATCH_SIZE, 1, 1, noise_dim]) 482 | 483 | def generate_and_save_images(model, epoch, test_input, bn_type): 484 | if bn_type == "cbn": 485 | label = tf.convert_to_tensor(np.random.randint(0, 10, 128)) 486 | label = tf.squeeze(label) 487 | predictions = model(test_input, label) 488 | elif bn_type == "sbn": 489 | predictions = model(test_input) 490 | else: 491 | predictions = model(test_input, training=False) 492 | 493 | fig = plt.figure(figsize=(4,4)) 494 | 495 | for i in range(8): 496 | plt.subplot(4, 4, i+1) 497 | plt.imshow((predictions[i]+1)/2) 498 | plt.axis('off') 499 | 500 | plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) 501 | 502 | @tf.function 503 | def train_step(images, labels=None, sbn=False, disc_steps=1): 504 | for _ in range(disc_steps): 505 | with tf.GradientTape() as disc_tape: 506 | 507 | noise = tf.random.normal([BATCH_SIZE, 1, 1, noise_dim]) 508 | if labels is not None: 509 | generated_images = generator(noise, labels) 510 | else: 511 | if sbn: 512 | generated_images = generator(noise) 513 | else: 514 | generated_images = generator(noise, training=True) 515 | real_output = discriminator(images, training=True) 516 | fake_output = discriminator(generated_images, training=True) 517 | disc_loss = discriminator_hinge_loss(real_output, fake_output) 518 | gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) 519 | discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) 520 | 521 | with tf.GradientTape() as gen_tape: 522 | noise = tf.random.normal([BATCH_SIZE, 1, 1, noise_dim]) 523 | if labels is not None: 524 | generated_images = generator(noise, labels) 525 | else: 526 | if sbn: 527 | generated_images = generator(noise) 528 | else: 529 | generated_images = generator(noise, training=True) 530 | fake_output = discriminator(generated_images, training=False) 531 | gen_loss = generator_hinge_loss(fake_output) 532 | 533 | gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) 534 | generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) 535 | singular_values = tf.linalg.svd(discriminator.trainable_variables[-2])[0] 536 | condition_number = tf.reduce_max(singular_values) 537 | train_stats = {'d_loss': disc_loss, "g_loss": gen_loss, 538 | 'd_grads': gradients_of_discriminator, 'g_grads': gradients_of_generator, 539 | 'cond_number': condition_number} 540 | return train_stats 541 | 542 | def train(dataset, epochs, bn_type=None): 543 | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 544 | train_log_dir = 'logs/gradient_tape/' + current_time + '/train' 545 | train_summary_writer = tf.summary.create_file_writer(train_log_dir) 546 | 547 | checkpoint_dir = './training_checkpoints' 548 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 549 | checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 550 | discriminator_optimizer=discriminator_optimizer, 551 | generator=generator, 552 | discriminator=discriminator) 553 | 554 | for epoch in range(epochs): 555 | start = time.time() 556 | 557 | for batch in dataset: 558 | if bn_type == "cbn": 559 | image_batch, label_batch = batch 560 | label_batch = tf.squeeze(label_batch) 561 | train_diction = train_step(image_batch, label_batch) 562 | elif bn_type == "sbn": 563 | image_batch = batch 564 | train_diction = train_step(image_batch, sbn=True) 565 | else: 566 | image_batch= batch 567 | train_diction = train_step(image_batch) 568 | 569 | 570 | with train_summary_writer.as_default(): 571 | tf.summary.scalar('disc_total_loss', train_diction['d_loss'], step=epoch) 572 | tf.summary.scalar('gen_total_loss', train_diction['g_loss'], step=epoch) 573 | tf.summary.scalar("max disc_grads[-1]: ", np.max(train_diction['d_grads'][-1].numpy()), step=epoch) 574 | tf.summary.scalar("max disc_grads[0]: ", np.max(train_diction['d_grads'][0].numpy()), step=epoch) 575 | tf.summary.scalar('gen_grads', np.max(train_diction['g_grads'][-1].numpy()), step=epoch) 576 | tf.summary.scalar('condition_number', train_diction['cond_number'], step=epoch) 577 | 578 | generate_and_save_images(generator, 579 | epoch + 1, 580 | seed, bn_type) 581 | 582 | # Save the model every 15 epochs 583 | if (epoch + 1) % 5 == 0: 584 | checkpoint.save(file_prefix = checkpoint_prefix) 585 | 586 | print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) 587 | # Generate after the final epoch 588 | generate_and_save_images(generator, 589 | epochs, 590 | seed, bn_type) 591 | 592 | def SAGAN_train(): 593 | global generator, discriminator, generator_optimizer, discriminator_optimizer 594 | 595 | discriminator = make_resnet_discriminator_model() 596 | generator = make_resnet_generator_model() 597 | 598 | lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( 599 | 2e-4, decay_rate=0.99, decay_steps=50000*EPOCHS) 600 | 601 | generator_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0, beta_2=0.9) 602 | discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0, beta_2=0.9) 603 | 604 | train_dataset = load_cifar10_dataset() 605 | train(train_dataset, EPOCHS) 606 | # train(train_dataset, EPOCHS, bn_type='cbn'); make_cbn_generator_model 607 | # train(train_dataset, EPOCHS, bn_type='sbn'); make_sbn_generator_model 608 | # make_resnet_discriminator_model; loss to hinge 609 | 610 | def load_cifar10_dataset(sbn=False): 611 | (train_images, train_labels), (_, _) = tf.keras.datasets.cifar10.load_data() 612 | train_images = train_images.reshape(train_images.shape[0], 32, 32, 3).astype('float32') 613 | train_images = (train_images/255) * 2 - 1 614 | train_labels = train_labels.astype('int32') 615 | 616 | BUFFER_SIZE = 50000 617 | AUTOTUNE = tf.data.experimental.AUTOTUNE 618 | 619 | if sbn: 620 | image_dataset = tf.data.Dataset.from_tensor_slices(train_images) 621 | label_dataset = tf.data.Dataset.from_tensor_slices(train_labels) 622 | train_dataset = tf.data.Dataset.zip((image_dataset, label_dataset)) 623 | train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True) 624 | train_dataset = train_dataset.shuffle(buffer_size=BUFFER_SIZE) 625 | train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE) 626 | else: 627 | train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 628 | train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE) 629 | 630 | return train_dataset 631 | 632 | if __name__ == "__main__": 633 | SAGAN_train() -------------------------------------------------------------------------------- /Summary-of-Different-GAN-Loss-Functions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrForExample/Generative_Models_Collection/9ec520636a1fac0b832326403f6440c9e78eec30/Summary-of-Different-GAN-Loss-Functions.png --------------------------------------------------------------------------------