├── DenseCRF ├── 2D-CRF.py └── 3D-CRF.py ├── README.md ├── data_analysis ├── get _threshold.py ├── get_spacing.py ├── liver_slice_percentage.py └── liver_voxel_percentage.py ├── data_prepare └── get_training_set.py ├── dataset └── dataset.py ├── img ├── loss_curve.png └── segmentation-result.png ├── loss ├── BCE.py ├── Dice.py ├── ELDice.py ├── Hybrid.py ├── Jaccard.py ├── SS.py ├── Tversky.py └── WBCE.py ├── net └── ResUNet.py ├── parameter.py ├── requirements.txt ├── train_ds.py ├── utilities └── calculate_metrics.py └── val.py /DenseCRF/2D-CRF.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 二维全连接条件随机场后处理优化 4 | PS:三维CRF会消耗大量内存,所以如果内存不充裕可以考虑使用二维 5 | """ 6 | 7 | import os 8 | import sys 9 | sys.path.append(os.path.split(sys.path[0])[0]) 10 | 11 | import collections 12 | 13 | import numpy as np 14 | import pandas as pd 15 | from tqdm import tqdm 16 | import SimpleITK as sitk 17 | 18 | import pydensecrf.densecrf as dcrf 19 | from pydensecrf.utils import create_pairwise_bilateral, create_pairwise_gaussian, unary_from_softmax 20 | 21 | from utilities.calculate_metrics import Metirc 22 | 23 | import parameter as para 24 | 25 | 26 | file_name = [] # 文件名称 27 | 28 | # 定义评价指标 29 | liver_score = collections.OrderedDict() 30 | liver_score['dice'] = [] 31 | liver_score['jacard'] = [] 32 | liver_score['voe'] = [] 33 | liver_score['fnr'] = [] 34 | liver_score['fpr'] = [] 35 | liver_score['assd'] = [] 36 | liver_score['rmsd'] = [] 37 | liver_score['msd'] = [] 38 | 39 | # 为了计算dice_global定义的两个变量 40 | dice_intersection = 0.0 41 | dice_union = 0.0 42 | 43 | 44 | for file_index, file in enumerate(os.listdir(para.test_ct_path)): 45 | 46 | print('file index:', file_index, file, '--------------------------------------') 47 | 48 | file_name.append(file) 49 | 50 | ct = sitk.ReadImage(os.path.join(para.test_ct_path, file), sitk.sitkInt16) 51 | ct_array = sitk.GetArrayFromImage(ct) 52 | 53 | pred = sitk.ReadImage(os.path.join(para.pred_path, file.replace('volume', 'pred')), sitk.sitkUInt8) 54 | pred_array = sitk.GetArrayFromImage(pred) 55 | 56 | seg = sitk.ReadImage(os.path.join(para.test_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8) 57 | seg_array = sitk.GetArrayFromImage(seg) 58 | seg_array[seg_array > 0] = 1 59 | 60 | # 灰度截断 61 | ct_array[ct_array > para.upper] = para.upper 62 | ct_array[ct_array < para.lower] = para.lower 63 | 64 | # 切割出预测结果部分,减少crf处理难度 65 | z = np.any(pred_array, axis=(1, 2)) 66 | start_z, end_z = np.where(z)[0][[0, -1]] 67 | 68 | y = np.any(pred_array, axis=(0, 1)) 69 | start_y, end_y = np.where(y)[0][[0, -1]] 70 | 71 | x = np.any(pred_array, axis=(0, 2)) 72 | start_x, end_x = np.where(x)[0][[0, -1]] 73 | 74 | # 扩张 75 | start_z = max(0, start_z - para.z_expand) 76 | start_x = max(0, start_x - para.x_expand) 77 | start_y = max(0, start_y - para.y_expand) 78 | 79 | end_z = min(ct_array.shape[0], end_z + para.z_expand) 80 | end_x = min(ct_array.shape[1], end_x + para.x_expand) 81 | end_y = min(ct_array.shape[2], end_y + para.y_expand) 82 | 83 | new_ct_array = ct_array[start_z: end_z, start_x: end_x, start_y: end_y] 84 | new_pred_array = pred_array[start_z: end_z, start_x: end_x, start_y: end_y] 85 | 86 | print('old shape', ct_array.shape) 87 | print('new shape', new_ct_array.shape) 88 | print('shrink to:', np.prod(new_ct_array.shape) / np.prod(ct_array.shape), '%') 89 | 90 | res = np.zeros_like(new_pred_array) 91 | 92 | for slice_index in tqdm(range(new_ct_array.shape[0])): 93 | 94 | data_array = new_ct_array[slice_index] 95 | seg = new_pred_array[slice_index] 96 | 97 | # 定义条件随机场 98 | n_labels = 2 99 | d = dcrf.DenseCRF(data_array.shape[0] * data_array.shape[1], n_labels) 100 | 101 | # 获取一元势 102 | unary = np.zeros_like(seg, dtype=np.float32) 103 | unary[seg == 0] = 0.1 104 | unary[seg == 1] = 0.9 105 | 106 | U = np.stack((1 - unary, unary), axis=0) 107 | d.setUnaryEnergy(unary_from_softmax(U)) 108 | 109 | # 获取二元势 110 | # This creates the color-independent features and then add them to the CRF 111 | feats = create_pairwise_gaussian(sdims=(para.s1, para.s1), shape=data_array.shape) 112 | d.addPairwiseEnergy(feats, compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 113 | 114 | # This creates the color-dependent features and then add them to the CRF 115 | feats = create_pairwise_bilateral(sdims=(para.s2, para.s2), schan=(para.s3,), img=data_array) 116 | d.addPairwiseEnergy(feats, compat=10, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 117 | 118 | # 进行推理 119 | Q = d.inference(para.max_iter) 120 | 121 | # 获取预测标签结果 122 | MAP = np.argmax(np.array(Q), axis=0).reshape(seg.shape) 123 | res[slice_index] = MAP 124 | 125 | liver_seg = np.zeros_like(seg_array, dtype=np.uint8) 126 | liver_seg[start_z: end_z, start_x: end_x, start_y: end_y] = res.astype(np.uint8) 127 | 128 | # 计算分割评价指标 129 | liver_metric = Metirc(seg_array, liver_seg, ct.GetSpacing()) 130 | 131 | liver_score['dice'].append(liver_metric.get_dice_coefficient()[0]) 132 | liver_score['jacard'].append(liver_metric.get_jaccard_index()) 133 | liver_score['voe'].append(liver_metric.get_VOE()) 134 | liver_score['fnr'].append(liver_metric.get_FNR()) 135 | liver_score['fpr'].append(liver_metric.get_FPR()) 136 | liver_score['assd'].append(liver_metric.get_ASSD()) 137 | liver_score['rmsd'].append(liver_metric.get_RMSD()) 138 | liver_score['msd'].append(liver_metric.get_MSD()) 139 | 140 | dice_intersection += liver_metric.get_dice_coefficient()[1] 141 | dice_union += liver_metric.get_dice_coefficient()[2] 142 | 143 | # 将CRF后处理的结果保存为nii数据 144 | pred_seg = sitk.GetImageFromArray(liver_seg) 145 | pred_seg.SetDirection(ct.GetDirection()) 146 | pred_seg.SetOrigin(ct.GetOrigin()) 147 | pred_seg.SetSpacing(ct.GetSpacing()) 148 | 149 | sitk.WriteImage(pred_seg, os.path.join(para.crf_path, file.replace('volume', 'crf'))) 150 | 151 | print('dice:', liver_score['dice'][-1]) 152 | print('--------------------------------------------------------------') 153 | 154 | 155 | # 将评价指标写入到exel中 156 | liver_data = pd.DataFrame(liver_score, index=file_name) 157 | 158 | liver_statistics = pd.DataFrame(index=['mean', 'std', 'min', 'max'], columns=list(liver_data.columns)) 159 | liver_statistics.loc['mean'] = liver_data.mean() 160 | liver_statistics.loc['std'] = liver_data.std() 161 | liver_statistics.loc['min'] = liver_data.min() 162 | liver_statistics.loc['max'] = liver_data.max() 163 | 164 | writer = pd.ExcelWriter('./result-post-processing.xlsx') 165 | liver_data.to_excel(writer, 'liver') 166 | liver_statistics.to_excel(writer, 'liver_statistics') 167 | writer.save() 168 | 169 | # 打印dice global 170 | print('dice global:', dice_intersection / dice_union) 171 | -------------------------------------------------------------------------------- /DenseCRF/3D-CRF.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 三维全连接条件随机场后处理优化 4 | """ 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.split(sys.path[0])[0]) 9 | 10 | import collections 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from tqdm import tqdm 15 | import SimpleITK as sitk 16 | 17 | import pydensecrf.densecrf as dcrf 18 | from pydensecrf.utils import create_pairwise_bilateral, create_pairwise_gaussian, unary_from_softmax 19 | 20 | from utilities.calculate_metrics import Metirc 21 | 22 | import parameter as para 23 | 24 | 25 | file_name = [] # 文件名称 26 | 27 | # 定义评价指标 28 | liver_score = collections.OrderedDict() 29 | liver_score['dice'] = [] 30 | liver_score['jacard'] = [] 31 | liver_score['voe'] = [] 32 | liver_score['fnr'] = [] 33 | liver_score['fpr'] = [] 34 | liver_score['assd'] = [] 35 | liver_score['rmsd'] = [] 36 | liver_score['msd'] = [] 37 | 38 | # 为了计算dice_global定义的两个变量 39 | dice_intersection = 0.0 40 | dice_union = 0.0 41 | 42 | 43 | for file_index, file in enumerate(os.listdir(para.test_ct_path)): 44 | 45 | print('file index:', file_index, file, '--------------------------------------') 46 | 47 | file_name.append(file) 48 | 49 | ct = sitk.ReadImage(os.path.join(para.test_ct_path, file), sitk.sitkInt16) 50 | ct_array = sitk.GetArrayFromImage(ct) 51 | 52 | pred = sitk.ReadImage(os.path.join(para.pred_path, file.replace('volume', 'pred')), sitk.sitkUInt8) 53 | pred_array = sitk.GetArrayFromImage(pred) 54 | 55 | seg = sitk.ReadImage(os.path.join(para.test_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8) 56 | seg_array = sitk.GetArrayFromImage(seg) 57 | seg_array[seg_array > 0] = 1 58 | 59 | # 灰度截断 60 | ct_array[ct_array > para.upper] = para.upper 61 | ct_array[ct_array < para.lower] = para.lower 62 | 63 | # 切割出预测结果部分,减少crf处理难度 64 | z = np.any(pred_array, axis=(1, 2)) 65 | start_z, end_z = np.where(z)[0][[0, -1]] 66 | 67 | y = np.any(pred_array, axis=(0, 1)) 68 | start_y, end_y = np.where(y)[0][[0, -1]] 69 | 70 | x = np.any(pred_array, axis=(0, 2)) 71 | start_x, end_x = np.where(x)[0][[0, -1]] 72 | 73 | # 扩张 74 | start_z = max(0, start_z - para.z_expand) 75 | start_x = max(0, start_x - para.x_expand) 76 | start_y = max(0, start_y - para.y_expand) 77 | 78 | end_z = min(ct_array.shape[0], end_z + para.z_expand) 79 | end_x = min(ct_array.shape[1], end_x + para.x_expand) 80 | end_y = min(ct_array.shape[2], end_y + para.y_expand) 81 | 82 | new_ct_array = ct_array[start_z: end_z, start_x: end_x, start_y: end_y] 83 | new_pred_array = pred_array[start_z: end_z, start_x: end_x, start_y: end_y] 84 | 85 | print('old shape', ct_array.shape) 86 | print('new shape', new_ct_array.shape) 87 | print('shrink to:', np.prod(new_ct_array.shape) / np.prod(ct_array.shape), '%') 88 | 89 | # 定义条件随机场 90 | n_labels = 2 91 | d = dcrf.DenseCRF(np.prod(new_ct_array.shape), n_labels) 92 | 93 | # 获取一元势 94 | unary = np.zeros_like(new_pred_array, dtype=np.float32) 95 | unary[new_pred_array == 0] = 0.1 96 | unary[new_pred_array == 1] = 0.9 97 | 98 | U = np.stack((1 - unary, unary), axis=0) 99 | d.setUnaryEnergy(unary_from_softmax(U)) 100 | 101 | # 获取二元势 102 | # This creates the color-independent features and then add them to the CRF 103 | feats = create_pairwise_gaussian(sdims=(para.s1, para.s1, para.s1), shape=new_ct_array.shape) 104 | d.addPairwiseEnergy(feats, compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 105 | 106 | # This creates the color-dependent features and then add them to the CRF 107 | feats = create_pairwise_bilateral(sdims=(para.s2, para.s2, para.s2), schan=(para.s3,), img=new_ct_array) 108 | d.addPairwiseEnergy(feats, compat=10, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 109 | 110 | # 进行推理 111 | Q, tmp1, tmp2 = d.startInference() 112 | for i in tqdm(range(para.max_iter)): 113 | # print("KL-divergence at {}: {}".format(i, d.klDivergence(Q))) 114 | d.stepInference(Q, tmp1, tmp2) 115 | 116 | # 获取预测标签结果 117 | MAP = np.argmax(np.array(Q), axis=0).reshape(new_pred_array.shape) 118 | 119 | liver_seg = np.zeros_like(seg_array, dtype=np.uint8) 120 | liver_seg[start_z: end_z, start_x: end_x, start_y: end_y] = MAP.astype(np.uint8) 121 | 122 | # 计算分割评价指标 123 | liver_metric = Metirc(seg_array, liver_seg, ct.GetSpacing()) 124 | 125 | liver_score['dice'].append(liver_metric.get_dice_coefficient()[0]) 126 | liver_score['jacard'].append(liver_metric.get_jaccard_index()) 127 | liver_score['voe'].append(liver_metric.get_VOE()) 128 | liver_score['fnr'].append(liver_metric.get_FNR()) 129 | liver_score['fpr'].append(liver_metric.get_FPR()) 130 | liver_score['assd'].append(liver_metric.get_ASSD()) 131 | liver_score['rmsd'].append(liver_metric.get_RMSD()) 132 | liver_score['msd'].append(liver_metric.get_MSD()) 133 | 134 | dice_intersection += liver_metric.get_dice_coefficient()[1] 135 | dice_union += liver_metric.get_dice_coefficient()[2] 136 | 137 | # 将CRF后处理的结果保存为nii数据 138 | pred_seg = sitk.GetImageFromArray(liver_seg) 139 | pred_seg.SetDirection(ct.GetDirection()) 140 | pred_seg.SetOrigin(ct.GetOrigin()) 141 | pred_seg.SetSpacing(ct.GetSpacing()) 142 | 143 | sitk.WriteImage(pred_seg, os.path.join(para.crf_path, file.replace('volume', 'crf'))) 144 | 145 | print('dice:', liver_score['dice'][-1]) 146 | print('--------------------------------------------------------------') 147 | 148 | 149 | # 将评价指标写入到exel中 150 | liver_data = pd.DataFrame(liver_score, index=file_name) 151 | 152 | liver_statistics = pd.DataFrame(index=['mean', 'std', 'min', 'max'], columns=list(liver_data.columns)) 153 | liver_statistics.loc['mean'] = liver_data.mean() 154 | liver_statistics.loc['std'] = liver_data.std() 155 | liver_statistics.loc['min'] = liver_data.min() 156 | liver_statistics.loc['max'] = liver_data.max() 157 | 158 | writer = pd.ExcelWriter('./result-post-processing.xlsx') 159 | liver_data.to_excel(writer, 'liver') 160 | liver_statistics.to_excel(writer, 'liver_statistics') 161 | writer.save() 162 | 163 | # 打印dice global 164 | print('dice global:', dice_intersection / dice_union) 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Liver segmengtation using deep learning 2 | we use 3DResUnet to segment liver in CT images and use DenseCRF as post processing. I write as much comment as possible and hope you will find this repo useful! 3 | 4 | ## Dataset 5 | Liver tumor Segmentation Challenge (LiTS) contain 131 contrast-enhanced CT images provided by hospital around the world. 3DIRCADb dataset is a subset of LiTS dataset with case number from 27 to 48. we train our model with 111 cases from LiTS after removeing the data from 3DIRCADb and evaluate on 3DIRCADb dataset. For more detail about the dataset, you can check this link: https://competitions.codalab.org/competitions/17094 6 | 7 | ## Experiment 8 | The whole traning process run on three GTX-1080Ti with batch size epual to three, we show some of the segmentation resluts of our 3DResUNet evaluate on 3DIRCADb dataset. Data augmentation is not used during the training process, because it is observed in the experiment that data enhancement such as random rotation or elastic deformation will lead to decrease in accuracy. 9 | 10 |
segmentation reslut
11 | 12 | The loss curve is shown below which is draw by visdom. 13 |
loss curve
14 | 15 | ## Usage 16 | I write all the parameter in **parameter.py**, so first set dataset path etc of your own and then run **./data_pareper/get_training_set.py** to get the training set, then you can run **./train_ds.py** to train the the network from scratch. after the model is well trained, run **val.py** to test the model on test set, if you want to use DenseCRF as post processing, run **./Densecrf/3D-CRF.py** if you get enough memory, or run **./Densecrf/2D-CRF.py** other wise. 17 | 18 | ## Main references: 19 | 1. Milletari F, Navab N, Ahmadi S A. V-net: Fully convolutional neural networks for volumetric medical image segmentation[C]//2016 Fourth International Conference on 3D Vision (3DV). IEEE, 2016: 565-571. 20 | 2. Wong K C L, Moradi M, Tang H, et al. 3d segmentation with exponential logarithmic loss for highly unbalanced object sizes[C]//International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2018: 612-619. 21 | 3. Yuan Y, Chao M, Lo Y C. Automatic skin lesion segmentation using deep fully convolutional networks with jaccard distance[J]. IEEE transactions on medical imaging, 2017, 36(9): 1876-1886. 22 | 4. Salehi S S M, Erdogmus D, Gholipour A. Tversky loss function for image segmentation using 3D fully convolutional deep networks[C]//International Workshop on Machine Learning in Medical Imaging. Springer, Cham, 2017: 379-387. 23 | 5. Brosch T, Yoo Y, Tang L Y W, et al. Deep convolutional encoder networks for multiple sclerosis lesion segmentation[C]//International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2015: 3-11. 24 | 6. Xu W, Liu H, Wang X, et al. Liver Segmentation in CT based on ResUNet with 3D Probabilistic and Geometric Post Process[C]//2019 IEEE 4th International Conference on Signal and Image Processing (ICSIP). IEEE, 2019: 685-689. 25 | 7. Krähenbühl P, Koltun V. Efficient inference in fully connected crfs with gaussian edge potentials[C]//Advances in neural information processing systems. 2011: 109-117. 26 | -------------------------------------------------------------------------------- /data_analysis/get _threshold.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 选取合适的截断阈值 4 | """ 5 | 6 | import os 7 | 8 | from tqdm import tqdm 9 | import SimpleITK as sitk 10 | 11 | import sys 12 | sys.path.append(os.path.split(sys.path[0])[0]) 13 | 14 | import parameter as para 15 | 16 | 17 | num_point = 0.0 18 | num_inlier = 0.0 19 | 20 | for file in tqdm(os.listdir(para.train_ct_path)): 21 | 22 | ct = sitk.ReadImage(os.path.join(para.train_ct_path, file), sitk.sitkInt16) 23 | ct_array = sitk.GetArrayFromImage(ct) 24 | 25 | seg = sitk.ReadImage(os.path.join(para.train_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8) 26 | seg_array = sitk.GetArrayFromImage(seg) 27 | 28 | liver_roi = ct_array[seg_array > 0] 29 | inliers = ((liver_roi < para.upper) * (liver_roi > para.lower)).astype(int).sum() 30 | 31 | print('{:.4}%'.format(inliers / liver_roi.shape[0] * 100)) 32 | print('------------') 33 | 34 | num_point += liver_roi.shape[0] 35 | num_inlier += inliers 36 | 37 | print(num_inlier / num_point) 38 | 39 | # -200 到 200 的阈值对于肝脏:训练集99.49%, 测试集99..0% 40 | # -200 到 200 的阈值对于肿瘤:训练集99.95%, 测试集99.45% 41 | -------------------------------------------------------------------------------- /data_analysis/get_spacing.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 查看数据轴向spacing分布 4 | """ 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.split(sys.path[0])[0]) 9 | 10 | from tqdm import tqdm 11 | import SimpleITK as sitk 12 | 13 | import parameter as para 14 | 15 | 16 | spacing_list = [] 17 | 18 | for file in tqdm(os.listdir(para.train_ct_path)): 19 | 20 | ct = sitk.ReadImage(os.path.join(para.train_ct_path, file), sitk.sitkInt16) 21 | temp = ct.GetSpacing()[-1] 22 | 23 | print('-----------------') 24 | print(temp) 25 | 26 | spacing_list.append(temp) 27 | 28 | print('mean:', sum(spacing_list) / len(spacing_list)) 29 | 30 | spacing_list.sort() 31 | print(spacing_list) 32 | 33 | # 训练集中的平均spacing是1.59mm 34 | # 测试集中的数据的spacing都是1mm 35 | -------------------------------------------------------------------------------- /data_analysis/liver_slice_percentage.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 查看肝脏区域slice占据整体slice的比例 4 | """ 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.split(sys.path[0])[0]) 9 | 10 | from tqdm import tqdm 11 | import SimpleITK as sitk 12 | 13 | import parameter as para 14 | 15 | total_slice = 0.0 16 | total_liver_slice = 0.0 17 | 18 | for file in tqdm(os.listdir(para.test_seg_path)): 19 | 20 | seg = sitk.ReadImage(os.path.join(para.test_seg_path, file)) 21 | seg_array = sitk.GetArrayFromImage(seg) 22 | 23 | liver_slice = 0 24 | 25 | for slice in seg_array: 26 | if 1 in slice or 2 in slice: 27 | liver_slice += 1 28 | 29 | total_slice += seg_array.shape[0] 30 | total_liver_slice += liver_slice 31 | 32 | print('precent:{:.4f}'.format(liver_slice / seg_array.shape[0] * 100)) 33 | 34 | print(total_liver_slice / total_slice) 35 | 36 | # 训练集包含肝脏的slice整体占比: 30.61% 37 | # 测试集包含肝脏的slice整体占比: 73.46% 38 | -------------------------------------------------------------------------------- /data_analysis/liver_voxel_percentage.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 查看肝脏区域像素点个数占据只包含肝脏区域的slice的百分比 4 | """ 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.split(sys.path[0])[0]) 9 | 10 | from tqdm import tqdm 11 | import SimpleITK as sitk 12 | 13 | import parameter as para 14 | 15 | total_point = 0.0 16 | total_liver_point = 0.0 17 | 18 | for seg_file in tqdm(os.listdir(para.train_seg_path)): 19 | 20 | seg = sitk.ReadImage(os.path.join(para.train_seg_path, seg_file), sitk.sitkUInt8) 21 | seg_array = sitk.GetArrayFromImage(seg) 22 | 23 | liver_slice = 0 24 | 25 | for slice in seg_array: 26 | if 1 in slice or 2 in slice: 27 | liver_slice += 1 28 | 29 | liver_point = (seg_array > 0).astype(int).sum() 30 | 31 | print('precent:{:.4f}'.format(liver_point / (liver_slice * 512 * 512) * 100)) 32 | 33 | total_point += (liver_slice * 512 * 512) 34 | total_liver_point += liver_point 35 | 36 | print(total_liver_point / total_point) 37 | 38 | # 训练集 6.99% 39 | # 测试集 6.97% 40 | -------------------------------------------------------------------------------- /data_prepare/get_training_set.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 获取可用于训练网络的训练数据集 4 | 需要四十分钟左右,产生的训练数据大小3G左右 5 | """ 6 | 7 | import os 8 | import sys 9 | sys.path.append(os.path.split(sys.path[0])[0]) 10 | import shutil 11 | from time import time 12 | 13 | import numpy as np 14 | from tqdm import tqdm 15 | import SimpleITK as sitk 16 | import scipy.ndimage as ndimage 17 | 18 | import parameter as para 19 | 20 | 21 | if os.path.exists(para.training_set_path): 22 | shutil.rmtree(para.training_set_path) 23 | 24 | new_ct_path = os.path.join(para.training_set_path, 'ct') 25 | new_seg_dir = os.path.join(para.training_set_path, 'seg') 26 | 27 | os.mkdir(para.training_set_path) 28 | os.mkdir(new_ct_path) 29 | os.mkdir(new_seg_dir) 30 | 31 | start = time() 32 | for file in tqdm(os.listdir(para.train_ct_path)): 33 | 34 | # 将CT和金标准入读内存 35 | ct = sitk.ReadImage(os.path.join(para.train_ct_path, file), sitk.sitkInt16) 36 | ct_array = sitk.GetArrayFromImage(ct) 37 | 38 | seg = sitk.ReadImage(os.path.join(para.train_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8) 39 | seg_array = sitk.GetArrayFromImage(seg) 40 | 41 | # 将金标准中肝脏和肝肿瘤的标签融合为一个 42 | seg_array[seg_array > 0] = 1 43 | 44 | # 将灰度值在阈值之外的截断掉 45 | ct_array[ct_array > para.upper] = para.upper 46 | ct_array[ct_array < para.lower] = para.lower 47 | 48 | # 对CT数据在横断面上进行降采样,并进行重采样,将所有数据的z轴的spacing调整到1mm 49 | ct_array = ndimage.zoom(ct_array, (ct.GetSpacing()[-1] / para.slice_thickness, para.down_scale, para.down_scale), order=3) 50 | seg_array = ndimage.zoom(seg_array, (ct.GetSpacing()[-1] / para.slice_thickness, 1, 1), order=0) 51 | 52 | # 找到肝脏区域开始和结束的slice,并各向外扩张slice 53 | z = np.any(seg_array, axis=(1, 2)) 54 | start_slice, end_slice = np.where(z)[0][[0, -1]] 55 | 56 | # 两个方向上各扩张slice 57 | start_slice = max(0, start_slice - para.expand_slice) 58 | end_slice = min(seg_array.shape[0] - 1, end_slice + para.expand_slice) 59 | 60 | # 如果这时候剩下的slice数量不足size,直接放弃该数据,这样的数据很少,所以不用担心 61 | if end_slice - start_slice + 1 < para.size: 62 | print('!!!!!!!!!!!!!!!!') 63 | print(file, 'have too little slice', ct_array.shape[0]) 64 | print('!!!!!!!!!!!!!!!!') 65 | continue 66 | 67 | ct_array = ct_array[start_slice:end_slice + 1, :, :] 68 | seg_array = seg_array[start_slice:end_slice + 1, :, :] 69 | 70 | # 最终将数据保存为nii 71 | new_ct = sitk.GetImageFromArray(ct_array) 72 | 73 | new_ct.SetDirection(ct.GetDirection()) 74 | new_ct.SetOrigin(ct.GetOrigin()) 75 | new_ct.SetSpacing((ct.GetSpacing()[0] * int(1 / para.down_scale), ct.GetSpacing()[1] * int(1 / para.down_scale), para.slice_thickness)) 76 | 77 | new_seg = sitk.GetImageFromArray(seg_array) 78 | 79 | new_seg.SetDirection(ct.GetDirection()) 80 | new_seg.SetOrigin(ct.GetOrigin()) 81 | new_seg.SetSpacing((ct.GetSpacing()[0], ct.GetSpacing()[1], para.slice_thickness)) 82 | 83 | sitk.WriteImage(new_ct, os.path.join(new_ct_path, file)) 84 | sitk.WriteImage(new_seg, os.path.join(new_seg_dir, file.replace('volume', 'segmentation').replace('.nii', '.nii.gz'))) 85 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | torch中的Dataset定义脚本 4 | """ 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.split(sys.path[0])[0]) 9 | 10 | import random 11 | 12 | import numpy as np 13 | import SimpleITK as sitk 14 | 15 | import torch 16 | from torch.utils.data import Dataset as dataset 17 | 18 | import parameter as para 19 | 20 | 21 | class Dataset(dataset): 22 | def __init__(self, ct_dir, seg_dir): 23 | 24 | self.ct_list = os.listdir(ct_dir) 25 | self.seg_list = list(map(lambda x: x.replace('volume', 'segmentation').replace('.nii', '.nii.gz'), self.ct_list)) 26 | 27 | self.ct_list = list(map(lambda x: os.path.join(ct_dir, x), self.ct_list)) 28 | self.seg_list = list(map(lambda x: os.path.join(seg_dir, x), self.seg_list)) 29 | 30 | def __getitem__(self, index): 31 | 32 | ct_path = self.ct_list[index] 33 | seg_path = self.seg_list[index] 34 | 35 | # 将CT和金标准读入到内存中 36 | ct = sitk.ReadImage(ct_path, sitk.sitkInt16) 37 | seg = sitk.ReadImage(seg_path, sitk.sitkUInt8) 38 | 39 | ct_array = sitk.GetArrayFromImage(ct) 40 | seg_array = sitk.GetArrayFromImage(seg) 41 | 42 | # min max 归一化 43 | ct_array = ct_array.astype(np.float32) 44 | ct_array = ct_array / 200 45 | 46 | # 在slice平面内随机选取48张slice 47 | start_slice = random.randint(0, ct_array.shape[0] - para.size) 48 | end_slice = start_slice + para.size - 1 49 | 50 | ct_array = ct_array[start_slice:end_slice + 1, :, :] 51 | seg_array = seg_array[start_slice:end_slice + 1, :, :] 52 | 53 | # 处理完毕,将array转换为tensor 54 | ct_array = torch.FloatTensor(ct_array).unsqueeze(0) 55 | seg_array = torch.FloatTensor(seg_array) 56 | 57 | return ct_array, seg_array 58 | 59 | def __len__(self): 60 | 61 | return len(self.ct_list) 62 | -------------------------------------------------------------------------------- /img/loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assassint2017/MICCAI-LITS2017/7419c945557ea540b4f30b8284761270bc07c805/img/loss_curve.png -------------------------------------------------------------------------------- /img/segmentation-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/assassint2017/MICCAI-LITS2017/7419c945557ea540b4f30b8284761270bc07c805/img/segmentation-result.png -------------------------------------------------------------------------------- /loss/BCE.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 二值交叉熵损失函数 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class BCELoss(nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | 14 | self.bce_loss = nn.BCELoss() 15 | 16 | def forward(self, pred, target): 17 | 18 | pred = pred.squeeze(dim=1) 19 | 20 | return self.bce_loss(pred, target) 21 | -------------------------------------------------------------------------------- /loss/Dice.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Dice loss 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class DiceLoss(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, pred, target): 16 | 17 | pred = pred.squeeze(dim=1) 18 | 19 | smooth = 1 20 | 21 | # dice系数的定义 22 | dice = 2 * (pred * target).sum(dim=1).sum(dim=1).sum(dim=1) / (pred.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 23 | target.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 24 | 25 | # 返回的是dice距离 26 | return torch.clamp((1 - dice).mean(), 0, 1) 27 | -------------------------------------------------------------------------------- /loss/ELDice.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Exponential Logarithmic Dice loss 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class ELDiceLoss(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, pred, target): 15 | 16 | pred = pred.squeeze(dim=1) 17 | 18 | smooth = 1 19 | 20 | # dice系数的定义 21 | dice = 2 * (pred * target).sum(dim=1).sum(dim=1).sum(dim=1) / (pred.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 22 | target.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 23 | 24 | # 返回的是dice距离 25 | return torch.clamp((torch.pow(-torch.log(dice + 1e-5), 0.3)).mean(), 0, 2) -------------------------------------------------------------------------------- /loss/Hybrid.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Dice loss + BCE loss 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class HybridLoss(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | self.bce_loss = nn.BCELoss() 15 | self.bce_weight = 1.0 16 | 17 | def forward(self, pred, target): 18 | 19 | pred = pred.squeeze(dim=1) 20 | 21 | smooth = 1 22 | 23 | # dice系数的定义 24 | dice = 2 * (pred * target).sum(dim=1).sum(dim=1).sum(dim=1) / (pred.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 25 | target.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 26 | 27 | # 返回的是dice距离 + 二值化交叉熵损失 28 | return torch.clamp((1 - dice).mean(), 0, 1) + self.bce_loss(pred, target) * self.bce_weight 29 | -------------------------------------------------------------------------------- /loss/Jaccard.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Jaccard loss 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class JaccardLoss(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, pred, target): 15 | 16 | pred = pred.squeeze(dim=1) 17 | 18 | smooth = 1 19 | 20 | # jaccard系数的定义 21 | dice = (pred * target).sum(dim=1).sum(dim=1).sum(dim=1) / (pred.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 22 | target.pow(2).sum(dim=1).sum(dim=1).sum(dim=1) - (pred * target).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 23 | 24 | # 返回的是jaccard距离 25 | return torch.clamp((1 - dice).mean(), 0, 1) 26 | -------------------------------------------------------------------------------- /loss/SS.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Sensitivity Specificity loss 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class SSLoss(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, pred, target): 14 | 15 | pred = pred.squeeze(dim=1) 16 | 17 | smooth = 1 18 | 19 | # jaccard系数的定义 20 | s1 = ((pred - target).pow(2) * target).sum(dim=1).sum(dim=1).sum(dim=1) / (smooth + target.sum(dim=1).sum(dim=1).sum(dim=1)) 21 | 22 | s2 = ((pred - target).pow(2) * (1 - target)).sum(dim=1).sum(dim=1).sum(dim=1) / (smooth + (1 - target).sum(dim=1).sum(dim=1).sum(dim=1)) 23 | 24 | # 返回的是jaccard距离 25 | return (0.05 * s1 + 0.95 * s2).mean() 26 | -------------------------------------------------------------------------------- /loss/Tversky.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Tversky loss 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class TverskyLoss(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, pred, target): 16 | 17 | pred = pred.squeeze(dim=1) 18 | 19 | smooth = 1 20 | 21 | # dice系数的定义 22 | dice = (pred * target).sum(dim=1).sum(dim=1).sum(dim=1) / ((pred * target).sum(dim=1).sum(dim=1).sum(dim=1)+ 23 | 0.3 * (pred * (1 - target)).sum(dim=1).sum(dim=1).sum(dim=1) + 0.7 * ((1 - pred) * target).sum(dim=1).sum(dim=1).sum(dim=1) + smooth) 24 | 25 | # 返回的是dice距离 26 | return torch.clamp((1 - dice).mean(), 0, 2) 27 | -------------------------------------------------------------------------------- /loss/WBCE.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 加权交叉熵损失函数 4 | 统计了一下训练集下的正负样本的比例,接近20:1 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class WCELoss(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | weight = torch.FloatTensor([0.05, 1]).cuda() 15 | self.ce_loss = nn.CrossEntropyLoss(weight) 16 | 17 | def forward(self, pred, target): 18 | pred_ = torch.ones_like(pred) - pred 19 | pred = torch.cat((pred_, pred), dim=1) 20 | 21 | target = torch.long() 22 | 23 | return self.ce_loss(pred, target) 24 | -------------------------------------------------------------------------------- /net/ResUNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 网络定义脚本 4 | """ 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.path.split(sys.path[0])[0]) 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | import parameter as para 15 | 16 | 17 | class ResUNet(nn.Module): 18 | """ 19 | 20 | 共9498260个可训练的参数, 接近九百五十万 21 | """ 22 | def __init__(self, training): 23 | super().__init__() 24 | 25 | self.training = training 26 | 27 | self.encoder_stage1 = nn.Sequential( 28 | nn.Conv3d(1, 16, 3, 1, padding=1), 29 | nn.PReLU(16), 30 | 31 | nn.Conv3d(16, 16, 3, 1, padding=1), 32 | nn.PReLU(16), 33 | ) 34 | 35 | self.encoder_stage2 = nn.Sequential( 36 | nn.Conv3d(32, 32, 3, 1, padding=1), 37 | nn.PReLU(32), 38 | 39 | nn.Conv3d(32, 32, 3, 1, padding=1), 40 | nn.PReLU(32), 41 | 42 | nn.Conv3d(32, 32, 3, 1, padding=1), 43 | nn.PReLU(32), 44 | ) 45 | 46 | self.encoder_stage3 = nn.Sequential( 47 | nn.Conv3d(64, 64, 3, 1, padding=1), 48 | nn.PReLU(64), 49 | 50 | nn.Conv3d(64, 64, 3, 1, padding=2, dilation=2), 51 | nn.PReLU(64), 52 | 53 | nn.Conv3d(64, 64, 3, 1, padding=4, dilation=4), 54 | nn.PReLU(64), 55 | ) 56 | 57 | self.encoder_stage4 = nn.Sequential( 58 | nn.Conv3d(128, 128, 3, 1, padding=3, dilation=3), 59 | nn.PReLU(128), 60 | 61 | nn.Conv3d(128, 128, 3, 1, padding=4, dilation=4), 62 | nn.PReLU(128), 63 | 64 | nn.Conv3d(128, 128, 3, 1, padding=5, dilation=5), 65 | nn.PReLU(128), 66 | ) 67 | 68 | self.decoder_stage1 = nn.Sequential( 69 | nn.Conv3d(128, 256, 3, 1, padding=1), 70 | nn.PReLU(256), 71 | 72 | nn.Conv3d(256, 256, 3, 1, padding=1), 73 | nn.PReLU(256), 74 | 75 | nn.Conv3d(256, 256, 3, 1, padding=1), 76 | nn.PReLU(256), 77 | ) 78 | 79 | self.decoder_stage2 = nn.Sequential( 80 | nn.Conv3d(128 + 64, 128, 3, 1, padding=1), 81 | nn.PReLU(128), 82 | 83 | nn.Conv3d(128, 128, 3, 1, padding=1), 84 | nn.PReLU(128), 85 | 86 | nn.Conv3d(128, 128, 3, 1, padding=1), 87 | nn.PReLU(128), 88 | ) 89 | 90 | self.decoder_stage3 = nn.Sequential( 91 | nn.Conv3d(64 + 32, 64, 3, 1, padding=1), 92 | nn.PReLU(64), 93 | 94 | nn.Conv3d(64, 64, 3, 1, padding=1), 95 | nn.PReLU(64), 96 | 97 | nn.Conv3d(64, 64, 3, 1, padding=1), 98 | nn.PReLU(64), 99 | ) 100 | 101 | self.decoder_stage4 = nn.Sequential( 102 | nn.Conv3d(32 + 16, 32, 3, 1, padding=1), 103 | nn.PReLU(32), 104 | 105 | nn.Conv3d(32, 32, 3, 1, padding=1), 106 | nn.PReLU(32), 107 | ) 108 | 109 | self.down_conv1 = nn.Sequential( 110 | nn.Conv3d(16, 32, 2, 2), 111 | nn.PReLU(32) 112 | ) 113 | 114 | self.down_conv2 = nn.Sequential( 115 | nn.Conv3d(32, 64, 2, 2), 116 | nn.PReLU(64) 117 | ) 118 | 119 | self.down_conv3 = nn.Sequential( 120 | nn.Conv3d(64, 128, 2, 2), 121 | nn.PReLU(128) 122 | ) 123 | 124 | self.down_conv4 = nn.Sequential( 125 | nn.Conv3d(128, 256, 3, 1, padding=1), 126 | nn.PReLU(256) 127 | ) 128 | 129 | self.up_conv2 = nn.Sequential( 130 | nn.ConvTranspose3d(256, 128, 2, 2), 131 | nn.PReLU(128) 132 | ) 133 | 134 | self.up_conv3 = nn.Sequential( 135 | nn.ConvTranspose3d(128, 64, 2, 2), 136 | nn.PReLU(64) 137 | ) 138 | 139 | self.up_conv4 = nn.Sequential( 140 | nn.ConvTranspose3d(64, 32, 2, 2), 141 | nn.PReLU(32) 142 | ) 143 | 144 | # 最后大尺度下的映射(256*256),下面的尺度依次递减 145 | self.map4 = nn.Sequential( 146 | nn.Conv3d(32, 1, 1, 1), 147 | nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear'), 148 | nn.Sigmoid() 149 | ) 150 | 151 | # 128*128 尺度下的映射 152 | self.map3 = nn.Sequential( 153 | nn.Conv3d(64, 1, 1, 1), 154 | nn.Upsample(scale_factor=(2, 4, 4), mode='trilinear'), 155 | nn.Sigmoid() 156 | ) 157 | 158 | # 64*64 尺度下的映射 159 | self.map2 = nn.Sequential( 160 | nn.Conv3d(128, 1, 1, 1), 161 | nn.Upsample(scale_factor=(4, 8, 8), mode='trilinear'), 162 | nn.Sigmoid() 163 | ) 164 | 165 | # 32*32 尺度下的映射 166 | self.map1 = nn.Sequential( 167 | nn.Conv3d(256, 1, 1, 1), 168 | nn.Upsample(scale_factor=(8, 16, 16), mode='trilinear'), 169 | nn.Sigmoid() 170 | ) 171 | 172 | def forward(self, inputs): 173 | 174 | long_range1 = self.encoder_stage1(inputs) + inputs 175 | 176 | short_range1 = self.down_conv1(long_range1) 177 | 178 | long_range2 = self.encoder_stage2(short_range1) + short_range1 179 | long_range2 = F.dropout(long_range2, para.drop_rate, self.training) 180 | 181 | short_range2 = self.down_conv2(long_range2) 182 | 183 | long_range3 = self.encoder_stage3(short_range2) + short_range2 184 | long_range3 = F.dropout(long_range3, para.drop_rate, self.training) 185 | 186 | short_range3 = self.down_conv3(long_range3) 187 | 188 | long_range4 = self.encoder_stage4(short_range3) + short_range3 189 | long_range4 = F.dropout(long_range4, para.drop_rate, self.training) 190 | 191 | short_range4 = self.down_conv4(long_range4) 192 | 193 | outputs = self.decoder_stage1(long_range4) + short_range4 194 | outputs = F.dropout(outputs, para.drop_rate, self.training) 195 | 196 | output1 = self.map1(outputs) 197 | 198 | short_range6 = self.up_conv2(outputs) 199 | 200 | outputs = self.decoder_stage2(torch.cat([short_range6, long_range3], dim=1)) + short_range6 201 | outputs = F.dropout(outputs, 0.3, self.training) 202 | 203 | output2 = self.map2(outputs) 204 | 205 | short_range7 = self.up_conv3(outputs) 206 | 207 | outputs = self.decoder_stage3(torch.cat([short_range7, long_range2], dim=1)) + short_range7 208 | outputs = F.dropout(outputs, 0.3, self.training) 209 | 210 | output3 = self.map3(outputs) 211 | 212 | short_range8 = self.up_conv4(outputs) 213 | 214 | outputs = self.decoder_stage4(torch.cat([short_range8, long_range1], dim=1)) + short_range8 215 | 216 | output4 = self.map4(outputs) 217 | 218 | if self.training is True: 219 | return output1, output2, output3, output4 220 | else: 221 | return output4 222 | 223 | 224 | def init(module): 225 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.ConvTranspose3d): 226 | nn.init.kaiming_normal_(module.weight.data, 0.25) 227 | nn.init.constant_(module.bias.data, 0) 228 | 229 | 230 | net = ResUNet(training=True) 231 | net.apply(init) 232 | 233 | # 计算网络参数 234 | print('net total parameters:', sum(param.numel() for param in net.parameters())) 235 | -------------------------------------------------------------------------------- /parameter.py: -------------------------------------------------------------------------------- 1 | # -----------------------路径相关参数--------------------------------------- 2 | 3 | train_ct_path = '/home/zcy/Desktop/dataset/MICCAI-LITS-2017/train/CT/' # 原始训练集CT数据路径 4 | 5 | train_seg_path = '/home/zcy/Desktop/dataset/MICCAI-LITS-2017/train/seg/' # 原始训练集标注数据路径 6 | 7 | test_ct_path = '/home/zcy/Desktop/dataset/MICCAI-LITS-2017/test/CT/' # 原始测试集CT数据路径 8 | 9 | test_seg_path = '/home/zcy/Desktop/dataset/MICCAI-LITS-2017/test/seg/' # 原始测试集标注数据路径 10 | 11 | training_set_path = './train/' # 用来训练网络的数据保存地址 12 | 13 | pred_path = '/home/zcy/Desktop/dataset/MICCAI-LITS-2017/test/liver_pred' # 网络预测结果保存路径 14 | 15 | crf_path = '/home/zcy/Desktop/dataset/MICCAI-LITS-2017/test/crf' # CRF优化结果保存路径 16 | 17 | module_path = './module/net550-0.028-0.022.pth' # 测试模型地址 18 | 19 | # -----------------------路径相关参数--------------- ------------------------ 20 | 21 | 22 | # ---------------------训练数据获取相关参数----------------------------------- 23 | 24 | size = 48 # 使用48张连续切片作为网络的输入 25 | 26 | down_scale = 0.5 # 横断面降采样因子 27 | 28 | expand_slice = 20 # 仅使用包含肝脏以及肝脏上下20张切片作为训练样本 29 | 30 | slice_thickness = 1 # 将所有数据在z轴的spacing归一化到1mm 31 | 32 | upper, lower = 200, -200 # CT数据灰度截断窗口 33 | 34 | # ---------------------训练数据获取相关参数----------------------------------- 35 | 36 | 37 | # -----------------------网络结构相关参数------------------------------------ 38 | 39 | drop_rate = 0.3 # dropout随机丢弃概率 40 | 41 | # -----------------------网络结构相关参数------------------------------------ 42 | 43 | 44 | # ---------------------网络训练相关参数-------------------------------------- 45 | 46 | gpu = '0' # 使用的显卡序号 47 | 48 | Epoch = 1000 49 | 50 | learning_rate = 1e-4 51 | 52 | learning_rate_decay = [500, 750] 53 | 54 | alpha = 0.33 # 深度监督衰减系数 55 | 56 | batch_size = 1 57 | 58 | num_workers = 3 59 | 60 | pin_memory = True 61 | 62 | cudnn_benchmark = True 63 | 64 | # ---------------------网络训练相关参数-------------------------------------- 65 | 66 | 67 | # ----------------------模型测试相关参数------------------------------------- 68 | 69 | threshold = 0.5 # 阈值度阈值 70 | 71 | stride = 12 # 滑动取样步长 72 | 73 | maximum_hole = 5e4 # 最大的空洞面积 74 | 75 | # ----------------------模型测试相关参数------------------------------------- 76 | 77 | 78 | # ---------------------CRF后处理优化相关参数---------------------------------- 79 | 80 | z_expand, x_expand, y_expand = 10, 30, 30 # 根据预测结果在三个方向上的扩展数量 81 | 82 | max_iter = 20 # CRF迭代次数 83 | 84 | s1, s2, s3 = 1, 10, 10 # CRF高斯核参数 85 | 86 | # ---------------------CRF后处理优化相关参数---------------------------------- -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.2 2 | torch==1.0.1.post2 3 | visdom==0.1.8.8 4 | pandas==0.23.3 5 | scipy==1.0.0 6 | tqdm==4.40.2 7 | scikit-image==0.13.1 8 | SimpleITK==1.0.1 9 | pydensecrf==1.0rc3 -------------------------------------------------------------------------------- /train_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 训练脚本 4 | """ 5 | 6 | import os 7 | from time import time 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.data import DataLoader 14 | 15 | from visdom import Visdom 16 | 17 | from dataset.dataset import Dataset 18 | 19 | from loss.Dice import DiceLoss 20 | from loss.ELDice import ELDiceLoss 21 | from loss.WBCE import WCELoss 22 | from loss.Jaccard import JaccardLoss 23 | from loss.SS import SSLoss 24 | from loss.Tversky import TverskyLoss 25 | from loss.Hybrid import HybridLoss 26 | from loss.BCE import BCELoss 27 | 28 | from net.ResUNet import net 29 | 30 | import parameter as para 31 | 32 | # 设置visdom 33 | viz = Visdom(port=666) 34 | step_list = [0] 35 | win = viz.line(X=np.array([0]), Y=np.array([1.0]), opts=dict(title='loss')) 36 | 37 | # 设置显卡相关 38 | os.environ['CUDA_VISIBLE_DEVICES'] = para.gpu 39 | cudnn.benchmark = para.cudnn_benchmark 40 | 41 | # 定义网络 42 | net = torch.nn.DataParallel(net).cuda() 43 | net.train() 44 | 45 | # 定义Dateset 46 | train_ds = Dataset(os.path.join(para.training_set_path, 'ct'), os.path.join(para.training_set_path, 'seg')) 47 | 48 | # 定义数据加载 49 | train_dl = DataLoader(train_ds, para.batch_size, True, num_workers=para.num_workers, pin_memory=para.pin_memory) 50 | 51 | # 挑选损失函数 52 | loss_func_list = [DiceLoss(), ELDiceLoss(), WCELoss(), JaccardLoss(), SSLoss(), TverskyLoss(), HybridLoss(), BCELoss()] 53 | loss_func = loss_func_list[5] 54 | 55 | # 定义优化器 56 | opt = torch.optim.Adam(net.parameters(), lr=para.learning_rate) 57 | 58 | # 学习率衰减 59 | lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, para.learning_rate_decay) 60 | 61 | # 深度监督衰减系数 62 | alpha = para.alpha 63 | 64 | # 训练网络 65 | start = time() 66 | for epoch in range(para.Epoch): 67 | 68 | lr_decay.step() 69 | 70 | mean_loss = [] 71 | 72 | for step, (ct, seg) in enumerate(train_dl): 73 | 74 | ct = ct.cuda() 75 | seg = seg.cuda() 76 | 77 | outputs = net(ct) 78 | 79 | loss1 = loss_func(outputs[0], seg) 80 | loss2 = loss_func(outputs[1], seg) 81 | loss3 = loss_func(outputs[2], seg) 82 | loss4 = loss_func(outputs[3], seg) 83 | 84 | loss = (loss1 + loss2 + loss3) * alpha + loss4 85 | 86 | mean_loss.append(loss4.item()) 87 | 88 | opt.zero_grad() 89 | loss.backward() 90 | opt.step() 91 | 92 | if step % 5 is 0: 93 | 94 | step_list.append(step_list[-1] + 1) 95 | viz.line(X=np.array([step_list[-1]]), Y=np.array([loss4.item()]), win=win, update='append') 96 | 97 | print('epoch:{}, step:{}, loss1:{:.3f}, loss2:{:.3f}, loss3:{:.3f}, loss4:{:.3f}, time:{:.3f} min' 98 | .format(epoch, step, loss1.item(), loss2.item(), loss3.item(), loss4.item(), (time() - start) / 60)) 99 | 100 | mean_loss = sum(mean_loss) / len(mean_loss) 101 | 102 | # 保存模型 103 | if epoch % 50 is 0 and epoch is not 0: 104 | 105 | # 网络模型的命名方式为:epoch轮数+当前minibatch的loss+本轮epoch的平均loss 106 | torch.save(net.state_dict(), './module/net{}-{:.3f}-{:.3f}.pth'.format(epoch, loss, mean_loss)) 107 | 108 | # 对深度监督系数进行衰减 109 | if epoch % 40 is 0 and epoch is not 0: 110 | alpha *= 0.8 111 | 112 | # 深度监督的系数变化 113 | # 1.000 114 | # 0.800 115 | # 0.640 116 | # 0.512 117 | # 0.410 118 | # 0.328 119 | # 0.262 120 | # 0.210 121 | # 0.168 122 | # 0.134 123 | # 0.107 124 | # 0.086 125 | # 0.069 126 | # 0.055 127 | # 0.044 128 | # 0.035 129 | # 0.028 130 | # 0.023 131 | # 0.018 132 | # 0.014 133 | # 0.012 134 | # 0.009 135 | # 0.007 136 | # 0.006 137 | # 0.005 138 | # 0.004 139 | # 0.003 140 | # 0.002 141 | # 0.002 142 | # 0.002 143 | # 0.001 144 | # 0.001 145 | # 0.001 146 | # 0.001 147 | # 0.001 148 | # 0.000 149 | # 0.000 150 | -------------------------------------------------------------------------------- /utilities/calculate_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 计算基于重叠度和距离等九种分割常见评价指标 4 | """ 5 | 6 | import math 7 | 8 | import numpy as np 9 | import scipy.spatial as spatial 10 | import scipy.ndimage.morphology as morphology 11 | 12 | 13 | class Metirc(): 14 | 15 | def __init__(self, real_mask, pred_mask, voxel_spacing): 16 | """ 17 | 18 | :param real_mask: 金标准 19 | :param pred_mask: 预测结果 20 | :param voxel_spacing: 体数据的spacing 21 | """ 22 | self.real_mask = real_mask 23 | self.pred_mask = pred_mask 24 | self.voxel_sapcing = voxel_spacing 25 | 26 | self.real_mask_surface_pts = self.get_surface(real_mask, voxel_spacing) 27 | self.pred_mask_surface_pts = self.get_surface(pred_mask, voxel_spacing) 28 | 29 | self.real2pred_nn = self.get_real2pred_nn() 30 | self.pred2real_nn = self.get_pred2real_nn() 31 | 32 | # 下面三个是提取边界和计算最小距离的实用函数 33 | def get_surface(self, mask, voxel_spacing): 34 | """ 35 | 36 | :param mask: ndarray 37 | :param voxel_spacing: 体数据的spacing 38 | :return: 提取array的表面点的真实坐标(以mm为单位) 39 | """ 40 | 41 | # 卷积核采用的是三维18邻域 42 | 43 | kernel = morphology.generate_binary_structure(3, 2) 44 | surface = morphology.binary_erosion(mask, kernel) ^ mask 45 | 46 | surface_pts = surface.nonzero() 47 | 48 | surface_pts = np.array(list(zip(surface_pts[0], surface_pts[1], surface_pts[2]))) 49 | 50 | # (0.7808688879013062, 0.7808688879013062, 2.5) (88, 410, 512) 51 | # 读出来的数据spacing和shape不是对应的,所以需要反向 52 | return surface_pts * np.array(self.voxel_sapcing[::-1]).reshape(1, 3) 53 | 54 | def get_pred2real_nn(self): 55 | """ 56 | 57 | :return: 预测结果表面体素到金标准表面体素的最小距离 58 | """ 59 | 60 | tree = spatial.cKDTree(self.real_mask_surface_pts) 61 | nn, _ = tree.query(self.pred_mask_surface_pts) 62 | 63 | return nn 64 | 65 | def get_real2pred_nn(self): 66 | """ 67 | 68 | :return: 金标准表面体素到预测结果表面体素的最小距离 69 | """ 70 | tree = spatial.cKDTree(self.pred_mask_surface_pts) 71 | nn, _ = tree.query(self.real_mask_surface_pts) 72 | 73 | return nn 74 | 75 | # 下面的六个指标是基于重叠度的 76 | def get_dice_coefficient(self): 77 | """ 78 | 79 | :return: dice系数 dice系数的分子 dice系数的分母(后两者用于计算dice_global) 80 | """ 81 | intersection = (self.real_mask * self.pred_mask).sum() 82 | union = self.real_mask.sum() + self.pred_mask.sum() 83 | 84 | return 2 * intersection / union, 2 * intersection, union 85 | 86 | def get_jaccard_index(self): 87 | """ 88 | 89 | :return: 杰卡德系数 90 | """ 91 | intersection = (self.real_mask * self.pred_mask).sum() 92 | union = (self.real_mask | self.pred_mask).sum() 93 | 94 | return intersection / union 95 | 96 | def get_VOE(self): 97 | """ 98 | 99 | :return: 体素重叠误差 Volumetric Overlap Error 100 | """ 101 | 102 | return 1 - self.get_jaccard_index() 103 | 104 | def get_RVD(self): 105 | """ 106 | 107 | :return: 体素相对误差 Relative Volume Difference 108 | """ 109 | 110 | return float(self.pred_mask.sum() - self.real_mask.sum()) / float(self.real_mask.sum()) 111 | 112 | def get_FNR(self): 113 | """ 114 | 115 | :return: 欠分割率 False negative rate 116 | """ 117 | fn = self.real_mask.sum() - (self.real_mask * self.pred_mask).sum() 118 | union = (self.real_mask | self.pred_mask).sum() 119 | 120 | return fn / union 121 | 122 | def get_FPR(self): 123 | """ 124 | 125 | :return: 过分割率 False positive rate 126 | """ 127 | fp = self.pred_mask.sum() - (self.real_mask * self.pred_mask).sum() 128 | union = (self.real_mask | self.pred_mask).sum() 129 | 130 | return fp / union 131 | 132 | # 下面的三个指标是基于距离的 133 | def get_ASSD(self): 134 | """ 135 | 136 | :return: 对称位置平均表面距离 Average Symmetric Surface Distance 137 | """ 138 | return (self.pred2real_nn.sum() + self.real2pred_nn.sum()) / \ 139 | (self.real_mask_surface_pts.shape[0] + self.pred_mask_surface_pts.shape[0]) 140 | 141 | def get_RMSD(self): 142 | """ 143 | 144 | :return: 对称位置表面距离的均方根 Root Mean Square symmetric Surface Distance 145 | """ 146 | return math.sqrt((np.power(self.pred2real_nn, 2).sum() + np.power(self.real2pred_nn, 2).sum()) / 147 | (self.real_mask_surface_pts.shape[0] + self.pred_mask_surface_pts.shape[0])) 148 | 149 | def get_MSD(self): 150 | """ 151 | 152 | :return: 对称位置的最大表面距离 Maximum Symmetric Surface Distance 153 | """ 154 | return max(self.pred2real_nn.max(), self.real2pred_nn.max()) 155 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | 测试脚本 4 | """ 5 | 6 | import os 7 | import copy 8 | import collections 9 | from time import time 10 | 11 | import torch 12 | import numpy as np 13 | import pandas as pd 14 | import scipy.ndimage as ndimage 15 | import SimpleITK as sitk 16 | import skimage.measure as measure 17 | import skimage.morphology as morphology 18 | 19 | from net.ResUNet import ResUNet 20 | from utilities.calculate_metrics import Metirc 21 | 22 | import parameter as para 23 | 24 | os.environ['CUDA_VISIBLE_DEVICES'] = para.gpu 25 | 26 | # 为了计算dice_global定义的两个变量 27 | dice_intersection = 0.0 28 | dice_union = 0.0 29 | 30 | file_name = [] # 文件名称 31 | time_pre_case = [] # 单例数据消耗时间 32 | 33 | # 定义评价指标 34 | liver_score = collections.OrderedDict() 35 | liver_score['dice'] = [] 36 | liver_score['jacard'] = [] 37 | liver_score['voe'] = [] 38 | liver_score['fnr'] = [] 39 | liver_score['fpr'] = [] 40 | liver_score['assd'] = [] 41 | liver_score['rmsd'] = [] 42 | liver_score['msd'] = [] 43 | 44 | # 定义网络并加载参数 45 | net = torch.nn.DataParallel(ResUNet(training=False)).cuda() 46 | net.load_state_dict(torch.load(para.module_path)) 47 | net.eval() 48 | 49 | for file_index, file in enumerate(os.listdir(para.test_ct_path)): 50 | 51 | start = time() 52 | 53 | file_name.append(file) 54 | 55 | # 将CT读入内存 56 | ct = sitk.ReadImage(os.path.join(para.test_ct_path, file), sitk.sitkInt16) 57 | ct_array = sitk.GetArrayFromImage(ct) 58 | 59 | origin_shape = ct_array.shape 60 | 61 | # 将灰度值在阈值之外的截断掉 62 | ct_array[ct_array > para.upper] = para.upper 63 | ct_array[ct_array < para.lower] = para.lower 64 | 65 | # min max 归一化 66 | ct_array = ct_array.astype(np.float32) 67 | ct_array = ct_array / 200 68 | 69 | # 对CT使用双三次算法进行插值,插值之后的array依然是int16 70 | ct_array = ndimage.zoom(ct_array, (1, para.down_scale, para.down_scale), order=3) 71 | 72 | # 对slice过少的数据使用padding 73 | too_small = False 74 | if ct_array.shape[0] < para.size: 75 | depth = ct_array.shape[0] 76 | temp = np.ones((para.size, int(512 * para.down_scale), int(512 * para.down_scale))) * para.lower 77 | temp[0: depth] = ct_array 78 | ct_array = temp 79 | too_small = True 80 | 81 | # 滑动窗口取样预测 82 | start_slice = 0 83 | end_slice = start_slice + para.size - 1 84 | count = np.zeros((ct_array.shape[0], 512, 512), dtype=np.int16) 85 | probability_map = np.zeros((ct_array.shape[0], 512, 512), dtype=np.float32) 86 | 87 | with torch.no_grad(): 88 | while end_slice < ct_array.shape[0]: 89 | 90 | ct_tensor = torch.FloatTensor(ct_array[start_slice: end_slice + 1]).cuda() 91 | ct_tensor = ct_tensor.unsqueeze(dim=0).unsqueeze(dim=0) 92 | 93 | outputs = net(ct_tensor) 94 | 95 | count[start_slice: end_slice + 1] += 1 96 | probability_map[start_slice: end_slice + 1] += np.squeeze(outputs.cpu().detach().numpy()) 97 | 98 | # 由于显存不足,这里直接保留ndarray数据,并在保存之后直接销毁计算图 99 | del outputs 100 | 101 | start_slice += para.stride 102 | end_slice = start_slice + para.size - 1 103 | 104 | if end_slice != ct_array.shape[0] - 1: 105 | end_slice = ct_array.shape[0] - 1 106 | start_slice = end_slice - para.size + 1 107 | 108 | ct_tensor = torch.FloatTensor(ct_array[start_slice: end_slice + 1]).cuda() 109 | ct_tensor = ct_tensor.unsqueeze(dim=0).unsqueeze(dim=0) 110 | outputs = net(ct_tensor) 111 | 112 | count[start_slice: end_slice + 1] += 1 113 | probability_map[start_slice: end_slice + 1] += np.squeeze(outputs.cpu().detach().numpy()) 114 | 115 | del outputs 116 | 117 | pred_seg = np.zeros_like(probability_map) 118 | pred_seg[probability_map >= (para.threshold * count)] = 1 119 | 120 | if too_small: 121 | temp = np.zeros((depth, 512, 512), dtype=np.float32) 122 | temp += pred_seg[0: depth] 123 | pred_seg = temp 124 | 125 | # 将金标准读入内存 126 | seg = sitk.ReadImage(os.path.join(para.test_seg_path, file.replace('volume', 'segmentation')), sitk.sitkUInt8) 127 | seg_array = sitk.GetArrayFromImage(seg) 128 | seg_array[seg_array > 0] = 1 129 | 130 | # 对肝脏进行最大连通域提取,移除细小区域,并进行内部的空洞填充 131 | pred_seg = pred_seg.astype(np.uint8) 132 | liver_seg = copy.deepcopy(pred_seg) 133 | liver_seg = measure.label(liver_seg, 4) 134 | props = measure.regionprops(liver_seg) 135 | 136 | max_area = 0 137 | max_index = 0 138 | for index, prop in enumerate(props, start=1): 139 | if prop.area > max_area: 140 | max_area = prop.area 141 | max_index = index 142 | 143 | liver_seg[liver_seg != max_index] = 0 144 | liver_seg[liver_seg == max_index] = 1 145 | 146 | liver_seg = liver_seg.astype(np.bool) 147 | morphology.remove_small_holes(liver_seg, para.maximum_hole, connectivity=2, in_place=True) 148 | liver_seg = liver_seg.astype(np.uint8) 149 | 150 | # 计算分割评价指标 151 | liver_metric = Metirc(seg_array, liver_seg, ct.GetSpacing()) 152 | 153 | liver_score['dice'].append(liver_metric.get_dice_coefficient()[0]) 154 | liver_score['jacard'].append(liver_metric.get_jaccard_index()) 155 | liver_score['voe'].append(liver_metric.get_VOE()) 156 | liver_score['fnr'].append(liver_metric.get_FNR()) 157 | liver_score['fpr'].append(liver_metric.get_FPR()) 158 | liver_score['assd'].append(liver_metric.get_ASSD()) 159 | liver_score['rmsd'].append(liver_metric.get_RMSD()) 160 | liver_score['msd'].append(liver_metric.get_MSD()) 161 | 162 | dice_intersection += liver_metric.get_dice_coefficient()[1] 163 | dice_union += liver_metric.get_dice_coefficient()[2] 164 | 165 | # 将预测的结果保存为nii数据 166 | pred_seg = sitk.GetImageFromArray(liver_seg) 167 | 168 | pred_seg.SetDirection(ct.GetDirection()) 169 | pred_seg.SetOrigin(ct.GetOrigin()) 170 | pred_seg.SetSpacing(ct.GetSpacing()) 171 | 172 | sitk.WriteImage(pred_seg, os.path.join(para.pred_path, file.replace('volume', 'pred'))) 173 | 174 | speed = time() - start 175 | time_pre_case.append(speed) 176 | 177 | print(file_index, 'this case use {:.3f} s'.format(speed)) 178 | print('-----------------------') 179 | 180 | 181 | # 将评价指标写入到exel中 182 | liver_data = pd.DataFrame(liver_score, index=file_name) 183 | liver_data['time'] = time_pre_case 184 | 185 | liver_statistics = pd.DataFrame(index=['mean', 'std', 'min', 'max'], columns=list(liver_data.columns)) 186 | liver_statistics.loc['mean'] = liver_data.mean() 187 | liver_statistics.loc['std'] = liver_data.std() 188 | liver_statistics.loc['min'] = liver_data.min() 189 | liver_statistics.loc['max'] = liver_data.max() 190 | 191 | writer = pd.ExcelWriter('./result.xlsx') 192 | liver_data.to_excel(writer, 'liver') 193 | liver_statistics.to_excel(writer, 'liver_statistics') 194 | writer.save() 195 | 196 | # 打印dice global 197 | print('dice global:', dice_intersection / dice_union) 198 | --------------------------------------------------------------------------------