├── unet ├── __init__.py └── unet_modular │ ├── __init__.py │ ├── utilities.py │ ├── progbar.py │ └── unet_base.py ├── .gitignore ├── images ├── Vis_combined_label_1.png ├── Vis_combined_label_2.png ├── combined_label_train.png ├── Vis_non_combined_label_1.png ├── Vis_non_combined_label_2.png └── non_combined_label_train.png ├── data_loader.py ├── README.md └── segmentation.py /unet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unet/unet_modular/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | logs/ 3 | *.log 4 | .DS_Store 5 | threeDunet_base.py 6 | -------------------------------------------------------------------------------- /images/Vis_combined_label_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitghosh/chromosome_segementation/HEAD/images/Vis_combined_label_1.png -------------------------------------------------------------------------------- /images/Vis_combined_label_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitghosh/chromosome_segementation/HEAD/images/Vis_combined_label_2.png -------------------------------------------------------------------------------- /images/combined_label_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitghosh/chromosome_segementation/HEAD/images/combined_label_train.png -------------------------------------------------------------------------------- /images/Vis_non_combined_label_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitghosh/chromosome_segementation/HEAD/images/Vis_non_combined_label_1.png -------------------------------------------------------------------------------- /images/Vis_non_combined_label_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitghosh/chromosome_segementation/HEAD/images/Vis_non_combined_label_2.png -------------------------------------------------------------------------------- /images/non_combined_label_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitghosh/chromosome_segementation/HEAD/images/non_combined_label_train.png -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import skimage as sk 4 | #print sk.__version__ 5 | from skimage import io 6 | from matplotlib import pyplot as plt 7 | import random 8 | 9 | h5f = h5py.File('/home/users/rohitg/LowRes_13434_overlapping_pairs.h5','r') 10 | pairs = h5f['dataset_1'][:] 11 | h5f.close() 12 | 13 | train_valid_split = int(0.8*pairs.shape[0]) 14 | 15 | 16 | def weights_assign(train, combine_label): 17 | if combine_label: 18 | weight_list = [0.1, 1, 10] 19 | else: 20 | weight_list = [0.1, 1, 1, 10] 21 | if train == 0: 22 | return weight_list[0] 23 | elif train == 1: 24 | return weight_list[1] 25 | elif train == 2: 26 | return weight_list[2] 27 | elif train == 3: 28 | return weight_list[3] 29 | 30 | def labels_combine(train): 31 | new_train = train 32 | new_train[train==2] = 1 33 | new_train[train==3] = 2 34 | return new_train 35 | 36 | def weights_matrix_gen(train,combine_label): 37 | func = np.vectorize(weights_assign, otypes=[np.float64]) 38 | weights = func(train,combine_label) 39 | return weights 40 | 41 | 42 | def train_data_loader(batch_size = 10, combine_label = False): 43 | for i in range(0,train_valid_split,batch_size): 44 | # indices = random.sample(range(i,i+200),batch_size) 45 | indices = range(i,i+batch_size) 46 | X_train = pairs[indices,:,:,0] 47 | Y_train = pairs[indices,:,:,1] 48 | if combine_label: 49 | Y_train = labels_combine(Y_train) 50 | if np.amax(Y_train) > 3: 51 | continue 52 | weights = weights_matrix_gen(Y_train,combine_label) 53 | X_train = X_train.reshape((batch_size,1,X_train.shape[1],X_train.shape[2])) 54 | Y_train = Y_train.reshape((batch_size,1,Y_train.shape[1],Y_train.shape[2])) 55 | yield X_train, Y_train, weights 56 | 57 | def valid_data_loader(nb_val_samples = 200, batch_size = 10, combine_label = False): 58 | for i in range(train_valid_split,train_valid_split + nb_val_samples,batch_size): 59 | indices = range(i,i+batch_size) 60 | X_valid = pairs[indices,:,:,0] 61 | Y_valid = pairs[indices,:,:,1] 62 | if combine_label: 63 | Y_valid = labels_combine(Y_valid) 64 | if np.amax(Y_valid) > 3: 65 | continue 66 | X_valid = X_valid.reshape((batch_size,1,X_valid.shape[1],X_valid.shape[2])) 67 | Y_valid = Y_valid.reshape((batch_size,1,Y_valid.shape[1],Y_valid.shape[2])) 68 | yield X_valid, Y_valid 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chromosome_segementation 2 | 3 | ## Introduction 4 | 5 | The repo is aimed for segmentation of overlapping chromosomes, as described in the problem statement given on [AI ON website](http://ai-on.org/projects/visual-segmentation-of-chromosomal-preparations.html) 6 | 7 | The repo uses [U-Net](https://arxiv.org/abs/1505.04597), state-of-the art segmentation net for segmenting overlapping chromosomes. The repo used Lasagne, a Theano based library for segmentation. 8 | 9 | ## Methodology 10 | 11 | The data consists of 4 classes, where class 4 is the common region between 2 overlapping chromosomes. The classes 1 & 2 , are non-overlapping part of each of the chromosomes. Class 0 is the background 12 | 13 | The performance of the net was observed using mean_dice_score. It was computed as `dice_score = 2*I/(GT + PL)` where I is the sum of the number of pixels predicted correctly except background, GT is the number of pixels which belong to ground-truth except background and PL is the number of pixels in predicted image except background. 14 | 15 | There were 2 methods of training attempted 16 | - Treating all the classes independently (param combine_label = True in segmentation.py) 17 | - Treating Class 1 & Class 2 as same i.e. Class 1 and Class 3 as Class 2 (param combine_label = False in segmentation.py)Assumption being the non-overlapping parts inherently aren't different in each chromosomes. Then we can apply conventional CV methods like watershed algorithm to distinguish between the non-overlapping blobs 18 | 19 | 24 | 25 | ##TODO 26 | - Build the Watershed algorithm post-processing part to the pipleline in case of combined_label training 27 | - Optimise the parameters better for the nets in each case 28 | - Build a full pipeline for test_data generator using best validated model 29 | 30 | 31 | ## Results 32 | 33 | With combined labels, could reach a dice score as high as **0.97**. Without combined labels, could reach a dice score as high as **0.81**. 34 | 35 | #### Predictions for combined_label 36 | ![predict_combined_new](/images/Vis_combined_label_2.png) 37 | ![predict_combined_new](/images/Vis_combined_label_1.png) 38 | #### Predictions without combining labels 39 | ![predict_non_combined_new](/images/Vis_non_combined_label_1.png) 40 | ![predict_non_combined_new](/images/Vis_non_combined_label_2.png) 41 | -------------------------------------------------------------------------------- /unet/unet_modular/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import lasagne 4 | 5 | def gen_config(layers,inchannels,startchannels,outputs,act,ltype,optimiser): 6 | config = {} 7 | config['act'] = act 8 | config['layertype'] = ltype 9 | config['layers'] = {} 10 | config['layers']['output'] = outputs 11 | 12 | config['layers']['enc'] = [] 13 | 14 | tmp = startchannels 15 | for i in range(layers): 16 | if i == 0: 17 | config['layers']['enc'].append( 18 | { "conv" : [ tmp, tmp ], 19 | "shape" : [ None,inchannels,None,None] } 20 | ) 21 | else: 22 | config['layers']['enc'].append( 23 | { "conv" : [ tmp, tmp ] } 24 | ) 25 | 26 | tmp *= 2 27 | config['layers']['bottom'] = {\ 28 | "conv" : [ tmp, tmp ] } 29 | tmp = tmp/2 30 | config['layers']['dec'] = [] 31 | for i in range(layers): 32 | config['layers']['dec'].append( 33 | { "conv" : [ tmp, tmp, tmp ] } ) 34 | tmp /= 2 35 | config['train_params'] = { 36 | 'lr' : 0.0001, 37 | 'mom' : 0.9 38 | } 39 | config['optimiser'] = optimiser 40 | return config 41 | 42 | def get_activation(act): 43 | if act=='relu': 44 | return lasagne.nonlinearities.rectify 45 | elif act == 'lrelu': 46 | return lasagne.nonlinearities.leaky_rectify 47 | elif act == 'vlrelu': 48 | return lasagne.nonlinearities.very_leaky_rectify 49 | elif act == 'elu': 50 | return lasagne.nonlinearities.elu 51 | elif act == 'sigmoid': 52 | return lasagne.nonlinearities.sigmoid 53 | else: 54 | return None 55 | 56 | def get_updates(cost,params,optimiser, lr): 57 | if optimiser == 'nesterov_momentum': 58 | return lasagne.updates.nesterov_momentum(cost, params, learning_rate=lr, momentum=0.9) 59 | if optimiser == 'adagrad': 60 | return lasagne.updates.adagrad(cost, params, learning_rate=lr, epsilon=1e-06) 61 | if optimiser == 'rmsprop': 62 | return lasagne.updates.rmsprop(cost, params, learning_rate=lr, rho=0.9, epsilon=1e-06) 63 | if optimiser == 'adadelta': 64 | return lasagne.updates.adadelta(cost, params, learning_rate=lr, rho=0.95, epsilon=1e-06) 65 | if optimiser == 'adam': 66 | return lasagne.updates.adam(cost, params, learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-08) 67 | if optimiser == 'adamax': 68 | return lasagne.updates.adamax(cost, params, learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-08) 69 | -------------------------------------------------------------------------------- /unet/unet_modular/progbar.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy as np 3 | import time 4 | import sys 5 | import six 6 | 7 | 8 | class Progbar(object): 9 | def __init__(self, target, width=30, verbose=1, interval=0.01): 10 | ''' 11 | @param target: total number of steps expected 12 | @param interval: minimum visual progress update interval (in seconds) 13 | ''' 14 | self.width = width 15 | self.target = target 16 | self.sum_values = {} 17 | self.unique_values = [] 18 | self.start = time.time() 19 | self.last_update = 0 20 | self.interval = interval 21 | self.total_width = 0 22 | self.seen_so_far = 0 23 | self.verbose = verbose 24 | 25 | def update(self, current, values=[], force=False): 26 | ''' 27 | @param current: index of current step 28 | @param values: list of tuples (name, value_for_last_step). 29 | The progress bar will display averages for these values. 30 | @param force: force visual progress update 31 | ''' 32 | for k, v in values: 33 | if k not in self.sum_values: 34 | self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far] 35 | self.unique_values.append(k) 36 | else: 37 | self.sum_values[k][0] += v * (current - self.seen_so_far) 38 | self.sum_values[k][1] += (current - self.seen_so_far) 39 | self.seen_so_far = current 40 | 41 | now = time.time() 42 | if self.verbose == 1: 43 | if not force and (now - self.last_update) < self.interval: 44 | return 45 | 46 | prev_total_width = self.total_width 47 | sys.stdout.write("\b" * prev_total_width) 48 | sys.stdout.write("\r") 49 | 50 | numdigits = int(np.floor(np.log10(self.target))) + 1 51 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 52 | bar = barstr % (current, self.target) 53 | prog = float(current) / self.target 54 | prog_width = int(self.width * prog) 55 | if prog_width > 0: 56 | bar += ('=' * (prog_width-1)) 57 | if current < self.target: 58 | bar += '>' 59 | else: 60 | bar += '=' 61 | bar += ('.' * (self.width - prog_width)) 62 | bar += ']' 63 | sys.stdout.write(bar) 64 | self.total_width = len(bar) 65 | 66 | if current: 67 | time_per_unit = (now - self.start) / current 68 | else: 69 | time_per_unit = 0 70 | eta = time_per_unit * (self.target - current) 71 | info = '' 72 | if current < self.target: 73 | info += ' - ETA: %ds' % eta 74 | else: 75 | info += ' - %ds' % (now - self.start) 76 | for k in self.unique_values: 77 | info += ' - %s:' % k 78 | if type(self.sum_values[k]) is list: 79 | avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) 80 | if abs(avg) > 1e-3: 81 | info += ' %.4f' % avg 82 | else: 83 | info += ' %.4e' % avg 84 | else: 85 | info += ' %s' % self.sum_values[k] 86 | 87 | self.total_width += len(info) 88 | if prev_total_width > self.total_width: 89 | info += ((prev_total_width - self.total_width) * " ") 90 | 91 | sys.stdout.write(info) 92 | sys.stdout.flush() 93 | 94 | if current >= self.target: 95 | sys.stdout.write("\n") 96 | 97 | if self.verbose == 2: 98 | if current >= self.target: 99 | info = '%ds' % (now - self.start) 100 | for k in self.unique_values: 101 | info += ' - %s:' % k 102 | avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) 103 | if avg > 1e-3: 104 | info += ' %.4f' % avg 105 | else: 106 | info += ' %.4e' % avg 107 | sys.stdout.write(info + "\n") 108 | 109 | self.last_update = now 110 | 111 | def add(self, n, values=[]): 112 | self.update(self.seen_so_far + n, values) 113 | -------------------------------------------------------------------------------- /unet/unet_modular/unet_base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | # import time 4 | import json 5 | #import cv2 6 | import numpy as np 7 | import theano 8 | import theano.tensor as T 9 | from theano.compile.nanguardmode import NanGuardMode 10 | import lasagne 11 | from lasagne.layers import batch_norm,ElemwiseSumLayer,NonlinearityLayer 12 | from lasagne.regularization import regularize_network_params, l2, l1 13 | from lasagne.layers.dnn import Conv2DDNNLayer as Conv2DLayer 14 | from lasagne.layers.dnn import MaxPool2DDNNLayer as MaxPool2DLayer 15 | from unet.unet_modular.utilities import * 16 | 17 | lasagne.layers.Conv2DLayer = Conv2DLayer 18 | lasagne.layers.MaxPool2DLayer = MaxPool2DLayer 19 | 20 | def softmax(x): 21 | return T.exp(x)/(T.exp(x).sum(1,keepdims=True)) 22 | 23 | def maxout(x,filters,kernel,maxout): 24 | x = batch_norm(lasagne.layers.Conv2DLayer( 25 | x, num_filters=filters*maxout,filter_size=(kernel,kernel), 26 | nonlinearity=None,pad=1, 27 | W=initf 28 | )) 29 | x = lasagne.layers.FeaturePoolLayer(x, pool_size=maxout) 30 | return x 31 | 32 | def sftmax(x): 33 | sftmax = x.reshape((x.shape[0],x.shape[1],x.shape[2]*x.shape[3])) 34 | sftmax = sftmax.dimshuffle((1,0,2)) 35 | sftmax = sftmax.reshape((sftmax.shape[0],sftmax.shape[1]*sftmax.shape[2])) 36 | sftmax = softmax(T.transpose(sftmax)) 37 | return sftmax 38 | 39 | def loss(yp,yt,w): 40 | return -T.mean(T.log(yp)[T.arange(yp.shape[0]), yt]*w) 41 | 42 | def normal(ilayer,fmaps,activation,t='enc',ltype='normal'): 43 | if t == 'enc': 44 | x = batch_norm(lasagne.layers.Conv2DLayer( 45 | ilayer, num_filters=fmaps[0],filter_size=(3,3), 46 | nonlinearity=None,pad=1, 47 | W=initf 48 | )) 49 | else: 50 | x = batch_norm(lasagne.layers.Conv2DLayer( 51 | ilayer, num_filters=fmaps[0],filter_size=(3,3), 52 | nonlinearity=activation,pad=1, 53 | W=initf 54 | )) 55 | if ltype == 'normal': 56 | x = batch_norm(lasagne.layers.Conv2DLayer( 57 | x, num_filters=fmaps[1],filter_size=(3,3), 58 | nonlinearity=activation,pad=1, 59 | W=initf 60 | )) 61 | elif ltype == 'residual': 62 | x = batch_norm(lasagne.layers.Conv2DLayer( 63 | x, num_filters=fmaps[1],filter_size=(3,3), 64 | nonlinearity=None,pad=1, 65 | W=initf 66 | )) 67 | y = lasagne.layers.Conv2DLayer( 68 | ilayer, num_filters=fmaps[1],filter_size=(1,1), 69 | nonlinearity=None,pad='same', W=initf) 70 | x = ElemwiseSumLayer([x, y]) 71 | x = NonlinearityLayer(x,nonlinearity=activation) 72 | return x 73 | 74 | initf = lasagne.init.GlorotUniform() 75 | 76 | def build_network(cfg,input_var): 77 | encs = cfg['layers']['enc'] 78 | act = cfg['act'] 79 | layertype = cfg['layertype'] 80 | activation = get_activation(act) 81 | if activation == None: 82 | return 83 | inpLayer = encs[0] 84 | inpShape = tuple(inpLayer['shape']) 85 | enc_outputs = [] 86 | x = lasagne.layers.InputLayer( 87 | shape=inpShape, 88 | input_var=input_var 89 | ) 90 | x = normal(x,inpLayer['conv'],activation,ltype=layertype) 91 | enc_outputs.append(x) 92 | x = lasagne.layers.MaxPool2DLayer( 93 | x, pool_size=(2,2) 94 | ) 95 | 96 | for enc in encs[1:]: 97 | x = normal(x,enc['conv'],activation) 98 | enc_outputs.append(x) 99 | x = lasagne.layers.MaxPool2DLayer( 100 | x, pool_size=(2,2) 101 | ) 102 | 103 | x = normal(x,cfg['layers']['bottom']['conv'],activation,t='dec',ltype=layertype) 104 | x = lasagne.layers.DropoutLayer( 105 | x,p=0.5 106 | ) 107 | decs = cfg['layers']['dec'] 108 | 109 | for i,dec in enumerate(decs): 110 | enco = enc_outputs[-(i+1)] 111 | x = lasagne.layers.Upscale2DLayer( 112 | x, 113 | scale_factor=(2,2) 114 | ) 115 | convf = dec['conv'][0] 116 | x = batch_norm(lasagne.layers.Conv2DLayer( 117 | x, num_filters=convf,filter_size=(2,2), 118 | nonlinearity=None,pad='full', 119 | W=initf 120 | )) 121 | x = lasagne.layers.ConcatLayer( 122 | [x,enco], 123 | cropping=['center',None,'center','center'] 124 | ) 125 | x = normal(x,dec['conv'][1:],activation,t='dec',ltype=layertype) 126 | 127 | ox = cfg['layers']['output'] 128 | x = lasagne.layers.Conv2DLayer( 129 | x, num_filters=ox,filter_size=(1,1), 130 | nonlinearity=sftmax, 131 | W=initf 132 | ) 133 | 134 | return x 135 | 136 | def get_functions(cfg): 137 | input_var = T.tensor4('inputs') 138 | target_var = T.ivector('targets') 139 | weights_var = T.vector('weights') 140 | lr=theano.shared(np.float32(0.0000)) 141 | 142 | network = build_network(cfg,input_var) 143 | prediction = lasagne.layers.get_output(network) 144 | test_prediction = lasagne.layers.get_output(network, deterministic = True) 145 | output_shape = lasagne.layers.get_output_shape(network) 146 | l2_penalty = regularize_network_params(network, l2) 147 | l1_penalty = regularize_network_params(network, l1) 148 | cost = loss(prediction,target_var,weights_var) + 5*1e-6*(l1_penalty + l2_penalty) 149 | params = lasagne.layers.get_all_params(network, trainable=True) 150 | #print (len(params)) 151 | def save_params(path): 152 | np.savez(path,params) 153 | return 154 | 155 | def load_params(path): 156 | data = np.load(path) 157 | param_values = [ x.get_value() for x in data['arr_0'] ] 158 | # print (len(param_values)) 159 | lasagne.layers.set_all_param_values(network, param_values, trainable=True) 160 | return 161 | 162 | def set_lr(value): 163 | lr.set_value(value) 164 | return 165 | 166 | optimiser= cfg['optimiser'] 167 | updates = get_updates(cost,params, optimiser ,lr) 168 | 169 | def acc(yp,yt): 170 | output = T.argmax(yp,axis=1) 171 | return T.mean(T.eq(output, target_var)) 172 | 173 | accuracy = acc(prediction,target_var) 174 | 175 | train_fn = theano.function([input_var, target_var,weights_var], cost, updates=updates) 176 | val_fn = theano.function([input_var, target_var], accuracy) 177 | train_predict_fn = theano.function([input_var], prediction) 178 | test_predict_fn = theano.function([input_var], test_prediction) 179 | 180 | return train_fn, test_predict_fn, train_predict_fn, save_params, load_params, output_shape, set_lr 181 | -------------------------------------------------------------------------------- /segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from unet.unet_modular.unet_base import * 3 | from unet.unet_modular.utilities import * 4 | from unet.unet_modular.progbar import * 5 | from data_loader import train_data_loader, valid_data_loader 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import math 10 | import scipy as sp 11 | from sklearn import metrics 12 | import logging 13 | # from datetime import datetime 14 | import time 15 | import csv 16 | from pastalog import Log 17 | 18 | 19 | 20 | 21 | seed = 123 22 | rng = np.random.RandomState(seed) 23 | 24 | combine_label = True 25 | layers = 3 26 | inputch = 1 27 | filters = 64 28 | if combine_label: 29 | outputs = 3 30 | else: 31 | outputs = 4 32 | act = 'relu' 33 | ltype = 'normal' 34 | lr = 1e-4 35 | nb_epoch = 40 36 | nb_samples_per_epoch = 1000 37 | nb_val_samples = 100 38 | patience = 20 39 | optimiser = 'adam' 40 | path = '/data/overalap-chromosomes/models/weights' 41 | train_batch_size = 10 42 | valid_batch_size = 10 43 | # predicted_path = '/data2/processed/luna_segmentation/predicted' 44 | save_path = '/data/overalap-chromosomes/demo/' 45 | 46 | def sum_indices(arr, index_list): 47 | #s = 0 48 | #for i in index_list: 49 | # s += arr[i] 50 | return sum([arr[i] for i in index_list]) 51 | 52 | def vis_detections(X_valid,Y_valid,Y_pred,save_image_path): 53 | fig = plt.figure() 54 | a=fig.add_subplot(1,3,1) 55 | plt.imshow(X_valid) 56 | a.set_title('Original Scan') 57 | a=fig.add_subplot(1,3,2) 58 | imgplot = plt.imshow(Y_valid) 59 | a.set_title('True') 60 | a=fig.add_subplot(1,3,3) 61 | plt.imshow(Y_pred) 62 | a.set_title('Predicted') 63 | plt.savefig(save_image_path) 64 | plt.close() 65 | 66 | 67 | # Function returns the 3 required dice scores given the groundtruths and predictions numpy arrays 68 | def get_dice_score(groundtruths, predictions): 69 | elements1, counts1 = np.unique(groundtruths, return_counts = True) 70 | elements2, counts2 = np.unique(predictions, return_counts = True) 71 | #assert(elements2 == elements1).all() 72 | unique_counts_groundtruths = np.array([0,0,0,0,0]) 73 | unique_counts_predictions = np.array([0,0,0,0,0]) 74 | for i in xrange(len(elements1)): 75 | unique_counts_groundtruths[elements1[i]] = counts1[i] 76 | for i in xrange(len(elements2)): 77 | unique_counts_predictions[elements2[i]] = counts2[i] 78 | 79 | #intersection_counts1 = [(np.where(groundtruths == i) == np.where(predictions == i)).sum() for i in xrange(5)] 80 | intersection_counts = np.array([0,0,0,0,0]) 81 | for i in xrange(0, len(groundtruths)): 82 | if(groundtruths[i] == predictions[i]): 83 | intersection_counts[groundtruths[i]] += 1 84 | 85 | ct_arr = np.array([1,2,3,4]) 86 | dice_score = 2*sum_indices(intersection_counts, ct_arr)/ (sum_indices(unique_counts_groundtruths, ct_arr)+sum_indices(unique_counts_predictions, ct_arr)) 87 | return dice_score 88 | 89 | 90 | def main_training(log_tuple, validation_set=0, threshold = 0.5, layers = 3, lr = 1e-2, nb_epoch = 5, nb_samples_per_epoch = 100 , 91 | nb_val_samples = 20, patience = 20,path = 'models/weights'): 92 | best_val_loss = np.inf 93 | not_done_looping = True 94 | nb_perf_not_improved = 0 95 | demo_dict = {} 96 | log_train,log_valid = log_tuple 97 | for epoch in range(nb_epoch): 98 | print ("Epoch: {}/{}".format(epoch+1, nb_epoch)) 99 | if not_done_looping: 100 | progbar = Progbar(target=nb_samples_per_epoch) 101 | seen = 0 102 | count_train_samples = 0 103 | decay = math.pow(0.5, epoch/50) 104 | lr = lr*decay 105 | set_lr(lr) 106 | mean_accuracy = 0 107 | mean_val_loss = 0 108 | mean_dice_score = 0 109 | mean_precision = 0 110 | mean_recall = 0 111 | count_valid_samples = 0 112 | no_of_patches_seen =0 113 | mean_train_loss= 0 114 | mean_train_recall =0 115 | mean_train_precision =0 116 | mean_train_dice_score =0 117 | 118 | 119 | for X_train, Y_train, weights in train_data_loader(train_batch_size, combine_label): 120 | if count_train_samples == nb_samples_per_epoch: 121 | break 122 | if seen < nb_samples_per_epoch: 123 | log_values=[] 124 | xs = X_train.shape[2] 125 | ys = Y_train.shape[3] 126 | Y_train = Y_train.reshape((train_batch_size*xs*ys,)) 127 | weights = weights.reshape((train_batch_size*xs*ys,)) 128 | train_loss = train_fn(X_train.astype('float32'),Y_train.astype('int32'),weights.astype('float32')) 129 | Y_pred = predict_fn(X_train.astype('float32')) 130 | Y_pred_class = np.argmax(Y_pred, axis =1) 131 | dice_score = get_dice_score(Y_train,Y_pred_class) 132 | mean_train_loss+= train_loss 133 | mean_train_dice_score+= dice_score 134 | count_train_samples += X_train.shape[0] 135 | seen+= X_train.shape[0] 136 | log_values.append(('train_loss',train_loss)) 137 | if seen < nb_samples_per_epoch: 138 | progbar.update(seen,log_values) 139 | log_values.append(('train_loss',train_loss)) 140 | progbar.update(seen,log_values, force=True) 141 | mean_train_loss = mean_train_loss/(nb_samples_per_epoch/train_batch_size) 142 | mean_train_dice_score = mean_train_dice_score/(nb_samples_per_epoch/train_batch_size) 143 | log_train.post('train_loss', mean_train_loss, epoch) 144 | log_train.post("mean_train_dice_score",mean_train_dice_score, epoch ) 145 | 146 | 147 | if epoch % 5 == 0: 148 | validation_start = time.time() 149 | count_valid_samples = 0 150 | for X_valid,Y_valid in valid_data_loader(nb_val_samples, valid_batch_size, combine_label): 151 | xs = X_valid.shape[2] 152 | ys = Y_valid.shape[3] 153 | Y_valid = Y_valid.reshape((valid_batch_size*xs*ys,)) 154 | Y_pred = test_predict_fn(X_valid.astype('float32')) 155 | val_loss = loss(Y_pred.astype('float32'), 156 | Y_valid.astype('int32'), 157 | np.ones((Y_valid.shape[0],)).astype('float32')).eval() 158 | Y_pred_class = np.argmax(Y_pred, axis =1) 159 | dice_score = get_dice_score(Y_valid,Y_pred_class) 160 | Y_pred = Y_pred_class.reshape(valid_batch_size,1,xs,ys) 161 | Y_valid = Y_valid.reshape(valid_batch_size,1,xs,ys) 162 | save_image_path = os.path.join(save_path, str(epoch), '{}.png'.format(count_valid_samples)) 163 | if not os.path.exists(os.path.join(save_path, str(epoch))): 164 | os.makedirs(os.path.join(save_path, str(epoch))) 165 | vis_detections(X_valid[5][0],Y_valid[5][0],Y_pred[5][0],save_image_path) 166 | mean_val_loss+= val_loss 167 | mean_dice_score+= dice_score 168 | count_valid_samples += 1 169 | 170 | mean_val_loss= mean_val_loss/(nb_val_samples/valid_batch_size) 171 | mean_dice_score = mean_dice_score/(nb_val_samples/valid_batch_size) 172 | print (mean_val_loss, mean_dice_score) 173 | 174 | log_valid.post("val_loss",mean_val_loss, epoch ) 175 | log_valid.post("mean_val_dice_score",mean_dice_score, epoch ) 176 | 177 | print ("mean_val_loss: {} , mean_dice_score: {}".format(mean_val_loss , mean_dice_score)) 178 | validation_end = time.time() 179 | validation_time = validation_end - validation_start 180 | print ('validation time : %ds' % validation_time) 181 | if mean_val_loss < best_val_loss: 182 | best_val_loss = mean_val_loss 183 | best_epoch = epoch 184 | nb_perf_not_improved = 0 185 | dpath = os.path.join(path,"Unet_vald_set_{}_val_loss_{}_epoch_{}".format(validation_set, best_val_loss,best_epoch)) 186 | save_params(dpath) 187 | else : 188 | nb_perf_not_improved+=1 189 | if nb_perf_not_improved > patience: 190 | print ("Exiting training as performance not improving for {} loops".format(patience)) 191 | not_done_looping = False 192 | 193 | 194 | 195 | return best_val_loss, best_epoch 196 | 197 | 198 | 199 | 200 | cfg = gen_config(layers,inputch,filters,outputs,act,ltype,optimiser) 201 | train_fn, test_predict_fn, predict_fn, save_params, load_params, output_shape, set_lr = get_functions(cfg) 202 | print ("Starting Training") 203 | with open('logs/log_training_2DUnet_lr_{}_optimiser_{}.log'.format(lr,optimiser), 'w') as f: 204 | sys.stdout = f 205 | print ("------- Checking for lr = {} ---------- ".format(lr)) 206 | log_train = Log('http://localhost:4152', '2DUnet_train') 207 | log_valid = Log('http://localhost:4152', '2DUnet_valid') 208 | log_tuple = (log_train,log_valid) 209 | 210 | best_val_loss, best_epoch = main_training(layers = layers,lr = lr, nb_epoch = nb_epoch, nb_samples_per_epoch = nb_samples_per_epoch , 211 | nb_val_samples = nb_val_samples, patience = patience,path = path, log_tuple = log_tuple) 212 | print ("---------------------------------------------------") 213 | sys.stdout = sys.__stdout__ 214 | 215 | # model_path = os.path.join(path,"Unet_vald_set_{}_val_loss_{}_epoch_{}.npz".format(validation_set,best_val_loss,best_epoch)) 216 | # best_model = load_params(model_path) 217 | # folder = 'subset'+str(validation_set) 218 | 219 | # for X,seriesuid in test_data_generator(validation_set =validation_set): 220 | # for i in range(X.shape[0]): 221 | # X_test = X[i] 222 | # X_test = X_test[np.newaxis, np.newaxis,...] 223 | # xs = X_test.shape[2] 224 | # ys = X_test.shape[3] 225 | # Y_pred = test_predict_fn(X_test.astype('float32')) 226 | # Y_pred = (Y_pred [:,1]> threshold ).astype('int') 227 | # Y_pred = Y_pred.reshape((1,xs,ys)) 228 | # if i==0 : 229 | # Y_pred_final = Y_pred 230 | # else: 231 | # Y_pred_final = np.append(Y_pred_final, Y_pred, axis = 0) 232 | # 233 | # np.save(os.path.join(predicted_path, folder, 'Y_segmentation_{}.npy'.format(seriesuid)), Y_pred_final) 234 | --------------------------------------------------------------------------------