├── img ├── image.jpg ├── image-label.png ├── image-prediction.png └── u-net-architecture.png ├── testUnet.py ├── trainUnet.py ├── LICENSE ├── README.md ├── evaluate.py ├── model.py └── data.py /img/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohandd/Unet-liverCT/HEAD/img/image.jpg -------------------------------------------------------------------------------- /img/image-label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohandd/Unet-liverCT/HEAD/img/image-label.png -------------------------------------------------------------------------------- /img/image-prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohandd/Unet-liverCT/HEAD/img/image-prediction.png -------------------------------------------------------------------------------- /img/u-net-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohandd/Unet-liverCT/HEAD/img/u-net-architecture.png -------------------------------------------------------------------------------- /testUnet.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from data import * 3 | 4 | print "testing..........." 5 | testGene = testGenerator("/ext/xhzhao/Unet-CT/data/liverCT/test") 6 | # model = unet("unet_membrane.hdf5") 7 | # model = unet("unet_liverCT_fulliamge.hdf5") 8 | model = unet("unet_liverCT_0.hdf5") 9 | # model.load_weights("unet_membrane.hdf5") 10 | results = model.predict_generator(testGene, 10, verbose=1) 11 | # results[results > 0.01] = 1 12 | # results[results <= 0.01] = 0 13 | saveResult("/ext/xhzhao/Unet-CT/data/liverCT_/liver1/pred", results) 14 | -------------------------------------------------------------------------------- /trainUnet.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from data import * 3 | 4 | data_gen_args = dict(rotation_range=0.2, 5 | width_shift_range=0.05, 6 | height_shift_range=0.05, 7 | shear_range=0.05, 8 | zoom_range=0.05, 9 | horizontal_flip=True, 10 | fill_mode='nearest') 11 | myGene = trainGenerator(2,'data/liverCT/train','image','label',data_gen_args,save_to_dir = None) 12 | model = unet() 13 | 14 | model_checkpoint = ModelCheckpoint('unet_liverCT_fulliamge.hdf5', monitor='loss',verbose=1, save_best_only=True) 15 | print "training............" 16 | model.fit_generator(myGene, steps_per_epoch=3000, epochs=50, callbacks=[model_checkpoint]) 17 | 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 赵晗 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unet-liverCT 2 | Unet network for liver CT image segmentation 3 | This work is based on [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) 4 | AND [unet](https://github.com/zhixuhao/unet). I develop the whole project to solve the problem of Unet network segmenting liver CT. 5 | ## Overview 6 | ### Data 7 | The dataset [3D-IRCADb(3D Image Reconstruction for Comparison of Algorithm Database)](https://www.ircad.fr/research/computer/).This dataset 8 | includes 20 people's liver CT images, 15 of 20 have tumors. 9 | ### Data Augmentation 10 | I use keras.preprocessing.image to do the data augmentation in order to get enough images to train the network. You can do it or not. 11 | ### Model Architecture 12 | ![img/u-net-architecture.png](img/u-net-architecture.png) 13 | ### Train and Test 14 | #### Dependencies 15 | + python == 2.7.15 16 | + tensorflow-gpu == 1.3.0 17 | + keras == 2.0.5 18 | I did not test other versions, you can have a try. 19 | #### Training 20 | 21 | ```python trainUnet.py``` 22 | #### Testing 23 | ```python testUnet.py``` 24 | #### Evaluating 25 | ```python evaluate.py``` 26 | The results of liver CT segmentation and tumor segmentation are based on the following indicators:**Dice coefficient** and **RVD(relative volume difference)** and **VOE(volumetric overlap error)**. 27 | ### Result 28 | ![img/image.jpg](img/image.jpg) 29 | ![img/image-label.png](img/image-label.png) 30 | ![img/image-prediction.png](img/image-prediction.png) 31 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.io as io 3 | import os 4 | 5 | 6 | def dice_score(path1, path2, num_image): 7 | area_A = 0 8 | area_B = 0 9 | area_C = 0 10 | sum = 0 11 | for i in range(num_image): 12 | img1 = io.imread(os.path.join(path1,"%d.jpg"%i),as_grey = True) 13 | img2 = io.imread(os.path.join(path2,"%d.jpg"%i),as_grey = True) 14 | img2=img2.astype("float64") 15 | for a in range(512): 16 | for b in range(512): 17 | if img1[a, b] > 0.0: 18 | area_A = area_A + 1 19 | print area_A 20 | 21 | for a in range(512): 22 | for b in range(512): 23 | if img2[a, b] > 0.0: 24 | area_B = area_B + 1 25 | print area_B 26 | 27 | for a in range(512): 28 | for b in range(512): 29 | if img1[a, b] > 0.0 and img2[a, b] > 0.0: 30 | area_C = area_C + 1 31 | print area_C 32 | 33 | # if area_A == 0.0 and area_B == 0.0: 34 | # dice = 1 35 | 36 | dice = (2.0 * area_C)/(area_A + area_B + 0.0) 37 | sum = sum + dice 38 | print 'Dice_score{}: {}'.format(i, dice) 39 | 40 | avg = sum / num_image 41 | print 'Dice_avg: {}'.format(avg) 42 | print '--------------------------------------------------------' 43 | 44 | 45 | def voe_err(path1, path2, num_image): 46 | area_C = 0 47 | area_D = 0 48 | sum1 = 0 49 | for i in range(num_image): 50 | img1 = io.imread(os.path.join(path1,"%d.jpg"%i),as_grey = True) 51 | img2 = io.imread(os.path.join(path2,"%d.jpg"%i),as_grey = True) 52 | img2 = np.float64(img2) 53 | for a in range(512): 54 | for b in range(512): 55 | if img1[a, b] > 0.0 and img2[a, b] > 0.0: 56 | area_C = area_C + 1 57 | # print area_C 58 | 59 | for a in range(512): 60 | for b in range(512): 61 | if img1[a, b] > 0.0 or img2[a, b] > 0.0: 62 | area_D = area_D + 1 63 | # print area_D 64 | voe = 1.0 - area_C / (area_D + 0.0) 65 | sum1 = sum1 + voe 66 | 67 | print 'Voe_err{}: {}'.format(i, 1.0 - area_C / (area_D + 0.0)) 68 | print 'sum:{}'.format(sum1) 69 | avg = sum1 / num_image 70 | print 'Voe_avg: {}'.format(avg) 71 | print '--------------------------------------------------------' 72 | 73 | 74 | def rvd_err(path1, path2, num_image): 75 | area_A = 0 76 | area_B = 0 77 | sum = 0 78 | for i in range(num_image): 79 | img1 = io.imread(os.path.join(path1, "%d.jpg" % i), as_grey=True) 80 | img2 = io.imread(os.path.join(path2, "%d.jpg" % i), as_grey=True) 81 | img2 = np.float64(img2) 82 | for a in range(512): 83 | for b in range(512): 84 | if img1[a, b] > 0.0: 85 | area_A = area_A + 1 86 | # print area_A 87 | 88 | for a in range(512): 89 | for b in range(512): 90 | if img2[a, b] > 0.0: 91 | area_B = area_B + 1 92 | # print area_B 93 | rvd = (area_B - area_A + 0.0) / (area_A + 0.0) 94 | sum = sum + rvd 95 | 96 | print 'Rvd_err{}: {}'.format(i, (area_B - area_A + 0.0) / (area_A + 0.0)) 97 | avg = sum / num_image 98 | print 'Rvd_avg: {}'.format(avg) 99 | print '--------------------------------------------------------' 100 | 101 | 102 | if __name__ == '__main__': 103 | path1 = '/ext/xhzhao/Unet-CT/evaluation/full_tumor_prediction' 104 | path2 = '/ext/xhzhao/Unet-CT/evaluation/full_tumor_label' 105 | num_image = 30 106 | # dice_score(path1, path2, num_image) 107 | 108 | voe_err(path1, path2, num_image) 109 | 110 | # rvd_err(path1, path2, num_image) 111 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import skimage.io as io 4 | import skimage.transform as trans 5 | import numpy as np 6 | import tensorflow as tf 7 | import keras.backend.tensorflow_backend as KTF 8 | from keras.models import * 9 | from keras.layers import * 10 | from keras.optimizers import * 11 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 12 | from keras import backend as keras 13 | # KTF.set_session(tf.Session(config=tf.ConfigProto(device_count={'gpu': 1}))) 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 15 | 16 | def unet(pretrained_weights = None, input_size = (512, 512, 1)): 17 | inputs = Input(input_size) 18 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) 19 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) 20 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 21 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) 22 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) 23 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 24 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) 25 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 26 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 27 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) 28 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 29 | drop4 = Dropout(0.5)(conv4) 30 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 31 | 32 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) 33 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 34 | drop5 = Dropout(0.5)(conv5) 35 | 36 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 37 | merge6 = concatenate([drop4,up6], axis = 3) 38 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) 39 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 40 | 41 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 42 | merge7 = concatenate([conv3,up7], axis = 3) 43 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) 44 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) 45 | 46 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 47 | merge8 = concatenate([conv2,up8], axis = 3) 48 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) 49 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) 50 | 51 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 52 | merge9 = concatenate([conv1,up9], axis = 3) 53 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) 54 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 55 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 56 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) 57 | 58 | model = Model(inputs = inputs, outputs = conv10) 59 | 60 | model.compile(optimizer = Adam(lr = 1e-5), loss = 'binary_crossentropy', metrics = ['accuracy']) 61 | 62 | #model.summary() 63 | 64 | if(pretrained_weights): 65 | model.load_weights(pretrained_weights) 66 | 67 | return model 68 | 69 | 70 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from keras.preprocessing.image import ImageDataGenerator 3 | import numpy as np 4 | import os 5 | import glob 6 | import skimage.io as io 7 | import skimage.transform as trans 8 | import itertools 9 | 10 | Sky = [128,128,128] 11 | Building = [128,0,0] 12 | Pole = [192,192,128] 13 | Road = [128,64,128] 14 | Pavement = [60,40,222] 15 | Tree = [128,128,0] 16 | SignSymbol = [192,128,128] 17 | Fence = [64,64,128] 18 | Car = [64,0,128] 19 | Pedestrian = [64,64,0] 20 | Bicyclist = [0,128,192] 21 | Unlabelled = [0,0,0] 22 | 23 | COLOR_DICT = np.array([Sky, Building, Pole, Road, Pavement, 24 | Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled]) 25 | 26 | 27 | def adjustData(img,mask,flag_multi_class,num_class): 28 | if(flag_multi_class): 29 | # img = img / 255 30 | img /= 255.0 31 | mask = mask[:,:,:,0] if(len(mask.shape) == 4) else mask[:,:,0] 32 | new_mask = np.zeros(mask.shape + (num_class,)) 33 | for i in range(num_class): 34 | #for one pixel in the image, find the class in mask and convert it into one-hot vector 35 | #index = np.where(mask == i) 36 | #index_mask = (index[0],index[1],index[2],np.zeros(len(index[0]),dtype = np.int64) + i) if (len(mask.shape) == 4) else (index[0],index[1],np.zeros(len(index[0]),dtype = np.int64) + i) 37 | #new_mask[index_mask] = 1 38 | new_mask[mask == i,i] = 1 39 | new_mask = np.reshape(new_mask,(new_mask.shape[0],new_mask.shape[1]*new_mask.shape[2],new_mask.shape[3])) if flag_multi_class else np.reshape(new_mask,(new_mask.shape[0]*new_mask.shape[1],new_mask.shape[2])) 40 | mask = new_mask 41 | elif(np.max(img) > 1): 42 | # img = img / 255 43 | img /= 255.0 44 | # mask = mask / 255 45 | mask /= 255.0 46 | mask[mask > 0.5] = 1 47 | mask[mask <= 0.5] = 0 48 | return (img,mask) 49 | 50 | 51 | 52 | def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "grayscale", 53 | mask_color_mode = "grayscale",image_save_prefix = "image",mask_save_prefix = "mask", 54 | flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (512,512),seed = 1): 55 | ''' 56 | can generate image and mask at the same time 57 | use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same 58 | if you want to visualize the results of generator, set save_to_dir = "your path" 59 | ''' 60 | image_datagen = ImageDataGenerator(**aug_dict) 61 | mask_datagen = ImageDataGenerator(**aug_dict) 62 | image_generator = image_datagen.flow_from_directory( 63 | train_path, 64 | classes = [image_folder], 65 | class_mode = None, 66 | color_mode = image_color_mode, 67 | target_size = target_size, 68 | batch_size = batch_size, 69 | save_to_dir = save_to_dir, 70 | save_prefix = image_save_prefix, 71 | seed = seed) 72 | mask_generator = mask_datagen.flow_from_directory( 73 | train_path, 74 | classes = [mask_folder], 75 | class_mode = None, 76 | color_mode = mask_color_mode, 77 | target_size = target_size, 78 | batch_size = batch_size, 79 | save_to_dir = save_to_dir, 80 | save_prefix = mask_save_prefix, 81 | seed = seed) 82 | # train_generator = zip(image_generator, mask_generator) 83 | train_generator = itertools.izip(image_generator, mask_generator) 84 | for (img,mask) in train_generator: 85 | img,mask = adjustData(img,mask,flag_multi_class,num_class) 86 | yield (img,mask) 87 | 88 | 89 | def testGenerator(test_path,num_image = 223, target_size = (512,512),flag_multi_class = False,as_gray = True): 90 | for i in range(num_image): 91 | img = io.imread(os.path.join(test_path,"%d.jpg"%i),as_gray = True) 92 | 93 | img = img / 255. 94 | # img /= 255. 95 | 96 | img = trans.resize(img,target_size) 97 | img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img 98 | img = np.reshape(img,(1,)+img.shape) 99 | # img = (img * 255).astype(np.uint8) 100 | yield img 101 | 102 | 103 | def geneTrainNpy(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = True,mask_as_gray = True): 104 | image_name_arr = glob.glob(os.path.join(image_path,"%s*.png"%image_prefix)) 105 | image_arr = [] 106 | mask_arr = [] 107 | for index,item in enumerate(image_name_arr): 108 | img = io.imread(item,as_gray = image_as_gray) 109 | img = np.reshape(img,img.shape + (1,)) if image_as_gray else img 110 | mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray) 111 | mask = np.reshape(mask,mask.shape + (1,)) if mask_as_gray else mask 112 | img,mask = adjustData(img,mask,flag_multi_class,num_class) 113 | image_arr.append(img) 114 | mask_arr.append(mask) 115 | image_arr = np.array(image_arr) 116 | mask_arr = np.array(mask_arr) 117 | return image_arr,mask_arr 118 | 119 | 120 | def labelVisualize(num_class,color_dict,img): 121 | img = img[:,:,0] if len(img.shape) == 3 else img 122 | img_out = np.zeros(img.shape + (3,)) 123 | for i in range(num_class): 124 | img_out[img == i,:] = color_dict[i] 125 | return img_out / 255.0 126 | 127 | 128 | 129 | def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2): 130 | for i,item in enumerate(npyfile): 131 | img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0] 132 | io.imsave(os.path.join(save_path,"%d.jpg"%i),img) 133 | --------------------------------------------------------------------------------