├── Augmentations ├── augmentations.ipynb ├── augmentations.py ├── data_split_save.py ├── rgb2label.py └── utils.py ├── Images ├── architecture.jpg ├── confusion_matrix.png └── predictions.png ├── PostProcessing.py ├── Poster.pdf ├── README.md ├── SimpleSegNet.py ├── Testing prediction on CPU.ipynb ├── Training demo.ipynb ├── Training ├── TrainingClass.py ├── create_h5.py ├── generate_parameters.py └── train.py ├── UNet.py ├── XNet.py ├── _config.yml └── requirements.txt /Augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from sklearn.model_selection import train_test_split 3 | import os, sys 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import imgaug as ia 8 | from imgaug import augmenters as iaa 9 | from imgaug import parameters as iap 10 | import create_h5 11 | import cv2 12 | import glob 13 | from random import shuffle 14 | import h5py 15 | import argparse 16 | from keras.utils import to_categorical 17 | import random 18 | from keras.preprocessing.image import ImageDataGenerator 19 | from utils import random_crop 20 | from utils import shuffle_together 21 | from utils import balanced_test_val_split 22 | import sys 23 | import time 24 | 25 | # ******************* PARAMETERS *************************# 26 | main_path = "Data" 27 | data_to_add = ['Humans','CT','Phantom'] 28 | hdf5_path = "final" 29 | 30 | EXAMPLES_PER_CATEGORY = 500 31 | image_size = 200 32 | train_size = 0.7 33 | n_classes = 3 34 | 35 | 36 | # output hdf5 file 37 | hdf5_name = '_'.join(data_to_add) 38 | 39 | if(EXAMPLES_PER_CATEGORY == 0): 40 | hdf5_name = hdf5_name + '_s' + str(image_size) + '.hdf5' 41 | 42 | else: 43 | hdf5_name = hdf5_name +'_s'+str(image_size)+'_a'+ str(EXAMPLES_PER_CATEGORY)+ '.hdf5' 44 | 45 | 46 | # ******************* TRAIN/TEST/VAL **********************# 47 | 48 | # Get balanced body parts split into train test and validation sets 49 | images_train, labels_train, body_train, filenames_train, images_test, labels_test, body_test, \ 50 | filenames_test, images_val, labels_val, body_val, filenames_val = \ 51 | balanced_test_val_split(main_path, data_to_add, image_size, train_size, n_classes) 52 | 53 | # ******************* AUGMENTATIONS **********************# 54 | 55 | # Find number of augmentations per image in order to have a balanced training set 56 | unique, counts = np.unique(body_train, return_counts=True) 57 | unique_per_category = dict(zip(unique, counts)) 58 | augmentations_per_category = dict(unique_per_category) 59 | for key in unique_per_category: 60 | augmentations_per_category[key] = int(EXAMPLES_PER_CATEGORY/unique_per_category[key]) 61 | 62 | #Augmentation templates 63 | translate_max = 0.01 64 | rotate_max = 15 65 | shear_max = 2 66 | 67 | affine_trasform = iaa.Affine( translate_percent={"x": (-translate_max, translate_max), 68 | "y": (-translate_max, translate_max)}, # translate by +- 69 | rotate=(-rotate_max, rotate_max), # rotate by -rotate_max to +rotate_max degrees 70 | shear=(-shear_max, shear_max), # shear by -shear_max to +shear_max degrees 71 | order=[1], # use nearest neighbour or bilinear interpolation (fast) 72 | cval=125, # if mode is constant, use a cval between 0 and 255 73 | mode="reflect", 74 | #mode = "", 75 | name="Affine", 76 | ) 77 | 78 | 79 | spatial_aug = iaa.Sequential([iaa.Fliplr(0.5), iaa.Flipud(0.5), affine_trasform]) 80 | 81 | other_aug = iaa.SomeOf((1, None), 82 | [ 83 | iaa.OneOf([ 84 | iaa.GaussianBlur((0, 0.4)), # blur images with a sigma between 0 and 1.0 85 | iaa.ElasticTransformation(alpha=(0.5, 1.5), sigma=0.25), # very few 86 | 87 | ]), 88 | 89 | ]) 90 | 91 | 92 | 93 | augmentator = [spatial_aug,other_aug] 94 | total_images=sum(augmentations_per_category[k]*unique_per_category[k] + unique_per_category[k] for k in augmentations_per_category) 95 | images_aug = np.zeros((total_images,images_train.shape[1],images_train.shape[2],images_train.shape[3])) 96 | labels_aug = np.zeros((total_images,labels_train.shape[1],labels_train.shape[2],labels_train.shape[3])) 97 | bodypart = np.empty((total_images),dtype = 'S10') 98 | filenames_aug = np.empty((total_images),dtype = 'S60') 99 | 100 | images_aug[:images_train.shape[0],...] = images_train 101 | labels_aug[:images_train.shape[0],...] = labels_train/255 102 | bodypart[:images_train.shape[0],...] = body_train 103 | filenames_aug[:images_train.shape[0],...] = filenames_train 104 | 105 | # Loop over the different bodyparts 106 | counter = images_train.shape[0] 107 | counter_block = 0 108 | for i, (k, v) in enumerate(augmentations_per_category.items()): 109 | # Indices of images with a given bodypart 110 | indices = np.array(np.where(body_train == k )[0]) 111 | # Number of augmentation per image 112 | N = int(v) 113 | 114 | for j in indices: 115 | for l in range(N): 116 | clear_output(wait=True) 117 | # Freeze randomization to apply same to labels 118 | spatial_det = augmentator[0].to_deterministic() 119 | other_det = augmentator[1] 120 | 121 | images_aug[counter,...] = spatial_det.augment_image(images_train[j]) 122 | 123 | labels_aug[counter,...] = spatial_det.augment_image(labels_train[j]) 124 | img_crop, label_crop = random_crop(images_aug[counter,...],labels_aug[counter,...],0.1,0.4) 125 | images_aug[counter,...] = other_det.augment_image(img_crop ) 126 | labels_aug[counter,...] = to_categorical(np.argmax(label_crop,axis=-1)) 127 | 128 | bodypart[counter] = k 129 | 130 | # Save names of the augmented images starting with aug_ 131 | filenames_aug[counter] = b'aug_' + filenames_train[j] 132 | sys.stdout.write('(Category %s) processing image %i/%i, augmented image %i/%i'%(k,counter_block, 133 | body_train.shape[0], 134 | l+1, N)) 135 | sys.stdout.flush() 136 | time.sleep(0.5) 137 | counter +=1 138 | counter_block +=1 139 | 140 | images_aug, labels_aug, bodypart, filenames_aug = shuffle_together(images_aug, labels_aug, bodypart, filenames_aug) 141 | 142 | images_test, labels_test, body_test, filenames_test = shuffle_together(images_test, labels_test, body_test, filenames_test) 143 | 144 | images_val, labels_val, body_val, filenames_val = shuffle_together(images_val, labels_val, body_val, filenames_val) 145 | 146 | print('Finished playing with cadavers ! ') 147 | 148 | create_h5.write_h5(hdf5_path + hdf5_name, images_aug, labels_aug, bodypart,filenames_aug, images_test, labels_test/255,body_test,filenames_test,\ 149 | images_val, labels_val/255,body_val ,filenames_val) 150 | 151 | 152 | print('Saving the hdf5 at %s ...'%hdf5_name) 153 | -------------------------------------------------------------------------------- /Augmentations/data_split_save.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from sklearn.model_selection import train_test_split 3 | import os, sys 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import imgaug as ia 8 | from imgaug import augmenters as iaa 9 | from imgaug import parameters as iap 10 | import create_h5 11 | import cv2 12 | import glob 13 | from random import shuffle 14 | import h5py 15 | import argparse 16 | from keras.utils import to_categorical 17 | import random 18 | from keras.preprocessing.image import ImageDataGenerator 19 | from utils import random_crop 20 | from utils import shuffle_together 21 | from utils import balanced_test_val_split 22 | import sys 23 | import time 24 | 25 | # ******************* PARAMETERS *************************# 26 | main_path = "Data" 27 | data_to_add = ['Humans','CT','Phantom'] 28 | 29 | image_size = 200 30 | train_size = 0.7 31 | n_classes = 3 32 | 33 | hdf5_path = "final" 34 | 35 | # output hdf5 file 36 | hdf5_name = '_'.join(data_to_add) 37 | 38 | hdf5_name = hdf5_name + '_s' + str(image_size) + '.hdf5' 39 | 40 | 41 | # ******************* TRAIN/TEST/VAL **********************# 42 | 43 | # Get balanced body parts split into train test and validation sets 44 | images_train, labels_train, body_train, filenames_train, images_test, labels_test, body_test, \ 45 | filenames_test, images_val, labels_val, body_val, filenames_val = \ 46 | balanced_test_val_split(main_path, data_to_add, image_size, train_size, n_classes) 47 | 48 | # Save hdf5 file without augmentations 49 | create_h5.write_h5(hdf5_name, images_train, labels_train/255, body_train,filenames_train, images_test, labels_test/255,body_test,filenames_test,\ 50 | images_val, labels_val/255,body_val ,filenames_val) 51 | 52 | -------------------------------------------------------------------------------- /Augmentations/rgb2label.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import rasterio 4 | import cv2 5 | import glob 6 | from keras.utils import to_categorical 7 | 8 | 9 | 10 | def get_mask_from_color( image, color ): 11 | """ Given one image and one color, returns a mask of the same shape as the image, with True values on the pixel positions with the same specified color""" 12 | rows, columns, channels = image.shape 13 | total_pixels = rows * columns 14 | image_flat = image.reshape(total_pixels, channels) 15 | color_array = np.array([color,] * total_pixels) 16 | channels_mask = np.isclose(image_flat, color_array, atol = 100) 17 | #combine channels 18 | mask = np.logical_and(channels_mask[:,0], channels_mask[:,1]) 19 | mask = np.logical_and(mask, channels_mask[:,2]) 20 | return mask.reshape(rows,columns) 21 | 22 | def get_012_label(image, n_colors = 3, colors = [[255,255,255], [255,255,0], [0,0,255]]): 23 | """ Given one image, returns labeling 0,1,2 for 3 colours.""" 24 | #color_0 = [255,255,255] 25 | #color_1 = [255,255,0] 26 | #color_2 = [0,0,255] 27 | 28 | label_012 = np.zeros((image.shape[0], image.shape[1])) 29 | 30 | if(n_colors == 2): 31 | mask = get_mask_from_color(image, colors[2]) 32 | label_012[mask] = 1 33 | 34 | elif(n_colors == 3): 35 | mask = get_mask_from_color(image, colors[1]) 36 | label_012[mask] = 1 37 | mask = get_mask_from_color(image, colors[2]) 38 | label_012[mask] = 2 39 | 40 | else: 41 | print("number of colors not implemented") 42 | return False 43 | 44 | return label_012 45 | 46 | def get_categorical_label(image, n_classes = 3): 47 | """ Given an image, computes the 012 label and uses keras to compute the categorical label""" 48 | label_012 = get_012_label(image, n_classes) 49 | return to_categorical(label_012, n_classes) 50 | -------------------------------------------------------------------------------- /Augmentations/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import rasterio 4 | import glob 5 | import cv2 6 | from random import shuffle 7 | import os 8 | import scipy.misc 9 | import rgb2label as gen_label 10 | 11 | 12 | # Utilities 13 | def balanced_test_val_split(main_path, data_to_add, image_size, train_size, n_classes): 14 | images_found = [] 15 | labels_found = [] 16 | for category in data_to_add: 17 | 18 | print('Checking labels and data match in %s folder ...'%category) 19 | data_path =os.path.join( main_path , 'Images' , category ) 20 | data_path += os.sep + '*.tif' 21 | 22 | labels_path = os.path.join(main_path, 'Labels', category) 23 | labels_path += os.sep + '*.jpg' 24 | 25 | images = glob.glob(data_path) 26 | labels = glob.glob(labels_path) 27 | assert len(labels) != 0 28 | #print('Checking if number of labeled files matches number of data image files....') 29 | # Check that number of labels corresponds to number of images 30 | 31 | assert len(labels) == len(images) 32 | 33 | # Check that they have the same names 34 | label_filename = [] 35 | img_filename = [] 36 | 37 | for (i, img) in enumerate(images): 38 | label_filename.append(labels[i].split(os.sep)[-1].split('.')[0].replace('onehot', '')) 39 | img_filename.append(img.split(os.sep)[-1].split('.')[0] ) 40 | 41 | 42 | label_filename = sorted(label_filename) 43 | img_filename = sorted(img_filename) 44 | 45 | for i in range(len(label_filename)): 46 | 47 | assert label_filename[i] == img_filename[i] 48 | images_found.append( os.path.join(main_path , 'Images' , category) + os.sep + img_filename[i] + '.tif') 49 | labels_found.append( os.path.join(main_path , 'Labels' , category) + os.sep + label_filename[i] + '.jpg') 50 | 51 | 52 | print('Names of labels and data in folder %s match perfectly, %d images found . '%(category, len(img_filename))) 53 | 54 | #shuffle images and labels 55 | c = list(zip(images_found,labels_found)) 56 | shuffle(c) 57 | images, labels = zip(*c) 58 | 59 | # Read and save all images + labels + bodypart 60 | images_read = np.zeros((len(images),image_size,image_size,1),dtype=np.float32) 61 | labels_read = np.zeros((len(labels), image_size, image_size,3),dtype=np.uint8) 62 | bodyparts = np.empty((len(images)),'S10') 63 | split_names = np.empty((len(images)),'S50') 64 | for i in range(len(images)): 65 | filename = images[i] 66 | img = rasterio.open(filename) 67 | img = img.read(1) 68 | images_read[i,...,0] = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_AREA) 69 | 70 | label_filename = labels[i] 71 | labels_images = cv2.imread(label_filename) 72 | 73 | labels_read[i,...] = scipy.misc.imresize(labels_images, (image_size,image_size,3), interp='nearest', mode=None) 74 | labels_read[i,...] = np.uint8(labels_read[i,...]) 75 | labels_read[i,...] = 255*gen_label.get_categorical_label(labels_read[i,...], n_classes) 76 | 77 | # Clean bodyparts names 78 | bodypart = filename.split(os.sep)[-1].split('_')[0].lower() 79 | split_names[i] = filename.split(os.sep)[-1].split('.')[0].lower() 80 | if((bodypart == 'left') or (bodypart == 'right') or (bodypart == 'asg')): 81 | bodypart = filename.split(os.sep)[-1].split('_')[1] 82 | if(bodypart == 'fractured'): 83 | bodypart = filename.split(os.sep)[-1].split('_')[2] 84 | if(bodypart == 'lower'): 85 | bodypart = filename.split(os.sep)[-1].split('_')[2] 86 | if((bodypart == 'belly') or (bodypart == 'plate')): 87 | bodypart = filename.split(os.sep)[-1].split('_')[1] 88 | if((bodypart == 'leg') and (filename.split(os.sep)[-1].split('_')[1] == 'lamb')): 89 | bodypart = filename.split(os.sep)[-1].split('_')[1] 90 | # Remove numbers 91 | bodypart = ''.join(i for i in bodypart if not i.isdigit()) 92 | if(bodypart == 'nof'): 93 | bodypart = 'neckoffemur' 94 | bodypart = bodypart.split('.')[0] 95 | if(bodypart == 'anke'): 96 | bodypart = 'ankle' 97 | 98 | if(bodypart == 'lumbar'): 99 | bodypart = 'lumbarspin' 100 | bodypart = bodypart.encode("ascii", "ignore") 101 | bodyparts[i] = bodypart 102 | 103 | 104 | unique, counts = np.unique(bodyparts, return_counts=True) 105 | unique_per_category = dict(zip(unique, counts)) 106 | 107 | #print('There are %d different bodyparts'%len(unique_per_category)) 108 | 109 | indices = np.arange(images_read.shape[0]) 110 | 111 | 112 | # Build balanced test and validation sets 113 | one_per_class = [] 114 | for i in unique_per_category: 115 | split_category = np.where(bodyparts==i)[0].tolist() 116 | #pick one from each category to be part of the test set 117 | chosen_one_per_class = random.choice(split_category) 118 | indices_to_remove = np.argwhere( indices ==chosen_one_per_class)[0].tolist() 119 | indices = np.delete(indices, indices_to_remove) 120 | one_per_class.append(chosen_one_per_class) 121 | 122 | bodyparts_cut = bodyparts[indices] 123 | unique, counts = np.unique(bodyparts_cut, return_counts=True) 124 | unique_per_category = dict(zip(unique, counts)) 125 | 126 | #print('Test that they are unique') 127 | #print(len(one_per_class) == len(set(one_per_class))) 128 | # From the different bodyparts left fill the test set from those that have more than one example 129 | # until test size is 0.3*total 130 | 131 | extra_need = int((1-train_size)*len(images)) - len(one_per_class) 132 | 133 | counter = 0 134 | test_extra = [] 135 | while ( counter < extra_need ): 136 | #reshuffle dictionary 137 | keys = list(unique_per_category.keys()) 138 | np.random.shuffle(keys) 139 | for bodypart in keys: 140 | if ( counter >= extra_need): 141 | break 142 | if( unique_per_category[bodypart] == 1 or unique_per_category[bodypart] == 0): 143 | continue 144 | 145 | #get random sample of that bodypart 146 | bodypart_indices = np.where(bodyparts[indices] == bodypart)[0].tolist() 147 | bodypart_choice = random.choice(indices[bodypart_indices]) 148 | test_extra.append(bodypart_choice) 149 | #remove bodypart index to avoid repetition 150 | unique_per_category[bodypart] -= 1 151 | remove_bodypart_index = np.argwhere( indices == bodypart_choice)[0].tolist() 152 | indices = np.delete(indices, remove_bodypart_index ) 153 | counter += 1 154 | 155 | test_indices = np.concatenate((one_per_class,test_extra)) 156 | 157 | images_train = images_read[indices,...] 158 | body_train = bodyparts[indices] 159 | split_names_train = split_names[indices] 160 | labels_train = labels_read[indices,...] 161 | 162 | random.shuffle(test_indices) 163 | 164 | images_test = images_read[test_indices[:int(len(test_indices)/2)],...] 165 | body_test = bodyparts[test_indices[:int(len(test_indices)/2)]] 166 | split_names_test = split_names[test_indices[:int(len(test_indices)/2)]] 167 | labels_test = labels_read[test_indices[:int(len(test_indices)/2)],...] 168 | 169 | images_val = images_read[test_indices[int(len(test_indices)/2):],...] 170 | body_val = bodyparts[test_indices[int(len(test_indices)/2):]] 171 | split_names_val = split_names[test_indices[int(len(test_indices)/2):]] 172 | labels_val = labels_read[test_indices[int(len(test_indices)/2):],...] 173 | 174 | 175 | #print(np.in1d(split_names_test, split_names_val, assume_unique=False, invert=False)) 176 | #print('FINAL SHAPES') 177 | #print('train set : %d images'%images_train.shape[0]) 178 | #print('test set : %d images'%images_test.shape[0]) 179 | #print('val set : %d images'%images_val.shape[0]) 180 | 181 | # Check that we didn't lose images on the way 182 | assert (images_train.shape[0] + images_test.shape[0] + images_val.shape[0]) == len(images) 183 | 184 | return images_train, labels_train, body_train, split_names_train, images_test, labels_test, body_test,\ 185 | split_names_test, images_val, labels_val, body_val, split_names_val 186 | 187 | 188 | 189 | 190 | def shuffle_together_simple(images, labels, bodyparts): 191 | 192 | c = list(zip(images,labels, bodyparts)) 193 | shuffle(c) 194 | images, labels, bodyparts = zip(*c) 195 | images = np.asarray(images) 196 | labels = np.asarray(labels) 197 | bodyparts = np.asarray(bodyparts) 198 | 199 | return images, labels, bodyparts 200 | 201 | def shuffle_together(images, labels, bodyparts, filenames): 202 | 203 | c = list(zip(images,labels, bodyparts,filenames)) 204 | shuffle(c) 205 | images, labels, bodyparts, filenames = zip(*c) 206 | images = np.asarray(images) 207 | labels = np.asarray(labels) 208 | bodyparts = np.asarray(bodyparts) 209 | filenames = np.asarray(filenames) 210 | 211 | return images, labels, bodyparts, filenames 212 | 213 | 214 | def random_crop(x, y, permin, permax): 215 | h, w, _ = x.shape 216 | per_h = random.uniform(permin, permax) 217 | per_w = random.uniform(permin, permax) 218 | crop_size = (int((1-per_h)*h),int((1-per_w)*w)) 219 | 220 | rangew = (w - crop_size[0]) // 2 if w>crop_size[0] else 0 221 | rangeh = (h - crop_size[1]) // 2 if h>crop_size[1] else 0 222 | offsetw = 0 if rangew == 0 else np.random.randint(rangew) 223 | offseth = 0 if rangeh == 0 else np.random.randint(rangeh) 224 | cropped_x = x[offseth:offseth+crop_size[0], offsetw:offsetw+crop_size[1], :] 225 | cropped_y = y[offseth:offseth+crop_size[0], offsetw:offsetw+crop_size[1], :] 226 | resize_x = cv2.resize(cropped_x, (h, w), interpolation=cv2.INTER_CUBIC) 227 | resize_y = cv2.resize(cropped_y, (h, w), interpolation=cv2.INTER_NEAREST) 228 | if cropped_y.shape[-1] == 0: 229 | return x, y 230 | else: 231 | return np.reshape(resize_x,(h,w,1)), resize_y 232 | 233 | -------------------------------------------------------------------------------- /Images/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JosephPB/XNet/56b982f8a16ff781dab1137dce7dcdd01954eefd/Images/architecture.jpg -------------------------------------------------------------------------------- /Images/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JosephPB/XNet/56b982f8a16ff781dab1137dce7dcdd01954eefd/Images/confusion_matrix.png -------------------------------------------------------------------------------- /Images/predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JosephPB/XNet/56b982f8a16ff781dab1137dce7dcdd01954eefd/Images/predictions.png -------------------------------------------------------------------------------- /PostProcessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.image as mpimg 4 | import h5py 5 | import glob 6 | import pandas as pd 7 | #import PIL 8 | import tensorflow as tf 9 | import cv2 10 | from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img 11 | from keras.models import Model, Sequential 12 | from keras.layers import * 13 | from keras.optimizers import * 14 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 15 | from keras.metrics import categorical_accuracy 16 | from keras import backend as K 17 | from keras import losses 18 | from keras.models import load_model as keras_load_model 19 | from keras.utils import to_categorical 20 | from keras.preprocessing.image import ImageDataGenerator 21 | import sys 22 | from keras.utils.generic_utils import get_custom_objects 23 | from sklearn.metrics import roc_curve, auc 24 | 25 | 26 | sys.path.insert(0, '../') 27 | 28 | 29 | 30 | class PostProcessing: 31 | beam = 0 32 | tissue = 1 33 | bone = 2 34 | def __init__(self, model_path, dataset_path, loss = 'categorical_crossentropy', device = "cpu"): 35 | self.model_path = model_path 36 | self.dataset_path = dataset_path 37 | self.read_h5_file() 38 | print(loss) 39 | self.load_model(device = device, loss = loss) 40 | print('Model loaded.') 41 | self.prediction_prob_rs, self.prediction_argmax = self.predict(device=device) 42 | 43 | def read_h5_file(self): 44 | "Read data from h5file" 45 | dataset = h5py.File(self.dataset_path, 'r') 46 | self.train_images = dataset['train_img'] 47 | self.test_images = dataset['test_img'][:] 48 | self.val_images = dataset['val_img'][:] 49 | self.train_labels = dataset['train_label'] 50 | self.train_body = dataset['train_bodypart'][:] 51 | self.test_labels = dataset['test_label'][:] 52 | self.val_labels = dataset['val_label'][:] 53 | self.test_body = dataset['test_bodypart'][:] 54 | self.val_body = dataset['val_bodypart'][:] 55 | self.test_filenames = dataset['test_file'][:] 56 | self.val_filenames = dataset['val_file'][:] 57 | self.no_images_training, self.height, self.width, self.classes = self.train_labels.shape 58 | self.train_labels = np.reshape(self.train_labels, (-1,self.height*self.width ,self.classes)) 59 | self.test_labels = np.reshape(self.test_labels, (-1,self.height*self.width ,self.classes)) 60 | self.val_labels = np.reshape(self.val_labels, (-1,self.height*self.width ,self.classes)) 61 | self.test_images = np.concatenate((self.test_images, self.val_images)) 62 | self.test_labels = np.concatenate((self.test_labels, self.val_labels)) 63 | self.test_filenames = np.concatenate((self.test_filenames, self.val_filenames)) 64 | #REMOVE breast and Rectangles 65 | #mask1 = np.where((self.test_filenames != b'breast_phantom') & (self.test_filenames != b'pmmaandal')) 66 | #self.test_images = self.test_images[mask1] 67 | #self.test_labels = self.test_labels[mask1] 68 | 69 | dataset.close() 70 | 71 | def load_model(self, device = "cpu", optimizer = Adam(lr=1e-4), loss = "categorical_crossentropy",\ 72 | metrics = ['accuracy'] ): 73 | if(device == "cpu"): 74 | with tf.device("/cpu:0"): 75 | if(loss == "jaccard"): 76 | from jaccard_loss import jaccard_distance 77 | self.model = keras_load_model(self.model_path,custom_objects ={'jaccard_distance': jaccard_distance}) 78 | self.model.compile(optimizer, loss = jaccard_distance, metrics = metrics) 79 | elif(loss == "fancy"): 80 | from kerasfancyloss import fancy_loss 81 | self.model = keras_load_model(self.model_path,custom_objects ={'fancy_loss': fancy_loss}) 82 | self.model.compile(optimizer, loss =fancy_loss, metrics = metrics) 83 | else: 84 | self.model = keras_load_model(self.model_path) 85 | self.model.compile(optimizer, loss, metrics) 86 | elif(device == "gpu"): 87 | if(loss == "jaccard"): 88 | from jaccard_loss import jaccard_distance 89 | self.model = keras_load_model(self.model_path,custom_objects ={'jaccard_distance': jaccard_distance}) 90 | self.model.compile(optimizer, loss = jaccard_distance, metrics = metrics) 91 | elif(loss == "fancy"): 92 | from kerasfancyloss import fancy_loss 93 | self.model = keras_load_model(self.model_path,custom_objects ={'fancy_loss': fancy_loss}) 94 | self.model.compile(optimizer, loss =fancy_loss, metrics = metrics) 95 | else: 96 | self.model = keras_load_model(self.model_path) 97 | self.model.compile(optimizer, loss, metrics) 98 | else: 99 | print("Device not understood") 100 | return None 101 | 102 | def predict(self, device = "cpu", images = None): 103 | if(images is None): 104 | images = self.test_images 105 | if( device == "cpu"): 106 | with tf.device("/cpu:0"): 107 | prediction_prob = self.model.predict(images, batch_size=1) 108 | elif(device == "gpu"): 109 | prediction_prob = self.model.predict(images, batch_size=1) 110 | else: 111 | print("Device not found") 112 | return None 113 | prediction_prob_rs = prediction_prob.reshape((-1,self.classes)) 114 | prediction_argmax = np.argmax(prediction_prob_rs, axis = -1) 115 | return prediction_prob_rs, prediction_argmax 116 | 117 | def evaluate_overall(self, device = "gpu"): 118 | images = self.test_images 119 | labels = self.test_labels 120 | if(device == "cpu"): 121 | with tf.device("/cpu:0"): 122 | loss_test, accuracy_test = self.model.evaluate(images,labels, batch_size = 1) 123 | 124 | elif(device == "gpu"): 125 | loss_test, accuracy_test = self.model.evaluate(images,labels, batch_size = 1) 126 | else: 127 | print("Device not understood") 128 | return None 129 | 130 | print("Overall accuracy : \n") 131 | print ('On test set {}%'.format(round(accuracy_test,2)*100)) 132 | 133 | # Count number of trainable parameters 134 | trainable_count = int(np.sum([K.count_params(p) for p in set(self.model.trainable_weights)])) 135 | print('Trainable params: {:,}'.format(trainable_count)) 136 | return accuracy_test, trainable_count 137 | 138 | def evaluate_perclass(self, device = "gpu"): 139 | 140 | _, predictions = self.predict() 141 | labels = self.test_labels 142 | labels = np.argmax(labels, axis = -1) 143 | labels = labels.flatten() 144 | 145 | beam_gt = np.where(labels == self.beam)[0] 146 | beam_pred = np.where(predictions == self.beam)[0] 147 | beam_accuracy = float(len(np.intersect1d(beam_gt, beam_pred, assume_unique=True)))/float(len(beam_pred)) 148 | 149 | tissue_gt = np.where(labels == self.tissue)[0] 150 | tissue_pred = np.where(predictions == self.tissue)[0] 151 | tissue_accuracy = float(len(np.intersect1d(tissue_gt, tissue_pred, assume_unique=True)))/float(len(tissue_pred)) 152 | 153 | bone_gt = np.where(labels == self.bone)[0] 154 | bone_pred = np.where(predictions == self.bone)[0] 155 | bone_accuracy = float(len(np.intersect1d(bone_gt, bone_pred, assume_unique=True)))/float(len(bone_pred)) 156 | 157 | print('Accuracy on the different classes : \n') 158 | print('Open beam %f, Soft tissue %f, Bone %f'%(beam_accuracy,tissue_accuracy, bone_accuracy)) 159 | return beam_accuracy, tissue_accuracy, bone_accuracy 160 | 161 | def tpfp(self, predictions = None, single_index = -1): 162 | 163 | if (not (single_index == -1)): 164 | labels = self.test_labels[single_index] 165 | if(predictions is not None): 166 | prediction_argmax = predictions.reshape(-1,200,200) 167 | else: 168 | prediction_argmax = self.prediction_argmax.reshape(-1,200,200) 169 | 170 | prediction_argmax = prediction_argmax[single_index] 171 | prediction_argmax = prediction_argmax.flatten() 172 | else: 173 | if( predictions is not None): 174 | prediction_argmax = predictions 175 | else: 176 | prediction_argmax = self.prediction_argmax 177 | labels = self.test_labels 178 | 179 | labels = np.argmax(labels, axis = -1) 180 | labels = labels.flatten() 181 | 182 | beam_gt = np.where(labels == self.beam)[0] 183 | beam_pred = np.where(prediction_argmax == self.beam)[0] 184 | 185 | tissue_gt = np.where(labels == self.tissue)[0] 186 | tissue_pred = np.where(prediction_argmax == self.tissue)[0] 187 | 188 | if (len(tissue_pred) == 0): 189 | return 0,0 190 | 191 | 192 | bone_gt = np.where(labels == self.bone)[0] 193 | bone_pred = np.where(prediction_argmax == self.bone)[0] 194 | 195 | # FALSE POSITIVES 196 | false_positives = 0 197 | beam_as_tissue = float(len(np.intersect1d(beam_gt, tissue_pred, assume_unique=True)))/float(len(tissue_pred)) 198 | false_positives = beam_as_tissue 199 | bone_as_tissue = float(len(np.intersect1d(bone_gt, tissue_pred, assume_unique=True)))/float(len(tissue_pred)) 200 | false_positives += bone_as_tissue 201 | 202 | # TRUE POSITIVES 203 | 204 | true_positives = 0 205 | true_positives = len(np.intersect1d(tissue_gt, tissue_pred, assume_unique=True))/len(tissue_gt) 206 | 207 | # FALSE NEGATIVES 208 | 209 | false_negatives = 0 210 | tissue_as_beam = float(len(np.intersect1d(tissue_gt, beam_pred, assume_unique=True)))/float(len(tissue_gt)) 211 | false_negatives = tissue_as_beam 212 | tissue_as_bone = float(len(np.intersect1d(tissue_gt, bone_pred, assume_unique=True)))/float(len(tissue_gt)) 213 | false_negatives += tissue_as_bone 214 | 215 | 216 | return true_positives, false_positives 217 | 218 | 219 | def thresholding(self,threshold, device = "cpu"): 220 | prob_prediction_tissue = self.prediction_prob_rs[...,self.tissue] 221 | tissue_pred = np.where((prob_prediction_tissue > threshold))[0] 222 | 223 | prediction_improved = self.prediction_argmax 224 | prediction_improved[tissue_pred] = self.tissue 225 | 226 | tissue_notsure = np.where((prob_prediction_tissue <= threshold))[0] 227 | openbeam_bone = self.prediction_prob_rs[...,[0,2]] 228 | prediction_improved[tissue_notsure] = 2 * np.argmax(openbeam_bone[tissue_notsure], axis = -1) 229 | self.prediction_argmax = prediction_improved 230 | return prediction_improved 231 | 232 | def thresholding_bodypart(self): 233 | 234 | unique, counts = np.unique(self.test_body, return_counts=True) 235 | thresholds = 0.6*np.ones(len(unique)) 236 | thresholds_dict = dict(zip(unique, thresholds)) 237 | thresholds_dict[b'ankle'] = 0.85 238 | thresholds_dict[b'hand'] = 0.99 239 | thresholds_dict[b'cropped'] = 0.99 240 | thresholds_dict[b'foils'] = 0.99 241 | thresholds_dict[b'lumbarspin'] = 0.99 242 | thresholds_dict[b'neckoffemu'] = 0.9 243 | prediction_prob_rs, prediction_argmax = self.predict() 244 | prediction_argmax = prediction_argmax.reshape(-1, self.height, self.width) 245 | prediction_improved = np.zeros_like(prediction_argmax) 246 | 247 | test_images = self.test_images[...,0] 248 | for i,image in enumerate(test_images): 249 | bodypart = self.test_body[i] 250 | threshold = thresholds_dict[bodypart] 251 | prediction_prob = prediction_prob_rs[i] 252 | prob_prediction_tissue = prediction_prob[...,self.tissue] 253 | tissue_pred = np.where((prob_prediction_tissue > threshold))[0] 254 | 255 | prediction_improved[i] = prediction_argmax[i] 256 | prediction_improved[tissue_pred] = self.tissue 257 | tissue_notsure = np.where((prob_prediction_tissue <= threshold))[0] 258 | openbeam_bone = prediction_prob[...,[0,2]] 259 | prediction_improved[tissue_notsure] = 2 * np.argmax(openbeam_bone[tissue_notsure], axis = -1) 260 | 261 | return prediction_improved 262 | 263 | 264 | 265 | def pixel_dilation(self, dilation_factor, predictions = None, both = False): 266 | '''Dilates pixels if bone and/or soft tissue. 267 | Input: 268 | prediction: argmaxed images shape = (height,width) 269 | dilation_factor: number of pixels by which to dilate 270 | both: bool, if True dilates both open beam and bone, with preference for bone, if False dilates bone''' 271 | 272 | if( predictions is None): 273 | _, predictions = self.predict() 274 | predictions = predictions.reshape((-1, self.height, self.width)) 275 | predictions = predictions.astype(np.float32) 276 | predictions_dilated = np.ones_like(predictions) 277 | 278 | prediction_bone = np.zeros_like(predictions) 279 | bone_indices = np.where(predictions == self.bone) 280 | prediction_bone[bone_indices] = self.bone 281 | prediction_bone = prediction_bone.reshape((-1, self.height, self.width)) 282 | prediction_bone = prediction_bone.astype(np.float32) 283 | prediction_bone_dilated = np.zeros_like(predictions) 284 | 285 | for i,prediction in enumerate(prediction_bone): 286 | #remove small groups of bone 287 | kernel_opening = np.ones((10,10), np.uint8) 288 | bone_pred = np.where(prediction == self.bone) 289 | opening = cv2.morphologyEx(prediction, cv2.MORPH_OPEN, kernel_opening) 290 | 291 | #dilate image 292 | kernel_dilate = np.ones((dilation_factor, dilation_factor), np.uint8 ) 293 | dilated = cv2.dilate(opening, kernel_dilate) 294 | prediction_bone_dilated[i,...] = dilated 295 | 296 | predictions_dilated[np.where(prediction_bone_dilated == self.bone)] = self.bone 297 | predictions_dilated[np.where(predictions == self.beam)] = self.beam 298 | 299 | return predictions_dilated 300 | 301 | 302 | # batch, height, width = prediction.shape 303 | # cp = np.copy(prediction) 304 | 305 | #for k in range(batch): 306 | # for i in range(dilation_factor,height-dilation_factor): 307 | # for j in range(dilation_factor,width-dilation_factor): 308 | # if prediction[k,i,j] == bone: 309 | # cp[k,i-dilation_factor:i+dilation_factor+1,j-dilation_factor:j+dilation_factor+1] = bone 310 | # if both == True: 311 | # if cp[k,i,j] == open_beam: 312 | # cp[k,i-dilation_factor:i+dilation_factor+1,j-dilation_factor:j+dilation_factor+1] = open_beam 313 | #return cp 314 | 315 | def plot(self,threshold, dilation_factor): 316 | probability_map, prediction = self.predict() 317 | prediction_threshold = self.thresholding(0.9) 318 | prediction_dilation = self.pixel_dilation(dilation_factor, prediction_threshold) 319 | ntestimages = len(self.test_images) 320 | left = 0.1 # the left side of the subplots of the figure 321 | right = 0.4 # the right side of the subplots of the figure 322 | bottom = 0.1 # the bottom of the subplots of the figure 323 | top = 0.9 # the top of the subplots of the figure 324 | wspace = 0.08 # the amount of width reserved for blank space between subplots 325 | hspace = 0.1 # the amount of height reserved for white space between subplots 326 | 327 | labels_plot = self.test_labels.reshape(-1, self.height, self.width, 3) * 255 328 | prediction = prediction.reshape(-1, self.height, self.width) 329 | prediction_threshold = prediction_threshold.reshape(-1, self.height, self.width) 330 | prediction_dilation = prediction_dilation.reshape(-1, self.height, self.width) 331 | 332 | for i, image in enumerate(self.test_images): 333 | print(self.test_filenames[i]) 334 | fig=plt.figure(figsize=(50, 50), dpi= 80, edgecolor='k',frameon=False) 335 | plt.subplots_adjust(left=left, bottom=bottom, right=right, top=top, wspace=wspace, hspace=hspace) 336 | index_show = i 337 | print(i) 338 | # Need TP/FP per image 339 | plt.subplot(ntestimages,5,1) 340 | plt.title('Image') 341 | plt.imshow(image[...,0],cmap='gray') 342 | plt.axis('off') 343 | 344 | plt.subplot(ntestimages,5,2) 345 | plt.title('Ground truth') 346 | plt.imshow(labels_plot[i]) 347 | plt.axis('off') 348 | 349 | plt.subplot(ntestimages,5,3) 350 | plt.title('Prediction') 351 | plt.imshow(prediction[i]) 352 | plt.axis('off') 353 | 354 | #plt.subplot(ntestimages,5,4) 355 | #plt.title('Probability map') 356 | #plt.imshow(probability_map[i]) 357 | #plt.axis('off') 358 | 359 | plt.subplot(ntestimages,5,4) 360 | plt.title('Threshold') 361 | plt.imshow(prediction_threshold[i]) 362 | plt.axis('off') 363 | 364 | plt.subplot(ntestimages,5,5) 365 | plt.title('Dilated') 366 | plt.imshow(prediction_dilation[i]) 367 | plt.axis('off') 368 | plt.show() 369 | 370 | def learning_curve(self, path_to_csv): 371 | 372 | csv_file = pd.read_csv(path_to_csv) 373 | self.csv = csv_file 374 | epochs = self.csv['epoch'] 375 | train_loss = self.csv['loss'] 376 | val_loss = self.csv['val_loss'] 377 | train_acc = self.csv['acc'] 378 | val_acc = self.csv['val_acc'] 379 | 380 | train_err = 1 - train_acc 381 | val_err = 1 - val_acc 382 | fig, ax = plt.subplots(2,1, figsize=(15,15)) 383 | ax[0].plot(epochs, train_err, color = 'blue', label = 'training error') 384 | ax[0].plot(epochs, val_err, color = 'orange', label = 'validation error') 385 | ax[0].plot(epochs, np.linspace(0.02,0.02,len(epochs)), color = 'green', label = 'desired error') 386 | ax[0].set_xlabel('number of epochs') 387 | ax[0].set_ylabel('error') 388 | ax[0].set_title('Error') 389 | ax[0].legend() 390 | 391 | ax[1].plot(epochs, train_loss, label = "training loss") 392 | ax[1].plot(epochs, val_loss, label = "validation loss") 393 | ax[1].legend() 394 | ax[1].set_xlabel("Number of epochs") 395 | ax[1].set_ylabel("Loss") 396 | ax[1].set_title("Loss") 397 | 398 | return fig, ax 399 | #plt.title('Epoch learning curve for Double Linked Network') 400 | #plt.savefig('Linked_epoch_LC.png', dpi = 250) 401 | 402 | 403 | def ROC_curve(self): 404 | # 1 for only tissue 405 | fpr, tpr, thresholds = roc_curve(self.test_labels[..., 1].reshape(-1), self.prediction_prob_rs[...,1].reshape(-1)) 406 | roc_auc = auc(fpr, tpr) 407 | fig, ax = plt.subplots() 408 | ax.plot(fpr, tpr, 409 | label='Tissue ROC curve (area = {0:0.2f})' 410 | ''.format(roc_auc), 411 | color='indianred', linestyle=':', linewidth=4) 412 | ax.plot([0, 1], [0, 1], 'k--') 413 | ax.set_xlim([0.0, 1.0]) 414 | ax.set_ylim([0.0, 1.05]) 415 | ax.set_xlabel('False Positive Rate') 416 | ax.set_ylabel('True Positive Rate') 417 | ax.set_title('Some extension of Receiver operating characteristic to multi-class') 418 | ax.legend(loc="lower right") 419 | return fig, ax 420 | -------------------------------------------------------------------------------- /Poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JosephPB/XNet/56b982f8a16ff781dab1137dce7dcdd01954eefd/Poster.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XNet 2 | 3 | XNet is a Convolutional Neural Network designed for the segmentation 4 | of X-Ray images into bone, soft tissue and open beam 5 | regions. Specifically, it performs well on small datasets with the aim 6 | to minimise the number of false positives in the soft tissue class. 7 | 8 | This code accompanies the paper published in the SPIE Medical Imaging Conference Proceedings (2019) and can be found on the preprint arXiv at: [arXiv:1812.00548](https://arxiv.org/abs/1812.00548) 9 | 10 | Cite as: 11 | ``` 12 | @inproceedings{10.1117/12.2512451, 13 | author = {Joseph Bullock and Carolina Cuesta-Lázaro and Arnau Quera-Bofarull}, 14 | title = {{XNet: a convolutional neural network (CNN) implementation for medical x-ray image segmentation suitable for small datasets}}, 15 | volume = {10953}, 16 | booktitle = {Medical Imaging 2019: Biomedical Applications in Molecular, Structural, and Functional Imaging}, 17 | editor = {Barjor Gimi and Andrzej Krol}, 18 | organization = {International Society for Optics and Photonics}, 19 | publisher = {SPIE}, 20 | pages = {453 -- 463}, 21 | keywords = {machine learning, deep learning, X-Ray segmentation, neural network, small datasets}, 22 | year = {2019}, 23 | doi = {10.1117/12.2512451}, 24 | URL = {https://doi.org/10.1117/12.2512451} 25 | } 26 | ``` 27 | 28 | ## Architecture 29 | 30 | ![](./Images/architecture.jpg) 31 | 32 | * Built on a typical encoder-decoder architecture as 33 | inspired by [SegNet](http://mi.eng.cam.ac.uk/projects/segnet/). 34 | 35 | * Additional feature extraction stage, with weight sharing across some 36 | layers. 37 | 38 | * Fine and coarse grained feature preservation through concatenation 39 | of layers. 40 | 41 | * L2 regularisation at each of the convolutional layers, to decrease overfitting. 42 | 43 | The architecture is described in the ```XNet.py``` file. 44 | 45 | ## Output 46 | 47 | XNet outputs a mask of equal size to the input images. 48 | 49 | ![](./Images/predictions.png) 50 | 51 | ## Training 52 | 53 | To train a model: 54 | 55 | 1. Open ```Training/generate_parameters.py``` and define your desired hyperparameters 56 | 2. Run ```Training/generate_parameters.py``` to generate a ```paramteres.txt``` file which is read ```Training/TrainingClass.py``` 57 | 3. Run ```train.py``` 58 | 59 | XNet is trained on a small dataset which has undergone augmention. Examples of this augmentation step can be found in the 60 | ```Augmentations/augmentations.ipynb``` notebook. Similarly the ```Training``` folder contains python scripts that perform the necessary augementations. 61 | 62 | Running ```Training/train.py``` calls various other scripts to perform one of two possible ways of augmenting the images: 63 | 64 | * 'On the fly augmentation' where a new set of augmentations is generated at each epoch. 65 | 66 | * Pre-augmented images. 67 | 68 | To select which method to use comment out the corresponding lines in the ```fit``` function in the ```Training/TrainingClass.py``` script. 69 | 70 | ```train.py``` also performs postprocessing to fine tune the results. 71 | 72 | ## Benchmarking 73 | 74 | XNet was benchmarked against two of the leading segmentation networks: 75 | 76 | * Simplified [SegNet](https://arxiv.org/abs/1511.00561) (found in the 77 | ```SimpleSegNet.py``` file) 78 | 79 | * [UNet](https://arxiv.org/abs/1505.04597) (found in the ```UNet.py``` 80 | file) 81 | 82 | ## Data 83 | 84 | We trained on a dataset of: 85 | 86 | * 150 X-Ray images. 87 | 88 | * No scatter correction. 89 | 90 | * 1500x1500 ```.tif``` image downsampled to 200x200 91 | 92 | * 20 human body part classes. 93 | 94 | * Highly imbalanced. 95 | 96 | As this work grew out of work with a corporation we are sadly unable to share the propriatory data we used. 97 | 98 | ## More information 99 | 100 | For more information and context see the conference poster 101 | ```Poster.pdf```. 102 | 103 | Please note that some of the path variables may need to be corrected in order to utilise the current filing system. These are planned to be updated in the future. 104 | -------------------------------------------------------------------------------- /SimpleSegNet.py: -------------------------------------------------------------------------------- 1 | from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img 2 | from keras.models import Model, Sequential 3 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D,Convolution2D 4 | from keras.layers import BatchNormalization, Reshape, Layer 5 | from keras.layers import Activation, Flatten, Dense, ConvLSTM2D, LeakyReLU 6 | from keras.optimizers import * 7 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 8 | from keras.metrics import categorical_accuracy 9 | from keras import backend as K 10 | from keras import losses 11 | from keras.models import load_model 12 | 13 | def model(input_shape=(64,64,1), classes=3, kernel_size = 3, filter_depth = (64,128,256,512,0)): 14 | 15 | img_input = Input(shape=input_shape) 16 | x = img_input 17 | # Encoder 18 | x = Conv2D(filter_depth[0], (kernel_size, kernel_size), padding="same")(x) 19 | x = BatchNormalization()(x) 20 | x = Activation("relu")(x) 21 | x = MaxPooling2D(pool_size=(2, 2))(x) 22 | 23 | x = Conv2D(filter_depth[1], (kernel_size, kernel_size), padding="same")(x) 24 | x = BatchNormalization()(x) 25 | x = Activation("relu")(x) 26 | x = MaxPooling2D(pool_size=(2, 2))(x) 27 | #50x50 28 | x = Conv2D(filter_depth[2], (kernel_size, kernel_size), padding="same")(x) 29 | x = BatchNormalization()(x) 30 | x = Activation("relu")(x) 31 | x = MaxPooling2D(pool_size=(2, 2))(x) 32 | #25x25 33 | x = Conv2D(filter_depth[3], (kernel_size, kernel_size), padding="same")(x) 34 | x = BatchNormalization()(x) 35 | x = Activation("relu")(x) 36 | 37 | # Decoder 38 | x = Conv2D(filter_depth[3], (kernel_size, kernel_size), padding="same")(x) 39 | x = BatchNormalization()(x) 40 | x = Activation("relu")(x) 41 | #25x25 42 | x = UpSampling2D(size=(2, 2))(x) 43 | x = Conv2D(filter_depth[2], (kernel_size, kernel_size), padding="same")(x) 44 | x = BatchNormalization()(x) 45 | x = Activation("relu")(x) 46 | #50x50 47 | x = UpSampling2D(size=(2, 2))(x) 48 | x = Conv2D(filter_depth[1], (kernel_size, kernel_size), padding="same")(x) 49 | x = BatchNormalization()(x) 50 | x = Activation("relu")(x) 51 | #100x100 52 | x = UpSampling2D(size=(2, 2))(x) 53 | x = Conv2D(filter_depth[0], (kernel_size, kernel_size), padding="same")(x) 54 | x = BatchNormalization()(x) 55 | x = Activation("relu")(x) 56 | 57 | x = Conv2D(classes, (1,1), padding="valid")(x) 58 | 59 | 60 | x = Reshape((input_shape[0]*input_shape[1],classes))(x) 61 | x = Activation("softmax")(x) 62 | 63 | model = Model(img_input, x) 64 | 65 | return model -------------------------------------------------------------------------------- /Training demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Segmenting X-Ray Images using Neural Networks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "The aim of this notebook to walk through the process of performing inference on a pretrianed model to segment X-Ray images into bone, soft-tissue and open beam regions" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "**Note:** Since we cannot relase the fully trained model for proprietary reasons, this pretrained model has only been trained on ~10 images. Perfomance is therefore significantly hindered." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Download data and model" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "!pip install googledrivedownloader\n", 38 | "from google_drive_downloader import GoogleDriveDownloader as gdd\n", 39 | "gdd.download_file_from_google_drive(file_id='1Wel_XsyE7HcEq0TkZWI61GABO4jOtj9C',\n", 40 | " dest_path='./dataset.hdf5')\n", 41 | "gdd.download_file_from_google_drive(file_id='1cePD5E-T9mr5W0xPGuzEnUt8Glpvn23U',\n", 42 | " dest_path='./model.h5')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import os, sys\n", 52 | "import numpy as np\n", 53 | "import h5py\n", 54 | "import matplotlib.pyplot as plt\n", 55 | "\n", 56 | "#import Keras sub-modules\n", 57 | "from keras.models import Model #functional API for Keras (best for greater flexibility)\n", 58 | "from keras.layers import Input, Concatenate, Conv2D, MaxPooling2D, UpSampling2D, Flatten, Dense #'main' layers\n", 59 | "from keras.layers import BatchNormalization, Dropout #regulartisation layers\n", 60 | "from keras.layers import Activation\n", 61 | "from keras.optimizers import * #import all optimisers\n", 62 | "from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger #callbacks for model performance analysis\n", 63 | "from keras.metrics import categorical_accuracy #metrics for model performance\n", 64 | "from keras import backend as K #gives backend functionality\n", 65 | "from keras import losses #imports pre-defined loss functions\n", 66 | "from keras.models import load_model #allows pre-trained models to be called back" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "g5 files compress the data for convenient storage, as well as offering accessibility through a 'key' system. This is a useful file type, however, if it can become cumbersome for large datasets. Other storage mechanisms will still work with this model." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "hdf5_path = \"./dataset.hdf5\" ## this is our h5 file containing training and testing data\n", 83 | "dataset = h5py.File(hdf5_path , 'r')\n", 84 | "\n", 85 | "classes = 3\n", 86 | "\n", 87 | "test_images = dataset['test_img'][:]\n", 88 | "no_images, height, width, channels = test_images.shape\n", 89 | "\n", 90 | "test_labels =dataset['test_label'][:].reshape(-1,height*width, classes )\n", 91 | "\n", 92 | "dataset.close()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Load model" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "model = load_model(\"./model.h5\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "The above should have been relatively straight forward. We simply use the existing data files we have created, called them, and compiled the model based on this input. Before progressing further, we shall now investigate the model. By printing a model summary, Keras provides a user friendly output which shows the layers and their parameters." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "model.summary()" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "With this summary in mind, go to the ```XNet.py``` script and have a look." 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "The process of training and saving a model is summarised in the ```Training/TrainingClass.py``` script. Have a look at that script to see what the different functions are doing. In the class you will see that all the augementations are being performed on the data, and so you just need to create an HDF5 file which contains non-augmented images and call it into the training class." 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "Once you are happy with your understanding of the training class we shall now try predicting from the model. In Keras this is very simple. Pass in the test image into the ```model.predict``` function and reshape the output to view it." 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "Choose an image to test on" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "test_index = 0" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "testing_image = test_images[test_index]\n", 171 | "\n", 172 | "#as we are only running one image, we must reshape to shape (batch, height, width, channels)\n", 173 | "testing_image = testing_image.reshape((1,200,200,1))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "prediction = model.predict(testing_image)\n", 183 | "\n", 184 | "#the prediction is a flattened array and so must be reshaped.\n", 185 | "#there are 3 channels as we are actually outputting the probability map over all 3 classes.\n", 186 | "prediction = prediction.reshape((200,200,3))" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "plt.imshow(prediction)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "After performing prediction various postprocessing stages can be employed to fine tune the output. See the ```PostProcessing.py``` script for more details, which is then called in the ```Training/train.py``` script." 203 | ] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 2", 209 | "language": "python", 210 | "name": "python2" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 2 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython2", 222 | "version": "2.7.15" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | -------------------------------------------------------------------------------- /Training/TrainingClass.py: -------------------------------------------------------------------------------- 1 | import tensorflow 2 | import h5py 3 | import numpy as np 4 | import os 5 | import sys 6 | import signal 7 | import shutil 8 | import importlib.util 9 | import time 10 | 11 | from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img 12 | from keras.models import Model, Sequential, load_model 13 | from keras.layers import * 14 | from keras import backend as K 15 | from keras import losses 16 | from keras.optimizers import * 17 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler, CSVLogger, TensorBoard, EarlyStopping 18 | from keras.metrics import categorical_accuracy 19 | from utils import shuffle_together_simple, random_crop 20 | from random import randint 21 | import imgaug as ia 22 | from keras.utils import to_categorical 23 | #import matplotlib 24 | matplotlib.use('Agg') 25 | import matplotlib.pyplot as plt 26 | from imgaug import augmenters as iaa 27 | from imgaug import parameters as iap 28 | 29 | def fancy_loss(y_true,y_pred): 30 | "This function has been written in tensorflow, needs some little changes to work with keras" 31 | y_pred = tf.reshape(y_pred,[-1,y_pred.shape[-1]]) 32 | y_true = tf.argmax(y_true, axis=-1) 33 | y_true = tf.reshape(y_true,[-1]) 34 | return lovasz_softmax_flat(y_pred, y_true) 35 | 36 | 37 | 38 | class TrainingClass: 39 | 40 | def __init__(self, name, model_path, data_path, save_folder, no_epochs, kernel_size, batch_size, filters, lrate = 1e-4, reg = 0.0001, loss = 'categorical_crossentropy', duplicate = True ): 41 | self.name = name 42 | self.model_path = model_path 43 | self.data_path = data_path 44 | self.save_folder = save_folder 45 | self.kernel_size = kernel_size 46 | self.batch_size = batch_size 47 | self.filters = filters 48 | self.lrate = lrate 49 | self.reg = reg 50 | self.no_epochs = no_epochs 51 | if(loss == "fancy"): 52 | from fancyloss import lovasz_softmax_flat 53 | self.loss = fancy_loss 54 | elif(loss == "jaccard"): 55 | from jaccard_loss import jaccard_distance 56 | self.loss = jaccard_distance 57 | else: 58 | self.loss = loss 59 | self.load_data() 60 | if(duplicate == True): 61 | self.train, self.train_label, self.train_bodypart = self.duplicate() 62 | print("Finished duplicating images, the final size of your training set is %d images."%self.train.shape[0]) 63 | self.write_metadata() 64 | self.compile() 65 | 66 | def load_data(self): 67 | "Loads data from h5 file" 68 | hf = h5py.File(self.data_path, 'r') 69 | 70 | self.train = hf['train_img'] 71 | self.no_images, self.height, self.width, self.channels= self.train.shape 72 | self.train_label = hf['train_label'] 73 | self.train_bodypart = hf['train_bodypart'][:] 74 | self.no_images, _, _, self.no_classes = self.train_label.shape 75 | self.val = hf['val_img'][:] 76 | self.val_label = hf['val_label'][:] 77 | self.val_label = self.val_label.reshape((-1,self.height*self.width,self.no_classes)) 78 | print("Data loaded succesfully.") 79 | 80 | def write_metadata(self): 81 | "Writes metadata to a txt file, with all the training information" 82 | metafile_path = self.save_folder + "/metadata.txt" 83 | 84 | if (os.path.isfile(metafile_path)): 85 | confirm_metada = input("Warning metadata file exists, continue? (y/n) ") 86 | if(confirm_metada == "y"): 87 | shutil.rmtree(metafile_path) 88 | else: 89 | sys.exit() 90 | 91 | metadata = open(metafile_path, "w") 92 | metadata.write("name: %s \n"%self.name) 93 | metadata.write("Data: %s \n"%self.data_path) 94 | metadata.write("kernel_size: %d \n" %self.kernel_size) 95 | metadata.write("batch_size:%d \n" %self.batch_size) 96 | metadata.write("filters %s \n" %(self.filters,)) 97 | metadata.write("lrate: %f \n" %self.lrate) 98 | metadata.write("reg: %f \n" %self.reg) 99 | metadata.write("Loss function: %s \n" %self.loss) 100 | metadata.write("no_epochs: %d \n" %self.no_epochs) 101 | metadata.close() 102 | 103 | 104 | 105 | def generator(self): 106 | "This generator is used to feed the data to the training algorithm. Given a batch size, randomly divides the training data into batches. This function allows training even when all the data cannot be loaded into RAM memory." 107 | 108 | while True: 109 | indices = np.asarray(range(0, self.no_images)) 110 | np.random.shuffle(indices) 111 | for idx in range(0, len(indices), self.batch_size): 112 | batch_indices = indices[idx:idx+self.batch_size] 113 | batch_indices.sort() 114 | batch_indices = batch_indices.tolist() 115 | by = self.train_label[batch_indices] 116 | by = by.reshape(-1, self.width*self.height, self.no_classes) 117 | bx = self.train[batch_indices] 118 | 119 | yield(bx,by) 120 | def duplicate(self): 121 | "Since our dataset is highly imbalanced among bodyparts, duplicate images from underrepresented bodyparts" 122 | img_per_category, counts = np.unique(self.train_bodypart, return_counts=True) 123 | img_per_category = dict(zip(img_per_category, counts)) 124 | EXAMPLES_PER_CATEGORY = max(img_per_category.values()) 125 | duplications_per_category = dict(img_per_category) 126 | for key in img_per_category: 127 | duplications_per_category[key] = int(EXAMPLES_PER_CATEGORY/img_per_category[key]) 128 | 129 | duplicated_size = sum(duplications_per_category[k]*img_per_category[k] + img_per_category[k] \ 130 | for k in duplications_per_category) 131 | 132 | train_duplicated = np.zeros((duplicated_size,self.height,self.width,self.train.shape[3])) 133 | labels_duplicated = np.zeros((duplicated_size,self.height, self.width,self.no_classes)) 134 | bodypart_duplicated = np.empty((duplicated_size),dtype = 'S10') 135 | 136 | train_duplicated[:self.no_images,...] = self.train 137 | labels_duplicated[:self.no_images,...] = self.train_label 138 | bodypart_duplicated[:self.no_images,...] = self.train_bodypart 139 | 140 | # Loop over the different kind of bodyparts 141 | counter = self.no_images 142 | counter_block = 0 143 | for i, (k, v) in enumerate(duplications_per_category.items()): 144 | # Indices of images with a given bodypart 145 | indices = np.array(np.where(self.train_bodypart == k )[0]) 146 | counter_block += len(indices) 147 | # Number of augmentation per image 148 | N = int(v) 149 | for j in indices: 150 | for l in range(N): 151 | train_duplicated[counter,...] =self.train[j] 152 | labels_duplicated[counter,...] = self.train_label[j] 153 | bodypart_duplicated[counter] = k 154 | counter +=1 155 | 156 | train_duplicated, labels_duplicated, bodypart_duplicated = shuffle_together_simple(train_duplicated, labels_duplicated, bodypart_duplicated) 157 | self.no_images = train_duplicated.shape[0] 158 | return train_duplicated, labels_duplicated, bodypart_duplicated 159 | 160 | def augmentator(self, index): 161 | " This function defines the trainsformations to apply on the images, and if required on the labels" 162 | 163 | translate_max = 0.01 164 | rotate_max = 15 165 | shear_max = 2 166 | 167 | affine_trasform = iaa.Affine( translate_percent={"x": (-translate_max, translate_max), 168 | "y": (-translate_max, translate_max)}, # translate by +- 169 | rotate=(-rotate_max, rotate_max), # rotate by -rotate_max to +rotate_max degrees 170 | shear=(-shear_max, shear_max), # shear by -shear_max to +shear_max degrees 171 | order=[1], # use nearest neighbour or bilinear interpolation (fast) 172 | cval=125, # if mode is constant, use a cval between 0 and 255 173 | mode="reflect", 174 | #mode = "", 175 | name="Affine", 176 | ) 177 | 178 | 179 | spatial_aug = iaa.Sequential([iaa.Fliplr(0.5), iaa.Flipud(0.5), affine_trasform]) 180 | 181 | other_aug = iaa.SomeOf((1, None), 182 | [ 183 | iaa.OneOf([ 184 | iaa.GaussianBlur((0, 0.4)), # blur images with a sigma between 0 and 1.0 185 | iaa.ElasticTransformation(alpha=(0.5, 1.5), sigma=0.25), # very few 186 | 187 | ]), 188 | 189 | ]) 190 | 191 | ''' 192 | affine_trasform = iaa.Affine( translate_percent={"x": (-translate_max, translate_max), 193 | "y": (-translate_max, translate_max)}, # translate by +- 194 | rotate=(-rotate_max, rotate_max), # rotate by -rotate_max to +rotate_max degrees 195 | shear=(-shear_max, shear_max), # shear by -shear_max to +shear_max degrees 196 | order=[1], # use nearest neighbour or bilinear interpolation (fast) 197 | cval=125, # if mode is constant, use a cval between 0 and 255 198 | mode="reflect", 199 | name="Affine", 200 | ) 201 | 202 | 203 | spatial_aug = iaa.Sequential([iaa.Fliplr(0.5), iaa.Flipud(0.5), affine_trasform]) 204 | 205 | sometimes = lambda aug: iaa.Sometimes(0.5, aug) 206 | 207 | other_aug = iaa.SomeOf((1, None), 208 | [ 209 | iaa.OneOf([ 210 | iaa.GaussianBlur((0, 0.4)), # blur images with a sigma between 0 and 1.0 211 | ]), 212 | 213 | ]) 214 | 215 | elastic_aug = iaa.SomeOf((1, None), 216 | [ 217 | iaa.OneOf([ 218 | sometimes(iaa.ElasticTransformation(alpha=(50, 60), sigma=16)), # move pixels locally around (with random strengths) 219 | ]), 220 | 221 | ]) 222 | 223 | 224 | # Defines augmentations to perform on the images and their labels 225 | augmentators = [spatial_aug,other_aug, elastic_aug] 226 | spatial_det = augmentators[0].to_deterministic() 227 | # to deterministic is needed to apply exactly the same spatial transformation to the data and the labels 228 | other_aug = augmentators[1] 229 | # When only adding noise there's no need to perform the transformation on the label 230 | elastic_det = augmentators[2].to_deterministic() 231 | 232 | image_aug = spatial_det.augment_image(self.train[index]) 233 | label_aug = spatial_det.augment_image(255*self.train_label[index]) 234 | 235 | image_aug = elastic_det.augment_image(image_aug) 236 | label_aug = elastic_det.augment_image(label_aug) 237 | 238 | img_crop, label_crop = random_crop(image_aug,label_aug,0.,0.4) 239 | image_aug = other_aug.augment_image(img_crop ) 240 | 241 | label_aug = label_crop 242 | 243 | 244 | label_aug = to_categorical(np.argmax(label_aug,axis=-1), num_classes = 3) # only needed if performing elastic transformations 245 | # Otherwise careful, returns [255,0,0] not [1,0,0] ! 246 | ''' 247 | augmentator = [spatial_aug,other_aug] 248 | spatial_det = augmentator[0].to_deterministic() 249 | other_det = augmentator[1] 250 | 251 | image_aug = spatial_det.augment_image(self.train[index]) 252 | label_aug = spatial_det.augment_image(self.train_label[index]) 253 | img_crop, label_crop = random_crop(image_aug,label_aug,0.1,0.4) 254 | image_aug = other_det.augment_image(img_crop ) 255 | label_aug = to_categorical(np.argmax(label_crop,axis=-1), num_classes = self.no_classes) 256 | return image_aug, label_aug 257 | 258 | def generator_with_augmentations(self): 259 | "This generator is used to feed the data to the training algorithm. Given a batch size, randomly divides the training data into batches and augment each image once randomly. " 260 | batch_images = np.zeros((self.batch_size, self.width, self.height, 1)) 261 | batch_labels = np.zeros((self.batch_size, self.width*self.height, self.no_classes)) # X and Y coordinates 262 | while True: 263 | indices = np.asarray(range(0, self.no_images)) 264 | np.random.shuffle(indices) 265 | for idx in range(0, len(indices), self.batch_size): 266 | batch_indices = indices[idx:idx+self.batch_size] 267 | batch_indices.sort() 268 | batch_indices = batch_indices.tolist() 269 | for i, idx2 in enumerate(batch_indices): 270 | augmented_image, augmented_label = self.augmentator(idx) 271 | augmented_label = augmented_label.reshape(self.width*self.height, self.no_classes) 272 | batch_images[i] = augmented_image 273 | batch_labels[i] = augmented_label 274 | 275 | yield (batch_images,batch_labels) 276 | 277 | def compile(self): 278 | spec = importlib.util.spec_from_file_location("module.name", self.model_path) 279 | print(self.model_path) 280 | self.model_module = importlib.util.module_from_spec(spec) 281 | spec.loader.exec_module(self.model_module) 282 | self.model = self.model_module.model(l2_lambda = self.reg, input_shape = (self.height, self.width, self.channels), classes = self.no_classes, kernel_size = self.kernel_size, filter_depth = self.filters) 283 | self.model.compile(optimizer = rmsprop(lr = self.lrate, decay = 1e-6), loss = self.loss, metrics = ['accuracy']) 284 | #self.model.compile(optimizer = Adam(lr = self.lrate), loss = self.loss, metrics = ['accuracy']) 285 | #self.model.compile(optimizer = SGD(lr = self.lrate, momentum = 0.9, nesterov = True), loss = self.loss, metrics = ['accuracy']) 286 | print("Model loaded and compiled succesfully.") 287 | 288 | def fit(self): 289 | csv_logger = CSVLogger(self.save_folder + "/" + self.name + ".csv") 290 | #save_path = self.name + "_{epoch:03d}.h5" 291 | save_path = self.name + ".h5" 292 | save_path = self.save_folder + "/" + save_path 293 | earlystop = EarlyStopping(monitor="val_loss", min_delta = 0, patience = 20, verbose = 1, mode = 'min') 294 | checkpoint = ModelCheckpoint(save_path, monitor = "val_loss", verbose = 1, save_best_only = True, save_weights_only = False, mode = "auto", period = 1) 295 | #tb = TensorBoard(log_dir = os.path.join(self.save_folder,'tboard'), batch_size = 1, write_graph = True, write_images = False) 296 | 297 | self.model.fit_generator(self.generator_with_augmentations(), steps_per_epoch = self.no_images // self.batch_size, epochs = self.no_epochs, callbacks = [csv_logger, checkpoint, earlystop], validation_data = (self.val, self.val_label)) 298 | #self.model.fit_generator(self.generator(), steps_per_epoch = self.no_images // self.batch_size, epochs = self.no_epochs, callbacks = [csv_logger, checkpoint, earlystop], validation_data = (self.val, self.val_label)) 299 | 300 | 301 | 302 | -------------------------------------------------------------------------------- /Training/create_h5.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os, sys 3 | import numpy as np 4 | import cv2 5 | import glob 6 | from random import shuffle 7 | from IPython.display import clear_output 8 | import h5py 9 | from sklearn.model_selection import train_test_split 10 | 11 | 12 | def write_h5(hdf5_name,images_train,labels_train,body_train, file_train, images_test,labels_test,body_test ,\ 13 | file_test,images_val,labels_val,body_val,file_val): 14 | 15 | hdf5_file = h5py.File(hdf5_name, mode='w') 16 | # Attributes 17 | hdf5_file.attrs['image_size'] = images_train.shape[2] 18 | hdf5_file.attrs['max_value'] = 1. 19 | hdf5_file.attrs['min_value'] = 0. 20 | print(body_train.shape) 21 | # Datasets 22 | hdf5_file.create_dataset("train_img", images_train.shape, np.float64) 23 | hdf5_file.create_dataset("train_label", labels_train.shape, np.int) 24 | hdf5_file.create_dataset("train_bodypart", body_train.shape, 'S10') 25 | hdf5_file.create_dataset("train_file", file_train.shape, 'S60') 26 | 27 | hdf5_file.create_dataset("test_img", images_test.shape, np.float64) 28 | hdf5_file.create_dataset("test_label", labels_test.shape, np.int) 29 | hdf5_file.create_dataset("test_bodypart", body_test.shape, 'S10') 30 | hdf5_file.create_dataset("test_file", file_test.shape, 'S60') 31 | 32 | hdf5_file.create_dataset("val_img", images_val.shape, np.float64) 33 | hdf5_file.create_dataset("val_label", labels_val.shape, np.int) 34 | hdf5_file.create_dataset("val_bodypart", body_val.shape, 'S10') 35 | hdf5_file.create_dataset("val_file", file_val.shape, 'S60') 36 | 37 | categories = ['train','test','val'] 38 | images_split = [images_train, images_test, images_val] 39 | labels_split = [labels_train, labels_test, labels_val] 40 | bodys_split = [body_train, body_test, body_val] 41 | names_split = [file_train, file_test, file_val] 42 | for j in range(len(images_split)): 43 | for i in range(images_split[j].shape[0]): 44 | clear_output(wait=True) 45 | # Zero mean 46 | img = images_split[j][i,...] - np.mean(images_split[j][i,...]) 47 | # Normalization -> perform after augmentation 48 | img = (img-np.min(img))/(np.max(img) - np.min(img)) 49 | 50 | hdf5_file[categories[j] + '_img'][i, ...] = img 51 | # same for labels 52 | #labels_simple = label_generate.GenerateOutput(labels_split[j][i,...]) 53 | #labels_onehot = onehot.OneHot(labels_simple) 54 | #hdf5_file[categories[j] + "_label"][i, ...] = labels_onehot 55 | hdf5_file[categories[j] + "_label"][i, ...] = labels_split[j][i,...] 56 | hdf5_file[categories[j] + "_bodypart"][i] = bodys_split[j][i] 57 | hdf5_file[categories[j] + "_file"][i] = names_split[j][i] 58 | #print('Saving image %i/%i in %s path' %(i+1,images_split[j].shape[0], categories[j])) 59 | 60 | hdf5_file.close() 61 | -------------------------------------------------------------------------------- /Training/generate_parameters.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import json 3 | import os 4 | import sys 5 | 6 | params_name = "parameters.txt" 7 | save_folder = "Model0" 8 | 9 | 10 | model = "Architectures/DoubleLinked.py" 11 | name = "DLs200_64" 12 | 13 | lrate = 0.0001 14 | reg = 0 15 | batch_size = 5 16 | kernel_size = 5 17 | filter_list = [64,128,256, 512, 1024] 18 | #loss = "fancy" 19 | loss = "categorical_crossentropy" 20 | #loss = "jaccard" 21 | #data = "Humans_CT_Phantom_s224.hdf5" 22 | data = "cv_dataset_s200.hdf5" 23 | no_epochs = 5000 24 | duplicate = True 25 | 26 | 27 | d = {"name":name, 28 | "model_path": model, 29 | "data_path": data, 30 | "save_folder": save_folder, 31 | "kernel_size": kernel_size, 32 | "batch_size": batch_size, 33 | "filters": filter_list, 34 | "lrate": lrate, 35 | "reg":reg, 36 | "loss": loss, 37 | "no_epochs": no_epochs, 38 | "duplicate": duplicate} 39 | 40 | 41 | if (os.path.isfile(params_name)): 42 | confirm_metada = input("Warning params file exists, continue? (y/n) ") 43 | if(confirm_metada == "y"): 44 | os.remove(params_name) 45 | else: 46 | sys.exit() 47 | 48 | 49 | with open(params_name, 'w') as fp: 50 | json.dump(d, fp) 51 | 52 | 53 | #models = ["DoubleLinked_s224", "DoubleLinked_s200"] 54 | #lrates = [1,2,3] 55 | #regs = [1,2,3] 56 | #batch_sizes = [5,20] 57 | #kernel_sizes = [3,5] 58 | #filter_list = [[64,128,256,512,1024,2048], [32,64,128,256,512,1024]] 59 | #losses = ["categorical_crossentropy", "fancy"] 60 | #im_size = ["Humans_CT_Phantom_s224.hdf5", "Humans_CT_Phantom_s200.hdf5"] 61 | # 62 | #counter = 0 63 | #for size in im_size: 64 | # for loss in losses: 65 | # for filters in filter_list: 66 | # for lrate in lrates: 67 | # for reg in regs: 68 | # for batch in batch_sizes: 69 | # for kernel in kernel_sizes: 70 | # 71 | # d["name"] = "Model_" + counter 72 | # d["save_folder"] = "Model_" + counter 73 | # d["data_path"] = size 74 | # d["lrate"] = lrate 75 | # d["reg"] = reg 76 | # d["batch_size"] = batch 77 | # d["filters"] = filters 78 | # d["kernel_size"] = kernel 79 | # d["loss"] = loss 80 | # d["name"] = name 81 | # 82 | # counter += 1 83 | # 84 | 85 | -------------------------------------------------------------------------------- /Training/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import TrainingClass 3 | import json 4 | import os 5 | import sys 6 | from PostProcessing import PostProcessing 7 | import glob 8 | import shutil 9 | import webbrowser 10 | import time 11 | #from Killer import GracefulKiller 12 | #killer = GracefulKiller() 13 | 14 | param_files = glob.glob("aug*.txt") 15 | print("I will train on all these parameter files:\n") 16 | print(*param_files, sep = "\n") 17 | 18 | #print("Opening tensorboard... \n") 19 | #tb_url = "http://127.0.0.1:7007/" 20 | #webbrowser.open(tb_url) 21 | 22 | for file in param_files: 23 | params = json.load(open(file,'r')) 24 | 25 | save_folder = params["save_folder"] 26 | if(os.path.isdir(save_folder)): 27 | rm_folder = input("Warning, folder exists! Delete? (y/n) ") 28 | if(rm_folder == "y"): 29 | shutil.rmtree(save_folder) 30 | else: 31 | sys.exit() 32 | 33 | os.mkdir(save_folder) 34 | 35 | #tensorboard stuff 36 | #tbdir = os.path.join(save_folder, "tboard") 37 | #os.mkdir(tbdir) 38 | #os.system("killall tensorboard") 39 | #os.system("tensorboard --logdir=" + tbdir + " --port=7007 &") 40 | training = TrainingClass.TrainingClass(**params) 41 | try: 42 | training.fit() 43 | except: 44 | print("\n Dying... \n") 45 | 46 | print("Running post training analysis...\n") 47 | 48 | h5_files = np.sort(glob.glob(os.path.join(params["save_folder"], "*.h5"))) 49 | try: 50 | pp = PostProcessing( h5_files[-1], params["data_path"], device = "gpu") 51 | except: 52 | print("You haven't trained anything?") 53 | continue 54 | 55 | pfile = open(os.path.join(params["save_folder"], "results.txt"), "w") 56 | pfile.write("Overall perfomance: \n") 57 | accuracy_test, trainable_count = pp.evaluate_overall(device = "gpu") 58 | pfile.write("Accuracy: {} \nTrainable parameters: {} \n\n".format(round(accuracy_test,2)*100, trainable_count) ) 59 | 60 | pfile.write("Performance per class:\n") 61 | beam_accuracy, tissue_accuracy, bone_accuracy = pp.evaluate_perclass() 62 | pfile.write(" Open beam: {} \n Soft Tissue: {} \n Bone: {}\n\n".format(round(beam_accuracy,2)*100, round(tissue_accuracy,2)*100, round(bone_accuracy,2)*100)) 63 | 64 | pfile.write("True Positives and False Positives:\n") 65 | tp, fp = pp.tpfp() 66 | pfile.write(" TP: {} \n FP {} \n\n ".format(round(tp,2)*100, round(fp,2)*100)) 67 | 68 | pfile.write("Threshold 90% \n") 69 | thresh90 = pp.thresholding(0.9) 70 | tp90, fp90 = pp.tpfp(thresh90) 71 | pfile.write(" TP90: {} \n FP90 {}\n \n ".format(round(tp90,2)*100, round(fp90,2)*100)) 72 | 73 | pfile.write("Threshold 99% \n") 74 | thresh99 = pp.thresholding(0.99) 75 | tp99, fp99 = pp.tpfp(thresh99) 76 | pfile.write(" TP99: {} \n FP99 {}\n \n ".format(round(tp99,2)*100, round(fp99,2)*100)) 77 | 78 | pfile.close() 79 | 80 | lc_fig, lc_ax = pp.learning_curve(os.path.join(params["save_folder"], params["name"] + ".csv")) 81 | lc_fig.savefig(os.path.join(params["save_folder"], "learning_curve.png")) 82 | 83 | rc_fig, rc_ax = pp.ROC_curve() 84 | rc_fig.savefig( os.path.join(params["save_folder"] , "roc_curve.png") ) 85 | 86 | 87 | -------------------------------------------------------------------------------- /UNet.py: -------------------------------------------------------------------------------- 1 | from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img 2 | from keras.models import Model 3 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D, Reshape 4 | from keras.layers import Activation 5 | from keras.optimizers import * 6 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 7 | from keras import backend as keras 8 | 9 | def model(input_shape=(64,64,3), classes=3, kernel_size = 3, filter_depth = (64,128,256,512,1024)): 10 | 11 | img_input = Input(shape=input_shape) 12 | 13 | #Encoder 14 | conv1 = Conv2D(filter_depth[0], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(img_input) 15 | conv1 = Conv2D(filter_depth[0], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv1) 16 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 17 | 18 | conv2 = Conv2D(filter_depth[1], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(pool1) 19 | conv2 = Conv2D(filter_depth[1], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv2) 20 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 21 | 22 | conv3 = Conv2D(filter_depth[2], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(pool2) 23 | conv3 = Conv2D(filter_depth[2], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv3) 24 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 25 | 26 | conv4 = Conv2D(filter_depth[3], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(pool3) 27 | conv4 = Conv2D(filter_depth[3], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv4) 28 | drop4 = Dropout(0.5)(conv4) 29 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 30 | 31 | conv5 = Conv2D(filter_depth[4], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(drop4) 32 | conv5 = Conv2D(filter_depth[4], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv5) 33 | drop5 = Dropout(0.5)(conv5) 34 | 35 | #Decoder 36 | up6 = UpSampling2D(size=(2, 2))(pool4) 37 | conv6 = Conv2D(filter_depth[3], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(drop5) 38 | conv6 = Conv2D(filter_depth[3], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv6) 39 | 40 | up7 = UpSampling2D(size=(2, 2))(conv6) 41 | conv7 = Conv2D(filter_depth[2], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(up7) 42 | conv7 = Conv2D(filter_depth[2], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv7) 43 | 44 | up8 = UpSampling2D(size=(2, 2))(conv7) 45 | conv8 = Conv2D(filter_depth[1], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(up8) 46 | conv8 = Conv2D(filter_depth[1], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv8) 47 | 48 | up9 = UpSampling2D(size=(2, 2))(conv8) 49 | copy9 = Concatenate() 50 | conv9 = Conv2D(filter_depth[0], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(up9) 51 | conv9 = Conv2D(filter_depth[0], (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv9) 52 | conv9 = Conv2D(2, (kernel_size,kernel_size), activation = 'relu', padding = 'same')(conv9) 53 | 54 | x = Conv2D(classes, (1,1), padding="valid")(conv9) 55 | 56 | x = Reshape((input_shape[0]*input_shape[1],classes))(x) 57 | x = Activation("softmax")(x) 58 | 59 | model = Model(img_input, x) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /XNet.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from keras.layers import Input, Concatenate, Conv2D, MaxPooling2D, UpSampling2D, Dropout, Cropping2D 3 | from keras.layers import BatchNormalization, Reshape, Layer 4 | from keras.layers import Activation, Flatten, Dense 5 | from keras.optimizers import * 6 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 7 | from keras.metrics import categorical_accuracy 8 | from keras import backend as K 9 | from keras import losses 10 | 11 | def model(input_shape=(64,64,3), classes=3, kernel_size = 3, filter_depth = (64,128,256,512,0)): 12 | 13 | img_input = Input(shape=input_shape) 14 | 15 | # Encoder 16 | conv1 = Conv2D(filter_depth[0], (kernel_size, kernel_size), padding="same")(img_input) 17 | batch1 = BatchNormalization()(conv1) 18 | act1 = Activation("relu")(batch1) 19 | pool1 = MaxPooling2D(pool_size=(2, 2))(act1) 20 | #100x100 21 | 22 | conv2 = Conv2D(filter_depth[1], (kernel_size, kernel_size), padding="same")(pool1) 23 | batch2 = BatchNormalization()(conv2) 24 | act2 = Activation("relu")(batch2) 25 | pool2 = MaxPooling2D(pool_size=(2, 2))(act2) 26 | #50x50 27 | 28 | conv3 = Conv2D(filter_depth[2], (kernel_size, kernel_size), padding="same")(pool2) 29 | batch3 = BatchNormalization()(conv3) 30 | act3 = Activation("relu")(batch3) 31 | pool3 = MaxPooling2D(pool_size=(2, 2))(act3) 32 | #25x25 33 | 34 | #Flat 35 | conv4 = Conv2D(filter_depth[3], (kernel_size, kernel_size), padding="same")(pool3) 36 | batch4 = BatchNormalization()(conv4) 37 | act4 = Activation("relu")(batch4) 38 | #25x25 39 | 40 | conv5 = Conv2D(filter_depth[3], (kernel_size, kernel_size), padding="same")(act4) 41 | batch5 = BatchNormalization()(conv5) 42 | act5 = Activation("relu")(batch5) 43 | #25x25 44 | 45 | #Up 46 | up6 = UpSampling2D(size=(2, 2))(act5) 47 | conv6 = Conv2D(filter_depth[2], (kernel_size, kernel_size), padding="same")(up6) 48 | batch6 = BatchNormalization()(conv6) 49 | act6 = Activation("relu")(batch6) 50 | concat6 = Concatenate()([act3,act6]) 51 | #50x50 52 | 53 | up7 = UpSampling2D(size=(2, 2))(concat6) 54 | conv7 = Conv2D(filter_depth[1], (kernel_size, kernel_size), padding="same")(up7) 55 | batch7 = BatchNormalization()(conv7) 56 | act7 = Activation("relu")(batch7) 57 | concat7 = Concatenate()([act2,act7]) 58 | #100x100 59 | 60 | #Down 61 | conv8 = Conv2D(filter_depth[1], (kernel_size, kernel_size), padding="same")(concat7) 62 | batch8 = BatchNormalization()(conv8) 63 | act8 = Activation("relu")(batch8) 64 | pool8 = MaxPooling2D(pool_size=(2, 2))(act8) 65 | #50x50 66 | 67 | conv9 = Conv2D(filter_depth[2], (kernel_size, kernel_size), padding="same")(pool8) 68 | batch9 = BatchNormalization()(conv9) 69 | act9 = Activation("relu")(batch9) 70 | pool9 = MaxPooling2D(pool_size=(2, 2))(act9) 71 | 72 | #25x25 73 | 74 | #Flat 75 | conv10 = Conv2D(filter_depth[3], (kernel_size, kernel_size), padding="same")(pool9) 76 | batch10 = BatchNormalization()(conv10) 77 | act10 = Activation("relu")(batch10) 78 | #25x25 79 | 80 | conv11 = Conv2D(filter_depth[3], (kernel_size, kernel_size), padding="same")(act10) 81 | batch11 = BatchNormalization()(conv11) 82 | act11 = Activation("relu")(batch11) 83 | #25x25 84 | 85 | #Encoder 86 | up12 = UpSampling2D(size=(2, 2))(act11) 87 | conv12 = Conv2D(filter_depth[2], (kernel_size, kernel_size), padding="same")(up12) 88 | batch12 = BatchNormalization()(conv12) 89 | act12 = Activation("relu")(batch12) 90 | concat12 = Concatenate()([act9,act12]) 91 | #50x50 92 | 93 | up13 = UpSampling2D(size=(2, 2))(concat12) 94 | conv13 = Conv2D(filter_depth[1], (kernel_size, kernel_size), padding="same")(up13) 95 | batch13 = BatchNormalization()(conv13) 96 | act13 = Activation("relu")(batch13) 97 | concat13 = Concatenate()([act8,act13]) 98 | #100x100 99 | 100 | up14 = UpSampling2D(size=(2, 2))(concat13) 101 | conv14 = Conv2D(filter_depth[0], (kernel_size, kernel_size), padding="same")(up14) 102 | batch14 = BatchNormalization()(conv14) 103 | act14 = Activation("relu")(batch14) 104 | concat14 = Concatenate()([act1,act14]) 105 | #200x200 106 | 107 | conv15 = Conv2D(classes, (1,1), padding="valid")(concat14) 108 | 109 | 110 | reshape15 = Reshape((input_shape[0]*input_shape[1],classes))(conv15) 111 | act15 = Activation("softmax")(reshape15) 112 | 113 | model = Model(img_input, act15) 114 | 115 | return model 116 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.0 2 | affine==2.2.2 3 | astor==0.7.1 4 | attrs==18.2.0 5 | Click==7.0 6 | click-plugins==1.0.4 7 | cligj==0.5.0 8 | cloudpickle==0.7.0 9 | cycler==0.10.0 10 | dask==1.1.1 11 | decorator==4.3.2 12 | gast==0.2.2 13 | grpcio==1.18.0 14 | h5py==2.9.0 15 | imageio==2.5.0 16 | imgaug==0.2.8 17 | Keras==2.2.4 18 | Keras-Applications==1.0.7 19 | Keras-Preprocessing==1.0.9 20 | kiwisolver==1.0.1 21 | Markdown==3.0.1 22 | matplotlib==3.0.2 23 | networkx==2.2 24 | numpy==1.16.1 25 | opencv-python==4.0.0.21 26 | pandas==0.24.1 27 | Pillow==5.4.1 28 | protobuf==3.6.1 29 | pyparsing==2.3.1 30 | python-dateutil==2.8.0 31 | pytz==2018.9 32 | PyWavelets==1.0.1 33 | pyyaml>=4.2b1 34 | rasterio==1.0.18 35 | scikit-image==0.14.2 36 | scikit-learn==0.20.2 37 | scipy==1.2.1 38 | Shapely==1.6.4.post2 39 | six==1.12.0 40 | sklearn==0.0 41 | snuggs==1.4.2 42 | tensorboard==1.12.2 43 | tensorflow==1.12.0 44 | termcolor==1.1.0 45 | tifffile==2019.1.30 46 | toolz==0.9.0 47 | Werkzeug==0.14.1 48 | --------------------------------------------------------------------------------