├── README.md ├── losses.py ├── LICENSE ├── training_walkthrough.ipynb ├── config.py ├── utils.py ├── net.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # progressive_growing_of_GANs 2 | Pure tensorflow implementation of progressive growing of GANs [https://arxiv.org/abs/1710.10196] 3 | 4 | Includes : 5 | * Progressive growing of network 6 | * Use of minibatch standard deviation 7 | * Pixelwise feature normalization 8 | * Equalized learning rate 9 | * Drift loss 10 | 11 | Tested on private dataset. 12 | 13 | See `training_walkthrough.ipynb` on how to train on your own dataset. 14 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | # Reference : https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar.py 3 | 4 | 5 | def js_loss(logits_real, logits_fake, smooth_factor=0.9): 6 | # discriminator loss for real/fake classification 7 | d_loss_real = tf.reduce_mean( 8 | tf.nn.sigmoid_cross_entropy_with_logits( 9 | logits=logits_real, labels=tf.ones_like(logits_real) * smooth_factor)) 10 | d_loss_fake = tf.reduce_mean( 11 | tf.nn.sigmoid_cross_entropy_with_logits( 12 | logits=logits_fake, labels=tf.zeros_like(logits_fake))) 13 | d_loss = d_loss_real + d_loss_fake 14 | 15 | # generator loss for fooling discriminator 16 | g_loss = tf.reduce_mean( 17 | tf.nn.sigmoid_cross_entropy_with_logits( 18 | logits=logits_fake, labels=tf.ones_like(logits_fake))) 19 | return d_loss, g_loss 20 | 21 | 22 | def wgan_loss(d_real, d_fake): 23 | # Standard WGAN loss 24 | g_loss = -tf.reduce_mean(d_fake) 25 | d_loss = tf.reduce_mean(d_fake) - tf.reduce_mean(d_real) 26 | return d_loss, g_loss 27 | 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Prerit Jaiswal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /training_walkthrough.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import print_function\n", 10 | "from __future__ import division\n", 11 | "\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import math\n", 15 | "\n", 16 | "from config import cfg\n", 17 | "from net import DCGAN\n", 18 | "import tensorflow as tf" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "### Set parameters \n", 26 | "All the parameters can be found in `config.py`" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "cfg.beta1 = 0.\n", 36 | "cfg.beta2 = 0.99\n", 37 | "cfg.batch_size = 16\n", 38 | "cfg.save_period = 1000\n", 39 | "cfg.display_period = 100\n", 40 | "cfg.n_iters = 20000\n", 41 | "cfg.n_critic = 1\n", 42 | "cfg.learning_rate = 0.001\n", 43 | "cfg.norm_g = 'pixel_norm'\n", 44 | "cfg.norm_d = None\n", 45 | "cfg.weight_scale = True\n", 46 | "cfg.drift_loss = True\n", 47 | "cfg.loss_mode = 'wgan_gp'\n", 48 | "cfg.use_tanh = True\n", 49 | "cfg.save_images = True" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "### Step-1 : Train at 8x8 resolution\n", 57 | "In the paper, they start with 4x4 resolution but it doesn't matter" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "cfg.resolution = 8\n", 67 | "cfg.transition = False\n", 68 | "cfg.load_model = None\n", 69 | "\n", 70 | "model = DCGAN(cfg)\n", 71 | "model.train()" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "### Step-2 : Transition from 8x8 to 16x16 resolution (fading)\n", 79 | "Note that this step will automatically load model from 8x8 resolution (`models/8x8`). So, you can leave `cfg.load_model = None`." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "cfg.resolution = 16\n", 89 | "cfg.transition = True\n", 90 | "cfg.load_model = None\n", 91 | "\n", 92 | "model = DCGAN(cfg)\n", 93 | "model.train()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "### Step-3 : Train at 16x16 resolution\n", 101 | "Now, you must load model from previous step to continue. Model from previous step is saved in `models/16x16_transition`, so we set `cfg.load_model = '16x16_transition`." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "cfg.resolution = 16\n", 111 | "cfg.transition = False\n", 112 | "cfg.load_model = '16x16_transition'\n", 113 | "\n", 114 | "model = DCGAN(cfg)\n", 115 | "model.train()" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "### Step-4 : Transition from 16x16 to 32x32 resolution (fading)\n", 123 | "Again, this step will automatically load model from `models/16x16`." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "cfg.resolution = 32\n", 133 | "cfg.transition = True\n", 134 | "cfg.load_model = None\n", 135 | "\n", 136 | "model = DCGAN(cfg)\n", 137 | "model.train()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "### Step-5 : Train at 32x32 resolution\n", 145 | "Now, we load model from previous step saved in `models/32x32_transition`" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "cfg.resolution = 32\n", 155 | "cfg.transition = False\n", 156 | "cfg.load_model = '32x32_transition'\n", 157 | "\n", 158 | "model = DCGAN(cfg)\n", 159 | "model.train()" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "Hopefully, you get the idea now. So you can continue this process to higher resolutions. " 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "Python 2", 180 | "language": "python", 181 | "name": "python2" 182 | }, 183 | "language_info": { 184 | "codemirror_mode": { 185 | "name": "ipython", 186 | "version": 2 187 | }, 188 | "file_extension": ".py", 189 | "mimetype": "text/x-python", 190 | "name": "python", 191 | "nbconvert_exporter": "python", 192 | "pygments_lexer": "ipython2", 193 | "version": "2.7.12" 194 | } 195 | }, 196 | "nbformat": 4, 197 | "nbformat_minor": 2 198 | } 199 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | # Training/validation parameters 5 | TRAIN = True # set to True for training + validation and False for testing 6 | N_ITERS = 5000 # number of iterations to train 7 | BATCH_SIZE = 64 # batch size for training 8 | DISPLAY_PERIOD = 5 # Interval to display loss 9 | SAVE_PERIOD = 1 # Interval to save model 10 | SAVE_IMAGES = True # if True, save images when model is saved 11 | SAVE_DIR = r'D:\Data\pgan\models' # path to save trained models 12 | IMAGE_SAVE_DIR = r'D:\Data\pgan\gan_images' # path to save GAN images 13 | SUMMARY_DIR = r'D:\Data\pgan\train' 14 | 15 | # Data Preprocessing and augmentation parameters 16 | INPUT_SHAPE = (512, 512, 3) # size of random crops used for training 17 | FLIP = True # applies random horizontal and vertical flips 18 | ROTATE = True # applies random rotations 19 | PREPROCESS = 'min-max' # can be 'min-max' or 'standard' 20 | 21 | # Adam optimizer parameters: 22 | LEARNING_RATE = 0.001 23 | BETA1 = 0. 24 | BETA2 = 0.99 25 | 26 | # Other parameters 27 | LEAKY_RELU_ALPHA = 0.2 # alpha in leaky relu 28 | SMOOTH_LABEL = True # uses 0.9 instead of 1 for positive labels 29 | NOISE_STDDEV = 0.01 # standard deviation for noise added to images 30 | Z_DIM = 128 # dim of latent space 31 | # LOSS_MODE can be 'js' or 'wgan_gp' 32 | # 'js' : Jensen-Shannon loss as in the original GAN paper 33 | # 'wgan_gp' Wasserstein GAN loss with gradient penalty (Gulrajani et al) 34 | LOSS_MODE = 'wgan_gp' 35 | N_CRITIC = 3 # number of times to train disc for every gen train step 36 | LAMBDA_GP = 10. # Gradient penalty lambda hyperparameter 37 | GAMMA_GP = 1. # Gradient penalty gamma hyperparameter 38 | MINIBATCH_STDDEV = True # include minibatch std deviation as feature 39 | NORM_D = None # options are None, pixel_norm, batch_norm and layer_norm 40 | NORM_G = 'pixel_norm' # options are None, pixel_norm, batch_norm and layer_norm 41 | USE_TANH = False # use tanh in the final layer of generator 42 | WEIGHT_SCALE = True # use weight scaling for equalized learning rate 43 | DRIFT_LOSS = True # add a drift loss term 44 | EPS_DRIFT = 0.001 # epsilon for drift loss term 45 | FADE_ALPHA = 0. # starting alpha to use for transition 46 | 47 | RESOLUTION = 16 # this is the current resolution of network 48 | MIN_RESOLUTION = 4 # min spatial resolution of features 49 | NF_MIN = 32 # min depth of features 50 | NF_MAX = 512 # max depth of features 51 | 52 | TRANSITION = False # whether to train in transition mode 53 | LOAD_MODEL = None # if not None, specify load sub-dir e.g. '16x16_transition' 54 | 55 | 56 | ############################################################### 57 | # dataset specific parameters 58 | ############################################################### 59 | 60 | # RGB mean 61 | IMAGE_MEAN = [184.02, 157.45, 215.96] 62 | IMAGE_STDDEV = [42.37, 48.23, 29.98] 63 | 64 | # class encodings 65 | CLASSES = {'Normal': 0, 66 | 'Benign': 1, 67 | 'InSitu': 2, 68 | 'Invasive': 3} 69 | 70 | # abreviations 71 | CLASS_ABR = {'Normal': 'n', 72 | 'Benign': 'b', 73 | 'InSitu': 'is', 74 | 'Invasive': 'iv'} 75 | 76 | # validation files 77 | VALIDATION_SET = {'Normal': ([i for i in range(46, 52)] 78 | + [i for i in range(61, 69)]), 79 | 'Benign': [i for i in range(45, 59)], 80 | 'InSitu': [i for i in range(40, 54)], 81 | 'Invasive': ([i for i in range(50, 54)] 82 | + [i for i in range(64, 74)])} 83 | 84 | # Data directory 85 | DATA_DIR = r'C:\Data\img' 86 | 87 | 88 | ################################################################# 89 | cfg = edict({'data_dir': DATA_DIR, 90 | 'summary_dir': SUMMARY_DIR, 91 | 'image_mean': IMAGE_MEAN, 92 | 'image_stddev': IMAGE_STDDEV, 93 | 'preprocess': PREPROCESS, 94 | 'classes': CLASSES, 95 | 'class_abr': CLASS_ABR, 96 | 'validation_set': VALIDATION_SET, 97 | 'train': TRAIN, 98 | 'input_shape': INPUT_SHAPE, 99 | 'flip': FLIP, 100 | 'rotate': ROTATE, 101 | 'smooth_label': SMOOTH_LABEL, 102 | 'noise_stddev': NOISE_STDDEV, 103 | 'z_dim': Z_DIM, 104 | 'loss_mode': LOSS_MODE, 105 | 'lambda_gp': LAMBDA_GP, 106 | 'gamma_gp': GAMMA_GP, 107 | 'n_iters': N_ITERS, 108 | 'batch_size': BATCH_SIZE, 109 | 'leakyRelu_alpha': LEAKY_RELU_ALPHA, 110 | 'learning_rate': LEARNING_RATE, 111 | 'beta1': BETA1, 112 | 'beta2': BETA2, 113 | 'norm_d': NORM_D, 114 | 'norm_g': NORM_G, 115 | 'weight_scale': WEIGHT_SCALE, 116 | 'drift_loss': DRIFT_LOSS, 117 | 'eps_drift': EPS_DRIFT, 118 | 'n_critic': N_CRITIC, 119 | 'use_tanh': USE_TANH, 120 | 'fade_alpha': FADE_ALPHA, 121 | 'resolution': RESOLUTION, 122 | 'min_resolution': MIN_RESOLUTION, 123 | 'nf_min': NF_MIN, 124 | 'nf_max': NF_MAX, 125 | 'transition': TRANSITION, 126 | 'load_model': LOAD_MODEL, 127 | 'minibatch_stddev': MINIBATCH_STDDEV, 128 | 'display_period': DISPLAY_PERIOD, 129 | 'save_images': SAVE_IMAGES, 130 | 'model_save_dir': SAVE_DIR, 131 | 'image_save_dir': IMAGE_SAVE_DIR, 132 | 'save_period': SAVE_PERIOD}) 133 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import tensorflow as tf 5 | import os 6 | from glob import glob 7 | import numpy as np 8 | import math 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | from tqdm import tqdm 12 | 13 | 14 | class ImageLoader(object): 15 | def __init__(self, cfg): 16 | self.cfg = cfg 17 | imgs = glob(cfg.data_dir + "/*.jpg") + \ 18 | glob(cfg.data_dir + "/*.png") + \ 19 | glob(cfg.data_dir + "/*.jpeg") + \ 20 | glob(cfg.data_dir + "/*.bmp") 21 | 22 | self.images = np.array(imgs) 23 | self.train_idx, self.val_idx = None, None 24 | self.train_test_split() 25 | if self.cfg.preprocess == 'min-max': 26 | self.img_mean = self.img_stddev = 127.5 27 | else: 28 | self.img_mean = self.cfg.image_mean 29 | self.img_stddev = self.cfg.image_stddev 30 | 31 | def train_test_split(self): 32 | # build validation set 33 | 34 | val_idx = range(0, len(self.images), 10) 35 | train_idx = [i for i, _ in enumerate(self.images) if i not in val_idx] 36 | self.train_idx = np.array(train_idx) 37 | self.val_idx = np.array(val_idx) 38 | print("Size of training set : ", self.train_idx.size) 39 | print("Size of validation set : ", self.val_idx.size) 40 | 41 | def preprocess_image(self, img): 42 | image = np.copy(img) 43 | if self.cfg.train: 44 | new_img = self.random_crop(image) 45 | if self.cfg.flip: 46 | new_img = self.random_flip(new_img) 47 | if self.cfg.rotate: 48 | new_img = self.random_rotate(new_img) 49 | return (new_img - self.img_mean) / self.img_stddev 50 | else: 51 | # Pick predefined crops in testing mode 52 | new_images = self.test_crop(image) 53 | return (new_images - self.img_mean) / self.img_stddev 54 | 55 | def postprocess_image(self, imgs): 56 | new_imgs = imgs * self.img_stddev + self.img_mean 57 | new_imgs[new_imgs < 0] = 0 58 | new_imgs[new_imgs > 255] = 255 59 | return new_imgs 60 | 61 | def random_crop(self, img): 62 | """ 63 | Applies random crops. 64 | Final image size given by self.cfg.input_shape 65 | """ 66 | img_h, img_w, _ = img.shape 67 | new_h, new_w, _ = self.cfg.input_shape 68 | img = np.pad(img, [(0, max(0, new_h - img_h)), (0, max(0, new_w - img_w)), (0,0)], mode='mean') 69 | top = np.random.randint(0, max(0, img_h - new_h)+1) 70 | left = np.random.randint(0, max(0, img_w - new_w)+1) 71 | new_img = img[top:top + new_h, left:left + new_w, :] 72 | return new_img 73 | 74 | def random_flip(self, img): 75 | """Random horizontal and vertical flips""" 76 | new_img = np.copy(img) 77 | if np.random.uniform() > 0.5: 78 | new_img = cv2.flip(new_img, 1) 79 | if np.random.uniform() > 0.5: 80 | new_img = cv2.flip(new_img, 0) 81 | return new_img 82 | 83 | def random_rotate(self, img): 84 | """Random rotations by 0, 90, 180, 360 degrees""" 85 | theta = np.random.choice([0, 90, 180, 360]) 86 | if theta == 0: 87 | return img 88 | h, w, _ = img.shape 89 | mat = cv2.getRotationMatrix2D((w / 2, h / 2), theta, 1) 90 | return cv2.warpAffine(img, mat, (w, h)) 91 | 92 | def test_crop(self, img): 93 | new_images = [] 94 | h, w, _ = self.cfg.input_shape 95 | for y, x in self.cfg.test_crops: 96 | new_img = img[y:y + h, x:x + w, :] 97 | new_images.append(new_img) 98 | return np.array(new_images) 99 | 100 | def load_batch(self, idx): 101 | """Loads batch of images and labels 102 | Arguments: 103 | idx: List of indices 104 | Returns: 105 | (images, labels): images and labels corresponding to indices 106 | """ 107 | batch_imgs = [] 108 | for index in idx: 109 | img_file = self.images[index] 110 | img = plt.imread(img_file)[:,:,:3] # For png, which have 4 channels 111 | img = self.preprocess_image(img) 112 | batch_imgs.append(img) 113 | return np.array(batch_imgs) 114 | 115 | def batch_generator(self): 116 | batch_size = self.cfg.batch_size 117 | for _ in range(self.cfg.n_iters): 118 | indices = np.random.randint(len(self.train_idx), size=batch_size) 119 | batch_idx = self.train_idx[indices] 120 | batch_imgs = self.load_batch(batch_idx) 121 | yield batch_imgs 122 | 123 | def create_batch_pipeline(self): 124 | images_names_tensor = tf.convert_to_tensor(self.images, dtype=tf.string) 125 | single_image_name, = tf.train.slice_input_producer([images_names_tensor], shuffle=True, capacity=128) 126 | single_image_content = tf.read_file(single_image_name) 127 | single_image = tf.image.decode_image(single_image_content, channels=3) 128 | single_image.set_shape([None, None, 3]) 129 | 130 | # Smart resize 131 | shp = tf.shape(single_image) 132 | r_size = shp[:2] 133 | dest_h = tf.random_uniform([1], 512, 1024, tf.int32) 134 | dest_h = tf.minimum(dest_h, r_size[0]) 135 | ratio = tf.to_float(dest_h) / tf.to_float(r_size[0]) 136 | n_size = tf.to_int32(tf.to_float(r_size) * ratio) 137 | single_image = tf.cast(tf.image.resize_images(single_image, n_size), np.uint8) 138 | 139 | # single_image = tf.image.random_brightness(single_image, .3) 140 | # single_image = tf.image.random_contrast(single_image, 0.9, 1.1) 141 | 142 | nH, nW = self.cfg.input_shape[:2] 143 | rH = tf.shape(single_image)[0] 144 | rW = tf.shape(single_image)[1] 145 | dH = tf.maximum(nH, rH) - rH 146 | dW = tf.maximum(nW, rW) - rW 147 | 148 | n = int(single_image.shape[-1]) 149 | single_image = tf.pad(single_image, 150 | tf.convert_to_tensor([[dH // 2, (dH + 1) // 2], [dW // 2, (dW + 1) // 2], [0, 0]])) 151 | single_image = tf.random_crop(single_image, [nH, nW, n], seed=123) 152 | single_image.set_shape([nH, nW, n]) 153 | 154 | angs = tf.to_float(tf.random_uniform([1], 0, 4, tf.int32)) * np.pi / 2 155 | single_image = tf.contrib.image.rotate(single_image, angs[0]) 156 | single_image = tf.image.random_flip_left_right(single_image) 157 | 158 | single_image = (tf.to_float(single_image) - self.img_mean) / self.img_stddev 159 | 160 | image_batch = tf.train.batch( 161 | [single_image], 162 | batch_size=self.cfg.batch_size, 163 | num_threads=16, 164 | capacity=128) 165 | 166 | return image_batch 167 | 168 | def grid_batch_images(self, images): 169 | n, h, w, c = images.shape 170 | a = int(math.floor(np.sqrt(n))) 171 | # images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8) 172 | images = images.astype(np.uint8) 173 | images_in_square = np.reshape(images[:a * a], (a, a, h, w, c)) 174 | new_img = np.zeros((h * a, w * a, c), dtype=np.uint8) 175 | for col_i, col_images in enumerate(images_in_square): 176 | for row_i, image in enumerate(col_images): 177 | new_img[col_i * h: (1 + col_i) * h, row_i * w: (1 + row_i) * w] = image 178 | resolution = self.cfg.resolution 179 | if self.cfg.resolution != h: 180 | scale = resolution / h 181 | new_img = cv2.resize(new_img, None, fx=scale, fy=scale, 182 | interpolation=cv2.INTER_NEAREST) 183 | return new_img 184 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from model import Model 4 | 5 | 6 | class DCGAN(Model): 7 | def __init__(self, cfg): 8 | self.alpha = cfg.leakyRelu_alpha 9 | input_size, _, nc = cfg.input_shape 10 | self.res = cfg.resolution 11 | self.min_res = cfg.min_resolution 12 | # number of times to upsample/downsample for full resolution: 13 | self.n_scalings = int(np.log2(input_size / self.min_res)) 14 | # number of times to upsample/downsample for current resolution: 15 | self.n_layers = int(np.log2(self.res / self.min_res)) 16 | self.nf_min = cfg.nf_min # min feature depth 17 | self.nf_max = cfg.nf_max # max feature depth 18 | self.batch_size = cfg.batch_size 19 | Model.__init__(self, cfg) 20 | 21 | def leaky_relu(self, input_): 22 | return tf.maximum(self.alpha * input_, input_) 23 | 24 | def add_minibatch_stddev_feat(self, input_): 25 | _, h, w, _ = input_.get_shape().as_list() 26 | new_feat_shape = [self.cfg.batch_size, h, w, 1] 27 | 28 | mean, var = tf.nn.moments(input_, axes=[0], keep_dims=True) 29 | stddev = tf.sqrt(tf.reduce_mean(var, keep_dims=True)) 30 | new_feat = tf.tile(stddev, multiples=new_feat_shape) 31 | return tf.concat([input_, new_feat], axis=3) 32 | 33 | def pixelwise_norm(self, a): 34 | return a / tf.sqrt(tf.reduce_mean(a * a, axis=3, keep_dims=True) + 1e-8) 35 | 36 | def conv2d(self, input_, n_filters, k_size, padding='same'): 37 | if not self.cfg.weight_scale: 38 | return tf.layers.conv2d(input_, n_filters, k_size, padding=padding) 39 | 40 | n_feats_in = input_.get_shape().as_list()[-1] 41 | fan_in = k_size * k_size * n_feats_in 42 | c = tf.constant(np.sqrt(2. / fan_in), dtype=tf.float32) 43 | kernel_init = tf.random_normal_initializer(stddev=1.) 44 | w_shape = [k_size, k_size, n_feats_in, n_filters] 45 | w = tf.get_variable('kernel', shape=w_shape, initializer=kernel_init) 46 | w = c * w 47 | strides = [1, 1, 1, 1] 48 | net = tf.nn.conv2d(input_, w, strides, padding=padding.upper()) 49 | b = tf.get_variable('bias', [n_filters], 50 | initializer=tf.constant_initializer(0.)) 51 | net = tf.nn.bias_add(net, b) 52 | return net 53 | 54 | def up_sample(self, input_): 55 | _, h, w, _ = input_.get_shape().as_list() 56 | new_size = [2 * h, 2 * w] 57 | return tf.image.resize_nearest_neighbor(input_, size=new_size) 58 | 59 | def down_sample(self, input_): 60 | return tf.layers.average_pooling2d(input_, 2, 2) 61 | 62 | def conv_module(self, input_, n_filters, training, k_sizes=None, 63 | norms=None, padding='same'): 64 | conv = input_ 65 | if k_sizes is None: 66 | k_sizes = [3] * len(n_filters) 67 | if norms is None: 68 | norms = [None, None] 69 | 70 | # series of conv + lRelu + norm 71 | for i, (nf, k_size, norm) in enumerate(zip(n_filters, k_sizes, norms)): 72 | var_scope = 'conv_block_' + str(i+1) 73 | with tf.variable_scope(var_scope): 74 | conv = self.conv2d(conv, nf, k_size, padding=padding) 75 | conv = self.leaky_relu(conv) 76 | if norm == 'batch_norm': 77 | conv = tf.layers.batch_normalization(conv, training=training) 78 | elif norm == 'pixel_norm': 79 | conv = self.pixelwise_norm(conv) 80 | elif norm == 'layer_norm': 81 | conv = tf.contrib.layers.layer_norm(conv) 82 | return conv 83 | 84 | def to_image(self, input_, resolution): 85 | nc = self.cfg.input_shape[-1] 86 | var_scope = '{0:}x{0:}'.format(resolution) 87 | with tf.variable_scope(var_scope + '/to_image'): 88 | out = self.conv2d(input_, nc, 1) 89 | return out 90 | 91 | def from_image(self, input_, n_filters, resolution): 92 | var_scope = '{0:}x{0:}'.format(resolution) 93 | with tf.variable_scope(var_scope + '/from_image'): 94 | out = self.conv2d(input_, n_filters, 1) 95 | return self.leaky_relu(out) 96 | 97 | def build_generator(self, training): 98 | z = self.tf_placeholders['z'] 99 | z_dim = self.cfg.z_dim 100 | feat_size = self.min_res 101 | norm = self.cfg.norm_g 102 | 103 | with tf.variable_scope('generator', reuse=(not training)): 104 | net = tf.reshape(z, (-1, 1, 1, z_dim)) 105 | padding = int(feat_size / 2) 106 | net = tf.pad(net, [[0, 0], [padding - 1, padding], 107 | [padding - 1, padding], [0, 0]]) 108 | feat_depth = min(self.nf_max, self.nf_min * 2 ** self.n_scalings) 109 | r = self.min_res 110 | var_scope = '{0:}x{0:}'.format(r) 111 | with tf.variable_scope(var_scope): 112 | net = self.conv_module(net, [feat_depth, feat_depth], 113 | training, k_sizes=[4, 3], 114 | norms=[None, norm]) 115 | layers = [] 116 | for i in range(self.n_layers): 117 | net = self.up_sample(net) 118 | n = self.nf_min * 2 ** (self.n_scalings - i - 1) 119 | feat_depth = min(self.nf_max, n) 120 | r *= 2 121 | var_scope = '{0:}x{0:}'.format(r) 122 | with tf.variable_scope(var_scope): 123 | net = self.conv_module(net, [feat_depth, feat_depth], 124 | training, norms=[norm, norm]) 125 | layers.append(net) 126 | 127 | # final layer: 128 | assert r == self.res, \ 129 | '{:} not equal to {:}'.format(r, self.res) 130 | net = self.to_image(net, self.res) 131 | if self.cfg.transition: 132 | alpha = self.tf_placeholders['alpha'] 133 | branch = layers[-2] 134 | branch = self.up_sample(branch) 135 | branch = self.to_image(branch, r / 2) 136 | net = alpha * net + (1. - alpha) * branch 137 | if self.cfg.use_tanh: 138 | net = tf.tanh(net) 139 | return net 140 | 141 | def build_discriminator(self, input_, reuse, training): 142 | norm = self.cfg.norm_d 143 | if (self.cfg.loss_mode == 'wgan_gp') and (norm == 'batch_norm'): 144 | norm = None 145 | with tf.variable_scope('discriminator', reuse=reuse): 146 | feat_depths = [min(self.nf_max, self.nf_min * 2 ** i) 147 | for i in range(self.n_scalings)] 148 | r = self.res 149 | net = self.from_image(input_, feat_depths[-self.n_layers], r) 150 | for i in range(self.n_layers): 151 | feat_depth_1 = feat_depths[-self.n_layers + i] 152 | feat_depth_2 = min(self.nf_max, 2 * feat_depth_1) 153 | var_scope = '{0:}x{0:}'.format(r) 154 | with tf.variable_scope(var_scope): 155 | net = self.conv_module(net, [feat_depth_1, feat_depth_2], 156 | training, norms=[norm, norm]) 157 | net = self.down_sample(net) 158 | r /= 2 159 | # add a transition branch if required 160 | if i == 0 and self.cfg.transition: 161 | alpha = self.tf_placeholders['alpha'] 162 | input_low = self.down_sample(input_) 163 | idx = -self.n_layers + 1 164 | branch = self.from_image(input_low, feat_depths[idx], 165 | self.res / 2) 166 | net = alpha * net + (1. - alpha) * branch 167 | 168 | # add final layer 169 | assert r == self.min_res, \ 170 | '{:} not equal to {:}'.format(r, self.min_res) 171 | net = self.add_minibatch_stddev_feat(net) 172 | feat_depth = min(self.nf_max, self.nf_min * 2 ** self.n_scalings) 173 | var_scope = '{0:}x{0:}'.format(r) 174 | with tf.variable_scope(var_scope): 175 | net = self.conv_module(net, [feat_depth, feat_depth], 176 | training, k_sizes=[3, 4], 177 | norms=[norm, None]) 178 | net = tf.reduce_mean(net, axis=[1, 2]) 179 | net = tf.reshape(net, [self.cfg.batch_size, feat_depth]) 180 | net = tf.layers.dense(net, 1) 181 | return net 182 | 183 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import losses 7 | import time 8 | from utils import ImageLoader 9 | 10 | 11 | class Model(object): 12 | def __init__(self, cfg): 13 | self.cfg = cfg 14 | self.tf_placeholders = {} 15 | self.create_tf_placeholders() 16 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 17 | self.d_train_op, self.g_train_op = None, None 18 | self.ema_op, self.ema_vars = None, {} 19 | self.d_loss, self.g_loss = None, None 20 | self.gen_images, self.eval_op = None, None 21 | self.image_loader = ImageLoader(self.cfg) 22 | 23 | def create_tf_placeholders(self): 24 | h, w, c = self.cfg.input_shape 25 | z_dim = self.cfg.z_dim 26 | z = tf.placeholder(tf.float32, [None, z_dim]) 27 | learning_rate = tf.placeholder(tf.float32) 28 | alpha = tf.placeholder(tf.float32, shape=()) 29 | self.tf_placeholders = {'z': z, 30 | 'learning_rate': learning_rate, 31 | 'alpha': alpha} 32 | 33 | def resize_image(self, image): 34 | _, input_size, _, _ = image.get_shape().as_list() 35 | res = self.cfg.resolution 36 | if input_size == res: 37 | return image 38 | new_size = [res, res] 39 | new_img = tf.image.resize_nearest_neighbor(image, size=new_size) 40 | if self.cfg.transition: 41 | alpha = self.tf_placeholders['alpha'] 42 | low_res_img = tf.layers.average_pooling2d(new_img, 2, 2) 43 | low_res_img = \ 44 | tf.image.resize_nearest_neighbor(low_res_img, size=new_size) 45 | new_img = alpha * new_img + (1. - alpha) * low_res_img 46 | return new_img 47 | 48 | def build_generator(self, training): 49 | raise NotImplementedError("Not yet implemented") 50 | 51 | def build_encoder(self, training): 52 | raise NotImplementedError("Not yet implemented") 53 | 54 | def build_discriminator(self, input_, reuse, training): 55 | raise NotImplementedError("Not yet implemented") 56 | 57 | def make_train_op(self, images): 58 | images_real = images 59 | tf.summary.image('images_real_original_size', images_real, 8) 60 | images_real = self.resize_image(images_real) 61 | tf.summary.image('images_real', images_real, 8) 62 | 63 | d_real = self.build_discriminator(images_real, reuse=False, 64 | training=True) 65 | 66 | images_fake = self.build_generator(training=True) 67 | tf.summary.image('images_fake', images_fake, 8) 68 | 69 | d_fake = self.build_discriminator(images_fake, reuse=True, 70 | training=True) 71 | 72 | d_loss, g_loss = None, None 73 | if self.cfg.loss_mode == 'js': 74 | smooth_factor = 0.9 if self.cfg.smooth_label else 1. 75 | d_loss, g_loss = losses.js_loss(d_real, d_fake, smooth_factor) 76 | elif self.cfg.loss_mode == 'wgan_gp': 77 | d_loss, g_loss = losses.wgan_loss(d_real, d_fake) 78 | # Gradient penalty 79 | lambda_gp = self.cfg.lambda_gp 80 | gamma_gp = self.cfg.gamma_gp 81 | batch_size = self.cfg.batch_size 82 | nc = self.cfg.input_shape[-1] 83 | res = self.cfg.resolution 84 | input_shape = [batch_size, res, res, nc] 85 | alpha = tf.random_uniform(shape=input_shape, minval=0., maxval=1.) 86 | differences = images_fake - images_real 87 | interpolates = images_real + alpha * differences 88 | gradients = tf.gradients( 89 | self.build_discriminator(interpolates, reuse=True, training=True), 90 | [interpolates, ])[0] 91 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3])) 92 | gradient_penalty = \ 93 | lambda_gp * tf.reduce_mean((slopes / gamma_gp - 1.) ** 2) 94 | d_loss += gradient_penalty 95 | 96 | if self.cfg.drift_loss: 97 | eps = self.cfg.eps_drift 98 | drift_loss = eps * tf.reduce_mean(tf.nn.l2_loss(d_real)) 99 | d_loss += drift_loss 100 | 101 | t_vars = tf.trainable_variables() 102 | d_vars = [var for var in t_vars if var.name.startswith('discriminator')] 103 | g_vars = [var for var in t_vars if var.name.startswith('generator')] 104 | 105 | beta1 = self.cfg.beta1 106 | beta2 = self.cfg.beta2 107 | learning_rate = self.tf_placeholders['learning_rate'] 108 | d_solver = tf.train.AdamOptimizer(learning_rate, beta1=beta1, beta2=beta2) 109 | g_solver = tf.train.AdamOptimizer(learning_rate, beta1=beta1, beta2=beta2) 110 | ema = tf.train.ExponentialMovingAverage(decay=0.999) 111 | self.ema_op = ema.apply(g_vars) 112 | self.ema_vars = {ema.average_name(v): v for v in g_vars} 113 | 114 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 115 | with tf.control_dependencies(update_ops): 116 | self.d_train_op = d_solver.minimize(d_loss, var_list=d_vars, 117 | global_step=self.global_step) 118 | self.g_train_op = g_solver.minimize(g_loss, var_list=g_vars) 119 | self.d_loss, self.g_loss = d_loss, g_loss 120 | 121 | def train(self): 122 | """ Train the model. """ 123 | batch_size = self.cfg.batch_size 124 | n_iters = self.cfg.n_iters 125 | n_critic = self.cfg.n_critic 126 | z_dim = self.cfg.z_dim 127 | learning_rate = self.cfg.learning_rate 128 | display_period = self.cfg.display_period 129 | save_period = self.cfg.save_period 130 | image_loader = self.image_loader 131 | transition = self.cfg.transition 132 | # paths for save directories 133 | save_tag = '{0:}x{0:}'.format(self.cfg.resolution) 134 | if transition: 135 | save_tag += '_transition' 136 | img_save_dir = os.path.join(self.cfg.image_save_dir, save_tag) 137 | if not os.path.exists(img_save_dir): 138 | os.makedirs(img_save_dir) 139 | save_dir = os.path.join(self.cfg.model_save_dir, save_tag) 140 | if not os.path.exists(save_dir): 141 | os.makedirs(save_dir) 142 | save_dir = os.path.join(save_dir, 'model') 143 | 144 | with tf.device("/cpu:0"): 145 | image_batch = image_loader.create_batch_pipeline() 146 | 147 | self.make_train_op(image_batch) 148 | 149 | merged = tf.summary.merge_all() 150 | writer = tf.summary.FileWriter(os.path.join(self.cfg.summary_dir, time.strftime('%Y%m%d_%H%M%S'))) 151 | 152 | # Create ops in graph before Session is created 153 | init = tf.global_variables_initializer() 154 | saver = tf.train.Saver() 155 | with tf.Session() as sess: 156 | sess.run(init) 157 | tf.train.start_queue_runners(sess) 158 | load_model = self.cfg.load_model 159 | if self.cfg.load_model: 160 | self.load(sess, saver, load_model) 161 | elif transition: 162 | vars_to_load = [] 163 | all_vars = tf.trainable_variables() 164 | r = self.cfg.min_resolution 165 | while r < self.cfg.resolution: 166 | var_scope = '{0:}x{0:}'.format(r) 167 | vars_to_load += [v for v in all_vars if var_scope in v.name] 168 | r *= 2 169 | saver_restore = tf.train.Saver(vars_to_load) 170 | tag = '{0:}x{0:}'.format(self.cfg.resolution // 2) 171 | print(tag) 172 | self.load(sess, saver_restore, tag=tag) 173 | 174 | alpha = self.cfg.fade_alpha 175 | global_step = 0 176 | sum_g_loss, sum_d_loss = 0., 0. 177 | # batch_gen = image_loader.batch_generator() 178 | 179 | for i in range(self.cfg.n_iters): 180 | batch_z = np.random.normal(0, 1, size=(batch_size, z_dim)) 181 | feed_dict = {self.tf_placeholders['z']: batch_z, 182 | self.tf_placeholders['learning_rate']: learning_rate, 183 | self.tf_placeholders['alpha']: alpha} 184 | if global_step % display_period == 0: 185 | _, global_step, d_loss, merged_res = \ 186 | sess.run([self.d_train_op, self.global_step, self.d_loss, merged], 187 | feed_dict=feed_dict) 188 | else: 189 | _, global_step, d_loss = \ 190 | sess.run([self.d_train_op, self.global_step, self.d_loss], 191 | feed_dict=feed_dict) 192 | 193 | g_loss = 0. 194 | if global_step % n_critic == 0: 195 | _, _, g_loss = \ 196 | sess.run([self.g_train_op, self.ema_op, self.g_loss], 197 | feed_dict=feed_dict) 198 | sum_g_loss += g_loss 199 | sum_d_loss += d_loss 200 | if transition: 201 | alpha_step = 1. / n_iters 202 | alpha = min(1., self.cfg.fade_alpha+global_step*alpha_step) 203 | if global_step % display_period == 0: 204 | writer.add_summary(merged_res, global_step) 205 | print("After {} iterations".format(global_step), 206 | "Discriminator loss : {:3.5f} " 207 | .format(sum_d_loss / display_period), 208 | "Generator loss : {:3.5f}" 209 | .format(sum_g_loss / display_period * n_critic)) 210 | sum_g_loss, sum_d_loss = 0., 0. 211 | if transition: 212 | print("Using alpha = ", alpha) 213 | if global_step % save_period == 0: 214 | print("Saving model in {}".format(save_dir)) 215 | saver.save(sess, save_dir, global_step) 216 | if self.cfg.save_images: 217 | gen_images = self.generate_images(save_tag, alpha=alpha) 218 | plt.figure(figsize=(10, 10)) 219 | grid = image_loader.grid_batch_images(gen_images) 220 | filename = os.path.join(img_save_dir, str(global_step) + '.png') 221 | plt.imsave(filename, grid) 222 | print("Saving model in {}".format(save_dir)) 223 | saver.save(sess, save_dir, global_step) 224 | 225 | def generate_images(self, model, batch_z=None, alpha=0.): 226 | """Runs generator to generate images""" 227 | batch_size = 64 # self.cfg.batch_size 228 | z_dim = self.cfg.z_dim 229 | if batch_z is None: 230 | batch_z = np.random.normal(0, 1, size=(batch_size, z_dim)) 231 | # saver = tf.train.Saver(self.ema_vars) 232 | saver = tf.train.Saver() 233 | feed_dict = {self.tf_placeholders['z']: batch_z, 234 | self.tf_placeholders['alpha']: alpha} 235 | image_loader = self.image_loader 236 | gen = self.build_generator(training=False) 237 | 238 | with tf.Session() as sess: 239 | self.load(sess, saver, model) 240 | gen_images = sess.run(gen, feed_dict=feed_dict) 241 | gen_images = image_loader.postprocess_image(gen_images) 242 | return gen_images 243 | 244 | def load(self, sess, saver, tag=None): 245 | """ Load the trained model. """ 246 | if tag is None: 247 | tag = '{0:}x{0:}'.format(self.cfg.input_shape[0]) 248 | 249 | load_dir = os.path.join(self.cfg.model_save_dir, tag, 'model') 250 | print("Loading model...") 251 | checkpoint = tf.train.get_checkpoint_state(os.path.dirname(load_dir)) 252 | if checkpoint is None: 253 | print("Error: No saved model found. Please train first.") 254 | sys.exit(0) 255 | saver.restore(sess, checkpoint.model_checkpoint_path) 256 | --------------------------------------------------------------------------------