├── Fm_eval.py
├── README.md
├── binary.py
├── checkpoint
└── nullfile
├── data
├── BraTS2019
└── preprocessBraTS.py
├── image
├── F1N.png
├── Trust_E.gif
└── nullfile
├── log
└── nullfile
├── models
├── criterions.py
├── eval
├── lib
│ ├── IntmdSequential.py
│ ├── PositionalEncoding.py
│ ├── TransBTS_downsample8x_skipconnection.py
│ ├── Transformer.py
│ ├── UNet3DZoo.py
│ ├── Unet_skipconnection.py
│ ├── VNet3D.py
│ ├── nullfile
│ └── seg_eval.py
└── trustedseg.py
├── numpyfunctions.py
├── plot.py
├── predict.py
├── results
└── nullfile
├── test_uncertainty.py
├── train.py
└── trainTBraTS.py
/Fm_eval.py:
--------------------------------------------------------------------------------
1 | # This is the eval for Sm Fm ...
2 | import os
3 | import argparse
4 | import tqdm
5 | import sys
6 |
7 | import numpy as np
8 |
9 | from PIL import Image
10 | from tabulate import tabulate
11 |
12 | filepath = os.path.split(os.path.abspath(__file__))[0]
13 | repopath = os.path.split(filepath)[0]
14 | sys.path.append(repopath)
15 |
16 | from utils.eval_functions import *
17 | from utils.utils import *
18 |
19 | def evaluate(opt, args):
20 | if os.path.isdir(opt.Eval.result_path) is False:
21 | os.makedirs(opt.Eval.result_path)
22 |
23 | method = os.path.split(opt.Eval.pred_root)[-1]
24 | Thresholds = np.linspace(1, 0, 256)
25 | headers = opt.Eval.metrics #['meanDic', 'meanIoU', 'wFm', 'Sm', 'meanEm', 'mae', 'maxEm', 'maxDic', 'maxIoU', 'meanSen', 'maxSen', 'meanSpe', 'maxSpe']
26 | results = []
27 |
28 | if args.verbose is True:
29 | print('#' * 20, 'Start Evaluation', '#' * 20)
30 | datasets = tqdm.tqdm(opt.Eval.datasets, desc='Expr - ' + method, total=len(
31 | opt.Eval.datasets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
32 | else:
33 | datasets = opt.Eval.datasets
34 |
35 | for dataset in datasets:
36 | pred_root = os.path.join(opt.Eval.pred_root, dataset)
37 | gt_root = os.path.join(opt.Eval.gt_root, dataset, 'masks')
38 |
39 | preds = os.listdir(pred_root)
40 | gts = os.listdir(gt_root)
41 |
42 | preds.sort()
43 | gts.sort()
44 |
45 | threshold_Fmeasure = np.zeros((len(preds), len(Thresholds)))
46 | threshold_Emeasure = np.zeros((len(preds), len(Thresholds)))
47 | threshold_IoU = np.zeros((len(preds), len(Thresholds)))
48 | # threshold_Precision = np.zeros((len(preds), len(Thresholds)))
49 | # threshold_Recall = np.zeros((len(preds), len(Thresholds)))
50 | threshold_Sensitivity = np.zeros((len(preds), len(Thresholds)))
51 | threshold_Specificity = np.zeros((len(preds), len(Thresholds)))
52 | threshold_Dice = np.zeros((len(preds), len(Thresholds)))
53 |
54 | Smeasure = np.zeros(len(preds))
55 | wFmeasure = np.zeros(len(preds))
56 | MAE = np.zeros(len(preds))
57 |
58 | if args.verbose is True:
59 | samples = tqdm.tqdm(enumerate(zip(preds, gts)), desc=dataset + ' - Evaluation', total=len(
60 | preds), position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}')
61 | else:
62 | samples = enumerate(zip(preds, gts))
63 |
64 | for i, sample in samples:
65 | pred, gt = sample
66 | assert os.path.splitext(pred)[0] == os.path.splitext(gt)[0]
67 |
68 | pred_mask = np.array(Image.open(os.path.join(pred_root, pred)))
69 | gt_mask = np.array(Image.open(os.path.join(gt_root, gt)))
70 |
71 | if len(pred_mask.shape) != 2:
72 | pred_mask = pred_mask[:, :, 0]
73 | if len(gt_mask.shape) != 2:
74 | gt_mask = gt_mask[:, :, 0]
75 |
76 | assert pred_mask.shape == gt_mask.shape
77 |
78 | gt_mask = gt_mask.astype(np.float64) / 255
79 | gt_mask = (gt_mask > 0.5).astype(np.float64)
80 |
81 | pred_mask = pred_mask.astype(np.float64) / 255
82 |
83 | Smeasure[i] = StructureMeasure(pred_mask, gt_mask)
84 | wFmeasure[i] = original_WFb(pred_mask, gt_mask)
85 | MAE[i] = np.mean(np.abs(gt_mask - pred_mask))
86 |
87 | threshold_E = np.zeros(len(Thresholds))
88 | threshold_F = np.zeros(len(Thresholds))
89 | threshold_Pr = np.zeros(len(Thresholds))
90 | threshold_Rec = np.zeros(len(Thresholds))
91 | threshold_Iou = np.zeros(len(Thresholds))
92 | threshold_Spe = np.zeros(len(Thresholds))
93 | threshold_Dic = np.zeros(len(Thresholds))
94 |
95 | for j, threshold in enumerate(Thresholds):
96 | threshold_Pr[j], threshold_Rec[j], threshold_Spe[j], threshold_Dic[j], threshold_F[j], threshold_Iou[j] = Fmeasure_calu(pred_mask, gt_mask, threshold)
97 |
98 | Bi_pred = np.zeros_like(pred_mask)
99 | Bi_pred[pred_mask >= threshold] = 1
100 | threshold_E[j] = EnhancedMeasure(Bi_pred, gt_mask)
101 |
102 | threshold_Emeasure[i, :] = threshold_E
103 | threshold_Fmeasure[i, :] = threshold_F
104 | threshold_Sensitivity[i, :] = threshold_Rec
105 | threshold_Specificity[i, :] = threshold_Spe
106 | threshold_Dice[i, :] = threshold_Dic
107 | threshold_IoU[i, :] = threshold_Iou
108 |
109 | result = []
110 |
111 | mae = np.mean(MAE)
112 | Sm = np.mean(Smeasure)
113 | wFm = np.mean(wFmeasure)
114 |
115 | column_E = np.mean(threshold_Emeasure, axis=0)
116 | meanEm = np.mean(column_E)
117 | maxEm = np.max(column_E)
118 |
119 | column_Sen = np.mean(threshold_Sensitivity, axis=0)
120 | meanSen = np.mean(column_Sen)
121 | maxSen = np.max(column_Sen)
122 |
123 | column_Spe = np.mean(threshold_Specificity, axis=0)
124 | meanSpe = np.mean(column_Spe)
125 | maxSpe = np.max(column_Spe)
126 |
127 | column_Dic = np.mean(threshold_Dice, axis=0)
128 | meanDic = np.mean(column_Dic)
129 | maxDic = np.max(column_Dic)
130 |
131 | column_IoU = np.mean(threshold_IoU, axis=0)
132 | meanIoU = np.mean(column_IoU)
133 | maxIoU = np.max(column_IoU)
134 |
135 | # result.extend([meanDic, meanIoU, wFm, Sm, meanEm, mae, maxEm, maxDic, maxIoU, meanSen, maxSen, meanSpe, maxSpe])
136 | # results.append([dataset, *result])
137 |
138 | out = []
139 | for metric in opt.Eval.metrics:
140 | out.append(eval(metric))
141 |
142 | result.extend(out)
143 | results.append([dataset, *result])
144 |
145 | csv = os.path.join(opt.Eval.result_path, 'result_' + dataset + '.csv')
146 | if os.path.isfile(csv) is True:
147 | csv = open(csv, 'a')
148 | else:
149 | csv = open(csv, 'w')
150 | csv.write(', '.join(['method', *headers]) + '\n')
151 |
152 | out_str = method + ','
153 | for metric in result:
154 | out_str += '{:.4f}'.format(metric) + ','
155 | out_str += '\n'
156 |
157 | csv.write(out_str)
158 | csv.close()
159 | tab = tabulate(results, headers=['dataset', *headers], floatfmt=".3f")
160 |
161 | if args.verbose is True:
162 | print(tab)
163 | print("#"*20, "End Evaluation", "#"*20)
164 |
165 | return tab
166 |
167 | if __name__ == "__main__":
168 | args = parse_args()
169 | opt = load_config(args.config)
170 | evaluate(opt, args)
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TBraTS
2 | * This repository provides the code for our accepted MICCAI'2022 paper "TBraTS: Trusted Brain Tumor Segmentation"
3 | * Official implementation [TBraTS: Trusted Brain Tumor Segmentation](https://arxiv.org/abs/2206.09309)
4 | * Journal version [[Paper]](https://arxiv.org/abs/2301.00349) [[code]](https://github.com/Cocofeat/UMIS)
5 | * Official video in [MICS2022](https://aim.nuist.edu.cn/events/mics2022.htm) of [TBraTS: Trusted Brain Tumor Segmentation](https://www.bilibili.com/video/BV1nW4y1a7Qp/?spm_id_from=333.337.search-card.all.click&vd_source=6ab19d355475883daafd34a6daae54a5) (**3rd Prize**)
6 |
7 | ## Introduction
8 | Despite recent improvements in the accuracy of brain tumor segmentation, the results still exhibit low levels of confidence and robustness. Uncertainty estimation is one effective way to change this situation, as it provides a measure of confidence in the segmentation results. In this paper, we propose a trusted brain tumor segmentation network which can generate robust segmentation results and reliable uncertainty estimations without excessive computational burden and modification of the backbone network. In our method, uncertainty is modeled explicitly using subjective logic theory, which treats the predictions of backbone neural network as subjective opinions by parameterizing the class probabilities of the segmentation as a Dirichlet distribution. Meanwhile, the trusted segmentation framework learns the function that gathers reliable evidence from the feature leading to the final segmentation results. Overall, our unified trusted segmentation framework endows the model with reliability and robustness to out-of-distribution samples. To evaluate the effectiveness of our model in robustness and reliability, qualitative and quantitative experiments are conducted on the BraTS 2019 dataset.
9 |
10 |

11 |
12 | ## Requirements
13 | Some important required packages include:
14 | Pytorch version >=0.4.1.
15 | Visdom
16 | Python == 3.7
17 | Some basic python packages such as Numpy.
18 |
19 | ## Data Acquisition
20 | - The multimodal brain tumor datasets (**BraTS 2019**) could be acquired from [here](https://ipp.cbica.upenn.edu/).
21 |
22 | ## Data Preprocess
23 | After downloading the dataset from [here](https://ipp.cbica.upenn.edu/), data preprocessing is needed which is to convert the .nii files as .pkl files and realize date normalization.
24 |
25 | Follow the `python3 data/preprocessBraTS.py ` which is referenced from the [TransBTS](https://github.com/Wenxuan-1119/TransBTS/blob/main/data/preprocess.py)
26 |
27 | ## Training & Testing
28 | Run the `python3 trainTBraTS.py ` : your own backbone with our framework(U/V/AU/TransBTS)
29 |
30 | Run the `python3 train.py ` : the backbone without our framework
31 |
32 | ## :fire: NEWS :fire:
33 | * [09/17] More experiments on trustworthy medical image segmentation please refer to [UMIS](https://github.com/Cocofeat/UMIS).
34 | * [09/17] We released all the codes.
35 | * [06/05] We will release the code as soon as possible.
36 | * [06/13] We have uploaded the main part of our code. We will upload all the code after camera-ready.
37 | * [06/22] Our pre-printed version of the paper is available at [TBraTS: Trusted Brain Tumor Segmentation](https://arxiv.org/abs/2206.09309)
38 | ## Citation
39 | If you find our work is helpful for your research, please consider to cite:
40 | ```
41 | @InProceedings{Coco2022TBraTS,
42 | author = {Zou, Ke and Yuan, Xuedong and Shen, Xiaojing and Wang, Meng and Fu, Huazhu},
43 | booktitle = {Medical Image Computing and Computer Assisted Intervention -- MICCAI 2022},
44 | title = {TBraTS: Trusted Brain Tumor Segmentation},
45 | year = {2022},
46 | address = {Cham},
47 | pages = {503--513},
48 | publisher = {Springer Nature Switzerland},
49 | }
50 | ```
51 | ## Acknowledgement
52 | Part of the code is revised from [TransBTS](https://github.com/Wenxuan-1119/TransBTS) and [TMC](https://github.com/hanmenghan/TMC)
53 |
54 | ## Contact
55 | * If you have any problems about our work, please contact [me](kezou8@gmail.com)
56 | * Project Link: [TBraTS](https://github.com/Cocofeat/TBraTS/)
57 |
--------------------------------------------------------------------------------
/checkpoint/nullfile:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/BraTS2019:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 | import random
5 | import numpy as np
6 | from torchvision.transforms import transforms
7 | from torch.utils.data import DataLoader
8 | import pickle
9 | from scipy import ndimage
10 | import argparse
11 |
12 | def pkload(fname):
13 | with open(fname, 'rb') as f:
14 | return pickle.load(f)
15 |
16 |
17 | class MaxMinNormalization(object):
18 | def __call__(self, sample):
19 | image = sample['image']
20 | label = sample['label']
21 | Max = np.max(image)
22 | Min = np.min(image)
23 | image = (image - Min) / (Max - Min)
24 |
25 | return {'image': image, 'label': label}
26 |
27 |
28 | class Random_Flip(object):
29 | def __call__(self, sample):
30 | image = sample['image']
31 | label = sample['label']
32 | if random.random() < 0.5:
33 | image = np.flip(image, 0)
34 | label = np.flip(label, 0)
35 | if random.random() < 0.5:
36 | image = np.flip(image, 1)
37 | label = np.flip(label, 1)
38 | if random.random() < 0.5:
39 | image = np.flip(image, 2)
40 | label = np.flip(label, 2)
41 |
42 | return {'image': image, 'label': label}
43 |
44 |
45 | class Random_Crop(object):
46 | def __call__(self, sample):
47 | image = sample['image']
48 | label = sample['label']
49 | H = random.randint(0, 240 - 128)
50 | W = random.randint(0, 240 - 128)
51 | D = random.randint(0, 160 - 128)
52 |
53 | image = image[H: H + 128, W: W + 128, D: D + 128, ...]
54 | label = label[..., H: H + 128, W: W + 128, D: D + 128]
55 |
56 | return {'image': image, 'label': label}
57 |
58 |
59 | class Random_intencity_shift(object):
60 | def __call__(self, sample, factor=0.1):
61 | image = sample['image']
62 | label = sample['label']
63 |
64 | scale_factor = np.random.uniform(1.0-factor, 1.0+factor, size=[1, image.shape[1], 1])
65 | shift_factor = np.random.uniform(-factor, factor, size=[1, image.shape[1], 1])
66 |
67 | image = image*scale_factor+shift_factor
68 |
69 | return {'image': image, 'label': label}
70 |
71 | class Random_intencity_shiftboth(object):
72 | def __call__(self, sample, factor=0.1):
73 | image = sample['image']
74 | label = sample['label']
75 |
76 | scale_factor = np.random.uniform(1.0-factor, 1.0+factor, size=[1, image.shape[1], 1, image.shape[-1]])
77 | shift_factor = np.random.uniform(-factor, factor, size=[1, image.shape[1], 1, image.shape[-1]])
78 |
79 | image = image*scale_factor+shift_factor
80 |
81 | return {'image': image, 'label': label}
82 |
83 | class Random_rotate(object):
84 | def __call__(self, sample):
85 | image = sample['image']
86 | label = sample['label']
87 |
88 | angle = round(np.random.uniform(-10, 10), 2)
89 | image = ndimage.rotate(image, angle, axes=(0, 1), reshape=False)
90 | label = ndimage.rotate(label, angle, axes=(0, 1), reshape=False)
91 |
92 | return {'image': image, 'label': label}
93 |
94 |
95 | class Pad(object):
96 | def __call__(self, sample):
97 | image = sample['image']
98 | label = sample['label']
99 |
100 | image = np.pad(image, ((0, 0), (0, 0), (0, 5)), mode='constant')
101 | label = np.pad(label, ((0, 0), (0, 0), (0, 5)), mode='constant')
102 | return {'image': image, 'label': label}
103 | #(240,240,155)>(240,240,160)
104 | class Padboth(object):
105 | def __call__(self, sample):
106 | image = sample['image']
107 | label = sample['label']
108 |
109 | image = np.pad(image, ((0, 0), (0, 0), (0, 5), (0, 0)), mode='constant')
110 | label = np.pad(label, ((0, 0), (0, 0), (0, 5)), mode='constant')
111 | return {'image': image, 'label': label}
112 | #(240,240,155,n)>(240,240,160,n)
113 |
114 | class ToTensor(object):
115 | """Convert ndarrays in sample to Tensors."""
116 | def __call__(self, sample):
117 | image = sample['image']
118 | image = np.ascontiguousarray(image)
119 | label = sample['label']
120 | label = np.ascontiguousarray(label)
121 |
122 | image = torch.from_numpy(image).float().unsqueeze(0)
123 | label = torch.from_numpy(label).long()
124 |
125 | return {'image': image, 'label': label}
126 |
127 | class ToTensorboth(object):
128 | """Convert ndarrays in sample to Tensors."""
129 | def __call__(self, sample):
130 | image = sample['image']
131 | image = np.ascontiguousarray(image.transpose(3, 0, 1, 2))
132 | label = sample['label']
133 | label = np.ascontiguousarray(label)
134 |
135 | image = torch.from_numpy(image).float()
136 | label = torch.from_numpy(label).long()
137 |
138 | return {'image': image, 'label': label}
139 |
140 | def transform(sample):
141 | trans = transforms.Compose([
142 | Pad(),
143 | # Random_rotate(), # time-consuming
144 | Random_Crop(),
145 | Random_Flip(),
146 | Random_intencity_shift(),
147 | ToTensor()
148 | ])
149 |
150 | return trans(sample)
151 |
152 |
153 | def transform_valid(sample):
154 | trans = transforms.Compose([
155 | Pad(),
156 | # MaxMinNormalization(),
157 | ToTensor()
158 | ])
159 |
160 | return trans(sample)
161 |
162 | def transformboth(sample):
163 | trans = transforms.Compose([
164 | Padboth(),
165 | # Random_rotate(), # time-consuming
166 | Random_Crop(),
167 | Random_Flip(),
168 | Random_intencity_shiftboth(),
169 | ToTensorboth()
170 | ])
171 |
172 | return trans(sample)
173 |
174 |
175 | def transformboth_valid(sample):
176 | trans = transforms.Compose([
177 | Padboth(),
178 | # MaxMinNormalization(),
179 | ToTensorboth()
180 | ])
181 |
182 | return trans(sample)
183 |
184 | class BraTS(Dataset):
185 | def __init__(self, list_file, root='', mode='train', modal='t1'):
186 | self.lines = []
187 | paths, names = [], []
188 | with open(list_file) as f:
189 | for line in f:
190 | line = line.strip()
191 | name = line.split('/')[-1]
192 | names.append(name)
193 | path = os.path.join(root, line, name + '_')
194 | paths.append(path)
195 | self.lines.append(line)
196 | # changed bo coco
197 | # del paths[0:3]
198 | # del names[0:3]
199 | # del self.lines[0:3]
200 | self.mode = mode
201 | self.modal = modal
202 | self.names = names
203 | self.paths = paths
204 |
205 | def __getitem__(self, item):
206 | # input could be chosed with t1/t2/four modalities
207 | path = self.paths[item]
208 | if self.mode == 'train':
209 | image, label = pkload(path + 'data_f32b04M.pkl')
210 | # print(np.unique(label))
211 | label[label==4]=3
212 | # print(np.unique(label))
213 | if self.modal == 't1':
214 | sample = {'image': image[..., 0] , 'label': label}
215 | sample = transform(sample)
216 | elif self.modal =='t2':
217 | sample = {'image': image[..., 1], 'label': label}
218 | sample = transform(sample)
219 | else:
220 | sample = {'image': image, 'label': label}
221 | sample = transformboth(sample)
222 | return sample['image'], sample['label']
223 | elif self.mode == 'valid':
224 | image, label = pkload(path + 'data_f32b04M.pkl')
225 | label[label == 4] = 3
226 | if self.modal == 't1':
227 | sample = {'image': image[..., 0], 'label': label}
228 | sample = transform_valid(sample)
229 | elif self.modal =='t2':
230 | sample = {'image': image[..., 1], 'label': label}
231 | sample = transform_valid(sample)
232 | else:
233 | sample = {'image': image, 'label': label}
234 | sample = transformboth_valid(sample)
235 | return sample['image'], sample['label']
236 | else:
237 | image,label = pkload(path + 'data_f32b04M.pkl')
238 | label[label == 4] = 3
239 | if self.modal == 't1':
240 | sample = {'image': image[..., 0], 'label': label}
241 | sample = transform_valid(sample)
242 | elif self.modal =='t2':
243 | sample = {'image': image[..., 1], 'label': label}
244 | sample = transform_valid(sample)
245 | else:
246 | sample = {'image': image, 'label': label}
247 | sample = transformboth_valid(sample)
248 | return sample['image'], sample['label']
249 |
250 | def __len__(self):
251 | return len(self.names)
252 |
253 | def collate(self, batch):
254 | return [torch.cat(v) for v in zip(*batch)]
255 |
256 | if __name__ == '__main__':
257 | parser = argparse.ArgumentParser()
258 | parser.add_argument('--root', default='E:/BraTSdata1/archive2019', type=str)
259 | parser.add_argument('--train_dir', default='MICCAI_BraTS_2019_Data_TTraining', type=str)
260 | parser.add_argument('--valid_dir', default='MICCAI_BraTS_2019_Data_TValidation', type=str)
261 | parser.add_argument('--test_dir', default='MICCAI_BraTS_2019_Data_TTest', type=str)
262 | parser.add_argument('--mode', default='train', type=str)
263 | parser.add_argument('--train_file', default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt', type=str)
264 | parser.add_argument('--valid_file', default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt', type=str)
265 | parser.add_argument('--test_file', default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt', type=str)
266 | parser.add_argument('--dataset', default='brats', type=str)
267 | parser.add_argument('--num_gpu', default= 4, type=int)
268 | parser.add_argument('--num_workers', default=4, type=int)
269 | parser.add_argument('--batch_size', default=8, type=int)
270 | parser.add_argument('--modal', default='both', type=str)
271 | parser.add_argument('--Variance', default=0.1, type=int)
272 | args = parser.parse_args()
273 | train_list = os.path.join(args.root, args.train_dir, args.train_file)
274 | train_root = os.path.join(args.root, args.train_dir)
275 | val_list = os.path.join(args.root, args.valid_dir, args.valid_file)
276 | val_root = os.path.join(args.root, args.valid_dir)
277 | test_list = os.path.join(args.root, args.test_dir, args.test_file)
278 | test_root = os.path.join(args.root, args.test_dir)
279 | train_set = BraTS(train_list, train_root, args.mode,args.modal)
280 | val_set = BraTS(val_list, val_root, args.mode, args.modal)
281 | test_set = BraTS(test_list, test_root, args.mode, args.modal)
282 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
283 | # train_loader = DataLoader(dataset=train_set, sampler=train_sampler, batch_size=args.batch_size // args.num_gpu,
284 | # drop_last=True, num_workers=args.num_workers, pin_memory=True)
285 | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size)
286 | val_loader = DataLoader(dataset=val_set, batch_size=1)
287 | test_loader = DataLoader(dataset=test_set, batch_size=1)
288 | for i, data in enumerate(train_loader):
289 | x, target = data
290 | if args.mode == 'test':
291 | noise = torch.clamp(torch.randn_like(x) * args.Variance, -args.Variance * 2, args.Variance * 2)
292 | x += noise
293 | # x_no = np.unique(x.numpy())
294 | # target_no = np.unique(target.numpy())
295 |
--------------------------------------------------------------------------------
/data/preprocessBraTS.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import numpy as np
4 | import nibabel as nib
5 | import shutil
6 | import random
7 | modalities = ('flair', 't1ce', 't1', 't2')
8 | twomodalities = ('t1', 't2')
9 | # this is the file to save the two-modality and four-modality img/mask
10 | # train
11 | train_set = {
12 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/',
13 | 'flist': 'train.txt',
14 | 'has_label': True
15 | }
16 |
17 | # test/validation data using in train dataset
18 | Ttrain_set = {
19 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_TTraining/',
20 | 'flist': 'train.txt',
21 | 'has_label': True
22 | }
23 | Tvalid_set = {
24 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_TValidation/',
25 | 'has_label': True
26 | }
27 | Ttest_set = {
28 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_TTest/',
29 | 'has_label': True
30 | }
31 | # valid_set = {
32 | # 'root': 'path to valid set',
33 | # 'flist': 'valid.txt',
34 | # 'has_label': False
35 | # }
36 | # test_set = {
37 | # 'root': 'path to testing set',
38 | # 'flist': 'test.txt',
39 | # 'has_label': False
40 | # }
41 |
42 | def ensure_dir_exists(dir_name):
43 | """Makes sure the folder exists on disk.
44 | Args:
45 | dir_name: Path string to the folder we want to create.
46 | """
47 | if not os.path.exists(dir_name):
48 | os.makedirs(dir_name)
49 |
50 | def save_subjects(root,dir_name,subjects):
51 | f = open(root+dir_name+"_subject.txt", "w")
52 | str = '\n'
53 | f.write(str.join(subjects))
54 | f.close()
55 |
56 | def nib_load(file_name):
57 | if not os.path.exists(file_name):
58 | print('Invalid file name, can not find the file!')
59 |
60 | proxy = nib.load(file_name)
61 | data = proxy.get_data()
62 | proxy.uncache()
63 | return data
64 |
65 |
66 | def process_i16(path, has_label=True):
67 | """ Save the original 3D MRI images with dtype=int16.
68 | Noted that no normalization is used! """
69 | label = np.array(nib_load(path + 'seg.nii.gz'), dtype='uint8', order='C')
70 |
71 | images = np.stack([
72 | np.array(nib_load(path + modal + '.nii.gz'), dtype='int16', order='C')
73 | for modal in modalities], -1)# [240,240,155]
74 |
75 | output = path + 'data_i16.pkl'
76 |
77 | with open(output, 'wb') as f:
78 | print(output)
79 | print(images.shape, type(images), label.shape, type(label)) # (240,240,155,4) , (240,240,155)
80 | pickle.dump((images, label), f)
81 |
82 | if not has_label:
83 | return
84 |
85 |
86 | def process_f32b0(path, has_label=True):
87 | """ Save the data with dtype=float32.
88 | z-score is used but keep the background with zero! """
89 | if has_label:
90 | label = np.array(nib_load(path + 'seg.nii'), dtype='uint8', order='C')
91 | images = np.stack([np.array(nib_load(path + modal + '.nii'), dtype='float32', order='C') for modal in modalities], -1) # [240,240,155]
92 |
93 | output = path + 'data_f32b0.pkl'
94 | mask = images.sum(-1) > 0
95 | for k in range(4):
96 |
97 | x = images[..., k] #
98 | y = x[mask]
99 |
100 | # 0.8885
101 | x[mask] -= y.mean()
102 | x[mask] /= y.std()
103 |
104 | images[..., k] = x
105 |
106 | with open(output, 'wb') as f:
107 | print(output)
108 |
109 | if has_label:
110 | pickle.dump((images, label), f)
111 | else:
112 | pickle.dump(images, f)
113 |
114 | if not has_label:
115 | return
116 |
117 | def process_f32b0twomodal(path, has_label=True):
118 | """ Save the data with dtype=float32.
119 | z-score is used but keep the background with zero! """
120 | if has_label:
121 | label = np.array(nib_load(path + 'seg.nii'), dtype='uint8', order='C')
122 | images = np.stack([np.array(nib_load(path + modal + '.nii'), dtype='float32', order='C') for modal in twomodalities], -1) # [240,240,155]
123 |
124 | output = path + 'data_f32b0.pkl'
125 | mask = images.sum(-1) > 0
126 | for k in range(images.shape[3]):
127 |
128 | x = images[..., k] #
129 | y = x[mask]
130 |
131 | # 0.8885
132 | x[mask] -= y.mean()
133 | x[mask] /= y.std()
134 |
135 | images[..., k] = x
136 |
137 | with open(output, 'wb') as f:
138 | print(output)
139 |
140 | if has_label:
141 | pickle.dump((images, label), f)
142 | else:
143 | pickle.dump(images, f)
144 |
145 | if not has_label:
146 | return
147 |
148 | def doit(dset,args_modal):
149 | root, has_label = dset['root'], dset['has_label']
150 | file_list = os.path.join(root, dset['flist'])
151 | subjects = open(file_list).read().splitlines()
152 | names = [sub.split('/')[-1] for sub in subjects]
153 | paths = [os.path.join(root, sub, name + '_') for sub, name in zip(subjects, names)]
154 |
155 | for path in paths:
156 | print(path)
157 | if args_modal =='2':
158 | process_f32b0twomodal(path, has_label) # two modal
159 | else:
160 | process_f32b0(path, has_label) # four modal
161 |
162 | def move_doit(dset,train_dir,valid_dir,test_dir):
163 | ensure_dir_exists(train_dir)
164 | ensure_dir_exists(valid_dir)
165 | ensure_dir_exists(test_dir)
166 | root, has_label = dset['root'], dset['has_label']
167 | file_list = os.path.join(root, dset['flist'])
168 | subjects = open(file_list).read().splitlines()
169 | HGG_subjects = []
170 | LGG_subjects = []
171 | for sub in subjects:
172 | if "HGG" in sub:
173 | HGG_subjects.append(sub)
174 | else:
175 | LGG_subjects.append(sub)
176 | #val + test
177 | valtest_HGGsubjects = random.sample(HGG_subjects, int(len(HGG_subjects)*0.3)+1)
178 | valtest_LGGsubjects = random.sample(LGG_subjects, int(len(LGG_subjects) * 0.3))
179 | # val test HGG
180 | val_HGGsubjects = random.sample(valtest_HGGsubjects, int(len(valtest_HGGsubjects)*0.5))
181 | test_HGGsubjects = list(set(valtest_HGGsubjects).difference(set(val_HGGsubjects)))
182 | # val test LGG
183 | val_LGGsubjects = random.sample(valtest_LGGsubjects, int(len(valtest_LGGsubjects) * 0.5))
184 | test_LGGsubjects = list(set(valtest_LGGsubjects).difference(set(val_LGGsubjects)))
185 | # val
186 | val_HGGsubjects.extend(val_LGGsubjects)
187 | val_subjects = val_HGGsubjects
188 | # test
189 | test_HGGsubjects.extend(test_LGGsubjects)
190 | test_subjects = test_HGGsubjects
191 | # val + test
192 | valtest_HGGsubjects.extend(valtest_LGGsubjects)
193 | valtest_subjects = valtest_HGGsubjects
194 | tran_subjects = list(set(subjects).difference(set(valtest_subjects)))
195 | # save subject name
196 | save_subjects(root, "Ttrain", tran_subjects)
197 | save_subjects(root, "Ttest", test_subjects)
198 | save_subjects(root, "Tval", val_subjects)
199 | # move training dataset to Ttrain,TTest and Tval dataset
200 | for subject in tran_subjects:
201 | shutil.move(root + subject, train_dir + subject)
202 | for subject_test in test_subjects:
203 | shutil.move(root + subject_test, test_dir + subject_test)
204 | for subject_val in val_subjects:
205 | shutil.move(root + subject_val, valid_dir + subject_val)
206 |
207 | def delete_doit(dir):
208 | rootdir = dir
209 | GG_filelist = os.listdir(rootdir)
210 | for file in GG_filelist:
211 | names = os.listdir(rootdir+file)
212 | for name in names:
213 | files = os.listdir(rootdir + '/' +file + '/' + name)
214 | for modal_file in files:
215 | if '.nii' in modal_file:
216 | del_file = rootdir + '/' +file + '/' + name + '/' + modal_file # 当代码和要删除的文件不在同一个文件夹时,必须使用绝对路径
217 | os.remove(del_file) # 删除文件
218 | print("已经删除:", del_file)
219 |
220 | if __name__ == '__main__':
221 | args_modal = '4'
222 | doit(train_set,args_modal)
223 | move_doit(train_set,Ttrain_set['root'],Tvalid_set['root'],Ttest_set['root'])
224 | delete_doit(Ttrain_set['root'])
225 | delete_doit(Tvalid_set['root'])
226 | delete_doit(Ttest_set['root'])
227 | # doit(valid_set)
228 | # doit(test_set)
229 |
230 |
--------------------------------------------------------------------------------
/image/F1N.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cocofeat/TBraTS/5255b7c6c8d338c2c1af72368169054d889a977c/image/F1N.png
--------------------------------------------------------------------------------
/image/Trust_E.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cocofeat/TBraTS/5255b7c6c8d338c2c1af72368169054d889a977c/image/Trust_E.gif
--------------------------------------------------------------------------------
/image/nullfile:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/log/nullfile:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/criterions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from torch import nn
6 | import numpy as np
7 |
8 | def expand_target(x, n_class,mode='softmax'):
9 | """
10 | Converts NxDxHxW label image to NxCxDxHxW, where each label is stored in a separate channel
11 | :param input: 4D input image (NxDxHxW)
12 | :param C: number of channels/labels
13 | :return: 5D output image (NxCxDxHxW)
14 | """
15 | assert x.dim() == 4
16 | shape = list(x.size())
17 | shape.insert(1, n_class)
18 | shape = tuple(shape)
19 | xx = torch.zeros(shape)
20 | if mode.lower() == 'softmax':
21 | xx[:, 1, :, :, :] = (x == 1)
22 | xx[:, 2, :, :, :] = (x == 2)
23 | xx[:, 3, :, :, :] = (x == 3)
24 | if mode.lower() == 'sigmoid':
25 | xx[:, 0, :, :, :] = (x == 1)
26 | xx[:, 1, :, :, :] = (x == 2)
27 | xx[:, 2, :, :, :] = (x == 3)
28 | return xx.to(x.device)
29 |
30 | def flatten(tensor):
31 | """Flattens a given tensor such that the channel axis is first.
32 | The shapes are transformed as follows:
33 | (N, C, D, H, W) -> (C, N * D * H * W)
34 | """
35 | C = tensor.size(1)
36 | # new axis order
37 | axis_order = (1, 0) + tuple(range(2, tensor.dim()))
38 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
39 | transposed = tensor.permute(axis_order)
40 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
41 | return transposed.reshape(C, -1)
42 |
43 | def Dice(output, target, eps=1e-5):
44 | target = target.float()
45 | num = 2 * (output * target).sum()
46 | den = output.sum() + target.sum() + eps
47 | return 1.0 - num/den
48 |
49 | def sum_tensor(inp, axes, keepdim=False):
50 | axes = np.unique(axes).astype(int)
51 | if keepdim:
52 | for ax in axes:
53 | inp = inp.sum(int(ax), keepdim=True)
54 | else:
55 | for ax in sorted(axes, reverse=True):
56 | inp = inp.sum(int(ax))
57 | return inp
58 |
59 | def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
60 | """
61 | net_output must be (b, c, x, y(, z)))
62 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
63 | if mask is provided it must have shape (b, 1, x, y(, z)))
64 | :param net_output:
65 | :param gt:
66 | :param axes: can be (, ) = no summation
67 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
68 | :param square: if True then fp, tp and fn will be squared before summation
69 | :return:
70 | """
71 | if axes is None:
72 | axes = tuple(range(2, len(net_output.size())))
73 |
74 | shp_x = net_output.shape
75 | shp_y = gt.shape
76 |
77 | with torch.no_grad():
78 | if len(shp_x) != len(shp_y):
79 | gt = gt.view((shp_y[0], 1, *shp_y[1:]))
80 |
81 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
82 | # if this is the case then gt is probably already a one hot encoding
83 | y_onehot = gt
84 | else:
85 | gt = gt.long()
86 | y_onehot = torch.zeros(shp_x)
87 | if net_output.device.type == "cuda":
88 | y_onehot = y_onehot.cuda(net_output.device.index)
89 | y_onehot.scatter_(1, gt, 1)
90 |
91 | tp = net_output * y_onehot
92 | fp = net_output * (1 - y_onehot)
93 | fn = (1 - net_output) * y_onehot
94 | tn = (1 - net_output) * (1 - y_onehot)
95 |
96 | if mask is not None:
97 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
98 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
99 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
100 | tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)
101 |
102 | if square:
103 | tp = tp ** 2
104 | fp = fp ** 2
105 | fn = fn ** 2
106 | tn = tn ** 2
107 |
108 | if len(axes) > 0:
109 | tp = sum_tensor(tp, axes, keepdim=False)
110 | fp = sum_tensor(fp, axes, keepdim=False)
111 | fn = sum_tensor(fn, axes, keepdim=False)
112 | tn = sum_tensor(tn, axes, keepdim=False)
113 |
114 | return tp, fp, fn, tn
115 |
116 | class SoftDiceLoss(nn.Module):
117 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.):
118 | """
119 | """
120 | super(SoftDiceLoss, self).__init__()
121 |
122 | self.do_bg = do_bg
123 | self.batch_dice = batch_dice
124 | self.apply_nonlin = apply_nonlin
125 | self.smooth = smooth
126 |
127 | def forward(self, x, y, loss_mask=None):
128 | shp_x = x.shape
129 |
130 | if self.batch_dice:
131 | axes = [0] + list(range(2, len(shp_x)))
132 | else:
133 | axes = list(range(2, len(shp_x)))
134 |
135 | if self.apply_nonlin is not None:
136 | x = self.apply_nonlin(x)
137 |
138 | tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
139 |
140 | nominator = 2 * tp + self.smooth
141 | denominator = 2 * tp + fp + fn + self.smooth
142 |
143 | dc = 1 - (nominator / (denominator + 1e-5))
144 |
145 | if not self.do_bg:
146 | if self.batch_dice:
147 | dc = dc[1:]
148 | else:
149 | dc = dc[:, 1:]
150 | dc = dc.mean()
151 | # changed by coco
152 | return dc
153 |
154 | class softBCE_dice(nn.Module):
155 | def __init__(self, aggregate="sum",weight_ce=0.5, weight_dice=0.5):
156 | """
157 | DO NOT APPLY NONLINEARITY IN YOUR NETWORK!
158 |
159 | THIS LOSS IS INTENDED TO BE USED FOR BRATS REGIONS ONLY
160 | :param soft_dice_kwargs:
161 | :param bce_kwargs:
162 | :param aggregate:
163 | """
164 | super(softBCE_dice, self).__init__()
165 |
166 | self.aggregate = aggregate
167 | # self.ce = nn.BCEWithLogitsLoss(**bce_kwargs)
168 | self.Bce = torch.nn.BCELoss()
169 | self.weight_dice = weight_dice
170 | self.weight_ce = weight_ce
171 | self.sigmoid = nn.Sigmoid()
172 | self.dc = SoftDiceLoss()
173 | def forward(self, output, target):
174 | # ce_loss = self.ce(net_output, target)
175 | Bce_loss1 = self.Bce(self.sigmoid(output[:, 1, ...]), (target == 1).float())
176 | Bce_loss2 = self.Bce(self.sigmoid(output[:, 2, ...]), (target == 2).float())
177 | Bce_loss3 = self.Bce(self.sigmoid(output[:, 3, ...]), (target == 3).float())
178 | # Diceloss1 = Dice(output[:, 1, ...], (target == 1).float())
179 | # Diceloss2 = Dice(output[:, 2, ...], (target == 2).float())
180 | # Diceloss3 = Dice(output[:, 3, ...], (target == 4).float())
181 | Diceloss1 = self.dc(output[:, 1, ...], (target == 1).float())
182 | Diceloss2 = self.dc(output[:, 2, ...], (target == 2).float())
183 | Diceloss3 = self.dc(output[:, 3, ...], (target == 3).float())
184 | if self.aggregate == "sum":
185 | result1 = self.weight_ce * Bce_loss1 + self.weight_dice * Diceloss1
186 | result2 = self.weight_ce * Bce_loss2 + self.weight_dice * Diceloss2
187 | result3 = self.weight_ce * Bce_loss3 + self.weight_dice * Diceloss3
188 | else:
189 | raise NotImplementedError("nah son") # reserved for other stuff (later)
190 |
191 | return result1+result2+result3, 1-result1.data, 1-result2.data, 1-result3.data
192 |
193 | def softmaxBCE_dice(output, target):
194 | '''
195 | The dice loss for using softmax activation function
196 | :param output: (b, num_class, d, h, w)
197 | :param target: (b, d, h, w)
198 | :return: softmax dice loss torch.nn.BCELoss()
199 | '''
200 | Bce = torch.nn.BCELoss()
201 | Diceloss1 = Dice(output[:, 1, ...], (target == 1).float())
202 | Bceloss1 = Bce(output[:, 1, ...], (target == 1).float())
203 | loss1 = Diceloss1 + Bceloss1
204 | Diceloss2 = Dice(output[:, 2, ...], (target == 2).float())
205 | Bceloss2 = Bce(output[:, 2, ...], (target == 2).float())
206 | loss2 = Diceloss2 + Bceloss2
207 | Diceloss3 = Dice(output[:, 3, ...], (target == 3).float())
208 | Bceloss3 = Bce(output[:, 3, ...], (target == 3).float())
209 | loss3 = Diceloss3 + Bceloss3
210 |
211 | return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data
212 |
213 | # loss function
214 |
215 |
216 |
217 | def TDice(output, target,criterion_dl):
218 | dice = criterion_dl(output, target)
219 | return dice
220 |
221 | def TFocal(output, target,criterion_fl):
222 | focal = criterion_fl(output, target)
223 | return focal
224 |
225 | def focal_dce_eviloss(p, alpha, c, global_step, annealing_step):
226 | # dice focal loss
227 | criterion_dl = DiceLoss()
228 | L_dice = TDice(alpha,p,criterion_dl)
229 | criterion_fl = FocalLoss(4)
230 | L_focal = TFocal(alpha, p, criterion_fl)
231 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW]
232 | alpha = alpha.transpose(1, 2) # [N, HW, C]
233 | alpha = alpha.contiguous().view(-1, alpha.size(2))
234 | S = torch.sum(alpha, dim=1, keepdim=True)
235 | E = alpha - 1
236 | label = F.one_hot(p, num_classes=c)
237 | label = label.view(-1, c)
238 | # digama loss
239 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
240 | # log loss
241 | # labelK = label * (torch.log(S) - torch.log(alpha))
242 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True)
243 |
244 | annealing_coef = min(1, global_step / annealing_step)
245 | alp = E * (1 - label) + 1
246 | L_KL = annealing_coef * KL(alp, c)
247 |
248 | return (L_ace + L_dice + L_focal + L_KL)
249 |
250 | def dce_eviloss(p, alpha, c, global_step, annealing_step):
251 | criterion_dl = DiceLoss()
252 | # L_dice = TDice(alpha,p,criterion_dl)
253 | L_dice,_,_,_ = softmax_dice(alpha, p)
254 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW]
255 | alpha = alpha.transpose(1, 2) # [N, HW, C]
256 | alpha = alpha.contiguous().view(-1, alpha.size(2))
257 | S = torch.sum(alpha, dim=1, keepdim=True)
258 | E = alpha - 1
259 | label = F.one_hot(p, num_classes=c)
260 | label = label.view(-1, c)
261 | # digama loss
262 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
263 | # log loss
264 | # labelK = label * (torch.log(S) - torch.log(alpha))
265 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True)
266 |
267 | annealing_coef = min(1, global_step / annealing_step)
268 | alp = E * (1 - label) + 1
269 | L_KL = annealing_coef * KL(alp, c)
270 |
271 | return (L_ace + L_dice + L_KL)
272 |
273 |
274 | def dce_loss(p, alpha, c, global_step, annealing_step):
275 | criterion_dl = DiceLoss()
276 | L_dice = TDice(alpha,p,criterion_dl)
277 |
278 | return L_dice
279 | def ce_loss(p, alpha, c, global_step, annealing_step):
280 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW]
281 | alpha = alpha.transpose(1, 2) # [N, HW, C]
282 | alpha = alpha.contiguous().view(-1, alpha.size(2))
283 |
284 | # 0.0 permute all
285 | # alpha = alpha.permute(0,2,3,4,1).view(-1, c)
286 | S = torch.sum(alpha, dim=1, keepdim=True)
287 | E = alpha - 1
288 | label = F.one_hot(p, num_classes=c)
289 | label = label.view(-1, c)
290 | # S = S.permute(0, 2, 3, 4, 1)
291 | # alpha = alpha.permute(0, 2, 3, 4, 1)
292 | # label_K = label * (torch.digamma(S) - torch.digamma(alpha))
293 | # digama loss
294 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True)
295 | # log loss
296 | # labelK = label * (torch.log(S) - torch.log(alpha))
297 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True)
298 |
299 | annealing_coef = min(1, global_step / annealing_step)
300 | # label = label.permute(0, 4, 1, 2, 3)
301 | alp = E * (1 - label) + 1
302 | # alp = E.permute(0, 2, 3, 4, 1) * (1 - label) + 1
303 | L_KL = annealing_coef * KL(alp, c)
304 |
305 | return (L_ace + L_KL)
306 | # return L_ace
307 |
308 | def KL(alpha, c):
309 | S_alpha = torch.sum(alpha, dim=1, keepdim=True)
310 | beta = torch.ones((1, c)).cuda()
311 | # Mbeta = torch.ones((alpha.shape[0],c)).cuda()
312 | S_beta = torch.sum(beta, dim=1, keepdim=True)
313 | lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True)
314 | lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta)
315 | dg0 = torch.digamma(S_alpha)
316 | dg1 = torch.digamma(alpha)
317 | kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni
318 | return kl
319 |
320 | def mse_loss(p, alpha, c, global_step, annealing_step=1):
321 | S = torch.sum(alpha, dim=1, keepdim=True)
322 | E = alpha - 1
323 | m = alpha / S
324 | label = F.one_hot(p, num_classes=c)
325 | A = torch.sum((label - m) ** 2, dim=1, keepdim=True)
326 | B = torch.sum(alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True)
327 | annealing_coef = min(1, global_step / annealing_step)
328 | alp = E * (1 - label) + 1
329 | C = annealing_coef * KL(alp, c)
330 | return (A + B) + C
331 |
332 | class DiceLoss(nn.Module):
333 |
334 | def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True):
335 | super(DiceLoss, self).__init__()
336 | self.alpha = alpha
337 | self.beta = beta
338 |
339 | self.size_average = size_average
340 | self.reduce = reduce
341 |
342 | def forward(self, preds, targets, weight=False):
343 | N = preds.size(0)
344 | C = preds.size(1)
345 |
346 | preds = preds.permute(0, 2, 3, 4, 1).contiguous().view(-1, C)
347 | if targets.size(1)==4:
348 | targets = targets.permute(0, 2, 3, 4, 1).contiguous().view(-1, C)
349 | else:
350 | targets = targets.view(-1, 1)
351 |
352 | log_P = F.log_softmax(preds, dim=1)
353 | P = torch.exp(log_P)
354 | # P = F.softmax(preds, dim=1)
355 | smooth = torch.zeros(C, dtype=torch.float32).fill_(0.00001)
356 |
357 | class_mask = torch.zeros(preds.shape).to(preds.device) + 1e-8
358 | class_mask.scatter_(1, targets, 1.)
359 |
360 | ones = torch.ones(preds.shape).to(preds.device)
361 | P_ = ones - P
362 | class_mask_ = ones - class_mask
363 |
364 | TP = P * class_mask
365 | FP = P * class_mask_
366 | FN = P_ * class_mask
367 |
368 | smooth = smooth.to(preds.device)
369 | self.alpha = FP.sum(dim=(0)) / ((FP.sum(dim=(0)) + FN.sum(dim=(0))) + smooth)
370 |
371 | self.alpha = torch.clamp(self.alpha, min=0.2, max=0.8)
372 | #print('alpha:', self.alpha)
373 | self.beta = 1 - self.alpha
374 | num = torch.sum(TP, dim=(0)).float()
375 | den = num + self.alpha * torch.sum(FP, dim=(0)).float() + self.beta * torch.sum(FN, dim=(0)).float()
376 |
377 | dice = num / (den + smooth)
378 |
379 | if not self.reduce:
380 | loss = torch.ones(C).to(dice.device) - dice
381 | return loss
382 | loss = 1 - dice
383 | if weight is not False:
384 | loss *= weight.squeeze(0)
385 | loss = loss.sum()
386 | if self.size_average:
387 | if weight is not False:
388 | loss /= weight.squeeze(0).sum()
389 | else:
390 | loss /= C
391 |
392 | return loss
393 |
394 | class FocalLoss(nn.Module):
395 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
396 | super(FocalLoss, self).__init__()
397 |
398 | if alpha is None:
399 | self.alpha = torch.ones(class_num, 1).cuda()
400 | else:
401 | self.alpha = alpha
402 |
403 | self.gamma = gamma
404 | self.size_average = size_average
405 |
406 | def forward(self, preds, targets, weight=False):
407 | N = preds.size(0)
408 | C = preds.size(1)
409 |
410 | preds = preds.permute(0, 2, 3, 4, 1).contiguous().view(-1, C)
411 | targets = targets.view(-1, 1)
412 |
413 | log_P = F.log_softmax(preds, dim=1)
414 | P = torch.exp(log_P)
415 | # P = F.softmax(preds, dim=1)
416 | # log_P = F.log_softmax(preds, dim=1)
417 | # class_mask = torch.zeros(preds.shape).to(preds.device) + 1e-8
418 | class_mask = torch.zeros(preds.shape).to(preds.device) # problem
419 | class_mask.scatter_(1, targets, 1.)
420 | # number = torch.unique(targets)
421 | alpha = self.alpha[targets.data.view(-1)] # problem alpha: weight of data
422 | # alpha = self.alpha.gather(0, targets.view(-1))
423 |
424 | probs = (P * class_mask).sum(1).view(-1, 1) # problem
425 | log_probs = (log_P * class_mask).sum(1).view(-1, 1)
426 |
427 | # probs = P.gather(1,targets.view(-1,1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll )
428 | # log_probs = log_P.gather(1,targets.view(-1,1))
429 |
430 | batch_loss = -alpha * (1-probs).pow(self.gamma)*log_probs
431 | if weight is not False:
432 | element_weight = weight.squeeze(0)[targets.squeeze(0)]
433 | batch_loss = batch_loss * element_weight
434 |
435 | if self.size_average:
436 | loss = batch_loss.mean()
437 | else:
438 | loss = batch_loss.sum()
439 |
440 | return loss
441 |
442 | def softmax_dice(output, target):
443 | '''
444 | The dice loss for using softmax activation function
445 | :param output: (b, num_class, d, h, w)
446 | :param target: (b, d, h, w)
447 | :return: softmax dice loss torch.nn.BCELoss()
448 | '''
449 |
450 | loss1 = Dice(output[:, 1, ...], (target == 1).float())
451 | loss2 = Dice(output[:, 2, ...], (target == 2).float())
452 | loss3 = Dice(output[:, 3, ...], (target == 3).float())
453 |
454 | return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data
455 |
456 | def softmax_dice2(output, target):
457 | '''
458 | The dice loss for using softmax activation function
459 | :param output: (b, num_class, d, h, w)
460 | :param target: (b, d, h, w)
461 | :return: softmax dice loss
462 | '''
463 | loss0 = Dice(output[:, 0, ...], (target == 0).float())
464 | loss1 = Dice(output[:, 1, ...], (target == 1).float())
465 | loss2 = Dice(output[:, 2, ...], (target == 2).float())
466 | loss3 = Dice(output[:, 3, ...], (target == 3).float())
467 |
468 | return loss1 + loss2 + loss3 + loss0, 1-loss1.data, 1-loss2.data, 1-loss3.data
469 |
470 |
471 | def sigmoid_dice(output, target):
472 | '''
473 | The dice loss for using sigmoid activation function
474 | :param output: (b, num_class-1, d, h, w)
475 | :param target: (b, d, h, w)
476 | :return:
477 | '''
478 | loss1 = Dice(output[:, 0, ...], (target == 1).float())
479 | loss2 = Dice(output[:, 1, ...], (target == 2).float())
480 | loss3 = Dice(output[:, 2, ...], (target == 3).float())
481 |
482 | return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data
483 |
484 |
485 | def Generalized_dice(output, target, eps=1e-5, weight_type='square'):
486 | if target.dim() == 4: #(b, h, w, d)
487 | target[target == 4] = 3 #transfer label 4 to 3
488 | target = expand_target(target, n_class=output.size()[1]) #extend target from (b, h, w, d) to (b, c, h, w, d)
489 |
490 | output = flatten(output)[1:, ...] # transpose [N,4,H,W,D] -> [4,N,H,W,D] -> [3, N*H*W*D] voxels
491 | target = flatten(target)[1:, ...] # [class, N*H*W*D]
492 |
493 | target_sum = target.sum(-1) # sub_class_voxels [3,1] -> 3个voxels
494 | if weight_type == 'square':
495 | class_weights = 1. / (target_sum * target_sum + eps)
496 | elif weight_type == 'identity':
497 | class_weights = 1. / (target_sum + eps)
498 | elif weight_type == 'sqrt':
499 | class_weights = 1. / (torch.sqrt(target_sum) + eps)
500 | else:
501 | raise ValueError('Check out the weight_type :', weight_type)
502 |
503 | # print(class_weights)
504 | intersect = (output * target).sum(-1)
505 | intersect_sum = (intersect * class_weights).sum()
506 | denominator = (output + target).sum(-1)
507 | denominator_sum = (denominator * class_weights).sum() + eps
508 |
509 | loss1 = 2*intersect[0] / (denominator[0] + eps)
510 | loss2 = 2*intersect[1] / (denominator[1] + eps)
511 | loss3 = 2*intersect[2] / (denominator[2] + eps)
512 |
513 | return 1 - 2. * intersect_sum / denominator_sum, loss1, loss2, loss3
514 |
--------------------------------------------------------------------------------
/models/eval:
--------------------------------------------------------------------------------
1 | # Forked from https://github.com/GewelsJI/VPS/blob/main/eval/metrics.py
2 | import numpy as np
3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist
4 |
5 |
6 | _EPS = np.spacing(1)
7 | _TYPE = np.float64
8 |
9 |
10 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple:
11 | gt = gt > 128
12 | pred = pred / 255
13 | if pred.max() != pred.min():
14 | pred = (pred - pred.min()) / (pred.max() - pred.min())
15 | return pred, gt
16 |
17 |
18 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float:
19 | return min(2 * matrix.mean(), max_value)
20 |
21 |
22 | class Fmeasure(object):
23 | def __init__(self, length, beta: float = 0.3):
24 | self.beta = beta
25 | self.precisions = []
26 | self.recalls = []
27 | self.adaptive_fms = []
28 | self.changeable_fms = []
29 |
30 | def step(self, pred: np.ndarray, gt: np.ndarray, idx):
31 | pred, gt = _prepare_data(pred, gt)
32 |
33 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt)
34 | self.adaptive_fms.append(adaptive_fm)
35 |
36 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt)
37 | self.precisions.append(precisions)
38 | self.recalls.append(recalls)
39 | self.changeable_fms.append(changeable_fms)
40 |
41 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float:
42 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1)
43 | binary_predcition = pred >= adaptive_threshold
44 | area_intersection = binary_predcition[gt].sum()
45 | if area_intersection == 0:
46 | adaptive_fm = 0
47 | else:
48 | pre = area_intersection / np.count_nonzero(binary_predcition)
49 | rec = area_intersection / np.count_nonzero(gt)
50 | # F_beta measure
51 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec)
52 | return adaptive_fm
53 |
54 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple:
55 | pred = (pred * 255).astype(np.uint8)
56 | bins = np.linspace(0, 256, 257)
57 | fg_hist, _ = np.histogram(pred[gt], bins=bins)
58 | bg_hist, _ = np.histogram(pred[~gt], bins=bins)
59 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0)
60 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0)
61 | TPs = fg_w_thrs
62 | Ps = fg_w_thrs + bg_w_thrs
63 | Ps[Ps == 0] = 1
64 | T = max(np.count_nonzero(gt), 1)
65 | precisions = TPs / Ps
66 | recalls = TPs / T
67 | numerator = (1 + self.beta) * precisions * recalls
68 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls)
69 | changeable_fms = numerator / denominator
70 | return precisions, recalls, changeable_fms
71 |
72 | def get_results(self):
73 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE))
74 | # precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256
75 | # recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256
76 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0)
77 | # return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm),
78 | # pr=dict(p=precision, r=recall))
79 | return dict(adpFm=adaptive_fm, meanFm=changeable_fm, maxFm=changeable_fm)
80 |
81 |
82 | class MAE(object):
83 | def __init__(self, length):
84 | self.maes = []
85 |
86 | def step(self, pred: np.ndarray, gt: np.ndarray, idx):
87 | pred, gt = _prepare_data(pred, gt)
88 |
89 | mae = self.cal_mae(pred, gt)
90 | self.maes.append(mae)
91 |
92 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float:
93 | mae = np.mean(np.abs(pred - gt))
94 | return mae
95 |
96 | def get_results(self):
97 | mae = np.mean(np.array(self.maes, _TYPE))
98 | return dict(MAE=mae)
99 |
100 |
101 | class Smeasure(object):
102 | def __init__(self, length, alpha: float = 0.5):
103 | self.sms = []
104 | self.alpha = alpha
105 |
106 | def step(self, pred: np.ndarray, gt: np.ndarray, idx):
107 | pred, gt = _prepare_data(pred=pred, gt=gt)
108 |
109 | sm = self.cal_sm(pred, gt)
110 | self.sms.append(sm)
111 |
112 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float:
113 | y = np.mean(gt)
114 | if y == 0:
115 | sm = 1 - np.mean(pred)
116 | elif y == 1:
117 | sm = np.mean(pred)
118 | else:
119 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
120 | sm = max(0, sm)
121 | return sm
122 |
123 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float:
124 | fg = pred * gt
125 | bg = (1 - pred) * (1 - gt)
126 | u = np.mean(gt)
127 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt)
128 | return object_score
129 |
130 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float:
131 | x = np.mean(pred[gt == 1])
132 | sigma_x = np.std(pred[gt == 1], ddof=1)
133 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS)
134 | return score
135 |
136 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float:
137 | x, y = self.centroid(gt)
138 | part_info = self.divide_with_xy(pred, gt, x, y)
139 | w1, w2, w3, w4 = part_info['weight']
140 | pred1, pred2, pred3, pred4 = part_info['pred']
141 | gt1, gt2, gt3, gt4 = part_info['gt']
142 | score1 = self.ssim(pred1, gt1)
143 | score2 = self.ssim(pred2, gt2)
144 | score3 = self.ssim(pred3, gt3)
145 | score4 = self.ssim(pred4, gt4)
146 |
147 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
148 |
149 | def centroid(self, matrix: np.ndarray) -> tuple:
150 | """
151 | To ensure consistency with the matlab code, one is added to the centroid coordinate,
152 | so there is no need to use the redundant addition operation when dividing the region later,
153 | because the sequence generated by ``1:X`` in matlab will contain ``X``.
154 | :param matrix: a bool data array
155 | :return: the centroid coordinate
156 | """
157 | h, w = matrix.shape
158 | area_object = np.count_nonzero(matrix)
159 | if area_object == 0:
160 | x = np.round(w / 2)
161 | y = np.round(h / 2)
162 | else:
163 | # More details can be found at: https://www.yuque.com/lart/blog/gpbigm
164 | y, x = np.argwhere(matrix).mean(axis=0).round()
165 | return int(x) + 1, int(y) + 1
166 |
167 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict:
168 | h, w = gt.shape
169 | area = h * w
170 |
171 | gt_LT = gt[0:y, 0:x]
172 | gt_RT = gt[0:y, x:w]
173 | gt_LB = gt[y:h, 0:x]
174 | gt_RB = gt[y:h, x:w]
175 |
176 | pred_LT = pred[0:y, 0:x]
177 | pred_RT = pred[0:y, x:w]
178 | pred_LB = pred[y:h, 0:x]
179 | pred_RB = pred[y:h, x:w]
180 |
181 | w1 = x * y / area
182 | w2 = y * (w - x) / area
183 | w3 = (h - y) * x / area
184 | w4 = 1 - w1 - w2 - w3
185 |
186 | return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB),
187 | pred=(pred_LT, pred_RT, pred_LB, pred_RB),
188 | weight=(w1, w2, w3, w4))
189 |
190 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float:
191 | h, w = pred.shape
192 | N = h * w
193 |
194 | x = np.mean(pred)
195 | y = np.mean(gt)
196 |
197 | sigma_x = np.sum((pred - x) ** 2) / (N - 1)
198 | sigma_y = np.sum((gt - y) ** 2) / (N - 1)
199 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1)
200 |
201 | alpha = 4 * x * y * sigma_xy
202 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y)
203 |
204 | if alpha != 0:
205 | score = alpha / (beta + _EPS)
206 | elif alpha == 0 and beta == 0:
207 | score = 1
208 | else:
209 | score = 0
210 | return score
211 |
212 | def get_results(self):
213 | sm = np.mean(np.array(self.sms, dtype=_TYPE))
214 | return dict(Smeasure=sm)
215 |
216 |
217 | class Emeasure(object):
218 | def __init__(self, length):
219 | self.adaptive_ems = []
220 | self.changeable_ems = []
221 |
222 | def step(self, pred: np.ndarray, gt: np.ndarray, idx):
223 | pred, gt = _prepare_data(pred=pred, gt=gt)
224 | self.gt_fg_numel = np.count_nonzero(gt)
225 | self.gt_size = gt.shape[0] * gt.shape[1]
226 |
227 | changeable_ems = self.cal_changeable_em(pred, gt)
228 | self.changeable_ems.append(changeable_ems)
229 | adaptive_em = self.cal_adaptive_em(pred, gt)
230 | self.adaptive_ems.append(adaptive_em)
231 |
232 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float:
233 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1)
234 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold)
235 | return adaptive_em
236 |
237 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
238 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt)
239 | return changeable_ems
240 |
241 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
242 | binarized_pred = pred >= threshold
243 | fg_fg_numel = np.count_nonzero(binarized_pred & gt)
244 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
245 |
246 | fg___numel = fg_fg_numel + fg_bg_numel
247 | bg___numel = self.gt_size - fg___numel
248 |
249 | if self.gt_fg_numel == 0:
250 | enhanced_matrix_sum = bg___numel
251 | elif self.gt_fg_numel == self.gt_size:
252 | enhanced_matrix_sum = fg___numel
253 | else:
254 | parts_numel, combinations = self.generate_parts_numel_combinations(
255 | fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel,
256 | pred_fg_numel=fg___numel, pred_bg_numel=bg___numel,
257 | )
258 |
259 | results_parts = []
260 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
261 | align_matrix_value = 2 * (combination[0] * combination[1]) / \
262 | (combination[0] ** 2 + combination[1] ** 2 + _EPS)
263 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
264 | results_parts.append(enhanced_matrix_value * part_numel)
265 | enhanced_matrix_sum = sum(results_parts)
266 |
267 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
268 | return em
269 |
270 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
271 | pred = (pred * 255).astype(np.uint8)
272 | bins = np.linspace(0, 256, 257)
273 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
274 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
275 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
276 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
277 |
278 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
279 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
280 |
281 | if self.gt_fg_numel == 0:
282 | enhanced_matrix_sum = bg___numel_w_thrs
283 | elif self.gt_fg_numel == self.gt_size:
284 | enhanced_matrix_sum = fg___numel_w_thrs
285 | else:
286 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
287 | fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
288 | pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
289 | )
290 |
291 | results_parts = np.empty(shape=(4, 256), dtype=np.float64)
292 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
293 | align_matrix_value = 2 * (combination[0] * combination[1]) / \
294 | (combination[0] ** 2 + combination[1] ** 2 + _EPS)
295 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
296 | results_parts[i] = enhanced_matrix_value * part_numel
297 | enhanced_matrix_sum = results_parts.sum(axis=0)
298 |
299 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
300 | return em
301 |
302 | def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
303 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel
304 | bg_bg_numel = pred_bg_numel - bg_fg_numel
305 |
306 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel]
307 |
308 | mean_pred_value = pred_fg_numel / self.gt_size
309 | mean_gt_value = self.gt_fg_numel / self.gt_size
310 |
311 | demeaned_pred_fg_value = 1 - mean_pred_value
312 | demeaned_pred_bg_value = 0 - mean_pred_value
313 | demeaned_gt_fg_value = 1 - mean_gt_value
314 | demeaned_gt_bg_value = 0 - mean_gt_value
315 |
316 | combinations = [
317 | (demeaned_pred_fg_value, demeaned_gt_fg_value),
318 | (demeaned_pred_fg_value, demeaned_gt_bg_value),
319 | (demeaned_pred_bg_value, demeaned_gt_fg_value),
320 | (demeaned_pred_bg_value, demeaned_gt_bg_value)
321 | ]
322 | return parts_numel, combinations
323 |
324 | def get_results(self):
325 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE))
326 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0)
327 | return dict(adpEm=adaptive_em, meanEm=changeable_em, maxEm=changeable_em)
328 |
329 |
330 | class WeightedFmeasure(object):
331 | def __init__(self, length, beta: float = 1):
332 | self.beta = beta
333 | self.weighted_fms = []
334 |
335 | def step(self, pred: np.ndarray, gt: np.ndarray, idx):
336 | pred, gt = _prepare_data(pred=pred, gt=gt)
337 |
338 | if np.all(~gt):
339 | wfm = 0
340 | else:
341 | wfm = self.cal_wfm(pred, gt)
342 | self.weighted_fms.append(wfm)
343 |
344 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float:
345 | # [Dst,IDXT] = bwdist(dGT);
346 | Dst, Idxt = bwdist(gt == 0, return_indices=True)
347 |
348 | # %Pixel dependency
349 | # E = abs(FG-dGT);
350 | E = np.abs(pred - gt)
351 | Et = np.copy(E)
352 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]]
353 |
354 | # K = fspecial('gaussian',7,5);
355 | # EA = imfilter(Et,K);
356 | K = self.matlab_style_gauss2D((7, 7), sigma=5)
357 | EA = convolve(Et, weights=K, mode="constant", cval=0)
358 | # MIN_E_EA = E;
359 | # MIN_E_EA(GT & EA np.ndarray:
379 | """
380 | 2D gaussian mask - should give the same result as MATLAB's
381 | fspecial('gaussian',[shape],[sigma])
382 | """
383 | m, n = [(ss - 1) / 2 for ss in shape]
384 | y, x = np.ogrid[-m: m + 1, -n: n + 1]
385 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
386 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
387 | sumh = h.sum()
388 | if sumh != 0:
389 | h /= sumh
390 | return h
391 |
392 | def get_results(self):
393 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE))
394 | return dict(wFmeasure=weighted_fm)
395 |
396 |
397 | class Medical(object):
398 | def __init__(self, length):
399 | self.Thresholds = np.linspace(1, 0, 256)
400 |
401 | self.threshold_Sensitivity = np.zeros((length, len(self.Thresholds)))
402 | self.threshold_Specificity = np.zeros((length, len(self.Thresholds)))
403 | self.threshold_Dice = np.zeros((length, len(self.Thresholds)))
404 | self.threshold_IoU = np.zeros((length, len(self.Thresholds)))
405 |
406 | def Fmeasure_calu(self, pred, gt, threshold):
407 | if threshold > 1:
408 | threshold = 1
409 |
410 | Label3 = np.zeros_like(gt)
411 | Label3[pred >= threshold] = 1
412 |
413 | NumRec = np.sum(Label3 == 1)
414 | NumNoRec = np.sum(Label3 == 0)
415 |
416 | LabelAnd = (Label3 == 1) & (gt == 1)
417 | NumAnd = np.sum(LabelAnd == 1)
418 | num_obj = np.sum(gt)
419 | num_pred = np.sum(Label3)
420 |
421 | FN = num_obj - NumAnd
422 | FP = NumRec - NumAnd
423 | TN = NumNoRec - FN
424 |
425 | if NumAnd == 0:
426 | RecallFtem = 0
427 | Dice = 0
428 | SpecifTem = 0
429 | IoU = 0
430 |
431 | else:
432 | IoU = NumAnd / (FN + NumRec)
433 | RecallFtem = NumAnd / num_obj
434 | SpecifTem = TN / (TN + FP)
435 | Dice = 2 * NumAnd / (num_obj + num_pred)
436 |
437 | return RecallFtem, SpecifTem, Dice, IoU
438 |
439 | def step(self, pred, gt, idx):
440 | pred, gt = _prepare_data(pred=pred, gt=gt)
441 |
442 | threshold_Rec = np.zeros(len(self.Thresholds))
443 | threshold_Iou = np.zeros(len(self.Thresholds))
444 | threshold_Spe = np.zeros(len(self.Thresholds))
445 | threshold_Dic = np.zeros(len(self.Thresholds))
446 |
447 | for j, threshold in enumerate(self.Thresholds):
448 | threshold_Rec[j], threshold_Spe[j], threshold_Dic[j], \
449 | threshold_Iou[j] = self.Fmeasure_calu(pred, gt, threshold)
450 |
451 | self.threshold_Sensitivity[idx, :] = threshold_Rec
452 | self.threshold_Specificity[idx, :] = threshold_Spe
453 | self.threshold_Dice[idx, :] = threshold_Dic
454 | self.threshold_IoU[idx, :] = threshold_Iou
455 |
456 | def get_results(self):
457 | column_Sen = np.mean(self.threshold_Sensitivity, axis=0)
458 | column_Spe = np.mean(self.threshold_Specificity, axis=0)
459 | column_Dic = np.mean(self.threshold_Dice, axis=0)
460 | column_IoU = np.mean(self.threshold_IoU, axis=0)
461 |
462 | return dict(meanSen=column_Sen, meanSpe=column_Spe, meanDice=column_Dic, meanIoU=column_IoU,
463 | maxSen=column_Sen, maxSpe=column_Spe, maxDice=column_Dic, maxIoU=column_IoU)
464 |
--------------------------------------------------------------------------------
/models/lib/IntmdSequential.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class IntermediateSequential(nn.Sequential):
5 | def __init__(self, *args, return_intermediate=True):
6 | super().__init__(*args)
7 | self.return_intermediate = return_intermediate
8 |
9 | def forward(self, input):
10 | if not self.return_intermediate:
11 | return super().forward(input)
12 |
13 | intermediate_outputs = {}
14 | output = input
15 | for name, module in self.named_children():
16 | output = intermediate_outputs[name] = module(output)
17 |
18 | return output, intermediate_outputs
19 |
20 |
--------------------------------------------------------------------------------
/models/lib/PositionalEncoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class FixedPositionalEncoding(nn.Module):
5 | def __init__(self, embedding_dim, max_length=512):
6 | super(FixedPositionalEncoding, self).__init__()
7 |
8 | pe = torch.zeros(max_length, embedding_dim)
9 | position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
10 | div_term = torch.exp(
11 | torch.arange(0, embedding_dim, 2).float()
12 | * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
13 | )
14 | pe[:, 0::2] = torch.sin(position * div_term)
15 | pe[:, 1::2] = torch.cos(position * div_term)
16 | pe = pe.unsqueeze(0).transpose(0, 1)
17 | self.register_buffer('pe', pe)
18 |
19 | def forward(self, x):
20 | x = x + self.pe[: x.size(0), :]
21 | return x
22 |
23 |
24 | class LearnedPositionalEncoding(nn.Module):
25 | def __init__(self, max_position_embeddings, embedding_dim, seq_length):
26 | super(LearnedPositionalEncoding, self).__init__()
27 |
28 | self.position_embeddings = nn.Parameter(torch.zeros(1, 4096, 512)) #8x
29 |
30 | def forward(self, x, position_ids=None):
31 |
32 | position_embeddings = self.position_embeddings
33 | return x + position_embeddings
34 |
--------------------------------------------------------------------------------
/models/lib/TransBTS_downsample8x_skipconnection.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.lib.Transformer import TransformerModel
4 | from models.lib.PositionalEncoding import FixedPositionalEncoding,LearnedPositionalEncoding
5 | from models.lib.Unet_skipconnection import Unet
6 |
7 |
8 | class TransformerBTS(nn.Module):
9 | def __init__(
10 | self,
11 | img_dim,
12 | patch_dim,
13 | num_channels,
14 | embedding_dim,
15 | num_heads,
16 | num_layers,
17 | hidden_dim,
18 | dropout_rate=0.0,
19 | attn_dropout_rate=0.0,
20 | conv_patch_representation=True,
21 | positional_encoding_type="learned",
22 | ):
23 | super(TransformerBTS, self).__init__()
24 |
25 | assert embedding_dim % num_heads == 0
26 | assert img_dim % patch_dim == 0
27 |
28 | self.img_dim = img_dim
29 | self.embedding_dim = embedding_dim
30 | self.num_heads = num_heads
31 | self.patch_dim = patch_dim
32 | self.num_channels = num_channels
33 | self.dropout_rate = dropout_rate
34 | self.attn_dropout_rate = attn_dropout_rate
35 | self.conv_patch_representation = conv_patch_representation
36 |
37 | self.num_patches = int((img_dim // patch_dim) ** 3)
38 | self.seq_length = self.num_patches
39 | self.flatten_dim = 128 * num_channels
40 |
41 | self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
42 | if positional_encoding_type == "learned":
43 | self.position_encoding = LearnedPositionalEncoding(
44 | self.seq_length, self.embedding_dim, self.seq_length
45 | )
46 | elif positional_encoding_type == "fixed":
47 | self.position_encoding = FixedPositionalEncoding(
48 | self.embedding_dim,
49 | )
50 |
51 | self.pe_dropout = nn.Dropout(p=self.dropout_rate)
52 |
53 | self.transformer = TransformerModel(
54 | embedding_dim,
55 | num_layers,
56 | num_heads,
57 | hidden_dim,
58 |
59 | self.dropout_rate,
60 | self.attn_dropout_rate,
61 | )
62 | self.pre_head_ln = nn.LayerNorm(embedding_dim)
63 |
64 | if self.conv_patch_representation:
65 |
66 | self.conv_x = nn.Conv3d(
67 | 128,
68 | self.embedding_dim,
69 | kernel_size=3,
70 | stride=1,
71 | padding=1
72 | )
73 |
74 | self.Unet = Unet(in_channels=num_channels, base_channels=16, num_classes=4)
75 | self.bn = nn.BatchNorm3d(128)
76 | self.relu = nn.ReLU(inplace=True)
77 |
78 |
79 | def encode(self, x):
80 | if self.conv_patch_representation:
81 | # combine embedding with conv patch distribution
82 | x1_1, x2_1, x3_1, x = self.Unet(x)
83 | x = self.bn(x)
84 | x = self.relu(x)
85 | x = self.conv_x(x)
86 | x = x.permute(0, 2, 3, 4, 1).contiguous()
87 | x = x.view(x.size(0), -1, self.embedding_dim)
88 |
89 | else:
90 | x = self.Unet(x)
91 | x = self.bn(x)
92 | x = self.relu(x)
93 | x = (
94 | x.unfold(2, 2, 2)
95 | .unfold(3, 2, 2)
96 | .unfold(4, 2, 2)
97 | .contiguous()
98 | )
99 | x = x.view(x.size(0), x.size(1), -1, 8)
100 | x = x.permute(0, 2, 3, 1).contiguous()
101 | x = x.view(x.size(0), -1, self.flatten_dim)
102 | x = self.linear_encoding(x)
103 |
104 | x = self.position_encoding(x)
105 | x = self.pe_dropout(x)
106 |
107 | # apply transformer
108 | x, intmd_x = self.transformer(x)
109 | x = self.pre_head_ln(x)
110 |
111 | return x1_1, x2_1, x3_1, x, intmd_x
112 |
113 | def decode(self, x):
114 | raise NotImplementedError("Should be implemented in child class!!")
115 |
116 | def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]):
117 |
118 | x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs = self.encode(x)
119 |
120 | decoder_output = self.decode(
121 | x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers
122 | )
123 |
124 | if auxillary_output_layers is not None:
125 | auxillary_outputs = {}
126 | for i in auxillary_output_layers:
127 | val = str(2 * i - 1)
128 | _key = 'Z' + str(i)
129 | auxillary_outputs[_key] = intmd_encoder_outputs[val]
130 |
131 | return decoder_output
132 |
133 | return decoder_output
134 |
135 | def _get_padding(self, padding_type, kernel_size):
136 | assert padding_type in ['SAME', 'VALID']
137 | if padding_type == 'SAME':
138 | _list = [(k - 1) // 2 for k in kernel_size]
139 | return tuple(_list)
140 | return tuple(0 for _ in kernel_size)
141 |
142 | def _reshape_output(self, x):
143 | x = x.view(
144 | x.size(0),
145 | int(self.img_dim / self.patch_dim),
146 | int(self.img_dim / self.patch_dim),
147 | int(self.img_dim / self.patch_dim),
148 | self.embedding_dim,
149 | )
150 | x = x.permute(0, 4, 1, 2, 3).contiguous()
151 |
152 | return x
153 |
154 |
155 | class BTS(TransformerBTS):
156 | def __init__(
157 | self,
158 | img_dim,
159 | patch_dim,
160 | num_channels,
161 | num_classes,
162 | embedding_dim,
163 | num_heads,
164 | num_layers,
165 | hidden_dim,
166 | dropout_rate=0.0,
167 | attn_dropout_rate=0.0,
168 | conv_patch_representation=True,
169 | positional_encoding_type="learned",
170 | ):
171 | super(BTS, self).__init__(
172 | img_dim=img_dim,
173 | patch_dim=patch_dim,
174 | num_channels=num_channels,
175 | embedding_dim=embedding_dim,
176 | num_heads=num_heads,
177 | num_layers=num_layers,
178 | hidden_dim=hidden_dim,
179 | dropout_rate=dropout_rate,
180 | attn_dropout_rate=attn_dropout_rate,
181 | conv_patch_representation=conv_patch_representation,
182 | positional_encoding_type=positional_encoding_type,
183 | )
184 |
185 | self.num_classes = num_classes
186 |
187 | self.Softmax = nn.Softmax(dim=1)
188 |
189 | self.Enblock8_1 = EnBlock1(in_channels=self.embedding_dim)
190 | self.Enblock8_2 = EnBlock2(in_channels=self.embedding_dim // 4)
191 |
192 | self.DeUp4 = DeUp_Cat(in_channels=self.embedding_dim//4, out_channels=self.embedding_dim//8)
193 | self.DeBlock4 = DeBlock(in_channels=self.embedding_dim//8)
194 |
195 | self.DeUp3 = DeUp_Cat(in_channels=self.embedding_dim//8, out_channels=self.embedding_dim//16)
196 | self.DeBlock3 = DeBlock(in_channels=self.embedding_dim//16)
197 |
198 | self.DeUp2 = DeUp_Cat(in_channels=self.embedding_dim//16, out_channels=self.embedding_dim//32)
199 | self.DeBlock2 = DeBlock(in_channels=self.embedding_dim//32)
200 |
201 | self.endconv = nn.Conv3d(self.embedding_dim // 32, 4, kernel_size=1)
202 |
203 |
204 | def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]):
205 |
206 | assert intmd_layers is not None, "pass the intermediate layers for MLA"
207 | encoder_outputs = {}
208 | all_keys = []
209 | for i in intmd_layers:
210 | val = str(2 * i - 1)
211 | _key = 'Z' + str(i)
212 | all_keys.append(_key)
213 | encoder_outputs[_key] = intmd_x[val]
214 | all_keys.reverse()
215 |
216 | x8 = encoder_outputs[all_keys[0]]
217 | x8 = self._reshape_output(x8)
218 | x8 = self.Enblock8_1(x8)
219 | x8 = self.Enblock8_2(x8)
220 |
221 | y4 = self.DeUp4(x8, x3_1) # (1, 64, 32, 32, 32)
222 | y4 = self.DeBlock4(y4)
223 |
224 | y3 = self.DeUp3(y4, x2_1) # (1, 32, 64, 64, 64)
225 | y3 = self.DeBlock3(y3)
226 |
227 | y2 = self.DeUp2(y3, x1_1) # (1, 16, 128, 128, 128)
228 | y2 = self.DeBlock2(y2)
229 |
230 | y = self.endconv(y2) # (1, 4, 128, 128, 128)
231 | y = self.Softmax(y)
232 | return y
233 |
234 | class EnBlock1(nn.Module):
235 | def __init__(self, in_channels):
236 | super(EnBlock1, self).__init__()
237 |
238 | self.bn1 = nn.BatchNorm3d(512 // 4)
239 | self.relu1 = nn.ReLU(inplace=True)
240 | self.bn2 = nn.BatchNorm3d(512 // 4)
241 | self.relu2 = nn.ReLU(inplace=True)
242 | self.conv1 = nn.Conv3d(in_channels, in_channels // 4, kernel_size=3, padding=1)
243 | self.conv2 = nn.Conv3d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1)
244 |
245 | def forward(self, x):
246 | x1 = self.conv1(x)
247 | x1 = self.bn1(x1)
248 | x1 = self.relu1(x1)
249 | x1 = self.conv2(x1)
250 | x1 = self.bn2(x1)
251 | x1 = self.relu2(x1)
252 |
253 | return x1
254 |
255 |
256 | class EnBlock2(nn.Module):
257 | def __init__(self, in_channels):
258 | super(EnBlock2, self).__init__()
259 |
260 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
261 | self.bn1 = nn.BatchNorm3d(512 // 4)
262 | self.relu1 = nn.ReLU(inplace=True)
263 | self.bn2 = nn.BatchNorm3d(512 // 4)
264 | self.relu2 = nn.ReLU(inplace=True)
265 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
266 |
267 | def forward(self, x):
268 | x1 = self.conv1(x)
269 | x1 = self.bn1(x1)
270 | x1 = self.relu1(x1)
271 | x1 = self.conv2(x1)
272 | x1 = self.bn2(x1)
273 | x1 = self.relu2(x1)
274 | x1 = x1 + x
275 |
276 | return x1
277 |
278 |
279 | class DeUp_Cat(nn.Module):
280 | def __init__(self, in_channels, out_channels):
281 | super(DeUp_Cat, self).__init__()
282 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
283 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)
284 | self.conv3 = nn.Conv3d(out_channels*2, out_channels, kernel_size=1)
285 |
286 | def forward(self, x, prev):
287 | x1 = self.conv1(x)
288 | y = self.conv2(x1)
289 | # y = y + prev
290 | y = torch.cat((prev, y), dim=1)
291 | y = self.conv3(y)
292 | return y
293 |
294 | class DeBlock(nn.Module):
295 | def __init__(self, in_channels):
296 | super(DeBlock, self).__init__()
297 |
298 | self.bn1 = nn.BatchNorm3d(in_channels)
299 | self.relu1 = nn.ReLU(inplace=True)
300 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
301 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
302 | self.bn2 = nn.BatchNorm3d(in_channels)
303 | self.relu2 = nn.ReLU(inplace=True)
304 |
305 | def forward(self, x):
306 | x1 = self.conv1(x)
307 | x1 = self.bn1(x1)
308 | x1 = self.relu1(x1)
309 | x1 = self.conv2(x1)
310 | x1 = self.bn2(x1)
311 | x1 = self.relu2(x1)
312 | x1 = x1 + x
313 |
314 | return x1
315 |
316 |
317 |
318 |
319 | def TransBTS(input_dims, _conv_repr=True, _pe_type="learned"):
320 | img_dim = 128
321 | num_classes = 4
322 | if input_dims == 'four':
323 | num_channels = 4 # 4 2 1
324 | else:
325 | num_channels = 1
326 | patch_dim = 8
327 | aux_layers = [1, 2, 3, 4]
328 | model = BTS(
329 | img_dim,
330 | patch_dim,
331 | num_channels,
332 | num_classes,
333 | embedding_dim=512,
334 | num_heads=8,
335 | num_layers=4,
336 | hidden_dim=2048, # 4096
337 | dropout_rate=0.1,
338 | attn_dropout_rate=0.1,
339 | conv_patch_representation=_conv_repr,
340 | positional_encoding_type=_pe_type,
341 | )
342 |
343 | return aux_layers, model
344 |
345 |
346 | if __name__ == '__main__':
347 | with torch.no_grad():
348 | import os
349 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
350 | cuda0 = torch.device('cuda:0')
351 | x = torch.rand((1, 4, 128, 128, 128), device=cuda0)
352 | _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned")
353 | model.cuda()
354 | y = model(x)
355 | print(y.shape)
356 |
--------------------------------------------------------------------------------
/models/lib/Transformer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from models.lib.IntmdSequential import IntermediateSequential
3 |
4 |
5 | class SelfAttention(nn.Module):
6 | def __init__(
7 | self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
8 | ):
9 | super().__init__()
10 | self.num_heads = heads
11 | head_dim = dim // heads
12 | self.scale = qk_scale or head_dim ** -0.5
13 |
14 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
15 | self.attn_drop = nn.Dropout(dropout_rate)
16 | self.proj = nn.Linear(dim, dim)
17 | self.proj_drop = nn.Dropout(dropout_rate)
18 |
19 | def forward(self, x):
20 | B, N, C = x.shape
21 | qkv = (
22 | self.qkv(x)
23 | .reshape(B, N, 3, self.num_heads, C // self.num_heads)
24 | .permute(2, 0, 3, 1, 4)
25 | )
26 | q, k, v = (
27 | qkv[0],
28 | qkv[1],
29 | qkv[2],
30 | ) # make torchscript happy (cannot use tensor as tuple)
31 |
32 | attn = (q @ k.transpose(-2, -1)) * self.scale
33 | attn = attn.softmax(dim=-1)
34 | attn = self.attn_drop(attn)
35 |
36 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
37 | x = self.proj(x)
38 | x = self.proj_drop(x)
39 | return x
40 |
41 |
42 | class Residual(nn.Module):
43 | def __init__(self, fn):
44 | super().__init__()
45 | self.fn = fn
46 |
47 | def forward(self, x):
48 | return self.fn(x) + x
49 |
50 |
51 | class PreNorm(nn.Module):
52 | def __init__(self, dim, fn):
53 | super().__init__()
54 | self.norm = nn.LayerNorm(dim)
55 | self.fn = fn
56 |
57 | def forward(self, x):
58 | return self.fn(self.norm(x))
59 |
60 |
61 | class PreNormDrop(nn.Module):
62 | def __init__(self, dim, dropout_rate, fn):
63 | super().__init__()
64 | self.norm = nn.LayerNorm(dim)
65 | self.dropout = nn.Dropout(p=dropout_rate)
66 | self.fn = fn
67 |
68 | def forward(self, x):
69 | return self.dropout(self.fn(self.norm(x)))
70 |
71 |
72 | class FeedForward(nn.Module):
73 | def __init__(self, dim, hidden_dim, dropout_rate):
74 | super().__init__()
75 | self.net = nn.Sequential(
76 | nn.Linear(dim, hidden_dim),
77 | nn.GELU(),
78 | nn.Dropout(p=dropout_rate),
79 | nn.Linear(hidden_dim, dim),
80 | nn.Dropout(p=dropout_rate),
81 | )
82 |
83 | def forward(self, x):
84 | return self.net(x)
85 |
86 |
87 | class TransformerModel(nn.Module):
88 | def __init__(
89 | self,
90 | dim,
91 | depth,
92 | heads,
93 | mlp_dim,
94 | dropout_rate=0.1,
95 | attn_dropout_rate=0.1,
96 | ):
97 | super().__init__()
98 | layers = []
99 | for _ in range(depth):
100 | layers.extend(
101 | [
102 | Residual(
103 | PreNormDrop(
104 | dim,
105 | dropout_rate,
106 | SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate),
107 | )
108 | ),
109 | Residual(
110 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
111 | ),
112 | ]
113 | )
114 | # dim = dim / 2
115 | self.net = IntermediateSequential(*layers)
116 |
117 |
118 | def forward(self, x):
119 | return self.net(x)
120 |
--------------------------------------------------------------------------------
/models/lib/UNet3DZoo.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | # adapt from https://github.com/MIC-DKFZ/BraTS2017
6 |
7 |
8 | def normalization(planes, norm='gn'):
9 | if norm == 'bn':
10 | m = nn.BatchNorm3d(planes)
11 | elif norm == 'gn':
12 | m = nn.GroupNorm(8, planes)
13 | elif norm == 'in':
14 | m = nn.InstanceNorm3d(planes)
15 | else:
16 | raise ValueError('normalization type {} is not supported'.format(norm))
17 | return m
18 |
19 |
20 |
21 | class InitConv(nn.Module):
22 | def __init__(self, in_channels=1, out_channels=16, dropout=0.2):
23 | super(InitConv, self).__init__()
24 |
25 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
26 | self.dropout = dropout
27 |
28 | def forward(self, x):
29 | y = self.conv(x)
30 | y = F.dropout3d(y, self.dropout)
31 |
32 | return y
33 |
34 |
35 | class EnBlock(nn.Module):
36 | def __init__(self, in_channels, norm='gn'):
37 | super(EnBlock, self).__init__()
38 |
39 | self.bn1 = normalization(in_channels, norm=norm)
40 | self.relu1 = nn.ReLU(inplace=True)
41 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
42 |
43 | self.bn2 = normalization(in_channels, norm=norm)
44 | self.relu2 = nn.ReLU(inplace=True)
45 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
46 |
47 | def forward(self, x):
48 | x1 = self.bn1(x)
49 | x1 = self.relu1(x1)
50 | x1 = self.conv1(x1)
51 | y = self.bn2(x1)
52 | y = self.relu2(y)
53 | y = self.conv2(y)
54 | y = y + x
55 |
56 | return y
57 |
58 |
59 | class EnDown(nn.Module):
60 | def __init__(self, in_channels, out_channels):
61 | super(EnDown, self).__init__()
62 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
63 |
64 | def forward(self, x):
65 | y = self.conv(x)
66 |
67 | return y
68 |
69 | class Attention_block(nn.Module):
70 | def __init__(self, F_g, F_l, F_int):
71 | super(Attention_block, self).__init__()
72 | self.W_g = nn.Sequential(
73 | nn.Conv3d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
74 | nn.BatchNorm3d(F_int)
75 | )
76 |
77 | self.W_x = nn.Sequential(
78 | nn.Conv3d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
79 | nn.BatchNorm3d(F_int)
80 | )
81 |
82 | self.psi = nn.Sequential(
83 | nn.Conv3d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
84 | nn.BatchNorm3d(1),
85 | nn.Sigmoid()
86 | )
87 |
88 | self.relu = nn.ReLU(inplace=True)
89 |
90 | def forward(self, g, x):
91 | # 下采样的gating signal 卷积
92 | g1 = self.W_g(g)
93 | # 上采样的 l 卷积
94 | x1 = self.W_x(x)
95 | # concat + relu
96 | psi = self.relu(g1 + x1)
97 | # channel 减为1,并Sigmoid,得到权重矩阵
98 | psi = self.psi(psi)
99 | # 返回加权的 x
100 | return x * psi
101 |
102 | class De_Cat(nn.Module):
103 | def __init__(self, in_channels, out_channels):
104 | super(De_Cat, self).__init__()
105 | # self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
106 | # self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)
107 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
108 |
109 | def forward(self, x, prev):
110 | # x1 = self.conv1(x)
111 | # y = self.conv2(x1)
112 | # y = y + prev
113 | y = torch.cat((x, prev), dim=1)
114 | y = self.conv1(y)
115 | return y
116 |
117 | class DeUp_Cat(nn.Module):
118 | def __init__(self, in_channels, out_channels):
119 | super(DeUp_Cat, self).__init__()
120 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
121 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)
122 | self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
123 |
124 | def forward(self, x, prev):
125 | x1 = self.conv1(x)
126 | y = self.conv2(x1)
127 | # y = y + prev
128 | y = torch.cat((prev, y), dim=1)
129 | y = self.conv3(y)
130 | return y
131 |
132 | class DeUp(nn.Module):
133 | def __init__(self, in_channels, out_channels):
134 | super(DeUp, self).__init__()
135 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
136 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)
137 | self.conv3 = nn.Conv3d(out_channels, out_channels, kernel_size=1)
138 |
139 | def forward(self, x):
140 | x = self.conv1(x)
141 | x = self.conv2(x)
142 | x = self.conv3(x)
143 | return x
144 |
145 | class DeBlock(nn.Module):
146 | def __init__(self, in_channels):
147 | super(DeBlock, self).__init__()
148 |
149 | self.bn1 = nn.BatchNorm3d(in_channels)
150 | self.relu1 = nn.ReLU(inplace=True)
151 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
152 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
153 | self.bn2 = nn.BatchNorm3d(in_channels)
154 | self.relu2 = nn.ReLU(inplace=True)
155 |
156 | def forward(self, x):
157 | x1 = self.conv1(x)
158 | x1 = self.bn1(x1)
159 | x1 = self.relu1(x1)
160 | x1 = self.conv2(x1)
161 | x1 = self.bn2(x1)
162 | x1 = self.relu2(x1)
163 | x1 = x1 + x
164 |
165 | return x1
166 |
167 | class AttUnet(nn.Module):
168 | def __init__(self, in_channels=1, base_channels=16, num_classes=4):
169 | super(AttUnet, self).__init__()
170 |
171 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2)
172 | self.EnBlock1 = EnBlock(in_channels=base_channels)
173 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2)
174 |
175 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2)
176 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2)
177 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4)
178 |
179 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4)
180 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4)
181 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8)
182 | self.Att4 = Attention_block(F_g=64, F_l=64, F_int=32)
183 | self.Att3 = Attention_block(F_g=32, F_l=32, F_int=16)
184 | self.Att2 = Attention_block(F_g=16, F_l=16, F_int=16)
185 |
186 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8)
187 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8)
188 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8)
189 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8)
190 |
191 | self.DeUp4 = DeUp(in_channels=base_channels*8, out_channels=base_channels*4)
192 | self.DeUpCat4 = De_Cat(in_channels=base_channels * 8, out_channels=base_channels * 4)
193 | self.DeBlock4 = DeBlock(in_channels=base_channels*4)
194 |
195 | self.DeUp3 = DeUp(in_channels=base_channels*4, out_channels=base_channels*2)
196 | self.DeUpCat3 = De_Cat(in_channels=base_channels * 4, out_channels=base_channels * 2)
197 | self.DeBlock3 = DeBlock(in_channels=base_channels*2)
198 |
199 | self.DeUp2 = DeUp(in_channels=base_channels*2, out_channels=base_channels)
200 | self.DeUpCat2 = De_Cat(in_channels=base_channels * 2, out_channels=base_channels)
201 | self.DeBlock2 = DeBlock(in_channels=base_channels)
202 | self.endconv = nn.Conv3d(base_channels, num_classes, kernel_size=1)
203 |
204 | def forward(self, x):
205 | x = self.InitConv(x) # (1, 16, 128, 128, 128)
206 |
207 | x1_1 = self.EnBlock1(x)
208 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64)
209 |
210 | x2_1 = self.EnBlock2_1(x1_2)
211 | x2_1 = self.EnBlock2_2(x2_1)
212 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32)
213 |
214 | x3_1 = self.EnBlock3_1(x2_2)
215 | x3_1 = self.EnBlock3_2(x3_1)
216 | x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16)
217 |
218 | x4_1 = self.EnBlock4_1(x3_2)
219 | x4_2 = self.EnBlock4_2(x4_1)
220 | x4_3 = self.EnBlock4_3(x4_2)
221 | x4_4 = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16)
222 |
223 | y4 = self.DeUp4(x4_4) # (1, 64, 32, 32, 32)
224 | x3_1 = self.Att4(g=y4, x=x3_1)
225 | y4 = self.DeUpCat4(x3_1, y4)
226 | y4 = self.DeBlock4(y4) # (1, 64, 32, 32, 32)
227 |
228 | y3 = self.DeUp3(y4) # (1, 32, 64, 64, 64)
229 | x2_1 = self.Att3(g=y3, x=x2_1)
230 | y3 = self.DeUpCat3(y3, x2_1)
231 | y3 = self.DeBlock3(y3) # (1, 32, 64, 64, 64)
232 |
233 | y2 = self.DeUp2(y3)
234 | x1_1 = self.Att2(g=y2, x=x1_1)
235 | y2 = self.DeUpCat2(y2, x1_1) # (1, 16, 128, 128, 128)
236 | y2 = self.DeBlock2(y2)
237 | y = self.endconv(y2)
238 |
239 | return y
240 |
241 | class Unet(nn.Module):
242 | def __init__(self, in_channels=1, base_channels=16, num_classes=4):
243 | super(Unet, self).__init__()
244 |
245 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2)
246 | self.EnBlock1 = EnBlock(in_channels=base_channels)
247 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2)
248 |
249 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2)
250 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2)
251 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4)
252 |
253 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4)
254 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4)
255 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8)
256 |
257 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8)
258 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8)
259 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8)
260 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8)
261 |
262 | self.DeUpCat4 = DeUp_Cat(in_channels=base_channels * 8, out_channels=base_channels * 4)
263 | self.DeBlock4 = DeBlock(in_channels=base_channels*4)
264 |
265 | self.DeUpCat3 = DeUp_Cat(in_channels=base_channels * 4, out_channels=base_channels * 2)
266 | self.DeBlock3 = DeBlock(in_channels=base_channels*2)
267 |
268 | self.DeUpCat2 = DeUp_Cat(in_channels=base_channels * 2, out_channels=base_channels)
269 | self.DeBlock2 = DeBlock(in_channels=base_channels)
270 | self.endconv = nn.Conv3d(base_channels, num_classes, kernel_size=1)
271 |
272 | def forward(self, x):
273 | x = self.InitConv(x) # (1, 16, 128, 128, 128)
274 |
275 | x1_1 = self.EnBlock1(x)
276 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64)
277 |
278 | x2_1 = self.EnBlock2_1(x1_2)
279 | x2_1 = self.EnBlock2_2(x2_1)
280 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32)
281 |
282 | x3_1 = self.EnBlock3_1(x2_2)
283 | x3_1 = self.EnBlock3_2(x3_1)
284 | x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16)
285 |
286 | x4_1 = self.EnBlock4_1(x3_2)
287 | x4_2 = self.EnBlock4_2(x4_1)
288 | x4_3 = self.EnBlock4_3(x4_2)
289 | x4_4 = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16)
290 |
291 | y4 = self.DeUpCat4(x4_4, x3_1)
292 | y4 = self.DeBlock4(y4) # (1, 64, 32, 32, 32)
293 |
294 | y3 = self.DeUpCat3(y4, x2_1) # (1, 32, 64, 64, 64)
295 | y3 = self.DeBlock3(y3)
296 |
297 | y2 = self.DeUpCat2(y3, x1_1) # (1, 16, 128, 128, 128)
298 | y2 = self.DeBlock2(y2)
299 | y = self.endconv(y2)
300 |
301 | return y
302 | class Unetdrop(nn.Module):
303 | def __init__(self, in_channels=1, base_channels=16, num_classes=4):
304 | super(Unetdrop, self).__init__()
305 |
306 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2)
307 | self.EnBlock1 = EnBlock(in_channels=base_channels)
308 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2)
309 |
310 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2)
311 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2)
312 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4)
313 |
314 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4)
315 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4)
316 | self.dropoutd1 = nn.Dropout(p=0.5)
317 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8)
318 | self.dropoutd2 = nn.Dropout(p=0.5)
319 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8)
320 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8)
321 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8)
322 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8)
323 | self.dropoutu1 = nn.Dropout(p=0.5)
324 | self.DeUpCat4 = DeUp_Cat(in_channels=base_channels * 8, out_channels=base_channels * 4)
325 | self.DeBlock4 = DeBlock(in_channels=base_channels*4)
326 | self.dropoutu2 = nn.Dropout(p=0.5)
327 | self.DeUpCat3 = DeUp_Cat(in_channels=base_channels * 4, out_channels=base_channels * 2)
328 | self.DeBlock3 = DeBlock(in_channels=base_channels*2)
329 |
330 | self.DeUpCat2 = DeUp_Cat(in_channels=base_channels * 2, out_channels=base_channels)
331 | self.DeBlock2 = DeBlock(in_channels=base_channels)
332 | self.endconv = nn.Conv3d(base_channels, num_classes, kernel_size=1)
333 |
334 | def forward(self, x):
335 | x = self.InitConv(x) # (1, 16, 128, 128, 128)
336 |
337 | x1_1 = self.EnBlock1(x)
338 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64)
339 |
340 | x2_1 = self.EnBlock2_1(x1_2)
341 | x2_1 = self.EnBlock2_2(x2_1)
342 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32)
343 |
344 | x3_1 = self.EnBlock3_1(x2_2) # (1, 64, 32, 32, 32)
345 | x3_1drop = self.EnBlock3_2(x3_1)
346 | x3_1drop = self.dropoutd1(x3_1drop)
347 | x3_2 = self.EnDown3(x3_1drop) # (1, 128, 16, 16, 16)
348 |
349 | x4_1 = self.EnBlock4_1(x3_2)
350 | x4_2 = self.EnBlock4_2(x4_1)
351 | x4_3 = self.EnBlock4_3(x4_2)
352 | x4_3 = self.dropoutd2(x4_3)
353 | x4_4 = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16)
354 |
355 | y4 = self.DeUpCat4(x4_4, x3_1)
356 | y4 = self.dropoutu1(y4)
357 | y4 = self.DeBlock4(y4) # (1, 64, 32, 32, 32)
358 |
359 | y3 = self.DeUpCat3(y4, x2_1) # (1, 32, 64, 64, 64)
360 | y3 = self.dropoutu2(y3)
361 | y3 = self.DeBlock3(y3)
362 |
363 | y2 = self.DeUpCat2(y3, x1_1) # (1, 16, 128, 128, 128)
364 | y2 = self.DeBlock2(y2)
365 | y = self.endconv(y2)
366 | return y
367 | if __name__ == '__main__':
368 | with torch.no_grad():
369 | import os
370 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
371 | cuda0 = torch.device('cuda:0')
372 | x = torch.rand((1, 1, 128, 128, 128), device=cuda0)
373 | model = AttUnet(in_channels=1, base_channels=16, num_classes=4)
374 | # model = Unet(in_channels=1, base_channels=16, num_classes=4)
375 | model.cuda()
376 | total = sum([param.nelement() for param in model.parameters()])
377 | print("Number of model's parameter: %.2fM" % (total / 1e6))
378 | output = model(x)
379 | print('output:', output.shape)
380 |
--------------------------------------------------------------------------------
/models/lib/Unet_skipconnection.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | # adapt from https://github.com/MIC-DKFZ/BraTS2017
6 |
7 |
8 | def normalization(planes, norm='gn'):
9 | if norm == 'bn':
10 | m = nn.BatchNorm3d(planes)
11 | elif norm == 'gn':
12 | m = nn.GroupNorm(8, planes)
13 | elif norm == 'in':
14 | m = nn.InstanceNorm3d(planes)
15 | else:
16 | raise ValueError('normalization type {} is not supported'.format(norm))
17 | return m
18 |
19 |
20 |
21 | class InitConv(nn.Module):
22 | def __init__(self, in_channels=4, out_channels=16, dropout=0.2):
23 | super(InitConv, self).__init__()
24 |
25 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
26 | self.dropout = dropout
27 |
28 | def forward(self, x):
29 | y = self.conv(x)
30 | y = F.dropout3d(y, self.dropout)
31 |
32 | return y
33 |
34 |
35 | class EnBlock(nn.Module):
36 | def __init__(self, in_channels, norm='gn'):
37 | super(EnBlock, self).__init__()
38 |
39 | self.bn1 = normalization(in_channels, norm=norm)
40 | self.relu1 = nn.ReLU(inplace=True)
41 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
42 |
43 | self.bn2 = normalization(in_channels, norm=norm)
44 | self.relu2 = nn.ReLU(inplace=True)
45 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
46 |
47 | def forward(self, x):
48 | x1 = self.bn1(x)
49 | x1 = self.relu1(x1)
50 | x1 = self.conv1(x1)
51 | y = self.bn2(x1)
52 | y = self.relu2(y)
53 | y = self.conv2(y)
54 | y = y + x
55 |
56 | return y
57 |
58 |
59 | class EnDown(nn.Module):
60 | def __init__(self, in_channels, out_channels):
61 | super(EnDown, self).__init__()
62 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
63 |
64 | def forward(self, x):
65 | y = self.conv(x)
66 |
67 | return y
68 |
69 |
70 |
71 | class Unet(nn.Module):
72 | def __init__(self, in_channels=4, base_channels=16, num_classes=4):
73 | super(Unet, self).__init__()
74 |
75 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2)
76 | self.EnBlock1 = EnBlock(in_channels=base_channels)
77 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2)
78 |
79 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2)
80 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2)
81 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4)
82 |
83 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4)
84 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4)
85 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8)
86 |
87 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8)
88 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8)
89 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8)
90 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8)
91 |
92 | def forward(self, x):
93 | x = self.InitConv(x) # (1, 16, 128, 128, 128)
94 |
95 | x1_1 = self.EnBlock1(x)
96 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64)
97 |
98 | x2_1 = self.EnBlock2_1(x1_2)
99 | x2_1 = self.EnBlock2_2(x2_1)
100 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32)
101 |
102 | x3_1 = self.EnBlock3_1(x2_2)
103 | x3_1 = self.EnBlock3_2(x3_1)
104 | x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16)
105 |
106 | x4_1 = self.EnBlock4_1(x3_2)
107 | x4_2 = self.EnBlock4_2(x4_1)
108 | x4_3 = self.EnBlock4_3(x4_2)
109 | output = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16)
110 |
111 | return x1_1,x2_1,x3_1,output
112 |
113 |
114 | if __name__ == '__main__':
115 | with torch.no_grad():
116 | import os
117 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
118 | cuda0 = torch.device('cuda:0')
119 | x = torch.rand((1, 4, 128, 128, 128), device=cuda0)
120 | # model = Unet1(in_channels=4, base_channels=16, num_classes=4)
121 | model = Unet(in_channels=4, base_channels=16, num_classes=4)
122 | model.cuda()
123 | output = model(x)
124 | print('output:', output.shape)
125 |
--------------------------------------------------------------------------------
/models/lib/VNet3D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class ConvBlock(nn.Module):
6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
7 | super(ConvBlock, self).__init__()
8 |
9 | ops = []
10 | for i in range(n_stages):
11 | if i==0:
12 | input_channel = n_filters_in
13 | else:
14 | input_channel = n_filters_out
15 |
16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
17 | if normalization == 'bn':
18 | ops.append(nn.BatchNorm3d(n_filters_out))
19 | elif normalization == 'gn':
20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
21 | elif normalization == 'in':
22 | ops.append(nn.InstanceNorm3d(n_filters_out))
23 | elif normalization != 'none':
24 | assert False
25 | ops.append(nn.ReLU(inplace=True))
26 |
27 | self.conv = nn.Sequential(*ops)
28 |
29 | def forward(self, x):
30 | x = self.conv(x)
31 | return x
32 |
33 |
34 | class ResidualConvBlock(nn.Module):
35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
36 | super(ResidualConvBlock, self).__init__()
37 |
38 | ops = []
39 | for i in range(n_stages):
40 | if i == 0:
41 | input_channel = n_filters_in
42 | else:
43 | input_channel = n_filters_out
44 |
45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
46 | if normalization == 'bn':
47 | ops.append(nn.BatchNorm3d(n_filters_out))
48 | elif normalization == 'gn':
49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
50 | elif normalization == 'in':
51 | ops.append(nn.InstanceNorm3d(n_filters_out))
52 | elif normalization != 'none':
53 | assert False
54 |
55 | if i != n_stages-1:
56 | ops.append(nn.ReLU(inplace=True))
57 |
58 | self.conv = nn.Sequential(*ops)
59 | self.relu = nn.ReLU(inplace=True)
60 |
61 | def forward(self, x):
62 | x = (self.conv(x) + x)
63 | x = self.relu(x)
64 | return x
65 |
66 |
67 | class DownsamplingConvBlock(nn.Module):
68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
69 | super(DownsamplingConvBlock, self).__init__()
70 |
71 | ops = []
72 | if normalization != 'none':
73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
74 | if normalization == 'bn':
75 | ops.append(nn.BatchNorm3d(n_filters_out))
76 | elif normalization == 'gn':
77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
78 | elif normalization == 'in':
79 | ops.append(nn.InstanceNorm3d(n_filters_out))
80 | else:
81 | assert False
82 | else:
83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
84 |
85 | ops.append(nn.ReLU(inplace=True))
86 |
87 | self.conv = nn.Sequential(*ops)
88 |
89 | def forward(self, x):
90 | x = self.conv(x)
91 | return x
92 |
93 |
94 | class UpsamplingDeconvBlock(nn.Module):
95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
96 | super(UpsamplingDeconvBlock, self).__init__()
97 |
98 | ops = []
99 | if normalization != 'none':
100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
101 | if normalization == 'bn':
102 | ops.append(nn.BatchNorm3d(n_filters_out))
103 | elif normalization == 'gn':
104 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
105 | elif normalization == 'in':
106 | ops.append(nn.InstanceNorm3d(n_filters_out))
107 | else:
108 | assert False
109 | else:
110 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride))
111 |
112 | ops.append(nn.ReLU(inplace=True))
113 |
114 | self.conv = nn.Sequential(*ops)
115 |
116 | def forward(self, x):
117 | x = self.conv(x)
118 | return x
119 |
120 |
121 | class Upsampling(nn.Module):
122 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
123 | super(Upsampling, self).__init__()
124 |
125 | ops = []
126 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False))
127 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
128 | if normalization == 'bn':
129 | ops.append(nn.BatchNorm3d(n_filters_out))
130 | elif normalization == 'gn':
131 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
132 | elif normalization == 'in':
133 | ops.append(nn.InstanceNorm3d(n_filters_out))
134 | elif normalization != 'none':
135 | assert False
136 | ops.append(nn.ReLU(inplace=True))
137 |
138 | self.conv = nn.Sequential(*ops)
139 |
140 | def forward(self, x):
141 | x = self.conv(x)
142 | return x
143 |
144 |
145 | class VNet(nn.Module):
146 | def __init__(self, n_channels=1, n_classes=4, n_filters=16, normalization='gn', has_dropout=False):
147 | super(VNet, self).__init__()
148 | self.has_dropout = has_dropout
149 |
150 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
151 | self.block_one_dw = DownsamplingConvBlock(n_filters, n_filters * 2, normalization=normalization) # 32
152 |
153 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
154 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) # 64
155 |
156 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
157 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) # 128
158 |
159 | # self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
160 | # self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) # 256
161 |
162 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
163 | self.block_four_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization)
164 |
165 | self.block_five = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
166 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization)
167 |
168 | self.block_six = ConvBlock(3, n_filters * 2, n_filters * 2, normalization=normalization)
169 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
170 |
171 | # self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
172 | # self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization)
173 |
174 | self.block_seven = ConvBlock(1, n_filters, n_filters, normalization=normalization)
175 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
176 |
177 | self.dropout = nn.Dropout3d(p=0.5, inplace=False)
178 | # self.__init_weight()
179 |
180 | def encoder(self, input):
181 | x1 = self.block_one(input)
182 | x1_dw = self.block_one_dw(x1)
183 |
184 | x2 = self.block_two(x1_dw)
185 | x2_dw = self.block_two_dw(x2)
186 |
187 | x3 = self.block_three(x2_dw)
188 | x3_dw = self.block_three_dw(x3)
189 |
190 | x4 = self.block_four(x3_dw)
191 |
192 | # x5 = self.block_five(x4_dw)
193 | # x5 = F.dropout3d(x5, p=0.5, training=True)
194 | if self.has_dropout:
195 | x4 = self.dropout(x4)
196 |
197 | res = [x1, x2, x3, x4]
198 |
199 | return res
200 |
201 | def decoder(self, features):
202 | x1 = features[0]
203 | x2 = features[1]
204 | x3 = features[2]
205 | x4 = features[3]
206 | # x5 = features[4]
207 |
208 | x4_up = self.block_four_up(x4)
209 | x4_up = x4_up + x3
210 |
211 | x5 = self.block_five(x4_up)
212 | x5_up = self.block_five_up(x5)
213 | x5_up = x5_up + x2
214 |
215 | x6 = self.block_six(x5_up)
216 | x6_up = self.block_six_up(x6)
217 | x6_up = x6_up + x1
218 |
219 | x7 = self.block_seven(x6_up)
220 | # x9 = F.dropout3d(x9, p=0.5, training=True)
221 | if self.has_dropout:
222 | x7 = self.dropout(x7)
223 | out = self.out_conv(x7)
224 | return out
225 |
226 |
227 | def forward(self, input, turnoff_drop=False):
228 | if turnoff_drop:
229 | has_dropout = self.has_dropout
230 | self.has_dropout = False
231 | features = self.encoder(input)
232 | out = self.decoder(features)
233 | if turnoff_drop:
234 | self.has_dropout = has_dropout
235 | return out
236 |
237 | if __name__ == '__main__':
238 | with torch.no_grad():
239 | import os
240 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
241 | cuda0 = torch.device('cuda:0')
242 | x = torch.rand((1, 1, 128, 128, 128), device=cuda0)
243 | # model = AttUnet(in_channels=1, base_channels=16, num_classes=4)
244 | model = VNet(n_channels=1, n_classes=4, n_filters=16, normalization='gn', has_dropout=False)
245 | model.cuda()
246 | total = sum([param.nelement() for param in model.parameters()])
247 | print("Number of model's parameter: %.2fM" % (total / 1e6))
248 | output = model(x)
249 | print('output:', output.shape)
--------------------------------------------------------------------------------
/models/lib/nullfile:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/lib/seg_eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import nibabel as nib
4 |
5 | # calculate evaluation metrics for segmentation
6 | def seg_eval_metric(pred_label, gt_label, output_chn):
7 | class_n = np.unique(gt_label)
8 | # dice
9 | dice_c = dice_n_class(move_img=pred_label, refer_img=gt_label, output_chn=output_chn)
10 | return dice_c
11 |
12 | # dice value
13 | def dice_n_class(move_img, refer_img, output_chn):
14 | # list of classes
15 | c_list_old = np.unique(refer_img)
16 | # for those class not in the Gt, set dice to zero
17 | c_list = np.arange(output_chn)
18 | dice_c = []
19 | for c in range(len(c_list)):
20 | # intersection
21 | ints = np.sum(((move_img == c_list[c]) * 1) * ((refer_img == c_list[c]) * 1))
22 | # sum
23 | sums = np.sum(((move_img == c_list[c]) * 1) + ((refer_img == c_list[c]) * 1)) + 0.0001
24 | dice_c.append((2.0 * ints) / sums)
25 |
26 | return dice_c
27 |
28 |
29 | # conformity value
30 | def conform_n_class(move_img, refer_img):
31 | # list of classes
32 | c_list = np.unique(refer_img)
33 |
34 | conform_c = []
35 | for c in range(len(c_list)):
36 | # intersection
37 | ints = np.sum(((move_img == c_list[c]) * 1) * ((refer_img == c_list[c]) * 1))
38 | # sum
39 | sums = np.sum(((move_img == c_list[c]) * 1) + ((refer_img == c_list[c]) * 1)) + 0.0001
40 | # dice
41 | dice_temp = (2.0 * ints) / sums
42 | # conformity
43 | conform_temp = (3*dice_temp - 2) / dice_temp
44 |
45 | conform_c.append(conform_temp)
46 |
47 | return conform_c
48 |
49 |
50 | # Jaccard index
51 | def jaccard_n_class(move_img, refer_img, output_chn):
52 | # list of classes
53 | c_list_old = np.unique(refer_img)
54 | # c_list = [0, 1, 2, 3]
55 | c_list = np.arange(output_chn)
56 |
57 | jaccard_c = []
58 | for c in range(len(c_list)):
59 | move_img_c = (move_img == c_list[c])
60 | refer_img_c = (refer_img == c_list[c])
61 | # intersection
62 | ints = np.sum(np.logical_and(move_img_c, refer_img_c)*1)
63 | # union
64 | uni = np.sum(np.logical_or(move_img_c, refer_img_c)*1) + 0.0001
65 |
66 | jaccard_c.append(ints / uni)
67 |
68 | return jaccard_c
69 |
70 |
71 | # precision and recall
72 | def precision_recall_n_class(move_img, refer_img):
73 | # list of classes
74 | c_list = np.unique(refer_img)
75 |
76 | precision_c = []
77 | recall_c = []
78 | for c in range(len(c_list)):
79 | move_img_c = (move_img == c_list[c])
80 | refer_img_c = (refer_img == c_list[c])
81 | # intersection
82 | ints = np.sum(np.logical_and(move_img_c, refer_img_c)*1)
83 | # precision
84 | prec = ints / (np.sum(move_img_c*1) + 0.001)
85 | # recall
86 | recall = ints / (np.sum(refer_img_c*1) + 0.001)
87 |
88 | precision_c.append(prec)
89 | recall_c.append(recall)
90 |
91 | return precision_c, recall_c
92 |
93 | # Sensitivity(recall of the positive)
94 | def sensitivity(pred, gt, output_chn):
95 | '''
96 | calculate the sensitivity and the specificity
97 | :param pred: predictions
98 | :param gt: ground truth
99 | :param output_chn: categories (including background)
100 | :return: A list contains sensitivities and the specificity, the first item is specificity
101 | and the others are sensitivities of other categories
102 | '''
103 | s_list = np.arange(output_chn)
104 | sensitivity_s = []
105 | for s in range(output_chn):
106 | # TP
107 | TP = np.sum((pred == s_list[s])*(gt == s_list[s]))
108 | # FN
109 | FN = np.sum((pred != s_list[s])*(gt == s_list[s]))
110 | # sensitivity &specificity(for category 0 means specificity, while others means sensitivity)
111 | sensitivity = TP/(TP+FN+0.0001)
112 | sensitivity_s.append(sensitivity)
113 | return sensitivity_s
114 |
115 | if __name__ == '__main__':
116 | pred = np.array([[1,1,1,1],[1,1,1,1],[0,0,0,0]])
117 | gt = np.array([[0,0,0,0],[1,1,1,1],[0,1,1,0]])
118 | sensi = sensitivity(pred, gt, 2)
119 | a =1
--------------------------------------------------------------------------------
/models/trustedseg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 | import torch.nn as nn
5 | import time
6 | import torch.nn.functional as F
7 | from models.criterions import softmaxBCE_dice,KL,ce_loss,mse_loss,dce_eviloss
8 | from predict import tailor_and_concat
9 | from models.lib.VNet3D import VNet
10 | from models.lib.UNet3DZoo import Unet,AttUnet
11 | from models.lib.TransBTS_downsample8x_skipconnection import TransBTS
12 |
13 | class TMSU(nn.Module):
14 |
15 | def __init__(self, classes, modes, model,input_dims,total_epochs,lambda_epochs=1):
16 | """
17 | :param classes: Number of classification categories
18 | :param modes: Number of modes
19 | :param classifier_dims: Dimension of the classifier
20 | :param annealing_epoch: KL divergence annealing epoch during training
21 | """
22 | super(TMSU, self).__init__()
23 | # ---- Net Backbone ----
24 | if model == 'AU' and input_dims =='four':
25 | self.backbone = AttUnet(in_channels=4, base_channels=16, num_classes=classes)
26 | elif model == 'AU':
27 | self.backbone = AttUnet(in_channels=1, base_channels=16, num_classes=classes)
28 | elif model == 'V'and input_dims =='four':
29 | self.backbone = VNet(n_channels=4, n_classes=classes, n_filters=16, normalization='gn', has_dropout=False)
30 | elif model == 'V':
31 | self.backbone = VNet(n_channels=1, n_classes=classes, n_filters=16, normalization='gn', has_dropout=False)
32 | elif model =='TransU':
33 | _, self.backbone = TransBTS(input_dims=input_dims, _conv_repr=True, _pe_type="learned")
34 | elif model == 'U'and input_dims =='four':
35 | self.backbone = Unet(in_channels=4, base_channels=16, num_classes=classes)
36 | else:
37 | self.backbone = Unet(in_channels=1, base_channels=16, num_classes=classes)
38 | self.backbone.cuda()
39 | self.modes = modes
40 | self.classes = classes
41 | self.eps = 1e-10
42 | self.lambda_epochs = lambda_epochs
43 | self.total_epochs = total_epochs+1
44 | # self.Classifiers = nn.ModuleList([Classifier(classifier_dims[i], self.classes) for i in range(self.modes)])
45 |
46 | def forward(self, X, y, global_step, mode, use_TTA=False):
47 | # X data
48 | # y target
49 | # global_step : epochs
50 |
51 | # step zero: backbone
52 | if mode == 'train':
53 | backbone_output = self.backbone(X)
54 | elif mode == 'val':
55 | backbone_output = tailor_and_concat(X, self.backbone)
56 | # backbone_X = F.softmax(backbone_X,dim=1)
57 | else:
58 | if not use_TTA:
59 | backbone_output = tailor_and_concat(X, self.backbone)
60 | # backbone_X = F.softmax(backbone_X,dim=1)
61 | else:
62 | x = X
63 | x = x[..., :155]
64 | logit = F.softmax(tailor_and_concat(x, self.backbone), 1) # no flip
65 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2,)), self.backbone).flip(dims=(2,)), 1) # flip H
66 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3,)), self.backbone).flip(dims=(3,)), 1) # flip W
67 | logit += F.softmax(tailor_and_concat(x.flip(dims=(4,)), self.backbone).flip(dims=(4,)), 1) # flip D
68 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3)), self.backbone).flip(dims=(2, 3)),
69 | 1) # flip H, W
70 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 4)), self.backbone).flip(dims=(2, 4)),
71 | 1) # flip H, D
72 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3, 4)), self.backbone).flip(dims=(3, 4)),
73 | 1) # flip W, D
74 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3, 4)), self.backbone).flip(dims=(2, 3, 4)),
75 | 1) # flip H, W, D
76 | backbone_output = logit / 8.0 # mean
77 | # backbone_X = F.softmax(backbone_X,dim=1)
78 |
79 | # step one
80 | evidence = self.infer(backbone_output) # batch_size * class * image_size
81 |
82 | # step two
83 | alpha = evidence + 1
84 | if mode == 'train' or mode == 'val':
85 | loss = dce_eviloss(y.to(torch.int64), alpha, self.classes, global_step, self.lambda_epochs)
86 | loss = torch.mean(loss)
87 | return evidence, loss
88 | else:
89 | return evidence
90 |
91 | def infer(self, input):
92 | """
93 | :param input: modal data
94 | :return: evidence of modal data
95 | """
96 | # evidence = (input-torch.min(input))/(torch.max(input)-torch.min(input))
97 | evidence = F.softplus(input)
98 | # evidence[m_num] = torch.exp(torch.clamp(evidence, -10, 10))
99 | # evidence = F.relu(evidence)
100 | return evidence
101 |
102 |
--------------------------------------------------------------------------------
/numpyfunctions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn.metrics as metrics
3 | import pymia.evaluation.metric as m
4 |
5 |
6 | def ece_binary(probabilities, target, n_bins=10, threshold_range: tuple = None, mask=None, out_bins: dict = None,
7 | bin_weighting='proportion'):
8 |
9 | n_dim = target.ndim
10 |
11 | pos_frac, mean_confidence, bin_count, non_zero_bins = \
12 | binary_calibration(probabilities, target, n_bins, threshold_range, mask)
13 |
14 | bin_proportions = _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim)
15 |
16 | if out_bins is not None:
17 | out_bins['bins_count'] = bin_count
18 | out_bins['bins_avg_confidence'] = mean_confidence
19 | out_bins['bins_positive_fraction'] = pos_frac
20 | out_bins['bins_non_zero'] = non_zero_bins
21 |
22 | ece = (np.abs(mean_confidence - pos_frac) * bin_proportions).sum()
23 | return ece
24 |
25 |
26 | def binary_calibration(probabilities, target, n_bins=10, threshold_range: tuple = None, mask=None):
27 | if probabilities.ndim > target.ndim:
28 | if probabilities.shape[-1] > 2:
29 | raise ValueError('can only evaluate the calibration for binary classification')
30 | elif probabilities.shape[-1] == 2:
31 | probabilities = probabilities[..., 1]
32 | else:
33 | probabilities = np.squeeze(probabilities, axis=-1)
34 |
35 | if mask is not None:
36 | probabilities = probabilities[mask]
37 | target = target[mask]
38 |
39 | if threshold_range is not None:
40 | low_thres, up_thres = threshold_range
41 | mask = np.logical_and(probabilities < up_thres, probabilities > low_thres)
42 | probabilities = probabilities[mask]
43 | target = target[mask]
44 |
45 | pos_frac, mean_confidence, bin_count, non_zero_bins = \
46 | _binary_calibration(target.flatten(), probabilities.flatten(), n_bins)
47 |
48 | return pos_frac, mean_confidence, bin_count, non_zero_bins
49 |
50 |
51 | def _binary_calibration(target, probs_positive_cls, n_bins=10):
52 | # same as sklearn.calibration calibration_curve but with the bin_count returned
53 | bins = np.linspace(0., 1. + 1e-8, n_bins + 1)
54 | binids = np.digitize(probs_positive_cls, bins) - 1
55 |
56 | # # note: this is the original formulation which has always n_bins + 1 as length
57 | # bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=len(bins))
58 | # bin_true = np.bincount(binids, weights=target, minlength=len(bins))
59 | # bin_total = np.bincount(binids, minlength=len(bins))
60 |
61 | bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=n_bins)
62 | bin_true = np.bincount(binids, weights=target, minlength=n_bins)
63 | bin_total = np.bincount(binids, minlength=n_bins)
64 |
65 | nonzero = bin_total != 0
66 | prob_true = (bin_true[nonzero] / bin_total[nonzero])
67 | prob_pred = (bin_sums[nonzero] / bin_total[nonzero])
68 |
69 | return prob_true, prob_pred, bin_total[nonzero], nonzero
70 |
71 |
72 | def _get_proportion(bin_weighting: str, bin_count: np.ndarray, non_zero_bins: np.ndarray, n_dim: int):
73 | if bin_weighting == 'proportion':
74 | bin_proportions = bin_count / bin_count.sum()
75 | elif bin_weighting == 'log_proportion':
76 | bin_proportions = np.log(bin_count) / np.log(bin_count).sum()
77 | elif bin_weighting == 'power_proportion':
78 | bin_proportions = bin_count**(1/n_dim) / (bin_count**(1/n_dim)).sum()
79 | elif bin_weighting == 'mean_proportion':
80 | bin_proportions = 1 / non_zero_bins.sum()
81 | else:
82 | raise ValueError('unknown bin weighting "{}"'.format(bin_weighting))
83 | return bin_proportions
84 |
85 |
86 | def uncertainty(prediction, target, thresholded_uncertainty, mask=None):
87 | if mask is not None:
88 | prediction = prediction[mask]
89 | target = target[mask]
90 | thresholded_uncertainty = thresholded_uncertainty[mask]
91 |
92 | tps = np.logical_and(target, prediction)
93 | tns = np.logical_and(~target, ~prediction)
94 | fps = np.logical_and(~target, prediction)
95 | fns = np.logical_and(target, ~prediction)
96 |
97 | tpu = np.logical_and(tps, thresholded_uncertainty).sum()
98 | tnu = np.logical_and(tns, thresholded_uncertainty).sum()
99 | fpu = np.logical_and(fps, thresholded_uncertainty).sum()
100 | fnu = np.logical_and(fns, thresholded_uncertainty).sum()
101 |
102 | tp = tps.sum()
103 | tn = tns.sum()
104 | fp = fps.sum()
105 | fn = fns.sum()
106 |
107 | return tp, tn, fp, fn, tpu, tnu, fpu, fnu
108 |
109 |
110 | def error_dice(fp, fn, tpu, tnu, fpu, fnu):
111 | if ((fnu + fpu) == 0) and ((fn + fp + fnu + fpu + tnu + tpu) == 0):
112 | return 1.
113 | return (2 * (fnu + fpu)) / (fn + fp + fnu + fpu + tnu + tpu)
114 |
115 |
116 | def error_recall(fp, fn, fpu, fnu):
117 | if ((fnu + fpu) == 0) and ((fn + fp) == 0):
118 | return 1.
119 | return (fnu + fpu) / (fn + fp)
120 |
121 |
122 | def error_precision(tpu, tnu, fpu, fnu):
123 | if ((fnu + fpu) == 0) and ((fnu + fpu + tpu + tnu) == 0):
124 | return 1.
125 | return (fnu + fpu) / (fnu + fpu + tpu + tnu)
126 |
127 |
128 | def dice(prediction, target):
129 | _check_ndarray(prediction)
130 | _check_ndarray(target)
131 |
132 | d = m.DiceCoefficient()
133 | d.confusion_matrix = m.ConfusionMatrix(prediction, target)
134 | return d.calculate()
135 |
136 |
137 | def confusion_matrx(prediction, target):
138 | _check_ndarray(prediction)
139 | _check_ndarray(target)
140 |
141 | cm = m.ConfusionMatrix(prediction, target)
142 | return cm.tp, cm.tn, cm.fp, cm.fn, cm.n
143 |
144 |
145 | def accuracy(prediction, target):
146 | _check_ndarray(prediction)
147 | _check_ndarray(target)
148 |
149 | a = m.Accuracy()
150 | a.confusion_matrix = m.ConfusionMatrix(prediction, target)
151 | return a.calculate()
152 |
153 |
154 | def log_loss_sklearn(probabilities, target, labels=None):
155 | _check_ndarray(probabilities)
156 | _check_ndarray(target)
157 |
158 | if probabilities.shape[-1] != target.shape[-1]:
159 | probabilities = probabilities.reshape(-1, probabilities.shape[-1])
160 | else:
161 | probabilities = probabilities.reshape(-1)
162 | target = target.reshape(-1)
163 | return metrics.log_loss(target, probabilities, labels=labels)
164 |
165 |
166 | def entropy(p, dim=-1, keepdims=False):
167 | # exactly the same as scipy.stats.entropy()
168 | return -np.where(p > 0, p * np.log(p), [0.0]).sum(axis=dim, keepdims=keepdims)
169 |
170 |
171 | def _check_ndarray(obj):
172 | if not isinstance(obj, np.ndarray):
173 | raise ValueError("object of type '{}' must be '{}'".format(type(obj).__name__, np.ndarray.__name__))
174 |
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import os
3 | def loss_plot(args,loss):
4 | num = args.end_epoch
5 | x = [i for i in range(num)]
6 | plot_save_path = r'results/plot/'
7 | if not os.path.exists(plot_save_path):
8 | os.makedirs(plot_save_path)
9 | save_loss = plot_save_path+str(args.model_name)+'_'+str(args.batch_size)+'_'+str(args.dataset)+'_'+str(args.end_epoch)+'_loss.jpg'
10 | plt.figure()
11 | plt.plot(x,loss,label='loss')
12 | plt.legend()
13 | plt.savefig(save_loss)
14 |
15 | def metrics_plot(arg,name,*args):
16 | num = arg.end_epoch
17 | names = name.split('&')
18 | metrics_value = args
19 | i=0
20 | x = [i for i in range(num)]
21 | plot_save_path = r'results/plot/'
22 | if not os.path.exists(plot_save_path):
23 | os.makedirs(plot_save_path)
24 | save_metrics = plot_save_path + str(arg.model_name) + '_' + str(arg.batch_size) + '_' + str(arg.dataset) + '_' + str(arg.end_epoch) + '_'+name+'.jpg'
25 | plt.figure()
26 | for l in metrics_value:
27 | plt.plot(x,l,label=str(names[i]))
28 | #plt.scatter(x,l,label=str(l))
29 | i+=1
30 | plt.legend()
31 | plt.savefig(save_metrics)
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import SimpleITK as sitk
9 | import cv2
10 | import math
11 | from medpy.metric import binary
12 | from sklearn.externals import joblib
13 | from binary import assd
14 | cudnn.benchmark = True
15 | # import models.criterions as U_entropy
16 | import numpy as np
17 | import nibabel as nib
18 | import imageio
19 | from test_uncertainty import ece_binary,UncertaintyAndCorrectionEvalNumpy,Normalized_U
20 |
21 | def cal_ueo(to_evaluate,thresholds):
22 | UEO = []
23 | for threshold in thresholds:
24 | results = dict()
25 | metric = UncertaintyAndCorrectionEvalNumpy(threshold)
26 | metric(to_evaluate,results)
27 | ueo = results['corrected_add_dice']
28 | UEO.append(ueo)
29 | max_UEO = max(UEO)
30 | return max_UEO
31 |
32 | def cal_ece(logits,targets):
33 | # ece_total = 0
34 | logit = logits
35 | target = targets
36 | pred = F.softmax(logit, dim=0)
37 | pc = pred.cpu().detach().numpy()
38 | pc = pc.argmax(0)
39 | ece = ece_binary(pc, target)
40 | return ece
41 |
42 | def cal_ece_our(preds,targets):
43 | # ece_total = 0
44 | target = targets
45 | pc = preds.cpu().detach().numpy()
46 | ece = ece_binary(pc, target)
47 | return ece
48 |
49 | def Uentropy(logits,c):
50 | # c = 4
51 | # logits = torch.randn(1, 4, 240, 240,155).cuda()
52 | pc = F.softmax(logits, dim=1) # 1 4 240 240 155
53 | logits = F.log_softmax(logits, dim=1) # 1 4 240 240 155
54 | u_all = -pc * logits / math.log(c)
55 | NU = torch.sum(u_all[:,1:u_all.shape[1],:,:], dim=1)
56 | return NU
57 |
58 | def Uentropy_our(logits,c):
59 | # c = 4
60 | # logits = torch.randn(1, 4, 240, 240,155).cuda()
61 | pc = logits # 1 4 240 240 155
62 | logpc = torch.log(logits) # 1 4 240 240 155
63 | # u_all = -pc * logpc / c
64 | u_all = -pc * logpc / math.log(c)
65 | # max_u = torch.max(u_all)
66 | # min_u = torch.min(u_all)
67 | # NU1 = torch.sum(u_all, dim=1)
68 | # k = u_all.shape[1]
69 | # NU2 = torch.sum(u_all[:, 0:u_all.shape[1]-1, :, :], dim=1)
70 | NU = torch.sum(u_all[:,1:u_all.shape[1],:,:], dim=1)
71 | return NU
72 |
73 | def one_hot(ori, classes):
74 |
75 | batch, h, w, d = ori.size()
76 | new_gd = torch.zeros((batch, classes, h, w, d), dtype=ori.dtype).cuda()
77 | for j in range(classes):
78 | index_list = (ori == j).nonzero()
79 |
80 | for i in range(len(index_list)):
81 | batch, height, width, depth = index_list[i]
82 | new_gd[batch, j, height, width, depth] = 1
83 |
84 | return new_gd.float()
85 |
86 | def tailor_and_concat(x, model):
87 | temp = []
88 |
89 | temp.append(x[..., :128, :128, :128])
90 | temp.append(x[..., :128, 112:240, :128])
91 | temp.append(x[..., 112:240, :128, :128])
92 | temp.append(x[..., 112:240, 112:240, :128])
93 | temp.append(x[..., :128, :128, 27:155])
94 | temp.append(x[..., :128, 112:240, 27:155])
95 | temp.append(x[..., 112:240, :128, 27:155])
96 | temp.append(x[..., 112:240, 112:240, 27:155])
97 |
98 | if x.shape[1] == 1:
99 | y = torch.cat((x.clone(), x.clone(), x.clone(), x.clone()), 1)
100 | elif x.shape[1] == 4:
101 | y = x.clone()
102 | else:
103 | y = torch.cat((x.clone(), x.clone()), 1)
104 |
105 | for i in range(len(temp)):
106 | temp[i] = model(temp[i])
107 | # .squeeze(0)
108 | # l= temp[0].unsqueeze(0)
109 | y[..., :128, :128, :128] = temp[0]
110 | y[..., :128, 128:240, :128] = temp[1][..., :, 16:128, :]
111 | y[..., 128:240, :128, :128] = temp[2][..., 16:128, :, :]
112 | y[..., 128:240, 128:240, :128] = temp[3][..., 16:128, 16:128, :]
113 | y[..., :128, :128, 128:155] = temp[4][..., 96:123]
114 | y[..., :128, 128:240, 128:155] = temp[5][..., :, 16:128, 96:123]
115 | y[..., 128:240, :128, 128:155] = temp[6][..., 16:128, :, 96:123]
116 | y[..., 128:240, 128:240, 128:155] = temp[7][..., 16:128, 16:128, 96:123]
117 |
118 | return y[..., :155]
119 |
120 | def hausdorff_distance(lT,lP):
121 | labelPred=sitk.GetImageFromArray(lP, isVector=False)
122 | labelTrue=sitk.GetImageFromArray(lT, isVector=False)
123 | hausdorffcomputer=sitk.HausdorffDistanceImageFilter()
124 | hausdorffcomputer.Execute(labelTrue>0.5,labelPred>0.5)
125 | return hausdorffcomputer.GetAverageHausdorffDistance()#hausdorffcomputer.GetHausdorffDistance()
126 |
127 | def hd_score(o,t, eps=1e-8):
128 | if (o.sum()==0) | (t.sum()==0):
129 | hd = eps,
130 | else:
131 | #ret += hausdorff_distance(wt_mask, wt_pb),
132 | hd = binary.hd95(o, t, voxelspacing=None),
133 |
134 | return hd
135 |
136 | def dice_score(o, t, eps=1e-8):
137 | if (o.sum()==0) | (t.sum()==0):
138 | dice = eps
139 | else:
140 | num = 2*(o*t).sum() + eps
141 | den = o.sum() + t.sum() + eps
142 | dice = num/den
143 | return dice
144 |
145 |
146 | def mIOU(o, t, eps=1e-8):
147 | num = (o*t).sum() + eps
148 | den = (o | t).sum() + eps
149 | return num/den
150 |
151 | def assd_score(o, t):
152 | s = assd(o, t)
153 | return s
154 |
155 | def softmax_mIOU_score(output, target):
156 | mIOU_score = []
157 | mIOU_score.append(mIOU(o=(output==1),t=(target==1)))
158 | mIOU_score.append(mIOU(o=(output==2),t=(target==2)))
159 | mIOU_score.append(mIOU(o=(output==3),t=(target==3)))
160 | return mIOU_score
161 |
162 | def softmax_output_hd(output, target):
163 | ret = []
164 |
165 | # whole (label: 1 ,2 ,3)
166 | o = output > 0; t = target > 0 # ce
167 | ret += hd_score(o, t),
168 | # core (tumor core 1 and 3)
169 | o = (output == 1) | (output == 3)
170 | t = (target == 1) | (target == 3)
171 | ret += hd_score(o, t),
172 | # active (enhanccing tumor region 1 )# 3
173 | o = (output == 3);t = (target == 3)
174 | ret += hd_score(o, t),
175 |
176 | return ret
177 |
178 | def softmax_output_assd(output, target):
179 | ret = []
180 |
181 | # whole (label: 1 ,2 ,3)
182 | wt_o = output > 0; wt_t = target > 0 # ce
183 | ret += assd_score(wt_o, wt_t),
184 | # core (tumor core 1 and 3)
185 | tc_o = (output == 1) | (output == 3)
186 | tc_t = (target == 1) | (target == 3)
187 | ret += assd_score(tc_o, tc_t),
188 | # active (enhanccing tumor region 1 )# 3
189 | et_o = (output == 3);et_t = (target == 3)
190 | ret += assd_score(et_o, et_t),
191 |
192 | return ret
193 |
194 | def softmax_output_dice(output, target):
195 | ret = []
196 |
197 | # whole (label: 1 ,2 ,3)
198 | o = output > 0; t = target > 0 # ce
199 | # print(o.shape)
200 | # print(t.shape)
201 | ret += dice_score(o, t),
202 | # core (tumor core 1 and 3)
203 | o = (output == 1) | (output == 3)
204 | t = (target == 1) | (target == 3)
205 | ret += dice_score(o, t),
206 | # active (enhanccing tumor region 1 )# 3
207 | o = (output == 3);t = (target == 3)
208 | ret += dice_score(o, t),
209 |
210 | return ret
211 |
212 |
213 | keys = 'whole', 'core', 'enhancing', 'loss'
214 |
215 | def validate_softmax(
216 | save_dir,
217 | best_dice,
218 | current_epoch,
219 | end_epoch,
220 | save_freq,
221 | valid_loader,
222 | model,
223 | multimodel,
224 | Net_name,
225 | names=None,# The names of the patients orderly!
226 | ):
227 |
228 | H, W, T = 240, 240, 160
229 |
230 | runtimes = []
231 | dice_total = 0
232 | iou_total = 0
233 | num = len(valid_loader)
234 |
235 | for i, data in enumerate(valid_loader):
236 | print('-------------------------------------------------------------------')
237 | msg = 'Subject {}/{}, '.format(i + 1, len(valid_loader))
238 | x, target = data
239 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
240 | x = x.to(device)
241 | # target = target.to(device)
242 | target = torch.squeeze(target).cpu().numpy()
243 | torch.cuda.synchronize() # add the code synchronize() to correctly count the runtime.
244 | start_time = time.time()
245 | logit = tailor_and_concat(x, model)
246 |
247 | torch.cuda.synchronize()
248 | elapsed_time = time.time() - start_time
249 | logging.info('Single sample test time consumption {:.2f} minutes!'.format(elapsed_time / 60))
250 | runtimes.append(elapsed_time)
251 |
252 | output = F.softmax(logit, dim=1)
253 | output = output[0, :, :H, :W, :T].cpu().detach().numpy()
254 | output = output.argmax(0)
255 | iou_res = softmax_mIOU_score(output, target[:, :, :155])
256 | dice_res = softmax_output_dice(output,target[:,:,:155])
257 | # hd_res = softmax_output_hd(output, target[:, :, :155])
258 | dice_total += dice_res[1]
259 | iou_total += iou_res[1]
260 | name = str(i)
261 | if names:
262 | name = names[i]
263 | msg += '{:>20}, '.format(name)
264 | print(msg)
265 | print('current_dice:{}'.format(dice_res))
266 | # print('current_dice:{},hd_res:{}'.format(dice_res,hd_res))
267 | aver_dice = dice_total / num
268 | aver_iou = iou_total / num
269 | if (current_epoch + 1) % int(save_freq) == 0:
270 | if aver_dice > best_dice\
271 | or (current_epoch + 1) % int(end_epoch - 1) == 0 \
272 | or (current_epoch + 1) % int(end_epoch - 2) == 0 \
273 | or (current_epoch + 1) % int(end_epoch - 3) == 0:
274 | print('aver_dice:{} > best_dice:{}'.format(aver_dice, best_dice))
275 | logging.info('aver_dice:{} > best_dice:{}'.format(aver_dice, best_dice))
276 | logging.info('===========>save best model!')
277 | best_dice = aver_dice
278 | print('===========>save best model!')
279 | file_name = os.path.join(save_dir, Net_name +'_' + multimodel + '_epoch_{}.pth'.format(current_epoch))
280 | torch.save({
281 | 'epoch': current_epoch,
282 | 'state_dict': model.state_dict(),
283 | },
284 | file_name)
285 | print('runtimes:', sum(runtimes)/len(runtimes))
286 |
287 | return best_dice,aver_dice,aver_iou
288 |
289 | def test_softmax(
290 | test_loader,
291 | model,
292 | multimodel,
293 | Net_name,
294 | Variance,
295 | load_file,
296 | savepath='', # when in validation set, you must specify the path to save the 'nii' segmentation results here
297 | names=None, # The names of the patients orderly!
298 | verbose=False,
299 | use_TTA=False, # Test time augmentation, False as default!
300 | save_format=None, # ['nii','npy'], use 'nii' as default. Its purpose is for submission.
301 | # snapshot=False, # for visualization. Default false. It is recommended to generate the visualized figures.
302 | # visual='', # the path to save visualization
303 | ):
304 |
305 | H, W, T = 240, 240, 155
306 | # model.eval()
307 |
308 | runtimes = []
309 | dice_total_WT = 0
310 | dice_total_TC = 0
311 | dice_total_ET = 0
312 | hd_total_WT = 0
313 | hd_total_TC = 0
314 | hd_total_ET = 0
315 | assd_total_WT = 0
316 | assd_total_TC = 0
317 | assd_total_ET = 0
318 |
319 | noise_dice_total_WT = 0
320 | noise_dice_total_TC = 0
321 | noise_dice_total_ET = 0
322 | noise_hd_total_WT = 0
323 | noise_hd_total_TC = 0
324 | noise_hd_total_ET = 0
325 | noise_assd_total_WT = 0
326 | noise_assd_total_TC = 0
327 | noise_assd_total_ET = 0
328 | mean_uncertainty_total = 0
329 | noise_mean_uncertainty_total = 0
330 | certainty_total = 0
331 | noise_certainty_total = 0
332 | num = len(test_loader)
333 | mne_total = 0
334 | noise_mne_total = 0
335 | ece_total = 0
336 | noise_ece_total = 0
337 | ece = 0
338 | noise_ece = 0
339 | ueo_total = 0
340 | noise_ueo_total = 0
341 | for i, data in enumerate(test_loader):
342 | print('-------------------------------------------------------------------')
343 | msg = 'Subject {}/{}, '.format(i+1, len(test_loader))
344 | x, target = data
345 | # noise_m = torch.randn_like(x) * Variance
346 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance * 2, Variance * 2)
347 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance, Variance)
348 | # noise = torch.clamp(torch.randn_like(x) * Variance)
349 | # noised_x = x + noise_m
350 |
351 | noise_m = torch.randn_like(x) * Variance
352 | noised_x = x + noise_m
353 | # if multimodel=='both':
354 | # noised_x[:, 0, ...] = x[:, 0, ...]
355 | x.cuda()
356 | noised_x.cuda()
357 | target = torch.squeeze(target).cpu().numpy()
358 | mean_uncertainty = torch.zeros(0)
359 | noised_mean_uncertainty = torch.zeros(0)
360 | # output = np.zeros((4, x.shape[2], x.shape[3], 155),dtype='float32')
361 | # noised_output = np.zeros((4, x.shape[2], x.shape[3], 155),dtype='float32')
362 | pc = np.zeros((x.shape[2], x.shape[3], 155),dtype='float32')
363 | noised_pc = np.zeros((x.shape[2], x.shape[3], 155),dtype='float32')
364 | if not use_TTA:
365 | # torch.cuda.synchronize() # add the code synchronize() to correctly count the runtime.
366 | # start_time = time.time()
367 | model_len = 2.0 # two modality or four modalityv
368 | if Net_name =='Udrop':
369 | T_drop = 2
370 | uncertainty = torch.zeros(1, x.shape[2], x.shape[3], 155)
371 | noised_uncertainty = torch.zeros(1, x.shape[2], x.shape[3], 155)
372 | for j in range(T_drop):
373 | print('dropout time:{}'.format(j))
374 | logit = tailor_and_concat(x, model) # 1 4 240 240 155
375 | logit_noise = tailor_and_concat(noised_x, model) # 1 4 240 240 155
376 | uncertainty += Uentropy(logit, 4)
377 | noised_uncertainty += Uentropy(logit_noise, 4)
378 | logit = F.softmax(logit, dim=1)
379 | output = logit / model_len
380 | output = output[0, :, :H, :W, :T].cpu().detach().numpy()
381 | pc += output.argmax(0)
382 | # for noise
383 | logit_noise = F.softmax(logit_noise, dim=1)
384 | noised_output = logit_noise / model_len
385 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy()
386 | noised_pc += noised_output.argmax(0)
387 | pc = pc / T_drop
388 | noised_pc = noised_pc / T_drop
389 | uncertainty = torch.squeeze(uncertainty) / T_drop
390 | noised_uncertainty = torch.squeeze(noised_uncertainty) / T_drop
391 |
392 | # logit = logit/T_drop
393 | # logit_noise= logit_noise/T_drop
394 | # Udropout_uncertainty=joblib.load('Udropout_uncertainty.pkl')
395 | else:
396 | logit = tailor_and_concat(x, model)
397 | logit = F.softmax(logit, dim=1)
398 | output = logit / model_len
399 | output = output[0, :, :H, :W, :T].cpu().detach().numpy()
400 | pc = output.argmax(0)
401 | # for input noise
402 | logit_noise = tailor_and_concat(noised_x, model)
403 | logit_noise = F.softmax(logit_noise, dim=1)
404 | noised_output = logit_noise / model_len
405 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy()
406 | noised_pc = noised_output.argmax(0)
407 | uncertainty = Uentropy(logit, 4)
408 | noised_uncertainty = Uentropy(logit_noise, 4)
409 | mean_uncertainty = torch.mean(uncertainty)
410 | noised_mean_uncertainty = torch.mean(noised_uncertainty)
411 | print('current_mean_uncertainty:{} ; current_noised_mean_uncertainty:{}'.format(mean_uncertainty, noised_mean_uncertainty))
412 | if Net_name == 'Udrop':
413 | joblib.dump({'pc': pc,
414 | 'noised_pc': noised_pc,'noised_uncertainty': noised_uncertainty,
415 | 'uncertainty': uncertainty}, 'Udropout_uncertainty_{}.pkl'.format(i))
416 |
417 | # Udropout_uncertainty = joblib.load('Udropout_uncertainty.pkl')
418 |
419 | # lnear = F.softplus(logit)
420 | # torch.cuda.synchronize()
421 | # elapsed_time = time.time() - start_time
422 | # logging.info('Single sample test time consumption {:.2f} minutes!'.format(elapsed_time/60))
423 | # runtimes.append(elapsed_time)
424 |
425 |
426 | # if multimodel == 'both':
427 | # model_len = 2.0 # two modality or four modality
428 | # # logit = F.softmax(logit, dim=1)
429 | # # output = logit / model_len
430 | #
431 | # # for noise
432 | # logit_noise = F.softmax(logit_noise, dim=1)
433 | # noised_output = logit_noise / model_len
434 |
435 | # load_file1 = load_file.replace('7998', '7996')
436 | # if os.path.isfile(load_file1):
437 | # checkpoint = torch.load(load_file1)
438 | # model.load_state_dict(checkpoint['state_dict'])
439 | # print('Successfully load checkpoint {}'.format(load_file1))
440 | # logit = tailor_and_concat(x, model)
441 | # logit = F.softmax(logit, dim=1)
442 | # output += logit / model_len
443 | # load_file1 = load_file.replace('7998', '7997')
444 | # if os.path.isfile(load_file1):
445 | # checkpoint = torch.load(load_file1)
446 | # model.load_state_dict(checkpoint['state_dict'])
447 | # print('Successfully load checkpoint {}'.format(load_file1))
448 | # logit = tailor_and_concat(x, model)
449 | # logit = F.softmax(logit, dim=1)
450 | # output += logit / model_len
451 | # load_file1 = load_file.replace('7998', '7999')
452 | # if os.path.isfile(load_file1):
453 | # checkpoint = torch.load(load_file1)
454 | # model.load_state_dict(checkpoint['state_dict'])
455 | # print('Successfully load checkpoint {}'.format(load_file1))
456 | # logit = tailor_and_concat(x, model)
457 | # logit = F.softmax(logit, dim=1)
458 | # output += logit / model_len
459 | # else:
460 | # # output = F.softmax(logit, dim=1)
461 | # noised_output = F.softmax(logit_noise, dim=1)
462 | else:
463 | # x = x[..., :155]
464 | logit = F.softmax(tailor_and_concat(x, model), 1) # no flip
465 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H
466 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W
467 | logit += F.softmax(tailor_and_concat(x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D
468 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W
469 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D
470 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D
471 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D
472 | # for noise x
473 | noised_x = noised_x[..., :155]
474 | noised_logit = F.softmax(tailor_and_concat(noised_x, model), 1) # no flip
475 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H
476 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W
477 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D
478 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W
479 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D
480 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D
481 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D
482 | output = logit / 8.0 # mean
483 | noised_output = noised_logit / 8.0 # mean
484 | uncertainty = Uentropy(output, 4)
485 | noised_uncertainty = Uentropy(noised_output, 4)
486 | output = output[0, :, :H, :W, :T].cpu().detach().numpy()
487 | pc = output.argmax(0)
488 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy()
489 | noised_pc = noised_output.argmax(0)
490 | mean_uncertainty = torch.mean(uncertainty)
491 | noised_mean_uncertainty = torch.mean(noised_uncertainty)
492 | output = pc
493 | noised_output = noised_pc
494 | U_output = uncertainty.cpu().detach().numpy()
495 | NU_output = noised_uncertainty.cpu().detach().numpy()
496 | certainty_total += mean_uncertainty # mix _uncertainty mean_uncertainty mean_uncertainty_succ
497 | noise_certainty_total += noised_mean_uncertainty # noised_mix_uncertainty noised_mean_uncertainty noised_mean_uncertainty_succ
498 | # ece
499 | ece_total += ece
500 | noise_ece_total += noise_ece
501 | # ueo
502 | # target = torch.squeeze(target).cpu().numpy()
503 | thresholds = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
504 | to_evaluate = dict()
505 | to_evaluate['target'] = target[:, :, :155]
506 | u = torch.squeeze(uncertainty)
507 | U = u.cpu().detach().numpy()
508 | to_evaluate['prediction'] = output
509 | to_evaluate['uncertainty'] = U
510 | UEO = cal_ueo(to_evaluate, thresholds)
511 | ueo_total += UEO
512 | noise_to_evaluate = dict()
513 | noise_to_evaluate['target'] = target[:, :, :155]
514 | noise_u = torch.squeeze(noised_uncertainty)
515 | noise_U = noise_u.cpu().detach().numpy()
516 | noise_to_evaluate['prediction'] = noised_output
517 | noise_to_evaluate['uncertainty'] = noise_U
518 | noise_UEO = cal_ueo(noise_to_evaluate, thresholds)
519 | print('current_UEO:{};current_noise_UEO:{}; current_num:{}'.format(UEO, noise_UEO, i))
520 | noise_ueo_total += noise_UEO
521 | # print(output.shape)
522 | # print(target.shape)
523 | # output = output[0, :, :H, :W, :T].cpu().detach().numpy()
524 | # output = output.argmax(0)
525 | # print(output.shape)
526 | # iou_res = softmax_mIOU_score(pc, target[:, :, :155])
527 | hd_res = softmax_output_hd(pc, target[:, :, :155])
528 | dice_res = softmax_output_dice(pc,target[:,:,:155])
529 | assd_res = softmax_output_assd(pc,target[:, :, :155])
530 | dice_total_WT += dice_res[0]
531 | dice_total_TC += dice_res[1]
532 | dice_total_ET += dice_res[2]
533 | hd_total_WT += hd_res[0][0]
534 | hd_total_TC += hd_res[1][0]
535 | hd_total_ET += hd_res[2][0]
536 | assd_total_WT += assd_res[0]
537 | assd_total_TC += assd_res[1]
538 | assd_total_ET += assd_res[2]
539 |
540 | # for noise_x
541 | noised_output = noised_pc
542 | # noise_iou_res = softmax_mIOU_score(noised_pc, target[:, :, :155])
543 | noise_hd_res = softmax_output_hd(noised_pc, target[:, :, :155])
544 | noise_dice_res = softmax_output_dice(noised_pc,target[:,:,:155])
545 | noised_assd_res = softmax_output_assd(noised_pc,
546 | target[:, :, :155])
547 | noise_dice_total_WT += noise_dice_res[0]
548 | noise_dice_total_TC += noise_dice_res[1]
549 | noise_dice_total_ET += noise_dice_res[2]
550 | noise_hd_total_WT += noise_hd_res[0][0]
551 | noise_hd_total_TC += noise_hd_res[1][0]
552 | noise_hd_total_ET += noise_hd_res[2][0]
553 | noise_assd_total_WT += noised_assd_res[0]
554 | noise_assd_total_TC += noised_assd_res[1]
555 | noise_assd_total_ET += noised_assd_res[2]
556 | mean_uncertainty_total += mean_uncertainty
557 | noise_mean_uncertainty_total += noised_mean_uncertainty
558 | name = str(i)
559 | if names:
560 | name = names[i]
561 | msg += '{:>20}, '.format(name)
562 |
563 | print(msg)
564 | snapshot= False # True
565 | if snapshot:
566 | """ --- grey figure---"""
567 | # Snapshot_img = np.zeros(shape=(H,W,T),dtype=np.uint8)
568 | # Snapshot_img[np.where(output[1,:,:,:]==1)] = 64
569 | # Snapshot_img[np.where(output[2,:,:,:]==1)] = 160
570 | # Snapshot_img[np.where(output[3,:,:,:]==1)] = 255
571 | """ --- colorful figure--- """
572 | Snapshot_img = np.zeros(shape=(H, W, 3, T), dtype=np.float32)
573 | # K = [np.where(output[0,:,:,:] == 1)]
574 | Snapshot_img[:, :, 0, :][np.where(output == 1)] = 255
575 | Snapshot_img[:, :, 1, :][np.where(output == 2)] = 255
576 | Snapshot_img[:, :, 2, :][np.where(output == 3)] = 255
577 |
578 | Noise_Snapshot_img = np.zeros(shape=(H, W, 3, T), dtype=np.float32)
579 | # K = [np.where(output[0,:,:,:] == 1)]
580 | Noise_Snapshot_img[:, :, 0, :][np.where(noised_output == 1)] = 255
581 | Noise_Snapshot_img[:, :, 1, :][np.where(noised_output == 2)] = 255
582 | Noise_Snapshot_img[:, :, 2, :][np.where(noised_output == 3)] = 255
583 | # target_img = np.zeros(shape=(H, W, 3, T), dtype=np.float32)
584 | # K = [np.where(output[0,:,:,:] == 1)]
585 | # target_img[:, :, 0, :][np.where(Otarget == 1)] = 255
586 | # target_img[:, :, 1, :][np.where(Otarget == 2)] = 255
587 | # target_img[:, :, 2, :][np.where(Otarget == 3)] = 255
588 |
589 | for frame in range(T):
590 | if not os.path.exists(os.path.join(savepath, str(Net_name), str(Variance), name)):
591 | os.makedirs(os.path.join(savepath, str(Net_name), str(Variance), name))
592 |
593 | # scipy.misc.imsave(os.path.join(visual, name, str(frame)+'.png'), Snapshot_img[:, :, :, frame])
594 | imageio.imwrite(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '.png'),
595 | Snapshot_img[:, :, :, frame])
596 | imageio.imwrite(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_noised.png'),
597 | Noise_Snapshot_img[:, :, :, frame])
598 | # imageio.imwrite(os.path.join(savepath, str(Net_name), name, str(frame) + '_gt.png'),
599 | # target_img[:, :, :, frame])
600 | # im0 = Image.fromarray(U_output[:, :, frame])
601 | # im1 = Image.fromarray(U_output[:, :, frame])
602 | # im2 = Image.fromarray(U_output[:, :, frame])
603 | # im0 = im0.convert('RGB')
604 | # im1 = im1.convert('RGB')
605 | # im2 = im2.convert('RGB')
606 | # im0.save(os.path.join(savepath, name, str(frame) + '_uncertainty.png'))
607 | # im1.save(os.path.join(savepath, name, str(frame) + '_input_T1.png'))
608 | # im2.save(os.path.join(savepath, name, str(frame) + '_input_T2.png'))
609 | # U_CV = cv2.cvtColor(U_output[:, :, frame], cv2.COLOR_GRAY2BGR)
610 | # U_heatmap = cv2.applyColorMap(U_CV, cv2.COLORMAP_JET)
611 | # cv2.imwrite(os.path.join(savepath, name, str(frame) + '_uncertainty.png'),
612 | # U_heatmap)
613 | # NU_CV = cv2.cvtColor(NU_output[:, :, frame], cv2.COLOR_GRAY2BGR)
614 | # NU_heatmap = cv2.applyColorMap(NU_CV, cv2.COLORMAP_JET)
615 | # cv2.imwrite(os.path.join(savepath, name, str(frame) + '_noised_uncertainty.png'),
616 | # NU_heatmap)
617 | imageio.imwrite(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_uncertainty.png'),
618 | U_output[:, :, frame])
619 | imageio.imwrite(
620 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_noised_uncertainty.png'),
621 | NU_output[:, :, frame])
622 | U_img = cv2.imread(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_uncertainty.png'))
623 | U_heatmap = cv2.applyColorMap(U_img, cv2.COLORMAP_JET)
624 | cv2.imwrite(
625 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_colormap_uncertainty.png'),
626 | U_heatmap)
627 | NU_img = cv2.imread(
628 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_noised_uncertainty.png'))
629 | NU_heatmap = cv2.applyColorMap(NU_img, cv2.COLORMAP_JET)
630 | cv2.imwrite(
631 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_colormap_noised_uncertainty.png'),
632 | NU_heatmap)
633 | aver_ueo = ueo_total / num
634 | aver_noise_ueo = noise_ueo_total/num
635 | aver_certainty = certainty_total / num
636 | aver_noise_certainty = noise_certainty_total / num
637 | aver_dice_WT = dice_total_WT / num
638 | aver_dice_TC = dice_total_TC / num
639 | aver_dice_ET = dice_total_ET / num
640 | aver_hd_WT = hd_total_WT / num
641 | aver_hd_TC = hd_total_TC / num
642 | aver_hd_ET = hd_total_ET / num
643 | aver_noise_dice_WT = noise_dice_total_WT / num
644 | aver_noise_dice_TC = noise_dice_total_TC / num
645 | aver_noise_dice_ET = noise_dice_total_ET / num
646 | aver_noise_hd_WT = noise_hd_total_WT / num
647 | aver_noise_hd_TC = noise_hd_total_TC / num
648 | aver_noise_hd_ET = noise_hd_total_ET / num
649 | aver_assd_WT = assd_total_WT / num
650 | aver_assd_TC = assd_total_TC / num
651 | aver_assd_ET = assd_total_ET / num
652 | aver_noise_assd_WT = noise_assd_total_WT / num
653 | aver_noise_assd_TC = noise_assd_total_TC / num
654 | aver_noise_assd_ET = noise_assd_total_ET / num
655 | aver_noise_mean_uncertainty = noise_mean_uncertainty_total/num
656 | aver_mean_uncertainty = mean_uncertainty_total/num
657 | print('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100,aver_dice_TC*100, aver_dice_ET*100))
658 | print('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100))
659 | print('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (aver_hd_WT,aver_hd_TC, aver_hd_ET))
660 | print('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET))
661 | print('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT, aver_assd_TC, aver_assd_ET))
662 | print('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % (
663 | aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET))
664 | # print('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % (aver_noise_mean_uncertainty,aver_mean_uncertainty))
665 | print('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % (aver_noise_mean_uncertainty,aver_mean_uncertainty))
666 | print('aver_certainty=%f,aver_noise_certainty = %f' % (aver_certainty, aver_noise_certainty))
667 | # print('aver_mne=%f,aver_noise_mne = %f' % (aver_mne, aver_noise_mne))
668 | # print('aver_ece=%f,aver_noise_ece = %f' % (aver_ece, aver_noise_ece))
669 | print('aver_ueo=%f,aver_noise_ueo = %f' % (aver_ueo, aver_noise_ueo))
670 | logging.info(
671 | 'aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100, aver_dice_TC*100, aver_dice_ET*100))
672 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (
673 | aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100))
674 | logging.info('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (
675 | aver_hd_WT, aver_hd_TC, aver_hd_ET))
676 | logging.info('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (
677 | aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET))
678 | logging.info('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT, aver_assd_TC, aver_assd_ET))
679 | logging.info('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % (
680 | aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET))
681 | logging.info('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % (
682 | aver_noise_mean_uncertainty, aver_mean_uncertainty))
683 | logging.info('aver_ueo=%f,aver_noise_ueo = %f' % (aver_ueo, aver_noise_ueo))
684 | # return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET]
685 | return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET],[aver_hd_WT,aver_hd_TC,aver_hd_ET],[aver_noise_hd_WT,aver_noise_hd_TC,aver_noise_hd_ET]
686 |
687 | def testensemblemax(
688 | test_loader,
689 | model,
690 | multimodel,
691 | Net_name,
692 | Variance,
693 | load_file,
694 | savepath='', # when in validation set, you must specify the path to save the 'nii' segmentation results here
695 | names=None, # The names of the patients orderly!
696 | verbose=False,
697 | use_TTA=False, # Test time augmentation, False as default!
698 | save_format=None, # ['nii','npy'], use 'nii' as default. Its purpose is for submission.
699 | # snapshot=False, # for visualization. Default false. It is recommended to generate the visualized figures.
700 | # visual='', # the path to save visualization
701 | ):
702 |
703 | H, W, T = 240, 240, 160
704 | # model.eval()
705 |
706 | runtimes = []
707 | dice_total_WT = 0
708 | dice_total_TC = 0
709 | dice_total_ET = 0
710 | hd_total_WT = 0
711 | hd_total_TC = 0
712 | hd_total_ET = 0
713 | assd_total_WT = 0
714 | assd_total_TC = 0
715 | assd_total_ET = 0
716 | noise_dice_total_WT = 0
717 | noise_dice_total_TC = 0
718 | noise_dice_total_ET = 0
719 | noise_hd_total_WT = 0
720 | noise_hd_total_TC = 0
721 | noise_hd_total_ET = 0
722 | noise_assd_total_WT = 0
723 | noise_assd_total_TC = 0
724 | noise_assd_total_ET = 0
725 | mean_uncertainty_total = 0
726 | noise_mean_uncertainty_total = 0
727 | num = len(test_loader)
728 |
729 | for i, data in enumerate(test_loader):
730 | print('-------------------------------------------------------------------')
731 | msg = 'Subject {}/{}, '.format(i + 1, len(test_loader))
732 | x, target = data
733 | noise_m = torch.randn_like(x) * Variance
734 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance * 2, Variance * 2)
735 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance, Variance)
736 | # noise = torch.clamp(torch.randn_like(x) * Variance)
737 | noised_x = x + noise_m
738 | x.cuda()
739 | noised_x.cuda()
740 | target = torch.squeeze(target).cpu().numpy()
741 |
742 | if not use_TTA:
743 | torch.cuda.synchronize() # add the code synchronize() to correctly count the runtime.
744 | start_time = time.time()
745 | logit = torch.zeros(x.shape[0], 4, x.shape[2], x.shape[3], 155)
746 | logit_noise = torch.zeros(x.shape[0], 4, x.shape[2], x.shape[3], 155)
747 | # load ensemble models
748 | for j in range(10):
749 | print('ensemble model:{}'.format(i))
750 | logit += tailor_and_concat(x, model[j])
751 | logit_noise += tailor_and_concat(noised_x, model[j])
752 | # calculate ensemble uncertainty by normalized entropy
753 | uncertainty = Uentropy(logit/10, 4)
754 | noised_uncertainty = Uentropy(logit_noise/10, 4)
755 |
756 | U_output = torch.squeeze(uncertainty).cpu().detach().numpy()
757 | noised_U_output = torch.squeeze(noised_uncertainty).cpu().detach().numpy()
758 | joblib.dump({'logit': logit, 'logit_noise': logit_noise, 'uncertainty': uncertainty,
759 | 'noised_uncertainty': noised_uncertainty, 'U_output': U_output,
760 | 'noised_U_output': noised_U_output}, 'Uensemble_uncertainty_{}.pkl'.format(i))
761 | # lnear = F.softplus(logit)
762 | torch.cuda.synchronize()
763 | elapsed_time = time.time() - start_time
764 | logging.info('Single sample test time consumption {:.2f} minutes!'.format(elapsed_time/60))
765 | runtimes.append(elapsed_time)
766 | output = F.softmax(logit/10, dim=1)
767 | noised_output = F.softmax(logit_noise/10, dim=1)
768 |
769 |
770 | else:
771 | x = x[..., :155]
772 | logit = F.softmax(tailor_and_concat(x, model), 1) # no flip
773 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H
774 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W
775 | logit += F.softmax(tailor_and_concat(x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D
776 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W
777 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D
778 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D
779 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D
780 | # for noise x
781 | noised_x = noised_x[..., :155]
782 | noised_logit = F.softmax(tailor_and_concat(noised_x, model), 1) # no flip
783 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H
784 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W
785 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D
786 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W
787 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D
788 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D
789 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D
790 | output = logit / 8.0 # mean
791 | noised_output = noised_logit / 8.0 # mean
792 | uncertainty = Uentropy(logit, 4)
793 | noised_uncertainty = Uentropy(noised_output, 4)
794 | mean_uncertainty = torch.mean(uncertainty)
795 | noised_mean_uncertainty = torch.mean(noised_uncertainty)
796 | output = output[0, :, :H, :W, :T].cpu().detach().numpy()
797 | output = output.argmax(0)
798 | # iou_res = softmax_mIOU_score(output, target[:, :, :155])
799 | hd_res = softmax_output_hd(output, target[:, :, :155])
800 | dice_res = softmax_output_dice(output, target[:, :, :155])
801 | assd_res = softmax_output_assd(output, target[:, :, :155])
802 | dice_total_WT += dice_res[0]
803 | dice_total_TC += dice_res[1]
804 | dice_total_ET += dice_res[2]
805 | hd_total_WT += hd_res[0][0]
806 | hd_total_TC += hd_res[1][0]
807 | hd_total_ET += hd_res[2][0]
808 | assd_total_WT += assd_res[0]
809 | assd_total_TC += assd_res[1]
810 | assd_total_ET += assd_res[2]
811 | # for noise_x
812 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy()
813 | noised_output = noised_output.argmax(0)
814 | noise_assd_res = softmax_output_assd(noised_output, target[:, :, :155])
815 | noise_hd_res = softmax_output_hd(noised_output, target[:, :, :155])
816 | noise_dice_res = softmax_output_dice(noised_output, target[:, :, :155])
817 |
818 | noise_dice_total_WT += noise_dice_res[0]
819 | noise_dice_total_TC += noise_dice_res[1]
820 | noise_dice_total_ET += noise_dice_res[2]
821 | noise_hd_total_WT += noise_hd_res[0][0]
822 | noise_hd_total_TC += noise_hd_res[1][0]
823 | noise_hd_total_ET += noise_hd_res[2][0]
824 | noise_assd_total_WT += noise_assd_res[0]
825 | noise_assd_total_TC += noise_assd_res[1]
826 | noise_assd_total_ET += noise_assd_res[2]
827 | mean_uncertainty_total += mean_uncertainty
828 | noise_mean_uncertainty_total += noised_mean_uncertainty
829 | name = str(i)
830 | if names:
831 | name = names[i]
832 | msg += '{:>20}, '.format(name)
833 |
834 | print(msg)
835 | print('current_dice:{} ; current_noised_dice:{}'.format(dice_res, noise_dice_res))
836 | print('current_uncertainty:{} ; current_noised_uncertainty:{}'.format(uncertainty, noised_uncertainty))
837 | # if savepath:
838 | # # .npy for further model ensemble
839 | # # .nii for directly model submission
840 | # assert save_format in ['npy', 'nii']
841 | # if save_format == 'npy':
842 | # np.save(os.path.join(savepath, Net_name +'_'+ name + '_preds'), output)
843 | # if save_format == 'nii':
844 | # # raise NotImplementedError
845 | # oname = os.path.join(savepath, Net_name + '_'+ name + '.nii.gz')
846 | # seg_img = np.zeros(shape=(H, W, T), dtype=np.uint8)
847 | #
848 | # seg_img[np.where(output == 1)] = 1
849 | # seg_img[np.where(output == 2)] = 2
850 | # seg_img[np.where(output == 3)] = 3
851 | # if verbose:
852 | # print('1:', np.sum(seg_img == 1), ' | 2:', np.sum(seg_img == 2), ' | 4:', np.sum(seg_img == 4))
853 | # print('WT:', np.sum((seg_img == 1) | (seg_img == 2) | (seg_img == 4)), ' | TC:',
854 | # np.sum((seg_img == 1) | (seg_img == 4)), ' | ET:', np.sum(seg_img == 4))
855 | # nib.save(nib.Nifti1Image(seg_img, None), oname)
856 | # print('Successfully save {}'.format(oname))
857 |
858 | # if snapshot:
859 | # """ --- grey figure---"""
860 | # # Snapshot_img = np.zeros(shape=(H,W,T),dtype=np.uint8)
861 | # # Snapshot_img[np.where(output[1,:,:,:]==1)] = 64
862 | # # Snapshot_img[np.where(output[2,:,:,:]==1)] = 160
863 | # # Snapshot_img[np.where(output[3,:,:,:]==1)] = 255
864 | # """ --- colorful figure--- """
865 | # Snapshot_img = np.zeros(shape=(H, W, 3, T), dtype=np.uint8)
866 | # Snapshot_img[:, :, 0, :][np.where(output == 1)] = 255
867 | # Snapshot_img[:, :, 1, :][np.where(output == 2)] = 255
868 | # Snapshot_img[:, :, 2, :][np.where(output == 3)] = 255
869 | #
870 | # for frame in range(T):
871 | # if not os.path.exists(os.path.join(visual, name)):
872 | # os.makedirs(os.path.join(visual, name))
873 | # # scipy.misc.imsave(os.path.join(visual, name, str(frame)+'.png'), Snapshot_img[:, :, :, frame])
874 | # imageio.imwrite(os.path.join(visual, name, str(frame)+'.png'), Snapshot_img[:, :, :, frame])
875 |
876 | aver_dice_WT = dice_total_WT / num
877 | aver_dice_TC = dice_total_TC / num
878 | aver_dice_ET = dice_total_ET / num
879 | aver_hd_WT = hd_total_WT / num
880 | aver_hd_TC = hd_total_TC / num
881 | aver_hd_ET = hd_total_ET / num
882 | aver_assd_WT = assd_total_WT / num
883 | aver_assd_TC = assd_total_TC / num
884 | aver_assd_ET = assd_total_ET / num
885 | aver_noise_dice_WT = noise_dice_total_WT / num
886 | aver_noise_dice_TC = noise_dice_total_TC / num
887 | aver_noise_dice_ET = noise_dice_total_ET / num
888 | aver_noise_hd_WT = noise_hd_total_WT / num
889 | aver_noise_hd_TC = noise_hd_total_TC / num
890 | aver_noise_hd_ET = noise_hd_total_ET / num
891 | aver_noise_assd_WT = noise_assd_total_WT / num
892 | aver_noise_assd_TC = noise_assd_total_TC / num
893 | aver_noise_assd_ET = noise_assd_total_ET / num
894 | aver_noise_mean_uncertainty = noise_mean_uncertainty_total/num
895 | aver_mean_uncertainty = mean_uncertainty_total/num
896 | print('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100,aver_dice_TC*100, aver_dice_ET*100))
897 | print('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100))
898 | print('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (aver_hd_WT,aver_hd_TC, aver_hd_ET))
899 | print('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET))
900 | print('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT,aver_assd_TC, aver_assd_ET))
901 | print('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % (aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET))
902 | logging.info(
903 | 'aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100, aver_dice_TC*100, aver_dice_ET*100))
904 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (
905 | aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100))
906 | logging.info('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (
907 | aver_hd_WT, aver_hd_TC, aver_hd_ET))
908 | logging.info('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (
909 | aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET))
910 | logging.info('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT, aver_assd_TC, aver_assd_ET))
911 | logging.info('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % (
912 | aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET))
913 | logging.info('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % (
914 | aver_noise_mean_uncertainty, aver_mean_uncertainty))
915 | # return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET]
916 | return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET],[aver_hd_WT,aver_hd_TC,aver_hd_ET],[aver_noise_hd_WT,aver_noise_hd_TC,aver_noise_hd_ET]
917 | # return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET],[aver_iou_WT,aver_iou_TC,aver_iou_ET],[aver_noise_iou_WT,aver_noise_iou_TC,aver_noise_iou_ET]
--------------------------------------------------------------------------------
/results/nullfile:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/test_uncertainty.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import torch
4 | import numpy as np
5 | import math
6 | import abc
7 | import joblib
8 | import joblib
9 | import torch.nn.functional as F
10 | # import torchfunctions as t_fn
11 | import numpyfunctions as np_fn
12 |
13 | def Uentropy(logits,c):
14 | # c = 4
15 | # logits = torch.randn(1, 4, 240, 240,155).cuda()
16 | pc = F.softmax(logits, dim=1) # 1 4 240 240 155
17 | logpc = F.log_softmax(logits, dim=1) # 1 4 240 240 155
18 | # u_all1 = -pc * logpc / c
19 | u_all = -pc * logpc / math.log(c)
20 | # max_u = torch.max(u_all)
21 | # min_u = torch.min(u_all)
22 | NU = torch.sum(u_all, dim=1)
23 | return NU
24 |
25 | class UncertaintyAndCorrectionEvalNumpy():
26 | def __init__(self, uncertainty_threshold):
27 | super(UncertaintyAndCorrectionEvalNumpy, self).__init__()
28 | self.uncertainty_threshold = uncertainty_threshold
29 |
30 | def __call__(self, to_evaluate=None, results=None):
31 | # entropy should be normalized [0,1] before calling this evaluation
32 | target = to_evaluate['target'].astype(np.bool)
33 | prediction = to_evaluate['prediction'].astype(np.bool)
34 | uncertainty = to_evaluate['uncertainty']
35 |
36 | thresholded_uncertainty = uncertainty > self.uncertainty_threshold
37 | tp, tn, fp, fn, tpu, tnu, fpu, fnu = \
38 | np_fn.uncertainty(prediction, target, thresholded_uncertainty)
39 |
40 | results['tpu'] = tpu
41 | results['tnu'] = tnu
42 | results['fpu'] = fpu
43 | results['fnu'] = fnu
44 |
45 | results['tp'] = tp
46 | results['tn'] = tn
47 | results['fp'] = fp
48 | results['fn'] = fn
49 |
50 | tpu_fpu_ratio = results['tpu'] / results['fpu']
51 | jaccard_index = results['tp'] / (results['tp'] + results['fp'] + results['fn'])
52 | results['dice_benefit'] = tpu_fpu_ratio < jaccard_index
53 | results['accuracy_benefit'] = tpu_fpu_ratio < 1
54 |
55 | results['dice'] = np_fn.dice(prediction, target)
56 | results['accuracy'] = np_fn.accuracy(prediction, target)
57 |
58 | corrected_prediction = prediction.copy()
59 | # correct to background
60 | corrected_prediction[thresholded_uncertainty] = 0
61 |
62 | results['corrected_dice'] = np_fn.dice(corrected_prediction, target)
63 | results['corrected_accuracy'] = np_fn.accuracy(corrected_prediction, target)
64 |
65 | results['dice_benefit_correct'] = (results['corrected_dice'] > results['dice']) == results['dice_benefit']
66 | results['accuracy_benefit_correct'] = (results['corrected_accuracy'] > results['accuracy']) == results[
67 | 'accuracy_benefit']
68 |
69 | corrected_prediction = prediction.copy()
70 | # correct to foreground
71 | corrected_prediction[thresholded_uncertainty] = 1
72 |
73 | results['corrected_add_dice'] = np_fn.dice(corrected_prediction, target)
74 | results['corrected_add_accuracy'] = np_fn.accuracy(corrected_prediction, target)
75 |
76 | def binary_calibration(probabilities, target, n_bins=10, threshold_range = None, mask=None):
77 | if probabilities.ndim > target.ndim:
78 | if probabilities.shape[-1] > 2:
79 | raise ValueError('can only evaluate the calibration for binary classification')
80 | elif probabilities.shape[-1] == 2:
81 | probabilities = probabilities[..., 1]
82 | else:
83 | probabilities = np.squeeze(probabilities, axis=-1)
84 |
85 | if mask is not None:
86 | probabilities = probabilities[mask]
87 | target = target[mask]
88 |
89 | if threshold_range is not None:
90 | low_thres, up_thres = threshold_range
91 | mask = np.logical_and(probabilities < up_thres, probabilities > low_thres)
92 | probabilities = probabilities[mask]
93 | target = target[mask]
94 |
95 | pos_frac, mean_confidence, bin_count, non_zero_bins = \
96 | _binary_calibration(target.flatten(), probabilities.flatten(), n_bins)
97 |
98 | return pos_frac, mean_confidence, bin_count, non_zero_bins
99 |
100 | def _binary_calibration(target, probs_positive_cls, n_bins=10):
101 | # same as sklearn.calibration calibration_curve but with the bin_count returned
102 | bins = np.linspace(0., 1. + 1e-8, n_bins + 1)
103 | binids = np.digitize(probs_positive_cls, bins) - 1
104 |
105 | # # note: this is the original formulation which has always n_bins + 1 as length
106 | # bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=len(bins))
107 | # bin_true = np.bincount(binids, weights=target, minlength=len(bins))
108 | # bin_total = np.bincount(binids, minlength=len(bins))
109 |
110 | bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=n_bins)
111 | bin_true = np.bincount(binids, weights=target, minlength=n_bins)
112 | bin_total = np.bincount(binids, minlength=n_bins)
113 |
114 | nonzero = bin_total != 0
115 | prob_true = (bin_true[nonzero] / bin_total[nonzero])
116 | prob_pred = (bin_sums[nonzero] / bin_total[nonzero])
117 |
118 | return prob_true, prob_pred, bin_total[nonzero], nonzero
119 |
120 | def _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim):
121 | if bin_weighting == 'proportion':
122 | bin_proportions = bin_count / bin_count.sum()
123 | elif bin_weighting == 'log_proportion':
124 | bin_proportions = np.log(bin_count) / np.log(bin_count).sum()
125 | elif bin_weighting == 'power_proportion':
126 | bin_proportions = bin_count**(1/n_dim) / (bin_count**(1/n_dim)).sum()
127 | elif bin_weighting == 'mean_proportion':
128 | bin_proportions = 1 / non_zero_bins.sum()
129 | else:
130 | raise ValueError('unknown bin weighting "{}"'.format(bin_weighting))
131 | return bin_proportions
132 |
133 | def ece_binary(probabilities, target, n_bins=10, threshold_range= None, mask=None, out_bins=None,
134 | bin_weighting='proportion'):
135 | # input: 1. probabilities (np) 2. target (np) 3. threshold_range (tuple[low,high]) 4. mask
136 |
137 | n_dim = target.ndim
138 |
139 | pos_frac, mean_confidence, bin_count, non_zero_bins = \
140 | binary_calibration(probabilities, target, n_bins, threshold_range, mask)
141 |
142 | bin_proportions = _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim)
143 |
144 | if out_bins is not None:
145 | out_bins['bins_count'] = bin_count
146 | out_bins['bins_avg_confidence'] = mean_confidence
147 | out_bins['bins_positive_fraction'] = pos_frac
148 | out_bins['bins_non_zero'] = non_zero_bins
149 |
150 | ece = (np.abs(mean_confidence - pos_frac) * bin_proportions).sum()
151 | return ece
152 |
153 | def pkload(fname):
154 | with open(fname, 'rb') as f:
155 | return joblib.load(f)
156 | # return pickle.load(f)
157 |
158 | def cal_u(uncertainty):
159 | mean_uncertainty_total = 0
160 | sum_uncertainty_total = 0
161 | for i in range(len(uncertainty)):
162 | u = uncertainty[i][0]
163 | total_uncertainty = torch.sum(u, -1, keepdim=True)
164 | # print('current_sum_certainty:{} ; current_mean_certainty:{}'.format(torch.mean(total_uncertainty),
165 | # torch.mean(u)))
166 | sum_uncertainty_total += torch.mean(total_uncertainty)
167 | mean_uncertainty_total += torch.mean(u)
168 | num = len(uncertainty)
169 | return sum_uncertainty_total/num,mean_uncertainty_total/num
170 |
171 | def cal_ece(logits,targets):
172 | ece_total = 0
173 | for i in range(len(logits)):
174 | # pc = torch.zeros(1,1,128,128,128)
175 | logit = logits[i][0]
176 | target = targets[i][..., :155]
177 | pred = F.softmax(logit, dim=0)
178 | pc = pred.cpu().detach().numpy()
179 | pc = pc.argmax(0)
180 | ece_total += ece_binary(pc, target)
181 | num = len(logits)
182 | return ece_total/num
183 |
184 | def cal_ueo(to_evaluate,thresholds):
185 | UEO = []
186 | for threshold in thresholds:
187 | metric = UncertaintyAndCorrectionEvalNumpy(threshold)
188 | UEO.append(list(metric(to_evaluate)))
189 | return UEO
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 20003 train_spup3.py
2 |
3 | import argparse
4 | import os
5 | import random
6 | import logging
7 | import numpy as np
8 | import time
9 | import setproctitle
10 | import torch
11 | import torch.optim
12 | from sklearn.externals import joblib
13 | # from models import criterions
14 | from models.lib.VNet3D import VNet
15 | from plot import loss_plot,metrics_plot
16 | from models.lib.UNet3DZoo import Unet,AttUnet,Unetdrop
17 | from models.criterions import softBCE_dice,softmax_dice,FocalLoss,DiceLoss
18 | from data.BraTS2019 import BraTS
19 | from torch.utils.data import DataLoader
20 | from tensorboardX import SummaryWriter
21 | from predict import validate_softmax,test_softmax,testensemblemax
22 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
23 |
24 | local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
25 |
26 | parser = argparse.ArgumentParser()
27 |
28 | # Basic Information
29 | parser.add_argument('--user', default='name of user', type=str)
30 | parser.add_argument('--experiment', default='TransBTS', type=str)
31 | parser.add_argument('--date', default=local_time.split(' ')[0], type=str)
32 |
33 | parser.add_argument('--description',
34 | default='TransBTS,'
35 | 'training on train.txt!',
36 | type=str)
37 |
38 | # DataSet Information
39 | parser.add_argument('--root', default='E:/BraTSdata1/archive2019', type=str) # folder_data_path
40 | parser.add_argument('--train_dir', default='MICCAI_BraTS_2019_Data_TTraining', type=str)
41 | parser.add_argument('--valid_dir', default='MICCAI_BraTS_2019_Data_TValidation', type=str)
42 | parser.add_argument('--test_dir', default='MICCAI_BraTS_2019_Data_TTest', type=str)
43 | parser.add_argument("--mode", default="train", type=str, help="train/test/train&test")
44 | # parser.add_argument('--train_file',
45 | # default='C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt', type=str)
46 | # parser.add_argument('--valid_file', default='C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt',
47 | # type=str)
48 | # parser.add_argument('--test_file', default='C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt',
49 | # type=str)
50 | parser.add_argument('--train_file',
51 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt',
52 | type=str)
53 | parser.add_argument('--valid_file',
54 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt',
55 | type=str)
56 | parser.add_argument('--test_file',
57 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt',
58 | type=str)
59 | parser.add_argument('--dataset', default='BraTS', type=str)
60 | parser.add_argument('--input_C', default=4, type=int)
61 | parser.add_argument('--input_H', default=240, type=int)
62 | parser.add_argument('--input_W', default=240, type=int)
63 | parser.add_argument('--input_D', default=160, type=int) # 155
64 | parser.add_argument('--crop_H', default=128, type=int)
65 | parser.add_argument('--crop_W', default=128, type=int)
66 | parser.add_argument('--crop_D', default=128, type=int)
67 | parser.add_argument('--output_D', default=155, type=int)
68 | parser.add_argument('--rlt', default=-1, type=float,
69 | help='relation between CE/FL and dice')
70 | # Training Information
71 | parser.add_argument('--lr', default=0.002, type=float)
72 | parser.add_argument('--weight_decay', default=1e-5, type=float)
73 | parser.add_argument('--amsgrad', default=True, type=bool)
74 | # parser.add_argument('--criterion', default='softmaxBCE_dice', type=str)
75 | parser.add_argument('--submission', default='./results', type=str)
76 | parser.add_argument('--visual', default='visualization', type=str)
77 | parser.add_argument('--num_class', default=4, type=int)
78 | parser.add_argument('--seed', default=1000, type=int)
79 | parser.add_argument('--no_cuda', default=False, type=bool)
80 | parser.add_argument('--batch_size', default=2, type=int, help="2/4/8")
81 | parser.add_argument('--start_epoch', default=0, type=int)
82 | parser.add_argument('--end_epoch', default=200, type=int)
83 | parser.add_argument('--save_freq', default=5, type=int)
84 | parser.add_argument('--resume', default='', type=str)
85 | parser.add_argument('--load', default=True, type=bool)
86 | parser.add_argument('--modal', default='t2', type=str) # multi-modal
87 | parser.add_argument('--model_name', default='V', type=str, help="AU/V/U")
88 | parser.add_argument('--Variance', default=2, type=int) # 1 2
89 | parser.add_argument('--use_TTA', default=False, type=bool, help="True/False")
90 | parser.add_argument('--save_format', default='nii', type=str)
91 | parser.add_argument('--test_date', default='2022-01-04', type=str)
92 | parser.add_argument('--test_epoch', default=184, type=int)
93 | args = parser.parse_args()
94 |
95 | def val(model,checkpoint_dir,epoch,best_dice):
96 | valid_list = os.path.join(args.root, args.valid_dir, args.valid_file)
97 | valid_root = os.path.join(args.root, args.valid_dir)
98 | valid_set = BraTS(valid_list, valid_root,'valid',args.modal)
99 | valid_loader = DataLoader(valid_set, batch_size=1)
100 | print('Samples for valid = {}'.format(len(valid_set)))
101 |
102 | start_time = time.time()
103 | model.eval()
104 | with torch.no_grad():
105 | best_dice,aver_dice,aver_iou = validate_softmax(save_dir = checkpoint_dir,
106 | best_dice = best_dice,
107 | current_epoch = epoch,
108 | save_freq = args.save_freq,
109 | end_epoch = args.end_epoch,
110 | valid_loader = valid_loader,
111 | model = model,
112 | multimodel = args.modal,
113 | Net_name = args.model_name,
114 | names = valid_set.names,
115 | )
116 | # dice_list.append(aver_dice)
117 | # iou_list.append(aver_iou)
118 | end_time = time.time()
119 | full_test_time = (end_time-start_time)/60
120 | average_time = full_test_time/len(valid_set)
121 | print('{:.2f} minutes!'.format(average_time))
122 | return best_dice,aver_dice,aver_iou
123 |
124 | def test(model):
125 | for arg in vars(args):
126 | logging.info('{}={}'.format(arg, getattr(args, arg)))
127 | logging.info('----------------------------------------This is a halving line----------------------------------')
128 | logging.info('{}'.format(args.description))
129 | test_list = os.path.join(args.root, args.test_dir, args.test_file)
130 | test_root = os.path.join(args.root, args.test_dir)
131 | test_set = BraTS(test_list, test_root,'test',args.modal)
132 | test_loader = DataLoader(test_set, batch_size=1)
133 | print('Samples for test = {}'.format(len(test_set)))
134 |
135 | logging.info('final test........')
136 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
137 | 'checkpoint', args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth'.format(args.test_epoch))
138 |
139 | if os.path.exists(load_file):
140 | checkpoint = torch.load(load_file)
141 | model.load_state_dict(checkpoint['state_dict'])
142 | args.start_epoch = checkpoint['epoch']
143 | print('Successfully load checkpoint {}'.format(os.path.join(args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth')))
144 | else:
145 | print('There is no resume file to load!')
146 |
147 |
148 | start_time = time.time()
149 | model.eval()
150 | with torch.no_grad():
151 | aver_dice,aver_noise_dice,aver_hd,aver_noise_hd = test_softmax( test_loader = test_loader,
152 | model = model,
153 | multimodel = args.modal,
154 | Net_name=args.model_name,
155 | Variance = args.Variance,
156 | load_file=load_file,
157 | savepath = args.submission,
158 | names = test_set.names,
159 | use_TTA = args.use_TTA,
160 | save_format = args.save_format,
161 | )
162 | end_time = time.time()
163 | full_test_time = (end_time-start_time)/60
164 | average_time = full_test_time/len(test_set)
165 | print('{:.2f} minutes!'.format(average_time))
166 | logging.info('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice[0],aver_dice[1],aver_dice[2]))
167 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (aver_noise_dice[0], aver_noise_dice[1], aver_noise_dice[2]))
168 | logging.info('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (aver_hd[0],aver_hd[1],aver_hd[2]))
169 | logging.info('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (aver_noise_hd[0], aver_noise_hd[1], aver_noise_hd[2]))
170 |
171 |
172 | def test_ensemble(model):
173 |
174 | test_list = os.path.join(args.root, args.test_dir, args.test_file)
175 | test_root = os.path.join(args.root, args.test_dir)
176 | test_set = BraTS(test_list, test_root,'test',args.modal)
177 | test_loader = DataLoader(test_set, batch_size=1)
178 | print('Samples for test = {}'.format(len(test_set)))
179 |
180 | logging.info('final test........')
181 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
182 | 'checkpoint', args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth'.format(args.test_epoch))
183 |
184 | if os.path.exists(load_file):
185 | checkpoint = torch.load(load_file)
186 | model.load_state_dict(checkpoint['state_dict'])
187 | args.start_epoch = checkpoint['epoch']
188 | print('Successfully load checkpoint {}'.format(os.path.join(args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth')))
189 | else:
190 | print('There is no resume file to load!')
191 |
192 | # load ensemble models
193 | load_model=[]
194 | for i in range(10):
195 | save_name1 = args.model_name + '_' + args.modal + '_epoch_' +'199' + 'e' + str(i) + '.pth'
196 | load_model[i] = torch.load(save_name1)
197 | model[i] = load_model[i]['state_dict']
198 |
199 | start_time = time.time()
200 | model.eval()
201 | with torch.no_grad():
202 | aver_dice, aver_noise_dice, aver_hd, aver_noise_hd = testensemblemax(test_loader=test_loader,
203 | model=model,
204 | multimodel=args.modal,
205 | Net_name=args.model_name,
206 | Variance=args.Variance,
207 | load_file=load_file,
208 | savepath=args.submission,
209 | names=test_set.names,
210 | use_TTA=args.use_TTA,
211 | save_format=args.save_format,
212 | )
213 | end_time = time.time()
214 | full_test_time = (end_time - start_time) / 60
215 | average_time = full_test_time / len(test_set)
216 | print('{:.2f} minutes!'.format(average_time))
217 | logging.info('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice[0], aver_dice[1], aver_dice[2]))
218 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (
219 | aver_noise_dice[0], aver_noise_dice[1], aver_noise_dice[2]))
220 | logging.info('aver_iou_WT=%f,aver_iou_TC = %f,aver_iou_ET = %f' % (aver_hd[0], aver_hd[1], aver_hd[2]))
221 | logging.info('aver_noise_iou_WT=%f,aver_noise_iou_TC = %f,aver_noise_iou_ET = %f' % (
222 | aver_noise_hd[0], aver_noise_hd[1], aver_noise_hd[2]))
223 |
224 | def train(criterion,model,criterion_fl,criterion_dl):
225 | # dataset
226 | train_list = os.path.join(args.root, args.train_dir, args.train_file)
227 | train_root = os.path.join(args.root, args.train_dir)
228 | train_set = BraTS(train_list, train_root, args.mode,args.modal)
229 | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size)
230 | print('Samples for train = {}'.format(len(train_set)))
231 |
232 | logging.info('--------------------------------------This is all argsurations----------------------------------')
233 | for arg in vars(args):
234 | logging.info('{}={}'.format(arg, getattr(args, arg)))
235 | logging.info('----------------------------------------This is a halving line----------------------------------')
236 | logging.info('{}'.format(args.description))
237 |
238 | torch.manual_seed(args.seed)
239 | torch.cuda.manual_seed(args.seed)
240 | random.seed(args.seed)
241 | np.random.seed(args.seed)
242 |
243 | model.cuda()
244 | model.train()
245 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.amsgrad)
246 | # criterion = getattr(criterions, args.criterion)
247 |
248 | checkpoint_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', args.experiment+args.date)
249 | if not os.path.exists(checkpoint_dir):
250 | os.makedirs(checkpoint_dir)
251 |
252 | resume = ''
253 |
254 | writer = SummaryWriter()
255 |
256 | if os.path.isfile(resume) and args.load:
257 | logging.info('loading checkpoint {}'.format(resume))
258 | checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
259 |
260 | model.load_state_dict(checkpoint['state_dict'])
261 |
262 | logging.info('Successfully loading checkpoint {} and training from epoch: {}'
263 | .format(args.resume, args.start_epoch))
264 | else:
265 | logging.info('re-training!!!')
266 |
267 | start_time = time.time()
268 |
269 | torch.set_grad_enabled(True)
270 | loss_list = []
271 | dice_list = []
272 | iou_list = []
273 | best_dice =0
274 | for epoch in range(args.start_epoch, args.end_epoch):
275 | epoch_loss = 0
276 | loss = 0
277 | # loss1 = 0
278 | # loss2 = 0
279 | # loss3 = 0
280 | setproctitle.setproctitle('{}: {}/{}'.format(args.user, epoch+1, args.end_epoch))
281 | start_epoch = time.time()
282 | for i, data in enumerate(train_loader):
283 |
284 | adjust_learning_rate(optimizer, epoch, args.end_epoch, args.lr)
285 |
286 | x, target = data
287 | x = x.cuda()
288 | target = target.cuda()
289 | output = model(x)
290 |
291 | if args.rlt > 0:
292 | loss = criterion_fl(output, target) + args.rlt * criterion_dl(output, target)
293 | else:
294 | loss = criterion_dl(output, target)
295 |
296 | # loss, loss1, loss2, loss3 = criterion(output, target)
297 | # loss1.requires_grad_(True)
298 | # loss2.requires_grad_(True)
299 | # loss3.requires_grad_(True)
300 | optimizer.zero_grad()
301 | loss.backward()
302 | # loss1.backward()
303 | # loss2.backward()
304 | # loss3.backward()
305 | optimizer.step()
306 | reduce_loss = loss.data.cpu().numpy()
307 | # reduce_loss1 = loss1.data.cpu().numpy()
308 | # reduce_loss2 = loss2.data.cpu().numpy()
309 | # reduce_loss3 = loss3.data.cpu().numpy()
310 | # logging.info('Epoch: {}_Iter:{} loss: {:.5f} || 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||'
311 | # .format(epoch, i, reduce_loss, reduce_loss1, reduce_loss2, reduce_loss3))
312 | logging.info('Epoch: {}_Iter:{} loss: {:.5f}'
313 | .format(epoch, i, reduce_loss))
314 |
315 | epoch_loss += reduce_loss
316 | end_epoch = time.time()
317 | loss_list.append(epoch_loss)
318 |
319 | writer.add_scalar('lr', optimizer.defaults['lr'], epoch)
320 | writer.add_scalar('loss', loss, epoch)
321 | # writer.add_scalar('loss1', loss1, epoch)
322 | # writer.add_scalar('loss2', loss2, epoch)
323 | # writer.add_scalar('loss3', loss3, epoch)
324 |
325 | epoch_time_minute = (end_epoch-start_epoch)/60
326 | remaining_time_hour = (args.end_epoch-epoch-1)*epoch_time_minute/60
327 | logging.info('Current epoch time consumption: {:.2f} minutes!'.format(epoch_time_minute))
328 | logging.info('Estimated remaining training time: {:.2f} hours!'.format(remaining_time_hour))
329 | best_dice,aver_dice,aver_iou = val(model,checkpoint_dir,epoch,best_dice)
330 | dice_list.append(aver_dice)
331 | iou_list.append(aver_iou)
332 | writer.close()
333 | # validation
334 |
335 | end_time = time.time()
336 | total_time = (end_time-start_time)/3600
337 | logging.info('The total training time is {:.2f} hours'.format(total_time))
338 | logging.info('----------------------------------The training process finished!-----------------------------------')
339 |
340 | loss_plot(args, loss_list)
341 | metrics_plot(args, 'dice',dice_list)
342 |
343 | def adjust_learning_rate(optimizer, epoch, max_epoch, init_lr, power=0.9):
344 | for param_group in optimizer.param_groups:
345 | param_group['lr'] = round(init_lr * np.power(1-(epoch) / max_epoch, power), 8)
346 |
347 |
348 | def log_args(log_file):
349 |
350 | logger = logging.getLogger()
351 | logger.setLevel(logging.DEBUG)
352 | formatter = logging.Formatter(
353 | '%(asctime)s ===> %(message)s',
354 | datefmt='%Y-%m-%d %H:%M:%S')
355 |
356 | # args FileHandler to save log file
357 | fh = logging.FileHandler(log_file)
358 | fh.setLevel(logging.DEBUG)
359 | fh.setFormatter(formatter)
360 |
361 | # args StreamHandler to print log to console
362 | ch = logging.StreamHandler()
363 | ch.setLevel(logging.DEBUG)
364 | ch.setFormatter(formatter)
365 |
366 | # add the two Handler
367 | logger.addHandler(ch)
368 | logger.addHandler(fh)
369 |
370 | if __name__ == '__main__':
371 | # criterion = softBCE_dice(aggregate="sum")
372 | criterion = softmax_dice
373 | criterion_fl = FocalLoss(4)
374 | criterion_dl = DiceLoss()
375 | num = 2
376 | # _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned")
377 | # x = [i for i in range(num)]
378 | # l = [i*random.random() for i in range(num)]
379 | # plt.figure()
380 | # plt.plot(x, l, label='dice')
381 | # log
382 | log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'log', args.experiment + args.date)
383 | log_file = log_dir + '.txt'
384 | log_args(log_file)
385 | # Net model choose
386 | if args.model_name == 'AU' and args.modal == 'both':
387 | model = AttUnet(in_channels=2, base_channels=16, num_classes=4)
388 | elif args.model_name == 'AU':
389 | model = AttUnet(in_channels=1, base_channels=16, num_classes=4)
390 | elif args.model_name == 'V' and args.modal == 'both':
391 | model = VNet(n_channels=2, n_classes=4, n_filters=16, normalization='gn', has_dropout=False)
392 | elif args.model_name == 'V' :
393 | model = VNet(n_channels=1, n_classes=4, n_filters=16, normalization='gn', has_dropout=False)
394 | elif args.model_name == 'Udrop'and args.modal == 'both':
395 | model = Unetdrop(in_channels=2, base_channels=16, num_classes=4)
396 | elif args.model_name == 'Udrop':
397 | model = Unetdrop(in_channels=1, base_channels=16, num_classes=4)
398 | elif args.model_name == 'U' and args.modal == 'both':
399 | model = Unet(in_channels=2, base_channels=16, num_classes=4)
400 | else:
401 | model = Unet(in_channels=1, base_channels=16, num_classes=4)
402 | # if 'train' in args.mode:
403 | # train(criterion,model,criterion_fl,criterion_dl)
404 | args.mode = 'test'
405 | # Udropout_uncertainty = joblib.load('Udropout_uncertainty.pkl')
406 | test(model)
407 | # test_ensemble(model)
408 |
--------------------------------------------------------------------------------
/trainTBraTS.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import time
5 | from torch.autograd import Variable
6 | from torch.utils.data import DataLoader
7 | from models.trustedseg import TMSU
8 | from data.BraTS2019 import BraTS
9 | from predict import tailor_and_concat,softmax_mIOU_score,softmax_output_dice
10 | import torch.nn.functional as F
11 | import warnings
12 | warnings.filterwarnings("ignore")
13 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
14 | import numpy as np
15 |
16 | class AverageMeter(object):
17 | """Computes and stores the average and current value"""
18 |
19 | def __init__(self):
20 | self.reset()
21 |
22 | def reset(self):
23 | self.val = 0
24 | self.avg = 0
25 | self.sum = 0
26 | self.count = 0
27 |
28 | def update(self, val, n=1):
29 | self.val = val
30 | self.sum += val * n
31 | self.count += n
32 | self.avg = self.sum / self.count
33 |
34 |
35 | if __name__ == "__main__":
36 | import argparse
37 |
38 | parser = argparse.ArgumentParser()
39 | # Basic Information
40 | local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
41 | parser.add_argument('--user', default='name of user', type=str)
42 | parser.add_argument('--experiment', default='TBraTS', type=str)
43 | parser.add_argument('--date', default=local_time.split(' ')[0], type=str)
44 | parser.add_argument('--description',
45 | default='Trusted brain tumor segmentation by coco,'
46 | 'training on train.txt!',
47 | type=str)
48 | # training detalis
49 | parser.add_argument('--epochs', type=int, default=199, metavar='N',
50 | help='number of epochs to train [default: 500]')
51 | parser.add_argument('--test_epoch', type=int, default=198, metavar='N',
52 | help='best epoch')
53 | parser.add_argument('--lambda-epochs', type=int, default=50, metavar='N',
54 | help='gradually increase the value of lambda from 0 to 1')
55 | parser.add_argument('--save_freq', default=1, type=int)
56 | parser.add_argument('--lr', type=float, default=0.002, metavar='LR',
57 | help='learning rate')
58 | # DataSet Information
59 | parser.add_argument('--root', default='E:/BraTSdata1/archive2019', type=str)
60 | parser.add_argument('--save_dir', default='./results', type=str)
61 | parser.add_argument('--train_dir', default='MICCAI_BraTS_2019_Data_TTraining', type=str)
62 | parser.add_argument('--valid_dir', default='MICCAI_BraTS_2019_Data_TValidation', type=str)
63 | parser.add_argument('--test_dir', default='MICCAI_BraTS_2019_Data_TTest', type=str)
64 | parser.add_argument("--mode", default="train", type=str, help="train/test/train&test")
65 | parser.add_argument('--train_file',
66 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt',
67 | type=str)
68 | parser.add_argument('--valid_file',
69 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt',
70 | type=str)
71 | parser.add_argument('--test_file',
72 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt',
73 | type=str)
74 | parser.add_argument('--dataset', default='brats', type=str)
75 | parser.add_argument('--classes', default=4, type=int)# brain tumor class
76 | parser.add_argument('--input_H', default=240, type=int)
77 | parser.add_argument('--input_W', default=240, type=int)
78 | parser.add_argument('--input_D', default=160, type=int) # 155
79 | parser.add_argument('--crop_H', default=128, type=int)
80 | parser.add_argument('--crop_W', default=128, type=int)
81 | parser.add_argument('--crop_D', default=128, type=int)
82 | parser.add_argument('--output_D', default=155, type=int)
83 | parser.add_argument('--batch_size', default=4, type=int, help="2/4/8")
84 | parser.add_argument('--input_dims', default='four', type=str) # multi-modal/Single-modal
85 | parser.add_argument('--model_name', default='V', type=str) # multi-modal V:168 AU:197
86 | parser.add_argument('--Variance', default=0.5, type=int) # noise level
87 | parser.add_argument('--use_TTA', default=True, type=bool, help="True/False")
88 | args = parser.parse_args()
89 | args.dims = [[240,240,160], [240,240,160]]
90 | args.modes = len(args.dims)
91 |
92 | train_list = os.path.join(args.root, args.train_dir, args.train_file)
93 | train_root = os.path.join(args.root, args.train_dir)
94 | train_set = BraTS(train_list, train_root, args.mode,args.input_dims)
95 | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size)
96 | print('Samples for train = {}'.format(len(train_set)))
97 | valid_list = os.path.join(args.root, args.valid_dir, args.valid_file)
98 | valid_root = os.path.join(args.root, args.valid_dir)
99 | valid_set = BraTS(valid_list, valid_root,'valid',args.input_dims)
100 | valid_loader = DataLoader(valid_set, batch_size=1)
101 | print('Samples for valid = {}'.format(len(valid_set)))
102 | test_list = os.path.join(args.root, args.test_dir, args.test_file)
103 | test_root = os.path.join(args.root, args.test_dir)
104 | test_set = BraTS(test_list, test_root,'test',args.input_dims)
105 | test_loader = DataLoader(test_set, batch_size=1)
106 | print('Samples for test = {}'.format(len(test_set)))
107 |
108 | model = TMSU(args.classes, args.modes, args.model_name, args.input_dims,args.epochs, args.lambda_epochs) # lambda KL divergence
109 | total = sum([param.nelement() for param in model.parameters()])
110 | print("Number of model's parameter: %.2fM" % (total / 1e6))
111 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
112 |
113 | model.cuda()
114 |
115 | def train(epoch):
116 |
117 | model.train()
118 | loss_meter = AverageMeter()
119 | step = 0
120 | dt_size = len(train_loader.dataset)
121 | for i, data in enumerate(train_loader):
122 | step += 1
123 | input, target = data
124 | x = input.cuda() # for multi-modal combine train
125 | target = target.cuda()
126 | args.mode = 'train'
127 | evidences, loss = model(x,target,epoch,args.mode)
128 |
129 | print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // train_loader.batch_size + 1, loss.item()))
130 |
131 | optimizer.zero_grad()
132 | loss.requires_grad_(True).backward()
133 | optimizer.step()
134 |
135 | loss_meter.update(loss.item())
136 | return loss_meter.avg
137 |
138 | def val(args,current_epoch,best_dice):
139 | print('===========>Validation begining!===========')
140 | model.eval()
141 | loss_meter = AverageMeter()
142 | dice_total, iou_total = 0, 0
143 | step = 0
144 | # model.eval()
145 | for i, data in enumerate(valid_loader):
146 | step += 1
147 | input, target = data
148 |
149 | # add gaussian noise to input data
150 | x = dict()
151 | for m_num in range(input.shape[1]):
152 | x[m_num] = input[..., m_num, :, :, :, ].unsqueeze(1).cuda()
153 | target = target.cuda()
154 |
155 | with torch.no_grad():
156 | args.mode = 'val'
157 | evidences, loss = model(x, target[:, :, :, :155], current_epoch,args.mode) # two modality or four modality
158 | # max
159 | _, predicted = torch.max(evidence.data, 1)
160 | output = predicted.cpu().detach().numpy()
161 |
162 | target = torch.squeeze(target).cpu().numpy()
163 | iou_res = softmax_mIOU_score(output, target[:, :, :155])
164 | dice_res = softmax_output_dice(output, target[:, :, :155])
165 | print('current_iou:{} ; current_dice:{}'.format(iou_res, dice_res))
166 | dice_total += dice_res[1]
167 | iou_total += iou_res[1]
168 | # loss & noised loss
169 | loss_meter.update(loss.item())
170 | aver_dice = dice_total / len(valid_loader)
171 | aver_iou = iou_total / len(valid_loader)
172 | if aver_dice > best_dice \
173 | or (current_epoch + 1) % int(args.epochs - 1) == 0 \
174 | or (current_epoch + 1) % int(args.epochs - 2) == 0 \
175 | or (current_epoch + 1) % int(args.epochs - 3) == 0:
176 | print('aver_dice:{} > best_dice:{}'.format(aver_dice, best_dice))
177 | best_dice = aver_dice
178 | print('===========>save best model!')
179 | file_name = os.path.join(args.save_dir, '_epoch_{}.pth'.format(current_epoch))
180 | torch.save({
181 | 'epoch': current_epoch,
182 | 'state_dict': model.state_dict(),
183 | },
184 | file_name)
185 | return loss_meter.avg, best_dice
186 |
187 | def test(args):
188 | print('===========>Test begining!===========')
189 |
190 | loss_meter = AverageMeter()
191 | noised_loss_meter = AverageMeter()
192 | dice_total,iou_total = 0,0
193 | noised_dice_total,noised_iou_total = 0,0
194 | step = 0
195 | dt_size = len(test_loader.dataset)
196 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
197 | args.save_dir,
198 | '_epoch_{}.pth'.format(args.test_epoch))
199 |
200 | if os.path.exists(load_file):
201 | checkpoint = torch.load(load_file)
202 | model.load_state_dict(checkpoint['state_dict'])
203 | args.start_epoch = checkpoint['epoch']
204 | print('Successfully load checkpoint {}'.format(os.path.join(args.save_dir + '/_epoch_' + str(args.test_epoch))))
205 | else:
206 | print('There is no resume file to load!')
207 | model.eval()
208 | for i, data in enumerate(test_loader):
209 | step += 1
210 | input, target = data
211 | # add gaussian noise to input data
212 | noise_m = torch.randn_like(input) * args.Variance
213 | noised_input = input + noise_m
214 | # x = input.cuda()
215 | # noised_x = noised_input.cuda()
216 | x = dict()
217 | noised_x = dict()
218 | for m_num in range(input.shape[1]):
219 | x[m_num] = input[...,m_num,:,:,:,].unsqueeze(1).cuda()
220 | noised_x[m_num] = noised_input[...,m_num,:,:,:,].unsqueeze(1).cuda()
221 | target = target.cuda()
222 |
223 | with torch.no_grad():
224 | args.mode = 'test'
225 | if not args.use_TTA:
226 | evidences, loss = model(x, target[:, :, :, :155], args.epochs,args.mode)
227 | noised_evidences, noised_loss = model(noised_x, target[:, :, :, :155], args.epochs,args.mode)
228 | else:
229 | evidences, loss = model(x, target[:, :, :, :155], args.epochs,args.mode,args.use_TTA)
230 | noised_evidences, noised_loss = model(noised_x, target[:, :, :, :155], args.epochs,args.mode,args.use_TTA)
231 | # results with TTA or not
232 |
233 | output = F.softmax(evidence, dim=1)
234 | # for input noise
235 | noised_output = F.softmax(noised_evidence, dim=1)
236 |
237 | # dice
238 | output = output[0, :, :args.input_H, :args.input_W, :args.input_D].cpu().detach().numpy()
239 | output = output.argmax(0)
240 | target = torch.squeeze(target).cpu().numpy()
241 | iou_res = softmax_mIOU_score(output, target[:, :, :155])
242 | dice_res = softmax_output_dice(output, target[:, :, :155])
243 | dice_total += dice_res[1]
244 | iou_total += iou_res[1]
245 | # for noise_x
246 | noised_output = noised_output[0, :, :args.input_H, :args.input_W, :args.input_D].cpu().detach().numpy()
247 | noised_output = noised_output.argmax(0)
248 | noised_iou_res = softmax_mIOU_score(noised_output, target[:, :, :155])
249 | noised_dice_res = softmax_output_dice(noised_output, target[:, :, :155])
250 | noised_dice_total += noised_dice_res[1]
251 | noised_iou_total += noised_iou_res[1]
252 | print('current_dice:{} ; current_noised_dice:{}'.format(dice_res, noised_dice_res))
253 | # loss & noised loss
254 | loss_meter.update(loss.item())
255 | noised_loss_meter.update(noised_loss.item())
256 | noised_aver_dice = noised_dice_total / len(test_loader)
257 | aver_dice = dice_total / len(test_loader)
258 | print('====> noised_aver_dice: {:.4f}'.format(noised_aver_dice))
259 | print('====> aver_dice: {:.4f}'.format(aver_dice))
260 | return loss_meter.avg,noised_loss_meter.avg, aver_dice,noised_aver_dice
261 |
262 |
263 | epoch_loss = 0
264 | best_dice = 0
265 | for epoch in range(1, args.epochs + 1):
266 | print('===========Train begining!===========')
267 | print('Epoch {}/{}'.format(epoch, args.epochs - 1))
268 | epoch_loss = train(epoch)
269 | print("epoch %d avg_loss:%0.3f" % (epoch, epoch_loss))
270 | val_loss, best_dice = val(args,epoch,best_dice)
271 | test_loss,noised_test_loss, test_dice,noised_test_dice = test(args)
272 |
--------------------------------------------------------------------------------