├── .gitignore ├── Ensemble_seg.py ├── README.md ├── Segmentation.py ├── Uncertainty.py ├── config ├── subtest_0.txt ├── subtest_1.txt ├── subtest_2.txt ├── subtest_3.txt ├── subtest_4.txt ├── subtest_5.txt └── train.txt ├── data_process ├── Preprocess.py ├── __init__.py ├── data_process_func.py ├── label_transfer.py └── transform.py ├── error_rate.py ├── fig ├── 1 └── summary.jpg ├── models ├── ._data_loader.py ├── ._data_process.py ├── LYC_data_loader.py ├── Struseg_dataset.py ├── Unet.py ├── Unet_Separate_3.py ├── layers.py ├── loss_function.py └── module.py ├── seg_img.py ├── train.py ├── util ├── ._.DS_Store ├── ._assd_evaluation.py ├── ._dice_evaluation.py ├── ._pre_process.py ├── ._train_test_func.py ├── ._visualize.py ├── Label_exist.py ├── assd_evaluation.py ├── binary.py ├── collect_organism_hist.py ├── data_augament.py ├── dice_evaluation.py ├── dump_data.py ├── grid_normal.py ├── make_3d_ground_truth_only.py ├── parse_config.py ├── pre_function.py ├── pre_process.py ├── train_test_func.py └── visualization │ ├── 3.py │ ├── evalution.py │ ├── show_Distance.py │ ├── show_Label_contours.py │ ├── show_boxplot.py │ ├── show_multi_hist.py │ ├── show_param.py │ └── visualize_loss.py └── weights_center_crop ├── multi_thresh_1 └── readme ├── multi_thresh_2 └── readme └── multi_thresh_3 └── readme /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Ensemble_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import, print_function 4 | import time 5 | import os 6 | import shutil 7 | import torch.tensor 8 | from util.train_test_func import * 9 | from util.parse_config import parse_config 10 | from util.binary import assd, dc, hd95 11 | from data_process.data_process_func import save_array_as_nifty_volume 12 | from util.assd_evaluation import one_hot 13 | from skimage import morphology 14 | 15 | def ensemble(): 16 | config ={ 17 | 'data': { 18 | 'data_root': '/lyc/Head-Neck/MICCAI-19-StructSeg/HaN_OAR_center_crop/', 19 | 'save_root': '/lyc/Head-Neck/MICCAI-19-StructSeg/HaN_OAR_center_crop/', 20 | 'seg_name': ['subprob_0.nii.gz','subprob_1.nii.gz','subprob_2.nii.gz','subprob_3.nii.gz','subprob_4.nii.gz','subprob_5.nii.gz'], 21 | # , 'subseg_6.nii.gz', 'subseg_7.nii.gz', 'subseg_8.nii.gz'], 22 | 'label_name': 'crop_label.nii.gz', 23 | 'save_name': 'weighted_enseg.nii.gz', 24 | 'class_num': 23 25 | }, 26 | } 27 | config_data = config['data'] 28 | Mode = ['valid'] 29 | class_num = config_data['class_num'] 30 | save = True 31 | delete = False 32 | cal_dice = False 33 | cal_hd95 = False 34 | 35 | for mode in Mode: 36 | patient_list = os.listdir(config_data['data_root']+mode) 37 | patient_num = len(patient_list) 38 | dice_array = np.zeros([patient_num, class_num]) 39 | hd95_array = np.zeros([patient_num, class_num]) 40 | for patient_order in range(patient_num): 41 | patient_path = os.path.join(config_data['data_root'], mode, patient_list[patient_order]) 42 | label_path = os.path.join(patient_path, config_data['label_name']) 43 | save_path = os.path.join(config_data['save_root'], mode, patient_list[patient_order], config_data['save_name']) 44 | label = torch.from_numpy(load_nifty_volume_as_array(label_path, transpose=True)) 45 | seg = 0 46 | for seg_order in range(len(config_data['seg_name'])): 47 | seg_name = config_data['seg_name'][seg_order] 48 | seg_path = os.path.join(patient_path, seg_name) 49 | cur_seg = load_nifty_volume_as_array(seg_path, transpose=False).astype(np.uint16) 50 | # for ii in range(class_num): 51 | # cur_seg[ii] *= weight_0[ii, -seg_order-1] 52 | seg += cur_seg 53 | if delete: 54 | shutil.rmtree(seg_path) 55 | seg = np.argmax(seg, axis=0).astype(np.int16) 56 | onehot_seg = one_hot(seg, class_num) 57 | onehot_label = one_hot(label, class_num) 58 | if cal_dice: 59 | Dice = np.zeros(class_num) 60 | for i in range(class_num): 61 | Dice[i] = dc(onehot_seg[i], onehot_label[i]) 62 | dice_array[patient_order] = Dice 63 | print('patient order', patient_order, ' dice:', Dice) 64 | if cal_hd95: 65 | HD = np.zeros(class_num) 66 | for i in range(class_num): 67 | HD[i] = hd95(onehot_seg[i], onehot_label[i]) 68 | hd95_array[patient_order] = HD 69 | print('patient order', patient_order, ' dice:', HD) 70 | 71 | if save: 72 | save_array_as_nifty_volume(seg, save_path, transpose=True) 73 | 74 | if cal_dice: 75 | dice_array[:, 0] = np.mean(dice_array[:, 1::], 1) 76 | dice_mean = np.mean(dice_array, 0) 77 | dice_std = np.std(dice_array, 0) 78 | print('{0:} mode: mean dice:{1:}, std of dice:{2:}'.format(mode, dice_mean, dice_std)) 79 | 80 | if cal_hd95: 81 | hd95_array[:, 0] = np.mean(hd95_array[:, 1::], 1) 82 | hd95_mean = np.mean(hd95_array, 0) 83 | hd95_std = np.std(hd95_array, 0) 84 | print('{0:} mode: mean dice:{1:}, std of dice:{2:}'.format(mode, hd95_mean, hd95_std)) 85 | 86 | if __name__ == '__main__': 87 | ensemble() 88 | 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SepNet 2 | Code for Automatic Segmentation of Organs-at-Risk fromHead-and-Neck CT using Separable ConvolutionalNeural Network with Hard-Region-Weighted Loss, which won the third place of [StructSeg19 challenge task1](https://structseg2019.grand-challenge.org/Home/). 3 | ## Abstract 4 | Nasopharyngeal Carcinoma (NPC) is a leading form of Head-and-Neck (HAN) cancer in the Arctic, China, Southeast Asia, and the Middle East/North Africa. Accurate segmentation of Organs-at-Risk (OAR) from Computed Tomography (CT) images with uncertainty information is critical for effective planning of radiation therapy for NPC treatment. Despite the state-of-the-art performance achieved by Convolutional Neural Networks (CNNs) for automatic segmentation of OARs, existing methods do not provide uncertainty estimation of the segmentation results for treatment planning, and their accuracy is still limited by several factors, including the low contrast of soft tissues in CT, highly imbalanced sizes of OARs and large inter-slice spacing. To address these problems, we propose a novel framework for accurate OAR segmentation with reliable uncertainty estimation. First, we propose a Segmental Linear Function (SLF) to transform the intensity of CT images so that better visibility of different OARs is obtained to facilitate the segmentation task. Second, to deal with the large inter-slice spacing, we introduce a novel network (named as 3D-SepNet) based on spatially separated inter-slice convolution and intra-slice convolution. Thirdly, to deal with organs or regions that are hard to segment, we propose a hard voxel weighting strategy that automatically pays more attention to hard voxels for better segmentation. Finally, we use an ensemble of models trained with different loss functions and intensity transforms to obtain robust results, which also leads to segmentation uncertainty without extra efforts. Our method won the third place of the HAN OAR segmentation task in StructSeg 2019 challenge and it achieved weighted average Dice of 80.52% and 95% Hausdorff Distance of 3.043 mm. Experimental results show that 1) our SLF for intensity transform helps to improve the accuracy of OAR segmentation from CT images; 2) With only 1/3 parameters of 3D UNet, our 3D-SepNet obtains better segmentation results for most OARs; 3) The proposed hard voxel weighting strategy used for training effectively improves the segmentation accuracy; 4) The segmentation uncertainty obtained by our method has a high correlation to mis-segmentations, which has a potential to assist more informed decisions in clinic practice. 5 | ![image](https://github.com/LWHYC/SepNet/blob/master/fig/summary.jpg) 6 | 7 | ## Requirements 8 | Pytorch >= 1.4, SimpleITK >= 1.2, scipy >= 1.3.1, nibabel >= 2.5.0 and some common packages. 9 | 10 | ## Usages 11 | - Prepare StructSeg2019 task1 data and split them into two folders: train and valid. (Each patient's CT image and label should be in a individual folder in train or valid folder) ; 12 | - Preprocess the data by `data_process/Preprocess.py`; 13 | - Change the `data_root` in `config/train.txt` to your data root; 14 | - Run `Python train.py`. Your model is saved as `model_save_prefix` in `config/train.txt`. 15 | - Subsegmentation is obtrained by `Segmentation.py`. Ensemble results are gained by `Ensemble_seg.py`. 16 | - Use `Uncertainty.py` to gain the uncertainty estimation. 17 | -------------------------------------------------------------------------------- /Segmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | import time 4 | from torch.nn import parallel 5 | import torch.tensor 6 | from models.LYC_data_loader import LYC_dataset 7 | from util.train_test_func import * 8 | from util.parse_config import parse_config 9 | from models.Unet import Unet 10 | from models.Unet_Separate import Unet_Separate 11 | from util.binary import assd, dc 12 | from data_process.data_process_func import save_array_as_nifty_volume 13 | from util.assd_evaluation import one_hot 14 | 15 | class NetFactory(object): 16 | @staticmethod 17 | def create(name): 18 | if name == 'Unet': 19 | return Unet 20 | 21 | if name == 'Unet_Separate': 22 | return Unet_Separate 23 | 24 | # add your own networks here 25 | print('unsupported network:', name) 26 | exit() 27 | 28 | 29 | def seg(config_file): 30 | # 1, load configuration parameters 31 | print('1.Load parameters') 32 | config = parse_config(config_file) 33 | config_data = config['data'] # config of data,e.g. data_shape,batch_size. 34 | config_net = config['network'] # config of net, e.g. net_name,base_feature_name,class_num. 35 | config_train = config['training'] 36 | random.seed(config_train.get('random_seed', 1)) 37 | output_feature = config_data.get('output_feature', False) 38 | net_type = config_net['net_type'] 39 | class_num = config_net['class_num'] 40 | save = False 41 | show = False 42 | cal_dice = True 43 | cal_assd = False 44 | 45 | # 2, load data 46 | print('2.Load data') 47 | Datamode = ['valid'] 48 | 49 | # 3. creat model 50 | print('3.Creat model') 51 | # dice_eval = TestDiceLoss(class_num) 52 | net_class = NetFactory.create(net_type) 53 | net = net_class(inc=config_net.get('input_channel', 1), 54 | n_classes = class_num, 55 | base_chns= config_net.get('base_feature_number', 16), 56 | droprate=config_net.get('drop_rate', 0.2), 57 | norm='in', 58 | depth=False, 59 | dilation=1 60 | ) 61 | 62 | net = torch.nn.DataParallel(net, device_ids=[0, 1]).cuda() 63 | if config_train['load_weight']: 64 | weight = torch.load(config_train['model_path'], map_location=lambda storage, loc: storage) 65 | net.load_state_dict(weight) 66 | print(torch.cuda.is_available()) 67 | 68 | 69 | # 4, start to seg 70 | print('''start to seg ''') 71 | net.eval() 72 | for mode in Datamode: 73 | Data = LYC_dataset(config_data, mode) 74 | patient_number = len(os.listdir(os.path.join(config_data['data_root'], mode))) 75 | with torch.no_grad(): 76 | t_array = np.zeros(patient_number) 77 | dice_array = np.zeros([patient_number, class_num]) 78 | assd_array = np.zeros([patient_number, class_num]) 79 | for patient_order in range(patient_number): 80 | t1 = time.time() 81 | valid_pair, patient_path = Data.get_list_img(patient_order) 82 | clip_number = len(valid_pair['images']) # 裁剪块数 83 | clip_height = config_data['test_data_shape'][0] # 裁剪图像的高度 84 | total_labels = valid_pair['labels'].cuda() 85 | predic_size = torch.Size([1, class_num]) + total_labels.size()[1::] 86 | totalpredic = torch.zeros(predic_size).cuda() # 完整预测 87 | if output_feature: 88 | outfeature_size = torch.Size([1, 2*config_net.get('base_feature_number')]) + total_labels.size()[1::] 89 | totalfeature = torch.zeros(outfeature_size).cuda() 90 | for i in range(clip_number): 91 | tempx = valid_pair['images'][i].cuda() 92 | if output_feature: 93 | pred, outfeature = net(tempx) 94 | else: 95 | pred = net(tempx) 96 | if i < clip_number - 1: 97 | totalpredic[:, :, i * clip_height:(i + 1) * clip_height] = pred 98 | else: 99 | totalpredic[:, :, -clip_height::] = pred 100 | if output_feature: 101 | if i < clip_number - 1: 102 | totalfeature[:, :, i * clip_height:(i + 1) * clip_height] = outfeature 103 | else: 104 | totalfeature[:, :, -clip_height::] = outfeature 105 | 106 | # torchdice = dice_eval(totalpredic, total_labels) 107 | # print('torch dice:', torchdice) 108 | totalpredic = torch.max(totalpredic, 1)[1].squeeze() 109 | totalpredic = np.uint8(totalpredic.cpu().data.numpy().squeeze()) 110 | totallabel = np.uint8(total_labels.cpu().data.numpy().squeeze()) 111 | if output_feature: 112 | totalfeature = totalpredic.cpu().data.numpy().squeeze() 113 | t2 = time.time() 114 | t = t2-t1 115 | t_array[patient_order] = t 116 | 117 | one_hot_label = one_hot(totallabel, class_num) 118 | one_hot_predic = one_hot(totalpredic, class_num) 119 | 120 | if cal_dice: 121 | Dice = np.zeros(class_num) 122 | for i in range(class_num): 123 | Dice[i] = dc(one_hot_predic[i], one_hot_label[i]) 124 | dice_array[patient_order] = Dice 125 | print('patient order', patient_order, ' dice:', Dice) 126 | 127 | if cal_assd: 128 | Assd = np.zeros(class_num) 129 | for i in range(class_num): 130 | Assd[i] = assd(one_hot_predic[i], one_hot_label[i], 1) 131 | assd_array[patient_order] = Assd 132 | 133 | if show: 134 | for i in np.arange(0, totalpredic.shape[0], 2): 135 | f, plots = plt.subplots(1, 2) 136 | plots[0].imshow(totalpredic[i]) 137 | plots[1].imshow(totallabel[i]) 138 | #plots[2].imshow(oriseg[i]) 139 | # plots[1, 0].imshow(totalfeature[0, i]) 140 | # plots[1, 1].imshow(totalfeature[5, i]) 141 | plt.show() 142 | if save : 143 | if output_feature: 144 | np.save(patient_path + '/Feature.npy', totalfeature) 145 | #np.save(patient_path + '/Seg_2.npy', totalpredic) 146 | save_array_as_nifty_volume(totalpredic, patient_path + '/Seg.nii.gz') 147 | # np.savetxt(patient_path+'/Dice.npy', Dice.squeeze()) 148 | # np.savetxt(patient_path+'/Assd.npy', Assd.squeeze()) 149 | 150 | if cal_dice: 151 | dice_array[:, 0] = np.mean(dice_array[:, 1::], 1) 152 | dice_mean = np.mean(dice_array, 0) 153 | dice_std = np.std(dice_array, 0) 154 | print('{0:} mode: mean dice:{1:}, std of dice:{2:}'.format(mode, dice_mean, dice_std)) 155 | 156 | if cal_assd: 157 | assd_array[:, 0] = np.mean(assd_array[:, 1::], 1) 158 | assd_mean = np.mean(assd_array, 0) 159 | assd_std = np.std(assd_array, 0) 160 | print('{0:} mode: mean assd:{1:}, std of assd:{2:}'.format(mode, assd_mean, assd_std)) 161 | 162 | t_mean = [t_array.mean()] 163 | t_std = [t_array.std()] 164 | print('{0:} mode: mean time:{1:}, std of time:{2:}'.format(mode, t_mean, t_std)) 165 | 166 | config_file = str('config/pnet_train.txt') 167 | assert (os.path.isfile(config_file)) 168 | seg(config_file) 169 | -------------------------------------------------------------------------------- /Uncertainty.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import, print_function 4 | import time 5 | import os 6 | import shutil 7 | import torch.tensor 8 | from util.train_test_func import * 9 | from util.parse_config import parse_config 10 | from util.binary import assd, dc 11 | from data_process.data_process_func import save_array_as_nifty_volume 12 | from util.assd_evaluation import one_hot 13 | from skimage import morphology 14 | 15 | ## 计算多模型uncertainty 16 | def uncertain(): 17 | config ={ 18 | 'data': { 19 | 'data_root': '/lyc/Head-Neck/MICCAI-19-StructSeg/HaN_OAR_center_crop/', 20 | 'save_root': '/lyc/Head-Neck/MICCAI-19-StructSeg/HaN_OAR_center_crop/', 21 | 'seg_name': ['subseg_0.nii.gz','subseg_1.nii.gz','subseg_2.nii.gz','subseg_3.nii.gz','subseg_4.nii.gz','subseg_5.nii.gz'], 22 | 'label_name': 'crop_label.nii.gz', 23 | 'save_name': 'uncertain.nii.gz', 24 | 'class_num': 23 25 | }, 26 | } 27 | config_data = config['data'] 28 | Mode = ['valid'] 29 | class_num = config_data['class_num'] 30 | save = True 31 | delete = False 32 | 33 | for mode in Mode: 34 | patient_list = os.listdir(config_data['data_root']+mode) 35 | patient_num = len(patient_list) 36 | for patient_order in range(patient_num): 37 | patient_path = os.path.join(config_data['data_root'], mode, patient_list[patient_order]) 38 | save_path = os.path.join(config_data['save_root'], mode, patient_list[patient_order], config_data['save_name']) 39 | seg_freq = 0 40 | for seg_order in range(len(config_data['seg_name'])): 41 | seg_name = config_data['seg_name'][seg_order] 42 | seg_path = os.path.join(patient_path, seg_name) 43 | cur_seg = load_nifty_volume_as_array(seg_path, transpose=True) 44 | cur_seg = one_hot(cur_seg, class_num).astype(np.float) 45 | seg_freq += cur_seg 46 | if delete: 47 | shutil.rmtree(seg_path) 48 | seg_freq /= 6 49 | uncertain = np.sum(-seg_freq*np.log(seg_freq+0.00001), axis=0) 50 | 51 | print('计算完{0:}'.format(patient_list[patient_order])) 52 | if save: 53 | save_array_as_nifty_volume(uncertain, save_path, transpose=True) 54 | 55 | 56 | 57 | 58 | if __name__ == '__main__': 59 | uncertain() 60 | -------------------------------------------------------------------------------- /config/subtest_0.txt: -------------------------------------------------------------------------------- 1 | [data] 2 | data_root = /lyc/MICCAI-19-StructSeg/HaN_OAR_center_crop 3 | img_name = crop_data_multi_thresh_1.nii.gz 4 | label_name = crop_label.nii.gz 5 | seg_name = subseg_0.nii.gz 6 | label_exist_name = label_exist.npy 7 | batch_size = 4 8 | class_num = 23 9 | random_scale = False 10 | random_rotate = False 11 | subdata_shape = [2, 120, 120] 12 | test_data_shape = [2, 256, 256] 13 | output_feature = False 14 | overlap_num = 8 15 | 16 | [network] 17 | net_type = Unet_Separate_3 18 | net_name = Unet_Separate_3 19 | base_feature_number = 24 20 | compress_feature_number = 4 21 | drop_rate = 0.5 22 | with_bn = False 23 | depth = False 24 | dilation = 1 25 | slice_margin = 3 26 | class_num = 23 27 | input_channel = 1 28 | 29 | 30 | [testing] 31 | load_weight = True 32 | model_path = weights_center_crop/multi_thresh/Unet_Separate_3/Unet_Separate_3_24_atmexp_1_0.795.pkl 33 | -------------------------------------------------------------------------------- /config/subtest_1.txt: -------------------------------------------------------------------------------- 1 | [data] 2 | data_root = /lyc/MICCAI-19-StructSeg/HaN_OAR_center_crop 3 | img_name = crop_data_multi_thresh_1.nii.gz 4 | label_name = crop_label.nii.gz 5 | seg_name = subseg_1.nii.gz 6 | label_exist_name = label_exist.npy 7 | batch_size = 4 8 | class_num = 23 9 | random_scale = False 10 | random_rotate = False 11 | subdata_shape = [2, 120, 120] 12 | test_data_shape = [2, 256, 256] 13 | output_feature = False 14 | overlap_num = 8 15 | 16 | [network] 17 | net_type = Unet_Separate_3 18 | net_name = Unet_Separate_3 19 | base_feature_number = 24 20 | compress_feature_number = 4 21 | drop_rate = 0.5 22 | with_bn = False 23 | depth = False 24 | dilation = 1 25 | slice_margin = 3 26 | class_num = 23 27 | input_channel = 1 28 | 29 | 30 | [testing] 31 | load_weight = True 32 | model_path = weights_center_crop/multi_thresh/Unet_Separate_3/Unet_Separate_3_24_atmexp_0.5_0.798.pkl 33 | -------------------------------------------------------------------------------- /config/subtest_2.txt: -------------------------------------------------------------------------------- 1 | [data] 2 | data_root = /lyc/MICCAI-19-StructSeg/HaN_OAR_center_crop 3 | img_name = crop_data_multi_thresh_2.nii.gz 4 | label_name = crop_label.nii.gz 5 | seg_name = subseg_2.nii.gz 6 | label_exist_name = label_exist.npy 7 | batch_size = 4 8 | class_num = 23 9 | random_scale = False 10 | random_rotate = False 11 | subdata_shape = [2, 120, 120] 12 | test_data_shape = [2, 256, 256] 13 | output_feature = False 14 | overlap_num = 8 15 | 16 | [network] 17 | net_type = Unet_Separate_3 18 | net_name = Unet_Separate_3 19 | base_feature_number = 24 20 | compress_feature_number = 4 21 | drop_rate = 0.5 22 | with_bn = False 23 | depth = False 24 | dilation = 1 25 | slice_margin = 3 26 | class_num = 23 27 | input_channel = 1 28 | 29 | 30 | [testing] 31 | load_weight = True 32 | model_path = weights_center_crop/multi_thresh_2/Unet_Separate_3/Unet_Separate_3_24_atmexp_1_0.795.pkl -------------------------------------------------------------------------------- /config/subtest_3.txt: -------------------------------------------------------------------------------- 1 | [data] 2 | data_root = /lyc/MICCAI-19-StructSeg/HaN_OAR_center_crop 3 | img_name = crop_data_multi_thresh_2.nii.gz 4 | label_name = crop_label.nii.gz 5 | seg_name = subseg_3.nii.gz 6 | label_exist_name = label_exist.npy 7 | batch_size = 4 8 | class_num = 23 9 | random_scale = False 10 | random_rotate = False 11 | subdata_shape = [2, 120, 120] 12 | test_data_shape = [2, 256, 256] 13 | output_feature = False 14 | overlap_num = 8 15 | 16 | [network] 17 | net_type = Unet_Separate_3 18 | net_name = Unet_Separate_3 19 | base_feature_number = 24 20 | compress_feature_number = 4 21 | drop_rate = 0.5 22 | with_bn = False 23 | depth = False 24 | dilation = 1 25 | slice_margin = 3 26 | class_num = 23 27 | input_channel = 1 28 | 29 | 30 | [testing] 31 | load_weight = True 32 | model_path = weights_center_crop/multi_thresh_2/Unet_Separate_3/Unet_Separate_3_24_atmexp_0.5_0.799.pkl -------------------------------------------------------------------------------- /config/subtest_4.txt: -------------------------------------------------------------------------------- 1 | [data] 2 | data_root = /lyc/MICCAI-19-StructSeg/HaN_OAR_center_crop 3 | img_name = crop_data_multi_thresh_3.nii.gz 4 | label_name = crop_label.nii.gz 5 | seg_name = subseg_4.nii.gz 6 | label_exist_name = label_exist.npy 7 | batch_size = 4 8 | class_num = 23 9 | random_scale = False 10 | random_rotate = False 11 | subdata_shape = [2, 120, 120] 12 | test_data_shape = [2, 256, 256] 13 | output_feature = False 14 | overlap_num = 8 15 | 16 | 17 | [network] 18 | net_type = Unet_Separate_3 19 | net_name = Unet_Separate_3 20 | base_feature_number = 24 21 | compress_feature_number = 4 22 | drop_rate = 0.5 23 | with_bn = False 24 | depth = False 25 | dilation = 1 26 | slice_margin = 3 27 | class_num = 23 28 | input_channel = 1 29 | 30 | 31 | [testing] 32 | load_weight = True 33 | model_path = weights_center_crop/multi_thresh_3/Unet_Separate_3/Unet_Separate_3_24_atmexp_1_0.794.pkl -------------------------------------------------------------------------------- /config/subtest_5.txt: -------------------------------------------------------------------------------- 1 | [data] 2 | data_root = /lyc/MICCAI-19-StructSeg/HaN_OAR_center_crop 3 | img_name = crop_data_multi_thresh_3.nii.gz 4 | label_name = crop_label.nii.gz 5 | seg_name = subseg_5.nii.gz 6 | label_exist_name = label_exist.npy 7 | batch_size = 4 8 | class_num = 23 9 | random_scale = False 10 | random_rotate = False 11 | subdata_shape = [2, 120, 120] 12 | test_data_shape = [2, 256, 256] 13 | output_feature = False 14 | overlap_num = 8 15 | 16 | 17 | [network] 18 | net_type = Unet_Separate_3 19 | net_name = Unet_Separate_3 20 | base_feature_number = 24 21 | compress_feature_number = 4 22 | drop_rate = 0.5 23 | with_bn = False 24 | depth = False 25 | dilation = 1 26 | slice_margin = 3 27 | class_num = 23 28 | input_channel = 1 29 | 30 | 31 | [testing] 32 | load_weight = True 33 | model_path = weights_center_crop/multi_thresh_3/Unet_Separate_3/Unet_Separate_3_24_atmexp_0.5_0.798.pkl -------------------------------------------------------------------------------- /config/train.txt: -------------------------------------------------------------------------------- 1 | [data] 2 | net_mode = Pnet 3 | data_root = /lyc/Head-Neck/MICCAI-19-StructSeg/HaN_OAR_center_crop 4 | img_name = crop_data_multi_thresh_1.nii.gz 5 | label_name = crop_label.nii.gz 6 | label_exist_name = label_exist.npy 7 | batch_size = 4 8 | random_scale = False 9 | random_rotate = False 10 | subdata_shape = [16, 120, 120] 11 | sublabel_shape = [16, 120, 120] 12 | test_data_shape = [16, 256, 256] 13 | test_label_shape = [16, 256, 256] 14 | label_convert_source = [0, 1, 2, 3, 4] 15 | label_convert_target = [0, 1, 1, 1, 1] 16 | zoom = False 17 | zoom_factor = [0, 1, 1] 18 | class_num = 23 19 | K_folder = 5 20 | I_folder = 1 21 | output_feature = False 22 | overlap_num = 8 23 | 24 | [network] 25 | net_type = Unet_Separate 26 | net_name = Unet_Separate 27 | base_feature_number = 24 28 | compress_feature_number = 4 29 | drop_rate = 0.5 30 | dilation = 1 31 | with_bn = False 32 | depth = False 33 | slice_margin = 3 34 | class_num = 23 35 | input_channel = 1 36 | 37 | 38 | [training] 39 | load_weight = False 40 | model_path = weights_center_crop/multi_thresh_1/Unet_Separate_4/Unet_Separate_4_24_ath_exp_0.5_0.785.pkl 41 | learning_rate = 1e-3 42 | decay = 1e-8 43 | maximal_epoch = 400 44 | snapshot_epoch = 10 45 | start_iteration = 0 46 | train_step = 100 47 | test_step = 100 48 | print_step = 10 49 | model_pre_trained = 50 | model_save_prefix = weights_center_crop/multi_thresh_1/Unet_Separate_3/Unet_Separate_3_24_ath_exp_0.5_sag 51 | best_dice = 0 52 | -------------------------------------------------------------------------------- /data_process/Preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import numpy as np 4 | from scipy import ndimage 5 | from data_process_func import * 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | data_root = '/lyc/MICCAI-19-StructSeg/HaN_OAR_center_crop' 10 | filename_list = ['data.nii.gz', 'label.nii.gz'] 11 | savename_list = ['crop_data.nii.gz', 'crop_label.nii.gz'] 12 | modelist = [ 'valid'] 13 | scale_num = [16, 16, 16] 14 | save_as_nifty = True 15 | respacing = False 16 | r = 128 17 | thresh_lis = [-500, -100, 400, 1500] 18 | norm_lis = [0, 0.1, 0.8, 1] 19 | normalize = img_multi_thresh_normalized 20 | 21 | 22 | for mode in modelist: 23 | filelist =os.listdir(os.path.join(data_root, mode)) 24 | filenum = len(filelist) 25 | for ii in range(filenum): 26 | data_path = os.path.join(data_root, mode, filelist[ii], filename_list[0]) 27 | data_crop_norm_save_path = os.path.join(data_root, mode, filelist[ii], savename_list[0]) 28 | 29 | label_path = os.path.join(data_root, mode, filelist[ii], filename_list[1]) 30 | label_crop_save_path = os.path.join(data_root, mode, filelist[ii], savename_list[1]) 31 | if respacing: 32 | data = load_and_respacing_nifty_volume_as_array(data_path, target_spacing=[3,1,1], order=2) 33 | label = np.int8(load_and_respacing_nifty_volume_as_array(label_path, mode='label', target_spacing=[3,1,1])) 34 | else: 35 | data = load_nifty_volume_as_array(data_path) 36 | label = np.int8(load_nifty_volume_as_array(label_path)) 37 | 38 | center = data.shape[1] // 2 39 | 40 | data_crop = data[:, center-r:center+r, center-r:center+r] 41 | data_crop_norm = normalize(data_crop, thresh_lis, norm_lis) 42 | label_crop = label[:, center-r:center+r, center-r:center+r] 43 | 44 | 45 | 46 | if save_as_nifty: 47 | save_array_as_nifty_volume(data_crop_norm, data_crop_norm_save_path) 48 | save_array_as_nifty_volume(label_crop, label_crop_save_path) 49 | else: 50 | np.save(data_crop_norm_save_path, data_crop_norm) 51 | np.save(label_crop_save_path, label_crop) 52 | print('成功储存', filelist[ii]) 53 | -------------------------------------------------------------------------------- /data_process/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_process_func import * 2 | -------------------------------------------------------------------------------- /data_process/data_process_func.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nibabel 3 | import numpy as np 4 | import random 5 | from scipy import ndimage 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import mpu.ml 9 | import math 10 | import SimpleITK as sitk 11 | def mkdir(path): 12 | """ 13 | 创建path所给文件夹 14 | :param path: 15 | :return: 16 | """ 17 | folder = os.path.exists(path) 18 | 19 | if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 20 | os.makedirs(path) # makedirs 创建文件时如果路径不存在会创建这个路径 21 | print("--- new folder... ---") 22 | 23 | print("--- OK ---") 24 | 25 | else: 26 | print("--- There is this folder! ---") 27 | 28 | def search_file_in_folder_list(folder_list, file_name): 29 | """ 30 | Find the full filename from a list of folders 31 | inputs: 32 | folder_list: a list of folders 33 | file_name: filename 34 | outputs: 35 | full_file_name: the full filename 36 | """ 37 | file_exist = False 38 | for folder in folder_list: 39 | full_file_name = os.path.join(folder, file_name) 40 | if(os.path.isfile(full_file_name)): 41 | file_exist = True 42 | break 43 | if(file_exist == False): 44 | raise ValueError('{0:} is not found in {1:}'.format(file_name, folder)) 45 | return full_file_name 46 | 47 | def show_numpy(file_path, file_name, nrows, ncols, img_slip=1, dim=2): 48 | """ 49 | show numpy image ,shape = h*w*l or w*l 50 | :param file_path: e.g '/lyc/RTData/123.npy' 51 | :param file_name: list,contains the file name wanted 52 | :param mode: '2D'or'3D' 53 | """ 54 | img_list = [np.load(file_path+'/'+file) for file in file_name] 55 | f, plots = plt.subplots(nrows, ncols, figsize=(60, 60)) 56 | if dim == 2: 57 | for i in range(len(img_list)): 58 | assert len(img_list[i].shape) == 2 59 | plots[divmod(i, nrows)[0], divmod(i, nrows)[1]].imshow(img_list[i]) 60 | plots[divmod(i, nrows)[0], divmod(i, nrows)[1]].set_title(file_name[i]) 61 | plt.show() 62 | 63 | elif dim == 3: 64 | for ii in np.arange(0, img_list[0].shape[0], img_slip): 65 | for iii in range(len(img_list)): 66 | assert len(img_list[iii].shape) == 2 67 | plots[divmod(iii, nrows)[0], divmod(iii, nrows)[1]].imshow(img_list[iii][ii]) 68 | plots[divmod(iii, nrows)[0], divmod(iii, nrows)[1]].set_title(file_name[iii][ii]) 69 | plt.show() 70 | 71 | 72 | 73 | 74 | def save_array_as_nifty_volume(data, filename, transpose=True,pixel_spacing=[1,1,3]): 75 | """ 76 | save a numpy array as nifty image 77 | inputs: 78 | data: a numpy array with shape [Channel, Depth, Height, Width] 79 | filename: the ouput file name 80 | outputs: None 81 | """ 82 | if transpose: 83 | data = data.transpose(2, 1, 0) 84 | img = nibabel.Nifti1Image(data, None) 85 | img.header.set_zooms(pixel_spacing) 86 | nibabel.save(img, filename) 87 | 88 | def save_list_as_nifty_volume(data, filename, pixel_spacing=[1,1,3]): 89 | """ 90 | save a list as nifty image 91 | inputs: 92 | data: a list contains array [Channel, Depth, Height, Width] 93 | filename: the output file name 94 | :return: 95 | """ 96 | data = np.asarray(data) 97 | img = nibabel.Nifti1Image(data) 98 | img.header.set_zooms(pixel_spacing) 99 | nibabel.save(img, filename) 100 | 101 | def save_list_as_array_volume(data, filename): 102 | """ 103 | save a list as numpy array 104 | :param data: list contains array 105 | :param filename: where to save 106 | :return: 107 | """ 108 | data = np.asarray(data, dtype=np.float16) 109 | np.save(filename, data) 110 | 111 | def itensity_normalize_one_volume(volume): 112 | """ 113 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 114 | inputs: 115 | volume: the input nd volume 116 | outputs: 117 | out: the normalized nd volume 118 | """ 119 | 120 | pixels = volume[volume > 0] 121 | mean = pixels.mean() 122 | std = pixels.std() 123 | out = (volume - mean)/std 124 | out_random = np.random.normal(0, 1, size = volume.shape) 125 | out[volume == 0] = out_random[volume == 0] 126 | return out 127 | 128 | def convert_label(in_volume, label_convert_source, label_convert_target): 129 | """ 130 | convert the label value in a volume 131 | inputs: 132 | in_volume: input nd volume with label set label_convert_source 133 | label_convert_source: a list of integers denoting input labels, e.g., [0, 1, 2, 4] 134 | label_convert_target: a list of integers denoting output labels, e.g.,[0, 1, 2, 3] 135 | outputs: 136 | out_volume: the output nd volume with label set label_convert_target 137 | """ 138 | mask_volume = np.zeros_like(in_volume) 139 | convert_volume = np.zeros_like(in_volume) 140 | for i in range(len(label_convert_source)): 141 | source_lab = label_convert_source[i] 142 | target_lab = label_convert_target[i] 143 | if(source_lab != target_lab): 144 | temp_source = np.asarray(in_volume == source_lab) 145 | temp_target = target_lab * temp_source 146 | mask_volume = mask_volume + temp_source 147 | convert_volume = convert_volume + temp_target 148 | out_volume = in_volume * 1 149 | out_volume[mask_volume>0] = convert_volume[mask_volume>0] 150 | return out_volume 151 | 152 | def fill_array(array, divisor): 153 | """ 154 | 由于下采样操作,需要对输入图像进行填充,使其满足采样比例的整数倍 155 | :param array: np array: depth*length*height 156 | :param divisor: The shape of the output file can be divided by divisor 157 | :return: 158 | """ 159 | shape = array.shape 160 | pad_num = [[0, divisor[i]-shape[i] % divisor[i]] for i in range(len(divisor)) ] 161 | pad_array = np.pad(array, pad_num, 'constant') 162 | return pad_array 163 | 164 | def get_random_roi_sampling_center(input_shape, output_shape, sample_mode, bounding_box = None): 165 | """ 166 | get a random coordinate representing the center of a roi for sampling 167 | inputs: 168 | input_shape: the shape of sampled volume 169 | output_shape: the desired roi shape 170 | sample_mode: 'full': the entire roi should be inside the input volume 171 | 'valid': only the roi centre should be inside the input volume 172 | bounding_box: the bounding box which the roi center should be limited to 173 | outputs: 174 | center: the output center coordinate of a roi 175 | """ 176 | center = [] 177 | for i in range(len(input_shape)): 178 | if(sample_mode[i] == 'full'): # 不同轴向的裁取方式不同,z轴为full,裁剪范围需全部在输入中 179 | if(bounding_box): 180 | x0 = bounding_box[i*2]; x1 = bounding_box[i*2 + 1] 181 | else: 182 | x0 = 0; x1 = input_shape[i] 183 | else: 184 | if(bounding_box): 185 | x0 = bounding_box[i*2] + int(output_shape[i]/2) 186 | x1 = bounding_box[i*2+1] - int(output_shape[i]/2) 187 | else: 188 | x0 = int(output_shape[i]/2) 189 | x1 = input_shape[i] - x0 190 | if(x1 <= x0): # 如果输出大于输入,后期会随机填充或0填充 191 | centeri = int((x0 + x1)/2) 192 | else: 193 | centeri = random.randint(x0, x1) # 如输出小于输入,可在[x0,l-x0]范围内任选点 194 | center.append(centeri) 195 | return center 196 | 197 | def get_bound_coordinate(file, pad=[0,0,0]): 198 | ''' 199 | 输出array非0区域的各维度上下界坐标+-pad 200 | :param file: groundtruth图, 201 | :param pad: 各维度扩充的大小 202 | :return: bound: [min,max] 203 | ''' 204 | file_size = file.shape 205 | nonzeropoint = np.asarray(np.nonzero(file)) # 得到非0点坐标,输出为一个3*n的array,3代表3个维度,n代表n个非0点在对应维度上的坐标 206 | maxpoint = np.max(nonzeropoint, 1).tolist() 207 | minpoint = np.min(nonzeropoint, 1).tolist() 208 | for i in range(len(pad)): 209 | maxpoint[i] = min(maxpoint[i]+pad[i], file_size[i]) 210 | minpoint[i] = max(minpoint[i]-pad[i], 0) 211 | return [minpoint, maxpoint] 212 | 213 | def labeltrans(labelpair, file): 214 | ''' 215 | :param labelpair: labelpair list 216 | :param file: np array 217 | :return: 218 | ''' 219 | newfile = np.zeros_like(file) 220 | for label in labelpair: 221 | newfile[np.where(file == label[0])] = label[1] 222 | return newfile 223 | 224 | 225 | def load_nifty_volume_as_array(filename, transpose=True, return_spacing=False): 226 | """ 227 | load nifty image into numpy array, and transpose it based on the [z,y,x] axis order 228 | The output array shape is like [Depth, Height, Width] 229 | inputs: 230 | filename: the input file name, should be *.nii or *.nii.gz 231 | outputs: 232 | data: a numpy data array 233 | """ 234 | img = nibabel.load(filename) 235 | data = img.get_data() 236 | if transpose: 237 | data = data.transpose(2, 1, 0) 238 | if return_spacing: 239 | spacing = img.header.get_zooms() 240 | return data, spacing 241 | else: 242 | return data 243 | 244 | 245 | def load_origin_nifty_volume_as_array(filename): 246 | """ 247 | load nifty image into numpy array, and transpose it based on the [z,y,x] axis order 248 | The output array shape is like [Depth, Height, Width] 249 | inputs: 250 | filename: the input file name, should be *.nii or *.nii.gz 251 | outputs: 252 | data: a numpy data array 253 | zoomfactor: 254 | """ 255 | img = nibabel.load(filename) 256 | pixelspacing = img.header.get_zooms() 257 | zoomfactor = list(pixelspacing) 258 | zoomfactor.reverse() 259 | data = img.get_data() 260 | data = data.transpose(2, 1, 0) 261 | 262 | return data, zoomfactor 263 | 264 | def load_and_respacing_nifty_volume_as_array(filename, mode='img', target_spacing=1, order=3): 265 | img = nibabel.load(filename) 266 | pixelspacing = list(img.header.get_zooms()) 267 | pixelspacing.reverse() 268 | zoomfactor = list(np.array(pixelspacing)/np.array(target_spacing)) 269 | data = img.get_data() 270 | data = data.transpose(2, 1, 0) 271 | 272 | if mode !='img': 273 | order=0 274 | data = ndimage.zoom(data, zoom=zoomfactor, order=order) 275 | 276 | return data 277 | 278 | def img_normalized(file, upthresh=0, downthresh=0, norm=True, thresh=True): 279 | """ 280 | :param file: np array 281 | :param upthresh: 282 | :param downthresh: 283 | :param norm: norm or not 284 | :return: 285 | """ 286 | if thresh: 287 | assert upthresh > downthresh 288 | file[np.where(file > upthresh)] = upthresh 289 | file[np.where(file < downthresh)] = downthresh 290 | if norm: 291 | file = (file-downthresh)/(upthresh-downthresh) 292 | return file 293 | 294 | def img_multi_thresh_normalized(file, thresh_lis=[0], norm_lis=[0]): 295 | """ 296 | :param file: np array 297 | :param upthresh: 298 | :param downthresh: 299 | :param norm: norm or not 300 | :return: 301 | """ 302 | new_file = np.zeros_like(file).astype(np.float) 303 | 304 | for i in range(1, len(thresh_lis)): 305 | 306 | mask = np.where((file=thresh_lis[i-1])) 307 | k = (norm_lis[i]-norm_lis[i-1])/(thresh_lis[i]-thresh_lis[i-1]) 308 | b = norm_lis[i-1] 309 | new_file[mask] = file[mask]-thresh_lis[i-1] 310 | new_file[mask] = k*new_file[mask]+b 311 | new_file[np.where(file >= thresh_lis[-1])] = norm_lis[-1] 312 | return new_file 313 | 314 | 315 | def transpose_volumes(volumes, slice_direction): 316 | """ 317 | transpose a list of volumes 318 | inputs: 319 | volumes: a list of nd volumes 320 | slice_direction: 'axial', 'sagittal', or 'coronal' 321 | outputs: 322 | tr_volumes: a list of transposed volumes 323 | """ 324 | if (slice_direction == 'axial'): 325 | tr_volumes = volumes 326 | elif(slice_direction == 'sagittal'): 327 | tr_volumes = [np.transpose(x, (2, 0, 1)) for x in volumes] 328 | elif(slice_direction == 'coronal'): 329 | tr_volumes = [np.transpose(x, (1, 0, 2)) for x in volumes] 330 | else: 331 | print('undefined slice direction:', slice_direction) 332 | tr_volumes = volumes 333 | return tr_volumes 334 | 335 | 336 | def resize_ND_volume_to_given_shape(volume, zoom_factor, order = 3): 337 | """ 338 | resize an nd volume to a given shape 339 | inputs: 340 | volume: the input nd volume, an nd array 341 | out_shape: the desired output shape, a list 342 | order: the order of interpolation 343 | outputs: 344 | out_volume: the reized nd volume with given shape 345 | """ 346 | out_volume = ndimage.interpolation.zoom(volume, zoom_factor, order = order) 347 | return out_volume 348 | 349 | def resize_Multi_label_to_given_shape(volume, zoom_factor,class_number, order = 3): 350 | """ 351 | resize an multi class label to a given shape 352 | :param volume: the input label, an tensor 353 | :param zoom_factor: the zoom fatcor of z,x,y 354 | :param class_number: the number of classes 355 | :param order: the order of the interpolation 356 | :return: shape = zoom_factor*original shape z,x,y 357 | """ 358 | volume_one = convert_to_one_hot(volume, class_number) 359 | volum_one_reshape = [ndimage.interpolation.zoom(volume_one[i+1], zoom_factor, order=order) for i in range(class_number-1)] 360 | output = np.zeros_like(volum_one_reshape[0]) 361 | for i in range(class_number-1): 362 | output = np.rint(volum_one_reshape[i])*(i+1)+output 363 | return output 364 | 365 | def convert_to_one_hot(volume, class_number): 366 | ''' 367 | one hot编码 368 | :param volume: label 369 | :param C: class number 370 | :return: 371 | ''' 372 | shape = [class_number]+list(volume.shape) 373 | volume_one = np.eye(class_number)[volume.reshape(-1)].T 374 | volume_one = volume_one.reshape(shape) 375 | return volume_one 376 | 377 | def convert_one_hot_to_multi_class(one_hot,class_num): 378 | """ 379 | input size:1*class_num*h*w*l or class_num*h*w*l 380 | :param one_hot: the one hot coder array 381 | :return: h*w*l 382 | """ 383 | one_hot = one_hot.squeeze() 384 | assert (class_num == one_hot.shape[0]) 385 | img = np.ones(one_hot.shape[1::]) 386 | for i in range(class_num): 387 | img += one_hot[i]*i 388 | return img 389 | def extract_roi_from_volume(volume, in_center, output_shape, fill = 'random'): 390 | """ 391 | extract a roi from a 3d volume 392 | inputs: 393 | volume: the input 3D_train volume 394 | in_center: the center of the roi 395 | output_shape: the size of the roi 396 | fill: 'random' or 'zero', the mode to fill roi region where is outside of the input volume 397 | outputs: 398 | output: the roi volume 399 | """ 400 | input_shape = volume.shape 401 | if(fill == 'random'): 402 | output = np.random.normal(0, 1, size = output_shape) 403 | else: 404 | output = np.zeros(output_shape) 405 | r0max = [int(x/2) for x in output_shape] 406 | r1max = [output_shape[i] - r0max[i] for i in range(len(r0max))] 407 | r0 = [min(r0max[i], in_center[i]) for i in range(len(r0max))] 408 | r1 = [min(r1max[i], input_shape[i] - in_center[i]) for i in range(len(r0max))] 409 | out_center = r0max 410 | 411 | output[np.ix_(range(out_center[0] - r0[0], out_center[0] + r1[0]), 412 | range(out_center[1] - r0[1], out_center[1] + r1[1]), 413 | range(out_center[2] - r0[2], out_center[2] + r1[2]))] = \ 414 | volume[np.ix_(range(in_center[0] - r0[0], in_center[0] + r1[0]), 415 | range(in_center[1] - r0[1], in_center[1] + r1[1]), 416 | range(in_center[2] - r0[2], in_center[2] + r1[2]))] 417 | return output 418 | 419 | def set_roi_to_volume(volume, center, sub_volume): 420 | """ 421 | set the content of an roi of a 3d/4d volume to a sub volume 422 | inputs: 423 | volume: the input 3D_train/4D volume 424 | center: the center of the roi 425 | sub_volume: the content of sub volume 426 | outputs: 427 | output_volume: the output 3D_train/4D volume 428 | """ 429 | volume_shape = volume.shape 430 | patch_shape = sub_volume.shape 431 | output_volume = volume 432 | for i in range(len(center)): 433 | if(center[i] >= volume_shape[i]): 434 | return output_volume 435 | r0max = [int(x/2) for x in patch_shape] 436 | r1max = [patch_shape[i] - r0max[i] for i in range(len(r0max))] 437 | r0 = [min(r0max[i], center[i]) for i in range(len(r0max))] 438 | r1 = [min(r1max[i], volume_shape[i] - center[i]) for i in range(len(r0max))] 439 | patch_center = r0max 440 | 441 | if(len(center) == 3): 442 | output_volume[np.ix_(range(center[0] - r0[0], center[0] + r1[0]), 443 | range(center[1] - r0[1], center[1] + r1[1]), 444 | range(center[2] - r0[2], center[2] + r1[2]))] = \ 445 | sub_volume[np.ix_(range(patch_center[0] - r0[0], patch_center[0] + r1[0]), 446 | range(patch_center[1] - r0[1], patch_center[1] + r1[1]), 447 | range(patch_center[2] - r0[2], patch_center[2] + r1[2]))] 448 | elif(len(center) == 4): 449 | output_volume[np.ix_(range(center[0] - r0[0], center[0] + r1[0]), 450 | range(center[1] - r0[1], center[1] + r1[1]), 451 | range(center[2] - r0[2], center[2] + r1[2]), 452 | range(center[3] - r0[3], center[3] + r1[3]))] = \ 453 | sub_volume[np.ix_(range(patch_center[0] - r0[0], patch_center[0] + r1[0]), 454 | range(patch_center[1] - r0[1], patch_center[1] + r1[1]), 455 | range(patch_center[2] - r0[2], patch_center[2] + r1[2]), 456 | range(patch_center[3] - r0[3], patch_center[3] + r1[3]))] 457 | else: 458 | raise ValueError("array dimension should be 3 or 4") 459 | return output_volume 460 | 461 | 462 | def get_roi(volume, margin): 463 | """ 464 | get the roi bounding box of a 3D_train volume 465 | inputs: 466 | volume: the input 3D_train volume 467 | margin: an integer margin along each axis 468 | output: 469 | [mind, maxd, minh, maxh, minw, maxw]: a list of lower and upper bound along each dimension 470 | """ 471 | [d_idxes, h_idxes, w_idxes] = np.nonzero(volume) 472 | [D, H, W] = volume.shape 473 | mind = max(d_idxes.min() - margin, 0) 474 | maxd = min(d_idxes.max() + margin, D) 475 | minh = max(h_idxes.min() - margin, 0) 476 | maxh = min(h_idxes.max() + margin, H) 477 | minw = max(w_idxes.min() - margin, 0) 478 | maxw = min(w_idxes.max() + margin, W) 479 | return [mind, maxd, minh, maxh, minw, maxw] 480 | 481 | def get_largest_two_component(img, print_info = False, threshold = None): 482 | """ 483 | Get the largest two components of a binary volume 484 | inputs: 485 | img: the input 3D_train volume 486 | threshold: a size threshold 487 | outputs: 488 | out_img: the output volume 489 | """ 490 | s = ndimage.generate_binary_structure(3,2) # iterate structure 491 | labeled_array, numpatches = ndimage.label(img,s) # labeling 492 | sizes = ndimage.sum(img,labeled_array,range(1,numpatches+1)) 493 | sizes_list = [sizes[i] for i in range(len(sizes))] 494 | sizes_list.sort() 495 | if(print_info): 496 | print('component size', sizes_list) 497 | if(len(sizes) == 1): 498 | out_img = img 499 | else: 500 | if(threshold): 501 | out_img = np.zeros_like(img) 502 | for temp_size in sizes_list: 503 | if(temp_size > threshold): 504 | temp_lab = np.where(sizes == temp_size)[0] + 1 505 | temp_cmp = labeled_array == temp_lab 506 | out_img = (out_img + temp_cmp) > 0 507 | return out_img 508 | else: 509 | max_size1 = sizes_list[-1] 510 | max_size2 = sizes_list[-2] 511 | max_label1 = np.where(sizes == max_size1)[0] + 1 512 | max_label2 = np.where(sizes == max_size2)[0] + 1 513 | component1 = labeled_array == max_label1 514 | component2 = labeled_array == max_label2 515 | if(max_size2*10 > max_size1): 516 | component1 = (component1 + component2) > 0 517 | out_img = component1 518 | return out_img 519 | 520 | def fill_holes(img): 521 | """ 522 | filling small holes of a binary volume with morphological operations 523 | """ 524 | neg = 1 - img 525 | s = ndimage.generate_binary_structure(3,1) # iterate structure 526 | labeled_array, numpatches = ndimage.label(neg,s) # labeling 527 | sizes = ndimage.sum(neg,labeled_array,range(1,numpatches+1)) 528 | sizes_list = [sizes[i] for i in range(len(sizes))] 529 | sizes_list.sort() 530 | max_size = sizes_list[-1] 531 | max_label = np.where(sizes == max_size)[0] + 1 532 | component = labeled_array == max_label 533 | return 1 - component 534 | 535 | 536 | def remove_external_core(lab_main, lab_ext): 537 | """ 538 | remove the core region that is outside of whole tumor 539 | """ 540 | 541 | # for each component of lab_ext, compute the overlap with lab_main 542 | s = ndimage.generate_binary_structure(3,2) # iterate structure 543 | labeled_array, numpatches = ndimage.label(lab_ext,s) # labeling 544 | sizes = ndimage.sum(lab_ext,labeled_array,range(1,numpatches+1)) 545 | sizes_list = [sizes[i] for i in range(len(sizes))] 546 | new_lab_ext = np.zeros_like(lab_ext) 547 | for i in range(len(sizes)): 548 | sizei = sizes_list[i] 549 | labeli = np.where(sizes == sizei)[0] + 1 550 | componenti = labeled_array == labeli 551 | overlap = componenti * lab_main 552 | if((overlap.sum()+ 0.0)/sizei >= 0.5): 553 | new_lab_ext = np.maximum(new_lab_ext, componenti) 554 | return new_lab_ext 555 | 556 | def binary_dice3d(s,g): 557 | """ 558 | dice score of 3d binary volumes 559 | inputs: 560 | s: segmentation volume 561 | g: ground truth volume 562 | outputs: 563 | dice: the dice score 564 | """ 565 | assert(len(s.shape)==3) 566 | [Ds, Hs, Ws] = s.shape 567 | [Dg, Hg, Wg] = g.shape 568 | assert(Ds==Dg and Hs==Hg and Ws==Wg) 569 | prod = np.multiply(s, g) 570 | s0 = prod.sum() 571 | s1 = s.sum() 572 | s2 = g.sum() 573 | dice = 2.0*s0/(s1 + s2 + 1e-10) 574 | return dice 575 | 576 | def make_overlap_weight(overlap_num): 577 | """ 578 | 考虑到网络感受野可能超过图像厚度,故子图边界预测结果相对不可信。 579 | 在叠加时应考虑加权,对每张图中心区域预测结果给予高权重,边界低权重。 580 | :return: 581 | """ 582 | 583 | if overlap_num%2==0: 584 | weight = [1 / (1+abs(i - overlap_num//2-0.5)) for i in range(1, overlap_num+1)] 585 | else: 586 | weight = [1 / (1 + abs(i - overlap_num // 2)) for i in range(1, overlap_num+1)] 587 | 588 | return weight 589 | 590 | def zoom_data(file, mode='img', zoom_factor=[1,1,1], class_number=0): 591 | """ 592 | 对数据进行插值并储存, 593 | :param data_root: 数据所在上层目录 594 | :param save_root: 存储的顶层目录 595 | :zoom_factor: 缩放倍数 596 | :return: 597 | """ 598 | 599 | if mode =='label': 600 | intfile = np.int16(file) 601 | #zoom_file = np.int16(resize_Multi_label_to_given_shape(intfile, zoom_factor, class_number, order=2)) 602 | zoom_file = ndimage.interpolation.zoom(file, zoom_factor, order=0) 603 | elif mode == 'img': 604 | zoom_file = ndimage.interpolation.zoom(file, zoom_factor, order = 3) 605 | else: 606 | KeyError('please choose img or label mode') 607 | return zoom_file -------------------------------------------------------------------------------- /data_process/label_transfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import nibabel 4 | from data_process_func import * 5 | 6 | ori_label = [14, 15, 16, 17, 18, 19, 21, 22] 7 | data_root = '/lyc/MICCAI-19-StructSeg/HaN_OAR_normal_spacing' 8 | data_mode = ['train', 'valid'] 9 | 10 | 11 | ori_labelfile_name = 'crop_label.nii.gz' 12 | new_labelfile_name = 'part_crop_label.nii.gz' 13 | 14 | 15 | for mode in data_mode: 16 | cur_data_root = os.path.join(data_root, mode) 17 | file_list = os.listdir(cur_data_root) 18 | 19 | for file in file_list: 20 | cur_label_path = os.path.join(cur_data_root, file, ori_labelfile_name) 21 | new_label_path = os.path.join(cur_data_root, file, new_labelfile_name) 22 | 23 | cur_label, spacing = load_nifty_volume_as_array(cur_label_path, return_spacing=True) 24 | new_label = np.zeros_like(cur_label) 25 | for i in range(len(ori_label)): 26 | mask = np.where(cur_label==ori_label[i]) 27 | new_label[mask] = i+1 28 | save_array_as_nifty_volume(new_label, new_label_path, spacing) 29 | print('successfully proceed {0:}'.format(file)) -------------------------------------------------------------------------------- /data_process/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import PIL 4 | import numbers 5 | import numpy as np 6 | import torch.nn as nn 7 | import collections 8 | import matplotlib.pyplot as plt 9 | import torchvision.transforms as ts 10 | import torchvision.transforms.functional as TF 11 | from PIL import Image, ImageDraw 12 | from scipy import ndimage 13 | 14 | _pil_interpolation_to_str = { 15 | Image.NEAREST: 'PIL.Image.NEAREST', 16 | Image.BILINEAR: 'PIL.Image.BILINEAR', 17 | Image.BICUBIC: 'PIL.Image.BICUBIC', 18 | Image.LANCZOS: 'PIL.Image.LANCZOS', 19 | } 20 | 21 | 22 | def IVDM3Seg_transform(): 23 | 24 | transform = ts.Compose([ts.ToTensor()]) 25 | 26 | return transform 27 | 28 | 29 | def FetalBrain_transform(sample, train_type): 30 | image, label = Image.fromarray(np.uint8(sample['image'] * 255), mode='L'), \ 31 | Image.fromarray(np.uint8(sample['label']), mode='L') 32 | 33 | # print(image.shape, label.shape) 34 | # print(np.uint8(sample['label']).max(), np.uint8(sample['label']).min()) 35 | if train_type == 'train': 36 | # image, label = randomcrop(size=196)(image, label) 37 | image, label = randomflip_rotate(image, label, p=0.5, degrees=30) 38 | # else: 39 | # image, label = randomcrop(size=256)(image, label) 40 | 41 | image = ts.Compose([ts.ToTensor(), 42 | ts.Normalize(mean=(0.485, 0.456, 0.5), std=(0.225, 0.225, 0.5))])(image) 43 | # label = ts.ToTensor()(label) 44 | label = torch.from_numpy(np.asarray(label)).unsqueeze(dim=0) 45 | # print(label.max(), label.min()) 46 | 47 | return {'image': image, 'label': label} 48 | 49 | 50 | def Lung2dseg_transform(sample, train_type): 51 | image, label = sample['image'], sample['label'] 52 | image = 255*((image - np.min(image)) / (np.max(image) - np.min(image))) 53 | 54 | image, label = Image.fromarray(np.uint8(image), mode='L'), Image.fromarray(label, mode='L') 55 | 56 | # print(image.shape, label.shape) 57 | if train_type == 'train': 58 | # image, label = randomcrop(size=256)(image, label) 59 | image, label = randomflip_rotate(image, label, p=0.5, degrees=30) 60 | # else: 61 | # image, label = randomcrop(size=256)(image, label) 62 | 63 | image = ts.Compose([ts.ToTensor(), 64 | ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image) 65 | label = ts.ToTensor()(label) 66 | 67 | return {'image': image, 'label': label} 68 | 69 | 70 | def ISIC2018_transform(sample, train_type): 71 | image, label = Image.fromarray(np.uint8(sample['image']*255), mode='RGB'),\ 72 | Image.fromarray(np.uint8(sample['label']*255), mode='L') 73 | 74 | if train_type == 'train': 75 | image, label = randomcrop(size=(224, 300))(image, label) 76 | image, label = randomflip_rotate(image, label, p=0.5, degrees=30) 77 | else: 78 | image, label = resize(size=(224, 300))(image, label) 79 | 80 | image = ts.Compose([ts.ToTensor(), 81 | ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image) 82 | label = ts.ToTensor()(label) 83 | 84 | return {'image': image, 'label': label} 85 | 86 | 87 | # these are founctional function for transform 88 | def randomflip_rotate(img, lab, p=0.5, degrees=0): 89 | if random.random() < p: 90 | if isinstance(degrees, numbers.Number): 91 | if degrees < 0: 92 | raise ValueError("If degrees is a single number, it must be positive.") 93 | degrees = (-degrees, degrees) 94 | else: 95 | if len(degrees) != 2: 96 | raise ValueError("If degrees is a sequence, it must be of len 2.") 97 | degrees = degrees 98 | angle = random.uniform(degrees[0], degrees[1]) 99 | img = TF.rotate(img, angle) 100 | lab = TF.rotate(lab, angle) 101 | 102 | return img, lab 103 | 104 | def random_rotate(img, lab, p=0.5, axes=(0,1), degrees=0): 105 | if random.random() < p: 106 | if isinstance(degrees, numbers.Number): 107 | if degrees < 0: 108 | raise ValueError("If degrees is a single number, it must be positive.") 109 | degrees = (-degrees, degrees) 110 | else: 111 | if len(degrees) != 2: 112 | raise ValueError("If degrees is a sequence, it must be of len 2.") 113 | degrees = degrees 114 | if len(axes) !=2: 115 | axes = random.sample(axes, 2) 116 | else: 117 | axes = axes 118 | angle = random.uniform(degrees[0], degrees[1]) 119 | img = ndimage.rotate(img, angle, axes=axes, order=0, reshape=False) 120 | lab = ndimage.rotate(lab, angle, axes=axes, order=0, reshape=False) 121 | 122 | return img, lab 123 | 124 | 125 | class randomcrop(object): 126 | """Crop the given PIL Image and mask at a random location. 127 | 128 | Args: 129 | size (sequence or int): Desired output size of the crop. If size is an 130 | int instead of sequence like (h, w), a square crop (size, size) is 131 | made. 132 | padding (int or sequence, optional): Optional padding on each border 133 | of the image. Default is 0, i.e no padding. If a sequence of length 134 | 4 is provided, it is used to pad left, top, right, bottom borders 135 | respectively. 136 | pad_if_needed (boolean): It will pad the image if smaller than the 137 | desired size to avoid raising an exception. 138 | """ 139 | 140 | def __init__(self, size, padding=0, pad_if_needed=False): 141 | if isinstance(size, numbers.Number): 142 | self.size = (int(size), int(size)) 143 | else: 144 | self.size = size 145 | self.padding = padding 146 | self.pad_if_needed = pad_if_needed 147 | 148 | @staticmethod 149 | def get_params(img, output_size): 150 | """Get parameters for ``crop`` for a random crop. 151 | 152 | Args: 153 | img (PIL Image): Image to be cropped. 154 | output_size (tuple): Expected output size of the crop. 155 | 156 | Returns: 157 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 158 | """ 159 | w, h = img.size 160 | th, tw = output_size 161 | if w == tw and h == th: 162 | return 0, 0, h, w 163 | 164 | i = random.randint(0, h - th) 165 | j = random.randint(0, w - tw) 166 | return i, j, th, tw 167 | 168 | def __call__(self, img, lab): 169 | """ 170 | Args: 171 | img (PIL Image): Image to be cropped. 172 | lab (PIL Image): Image to be cropped. 173 | 174 | Returns: 175 | PIL Image: Cropped image and mask. 176 | """ 177 | if self.padding > 0: 178 | img = TF.pad(img, self.padding) 179 | lab = TF.pad(lab, self.padding) 180 | 181 | # pad the width if needed 182 | if self.pad_if_needed and img.size[0] < self.size[1]: 183 | img = TF.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0)) 184 | lab = TF.pad(lab, (int((1 + self.size[1] - lab.size[0]) / 2), 0)) 185 | # pad the height if needed 186 | if self.pad_if_needed and img.size[1] < self.size[0]: 187 | img = TF.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2))) 188 | lab = TF.pad(lab, (0, int((1 + self.size[0] - lab.size[1]) / 2))) 189 | 190 | i, j, h, w = self.get_params(img, self.size) 191 | 192 | return TF.crop(img, i, j, h, w), TF.crop(lab, i, j, h, w) 193 | 194 | def __repr__(self): 195 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 196 | 197 | 198 | class resize(object): 199 | """Resize the input PIL Image and mask to the given size. 200 | 201 | Args: 202 | size (sequence or int): Desired output size. If size is a sequence like 203 | (h, w), output size will be matched to this. If size is an int, 204 | smaller edge of the image will be matched to this number. 205 | i.e, if height > width, then image will be rescaled to 206 | (size * height / width, size) 207 | interpolation (int, optional): Desired interpolation. Default is 208 | ``PIL.Image.BILINEAR`` 209 | """ 210 | 211 | def __init__(self, size, interpolation=Image.BILINEAR): 212 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 213 | self.size = size 214 | self.interpolation = interpolation 215 | 216 | def __call__(self, img, lab): 217 | """ 218 | Args: 219 | img (PIL Image): Image to be scaled. 220 | lab (PIL Image): Image to be scaled. 221 | 222 | Returns: 223 | PIL Image: Rescaled image and mask. 224 | """ 225 | return TF.resize(img, self.size, self.interpolation), TF.resize(lab, self.size, self.interpolation) 226 | 227 | def __repr__(self): 228 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 229 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 230 | 231 | 232 | def itensity_normalize(volume): 233 | """ 234 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 235 | inputs: 236 | volume: the input nd volume 237 | outputs: 238 | out: the normalized n d volume 239 | """ 240 | 241 | # pixels = volume[volume > 0] 242 | mean = volume.mean() 243 | std = volume.std() 244 | out = (volume - mean) / std 245 | out_random = np.random.normal(0, 1, size=volume.shape) 246 | out[volume == 0] = out_random[volume == 0] 247 | 248 | return out -------------------------------------------------------------------------------- /error_rate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import nibabel 4 | 5 | def load_nifty_volume_as_array(filename): 6 | # input shape [W, H, D] 7 | # output shape [D, H, W] 8 | seg = nibabel.load(filename) 9 | data = seg.get_data() 10 | data = np.transpose(data, [2,1,0]) 11 | return data 12 | 13 | def save_array_as_nifty_volume(data, filename): 14 | # numpy data shape [D, H, W] 15 | # nifty image shape [W, H, W] 16 | data = np.transpose(data, [2,1,0]) 17 | seg = nibabel.Nifti1Image(data, np.eye(4)) 18 | nibabel.save(seg, filename) 19 | 20 | patientroot = '/lyc/Head-Neck/MICCAI-19-StructSeg/HaN_OAR_center_crop/valid/' #预测结果-+ 21 | uncertaininterval = [0,0.4505,0.6365,0.6931,1.011] 22 | data_name = ['enseg.nii.gz','uncertain.nii.gz', 'crop_label.nii.gz'] 23 | error_list = np.zeros([10, 5]) 24 | patient_number =0 25 | for patient in os.listdir(patientroot): 26 | print('segname is ',patient) 27 | ''' 28 | 根据原label得到新label/seg的名称,像素间距与储存路径 29 | ''' 30 | segname = data_name[0] 31 | labelname = data_name[2] 32 | uncertainname = data_name[1] 33 | 34 | segpath = os.path.join(patientroot, patient, segname) 35 | labelpath = os.path.join(patientroot,patient, labelname) 36 | uncertainpath = os.path.join(patientroot,patient, uncertainname) 37 | 38 | seg = load_nifty_volume_as_array(segpath) 39 | label = load_nifty_volume_as_array(labelpath) 40 | uncertainty = load_nifty_volume_as_array(uncertainpath) 41 | errormap = np.zeros_like(seg) 42 | errormap[np.where(label!=seg)]=1 43 | print(np.sum(errormap)) 44 | for i in range(len(uncertaininterval)): 45 | uncertainmap = np.zeros_like(seg) 46 | uncertainty_index = np.where(uncertainty==uncertaininterval[i]) 47 | uncertainmap[uncertainty_index]=1 48 | errorsum = np.sum(uncertainmap*errormap) 49 | error_list[patient_number, i] = errorsum/len(uncertainty_index[0]) 50 | print(error_list) 51 | patient_number+=1 52 | -------------------------------------------------------------------------------- /fig/1: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /fig/summary.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/fig/summary.jpg -------------------------------------------------------------------------------- /models/._data_loader.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/models/._data_loader.py -------------------------------------------------------------------------------- /models/._data_process.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/models/._data_process.py -------------------------------------------------------------------------------- /models/LYC_data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | 4 | from data_process.data_process_func import * 5 | from data_process.transform import * 6 | from scipy import ndimage 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | import os 11 | import os.path 12 | import random 13 | import math 14 | import nibabel 15 | import time 16 | from scipy import ndimage 17 | 18 | 19 | class LYC_dataset(Dataset): 20 | 21 | def __init__(self, config, stage): 22 | """ 23 | 用于分割的数据集 24 | """ 25 | self.config = config 26 | self.stage = stage 27 | self.data_root = config['data_root'] 28 | self.batchsize = self.config['batch_size'] 29 | self.subdatashape = self.config['subdata_shape'] 30 | self.testdatashape = self.config['test_data_shape'] 31 | self.classnum = int(self.config['class_num']) 32 | self.img_name = self.config['img_name'] 33 | self.label_name = self.config['label_name'] 34 | self.label_exist_name = self.config['label_exist_name'] 35 | self.patient_name = os.listdir(self.data_root+'/'+self.stage) 36 | self.patient_num = len(self.patient_name) 37 | self.output_feature = self.config['output_feature'] 38 | self.random_rotate = self.config['random_rotate'] 39 | self.random_scale = self.config['random_scale'] 40 | if self.output_feature: 41 | self.feature_name = self.config['feature_name'] 42 | self.overlap_num = self.config['overlap_num'] 43 | 44 | def get_subimage_batch(self): 45 | ''' 46 | just for train!!! 47 | :return: images and labels with subshape 48 | ''' 49 | patientchosen = random.sample(self.patient_name, self.batchsize) 50 | batch = {} 51 | databatch = [] 52 | labelbatch = [] 53 | labelexistbatch = [] 54 | multichannel = isinstance(self.img_name, list) 55 | if multichannel: 56 | chan_num = len(self.img_name) 57 | for i in range(self.batchsize): 58 | datapath = [] 59 | originaldata = [] 60 | labelpath = "{0:}/{1:}/{2:}/{3:}".format(self.data_root, self.stage, patientchosen[i], self.label_name) 61 | for sub_img_name in self.img_name: 62 | subdatapath = "{0:}/{1:}/{2:}/{3:}".format(self.data_root, self.stage, patientchosen[i], sub_img_name) 63 | datapath.append(subdatapath) 64 | if datapath[0].endswith('npy') or datapath[0].endswith('npz'): 65 | for subdatapath in datapath: 66 | originaldata.append(np.load(subdatapath)) 67 | originallabel = np.load(labelpath) 68 | elif datapath[0].endswith('nii') or datapath[0].endswith('nii.gz'): 69 | for subdatapath in datapath: 70 | originaldata=load_nifty_volume_as_array(subdatapath, transpose=False) 71 | originallabel = load_nifty_volume_as_array(labelpath) 72 | else: 73 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 74 | 75 | if self.random_scale: 76 | zoomfactor = random.choice([0.8, 1, 1.2]) 77 | originallabel = ndimage.interpolation.zoom(originallabel, zoomfactor, order = 0) 78 | for ii in range(chan_num): 79 | originaldata[ii] = ndimage.interpolation.zoom(originaldata[ii], zoomfactor, order = 3) 80 | 81 | originaldata = np.asarray(originaldata) 82 | shapemin = [random.randint(0, originaldata.shape[ii+1]-self.subdatashape[ii]) for ii in range(3)] 83 | subdata = originaldata[:, shapemin[0]:shapemin[0]+self.subdatashape[0], shapemin[1]:shapemin[1]+self.subdatashape[1], 84 | shapemin[2]:shapemin[2]+self.subdatashape[2]] 85 | sublabel = originallabel[shapemin[0]:shapemin[0]+self.subdatashape[0], shapemin[1]:shapemin[1]+self.subdatashape[1], 86 | shapemin[2]:shapemin[2]+self.subdatashape[2]] 87 | 88 | if self.random_rotate: 89 | subdata,sublabel = random_rotate(subdata, sublabel, p=0.9, degrees=[15, -15], axes=[0, 1, 2]) 90 | 91 | 92 | databatch.append(subdata) 93 | labelbatch.append(sublabel) 94 | 95 | else: 96 | for i in range(self.batchsize): 97 | datapath = "{0:}/{1:}/{2:}/{3:}".format(self.data_root, self.stage, patientchosen[i], self.img_name) 98 | labelpath = "{0:}/{1:}/{2:}/{3:}".format(self.data_root, self.stage, patientchosen[i], self.label_name) 99 | 100 | if datapath.endswith('npy') or datapath.endswith('npz'): 101 | originaldata = np.load(datapath) 102 | originallabel = np.load(labelpath) 103 | elif datapath.endswith('nii') or datapath.endswith('nii.gz'): 104 | originaldata = load_nifty_volume_as_array(datapath) 105 | originallabel = load_nifty_volume_as_array(labelpath) 106 | else: 107 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 108 | 109 | 110 | if self.random_scale: 111 | zoomfactor = random.choice([0.8,1,1.2]) 112 | originallabel = ndimage.interpolation.zoom(originallabel, zoomfactor, order = 0) 113 | originaldata = ndimage.interpolation.zoom(originaldata, zoomfactor, order = 3) 114 | 115 | 116 | shapemin = [random.randint(0, originaldata.shape[ii]-self.subdatashape[ii]) for ii in range(3)] 117 | subdata = originaldata[shapemin[0]:shapemin[0]+self.subdatashape[0], shapemin[1]:shapemin[1]+self.subdatashape[1], 118 | shapemin[2]:shapemin[2]+self.subdatashape[2]] 119 | sublabel = originallabel[shapemin[0]:shapemin[0]+self.subdatashape[0], shapemin[1]:shapemin[1]+self.subdatashape[1], 120 | shapemin[2]:shapemin[2]+self.subdatashape[2]] 121 | if self.random_rotate: 122 | subdata,sublabel = random_rotate(subdata, sublabel, p=0.5, degrees=[15, -15], axes=[0, 1, 2]) 123 | 124 | 125 | databatch.append(subdata[np.newaxis, :]) 126 | labelbatch.append(sublabel) 127 | 128 | 129 | batch['images'] = databatch 130 | batch['labels'] = labelbatch 131 | return batch 132 | 133 | 134 | 135 | def get_list_img(self, patient_order): 136 | ''' 137 | test与valid用,将给定patient的img与label按z轴切片,底端部分会有些重叠 138 | :patient_order:在patient_name里所选病人的order 139 | :return: images:N*1*H/N*W*L,N为所切割片数,1是为了满足预测时维数需要所加,相当于预测时batchsize为1 140 | labels:1*H*W*L 141 | ''' 142 | print(self.patient_name[patient_order]) 143 | patient_path = "{0:}/{1:}/{2:}".format(self.data_root, self.stage, self.patient_name[patient_order]) 144 | databatch = [] 145 | batch = {} 146 | multichannel = isinstance(self.img_name, list) 147 | 148 | if multichannel: 149 | datapath = [] 150 | originaldata = [] 151 | labelpath = os.path.join(patient_path, self.label_name) 152 | for sub_img_name in self.img_name: 153 | subdatapath =os.path.join(patient_path, sub_img_name) 154 | datapath.append(subdatapath) 155 | if datapath[0].endswith('npy') or datapath[0].endswith('npz'): 156 | for subdatapath in datapath: 157 | originaldata = np.load(subdatapath) 158 | originallabel = np.load(labelpath)[np.newaxis, :] 159 | elif datapath[0].endswith('nii') or datapath[0].endswith('nii.gz'): 160 | for subdatapath in datapath: 161 | originaldata = load_nifty_volume_as_array(subdatapath, transpose=False) 162 | originallabel = load_nifty_volume_as_array(labelpath)[np.newaxis, :] 163 | else: 164 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 165 | 166 | originaldata = np.asarray(originaldata) 167 | if self.output_feature: 168 | labelpath = os.path.join(patient_path, self.feature_name) 169 | originalfeature = np.load(labelpath) 170 | originaldata = np.concatenate((originaldata, originalfeature), 0) 171 | batch['originalshape'] = originaldata.shape[1::] 172 | 173 | 174 | 175 | img_number = int(math.ceil(originaldata.shape[1] / self.testdatashape[0])) # 在z轴上切成的块数,若为小数会向上取整 176 | for i in range(img_number-1): 177 | subdata = originaldata[:, i*self.testdatashape[0]:(i+1)*self.testdatashape[0]][np.newaxis, :] 178 | subdata = torch.from_numpy(subdata).float() 179 | databatch.append(subdata) 180 | subdata = originaldata[:, -self.testdatashape[0]::][np.newaxis, :] 181 | subdata = torch.from_numpy(subdata).float() 182 | databatch.append(subdata) 183 | else: 184 | datapath = os.path.join(patient_path, self.img_name) 185 | labelpath = os.path.join(patient_path, self.label_name) 186 | if datapath.endswith('npy') or datapath.endswith('npz'): 187 | originaldata = np.load(datapath) 188 | originallabel = np.load(labelpath)[np.newaxis, :] # 增加一个batchsize维度 189 | elif datapath.endswith('nii') or datapath.endswith('nii.gz'): 190 | originaldata = load_nifty_volume_as_array(datapath) 191 | originallabel = load_nifty_volume_as_array(labelpath)[np.newaxis, :] 192 | else: 193 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 194 | 195 | if self.output_feature: 196 | labelpath = os.path.join(patient_path, self.feature_name) 197 | originalfeature = np.load(labelpath) 198 | originaldata = np.concatenate((originaldata, originalfeature), 0) 199 | batch['originalshape'] = originaldata.shape 200 | 201 | img_number = int(math.ceil(originaldata.shape[0] / self.testdatashape[0])) # 在z轴上切成的块数,若为小数会向上取整 202 | for i in range(img_number - 1): 203 | subdata = originaldata[i * self.testdatashape[0]:(i + 1) * self.testdatashape[0]][np.newaxis, :][ 204 | np.newaxis, :] 205 | subdata = torch.from_numpy(subdata).float() 206 | databatch.append(subdata) 207 | subdata = originaldata[-self.testdatashape[0]::][np.newaxis, :][np.newaxis, :] 208 | subdata = torch.from_numpy(subdata).float() 209 | databatch.append(subdata) 210 | 211 | batch['images'] = databatch 212 | batch['labels'] = torch.from_numpy(np.int16(originallabel)).float() 213 | return batch, patient_path 214 | 215 | def get_list_overlap_img(self, patient_order): 216 | ''' 217 | test与valid用,将给定patient的img与label按z轴切片,底端部分会有些重叠 218 | :patient_order:在patient_name里所选病人的order 219 | :return: images:N*1*H/N*W*L,N为所切割片数,1是为了满足预测时维数需要所加,相当于预测时batchsize为1 220 | labels:1*H*W*L 221 | ''' 222 | print(self.patient_name[patient_order]) 223 | patient_path = "{0:}/{1:}/{2:}".format(self.data_root, self.stage, self.patient_name[patient_order]) 224 | databatch = [] 225 | batch = {} 226 | multichannel = isinstance(self.img_name, list) 227 | 228 | if multichannel: 229 | datapath = [] 230 | originaldata = [] 231 | labelpath = os.path.join(patient_path, self.label_name) 232 | for sub_img_name in self.img_name: 233 | subdatapath =os.path.join(patient_path, sub_img_name) 234 | datapath.append(subdatapath) 235 | if datapath[0].endswith('npy') or datapath[0].endswith('npz'): 236 | for subdatapath in datapath: 237 | originaldata = np.load(subdatapath) 238 | originallabel = np.load(labelpath)[np.newaxis, :] 239 | elif datapath[0].endswith('nii') or datapath[0].endswith('nii.gz'): 240 | for subdatapath in datapath: 241 | originaldata = load_nifty_volume_as_array(subdatapath, transpose=False) 242 | originallabel = load_nifty_volume_as_array(labelpath)[np.newaxis, :] 243 | else: 244 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 245 | 246 | originaldata = np.asarray(originaldata) 247 | if self.output_feature: 248 | labelpath = os.path.join(patient_path, self.feature_name) 249 | originalfeature = np.load(labelpath) 250 | originaldata = np.concatenate((originaldata, originalfeature), 0) 251 | batch['originalshape'] = originaldata.shape[1::] 252 | 253 | 254 | 255 | img_number = int(math.ceil(originaldata.shape[1] / self.testdatashape[0])) # 在z轴上切成的块数,若为小数会向上取整 256 | for i in range(img_number-1): 257 | subdata = originaldata[:, i*self.testdatashape[0]:(i+1)*self.testdatashape[0]][np.newaxis, :] 258 | subdata = torch.from_numpy(subdata).float() 259 | databatch.append(subdata) 260 | subdata = originaldata[:, -self.testdatashape[0]::][np.newaxis, :] 261 | subdata = torch.from_numpy(subdata).float() 262 | databatch.append(subdata) 263 | else: 264 | datapath = os.path.join(patient_path, self.img_name) 265 | labelpath = os.path.join(patient_path, self.label_name) 266 | if datapath.endswith('npy') or datapath.endswith('npz'): 267 | originaldata = np.load(datapath) 268 | originallabel = np.load(labelpath)[np.newaxis, :] # 增加一个batchsize维度 269 | elif datapath.endswith('nii') or datapath.endswith('nii.gz'): 270 | originaldata = load_nifty_volume_as_array(datapath) 271 | originallabel = load_nifty_volume_as_array(labelpath)[np.newaxis, :] 272 | else: 273 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 274 | 275 | if self.output_feature: 276 | labelpath = os.path.join(patient_path, self.feature_name) 277 | originalfeature = np.load(labelpath) 278 | originaldata = np.concatenate((originaldata, originalfeature), 0) 279 | batch['originalshape'] = originaldata.shape 280 | 281 | img_number = int(math.ceil(originaldata.shape[0] / self.testdatashape[0]))-self.overlap_num+1 # 在z轴上切成的块数,若为小数会向上取整 282 | for i in range(img_number - 1): 283 | subdata = originaldata[i * self.testdatashape[0]:(i + self.overlap_num) * self.testdatashape[0]][np.newaxis, :][ 284 | np.newaxis, :] 285 | subdata = torch.from_numpy(subdata).float() 286 | databatch.append(subdata) 287 | subdata = originaldata[-self.overlap_num*self.testdatashape[0]::][np.newaxis, :][np.newaxis, :] 288 | subdata = torch.from_numpy(subdata).float() 289 | databatch.append(subdata) 290 | 291 | batch['images'] = databatch 292 | batch['labels'] = torch.from_numpy(np.int16(originallabel)).float() 293 | return batch, patient_path 294 | 295 | 296 | def set_noclass_zero(labelexist, prediction, label): 297 | ''' 298 | 当前患者中未标记的器官对应label置0 299 | :param labelexist: 记录有哪些器官 300 | :param label: 301 | :return: 302 | ''' 303 | realprediction = torch.max(prediction, 1)[1] 304 | for i in range(labelexist.shape[0]): 305 | for classnum in range(1, labelexist.shape[1]): 306 | if labelexist[i, classnum] == 0: 307 | a = realprediction[i]==classnum # prediction中为当前class的位置 308 | b = label[i]!=0&classnum # label中其它器官的位置 309 | a = a*(1-b) 310 | label[i][a] = classnum 311 | return label 312 | -------------------------------------------------------------------------------- /models/Struseg_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | 4 | from data_process.data_process_func import * 5 | from data_process.transform import * 6 | from scipy import ndimage 7 | import glob 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | import os 12 | import os.path 13 | import random 14 | import math 15 | from torchvision import transforms as T 16 | import nibabel 17 | import time 18 | from scipy import ndimage 19 | 20 | 21 | class Struseg_dataset(Dataset): 22 | 23 | def __init__(self, config, stage): 24 | """ 25 | 用于分割的数据集 26 | """ 27 | self.config = config 28 | self.stage = stage 29 | self.net_mode = config['net_mode'] 30 | self.data_root = config['data_root'] 31 | self.batchsize = self.config['batch_size'] 32 | self.subdatashape = self.config['subdata_shape'] 33 | self.sublabelshape = self.config['sublabel_shape'] 34 | self.testdatashape = self.config['test_data_shape'] 35 | self.testlabelshape = self.config['test_label_shape'] 36 | self.classnum = int(self.config['class_num']) 37 | self.img_name = self.config['img_name'] 38 | self.label_name = self.config['label_name'] 39 | self.label_exist_name = self.config['label_exist_name'] 40 | self.patient_path = os.path.join(self.data_root, self.stage) 41 | self.patient_num = len(self.patient_path) 42 | self.output_feature = self.config['output_feature'] 43 | self.random_rotate = self.config['random_rotate'] 44 | self.random_scale = self.config['random_scale'] 45 | if self.output_feature: 46 | self.feature_name = self.config['feature_name'] 47 | 48 | data_image = glob.glob(self.patient_path+'/*/'+self.img_name) 49 | self.data_image = data_image 50 | 51 | mask_image = glob.glob(self.patient_path+'/*/'+self.label_name) 52 | self.mask_image = mask_image 53 | 54 | if self.net_mode == 'train': 55 | self.transform = T.Compose(T.RandomRotation(15), 56 | T.RandomCrop(self.subdatashape), 57 | T.ToTensor) 58 | 59 | def __getitem__(self, index): 60 | data_image_path = self.data_image[index] 61 | mask_image_path = self.mask_image[index] 62 | 63 | if mask_image_path.endswith('npy') or data_image_path.endswith('npz'): 64 | originaldata = np.load(data_image_path) 65 | originallabel = np.load(mask_image_path) 66 | elif mask_image_path.endswith('nii') or mask_image_path.endswith('nii.gz'): 67 | originaldata = load_nifty_volume_as_array(data_image_path) 68 | originallabel = load_nifty_volume_as_array(mask_image_path) 69 | else: 70 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 71 | 72 | if self.random_scale: 73 | zoomfactor = random.choice([0.8, 1, 1.2]) 74 | originallabel = ndimage.interpolation.zoom(originallabel, zoomfactor, order=0) 75 | originaldata = ndimage.interpolation.zoom(originaldata, zoomfactor, order=3) 76 | 77 | if self.transform: 78 | image_data = self.transform(originaldata) 79 | mask_data = self.transform(originallabel) 80 | 81 | return image_data, mask_data 82 | 83 | def __len__(self): 84 | return len(self.data_image) 85 | 86 | 87 | def get_list_img(self, patient_order): 88 | ''' 89 | test与valid用,将给定patient的img与label按z轴切片,底端部分会有些重叠 90 | :patient_order:在patient_path里所选病人的order 91 | :return: images:N*1*H/N*W*L,N为所切割片数,1是为了满足预测时维数需要所加,相当于预测时batchsize为1 92 | labels:1*H*W*L 93 | ''' 94 | print(self.patient_path[patient_order]) 95 | patient_path = "{0:}/{1:}/{2:}".format(self.data_root, self.stage, self.patient_path[patient_order]) 96 | datapath = os.path.join(patient_path, self.img_name) 97 | labelpath = os.path.join(patient_path, self.label_name) 98 | databatch = [] 99 | labelbatch = [] 100 | batch = {} 101 | 102 | if datapath.endswith('npy') or datapath.endswith('npz'): 103 | originaldata = np.load(datapath) 104 | originallabel = np.load(labelpath)[np.newaxis, :] # 增加一个batchsize维度 105 | elif datapath.endswith('nii') or datapath.endswith('nii.gz'): 106 | originaldata = load_nifty_volume_as_array(datapath) 107 | originallabel = load_nifty_volume_as_array(labelpath)[np.newaxis, :] 108 | else: 109 | ValueError('please input correct file name! i.e.".nii" ".npy" ".nii.gz"') 110 | 111 | if self.output_feature: 112 | labelpath = os.path.join(patient_path, self.feature_name) 113 | originalfeature = np.load(labelpath) 114 | originaldata = np.concatenate((originaldata, originalfeature), 0) 115 | batch['originalshape'] = originaldata.shape 116 | 117 | 118 | 119 | img_number = int(math.ceil(originaldata.shape[0] / self.testdatashape[0])) # 在z轴上切成的块数,若为小数会向上取整 120 | for i in range(img_number-1): 121 | subdata = originaldata[i*self.testdatashape[0]:(i+1)*self.testdatashape[0]][np.newaxis, :][np.newaxis, :] 122 | subdata = torch.from_numpy(subdata).float() 123 | databatch.append(subdata) 124 | subdata = originaldata[-self.testdatashape[0]::][np.newaxis, :][np.newaxis, :] 125 | subdata = torch.from_numpy(subdata).float() 126 | databatch.append(subdata) 127 | 128 | 129 | batch['images'] = databatch 130 | batch['listlabels'] = labelbatch 131 | batch['labels'] = torch.from_numpy(np.int16(originallabel)).float() 132 | return batch, patient_path 133 | 134 | 135 | def set_noclass_zero(labelexist, prediction, label): 136 | ''' 137 | 当前患者中未标记的器官对应label置0 138 | :param labelexist: 记录有哪些器官 139 | :param label: 140 | :return: 141 | ''' 142 | realprediction = torch.max(prediction, 1)[1] 143 | for i in range(labelexist.shape[0]): 144 | for classnum in range(1, labelexist.shape[1]): 145 | if labelexist[i, classnum] == 0: 146 | a = realprediction[i]==classnum # prediction中为当前class的位置 147 | b = label[i]!=0&classnum # label中其它器官的位置 148 | a = a*(1-b) 149 | label[i][a] = classnum 150 | return label 151 | -------------------------------------------------------------------------------- /models/Unet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from models.module import Module 3 | import torch as t 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.layers import DoubleConv3D, SingleConv3D 7 | import math 8 | 9 | class Unet(Module): 10 | def __init__(self, inc=1, n_classes=5, base_chns=16, droprate=0, norm='in', depth=False, dilation=1): 11 | super(Unet, self).__init__() 12 | self.model_name = "seg" 13 | self.upsample = nn.Upsample(scale_factor=2, mode='trilinear') # 1/4(h,h) 14 | self.downsample = nn.MaxPool3d(2, 2) # 1/2(h,h) 15 | self.drop = nn.Dropout(droprate) 16 | 17 | self.conv1_1 = SingleConv3D(inc, base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 18 | self.conv1_2 = SingleConv3D(base_chns, 2*base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 19 | 20 | self.conv2_1 = SingleConv3D(2*base_chns, 2*base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 21 | self.conv2_2 = SingleConv3D(2 * base_chns, 4 * base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 22 | 23 | self.conv3_1 = SingleConv3D(4*base_chns, 4*base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 24 | self.conv3_2 = SingleConv3D(4 * base_chns, 8 * base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 25 | 26 | self.conv4_1 = SingleConv3D(8*base_chns, 8*base_chns, norm=norm, depth=depth, dilat=math.ceil(dilation/2), pad='same') 27 | self.conv4_2 = SingleConv3D(8 * base_chns, 16 * base_chns, norm=norm, depth=depth, dilat=math.ceil(dilation/2), pad='same') 28 | 29 | self.conv5_1 = SingleConv3D(24*base_chns, 8*base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 30 | self.conv5_2 = SingleConv3D(8 * base_chns, 8 * base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 31 | 32 | self.conv6_1 = SingleConv3D(12*base_chns, 4*base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 33 | self.conv6_2 = SingleConv3D(4 * base_chns, 4 * base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 34 | 35 | self.conv7_1 = SingleConv3D(6*base_chns, 2*base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 36 | self.conv7_2 = SingleConv3D(2 * base_chns, 2 * base_chns, norm=norm, depth=depth, dilat=dilation, pad='same') 37 | 38 | self.classification = nn.Sequential( 39 | nn.Dropout3d(p=0.1), 40 | nn.Conv3d(in_channels=2*base_chns, out_channels=n_classes, kernel_size=1), 41 | ) 42 | 43 | 44 | def forward(self, x): 45 | out = self.conv1_1(x) 46 | conv1 = self.conv1_2(out) 47 | out = self.downsample(conv1) # 1/2 48 | out = self.conv2_1(out) 49 | conv2 = self.conv2_2(out) # 50 | out = self.downsample(conv2) # 1/4 51 | out = self.conv3_1(out) 52 | conv3 = self.conv3_2(out) # 53 | out = self.downsample(conv3) # 1/8 54 | out = self.conv4_1(out) 55 | out = self.conv4_2(out) 56 | out = self.drop(out) 57 | 58 | up5 = self.upsample(out) # 1/4 59 | out = t.cat((up5, conv3), 1) 60 | out = self.conv5_1(out) 61 | out = self.conv5_2(out) 62 | 63 | up6 = self.upsample(out) # 1/2 64 | out = t.cat((up6, conv2), 1) 65 | out = self.conv6_1(out) 66 | out = self.conv6_2(out) 67 | 68 | up7 = self.upsample(out) 69 | out = t.cat((up7, conv1), 1) 70 | out = self.conv7_1(out) 71 | out = self.conv7_2(out) 72 | 73 | out = self.classification(out) 74 | predic = F.softmax(out, dim=1) 75 | return predic 76 | -------------------------------------------------------------------------------- /models/Unet_Separate_3.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from models.module import Module 3 | import torch as t 4 | import torch.nn as nn 5 | from models.layers import TriResSeparateConv3D 6 | import torch.nn.functional as F 7 | import math 8 | class Unet_Separate(Module): 9 | """ 10 | 相比普通U-net加入了res连接,并分离了3D卷积为3*3*1+1*1*3,最终几层宽度增加 11 | """ 12 | def __init__(self, inc=1, n_classes=5, base_chns=12, droprate=0, norm='in', depth = False, dilation=1, separate_direction='axial'): 13 | super(Unet_Separate, self).__init__() 14 | self.model_name = "seg" 15 | 16 | self.dropout = nn.Dropout(droprate) 17 | self.downsample = nn.MaxPool3d(2, 2) 18 | self.upsample = nn.Upsample(scale_factor=2, mode='trilinear') 19 | 20 | self.conv1 = TriResSeparateConv3D(inc, 2*base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 21 | 22 | self.conv2 = TriResSeparateConv3D(2*base_chns, 2 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 23 | 24 | self.conv3 = TriResSeparateConv3D(2 * base_chns, 4 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 25 | 26 | 27 | self.conv4 = TriResSeparateConv3D(4 * base_chns, 8 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 28 | self.conv5 = TriResSeparateConv3D(8 * base_chns, 4 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 29 | 30 | 31 | self.conv6_1 = TriResSeparateConv3D(8 * base_chns, 4 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 32 | self.conv6_2 = TriResSeparateConv3D(4 * base_chns, 2 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 33 | 34 | self.conv7_1 = TriResSeparateConv3D(4 * base_chns, 2 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 35 | self.conv7_2 = TriResSeparateConv3D(2 * base_chns, 2 * base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 36 | 37 | self.conv8_1 = TriResSeparateConv3D(4 * base_chns, 2*base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 38 | self.conv8_2 = TriResSeparateConv3D(2*base_chns, 2*base_chns, norm=norm, depth=depth, pad='same', dilat=dilation, separate_direction=separate_direction) 39 | 40 | self.classification = nn.Sequential( 41 | nn.ReLU(), 42 | nn.Dropout3d(p=0.1), 43 | nn.Conv3d(in_channels=2*base_chns, out_channels=n_classes, kernel_size=1), 44 | ) 45 | 46 | 47 | 48 | def forward(self, x): 49 | conv1 = self.conv1(x) 50 | out = self.downsample(conv1) # 1/2 51 | conv2 = self.conv2(out) # 52 | out = self.downsample(conv2) # 1/4 53 | conv3 = self.conv3(out) # 54 | out = self.downsample(conv3) # 1/8 55 | out = self.conv4(out) 56 | out = self.conv5(out) 57 | 58 | out = self.dropout(out) 59 | 60 | out = self.upsample(out) # 1/4 61 | out = t.cat((out, conv3), 1) 62 | out = self.conv6_1(out) 63 | out = self.conv6_2(out) 64 | 65 | out = self.upsample(out) # 1/2 66 | out = t.cat((out, conv2), 1) 67 | out = self.conv7_1(out) 68 | out = self.conv7_2(out) 69 | 70 | out = self.upsample(out) # 1/2 71 | out = t.cat((out, conv1), 1) 72 | out = self.conv8_1(out) 73 | out = self.conv8_2(out) 74 | 75 | out = self.classification(out) 76 | predic = F.softmax(out, dim=1) 77 | return predic 78 | -------------------------------------------------------------------------------- /models/loss_function.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torch import eq, sum, gt # eq返回相同元素索引,gt返回大于给定值索引 7 | from torch.nn import init 8 | import numpy as np 9 | from torch.autograd import Variable 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | class TestDiceLoss(nn.Module): 14 | def __init__(self, n_class): 15 | super(TestDiceLoss, self).__init__() 16 | self.one_hot_encoder = One_Hot(n_class).forward 17 | self.n_class = n_class 18 | 19 | def forward(self, input, target, show=False): 20 | smooth = 0.00001 21 | batch_size = input.size(0) 22 | input = torch.max(input, 1)[1] 23 | input = self.one_hot_encoder(input).contiguous().view(batch_size, self.n_class, -1) 24 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_class, -1) 25 | inter = torch.sum(torch.sum(input * target, 2), 0) + smooth 26 | union1 = torch.sum(torch.sum(input, 2), 0) + smooth 27 | union2 = torch.sum(torch.sum(target, 2), 0) + smooth 28 | 29 | 30 | ''' 31 | 为避免当前训练图像中未出现的器官影响dice,删除dice大于0.98的部分 32 | ''' 33 | andU = 2.0 * inter / (union1 + union2) 34 | score = andU 35 | 36 | return score.float() 37 | 38 | class SoftDiceLoss(nn.Module): 39 | def __init__(self, n_class): 40 | super(SoftDiceLoss, self).__init__() 41 | self.one_hot_encoder = One_Hot(n_class).forward 42 | self.n_class = n_class 43 | 44 | def forward(self, input, target): 45 | ''' 46 | :param input: the prediction, batchsize*n_class*depth*length*width 47 | :param target: the groundtruth, batchsize*depth*length*width 48 | :return: loss 49 | ''' 50 | smooth = 0.01 51 | batch_size = input.size(0) 52 | 53 | input = input.view(batch_size, self.n_class, -1) 54 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_class, -1) 55 | 56 | inter = torch.sum(input * target, 2) + smooth 57 | union1 = torch.sum(input, 2) + smooth 58 | union2 = torch.sum(target, 2) + smooth 59 | 60 | andU = torch.sum(2.0 * inter/(union1 + union2)) 61 | score = 1 - andU/(batch_size*self.n_class) 62 | 63 | return score 64 | 65 | class FocalLoss(nn.Module): 66 | def __init__(self, n_class): 67 | super(FocalLoss, self).__init__() 68 | self.one_hot_encoder = One_Hot(n_class).forward 69 | self.n_class = n_class 70 | 71 | def forward(self, input,target): 72 | ''' 73 | :param input: the prediction, batchsize*n_class*depth*length*width 74 | :param target: the groundtruth, batchsize*depth*length*width 75 | :return: loss 76 | ''' 77 | batch_size = input.size(0) 78 | input = input.view(batch_size, self.n_class, -1) 79 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_class, -1) 80 | volume = input.size(2) 81 | score = -torch.sum(target*(1-input)**2*torch.log10(input))/volume 82 | 83 | return score 84 | 85 | class Focal_and_Dice_loss(nn.Module): 86 | def __init__(self, n_class, lamda): 87 | super(Focal_and_Dice_loss, self).__init__() 88 | self.one_hot_encoder = One_Hot(n_class).forward 89 | self.n_class = n_class 90 | self.lamda = lamda 91 | self.FocalLoss = FocalLoss(n_class) 92 | self.SoftDiceloss = SoftDiceLoss(n_class) 93 | 94 | def forward(self, input, target): 95 | ''' 96 | :param input: the prediction, batchsize*n_class*depth*length*width 97 | :param target: the groundtruth, batchsize*depth*length*width 98 | :return: loss 99 | ''' 100 | score = self.lamda*self.FocalLoss(input, target)+self.n_class*self.SoftDiceloss(input, target) 101 | return score 102 | 103 | 104 | class AttentionDiceLoss(nn.Module): 105 | def __init__(self, n_class, alpha): 106 | super(AttentionDiceLoss, self).__init__() 107 | self.one_hot_encoder = One_Hot(n_class).forward 108 | self.n_class = n_class 109 | self.alpha = alpha 110 | def forward(self, input, target): 111 | ''' 112 | :param input: the prediction, batchsize*n_class*depth*length*width 113 | :param target: the groundtruth, batchsize*depth*length*width 114 | :return: loss 115 | ''' 116 | smooth = 0.01 117 | batch_size = input.size(0) 118 | 119 | input = input.view(batch_size, self.n_class, -1) 120 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_class, -1) 121 | attentioninput = torch.exp((input - target) / self.alpha) * input 122 | inter = torch.sum(attentioninput * target, 2) + smooth 123 | union1 = torch.sum(attentioninput, 2) + smooth 124 | union2 = torch.sum(target, 2) + smooth 125 | 126 | andU = torch.sum(2.0 * inter / (union1 + union2)) 127 | score = batch_size * self.n_class - andU 128 | 129 | return score 130 | 131 | 132 | class ExpDiceLoss(nn.Module): 133 | def __init__(self, n_class, weights=[1, 1], gama=0.0001): 134 | super(ExpDiceLoss, self).__init__() 135 | self.one_hot_encoder = One_Hot(n_class).forward 136 | self.n_class = n_class 137 | self.gama = gama 138 | self.weight = weights 139 | smooth = 1 140 | self.Ldice = Ldice(n_class, smooth) 141 | self.Lcross = Lcross(n_class) 142 | def forward(self, input, target): 143 | ''' 144 | :param input: batch*class*depth*length*height or batch*calss*length*height 145 | :param target: batch*depth*length*height or batch*length*height 146 | :return: 147 | ''' 148 | smooth = 1 149 | batch_size = input.size(0) 150 | realinput = input 151 | realtarget = target 152 | 153 | input = input.view(batch_size, self.n_class, -1) 154 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_class, -1) 155 | label_sum = torch.sum(target[:, 1::], 2) + smooth # 非背景类label各自和 156 | Wl = (torch.sum(label_sum) / torch.sum(label_sum, 0))**0.5 # 各label占总非背景类label比值的开方 157 | Ldice = self.Ldice(input, target, batch_size) # 158 | Lcross = self.Lcross(realinput, realtarget, Wl, label_sum) 159 | Lexp = self.weight[0] * Ldice + self.weight[1] * Lcross 160 | return Lexp 161 | 162 | 163 | class AttentionExpDiceLoss(nn.Module): 164 | def __init__(self, n_class, alpha, gama=0.0001): 165 | super(AttentionExpDiceLoss, self).__init__() 166 | self.one_hot_encoder = One_Hot(n_class).forward 167 | self.n_class = n_class 168 | self.gama = gama 169 | self.weight = [1,1] 170 | self.alpha = alpha 171 | smooth = 1 172 | self.Ldice = Ldice(n_class-1, smooth) 173 | self.Lcross = Lcross(n_class) 174 | def forward(self, input, target): 175 | ''' 176 | :param input: batch*class*depth*length*height or batch*calss*length*height 177 | :param target: batch*depth*length*height or batch*length*height 178 | :param dis: batch*class*depth*length*height or batch*calss*length*height 179 | :return: 180 | ''' 181 | smooth = 1 182 | batch_size = input.size(0) 183 | realinput = input 184 | realtarget = target 185 | input = input.view(batch_size, self.n_class, -1)[:, 1::] 186 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_class, -1)[:, 1::] 187 | attentionseg = torch.exp((input - target)/self.alpha) * input 188 | label_sum = torch.sum(target, 2) + smooth # 非背景类label各自和 189 | Wl = (torch.sum(label_sum) / torch.sum(label_sum, 0))**0.5 # 各label占总非背景类label比值的开方 190 | Ldice = self.Ldice(attentionseg, target, batch_size) # 191 | Lcross = self.Lcross(realinput, realtarget, Wl, label_sum) 192 | Lexp = self.weight[0] * Ldice + self.weight[1] * Lcross 193 | return Lexp 194 | 195 | 196 | class BatchExpDiceLoss(nn.Module): 197 | def __init__(self, n_class, weights=[1, 1], gama=0.0001): 198 | super(BatchExpDiceLoss, self).__init__() 199 | self.one_hot_encoder = One_Hot(n_class).forward 200 | self.n_class = n_class 201 | self.gama = gama 202 | self.weight = weights 203 | smooth = 1 204 | self.BatchLdice = BatchLdice(n_class, smooth) 205 | self.Lcross = Lcross(n_class) 206 | def forward(self, input, target): 207 | ''' 208 | :param input: batch*class*depth*length*height or batch*calss*length*height 209 | :param target: batch*depth*length*height or batch*length*height 210 | :return: batch*ExpDice 211 | ''' 212 | smooth = 1 213 | batch_size = input.size(0) 214 | realinput = input 215 | realtarget = target 216 | 217 | input = input.view(batch_size, self.n_class, -1) 218 | target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_class, -1) 219 | label_sum = torch.sum(target[:, 1::], 2) + smooth # 非背景类label各自和 220 | Wl = (torch.sum(label_sum) / torch.sum(label_sum, 0))**0.5 221 | Ldice = self.BatchLdice(input, target, batch_size) 222 | Lcross = self.Lcross(realinput, realtarget, Wl, label_sum) 223 | Lexp = self.weight[0] * Ldice + self.weight[1] * Lcross 224 | return Lexp 225 | 226 | 227 | class One_Hot(nn.Module): 228 | def __init__(self, depth): 229 | super(One_Hot, self).__init__() 230 | self.depth = depth 231 | self.ones = torch.eye(depth).cuda() # torch.sparse.torch.eye 232 | # eye生成depth尺度的单位矩阵 233 | 234 | def forward(self, X_in): 235 | ''' 236 | :param X_in: batch*depth*length*height or batch*length*height 237 | :return: batch*class*depth*length*height or batch*calss*length*height 238 | ''' 239 | n_dim = X_in.dim() # 返回dimension数目 240 | output_size = X_in.size() + torch.Size([self.depth]) # 增加一个class通道 241 | num_element = X_in.numel() # 返回element总数 242 | X_in = X_in.data.long().view(num_element) # 将target拉伸为一行 243 | out1 = Variable(self.ones.index_select(0, X_in)) 244 | out = out1.view(output_size) 245 | return out.permute(0, -1, *range(1, n_dim)).squeeze(dim=2).float() # permute更改dimension顺序 246 | 247 | def __repr__(self): 248 | return self.__class__.__name__ + "({})".format(self.depth) 249 | 250 | def make_one_hot(input, num_classes): 251 | """Convert class index tensor to one hot encoding tensor. 252 | Args: 253 | input: A tensor of shape [N, 1, *] 254 | num_classes: An int of number of class 255 | Returns: 256 | A tensor of shape [N, num_classes, *] 257 | """ 258 | shape = np.array(input.shape) 259 | shape[1] = num_classes 260 | shape = tuple(shape) 261 | result = torch.zeros(shape) 262 | result = result.scatter_(1, input.cpu(), 1) 263 | 264 | return result 265 | 266 | 267 | class BinaryDiceLoss(nn.Module): 268 | """Dice loss of binary class 269 | Args: 270 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 271 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 272 | predict: A tensor of shape [N, *] 273 | target: A tensor of shape same with predict 274 | reduction: Reduction method to apply, return mean over batch if 'mean', 275 | return sum if 'sum', return a tensor of shape [N,] if 'none' 276 | Returns: 277 | Loss tensor according to arg reduction 278 | Raise: 279 | Exception if unexpected reduction 280 | """ 281 | def __init__(self, smooth=1, p=2, reduction='mean'): 282 | super(BinaryDiceLoss, self).__init__() 283 | self.smooth = smooth 284 | self.p = p 285 | self.reduction = reduction 286 | 287 | def forward(self, predict, target): 288 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 289 | predict = predict.contiguous().view(predict.shape[0], -1) 290 | target = target.contiguous().view(target.shape[0], -1) 291 | 292 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth 293 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth 294 | 295 | loss = 1 - num / den 296 | 297 | if self.reduction == 'mean': 298 | return loss.mean() 299 | elif self.reduction == 'sum': 300 | return loss.sum() 301 | elif self.reduction == 'none': 302 | return loss 303 | else: 304 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 305 | 306 | 307 | class DiceLoss(nn.Module): 308 | """Dice loss, need one hot encode input 309 | Args: 310 | weight: An array of shape [num_classes,] 311 | ignore_index: class index to ignore 312 | predict: A tensor of shape [N, C, *] 313 | target: A tensor of same shape with predict 314 | other args pass to BinaryDiceLoss 315 | Return: 316 | same as BinaryDiceLoss 317 | """ 318 | def __init__(self, weight=None, ignore_index=None, **kwargs): 319 | super(DiceLoss, self).__init__() 320 | self.kwargs = kwargs 321 | self.weight = weight 322 | self.ignore_index = ignore_index 323 | 324 | def forward(self, predict, target): 325 | assert predict.shape == target.shape, 'predict & target shape do not match' 326 | dice = BinaryDiceLoss(**self.kwargs) 327 | total_loss = 0 328 | predict = F.softmax(predict, dim=1) 329 | 330 | for i in range(target.shape[1]): 331 | if i != self.ignore_index: 332 | dice_loss = dice(predict[:, i], target[:, i]) 333 | if self.weight is not None: 334 | assert self.weight.shape[0] == target.shape[1], \ 335 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 336 | dice_loss *= self.weights[i] 337 | total_loss += dice_loss 338 | 339 | return total_loss/target.shape[1] 340 | 341 | 342 | class Ldice(nn.Module): 343 | def __init__(self, smooth, n_class): 344 | super(Ldice, self).__init__() 345 | self.smooth = smooth 346 | self.n_class = n_class 347 | 348 | def forward(self, input, target, batch_size): 349 | ''' 350 | Ldice 351 | ''' 352 | inter = torch.sum(input * target, 2) + self.smooth 353 | union1 = torch.sum(input, 2) + self.smooth 354 | union2 = torch.sum(target, 2) + self.smooth 355 | dice = 2.0 * inter / (union1 + union2) 356 | logdice = -torch.log(dice) 357 | expdice = torch.sum(logdice) # ** self.gama 358 | Ldice = expdice / (batch_size*self.n_class) 359 | return Ldice 360 | 361 | 362 | class Lcross(nn.Module): 363 | def __init__(self, n_class): 364 | super(Lcross, self).__init__() 365 | self.n_class = n_class 366 | def forward(self, realinput, realtarget, Wl, label_sum): 367 | ''' 368 | realinput: 369 | realtarget: 370 | Wl: 各label占总非背景类label比值的开方 371 | ''' 372 | Lcross = 0 373 | for i in range(1, self.n_class): 374 | mask = realtarget == i 375 | if torch.sum(mask).item() > 0: 376 | ProLabel = realinput[:, i][mask.detach()] 377 | LogLabel = -torch.log(ProLabel) 378 | ExpLabel = torch.sum(LogLabel) # **self.gama 379 | Lcross += Wl[i - 1] * ExpLabel 380 | Lcross = Lcross / torch.sum(label_sum) 381 | 382 | return Lcross 383 | 384 | -------------------------------------------------------------------------------- /models/module.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | import torch as t 3 | import time 4 | 5 | 6 | class Module(t.nn.Module): 7 | ''' 8 | 封装了nn.Module,主要是提供了save和load两个方法 9 | ''' 10 | 11 | def __init__(self): 12 | super(Module,self).__init__() 13 | self.model_name=str(type(self))# 默认名字 14 | 15 | def load(self, path): 16 | ''' 17 | 可加载指定路径的模型 18 | ''' 19 | self.load_state_dict(t.load(path)) 20 | 21 | def save(self, name=None): 22 | ''' 23 | 保存模型,默认使用“模型名字+时间”作为文件名, 24 | 如AlexNet_0710_23:57:29.pth 25 | ''' 26 | if name is None: 27 | prefix = 'checkpoints/' + self.model_name + '_' 28 | name = time.strftime(prefix + '%m%d_%H:%M:%S.pth') 29 | t.save(self.state_dict(), name) 30 | return name 31 | 32 | 33 | class Flat(t.nn.Module): 34 | ''' 35 | 把输入reshape成(batch_size,dim_length) 36 | ''' 37 | 38 | def __init__(self): 39 | super(Flat, self).__init__() 40 | #self.size = size 41 | 42 | def forward(self, x): 43 | return x.view(x.size(0), -1) 44 | -------------------------------------------------------------------------------- /seg_img.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import, print_function 4 | import time 5 | import os 6 | from torch.nn import parallel 7 | import torch.tensor 8 | from models.LYC_data_loader import LYC_dataset 9 | from util.train_test_func import * 10 | from util.parse_config import parse_config 11 | from models.Unet_Separate_2 import Unet_Separate_2 12 | from models.Unet_Separate_3 import Unet_Separate_3 13 | from models.Unet_Separate_4 import Unet_Separate_4 14 | from models.UnetSE_Separate_3 import UnetSE_Separate_3 15 | from models.UnetDense_Separate_3 import UnetDense_Separate_3 16 | from models.UnetDense_Separate_4 import UnetDense_Separate_4 17 | from models.UnetDense_Separate_5 import UnetDense_Separate_5 18 | from models.Unet import Unet 19 | from models.Unet_Res import Unet_Res 20 | from models.SOLNet import SOLNet 21 | from models.Unet_Separate import Unet_Separate 22 | from models.DS_Unet_Separate_3 import DS_Unet_Separate_3 23 | from models.DS_Unet_Separate_4 import DS_Unet_Separate_4 24 | from models.loss_function import TestDiceLoss 25 | from util.binary import assd, dc, hd95 26 | from data_process.data_process_func import save_array_as_nifty_volume, make_overlap_weight 27 | from util.assd_evaluation import one_hot 28 | from skimage import morphology 29 | import pandas as pd 30 | 31 | class NetFactory(object): 32 | @staticmethod 33 | def create(name): 34 | if name == 'Unet_Separate_2': 35 | return Unet_Separate_2 36 | 37 | if name == 'Unet': 38 | return Unet 39 | 40 | if name == 'Unet_Res': 41 | return Unet_Res 42 | 43 | if name == 'Unet_Separate': 44 | return Unet_Separate 45 | 46 | if name == 'Unet_Separate_3': 47 | return Unet_Separate_3 48 | 49 | if name == 'Unet_Separate_4': 50 | return Unet_Separate_4 51 | 52 | if name == 'UnetSE_Separate_3': 53 | return UnetSE_Separate_3 54 | 55 | if name == 'SOLNet': 56 | return SOLNet 57 | 58 | if name == 'UnetDense_Separate_3': 59 | return UnetDense_Separate_3 60 | 61 | if name == 'UnetDense_Separate_4': 62 | return UnetDense_Separate_4 63 | 64 | if name == 'UnetDense_Separate_5': 65 | return UnetDense_Separate_5 66 | 67 | if name == 'DS_Unet_Separate_3': 68 | return DS_Unet_Separate_3 69 | 70 | if name == 'DS_Unet_Separate_4': 71 | return DS_Unet_Separate_4 72 | 73 | # add your own networks here 74 | print('unsupported network:', name) 75 | exit() 76 | 77 | 78 | def seg(config_file): 79 | # 1, load configuration parameters 80 | print('1.Load parameters') 81 | config = parse_config(config_file) 82 | config_data = config['data'] # config of data,e.g. data_shape,batch_size. 83 | config_net = config['network'] # config of net, e.g. net_name,base_feature_name,class_num. 84 | config_test = config['testing'] 85 | overlap_num = config['data']['overlap_num'] 86 | random.seed(config_test.get('random_seed', 1)) 87 | subseg_name = config_data['seg_name'] 88 | subprob_name = subseg_name.replace('seg', 'prob') 89 | net_type = config_net['net_type'] 90 | class_num = config_net['class_num'] 91 | overlap_weight = make_overlap_weight(overlap_num) 92 | output_probability = False 93 | save = False 94 | save_array_as_xls = False 95 | show = False 96 | cal_dice = True 97 | cal_assd = False 98 | cal_hd95 = False 99 | show_hist = False 100 | overlap_bias = False 101 | class_weight = np.asarray([0, 100,100,50,50,80,80,50,80,80,80,50,50,70,70,70,70,60,60,100,100,100]) 102 | 103 | 104 | # 2, load data 105 | print('2.Load data') 106 | Datamode = ['valid'] 107 | 108 | # 3. creat model 109 | print('3.Creat model') 110 | net_class = NetFactory.create(net_type) 111 | net = net_class(inc=config_net.get('input_channel', 1), 112 | n_classes = class_num, 113 | base_chns= config_net.get('base_feature_number', 16), 114 | droprate=config_net.get('drop_rate', 0.2), 115 | norm='in', 116 | depth=False, 117 | dilation=config_net.get('dilation', 1) 118 | ) 119 | 120 | net = torch.nn.DataParallel(net, device_ids=[0]).cuda() 121 | if config_test['load_weight']: 122 | weight = torch.load(config_test['model_path'], map_location=lambda storage, loc: storage) 123 | net.load_state_dict(weight) 124 | print(torch.cuda.is_available()) 125 | 126 | 127 | # 4, start to seg 128 | print('''start to seg ''') 129 | net.eval() 130 | for mode in Datamode: 131 | Data = LYC_dataset(config_data, mode) 132 | patient_number = len(os.listdir(os.path.join(config_data['data_root'], mode))) 133 | with torch.no_grad(): 134 | t_array = np.zeros(patient_number) 135 | dice_array = np.zeros([patient_number, class_num]) 136 | assd_array = np.zeros([patient_number, class_num]) 137 | hd95_array = np.zeros([patient_number, class_num]) 138 | for patient_order in range(patient_number): 139 | t1 = time.time() 140 | valid_pair, patient_path = Data.get_list_overlap_img(patient_order) # 因为病人的数据无法一次完全预测,内存不够,所以裁剪成几块 141 | clip_number = len(valid_pair['images']) # 裁剪块数 142 | clip_height = config_data['test_data_shape'][0] # 裁剪图像的高度 143 | total_labels = valid_pair['labels'].cuda() 144 | predic_size = torch.Size([1, class_num]) + total_labels.size()[1::] 145 | totalpredic = torch.zeros(predic_size).cuda() # 完整预测 146 | outfeature_size = torch.Size([1, 2*config_net.get('base_feature_number')]) + total_labels.size()[1::] 147 | totalfeature = torch.zeros(outfeature_size).cuda() 148 | for i in range(clip_number): 149 | tempx = valid_pair['images'][i].cuda() 150 | pred = net(tempx) 151 | if overlap_bias: 152 | for j in range(overlap_num): 153 | pred[:,:,j*clip_height:(j+1)*clip_height]*= overlap_weight[j] 154 | if i < clip_number - 1: 155 | totalpredic[:, :, i * clip_height:(i + overlap_num) * clip_height] += pred 156 | else: 157 | totalpredic[:, :, -overlap_num*clip_height::] += pred 158 | 159 | if output_probability: 160 | totalfeature = (100*totalpredic.cpu().data.numpy().squeeze()).astype(np.uint16) 161 | totalpredic = torch.max(totalpredic, 1)[1].squeeze() 162 | totalpredic = np.uint8(totalpredic.cpu().data.numpy().squeeze()) 163 | totallabel = np.uint8(total_labels.cpu().data.numpy().squeeze()) 164 | 165 | t2 = time.time() 166 | t = t2-t1 167 | t_array[patient_order] = t 168 | 169 | one_hot_label = one_hot(totallabel, class_num) 170 | one_hot_predic = one_hot(totalpredic, class_num) 171 | # for i in range(one_hot_predic[20].shape[0]): 172 | # one_hot_predic[20, i] = morphology.erosion(one_hot_predic[20, i], np.ones([1, 1])) 173 | 174 | if cal_dice: 175 | Dice = np.zeros(class_num) 176 | for i in range(class_num): 177 | Dice[i] = dc(one_hot_predic[i], one_hot_label[i]) 178 | dice_array[patient_order] = Dice 179 | print('patient order', patient_order, ' dice:', Dice) 180 | 181 | if cal_assd: 182 | Assd = np.zeros(class_num) 183 | for i in range(class_num): 184 | Assd[i] = assd(one_hot_predic[i], one_hot_label[i], 1) 185 | assd_array[patient_order] = Assd 186 | print('patient order', patient_order, ' dice:', Assd) 187 | 188 | if cal_hd95: 189 | Hd95 = np.zeros(class_num) 190 | for i in range(class_num): 191 | Hd95[i] = hd95(one_hot_predic[i], one_hot_label[i], 1) 192 | hd95_array[patient_order] = Hd95 193 | print('patient order', patient_order, ' Hd95:', Hd95) 194 | 195 | 196 | if show: 197 | for i in np.arange(0, totalpredic.shape[0], 2): 198 | f, plots = plt.subplots(1, 2) 199 | plots[0].imshow(totalpredic[i]) 200 | plots[1].imshow(totallabel[i]) 201 | plt.show() 202 | if save : 203 | if output_probability: 204 | save_array_as_nifty_volume(totalfeature, patient_path+'/'+subprob_name, transpose=False, pixel_spacing=[1,1,1,1]) 205 | # np.save(patient_path+'/'+subseg_name, totalfeature) 206 | save_array_as_nifty_volume(totalpredic, patient_path + '/' +subseg_name) 207 | # np.savetxt(patient_path+'/Dice.npy', Dice.squeeze()) 208 | # np.savetxt(patient_path+'/Assd.npy', Assd.squeeze()) 209 | 210 | if cal_dice: 211 | dice_array[:, 0] = np.mean(dice_array[:, 1::], 1) 212 | dice_mean = np.mean(dice_array, 0) 213 | dice_std = np.std(dice_array, 0) 214 | # weight_score = np.inner(dice_mean, class_weight) 215 | print('{0:} mode: mean dice:{1:}, std of dice:{2:}'.format(mode, dice_mean, dice_std))#, weight_score)) 216 | if show_hist: 217 | plt.figure('hist') 218 | for i in range(class_num-1): 219 | plt.subplot(4, 6, i+1) 220 | plt.hist(dice_array[:, i+1], bins=10, range=(0, 1)) 221 | plt.show() 222 | if cal_assd: 223 | assd_array[:, 0] = np.mean(assd_array[:, 1::], 1) 224 | assd_mean = np.mean(assd_array, 0) 225 | assd_std = np.std(assd_array, 0) 226 | print('{0:} mode: mean assd:{1:}, std of assd:{2:}'.format(mode, assd_mean, assd_std)) 227 | 228 | if cal_hd95: 229 | hd95_array[:, 0] = np.mean(hd95_array[:, 1::], 1) 230 | hd95_mean = np.mean(hd95_array, 0) 231 | hd95_std = np.std(hd95_array, 0) 232 | if save_array_as_xls: 233 | mean = pd.DataFrame(hd95_mean) 234 | mean_writer = pd.ExcelWriter('hd95_mean.xlsx') 235 | mean.to_excel(mean_writer, 'page_1', float_format='%.3f') 236 | mean_writer.save() 237 | mean_writer.close() 238 | 239 | std = pd.DataFrame(hd95_std) 240 | std_writer = pd.ExcelWriter('hd95_std.xlsx') 241 | std.to_excel(std_writer, 'page_1', float_format='%.3f') 242 | std_writer.save() 243 | std_writer.close() 244 | print('{0:} mode: mean HD95:{1:}, std of HD95:{2:}'.format(mode, hd95_mean, hd95_std)) 245 | t_mean = [t_array.mean()] 246 | t_std = [t_array.std()] 247 | print('{0:} mode: mean time:{1:}, std of time:{2:}'.format(mode, t_mean, t_std)) 248 | 249 | 250 | 251 | if __name__ == '__main__': 252 | #for i in range(6): 253 | config_file = str('config/pnet_test.txt') 254 | assert (os.path.isfile(config_file)) 255 | seg(config_file) 256 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import, print_function 4 | 5 | import time 6 | import torch.optim as optim 7 | from torch.nn import parallel 8 | import torch.tensor 9 | from models.LYC_data_loader import LYC_dataset, set_noclass_zero 10 | from util.train_test_func import * 11 | from util.parse_config import parse_config 12 | from models.Unet_Separate import Unet_Separate 13 | from models.Unet import Unet 14 | from models.loss_function import TestDiceLoss, SoftDiceLoss, ExpDiceLoss, AttentionExpDiceLoss, DiceLoss, make_one_hot 15 | from models.Pymic_loss import soft_dice_loss, get_soft_label 16 | from util.visualization.visualize_loss import loss_visualize 17 | from util.visualization.show_param import show_param 18 | 19 | class NetFactory(object): 20 | @staticmethod 21 | def create(name): 22 | if name == 'Unet': 23 | return Unet 24 | 25 | if name == 'Unet_Separate': 26 | return Unet_Separate 27 | 28 | print('unsupported network:', name) 29 | exit() 30 | 31 | def clip_gradient(optimizer, grad_clip): 32 | """ 33 | Clips gradients computed during backpropagation to avoid explosion of gradients. 34 | 35 | :param optimizer: optimizer with the gradients to be clipped 36 | :param grad_clip: clip value 37 | """ 38 | for group in optimizer.param_groups: 39 | for param in group["params"]: 40 | if param.grad is not None: 41 | param.grad.data.clamp_(-grad_clip, grad_clip) 42 | 43 | 44 | 45 | def train(config_file): 46 | # 1, load configuration parameters 47 | print('1.Load parameters') 48 | config = parse_config(config_file) 49 | config_data = config['data'] # data config, like data_shape,batch_size, 50 | config_net = config['network'] # net config, like net_name,base_feature_name,class_num 51 | config_train = config['training'] 52 | 53 | random.seed(config_train.get('random_seed', 1)) 54 | 55 | valid_patient_number = len(os.listdir(config_data['data_root']+'/'+'valid')) 56 | net_type = config_net['net_type'] 57 | class_num = config_net['class_num'] 58 | batch_size = config_data.get('batch_size', 4) 59 | lr = config_train.get('learning_rate', 1e-3) 60 | best_dice = config_train.get('best_dice', 0.5) 61 | 62 | # 2, load data 63 | print('2.Load data') 64 | trainData = LYC_dataset(config_data, 'train') 65 | validData = LYC_dataset(config_data, 'valid') 66 | 67 | # 3. creat model 68 | print('3.Creat model') 69 | net_class = NetFactory.create(net_type) 70 | net = net_class(inc=config_net.get('input_channel', 1), 71 | n_classes = class_num, 72 | base_chns= config_net.get('base_feature_number', 16), 73 | droprate=config_net.get('drop_rate', 0.2), 74 | norm='in', 75 | depth=config_net.get('depth', False), 76 | dilation=config_net.get('dilation', 1), 77 | separate_direction='axial' 78 | ) 79 | net = torch.nn.DataParallel(net, device_ids=[0, 1]).cuda() 80 | if config_train['load_weight']: 81 | weight = torch.load(config_train['model_path'], map_location=lambda storage, loc: storage) 82 | net.load_state_dict(weight) 83 | 84 | show_param(net) 85 | 86 | dice_eval = TestDiceLoss(n_class=class_num) 87 | loss_func = AttentionExpDiceLoss(n_class=class_num, alpha=0.5) 88 | show_loss = loss_visualize(class_num) 89 | 90 | Adamoptimizer = optim.Adam(net.parameters(), lr=lr, weight_decay= config_train.get('decay', 1e-7)) 91 | Adamscheduler = torch.optim.lr_scheduler.StepLR(Adamoptimizer, step_size=10, gamma=0.9) 92 | 93 | # 4, start to train 94 | print('4.Start to train') 95 | dice_file = config_train['model_save_prefix'] + "_dice.txt" 96 | start_it = config_train.get('start_iteration', 0) 97 | dice_save= np.zeros([config_train['maximal_epoch'], 2+class_num]) 98 | for n in range(start_it, config_train['maximal_epoch']): 99 | train_loss_list, train_dice_list = np.zeros(config_train['train_step']//config_train['print_step']), np.zeros([config_train['train_step']//config_train['print_step'], class_num]) 100 | valid_loss_list, valid_dice_list = np.zeros(valid_patient_number), np.zeros([valid_patient_number, class_num]) 101 | 102 | 103 | optimizer = Adamoptimizer 104 | 105 | net.train() 106 | print('###train###\n') 107 | for step in range(config_train['train_step']): 108 | train_pair = trainData.get_subimage_batch() 109 | tempx = torch.FloatTensor(train_pair['images']).cuda() 110 | tempy = torch.FloatTensor(train_pair['labels']).cuda() 111 | # soft_tempy = get_soft_label(tempy.unsqueeze(1), class_num) 112 | predic = net(tempx) 113 | train_loss = loss_func(predic, tempy) 114 | optimizer.zero_grad() 115 | train_loss.backward() 116 | # torch.nn.utils.clip_grad_norm(net.parameters(), 10) 117 | optimizer.step() 118 | if step%config_train['print_step']==0: 119 | train_loss = train_loss.cpu().data.numpy() 120 | train_loss_list[step//config_train['print_step']] = train_loss 121 | train_dice = dice_eval(predic, tempy) 122 | train_dice = train_dice.cpu().data.numpy() 123 | train_dice_list[step//config_train['print_step']] = train_dice 124 | print('train loss:', train_loss, ' train dice:', train_dice) 125 | Adamscheduler.step() 126 | 127 | print('###test###\n') 128 | with torch.no_grad(): 129 | net.eval() 130 | for patient_order in range(valid_patient_number): 131 | valid_pair, patient_path = validData.get_list_img(patient_order) 132 | clip_number = len(valid_pair['images']) 133 | clip_height = config_data['test_data_shape'][0] 134 | total_labels = valid_pair['labels'].cuda() 135 | predic_size = torch.Size([1, class_num])+total_labels.size()[1::] 136 | totalpredic = torch.zeros(predic_size).cuda() 137 | 138 | for i in range(clip_number): 139 | tempx = valid_pair['images'][i].cuda() 140 | pred = net(tempx) 141 | # pred[:, 0][tempx[:, 0] <= 0.0001] = 1 142 | if i < clip_number-1: 143 | totalpredic[:, :, i * clip_height:(i + 1) * clip_height] = pred 144 | else: 145 | totalpredic[:, :, -clip_height::] = pred 146 | 147 | valid_dice = dice_eval(totalpredic, total_labels, show=True).cpu().data.numpy() 148 | valid_dice_list[patient_order] = valid_dice 149 | print(' valid dice:', valid_dice) 150 | 151 | 152 | batch_dice = [valid_dice_list.mean(axis=0), train_dice_list.mean(axis=0)] 153 | t = time.strftime('%X %x %Z') 154 | print(t, 'n', n, '\ndice:\n', batch_dice) 155 | show_loss.plot_loss(n, batch_dice) 156 | train_dice_mean = np.asarray([batch_dice[1][1::].mean(axis=0)]) 157 | valid_dice_classes = batch_dice[0][1::] 158 | valid_dice_mean = np.asarray([valid_dice_classes.mean(axis=0)]) 159 | batch_dice = np.append(np.append(train_dice_mean, 160 | valid_dice_mean), valid_dice_classes) 161 | dice_save[n] = np.append(n, batch_dice) 162 | 163 | if batch_dice[1] > best_dice: 164 | best_dice = batch_dice[1] 165 | torch.save(net.state_dict(), config_train['model_save_prefix'] + "_{0:}.pkl".format(batch_dice[1])) 166 | 167 | if __name__ == '__main__': 168 | config_file = str('config/train.txt') 169 | assert(os.path.isfile(config_file)) 170 | train(config_file) 171 | -------------------------------------------------------------------------------- /util/._.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/util/._.DS_Store -------------------------------------------------------------------------------- /util/._assd_evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/util/._assd_evaluation.py -------------------------------------------------------------------------------- /util/._dice_evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/util/._dice_evaluation.py -------------------------------------------------------------------------------- /util/._pre_process.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/util/._pre_process.py -------------------------------------------------------------------------------- /util/._train_test_func.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/util/._train_test_func.py -------------------------------------------------------------------------------- /util/._visualize.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SepNet/934f71ba53a3c1906c5ff3558c3805f4ce790c9a/util/._visualize.py -------------------------------------------------------------------------------- /util/Label_exist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.ndimage 4 | import nibabel 5 | from util.pre_process import * 6 | 7 | ### 由于肿瘤医院数据标注不全,我们需要另存一个文档指出当前数据具有哪些标注 8 | 9 | data_root = '/lyc/MICCAI-19-StructSeg/HaN_OAR/train' 10 | filename = 'crop_label.nii.gz' 11 | savefilename = 'label_exist.npy' 12 | file_list = os.listdir(data_root) 13 | classnum = 23 14 | 15 | for file in file_list: 16 | data_path = os.path.join(data_root, file, filename) 17 | save_path = os.path.join(data_root, file, savefilename) 18 | data = load_nifty_volume_as_array(data_path) 19 | label_exist = np.zeros(classnum) 20 | for i in range(classnum): 21 | if np.sum(np.where(data == i)) > 5: 22 | label_exist[i] = 1 23 | np.savetxt(save_path, label_exist) 24 | print('已储存', file) 25 | -------------------------------------------------------------------------------- /util/assd_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import csv 5 | import math 6 | import nibabel 7 | import random 8 | import numpy as np 9 | from scipy import ndimage 10 | 11 | def one_hot(img, nb_classes): 12 | hot_img = np.zeros([nb_classes]+list(img.shape)) 13 | for i in range(nb_classes): 14 | hot_img[i][np.where(img == i)] = 1 15 | return hot_img 16 | 17 | def binary_assd3d(s, g, spacing = [1.0, 1.0, 1.0]): 18 | assert(len(s.shape)==3) 19 | [Ds, Hs, Ws] = s.shape 20 | [Dg, Hg, Wg] = g.shape 21 | assert(Ds==Dg and Hs==Hg and Ws==Wg) 22 | scale = [1.0, 1/spacing[1], 1/spacing[2]] 23 | s_resample = ndimage.interpolation.zoom(s, scale, order = 0) 24 | g_resample = ndimage.interpolation.zoom(g, scale, order = 0) 25 | point_list_s = volume_to_surface(s_resample) # 含所有边界点的列表 26 | point_list_g = volume_to_surface(g_resample) 27 | new_spacing = [spacing[0], 1.0, 1.0] 28 | dis_array1 = assd_distance_from_one_surface_to_another(point_list_s, point_list_g, new_spacing) 29 | dis_array2 = assd_distance_from_one_surface_to_another(point_list_g, point_list_s, new_spacing) 30 | assd = (dis_array1.sum() + dis_array2.sum())/(len(dis_array1) + len(dis_array2)) 31 | return assd 32 | 33 | def assd_distance_from_one_surface_to_another(point_list_s, point_list_g, spacing): 34 | dis_square = 0.0 35 | n_max = 500 36 | if(len(point_list_s) > n_max): 37 | point_list_s = random.sample(point_list_s, n_max) 38 | distance_array = np.zeros(len(point_list_s)) 39 | for i in range(len(point_list_s)): 40 | ps = point_list_s[i] 41 | ps_nearest = 1e10 42 | for pg in point_list_g: 43 | dd = spacing[0]*(ps[0] - pg[0]) 44 | dh = spacing[1]*(ps[1] - pg[1]) 45 | dw = spacing[2]*(ps[2] - pg[2]) 46 | temp_dis_square = dd*dd + dh*dh + dw*dw 47 | if(temp_dis_square < ps_nearest): 48 | ps_nearest = temp_dis_square 49 | distance_array[i] = math.sqrt(ps_nearest) 50 | return distance_array 51 | 52 | def volume_to_surface(img): 53 | strt = ndimage.generate_binary_structure(3,2) 54 | img = ndimage.morphology.binary_closing(img, strt, 5) 55 | point_list = [] 56 | [D, H, W] = img.shape 57 | offset_d = [-1, 1, 0, 0, 0, 0] 58 | offset_h = [ 0, 0, -1, 1, 0, 0] 59 | offset_w = [ 0, 0, 0, 0, -1, 1] 60 | for d in range(1, D-1): 61 | for h in range(1, H-1): 62 | for w in range(1, W-1): 63 | if(img[d, h, w] > 0): 64 | edge_flag = False 65 | for idx in range(6): 66 | if(img[d + offset_d[idx], h + offset_h[idx], w + offset_w[idx]] == 0): # 在6个方向上迈一步只要有一个为0,该点就为边界点 67 | edge_flag = True 68 | break 69 | if(edge_flag): 70 | point_list.append([d, h, w]) 71 | return point_list 72 | 73 | 74 | 75 | 76 | 77 | # if __name__ == '__main__': 78 | # if(len(sys.argv) != 2): 79 | # print('Number of arguments should be 2. e.g.') 80 | # print(' python util/dice_evaluation.py config.txt') 81 | # exit() -------------------------------------------------------------------------------- /util/collect_organism_hist.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import matplotlib.pyplot as plt 5 | from matplotlib.path import Path 6 | import matplotlib.patches as patches 7 | import dicom 8 | import numpy as np 9 | 10 | def collect_organism_hist(path,Label_wanted): 11 | 12 | Windows_img = {} # 用来保存窗口图片 13 | maskpoint = np.ones([512, 512, 2]) # 生成各点坐标 14 | patient_count = 0 15 | for x in range(512): 16 | for y in range(512): 17 | maskpoint[x, y] = [x, y] 18 | 19 | for patient_file in os.listdir(path): 20 | try: 21 | patient_count += 1 22 | print('正在处理患者%s,这是第%d个病人' % (patient_file, patient_count)) 23 | ctslices = [] 24 | ctnumber = 0 # 记录CT个数 25 | for s in os.listdir(str(path) + '/' + patient_file): # 载入文件 26 | if 'CT' in s: 27 | ctnumber += 1 28 | ctslices.append(dicom.read_file(str(path)+'/'+patient_file + '/' + s)) # 判断是否为CT 29 | if 'RS' in s: 30 | rsslices = dicom.read_file(str(path)+'/'+patient_file + '/' + s) # 读入RS文件 31 | 32 | ctslices.sort(key=lambda x: int(x.ImagePositionPatient[2])) # 按z坐标从小到大排序 33 | origin = [s.ImagePositionPatient for s in ctslices] # 网格原点在世界坐标系的位置 34 | spacing = ctslices[0].PixelSpacing # 采样间隔 35 | intercept = ctslices[0].RescaleIntercept # 重采样截距 36 | slope = ctslices[0].RescaleSlope # 重采样斜率 37 | 38 | 39 | "提取患者的第I个靶区" 40 | for i in range(len(rsslices.RTROIObservationsSequence)): # 第i个勾画区域 41 | label = rsslices.RTROIObservationsSequence[i].ROIObservationLabel # ROIObservationLabel即该ROI是何器官 42 | label = label.lower() 43 | label = label.replace(' ', '') 44 | label = label.replace('-', '') 45 | label = label.replace('_', '') 46 | 47 | if label in Label_wanted: 48 | try: 49 | print(label) 50 | maskdata = np.zeros([ctnumber, 512, 512]) 51 | for j in range(len(rsslices.ROIContourSequence[i].ContourSequence)): # 第j层靶区曲线 52 | 53 | "提取靶区轮廓线坐标并转换为世界坐标" 54 | numberOfPoints = rsslices.ROIContourSequence[i].ContourSequence[j].NumberofContourPoints # 该层曲线上点数 55 | Data = rsslices.ROIContourSequence[i].ContourSequence[j].ContourData 56 | conData = np.zeros([numberOfPoints, 3]) # 存储靶区曲线各点的世界坐标 57 | pointdata = np.zeros([numberOfPoints, 2]) # 存储靶区曲线各点的体素坐标 58 | Z = Data[2] # 当前曲线的z坐标 59 | znumber = round((Z-origin[0][2])/3) # 当前靶线是第几层CT 60 | ctimg = np.array(ctslices[round(znumber)].pixel_array) 61 | ctimg[ctimg == -2000] = 0 62 | if slope != 1: 63 | ctimg = slope * ctimg.astype(np.float64) 64 | ctimg = ctimg.astype(np.int16) 65 | ctimg = ctimg.astype(np.int16) 66 | ctimg += np.int16(intercept) 67 | for jj in range(numberOfPoints): 68 | ii = jj * 3 69 | conData[jj, 0] = Data[ii ] # 轮廓世界坐标系 70 | conData[jj, 1] = Data[ii + 1] 71 | conData[jj, 2] = Data[ii + 2] 72 | pointdata[jj, 0] = round((conData[jj, 0] - origin[0][0])/spacing[0]) # 轮廓X坐标 73 | pointdata[jj, 1] = round((conData[jj, 1] - origin[0][1])/spacing[1]) # 轮廓Y坐标 74 | 75 | "生成靶区mask" 76 | pointdata = np.array(pointdata) 77 | polyline = Path(pointdata, closed=True) # 制成闭合的曲线 78 | maskpoint_reshape = maskpoint.reshape(512*512, 2) 79 | pointin = polyline.contains_points(maskpoint_reshape) 80 | maskpoint_reshape = maskpoint_reshape[pointin, :] 81 | for k in maskpoint_reshape: 82 | maskdata[znumber][int(k[0]), int(k[1])] = 1 83 | if label in Windows_img: 84 | Label_img = (maskdata[znumber] * ctimg).flatten() 85 | Label_cor = Label_img != 0 86 | Label_img = Label_img[Label_cor] 87 | np.concatenate((Windows_img[Label_wanted[label]], Label_img)) 88 | else: 89 | Label_img = (maskdata[znumber]*ctimg).flatten() 90 | Label_cor = Label_img != 0 91 | Label_img = Label_img[Label_cor] 92 | Windows_img[Label_wanted[label]] = Label_img 93 | 94 | except: 95 | print('患者%s的%s标注有点问题' % (patient_file, label)) 96 | except: 97 | print('患者%s的图像数据有些问题' % patient_file) 98 | 99 | # Windows_img[Label_wanted][Windows_img[Label_wanted] > 200 ] = 0 100 | # Windows_img[Label_wanted][Windows_img[Label_wanted] < -200] = 0 101 | for label in Windows_img: 102 | 103 | plt.hist(Windows_img[label], log=True) 104 | plt.title('%s histogram' % label) 105 | plt.show() 106 | 107 | return 108 | 109 | 110 | if __name__ == "__main__": 111 | 112 | path = '/lyc/RTData/Original CT/RT' 113 | Label_wanted = label_wanted 114 | collect_organism_hist(path, Label_wanted) -------------------------------------------------------------------------------- /util/data_augament.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import matplotlib.pyplot as plt 4 | import random 5 | import cv2 6 | from skimage import measure 7 | 8 | 9 | def flip(img_path, save_path, show_img=False, prob=0.8): 10 | 11 | img_num = len(os.listdir(img_path)) 12 | flipnum = 1 13 | for file in os.listdir(img_path): 14 | horizon_prob = random.random() 15 | vertical_prob = random.random() 16 | if horizon_prob < prob: 17 | img = np.load(img_path + '/' + file) 18 | img_horizon = img[:, :, ::-1] 19 | np.save(save_path + '/' + file[:-4] + '_horizon', img_horizon) 20 | 21 | if vertical_prob < prob: 22 | img = np.load(img_path + '/' + file) 23 | img_vertical = img[:, ::-1, :] 24 | np.save(save_path + '/' + file[:-4] + '_vertical', img_vertical) 25 | flipnum += 1 26 | 27 | if show_img and vertical_prob < prob and horizon_prob < prob: 28 | f, plots = plt.subplots(3, 2, figsize=[60, 60]) 29 | plots[0, 0].axis('off') 30 | plots[0, 0].set_title('img') 31 | plots[0, 0].imshow(img[0], cmap=plt.cm.bone) 32 | plots[0, 1].axis('off') 33 | plots[0, 1].set_title('ground truth') 34 | plots[0, 1].imshow(img[1], cmap=plt.cm.bone) 35 | plots[1, 0].axis('off') 36 | plots[1, 0].set_title('img horizon flip') 37 | plots[1, 0].imshow(img_horizon[0], cmap=plt.cm.bone) 38 | plots[1, 1].axis('off') 39 | plots[1, 1].set_title('ground truth flip') 40 | plots[1, 1].imshow(img_horizon[1], cmap=plt.cm.bone) 41 | plots[2, 0].axis('off') 42 | plots[2, 0].set_title('img vertical flip') 43 | plots[2, 0].imshow(img_vertical[0], cmap=plt.cm.bone) 44 | plots[2, 1].axis('off') 45 | plots[2, 1].set_title('ground truth flip') 46 | plots[2, 1].imshow(img_vertical[1], cmap=plt.cm.bone) 47 | plt.show() 48 | print('翻转已完成 %s/%s' % (flipnum, img_num)) 49 | 50 | return 51 | 52 | 53 | def rotation(img_path, save_path, show_img=False, prob=0.8): 54 | 55 | img_num = len(os.listdir(img_path)) 56 | rotatenum = 1 57 | for file in os.listdir(img_path): 58 | rotate_prob = random.random() 59 | if rotate_prob < prob: 60 | img = np.load(img_path + '/' + file) 61 | img_rotation_counter = rotate(img, 20) 62 | img_rotation = rotate(img, -20) 63 | # np.save(save_path + '/' + file[:-4] + '_rotation', img_rotation) 64 | # np.save(save_path + '/' + file[:-4] + '_counter_rotation', img_rotation_counter) 65 | rotatenum += 1 66 | if show_img and rotate_prob < prob: 67 | f, plots = plt.subplots(3, 2, figsize=[60, 60]) 68 | plots[0, 0].axis('off') 69 | plots[0, 0].set_title('img') 70 | plots[0, 0].imshow(img[0], cmap=plt.cm.bone) 71 | plots[0, 1].axis('off') 72 | plots[0, 1].set_title('ground truth') 73 | plots[0, 1].imshow(img[1], cmap=plt.cm.bone) 74 | plots[1, 0].axis('off') 75 | plots[1, 0].set_title('img counter rotation') 76 | plots[1, 0].imshow(img_rotation_counter[0], cmap=plt.cm.bone) 77 | plots[1, 1].axis('off') 78 | plots[1, 1].set_title('ground counter rotation') 79 | plots[1, 1].imshow(img_rotation_counter[1], cmap=plt.cm.bone) 80 | plots[2, 0].axis('off') 81 | plots[2, 0].set_title('img rotation') 82 | plots[2, 0].imshow(img_rotation[0], cmap=plt.cm.bone) 83 | plots[2, 1].axis('off') 84 | plots[2, 1].set_title('ground rotation') 85 | plots[2, 1].imshow(img_rotation[1], cmap=plt.cm.bone) 86 | plt.show() 87 | 88 | print('旋转已完成 %s/%s' % (rotatenum, img_num)) 89 | 90 | return 91 | 92 | def translation(img_path, save_path, show_img=False, prob=0.8): 93 | """ 94 | 平移图像用 95 | :param img_path: 96 | :param save_path: 97 | :param show_img: 98 | :param prob: 小于就平移 99 | :return: 100 | """ 101 | 102 | img_num = len(os.listdir(img_path)) 103 | translatenum = 1 104 | for file in os.listdir(img_path): 105 | rotate_prob = random.random() 106 | if rotate_prob < prob: 107 | img = np.load(img_path + '/' + file) 108 | img_up = translate(img, 0, 10) 109 | img_down = translate(img, 0, -10) 110 | np.save(save_path + '/' + file[:-4] + '_up', img_up) 111 | np.save(save_path + '/' + file[:-4] + '_down', img_down) 112 | translatenum += 1 113 | if show_img and rotate_prob < prob: 114 | f, plots = plt.subplots(3, 2, figsize=[60, 60]) 115 | plots[0, 0].axis('off') 116 | plots[0, 0].set_title('img') 117 | plots[0, 0].imshow(img[0], cmap=plt.cm.bone) 118 | plots[0, 1].axis('off') 119 | plots[0, 1].set_title('ground truth') 120 | plots[0, 1].imshow(img[1], cmap=plt.cm.bone) 121 | plots[1, 0].axis('off') 122 | plots[1, 0].set_title('img counter rotation') 123 | plots[1, 0].imshow(img_up[0], cmap=plt.cm.bone) 124 | plots[1, 1].axis('off') 125 | plots[1, 1].set_title('ground counter rotation') 126 | plots[1, 1].imshow(img_up[1], cmap=plt.cm.bone) 127 | plots[2, 0].axis('off') 128 | plots[2, 0].set_title('img rotation') 129 | plots[2, 0].imshow(img_down[0], cmap=plt.cm.bone) 130 | plots[2, 1].axis('off') 131 | plots[2, 1].set_title('ground rotation') 132 | plots[2, 1].imshow(img_down[1], cmap=plt.cm.bone) 133 | plt.show() 134 | 135 | print('平移已完成 %s/%s' % (translatenum, img_num)) 136 | 137 | return 138 | 139 | def rotate(img, angle, center=None, scale=1.0): 140 | ''' 141 | 旋转图像 142 | :param img: 143 | :param angle: 逆时针旋转角度 144 | :param center: 旋转中心,不指定默认为图像中心 145 | :param scale: 尺度变化参数 146 | :return: 147 | ''' 148 | image = np.transpose(img, (1, 2, 0)) 149 | (h, w) = image.shape[:2] 150 | # 若未指定旋转中心,则将图像中心设为旋转中心 151 | if center is None: 152 | center = (w / 2, h / 2) 153 | # 执行旋转 154 | M = cv2.getRotationMatrix2D(center, angle, scale) 155 | rotated = cv2.warpAffine(image, M, (w, h)) 156 | rotated = np.transpose(rotated, (2, 0, 1)) 157 | return rotated 158 | 159 | 160 | def translate(img, x, y): 161 | ''' 162 | 图像平移函数 163 | 原numpy文件通道数在第0维,cv2操作前需先转到第二维 164 | :param img: 原文件 165 | :param x: x轴平移距离 166 | :param y: y轴平移距离 167 | :return: 平移后图像 168 | ''' 169 | 170 | image = np.transpose(img, (1, 2, 0)) 171 | # 定义平移矩阵 172 | M = np.float32([[1, 0, x], [0, 1, y]]) 173 | shifted = cv2.warpAffine(image, M, (image.shape[1], image.shape[0])) 174 | shifted = np.transpose(shifted, (2, 0, 1)) 175 | # 返回转换后的图像 176 | return shifted 177 | 178 | 179 | def clip(img, module=None, clip_module='test', box=None, patient_name=None, znumber=None, width=256, length=256, multipl=1, multinumber=4): 180 | ''' 181 | :param img: 182 | :param box: 指定身体区域,为行列最小与行列最大值 183 | :param module: 判断分割身体还是器官 184 | :param patient_name: 患者姓名 185 | :param znumber : 当前图像纵坐标 186 | :param width : 裁剪图像的宽 187 | :param length : 裁剪图像的高 188 | :param multipl: 判断是否要裁取多个图,即是否4倍增强. 189 | :param multinumber: 增强倍数 190 | :return: 裁剪后图像 191 | ''' 192 | if module == 'body': 193 | center = [(box[0]+box[2])/2, (box[1]+box[3])/2] 194 | body_length = box[2] - box[0] 195 | body_width = box[3] - box[1] 196 | if body_width < width and body_length < length: 197 | if multipl == 1: 198 | height_random = min(box[0], length - body_length) 199 | width_random = min(box[1], width - body_width) 200 | for i in range(multinumber): 201 | rowmin = int(box[0] - random.randint(0, height_random)) # 裁取图像的行坐标起点 202 | colmin = int(box[1] - random.randint(0, width_random)) # 裁取图像的纵坐标起点 203 | Img = img[:, rowmin:rowmin + width, colmin:colmin + length] 204 | plt.imshow(Img[0], cmap='bone') 205 | plt.show() 206 | np.save('/lyc/RTData/Parotid256/train/%s_%s_multi%s' % (patient_name, znumber, i), Img) 207 | else: 208 | rowmin = int(center[0] - width) 209 | colmin = int(center[1] - width) 210 | Img = img[:, rowmin:rowmin + width, colmin:colmin + length] 211 | plt.imshow(img[0], cmap='bone') 212 | plt.show() 213 | np.save('/lyc/RTData/OpticChaism128/train/%s_%s' % (patient_name, znumber), Img) 214 | 215 | elif module == 'organism': 216 | labels = measure.label(img[1]) 217 | regions = measure.regionprops(labels) 218 | if len(regions) > 0: 219 | box = [512, 512, 0, 0] 220 | for labelnumber in regions: # 找出该ct中所有标注的最大最小横纵坐标值. 221 | box[0], box[1] = min(labelnumber.bbox[0], box[0]), min(labelnumber.bbox[1], box[1]) 222 | box[2], box[3] = max(labelnumber.bbox[2], box[2]), max(labelnumber.bbox[3], box[3]) 223 | 224 | # 为了避免裁剪时让label太靠近边界,故框四边各向外扩张16 225 | box[0] -= 16 226 | box[1] -= 16 227 | box[2] += 16 228 | box[3] += 16 229 | 230 | label_height = box[2] - box[0] 231 | label_width = box[3] - box[1] 232 | if label_width < width and label_height < length: 233 | if multipl == 1: 234 | row_random_max = min(box[0], 512 - width) 235 | row_random_min = max(box[2]-width, 0) 236 | col_random_max = min(box[1], 512 - length) 237 | col_random_min = max(box[3]-length, 0) 238 | for i in range(multinumber): 239 | colmin = random.randint(col_random_min, col_random_max) # 裁取图像的行坐标起点 240 | rowmin = random.randint(row_random_min, row_random_max) # 裁取图像的纵坐标起点 241 | Img = img[:, rowmin:rowmin + width, colmin:colmin + length] 242 | # f, plots = plt.subplots(1, 4, figsize=(60, 60)) 243 | # plots[0].imshow(Img[0], cmap='bone') 244 | # plots[1].imshow(Img[1]) 245 | # plots[2].imshow(Img[1]*Img[0]) 246 | # plots[3].imshow(img[0]) 247 | # plt.show() 248 | np.save('/lyc/RTData/OpticNerve/%s/%s/%s_%s_multi%s' % (width, clip_module, patient_name, znumber, i), Img ) 249 | 250 | return 251 | 252 | def mkdir(path): 253 | """ 254 | 创建path所给文件夹 255 | :param path: 256 | :return: 257 | """ 258 | folder = os.path.exists(path) 259 | 260 | if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 261 | os.makedirs(path) # makedirs 创建文件时如果路径不存在会创建这个路径 262 | print("--- new folder... ---") 263 | 264 | print("--- OK ---") 265 | 266 | else: 267 | print("--- There is this folder! ---") 268 | 269 | 270 | def ThreeDclip(img, label, module=None, clip_module='test', box=None, duplicate_path=None, 271 | patient_name=None, znumber=None, width=256, length=256, multipl=0, multinumber=4, show_label=False): 272 | ''' 273 | :param img: 274 | :param box: 指定身体区域,为行列最小与行列最大值 275 | :param module: 判断分割身体还是器官 276 | :param patient_name: 患者姓名 277 | :param znumber : 当前图像纵坐标 278 | :param width : 裁剪图像的宽 279 | :param length : 裁剪图像的高 280 | :param multipl: 判断是否要裁取多个图,即是否4倍增强. 281 | :param multinumber: 增强倍数 282 | :return: 裁剪后图像 283 | ''' 284 | save_path = duplicate_path + '/' + clip_module + '/' + patient_name 285 | mkdir(save_path) 286 | if module == 'body': 287 | body_length = max(box[:, 2]) - min(box[:, 0]) 288 | body_width = max(box[:, 3]) - min(box[:, 1]) 289 | if body_width > width or body_length > length: 290 | print('图像设置太小, patient: %s, 所需宽: %d, 所需长: %d' % (patient_name, body_width, body_length)) 291 | body_length = length 292 | body_width = width 293 | center = [256, 256] 294 | else: 295 | center = [(min(box[:, 0]) + max(box[:, 2])) / 2, (min(box[:, 1]) + max(box[:, 3])) / 2] 296 | if multipl == 1: 297 | height_random = min(box[0], length - body_length) 298 | width_random = min(box[1], width - body_width) 299 | for i in range(multinumber): 300 | rowmin = int(box[0] - random.randint(0, height_random)) # 裁取图像的行坐标起点 301 | colmin = int(box[1] - random.randint(0, width_random)) # 裁取图像的纵坐标起点 302 | Img = img[:, rowmin:rowmin + width, colmin:colmin + length] 303 | label = label[:, rowmin:rowmin + width, colmin:colmin + length] 304 | if show_label: 305 | for i in range(len(label)): 306 | if i % 5 == 0: 307 | f, plots = plt.subplots(1, 2, figsize=(60, 60)) 308 | plots[0].imshow(Img[i], cmap=plt.cm.bone) 309 | plots[1].imshow(Img[i] * label[i]) 310 | plt.show() 311 | np.save(save_path + '/' + 'Img.npy', Img) 312 | np.save(save_path + '/' + 'label.npy', label) 313 | else: 314 | rowmin = int(center[0] - width/2) 315 | colmin = int(center[1] - length/2) 316 | Img = img[:, rowmin:rowmin + width, colmin:colmin + length] 317 | label = label[:, rowmin:rowmin + width, colmin:colmin + length] 318 | if show_label: 319 | for i in range(len(label)): 320 | if i % 5 == 0: 321 | f, plots = plt.subplots(1, 2, figsize=(60, 60)) 322 | plots[0].imshow(Img[i], cmap=plt.cm.bone) 323 | plots[1].imshow(Img[i] * label[i]) 324 | plt.show() 325 | np.save(save_path + '/' + 'Img.npy', Img) 326 | np.save(save_path + '/' + 'label.npy', label) 327 | 328 | 329 | elif module == 'organism': 330 | labels = measure.label(img[1]) 331 | regions = measure.regionprops(labels) 332 | if len(regions) > 0: 333 | box = [512, 512, 0, 0] 334 | for labelnumber in regions: # 找出该ct中所有标注的最大最小横纵坐标值. 335 | box[0], box[1] = min(labelnumber.bbox[0], box[0]), min(labelnumber.bbox[1], box[1]) 336 | box[2], box[3] = max(labelnumber.bbox[2], box[2]), max(labelnumber.bbox[3], box[3]) 337 | 338 | # 为了避免裁剪时让label太靠近边界,故框四边各向外扩张16 339 | box[0] -= 16 340 | box[1] -= 16 341 | box[2] += 16 342 | box[3] += 16 343 | 344 | label_height = box[2] - box[0] 345 | label_width = box[3] - box[1] 346 | if label_width < width and label_height < length: 347 | if multipl == 1: 348 | row_random_max = min(box[0], 512 - width) 349 | row_random_min = max(box[2]-width, 0) 350 | col_random_max = min(box[1], 512 - length) 351 | col_random_min = max(box[3]-length, 0) 352 | for i in range(multinumber): 353 | colmin = random.randint(col_random_min, col_random_max) # 裁取图像的行坐标起点 354 | rowmin = random.randint(row_random_min, row_random_max) # 裁取图像的纵坐标起点 355 | Img = img[:, rowmin:rowmin + width, colmin:colmin + length] 356 | # f, plots = plt.subplots(1, 4, figsize=(60, 60)) 357 | # plots[0].imshow(Img[0], cmap='bone') 358 | # plots[1].imshow(Img[1]) 359 | # plots[2].imshow(Img[1]*Img[0]) 360 | # plots[3].imshow(img[0]) 361 | # plt.show() 362 | np.save('/lyc/RTData/OpticNerve/%s/%s/%s_%s_multi%s' % (width, clip_module, patient_name, znumber, i), Img ) 363 | 364 | return 365 | 366 | if __name__ == "__main__": 367 | img_path = '/lyc/RTData/OpticNerve/256/train' 368 | save_path = '/lyc/RTData/OpticNerve/256/aug' 369 | show_img = False 370 | 371 | flip(img_path, save_path, show_img=show_img, prob=0.8) 372 | rotation(img_path, save_path, show_img=show_img, prob=0.8) 373 | # translation(img_path, save_path, show_img=show_img, prob=0.8) -------------------------------------------------------------------------------- /util/dice_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | import os 4 | import sys 5 | import numpy as np 6 | import mpu.ml 7 | from scipy import ndimage 8 | import matplotlib.pyplot as plt 9 | from Training.util.parse_config import parse_config 10 | from Training.data_process.data_process_func import load_nifty_volume_as_array 11 | from Training.util.binary import dc, assd 12 | 13 | def get_largest_component(img): 14 | s = ndimage.generate_binary_structure(3,1) # iterate structure 15 | labeled_array, numpatches = ndimage.label(img,s) # labeling 16 | sizes = ndimage.sum(img,labeled_array,range(1,numpatches+1)) 17 | max_label = np.where(sizes == sizes.max())[0] + 1 18 | return labeled_array == max_label 19 | 20 | def binary_dice3d(s,g): 21 | assert(len(s.shape)==2) 22 | prod = np.multiply(s, g) 23 | s0 = prod.sum(axis=-1) 24 | s1 = s.sum(axis=-1) 25 | s2 = g.sum(axis=-1) 26 | dice = 2.0*s0/(s1 + s2 + 0.00001) 27 | return dice[1::] 28 | 29 | def dice_of_binary_volumes(s_name, g_name): 30 | s = load_nifty_volume_as_array(s_name) 31 | g = load_nifty_volume_as_array(g_name) 32 | dice = binary_dice3d(s, g) 33 | return dice 34 | 35 | def one_hot(img, nb_classes): 36 | hot_img = np.zeros([nb_classes]+list(img.shape)) 37 | for i in range(nb_classes): 38 | hot_img[i][np.where(img == i)] = 1 39 | return hot_img 40 | 41 | def evaluation(folder, classnum=6, save=False): 42 | patient_list = os.listdir(folder) 43 | dice_all_data = [] 44 | assd_all_data = [] 45 | for patient in patient_list: 46 | s_name = os.path.join(folder, patient + '/label.nii.gz') 47 | g_name = os.path.join(folder, patient + '/OverlayInterSeg.nii.gz') 48 | #s_volume = np.int64(np.load(s_name)) 49 | s_volume = load_nifty_volume_as_array(s_name) 50 | g_volume = load_nifty_volume_as_array(g_name) 51 | s_volume = one_hot(s_volume, nb_classes=classnum) 52 | g_volume = one_hot(g_volume, nb_classes=classnum) 53 | dice_list=[] 54 | assd_list=[] 55 | for i in range(classnum): 56 | temp_dice = dc(s_volume[i], g_volume[i]) 57 | temp_assd = assd(s_volume[i], g_volume[i],voxelspacing=[3, 1, 1]) 58 | dice_list.append(temp_dice) 59 | assd_list.append(temp_assd) 60 | dice_all_data.append(dice_list) 61 | assd_all_data.append(assd_list) 62 | print(patient, dice_list) 63 | if save: 64 | np.savetxt(os.path.join(folder, patient + '/Inter_dice.txt'), np.asarray(dice_list)) 65 | dice_all_data = np.asarray(dice_all_data) 66 | dice_mean = [dice_all_data.mean(axis = 0)] 67 | dice_std = [dice_all_data.std(axis = 0)] 68 | if save: 69 | np.savetxt(folder + '/dice_all.txt', dice_all_data) 70 | np.savetxt(folder + '/dice_mean.txt', dice_mean) 71 | np.savetxt(folder + '/dice_std.txt', dice_std) 72 | print('dice mean ', dice_mean) 73 | print('dice std ', dice_std) 74 | assd_all_data = np.asarray(assd_all_data) 75 | assd_mean = [assd_all_data.mean(axis = 0)] 76 | assd_std = [assd_all_data.std(axis = 0)] 77 | print('assd mean ', assd_mean) 78 | print('assd std ', assd_std) 79 | if save: 80 | np.savetxt(folder + '/assd_all.txt', assd_all_data) 81 | np.savetxt(folder + '/assd_mean.txt', assd_mean) 82 | np.savetxt(folder + '/assd_std.txt', assd_std) 83 | 84 | # if __name__ == '__main__': 85 | # if(len(sys.argv) != 2): 86 | # print('Number of arguments should be 2. e.g.') 87 | # print(' python util/dice_evaluation.py config.txt') 88 | # exit() 89 | folder = '/lyc/Head-Neck-CT/3D_data/test/' 90 | evaluation(folder, classnum=5, save=False) 91 | -------------------------------------------------------------------------------- /util/dump_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | ### 把鼻咽癌的3D数据转存一下 4 | 5 | def mkdir(path): 6 | folder = os.path.exists(path) 7 | 8 | if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 9 | os.makedirs(path) # makedirs 创建文件时如果路径不存在会创建这个路径 10 | print("--- new folder... ---") 11 | 12 | print("--- OK ---") 13 | 14 | else: 15 | print("--- There is this folder! ---") 16 | 17 | 18 | 19 | def dumpdata(root, dumproot, filename): 20 | folder = os.path.exists(root) 21 | count = 0 22 | for fn in os.listdir(root): # fn 表示的是文件名 23 | count = count + 1 24 | print(count) 25 | if folder: 26 | for patient in os.listdir(root): 27 | count = count - 1 28 | dumppath = dumproot + patient 29 | mkdir(dumppath) 30 | for file in filename: 31 | filepath = root + patient + '/' + file 32 | dumpfile = np.load(filepath) 33 | np.save(dumppath + '/' + file, dumpfile) 34 | print("%d patients left" % count) 35 | 36 | root = "/lyc/RTData/Original CT/RT/" 37 | dumproot = "/lyc/RTData/3D_data/" 38 | filename = ["Img_norm.npy", "Mask.npy", "Patient_pixel.npy"] 39 | dumpdata(root, dumproot, filename) 40 | -------------------------------------------------------------------------------- /util/grid_normal.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def grid_mean_3d(image, grid_size=[4,4,4], norm="mean"): 6 | w, h, d = image.shape 7 | grid_x, grid_y, grid_z = grid_size 8 | assert w % grid_x == 0 9 | assert h % grid_y == 0 10 | assert d % grid_z == 0 11 | round1 = w // grid_x 12 | round2 = h // grid_y 13 | round3 = d // grid_z 14 | all_grid_patchs = [] 15 | for indx in range(round1): 16 | for indy in range(round2): 17 | for indz in range(round3): 18 | patch = image[indx*grid_x:(indx+1)*grid_x, indy * 19 | grid_y:(indy+1)*grid_y, indz*grid_z:(indz+1)*grid_z] 20 | all_grid_patchs.append(patch) 21 | 22 | num = 0 23 | grid_mean_image = np.zeros((w, h, d), np.float32) 24 | for indx in range(round1): 25 | for indy in range(round2): 26 | for indz in range(round3): 27 | if norm == "mean": 28 | grid_mean_image[indx*grid_x:(indx+1)*grid_x, indy*grid_y:( 29 | indy+1)*grid_y, indz*grid_z:(indz+1)*grid_z] = all_grid_patchs[num].mean() 30 | if norm == "max": 31 | grid_mean_image[indx*grid_x:(indx+1)*grid_x, indy*grid_y:( 32 | indy+1)*grid_y, indz*grid_z:(indz+1)*grid_z] = all_grid_patchs[num].max() 33 | if norm == "min": 34 | grid_mean_image[indx*grid_x:(indx+1)*grid_x, indy*grid_y:( 35 | indy+1)*grid_y, indz*grid_z:(indz+1)*grid_z] = all_grid_patchs[num].min() 36 | if norm == "random": 37 | patch_value = [all_grid_patchs[num].min( 38 | ), all_grid_patchs[num].max(), all_grid_patchs[num].mean()] 39 | grid_mean_image[indx*grid_x:(indx+1)*grid_x, indy*grid_y:( 40 | indy+1)*grid_y, indz*grid_z:(indz+1)*grid_z] = patch_value[random.randint(0, 2)] 41 | num += 1 42 | return grid_mean_image 43 | -------------------------------------------------------------------------------- /util/make_3d_ground_truth_only.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import os 4 | import matplotlib.pyplot as plt 5 | from matplotlib.path import Path 6 | import matplotlib.patches as patches 7 | import dicom 8 | import numpy as np 9 | from util.pre_function import get_segmented_body 10 | from skimage import measure 11 | from util.data_augament import ThreeDclip 12 | 13 | def make_3d_groundtruth_only(patient_file_root, duplicate_path, label_wanted, label_aug, window_max, window_min, show_body = False, show_label = False): 14 | 15 | maskpoint = np.ones([512, 512, 2]) # 记录图像每点坐标 16 | patient_count = 0 17 | 18 | for x in range(512): 19 | for y in range(512): 20 | maskpoint[x, y] = [x, y] 21 | patient_number = len(os.listdir(patient_file_root)) 22 | train_number = int(0.5*patient_number) 23 | test_number = int(0.7*patient_number) 24 | 25 | for patient_file in os.listdir(patient_file_root): 26 | 27 | zmax = 0 # 记录Label的最大最小z值 28 | zmin = 300 29 | patient_count += 1 30 | print('正在处理患者%s,这是第%d个病人' % (patient_file, patient_count)) 31 | ctslices = [] 32 | ctnumber = 0 33 | for s in os.listdir(str(patient_file_root) + '/' + patient_file): 34 | if 'CT' in s: 35 | ctnumber += 1 36 | ctslices.append(dicom.read_file(str(patient_file_root)+'/'+patient_file + '/' + s)) 37 | if 'RS' in s: 38 | rsslices = dicom.read_file(str(patient_file_root)+'/'+patient_file + '/' + s) # 读入RS文件 39 | 40 | ctslices.sort(key=lambda x : int(x.ImagePositionPatient[2])) # 按z坐标从小到大排序 41 | origin = [s.ImagePositionPatient for s in ctslices] # 网格原点在世界坐标系的位置 42 | spacing = ctslices[0].PixelSpacing # 采样间隔 43 | labeldata = np.zeros([ctnumber, 512, 512]) 44 | imgdata = np.zeros([ctnumber, 512, 512]) 45 | aug_label = np.zeros(ctnumber) 46 | body_box = np.zeros([ctnumber, 4]) 47 | intercept = ctslices[0].RescaleIntercept # 重采样截距 48 | slope = ctslices[0].RescaleSlope # 重采样斜率 49 | 50 | ''' 51 | 提取患者身体区域 52 | ''' 53 | znumber = 0 54 | for ct in ctslices: 55 | 56 | ctimg = np.array(ct.pixel_array) 57 | ctimg[ctimg == -2000] = 0 58 | if slope != 1: 59 | ctimg = slope * ctimg.astype(np.float64) 60 | ctimg = ctimg.astype(np.int64) 61 | ctimg = ctimg.astype(np.int64) 62 | ctimg += np.int64(intercept) 63 | 64 | body_img, body_mask = get_segmented_body(ctimg, window_max=window_max, window_min=window_min, 65 | window_length=0, 66 | show_body=show_body, znumber=znumber) 67 | labels = measure.label(body_mask) 68 | regions = measure.regionprops(labels)[0].bbox 69 | body_box[znumber] = regions 70 | imgdata[znumber] = body_img 71 | znumber += 1 72 | 73 | ''' 74 | 提取患者的靶区 75 | ''' 76 | for i in range(len(rsslices.RTROIObservationsSequence)): # 第i个勾画区域 77 | label = rsslices.RTROIObservationsSequence[i].ROIObservationLabel # ROIObservationLabel即该ROI是何器官 78 | label = label.lower() 79 | label = label.replace(' ', '') 80 | label = label.replace('-', '') 81 | label = label.replace('_', '') 82 | 83 | if label in label_wanted: 84 | print(label) 85 | for j in range(len(rsslices.ROIContourSequence[i].ContourSequence)): # 第j层靶区曲线 86 | 87 | "提取靶区轮廓线坐标并转换为世界坐标" 88 | numberOfPoints = rsslices.ROIContourSequence[i].ContourSequence[j].NumberofContourPoints # 该层曲线上点数 89 | Data = rsslices.ROIContourSequence[i].ContourSequence[j].ContourData 90 | conData = np.zeros([numberOfPoints, 3]) # 存储靶区曲线各点的世界坐标 91 | pointdata = np.zeros([numberOfPoints, 2]) # 存储靶区曲线各点的体素坐标 92 | Z = Data[2] 93 | znumber = round((Z-origin[0][2])/3) 94 | if znumber > zmax: 95 | zmax = znumber 96 | elif znumber < zmin: 97 | zmin = znumber 98 | for jj in range(numberOfPoints): 99 | ii = jj * 3 100 | conData[jj, 0] = Data[ii ] # 轮廓世界坐标系 101 | conData[jj, 1] = Data[ii + 1] 102 | conData[jj, 2] = Data[ii + 2] 103 | pointdata[jj, 0] = round((conData[jj, 0] - origin[0][0])/spacing[0]) # 轮廓X坐标 104 | pointdata[jj, 1] = round((conData[jj, 1] - origin[0][1])/spacing[1]) # 轮廓Y坐标 105 | 106 | "生成靶区mask" 107 | pointdata = np.array(pointdata) 108 | polyline = Path(pointdata, closed=True) # 制成闭合的曲线 109 | maskpoint_reshape = maskpoint.reshape(512*512, 2) 110 | pointin = polyline.contains_points(maskpoint_reshape) 111 | maskpoint_reshape = maskpoint_reshape[pointin, :] 112 | for k in maskpoint_reshape: 113 | labeldata[znumber, int(k[1]), int(k[0])] = label_wanted[label] 114 | # if label in label_aug: # 判断是否需要增强 115 | # aug_label[znumber] = 1 116 | imgdata = imgdata[zmin-5:zmax+5] 117 | labeldata = labeldata[zmin-5:zmax+5] 118 | body_box = body_box[zmin-5:zmax+5] 119 | if show_label: 120 | for i in range(len(labeldata)): 121 | if i % 5 == 0: 122 | f, plots = plt.subplots(1, 2, figsize=(60, 60)) 123 | plots[0].imshow(imgdata[i], cmap=plt.cm.bone) 124 | plots[1].imshow(imgdata[i]*labeldata[i]) 125 | plt.show() 126 | 127 | if patient_count <= train_number: 128 | ThreeDclip(imgdata, labeldata, 'body', 'train', body_box, duplicate_path, patient_file, 129 | multinumber=4, show_label=False) # 训练集需要裁剪增强 130 | 131 | elif train_number < patient_count <= test_number: 132 | ThreeDclip(imgdata, labeldata, 'body', 'test', body_box, duplicate_path, patient_file, multinumber=4) 133 | 134 | else: 135 | 136 | ThreeDclip(imgdata, labeldata, 'body', 'valid', body_box, duplicate_path, patient_file, multinumber=4) 137 | 138 | print('已成功存储患者%s的数据' % patient_file) 139 | 140 | 141 | 142 | 143 | return 144 | 145 | 146 | -------------------------------------------------------------------------------- /util/parse_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | 4 | import configparser 5 | def is_int(val_str): 6 | start_digit = 0 7 | if(val_str[0] =='-'): 8 | start_digit = 1 9 | flag = True 10 | for i in range(start_digit, len(val_str)): 11 | if(str(val_str[i]) < '0' or str(val_str[i]) > '9'): 12 | flag = False 13 | break 14 | return flag 15 | 16 | def is_float(val_str): 17 | flag = False 18 | if('.' in val_str and len(val_str.split('.'))==2): 19 | if(is_int(val_str.split('.')[0]) and is_int(val_str.split('.')[1])): 20 | flag = True 21 | else: 22 | flag = False 23 | elif('e' in val_str and len(val_str.split('e'))==2): 24 | if(is_int(val_str.split('e')[0]) and is_int(val_str.split('e')[1])): 25 | flag = True 26 | else: 27 | flag = False 28 | else: 29 | flag = False 30 | return flag 31 | 32 | def is_bool(var_str): 33 | if( var_str=='True' or var_str == 'true' or var_str =='False' or var_str=='false'): 34 | return True 35 | else: 36 | return False 37 | 38 | def parse_bool(var_str): 39 | if(var_str=='True' or var_str == 'true' ): 40 | return True 41 | else: 42 | return False 43 | 44 | def is_list(val_str): 45 | if(val_str[0] == '[' and val_str[-1] == ']'): 46 | return True 47 | else: 48 | return False 49 | 50 | def parse_list(val_str): 51 | sub_str = val_str[1:-1] 52 | splits = sub_str.split(',') 53 | output = [] 54 | for item in splits: 55 | item = item.strip() 56 | if(is_int(item)): 57 | output.append(int(item)) 58 | elif(is_float(item)): 59 | output.append(float(item)) 60 | elif(is_bool(item)): 61 | output.append(parse_bool(item)) 62 | else: 63 | output.append(item) 64 | return output 65 | 66 | def parse_value_from_string(val_str): 67 | # val_str = val_str.encode('ascii','ignore') 68 | if(is_int(val_str)): 69 | val = int(val_str) 70 | elif(is_float(val_str)): 71 | val = float(val_str) 72 | elif(is_list(val_str)): 73 | val = parse_list(val_str) 74 | elif(is_bool(val_str)): 75 | val = parse_bool(val_str) 76 | else: 77 | val = val_str 78 | return val 79 | 80 | def parse_config(filename): 81 | config = configparser.ConfigParser() 82 | config.read(filename) 83 | output = {} 84 | for section in config.sections(): # Return a list of section names 85 | output[section] = {} 86 | for key in config[section]: 87 | val_str = str(config[section][key]) 88 | if(len(val_str)>0): 89 | val = parse_value_from_string(val_str) 90 | else: 91 | val = None 92 | print(section, key, val_str, val) 93 | output[section][key] = val 94 | return output 95 | 96 | if __name__ == "__main__": 97 | print(is_int('555')) 98 | print(is_float('555.10')) 99 | a='[1 ,2 ,3 ]' 100 | print(a) 101 | print(parse_list(a)) 102 | 103 | -------------------------------------------------------------------------------- /util/pre_function.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding:utf-8 -*- 3 | 4 | 5 | ''' 6 | 该文件主要包含图像处理的函数 7 | ''' 8 | import os 9 | from skimage import morphology 10 | from skimage.measure import label, regionprops 11 | from skimage.filters import roberts 12 | from skimage import measure 13 | from scipy import ndimage as ndi 14 | import matplotlib.pyplot as plt 15 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 16 | import dicom 17 | import scipy.misc 18 | import numpy as np 19 | 20 | 21 | # Load the scans in given folder path 22 | def load_scan(path): 23 | ''' 24 | 该函数用于载入path下各患者的CT,并从中提取图像储存在各自的文件夹 25 | ''' 26 | Patient_count = 0 27 | for Patient_file in os.listdir(path): 28 | try: 29 | Patient_count += 1 30 | CTslices = [] 31 | CTnumber = 0 # 记录CT个数 32 | print('正在处理患者%s,这是第%d个病人' % (Patient_file, Patient_count)) 33 | for s in os.listdir(str(path) + '/' + Patient_file): # 载入文件 34 | if 'CT' in s: 35 | CTnumber += 1 36 | CTslices.append(dicom.read_file(str(path) + '/' + Patient_file + '/' + s)) # 判断是否为CT 37 | CTslices.sort(key=lambda x: int(x.ImagePositionPatient[2])) 38 | 39 | "提取CT间隔" 40 | print(CTslices[1].ImagePositionPatient[2], CTslices[2].ImagePositionPatient[2]) 41 | try: 42 | slice_thickness = np.abs(CTslices[0].ImagePositionPatient[2] - CTslices[1].ImagePositionPatient[2]) 43 | except: 44 | slice_thickness = np.abs(CTslices[0].SliceLocation - CTslices[1].SliceLocation) 45 | for s in CTslices: 46 | s.SliceThickness = slice_thickness 47 | 48 | CTimage = get_pixels_hu(CTslices) 49 | np.save(str(path) + '/' + Patient_file + '/Patient_pixel.npy', CTimage) 50 | except: 51 | print('提取图像中,患者%sCT图像有问题' % Patient_file) 52 | print('图像提取完成!') 53 | return 54 | 55 | 56 | def get_pixels_hu(slices): 57 | # 灰度值转换为HU单元 58 | 59 | image = np.stack([s.pixel_array for s in slices]) 60 | 61 | # Convert to int16 (from sometimes int16), 62 | # should be possible as values should always be low enough (<32k) 63 | image = image.astype(np.float64) 64 | 65 | # Set outside-of-scan pixels to 0 66 | # The intercept is usually -1024, so air is approximately 0 67 | image[image == -2000] = 0 68 | 69 | # Convert to Hounsfield units (HU) 70 | for slice_number in range(len(slices)): 71 | # 回到HU单元,乘以rescale比率并加上intercept(存储在扫描面的元数据中) 72 | intercept = slices[slice_number].RescaleIntercept # 截距 73 | slope = slices[slice_number].RescaleSlope # 斜率 74 | if slope != 1: 75 | image[slice_number] = slope * image[slice_number].astype(np.float64) 76 | image[slice_number] = image[slice_number].astype(np.float64) 77 | image[slice_number] += np.float64(intercept) 78 | return np.array(image, dtype=np.float64) 79 | 80 | 81 | def resample(image, scan, new_spacing=[1, 1, 1]): 82 | # 重采样 83 | # 不同扫描面的像素尺寸、粗细粒度是不同的。这不利于我们进行CNN任务,我们可以使用同构采样。 84 | # Determine current pixel spacing 85 | print('扫描面厚度,像素间距', [scan[0].SliceThickness] + scan[0].PixelSpacing) 86 | spacing = map(float, ([scan[0].SliceThickness] + scan[0].PixelSpacing)) 87 | spacing = np.array(list(spacing)) 88 | resize_factor = spacing / new_spacing # ??? 89 | new_real_shape = image.shape * resize_factor 90 | new_shape = np.round(new_real_shape) 91 | print('new real shape is', new_real_shape, 'resize factor is', resize_factor) 92 | real_resize_factor = new_shape / image.shape 93 | new_spacing = spacing / real_resize_factor 94 | 95 | # 插值法上采样 96 | image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest') 97 | 98 | return image, new_spacing 99 | 100 | 101 | def plot_3d(image, threshold=-300): 102 | # 使用matplotlib输出肺部扫描的3D图像方法。可能需要一两分钟 103 | # Position the scan upright, 104 | # so the head of the patient would be at the top facing the camera 105 | p = image.transpose(2, 1, 0) 106 | verts, faces, x, y = measure.marching_cubes(p, threshold) 107 | fig = plt.figure(figsize=(10, 10)) 108 | ax = fig.add_subplot(111, projection='3d') 109 | # Fancy indexing: `verts[faces]` to generate a collection of triangles 110 | mesh = Poly3DCollection(verts[faces], alpha=0.1) 111 | face_color = [0.5, 0.5, 1] 112 | mesh.set_facecolor(face_color) 113 | ax.add_collection3d(mesh) 114 | ax.set_xlim(0, p.shape[0]) 115 | ax.set_ylim(0, p.shape[1]) 116 | ax.set_zlim(0, p.shape[2]) 117 | plt.show() 118 | 119 | 120 | def plot_ct_scan(scan): 121 | # 输出一个病人scans中所有切面slices 122 | ''' 123 | plot a few more images of the slices 124 | :param scan: 125 | :return: 126 | ''' 127 | f, plots = plt.subplots(int(scan.shape[0] / 20), 4, figsize=(50, 50)) 128 | for i in range(0, scan.shape[0], 5): 129 | plots[int(i / 20), int((i % 20) / 5)].axis('on') 130 | plots[int(i / 20), int((i % 20) / 5)].imshow(scan[i], cmap=plt.cm.bone) 131 | plt.show() 132 | 133 | 134 | def get_segmented_body(img, window_max=250, window_min=-150, window_length=0, show_body=False, znumber=0): 135 | ''' 136 | 将身体与外部分离出来 137 | ''' 138 | 139 | mask = [] 140 | 141 | if znumber < 40: 142 | radius = [13, 6] 143 | else: 144 | radius = [6, 8] 145 | 146 | plot = False 147 | show_now = False 148 | if show_body: 149 | if znumber % 10 == 0: 150 | plot = True 151 | show_now = True 152 | 153 | if plot == True: 154 | f, plots = plt.subplots(2, 4, figsize=(60, 60)) 155 | 156 | ''' 157 | Step 1: Convert into a binary image.二值化,为确保所定阈值通过大多数 158 | ''' 159 | threshold = -600 160 | binary = np.where(img > threshold, 1.0, 0.0) # threshold the image 161 | 162 | if plot == True: 163 | plots[0, 0].axis('off') 164 | plots[0, 0].set_title('convert into a binary image,the the threshold%s' % threshold) 165 | plots[0, 0].imshow(binary, cmap=plt.cm.bone) 166 | ''' 167 | Step 2: Remove the blobs connected to the border of the image. 168 | 清除边界 169 | ''' 170 | # cleared = clear_border(binary,buffer_size=50) 171 | # if plot == True: 172 | # plots[0,1].axis('off') 173 | # plots[0,1].set_title('after clear border') 174 | # plots[0,1].imshow(cleared[0], cmap=plt.cm.bone) 175 | # print(cleared[0]) 176 | 177 | ''' 178 | Step 3: Erosion operation with a disk of radius 2. This operation is 179 | seperate the lung nodules attached to the blood vessels. 180 | 腐蚀操作,以2mm为半径去除 181 | ''' 182 | binary = morphology.erosion(binary, np.ones([radius[0], radius[0]])) 183 | if plot == True: 184 | plots[0, 1].axis('off') 185 | plots[0, 1].set_title('erosion operation') 186 | plots[0, 1].imshow(binary, cmap=plt.cm.bone) 187 | 188 | ''' 189 | Step 4: Closure operation with a disk of radius 10. This operation is 190 | to keep nodules attached to the lung wall.闭合运算 191 | ''' 192 | binary = morphology.dilation(binary, np.ones([radius[1], radius[1]])) 193 | if plot == True: 194 | plots[0, 2].axis('off') 195 | plots[0, 2].set_title('closure operation') 196 | plots[0, 2].imshow(binary, cmap=plt.cm.bone) 197 | 198 | ''' 199 | Step 5: Label the image.连通区域标记 200 | ''' 201 | label_image = label(binary) 202 | if plot == True: 203 | plots[0, 3].axis('off') 204 | plots[0, 3].set_title('found all connective graph') 205 | plots[0, 3].imshow(label_image) 206 | 207 | ''' 208 | Step 6: Keep the labels with the largest area.保留最大区域 209 | ''' 210 | areas = [r.area for r in regionprops(label_image)] 211 | areas.sort() 212 | if len(areas) > 1: 213 | for region in regionprops(label_image): 214 | if region.area < areas[-1]: 215 | for coordinates in region.coords: 216 | label_image[coordinates[0], coordinates[1]] = 0 217 | binary = label_image > 0 218 | if plot == True: 219 | plots[1, 0].axis('off') 220 | plots[1, 0].set_title('keep the largest area') 221 | plots[1, 0].imshow(binary, cmap=plt.cm.bone) 222 | 223 | ''' 224 | Step 7: Fill in the small holes inside the binary mask .孔洞填充 225 | ''' 226 | edges = roberts(binary) 227 | binary = ndi.binary_fill_holes(edges) 228 | if plot == True: 229 | plots[1, 1].axis('off') 230 | plots[1, 1].set_title('fill in the small holes') 231 | plots[1, 1].imshow(binary, cmap=plt.cm.bone) 232 | 233 | ''' 234 | Step 8: show the input image. 235 | ''' 236 | if plot == True: 237 | plots[1, 3].axis('off') 238 | plots[1, 3].set_title('input image') 239 | plots[1, 3].imshow(img, cmap='gray') 240 | 241 | ''' 242 | Step 9: Superimpose the binary mask on the input image. 243 | ''' 244 | get_high_vals = binary == 0 245 | img[get_high_vals] = 0 246 | if plot == True: 247 | plots[1, 2].axis('off') 248 | plots[1, 2].set_title('superimpose the binary mask') 249 | plots[1, 2].imshow(img, cmap='gray') 250 | if show_now == True: 251 | plt.show() 252 | mask.append(binary) 253 | 254 | img[img > (window_max + window_length)] = window_max + window_length 255 | img[img < (window_min - window_length)] = window_min - window_length 256 | img = (img - window_min) / (window_max - window_min) 257 | img[get_high_vals] = 0 258 | if plot == True: 259 | fig = plt.figure() 260 | ax = fig.add_subplot(111) 261 | ax.imshow(img, cmap='gray') 262 | plt.show() 263 | 264 | return img, binary 265 | 266 | 267 | def largest_label_volume(im, bg=-1): 268 | vals, counts = np.unique(im, return_counts=True) 269 | counts = counts[vals != bg] 270 | vals = vals[vals != bg] 271 | if len(counts) > 0: 272 | return vals[np.argmax(counts)] 273 | else: 274 | return None 275 | 276 | 277 | def segment_lung_mask(image, fill_lung_structures=True): 278 | '''肺部图像分割 279 | 为了减少有问题的空间,我们可以分割肺部图像(有时候是附近的组织) 280 | 这包含一些步骤,包括区域增长和形态运算,此时,我们只分析相连组件 281 | ''' 282 | 283 | # 1是空气,2是肺部 284 | # not actually binary, but 1 and 2. 285 | # 0 is treated as background, which we do not want 286 | binary_image = np.array(image > -320, dtype=np.int8) + 1 287 | labels = measure.label(binary_image) # 连通区域标记 288 | 289 | # Pick the pixel in the very corner to determine which label is air. 290 | # Improvement: Pick multiple background labels from around the patient 291 | # More resistant to "trays" on which the patient lays cutting the air 292 | # around the person in half 293 | background_label = labels[0, 0, 0] 294 | # Fill the air around the person 295 | binary_image[background_label == labels] = 2 296 | # Method of filling the lung structures (that is superior to something like 297 | # morphological closing) 298 | if fill_lung_structures: 299 | # For every slice we determine the largest solid structure 300 | for i, axial_slice in enumerate(binary_image): 301 | axial_slice = axial_slice - 1 302 | labeling = measure.label(axial_slice) 303 | l_max = largest_label_volume(labeling, bg=0) 304 | if l_max is not None: # This slice contains some lung 305 | binary_image[i][labeling != l_max] = 1 306 | binary_image -= 1 # Make the image actual binary 307 | binary_image = 1 - binary_image # Invert it, lungs are now 1 308 | # Remove other air pockets insided body 309 | labels = measure.label(binary_image, background=0) 310 | l_max = largest_label_volume(labels, bg=0) 311 | if l_max is not None: # There are air pockets 312 | binary_image[labels != l_max] = 0 313 | return binary_image 314 | -------------------------------------------------------------------------------- /util/pre_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | import os 4 | import nibabel 5 | import numpy as np 6 | import random 7 | #import vtk 8 | #from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk 9 | 10 | def get_roi_size(inputVolume): 11 | [d_idxes, h_idxes, w_idxes] = np.nonzero(inputVolume) 12 | mind = d_idxes.min(); maxd = d_idxes.max() 13 | minh = h_idxes.min(); maxh = h_idxes.max() 14 | minw = w_idxes.min(); maxw = w_idxes.max() 15 | return [maxd - mind, maxh - minh, maxw - minw] 16 | 17 | def get_unique_image_name(img_name_list, subname): 18 | img_name = [x for x in img_name_list if subname in x] 19 | assert(len(img_name) == 1) 20 | return img_name[0] 21 | 22 | def load_nifty_volume_as_array(filename): 23 | # input shape [W, H, D] 24 | # output shape [D, H, W] 25 | img = nibabel.load(filename) 26 | data = img.get_data() 27 | data = np.transpose(data, [2,1,0]) 28 | return data 29 | 30 | #def load_vtk_volume_as_array(imgName): 31 | # if(imgName.endswith('nii')): 32 | # reader=vtk.vtkNIFTIImageReader() 33 | # elif(imgName.endswith('mha')): 34 | # reader = vtk.vtkMetaImageReader() 35 | # else: 36 | # raise ValueError('could not open file {0:}'.format(imgName)) 37 | # reader.SetFileName(imgName) 38 | # reader.Update() 39 | # vtkImg = reader.GetOutput() 40 | # shape = vtkImg.GetDimensions() 41 | # sc = vtkImg.GetPointData().GetScalars() 42 | # img_np = np.array(vtk_to_numpy(sc).reshape([shape[2],shape[1],shape[0]])) 43 | # return img_np 44 | 45 | def save_array_as_nifty_volume(data, filename): 46 | # numpy data shape [D, H, W] 47 | # nifty image shape [W, H, W] 48 | data = np.transpose(data, [2, 1, 0]) 49 | img = nibabel.Nifti1Image(data, np.eye(4)) 50 | nibabel.save(img, filename) 51 | 52 | # for brats 17 53 | def load_all_modalities_in_one_folder(patient_dir, ground_truth = True): 54 | img_name_list = os.listdir(patient_dir) 55 | img_list = [] 56 | sub_name_list = ['flair.nii', 't1ce.nii', 't1.nii', 't2.nii'] 57 | if(ground_truth): 58 | sub_name_list.append('seg.nii') 59 | for sub_name in sub_name_list: 60 | img_name = get_unique_image_name(img_name_list, sub_name) 61 | img = load_nifty_volume_as_array(os.path.join(patient_dir, img_name)) 62 | img_list.append(img) 63 | return img_list 64 | 65 | # for brats15 66 | def load_all_modalities_in_one_folder_15(patient_dir, ground_truth = True): 67 | img_name_list = os.listdir(patient_dir) 68 | print('image names', img_name_list) 69 | img_list = [] 70 | sub_name_list = ['Flair.', 'T1.', 'T1c.', 'T2.'] 71 | if(ground_truth): 72 | sub_name_list.append('OT.') 73 | for sub_name in sub_name_list: 74 | for img_name in img_name_list: 75 | if(sub_name in img_name): 76 | full_img_name = patient_dir + '/' + img_name + '/' + img_name + '.mha' 77 | print(full_img_name) 78 | img = load_vtk_volume_as_array(full_img_name) 79 | img_list.append(img) 80 | return img_list 81 | 82 | def get_itensity_statistics(volume, n_pxl, iten_sum, iten_sq_sum): 83 | volume = np.asanyarray(volume, np.float32) 84 | pixels = volume[volume > 0] 85 | n_pxl = n_pxl + len(pixels) 86 | iten_sum = iten_sum + pixels.sum() 87 | iten_sq_sum = iten_sq_sum + np.square(pixels).sum() 88 | return n_pxl, iten_sum, iten_sq_sum 89 | 90 | def get_all_patients_dir(data_root): 91 | sub_sets = ['HGG/', 'LGG/'] 92 | all_patients_list = [] 93 | for sub_source in sub_sets: 94 | sub_source = data_root + sub_source 95 | patient_list = os.listdir(sub_source) 96 | patient_list = [sub_source + x for x in patient_list if 'Brats' in x] 97 | all_patients_list.extend(patient_list) 98 | print('patients for ', sub_source,len(patient_list)) 99 | print("total patients ", len(all_patients_list)) 100 | return all_patients_list 101 | 102 | def get_roi_range_in_one_dimention(x0, x1, L): 103 | margin = L - (x1 - x0) 104 | mg0 = margin/2 105 | mg1 = margin - mg0 106 | x0 = x0 - mg0 107 | x1 = x1 + mg1 108 | return [x0, x1] 109 | 110 | def get_roi_from_volumes(volumes): 111 | [outD, outH, outW] = [144, 176, 144] 112 | [d_idxes, h_idxes, w_idxes] = np.nonzero(volumes[0]) 113 | mind = d_idxes.min(); maxd = d_idxes.max() 114 | minh = h_idxes.min(); maxh = h_idxes.max() 115 | minw = w_idxes.min(); maxw = w_idxes.max() 116 | print(mind, maxd, minh, maxh, minw, maxw) 117 | [mind, maxd] = get_roi_range_in_one_dimention(mind, maxd, outD) 118 | [minh, maxh] = get_roi_range_in_one_dimention(minh, maxh, outH) 119 | [minw, maxw] = get_roi_range_in_one_dimention(minw, maxw, outW) 120 | print(mind, maxd, minh, maxh, minw, maxw) 121 | roi_volumes = [] 122 | for volume in volumes: 123 | roi_volume = volume[np.ix_(range(mind, maxd), range(minh, maxh), range(minw, maxw))] 124 | roi_volumes.append(roi_volume) 125 | print(roi_volume.shape) 126 | return roi_volumes, [mind, maxd, minh, maxh, minw, maxw] 127 | 128 | def get_training_set_statistics(): 129 | source_root = '/Users/guotaiwang/Documents/data/BRATS2017/BRATS17TrainingData/' 130 | all_patients_list = get_all_patients_dir(source_root) 131 | 132 | # get itensity mean and std 133 | # n_pxls = np.zeros([4], np.float32) 134 | # iten_sum = np.zeros([4], np.float32) 135 | # iten_sq_sum = np.zeros([4], np.float32) 136 | # for patient_dir in all_patients_list: 137 | # volumes = load_all_modalities_in_one_folder(patient_dir) 138 | # for i in range(4): 139 | # n_pxls[i], iten_sum[i], iten_sq_sum[i] = get_itensity_statistics( 140 | # volumes[i], n_pxls[i], iten_sum[i], iten_sq_sum[i]) 141 | # print patient_dir 142 | # print volumes[0][volumes[0]>0].mean(), volumes[1][volumes[1]>0].mean(), volumes[2][volumes[2]>0].mean(), volumes[3][volumes[3]>0].mean() 143 | # mean = np.divide(iten_sum, n_pxls) 144 | # sq_men = np.divide(iten_sq_sum, n_pxls) 145 | # std = np.sqrt(sq_men - np.square(mean)) 146 | # print mean, std 147 | 148 | roi_size = [] 149 | for patient_dir in all_patients_list: 150 | volumes = load_all_modalities_in_one_folder(patient_dir) 151 | for i in range(4): 152 | roi = get_roi_size(volumes[i]) 153 | roi_size.append(roi) 154 | roi_size = np.asarray(roi_size) 155 | print(roi_size.mean(axis = 0), roi_size.std(axis = 0)) 156 | 157 | def extract_roi_for_training_set(): 158 | source_root = '/Users/guotaiwang/Documents/data/BRATS2017/BRATS17TrainingData/' 159 | target_root = 'Training_extract' 160 | sub_sets = ['HGG/', 'LGG/'] 161 | modality_names = ['flair.nii.gz', 't1ce.nii.gz', 't1.nii.gz', 't2.nii.gz', 'seg.nii.gz'] 162 | all_patients_list = get_all_patients_dir(source_root) 163 | for patient_dir in all_patients_list: 164 | volumes = load_all_modalities_in_one_folder(patient_dir) 165 | roi_volumes, roi = get_roi_from_volumes(volumes) 166 | for i in range(len(roi_volumes)): 167 | save_patient_dir = patient_dir.replace("BRATS17TrainingData", target_root) 168 | print(save_patient_dir) 169 | if(not os.path.isdir(save_patient_dir)): 170 | os.mkdir(save_patient_dir) 171 | save_name = os.path.join(save_patient_dir, modality_names[i]) 172 | img = nibabel.Nifti1Image(roi_volumes[i], np.eye(4)) 173 | nibabel.save(img, save_name) 174 | 175 | def split_data(split_name, seed): 176 | source_root = '/Users/guotaiwang/Documents/data/BRATS2017/Training_extract/' 177 | all_patients_list = get_all_patients_dir(source_root) 178 | random.seed(seed) 179 | n = len(all_patients_list) 180 | n_test = 50 181 | test_mask = np.zeros([n]) 182 | test_idx = random.sample(range(n), n_test) 183 | test_mask[test_idx] = 1 184 | 185 | train_list = [] 186 | test_list = [] 187 | for i in range(n): 188 | patient_split = all_patients_list[i].split('/') 189 | patient = patient_split[-2] + '/' + patient_split[-1] 190 | if(test_mask[i]): 191 | test_list.append(patient) 192 | else: 193 | train_list.append(patient) 194 | print("train_list", len(train_list)) 195 | print("test_list ", len(test_list)) 196 | train_file = open(split_name + '/train.txt', 'w') 197 | for patient in train_list: 198 | train_file.write("%s\n" % patient) 199 | test_file = open(split_name + '/test.txt', 'w') 200 | for patient in test_list: 201 | test_file.write("%s\n" % patient) 202 | seed_file = open(split_name + '/seed.txt', 'w') 203 | seed_file.write("%d"%seed) 204 | 205 | def Brats17_data_set_crop_rename(source_folder, save_folder, crop, ground_truth): 206 | patient_list = os.listdir(source_folder) 207 | patient_list = [x for x in patient_list if 'Brats17' in x] 208 | margin = 5 209 | save_postfix = ['Flair', 'T1c', 'T1', 'T2'] 210 | if(ground_truth): 211 | save_postfix.append('Label') 212 | print('patient number ', len(patient_list)) 213 | for patient_dir in patient_list: 214 | print(patient_dir) 215 | continue 216 | full_patient_dir = os.path.join(source_folder, patient_dir) 217 | imgs = load_all_modalities_in_one_folder(full_patient_dir, ground_truth = ground_truth) 218 | assert(len(imgs) == len(save_postfix)) 219 | if(crop): 220 | [d_idxes, h_idxes, w_idxes] = np.nonzero(imgs[0]) 221 | mind = d_idxes.min() - margin; maxd = d_idxes.max() + margin 222 | minh = h_idxes.min() - margin; maxh = h_idxes.max() + margin 223 | minw = w_idxes.min() - margin; maxw = w_idxes.max() + margin 224 | for mod_idx in range(len(save_postfix)): 225 | if(crop): 226 | roi_volume = imgs[mod_idx][np.ix_(range(mind, maxd), range(minh, maxh), range(minw, maxw))] 227 | else: 228 | roi_volume = imgs[mod_idx] 229 | save_name = "{0:}_{1:}.nii.gz".format(patient_dir, save_postfix[mod_idx]) 230 | save_name = os.path.join(save_folder, save_name) 231 | save_array_as_nifty_volume(roi_volume, save_name) 232 | 233 | def Brats15_data_set_crop_rename(source_folder, save_folder, crop): 234 | patient_list = os.listdir(source_folder) 235 | patient_list = [x for x in patient_list if 'brats' in x] 236 | margin = 5 237 | save_postfix = ['Flair', 'T1', 'T1c', 'T2', 'Label'] 238 | for patient_dir in patient_list: 239 | print(patient_dir) 240 | full_patient_dir = os.path.join(source_folder, patient_dir) 241 | imgs = load_all_modalities_in_one_folder_15(full_patient_dir, ground_truth = True) 242 | assert(len(imgs) == len(save_postfix)) 243 | if(crop): 244 | [d_idxes, h_idxes, w_idxes] = np.nonzero(imgs[0]) 245 | mind = d_idxes.min() - margin; maxd = d_idxes.max() + margin 246 | minh = h_idxes.min() - margin; maxh = h_idxes.max() + margin 247 | minw = w_idxes.min() - margin; maxw = w_idxes.max() + margin 248 | for mod_idx in range(len(imgs)): 249 | if(crop): 250 | roi_volume = imgs[mod_idx][np.ix_(range(mind, maxd), range(minh, maxh), range(minw, maxw))] 251 | else: 252 | roi_volume = imgs[mod_idx] 253 | save_name = "{0:}_{1:}.nii.gz".format(patient_dir, save_postfix[mod_idx]) 254 | save_name = os.path.join(save_folder, save_name) 255 | save_array_as_nifty_volume(roi_volume, save_name) 256 | 257 | 258 | if __name__ == "__main__": 259 | # get_training_set_statistics() 260 | # brats 15 crop and rename 261 | validation_data_source = '/Users/guotaiwang/Documents/data/BRATS/Brats2015_Training/HGG' 262 | validation_data_save = '/Users/guotaiwang/Documents/data/BRATS/Brats2015_Train_croprename/HGG' 263 | Brats15_data_set_crop_rename(validation_data_source,validation_data_save, True) 264 | # brats 17 validation crop and rename, for validation data, no crop 265 | # validation_data_source = '/Users/guotaiwang/Documents/data/BRATS2017/Brats17TestingData' 266 | # validation_data_save = '/Users/guotaiwang/Documents/data/BRATS2017/Brats17TestingData_renamed' 267 | # Brats17_data_set_crop_rename(validation_data_source,validation_data_save, False, False) 268 | # 269 | # load_name = '/Users/guotaiwang/Documents/data/BRATS2017/Brats17TrainingData_crop_renamed/HGG/HGG1_FLAIR.nii.gz' 270 | # volume = load_nifty_volume_as_array(load_name) 271 | # print volume.shape 272 | # sub_volume = volume[0:100][:][:] 273 | # print sub_volume.shape 274 | # save_folder = '/Users/guotaiwang/Documents/workspace/tf_project/tf_brats/data_process/temp_data' 275 | # save_name = save_folder + '/Flair1sub.nii.gz' 276 | # save_array_as_nifty_volume(sub_volume, save_name) 277 | 278 | -------------------------------------------------------------------------------- /util/train_test_func.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | import tensorflow as tf 4 | from data_process.data_process_func import * 5 | 6 | def volume_probability_prediction(temp_imgs, data_shape, label_shape, data_channel, 7 | class_num, batch_size, sess, proby, x): 8 | ''' 9 | Test one image with sub regions along z-axis 10 | ''' 11 | [D, H, W] = temp_imgs[0].shape 12 | input_center = [int(D/2), int(H/2), int(W/2)] 13 | temp_prob = np.zeros([D, H, W, class_num]) 14 | sub_image_baches = [] 15 | for center_slice in range(int(label_shape[0]/2), D + int(label_shape[0]/2), label_shape[0]): 16 | center_slice = min(center_slice, D - int(label_shape[0]/2)) 17 | sub_image_bach = [] 18 | for chn in range(data_channel): 19 | temp_input_center = [center_slice, input_center[1], input_center[2]] 20 | sub_image = extract_roi_from_volume( 21 | temp_imgs[chn], temp_input_center, data_shape) 22 | sub_image_bach.append(sub_image) 23 | sub_image_bach = np.asanyarray(sub_image_bach, np.float32) 24 | sub_image_baches.append(sub_image_bach) 25 | total_batch = len(sub_image_baches) 26 | max_mini_batch = int((total_batch+batch_size-1)/batch_size) 27 | sub_label_idx = 0 28 | for mini_batch_idx in range(max_mini_batch): 29 | data_mini_batch = sub_image_baches[mini_batch_idx*batch_size: 30 | min((mini_batch_idx+1)*batch_size, total_batch)] 31 | if(mini_batch_idx == max_mini_batch - 1): 32 | for idx in range(batch_size - (total_batch - mini_batch_idx*batch_size)): 33 | data_mini_batch.append(np.random.normal(0, 1, size = [data_channel] + data_shape)) 34 | data_mini_batch = np.asarray(data_mini_batch, np.float32) 35 | data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) 36 | prob_mini_batch = sess.run(proby, feed_dict = {x:data_mini_batch}) 37 | 38 | for batch_idx in range(prob_mini_batch.shape[0]): 39 | center_slice = sub_label_idx*label_shape[0] + int(label_shape[0]/2) 40 | center_slice = min(center_slice, D - int(label_shape[0]/2)) 41 | temp_input_center = [center_slice, input_center[1], input_center[2], int(class_num/2)] 42 | sub_prob = np.reshape(prob_mini_batch[batch_idx], label_shape + [class_num]) 43 | temp_prob = set_roi_to_volume(temp_prob, temp_input_center, sub_prob) 44 | sub_label_idx = sub_label_idx + 1 45 | return temp_prob 46 | 47 | 48 | def volume_probability_prediction_3d_roi(temp_imgs, data_shape, label_shape, data_channel, 49 | class_num, batch_size, sess, proby, x): 50 | ''' 51 | Test one image with sub regions along x, y, z axis 52 | ''' 53 | [D, H, W] = temp_imgs[0].shape 54 | temp_prob = np.zeros([D, H, W, class_num]) 55 | sub_image_batches = [] 56 | sub_image_centers = [] 57 | roid_half = int(label_shape[0]/2) 58 | roih_half = int(label_shape[1]/2) 59 | roiw_half = int(label_shape[2]/2) 60 | for centerd in range(roid_half, D + roid_half, label_shape[0]): 61 | centerd = min(centerd, D - roid_half) 62 | for centerh in range(roih_half, H + roih_half, label_shape[1]): 63 | centerh = min(centerh, H - roih_half) 64 | for centerw in range(roiw_half, W + roiw_half, label_shape[2]): 65 | centerw = min(centerw, W - roiw_half) 66 | temp_input_center = [centerd, centerh, centerw] 67 | sub_image_centers.append(temp_input_center) 68 | sub_image_batch = [] 69 | for chn in range(data_channel): 70 | sub_image = extract_roi_from_volume(temp_imgs[chn], temp_input_center, data_shape) 71 | sub_image_batch.append(sub_image) 72 | sub_image_bach = np.asanyarray(sub_image_batch, np.float32) 73 | sub_image_batches.append(sub_image_bach) 74 | 75 | total_batch = len(sub_image_batches) 76 | max_mini_batch = int((total_batch + batch_size - 1)/batch_size) 77 | sub_label_idx = 0 78 | for mini_batch_idx in range(max_mini_batch): 79 | data_mini_batch = sub_image_batches[mini_batch_idx*batch_size: 80 | min((mini_batch_idx+1)*batch_size, total_batch)] 81 | if(mini_batch_idx == max_mini_batch - 1): 82 | for idx in range(batch_size - (total_batch - mini_batch_idx*batch_size)): 83 | data_mini_batch.append(np.random.normal(0, 1, size = [data_channel] + data_shape)) 84 | data_mini_batch = np.asanyarray(data_mini_batch, np.float32) 85 | data_mini_batch = np.transpose(data_mini_batch, [0, 2, 3, 4, 1]) 86 | outprob_mini_batch = sess.run(proby, feed_dict = {x:data_mini_batch}) 87 | 88 | for batch_idx in range(batch_size): 89 | glb_batch_idx = batch_idx + mini_batch_idx * batch_size 90 | if(glb_batch_idx >= total_batch): 91 | continue 92 | temp_center = sub_image_centers[glb_batch_idx] 93 | temp_prob = set_roi_to_volume(temp_prob, temp_center + [1], outprob_mini_batch[batch_idx]) 94 | sub_label_idx = sub_label_idx + 1 95 | return temp_prob 96 | 97 | def volume_probability_prediction_dynamic_shape(temp_imgs, data_shape, label_shape, data_channel, 98 | class_num, batch_size, sess, net): 99 | ''' 100 | Test one image with sub regions along z-axis 101 | The height and width of input tensor is adapted to those of the input image 102 | ''' 103 | # construct graph 104 | [D, H, W] = temp_imgs[0].shape 105 | Hx = max(int((H+3)/4)*4, data_shape[1]) 106 | Wx = max(int((W+3)/4)*4, data_shape[2]) 107 | data_slice = data_shape[0] 108 | label_slice = label_shape[0] 109 | full_data_shape = [batch_size, data_slice, Hx, Wx, data_channel] 110 | x = tf.placeholder(tf.float32, full_data_shape) 111 | predicty = net(x, is_training = True) 112 | proby = tf.nn.softmax(predicty) 113 | 114 | new_data_shape = [data_slice, Hx, Wx] 115 | new_label_shape = [label_slice, Hx, Wx] 116 | temp_prob = volume_probability_prediction(temp_imgs, new_data_shape, new_label_shape, data_channel, 117 | class_num, batch_size, sess, proby, x) 118 | return temp_prob 119 | 120 | def test_one_image_three_nets_adaptive_shape(temp_imgs, data_shapes, label_shapes, data_channel, class_num, 121 | batch_size, sess, nets, outputs, inputs, shape_mode): 122 | ''' 123 | Test one image with three anisotropic networks with fixed or adaptable tensor height and width. 124 | These networks are used in axial, saggital and coronal view respectively. 125 | shape_mode: 0: use fixed tensor shape in all direction 126 | 1: compare tensor shape and image shape and then select fixed or adaptive tensor shape 127 | 2: use adaptive tensor shape in all direction 128 | ''' 129 | [ax_data_shape, sg_data_shape, cr_data_shape] = data_shapes 130 | [ax_label_shape, sg_label_shape, cr_label_shape] = label_shapes 131 | [D, H, W] = temp_imgs[0].shape 132 | if(shape_mode == 0 or (shape_mode == 1 and (H <= ax_data_shape[1] and W <= ax_data_shape[2]))): 133 | prob = volume_probability_prediction(temp_imgs, ax_data_shape, ax_label_shape, data_channel, 134 | class_num, batch_size, sess, outputs[0], inputs[0]) 135 | else: 136 | prob = volume_probability_prediction_dynamic_shape(temp_imgs, ax_data_shape, ax_label_shape, data_channel, 137 | class_num, batch_size, sess, nets[0]) 138 | 139 | tr_volumes1 = transpose_volumes(temp_imgs, 'sagittal') 140 | [sgD, sgH, sgW] = tr_volumes1[0].shape 141 | if(shape_mode == 0 or (shape_mode == 1 and (sgH <= sg_data_shape[1] and sgW <= sg_data_shape[2]))): 142 | prob1 = volume_probability_prediction(tr_volumes1, sg_data_shape, sg_label_shape, data_channel, 143 | class_num, batch_size, sess, outputs[1], inputs[1]) 144 | else: 145 | prob1 = volume_probability_prediction_dynamic_shape(tr_volumes1, sg_data_shape, sg_label_shape, data_channel, 146 | class_num, batch_size, sess, nets[1]) 147 | prob1 = np.transpose(prob1, [1,2,0,3]) 148 | 149 | tr_volumes2 = transpose_volumes(temp_imgs, 'coronal') 150 | [trD, trH, trW] = tr_volumes2[0].shape 151 | if(shape_mode == 0 or (shape_mode == 1 and (trH <= cr_data_shape[1] and trW <= cr_data_shape[2]))): 152 | prob2 = volume_probability_prediction(tr_volumes2, cr_data_shape, cr_label_shape, data_channel, 153 | class_num, batch_size, sess, outputs[2], inputs[2]) 154 | else: 155 | prob2 = volume_probability_prediction_dynamic_shape(tr_volumes2, cr_data_shape, cr_label_shape, data_channel, 156 | class_num, batch_size, sess, nets[2]) 157 | prob2 = np.transpose(prob2, [1,0,2,3]) 158 | 159 | prob = (prob + prob1 + prob2)/3.0 160 | return prob 161 | -------------------------------------------------------------------------------- /util/visualization/3.py: -------------------------------------------------------------------------------- 1 | 3 2 | -------------------------------------------------------------------------------- /util/visualization/evalution.py: -------------------------------------------------------------------------------- 1 | from Training.util.binary import dc,assd 2 | import os 3 | import numpy as np 4 | from Training.data_process.data_process_func import load_nifty_volume_as_array 5 | 6 | def one_hot(img, nb_classes): 7 | hot_img = np.zeros([nb_classes]+list(img.shape)) 8 | for i in range(nb_classes): 9 | hot_img[i][np.where(img == i)] = 1 10 | return hot_img 11 | 12 | def evaluation(folder, evaluate_dice, evaluate_assd): 13 | patient_list = os.listdir(folder) 14 | dice_all_data = [] 15 | assd_all_data = [] 16 | for patient in patient_list: 17 | s_name = os.path.join(folder, patient + '/label.npy') 18 | g_name = os.path.join(folder, patient + '/InterSeg.nii.gz') 19 | s_volume = np.int64(np.load(s_name)) 20 | g_volume = load_nifty_volume_as_array(g_name) 21 | s_volume = one_hot(s_volume, nb_classes=5) 22 | g_volume = one_hot(g_volume, nb_classes=5) 23 | if evaluate_dice: 24 | dice=[] 25 | for i in range(5): 26 | dice.append(dc(g_volume[i], s_volume[i])) 27 | dice_all_data.append(dice) 28 | print(patient, dice) 29 | if evaluate_assd: 30 | Assd = [] 31 | for i in range(5): 32 | Assd.append(assd(g_volume[i], s_volume[i])) 33 | assd_all_data.append(Assd) 34 | print(patient, Assd) 35 | if evaluate_dice: 36 | dice_all_data = np.asarray(dice_all_data) 37 | dice_mean = [dice_all_data.mean(axis = 0)] 38 | dice_std = [dice_all_data.std(axis = 0)] 39 | np.savetxt(folder + '/dice_all.txt', dice_all_data) 40 | np.savetxt(folder + '/dice_mean.txt', dice_mean) 41 | np.savetxt(folder + '/dice_std.txt', dice_std) 42 | print('dice mean ', dice_mean) 43 | print('dice std ', dice_std) 44 | if evaluate_assd: 45 | assd_all_data = np.asarray(assd_all_data) 46 | assd_mean = [assd_all_data.mean(axis = 0)] 47 | assd_std = [assd_all_data.std(axis = 0)] 48 | np.savetxt(folder + '/dice_all.txt', assd_all_data) 49 | np.savetxt(folder + '/dice_mean.txt', assd_mean) 50 | np.savetxt(folder + '/dice_std.txt', assd_std) 51 | print('assd mean ', assd_mean) 52 | print('assd std ', assd_std) 53 | 54 | 55 | 56 | evaluate_dice=False 57 | evaluate_assd=True 58 | if __name__ =='__main__': 59 | folder = '/lyc/Head-Neck-CT/3D_data/valid' 60 | evaluation(folder, evaluate_dice, evaluate_assd) 61 | 62 | 63 | -------------------------------------------------------------------------------- /util/visualization/show_Distance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from skimage import measure,draw 5 | CTpath = '/lyc/abdomen/data_zoom/test/btcv01/' 6 | file_wanted = ['imgzoom.npy', 'labelzoom.npy', 7 | 'Seg.npy', 'FSDis_combine.npy', 'ManualDis.npy'] 8 | 9 | # img = np.load(CTpath+file_wanted[0]) 10 | # label = np.load(CTpath+file_wanted[1]) 11 | # Seg = np.load(CTpath+file_wanted[2]) 12 | FSDis = np.float64(np.load(CTpath+file_wanted[3])) 13 | label = np.load(CTpath+file_wanted[1]) 14 | plt.figure(dpi=800) 15 | for i in np.arange(1, FSDis[0].shape[0]): 16 | f, plots = plt.subplots(2, 4, figsize=[60, 60]) 17 | plots[0, 0].imshow(FSDis[0][i], cmap='gray') 18 | plots[0, 0].set_title('image{0:}'.format(i)) 19 | plots[0, 1].imshow(FSDis[1][i], cmap='gray') 20 | plots[0, 2].imshow(label[i], cmap='gray') 21 | plots[0, 2].set_title('lab') 22 | plots[0, 3].imshow(FSDis[7][i], cmap='gray') 23 | plots[1, 0].imshow(FSDis[8][i], cmap='gray') 24 | plots[1, 1].imshow(FSDis[9][i], cmap='gray') 25 | plots[1, 2].imshow(FSDis[10][i], cmap='gray') 26 | plots[1, 3].imshow(FSDis[11][i], cmap='gray') 27 | plt.show() -------------------------------------------------------------------------------- /util/visualization/show_Label_contours.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from skimage import measure,draw 5 | 6 | 7 | def show_label_contours(label, img): 8 | ''' 9 | :param label: int,Length*Height 10 | :param img: float,Length*Height 11 | ''' 12 | for ii in range(img.shape[1]): 13 | plt.imshow(img[ii], zorder=10,cmap='gray') 14 | contours3 = measure.find_contours(label[ii], 0.1) 15 | for n, contour in enumerate(contours3): 16 | plt.plot(contour[:, 1], contour[:, 0], 'g', zorder=20) 17 | plt.show() 18 | 19 | 20 | CTpath = '/lyc/Head-Neck-CT/3D_data/valid/liming/' 21 | file_wanted = ['Img.npy', 'label.npy'] 22 | img = np.load(CTpath + file_wanted[0]) 23 | label = np.load(CTpath + file_wanted[1]) 24 | show_label_contours(label,img) -------------------------------------------------------------------------------- /util/visualization/show_boxplot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import loadtxt 3 | import matplotlib.pyplot as plt 4 | import os 5 | def show_performance(method_names, Data, ylabel, output_name, classnum, itertime, edge_color, fill_color): 6 | fig, ax = plt.subplots(figsize=(2.5,2.8)) 7 | x=[i+1 for i in range(classnum)] 8 | for ii in range(itertime): 9 | flierprops = dict(marker='+', markerfacecolor=edge_color[ii], markersize=7, 10 | linestyle='none', markeredgecolor=edge_color[ii]) 11 | bp = ax.boxplot(Data[ii], patch_artist=True, flierprops=flierprops) 12 | for element in ['boxes', 'whiskers', 'means', 'caps']: 13 | plt.setp(bp[element], color=edge_color[ii]) 14 | for patch in bp['boxes']: 15 | patch.set(facecolor=fill_color) 16 | 17 | ylabel = ylabel 18 | plt.ylim(0, 1000) 19 | plt.ylabel(ylabel) 20 | plt.xticks(x, method_names,rotation=30, ha='right') 21 | plt.subplots_adjust(left=0.3, right=0.8, top=0.95, bottom=0.25) 22 | plt.show() 23 | # fig.savefig(output_name, dpi=400) 24 | 25 | data_root = '/lyc/Head-Neck-CT/3D_data/valid/' 26 | output_name = '/home/uestc-c1501b/paper769/pic/Time_compare.png' 27 | patient_list = os.listdir(data_root) 28 | dice_list = [[] for _ in range(4)] 29 | i = 0 30 | classnum = 3 31 | itertime = 1 32 | Data_list = [[[336,304,262,251,290,276,255,273,337,237,252,231,196,199,311], [308,224,208,219,202,288,231,188,269,236,169,182,142,155,156] 33 | , [668, 621, 423, 692, 562, 629, 717, 793, 672, 825, 756, 777, 623]]] 34 | method_names = ['Overlay', 'Combine', 'Normal'] 35 | edge_color = ['brown'] 36 | ylabel = 'User Time (s)' 37 | fill_color = 'white' 38 | all_data = np.asarray(Data_list) 39 | show_performance(method_names, all_data, ylabel, output_name, classnum=classnum, itertime=itertime, edge_color=edge_color, fill_color=fill_color) 40 | -------------------------------------------------------------------------------- /util/visualization/show_multi_hist.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/visualization/show_param.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def show_param(net): 4 | params = list(net.parameters()) 5 | k = 0 6 | for i in params: 7 | l = 1 8 | print("该层的结构:", str(list(i.size()))) 9 | for j in i.size(): 10 | l *= j 11 | print("该层参数和:", str(l)) 12 | k = k + l 13 | print("总参数数量和:" + str(k)) 14 | -------------------------------------------------------------------------------- /util/visualization/visualize_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, print_function 3 | import scipy 4 | from scipy import ndimage 5 | from PIL import Image 6 | import numpy as np 7 | from graphviz import Digraph 8 | import torch 9 | from torch.autograd import Variable 10 | from visdom import Visdom 11 | 12 | def add_countor(In, Seg, Color=(0, 255, 0)): 13 | Out = In.copy() 14 | [H, W] = In.size 15 | for i in range(H): 16 | for j in range(W): 17 | if(i==0 or i==H-1 or j==0 or j == W-1): 18 | if(Seg.getpixel((i,j))!=0): 19 | Out.putpixel((i,j), Color) 20 | elif(Seg.getpixel((i,j))!=0 and \ 21 | not(Seg.getpixel((i-1,j))!=0 and \ 22 | Seg.getpixel((i+1,j))!=0 and \ 23 | Seg.getpixel((i,j-1))!=0 and \ 24 | Seg.getpixel((i,j+1))!=0)): 25 | Out.putpixel((i,j), Color) 26 | return Out 27 | 28 | def add_segmentation(img, seg, Color=(0, 255, 0)): 29 | seg = np.asarray(seg) 30 | if(img.size[1] != seg.shape[0] or img.size[0] != seg.shape[1]): 31 | print('segmentation has been resized') 32 | seg = scipy.misc.imresize(seg, (img.size[1], img.size[0]), interp='nearest') 33 | strt = ndimage.generate_binary_structure(2, 1) 34 | seg = np.asarray(ndimage.morphology.binary_opening(seg, strt), np.uint8) 35 | seg = np.asarray(ndimage.morphology.binary_closing(seg, strt), np.uint8) 36 | 37 | img_show = add_countor(img, Image.fromarray(seg), Color) 38 | strt = ndimage.generate_binary_structure(2, 1) 39 | seg = np.asarray(ndimage.morphology.binary_dilation(seg, strt), np.uint8) 40 | img_show = add_countor(img_show, Image.fromarray(seg), Color) 41 | return img_show 42 | 43 | 44 | class loss_visualize(object): 45 | def __init__(self, class_num, env='loss'): 46 | self.viz = Visdom(env=env) 47 | epoch = 0 48 | self.loss = self.viz.line(X=np.array([epoch]), 49 | Y=np.zeros([1, class_num+1]), # +1是因为除去背景还有train与test 50 | opts=dict(showlegend=True)) 51 | 52 | def plot_loss(self, epoch, batch_dice): 53 | train_dice_mean = np.asarray([batch_dice[1][1::].mean(axis=0)]) 54 | valid_dice_classes = batch_dice[0][1::] 55 | valid_dice_mean = np.asarray([valid_dice_classes.mean(axis=0)]) 56 | dice = np.append(np.append(train_dice_mean, 57 | valid_dice_mean), valid_dice_classes)[np.newaxis, :] 58 | self.viz.line( 59 | X=np.array([epoch+1]), 60 | Y=dice, 61 | win=self.loss, # win要保持一致 62 | update='append') -------------------------------------------------------------------------------- /weights_center_crop/multi_thresh_1/readme: -------------------------------------------------------------------------------- 1 | weights for SLF1 2 | -------------------------------------------------------------------------------- /weights_center_crop/multi_thresh_2/readme: -------------------------------------------------------------------------------- 1 | weights for SLF2 2 | -------------------------------------------------------------------------------- /weights_center_crop/multi_thresh_3/readme: -------------------------------------------------------------------------------- 1 | weights for SLF3 2 | --------------------------------------------------------------------------------