├── util
├── __init__.py
├── util.py
└── data.py
├── .gitignore
├── images
├── gen_method.png
└── image_collage_extended.jpg
├── LICENSE
├── README.md
├── test.py
├── train.py
└── models.py
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | data/
3 | log/
4 | results/
5 |
--------------------------------------------------------------------------------
/images/gen_method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/costapt/vess2ret/HEAD/images/gen_method.png
--------------------------------------------------------------------------------
/images/image_collage_extended.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/costapt/vess2ret/HEAD/images/image_collage_extended.jpg
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Pedro Miguel Vendas da Costa
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards Adversarial Retinal Image Synthesis
2 |
3 | [Arxiv](https://arxiv.org/abs/1701.08974) [Demo](http://vess2ret.inesctec.pt)
4 |
5 | We use an image-to-image translation technique based on the idea of adversarial learning to synthesize eye fundus images directly from data. We pair true eye fundus images with their respective vessel trees, by means of a vessel segmentation technique. These pairs are then used to learn a mapping from a binary vessel tree to a new retinal image.
6 |
7 |
8 |
9 |
10 | ## How it works
11 | - Get pairs of binary retinal vessel trees and corresponding retinal images
12 | The user can provide their own vessel annotations.
13 | In our case , because a large enough manually annotated database was not available we applied a DNN vessel segmentation method on the [Messidor database](http://www.adcis.net/en/Download-Third-Party/Messidor.html). For details please refer to [arxiv](https://arxiv.org/abs/1701.08974).
14 |
15 | - Train the image generator on the set of image pairs.
16 | The model was based in [pix2pix](https://github.com/phillipi/pix2pix). We use a Generative Adversarial Network and combine the adversarial loss with a global L1 loss. Our images have 512x512 pixel resolution. The implementation was developed in Python using Keras.
17 |
18 |
19 | - Test the model.
20 | The model is now able to synthesize a new retinal image from any given vessel tree.
21 |
22 |
23 |
24 |
25 |
26 | ## Setup
27 |
28 | ## Prerequisites
29 | - Keras (Theano or Tensorflow backend) with the "image_dim_ordering" set to "th"
30 |
31 | ### Set up directories
32 |
33 | The data must be organized into a train, validation and test directories. By default the directory tree is:
34 |
35 | * 'data/unet_segmentations_binary'
36 | * 'train'
37 | * 'A', contains the binary segmentations
38 | * 'B', contains the retinal images
39 | * 'val'
40 | * 'A', contains the binary segmentations
41 | * 'B', contains the retinal images
42 | * 'test'
43 | * 'A', contains the binary segmentations
44 | * 'B', contains the retinal images
45 |
46 | The defaults can be changed by altering the parameters at run time:
47 | ```bash
48 | python train.py [--base_dir] [--train_dir] [--val_dir]
49 | ```
50 | Folders {A,B} contain corresponding pairs of images. Make sure these folders have the default name. The pairs should have the same filename.
51 |
52 | ## Usage
53 |
54 | ## Model
55 |
56 | The model can be used with any given vessel tree of the according size. You can download the pre-trained weights available [here](https://drive.google.com/drive/folders/0B_82R0TWezB9VExYbmt2ZUJSUmc?usp=sharing) and load them at test time. If you choose to do this skip the training step.
57 |
58 | ### Train the model
59 |
60 | To train the model run:
61 |
62 | ```bash
63 | python train.py [--help]
64 | ```
65 | By default the model will be saved to a folder named 'log'.
66 |
67 | ### Test the model
68 |
69 | To test the model run:
70 |
71 | ```bash
72 | python test.py [--help]
73 | ```
74 | If you are running the test using pre-trained weights downloaded from [here](https://drive.google.com/drive/folders/0B_82R0TWezB9VExYbmt2ZUJSUmc?usp=sharing) make sure both the weights and params.json are saved in the log folder.
75 |
76 |
77 | ## Citation
78 | If you use this code for your research, please cite our paper [Towards Adversarial Retinal Image Synthesis](https://arxiv.org/abs/1701.08974):
79 |
80 | ```
81 | @article{ costa_retinal_generation_2017,
82 | title={Towards Adversarial Retinal Image Synthesis},
83 | author={ Costa, P., Galdran, A., Meyer, M.I., Abràmoff, M.D., Niemejer, M., Mendonca, A.M., Campilho, A. },
84 | journal={arxiv},
85 | year={2017},
86 | doi={10.5281/zenodo.265508}
87 | }
88 | ```
89 |
90 | [](https://doi.org/10.5281/zenodo.265508)
91 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """Auxiliary methods."""
2 | import os
3 | import json
4 | from errno import EEXIST
5 |
6 | import numpy as np
7 | import seaborn as sns
8 | import cPickle as pickle
9 | import matplotlib.pyplot as plt
10 |
11 | sns.set()
12 |
13 | DEFAULT_LOG_DIR = 'log'
14 | ATOB_WEIGHTS_FILE = 'atob_weights.h5'
15 | D_WEIGHTS_FILE = 'd_weights.h5'
16 |
17 |
18 | class MyDict(dict):
19 | """
20 | Dictionary that allows to access elements with dot notation.
21 |
22 | ex:
23 | >> d = MyDict({'key': 'val'})
24 | >> d.key
25 | 'val'
26 | >> d.key2 = 'val2'
27 | >> d
28 | {'key2': 'val2', 'key': 'val'}
29 | """
30 |
31 | __getattr__ = dict.get
32 | __setattr__ = dict.__setitem__
33 |
34 |
35 | def convert_to_rgb(img, is_binary=False):
36 | """Given an image, make sure it has 3 channels and that it is between 0 and 1."""
37 | if len(img.shape) != 3:
38 | raise Exception("""Image must have 3 dimensions (channels x height x width). """
39 | """Given {0}""".format(len(img.shape)))
40 |
41 | img_ch, _, _ = img.shape
42 | if img_ch != 3 and img_ch != 1:
43 | raise Exception("""Unsupported number of channels. """
44 | """Must be 1 or 3, given {0}.""".format(img_ch))
45 |
46 | imgp = img
47 | if img_ch == 1:
48 | imgp = np.repeat(img, 3, axis=0)
49 |
50 | if not is_binary:
51 | imgp = imgp * 127.5 + 127.5
52 | imgp /= 255.
53 |
54 | return np.clip(imgp.transpose((1, 2, 0)), 0, 1)
55 |
56 |
57 | def compose_imgs(a, b, is_a_binary=True, is_b_binary=False):
58 | """Place a and b side by side to be plotted."""
59 | ap = convert_to_rgb(a, is_binary=is_a_binary)
60 | bp = convert_to_rgb(b, is_binary=is_b_binary)
61 |
62 | if ap.shape != bp.shape:
63 | raise Exception("""A and B must have the same size. """
64 | """{0} != {1}""".format(ap.shape, bp.shape))
65 |
66 | # ap.shape and bp.shape must have the same size here
67 | h, w, ch = ap.shape
68 | composed = np.zeros((h, 2*w, ch))
69 | composed[:, :w, :] = ap
70 | composed[:, w:, :] = bp
71 |
72 | return composed
73 |
74 |
75 | def get_log_dir(log_dir, expt_name):
76 | """Compose the log_dir with the experiment name."""
77 | if log_dir is None:
78 | raise Exception('log_dir can not be None.')
79 |
80 | if expt_name is not None:
81 | return os.path.join(log_dir, expt_name)
82 | return log_dir
83 |
84 |
85 | def mkdir(mypath):
86 | """Create a directory if it does not exist."""
87 | try:
88 | os.makedirs(mypath)
89 | except OSError as exc:
90 | if exc.errno == EEXIST and os.path.isdir(mypath):
91 | pass
92 | else:
93 | raise
94 |
95 |
96 | def create_expt_dir(params):
97 | """Create the experiment directory and return it."""
98 | expt_dir = get_log_dir(params.log_dir, params.expt_name)
99 |
100 | # Create directories if they do not exist
101 | mkdir(params.log_dir)
102 | mkdir(expt_dir)
103 |
104 | # Save the parameters
105 | json.dump(params, open(os.path.join(expt_dir, 'params.json'), 'wb'),
106 | indent=4, sort_keys=True)
107 |
108 | return expt_dir
109 |
110 |
111 | def plot_loss(loss, label, filename, log_dir):
112 | """Plot a loss function and save it in a file."""
113 | plt.figure(figsize=(5, 4))
114 | plt.plot(loss, label=label)
115 | plt.legend()
116 | plt.savefig(os.path.join(log_dir, filename))
117 | plt.clf()
118 |
119 |
120 | def log(losses, atob, it_val, N=4, log_dir=DEFAULT_LOG_DIR, expt_name=None,
121 | is_a_binary=True, is_b_binary=False):
122 | """Log losses and atob results."""
123 | log_dir = get_log_dir(log_dir, expt_name)
124 |
125 | # Save the losses for further inspection
126 | pickle.dump(losses, open(os.path.join(log_dir, 'losses.pkl'), 'wb'))
127 |
128 | ###########################################################################
129 | # PLOT THE LOSSES #
130 | ###########################################################################
131 | plot_loss(losses['d'], 'discriminator', 'd_loss.png', log_dir)
132 | plot_loss(losses['d_val'], 'discriminator validation', 'd_val_loss.png', log_dir)
133 |
134 | plot_loss(losses['p2p'], 'Pix2Pix', 'p2p_loss.png', log_dir)
135 | plot_loss(losses['p2p_val'], 'Pix2Pix validation', 'p2p_val_loss.png', log_dir)
136 |
137 | ###########################################################################
138 | # PLOT THE A->B RESULTS #
139 | ###########################################################################
140 | plt.figure(figsize=(10, 6))
141 | for i in range(N*N):
142 | a, _ = next(it_val)
143 |
144 | bp = atob.predict(a)
145 | img = compose_imgs(a[0], bp[0], is_a_binary=is_a_binary, is_b_binary=is_b_binary)
146 |
147 | plt.subplot(N, N, i+1)
148 | plt.imshow(img)
149 | plt.axis('off')
150 |
151 | plt.savefig(os.path.join(log_dir, 'atob.png'))
152 | plt.clf()
153 |
154 | # Make sure all the figures are closed.
155 | plt.close('all')
156 |
157 |
158 | def save_weights(models, log_dir=DEFAULT_LOG_DIR, expt_name=None):
159 | """Save the weights of the models into a file."""
160 | log_dir = get_log_dir(log_dir, expt_name)
161 |
162 | models.atob.save_weights(os.path.join(log_dir, ATOB_WEIGHTS_FILE), overwrite=True)
163 | models.d.save_weights(os.path.join(log_dir, D_WEIGHTS_FILE), overwrite=True)
164 |
165 |
166 | def load_weights(atob, d, log_dir=DEFAULT_LOG_DIR, expt_name=None):
167 | """Load the weights into the corresponding models."""
168 | log_dir = get_log_dir(log_dir, expt_name)
169 |
170 | atob.load_weights(os.path.join(log_dir, ATOB_WEIGHTS_FILE))
171 | d.load_weights(os.path.join(log_dir, D_WEIGHTS_FILE))
172 |
173 |
174 | def load_weights_of(m, weights_file, log_dir=DEFAULT_LOG_DIR, expt_name=None):
175 | """Load the weights of the model m."""
176 | log_dir = get_log_dir(log_dir, expt_name)
177 |
178 | m.load_weights(os.path.join(log_dir, weights_file))
179 |
180 |
181 | def load_losses(log_dir=DEFAULT_LOG_DIR, expt_name=None):
182 | """Load the losses of the given experiment."""
183 | log_dir = get_log_dir(log_dir, expt_name)
184 | losses = pickle.load(open(os.path.join(log_dir, 'losses.pkl'), 'rb'))
185 | return losses
186 |
187 |
188 | def load_params(params):
189 | """
190 | Load the parameters of an experiment and return them.
191 |
192 | The params passed as argument will be merged with the new params dict.
193 | If there is a conflict with a key, the params passed as argument prevails.
194 | """
195 | expt_dir = get_log_dir(params.log_dir, params.expt_name)
196 |
197 | expt_params = json.load(open(os.path.join(expt_dir, 'params.json'), 'rb'))
198 |
199 | # Update the loaded parameters with the current parameters. This will
200 | # override conflicting keys as expected.
201 | expt_params.update(params)
202 |
203 | return expt_params
204 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """Script to test a trained model."""
2 | import os
3 | import sys
4 | import getopt
5 |
6 | import numpy as np
7 | import models as m
8 | import matplotlib.pyplot as plt
9 | import util.util as u
10 |
11 | from util.data import TwoImageIterator
12 | from util.util import MyDict, load_params, load_weights_of, compose_imgs, convert_to_rgb, mkdir, get_log_dir
13 |
14 |
15 | def print_help():
16 | """Print how to use this script."""
17 | print "Usage:"
18 | print "test.py [--help] [--results_dir] [--log_dir] [--base_dir] [--train_dir] [--val_dir] " \
19 | "[--test_dir] [--load_to_memory] [--expt_name] [--target_size] [--N]"
20 | print "--results_dir: Directory where to save the results."
21 | print "--log_dir': Directory where the experiment was logged."
22 | print "--base_dir: Directory that contains the data."
23 | print "--train_dir: Directory inside base_dir that contains training data."
24 | print "--val_dir: Directory inside base_dir that contains validation data."
25 | print "--test_dir: Directory inside base_dir that contains test data."
26 | print "--load_to_memory: Whether to load the images into memory."
27 | print "--expt_name: The name of the experiment to test."
28 | print "--target_size: The size of the images loaded by the iterator."
29 | print "--N: The number of samples to generate."
30 |
31 |
32 | def join_and_create_dir(*paths):
33 | """Join the paths provided as arguments, create the directory and return the path."""
34 | path = os.path.join(*paths)
35 | mkdir(path)
36 |
37 | return path
38 |
39 |
40 | def save_pix2pix(unet, it, path, params):
41 | """Save the results of the pix2pix model."""
42 | real_dir = join_and_create_dir(path, 'real')
43 | a_dir = join_and_create_dir(path, 'A')
44 | b_dir = join_and_create_dir(path, 'B')
45 | comp_dir = join_and_create_dir(path, 'composed')
46 |
47 | for i, filename in enumerate(it.filenames):
48 | a, b = next(it)
49 | bp = unet.predict(a)
50 | bp = convert_to_rgb(bp[0], is_binary=params.is_b_binary)
51 |
52 | img = compose_imgs(a[0], b[0], is_a_binary=params.is_a_binary, is_b_binary=params.is_b_binary)
53 | hi, wi, chi = img.shape
54 | hb, wb, chb = bp.shape
55 | if hi != hb or wi != 2*wb or chi != chb:
56 | raise Exception("Mismatch in img and bp dimensions {0} / {1}".format(img.shape, bp.shape))
57 |
58 | composed = np.zeros((hi, wi+wb, chi))
59 | composed[:, :wi, :] = img
60 | composed[:, wi:, :] = bp
61 |
62 | a = convert_to_rgb(a[0], is_binary=params.is_a_binary)
63 | b = convert_to_rgb(b[0], is_binary=params.is_b_binary)
64 |
65 | plt.imsave(open(os.path.join(real_dir, filename), 'wb+'), b)
66 | plt.imsave(open(os.path.join(b_dir, filename), 'wb+'), bp)
67 | plt.imsave(open(os.path.join(a_dir, filename), 'wb+'), a)
68 | plt.imsave(open(os.path.join(comp_dir, filename), 'wb+'), composed)
69 |
70 |
71 | def save_all_pix2pix(unet, it_train, it_val, it_test, params):
72 | """Save all the results of the pix2pix model."""
73 | expt_dir = get_log_dir(params.results_dir, params.expt_name)
74 |
75 | # Create directores if they do not exist
76 | mkdir(params.results_dir)
77 | mkdir(expt_dir)
78 |
79 | train_dir = join_and_create_dir(expt_dir, params.train_dir)
80 | val_dir = join_and_create_dir(expt_dir, params.val_dir)
81 | test_dir = join_and_create_dir(expt_dir, params.test_dir)
82 |
83 | save_pix2pix(unet, it_train, train_dir, params)
84 | save_pix2pix(unet, it_val, val_dir, params)
85 | save_pix2pix(unet, it_test, test_dir, params)
86 |
87 |
88 | if __name__ == '__main__':
89 | a = sys.argv[1:]
90 |
91 | params = MyDict({
92 | 'results_dir': 'results', # Directory where to save the results
93 | 'log_dir': 'log', # Directory where the experiment was logged
94 | 'base_dir': 'data/unet_segmentations_binary', # Directory that contains the data
95 | 'train_dir': 'train', # Directory inside base_dir that contains training data
96 | 'val_dir': 'val', # Directory inside base_dir that contains validation data
97 | 'test_dir': 'test', # Directory inside base_dir that contains test data
98 | 'load_to_memory': True, # Whether to load the images into memory
99 | 'expt_name': None, # The name of the experiment to test
100 | 'target_size': 512, # The size of the images loaded by the iterator
101 | 'N': 100, # The number of samples to generate
102 | })
103 |
104 | param_names = [k + '=' for k in params.keys()] + ['help']
105 |
106 | try:
107 | opts, args = getopt.getopt(a, '', param_names)
108 | except getopt.GetoptError:
109 | print_help()
110 | sys.exit()
111 |
112 | for opt, arg in opts:
113 | if opt == '--help':
114 | print_help()
115 | sys.exit()
116 | elif opt in ('--target_size', '--N'):
117 | params[opt[2:]] = int(arg)
118 | elif opt in ('--load_to_memory'):
119 | params[opt[2:]] = True if arg == 'True' else False
120 | elif opt in ('--results_dir', '--log_dir', '--base_dir', '--train_dir',
121 | '--val_dir', '--test_dir', '--expt_name'):
122 | params[opt[2:]] = arg
123 |
124 | params = load_params(params)
125 | params = MyDict(params)
126 |
127 | # Define the U-Net generator
128 | unet = m.g_unet(params.a_ch, params.b_ch, params.nfatob, is_binary=params.is_b_binary)
129 | load_weights_of(unet, u.ATOB_WEIGHTS_FILE, log_dir=params.log_dir, expt_name=params.expt_name)
130 |
131 | ts = params.target_size
132 | train_dir = os.path.join(params.base_dir, params.train_dir)
133 | it_train = TwoImageIterator(train_dir, is_a_binary=params.is_a_binary,
134 | is_a_grayscale=params.is_a_grayscale,
135 | is_b_grayscale=params.is_b_grayscale,
136 | is_b_binary=params.is_b_binary, batch_size=1,
137 | load_to_memory=params.load_to_memory,
138 | target_size=(ts, ts), shuffle=False)
139 | val_dir = os.path.join(params.base_dir, params.val_dir)
140 | it_val = TwoImageIterator(val_dir, is_a_binary=params.is_a_binary,
141 | is_b_binary=params.is_b_binary,
142 | is_a_grayscale=params.is_a_grayscale,
143 | is_b_grayscale=params.is_b_grayscale, batch_size=1,
144 | load_to_memory=params.load_to_memory,
145 | target_size=(ts, ts), shuffle=False)
146 | test_dir = os.path.join(params.base_dir, params.test_dir)
147 | it_test = TwoImageIterator(test_dir, is_a_binary=params.is_a_binary,
148 | is_b_binary=params.is_b_binary,
149 | is_a_grayscale=params.is_a_grayscale,
150 | is_b_grayscale=params.is_b_grayscale, batch_size=1,
151 | load_to_memory=params.load_to_memory,
152 | target_size=(ts, ts), shuffle=False)
153 |
154 | save_all_pix2pix(unet, it_train, it_val, it_test, params)
155 |
--------------------------------------------------------------------------------
/util/data.py:
--------------------------------------------------------------------------------
1 | """Auxiliar methods to deal with loading the dataset."""
2 | import os
3 | import random
4 |
5 | import numpy as np
6 |
7 | from keras.preprocessing.image import apply_transform, flip_axis
8 | from keras.preprocessing.image import transform_matrix_offset_center
9 | from keras.preprocessing.image import Iterator, load_img, img_to_array
10 |
11 |
12 | class TwoImageIterator(Iterator):
13 | """Class to iterate A and B images at the same time."""
14 |
15 | def __init__(self, directory, a_dir_name='A', b_dir_name='B', load_to_memory=False,
16 | is_a_binary=False, is_b_binary=False, is_a_grayscale=False,
17 | is_b_grayscale=False, target_size=(256, 256), rotation_range=0.,
18 | height_shift_range=0., width_shift_range=0., zoom_range=0.,
19 | fill_mode='constant', cval=0., horizontal_flip=False,
20 | vertical_flip=False, dim_ordering='default', N=-1,
21 | batch_size=32, shuffle=True, seed=None):
22 | """
23 | Iterate through two directories at the same time.
24 |
25 | Files under the directory A and B with the same name will be returned
26 | at the same time.
27 | Parameters:
28 | - directory: base directory of the dataset. Should contain two
29 | directories with name a_dir_name and b_dir_name;
30 | - a_dir_name: name of directory under directory that contains the A
31 | images;
32 | - b_dir_name: name of directory under directory that contains the B
33 | images;
34 | - load_to_memory: if true, loads the images to memory when creating the
35 | iterator;
36 | - is_a_binary: converts A images to binary images. Applies a threshold of 0.5.
37 | - is_b_binary: converts B images to binary images. Applies a threshold of 0.5.
38 | - is_a_grayscale: if True, A images will only have one channel.
39 | - is_b_grayscale: if True, B images will only have one channel.
40 | - N: if -1 uses the entire dataset. Otherwise only uses a subset;
41 | - batch_size: the size of the batches to create;
42 | - shuffle: if True the order of the images in X will be shuffled;
43 | - seed: seed for a random number generator.
44 | """
45 | self.directory = directory
46 |
47 | self.a_dir = os.path.join(directory, a_dir_name)
48 | self.b_dir = os.path.join(directory, b_dir_name)
49 |
50 | a_files = set(x for x in os.listdir(self.a_dir))
51 | b_files = set(x for x in os.listdir(self.b_dir))
52 | # Files inside a and b should have the same name. Images without a pair are discarded.
53 | self.filenames = list(a_files.intersection(b_files))
54 |
55 | # Use only a subset of the files. Good to easily overfit the model
56 | if N > 0:
57 | random.shuffle(self.filenames)
58 | self.filenames = self.filenames[:N]
59 | self.N = len(self.filenames)
60 | if self.N == 0:
61 | raise Exception("""Did not find any pair in the dataset. Please check that """
62 | """the names and extensions of the pairs are exactly the same. """
63 | """Searched inside folders: {0} and {1}""".format(self.a_dir, self.b_dir))
64 |
65 | self.dim_ordering = dim_ordering
66 | if self.dim_ordering not in ('th', 'default', 'tf'):
67 | raise Exception('dim_ordering should be one of "th", "tf" or "default". '
68 | 'Got {0}'.format(self.dim_ordering))
69 |
70 | self.target_size = target_size
71 |
72 | self.is_a_binary = is_a_binary
73 | self.is_b_binary = is_b_binary
74 | self.is_a_grayscale = is_a_grayscale
75 | self.is_b_grayscale = is_b_grayscale
76 |
77 | self.image_shape_a = self._get_image_shape(self.is_a_grayscale)
78 | self.image_shape_b = self._get_image_shape(self.is_b_grayscale)
79 |
80 | self.load_to_memory = load_to_memory
81 | if self.load_to_memory:
82 | self._load_imgs_to_memory()
83 |
84 | if self.dim_ordering in ('th', 'default'):
85 | self.channel_index = 1
86 | self.row_index = 2
87 | self.col_index = 3
88 | if dim_ordering == 'tf':
89 | self.channel_index = 3
90 | self.row_index = 1
91 | self.col_index = 2
92 |
93 | self.rotation_range = rotation_range
94 | self.height_shift_range = height_shift_range
95 | self.width_shift_range = width_shift_range
96 | self.fill_mode = fill_mode
97 | self.cval = cval
98 | self.horizontal_flip = horizontal_flip
99 | self.vertical_flip = vertical_flip
100 |
101 | if np.isscalar(zoom_range):
102 | self.zoom_range = [1 - zoom_range, 1 + zoom_range]
103 | elif len(zoom_range) == 2:
104 | self.zoom_range = [zoom_range[0], zoom_range[1]]
105 |
106 | super(TwoImageIterator, self).__init__(len(self.filenames), batch_size,
107 | shuffle, seed)
108 |
109 | def _get_image_shape(self, is_grayscale):
110 | """Auxiliar method to get the image shape given the color mode."""
111 | if is_grayscale:
112 | if self.dim_ordering == 'tf':
113 | return self.target_size + (1,)
114 | else:
115 | return (1,) + self.target_size
116 | else:
117 | if self.dim_ordering == 'tf':
118 | return self.target_size + (3,)
119 | else:
120 | return (3,) + self.target_size
121 |
122 | def _load_imgs_to_memory(self):
123 | """Load images to memory."""
124 | if not self.load_to_memory:
125 | raise Exception('Can not load images to memory. Reason: load_to_memory = False')
126 |
127 | self.a = np.zeros((self.N,) + self.image_shape_a)
128 | self.b = np.zeros((self.N,) + self.image_shape_b)
129 |
130 | for idx in range(self.N):
131 | ai, bi = self._load_img_pair(idx, False)
132 | self.a[idx] = ai
133 | self.b[idx] = bi
134 |
135 | def _binarize(self, batch):
136 | """Make input binary images have 0 and 1 values only."""
137 | bin_batch = batch / 255.
138 | bin_batch[bin_batch >= 0.5] = 1
139 | bin_batch[bin_batch < 0.5] = 0
140 | return bin_batch
141 |
142 | def _normalize_for_tanh(self, batch):
143 | """Make input image values lie between -1 and 1."""
144 | tanh_batch = batch - 127.5
145 | tanh_batch /= 127.5
146 | return tanh_batch
147 |
148 | def _load_img_pair(self, idx, load_from_memory):
149 | """Get a pair of images with index idx."""
150 | if load_from_memory:
151 | a = self.a[idx]
152 | b = self.b[idx]
153 | return a, b
154 |
155 | fname = self.filenames[idx]
156 |
157 | a = load_img(os.path.join(self.a_dir, fname),
158 | grayscale=self.is_a_grayscale,
159 | target_size=self.target_size)
160 | b = load_img(os.path.join(self.b_dir, fname),
161 | grayscale=self.is_b_grayscale,
162 | target_size=self.target_size)
163 |
164 | a = img_to_array(a, self.dim_ordering)
165 | b = img_to_array(b, self.dim_ordering)
166 |
167 | return a, b
168 |
169 | def _random_transform(self, a, b):
170 | """
171 | Random dataset augmentation.
172 |
173 | Adapted from https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
174 | """
175 | # a and b are single images, so they don't have image number at index 0
176 | img_row_index = self.row_index - 1
177 | img_col_index = self.col_index - 1
178 | img_channel_index = self.channel_index - 1
179 |
180 | # use composition of homographies to generate final transform that needs to be applied
181 | if self.rotation_range:
182 | theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range)
183 | else:
184 | theta = 0
185 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
186 | [np.sin(theta), np.cos(theta), 0],
187 | [0, 0, 1]])
188 | if self.height_shift_range:
189 | tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * a.shape[img_row_index]
190 | else:
191 | tx = 0
192 |
193 | if self.width_shift_range:
194 | ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * a.shape[img_col_index]
195 | else:
196 | ty = 0
197 |
198 | translation_matrix = np.array([[1, 0, tx],
199 | [0, 1, ty],
200 | [0, 0, 1]])
201 |
202 | if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:
203 | zx, zy = 1, 1
204 | else:
205 | zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2)
206 | zoom_matrix = np.array([[zx, 0, 0],
207 | [0, zy, 0],
208 | [0, 0, 1]])
209 |
210 | transform_matrix = np.dot(np.dot(rotation_matrix, translation_matrix), zoom_matrix)
211 |
212 | h, w = a.shape[img_row_index], a.shape[img_col_index]
213 | transform_matrix = transform_matrix_offset_center(transform_matrix, h, w)
214 | a = apply_transform(a, transform_matrix, img_channel_index,
215 | fill_mode=self.fill_mode, cval=self.cval)
216 | b = apply_transform(b, transform_matrix, img_channel_index,
217 | fill_mode=self.fill_mode, cval=self.cval)
218 |
219 | if self.horizontal_flip:
220 | if np.random.random() < 0.5:
221 | a = flip_axis(a, img_col_index)
222 | b = flip_axis(b, img_col_index)
223 |
224 | if self.vertical_flip:
225 | if np.random.random() < 0.5:
226 | a = flip_axis(a, img_row_index)
227 | b = flip_axis(b, img_row_index)
228 |
229 | return a, b
230 |
231 | def next(self):
232 | """Get the next pair of the sequence."""
233 | # Lock the iterator when the index is changed.
234 | with self.lock:
235 | index_array, _, current_batch_size = next(self.index_generator)
236 |
237 | batch_a = np.zeros((current_batch_size,) + self.image_shape_a)
238 | batch_b = np.zeros((current_batch_size,) + self.image_shape_b)
239 |
240 | for i, j in enumerate(index_array):
241 | a_img, b_img = self._load_img_pair(j, self.load_to_memory)
242 | a_img, b_img = self._random_transform(a_img, b_img)
243 |
244 | batch_a[i] = a_img
245 | batch_b[i] = b_img
246 |
247 | if self.is_a_binary:
248 | batch_a = self._binarize(batch_a)
249 | else:
250 | batch_a = self._normalize_for_tanh(batch_a)
251 |
252 | if self.is_b_binary:
253 | batch_b = self._binarize(batch_b)
254 | else:
255 | batch_b = self._normalize_for_tanh(batch_b)
256 |
257 | return [batch_a, batch_b]
258 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """The script used to train the model."""
2 | import os
3 | import sys
4 | import getopt
5 |
6 | import numpy as np
7 | import models as m
8 |
9 | from tqdm import tqdm
10 | from keras.optimizers import Adam
11 | from util.data import TwoImageIterator
12 | from util.util import MyDict, log, save_weights, load_weights, load_losses, create_expt_dir
13 |
14 |
15 | def print_help():
16 | """Print how to use this script."""
17 | print "Usage:"
18 | print "train.py [--help] [--nfd] [--nfatob] [--alpha] [--epochs] [batch_size] [--samples_per_batch] " \
19 | "[--save_every] [--lr] [--beta_1] [--continue_train] [--log_dir]" \
20 | "[--expt_name] [--base_dir] [--train_dir] [--val_dir] [--train_samples] " \
21 | "[--val_samples] [--load_to_memory] [--a_ch] [--b_ch] [--is_a_binary] " \
22 | "[--is_b_binary] [--is_a_grayscale] [--is_b_grayscale] [--target_size] " \
23 | "[--rotation_range] [--height_shift_range] [--width_shift_range] " \
24 | "[--horizontal_flip] [--vertical_flip] [--zoom_range]"
25 | print "--nfd: Number of filters of the first layer of the discriminator."
26 | print "--nfatob: Number of filters of the first layer of the AtoB model."
27 | print "--alpha: The weight of the reconstruction loss of the AtoB model."
28 | print "--epochs: Number of epochs to train the model."
29 | print "--batch_size: the size of the batch to train."
30 | print "--samples_per_batch: The number of samples to train each model on each iteration."
31 | print "--save_every: Save results every 'save_every' epochs on the log folder."
32 | print "--lr: The learning rate to train the models."
33 | print "--beta_1: The beta_1 value of the Adam optimizer."
34 | print "--continue_train: If it should continue the training from the last checkpoint."
35 | print "--log_dir: The directory to place the logs."
36 | print "--expt_name: The name of the experiment. Saves the logs into a folder with this name."
37 | print "--base_dir: Directory that contains the data."
38 | print "--train_dir: Directory inside base_dir that contains training data. " \
39 | "Must contain an A and B folder."
40 | print "--val_dir: Directory inside base_dir that contains validation data. " \
41 | "Must contain an A and B folder."
42 | print "--train_samples: The number of training samples. Set -1 to be the same as training examples."
43 | print "--val_samples: The number of validation samples. Set -1 to be the same as validation examples."
44 | print "--load_to_memory: Whether to load images into memory or read from the filesystem."
45 | print "--a_ch: Number of channels of images A."
46 | print "--b_ch: Number of channels of images B."
47 | print "--is_a_binary: If A is binary, its values will be 0 or 1. A threshold of 0.5 is used."
48 | print "--is_b_binary: If B is binary, the last layer of the atob model is " \
49 | "followed by a sigmoid. Otherwise, a tanh is used. When the sigmoid is " \
50 | "used, the binary crossentropy loss is used. For the tanh, the L1 is used. Also, " \
51 | "its values will be 0 or 1. A threshold of 0.5 is used."
52 | print "--is_a_grayscale: If A images should only have one channel. If they are color images, " \
53 | "they are converted to grayscale."
54 | print "--is_b_grayscale: If B images should only have one channel. If they are color images, " \
55 | "they are converted to grayscale."
56 | print "--target_size: The size of the images loaded by the iterator. THIS DOES NOT CHANGE THE MODELS. " \
57 | "If you want to accept images of different sizes you will need to update the models.py files."
58 | print "--rotation_range: The range to rotate training images for dataset augmentation."
59 | print "--height_shift_range: Percentage of height of the image to translate for dataset augmentation."
60 | print "--width_shift_range: Percentage of width of the image to translate for dataset augmentation."
61 | print "--horizontal_flip: If true performs random horizontal flips on the train set."
62 | print "--vertical_flip: If true performs random vertical flips on the train set."
63 | print "--zoom_range: Defines the range to scale the image for dataset augmentation."
64 |
65 |
66 | def discriminator_generator(it, atob, dout_size):
67 | """
68 | Generate batches for the discriminator.
69 |
70 | Parameters:
71 | - it: an iterator that returns a pair of images;
72 | - atob: the generator network that maps an image to another representation;
73 | - dout_size: the size of the output of the discriminator.
74 | """
75 | while True:
76 | # Fake pair
77 | a_fake, _ = next(it)
78 | b_fake = atob.predict(a_fake)
79 |
80 | # Real pair
81 | a_real, b_real = next(it)
82 |
83 | # Concatenate the channels. Images become (ch_a + ch_b) x 256 x 256
84 | fake = np.concatenate((a_fake, b_fake), axis=1)
85 | real = np.concatenate((a_real, b_real), axis=1)
86 |
87 | # Concatenate fake and real pairs into a single batch
88 | batch_x = np.concatenate((fake, real), axis=0)
89 |
90 | # 1 is fake, 0 is real
91 | batch_y = np.ones((batch_x.shape[0], 1) + dout_size)
92 | batch_y[fake.shape[0]:] = 0
93 |
94 | yield batch_x, batch_y
95 |
96 |
97 | def train_discriminator(d, it, samples_per_batch=20):
98 | """Train the discriminator network."""
99 | return d.fit_generator(it, samples_per_epoch=samples_per_batch*2, nb_epoch=1, verbose=False)
100 |
101 |
102 | def pix2pix_generator(it, dout_size):
103 | """
104 | Generate data for the generator network.
105 |
106 | Parameters:
107 | - it: an iterator that returns a pair of images;
108 | - dout_size: the size of the output of the discriminator.
109 | """
110 | for a, b in it:
111 | # 1 is fake, 0 is real
112 | y = np.zeros((a.shape[0], 1) + dout_size)
113 | yield [a, b], y
114 |
115 |
116 | def train_pix2pix(pix2pix, it, samples_per_batch=20):
117 | """Train the generator network."""
118 | return pix2pix.fit_generator(it, nb_epoch=1, samples_per_epoch=samples_per_batch, verbose=False)
119 |
120 |
121 | def evaluate(models, generators, losses, val_samples=192):
122 | """Evaluate and display the losses of the models."""
123 | # Get necessary generators
124 | d_gen = generators.d_gen_val
125 | p2p_gen = generators.p2p_gen_val
126 |
127 | # Get necessary models
128 | d = models.d
129 | p2p = models.p2p
130 |
131 | # Evaluate
132 | d_loss = d.evaluate_generator(d_gen, val_samples)
133 | p2p_loss = p2p.evaluate_generator(p2p_gen, val_samples)
134 |
135 | losses['d_val'].append(d_loss)
136 | losses['p2p_val'].append(p2p_loss)
137 |
138 | print ''
139 | print ('Train Losses of (D={0} / P2P={1});\n'
140 | 'Validation Losses of (D={2} / P2P={3})'.format(
141 | losses['d'][-1], losses['p2p'][-1], d_loss, p2p_loss))
142 |
143 | return d_loss, p2p_loss
144 |
145 |
146 | def model_creation(d, atob, params):
147 | """Create all the necessary models."""
148 | opt = Adam(lr=params.lr, beta_1=params.beta_1)
149 | p2p = m.pix2pix(atob, d, params.a_ch, params.b_ch, alpha=params.alpha, opt=opt,
150 | is_a_binary=params.is_a_binary, is_b_binary=params.is_b_binary)
151 |
152 | models = MyDict({
153 | 'atob': atob,
154 | 'd': d,
155 | 'p2p': p2p,
156 | })
157 |
158 | return models
159 |
160 |
161 | def generators_creation(it_train, it_val, models, dout_size):
162 | """Create all the necessary data generators."""
163 | # Discriminator data generators
164 | d_gen = discriminator_generator(it_train, models.atob, dout_size)
165 | d_gen_val = discriminator_generator(it_val, models.atob, dout_size)
166 |
167 | # Workaround to make tensorflow work. When atob.predict is called the first
168 | # time it calls tf.get_default_graph. This should be done on the main thread
169 | # and not inside fit_generator. See https://github.com/fchollet/keras/issues/2397
170 | next(d_gen)
171 |
172 | # pix2pix data generators
173 | p2p_gen = pix2pix_generator(it_train, dout_size)
174 | p2p_gen_val = pix2pix_generator(it_val, dout_size)
175 |
176 | generators = MyDict({
177 | 'd_gen': d_gen,
178 | 'd_gen_val': d_gen_val,
179 | 'p2p_gen': p2p_gen,
180 | 'p2p_gen_val': p2p_gen_val,
181 | })
182 |
183 | return generators
184 |
185 |
186 | def train_iteration(models, generators, losses, params):
187 | """Perform a train iteration."""
188 | # Get necessary generators
189 | d_gen = generators.d_gen
190 | p2p_gen = generators.p2p_gen
191 |
192 | # Get necessary models
193 | d = models.d
194 | p2p = models.p2p
195 |
196 | # Update the dscriminator
197 | dhist = train_discriminator(d, d_gen, samples_per_batch=params.samples_per_batch)
198 | losses['d'].extend(dhist.history['loss'])
199 |
200 | # Update the generator
201 | p2phist = train_pix2pix(p2p, p2p_gen, samples_per_batch=params.samples_per_batch)
202 | losses['p2p'].extend(p2phist.history['loss'])
203 |
204 |
205 | def train(models, it_train, it_val, params):
206 | """
207 | Train the model.
208 |
209 | Parameters:
210 | - models: a dictionary with all the models.
211 | - atob: a model that goes from A to B.
212 | - d: the discriminator model.
213 | - p2p: a Pix2Pix model.
214 | - it_train: the iterator of the training data.
215 | - it_val: the iterator of the validation data.
216 | - params: parameters of the training procedure.
217 | - dout_size: the size of the output of the discriminator model.
218 | """
219 | # Create the experiment folder and save the parameters
220 | create_expt_dir(params)
221 |
222 | # Get the output shape of the discriminator
223 | dout_size = d.output_shape[-2:]
224 | # Define the data generators
225 | generators = generators_creation(it_train, it_val, models, dout_size)
226 |
227 | # Define the number of samples to use on each training epoch
228 | train_samples = params.train_samples
229 | if params.train_samples == -1:
230 | train_samples = it_train.N
231 | batches_per_epoch = train_samples // params.samples_per_batch
232 |
233 | # Define the number of samples to use for validation
234 | val_samples = params.val_samples
235 | if val_samples == -1:
236 | val_samples = it_val.N
237 |
238 | losses = {'p2p': [], 'd': [], 'p2p_val': [], 'd_val': []}
239 | if params.continue_train:
240 | losses = load_losses(log_dir=params.log_dir, expt_name=params.expt_name)
241 |
242 | for e in tqdm(range(params.epochs)):
243 |
244 | for b in range(batches_per_epoch):
245 | train_iteration(models, generators, losses, params)
246 |
247 | # Evaluate how the models is doing on the validation set.
248 | evaluate(models, generators, losses, val_samples=val_samples)
249 |
250 | if (e + 1) % params.save_every == 0:
251 | save_weights(models, log_dir=params.log_dir, expt_name=params.expt_name)
252 | log(losses, models.atob, it_val, log_dir=params.log_dir, expt_name=params.expt_name,
253 | is_a_binary=params.is_a_binary, is_b_binary=params.is_b_binary)
254 |
255 | if __name__ == '__main__':
256 | a = sys.argv[1:]
257 |
258 | params = MyDict({
259 | # Model
260 | 'nfd': 32, # Number of filters of the first layer of the discriminator
261 | 'nfatob': 64, # Number of filters of the first layer of the AtoB model
262 | 'alpha': 100, # The weight of the reconstruction loss of the atob model
263 | # Train
264 | 'epochs': 100, # Number of epochs to train the model
265 | 'batch_size': 1, # The batch size
266 | 'samples_per_batch': 20, # The number of samples to train each model on each iteration
267 | 'save_every': 10, # Save results every 'save_every' epochs on the log folder
268 | 'lr': 2e-4, # The learning rate to train the models
269 | 'beta_1': 0.5, # The beta_1 value of the Adam optimizer
270 | 'continue_train': False, # If it should continue the training from the last checkpoint
271 | # File system
272 | 'log_dir': 'log', # Directory to log
273 | 'expt_name': None, # The name of the experiment. Saves the logs into a folder with this name
274 | 'base_dir': 'data/unet_segmentations_binary', # Directory that contains the data
275 | 'train_dir': 'train', # Directory inside base_dir that contains training data
276 | 'val_dir': 'val', # Directory inside base_dir that contains validation data
277 | 'train_samples': -1, # The number of training samples. Set -1 to be the same as training examples
278 | 'val_samples': -1, # The number of validation samples. Set -1 to be the same as validation examples
279 | 'load_to_memory': True, # Whether to load the images into memory
280 | # Image
281 | 'a_ch': 1, # Number of channels of images A
282 | 'b_ch': 3, # Number of channels of images B
283 | 'is_a_binary': True, # If A is binary, its values will be either 0 or 1
284 | 'is_b_binary': False, # If B is binary, the last layer of the atob model is followed by a sigmoid
285 | 'is_a_grayscale': True, # If A is grayscale, the image will only have one channel
286 | 'is_b_grayscale': False, # If B is grayscale, the image will only have one channel
287 | 'target_size': 512, # The size of the images loaded by the iterator. DOES NOT CHANGE THE MODELS
288 | 'rotation_range': 0., # The range to rotate training images for dataset augmentation
289 | 'height_shift_range': 0., # Percentage of height of the image to translate for dataset augmentation
290 | 'width_shift_range': 0., # Percentage of width of the image to translate for dataset augmentation
291 | 'horizontal_flip': False, # If true performs random horizontal flips on the train set
292 | 'vertical_flip': False, # If true performs random vertical flips on the train set
293 | 'zoom_range': 0., # Defines the range to scale the image for dataset augmentation
294 | })
295 |
296 | param_names = [k + '=' for k in params.keys()] + ['help']
297 |
298 | try:
299 | opts, args = getopt.getopt(a, '', param_names)
300 | except getopt.GetoptError:
301 | print_help()
302 | sys.exit()
303 |
304 | for opt, arg in opts:
305 | if opt == '--help':
306 | print_help()
307 | sys.exit()
308 | elif opt in ('--nfatob' '--nfd', '--a_ch', '--b_ch', '--epochs', '--batch_size',
309 | '--samples_per_batch', '--save_every', '--train_samples', '--val_samples',
310 | '--target_size'):
311 | params[opt[2:]] = int(arg)
312 | elif opt in ('--lr', '--beta_1', '--rotation_range', '--height_shift_range',
313 | '--width_shift_range', '--zoom_range', '--alpha'):
314 | params[opt[2:]] = float(arg)
315 | elif opt in ('--is_a_binary', '--is_b_binary', '--is_a_grayscale', '--is_b_grayscale',
316 | '--continue_train', '--horizontal_flip', '--vertical_flip',
317 | '--load_to_memory'):
318 | params[opt[2:]] = True if arg == 'True' else False
319 | elif opt in ('--base_dir', '--train_dir', '--val_dir', '--expt_name', '--log_dir'):
320 | params[opt[2:]] = arg
321 |
322 | dopt = Adam(lr=params.lr, beta_1=params.beta_1)
323 |
324 | # Define the U-Net generator
325 | unet = m.g_unet(params.a_ch, params.b_ch, params.nfatob,
326 | batch_size=params.batch_size, is_binary=params.is_b_binary)
327 |
328 | # Define the discriminator
329 | d = m.discriminator(params.a_ch, params.b_ch, params.nfd, opt=dopt)
330 |
331 | if params.continue_train:
332 | load_weights(unet, d, log_dir=params.log_dir, expt_name=params.expt_name)
333 |
334 | ts = params.target_size
335 | train_dir = os.path.join(params.base_dir, params.train_dir)
336 | it_train = TwoImageIterator(train_dir, is_a_binary=params.is_a_binary,
337 | is_a_grayscale=params.is_a_grayscale,
338 | is_b_grayscale=params.is_b_grayscale,
339 | is_b_binary=params.is_b_binary,
340 | batch_size=params.batch_size,
341 | load_to_memory=params.load_to_memory,
342 | rotation_range=params.rotation_range,
343 | height_shift_range=params.height_shift_range,
344 | width_shift_range=params.height_shift_range,
345 | zoom_range=params.zoom_range,
346 | horizontal_flip=params.horizontal_flip,
347 | vertical_flip=params.vertical_flip,
348 | target_size=(ts, ts))
349 | val_dir = os.path.join(params.base_dir, params.val_dir)
350 | it_val = TwoImageIterator(val_dir, is_a_binary=params.is_a_binary,
351 | is_b_binary=params.is_b_binary,
352 | is_a_grayscale=params.is_a_grayscale,
353 | is_b_grayscale=params.is_b_grayscale,
354 | batch_size=params.batch_size,
355 | load_to_memory=params.load_to_memory,
356 | target_size=(ts, ts))
357 |
358 | models = model_creation(d, unet, params)
359 | train(models, it_train, it_val, params)
360 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | __doc__ = """The model definitions for the pix2pix network taken from the
2 | retina repository at https://github.com/costapt/vess2ret
3 | """
4 | import os
5 |
6 | import keras
7 | from keras import backend as K
8 | from keras import objectives
9 | from keras.layers import Input, merge
10 | from keras.layers.advanced_activations import LeakyReLU
11 | from keras.layers.convolutional import Convolution2D, Deconvolution2D
12 | from keras.layers.core import Activation, Dropout
13 | from keras.layers.normalization import BatchNormalization
14 | from keras.models import Model
15 | from keras.optimizers import Adam
16 |
17 | KERAS_2 = keras.__version__[0] == '2'
18 | try:
19 | # keras 2 imports
20 | from keras.layers.convolutional import Conv2DTranspose
21 | from keras.layers.merge import Concatenate
22 | except ImportError:
23 | print("Keras 2 layers could not be imported defaulting to keras1")
24 | KERAS_2 = False
25 |
26 | K.set_image_dim_ordering('th')
27 |
28 |
29 | def concatenate_layers(inputs, concat_axis, mode='concat'):
30 | if KERAS_2:
31 | assert mode == 'concat', "Only concatenation is supported in this wrapper"
32 | return Concatenate(axis=concat_axis)(inputs)
33 | else:
34 | return merge(inputs=inputs, concat_axis=concat_axis, mode=mode)
35 |
36 |
37 | def Convolution(f, k=3, s=2, border_mode='same', **kwargs):
38 | """Convenience method for Convolutions."""
39 | if KERAS_2:
40 | return Convolution2D(f,
41 | kernel_size=(k, k),
42 | padding=border_mode,
43 | strides=(s, s),
44 | **kwargs)
45 | else:
46 | return Convolution2D(f, k, k, border_mode=border_mode,
47 | subsample=(s, s),
48 | **kwargs)
49 |
50 |
51 | def Deconvolution(f, output_shape, k=2, s=2, **kwargs):
52 | """Convenience method for Transposed Convolutions."""
53 | if KERAS_2:
54 | return Conv2DTranspose(f,
55 | kernel_size=(k, k),
56 | output_shape=output_shape,
57 | strides=(s, s),
58 | data_format=K.image_data_format(),
59 | **kwargs)
60 | else:
61 | return Deconvolution2D(f, k, k, output_shape=output_shape,
62 | subsample=(s, s), **kwargs)
63 |
64 |
65 | def BatchNorm(mode=2, axis=1, **kwargs):
66 | """Convenience method for BatchNormalization layers."""
67 | if KERAS_2:
68 | return BatchNormalization(axis=axis, **kwargs)
69 | else:
70 | return BatchNormalization(mode=2,axis=axis, **kwargs)
71 |
72 |
73 | def g_unet(in_ch, out_ch, nf, batch_size=1, is_binary=False, name='unet'):
74 | # type: (int, int, int, int, bool, str) -> keras.models.Model
75 | """Define a U-Net.
76 |
77 | Input has shape in_ch x 512 x 512
78 | Parameters:
79 | - in_ch: the number of input channels;
80 | - out_ch: the number of output channels;
81 | - nf: the number of filters of the first layer;
82 | - is_binary: if is_binary is true, the last layer is followed by a sigmoid
83 | activation function, otherwise, a tanh is used.
84 | >>> K.set_image_dim_ordering('th')
85 | >>> K.image_data_format()
86 | 'channels_first'
87 | >>> unet = g_unet(1, 2, 3, batch_size=5, is_binary=True)
88 | TheanoShapedU-NET
89 | >>> for ilay in unet.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id
90 | >>> unet.summary() #doctest: +NORMALIZE_WHITESPACE
91 | _________________________________________________________________
92 | Layer (type) Output Shape Param #
93 | =================================================================
94 | input (InputLayer) (None, 1, 512, 512) 0
95 | _________________________________________________________________
96 | conv2d (Conv2D) (None, 3, 256, 256) 30
97 | _________________________________________________________________
98 | batch_normalization (BatchNo (None, 3, 256, 256) 12
99 | _________________________________________________________________
100 | leaky_re_lu (LeakyReLU) (None, 3, 256, 256) 0
101 | _________________________________________________________________
102 | conv2d (Conv2D) (None, 6, 128, 128) 168
103 | _________________________________________________________________
104 | batch_normalization (BatchNo (None, 6, 128, 128) 24
105 | _________________________________________________________________
106 | leaky_re_lu (LeakyReLU) (None, 6, 128, 128) 0
107 | _________________________________________________________________
108 | conv2d (Conv2D) (None, 12, 64, 64) 660
109 | _________________________________________________________________
110 | batch_normalization (BatchNo (None, 12, 64, 64) 48
111 | _________________________________________________________________
112 | leaky_re_lu (LeakyReLU) (None, 12, 64, 64) 0
113 | _________________________________________________________________
114 | conv2d (Conv2D) (None, 24, 32, 32) 2616
115 | _________________________________________________________________
116 | batch_normalization (BatchNo (None, 24, 32, 32) 96
117 | _________________________________________________________________
118 | leaky_re_lu (LeakyReLU) (None, 24, 32, 32) 0
119 | _________________________________________________________________
120 | conv2d (Conv2D) (None, 24, 16, 16) 5208
121 | _________________________________________________________________
122 | batch_normalization (BatchNo (None, 24, 16, 16) 96
123 | _________________________________________________________________
124 | leaky_re_lu (LeakyReLU) (None, 24, 16, 16) 0
125 | _________________________________________________________________
126 | conv2d (Conv2D) (None, 24, 8, 8) 5208
127 | _________________________________________________________________
128 | batch_normalization (BatchNo (None, 24, 8, 8) 96
129 | _________________________________________________________________
130 | leaky_re_lu (LeakyReLU) (None, 24, 8, 8) 0
131 | _________________________________________________________________
132 | conv2d (Conv2D) (None, 24, 4, 4) 5208
133 | _________________________________________________________________
134 | batch_normalization (BatchNo (None, 24, 4, 4) 96
135 | _________________________________________________________________
136 | leaky_re_lu (LeakyReLU) (None, 24, 4, 4) 0
137 | _________________________________________________________________
138 | conv2d (Conv2D) (None, 24, 2, 2) 5208
139 | _________________________________________________________________
140 | batch_normalization (BatchNo (None, 24, 2, 2) 96
141 | _________________________________________________________________
142 | leaky_re_lu (LeakyReLU) (None, 24, 2, 2) 0
143 | _________________________________________________________________
144 | conv2d (Conv2D) (None, 24, 1, 1) 2328
145 | _________________________________________________________________
146 | batch_normalization (BatchNo (None, 24, 1, 1) 96
147 | _________________________________________________________________
148 | leaky_re_lu (LeakyReLU) (None, 24, 1, 1) 0
149 | _________________________________________________________________
150 | conv2d_transpose (Conv2DTran (None, 24, 2, 2) 2328
151 | _________________________________________________________________
152 | batch_normalization (BatchNo (None, 24, 2, 2) 96
153 | _________________________________________________________________
154 | dropout (Dropout) (None, 24, 2, 2) 0
155 | _________________________________________________________________
156 | concatenate (Concatenate) (None, 48, 2, 2) 0
157 | _________________________________________________________________
158 | leaky_re_lu (LeakyReLU) (None, 48, 2, 2) 0
159 | _________________________________________________________________
160 | conv2d_transpose (Conv2DTran (None, 24, 4, 4) 4632
161 | _________________________________________________________________
162 | batch_normalization (BatchNo (None, 24, 4, 4) 96
163 | _________________________________________________________________
164 | dropout (Dropout) (None, 24, 4, 4) 0
165 | _________________________________________________________________
166 | concatenate (Concatenate) (None, 48, 4, 4) 0
167 | _________________________________________________________________
168 | leaky_re_lu (LeakyReLU) (None, 48, 4, 4) 0
169 | _________________________________________________________________
170 | conv2d_transpose (Conv2DTran (None, 24, 8, 8) 4632
171 | _________________________________________________________________
172 | batch_normalization (BatchNo (None, 24, 8, 8) 96
173 | _________________________________________________________________
174 | dropout (Dropout) (None, 24, 8, 8) 0
175 | _________________________________________________________________
176 | concatenate (Concatenate) (None, 48, 8, 8) 0
177 | _________________________________________________________________
178 | leaky_re_lu (LeakyReLU) (None, 48, 8, 8) 0
179 | _________________________________________________________________
180 | conv2d_transpose (Conv2DTran (None, 24, 16, 16) 4632
181 | _________________________________________________________________
182 | batch_normalization (BatchNo (None, 24, 16, 16) 96
183 | _________________________________________________________________
184 | concatenate (Concatenate) (None, 48, 16, 16) 0
185 | _________________________________________________________________
186 | leaky_re_lu (LeakyReLU) (None, 48, 16, 16) 0
187 | _________________________________________________________________
188 | conv2d_transpose (Conv2DTran (None, 24, 32, 32) 4632
189 | _________________________________________________________________
190 | batch_normalization (BatchNo (None, 24, 32, 32) 96
191 | _________________________________________________________________
192 | concatenate (Concatenate) (None, 48, 32, 32) 0
193 | _________________________________________________________________
194 | leaky_re_lu (LeakyReLU) (None, 48, 32, 32) 0
195 | _________________________________________________________________
196 | conv2d_transpose (Conv2DTran (None, 12, 64, 64) 2316
197 | _________________________________________________________________
198 | batch_normalization (BatchNo (None, 12, 64, 64) 48
199 | _________________________________________________________________
200 | concatenate (Concatenate) (None, 24, 64, 64) 0
201 | _________________________________________________________________
202 | leaky_re_lu (LeakyReLU) (None, 24, 64, 64) 0
203 | _________________________________________________________________
204 | conv2d_transpose (Conv2DTran (None, 6, 128, 128) 582
205 | _________________________________________________________________
206 | batch_normalization (BatchNo (None, 6, 128, 128) 24
207 | _________________________________________________________________
208 | concatenate (Concatenate) (None, 12, 128, 128) 0
209 | _________________________________________________________________
210 | leaky_re_lu (LeakyReLU) (None, 12, 128, 128) 0
211 | _________________________________________________________________
212 | conv2d_transpose (Conv2DTran (None, 3, 256, 256) 147
213 | _________________________________________________________________
214 | batch_normalization (BatchNo (None, 3, 256, 256) 12
215 | _________________________________________________________________
216 | concatenate (Concatenate) (None, 6, 256, 256) 0
217 | _________________________________________________________________
218 | leaky_re_lu (LeakyReLU) (None, 6, 256, 256) 0
219 | _________________________________________________________________
220 | conv2d_transpose (Conv2DTran (None, 2, 512, 512) 50
221 | _________________________________________________________________
222 | activation (Activation) (None, 2, 512, 512) 0
223 | =================================================================
224 | Total params: 51,809.0
225 | Trainable params: 51,197.0
226 | Non-trainable params: 612.0
227 | _________________________________________________________________
228 | >>> K.set_image_dim_ordering('tf')
229 | >>> K.image_data_format()
230 | 'channels_last'
231 | >>> unet2=g_unet(3, 4, 2, batch_size=7, is_binary=False)
232 | TensorflowShapedU-NET
233 | >>> for ilay in unet2.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id
234 | >>> unet2.summary() #doctest: +NORMALIZE_WHITESPACE
235 | _________________________________________________________________
236 | Layer (type) Output Shape Param #
237 | =================================================================
238 | input (InputLayer) (None, 512, 512, 3) 0
239 | _________________________________________________________________
240 | conv2d (Conv2D) (None, 256, 256, 2) 56
241 | _________________________________________________________________
242 | batch_normalization (BatchNo (None, 256, 256, 2) 1024
243 | _________________________________________________________________
244 | leaky_re_lu (LeakyReLU) (None, 256, 256, 2) 0
245 | _________________________________________________________________
246 | conv2d (Conv2D) (None, 128, 128, 4) 76
247 | _________________________________________________________________
248 | batch_normalization (BatchNo (None, 128, 128, 4) 512
249 | _________________________________________________________________
250 | leaky_re_lu (LeakyReLU) (None, 128, 128, 4) 0
251 | _________________________________________________________________
252 | conv2d (Conv2D) (None, 64, 64, 8) 296
253 | _________________________________________________________________
254 | batch_normalization (BatchNo (None, 64, 64, 8) 256
255 | _________________________________________________________________
256 | leaky_re_lu (LeakyReLU) (None, 64, 64, 8) 0
257 | _________________________________________________________________
258 | conv2d (Conv2D) (None, 32, 32, 16) 1168
259 | _________________________________________________________________
260 | batch_normalization (BatchNo (None, 32, 32, 16) 128
261 | _________________________________________________________________
262 | leaky_re_lu (LeakyReLU) (None, 32, 32, 16) 0
263 | _________________________________________________________________
264 | conv2d (Conv2D) (None, 16, 16, 16) 2320
265 | _________________________________________________________________
266 | batch_normalization (BatchNo (None, 16, 16, 16) 64
267 | _________________________________________________________________
268 | leaky_re_lu (LeakyReLU) (None, 16, 16, 16) 0
269 | _________________________________________________________________
270 | conv2d (Conv2D) (None, 8, 8, 16) 2320
271 | _________________________________________________________________
272 | batch_normalization (BatchNo (None, 8, 8, 16) 32
273 | _________________________________________________________________
274 | leaky_re_lu (LeakyReLU) (None, 8, 8, 16) 0
275 | _________________________________________________________________
276 | conv2d (Conv2D) (None, 4, 4, 16) 2320
277 | _________________________________________________________________
278 | batch_normalization (BatchNo (None, 4, 4, 16) 16
279 | _________________________________________________________________
280 | leaky_re_lu (LeakyReLU) (None, 4, 4, 16) 0
281 | _________________________________________________________________
282 | conv2d (Conv2D) (None, 2, 2, 16) 2320
283 | _________________________________________________________________
284 | batch_normalization (BatchNo (None, 2, 2, 16) 8
285 | _________________________________________________________________
286 | leaky_re_lu (LeakyReLU) (None, 2, 2, 16) 0
287 | _________________________________________________________________
288 | conv2d (Conv2D) (None, 1, 1, 16) 1040
289 | _________________________________________________________________
290 | batch_normalization (BatchNo (None, 1, 1, 16) 4
291 | _________________________________________________________________
292 | leaky_re_lu (LeakyReLU) (None, 1, 1, 16) 0
293 | _________________________________________________________________
294 | conv2d_transpose (Conv2DTran (None, 2, 2, 16) 1040
295 | _________________________________________________________________
296 | batch_normalization (BatchNo (None, 2, 2, 16) 8
297 | _________________________________________________________________
298 | dropout (Dropout) (None, 2, 2, 16) 0
299 | _________________________________________________________________
300 | concatenate (Concatenate) (None, 2, 2, 32) 0
301 | _________________________________________________________________
302 | leaky_re_lu (LeakyReLU) (None, 2, 2, 32) 0
303 | _________________________________________________________________
304 | conv2d_transpose (Conv2DTran (None, 4, 4, 16) 2064
305 | _________________________________________________________________
306 | batch_normalization (BatchNo (None, 4, 4, 16) 16
307 | _________________________________________________________________
308 | dropout (Dropout) (None, 4, 4, 16) 0
309 | _________________________________________________________________
310 | concatenate (Concatenate) (None, 4, 4, 32) 0
311 | _________________________________________________________________
312 | leaky_re_lu (LeakyReLU) (None, 4, 4, 32) 0
313 | _________________________________________________________________
314 | conv2d_transpose (Conv2DTran (None, 8, 8, 16) 2064
315 | _________________________________________________________________
316 | batch_normalization (BatchNo (None, 8, 8, 16) 32
317 | _________________________________________________________________
318 | dropout (Dropout) (None, 8, 8, 16) 0
319 | _________________________________________________________________
320 | concatenate (Concatenate) (None, 8, 8, 32) 0
321 | _________________________________________________________________
322 | leaky_re_lu (LeakyReLU) (None, 8, 8, 32) 0
323 | _________________________________________________________________
324 | conv2d_transpose (Conv2DTran (None, 16, 16, 16) 2064
325 | _________________________________________________________________
326 | batch_normalization (BatchNo (None, 16, 16, 16) 64
327 | _________________________________________________________________
328 | concatenate (Concatenate) (None, 16, 16, 32) 0
329 | _________________________________________________________________
330 | leaky_re_lu (LeakyReLU) (None, 16, 16, 32) 0
331 | _________________________________________________________________
332 | conv2d_transpose (Conv2DTran (None, 32, 32, 16) 2064
333 | _________________________________________________________________
334 | batch_normalization (BatchNo (None, 32, 32, 16) 128
335 | _________________________________________________________________
336 | concatenate (Concatenate) (None, 32, 32, 32) 0
337 | _________________________________________________________________
338 | leaky_re_lu (LeakyReLU) (None, 32, 32, 32) 0
339 | _________________________________________________________________
340 | conv2d_transpose (Conv2DTran (None, 64, 64, 8) 1032
341 | _________________________________________________________________
342 | batch_normalization (BatchNo (None, 64, 64, 8) 256
343 | _________________________________________________________________
344 | concatenate (Concatenate) (None, 64, 64, 16) 0
345 | _________________________________________________________________
346 | leaky_re_lu (LeakyReLU) (None, 64, 64, 16) 0
347 | _________________________________________________________________
348 | conv2d_transpose (Conv2DTran (None, 128, 128, 4) 260
349 | _________________________________________________________________
350 | batch_normalization (BatchNo (None, 128, 128, 4) 512
351 | _________________________________________________________________
352 | concatenate (Concatenate) (None, 128, 128, 8) 0
353 | _________________________________________________________________
354 | leaky_re_lu (LeakyReLU) (None, 128, 128, 8) 0
355 | _________________________________________________________________
356 | conv2d_transpose (Conv2DTran (None, 256, 256, 2) 66
357 | _________________________________________________________________
358 | batch_normalization (BatchNo (None, 256, 256, 2) 1024
359 | _________________________________________________________________
360 | concatenate (Concatenate) (None, 256, 256, 4) 0
361 | _________________________________________________________________
362 | leaky_re_lu (LeakyReLU) (None, 256, 256, 4) 0
363 | _________________________________________________________________
364 | conv2d_transpose (Conv2DTran (None, 512, 512, 4) 68
365 | _________________________________________________________________
366 | activation (Activation) (None, 512, 512, 4) 0
367 | =================================================================
368 | Total params: 26,722.0
369 | Trainable params: 24,680.0
370 | Non-trainable params: 2,042.0
371 | _________________________________________________________________
372 | """
373 | merge_params = {
374 | 'mode': 'concat',
375 | 'concat_axis': 1
376 | }
377 | if K.image_dim_ordering() == 'th':
378 | print('TheanoShapedU-NET')
379 | i = Input(shape=(in_ch, 512, 512))
380 |
381 | def get_deconv_shape(samples, channels, x_dim, y_dim):
382 | return samples, channels, x_dim, y_dim
383 |
384 | elif K.image_dim_ordering() == 'tf':
385 | i = Input(shape=(512, 512, in_ch))
386 | print('TensorflowShapedU-NET')
387 |
388 | def get_deconv_shape(samples, channels, x_dim, y_dim):
389 | return samples, x_dim, y_dim, channels
390 |
391 | merge_params['concat_axis'] = 3
392 | else:
393 | raise ValueError(
394 | 'Keras dimension ordering not supported: {}'.format(
395 | K.image_dim_ordering()))
396 |
397 | # in_ch x 512 x 512
398 | conv1 = Convolution(nf)(i)
399 | conv1 = BatchNorm()(conv1)
400 | x = LeakyReLU(0.2)(conv1)
401 | # nf x 256 x 256
402 |
403 | conv2 = Convolution(nf * 2)(x)
404 | conv2 = BatchNorm()(conv2)
405 | x = LeakyReLU(0.2)(conv2)
406 | # nf*2 x 128 x 128
407 |
408 | conv3 = Convolution(nf * 4)(x)
409 | conv3 = BatchNorm()(conv3)
410 | x = LeakyReLU(0.2)(conv3)
411 | # nf*4 x 64 x 64
412 |
413 | conv4 = Convolution(nf * 8)(x)
414 | conv4 = BatchNorm()(conv4)
415 | x = LeakyReLU(0.2)(conv4)
416 | # nf*8 x 32 x 32
417 |
418 | conv5 = Convolution(nf * 8)(x)
419 | conv5 = BatchNorm()(conv5)
420 | x = LeakyReLU(0.2)(conv5)
421 | # nf*8 x 16 x 16
422 |
423 | conv6 = Convolution(nf * 8)(x)
424 | conv6 = BatchNorm()(conv6)
425 | x = LeakyReLU(0.2)(conv6)
426 | # nf*8 x 8 x 8
427 |
428 | conv7 = Convolution(nf * 8)(x)
429 | conv7 = BatchNorm()(conv7)
430 | x = LeakyReLU(0.2)(conv7)
431 | # nf*8 x 4 x 4
432 |
433 | conv8 = Convolution(nf * 8)(x)
434 | conv8 = BatchNorm()(conv8)
435 | x = LeakyReLU(0.2)(conv8)
436 | # nf*8 x 2 x 2
437 |
438 | conv9 = Convolution(nf * 8, k=2, s=1, border_mode='valid')(x)
439 | conv9 = BatchNorm()(conv9)
440 | x = LeakyReLU(0.2)(conv9)
441 | # nf*8 x 1 x 1
442 |
443 | dconv1 = Deconvolution(nf * 8,
444 | get_deconv_shape(batch_size, nf * 8, 2, 2),
445 | k=2, s=1)(x)
446 | dconv1 = BatchNorm()(dconv1)
447 | dconv1 = Dropout(0.5)(dconv1)
448 |
449 | x = concatenate_layers([dconv1, conv8], **merge_params)
450 |
451 | x = LeakyReLU(0.2)(x)
452 | # nf*(8 + 8) x 2 x 2
453 |
454 | dconv2 = Deconvolution(nf * 8,
455 | get_deconv_shape(batch_size, nf * 8, 4, 4))(x)
456 | dconv2 = BatchNorm()(dconv2)
457 | dconv2 = Dropout(0.5)(dconv2)
458 | x = concatenate_layers([dconv2, conv7], **merge_params)
459 | x = LeakyReLU(0.2)(x)
460 | # nf*(8 + 8) x 4 x 4
461 |
462 | dconv3 = Deconvolution(nf * 8,
463 | get_deconv_shape(batch_size, nf * 8, 8, 8))(x)
464 | dconv3 = BatchNorm()(dconv3)
465 | dconv3 = Dropout(0.5)(dconv3)
466 | x = concatenate_layers([dconv3, conv6], **merge_params)
467 | x = LeakyReLU(0.2)(x)
468 | # nf*(8 + 8) x 8 x 8
469 |
470 | dconv4 = Deconvolution(nf * 8,
471 | get_deconv_shape(batch_size, nf * 8, 16, 16))(x)
472 | dconv4 = BatchNorm()(dconv4)
473 | x = concatenate_layers([dconv4, conv5], **merge_params)
474 | x = LeakyReLU(0.2)(x)
475 | # nf*(8 + 8) x 16 x 16
476 |
477 | dconv5 = Deconvolution(nf * 8,
478 | get_deconv_shape(batch_size, nf * 8, 32, 32))(x)
479 | dconv5 = BatchNorm()(dconv5)
480 | x = concatenate_layers([dconv5, conv4], **merge_params)
481 | x = LeakyReLU(0.2)(x)
482 | # nf*(8 + 8) x 32 x 32
483 |
484 | dconv6 = Deconvolution(nf * 4,
485 | get_deconv_shape(batch_size, nf * 4, 64, 64))(x)
486 | dconv6 = BatchNorm()(dconv6)
487 | x = concatenate_layers([dconv6, conv3], **merge_params)
488 | x = LeakyReLU(0.2)(x)
489 | # nf*(4 + 4) x 64 x 64
490 |
491 | dconv7 = Deconvolution(nf * 2,
492 | get_deconv_shape(batch_size, nf * 2, 128, 128))(x)
493 | dconv7 = BatchNorm()(dconv7)
494 | x = concatenate_layers([dconv7, conv2], **merge_params)
495 | x = LeakyReLU(0.2)(x)
496 | # nf*(2 + 2) x 128 x 128
497 |
498 | dconv8 = Deconvolution(nf,
499 | get_deconv_shape(batch_size, nf, 256, 256))(x)
500 | dconv8 = BatchNorm()(dconv8)
501 | x = concatenate_layers([dconv8, conv1], **merge_params)
502 | x = LeakyReLU(0.2)(x)
503 | # nf*(1 + 1) x 256 x 256
504 |
505 | dconv9 = Deconvolution(out_ch,
506 | get_deconv_shape(batch_size, out_ch, 512, 512))(x)
507 | # out_ch x 512 x 512
508 |
509 | act = 'sigmoid' if is_binary else 'tanh'
510 | out = Activation(act)(dconv9)
511 |
512 | unet = Model(i, out, name=name)
513 |
514 | return unet
515 |
516 |
517 | def discriminator(a_ch, b_ch, nf, opt=Adam(lr=2e-4, beta_1=0.5), name='d'):
518 | """Define the discriminator network.
519 |
520 | Parameters:
521 | - a_ch: the number of channels of the first image;
522 | - b_ch: the number of channels of the second image;
523 | - nf: the number of filters of the first layer.
524 | >>> K.set_image_dim_ordering('th')
525 | >>> disc=discriminator(3,4,2)
526 | >>> for ilay in disc.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id
527 | >>> disc.summary() #doctest: +NORMALIZE_WHITESPACE
528 | _________________________________________________________________
529 | Layer (type) Output Shape Param #
530 | =================================================================
531 | input (InputLayer) (None, 7, 512, 512) 0
532 | _________________________________________________________________
533 | conv2d (Conv2D) (None, 2, 256, 256) 128
534 | _________________________________________________________________
535 | leaky_re_lu (LeakyReLU) (None, 2, 256, 256) 0
536 | _________________________________________________________________
537 | conv2d (Conv2D) (None, 4, 128, 128) 76
538 | _________________________________________________________________
539 | leaky_re_lu (LeakyReLU) (None, 4, 128, 128) 0
540 | _________________________________________________________________
541 | conv2d (Conv2D) (None, 8, 64, 64) 296
542 | _________________________________________________________________
543 | leaky_re_lu (LeakyReLU) (None, 8, 64, 64) 0
544 | _________________________________________________________________
545 | conv2d (Conv2D) (None, 16, 32, 32) 1168
546 | _________________________________________________________________
547 | leaky_re_lu (LeakyReLU) (None, 16, 32, 32) 0
548 | _________________________________________________________________
549 | conv2d (Conv2D) (None, 1, 16, 16) 145
550 | _________________________________________________________________
551 | activation (Activation) (None, 1, 16, 16) 0
552 | =================================================================
553 | Total params: 1,813.0
554 | Trainable params: 1,813.0
555 | Non-trainable params: 0.0
556 | _________________________________________________________________
557 | """
558 | i = Input(shape=(a_ch + b_ch, 512, 512))
559 |
560 | # (a_ch + b_ch) x 512 x 512
561 | conv1 = Convolution(nf)(i)
562 | x = LeakyReLU(0.2)(conv1)
563 | # nf x 256 x 256
564 |
565 | conv2 = Convolution(nf * 2)(x)
566 | x = LeakyReLU(0.2)(conv2)
567 | # nf*2 x 128 x 128
568 |
569 | conv3 = Convolution(nf * 4)(x)
570 | x = LeakyReLU(0.2)(conv3)
571 | # nf*4 x 64 x 64
572 |
573 | conv4 = Convolution(nf * 8)(x)
574 | x = LeakyReLU(0.2)(conv4)
575 | # nf*8 x 32 x 32
576 |
577 | conv5 = Convolution(1)(x)
578 | out = Activation('sigmoid')(conv5)
579 | # 1 x 16 x 16
580 |
581 | d = Model(i, out, name=name)
582 |
583 | def d_loss(y_true, y_pred):
584 | L = objectives.binary_crossentropy(K.batch_flatten(y_true),
585 | K.batch_flatten(y_pred))
586 | return L
587 |
588 | d.compile(optimizer=opt, loss=d_loss)
589 | return d
590 |
591 |
592 | def pix2pix(atob, d, a_ch, b_ch, alpha=100, is_a_binary=False,
593 | is_b_binary=False, opt=Adam(lr=2e-4, beta_1=0.5), name='pix2pix'):
594 | # type: (...) -> keras.models.Model
595 | """
596 | Define the pix2pix network.
597 | :param atob:
598 | :param d:
599 | :param a_ch:
600 | :param b_ch:
601 | :param alpha:
602 | :param is_a_binary:
603 | :param is_b_binary:
604 | :param opt:
605 | :param name:
606 | :return:
607 | >>> K.set_image_dim_ordering('th')
608 | >>> unet = g_unet(3, 4, 2, batch_size=8, is_binary=False)
609 | TheanoShapedU-NET
610 | >>> disc=discriminator(3,4,2)
611 | >>> pp_net=pix2pix(unet, disc, 3, 4)
612 | >>> for ilay in pp_net.layers: ilay.name='_'.join(ilay.name.split('_')[:-1]) # remove layer id
613 | >>> pp_net.summary() #doctest: +NORMALIZE_WHITESPACE
614 | _________________________________________________________________
615 | Layer (type) Output Shape Param #
616 | =================================================================
617 | input (InputLayer) (None, 3, 512, 512) 0
618 | _________________________________________________________________
619 | (Model) (None, 4, 512, 512) 23454
620 | _________________________________________________________________
621 | concatenate (Concatenate) (None, 7, 512, 512) 0
622 | _________________________________________________________________
623 | (Model) (None, 1, 16, 16) 1813
624 | =================================================================
625 | Total params: 25,267.0
626 | Trainable params: 24,859.0
627 | Non-trainable params: 408.0
628 | _________________________________________________________________
629 | """
630 | a = Input(shape=(a_ch, 512, 512))
631 | b = Input(shape=(b_ch, 512, 512))
632 |
633 | # A -> B'
634 | bp = atob(a)
635 |
636 | # Discriminator receives the pair of images
637 | d_in = concatenate_layers([a, bp], mode='concat', concat_axis=1)
638 |
639 | pix2pix = Model([a, b], d(d_in), name=name)
640 |
641 | def pix2pix_loss(y_true, y_pred):
642 | y_true_flat = K.batch_flatten(y_true)
643 | y_pred_flat = K.batch_flatten(y_pred)
644 |
645 | # Adversarial Loss
646 | L_adv = objectives.binary_crossentropy(y_true_flat, y_pred_flat)
647 |
648 | # A to B loss
649 | b_flat = K.batch_flatten(b)
650 | bp_flat = K.batch_flatten(bp)
651 | if is_b_binary:
652 | L_atob = objectives.binary_crossentropy(b_flat, bp_flat)
653 | else:
654 | L_atob = K.mean(K.abs(b_flat - bp_flat))
655 |
656 | return L_adv + alpha * L_atob
657 |
658 | # This network is used to train the generator. Freeze the discriminator part.
659 | pix2pix.get_layer('d').trainable = False
660 |
661 | pix2pix.compile(optimizer=opt, loss=pix2pix_loss)
662 | return pix2pix
663 |
664 |
665 | if __name__ == '__main__':
666 | import doctest
667 |
668 | TEST_TF = True
669 | if TEST_TF:
670 | os.environ['KERAS_BACKEND'] = 'tensorflow'
671 | else:
672 | os.environ['KERAS_BACKEND'] = 'theano'
673 | doctest.testsource('models.py', verbose=True, optionflags=doctest.ELLIPSIS)
674 |
--------------------------------------------------------------------------------