├── .gitignore ├── README.md ├── apply_patches_to_mask.ipynb ├── extract_patches.py ├── how_to_train.txt ├── images └── output_0_0.png ├── multi_gpu.py ├── train.py └── u_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.*~ 3 | .ipynb_checkpoints 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PatchesNet 2 | Carvana Boundaries Refinement 3 | 4 | 5 | ```python 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import glob 9 | import cv2 10 | import os 11 | 12 | N = 64 13 | 14 | ids = [os.path.basename(x) for x in glob.glob('/data/pavel/carv/train/*09*.jpg')] 15 | ids = [x.split('.')[0] for x in ids] 16 | 17 | for id in ids[15:16]: 18 | mask = cv2.imread('/data/pavel/carv/train_masks/%s_mask.png' % id, cv2.IMREAD_GRAYSCALE) / 255.0 19 | mask = np.pad(mask, ((0,0), (1,1)), 'constant') 20 | 21 | border = np.abs(np.gradient(mask)[1]) + np.abs(np.gradient(mask)[0]) 22 | border = np.select([border == 0.5, border != 0.5], [1.0, border]) 23 | 24 | img = cv2.imread('/data/pavel/carv/train/%s.jpg' % id) 25 | img = np.pad(img, ((0,0), (1,1), (0,0)), 'constant') 26 | 27 | height, width = mask.shape 28 | 29 | i = 0 30 | for x, y in zip(np.nonzero(border)[0], np.nonzero(border)[1]): 31 | if i%50 == 0: 32 | cv2.rectangle(img,(y-N-1,x-N-1),(y+N,x+N),(255,255,0),1) 33 | i = i +1 34 | 35 | plt.figure(figsize=(25, 25)) 36 | plt.imshow(img) 37 | plt.show() 38 | 39 | ``` 40 | 41 | 42 | ![png](images/output_0_0.png) 43 | 44 | 45 | 46 | ```python 47 | 48 | ``` 49 | -------------------------------------------------------------------------------- /extract_patches.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import glob 4 | import cv2 5 | import scipy.misc as misc 6 | import os 7 | from tqdm import tqdm 8 | import errno 9 | from os.path import join 10 | import csv 11 | import random 12 | 13 | # patch size 14 | N = 256 15 | ROOT_DIR = '/data/pavel/carv' 16 | 17 | def mkdir_p(path): 18 | """Utility function emulating mkdir -p.""" 19 | try: 20 | os.makedirs(path) 21 | except OSError as exc: # Python >2.5 22 | if exc.errno == errno.EEXIST and os.path.isdir(path): 23 | pass 24 | else: 25 | raise 26 | 27 | TRAIN_FOLDER_PATCHES = join(ROOT_DIR, 'train_patches_' + str(N)) 28 | TRAIN_FOLDER_MASKS = join(ROOT_DIR, 'train_patches_masks_' + str(N)) 29 | CSVFILENAME = join(ROOT_DIR, 'train_patches_' + str(N) + ".csv") 30 | 31 | mkdir_p(TRAIN_FOLDER_PATCHES) 32 | mkdir_p(TRAIN_FOLDER_MASKS) 33 | 34 | ids = [os.path.basename(x) for x in glob.glob(ROOT_DIR + '/train_hq/*.jpg')] 35 | ids = [x.split('.')[0] for x in ids] 36 | ids.sort() 37 | 38 | with open(CSVFILENAME, 'wb') as csvfile: 39 | writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 40 | for j in tqdm(range(len(ids))): 41 | mask = misc.imread(ROOT_DIR + '/train_masks/%s_mask.gif' % 42 | ids[j], cv2.IMREAD_GRAYSCALE)[...,0] / 255.0 43 | mask = np.pad(mask, ((N // 2, N // 2), (N // 2, N // 2)), 'symmetric') 44 | 45 | border = np.abs(np.gradient(mask)[1]) + np.abs(np.gradient(mask)[0]) 46 | border = np.select([border == 0.5, border != 0.5], [1.0, border]) 47 | 48 | img = cv2.imread(ROOT_DIR + '/train_hq/%s.jpg' % ids[j]) 49 | img = np.pad( 50 | img, ((N // 2, N // 2), (N // 2, N // 2), (0, 0)), 'symmetric') 51 | 52 | height, width = mask.shape 53 | 54 | patches_img = [] 55 | patches_mask = [] 56 | 57 | i = 0 58 | for x, y in zip(np.nonzero(border)[0], np.nonzero(border)[1]): 59 | if i % 50 == 0 and x - N // 2 >= 0 and y - N // 2 >= 0 and x + N // 2 < img.shape[0] and y + N // 2 < img.shape[1]: 60 | patch_filename = '%s_%s' % (ids[j], i) 61 | misc.imsave(join(TRAIN_FOLDER_PATCHES, patch_filename + '.jpg'), 62 | img[x - N // 2:x + N // 2, y - N // 2:y + N // 2, :]) 63 | misc.imsave(join(TRAIN_FOLDER_MASKS, patch_filename + '.png'), 64 | mask[x - N // 2:x + N // 2, y - N // 2:y + N // 2] * 255) 65 | writer.writerow([patch_filename + '.jpg', y, x]) 66 | i = i + 1 67 | 68 | # write a random patch (maybe not touching edge to train patchesnet on false positives outside/inside car) 69 | x = random.randint(N//2, img.shape[0] - N//2) 70 | y = random.randint(N//2, img.shape[1] - N//2) 71 | 72 | patch_filename = '%s_%s' % (ids[j], i) 73 | misc.imsave(join(TRAIN_FOLDER_PATCHES, patch_filename + '.jpg'), 74 | img[x - N // 2:x + N // 2, y - N // 2:y + N // 2, :]) 75 | misc.imsave(join(TRAIN_FOLDER_MASKS, patch_filename + '.png'), 76 | mask[x - N // 2:x + N // 2, y - N // 2:y + N // 2] * 255) 77 | writer.writerow([patch_filename + '.jpg', y, x]) 78 | 79 | -------------------------------------------------------------------------------- /how_to_train.txt: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 KERAS_BACKEND=tensorflow python train.py 2 | -------------------------------------------------------------------------------- /images/output_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pavelgonchar/PatchesNet/7523c39f1575f2df88d184ace0645d4535359866/images/output_0_0.png -------------------------------------------------------------------------------- /multi_gpu.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import keras 3 | from keras.models import* 4 | from keras.layers import Input, merge, Lambda 5 | from keras.layers.merge import Concatenate 6 | from keras import backend as K 7 | 8 | import tensorflow as tf 9 | session_config = tf.ConfigProto() 10 | session_config.gpu_options.allow_growth = True 11 | session = tf.Session(config=session_config) 12 | 13 | def slice_batch(x, n_gpus, part): 14 | sh = K.shape(x) 15 | L = sh[0] // n_gpus 16 | if part == n_gpus - 1: 17 | return x[part*L:] 18 | return x[part*L:(part+1)*L] 19 | 20 | 21 | def to_multi_gpu(model, n_gpus=2): 22 | if n_gpus ==1: 23 | return model 24 | 25 | with tf.device('/cpu:0'): 26 | x = Input(model.input_shape[1:]) 27 | towers = [] 28 | for g in range(n_gpus): 29 | with tf.device('/gpu:' + str(g)): 30 | slice_g = Lambda(slice_batch, lambda shape: shape, arguments={'n_gpus':n_gpus, 'part':g})(x) 31 | towers.append(model(slice_g)) 32 | 33 | with tf.device('/cpu:0'): 34 | # Deprecated 35 | #merged = merge(towers, mode='concat', concat_axis=0) 36 | merged = Concatenate(axis=0)(towers) 37 | 38 | new_model = Model(inputs=[x], outputs=merged) 39 | funcType = type(model.save) 40 | 41 | # monkeypatch the save to save just the underlying model 42 | def new_save(self_,filepath, overwrite=True): 43 | model.save(filepath, overwrite) 44 | new_model.save=funcType(new_save, new_model) 45 | return new_model -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pandas as pd 4 | from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard 5 | from keras.optimizers import SGD, Adam, RMSprop, Nadam 6 | from keras.models import load_model 7 | from keras.engine.training import Model 8 | 9 | from sklearn.model_selection import train_test_split 10 | import glob 11 | import u_net 12 | from u_net import bce_dice_loss, dice_loss, dice_loss100, Scale 13 | import os 14 | import scipy.misc as misc 15 | import random 16 | from os.path import join 17 | import errno 18 | import itertools 19 | import argparse 20 | from multi_gpu import to_multi_gpu 21 | import csv 22 | import re 23 | from tqdm import tqdm 24 | import copy 25 | from threading import Lock 26 | 27 | 28 | ROOT_DIR = '/data/pavel/carv' 29 | WEIGHTS_DIR = '../PatchesNet-binaries/weights' 30 | 31 | def mkdir_p(path): 32 | """Utility function emulating mkdir -p.""" 33 | try: 34 | os.makedirs(path) 35 | except OSError as exc: # Python >2.5 36 | if exc.errno == errno.EEXIST and os.path.isdir(path): 37 | pass 38 | else: 39 | raise 40 | 41 | mkdir_p(WEIGHTS_DIR) 42 | 43 | parser = argparse.ArgumentParser() 44 | 45 | # train 46 | parser.add_argument('--max-epoch', type=int, default=200, help='Epoch to run') 47 | parser.add_argument('-l', '--learning-rate', type=float, default=1e-2, help='Initial learning rate, e.g. -l 1e-2') 48 | parser.add_argument('-lw', '--load-weights', type=str, help='load model weights (and continue training)') 49 | parser.add_argument('-lm', '--load-model', type=str, help='load model (and continue training)') 50 | parser.add_argument('-c', '--cpu', action='store_true', help='force CPU usage') 51 | parser.add_argument('-p', '--patch-size', type=int, default=384, help='Patch size, e.g -p 128') 52 | parser.add_argument('-i', '--input-size', type=int, default=384, help='Network input size, e.g -i 256') 53 | parser.add_argument('-ub', '--use-background', action='store_true', help='Use magic background as extra input to NN') 54 | parser.add_argument('-uc', '--use-coarse', action='store_true', help='Use coarse mask as extra input to NN') 55 | parser.add_argument('-o', '--optimizer', type=str, default='sgd', help='Optimizer to use: adam, nadam, sgd, e.g. -o adam') 56 | parser.add_argument('-at', '--augmentation-tps', action='store_true', help='TPS augmentation') 57 | parser.add_argument('-af', '--augmentation-flips', action='store_true', help='Flips augmentation') 58 | parser.add_argument('-s', '--suffix', type=str, default=None, help='Suffix for saving model name') 59 | parser.add_argument('-m', '--model', type=str, default='dilated_unet', help='Use model, e.g. -m dilated_unet -m unet_256, unet_bg_256, largekernels') 60 | parser.add_argument('-f', '--fractional-epoch', type=int, default=1, help='Reduce epoch steps by factor, e.g. -f 10 (after 10 epochs all samples would have been seen) ') 61 | 62 | # test / submission 63 | parser.add_argument('-t', '--test', action='store_true', help='Test/Submit') 64 | parser.add_argument('-tb', '--test-background', type=str, default=join(ROOT_DIR, 'test_background_hq_09970'), help='Magic backgrounds folder in PNG format for test, e.g. -tb /data/pavel/carv/test_backgroound_hq') 65 | parser.add_argument('-tc', '--test-coarse', type=str, help='Coarse mask folder in PNG format for test, e.g. -tc /data/pavel/carv/09967_test') 66 | parser.add_argument('-tf', '--test-folder', type=str, default=join(ROOT_DIR, 'test_hq'), help='Test folder e.g. -tc /data/pavel/carv/test_hq') 67 | parser.add_argument('-tppi', '--test-patches-per-image', type=int, default=128, help='Patches per image (rounded to multiple of batch size)') 68 | parser.add_argument('-tts', '--test-total-splits', type=int, default=1, help='Only do the Xth car, e.g. -tts 12 (needs to work with -tcs)') 69 | parser.add_argument('-tcs', '--test-current-split', type=int, default=0, help='Only do the Nth car out of Xth, e.g. -tts 12 -tcs 2') 70 | parser.add_argument('-tsps', '--test-smart-patch-selection', action='store_true', help='Use a smarter way of selecting patches') 71 | 72 | # common 73 | parser.add_argument('-g', '--gpus', type=int, default=1, help='Use GPUs, e.g -g 2') 74 | parser.add_argument('-b', '--batch-size', type=int, default=16, help='Batch Size during training/test, e.g. -b 32') 75 | parser.add_argument('-tcm', '--threshold-coarse', action='store_true', help='Threshold coarse mask (for training or eval)') 76 | 77 | args = parser.parse_args() 78 | 79 | def preprocess_input_imagenet(img): 80 | return img.astype(np.float32) - np.float32([103.939, 116.779, 123.68]) 81 | 82 | # WARNING -> this would fail for 'largekernels' if LOADING MODEL (b/c args.model would be undefined) 83 | # TODO: Fix if you plan to load models based on 'largekernels' architecture 84 | preprocess_for_model = preprocess_input_imagenet if args.model == 'largekernels' else lambda x: x / 255. 85 | 86 | if args.cpu: 87 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 88 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 89 | 90 | PATCH_SIZE = args.patch_size 91 | TRAIN_FOLDER_PATCHES = join(ROOT_DIR, 'train_patches_' + str(PATCH_SIZE)) 92 | TRAIN_FOLDER_MASKS = join(ROOT_DIR, 'train_patches_masks_' + str(PATCH_SIZE)) 93 | BACKGROUNDS_FOLDER = join(ROOT_DIR, 'train_background_hq_09970') 94 | CSV_FILENAME = join(ROOT_DIR, 'train_patches_' + str(PATCH_SIZE) + ".csv") 95 | COARSE_FOLDER = join(ROOT_DIR, '09970') 96 | 97 | input_size = args.input_size 98 | batch_size = args.batch_size 99 | 100 | if not args.test: 101 | all_files = glob.glob(join(TRAIN_FOLDER_PATCHES, '*_*.jpg')) 102 | ids = list(set([(x.split('/')[-1]).split('_')[0] for x in all_files])) 103 | ids.sort() 104 | 105 | if args.use_background or args.use_coarse: 106 | with open(CSV_FILENAME, 'rb') as csvfile: 107 | reader = csv.reader(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 108 | patches_dict = {rows[0]:(int(rows[1]),int(rows[2])) for rows in reader} 109 | if args.use_background: 110 | background_dict = { } 111 | stats_dict = { } 112 | if args.use_coarse: 113 | coarse_dict = { } 114 | 115 | ids_train_split, ids_valid_split = train_test_split( 116 | ids, test_size=0.1, random_state=13) 117 | 118 | ids_train_split = [os.path.basename(x).split('.')[0] 119 | for x in all_files if os.path.basename(x).split('_')[0] in ids_train_split] 120 | ids_valid_split = [os.path.basename(x).split('.')[0] 121 | for x in all_files if os.path.basename(x).split('_')[0] in ids_valid_split] 122 | 123 | ids_train_split += ids_valid_split 124 | 125 | print('Training on {} samples'.format(len(ids_train_split))) 126 | print('Validating on {} samples'.format(len(ids_valid_split))) 127 | 128 | else: 129 | 130 | all_files = glob.glob(join(args.test_folder, '*_*.jpg')) 131 | ids = list(set([(x.split('/')[-1]).split('_')[0] for x in all_files])) 132 | ids.sort() 133 | print('Testing on {} samples'.format(len(ids))) 134 | 135 | def randomShiftScaleRotate(image, mask, 136 | shift_limit=(-0.0625, 0.0625), 137 | scale_limit=(-0.1, 0.1), 138 | rotate_limit=(-45, 45), aspect_limit=(0, 0), 139 | borderMode=cv2.BORDER_CONSTANT, u=0.5): 140 | if np.random.random() < u: 141 | height, width, channel = image.shape 142 | 143 | angle = np.random.uniform(rotate_limit[0], rotate_limit[1]) # degree 144 | scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1]) 145 | aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1]) 146 | sx = scale * aspect / (aspect ** 0.5) 147 | sy = scale / (aspect ** 0.5) 148 | dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width) 149 | dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height) 150 | 151 | cc = np.math.cos(angle / 180 * np.math.pi) * sx 152 | ss = np.math.sin(angle / 180 * np.math.pi) * sy 153 | rotate_matrix = np.array([[cc, -ss], [ss, cc]]) 154 | 155 | box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) 156 | box1 = box0 - np.array([width / 2, height / 2]) 157 | box1 = np.dot(box1, rotate_matrix.T) + \ 158 | np.array([width / 2 + dx, height / 2 + dy]) 159 | 160 | box0 = box0.astype(np.float32) 161 | box1 = box1.astype(np.float32) 162 | mat = cv2.getPerspectiveTransform(box0, box1) 163 | image = cv2.warpPerspective( 164 | image, mat, ( 165 | width, height), flags=cv2.INTER_CUBIC, borderMode=borderMode, 166 | borderValue=( 167 | 0, 0, 168 | 0,)) 169 | mask = cv2.warpPerspective( 170 | mask, mat, ( 171 | width, height), flags=cv2.INTER_CUBIC, borderMode=borderMode, 172 | borderValue=( 173 | 0, 0, 174 | 0,)) 175 | return image, mask 176 | 177 | def randomHorizontalFlip(image, mask, u=0.5): 178 | if np.random.random() < u: 179 | image = np.fliplr(image) 180 | if mask is not None: 181 | mask = np.fliplr(mask) 182 | 183 | return image, mask 184 | 185 | def generator(ids, training = True): 186 | random.seed(13) 187 | while True: 188 | if training: 189 | random.shuffle(ids) 190 | for start in range(0, len(ids), batch_size): 191 | x_batch = [] 192 | y_batch = [] 193 | end = min(start + batch_size, len(ids)) 194 | ids_batch = ids[start:end] 195 | i = 0 196 | for id in ids_batch: 197 | 198 | img = cv2.imread(join(TRAIN_FOLDER_PATCHES, '{}.jpg'.format(id))) 199 | if (input_size, input_size, 3) != img.shape: 200 | img = cv2.resize(img, (input_size, input_size), interpolation=cv2.INTER_CUBIC) 201 | 202 | mask = cv2.imread(join(TRAIN_FOLDER_MASKS, '{}.png'.format(id)), cv2.IMREAD_GRAYSCALE) 203 | if (input_size, input_size) != mask.shape: 204 | mask = cv2.resize(mask, (input_size, input_size), interpolation=cv2.INTER_LINEAR) 205 | 206 | if args.use_background: 207 | background = None 208 | car_id = id.split('_')[0] 209 | if car_id in background_dict: 210 | all_background = background_dict[car_id] 211 | else: 212 | all_background = cv2.imread(join(BACKGROUNDS_FOLDER, '{}.png'.format(car_id))) 213 | all_background = np.pad(all_background, ((PATCH_SIZE // 2, PATCH_SIZE // 2), (PATCH_SIZE // 2, PATCH_SIZE // 2), (0, 0)), 'symmetric') 214 | background_dict[car_id] = all_background 215 | patch_id = '{}.jpg'.format(id) 216 | x,y = patches_dict[patch_id] 217 | background = np.copy(all_background[y-PATCH_SIZE//2:y+PATCH_SIZE//2,x-PATCH_SIZE//2:x+PATCH_SIZE//2]) 218 | 219 | no_background_color = (255,0,255) 220 | background_index = np.all(background != no_background_color, axis=-1) 221 | selected_background = background[background_index] 222 | background_l2 = np.expand_dims(255 - np.linalg.norm(background - img, axis=2) / np.sqrt(3.), axis=2) 223 | background_mask = np.zeros((PATCH_SIZE, PATCH_SIZE,1), dtype=np.uint8) 224 | background_mask[background_index] = 255 225 | selected_background_l2 = background_l2[background_index] 226 | 227 | if patch_id in stats_dict: 228 | selected_background_mean,selected_background_std = stats_dict[patch_id] 229 | else: 230 | if selected_background.size > 0: 231 | selected_background_mean = np.mean(selected_background_l2) 232 | selected_background_std = np.std(selected_background_l2) 233 | else: 234 | selected_background_mean = np.mean(img) 235 | selected_background_std = np.std(img) 236 | stats_dict[patch_id] = (selected_background_mean, selected_background_std) 237 | 238 | background_l2[~background_index] = \ 239 | np.random.normal(loc=selected_background_mean, scale=selected_background_std, size=(PATCH_SIZE**2-(selected_background.size//3),1)) 240 | 241 | img = np.concatenate([img, background_l2, background_mask], axis=2) 242 | 243 | if args.use_coarse: 244 | car_view_id = id.split('_')[:2] 245 | car_view_file = '{}_{}.png'.format(car_view_id[0], car_view_id[1]) 246 | if car_view_file in coarse_dict: 247 | all_coarse = coarse_dict[car_view_file] 248 | else: 249 | all_coarse = cv2.imread(join(COARSE_FOLDER, car_view_file), cv2.IMREAD_GRAYSCALE) 250 | 251 | if args.threshold_coarse: 252 | all_coarse = 255 * np.rint(all_coarse/255.).astype(np.uint8) 253 | 254 | all_coarse = np.pad(all_coarse, ((PATCH_SIZE // 2, PATCH_SIZE // 2), (PATCH_SIZE // 2, PATCH_SIZE // 2)), 'symmetric') 255 | 256 | #coarse_dict[car_view_file] = all_coarse 257 | patch_id = '{}.jpg'.format(id) 258 | x,y = patches_dict[patch_id] 259 | coarse = np.copy(all_coarse[y-PATCH_SIZE//2:y+PATCH_SIZE//2,x-PATCH_SIZE//2:x+PATCH_SIZE//2]) 260 | img = np.concatenate([img, np.expand_dims(coarse, axis=2)], axis=2) 261 | 262 | if training: 263 | if args.augmentation_tps: 264 | img, mask = tps({'img': img, 'mask': mask, 'seed': random.randint(0,1000)}) 265 | img = cv2.resize(img, (input_size, input_size), interpolation=cv2.INTER_CUBIC) 266 | mask = cv2.resize(mask, (input_size, input_size), interpolation=cv2.INTER_LINEAR) 267 | 268 | if args.augmentation_flips: 269 | img, mask = randomHorizontalFlip(img, mask) 270 | 271 | mask = np.expand_dims(mask, axis=2) 272 | img = preprocess_for_model(img.astype(np.float32)) 273 | x_batch.append(img) 274 | y_batch.append(mask) 275 | 276 | if img.shape[:2] != (PATCH_SIZE, PATCH_SIZE): 277 | print(id) 278 | x_batch = np.array(x_batch, np.float32) 279 | y_batch = np.array(y_batch, np.float32) / 255. 280 | yield x_batch, y_batch 281 | 282 | def get_weighted_window(patch_size): 283 | squareX, squareY = np.meshgrid( 284 | np.arange(1, patch_size // 2 + 1, 1), 285 | np.arange(1, patch_size // 2 + 1, 1)) 286 | grid = (squareX + squareY) // 2 287 | square = np.zeros((patch_size, patch_size), dtype=np.float32) 288 | square[0:patch_size // 2, 0:patch_size // 2] = grid 289 | square[patch_size // 2:, 0:patch_size // 2] = np.flip(grid, 0) 290 | square[0:patch_size // 2, patch_size // 2:] = np.flip(grid, 1) 291 | square[patch_size // 2:, patch_size // 2:] = patch_size // 2 + 1 - grid 292 | w = np.sqrt(np.sqrt(square / (patch_size // 2))) 293 | return w 294 | 295 | def rle_encode(pixels): 296 | #pixels = pixels[:, :1918,:] 297 | pixels = pixels.ravel() 298 | np.rint(pixels, out=pixels) 299 | 300 | # We avoid issues with '1' at the start or end (at the corners of 301 | # the original image) by setting those pixels to '0' explicitly. 302 | # We do not expect these to be non-zero for an accurate mask, 303 | # so this should not harm the score. 304 | pixels[0] = 0 305 | pixels[-1] = 0 306 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 307 | runs[1::2] = runs[1::2] - runs[:-1:2] 308 | return runs 309 | 310 | def rle_to_string(runs): 311 | return ' '.join(str(x) for x in runs) 312 | 313 | def test_model(model, ids, X, CO, patches_per_image, batch_size, csv_filename, save_pngs_to_folder, input_channels): 314 | 315 | random.seed(13) 316 | batch_lock = Lock() 317 | 318 | def patches_generator(car_id, all_coarse_padded_batch, car_xy_flip_batch): 319 | batch_lock.acquire() 320 | 321 | patch_id = 0 322 | if X: 323 | all_background = cv2.imread(join(args.test_background, '{}.png'.format(car_id))) 324 | all_background_padded = np.pad(all_background, ((PATCH_SIZE // 2, PATCH_SIZE // 2), (PATCH_SIZE // 2, PATCH_SIZE // 2), (0, 0)), 'symmetric') 325 | 326 | for idx in range(1,17): 327 | car_view_file = '{}_{:02d}'.format(car_id, idx) 328 | 329 | img = cv2.imread(join(args.test_folder, car_view_file + '.jpg')) 330 | img_padded = np.pad(img, ((PATCH_SIZE // 2, PATCH_SIZE // 2), (PATCH_SIZE // 2, PATCH_SIZE // 2), (0, 0)), 'symmetric') 331 | 332 | all_coarse = cv2.imread(join(args.test_coarse, car_view_file + '.png'), cv2.IMREAD_GRAYSCALE) 333 | if args.threshold_coarse: 334 | all_coarse = 255 * np.rint(all_coarse/255.).astype(np.uint8) 335 | 336 | all_coarse_padded_batch[idx-1, ...] = np.pad(all_coarse, ((PATCH_SIZE // 2, PATCH_SIZE // 2), (PATCH_SIZE // 2, PATCH_SIZE // 2)), 'symmetric') 337 | 338 | if args.threshold_coarse: 339 | all_coarse_mask = all_coarse 340 | else: 341 | all_coarse_mask = 255 * np.rint(all_coarse/255.).astype(np.uint8) 342 | 343 | border = np.abs(np.gradient(all_coarse_mask)[1]) + np.abs(np.gradient(all_coarse_mask)[0]) 344 | border = np.select([border == 0.5, border != 0.5], [1.0, border]) 345 | 346 | edges = np.nonzero(border) 347 | 348 | seed = random.randint(0,1000) 349 | edges_x, edges_y = edges[0], edges[1] 350 | n_patches = batch_size * (patches_per_image // batch_size) 351 | 352 | if args.test_smart_patch_selection: 353 | edge_probs = [] 354 | for y,x in zip(edges[0], edges[1]): 355 | B = PATCH_SIZE // 32 356 | x += PATCH_SIZE // 2 357 | y += PATCH_SIZE // 2 358 | edge_prob = 1. / (all_coarse_padded[idx-1, y-B//2:y+B//2,x-B//2:x+B//2].mean() + 1.) 359 | edge_probs.append(edge_prob) 360 | 361 | # TODO: normalize edge_probs so they add up to 1, otherwise -tsps would not work at all 362 | 363 | random.seed(seed) 364 | edges_x = np.random.choice(edges_x, size = n_patches, replace=None, p=edge_probs) 365 | random.seed(seed) 366 | edges_y = np.random.choice(edges_y, size = n_patches, replace=None, p=edge_probs) 367 | else: 368 | random.seed(seed) 369 | random.shuffle(edges_x) 370 | random.seed(seed) 371 | random.shuffle(edges_y) 372 | 373 | edges = edges_x[: n_patches], edges_y[: n_patches] 374 | 375 | i = 0 376 | img_batch = np.empty((batch_size, input_size, input_size, input_channels), dtype=np.float32) 377 | 378 | xy_batch = [] 379 | 380 | for y,x in zip(edges[0], edges[1]): 381 | x = x + PATCH_SIZE // 2 382 | y = y + PATCH_SIZE // 2 383 | 384 | x_l, x_r = x - PATCH_SIZE // 2, x + PATCH_SIZE // 2 385 | y_l, y_r = y - PATCH_SIZE // 2, y + PATCH_SIZE // 2 386 | 387 | img = img_padded[y_l:y_r, x_l:x_r, :] 388 | 389 | if X: 390 | background = np.copy(all_background_padded[y_l:y_r, x_l:x_r,:]) 391 | 392 | no_background_color = (255,0,255) 393 | background_index = np.all(background != no_background_color, axis=-1) 394 | selected_background = background[background_index] 395 | background_l2 = np.expand_dims(255 - np.linalg.norm(background - img, axis=2) / np.sqrt(3.), axis=2) 396 | background_mask = np.zeros((PATCH_SIZE, PATCH_SIZE,1), dtype=np.uint8) 397 | background_mask[background_index] = 255 398 | selected_background_l2 = background_l2[background_index] 399 | 400 | if selected_background.size > 0: 401 | selected_background_mean = np.mean(selected_background_l2) 402 | selected_background_std = np.std(selected_background_l2) 403 | else: 404 | selected_background_mean = np.mean(img) 405 | selected_background_std = np.std(img) 406 | 407 | background_l2[~background_index] = \ 408 | np.random.normal(loc=selected_background_mean, scale=selected_background_std, size=(PATCH_SIZE**2-(selected_background.size//3),1)) 409 | 410 | img = np.concatenate([img, background_l2, background_mask], axis=2) 411 | 412 | if CO: 413 | coarse = np.copy(all_coarse_padded_batch[idx-1, y_l:y_r, x_l:x_r]) 414 | img = np.concatenate([img, np.expand_dims(coarse, axis=2)], axis=2) 415 | 416 | if (input_size, input_size) != img.shape[:2]: 417 | img = cv2.resize(img, (input_size, input_size), interpolation=cv2.INTER_CUBIC) 418 | 419 | if img.shape[:2] != (PATCH_SIZE, PATCH_SIZE): 420 | print(id) 421 | 422 | img = preprocess_for_model(img.astype(np.float32)) 423 | flip = random.randint(0,1) 424 | if flip: 425 | img = np.fliplr(img) 426 | 427 | img_batch[i,...] = copy.deepcopy(img) 428 | xy_batch.append(copy.deepcopy((x_l, x_r, y_l, y_r, flip, patch_id))) 429 | 430 | i += 1 431 | patch_id += 1 432 | 433 | if i == batch_size: 434 | 435 | yield(img_batch) 436 | #print("Yield:", car_id, idx) 437 | i = 0 438 | 439 | car_xy_flip_batch.append(copy.deepcopy(xy_batch)) 440 | 441 | batch_lock.release() 442 | 443 | # this is a workaround to make Keras max queue size happy 444 | while True: 445 | yield(img_batch) 446 | 447 | 448 | weighted_window = get_weighted_window(PATCH_SIZE) 449 | 450 | split_preffix = '' 451 | if args.test_total_splits != 1: 452 | split_preffix = str(args.test_current_split) + "_of_" + str(args.test_total_splits) 453 | 454 | with open(split_preffix + csv_filename, 'wb') as csvfile: 455 | writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 456 | 457 | if args.test_current_split == 0: 458 | writer.writerow(['img', 'rle_mask']) 459 | 460 | 461 | for car_id in tqdm(ids[args.test_current_split::args.test_total_splits]): 462 | 463 | _car_xy_flip = [] 464 | _all_coarse_padded = np.zeros((16,1280+PATCH_SIZE,1918+PATCH_SIZE), dtype=np.uint8) 465 | 466 | patches_probs = model.predict_generator( 467 | patches_generator(car_id, _all_coarse_padded, _car_xy_flip), 468 | steps = 16 * patches_per_image // batch_size, 469 | max_queue_size = 1) 470 | 471 | batch_lock.acquire() 472 | all_coarse_padded = copy.deepcopy(_all_coarse_padded) 473 | car_xy_flip = copy.deepcopy(_car_xy_flip) 474 | #print("FINISHED CALLING GEN") 475 | #print(patches_probs.shape) 476 | #print(car_xy_flip) 477 | #print(len(car_xy_flip)) 478 | 479 | idx = 1 480 | for xy_batch in car_xy_flip: 481 | 482 | car_view_file = '{}_{:02d}'.format(car_id, idx) 483 | 484 | probabilities_padded = np.zeros((1280+PATCH_SIZE,1918+PATCH_SIZE), dtype=np.float32) 485 | weights_padded = np.zeros((1280+PATCH_SIZE,1918+PATCH_SIZE), dtype=np.float32) 486 | 487 | for (x_l, x_r, y_l, y_r, flip, patch_id) in xy_batch: 488 | #print(x_l, x_r, y_l, y_r, flip) 489 | patch_probs = np.squeeze(patches_probs[patch_id], axis=2) 490 | if flip: 491 | patch_probs = np.fliplr(patch_probs) 492 | probabilities_padded[y_l:y_r, x_l:x_r] += np.multiply(patch_probs, weighted_window) 493 | weights_padded[y_l:y_r, x_l:x_r] += weighted_window 494 | 495 | zero_weights = (weights_padded == 0) 496 | weights_padded[zero_weights] = 1. 497 | probabilities_padded /= weights_padded 498 | probabilities_padded[zero_weights] = all_coarse_padded[idx-1, zero_weights] / 255. 499 | 500 | probabilities = probabilities_padded[PATCH_SIZE//2:-PATCH_SIZE//2, PATCH_SIZE//2:-PATCH_SIZE//2] 501 | 502 | cv2.imwrite(join(save_pngs_to_folder, car_view_file + ".png"), probabilities*255.) 503 | 504 | rle = rle_encode(probabilities) 505 | writer.writerow([car_view_file + ".jpg", rle_to_string(rle)]) 506 | idx += 1 507 | 508 | batch_lock.release() 509 | 510 | initial_epoch = 0 511 | 512 | if args.load_model: 513 | print("Loading model " + args.load_model) 514 | 515 | # monkey-patch loss so model loads ok 516 | # https://github.com/fchollet/keras/issues/5916#issuecomment-290344248 517 | import keras.losses 518 | import keras.metrics 519 | keras.losses.bce_dice_loss = bce_dice_loss 520 | keras.metrics.dice_loss = dice_loss 521 | keras.metrics.dice_loss100 = dice_loss100 522 | 523 | model = load_model(args.load_model, compile=False, custom_objects = { 'Scale' : Scale}) 524 | match = re.search(r'patchesnet-([_a-zA-Z]+)-epoch(\d+)-.*', args.load_model) 525 | model_name = match.group(1).split("__")[0] 526 | initial_epoch = int(match.group(2)) + 1 527 | 528 | input_dimensions = model.get_input_shape_at(0)[1:] 529 | print(input_dimensions) 530 | assert input_dimensions[:2] == (args.input_size, args.input_size) 531 | 532 | name_dict = { 533 | 'rgb' : (False, False, 3), 534 | 'rgbCO' : (False, True, 4), 535 | 'rgbX' : (True, False, 5), 536 | 'rgbXCO' : (True, True, 6) } 537 | 538 | X, CO, input_channels = name_dict[model.layers[1].name.split("_")[0]] 539 | 540 | if not args.test: 541 | assert args.use_background == X 542 | assert args.use_coarse == CO 543 | 544 | else: 545 | model_name = args.model 546 | input_channels = 3 547 | if args.use_background: 548 | input_channels += 2 549 | if args.use_coarse: 550 | input_channels += 1 551 | 552 | model = getattr(u_net, 'get_'+ model_name)(input_shape=(input_size, input_size, input_channels)) 553 | 554 | if args.load_weights: 555 | model.load_weights(args.load_weights, by_name=True) 556 | 557 | model.summary() 558 | 559 | if args.suffix is None: 560 | suffix = "__rgb" 561 | if args.use_background: 562 | suffix += "X" 563 | if args.use_coarse: 564 | suffix += "CO" 565 | else: 566 | suffix = "__" + args.suffix 567 | 568 | 569 | if args.gpus != 1: 570 | model = to_multi_gpu(model,n_gpus=args.gpus) 571 | 572 | if args.test: 573 | model_basename = args.load_model.split('/')[-1] 574 | mkdir_p('test_' + model_basename) 575 | test_model(model, ids, X, CO, 576 | patches_per_image = args.test_patches_per_image, 577 | batch_size = args.batch_size, 578 | csv_filename = model_basename + '.csv', 579 | save_pngs_to_folder = 'test_' + model_basename, 580 | input_channels = input_channels) 581 | 582 | else: 583 | 584 | callbacks = [ReduceLROnPlateau(monitor='val_dice_loss100', 585 | factor=0.5, 586 | patience=4, 587 | verbose=1, 588 | epsilon=1e-4, 589 | mode='max'), 590 | ModelCheckpoint(monitor='val_dice_loss100', 591 | filepath=join(WEIGHTS_DIR,"patchesnet-"+ model_name + suffix + "-epoch{epoch:02d}-val_dice{val_dice_loss:.6f}"), 592 | save_best_only=False, 593 | save_weights_only=False, 594 | mode='max')] 595 | 596 | 597 | 598 | if args.optimizer == 'adam': 599 | optimizer=Adam(lr=args.learning_rate) 600 | elif args.optimizer == 'nadam': 601 | optimizer=Nadam(lr=args.learning_rate) 602 | elif args.optimizer == 'rmsprop': 603 | optimizer=RMSprop(lr=args.learning_rate) 604 | elif args.optimizer == 'sgd': 605 | optimizer=SGD(lr=args.learning_rate, momentum=0.9) 606 | else: 607 | assert False 608 | 609 | model.compile(optimizer=optimizer, loss=bce_dice_loss, metrics=[dice_loss, dice_loss100]) 610 | model.fit_generator(generator=generator(ids = ids_train_split, training=True), 611 | steps_per_epoch=np.ceil( 612 | float(len(ids_train_split)) / float(batch_size)) // args.fractional_epoch, 613 | epochs=args.max_epoch, 614 | initial_epoch = initial_epoch, 615 | verbose=1, 616 | callbacks=callbacks, 617 | validation_data=generator(ids = ids_valid_split, training=False), 618 | validation_steps=np.ceil(float(len(ids_valid_split)) / float(batch_size))) 619 | -------------------------------------------------------------------------------- /u_net.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | from keras.models import Model 3 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D, BatchNormalization, ZeroPadding2D 4 | from keras.optimizers import SGD, Adam, RMSprop, Nadam 5 | from keras.losses import binary_crossentropy 6 | import keras.backend as K 7 | from multi_gpu import to_multi_gpu 8 | from keras.layers.merge import add 9 | from keras_contrib.layers.normalization import InstanceNormalization 10 | from keras.applications import ResNet50 11 | 12 | from keras.layers import Input, Conv2D, Conv2DTranspose 13 | from keras.layers import MaxPooling2D, Cropping2D, Concatenate 14 | from keras.layers import Lambda, Activation, BatchNormalization, Dropout 15 | from keras.models import Model 16 | from keras import initializers 17 | from keras.engine import Layer, InputSpec 18 | import math 19 | 20 | 21 | def weighted_bce_loss(y_true, y_pred, weight): 22 | # avoiding overflow 23 | epsilon = 1e-7 24 | y_pred = K.clip(y_pred, epsilon, 1. - epsilon) 25 | logit_y_pred = K.log(y_pred / (1. - y_pred)) 26 | 27 | # https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits 28 | loss = (1. - y_true) * logit_y_pred + (1. + (weight - 1.) * y_true) * \ 29 | (K.log(1. + K.exp(-K.abs(logit_y_pred))) 30 | + K.maximum(-logit_y_pred, 0.)) 31 | return K.sum(loss) / K.sum(weight) 32 | 33 | 34 | def weighted_dice_loss(y_true, y_pred, weight): 35 | smooth = 1. 36 | w, m1, m2 = weight * weight, y_true, y_pred 37 | intersection = (m1 * m2) 38 | score = (2. * K.sum(w * intersection) + smooth) / \ 39 | (K.sum(w * m1) + K.sum(w * m2) + smooth) 40 | loss = 1. - K.sum(score) 41 | return loss 42 | 43 | 44 | def weighted_bce_dice_loss(y_true, y_pred): 45 | y_true = K.cast(y_true, 'float32') 46 | y_pred = K.cast(y_pred, 'float32') 47 | # if we want to get same size of output, kernel size must be odd number 48 | averaged_mask = K.pool2d( 49 | y_true, pool_size=(11, 11), strides=(1, 1), padding='same', pool_mode='avg') 50 | border = K.cast(K.greater(averaged_mask, 0.005), 'float32') * \ 51 | K.cast(K.less(averaged_mask, 0.995), 'float32') 52 | weight = K.ones_like(averaged_mask) 53 | w0 = K.sum(weight) 54 | weight += border * 2 55 | w1 = K.sum(weight) 56 | weight *= (w0 / w1) 57 | # loss = weighted_bce_loss(y_true, y_pred, weight) + \ 58 | loss = weighted_dice_loss(y_true, y_pred, weight) 59 | return loss 60 | 61 | 62 | def dice_loss(y_true, y_pred): 63 | smooth = 1. 64 | y_true_f = K.flatten(y_true) 65 | y_pred_f = K.flatten(y_pred) 66 | intersection = K.sum(y_true_f * y_pred_f) 67 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 68 | 69 | 70 | def dice_loss100(y_true, y_pred): 71 | smooth = 1. 72 | y_true_f = K.flatten(y_true) 73 | y_pred_f = K.flatten(y_pred) 74 | intersection = K.sum(y_true_f * y_pred_f) 75 | return (200. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 76 | 77 | 78 | def jaccard_coef(y_true, y_pred): 79 | # __author__ = Vladimir Iglovikov 80 | smooth = 1e-12 81 | y_pred_pos = K.round(K.clip(y_pred, 0, 1)) 82 | intersection = K.sum(y_true * y_pred, axis=[0, -1, -2]) 83 | sum_ = K.sum(y_true + y_pred, axis=[0, -1, -2]) 84 | 85 | jac = (intersection + smooth) / (sum_ - intersection + smooth) 86 | 87 | return K.mean(jac) 88 | 89 | 90 | def bce_dice_loss(y_true, y_pred): 91 | # return (binary_crossentropy(y_true, y_pred) + (1 - dice_loss(y_true, 92 | # y_pred))) /2 93 | return 1 - dice_loss(y_true, y_pred) 94 | # return 1 -jaccard_coef(y_true, y_pred) 95 | 96 | 97 | def get_unet_128(input_shape=(128, 128, 3), 98 | num_classes=1): 99 | inputs = Input(shape=input_shape) 100 | # 128 101 | 102 | down1 = Conv2D(64, (3, 3), padding='same')(inputs) 103 | down1 = BatchNormalization()(down1) 104 | down1 = Activation('relu')(down1) 105 | down1 = Conv2D(64, (3, 3), padding='same')(down1) 106 | down1 = BatchNormalization()(down1) 107 | down1 = Activation('relu')(down1) 108 | down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) 109 | # 64 110 | 111 | down2 = Conv2D(128, (3, 3), padding='same')(down1_pool) 112 | down2 = BatchNormalization()(down2) 113 | down2 = Activation('relu')(down2) 114 | down2 = Conv2D(128, (3, 3), padding='same')(down2) 115 | down2 = BatchNormalization()(down2) 116 | down2 = Activation('relu')(down2) 117 | down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) 118 | # 32 119 | 120 | down3 = Conv2D(256, (3, 3), padding='same')(down2_pool) 121 | down3 = BatchNormalization()(down3) 122 | down3 = Activation('relu')(down3) 123 | down3 = Conv2D(256, (3, 3), padding='same')(down3) 124 | down3 = BatchNormalization()(down3) 125 | down3 = Activation('relu')(down3) 126 | down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) 127 | # 16 128 | 129 | down4 = Conv2D(512, (3, 3), padding='same')(down3_pool) 130 | down4 = BatchNormalization()(down4) 131 | down4 = Activation('relu')(down4) 132 | down4 = Conv2D(512, (3, 3), padding='same')(down4) 133 | down4 = BatchNormalization()(down4) 134 | down4 = Activation('relu')(down4) 135 | down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) 136 | # 8 137 | 138 | center = Conv2D(1024, (3, 3), padding='same')(down4_pool) 139 | center = BatchNormalization()(center) 140 | center = Activation('relu')(center) 141 | center = Conv2D(1024, (3, 3), padding='same')(center) 142 | center = BatchNormalization()(center) 143 | center = Activation('relu')(center) 144 | # center 145 | 146 | up4 = UpSampling2D((2, 2))(center) 147 | up4 = concatenate([down4, up4], axis=3) 148 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 149 | up4 = BatchNormalization()(up4) 150 | up4 = Activation('relu')(up4) 151 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 152 | up4 = BatchNormalization()(up4) 153 | up4 = Activation('relu')(up4) 154 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 155 | up4 = BatchNormalization()(up4) 156 | up4 = Activation('relu')(up4) 157 | # 16 158 | 159 | up3 = UpSampling2D((2, 2))(up4) 160 | up3 = concatenate([down3, up3], axis=3) 161 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 162 | up3 = BatchNormalization()(up3) 163 | up3 = Activation('relu')(up3) 164 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 165 | up3 = BatchNormalization()(up3) 166 | up3 = Activation('relu')(up3) 167 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 168 | up3 = BatchNormalization()(up3) 169 | up3 = Activation('relu')(up3) 170 | # 32 171 | 172 | up2 = UpSampling2D((2, 2))(up3) 173 | up2 = concatenate([down2, up2], axis=3) 174 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 175 | up2 = BatchNormalization()(up2) 176 | up2 = Activation('relu')(up2) 177 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 178 | up2 = BatchNormalization()(up2) 179 | up2 = Activation('relu')(up2) 180 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 181 | up2 = BatchNormalization()(up2) 182 | up2 = Activation('relu')(up2) 183 | # 64 184 | 185 | up1 = UpSampling2D((2, 2))(up2) 186 | up1 = concatenate([down1, up1], axis=3) 187 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 188 | up1 = BatchNormalization()(up1) 189 | up1 = Activation('relu')(up1) 190 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 191 | up1 = BatchNormalization()(up1) 192 | up1 = Activation('relu')(up1) 193 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 194 | up1 = BatchNormalization()(up1) 195 | up1 = Activation('relu')(up1) 196 | # 128 197 | 198 | classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up1) 199 | 200 | model = Model(inputs=inputs, outputs=classify) 201 | 202 | model.compile(optimizer=SGD(lr=0.01, momentum=0.9), 203 | loss=bce_dice_loss, metrics=[dice_loss, dice_loss100]) 204 | 205 | return model 206 | 207 | 208 | def get_unet_256(input_shape=(256, 256, 3), naive_upsampling= True ): 209 | 210 | inputs = Input(shape=input_shape) 211 | bg_preffix_dict = { 212 | 3 : 'rgb_', # RGB 213 | 4 : 'rgbCO_', # RGB + coarse mask 214 | 5 : 'rgbX_', # RGB + BG 215 | 6 : 'rgbXCO_' } # RGB + BG + coarse mask 216 | bg_preffix = bg_preffix_dict[input_shape[2]] 217 | 218 | global bn_axis 219 | if K.image_dim_ordering() == 'tf': 220 | bn_axis = 3 221 | else: 222 | bn_axis = 1 223 | 224 | inputs_normalized = InstanceNormalization(axis=bn_axis, name=bg_preffix + 'instancenormalization1')(inputs) 225 | 226 | # 256 227 | 228 | # receptive 365 229 | down0 = Conv2D(32, (3, 3), padding='same', name=bg_preffix + "conv1")(inputs_normalized) 230 | down0 = BatchNormalization()(down0) 231 | down0 = Activation('relu')(down0) 232 | down0 = Conv2D(32, (3, 3), padding='same')(down0) 233 | down0 = BatchNormalization()(down0) 234 | down0 = Activation('relu')(down0) 235 | down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) 236 | # 128 237 | 238 | # receptive 155 239 | down1 = Conv2D(64, (3, 3), padding='same')(down0_pool) 240 | down1 = BatchNormalization()(down1) 241 | down1 = Activation('relu')(down1) 242 | down1 = Conv2D(64, (3, 3), padding='same')(down1) 243 | down1 = BatchNormalization()(down1) 244 | down1 = Activation('relu')(down1) 245 | down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) 246 | # 64 247 | 248 | # receptive 75 249 | down2 = Conv2D(128, (3, 3), padding='same')(down1_pool) 250 | down2 = BatchNormalization()(down2) 251 | down2 = Activation('relu')(down2) 252 | down2 = Conv2D(128, (3, 3), padding='same')(down2) 253 | down2 = BatchNormalization()(down2) 254 | down2 = Activation('relu')(down2) 255 | down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) 256 | 257 | # receptive 35 258 | # 32 259 | down3 = Conv2D(256, (3, 3), padding='same')(down2_pool) 260 | down3 = BatchNormalization()(down3) 261 | down3 = Activation('relu')(down3) 262 | down3 = Conv2D(256, (3, 3), padding='same')(down3) 263 | down3 = BatchNormalization()(down3) 264 | down3 = Activation('relu')(down3) 265 | down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) 266 | # 16 267 | 268 | # receptive 15 269 | down4 = Conv2D(512, (3, 3), padding='same')(down3_pool) 270 | down4 = BatchNormalization()(down4) 271 | down4 = Activation('relu')(down4) 272 | down4 = Conv2D(512, (3, 3), padding='same')(down4) 273 | down4 = BatchNormalization()(down4) 274 | down4 = Activation('relu')(down4) 275 | down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) 276 | # 8 277 | 278 | # receptive 5 279 | center = Conv2D(1024, (3, 3), padding='same')(down4_pool) 280 | center = BatchNormalization()(center) 281 | center = Activation('relu')(center) 282 | center = Conv2D(1024, (3, 3), padding='same')(center) 283 | center = BatchNormalization()(center) 284 | center = Activation('relu')(center) 285 | # center 286 | 287 | if naive_upsampling: 288 | up4 = UpSampling2D((2, 2))(center) 289 | else: 290 | up4 = Conv2DTranspose(1024, kernel_size=(2,2), strides=(2,2))(center) 291 | up4 = concatenate([down4, up4], axis=3) 292 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 293 | up4 = BatchNormalization()(up4) 294 | up4 = Activation('relu')(up4) 295 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 296 | up4 = BatchNormalization()(up4) 297 | up4 = Activation('relu')(up4) 298 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 299 | up4 = BatchNormalization()(up4) 300 | up4 = Activation('relu')(up4) 301 | # 16 302 | 303 | if naive_upsampling: 304 | up3 = UpSampling2D((2, 2))(up4) 305 | else: 306 | up3 = Conv2DTranspose(512, kernel_size=(2,2), strides=(2,2))(up4) 307 | up3 = concatenate([down3, up3], axis=3) 308 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 309 | up3 = BatchNormalization()(up3) 310 | up3 = Activation('relu')(up3) 311 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 312 | up3 = BatchNormalization()(up3) 313 | up3 = Activation('relu')(up3) 314 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 315 | up3 = BatchNormalization()(up3) 316 | up3 = Activation('relu')(up3) 317 | # 32 318 | 319 | if naive_upsampling: 320 | up2 = UpSampling2D((2, 2))(up3) 321 | else: 322 | up2 = Conv2DTranspose(256, kernel_size=(2,2), strides=(2,2))(up3) 323 | up2 = concatenate([down2, up2], axis=3) 324 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 325 | up2 = BatchNormalization()(up2) 326 | up2 = Activation('relu')(up2) 327 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 328 | up2 = BatchNormalization()(up2) 329 | up2 = Activation('relu')(up2) 330 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 331 | up2 = BatchNormalization()(up2) 332 | up2 = Activation('relu')(up2) 333 | # 64 334 | if naive_upsampling: 335 | up1 = UpSampling2D((2, 2))(up2) 336 | else: 337 | up1 = Conv2DTranspose(128, kernel_size=(2,2), strides=(2,2))(up2) 338 | 339 | up1 = concatenate([down1, up1], axis=3) 340 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 341 | up1 = BatchNormalization()(up1) 342 | up1 = Activation('relu')(up1) 343 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 344 | up1 = BatchNormalization()(up1) 345 | up1 = Activation('relu')(up1) 346 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 347 | up1 = BatchNormalization()(up1) 348 | up1 = Activation('relu')(up1) 349 | # 128 350 | 351 | down0xtra = Conv2D(32, (3, 3), padding='same', name=bg_preffix + "down0_xtra_conv1")(inputs_normalized) 352 | down0xtra = BatchNormalization()(down0xtra) 353 | down0xtra = Activation('relu')(down0xtra) 354 | down0xtra = Conv2D(32, (3, 3), padding='same')(down0xtra) 355 | down0xtra = BatchNormalization()(down0xtra) 356 | down0xtra = Activation('relu')(down0xtra) 357 | down0xtra = Conv2D(32, (3, 3), padding='same')(down0xtra) 358 | down0xtra = BatchNormalization()(down0xtra) 359 | down0xtra = Activation('relu')(down0xtra) 360 | 361 | if naive_upsampling: 362 | up0 = UpSampling2D((2, 2))(up1) 363 | else: 364 | up0 = Conv2DTranspose(64, kernel_size=(2,2), strides=(2,2))(up1) 365 | 366 | up0 = concatenate([down0, up0, down0xtra], axis=3) 367 | up0 = Conv2D(32, (3, 3), padding='same')(up0) 368 | up0 = BatchNormalization()(up0) 369 | up0 = Activation('relu')(up0) 370 | up0 = Conv2D(32, (3, 3), padding='same')(up0) 371 | up0 = BatchNormalization()(up0) 372 | up0 = Activation('relu')(up0) 373 | up0 = Conv2D(32, (3, 3), padding='same')(up0) 374 | up0 = BatchNormalization()(up0) 375 | up0 = Activation('relu')(up0) 376 | # 256 377 | 378 | classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0) 379 | 380 | model = Model(inputs=inputs, outputs=classify) 381 | 382 | return model 383 | 384 | def get_resunet(input_shape=(256, 256, 3), naive_upsampling= True, full_residual = False): 385 | 386 | inputs = Input(shape=input_shape) 387 | bg_preffix_dict = { 388 | 3 : 'rgb_', # RGB 389 | 4 : 'rgbCO_', # RGB + coarse mask 390 | 5 : 'rgbX_', # RGB + BG 391 | 6 : 'rgbXCO_' } # RGB + BG + coarse mask 392 | bg_preffix = bg_preffix_dict[input_shape[2]] 393 | 394 | global bn_axis 395 | if K.image_dim_ordering() == 'tf': 396 | bn_axis = 3 397 | else: 398 | bn_axis = 1 399 | 400 | inputs_normalized = InstanceNormalization(axis=bn_axis, name=bg_preffix + 'instancenormalization1')(inputs) 401 | 402 | down = inputs_normalized 403 | 404 | down_blocks = [32, 64, 128, 256, 512, 1024] 405 | skips = [] 406 | 407 | scale = False 408 | 409 | # 7 -> 14 + 7 -> 42 + 7 -> 98 +7 -> 210 + 7 = 217 410 | # 512 -> 256 -> 128 -> 64 -> 32 411 | 412 | trf = 0 413 | erf = 0 414 | residual_filter_factor = 4 if full_residual else 1 415 | for i, block in enumerate(down_blocks): 416 | first = i == 0 417 | last = i == (len(down_blocks) -1) 418 | 419 | dilation = [1,1,1] 420 | k = 3 421 | 422 | def unet_block(input, kernel_size, filters, stage, block, filter_start=0, preffix=""): 423 | x = input 424 | for i, nb_filter in enumerate(filters): 425 | suffix = str(stage) + "_" + block + "_" + str(i+filter_start) 426 | x = Conv2D(nb_filter, kernel_size, padding='same', name=preffix + 'conv_' + suffix)(x) 427 | x = BatchNormalization(name=preffix + 'bn_' + suffix)(x) 428 | x = Activation('relu', name=preffix + 'act_' + suffix)(x) 429 | return x 430 | 431 | if not full_residual: 432 | down = unet_block(down, k, [block, block], stage=i, block='down', preffix = bg_preffix if first else '') 433 | down = conv_block(down, k, [block, block, block*residual_filter_factor], stage=i, block='down', strides=(1, 1), preffix=bg_preffix if first else '', zero_padding=False, scale=scale) 434 | 435 | if full_residual: 436 | down = conv_block(down, k, [block, block, block*residual_filter_factor], stage=i, block='down', strides=(1, 1), preffix=bg_preffix if first else '', zero_padding=False, scale=scale) 437 | down = identity_block(down, k, [block, block, block*residual_filter_factor], stage=i, block='down0', preffix='', zero_padding=False, scale=scale) 438 | down = identity_block(down, k, [block, block, block*residual_filter_factor], stage=i, block='down1', preffix='', zero_padding=False, scale=scale) 439 | rf = 1 + (k-1) * dilation[0] + (k-1) * dilation[1]+ (k-1) * dilation[2] 440 | trf += rf 441 | erf += rf / math.sqrt(3) 442 | 443 | if not last: 444 | skips.append(down) 445 | down = MaxPooling2D((2, 2), strides=(2, 2))(down) 446 | trf *= 2 447 | erf *= 2 448 | 449 | # see http://www.cs.toronto.edu/~wenjie/papers/nips16/top.pdf 450 | print("Theoretical receptive field: " + str(trf) + " pixels") 451 | print("Effective receptive field: " + str(erf) + " pixels") 452 | 453 | up = down 454 | for i, block in enumerate(down_blocks[:-1][::-1]): 455 | if naive_upsampling: 456 | up = UpSampling2D((2, 2))(up) 457 | else: 458 | up = Conv2DTranspose(block, kernel_size=(2,2), strides=(2,2))(up) 459 | up = concatenate([up, skips.pop()], axis=3) 460 | if not full_residual: 461 | up = Conv2D(block, 3, padding='same')(up) 462 | up = BatchNormalization()(up) 463 | up = Activation('relu')(up) 464 | up = Conv2D(block, 3, padding='same')(up) 465 | up = BatchNormalization()(up) 466 | up = Activation('relu')(up) 467 | up = conv_block(up, 3, [block//1, block//1, block*residual_filter_factor], stage=i, block='up', strides=(1, 1), preffix='', zero_padding=False, scale=scale) 468 | if full_residual: 469 | up = identity_block(up, 3, [block, block, block*residual_filter_factor], stage=i, block='up0', preffix='', zero_padding=False, scale=scale) 470 | up = identity_block(up, 3, [block, block, block*residual_filter_factor], stage=i, block='up1', preffix='', zero_padding=False, scale=scale) 471 | 472 | kernel_sigmoid = 7 473 | classify = Conv2D(1, (kernel_sigmoid, kernel_sigmoid), padding='same', activation='sigmoid', name='conv_sigmoid_' + str(kernel_sigmoid))(up) 474 | 475 | model = Model(inputs=inputs, outputs=classify) 476 | 477 | return model 478 | 479 | def get_unet_background_256(input_shape=(256, 256, 6), 480 | num_classes=1): 481 | inputs = Input(shape=input_shape) 482 | inputs_normalized = InstanceNormalization(axis=3)(inputs) 483 | # 256 484 | 485 | down0 = Conv2D(32, (3, 3), padding='same')(inputs_normalized) 486 | down0 = BatchNormalization()(down0) 487 | down0 = Activation('relu')(down0) 488 | down0 = Conv2D(32, (3, 3), padding='same')(down0) 489 | down0 = BatchNormalization()(down0) 490 | down0 = Activation('relu')(down0) 491 | down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) 492 | # 128 493 | 494 | down1 = Conv2D(64, (3, 3), padding='same')(down0_pool) 495 | down1 = BatchNormalization()(down1) 496 | down1 = Activation('relu')(down1) 497 | down1 = Conv2D(64, (3, 3), padding='same')(down1) 498 | down1 = BatchNormalization()(down1) 499 | down1 = Activation('relu')(down1) 500 | down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) 501 | # 64 502 | 503 | down2 = Conv2D(128, (3, 3), padding='same')(down1_pool) 504 | down2 = BatchNormalization()(down2) 505 | down2 = Activation('relu')(down2) 506 | down2 = Conv2D(128, (3, 3), padding='same')(down2) 507 | down2 = BatchNormalization()(down2) 508 | down2 = Activation('relu')(down2) 509 | down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) 510 | # 32 511 | 512 | down3 = Conv2D(256, (3, 3), padding='same')(down2_pool) 513 | down3 = BatchNormalization()(down3) 514 | down3 = Activation('relu')(down3) 515 | down3 = Conv2D(256, (3, 3), padding='same')(down3) 516 | down3 = BatchNormalization()(down3) 517 | down3 = Activation('relu')(down3) 518 | down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) 519 | # 16 520 | 521 | down4 = Conv2D(512, (3, 3), padding='same')(down3_pool) 522 | down4 = BatchNormalization()(down4) 523 | down4 = Activation('relu')(down4) 524 | down4 = Conv2D(512, (3, 3), padding='same')(down4) 525 | down4 = BatchNormalization()(down4) 526 | down4 = Activation('relu')(down4) 527 | down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) 528 | # 8 529 | 530 | center = Conv2D(1024, (3, 3), padding='same')(down4_pool) 531 | center = BatchNormalization()(center) 532 | center = Activation('relu')(center) 533 | center = Conv2D(1024, (3, 3), padding='same')(center) 534 | center = BatchNormalization()(center) 535 | center = Activation('relu')(center) 536 | # center 537 | 538 | up4 = UpSampling2D((2, 2))(center) 539 | up4 = concatenate([down4, up4], axis=3) 540 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 541 | up4 = BatchNormalization()(up4) 542 | up4 = Activation('relu')(up4) 543 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 544 | up4 = BatchNormalization()(up4) 545 | up4 = Activation('relu')(up4) 546 | up4 = Conv2D(512, (3, 3), padding='same')(up4) 547 | up4 = BatchNormalization()(up4) 548 | up4 = Activation('relu')(up4) 549 | # 16 550 | 551 | up3 = UpSampling2D((2, 2))(up4) 552 | up3 = concatenate([down3, up3], axis=3) 553 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 554 | up3 = BatchNormalization()(up3) 555 | up3 = Activation('relu')(up3) 556 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 557 | up3 = BatchNormalization()(up3) 558 | up3 = Activation('relu')(up3) 559 | up3 = Conv2D(256, (3, 3), padding='same')(up3) 560 | up3 = BatchNormalization()(up3) 561 | up3 = Activation('relu')(up3) 562 | # 32 563 | 564 | up2 = UpSampling2D((2, 2))(up3) 565 | up2 = concatenate([down2, up2], axis=3) 566 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 567 | up2 = BatchNormalization()(up2) 568 | up2 = Activation('relu')(up2) 569 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 570 | up2 = BatchNormalization()(up2) 571 | up2 = Activation('relu')(up2) 572 | up2 = Conv2D(128, (3, 3), padding='same')(up2) 573 | up2 = BatchNormalization()(up2) 574 | up2 = Activation('relu')(up2) 575 | # 64 576 | 577 | up1 = UpSampling2D((2, 2))(up2) 578 | up1 = concatenate([down1, up1], axis=3) 579 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 580 | up1 = BatchNormalization()(up1) 581 | up1 = Activation('relu')(up1) 582 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 583 | up1 = BatchNormalization()(up1) 584 | up1 = Activation('relu')(up1) 585 | up1 = Conv2D(64, (3, 3), padding='same')(up1) 586 | up1 = BatchNormalization()(up1) 587 | up1 = Activation('relu')(up1) 588 | # 128 589 | 590 | down0xtra = Conv2D(32, (3, 3), padding='same')(inputs_normalized) 591 | down0xtra = BatchNormalization()(down0xtra) 592 | down0xtra = Activation('relu')(down0xtra) 593 | down0xtra = Conv2D(32, (3, 3), padding='same')(down0xtra) 594 | down0xtra = BatchNormalization()(down0xtra) 595 | down0xtra = Activation('relu')(down0xtra) 596 | down0xtra = Conv2D(32, (3, 3), padding='same')(down0xtra) 597 | down0xtra = BatchNormalization()(down0xtra) 598 | down0xtra = Activation('relu')(down0xtra) 599 | 600 | up0 = UpSampling2D((2, 2))(up1) 601 | up0 = concatenate([down0, up0, down0xtra], axis=3) 602 | up0 = Conv2D(32, (3, 3), padding='same')(up0) 603 | up0 = BatchNormalization()(up0) 604 | up0 = Activation('relu')(up0) 605 | up0 = Conv2D(32, (3, 3), padding='same')(up0) 606 | up0 = BatchNormalization()(up0) 607 | up0 = Activation('relu')(up0) 608 | up0 = Conv2D(32, (3, 3), padding='same')(up0) 609 | up0 = BatchNormalization()(up0) 610 | up0 = Activation('relu')(up0) 611 | # 256 612 | 613 | classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0) 614 | 615 | model = Model(inputs=inputs, outputs=classify) 616 | 617 | # model = to_multi_gpu(model,n_gpus=8) 618 | 619 | return model 620 | 621 | def get_unet_512(input_shape=(512, 512, 3), 622 | num_classes=1): 623 | inputs = Input(shape=input_shape) 624 | # 512 625 | 626 | down0a = Conv2D(16 * 2, (7, 7), padding='same')(inputs) 627 | down0a = BatchNormalization()(down0a) 628 | down0a = Activation('relu')(down0a) 629 | down0a = Conv2D(16 * 2, (3, 3), padding='same')(down0a) 630 | down0a = BatchNormalization()(down0a) 631 | down0a = Activation('relu')(down0a) 632 | down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a) 633 | # 256 634 | 635 | down0 = Conv2D(32 * 2, (3, 3), padding='same')(down0a_pool) 636 | down0 = BatchNormalization()(down0) 637 | down0 = Activation('relu')(down0) 638 | down0 = Conv2D(32 * 2, (3, 3), padding='same')(down0) 639 | down0 = BatchNormalization()(down0) 640 | down0 = Activation('relu')(down0) 641 | down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) 642 | # 128 643 | 644 | down1 = Conv2D(64 * 2, (3, 3), padding='same')(down0_pool) 645 | down1 = BatchNormalization()(down1) 646 | down1 = Activation('relu')(down1) 647 | down1 = Conv2D(64 * 2, (3, 3), padding='same')(down1) 648 | down1 = BatchNormalization()(down1) 649 | down1 = Activation('relu')(down1) 650 | down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) 651 | # 64 652 | 653 | down2 = Conv2D(128 * 2, (3, 3), padding='same')(down1_pool) 654 | down2 = BatchNormalization()(down2) 655 | down2 = Activation('relu')(down2) 656 | down2 = Conv2D(128 * 2, (3, 3), padding='same')(down2) 657 | down2 = BatchNormalization()(down2) 658 | down2 = Activation('relu')(down2) 659 | down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) 660 | # 32 661 | 662 | down3 = Conv2D(256 * 2, (3, 3), padding='same')(down2_pool) 663 | down3 = BatchNormalization()(down3) 664 | down3 = Activation('relu')(down3) 665 | down3 = Conv2D(256 * 2, (3, 3), padding='same')(down3) 666 | down3 = BatchNormalization()(down3) 667 | down3 = Activation('relu')(down3) 668 | down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) 669 | # 16 670 | 671 | down4 = Conv2D(512 * 2, (3, 3), padding='same')(down3_pool) 672 | down4 = BatchNormalization()(down4) 673 | down4 = Activation('relu')(down4) 674 | down4 = Conv2D(512 * 2, (3, 3), padding='same')(down4) 675 | down4 = BatchNormalization()(down4) 676 | down4 = Activation('relu')(down4) 677 | down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) 678 | # 8 679 | 680 | center = Conv2D(1024 * 2, (3, 3), padding='same')(down4_pool) 681 | center = BatchNormalization()(center) 682 | center = Activation('relu')(center) 683 | center = Conv2D(1024 * 2, (3, 3), padding='same')(center) 684 | center = BatchNormalization()(center) 685 | center = Activation('relu')(center) 686 | # center 687 | 688 | up4 = UpSampling2D((2, 2))(center) 689 | up4 = concatenate([down4, up4], axis=3) 690 | up4 = Conv2D(512 * 2, (3, 3), padding='same')(up4) 691 | up4 = BatchNormalization()(up4) 692 | up4 = Activation('relu')(up4) 693 | up4 = Conv2D(512 * 2, (3, 3), padding='same')(up4) 694 | up4 = BatchNormalization()(up4) 695 | up4 = Activation('relu')(up4) 696 | up4 = Conv2D(512 * 2, (3, 3), padding='same')(up4) 697 | up4 = BatchNormalization()(up4) 698 | up4 = Activation('relu')(up4) 699 | # 16 700 | 701 | up3 = UpSampling2D((2, 2))(up4) 702 | up3 = concatenate([down3, up3], axis=3) 703 | up3 = Conv2D(256 * 2, (3, 3), padding='same')(up3) 704 | up3 = BatchNormalization()(up3) 705 | up3 = Activation('relu')(up3) 706 | up3 = Conv2D(256 * 2, (3, 3), padding='same')(up3) 707 | up3 = BatchNormalization()(up3) 708 | up3 = Activation('relu')(up3) 709 | up3 = Conv2D(256 * 2, (3, 3), padding='same')(up3) 710 | up3 = BatchNormalization()(up3) 711 | up3 = Activation('relu')(up3) 712 | # 32 713 | 714 | up2 = UpSampling2D((2, 2))(up3) 715 | up2 = concatenate([down2, up2], axis=3) 716 | up2 = Conv2D(128 * 2, (3, 3), padding='same')(up2) 717 | up2 = BatchNormalization()(up2) 718 | up2 = Activation('relu')(up2) 719 | up2 = Conv2D(128 * 2, (3, 3), padding='same')(up2) 720 | up2 = BatchNormalization()(up2) 721 | up2 = Activation('relu')(up2) 722 | up2 = Conv2D(128 * 2, (3, 3), padding='same')(up2) 723 | up2 = BatchNormalization()(up2) 724 | up2 = Activation('relu')(up2) 725 | # 64 726 | 727 | up1 = UpSampling2D((2, 2))(up2) 728 | up1 = concatenate([down1, up1], axis=3) 729 | up1 = Conv2D(64 * 2, (3, 3), padding='same')(up1) 730 | up1 = BatchNormalization()(up1) 731 | up1 = Activation('relu')(up1) 732 | up1 = Conv2D(64 * 2, (3, 3), padding='same')(up1) 733 | up1 = BatchNormalization()(up1) 734 | up1 = Activation('relu')(up1) 735 | up1 = Conv2D(64 * 2, (3, 3), padding='same')(up1) 736 | up1 = BatchNormalization()(up1) 737 | up1 = Activation('relu')(up1) 738 | # 128 739 | 740 | up0 = UpSampling2D((2, 2))(up1) 741 | up0 = concatenate([down0, up0], axis=3) 742 | up0 = Conv2D(32 * 2, (3, 3), padding='same')(up0) 743 | up0 = BatchNormalization()(up0) 744 | up0 = Activation('relu')(up0) 745 | up0 = Conv2D(32 * 2, (3, 3), padding='same')(up0) 746 | up0 = BatchNormalization()(up0) 747 | up0 = Activation('relu')(up0) 748 | up0 = Conv2D(32 * 2, (3, 3), padding='same')(up0) 749 | up0 = BatchNormalization()(up0) 750 | up0 = Activation('relu')(up0) 751 | # 256 752 | 753 | up0a = UpSampling2D((2, 2))(up0) 754 | up0a = concatenate([down0a, up0a], axis=3) 755 | up0a = Conv2D(16 * 2, (3, 3), padding='same')(up0a) 756 | up0a = BatchNormalization()(up0a) 757 | up0a = Activation('relu')(up0a) 758 | up0a = Conv2D(16 * 2, (3, 3), padding='same')(up0a) 759 | up0a = BatchNormalization()(up0a) 760 | up0a = Activation('relu')(up0a) 761 | up0a = Conv2D(16 * 2, (7, 7), padding='same')(up0a) 762 | up0a = BatchNormalization()(up0a) 763 | up0a = Activation('relu')(up0a) 764 | # 512 765 | 766 | classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a) 767 | 768 | model = Model(inputs=inputs, outputs=classify) 769 | 770 | # model = to_multi_gpu(model,n_gpus=8) 771 | 772 | model.compile(optimizer=SGD(lr=0.001, momentum=0.99), 773 | loss=bce_dice_loss, metrics=[dice_loss]) 774 | 775 | return model 776 | 777 | 778 | def get_unet_1024(input_shape=(1024, 1024, 3), 779 | num_classes=1, mult=2): 780 | inputs = Input(shape=input_shape) 781 | # 1024 782 | 783 | # preprocess = BatchNormalization(center=False, scale=False, 784 | # name='preprocess')(inputs) 785 | 786 | # receptive 1275 787 | down0b = Conv2D(4 * mult, (3, 3), padding='same')(inputs) 788 | down0b = BatchNormalization()(down0b) 789 | down0b = Activation('relu')(down0b) 790 | down0b = Conv2D(4 * mult, (3, 3), padding='same')(down0b) 791 | down0b = BatchNormalization()(down0b) 792 | down0b = Activation('relu')(down0b) 793 | 794 | # receptive 1270 795 | down0b_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0b) 796 | # 512 797 | 798 | # receptive 635 799 | down0a = Conv2D(8 * mult, (3, 3), padding='same')(down0b_pool) 800 | down0a = BatchNormalization()(down0a) 801 | down0a = Activation('relu')(down0a) 802 | down0a = Conv2D(8 * mult, (3, 3), padding='same')(down0a) 803 | down0a = BatchNormalization()(down0a) 804 | down0a = Activation('relu')(down0a) 805 | # receptive 630 806 | down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a) 807 | # 256 808 | 809 | # receptive 315 810 | down0 = Conv2D(16 * mult, (3, 3), padding='same')(down0a_pool) 811 | down0 = BatchNormalization()(down0) 812 | down0 = Activation('relu')(down0) 813 | down0 = Conv2D(16 * mult, (3, 3), padding='same')(down0) 814 | down0 = BatchNormalization()(down0) 815 | down0 = Activation('relu')(down0) 816 | 817 | # receptive 310 818 | down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0) 819 | # 128 820 | 821 | # receptive 155 822 | down1 = Conv2D(32 * mult, (3, 3), padding='same')(down0_pool) 823 | down1 = BatchNormalization()(down1) 824 | down1 = Activation('relu')(down1) 825 | down1 = Conv2D(32 * mult, (3, 3), padding='same')(down1) 826 | down1 = BatchNormalization()(down1) 827 | down1 = Activation('relu')(down1) 828 | # receptive 150 829 | down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1) 830 | # 64 831 | 832 | # receptive 75 833 | down2 = Conv2D(64 * mult, (3, 3), padding='same')(down1_pool) 834 | down2 = BatchNormalization()(down2) 835 | down2 = Activation('relu')(down2) 836 | down2 = Conv2D(64 * mult, (3, 3), padding='same')(down2) 837 | down2 = BatchNormalization()(down2) 838 | down2 = Activation('relu')(down2) 839 | # receptive 70 840 | down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2) 841 | # 32 842 | 843 | # receptive 35 844 | down3 = Conv2D(128 * mult, (3, 3), padding='same')(down2_pool) 845 | down3 = BatchNormalization()(down3) 846 | down3 = Activation('relu')(down3) 847 | down3 = Conv2D(128 * mult, (3, 3), padding='same')(down3) 848 | down3 = BatchNormalization()(down3) 849 | down3 = Activation('relu')(down3) 850 | # receptive 30 851 | down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3) 852 | # 16 853 | 854 | # receptive 15 855 | down4 = Conv2D(256 * mult, (3, 3), padding='same')(down3_pool) 856 | down4 = BatchNormalization()(down4) 857 | down4 = Activation('relu')(down4) 858 | down4 = Conv2D(256 * mult, (3, 3), padding='same')(down4) 859 | down4 = BatchNormalization()(down4) 860 | down4 = Activation('relu')(down4) 861 | 862 | # receptive 10 863 | down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4) 864 | # 8 865 | 866 | # receptive 5 867 | center = Conv2D(512 * mult, (3, 3), padding='same')(down4_pool) 868 | center = BatchNormalization()(center) 869 | center = Activation('relu')(center) 870 | center = Conv2D(512 * mult, (3, 3), padding='same')(center) 871 | center = BatchNormalization()(center) 872 | center = Activation('relu')(center) 873 | # center 874 | 875 | up4 = UpSampling2D((2, 2))(center) 876 | up4 = concatenate([down4, up4], axis=3) 877 | up4 = Conv2D(256 * mult, (3, 3), padding='same')(up4) 878 | up4 = BatchNormalization()(up4) 879 | up4 = Activation('relu')(up4) 880 | up4 = Conv2D(256 * mult, (3, 3), padding='same')(up4) 881 | up4 = BatchNormalization()(up4) 882 | up4 = Activation('relu')(up4) 883 | up4 = Conv2D(256 * mult, (3, 3), padding='same')(up4) 884 | up4 = BatchNormalization()(up4) 885 | up4 = Activation('relu')(up4) 886 | # 16 887 | 888 | up3 = UpSampling2D((2, 2))(up4) 889 | up3 = concatenate([down3, up3], axis=3) 890 | up3 = Conv2D(128 * mult, (3, 3), padding='same')(up3) 891 | up3 = BatchNormalization()(up3) 892 | up3 = Activation('relu')(up3) 893 | up3 = Conv2D(128 * mult, (3, 3), padding='same')(up3) 894 | up3 = BatchNormalization()(up3) 895 | up3 = Activation('relu')(up3) 896 | up3 = Conv2D(128 * mult, (3, 3), padding='same')(up3) 897 | up3 = BatchNormalization()(up3) 898 | up3 = Activation('relu')(up3) 899 | # 32 900 | 901 | up2 = UpSampling2D((2, 2))(up3) 902 | up2 = concatenate([down2, up2], axis=3) 903 | up2 = Conv2D(64 * mult, (3, 3), padding='same')(up2) 904 | up2 = BatchNormalization()(up2) 905 | up2 = Activation('relu')(up2) 906 | up2 = Conv2D(64 * mult, (3, 3), padding='same')(up2) 907 | up2 = BatchNormalization()(up2) 908 | up2 = Activation('relu')(up2) 909 | up2 = Conv2D(64 * mult, (3, 3), padding='same')(up2) 910 | up2 = BatchNormalization()(up2) 911 | up2 = Activation('relu')(up2) 912 | # 64 913 | 914 | up1 = UpSampling2D((2, 2))(up2) 915 | up1 = concatenate([down1, up1], axis=3) 916 | up1 = Conv2D(32 * mult, (3, 3), padding='same')(up1) 917 | up1 = BatchNormalization()(up1) 918 | up1 = Activation('relu')(up1) 919 | up1 = Conv2D(32 * mult, (3, 3), padding='same')(up1) 920 | up1 = BatchNormalization()(up1) 921 | up1 = Activation('relu')(up1) 922 | up1 = Conv2D(32 * mult, (3, 3), padding='same')(up1) 923 | up1 = BatchNormalization()(up1) 924 | up1 = Activation('relu')(up1) 925 | # 128 926 | 927 | up0 = UpSampling2D((2, 2))(up1) 928 | up0 = concatenate([down0, up0], axis=3) 929 | up0 = Conv2D(16 * mult, (3, 3), padding='same')(up0) 930 | up0 = BatchNormalization()(up0) 931 | up0 = Activation('relu')(up0) 932 | up0 = Conv2D(16 * mult, (3, 3), padding='same')(up0) 933 | up0 = BatchNormalization()(up0) 934 | up0 = Activation('relu')(up0) 935 | up0 = Conv2D(16 * mult, (3, 3), padding='same')(up0) 936 | up0 = BatchNormalization()(up0) 937 | up0 = Activation('relu')(up0) 938 | # 256 939 | 940 | up0a = UpSampling2D((2, 2))(up0) 941 | up0a = concatenate([down0a, up0a], axis=3) 942 | up0a = Conv2D(8 * mult, (3, 3), padding='same')(up0a) 943 | up0a = BatchNormalization()(up0a) 944 | up0a = Activation('relu')(up0a) 945 | up0a = Conv2D(8 * mult, (3, 3), padding='same')(up0a) 946 | up0a = BatchNormalization()(up0a) 947 | up0a = Activation('relu')(up0a) 948 | up0a = Conv2D(8 * mult, (3, 3), padding='same')(up0a) 949 | up0a = BatchNormalization()(up0a) 950 | up0a = Activation('relu')(up0a) 951 | # 512 952 | 953 | up0b = UpSampling2D((2, 2))(up0a) 954 | up0b = concatenate([down0b, up0b], axis=3) 955 | up0b = Conv2D(4 * mult, (3, 3), padding='same')(up0b) 956 | up0b = BatchNormalization()(up0b) 957 | up0b = Activation('relu')(up0b) 958 | up0b = Conv2D(4 * mult, (3, 3), padding='same')(up0b) 959 | up0b = BatchNormalization()(up0b) 960 | up0b = Activation('relu')(up0b) 961 | up0b = Conv2D(4 * mult, (3, 3), padding='same')(up0b) 962 | up0b = BatchNormalization()(up0b) 963 | up0b = Activation('relu')(up0b) 964 | # 1024 965 | 966 | # res_up0b = up0b 967 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res1_1')(up0b) 968 | # up0b = BatchNormalization(name='res1_1_bn')(up0b) 969 | # up0b = Activation('relu')(up0b) 970 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res1_2')(up0b) 971 | # up0b = add([res_up0b, up0b]) 972 | 973 | # res_up0b = up0b 974 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res2_1')(up0b) 975 | # up0b = BatchNormalization(name='res2_1_bn')(up0b) 976 | # up0b = Activation('relu')(up0b) 977 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res2_2')(up0b) 978 | # up0b = add([res_up0b, up0b]) 979 | 980 | # res_up0b = up0b 981 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res3_1')(up0b) 982 | # up0b = BatchNormalization(name='res3_1_bn')(up0b) 983 | # up0b = Activation('relu')(up0b) 984 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res3_2')(up0b) 985 | # up0b = add([res_up0b, up0b]) 986 | 987 | # res_up0b = up0b 988 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res4_1')(up0b) 989 | # up0b = BatchNormalization(name='res4_1_bn')(up0b) 990 | # up0b = Activation('relu')(up0b) 991 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res4_2')(up0b) 992 | # up0b = add([res_up0b, up0b]) 993 | 994 | # res_up0b = up0b 995 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res5_1')(up0b) 996 | # up0b = BatchNormalization(name='res5_1_bn')(up0b) 997 | # up0b = Activation('relu')(up0b) 998 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res5_2')(up0b) 999 | # up0b = add([res_up0b, up0b]) 1000 | 1001 | # res_up0b = up0b 1002 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res6_1')(up0b) 1003 | # up0b = BatchNormalization(name='res6_1_bn')(up0b) 1004 | # up0b = Activation('relu')(up0b) 1005 | # up0b = Conv2D(4*mult, (3, 3), padding='same', name='res6_2')(up0b) 1006 | # up0b = add([res_up0b, up0b]) 1007 | 1008 | up0b = concatenate([up0b, inputs]) 1009 | # res_up0b = up0b 1010 | # up0b = Conv2D(4*mult+3, (3, 3), padding='same', name='res6_1')(up0b) 1011 | # up0b = BatchNormalization(name='res6_1_bn')(up0b) 1012 | # up0b = Activation('relu')(up0b) 1013 | # up0b = Conv2D(4*mult+3, (3, 3), padding='same', name='res6_2')(up0b) 1014 | # up0b = add([res_up0b, up0b]) 1015 | 1016 | classify = Conv2D( 1017 | num_classes, (1, 1), activation='sigmoid', name='newsigmoid')(up0b) 1018 | # classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0b) 1019 | 1020 | model = Model(inputs=inputs, outputs=classify) 1021 | 1022 | # model = to_multi_gpu(model,n_gpus=2) 1023 | 1024 | model.compile(optimizer=SGD(lr=0.001, momentum=0.9), 1025 | loss=bce_dice_loss, metrics=[dice_loss, dice_loss100]) 1026 | # model.compile(optimizer=SGD(lr=0.01, momentum=0.9), loss=bce_dice_loss, 1027 | # metrics=[dice_loss2]) 1028 | 1029 | return model 1030 | 1031 | def downsampling_block(input_tensor, filters, padding='valid', 1032 | batchnorm=False, dropout=0.0): 1033 | _, height, width, _ = K.int_shape(input_tensor) 1034 | print(height, width) 1035 | #assert height % 2 == 0 1036 | #assert width % 2 == 0 1037 | 1038 | x = Conv2D(filters, kernel_size=(3,3), padding=padding, 1039 | dilation_rate=1)(input_tensor) 1040 | x = BatchNormalization()(x) if batchnorm else x 1041 | x = Activation('relu')(x) 1042 | x = Dropout(dropout)(x) if dropout > 0 else x 1043 | 1044 | x = Conv2D(filters, kernel_size=(3,3), padding=padding, dilation_rate=2)(x) 1045 | x = BatchNormalization()(x) if batchnorm else x 1046 | x = Activation('relu')(x) 1047 | x = Dropout(dropout)(x) if dropout > 0 else x 1048 | 1049 | return MaxPooling2D(pool_size=(2,2))(x), x 1050 | 1051 | def upsampling_block(input_tensor, skip_tensor, filters, padding='valid', 1052 | batchnorm=False, dropout=0.0): 1053 | #x = UpSampling2D((2, 2))(input_tensor) 1054 | x = Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2))(input_tensor) 1055 | 1056 | 1057 | # compute amount of cropping needed for skip_tensor 1058 | _, x_height, x_width, _ = K.int_shape(x) 1059 | _, s_height, s_width, _ = K.int_shape(skip_tensor) 1060 | 1061 | h_crop = s_height - x_height 1062 | w_crop = s_width - x_width 1063 | assert h_crop >= 0 1064 | assert w_crop >= 0 1065 | if h_crop == 0 and w_crop == 0: 1066 | y = skip_tensor 1067 | else: 1068 | cropping = ((h_crop//2, h_crop - h_crop//2), (w_crop//2, w_crop - w_crop//2)) 1069 | y = Cropping2D(cropping=cropping)(skip_tensor) 1070 | 1071 | print(K.int_shape(x)) 1072 | print(K.int_shape(y)) 1073 | 1074 | x = Concatenate(axis=3)([x, y]) 1075 | 1076 | # no dilation in upsampling convolutions 1077 | x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x) 1078 | x = BatchNormalization()(x) if batchnorm else x 1079 | x = Activation('relu')(x) 1080 | x = Dropout(dropout)(x) if dropout > 0 else x 1081 | 1082 | x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x) 1083 | x = BatchNormalization()(x) if batchnorm else x 1084 | x = Activation('relu')(x) 1085 | x = Dropout(dropout)(x) if dropout > 0 else x 1086 | 1087 | return x 1088 | 1089 | def get_dilated_unet(input_shape=(256, 256, 3), features=64, depth=3, 1090 | temperature=1.0, padding='same', batchnorm=True, 1091 | dropout=0.0, dilation_layers=3): 1092 | """Generate `dilated U-Net' model where the convolutions in the encoding and 1093 | bottleneck are replaced by dilated convolutions. The second convolution in 1094 | pair at a given scale in the encoder is dilated by 2. The number of 1095 | dilation layers in the innermost bottleneck is controlled by the 1096 | `dilation_layers' parameter -- this is the `context module' proposed by Yu, 1097 | Koltun 2016 in "Multi-scale Context Aggregation by Dilated Convolutions" 1098 | 1099 | Arbitrary number of input channels and output classes are supported. 1100 | 1101 | Arguments: 1102 | height - input image height (pixels) 1103 | width - input image width (pixels) 1104 | channels - input image features (1 for grayscale, 3 for RGB) 1105 | classes - number of output classes (2 in paper) 1106 | features - number of output features for first convolution (64 in paper) 1107 | Number of features double after each down sampling block 1108 | depth - number of downsampling operations (4 in paper) 1109 | padding - 'valid' (used in paper) or 'same' 1110 | batchnorm - include batch normalization layers before activations 1111 | dropout - fraction of units to dropout, 0 to keep all units 1112 | dilation_layers - number of dilated convolutions in innermost bottleneck 1113 | 1114 | Output: 1115 | Dilated U-Net model expecting input shape (height, width, maps) and 1116 | generates output with shape (output_height, output_width, classes). 1117 | If padding is 'same', then output_height = height and 1118 | output_width = width. 1119 | 1120 | """ 1121 | height, width, channels = input_shape 1122 | 1123 | input = Input(shape=input_shape) 1124 | x = InstanceNormalization(axis=3)(input) 1125 | 1126 | inputs = input 1127 | 1128 | skips = [] 1129 | for i in range(depth): 1130 | x, x0 = downsampling_block(x, features, padding, 1131 | batchnorm, dropout) 1132 | skips.append(x0) 1133 | features *= 2 1134 | 1135 | dilation_rate = 1 1136 | for n in range(dilation_layers): 1137 | x = Conv2D(filters=features, kernel_size=(3,3), padding=padding, 1138 | dilation_rate=dilation_rate)(x) 1139 | x = BatchNormalization()(x) if batchnorm else x 1140 | x = Activation('relu')(x) 1141 | x = Dropout(dropout)(x) if dropout > 0 else x 1142 | dilation_rate *= 2 1143 | 1144 | for i in reversed(range(depth)): 1145 | features //= 2 1146 | x = upsampling_block(x, skips[i], features, padding, 1147 | batchnorm, dropout) 1148 | 1149 | segmentation = Conv2D(filters=1, activation='sigmoid', kernel_size=(1,1))(x) 1150 | 1151 | 1152 | return Model(inputs=inputs, outputs=segmentation) 1153 | 1154 | def gcn(i, k, n, filters=1): 1155 | left = Conv2D(filters, kernel_size=(k,1), padding='same', name=n + '_l0')(i) 1156 | left = Conv2D(filters, kernel_size=(1,k), padding='same', name=n + '_l1')(left) 1157 | right = Conv2D(filters, kernel_size=(1,k), padding='same', name=n + '_r0')(i) 1158 | right = Conv2D(filters, kernel_size=(k,1), padding='same', name=n + '_r1')(right) 1159 | return add([left, right]) 1160 | 1161 | def br(i, n, filters=1, activation='relu'): 1162 | res = Conv2D(filters, kernel_size=(3,3), padding='same', activation=activation, name=n + '_c0')(i) 1163 | res = Conv2D(filters, kernel_size=(3,3), padding='same', name=n + '_c1')(res) 1164 | res = Conv2D(i._keras_shape[-1], kernel_size=(1,1), padding='same', name=n + '_c2')(res) 1165 | return add([res, i]) 1166 | 1167 | class Scale(Layer): 1168 | '''Custom Layer for ResNet used for BatchNormalization. 1169 | 1170 | Learns a set of weights and biases used for scaling the input data. 1171 | the output consists simply in an element-wise multiplication of the input 1172 | and a sum of a set of constants: 1173 | 1174 | out = in * gamma + beta, 1175 | 1176 | where 'gamma' and 'beta' are the weights and biases larned. 1177 | 1178 | # Arguments 1179 | axis: integer, axis along which to normalize in mode 0. For instance, 1180 | if your input tensor has shape (samples, channels, rows, cols), 1181 | set axis to 1 to normalize per feature map (channels axis). 1182 | momentum: momentum in the computation of the 1183 | exponential average of the mean and standard deviation 1184 | of the data, for feature-wise normalization. 1185 | weights: Initialization weights. 1186 | List of 2 Numpy arrays, with shapes: 1187 | `[(input_shape,), (input_shape,)]` 1188 | beta_init: name of initialization function for shift parameter 1189 | (see [initializers](../initializers.md)), or alternatively, 1190 | Theano/TensorFlow function to use for weights initialization. 1191 | This parameter is only relevant if you don't pass a `weights` argument. 1192 | gamma_init: name of initialization function for scale parameter (see 1193 | [initializers](../initializers.md)), or alternatively, 1194 | Theano/TensorFlow function to use for weights initialization. 1195 | This parameter is only relevant if you don't pass a `weights` argument. 1196 | ''' 1197 | def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs): 1198 | self.momentum = momentum 1199 | self.axis = axis 1200 | self.beta_init = initializers.get(beta_init) 1201 | self.gamma_init = initializers.get(gamma_init) 1202 | self.initial_weights = weights 1203 | super(Scale, self).__init__(**kwargs) 1204 | 1205 | def build(self, input_shape): 1206 | self.input_spec = [InputSpec(shape=input_shape)] 1207 | shape = (int(input_shape[self.axis]),) 1208 | 1209 | self.gamma = K.variable(self.gamma_init(shape), name='%s_gamma'%self.name) 1210 | self.beta = K.variable(self.beta_init(shape), name='%s_beta'%self.name) 1211 | self.trainable_weights = [self.gamma, self.beta] 1212 | 1213 | if self.initial_weights is not None: 1214 | self.set_weights(self.initial_weights) 1215 | del self.initial_weights 1216 | 1217 | def call(self, x, mask=None): 1218 | input_shape = self.input_spec[0].shape 1219 | broadcast_shape = [1] * len(input_shape) 1220 | broadcast_shape[self.axis] = input_shape[self.axis] 1221 | 1222 | out = K.reshape(self.gamma, broadcast_shape) * x + K.reshape(self.beta, broadcast_shape) 1223 | return out 1224 | 1225 | def get_config(self): 1226 | config = {"momentum": self.momentum, "axis": self.axis} 1227 | base_config = super(Scale, self).get_config() 1228 | return dict(list(base_config.items()) + list(config.items())) 1229 | 1230 | def identity_block(input_tensor, kernel_size, filters, stage, block, preffix='', activation='relu', zero_padding=False, scale=True): 1231 | '''The identity_block is the block that has no conv layer at shortcut 1232 | # Arguments 1233 | input_tensor: input tensor 1234 | kernel_size: defualt 3, the kernel size of middle conv layer at main path 1235 | filters: list of integers, the nb_filters of 3 conv layer at main path 1236 | stage: integer, current stage label, used for generating layer names 1237 | block: 'a','b'..., current block label, used for generating layer names 1238 | ''' 1239 | eps = 1.1e-5 1240 | nb_filter1, nb_filter2, nb_filter3 = filters 1241 | conv_name_base = preffix + 'res' + str(stage) + block + '_branch' 1242 | bn_name_base = preffix + 'bn' + str(stage) + block + '_branch' 1243 | scale_name_base = preffix + 'scale' + str(stage) + block + '_branch' 1244 | 1245 | padding = 'valid' if zero_padding else 'same' 1246 | 1247 | x = Conv2D(nb_filter1, (1, 1), name=conv_name_base + '2a', use_bias=False, padding=padding)(input_tensor) 1248 | x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2a')(x) 1249 | if scale: x = Scale(axis=bn_axis, name=scale_name_base + '2a')(x) 1250 | x = Activation(activation, name=conv_name_base + '2a_relu')(x) 1251 | 1252 | if zero_padding: 1253 | x = ZeroPadding2D((1, 1), name=conv_name_base + '2b_zeropadding')(x) 1254 | x = Conv2D(nb_filter2, (kernel_size, kernel_size), name=conv_name_base + '2b', use_bias=False, padding=padding)(x) 1255 | x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2b')(x) 1256 | if scale: x = Scale(axis=bn_axis, name=scale_name_base + '2b')(x) 1257 | x = Activation(activation, name=conv_name_base + '2b_relu')(x) 1258 | 1259 | x = Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c', use_bias=False, padding=padding)(x) 1260 | x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2c')(x) 1261 | if scale: x = Scale(axis=bn_axis, name=scale_name_base + '2c')(x) 1262 | 1263 | x = add([x, input_tensor], name=preffix + 'res' + str(stage) + block) 1264 | x = Activation(activation, name=preffix + 'res' + str(stage) + block + '_relu')(x) 1265 | return x 1266 | 1267 | def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), preffix='', activation='relu', zero_padding=True, scale=True): 1268 | '''conv_block is the block that has a conv layer at shortcut 1269 | # Arguments 1270 | input_tensor: input tensor 1271 | kernel_size: defualt 3, the kernel size of middle conv layer at main path 1272 | filters: list of integers, the nb_filters of 3 conv layer at main path 1273 | stage: integer, current stage label, used for generating layer names 1274 | block: 'a','b'..., current block label, used for generating layer names 1275 | Note that from stage 3, the first conv layer at main path is with subsample=(2,2) 1276 | And the shortcut should have subsample=(2,2) as well 1277 | ''' 1278 | eps = 1.1e-5 1279 | nb_filter1, nb_filter2, nb_filter3 = filters 1280 | conv_name_base = preffix + 'res' + str(stage) + block + '_branch' 1281 | bn_name_base = preffix + 'bn' + str(stage) + block + '_branch' 1282 | scale_name_base = preffix + 'scale' + str(stage) + block + '_branch' 1283 | 1284 | padding = 'valid' if zero_padding else 'same' 1285 | 1286 | x = Conv2D(nb_filter1, (1, 1), strides=strides, name=conv_name_base + '2a', use_bias=False, padding=padding)(input_tensor) 1287 | x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2a')(x) 1288 | if scale: x = Scale(axis=bn_axis, name=scale_name_base + '2a')(x) 1289 | x = Activation(activation, name=conv_name_base + '2a_relu')(x) 1290 | 1291 | if zero_padding: 1292 | x = ZeroPadding2D((1, 1), name=conv_name_base + '2b_zeropadding')(x) 1293 | x = Conv2D(nb_filter2, (kernel_size, kernel_size), name=conv_name_base + '2b', use_bias=False, padding=padding)(x) 1294 | x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2b')(x) 1295 | if scale: x = Scale(axis=bn_axis, name=scale_name_base + '2b')(x) 1296 | x = Activation(activation, name=conv_name_base + '2b_relu')(x) 1297 | 1298 | x = Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c', use_bias=False, padding=padding)(x) 1299 | x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2c')(x) 1300 | if scale: x = Scale(axis=bn_axis, name=scale_name_base + '2c')(x) 1301 | 1302 | shortcut = Conv2D(nb_filter3, (1, 1), strides=strides, name=conv_name_base + '1', use_bias=False, padding=padding)(input_tensor) 1303 | shortcut = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '1')(shortcut) 1304 | if scale: shortcut = Scale(axis=bn_axis, name=scale_name_base + '1')(shortcut) 1305 | 1306 | x = add([x, shortcut], name=preffix + 'res' + str(stage) + block) 1307 | x = Activation(activation, name=preffix + 'res' + str(stage) + block + '_relu')(x) 1308 | return x 1309 | 1310 | def get_largekernels(input_shape=(256, 256, 3), k=15): 1311 | 1312 | '''Instantiate the ResNet152 architecture, 1313 | # Arguments 1314 | weights_path: path to pretrained weight file 1315 | # Returns 1316 | A Keras model instance. 1317 | ''' 1318 | eps = 1.1e-5 1319 | 1320 | # Handle Dimension Ordering for different backends 1321 | global bn_axis 1322 | if K.image_dim_ordering() == 'tf': 1323 | bn_axis = 3 1324 | img_input = Input(shape=input_shape, name='data') 1325 | else: 1326 | bn_axis = 1 1327 | img_input = Input(shape=input_shape, name='data') 1328 | 1329 | res1 = Scale(axis=bn_axis, name='lkm_scale_conv1')(img_input) 1330 | act = 'relu' 1331 | res1 = conv_block(res1, 3, [32, 32, 128], stage=1, block='a', strides=(1, 1), preffix='lkm_', activation=act) 1332 | res1 = identity_block(res1, 3, [32, 32, 128], stage=1, block='b', preffix='lkm_', activation=act) 1333 | res1 = identity_block(res1, 3, [32, 32, 128], stage=1, block='c', preffix='lkm_', activation=act) 1334 | 1335 | x = ZeroPadding2D((3, 3), name='conv1_zeropadding')(img_input) 1336 | x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=False)(x) 1337 | x = BatchNormalization(epsilon=eps, axis=bn_axis, name='bn_conv1')(x) 1338 | x = Scale(axis=bn_axis, name='scale_conv1')(x) 1339 | x = Activation(act, name='conv1_relu')(x) 1340 | res2 = x 1341 | 1342 | x = MaxPooling2D((3, 3), strides=(2, 2), name='pool1', padding='same')(x) 1343 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), activation=act) 1344 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', activation=act) 1345 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', activation=act) 1346 | 1347 | res3 = x 1348 | 1349 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', activation=act) 1350 | for i in range(1,8): 1351 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b'+str(i), activation=act) 1352 | 1353 | res4 = x 1354 | 1355 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', activation=act) 1356 | for i in range(1,36): 1357 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b'+str(i), activation=act) 1358 | 1359 | res5 = x 1360 | 1361 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', activation=act) 1362 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', activation=act) 1363 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', activation=act) 1364 | 1365 | res6 = x 1366 | 1367 | ff = 64 1368 | act = 'elu' 1369 | 1370 | 1371 | br6 = br(gcn(res6, k, n='lkm_gcn6_', filters = ff//(2**5)), filters = ff//(2**5), n='lkm_br6_', activation=act) # 2048 -> 1024 1372 | br5 = br(gcn(res5, k, n='lkm_gcn5_', filters = ff//(2**5)), filters = ff//(2**5), n='lkm_br5_', activation=act) 1373 | br4 = br(gcn(res4, k, n='lkm_gcn4_', filters = ff//(2**4)), filters = ff//(2**4), n='lkm_br4_', activation=act) 1374 | br3 = br(gcn(res3, k, n='lkm_gcn3_', filters = ff//(2**3)), filters = ff//(2**3), n='lkm_br3_', activation=act) 1375 | br2 = br(gcn(res2, k, n='lkm_gcn2_', filters = ff//(2**2)), filters = ff//(2**2), n='lkm_br2_', activation=act) 1376 | br1 = br(gcn(res1, k, n='lkm_gcn1_', filters = ff//(2**1)), filters = ff//(2**1), n='lkm_br1_', activation=act) 1377 | 1378 | u6 = Conv2DTranspose( ff//(2**5), kernel_size=(2,2), strides=(2,2), name='lkm_u6_')(br6) # 32 x 32 1379 | br5a = br(add([u6,br5]), n='lkm_br5a_', filters=ff//(2**5), activation=act) 1380 | u5 = Conv2DTranspose( ff//(2**4), kernel_size=(2,2), strides=(2,2), name='lkm_u5_')(br5a) # 32 x 32 1381 | br4a = br(add([u5,br4]), n='lkm_br4a_', filters=ff//(2**4), activation=act) 1382 | u4 = Conv2DTranspose( ff//(2**3), kernel_size=(2,2), strides=(2,2), name='lkm_u4_')(br4a) # 64 x 64 1383 | br3a = br(add([u4,br3]), n='lkm_br3a_', filters=ff//(2**3), activation=act) 1384 | u3 = Conv2DTranspose( ff//(2**2), kernel_size=(2,2), strides=(2,2), name='lkm_u3_')(br3a) # 128 x 128 1385 | br2a = br(add([u3,br2]), n='lkm_br2a_', filters=ff//(2**2), activation=act) 1386 | u2 = Conv2DTranspose( ff//(2**1), kernel_size=(2,2), strides=(2,2), name='lkm_u2_')(br2a) # 256 x 256 1387 | br1a = br(add([u2,br1]), n='lkm_br1a_', filters=ff//(2**1), activation=act) 1388 | 1389 | segmentation = Conv2D(1, (1, 1), activation='sigmoid', name='lkm_segmentation_')(br1a) 1390 | 1391 | model = Model(inputs=img_input, outputs=segmentation) 1392 | #model.load_weights("resnet152_weights_tf.h5", by_name=True) 1393 | 1394 | for layer in model.layers: 1395 | if layer.name.split("_")[0] in ['lkm', 'bn5a','bn5b', 'bn5c', 'res5a','res5b', 'res5c', 'scale5a','scale5b','scale5c'] : 1396 | layer.trainable = True 1397 | else: 1398 | layer.trainable = False 1399 | 1400 | 1401 | return model 1402 | --------------------------------------------------------------------------------