├── util ├── __init__.py ├── util.py └── data.py ├── .gitignore ├── images ├── gen_method.png └── image_collage_extended.jpg ├── LICENSE ├── README.md ├── test.py ├── train.py └── models.py /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data/ 3 | log/ 4 | results/ 5 | -------------------------------------------------------------------------------- /images/gen_method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/costapt/vess2ret/HEAD/images/gen_method.png -------------------------------------------------------------------------------- /images/image_collage_extended.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/costapt/vess2ret/HEAD/images/image_collage_extended.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Pedro Miguel Vendas da Costa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Adversarial Retinal Image Synthesis 2 | 3 | [Arxiv](https://arxiv.org/abs/1701.08974) [Demo](http://vess2ret.inesctec.pt) 4 | 5 | We use an image-to-image translation technique based on the idea of adversarial learning to synthesize eye fundus images directly from data. We pair true eye fundus images with their respective vessel trees, by means of a vessel segmentation technique. These pairs are then used to learn a mapping from a binary vessel tree to a new retinal image. 6 | 7 | 8 | 9 | 10 | ## How it works 11 | - Get pairs of binary retinal vessel trees and corresponding retinal images 12 | The user can provide their own vessel annotations. 13 | In our case , because a large enough manually annotated database was not available we applied a DNN vessel segmentation method on the [Messidor database](http://www.adcis.net/en/Download-Third-Party/Messidor.html). For details please refer to [arxiv](https://arxiv.org/abs/1701.08974). 14 | 15 | - Train the image generator on the set of image pairs. 16 | The model was based in [pix2pix](https://github.com/phillipi/pix2pix). We use a Generative Adversarial Network and combine the adversarial loss with a global L1 loss. Our images have 512x512 pixel resolution. The implementation was developed in Python using Keras. 17 | 18 | 19 | - Test the model. 20 | The model is now able to synthesize a new retinal image from any given vessel tree. 21 | 22 |

23 | 24 |

25 | 26 | ## Setup 27 | 28 | ## Prerequisites 29 | - Keras (Theano or Tensorflow backend) with the "image_dim_ordering" set to "th" 30 | 31 | ### Set up directories 32 | 33 | The data must be organized into a train, validation and test directories. By default the directory tree is: 34 | 35 | * 'data/unet_segmentations_binary' 36 | * 'train' 37 | * 'A', contains the binary segmentations 38 | * 'B', contains the retinal images 39 | * 'val' 40 | * 'A', contains the binary segmentations 41 | * 'B', contains the retinal images 42 | * 'test' 43 | * 'A', contains the binary segmentations 44 | * 'B', contains the retinal images 45 | 46 | The defaults can be changed by altering the parameters at run time: 47 | ```bash 48 | python train.py [--base_dir] [--train_dir] [--val_dir] 49 | ``` 50 | Folders {A,B} contain corresponding pairs of images. Make sure these folders have the default name. The pairs should have the same filename. 51 | 52 | ## Usage 53 | 54 | ## Model 55 | 56 | The model can be used with any given vessel tree of the according size. You can download the pre-trained weights available [here](https://drive.google.com/drive/folders/0B_82R0TWezB9VExYbmt2ZUJSUmc?usp=sharing) and load them at test time. If you choose to do this skip the training step. 57 | 58 | ### Train the model 59 | 60 | To train the model run: 61 | 62 | ```bash 63 | python train.py [--help] 64 | ``` 65 | By default the model will be saved to a folder named 'log'. 66 | 67 | ### Test the model 68 | 69 | To test the model run: 70 | 71 | ```bash 72 | python test.py [--help] 73 | ``` 74 | If you are running the test using pre-trained weights downloaded from [here](https://drive.google.com/drive/folders/0B_82R0TWezB9VExYbmt2ZUJSUmc?usp=sharing) make sure both the weights and params.json are saved in the log folder. 75 | 76 | 77 | ## Citation 78 | If you use this code for your research, please cite our paper [Towards Adversarial Retinal Image Synthesis](https://arxiv.org/abs/1701.08974): 79 | 80 | ``` 81 | @article{ costa_retinal_generation_2017, 82 | title={Towards Adversarial Retinal Image Synthesis}, 83 | author={ Costa, P., Galdran, A., Meyer, M.I., Abràmoff, M.D., Niemejer, M., Mendonca, A.M., Campilho, A. }, 84 | journal={arxiv}, 85 | year={2017}, 86 | doi={10.5281/zenodo.265508} 87 | } 88 | ``` 89 | 90 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.265508.svg)](https://doi.org/10.5281/zenodo.265508) 91 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """Auxiliary methods.""" 2 | import os 3 | import json 4 | from errno import EEXIST 5 | 6 | import numpy as np 7 | import seaborn as sns 8 | import cPickle as pickle 9 | import matplotlib.pyplot as plt 10 | 11 | sns.set() 12 | 13 | DEFAULT_LOG_DIR = 'log' 14 | ATOB_WEIGHTS_FILE = 'atob_weights.h5' 15 | D_WEIGHTS_FILE = 'd_weights.h5' 16 | 17 | 18 | class MyDict(dict): 19 | """ 20 | Dictionary that allows to access elements with dot notation. 21 | 22 | ex: 23 | >> d = MyDict({'key': 'val'}) 24 | >> d.key 25 | 'val' 26 | >> d.key2 = 'val2' 27 | >> d 28 | {'key2': 'val2', 'key': 'val'} 29 | """ 30 | 31 | __getattr__ = dict.get 32 | __setattr__ = dict.__setitem__ 33 | 34 | 35 | def convert_to_rgb(img, is_binary=False): 36 | """Given an image, make sure it has 3 channels and that it is between 0 and 1.""" 37 | if len(img.shape) != 3: 38 | raise Exception("""Image must have 3 dimensions (channels x height x width). """ 39 | """Given {0}""".format(len(img.shape))) 40 | 41 | img_ch, _, _ = img.shape 42 | if img_ch != 3 and img_ch != 1: 43 | raise Exception("""Unsupported number of channels. """ 44 | """Must be 1 or 3, given {0}.""".format(img_ch)) 45 | 46 | imgp = img 47 | if img_ch == 1: 48 | imgp = np.repeat(img, 3, axis=0) 49 | 50 | if not is_binary: 51 | imgp = imgp * 127.5 + 127.5 52 | imgp /= 255. 53 | 54 | return np.clip(imgp.transpose((1, 2, 0)), 0, 1) 55 | 56 | 57 | def compose_imgs(a, b, is_a_binary=True, is_b_binary=False): 58 | """Place a and b side by side to be plotted.""" 59 | ap = convert_to_rgb(a, is_binary=is_a_binary) 60 | bp = convert_to_rgb(b, is_binary=is_b_binary) 61 | 62 | if ap.shape != bp.shape: 63 | raise Exception("""A and B must have the same size. """ 64 | """{0} != {1}""".format(ap.shape, bp.shape)) 65 | 66 | # ap.shape and bp.shape must have the same size here 67 | h, w, ch = ap.shape 68 | composed = np.zeros((h, 2*w, ch)) 69 | composed[:, :w, :] = ap 70 | composed[:, w:, :] = bp 71 | 72 | return composed 73 | 74 | 75 | def get_log_dir(log_dir, expt_name): 76 | """Compose the log_dir with the experiment name.""" 77 | if log_dir is None: 78 | raise Exception('log_dir can not be None.') 79 | 80 | if expt_name is not None: 81 | return os.path.join(log_dir, expt_name) 82 | return log_dir 83 | 84 | 85 | def mkdir(mypath): 86 | """Create a directory if it does not exist.""" 87 | try: 88 | os.makedirs(mypath) 89 | except OSError as exc: 90 | if exc.errno == EEXIST and os.path.isdir(mypath): 91 | pass 92 | else: 93 | raise 94 | 95 | 96 | def create_expt_dir(params): 97 | """Create the experiment directory and return it.""" 98 | expt_dir = get_log_dir(params.log_dir, params.expt_name) 99 | 100 | # Create directories if they do not exist 101 | mkdir(params.log_dir) 102 | mkdir(expt_dir) 103 | 104 | # Save the parameters 105 | json.dump(params, open(os.path.join(expt_dir, 'params.json'), 'wb'), 106 | indent=4, sort_keys=True) 107 | 108 | return expt_dir 109 | 110 | 111 | def plot_loss(loss, label, filename, log_dir): 112 | """Plot a loss function and save it in a file.""" 113 | plt.figure(figsize=(5, 4)) 114 | plt.plot(loss, label=label) 115 | plt.legend() 116 | plt.savefig(os.path.join(log_dir, filename)) 117 | plt.clf() 118 | 119 | 120 | def log(losses, atob, it_val, N=4, log_dir=DEFAULT_LOG_DIR, expt_name=None, 121 | is_a_binary=True, is_b_binary=False): 122 | """Log losses and atob results.""" 123 | log_dir = get_log_dir(log_dir, expt_name) 124 | 125 | # Save the losses for further inspection 126 | pickle.dump(losses, open(os.path.join(log_dir, 'losses.pkl'), 'wb')) 127 | 128 | ########################################################################### 129 | # PLOT THE LOSSES # 130 | ########################################################################### 131 | plot_loss(losses['d'], 'discriminator', 'd_loss.png', log_dir) 132 | plot_loss(losses['d_val'], 'discriminator validation', 'd_val_loss.png', log_dir) 133 | 134 | plot_loss(losses['p2p'], 'Pix2Pix', 'p2p_loss.png', log_dir) 135 | plot_loss(losses['p2p_val'], 'Pix2Pix validation', 'p2p_val_loss.png', log_dir) 136 | 137 | ########################################################################### 138 | # PLOT THE A->B RESULTS # 139 | ########################################################################### 140 | plt.figure(figsize=(10, 6)) 141 | for i in range(N*N): 142 | a, _ = next(it_val) 143 | 144 | bp = atob.predict(a) 145 | img = compose_imgs(a[0], bp[0], is_a_binary=is_a_binary, is_b_binary=is_b_binary) 146 | 147 | plt.subplot(N, N, i+1) 148 | plt.imshow(img) 149 | plt.axis('off') 150 | 151 | plt.savefig(os.path.join(log_dir, 'atob.png')) 152 | plt.clf() 153 | 154 | # Make sure all the figures are closed. 155 | plt.close('all') 156 | 157 | 158 | def save_weights(models, log_dir=DEFAULT_LOG_DIR, expt_name=None): 159 | """Save the weights of the models into a file.""" 160 | log_dir = get_log_dir(log_dir, expt_name) 161 | 162 | models.atob.save_weights(os.path.join(log_dir, ATOB_WEIGHTS_FILE), overwrite=True) 163 | models.d.save_weights(os.path.join(log_dir, D_WEIGHTS_FILE), overwrite=True) 164 | 165 | 166 | def load_weights(atob, d, log_dir=DEFAULT_LOG_DIR, expt_name=None): 167 | """Load the weights into the corresponding models.""" 168 | log_dir = get_log_dir(log_dir, expt_name) 169 | 170 | atob.load_weights(os.path.join(log_dir, ATOB_WEIGHTS_FILE)) 171 | d.load_weights(os.path.join(log_dir, D_WEIGHTS_FILE)) 172 | 173 | 174 | def load_weights_of(m, weights_file, log_dir=DEFAULT_LOG_DIR, expt_name=None): 175 | """Load the weights of the model m.""" 176 | log_dir = get_log_dir(log_dir, expt_name) 177 | 178 | m.load_weights(os.path.join(log_dir, weights_file)) 179 | 180 | 181 | def load_losses(log_dir=DEFAULT_LOG_DIR, expt_name=None): 182 | """Load the losses of the given experiment.""" 183 | log_dir = get_log_dir(log_dir, expt_name) 184 | losses = pickle.load(open(os.path.join(log_dir, 'losses.pkl'), 'rb')) 185 | return losses 186 | 187 | 188 | def load_params(params): 189 | """ 190 | Load the parameters of an experiment and return them. 191 | 192 | The params passed as argument will be merged with the new params dict. 193 | If there is a conflict with a key, the params passed as argument prevails. 194 | """ 195 | expt_dir = get_log_dir(params.log_dir, params.expt_name) 196 | 197 | expt_params = json.load(open(os.path.join(expt_dir, 'params.json'), 'rb')) 198 | 199 | # Update the loaded parameters with the current parameters. This will 200 | # override conflicting keys as expected. 201 | expt_params.update(params) 202 | 203 | return expt_params 204 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Script to test a trained model.""" 2 | import os 3 | import sys 4 | import getopt 5 | 6 | import numpy as np 7 | import models as m 8 | import matplotlib.pyplot as plt 9 | import util.util as u 10 | 11 | from util.data import TwoImageIterator 12 | from util.util import MyDict, load_params, load_weights_of, compose_imgs, convert_to_rgb, mkdir, get_log_dir 13 | 14 | 15 | def print_help(): 16 | """Print how to use this script.""" 17 | print "Usage:" 18 | print "test.py [--help] [--results_dir] [--log_dir] [--base_dir] [--train_dir] [--val_dir] " \ 19 | "[--test_dir] [--load_to_memory] [--expt_name] [--target_size] [--N]" 20 | print "--results_dir: Directory where to save the results." 21 | print "--log_dir': Directory where the experiment was logged." 22 | print "--base_dir: Directory that contains the data." 23 | print "--train_dir: Directory inside base_dir that contains training data." 24 | print "--val_dir: Directory inside base_dir that contains validation data." 25 | print "--test_dir: Directory inside base_dir that contains test data." 26 | print "--load_to_memory: Whether to load the images into memory." 27 | print "--expt_name: The name of the experiment to test." 28 | print "--target_size: The size of the images loaded by the iterator." 29 | print "--N: The number of samples to generate." 30 | 31 | 32 | def join_and_create_dir(*paths): 33 | """Join the paths provided as arguments, create the directory and return the path.""" 34 | path = os.path.join(*paths) 35 | mkdir(path) 36 | 37 | return path 38 | 39 | 40 | def save_pix2pix(unet, it, path, params): 41 | """Save the results of the pix2pix model.""" 42 | real_dir = join_and_create_dir(path, 'real') 43 | a_dir = join_and_create_dir(path, 'A') 44 | b_dir = join_and_create_dir(path, 'B') 45 | comp_dir = join_and_create_dir(path, 'composed') 46 | 47 | for i, filename in enumerate(it.filenames): 48 | a, b = next(it) 49 | bp = unet.predict(a) 50 | bp = convert_to_rgb(bp[0], is_binary=params.is_b_binary) 51 | 52 | img = compose_imgs(a[0], b[0], is_a_binary=params.is_a_binary, is_b_binary=params.is_b_binary) 53 | hi, wi, chi = img.shape 54 | hb, wb, chb = bp.shape 55 | if hi != hb or wi != 2*wb or chi != chb: 56 | raise Exception("Mismatch in img and bp dimensions {0} / {1}".format(img.shape, bp.shape)) 57 | 58 | composed = np.zeros((hi, wi+wb, chi)) 59 | composed[:, :wi, :] = img 60 | composed[:, wi:, :] = bp 61 | 62 | a = convert_to_rgb(a[0], is_binary=params.is_a_binary) 63 | b = convert_to_rgb(b[0], is_binary=params.is_b_binary) 64 | 65 | plt.imsave(open(os.path.join(real_dir, filename), 'wb+'), b) 66 | plt.imsave(open(os.path.join(b_dir, filename), 'wb+'), bp) 67 | plt.imsave(open(os.path.join(a_dir, filename), 'wb+'), a) 68 | plt.imsave(open(os.path.join(comp_dir, filename), 'wb+'), composed) 69 | 70 | 71 | def save_all_pix2pix(unet, it_train, it_val, it_test, params): 72 | """Save all the results of the pix2pix model.""" 73 | expt_dir = get_log_dir(params.results_dir, params.expt_name) 74 | 75 | # Create directores if they do not exist 76 | mkdir(params.results_dir) 77 | mkdir(expt_dir) 78 | 79 | train_dir = join_and_create_dir(expt_dir, params.train_dir) 80 | val_dir = join_and_create_dir(expt_dir, params.val_dir) 81 | test_dir = join_and_create_dir(expt_dir, params.test_dir) 82 | 83 | save_pix2pix(unet, it_train, train_dir, params) 84 | save_pix2pix(unet, it_val, val_dir, params) 85 | save_pix2pix(unet, it_test, test_dir, params) 86 | 87 | 88 | if __name__ == '__main__': 89 | a = sys.argv[1:] 90 | 91 | params = MyDict({ 92 | 'results_dir': 'results', # Directory where to save the results 93 | 'log_dir': 'log', # Directory where the experiment was logged 94 | 'base_dir': 'data/unet_segmentations_binary', # Directory that contains the data 95 | 'train_dir': 'train', # Directory inside base_dir that contains training data 96 | 'val_dir': 'val', # Directory inside base_dir that contains validation data 97 | 'test_dir': 'test', # Directory inside base_dir that contains test data 98 | 'load_to_memory': True, # Whether to load the images into memory 99 | 'expt_name': None, # The name of the experiment to test 100 | 'target_size': 512, # The size of the images loaded by the iterator 101 | 'N': 100, # The number of samples to generate 102 | }) 103 | 104 | param_names = [k + '=' for k in params.keys()] + ['help'] 105 | 106 | try: 107 | opts, args = getopt.getopt(a, '', param_names) 108 | except getopt.GetoptError: 109 | print_help() 110 | sys.exit() 111 | 112 | for opt, arg in opts: 113 | if opt == '--help': 114 | print_help() 115 | sys.exit() 116 | elif opt in ('--target_size', '--N'): 117 | params[opt[2:]] = int(arg) 118 | elif opt in ('--load_to_memory'): 119 | params[opt[2:]] = True if arg == 'True' else False 120 | elif opt in ('--results_dir', '--log_dir', '--base_dir', '--train_dir', 121 | '--val_dir', '--test_dir', '--expt_name'): 122 | params[opt[2:]] = arg 123 | 124 | params = load_params(params) 125 | params = MyDict(params) 126 | 127 | # Define the U-Net generator 128 | unet = m.g_unet(params.a_ch, params.b_ch, params.nfatob, is_binary=params.is_b_binary) 129 | load_weights_of(unet, u.ATOB_WEIGHTS_FILE, log_dir=params.log_dir, expt_name=params.expt_name) 130 | 131 | ts = params.target_size 132 | train_dir = os.path.join(params.base_dir, params.train_dir) 133 | it_train = TwoImageIterator(train_dir, is_a_binary=params.is_a_binary, 134 | is_a_grayscale=params.is_a_grayscale, 135 | is_b_grayscale=params.is_b_grayscale, 136 | is_b_binary=params.is_b_binary, batch_size=1, 137 | load_to_memory=params.load_to_memory, 138 | target_size=(ts, ts), shuffle=False) 139 | val_dir = os.path.join(params.base_dir, params.val_dir) 140 | it_val = TwoImageIterator(val_dir, is_a_binary=params.is_a_binary, 141 | is_b_binary=params.is_b_binary, 142 | is_a_grayscale=params.is_a_grayscale, 143 | is_b_grayscale=params.is_b_grayscale, batch_size=1, 144 | load_to_memory=params.load_to_memory, 145 | target_size=(ts, ts), shuffle=False) 146 | test_dir = os.path.join(params.base_dir, params.test_dir) 147 | it_test = TwoImageIterator(test_dir, is_a_binary=params.is_a_binary, 148 | is_b_binary=params.is_b_binary, 149 | is_a_grayscale=params.is_a_grayscale, 150 | is_b_grayscale=params.is_b_grayscale, batch_size=1, 151 | load_to_memory=params.load_to_memory, 152 | target_size=(ts, ts), shuffle=False) 153 | 154 | save_all_pix2pix(unet, it_train, it_val, it_test, params) 155 | -------------------------------------------------------------------------------- /util/data.py: -------------------------------------------------------------------------------- 1 | """Auxiliar methods to deal with loading the dataset.""" 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | 7 | from keras.preprocessing.image import apply_transform, flip_axis 8 | from keras.preprocessing.image import transform_matrix_offset_center 9 | from keras.preprocessing.image import Iterator, load_img, img_to_array 10 | 11 | 12 | class TwoImageIterator(Iterator): 13 | """Class to iterate A and B images at the same time.""" 14 | 15 | def __init__(self, directory, a_dir_name='A', b_dir_name='B', load_to_memory=False, 16 | is_a_binary=False, is_b_binary=False, is_a_grayscale=False, 17 | is_b_grayscale=False, target_size=(256, 256), rotation_range=0., 18 | height_shift_range=0., width_shift_range=0., zoom_range=0., 19 | fill_mode='constant', cval=0., horizontal_flip=False, 20 | vertical_flip=False, dim_ordering='default', N=-1, 21 | batch_size=32, shuffle=True, seed=None): 22 | """ 23 | Iterate through two directories at the same time. 24 | 25 | Files under the directory A and B with the same name will be returned 26 | at the same time. 27 | Parameters: 28 | - directory: base directory of the dataset. Should contain two 29 | directories with name a_dir_name and b_dir_name; 30 | - a_dir_name: name of directory under directory that contains the A 31 | images; 32 | - b_dir_name: name of directory under directory that contains the B 33 | images; 34 | - load_to_memory: if true, loads the images to memory when creating the 35 | iterator; 36 | - is_a_binary: converts A images to binary images. Applies a threshold of 0.5. 37 | - is_b_binary: converts B images to binary images. Applies a threshold of 0.5. 38 | - is_a_grayscale: if True, A images will only have one channel. 39 | - is_b_grayscale: if True, B images will only have one channel. 40 | - N: if -1 uses the entire dataset. Otherwise only uses a subset; 41 | - batch_size: the size of the batches to create; 42 | - shuffle: if True the order of the images in X will be shuffled; 43 | - seed: seed for a random number generator. 44 | """ 45 | self.directory = directory 46 | 47 | self.a_dir = os.path.join(directory, a_dir_name) 48 | self.b_dir = os.path.join(directory, b_dir_name) 49 | 50 | a_files = set(x for x in os.listdir(self.a_dir)) 51 | b_files = set(x for x in os.listdir(self.b_dir)) 52 | # Files inside a and b should have the same name. Images without a pair are discarded. 53 | self.filenames = list(a_files.intersection(b_files)) 54 | 55 | # Use only a subset of the files. Good to easily overfit the model 56 | if N > 0: 57 | random.shuffle(self.filenames) 58 | self.filenames = self.filenames[:N] 59 | self.N = len(self.filenames) 60 | if self.N == 0: 61 | raise Exception("""Did not find any pair in the dataset. Please check that """ 62 | """the names and extensions of the pairs are exactly the same. """ 63 | """Searched inside folders: {0} and {1}""".format(self.a_dir, self.b_dir)) 64 | 65 | self.dim_ordering = dim_ordering 66 | if self.dim_ordering not in ('th', 'default', 'tf'): 67 | raise Exception('dim_ordering should be one of "th", "tf" or "default". ' 68 | 'Got {0}'.format(self.dim_ordering)) 69 | 70 | self.target_size = target_size 71 | 72 | self.is_a_binary = is_a_binary 73 | self.is_b_binary = is_b_binary 74 | self.is_a_grayscale = is_a_grayscale 75 | self.is_b_grayscale = is_b_grayscale 76 | 77 | self.image_shape_a = self._get_image_shape(self.is_a_grayscale) 78 | self.image_shape_b = self._get_image_shape(self.is_b_grayscale) 79 | 80 | self.load_to_memory = load_to_memory 81 | if self.load_to_memory: 82 | self._load_imgs_to_memory() 83 | 84 | if self.dim_ordering in ('th', 'default'): 85 | self.channel_index = 1 86 | self.row_index = 2 87 | self.col_index = 3 88 | if dim_ordering == 'tf': 89 | self.channel_index = 3 90 | self.row_index = 1 91 | self.col_index = 2 92 | 93 | self.rotation_range = rotation_range 94 | self.height_shift_range = height_shift_range 95 | self.width_shift_range = width_shift_range 96 | self.fill_mode = fill_mode 97 | self.cval = cval 98 | self.horizontal_flip = horizontal_flip 99 | self.vertical_flip = vertical_flip 100 | 101 | if np.isscalar(zoom_range): 102 | self.zoom_range = [1 - zoom_range, 1 + zoom_range] 103 | elif len(zoom_range) == 2: 104 | self.zoom_range = [zoom_range[0], zoom_range[1]] 105 | 106 | super(TwoImageIterator, self).__init__(len(self.filenames), batch_size, 107 | shuffle, seed) 108 | 109 | def _get_image_shape(self, is_grayscale): 110 | """Auxiliar method to get the image shape given the color mode.""" 111 | if is_grayscale: 112 | if self.dim_ordering == 'tf': 113 | return self.target_size + (1,) 114 | else: 115 | return (1,) + self.target_size 116 | else: 117 | if self.dim_ordering == 'tf': 118 | return self.target_size + (3,) 119 | else: 120 | return (3,) + self.target_size 121 | 122 | def _load_imgs_to_memory(self): 123 | """Load images to memory.""" 124 | if not self.load_to_memory: 125 | raise Exception('Can not load images to memory. Reason: load_to_memory = False') 126 | 127 | self.a = np.zeros((self.N,) + self.image_shape_a) 128 | self.b = np.zeros((self.N,) + self.image_shape_b) 129 | 130 | for idx in range(self.N): 131 | ai, bi = self._load_img_pair(idx, False) 132 | self.a[idx] = ai 133 | self.b[idx] = bi 134 | 135 | def _binarize(self, batch): 136 | """Make input binary images have 0 and 1 values only.""" 137 | bin_batch = batch / 255. 138 | bin_batch[bin_batch >= 0.5] = 1 139 | bin_batch[bin_batch < 0.5] = 0 140 | return bin_batch 141 | 142 | def _normalize_for_tanh(self, batch): 143 | """Make input image values lie between -1 and 1.""" 144 | tanh_batch = batch - 127.5 145 | tanh_batch /= 127.5 146 | return tanh_batch 147 | 148 | def _load_img_pair(self, idx, load_from_memory): 149 | """Get a pair of images with index idx.""" 150 | if load_from_memory: 151 | a = self.a[idx] 152 | b = self.b[idx] 153 | return a, b 154 | 155 | fname = self.filenames[idx] 156 | 157 | a = load_img(os.path.join(self.a_dir, fname), 158 | grayscale=self.is_a_grayscale, 159 | target_size=self.target_size) 160 | b = load_img(os.path.join(self.b_dir, fname), 161 | grayscale=self.is_b_grayscale, 162 | target_size=self.target_size) 163 | 164 | a = img_to_array(a, self.dim_ordering) 165 | b = img_to_array(b, self.dim_ordering) 166 | 167 | return a, b 168 | 169 | def _random_transform(self, a, b): 170 | """ 171 | Random dataset augmentation. 172 | 173 | Adapted from https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py 174 | """ 175 | # a and b are single images, so they don't have image number at index 0 176 | img_row_index = self.row_index - 1 177 | img_col_index = self.col_index - 1 178 | img_channel_index = self.channel_index - 1 179 | 180 | # use composition of homographies to generate final transform that needs to be applied 181 | if self.rotation_range: 182 | theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range) 183 | else: 184 | theta = 0 185 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 186 | [np.sin(theta), np.cos(theta), 0], 187 | [0, 0, 1]]) 188 | if self.height_shift_range: 189 | tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * a.shape[img_row_index] 190 | else: 191 | tx = 0 192 | 193 | if self.width_shift_range: 194 | ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * a.shape[img_col_index] 195 | else: 196 | ty = 0 197 | 198 | translation_matrix = np.array([[1, 0, tx], 199 | [0, 1, ty], 200 | [0, 0, 1]]) 201 | 202 | if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: 203 | zx, zy = 1, 1 204 | else: 205 | zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2) 206 | zoom_matrix = np.array([[zx, 0, 0], 207 | [0, zy, 0], 208 | [0, 0, 1]]) 209 | 210 | transform_matrix = np.dot(np.dot(rotation_matrix, translation_matrix), zoom_matrix) 211 | 212 | h, w = a.shape[img_row_index], a.shape[img_col_index] 213 | transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) 214 | a = apply_transform(a, transform_matrix, img_channel_index, 215 | fill_mode=self.fill_mode, cval=self.cval) 216 | b = apply_transform(b, transform_matrix, img_channel_index, 217 | fill_mode=self.fill_mode, cval=self.cval) 218 | 219 | if self.horizontal_flip: 220 | if np.random.random() < 0.5: 221 | a = flip_axis(a, img_col_index) 222 | b = flip_axis(b, img_col_index) 223 | 224 | if self.vertical_flip: 225 | if np.random.random() < 0.5: 226 | a = flip_axis(a, img_row_index) 227 | b = flip_axis(b, img_row_index) 228 | 229 | return a, b 230 | 231 | def next(self): 232 | """Get the next pair of the sequence.""" 233 | # Lock the iterator when the index is changed. 234 | with self.lock: 235 | index_array, _, current_batch_size = next(self.index_generator) 236 | 237 | batch_a = np.zeros((current_batch_size,) + self.image_shape_a) 238 | batch_b = np.zeros((current_batch_size,) + self.image_shape_b) 239 | 240 | for i, j in enumerate(index_array): 241 | a_img, b_img = self._load_img_pair(j, self.load_to_memory) 242 | a_img, b_img = self._random_transform(a_img, b_img) 243 | 244 | batch_a[i] = a_img 245 | batch_b[i] = b_img 246 | 247 | if self.is_a_binary: 248 | batch_a = self._binarize(batch_a) 249 | else: 250 | batch_a = self._normalize_for_tanh(batch_a) 251 | 252 | if self.is_b_binary: 253 | batch_b = self._binarize(batch_b) 254 | else: 255 | batch_b = self._normalize_for_tanh(batch_b) 256 | 257 | return [batch_a, batch_b] 258 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """The script used to train the model.""" 2 | import os 3 | import sys 4 | import getopt 5 | 6 | import numpy as np 7 | import models as m 8 | 9 | from tqdm import tqdm 10 | from keras.optimizers import Adam 11 | from util.data import TwoImageIterator 12 | from util.util import MyDict, log, save_weights, load_weights, load_losses, create_expt_dir 13 | 14 | 15 | def print_help(): 16 | """Print how to use this script.""" 17 | print "Usage:" 18 | print "train.py [--help] [--nfd] [--nfatob] [--alpha] [--epochs] [batch_size] [--samples_per_batch] " \ 19 | "[--save_every] [--lr] [--beta_1] [--continue_train] [--log_dir]" \ 20 | "[--expt_name] [--base_dir] [--train_dir] [--val_dir] [--train_samples] " \ 21 | "[--val_samples] [--load_to_memory] [--a_ch] [--b_ch] [--is_a_binary] " \ 22 | "[--is_b_binary] [--is_a_grayscale] [--is_b_grayscale] [--target_size] " \ 23 | "[--rotation_range] [--height_shift_range] [--width_shift_range] " \ 24 | "[--horizontal_flip] [--vertical_flip] [--zoom_range]" 25 | print "--nfd: Number of filters of the first layer of the discriminator." 26 | print "--nfatob: Number of filters of the first layer of the AtoB model." 27 | print "--alpha: The weight of the reconstruction loss of the AtoB model." 28 | print "--epochs: Number of epochs to train the model." 29 | print "--batch_size: the size of the batch to train." 30 | print "--samples_per_batch: The number of samples to train each model on each iteration." 31 | print "--save_every: Save results every 'save_every' epochs on the log folder." 32 | print "--lr: The learning rate to train the models." 33 | print "--beta_1: The beta_1 value of the Adam optimizer." 34 | print "--continue_train: If it should continue the training from the last checkpoint." 35 | print "--log_dir: The directory to place the logs." 36 | print "--expt_name: The name of the experiment. Saves the logs into a folder with this name." 37 | print "--base_dir: Directory that contains the data." 38 | print "--train_dir: Directory inside base_dir that contains training data. " \ 39 | "Must contain an A and B folder." 40 | print "--val_dir: Directory inside base_dir that contains validation data. " \ 41 | "Must contain an A and B folder." 42 | print "--train_samples: The number of training samples. Set -1 to be the same as training examples." 43 | print "--val_samples: The number of validation samples. Set -1 to be the same as validation examples." 44 | print "--load_to_memory: Whether to load images into memory or read from the filesystem." 45 | print "--a_ch: Number of channels of images A." 46 | print "--b_ch: Number of channels of images B." 47 | print "--is_a_binary: If A is binary, its values will be 0 or 1. A threshold of 0.5 is used." 48 | print "--is_b_binary: If B is binary, the last layer of the atob model is " \ 49 | "followed by a sigmoid. Otherwise, a tanh is used. When the sigmoid is " \ 50 | "used, the binary crossentropy loss is used. For the tanh, the L1 is used. Also, " \ 51 | "its values will be 0 or 1. A threshold of 0.5 is used." 52 | print "--is_a_grayscale: If A images should only have one channel. If they are color images, " \ 53 | "they are converted to grayscale." 54 | print "--is_b_grayscale: If B images should only have one channel. If they are color images, " \ 55 | "they are converted to grayscale." 56 | print "--target_size: The size of the images loaded by the iterator. THIS DOES NOT CHANGE THE MODELS. " \ 57 | "If you want to accept images of different sizes you will need to update the models.py files." 58 | print "--rotation_range: The range to rotate training images for dataset augmentation." 59 | print "--height_shift_range: Percentage of height of the image to translate for dataset augmentation." 60 | print "--width_shift_range: Percentage of width of the image to translate for dataset augmentation." 61 | print "--horizontal_flip: If true performs random horizontal flips on the train set." 62 | print "--vertical_flip: If true performs random vertical flips on the train set." 63 | print "--zoom_range: Defines the range to scale the image for dataset augmentation." 64 | 65 | 66 | def discriminator_generator(it, atob, dout_size): 67 | """ 68 | Generate batches for the discriminator. 69 | 70 | Parameters: 71 | - it: an iterator that returns a pair of images; 72 | - atob: the generator network that maps an image to another representation; 73 | - dout_size: the size of the output of the discriminator. 74 | """ 75 | while True: 76 | # Fake pair 77 | a_fake, _ = next(it) 78 | b_fake = atob.predict(a_fake) 79 | 80 | # Real pair 81 | a_real, b_real = next(it) 82 | 83 | # Concatenate the channels. Images become (ch_a + ch_b) x 256 x 256 84 | fake = np.concatenate((a_fake, b_fake), axis=1) 85 | real = np.concatenate((a_real, b_real), axis=1) 86 | 87 | # Concatenate fake and real pairs into a single batch 88 | batch_x = np.concatenate((fake, real), axis=0) 89 | 90 | # 1 is fake, 0 is real 91 | batch_y = np.ones((batch_x.shape[0], 1) + dout_size) 92 | batch_y[fake.shape[0]:] = 0 93 | 94 | yield batch_x, batch_y 95 | 96 | 97 | def train_discriminator(d, it, samples_per_batch=20): 98 | """Train the discriminator network.""" 99 | return d.fit_generator(it, samples_per_epoch=samples_per_batch*2, nb_epoch=1, verbose=False) 100 | 101 | 102 | def pix2pix_generator(it, dout_size): 103 | """ 104 | Generate data for the generator network. 105 | 106 | Parameters: 107 | - it: an iterator that returns a pair of images; 108 | - dout_size: the size of the output of the discriminator. 109 | """ 110 | for a, b in it: 111 | # 1 is fake, 0 is real 112 | y = np.zeros((a.shape[0], 1) + dout_size) 113 | yield [a, b], y 114 | 115 | 116 | def train_pix2pix(pix2pix, it, samples_per_batch=20): 117 | """Train the generator network.""" 118 | return pix2pix.fit_generator(it, nb_epoch=1, samples_per_epoch=samples_per_batch, verbose=False) 119 | 120 | 121 | def evaluate(models, generators, losses, val_samples=192): 122 | """Evaluate and display the losses of the models.""" 123 | # Get necessary generators 124 | d_gen = generators.d_gen_val 125 | p2p_gen = generators.p2p_gen_val 126 | 127 | # Get necessary models 128 | d = models.d 129 | p2p = models.p2p 130 | 131 | # Evaluate 132 | d_loss = d.evaluate_generator(d_gen, val_samples) 133 | p2p_loss = p2p.evaluate_generator(p2p_gen, val_samples) 134 | 135 | losses['d_val'].append(d_loss) 136 | losses['p2p_val'].append(p2p_loss) 137 | 138 | print '' 139 | print ('Train Losses of (D={0} / P2P={1});\n' 140 | 'Validation Losses of (D={2} / P2P={3})'.format( 141 | losses['d'][-1], losses['p2p'][-1], d_loss, p2p_loss)) 142 | 143 | return d_loss, p2p_loss 144 | 145 | 146 | def model_creation(d, atob, params): 147 | """Create all the necessary models.""" 148 | opt = Adam(lr=params.lr, beta_1=params.beta_1) 149 | p2p = m.pix2pix(atob, d, params.a_ch, params.b_ch, alpha=params.alpha, opt=opt, 150 | is_a_binary=params.is_a_binary, is_b_binary=params.is_b_binary) 151 | 152 | models = MyDict({ 153 | 'atob': atob, 154 | 'd': d, 155 | 'p2p': p2p, 156 | }) 157 | 158 | return models 159 | 160 | 161 | def generators_creation(it_train, it_val, models, dout_size): 162 | """Create all the necessary data generators.""" 163 | # Discriminator data generators 164 | d_gen = discriminator_generator(it_train, models.atob, dout_size) 165 | d_gen_val = discriminator_generator(it_val, models.atob, dout_size) 166 | 167 | # Workaround to make tensorflow work. When atob.predict is called the first 168 | # time it calls tf.get_default_graph. This should be done on the main thread 169 | # and not inside fit_generator. See https://github.com/fchollet/keras/issues/2397 170 | next(d_gen) 171 | 172 | # pix2pix data generators 173 | p2p_gen = pix2pix_generator(it_train, dout_size) 174 | p2p_gen_val = pix2pix_generator(it_val, dout_size) 175 | 176 | generators = MyDict({ 177 | 'd_gen': d_gen, 178 | 'd_gen_val': d_gen_val, 179 | 'p2p_gen': p2p_gen, 180 | 'p2p_gen_val': p2p_gen_val, 181 | }) 182 | 183 | return generators 184 | 185 | 186 | def train_iteration(models, generators, losses, params): 187 | """Perform a train iteration.""" 188 | # Get necessary generators 189 | d_gen = generators.d_gen 190 | p2p_gen = generators.p2p_gen 191 | 192 | # Get necessary models 193 | d = models.d 194 | p2p = models.p2p 195 | 196 | # Update the dscriminator 197 | dhist = train_discriminator(d, d_gen, samples_per_batch=params.samples_per_batch) 198 | losses['d'].extend(dhist.history['loss']) 199 | 200 | # Update the generator 201 | p2phist = train_pix2pix(p2p, p2p_gen, samples_per_batch=params.samples_per_batch) 202 | losses['p2p'].extend(p2phist.history['loss']) 203 | 204 | 205 | def train(models, it_train, it_val, params): 206 | """ 207 | Train the model. 208 | 209 | Parameters: 210 | - models: a dictionary with all the models. 211 | - atob: a model that goes from A to B. 212 | - d: the discriminator model. 213 | - p2p: a Pix2Pix model. 214 | - it_train: the iterator of the training data. 215 | - it_val: the iterator of the validation data. 216 | - params: parameters of the training procedure. 217 | - dout_size: the size of the output of the discriminator model. 218 | """ 219 | # Create the experiment folder and save the parameters 220 | create_expt_dir(params) 221 | 222 | # Get the output shape of the discriminator 223 | dout_size = d.output_shape[-2:] 224 | # Define the data generators 225 | generators = generators_creation(it_train, it_val, models, dout_size) 226 | 227 | # Define the number of samples to use on each training epoch 228 | train_samples = params.train_samples 229 | if params.train_samples == -1: 230 | train_samples = it_train.N 231 | batches_per_epoch = train_samples // params.samples_per_batch 232 | 233 | # Define the number of samples to use for validation 234 | val_samples = params.val_samples 235 | if val_samples == -1: 236 | val_samples = it_val.N 237 | 238 | losses = {'p2p': [], 'd': [], 'p2p_val': [], 'd_val': []} 239 | if params.continue_train: 240 | losses = load_losses(log_dir=params.log_dir, expt_name=params.expt_name) 241 | 242 | for e in tqdm(range(params.epochs)): 243 | 244 | for b in range(batches_per_epoch): 245 | train_iteration(models, generators, losses, params) 246 | 247 | # Evaluate how the models is doing on the validation set. 248 | evaluate(models, generators, losses, val_samples=val_samples) 249 | 250 | if (e + 1) % params.save_every == 0: 251 | save_weights(models, log_dir=params.log_dir, expt_name=params.expt_name) 252 | log(losses, models.atob, it_val, log_dir=params.log_dir, expt_name=params.expt_name, 253 | is_a_binary=params.is_a_binary, is_b_binary=params.is_b_binary) 254 | 255 | if __name__ == '__main__': 256 | a = sys.argv[1:] 257 | 258 | params = MyDict({ 259 | # Model 260 | 'nfd': 32, # Number of filters of the first layer of the discriminator 261 | 'nfatob': 64, # Number of filters of the first layer of the AtoB model 262 | 'alpha': 100, # The weight of the reconstruction loss of the atob model 263 | # Train 264 | 'epochs': 100, # Number of epochs to train the model 265 | 'batch_size': 1, # The batch size 266 | 'samples_per_batch': 20, # The number of samples to train each model on each iteration 267 | 'save_every': 10, # Save results every 'save_every' epochs on the log folder 268 | 'lr': 2e-4, # The learning rate to train the models 269 | 'beta_1': 0.5, # The beta_1 value of the Adam optimizer 270 | 'continue_train': False, # If it should continue the training from the last checkpoint 271 | # File system 272 | 'log_dir': 'log', # Directory to log 273 | 'expt_name': None, # The name of the experiment. Saves the logs into a folder with this name 274 | 'base_dir': 'data/unet_segmentations_binary', # Directory that contains the data 275 | 'train_dir': 'train', # Directory inside base_dir that contains training data 276 | 'val_dir': 'val', # Directory inside base_dir that contains validation data 277 | 'train_samples': -1, # The number of training samples. Set -1 to be the same as training examples 278 | 'val_samples': -1, # The number of validation samples. Set -1 to be the same as validation examples 279 | 'load_to_memory': True, # Whether to load the images into memory 280 | # Image 281 | 'a_ch': 1, # Number of channels of images A 282 | 'b_ch': 3, # Number of channels of images B 283 | 'is_a_binary': True, # If A is binary, its values will be either 0 or 1 284 | 'is_b_binary': False, # If B is binary, the last layer of the atob model is followed by a sigmoid 285 | 'is_a_grayscale': True, # If A is grayscale, the image will only have one channel 286 | 'is_b_grayscale': False, # If B is grayscale, the image will only have one channel 287 | 'target_size': 512, # The size of the images loaded by the iterator. DOES NOT CHANGE THE MODELS 288 | 'rotation_range': 0., # The range to rotate training images for dataset augmentation 289 | 'height_shift_range': 0., # Percentage of height of the image to translate for dataset augmentation 290 | 'width_shift_range': 0., # Percentage of width of the image to translate for dataset augmentation 291 | 'horizontal_flip': False, # If true performs random horizontal flips on the train set 292 | 'vertical_flip': False, # If true performs random vertical flips on the train set 293 | 'zoom_range': 0., # Defines the range to scale the image for dataset augmentation 294 | }) 295 | 296 | param_names = [k + '=' for k in params.keys()] + ['help'] 297 | 298 | try: 299 | opts, args = getopt.getopt(a, '', param_names) 300 | except getopt.GetoptError: 301 | print_help() 302 | sys.exit() 303 | 304 | for opt, arg in opts: 305 | if opt == '--help': 306 | print_help() 307 | sys.exit() 308 | elif opt in ('--nfatob' '--nfd', '--a_ch', '--b_ch', '--epochs', '--batch_size', 309 | '--samples_per_batch', '--save_every', '--train_samples', '--val_samples', 310 | '--target_size'): 311 | params[opt[2:]] = int(arg) 312 | elif opt in ('--lr', '--beta_1', '--rotation_range', '--height_shift_range', 313 | '--width_shift_range', '--zoom_range', '--alpha'): 314 | params[opt[2:]] = float(arg) 315 | elif opt in ('--is_a_binary', '--is_b_binary', '--is_a_grayscale', '--is_b_grayscale', 316 | '--continue_train', '--horizontal_flip', '--vertical_flip', 317 | '--load_to_memory'): 318 | params[opt[2:]] = True if arg == 'True' else False 319 | elif opt in ('--base_dir', '--train_dir', '--val_dir', '--expt_name', '--log_dir'): 320 | params[opt[2:]] = arg 321 | 322 | dopt = Adam(lr=params.lr, beta_1=params.beta_1) 323 | 324 | # Define the U-Net generator 325 | unet = m.g_unet(params.a_ch, params.b_ch, params.nfatob, 326 | batch_size=params.batch_size, is_binary=params.is_b_binary) 327 | 328 | # Define the discriminator 329 | d = m.discriminator(params.a_ch, params.b_ch, params.nfd, opt=dopt) 330 | 331 | if params.continue_train: 332 | load_weights(unet, d, log_dir=params.log_dir, expt_name=params.expt_name) 333 | 334 | ts = params.target_size 335 | train_dir = os.path.join(params.base_dir, params.train_dir) 336 | it_train = TwoImageIterator(train_dir, is_a_binary=params.is_a_binary, 337 | is_a_grayscale=params.is_a_grayscale, 338 | is_b_grayscale=params.is_b_grayscale, 339 | is_b_binary=params.is_b_binary, 340 | batch_size=params.batch_size, 341 | load_to_memory=params.load_to_memory, 342 | rotation_range=params.rotation_range, 343 | height_shift_range=params.height_shift_range, 344 | width_shift_range=params.height_shift_range, 345 | zoom_range=params.zoom_range, 346 | horizontal_flip=params.horizontal_flip, 347 | vertical_flip=params.vertical_flip, 348 | target_size=(ts, ts)) 349 | val_dir = os.path.join(params.base_dir, params.val_dir) 350 | it_val = TwoImageIterator(val_dir, is_a_binary=params.is_a_binary, 351 | is_b_binary=params.is_b_binary, 352 | is_a_grayscale=params.is_a_grayscale, 353 | is_b_grayscale=params.is_b_grayscale, 354 | batch_size=params.batch_size, 355 | load_to_memory=params.load_to_memory, 356 | target_size=(ts, ts)) 357 | 358 | models = model_creation(d, unet, params) 359 | train(models, it_train, it_val, params) 360 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | __doc__ = """The model definitions for the pix2pix network taken from the 2 | retina repository at https://github.com/costapt/vess2ret 3 | """ 4 | import os 5 | 6 | import keras 7 | from keras import backend as K 8 | from keras import objectives 9 | from keras.layers import Input, merge 10 | from keras.layers.advanced_activations import LeakyReLU 11 | from keras.layers.convolutional import Convolution2D, Deconvolution2D 12 | from keras.layers.core import Activation, Dropout 13 | from keras.layers.normalization import BatchNormalization 14 | from keras.models import Model 15 | from keras.optimizers import Adam 16 | 17 | KERAS_2 = keras.__version__[0] == '2' 18 | try: 19 | # keras 2 imports 20 | from keras.layers.convolutional import Conv2DTranspose 21 | from keras.layers.merge import Concatenate 22 | except ImportError: 23 | print("Keras 2 layers could not be imported defaulting to keras1") 24 | KERAS_2 = False 25 | 26 | K.set_image_dim_ordering('th') 27 | 28 | 29 | def concatenate_layers(inputs, concat_axis, mode='concat'): 30 | if KERAS_2: 31 | assert mode == 'concat', "Only concatenation is supported in this wrapper" 32 | return Concatenate(axis=concat_axis)(inputs) 33 | else: 34 | return merge(inputs=inputs, concat_axis=concat_axis, mode=mode) 35 | 36 | 37 | def Convolution(f, k=3, s=2, border_mode='same', **kwargs): 38 | """Convenience method for Convolutions.""" 39 | if KERAS_2: 40 | return Convolution2D(f, 41 | kernel_size=(k, k), 42 | padding=border_mode, 43 | strides=(s, s), 44 | **kwargs) 45 | else: 46 | return Convolution2D(f, k, k, border_mode=border_mode, 47 | subsample=(s, s), 48 | **kwargs) 49 | 50 | 51 | def Deconvolution(f, output_shape, k=2, s=2, **kwargs): 52 | """Convenience method for Transposed Convolutions.""" 53 | if KERAS_2: 54 | return Conv2DTranspose(f, 55 | kernel_size=(k, k), 56 | output_shape=output_shape, 57 | strides=(s, s), 58 | data_format=K.image_data_format(), 59 | **kwargs) 60 | else: 61 | return Deconvolution2D(f, k, k, output_shape=output_shape, 62 | subsample=(s, s), **kwargs) 63 | 64 | 65 | def BatchNorm(mode=2, axis=1, **kwargs): 66 | """Convenience method for BatchNormalization layers.""" 67 | if KERAS_2: 68 | return BatchNormalization(axis=axis, **kwargs) 69 | else: 70 | return BatchNormalization(mode=2,axis=axis, **kwargs) 71 | 72 | 73 | def g_unet(in_ch, out_ch, nf, batch_size=1, is_binary=False, name='unet'): 74 | # type: (int, int, int, int, bool, str) -> keras.models.Model 75 | """Define a U-Net. 76 | 77 | Input has shape in_ch x 512 x 512 78 | Parameters: 79 | - in_ch: the number of input channels; 80 | - out_ch: the number of output channels; 81 | - nf: the number of filters of the first layer; 82 | - is_binary: if is_binary is true, the last layer is followed by a sigmoid 83 | activation function, otherwise, a tanh is used. 84 | >>> K.set_image_dim_ordering('th') 85 | >>> K.image_data_format() 86 | 'channels_first' 87 | >>> unet = g_unet(1, 2, 3, batch_size=5, is_binary=True) 88 | TheanoShapedU-NET 89 | >>> for ilay in unet.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id 90 | >>> unet.summary() #doctest: +NORMALIZE_WHITESPACE 91 | _________________________________________________________________ 92 | Layer (type) Output Shape Param # 93 | ================================================================= 94 | input (InputLayer) (None, 1, 512, 512) 0 95 | _________________________________________________________________ 96 | conv2d (Conv2D) (None, 3, 256, 256) 30 97 | _________________________________________________________________ 98 | batch_normalization (BatchNo (None, 3, 256, 256) 12 99 | _________________________________________________________________ 100 | leaky_re_lu (LeakyReLU) (None, 3, 256, 256) 0 101 | _________________________________________________________________ 102 | conv2d (Conv2D) (None, 6, 128, 128) 168 103 | _________________________________________________________________ 104 | batch_normalization (BatchNo (None, 6, 128, 128) 24 105 | _________________________________________________________________ 106 | leaky_re_lu (LeakyReLU) (None, 6, 128, 128) 0 107 | _________________________________________________________________ 108 | conv2d (Conv2D) (None, 12, 64, 64) 660 109 | _________________________________________________________________ 110 | batch_normalization (BatchNo (None, 12, 64, 64) 48 111 | _________________________________________________________________ 112 | leaky_re_lu (LeakyReLU) (None, 12, 64, 64) 0 113 | _________________________________________________________________ 114 | conv2d (Conv2D) (None, 24, 32, 32) 2616 115 | _________________________________________________________________ 116 | batch_normalization (BatchNo (None, 24, 32, 32) 96 117 | _________________________________________________________________ 118 | leaky_re_lu (LeakyReLU) (None, 24, 32, 32) 0 119 | _________________________________________________________________ 120 | conv2d (Conv2D) (None, 24, 16, 16) 5208 121 | _________________________________________________________________ 122 | batch_normalization (BatchNo (None, 24, 16, 16) 96 123 | _________________________________________________________________ 124 | leaky_re_lu (LeakyReLU) (None, 24, 16, 16) 0 125 | _________________________________________________________________ 126 | conv2d (Conv2D) (None, 24, 8, 8) 5208 127 | _________________________________________________________________ 128 | batch_normalization (BatchNo (None, 24, 8, 8) 96 129 | _________________________________________________________________ 130 | leaky_re_lu (LeakyReLU) (None, 24, 8, 8) 0 131 | _________________________________________________________________ 132 | conv2d (Conv2D) (None, 24, 4, 4) 5208 133 | _________________________________________________________________ 134 | batch_normalization (BatchNo (None, 24, 4, 4) 96 135 | _________________________________________________________________ 136 | leaky_re_lu (LeakyReLU) (None, 24, 4, 4) 0 137 | _________________________________________________________________ 138 | conv2d (Conv2D) (None, 24, 2, 2) 5208 139 | _________________________________________________________________ 140 | batch_normalization (BatchNo (None, 24, 2, 2) 96 141 | _________________________________________________________________ 142 | leaky_re_lu (LeakyReLU) (None, 24, 2, 2) 0 143 | _________________________________________________________________ 144 | conv2d (Conv2D) (None, 24, 1, 1) 2328 145 | _________________________________________________________________ 146 | batch_normalization (BatchNo (None, 24, 1, 1) 96 147 | _________________________________________________________________ 148 | leaky_re_lu (LeakyReLU) (None, 24, 1, 1) 0 149 | _________________________________________________________________ 150 | conv2d_transpose (Conv2DTran (None, 24, 2, 2) 2328 151 | _________________________________________________________________ 152 | batch_normalization (BatchNo (None, 24, 2, 2) 96 153 | _________________________________________________________________ 154 | dropout (Dropout) (None, 24, 2, 2) 0 155 | _________________________________________________________________ 156 | concatenate (Concatenate) (None, 48, 2, 2) 0 157 | _________________________________________________________________ 158 | leaky_re_lu (LeakyReLU) (None, 48, 2, 2) 0 159 | _________________________________________________________________ 160 | conv2d_transpose (Conv2DTran (None, 24, 4, 4) 4632 161 | _________________________________________________________________ 162 | batch_normalization (BatchNo (None, 24, 4, 4) 96 163 | _________________________________________________________________ 164 | dropout (Dropout) (None, 24, 4, 4) 0 165 | _________________________________________________________________ 166 | concatenate (Concatenate) (None, 48, 4, 4) 0 167 | _________________________________________________________________ 168 | leaky_re_lu (LeakyReLU) (None, 48, 4, 4) 0 169 | _________________________________________________________________ 170 | conv2d_transpose (Conv2DTran (None, 24, 8, 8) 4632 171 | _________________________________________________________________ 172 | batch_normalization (BatchNo (None, 24, 8, 8) 96 173 | _________________________________________________________________ 174 | dropout (Dropout) (None, 24, 8, 8) 0 175 | _________________________________________________________________ 176 | concatenate (Concatenate) (None, 48, 8, 8) 0 177 | _________________________________________________________________ 178 | leaky_re_lu (LeakyReLU) (None, 48, 8, 8) 0 179 | _________________________________________________________________ 180 | conv2d_transpose (Conv2DTran (None, 24, 16, 16) 4632 181 | _________________________________________________________________ 182 | batch_normalization (BatchNo (None, 24, 16, 16) 96 183 | _________________________________________________________________ 184 | concatenate (Concatenate) (None, 48, 16, 16) 0 185 | _________________________________________________________________ 186 | leaky_re_lu (LeakyReLU) (None, 48, 16, 16) 0 187 | _________________________________________________________________ 188 | conv2d_transpose (Conv2DTran (None, 24, 32, 32) 4632 189 | _________________________________________________________________ 190 | batch_normalization (BatchNo (None, 24, 32, 32) 96 191 | _________________________________________________________________ 192 | concatenate (Concatenate) (None, 48, 32, 32) 0 193 | _________________________________________________________________ 194 | leaky_re_lu (LeakyReLU) (None, 48, 32, 32) 0 195 | _________________________________________________________________ 196 | conv2d_transpose (Conv2DTran (None, 12, 64, 64) 2316 197 | _________________________________________________________________ 198 | batch_normalization (BatchNo (None, 12, 64, 64) 48 199 | _________________________________________________________________ 200 | concatenate (Concatenate) (None, 24, 64, 64) 0 201 | _________________________________________________________________ 202 | leaky_re_lu (LeakyReLU) (None, 24, 64, 64) 0 203 | _________________________________________________________________ 204 | conv2d_transpose (Conv2DTran (None, 6, 128, 128) 582 205 | _________________________________________________________________ 206 | batch_normalization (BatchNo (None, 6, 128, 128) 24 207 | _________________________________________________________________ 208 | concatenate (Concatenate) (None, 12, 128, 128) 0 209 | _________________________________________________________________ 210 | leaky_re_lu (LeakyReLU) (None, 12, 128, 128) 0 211 | _________________________________________________________________ 212 | conv2d_transpose (Conv2DTran (None, 3, 256, 256) 147 213 | _________________________________________________________________ 214 | batch_normalization (BatchNo (None, 3, 256, 256) 12 215 | _________________________________________________________________ 216 | concatenate (Concatenate) (None, 6, 256, 256) 0 217 | _________________________________________________________________ 218 | leaky_re_lu (LeakyReLU) (None, 6, 256, 256) 0 219 | _________________________________________________________________ 220 | conv2d_transpose (Conv2DTran (None, 2, 512, 512) 50 221 | _________________________________________________________________ 222 | activation (Activation) (None, 2, 512, 512) 0 223 | ================================================================= 224 | Total params: 51,809.0 225 | Trainable params: 51,197.0 226 | Non-trainable params: 612.0 227 | _________________________________________________________________ 228 | >>> K.set_image_dim_ordering('tf') 229 | >>> K.image_data_format() 230 | 'channels_last' 231 | >>> unet2=g_unet(3, 4, 2, batch_size=7, is_binary=False) 232 | TensorflowShapedU-NET 233 | >>> for ilay in unet2.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id 234 | >>> unet2.summary() #doctest: +NORMALIZE_WHITESPACE 235 | _________________________________________________________________ 236 | Layer (type) Output Shape Param # 237 | ================================================================= 238 | input (InputLayer) (None, 512, 512, 3) 0 239 | _________________________________________________________________ 240 | conv2d (Conv2D) (None, 256, 256, 2) 56 241 | _________________________________________________________________ 242 | batch_normalization (BatchNo (None, 256, 256, 2) 1024 243 | _________________________________________________________________ 244 | leaky_re_lu (LeakyReLU) (None, 256, 256, 2) 0 245 | _________________________________________________________________ 246 | conv2d (Conv2D) (None, 128, 128, 4) 76 247 | _________________________________________________________________ 248 | batch_normalization (BatchNo (None, 128, 128, 4) 512 249 | _________________________________________________________________ 250 | leaky_re_lu (LeakyReLU) (None, 128, 128, 4) 0 251 | _________________________________________________________________ 252 | conv2d (Conv2D) (None, 64, 64, 8) 296 253 | _________________________________________________________________ 254 | batch_normalization (BatchNo (None, 64, 64, 8) 256 255 | _________________________________________________________________ 256 | leaky_re_lu (LeakyReLU) (None, 64, 64, 8) 0 257 | _________________________________________________________________ 258 | conv2d (Conv2D) (None, 32, 32, 16) 1168 259 | _________________________________________________________________ 260 | batch_normalization (BatchNo (None, 32, 32, 16) 128 261 | _________________________________________________________________ 262 | leaky_re_lu (LeakyReLU) (None, 32, 32, 16) 0 263 | _________________________________________________________________ 264 | conv2d (Conv2D) (None, 16, 16, 16) 2320 265 | _________________________________________________________________ 266 | batch_normalization (BatchNo (None, 16, 16, 16) 64 267 | _________________________________________________________________ 268 | leaky_re_lu (LeakyReLU) (None, 16, 16, 16) 0 269 | _________________________________________________________________ 270 | conv2d (Conv2D) (None, 8, 8, 16) 2320 271 | _________________________________________________________________ 272 | batch_normalization (BatchNo (None, 8, 8, 16) 32 273 | _________________________________________________________________ 274 | leaky_re_lu (LeakyReLU) (None, 8, 8, 16) 0 275 | _________________________________________________________________ 276 | conv2d (Conv2D) (None, 4, 4, 16) 2320 277 | _________________________________________________________________ 278 | batch_normalization (BatchNo (None, 4, 4, 16) 16 279 | _________________________________________________________________ 280 | leaky_re_lu (LeakyReLU) (None, 4, 4, 16) 0 281 | _________________________________________________________________ 282 | conv2d (Conv2D) (None, 2, 2, 16) 2320 283 | _________________________________________________________________ 284 | batch_normalization (BatchNo (None, 2, 2, 16) 8 285 | _________________________________________________________________ 286 | leaky_re_lu (LeakyReLU) (None, 2, 2, 16) 0 287 | _________________________________________________________________ 288 | conv2d (Conv2D) (None, 1, 1, 16) 1040 289 | _________________________________________________________________ 290 | batch_normalization (BatchNo (None, 1, 1, 16) 4 291 | _________________________________________________________________ 292 | leaky_re_lu (LeakyReLU) (None, 1, 1, 16) 0 293 | _________________________________________________________________ 294 | conv2d_transpose (Conv2DTran (None, 2, 2, 16) 1040 295 | _________________________________________________________________ 296 | batch_normalization (BatchNo (None, 2, 2, 16) 8 297 | _________________________________________________________________ 298 | dropout (Dropout) (None, 2, 2, 16) 0 299 | _________________________________________________________________ 300 | concatenate (Concatenate) (None, 2, 2, 32) 0 301 | _________________________________________________________________ 302 | leaky_re_lu (LeakyReLU) (None, 2, 2, 32) 0 303 | _________________________________________________________________ 304 | conv2d_transpose (Conv2DTran (None, 4, 4, 16) 2064 305 | _________________________________________________________________ 306 | batch_normalization (BatchNo (None, 4, 4, 16) 16 307 | _________________________________________________________________ 308 | dropout (Dropout) (None, 4, 4, 16) 0 309 | _________________________________________________________________ 310 | concatenate (Concatenate) (None, 4, 4, 32) 0 311 | _________________________________________________________________ 312 | leaky_re_lu (LeakyReLU) (None, 4, 4, 32) 0 313 | _________________________________________________________________ 314 | conv2d_transpose (Conv2DTran (None, 8, 8, 16) 2064 315 | _________________________________________________________________ 316 | batch_normalization (BatchNo (None, 8, 8, 16) 32 317 | _________________________________________________________________ 318 | dropout (Dropout) (None, 8, 8, 16) 0 319 | _________________________________________________________________ 320 | concatenate (Concatenate) (None, 8, 8, 32) 0 321 | _________________________________________________________________ 322 | leaky_re_lu (LeakyReLU) (None, 8, 8, 32) 0 323 | _________________________________________________________________ 324 | conv2d_transpose (Conv2DTran (None, 16, 16, 16) 2064 325 | _________________________________________________________________ 326 | batch_normalization (BatchNo (None, 16, 16, 16) 64 327 | _________________________________________________________________ 328 | concatenate (Concatenate) (None, 16, 16, 32) 0 329 | _________________________________________________________________ 330 | leaky_re_lu (LeakyReLU) (None, 16, 16, 32) 0 331 | _________________________________________________________________ 332 | conv2d_transpose (Conv2DTran (None, 32, 32, 16) 2064 333 | _________________________________________________________________ 334 | batch_normalization (BatchNo (None, 32, 32, 16) 128 335 | _________________________________________________________________ 336 | concatenate (Concatenate) (None, 32, 32, 32) 0 337 | _________________________________________________________________ 338 | leaky_re_lu (LeakyReLU) (None, 32, 32, 32) 0 339 | _________________________________________________________________ 340 | conv2d_transpose (Conv2DTran (None, 64, 64, 8) 1032 341 | _________________________________________________________________ 342 | batch_normalization (BatchNo (None, 64, 64, 8) 256 343 | _________________________________________________________________ 344 | concatenate (Concatenate) (None, 64, 64, 16) 0 345 | _________________________________________________________________ 346 | leaky_re_lu (LeakyReLU) (None, 64, 64, 16) 0 347 | _________________________________________________________________ 348 | conv2d_transpose (Conv2DTran (None, 128, 128, 4) 260 349 | _________________________________________________________________ 350 | batch_normalization (BatchNo (None, 128, 128, 4) 512 351 | _________________________________________________________________ 352 | concatenate (Concatenate) (None, 128, 128, 8) 0 353 | _________________________________________________________________ 354 | leaky_re_lu (LeakyReLU) (None, 128, 128, 8) 0 355 | _________________________________________________________________ 356 | conv2d_transpose (Conv2DTran (None, 256, 256, 2) 66 357 | _________________________________________________________________ 358 | batch_normalization (BatchNo (None, 256, 256, 2) 1024 359 | _________________________________________________________________ 360 | concatenate (Concatenate) (None, 256, 256, 4) 0 361 | _________________________________________________________________ 362 | leaky_re_lu (LeakyReLU) (None, 256, 256, 4) 0 363 | _________________________________________________________________ 364 | conv2d_transpose (Conv2DTran (None, 512, 512, 4) 68 365 | _________________________________________________________________ 366 | activation (Activation) (None, 512, 512, 4) 0 367 | ================================================================= 368 | Total params: 26,722.0 369 | Trainable params: 24,680.0 370 | Non-trainable params: 2,042.0 371 | _________________________________________________________________ 372 | """ 373 | merge_params = { 374 | 'mode': 'concat', 375 | 'concat_axis': 1 376 | } 377 | if K.image_dim_ordering() == 'th': 378 | print('TheanoShapedU-NET') 379 | i = Input(shape=(in_ch, 512, 512)) 380 | 381 | def get_deconv_shape(samples, channels, x_dim, y_dim): 382 | return samples, channels, x_dim, y_dim 383 | 384 | elif K.image_dim_ordering() == 'tf': 385 | i = Input(shape=(512, 512, in_ch)) 386 | print('TensorflowShapedU-NET') 387 | 388 | def get_deconv_shape(samples, channels, x_dim, y_dim): 389 | return samples, x_dim, y_dim, channels 390 | 391 | merge_params['concat_axis'] = 3 392 | else: 393 | raise ValueError( 394 | 'Keras dimension ordering not supported: {}'.format( 395 | K.image_dim_ordering())) 396 | 397 | # in_ch x 512 x 512 398 | conv1 = Convolution(nf)(i) 399 | conv1 = BatchNorm()(conv1) 400 | x = LeakyReLU(0.2)(conv1) 401 | # nf x 256 x 256 402 | 403 | conv2 = Convolution(nf * 2)(x) 404 | conv2 = BatchNorm()(conv2) 405 | x = LeakyReLU(0.2)(conv2) 406 | # nf*2 x 128 x 128 407 | 408 | conv3 = Convolution(nf * 4)(x) 409 | conv3 = BatchNorm()(conv3) 410 | x = LeakyReLU(0.2)(conv3) 411 | # nf*4 x 64 x 64 412 | 413 | conv4 = Convolution(nf * 8)(x) 414 | conv4 = BatchNorm()(conv4) 415 | x = LeakyReLU(0.2)(conv4) 416 | # nf*8 x 32 x 32 417 | 418 | conv5 = Convolution(nf * 8)(x) 419 | conv5 = BatchNorm()(conv5) 420 | x = LeakyReLU(0.2)(conv5) 421 | # nf*8 x 16 x 16 422 | 423 | conv6 = Convolution(nf * 8)(x) 424 | conv6 = BatchNorm()(conv6) 425 | x = LeakyReLU(0.2)(conv6) 426 | # nf*8 x 8 x 8 427 | 428 | conv7 = Convolution(nf * 8)(x) 429 | conv7 = BatchNorm()(conv7) 430 | x = LeakyReLU(0.2)(conv7) 431 | # nf*8 x 4 x 4 432 | 433 | conv8 = Convolution(nf * 8)(x) 434 | conv8 = BatchNorm()(conv8) 435 | x = LeakyReLU(0.2)(conv8) 436 | # nf*8 x 2 x 2 437 | 438 | conv9 = Convolution(nf * 8, k=2, s=1, border_mode='valid')(x) 439 | conv9 = BatchNorm()(conv9) 440 | x = LeakyReLU(0.2)(conv9) 441 | # nf*8 x 1 x 1 442 | 443 | dconv1 = Deconvolution(nf * 8, 444 | get_deconv_shape(batch_size, nf * 8, 2, 2), 445 | k=2, s=1)(x) 446 | dconv1 = BatchNorm()(dconv1) 447 | dconv1 = Dropout(0.5)(dconv1) 448 | 449 | x = concatenate_layers([dconv1, conv8], **merge_params) 450 | 451 | x = LeakyReLU(0.2)(x) 452 | # nf*(8 + 8) x 2 x 2 453 | 454 | dconv2 = Deconvolution(nf * 8, 455 | get_deconv_shape(batch_size, nf * 8, 4, 4))(x) 456 | dconv2 = BatchNorm()(dconv2) 457 | dconv2 = Dropout(0.5)(dconv2) 458 | x = concatenate_layers([dconv2, conv7], **merge_params) 459 | x = LeakyReLU(0.2)(x) 460 | # nf*(8 + 8) x 4 x 4 461 | 462 | dconv3 = Deconvolution(nf * 8, 463 | get_deconv_shape(batch_size, nf * 8, 8, 8))(x) 464 | dconv3 = BatchNorm()(dconv3) 465 | dconv3 = Dropout(0.5)(dconv3) 466 | x = concatenate_layers([dconv3, conv6], **merge_params) 467 | x = LeakyReLU(0.2)(x) 468 | # nf*(8 + 8) x 8 x 8 469 | 470 | dconv4 = Deconvolution(nf * 8, 471 | get_deconv_shape(batch_size, nf * 8, 16, 16))(x) 472 | dconv4 = BatchNorm()(dconv4) 473 | x = concatenate_layers([dconv4, conv5], **merge_params) 474 | x = LeakyReLU(0.2)(x) 475 | # nf*(8 + 8) x 16 x 16 476 | 477 | dconv5 = Deconvolution(nf * 8, 478 | get_deconv_shape(batch_size, nf * 8, 32, 32))(x) 479 | dconv5 = BatchNorm()(dconv5) 480 | x = concatenate_layers([dconv5, conv4], **merge_params) 481 | x = LeakyReLU(0.2)(x) 482 | # nf*(8 + 8) x 32 x 32 483 | 484 | dconv6 = Deconvolution(nf * 4, 485 | get_deconv_shape(batch_size, nf * 4, 64, 64))(x) 486 | dconv6 = BatchNorm()(dconv6) 487 | x = concatenate_layers([dconv6, conv3], **merge_params) 488 | x = LeakyReLU(0.2)(x) 489 | # nf*(4 + 4) x 64 x 64 490 | 491 | dconv7 = Deconvolution(nf * 2, 492 | get_deconv_shape(batch_size, nf * 2, 128, 128))(x) 493 | dconv7 = BatchNorm()(dconv7) 494 | x = concatenate_layers([dconv7, conv2], **merge_params) 495 | x = LeakyReLU(0.2)(x) 496 | # nf*(2 + 2) x 128 x 128 497 | 498 | dconv8 = Deconvolution(nf, 499 | get_deconv_shape(batch_size, nf, 256, 256))(x) 500 | dconv8 = BatchNorm()(dconv8) 501 | x = concatenate_layers([dconv8, conv1], **merge_params) 502 | x = LeakyReLU(0.2)(x) 503 | # nf*(1 + 1) x 256 x 256 504 | 505 | dconv9 = Deconvolution(out_ch, 506 | get_deconv_shape(batch_size, out_ch, 512, 512))(x) 507 | # out_ch x 512 x 512 508 | 509 | act = 'sigmoid' if is_binary else 'tanh' 510 | out = Activation(act)(dconv9) 511 | 512 | unet = Model(i, out, name=name) 513 | 514 | return unet 515 | 516 | 517 | def discriminator(a_ch, b_ch, nf, opt=Adam(lr=2e-4, beta_1=0.5), name='d'): 518 | """Define the discriminator network. 519 | 520 | Parameters: 521 | - a_ch: the number of channels of the first image; 522 | - b_ch: the number of channels of the second image; 523 | - nf: the number of filters of the first layer. 524 | >>> K.set_image_dim_ordering('th') 525 | >>> disc=discriminator(3,4,2) 526 | >>> for ilay in disc.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id 527 | >>> disc.summary() #doctest: +NORMALIZE_WHITESPACE 528 | _________________________________________________________________ 529 | Layer (type) Output Shape Param # 530 | ================================================================= 531 | input (InputLayer) (None, 7, 512, 512) 0 532 | _________________________________________________________________ 533 | conv2d (Conv2D) (None, 2, 256, 256) 128 534 | _________________________________________________________________ 535 | leaky_re_lu (LeakyReLU) (None, 2, 256, 256) 0 536 | _________________________________________________________________ 537 | conv2d (Conv2D) (None, 4, 128, 128) 76 538 | _________________________________________________________________ 539 | leaky_re_lu (LeakyReLU) (None, 4, 128, 128) 0 540 | _________________________________________________________________ 541 | conv2d (Conv2D) (None, 8, 64, 64) 296 542 | _________________________________________________________________ 543 | leaky_re_lu (LeakyReLU) (None, 8, 64, 64) 0 544 | _________________________________________________________________ 545 | conv2d (Conv2D) (None, 16, 32, 32) 1168 546 | _________________________________________________________________ 547 | leaky_re_lu (LeakyReLU) (None, 16, 32, 32) 0 548 | _________________________________________________________________ 549 | conv2d (Conv2D) (None, 1, 16, 16) 145 550 | _________________________________________________________________ 551 | activation (Activation) (None, 1, 16, 16) 0 552 | ================================================================= 553 | Total params: 1,813.0 554 | Trainable params: 1,813.0 555 | Non-trainable params: 0.0 556 | _________________________________________________________________ 557 | """ 558 | i = Input(shape=(a_ch + b_ch, 512, 512)) 559 | 560 | # (a_ch + b_ch) x 512 x 512 561 | conv1 = Convolution(nf)(i) 562 | x = LeakyReLU(0.2)(conv1) 563 | # nf x 256 x 256 564 | 565 | conv2 = Convolution(nf * 2)(x) 566 | x = LeakyReLU(0.2)(conv2) 567 | # nf*2 x 128 x 128 568 | 569 | conv3 = Convolution(nf * 4)(x) 570 | x = LeakyReLU(0.2)(conv3) 571 | # nf*4 x 64 x 64 572 | 573 | conv4 = Convolution(nf * 8)(x) 574 | x = LeakyReLU(0.2)(conv4) 575 | # nf*8 x 32 x 32 576 | 577 | conv5 = Convolution(1)(x) 578 | out = Activation('sigmoid')(conv5) 579 | # 1 x 16 x 16 580 | 581 | d = Model(i, out, name=name) 582 | 583 | def d_loss(y_true, y_pred): 584 | L = objectives.binary_crossentropy(K.batch_flatten(y_true), 585 | K.batch_flatten(y_pred)) 586 | return L 587 | 588 | d.compile(optimizer=opt, loss=d_loss) 589 | return d 590 | 591 | 592 | def pix2pix(atob, d, a_ch, b_ch, alpha=100, is_a_binary=False, 593 | is_b_binary=False, opt=Adam(lr=2e-4, beta_1=0.5), name='pix2pix'): 594 | # type: (...) -> keras.models.Model 595 | """ 596 | Define the pix2pix network. 597 | :param atob: 598 | :param d: 599 | :param a_ch: 600 | :param b_ch: 601 | :param alpha: 602 | :param is_a_binary: 603 | :param is_b_binary: 604 | :param opt: 605 | :param name: 606 | :return: 607 | >>> K.set_image_dim_ordering('th') 608 | >>> unet = g_unet(3, 4, 2, batch_size=8, is_binary=False) 609 | TheanoShapedU-NET 610 | >>> disc=discriminator(3,4,2) 611 | >>> pp_net=pix2pix(unet, disc, 3, 4) 612 | >>> for ilay in pp_net.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id 613 | >>> pp_net.summary() #doctest: +NORMALIZE_WHITESPACE 614 | _________________________________________________________________ 615 | Layer (type) Output Shape Param # 616 | ================================================================= 617 | input (InputLayer) (None, 3, 512, 512) 0 618 | _________________________________________________________________ 619 | (Model) (None, 4, 512, 512) 23454 620 | _________________________________________________________________ 621 | concatenate (Concatenate) (None, 7, 512, 512) 0 622 | _________________________________________________________________ 623 | (Model) (None, 1, 16, 16) 1813 624 | ================================================================= 625 | Total params: 25,267.0 626 | Trainable params: 24,859.0 627 | Non-trainable params: 408.0 628 | _________________________________________________________________ 629 | """ 630 | a = Input(shape=(a_ch, 512, 512)) 631 | b = Input(shape=(b_ch, 512, 512)) 632 | 633 | # A -> B' 634 | bp = atob(a) 635 | 636 | # Discriminator receives the pair of images 637 | d_in = concatenate_layers([a, bp], mode='concat', concat_axis=1) 638 | 639 | pix2pix = Model([a, b], d(d_in), name=name) 640 | 641 | def pix2pix_loss(y_true, y_pred): 642 | y_true_flat = K.batch_flatten(y_true) 643 | y_pred_flat = K.batch_flatten(y_pred) 644 | 645 | # Adversarial Loss 646 | L_adv = objectives.binary_crossentropy(y_true_flat, y_pred_flat) 647 | 648 | # A to B loss 649 | b_flat = K.batch_flatten(b) 650 | bp_flat = K.batch_flatten(bp) 651 | if is_b_binary: 652 | L_atob = objectives.binary_crossentropy(b_flat, bp_flat) 653 | else: 654 | L_atob = K.mean(K.abs(b_flat - bp_flat)) 655 | 656 | return L_adv + alpha * L_atob 657 | 658 | # This network is used to train the generator. Freeze the discriminator part. 659 | pix2pix.get_layer('d').trainable = False 660 | 661 | pix2pix.compile(optimizer=opt, loss=pix2pix_loss) 662 | return pix2pix 663 | 664 | 665 | if __name__ == '__main__': 666 | import doctest 667 | 668 | TEST_TF = True 669 | if TEST_TF: 670 | os.environ['KERAS_BACKEND'] = 'tensorflow' 671 | else: 672 | os.environ['KERAS_BACKEND'] = 'theano' 673 | doctest.testsource('models.py', verbose=True, optionflags=doctest.ELLIPSIS) 674 | --------------------------------------------------------------------------------