├── .gitignore ├── LICENSE ├── README.md ├── notebooks └── .ipynb_checkpoints │ └── model evaluation-checkpoint.ipynb ├── outputs └── exp_1 │ └── results │ ├── 0.0_plot_accuracy.png │ ├── 0.0_plot_loss.png │ └── val_dice.txt ├── requirements.txt ├── scripts ├── config_files │ └── defaults.config └── train.py ├── setup.py └── tnseg ├── Augmentor ├── ImageSource.py ├── ImageUtilities.py ├── Operations.py ├── Pipeline.py └── __init__.py ├── __init__.py ├── dataset.py ├── evaluate.py ├── loss.py ├── models ├── __init__.py ├── dilated_densenet.py ├── dilated_unet.py ├── unet.py └── window_unet.py ├── opts.py └── ufarray.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | data_small/ 3 | *.pyc 4 | run.sh 5 | errors/ 6 | run_output/ 7 | outputs/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 chuckyee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Thyroid Nodule Segmentation 2 | 3 | This repository contains code and models to segment thyroid nodules in ultrasound images. 4 | Dataset used: [Open-CAS Ultrasound Dataset](http://opencas.webarchiv.kit.edu/?q=node/29) 5 | 6 | ## Installation 7 | 8 | The main code is written as a Python package named 'tnseg'. After cloning this 9 | repository to your machine, install with: 10 | 11 | ```bash 12 | cd cloned/path 13 | pip install . 14 | ``` 15 | 16 | You should then be able to use the package in Python: 17 | 18 | ```python 19 | import matplotlib.pyplot as plt 20 | from tnseg import dataset, models, loss, opts, evaluate 21 | ``` 22 | 23 | ## Running models 24 | 25 | Scripts for model training and evaluation are located under /scripts/. 26 | 27 | ```bash 28 | python -u scripts/train.py config_files/defaults.config 29 | ``` 30 | 31 | On running the model, the outputs are saved in the outputs/ folder, in a folder named with the experiment name (this should be specified in the config file). The outputs include the following: 32 | 1. weights/ : Weights saved during the training. 33 | 2. results/ : The error and accuracy plots, validation dice coefficients 34 | 3. predictions/ : Predicted annotation maps of all the validation folders 35 | 36 | Note: In this project, the dataset contains 16 folders. Due to the limited nature of the dataset, we trained 8 models (14 train and 2 validation), and obtained the validation dice coefficients of all the folders. 37 | 38 | Note: this package is written with the Tensorflow backend in mind -- (batch, 39 | height, width, channels) ordered is assumed and is not portable to Theano. 40 | 41 | ## Models 42 | 43 | The implemented models are: 44 | 1. UNet 45 | 2. Window UNet 46 | 3. Dilated UNet 47 | 4. Dilated Densenet 48 | 49 | 50 | -------------------------------------------------------------------------------- /notebooks/.ipynb_checkpoints/model evaluation-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /outputs/exp_1/results/0.0_plot_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suryatejadev/thyroid_segmentation/09c291a16f33490757f195057a64acd1ea17bd83/outputs/exp_1/results/0.0_plot_accuracy.png -------------------------------------------------------------------------------- /outputs/exp_1/results/0.0_plot_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suryatejadev/thyroid_segmentation/09c291a16f33490757f195057a64acd1ea17bd83/outputs/exp_1/results/0.0_plot_loss.png -------------------------------------------------------------------------------- /outputs/exp_1/results/val_dice.txt: -------------------------------------------------------------------------------- 1 | {'D2': (0.029046216800803805, array([[[0., 0., 0., ..., 0., 0., 0.], 2 | [0., 0., 0., ..., 0., 0., 0.], 3 | [0., 0., 0., ..., 0., 0., 0.], 4 | ..., 5 | [0., 0., 0., ..., 0., 0., 0.], 6 | [0., 0., 0., ..., 0., 0., 0.], 7 | [0., 0., 0., ..., 0., 0., 0.]], 8 | 9 | [[0., 0., 0., ..., 0., 0., 0.], 10 | [0., 0., 0., ..., 0., 0., 0.], 11 | [0., 0., 0., ..., 0., 0., 0.], 12 | ..., 13 | [0., 0., 0., ..., 0., 0., 0.], 14 | [0., 0., 0., ..., 0., 0., 0.], 15 | [0., 0., 0., ..., 0., 0., 0.]], 16 | 17 | [[0., 1., 0., ..., 0., 0., 0.], 18 | [0., 1., 0., ..., 0., 0., 0.], 19 | [0., 1., 0., ..., 0., 0., 0.], 20 | ..., 21 | [0., 0., 0., ..., 0., 0., 0.], 22 | [0., 0., 0., ..., 0., 0., 0.], 23 | [0., 0., 0., ..., 0., 0., 0.]], 24 | 25 | ..., 26 | 27 | [[0., 0., 0., ..., 0., 0., 0.], 28 | [0., 0., 0., ..., 0., 0., 0.], 29 | [0., 0., 0., ..., 0., 0., 0.], 30 | ..., 31 | [0., 0., 0., ..., 0., 0., 0.], 32 | [0., 0., 0., ..., 0., 0., 0.], 33 | [0., 0., 0., ..., 0., 0., 0.]], 34 | 35 | [[0., 0., 0., ..., 0., 0., 0.], 36 | [0., 0., 0., ..., 0., 0., 0.], 37 | [0., 0., 0., ..., 0., 0., 0.], 38 | ..., 39 | [0., 0., 0., ..., 0., 0., 0.], 40 | [0., 0., 0., ..., 0., 0., 0.], 41 | [0., 0., 0., ..., 0., 0., 0.]], 42 | 43 | [[0., 0., 0., ..., 0., 0., 0.], 44 | [0., 0., 0., ..., 0., 0., 0.], 45 | [0., 0., 0., ..., 0., 0., 0.], 46 | ..., 47 | [0., 0., 0., ..., 0., 0., 0.], 48 | [0., 0., 0., ..., 0., 0., 0.], 49 | [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)), 'D1': (0.06668710834457738, array([[[0., 0., 0., ..., 0., 0., 0.], 50 | [0., 0., 0., ..., 0., 0., 0.], 51 | [0., 0., 0., ..., 0., 0., 0.], 52 | ..., 53 | [0., 0., 0., ..., 0., 0., 0.], 54 | [0., 0., 0., ..., 0., 0., 0.], 55 | [0., 0., 0., ..., 0., 0., 0.]], 56 | 57 | [[0., 0., 0., ..., 0., 0., 0.], 58 | [0., 0., 0., ..., 0., 0., 0.], 59 | [0., 0., 0., ..., 0., 0., 0.], 60 | ..., 61 | [0., 0., 0., ..., 0., 0., 0.], 62 | [0., 0., 0., ..., 0., 0., 0.], 63 | [0., 0., 0., ..., 0., 0., 0.]], 64 | 65 | [[0., 0., 0., ..., 0., 0., 0.], 66 | [0., 0., 0., ..., 0., 0., 0.], 67 | [0., 0., 0., ..., 0., 0., 0.], 68 | ..., 69 | [0., 0., 0., ..., 0., 0., 0.], 70 | [0., 0., 0., ..., 0., 0., 0.], 71 | [0., 0., 0., ..., 0., 0., 0.]], 72 | 73 | ..., 74 | 75 | [[0., 0., 0., ..., 0., 0., 0.], 76 | [0., 0., 0., ..., 0., 0., 0.], 77 | [0., 0., 0., ..., 0., 0., 0.], 78 | ..., 79 | [0., 0., 0., ..., 0., 0., 0.], 80 | [0., 0., 0., ..., 0., 0., 0.], 81 | [0., 0., 0., ..., 0., 0., 0.]], 82 | 83 | [[0., 0., 0., ..., 0., 0., 0.], 84 | [0., 0., 0., ..., 0., 0., 0.], 85 | [0., 0., 0., ..., 0., 0., 0.], 86 | ..., 87 | [0., 0., 0., ..., 0., 0., 0.], 88 | [0., 0., 0., ..., 0., 0., 0.], 89 | [0., 0., 0., ..., 0., 0., 0.]], 90 | 91 | [[0., 0., 0., ..., 0., 0., 0.], 92 | [0., 0., 0., ..., 0., 0., 0.], 93 | [0., 0., 0., ..., 0., 0., 0.], 94 | ..., 95 | [0., 0., 0., ..., 0., 0., 0.], 96 | [0., 0., 0., ..., 0., 0., 0.], 97 | [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32))} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.13.3 2 | 3 | -------------------------------------------------------------------------------- /scripts/config_files/defaults.config: -------------------------------------------------------------------------------- 1 | [model] 2 | model = unet # Model: unet, dilated-unet or dilated-densenet 3 | features = 32 # Number of features maps after first convolutional layer 4 | depth = 4 # Number of downsampled convolutional blocks 5 | temperature = 1.0 # Temperature of final softmax layer in model 6 | padding = same # Padding in convolutional layers. Either `same' or `valid' 7 | batchnorm = False # Whether to apply batch normalization before activation layers 8 | dropout = 0.0 # Rate for dropout of activation units (set to zero to omit) 9 | window = 1 # Parameter for window Unet 10 | dilation = 1 1 # Dilation parameter for the encoder architecture 11 | 12 | [data generator properties] 13 | zero_padding = 320 448 # zero_padding: None, [320, 448]. If None, zero padding is applied to each batch 14 | data_skew = False # Skew the probabilities of the batch samples in the ratio of percentage of TRUE pixels 15 | 16 | [loss] 17 | loss = dice # Loss function: `pixel' for pixel-wise cross entropy, 18 | # `dice' for sorensen-dice coefficient, 19 | # `jaccard' for intersection over union 20 | loss_weights = 0.5 0.5 # When using dice or jaccard loss, how much to weight each output class 21 | 22 | [training] 23 | epochs = 1 # Number of epochs to train 24 | batch_size = 4 # Mini-batch size for training 25 | validation_split = 0.2 # Fraction of training data to hold out for validation 26 | optimizer = adam # Optimizer: sgd, rmsprop, adagrad, adadelta, adam, adamax, or nadam 27 | learning_rate = 1e-5 # Optimizer learning rate 28 | momentum = # Momentum for SGD optimizer 29 | decay = # Learning rate decay (for all optimizers except nadam) 30 | shuffle_train_val = True 31 | shuffle = True 32 | seed = 0 33 | train_steps_per_epoch = 20 34 | val_steps_per_epoch = 8 35 | 36 | [files] 37 | load_weights = # Name of file to load previously-saved model weights 38 | datadir = ../data # Directory containing list of patientXX/ subdirectories 39 | outdir = ../outputs/exp_1 # Where to write weight files 40 | checkpoint = True # Whether to output model weight checkpoint files 41 | ckpt_period = 10 # Period of epochs after which weights are saved 42 | 43 | [augmentation] 44 | data_augment = True # Whether to apply image augmentation to training set 45 | rotation_range = 180 # Rotation range (0-180 degrees) 46 | width_shift_range = 0.1 # Width shift range, as a float fraction of the width 47 | height_shift_range = 0.1 # Height shift range, as a float fraction of the height 48 | shear_range = 0.1 # Shear intensity (in radians) 49 | zoom_range = 0.05 # Amount of zoom. If a scalar z, zoom in [1-z, 1+z]. 50 | # Can also pass a pair of floats as the zoom range. 51 | fill_mode = nearest # Points outside boundaries are filled according to 52 | # mode: constant, nearest, reflect, or wrap) 53 | alpha = 500 # Random elastic distortion: magnitude of distortion 54 | sigma = 20 # Random elastic distortion: length scale 55 | normalize = True 56 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import division, print_function 4 | 5 | import os 6 | import argparse 7 | import logging 8 | import sys 9 | sys.path.append('..') 10 | import matplotlib as mpl 11 | mpl.use('Agg') 12 | 13 | from keras import losses, optimizers, utils 14 | from keras.optimizers import SGD, RMSprop, Adagrad, Adadelta, Adam, Adamax, Nadam 15 | from keras.callbacks import ModelCheckpoint 16 | from keras import backend as K 17 | 18 | from tnseg import dataset, models, loss, opts, evaluate 19 | 20 | # os.environ["CUDA_VISIBLE_DEVICES"] = "" 21 | 22 | def select_optimizer(optimizer_name, optimizer_args): 23 | optimizers = { 24 | 'sgd': SGD, 25 | 'rmsprop': RMSprop, 26 | 'adagrad': Adagrad, 27 | 'adadelta': Adadelta, 28 | 'adam': Adam, 29 | 'adamax': Adamax, 30 | 'nadam': Nadam, 31 | } 32 | if optimizer_name not in optimizers: 33 | raise Exception("Unknown optimizer ({}).".format(name)) 34 | 35 | return optimizers[optimizer_name](**optimizer_args) 36 | 37 | def train(validation_index, args): 38 | 39 | 40 | logging.info("Loading dataset...") 41 | augmentation_args = None 42 | if args.data_augment==True: 43 | augmentation_args = { 44 | 'rotation_range': args.rotation_range, 45 | 'width_shift_range': args.width_shift_range, 46 | 'height_shift_range': args.height_shift_range, 47 | 'shear_range': args.shear_range, 48 | 'zoom_range': args.zoom_range, 49 | 'fill_mode' : args.fill_mode, 50 | 'alpha': args.alpha, 51 | 'sigma': args.sigma, 52 | } 53 | 54 | # train_generator, train_steps_per_epoch, \ 55 | # val_generator, val_steps_per_epoch = dataset.create_generators( 56 | # args.datadir, args.batch_size, 57 | # validation_split=args.validation_split, 58 | # mask=args.classes, 59 | # shuffle_train_val=args.shuffle_train_val, 60 | # shuffle=args.shuffle, 61 | # seed=args.seed, 62 | # normalize_images=args.normalize, 63 | # augment_training=args.augment_training, 64 | # augment_validation=args.augment_validation, 65 | # augmentation_args=augmentation_args, # new arguments from here... 66 | # window_size=0, 67 | # adaptive_padding=True, 68 | # constant_padding_height=320, 69 | # constant_padding_width=448, 70 | # datagen_method='zeropad' 71 | # ) 72 | 73 | train_generator, val_generator = dataset.create_generators( 74 | args.datadir, args.batch_size, 75 | augmentation_args=augmentation_args, 76 | model=args.model, 77 | zero_padding=args.zero_padding, 78 | data_skew=args.data_skew, 79 | validation_index=validation_index, 80 | window=args.window 81 | ) 82 | 83 | if args.model=='unet': 84 | m = models.unet(height=None, width=None, channels=1, features=args.features, 85 | depth=args.depth, padding=args.padding, temperature=args.temperature, 86 | batchnorm=args.batchnorm, dropout=args.dropout, dilation=args.dilation) 87 | elif args.model=='dilated-unet': 88 | m = models.dilated_unet(height=None, width=None, channels=1, 89 | classes=2, features=args.features, depth=args.depth, 90 | temperature=args.temperature, padding=args.padding, 91 | batchnorm=args.batchnorm, dropout=args.dropout) 92 | elif args.model=='dilated-densenet': 93 | m = models.dilated_densenet(height=None, width=None, channels=1, 94 | classes=2, features=args.features, depth=args.depth, 95 | temperature=args.temperature, padding=args.padding, 96 | batchnorm=args.batchnorm,dropout=args.dropout) 97 | elif args.model=='window-unet': 98 | m = models.window_unet(height=None, width=None, 99 | features=args.features, padding=args.padding, 100 | dropout=args.dropout, batchnorm=args.batchnorm, window_size=args.window) 101 | else: 102 | raise ValueError('Model not supported. Please select from: unet,\ 103 | dilated-unet, dilated-densenet, window-unet') 104 | 105 | m.summary() 106 | 107 | if args.load_weights: 108 | logging.info("Loading saved weights from file: {}".format(args.load_weights)) 109 | m.load_weights(args.load_weights) 110 | 111 | # instantiate optimizer, and only keep args that have been set 112 | # (not all optimizers have args like `momentum' or `decay') 113 | optimizer_args = { 114 | 'lr': args.learning_rate, 115 | 'momentum': args.momentum, 116 | 'decay': args.decay 117 | } 118 | for k in list(optimizer_args): 119 | if optimizer_args[k] is None: 120 | del optimizer_args[k] 121 | optimizer = select_optimizer(args.optimizer, optimizer_args) 122 | 123 | # select loss function: pixel-wise crossentropy, soft dice or soft 124 | # jaccard coefficient 125 | # if args.loss == 'pixel': 126 | # def lossfunc(y_true, y_pred): 127 | # return loss.weighted_categorical_crossentropy( 128 | # y_true, y_pred, args.loss_weights) 129 | # elif args.loss == 'dice': 130 | # def lossfunc(y_true, y_pred): 131 | # return loss.sorensen_dice_loss(y_true, y_pred, args.loss_weights) 132 | # elif args.loss == 'jaccard': 133 | # def lossfunc(y_true, y_pred): 134 | # return loss.jaccard_loss(y_true, y_pred, args.loss_weights) 135 | # else: 136 | # raise Exception("Unknown loss ({})".format(args.loss)) 137 | 138 | # def dice(y_true, y_pred): 139 | # batch_dice_coefs = loss.sorensen_dice(y_true, y_pred, axis=[1, 2]) 140 | # dice_coefs = K.mean(batch_dice_coefs, axis=0) 141 | # return dice_coefs[1] # HACK for 2-class cas metrics = [loss.dice_coef] 142 | 143 | if args.loss == 'dice': 144 | lossfunc = lambda y_true, y_pred: loss.dice_coef_loss(y_true, y_pred) 145 | elif args.loss == 'pixel': 146 | lossfunc = lambda y_true, y_pred: loss.bin_crossentropy_loss(y_true, y_pred) 147 | else: 148 | raise Exception("Unknown loss ({})".format(args.loss)) 149 | 150 | metrics = [loss.dice_coef] 151 | 152 | m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics) 153 | 154 | # automatic saving of model during training 155 | # if args.checkpoint: 156 | # if args.loss == 'pixel': 157 | # filepath = os.path.join( 158 | # args.outdir, "weights-{epoch:02d}-{val_acc:.4f}.hdf5") 159 | # monitor = 'val_acc' 160 | # mode = 'max' 161 | # elif args.loss == 'dice': 162 | # filepath = os.path.join( 163 | # args.outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5") 164 | # monitor='val_dice' 165 | # mode = 'max' 166 | # elif args.loss == 'jaccard': 167 | # filepath = os.path.join( 168 | # args.outdir, "weights-{epoch:02d}-{val_jaccard:.4f}.hdf5") 169 | # monitor='val_jaccard' 170 | # mode = 'max' 171 | # checkpoint = ModelCheckpoint( 172 | # filepath, monitor=monitor, verbose=1, 173 | # save_best_only=True, mode=mode) 174 | # callbacks = [checkpoint] 175 | # else: 176 | # callbacks = [] 177 | 178 | # train 179 | if args.checkpoint: 180 | wt_index = str(validation_index[0]/2) 181 | wt_path = args.outdir + '/weights/wt-'+wt_index+'-{epoch:02d}-{val_dice_coef:.2f}.h5' 182 | checkpoint = ModelCheckpoint(wt_path, monitor='val_dice_coef', verbose=1, 183 | save_weights_only=True, period=args.ckpt_period) 184 | callbacks = [checkpoint] 185 | else: 186 | callbacks = [] 187 | logging.info("Begin training.") 188 | out = m.fit_generator(train_generator, 189 | epochs=args.epochs, 190 | steps_per_epoch=args.train_steps_per_epoch, 191 | validation_data=val_generator, 192 | validation_steps=args.val_steps_per_epoch, 193 | callbacks=callbacks, 194 | verbose=2) 195 | return m, out 196 | 197 | def evaluation(model, out, validation_index, args): 198 | iter_model = validation_index[0]/2 199 | 200 | # Paths to save predictions and results of the model --------------------- 201 | save_prediction_path = args.outdir+'/predictions/' 202 | save_results_path = args.outdir+'/results/' 203 | 204 | # Saving accuracy and error plots ---------------------------------------- 205 | evaluate.eval_error_plots(out, save_results_path+str(iter_model)+'_') 206 | 207 | # Saving history --------------------------------------------------------- 208 | results_dict = {} 209 | results_dict['history_'+str(iter_model)] = out.history 210 | 211 | # Saving output annotation maps ------------------------------------------ 212 | folder_names = os.listdir(args.datadir+'/images/') 213 | dice_vals = [] 214 | for folder_index in validation_index: 215 | folder = folder_names[folder_index][:-4] 216 | folder_prediction_path = save_prediction_path+folder+'/' 217 | if os.path.exists(folder_prediction_path)==False: 218 | os.mkdir(folder_prediction_path) 219 | dice_vals.append(evaluate.evaluate_test_folder(model, folder_prediction_path, 220 | args.datadir+'/data_images/'+folder+'/', n_window=args.window)) 221 | return dice_vals 222 | 223 | #################################################################### 224 | # Training Methodology: 225 | # - The dataset has 16 DICOM videos 226 | # - Using each architecture, We are building 8 models, 227 | # using 2 videos for validation for each model 228 | # - Our output is the validation dice coefficient for the 16 videos 229 | # - This methodology is adopted due to the availability of less data 230 | ##########################################################3######### 231 | if __name__ == '__main__': 232 | 233 | logging.basicConfig(level=logging.INFO) 234 | args = opts.parse_arguments() 235 | 236 | # Creating experiment output folders 237 | if os.path.exists(args.outdir)==False: 238 | os.mkdir(args.outdir) 239 | for item in ['predictions/', 'results/', 'weights/']: 240 | path = args.outdir + '/' + item 241 | if os.path.exists(path) == False: 242 | os.mkdir(path) 243 | 244 | # Train and evaluate 245 | dice_vals = {} 246 | for iter_model in range(1): 247 | validation_index = [iter_model*2, iter_model*2+1] 248 | print('Validation Folders = ',validation_index) 249 | model, out = train(validation_index, args) 250 | dice_vals['D'+str(iter_model*2+1)],\ 251 | dice_vals['D'+str(iter_model*2+2)] = \ 252 | evaluation(model, out, validation_index, args) 253 | 254 | # Print the Folder Dice coefficients in a file 255 | dice_coef_path = args.outdir+'/results/val_dice.txt' 256 | with open(args.outdir+'/results/val_dice.txt','w') as f: 257 | f.write(str(dice_vals)) 258 | f.close() 259 | 260 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | setup(name='tnseg', 5 | version='0.1', 6 | description='Thyorid Nodule Segmentation', 7 | url='http://github.com/suryatejadev/thyroid-segmentation', 8 | author='Surya Teja Devarakonda, Santhosh Vangapelli', 9 | author_email='suryatejadev@cs.umass.edu, svangapelli@cs.umass.edu', 10 | license='MIT', 11 | packages=['tnseg', 'tnseg.models'], 12 | zip_safe=False) 13 | -------------------------------------------------------------------------------- /tnseg/Augmentor/ImageSource.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, 2 | print_function, unicode_literals) 3 | from builtins import * 4 | 5 | import os 6 | import glob 7 | 8 | 9 | class ImageSource(object): 10 | """ 11 | The ImageSource class is used to search for and contain paths to images for augmentation. 12 | """ 13 | def __init__(self, source_directory, recursive_scan=False): 14 | source_directory = os.path.abspath(source_directory) 15 | self.image_list = self.scan_directory(source_directory, recursive_scan) 16 | 17 | self.largest_file_dimensions = (800, 600) 18 | 19 | def scan_directory(self, source_directory, recusrive_scan=False): 20 | # TODO: Make this a static member somewhere later 21 | file_types = ['*.jpg', '*.bmp', '*.jpeg', '*.gif', '*.img', '*.png'] 22 | file_types.extend([str.upper(x) for x in file_types]) 23 | 24 | list_of_files = [] 25 | 26 | for file_type in file_types: 27 | list_of_files.extend(glob.glob(os.path.join(os.path.abspath(source_directory), file_type))) 28 | 29 | return list_of_files 30 | -------------------------------------------------------------------------------- /tnseg/Augmentor/ImageUtilities.py: -------------------------------------------------------------------------------- 1 | # ImageUtilities.py 2 | # Author: Marcus D. Bloice 3 | # Licensed under the terms of the MIT Licence. 4 | """ 5 | The ImageUtilities module provides a number of helper functions, as well as 6 | the main :class:`~Augmentor.ImageUtilities.AugmentorImage` class, that is used 7 | throughout the package as a container class for images to be augmented. 8 | """ 9 | from __future__ import (absolute_import, division, 10 | print_function, unicode_literals) 11 | from builtins import * 12 | 13 | import os 14 | import glob 15 | import numbers 16 | import random 17 | import warnings 18 | import numpy as np 19 | 20 | 21 | class AugmentorImage(object): 22 | """ 23 | Wrapper class containing paths to images, as well as a number of other 24 | parameters, that are used by the Pipeline and Operation modules to perform 25 | augmentation. 26 | 27 | Each image that is found by Augmentor during the initialisation of a 28 | Pipeline object is contained with a new AugmentorImage object. 29 | """ 30 | def __init__(self, image_path, output_directory): 31 | """ 32 | To initialise an AugmentorImage object for any image, the image's 33 | file path is required, as well as that image's output directory, 34 | which defines where any augmented images are stored. 35 | 36 | :param image_path: The full path to an image. 37 | :param output_directory: The directory where augmented images for this 38 | image should be saved. 39 | """ 40 | # Just to stop Pylint complaining about initialising these outside 41 | # of __init__ which is not actually happening, as the are being 42 | # initialised in the setters from within init, but anyway I shall obey. 43 | self._ground_truth = None 44 | self._image_path = None 45 | self._output_directory = None 46 | self._file_format = None # TODO: pass this for each image. 47 | self._image_PIL = None 48 | self._class_label = None 49 | self._class_label_int = None 50 | self._label_pair = None 51 | self._categorical_label = None 52 | 53 | # Now we call the setters that we require. 54 | self.image_path = image_path 55 | self.output_directory = output_directory 56 | 57 | def __str__(self): 58 | return """ 59 | Image path: %s 60 | Ground truth path: %s 61 | File format (inferred from extension): %s 62 | Class label: %s 63 | Numerical class label (auto assigned): %s 64 | """ % (self.image_path, self.ground_truth, self.file_format, self.class_label, self.class_label_int) 65 | 66 | @property 67 | def output_directory(self): 68 | """ 69 | The :attr:`output_directory` property contains a path to the directory 70 | to which augmented images will be saved for this instance. 71 | 72 | :getter: Returns this image's output directory. 73 | :setter: Sets this image's output directory. 74 | :type: String 75 | """ 76 | return self._output_directory 77 | 78 | @output_directory.setter 79 | def output_directory(self, value): 80 | self._output_directory = value 81 | 82 | @property 83 | def image_path(self): 84 | """ 85 | The :attr:`image_path` property contains the absolute file path to the 86 | image. 87 | 88 | :getter: Returns this image's image path. 89 | :setter: Sets this image's image path 90 | :type: String 91 | """ 92 | return self._image_path 93 | 94 | @image_path.setter 95 | def image_path(self, value): 96 | self._image_path = value 97 | #if os.path.exists(value): 98 | # self._image_path = value 99 | #else: 100 | # raise IOError("The file specified does not exist.") 101 | 102 | @property 103 | def image_PIL(self): 104 | return self._image_PIL 105 | 106 | @image_PIL.setter 107 | def image_PIL(self, value): 108 | self._image_PIL = value 109 | 110 | @property 111 | def image_file_name(self): 112 | """ 113 | The :attr:`image_file_name` property contains the **file name** of the 114 | image contained in this instance. **There is no setter for this 115 | property.** 116 | 117 | :getter: Returns this image's file name. 118 | :type: String 119 | """ 120 | return os.path.basename(self.image_path) 121 | 122 | @property 123 | def class_label(self): 124 | return self._class_label 125 | 126 | @class_label.setter 127 | def class_label(self, value): 128 | self._class_label = value 129 | 130 | @property 131 | def class_label_int(self): 132 | return self._class_label_int 133 | 134 | @class_label_int.setter 135 | def class_label_int(self, value): 136 | self._class_label_int = value 137 | 138 | @property 139 | def categorical_label(self): 140 | return self._categorical_label 141 | 142 | @categorical_label.setter 143 | def categorical_label(self, value): 144 | self._categorical_label = value 145 | 146 | @property 147 | def ground_truth(self): 148 | """ 149 | The :attr:`ground_truth` property contains an absolute path to the 150 | ground truth file for an image. 151 | 152 | :getter: Returns this image's ground truth file path. 153 | :setter: Sets this image's ground truth file path. 154 | :type: String 155 | """ 156 | return self._ground_truth 157 | 158 | @ground_truth.setter 159 | def ground_truth(self, value): 160 | if os.path.isfile(value): 161 | self._ground_truth = value 162 | 163 | @property 164 | def label_pair(self): 165 | return self.class_label_int, self.class_label 166 | 167 | @property 168 | def file_format(self): 169 | return self._file_format 170 | 171 | @file_format.setter 172 | def file_format(self, value): 173 | self._file_format = value 174 | 175 | 176 | def parse_user_parameter(user_param): 177 | 178 | if isinstance(user_param, numbers.Real): 179 | return user_param 180 | elif isinstance(user_param, tuple): 181 | return random.sample(user_param, 1)[0] 182 | elif isinstance(user_param, list): 183 | return random.choice(np.arange(*user_param)) 184 | 185 | 186 | def extract_paths_and_extensions(image_path): 187 | """ 188 | Extract an image's file name, its extension, and its root path (the 189 | image's absolute path without the file name). 190 | 191 | :param image_path: The path to the image. 192 | :type image_path: String 193 | :return: A 3-tuple containing the image's file name, extension, and 194 | root path. 195 | """ 196 | file_name, extension = os.path.splitext(image_path) 197 | root_path = os.path.dirname(image_path) 198 | 199 | return file_name, extension, root_path 200 | 201 | 202 | def scan(source_directory, output_directory): 203 | 204 | abs_output_directory = os.path.abspath(output_directory) 205 | files_and_directories = glob.glob(os.path.join(os.path.abspath(source_directory), '*')) 206 | 207 | directory_count = 0 208 | directories = [] 209 | 210 | class_labels = [] 211 | 212 | for f in files_and_directories: 213 | if os.path.isdir(f): 214 | if f != abs_output_directory: 215 | directories.append(f) 216 | directory_count += 1 217 | 218 | directories = sorted(directories) 219 | label_counter = 0 220 | 221 | if directory_count == 0: 222 | 223 | augmentor_images = [] 224 | # This was wrong 225 | # parent_directory_name = os.path.basename(os.path.abspath(os.path.join(source_directory, os.pardir))) 226 | parent_directory_name = os.path.basename(os.path.abspath(source_directory)) 227 | 228 | for image_path in scan_directory(source_directory): 229 | a = AugmentorImage(image_path=image_path, output_directory=abs_output_directory) 230 | a.class_label = parent_directory_name 231 | a.class_label_int = label_counter 232 | a.categorical_label = [label_counter] 233 | a.file_format = os.path.splitext(image_path)[1].split(".")[1] 234 | augmentor_images.append(a) 235 | 236 | class_labels.append((label_counter, parent_directory_name)) 237 | 238 | return augmentor_images, class_labels 239 | 240 | elif directory_count != 0: 241 | augmentor_images = [] 242 | 243 | for d in directories: 244 | output_directory = os.path.join(abs_output_directory, os.path.split(d)[1]) 245 | for image_path in scan_directory(d): 246 | categorical_label = np.zeros(directory_count, dtype=np.uint32) 247 | a = AugmentorImage(image_path=image_path, output_directory=output_directory) 248 | a.class_label = os.path.split(d)[1] 249 | a.class_label_int = label_counter 250 | categorical_label[label_counter] = 1 # Set to 1 with the index of the current class. 251 | a.categorical_label = categorical_label 252 | a.file_format = os.path.splitext(image_path)[1].split(".")[1] 253 | augmentor_images.append(a) 254 | class_labels.append((os.path.split(d)[1], label_counter)) 255 | label_counter += 1 256 | 257 | return augmentor_images, class_labels 258 | 259 | 260 | def scan_directory(source_directory): 261 | """ 262 | Scan a directory for images, returning any images found with the 263 | extensions ``.jpg``, ``.JPG``, ``.jpeg``, ``.JPEG``, ``.gif``, ``.GIF``, 264 | ``.img``, ``.IMG``, ``.png`` or ``.PNG``. 265 | 266 | :param source_directory: The directory to scan for images. 267 | :type source_directory: String 268 | :return: A list of images found in the :attr:`source_directory` 269 | """ 270 | # TODO: GIFs are highly problematic. It may make sense to drop GIF support. 271 | file_types = ['*.jpg', '*.bmp', '*.jpeg', '*.gif', '*.img', '*.png'] 272 | 273 | list_of_files = [] 274 | 275 | if os.name == "nt": 276 | for file_type in file_types: 277 | list_of_files.extend(glob.glob(os.path.join(os.path.abspath(source_directory), file_type))) 278 | else: 279 | file_types.extend([str.upper(str(x)) for x in file_types]) 280 | for file_type in file_types: 281 | list_of_files.extend(glob.glob(os.path.join(os.path.abspath(source_directory), file_type))) 282 | 283 | return list_of_files 284 | 285 | 286 | def scan_directory_with_classes(source_directory): 287 | warnings.warn("The scan_directory_with_classes() function has been deprecated.", DeprecationWarning) 288 | l = glob.glob(os.path.join(source_directory, '*')) 289 | 290 | directories = [] 291 | 292 | for f in l: 293 | if os.path.isdir(f): 294 | directories.append(f) 295 | 296 | list_of_files = {} 297 | 298 | for d in directories: 299 | list_of_files_current_folder = scan_directory(d) 300 | list_of_files[os.path.split(d)[1]] = list_of_files_current_folder 301 | 302 | return list_of_files 303 | -------------------------------------------------------------------------------- /tnseg/Augmentor/Pipeline.py: -------------------------------------------------------------------------------- 1 | # Pipeline.py 2 | # Author: Marcus D. Bloice 3 | # Licensed under the terms of the MIT Licence. 4 | """ 5 | The Pipeline module is the user facing API for the Augmentor package. It 6 | contains the :class:`~Augmentor.Pipeline.Pipeline` class which is used to 7 | create pipeline objects, which can be used to build an augmentation pipeline 8 | by adding operations to the pipeline object. 9 | 10 | For a good overview of how to use Augmentor, along with code samples and 11 | example images, can be seen in the :ref:`mainfeatures` section. 12 | """ 13 | from __future__ import (absolute_import, division, 14 | print_function, unicode_literals) 15 | 16 | from builtins import * 17 | 18 | from .Operations import * 19 | from .ImageUtilities import scan_directory, scan, AugmentorImage 20 | 21 | import os 22 | import sys 23 | import random 24 | import uuid 25 | import warnings 26 | import numbers 27 | import numpy as np 28 | 29 | from tqdm import tqdm 30 | from PIL import Image 31 | 32 | 33 | class Pipeline(object): 34 | """ 35 | The Pipeline class handles the creation of augmentation pipelines 36 | and the generation of augmented data by applying operations to 37 | this pipeline. 38 | """ 39 | 40 | # Some class variables we use often 41 | _probability_error_text = "The probability argument must be between 0 and 1." 42 | _threshold_error_text = "The value of threshold must be between 0 and 255." 43 | _valid_formats = ["PNG", "BMP", "GIF", "JPEG"] 44 | _legal_filters = ["NEAREST", "BICUBIC", "ANTIALIAS", "BILINEAR"] 45 | 46 | def __init__(self, source_directory=None, output_directory="output", save_format=None): 47 | """ 48 | Create a new Pipeline object pointing to a directory containing your 49 | original image dataset. 50 | 51 | Create a new Pipeline object, using the :attr:`source_directory` 52 | parameter as a source directory where your original images are 53 | stored. This folder will be scanned, and any valid file files 54 | will be collected and used as the original dataset that should 55 | be augmented. The scan will find any image files with the extensions 56 | JPEG/JPG, PNG, and GIF (case insensitive). 57 | 58 | :param source_directory: A directory on your filesystem where your 59 | original images are stored. 60 | :param output_directory: Specifies where augmented images should be 61 | saved to the disk. Default is the directory **output** relative to 62 | the path where the original image set was specified. If it does not 63 | exist it will be created. 64 | :param save_format: The file format to use when saving newly created, 65 | augmented images. Default is JPEG. Legal options are BMP, PNG, and 66 | GIF. 67 | :return: A :class:`Pipeline` object. 68 | """ 69 | random.seed() 70 | 71 | # TODO: Allow a single image to be added when initialising. 72 | # Initialise some variables for the Pipeline object. 73 | self.image_counter = 0 74 | self.augmentor_images = [] 75 | self.distinct_dimensions = set() 76 | self.distinct_formats = set() 77 | self.save_format = save_format 78 | self.operations = [] 79 | self.class_labels = [] 80 | self.process_ground_truth_images = False 81 | 82 | # Now we populate some fields, which we may need to do again later if another 83 | # directory is added, so we place it all in a function of its own. 84 | if source_directory is not None: 85 | self._populate(source_directory=source_directory, 86 | output_directory=output_directory, 87 | ground_truth_directory=None, 88 | ground_truth_output_directory=output_directory) 89 | 90 | def _populate(self, source_directory, output_directory, ground_truth_directory, ground_truth_output_directory): 91 | """ 92 | Private method for populating member variables with AugmentorImage 93 | objects for each of the images found in the source directory 94 | specified by the user. It also populates a number of fields such as 95 | the :attr:`output_directory` member variable, used later when saving 96 | images to disk. 97 | 98 | This method is used by :func:`__init__`. 99 | 100 | :param source_directory: The directory to scan for images. 101 | :param output_directory: The directory to set for saving files. 102 | Defaults to a directory named output relative to 103 | :attr:`source_directory`. 104 | :param ground_truth_directory: A directory containing ground truth 105 | files for the associated images in the :attr:`source_directory` 106 | directory. 107 | :param ground_truth_output_directory: A path to a directory to store 108 | the output of the operations on the ground truth data set. 109 | :type source_directory: String 110 | :type output_directory: String 111 | :type ground_truth_directory: String 112 | :type ground_truth_output_directory: String 113 | :return: None 114 | """ 115 | 116 | # Check if the source directory for the original images to augment exists at all 117 | if not os.path.exists(source_directory): 118 | raise IOError("The source directory you specified does not exist.") 119 | 120 | # If a ground truth directory is being specified we will check here if the path exists at all. 121 | if ground_truth_directory: 122 | if not os.path.exists(ground_truth_directory): 123 | raise IOError("The ground truth source directory you specified does not exist.") 124 | 125 | # Get absolute path for output 126 | abs_output_directory = os.path.join(source_directory, output_directory) 127 | 128 | # Scan the directory that user supplied. 129 | self.augmentor_images, self.class_labels = scan(source_directory, abs_output_directory) 130 | 131 | # Make output directory/directories 132 | if len(set(self.class_labels)) <= 1: # Fixed bad bug by adding set() function here. 133 | if not os.path.exists(abs_output_directory): 134 | try: 135 | os.makedirs(abs_output_directory) 136 | except IOError: 137 | print("Insufficient rights to read or write output directory (%s)" % abs_output_directory) 138 | else: 139 | for class_label in self.class_labels: 140 | if not os.path.exists(os.path.join(abs_output_directory, str(class_label[0]))): 141 | try: 142 | os.makedirs(os.path.join(abs_output_directory, str(class_label[0]))) 143 | except IOError: 144 | print("Insufficient rights to read or write output directory (%s)" % abs_output_directory) 145 | 146 | # Check the images, read their dimensions, and remove them if they cannot be read 147 | # TODO: Do not throw an error here, just remove the image and continue. 148 | for augmentor_image in self.augmentor_images: 149 | try: 150 | with Image.open(augmentor_image.image_path) as opened_image: 151 | self.distinct_dimensions.add(opened_image.size) 152 | self.distinct_formats.add(opened_image.format) 153 | except IOError as e: 154 | print("There is a problem with image %s in your source directory: %s" % (augmentor_image.image_path, e.message)) 155 | self.augmentor_images.remove(augmentor_image) 156 | 157 | sys.stdout.write("Initialised with %s image(s) found.\n" % len(self.augmentor_images)) 158 | sys.stdout.write("Output directory set to %s." % abs_output_directory) 159 | 160 | def _execute(self, augmentor_image, save_to_disk=True, list=False): 161 | """ 162 | Private method. Used to pass an image through the current pipeline, 163 | and return the augmented image. 164 | 165 | The returned image can then either be saved to disk or simply passed 166 | back to the user. Currently this is fixed to True, as Augmentor 167 | has only been implemented to save to disk at present. 168 | 169 | :param augmentor_image: The image to pass through the pipeline. 170 | :param save_to_disk: Whether to save the image to disk. Currently 171 | fixed to true. 172 | :type augmentor_image: :class:`ImageUtilities.AugmentorImage` 173 | :type save_to_disk: Boolean 174 | :return: The augmented image. 175 | """ 176 | # self.image_counter += 1 # TODO: See if I can remove this... 177 | 178 | images = [] 179 | 180 | if augmentor_image.image_PIL is not None: 181 | images.append(augmentor_image.image_PIL) 182 | if augmentor_image.image_PIL_ground_truth is not None: 183 | images.append(augmentor_image.image_PIL_ground_truth) 184 | 185 | if augmentor_image.image_path is not None: 186 | images.append(Image.open(augmentor_image.image_path)) 187 | 188 | if augmentor_image.ground_truth is not None: 189 | if isinstance(augmentor_image.ground_truth, list): 190 | for image in augmentor_image.ground_truth: 191 | images.append(Image.open(image)) 192 | else: 193 | images.append(Image.open(augmentor_image.ground_truth)) 194 | 195 | for operation in self.operations: 196 | r = round(random.uniform(0, 1), 1) 197 | if r <= operation.probability: 198 | images = operation.perform_operation(images) 199 | 200 | if save_to_disk: 201 | file_name = str(uuid.uuid4()) 202 | try: 203 | # TODO: Add a 'coerce' parameter to force conversion to RGB for PNGA->JPEG saving. 204 | # if image.mode != "RGB": 205 | # image = image.convert("RGB") 206 | for i in range(len(images)): 207 | if i == 0: 208 | save_name = augmentor_image.class_label + "_original_" + file_name \ 209 | + "." + (self.save_format if self.save_format else augmentor_image.file_format) 210 | images[i].save(os.path.join(augmentor_image.output_directory, save_name)) 211 | else: 212 | save_name = "_groundtruth_(" + str(i) + ")_" + augmentor_image.class_label + "_" + file_name \ 213 | + "." + (self.save_format if self.save_format else augmentor_image.file_format) 214 | images[i].save(os.path.join(augmentor_image.output_directory, save_name)) 215 | except IOError as e: 216 | print("Error writing %s, %s. Change save_format to PNG?" % (file_name, e.message)) 217 | print("You can change the save format using the set_save_format(save_format) function.") 218 | print("By passing save_format=\"auto\", Augmentor can save in the correct format automatically.") 219 | 220 | # Currently we return only the first image if it is a list 221 | # for the generator functions. This will be fixed in a future 222 | # version. 223 | return images[0] if list is False else images 224 | 225 | def _execute_with_array(self, image): 226 | """ 227 | Private method used to execute a pipeline on array or matrix data. 228 | :param image: The image to pass through the pipeline. 229 | :type image: Array like object. 230 | :return: The augmented image. 231 | """ 232 | 233 | pil_image = [Image.fromarray(image)] 234 | 235 | for operation in self.operations: 236 | r = round(random.uniform(0, 1), 1) 237 | if r <= operation.probability: 238 | pil_image = operation.perform_operation(pil_image) 239 | 240 | numpy_array = np.asarray(pil_image[0]) 241 | 242 | return numpy_array 243 | 244 | def set_save_format(self, save_format): 245 | """ 246 | Set the save format for the pipeline. Pass the value 247 | :attr:`save_format="auto"` to allow Augmentor to choose 248 | the correct save format based on each individual image's 249 | file extension. 250 | 251 | If :attr:`save_format` is set to, for example, 252 | :attr:`save_format="JPEG"` or :attr:`save_format="JPG"`, 253 | Augmentor will attempt to save the files using the 254 | JPEG format, which may result in errors if the file cannot 255 | be saved in this format, such as PNG images with an alpha 256 | channel. 257 | 258 | :param save_format: The save format to save the images 259 | when writing to disk. 260 | :return: None 261 | """ 262 | 263 | if save_format == "auto": 264 | self.save_format = None 265 | else: 266 | self.save_format = save_format 267 | 268 | def sample(self, n): 269 | """ 270 | Generate :attr:`n` number of samples from the current pipeline. 271 | 272 | This function samples from the pipeline, using the original images 273 | defined during instantiation. All images generated by the pipeline 274 | are by default stored in an ``output`` directory, relative to the 275 | path defined during the pipeline's instantiation. 276 | 277 | :param n: The number of new samples to produce. 278 | :type n: Integer 279 | :return: None 280 | """ 281 | if len(self.augmentor_images) == 0: 282 | raise IndexError("There are no images in the pipeline. " 283 | "Add a directory using add_directory(), " 284 | "pointing it to a directory containing images.") 285 | 286 | if len(self.operations) == 0: 287 | raise IndexError("There are no operations associated with this pipeline.") 288 | 289 | sample_count = 1 290 | 291 | progress_bar = tqdm(total=n, desc="Executing Pipeline", unit=' Samples', leave=False) 292 | while sample_count <= n: 293 | for augmentor_image in self.augmentor_images: 294 | if sample_count <= n: 295 | self._execute(augmentor_image) 296 | file_name_to_print = os.path.basename(augmentor_image.image_path) 297 | # This is just to shorten very long file names which obscure the progress bar. 298 | if len(file_name_to_print) >= 30: 299 | file_name_to_print = file_name_to_print[0:10] + "..." + \ 300 | file_name_to_print[-10: len(file_name_to_print)] 301 | progress_bar.set_description("Processing %s" % file_name_to_print) 302 | progress_bar.update(1) 303 | sample_count += 1 304 | progress_bar.close() 305 | 306 | def sample_with_array(self, image_array, ground_truth=None, save_to_disk=False, mode = 'RGB'): 307 | """ 308 | Generate images using a single image in array-like format. 309 | 310 | .. seealso:: 311 | See :func:`keras_image_generator_without_replacement()` for 312 | 313 | :param image_array: The image to pass through the pipeline. 314 | :param save_to_disk: Whether to save to disk or not (default). 315 | :return: 316 | """ 317 | a = AugmentorImage(image_path=None, output_directory=None) 318 | a.image_PIL = Image.fromarray(image_array, mode) 319 | a.image_PIL_ground_truth = Image.fromarray(ground_truth) if ground_truth is not None else None 320 | 321 | return self._execute(a, save_to_disk, list=True) 322 | 323 | @staticmethod 324 | def categorical_labels(numerical_labels): 325 | """ 326 | Return categorical labels for an array of 0-based numerical labels. 327 | 328 | :param numerical_labels: The numerical labels. 329 | :type numerical_labels: Array-like list. 330 | :return: The categorical labels. 331 | """ 332 | # class_labels_np = np.array([x.class_label_int for x in numerical_labels]) 333 | class_labels_np = np.array(numerical_labels) 334 | one_hot_encoding = np.zeros((class_labels_np.size, class_labels_np.max() + 1)) 335 | one_hot_encoding[np.arange(class_labels_np.size), class_labels_np] = 1 336 | one_hot_encoding = one_hot_encoding.astype(np.uint) 337 | 338 | return one_hot_encoding 339 | 340 | def image_generator(self): 341 | while True: 342 | im_index = random.randint(0, len(self.augmentor_images)-1) # Fix for issue 52. 343 | yield self._execute(self.augmentor_images[im_index], save_to_disk=False), \ 344 | self.augmentor_images[im_index].class_label_int 345 | 346 | def keras_generator(self, batch_size, scaled=True, image_data_format="channels_last"): 347 | """ 348 | Returns an image generator that will sample from the current pipeline 349 | indefinitely, as long as it is called. 350 | 351 | .. warning:: 352 | This function returns images from the current pipeline 353 | **with replacement**. 354 | 355 | You must configure the generator to provide data in the same 356 | format that Keras is configured for. You can use the functions 357 | :func:`keras.backend.image_data_format()` and 358 | :func:`keras.backend.set_image_data_format()` to get and set 359 | Keras' image format at runtime. 360 | 361 | .. code-block:: python 362 | 363 | >>> from keras import backend as K 364 | >>> K.image_data_format() 365 | 'channels_first' 366 | >>> K.set_image_data_format('channels_last') 367 | >>> K.image_data_format() 368 | 'channels_last' 369 | 370 | By default, Augmentor uses ``'channels_last'``. 371 | 372 | :param batch_size: The number of images to return per batch. 373 | :type batch_size: Integer 374 | :param scaled: True (default) if pixels are to be converted 375 | to float32 values between 0 and 1, or False if pixels should be 376 | integer values between 0-255. 377 | :type scaled: Boolean 378 | :param image_data_format: Either ``'channels_last'`` (default) or 379 | ``'channels_first'``. 380 | :type image_data_format: String 381 | :return: An image generator. 382 | """ 383 | 384 | if image_data_format not in ["channels_first", "channels_last"]: 385 | warnings.warn("To work with Keras, must be one of channels_first or channels_last.") 386 | 387 | while True: 388 | 389 | # Randomly select 25 images for augmentation and yield the 390 | # augmented images. 391 | # X = np.array([]) 392 | # y = np.array([]) 393 | # The correct thing to do here is to pre-allocate 394 | # batch = np.ndarray((batch_size, 28, 28, 1)) 395 | 396 | X = [] 397 | y = [] 398 | 399 | for i in range(batch_size): 400 | 401 | # Pre-allocate 402 | # batch[i:i+28] 403 | 404 | # Select random image, get image array and label 405 | random_image_index = random.randint(0, len(self.augmentor_images)-1) 406 | numpy_array = np.asarray(self._execute(self.augmentor_images[random_image_index], save_to_disk=False)) 407 | label = self.augmentor_images[random_image_index].categorical_label 408 | 409 | # Reshape 410 | w = numpy_array.shape[0] 411 | h = numpy_array.shape[1] 412 | 413 | if np.ndim(numpy_array) == 2: 414 | l = 1 415 | else: 416 | l = np.shape(numpy_array)[2] 417 | 418 | if image_data_format == "channels_last": 419 | numpy_array = numpy_array.reshape(w, h, l) 420 | elif image_data_format == "channels_first": 421 | numpy_array = numpy_array.reshape(l, w, h) 422 | 423 | X.append(numpy_array) 424 | y.append(label) 425 | 426 | X = np.asarray(X) 427 | y = np.asarray(y) 428 | 429 | if scaled: 430 | X = X.astype('float32') 431 | X /= 255 432 | 433 | yield (X, y) 434 | 435 | def keras_generator_from_array(self, images, labels, batch_size, scaled=True, image_data_format="channels_last"): 436 | """ 437 | Returns an image generator that will sample from the current pipeline 438 | indefinitely, as long as it is called. 439 | 440 | .. warning:: 441 | This function returns images from :attr:`images` 442 | **with replacement**. 443 | 444 | You must configure the generator to provide data in the same 445 | format that Keras is configured for. You can use the functions 446 | :func:`keras.backend.image_data_format()` and 447 | :func:`keras.backend.set_image_data_format()` to get and set 448 | Keras' image format at runtime. 449 | 450 | .. code-block:: python 451 | 452 | >>> from keras import backend as K 453 | >>> K.image_data_format() 454 | 'channels_first' 455 | >>> K.set_image_data_format('channels_last') 456 | >>> K.image_data_format() 457 | 'channels_last' 458 | 459 | By default, Augmentor uses ``'channels_last'``. 460 | 461 | :param images: The images to augment using the current pipeline. 462 | :type images: Array-like matrix. For greyscale images they can be 463 | in the form ``(l, x, y)`` or ``(l, x, y, 1)``, where 464 | :attr:`l` is the number of images, :attr:`x` is the image width 465 | and :attr:`y` is the image height. For RGB/A images, the matrix 466 | should be in the form ``(l, x, y, n)``, where :attr:`n` is the 467 | number of layers, e.g. 3 for RGB or 4 for RGBA and CMYK. 468 | :param labels: The label associated with each image in :attr:`images`. 469 | :type labels: List. 470 | :param batch_size: The number of images to return per batch. 471 | :type batch_size: Integer 472 | :param scaled: True (default) if pixels are to be converted 473 | to float32 values between 0 and 1, or False if pixels should be 474 | integer values between 0-255. 475 | :type scaled: Boolean 476 | :param image_data_format: Either ``'channels_last'`` (default) or 477 | ``'channels_first'``. When ``'channels_last'`` is specified the 478 | returned batch is in the form ``(batch_size, x, y, num_channels)``, 479 | while for ``'channels_last'`` the batch is returned in the form 480 | ``(batch_size, num_channels, x, y)``. 481 | :param image_data_format: String 482 | :return: An image generator. 483 | """ 484 | 485 | # Here, we will expect an matrix in the shape (l, x, y) 486 | # where l is the number of images 487 | 488 | # Check if the labels and images align 489 | if len(images) != len(labels): 490 | raise IndexError("The number of images does not match the number of labels.") 491 | 492 | while True: 493 | 494 | X = [] 495 | y = [] 496 | 497 | for i in range(batch_size): 498 | 499 | random_image_index = random.randint(0, len(images)-1) 500 | 501 | # Before passing the image we must format it in a shape that 502 | # Pillow can understand, that is either (w, h) for greyscale 503 | # or (w, h, num_channels) for RGB, RGBA, or CMYK images. 504 | # PIL expects greyscale or B&W images in the form (w, h) 505 | # and RGB(A) images images in the form (w, h, n) where n is 506 | # the number of channels, which is 3 or 4. 507 | # However, Keras often works with greyscale/B&W images in the 508 | # form (w, h, 1). We will convert all images to (w, h) if they 509 | # are not RGB, otherwise we will use (w, h, n). 510 | if np.ndim(images) == 3: 511 | l = 1 512 | else: 513 | l = np.shape(images)[-1] 514 | 515 | w = images[random_image_index].shape[0] 516 | h = images[random_image_index].shape[1] 517 | 518 | if l == 1: 519 | numpy_array = self._execute_with_array(np.reshape(images[random_image_index], (w, h))) 520 | else: 521 | numpy_array = self._execute_with_array(np.reshape(images[random_image_index], (w, h, l))) 522 | 523 | if image_data_format == "channels_first": 524 | numpy_array = numpy_array.reshape(l, w, h) 525 | elif image_data_format == "channels_last": 526 | numpy_array = numpy_array.reshape(w, h, l) 527 | 528 | X.append(numpy_array) 529 | y.append(labels[random_image_index]) 530 | 531 | X = np.asarray(X) 532 | y = np.asarray(y) 533 | 534 | if scaled: 535 | X = X.astype('float32') 536 | X /= 255 537 | 538 | yield(X, y) 539 | 540 | def torch_transform(self): 541 | """ 542 | Returns the pipeline as a function that can be used with torchvision. 543 | 544 | .. code-block:: python 545 | 546 | >>> import Augmentor 547 | >>> import torchvision 548 | >>> p = Augmentor.Pipeline() 549 | >>> p.rotate(probability=0.7, max_left_rotate=10, max_right_rotate=10) 550 | >>> p.zoom(probability=0.5, min_factor=1.1, max_factor=1.5) 551 | >>> transforms = torchvision.transforms.Compose([ 552 | >>> p.torch_transform(), 553 | >>> torchvision.transforms.ToTensor(), 554 | >>> ]) 555 | 556 | :return: The pipeline as a function. 557 | """ 558 | def _transform(image): 559 | for operation in self.operations: 560 | r = round(random.uniform(0, 1), 1) 561 | if r <= operation.probability: 562 | image = [image] 563 | image = operation.perform_operation(image) 564 | 565 | return image 566 | 567 | return _transform 568 | 569 | def add_operation(self, operation): 570 | """ 571 | Add an operation directly to the pipeline. Can be used to add custom 572 | operations to a pipeline. 573 | 574 | To add custom operations to a pipeline, subclass from the 575 | Operation abstract base class, overload its methods, and insert the 576 | new object into the pipeline using this method. 577 | 578 | .. seealso:: The :class:`.Operation` class. 579 | 580 | :param operation: An object of the operation you wish to add to the 581 | pipeline. Will accept custom operations written at run-time. 582 | :type operation: Operation 583 | :return: None 584 | """ 585 | if isinstance(operation, Operation): 586 | self.operations.append(operation) 587 | else: 588 | raise TypeError("Must be of type Operation to be added to the pipeline.") 589 | 590 | def remove_operation(self, operation_index=-1): 591 | """ 592 | Remove the operation specified by :attr:`operation_index`, if 593 | supplied, otherwise it will remove the latest operation added to the 594 | pipeline. 595 | 596 | .. seealso:: Use the :func:`status` function to find an operation's 597 | index. 598 | 599 | :param operation_index: The index of the operation to remove. 600 | :type operation_index: Integer 601 | :return: The removed operation. You can reinsert this at end of the 602 | pipeline using :func:`add_operation` if required. 603 | """ 604 | 605 | # Python's own List exceptions can handle erroneous user input. 606 | self.operations.pop(operation_index) 607 | 608 | def add_further_directory(self, new_source_directory, new_output_directory="output"): 609 | """ 610 | Add a further directory containing images you wish to scan for augmentation. 611 | 612 | :param new_source_directory: The directory to scan for images. 613 | :param new_output_directory: The directory to use for outputted, 614 | augmented images. 615 | :type new_source_directory: String 616 | :type new_output_directory: String 617 | :return: None 618 | """ 619 | if not os.path.exists(new_source_directory): 620 | raise IOError("The path does not appear to exist.") 621 | 622 | self._populate(source_directory=new_source_directory, 623 | output_directory=new_output_directory, 624 | ground_truth_directory=None, 625 | ground_truth_output_directory=new_output_directory) 626 | 627 | def status(self): 628 | """ 629 | Prints the status of the pipeline to the console. If you want to 630 | remove an operation, use the index shown and the 631 | :func:`remove_operation` method. 632 | 633 | .. seealso:: The :func:`remove_operation` function. 634 | 635 | .. seealso:: The :func:`add_operation` function. 636 | 637 | The status includes the number of operations currently attached to 638 | the pipeline, each operation's parameters, the number of images in the 639 | pipeline, and a summary of the images' properties, such as their 640 | dimensions and formats. 641 | 642 | :return: None 643 | """ 644 | # TODO: Return this as a dictionary of some kind and print from the dict if in console 645 | print("Operations: %s" % len(self.operations)) 646 | 647 | if len(self.operations) != 0: 648 | operation_index = 0 649 | for operation in self.operations: 650 | print("\t%s: %s (" % (operation_index, operation), end="") 651 | for operation_attribute, operation_value in operation.__dict__.items(): 652 | print("%s=%s " % (operation_attribute, operation_value), end="") 653 | print(")") 654 | operation_index += 1 655 | 656 | print("Images: %s" % len(self.augmentor_images)) 657 | 658 | label_pairs = sorted(set([x.label_pair for x in self.augmentor_images])) 659 | 660 | print("Classes: %s" % len(label_pairs)) 661 | 662 | for label_pair in label_pairs: 663 | print ("\tClass index: %s Class label: %s " % (label_pair[0], label_pair[1])) 664 | 665 | if len(self.augmentor_images) != 0: 666 | print("Dimensions: %s" % len(self.distinct_dimensions)) 667 | for distinct_dimension in self.distinct_dimensions: 668 | print("\tWidth: %s Height: %s" % (distinct_dimension[0], distinct_dimension[1])) 669 | print("Formats: %s" % len(self.distinct_formats)) 670 | for distinct_format in self.distinct_formats: 671 | print("\t %s" % distinct_format) 672 | 673 | print("\nYou can remove operations using the appropriate index and the remove_operation(index) function.") 674 | 675 | @staticmethod 676 | def set_seed(seed): 677 | """ 678 | Set the seed of Python's internal random number generator. 679 | 680 | :param seed: The seed to use. Strings or other objects will be hashed. 681 | :type seed: Integer 682 | :return: None 683 | """ 684 | random.seed(seed) 685 | 686 | # TODO: Implement 687 | # def subtract_mean(self, probability=1): 688 | # # For implementation example, see bottom of: 689 | # # https://patrykchrabaszcz.github.io/Imagenet32/ 690 | # self.add_operation(Mean(probability=probability)) 691 | 692 | def rotate90(self, probability): 693 | """ 694 | Rotate an image by 90 degrees. 695 | 696 | The operation will rotate an image by 90 degrees, and will be 697 | performed with a probability of that specified by the 698 | :attr:`probability` parameter. 699 | 700 | :param probability: A value between 0 and 1 representing the 701 | probability that the operation should be performed. 702 | :type probability: Float 703 | :return: None 704 | """ 705 | if not 0 < probability <= 1: 706 | raise ValueError(Pipeline._probability_error_text) 707 | else: 708 | self.add_operation(Rotate(probability=probability, rotation=90)) 709 | 710 | def rotate180(self, probability): 711 | """ 712 | Rotate an image by 180 degrees. 713 | 714 | The operation will rotate an image by 180 degrees, and will be 715 | performed with a probability of that specified by the 716 | :attr:`probability` parameter. 717 | 718 | :param probability: A value between 0 and 1 representing the 719 | probability that the operation should be performed. 720 | :type probability: Float 721 | :return: None 722 | """ 723 | if not 0 < probability <= 1: 724 | raise ValueError(Pipeline._probability_error_text) 725 | else: 726 | self.add_operation(Rotate(probability=probability, rotation=180)) 727 | 728 | def rotate270(self, probability): 729 | """ 730 | Rotate an image by 270 degrees. 731 | 732 | The operation will rotate an image by 270 degrees, and will be 733 | performed with a probability of that specified by the 734 | :attr:`probability` parameter. 735 | 736 | :param probability: A value between 0 and 1 representing the 737 | probability that the operation should be performed. 738 | :type probability: Float 739 | :return: None 740 | """ 741 | if not 0 < probability <= 1: 742 | raise ValueError(Pipeline._probability_error_text) 743 | else: 744 | self.add_operation(Rotate(probability=probability, rotation=270)) 745 | 746 | def rotate_random_90(self, probability): 747 | """ 748 | Rotate an image by either 90, 180, or 270 degrees, selected randomly. 749 | 750 | This function will rotate by either 90, 180, or 270 degrees. This is 751 | useful to avoid scenarios where images may be rotated back to their 752 | original positions (such as a :func:`rotate90` and a :func:`rotate270` 753 | being performed directly afterwards. The random rotation is chosen 754 | uniformly from 90, 180, or 270 degrees. The probability controls the 755 | chance of the operation being performed at all, and does not affect 756 | the rotation degree. 757 | 758 | :param probability: A value between 0 and 1 representing the 759 | probability that the operation should be performed. 760 | :type probability: Float 761 | :return: None 762 | """ 763 | if not 0 < probability <= 1: 764 | raise ValueError(Pipeline._probability_error_text) 765 | else: 766 | self.add_operation(Rotate(probability=probability, rotation=-1)) 767 | 768 | def rotate(self, probability, max_left_rotation, max_right_rotation): 769 | """ 770 | Rotate an image by an arbitrary amount. 771 | 772 | The operation will rotate an image by an random amount, within a range 773 | specified. The parameters :attr:`max_left_rotation` and 774 | :attr:`max_right_rotation` allow you to control this range. If you 775 | wish to rotate the images by an exact number of degrees, set both 776 | :attr:`max_left_rotation` and :attr:`max_right_rotation` to the same 777 | value. 778 | 779 | .. note:: This function will rotate **in place**, and crop the largest 780 | possible rectangle from the rotated image. 781 | 782 | In practice, angles larger than 25 degrees result in images that 783 | do not render correctly, therefore there is a limit of 25 degrees 784 | for this function. 785 | 786 | If this function returns images that are not rendered correctly, then 787 | you must reduce the :attr:`max_left_rotation` and 788 | :attr:`max_right_rotation` arguments! 789 | 790 | :param probability: A value between 0 and 1 representing the 791 | probability that the operation should be performed. 792 | :param max_left_rotation: The maximum number of degrees the image can 793 | be rotated to the left. 794 | :param max_right_rotation: The maximum number of degrees the image can 795 | be rotated to the right. 796 | :type probability: Float 797 | :type max_left_rotation: Integer 798 | :type max_right_rotation: Integer 799 | :return: None 800 | """ 801 | if not 0 < probability <= 1: 802 | raise ValueError(Pipeline._probability_error_text) 803 | if not 0 <= max_left_rotation <= 25: 804 | raise ValueError("The max_left_rotation argument must be between 0 and 25.") 805 | if not 0 <= max_right_rotation <= 25: 806 | raise ValueError("The max_right_rotation argument must be between 0 and 25.") 807 | else: 808 | self.add_operation(RotateRange(probability=probability, max_left_rotation=ceil(max_left_rotation), 809 | max_right_rotation=ceil(max_right_rotation))) 810 | 811 | def rotate_without_crop(self, probability, max_left_rotation, max_right_rotation, expand=False): 812 | """ 813 | Rotate an image without automatically cropping. 814 | 815 | The :attr:`expand` parameter controls whether the image is enlarged 816 | to contain the new rotated images, or if the image size is maintained 817 | Defaults to :attr:`false` so that images maintain their dimensions 818 | when using this function. 819 | 820 | :param probability: A value between 0 and 1 representing the 821 | probability that the operation should be performed. 822 | :param max_left_rotation: The maximum number of degrees the image can 823 | be rotated to the left. 824 | :param max_right_rotation: The maximum number of degrees the image can 825 | be rotated to the right. 826 | :type probability: Float 827 | :type max_left_rotation: Integer 828 | :type max_right_rotation: Integer 829 | :param expand: Controls whether the image's size should be 830 | increased to accommodate the rotation. Defaults to :attr:`false` 831 | so that images maintain their original dimensions after rotation. 832 | :return: None 833 | """ 834 | self.add_operation(RotateStandard(probability=probability, max_left_rotation=ceil(max_left_rotation), 835 | max_right_rotation=ceil(max_right_rotation), expand=expand)) 836 | 837 | def flip_top_bottom(self, probability): 838 | """ 839 | Flip (mirror) the image along its vertical axis, i.e. from top to 840 | bottom. 841 | 842 | .. seealso:: The :func:`flip_left_right` function. 843 | 844 | :param probability: A value between 0 and 1 representing the 845 | probability that the operation should be performed. 846 | :type probability: Float 847 | :return: None 848 | """ 849 | if not 0 < probability <= 1: 850 | raise ValueError(Pipeline._probability_error_text) 851 | else: 852 | self.add_operation(Flip(probability=probability, top_bottom_left_right="TOP_BOTTOM")) 853 | 854 | def flip_left_right(self, probability): 855 | """ 856 | Flip (mirror) the image along its horizontal axis, i.e. from left to 857 | right. 858 | 859 | .. seealso:: The :func:`flip_top_bottom` function. 860 | 861 | :param probability: A value between 0 and 1 representing the 862 | probability that the operation should be performed. 863 | :type probability: Float 864 | :return: None 865 | """ 866 | if not 0 < probability <= 1: 867 | raise ValueError(Pipeline._probability_error_text) 868 | else: 869 | self.add_operation(Flip(probability=probability, top_bottom_left_right="LEFT_RIGHT")) 870 | 871 | def flip_random(self, probability): 872 | """ 873 | Flip (mirror) the image along **either** its horizontal or vertical 874 | axis. 875 | 876 | This function mirrors the image along either the horizontal axis or 877 | the vertical access. The axis is selected randomly. 878 | 879 | :param probability: A value between 0 and 1 representing the 880 | probability that the operation should be performed. 881 | :type probability: Float 882 | :return: None 883 | """ 884 | if not 0 < probability <= 1: 885 | raise ValueError(Pipeline._probability_error_text) 886 | else: 887 | self.add_operation(Flip(probability=probability, top_bottom_left_right="RANDOM")) 888 | 889 | def random_distortion(self, probability, grid_width, grid_height, magnitude): 890 | """ 891 | Performs a random, elastic distortion on an image. 892 | 893 | This function performs a randomised, elastic distortion controlled 894 | by the parameters specified. The grid width and height controls how 895 | fine the distortions are. Smaller sizes will result in larger, more 896 | pronounced, and less granular distortions. Larger numbers will result 897 | in finer, more granular distortions. The magnitude of the distortions 898 | can be controlled using magnitude. This can be random or fixed. 899 | 900 | *Good* values for parameters are between 2 and 10 for the grid 901 | width and height, with a magnitude of between 1 and 10. Using values 902 | outside of these approximate ranges may result in unpredictable 903 | behaviour. 904 | 905 | :param probability: A value between 0 and 1 representing the 906 | probability that the operation should be performed. 907 | :param grid_width: The number of rectangles in the grid's horizontal 908 | axis. 909 | :param grid_height: The number of rectangles in the grid's vertical 910 | axis. 911 | :param magnitude: The magnitude of the distortions. 912 | :type probability: Float 913 | :type grid_width: Integer 914 | :type grid_height: Integer 915 | :type magnitude: Integer 916 | :return: None 917 | """ 918 | if not 0 < probability <= 1: 919 | raise ValueError(Pipeline._probability_error_text) 920 | else: 921 | self.add_operation(Distort(probability=probability, grid_width=grid_width, 922 | grid_height=grid_height, magnitude=magnitude)) 923 | 924 | def gaussian_distortion(self, probability, grid_width, grid_height, magnitude, corner, method, mex=0.5, mey=0.5, 925 | sdx=0.05, sdy=0.05): 926 | """ 927 | Performs a random, elastic gaussian distortion on an image. 928 | 929 | This function performs a randomised, elastic gaussian distortion controlled 930 | by the parameters specified. The grid width and height controls how 931 | fine the distortions are. Smaller sizes will result in larger, more 932 | pronounced, and less granular distortions. Larger numbers will result 933 | in finer, more granular distortions. The magnitude of the distortions 934 | can be controlled using magnitude. This can be random or fixed. 935 | 936 | *Good* values for parameters are between 2 and 10 for the grid 937 | width and height, with a magnitude of between 1 and 10. Using values 938 | outside of these approximate ranges may result in unpredictable 939 | behaviour. 940 | 941 | :param probability: A value between 0 and 1 representing the 942 | probability that the operation should be performed. 943 | :param grid_width: The number of rectangles in the grid's horizontal 944 | axis. 945 | :param grid_height: The number of rectangles in the grid's vertical 946 | axis. 947 | :param magnitude: The magnitude of the distortions. 948 | :param corner: which corner of picture to distort. 949 | Possible values: "bell"(circular surface applied), "ul"(upper left), 950 | "ur"(upper right), "dl"(down left), "dr"(down right). 951 | :param method: possible values: "in"(apply max magnitude to the chosen 952 | corner), "out"(inverse of method in). 953 | :param mex: used to generate 3d surface for similar distortions. 954 | Surface is based on normal distribution. 955 | :param mey: used to generate 3d surface for similar distortions. 956 | Surface is based on normal distribution. 957 | :param sdx: used to generate 3d surface for similar distortions. 958 | Surface is based on normal distribution. 959 | :param sdy: used to generate 3d surface for similar distortions. 960 | Surface is based on normal distribution. 961 | :type probability: Float 962 | :type grid_width: Integer 963 | :type grid_height: Integer 964 | :type magnitude: Integer 965 | :type corner: String 966 | :type method: String 967 | :type mex: Float 968 | :type mey: Float 969 | :type sdx: Float 970 | :type sdy: Float 971 | :return: None 972 | 973 | For values :attr:`mex`, :attr:`mey`, :attr:`sdx`, and :attr:`sdy` the 974 | surface is based on the normal distribution: 975 | 976 | .. math:: 977 | 978 | e^{- \Big( \\frac{(x-\\text{mex})^2}{\\text{sdx}} + \\frac{(y-\\text{mey})^2}{\\text{sdy}} \Big) } 979 | """ 980 | if not 0 < probability <= 1: 981 | raise ValueError(Pipeline._probability_error_text) 982 | else: 983 | self.add_operation(GaussianDistortion(probability=probability, grid_width=grid_width, 984 | grid_height=grid_height, 985 | magnitude=magnitude, corner=corner, 986 | method=method, mex=mex, 987 | mey=mey, sdx=sdx, sdy=sdy)) 988 | 989 | def zoom(self, probability, min_factor, max_factor): 990 | """ 991 | Zoom in to an image, while **maintaining its size**. The amount by 992 | which the image is zoomed is a randomly chosen value between 993 | :attr:`min_factor` and :attr:`max_factor`. 994 | 995 | Typical values may be ``min_factor=1.1`` and ``max_factor=1.5``. 996 | 997 | To zoom by a constant amount, set :attr:`min_factor` and 998 | :attr:`max_factor` to the same value. 999 | 1000 | .. seealso:: See :func:`zoom_random` for zooming into random areas 1001 | of the image. 1002 | 1003 | :param probability: A value between 0 and 1 representing the 1004 | probability that the operation should be performed. 1005 | :param min_factor: The minimum factor by which to zoom the image. 1006 | :param max_factor: The maximum factor by which to zoom the image. 1007 | :type probability: Float 1008 | :type min_factor: Float 1009 | :type max_factor: Float 1010 | :return: None 1011 | """ 1012 | if not 0 < probability <= 1: 1013 | raise ValueError(Pipeline._probability_error_text) 1014 | elif min_factor <= 0: 1015 | raise ValueError("The min_factor argument must be greater than 0.") 1016 | else: 1017 | self.add_operation(Zoom(probability=probability, min_factor=min_factor, max_factor=max_factor)) 1018 | 1019 | def zoom_random(self, probability, percentage_area, randomise_percentage_area=False): 1020 | """ 1021 | Zooms into an image at a random location within the image. 1022 | 1023 | You can randomise the zoom level by setting the 1024 | :attr:`randomise_percentage_area` argument to true. 1025 | 1026 | .. seealso:: See :func:`zoom` for zooming into the centre of images. 1027 | 1028 | :param probability: The probability that the function will execute 1029 | when the image is passed through the pipeline. 1030 | :param percentage_area: The area, as a percentage of the current 1031 | image's area, to crop. 1032 | :param randomise_percentage_area: If True, will use 1033 | :attr:`percentage_area` as an upper bound and randomise the crop from 1034 | between 0 and :attr:`percentage_area`. 1035 | :return: None 1036 | """ 1037 | if not 0 < probability <= 1: 1038 | raise ValueError(Pipeline._probability_error_text) 1039 | elif not 0.1 <= percentage_area < 1: 1040 | raise ValueError("The percentage_area argument must be greater than 0.1 and less than 1.") 1041 | elif not isinstance(randomise_percentage_area, bool): 1042 | raise ValueError("The randomise_percentage_area argument must be True or False.") 1043 | else: 1044 | self.add_operation(ZoomRandom(probability=probability, percentage_area=percentage_area, randomise=randomise_percentage_area)) 1045 | 1046 | def crop_by_size(self, probability, width, height, centre=True): 1047 | """ 1048 | Crop an image according to a set of dimensions. 1049 | 1050 | Crop each image according to :attr:`width` and :attr:`height`, by 1051 | default at the centre of each image, otherwise at a random location 1052 | within the image. 1053 | 1054 | .. seealso:: See :func:`crop_random` to crop a random, non-centred 1055 | area of the image. 1056 | 1057 | If the crop area exceeds the size of the image, this function will 1058 | crop the entire area of the image. 1059 | 1060 | :param probability: The probability that the function will execute 1061 | when the image is passed through the pipeline. 1062 | :param width: The width of the desired crop. 1063 | :param height: The height of the desired crop. 1064 | :param centre: If **True**, crops from the centre of the image, 1065 | otherwise crops at a random location within the image, maintaining 1066 | the dimensions specified. 1067 | :type probability: Float 1068 | :type width: Integer 1069 | :type height: Integer 1070 | :type centre: Boolean 1071 | :return: None 1072 | """ 1073 | if not 0 < probability <= 1: 1074 | raise ValueError(Pipeline._probability_error_text) 1075 | elif width <= 1: 1076 | raise ValueError("The width argument must be greater than 1.") 1077 | elif height <= 1: 1078 | raise ValueError("The height argument must be greater than 1.") 1079 | elif not isinstance(centre, bool): 1080 | raise ValueError("The centre argument must be True or False.") 1081 | else: 1082 | self.add_operation(Crop(probability=probability, width=width, height=height, centre=centre)) 1083 | 1084 | def crop_centre(self, probability, percentage_area, randomise_percentage_area=False): 1085 | """ 1086 | Crops the centre of an image as a percentage of the image's area. 1087 | 1088 | :param probability: The probability that the function will execute 1089 | when the image is passed through the pipeline. 1090 | :param percentage_area: The area, as a percentage of the current 1091 | image's area, to crop. 1092 | :param randomise_percentage_area: If True, will use 1093 | :attr:`percentage_area` as an upper bound and randomise the crop from 1094 | between 0 and :attr:`percentage_area`. 1095 | :type probability: Float 1096 | :type percentage_area: Float 1097 | :type randomise_percentage_area: Boolean 1098 | :return: None 1099 | """ 1100 | if not 0 < probability <= 1: 1101 | raise ValueError(Pipeline._probability_error_text) 1102 | elif not 0.1 <= percentage_area < 1: 1103 | raise ValueError("The percentage_area argument must be greater than 0.1 and less than 1.") 1104 | elif not isinstance(randomise_percentage_area, bool): 1105 | raise ValueError("The randomise_percentage_area argument must be True or False.") 1106 | else: 1107 | self.add_operation(CropPercentage(probability=probability, percentage_area=percentage_area, centre=True, 1108 | randomise_percentage_area=randomise_percentage_area)) 1109 | 1110 | def crop_random(self, probability, percentage_area, randomise_percentage_area=False): 1111 | """ 1112 | Crop a random area of an image, based on the percentage area to be 1113 | returned. 1114 | 1115 | This function crops a random area from an image, based on the area you 1116 | specify using :attr:`percentage_area`. 1117 | 1118 | :param probability: The probability that the function will execute 1119 | when the image is passed through the pipeline. 1120 | :param percentage_area: The area, as a percentage of the current 1121 | image's area, to crop. 1122 | :param randomise_percentage_area: If True, will use 1123 | :attr:`percentage_area` as an upper bound and randomise the crop from 1124 | between 0 and :attr:`percentage_area`. 1125 | :type probability: Float 1126 | :type percentage_area: Float 1127 | :type randomise_percentage_area: Boolean 1128 | :return: None 1129 | """ 1130 | if not 0 < probability <= 1: 1131 | raise ValueError(Pipeline._probability_error_text) 1132 | elif not 0.1 <= percentage_area < 1: 1133 | raise ValueError("The percentage_area argument must be greater than 0.1 and less than 1.") 1134 | elif not isinstance(randomise_percentage_area, bool): 1135 | raise ValueError("The randomise_percentage_area argument must be True or False.") 1136 | else: 1137 | self.add_operation(CropPercentage(probability=probability, percentage_area=percentage_area, centre=False, 1138 | randomise_percentage_area=randomise_percentage_area)) 1139 | 1140 | def histogram_equalisation(self, probability=1.0): 1141 | """ 1142 | Apply histogram equalisation to the image. 1143 | 1144 | :param probability: A value between 0 and 1 representing the 1145 | probability that the operation should be performed. For histogram, 1146 | equalisation it is recommended that the probability be set to 1. 1147 | :type probability: Float 1148 | :return: None 1149 | """ 1150 | if not 0 < probability <= 1: 1151 | raise ValueError(Pipeline._probability_error_text) 1152 | else: 1153 | self.add_operation(HistogramEqualisation(probability=probability)) 1154 | 1155 | def scale(self, probability, scale_factor): 1156 | """ 1157 | Scale (enlarge) an image, while maintaining its aspect ratio. This 1158 | returns an image with larger dimensions than the original image. 1159 | 1160 | Use :func:`resize` to resize an image to absolute pixel values. 1161 | 1162 | :param probability: A value between 0 and 1 representing the 1163 | probability that the operation should be performed. 1164 | :param scale_factor: The factor to scale by, which must be greater 1165 | than 1.0. 1166 | :type probability: Float 1167 | :type scale_factor: Float 1168 | :return: None 1169 | """ 1170 | if not 0 < probability <= 1: 1171 | raise ValueError(Pipeline._probability_error_text) 1172 | elif scale_factor <= 1.0: 1173 | raise ValueError("The scale_factor argument must be greater than 1.") 1174 | else: 1175 | self.add_operation(Scale(probability=probability, scale_factor=scale_factor)) 1176 | 1177 | def resize(self, probability, width, height, resample_filter="BICUBIC"): 1178 | """ 1179 | Resize an image according to a set of dimensions specified by the 1180 | user in pixels. 1181 | 1182 | :param probability: A value between 0 and 1 representing the 1183 | probability that the operation should be performed. For resizing, 1184 | it is recommended that the probability be set to 1. 1185 | :param width: The new width that the image should be resized to. 1186 | :param height: The new height that the image should be resized to. 1187 | :param resample_filter: The resampling filter to use. Must be one of 1188 | BICUBIC, BILINEAR, ANTIALIAS, or NEAREST. 1189 | :type probability: Float 1190 | :type width: Integer 1191 | :type height: Integer 1192 | :type resample_filter: String 1193 | :return: None 1194 | """ 1195 | if not 0 < probability <= 1: 1196 | raise ValueError(Pipeline._probability_error_text) 1197 | elif not width > 1: 1198 | raise ValueError("Width must be greater than 1.") 1199 | elif not height > 1: 1200 | raise ValueError("Height must be greater than 1.") 1201 | elif resample_filter not in Pipeline._legal_filters: 1202 | raise ValueError("The save_filter argument must be one of %s." % Pipeline._legal_filters) 1203 | else: 1204 | self.add_operation(Resize(probability=probability, width=width, height=height, resample_filter=resample_filter)) 1205 | 1206 | def skew_left_right(self, probability, magnitude=1): 1207 | """ 1208 | Skew an image by tilting it left or right by a random amount. The 1209 | magnitude of this skew can be set to a maximum using the 1210 | magnitude parameter. This can be either a scalar representing the 1211 | maximum tilt, or vector representing a range. 1212 | 1213 | To see examples of the various skews, see :ref:`perspectiveskewing`. 1214 | 1215 | :param probability: A value between 0 and 1 representing the 1216 | probability that the operation should be performed. 1217 | :param magnitude: The maximum tilt, which must be value between 0.1 1218 | and 1.0, where 1 represents a tilt of 45 degrees. 1219 | :type probability: Float 1220 | :type magnitude: Float 1221 | :return: None 1222 | """ 1223 | if not 0 < probability <= 1: 1224 | raise ValueError(Pipeline._probability_error_text) 1225 | elif not 0 < magnitude <= 1: 1226 | raise ValueError("The magnitude argument must be greater than 0 and less than or equal to 1.") 1227 | else: 1228 | self.add_operation(Skew(probability=probability, skew_type="TILT_LEFT_RIGHT", magnitude=magnitude)) 1229 | 1230 | def skew_top_bottom(self, probability, magnitude=1): 1231 | """ 1232 | Skew an image by tilting it forwards or backwards by a random amount. 1233 | The magnitude of this skew can be set to a maximum using the 1234 | magnitude parameter. This can be either a scalar representing the 1235 | maximum tilt, or vector representing a range. 1236 | 1237 | To see examples of the various skews, see :ref:`perspectiveskewing`. 1238 | 1239 | :param probability: A value between 0 and 1 representing the 1240 | probability that the operation should be performed. 1241 | :param magnitude: The maximum tilt, which must be value between 0.1 1242 | and 1.0, where 1 represents a tilt of 45 degrees. 1243 | :type probability: Float 1244 | :type magnitude: Float 1245 | :return: None 1246 | """ 1247 | if not 0 < probability <= 1: 1248 | raise ValueError(Pipeline._probability_error_text) 1249 | elif not 0 < magnitude <= 1: 1250 | raise ValueError("The magnitude argument must be greater than 0 and less than or equal to 1.") 1251 | else: 1252 | self.add_operation(Skew(probability=probability, 1253 | skew_type="TILT_TOP_BOTTOM", 1254 | magnitude=magnitude)) 1255 | 1256 | def skew_tilt(self, probability, magnitude=1): 1257 | """ 1258 | Skew an image by tilting in a random direction, either forwards, 1259 | backwards, left, or right, by a random amount. The magnitude of 1260 | this skew can be set to a maximum using the magnitude parameter. 1261 | This can be either a scalar representing the maximum tilt, or 1262 | vector representing a range. 1263 | 1264 | To see examples of the various skews, see :ref:`perspectiveskewing`. 1265 | 1266 | :param probability: A value between 0 and 1 representing the 1267 | probability that the operation should be performed. 1268 | :param magnitude: The maximum tilt, which must be value between 0.1 1269 | and 1.0, where 1 represents a tilt of 45 degrees. 1270 | :type probability: Float 1271 | :type magnitude: Float 1272 | :return: None 1273 | """ 1274 | if not 0 < probability <= 1: 1275 | raise ValueError(Pipeline._probability_error_text) 1276 | elif not 0 < magnitude <= 1: 1277 | raise ValueError("The magnitude argument must be greater than 0 and less than or equal to 1.") 1278 | else: 1279 | self.add_operation(Skew(probability=probability, 1280 | skew_type="TILT", 1281 | magnitude=magnitude)) 1282 | 1283 | def skew_corner(self, probability, magnitude=1): 1284 | """ 1285 | Skew an image towards one corner, randomly by a random magnitude. 1286 | 1287 | To see examples of the various skews, see :ref:`perspectiveskewing`. 1288 | 1289 | :param probability: A value between 0 and 1 representing the 1290 | probability that the operation should be performed. 1291 | :param magnitude: The maximum skew, which must be value between 0.1 1292 | and 1.0. 1293 | :return: 1294 | """ 1295 | if not 0 < probability <= 1: 1296 | raise ValueError(Pipeline._probability_error_text) 1297 | elif not 0 < magnitude <= 1: 1298 | raise ValueError("The magnitude argument must be greater than 0 and less than or equal to 1.") 1299 | else: 1300 | self.add_operation(Skew(probability=probability, 1301 | skew_type="CORNER", 1302 | magnitude=magnitude)) 1303 | 1304 | def skew(self, probability, magnitude=1): 1305 | """ 1306 | Skew an image in a random direction, either left to right, 1307 | top to bottom, or one of 8 corner directions. 1308 | 1309 | To see examples of all the skew types, see :ref:`perspectiveskewing`. 1310 | 1311 | :param probability: A value between 0 and 1 representing the 1312 | probability that the operation should be performed. 1313 | :param magnitude: The maximum skew, which must be value between 0.1 1314 | and 1.0. 1315 | :type probability: Float 1316 | :type magnitude: Float 1317 | :return: None 1318 | """ 1319 | if not 0 < probability <= 1: 1320 | raise ValueError(Pipeline._probability_error_text) 1321 | elif not 0 < magnitude <= 1: 1322 | raise ValueError("The magnitude argument must be greater than 0 and less than or equal to 1.") 1323 | else: 1324 | self.add_operation(Skew(probability=probability, 1325 | skew_type="RANDOM", 1326 | magnitude=magnitude)) 1327 | 1328 | def shear(self, probability, max_shear_left, max_shear_right): 1329 | """ 1330 | Shear the image by a specified number of degrees. 1331 | 1332 | In practice, shear angles of more than 25 degrees can cause 1333 | unpredictable behaviour. If you are observing images that are 1334 | incorrectly rendered (e.g. they do not contain any information) 1335 | then reduce the shear angles. 1336 | 1337 | :param probability: The probability that the operation is performed. 1338 | :param max_shear_left: The max number of degrees to shear to the left. 1339 | Cannot be larger than 25 degrees. 1340 | :param max_shear_right: The max number of degrees to shear to the 1341 | right. Cannot be larger than 25 degrees. 1342 | :return: None 1343 | """ 1344 | if not 0 < probability <= 1: 1345 | raise ValueError(Pipeline._probability_error_text) 1346 | elif not 0 < max_shear_left <= 25: 1347 | raise ValueError("The max_shear_left argument must be between 0 and 25.") 1348 | elif not 0 < max_shear_right <= 25: 1349 | raise ValueError("The max_shear_right argument must be between 0 and 25.") 1350 | else: 1351 | self.add_operation(Shear(probability=probability, 1352 | max_shear_left=max_shear_left, 1353 | max_shear_right=max_shear_right)) 1354 | 1355 | def greyscale(self, probability): 1356 | """ 1357 | Convert images to greyscale. For this operation, setting the 1358 | :attr:`probability` to 1.0 is recommended. 1359 | 1360 | .. seealso:: The :func:`black_and_white` function. 1361 | 1362 | :param probability: A value between 0 and 1 representing the 1363 | probability that the operation should be performed. For resizing, 1364 | it is recommended that the probability be set to 1. 1365 | :type probability: Float 1366 | :return: None 1367 | """ 1368 | if not 0 < probability <= 1: 1369 | raise ValueError(Pipeline._probability_error_text) 1370 | else: 1371 | self.add_operation(Greyscale(probability=probability)) 1372 | 1373 | def black_and_white(self, probability, threshold=128): 1374 | """ 1375 | Convert images to black and white. In other words convert the image 1376 | to use a 1-bit, binary palette. The threshold defaults to 128, 1377 | but can be controlled using the :attr:`threshold` parameter. 1378 | 1379 | .. seealso:: The :func:`greyscale` function. 1380 | 1381 | :param probability: A value between 0 and 1 representing the 1382 | probability that the operation should be performed. For resizing, 1383 | it is recommended that the probability be set to 1. 1384 | :param threshold: A value between 0 and 255 which controls the 1385 | threshold point at which each pixel is converted to either black 1386 | or white. Any values above this threshold are converted to white, and 1387 | any values below this threshold are converted to black. 1388 | :type probability: Float 1389 | :type threshold: Integer 1390 | :return: None 1391 | """ 1392 | if not 0 < probability <= 1: 1393 | raise ValueError(Pipeline._probability_error_text) 1394 | elif not 0 <= threshold <= 255: 1395 | raise ValueError("The threshold must be between 0 and 255.") 1396 | else: 1397 | self.add_operation(BlackAndWhite(probability=probability, threshold=threshold)) 1398 | 1399 | def invert(self, probability): 1400 | """ 1401 | Invert an image. For this operation, setting the 1402 | :attr:`probability` to 1.0 is recommended. 1403 | 1404 | .. warning:: This function will cause errors if used on binary, 1-bit 1405 | palette images (e.g. black and white). 1406 | 1407 | :param probability: A value between 0 and 1 representing the 1408 | probability that the operation should be performed. For resizing, 1409 | it is recommended that the probability be set to 1. 1410 | :return: None 1411 | """ 1412 | if not 0 < probability <= 1: 1413 | raise ValueError(Pipeline._probability_error_text) 1414 | else: 1415 | self.add_operation(Invert(probability=probability)) 1416 | 1417 | def random_erasing(self, probability, rectangle_area): 1418 | """ 1419 | Work in progress. This operation performs a Random Erasing operation, 1420 | as described in 1421 | `https://arxiv.org/abs/1708.04896 `_ 1422 | by Zhong et al. 1423 | 1424 | Its purpose is to make models robust to occlusion, by randomly 1425 | replacing rectangular regions with random pixel values. 1426 | 1427 | For greyscale images the random pixels values will also be greyscale, 1428 | and for RGB images the random pixels values will be in RGB. 1429 | 1430 | This operation is subject to change, the original work describes 1431 | several ways of filling the random regions, including a random 1432 | solid colour or greyscale value. Currently this operations uses 1433 | the method which yielded the best results in the tests performed 1434 | by Zhong et al. 1435 | 1436 | :param probability: A value between 0 and 1 representing the 1437 | probability that the operation should be performed. 1438 | :param rectangle_area: The percentage area of the image to occlude 1439 | with the random rectangle, between 0.1 and 1. 1440 | :return: None 1441 | """ 1442 | if not 0 < probability <= 1: 1443 | raise ValueError(Pipeline._probability_error_text) 1444 | elif not 0.1 < rectangle_area <= 1: 1445 | raise ValueError("The rectangle_area must be between 0.1 and 1.") 1446 | else: 1447 | self.add_operation(RandomErasing(probability=probability, rectangle_area=rectangle_area)) 1448 | 1449 | def ground_truth(self, ground_truth_directory): 1450 | """ 1451 | Specifies a directory containing corresponding images that 1452 | constitute respective ground truth images for the images 1453 | in the current pipeline. 1454 | 1455 | This function will search the directory specified by 1456 | :attr:`ground_truth_directory` and will associate each ground truth 1457 | image with the images in the pipeline by file name. 1458 | 1459 | Therefore, an image titled ``cat321.jpg`` will match with the 1460 | image ``cat321.jpg`` in the :attr:`ground_truth_directory`. 1461 | The function respects each image's label, therefore the image 1462 | named ``cat321.jpg`` with the label ``cat`` will match the image 1463 | ``cat321.jpg`` in the subdirectory ``cat`` relative to 1464 | :attr:`ground_truth_directory`. 1465 | 1466 | Typically used to specify a set of ground truth or gold standard 1467 | images that should be augmented alongside the original images 1468 | of a dataset, such as image masks or semantic segmentation ground 1469 | truth images. 1470 | 1471 | A number of such data sets are openly available, see for example 1472 | `https://arxiv.org/pdf/1704.06857.pdf `_ 1473 | (Garcia-Garcia et al., 2017). 1474 | 1475 | :param ground_truth_directory: A directory containing the 1476 | ground truth images that correspond to the images in the 1477 | current pipeline. 1478 | :type ground_truth_directory: String 1479 | :return: None. 1480 | """ 1481 | 1482 | num_of_ground_truth_images_added = 0 1483 | 1484 | # Progress bar 1485 | progress_bar = tqdm(total=len(self.augmentor_images), desc="Processing", unit=' Images', leave=False) 1486 | 1487 | if len(self.class_labels) == 1: 1488 | for augmentor_image_idx in range(len(self.augmentor_images)): 1489 | ground_truth_image = os.path.join(ground_truth_directory, 1490 | self.augmentor_images[augmentor_image_idx].image_file_name) 1491 | if os.path.isfile(ground_truth_image): 1492 | self.augmentor_images[augmentor_image_idx].ground_truth = ground_truth_image 1493 | num_of_ground_truth_images_added += 1 1494 | else: 1495 | for i in range(len(self.class_labels)): 1496 | for augmentor_image_idx in range(len(self.augmentor_images)): 1497 | ground_truth_image = os.path.join(ground_truth_directory, 1498 | self.augmentor_images[augmentor_image_idx].class_label, 1499 | self.augmentor_images[augmentor_image_idx].image_file_name) 1500 | if os.path.isfile(ground_truth_image): 1501 | if self.augmentor_images[augmentor_image_idx].class_label == self.class_labels[i][0]: 1502 | # Check files are the same size. There may be a better way to do this. 1503 | original_image_dimensions = \ 1504 | Image.open(self.augmentor_images[augmentor_image_idx].image_path).size 1505 | ground_image_dimensions = Image.open(ground_truth_image).size 1506 | if original_image_dimensions == ground_image_dimensions: 1507 | self.augmentor_images[augmentor_image_idx].ground_truth = ground_truth_image 1508 | num_of_ground_truth_images_added += 1 1509 | progress_bar.update(1) 1510 | 1511 | progress_bar.close() 1512 | 1513 | # May not be required after all, check later. 1514 | if num_of_ground_truth_images_added != 0: 1515 | self.process_ground_truth_images = True 1516 | 1517 | print("%s ground truth image(s) found." % num_of_ground_truth_images_added) 1518 | 1519 | def get_ground_truth_paths(self): 1520 | """ 1521 | Returns a list of image and ground truth image path pairs. Used for 1522 | verification purposes to ensure the ground truth images match to the 1523 | images containing in the pipeline. 1524 | 1525 | :return: A list of tuples containing the image path and ground truth 1526 | path pairs. 1527 | """ 1528 | paths = [] 1529 | 1530 | for augmentor_image in self.augmentor_images: 1531 | print("Image path: %s\nGround truth path: %s\n---\n" % (augmentor_image.image_path, augmentor_image.ground_truth)) 1532 | paths.append((augmentor_image.image_path, augmentor_image.ground_truth)) 1533 | 1534 | return paths 1535 | -------------------------------------------------------------------------------- /tnseg/Augmentor/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Augmentor image augmentation library. 3 | 4 | Augmentor is a software package for augmenting image data. It provides a number of utilities that aid augmentation \ 5 | in a automated manner. The aim of the package is to make augmentation for machine learning tasks less prone to \ 6 | error, more reproducible, more efficient, and easier to perform. 7 | 8 | .. moduleauthor:: Marcus D. Bloice 9 | :platform: Windows, Linux, Macintosh 10 | :synopsis: An image augmentation library for Machine Learning. 11 | 12 | """ 13 | 14 | from .Pipeline import Pipeline 15 | 16 | __author__ = """Marcus D. Bloice""" 17 | __email__ = 'marcus.bloice@medunigraz.at' 18 | __version__ = '0.2.0' 19 | 20 | __all__ = ['Pipeline'] 21 | -------------------------------------------------------------------------------- /tnseg/__init__.py: -------------------------------------------------------------------------------- 1 | from . import opts 2 | -------------------------------------------------------------------------------- /tnseg/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | import pydicom as pyd 5 | from glob import glob 6 | from keras.preprocessing.image import ImageDataGenerator 7 | from keras.utils import to_categorical 8 | import bisect 9 | import random 10 | import math 11 | from Augmentor import * 12 | 13 | def import_dicom_data(path): 14 | data_path = path + '/images/' 15 | annot_path = path + '/labels/' 16 | data_list = glob(data_path + '*.dcm') 17 | annot_list = glob(annot_path + '*.dcm') 18 | N = len(data_list) 19 | data = [] 20 | annot = [] 21 | annot_frames = np.zeros((N)) 22 | print('Data Image Resolutions') 23 | for i in range(N): 24 | x = pyd.read_file(data_list[i]).pixel_array 25 | x = x[:len(x) / 2] 26 | y = pyd.read_file(annot_list[i]).pixel_array 27 | y = y[:len(y) / 2] 28 | n_frame = 0 29 | for j in range(y.shape[0]): 30 | if np.where(y[j] == 1)[0].size > 0: 31 | n_frame += 1 32 | annot_frames[i] = n_frame 33 | print(x.shape, n_frame) 34 | data.append(x) 35 | annot.append(y) 36 | return data, annot 37 | 38 | def zeropad(data, annot, h_max, w_max): 39 | # If the data is a list of images of different resolutions 40 | # useful in testing 41 | if isinstance(data, list): 42 | n = len(data) 43 | data_pad = np.zeros((n, h_max, w_max)) 44 | annot_pad = np.zeros((n, h_max, w_max)) 45 | for i in range(n): 46 | pad_l1 = (h_max - data[i].shape[0]) // 2 47 | pad_l2 = (h_max - data[i].shape[0]) - (h_max - data[i].shape[0]) // 2 48 | pad_h1 = (w_max - data[i].shape[1]) // 2 49 | pad_h2 = (w_max - data[i].shape[1]) - (w_max - data[i].shape[1]) // 2 50 | data_pad[i] = np.pad(data[i], ((pad_l1, pad_l2), (pad_h1, pad_h2)), 'constant', 51 | constant_values=((0, 0), (0, 0))) 52 | annot_pad[i] = np.pad(annot[i], ((pad_l1, pad_l2), (pad_h1, pad_h2)), 'constant', 53 | constant_values=((0, 0), (0, 0))) 54 | # If data is a numpy array with images of same resolution 55 | else: 56 | pad_l1 = (h_max - data.shape[1]) // 2 57 | pad_l2 = (h_max - data.shape[1]) - (h_max - data.shape[1]) // 2 58 | pad_h1 = (w_max - data.shape[2]) // 2 59 | pad_h2 = (w_max - data.shape[2]) - (w_max - data.shape[2]) // 2 60 | 61 | data_pad = np.pad(data, ((0, 0), (pad_l1, pad_l2), (pad_h1, pad_h2)), 'constant', 62 | constant_values=((0, 0), (0, 0), (0, 0))) 63 | annot_pad = np.pad(annot, ((0, 0), (pad_l1, pad_l2), (pad_h1, pad_h2)), 'constant', 64 | constant_values=((0, 0), (0, 0), (0, 0))) 65 | return data_pad, annot_pad 66 | 67 | def data_augment(imgs, lb): 68 | p = Pipeline() 69 | p.rotate(probability=0.7, max_left_rotation=10, max_right_rotation=10) 70 | imgs_temp, lb_temp = np.zeros(imgs.shape), np.zeros(imgs.shape) 71 | for i in range(imgs.shape[0]): 72 | pil_images = p.sample_with_array(imgs[i], ground_truth=lb[i], mode='L') 73 | imgs_temp[i], lb_temp[i] = np.asarray(pil_images[0]), np.asarray(pil_images[1]) 74 | 75 | return imgs_temp, lb_temp 76 | 77 | def get_weighted_batch(imgs, labels, batch_size, data_aug, high_skew=False): 78 | while 1: 79 | thy_re = [np.count_nonzero(labels[i] == 1) * 1.0 / np.prod(labels[i].shape) for i in range(imgs.shape[0])] 80 | if high_skew==True: 81 | thy_re = [el**2 for el in thy_re] 82 | cumul = [thy_re[0]] 83 | for item in thy_re[1:]: cumul.append(cumul[-1] + item) 84 | total_prob = sum(thy_re) 85 | 86 | ar_inds = [bisect.bisect_right(cumul, random.uniform(0, total_prob)) for i in range(batch_size)] 87 | lb, batch_imgs = labels[ar_inds], imgs[ar_inds] 88 | l, r, t, b = 0, batch_imgs.shape[1], 0, batch_imgs.shape[2] 89 | for i in range(batch_imgs.shape[1]): 90 | if np.all(batch_imgs[:, i, :] == 0): 91 | l = i + 1 92 | else: 93 | break 94 | for i in range(batch_imgs.shape[1] - 1, -1, -1): 95 | if np.all(batch_imgs[:, i, :] == 0): 96 | r = i 97 | else: 98 | break 99 | for i in range(batch_imgs.shape[2]): 100 | if np.all(batch_imgs[:, :, i] == 0): 101 | t = i + 1 102 | else: 103 | break 104 | for i in range(batch_imgs.shape[2] - 1, -1, -1): 105 | if np.all(batch_imgs[:, :, i] == 0): 106 | b = i 107 | else: 108 | break 109 | l, r, t, b = (l // 16) * 16, math.ceil(r * 1.0 / 16) * 16, (t // 16) * 16, math.ceil(b * 1.0 / 16) * 16 110 | l, r, t, b = int(l), int(r), int(t), int(b) 111 | batch_imgs, lb = batch_imgs[:, l:r, t:b], lb[:, l:r, t:b] 112 | if (data_aug): 113 | batch_imgs, lb = data_augment(batch_imgs, lb) 114 | yield np.expand_dims(batch_imgs, axis=3),np.expand_dims(lb, axis=3) 115 | 116 | def get_weighted_batch_window_2d(imgs, labels, batch_size, data_aug, n_window=0, high_skew=False): 117 | # a=0 118 | # if a==0: 119 | # print('datagen') 120 | while 1: 121 | thy_re = [np.count_nonzero(labels[i] == 1) * 1.0 / np.prod(labels[i].shape) for i in range(imgs.shape[0])] 122 | if high_skew==True: 123 | thy_re = [el**2 for el in thy_re] 124 | cumul = [thy_re[0]] 125 | for item in thy_re[1:]: cumul.append(cumul[-1] + item) 126 | total_prob = sum(thy_re) 127 | 128 | ar_inds = [bisect.bisect_right(cumul, random.uniform(0, total_prob)) for i in range(batch_size)] 129 | if n_window==0: 130 | batch_imgs = imgs[ar_inds] 131 | # Get n_window frames per index. 132 | else: 133 | batch_imgs = np.zeros((batch_size*n_window,imgs.shape[1],imgs.shape[2])) 134 | for i in range(batch_size): 135 | if ar_inds[i]==0: 136 | ar_inds[i] = 1 137 | elif ar_inds[i] == len(imgs)-1: 138 | ar_inds[i] -= 1 139 | batch_imgs[n_window*i:n_window*(i+1)] = imgs[ar_inds[i]-1:ar_inds[i]+2] 140 | lb = labels[ar_inds] 141 | l, r, t, b = 0, batch_imgs.shape[1], 0, batch_imgs.shape[2] 142 | for i in range(batch_imgs.shape[1]): 143 | if np.all(batch_imgs[:, i, :] == 0): 144 | l = i + 1 145 | else: 146 | break 147 | for i in range(batch_imgs.shape[1] - 1, -1, -1): 148 | if np.all(batch_imgs[:, i, :] == 0): 149 | r = i 150 | else: 151 | break 152 | for i in range(batch_imgs.shape[2]): 153 | if np.all(batch_imgs[:, :, i] == 0): 154 | t = i + 1 155 | else: 156 | break 157 | for i in range(batch_imgs.shape[2] - 1, -1, -1): 158 | if np.all(batch_imgs[:, :, i] == 0): 159 | b = i 160 | else: 161 | break 162 | l, r, t, b = (l // 16) * 16, math.ceil(r * 1.0 / 16) * 16, (t // 16) * 16, math.ceil(b * 1.0 / 16) * 16 163 | l, r, t, b = int(l), int(r), int(t), int(b) 164 | batch_imgs, lb = batch_imgs[:, l:r, t:b], lb[:, l:r, t:b] 165 | # batch_imgs_3d = np.zeros((batch_size,imgs.shape[1], imgs.shape[2], n_window)) 166 | # k=0 167 | # for i in range(batch_size): 168 | # for j in range(n_window): 169 | # batch_imgs_3d[i,:,:,j] = batch_imgs[k,:,:] 170 | # k += 1 171 | batch_imgs = np.array([np.rollaxis(batch_imgs[n_window*i:n_window*(i+1)],0,3) for i in range(batch_size)]) 172 | if (data_aug): 173 | batch_imgs, lb = batch_imgs, lb#data_augment(batch_imgs, lb) 174 | # print('batch = ',batch_imgs.shape, lb.shape) 175 | yield batch_imgs,np.expand_dims(lb, axis=3) 176 | 177 | def get_max_dimensions(data_list): 178 | return 320, 448 179 | 180 | def create_generators(datadir=None, batch_size=64, augmentation_args=None,\ 181 | model='unet', zero_padding=[0,0], data_skew=False, validation_index=None, window=0): 182 | 183 | # Load data from the data directory 184 | if datadir==None: 185 | raise Exception("Data directory not specified") 186 | data_list, annot_list = import_dicom_data(datadir) 187 | print(len(data_list)) 188 | 189 | # Get the max dimensions of the DICOM frames, and zeropad all images 190 | h_max, w_max = get_max_dimensions(data_list) 191 | for i, data in enumerate(data_list): 192 | data_list[i], annot_list[i] = zeropad(data_list[i], annot_list[i], h_max, w_max) 193 | 194 | # Get train and validation data 195 | N = len(data_list) 196 | if validation_index==None: 197 | raise Exception("Please specify validation indices") 198 | else: 199 | trn_imgs = [] 200 | trn_labels = [] 201 | val_imgs = [] 202 | val_labels = [] 203 | for i in range(len(data_list)): 204 | if i in validation_index: 205 | val_imgs.append(data_list[i]) 206 | val_labels.append(annot_list[i]) 207 | else: 208 | trn_imgs.append(data_list[i]) 209 | trn_labels.append(annot_list[i]) 210 | val_imgs = np.concatenate(val_imgs,axis=0) 211 | val_labels = np.concatenate(val_labels,axis=0) 212 | trn_imgs = np.concatenate(trn_imgs,axis=0) 213 | trn_labels = np.concatenate(trn_labels,axis=0) 214 | print(val_imgs.shape, val_labels.shape, trn_imgs.shape, trn_labels.shape) 215 | 216 | # Data generator for augmentation 217 | if augmentation_args !=None: 218 | data_augment=True 219 | datagen = ImageDataGenerator( 220 | rotation_range=augmentation_args['rotation_range'], 221 | width_shift_range=augmentation_args['width_shift_range'], 222 | height_shift_range=augmentation_args['height_shift_range'], 223 | shear_range=augmentation_args['shear_range'], 224 | zoom_range=augmentation_args['zoom_range'], 225 | fill_mode=augmentation_args['fill_mode']) 226 | else: 227 | data_augment=False 228 | datagen = ImageDataGenerator( 229 | rotation_range=0., 230 | width_shift_range=0., 231 | height_shift_range=0., 232 | shear_range=0., 233 | zoom_range=0., 234 | horizontal_flip=0., 235 | fill_mode=0.) 236 | 237 | # Get model specific data generators 238 | if model in ['unet', 'dilated-unet', 'dilated-densenet']: 239 | if data_skew==True: 240 | train_generator = get_weighted_batch(trn_imgs, trn_labels, batch_size, data_augment) 241 | val_generator = get_weighted_batch(val_imgs, val_labels, batch_size, data_augment, high_skew=True) 242 | else: 243 | train_generator = datagen.flow(x=np.expand_dims(trn_imgs, axis=3), y=np.expand_dims(trn_labels, axis=3), batch_size=16) 244 | val_generator = datagen.flow(x=np.expand_dims(val_imgs, axis=3), y=np.expand_dims(val_labels, axis=3), batch_size=16) 245 | elif model=='window-unet': 246 | train_generator = get_weighted_batch_window_2d(trn_imgs, trn_labels, batch_size, data_augment, window) 247 | val_generator = get_weighted_batch_window_2d(val_imgs, val_labels, batch_size, data_augment, window, high_skew=True) 248 | 249 | return train_generator, val_generator 250 | -------------------------------------------------------------------------------- /tnseg/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('.') 4 | 5 | from dataset import * 6 | from scipy.misc import imsave 7 | import matplotlib.pyplot as plt 8 | import os 9 | import pdb 10 | import pydensecrf.densecrf as dcrf 11 | from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral, create_pairwise_gaussian 12 | from loss import * 13 | from tqdm import tqdm 14 | 15 | from ufarray import * 16 | sys.setrecursionlimit(10000) 17 | 18 | # save error plots 19 | def eval_error_plots(out, output_dir): 20 | # get train and val acc and loss 21 | loss = out.history['loss'] 22 | val_loss = out.history['val_loss'] 23 | acc_key = [i for i in out.history.keys() if ('val' not in i and 'loss' not in i)][0] 24 | acc = out.history[acc_key] 25 | val_acc = out.history['val_' + acc_key] 26 | 27 | # Plot and save them 28 | plt.figure() 29 | plt.plot(loss, 'b', label='Training') 30 | plt.plot(val_loss, 'r', label='Validation') 31 | plt.title('Training vs Validation loss') 32 | plt.legend() 33 | plt.savefig(output_dir + 'plot_loss.png', dpi=300) 34 | plt.close() 35 | plt.figure() 36 | plt.plot(acc, 'b', label='Training') 37 | plt.plot(val_acc, 'r', label='Validation') 38 | plt.title('Training vs Validation ' + acc_key) 39 | plt.legend() 40 | plt.savefig(output_dir + 'plot_accuracy.png', dpi=300) 41 | plt.close() 42 | 43 | 44 | def post_processing(data, probas): 45 | [n,h,w] = data.shape 46 | n_labels = 2 47 | pred_maps = np.zeros(data.shape) 48 | print 'postprocessing:', data.shape, probas.shape 49 | for i in tqdm(range(n)): 50 | img = data[i][...,np.newaxis] 51 | proba = probas[i] 52 | labels = np.zeros((2,img.shape[0],img.shape[1])) 53 | labels[0] = 1-proba 54 | labels[1] = proba 55 | 56 | U = unary_from_softmax(labels) # note: num classes is first dim 57 | pairwise_energy = create_pairwise_bilateral(sdims=(50,50), schan=(5,), img=img, chdim=2) 58 | pairwise_gaussian = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2]) 59 | 60 | d = dcrf.DenseCRF2D(w, h, n_labels) 61 | d.setUnaryEnergy(U) 62 | d.addPairwiseEnergy(pairwise_gaussian, compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 63 | d.addPairwiseEnergy(pairwise_energy, compat=5, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) # `compat` is the "strength" of this potential. 64 | 65 | Q = d.inference(50) 66 | pred_maps[i] = np.argmax(Q, axis=0).reshape((h,w)) 67 | return pred_maps 68 | 69 | 70 | def remove_smaller_components(im): 71 | if im.max()==0.0: 72 | return im 73 | sizes = {} 74 | im_ = im.copy() 75 | def dfs(i,j, root, key_elem, change_to): 76 | if i>=0 and i=0 and j< im_.shape[1] and im_[i,j] ==key_elem: 77 | im_[i][j]=change_to 78 | if root in sizes: 79 | sizes[root] += 1 80 | else: 81 | sizes[root] =0 82 | dfs(i-1,j,root,key_elem, change_to) 83 | dfs(i+1,j,root,key_elem, change_to) 84 | dfs(i,j-1,root,key_elem, change_to) 85 | dfs(i,j+1,root,key_elem, change_to) 86 | 87 | for i in range(im_.shape[0]): 88 | for j in range(im_.shape[1]): 89 | dfs(i,j, tuple((i,j)),1,2) 90 | 91 | big_comp = max(sizes, key=sizes.get) 92 | dfs(big_comp[0], big_comp[1], big_comp, 2,1) 93 | im_[im_>1] = 0 94 | return im_ 95 | 96 | def evaluate_test_folder(model, save_path=None, test_path=None, postproc=False, n_window=3): 97 | # Convert the data into input for the UNet 98 | img_path_list = [path for path in os.listdir(test_path + 'images/')] 99 | data = np.array([plt.imread(test_path + 'images/' + path) for path in img_path_list]) 100 | annot = np.array([plt.imread(test_path + 'groundtruth/' + path) for path in img_path_list]) 101 | # print data.min(), data.max(), annot.max() 102 | n_frame = len(data) 103 | # print(data.shape, annot.shape) 104 | if(data.shape[1]%16!=0 or data.shape[2]%16!=0): 105 | pad_width_h1 = int(np.floor((16-data.shape[1]%16)/2)) 106 | pad_width_h2 = 16 - data.shape[1]%16 - pad_width_h1 107 | pad_width_w1 = int(np.floor((16-data.shape[2]%16)/2)) 108 | pad_width_w2 = 16 - data.shape[2]%16 - pad_width_w1 109 | data = np.pad(data,((0,0),(pad_width_h1,pad_width_h2),(pad_width_w1,pad_width_w2)),'constant') 110 | annot = np.pad(annot,((0,0),(pad_width_h1,pad_width_h2),(pad_width_w1,pad_width_w2)),'constant') 111 | # print(data.shape, annot.shape) 112 | #data_pad, annot_pad = zeropad(data, annot, h_max, w_max) 113 | if n_window==0: 114 | probas = model.predict(data[...,np.newaxis]*255., batch_size=8)[...,0] 115 | else: 116 | data_window = np.zeros((n_frame, data.shape[1], data.shape[2], n_window)) 117 | n_window_half = int((n_window-1)/2) 118 | for i in range(n_window_half,n_frame-n_window_half): 119 | data_window[i] = np.rollaxis(data[i-n_window_half:i+n_window_half+1],0,3) 120 | # print(data_window.shape) 121 | probas = model.predict(data_window*255.)[...,0] 122 | if postproc==True: 123 | probas = post_processing(data*255., probas) 124 | 125 | # Threshold predictions 126 | thresh = 0.5 127 | pred_maps = probas.copy() 128 | pred_maps[probas>=thresh] = 1#255 129 | pred_maps[probas0.04: 150 | # annot_.append(annot[i]) 151 | # annot_pred_.append(annot_pred[i]) 152 | # dice_coef_.append(dice_coef[i]) 153 | # return np.mean(np.array(dice_coef_)) 154 | 155 | dice_coef_avg = 0.0 156 | for i in range(n_frame): 157 | dice_coef_avg += dice_coef_numpy(annot[i], pred_maps[i]) 158 | dice_coef_avg /= n_frame 159 | print('Folder dice coef pred maps= ',dice_coef_avg) 160 | dice_coef = dice_coef_numpy(annot, probas) 161 | print('Folder dice coef = ',dice_coef) 162 | 163 | # Save the images onto disk 164 | if save_path !=None: 165 | for i in range(n_frame): 166 | plt.figure() 167 | ax = plt.subplot('131') 168 | ax.imshow(data[i], cmap='gray') 169 | ax.set_title('Actual Image') 170 | ax = plt.subplot('132') 171 | ax.imshow(annot[i], cmap='gray') 172 | ax.set_title('True Annotation') 173 | ax = plt.subplot('133') 174 | ax.imshow(pred_maps[i], cmap='gray') 175 | ax.set_title('Predicted Annotation') 176 | plt.savefig(save_path + img_path_list[i]) 177 | plt.close() 178 | return dice_coef, pred_maps 179 | 180 | -------------------------------------------------------------------------------- /tnseg/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from keras import backend as K 3 | import numpy as np 4 | import tensorflow as tf 5 | from keras.losses import binary_crossentropy 6 | 7 | def dice_coef(y_true, y_pred): 8 | smooth = 1. 9 | y_true_f = K.flatten(y_true) 10 | y_pred_f = K.flatten(y_pred) 11 | intersection = K.sum(y_true_f * y_pred_f) 12 | return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 13 | 14 | def dice_coef_numpy(y_true, y_pred): 15 | smooth = 1. 16 | y_true_f = np.ndarray.flatten(y_true) 17 | y_pred_f = np.ndarray.flatten(y_pred) 18 | intersection = np.sum(y_true_f * y_pred_f) 19 | return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth) 20 | 21 | def bin_crossentropy_loss(y_true, y_pred): 22 | return binary_crossentropy(y_true, y_pred) 23 | 24 | def iou_score(y_true, y_pred): 25 | smooth = 1. 26 | y_true_f = K.flatten(y_true) 27 | y_pred_f = K.flatten(y_pred) 28 | intersection = K.sum(y_true_f * y_pred_f) 29 | return (1. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) 30 | 31 | def dice_coef_loss(y_true, y_pred): 32 | return -dice_coef(y_true, y_pred) 33 | 34 | def dice_coef_log_loss(y_true, y_pred): 35 | return -K.log(dice_coef(y_true, y_pred)) 36 | 37 | def iou_score_loss(y_true, y_pred): 38 | return -iou_score(y_true, y_pred) 39 | 40 | 41 | -------------------------------------------------------------------------------- /tnseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import unet 2 | from .dilated_unet import dilated_unet 3 | from .dilated_densenet import dilated_densenet 4 | from .window_unet import window_unet 5 | -------------------------------------------------------------------------------- /tnseg/models/dilated_densenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from keras.models import Model 4 | from keras.layers import Input, concatenate, Concatenate, Conv2D, Conv3D, MaxPooling2D, Conv2DTranspose, Dropout, \ 5 | BatchNormalization, merge, UpSampling2D, Cropping2D, ZeroPadding2D, Reshape, core, Convolution2D, Activation, Lambda 6 | from keras import backend as K 7 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 8 | 9 | def dilated_densenet1(height, width, channels, classes, features=12, depth=4, 10 | temperature=1.0, padding='same', batchnorm=False, 11 | dropout=0.0, dilation=dilation): 12 | dilation = tuple(map(int, dilation)) 13 | x = Input(shape=(height, width, channels)) 14 | inputs = x 15 | 16 | # initial convolution 17 | x = Conv2D(features, kernel_size=(5,5), padding=padding)(x) 18 | 19 | maps = [inputs] 20 | dilation_rate = 1 21 | kernel_size = (3,3) 22 | for n in range(depth): 23 | maps.append(x) 24 | x = Concatenate()(maps) 25 | x = BatchNormalization()(x) 26 | x = Activation('relu')(x) 27 | x = Conv2D(features, kernel_size, dilation_rate=dilation_rate, 28 | padding=padding)(x) 29 | dilation_rate *= 2 30 | 31 | probabilities = Conv2D(1, kernel_size=(1,1), activation='sigmoid')(x) 32 | 33 | model = Model(inputs=inputs, outputs=probabilities) 34 | return model 35 | 36 | def dilated_densenet2(height, width, channels, classes, features=12, depth=4, 37 | temperature=1.0, padding='same', batchnorm=False, 38 | dropout=0.0, dilation=dilation): 39 | dilation = tuple(map(int, dilation)) 40 | x = Input(shape=(height, width, channels)) 41 | inputs = x 42 | 43 | # initial convolution 44 | x = Conv2D(features, kernel_size=(5,5), padding=padding)(x) 45 | 46 | maps = [inputs] 47 | dilation_rate = 1 48 | kernel_size = (3,3) 49 | for n in range(depth): 50 | maps.append(x) 51 | x = Concatenate()(maps) 52 | 53 | x = BatchNormalization()(x) 54 | x = Activation('relu')(x) 55 | x = Conv2D(4*features, kernel_size=1)(x) 56 | 57 | x = BatchNormalization()(x) 58 | x = Activation('relu')(x) 59 | x = Conv2D(features, kernel_size, dilation_rate=dilation_rate, 60 | padding=padding)(x) 61 | dilation_rate *= 2 62 | 63 | probabilities = Conv2D(1, kernel_size=(1,1), activation='sigmoid')(x) 64 | 65 | model = Model(inputs=inputs, outputs=probabilities) 66 | return model 67 | 68 | def dilated_densenet(height, width, channels, classes, features=12, depth=4, 69 | temperature=1.0, padding='same', batchnorm=False, 70 | dropout=0.0, dilation=dilation): 71 | dilation = tuple(map(int, dilation)) 72 | x = Input(shape=(height, width, channels)) 73 | inputs = x 74 | 75 | # initial convolution 76 | x = Conv2D(features, kernel_size=(5,5), padding=padding)(x) 77 | 78 | maps = [inputs] 79 | dilation_rate = 1 80 | kernel_size = (3,3) 81 | for n in range(depth): 82 | maps.append(x) 83 | x = Concatenate()(maps) 84 | x = BatchNormalization()(x) 85 | x = Activation('relu')(x) 86 | x = Conv2D(features, kernel_size, dilation_rate=dilation_rate, 87 | padding=padding)(x) 88 | dilation_rate *= 2 89 | 90 | # Additional 2 layers to help generate segmentation mask 91 | x = Conv2D(features, kernel_size=(3,3), activation='relu', padding=padding)(x) 92 | x = Conv2D(features, kernel_size=(3,3), activation='relu', padding=padding)(x) 93 | 94 | probabilities = Conv2D(1, kernel_size=(1,1), activation='sigmoid')(x) 95 | 96 | model = Model(inputs=inputs, outputs=probabilities) 97 | return model 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /tnseg/models/dilated_unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from keras.models import Model 4 | from keras.layers import Input, concatenate, Concatenate, Conv2D, Conv3D, MaxPooling2D, Conv2DTranspose, Dropout, \ 5 | BatchNormalization, merge, UpSampling2D, Cropping2D, ZeroPadding2D, Reshape, core, Convolution2D, Activation, Lambda 6 | from keras import backend as K 7 | 8 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 9 | 10 | # Dialted unet 11 | def downsampling_block(input_tensor, filters, padding='same', 12 | batchnorm=False, dropout=0.0): 13 | _, height, width, _ = K.int_shape(input_tensor) 14 | # assert height % 2 == 0 15 | # assert width % 2 == 0 16 | 17 | x = Conv2D(filters, kernel_size=(3,3), padding=padding, 18 | dilation_rate=1)(input_tensor) 19 | x = BatchNormalization()(x) if batchnorm else x 20 | x = Activation('relu')(x) 21 | x = Dropout(dropout)(x) if dropout > 0 else x 22 | 23 | x = Conv2D(filters, kernel_size=(3,3), padding=padding, dilation_rate=2)(x) 24 | x = BatchNormalization()(x) if batchnorm else x 25 | x = Activation('relu')(x) 26 | x = Dropout(dropout)(x) if dropout > 0 else x 27 | 28 | return MaxPooling2D(pool_size=(2,2))(x), x 29 | 30 | def upsampling_block(input_tensor, skip_tensor, filters, padding='same', 31 | batchnorm=False, dropout=0.0): 32 | x = Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2))(input_tensor) 33 | 34 | # compute amount of cropping needed for skip_tensor 35 | # _, x_height, x_width, _ = K.int_shape(x) 36 | # _, s_height, s_width, _ = K.int_shape(skip_tensor) 37 | h_crop = 0 # s_height - x_height 38 | w_crop = 0 # s_width - x_width 39 | # assert h_crop >= 0 40 | # assert w_crop >= 0 41 | if h_crop == 0 and w_crop == 0: 42 | y = skip_tensor 43 | else: 44 | cropping = ((h_crop//2, h_crop - h_crop//2), (w_crop//2, w_crop - w_crop//2)) 45 | y = Cropping2D(cropping=cropping)(skip_tensor) 46 | 47 | # x = Concatenate()([x, y]) 48 | x = concatenate([x, y], axis=3) 49 | 50 | # no dilation in upsampling convolutions 51 | x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x) 52 | x = BatchNormalization()(x) if batchnorm else x 53 | x = Activation('relu')(x) 54 | x = Dropout(dropout)(x) if dropout > 0 else x 55 | 56 | x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x) 57 | x = BatchNormalization()(x) if batchnorm else x 58 | x = Activation('relu')(x) 59 | x = Dropout(dropout)(x) if dropout > 0 else x 60 | 61 | return x 62 | 63 | def dilated_unet(height, width, channels, classes, features=64, depth=4, 64 | temperature=1.0, padding='same', batchnorm=False, 65 | dropout=0.0, dilation_layers=5, dilation=dilation): 66 | """Generate `dilated U-Net' model where the convolutions in the encoding and 67 | bottleneck are replaced by dilated convolutions. The second convolution in 68 | pair at a given scale in the encoder is dilated by 2. The number of 69 | dilation layers in the innermost bottleneck is controlled by the 70 | `dilation_layers' parameter -- this is the `context module' proposed by Yu, 71 | Koltun 2016 in "Multi-scale Context Aggregation by Dilated Convolutions" 72 | 73 | Arbitrary number of input channels and output classes are supported. 74 | 75 | Arguments: 76 | height - input image height (pixels) 77 | width - input image width (pixels) 78 | channels - input image features (1 for grayscale, 3 for RGB) 79 | classes - number of output classes (2 in paper) 80 | features - number of output features for first convolution (64 in paper) 81 | Number of features double after each down sampling block 82 | depth - number of downsampling operations (4 in paper) 83 | padding - 'valid' (used in paper) or 'same' 84 | batchnorm - include batch normalization layers before activations 85 | dropout - fraction of units to dropout, 0 to keep all units 86 | dilation_layers - number of dilated convolutions in innermost bottleneck 87 | 88 | Output: 89 | Dilated U-Net model expecting input shape (height, width, maps) and 90 | generates output with shape (output_height, output_width, classes). 91 | If padding is 'same', then output_height = height and 92 | output_width = width. 93 | 94 | """ 95 | dilation = tuple(map(int, dilation)) 96 | x = Input(shape=(height, width, channels)) 97 | inputs = x 98 | 99 | skips = [] 100 | for i in range(depth): 101 | x, x0 = downsampling_block(x, features, padding, 102 | batchnorm, dropout) 103 | skips.append(x0) 104 | features *= 2 105 | 106 | dilation_rate = 1 107 | for n in range(dilation_layers): 108 | x = Conv2D(filters=features, kernel_size=(3,3), padding=padding, 109 | dilation_rate=dilation_rate)(x) 110 | x = BatchNormalization()(x) if batchnorm else x 111 | x = Activation('relu')(x) 112 | x = Dropout(dropout)(x) if dropout > 0 else x 113 | dilation_rate *= 2 114 | 115 | for i in reversed(range(depth)): 116 | features //= 2 117 | x = upsampling_block(x, skips[i], features, padding, 118 | batchnorm, dropout) 119 | 120 | x = Conv2D(filters=1, kernel_size=(1,1))(x) 121 | 122 | logits = Lambda(lambda z: z/temperature)(x) 123 | probabilities = Activation('sigmoid')(logits) 124 | 125 | model = Model(inputs=inputs, outputs=probabilities) 126 | return model 127 | 128 | 129 | -------------------------------------------------------------------------------- /tnseg/models/unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from keras.models import Model 4 | from keras.layers import Input, concatenate, Concatenate, Conv2D, Conv3D, MaxPooling2D, Conv2DTranspose, Dropout, \ 5 | BatchNormalization, merge, UpSampling2D, Cropping2D, ZeroPadding2D, Reshape, core, Convolution2D, Activation, Lambda 6 | from keras import backend as K 7 | 8 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 9 | 10 | def get_crop_shape(target, refer): 11 | # width, the 3rd dimension 12 | cw = (target.get_shape()[2] - refer.get_shape()[2]).value 13 | assert (cw >= 0) 14 | if cw % 2 != 0: 15 | cw1, cw2 = int(cw/2), int(cw/2) + 1 16 | else: 17 | cw1, cw2 = int(cw/2), int(cw/2) 18 | # height, the 2nd dimension 19 | ch = (target.get_shape()[1] - refer.get_shape()[1]).value 20 | assert (ch >= 0) 21 | if ch % 2 != 0: 22 | ch1, ch2 = int(ch/2), int(ch/2) + 1 23 | else: 24 | ch1, ch2 = int(ch/2), int(ch/2) 25 | 26 | return (ch1, ch2), (cw1, cw2) 27 | 28 | def conv_block(x, n_filt=64, padding='same', dropout=0.0, batchnorm=False, dilation=(1,1), pool=True): 29 | def conv_l(inp): 30 | 31 | conv = Conv2D(n_filt, (3, 3), padding=padding, dilation_rate=dilation)(inp) 32 | conv = Activation('relu')(conv) 33 | conv = BatchNormalization()(conv) if batchnorm else conv 34 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 35 | return conv 36 | 37 | conv = conv_l(x) 38 | conv = conv_l(conv) 39 | pool = MaxPooling2D(pool_size=(2, 2))(conv) if pool else conv 40 | return conv,pool 41 | 42 | def upconv_block(x, x_conv, n_filt, padding='same', dropout=0.0, batchnorm=False): 43 | #up_conv = UpSampling2D(size=(2, 2), data_format="channels_last")(x) 44 | up_conv = Conv2DTranspose(n_filt, (2, 2), strides=(2, 2), padding=padding)(x) 45 | # crop x_conv 46 | if padding=='valid': 47 | ch, cw = get_crop_shape(x_conv, up_conv) 48 | x_conv = Cropping2D(cropping=(ch,cw), data_format="channels_last")(x_conv) 49 | up = concatenate([up_conv, x_conv], axis=3) 50 | 51 | conv = Conv2D(n_filt, (3, 3), padding=padding, dilation_rate=(1,1))(up) 52 | conv = Activation('relu')(conv) 53 | conv = BatchNormalization()(conv) if batchnorm else conv 54 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 55 | 56 | conv = Conv2D(n_filt, (3, 3), padding=padding, dilation_rate=(1,1))(conv) 57 | conv = Activation('relu')(conv) 58 | conv = BatchNormalization()(conv) if batchnorm else conv 59 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 60 | return conv 61 | 62 | def unet(height=None, width=None, channels=1, features=32, 63 | depth=4, padding='same', temperature=1.0, 64 | batchnorm=False, dropout=0.0, dilation=(1,1)): 65 | 66 | # Define the input 67 | dilation = tuple(map(int, dilation)) 68 | inputs = Input((height, width, channels)) 69 | pool = True 70 | 71 | # Contracting path 72 | conv1, pool1 = conv_block(inputs, features , padding, dropout, batchnorm, dilation, pool) 73 | conv2, pool2 = conv_block(pool1 , features*2 , padding, dropout, batchnorm, dilation, pool) 74 | conv3, pool3 = conv_block(pool2 , features*4 , padding, dropout, batchnorm, dilation, pool) 75 | conv4, pool4 = conv_block(pool3 , features*8 , padding, dropout, batchnorm, dilation, pool) 76 | conv5, _ = conv_block(pool4 , features*16, padding, dropout, batchnorm, dilation, pool) 77 | 78 | # Expanding path 79 | conv6 = upconv_block(conv5, conv4, features*8, padding, dropout, batchnorm) 80 | conv7 = upconv_block(conv6, conv3, features*4, padding, dropout, batchnorm) 81 | conv8 = upconv_block(conv7, conv2, features*2, padding, dropout, batchnorm) 82 | conv9 = upconv_block(conv8, conv1, features*1, padding, dropout, batchnorm) 83 | conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9) 84 | 85 | return Model(inputs=[inputs], outputs=[conv10]) 86 | 87 | 88 | -------------------------------------------------------------------------------- /tnseg/models/window_unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from keras.models import Model 4 | from keras.layers import Input, concatenate, Concatenate, Conv2D, Conv3D, MaxPooling2D, Conv2DTranspose, Dropout, \ 5 | BatchNormalization, merge, UpSampling2D, Cropping2D, ZeroPadding2D, Reshape, core, Convolution2D, Activation, Lambda 6 | from keras import backend as K 7 | 8 | K.set_image_data_format('channels_last') # TF dimension ordering in this code 9 | 10 | def conv_block(x, n_filt=64, padding='same', dropout=0.0, batchnorm=False, dilation=(1,1), pool=True): 11 | def conv_l(inp): 12 | 13 | conv = Conv2D(n_filt, (3, 3), padding=padding, dilation_rate=dilation)(inp) 14 | conv = Activation('relu')(conv) 15 | conv = BatchNormalization()(conv) if batchnorm else conv 16 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 17 | return conv 18 | 19 | conv = conv_l(x) 20 | conv = conv_l(conv) 21 | pool = MaxPooling2D(pool_size=(2, 2))(conv) if pool else conv 22 | return conv,pool 23 | 24 | def upconv_block(x, x_conv, n_filt, padding='same', dropout=0.0, batchnorm=False): 25 | #up_conv = UpSampling2D(size=(2, 2), data_format="channels_last")(x) 26 | up_conv = Conv2DTranspose(n_filt, (2, 2), strides=(2, 2), padding=padding)(x) 27 | # crop x_conv 28 | if padding=='valid': 29 | ch, cw = get_crop_shape(x_conv, up_conv) 30 | x_conv = Cropping2D(cropping=(ch,cw), data_format="channels_last")(x_conv) 31 | up = concatenate([up_conv, x_conv], axis=3) 32 | 33 | conv = Conv2D(n_filt, (3, 3), padding=padding, dilation_rate=(1,1))(up) 34 | conv = Activation('relu')(conv) 35 | conv = BatchNormalization()(conv) if batchnorm else conv 36 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 37 | 38 | conv = Conv2D(n_filt, (3, 3), padding=padding, dilation_rate=(1,1))(conv) 39 | conv = Activation('relu')(conv) 40 | conv = BatchNormalization()(conv) if batchnorm else conv 41 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 42 | return conv 43 | 44 | def conv_block_window(x, n_filt=64, padding='same', dropout=0.0, batchnorm=False, dilation=(1,1), pool=True, window_size=3): 45 | def conv_l(inp): 46 | conv = Conv2D(n_filt, (3, 3), padding=padding, dilation_rate=dilation)(inp) 47 | conv = Activation('relu')(conv) 48 | conv = BatchNormalization()(conv) if batchnorm else conv 49 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 50 | return conv 51 | 52 | def conv_l_window(inp,n_window): 53 | conv = Conv3D(n_filt, (3, 3, n_window), padding=padding, dilation_rate=1)(inp) 54 | conv = Activation('relu')(conv) 55 | conv = BatchNormalization()(conv) if batchnorm else conv 56 | conv = Dropout(dropout)(conv) if dropout>0.0 else conv 57 | return conv 58 | 59 | conv = conv_l_window(x,window_size) 60 | # conv = Lambda(lambda y: K.squeeze(y, axis=3))(conv) 61 | # conv = conv_l(conv) 62 | # pool = MaxPooling2D(pool_size=(2, 2))(conv) if pool else conv 63 | return conv,0#pool 64 | 65 | # Window UNet implementation using 3d kernel on 1st layer 66 | def unet_window_3d(img_rows, img_cols, init_filt=32, padding='same', dropout=0.0, batchnorm=False, dilation=(1,1), pool=True, window_size=3): 67 | 68 | # Define the input 69 | inputs = Input((320,448, window_size, 1)) 70 | 71 | # Contracting path 72 | conv1, pool1 = conv_block_window(inputs, init_filt , padding, dropout, batchnorm, dilation, pool, window_size) 73 | # conv2, pool2 = conv_block(pool1 , init_filt*2 , padding, dropout, batchnorm, dilation, pool) 74 | # conv3, pool3 = conv_block(pool2 , init_filt*4 , padding, dropout, batchnorm, dilation, pool) 75 | # conv4, pool4 = conv_block(pool3 , init_filt*8 , padding, dropout, batchnorm, dilation, pool) 76 | # conv5, _ = conv_block(pool4 , init_filt*16, padding, dropout, batchnorm, dilation, pool) 77 | # 78 | # # Expanding path 79 | # conv6 = upconv_block(conv5, conv4, init_filt*8, padding, dropout, batchnorm) 80 | # conv7 = upconv_block(conv6, conv3, init_filt*4, padding, dropout, batchnorm) 81 | # conv8 = upconv_block(conv7, conv2, init_filt*2, padding, dropout, batchnorm) 82 | # conv9 = upconv_block(conv8, conv1, init_filt*1, padding, dropout, batchnorm) 83 | # conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9) 84 | 85 | model = Model(inputs=[inputs], outputs=[conv1]) 86 | 87 | model.compile(optimizer=Adam(lr=1e-5), loss=dice_coef_loss, metrics=[dice_coef]) 88 | #model.compile(optimizer=Adam(lr=1e-5), loss=sparse_categorical_crossentropy, metrics=[dice_coef]) 89 | 90 | return model 91 | 92 | # Window UNet implementation 2D (use 2D kernel and sum conv outputs of frames to get feature maps) 93 | def window_unet(height=None, width=None, features=32, padding='same', dropout=0.0, batchnorm=False, dilation=(1,1), pool=True, window_size=3, dilation=dilation): 94 | 95 | # Define the input 96 | dilation = tuple(map(int, dilation)) 97 | inputs = Input((None, None, window_size)) 98 | 99 | # Contracting path 100 | conv1, pool1 = conv_block(inputs, features , padding, dropout, batchnorm, dilation, pool) 101 | conv2, pool2 = conv_block(pool1 , features*2 , padding, dropout, batchnorm, dilation, pool) 102 | conv3, pool3 = conv_block(pool2 , features*4 , padding, dropout, batchnorm, dilation, pool) 103 | conv4, pool4 = conv_block(pool3 , features*8 , padding, dropout, batchnorm, dilation, pool) 104 | conv5, _ = conv_block(pool4 , features*16, padding, dropout, batchnorm, dilation, pool) 105 | 106 | # Expanding path 107 | conv6 = upconv_block(conv5, conv4, features*8, padding, dropout, batchnorm) 108 | conv7 = upconv_block(conv6, conv3, features*4, padding, dropout, batchnorm) 109 | conv8 = upconv_block(conv7, conv2, features*2, padding, dropout, batchnorm) 110 | conv9 = upconv_block(conv8, conv1, features*1, padding, dropout, batchnorm) 111 | conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9) 112 | 113 | model = Model(inputs=[inputs], outputs=[conv10]) 114 | 115 | return model 116 | 117 | 118 | -------------------------------------------------------------------------------- /tnseg/opts.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import os 4 | import argparse 5 | import configparser 6 | import logging 7 | 8 | import pdb 9 | 10 | definitions = [ 11 | # model type default help 12 | ('model', (str, 'unet', "Model: unet, dilated-unet, dilated-densenet, window-unet, masktrack-unet")), 13 | ('features', (int, 32, "Number of features maps after first convolutional layer.")), 14 | ('depth', (int, 4, "Number of downsampled convolutional blocks.")), 15 | ('temperature', (float, 1.0, "Temperature of final softmax layer in model.")), 16 | ('padding', (str, 'same', "Padding in convolutional layers. Either `same' or `valid'.")), 17 | ('dropout', (float, 0.0, "Rate for dropout of activation units.")), 18 | # ('classes', (str, 'inner', "One of `inner' (endocardium), `outer' (epicardium), or `both'.")), 19 | ('batchnorm', {'default': False, 'action': 'store_true', 20 | 'help': "Apply batch normalization before nonlinearities."}), 21 | ('window', (int, 0, "Window size for Window UNet")), 22 | ('dilation', (int, [1,1], "Dilation parameter for the encoder architecture")), 23 | 24 | # data generator properties 25 | ('zero_padding', {'type': int, 'nargs': '+', 'default': [0,0], 26 | 'help': "zero_padding: None, [320, 448]. If None, zero padding is applied to each batch"}), 27 | ('data_skew', {'default': False, 'action': 'store_true', 28 | 'help': "Skew the probabilities of the batch samples in the ratio of percentage of TRUE pixels"}), 29 | 30 | # loss 31 | ('loss', (str, 'dice', "Loss function: `pixel' for pixel-wise cross entropy, `dice' for dice coefficient.")), 32 | ('loss-weights', {'type': float, 'nargs': '+', 'default': [0.1, 0.9], 33 | 'help': "When using dice or jaccard loss, how much to weight each output class."}), 34 | 35 | # training 36 | ('epochs', (int, 20, "Number of epochs to train.")), 37 | ('batch-size', (int, 32, "Mini-batch size for training.")), 38 | ('validation-split', (float, 0.2, "Percentage of training data to hold out for validation.")), 39 | ('optimizer', (str, 'adam', "Optimizer: sgd, rmsprop, adagrad, adadelta, adam, adamax, or nadam.")), 40 | ('learning-rate', (float, 1e-5, "Optimizer learning rate.")), 41 | ('momentum', (float, None, "Momentum for SGD optimizer.")), 42 | ('decay', (float, None, "Learning rate decay (not applicable for nadam).")), 43 | ('shuffle_train_val', {'default': False, 'action': 'store_true', 44 | 'help': "Shuffle images before splitting into train vs. val."}), 45 | ('shuffle', {'default': False, 'action': 'store_true', 46 | 'help': "Shuffle images before each training epoch."}), 47 | ('seed', (int, None, "Seed for numpy RandomState")), 48 | ('train_steps_per_epoch', (int, 20, "Number of train steps in one epoch.")), 49 | ('val_steps_per_epoch', (int, 8, "Number of validation steps in one epoch.")), 50 | 51 | # files 52 | ('datadir', (str, '../data/', "Directory containing patientXX/ directories.")), 53 | ('outdir', (str, '../output/', "Directory to write output data.")), 54 | # ('outfile', (str, 'weights-final.hdf5', "File to write final model weights.")), 55 | ('load-weights', (str, '', "Load model weights from specified file to initialize training.")), 56 | ('checkpoint', {'default': False, 'action': 'store_true', 57 | 'help': "Write model weights after each epoch if validation accuracy improves."}), 58 | ('ckpt_period', (int, 10, "Period of epochs after which weights are saved")), 59 | 60 | # augmentation 61 | ('data-augment', {'default': False, 'action': 'store_true', 62 | 'help': "Whether to apply image augmentation to training set."}), 63 | # ('augment-training', {'default': False, 'action': 'store_true', 64 | # 'help': "Whether to apply image augmentation to training set."}), 65 | # ('augment-validation', {'default': False, 'action': 'store_true', 66 | # 'help': "Whether to apply image augmentation to validation set."}), 67 | ('rotation-range', (float, 180, "Rotation range (0-180 degrees)")), 68 | ('width-shift-range', (float, 0.1, "Width shift range, as a float fraction of the width")), 69 | ('height-shift-range', (float, 0.1, "Height shift range, as a float fraction of the height")), 70 | ('shear-range', (float, 0.1, "Shear intensity (in radians)")), 71 | ('zoom-range', (float, 0.05, "Amount of zoom. If a scalar z, zoom in [1-z, 1+z]. Can also pass a pair of floats as the zoom range.")), 72 | ('fill-mode', (str, 'nearest', "Points outside boundaries are filled according to mode: constant, nearest, reflect, or wrap")), 73 | ('alpha', (float, 500, "Random elastic distortion: magnitude of distortion")), 74 | ('sigma', (float, 20, "Random elastic distortion: length scale")), 75 | ('normalize', {'default': False, 'action': 'store_true', 76 | 'help': "Subtract mean and divide by std dev from each image."}), 77 | ] 78 | 79 | noninitialized = { 80 | 'learning_rate': 'getfloat', 81 | 'momentum': 'getfloat', 82 | 'decay': 'getfloat', 83 | 'seed': 'getint', 84 | } 85 | 86 | def update_from_configfile(args, default, config, section, key): 87 | # Point of this function is to update the args Namespace. 88 | value = config.get(section, key) 89 | if value == '' or value is None: 90 | return 91 | 92 | # Command-line arguments override config file values 93 | if getattr(args, key) != default: 94 | return 95 | 96 | # Config files always store values as strings -- get correct type 97 | if isinstance(default, bool): 98 | value = config.getboolean(section, key) 99 | elif isinstance(default, int): 100 | value = config.getint(section, key) 101 | elif isinstance(default, float): 102 | value = config.getfloat(section, key) 103 | elif isinstance(default, str): 104 | value = config.get(section, key) 105 | elif isinstance(default, list): 106 | # special case (HACK): loss-weights is list of floats 107 | string = config.get(section, key) 108 | print(string, string.split()) 109 | value = [float(x) for x in string.split()] 110 | elif default is None: 111 | # values which aren't initialized 112 | getter = getattr(config, noninitialized[key]) 113 | value = getter(section, key) 114 | setattr(args, key, value) 115 | 116 | def parse_arguments(): 117 | parser = argparse.ArgumentParser( 118 | description="Train U-Net to segment thyroid nodules from ultrasound images") 119 | 120 | for argname, kwargs in definitions: 121 | d = kwargs 122 | if isinstance(kwargs, tuple): 123 | d = dict(zip(['type', 'default', 'help'], kwargs)) 124 | parser.add_argument('--' + argname, **d) 125 | 126 | # allow user to input configuration file 127 | parser.add_argument( 128 | 'configfile', nargs='?', type=str, help="Load options from config " 129 | "file (command line arguments take precedence).") 130 | 131 | args = parser.parse_args() 132 | 133 | if args.configfile: 134 | logging.info("Loading options from config file: {}".format(args.configfile)) 135 | config = configparser.ConfigParser( 136 | inline_comment_prefixes=['#', ';'], allow_no_value=True) 137 | config.read(args.configfile) 138 | for section in config: 139 | for key in config[section]: 140 | if key not in args: 141 | raise Exception("Unknown option {} in config file.".format(key)) 142 | update_from_configfile(args, parser.get_default(key), 143 | config, section, key) 144 | 145 | for k,v in vars(args).items(): 146 | logging.info("{:20s} = {}".format(k, v)) 147 | 148 | return args 149 | 150 | -------------------------------------------------------------------------------- /tnseg/ufarray.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math, random 3 | from itertools import product 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | def get_max_blob(img): 8 | img = np.invert(img.astype(int)) 9 | labels = run(img) 10 | max_val = max(labels.values(), key=labels.values().count) 11 | img_new = np.zeros(img.shape) 12 | for pixel in labels.keys(): 13 | if labels[pixel]==max_val: 14 | img_new[pixel] = 255 15 | return img_new 16 | 17 | 18 | 19 | # Array based union find data structure 20 | 21 | # P: The array, which encodes the set membership of all the elements 22 | 23 | class UFarray: 24 | def __init__(self): 25 | # Array which holds label -> set equivalences 26 | self.P = [] 27 | 28 | # Name of the next label, when one is created 29 | self.label = 0 30 | 31 | def makeLabel(self): 32 | r = self.label 33 | self.label += 1 34 | self.P.append(r) 35 | return r 36 | 37 | # Makes all nodes "in the path of node i" point to root 38 | def setRoot(self, i, root): 39 | while self.P[i] < i: 40 | j = self.P[i] 41 | self.P[i] = root 42 | i = j 43 | self.P[i] = root 44 | 45 | # Finds the root node of the tree containing node i 46 | def findRoot(self, i): 47 | while self.P[i] < i: 48 | i = self.P[i] 49 | return i 50 | 51 | # Finds the root of the tree containing node i 52 | # Simultaneously compresses the tree 53 | def find(self, i): 54 | root = self.findRoot(i) 55 | self.setRoot(i, root) 56 | return root 57 | 58 | # Joins the two trees containing nodes i and j 59 | # Modified to be less agressive about compressing paths 60 | # because performance was suffering some from over-compression 61 | def union(self, i, j): 62 | if i != j: 63 | root = self.findRoot(i) 64 | rootj = self.findRoot(j) 65 | if root > rootj: root = rootj 66 | self.setRoot(j, root) 67 | self.setRoot(i, root) 68 | 69 | def flatten(self): 70 | for i in range(1, len(self.P)): 71 | self.P[i] = self.P[self.P[i]] 72 | 73 | def flattenL(self): 74 | k = 1 75 | for i in range(1, len(self.P)): 76 | if self.P[i] < i: 77 | self.P[i] = self.P[self.P[i]] 78 | else: 79 | self.P[i] = k 80 | k += 1 81 | 82 | def run(img): 83 | data = img.copy() 84 | width, height = img.shape 85 | 86 | # Union find data structure 87 | uf = UFarray() 88 | 89 | # 90 | # First pass 91 | # 92 | 93 | # Dictionary of point:label pairs 94 | labels = {} 95 | 96 | for y, x in product(range(height), range(width)): 97 | 98 | # 99 | # Pixel names were chosen as shown: 100 | # 101 | # ------------- 102 | # | a | b | c | 103 | # ------------- 104 | # | d | e | | 105 | # ------------- 106 | # | | | | 107 | # ------------- 108 | # 109 | # The current pixel is e 110 | # a, b, c, and d are its neighbors of interest 111 | # 112 | # 255 is white, 0 is black 113 | # White pixels part of the background, so they are ignored 114 | # If a pixel lies outside the bounds of the image, it default to white 115 | # 116 | 117 | # If the current pixel is white, it's obviously not a component... 118 | if data[x, y] == 255: 119 | pass 120 | 121 | # If pixel b is in the image and black: 122 | # a, d, and c are its neighbors, so they are all part of the same component 123 | # Therefore, there is no reason to check their labels 124 | # so simply assign b's label to e 125 | elif y > 0 and data[x, y-1] == 0: 126 | labels[x, y] = labels[(x, y-1)] 127 | 128 | # If pixel c is in the image and black: 129 | # b is its neighbor, but a and d are not 130 | # Therefore, we must check a and d's labels 131 | elif x+1 < width and y > 0 and data[x+1, y-1] == 0: 132 | 133 | c = labels[(x+1, y-1)] 134 | labels[x, y] = c 135 | 136 | # If pixel a is in the image and black: 137 | # Then a and c are connected through e 138 | # Therefore, we must union their sets 139 | if x > 0 and data[x-1, y-1] == 0: 140 | a = labels[(x-1, y-1)] 141 | uf.union(c, a) 142 | 143 | # If pixel d is in the image and black: 144 | # Then d and c are connected through e 145 | # Therefore we must union their sets 146 | elif x > 0 and data[x-1, y] == 0: 147 | d = labels[(x-1, y)] 148 | uf.union(c, d) 149 | 150 | # If pixel a is in the image and black: 151 | # We already know b and c are white 152 | # d is a's neighbor, so they already have the same label 153 | # So simply assign a's label to e 154 | elif x > 0 and y > 0 and data[x-1, y-1] == 0: 155 | labels[x, y] = labels[(x-1, y-1)] 156 | 157 | # If pixel d is in the image and black 158 | # We already know a, b, and c are white 159 | # so simpy assign d's label to e 160 | elif x > 0 and data[x-1, y] == 0: 161 | labels[x, y] = labels[(x-1, y)] 162 | 163 | # All the neighboring pixels are white, 164 | # Therefore the current pixel is a new component 165 | else: 166 | labels[x, y] = uf.makeLabel() 167 | return labels 168 | 169 | 170 | --------------------------------------------------------------------------------