├── LICENSE ├── README.md ├── code ├── README.md ├── dataloaders │ ├── la_heart.py │ ├── livertumor.py │ └── utils.py ├── networks │ ├── vnet.py │ ├── vnet_multi_head.py │ ├── vnet_multi_task.py │ ├── vnet_rec.py │ └── vnet_sdf.py ├── test_LA.py ├── test_LA_AAAISDF.py ├── test_LA_MultiHead_FGDTM.py ├── test_LA_MultiHead_SDF.py ├── test_LA_Rec_FGDTM.py ├── test_LA_Rec_SDF.py ├── test_LITS.py ├── test_LITS_MultiHead_SDF.py ├── test_LITS_Rec_SDF.py ├── test_util.py ├── train_LA.py ├── train_LA_AAAISDF.py ├── train_LA_AAAISDF_L1.py ├── train_LA_BD.py ├── train_LA_HD.py ├── train_LA_MultiHead_FGDTM_L1.py ├── train_LA_MultiHead_FGDTM_L1PlusL2.py ├── train_LA_MultiHead_FGDTM_L2.py ├── train_LA_MultiHead_SDF_L1.py ├── train_LA_MultiHead_SDF_L1PlusL2.py ├── train_LA_MultiHead_SDF_L2.py ├── train_LA_Rec_FGDTM_L1.py ├── train_LA_Rec_FGDTM_L1PlusL2.py ├── train_LA_Rec_FGDTM_L2.py ├── train_LA_Rec_SDF_L1.py ├── train_LA_Rec_SDF_L1PlusL2.py ├── train_LA_Rec_SDF_L2.py ├── train_LA_SDF.py ├── train_LITS.py ├── train_LITS_BD.py ├── train_LITS_HD.py ├── train_LITS_Rec_SDF_L1.py ├── train_LITS_Rec_SDF_L1PlusL2.py ├── train_LITS_Rec_SDF_L2.py ├── train_LiTS_MultiHead_SDF_L1.py ├── train_LiTS_MultiHead_SDF_L2.py └── utils │ ├── losses.py │ ├── ramps.py │ └── util.py ├── data ├── README.md ├── test.list ├── test_unlabel.list └── train.list └── overview.PNG /README.md: -------------------------------------------------------------------------------- 1 | # 3D Medical Image Segmentation With Distance Transform Maps 2 | 3 | ## Motivation: How Distance Transform Maps Boost Segmentation CNNs [(MIDL 2020)](https://2020.midl.io/papers/ma20a.html) 4 | 5 | Incorporating the distance Transform maps of image segmentation labels into CNNs-based segmentation tasks has received significant attention in 2019. These methods can be classified into two main classes in terms of the main usage of distance transform maps. 6 | 7 | - Designing new loss functions 8 | - Adding an auxiliary task, e.g. distance map regression 9 | 10 | ![Overview](https://github.com/JunMa11/SegWithDistMap/blob/master/overview.PNG) 11 | 12 | However, with these new methods on the one hand and the diversity of the specific implementations and dataset-related challenges on the other, it's hard to figure out which design can generalize well beyond the experiments in the original papers. 13 | In this repository, we want to re-implement these methods (published in 2019) and evaluate them on the same 3D segmentation tasks (heart and liver tumor segmentation). 14 | 15 | ## Experiments 16 | 17 | | Task | LA Contributor | GPU | LiTS Contributor | GPU | 18 | | -------------------------------------- | ------------- | ---------- | ------------ | ---------- | 19 | | Boundary loss | [Yiwen Zhang](https://github.com/whisney) | 2080ti | [Mengzhang Li](https://github.com/MengzhangLI) | TITIAN RTX | 20 | | Hausdorff loss | [Yiwen Zhang](https://github.com/whisney) | 2080ti | [Mengzhang Li](https://github.com/MengzhangLI) | TITIAN RTX | 21 | | Signed distance map loss (AAAI 2020) | [Zhan Wei](https://github.com/zhanwei33) | 1080ti | cancel | - | 22 | | Multi-Head: FG DTM regression-L1 | [Yiwen Zhang](https://github.com/whisney) | 2080ti | cancel | - | 23 | | Multi-Head: FG DTM regression-L2 | [Jianan Liu]() | 2080ti | cancel | - | 24 | | Multi-Head: FG DTM regression-L1 + L2 | [Gaoxiang Chen](https://github.com/AMSTLHX) | 2080ti | cancel | - | 25 | | Multi-Head: SDF regression-L1 | [Feng Cheng](836155475@qq.com) | TITAN X | [Chao Peng](https://github.com/AMSTLHX) | TITAN RTX | 26 | | Multi-Head: SDF regression-L2 | [Rongfei Lv](https://github.com/lrfdl) | TITAN RTX | [Rongfei Lv](https://github.com/lrfdl) | TITAN RTX | 27 | | Multi-Head: SDF regression-L1+L2 | [Yixin Wang](https://github.com/Wangyixinxin) | P100 | cancel | - | 28 | | Add-Branch: FG DTM regression-L1 | [Yaliang Zhao](441926980) | TITAN RTX | cancel | - | 29 | | Add-Branch: FG DTM regression-L2 | [Mengzhang Li](https://github.com/MengzhangLI) | TITIAN RTX | cancel | - | 30 | | Add-Branch: FG DTM regression-L1+L2 | [Yixin Wang](https://github.com/Wangyixinxin) | P100 | cancel | - | 31 | | Add-Branch: SDF regression-L1 | [Feng Cheng](836155475@qq.com) | TITAN X | [Yixin Wang](https://github.com/Wangyixinxin) | TITAN RTX | 32 | | Add-Branch: SDF regression-L2 | [Feng Cheng](836155475@qq.com) | TITAN X | [Yixin Wang](https://github.com/Wangyixinxin) | P100 | 33 | | Add-Branch: SDF regression-L1+L2 | [Yixin Wang](https://github.com/Wangyixinxin) | P100 | [Yunpeng Wang]() | TITAN XP | 34 | 35 | > [Here](https://github.com/JunMa11/SegWithDistMap/tree/master/code) is the code, and trained modles can be downloaded from [Baidu Disk](https://pan.baidu.com/s/1E9SlHw4DXuvsqFQRD_HHag) (pw:mgn0). 36 | 37 | 38 | 39 | ## Related Work in 2019 40 | 41 | ### New loss functions 42 | 43 | | Date | First author | Title | Official Code | Publication | 44 | | ---- | ------------- | --------------------------- | -------------- | ------------------------------ | 45 | | 2019 | Yuan Xue | Shape-Aware Organ Segmentation by Predicting Signed Distance Maps | None | [AAAI 2020](https://www.aaai.org/Papers/AAAI/2020GB/AAAI-XueY.1482.pdf) | 46 | | 2019 | [Hoel Kervadec](https://scholar.google.com.hk/citations?user=yeFGhfgAAAAJ&hl=zh-CN&oi=sra) | Boundary loss for highly unbalanced segmentation | [pytorch](https://github.com/LIVIAETS/surface-loss) | [MIDL 2019](http://proceedings.mlr.press/v102/kervadec19a.html) | 47 | |2019|Davood Karimi|Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural Networks [(arxiv)](https://arxiv.org/abs/1904.10030) |None|[TMI 2019](https://ieeexplore.ieee.org/document/8767031)| 48 | 49 | 50 | 51 | ### Auxiliary tasks 52 | 53 | | Date | First author | Title | Official Code | Publication | 54 | | ---- | ------------- | --------------------------- | -------------- | ------------------------------ | 55 | | 2019 | Yan Wang | Deep Distance Transform for Tubular Structure Segmentation in CT Scans | None | [CVPR2020](http://openaccess.thecvf.com/content_CVPR_2020/html/Wang_Deep_Distance_Transform_for_Tubular_Structure_Segmentation_in_CT_Scans_CVPR_2020_paper.html) | 56 | | 2019 | [Shusil Dangi](https://scholar.google.com.hk/citations?user=h12ifugAAAAJ&hl=zh-CN&oi=sra) |A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation [(arxiv)](https://arxiv.org/abs/1901.01238) | None | [Medical Physics](https://aapm.onlinelibrary.wiley.com/doi/abs/10.1002/mp.13853) | 57 | |2019|[Fernando Navarro](https://scholar.google.com.hk/citations?user=rRKrhrwAAAAJ&hl=zh-CN&oi=sra)|Shape-Aware Complementary-Task Learning for Multi-organ Segmentation [(arxiv)](https://arxiv.org/abs/1908.05099)|None| [MICCAI MLMI 2019](https://link.springer.com/chapter/10.1007/978-3-030-32692-0_71)| 58 | |2019|[Balamurali Murugesan](https://scholar.google.com.hk/citations?user=TmuKf44AAAAJ&hl=en&oi=sra)|Psi-Net: Shape and boundary aware joint multi-task deep network for medical image segmentation [(arXiv)](https://arxiv.org/abs/1902.04099)|None|[EMBC](https://ieeexplore.ieee.org/abstract/document/8857339)| 59 | |2019|[Balamurali Murugesan](https://scholar.google.com.hk/citations?user=TmuKf44AAAAJ&hl=en&oi=sra)|Conv-MCD: A Plug-and-Play Multi-task Module for Medical Image Segmentation [(arXiv)](https://arxiv.org/abs/1908.05311)|[Pytorch](https://github.com/Bala93/Multi-task-deep-network)|[MLMI](https://link.springer.com/chapter/10.1007/978-3-030-32692-0_34)| 60 | 61 | 62 | ## Acknowledgments 63 | 64 | The authors would like to thank the organization team of MICCAI 2017 liver tumor segmentation challenge MICCAI 2018 and left atrial segmentation challenge for the publicly available dataset. 65 | We also thank the reviewers for their valuable comments and suggestions. 66 | We appreciate Cheng Chen, Feng Cheng, Mengzhang Li, Chengwei Su, Chengfeng Zhou and Yaliang Zhao to help us finish some experiments. 67 | Last but not least, we thank Lequan Yu for his great PyTorch implementation of [V-Net](https://github.com/yulequan/UA-MT) and Fabian Isensee for his great PyTorch implementation of [nnU-Net](https://github.com/MIC-DKFZ/nnUNett). 68 | 69 | 70 | ## Including the following citation in your work would be highly appreciated. 71 | 72 | ``` 73 | @inproceedings{ma-MIDL2020-SegWithDist, 74 | title={How Distance Transform Maps Boost Segmentation CNNs: An Empirical Study}, 75 | author={Ma, Jun and Wei, Zhan and Zhang, Yiwen and Wang, Yixin and Lv, Rongfei and Zhu, Cheng and Chen, Gaoxiang and Liu, Jianan and Peng, Chao and Wang, Lei and Wang, Yunpeng and Chen, Jianan}, 76 | booktitle={Medical Imaging with Deep Learning}, 77 | pages = {479--492}, 78 | volume = {121}, 79 | month = {06--08 Jul}, 80 | year={2020}, 81 | series = {Proceedings of Machine Learning Research}, 82 | editor = {Tal Arbel and Ismail Ben Ayed and Marleen de Bruijne and Maxime Descoteaux and Herve Lombaert and Christopher Pal}, 83 | publisher = {PMLR}, 84 | url = {http://proceedings.mlr.press/v121/ma20b.html} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | ### Requirements 2 | - pytorch>=1.0 3 | - tensorboardX 4 | - scikit-image 5 | - scipy 6 | - tqdm 7 | 8 | **Note: the code has been tested on ubuntu. I'm not sure whether it works on windows.** 9 | 10 | 11 | ## V-Net with different loss functions 12 | ### V-Net Training 13 | - LA Heart MRI dataset: run `python train_LA.py` 14 | - Liver tumor CT dataset: run `python train_LITS.py` 15 | 16 | ### V-Net with boundary loss 17 | - LA Heart MRI dataset: run `python train_LA_BD.py` 18 | - Liver tumor CT dataset: run `python train_LITS_BD.py` 19 | 20 | > You need to set `--exp` properly. Both [compute_sdf](https://github.com/JunMa11/SegWithDistMap/blob/ed55b65889a4ba4cf9f7532e63124fe9ba10fcf0/code/train_LA_BD.py#L94) and [compute_sdf1_1](https://github.com/JunMa11/SegWithDistMap/blob/ed55b65889a4ba4cf9f7532e63124fe9ba10fcf0/code/train_LA_BD.py#L63) are worth to try. 21 | 22 | ### V-Net with hausdorff distance loss 23 | - LA Heart MRI dataset: run `python train_LA_HD.py` 24 | - Liver tumor CT dataset: run `python train_LITS_HD.py` 25 | 26 | > You need to set `--exp` properly. Both [compute_dtm](https://github.com/JunMa11/SegWithDistMap/blob/ed55b65889a4ba4cf9f7532e63124fe9ba10fcf0/code/train_LA_HD.py#L86) and [compute_dtm01](https://github.com/JunMa11/SegWithDistMap/blob/ed55b65889a4ba4cf9f7532e63124fe9ba10fcf0/code/train_LA_HD.py#LL63) are worth to try. 27 | 28 | ### Testing 29 | - LA heart MRI dataset: run `python test_LA.py` 30 | - Liver tumor CT dataset: run `python test_LITS.py` 31 | 32 | 33 | ## [Signed distance map loss](https://arxiv.org/abs/1912.03849) 34 | 35 | > Xue et al. Shape-Aware Organ Segmentation by Predicting Signed Distance Maps [arxiv](https://arxiv.org/abs/1912.03849) 36 | 37 | ### Training 38 | 39 | - run `python train_LA_AAAISDF.py` 40 | - run `python train_LA_AAAISDF_L1.py` 41 | 42 | ### Testing 43 | - run `test_LA_AAAISDF.py` 44 | 45 | 46 | ## V-Net with additional heads 47 | > Wang et al. Deep Distance Transform for Tubular Structure Segmentation in CT Scans [arxiv](https://arxiv.org/abs/1912.03383) 48 | 49 | > Navarro et al. Shape-Aware Complementary-Task Learning for Multi-organ Segmentation [arxiv](https://arxiv.org/abs/1908.05099) 50 | 51 | ### Training 52 | 53 | - run `python train_LA_MultiHead_FGDTM_L1.py` to regress foreground distance transform map 54 | 55 | > L1 can be replaced with L2 or L1PlusL2 56 | 57 | - run `python train_LA_MultiHead_SDF_L1.py` to regress signed distance function 58 | 59 | > L1 can be replaced with L2 or L1PlusL2 60 | 61 | ### Testing 62 | 63 | - run `test_LA_MultiHead_FGDTM.py ` 64 | - run `test_LA_MultiHead_SDF.py` 65 | 66 | ## V-Net with additional reconstruction branch 67 | ### Training 68 | 69 | - run `python train_LA_Rec_FGDTM_L1.py` to regress foreground distance transform map 70 | 71 | > L1 can be replaced with L2 or L1PlusL2 72 | 73 | - run `python train_LA_Rec_SDF_L1.py` to regress signed distance function 74 | 75 | > L1 can be replaced with L2 or L1PlusL2 76 | 77 | 78 | ### Testing 79 | 80 | - run `test_LA_Rec_FGDTM.py ` 81 | - run `test_LA_Rec_SDF.py` 82 | 83 | ## Tips 84 | - `--model` can be used to specificy the model name 85 | - `--epoch_num` can be used to specificy the checkpoint 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /code/dataloaders/la_heart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from torch.utils.data import Dataset 6 | import h5py 7 | import itertools 8 | from torch.utils.data.sampler import Sampler 9 | 10 | class LAHeart(Dataset): 11 | """ LA Dataset """ 12 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 13 | self._base_dir = base_dir 14 | self.transform = transform 15 | self.sample_list = [] 16 | if split=='train': 17 | with open(self._base_dir+'/../train.list', 'r') as f: 18 | self.image_list = f.readlines() 19 | elif split == 'test': 20 | with open(self._base_dir+'/../test.list', 'r') as f: 21 | self.image_list = f.readlines() 22 | self.image_list = [item.replace('\n','') for item in self.image_list] 23 | if num is not None: 24 | self.image_list = self.image_list[:num] 25 | print("total {} samples".format(len(self.image_list))) 26 | 27 | def __len__(self): 28 | return len(self.image_list) 29 | 30 | def __getitem__(self, idx): 31 | image_name = self.image_list[idx] 32 | h5f = h5py.File(self._base_dir+"/"+image_name+"/mri_norm2.h5", 'r') 33 | image = h5f['image'][:] 34 | label = h5f['label'][:] 35 | sample = {'image': image, 'label': label} 36 | if self.transform: 37 | sample = self.transform(sample) 38 | 39 | return sample 40 | 41 | class CenterCrop(object): 42 | def __init__(self, output_size): 43 | self.output_size = output_size 44 | 45 | def __call__(self, sample): 46 | image, label = sample['image'], sample['label'] 47 | 48 | # pad the sample if necessary 49 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 50 | self.output_size[2]: 51 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 52 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 53 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 54 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 55 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 56 | 57 | (w, h, d) = image.shape 58 | 59 | w1 = int(round((w - self.output_size[0]) / 2.)) 60 | h1 = int(round((h - self.output_size[1]) / 2.)) 61 | d1 = int(round((d - self.output_size[2]) / 2.)) 62 | 63 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 64 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 65 | 66 | return {'image': image, 'label': label} 67 | 68 | 69 | class RandomCrop(object): 70 | """ 71 | Crop randomly the image in a sample 72 | Args: 73 | output_size (int): Desired output size 74 | """ 75 | 76 | def __init__(self, output_size): 77 | self.output_size = output_size 78 | 79 | def __call__(self, sample): 80 | image, label = sample['image'], sample['label'] 81 | 82 | # pad the sample if necessary 83 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 84 | self.output_size[2]: 85 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 86 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 87 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 88 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 89 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 90 | 91 | (w, h, d) = image.shape 92 | # if np.random.uniform() > 0.33: 93 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4) 94 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4) 95 | # else: 96 | w1 = np.random.randint(0, w - self.output_size[0]) 97 | h1 = np.random.randint(0, h - self.output_size[1]) 98 | d1 = np.random.randint(0, d - self.output_size[2]) 99 | 100 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 101 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 102 | return {'image': image, 'label': label} 103 | 104 | 105 | class RandomRotFlip(object): 106 | """ 107 | Crop randomly flip the dataset in a sample 108 | Args: 109 | output_size (int): Desired output size 110 | """ 111 | 112 | def __call__(self, sample): 113 | image, label = sample['image'], sample['label'] 114 | k = np.random.randint(0, 4) 115 | image = np.rot90(image, k) 116 | label = np.rot90(label, k) 117 | axis = np.random.randint(0, 2) 118 | image = np.flip(image, axis=axis).copy() 119 | label = np.flip(label, axis=axis).copy() 120 | 121 | return {'image': image, 'label': label} 122 | 123 | 124 | class RandomNoise(object): 125 | def __init__(self, mu=0, sigma=0.1): 126 | self.mu = mu 127 | self.sigma = sigma 128 | 129 | def __call__(self, sample): 130 | image, label = sample['image'], sample['label'] 131 | noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma) 132 | noise = noise + self.mu 133 | image = image + noise 134 | return {'image': image, 'label': label} 135 | 136 | 137 | class CreateOnehotLabel(object): 138 | def __init__(self, num_classes): 139 | self.num_classes = num_classes 140 | 141 | def __call__(self, sample): 142 | image, label = sample['image'], sample['label'] 143 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 144 | for i in range(self.num_classes): 145 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 146 | return {'image': image, 'label': label,'onehot_label':onehot_label} 147 | 148 | 149 | class ToTensor(object): 150 | """Convert ndarrays in sample to Tensors.""" 151 | 152 | def __call__(self, sample): 153 | image = sample['image'] 154 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 155 | if 'onehot_label' in sample: 156 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(), 157 | 'onehot_label': torch.from_numpy(sample['onehot_label']).long()} 158 | else: 159 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} 160 | 161 | 162 | class TwoStreamBatchSampler(Sampler): 163 | """Iterate two sets of indices 164 | 165 | An 'epoch' is one iteration through the primary indices. 166 | During the epoch, the secondary indices are iterated through 167 | as many times as needed. 168 | """ 169 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 170 | self.primary_indices = primary_indices 171 | self.secondary_indices = secondary_indices 172 | self.secondary_batch_size = secondary_batch_size 173 | self.primary_batch_size = batch_size - secondary_batch_size 174 | 175 | assert len(self.primary_indices) >= self.primary_batch_size > 0 176 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 177 | 178 | def __iter__(self): 179 | primary_iter = iterate_once(self.primary_indices) 180 | secondary_iter = iterate_eternally(self.secondary_indices) 181 | return ( 182 | primary_batch + secondary_batch 183 | for (primary_batch, secondary_batch) 184 | in zip(grouper(primary_iter, self.primary_batch_size), 185 | grouper(secondary_iter, self.secondary_batch_size)) 186 | ) 187 | 188 | def __len__(self): 189 | return len(self.primary_indices) // self.primary_batch_size 190 | 191 | def iterate_once(iterable): 192 | return np.random.permutation(iterable) 193 | 194 | 195 | def iterate_eternally(indices): 196 | def infinite_shuffles(): 197 | while True: 198 | yield np.random.permutation(indices) 199 | return itertools.chain.from_iterable(infinite_shuffles()) 200 | 201 | 202 | def grouper(iterable, n): 203 | "Collect data into fixed-length chunks or blocks" 204 | # grouper('ABCDEFG', 3) --> ABC DEF" 205 | args = [iter(iterable)] * n 206 | return zip(*args) 207 | -------------------------------------------------------------------------------- /code/dataloaders/livertumor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from torch.utils.data import Dataset 6 | import h5py 7 | import itertools 8 | from torch.utils.data.sampler import Sampler 9 | 10 | """ 11 | patch size: (96, 128, 160) 12 | media patient size (147, 216, 243) 13 | """ 14 | 15 | class LiverTumor(Dataset): 16 | """ LITS Dataset """ 17 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 18 | self._base_dir = base_dir 19 | self.transform = transform 20 | self.sample_list = [] 21 | if split=='train': 22 | with open(self._base_dir+'/LITS_train.list', 'r') as f: 23 | self.image_list = f.readlines() 24 | elif split == 'test': 25 | with open(self._base_dir+'./LITS_test.list', 'r') as f: 26 | self.image_list = f.readlines() 27 | self.image_list = [item.replace('\n','') for item in self.image_list] 28 | if num is not None: 29 | self.image_list = self.image_list[:num] 30 | print("total {} samples".format(len(self.image_list))) 31 | 32 | def __len__(self): 33 | return len(self.image_list) 34 | 35 | def __getitem__(self, idx): 36 | image_name = self.image_list[idx] 37 | h5f = h5py.File(self._base_dir+"/h5/"+image_name, 'r') 38 | image = h5f['image'][:] 39 | label = h5f['label'][:] 40 | sample = {'image': image, 'label': label} 41 | if self.transform: 42 | sample = self.transform(sample) 43 | 44 | return sample 45 | 46 | class CenterCrop(object): 47 | def __init__(self, output_size=(96, 128, 160)): 48 | self.output_size = output_size 49 | 50 | def __call__(self, sample): 51 | image, label = sample['image'], sample['label'] 52 | 53 | # pad the sample if necessary 54 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 55 | self.output_size[2]: 56 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 57 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 58 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 59 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 60 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 61 | 62 | (w, h, d) = image.shape 63 | 64 | w1 = int(round((w - self.output_size[0]) / 2.)) 65 | h1 = int(round((h - self.output_size[1]) / 2.)) 66 | d1 = int(round((d - self.output_size[2]) / 2.)) 67 | 68 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 69 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 70 | 71 | return {'image': image, 'label': label} 72 | 73 | 74 | class RandomCrop(object): 75 | """ 76 | Crop randomly the image in a sample 77 | Args: 78 | output_size (int): Desired output size 79 | """ 80 | 81 | def __init__(self, output_size=(96, 128, 160)): 82 | self.output_size = output_size 83 | 84 | def __call__(self, sample): 85 | image, label = sample['image'], sample['label'] 86 | 87 | # pad the sample if necessary 88 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 89 | self.output_size[2]: 90 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 91 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 92 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 93 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 94 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 95 | 96 | (w, h, d) = image.shape 97 | # if np.random.uniform() > 0.33: 98 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4) 99 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4) 100 | # else: 101 | w1 = np.random.randint(0, w - self.output_size[0]) 102 | h1 = np.random.randint(0, h - self.output_size[1]) 103 | d1 = np.random.randint(0, d - self.output_size[2]) 104 | 105 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 106 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 107 | return {'image': image, 'label': label} 108 | 109 | 110 | class RandomRotFlip(object): 111 | """ 112 | Crop randomly flip the dataset in a sample 113 | Args: 114 | output_size (int): Desired output size 115 | """ 116 | 117 | def __call__(self, sample): 118 | image, label = sample['image'], sample['label'] 119 | k = np.random.randint(0, 4) 120 | image = np.rot90(image, k) 121 | label = np.rot90(label, k) 122 | axis = np.random.randint(0, 2) 123 | image = np.flip(image, axis=axis).copy() 124 | label = np.flip(label, axis=axis).copy() 125 | 126 | return {'image': image, 'label': label} 127 | 128 | 129 | class RandomNoise(object): 130 | def __init__(self, mu=0, sigma=0.1): 131 | self.mu = mu 132 | self.sigma = sigma 133 | 134 | def __call__(self, sample): 135 | image, label = sample['image'], sample['label'] 136 | noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma) 137 | noise = noise + self.mu 138 | image = image + noise 139 | return {'image': image, 'label': label} 140 | 141 | 142 | class CreateOnehotLabel(object): 143 | def __init__(self, num_classes): 144 | self.num_classes = num_classes 145 | 146 | def __call__(self, sample): 147 | image, label = sample['image'], sample['label'] 148 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 149 | for i in range(self.num_classes): 150 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 151 | return {'image': image, 'label': label,'onehot_label':onehot_label} 152 | 153 | 154 | class ToTensor(object): 155 | """Convert ndarrays in sample to Tensors.""" 156 | 157 | def __call__(self, sample): 158 | image = sample['image'] 159 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 160 | if 'onehot_label' in sample: 161 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(), 162 | 'onehot_label': torch.from_numpy(sample['onehot_label']).long()} 163 | else: 164 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} 165 | 166 | 167 | class TwoStreamBatchSampler(Sampler): 168 | """Iterate two sets of indices 169 | TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs) 170 | 171 | An 'epoch' is one iteration through the primary indices. 172 | During the epoch, the secondary indices are iterated through 173 | as many times as needed. 174 | """ 175 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 176 | self.primary_indices = primary_indices 177 | self.secondary_indices = secondary_indices 178 | self.secondary_batch_size = secondary_batch_size 179 | self.primary_batch_size = batch_size - secondary_batch_size 180 | 181 | assert len(self.primary_indices) >= self.primary_batch_size > 0 182 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 183 | 184 | def __iter__(self): 185 | primary_iter = iterate_once(self.primary_indices) 186 | secondary_iter = iterate_eternally(self.secondary_indices) 187 | return ( 188 | primary_batch + secondary_batch 189 | for (primary_batch, secondary_batch) 190 | in zip(grouper(primary_iter, self.primary_batch_size), 191 | grouper(secondary_iter, self.secondary_batch_size)) 192 | ) 193 | 194 | def __len__(self): 195 | return len(self.primary_indices) // self.primary_batch_size 196 | 197 | def iterate_once(iterable): 198 | return np.random.permutation(iterable) 199 | 200 | 201 | def iterate_eternally(indices): 202 | def infinite_shuffles(): 203 | while True: 204 | yield np.random.permutation(indices) 205 | return itertools.chain.from_iterable(infinite_shuffles()) 206 | 207 | 208 | def grouper(iterable, n): 209 | "Collect data into fixed-length chunks or blocks" 210 | # grouper('ABCDEFG', 3) --> ABC DEF" 211 | args = [iter(iterable)] * n 212 | return zip(*args) 213 | -------------------------------------------------------------------------------- /code/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | from skimage import measure 7 | import scipy.ndimage as nd 8 | 9 | 10 | def recursive_glob(rootdir='.', suffix=''): 11 | """Performs recursive glob with given suffix and rootdir 12 | :param rootdir is the root directory 13 | :param suffix is the suffix to be searched 14 | """ 15 | return [os.path.join(looproot, filename) 16 | for looproot, _, filenames in os.walk(rootdir) 17 | for filename in filenames if filename.endswith(suffix)] 18 | 19 | def get_cityscapes_labels(): 20 | return np.array([ 21 | # [ 0, 0, 0], 22 | [128, 64, 128], 23 | [244, 35, 232], 24 | [70, 70, 70], 25 | [102, 102, 156], 26 | [190, 153, 153], 27 | [153, 153, 153], 28 | [250, 170, 30], 29 | [220, 220, 0], 30 | [107, 142, 35], 31 | [152, 251, 152], 32 | [0, 130, 180], 33 | [220, 20, 60], 34 | [255, 0, 0], 35 | [0, 0, 142], 36 | [0, 0, 70], 37 | [0, 60, 100], 38 | [0, 80, 100], 39 | [0, 0, 230], 40 | [119, 11, 32]]) 41 | 42 | def get_pascal_labels(): 43 | """Load the mapping that associates pascal classes with label colors 44 | Returns: 45 | np.ndarray with dimensions (21, 3) 46 | """ 47 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 48 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 49 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 50 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 51 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 52 | [0, 64, 128]]) 53 | 54 | 55 | def encode_segmap(mask): 56 | """Encode segmentation label images as pascal classes 57 | Args: 58 | mask (np.ndarray): raw segmentation label image of dimension 59 | (M, N, 3), in which the Pascal classes are encoded as colours. 60 | Returns: 61 | (np.ndarray): class map with dimensions (M,N), where the value at 62 | a given location is the integer denoting the class index. 63 | """ 64 | mask = mask.astype(int) 65 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 66 | for ii, label in enumerate(get_pascal_labels()): 67 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 68 | label_mask = label_mask.astype(int) 69 | return label_mask 70 | 71 | 72 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 73 | rgb_masks = [] 74 | for label_mask in label_masks: 75 | rgb_mask = decode_segmap(label_mask, dataset) 76 | rgb_masks.append(rgb_mask) 77 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 78 | return rgb_masks 79 | 80 | def decode_segmap(label_mask, dataset, plot=False): 81 | """Decode segmentation class labels into a color image 82 | Args: 83 | label_mask (np.ndarray): an (M,N) array of integer values denoting 84 | the class label at each spatial location. 85 | plot (bool, optional): whether to show the resulting color image 86 | in a figure. 87 | Returns: 88 | (np.ndarray, optional): the resulting decoded color image. 89 | """ 90 | if dataset == 'pascal': 91 | n_classes = 21 92 | label_colours = get_pascal_labels() 93 | elif dataset == 'cityscapes': 94 | n_classes = 19 95 | label_colours = get_cityscapes_labels() 96 | else: 97 | raise NotImplementedError 98 | 99 | r = label_mask.copy() 100 | g = label_mask.copy() 101 | b = label_mask.copy() 102 | for ll in range(0, n_classes): 103 | r[label_mask == ll] = label_colours[ll, 0] 104 | g[label_mask == ll] = label_colours[ll, 1] 105 | b[label_mask == ll] = label_colours[ll, 2] 106 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 107 | rgb[:, :, 0] = r / 255.0 108 | rgb[:, :, 1] = g / 255.0 109 | rgb[:, :, 2] = b / 255.0 110 | if plot: 111 | plt.imshow(rgb) 112 | plt.show() 113 | else: 114 | return rgb 115 | 116 | def generate_param_report(logfile, param): 117 | log_file = open(logfile, 'w') 118 | # for key, val in param.items(): 119 | # log_file.write(key + ':' + str(val) + '\n') 120 | log_file.write(str(param)) 121 | log_file.close() 122 | 123 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 124 | n, c, h, w = logit.size() 125 | # logit = logit.permute(0, 2, 3, 1) 126 | target = target.squeeze(1) 127 | if weight is None: 128 | criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 129 | else: 130 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) 131 | loss = criterion(logit, target.long()) 132 | 133 | if size_average: 134 | loss /= (h * w) 135 | 136 | if batch_average: 137 | loss /= n 138 | 139 | return loss 140 | 141 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 142 | return base_lr * ((1 - float(iter_) / max_iter) ** power) 143 | 144 | 145 | def get_iou(pred, gt, n_classes=21): 146 | total_iou = 0.0 147 | for i in range(len(pred)): 148 | pred_tmp = pred[i] 149 | gt_tmp = gt[i] 150 | 151 | intersect = [0] * n_classes 152 | union = [0] * n_classes 153 | for j in range(n_classes): 154 | match = (pred_tmp == j) + (gt_tmp == j) 155 | 156 | it = torch.sum(match == 2).item() 157 | un = torch.sum(match > 0).item() 158 | 159 | intersect[j] += it 160 | union[j] += un 161 | 162 | iou = [] 163 | for k in range(n_classes): 164 | if union[k] == 0: 165 | continue 166 | iou.append(intersect[k] / union[k]) 167 | 168 | img_iou = (sum(iou) / len(iou)) 169 | total_iou += img_iou 170 | 171 | return total_iou 172 | 173 | def get_dice(pred, gt): 174 | total_dice = 0.0 175 | pred = pred.long() 176 | gt = gt.long() 177 | for i in range(len(pred)): 178 | pred_tmp = pred[i] 179 | gt_tmp = gt[i] 180 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 181 | print(dice) 182 | total_dice += dice 183 | 184 | return total_dice 185 | 186 | def get_mc_dice(pred, gt, num=2): 187 | # num is the total number of classes, include the background 188 | total_dice = np.zeros(num-1) 189 | pred = pred.long() 190 | gt = gt.long() 191 | for i in range(len(pred)): 192 | for j in range(1, num): 193 | pred_tmp = (pred[i]==j) 194 | gt_tmp = (gt[i]==j) 195 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 196 | total_dice[j-1] +=dice 197 | return total_dice 198 | 199 | def post_processing(prediction): 200 | prediction = nd.binary_fill_holes(prediction) 201 | label_cc, num_cc = measure.label(prediction,return_num=True) 202 | total_cc = np.sum(prediction) 203 | measure.regionprops(label_cc) 204 | for cc in range(1,num_cc+1): 205 | single_cc = (label_cc==cc) 206 | single_vol = np.sum(single_cc) 207 | if single_vol/total_cc<0.2: 208 | prediction[single_cc]=0 209 | 210 | return prediction 211 | 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /code/test_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet import VNet 5 | from test_util import test_all_case 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 9 | parser.add_argument('--model', type=str, default='vnet_supervisedonly_dp', help='model_name') 10 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 11 | parser.add_argument('--epoch_num', type=int, default='6000', help='checkpoint to use') 12 | FLAGS = parser.parse_args() 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 15 | snapshot_path = "../model_la/"+FLAGS.model+"/" 16 | test_save_path = "../model_la/prediction/"+FLAGS.model+"_post/" 17 | if not os.path.exists(test_save_path): 18 | os.makedirs(test_save_path) 19 | 20 | num_classes = 2 21 | 22 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 23 | image_list = f.readlines() 24 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 25 | # print(image_list) 26 | 27 | def test_calculate_metric(epoch_num): 28 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 29 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 30 | net.load_state_dict(torch.load(save_mode_path)) 31 | print("init weight from {}".format(save_mode_path)) 32 | net.eval() 33 | 34 | avg_metric = test_all_case(net, image_list, num_classes=num_classes, 35 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 36 | save_result=True, test_save_path=test_save_path) 37 | 38 | return avg_metric 39 | 40 | 41 | if __name__ == '__main__': 42 | metric = test_calculate_metric(FLAGS.epoch_num) 43 | # print(metric) 44 | -------------------------------------------------------------------------------- /code/test_LA_AAAISDF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet_sdf import VNet 5 | import h5py 6 | import math 7 | import nibabel as nib 8 | import numpy as np 9 | from medpy import metric 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import os 13 | import pandas as pd 14 | from collections import OrderedDict 15 | import pdb 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 19 | parser.add_argument('--model', type=str, default='vnet_dp_la_AAAISDFL1', help='model_name') 20 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 21 | parser.add_argument('--epoch_num', type=int, default='6000', help='model to use') 22 | FLAGS = parser.parse_args() 23 | 24 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 25 | snapshot_path = "../model_la/"+FLAGS.model+"/" 26 | test_save_path = "../model_la/prediction/"+FLAGS.model+"_post/" 27 | if not os.path.exists(test_save_path): 28 | os.makedirs(test_save_path) 29 | 30 | num_classes = 2 31 | 32 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 33 | image_list = f.readlines() 34 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 35 | # print(image_list) 36 | 37 | def test_calculate_metric(epoch_num): 38 | net = VNet(n_channels=1, n_classes=num_classes-1, normalization='batchnorm', has_dropout=False).cuda() 39 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 40 | net.load_state_dict(torch.load(save_mode_path)) 41 | print("init weight from {}".format(save_mode_path)) 42 | net.eval() 43 | 44 | avg_metric = dist_test_all_case(net, image_list, num_classes=num_classes, 45 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 46 | save_result=True, test_save_path=test_save_path) 47 | 48 | return avg_metric 49 | 50 | def dist_test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 51 | total_metric = 0.0 52 | metric_dict = OrderedDict() 53 | metric_dict['name'] = list() 54 | metric_dict['dice'] = list() 55 | metric_dict['jaccard'] = list() 56 | metric_dict['asd'] = list() 57 | metric_dict['95hd'] = list() 58 | for image_path in tqdm(image_list): 59 | case_name = image_path.split('/')[-2] 60 | id = image_path.split('/')[-1] 61 | h5f = h5py.File(image_path, 'r') 62 | image = h5f['image'][:] 63 | label = h5f['label'][:] 64 | if preproc_fn is not None: 65 | image = preproc_fn(image) 66 | prediction, score_map, pred_dist = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 67 | 68 | if np.sum(prediction)==0: 69 | single_metric = (0,0,0,0) 70 | else: 71 | single_metric = calculate_metric_percase(prediction, label[:]) 72 | metric_dict['name'].append(case_name) 73 | metric_dict['dice'].append(single_metric[0]) 74 | metric_dict['jaccard'].append(single_metric[1]) 75 | metric_dict['asd'].append(single_metric[2]) 76 | metric_dict['95hd'].append(single_metric[3]) 77 | # print(metric_dict) 78 | 79 | 80 | total_metric += np.asarray(single_metric) 81 | 82 | if save_result: 83 | test_save_path_temp = os.path.join(test_save_path, case_name) 84 | if not os.path.exists(test_save_path_temp): 85 | os.makedirs(test_save_path_temp) 86 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 87 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 88 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 89 | nib.save(nib.Nifti1Image(pred_dist[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_dist.nii.gz") 90 | avg_metric = total_metric / len(image_list) 91 | metric_csv = pd.DataFrame(metric_dict) 92 | metric_csv.to_csv(test_save_path + '/metric_'+str(FLAGS.epoch_num)+'.csv', index=False) 93 | print('average metric is {}'.format(avg_metric)) 94 | 95 | return avg_metric 96 | 97 | 98 | 99 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 100 | w, h, d = image.shape 101 | 102 | # if the size of image is less than patch_size, then padding it 103 | add_pad = False 104 | if w < patch_size[0]: 105 | w_pad = patch_size[0]-w 106 | add_pad = True 107 | else: 108 | w_pad = 0 109 | if h < patch_size[1]: 110 | h_pad = patch_size[1]-h 111 | add_pad = True 112 | else: 113 | h_pad = 0 114 | if d < patch_size[2]: 115 | d_pad = patch_size[2]-d 116 | add_pad = True 117 | else: 118 | d_pad = 0 119 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 120 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 121 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 122 | if add_pad: 123 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 124 | ww,hh,dd = image.shape 125 | 126 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 127 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 128 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 129 | # print("{}, {}, {}".format(sx, sy, sz)) 130 | score_map = np.zeros((num_classes-1, ) + image.shape).astype(np.float32) 131 | cnt = np.zeros(image.shape).astype(np.float32) 132 | pred_dist = np.zeros(image.shape).astype(np.float32) 133 | 134 | for x in range(0, sx): 135 | xs = min(stride_xy*x, ww-patch_size[0]) 136 | for y in range(0, sy): 137 | ys = min(stride_xy * y,hh-patch_size[1]) 138 | for z in range(0, sz): 139 | zs = min(stride_z * z, dd-patch_size[2]) 140 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 141 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 142 | test_patch = torch.from_numpy(test_patch).cuda() 143 | out_dist = net(test_patch) 144 | y = torch.sigmoid(-1500*out_dist) 145 | # print(y.shape, out_dist.shape) # ([1, 1, 112, 112, 80]) ([1, 1, 112, 112, 80]) 146 | # pdb.set_trace() 147 | y = y.cpu().data.numpy() 148 | y = y[0,:,:,:,:] 149 | out_dist = out_dist.cpu().data.numpy() 150 | pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] = out_dist[0,0,:,:,:] 151 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 152 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 153 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 154 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 155 | 156 | 157 | score_map = score_map/np.expand_dims(cnt,axis=0) 158 | label_map = (score_map>0.5).astype(score_map.dtype).squeeze() 159 | if add_pad: 160 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 161 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 162 | pred_dist = pred_dist[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 163 | return label_map, score_map, pred_dist 164 | 165 | def cal_dice(prediction, label, num=2): 166 | total_dice = np.zeros(num-1) 167 | for i in range(1, num): 168 | prediction_tmp = (prediction==i) 169 | label_tmp = (label==i) 170 | prediction_tmp = prediction_tmp.astype(np.float) 171 | label_tmp = label_tmp.astype(np.float) 172 | 173 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 174 | total_dice[i - 1] += dice 175 | 176 | return total_dice 177 | 178 | 179 | def calculate_metric_percase(pred, gt): 180 | dice = metric.binary.dc(pred, gt) 181 | jc = metric.binary.jc(pred, gt) 182 | hd = metric.binary.hd95(pred, gt) 183 | asd = metric.binary.asd(pred, gt) 184 | 185 | return dice, jc, hd, asd 186 | 187 | 188 | if __name__ == '__main__': 189 | metric = test_calculate_metric(FLAGS.epoch_num) 190 | # print(metric) -------------------------------------------------------------------------------- /code/test_LA_MultiHead_FGDTM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet_multi_head import VNetMultiHead 5 | import h5py 6 | import math 7 | import nibabel as nib 8 | import numpy as np 9 | from medpy import metric 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import os 13 | import pandas as pd 14 | from collections import OrderedDict 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 19 | parser.add_argument('--model', type=str, default='vnet_dp_la_MH_FGDTM_L1PlusL2', help='model_name') 20 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 21 | parser.add_argument('--epoch_num', type=int, default='6000', help='checkpoint to use') 22 | FLAGS = parser.parse_args() 23 | 24 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 25 | snapshot_path = "../model_la/"+FLAGS.model+"/" 26 | test_save_path = "../model_la/prediction/"+FLAGS.model+"_post/" 27 | if not os.path.exists(test_save_path): 28 | os.makedirs(test_save_path) 29 | 30 | num_classes = 2 31 | 32 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 33 | image_list = f.readlines() 34 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 35 | # print(image_list) 36 | 37 | def test_calculate_metric(epoch_num): 38 | net = VNetMultiHead(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 39 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 40 | net.load_state_dict(torch.load(save_mode_path)) 41 | print("init weight from {}".format(save_mode_path)) 42 | net.eval() 43 | 44 | avg_metric = dist_test_all_case(net, image_list, num_classes=num_classes, 45 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 46 | save_result=True, test_save_path=test_save_path) 47 | 48 | return avg_metric 49 | 50 | def dist_test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 51 | total_metric = 0.0 52 | metric_dict = OrderedDict() 53 | metric_dict['name'] = list() 54 | metric_dict['dice'] = list() 55 | metric_dict['jaccard'] = list() 56 | metric_dict['asd'] = list() 57 | metric_dict['95hd'] = list() 58 | for image_path in tqdm(image_list): 59 | case_name = image_path.split('/')[-2] 60 | id = image_path.split('/')[-1] 61 | h5f = h5py.File(image_path, 'r') 62 | image = h5f['image'][:] 63 | label = h5f['label'][:] 64 | if preproc_fn is not None: 65 | image = preproc_fn(image) 66 | prediction, score_map, pred_dist = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 67 | 68 | if np.sum(prediction)==0: 69 | single_metric = (0,0,0,0) 70 | else: 71 | single_metric = calculate_metric_percase(prediction, label[:]) 72 | metric_dict['name'].append(case_name) 73 | metric_dict['dice'].append(single_metric[0]) 74 | metric_dict['jaccard'].append(single_metric[1]) 75 | metric_dict['asd'].append(single_metric[2]) 76 | metric_dict['95hd'].append(single_metric[3]) 77 | # print(metric_dict) 78 | 79 | 80 | total_metric += np.asarray(single_metric) 81 | 82 | if save_result: 83 | test_save_path_temp = os.path.join(test_save_path, case_name) 84 | if not os.path.exists(test_save_path_temp): 85 | os.makedirs(test_save_path_temp) 86 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 87 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 88 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 89 | nib.save(nib.Nifti1Image(pred_dist[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_dist.nii.gz") 90 | avg_metric = total_metric / len(image_list) 91 | metric_csv = pd.DataFrame(metric_dict) 92 | metric_csv.to_csv(test_save_path + '/metric_'+str(FLAGS.epoch_num)+'.csv', index=False) 93 | print('average metric is {}'.format(avg_metric)) 94 | 95 | return avg_metric 96 | 97 | 98 | 99 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 100 | w, h, d = image.shape 101 | 102 | # if the size of image is less than patch_size, then padding it 103 | add_pad = False 104 | if w < patch_size[0]: 105 | w_pad = patch_size[0]-w 106 | add_pad = True 107 | else: 108 | w_pad = 0 109 | if h < patch_size[1]: 110 | h_pad = patch_size[1]-h 111 | add_pad = True 112 | else: 113 | h_pad = 0 114 | if d < patch_size[2]: 115 | d_pad = patch_size[2]-d 116 | add_pad = True 117 | else: 118 | d_pad = 0 119 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 120 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 121 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 122 | if add_pad: 123 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 124 | ww,hh,dd = image.shape 125 | 126 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 127 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 128 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 129 | # print("{}, {}, {}".format(sx, sy, sz)) 130 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 131 | cnt = np.zeros(image.shape).astype(np.float32) 132 | pred_dist = np.zeros(image.shape).astype(np.float32) 133 | 134 | for x in range(0, sx): 135 | xs = min(stride_xy*x, ww-patch_size[0]) 136 | for y in range(0, sy): 137 | ys = min(stride_xy * y,hh-patch_size[1]) 138 | for z in range(0, sz): 139 | zs = min(stride_z * z, dd-patch_size[2]) 140 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 141 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 142 | test_patch = torch.from_numpy(test_patch).cuda() 143 | y1, out_dist = net(test_patch) 144 | # print(y1.shape, out_dist.shape) # ([1, 2, 112, 112, 80]) ([1, 1, 112, 112, 80]) 145 | y = F.softmax(y1, dim=1) 146 | y = y.cpu().data.numpy() 147 | y = y[0,:,:,:,:] 148 | out_dist = out_dist.cpu().data.numpy() 149 | pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 150 | = pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + out_dist[0,0,:,:,:] 151 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 152 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 153 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 154 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 155 | 156 | 157 | score_map = score_map/np.expand_dims(cnt,axis=0) 158 | pred_dist = pred_dist/cnt 159 | label_map = np.argmax(score_map, axis = 0) 160 | if add_pad: 161 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 162 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 163 | pred_dist = pred_dist[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 164 | return label_map, score_map, pred_dist 165 | 166 | def cal_dice(prediction, label, num=2): 167 | total_dice = np.zeros(num-1) 168 | for i in range(1, num): 169 | prediction_tmp = (prediction==i) 170 | label_tmp = (label==i) 171 | prediction_tmp = prediction_tmp.astype(np.float) 172 | label_tmp = label_tmp.astype(np.float) 173 | 174 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 175 | total_dice[i - 1] += dice 176 | 177 | return total_dice 178 | 179 | 180 | def calculate_metric_percase(pred, gt): 181 | dice = metric.binary.dc(pred, gt) 182 | jc = metric.binary.jc(pred, gt) 183 | hd = metric.binary.hd95(pred, gt) 184 | asd = metric.binary.asd(pred, gt) 185 | 186 | return dice, jc, hd, asd 187 | 188 | 189 | if __name__ == '__main__': 190 | metric = test_calculate_metric(FLAGS.epoch_num) 191 | # print(metric) -------------------------------------------------------------------------------- /code/test_LA_MultiHead_SDF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet_multi_head import VNetMultiHead 5 | import h5py 6 | import math 7 | import nibabel as nib 8 | import numpy as np 9 | from medpy import metric 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import os 13 | import pandas as pd 14 | from collections import OrderedDict 15 | 16 | """ 17 | Testing 18 | Ref: 19 | Shape-Aware Organ Segmentation by Predicting Signed Distance Maps 20 | https://arxiv.org/abs/1912.03849 21 | """ 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 25 | parser.add_argument('--model', type=str, default='vnet_dp_la_MH_SDFL1PlusL2', help='model_name') 26 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 27 | parser.add_argument('--epoch_num', type=int, default='6000', help='checkpoint to use') 28 | FLAGS = parser.parse_args() 29 | 30 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 31 | snapshot_path = "../model_la/"+FLAGS.model+"/" 32 | test_save_path = "../model_la/prediction/"+FLAGS.model+"_post/" 33 | if not os.path.exists(test_save_path): 34 | os.makedirs(test_save_path) 35 | 36 | num_classes = 2 37 | 38 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 39 | image_list = f.readlines() 40 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 41 | # print(image_list) 42 | 43 | def test_calculate_metric(epoch_num): 44 | net = VNetMultiHead(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 45 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 46 | net.load_state_dict(torch.load(save_mode_path)) 47 | print("init weight from {}".format(save_mode_path)) 48 | net.eval() 49 | 50 | avg_metric = dist_test_all_case(net, image_list, num_classes=num_classes, 51 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 52 | save_result=True, test_save_path=test_save_path) 53 | 54 | return avg_metric 55 | 56 | def dist_test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 57 | total_metric = 0.0 58 | metric_dict = OrderedDict() 59 | metric_dict['name'] = list() 60 | metric_dict['dice'] = list() 61 | metric_dict['jaccard'] = list() 62 | metric_dict['asd'] = list() 63 | metric_dict['95hd'] = list() 64 | for image_path in tqdm(image_list): 65 | case_name = image_path.split('/')[-2] 66 | id = image_path.split('/')[-1] 67 | h5f = h5py.File(image_path, 'r') 68 | image = h5f['image'][:] 69 | label = h5f['label'][:] 70 | if preproc_fn is not None: 71 | image = preproc_fn(image) 72 | prediction, score_map, pred_dist = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 73 | 74 | if np.sum(prediction)==0: 75 | single_metric = (0,0,0,0) 76 | else: 77 | single_metric = calculate_metric_percase(prediction, label[:]) 78 | metric_dict['name'].append(case_name) 79 | metric_dict['dice'].append(single_metric[0]) 80 | metric_dict['jaccard'].append(single_metric[1]) 81 | metric_dict['asd'].append(single_metric[2]) 82 | metric_dict['95hd'].append(single_metric[3]) 83 | # print(metric_dict) 84 | 85 | 86 | total_metric += np.asarray(single_metric) 87 | 88 | if save_result: 89 | test_save_path_temp = os.path.join(test_save_path, case_name) 90 | if not os.path.exists(test_save_path_temp): 91 | os.makedirs(test_save_path_temp) 92 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 93 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 94 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 95 | nib.save(nib.Nifti1Image(pred_dist[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_dist.nii.gz") 96 | avg_metric = total_metric / len(image_list) 97 | metric_csv = pd.DataFrame(metric_dict) 98 | metric_csv.to_csv(test_save_path + '/metric_'+str(FLAGS.epoch_num)+'.csv', index=False) 99 | print('average metric is {}'.format(avg_metric)) 100 | 101 | return avg_metric 102 | 103 | 104 | 105 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 106 | w, h, d = image.shape 107 | 108 | # if the size of image is less than patch_size, then padding it 109 | add_pad = False 110 | if w < patch_size[0]: 111 | w_pad = patch_size[0]-w 112 | add_pad = True 113 | else: 114 | w_pad = 0 115 | if h < patch_size[1]: 116 | h_pad = patch_size[1]-h 117 | add_pad = True 118 | else: 119 | h_pad = 0 120 | if d < patch_size[2]: 121 | d_pad = patch_size[2]-d 122 | add_pad = True 123 | else: 124 | d_pad = 0 125 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 126 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 127 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 128 | if add_pad: 129 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 130 | ww,hh,dd = image.shape 131 | 132 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 133 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 134 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 135 | # print("{}, {}, {}".format(sx, sy, sz)) 136 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 137 | cnt = np.zeros(image.shape).astype(np.float32) 138 | pred_dist = np.zeros(image.shape).astype(np.float32) 139 | 140 | for x in range(0, sx): 141 | xs = min(stride_xy*x, ww-patch_size[0]) 142 | for y in range(0, sy): 143 | ys = min(stride_xy * y,hh-patch_size[1]) 144 | for z in range(0, sz): 145 | zs = min(stride_z * z, dd-patch_size[2]) 146 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 147 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 148 | test_patch = torch.from_numpy(test_patch).cuda() 149 | y1, out_dist = net(test_patch) 150 | # print(y1.shape, out_dist.shape) # ([1, 2, 112, 112, 80]) ([1, 1, 112, 112, 80]) 151 | y = F.softmax(y1, dim=1) 152 | y = y.cpu().data.numpy() 153 | y = y[0,:,:,:,:] 154 | out_dist = torch.tanh(out_dist) 155 | out_dist = out_dist.cpu().data.numpy() 156 | pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 157 | = pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + out_dist[0,0,:,:,:] 158 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 159 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 160 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 161 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 162 | 163 | 164 | score_map = score_map/np.expand_dims(cnt,axis=0) 165 | pred_dist = pred_dist/cnt 166 | label_map = np.argmax(score_map, axis = 0) 167 | if add_pad: 168 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 169 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 170 | pred_dist = pred_dist[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 171 | return label_map, score_map, pred_dist 172 | 173 | def cal_dice(prediction, label, num=2): 174 | total_dice = np.zeros(num-1) 175 | for i in range(1, num): 176 | prediction_tmp = (prediction==i) 177 | label_tmp = (label==i) 178 | prediction_tmp = prediction_tmp.astype(np.float) 179 | label_tmp = label_tmp.astype(np.float) 180 | 181 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 182 | total_dice[i - 1] += dice 183 | 184 | return total_dice 185 | 186 | 187 | def calculate_metric_percase(pred, gt): 188 | dice = metric.binary.dc(pred, gt) 189 | jc = metric.binary.jc(pred, gt) 190 | hd = metric.binary.hd95(pred, gt) 191 | asd = metric.binary.asd(pred, gt) 192 | 193 | return dice, jc, hd, asd 194 | 195 | 196 | if __name__ == '__main__': 197 | metric = test_calculate_metric(FLAGS.epoch_num) 198 | # print(metric) -------------------------------------------------------------------------------- /code/test_LA_Rec_FGDTM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet_rec import VNetRec 5 | import h5py 6 | import math 7 | import nibabel as nib 8 | import numpy as np 9 | from medpy import metric 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import os 13 | import pandas as pd 14 | from collections import OrderedDict 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 19 | parser.add_argument('--model', type=str, default='vnet_dp_la_Rec_FGDTM_L1PlusL2', help='model_name') 20 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 21 | parser.add_argument('--epoch_num', type=int, default='6000', help='checkpoint to use') 22 | FLAGS = parser.parse_args() 23 | 24 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 25 | snapshot_path = "../model_la/"+FLAGS.model+"/" 26 | test_save_path = "../model_la/prediction/"+FLAGS.model+"_post/" 27 | if not os.path.exists(test_save_path): 28 | os.makedirs(test_save_path) 29 | 30 | num_classes = 2 31 | 32 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 33 | image_list = f.readlines() 34 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 35 | # print(image_list) 36 | 37 | def test_calculate_metric(epoch_num): 38 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 39 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 40 | net.load_state_dict(torch.load(save_mode_path)) 41 | print("init weight from {}".format(save_mode_path)) 42 | net.eval() 43 | 44 | avg_metric = dist_test_all_case(net, image_list, num_classes=num_classes, 45 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 46 | save_result=True, test_save_path=test_save_path) 47 | 48 | return avg_metric 49 | 50 | def dist_test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 51 | total_metric = 0.0 52 | metric_dict = OrderedDict() 53 | metric_dict['name'] = list() 54 | metric_dict['dice'] = list() 55 | metric_dict['jaccard'] = list() 56 | metric_dict['asd'] = list() 57 | metric_dict['95hd'] = list() 58 | for image_path in tqdm(image_list): 59 | case_name = image_path.split('/')[-2] 60 | id = image_path.split('/')[-1] 61 | h5f = h5py.File(image_path, 'r') 62 | image = h5f['image'][:] 63 | label = h5f['label'][:] 64 | if preproc_fn is not None: 65 | image = preproc_fn(image) 66 | prediction, score_map, pred_dist = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 67 | 68 | if np.sum(prediction)==0: 69 | single_metric = (0,0,0,0) 70 | else: 71 | single_metric = calculate_metric_percase(prediction, label[:]) 72 | metric_dict['name'].append(case_name) 73 | metric_dict['dice'].append(single_metric[0]) 74 | metric_dict['jaccard'].append(single_metric[1]) 75 | metric_dict['asd'].append(single_metric[2]) 76 | metric_dict['95hd'].append(single_metric[3]) 77 | # print(metric_dict) 78 | 79 | 80 | total_metric += np.asarray(single_metric) 81 | 82 | if save_result: 83 | test_save_path_temp = os.path.join(test_save_path, case_name) 84 | if not os.path.exists(test_save_path_temp): 85 | os.makedirs(test_save_path_temp) 86 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 87 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 88 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 89 | nib.save(nib.Nifti1Image(pred_dist[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_dist.nii.gz") 90 | avg_metric = total_metric / len(image_list) 91 | metric_csv = pd.DataFrame(metric_dict) 92 | metric_csv.to_csv(test_save_path + '/metric_'+str(FLAGS.epoch_num)+'.csv', index=False) 93 | print('average metric is {}'.format(avg_metric)) 94 | 95 | return avg_metric 96 | 97 | 98 | 99 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 100 | w, h, d = image.shape 101 | 102 | # if the size of image is less than patch_size, then padding it 103 | add_pad = False 104 | if w < patch_size[0]: 105 | w_pad = patch_size[0]-w 106 | add_pad = True 107 | else: 108 | w_pad = 0 109 | if h < patch_size[1]: 110 | h_pad = patch_size[1]-h 111 | add_pad = True 112 | else: 113 | h_pad = 0 114 | if d < patch_size[2]: 115 | d_pad = patch_size[2]-d 116 | add_pad = True 117 | else: 118 | d_pad = 0 119 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 120 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 121 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 122 | if add_pad: 123 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 124 | ww,hh,dd = image.shape 125 | 126 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 127 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 128 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 129 | # print("{}, {}, {}".format(sx, sy, sz)) 130 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 131 | cnt = np.zeros(image.shape).astype(np.float32) 132 | pred_dist = np.zeros(image.shape).astype(np.float32) 133 | 134 | for x in range(0, sx): 135 | xs = min(stride_xy*x, ww-patch_size[0]) 136 | for y in range(0, sy): 137 | ys = min(stride_xy * y,hh-patch_size[1]) 138 | for z in range(0, sz): 139 | zs = min(stride_z * z, dd-patch_size[2]) 140 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 141 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 142 | test_patch = torch.from_numpy(test_patch).cuda() 143 | y1, out_dist = net(test_patch) 144 | # print(y1.shape, out_dist.shape) # ([1, 2, 112, 112, 80]) ([1, 1, 112, 112, 80]) 145 | y = F.softmax(y1, dim=1) 146 | y = y.cpu().data.numpy() 147 | y = y[0,:,:,:,:] 148 | out_dist = out_dist.cpu().data.numpy() 149 | pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 150 | = pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + out_dist[0,0,:,:,:] 151 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 152 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 153 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 154 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 155 | 156 | 157 | score_map = score_map/np.expand_dims(cnt,axis=0) 158 | pred_dist = pred_dist/cnt 159 | label_map = np.argmax(score_map, axis = 0) 160 | if add_pad: 161 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 162 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 163 | pred_dist = pred_dist[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 164 | return label_map, score_map, pred_dist 165 | 166 | def cal_dice(prediction, label, num=2): 167 | total_dice = np.zeros(num-1) 168 | for i in range(1, num): 169 | prediction_tmp = (prediction==i) 170 | label_tmp = (label==i) 171 | prediction_tmp = prediction_tmp.astype(np.float) 172 | label_tmp = label_tmp.astype(np.float) 173 | 174 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 175 | total_dice[i - 1] += dice 176 | 177 | return total_dice 178 | 179 | 180 | def calculate_metric_percase(pred, gt): 181 | dice = metric.binary.dc(pred, gt) 182 | jc = metric.binary.jc(pred, gt) 183 | hd = metric.binary.hd95(pred, gt) 184 | asd = metric.binary.asd(pred, gt) 185 | 186 | return dice, jc, hd, asd 187 | 188 | 189 | if __name__ == '__main__': 190 | metric = test_calculate_metric(FLAGS.epoch_num) 191 | # print(metric) -------------------------------------------------------------------------------- /code/test_LA_Rec_SDF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet_rec import VNetRec 5 | import h5py 6 | import math 7 | import nibabel as nib 8 | import numpy as np 9 | from medpy import metric 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import os 13 | import pandas as pd 14 | from collections import OrderedDict 15 | 16 | """ 17 | Testing 18 | Adding reconstruction branch to V-Net 19 | Ref: 20 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 21 | https://arxiv.org/abs/1901.01238 22 | """ 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 26 | parser.add_argument('--model', type=str, default='vnet_dp_la_Rec_SDF_L1PlusL2', help='model_name') 27 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 28 | parser.add_argument('--epoch_num', type=int, default='6000', help='checkpoint to use') 29 | FLAGS = parser.parse_args() 30 | 31 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 32 | snapshot_path = "../model_la/"+FLAGS.model+"/" 33 | test_save_path = "../model_la/prediction/"+FLAGS.model+"_post/" 34 | if not os.path.exists(test_save_path): 35 | os.makedirs(test_save_path) 36 | 37 | num_classes = 2 38 | 39 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 40 | image_list = f.readlines() 41 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 42 | # print(image_list) 43 | 44 | def test_calculate_metric(epoch_num): 45 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 46 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 47 | net.load_state_dict(torch.load(save_mode_path)) 48 | print("init weight from {}".format(save_mode_path)) 49 | net.eval() 50 | 51 | avg_metric = dist_test_all_case(net, image_list, num_classes=num_classes, 52 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 53 | save_result=True, test_save_path=test_save_path) 54 | 55 | return avg_metric 56 | 57 | def dist_test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 58 | total_metric = 0.0 59 | metric_dict = OrderedDict() 60 | metric_dict['name'] = list() 61 | metric_dict['dice'] = list() 62 | metric_dict['jaccard'] = list() 63 | metric_dict['asd'] = list() 64 | metric_dict['95hd'] = list() 65 | for image_path in tqdm(image_list): 66 | case_name = image_path.split('/')[-2] 67 | id = image_path.split('/')[-1] 68 | h5f = h5py.File(image_path, 'r') 69 | image = h5f['image'][:] 70 | label = h5f['label'][:] 71 | if preproc_fn is not None: 72 | image = preproc_fn(image) 73 | prediction, score_map, pred_dist = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 74 | 75 | if np.sum(prediction)==0: 76 | single_metric = (0,0,0,0) 77 | else: 78 | single_metric = calculate_metric_percase(prediction, label[:]) 79 | metric_dict['name'].append(case_name) 80 | metric_dict['dice'].append(single_metric[0]) 81 | metric_dict['jaccard'].append(single_metric[1]) 82 | metric_dict['asd'].append(single_metric[2]) 83 | metric_dict['95hd'].append(single_metric[3]) 84 | # print(metric_dict) 85 | 86 | 87 | total_metric += np.asarray(single_metric) 88 | 89 | if save_result: 90 | test_save_path_temp = os.path.join(test_save_path, case_name) 91 | if not os.path.exists(test_save_path_temp): 92 | os.makedirs(test_save_path_temp) 93 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 94 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 95 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 96 | nib.save(nib.Nifti1Image(pred_dist[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_dist.nii.gz") 97 | avg_metric = total_metric / len(image_list) 98 | metric_csv = pd.DataFrame(metric_dict) 99 | metric_csv.to_csv(test_save_path + '/metric_'+str(FLAGS.epoch_num)+'.csv', index=False) 100 | print('average metric is {}'.format(avg_metric)) 101 | 102 | return avg_metric 103 | 104 | 105 | 106 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 107 | w, h, d = image.shape 108 | 109 | # if the size of image is less than patch_size, then padding it 110 | add_pad = False 111 | if w < patch_size[0]: 112 | w_pad = patch_size[0]-w 113 | add_pad = True 114 | else: 115 | w_pad = 0 116 | if h < patch_size[1]: 117 | h_pad = patch_size[1]-h 118 | add_pad = True 119 | else: 120 | h_pad = 0 121 | if d < patch_size[2]: 122 | d_pad = patch_size[2]-d 123 | add_pad = True 124 | else: 125 | d_pad = 0 126 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 127 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 128 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 129 | if add_pad: 130 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 131 | ww,hh,dd = image.shape 132 | 133 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 134 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 135 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 136 | # print("{}, {}, {}".format(sx, sy, sz)) 137 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 138 | cnt = np.zeros(image.shape).astype(np.float32) 139 | pred_dist = np.zeros(image.shape).astype(np.float32) 140 | 141 | for x in range(0, sx): 142 | xs = min(stride_xy*x, ww-patch_size[0]) 143 | for y in range(0, sy): 144 | ys = min(stride_xy * y,hh-patch_size[1]) 145 | for z in range(0, sz): 146 | zs = min(stride_z * z, dd-patch_size[2]) 147 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 148 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 149 | test_patch = torch.from_numpy(test_patch).cuda() 150 | y1, out_dist = net(test_patch) 151 | # print(y1.shape, out_dist.shape) # ([1, 2, 112, 112, 80]) ([1, 1, 112, 112, 80]) 152 | y = F.softmax(y1, dim=1) 153 | y = y.cpu().data.numpy() 154 | y = y[0,:,:,:,:] 155 | out_dist = torch.tanh(out_dist) 156 | out_dist = out_dist.cpu().data.numpy() 157 | pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 158 | = pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + out_dist[0,0,:,:,:] 159 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 160 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 161 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 162 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 163 | 164 | 165 | score_map = score_map/np.expand_dims(cnt,axis=0) 166 | pred_dist = pred_dist/cnt 167 | label_map = np.argmax(score_map, axis = 0) 168 | if add_pad: 169 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 170 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 171 | pred_dist = pred_dist[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 172 | return label_map, score_map, pred_dist 173 | 174 | def cal_dice(prediction, label, num=2): 175 | total_dice = np.zeros(num-1) 176 | for i in range(1, num): 177 | prediction_tmp = (prediction==i) 178 | label_tmp = (label==i) 179 | prediction_tmp = prediction_tmp.astype(np.float) 180 | label_tmp = label_tmp.astype(np.float) 181 | 182 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 183 | total_dice[i - 1] += dice 184 | 185 | return total_dice 186 | 187 | 188 | def calculate_metric_percase(pred, gt): 189 | dice = metric.binary.dc(pred, gt) 190 | jc = metric.binary.jc(pred, gt) 191 | hd = metric.binary.hd95(pred, gt) 192 | asd = metric.binary.asd(pred, gt) 193 | 194 | return dice, jc, hd, asd 195 | 196 | 197 | if __name__ == '__main__': 198 | metric = test_calculate_metric(FLAGS.epoch_num) 199 | # print(metric) -------------------------------------------------------------------------------- /code/test_LITS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import torch 5 | from networks.vnet import VNet 6 | import h5py 7 | import math 8 | import nibabel as nib 9 | import numpy as np 10 | from medpy import metric 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | import pandas as pd 14 | from collections import OrderedDict 15 | import SimpleITK as sitk 16 | # import pdb 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--root_path', type=str, default='../data/LITS', help='Name of Experiment') 21 | parser.add_argument('--model', type=str, default='vnet_supervisedonly_dp', help='model_name') 22 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 23 | FLAGS = parser.parse_args() 24 | 25 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 26 | snapshot_path = "../model_lits/"+FLAGS.model+"/" 27 | test_save_path = "../model_lits/prediction/"+FLAGS.model+"_post/" 28 | if not os.path.exists(test_save_path): 29 | os.makedirs(test_save_path) 30 | 31 | num_classes = 2 32 | 33 | with open(FLAGS.root_path + '/LITS_test.list', 'r') as f: 34 | image_list = f.readlines() 35 | image_list = [FLAGS.root_path +'/h5/' + item.replace('\n', '') for item in image_list] 36 | # print(image_list) 37 | 38 | def test_calculate_metric(epoch_num): 39 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 40 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 41 | net.load_state_dict(torch.load(save_mode_path)) 42 | print("init weight from {}".format(save_mode_path)) 43 | net.eval() 44 | 45 | avg_metric = test_all_case(net, image_list, num_classes=num_classes, 46 | patch_size=(96, 128, 160), stride_xy=18, stride_z=4, 47 | save_result=True, test_save_path=test_save_path) 48 | 49 | return avg_metric 50 | 51 | 52 | 53 | def test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 54 | total_metric = 0.0 55 | metric_dict = OrderedDict() 56 | metric_dict['name'] = list() 57 | metric_dict['dice'] = list() 58 | metric_dict['jaccard'] = list() 59 | metric_dict['asd'] = list() 60 | metric_dict['95hd'] = list() 61 | for image_path in tqdm(image_list): 62 | case_name = image_path.split('/')[-1] 63 | id = re.findall('\d+', case_name)[0] 64 | h5f = h5py.File(image_path, 'r') 65 | image = h5f['image'][:] 66 | label = h5f['label'][:] 67 | if preproc_fn is not None: 68 | image = preproc_fn(image) 69 | prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 70 | 71 | if np.sum(prediction)==0: 72 | single_metric = (0,0,0,0) 73 | else: 74 | single_metric = calculate_metric_percase(prediction, label[:]) 75 | metric_dict['name'].append(case_name) 76 | metric_dict['dice'].append(single_metric[0]) 77 | metric_dict['jaccard'].append(single_metric[1]) 78 | metric_dict['asd'].append(single_metric[2]) 79 | metric_dict['95hd'].append(single_metric[3]) 80 | # print(metric_dict) 81 | 82 | 83 | total_metric += np.asarray(single_metric) 84 | 85 | if save_result: 86 | test_save_path_temp = os.path.join(test_save_path, case_name.split('.h5')[0]) 87 | if not os.path.exists(test_save_path_temp): 88 | os.makedirs(test_save_path_temp) 89 | sitk.WriteImage(sitk.GetImageFromArray(prediction.astype(np.float32)), test_save_path_temp + '/' + id + "_pred.nii.gz") 90 | sitk.WriteImage(sitk.GetImageFromArray(image[:].astype(np.float32)), test_save_path_temp + '/' + id + "_img.nii.gz") 91 | sitk.WriteImage(sitk.GetImageFromArray(label[:].astype(np.float32)), test_save_path_temp + '/' + id + "_gt.nii.gz") 92 | avg_metric = total_metric / len(image_list) 93 | metric_csv = pd.DataFrame(metric_dict) 94 | metric_csv.to_csv(test_save_path + '/metric.csv', index=False) 95 | print('average metric is {}'.format(avg_metric)) 96 | 97 | return avg_metric 98 | 99 | 100 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 101 | w, h, d = image.shape 102 | 103 | # if the size of image is less than patch_size, then padding it 104 | add_pad = False 105 | if w < patch_size[0]: 106 | w_pad = patch_size[0]-w 107 | add_pad = True 108 | else: 109 | w_pad = 0 110 | if h < patch_size[1]: 111 | h_pad = patch_size[1]-h 112 | add_pad = True 113 | else: 114 | h_pad = 0 115 | if d < patch_size[2]: 116 | d_pad = patch_size[2]-d 117 | add_pad = True 118 | else: 119 | d_pad = 0 120 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 121 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 122 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 123 | if add_pad: 124 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 125 | ww,hh,dd = image.shape 126 | 127 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 128 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 129 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 130 | # print("{}, {}, {}".format(sx, sy, sz)) 131 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 132 | cnt = np.zeros(image.shape).astype(np.float32) 133 | 134 | for x in range(0, sx): 135 | xs = min(stride_xy*x, ww-patch_size[0]) 136 | for y in range(0, sy): 137 | ys = min(stride_xy * y,hh-patch_size[1]) 138 | for z in range(0, sz): 139 | zs = min(stride_z * z, dd-patch_size[2]) 140 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 141 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 142 | test_patch = torch.from_numpy(test_patch).cuda() 143 | y1 = net(test_patch) 144 | y = F.softmax(y1, dim=1) 145 | y = y.cpu().data.numpy() 146 | y = y[0,:,:,:,:] 147 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 148 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 149 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 150 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 151 | score_map = score_map/np.expand_dims(cnt,axis=0) 152 | label_map = np.argmax(score_map, axis = 0) 153 | if add_pad: 154 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 155 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 156 | return label_map, score_map 157 | 158 | def cal_dice(prediction, label, num=2): 159 | total_dice = np.zeros(num-1) 160 | for i in range(1, num): 161 | prediction_tmp = (prediction==i) 162 | label_tmp = (label==i) 163 | prediction_tmp = prediction_tmp.astype(np.float) 164 | label_tmp = label_tmp.astype(np.float) 165 | 166 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 167 | total_dice[i - 1] += dice 168 | 169 | return total_dice 170 | 171 | 172 | def calculate_metric_percase(pred, gt): 173 | dice = metric.binary.dc(pred, gt) 174 | jc = metric.binary.jc(pred, gt) 175 | hd = metric.binary.hd95(pred, gt) 176 | asd = metric.binary.asd(pred, gt) 177 | 178 | return dice, jc, hd, asd 179 | 180 | 181 | 182 | 183 | 184 | if __name__ == '__main__': 185 | metric = test_calculate_metric(4000) 186 | # print(metric) 187 | -------------------------------------------------------------------------------- /code/test_LITS_MultiHead_SDF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import torch 5 | from networks.vnet_multi_head import VNetMultiHead 6 | import h5py 7 | import math 8 | import nibabel as nib 9 | import numpy as np 10 | from medpy import metric 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | import pandas as pd 14 | from collections import OrderedDict 15 | 16 | """ 17 | Testing 18 | Ref: 19 | Shape-Aware Organ Segmentation by Predicting Signed Distance Maps 20 | https://arxiv.org/abs/1912.03849 21 | """ 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--root_path', type=str, default='../data/LITS', help='Name of Experiment') 25 | parser.add_argument('--model', type=str, default='vnet_lits_MH_SDFL1_lr01', help='model_name') 26 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 27 | parser.add_argument('--epoch_num', type=int, default='10000', help='checkpoint to use') 28 | FLAGS = parser.parse_args() 29 | 30 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 31 | snapshot_path = "../model_lits/"+FLAGS.model+"/" 32 | test_save_path = "../model_lits/prediction/"+FLAGS.model+"_post/" 33 | if not os.path.exists(test_save_path): 34 | os.makedirs(test_save_path) 35 | 36 | num_classes = 2 37 | 38 | with open(FLAGS.root_path + '/LITS_test.list', 'r') as f: 39 | image_list = f.readlines() 40 | image_list = [FLAGS.root_path +'/h5/' + item.replace('\n', '') for item in image_list] 41 | 42 | 43 | def test_calculate_metric(epoch_num): 44 | net = VNetMultiHead(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 45 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 46 | net.load_state_dict(torch.load(save_mode_path)) 47 | print("init weight from {}".format(save_mode_path)) 48 | net.eval() 49 | 50 | avg_metric = dist_test_all_case(net, image_list, num_classes=num_classes, 51 | patch_size=(96, 128, 160), stride_xy=18, stride_z=4, 52 | save_result=True, test_save_path=test_save_path) 53 | 54 | return avg_metric 55 | 56 | def dist_test_all_case(net, image_list, num_classes, patch_size=(96, 128, 160), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 57 | total_metric = 0.0 58 | metric_dict = OrderedDict() 59 | metric_dict['name'] = list() 60 | metric_dict['dice'] = list() 61 | metric_dict['jaccard'] = list() 62 | metric_dict['asd'] = list() 63 | metric_dict['95hd'] = list() 64 | for image_path in tqdm(image_list): 65 | case_name = image_path.split('/')[-1].split('.h5')[0] # format: train_id 66 | id = re.findall('\d+', case_name)[0] 67 | h5f = h5py.File(image_path, 'r') 68 | image = h5f['image'][:] 69 | label = h5f['label'][:] 70 | if preproc_fn is not None: 71 | image = preproc_fn(image) 72 | prediction, score_map, pred_dist = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 73 | 74 | if np.sum(prediction)==0: 75 | single_metric = (0,0,0,0) 76 | else: 77 | single_metric = calculate_metric_percase(prediction, label[:]) 78 | metric_dict['name'].append(case_name) 79 | metric_dict['dice'].append(single_metric[0]) 80 | metric_dict['jaccard'].append(single_metric[1]) 81 | metric_dict['asd'].append(single_metric[2]) 82 | metric_dict['95hd'].append(single_metric[3]) 83 | # print(metric_dict) 84 | 85 | 86 | total_metric += np.asarray(single_metric) 87 | 88 | if save_result: 89 | test_save_path_temp = os.path.join(test_save_path, case_name) 90 | if not os.path.exists(test_save_path_temp): 91 | os.makedirs(test_save_path_temp) 92 | # for simplification, we set the spacing as (1,1,1) which is not the original image spacing. It has no effect on metrics. 93 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 94 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 95 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 96 | nib.save(nib.Nifti1Image(pred_dist[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_dist.nii.gz") 97 | avg_metric = total_metric / len(image_list) 98 | metric_csv = pd.DataFrame(metric_dict) 99 | metric_csv.to_csv(test_save_path + '/metric_'+str(FLAGS.epoch_num)+'.csv', index=False) 100 | print('average metric is {}'.format(avg_metric)) 101 | 102 | return avg_metric 103 | 104 | 105 | 106 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 107 | w, h, d = image.shape 108 | 109 | # if the size of image is less than patch_size, then padding it 110 | add_pad = False 111 | if w < patch_size[0]: 112 | w_pad = patch_size[0]-w 113 | add_pad = True 114 | else: 115 | w_pad = 0 116 | if h < patch_size[1]: 117 | h_pad = patch_size[1]-h 118 | add_pad = True 119 | else: 120 | h_pad = 0 121 | if d < patch_size[2]: 122 | d_pad = patch_size[2]-d 123 | add_pad = True 124 | else: 125 | d_pad = 0 126 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 127 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 128 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 129 | if add_pad: 130 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 131 | ww,hh,dd = image.shape 132 | 133 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 134 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 135 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 136 | # print("{}, {}, {}".format(sx, sy, sz)) 137 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 138 | cnt = np.zeros(image.shape).astype(np.float32) 139 | pred_dist = np.zeros(image.shape).astype(np.float32) 140 | 141 | for x in range(0, sx): 142 | xs = min(stride_xy*x, ww-patch_size[0]) 143 | for y in range(0, sy): 144 | ys = min(stride_xy * y,hh-patch_size[1]) 145 | for z in range(0, sz): 146 | zs = min(stride_z * z, dd-patch_size[2]) 147 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 148 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 149 | test_patch = torch.from_numpy(test_patch).cuda() 150 | y1, out_dist = net(test_patch) 151 | # print(y1.shape, out_dist.shape) 152 | y = F.softmax(y1, dim=1) 153 | y = y.cpu().data.numpy() 154 | y = y[0,:,:,:,:] 155 | out_dist = torch.tanh(out_dist) 156 | out_dist = out_dist.cpu().data.numpy() 157 | pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 158 | = pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + out_dist[0,0,:,:,:] 159 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 160 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 161 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 162 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 163 | 164 | 165 | score_map = score_map/np.expand_dims(cnt,axis=0) 166 | pred_dist = pred_dist/cnt 167 | label_map = np.argmax(score_map, axis = 0) 168 | if add_pad: 169 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 170 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 171 | pred_dist = pred_dist[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 172 | return label_map, score_map, pred_dist 173 | 174 | def cal_dice(prediction, label, num=2): 175 | total_dice = np.zeros(num-1) 176 | for i in range(1, num): 177 | prediction_tmp = (prediction==i) 178 | label_tmp = (label==i) 179 | prediction_tmp = prediction_tmp.astype(np.float) 180 | label_tmp = label_tmp.astype(np.float) 181 | 182 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 183 | total_dice[i - 1] += dice 184 | 185 | return total_dice 186 | 187 | 188 | def calculate_metric_percase(pred, gt): 189 | dice = metric.binary.dc(pred, gt) 190 | jc = metric.binary.jc(pred, gt) 191 | hd = metric.binary.hd95(pred, gt) 192 | asd = metric.binary.asd(pred, gt) 193 | 194 | return dice, jc, hd, asd 195 | 196 | 197 | if __name__ == '__main__': 198 | metric = test_calculate_metric(FLAGS.epoch_num) 199 | # print(metric) -------------------------------------------------------------------------------- /code/test_LITS_Rec_SDF.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import torch 5 | from networks.vnet_rec import VNetRec 6 | import h5py 7 | import math 8 | import nibabel as nib 9 | import numpy as np 10 | from medpy import metric 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | import pandas as pd 14 | from collections import OrderedDict 15 | 16 | """ 17 | Testing 18 | Adding reconstruction branch to V-Net 19 | Ref: 20 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 21 | https://arxiv.org/abs/1901.01238 22 | """ 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--root_path', type=str, default='../data/LITS', help='Name of Experiment') 26 | parser.add_argument('--model', type=str, default='vnet_lits_Rec_SDF_L1PlusL2_lr01', help='model_name') 27 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 28 | parser.add_argument('--epoch_num', type=int, default='10000', help='checkpoint to use') 29 | FLAGS = parser.parse_args() 30 | 31 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 32 | snapshot_path = "../model_lits/"+FLAGS.model+"/" 33 | test_save_path = "../model_lits/prediction/"+FLAGS.model+"_post/" 34 | if not os.path.exists(test_save_path): 35 | os.makedirs(test_save_path) 36 | 37 | num_classes = 2 38 | 39 | with open(FLAGS.root_path + '/LITS_test.list', 'r') as f: 40 | image_list = f.readlines() 41 | image_list = [FLAGS.root_path +'/h5/' + item.replace('\n', '') for item in image_list] 42 | 43 | 44 | def test_calculate_metric(epoch_num): 45 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 46 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 47 | net.load_state_dict(torch.load(save_mode_path)) 48 | print("init weight from {}".format(save_mode_path)) 49 | net.eval() 50 | 51 | avg_metric = dist_test_all_case(net, image_list, num_classes=num_classes, 52 | patch_size=(96, 128, 160), stride_xy=18, stride_z=4, 53 | save_result=True, test_save_path=test_save_path) 54 | 55 | return avg_metric 56 | 57 | def dist_test_all_case(net, image_list, num_classes, patch_size=(96, 128, 160), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 58 | total_metric = 0.0 59 | metric_dict = OrderedDict() 60 | metric_dict['name'] = list() 61 | metric_dict['dice'] = list() 62 | metric_dict['jaccard'] = list() 63 | metric_dict['asd'] = list() 64 | metric_dict['95hd'] = list() 65 | for image_path in tqdm(image_list): 66 | case_name = image_path.split('/')[-1].split('.h5')[0] # format: train_id 67 | id = re.findall('\d+', case_name)[0] 68 | h5f = h5py.File(image_path, 'r') 69 | image = h5f['image'][:] 70 | label = h5f['label'][:] 71 | if preproc_fn is not None: 72 | image = preproc_fn(image) 73 | prediction, score_map, pred_dist = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 74 | 75 | if np.sum(prediction)==0: 76 | single_metric = (0,0,0,0) 77 | else: 78 | single_metric = calculate_metric_percase(prediction, label[:]) 79 | metric_dict['name'].append(case_name) 80 | metric_dict['dice'].append(single_metric[0]) 81 | metric_dict['jaccard'].append(single_metric[1]) 82 | metric_dict['asd'].append(single_metric[2]) 83 | metric_dict['95hd'].append(single_metric[3]) 84 | # print(metric_dict) 85 | 86 | 87 | total_metric += np.asarray(single_metric) 88 | 89 | if save_result: 90 | test_save_path_temp = os.path.join(test_save_path, case_name) 91 | if not os.path.exists(test_save_path_temp): 92 | os.makedirs(test_save_path_temp) 93 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 94 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 95 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 96 | nib.save(nib.Nifti1Image(pred_dist[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_dist.nii.gz") 97 | avg_metric = total_metric / len(image_list) 98 | metric_csv = pd.DataFrame(metric_dict) 99 | metric_csv.to_csv(test_save_path + '/metric_'+str(FLAGS.epoch_num)+'.csv', index=False) 100 | print('average metric is {}'.format(avg_metric)) 101 | 102 | return avg_metric 103 | 104 | 105 | 106 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 107 | w, h, d = image.shape 108 | 109 | # if the size of image is less than patch_size, then padding it 110 | add_pad = False 111 | if w < patch_size[0]: 112 | w_pad = patch_size[0]-w 113 | add_pad = True 114 | else: 115 | w_pad = 0 116 | if h < patch_size[1]: 117 | h_pad = patch_size[1]-h 118 | add_pad = True 119 | else: 120 | h_pad = 0 121 | if d < patch_size[2]: 122 | d_pad = patch_size[2]-d 123 | add_pad = True 124 | else: 125 | d_pad = 0 126 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 127 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 128 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 129 | if add_pad: 130 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 131 | ww,hh,dd = image.shape 132 | 133 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 134 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 135 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 136 | # print("{}, {}, {}".format(sx, sy, sz)) 137 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 138 | cnt = np.zeros(image.shape).astype(np.float32) 139 | pred_dist = np.zeros(image.shape).astype(np.float32) 140 | 141 | for x in range(0, sx): 142 | xs = min(stride_xy*x, ww-patch_size[0]) 143 | for y in range(0, sy): 144 | ys = min(stride_xy * y,hh-patch_size[1]) 145 | for z in range(0, sz): 146 | zs = min(stride_z * z, dd-patch_size[2]) 147 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 148 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 149 | test_patch = torch.from_numpy(test_patch).cuda() 150 | y1, out_dist = net(test_patch) 151 | # print(y1.shape, out_dist.shape) # ([1, 2, 112, 112, 80]) ([1, 1, 112, 112, 80]) 152 | y = F.softmax(y1, dim=1) 153 | y = y.cpu().data.numpy() 154 | y = y[0,:,:,:,:] 155 | out_dist = torch.tanh(out_dist) 156 | out_dist = out_dist.cpu().data.numpy() 157 | pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 158 | = pred_dist[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + out_dist[0,0,:,:,:] 159 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 160 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 161 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 162 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 163 | 164 | 165 | score_map = score_map/np.expand_dims(cnt,axis=0) 166 | pred_dist = pred_dist/cnt 167 | label_map = np.argmax(score_map, axis = 0) 168 | if add_pad: 169 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 170 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 171 | pred_dist = pred_dist[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 172 | return label_map, score_map, pred_dist 173 | 174 | def cal_dice(prediction, label, num=2): 175 | total_dice = np.zeros(num-1) 176 | for i in range(1, num): 177 | prediction_tmp = (prediction==i) 178 | label_tmp = (label==i) 179 | prediction_tmp = prediction_tmp.astype(np.float) 180 | label_tmp = label_tmp.astype(np.float) 181 | 182 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 183 | total_dice[i - 1] += dice 184 | 185 | return total_dice 186 | 187 | 188 | def calculate_metric_percase(pred, gt): 189 | dice = metric.binary.dc(pred, gt) 190 | jc = metric.binary.jc(pred, gt) 191 | hd = metric.binary.hd95(pred, gt) 192 | asd = metric.binary.asd(pred, gt) 193 | 194 | return dice, jc, hd, asd 195 | 196 | 197 | if __name__ == '__main__': 198 | metric = test_calculate_metric(FLAGS.epoch_num) 199 | # print(metric) -------------------------------------------------------------------------------- /code/test_util.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import math 3 | import nibabel as nib 4 | import numpy as np 5 | from medpy import metric 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import os 10 | import pandas as pd 11 | from collections import OrderedDict 12 | 13 | def test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None): 14 | total_metric = 0.0 15 | metric_dict = OrderedDict() 16 | metric_dict['name'] = list() 17 | metric_dict['dice'] = list() 18 | metric_dict['jaccard'] = list() 19 | metric_dict['asd'] = list() 20 | metric_dict['95hd'] = list() 21 | for image_path in tqdm(image_list): 22 | case_name = image_path.split('/')[-2] 23 | id = image_path.split('/')[-1] 24 | h5f = h5py.File(image_path, 'r') 25 | image = h5f['image'][:] 26 | label = h5f['label'][:] 27 | if preproc_fn is not None: 28 | image = preproc_fn(image) 29 | prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 30 | 31 | if np.sum(prediction)==0: 32 | single_metric = (0,0,0,0) 33 | else: 34 | single_metric = calculate_metric_percase(prediction, label[:]) 35 | metric_dict['name'].append(case_name) 36 | metric_dict['dice'].append(single_metric[0]) 37 | metric_dict['jaccard'].append(single_metric[1]) 38 | metric_dict['asd'].append(single_metric[2]) 39 | metric_dict['95hd'].append(single_metric[3]) 40 | # print(metric_dict) 41 | 42 | 43 | total_metric += np.asarray(single_metric) 44 | 45 | if save_result: 46 | test_save_path_temp = os.path.join(test_save_path, case_name) 47 | if not os.path.exists(test_save_path_temp): 48 | os.makedirs(test_save_path_temp) 49 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_pred.nii.gz") 50 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_img.nii.gz") 51 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path_temp + '/' + id + "_gt.nii.gz") 52 | avg_metric = total_metric / len(image_list) 53 | metric_csv = pd.DataFrame(metric_dict) 54 | metric_csv.to_csv(test_save_path + '/metric.csv', index=False) 55 | print('average metric is {}'.format(avg_metric)) 56 | 57 | return avg_metric 58 | 59 | 60 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 61 | w, h, d = image.shape 62 | 63 | # if the size of image is less than patch_size, then padding it 64 | add_pad = False 65 | if w < patch_size[0]: 66 | w_pad = patch_size[0]-w 67 | add_pad = True 68 | else: 69 | w_pad = 0 70 | if h < patch_size[1]: 71 | h_pad = patch_size[1]-h 72 | add_pad = True 73 | else: 74 | h_pad = 0 75 | if d < patch_size[2]: 76 | d_pad = patch_size[2]-d 77 | add_pad = True 78 | else: 79 | d_pad = 0 80 | wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 81 | hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 82 | dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 83 | if add_pad: 84 | image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 85 | ww,hh,dd = image.shape 86 | 87 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 88 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 89 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 90 | print("{}, {}, {}".format(sx, sy, sz)) 91 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 92 | cnt = np.zeros(image.shape).astype(np.float32) 93 | 94 | for x in range(0, sx): 95 | xs = min(stride_xy*x, ww-patch_size[0]) 96 | for y in range(0, sy): 97 | ys = min(stride_xy * y,hh-patch_size[1]) 98 | for z in range(0, sz): 99 | zs = min(stride_z * z, dd-patch_size[2]) 100 | test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 101 | test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 102 | test_patch = torch.from_numpy(test_patch).cuda() 103 | y1 = net(test_patch) 104 | y = F.softmax(y1, dim=1) 105 | y = y.cpu().data.numpy() 106 | y = y[0,:,:,:,:] 107 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 108 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 109 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 110 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 111 | score_map = score_map/np.expand_dims(cnt,axis=0) 112 | label_map = np.argmax(score_map, axis = 0) 113 | if add_pad: 114 | label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 115 | score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 116 | return label_map, score_map 117 | 118 | def cal_dice(prediction, label, num=2): 119 | total_dice = np.zeros(num-1) 120 | for i in range(1, num): 121 | prediction_tmp = (prediction==i) 122 | label_tmp = (label==i) 123 | prediction_tmp = prediction_tmp.astype(np.float) 124 | label_tmp = label_tmp.astype(np.float) 125 | 126 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 127 | total_dice[i - 1] += dice 128 | 129 | return total_dice 130 | 131 | 132 | def calculate_metric_percase(pred, gt): 133 | dice = metric.binary.dc(pred, gt) 134 | jc = metric.binary.jc(pred, gt) 135 | hd = metric.binary.hd95(pred, gt) 136 | asd = metric.binary.asd(pred, gt) 137 | 138 | return dice, jc, hd, asd 139 | -------------------------------------------------------------------------------- /code/train_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet import VNet 21 | from utils.losses import dice_loss 22 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 27 | parser.add_argument('--exp', type=str, default='vnet_supervisedonly', help='model_name') 28 | parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train') 29 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 30 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 31 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 32 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 33 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 34 | args = parser.parse_args() 35 | 36 | train_data_path = args.root_path 37 | snapshot_path = "../model_la/" + args.exp + "/" 38 | 39 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 40 | batch_size = args.batch_size * len(args.gpu.split(',')) 41 | max_iterations = args.max_iterations 42 | base_lr = args.base_lr 43 | 44 | if args.deterministic: 45 | cudnn.benchmark = False 46 | cudnn.deterministic = True 47 | random.seed(args.seed) 48 | np.random.seed(args.seed) 49 | torch.manual_seed(args.seed) 50 | torch.cuda.manual_seed(args.seed) 51 | 52 | patch_size = (112, 112, 80) 53 | num_classes = 2 54 | 55 | if __name__ == "__main__": 56 | ## make logger file 57 | if not os.path.exists(snapshot_path): 58 | os.makedirs(snapshot_path) 59 | if os.path.exists(snapshot_path + '/code'): 60 | shutil.rmtree(snapshot_path + '/code') 61 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 62 | 63 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 64 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 65 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 66 | logging.info(str(args)) 67 | 68 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 69 | net = net.cuda() 70 | 71 | db_train = LAHeart(base_dir=train_data_path, 72 | split='train', 73 | num=16, 74 | transform = transforms.Compose([ 75 | RandomRotFlip(), 76 | RandomCrop(patch_size), 77 | ToTensor(), 78 | ])) 79 | db_test = LAHeart(base_dir=train_data_path, 80 | split='test', 81 | transform = transforms.Compose([ 82 | CenterCrop(patch_size), 83 | ToTensor() 84 | ])) 85 | def worker_init_fn(worker_id): 86 | random.seed(args.seed+worker_id) 87 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 88 | 89 | net.train() 90 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 91 | 92 | writer = SummaryWriter(snapshot_path+'/log') 93 | logging.info("{} itertations per epoch".format(len(trainloader))) 94 | 95 | iter_num = 0 96 | max_epoch = max_iterations//len(trainloader)+1 97 | lr_ = base_lr 98 | net.train() 99 | for epoch_num in tqdm(range(max_epoch), ncols=70): 100 | time1 = time.time() 101 | for i_batch, sampled_batch in enumerate(trainloader): 102 | time2 = time.time() 103 | # print('fetch data cost {}'.format(time2-time1)) 104 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 105 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 106 | outputs = net(volume_batch) 107 | 108 | loss_seg = F.cross_entropy(outputs, label_batch) 109 | outputs_soft = F.softmax(outputs, dim=1) 110 | loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 111 | loss = 0.5*(loss_seg+loss_seg_dice) 112 | 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | 117 | iter_num = iter_num + 1 118 | writer.add_scalar('lr', lr_, iter_num) 119 | writer.add_scalar('loss/loss_seg', loss_seg, iter_num) 120 | writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num) 121 | writer.add_scalar('loss/loss', loss, iter_num) 122 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 123 | if iter_num % 50 == 0: 124 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 125 | grid_image = make_grid(image, 5, normalize=True) 126 | writer.add_image('train/Image', grid_image, iter_num) 127 | 128 | outputs_soft = F.softmax(outputs, 1) 129 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 130 | grid_image = make_grid(image, 5, normalize=False) 131 | writer.add_image('train/Predicted_label', grid_image, iter_num) 132 | 133 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 134 | grid_image = make_grid(image, 5, normalize=False) 135 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 136 | 137 | ## change lr 138 | if iter_num % 2500 == 0: 139 | lr_ = base_lr * 0.1 ** (iter_num // 2500) 140 | for param_group in optimizer.param_groups: 141 | param_group['lr'] = lr_ 142 | if iter_num % 1000 == 0: 143 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 144 | torch.save(net.state_dict(), save_mode_path) 145 | logging.info("save model to {}".format(save_mode_path)) 146 | 147 | if iter_num > max_iterations: 148 | break 149 | time1 = time.time() 150 | if iter_num > max_iterations: 151 | break 152 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 153 | torch.save(net.state_dict(), save_mode_path) 154 | logging.info("save model to {}".format(save_mode_path)) 155 | writer.close() 156 | -------------------------------------------------------------------------------- /code/train_LA_MultiHead_FGDTM_L1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_multi_head import VNetMultiHead 21 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | 24 | 25 | """ 26 | Train a multi-head vnet to output 27 | 1) predicted segmentation 28 | 2) regress the distance transform map 29 | e.g. 30 | Deep Distance Transform for Tubular Structure Segmentation in CT Scans 31 | https://arxiv.org/abs/1912.03383 32 | Shape-Aware Complementary-Task Learning for Multi-Organ Segmentation 33 | https://arxiv.org/abs/1908.05099 34 | """ 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 38 | parser.add_argument('--exp', type=str, default='vnet_dp_la_MH_FGDTM_L1', help='model_name;dp:add dropout; MH:multi-head') 39 | parser.add_argument('--max_iterations', type=int, default=10000, help='maximum epoch number to train') 40 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 41 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 42 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 43 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 44 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 45 | args = parser.parse_args() 46 | 47 | train_data_path = args.root_path 48 | snapshot_path = "../model_la/" + args.exp + "/" 49 | 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 51 | batch_size = args.batch_size * len(args.gpu.split(',')) 52 | max_iterations = args.max_iterations 53 | base_lr = args.base_lr 54 | 55 | if args.deterministic: 56 | cudnn.benchmark = False 57 | cudnn.deterministic = True 58 | random.seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.cuda.manual_seed(args.seed) 62 | 63 | patch_size = (112, 112, 80) 64 | num_classes = 2 65 | 66 | def dice_loss(score, target): 67 | target = target.float() 68 | smooth = 1e-5 69 | intersect = torch.sum(score * target) 70 | y_sum = torch.sum(target * target) 71 | z_sum = torch.sum(score * score) 72 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 73 | loss = 1 - loss 74 | return loss 75 | 76 | def compute_dtm(img_gt, out_shape): 77 | """ 78 | compute the distance transform map of foreground in binary mask 79 | input: segmentation, shape = (batch_size, x, y, z) 80 | output: the foreground Distance Map (SDM) 81 | dtm(x) = 0; x in segmentation boundary 82 | inf|x-y|; x in segmentation 83 | """ 84 | 85 | fg_dtm = np.zeros(out_shape) 86 | 87 | for b in range(out_shape[0]): # batch size 88 | for c in range(out_shape[1]): 89 | posmask = img_gt[b].astype(np.bool) 90 | if posmask.any(): 91 | posdis = distance(posmask) 92 | fg_dtm[b][c] = posdis 93 | 94 | return fg_dtm 95 | 96 | 97 | 98 | if __name__ == "__main__": 99 | ## make logger file 100 | if not os.path.exists(snapshot_path): 101 | os.makedirs(snapshot_path) 102 | if os.path.exists(snapshot_path + '/code'): 103 | shutil.rmtree(snapshot_path + '/code') 104 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 105 | 106 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 107 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 108 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 109 | logging.info(str(args)) 110 | 111 | net = VNetMultiHead(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 112 | net = net.cuda() 113 | 114 | db_train = LAHeart(base_dir=train_data_path, 115 | split='train', 116 | num=16, 117 | transform = transforms.Compose([ 118 | RandomRotFlip(), 119 | RandomCrop(patch_size), 120 | ToTensor(), 121 | ])) 122 | 123 | def worker_init_fn(worker_id): 124 | random.seed(args.seed+worker_id) 125 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 126 | 127 | net.train() 128 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 129 | 130 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 131 | logging.info("{} itertations per epoch".format(len(trainloader))) 132 | 133 | iter_num = 0 134 | max_epoch = max_iterations//len(trainloader)+1 135 | lr_ = base_lr 136 | net.train() 137 | for epoch_num in tqdm(range(max_epoch), ncols=70): 138 | for i_batch, sampled_batch in enumerate(trainloader): 139 | # generate paired iput 140 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 141 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 142 | outputs, out_dis = net(volume_batch) 143 | 144 | with torch.no_grad(): 145 | gt_dis = compute_dtm(label_batch.cpu().numpy(), out_dis.shape) 146 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 147 | 148 | # compute CE + Dice loss 149 | loss_ce = F.cross_entropy(outputs, label_batch) 150 | outputs_soft = F.softmax(outputs, dim=1) 151 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 152 | # compute L1 Loss 153 | loss_dist = torch.norm(out_dis-gt_dis, 1)/torch.numel(out_dis) 154 | 155 | loss = loss_ce + loss_dice + loss_dist 156 | 157 | optimizer.zero_grad() 158 | loss.backward() 159 | optimizer.step() 160 | 161 | iter_num = iter_num + 1 162 | writer.add_scalar('lr', lr_, iter_num) 163 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 164 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 165 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 166 | writer.add_scalar('loss/loss', loss, iter_num) 167 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 168 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 169 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 170 | if iter_num % 2 == 0: 171 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 172 | grid_image = make_grid(image, 5, normalize=True) 173 | writer.add_image('train/Image', grid_image, iter_num) 174 | 175 | outputs_soft = F.softmax(outputs, 1) 176 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 177 | grid_image = make_grid(image, 5, normalize=False) 178 | writer.add_image('train/Predicted_label', grid_image, iter_num) 179 | 180 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 181 | grid_image = make_grid(image, 5, normalize=False) 182 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 183 | 184 | out_dis_slice = out_dis[0, 0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 185 | grid_image = make_grid(out_dis_slice, 5, normalize=False) 186 | writer.add_image('train/out_dis_map', grid_image, iter_num) 187 | 188 | gt_dis_slice = gt_dis[0, 0,:, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 189 | grid_image = make_grid(gt_dis_slice, 5, normalize=False) 190 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 191 | ## change lr 192 | if iter_num % 2500 == 0: 193 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 194 | for param_group in optimizer.param_groups: 195 | param_group['lr'] = lr_ 196 | if iter_num % 1000 == 0: 197 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 198 | torch.save(net.state_dict(), save_mode_path) 199 | logging.info("save model to {}".format(save_mode_path)) 200 | 201 | if iter_num > max_iterations: 202 | break 203 | time1 = time.time() 204 | if iter_num > max_iterations: 205 | break 206 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 207 | torch.save(net.state_dict(), save_mode_path) 208 | logging.info("save model to {}".format(save_mode_path)) 209 | writer.close() 210 | -------------------------------------------------------------------------------- /code/train_LA_MultiHead_FGDTM_L1PlusL2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_multi_head import VNetMultiHead 21 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | 24 | 25 | """ 26 | Train a multi-head vnet to output 27 | 1) predicted segmentation 28 | 2) regress the distance transform map 29 | e.g. 30 | Deep Distance Transform for Tubular Structure Segmentation in CT Scans 31 | https://arxiv.org/abs/1912.03383 32 | Shape-Aware Complementary-Task Learning for Multi-Organ Segmentation 33 | https://arxiv.org/abs/1908.05099 34 | """ 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 38 | parser.add_argument('--exp', type=str, default='vnet_dp_la_MH_FGDTM_L1PlusL2', help='model_name;dp:add dropout; MH:multi-head') 39 | parser.add_argument('--max_iterations', type=int, default=10000, help='maximum epoch number to train') 40 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 41 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 42 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 43 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 44 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 45 | args = parser.parse_args() 46 | 47 | train_data_path = args.root_path 48 | snapshot_path = "../model_la/" + args.exp + "/" 49 | 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 51 | batch_size = args.batch_size * len(args.gpu.split(',')) 52 | max_iterations = args.max_iterations 53 | base_lr = args.base_lr 54 | 55 | if args.deterministic: 56 | cudnn.benchmark = False 57 | cudnn.deterministic = True 58 | random.seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.cuda.manual_seed(args.seed) 62 | 63 | patch_size = (112, 112, 80) 64 | num_classes = 2 65 | 66 | def dice_loss(score, target): 67 | target = target.float() 68 | smooth = 1e-5 69 | intersect = torch.sum(score * target) 70 | y_sum = torch.sum(target * target) 71 | z_sum = torch.sum(score * score) 72 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 73 | loss = 1 - loss 74 | return loss 75 | 76 | def compute_dtm(img_gt, out_shape): 77 | """ 78 | compute the distance transform map of foreground in binary mask 79 | input: segmentation, shape = (batch_size, x, y, z) 80 | output: the foreground Distance Map (SDM) 81 | dtm(x) = 0; x in segmentation boundary 82 | inf|x-y|; x in segmentation 83 | """ 84 | fg_dtm = np.zeros(out_shape) 85 | 86 | for b in range(out_shape[0]): # batch size 87 | for c in range(out_shape[1]): 88 | posmask = img_gt[b].astype(np.bool) 89 | if posmask.any(): 90 | posdis = distance(posmask) 91 | fg_dtm[b][c] = posdis 92 | 93 | return fg_dtm 94 | 95 | 96 | 97 | if __name__ == "__main__": 98 | ## make logger file 99 | if not os.path.exists(snapshot_path): 100 | os.makedirs(snapshot_path) 101 | if os.path.exists(snapshot_path + '/code'): 102 | shutil.rmtree(snapshot_path + '/code') 103 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 104 | 105 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 106 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 107 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 108 | logging.info(str(args)) 109 | 110 | net = VNetMultiHead(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 111 | net = net.cuda() 112 | 113 | db_train = LAHeart(base_dir=train_data_path, 114 | split='train', 115 | num=16, 116 | transform = transforms.Compose([ 117 | RandomRotFlip(), 118 | RandomCrop(patch_size), 119 | ToTensor(), 120 | ])) 121 | 122 | def worker_init_fn(worker_id): 123 | random.seed(args.seed+worker_id) 124 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 125 | 126 | net.train() 127 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 128 | 129 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 130 | logging.info("{} itertations per epoch".format(len(trainloader))) 131 | 132 | iter_num = 0 133 | max_epoch = max_iterations//len(trainloader)+1 134 | lr_ = base_lr 135 | net.train() 136 | for epoch_num in tqdm(range(max_epoch), ncols=70): 137 | for i_batch, sampled_batch in enumerate(trainloader): 138 | # generate paired iput 139 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 140 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 141 | outputs, out_dis = net(volume_batch) 142 | 143 | with torch.no_grad(): 144 | gt_dis = compute_dtm(label_batch.cpu().numpy(), out_dis.shape) 145 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 146 | 147 | # compute CE + Dice loss 148 | loss_ce = F.cross_entropy(outputs, label_batch) 149 | outputs_soft = F.softmax(outputs, dim=1) 150 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 151 | # compute L1 + L2 Loss 152 | loss_dist = torch.norm(out_dis-gt_dis, 1)/torch.numel(out_dis) + F.mse_loss(out_dis, gt_dis) 153 | 154 | loss = loss_ce + loss_dice + loss_dist 155 | 156 | optimizer.zero_grad() 157 | loss.backward() 158 | optimizer.step() 159 | 160 | iter_num = iter_num + 1 161 | writer.add_scalar('lr', lr_, iter_num) 162 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 163 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 164 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 165 | writer.add_scalar('loss/loss', loss, iter_num) 166 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 167 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 168 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 169 | if iter_num % 2 == 0: 170 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 171 | grid_image = make_grid(image, 5, normalize=True) 172 | writer.add_image('train/Image', grid_image, iter_num) 173 | 174 | outputs_soft = F.softmax(outputs, 1) 175 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 176 | grid_image = make_grid(image, 5, normalize=False) 177 | writer.add_image('train/Predicted_label', grid_image, iter_num) 178 | 179 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 180 | grid_image = make_grid(image, 5, normalize=False) 181 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 182 | 183 | out_dis_slice = out_dis[0, 0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 184 | grid_image = make_grid(out_dis_slice, 5, normalize=False) 185 | writer.add_image('train/out_dis_map', grid_image, iter_num) 186 | 187 | gt_dis_slice = gt_dis[0, 0,:, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 188 | grid_image = make_grid(gt_dis_slice, 5, normalize=False) 189 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 190 | ## change lr 191 | if iter_num % 2500 == 0: 192 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 193 | for param_group in optimizer.param_groups: 194 | param_group['lr'] = lr_ 195 | if iter_num % 1000 == 0: 196 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 197 | torch.save(net.state_dict(), save_mode_path) 198 | logging.info("save model to {}".format(save_mode_path)) 199 | 200 | if iter_num > max_iterations: 201 | break 202 | time1 = time.time() 203 | if iter_num > max_iterations: 204 | break 205 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 206 | torch.save(net.state_dict(), save_mode_path) 207 | logging.info("save model to {}".format(save_mode_path)) 208 | writer.close() 209 | -------------------------------------------------------------------------------- /code/train_LA_MultiHead_FGDTM_L2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_multi_head import VNetMultiHead 21 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | 24 | 25 | """ 26 | Train a multi-head vnet to output 27 | 1) predicted segmentation 28 | 2) regress the distance transform map 29 | e.g. 30 | Deep Distance Transform for Tubular Structure Segmentation in CT Scans 31 | https://arxiv.org/abs/1912.03383 32 | Shape-Aware Complementary-Task Learning for Multi-Organ Segmentation 33 | https://arxiv.org/abs/1908.05099 34 | """ 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 38 | parser.add_argument('--exp', type=str, default='vnet_dp_la_MH_FGDTM_L2', help='model_name;dp:add dropout; MH:multi-head') 39 | parser.add_argument('--max_iterations', type=int, default=10000, help='maximum epoch number to train') 40 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 41 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 42 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 43 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 44 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 45 | args = parser.parse_args() 46 | 47 | train_data_path = args.root_path 48 | snapshot_path = "../model_la/" + args.exp + "/" 49 | 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 51 | batch_size = args.batch_size * len(args.gpu.split(',')) 52 | max_iterations = args.max_iterations 53 | base_lr = args.base_lr 54 | 55 | if args.deterministic: 56 | cudnn.benchmark = False 57 | cudnn.deterministic = True 58 | random.seed(args.seed) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.cuda.manual_seed(args.seed) 62 | 63 | patch_size = (112, 112, 80) 64 | num_classes = 2 65 | 66 | def dice_loss(score, target): 67 | target = target.float() 68 | smooth = 1e-5 69 | intersect = torch.sum(score * target) 70 | y_sum = torch.sum(target * target) 71 | z_sum = torch.sum(score * score) 72 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 73 | loss = 1 - loss 74 | return loss 75 | 76 | def compute_dtm(img_gt, out_shape): 77 | """ 78 | compute the distance transform map of foreground in binary mask 79 | input: segmentation, shape = (batch_size, x, y, z) 80 | output: the foreground Distance Map (SDM) 81 | dtm(x) = 0; x in segmentation boundary 82 | inf|x-y|; x in segmentation 83 | """ 84 | 85 | fg_dtm = np.zeros(out_shape) 86 | 87 | for b in range(out_shape[0]): # batch size 88 | for c in range(out_shape[1]): 89 | posmask = img_gt[b].astype(np.bool) 90 | if posmask.any(): 91 | posdis = distance(posmask) 92 | fg_dtm[b][c] = posdis 93 | 94 | return fg_dtm 95 | 96 | 97 | 98 | if __name__ == "__main__": 99 | ## make logger file 100 | if not os.path.exists(snapshot_path): 101 | os.makedirs(snapshot_path) 102 | if os.path.exists(snapshot_path + '/code'): 103 | shutil.rmtree(snapshot_path + '/code') 104 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 105 | 106 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 107 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 108 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 109 | logging.info(str(args)) 110 | 111 | net = VNetMultiHead(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 112 | net = net.cuda() 113 | 114 | db_train = LAHeart(base_dir=train_data_path, 115 | split='train', 116 | num=16, 117 | transform = transforms.Compose([ 118 | RandomRotFlip(), 119 | RandomCrop(patch_size), 120 | ToTensor(), 121 | ])) 122 | 123 | def worker_init_fn(worker_id): 124 | random.seed(args.seed+worker_id) 125 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 126 | 127 | net.train() 128 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 129 | 130 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 131 | logging.info("{} itertations per epoch".format(len(trainloader))) 132 | 133 | iter_num = 0 134 | max_epoch = max_iterations//len(trainloader)+1 135 | lr_ = base_lr 136 | net.train() 137 | for epoch_num in tqdm(range(max_epoch), ncols=70): 138 | for i_batch, sampled_batch in enumerate(trainloader): 139 | # generate paired iput 140 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 141 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 142 | outputs, out_dis = net(volume_batch) 143 | 144 | with torch.no_grad(): 145 | gt_dis = compute_dtm(label_batch.cpu().numpy(), out_dis.shape) 146 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 147 | 148 | # compute CE + Dice loss 149 | loss_ce = F.cross_entropy(outputs, label_batch) 150 | outputs_soft = F.softmax(outputs, dim=1) 151 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 152 | # compute L2 Loss 153 | loss_dist = F.mse_loss(out_dis, gt_dis) 154 | 155 | loss = loss_ce + loss_dice + loss_dist 156 | 157 | optimizer.zero_grad() 158 | loss.backward() 159 | optimizer.step() 160 | 161 | iter_num = iter_num + 1 162 | writer.add_scalar('lr', lr_, iter_num) 163 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 164 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 165 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 166 | writer.add_scalar('loss/loss', loss, iter_num) 167 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 168 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 169 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 170 | if iter_num % 2 == 0: 171 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 172 | grid_image = make_grid(image, 5, normalize=True) 173 | writer.add_image('train/Image', grid_image, iter_num) 174 | 175 | outputs_soft = F.softmax(outputs, 1) 176 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 177 | grid_image = make_grid(image, 5, normalize=False) 178 | writer.add_image('train/Predicted_label', grid_image, iter_num) 179 | 180 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 181 | grid_image = make_grid(image, 5, normalize=False) 182 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 183 | 184 | out_dis_slice = out_dis[0, 0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 185 | grid_image = make_grid(out_dis_slice, 5, normalize=False) 186 | writer.add_image('train/out_dis_map', grid_image, iter_num) 187 | 188 | gt_dis_slice = gt_dis[0, 0,:, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 189 | grid_image = make_grid(gt_dis_slice, 5, normalize=False) 190 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 191 | ## change lr 192 | if iter_num % 2500 == 0: 193 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 194 | for param_group in optimizer.param_groups: 195 | param_group['lr'] = lr_ 196 | if iter_num % 1000 == 0: 197 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 198 | torch.save(net.state_dict(), save_mode_path) 199 | logging.info("save model to {}".format(save_mode_path)) 200 | 201 | if iter_num > max_iterations: 202 | break 203 | time1 = time.time() 204 | if iter_num > max_iterations: 205 | break 206 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 207 | torch.save(net.state_dict(), save_mode_path) 208 | logging.info("save model to {}".format(save_mode_path)) 209 | writer.close() 210 | -------------------------------------------------------------------------------- /code/train_LA_Rec_FGDTM_L1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_rec import VNetRec 21 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | 24 | 25 | """ 26 | Adding reconstruction branch to V-Net 27 | Ref: 28 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 29 | https://arxiv.org/abs/1901.01238 30 | """ 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 34 | parser.add_argument('--exp', type=str, default='vnet_dp_la_Rec_FGDTM_L1', help='model_name;dp:add dropout; Rec:Reconstruction') 35 | parser.add_argument('--max_iterations', type=int, default=10000, help='maximum epoch number to train') 36 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 37 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 40 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 41 | args = parser.parse_args() 42 | 43 | train_data_path = args.root_path 44 | snapshot_path = "../model_la/" + args.exp + "/" 45 | 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | batch_size = args.batch_size * len(args.gpu.split(',')) 48 | max_iterations = args.max_iterations 49 | base_lr = args.base_lr 50 | 51 | if args.deterministic: 52 | cudnn.benchmark = False 53 | cudnn.deterministic = True 54 | random.seed(args.seed) 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | patch_size = (112, 112, 80) 60 | num_classes = 2 61 | 62 | def dice_loss(score, target): 63 | target = target.float() 64 | smooth = 1e-5 65 | intersect = torch.sum(score * target) 66 | y_sum = torch.sum(target * target) 67 | z_sum = torch.sum(score * score) 68 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 69 | loss = 1 - loss 70 | return loss 71 | 72 | def compute_dtm(img_gt, out_shape): 73 | """ 74 | compute the distance transform map of foreground in binary mask 75 | input: segmentation, shape = (batch_size, x, y, z) 76 | output: the foreground Distance Map (SDM) 77 | dtm(x) = 0; x in segmentation boundary 78 | inf|x-y|; x in segmentation 79 | """ 80 | fg_dtm = np.zeros(out_shape) 81 | 82 | for b in range(out_shape[0]): # batch size 83 | for c in range(out_shape[1]): 84 | posmask = img_gt[b].astype(np.bool) 85 | if posmask.any(): 86 | posdis = distance(posmask) 87 | fg_dtm[b][c] = posdis 88 | 89 | return fg_dtm 90 | 91 | 92 | 93 | if __name__ == "__main__": 94 | ## make logger file 95 | if not os.path.exists(snapshot_path): 96 | os.makedirs(snapshot_path) 97 | if os.path.exists(snapshot_path + '/code'): 98 | shutil.rmtree(snapshot_path + '/code') 99 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 100 | 101 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 102 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 103 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 104 | logging.info(str(args)) 105 | 106 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 107 | net = net.cuda() 108 | 109 | db_train = LAHeart(base_dir=train_data_path, 110 | split='train', 111 | num=16, 112 | transform = transforms.Compose([ 113 | RandomRotFlip(), 114 | RandomCrop(patch_size), 115 | ToTensor(), 116 | ])) 117 | 118 | def worker_init_fn(worker_id): 119 | random.seed(args.seed+worker_id) 120 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 121 | 122 | net.train() 123 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 124 | 125 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 126 | logging.info("{} itertations per epoch".format(len(trainloader))) 127 | 128 | iter_num = 0 129 | max_epoch = max_iterations//len(trainloader)+1 130 | lr_ = base_lr 131 | net.train() 132 | for epoch_num in tqdm(range(max_epoch), ncols=70): 133 | for i_batch, sampled_batch in enumerate(trainloader): 134 | # generate paired iput 135 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 136 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 137 | outputs, out_dis = net(volume_batch) 138 | 139 | with torch.no_grad(): 140 | gt_dis = compute_dtm(label_batch.cpu().numpy(), out_dis.shape) 141 | # debug: check distance map 142 | # import matplotlib.pyplot as plt 143 | # plt.figure() 144 | # plt.imshow(gt_dis[0,0,:,:,40]); plt.axis('off'); plt.colorbar() 145 | # plt.show() 146 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 147 | 148 | # compute CE + Dice loss 149 | loss_ce = F.cross_entropy(outputs, label_batch) 150 | outputs_soft = F.softmax(outputs, dim=1) 151 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 152 | # compute L1 Loss 153 | loss_dist = torch.norm(out_dis-gt_dis, 1)/torch.numel(out_dis) 154 | 155 | loss = loss_ce + loss_dice + loss_dist 156 | 157 | optimizer.zero_grad() 158 | loss.backward() 159 | optimizer.step() 160 | 161 | iter_num = iter_num + 1 162 | writer.add_scalar('lr', lr_, iter_num) 163 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 164 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 165 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 166 | writer.add_scalar('loss/loss', loss, iter_num) 167 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 168 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 169 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 170 | if iter_num % 2 == 0: 171 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 172 | grid_image = make_grid(image, 5, normalize=True) 173 | writer.add_image('train/Image', grid_image, iter_num) 174 | 175 | outputs_soft = F.softmax(outputs, 1) 176 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 177 | grid_image = make_grid(image, 5, normalize=False) 178 | writer.add_image('train/Predicted_label', grid_image, iter_num) 179 | 180 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 181 | grid_image = make_grid(image, 5, normalize=False) 182 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 183 | 184 | out_dis_slice = out_dis[0, 0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 185 | grid_image = make_grid(out_dis_slice, 5, normalize=True) 186 | writer.add_image('train/out_dis_map', grid_image, iter_num) 187 | 188 | gt_dis_slice = gt_dis[0, 0,:, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 189 | grid_image = make_grid(gt_dis_slice, 5, normalize=True) 190 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 191 | ## change lr 192 | if iter_num % 2500 == 0: 193 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 194 | for param_group in optimizer.param_groups: 195 | param_group['lr'] = lr_ 196 | if iter_num % 1000 == 0: 197 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 198 | torch.save(net.state_dict(), save_mode_path) 199 | logging.info("save model to {}".format(save_mode_path)) 200 | 201 | if iter_num > max_iterations: 202 | break 203 | time1 = time.time() 204 | if iter_num > max_iterations: 205 | break 206 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 207 | torch.save(net.state_dict(), save_mode_path) 208 | logging.info("save model to {}".format(save_mode_path)) 209 | writer.close() 210 | -------------------------------------------------------------------------------- /code/train_LA_Rec_FGDTM_L1PlusL2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_rec import VNetRec 21 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | 24 | 25 | """ 26 | Adding reconstruction branch to V-Net 27 | Ref: 28 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 29 | https://arxiv.org/abs/1901.01238 30 | """ 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 34 | parser.add_argument('--exp', type=str, default='vnet_dp_la_Rec_FGDTM_L1PlusL2', help='model_name;dp:add dropout; Rec:Reconstruction') 35 | parser.add_argument('--max_iterations', type=int, default=10000, help='maximum epoch number to train') 36 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 37 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 40 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 41 | args = parser.parse_args() 42 | 43 | train_data_path = args.root_path 44 | snapshot_path = "../model_la/" + args.exp + "/" 45 | 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | batch_size = args.batch_size * len(args.gpu.split(',')) 48 | max_iterations = args.max_iterations 49 | base_lr = args.base_lr 50 | 51 | if args.deterministic: 52 | cudnn.benchmark = False 53 | cudnn.deterministic = True 54 | random.seed(args.seed) 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | patch_size = (112, 112, 80) 60 | num_classes = 2 61 | 62 | def dice_loss(score, target): 63 | target = target.float() 64 | smooth = 1e-5 65 | intersect = torch.sum(score * target) 66 | y_sum = torch.sum(target * target) 67 | z_sum = torch.sum(score * score) 68 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 69 | loss = 1 - loss 70 | return loss 71 | 72 | def compute_dtm(img_gt, out_shape): 73 | """ 74 | compute the distance transform map of foreground in binary mask 75 | input: segmentation, shape = (batch_size, x, y, z) 76 | output: the foreground Distance Map (SDM) 77 | dtm(x) = 0; x in segmentation boundary 78 | inf|x-y|; x in segmentation 79 | """ 80 | fg_dtm = np.zeros(out_shape) 81 | 82 | for b in range(out_shape[0]): # batch size 83 | for c in range(out_shape[1]): 84 | posmask = img_gt[b].astype(np.bool) 85 | if posmask.any(): 86 | posdis = distance(posmask) 87 | fg_dtm[b][c] = posdis 88 | 89 | return fg_dtm 90 | 91 | 92 | 93 | if __name__ == "__main__": 94 | ## make logger file 95 | if not os.path.exists(snapshot_path): 96 | os.makedirs(snapshot_path) 97 | if os.path.exists(snapshot_path + '/code'): 98 | shutil.rmtree(snapshot_path + '/code') 99 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 100 | 101 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 102 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 103 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 104 | logging.info(str(args)) 105 | 106 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 107 | net = net.cuda() 108 | 109 | db_train = LAHeart(base_dir=train_data_path, 110 | split='train', 111 | num=16, 112 | transform = transforms.Compose([ 113 | RandomRotFlip(), 114 | RandomCrop(patch_size), 115 | ToTensor(), 116 | ])) 117 | 118 | def worker_init_fn(worker_id): 119 | random.seed(args.seed+worker_id) 120 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 121 | 122 | net.train() 123 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 124 | 125 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 126 | logging.info("{} itertations per epoch".format(len(trainloader))) 127 | 128 | iter_num = 0 129 | max_epoch = max_iterations//len(trainloader)+1 130 | lr_ = base_lr 131 | net.train() 132 | for epoch_num in tqdm(range(max_epoch), ncols=70): 133 | for i_batch, sampled_batch in enumerate(trainloader): 134 | # generate paired iput 135 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 136 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 137 | outputs, out_dis = net(volume_batch) 138 | 139 | with torch.no_grad(): 140 | gt_dis = compute_dtm(label_batch.cpu().numpy(), out_dis.shape) 141 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 142 | 143 | # compute CE + Dice loss 144 | loss_ce = F.cross_entropy(outputs, label_batch) 145 | outputs_soft = F.softmax(outputs, dim=1) 146 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 147 | # compute L1 + L2 Loss 148 | loss_dist = torch.norm(out_dis-gt_dis, 1)/torch.numel(out_dis) + F.mse_loss(out_dis, gt_dis) 149 | 150 | loss = loss_ce + loss_dice + loss_dist 151 | 152 | optimizer.zero_grad() 153 | loss.backward() 154 | optimizer.step() 155 | 156 | iter_num = iter_num + 1 157 | writer.add_scalar('lr', lr_, iter_num) 158 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 159 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 160 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 161 | writer.add_scalar('loss/loss', loss, iter_num) 162 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 163 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 164 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 165 | if iter_num % 2 == 0: 166 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 167 | grid_image = make_grid(image, 5, normalize=True) 168 | writer.add_image('train/Image', grid_image, iter_num) 169 | 170 | outputs_soft = F.softmax(outputs, 1) 171 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 172 | grid_image = make_grid(image, 5, normalize=False) 173 | writer.add_image('train/Predicted_label', grid_image, iter_num) 174 | 175 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 176 | grid_image = make_grid(image, 5, normalize=False) 177 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 178 | 179 | out_dis_slice = out_dis[0, 0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 180 | grid_image = make_grid(out_dis_slice, 5, normalize=False) 181 | writer.add_image('train/out_dis_map', grid_image, iter_num) 182 | 183 | gt_dis_slice = gt_dis[0, 0,:, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 184 | grid_image = make_grid(gt_dis_slice, 5, normalize=False) 185 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 186 | ## change lr 187 | if iter_num % 2500 == 0: 188 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 189 | for param_group in optimizer.param_groups: 190 | param_group['lr'] = lr_ 191 | if iter_num % 1000 == 0: 192 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 193 | torch.save(net.state_dict(), save_mode_path) 194 | logging.info("save model to {}".format(save_mode_path)) 195 | 196 | if iter_num > max_iterations: 197 | break 198 | time1 = time.time() 199 | if iter_num > max_iterations: 200 | break 201 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 202 | torch.save(net.state_dict(), save_mode_path) 203 | logging.info("save model to {}".format(save_mode_path)) 204 | writer.close() 205 | -------------------------------------------------------------------------------- /code/train_LA_Rec_FGDTM_L2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_rec import VNetRec 21 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | 24 | 25 | """ 26 | Adding reconstruction branch to V-Net 27 | Ref: 28 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 29 | https://arxiv.org/abs/1901.01238 30 | """ 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--root_path', type=str, default='../data/2018LA_Seg_Training Set/', help='Name of Experiment') 34 | parser.add_argument('--exp', type=str, default='vnet_dp_la_Rec_FGDTM_L2', help='model_name;dp:add dropout; Rec:Reconstruction') 35 | parser.add_argument('--max_iterations', type=int, default=10000, help='maximum epoch number to train') 36 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 37 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 40 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 41 | args = parser.parse_args() 42 | 43 | train_data_path = args.root_path 44 | snapshot_path = "../model_la/" + args.exp + "/" 45 | 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | batch_size = args.batch_size * len(args.gpu.split(',')) 48 | max_iterations = args.max_iterations 49 | base_lr = args.base_lr 50 | 51 | if args.deterministic: 52 | cudnn.benchmark = False 53 | cudnn.deterministic = True 54 | random.seed(args.seed) 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | patch_size = (112, 112, 80) 60 | num_classes = 2 61 | 62 | def dice_loss(score, target): 63 | target = target.float() 64 | smooth = 1e-5 65 | intersect = torch.sum(score * target) 66 | y_sum = torch.sum(target * target) 67 | z_sum = torch.sum(score * score) 68 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 69 | loss = 1 - loss 70 | return loss 71 | 72 | def compute_dtm(img_gt, out_shape): 73 | """ 74 | compute the distance transform map of foreground in binary mask 75 | input: segmentation, shape = (batch_size, x, y, z); 76 | out_shape = (batch_size, 1, x, y, z) 77 | output: the foreground Distance Map (SDM) 78 | dtm(x) = 0; x in segmentation boundary 79 | inf|x-y|; x in segmentation 80 | """ 81 | fg_dtm = np.zeros(out_shape) 82 | 83 | for b in range(out_shape[0]): # batch size 84 | for c in range(out_shape[1]): 85 | posmask = img_gt[b].astype(np.bool) 86 | if posmask.any(): 87 | posdis = distance(posmask) 88 | fg_dtm[b][c] = posdis 89 | 90 | return fg_dtm 91 | 92 | 93 | 94 | if __name__ == "__main__": 95 | ## make logger file 96 | if not os.path.exists(snapshot_path): 97 | os.makedirs(snapshot_path) 98 | if os.path.exists(snapshot_path + '/code'): 99 | shutil.rmtree(snapshot_path + '/code') 100 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 101 | 102 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 103 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 104 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 105 | logging.info(str(args)) 106 | 107 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 108 | net = net.cuda() 109 | 110 | db_train = LAHeart(base_dir=train_data_path, 111 | split='train', 112 | num=16, 113 | transform = transforms.Compose([ 114 | RandomRotFlip(), 115 | RandomCrop(patch_size), 116 | ToTensor(), 117 | ])) 118 | 119 | def worker_init_fn(worker_id): 120 | random.seed(args.seed+worker_id) 121 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 122 | 123 | net.train() 124 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 125 | 126 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 127 | logging.info("{} itertations per epoch".format(len(trainloader))) 128 | 129 | iter_num = 0 130 | max_epoch = max_iterations//len(trainloader)+1 131 | lr_ = base_lr 132 | net.train() 133 | for epoch_num in tqdm(range(max_epoch), ncols=70): 134 | for i_batch, sampled_batch in enumerate(trainloader): 135 | # generate paired iput 136 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 137 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 138 | outputs, out_dis = net(volume_batch) 139 | 140 | with torch.no_grad(): 141 | gt_dis = compute_dtm(label_batch.cpu().numpy(), out_dis.shape) 142 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 143 | 144 | # compute CE + Dice loss 145 | loss_ce = F.cross_entropy(outputs, label_batch) 146 | outputs_soft = F.softmax(outputs, dim=1) 147 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 148 | # compute L2 Loss 149 | loss_dist = F.mse_loss(out_dis, gt_dis) 150 | 151 | loss = loss_ce + loss_dice + loss_dist 152 | 153 | optimizer.zero_grad() 154 | loss.backward() 155 | optimizer.step() 156 | 157 | iter_num = iter_num + 1 158 | writer.add_scalar('lr', lr_, iter_num) 159 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 160 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 161 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 162 | writer.add_scalar('loss/loss', loss, iter_num) 163 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 164 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 165 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 166 | if iter_num % 2 == 0: 167 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 168 | grid_image = make_grid(image, 5, normalize=True) 169 | writer.add_image('train/Image', grid_image, iter_num) 170 | 171 | outputs_soft = F.softmax(outputs, 1) 172 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 173 | grid_image = make_grid(image, 5, normalize=False) 174 | writer.add_image('train/Predicted_label', grid_image, iter_num) 175 | 176 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 177 | grid_image = make_grid(image, 5, normalize=False) 178 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 179 | 180 | out_dis_slice = out_dis[0, 0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 181 | grid_image = make_grid(out_dis_slice, 5, normalize=True) 182 | writer.add_image('train/out_dis_map', grid_image, iter_num) 183 | 184 | gt_dis_slice = gt_dis[0, 0,:, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 185 | grid_image = make_grid(gt_dis_slice, 5, normalize=True) 186 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 187 | ## change lr 188 | if iter_num % 2500 == 0: 189 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 190 | for param_group in optimizer.param_groups: 191 | param_group['lr'] = lr_ 192 | if iter_num % 1000 == 0: 193 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 194 | torch.save(net.state_dict(), save_mode_path) 195 | logging.info("save model to {}".format(save_mode_path)) 196 | 197 | if iter_num > max_iterations: 198 | break 199 | time1 = time.time() 200 | if iter_num > max_iterations: 201 | break 202 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 203 | torch.save(net.state_dict(), save_mode_path) 204 | logging.info("save model to {}".format(save_mode_path)) 205 | writer.close() 206 | -------------------------------------------------------------------------------- /code/train_LITS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet import VNet 21 | from utils.losses import dice_loss 22 | from dataloaders.livertumor import LiverTumor, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--root_path', type=str, default='../data/LITS', help='Name of Experiment') 27 | parser.add_argument('--exp', type=str, default='vnet_supervisedonly_dp', help='model_name') 28 | parser.add_argument('--max_iterations', type=int, default=50000, help='maximum epoch number to train') 29 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 30 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 31 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 32 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 33 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 34 | args = parser.parse_args() 35 | 36 | train_data_path = args.root_path 37 | snapshot_path = "../model_lits/" + args.exp + "/" 38 | 39 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 40 | batch_size = args.batch_size * len(args.gpu.split(',')) 41 | max_iterations = args.max_iterations 42 | base_lr = args.base_lr 43 | 44 | if args.deterministic: 45 | cudnn.benchmark = False 46 | cudnn.deterministic = True 47 | random.seed(args.seed) 48 | np.random.seed(args.seed) 49 | torch.manual_seed(args.seed) 50 | torch.cuda.manual_seed(args.seed) 51 | 52 | patch_size = (96, 128, 160) 53 | num_classes = 2 54 | 55 | if __name__ == "__main__": 56 | ## make logger file 57 | if not os.path.exists(snapshot_path): 58 | os.makedirs(snapshot_path) 59 | if os.path.exists(snapshot_path + '/code'): 60 | shutil.rmtree(snapshot_path + '/code') 61 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 62 | 63 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 64 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 65 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 66 | logging.info(str(args)) 67 | 68 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 69 | net = net.cuda() 70 | 71 | db_train = LiverTumor(base_dir=train_data_path, 72 | split='train', 73 | transform = transforms.Compose([ 74 | RandomRotFlip(), 75 | RandomCrop(patch_size), 76 | ToTensor(), 77 | ])) 78 | # db_test = LiverTumor(base_dir=train_data_path, 79 | # split='test', 80 | # transform = transforms.Compose([ 81 | # CenterCrop(patch_size), 82 | # ToTensor() 83 | # ])) 84 | def worker_init_fn(worker_id): 85 | random.seed(args.seed+worker_id) 86 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 87 | 88 | net.train() 89 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 90 | 91 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 92 | logging.info("{} itertations per epoch".format(len(trainloader))) 93 | 94 | iter_num = 0 95 | max_epoch = max_iterations//len(trainloader)+1 96 | lr_ = base_lr 97 | net.train() 98 | for epoch_num in tqdm(range(max_epoch), ncols=70): 99 | time1 = time.time() 100 | for i_batch, sampled_batch in enumerate(trainloader): 101 | time2 = time.time() 102 | # print('fetch data cost {}'.format(time2-time1)) 103 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 104 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 105 | outputs = net(volume_batch) 106 | 107 | loss_ce = F.cross_entropy(outputs, label_batch) 108 | outputs_soft = F.softmax(outputs, dim=1) 109 | loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 110 | loss = loss_ce+loss_seg_dice 111 | 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | 116 | iter_num = iter_num + 1 117 | writer.add_scalar('lr', lr_, iter_num) 118 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 119 | writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num) 120 | writer.add_scalar('loss/loss', loss, iter_num) 121 | logging.info('iteration %d : loss_seg_dice : %f' % (iter_num, loss_seg_dice.item())) 122 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 123 | if iter_num % 2 == 0: 124 | image = volume_batch[0, 0:1, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1,3,1,1) 125 | grid_image = make_grid(image, 5, normalize=True) 126 | writer.add_image('train/Image', grid_image, iter_num) 127 | 128 | outputs_soft = F.softmax(outputs, 1) 129 | image = outputs_soft[0, 1:2, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 130 | grid_image = make_grid(image, 5, normalize=False) 131 | writer.add_image('train/Predicted_label', grid_image, iter_num) 132 | 133 | image = label_batch[0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 134 | grid_image = make_grid(image, 5, normalize=False) 135 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 136 | 137 | ## change lr 138 | if iter_num % 2500 == 0: 139 | lr_ = base_lr * 0.1 ** (iter_num // 2500) 140 | for param_group in optimizer.param_groups: 141 | param_group['lr'] = lr_ 142 | if iter_num % 1000 == 0: 143 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 144 | torch.save(net.state_dict(), save_mode_path) 145 | logging.info("save model to {}".format(save_mode_path)) 146 | 147 | if iter_num > max_iterations: 148 | break 149 | time1 = time.time() 150 | if iter_num > max_iterations: 151 | break 152 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 153 | torch.save(net.state_dict(), save_mode_path) 154 | logging.info("save model to {}".format(save_mode_path)) 155 | writer.close() 156 | -------------------------------------------------------------------------------- /code/train_LITS_Rec_SDF_L1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_rec import VNetRec 21 | from dataloaders.livertumor import LiverTumor, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | from skimage import segmentation as skimage_seg 24 | 25 | """ 26 | Adding reconstruction branch to V-Net 27 | Ref: 28 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 29 | https://arxiv.org/abs/1901.01238 30 | """ 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--root_path', type=str, default='../data/LITS', help='Name of Experiment') 34 | parser.add_argument('--exp', type=str, default='vnet_lits_Rec_SDF_L1_lr01', help='model_name;dp:add dropout; Rec:Reconstruction') 35 | parser.add_argument('--max_iterations', type=int, default=20000, help='maximum epoch number to train') 36 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 37 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 40 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 41 | args = parser.parse_args() 42 | 43 | train_data_path = args.root_path 44 | snapshot_path = "../model_lits/" + args.exp + "/" 45 | 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | batch_size = args.batch_size * len(args.gpu.split(',')) 48 | max_iterations = args.max_iterations 49 | base_lr = args.base_lr 50 | 51 | if args.deterministic: 52 | cudnn.benchmark = False 53 | cudnn.deterministic = True 54 | random.seed(args.seed) 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | patch_size = (96, 128, 160) 60 | num_classes = 2 61 | 62 | def dice_loss(score, target): 63 | target = target.float() 64 | smooth = 1e-5 65 | intersect = torch.sum(score * target) 66 | y_sum = torch.sum(target * target) 67 | z_sum = torch.sum(score * score) 68 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 69 | loss = 1 - loss 70 | return loss 71 | 72 | def compute_sdf(img_gt, out_shape): 73 | """ 74 | compute the signed distance map of binary mask 75 | input: segmentation, shape = (batch_size,c, x, y, z) 76 | output: the Signed Distance Map (SDM) 77 | sdf(x) = 0; x in segmentation boundary 78 | -inf|x-y|; x in segmentation 79 | +inf|x-y|; x out of segmentation 80 | normalize sdf to [-1,1] 81 | 82 | """ 83 | 84 | img_gt = img_gt.astype(np.uint8) 85 | normalized_sdf = np.zeros(out_shape) 86 | 87 | for b in range(out_shape[0]): # batch size 88 | for c in range(out_shape[1]): 89 | posmask = img_gt[b].astype(np.bool) 90 | if posmask.any(): 91 | negmask = ~posmask 92 | posdis = distance(posmask) 93 | negdis = distance(negmask) 94 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 95 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 96 | sdf[boundary==1] = 0 97 | normalized_sdf[b][c] = sdf 98 | 99 | return normalized_sdf 100 | 101 | if __name__ == "__main__": 102 | ## make logger file 103 | if not os.path.exists(snapshot_path): 104 | os.makedirs(snapshot_path) 105 | if os.path.exists(snapshot_path + '/code'): 106 | shutil.rmtree(snapshot_path + '/code') 107 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 108 | 109 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 110 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 111 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 112 | logging.info(str(args)) 113 | 114 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False) 115 | net = net.cuda() 116 | 117 | db_train = LiverTumor(base_dir=train_data_path, 118 | split='train', 119 | transform = transforms.Compose([ 120 | RandomRotFlip(), 121 | RandomCrop(patch_size), 122 | ToTensor(), 123 | ])) 124 | 125 | def worker_init_fn(worker_id): 126 | random.seed(args.seed+worker_id) 127 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 128 | 129 | net.train() 130 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 131 | 132 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 133 | logging.info("{} itertations per epoch".format(len(trainloader))) 134 | 135 | iter_num = 0 136 | max_epoch = max_iterations//len(trainloader)+1 137 | lr_ = base_lr 138 | net.train() 139 | for epoch_num in tqdm(range(max_epoch), ncols=70): 140 | for i_batch, sampled_batch in enumerate(trainloader): 141 | # generate paired iput 142 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 143 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 144 | outputs, out_dis = net(volume_batch) 145 | out_dis = torch.tanh(out_dis) 146 | 147 | with torch.no_grad(): 148 | gt_dis = compute_sdf(label_batch.cpu().numpy(), out_dis.shape) 149 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 150 | 151 | # compute CE + Dice loss 152 | loss_ce = F.cross_entropy(outputs, label_batch) 153 | outputs_soft = F.softmax(outputs, dim=1) 154 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 155 | # compute L1 Loss 156 | loss_dist = torch.norm(out_dis-gt_dis, 1)/torch.numel(out_dis) 157 | 158 | loss = loss_ce + loss_dice + loss_dist 159 | 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | iter_num = iter_num + 1 165 | writer.add_scalar('lr', lr_, iter_num) 166 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 167 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 168 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 169 | writer.add_scalar('loss/loss', loss, iter_num) 170 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 171 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 172 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 173 | if iter_num % 2 == 0: 174 | image = volume_batch[0, 0:1, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1,3,1,1) 175 | grid_image = make_grid(image, 5, normalize=True) 176 | writer.add_image('train/Image', grid_image, iter_num) 177 | 178 | outputs_soft = F.softmax(outputs, 1) 179 | image = outputs_soft[0, 1:2, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 180 | grid_image = make_grid(image, 5, normalize=False) 181 | writer.add_image('train/Predicted_label', grid_image, iter_num) 182 | 183 | image = label_batch[0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 184 | grid_image = make_grid(image, 5, normalize=False) 185 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 186 | 187 | out_dis_slice = out_dis[0, 0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 188 | grid_image = make_grid(out_dis_slice, 5, normalize=False) 189 | writer.add_image('train/out_dis_map', grid_image, iter_num) 190 | 191 | gt_dis_slice = gt_dis[0, 0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 192 | grid_image = make_grid(gt_dis_slice, 5, normalize=False) 193 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 194 | ## change lr 195 | if iter_num % 2500 == 0: 196 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 197 | for param_group in optimizer.param_groups: 198 | param_group['lr'] = lr_ 199 | if iter_num % 1000 == 0: 200 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 201 | torch.save(net.state_dict(), save_mode_path) 202 | logging.info("save model to {}".format(save_mode_path)) 203 | 204 | if iter_num > max_iterations: 205 | break 206 | time1 = time.time() 207 | if iter_num > max_iterations: 208 | break 209 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 210 | torch.save(net.state_dict(), save_mode_path) 211 | logging.info("save model to {}".format(save_mode_path)) 212 | writer.close() 213 | -------------------------------------------------------------------------------- /code/train_LITS_Rec_SDF_L1PlusL2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_rec import VNetRec 21 | from dataloaders.livertumor import LiverTumor, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | from skimage import segmentation as skimage_seg 24 | 25 | """ 26 | Adding reconstruction branch to V-Net 27 | Ref: 28 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 29 | https://arxiv.org/abs/1901.01238 30 | """ 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--root_path', type=str, default='../data/LITS', help='Name of Experiment') 34 | parser.add_argument('--exp', type=str, default='vnet_lits_Rec_SDF_L1PlusL2_lr01', help='model_name;dp:add dropout; Rec:Reconstruction') 35 | parser.add_argument('--max_iterations', type=int, default=20000, help='maximum epoch number to train') 36 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 37 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 40 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 41 | args = parser.parse_args() 42 | 43 | train_data_path = args.root_path 44 | snapshot_path = "../model_lits/" + args.exp + "/" 45 | 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | batch_size = args.batch_size * len(args.gpu.split(',')) 48 | max_iterations = args.max_iterations 49 | base_lr = args.base_lr 50 | 51 | if args.deterministic: 52 | cudnn.benchmark = False 53 | cudnn.deterministic = True 54 | random.seed(args.seed) 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | patch_size = (96, 128, 160) 60 | num_classes = 2 61 | 62 | def dice_loss(score, target): 63 | target = target.float() 64 | smooth = 1e-5 65 | intersect = torch.sum(score * target) 66 | y_sum = torch.sum(target * target) 67 | z_sum = torch.sum(score * score) 68 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 69 | loss = 1 - loss 70 | return loss 71 | 72 | def compute_sdf(img_gt, out_shape): 73 | """ 74 | compute the signed distance map of binary mask 75 | input: segmentation, shape = (batch_size,c, x, y, z) 76 | output: the Signed Distance Map (SDM) 77 | sdf(x) = 0; x in segmentation boundary 78 | -inf|x-y|; x in segmentation 79 | +inf|x-y|; x out of segmentation 80 | normalize sdf to [-1,1] 81 | 82 | """ 83 | 84 | img_gt = img_gt.astype(np.uint8) 85 | normalized_sdf = np.zeros(out_shape) 86 | 87 | for b in range(out_shape[0]): # batch size 88 | for c in range(out_shape[1]): 89 | posmask = img_gt[b].astype(np.bool) 90 | if posmask.any(): 91 | negmask = ~posmask 92 | posdis = distance(posmask) 93 | negdis = distance(negmask) 94 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 95 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 96 | sdf[boundary==1] = 0 97 | normalized_sdf[b][c] = sdf 98 | 99 | return normalized_sdf 100 | 101 | if __name__ == "__main__": 102 | ## make logger file 103 | if not os.path.exists(snapshot_path): 104 | os.makedirs(snapshot_path) 105 | if os.path.exists(snapshot_path + '/code'): 106 | shutil.rmtree(snapshot_path + '/code') 107 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 108 | 109 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 110 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 111 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 112 | logging.info(str(args)) 113 | 114 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False) 115 | net = net.cuda() 116 | 117 | db_train = LiverTumor(base_dir=train_data_path, 118 | split='train', 119 | transform = transforms.Compose([ 120 | RandomRotFlip(), 121 | RandomCrop(patch_size), 122 | ToTensor(), 123 | ])) 124 | 125 | def worker_init_fn(worker_id): 126 | random.seed(args.seed+worker_id) 127 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 128 | 129 | net.train() 130 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 131 | 132 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 133 | logging.info("{} itertations per epoch".format(len(trainloader))) 134 | 135 | iter_num = 0 136 | max_epoch = max_iterations//len(trainloader)+1 137 | lr_ = base_lr 138 | net.train() 139 | for epoch_num in tqdm(range(max_epoch), ncols=70): 140 | for i_batch, sampled_batch in enumerate(trainloader): 141 | # generate paired iput 142 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 143 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 144 | outputs, out_dis = net(volume_batch) 145 | out_dis = torch.tanh(out_dis) 146 | 147 | with torch.no_grad(): 148 | gt_dis = compute_sdf(label_batch.cpu().numpy(), out_dis.shape) 149 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 150 | 151 | # compute CE + Dice loss 152 | loss_ce = F.cross_entropy(outputs, label_batch) 153 | outputs_soft = F.softmax(outputs, dim=1) 154 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 155 | # compute L1 + L2 Loss 156 | loss_dist = torch.norm(out_dis-gt_dis, 1)/torch.numel(out_dis) + F.mse_loss(out_dis, gt_dis) 157 | 158 | loss = loss_ce + loss_dice + loss_dist 159 | 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | iter_num = iter_num + 1 165 | writer.add_scalar('lr', lr_, iter_num) 166 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 167 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 168 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 169 | writer.add_scalar('loss/loss', loss, iter_num) 170 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 171 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 172 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 173 | if iter_num % 2 == 0: 174 | image = volume_batch[0, 0:1, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1,3,1,1) 175 | grid_image = make_grid(image, 5, normalize=True) 176 | writer.add_image('train/Image', grid_image, iter_num) 177 | 178 | outputs_soft = F.softmax(outputs, 1) 179 | image = outputs_soft[0, 1:2, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 180 | grid_image = make_grid(image, 5, normalize=False) 181 | writer.add_image('train/Predicted_label', grid_image, iter_num) 182 | 183 | image = label_batch[0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 184 | grid_image = make_grid(image, 5, normalize=False) 185 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 186 | 187 | out_dis_slice = out_dis[0, 0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 188 | grid_image = make_grid(out_dis_slice, 5, normalize=False) 189 | writer.add_image('train/out_dis_map', grid_image, iter_num) 190 | 191 | gt_dis_slice = gt_dis[0, 0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 192 | grid_image = make_grid(gt_dis_slice, 5, normalize=False) 193 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 194 | ## change lr 195 | if iter_num % 2500 == 0: 196 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 197 | for param_group in optimizer.param_groups: 198 | param_group['lr'] = lr_ 199 | if iter_num % 1000 == 0: 200 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 201 | torch.save(net.state_dict(), save_mode_path) 202 | logging.info("save model to {}".format(save_mode_path)) 203 | 204 | if iter_num > max_iterations: 205 | break 206 | time1 = time.time() 207 | if iter_num > max_iterations: 208 | break 209 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 210 | torch.save(net.state_dict(), save_mode_path) 211 | logging.info("save model to {}".format(save_mode_path)) 212 | writer.close() 213 | -------------------------------------------------------------------------------- /code/train_LITS_Rec_SDF_L2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet_rec import VNetRec 21 | from dataloaders.livertumor import LiverTumor, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 22 | from scipy.ndimage import distance_transform_edt as distance 23 | from skimage import segmentation as skimage_seg 24 | 25 | """ 26 | Adding reconstruction branch to V-Net 27 | Ref: 28 | A Distance Map Regularized CNN for Cardiac Cine MR Image Segmentation 29 | https://arxiv.org/abs/1901.01238 30 | """ 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--root_path', type=str, default='../data/LITS', help='Name of Experiment') 34 | parser.add_argument('--exp', type=str, default='vnet_lits_Rec_SDF_L2_lr01', help='model_name;dp:add dropout; Rec:Reconstruction') 35 | parser.add_argument('--max_iterations', type=int, default=20000, help='maximum epoch number to train') 36 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 37 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 38 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 39 | parser.add_argument('--seed', type=int, default=2019, help='random seed') 40 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 41 | args = parser.parse_args() 42 | 43 | train_data_path = args.root_path 44 | snapshot_path = "../model_lits/" + args.exp + "/" 45 | 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | batch_size = args.batch_size * len(args.gpu.split(',')) 48 | max_iterations = args.max_iterations 49 | base_lr = args.base_lr 50 | 51 | if args.deterministic: 52 | cudnn.benchmark = False 53 | cudnn.deterministic = True 54 | random.seed(args.seed) 55 | np.random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | patch_size = (96, 128, 160) 60 | num_classes = 2 61 | 62 | def dice_loss(score, target): 63 | target = target.float() 64 | smooth = 1e-5 65 | intersect = torch.sum(score * target) 66 | y_sum = torch.sum(target * target) 67 | z_sum = torch.sum(score * score) 68 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 69 | loss = 1 - loss 70 | return loss 71 | 72 | def compute_sdf(img_gt, out_shape): 73 | """ 74 | compute the signed distance map of binary mask 75 | input: segmentation, shape = (batch_size,c, x, y, z) 76 | output: the Signed Distance Map (SDM) 77 | sdf(x) = 0; x in segmentation boundary 78 | -inf|x-y|; x in segmentation 79 | +inf|x-y|; x out of segmentation 80 | normalize sdf to [-1,1] 81 | 82 | """ 83 | 84 | img_gt = img_gt.astype(np.uint8) 85 | normalized_sdf = np.zeros(out_shape) 86 | 87 | for b in range(out_shape[0]): # batch size 88 | for c in range(out_shape[1]): 89 | posmask = img_gt[b].astype(np.bool) 90 | if posmask.any(): 91 | negmask = ~posmask 92 | posdis = distance(posmask) 93 | negdis = distance(negmask) 94 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 95 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 96 | sdf[boundary==1] = 0 97 | normalized_sdf[b][c] = sdf 98 | 99 | return normalized_sdf 100 | 101 | if __name__ == "__main__": 102 | ## make logger file 103 | if not os.path.exists(snapshot_path): 104 | os.makedirs(snapshot_path) 105 | if os.path.exists(snapshot_path + '/code'): 106 | shutil.rmtree(snapshot_path + '/code') 107 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 108 | 109 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 110 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 111 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 112 | logging.info(str(args)) 113 | 114 | net = VNetRec(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False) 115 | net = net.cuda() 116 | 117 | db_train = LiverTumor(base_dir=train_data_path, 118 | split='train', 119 | transform = transforms.Compose([ 120 | RandomRotFlip(), 121 | RandomCrop(patch_size), 122 | ToTensor(), 123 | ])) 124 | 125 | def worker_init_fn(worker_id): 126 | random.seed(args.seed+worker_id) 127 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 128 | 129 | net.train() 130 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 131 | 132 | writer = SummaryWriter(snapshot_path+'/log', flush_secs=2) 133 | logging.info("{} itertations per epoch".format(len(trainloader))) 134 | 135 | iter_num = 0 136 | max_epoch = max_iterations//len(trainloader)+1 137 | lr_ = base_lr 138 | net.train() 139 | for epoch_num in tqdm(range(max_epoch), ncols=70): 140 | for i_batch, sampled_batch in enumerate(trainloader): 141 | # generate paired iput 142 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 143 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 144 | outputs, out_dis = net(volume_batch) 145 | out_dis = torch.tanh(out_dis) 146 | 147 | with torch.no_grad(): 148 | gt_dis = compute_sdf(label_batch.cpu().numpy(), out_dis.shape) 149 | gt_dis = torch.from_numpy(gt_dis).float().cuda() 150 | 151 | # compute CE + Dice loss 152 | loss_ce = F.cross_entropy(outputs, label_batch) 153 | outputs_soft = F.softmax(outputs, dim=1) 154 | loss_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 155 | # compute L2 Loss 156 | loss_dist = F.mse_loss(out_dis, gt_dis) 157 | 158 | loss = loss_ce + loss_dice + loss_dist 159 | 160 | optimizer.zero_grad() 161 | loss.backward() 162 | optimizer.step() 163 | 164 | iter_num = iter_num + 1 165 | writer.add_scalar('lr', lr_, iter_num) 166 | writer.add_scalar('loss/loss_ce', loss_ce, iter_num) 167 | writer.add_scalar('loss/loss_dice', loss_dice, iter_num) 168 | writer.add_scalar('loss/loss_dist', loss_dist, iter_num) 169 | writer.add_scalar('loss/loss', loss, iter_num) 170 | logging.info('iteration %d : loss_dist : %f' % (iter_num, loss_dist.item())) 171 | logging.info('iteration %d : loss_dice : %f' % (iter_num, loss_dice.item())) 172 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 173 | if iter_num % 2 == 0: 174 | image = volume_batch[0, 0:1, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1,3,1,1) 175 | grid_image = make_grid(image, 5, normalize=True) 176 | writer.add_image('train/Image', grid_image, iter_num) 177 | 178 | outputs_soft = F.softmax(outputs, 1) 179 | image = outputs_soft[0, 1:2, 30:71:10, :, :].permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 180 | grid_image = make_grid(image, 5, normalize=False) 181 | writer.add_image('train/Predicted_label', grid_image, iter_num) 182 | 183 | image = label_batch[0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 184 | grid_image = make_grid(image, 5, normalize=False) 185 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 186 | 187 | out_dis_slice = out_dis[0, 0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 188 | grid_image = make_grid(out_dis_slice, 5, normalize=False) 189 | writer.add_image('train/out_dis_map', grid_image, iter_num) 190 | 191 | gt_dis_slice = gt_dis[0, 0, 30:71:10, :, :].unsqueeze(0).permute(1, 0, 2, 3).repeat(1, 3, 1, 1) 192 | grid_image = make_grid(gt_dis_slice, 5, normalize=False) 193 | writer.add_image('train/gt_dis_map', grid_image, iter_num) 194 | ## change lr 195 | if iter_num % 2500 == 0: 196 | lr_ = base_lr * 0.1 ** (iter_num // 1000) 197 | for param_group in optimizer.param_groups: 198 | param_group['lr'] = lr_ 199 | if iter_num % 1000 == 0: 200 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 201 | torch.save(net.state_dict(), save_mode_path) 202 | logging.info("save model to {}".format(save_mode_path)) 203 | 204 | if iter_num > max_iterations: 205 | break 206 | time1 = time.time() 207 | if iter_num > max_iterations: 208 | break 209 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth') 210 | torch.save(net.state_dict(), save_mode_path) 211 | logging.info("save model to {}".format(save_mode_path)) 212 | writer.close() 213 | -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.data.sampler import Sampler 15 | 16 | import networks 17 | 18 | def load_model(path): 19 | """Loads model and return it without DataParallel table.""" 20 | if os.path.isfile(path): 21 | print("=> loading checkpoint '{}'".format(path)) 22 | checkpoint = torch.load(path) 23 | 24 | # size of the top layer 25 | N = checkpoint['state_dict']['top_layer.bias'].size() 26 | 27 | # build skeleton of the model 28 | sob = 'sobel.0.weight' in checkpoint['state_dict'].keys() 29 | model = models.__dict__[checkpoint['arch']](sobel=sob, out=int(N[0])) 30 | 31 | # deal with a dataparallel table 32 | def rename_key(key): 33 | if not 'module' in key: 34 | return key 35 | return ''.join(key.split('.module')) 36 | 37 | checkpoint['state_dict'] = {rename_key(key): val 38 | for key, val 39 | in checkpoint['state_dict'].items()} 40 | 41 | # load weights 42 | model.load_state_dict(checkpoint['state_dict']) 43 | print("Loaded") 44 | else: 45 | model = None 46 | print("=> no checkpoint found at '{}'".format(path)) 47 | return model 48 | 49 | 50 | class UnifLabelSampler(Sampler): 51 | """Samples elements uniformely accross pseudolabels. 52 | Args: 53 | N (int): size of returned iterator. 54 | images_lists: dict of key (target), value (list of data with this target) 55 | """ 56 | 57 | def __init__(self, N, images_lists): 58 | self.N = N 59 | self.images_lists = images_lists 60 | self.indexes = self.generate_indexes_epoch() 61 | 62 | def generate_indexes_epoch(self): 63 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 64 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 65 | 66 | for i in range(len(self.images_lists)): 67 | indexes = np.random.choice( 68 | self.images_lists[i], 69 | size_per_pseudolabel, 70 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 71 | ) 72 | res[i * size_per_pseudolabel: (i + 1) * size_per_pseudolabel] = indexes 73 | 74 | np.random.shuffle(res) 75 | return res[:self.N].astype('int') 76 | 77 | def __iter__(self): 78 | return iter(self.indexes) 79 | 80 | def __len__(self): 81 | return self.N 82 | 83 | 84 | class AverageMeter(object): 85 | """Computes and stores the average and current value""" 86 | def __init__(self): 87 | self.reset() 88 | 89 | def reset(self): 90 | self.val = 0 91 | self.avg = 0 92 | self.sum = 0 93 | self.count = 0 94 | 95 | def update(self, val, n=1): 96 | self.val = val 97 | self.sum += val * n 98 | self.count += n 99 | self.avg = self.sum / self.count 100 | 101 | 102 | def learning_rate_decay(optimizer, t, lr_0): 103 | for param_group in optimizer.param_groups: 104 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t) 105 | param_group['lr'] = lr 106 | 107 | 108 | class Logger(): 109 | """ Class to update every epoch to keep trace of the results 110 | Methods: 111 | - log() log and save 112 | """ 113 | 114 | def __init__(self, path): 115 | self.path = path 116 | self.data = [] 117 | 118 | def log(self, train_point): 119 | self.data.append(train_point) 120 | with open(os.path.join(self.path), 'wb') as fp: 121 | pickle.dump(self.data, fp, -1) 122 | 123 | 124 | 125 | 126 | 127 | def norm_ip(img, min, max): 128 | out = torch.clamp(img, min=min, max=max) 129 | out = (out - min) / (max - min + 1e-5) 130 | return out 131 | 132 | 133 | def norm_range(t, range=None): 134 | if range is not None: 135 | return norm_ip(t, range[0], range[1]) 136 | else: 137 | return norm_ip(t, float(t.min()), float(t.max())) 138 | 139 | 140 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | - Download heart MRI data [MICCAI 2018 Atrial Segmentation Challenge](http://atriaseg2018.cardiacatlas.org/data/) at [here](https://share.weiyun.com/IjRZfaUg). 2 | - Download liver tumor CT data [here](https://pan.baidu.com/s/1YzM1i0ZaZa5NaUdaSJlz5A). Password: g3fz 3 | 4 | > All the data have been converted to `h5` format. 5 | -------------------------------------------------------------------------------- /data/test.list: -------------------------------------------------------------------------------- 1 | UPT6DX9IQY9JAZ7HJKA7 2 | UTBUJIWZMKP64E3N73YC 3 | ULHWPWKKLTE921LQLH1P 4 | V0MZOWJ6MU3RMRCV9EXR 5 | VDOF02M8ZHEAADFMS6NP 6 | VG4C826RAAKVMV9BQLVD 7 | VIXBEFTNVHZWKAKURJBN 8 | VQ2L3WM8KEVF6L44E6G9 9 | WBG9WYZ1B25WDT5WAT8T 10 | WMDG2EFA6L2SNDZXIRU0 11 | WNPKE0W404QE9AELX1LR 12 | WSJB9P4JCXUVHBOYFVWL 13 | WW8F5CO4S4K5IM5Z7EXX 14 | X18LU5AOBNNDMLTA0JZL 15 | XYDLYJ5CS19FDBVLJIPI 16 | Y7ZU0B2APPF54WG6PDMF 17 | YDKD1HVHSME6NVMA8I39 18 | Z9GMG63CJLL0VW893BB1 19 | ZIJLJAVQV3FJ6JSQOH1E 20 | ZQPMJ4XEC5A4BISD45P1 21 | -------------------------------------------------------------------------------- /data/test_unlabel.list: -------------------------------------------------------------------------------- 1 | 5FKQL4K14KCB72Y8YMC2 2 | 5HH0WPWIY06DLAFOBQ4M 3 | 5QFK2PMHNX7UALK52NNA 4 | 5UB5KFD2PK38Z4LS6W80 5 | 6799D6LEBH3NSRV1KH27 6 | 78NJ5YFQF72BGC8RO51C 7 | 7FUCNXB39F78WTOP5K71 8 | 8GYK8A9MBRC9TV0FVSRA 9 | 8M99G0JLAXG9GLPV0O8G 10 | 8RE90C8H5DKF4V6HO8UU 11 | 8ZG2TRZ81MAWHZPN9KKG 12 | 9DCM2IB45SK6YKQNYUQY 13 | 9DHWWP5Y66VDMPXISZ13 14 | 9DQYTIU00I4JC0OEOKQQ 15 | A11O45O3NAXWM7T2H8CH 16 | A4R1S23KR0KU2WSYHK2X 17 | A5RNNK0A891WUSC2V624 18 | AT5CRO5JUDBWD4RUPXSQ 19 | BNK95S2SJXEGSW7VAKYU 20 | BXJWOUYP2J3EN4U92517 21 | BYSRSI3H4YTWKMM3MADP 22 | BZUFJX66T0W6ZPVTL9DU 23 | CB5P5W7X310NIIVU7UZV 24 | CBIJFVZ5L9BS0LKWE8YL 25 | CCGAKN4EDT72KC8TTJ76 26 | CLXFYOBQDCVXQ9P7YC07 27 | CMPXO4J23G58J53Q98SZ 28 | CZPMV6KWZ4I7IJJP9FOK 29 | DLKXBV73A55ZTSZ0QQI2 30 | DQ5UYBGR5QP6L692QSG6 31 | DYXSCIWHLSUOZIDDSZ40 32 | E2ZMO66WGS74UKXTZPPQ 33 | EJ5V7SPR4961JWD6SS8V 34 | FGM5NIWN3URY4HF4WNUW 35 | GSC9KNY0VEZXFSGWNF25 36 | HVE7DR3CUA2IM3RC6OMA 37 | HZZ4O0BRKF8S0YX3NNF7 38 | I2VZ7N8H9QYNYT7ZZF1Y 39 | IDWWHGWJ5STOQXSDT6GU 40 | IIY6TYJMTJIZRIZLB9YW 41 | IJJY51YW3W4YJJ7DTVTK 42 | IQYKPTWXVV9H0IHB8YXC 43 | JEC6HJ7SQJXBKVREX03F 44 | JGFOLWJF7YCYD8DPHQNH 45 | K32FD6LRSUSSXGS1YUOX 46 | KM5RYAMP4P4ZP6XWP3Q2 47 | KSNYHUBHHUJTYJ14UQZR 48 | LH4FVU3TQDEC87YGN6FL 49 | LJSDNMND9SHKM7Q4IRHJ 50 | MFTDVMBWFNQ3F5KHBRDR 51 | MJHV7F65TB2A76CQLOC3 52 | MVKIPGBKTNSENNP1S4HB 53 | O5TSIKRD4AIB8K84WIR9 54 | OIRDLE32TXZX942FVZMM 55 | P1OTI3IWJUIB5NRLULLH 56 | PVNXUK681N9BY14K4Z86 57 | Q0MEX9ZIKAGJORSPLQ3Y 58 | Q7J0WYM695R9MA285ZW0 59 | QZC1W0FNR19KJFLOCFLH 60 | R8ER97O9UUN77C02VE2J 61 | RSZY41MT2FGDKHWWL5L2 62 | SN4LF8SGBSRQUPTDSX78 63 | TDDI6L3Y0L9VVFP9MNFS 64 | UZUZZT2W9IUSHL6ASOX3 65 | -------------------------------------------------------------------------------- /data/train.list: -------------------------------------------------------------------------------- 1 | 06SR5RBREL16DQ6M8LWS 2 | 0RZDK210BSMWAA6467LU 3 | 1D7CUD1955YZPGK8XHJX 4 | 1GU15S0GJ6PFNARO469W 5 | 1MHBF3G6DCPWHSKG7XCP 6 | 23X6SY44VT9KFHR7S7OC 7 | 2XL5HSFSE93RMOJDRGR4 8 | 38CWS74285MFGZZXR09Z 9 | 3C2QTUNI0852XV7ZH4Q1 10 | 3DA0T2V6JJ2NLUAV6FWM 11 | 4498CA6DZWELOXCBRYRF 12 | 45C45I6IXAFGNRO067W9 13 | 4CHFJGF6ZUM7CMZTNFQF 14 | 4EPVTT1HPA8U60CDUKXE 15 | 57SGAJMLCTCH92QUA0EE 16 | 5BHTH9RHH3PQT913I59W 17 | 5FKQL4K14KCB72Y8YMC2 18 | 5HH0WPWIY06DLAFOBQ4M 19 | 5QFK2PMHNX7UALK52NNA 20 | 5UB5KFD2PK38Z4LS6W80 21 | 6799D6LEBH3NSRV1KH27 22 | 78NJ5YFQF72BGC8RO51C 23 | 7FUCNXB39F78WTOP5K71 24 | 8GYK8A9MBRC9TV0FVSRA 25 | 8M99G0JLAXG9GLPV0O8G 26 | 8RE90C8H5DKF4V6HO8UU 27 | 8ZG2TRZ81MAWHZPN9KKG 28 | 9DCM2IB45SK6YKQNYUQY 29 | 9DHWWP5Y66VDMPXISZ13 30 | 9DQYTIU00I4JC0OEOKQQ 31 | A11O45O3NAXWM7T2H8CH 32 | A4R1S23KR0KU2WSYHK2X 33 | A5RNNK0A891WUSC2V624 34 | AT5CRO5JUDBWD4RUPXSQ 35 | BNK95S2SJXEGSW7VAKYU 36 | BXJWOUYP2J3EN4U92517 37 | BYSRSI3H4YTWKMM3MADP 38 | BZUFJX66T0W6ZPVTL9DU 39 | CB5P5W7X310NIIVU7UZV 40 | CBIJFVZ5L9BS0LKWE8YL 41 | CCGAKN4EDT72KC8TTJ76 42 | CLXFYOBQDCVXQ9P7YC07 43 | CMPXO4J23G58J53Q98SZ 44 | CZPMV6KWZ4I7IJJP9FOK 45 | DLKXBV73A55ZTSZ0QQI2 46 | DQ5UYBGR5QP6L692QSG6 47 | DYXSCIWHLSUOZIDDSZ40 48 | E2ZMO66WGS74UKXTZPPQ 49 | EJ5V7SPR4961JWD6SS8V 50 | FGM5NIWN3URY4HF4WNUW 51 | GSC9KNY0VEZXFSGWNF25 52 | HVE7DR3CUA2IM3RC6OMA 53 | HZZ4O0BRKF8S0YX3NNF7 54 | I2VZ7N8H9QYNYT7ZZF1Y 55 | IDWWHGWJ5STOQXSDT6GU 56 | IIY6TYJMTJIZRIZLB9YW 57 | IJJY51YW3W4YJJ7DTVTK 58 | IQYKPTWXVV9H0IHB8YXC 59 | JEC6HJ7SQJXBKVREX03F 60 | JGFOLWJF7YCYD8DPHQNH 61 | K32FD6LRSUSSXGS1YUOX 62 | KM5RYAMP4P4ZP6XWP3Q2 63 | KSNYHUBHHUJTYJ14UQZR 64 | LH4FVU3TQDEC87YGN6FL 65 | LJSDNMND9SHKM7Q4IRHJ 66 | MFTDVMBWFNQ3F5KHBRDR 67 | MJHV7F65TB2A76CQLOC3 68 | MVKIPGBKTNSENNP1S4HB 69 | O5TSIKRD4AIB8K84WIR9 70 | OIRDLE32TXZX942FVZMM 71 | P1OTI3IWJUIB5NRLULLH 72 | PVNXUK681N9BY14K4Z86 73 | Q0MEX9ZIKAGJORSPLQ3Y 74 | Q7J0WYM695R9MA285ZW0 75 | QZC1W0FNR19KJFLOCFLH 76 | R8ER97O9UUN77C02VE2J 77 | RSZY41MT2FGDKHWWL5L2 78 | SN4LF8SGBSRQUPTDSX78 79 | TDDI6L3Y0L9VVFP9MNFS 80 | UZUZZT2W9IUSHL6ASOX3 81 | -------------------------------------------------------------------------------- /overview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunMa11/SegWithDistMap/153dabf3bc5d9e48058e1497857ac6b00c7abab8/overview.PNG --------------------------------------------------------------------------------