├── Fm_eval.py ├── README.md ├── binary.py ├── checkpoint └── nullfile ├── data ├── BraTS2019 └── preprocessBraTS.py ├── image ├── F1N.png ├── Trust_E.gif └── nullfile ├── log └── nullfile ├── models ├── criterions.py ├── eval ├── lib │ ├── IntmdSequential.py │ ├── PositionalEncoding.py │ ├── TransBTS_downsample8x_skipconnection.py │ ├── Transformer.py │ ├── UNet3DZoo.py │ ├── Unet_skipconnection.py │ ├── VNet3D.py │ ├── nullfile │ └── seg_eval.py └── trustedseg.py ├── numpyfunctions.py ├── plot.py ├── predict.py ├── results └── nullfile ├── test_uncertainty.py ├── train.py └── trainTBraTS.py /Fm_eval.py: -------------------------------------------------------------------------------- 1 | # This is the eval for Sm Fm ... 2 | import os 3 | import argparse 4 | import tqdm 5 | import sys 6 | 7 | import numpy as np 8 | 9 | from PIL import Image 10 | from tabulate import tabulate 11 | 12 | filepath = os.path.split(os.path.abspath(__file__))[0] 13 | repopath = os.path.split(filepath)[0] 14 | sys.path.append(repopath) 15 | 16 | from utils.eval_functions import * 17 | from utils.utils import * 18 | 19 | def evaluate(opt, args): 20 | if os.path.isdir(opt.Eval.result_path) is False: 21 | os.makedirs(opt.Eval.result_path) 22 | 23 | method = os.path.split(opt.Eval.pred_root)[-1] 24 | Thresholds = np.linspace(1, 0, 256) 25 | headers = opt.Eval.metrics #['meanDic', 'meanIoU', 'wFm', 'Sm', 'meanEm', 'mae', 'maxEm', 'maxDic', 'maxIoU', 'meanSen', 'maxSen', 'meanSpe', 'maxSpe'] 26 | results = [] 27 | 28 | if args.verbose is True: 29 | print('#' * 20, 'Start Evaluation', '#' * 20) 30 | datasets = tqdm.tqdm(opt.Eval.datasets, desc='Expr - ' + method, total=len( 31 | opt.Eval.datasets), position=0, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}') 32 | else: 33 | datasets = opt.Eval.datasets 34 | 35 | for dataset in datasets: 36 | pred_root = os.path.join(opt.Eval.pred_root, dataset) 37 | gt_root = os.path.join(opt.Eval.gt_root, dataset, 'masks') 38 | 39 | preds = os.listdir(pred_root) 40 | gts = os.listdir(gt_root) 41 | 42 | preds.sort() 43 | gts.sort() 44 | 45 | threshold_Fmeasure = np.zeros((len(preds), len(Thresholds))) 46 | threshold_Emeasure = np.zeros((len(preds), len(Thresholds))) 47 | threshold_IoU = np.zeros((len(preds), len(Thresholds))) 48 | # threshold_Precision = np.zeros((len(preds), len(Thresholds))) 49 | # threshold_Recall = np.zeros((len(preds), len(Thresholds))) 50 | threshold_Sensitivity = np.zeros((len(preds), len(Thresholds))) 51 | threshold_Specificity = np.zeros((len(preds), len(Thresholds))) 52 | threshold_Dice = np.zeros((len(preds), len(Thresholds))) 53 | 54 | Smeasure = np.zeros(len(preds)) 55 | wFmeasure = np.zeros(len(preds)) 56 | MAE = np.zeros(len(preds)) 57 | 58 | if args.verbose is True: 59 | samples = tqdm.tqdm(enumerate(zip(preds, gts)), desc=dataset + ' - Evaluation', total=len( 60 | preds), position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}') 61 | else: 62 | samples = enumerate(zip(preds, gts)) 63 | 64 | for i, sample in samples: 65 | pred, gt = sample 66 | assert os.path.splitext(pred)[0] == os.path.splitext(gt)[0] 67 | 68 | pred_mask = np.array(Image.open(os.path.join(pred_root, pred))) 69 | gt_mask = np.array(Image.open(os.path.join(gt_root, gt))) 70 | 71 | if len(pred_mask.shape) != 2: 72 | pred_mask = pred_mask[:, :, 0] 73 | if len(gt_mask.shape) != 2: 74 | gt_mask = gt_mask[:, :, 0] 75 | 76 | assert pred_mask.shape == gt_mask.shape 77 | 78 | gt_mask = gt_mask.astype(np.float64) / 255 79 | gt_mask = (gt_mask > 0.5).astype(np.float64) 80 | 81 | pred_mask = pred_mask.astype(np.float64) / 255 82 | 83 | Smeasure[i] = StructureMeasure(pred_mask, gt_mask) 84 | wFmeasure[i] = original_WFb(pred_mask, gt_mask) 85 | MAE[i] = np.mean(np.abs(gt_mask - pred_mask)) 86 | 87 | threshold_E = np.zeros(len(Thresholds)) 88 | threshold_F = np.zeros(len(Thresholds)) 89 | threshold_Pr = np.zeros(len(Thresholds)) 90 | threshold_Rec = np.zeros(len(Thresholds)) 91 | threshold_Iou = np.zeros(len(Thresholds)) 92 | threshold_Spe = np.zeros(len(Thresholds)) 93 | threshold_Dic = np.zeros(len(Thresholds)) 94 | 95 | for j, threshold in enumerate(Thresholds): 96 | threshold_Pr[j], threshold_Rec[j], threshold_Spe[j], threshold_Dic[j], threshold_F[j], threshold_Iou[j] = Fmeasure_calu(pred_mask, gt_mask, threshold) 97 | 98 | Bi_pred = np.zeros_like(pred_mask) 99 | Bi_pred[pred_mask >= threshold] = 1 100 | threshold_E[j] = EnhancedMeasure(Bi_pred, gt_mask) 101 | 102 | threshold_Emeasure[i, :] = threshold_E 103 | threshold_Fmeasure[i, :] = threshold_F 104 | threshold_Sensitivity[i, :] = threshold_Rec 105 | threshold_Specificity[i, :] = threshold_Spe 106 | threshold_Dice[i, :] = threshold_Dic 107 | threshold_IoU[i, :] = threshold_Iou 108 | 109 | result = [] 110 | 111 | mae = np.mean(MAE) 112 | Sm = np.mean(Smeasure) 113 | wFm = np.mean(wFmeasure) 114 | 115 | column_E = np.mean(threshold_Emeasure, axis=0) 116 | meanEm = np.mean(column_E) 117 | maxEm = np.max(column_E) 118 | 119 | column_Sen = np.mean(threshold_Sensitivity, axis=0) 120 | meanSen = np.mean(column_Sen) 121 | maxSen = np.max(column_Sen) 122 | 123 | column_Spe = np.mean(threshold_Specificity, axis=0) 124 | meanSpe = np.mean(column_Spe) 125 | maxSpe = np.max(column_Spe) 126 | 127 | column_Dic = np.mean(threshold_Dice, axis=0) 128 | meanDic = np.mean(column_Dic) 129 | maxDic = np.max(column_Dic) 130 | 131 | column_IoU = np.mean(threshold_IoU, axis=0) 132 | meanIoU = np.mean(column_IoU) 133 | maxIoU = np.max(column_IoU) 134 | 135 | # result.extend([meanDic, meanIoU, wFm, Sm, meanEm, mae, maxEm, maxDic, maxIoU, meanSen, maxSen, meanSpe, maxSpe]) 136 | # results.append([dataset, *result]) 137 | 138 | out = [] 139 | for metric in opt.Eval.metrics: 140 | out.append(eval(metric)) 141 | 142 | result.extend(out) 143 | results.append([dataset, *result]) 144 | 145 | csv = os.path.join(opt.Eval.result_path, 'result_' + dataset + '.csv') 146 | if os.path.isfile(csv) is True: 147 | csv = open(csv, 'a') 148 | else: 149 | csv = open(csv, 'w') 150 | csv.write(', '.join(['method', *headers]) + '\n') 151 | 152 | out_str = method + ',' 153 | for metric in result: 154 | out_str += '{:.4f}'.format(metric) + ',' 155 | out_str += '\n' 156 | 157 | csv.write(out_str) 158 | csv.close() 159 | tab = tabulate(results, headers=['dataset', *headers], floatfmt=".3f") 160 | 161 | if args.verbose is True: 162 | print(tab) 163 | print("#"*20, "End Evaluation", "#"*20) 164 | 165 | return tab 166 | 167 | if __name__ == "__main__": 168 | args = parse_args() 169 | opt = load_config(args.config) 170 | evaluate(opt, args) 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TBraTS 2 | * This repository provides the code for our accepted MICCAI'2022 paper "TBraTS: Trusted Brain Tumor Segmentation" 3 | * Official implementation [TBraTS: Trusted Brain Tumor Segmentation](https://arxiv.org/abs/2206.09309) 4 | * Journal version [[Paper]](https://arxiv.org/abs/2301.00349) [[code]](https://github.com/Cocofeat/UMIS) 5 | * Official video in [MICS2022](https://aim.nuist.edu.cn/events/mics2022.htm) of [TBraTS: Trusted Brain Tumor Segmentation](https://www.bilibili.com/video/BV1nW4y1a7Qp/?spm_id_from=333.337.search-card.all.click&vd_source=6ab19d355475883daafd34a6daae54a5) (**3rd Prize**) 6 | 7 | ## Introduction 8 | Despite recent improvements in the accuracy of brain tumor segmentation, the results still exhibit low levels of confidence and robustness. Uncertainty estimation is one effective way to change this situation, as it provides a measure of confidence in the segmentation results. In this paper, we propose a trusted brain tumor segmentation network which can generate robust segmentation results and reliable uncertainty estimations without excessive computational burden and modification of the backbone network. In our method, uncertainty is modeled explicitly using subjective logic theory, which treats the predictions of backbone neural network as subjective opinions by parameterizing the class probabilities of the segmentation as a Dirichlet distribution. Meanwhile, the trusted segmentation framework learns the function that gathers reliable evidence from the feature leading to the final segmentation results. Overall, our unified trusted segmentation framework endows the model with reliability and robustness to out-of-distribution samples. To evaluate the effectiveness of our model in robustness and reliability, qualitative and quantitative experiments are conducted on the BraTS 2019 dataset. 9 | 10 |
Our TBraTS framework
11 | 12 | ## Requirements 13 | Some important required packages include: 14 | Pytorch version >=0.4.1. 15 | Visdom 16 | Python == 3.7 17 | Some basic python packages such as Numpy. 18 | 19 | ## Data Acquisition 20 | - The multimodal brain tumor datasets (**BraTS 2019**) could be acquired from [here](https://ipp.cbica.upenn.edu/). 21 | 22 | ## Data Preprocess 23 | After downloading the dataset from [here](https://ipp.cbica.upenn.edu/), data preprocessing is needed which is to convert the .nii files as .pkl files and realize date normalization. 24 | 25 | Follow the `python3 data/preprocessBraTS.py ` which is referenced from the [TransBTS](https://github.com/Wenxuan-1119/TransBTS/blob/main/data/preprocess.py) 26 | 27 | ## Training & Testing 28 | Run the `python3 trainTBraTS.py ` : your own backbone with our framework(U/V/AU/TransBTS) 29 | 30 | Run the `python3 train.py ` : the backbone without our framework 31 | 32 | ## :fire: NEWS :fire: 33 | * [09/17] More experiments on trustworthy medical image segmentation please refer to [UMIS](https://github.com/Cocofeat/UMIS). 34 | * [09/17] We released all the codes. 35 | * [06/05] We will release the code as soon as possible. 36 | * [06/13] We have uploaded the main part of our code. We will upload all the code after camera-ready. 37 | * [06/22] Our pre-printed version of the paper is available at [TBraTS: Trusted Brain Tumor Segmentation](https://arxiv.org/abs/2206.09309) 38 | ## Citation 39 | If you find our work is helpful for your research, please consider to cite: 40 | ``` 41 | @InProceedings{Coco2022TBraTS, 42 | author = {Zou, Ke and Yuan, Xuedong and Shen, Xiaojing and Wang, Meng and Fu, Huazhu}, 43 | booktitle = {Medical Image Computing and Computer Assisted Intervention -- MICCAI 2022}, 44 | title = {TBraTS: Trusted Brain Tumor Segmentation}, 45 | year = {2022}, 46 | address = {Cham}, 47 | pages = {503--513}, 48 | publisher = {Springer Nature Switzerland}, 49 | } 50 | ``` 51 | ## Acknowledgement 52 | Part of the code is revised from [TransBTS](https://github.com/Wenxuan-1119/TransBTS) and [TMC](https://github.com/hanmenghan/TMC) 53 | 54 | ## Contact 55 | * If you have any problems about our work, please contact [me](kezou8@gmail.com) 56 | * Project Link: [TBraTS](https://github.com/Cocofeat/TBraTS/) 57 | -------------------------------------------------------------------------------- /checkpoint/nullfile: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/BraTS2019: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | import random 5 | import numpy as np 6 | from torchvision.transforms import transforms 7 | from torch.utils.data import DataLoader 8 | import pickle 9 | from scipy import ndimage 10 | import argparse 11 | 12 | def pkload(fname): 13 | with open(fname, 'rb') as f: 14 | return pickle.load(f) 15 | 16 | 17 | class MaxMinNormalization(object): 18 | def __call__(self, sample): 19 | image = sample['image'] 20 | label = sample['label'] 21 | Max = np.max(image) 22 | Min = np.min(image) 23 | image = (image - Min) / (Max - Min) 24 | 25 | return {'image': image, 'label': label} 26 | 27 | 28 | class Random_Flip(object): 29 | def __call__(self, sample): 30 | image = sample['image'] 31 | label = sample['label'] 32 | if random.random() < 0.5: 33 | image = np.flip(image, 0) 34 | label = np.flip(label, 0) 35 | if random.random() < 0.5: 36 | image = np.flip(image, 1) 37 | label = np.flip(label, 1) 38 | if random.random() < 0.5: 39 | image = np.flip(image, 2) 40 | label = np.flip(label, 2) 41 | 42 | return {'image': image, 'label': label} 43 | 44 | 45 | class Random_Crop(object): 46 | def __call__(self, sample): 47 | image = sample['image'] 48 | label = sample['label'] 49 | H = random.randint(0, 240 - 128) 50 | W = random.randint(0, 240 - 128) 51 | D = random.randint(0, 160 - 128) 52 | 53 | image = image[H: H + 128, W: W + 128, D: D + 128, ...] 54 | label = label[..., H: H + 128, W: W + 128, D: D + 128] 55 | 56 | return {'image': image, 'label': label} 57 | 58 | 59 | class Random_intencity_shift(object): 60 | def __call__(self, sample, factor=0.1): 61 | image = sample['image'] 62 | label = sample['label'] 63 | 64 | scale_factor = np.random.uniform(1.0-factor, 1.0+factor, size=[1, image.shape[1], 1]) 65 | shift_factor = np.random.uniform(-factor, factor, size=[1, image.shape[1], 1]) 66 | 67 | image = image*scale_factor+shift_factor 68 | 69 | return {'image': image, 'label': label} 70 | 71 | class Random_intencity_shiftboth(object): 72 | def __call__(self, sample, factor=0.1): 73 | image = sample['image'] 74 | label = sample['label'] 75 | 76 | scale_factor = np.random.uniform(1.0-factor, 1.0+factor, size=[1, image.shape[1], 1, image.shape[-1]]) 77 | shift_factor = np.random.uniform(-factor, factor, size=[1, image.shape[1], 1, image.shape[-1]]) 78 | 79 | image = image*scale_factor+shift_factor 80 | 81 | return {'image': image, 'label': label} 82 | 83 | class Random_rotate(object): 84 | def __call__(self, sample): 85 | image = sample['image'] 86 | label = sample['label'] 87 | 88 | angle = round(np.random.uniform(-10, 10), 2) 89 | image = ndimage.rotate(image, angle, axes=(0, 1), reshape=False) 90 | label = ndimage.rotate(label, angle, axes=(0, 1), reshape=False) 91 | 92 | return {'image': image, 'label': label} 93 | 94 | 95 | class Pad(object): 96 | def __call__(self, sample): 97 | image = sample['image'] 98 | label = sample['label'] 99 | 100 | image = np.pad(image, ((0, 0), (0, 0), (0, 5)), mode='constant') 101 | label = np.pad(label, ((0, 0), (0, 0), (0, 5)), mode='constant') 102 | return {'image': image, 'label': label} 103 | #(240,240,155)>(240,240,160) 104 | class Padboth(object): 105 | def __call__(self, sample): 106 | image = sample['image'] 107 | label = sample['label'] 108 | 109 | image = np.pad(image, ((0, 0), (0, 0), (0, 5), (0, 0)), mode='constant') 110 | label = np.pad(label, ((0, 0), (0, 0), (0, 5)), mode='constant') 111 | return {'image': image, 'label': label} 112 | #(240,240,155,n)>(240,240,160,n) 113 | 114 | class ToTensor(object): 115 | """Convert ndarrays in sample to Tensors.""" 116 | def __call__(self, sample): 117 | image = sample['image'] 118 | image = np.ascontiguousarray(image) 119 | label = sample['label'] 120 | label = np.ascontiguousarray(label) 121 | 122 | image = torch.from_numpy(image).float().unsqueeze(0) 123 | label = torch.from_numpy(label).long() 124 | 125 | return {'image': image, 'label': label} 126 | 127 | class ToTensorboth(object): 128 | """Convert ndarrays in sample to Tensors.""" 129 | def __call__(self, sample): 130 | image = sample['image'] 131 | image = np.ascontiguousarray(image.transpose(3, 0, 1, 2)) 132 | label = sample['label'] 133 | label = np.ascontiguousarray(label) 134 | 135 | image = torch.from_numpy(image).float() 136 | label = torch.from_numpy(label).long() 137 | 138 | return {'image': image, 'label': label} 139 | 140 | def transform(sample): 141 | trans = transforms.Compose([ 142 | Pad(), 143 | # Random_rotate(), # time-consuming 144 | Random_Crop(), 145 | Random_Flip(), 146 | Random_intencity_shift(), 147 | ToTensor() 148 | ]) 149 | 150 | return trans(sample) 151 | 152 | 153 | def transform_valid(sample): 154 | trans = transforms.Compose([ 155 | Pad(), 156 | # MaxMinNormalization(), 157 | ToTensor() 158 | ]) 159 | 160 | return trans(sample) 161 | 162 | def transformboth(sample): 163 | trans = transforms.Compose([ 164 | Padboth(), 165 | # Random_rotate(), # time-consuming 166 | Random_Crop(), 167 | Random_Flip(), 168 | Random_intencity_shiftboth(), 169 | ToTensorboth() 170 | ]) 171 | 172 | return trans(sample) 173 | 174 | 175 | def transformboth_valid(sample): 176 | trans = transforms.Compose([ 177 | Padboth(), 178 | # MaxMinNormalization(), 179 | ToTensorboth() 180 | ]) 181 | 182 | return trans(sample) 183 | 184 | class BraTS(Dataset): 185 | def __init__(self, list_file, root='', mode='train', modal='t1'): 186 | self.lines = [] 187 | paths, names = [], [] 188 | with open(list_file) as f: 189 | for line in f: 190 | line = line.strip() 191 | name = line.split('/')[-1] 192 | names.append(name) 193 | path = os.path.join(root, line, name + '_') 194 | paths.append(path) 195 | self.lines.append(line) 196 | # changed bo coco 197 | # del paths[0:3] 198 | # del names[0:3] 199 | # del self.lines[0:3] 200 | self.mode = mode 201 | self.modal = modal 202 | self.names = names 203 | self.paths = paths 204 | 205 | def __getitem__(self, item): 206 | # input could be chosed with t1/t2/four modalities 207 | path = self.paths[item] 208 | if self.mode == 'train': 209 | image, label = pkload(path + 'data_f32b04M.pkl') 210 | # print(np.unique(label)) 211 | label[label==4]=3 212 | # print(np.unique(label)) 213 | if self.modal == 't1': 214 | sample = {'image': image[..., 0] , 'label': label} 215 | sample = transform(sample) 216 | elif self.modal =='t2': 217 | sample = {'image': image[..., 1], 'label': label} 218 | sample = transform(sample) 219 | else: 220 | sample = {'image': image, 'label': label} 221 | sample = transformboth(sample) 222 | return sample['image'], sample['label'] 223 | elif self.mode == 'valid': 224 | image, label = pkload(path + 'data_f32b04M.pkl') 225 | label[label == 4] = 3 226 | if self.modal == 't1': 227 | sample = {'image': image[..., 0], 'label': label} 228 | sample = transform_valid(sample) 229 | elif self.modal =='t2': 230 | sample = {'image': image[..., 1], 'label': label} 231 | sample = transform_valid(sample) 232 | else: 233 | sample = {'image': image, 'label': label} 234 | sample = transformboth_valid(sample) 235 | return sample['image'], sample['label'] 236 | else: 237 | image,label = pkload(path + 'data_f32b04M.pkl') 238 | label[label == 4] = 3 239 | if self.modal == 't1': 240 | sample = {'image': image[..., 0], 'label': label} 241 | sample = transform_valid(sample) 242 | elif self.modal =='t2': 243 | sample = {'image': image[..., 1], 'label': label} 244 | sample = transform_valid(sample) 245 | else: 246 | sample = {'image': image, 'label': label} 247 | sample = transformboth_valid(sample) 248 | return sample['image'], sample['label'] 249 | 250 | def __len__(self): 251 | return len(self.names) 252 | 253 | def collate(self, batch): 254 | return [torch.cat(v) for v in zip(*batch)] 255 | 256 | if __name__ == '__main__': 257 | parser = argparse.ArgumentParser() 258 | parser.add_argument('--root', default='E:/BraTSdata1/archive2019', type=str) 259 | parser.add_argument('--train_dir', default='MICCAI_BraTS_2019_Data_TTraining', type=str) 260 | parser.add_argument('--valid_dir', default='MICCAI_BraTS_2019_Data_TValidation', type=str) 261 | parser.add_argument('--test_dir', default='MICCAI_BraTS_2019_Data_TTest', type=str) 262 | parser.add_argument('--mode', default='train', type=str) 263 | parser.add_argument('--train_file', default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt', type=str) 264 | parser.add_argument('--valid_file', default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt', type=str) 265 | parser.add_argument('--test_file', default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt', type=str) 266 | parser.add_argument('--dataset', default='brats', type=str) 267 | parser.add_argument('--num_gpu', default= 4, type=int) 268 | parser.add_argument('--num_workers', default=4, type=int) 269 | parser.add_argument('--batch_size', default=8, type=int) 270 | parser.add_argument('--modal', default='both', type=str) 271 | parser.add_argument('--Variance', default=0.1, type=int) 272 | args = parser.parse_args() 273 | train_list = os.path.join(args.root, args.train_dir, args.train_file) 274 | train_root = os.path.join(args.root, args.train_dir) 275 | val_list = os.path.join(args.root, args.valid_dir, args.valid_file) 276 | val_root = os.path.join(args.root, args.valid_dir) 277 | test_list = os.path.join(args.root, args.test_dir, args.test_file) 278 | test_root = os.path.join(args.root, args.test_dir) 279 | train_set = BraTS(train_list, train_root, args.mode,args.modal) 280 | val_set = BraTS(val_list, val_root, args.mode, args.modal) 281 | test_set = BraTS(test_list, test_root, args.mode, args.modal) 282 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 283 | # train_loader = DataLoader(dataset=train_set, sampler=train_sampler, batch_size=args.batch_size // args.num_gpu, 284 | # drop_last=True, num_workers=args.num_workers, pin_memory=True) 285 | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size) 286 | val_loader = DataLoader(dataset=val_set, batch_size=1) 287 | test_loader = DataLoader(dataset=test_set, batch_size=1) 288 | for i, data in enumerate(train_loader): 289 | x, target = data 290 | if args.mode == 'test': 291 | noise = torch.clamp(torch.randn_like(x) * args.Variance, -args.Variance * 2, args.Variance * 2) 292 | x += noise 293 | # x_no = np.unique(x.numpy()) 294 | # target_no = np.unique(target.numpy()) 295 | -------------------------------------------------------------------------------- /data/preprocessBraTS.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | import nibabel as nib 5 | import shutil 6 | import random 7 | modalities = ('flair', 't1ce', 't1', 't2') 8 | twomodalities = ('t1', 't2') 9 | # this is the file to save the two-modality and four-modality img/mask 10 | # train 11 | train_set = { 12 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/', 13 | 'flist': 'train.txt', 14 | 'has_label': True 15 | } 16 | 17 | # test/validation data using in train dataset 18 | Ttrain_set = { 19 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_TTraining/', 20 | 'flist': 'train.txt', 21 | 'has_label': True 22 | } 23 | Tvalid_set = { 24 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_TValidation/', 25 | 'has_label': True 26 | } 27 | Ttest_set = { 28 | 'root': 'C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_TTest/', 29 | 'has_label': True 30 | } 31 | # valid_set = { 32 | # 'root': 'path to valid set', 33 | # 'flist': 'valid.txt', 34 | # 'has_label': False 35 | # } 36 | # test_set = { 37 | # 'root': 'path to testing set', 38 | # 'flist': 'test.txt', 39 | # 'has_label': False 40 | # } 41 | 42 | def ensure_dir_exists(dir_name): 43 | """Makes sure the folder exists on disk. 44 | Args: 45 | dir_name: Path string to the folder we want to create. 46 | """ 47 | if not os.path.exists(dir_name): 48 | os.makedirs(dir_name) 49 | 50 | def save_subjects(root,dir_name,subjects): 51 | f = open(root+dir_name+"_subject.txt", "w") 52 | str = '\n' 53 | f.write(str.join(subjects)) 54 | f.close() 55 | 56 | def nib_load(file_name): 57 | if not os.path.exists(file_name): 58 | print('Invalid file name, can not find the file!') 59 | 60 | proxy = nib.load(file_name) 61 | data = proxy.get_data() 62 | proxy.uncache() 63 | return data 64 | 65 | 66 | def process_i16(path, has_label=True): 67 | """ Save the original 3D MRI images with dtype=int16. 68 | Noted that no normalization is used! """ 69 | label = np.array(nib_load(path + 'seg.nii.gz'), dtype='uint8', order='C') 70 | 71 | images = np.stack([ 72 | np.array(nib_load(path + modal + '.nii.gz'), dtype='int16', order='C') 73 | for modal in modalities], -1)# [240,240,155] 74 | 75 | output = path + 'data_i16.pkl' 76 | 77 | with open(output, 'wb') as f: 78 | print(output) 79 | print(images.shape, type(images), label.shape, type(label)) # (240,240,155,4) , (240,240,155) 80 | pickle.dump((images, label), f) 81 | 82 | if not has_label: 83 | return 84 | 85 | 86 | def process_f32b0(path, has_label=True): 87 | """ Save the data with dtype=float32. 88 | z-score is used but keep the background with zero! """ 89 | if has_label: 90 | label = np.array(nib_load(path + 'seg.nii'), dtype='uint8', order='C') 91 | images = np.stack([np.array(nib_load(path + modal + '.nii'), dtype='float32', order='C') for modal in modalities], -1) # [240,240,155] 92 | 93 | output = path + 'data_f32b0.pkl' 94 | mask = images.sum(-1) > 0 95 | for k in range(4): 96 | 97 | x = images[..., k] # 98 | y = x[mask] 99 | 100 | # 0.8885 101 | x[mask] -= y.mean() 102 | x[mask] /= y.std() 103 | 104 | images[..., k] = x 105 | 106 | with open(output, 'wb') as f: 107 | print(output) 108 | 109 | if has_label: 110 | pickle.dump((images, label), f) 111 | else: 112 | pickle.dump(images, f) 113 | 114 | if not has_label: 115 | return 116 | 117 | def process_f32b0twomodal(path, has_label=True): 118 | """ Save the data with dtype=float32. 119 | z-score is used but keep the background with zero! """ 120 | if has_label: 121 | label = np.array(nib_load(path + 'seg.nii'), dtype='uint8', order='C') 122 | images = np.stack([np.array(nib_load(path + modal + '.nii'), dtype='float32', order='C') for modal in twomodalities], -1) # [240,240,155] 123 | 124 | output = path + 'data_f32b0.pkl' 125 | mask = images.sum(-1) > 0 126 | for k in range(images.shape[3]): 127 | 128 | x = images[..., k] # 129 | y = x[mask] 130 | 131 | # 0.8885 132 | x[mask] -= y.mean() 133 | x[mask] /= y.std() 134 | 135 | images[..., k] = x 136 | 137 | with open(output, 'wb') as f: 138 | print(output) 139 | 140 | if has_label: 141 | pickle.dump((images, label), f) 142 | else: 143 | pickle.dump(images, f) 144 | 145 | if not has_label: 146 | return 147 | 148 | def doit(dset,args_modal): 149 | root, has_label = dset['root'], dset['has_label'] 150 | file_list = os.path.join(root, dset['flist']) 151 | subjects = open(file_list).read().splitlines() 152 | names = [sub.split('/')[-1] for sub in subjects] 153 | paths = [os.path.join(root, sub, name + '_') for sub, name in zip(subjects, names)] 154 | 155 | for path in paths: 156 | print(path) 157 | if args_modal =='2': 158 | process_f32b0twomodal(path, has_label) # two modal 159 | else: 160 | process_f32b0(path, has_label) # four modal 161 | 162 | def move_doit(dset,train_dir,valid_dir,test_dir): 163 | ensure_dir_exists(train_dir) 164 | ensure_dir_exists(valid_dir) 165 | ensure_dir_exists(test_dir) 166 | root, has_label = dset['root'], dset['has_label'] 167 | file_list = os.path.join(root, dset['flist']) 168 | subjects = open(file_list).read().splitlines() 169 | HGG_subjects = [] 170 | LGG_subjects = [] 171 | for sub in subjects: 172 | if "HGG" in sub: 173 | HGG_subjects.append(sub) 174 | else: 175 | LGG_subjects.append(sub) 176 | #val + test 177 | valtest_HGGsubjects = random.sample(HGG_subjects, int(len(HGG_subjects)*0.3)+1) 178 | valtest_LGGsubjects = random.sample(LGG_subjects, int(len(LGG_subjects) * 0.3)) 179 | # val test HGG 180 | val_HGGsubjects = random.sample(valtest_HGGsubjects, int(len(valtest_HGGsubjects)*0.5)) 181 | test_HGGsubjects = list(set(valtest_HGGsubjects).difference(set(val_HGGsubjects))) 182 | # val test LGG 183 | val_LGGsubjects = random.sample(valtest_LGGsubjects, int(len(valtest_LGGsubjects) * 0.5)) 184 | test_LGGsubjects = list(set(valtest_LGGsubjects).difference(set(val_LGGsubjects))) 185 | # val 186 | val_HGGsubjects.extend(val_LGGsubjects) 187 | val_subjects = val_HGGsubjects 188 | # test 189 | test_HGGsubjects.extend(test_LGGsubjects) 190 | test_subjects = test_HGGsubjects 191 | # val + test 192 | valtest_HGGsubjects.extend(valtest_LGGsubjects) 193 | valtest_subjects = valtest_HGGsubjects 194 | tran_subjects = list(set(subjects).difference(set(valtest_subjects))) 195 | # save subject name 196 | save_subjects(root, "Ttrain", tran_subjects) 197 | save_subjects(root, "Ttest", test_subjects) 198 | save_subjects(root, "Tval", val_subjects) 199 | # move training dataset to Ttrain,TTest and Tval dataset 200 | for subject in tran_subjects: 201 | shutil.move(root + subject, train_dir + subject) 202 | for subject_test in test_subjects: 203 | shutil.move(root + subject_test, test_dir + subject_test) 204 | for subject_val in val_subjects: 205 | shutil.move(root + subject_val, valid_dir + subject_val) 206 | 207 | def delete_doit(dir): 208 | rootdir = dir 209 | GG_filelist = os.listdir(rootdir) 210 | for file in GG_filelist: 211 | names = os.listdir(rootdir+file) 212 | for name in names: 213 | files = os.listdir(rootdir + '/' +file + '/' + name) 214 | for modal_file in files: 215 | if '.nii' in modal_file: 216 | del_file = rootdir + '/' +file + '/' + name + '/' + modal_file # 当代码和要删除的文件不在同一个文件夹时,必须使用绝对路径 217 | os.remove(del_file) # 删除文件 218 | print("已经删除:", del_file) 219 | 220 | if __name__ == '__main__': 221 | args_modal = '4' 222 | doit(train_set,args_modal) 223 | move_doit(train_set,Ttrain_set['root'],Tvalid_set['root'],Ttest_set['root']) 224 | delete_doit(Ttrain_set['root']) 225 | delete_doit(Tvalid_set['root']) 226 | delete_doit(Ttest_set['root']) 227 | # doit(valid_set) 228 | # doit(test_set) 229 | 230 | -------------------------------------------------------------------------------- /image/F1N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cocofeat/TBraTS/5255b7c6c8d338c2c1af72368169054d889a977c/image/F1N.png -------------------------------------------------------------------------------- /image/Trust_E.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cocofeat/TBraTS/5255b7c6c8d338c2c1af72368169054d889a977c/image/Trust_E.gif -------------------------------------------------------------------------------- /image/nullfile: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /log/nullfile: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/criterions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch import nn 6 | import numpy as np 7 | 8 | def expand_target(x, n_class,mode='softmax'): 9 | """ 10 | Converts NxDxHxW label image to NxCxDxHxW, where each label is stored in a separate channel 11 | :param input: 4D input image (NxDxHxW) 12 | :param C: number of channels/labels 13 | :return: 5D output image (NxCxDxHxW) 14 | """ 15 | assert x.dim() == 4 16 | shape = list(x.size()) 17 | shape.insert(1, n_class) 18 | shape = tuple(shape) 19 | xx = torch.zeros(shape) 20 | if mode.lower() == 'softmax': 21 | xx[:, 1, :, :, :] = (x == 1) 22 | xx[:, 2, :, :, :] = (x == 2) 23 | xx[:, 3, :, :, :] = (x == 3) 24 | if mode.lower() == 'sigmoid': 25 | xx[:, 0, :, :, :] = (x == 1) 26 | xx[:, 1, :, :, :] = (x == 2) 27 | xx[:, 2, :, :, :] = (x == 3) 28 | return xx.to(x.device) 29 | 30 | def flatten(tensor): 31 | """Flattens a given tensor such that the channel axis is first. 32 | The shapes are transformed as follows: 33 | (N, C, D, H, W) -> (C, N * D * H * W) 34 | """ 35 | C = tensor.size(1) 36 | # new axis order 37 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 38 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 39 | transposed = tensor.permute(axis_order) 40 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 41 | return transposed.reshape(C, -1) 42 | 43 | def Dice(output, target, eps=1e-5): 44 | target = target.float() 45 | num = 2 * (output * target).sum() 46 | den = output.sum() + target.sum() + eps 47 | return 1.0 - num/den 48 | 49 | def sum_tensor(inp, axes, keepdim=False): 50 | axes = np.unique(axes).astype(int) 51 | if keepdim: 52 | for ax in axes: 53 | inp = inp.sum(int(ax), keepdim=True) 54 | else: 55 | for ax in sorted(axes, reverse=True): 56 | inp = inp.sum(int(ax)) 57 | return inp 58 | 59 | def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): 60 | """ 61 | net_output must be (b, c, x, y(, z))) 62 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 63 | if mask is provided it must have shape (b, 1, x, y(, z))) 64 | :param net_output: 65 | :param gt: 66 | :param axes: can be (, ) = no summation 67 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 68 | :param square: if True then fp, tp and fn will be squared before summation 69 | :return: 70 | """ 71 | if axes is None: 72 | axes = tuple(range(2, len(net_output.size()))) 73 | 74 | shp_x = net_output.shape 75 | shp_y = gt.shape 76 | 77 | with torch.no_grad(): 78 | if len(shp_x) != len(shp_y): 79 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 80 | 81 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 82 | # if this is the case then gt is probably already a one hot encoding 83 | y_onehot = gt 84 | else: 85 | gt = gt.long() 86 | y_onehot = torch.zeros(shp_x) 87 | if net_output.device.type == "cuda": 88 | y_onehot = y_onehot.cuda(net_output.device.index) 89 | y_onehot.scatter_(1, gt, 1) 90 | 91 | tp = net_output * y_onehot 92 | fp = net_output * (1 - y_onehot) 93 | fn = (1 - net_output) * y_onehot 94 | tn = (1 - net_output) * (1 - y_onehot) 95 | 96 | if mask is not None: 97 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 98 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 99 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 100 | tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) 101 | 102 | if square: 103 | tp = tp ** 2 104 | fp = fp ** 2 105 | fn = fn ** 2 106 | tn = tn ** 2 107 | 108 | if len(axes) > 0: 109 | tp = sum_tensor(tp, axes, keepdim=False) 110 | fp = sum_tensor(fp, axes, keepdim=False) 111 | fn = sum_tensor(fn, axes, keepdim=False) 112 | tn = sum_tensor(tn, axes, keepdim=False) 113 | 114 | return tp, fp, fn, tn 115 | 116 | class SoftDiceLoss(nn.Module): 117 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.): 118 | """ 119 | """ 120 | super(SoftDiceLoss, self).__init__() 121 | 122 | self.do_bg = do_bg 123 | self.batch_dice = batch_dice 124 | self.apply_nonlin = apply_nonlin 125 | self.smooth = smooth 126 | 127 | def forward(self, x, y, loss_mask=None): 128 | shp_x = x.shape 129 | 130 | if self.batch_dice: 131 | axes = [0] + list(range(2, len(shp_x))) 132 | else: 133 | axes = list(range(2, len(shp_x))) 134 | 135 | if self.apply_nonlin is not None: 136 | x = self.apply_nonlin(x) 137 | 138 | tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) 139 | 140 | nominator = 2 * tp + self.smooth 141 | denominator = 2 * tp + fp + fn + self.smooth 142 | 143 | dc = 1 - (nominator / (denominator + 1e-5)) 144 | 145 | if not self.do_bg: 146 | if self.batch_dice: 147 | dc = dc[1:] 148 | else: 149 | dc = dc[:, 1:] 150 | dc = dc.mean() 151 | # changed by coco 152 | return dc 153 | 154 | class softBCE_dice(nn.Module): 155 | def __init__(self, aggregate="sum",weight_ce=0.5, weight_dice=0.5): 156 | """ 157 | DO NOT APPLY NONLINEARITY IN YOUR NETWORK! 158 | 159 | THIS LOSS IS INTENDED TO BE USED FOR BRATS REGIONS ONLY 160 | :param soft_dice_kwargs: 161 | :param bce_kwargs: 162 | :param aggregate: 163 | """ 164 | super(softBCE_dice, self).__init__() 165 | 166 | self.aggregate = aggregate 167 | # self.ce = nn.BCEWithLogitsLoss(**bce_kwargs) 168 | self.Bce = torch.nn.BCELoss() 169 | self.weight_dice = weight_dice 170 | self.weight_ce = weight_ce 171 | self.sigmoid = nn.Sigmoid() 172 | self.dc = SoftDiceLoss() 173 | def forward(self, output, target): 174 | # ce_loss = self.ce(net_output, target) 175 | Bce_loss1 = self.Bce(self.sigmoid(output[:, 1, ...]), (target == 1).float()) 176 | Bce_loss2 = self.Bce(self.sigmoid(output[:, 2, ...]), (target == 2).float()) 177 | Bce_loss3 = self.Bce(self.sigmoid(output[:, 3, ...]), (target == 3).float()) 178 | # Diceloss1 = Dice(output[:, 1, ...], (target == 1).float()) 179 | # Diceloss2 = Dice(output[:, 2, ...], (target == 2).float()) 180 | # Diceloss3 = Dice(output[:, 3, ...], (target == 4).float()) 181 | Diceloss1 = self.dc(output[:, 1, ...], (target == 1).float()) 182 | Diceloss2 = self.dc(output[:, 2, ...], (target == 2).float()) 183 | Diceloss3 = self.dc(output[:, 3, ...], (target == 3).float()) 184 | if self.aggregate == "sum": 185 | result1 = self.weight_ce * Bce_loss1 + self.weight_dice * Diceloss1 186 | result2 = self.weight_ce * Bce_loss2 + self.weight_dice * Diceloss2 187 | result3 = self.weight_ce * Bce_loss3 + self.weight_dice * Diceloss3 188 | else: 189 | raise NotImplementedError("nah son") # reserved for other stuff (later) 190 | 191 | return result1+result2+result3, 1-result1.data, 1-result2.data, 1-result3.data 192 | 193 | def softmaxBCE_dice(output, target): 194 | ''' 195 | The dice loss for using softmax activation function 196 | :param output: (b, num_class, d, h, w) 197 | :param target: (b, d, h, w) 198 | :return: softmax dice loss torch.nn.BCELoss() 199 | ''' 200 | Bce = torch.nn.BCELoss() 201 | Diceloss1 = Dice(output[:, 1, ...], (target == 1).float()) 202 | Bceloss1 = Bce(output[:, 1, ...], (target == 1).float()) 203 | loss1 = Diceloss1 + Bceloss1 204 | Diceloss2 = Dice(output[:, 2, ...], (target == 2).float()) 205 | Bceloss2 = Bce(output[:, 2, ...], (target == 2).float()) 206 | loss2 = Diceloss2 + Bceloss2 207 | Diceloss3 = Dice(output[:, 3, ...], (target == 3).float()) 208 | Bceloss3 = Bce(output[:, 3, ...], (target == 3).float()) 209 | loss3 = Diceloss3 + Bceloss3 210 | 211 | return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data 212 | 213 | # loss function 214 | 215 | 216 | 217 | def TDice(output, target,criterion_dl): 218 | dice = criterion_dl(output, target) 219 | return dice 220 | 221 | def TFocal(output, target,criterion_fl): 222 | focal = criterion_fl(output, target) 223 | return focal 224 | 225 | def focal_dce_eviloss(p, alpha, c, global_step, annealing_step): 226 | # dice focal loss 227 | criterion_dl = DiceLoss() 228 | L_dice = TDice(alpha,p,criterion_dl) 229 | criterion_fl = FocalLoss(4) 230 | L_focal = TFocal(alpha, p, criterion_fl) 231 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 232 | alpha = alpha.transpose(1, 2) # [N, HW, C] 233 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 234 | S = torch.sum(alpha, dim=1, keepdim=True) 235 | E = alpha - 1 236 | label = F.one_hot(p, num_classes=c) 237 | label = label.view(-1, c) 238 | # digama loss 239 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True) 240 | # log loss 241 | # labelK = label * (torch.log(S) - torch.log(alpha)) 242 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True) 243 | 244 | annealing_coef = min(1, global_step / annealing_step) 245 | alp = E * (1 - label) + 1 246 | L_KL = annealing_coef * KL(alp, c) 247 | 248 | return (L_ace + L_dice + L_focal + L_KL) 249 | 250 | def dce_eviloss(p, alpha, c, global_step, annealing_step): 251 | criterion_dl = DiceLoss() 252 | # L_dice = TDice(alpha,p,criterion_dl) 253 | L_dice,_,_,_ = softmax_dice(alpha, p) 254 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 255 | alpha = alpha.transpose(1, 2) # [N, HW, C] 256 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 257 | S = torch.sum(alpha, dim=1, keepdim=True) 258 | E = alpha - 1 259 | label = F.one_hot(p, num_classes=c) 260 | label = label.view(-1, c) 261 | # digama loss 262 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True) 263 | # log loss 264 | # labelK = label * (torch.log(S) - torch.log(alpha)) 265 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True) 266 | 267 | annealing_coef = min(1, global_step / annealing_step) 268 | alp = E * (1 - label) + 1 269 | L_KL = annealing_coef * KL(alp, c) 270 | 271 | return (L_ace + L_dice + L_KL) 272 | 273 | 274 | def dce_loss(p, alpha, c, global_step, annealing_step): 275 | criterion_dl = DiceLoss() 276 | L_dice = TDice(alpha,p,criterion_dl) 277 | 278 | return L_dice 279 | def ce_loss(p, alpha, c, global_step, annealing_step): 280 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 281 | alpha = alpha.transpose(1, 2) # [N, HW, C] 282 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 283 | 284 | # 0.0 permute all 285 | # alpha = alpha.permute(0,2,3,4,1).view(-1, c) 286 | S = torch.sum(alpha, dim=1, keepdim=True) 287 | E = alpha - 1 288 | label = F.one_hot(p, num_classes=c) 289 | label = label.view(-1, c) 290 | # S = S.permute(0, 2, 3, 4, 1) 291 | # alpha = alpha.permute(0, 2, 3, 4, 1) 292 | # label_K = label * (torch.digamma(S) - torch.digamma(alpha)) 293 | # digama loss 294 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha)), dim=1, keepdim=True) 295 | # log loss 296 | # labelK = label * (torch.log(S) - torch.log(alpha)) 297 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True) 298 | 299 | annealing_coef = min(1, global_step / annealing_step) 300 | # label = label.permute(0, 4, 1, 2, 3) 301 | alp = E * (1 - label) + 1 302 | # alp = E.permute(0, 2, 3, 4, 1) * (1 - label) + 1 303 | L_KL = annealing_coef * KL(alp, c) 304 | 305 | return (L_ace + L_KL) 306 | # return L_ace 307 | 308 | def KL(alpha, c): 309 | S_alpha = torch.sum(alpha, dim=1, keepdim=True) 310 | beta = torch.ones((1, c)).cuda() 311 | # Mbeta = torch.ones((alpha.shape[0],c)).cuda() 312 | S_beta = torch.sum(beta, dim=1, keepdim=True) 313 | lnB = torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha), dim=1, keepdim=True) 314 | lnB_uni = torch.sum(torch.lgamma(beta), dim=1, keepdim=True) - torch.lgamma(S_beta) 315 | dg0 = torch.digamma(S_alpha) 316 | dg1 = torch.digamma(alpha) 317 | kl = torch.sum((alpha - beta) * (dg1 - dg0), dim=1, keepdim=True) + lnB + lnB_uni 318 | return kl 319 | 320 | def mse_loss(p, alpha, c, global_step, annealing_step=1): 321 | S = torch.sum(alpha, dim=1, keepdim=True) 322 | E = alpha - 1 323 | m = alpha / S 324 | label = F.one_hot(p, num_classes=c) 325 | A = torch.sum((label - m) ** 2, dim=1, keepdim=True) 326 | B = torch.sum(alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True) 327 | annealing_coef = min(1, global_step / annealing_step) 328 | alp = E * (1 - label) + 1 329 | C = annealing_coef * KL(alp, c) 330 | return (A + B) + C 331 | 332 | class DiceLoss(nn.Module): 333 | 334 | def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True): 335 | super(DiceLoss, self).__init__() 336 | self.alpha = alpha 337 | self.beta = beta 338 | 339 | self.size_average = size_average 340 | self.reduce = reduce 341 | 342 | def forward(self, preds, targets, weight=False): 343 | N = preds.size(0) 344 | C = preds.size(1) 345 | 346 | preds = preds.permute(0, 2, 3, 4, 1).contiguous().view(-1, C) 347 | if targets.size(1)==4: 348 | targets = targets.permute(0, 2, 3, 4, 1).contiguous().view(-1, C) 349 | else: 350 | targets = targets.view(-1, 1) 351 | 352 | log_P = F.log_softmax(preds, dim=1) 353 | P = torch.exp(log_P) 354 | # P = F.softmax(preds, dim=1) 355 | smooth = torch.zeros(C, dtype=torch.float32).fill_(0.00001) 356 | 357 | class_mask = torch.zeros(preds.shape).to(preds.device) + 1e-8 358 | class_mask.scatter_(1, targets, 1.) 359 | 360 | ones = torch.ones(preds.shape).to(preds.device) 361 | P_ = ones - P 362 | class_mask_ = ones - class_mask 363 | 364 | TP = P * class_mask 365 | FP = P * class_mask_ 366 | FN = P_ * class_mask 367 | 368 | smooth = smooth.to(preds.device) 369 | self.alpha = FP.sum(dim=(0)) / ((FP.sum(dim=(0)) + FN.sum(dim=(0))) + smooth) 370 | 371 | self.alpha = torch.clamp(self.alpha, min=0.2, max=0.8) 372 | #print('alpha:', self.alpha) 373 | self.beta = 1 - self.alpha 374 | num = torch.sum(TP, dim=(0)).float() 375 | den = num + self.alpha * torch.sum(FP, dim=(0)).float() + self.beta * torch.sum(FN, dim=(0)).float() 376 | 377 | dice = num / (den + smooth) 378 | 379 | if not self.reduce: 380 | loss = torch.ones(C).to(dice.device) - dice 381 | return loss 382 | loss = 1 - dice 383 | if weight is not False: 384 | loss *= weight.squeeze(0) 385 | loss = loss.sum() 386 | if self.size_average: 387 | if weight is not False: 388 | loss /= weight.squeeze(0).sum() 389 | else: 390 | loss /= C 391 | 392 | return loss 393 | 394 | class FocalLoss(nn.Module): 395 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True): 396 | super(FocalLoss, self).__init__() 397 | 398 | if alpha is None: 399 | self.alpha = torch.ones(class_num, 1).cuda() 400 | else: 401 | self.alpha = alpha 402 | 403 | self.gamma = gamma 404 | self.size_average = size_average 405 | 406 | def forward(self, preds, targets, weight=False): 407 | N = preds.size(0) 408 | C = preds.size(1) 409 | 410 | preds = preds.permute(0, 2, 3, 4, 1).contiguous().view(-1, C) 411 | targets = targets.view(-1, 1) 412 | 413 | log_P = F.log_softmax(preds, dim=1) 414 | P = torch.exp(log_P) 415 | # P = F.softmax(preds, dim=1) 416 | # log_P = F.log_softmax(preds, dim=1) 417 | # class_mask = torch.zeros(preds.shape).to(preds.device) + 1e-8 418 | class_mask = torch.zeros(preds.shape).to(preds.device) # problem 419 | class_mask.scatter_(1, targets, 1.) 420 | # number = torch.unique(targets) 421 | alpha = self.alpha[targets.data.view(-1)] # problem alpha: weight of data 422 | # alpha = self.alpha.gather(0, targets.view(-1)) 423 | 424 | probs = (P * class_mask).sum(1).view(-1, 1) # problem 425 | log_probs = (log_P * class_mask).sum(1).view(-1, 1) 426 | 427 | # probs = P.gather(1,targets.view(-1,1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll ) 428 | # log_probs = log_P.gather(1,targets.view(-1,1)) 429 | 430 | batch_loss = -alpha * (1-probs).pow(self.gamma)*log_probs 431 | if weight is not False: 432 | element_weight = weight.squeeze(0)[targets.squeeze(0)] 433 | batch_loss = batch_loss * element_weight 434 | 435 | if self.size_average: 436 | loss = batch_loss.mean() 437 | else: 438 | loss = batch_loss.sum() 439 | 440 | return loss 441 | 442 | def softmax_dice(output, target): 443 | ''' 444 | The dice loss for using softmax activation function 445 | :param output: (b, num_class, d, h, w) 446 | :param target: (b, d, h, w) 447 | :return: softmax dice loss torch.nn.BCELoss() 448 | ''' 449 | 450 | loss1 = Dice(output[:, 1, ...], (target == 1).float()) 451 | loss2 = Dice(output[:, 2, ...], (target == 2).float()) 452 | loss3 = Dice(output[:, 3, ...], (target == 3).float()) 453 | 454 | return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data 455 | 456 | def softmax_dice2(output, target): 457 | ''' 458 | The dice loss for using softmax activation function 459 | :param output: (b, num_class, d, h, w) 460 | :param target: (b, d, h, w) 461 | :return: softmax dice loss 462 | ''' 463 | loss0 = Dice(output[:, 0, ...], (target == 0).float()) 464 | loss1 = Dice(output[:, 1, ...], (target == 1).float()) 465 | loss2 = Dice(output[:, 2, ...], (target == 2).float()) 466 | loss3 = Dice(output[:, 3, ...], (target == 3).float()) 467 | 468 | return loss1 + loss2 + loss3 + loss0, 1-loss1.data, 1-loss2.data, 1-loss3.data 469 | 470 | 471 | def sigmoid_dice(output, target): 472 | ''' 473 | The dice loss for using sigmoid activation function 474 | :param output: (b, num_class-1, d, h, w) 475 | :param target: (b, d, h, w) 476 | :return: 477 | ''' 478 | loss1 = Dice(output[:, 0, ...], (target == 1).float()) 479 | loss2 = Dice(output[:, 1, ...], (target == 2).float()) 480 | loss3 = Dice(output[:, 2, ...], (target == 3).float()) 481 | 482 | return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data 483 | 484 | 485 | def Generalized_dice(output, target, eps=1e-5, weight_type='square'): 486 | if target.dim() == 4: #(b, h, w, d) 487 | target[target == 4] = 3 #transfer label 4 to 3 488 | target = expand_target(target, n_class=output.size()[1]) #extend target from (b, h, w, d) to (b, c, h, w, d) 489 | 490 | output = flatten(output)[1:, ...] # transpose [N,4,H,W,D] -> [4,N,H,W,D] -> [3, N*H*W*D] voxels 491 | target = flatten(target)[1:, ...] # [class, N*H*W*D] 492 | 493 | target_sum = target.sum(-1) # sub_class_voxels [3,1] -> 3个voxels 494 | if weight_type == 'square': 495 | class_weights = 1. / (target_sum * target_sum + eps) 496 | elif weight_type == 'identity': 497 | class_weights = 1. / (target_sum + eps) 498 | elif weight_type == 'sqrt': 499 | class_weights = 1. / (torch.sqrt(target_sum) + eps) 500 | else: 501 | raise ValueError('Check out the weight_type :', weight_type) 502 | 503 | # print(class_weights) 504 | intersect = (output * target).sum(-1) 505 | intersect_sum = (intersect * class_weights).sum() 506 | denominator = (output + target).sum(-1) 507 | denominator_sum = (denominator * class_weights).sum() + eps 508 | 509 | loss1 = 2*intersect[0] / (denominator[0] + eps) 510 | loss2 = 2*intersect[1] / (denominator[1] + eps) 511 | loss3 = 2*intersect[2] / (denominator[2] + eps) 512 | 513 | return 1 - 2. * intersect_sum / denominator_sum, loss1, loss2, loss3 514 | -------------------------------------------------------------------------------- /models/eval: -------------------------------------------------------------------------------- 1 | # Forked from https://github.com/GewelsJI/VPS/blob/main/eval/metrics.py 2 | import numpy as np 3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 4 | 5 | 6 | _EPS = np.spacing(1) 7 | _TYPE = np.float64 8 | 9 | 10 | def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: 11 | gt = gt > 128 12 | pred = pred / 255 13 | if pred.max() != pred.min(): 14 | pred = (pred - pred.min()) / (pred.max() - pred.min()) 15 | return pred, gt 16 | 17 | 18 | def _get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float: 19 | return min(2 * matrix.mean(), max_value) 20 | 21 | 22 | class Fmeasure(object): 23 | def __init__(self, length, beta: float = 0.3): 24 | self.beta = beta 25 | self.precisions = [] 26 | self.recalls = [] 27 | self.adaptive_fms = [] 28 | self.changeable_fms = [] 29 | 30 | def step(self, pred: np.ndarray, gt: np.ndarray, idx): 31 | pred, gt = _prepare_data(pred, gt) 32 | 33 | adaptive_fm = self.cal_adaptive_fm(pred=pred, gt=gt) 34 | self.adaptive_fms.append(adaptive_fm) 35 | 36 | precisions, recalls, changeable_fms = self.cal_pr(pred=pred, gt=gt) 37 | self.precisions.append(precisions) 38 | self.recalls.append(recalls) 39 | self.changeable_fms.append(changeable_fms) 40 | 41 | def cal_adaptive_fm(self, pred: np.ndarray, gt: np.ndarray) -> float: 42 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 43 | binary_predcition = pred >= adaptive_threshold 44 | area_intersection = binary_predcition[gt].sum() 45 | if area_intersection == 0: 46 | adaptive_fm = 0 47 | else: 48 | pre = area_intersection / np.count_nonzero(binary_predcition) 49 | rec = area_intersection / np.count_nonzero(gt) 50 | # F_beta measure 51 | adaptive_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec) 52 | return adaptive_fm 53 | 54 | def cal_pr(self, pred: np.ndarray, gt: np.ndarray) -> tuple: 55 | pred = (pred * 255).astype(np.uint8) 56 | bins = np.linspace(0, 256, 257) 57 | fg_hist, _ = np.histogram(pred[gt], bins=bins) 58 | bg_hist, _ = np.histogram(pred[~gt], bins=bins) 59 | fg_w_thrs = np.cumsum(np.flip(fg_hist), axis=0) 60 | bg_w_thrs = np.cumsum(np.flip(bg_hist), axis=0) 61 | TPs = fg_w_thrs 62 | Ps = fg_w_thrs + bg_w_thrs 63 | Ps[Ps == 0] = 1 64 | T = max(np.count_nonzero(gt), 1) 65 | precisions = TPs / Ps 66 | recalls = TPs / T 67 | numerator = (1 + self.beta) * precisions * recalls 68 | denominator = np.where(numerator == 0, 1, self.beta * precisions + recalls) 69 | changeable_fms = numerator / denominator 70 | return precisions, recalls, changeable_fms 71 | 72 | def get_results(self): 73 | adaptive_fm = np.mean(np.array(self.adaptive_fms, _TYPE)) 74 | # precision = np.mean(np.array(self.precisions, dtype=_TYPE), axis=0) # N, 256 75 | # recall = np.mean(np.array(self.recalls, dtype=_TYPE), axis=0) # N, 256 76 | changeable_fm = np.mean(np.array(self.changeable_fms, dtype=_TYPE), axis=0) 77 | # return dict(fm=dict(adp=adaptive_fm, curve=changeable_fm), 78 | # pr=dict(p=precision, r=recall)) 79 | return dict(adpFm=adaptive_fm, meanFm=changeable_fm, maxFm=changeable_fm) 80 | 81 | 82 | class MAE(object): 83 | def __init__(self, length): 84 | self.maes = [] 85 | 86 | def step(self, pred: np.ndarray, gt: np.ndarray, idx): 87 | pred, gt = _prepare_data(pred, gt) 88 | 89 | mae = self.cal_mae(pred, gt) 90 | self.maes.append(mae) 91 | 92 | def cal_mae(self, pred: np.ndarray, gt: np.ndarray) -> float: 93 | mae = np.mean(np.abs(pred - gt)) 94 | return mae 95 | 96 | def get_results(self): 97 | mae = np.mean(np.array(self.maes, _TYPE)) 98 | return dict(MAE=mae) 99 | 100 | 101 | class Smeasure(object): 102 | def __init__(self, length, alpha: float = 0.5): 103 | self.sms = [] 104 | self.alpha = alpha 105 | 106 | def step(self, pred: np.ndarray, gt: np.ndarray, idx): 107 | pred, gt = _prepare_data(pred=pred, gt=gt) 108 | 109 | sm = self.cal_sm(pred, gt) 110 | self.sms.append(sm) 111 | 112 | def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: 113 | y = np.mean(gt) 114 | if y == 0: 115 | sm = 1 - np.mean(pred) 116 | elif y == 1: 117 | sm = np.mean(pred) 118 | else: 119 | sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 120 | sm = max(0, sm) 121 | return sm 122 | 123 | def object(self, pred: np.ndarray, gt: np.ndarray) -> float: 124 | fg = pred * gt 125 | bg = (1 - pred) * (1 - gt) 126 | u = np.mean(gt) 127 | object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) 128 | return object_score 129 | 130 | def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: 131 | x = np.mean(pred[gt == 1]) 132 | sigma_x = np.std(pred[gt == 1], ddof=1) 133 | score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) 134 | return score 135 | 136 | def region(self, pred: np.ndarray, gt: np.ndarray) -> float: 137 | x, y = self.centroid(gt) 138 | part_info = self.divide_with_xy(pred, gt, x, y) 139 | w1, w2, w3, w4 = part_info['weight'] 140 | pred1, pred2, pred3, pred4 = part_info['pred'] 141 | gt1, gt2, gt3, gt4 = part_info['gt'] 142 | score1 = self.ssim(pred1, gt1) 143 | score2 = self.ssim(pred2, gt2) 144 | score3 = self.ssim(pred3, gt3) 145 | score4 = self.ssim(pred4, gt4) 146 | 147 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 148 | 149 | def centroid(self, matrix: np.ndarray) -> tuple: 150 | """ 151 | To ensure consistency with the matlab code, one is added to the centroid coordinate, 152 | so there is no need to use the redundant addition operation when dividing the region later, 153 | because the sequence generated by ``1:X`` in matlab will contain ``X``. 154 | :param matrix: a bool data array 155 | :return: the centroid coordinate 156 | """ 157 | h, w = matrix.shape 158 | area_object = np.count_nonzero(matrix) 159 | if area_object == 0: 160 | x = np.round(w / 2) 161 | y = np.round(h / 2) 162 | else: 163 | # More details can be found at: https://www.yuque.com/lart/blog/gpbigm 164 | y, x = np.argwhere(matrix).mean(axis=0).round() 165 | return int(x) + 1, int(y) + 1 166 | 167 | def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict: 168 | h, w = gt.shape 169 | area = h * w 170 | 171 | gt_LT = gt[0:y, 0:x] 172 | gt_RT = gt[0:y, x:w] 173 | gt_LB = gt[y:h, 0:x] 174 | gt_RB = gt[y:h, x:w] 175 | 176 | pred_LT = pred[0:y, 0:x] 177 | pred_RT = pred[0:y, x:w] 178 | pred_LB = pred[y:h, 0:x] 179 | pred_RB = pred[y:h, x:w] 180 | 181 | w1 = x * y / area 182 | w2 = y * (w - x) / area 183 | w3 = (h - y) * x / area 184 | w4 = 1 - w1 - w2 - w3 185 | 186 | return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB), 187 | pred=(pred_LT, pred_RT, pred_LB, pred_RB), 188 | weight=(w1, w2, w3, w4)) 189 | 190 | def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: 191 | h, w = pred.shape 192 | N = h * w 193 | 194 | x = np.mean(pred) 195 | y = np.mean(gt) 196 | 197 | sigma_x = np.sum((pred - x) ** 2) / (N - 1) 198 | sigma_y = np.sum((gt - y) ** 2) / (N - 1) 199 | sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) 200 | 201 | alpha = 4 * x * y * sigma_xy 202 | beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) 203 | 204 | if alpha != 0: 205 | score = alpha / (beta + _EPS) 206 | elif alpha == 0 and beta == 0: 207 | score = 1 208 | else: 209 | score = 0 210 | return score 211 | 212 | def get_results(self): 213 | sm = np.mean(np.array(self.sms, dtype=_TYPE)) 214 | return dict(Smeasure=sm) 215 | 216 | 217 | class Emeasure(object): 218 | def __init__(self, length): 219 | self.adaptive_ems = [] 220 | self.changeable_ems = [] 221 | 222 | def step(self, pred: np.ndarray, gt: np.ndarray, idx): 223 | pred, gt = _prepare_data(pred=pred, gt=gt) 224 | self.gt_fg_numel = np.count_nonzero(gt) 225 | self.gt_size = gt.shape[0] * gt.shape[1] 226 | 227 | changeable_ems = self.cal_changeable_em(pred, gt) 228 | self.changeable_ems.append(changeable_ems) 229 | adaptive_em = self.cal_adaptive_em(pred, gt) 230 | self.adaptive_ems.append(adaptive_em) 231 | 232 | def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: 233 | adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) 234 | adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) 235 | return adaptive_em 236 | 237 | def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 238 | changeable_ems = self.cal_em_with_cumsumhistogram(pred, gt) 239 | return changeable_ems 240 | 241 | def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: 242 | binarized_pred = pred >= threshold 243 | fg_fg_numel = np.count_nonzero(binarized_pred & gt) 244 | fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) 245 | 246 | fg___numel = fg_fg_numel + fg_bg_numel 247 | bg___numel = self.gt_size - fg___numel 248 | 249 | if self.gt_fg_numel == 0: 250 | enhanced_matrix_sum = bg___numel 251 | elif self.gt_fg_numel == self.gt_size: 252 | enhanced_matrix_sum = fg___numel 253 | else: 254 | parts_numel, combinations = self.generate_parts_numel_combinations( 255 | fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel, 256 | pred_fg_numel=fg___numel, pred_bg_numel=bg___numel, 257 | ) 258 | 259 | results_parts = [] 260 | for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)): 261 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 262 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 263 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 264 | results_parts.append(enhanced_matrix_value * part_numel) 265 | enhanced_matrix_sum = sum(results_parts) 266 | 267 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 268 | return em 269 | 270 | def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: 271 | pred = (pred * 255).astype(np.uint8) 272 | bins = np.linspace(0, 256, 257) 273 | fg_fg_hist, _ = np.histogram(pred[gt], bins=bins) 274 | fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins) 275 | fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0) 276 | fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0) 277 | 278 | fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs 279 | bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs 280 | 281 | if self.gt_fg_numel == 0: 282 | enhanced_matrix_sum = bg___numel_w_thrs 283 | elif self.gt_fg_numel == self.gt_size: 284 | enhanced_matrix_sum = fg___numel_w_thrs 285 | else: 286 | parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations( 287 | fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs, 288 | pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs, 289 | ) 290 | 291 | results_parts = np.empty(shape=(4, 256), dtype=np.float64) 292 | for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)): 293 | align_matrix_value = 2 * (combination[0] * combination[1]) / \ 294 | (combination[0] ** 2 + combination[1] ** 2 + _EPS) 295 | enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 296 | results_parts[i] = enhanced_matrix_value * part_numel 297 | enhanced_matrix_sum = results_parts.sum(axis=0) 298 | 299 | em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) 300 | return em 301 | 302 | def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel): 303 | bg_fg_numel = self.gt_fg_numel - fg_fg_numel 304 | bg_bg_numel = pred_bg_numel - bg_fg_numel 305 | 306 | parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] 307 | 308 | mean_pred_value = pred_fg_numel / self.gt_size 309 | mean_gt_value = self.gt_fg_numel / self.gt_size 310 | 311 | demeaned_pred_fg_value = 1 - mean_pred_value 312 | demeaned_pred_bg_value = 0 - mean_pred_value 313 | demeaned_gt_fg_value = 1 - mean_gt_value 314 | demeaned_gt_bg_value = 0 - mean_gt_value 315 | 316 | combinations = [ 317 | (demeaned_pred_fg_value, demeaned_gt_fg_value), 318 | (demeaned_pred_fg_value, demeaned_gt_bg_value), 319 | (demeaned_pred_bg_value, demeaned_gt_fg_value), 320 | (demeaned_pred_bg_value, demeaned_gt_bg_value) 321 | ] 322 | return parts_numel, combinations 323 | 324 | def get_results(self): 325 | adaptive_em = np.mean(np.array(self.adaptive_ems, dtype=_TYPE)) 326 | changeable_em = np.mean(np.array(self.changeable_ems, dtype=_TYPE), axis=0) 327 | return dict(adpEm=adaptive_em, meanEm=changeable_em, maxEm=changeable_em) 328 | 329 | 330 | class WeightedFmeasure(object): 331 | def __init__(self, length, beta: float = 1): 332 | self.beta = beta 333 | self.weighted_fms = [] 334 | 335 | def step(self, pred: np.ndarray, gt: np.ndarray, idx): 336 | pred, gt = _prepare_data(pred=pred, gt=gt) 337 | 338 | if np.all(~gt): 339 | wfm = 0 340 | else: 341 | wfm = self.cal_wfm(pred, gt) 342 | self.weighted_fms.append(wfm) 343 | 344 | def cal_wfm(self, pred: np.ndarray, gt: np.ndarray) -> float: 345 | # [Dst,IDXT] = bwdist(dGT); 346 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 347 | 348 | # %Pixel dependency 349 | # E = abs(FG-dGT); 350 | E = np.abs(pred - gt) 351 | Et = np.copy(E) 352 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 353 | 354 | # K = fspecial('gaussian',7,5); 355 | # EA = imfilter(Et,K); 356 | K = self.matlab_style_gauss2D((7, 7), sigma=5) 357 | EA = convolve(Et, weights=K, mode="constant", cval=0) 358 | # MIN_E_EA = E; 359 | # MIN_E_EA(GT & EA np.ndarray: 379 | """ 380 | 2D gaussian mask - should give the same result as MATLAB's 381 | fspecial('gaussian',[shape],[sigma]) 382 | """ 383 | m, n = [(ss - 1) / 2 for ss in shape] 384 | y, x = np.ogrid[-m: m + 1, -n: n + 1] 385 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 386 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 387 | sumh = h.sum() 388 | if sumh != 0: 389 | h /= sumh 390 | return h 391 | 392 | def get_results(self): 393 | weighted_fm = np.mean(np.array(self.weighted_fms, dtype=_TYPE)) 394 | return dict(wFmeasure=weighted_fm) 395 | 396 | 397 | class Medical(object): 398 | def __init__(self, length): 399 | self.Thresholds = np.linspace(1, 0, 256) 400 | 401 | self.threshold_Sensitivity = np.zeros((length, len(self.Thresholds))) 402 | self.threshold_Specificity = np.zeros((length, len(self.Thresholds))) 403 | self.threshold_Dice = np.zeros((length, len(self.Thresholds))) 404 | self.threshold_IoU = np.zeros((length, len(self.Thresholds))) 405 | 406 | def Fmeasure_calu(self, pred, gt, threshold): 407 | if threshold > 1: 408 | threshold = 1 409 | 410 | Label3 = np.zeros_like(gt) 411 | Label3[pred >= threshold] = 1 412 | 413 | NumRec = np.sum(Label3 == 1) 414 | NumNoRec = np.sum(Label3 == 0) 415 | 416 | LabelAnd = (Label3 == 1) & (gt == 1) 417 | NumAnd = np.sum(LabelAnd == 1) 418 | num_obj = np.sum(gt) 419 | num_pred = np.sum(Label3) 420 | 421 | FN = num_obj - NumAnd 422 | FP = NumRec - NumAnd 423 | TN = NumNoRec - FN 424 | 425 | if NumAnd == 0: 426 | RecallFtem = 0 427 | Dice = 0 428 | SpecifTem = 0 429 | IoU = 0 430 | 431 | else: 432 | IoU = NumAnd / (FN + NumRec) 433 | RecallFtem = NumAnd / num_obj 434 | SpecifTem = TN / (TN + FP) 435 | Dice = 2 * NumAnd / (num_obj + num_pred) 436 | 437 | return RecallFtem, SpecifTem, Dice, IoU 438 | 439 | def step(self, pred, gt, idx): 440 | pred, gt = _prepare_data(pred=pred, gt=gt) 441 | 442 | threshold_Rec = np.zeros(len(self.Thresholds)) 443 | threshold_Iou = np.zeros(len(self.Thresholds)) 444 | threshold_Spe = np.zeros(len(self.Thresholds)) 445 | threshold_Dic = np.zeros(len(self.Thresholds)) 446 | 447 | for j, threshold in enumerate(self.Thresholds): 448 | threshold_Rec[j], threshold_Spe[j], threshold_Dic[j], \ 449 | threshold_Iou[j] = self.Fmeasure_calu(pred, gt, threshold) 450 | 451 | self.threshold_Sensitivity[idx, :] = threshold_Rec 452 | self.threshold_Specificity[idx, :] = threshold_Spe 453 | self.threshold_Dice[idx, :] = threshold_Dic 454 | self.threshold_IoU[idx, :] = threshold_Iou 455 | 456 | def get_results(self): 457 | column_Sen = np.mean(self.threshold_Sensitivity, axis=0) 458 | column_Spe = np.mean(self.threshold_Specificity, axis=0) 459 | column_Dic = np.mean(self.threshold_Dice, axis=0) 460 | column_IoU = np.mean(self.threshold_IoU, axis=0) 461 | 462 | return dict(meanSen=column_Sen, meanSpe=column_Spe, meanDice=column_Dic, meanIoU=column_IoU, 463 | maxSen=column_Sen, maxSpe=column_Spe, maxDice=column_Dic, maxIoU=column_IoU) 464 | -------------------------------------------------------------------------------- /models/lib/IntmdSequential.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class IntermediateSequential(nn.Sequential): 5 | def __init__(self, *args, return_intermediate=True): 6 | super().__init__(*args) 7 | self.return_intermediate = return_intermediate 8 | 9 | def forward(self, input): 10 | if not self.return_intermediate: 11 | return super().forward(input) 12 | 13 | intermediate_outputs = {} 14 | output = input 15 | for name, module in self.named_children(): 16 | output = intermediate_outputs[name] = module(output) 17 | 18 | return output, intermediate_outputs 19 | 20 | -------------------------------------------------------------------------------- /models/lib/PositionalEncoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FixedPositionalEncoding(nn.Module): 5 | def __init__(self, embedding_dim, max_length=512): 6 | super(FixedPositionalEncoding, self).__init__() 7 | 8 | pe = torch.zeros(max_length, embedding_dim) 9 | position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) 10 | div_term = torch.exp( 11 | torch.arange(0, embedding_dim, 2).float() 12 | * (-torch.log(torch.tensor(10000.0)) / embedding_dim) 13 | ) 14 | pe[:, 0::2] = torch.sin(position * div_term) 15 | pe[:, 1::2] = torch.cos(position * div_term) 16 | pe = pe.unsqueeze(0).transpose(0, 1) 17 | self.register_buffer('pe', pe) 18 | 19 | def forward(self, x): 20 | x = x + self.pe[: x.size(0), :] 21 | return x 22 | 23 | 24 | class LearnedPositionalEncoding(nn.Module): 25 | def __init__(self, max_position_embeddings, embedding_dim, seq_length): 26 | super(LearnedPositionalEncoding, self).__init__() 27 | 28 | self.position_embeddings = nn.Parameter(torch.zeros(1, 4096, 512)) #8x 29 | 30 | def forward(self, x, position_ids=None): 31 | 32 | position_embeddings = self.position_embeddings 33 | return x + position_embeddings 34 | -------------------------------------------------------------------------------- /models/lib/TransBTS_downsample8x_skipconnection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.lib.Transformer import TransformerModel 4 | from models.lib.PositionalEncoding import FixedPositionalEncoding,LearnedPositionalEncoding 5 | from models.lib.Unet_skipconnection import Unet 6 | 7 | 8 | class TransformerBTS(nn.Module): 9 | def __init__( 10 | self, 11 | img_dim, 12 | patch_dim, 13 | num_channels, 14 | embedding_dim, 15 | num_heads, 16 | num_layers, 17 | hidden_dim, 18 | dropout_rate=0.0, 19 | attn_dropout_rate=0.0, 20 | conv_patch_representation=True, 21 | positional_encoding_type="learned", 22 | ): 23 | super(TransformerBTS, self).__init__() 24 | 25 | assert embedding_dim % num_heads == 0 26 | assert img_dim % patch_dim == 0 27 | 28 | self.img_dim = img_dim 29 | self.embedding_dim = embedding_dim 30 | self.num_heads = num_heads 31 | self.patch_dim = patch_dim 32 | self.num_channels = num_channels 33 | self.dropout_rate = dropout_rate 34 | self.attn_dropout_rate = attn_dropout_rate 35 | self.conv_patch_representation = conv_patch_representation 36 | 37 | self.num_patches = int((img_dim // patch_dim) ** 3) 38 | self.seq_length = self.num_patches 39 | self.flatten_dim = 128 * num_channels 40 | 41 | self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim) 42 | if positional_encoding_type == "learned": 43 | self.position_encoding = LearnedPositionalEncoding( 44 | self.seq_length, self.embedding_dim, self.seq_length 45 | ) 46 | elif positional_encoding_type == "fixed": 47 | self.position_encoding = FixedPositionalEncoding( 48 | self.embedding_dim, 49 | ) 50 | 51 | self.pe_dropout = nn.Dropout(p=self.dropout_rate) 52 | 53 | self.transformer = TransformerModel( 54 | embedding_dim, 55 | num_layers, 56 | num_heads, 57 | hidden_dim, 58 | 59 | self.dropout_rate, 60 | self.attn_dropout_rate, 61 | ) 62 | self.pre_head_ln = nn.LayerNorm(embedding_dim) 63 | 64 | if self.conv_patch_representation: 65 | 66 | self.conv_x = nn.Conv3d( 67 | 128, 68 | self.embedding_dim, 69 | kernel_size=3, 70 | stride=1, 71 | padding=1 72 | ) 73 | 74 | self.Unet = Unet(in_channels=num_channels, base_channels=16, num_classes=4) 75 | self.bn = nn.BatchNorm3d(128) 76 | self.relu = nn.ReLU(inplace=True) 77 | 78 | 79 | def encode(self, x): 80 | if self.conv_patch_representation: 81 | # combine embedding with conv patch distribution 82 | x1_1, x2_1, x3_1, x = self.Unet(x) 83 | x = self.bn(x) 84 | x = self.relu(x) 85 | x = self.conv_x(x) 86 | x = x.permute(0, 2, 3, 4, 1).contiguous() 87 | x = x.view(x.size(0), -1, self.embedding_dim) 88 | 89 | else: 90 | x = self.Unet(x) 91 | x = self.bn(x) 92 | x = self.relu(x) 93 | x = ( 94 | x.unfold(2, 2, 2) 95 | .unfold(3, 2, 2) 96 | .unfold(4, 2, 2) 97 | .contiguous() 98 | ) 99 | x = x.view(x.size(0), x.size(1), -1, 8) 100 | x = x.permute(0, 2, 3, 1).contiguous() 101 | x = x.view(x.size(0), -1, self.flatten_dim) 102 | x = self.linear_encoding(x) 103 | 104 | x = self.position_encoding(x) 105 | x = self.pe_dropout(x) 106 | 107 | # apply transformer 108 | x, intmd_x = self.transformer(x) 109 | x = self.pre_head_ln(x) 110 | 111 | return x1_1, x2_1, x3_1, x, intmd_x 112 | 113 | def decode(self, x): 114 | raise NotImplementedError("Should be implemented in child class!!") 115 | 116 | def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]): 117 | 118 | x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs = self.encode(x) 119 | 120 | decoder_output = self.decode( 121 | x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers 122 | ) 123 | 124 | if auxillary_output_layers is not None: 125 | auxillary_outputs = {} 126 | for i in auxillary_output_layers: 127 | val = str(2 * i - 1) 128 | _key = 'Z' + str(i) 129 | auxillary_outputs[_key] = intmd_encoder_outputs[val] 130 | 131 | return decoder_output 132 | 133 | return decoder_output 134 | 135 | def _get_padding(self, padding_type, kernel_size): 136 | assert padding_type in ['SAME', 'VALID'] 137 | if padding_type == 'SAME': 138 | _list = [(k - 1) // 2 for k in kernel_size] 139 | return tuple(_list) 140 | return tuple(0 for _ in kernel_size) 141 | 142 | def _reshape_output(self, x): 143 | x = x.view( 144 | x.size(0), 145 | int(self.img_dim / self.patch_dim), 146 | int(self.img_dim / self.patch_dim), 147 | int(self.img_dim / self.patch_dim), 148 | self.embedding_dim, 149 | ) 150 | x = x.permute(0, 4, 1, 2, 3).contiguous() 151 | 152 | return x 153 | 154 | 155 | class BTS(TransformerBTS): 156 | def __init__( 157 | self, 158 | img_dim, 159 | patch_dim, 160 | num_channels, 161 | num_classes, 162 | embedding_dim, 163 | num_heads, 164 | num_layers, 165 | hidden_dim, 166 | dropout_rate=0.0, 167 | attn_dropout_rate=0.0, 168 | conv_patch_representation=True, 169 | positional_encoding_type="learned", 170 | ): 171 | super(BTS, self).__init__( 172 | img_dim=img_dim, 173 | patch_dim=patch_dim, 174 | num_channels=num_channels, 175 | embedding_dim=embedding_dim, 176 | num_heads=num_heads, 177 | num_layers=num_layers, 178 | hidden_dim=hidden_dim, 179 | dropout_rate=dropout_rate, 180 | attn_dropout_rate=attn_dropout_rate, 181 | conv_patch_representation=conv_patch_representation, 182 | positional_encoding_type=positional_encoding_type, 183 | ) 184 | 185 | self.num_classes = num_classes 186 | 187 | self.Softmax = nn.Softmax(dim=1) 188 | 189 | self.Enblock8_1 = EnBlock1(in_channels=self.embedding_dim) 190 | self.Enblock8_2 = EnBlock2(in_channels=self.embedding_dim // 4) 191 | 192 | self.DeUp4 = DeUp_Cat(in_channels=self.embedding_dim//4, out_channels=self.embedding_dim//8) 193 | self.DeBlock4 = DeBlock(in_channels=self.embedding_dim//8) 194 | 195 | self.DeUp3 = DeUp_Cat(in_channels=self.embedding_dim//8, out_channels=self.embedding_dim//16) 196 | self.DeBlock3 = DeBlock(in_channels=self.embedding_dim//16) 197 | 198 | self.DeUp2 = DeUp_Cat(in_channels=self.embedding_dim//16, out_channels=self.embedding_dim//32) 199 | self.DeBlock2 = DeBlock(in_channels=self.embedding_dim//32) 200 | 201 | self.endconv = nn.Conv3d(self.embedding_dim // 32, 4, kernel_size=1) 202 | 203 | 204 | def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]): 205 | 206 | assert intmd_layers is not None, "pass the intermediate layers for MLA" 207 | encoder_outputs = {} 208 | all_keys = [] 209 | for i in intmd_layers: 210 | val = str(2 * i - 1) 211 | _key = 'Z' + str(i) 212 | all_keys.append(_key) 213 | encoder_outputs[_key] = intmd_x[val] 214 | all_keys.reverse() 215 | 216 | x8 = encoder_outputs[all_keys[0]] 217 | x8 = self._reshape_output(x8) 218 | x8 = self.Enblock8_1(x8) 219 | x8 = self.Enblock8_2(x8) 220 | 221 | y4 = self.DeUp4(x8, x3_1) # (1, 64, 32, 32, 32) 222 | y4 = self.DeBlock4(y4) 223 | 224 | y3 = self.DeUp3(y4, x2_1) # (1, 32, 64, 64, 64) 225 | y3 = self.DeBlock3(y3) 226 | 227 | y2 = self.DeUp2(y3, x1_1) # (1, 16, 128, 128, 128) 228 | y2 = self.DeBlock2(y2) 229 | 230 | y = self.endconv(y2) # (1, 4, 128, 128, 128) 231 | y = self.Softmax(y) 232 | return y 233 | 234 | class EnBlock1(nn.Module): 235 | def __init__(self, in_channels): 236 | super(EnBlock1, self).__init__() 237 | 238 | self.bn1 = nn.BatchNorm3d(512 // 4) 239 | self.relu1 = nn.ReLU(inplace=True) 240 | self.bn2 = nn.BatchNorm3d(512 // 4) 241 | self.relu2 = nn.ReLU(inplace=True) 242 | self.conv1 = nn.Conv3d(in_channels, in_channels // 4, kernel_size=3, padding=1) 243 | self.conv2 = nn.Conv3d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1) 244 | 245 | def forward(self, x): 246 | x1 = self.conv1(x) 247 | x1 = self.bn1(x1) 248 | x1 = self.relu1(x1) 249 | x1 = self.conv2(x1) 250 | x1 = self.bn2(x1) 251 | x1 = self.relu2(x1) 252 | 253 | return x1 254 | 255 | 256 | class EnBlock2(nn.Module): 257 | def __init__(self, in_channels): 258 | super(EnBlock2, self).__init__() 259 | 260 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 261 | self.bn1 = nn.BatchNorm3d(512 // 4) 262 | self.relu1 = nn.ReLU(inplace=True) 263 | self.bn2 = nn.BatchNorm3d(512 // 4) 264 | self.relu2 = nn.ReLU(inplace=True) 265 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 266 | 267 | def forward(self, x): 268 | x1 = self.conv1(x) 269 | x1 = self.bn1(x1) 270 | x1 = self.relu1(x1) 271 | x1 = self.conv2(x1) 272 | x1 = self.bn2(x1) 273 | x1 = self.relu2(x1) 274 | x1 = x1 + x 275 | 276 | return x1 277 | 278 | 279 | class DeUp_Cat(nn.Module): 280 | def __init__(self, in_channels, out_channels): 281 | super(DeUp_Cat, self).__init__() 282 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) 283 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) 284 | self.conv3 = nn.Conv3d(out_channels*2, out_channels, kernel_size=1) 285 | 286 | def forward(self, x, prev): 287 | x1 = self.conv1(x) 288 | y = self.conv2(x1) 289 | # y = y + prev 290 | y = torch.cat((prev, y), dim=1) 291 | y = self.conv3(y) 292 | return y 293 | 294 | class DeBlock(nn.Module): 295 | def __init__(self, in_channels): 296 | super(DeBlock, self).__init__() 297 | 298 | self.bn1 = nn.BatchNorm3d(in_channels) 299 | self.relu1 = nn.ReLU(inplace=True) 300 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 301 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 302 | self.bn2 = nn.BatchNorm3d(in_channels) 303 | self.relu2 = nn.ReLU(inplace=True) 304 | 305 | def forward(self, x): 306 | x1 = self.conv1(x) 307 | x1 = self.bn1(x1) 308 | x1 = self.relu1(x1) 309 | x1 = self.conv2(x1) 310 | x1 = self.bn2(x1) 311 | x1 = self.relu2(x1) 312 | x1 = x1 + x 313 | 314 | return x1 315 | 316 | 317 | 318 | 319 | def TransBTS(input_dims, _conv_repr=True, _pe_type="learned"): 320 | img_dim = 128 321 | num_classes = 4 322 | if input_dims == 'four': 323 | num_channels = 4 # 4 2 1 324 | else: 325 | num_channels = 1 326 | patch_dim = 8 327 | aux_layers = [1, 2, 3, 4] 328 | model = BTS( 329 | img_dim, 330 | patch_dim, 331 | num_channels, 332 | num_classes, 333 | embedding_dim=512, 334 | num_heads=8, 335 | num_layers=4, 336 | hidden_dim=2048, # 4096 337 | dropout_rate=0.1, 338 | attn_dropout_rate=0.1, 339 | conv_patch_representation=_conv_repr, 340 | positional_encoding_type=_pe_type, 341 | ) 342 | 343 | return aux_layers, model 344 | 345 | 346 | if __name__ == '__main__': 347 | with torch.no_grad(): 348 | import os 349 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 350 | cuda0 = torch.device('cuda:0') 351 | x = torch.rand((1, 4, 128, 128, 128), device=cuda0) 352 | _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned") 353 | model.cuda() 354 | y = model(x) 355 | print(y.shape) 356 | -------------------------------------------------------------------------------- /models/lib/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.lib.IntmdSequential import IntermediateSequential 3 | 4 | 5 | class SelfAttention(nn.Module): 6 | def __init__( 7 | self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0 8 | ): 9 | super().__init__() 10 | self.num_heads = heads 11 | head_dim = dim // heads 12 | self.scale = qk_scale or head_dim ** -0.5 13 | 14 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 15 | self.attn_drop = nn.Dropout(dropout_rate) 16 | self.proj = nn.Linear(dim, dim) 17 | self.proj_drop = nn.Dropout(dropout_rate) 18 | 19 | def forward(self, x): 20 | B, N, C = x.shape 21 | qkv = ( 22 | self.qkv(x) 23 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 24 | .permute(2, 0, 3, 1, 4) 25 | ) 26 | q, k, v = ( 27 | qkv[0], 28 | qkv[1], 29 | qkv[2], 30 | ) # make torchscript happy (cannot use tensor as tuple) 31 | 32 | attn = (q @ k.transpose(-2, -1)) * self.scale 33 | attn = attn.softmax(dim=-1) 34 | attn = self.attn_drop(attn) 35 | 36 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 37 | x = self.proj(x) 38 | x = self.proj_drop(x) 39 | return x 40 | 41 | 42 | class Residual(nn.Module): 43 | def __init__(self, fn): 44 | super().__init__() 45 | self.fn = fn 46 | 47 | def forward(self, x): 48 | return self.fn(x) + x 49 | 50 | 51 | class PreNorm(nn.Module): 52 | def __init__(self, dim, fn): 53 | super().__init__() 54 | self.norm = nn.LayerNorm(dim) 55 | self.fn = fn 56 | 57 | def forward(self, x): 58 | return self.fn(self.norm(x)) 59 | 60 | 61 | class PreNormDrop(nn.Module): 62 | def __init__(self, dim, dropout_rate, fn): 63 | super().__init__() 64 | self.norm = nn.LayerNorm(dim) 65 | self.dropout = nn.Dropout(p=dropout_rate) 66 | self.fn = fn 67 | 68 | def forward(self, x): 69 | return self.dropout(self.fn(self.norm(x))) 70 | 71 | 72 | class FeedForward(nn.Module): 73 | def __init__(self, dim, hidden_dim, dropout_rate): 74 | super().__init__() 75 | self.net = nn.Sequential( 76 | nn.Linear(dim, hidden_dim), 77 | nn.GELU(), 78 | nn.Dropout(p=dropout_rate), 79 | nn.Linear(hidden_dim, dim), 80 | nn.Dropout(p=dropout_rate), 81 | ) 82 | 83 | def forward(self, x): 84 | return self.net(x) 85 | 86 | 87 | class TransformerModel(nn.Module): 88 | def __init__( 89 | self, 90 | dim, 91 | depth, 92 | heads, 93 | mlp_dim, 94 | dropout_rate=0.1, 95 | attn_dropout_rate=0.1, 96 | ): 97 | super().__init__() 98 | layers = [] 99 | for _ in range(depth): 100 | layers.extend( 101 | [ 102 | Residual( 103 | PreNormDrop( 104 | dim, 105 | dropout_rate, 106 | SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate), 107 | ) 108 | ), 109 | Residual( 110 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) 111 | ), 112 | ] 113 | ) 114 | # dim = dim / 2 115 | self.net = IntermediateSequential(*layers) 116 | 117 | 118 | def forward(self, x): 119 | return self.net(x) 120 | -------------------------------------------------------------------------------- /models/lib/UNet3DZoo.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | # adapt from https://github.com/MIC-DKFZ/BraTS2017 6 | 7 | 8 | def normalization(planes, norm='gn'): 9 | if norm == 'bn': 10 | m = nn.BatchNorm3d(planes) 11 | elif norm == 'gn': 12 | m = nn.GroupNorm(8, planes) 13 | elif norm == 'in': 14 | m = nn.InstanceNorm3d(planes) 15 | else: 16 | raise ValueError('normalization type {} is not supported'.format(norm)) 17 | return m 18 | 19 | 20 | 21 | class InitConv(nn.Module): 22 | def __init__(self, in_channels=1, out_channels=16, dropout=0.2): 23 | super(InitConv, self).__init__() 24 | 25 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) 26 | self.dropout = dropout 27 | 28 | def forward(self, x): 29 | y = self.conv(x) 30 | y = F.dropout3d(y, self.dropout) 31 | 32 | return y 33 | 34 | 35 | class EnBlock(nn.Module): 36 | def __init__(self, in_channels, norm='gn'): 37 | super(EnBlock, self).__init__() 38 | 39 | self.bn1 = normalization(in_channels, norm=norm) 40 | self.relu1 = nn.ReLU(inplace=True) 41 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 42 | 43 | self.bn2 = normalization(in_channels, norm=norm) 44 | self.relu2 = nn.ReLU(inplace=True) 45 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 46 | 47 | def forward(self, x): 48 | x1 = self.bn1(x) 49 | x1 = self.relu1(x1) 50 | x1 = self.conv1(x1) 51 | y = self.bn2(x1) 52 | y = self.relu2(y) 53 | y = self.conv2(y) 54 | y = y + x 55 | 56 | return y 57 | 58 | 59 | class EnDown(nn.Module): 60 | def __init__(self, in_channels, out_channels): 61 | super(EnDown, self).__init__() 62 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 63 | 64 | def forward(self, x): 65 | y = self.conv(x) 66 | 67 | return y 68 | 69 | class Attention_block(nn.Module): 70 | def __init__(self, F_g, F_l, F_int): 71 | super(Attention_block, self).__init__() 72 | self.W_g = nn.Sequential( 73 | nn.Conv3d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 74 | nn.BatchNorm3d(F_int) 75 | ) 76 | 77 | self.W_x = nn.Sequential( 78 | nn.Conv3d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 79 | nn.BatchNorm3d(F_int) 80 | ) 81 | 82 | self.psi = nn.Sequential( 83 | nn.Conv3d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 84 | nn.BatchNorm3d(1), 85 | nn.Sigmoid() 86 | ) 87 | 88 | self.relu = nn.ReLU(inplace=True) 89 | 90 | def forward(self, g, x): 91 | # 下采样的gating signal 卷积 92 | g1 = self.W_g(g) 93 | # 上采样的 l 卷积 94 | x1 = self.W_x(x) 95 | # concat + relu 96 | psi = self.relu(g1 + x1) 97 | # channel 减为1,并Sigmoid,得到权重矩阵 98 | psi = self.psi(psi) 99 | # 返回加权的 x 100 | return x * psi 101 | 102 | class De_Cat(nn.Module): 103 | def __init__(self, in_channels, out_channels): 104 | super(De_Cat, self).__init__() 105 | # self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) 106 | # self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) 107 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) 108 | 109 | def forward(self, x, prev): 110 | # x1 = self.conv1(x) 111 | # y = self.conv2(x1) 112 | # y = y + prev 113 | y = torch.cat((x, prev), dim=1) 114 | y = self.conv1(y) 115 | return y 116 | 117 | class DeUp_Cat(nn.Module): 118 | def __init__(self, in_channels, out_channels): 119 | super(DeUp_Cat, self).__init__() 120 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) 121 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) 122 | self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=1) 123 | 124 | def forward(self, x, prev): 125 | x1 = self.conv1(x) 126 | y = self.conv2(x1) 127 | # y = y + prev 128 | y = torch.cat((prev, y), dim=1) 129 | y = self.conv3(y) 130 | return y 131 | 132 | class DeUp(nn.Module): 133 | def __init__(self, in_channels, out_channels): 134 | super(DeUp, self).__init__() 135 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) 136 | self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) 137 | self.conv3 = nn.Conv3d(out_channels, out_channels, kernel_size=1) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.conv2(x) 142 | x = self.conv3(x) 143 | return x 144 | 145 | class DeBlock(nn.Module): 146 | def __init__(self, in_channels): 147 | super(DeBlock, self).__init__() 148 | 149 | self.bn1 = nn.BatchNorm3d(in_channels) 150 | self.relu1 = nn.ReLU(inplace=True) 151 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 152 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 153 | self.bn2 = nn.BatchNorm3d(in_channels) 154 | self.relu2 = nn.ReLU(inplace=True) 155 | 156 | def forward(self, x): 157 | x1 = self.conv1(x) 158 | x1 = self.bn1(x1) 159 | x1 = self.relu1(x1) 160 | x1 = self.conv2(x1) 161 | x1 = self.bn2(x1) 162 | x1 = self.relu2(x1) 163 | x1 = x1 + x 164 | 165 | return x1 166 | 167 | class AttUnet(nn.Module): 168 | def __init__(self, in_channels=1, base_channels=16, num_classes=4): 169 | super(AttUnet, self).__init__() 170 | 171 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2) 172 | self.EnBlock1 = EnBlock(in_channels=base_channels) 173 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2) 174 | 175 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2) 176 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2) 177 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4) 178 | 179 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4) 180 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4) 181 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8) 182 | self.Att4 = Attention_block(F_g=64, F_l=64, F_int=32) 183 | self.Att3 = Attention_block(F_g=32, F_l=32, F_int=16) 184 | self.Att2 = Attention_block(F_g=16, F_l=16, F_int=16) 185 | 186 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8) 187 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8) 188 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8) 189 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8) 190 | 191 | self.DeUp4 = DeUp(in_channels=base_channels*8, out_channels=base_channels*4) 192 | self.DeUpCat4 = De_Cat(in_channels=base_channels * 8, out_channels=base_channels * 4) 193 | self.DeBlock4 = DeBlock(in_channels=base_channels*4) 194 | 195 | self.DeUp3 = DeUp(in_channels=base_channels*4, out_channels=base_channels*2) 196 | self.DeUpCat3 = De_Cat(in_channels=base_channels * 4, out_channels=base_channels * 2) 197 | self.DeBlock3 = DeBlock(in_channels=base_channels*2) 198 | 199 | self.DeUp2 = DeUp(in_channels=base_channels*2, out_channels=base_channels) 200 | self.DeUpCat2 = De_Cat(in_channels=base_channels * 2, out_channels=base_channels) 201 | self.DeBlock2 = DeBlock(in_channels=base_channels) 202 | self.endconv = nn.Conv3d(base_channels, num_classes, kernel_size=1) 203 | 204 | def forward(self, x): 205 | x = self.InitConv(x) # (1, 16, 128, 128, 128) 206 | 207 | x1_1 = self.EnBlock1(x) 208 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64) 209 | 210 | x2_1 = self.EnBlock2_1(x1_2) 211 | x2_1 = self.EnBlock2_2(x2_1) 212 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32) 213 | 214 | x3_1 = self.EnBlock3_1(x2_2) 215 | x3_1 = self.EnBlock3_2(x3_1) 216 | x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16) 217 | 218 | x4_1 = self.EnBlock4_1(x3_2) 219 | x4_2 = self.EnBlock4_2(x4_1) 220 | x4_3 = self.EnBlock4_3(x4_2) 221 | x4_4 = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16) 222 | 223 | y4 = self.DeUp4(x4_4) # (1, 64, 32, 32, 32) 224 | x3_1 = self.Att4(g=y4, x=x3_1) 225 | y4 = self.DeUpCat4(x3_1, y4) 226 | y4 = self.DeBlock4(y4) # (1, 64, 32, 32, 32) 227 | 228 | y3 = self.DeUp3(y4) # (1, 32, 64, 64, 64) 229 | x2_1 = self.Att3(g=y3, x=x2_1) 230 | y3 = self.DeUpCat3(y3, x2_1) 231 | y3 = self.DeBlock3(y3) # (1, 32, 64, 64, 64) 232 | 233 | y2 = self.DeUp2(y3) 234 | x1_1 = self.Att2(g=y2, x=x1_1) 235 | y2 = self.DeUpCat2(y2, x1_1) # (1, 16, 128, 128, 128) 236 | y2 = self.DeBlock2(y2) 237 | y = self.endconv(y2) 238 | 239 | return y 240 | 241 | class Unet(nn.Module): 242 | def __init__(self, in_channels=1, base_channels=16, num_classes=4): 243 | super(Unet, self).__init__() 244 | 245 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2) 246 | self.EnBlock1 = EnBlock(in_channels=base_channels) 247 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2) 248 | 249 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2) 250 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2) 251 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4) 252 | 253 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4) 254 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4) 255 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8) 256 | 257 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8) 258 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8) 259 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8) 260 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8) 261 | 262 | self.DeUpCat4 = DeUp_Cat(in_channels=base_channels * 8, out_channels=base_channels * 4) 263 | self.DeBlock4 = DeBlock(in_channels=base_channels*4) 264 | 265 | self.DeUpCat3 = DeUp_Cat(in_channels=base_channels * 4, out_channels=base_channels * 2) 266 | self.DeBlock3 = DeBlock(in_channels=base_channels*2) 267 | 268 | self.DeUpCat2 = DeUp_Cat(in_channels=base_channels * 2, out_channels=base_channels) 269 | self.DeBlock2 = DeBlock(in_channels=base_channels) 270 | self.endconv = nn.Conv3d(base_channels, num_classes, kernel_size=1) 271 | 272 | def forward(self, x): 273 | x = self.InitConv(x) # (1, 16, 128, 128, 128) 274 | 275 | x1_1 = self.EnBlock1(x) 276 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64) 277 | 278 | x2_1 = self.EnBlock2_1(x1_2) 279 | x2_1 = self.EnBlock2_2(x2_1) 280 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32) 281 | 282 | x3_1 = self.EnBlock3_1(x2_2) 283 | x3_1 = self.EnBlock3_2(x3_1) 284 | x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16) 285 | 286 | x4_1 = self.EnBlock4_1(x3_2) 287 | x4_2 = self.EnBlock4_2(x4_1) 288 | x4_3 = self.EnBlock4_3(x4_2) 289 | x4_4 = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16) 290 | 291 | y4 = self.DeUpCat4(x4_4, x3_1) 292 | y4 = self.DeBlock4(y4) # (1, 64, 32, 32, 32) 293 | 294 | y3 = self.DeUpCat3(y4, x2_1) # (1, 32, 64, 64, 64) 295 | y3 = self.DeBlock3(y3) 296 | 297 | y2 = self.DeUpCat2(y3, x1_1) # (1, 16, 128, 128, 128) 298 | y2 = self.DeBlock2(y2) 299 | y = self.endconv(y2) 300 | 301 | return y 302 | class Unetdrop(nn.Module): 303 | def __init__(self, in_channels=1, base_channels=16, num_classes=4): 304 | super(Unetdrop, self).__init__() 305 | 306 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2) 307 | self.EnBlock1 = EnBlock(in_channels=base_channels) 308 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2) 309 | 310 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2) 311 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2) 312 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4) 313 | 314 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4) 315 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4) 316 | self.dropoutd1 = nn.Dropout(p=0.5) 317 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8) 318 | self.dropoutd2 = nn.Dropout(p=0.5) 319 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8) 320 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8) 321 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8) 322 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8) 323 | self.dropoutu1 = nn.Dropout(p=0.5) 324 | self.DeUpCat4 = DeUp_Cat(in_channels=base_channels * 8, out_channels=base_channels * 4) 325 | self.DeBlock4 = DeBlock(in_channels=base_channels*4) 326 | self.dropoutu2 = nn.Dropout(p=0.5) 327 | self.DeUpCat3 = DeUp_Cat(in_channels=base_channels * 4, out_channels=base_channels * 2) 328 | self.DeBlock3 = DeBlock(in_channels=base_channels*2) 329 | 330 | self.DeUpCat2 = DeUp_Cat(in_channels=base_channels * 2, out_channels=base_channels) 331 | self.DeBlock2 = DeBlock(in_channels=base_channels) 332 | self.endconv = nn.Conv3d(base_channels, num_classes, kernel_size=1) 333 | 334 | def forward(self, x): 335 | x = self.InitConv(x) # (1, 16, 128, 128, 128) 336 | 337 | x1_1 = self.EnBlock1(x) 338 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64) 339 | 340 | x2_1 = self.EnBlock2_1(x1_2) 341 | x2_1 = self.EnBlock2_2(x2_1) 342 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32) 343 | 344 | x3_1 = self.EnBlock3_1(x2_2) # (1, 64, 32, 32, 32) 345 | x3_1drop = self.EnBlock3_2(x3_1) 346 | x3_1drop = self.dropoutd1(x3_1drop) 347 | x3_2 = self.EnDown3(x3_1drop) # (1, 128, 16, 16, 16) 348 | 349 | x4_1 = self.EnBlock4_1(x3_2) 350 | x4_2 = self.EnBlock4_2(x4_1) 351 | x4_3 = self.EnBlock4_3(x4_2) 352 | x4_3 = self.dropoutd2(x4_3) 353 | x4_4 = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16) 354 | 355 | y4 = self.DeUpCat4(x4_4, x3_1) 356 | y4 = self.dropoutu1(y4) 357 | y4 = self.DeBlock4(y4) # (1, 64, 32, 32, 32) 358 | 359 | y3 = self.DeUpCat3(y4, x2_1) # (1, 32, 64, 64, 64) 360 | y3 = self.dropoutu2(y3) 361 | y3 = self.DeBlock3(y3) 362 | 363 | y2 = self.DeUpCat2(y3, x1_1) # (1, 16, 128, 128, 128) 364 | y2 = self.DeBlock2(y2) 365 | y = self.endconv(y2) 366 | return y 367 | if __name__ == '__main__': 368 | with torch.no_grad(): 369 | import os 370 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 371 | cuda0 = torch.device('cuda:0') 372 | x = torch.rand((1, 1, 128, 128, 128), device=cuda0) 373 | model = AttUnet(in_channels=1, base_channels=16, num_classes=4) 374 | # model = Unet(in_channels=1, base_channels=16, num_classes=4) 375 | model.cuda() 376 | total = sum([param.nelement() for param in model.parameters()]) 377 | print("Number of model's parameter: %.2fM" % (total / 1e6)) 378 | output = model(x) 379 | print('output:', output.shape) 380 | -------------------------------------------------------------------------------- /models/lib/Unet_skipconnection.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | # adapt from https://github.com/MIC-DKFZ/BraTS2017 6 | 7 | 8 | def normalization(planes, norm='gn'): 9 | if norm == 'bn': 10 | m = nn.BatchNorm3d(planes) 11 | elif norm == 'gn': 12 | m = nn.GroupNorm(8, planes) 13 | elif norm == 'in': 14 | m = nn.InstanceNorm3d(planes) 15 | else: 16 | raise ValueError('normalization type {} is not supported'.format(norm)) 17 | return m 18 | 19 | 20 | 21 | class InitConv(nn.Module): 22 | def __init__(self, in_channels=4, out_channels=16, dropout=0.2): 23 | super(InitConv, self).__init__() 24 | 25 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) 26 | self.dropout = dropout 27 | 28 | def forward(self, x): 29 | y = self.conv(x) 30 | y = F.dropout3d(y, self.dropout) 31 | 32 | return y 33 | 34 | 35 | class EnBlock(nn.Module): 36 | def __init__(self, in_channels, norm='gn'): 37 | super(EnBlock, self).__init__() 38 | 39 | self.bn1 = normalization(in_channels, norm=norm) 40 | self.relu1 = nn.ReLU(inplace=True) 41 | self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 42 | 43 | self.bn2 = normalization(in_channels, norm=norm) 44 | self.relu2 = nn.ReLU(inplace=True) 45 | self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) 46 | 47 | def forward(self, x): 48 | x1 = self.bn1(x) 49 | x1 = self.relu1(x1) 50 | x1 = self.conv1(x1) 51 | y = self.bn2(x1) 52 | y = self.relu2(y) 53 | y = self.conv2(y) 54 | y = y + x 55 | 56 | return y 57 | 58 | 59 | class EnDown(nn.Module): 60 | def __init__(self, in_channels, out_channels): 61 | super(EnDown, self).__init__() 62 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 63 | 64 | def forward(self, x): 65 | y = self.conv(x) 66 | 67 | return y 68 | 69 | 70 | 71 | class Unet(nn.Module): 72 | def __init__(self, in_channels=4, base_channels=16, num_classes=4): 73 | super(Unet, self).__init__() 74 | 75 | self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2) 76 | self.EnBlock1 = EnBlock(in_channels=base_channels) 77 | self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2) 78 | 79 | self.EnBlock2_1 = EnBlock(in_channels=base_channels*2) 80 | self.EnBlock2_2 = EnBlock(in_channels=base_channels*2) 81 | self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4) 82 | 83 | self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4) 84 | self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4) 85 | self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8) 86 | 87 | self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8) 88 | self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8) 89 | self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8) 90 | self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8) 91 | 92 | def forward(self, x): 93 | x = self.InitConv(x) # (1, 16, 128, 128, 128) 94 | 95 | x1_1 = self.EnBlock1(x) 96 | x1_2 = self.EnDown1(x1_1) # (1, 32, 64, 64, 64) 97 | 98 | x2_1 = self.EnBlock2_1(x1_2) 99 | x2_1 = self.EnBlock2_2(x2_1) 100 | x2_2 = self.EnDown2(x2_1) # (1, 64, 32, 32, 32) 101 | 102 | x3_1 = self.EnBlock3_1(x2_2) 103 | x3_1 = self.EnBlock3_2(x3_1) 104 | x3_2 = self.EnDown3(x3_1) # (1, 128, 16, 16, 16) 105 | 106 | x4_1 = self.EnBlock4_1(x3_2) 107 | x4_2 = self.EnBlock4_2(x4_1) 108 | x4_3 = self.EnBlock4_3(x4_2) 109 | output = self.EnBlock4_4(x4_3) # (1, 128, 16, 16, 16) 110 | 111 | return x1_1,x2_1,x3_1,output 112 | 113 | 114 | if __name__ == '__main__': 115 | with torch.no_grad(): 116 | import os 117 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 118 | cuda0 = torch.device('cuda:0') 119 | x = torch.rand((1, 4, 128, 128, 128), device=cuda0) 120 | # model = Unet1(in_channels=4, base_channels=16, num_classes=4) 121 | model = Unet(in_channels=4, base_channels=16, num_classes=4) 122 | model.cuda() 123 | output = model(x) 124 | print('output:', output.shape) 125 | -------------------------------------------------------------------------------- /models/lib/VNet3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'bn': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'gn': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'in': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class ResidualConvBlock(nn.Module): 35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 36 | super(ResidualConvBlock, self).__init__() 37 | 38 | ops = [] 39 | for i in range(n_stages): 40 | if i == 0: 41 | input_channel = n_filters_in 42 | else: 43 | input_channel = n_filters_out 44 | 45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 46 | if normalization == 'bn': 47 | ops.append(nn.BatchNorm3d(n_filters_out)) 48 | elif normalization == 'gn': 49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 50 | elif normalization == 'in': 51 | ops.append(nn.InstanceNorm3d(n_filters_out)) 52 | elif normalization != 'none': 53 | assert False 54 | 55 | if i != n_stages-1: 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | x = (self.conv(x) + x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class DownsamplingConvBlock(nn.Module): 68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 69 | super(DownsamplingConvBlock, self).__init__() 70 | 71 | ops = [] 72 | if normalization != 'none': 73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 74 | if normalization == 'bn': 75 | ops.append(nn.BatchNorm3d(n_filters_out)) 76 | elif normalization == 'gn': 77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 78 | elif normalization == 'in': 79 | ops.append(nn.InstanceNorm3d(n_filters_out)) 80 | else: 81 | assert False 82 | else: 83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 84 | 85 | ops.append(nn.ReLU(inplace=True)) 86 | 87 | self.conv = nn.Sequential(*ops) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class UpsamplingDeconvBlock(nn.Module): 95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 96 | super(UpsamplingDeconvBlock, self).__init__() 97 | 98 | ops = [] 99 | if normalization != 'none': 100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 101 | if normalization == 'bn': 102 | ops.append(nn.BatchNorm3d(n_filters_out)) 103 | elif normalization == 'gn': 104 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 105 | elif normalization == 'in': 106 | ops.append(nn.InstanceNorm3d(n_filters_out)) 107 | else: 108 | assert False 109 | else: 110 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 111 | 112 | ops.append(nn.ReLU(inplace=True)) 113 | 114 | self.conv = nn.Sequential(*ops) 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | return x 119 | 120 | 121 | class Upsampling(nn.Module): 122 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 123 | super(Upsampling, self).__init__() 124 | 125 | ops = [] 126 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 127 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 128 | if normalization == 'bn': 129 | ops.append(nn.BatchNorm3d(n_filters_out)) 130 | elif normalization == 'gn': 131 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 132 | elif normalization == 'in': 133 | ops.append(nn.InstanceNorm3d(n_filters_out)) 134 | elif normalization != 'none': 135 | assert False 136 | ops.append(nn.ReLU(inplace=True)) 137 | 138 | self.conv = nn.Sequential(*ops) 139 | 140 | def forward(self, x): 141 | x = self.conv(x) 142 | return x 143 | 144 | 145 | class VNet(nn.Module): 146 | def __init__(self, n_channels=1, n_classes=4, n_filters=16, normalization='gn', has_dropout=False): 147 | super(VNet, self).__init__() 148 | self.has_dropout = has_dropout 149 | 150 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 151 | self.block_one_dw = DownsamplingConvBlock(n_filters, n_filters * 2, normalization=normalization) # 32 152 | 153 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 154 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) # 64 155 | 156 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 157 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) # 128 158 | 159 | # self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 160 | # self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) # 256 161 | 162 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 163 | self.block_four_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 164 | 165 | self.block_five = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 166 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 167 | 168 | self.block_six = ConvBlock(3, n_filters * 2, n_filters * 2, normalization=normalization) 169 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 170 | 171 | # self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 172 | # self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 173 | 174 | self.block_seven = ConvBlock(1, n_filters, n_filters, normalization=normalization) 175 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 176 | 177 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 178 | # self.__init_weight() 179 | 180 | def encoder(self, input): 181 | x1 = self.block_one(input) 182 | x1_dw = self.block_one_dw(x1) 183 | 184 | x2 = self.block_two(x1_dw) 185 | x2_dw = self.block_two_dw(x2) 186 | 187 | x3 = self.block_three(x2_dw) 188 | x3_dw = self.block_three_dw(x3) 189 | 190 | x4 = self.block_four(x3_dw) 191 | 192 | # x5 = self.block_five(x4_dw) 193 | # x5 = F.dropout3d(x5, p=0.5, training=True) 194 | if self.has_dropout: 195 | x4 = self.dropout(x4) 196 | 197 | res = [x1, x2, x3, x4] 198 | 199 | return res 200 | 201 | def decoder(self, features): 202 | x1 = features[0] 203 | x2 = features[1] 204 | x3 = features[2] 205 | x4 = features[3] 206 | # x5 = features[4] 207 | 208 | x4_up = self.block_four_up(x4) 209 | x4_up = x4_up + x3 210 | 211 | x5 = self.block_five(x4_up) 212 | x5_up = self.block_five_up(x5) 213 | x5_up = x5_up + x2 214 | 215 | x6 = self.block_six(x5_up) 216 | x6_up = self.block_six_up(x6) 217 | x6_up = x6_up + x1 218 | 219 | x7 = self.block_seven(x6_up) 220 | # x9 = F.dropout3d(x9, p=0.5, training=True) 221 | if self.has_dropout: 222 | x7 = self.dropout(x7) 223 | out = self.out_conv(x7) 224 | return out 225 | 226 | 227 | def forward(self, input, turnoff_drop=False): 228 | if turnoff_drop: 229 | has_dropout = self.has_dropout 230 | self.has_dropout = False 231 | features = self.encoder(input) 232 | out = self.decoder(features) 233 | if turnoff_drop: 234 | self.has_dropout = has_dropout 235 | return out 236 | 237 | if __name__ == '__main__': 238 | with torch.no_grad(): 239 | import os 240 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 241 | cuda0 = torch.device('cuda:0') 242 | x = torch.rand((1, 1, 128, 128, 128), device=cuda0) 243 | # model = AttUnet(in_channels=1, base_channels=16, num_classes=4) 244 | model = VNet(n_channels=1, n_classes=4, n_filters=16, normalization='gn', has_dropout=False) 245 | model.cuda() 246 | total = sum([param.nelement() for param in model.parameters()]) 247 | print("Number of model's parameter: %.2fM" % (total / 1e6)) 248 | output = model(x) 249 | print('output:', output.shape) -------------------------------------------------------------------------------- /models/lib/nullfile: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/lib/seg_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import nibabel as nib 4 | 5 | # calculate evaluation metrics for segmentation 6 | def seg_eval_metric(pred_label, gt_label, output_chn): 7 | class_n = np.unique(gt_label) 8 | # dice 9 | dice_c = dice_n_class(move_img=pred_label, refer_img=gt_label, output_chn=output_chn) 10 | return dice_c 11 | 12 | # dice value 13 | def dice_n_class(move_img, refer_img, output_chn): 14 | # list of classes 15 | c_list_old = np.unique(refer_img) 16 | # for those class not in the Gt, set dice to zero 17 | c_list = np.arange(output_chn) 18 | dice_c = [] 19 | for c in range(len(c_list)): 20 | # intersection 21 | ints = np.sum(((move_img == c_list[c]) * 1) * ((refer_img == c_list[c]) * 1)) 22 | # sum 23 | sums = np.sum(((move_img == c_list[c]) * 1) + ((refer_img == c_list[c]) * 1)) + 0.0001 24 | dice_c.append((2.0 * ints) / sums) 25 | 26 | return dice_c 27 | 28 | 29 | # conformity value 30 | def conform_n_class(move_img, refer_img): 31 | # list of classes 32 | c_list = np.unique(refer_img) 33 | 34 | conform_c = [] 35 | for c in range(len(c_list)): 36 | # intersection 37 | ints = np.sum(((move_img == c_list[c]) * 1) * ((refer_img == c_list[c]) * 1)) 38 | # sum 39 | sums = np.sum(((move_img == c_list[c]) * 1) + ((refer_img == c_list[c]) * 1)) + 0.0001 40 | # dice 41 | dice_temp = (2.0 * ints) / sums 42 | # conformity 43 | conform_temp = (3*dice_temp - 2) / dice_temp 44 | 45 | conform_c.append(conform_temp) 46 | 47 | return conform_c 48 | 49 | 50 | # Jaccard index 51 | def jaccard_n_class(move_img, refer_img, output_chn): 52 | # list of classes 53 | c_list_old = np.unique(refer_img) 54 | # c_list = [0, 1, 2, 3] 55 | c_list = np.arange(output_chn) 56 | 57 | jaccard_c = [] 58 | for c in range(len(c_list)): 59 | move_img_c = (move_img == c_list[c]) 60 | refer_img_c = (refer_img == c_list[c]) 61 | # intersection 62 | ints = np.sum(np.logical_and(move_img_c, refer_img_c)*1) 63 | # union 64 | uni = np.sum(np.logical_or(move_img_c, refer_img_c)*1) + 0.0001 65 | 66 | jaccard_c.append(ints / uni) 67 | 68 | return jaccard_c 69 | 70 | 71 | # precision and recall 72 | def precision_recall_n_class(move_img, refer_img): 73 | # list of classes 74 | c_list = np.unique(refer_img) 75 | 76 | precision_c = [] 77 | recall_c = [] 78 | for c in range(len(c_list)): 79 | move_img_c = (move_img == c_list[c]) 80 | refer_img_c = (refer_img == c_list[c]) 81 | # intersection 82 | ints = np.sum(np.logical_and(move_img_c, refer_img_c)*1) 83 | # precision 84 | prec = ints / (np.sum(move_img_c*1) + 0.001) 85 | # recall 86 | recall = ints / (np.sum(refer_img_c*1) + 0.001) 87 | 88 | precision_c.append(prec) 89 | recall_c.append(recall) 90 | 91 | return precision_c, recall_c 92 | 93 | # Sensitivity(recall of the positive) 94 | def sensitivity(pred, gt, output_chn): 95 | ''' 96 | calculate the sensitivity and the specificity 97 | :param pred: predictions 98 | :param gt: ground truth 99 | :param output_chn: categories (including background) 100 | :return: A list contains sensitivities and the specificity, the first item is specificity 101 | and the others are sensitivities of other categories 102 | ''' 103 | s_list = np.arange(output_chn) 104 | sensitivity_s = [] 105 | for s in range(output_chn): 106 | # TP 107 | TP = np.sum((pred == s_list[s])*(gt == s_list[s])) 108 | # FN 109 | FN = np.sum((pred != s_list[s])*(gt == s_list[s])) 110 | # sensitivity &specificity(for category 0 means specificity, while others means sensitivity) 111 | sensitivity = TP/(TP+FN+0.0001) 112 | sensitivity_s.append(sensitivity) 113 | return sensitivity_s 114 | 115 | if __name__ == '__main__': 116 | pred = np.array([[1,1,1,1],[1,1,1,1],[0,0,0,0]]) 117 | gt = np.array([[0,0,0,0],[1,1,1,1],[0,1,1,0]]) 118 | sensi = sensitivity(pred, gt, 2) 119 | a =1 -------------------------------------------------------------------------------- /models/trustedseg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import torch.nn as nn 5 | import time 6 | import torch.nn.functional as F 7 | from models.criterions import softmaxBCE_dice,KL,ce_loss,mse_loss,dce_eviloss 8 | from predict import tailor_and_concat 9 | from models.lib.VNet3D import VNet 10 | from models.lib.UNet3DZoo import Unet,AttUnet 11 | from models.lib.TransBTS_downsample8x_skipconnection import TransBTS 12 | 13 | class TMSU(nn.Module): 14 | 15 | def __init__(self, classes, modes, model,input_dims,total_epochs,lambda_epochs=1): 16 | """ 17 | :param classes: Number of classification categories 18 | :param modes: Number of modes 19 | :param classifier_dims: Dimension of the classifier 20 | :param annealing_epoch: KL divergence annealing epoch during training 21 | """ 22 | super(TMSU, self).__init__() 23 | # ---- Net Backbone ---- 24 | if model == 'AU' and input_dims =='four': 25 | self.backbone = AttUnet(in_channels=4, base_channels=16, num_classes=classes) 26 | elif model == 'AU': 27 | self.backbone = AttUnet(in_channels=1, base_channels=16, num_classes=classes) 28 | elif model == 'V'and input_dims =='four': 29 | self.backbone = VNet(n_channels=4, n_classes=classes, n_filters=16, normalization='gn', has_dropout=False) 30 | elif model == 'V': 31 | self.backbone = VNet(n_channels=1, n_classes=classes, n_filters=16, normalization='gn', has_dropout=False) 32 | elif model =='TransU': 33 | _, self.backbone = TransBTS(input_dims=input_dims, _conv_repr=True, _pe_type="learned") 34 | elif model == 'U'and input_dims =='four': 35 | self.backbone = Unet(in_channels=4, base_channels=16, num_classes=classes) 36 | else: 37 | self.backbone = Unet(in_channels=1, base_channels=16, num_classes=classes) 38 | self.backbone.cuda() 39 | self.modes = modes 40 | self.classes = classes 41 | self.eps = 1e-10 42 | self.lambda_epochs = lambda_epochs 43 | self.total_epochs = total_epochs+1 44 | # self.Classifiers = nn.ModuleList([Classifier(classifier_dims[i], self.classes) for i in range(self.modes)]) 45 | 46 | def forward(self, X, y, global_step, mode, use_TTA=False): 47 | # X data 48 | # y target 49 | # global_step : epochs 50 | 51 | # step zero: backbone 52 | if mode == 'train': 53 | backbone_output = self.backbone(X) 54 | elif mode == 'val': 55 | backbone_output = tailor_and_concat(X, self.backbone) 56 | # backbone_X = F.softmax(backbone_X,dim=1) 57 | else: 58 | if not use_TTA: 59 | backbone_output = tailor_and_concat(X, self.backbone) 60 | # backbone_X = F.softmax(backbone_X,dim=1) 61 | else: 62 | x = X 63 | x = x[..., :155] 64 | logit = F.softmax(tailor_and_concat(x, self.backbone), 1) # no flip 65 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2,)), self.backbone).flip(dims=(2,)), 1) # flip H 66 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3,)), self.backbone).flip(dims=(3,)), 1) # flip W 67 | logit += F.softmax(tailor_and_concat(x.flip(dims=(4,)), self.backbone).flip(dims=(4,)), 1) # flip D 68 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3)), self.backbone).flip(dims=(2, 3)), 69 | 1) # flip H, W 70 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 4)), self.backbone).flip(dims=(2, 4)), 71 | 1) # flip H, D 72 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3, 4)), self.backbone).flip(dims=(3, 4)), 73 | 1) # flip W, D 74 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3, 4)), self.backbone).flip(dims=(2, 3, 4)), 75 | 1) # flip H, W, D 76 | backbone_output = logit / 8.0 # mean 77 | # backbone_X = F.softmax(backbone_X,dim=1) 78 | 79 | # step one 80 | evidence = self.infer(backbone_output) # batch_size * class * image_size 81 | 82 | # step two 83 | alpha = evidence + 1 84 | if mode == 'train' or mode == 'val': 85 | loss = dce_eviloss(y.to(torch.int64), alpha, self.classes, global_step, self.lambda_epochs) 86 | loss = torch.mean(loss) 87 | return evidence, loss 88 | else: 89 | return evidence 90 | 91 | def infer(self, input): 92 | """ 93 | :param input: modal data 94 | :return: evidence of modal data 95 | """ 96 | # evidence = (input-torch.min(input))/(torch.max(input)-torch.min(input)) 97 | evidence = F.softplus(input) 98 | # evidence[m_num] = torch.exp(torch.clamp(evidence, -10, 10)) 99 | # evidence = F.relu(evidence) 100 | return evidence 101 | 102 | -------------------------------------------------------------------------------- /numpyfunctions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics as metrics 3 | import pymia.evaluation.metric as m 4 | 5 | 6 | def ece_binary(probabilities, target, n_bins=10, threshold_range: tuple = None, mask=None, out_bins: dict = None, 7 | bin_weighting='proportion'): 8 | 9 | n_dim = target.ndim 10 | 11 | pos_frac, mean_confidence, bin_count, non_zero_bins = \ 12 | binary_calibration(probabilities, target, n_bins, threshold_range, mask) 13 | 14 | bin_proportions = _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim) 15 | 16 | if out_bins is not None: 17 | out_bins['bins_count'] = bin_count 18 | out_bins['bins_avg_confidence'] = mean_confidence 19 | out_bins['bins_positive_fraction'] = pos_frac 20 | out_bins['bins_non_zero'] = non_zero_bins 21 | 22 | ece = (np.abs(mean_confidence - pos_frac) * bin_proportions).sum() 23 | return ece 24 | 25 | 26 | def binary_calibration(probabilities, target, n_bins=10, threshold_range: tuple = None, mask=None): 27 | if probabilities.ndim > target.ndim: 28 | if probabilities.shape[-1] > 2: 29 | raise ValueError('can only evaluate the calibration for binary classification') 30 | elif probabilities.shape[-1] == 2: 31 | probabilities = probabilities[..., 1] 32 | else: 33 | probabilities = np.squeeze(probabilities, axis=-1) 34 | 35 | if mask is not None: 36 | probabilities = probabilities[mask] 37 | target = target[mask] 38 | 39 | if threshold_range is not None: 40 | low_thres, up_thres = threshold_range 41 | mask = np.logical_and(probabilities < up_thres, probabilities > low_thres) 42 | probabilities = probabilities[mask] 43 | target = target[mask] 44 | 45 | pos_frac, mean_confidence, bin_count, non_zero_bins = \ 46 | _binary_calibration(target.flatten(), probabilities.flatten(), n_bins) 47 | 48 | return pos_frac, mean_confidence, bin_count, non_zero_bins 49 | 50 | 51 | def _binary_calibration(target, probs_positive_cls, n_bins=10): 52 | # same as sklearn.calibration calibration_curve but with the bin_count returned 53 | bins = np.linspace(0., 1. + 1e-8, n_bins + 1) 54 | binids = np.digitize(probs_positive_cls, bins) - 1 55 | 56 | # # note: this is the original formulation which has always n_bins + 1 as length 57 | # bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=len(bins)) 58 | # bin_true = np.bincount(binids, weights=target, minlength=len(bins)) 59 | # bin_total = np.bincount(binids, minlength=len(bins)) 60 | 61 | bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=n_bins) 62 | bin_true = np.bincount(binids, weights=target, minlength=n_bins) 63 | bin_total = np.bincount(binids, minlength=n_bins) 64 | 65 | nonzero = bin_total != 0 66 | prob_true = (bin_true[nonzero] / bin_total[nonzero]) 67 | prob_pred = (bin_sums[nonzero] / bin_total[nonzero]) 68 | 69 | return prob_true, prob_pred, bin_total[nonzero], nonzero 70 | 71 | 72 | def _get_proportion(bin_weighting: str, bin_count: np.ndarray, non_zero_bins: np.ndarray, n_dim: int): 73 | if bin_weighting == 'proportion': 74 | bin_proportions = bin_count / bin_count.sum() 75 | elif bin_weighting == 'log_proportion': 76 | bin_proportions = np.log(bin_count) / np.log(bin_count).sum() 77 | elif bin_weighting == 'power_proportion': 78 | bin_proportions = bin_count**(1/n_dim) / (bin_count**(1/n_dim)).sum() 79 | elif bin_weighting == 'mean_proportion': 80 | bin_proportions = 1 / non_zero_bins.sum() 81 | else: 82 | raise ValueError('unknown bin weighting "{}"'.format(bin_weighting)) 83 | return bin_proportions 84 | 85 | 86 | def uncertainty(prediction, target, thresholded_uncertainty, mask=None): 87 | if mask is not None: 88 | prediction = prediction[mask] 89 | target = target[mask] 90 | thresholded_uncertainty = thresholded_uncertainty[mask] 91 | 92 | tps = np.logical_and(target, prediction) 93 | tns = np.logical_and(~target, ~prediction) 94 | fps = np.logical_and(~target, prediction) 95 | fns = np.logical_and(target, ~prediction) 96 | 97 | tpu = np.logical_and(tps, thresholded_uncertainty).sum() 98 | tnu = np.logical_and(tns, thresholded_uncertainty).sum() 99 | fpu = np.logical_and(fps, thresholded_uncertainty).sum() 100 | fnu = np.logical_and(fns, thresholded_uncertainty).sum() 101 | 102 | tp = tps.sum() 103 | tn = tns.sum() 104 | fp = fps.sum() 105 | fn = fns.sum() 106 | 107 | return tp, tn, fp, fn, tpu, tnu, fpu, fnu 108 | 109 | 110 | def error_dice(fp, fn, tpu, tnu, fpu, fnu): 111 | if ((fnu + fpu) == 0) and ((fn + fp + fnu + fpu + tnu + tpu) == 0): 112 | return 1. 113 | return (2 * (fnu + fpu)) / (fn + fp + fnu + fpu + tnu + tpu) 114 | 115 | 116 | def error_recall(fp, fn, fpu, fnu): 117 | if ((fnu + fpu) == 0) and ((fn + fp) == 0): 118 | return 1. 119 | return (fnu + fpu) / (fn + fp) 120 | 121 | 122 | def error_precision(tpu, tnu, fpu, fnu): 123 | if ((fnu + fpu) == 0) and ((fnu + fpu + tpu + tnu) == 0): 124 | return 1. 125 | return (fnu + fpu) / (fnu + fpu + tpu + tnu) 126 | 127 | 128 | def dice(prediction, target): 129 | _check_ndarray(prediction) 130 | _check_ndarray(target) 131 | 132 | d = m.DiceCoefficient() 133 | d.confusion_matrix = m.ConfusionMatrix(prediction, target) 134 | return d.calculate() 135 | 136 | 137 | def confusion_matrx(prediction, target): 138 | _check_ndarray(prediction) 139 | _check_ndarray(target) 140 | 141 | cm = m.ConfusionMatrix(prediction, target) 142 | return cm.tp, cm.tn, cm.fp, cm.fn, cm.n 143 | 144 | 145 | def accuracy(prediction, target): 146 | _check_ndarray(prediction) 147 | _check_ndarray(target) 148 | 149 | a = m.Accuracy() 150 | a.confusion_matrix = m.ConfusionMatrix(prediction, target) 151 | return a.calculate() 152 | 153 | 154 | def log_loss_sklearn(probabilities, target, labels=None): 155 | _check_ndarray(probabilities) 156 | _check_ndarray(target) 157 | 158 | if probabilities.shape[-1] != target.shape[-1]: 159 | probabilities = probabilities.reshape(-1, probabilities.shape[-1]) 160 | else: 161 | probabilities = probabilities.reshape(-1) 162 | target = target.reshape(-1) 163 | return metrics.log_loss(target, probabilities, labels=labels) 164 | 165 | 166 | def entropy(p, dim=-1, keepdims=False): 167 | # exactly the same as scipy.stats.entropy() 168 | return -np.where(p > 0, p * np.log(p), [0.0]).sum(axis=dim, keepdims=keepdims) 169 | 170 | 171 | def _check_ndarray(obj): 172 | if not isinstance(obj, np.ndarray): 173 | raise ValueError("object of type '{}' must be '{}'".format(type(obj).__name__, np.ndarray.__name__)) 174 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | def loss_plot(args,loss): 4 | num = args.end_epoch 5 | x = [i for i in range(num)] 6 | plot_save_path = r'results/plot/' 7 | if not os.path.exists(plot_save_path): 8 | os.makedirs(plot_save_path) 9 | save_loss = plot_save_path+str(args.model_name)+'_'+str(args.batch_size)+'_'+str(args.dataset)+'_'+str(args.end_epoch)+'_loss.jpg' 10 | plt.figure() 11 | plt.plot(x,loss,label='loss') 12 | plt.legend() 13 | plt.savefig(save_loss) 14 | 15 | def metrics_plot(arg,name,*args): 16 | num = arg.end_epoch 17 | names = name.split('&') 18 | metrics_value = args 19 | i=0 20 | x = [i for i in range(num)] 21 | plot_save_path = r'results/plot/' 22 | if not os.path.exists(plot_save_path): 23 | os.makedirs(plot_save_path) 24 | save_metrics = plot_save_path + str(arg.model_name) + '_' + str(arg.batch_size) + '_' + str(arg.dataset) + '_' + str(arg.end_epoch) + '_'+name+'.jpg' 25 | plt.figure() 26 | for l in metrics_value: 27 | plt.plot(x,l,label=str(names[i])) 28 | #plt.scatter(x,l,label=str(l)) 29 | i+=1 30 | plt.legend() 31 | plt.savefig(save_metrics) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import SimpleITK as sitk 9 | import cv2 10 | import math 11 | from medpy.metric import binary 12 | from sklearn.externals import joblib 13 | from binary import assd 14 | cudnn.benchmark = True 15 | # import models.criterions as U_entropy 16 | import numpy as np 17 | import nibabel as nib 18 | import imageio 19 | from test_uncertainty import ece_binary,UncertaintyAndCorrectionEvalNumpy,Normalized_U 20 | 21 | def cal_ueo(to_evaluate,thresholds): 22 | UEO = [] 23 | for threshold in thresholds: 24 | results = dict() 25 | metric = UncertaintyAndCorrectionEvalNumpy(threshold) 26 | metric(to_evaluate,results) 27 | ueo = results['corrected_add_dice'] 28 | UEO.append(ueo) 29 | max_UEO = max(UEO) 30 | return max_UEO 31 | 32 | def cal_ece(logits,targets): 33 | # ece_total = 0 34 | logit = logits 35 | target = targets 36 | pred = F.softmax(logit, dim=0) 37 | pc = pred.cpu().detach().numpy() 38 | pc = pc.argmax(0) 39 | ece = ece_binary(pc, target) 40 | return ece 41 | 42 | def cal_ece_our(preds,targets): 43 | # ece_total = 0 44 | target = targets 45 | pc = preds.cpu().detach().numpy() 46 | ece = ece_binary(pc, target) 47 | return ece 48 | 49 | def Uentropy(logits,c): 50 | # c = 4 51 | # logits = torch.randn(1, 4, 240, 240,155).cuda() 52 | pc = F.softmax(logits, dim=1) # 1 4 240 240 155 53 | logits = F.log_softmax(logits, dim=1) # 1 4 240 240 155 54 | u_all = -pc * logits / math.log(c) 55 | NU = torch.sum(u_all[:,1:u_all.shape[1],:,:], dim=1) 56 | return NU 57 | 58 | def Uentropy_our(logits,c): 59 | # c = 4 60 | # logits = torch.randn(1, 4, 240, 240,155).cuda() 61 | pc = logits # 1 4 240 240 155 62 | logpc = torch.log(logits) # 1 4 240 240 155 63 | # u_all = -pc * logpc / c 64 | u_all = -pc * logpc / math.log(c) 65 | # max_u = torch.max(u_all) 66 | # min_u = torch.min(u_all) 67 | # NU1 = torch.sum(u_all, dim=1) 68 | # k = u_all.shape[1] 69 | # NU2 = torch.sum(u_all[:, 0:u_all.shape[1]-1, :, :], dim=1) 70 | NU = torch.sum(u_all[:,1:u_all.shape[1],:,:], dim=1) 71 | return NU 72 | 73 | def one_hot(ori, classes): 74 | 75 | batch, h, w, d = ori.size() 76 | new_gd = torch.zeros((batch, classes, h, w, d), dtype=ori.dtype).cuda() 77 | for j in range(classes): 78 | index_list = (ori == j).nonzero() 79 | 80 | for i in range(len(index_list)): 81 | batch, height, width, depth = index_list[i] 82 | new_gd[batch, j, height, width, depth] = 1 83 | 84 | return new_gd.float() 85 | 86 | def tailor_and_concat(x, model): 87 | temp = [] 88 | 89 | temp.append(x[..., :128, :128, :128]) 90 | temp.append(x[..., :128, 112:240, :128]) 91 | temp.append(x[..., 112:240, :128, :128]) 92 | temp.append(x[..., 112:240, 112:240, :128]) 93 | temp.append(x[..., :128, :128, 27:155]) 94 | temp.append(x[..., :128, 112:240, 27:155]) 95 | temp.append(x[..., 112:240, :128, 27:155]) 96 | temp.append(x[..., 112:240, 112:240, 27:155]) 97 | 98 | if x.shape[1] == 1: 99 | y = torch.cat((x.clone(), x.clone(), x.clone(), x.clone()), 1) 100 | elif x.shape[1] == 4: 101 | y = x.clone() 102 | else: 103 | y = torch.cat((x.clone(), x.clone()), 1) 104 | 105 | for i in range(len(temp)): 106 | temp[i] = model(temp[i]) 107 | # .squeeze(0) 108 | # l= temp[0].unsqueeze(0) 109 | y[..., :128, :128, :128] = temp[0] 110 | y[..., :128, 128:240, :128] = temp[1][..., :, 16:128, :] 111 | y[..., 128:240, :128, :128] = temp[2][..., 16:128, :, :] 112 | y[..., 128:240, 128:240, :128] = temp[3][..., 16:128, 16:128, :] 113 | y[..., :128, :128, 128:155] = temp[4][..., 96:123] 114 | y[..., :128, 128:240, 128:155] = temp[5][..., :, 16:128, 96:123] 115 | y[..., 128:240, :128, 128:155] = temp[6][..., 16:128, :, 96:123] 116 | y[..., 128:240, 128:240, 128:155] = temp[7][..., 16:128, 16:128, 96:123] 117 | 118 | return y[..., :155] 119 | 120 | def hausdorff_distance(lT,lP): 121 | labelPred=sitk.GetImageFromArray(lP, isVector=False) 122 | labelTrue=sitk.GetImageFromArray(lT, isVector=False) 123 | hausdorffcomputer=sitk.HausdorffDistanceImageFilter() 124 | hausdorffcomputer.Execute(labelTrue>0.5,labelPred>0.5) 125 | return hausdorffcomputer.GetAverageHausdorffDistance()#hausdorffcomputer.GetHausdorffDistance() 126 | 127 | def hd_score(o,t, eps=1e-8): 128 | if (o.sum()==0) | (t.sum()==0): 129 | hd = eps, 130 | else: 131 | #ret += hausdorff_distance(wt_mask, wt_pb), 132 | hd = binary.hd95(o, t, voxelspacing=None), 133 | 134 | return hd 135 | 136 | def dice_score(o, t, eps=1e-8): 137 | if (o.sum()==0) | (t.sum()==0): 138 | dice = eps 139 | else: 140 | num = 2*(o*t).sum() + eps 141 | den = o.sum() + t.sum() + eps 142 | dice = num/den 143 | return dice 144 | 145 | 146 | def mIOU(o, t, eps=1e-8): 147 | num = (o*t).sum() + eps 148 | den = (o | t).sum() + eps 149 | return num/den 150 | 151 | def assd_score(o, t): 152 | s = assd(o, t) 153 | return s 154 | 155 | def softmax_mIOU_score(output, target): 156 | mIOU_score = [] 157 | mIOU_score.append(mIOU(o=(output==1),t=(target==1))) 158 | mIOU_score.append(mIOU(o=(output==2),t=(target==2))) 159 | mIOU_score.append(mIOU(o=(output==3),t=(target==3))) 160 | return mIOU_score 161 | 162 | def softmax_output_hd(output, target): 163 | ret = [] 164 | 165 | # whole (label: 1 ,2 ,3) 166 | o = output > 0; t = target > 0 # ce 167 | ret += hd_score(o, t), 168 | # core (tumor core 1 and 3) 169 | o = (output == 1) | (output == 3) 170 | t = (target == 1) | (target == 3) 171 | ret += hd_score(o, t), 172 | # active (enhanccing tumor region 1 )# 3 173 | o = (output == 3);t = (target == 3) 174 | ret += hd_score(o, t), 175 | 176 | return ret 177 | 178 | def softmax_output_assd(output, target): 179 | ret = [] 180 | 181 | # whole (label: 1 ,2 ,3) 182 | wt_o = output > 0; wt_t = target > 0 # ce 183 | ret += assd_score(wt_o, wt_t), 184 | # core (tumor core 1 and 3) 185 | tc_o = (output == 1) | (output == 3) 186 | tc_t = (target == 1) | (target == 3) 187 | ret += assd_score(tc_o, tc_t), 188 | # active (enhanccing tumor region 1 )# 3 189 | et_o = (output == 3);et_t = (target == 3) 190 | ret += assd_score(et_o, et_t), 191 | 192 | return ret 193 | 194 | def softmax_output_dice(output, target): 195 | ret = [] 196 | 197 | # whole (label: 1 ,2 ,3) 198 | o = output > 0; t = target > 0 # ce 199 | # print(o.shape) 200 | # print(t.shape) 201 | ret += dice_score(o, t), 202 | # core (tumor core 1 and 3) 203 | o = (output == 1) | (output == 3) 204 | t = (target == 1) | (target == 3) 205 | ret += dice_score(o, t), 206 | # active (enhanccing tumor region 1 )# 3 207 | o = (output == 3);t = (target == 3) 208 | ret += dice_score(o, t), 209 | 210 | return ret 211 | 212 | 213 | keys = 'whole', 'core', 'enhancing', 'loss' 214 | 215 | def validate_softmax( 216 | save_dir, 217 | best_dice, 218 | current_epoch, 219 | end_epoch, 220 | save_freq, 221 | valid_loader, 222 | model, 223 | multimodel, 224 | Net_name, 225 | names=None,# The names of the patients orderly! 226 | ): 227 | 228 | H, W, T = 240, 240, 160 229 | 230 | runtimes = [] 231 | dice_total = 0 232 | iou_total = 0 233 | num = len(valid_loader) 234 | 235 | for i, data in enumerate(valid_loader): 236 | print('-------------------------------------------------------------------') 237 | msg = 'Subject {}/{}, '.format(i + 1, len(valid_loader)) 238 | x, target = data 239 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 240 | x = x.to(device) 241 | # target = target.to(device) 242 | target = torch.squeeze(target).cpu().numpy() 243 | torch.cuda.synchronize() # add the code synchronize() to correctly count the runtime. 244 | start_time = time.time() 245 | logit = tailor_and_concat(x, model) 246 | 247 | torch.cuda.synchronize() 248 | elapsed_time = time.time() - start_time 249 | logging.info('Single sample test time consumption {:.2f} minutes!'.format(elapsed_time / 60)) 250 | runtimes.append(elapsed_time) 251 | 252 | output = F.softmax(logit, dim=1) 253 | output = output[0, :, :H, :W, :T].cpu().detach().numpy() 254 | output = output.argmax(0) 255 | iou_res = softmax_mIOU_score(output, target[:, :, :155]) 256 | dice_res = softmax_output_dice(output,target[:,:,:155]) 257 | # hd_res = softmax_output_hd(output, target[:, :, :155]) 258 | dice_total += dice_res[1] 259 | iou_total += iou_res[1] 260 | name = str(i) 261 | if names: 262 | name = names[i] 263 | msg += '{:>20}, '.format(name) 264 | print(msg) 265 | print('current_dice:{}'.format(dice_res)) 266 | # print('current_dice:{},hd_res:{}'.format(dice_res,hd_res)) 267 | aver_dice = dice_total / num 268 | aver_iou = iou_total / num 269 | if (current_epoch + 1) % int(save_freq) == 0: 270 | if aver_dice > best_dice\ 271 | or (current_epoch + 1) % int(end_epoch - 1) == 0 \ 272 | or (current_epoch + 1) % int(end_epoch - 2) == 0 \ 273 | or (current_epoch + 1) % int(end_epoch - 3) == 0: 274 | print('aver_dice:{} > best_dice:{}'.format(aver_dice, best_dice)) 275 | logging.info('aver_dice:{} > best_dice:{}'.format(aver_dice, best_dice)) 276 | logging.info('===========>save best model!') 277 | best_dice = aver_dice 278 | print('===========>save best model!') 279 | file_name = os.path.join(save_dir, Net_name +'_' + multimodel + '_epoch_{}.pth'.format(current_epoch)) 280 | torch.save({ 281 | 'epoch': current_epoch, 282 | 'state_dict': model.state_dict(), 283 | }, 284 | file_name) 285 | print('runtimes:', sum(runtimes)/len(runtimes)) 286 | 287 | return best_dice,aver_dice,aver_iou 288 | 289 | def test_softmax( 290 | test_loader, 291 | model, 292 | multimodel, 293 | Net_name, 294 | Variance, 295 | load_file, 296 | savepath='', # when in validation set, you must specify the path to save the 'nii' segmentation results here 297 | names=None, # The names of the patients orderly! 298 | verbose=False, 299 | use_TTA=False, # Test time augmentation, False as default! 300 | save_format=None, # ['nii','npy'], use 'nii' as default. Its purpose is for submission. 301 | # snapshot=False, # for visualization. Default false. It is recommended to generate the visualized figures. 302 | # visual='', # the path to save visualization 303 | ): 304 | 305 | H, W, T = 240, 240, 155 306 | # model.eval() 307 | 308 | runtimes = [] 309 | dice_total_WT = 0 310 | dice_total_TC = 0 311 | dice_total_ET = 0 312 | hd_total_WT = 0 313 | hd_total_TC = 0 314 | hd_total_ET = 0 315 | assd_total_WT = 0 316 | assd_total_TC = 0 317 | assd_total_ET = 0 318 | 319 | noise_dice_total_WT = 0 320 | noise_dice_total_TC = 0 321 | noise_dice_total_ET = 0 322 | noise_hd_total_WT = 0 323 | noise_hd_total_TC = 0 324 | noise_hd_total_ET = 0 325 | noise_assd_total_WT = 0 326 | noise_assd_total_TC = 0 327 | noise_assd_total_ET = 0 328 | mean_uncertainty_total = 0 329 | noise_mean_uncertainty_total = 0 330 | certainty_total = 0 331 | noise_certainty_total = 0 332 | num = len(test_loader) 333 | mne_total = 0 334 | noise_mne_total = 0 335 | ece_total = 0 336 | noise_ece_total = 0 337 | ece = 0 338 | noise_ece = 0 339 | ueo_total = 0 340 | noise_ueo_total = 0 341 | for i, data in enumerate(test_loader): 342 | print('-------------------------------------------------------------------') 343 | msg = 'Subject {}/{}, '.format(i+1, len(test_loader)) 344 | x, target = data 345 | # noise_m = torch.randn_like(x) * Variance 346 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance * 2, Variance * 2) 347 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance, Variance) 348 | # noise = torch.clamp(torch.randn_like(x) * Variance) 349 | # noised_x = x + noise_m 350 | 351 | noise_m = torch.randn_like(x) * Variance 352 | noised_x = x + noise_m 353 | # if multimodel=='both': 354 | # noised_x[:, 0, ...] = x[:, 0, ...] 355 | x.cuda() 356 | noised_x.cuda() 357 | target = torch.squeeze(target).cpu().numpy() 358 | mean_uncertainty = torch.zeros(0) 359 | noised_mean_uncertainty = torch.zeros(0) 360 | # output = np.zeros((4, x.shape[2], x.shape[3], 155),dtype='float32') 361 | # noised_output = np.zeros((4, x.shape[2], x.shape[3], 155),dtype='float32') 362 | pc = np.zeros((x.shape[2], x.shape[3], 155),dtype='float32') 363 | noised_pc = np.zeros((x.shape[2], x.shape[3], 155),dtype='float32') 364 | if not use_TTA: 365 | # torch.cuda.synchronize() # add the code synchronize() to correctly count the runtime. 366 | # start_time = time.time() 367 | model_len = 2.0 # two modality or four modalityv 368 | if Net_name =='Udrop': 369 | T_drop = 2 370 | uncertainty = torch.zeros(1, x.shape[2], x.shape[3], 155) 371 | noised_uncertainty = torch.zeros(1, x.shape[2], x.shape[3], 155) 372 | for j in range(T_drop): 373 | print('dropout time:{}'.format(j)) 374 | logit = tailor_and_concat(x, model) # 1 4 240 240 155 375 | logit_noise = tailor_and_concat(noised_x, model) # 1 4 240 240 155 376 | uncertainty += Uentropy(logit, 4) 377 | noised_uncertainty += Uentropy(logit_noise, 4) 378 | logit = F.softmax(logit, dim=1) 379 | output = logit / model_len 380 | output = output[0, :, :H, :W, :T].cpu().detach().numpy() 381 | pc += output.argmax(0) 382 | # for noise 383 | logit_noise = F.softmax(logit_noise, dim=1) 384 | noised_output = logit_noise / model_len 385 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy() 386 | noised_pc += noised_output.argmax(0) 387 | pc = pc / T_drop 388 | noised_pc = noised_pc / T_drop 389 | uncertainty = torch.squeeze(uncertainty) / T_drop 390 | noised_uncertainty = torch.squeeze(noised_uncertainty) / T_drop 391 | 392 | # logit = logit/T_drop 393 | # logit_noise= logit_noise/T_drop 394 | # Udropout_uncertainty=joblib.load('Udropout_uncertainty.pkl') 395 | else: 396 | logit = tailor_and_concat(x, model) 397 | logit = F.softmax(logit, dim=1) 398 | output = logit / model_len 399 | output = output[0, :, :H, :W, :T].cpu().detach().numpy() 400 | pc = output.argmax(0) 401 | # for input noise 402 | logit_noise = tailor_and_concat(noised_x, model) 403 | logit_noise = F.softmax(logit_noise, dim=1) 404 | noised_output = logit_noise / model_len 405 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy() 406 | noised_pc = noised_output.argmax(0) 407 | uncertainty = Uentropy(logit, 4) 408 | noised_uncertainty = Uentropy(logit_noise, 4) 409 | mean_uncertainty = torch.mean(uncertainty) 410 | noised_mean_uncertainty = torch.mean(noised_uncertainty) 411 | print('current_mean_uncertainty:{} ; current_noised_mean_uncertainty:{}'.format(mean_uncertainty, noised_mean_uncertainty)) 412 | if Net_name == 'Udrop': 413 | joblib.dump({'pc': pc, 414 | 'noised_pc': noised_pc,'noised_uncertainty': noised_uncertainty, 415 | 'uncertainty': uncertainty}, 'Udropout_uncertainty_{}.pkl'.format(i)) 416 | 417 | # Udropout_uncertainty = joblib.load('Udropout_uncertainty.pkl') 418 | 419 | # lnear = F.softplus(logit) 420 | # torch.cuda.synchronize() 421 | # elapsed_time = time.time() - start_time 422 | # logging.info('Single sample test time consumption {:.2f} minutes!'.format(elapsed_time/60)) 423 | # runtimes.append(elapsed_time) 424 | 425 | 426 | # if multimodel == 'both': 427 | # model_len = 2.0 # two modality or four modality 428 | # # logit = F.softmax(logit, dim=1) 429 | # # output = logit / model_len 430 | # 431 | # # for noise 432 | # logit_noise = F.softmax(logit_noise, dim=1) 433 | # noised_output = logit_noise / model_len 434 | 435 | # load_file1 = load_file.replace('7998', '7996') 436 | # if os.path.isfile(load_file1): 437 | # checkpoint = torch.load(load_file1) 438 | # model.load_state_dict(checkpoint['state_dict']) 439 | # print('Successfully load checkpoint {}'.format(load_file1)) 440 | # logit = tailor_and_concat(x, model) 441 | # logit = F.softmax(logit, dim=1) 442 | # output += logit / model_len 443 | # load_file1 = load_file.replace('7998', '7997') 444 | # if os.path.isfile(load_file1): 445 | # checkpoint = torch.load(load_file1) 446 | # model.load_state_dict(checkpoint['state_dict']) 447 | # print('Successfully load checkpoint {}'.format(load_file1)) 448 | # logit = tailor_and_concat(x, model) 449 | # logit = F.softmax(logit, dim=1) 450 | # output += logit / model_len 451 | # load_file1 = load_file.replace('7998', '7999') 452 | # if os.path.isfile(load_file1): 453 | # checkpoint = torch.load(load_file1) 454 | # model.load_state_dict(checkpoint['state_dict']) 455 | # print('Successfully load checkpoint {}'.format(load_file1)) 456 | # logit = tailor_and_concat(x, model) 457 | # logit = F.softmax(logit, dim=1) 458 | # output += logit / model_len 459 | # else: 460 | # # output = F.softmax(logit, dim=1) 461 | # noised_output = F.softmax(logit_noise, dim=1) 462 | else: 463 | # x = x[..., :155] 464 | logit = F.softmax(tailor_and_concat(x, model), 1) # no flip 465 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H 466 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W 467 | logit += F.softmax(tailor_and_concat(x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D 468 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W 469 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D 470 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D 471 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D 472 | # for noise x 473 | noised_x = noised_x[..., :155] 474 | noised_logit = F.softmax(tailor_and_concat(noised_x, model), 1) # no flip 475 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H 476 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W 477 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D 478 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W 479 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D 480 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D 481 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D 482 | output = logit / 8.0 # mean 483 | noised_output = noised_logit / 8.0 # mean 484 | uncertainty = Uentropy(output, 4) 485 | noised_uncertainty = Uentropy(noised_output, 4) 486 | output = output[0, :, :H, :W, :T].cpu().detach().numpy() 487 | pc = output.argmax(0) 488 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy() 489 | noised_pc = noised_output.argmax(0) 490 | mean_uncertainty = torch.mean(uncertainty) 491 | noised_mean_uncertainty = torch.mean(noised_uncertainty) 492 | output = pc 493 | noised_output = noised_pc 494 | U_output = uncertainty.cpu().detach().numpy() 495 | NU_output = noised_uncertainty.cpu().detach().numpy() 496 | certainty_total += mean_uncertainty # mix _uncertainty mean_uncertainty mean_uncertainty_succ 497 | noise_certainty_total += noised_mean_uncertainty # noised_mix_uncertainty noised_mean_uncertainty noised_mean_uncertainty_succ 498 | # ece 499 | ece_total += ece 500 | noise_ece_total += noise_ece 501 | # ueo 502 | # target = torch.squeeze(target).cpu().numpy() 503 | thresholds = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95] 504 | to_evaluate = dict() 505 | to_evaluate['target'] = target[:, :, :155] 506 | u = torch.squeeze(uncertainty) 507 | U = u.cpu().detach().numpy() 508 | to_evaluate['prediction'] = output 509 | to_evaluate['uncertainty'] = U 510 | UEO = cal_ueo(to_evaluate, thresholds) 511 | ueo_total += UEO 512 | noise_to_evaluate = dict() 513 | noise_to_evaluate['target'] = target[:, :, :155] 514 | noise_u = torch.squeeze(noised_uncertainty) 515 | noise_U = noise_u.cpu().detach().numpy() 516 | noise_to_evaluate['prediction'] = noised_output 517 | noise_to_evaluate['uncertainty'] = noise_U 518 | noise_UEO = cal_ueo(noise_to_evaluate, thresholds) 519 | print('current_UEO:{};current_noise_UEO:{}; current_num:{}'.format(UEO, noise_UEO, i)) 520 | noise_ueo_total += noise_UEO 521 | # print(output.shape) 522 | # print(target.shape) 523 | # output = output[0, :, :H, :W, :T].cpu().detach().numpy() 524 | # output = output.argmax(0) 525 | # print(output.shape) 526 | # iou_res = softmax_mIOU_score(pc, target[:, :, :155]) 527 | hd_res = softmax_output_hd(pc, target[:, :, :155]) 528 | dice_res = softmax_output_dice(pc,target[:,:,:155]) 529 | assd_res = softmax_output_assd(pc,target[:, :, :155]) 530 | dice_total_WT += dice_res[0] 531 | dice_total_TC += dice_res[1] 532 | dice_total_ET += dice_res[2] 533 | hd_total_WT += hd_res[0][0] 534 | hd_total_TC += hd_res[1][0] 535 | hd_total_ET += hd_res[2][0] 536 | assd_total_WT += assd_res[0] 537 | assd_total_TC += assd_res[1] 538 | assd_total_ET += assd_res[2] 539 | 540 | # for noise_x 541 | noised_output = noised_pc 542 | # noise_iou_res = softmax_mIOU_score(noised_pc, target[:, :, :155]) 543 | noise_hd_res = softmax_output_hd(noised_pc, target[:, :, :155]) 544 | noise_dice_res = softmax_output_dice(noised_pc,target[:,:,:155]) 545 | noised_assd_res = softmax_output_assd(noised_pc, 546 | target[:, :, :155]) 547 | noise_dice_total_WT += noise_dice_res[0] 548 | noise_dice_total_TC += noise_dice_res[1] 549 | noise_dice_total_ET += noise_dice_res[2] 550 | noise_hd_total_WT += noise_hd_res[0][0] 551 | noise_hd_total_TC += noise_hd_res[1][0] 552 | noise_hd_total_ET += noise_hd_res[2][0] 553 | noise_assd_total_WT += noised_assd_res[0] 554 | noise_assd_total_TC += noised_assd_res[1] 555 | noise_assd_total_ET += noised_assd_res[2] 556 | mean_uncertainty_total += mean_uncertainty 557 | noise_mean_uncertainty_total += noised_mean_uncertainty 558 | name = str(i) 559 | if names: 560 | name = names[i] 561 | msg += '{:>20}, '.format(name) 562 | 563 | print(msg) 564 | snapshot= False # True 565 | if snapshot: 566 | """ --- grey figure---""" 567 | # Snapshot_img = np.zeros(shape=(H,W,T),dtype=np.uint8) 568 | # Snapshot_img[np.where(output[1,:,:,:]==1)] = 64 569 | # Snapshot_img[np.where(output[2,:,:,:]==1)] = 160 570 | # Snapshot_img[np.where(output[3,:,:,:]==1)] = 255 571 | """ --- colorful figure--- """ 572 | Snapshot_img = np.zeros(shape=(H, W, 3, T), dtype=np.float32) 573 | # K = [np.where(output[0,:,:,:] == 1)] 574 | Snapshot_img[:, :, 0, :][np.where(output == 1)] = 255 575 | Snapshot_img[:, :, 1, :][np.where(output == 2)] = 255 576 | Snapshot_img[:, :, 2, :][np.where(output == 3)] = 255 577 | 578 | Noise_Snapshot_img = np.zeros(shape=(H, W, 3, T), dtype=np.float32) 579 | # K = [np.where(output[0,:,:,:] == 1)] 580 | Noise_Snapshot_img[:, :, 0, :][np.where(noised_output == 1)] = 255 581 | Noise_Snapshot_img[:, :, 1, :][np.where(noised_output == 2)] = 255 582 | Noise_Snapshot_img[:, :, 2, :][np.where(noised_output == 3)] = 255 583 | # target_img = np.zeros(shape=(H, W, 3, T), dtype=np.float32) 584 | # K = [np.where(output[0,:,:,:] == 1)] 585 | # target_img[:, :, 0, :][np.where(Otarget == 1)] = 255 586 | # target_img[:, :, 1, :][np.where(Otarget == 2)] = 255 587 | # target_img[:, :, 2, :][np.where(Otarget == 3)] = 255 588 | 589 | for frame in range(T): 590 | if not os.path.exists(os.path.join(savepath, str(Net_name), str(Variance), name)): 591 | os.makedirs(os.path.join(savepath, str(Net_name), str(Variance), name)) 592 | 593 | # scipy.misc.imsave(os.path.join(visual, name, str(frame)+'.png'), Snapshot_img[:, :, :, frame]) 594 | imageio.imwrite(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '.png'), 595 | Snapshot_img[:, :, :, frame]) 596 | imageio.imwrite(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_noised.png'), 597 | Noise_Snapshot_img[:, :, :, frame]) 598 | # imageio.imwrite(os.path.join(savepath, str(Net_name), name, str(frame) + '_gt.png'), 599 | # target_img[:, :, :, frame]) 600 | # im0 = Image.fromarray(U_output[:, :, frame]) 601 | # im1 = Image.fromarray(U_output[:, :, frame]) 602 | # im2 = Image.fromarray(U_output[:, :, frame]) 603 | # im0 = im0.convert('RGB') 604 | # im1 = im1.convert('RGB') 605 | # im2 = im2.convert('RGB') 606 | # im0.save(os.path.join(savepath, name, str(frame) + '_uncertainty.png')) 607 | # im1.save(os.path.join(savepath, name, str(frame) + '_input_T1.png')) 608 | # im2.save(os.path.join(savepath, name, str(frame) + '_input_T2.png')) 609 | # U_CV = cv2.cvtColor(U_output[:, :, frame], cv2.COLOR_GRAY2BGR) 610 | # U_heatmap = cv2.applyColorMap(U_CV, cv2.COLORMAP_JET) 611 | # cv2.imwrite(os.path.join(savepath, name, str(frame) + '_uncertainty.png'), 612 | # U_heatmap) 613 | # NU_CV = cv2.cvtColor(NU_output[:, :, frame], cv2.COLOR_GRAY2BGR) 614 | # NU_heatmap = cv2.applyColorMap(NU_CV, cv2.COLORMAP_JET) 615 | # cv2.imwrite(os.path.join(savepath, name, str(frame) + '_noised_uncertainty.png'), 616 | # NU_heatmap) 617 | imageio.imwrite(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_uncertainty.png'), 618 | U_output[:, :, frame]) 619 | imageio.imwrite( 620 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_noised_uncertainty.png'), 621 | NU_output[:, :, frame]) 622 | U_img = cv2.imread(os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_uncertainty.png')) 623 | U_heatmap = cv2.applyColorMap(U_img, cv2.COLORMAP_JET) 624 | cv2.imwrite( 625 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_colormap_uncertainty.png'), 626 | U_heatmap) 627 | NU_img = cv2.imread( 628 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_noised_uncertainty.png')) 629 | NU_heatmap = cv2.applyColorMap(NU_img, cv2.COLORMAP_JET) 630 | cv2.imwrite( 631 | os.path.join(savepath, str(Net_name), str(Variance), name, str(frame) + '_colormap_noised_uncertainty.png'), 632 | NU_heatmap) 633 | aver_ueo = ueo_total / num 634 | aver_noise_ueo = noise_ueo_total/num 635 | aver_certainty = certainty_total / num 636 | aver_noise_certainty = noise_certainty_total / num 637 | aver_dice_WT = dice_total_WT / num 638 | aver_dice_TC = dice_total_TC / num 639 | aver_dice_ET = dice_total_ET / num 640 | aver_hd_WT = hd_total_WT / num 641 | aver_hd_TC = hd_total_TC / num 642 | aver_hd_ET = hd_total_ET / num 643 | aver_noise_dice_WT = noise_dice_total_WT / num 644 | aver_noise_dice_TC = noise_dice_total_TC / num 645 | aver_noise_dice_ET = noise_dice_total_ET / num 646 | aver_noise_hd_WT = noise_hd_total_WT / num 647 | aver_noise_hd_TC = noise_hd_total_TC / num 648 | aver_noise_hd_ET = noise_hd_total_ET / num 649 | aver_assd_WT = assd_total_WT / num 650 | aver_assd_TC = assd_total_TC / num 651 | aver_assd_ET = assd_total_ET / num 652 | aver_noise_assd_WT = noise_assd_total_WT / num 653 | aver_noise_assd_TC = noise_assd_total_TC / num 654 | aver_noise_assd_ET = noise_assd_total_ET / num 655 | aver_noise_mean_uncertainty = noise_mean_uncertainty_total/num 656 | aver_mean_uncertainty = mean_uncertainty_total/num 657 | print('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100,aver_dice_TC*100, aver_dice_ET*100)) 658 | print('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100)) 659 | print('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (aver_hd_WT,aver_hd_TC, aver_hd_ET)) 660 | print('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET)) 661 | print('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT, aver_assd_TC, aver_assd_ET)) 662 | print('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % ( 663 | aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET)) 664 | # print('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % (aver_noise_mean_uncertainty,aver_mean_uncertainty)) 665 | print('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % (aver_noise_mean_uncertainty,aver_mean_uncertainty)) 666 | print('aver_certainty=%f,aver_noise_certainty = %f' % (aver_certainty, aver_noise_certainty)) 667 | # print('aver_mne=%f,aver_noise_mne = %f' % (aver_mne, aver_noise_mne)) 668 | # print('aver_ece=%f,aver_noise_ece = %f' % (aver_ece, aver_noise_ece)) 669 | print('aver_ueo=%f,aver_noise_ueo = %f' % (aver_ueo, aver_noise_ueo)) 670 | logging.info( 671 | 'aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100, aver_dice_TC*100, aver_dice_ET*100)) 672 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % ( 673 | aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100)) 674 | logging.info('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % ( 675 | aver_hd_WT, aver_hd_TC, aver_hd_ET)) 676 | logging.info('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % ( 677 | aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET)) 678 | logging.info('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT, aver_assd_TC, aver_assd_ET)) 679 | logging.info('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % ( 680 | aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET)) 681 | logging.info('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % ( 682 | aver_noise_mean_uncertainty, aver_mean_uncertainty)) 683 | logging.info('aver_ueo=%f,aver_noise_ueo = %f' % (aver_ueo, aver_noise_ueo)) 684 | # return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET] 685 | return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET],[aver_hd_WT,aver_hd_TC,aver_hd_ET],[aver_noise_hd_WT,aver_noise_hd_TC,aver_noise_hd_ET] 686 | 687 | def testensemblemax( 688 | test_loader, 689 | model, 690 | multimodel, 691 | Net_name, 692 | Variance, 693 | load_file, 694 | savepath='', # when in validation set, you must specify the path to save the 'nii' segmentation results here 695 | names=None, # The names of the patients orderly! 696 | verbose=False, 697 | use_TTA=False, # Test time augmentation, False as default! 698 | save_format=None, # ['nii','npy'], use 'nii' as default. Its purpose is for submission. 699 | # snapshot=False, # for visualization. Default false. It is recommended to generate the visualized figures. 700 | # visual='', # the path to save visualization 701 | ): 702 | 703 | H, W, T = 240, 240, 160 704 | # model.eval() 705 | 706 | runtimes = [] 707 | dice_total_WT = 0 708 | dice_total_TC = 0 709 | dice_total_ET = 0 710 | hd_total_WT = 0 711 | hd_total_TC = 0 712 | hd_total_ET = 0 713 | assd_total_WT = 0 714 | assd_total_TC = 0 715 | assd_total_ET = 0 716 | noise_dice_total_WT = 0 717 | noise_dice_total_TC = 0 718 | noise_dice_total_ET = 0 719 | noise_hd_total_WT = 0 720 | noise_hd_total_TC = 0 721 | noise_hd_total_ET = 0 722 | noise_assd_total_WT = 0 723 | noise_assd_total_TC = 0 724 | noise_assd_total_ET = 0 725 | mean_uncertainty_total = 0 726 | noise_mean_uncertainty_total = 0 727 | num = len(test_loader) 728 | 729 | for i, data in enumerate(test_loader): 730 | print('-------------------------------------------------------------------') 731 | msg = 'Subject {}/{}, '.format(i + 1, len(test_loader)) 732 | x, target = data 733 | noise_m = torch.randn_like(x) * Variance 734 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance * 2, Variance * 2) 735 | # noise = torch.clamp(torch.randn_like(x) * Variance, -Variance, Variance) 736 | # noise = torch.clamp(torch.randn_like(x) * Variance) 737 | noised_x = x + noise_m 738 | x.cuda() 739 | noised_x.cuda() 740 | target = torch.squeeze(target).cpu().numpy() 741 | 742 | if not use_TTA: 743 | torch.cuda.synchronize() # add the code synchronize() to correctly count the runtime. 744 | start_time = time.time() 745 | logit = torch.zeros(x.shape[0], 4, x.shape[2], x.shape[3], 155) 746 | logit_noise = torch.zeros(x.shape[0], 4, x.shape[2], x.shape[3], 155) 747 | # load ensemble models 748 | for j in range(10): 749 | print('ensemble model:{}'.format(i)) 750 | logit += tailor_and_concat(x, model[j]) 751 | logit_noise += tailor_and_concat(noised_x, model[j]) 752 | # calculate ensemble uncertainty by normalized entropy 753 | uncertainty = Uentropy(logit/10, 4) 754 | noised_uncertainty = Uentropy(logit_noise/10, 4) 755 | 756 | U_output = torch.squeeze(uncertainty).cpu().detach().numpy() 757 | noised_U_output = torch.squeeze(noised_uncertainty).cpu().detach().numpy() 758 | joblib.dump({'logit': logit, 'logit_noise': logit_noise, 'uncertainty': uncertainty, 759 | 'noised_uncertainty': noised_uncertainty, 'U_output': U_output, 760 | 'noised_U_output': noised_U_output}, 'Uensemble_uncertainty_{}.pkl'.format(i)) 761 | # lnear = F.softplus(logit) 762 | torch.cuda.synchronize() 763 | elapsed_time = time.time() - start_time 764 | logging.info('Single sample test time consumption {:.2f} minutes!'.format(elapsed_time/60)) 765 | runtimes.append(elapsed_time) 766 | output = F.softmax(logit/10, dim=1) 767 | noised_output = F.softmax(logit_noise/10, dim=1) 768 | 769 | 770 | else: 771 | x = x[..., :155] 772 | logit = F.softmax(tailor_and_concat(x, model), 1) # no flip 773 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H 774 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W 775 | logit += F.softmax(tailor_and_concat(x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D 776 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W 777 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D 778 | logit += F.softmax(tailor_and_concat(x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D 779 | logit += F.softmax(tailor_and_concat(x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D 780 | # for noise x 781 | noised_x = noised_x[..., :155] 782 | noised_logit = F.softmax(tailor_and_concat(noised_x, model), 1) # no flip 783 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2,)), model).flip(dims=(2,)), 1) # flip H 784 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3,)), model).flip(dims=(3,)), 1) # flip W 785 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(4,)), model).flip(dims=(4,)), 1) # flip D 786 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3)), model).flip(dims=(2, 3)), 1) # flip H, W 787 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 4)), model).flip(dims=(2, 4)), 1) # flip H, D 788 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(3, 4)), model).flip(dims=(3, 4)), 1) # flip W, D 789 | noised_logit += F.softmax(tailor_and_concat(noised_x.flip(dims=(2, 3, 4)), model).flip(dims=(2, 3, 4)), 1) # flip H, W, D 790 | output = logit / 8.0 # mean 791 | noised_output = noised_logit / 8.0 # mean 792 | uncertainty = Uentropy(logit, 4) 793 | noised_uncertainty = Uentropy(noised_output, 4) 794 | mean_uncertainty = torch.mean(uncertainty) 795 | noised_mean_uncertainty = torch.mean(noised_uncertainty) 796 | output = output[0, :, :H, :W, :T].cpu().detach().numpy() 797 | output = output.argmax(0) 798 | # iou_res = softmax_mIOU_score(output, target[:, :, :155]) 799 | hd_res = softmax_output_hd(output, target[:, :, :155]) 800 | dice_res = softmax_output_dice(output, target[:, :, :155]) 801 | assd_res = softmax_output_assd(output, target[:, :, :155]) 802 | dice_total_WT += dice_res[0] 803 | dice_total_TC += dice_res[1] 804 | dice_total_ET += dice_res[2] 805 | hd_total_WT += hd_res[0][0] 806 | hd_total_TC += hd_res[1][0] 807 | hd_total_ET += hd_res[2][0] 808 | assd_total_WT += assd_res[0] 809 | assd_total_TC += assd_res[1] 810 | assd_total_ET += assd_res[2] 811 | # for noise_x 812 | noised_output = noised_output[0, :, :H, :W, :T].cpu().detach().numpy() 813 | noised_output = noised_output.argmax(0) 814 | noise_assd_res = softmax_output_assd(noised_output, target[:, :, :155]) 815 | noise_hd_res = softmax_output_hd(noised_output, target[:, :, :155]) 816 | noise_dice_res = softmax_output_dice(noised_output, target[:, :, :155]) 817 | 818 | noise_dice_total_WT += noise_dice_res[0] 819 | noise_dice_total_TC += noise_dice_res[1] 820 | noise_dice_total_ET += noise_dice_res[2] 821 | noise_hd_total_WT += noise_hd_res[0][0] 822 | noise_hd_total_TC += noise_hd_res[1][0] 823 | noise_hd_total_ET += noise_hd_res[2][0] 824 | noise_assd_total_WT += noise_assd_res[0] 825 | noise_assd_total_TC += noise_assd_res[1] 826 | noise_assd_total_ET += noise_assd_res[2] 827 | mean_uncertainty_total += mean_uncertainty 828 | noise_mean_uncertainty_total += noised_mean_uncertainty 829 | name = str(i) 830 | if names: 831 | name = names[i] 832 | msg += '{:>20}, '.format(name) 833 | 834 | print(msg) 835 | print('current_dice:{} ; current_noised_dice:{}'.format(dice_res, noise_dice_res)) 836 | print('current_uncertainty:{} ; current_noised_uncertainty:{}'.format(uncertainty, noised_uncertainty)) 837 | # if savepath: 838 | # # .npy for further model ensemble 839 | # # .nii for directly model submission 840 | # assert save_format in ['npy', 'nii'] 841 | # if save_format == 'npy': 842 | # np.save(os.path.join(savepath, Net_name +'_'+ name + '_preds'), output) 843 | # if save_format == 'nii': 844 | # # raise NotImplementedError 845 | # oname = os.path.join(savepath, Net_name + '_'+ name + '.nii.gz') 846 | # seg_img = np.zeros(shape=(H, W, T), dtype=np.uint8) 847 | # 848 | # seg_img[np.where(output == 1)] = 1 849 | # seg_img[np.where(output == 2)] = 2 850 | # seg_img[np.where(output == 3)] = 3 851 | # if verbose: 852 | # print('1:', np.sum(seg_img == 1), ' | 2:', np.sum(seg_img == 2), ' | 4:', np.sum(seg_img == 4)) 853 | # print('WT:', np.sum((seg_img == 1) | (seg_img == 2) | (seg_img == 4)), ' | TC:', 854 | # np.sum((seg_img == 1) | (seg_img == 4)), ' | ET:', np.sum(seg_img == 4)) 855 | # nib.save(nib.Nifti1Image(seg_img, None), oname) 856 | # print('Successfully save {}'.format(oname)) 857 | 858 | # if snapshot: 859 | # """ --- grey figure---""" 860 | # # Snapshot_img = np.zeros(shape=(H,W,T),dtype=np.uint8) 861 | # # Snapshot_img[np.where(output[1,:,:,:]==1)] = 64 862 | # # Snapshot_img[np.where(output[2,:,:,:]==1)] = 160 863 | # # Snapshot_img[np.where(output[3,:,:,:]==1)] = 255 864 | # """ --- colorful figure--- """ 865 | # Snapshot_img = np.zeros(shape=(H, W, 3, T), dtype=np.uint8) 866 | # Snapshot_img[:, :, 0, :][np.where(output == 1)] = 255 867 | # Snapshot_img[:, :, 1, :][np.where(output == 2)] = 255 868 | # Snapshot_img[:, :, 2, :][np.where(output == 3)] = 255 869 | # 870 | # for frame in range(T): 871 | # if not os.path.exists(os.path.join(visual, name)): 872 | # os.makedirs(os.path.join(visual, name)) 873 | # # scipy.misc.imsave(os.path.join(visual, name, str(frame)+'.png'), Snapshot_img[:, :, :, frame]) 874 | # imageio.imwrite(os.path.join(visual, name, str(frame)+'.png'), Snapshot_img[:, :, :, frame]) 875 | 876 | aver_dice_WT = dice_total_WT / num 877 | aver_dice_TC = dice_total_TC / num 878 | aver_dice_ET = dice_total_ET / num 879 | aver_hd_WT = hd_total_WT / num 880 | aver_hd_TC = hd_total_TC / num 881 | aver_hd_ET = hd_total_ET / num 882 | aver_assd_WT = assd_total_WT / num 883 | aver_assd_TC = assd_total_TC / num 884 | aver_assd_ET = assd_total_ET / num 885 | aver_noise_dice_WT = noise_dice_total_WT / num 886 | aver_noise_dice_TC = noise_dice_total_TC / num 887 | aver_noise_dice_ET = noise_dice_total_ET / num 888 | aver_noise_hd_WT = noise_hd_total_WT / num 889 | aver_noise_hd_TC = noise_hd_total_TC / num 890 | aver_noise_hd_ET = noise_hd_total_ET / num 891 | aver_noise_assd_WT = noise_assd_total_WT / num 892 | aver_noise_assd_TC = noise_assd_total_TC / num 893 | aver_noise_assd_ET = noise_assd_total_ET / num 894 | aver_noise_mean_uncertainty = noise_mean_uncertainty_total/num 895 | aver_mean_uncertainty = mean_uncertainty_total/num 896 | print('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100,aver_dice_TC*100, aver_dice_ET*100)) 897 | print('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100)) 898 | print('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (aver_hd_WT,aver_hd_TC, aver_hd_ET)) 899 | print('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET)) 900 | print('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT,aver_assd_TC, aver_assd_ET)) 901 | print('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % (aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET)) 902 | logging.info( 903 | 'aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice_WT*100, aver_dice_TC*100, aver_dice_ET*100)) 904 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % ( 905 | aver_noise_dice_WT*100, aver_noise_dice_TC*100, aver_noise_dice_ET*100)) 906 | logging.info('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % ( 907 | aver_hd_WT, aver_hd_TC, aver_hd_ET)) 908 | logging.info('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % ( 909 | aver_noise_hd_WT, aver_noise_hd_TC, aver_noise_hd_ET)) 910 | logging.info('aver_assd_WT=%f,aver_assd_TC = %f,aver_assd_ET = %f' % (aver_assd_WT, aver_assd_TC, aver_assd_ET)) 911 | logging.info('aver_noise_assd_WT=%f,aver_noise_assd_TC = %f,aver_noise_assd_ET = %f' % ( 912 | aver_noise_assd_WT, aver_noise_assd_TC, aver_noise_assd_ET)) 913 | logging.info('aver_noise_mean_uncertainty=%f,aver_mean_uncertainty = %f' % ( 914 | aver_noise_mean_uncertainty, aver_mean_uncertainty)) 915 | # return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET] 916 | return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET],[aver_hd_WT,aver_hd_TC,aver_hd_ET],[aver_noise_hd_WT,aver_noise_hd_TC,aver_noise_hd_ET] 917 | # return [aver_dice_WT,aver_dice_TC,aver_dice_ET],[aver_noise_dice_WT,aver_noise_dice_TC,aver_noise_dice_ET],[aver_iou_WT,aver_iou_TC,aver_iou_ET],[aver_noise_iou_WT,aver_noise_iou_TC,aver_noise_iou_ET] -------------------------------------------------------------------------------- /results/nullfile: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test_uncertainty.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import torch 4 | import numpy as np 5 | import math 6 | import abc 7 | import joblib 8 | import joblib 9 | import torch.nn.functional as F 10 | # import torchfunctions as t_fn 11 | import numpyfunctions as np_fn 12 | 13 | def Uentropy(logits,c): 14 | # c = 4 15 | # logits = torch.randn(1, 4, 240, 240,155).cuda() 16 | pc = F.softmax(logits, dim=1) # 1 4 240 240 155 17 | logpc = F.log_softmax(logits, dim=1) # 1 4 240 240 155 18 | # u_all1 = -pc * logpc / c 19 | u_all = -pc * logpc / math.log(c) 20 | # max_u = torch.max(u_all) 21 | # min_u = torch.min(u_all) 22 | NU = torch.sum(u_all, dim=1) 23 | return NU 24 | 25 | class UncertaintyAndCorrectionEvalNumpy(): 26 | def __init__(self, uncertainty_threshold): 27 | super(UncertaintyAndCorrectionEvalNumpy, self).__init__() 28 | self.uncertainty_threshold = uncertainty_threshold 29 | 30 | def __call__(self, to_evaluate=None, results=None): 31 | # entropy should be normalized [0,1] before calling this evaluation 32 | target = to_evaluate['target'].astype(np.bool) 33 | prediction = to_evaluate['prediction'].astype(np.bool) 34 | uncertainty = to_evaluate['uncertainty'] 35 | 36 | thresholded_uncertainty = uncertainty > self.uncertainty_threshold 37 | tp, tn, fp, fn, tpu, tnu, fpu, fnu = \ 38 | np_fn.uncertainty(prediction, target, thresholded_uncertainty) 39 | 40 | results['tpu'] = tpu 41 | results['tnu'] = tnu 42 | results['fpu'] = fpu 43 | results['fnu'] = fnu 44 | 45 | results['tp'] = tp 46 | results['tn'] = tn 47 | results['fp'] = fp 48 | results['fn'] = fn 49 | 50 | tpu_fpu_ratio = results['tpu'] / results['fpu'] 51 | jaccard_index = results['tp'] / (results['tp'] + results['fp'] + results['fn']) 52 | results['dice_benefit'] = tpu_fpu_ratio < jaccard_index 53 | results['accuracy_benefit'] = tpu_fpu_ratio < 1 54 | 55 | results['dice'] = np_fn.dice(prediction, target) 56 | results['accuracy'] = np_fn.accuracy(prediction, target) 57 | 58 | corrected_prediction = prediction.copy() 59 | # correct to background 60 | corrected_prediction[thresholded_uncertainty] = 0 61 | 62 | results['corrected_dice'] = np_fn.dice(corrected_prediction, target) 63 | results['corrected_accuracy'] = np_fn.accuracy(corrected_prediction, target) 64 | 65 | results['dice_benefit_correct'] = (results['corrected_dice'] > results['dice']) == results['dice_benefit'] 66 | results['accuracy_benefit_correct'] = (results['corrected_accuracy'] > results['accuracy']) == results[ 67 | 'accuracy_benefit'] 68 | 69 | corrected_prediction = prediction.copy() 70 | # correct to foreground 71 | corrected_prediction[thresholded_uncertainty] = 1 72 | 73 | results['corrected_add_dice'] = np_fn.dice(corrected_prediction, target) 74 | results['corrected_add_accuracy'] = np_fn.accuracy(corrected_prediction, target) 75 | 76 | def binary_calibration(probabilities, target, n_bins=10, threshold_range = None, mask=None): 77 | if probabilities.ndim > target.ndim: 78 | if probabilities.shape[-1] > 2: 79 | raise ValueError('can only evaluate the calibration for binary classification') 80 | elif probabilities.shape[-1] == 2: 81 | probabilities = probabilities[..., 1] 82 | else: 83 | probabilities = np.squeeze(probabilities, axis=-1) 84 | 85 | if mask is not None: 86 | probabilities = probabilities[mask] 87 | target = target[mask] 88 | 89 | if threshold_range is not None: 90 | low_thres, up_thres = threshold_range 91 | mask = np.logical_and(probabilities < up_thres, probabilities > low_thres) 92 | probabilities = probabilities[mask] 93 | target = target[mask] 94 | 95 | pos_frac, mean_confidence, bin_count, non_zero_bins = \ 96 | _binary_calibration(target.flatten(), probabilities.flatten(), n_bins) 97 | 98 | return pos_frac, mean_confidence, bin_count, non_zero_bins 99 | 100 | def _binary_calibration(target, probs_positive_cls, n_bins=10): 101 | # same as sklearn.calibration calibration_curve but with the bin_count returned 102 | bins = np.linspace(0., 1. + 1e-8, n_bins + 1) 103 | binids = np.digitize(probs_positive_cls, bins) - 1 104 | 105 | # # note: this is the original formulation which has always n_bins + 1 as length 106 | # bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=len(bins)) 107 | # bin_true = np.bincount(binids, weights=target, minlength=len(bins)) 108 | # bin_total = np.bincount(binids, minlength=len(bins)) 109 | 110 | bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=n_bins) 111 | bin_true = np.bincount(binids, weights=target, minlength=n_bins) 112 | bin_total = np.bincount(binids, minlength=n_bins) 113 | 114 | nonzero = bin_total != 0 115 | prob_true = (bin_true[nonzero] / bin_total[nonzero]) 116 | prob_pred = (bin_sums[nonzero] / bin_total[nonzero]) 117 | 118 | return prob_true, prob_pred, bin_total[nonzero], nonzero 119 | 120 | def _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim): 121 | if bin_weighting == 'proportion': 122 | bin_proportions = bin_count / bin_count.sum() 123 | elif bin_weighting == 'log_proportion': 124 | bin_proportions = np.log(bin_count) / np.log(bin_count).sum() 125 | elif bin_weighting == 'power_proportion': 126 | bin_proportions = bin_count**(1/n_dim) / (bin_count**(1/n_dim)).sum() 127 | elif bin_weighting == 'mean_proportion': 128 | bin_proportions = 1 / non_zero_bins.sum() 129 | else: 130 | raise ValueError('unknown bin weighting "{}"'.format(bin_weighting)) 131 | return bin_proportions 132 | 133 | def ece_binary(probabilities, target, n_bins=10, threshold_range= None, mask=None, out_bins=None, 134 | bin_weighting='proportion'): 135 | # input: 1. probabilities (np) 2. target (np) 3. threshold_range (tuple[low,high]) 4. mask 136 | 137 | n_dim = target.ndim 138 | 139 | pos_frac, mean_confidence, bin_count, non_zero_bins = \ 140 | binary_calibration(probabilities, target, n_bins, threshold_range, mask) 141 | 142 | bin_proportions = _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim) 143 | 144 | if out_bins is not None: 145 | out_bins['bins_count'] = bin_count 146 | out_bins['bins_avg_confidence'] = mean_confidence 147 | out_bins['bins_positive_fraction'] = pos_frac 148 | out_bins['bins_non_zero'] = non_zero_bins 149 | 150 | ece = (np.abs(mean_confidence - pos_frac) * bin_proportions).sum() 151 | return ece 152 | 153 | def pkload(fname): 154 | with open(fname, 'rb') as f: 155 | return joblib.load(f) 156 | # return pickle.load(f) 157 | 158 | def cal_u(uncertainty): 159 | mean_uncertainty_total = 0 160 | sum_uncertainty_total = 0 161 | for i in range(len(uncertainty)): 162 | u = uncertainty[i][0] 163 | total_uncertainty = torch.sum(u, -1, keepdim=True) 164 | # print('current_sum_certainty:{} ; current_mean_certainty:{}'.format(torch.mean(total_uncertainty), 165 | # torch.mean(u))) 166 | sum_uncertainty_total += torch.mean(total_uncertainty) 167 | mean_uncertainty_total += torch.mean(u) 168 | num = len(uncertainty) 169 | return sum_uncertainty_total/num,mean_uncertainty_total/num 170 | 171 | def cal_ece(logits,targets): 172 | ece_total = 0 173 | for i in range(len(logits)): 174 | # pc = torch.zeros(1,1,128,128,128) 175 | logit = logits[i][0] 176 | target = targets[i][..., :155] 177 | pred = F.softmax(logit, dim=0) 178 | pc = pred.cpu().detach().numpy() 179 | pc = pc.argmax(0) 180 | ece_total += ece_binary(pc, target) 181 | num = len(logits) 182 | return ece_total/num 183 | 184 | def cal_ueo(to_evaluate,thresholds): 185 | UEO = [] 186 | for threshold in thresholds: 187 | metric = UncertaintyAndCorrectionEvalNumpy(threshold) 188 | UEO.append(list(metric(to_evaluate))) 189 | return UEO -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 20003 train_spup3.py 2 | 3 | import argparse 4 | import os 5 | import random 6 | import logging 7 | import numpy as np 8 | import time 9 | import setproctitle 10 | import torch 11 | import torch.optim 12 | from sklearn.externals import joblib 13 | # from models import criterions 14 | from models.lib.VNet3D import VNet 15 | from plot import loss_plot,metrics_plot 16 | from models.lib.UNet3DZoo import Unet,AttUnet,Unetdrop 17 | from models.criterions import softBCE_dice,softmax_dice,FocalLoss,DiceLoss 18 | from data.BraTS2019 import BraTS 19 | from torch.utils.data import DataLoader 20 | from tensorboardX import SummaryWriter 21 | from predict import validate_softmax,test_softmax,testensemblemax 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 23 | 24 | local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 25 | 26 | parser = argparse.ArgumentParser() 27 | 28 | # Basic Information 29 | parser.add_argument('--user', default='name of user', type=str) 30 | parser.add_argument('--experiment', default='TransBTS', type=str) 31 | parser.add_argument('--date', default=local_time.split(' ')[0], type=str) 32 | 33 | parser.add_argument('--description', 34 | default='TransBTS,' 35 | 'training on train.txt!', 36 | type=str) 37 | 38 | # DataSet Information 39 | parser.add_argument('--root', default='E:/BraTSdata1/archive2019', type=str) # folder_data_path 40 | parser.add_argument('--train_dir', default='MICCAI_BraTS_2019_Data_TTraining', type=str) 41 | parser.add_argument('--valid_dir', default='MICCAI_BraTS_2019_Data_TValidation', type=str) 42 | parser.add_argument('--test_dir', default='MICCAI_BraTS_2019_Data_TTest', type=str) 43 | parser.add_argument("--mode", default="train", type=str, help="train/test/train&test") 44 | # parser.add_argument('--train_file', 45 | # default='C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt', type=str) 46 | # parser.add_argument('--valid_file', default='C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt', 47 | # type=str) 48 | # parser.add_argument('--test_file', default='C:/Coco_file/BraTSdata/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt', 49 | # type=str) 50 | parser.add_argument('--train_file', 51 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt', 52 | type=str) 53 | parser.add_argument('--valid_file', 54 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt', 55 | type=str) 56 | parser.add_argument('--test_file', 57 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt', 58 | type=str) 59 | parser.add_argument('--dataset', default='BraTS', type=str) 60 | parser.add_argument('--input_C', default=4, type=int) 61 | parser.add_argument('--input_H', default=240, type=int) 62 | parser.add_argument('--input_W', default=240, type=int) 63 | parser.add_argument('--input_D', default=160, type=int) # 155 64 | parser.add_argument('--crop_H', default=128, type=int) 65 | parser.add_argument('--crop_W', default=128, type=int) 66 | parser.add_argument('--crop_D', default=128, type=int) 67 | parser.add_argument('--output_D', default=155, type=int) 68 | parser.add_argument('--rlt', default=-1, type=float, 69 | help='relation between CE/FL and dice') 70 | # Training Information 71 | parser.add_argument('--lr', default=0.002, type=float) 72 | parser.add_argument('--weight_decay', default=1e-5, type=float) 73 | parser.add_argument('--amsgrad', default=True, type=bool) 74 | # parser.add_argument('--criterion', default='softmaxBCE_dice', type=str) 75 | parser.add_argument('--submission', default='./results', type=str) 76 | parser.add_argument('--visual', default='visualization', type=str) 77 | parser.add_argument('--num_class', default=4, type=int) 78 | parser.add_argument('--seed', default=1000, type=int) 79 | parser.add_argument('--no_cuda', default=False, type=bool) 80 | parser.add_argument('--batch_size', default=2, type=int, help="2/4/8") 81 | parser.add_argument('--start_epoch', default=0, type=int) 82 | parser.add_argument('--end_epoch', default=200, type=int) 83 | parser.add_argument('--save_freq', default=5, type=int) 84 | parser.add_argument('--resume', default='', type=str) 85 | parser.add_argument('--load', default=True, type=bool) 86 | parser.add_argument('--modal', default='t2', type=str) # multi-modal 87 | parser.add_argument('--model_name', default='V', type=str, help="AU/V/U") 88 | parser.add_argument('--Variance', default=2, type=int) # 1 2 89 | parser.add_argument('--use_TTA', default=False, type=bool, help="True/False") 90 | parser.add_argument('--save_format', default='nii', type=str) 91 | parser.add_argument('--test_date', default='2022-01-04', type=str) 92 | parser.add_argument('--test_epoch', default=184, type=int) 93 | args = parser.parse_args() 94 | 95 | def val(model,checkpoint_dir,epoch,best_dice): 96 | valid_list = os.path.join(args.root, args.valid_dir, args.valid_file) 97 | valid_root = os.path.join(args.root, args.valid_dir) 98 | valid_set = BraTS(valid_list, valid_root,'valid',args.modal) 99 | valid_loader = DataLoader(valid_set, batch_size=1) 100 | print('Samples for valid = {}'.format(len(valid_set))) 101 | 102 | start_time = time.time() 103 | model.eval() 104 | with torch.no_grad(): 105 | best_dice,aver_dice,aver_iou = validate_softmax(save_dir = checkpoint_dir, 106 | best_dice = best_dice, 107 | current_epoch = epoch, 108 | save_freq = args.save_freq, 109 | end_epoch = args.end_epoch, 110 | valid_loader = valid_loader, 111 | model = model, 112 | multimodel = args.modal, 113 | Net_name = args.model_name, 114 | names = valid_set.names, 115 | ) 116 | # dice_list.append(aver_dice) 117 | # iou_list.append(aver_iou) 118 | end_time = time.time() 119 | full_test_time = (end_time-start_time)/60 120 | average_time = full_test_time/len(valid_set) 121 | print('{:.2f} minutes!'.format(average_time)) 122 | return best_dice,aver_dice,aver_iou 123 | 124 | def test(model): 125 | for arg in vars(args): 126 | logging.info('{}={}'.format(arg, getattr(args, arg))) 127 | logging.info('----------------------------------------This is a halving line----------------------------------') 128 | logging.info('{}'.format(args.description)) 129 | test_list = os.path.join(args.root, args.test_dir, args.test_file) 130 | test_root = os.path.join(args.root, args.test_dir) 131 | test_set = BraTS(test_list, test_root,'test',args.modal) 132 | test_loader = DataLoader(test_set, batch_size=1) 133 | print('Samples for test = {}'.format(len(test_set))) 134 | 135 | logging.info('final test........') 136 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 137 | 'checkpoint', args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth'.format(args.test_epoch)) 138 | 139 | if os.path.exists(load_file): 140 | checkpoint = torch.load(load_file) 141 | model.load_state_dict(checkpoint['state_dict']) 142 | args.start_epoch = checkpoint['epoch'] 143 | print('Successfully load checkpoint {}'.format(os.path.join(args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth'))) 144 | else: 145 | print('There is no resume file to load!') 146 | 147 | 148 | start_time = time.time() 149 | model.eval() 150 | with torch.no_grad(): 151 | aver_dice,aver_noise_dice,aver_hd,aver_noise_hd = test_softmax( test_loader = test_loader, 152 | model = model, 153 | multimodel = args.modal, 154 | Net_name=args.model_name, 155 | Variance = args.Variance, 156 | load_file=load_file, 157 | savepath = args.submission, 158 | names = test_set.names, 159 | use_TTA = args.use_TTA, 160 | save_format = args.save_format, 161 | ) 162 | end_time = time.time() 163 | full_test_time = (end_time-start_time)/60 164 | average_time = full_test_time/len(test_set) 165 | print('{:.2f} minutes!'.format(average_time)) 166 | logging.info('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice[0],aver_dice[1],aver_dice[2])) 167 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % (aver_noise_dice[0], aver_noise_dice[1], aver_noise_dice[2])) 168 | logging.info('aver_hd_WT=%f,aver_hd_TC = %f,aver_hd_ET = %f' % (aver_hd[0],aver_hd[1],aver_hd[2])) 169 | logging.info('aver_noise_hd_WT=%f,aver_noise_hd_TC = %f,aver_noise_hd_ET = %f' % (aver_noise_hd[0], aver_noise_hd[1], aver_noise_hd[2])) 170 | 171 | 172 | def test_ensemble(model): 173 | 174 | test_list = os.path.join(args.root, args.test_dir, args.test_file) 175 | test_root = os.path.join(args.root, args.test_dir) 176 | test_set = BraTS(test_list, test_root,'test',args.modal) 177 | test_loader = DataLoader(test_set, batch_size=1) 178 | print('Samples for test = {}'.format(len(test_set))) 179 | 180 | logging.info('final test........') 181 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 182 | 'checkpoint', args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth'.format(args.test_epoch)) 183 | 184 | if os.path.exists(load_file): 185 | checkpoint = torch.load(load_file) 186 | model.load_state_dict(checkpoint['state_dict']) 187 | args.start_epoch = checkpoint['epoch'] 188 | print('Successfully load checkpoint {}'.format(os.path.join(args.experiment + args.test_date, args.model_name + '_' + args.modal + '_epoch_{}.pth'))) 189 | else: 190 | print('There is no resume file to load!') 191 | 192 | # load ensemble models 193 | load_model=[] 194 | for i in range(10): 195 | save_name1 = args.model_name + '_' + args.modal + '_epoch_' +'199' + 'e' + str(i) + '.pth' 196 | load_model[i] = torch.load(save_name1) 197 | model[i] = load_model[i]['state_dict'] 198 | 199 | start_time = time.time() 200 | model.eval() 201 | with torch.no_grad(): 202 | aver_dice, aver_noise_dice, aver_hd, aver_noise_hd = testensemblemax(test_loader=test_loader, 203 | model=model, 204 | multimodel=args.modal, 205 | Net_name=args.model_name, 206 | Variance=args.Variance, 207 | load_file=load_file, 208 | savepath=args.submission, 209 | names=test_set.names, 210 | use_TTA=args.use_TTA, 211 | save_format=args.save_format, 212 | ) 213 | end_time = time.time() 214 | full_test_time = (end_time - start_time) / 60 215 | average_time = full_test_time / len(test_set) 216 | print('{:.2f} minutes!'.format(average_time)) 217 | logging.info('aver_dice_WT=%f,aver_dice_TC = %f,aver_dice_ET = %f' % (aver_dice[0], aver_dice[1], aver_dice[2])) 218 | logging.info('aver_noise_dice_WT=%f,aver_noise_dice_TC = %f,aver_noise_dice_ET = %f' % ( 219 | aver_noise_dice[0], aver_noise_dice[1], aver_noise_dice[2])) 220 | logging.info('aver_iou_WT=%f,aver_iou_TC = %f,aver_iou_ET = %f' % (aver_hd[0], aver_hd[1], aver_hd[2])) 221 | logging.info('aver_noise_iou_WT=%f,aver_noise_iou_TC = %f,aver_noise_iou_ET = %f' % ( 222 | aver_noise_hd[0], aver_noise_hd[1], aver_noise_hd[2])) 223 | 224 | def train(criterion,model,criterion_fl,criterion_dl): 225 | # dataset 226 | train_list = os.path.join(args.root, args.train_dir, args.train_file) 227 | train_root = os.path.join(args.root, args.train_dir) 228 | train_set = BraTS(train_list, train_root, args.mode,args.modal) 229 | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size) 230 | print('Samples for train = {}'.format(len(train_set))) 231 | 232 | logging.info('--------------------------------------This is all argsurations----------------------------------') 233 | for arg in vars(args): 234 | logging.info('{}={}'.format(arg, getattr(args, arg))) 235 | logging.info('----------------------------------------This is a halving line----------------------------------') 236 | logging.info('{}'.format(args.description)) 237 | 238 | torch.manual_seed(args.seed) 239 | torch.cuda.manual_seed(args.seed) 240 | random.seed(args.seed) 241 | np.random.seed(args.seed) 242 | 243 | model.cuda() 244 | model.train() 245 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.amsgrad) 246 | # criterion = getattr(criterions, args.criterion) 247 | 248 | checkpoint_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', args.experiment+args.date) 249 | if not os.path.exists(checkpoint_dir): 250 | os.makedirs(checkpoint_dir) 251 | 252 | resume = '' 253 | 254 | writer = SummaryWriter() 255 | 256 | if os.path.isfile(resume) and args.load: 257 | logging.info('loading checkpoint {}'.format(resume)) 258 | checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) 259 | 260 | model.load_state_dict(checkpoint['state_dict']) 261 | 262 | logging.info('Successfully loading checkpoint {} and training from epoch: {}' 263 | .format(args.resume, args.start_epoch)) 264 | else: 265 | logging.info('re-training!!!') 266 | 267 | start_time = time.time() 268 | 269 | torch.set_grad_enabled(True) 270 | loss_list = [] 271 | dice_list = [] 272 | iou_list = [] 273 | best_dice =0 274 | for epoch in range(args.start_epoch, args.end_epoch): 275 | epoch_loss = 0 276 | loss = 0 277 | # loss1 = 0 278 | # loss2 = 0 279 | # loss3 = 0 280 | setproctitle.setproctitle('{}: {}/{}'.format(args.user, epoch+1, args.end_epoch)) 281 | start_epoch = time.time() 282 | for i, data in enumerate(train_loader): 283 | 284 | adjust_learning_rate(optimizer, epoch, args.end_epoch, args.lr) 285 | 286 | x, target = data 287 | x = x.cuda() 288 | target = target.cuda() 289 | output = model(x) 290 | 291 | if args.rlt > 0: 292 | loss = criterion_fl(output, target) + args.rlt * criterion_dl(output, target) 293 | else: 294 | loss = criterion_dl(output, target) 295 | 296 | # loss, loss1, loss2, loss3 = criterion(output, target) 297 | # loss1.requires_grad_(True) 298 | # loss2.requires_grad_(True) 299 | # loss3.requires_grad_(True) 300 | optimizer.zero_grad() 301 | loss.backward() 302 | # loss1.backward() 303 | # loss2.backward() 304 | # loss3.backward() 305 | optimizer.step() 306 | reduce_loss = loss.data.cpu().numpy() 307 | # reduce_loss1 = loss1.data.cpu().numpy() 308 | # reduce_loss2 = loss2.data.cpu().numpy() 309 | # reduce_loss3 = loss3.data.cpu().numpy() 310 | # logging.info('Epoch: {}_Iter:{} loss: {:.5f} || 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||' 311 | # .format(epoch, i, reduce_loss, reduce_loss1, reduce_loss2, reduce_loss3)) 312 | logging.info('Epoch: {}_Iter:{} loss: {:.5f}' 313 | .format(epoch, i, reduce_loss)) 314 | 315 | epoch_loss += reduce_loss 316 | end_epoch = time.time() 317 | loss_list.append(epoch_loss) 318 | 319 | writer.add_scalar('lr', optimizer.defaults['lr'], epoch) 320 | writer.add_scalar('loss', loss, epoch) 321 | # writer.add_scalar('loss1', loss1, epoch) 322 | # writer.add_scalar('loss2', loss2, epoch) 323 | # writer.add_scalar('loss3', loss3, epoch) 324 | 325 | epoch_time_minute = (end_epoch-start_epoch)/60 326 | remaining_time_hour = (args.end_epoch-epoch-1)*epoch_time_minute/60 327 | logging.info('Current epoch time consumption: {:.2f} minutes!'.format(epoch_time_minute)) 328 | logging.info('Estimated remaining training time: {:.2f} hours!'.format(remaining_time_hour)) 329 | best_dice,aver_dice,aver_iou = val(model,checkpoint_dir,epoch,best_dice) 330 | dice_list.append(aver_dice) 331 | iou_list.append(aver_iou) 332 | writer.close() 333 | # validation 334 | 335 | end_time = time.time() 336 | total_time = (end_time-start_time)/3600 337 | logging.info('The total training time is {:.2f} hours'.format(total_time)) 338 | logging.info('----------------------------------The training process finished!-----------------------------------') 339 | 340 | loss_plot(args, loss_list) 341 | metrics_plot(args, 'dice',dice_list) 342 | 343 | def adjust_learning_rate(optimizer, epoch, max_epoch, init_lr, power=0.9): 344 | for param_group in optimizer.param_groups: 345 | param_group['lr'] = round(init_lr * np.power(1-(epoch) / max_epoch, power), 8) 346 | 347 | 348 | def log_args(log_file): 349 | 350 | logger = logging.getLogger() 351 | logger.setLevel(logging.DEBUG) 352 | formatter = logging.Formatter( 353 | '%(asctime)s ===> %(message)s', 354 | datefmt='%Y-%m-%d %H:%M:%S') 355 | 356 | # args FileHandler to save log file 357 | fh = logging.FileHandler(log_file) 358 | fh.setLevel(logging.DEBUG) 359 | fh.setFormatter(formatter) 360 | 361 | # args StreamHandler to print log to console 362 | ch = logging.StreamHandler() 363 | ch.setLevel(logging.DEBUG) 364 | ch.setFormatter(formatter) 365 | 366 | # add the two Handler 367 | logger.addHandler(ch) 368 | logger.addHandler(fh) 369 | 370 | if __name__ == '__main__': 371 | # criterion = softBCE_dice(aggregate="sum") 372 | criterion = softmax_dice 373 | criterion_fl = FocalLoss(4) 374 | criterion_dl = DiceLoss() 375 | num = 2 376 | # _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned") 377 | # x = [i for i in range(num)] 378 | # l = [i*random.random() for i in range(num)] 379 | # plt.figure() 380 | # plt.plot(x, l, label='dice') 381 | # log 382 | log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'log', args.experiment + args.date) 383 | log_file = log_dir + '.txt' 384 | log_args(log_file) 385 | # Net model choose 386 | if args.model_name == 'AU' and args.modal == 'both': 387 | model = AttUnet(in_channels=2, base_channels=16, num_classes=4) 388 | elif args.model_name == 'AU': 389 | model = AttUnet(in_channels=1, base_channels=16, num_classes=4) 390 | elif args.model_name == 'V' and args.modal == 'both': 391 | model = VNet(n_channels=2, n_classes=4, n_filters=16, normalization='gn', has_dropout=False) 392 | elif args.model_name == 'V' : 393 | model = VNet(n_channels=1, n_classes=4, n_filters=16, normalization='gn', has_dropout=False) 394 | elif args.model_name == 'Udrop'and args.modal == 'both': 395 | model = Unetdrop(in_channels=2, base_channels=16, num_classes=4) 396 | elif args.model_name == 'Udrop': 397 | model = Unetdrop(in_channels=1, base_channels=16, num_classes=4) 398 | elif args.model_name == 'U' and args.modal == 'both': 399 | model = Unet(in_channels=2, base_channels=16, num_classes=4) 400 | else: 401 | model = Unet(in_channels=1, base_channels=16, num_classes=4) 402 | # if 'train' in args.mode: 403 | # train(criterion,model,criterion_fl,criterion_dl) 404 | args.mode = 'test' 405 | # Udropout_uncertainty = joblib.load('Udropout_uncertainty.pkl') 406 | test(model) 407 | # test_ensemble(model) 408 | -------------------------------------------------------------------------------- /trainTBraTS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import time 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | from models.trustedseg import TMSU 8 | from data.BraTS2019 import BraTS 9 | from predict import tailor_and_concat,softmax_mIOU_score,softmax_output_dice 10 | import torch.nn.functional as F 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 14 | import numpy as np 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | 34 | 35 | if __name__ == "__main__": 36 | import argparse 37 | 38 | parser = argparse.ArgumentParser() 39 | # Basic Information 40 | local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 41 | parser.add_argument('--user', default='name of user', type=str) 42 | parser.add_argument('--experiment', default='TBraTS', type=str) 43 | parser.add_argument('--date', default=local_time.split(' ')[0], type=str) 44 | parser.add_argument('--description', 45 | default='Trusted brain tumor segmentation by coco,' 46 | 'training on train.txt!', 47 | type=str) 48 | # training detalis 49 | parser.add_argument('--epochs', type=int, default=199, metavar='N', 50 | help='number of epochs to train [default: 500]') 51 | parser.add_argument('--test_epoch', type=int, default=198, metavar='N', 52 | help='best epoch') 53 | parser.add_argument('--lambda-epochs', type=int, default=50, metavar='N', 54 | help='gradually increase the value of lambda from 0 to 1') 55 | parser.add_argument('--save_freq', default=1, type=int) 56 | parser.add_argument('--lr', type=float, default=0.002, metavar='LR', 57 | help='learning rate') 58 | # DataSet Information 59 | parser.add_argument('--root', default='E:/BraTSdata1/archive2019', type=str) 60 | parser.add_argument('--save_dir', default='./results', type=str) 61 | parser.add_argument('--train_dir', default='MICCAI_BraTS_2019_Data_TTraining', type=str) 62 | parser.add_argument('--valid_dir', default='MICCAI_BraTS_2019_Data_TValidation', type=str) 63 | parser.add_argument('--test_dir', default='MICCAI_BraTS_2019_Data_TTest', type=str) 64 | parser.add_argument("--mode", default="train", type=str, help="train/test/train&test") 65 | parser.add_argument('--train_file', 66 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttrain_subject.txt', 67 | type=str) 68 | parser.add_argument('--valid_file', 69 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Tval_subject.txt', 70 | type=str) 71 | parser.add_argument('--test_file', 72 | default='E:/BraTSdata1/archive2019/MICCAI_BraTS_2019_Data_Training/Ttest_subject.txt', 73 | type=str) 74 | parser.add_argument('--dataset', default='brats', type=str) 75 | parser.add_argument('--classes', default=4, type=int)# brain tumor class 76 | parser.add_argument('--input_H', default=240, type=int) 77 | parser.add_argument('--input_W', default=240, type=int) 78 | parser.add_argument('--input_D', default=160, type=int) # 155 79 | parser.add_argument('--crop_H', default=128, type=int) 80 | parser.add_argument('--crop_W', default=128, type=int) 81 | parser.add_argument('--crop_D', default=128, type=int) 82 | parser.add_argument('--output_D', default=155, type=int) 83 | parser.add_argument('--batch_size', default=4, type=int, help="2/4/8") 84 | parser.add_argument('--input_dims', default='four', type=str) # multi-modal/Single-modal 85 | parser.add_argument('--model_name', default='V', type=str) # multi-modal V:168 AU:197 86 | parser.add_argument('--Variance', default=0.5, type=int) # noise level 87 | parser.add_argument('--use_TTA', default=True, type=bool, help="True/False") 88 | args = parser.parse_args() 89 | args.dims = [[240,240,160], [240,240,160]] 90 | args.modes = len(args.dims) 91 | 92 | train_list = os.path.join(args.root, args.train_dir, args.train_file) 93 | train_root = os.path.join(args.root, args.train_dir) 94 | train_set = BraTS(train_list, train_root, args.mode,args.input_dims) 95 | train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size) 96 | print('Samples for train = {}'.format(len(train_set))) 97 | valid_list = os.path.join(args.root, args.valid_dir, args.valid_file) 98 | valid_root = os.path.join(args.root, args.valid_dir) 99 | valid_set = BraTS(valid_list, valid_root,'valid',args.input_dims) 100 | valid_loader = DataLoader(valid_set, batch_size=1) 101 | print('Samples for valid = {}'.format(len(valid_set))) 102 | test_list = os.path.join(args.root, args.test_dir, args.test_file) 103 | test_root = os.path.join(args.root, args.test_dir) 104 | test_set = BraTS(test_list, test_root,'test',args.input_dims) 105 | test_loader = DataLoader(test_set, batch_size=1) 106 | print('Samples for test = {}'.format(len(test_set))) 107 | 108 | model = TMSU(args.classes, args.modes, args.model_name, args.input_dims,args.epochs, args.lambda_epochs) # lambda KL divergence 109 | total = sum([param.nelement() for param in model.parameters()]) 110 | print("Number of model's parameter: %.2fM" % (total / 1e6)) 111 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) 112 | 113 | model.cuda() 114 | 115 | def train(epoch): 116 | 117 | model.train() 118 | loss_meter = AverageMeter() 119 | step = 0 120 | dt_size = len(train_loader.dataset) 121 | for i, data in enumerate(train_loader): 122 | step += 1 123 | input, target = data 124 | x = input.cuda() # for multi-modal combine train 125 | target = target.cuda() 126 | args.mode = 'train' 127 | evidences, loss = model(x,target,epoch,args.mode) 128 | 129 | print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // train_loader.batch_size + 1, loss.item())) 130 | 131 | optimizer.zero_grad() 132 | loss.requires_grad_(True).backward() 133 | optimizer.step() 134 | 135 | loss_meter.update(loss.item()) 136 | return loss_meter.avg 137 | 138 | def val(args,current_epoch,best_dice): 139 | print('===========>Validation begining!===========') 140 | model.eval() 141 | loss_meter = AverageMeter() 142 | dice_total, iou_total = 0, 0 143 | step = 0 144 | # model.eval() 145 | for i, data in enumerate(valid_loader): 146 | step += 1 147 | input, target = data 148 | 149 | # add gaussian noise to input data 150 | x = dict() 151 | for m_num in range(input.shape[1]): 152 | x[m_num] = input[..., m_num, :, :, :, ].unsqueeze(1).cuda() 153 | target = target.cuda() 154 | 155 | with torch.no_grad(): 156 | args.mode = 'val' 157 | evidences, loss = model(x, target[:, :, :, :155], current_epoch,args.mode) # two modality or four modality 158 | # max 159 | _, predicted = torch.max(evidence.data, 1) 160 | output = predicted.cpu().detach().numpy() 161 | 162 | target = torch.squeeze(target).cpu().numpy() 163 | iou_res = softmax_mIOU_score(output, target[:, :, :155]) 164 | dice_res = softmax_output_dice(output, target[:, :, :155]) 165 | print('current_iou:{} ; current_dice:{}'.format(iou_res, dice_res)) 166 | dice_total += dice_res[1] 167 | iou_total += iou_res[1] 168 | # loss & noised loss 169 | loss_meter.update(loss.item()) 170 | aver_dice = dice_total / len(valid_loader) 171 | aver_iou = iou_total / len(valid_loader) 172 | if aver_dice > best_dice \ 173 | or (current_epoch + 1) % int(args.epochs - 1) == 0 \ 174 | or (current_epoch + 1) % int(args.epochs - 2) == 0 \ 175 | or (current_epoch + 1) % int(args.epochs - 3) == 0: 176 | print('aver_dice:{} > best_dice:{}'.format(aver_dice, best_dice)) 177 | best_dice = aver_dice 178 | print('===========>save best model!') 179 | file_name = os.path.join(args.save_dir, '_epoch_{}.pth'.format(current_epoch)) 180 | torch.save({ 181 | 'epoch': current_epoch, 182 | 'state_dict': model.state_dict(), 183 | }, 184 | file_name) 185 | return loss_meter.avg, best_dice 186 | 187 | def test(args): 188 | print('===========>Test begining!===========') 189 | 190 | loss_meter = AverageMeter() 191 | noised_loss_meter = AverageMeter() 192 | dice_total,iou_total = 0,0 193 | noised_dice_total,noised_iou_total = 0,0 194 | step = 0 195 | dt_size = len(test_loader.dataset) 196 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 197 | args.save_dir, 198 | '_epoch_{}.pth'.format(args.test_epoch)) 199 | 200 | if os.path.exists(load_file): 201 | checkpoint = torch.load(load_file) 202 | model.load_state_dict(checkpoint['state_dict']) 203 | args.start_epoch = checkpoint['epoch'] 204 | print('Successfully load checkpoint {}'.format(os.path.join(args.save_dir + '/_epoch_' + str(args.test_epoch)))) 205 | else: 206 | print('There is no resume file to load!') 207 | model.eval() 208 | for i, data in enumerate(test_loader): 209 | step += 1 210 | input, target = data 211 | # add gaussian noise to input data 212 | noise_m = torch.randn_like(input) * args.Variance 213 | noised_input = input + noise_m 214 | # x = input.cuda() 215 | # noised_x = noised_input.cuda() 216 | x = dict() 217 | noised_x = dict() 218 | for m_num in range(input.shape[1]): 219 | x[m_num] = input[...,m_num,:,:,:,].unsqueeze(1).cuda() 220 | noised_x[m_num] = noised_input[...,m_num,:,:,:,].unsqueeze(1).cuda() 221 | target = target.cuda() 222 | 223 | with torch.no_grad(): 224 | args.mode = 'test' 225 | if not args.use_TTA: 226 | evidences, loss = model(x, target[:, :, :, :155], args.epochs,args.mode) 227 | noised_evidences, noised_loss = model(noised_x, target[:, :, :, :155], args.epochs,args.mode) 228 | else: 229 | evidences, loss = model(x, target[:, :, :, :155], args.epochs,args.mode,args.use_TTA) 230 | noised_evidences, noised_loss = model(noised_x, target[:, :, :, :155], args.epochs,args.mode,args.use_TTA) 231 | # results with TTA or not 232 | 233 | output = F.softmax(evidence, dim=1) 234 | # for input noise 235 | noised_output = F.softmax(noised_evidence, dim=1) 236 | 237 | # dice 238 | output = output[0, :, :args.input_H, :args.input_W, :args.input_D].cpu().detach().numpy() 239 | output = output.argmax(0) 240 | target = torch.squeeze(target).cpu().numpy() 241 | iou_res = softmax_mIOU_score(output, target[:, :, :155]) 242 | dice_res = softmax_output_dice(output, target[:, :, :155]) 243 | dice_total += dice_res[1] 244 | iou_total += iou_res[1] 245 | # for noise_x 246 | noised_output = noised_output[0, :, :args.input_H, :args.input_W, :args.input_D].cpu().detach().numpy() 247 | noised_output = noised_output.argmax(0) 248 | noised_iou_res = softmax_mIOU_score(noised_output, target[:, :, :155]) 249 | noised_dice_res = softmax_output_dice(noised_output, target[:, :, :155]) 250 | noised_dice_total += noised_dice_res[1] 251 | noised_iou_total += noised_iou_res[1] 252 | print('current_dice:{} ; current_noised_dice:{}'.format(dice_res, noised_dice_res)) 253 | # loss & noised loss 254 | loss_meter.update(loss.item()) 255 | noised_loss_meter.update(noised_loss.item()) 256 | noised_aver_dice = noised_dice_total / len(test_loader) 257 | aver_dice = dice_total / len(test_loader) 258 | print('====> noised_aver_dice: {:.4f}'.format(noised_aver_dice)) 259 | print('====> aver_dice: {:.4f}'.format(aver_dice)) 260 | return loss_meter.avg,noised_loss_meter.avg, aver_dice,noised_aver_dice 261 | 262 | 263 | epoch_loss = 0 264 | best_dice = 0 265 | for epoch in range(1, args.epochs + 1): 266 | print('===========Train begining!===========') 267 | print('Epoch {}/{}'.format(epoch, args.epochs - 1)) 268 | epoch_loss = train(epoch) 269 | print("epoch %d avg_loss:%0.3f" % (epoch, epoch_loss)) 270 | val_loss, best_dice = val(args,epoch,best_dice) 271 | test_loss,noised_test_loss, test_dice,noised_test_dice = test(args) 272 | --------------------------------------------------------------------------------