├── MedIA’24 ├── null ├── Models │ ├── null │ ├── fundus_swin_network.py │ ├── generate_model.py │ ├── resnet.py │ ├── unetr.py │ ├── res2net.py │ └── swin_transformer.py ├── results │ └── null ├── pretrain │ └── readme ├── readme ├── README.md ├── main_train2.sh ├── metrics.py ├── metrics2.py ├── baseline_models.py └── train3_trans.py ├── MICCAI23 ├── Models │ ├── null │ ├── generate_model.py │ ├── resnet.py │ └── res2net.py ├── README.md └── data.py └── README.md /MedIA’24/null: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /MICCAI23/Models/null: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /MedIA’24/Models/null: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /MedIA’24/results/null: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /MedIA’24/pretrain/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /MedIA’24/readme: -------------------------------------------------------------------------------- 1 | # EyeMoSt+ 2 | * This repository provides the code for our paper 【Medical Image Analysis submission 2024】"Confidence-aware multi-modality learning for eye disease screening" 3 | * Since our article is still under review, we will release all the code once we submit our revision. 4 | * Current official implementation of [EyeMoSt](https://arxiv.org/abs/2303.09790) 5 | * You can refer part codes from conference paper[EyeMoSt](https://github.com/Cocofeat/EyeMoSt/tree/main/MICCAI23). 6 | -------------------------------------------------------------------------------- /MICCAI23/README.md: -------------------------------------------------------------------------------- 1 | # EyeMoSt 2 | * This repository provides the code for our paper 【MICCAI 2023 Early Accept】"Reliable Multimodality Eye Disease Screening via Mixture of Student's t Distributions" 3 | * Current official implementation of [EyeMoSt](https://arxiv.org/abs/2303.09790) 4 | * All of codes are released [EyeMoSt+](https://github.com/Cocofeat/EyeMoSt/tree/main/MedIA%E2%80%9924). 5 | 6 | ## Datasets 7 | * [GAMMA dataset](https://gamma.grand-challenge.org/) 8 | 9 | ## Code Usage 10 | ### 1. Prepare dataset 11 | * Download the datasets and change the dataset path: 12 | * [GAMMA dataset basepath and datapath](https://github.com/Cocofeat/EyeMoSt/blob/fb471c67beafe70dfb4d67f896d3220ec0a48df3/MedIA%E2%80%9924/train3_trans.py#L431) 13 | 14 | ### 2. Pretrained models 15 | * Download pretrained models and put them in ./pretrain/ 16 | * Fundus (2D): [Res2Net](https://github.com/LeiJiangJNU/Res2Net) 17 | * OCT (3D): [Med3d](https://github.com/cshwhale/Med3D) 18 | 19 | ### 3. Train & Test 20 | Run the script ```main_train2.sh python train.py``` to test our model (change ``` model_name ```& ```mode```) 21 | -------------------------------------------------------------------------------- /MedIA’24/README.md: -------------------------------------------------------------------------------- 1 | # Confidence-aware multi-modality learning for eye disease screening 2 | ## 1. Requirment 3 | - Pytorch 1.3.0 4 | - Python 3 5 | - sklearn 6 | - numpy 7 | - scipy 8 | ## 2. Prepare dataset 9 | * Download the datasets and change the dataset path: 10 | * [OLIVES dataset path](https://github.com/Cocofeat/EyeMoSt/blob/fb471c67beafe70dfb4d67f896d3220ec0a48df3/MedIA%E2%80%9924/train3_trans.py#L409) 11 | * [GAMMA dataset basepath and datapath](https://github.com/Cocofeat/EyeMoSt/blob/fb471c67beafe70dfb4d67f896d3220ec0a48df3/MedIA%E2%80%9924/train3_trans.py#L431) 12 | 13 | ## 2. Pretrained models 14 | * Download pretrained models and put them in ./pretrain/ 15 | 16 | ### 2.1 CNN-based 17 | * Fundus (2D): [Res2Net](https://github.com/LeiJiangJNU/Res2Net) 18 | * OCT (3D): [Med3d](https://github.com/cshwhale/Med3D) 19 | ### 2.2 Transformer-based 20 | * Fundus (2D): [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) 21 | * OCT (3D): [UNETR](https://github.com/Project-MONAI/research-contributions/tree/main/UNETR) 22 | 23 | ## 3. Train 24 | ### 3.1 Train Baseline 25 | Run the script ```main_train2.sh python baseline_train3_trans.py``` to train the baselines (change ``` model_name ```& ```mode```), models will be saved in folder ```results``` 26 | ### 3.2 Train Our Model 27 | Run the script ```main_train2.sh python train3_trans.py``` to train our model (change ``` model_name ```), models will be saved in folder ```results``` 28 | 29 | ## 4. Test 30 | ### 4.1 Test Baseline 31 | Run the script ```main_train2.sh python baseline_train3_trans.py``` to test our model (change ``` model_name ```& ```mode```) 32 | ### 4.2 Test Our Model 33 | Run the script ```main_train2.sh python train3_trans.py``` to test our model (change ``` model_name ```& ```mode```) 34 | -------------------------------------------------------------------------------- /MedIA’24/main_train2.sh: -------------------------------------------------------------------------------- 1 | # 1.train&test GAMMA 2 | # 1.1 train GAMMA 3 | # model name: EyeMost_Plus_transformer,EyeMost_Plus,EyeMost,EyeMost_prior;TMC 4 | # EyeMost_Plus means EyeMost (CNN) 5 | # EyeMost_Plus_transformer means EyeMost (Transformer) 6 | 7 | # model_base: cnn/Transformer 8 | # dataset: MGamma/OLIVES/MMOCTF 9 | # condition: noise/normal 10 | # EyeMost_Plus means EyeMost (CNN) 11 | # EyeMost_Plus_transformer means EyeMost (Transformer) 12 | 13 | 14 | # 1.1 train GAMMA 15 | # CUDA_VISIBLE_DEVICES=1 python train3_trans.py \ 16 | # --folder "folder0"\ 17 | # --mode "train&test"\ 18 | # --model_name "EyeMost_Plus_transformer"\ 19 | # --model_base "transformer"\ 20 | # --dataset "MGamma"\ 21 | # --condition "normal" 22 | # 1.2 test GAMMA 23 | # CUDA_VISIBLE_DEVICES=1 python train3_trans.py \ 24 | # --folder "folder2"\ 25 | # --mode "test"\ 26 | # --model_base "transformer"\ 27 | # --model_name "EyeMost_Plus_transformer"\ 28 | # --dataset "MGamma"\ 29 | # --condition "noise" 30 | 31 | # 1.3 train baseline Base_transformer/Res2Net2D/ResNet3D/Multi_ResNet(B-CNN)/Multi_EF_ResNet(B-EF)/Multi_CBAM_ResNet(M2LC)/Multi_dropout_ResNet(MCDO) 32 | CUDA_VISIBLE_DEVICES=0 python baseline_train3_trans.py \ 33 | --folder "folder0"\ 34 | --mode "train&test"\ 35 | --model_name "Base_transformer"\ 36 | --model_base "transformer"\ 37 | --dataset "MGamma"\ 38 | --condition "normal" 39 | 40 | # # 1.4 test baseline Base_transformer/BIF 41 | # CUDA_VISIBLE_DEVICES=0 python baseline_train3_trans.py \ 42 | # --folder "folder0"\ 43 | # --mode "test"\ 44 | # --model_name "Base_transformer"\ 45 | # --model_base "transformer"\ 46 | # --dataset "MGamma"\ 47 | # --condition "noise" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 【EyeMoSt & EyeMoSt+】 2 | * This repository provides the code for our paper 【MICCAI 2023 Early Accept】"Reliable Multimodality Eye Disease Screening via Mixture of Student's t Distributions" and 【Medical Image Analysis submission 2024】"Confidence-aware multi-modality learning for eye disease screening" 3 | * Current official implementation of [EyeMoSt](https://arxiv.org/abs/2303.09790) 4 | * All codes are released in the version of [EyeMoSt+](https://github.com/Cocofeat/EyeMoSt/tree/main/MedIA%E2%80%9924). 5 | 6 | ## Requirment 7 | - Pytorch 1.3.0 8 | - Python 3 9 | - sklearn 10 | - numpy 11 | - scipy 12 | - ... 13 | 14 | ## Datasets 15 | * [GAMMA dataset](https://gamma.grand-challenge.org/) 16 | * [OLIVES dataset](https://doi.org/10.5281/zenodo.7105232) 17 | 18 | ## Code Usage 19 | ### 1. Prepare dataset 20 | * Download the datasets and change the dataset path: 21 | * [OLIVES dataset path](https://github.com/Cocofeat/EyeMoSt/blob/fb471c67beafe70dfb4d67f896d3220ec0a48df3/MedIA%E2%80%9924/train3_trans.py#L409) 22 | * [GAMMA dataset basepath and datapath](https://github.com/Cocofeat/EyeMoSt/blob/fb471c67beafe70dfb4d67f896d3220ec0a48df3/MedIA%E2%80%9924/train3_trans.py#L431) 23 | 24 | ### 2. Pretrained models 25 | * Download pretrained models and put them in ./pretrain/ 26 | 27 | #### 2.1 CNN-based 28 | * Fundus (2D): [Res2Net](https://github.com/LeiJiangJNU/Res2Net) 29 | * OCT (3D): [Med3d](https://github.com/cshwhale/Med3D) 30 | #### 2.2 Transformer-based 31 | * Fundus (2D): [Swin-Transformer](https://github.com/microsoft/Swin-Transformer) 32 | * OCT (3D): [UNETR](https://github.com/Project-MONAI/research-contributions/tree/main/UNETR) 33 | 34 | ### 3. Train 35 | #### 3.1 Train Baseline 36 | Run the script [main_train2.sh](https://github.com/Cocofeat/EyeMoSt/blob/main/MedIA%E2%80%9924/main_train2.sh)```main_train2.sh python baseline_train3_trans.py``` to train the baselines (change ``` model_name ```& ```mode```), models will be saved in folder ```results``` 37 | #### 3.2 Train Our Model 38 | Run the script [main_train2.sh](https://github.com/Cocofeat/EyeMoSt/blob/main/MedIA%E2%80%9924/main_train2.sh) ```main_train2.sh python train3_trans.py``` to train our model (change ``` model_name ```), models will be saved in folder ```results``` 39 | ### 4. Test 40 | #### 4.1 Test Baseline 41 | Run the script [main_train2.sh](https://github.com/Cocofeat/EyeMoSt/blob/main/MedIA%E2%80%9924/main_train2.sh) ```main_train2.sh python baseline_train3_trans.py``` to test our model (change ``` model_name ```& ```mode```) 42 | #### 4.2 Test Our Model 43 | Run the script [main_train2.sh](https://github.com/Cocofeat/EyeMoSt/blob/main/MedIA%E2%80%9924/main_train2.sh) ```main_train2.sh python train3_trans.py``` to test our model (change ``` model_name ```& ```mode```) 44 | 45 | ## Citation 46 | If you find uMedGround helps your research, please cite our paper: 47 | ``` 48 | @InProceedings{uMedGround_Zou_2024, 49 | author="Zou, Ke 50 | and Lin, Tian 51 | and Yuan, Xuedong 52 | and Chen, Haoyu 53 | and Shen, Xiaojing 54 | and Wang, Meng 55 | and Fu, Huazhu", 56 | title="Reliable Multimodality Eye Disease Screening via Mixture of Student's t Distributions", 57 | journal={arXiv preprint arXiv:2404.06798}, 58 | year={2024} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /MedIA’24/Models/fundus_swin_network.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | #from opts import opts 10 | #opt = opts().parse() 11 | import torch 12 | 13 | def build_model(): 14 | #model_type = config.MODEL.TYPE 15 | 16 | img_size = 384 #224 17 | patch_size = 4 18 | in_chans = 3 19 | num_classes = 3 20 | embed_dim = 128 21 | depths = [2,2,18,2] 22 | num_heads = [4, 8, 16, 32] 23 | window_size = 12 #7 24 | mlp_ratio = 4 25 | qkv_bias = True 26 | qk_scale = None 27 | drop_rate = 0.0 28 | drop_path_rate = 0.5 29 | ape = False 30 | patch_norm = True 31 | use_checkpoint = False 32 | #import pdb;pdb.set_trace() 33 | if True:#opt.model_type == 'swin': 34 | model = SwinTransformer(img_size= img_size,#config.DATA.IMG_SIZE, 35 | patch_size= patch_size, #config.MODEL.SWIN.PATCH_SIZE, 36 | in_chans= in_chans,#config.MODEL.SWIN.IN_CHANS, 37 | num_classes=num_classes, #config.MODEL.NUM_CLASSES, 38 | embed_dim=embed_dim,#config.MODEL.SWIN.EMBED_DIM, 39 | depths=depths,#config.MODEL.SWIN.DEPTHS, 40 | num_heads=num_heads,#config.MODEL.SWIN.NUM_HEADS, 41 | window_size=window_size,#config.MODEL.SWIN.WINDOW_SIZE, 42 | mlp_ratio=mlp_ratio,#config.MODEL.SWIN.MLP_RATIO, 43 | qkv_bias=qkv_bias,#config.MODEL.SWIN.QKV_BIAS, 44 | qk_scale=qk_scale,#config.MODEL.SWIN.QK_SCALE, 45 | drop_rate=drop_rate,#config.MODEL.DROP_RATE, 46 | drop_path_rate=drop_path_rate,#config.MODEL.DROP_PATH_RATE, 47 | ape=ape,#config.MODEL.SWIN.APE, 48 | patch_norm=patch_norm,#config.MODEL.SWIN.PATCH_NORM, 49 | use_checkpoint=False)#config.TRAIN.USE_CHECKPOINT) 50 | #else: 51 | # raise NotImplementedError(f"Unkown model: {model_type}") 52 | #import pdb;pdb.set_trace() 53 | 54 | #snapshot_name = '../data/trained_models/swin_base_patch4_window7_224.pth' 55 | # snapshot_name = './pretrain/swin_base_patch4_window12_384_22k.pth' 56 | # snapshot_name = '/data/zou_ke/projects/2021-gamma-main/src/pretrain/swin_base_patch4_window12_384.pth' 57 | snapshot_name = '/data/zou_ke/projects/2021-gamma-main/src/pretrain/swin_base_patch4_window12_384.pth' 58 | pre_state_dict = torch.load(snapshot_name) 59 | print("load model OK.") 60 | pre_state_dict = pre_state_dict['model'] 61 | cnt = 0 62 | state_dict = model.state_dict() 63 | for key_old in pre_state_dict.keys(): 64 | key = key_old 65 | if key not in state_dict: 66 | continue 67 | value = pre_state_dict[key_old] 68 | if not isinstance(value, torch.FloatTensor): 69 | value = value.data 70 | state_dict[key] = value 71 | cnt += 1 72 | print('Load para num:', cnt) 73 | model.load_state_dict(state_dict) 74 | 75 | return model 76 | 77 | -------------------------------------------------------------------------------- /MICCAI23/Models/generate_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from Models import resnet 4 | 5 | 6 | def generate_model(model_type='resnet', model_depth=50, 7 | input_W=224, input_H=224, input_D=224, resnet_shortcut='B', 8 | no_cuda=False, gpu_id=[0], 9 | pretrain_path = 'pretrain/resnet_50.pth', 10 | nb_class=1): 11 | assert model_type in [ 12 | 'resnet' 13 | ] 14 | 15 | if model_type == 'resnet': 16 | assert model_depth in [10, 18, 34, 50, 101, 152, 200] 17 | 18 | if model_depth == 10: 19 | model = resnet.resnet10( 20 | sample_input_W=input_W, 21 | sample_input_H=input_H, 22 | sample_input_D=input_D, 23 | shortcut_type=resnet_shortcut, 24 | no_cuda=no_cuda, 25 | num_seg_classes=nb_class) 26 | fc_input = 256 27 | elif model_depth == 18: 28 | model = resnet.resnet18( 29 | sample_input_W=input_W, 30 | sample_input_H=input_H, 31 | sample_input_D=input_D, 32 | shortcut_type=resnet_shortcut, 33 | no_cuda=no_cuda, 34 | num_seg_classes=nb_class) 35 | fc_input = 512 36 | elif model_depth == 34: 37 | model = resnet.resnet34( 38 | sample_input_W=input_W, 39 | sample_input_H=input_H, 40 | sample_input_D=input_D, 41 | shortcut_type=resnet_shortcut, 42 | no_cuda=no_cuda, 43 | num_seg_classes=nb_class) 44 | fc_input = 512 45 | elif model_depth == 50: 46 | model = resnet.resnet50( 47 | sample_input_W=input_W, 48 | sample_input_H=input_H, 49 | sample_input_D=input_D, 50 | shortcut_type=resnet_shortcut, 51 | no_cuda=no_cuda, 52 | num_seg_classes=nb_class) 53 | fc_input = 2048 54 | elif model_depth == 101: 55 | model = resnet.resnet101( 56 | sample_input_W=input_W, 57 | sample_input_H=input_H, 58 | sample_input_D=input_D, 59 | shortcut_type=resnet_shortcut, 60 | no_cuda=no_cuda, 61 | num_seg_classes=nb_class) 62 | fc_input = 2048 63 | elif model_depth == 152: 64 | model = resnet.resnet152( 65 | sample_input_W=input_W, 66 | sample_input_H=input_H, 67 | sample_input_D=input_D, 68 | shortcut_type=resnet_shortcut, 69 | no_cuda=no_cuda, 70 | num_seg_classes=nb_class) 71 | fc_input = 2048 72 | elif model_depth == 200: 73 | model = resnet.resnet200( 74 | sample_input_W=input_W, 75 | sample_input_H=input_H, 76 | sample_input_D=input_D, 77 | shortcut_type=resnet_shortcut, 78 | no_cuda=no_cuda, 79 | num_seg_classes=nb_class) 80 | fc_input = 2048 81 | else: 82 | model = resnet.resnet10( 83 | sample_input_W=input_W, 84 | sample_input_H=input_H, 85 | sample_input_D=input_D, 86 | shortcut_type=resnet_shortcut, 87 | no_cuda=no_cuda, 88 | num_seg_classes=nb_class) 89 | fc_input = 256 90 | 91 | model.conv_seg = nn.Sequential(nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten(), 92 | nn.Linear(in_features=fc_input, out_features=nb_class, bias=True)) 93 | 94 | if not no_cuda: 95 | if len(gpu_id) > 1: 96 | model = model.cuda() 97 | model = nn.DataParallel(model, device_ids=gpu_id) 98 | net_dict = model.state_dict() 99 | else: 100 | import os 101 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id[0]) 102 | model = model.cuda() 103 | model = nn.DataParallel(model, device_ids=None) 104 | net_dict = model.state_dict() 105 | else: 106 | net_dict = model.state_dict() 107 | 108 | print('loading pretrained model {}'.format(pretrain_path)) 109 | pretrain = torch.load(pretrain_path) 110 | pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()} 111 | # k 是每一层的名称,v是权重数值 112 | net_dict.update(pretrain_dict) # 字典 dict2 的键/值对更新到 dict 里。 113 | model.load_state_dict(net_dict) # model.load_state_dict()函数把加载的权重复制到模型的权重中去 114 | 115 | print("-------- pre-train model load successfully --------") 116 | 117 | return model 118 | -------------------------------------------------------------------------------- /MedIA’24/Models/generate_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from Models import resnet 4 | 5 | 6 | def generate_model(model_type='resnet', model_depth=50, 7 | input_W=224, input_H=224, input_D=224, resnet_shortcut='B', 8 | no_cuda=False, gpu_id=[0], 9 | pretrain_path = 'pretrain/resnet_50.pth', 10 | nb_class=1): 11 | assert model_type in [ 12 | 'resnet' 13 | ] 14 | 15 | if model_type == 'resnet': 16 | assert model_depth in [10, 18, 34, 50, 101, 152, 200] 17 | 18 | if model_depth == 10: 19 | model = resnet.resnet10( 20 | sample_input_W=input_W, 21 | sample_input_H=input_H, 22 | sample_input_D=input_D, 23 | shortcut_type=resnet_shortcut, 24 | no_cuda=no_cuda, 25 | num_seg_classes=nb_class) 26 | fc_input = 256 27 | elif model_depth == 18: 28 | model = resnet.resnet18( 29 | sample_input_W=input_W, 30 | sample_input_H=input_H, 31 | sample_input_D=input_D, 32 | shortcut_type=resnet_shortcut, 33 | no_cuda=no_cuda, 34 | num_seg_classes=nb_class) 35 | fc_input = 512 36 | elif model_depth == 34: 37 | model = resnet.resnet34( 38 | sample_input_W=input_W, 39 | sample_input_H=input_H, 40 | sample_input_D=input_D, 41 | shortcut_type=resnet_shortcut, 42 | no_cuda=no_cuda, 43 | num_seg_classes=nb_class) 44 | fc_input = 512 45 | elif model_depth == 50: 46 | model = resnet.resnet50( 47 | sample_input_W=input_W, 48 | sample_input_H=input_H, 49 | sample_input_D=input_D, 50 | shortcut_type=resnet_shortcut, 51 | no_cuda=no_cuda, 52 | num_seg_classes=nb_class) 53 | fc_input = 2048 54 | elif model_depth == 101: 55 | model = resnet.resnet101( 56 | sample_input_W=input_W, 57 | sample_input_H=input_H, 58 | sample_input_D=input_D, 59 | shortcut_type=resnet_shortcut, 60 | no_cuda=no_cuda, 61 | num_seg_classes=nb_class) 62 | fc_input = 2048 63 | elif model_depth == 152: 64 | model = resnet.resnet152( 65 | sample_input_W=input_W, 66 | sample_input_H=input_H, 67 | sample_input_D=input_D, 68 | shortcut_type=resnet_shortcut, 69 | no_cuda=no_cuda, 70 | num_seg_classes=nb_class) 71 | fc_input = 2048 72 | elif model_depth == 200: 73 | model = resnet.resnet200( 74 | sample_input_W=input_W, 75 | sample_input_H=input_H, 76 | sample_input_D=input_D, 77 | shortcut_type=resnet_shortcut, 78 | no_cuda=no_cuda, 79 | num_seg_classes=nb_class) 80 | fc_input = 2048 81 | else: 82 | model = resnet.resnet10( 83 | sample_input_W=input_W, 84 | sample_input_H=input_H, 85 | sample_input_D=input_D, 86 | shortcut_type=resnet_shortcut, 87 | no_cuda=no_cuda, 88 | num_seg_classes=nb_class) 89 | fc_input = 256 90 | 91 | model.conv_seg = nn.Sequential(nn.AdaptiveAvgPool3d((1, 1, 1)), nn.Flatten(), 92 | nn.Linear(in_features=fc_input, out_features=nb_class, bias=True)) 93 | 94 | if not no_cuda: 95 | if len(gpu_id) > 1: 96 | model = model.cuda() 97 | model = nn.DataParallel(model, device_ids=gpu_id) 98 | net_dict = model.state_dict() 99 | else: 100 | import os 101 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id[0]) 102 | model = model.cuda() 103 | model = nn.DataParallel(model, device_ids=None) 104 | net_dict = model.state_dict() 105 | else: 106 | net_dict = model.state_dict() 107 | 108 | print('loading pretrained model {}'.format(pretrain_path)) 109 | pretrain = torch.load(pretrain_path) 110 | pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()} 111 | # k 是每一层的名称,v是权重数值 112 | net_dict.update(pretrain_dict) # 字典 dict2 的键/值对更新到 dict 里。 113 | model.load_state_dict(net_dict) # model.load_state_dict()函数把加载的权重复制到模型的权重中去 114 | 115 | print("-------- pre-train model load successfully --------") 116 | 117 | return model 118 | -------------------------------------------------------------------------------- /MedIA’24/metrics.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import numpy as np 3 | import math 4 | import torch 5 | 6 | def binary_calibration(probabilities, target, n_bins=10, threshold_range = None, mask=None): 7 | if probabilities.ndim > target.ndim: 8 | if probabilities.shape[-1] > 2: 9 | raise ValueError('can only evaluate the calibration for binary classification') 10 | elif probabilities.shape[-1] == 2: 11 | probabilities = probabilities[..., 1] 12 | else: 13 | probabilities = np.squeeze(probabilities, axis=-1) 14 | 15 | if mask is not None: 16 | probabilities = probabilities[mask] 17 | target = target[mask] 18 | 19 | if threshold_range is not None: 20 | low_thres, up_thres = threshold_range 21 | mask = np.logical_and(probabilities < up_thres, probabilities > low_thres) 22 | probabilities = probabilities[mask] 23 | target = target[mask] 24 | 25 | pos_frac, mean_confidence, bin_count, non_zero_bins = \ 26 | _binary_calibration(target.flatten(), probabilities.flatten(), n_bins) 27 | 28 | return pos_frac, mean_confidence, bin_count, non_zero_bins 29 | 30 | def _binary_calibration(target, probs_positive_cls, n_bins=10): 31 | # same as sklearn.calibration calibration_curve but with the bin_count returned 32 | bins = np.linspace(0., 1. + 1e-8, n_bins + 1) 33 | binids = np.digitize(probs_positive_cls, bins) - 1 34 | 35 | # # note: this is the original formulation which has always n_bins + 1 as length 36 | # bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=len(bins)) 37 | # bin_true = np.bincount(binids, weights=target, minlength=len(bins)) 38 | # bin_total = np.bincount(binids, minlength=len(bins)) 39 | 40 | bin_sums = np.bincount(binids, weights=probs_positive_cls, minlength=n_bins) 41 | bin_true = np.bincount(binids, weights=target, minlength=n_bins) 42 | bin_total = np.bincount(binids, minlength=n_bins) 43 | 44 | nonzero = bin_total != 0 45 | prob_true = (bin_true[nonzero] / bin_total[nonzero]) 46 | prob_pred = (bin_sums[nonzero] / bin_total[nonzero]) 47 | 48 | return prob_true, prob_pred, bin_total[nonzero], nonzero 49 | 50 | def _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim): 51 | if bin_weighting == 'proportion': 52 | bin_proportions = bin_count / bin_count.sum() 53 | elif bin_weighting == 'log_proportion': 54 | bin_proportions = np.log(bin_count) / np.log(bin_count).sum() 55 | elif bin_weighting == 'power_proportion': 56 | bin_proportions = bin_count**(1/n_dim) / (bin_count**(1/n_dim)).sum() 57 | elif bin_weighting == 'mean_proportion': 58 | bin_proportions = 1 / non_zero_bins.sum() 59 | else: 60 | raise ValueError('unknown bin weighting "{}"'.format(bin_weighting)) 61 | return bin_proportions 62 | 63 | def ece_binary(probabilities, target, n_bins=10, threshold_range= None, mask=None, out_bins=None, 64 | bin_weighting='proportion'): 65 | # input: 1. probabilities (np) 2. target (np) 3. threshold_range (tuple[low,high]) 4. mask 66 | 67 | n_dim = target.ndim 68 | 69 | pos_frac, mean_confidence, bin_count, non_zero_bins = \ 70 | binary_calibration(probabilities, target, n_bins, threshold_range, mask) 71 | 72 | bin_proportions = _get_proportion(bin_weighting, bin_count, non_zero_bins, n_dim) 73 | 74 | if out_bins is not None: 75 | out_bins['bins_count'] = bin_count 76 | out_bins['bins_avg_confidence'] = mean_confidence 77 | out_bins['bins_positive_fraction'] = pos_frac 78 | out_bins['bins_non_zero'] = non_zero_bins 79 | 80 | ece = (np.abs(mean_confidence - pos_frac) * bin_proportions).sum() 81 | return ece 82 | 83 | def cal_ece(logits,targets): 84 | # ece_total = 0 85 | logit = logits 86 | target = targets.cpu().detach().numpy() 87 | pred = F.softmax(logit, dim=0) 88 | pc = pred.cpu().detach().numpy() 89 | pc = pc.argmax(0) 90 | ece = ece_binary(pc, target) 91 | return ece 92 | 93 | def cal_ece_our(preds,targets): 94 | # ece_total = 0 95 | target = targets.cpu().detach().numpy() 96 | pc = preds.cpu().detach().numpy() 97 | pc = pc.argmax(0) 98 | ece = ece_binary(pc, target) 99 | return ece 100 | 101 | def Uentropy(logits,c): 102 | # c = 4 103 | # logits = torch.randn(1, 4, 240, 240,155).cuda() 104 | pc = F.softmax(logits, dim=1) # 1 4 240 240 155 105 | logpc = F.log_softmax(logits, dim=1) # 1 4 240 240 155 106 | # u_all1 = -pc * logpc / c 107 | u_all = -pc * logpc / math.log(c) 108 | # max_u = torch.max(u_all) 109 | # min_u = torch.min(u_all) 110 | # NU1 = torch.sum(u_all, dim=1) 111 | # k = u_all.shape[1] 112 | # NU2 = torch.sum(u_all[:, 0:u_all.shape[1]-1, :, :], dim=1) 113 | NU = torch.sum(u_all[:,1:u_all.shape[1],:,:], dim=1) 114 | return NU 115 | 116 | def Uentropy_our(logits,c): 117 | # c = 4 118 | # logits = torch.randn(1, 4, 240, 240,155).cuda() 119 | pc = logits # 1 4 240 240 155 120 | logpc = torch.log(logits) # 1 4 240 240 155 121 | # u_all1 = -pc * logpc / c 122 | u_all = -pc * logpc / math.log(c) 123 | # max_u = torch.max(u_all) 124 | # min_u = torch.min(u_all) 125 | # NU1 = torch.sum(u_all, dim=1) 126 | # k = u_all.shape[1] 127 | # NU2 = torch.sum(u_all[:, 0:u_all.shape[1]-1, :, :], dim=1) 128 | NU = torch.sum(u_all[:,1:u_all.shape[1],:,:], dim=1) 129 | return NU 130 | -------------------------------------------------------------------------------- /MedIA’24/metrics2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from sklearn import metrics 5 | from sklearn.preprocessing import OneHotEncoder 6 | 7 | 8 | def calc_metrics_for_CPM(predict, softmax, logit, label): 9 | label_onehot = torch.nn.functional.one_hot(torch.from_numpy(label)) 10 | label_onehot = label_onehot.squeeze().cpu().detach().numpy() 11 | correct = (predict == label) 12 | # print(predict[:20]) 13 | # print(label[:20]) 14 | # exit() 15 | acc = np.count_nonzero(correct) / len(predict) 16 | aurc, eaurc = calc_aurc_eaurc(softmax, correct) 17 | aupr, fpr = calc_fpr_aupr(softmax, correct) 18 | ece = calc_ece(softmax, label, bins=15) 19 | nll, brier = calc_nll_brier(softmax, logit, label, label_onehot) 20 | 21 | return acc, aurc, eaurc, aupr, fpr, ece, nll, brier 22 | 23 | 24 | def calc_metrics(loader, label, label_onehot, model, criterion, t): 25 | acc, softmax, correct, logit = get_metric_values(loader, model, criterion, t) 26 | # aurc, eaurc 27 | aurc, eaurc = calc_aurc_eaurc(softmax, correct) 28 | # fpr, aupr 29 | aupr, fpr = calc_fpr_aupr(softmax, correct) 30 | # calibration measure ece , mce, rmsce 31 | ece = calc_ece(softmax, label, bins=15) 32 | # brier, nll 33 | nll, brier = calc_nll_brier(softmax, logit, label, label_onehot) 34 | 35 | return acc, aurc, eaurc, aupr, fpr, ece, nll, brier 36 | 37 | 38 | # AURC, EAURC 39 | def calc_aurc_eaurc(softmax, correct): 40 | softmax = np.array(softmax) 41 | correctness = np.array(correct) 42 | softmax_max = np.max(softmax, 1) 43 | sort_values = sorted(zip(softmax_max[:], correctness[:]), key=lambda x: x[0], reverse=True) 44 | sort_softmax_max, sort_correctness = zip(*sort_values) 45 | risk_li, coverage_li = coverage_risk(sort_softmax_max, sort_correctness) 46 | aurc, eaurc = aurc_eaurc(risk_li) 47 | 48 | return aurc, eaurc 49 | 50 | 51 | # AUPR ERROR 52 | def calc_fpr_aupr(softmax, correct): 53 | softmax = np.array(softmax) 54 | correctness = np.array(correct) 55 | softmax_max = np.max(softmax, 1) 56 | 57 | fpr, tpr, thresholds = metrics.roc_curve(correctness, softmax_max) 58 | idx_tpr_95 = np.argmin(np.abs(tpr - 0.95)) 59 | fpr_in_tpr_95 = fpr[idx_tpr_95] 60 | 61 | aupr_err = metrics.average_precision_score(correctness, softmax_max) 62 | 63 | # print("AUPR {0:.2f}".format(aupr_err * 100)) 64 | # print('FPR {0:.2f}'.format(fpr_in_tpr_95 * 100)) 65 | 66 | return aupr_err, fpr_in_tpr_95 67 | 68 | 69 | # ECE 70 | def calc_ece(softmax, label, bins=15): 71 | bin_boundaries = torch.linspace(0, 1, bins + 1) 72 | bin_lowers = bin_boundaries[:-1] 73 | bin_uppers = bin_boundaries[1:] 74 | 75 | softmax = torch.tensor(softmax) 76 | labels = torch.tensor(label) 77 | 78 | softmax_max, predictions = torch.max(softmax, 1) 79 | correctness = predictions.eq(labels) 80 | 81 | ece = torch.zeros(1) 82 | 83 | for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): 84 | in_bin = softmax_max.gt(bin_lower.item()) * softmax_max.le(bin_upper.item()) 85 | prop_in_bin = in_bin.float().mean() 86 | 87 | if prop_in_bin.item() > 0.0: 88 | accuracy_in_bin = correctness[in_bin].float().mean() 89 | avg_confidence_in_bin = softmax_max[in_bin].mean() 90 | 91 | ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin 92 | 93 | # print("ECE {0:.2f} ".format(ece.item() * 100)) 94 | 95 | return ece.item() 96 | 97 | 98 | # NLL & Brier Score 99 | def calc_nll_brier(softmax, logit, label, label_onehot): 100 | brier_score = np.mean(np.sum((softmax - label_onehot) ** 2, axis=1)) 101 | 102 | logit = torch.tensor(logit, dtype=torch.float) 103 | label = torch.tensor(label, dtype=torch.int) 104 | logsoftmax = torch.nn.LogSoftmax(dim=1) 105 | 106 | log_softmax = logsoftmax(logit) 107 | nll = calc_nll(log_softmax, label.long()) 108 | 109 | # print("NLL {0:.2f} ".format(nll.item() * 10)) 110 | # print('Brier {0:.2f}'.format(brier_score * 100)) 111 | 112 | return nll.item()* 10, brier_score* 100 113 | 114 | 115 | # Calc NLL 116 | def calc_nll(log_softmax, label): 117 | out = torch.zeros_like(label, dtype=torch.float) 118 | for i in range(len(label)): 119 | out[i] = log_softmax[i][label[i]] 120 | 121 | return -out.sum() / len(out) 122 | 123 | 124 | # Calc coverage, risk 125 | def coverage_risk(confidence, correctness): 126 | risk_list = [] 127 | coverage_list = [] 128 | risk = 0 129 | for i in range(len(confidence)): 130 | coverage = (i + 1) / len(confidence) 131 | coverage_list.append(coverage) 132 | 133 | if correctness[i] == 0: 134 | risk += 1 135 | 136 | risk_list.append(risk / (i + 1)) 137 | 138 | return risk_list, coverage_list 139 | 140 | 141 | # Calc aurc, eaurc 142 | def aurc_eaurc(risk_list): 143 | r = risk_list[-1] 144 | risk_coverage_curve_area = 0 145 | optimal_risk_area = r + (1 - r) * np.log(1 - r) 146 | for risk_value in risk_list: 147 | risk_coverage_curve_area += risk_value * (1 / len(risk_list)) 148 | 149 | aurc = risk_coverage_curve_area 150 | eaurc = risk_coverage_curve_area - optimal_risk_area 151 | 152 | # print("AURC {0:.2f}".format(aurc * 1000)) 153 | # print("EAURC {0:.2f}".format(eaurc * 1000)) 154 | 155 | return aurc, eaurc 156 | 157 | 158 | # Get softmax, logit 159 | def get_metric_values(loader, model, criterion, t): 160 | model.eval() 161 | with torch.no_grad(): 162 | total_loss = 0 163 | total_acc = 0 164 | accuracy = 0 165 | 166 | list_softmax = [] 167 | list_correct = [] 168 | list_logit = [] 169 | 170 | for input, target, idx in loader: 171 | input = input.cuda() 172 | target = target.cuda() 173 | 174 | output = model(input) 175 | output = output / t 176 | 177 | loss = criterion(output, target.long()).cuda() 178 | 179 | total_loss += loss.mean().item() 180 | pred = output.data.max(1, keepdim=True)[1] 181 | 182 | total_acc += pred.eq(target.data.view_as(pred)).sum() 183 | 184 | for i in output: 185 | list_logit.append(i.cpu().data.numpy()) 186 | 187 | list_softmax.extend(F.softmax(output).cpu().data.numpy()) 188 | 189 | for j in range(len(pred)): 190 | if pred[j] == target[j]: 191 | accuracy += 1 192 | cor = 1 193 | else: 194 | cor = 0 195 | list_correct.append(cor) 196 | 197 | total_loss /= len(loader) 198 | total_acc = 100. * total_acc / len(loader.dataset) 199 | 200 | print('Accuracy {:.2f}'.format(total_acc)) 201 | 202 | return total_acc.item(), list_softmax, list_correct, list_logit -------------------------------------------------------------------------------- /MICCAI23/Models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet200' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | dilation=dilation, 21 | stride=stride, 22 | padding=dilation, 23 | bias=False) 24 | 25 | 26 | def downsample_basic_block(x, planes, stride, no_cuda=False): 27 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 28 | zero_pads = torch.Tensor( 29 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 30 | out.size(4)).zero_() 31 | if not no_cuda: 32 | if isinstance(out.data, torch.cuda.FloatTensor): 33 | zero_pads = zero_pads.cuda() 34 | 35 | out = Variable(torch.cat([out.data, zero_pads.cuda()], dim=1)) 36 | 37 | return out 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 44 | super(BasicBlock, self).__init__() 45 | self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation) 46 | self.bn1 = nn.BatchNorm3d(planes) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.conv2 = conv3x3x3(planes, planes, dilation=dilation) 49 | self.bn2 = nn.BatchNorm3d(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | self.dilation = dilation 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | out = self.conv2(out) 61 | out = self.bn2(out) 62 | 63 | if self.downsample is not None: 64 | residual = self.downsample(x) 65 | 66 | out += residual 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 76 | super(Bottleneck, self).__init__() 77 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 78 | self.bn1 = nn.BatchNorm3d(planes) 79 | self.conv2 = nn.Conv3d( 80 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) 81 | self.bn2 = nn.BatchNorm3d(planes) 82 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm3d(planes * 4) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | self.dilation = dilation 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, 115 | block, 116 | layers, 117 | sample_input_D, 118 | sample_input_H, 119 | sample_input_W, 120 | num_seg_classes, 121 | shortcut_type='B', 122 | no_cuda=False): 123 | self.inplanes = 64 124 | self.no_cuda = no_cuda 125 | super(ResNet, self).__init__() 126 | self.conv1 = nn.Conv3d( 127 | 1, 128 | 64, 129 | kernel_size=7, 130 | stride=(2, 2, 2), 131 | padding=(3, 3, 3), 132 | bias=False) 133 | 134 | self.bn1 = nn.BatchNorm3d(64) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 138 | self.layer2 = self._make_layer( 139 | block, 128, layers[1], shortcut_type, stride=2) 140 | self.layer3 = self._make_layer( 141 | block, 256, layers[2], shortcut_type, stride=1, dilation=2) 142 | self.layer4 = self._make_layer( 143 | block, 512, layers[3], shortcut_type, stride=1, dilation=4) 144 | self.avgpool = nn.AdaptiveAvgPool2d(1) 145 | self.fc = nn.Linear(512 * block.expansion, num_seg_classes) 146 | # self.conv_seg = nn.Sequential( 147 | # nn.ConvTranspose3d( 148 | # 512 * block.expansion, 149 | # 32, 150 | # 2, 151 | # stride=2 152 | # ), 153 | # nn.BatchNorm3d(32), 154 | # nn.ReLU(inplace=True), 155 | # nn.Conv3d( 156 | # 32, 157 | # 32, 158 | # kernel_size=3, 159 | # stride=(1, 1, 1), 160 | # padding=(1, 1, 1), 161 | # bias=False), 162 | # nn.BatchNorm3d(32), 163 | # nn.ReLU(inplace=True), 164 | # nn.Conv3d( 165 | # 32, 166 | # num_seg_classes, 167 | # kernel_size=1, 168 | # stride=(1, 1, 1), 169 | # bias=False) 170 | # ) 171 | 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv3d): 174 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 175 | elif isinstance(m, nn.BatchNorm3d): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | 179 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1): 180 | downsample = None 181 | if stride != 1 or self.inplanes != planes * block.expansion: 182 | if shortcut_type == 'A': 183 | downsample = partial( 184 | downsample_basic_block, 185 | planes=planes * block.expansion, 186 | stride=stride, 187 | no_cuda=self.no_cuda) 188 | else: 189 | downsample = nn.Sequential( 190 | nn.Conv3d( 191 | self.inplanes, 192 | planes * block.expansion, 193 | kernel_size=1, 194 | stride=stride, 195 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 196 | 197 | layers = [] 198 | layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample)) 199 | self.inplanes = planes * block.expansion 200 | for i in range(1, blocks): 201 | layers.append(block(self.inplanes, planes, dilation=dilation)) 202 | 203 | return nn.Sequential(*layers) 204 | 205 | def forward(self, x): 206 | x = self.conv1(x) 207 | x = self.bn1(x) 208 | x = self.relu(x) 209 | x = self.maxpool(x) 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | x = self.avgpool(x) 215 | x = x.view(x.size(0), -1) 216 | x = self.fc(x) 217 | # x = self.conv_seg(x) 218 | 219 | return x 220 | 221 | 222 | def resnet10(**kwargs): 223 | """Constructs a ResNet-18 model. 224 | """ 225 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 226 | return model 227 | 228 | 229 | def resnet18(**kwargs): 230 | """Constructs a ResNet-18 model. 231 | """ 232 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 233 | return model 234 | 235 | 236 | def resnet34(**kwargs): 237 | """Constructs a ResNet-34 model. 238 | """ 239 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 240 | return model 241 | 242 | 243 | def resnet50(**kwargs): 244 | """Constructs a ResNet-50 model. 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 247 | return model 248 | 249 | 250 | def resnet101(**kwargs): 251 | """Constructs a ResNet-101 model. 252 | """ 253 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 254 | return model 255 | 256 | 257 | def resnet152(**kwargs): 258 | """Constructs a ResNet-101 model. 259 | """ 260 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 261 | return model 262 | 263 | 264 | def resnet200(**kwargs): 265 | """Constructs a ResNet-101 model. 266 | """ 267 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 268 | return model -------------------------------------------------------------------------------- /MedIA’24/Models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = [ 9 | 'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet200' 11 | ] 12 | 13 | 14 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1): 15 | # 3x3x3 convolution with padding 16 | return nn.Conv3d( 17 | in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | dilation=dilation, 21 | stride=stride, 22 | padding=dilation, 23 | bias=False) 24 | 25 | 26 | def downsample_basic_block(x, planes, stride, no_cuda=False): 27 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 28 | zero_pads = torch.Tensor( 29 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 30 | out.size(4)).zero_() 31 | if not no_cuda: 32 | if isinstance(out.data, torch.cuda.FloatTensor): 33 | zero_pads = zero_pads.cuda() 34 | 35 | out = Variable(torch.cat([out.data, zero_pads.cuda()], dim=1)) 36 | 37 | return out 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 44 | super(BasicBlock, self).__init__() 45 | self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation) 46 | self.bn1 = nn.BatchNorm3d(planes) 47 | self.relu = nn.ReLU(inplace=True) 48 | self.conv2 = conv3x3x3(planes, planes, dilation=dilation) 49 | self.bn2 = nn.BatchNorm3d(planes) 50 | self.downsample = downsample 51 | self.stride = stride 52 | self.dilation = dilation 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | out = self.conv2(out) 61 | out = self.bn2(out) 62 | 63 | if self.downsample is not None: 64 | residual = self.downsample(x) 65 | 66 | out += residual 67 | out = self.relu(out) 68 | 69 | return out 70 | 71 | 72 | class Bottleneck(nn.Module): 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 76 | super(Bottleneck, self).__init__() 77 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 78 | self.bn1 = nn.BatchNorm3d(planes) 79 | self.conv2 = nn.Conv3d( 80 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) 81 | self.bn2 = nn.BatchNorm3d(planes) 82 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 83 | self.bn3 = nn.BatchNorm3d(planes * 4) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.stride = stride 87 | self.dilation = dilation 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.conv1(x) 93 | out = self.bn1(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv2(out) 97 | out = self.bn2(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv3(out) 101 | out = self.bn3(out) 102 | 103 | if self.downsample is not None: 104 | residual = self.downsample(x) 105 | 106 | out += residual 107 | out = self.relu(out) 108 | 109 | return out 110 | 111 | 112 | class ResNet(nn.Module): 113 | 114 | def __init__(self, 115 | block, 116 | layers, 117 | sample_input_D, 118 | sample_input_H, 119 | sample_input_W, 120 | num_seg_classes, 121 | shortcut_type='B', 122 | no_cuda=False): 123 | self.inplanes = 64 124 | self.no_cuda = no_cuda 125 | super(ResNet, self).__init__() 126 | self.conv1 = nn.Conv3d( 127 | 1, 128 | 64, 129 | kernel_size=7, 130 | stride=(2, 2, 2), 131 | padding=(3, 3, 3), 132 | bias=False) 133 | 134 | self.bn1 = nn.BatchNorm3d(64) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 137 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 138 | self.layer2 = self._make_layer( 139 | block, 128, layers[1], shortcut_type, stride=2) 140 | self.layer3 = self._make_layer( 141 | block, 256, layers[2], shortcut_type, stride=1, dilation=2) 142 | self.layer4 = self._make_layer( 143 | block, 512, layers[3], shortcut_type, stride=1, dilation=4) 144 | self.avgpool = nn.AdaptiveAvgPool2d(1) 145 | self.fc = nn.Linear(512 * block.expansion, num_seg_classes) 146 | # self.conv_seg = nn.Sequential( 147 | # nn.ConvTranspose3d( 148 | # 512 * block.expansion, 149 | # 32, 150 | # 2, 151 | # stride=2 152 | # ), 153 | # nn.BatchNorm3d(32), 154 | # nn.ReLU(inplace=True), 155 | # nn.Conv3d( 156 | # 32, 157 | # 32, 158 | # kernel_size=3, 159 | # stride=(1, 1, 1), 160 | # padding=(1, 1, 1), 161 | # bias=False), 162 | # nn.BatchNorm3d(32), 163 | # nn.ReLU(inplace=True), 164 | # nn.Conv3d( 165 | # 32, 166 | # num_seg_classes, 167 | # kernel_size=1, 168 | # stride=(1, 1, 1), 169 | # bias=False) 170 | # ) 171 | 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv3d): 174 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 175 | elif isinstance(m, nn.BatchNorm3d): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | 179 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1): 180 | downsample = None 181 | if stride != 1 or self.inplanes != planes * block.expansion: 182 | if shortcut_type == 'A': 183 | downsample = partial( 184 | downsample_basic_block, 185 | planes=planes * block.expansion, 186 | stride=stride, 187 | no_cuda=self.no_cuda) 188 | else: 189 | downsample = nn.Sequential( 190 | nn.Conv3d( 191 | self.inplanes, 192 | planes * block.expansion, 193 | kernel_size=1, 194 | stride=stride, 195 | bias=False), nn.BatchNorm3d(planes * block.expansion)) 196 | 197 | layers = [] 198 | layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample)) 199 | self.inplanes = planes * block.expansion 200 | for i in range(1, blocks): 201 | layers.append(block(self.inplanes, planes, dilation=dilation)) 202 | 203 | return nn.Sequential(*layers) 204 | 205 | def forward(self, x): 206 | x = self.conv1(x) 207 | x = self.bn1(x) 208 | x = self.relu(x) 209 | x = self.maxpool(x) 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | x = self.avgpool(x) 215 | x = x.view(x.size(0), -1) 216 | x = self.fc(x) 217 | # x = self.conv_seg(x) 218 | 219 | return x 220 | 221 | 222 | def resnet10(**kwargs): 223 | """Constructs a ResNet-18 model. 224 | """ 225 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 226 | return model 227 | 228 | 229 | def resnet18(**kwargs): 230 | """Constructs a ResNet-18 model. 231 | """ 232 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 233 | return model 234 | 235 | 236 | def resnet34(**kwargs): 237 | """Constructs a ResNet-34 model. 238 | """ 239 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 240 | return model 241 | 242 | 243 | def resnet50(**kwargs): 244 | """Constructs a ResNet-50 model. 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 247 | return model 248 | 249 | 250 | def resnet101(**kwargs): 251 | """Constructs a ResNet-101 model. 252 | """ 253 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 254 | return model 255 | 256 | 257 | def resnet152(**kwargs): 258 | """Constructs a ResNet-101 model. 259 | """ 260 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 261 | return model 262 | 263 | 264 | def resnet200(**kwargs): 265 | """Constructs a ResNet-101 model. 266 | """ 267 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 268 | return model -------------------------------------------------------------------------------- /MedIA’24/Models/unetr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Tuple, Union 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock 18 | from monai.networks.blocks.dynunet_block import UnetOutBlock 19 | from monai.networks.nets import ViT 20 | 21 | 22 | class UNETR_base_3DNet(nn.Module): 23 | # res2net based encoder decoder 24 | def __init__(self, num_classes=10): 25 | super(UNETR_base_3DNet, self).__init__() 26 | # ---- ResNet Backbone ---- 27 | self.UNETR = UNETR_pretrain(pretrained=True) 28 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.UNETR.hidden_size)) 29 | self.classification_head = nn.Sequential(nn.Linear(self.UNETR.hidden_size, num_classes), nn.Tanh()) 30 | self.norm = nn.LayerNorm(self.UNETR.hidden_size) 31 | self.avgpool = nn.AdaptiveAvgPool1d(1) 32 | def forward(self, x): 33 | #origanal x do: 34 | # x1, hidden_states_out = self.UNETR.vit(x) 35 | x, hidden_states_out = self.UNETR.vit(x) 36 | # cls_token = self.cls_token.expand(x.shape[0], -1, -1) 37 | # x = torch.cat((cls_token, x), dim=1) 38 | x = self.norm(x) 39 | # x = self.classification_head(x[:, 0]) 40 | x = self.avgpool(x.transpose(1, 2)) 41 | x = torch.flatten(x, 1) 42 | return x 43 | 44 | def UNETR_pretrain(pretrained=False, **kwargs): 45 | """Constructs a Res2Net-50_v1b_26w_4s lib. 46 | Args: 47 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 48 | """ 49 | model = UNETR( 50 | in_channels=1, 51 | out_channels=14, 52 | img_size=(96, 96, 96), 53 | feature_size=16, 54 | hidden_size=768, 55 | mlp_dim=3072, 56 | num_heads=12, 57 | pos_embed='perceptron', 58 | norm_name='instance', 59 | conv_block=True, 60 | res_block=True, 61 | dropout_rate=0.0) 62 | # model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, num_classes=2) # changed by coco 63 | 64 | if pretrained: 65 | model_state = torch.load('/data/zou_ke/projects/TMC_ICLR/pretrain/UNETR_model_best_acc.pth') 66 | model.load_state_dict(model_state) 67 | # model.load_state_dict(model_state,strict=False) 68 | # model.load_from(model_state) 69 | return model 70 | 71 | class UNETR(nn.Module): 72 | """ 73 | UNETR based on: "Hatamizadeh et al., 74 | UNETR: Transformers for 3D Medical Image Segmentation " 75 | """ 76 | 77 | def __init__( 78 | self, 79 | in_channels: int, 80 | out_channels: int, 81 | img_size: Tuple[int, int, int], 82 | feature_size: int = 16, 83 | hidden_size: int = 768, 84 | mlp_dim: int = 3072, 85 | num_heads: int = 12, 86 | pos_embed: str = "perceptron", 87 | norm_name: Union[Tuple, str] = "instance", 88 | conv_block: bool = True, 89 | res_block: bool = True, 90 | dropout_rate: float = 0.0, 91 | ) -> None: 92 | """ 93 | Args: 94 | in_channels: dimension of input channels. 95 | out_channels: dimension of output channels. 96 | img_size: dimension of input image. 97 | feature_size: dimension of network feature size. 98 | hidden_size: dimension of hidden layer. 99 | mlp_dim: dimension of feedforward layer. 100 | num_heads: number of attention heads. 101 | pos_embed: position embedding layer type. 102 | norm_name: feature normalization type and arguments. 103 | conv_block: bool argument to determine if convolutional block is used. 104 | res_block: bool argument to determine if residual block is used. 105 | dropout_rate: faction of the input units to drop. 106 | 107 | Examples:: 108 | 109 | # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm 110 | >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') 111 | 112 | # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm 113 | >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') 114 | 115 | """ 116 | 117 | super().__init__() 118 | 119 | if not (0 <= dropout_rate <= 1): 120 | raise AssertionError("dropout_rate should be between 0 and 1.") 121 | 122 | if hidden_size % num_heads != 0: 123 | raise AssertionError("hidden size should be divisible by num_heads.") 124 | 125 | if pos_embed not in ["conv", "perceptron"]: 126 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 127 | 128 | self.num_layers = 12 129 | self.patch_size = (16, 16, 16) 130 | self.feat_size = ( 131 | img_size[0] // self.patch_size[0], 132 | img_size[1] // self.patch_size[1], 133 | img_size[2] // self.patch_size[2], 134 | ) 135 | self.hidden_size = hidden_size 136 | self.classification = False 137 | self.vit = ViT( 138 | in_channels=in_channels, 139 | img_size=img_size, 140 | patch_size=self.patch_size, 141 | hidden_size=hidden_size, 142 | mlp_dim=mlp_dim, 143 | num_layers=self.num_layers, 144 | num_heads=num_heads, 145 | pos_embed=pos_embed, 146 | classification=self.classification, 147 | dropout_rate=dropout_rate, 148 | ) 149 | self.encoder1 = UnetrBasicBlock( 150 | spatial_dims=3, 151 | in_channels=in_channels, 152 | out_channels=feature_size, 153 | kernel_size=3, 154 | stride=1, 155 | norm_name=norm_name, 156 | res_block=res_block, 157 | ) 158 | self.encoder2 = UnetrPrUpBlock( 159 | spatial_dims=3, 160 | in_channels=hidden_size, 161 | out_channels=feature_size * 2, 162 | num_layer=2, 163 | kernel_size=3, 164 | stride=1, 165 | upsample_kernel_size=2, 166 | norm_name=norm_name, 167 | conv_block=conv_block, 168 | res_block=res_block, 169 | ) 170 | self.encoder3 = UnetrPrUpBlock( 171 | spatial_dims=3, 172 | in_channels=hidden_size, 173 | out_channels=feature_size * 4, 174 | num_layer=1, 175 | kernel_size=3, 176 | stride=1, 177 | upsample_kernel_size=2, 178 | norm_name=norm_name, 179 | conv_block=conv_block, 180 | res_block=res_block, 181 | ) 182 | self.encoder4 = UnetrPrUpBlock( 183 | spatial_dims=3, 184 | in_channels=hidden_size, 185 | out_channels=feature_size * 8, 186 | num_layer=0, 187 | kernel_size=3, 188 | stride=1, 189 | upsample_kernel_size=2, 190 | norm_name=norm_name, 191 | conv_block=conv_block, 192 | res_block=res_block, 193 | ) 194 | self.decoder5 = UnetrUpBlock( 195 | spatial_dims=3, 196 | in_channels=hidden_size, 197 | out_channels=feature_size * 8, 198 | kernel_size=3, 199 | upsample_kernel_size=2, 200 | norm_name=norm_name, 201 | res_block=res_block, 202 | ) 203 | self.decoder4 = UnetrUpBlock( 204 | spatial_dims=3, 205 | in_channels=feature_size * 8, 206 | out_channels=feature_size * 4, 207 | kernel_size=3, 208 | upsample_kernel_size=2, 209 | norm_name=norm_name, 210 | res_block=res_block, 211 | ) 212 | self.decoder3 = UnetrUpBlock( 213 | spatial_dims=3, 214 | in_channels=feature_size * 4, 215 | out_channels=feature_size * 2, 216 | kernel_size=3, 217 | upsample_kernel_size=2, 218 | norm_name=norm_name, 219 | res_block=res_block, 220 | ) 221 | self.decoder2 = UnetrUpBlock( 222 | spatial_dims=3, 223 | in_channels=feature_size * 2, 224 | out_channels=feature_size, 225 | kernel_size=3, 226 | upsample_kernel_size=2, 227 | norm_name=norm_name, 228 | res_block=res_block, 229 | ) 230 | self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore 231 | 232 | def proj_feat(self, x, hidden_size, feat_size): 233 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 234 | x = x.permute(0, 4, 1, 2, 3).contiguous() 235 | return x 236 | 237 | def load_from(self, weights): 238 | with torch.no_grad(): 239 | res_weight = weights 240 | # copy weights from patch embedding 241 | # for i in weights["state_dict"]: 242 | for i in weights["state_dict"]: 243 | print(i) 244 | self.vit.patch_embedding.position_embeddings.copy_( 245 | weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] 246 | ) 247 | self.vit.patch_embedding.cls_token.copy_( 248 | weights["state_dict"]["module.transformer.patch_embedding.cls_token"] 249 | ) 250 | self.vit.patch_embedding.patch_embeddings[1].weight.copy_( 251 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] 252 | ) 253 | self.vit.patch_embedding.patch_embeddings[1].bias.copy_( 254 | weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] 255 | ) 256 | 257 | # copy weights from encoding blocks (default: num of blocks: 12) 258 | for bname, block in self.vit.blocks.named_children(): 259 | print(block) 260 | block.loadFrom(weights, n_block=bname) 261 | # last norm layer of transformer 262 | self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) 263 | self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) 264 | 265 | def forward(self, x_in): 266 | x, hidden_states_out = self.vit(x_in) 267 | enc1 = self.encoder1(x_in) 268 | x2 = hidden_states_out[3] 269 | enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) 270 | x3 = hidden_states_out[6] 271 | enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) 272 | x4 = hidden_states_out[9] 273 | enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) 274 | dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) 275 | dec3 = self.decoder5(dec4, enc4) 276 | dec2 = self.decoder4(dec3, enc3) 277 | dec1 = self.decoder3(dec2, enc2) 278 | out = self.decoder2(dec1, enc1) 279 | logits = self.out(out) 280 | return logits -------------------------------------------------------------------------------- /MedIA’24/Models/res2net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | 15 | 16 | class Bottle2neck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 20 | """ Constructor 21 | Args: 22 | inplanes: input channel dimensionality 23 | planes: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | downsample: None when stride = 1 26 | baseWidth: basic width of conv3x3 27 | scale: number of scale. 28 | type: 'normal': normal set. 'stage': first block of a new stage. 29 | """ 30 | super(Bottle2neck, self).__init__() 31 | 32 | width = int(math.floor(planes * (baseWidth / 64.0))) 33 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(width * scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale - 1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) 46 | bns.append(nn.BatchNorm2d(width)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stype = stype 56 | self.scale = scale 57 | self.width = width 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i == 0 or self.stype == 'stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i == 0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype == 'normal': 79 | out = torch.cat((out, spx[self.nums]), 1) 80 | elif self.scale != 1 and self.stype == 'stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class Res2Net2(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net2, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AdaptiveAvgPool2d(1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 116 | elif isinstance(m, nn.BatchNorm2d): 117 | nn.init.constant_(m.weight, 1) 118 | nn.init.constant_(m.bias, 0) 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 131 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | class Res2Net(nn.Module): 156 | 157 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 158 | self.inplanes = 64 159 | super(Res2Net, self).__init__() 160 | self.baseWidth = baseWidth 161 | self.scale = scale 162 | self.conv1 = nn.Sequential( 163 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 164 | nn.BatchNorm2d(32), 165 | nn.ReLU(inplace=True), 166 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 167 | nn.BatchNorm2d(32), 168 | nn.ReLU(inplace=True), 169 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 170 | ) 171 | self.bn1 = nn.BatchNorm2d(64) 172 | self.relu = nn.ReLU() 173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 174 | self.layer1 = self._make_layer(block, 64, layers[0]) 175 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 176 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 177 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 178 | self.avgpool = nn.AdaptiveAvgPool2d(1) 179 | self.fc = nn.Linear(512 * block.expansion, num_classes) 180 | 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 184 | elif isinstance(m, nn.BatchNorm2d): 185 | nn.init.constant_(m.weight, 1) 186 | nn.init.constant_(m.bias, 0) 187 | 188 | def _make_layer(self, block, planes, blocks, stride=1): 189 | downsample = None 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | nn.AvgPool2d(kernel_size=stride, stride=stride, 193 | ceil_mode=True, count_include_pad=False), 194 | nn.Conv2d(self.inplanes, planes * block.expansion, 195 | kernel_size=1, stride=1, bias=False), 196 | nn.BatchNorm2d(planes * block.expansion), 197 | ) 198 | 199 | layers = [] 200 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 201 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 202 | self.inplanes = planes * block.expansion 203 | for i in range(1, blocks): 204 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def forward(self, x): 209 | x = self.conv1(x) 210 | x = self.bn1(x) 211 | x = self.relu(x) 212 | x = self.maxpool(x) 213 | 214 | x = self.layer1(x) 215 | x = self.layer2(x) 216 | x = self.layer3(x) 217 | x = self.layer4(x) 218 | 219 | x = self.avgpool(x) 220 | x = x.view(x.size(0), -1) 221 | x = self.fc(x) 222 | 223 | return x 224 | 225 | 226 | def res2net50_v1b(pretrained=False, **kwargs): 227 | """Constructs a Res2Net-50_v1b lib. 228 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 229 | Args: 230 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 231 | """ 232 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 233 | if pretrained: 234 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 235 | return model 236 | 237 | 238 | def res2net101_v1b(pretrained=False, **kwargs): 239 | """Constructs a Res2Net-50_v1b_26w_4s lib. 240 | Args: 241 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 242 | """ 243 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 244 | if pretrained: 245 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 246 | return model 247 | 248 | 249 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 250 | """Constructs a Res2Net-50_v1b_26w_4s lib. 251 | Args: 252 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 253 | """ 254 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 255 | # model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, num_classes=2) # changed by coco 256 | 257 | if pretrained: 258 | model_state = torch.load('/data/zou_ke/projects/TMC_ICLR/pretrain/res2net50_v1b_26w_4s-3cf99910.pth') 259 | model.load_state_dict(model_state) 260 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 261 | return model 262 | 263 | def res2net50_v1b_14w_8s(pretrained=False, **kwargs): 264 | """Constructs a Res2Net-50_v1b_26w_4s lib. 265 | Args: 266 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 267 | """ 268 | model = Res2Net2(Bottle2neck, [3, 4, 6, 3], baseWidth=14, scale=8, **kwargs) 269 | # model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, num_classes=2) # changed by coco 270 | 271 | if pretrained: 272 | model_state = torch.load('/data/zou_ke/projects/TMC_ICLR/pretrain/res2net50_14w_8s-6527dddc.pth') 273 | model.load_state_dict(model_state) 274 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 275 | return model 276 | 277 | # def res2net50_v1b_14w_8s(pretrained=False, **kwargs): 278 | # """Constructs a Res2Net-50_14w_8s model. 279 | # Args: 280 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | # """ 282 | # model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs) 283 | # if pretrained: 284 | # # model = nn.DataParallel(model).cuda() 285 | # model_state = torch.load('/home/zou_ke/projects/TMC/TMC_ICLR/pretrain/res2net50_14w_8s-6527dddc.pth') 286 | # # new_state_dict = OrderedDict() 287 | # # for k, v in model_state.items(): 288 | # # name = k[7:] # module字段在最前面,从第7个字符开始就可以去掉module 289 | # # new_state_dict[name] = v # 新字典的key值对应的value一一对应 290 | # # kk = model_state.OrderedDict 291 | # model.load_state_dict(model_state, strict=False) 292 | # # model_state = torch.load('/home/zou_ke/projects/TMC/TMC_ICLR/pretrain/res2net50_14w_8s-6527dddc.pth') 293 | # # new_state_dict = OrderedDict() 294 | # # for k, v in model_state.items(): # k为module.xxx.weight, v为权重 295 | # # name = k.split('.')[0] # 截取`module.`后面的xxx.weight 296 | # # new_state_dict[name] = v 297 | # # # load params 298 | # # model.load_state_dict(new_state_dict) 299 | # return model 300 | 301 | 302 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 303 | """Constructs a Res2Net-50_v1b_26w_4s lib. 304 | Args: 305 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 306 | """ 307 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 308 | if pretrained: 309 | # model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 310 | model_state = torch.load('/home/zou_ke/projects/TMC_ICLR/pretrain/res2net101_v1b_26w_4s-0812c246.pth') 311 | model.load_state_dict(model_state) 312 | return model 313 | 314 | 315 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 316 | """Constructs a Res2Net-50_v1b_26w_4s lib. 317 | Args: 318 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 319 | """ 320 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs) 321 | if pretrained: 322 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 323 | return model 324 | 325 | 326 | if __name__ == '__main__': 327 | images = torch.rand(1, 3, 224, 224).cuda(0) 328 | model = res2net50_v1b_26w_4s(pretrained=True) 329 | model = model.cuda(0) 330 | print(model(images).size()) -------------------------------------------------------------------------------- /MICCAI23/Models/res2net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 9 | 10 | model_urls = { 11 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 12 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 13 | } 14 | 15 | 16 | class Bottle2neck(nn.Module): 17 | expansion = 4 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 20 | """ Constructor 21 | Args: 22 | inplanes: input channel dimensionality 23 | planes: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | downsample: None when stride = 1 26 | baseWidth: basic width of conv3x3 27 | scale: number of scale. 28 | type: 'normal': normal set. 'stage': first block of a new stage. 29 | """ 30 | super(Bottle2neck, self).__init__() 31 | 32 | width = int(math.floor(planes * (baseWidth / 64.0))) 33 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(width * scale) 35 | 36 | if scale == 1: 37 | self.nums = 1 38 | else: 39 | self.nums = scale - 1 40 | if stype == 'stage': 41 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 42 | convs = [] 43 | bns = [] 44 | for i in range(self.nums): 45 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) 46 | bns.append(nn.BatchNorm2d(width)) 47 | self.convs = nn.ModuleList(convs) 48 | self.bns = nn.ModuleList(bns) 49 | 50 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 52 | 53 | self.relu = nn.ReLU(inplace=True) 54 | self.downsample = downsample 55 | self.stype = stype 56 | self.scale = scale 57 | self.width = width 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | spx = torch.split(out, self.width, 1) 67 | for i in range(self.nums): 68 | if i == 0 or self.stype == 'stage': 69 | sp = spx[i] 70 | else: 71 | sp = sp + spx[i] 72 | sp = self.convs[i](sp) 73 | sp = self.relu(self.bns[i](sp)) 74 | if i == 0: 75 | out = sp 76 | else: 77 | out = torch.cat((out, sp), 1) 78 | if self.scale != 1 and self.stype == 'normal': 79 | out = torch.cat((out, spx[self.nums]), 1) 80 | elif self.scale != 1 and self.stype == 'stage': 81 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class Res2Net2(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net2, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AdaptiveAvgPool2d(1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 116 | elif isinstance(m, nn.BatchNorm2d): 117 | nn.init.constant_(m.weight, 1) 118 | nn.init.constant_(m.bias, 0) 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 131 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | class Res2Net(nn.Module): 156 | 157 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 158 | self.inplanes = 64 159 | super(Res2Net, self).__init__() 160 | self.baseWidth = baseWidth 161 | self.scale = scale 162 | self.conv1 = nn.Sequential( 163 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 164 | nn.BatchNorm2d(32), 165 | nn.ReLU(inplace=True), 166 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 167 | nn.BatchNorm2d(32), 168 | nn.ReLU(inplace=True), 169 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 170 | ) 171 | self.bn1 = nn.BatchNorm2d(64) 172 | self.relu = nn.ReLU() 173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 174 | self.layer1 = self._make_layer(block, 64, layers[0]) 175 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 176 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 177 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 178 | self.avgpool = nn.AdaptiveAvgPool2d(1) 179 | self.fc = nn.Linear(512 * block.expansion, num_classes) 180 | 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 184 | elif isinstance(m, nn.BatchNorm2d): 185 | nn.init.constant_(m.weight, 1) 186 | nn.init.constant_(m.bias, 0) 187 | 188 | def _make_layer(self, block, planes, blocks, stride=1): 189 | downsample = None 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | nn.AvgPool2d(kernel_size=stride, stride=stride, 193 | ceil_mode=True, count_include_pad=False), 194 | nn.Conv2d(self.inplanes, planes * block.expansion, 195 | kernel_size=1, stride=1, bias=False), 196 | nn.BatchNorm2d(planes * block.expansion), 197 | ) 198 | 199 | layers = [] 200 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 201 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 202 | self.inplanes = planes * block.expansion 203 | for i in range(1, blocks): 204 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def forward(self, x): 209 | x = self.conv1(x) 210 | x = self.bn1(x) 211 | x = self.relu(x) 212 | x = self.maxpool(x) 213 | 214 | x = self.layer1(x) 215 | x = self.layer2(x) 216 | x = self.layer3(x) 217 | x = self.layer4(x) 218 | 219 | x = self.avgpool(x) 220 | x = x.view(x.size(0), -1) 221 | x = self.fc(x) 222 | 223 | return x 224 | 225 | 226 | def res2net50_v1b(pretrained=False, **kwargs): 227 | """Constructs a Res2Net-50_v1b lib. 228 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 229 | Args: 230 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 231 | """ 232 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 233 | if pretrained: 234 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 235 | return model 236 | 237 | 238 | def res2net101_v1b(pretrained=False, **kwargs): 239 | """Constructs a Res2Net-50_v1b_26w_4s lib. 240 | Args: 241 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 242 | """ 243 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 244 | if pretrained: 245 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 246 | return model 247 | 248 | 249 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 250 | """Constructs a Res2Net-50_v1b_26w_4s lib. 251 | Args: 252 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 253 | """ 254 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 255 | # model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, num_classes=2) # changed by coco 256 | 257 | if pretrained: 258 | model_state = torch.load('/home/zou_ke/projects/TMC/TMC_ICLR/pretrain/res2net50_v1b_26w_4s-3cf99910.pth') 259 | model.load_state_dict(model_state) 260 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 261 | return model 262 | 263 | def res2net50_v1b_14w_8s(pretrained=False, **kwargs): 264 | """Constructs a Res2Net-50_v1b_26w_4s lib. 265 | Args: 266 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 267 | """ 268 | model = Res2Net2(Bottle2neck, [3, 4, 6, 3], baseWidth=14, scale=8, **kwargs) 269 | # model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, num_classes=2) # changed by coco 270 | 271 | if pretrained: 272 | model_state = torch.load('/home/zou_ke/projects/TMC/TMC_ICLR/pretrain/res2net50_14w_8s-6527dddc.pth') 273 | model.load_state_dict(model_state) 274 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 275 | return model 276 | 277 | # def res2net50_v1b_14w_8s(pretrained=False, **kwargs): 278 | # """Constructs a Res2Net-50_14w_8s model. 279 | # Args: 280 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | # """ 282 | # model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs) 283 | # if pretrained: 284 | # # model = nn.DataParallel(model).cuda() 285 | # model_state = torch.load('/home/zou_ke/projects/TMC/TMC_ICLR/pretrain/res2net50_14w_8s-6527dddc.pth') 286 | # # new_state_dict = OrderedDict() 287 | # # for k, v in model_state.items(): 288 | # # name = k[7:] # module字段在最前面,从第7个字符开始就可以去掉module 289 | # # new_state_dict[name] = v # 新字典的key值对应的value一一对应 290 | # # kk = model_state.OrderedDict 291 | # model.load_state_dict(model_state, strict=False) 292 | # # model_state = torch.load('/home/zou_ke/projects/TMC/TMC_ICLR/pretrain/res2net50_14w_8s-6527dddc.pth') 293 | # # new_state_dict = OrderedDict() 294 | # # for k, v in model_state.items(): # k为module.xxx.weight, v为权重 295 | # # name = k.split('.')[0] # 截取`module.`后面的xxx.weight 296 | # # new_state_dict[name] = v 297 | # # # load params 298 | # # model.load_state_dict(new_state_dict) 299 | # return model 300 | 301 | 302 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 303 | """Constructs a Res2Net-50_v1b_26w_4s lib. 304 | Args: 305 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 306 | """ 307 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 308 | if pretrained: 309 | # model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 310 | model_state = torch.load('/home/zou_ke/projects/TMC/TMC_ICLR/pretrain/res2net101_v1b_26w_4s-0812c246.pth') 311 | model.load_state_dict(model_state) 312 | return model 313 | 314 | 315 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 316 | """Constructs a Res2Net-50_v1b_26w_4s lib. 317 | Args: 318 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 319 | """ 320 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs) 321 | if pretrained: 322 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 323 | return model 324 | 325 | 326 | if __name__ == '__main__': 327 | images = torch.rand(1, 3, 224, 224).cuda(0) 328 | model = res2net50_v1b_26w_4s(pretrained=True) 329 | model = model.cuda(0) 330 | print(model(images).size()) -------------------------------------------------------------------------------- /MedIA’24/Models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint as checkpoint 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | # from utils import utils 13 | # args = utils.parse_command() 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | def window_partition(x, window_size): 34 | """ 35 | Args: 36 | x: (B, H, W, C) 37 | window_size (int): window size 38 | Returns: 39 | windows: (num_windows*B, window_size, window_size, C) 40 | """ 41 | 42 | B, H, W, C = x.shape 43 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 44 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 45 | return windows 46 | 47 | def window_reverse(windows, window_size, H, W): 48 | """ 49 | Args: 50 | windows: (num_windows*B, window_size, window_size, C) 51 | window_size (int): Window size 52 | H (int): Height of image 53 | W (int): Width of image 54 | Returns: 55 | x: (B, H, W, C) 56 | """ 57 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 58 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 59 | 60 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 61 | 62 | return x 63 | 64 | class WindowAttention(nn.Module): 65 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 66 | It supports both of shifted and non-shifted window. 67 | Args: 68 | dim (int): Number of input channels. 69 | window_size (tuple[int]): The height and width of the window. 70 | num_heads (int): Number of attention heads. 71 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 72 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 73 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 74 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 75 | """ 76 | 77 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 78 | super().__init__() 79 | self.dim = dim 80 | self.window_size = window_size # Wh, Ww 81 | self.num_heads = num_heads 82 | head_dim = dim // num_heads 83 | self.scale = qk_scale or head_dim ** -0.5 84 | 85 | # define a parameter table of relative position bias 86 | self.relative_position_bias_table = nn.Parameter( 87 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 88 | 89 | # get pair-wise relative position index for each token inside the window 90 | coords_h = torch.arange(self.window_size[0]) 91 | coords_w = torch.arange(self.window_size[1]) 92 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 93 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 94 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 95 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 96 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 97 | relative_coords[:, :, 1] += self.window_size[1] - 1 98 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 99 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 100 | self.register_buffer("relative_position_index", relative_position_index) 101 | 102 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 103 | self.attn_drop = nn.Dropout(attn_drop) 104 | self.proj = nn.Linear(dim, dim) 105 | 106 | self.proj_drop = nn.Dropout(proj_drop) 107 | 108 | trunc_normal_(self.relative_position_bias_table, std=.02) 109 | self.softmax = nn.Softmax(dim=-1) 110 | 111 | def forward(self, x, mask=None): 112 | """ 113 | Args: 114 | x: input features with shape of (num_windows*B, N, C) 115 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 116 | """ 117 | 118 | B_, N, C = x.shape 119 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 120 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 121 | 122 | q = q * self.scale 123 | attn = torch.matmul(q,k.transpose(-2,-1)) #(q @ k.transpose(-2, -1)) 124 | 125 | 126 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 127 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 128 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 129 | attn = attn + relative_position_bias.unsqueeze(0) 130 | 131 | if mask is not None: 132 | nW = mask.shape[0] 133 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 134 | attn = attn.view(-1, self.num_heads, N, N) 135 | attn = self.softmax(attn) 136 | else: 137 | attn = self.softmax(attn) 138 | 139 | attn = self.attn_drop(attn) 140 | 141 | #x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 142 | x = (torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C) ) 143 | x = self.proj(x) 144 | 145 | x = self.proj_drop(x) 146 | return x 147 | 148 | #def extra_repr(self) -> str: 149 | # return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 150 | 151 | 152 | def flops(self, N): 153 | # calculate flops for 1 window with token length of N 154 | flops = 0 155 | # qkv = self.qkv(x) 156 | flops += N * self.dim * 3 * self.dim 157 | # attn = (q @ k.transpose(-2, -1)) 158 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 159 | # x = (attn @ v) 160 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 161 | # x = self.proj(x) 162 | flops += N * self.dim * self.dim 163 | return flops 164 | 165 | class SwinTransformerBlock(nn.Module): 166 | r""" Swin Transformer Block. 167 | Args: 168 | dim (int): Number of input channels. 169 | input_resolution (tuple[int]): Input resulotion. 170 | num_heads (int): Number of attention heads. 171 | window_size (int): Window size. 172 | shift_size (int): Shift size for SW-MSA. 173 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 174 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 175 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 176 | drop (float, optional): Dropout rate. Default: 0.0 177 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 178 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 179 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 180 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 181 | """ 182 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 183 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 184 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 185 | super().__init__() 186 | self.dim = dim 187 | self.input_resolution = input_resolution 188 | self.num_heads = num_heads 189 | self.window_size = window_size 190 | self.shift_size = shift_size 191 | self.mlp_ratio = mlp_ratio 192 | if min(self.input_resolution) <= self.window_size: 193 | 194 | # if window size is larger than input resolution, we don't partition windows 195 | 196 | self.shift_size = 0 197 | 198 | self.window_size = min(self.input_resolution) 199 | 200 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 201 | 202 | self.norm1 = norm_layer(dim) 203 | self.attn = WindowAttention( 204 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 205 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 206 | 207 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 208 | self.norm2 = norm_layer(dim) 209 | mlp_hidden_dim = int(dim * mlp_ratio) 210 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 211 | 212 | if self.shift_size > 0: 213 | # calculate attention mask for SW-MSA 214 | H, W = self.input_resolution 215 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 216 | h_slices = (slice(0, -self.window_size), 217 | slice(-self.window_size, -self.shift_size), 218 | slice(-self.shift_size, None)) 219 | w_slices = (slice(0, -self.window_size), 220 | slice(-self.window_size, -self.shift_size), 221 | slice(-self.shift_size, None)) 222 | cnt = 0 223 | for h in h_slices: 224 | 225 | for w in w_slices: 226 | 227 | img_mask[:, h, w, :] = cnt 228 | 229 | cnt += 1 230 | 231 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 232 | 233 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 234 | 235 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 236 | 237 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 238 | 239 | else: 240 | 241 | attn_mask = None 242 | 243 | 244 | 245 | self.register_buffer("attn_mask", attn_mask) 246 | 247 | 248 | 249 | def forward(self, x): 250 | 251 | H, W = self.input_resolution 252 | 253 | B, L, C = x.shape 254 | 255 | assert L == H * W, "input feature has wrong size" 256 | 257 | 258 | 259 | shortcut = x 260 | 261 | x = self.norm1(x) 262 | 263 | x = x.view(B, H, W, C) 264 | 265 | 266 | 267 | # cyclic shift 268 | 269 | if self.shift_size > 0: 270 | 271 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 272 | 273 | else: 274 | 275 | shifted_x = x 276 | 277 | 278 | 279 | # partition windows 280 | 281 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 282 | 283 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 284 | 285 | 286 | 287 | # W-MSA/SW-MSA 288 | 289 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 290 | 291 | 292 | 293 | # merge windows 294 | 295 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 296 | 297 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 298 | 299 | 300 | 301 | # reverse cyclic shift 302 | 303 | if self.shift_size > 0: 304 | 305 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 306 | 307 | else: 308 | 309 | x = shifted_x 310 | 311 | x = x.view(B, H * W, C) 312 | 313 | 314 | 315 | # FFN 316 | 317 | x = shortcut + self.drop_path(x) 318 | 319 | x = x + self.drop_path(self.mlp(self.norm2(x))) 320 | 321 | 322 | 323 | return x 324 | 325 | 326 | 327 | #def extra_repr(self) -> str: 328 | # return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 329 | # f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 330 | 331 | 332 | 333 | def flops(self): 334 | 335 | flops = 0 336 | 337 | H, W = self.input_resolution 338 | 339 | # norm1 340 | 341 | flops += self.dim * H * W 342 | 343 | # W-MSA/SW-MSA 344 | 345 | nW = H * W / self.window_size / self.window_size 346 | 347 | flops += nW * self.attn.flops(self.window_size * self.window_size) 348 | 349 | # mlp 350 | 351 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 352 | 353 | # norm2 354 | 355 | flops += self.dim * H * W 356 | 357 | return flops 358 | 359 | 360 | 361 | 362 | 363 | class PatchMerging(nn.Module): 364 | 365 | r""" Patch Merging Layer. 366 | 367 | 368 | 369 | Args: 370 | 371 | input_resolution (tuple[int]): Resolution of input feature. 372 | 373 | dim (int): Number of input channels. 374 | 375 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 376 | 377 | """ 378 | 379 | 380 | 381 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 382 | 383 | super().__init__() 384 | 385 | self.input_resolution = input_resolution 386 | 387 | self.dim = dim 388 | 389 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 390 | 391 | self.norm = norm_layer(4 * dim) 392 | 393 | 394 | 395 | def forward(self, x): 396 | 397 | """ 398 | 399 | x: B, H*W, C 400 | 401 | """ 402 | 403 | H, W = self.input_resolution 404 | 405 | B, L, C = x.shape 406 | 407 | assert L == H * W, "input feature has wrong size" 408 | 409 | assert H % 2 == 0 and W % 2 == 0, "x size ({H}*{W}) are not even." 410 | 411 | 412 | 413 | x = x.view(B, H, W, C) 414 | 415 | 416 | 417 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 418 | 419 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 420 | 421 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 422 | 423 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 424 | 425 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 426 | 427 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 428 | 429 | 430 | 431 | x = self.norm(x) 432 | 433 | x = self.reduction(x) 434 | 435 | 436 | 437 | return x 438 | 439 | 440 | 441 | #def extra_repr(self) -> str: 442 | # return f"input_resolution={self.input_resolution}, dim={self.dim}" 443 | 444 | 445 | 446 | def flops(self): 447 | 448 | H, W = self.input_resolution 449 | 450 | flops = H * W * self.dim 451 | 452 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 453 | 454 | return flops 455 | 456 | 457 | 458 | 459 | 460 | class BasicLayer(nn.Module): 461 | 462 | """ A basic Swin Transformer layer for one stage. 463 | 464 | 465 | 466 | Args: 467 | 468 | dim (int): Number of input channels. 469 | 470 | input_resolution (tuple[int]): Input resolution. 471 | 472 | depth (int): Number of blocks. 473 | 474 | num_heads (int): Number of attention heads. 475 | 476 | window_size (int): Local window size. 477 | 478 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 479 | 480 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 481 | 482 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 483 | 484 | drop (float, optional): Dropout rate. Default: 0.0 485 | 486 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 487 | 488 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 489 | 490 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 491 | 492 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 493 | 494 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 495 | 496 | """ 497 | 498 | 499 | 500 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 501 | 502 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 503 | 504 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 505 | 506 | 507 | 508 | super().__init__() 509 | 510 | self.dim = dim 511 | 512 | self.input_resolution = input_resolution 513 | 514 | self.depth = depth 515 | 516 | self.use_checkpoint = use_checkpoint 517 | 518 | 519 | 520 | # build blocks 521 | 522 | self.blocks = nn.ModuleList([ 523 | 524 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 525 | 526 | num_heads=num_heads, window_size=window_size, 527 | 528 | shift_size=0 if (i % 2 == 0) else window_size // 2, 529 | 530 | mlp_ratio=mlp_ratio, 531 | 532 | qkv_bias=qkv_bias, qk_scale=qk_scale, 533 | 534 | drop=drop, attn_drop=attn_drop, 535 | 536 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 537 | 538 | norm_layer=norm_layer) 539 | 540 | for i in range(depth)]) 541 | 542 | 543 | 544 | # patch merging layer 545 | 546 | if downsample is not None: 547 | 548 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 549 | 550 | else: 551 | 552 | self.downsample = None 553 | 554 | 555 | 556 | def forward(self, x): 557 | 558 | for blk in self.blocks: 559 | 560 | if self.use_checkpoint: 561 | 562 | x = checkpoint.checkpoint(blk, x) 563 | 564 | else: 565 | 566 | x = blk(x) 567 | 568 | if self.downsample is not None: 569 | 570 | x = self.downsample(x) 571 | 572 | return x 573 | 574 | 575 | 576 | #def extra_repr(self) -> str: 577 | # return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 578 | 579 | def flops(self): 580 | flops = 0 581 | for blk in self.blocks: 582 | flops += blk.flops() 583 | if self.downsample is not None: 584 | flops += self.downsample.flops() 585 | return flops 586 | 587 | 588 | class PatchEmbed(nn.Module): 589 | r""" Image to Patch Embedding 590 | Args: 591 | img_size (int): Image size. Default: 224. 592 | patch_size (int): Patch token size. Default: 4. 593 | in_chans (int): Number of input image channels. Default: 3. 594 | embed_dim (int): Number of linear projection output channels. Default: 96. 595 | norm_layer (nn.Module, optional): Normalization layer. Default: None 596 | """ 597 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 598 | super().__init__() 599 | 600 | img_size = to_2tuple(img_size) 601 | patch_size = to_2tuple(patch_size) 602 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 603 | self.img_size = img_size 604 | self.patch_size = patch_size 605 | self.patches_resolution = patches_resolution 606 | self.num_patches = patches_resolution[0] * patches_resolution[1] 607 | 608 | self.in_chans = in_chans 609 | self.embed_dim = embed_dim 610 | 611 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 612 | 613 | if norm_layer is not None: 614 | self.norm = norm_layer(embed_dim) 615 | else: 616 | self.norm = None 617 | 618 | def forward(self, x): 619 | B, C, H, W = x.shape 620 | 621 | # FIXME look at relaxing size constraints 622 | assert H == self.img_size[0] and W == self.img_size[1], \ 623 | "Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 624 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 625 | if self.norm is not None: 626 | x = self.norm(x) 627 | return x 628 | 629 | def flops(self): 630 | Ho, Wo = self.patches_resolution 631 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 632 | if self.norm is not None: 633 | flops += Ho * Wo * self.embed_dim 634 | return flops 635 | 636 | 637 | class SwinTransformer(nn.Module): 638 | r""" Swin Transformer 639 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 640 | https://arxiv.org/pdf/2103.14030 641 | Args: 642 | img_size (int | tuple(int)): Input image size. Default 224 643 | patch_size (int | tuple(int)): Patch size. Default: 4 644 | in_chans (int): Number of input image channels. Default: 3 645 | num_classes (int): Number of classes for classification head. Default: 1000 646 | embed_dim (int): Patch embedding dimension. Default: 96 647 | depths (tuple(int)): Depth of each Swin Transformer layer. 648 | num_heads (tuple(int)): Number of attention heads in different layers. 649 | window_size (int): Window size. Default: 7 650 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 651 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 652 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 653 | drop_rate (float): Dropout rate. Default: 0 654 | attn_drop_rate (float): Attention dropout rate. Default: 0 655 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 656 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 657 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 658 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 659 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 660 | """ 661 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 662 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 663 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 664 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 665 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 666 | use_checkpoint=False, **kwargs): 667 | super().__init__() 668 | 669 | 670 | self.num_classes = num_classes 671 | self.num_layers = len(depths) 672 | self.embed_dim = embed_dim 673 | 674 | self.ape = ape 675 | self.patch_norm = patch_norm 676 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 677 | self.mlp_ratio = mlp_ratio 678 | 679 | 680 | # split image into non-overlapping patches 681 | self.patch_embed = PatchEmbed( 682 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 683 | norm_layer=norm_layer if self.patch_norm else None) 684 | num_patches = self.patch_embed.num_patches 685 | patches_resolution = self.patch_embed.patches_resolution 686 | self.patches_resolution = patches_resolution 687 | 688 | # absolute position embedding 689 | if self.ape: 690 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 691 | trunc_normal_(self.absolute_pos_embed, std=.02) 692 | 693 | self.pos_drop = nn.Dropout(p=drop_rate) 694 | 695 | 696 | # stochastic depth 697 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 698 | 699 | 700 | # build layers 701 | self.layers = nn.ModuleList() 702 | for i_layer in range(self.num_layers): 703 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 704 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 705 | patches_resolution[1] // (2 ** i_layer)), 706 | depth=depths[i_layer], 707 | num_heads=num_heads[i_layer], 708 | window_size=window_size, 709 | mlp_ratio=self.mlp_ratio, 710 | qkv_bias=qkv_bias, qk_scale=qk_scale, 711 | drop=drop_rate, attn_drop=attn_drop_rate, 712 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 713 | norm_layer=norm_layer, 714 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 715 | use_checkpoint=use_checkpoint) 716 | self.layers.append(layer) 717 | 718 | self.norm = norm_layer(self.num_features) 719 | self.avgpool = nn.AdaptiveAvgPool1d(1) 720 | self.head = nn.Linear(self.num_features, 1000) #if num_classes > 0 else nn.Identity() 721 | self.head_new = nn.Linear(self.num_features, num_classes) #if num_classes > 0 else nn.Identity() 722 | 723 | self.apply(self._init_weights) 724 | 725 | 726 | def _init_weights(self, m): 727 | 728 | if isinstance(m, nn.Linear): 729 | 730 | trunc_normal_(m.weight, std=.02) 731 | if isinstance(m, nn.Linear) and m.bias is not None: 732 | nn.init.constant_(m.bias, 0) 733 | elif isinstance(m, nn.LayerNorm): 734 | nn.init.constant_(m.bias, 0) 735 | nn.init.constant_(m.weight, 1.0) 736 | 737 | @torch.jit.ignore 738 | def no_weight_decay(self): 739 | return {'absolute_pos_embed'} 740 | 741 | @torch.jit.ignore 742 | def no_weight_decay_keywords(self): 743 | return {'relative_position_bias_table'} 744 | 745 | def forward_features(self, x): 746 | x = self.patch_embed(x) 747 | 748 | if self.ape: 749 | x = x + self.absolute_pos_embed 750 | x = self.pos_drop(x) 751 | 752 | for layer in self.layers: 753 | x = layer(x) 754 | 755 | x = self.norm(x) # B L C 756 | x = self.avgpool(x.transpose(1, 2)) # B C 1 757 | x = torch.flatten(x, 1) 758 | return x 759 | 760 | def forward(self, x): 761 | x = self.forward_features(x) 762 | # if args.dataset == 'GAMMA-all': 763 | # return x 764 | # x = self.head_new(x) 765 | return x 766 | 767 | def flops(self): 768 | flops = 0 769 | flops += self.patch_embed.flops() 770 | for i, layer in enumerate(self.layers): 771 | flops += layer.flops() 772 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 773 | flops += self.num_features * self.num_classes 774 | return flops 775 | 776 | -------------------------------------------------------------------------------- /MedIA’24/baseline_models.py: -------------------------------------------------------------------------------- 1 | from Models.generate_model import * 2 | from Models.res2net import res2net50_v1b_26w_4s,res2net50_v1b_14w_8s,res2net101_v1b_26w_4s 3 | import torch.nn.functional as F 4 | class Medical_feature_2DNet(nn.Module): 5 | # res2net based encoder decoder 6 | def __init__(self, num_classes=10): 7 | super(Medical_feature_2DNet, self).__init__() 8 | # ---- ResNet Backbone ---- 9 | self.res2net = res2net50_v1b_26w_4s(pretrained=True) 10 | def forward(self, x): 11 | #origanal x do: 12 | x = self.res2net.conv1(x) 13 | x = self.res2net.bn1(x) 14 | x = self.res2net.relu(x) 15 | x = self.res2net.maxpool(x) # bs, 64, 128, 128 16 | # ---- low-level features ---- 17 | x1 = self.res2net.layer1(x) # bs, 256, 128, 128 18 | x2 = self.res2net.layer2(x1) # bs, 512, 64, 64 19 | x3 = self.res2net.layer3(x2) # bs, 1024, 32, 32 20 | x4 = self.res2net.layer4(x3) # bs, 2048, 16, 16 21 | return x4 22 | 23 | 24 | class Medical_base_2DNet(nn.Module): 25 | # res2net based encoder decoder 26 | def __init__(self, num_classes=10): 27 | super(Medical_base_2DNet, self).__init__() 28 | # ---- ResNet Backbone ---- 29 | self.res2net = res2net50_v1b_26w_4s(pretrained=True) 30 | def forward(self, x): 31 | #origanal x do: 32 | x = self.res2net.conv1(x) 33 | x = self.res2net.bn1(x) 34 | x = self.res2net.relu(x) 35 | x = self.res2net.maxpool(x) # bs, 64, 128, 128 36 | # ---- low-level features ---- 37 | x1 = self.res2net.layer1(x) # bs, 256, 128, 128 38 | x2 = self.res2net.layer2(x1) # bs, 512, 64, 64 39 | x3 = self.res2net.layer3(x2) # bs, 1024, 32, 32 40 | x4 = self.res2net.layer4(x3) # bs, 2048, 16, 16 41 | x4 = self.res2net.avgpool(x4) # bs, 2048, 1, 1 42 | x4 = x4.view(x4.size(0), -1) # bs, 1, 2048, 43 | return x4 44 | 45 | class Medical_base2_2DNet(nn.Module): 46 | # res2net based encoder decoder 47 | def __init__(self, num_classes=10): 48 | super(Medical_base2_2DNet, self).__init__() 49 | # ---- ResNet Backbone ---- 50 | self.res2net = res2net50_v1b_14w_8s(pretrained=True) 51 | # self.res2net = res2net50_v1b_26w_4s(pretrained=True) 52 | # self.res2net = res2net101_v1b_26w_4s(pretrained=True) 53 | 54 | def forward(self, x): 55 | #origanal x do: 56 | x = self.res2net.conv1(x) 57 | x = self.res2net.bn1(x) 58 | x = self.res2net.relu(x) 59 | x = self.res2net.maxpool(x) # bs, 64, 64, 64 60 | # ---- low-level features ---- 61 | x1 = self.res2net.layer1(x) # bs, 256, 64, 64 62 | x2 = self.res2net.layer2(x1) # bs, 512, 32, 32 63 | x3 = self.res2net.layer3(x2) # bs, 1024, 16, 16 64 | x4 = self.res2net.layer4(x3) # bs, 2048, 8, 8 65 | x4 = self.res2net.avgpool(x4) 66 | x4 = x4.view(x4.size(0), -1) 67 | return x4 68 | 69 | class Medical_base_dropout_2DNet(nn.Module): 70 | # res2net based encoder decoder 71 | def __init__(self, num_classes=10): 72 | super(Medical_base_dropout_2DNet, self).__init__() 73 | # ---- ResNet Backbone ---- 74 | self.res2net = res2net50_v1b_26w_4s(pretrained=True) 75 | def forward(self, x): 76 | #origanal x do: 77 | x = self.res2net.conv1(x) 78 | x = self.res2net.bn1(x) 79 | x = self.res2net.relu(x) 80 | x = self.res2net.maxpool(x) # bs, 64, 64, 64 81 | # dropout layer 82 | x = F.dropout(x, p=0.2) 83 | # ---- low-level features ---- 84 | x1 = self.res2net.layer1(x) # bs, 256, 64, 64 85 | x2 = self.res2net.layer2(x1) # bs, 512, 32, 32 86 | x3 = self.res2net.layer3(x2) # bs, 1024, 16, 16 87 | x4 = self.res2net.layer4(x3) # bs, 2048, 8, 8 88 | # dropout layer 89 | x4 = F.dropout(x4, p=0.2) 90 | x4 = self.res2net.avgpool(x4) 91 | x4 = x4.view(x4.size(0), -1) 92 | return x4 93 | 94 | class Medical_2DNet(nn.Module): 95 | # res2net based encoder decoder 96 | def __init__(self, num_classes=10): 97 | super(Medical_2DNet, self).__init__() 98 | # ---- ResNet Backbone ---- 99 | self.res2net = res2net50_v1b_26w_4s(pretrained=True) 100 | self.fc = nn.Linear(2048, num_classes) 101 | 102 | def forward(self, x): 103 | #origanal x do: 104 | x = self.res2net.conv1(x) 105 | x = self.res2net.bn1(x) 106 | x = self.res2net.relu(x) 107 | x = self.res2net.maxpool(x) # bs, 64, 64, 64 108 | # ---- low-level features ---- 109 | x1 = self.res2net.layer1(x) # bs, 256, 64, 64 110 | x2 = self.res2net.layer2(x1) # bs, 512, 32, 32 111 | x3 = self.res2net.layer3(x2) # bs, 1024, 16, 16 112 | x4 = self.res2net.layer4(x3) # bs, 2048, 8, 8 113 | x4 = self.res2net.avgpool(x4) 114 | x4 = x4.view(x4.size(0), -1) 115 | out = self.fc(x4) 116 | return out 117 | 118 | 119 | class Medical_3DNet(nn.Module): 120 | # res2net based encoder decoder 121 | def __init__(self, classifier_OCT_dims,num_classes=10): 122 | super(Medical_3DNet, self).__init__() 123 | # ---- ResNet Backbone ---- 124 | self.resnet_3DNet = generate_model(model_type='resnet', model_depth=10, input_W=classifier_OCT_dims[0][0], 125 | input_H=classifier_OCT_dims[0][1], input_D=classifier_OCT_dims[0][2], 126 | resnet_shortcut='B', 127 | no_cuda=True, gpu_id=[0], 128 | pretrain_path='./pretrain/resnet_10_23dataset.pth', nb_class=num_classes) 129 | if classifier_OCT_dims[0][0] == 128: 130 | self.fc = nn.Linear(8192, num_classes) # MMOCTF 131 | else: 132 | self.fc = nn.Linear(3072, num_classes) # OLIVES 133 | 134 | def forward(self, x): 135 | 136 | x = self.resnet_3DNet.conv1(x) 137 | x = self.resnet_3DNet.bn1(x) 138 | x = self.resnet_3DNet.relu(x) 139 | x = self.resnet_3DNet.maxpool(x) # bs, 64, 64, 64 140 | # ---- low-level features ---- 141 | x1 = self.resnet_3DNet.layer1(x) # bs, 256, 64, 64 142 | x2 = self.resnet_3DNet.layer2(x1) # bs, 512, 32, 32 143 | x3 = self.resnet_3DNet.layer3(x2) # bs, 1024, 16, 16 144 | x4 = self.resnet_3DNet.layer4(x3) # bs, 2048, 8, 8 145 | x4 = self.resnet_3DNet.avgpool(x4) 146 | x4 = x4.view(x4.size(0), -1) 147 | out = self.fc(x4) 148 | return out 149 | 150 | class Medical_base_3DNet(nn.Module): 151 | # res2net based encoder decoder 152 | def __init__(self, classifier_OCT_dims,num_classes=10): 153 | super(Medical_base_3DNet, self).__init__() 154 | # ---- ResNet Backbone ---- 155 | self.resnet_3DNet = generate_model(model_type='resnet', model_depth=10, input_W=classifier_OCT_dims[0][0], 156 | input_H=classifier_OCT_dims[0][1], input_D=classifier_OCT_dims[0][2], 157 | resnet_shortcut='B', 158 | no_cuda=True, gpu_id=[0], 159 | pretrain_path='./pretrain/resnet_10_23dataset.pth', nb_class=num_classes) 160 | 161 | def forward(self, x): 162 | 163 | x = self.resnet_3DNet.conv1(x) 164 | x = self.resnet_3DNet.bn1(x) 165 | x = self.resnet_3DNet.relu(x) 166 | x = self.resnet_3DNet.maxpool(x) # bs, 64, 32, 32,64 167 | # ---- low-level features ---- 168 | x1 = self.resnet_3DNet.layer1(x) # bs, 64, 32, 32,64 169 | x2 = self.resnet_3DNet.layer2(x1) # bs, 128, 16, 16,32 170 | x3 = self.resnet_3DNet.layer3(x2) # bs, 256, 16, 16,32 171 | x4 = self.resnet_3DNet.layer4(x3) # bs, 512, 16, 16,32 172 | x4 = self.resnet_3DNet.avgpool(x4) # bs, 512, 16, 1,1 173 | x4 = x4.view(x4.size(0), -1) # 8192 174 | return x4 175 | 176 | class Medical_feature_3DNet(nn.Module): 177 | # res2net based encoder decoder 178 | def __init__(self, classifier_OCT_dims,num_classes=10): 179 | super(Medical_feature_3DNet, self).__init__() 180 | # ---- ResNet Backbone ---- 181 | self.resnet_3DNet = generate_model(model_type='resnet', model_depth=10, input_W=classifier_OCT_dims[0][0], 182 | input_H=classifier_OCT_dims[0][1], input_D=classifier_OCT_dims[0][2], 183 | resnet_shortcut='B', 184 | no_cuda=True, gpu_id=[0], 185 | pretrain_path='./pretrain/resnet_10_23dataset.pth', nb_class=num_classes) 186 | 187 | def forward(self, x): 188 | 189 | x = self.resnet_3DNet.conv1(x) 190 | x = self.resnet_3DNet.bn1(x) 191 | x = self.resnet_3DNet.relu(x) 192 | x = self.resnet_3DNet.maxpool(x) # bs, 64, 32, 32,64 193 | # ---- low-level features ---- 194 | x1 = self.resnet_3DNet.layer1(x) # bs, 64, 32, 32,64 195 | x2 = self.resnet_3DNet.layer2(x1) # bs, 128, 16, 16,32 196 | x3 = self.resnet_3DNet.layer3(x2) # bs, 256, 16, 16,32 197 | x4 = self.resnet_3DNet.layer4(x3) # bs, 512, 16, 16,32 198 | return x4 199 | 200 | class Medical_base2_3DNet(nn.Module): 201 | # res2net based encoder decoder 202 | def __init__(self, classifier_OCT_dims,num_classes=10): 203 | super(Medical_base2_3DNet, self).__init__() 204 | # ---- ResNet Backbone ---- 205 | self.resnet_3DNet = generate_model(model_type='resnet', model_depth=18, input_W=classifier_OCT_dims[0][0], 206 | input_H=classifier_OCT_dims[0][1], input_D=classifier_OCT_dims[0][2], 207 | resnet_shortcut='A', 208 | no_cuda=True, gpu_id=[0], 209 | pretrain_path='./pretrain/resnet_18_23dataset.pth', nb_class=num_classes) 210 | 211 | def forward(self, x): 212 | 213 | x = self.resnet_3DNet.conv1(x) 214 | x = self.resnet_3DNet.bn1(x) 215 | x = self.resnet_3DNet.relu(x) 216 | x = self.resnet_3DNet.maxpool(x) # bs, 64, 32, 32,64 217 | # ---- low-level features ---- 218 | x1 = self.resnet_3DNet.layer1(x) ## bs, 64, 32, 32,64 219 | x2 = self.resnet_3DNet.layer2(x1) # bs, 128, 16, 16,32 220 | x3 = self.resnet_3DNet.layer3(x2) # bs, 256, 16, 16,32 221 | x4 = self.resnet_3DNet.layer4(x3) # bs, 512, 16, 16,32 222 | x4 = self.resnet_3DNet.avgpool(x4) # bs, 512, 16, 1,1 223 | x4 = x4.view(x4.size(0), -1) # 8192 224 | return x4 225 | 226 | class Medical_base_dropout_3DNet(nn.Module): 227 | # res2net based encoder decoder 228 | def __init__(self, classifier_OCT_dims,num_classes=10): 229 | super(Medical_base_dropout_3DNet, self).__init__() 230 | # ---- ResNet Backbone ---- 231 | self.resnet_3DNet = generate_model(model_type='resnet', model_depth=10, input_W=classifier_OCT_dims[0][0], 232 | input_H=classifier_OCT_dims[0][1], input_D=classifier_OCT_dims[0][2], 233 | resnet_shortcut='B', 234 | no_cuda=True, gpu_id=[0], 235 | pretrain_path='./pretrain/resnet_10_23dataset.pth', nb_class=num_classes) 236 | 237 | def forward(self, x): 238 | 239 | x = self.resnet_3DNet.conv1(x) 240 | x = self.resnet_3DNet.bn1(x) 241 | x = self.resnet_3DNet.relu(x) 242 | x = self.resnet_3DNet.maxpool(x) # bs, 64, 64, 64 243 | # dropout layer 244 | x = F.dropout(x, p=0.2) 245 | # ---- low-level features ---- 246 | x1 = self.resnet_3DNet.layer1(x) # bs, 256, 64, 64 247 | x2 = self.resnet_3DNet.layer2(x1) # bs, 512, 32, 32 248 | x3 = self.resnet_3DNet.layer3(x2) # bs, 1024, 16, 16 249 | x4 = self.resnet_3DNet.layer4(x3) # bs, 2048, 8, 8 250 | x4 = self.resnet_3DNet.avgpool(x4) 251 | # dropout layer 252 | x4 = F.dropout(x4, p=0.2) 253 | x4 = x4.view(x4.size(0), -1) 254 | return x4 255 | 256 | class ResNet3D(nn.Module): 257 | 258 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 259 | """ 260 | :param classes: Number of classification categories 261 | :param views: Number of modalties 262 | :param classifier_dims: Dimension of the classifier 263 | :param annealing_epoch: KL divergence annealing epoch during training 264 | """ 265 | super(ResNet3D, self).__init__() 266 | self.modalties = modalties 267 | self.classes = classes 268 | self.lambda_epochs = lambda_epochs 269 | 270 | # ---- 3D ResNet Backbone ---- 271 | classifier_OCT_dims = classifiers_dims 272 | self.resnet_3DNet = Medical_3DNet(classifier_OCT_dims,num_classes=self.classes) 273 | self.Classifiers= nn.ModuleList([self.resnet_3DNet]) 274 | self.bce_loss = nn.BCELoss() 275 | self.ce_loss = nn.CrossEntropyLoss() 276 | self.sfm = nn.Softmax() 277 | 278 | def forward(self, X, y): 279 | output = self.infer(X[1]) 280 | loss = 0 281 | for v_num in range(self.modalties): 282 | pred = output[v_num] 283 | # label = F.one_hot(y, num_classes=self.classes) 284 | # loss = self.ce_loss(label, pred) 285 | loss = self.ce_loss(pred, y) 286 | 287 | loss = torch.mean(loss) 288 | return pred, loss 289 | 290 | def infer(self, input): 291 | """ 292 | :param input: Multi-view data 293 | :return: evidence of every view 294 | """ 295 | evidence = dict() 296 | for m_num in range(self.modalties): 297 | backbone_output = self.Classifiers[m_num](input) 298 | evidence[m_num] = self.sfm(backbone_output) 299 | return evidence 300 | 301 | class Res2Net2D(nn.Module): 302 | 303 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 304 | """ 305 | :param classes: Number of classification categories 306 | :param views: Number of modalties 307 | :param classifier_dims: Dimension of the classifier 308 | :param annealing_epoch: KL divergence annealing epoch during training 309 | """ 310 | super(Res2Net2D, self).__init__() 311 | self.modalties = modalties 312 | self.classes = classes 313 | self.lambda_epochs = lambda_epochs 314 | # ---- 2D Res2Net Backbone ---- 315 | classifier_Fundus_dims = classifiers_dims[0] 316 | self.res2net_2DNet = Medical_2DNet(num_classes=self.classes) 317 | self.Classifiers= nn.ModuleList([self.res2net_2DNet]) 318 | self.bce_loss = nn.BCELoss() 319 | self.ce_loss = nn.CrossEntropyLoss() 320 | self.sfm = nn.Softmax() 321 | 322 | def forward(self, X, y): 323 | output = self.infer(X[0]) 324 | loss = 0 325 | for v_num in range(self.modalties): 326 | pred = output[v_num] 327 | # label = F.one_hot(y, num_classes=self.classes) 328 | # loss = self.ce_loss(label, pred) 329 | loss = self.ce_loss(pred, y) 330 | 331 | loss = torch.mean(loss) 332 | return pred, loss 333 | 334 | def infer(self, input): 335 | """ 336 | :param input: Multi-view data 337 | :return: evidence of every view 338 | """ 339 | evidence = dict() 340 | for m_num in range(self.modalties): 341 | backbone_output = self.Classifiers[m_num](input) 342 | evidence[m_num] = self.sfm(backbone_output) 343 | return evidence 344 | 345 | class Multi_ResNet(nn.Module): 346 | 347 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 348 | """ 349 | :param classes: Number of classification categories 350 | :param views: Number of modalties 351 | :param classifier_dims: Dimension of the classifier 352 | :param annealing_epoch: KL divergence annealing epoch during training 353 | """ 354 | super(Multi_ResNet, self).__init__() 355 | self.modalties = modalties 356 | self.classes = classes 357 | self.lambda_epochs = lambda_epochs 358 | # ---- 2D Res2Net Backbone ---- 359 | self.res2net_2DNet = Medical_base_2DNet(num_classes=self.classes) 360 | 361 | # ---- 3D ResNet Backbone ---- 362 | classifier_OCT_dims = classifiers_dims[0] 363 | self.resnet_3DNet = Medical_base_3DNet(classifier_OCT_dims,num_classes=self.classes) 364 | self.sp = nn.Softplus() 365 | if classifier_OCT_dims[0][0] == 128: 366 | self.fc = nn.Linear(2048 + 8192, classes) # MMOCTF 367 | else: 368 | self.fc = nn.Linear(2048 + 3072, classes) # OLIVES 369 | 370 | # self.fc = nn.Linear(2048 + 8192, classes) # MMOCTF 371 | # self.fc = nn.Linear(2048 + 3072, classes) #OLIVES 372 | 373 | self.ce_loss = nn.CrossEntropyLoss() 374 | 375 | 376 | def forward(self, X, y): 377 | backboneout_1 = self.res2net_2DNet(X[0]) 378 | backboneout_2 = self.resnet_3DNet(X[1]) 379 | combine_features = torch.cat([backboneout_1,backboneout_2],1) 380 | pred = self.fc(combine_features) 381 | loss = self.ce_loss(pred, y) 382 | 383 | loss = torch.mean(loss) 384 | return pred, loss 385 | 386 | class Multi_EF_ResNet(nn.Module): 387 | 388 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 389 | """ 390 | :param classes: Number of classification categories 391 | :param views: Number of modalties 392 | :param classifier_dims: Dimension of the classifier 393 | :param annealing_epoch: KL divergence annealing epoch during training 394 | """ 395 | super(Multi_EF_ResNet, self).__init__() 396 | self.modalties = modalties 397 | self.classes = classes 398 | self.lambda_epochs = lambda_epochs 399 | # ---- 2D Res2Net Backbone ---- 400 | self.res2net_2DNet = Medical_base_2DNet(num_classes=self.classes) 401 | 402 | # ---- 3D ResNet Backbone ---- 403 | classifier_OCT_dims = classifiers_dims[0] 404 | 405 | # --- 2D early fusion conv 406 | if classifier_OCT_dims[0][-1] == 248: 407 | self.ef_conv = nn.Sequential( 408 | nn.AvgPool2d(kernel_size=1, stride=[2,2], 409 | ceil_mode=True, count_include_pad=False), 410 | nn.Conv2d(3, 3, 1, 1)) 411 | self.fc = nn.Linear(3584, classes) 412 | 413 | else: 414 | self.ef_conv = nn.Sequential( 415 | nn.AvgPool2d(kernel_size=1, stride=[4,2], 416 | ceil_mode=True, count_include_pad=False), 417 | nn.Conv2d(3, 3, 1, 1)) 418 | self.fc = nn.Linear(8704, classes) 419 | 420 | self.resnet_3DNet = Medical_base_3DNet(classifier_OCT_dims,num_classes=self.classes) 421 | self.sp = nn.Softplus() 422 | self.ce_loss = nn.CrossEntropyLoss() 423 | 424 | 425 | def forward(self, X, y): 426 | X0_features = self.ef_conv(X[0]) 427 | if self.classes == 2: 428 | if X[1].shape[-1] == 248: 429 | X[1].resize_(X[1].shape[0],X[1].shape[1],X[1].shape[2],X0_features.shape[-2],X0_features.shape[-1]) 430 | combine_features = torch.cat([X0_features.unsqueeze(1),X[1].permute(0,1,2,4,3)],2) 431 | 432 | else: 433 | combine_features = torch.cat([X0_features.unsqueeze(1),X[1].permute(0,1,2,4,3)],2) 434 | else: 435 | combine_features = torch.cat([X0_features.unsqueeze(1),X[1]],2) 436 | 437 | # backboneout_1 = self.res2net_2DNet(X[0]) 438 | backboneout_2 = self.resnet_3DNet(combine_features) 439 | pred = self.fc(backboneout_2) 440 | loss = self.ce_loss(pred, y) 441 | 442 | loss = torch.mean(loss) 443 | return pred, loss 444 | 445 | 446 | class CBAM2D(nn.Module): 447 | def __init__(self, channel, reduction=16, spatial_kernel=7): 448 | super(CBAM2D, self).__init__() 449 | 450 | # channel attention 压缩H,W为1 451 | self.max_pool = nn.AdaptiveMaxPool2d(1) 452 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 453 | 454 | # shared MLP 455 | self.mlp = nn.Sequential( 456 | # Conv2d比Linear方便操作 457 | # nn.Linear(channel, channel // reduction, bias=False) 458 | nn.Conv2d(channel, channel // reduction, 1, bias=False), 459 | # inplace=True直接替换,节省内存 460 | nn.ReLU(inplace=True), 461 | # nn.Linear(channel // reduction, channel,bias=False) 462 | nn.Conv2d(channel // reduction, channel, 1, bias=False) 463 | ) 464 | 465 | # spatial attention 466 | self.conv = nn.Conv2d(2, 1, kernel_size=spatial_kernel, 467 | padding=spatial_kernel // 2, bias=False) 468 | self.sigmoid = nn.Sigmoid() 469 | 470 | def forward(self, x): 471 | max_out = self.mlp(self.max_pool(x)) 472 | avg_out = self.mlp(self.avg_pool(x)) 473 | channel_out = self.sigmoid(max_out + avg_out) 474 | x = channel_out * x 475 | 476 | max_out, _ = torch.max(x, dim=1, keepdim=True) 477 | avg_out = torch.mean(x, dim=1, keepdim=True) 478 | spatial_out = self.sigmoid(self.conv(torch.cat([max_out, avg_out], dim=1))) 479 | x = spatial_out * x 480 | return x 481 | 482 | class CBAM3D(nn.Module): 483 | def __init__(self, channel, reduction=16, spatial_kernel=7): 484 | super(CBAM3D, self).__init__() 485 | 486 | # channel attention 压缩H,W为1 487 | self.max_pool = nn.AdaptiveMaxPool3d(1) 488 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 489 | 490 | # shared MLP 491 | self.mlp = nn.Sequential( 492 | # Conv2d比Linear方便操作 493 | # nn.Linear(channel, channel // reduction, bias=False) 494 | nn.Conv3d(channel, channel // reduction, 1, bias=False), 495 | # inplace=True直接替换,节省内存 496 | nn.ReLU(inplace=True), 497 | # nn.Linear(channel // reduction, channel,bias=False) 498 | nn.Conv3d(channel // reduction, channel, 1, bias=False) 499 | ) 500 | 501 | # spatial attention 502 | self.conv = nn.Conv3d(2, 1, kernel_size=spatial_kernel, 503 | padding=spatial_kernel // 2, bias=False) 504 | self.sigmoid = nn.Sigmoid() 505 | 506 | def forward(self, x): 507 | max_out = self.mlp(self.max_pool(x)) 508 | avg_out = self.mlp(self.avg_pool(x)) 509 | channel_out = self.sigmoid(max_out + avg_out) 510 | x = channel_out * x 511 | 512 | max_out, _ = torch.max(x, dim=1, keepdim=True) 513 | avg_out = torch.mean(x, dim=1, keepdim=True) 514 | spatial_out = self.sigmoid(self.conv(torch.cat([max_out, avg_out], dim=1))) 515 | x = spatial_out * x 516 | return x 517 | 518 | class Multi_CBAM_ResNet(nn.Module): 519 | 520 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 521 | """ 522 | :param classes: Number of classification categories 523 | :param views: Number of modalties 524 | :param classifier_dims: Dimension of the classifier 525 | :param annealing_epoch: KL divergence annealing epoch during training 526 | """ 527 | super(Multi_CBAM_ResNet, self).__init__() 528 | self.modalties = modalties 529 | self.classes = classes 530 | self.lambda_epochs = lambda_epochs 531 | # ---- 2D Res2Net Backbone ---- 532 | self.res2net_2DNet = Medical_feature_2DNet(num_classes=self.classes) 533 | 534 | # ---- 3D ResNet Backbone ---- 535 | classifier_OCT_dims = classifiers_dims[0] 536 | 537 | # ---- CBAM Layer---- 538 | self.CBAM2D_layer = CBAM2D(2048) 539 | self.CBAM3D_layer = CBAM3D(512) 540 | # GAP 541 | self.avgpool = nn.AdaptiveAvgPool2d(1) 542 | 543 | self.resnet_3DNet = Medical_feature_3DNet(classifier_OCT_dims,num_classes=self.classes) 544 | self.sp = nn.Softplus() 545 | if classifier_OCT_dims[0][0] == 128: 546 | self.fc = nn.Linear(2048 + 8192, classes) # MMOCTF 547 | else: 548 | self.fc = nn.Linear(2048 + 3072, classes) # OLIVES 549 | 550 | self.ce_loss = nn.CrossEntropyLoss() 551 | 552 | def forward(self, X, y): 553 | backboneout_1 = self.res2net_2DNet(X[0]) 554 | backboneout_2 = self.resnet_3DNet(X[1]) 555 | backboneout_1_CBAM = self.CBAM2D_layer(backboneout_1) 556 | backboneout_2_CBAM = self.CBAM3D_layer(backboneout_2) 557 | backboneout_1_CBAM_GAP = self.avgpool(backboneout_1_CBAM) 558 | backboneout_2_CBAM_GAP = self.avgpool(backboneout_2_CBAM) 559 | backboneout_1_CBAM_GAP = backboneout_1_CBAM_GAP.view(backboneout_1_CBAM_GAP.size(0), -1) 560 | backboneout_2_CBAM_GAP = backboneout_2_CBAM_GAP.view(backboneout_2_CBAM_GAP.size(0), -1) 561 | combine_features = torch.cat([backboneout_1_CBAM_GAP,backboneout_2_CBAM_GAP],1) 562 | pred = self.fc(combine_features) 563 | loss = self.ce_loss(pred, y) 564 | 565 | loss = torch.mean(loss) 566 | return pred, loss 567 | 568 | 569 | class Multi_ensemble_ResNet(nn.Module): 570 | 571 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 572 | """ 573 | :param classes: Number of classification categories 574 | :param views: Number of modalties 575 | :param classifier_dims: Dimension of the classifier 576 | :param annealing_epoch: KL divergence annealing epoch during training 577 | """ 578 | super(Multi_ensemble_ResNet, self).__init__() 579 | self.modalties = modalties 580 | self.classes = classes 581 | self.lambda_epochs = lambda_epochs 582 | # ---- 2D Res2Net Backbone ---- 583 | self.res2net_2DNet = Medical_base2_2DNet(num_classes=self.classes) 584 | 585 | # ---- 3D ResNet Backbone ---- 586 | classifier_OCT_dims = classifiers_dims[0] 587 | self.resnet_3DNet = Medical_base_3DNet(classifier_OCT_dims,num_classes=self.classes) 588 | self.sp = nn.Softplus() 589 | self.fc = nn.Linear(2048 + 8192, classes) 590 | self.ce_loss = nn.CrossEntropyLoss() 591 | 592 | 593 | def forward(self, X, y): 594 | backboneout_1 = self.res2net_2DNet(X[0]) 595 | backboneout_2 = self.resnet_3DNet(X[1]) 596 | combine_features = torch.cat([backboneout_1,backboneout_2],1) 597 | pred = self.fc(combine_features) 598 | loss = self.ce_loss(pred, y) 599 | 600 | loss = torch.mean(loss) 601 | return pred, loss 602 | 603 | class Multi_ensemble_3D_ResNet(nn.Module): 604 | 605 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 606 | """ 607 | :param classes: Number of classification categories 608 | :param views: Number of modalties 609 | :param classifier_dims: Dimension of the classifier 610 | :param annealing_epoch: KL divergence annealing epoch during training 611 | """ 612 | super(Multi_ensemble_3D_ResNet, self).__init__() 613 | self.modalties = modalties 614 | self.classes = classes 615 | self.lambda_epochs = lambda_epochs 616 | # ---- 2D Res2Net Backbone ---- 617 | self.res2net_2DNet = Medical_base_2DNet(num_classes=self.classes) 618 | 619 | # ---- 3D ResNet Backbone ---- 620 | classifier_OCT_dims = classifiers_dims[0] 621 | self.resnet_3DNet = Medical_base2_3DNet(classifier_OCT_dims,num_classes=self.classes) 622 | self.sp = nn.Softplus() 623 | self.fc = nn.Linear(2048 + 8192, classes) 624 | self.ce_loss = nn.CrossEntropyLoss() 625 | 626 | 627 | def forward(self, X, y): 628 | backboneout_1 = self.res2net_2DNet(X[0]) 629 | backboneout_2 = self.resnet_3DNet(X[1]) 630 | combine_features = torch.cat([backboneout_1,backboneout_2],1) 631 | pred = self.fc(combine_features) 632 | loss = self.ce_loss(pred, y) 633 | 634 | loss = torch.mean(loss) 635 | return pred, loss 636 | 637 | class Multi_dropout_ResNet(nn.Module): 638 | 639 | def __init__(self, classes, modalties, classifiers_dims, lambda_epochs=1): 640 | """ 641 | :param classes: Number of classification categories 642 | :param views: Number of modalties 643 | :param classifier_dims: Dimension of the classifier 644 | :param annealing_epoch: KL divergence annealing epoch during training 645 | """ 646 | super(Multi_dropout_ResNet, self).__init__() 647 | self.modalties = modalties 648 | self.classes = classes 649 | self.lambda_epochs = lambda_epochs 650 | # ---- 2D Res2Net Backbone ---- 651 | self.res2net_2DNet = Medical_base_dropout_2DNet(num_classes=self.classes) 652 | 653 | # ---- 3D ResNet Backbone ---- 654 | classifier_OCT_dims = classifiers_dims[0] 655 | self.resnet_3DNet = Medical_base_dropout_3DNet(classifier_OCT_dims,num_classes=self.classes) 656 | self.sp = nn.Softplus() 657 | 658 | if classifier_OCT_dims[0][0] == 128: 659 | self.fc = nn.Linear(2048 + 8192, classes) # MMOCTF 660 | else: 661 | self.fc = nn.Linear(2048 + 3072, classes) # OLIVES 662 | 663 | self.ce_loss = nn.CrossEntropyLoss() 664 | 665 | 666 | def forward(self, X, y): 667 | backboneout_1 = self.res2net_2DNet(X[0]) 668 | backboneout_2 = self.resnet_3DNet(X[1]) 669 | combine_features = torch.cat([backboneout_1,backboneout_2],1) 670 | pred = self.fc(combine_features) 671 | loss = self.ce_loss(pred, y) 672 | 673 | loss = torch.mean(loss) 674 | return pred, loss -------------------------------------------------------------------------------- /MedIA’24/train3_trans.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from torch.utils.data import DataLoader 6 | from model import TMC,EyeMost_Plus_transformer,EyeMost_Plus,EyeMost,EyeMost_prior 7 | from sklearn.model_selection import KFold 8 | from data import Multi_modal_data,GAMMA_dataset,OLIVES_dataset 9 | import warnings 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from metrics import cal_ece,cal_ece_our 13 | from metrics2 import calc_aurc_eaurc,calc_nll_brier 14 | from scipy.io import loadmat 15 | from sklearn import metrics 16 | from sklearn.metrics import f1_score 17 | from sklearn.metrics import recall_score 18 | from sklearn.metrics import cohen_kappa_score 19 | import torch.nn as nn 20 | import seaborn as sns 21 | import torch.nn.functional as F 22 | import math 23 | warnings.filterwarnings("ignore") 24 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 25 | import time 26 | from numpy import log, sqrt 27 | from scipy.special import psi, beta 28 | import logging 29 | def log_args(log_file): 30 | 31 | logger = logging.getLogger() 32 | logger.setLevel(logging.DEBUG) 33 | formatter = logging.Formatter( 34 | '%(asctime)s ===> %(message)s', 35 | datefmt='%Y-%m-%d %H:%M:%S') 36 | 37 | # args FileHandler to save log file 38 | fh = logging.FileHandler(log_file) 39 | fh.setLevel(logging.DEBUG) 40 | fh.setFormatter(formatter) 41 | 42 | # args StreamHandler to print log to console 43 | ch = logging.StreamHandler() 44 | ch.setLevel(logging.DEBUG) 45 | ch.setFormatter(formatter) 46 | 47 | # add the two Handler 48 | logger.addHandler(ch) 49 | logger.addHandler(fh) 50 | 51 | def Uentropy(logits,c): 52 | 53 | pc = F.softmax(logits, dim=1) 54 | logits = F.log_softmax(logits, dim=1) 55 | u_all = -pc * logits / math.log(c) 56 | NU = torch.sum(u_all[:,1:u_all.shape[1]], dim=1) 57 | return NU 58 | 59 | def entropy(sigma_2,predict_id): 60 | id = int(predict_id) 61 | entropy_list = 0.5 * np.log(2 * np.pi * np.exp(1.) * (sigma_2.cpu().detach().float().numpy())) 62 | # entropy_list = np.log(sigma_2.cpu().detach().float().numpy()) 63 | 64 | NU = entropy_list[0,id] 65 | return NU 66 | 67 | def tensor_id_entropy(logits,predict_id): 68 | id = int(predict_id) 69 | p = F.softmax(logits, dim=1) 70 | logp = F.log_softmax(logits, dim=1) 71 | plogp = p * logp 72 | entropy = -plogp[0][id] 73 | return entropy 74 | 75 | def loss_plot(args,loss): 76 | num = args.end_epochs 77 | x = [i for i in range(num)] 78 | plot_save_path = r'results/plot/' 79 | if not os.path.exists(plot_save_path): 80 | os.makedirs(plot_save_path) 81 | save_loss = plot_save_path+str(args.model_name)+'_'+str(args.batch_size)+'_'+str(args.dataset)+'_'+str(args.end_epochs)+'_loss.jpg' 82 | list_loss = list(loss) 83 | plt.figure() 84 | plt.plot(x,loss,label='loss') 85 | plt.legend() 86 | plt.savefig(save_loss) 87 | 88 | def metrics_plot(arg,name,*args): 89 | num = arg.end_epochs 90 | names = name.split('&') 91 | metrics_value = args 92 | i=0 93 | x = [i for i in range(num)] 94 | plot_save_path = r'results/plot/' 95 | if not os.path.exists(plot_save_path): 96 | os.makedirs(plot_save_path) 97 | save_metrics = plot_save_path + str(arg.model_name) + '_' + str(arg.batch_size) + '_' + str(arg.dataset) + '_' + str(arg.end_epochs) + '_'+name+'.jpg' 98 | plt.figure() 99 | for l in metrics_value: 100 | plt.plot(x,l,label=str(names[i])) 101 | #plt.scatter(x,l,label=str(l)) 102 | i+=1 103 | plt.legend() 104 | plt.savefig(save_metrics) 105 | 106 | class AverageMeter(object): 107 | """Computes and stores the average and current value""" 108 | 109 | def __init__(self): 110 | self.reset() 111 | 112 | def reset(self): 113 | self.val = 0 114 | self.avg = 0 115 | self.sum = 0 116 | self.count = 0 117 | 118 | def update(self, val, n=1): 119 | self.val = val 120 | self.sum += val * n 121 | self.count += n 122 | self.avg = self.sum / self.count 123 | 124 | def train(epoch,train_loader,model): 125 | model.train() 126 | loss_meter = AverageMeter() 127 | # loss_list = [] 128 | for batch_idx, (data, target) in enumerate(train_loader): 129 | for v_num in range(len(data)): 130 | data[v_num] = Variable(data[v_num].cuda()) 131 | target = Variable(target.long().cuda()) 132 | # target = Variable(np.array(target)).cuda()) 133 | 134 | # refresh the optimizer 135 | optimizer.zero_grad() 136 | evidences, evidence_a, loss, _ = model(data, target, epoch) 137 | print("total loss %f"%loss) 138 | # compute gradients and take step 139 | loss.backward() 140 | optimizer.step() 141 | loss_meter.update(loss.item()) 142 | # for i in range(0,len(loss_meter)): 143 | # loss_list = loss_list.append(loss_meter[i].avg) 144 | return loss_meter 145 | 146 | def val(current_epoch,val_loader,model,best_acc): 147 | model.eval() 148 | loss_meter = AverageMeter() 149 | correct_num, data_num = 0, 0 150 | for batch_idx, (data, target) in enumerate(val_loader): 151 | for m_num in range(len(data)): 152 | data[m_num] = Variable(data[m_num].float().cuda()) 153 | data_num += target.size(0) 154 | with torch.no_grad(): 155 | target = Variable(target.long().cuda()) 156 | evidences, evidence_a, loss,_ = model(data, target, epoch) 157 | _, predicted = torch.max(evidence_a.data, 1) 158 | correct_num += (predicted == target).sum().item() 159 | loss_meter.update(loss.item()) 160 | aver_acc = correct_num / data_num 161 | print('====> acc: {:.4f}'.format(aver_acc)) 162 | if evidence_a.shape[1] >2: 163 | if aver_acc > best_acc: 164 | print('aver_acc:{} > best_acc:{}'.format(aver_acc, best_acc)) 165 | best_acc = aver_acc 166 | print('===========>save best model!') 167 | file_name = os.path.join(args.save_dir, 168 | args.model_name + '_' + args.dataset + '_' + args.folder + '_best_epoch.pth') 169 | torch.save({ 170 | 'epoch': current_epoch, 171 | 'state_dict': model.state_dict(), 172 | }, 173 | file_name) 174 | return loss_meter.avg, best_acc 175 | 176 | else: 177 | if (current_epoch + 1) % int(args.end_epochs - 1) == 0 \ 178 | or (current_epoch + 1) % int(args.end_epochs - 2) == 0 \ 179 | or (current_epoch + 1) % int(args.end_epochs - 3) == 0: 180 | file_name = os.path.join(args.save_dir, 181 | args.model_name + '_' + args.dataset + '_' + args.folder + '_epoch_{}.pth'.format( 182 | current_epoch)) 183 | torch.save({ 184 | 'epoch': current_epoch, 185 | 'state_dict': model.state_dict(), 186 | }, 187 | file_name) 188 | if aver_acc > best_acc: 189 | print('aver_acc:{} > best_acc:{}'.format(aver_acc, best_acc)) 190 | best_acc = aver_acc 191 | print('===========>save best model!') 192 | file_name = os.path.join(args.save_dir, 193 | args.model_name + '_' + args.dataset + '_' + args.folder + '_best_epoch.pth') 194 | torch.save({ 'epoch': current_epoch,'state_dict': model.state_dict()},file_name) 195 | # if (current_epoch + 1) % int(args.end_epochs - 1) == 0 \ 196 | # or (current_epoch + 1) % int(args.end_epochs - 2) == 0 \ 197 | # or (current_epoch + 1) % int(args.end_epochs - 3) == 0: 198 | # file_name = os.path.join(args.save_dir, 199 | # args.model_name + '_' + args.dataset + '_' + args.folder + '_epoch_{}.pth'.format( 200 | # current_epoch)) 201 | # torch.save({ 202 | # 'epoch': current_epoch, 203 | # 'state_dict': model.state_dict(), 204 | # }, 205 | # file_name) 206 | # if aver_acc > best_acc: 207 | # print('aver_acc:{} > best_acc:{}'.format(aver_acc, best_acc)) 208 | # best_acc = aver_acc 209 | # print('===========>save best model!') 210 | # file_name = os.path.join(args.save_dir, 211 | # args.model_name + '_' + args.dataset + '_' + args.folder + '_best_epoch.pth') 212 | # torch.save({ 'epoch': current_epoch,'state_dict': model.state_dict()},file_name) 213 | return loss_meter.avg, best_acc 214 | 215 | def test(args, test_loader,model,epoch): 216 | if args.num_classes == 2: 217 | if args.test_epoch > 200: 218 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 219 | args.save_dir, 220 | args.model_name + '_' + args.dataset +'_'+ args.folder + '_best_epoch.pth') 221 | else: 222 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 223 | args.save_dir, 224 | args.model_name + '_' + args.dataset +'_'+ args.folder + '_epoch_{}.pth'.format(args.test_epoch)) 225 | else: 226 | if args.test_epoch > 200: 227 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 228 | args.save_dir, 229 | args.model_name + '_' + args.dataset +'_'+ args.folder + '_best_epoch.pth') 230 | else: 231 | load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)), 232 | args.save_dir, 233 | args.model_name + '_' + args.dataset +'_'+ args.folder + '_epoch_{}.pth'.format(args.test_epoch)) 234 | 235 | if os.path.exists(load_file): 236 | checkpoint = torch.load(load_file) 237 | model.load_state_dict(checkpoint['state_dict']) 238 | args.start_epoch = checkpoint['epoch'] 239 | print('Successfully load checkpoint {}'.format( 240 | os.path.join(args.save_dir + '/' + args.model_name +'_'+args.dataset+ '_epoch_' + str(args.test_epoch)))) 241 | else: 242 | print('There is no resume file to load!') 243 | model.eval() 244 | list_acc = [] 245 | 246 | Fundus_confi_list= [] 247 | OCT_confi_list= [] 248 | fusion_confi_list = [] 249 | 250 | correct_list = [] 251 | entropy_list =[] 252 | ece_list = [] 253 | nll_list = [] 254 | brier_list = [] 255 | label_list = [] 256 | prediction_list = [] 257 | probability_list = [] 258 | evidence_list = [] 259 | one_hot_label_list = [] 260 | one_hot_probability_list = [] 261 | correct_num, data_num = 0, 0 262 | 263 | for batch_idx, (data, target) in enumerate(test_loader): 264 | for v_num in range(len(data)): 265 | data[v_num] = Variable(data[v_num].float().cuda()) 266 | data_num += target.size(0) 267 | with torch.no_grad(): 268 | target = Variable(target.long().cuda()) 269 | if args.model_name =="ResNet_TMC": 270 | evidences, evidence_a, _, u_a= model(data, target, epoch) 271 | else: 272 | evidences, evidence_a, _, u_a, gamma, v, alpha, beta = model(data, target, epoch) 273 | 274 | probability_fu_m = F.softmax(evidence_a) 275 | confidence_fu_m, _ = torch.max(probability_fu_m, axis=1) 276 | 277 | # probability_list.append(b_a.cpu().detach().float().numpy()) 278 | correct_pred, predicted = torch.max(evidence_a.data, 1) 279 | correct_num += (predicted == target).sum().item() 280 | correct = (predicted == target) 281 | 282 | list_acc.append((predicted == target).sum().item()) 283 | probability = torch.softmax(evidence_a, dim=1).cpu().detach().float().numpy() 284 | one_hot_label = F.one_hot(target, num_classes=args.num_classes).squeeze(dim=0).cpu().detach().float().numpy() 285 | # NLL brier 286 | nll, brier = calc_nll_brier(probability, evidence_a, target, one_hot_label) 287 | nll_list.append(nll) 288 | brier_list.append(brier) 289 | prediction_list.append(predicted.cpu().detach().float().numpy()) 290 | label_list.append(target.cpu().detach().float().numpy()) 291 | correct_list.append(correct.cpu().detach().float().numpy()) 292 | one_hot_label_list.append(F.one_hot(target, num_classes=args.num_classes).squeeze(dim=0).cpu().detach().float().numpy()) 293 | probability_list.append(torch.softmax(evidence_a, dim=1).cpu().detach().float().numpy()[:,1]) 294 | evidence_list.append(evidence_a.cpu().detach().float().numpy()) 295 | one_hot_probability_list.append(torch.softmax(evidence_a, dim=1).data.squeeze(dim=0).cpu().detach().float().numpy()) 296 | # ece 297 | ece_list.append(cal_ece_our(torch.squeeze(evidence_a), target)) 298 | # entropy 299 | entropy_list.append(tensor_id_entropy(evidence_a,predicted.cpu().detach().float().numpy())) 300 | 301 | if args.num_classes > 2: 302 | epoch_auc = metrics.roc_auc_score(one_hot_label_list,one_hot_probability_list, multi_class='ovo') 303 | else: 304 | epoch_auc = metrics.roc_auc_score(label_list,probability_list) 305 | 306 | # correct_list = (prediction_list==label_list) 307 | aurc, eaurc = calc_aurc_eaurc(probability_list, correct_list) 308 | # np.savez(r'./results/OLIVES_Fundus_noise01.npz', Fundus_confi_list=Fundus_confi_list,OCT_confi_list=OCT_confi_list,fusion_confi_list=fusion_confi_list,np_u_list=np_u_list, np_entropy_list=np_entropy_list,np_OCT_au_list=np_OCT_au_list,np_Fundus_au_list=np_Fundus_au_list,np_OCT_eu_list=np_OCT_eu_list,np_Fundus_eu_list=np_Fundus_eu_list, np_OCT_uall=np_OCT_uall,np_Fundus_uall=np_Fundus_uall) 309 | np.savez(r'./results/' + args.dataset + "_" + "noise" + "_" + str(Condition_G_Variance[0])+".npz", Fundus_confi_list=Fundus_confi_list,OCT_confi_list=OCT_confi_list,fusion_confi_list=fusion_confi_list,np_u_list=np_u_list, np_entropy_list=np_entropy_list,np_OCT_au_list=np_OCT_au_list,np_Fundus_au_list=np_Fundus_au_list,np_OCT_eu_list=np_OCT_eu_list,np_Fundus_eu_list=np_Fundus_eu_list, np_OCT_uall=np_OCT_uall,np_Fundus_uall=np_Fundus_uall) 310 | 311 | avg_acc = correct_num/data_num 312 | avg_ece = sum(ece_list)/len(ece_list) 313 | avg_nll = sum(nll_list)/len(nll_list) 314 | avg_brier = sum(brier_list)/len(brier_list) 315 | 316 | avg_kappa = cohen_kappa_score(prediction_list, label_list) 317 | F1_Score = f1_score(y_true=label_list, y_pred=prediction_list, average='weighted') 318 | Recall_Score = recall_score(y_true=label_list, y_pred=prediction_list, average='weighted') 319 | 320 | if not os.path.exists(os.path.join(args.save_dir, "{}_{}_{}".format(args.model_name,args.dataset,args.folder))): 321 | os.makedirs(os.path.join(args.save_dir, "{}_{}_{}".format(args.model_name,args.dataset,args.folder))) 322 | 323 | with open(os.path.join(args.save_dir,"{}_{}_{}_Metric.txt".format(args.model_name,args.dataset,args.folder)),'w') as Txt: 324 | Txt.write("Acc: {}, AUC: {}, AURC: {}, EAURC: {}, NLL: {}, BRIER: {}, F1_Score: {}, Recall_Score: {}, Kappa_Score: {}, ECE: {}\n".format( 325 | round(avg_acc,6),round(epoch_auc,6),round(aurc,6),round(eaurc,6),round(avg_nll,6),round(avg_brier,6),round(F1_Score,6),round(Recall_Score,6),round(avg_kappa,6),round(avg_ece,6) 326 | )) 327 | # print( 328 | # "Acc: {:.4f}, AUC: {:.4f}, AURC: {:.4f}, EAURC: {:.4f}, NLL: {:.4f}, BRIER: {:.4f}, F1_Score: {:.4f}, Recall_Score: {:.4f}, kappa: {:.4f}, ECE: {:.4f}".format( 329 | # avg_acc, epoch_auc, aurc, eaurc, avg_nll, avg_brier, F1_Score, Recall_Score, avg_kappa, avg_ece)) 330 | # print('====> mean_inc_u: {:.4f},mean_c_u: {:.4f}\n'.format(mean_inc_u, mean_c_u)) 331 | return avg_acc,epoch_auc,aurc, eaurc, avg_nll,avg_brier,F1_Score,Recall_Score,avg_kappa,avg_ece 332 | 333 | if __name__ == "__main__": 334 | 335 | # kkk= math.log(math.pi) 336 | # filename = './datasets/handwritten_6views.mat' 337 | # image = loadmat(filename) 338 | import argparse 339 | 340 | parser = argparse.ArgumentParser() 341 | parser.add_argument('--batch-size', type=int, default=16, metavar='N', 342 | help='input batch size for training [default: 100]') 343 | parser.add_argument('--start_epoch', type=int, default=1, metavar='N', 344 | help='number of epochs to train [default: 500]') 345 | parser.add_argument('--end_epochs', type=int, default=100, metavar='N', 346 | help='number of epochs to train [default: 500]') 347 | parser.add_argument('--test_epoch', type=int, default=98, metavar='N', 348 | help='number of epochs to train [default: 500]') 349 | parser.add_argument('--lambda_epochs', type=int, default=50, metavar='N', 350 | help='gradually increase the value of lambda from 0 to 1') 351 | parser.add_argument('--modal_number', type=int, default=2, metavar='N', 352 | help='modalties number') 353 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', 354 | help='learning rate') # 0.0001 ResNet_ECNP_ST_Beta_fusion_C_DF_CML_loss ResNet_ECNP_ST_Beta_fusion_C_DF_CML_loss_evi ResNet_ECNP_ST_Beta_fusion_C_DF_CML ResNet_ECNP_ST_Beta_fusion_C_DF_CML_2D ResNet_ECNP_ST_Beta_fusion_C_DF_CML_3D 355 | parser.add_argument('--save_dir', default='./results', type=str) # ResNet_MMST /ResNet_MMST_Reg /ResNet_MMST_Reg_Ce/ResNet_MMST_Reg/ResNet_MMST_Reg_CML//ResNet_MMST_Reg_CMLU /ResNet_MMST_Reg_Ce/ ResNet_TMC / ResNet_ECNP / ResNet_ECNP_Beta_CML / ResNet_ECNP_Beta_Reg /ResNet_ECNP_Beta_Reg_Ce / ResNet_ECNP_Beta_Reg_CML / ResNet_ECNP_Beta_Reg_CMLU/ ResNet_MMST_Reg_Ce_CMLU/ ResNet_MMST_Reg_Ce_CML 356 | parser.add_argument("--model_name", default="EyeMost_PLUS_transformer", type=str, help="ResNet_TMC/Base_transformer/ResNet_ECNP_ST_Beta_fusion3/ResNet_ECNP_ST_Beta_fusion_C_DF/ResNet_ECNP_ST_Beta_fusion_C_DF_CML/ResNet_ECNP_ST_Beta_fusion_C_DF_CML_lamdaC05/ResNet_ECNP_ST_Beta_fusion_C_DF_CML_lamdaC15") 357 | parser.add_argument("--model_base", default="transformer", type=str, help="transformer/cnn") 358 | parser.add_argument("--dataset", default="MGamma", type=str, help="MMOCTF/Gamma/MGamma/OLIVES") 359 | # parser.add_argument('--num_classes', type=int, default=2, metavar='N', 360 | # help='class number: MMOCTF: 2 /Gamma: 3') 361 | parser.add_argument("--condition", default="normal", type=str, help="noise/normal") 362 | parser.add_argument("--condition_name", default="Gaussian", type=str, help="Gaussian/SaltPepper/All") 363 | parser.add_argument("--Condition_SP_Variance", default=0.005, type=int, help="Variance: 0.01/0.1") 364 | parser.add_argument("--Condition_G_Variance", default=0.05, type=int, help="Variance: 15/1/0.1") 365 | parser.add_argument("--folder", default="folder0", type=str, help="folder0/folder1/folder2/folder3/folder4") 366 | parser.add_argument("--mode", default="test", type=str, help="test/train&test") 367 | # -- for ECNP parameters 368 | parser.add_argument('-rps', '--representation_size', type=int, default=128, help='Representation size for context') 369 | parser.add_argument('-hs', '--hidden_size', type=int, default=128, help='Model hidden size') 370 | parser.add_argument('-ev_dec_beta_min', '--ev_dec_beta_min', type=float, default=0.2, 371 | help="EDL Decoder beta minimum value") 372 | parser.add_argument('-ev_dec_alpha_max', '--ev_dec_alpha_max', type=float, default=20.0, 373 | help="EDL output alpha maximum value") 374 | parser.add_argument('-ev_dec_v_max', '--ev_dec_v_max', type=float, default=20.0, help="EDL output v maximum value") 375 | parser.add_argument('-nig_nll_reg_coef', '--nig_nll_reg_coef', type=float, default=0.1, 376 | help="EDL nll reg balancing factor") 377 | parser.add_argument('-nig_nll_ker_reg_coef', '--nig_nll_ker_reg_coef', type=float, default=1.0, 378 | help='EDL kernel reg balancing factor') 379 | parser.add_argument('-ev_st_u_min', '--ev_st_u_min', type=float, default=0.0001, 380 | help="EDL st output sigma minnum value") 381 | parser.add_argument('-ev_st_sigma_min', '--ev_st_sigma_min', type=float, default=0.2, 382 | help="EDL st output sigma minnum value") 383 | parser.add_argument('-ev_st_v_max', '--ev_st_v_max', type=float, default=30.0, help="EDL output v maximum value") 384 | 385 | 386 | Condition_G_Variance = [0,0.01, 0.03, 0.05, 0.07, 0.1,0.3,0.5] # OCT & GAMMA 387 | 388 | seed_num = list(range(1,11)) 389 | condition_level = ['normal','noise'] 390 | 391 | args = parser.parse_args() 392 | args.seed_idx = 11 393 | 394 | if args.dataset =="MMOCTF": 395 | args.data_path = '/data/zou_ke/projects_data/Multi-OF/2000/' 396 | args.modalties_name = ["FUN", "OCT"] 397 | args.num_classes = 2 398 | args.dims = [[(128, 256, 128)], [(512, 512)]] 399 | args.modalties = len(args.dims) 400 | train_loader = torch.utils.data.DataLoader( 401 | Multi_modal_data(args.data_path, args.modal_number,args.modalties_name, 'train',args.condition,args, folder=args.folder), batch_size=args.batch_size) 402 | val_loader = torch.utils.data.DataLoader( 403 | Multi_modal_data(args.data_path, args.modal_number, args.modalties_name, 'val',args.condition,args, folder=args.folder), batch_size=1) 404 | test_loader = torch.utils.data.DataLoader( 405 | Multi_modal_data(args.data_path, args.modal_number, args.modalties_name, 'test',args.condition,args, folder=args.folder), batch_size=1) 406 | N_mini_batches = len(train_loader) 407 | print('The number of training images = %d' % N_mini_batches) 408 | elif args.dataset =="OLIVES": 409 | args.data_path = '/data/zou_ke/projects_data/OLIVES/OLIVES/' 410 | # args.data_path = '/data/zou_ke/projects_data/OLIVES2/OLIVES/' 411 | # args.data_path = '/data/zou_ke/projects_data/OLIVES3/OLIVES/' 412 | 413 | args.modalties_name = ["FUN", "OCT"] 414 | args.num_classes = 2 415 | args.dims = [[(48, 248, 248)], [(512, 512)]] 416 | args.modalties = len(args.dims) 417 | train_loader = torch.utils.data.DataLoader( 418 | OLIVES_dataset(args.data_path, args.modal_number,args.modalties_name, 'train',args.condition,args, folder=args.folder), batch_size=args.batch_size) 419 | val_loader = torch.utils.data.DataLoader( 420 | OLIVES_dataset(args.data_path, args.modal_number, args.modalties_name, 'val',args.condition,args, folder=args.folder), batch_size=1) 421 | test_loader = torch.utils.data.DataLoader( 422 | OLIVES_dataset(args.data_path, args.modal_number, args.modalties_name, 'test',args.condition,args, folder=args.folder), batch_size=1) 423 | N_mini_batches = len(train_loader) 424 | print('The number of training images = %d' % N_mini_batches) 425 | elif args.dataset =="MGamma": 426 | args.modalties_name = ["FUN", "OCT"] 427 | args.dims = [[(128, 256, 128)], [(512, 512)]] 428 | args.num_classes = 3 429 | args.modalties = len(args.dims) 430 | args.base_path = '/data/zou_ke/projects_data/Multi-OF/Gamma/' 431 | args.data_path = '/data/zou_ke/projects_data/Multi-OF/MGamma/' 432 | filelists = os.listdir(args.data_path) 433 | # kf = KFold(n_splits=5, shuffle=True, random_state=10) 434 | kf = KFold(n_splits=5, shuffle=True, random_state=10) 435 | 436 | y = kf.split(filelists) 437 | count = 0 438 | train_filelists = [[], [], [], [], []] 439 | val_filelists = [[], [], [], [], []] 440 | for tidx, vidx in y: 441 | train_filelists[count], val_filelists[count] = np.array(filelists)[tidx], np.array(filelists)[vidx] 442 | count = count + 1 443 | f_folder = int(args.folder[-1]) 444 | train_dataset = GAMMA_dataset(args,dataset_root = args.data_path, 445 | oct_img_size = args.dims[0], 446 | fundus_img_size = args.dims[1], 447 | mode = 'train', 448 | label_file = args.base_path + 'glaucoma_grading_training_GT.xlsx', 449 | filelists = np.array(train_filelists[f_folder])) 450 | 451 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size) 452 | val_dataset = GAMMA_dataset(args,dataset_root=args.data_path, 453 | oct_img_size = args.dims[0], 454 | fundus_img_size = args.dims[1], 455 | mode = 'val', 456 | label_file = args.base_path + 'glaucoma_grading_training_GT.xlsx', 457 | filelists = np.array(val_filelists[f_folder]),) 458 | 459 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1) 460 | test_dataset = val_dataset 461 | test_loader = val_loader 462 | else: 463 | print('There is no this dataset name') 464 | raise NameError 465 | 466 | if args.model_name =="ResNet_TMC": 467 | model = TMC(args.num_classes, args.modalties, args.dims, args.lambda_epochs) 468 | elif args.model_name == "EyeMost_prior": # fusion rule for mean Our 469 | model = EyeMost_prior(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs) 470 | elif args.model_name == "EyeMost": # fusion rule for mean Our 471 | model = EyeMost(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs) 472 | elif args.model_name == "EyeMost_Plus": # lamdaC = 10 473 | model = EyeMost_Plus(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs) 474 | elif args.model_name == "EyeMost_Plus_transformer": # lamdaC = 10 475 | model = EyeMost_Plus_transformer(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs) 476 | else: 477 | print('There is no this model name') 478 | raise NameError 479 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) 480 | # optimizer = optim.AdamW(model.parameters(), lr=1e-5, betas=(0.9, 0.999), 481 | # weight_decay=0.0005) 482 | model.cuda() 483 | best_acc = 0 484 | loss_list = [] 485 | acc_list = [] 486 | 487 | if args.mode =='train&test': 488 | epoch = 0 489 | for epoch in range(args.start_epoch, args.end_epochs + 1): 490 | print('===========Train begining!===========') 491 | print('Epoch {}/{}'.format(epoch, args.end_epochs - 1)) 492 | epoch_loss = train(epoch,train_loader,model) 493 | print("epoch %d avg_loss:%0.3f" % (epoch, epoch_loss.avg)) 494 | val_loss, best_acc = val(epoch,val_loader,model,best_acc) 495 | loss_list.append(epoch_loss.avg) 496 | acc_list.append(best_acc) 497 | loss_plot(args, loss_list) 498 | metrics_plot(args, 'acc', acc_list) 499 | test_acc,test_acclist = test(args,test_loader,model,epoch) 500 | 501 | elif args.mode == 'test': 502 | epoch = args.test_epoch 503 | for i in range(len(Condition_G_Variance)): 504 | args.Condition_G_Variance = Condition_G_Variance[i] 505 | print("Gaussian noise: %f" % args.Condition_G_Variance) 506 | acc_list,auc_list,aurc_list,eaurc_list,nll_list, brier_list,\ 507 | F1_list,Rec_list,kap_list,ECE_list = [],[],[],[],[],[],[],[],[],[] 508 | 509 | for j in range(len(seed_num)): 510 | # for j in range(1): 511 | 512 | args.seed_idx = seed_num[j] 513 | # print("seed_idx: %d" % args.seed_idx) 514 | 515 | if args.dataset == "MMOCTF": 516 | args.data_path = '/data/zou_ke/projects_data/Multi-OF/2000/' 517 | args.modalties_name = ["FUN", "OCT"] 518 | args.num_classes = 2 519 | args.dims = [[(128, 256, 128)], [(512, 512)]] 520 | args.modalties = len(args.dims) 521 | 522 | test_loader = torch.utils.data.DataLoader( 523 | Multi_modal_data(args.data_path, args.modal_number, args.modalties_name, 'test', args.condition, 524 | args, folder=args.folder), batch_size=1) 525 | N_mini_batches = len(test_loader) 526 | # print('The number of testing images = %d' % N_mini_batches) 527 | elif args.dataset == "OLIVES": 528 | args.data_path = '/data/zou_ke/projects_data/OLIVES/OLIVES/' 529 | # args.data_path = '/data/zou_ke/projects_data/OLIVES2/OLIVES/' 530 | # args.data_path = '/data/zou_ke/projects_data/OLIVES3/OLIVES/' 531 | 532 | args.modalties_name = ["FUN", "OCT"] 533 | args.num_classes = 2 534 | args.dims = [[(48, 248, 248)], [(512, 512)]] 535 | args.modalties = len(args.dims) 536 | 537 | test_loader = torch.utils.data.DataLoader( 538 | OLIVES_dataset(args.data_path, args.modal_number, args.modalties_name, 'test', args.condition, 539 | args, folder=args.folder), batch_size=1) 540 | N_mini_batches = len(test_loader) 541 | # print('The number of testing images = %d' % N_mini_batches) 542 | elif args.dataset == "MGamma": 543 | args.modalties_name = ["FUN", "OCT"] 544 | args.dims = [[(128, 256, 128)], [(512, 512)]] 545 | args.num_classes = 3 546 | args.modalties = len(args.dims) 547 | args.base_path = '/data/zou_ke/projects_data/Multi-OF/Gamma/' 548 | args.data_path = '/data/zou_ke/projects_data/Multi-OF/MGamma/' 549 | filelists = os.listdir(args.data_path) 550 | kf = KFold(n_splits=5, shuffle=True, random_state=10) 551 | y = kf.split(filelists) 552 | count = 0 553 | train_filelists = [[], [], [], [], []] 554 | val_filelists = [[], [], [], [], []] 555 | for tidx, vidx in y: 556 | train_filelists[count], val_filelists[count] = np.array(filelists)[tidx], np.array(filelists)[ 557 | vidx] 558 | count = count + 1 559 | f_folder = int(args.folder[-1]) 560 | val_dataset = GAMMA_dataset(args, dataset_root=args.data_path, 561 | oct_img_size=args.dims[0], 562 | fundus_img_size=args.dims[1], 563 | mode='val', 564 | label_file=args.base_path + 'glaucoma_grading_training_GT.xlsx', 565 | filelists=np.array(val_filelists[f_folder]), ) 566 | 567 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1) 568 | test_dataset = val_dataset 569 | test_loader = val_loader 570 | N_mini_batches = len(test_loader) 571 | # print('The number of testing images = %d' % N_mini_batches) 572 | else: 573 | print('There is no this dataset name') 574 | raise NameError 575 | # start_time = time.time() 576 | test_acc, test_auc, test_aurc, test_eaurc, test_nll, test_brier,test_F1, test_Rec, test_kappa, test_ece\ 577 | = test(args, test_loader, model, epoch) 578 | # print('time cost:' + str(int(time.time()-start_time)) + 's') 579 | acc_list.append(test_acc) 580 | auc_list.append(test_auc) 581 | aurc_list.append(test_aurc) 582 | eaurc_list.append(test_eaurc) 583 | nll_list.append(test_nll) 584 | brier_list.append(test_brier) 585 | F1_list.append(test_F1) 586 | Rec_list.append(test_Rec) 587 | kap_list.append(test_kappa) 588 | ECE_list.append(test_ece) 589 | 590 | acc_list_mean,acc_list_std = np.mean(acc_list),np.std(acc_list) 591 | auc_list_mean,auc_list_std = np.mean(auc_list),np.std(auc_list) 592 | aurc_list_mean,aurc_list_std = np.mean(aurc_list),np.std(aurc_list) 593 | eaurc_list_mean,eaurc_list_std = np.mean(eaurc_list),np.std(eaurc_list) 594 | nll_list_mean,nll_list_std = np.mean(nll_list),np.std(nll_list) 595 | brier_list_mean,brier_list_std = np.mean(brier_list),np.std(brier_list) 596 | F1_list_mean,F1_list_std = np.mean(F1_list),np.std(F1_list) 597 | Rec_list_mean,Rec_list_std = np.mean(Rec_list),np.std(Rec_list) 598 | kap_list_mean,kap_list_std = np.mean(kap_list),np.std(kap_list) 599 | ECE_list_mean,ECE_list_std = np.mean(ECE_list),np.std(ECE_list) 600 | print( 601 | "Mean_Std_Acc: {:.4f} +- {:.4f}, Mean_Std_AUC: {:.4f} +- {:.4f},Mean_Std_AURC: {:.4f} +- {:.4f}, " 602 | "Mean_Std_EAURC: {:.4f} +- {:.4f}, Mean_Std_nll: {:.4f} +- {:.4f}, Mean_Std_brier: {:.4f} +- {:.4f}, Mean_Std_F1_Score: {:.4f} +- " 603 | "{:.4f}, Mean_Std_Recall_Score: {:.4f} +- {:.6f}, Mean_Std_kappa: {:.4f} +- {:.4f}, Mean_Std_ECE: {:.4f} +- {:.4f}".format( 604 | acc_list_mean, acc_list_std, auc_list_mean,auc_list_std, aurc_list_mean, aurc_list_std, eaurc_list_mean, eaurc_list_std, 605 | nll_list_mean, nll_list_std,brier_list_mean,brier_list_std,F1_list_mean, F1_list_std,Rec_list_mean,Rec_list_std,kap_list_mean,kap_list_std,ECE_list_mean,ECE_list_std)) 606 | logging.info( 607 | "Mean_Std_Acc: {:.4f} +- {:.4f}, Mean_Std_AUC: {:.4f} +- {:.4f},Mean_Std_AURC: {:.4f} +- {:.4f}, " 608 | "Mean_Std_EAURC: {:.4f} +- {:.4f}, Mean_Std_nll: {:.4f} +- {:.4f}, Mean_Std_brier: {:.4f} +- {:.4f}, Mean_Std_F1_Score: {:.4f} +- " 609 | "{:.4f}, Mean_Std_Recall_Score: {:.4f} +- {:.4f}, Mean_Std_kappa: {:.4f} +- {:.4f}, Mean_Std_ECE: {:.4f} +- {:.4f}".format( 610 | acc_list_mean, acc_list_std, auc_list_mean,auc_list_std, aurc_list_mean, aurc_list_std, eaurc_list_mean, eaurc_list_std, 611 | nll_list_mean, nll_list_std,brier_list_mean,brier_list_std,F1_list_mean, F1_list_std,Rec_list_mean,Rec_list_std,kap_list_mean,kap_list_std,ECE_list_mean,ECE_list_std)) 612 | -------------------------------------------------------------------------------- /MICCAI23/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | from torch.utils.data import Dataset 4 | # from sklearn.preprocessing import MinMaxScaler 5 | from os.path import join 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | from PIL import Image 10 | import os 11 | import sys 12 | import argparse 13 | import time 14 | import math 15 | import pandas as pd 16 | from sklearn.model_selection import KFold 17 | import cv2 18 | from torchvision import transforms 19 | from scipy import ndimage 20 | 21 | # def addsalt_pepper(img, SNR): 22 | # img_ = img.copy() 23 | # c, h, w = img_.shape 24 | # mask = np.random.choice((0, 1, 2), size=(1, h, w), p=[SNR, (1 - SNR) / 2., (1 - SNR) / 2.]) 25 | # mask = np.repeat(mask, c, axis=0) # 按channel 复制到 与img具有相同的shape 26 | # img_[mask == 1] = 255 # 盐噪声 27 | # img_[mask == 2] = 0 # 椒噪声 28 | # 29 | # return img_ 30 | def add_salt_peper_3D(image,amout): 31 | # 设置添加椒盐噪声的数目比例 32 | s_vs_p = 0.5 33 | noisy_img = np.copy(image) 34 | # 添加salt噪声 35 | num_salt = np.ceil(amout * image.size * s_vs_p) 36 | # 设置添加噪声的坐标位置 37 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape] 38 | noisy_img[coords[0], coords[1]] = 1. 39 | # 添加pepper噪声 40 | num_pepper = np.ceil(amout * image.size * (1. - s_vs_p)) 41 | # 设置添加噪声的坐标位置 42 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape] 43 | noisy_img[coords[0], coords[1]] = 0. 44 | # out_img = noisy_img 45 | return noisy_img 46 | 47 | def add_salt_peper(image,amout): 48 | # 设置添加椒盐噪声的数目比例 49 | s_vs_p = 0.5 50 | noisy_img = np.copy(image) 51 | # 添加salt噪声 52 | num_salt = np.ceil(amout * image.shape[0] * image.shape[1] * s_vs_p) 53 | # 设置添加噪声的坐标位置 54 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape] 55 | noisy_img[coords[0], coords[1], :] = 1. 56 | # 添加pepper噪声 57 | num_pepper = np.ceil(amout * image.shape[0] * image.shape[1] * (1. - s_vs_p)) 58 | # 设置添加噪声的坐标位置 59 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape] 60 | noisy_img[coords[0], coords[1], :] = 0. 61 | # out_img = noisy_img 62 | return noisy_img 63 | 64 | class GAMMA_sub1_dataset(Dataset): 65 | def __init__(self, 66 | dataset_root, 67 | oct_img_size, 68 | fundus_img_size, 69 | mode='train', 70 | label_file='', 71 | filelists=None, 72 | ): 73 | 74 | self.dataset_root = dataset_root 75 | self.input_D = oct_img_size[0][0] 76 | self.input_H = oct_img_size[0][1] 77 | self.input_W = oct_img_size[0][2] 78 | mean = (0.3163843, 0.86174834, 0.3641431) 79 | std = (0.24608557, 0.11123227, 0.26710403) 80 | normalize = transforms.Normalize(mean=mean, std=std) 81 | 82 | self.fundus_train_transforms = transforms.Compose([ 83 | transforms.ToTensor(), 84 | transforms.RandomApply([ 85 | transforms.ColorJitter(0.2, 0.2, 0.2, 0.1) 86 | ], p=0.8), 87 | transforms.RandomGrayscale(p=0.2), 88 | transforms.CenterCrop(600), 89 | transforms.Resize(fundus_img_size[0][0]), 90 | transforms.RandomHorizontalFlip(), 91 | normalize, 92 | ]) 93 | 94 | self.oct_train_transforms = transforms.Compose([ 95 | transforms.ToTensor(), 96 | transforms.RandomHorizontalFlip(), 97 | ]) 98 | 99 | self.fundus_val_transforms = transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Resize(fundus_img_size[0][0]) 102 | ]) 103 | 104 | self.oct_val_transforms = transforms.Compose([ 105 | transforms.ToTensor(), 106 | ]) 107 | 108 | self.mode = mode.lower() 109 | label = {row['data']: row[1:].values 110 | for _, row in pd.read_excel(label_file).iterrows()} 111 | # if train is all 112 | self.file_list = [] 113 | for f in filelists: 114 | self.file_list.append([f, label[int(f)]]) 115 | 116 | # if only for test 117 | # if self.mode == 'train': 118 | # label = {row['data']: row[1:].values 119 | # for _, row in pd.read_excel(label_file).iterrows()} 120 | # 121 | # self.file_list = [[f, label[int(f)]] for f in os.listdir(dataset_root)] 122 | # elif self.mode == "test" or self.mode == "val" : 123 | # self.file_list = [[f, None] for f in os.listdir(dataset_root)] 124 | 125 | # if filelists is not None: 126 | # self.file_list = [item for item in self.file_list if item[0] in filelists] 127 | def __getitem__(self, idx): 128 | data = dict() 129 | 130 | real_index, label = self.file_list[idx] 131 | 132 | # Fundus read 133 | fundus_img_path = os.path.join(self.dataset_root, real_index,real_index +".jpg") 134 | fundus_img = cv2.imread(fundus_img_path)[:, :, ::-1] # BGR -> RGB 135 | # OCT read 136 | # oct_series_list = sorted(os.listdir(os.path.join(self.dataset_root, real_index, real_index)), 137 | # key=lambda x: int(x.strip("_")[0])) 138 | oct_series_list = os.listdir(os.path.join(self.dataset_root, real_index, real_index)) 139 | oct_series_0 = cv2.imread(os.path.join(self.dataset_root, real_index, real_index, oct_series_list[0]), 140 | cv2.IMREAD_GRAYSCALE) 141 | oct_img = np.zeros((len(oct_series_list), oct_series_0.shape[0], oct_series_0.shape[1], 1), dtype="uint8") 142 | for k, p in enumerate(oct_series_list): 143 | oct_img[k] = cv2.imread( 144 | os.path.join(self.dataset_root, real_index, real_index, p), cv2.IMREAD_GRAYSCALE)[..., np.newaxis] 145 | 146 | # Fundus clip 147 | if fundus_img.shape[0] == 2000: 148 | fundus_img = fundus_img[1000 - 967:1000 + 967, 1496 - 978:1496 + 978, :] 149 | 150 | fundus_img = fundus_img.copy() 151 | oct_img = self.__resize_oct_data__(oct_img) 152 | fundus_img = fundus_img / 255.0 153 | oct_img = oct_img / 255.0 154 | if self.mode == "train": 155 | fundus_img = self.fundus_train_transforms(fundus_img.astype(np.float32)) 156 | oct_img = self.oct_train_transforms(oct_img.astype(np.float32)) 157 | else: 158 | fundus_img = self.fundus_val_transforms(fundus_img) 159 | oct_img = self.oct_val_transforms(oct_img) 160 | # data[0] = fundus_img.transpose(2, 0, 1) # H, W, C -> C, H, W 161 | # data[1] = oct_img.squeeze(-1) # D, H, W, 1 -> D, H, W 162 | data[0] = fundus_img 163 | data[1] = oct_img.unsqueeze(0) 164 | 165 | label = label.argmax() 166 | 167 | return data, label 168 | 169 | def __len__(self): 170 | return len(self.file_list) 171 | 172 | def __resize_oct_data__(self, data): 173 | """ 174 | Resize the data to the input size 175 | """ 176 | data = data.squeeze() 177 | [depth, height, width] = data.shape 178 | scale = [self.input_D*1.0/depth, self.input_H *1.0/height, self.input_W*1.0/width] 179 | data = ndimage.interpolation.zoom(data, scale, order=0) 180 | # data = data.unsqueeze() 181 | return data 182 | 183 | class OLIVES_dataset(Dataset): 184 | """ 185 | load multi-view data 186 | """ 187 | 188 | def __init__(self, root, modal_number,modalties,mode,condition,args, folder='folder0'): 189 | """ 190 | :param root: data name and path 191 | :param train: load training set or test set 192 | """ 193 | super(OLIVES_dataset, self).__init__() 194 | self.root = root 195 | self.mode = mode 196 | self.data_path = self.root + folder + "/" 197 | self.modalties = modalties 198 | self.condition = condition 199 | self.condition_name = args.condition_name 200 | self.seed_idx = args.seed_idx 201 | self.Condition_SP_Variance = args.Condition_SP_Variance 202 | self.Condition_G_Variance = args.Condition_G_Variance 203 | y_files = [] 204 | 205 | self.X = dict() 206 | for m_num in range(modal_number): 207 | x_files = [] 208 | c_m = modalties[m_num] 209 | with open(join(self.data_path, self.mode +"_" + c_m + '.txt'), 210 | 'r',encoding="gb18030",errors="ignore") as fx: 211 | files = fx.readlines() 212 | for file in files: 213 | file = file.replace('\n', '') 214 | x_files.append(file) 215 | self.X[m_num] =x_files 216 | with open(join(self.data_path, self.mode + '_GT.txt'), 217 | 'r') as fy: 218 | yfiles = fy.readlines() 219 | for yfile in yfiles: 220 | yfile = yfile.replace('\n', '') 221 | y_files.append(yfile) 222 | self.y = y_files 223 | 224 | def __getitem__(self, file_num): 225 | data = dict() 226 | np.random.seed(self.seed_idx) 227 | for m_num in range(len(self.X)): 228 | if self.modalties[m_num] == "FUN": 229 | fundus_data = np.load(self.X[m_num][file_num]).astype(np.float32) 230 | # first 231 | # data_PIL = Image.fromarray(fundus_data/255.0) 232 | # data_PIL = data_PIL.convert("RGB") 233 | # np_data = np.array(data_PIL).transpose((2, 1, 0)) 234 | # data[m_num] = np_data 235 | 236 | # right 237 | data_PIL = Image.fromarray(fundus_data) 238 | data_PIL = data_PIL.convert("RGB") 239 | np_data = np.array(data_PIL).transpose((2, 1, 0)) 240 | data[m_num] = np_data/255.0 241 | # resize to 256*256 242 | # np_km = np.load(self.X[m_num][file_num]).astype(np.float32) 243 | # Image_km = Image.fromarray(np.uint8(np_km.transpose(1, 2, 0))) 244 | # resize_km = Image_km.resize((256,256)) 245 | # data[m_num] = np.array(resize_km).transpose(2, 1, 0).astype(np.float32) 246 | # plt.figure(5) 247 | # # plt.imshow(data[m_num].transpose(1,2,0).astype(np.uint8)) 248 | # plt.imshow(data[m_num].transpose(1,2,0)) 249 | # plt.axis('off') 250 | # plt.show() 251 | noise_data = data[m_num].copy() 252 | 253 | # ## Noise begin 254 | if self.condition == 'noise': 255 | if self.condition_name == "SaltPepper": 256 | # data[m_num] = addsalt_pepper(data[m_num], self.Condition_SP_Variance) # c, 257 | noise_data = add_salt_peper(noise_data.transpose(1, 2, 0), self.Condition_SP_Variance) # c, 258 | noise_data = noise_data.transpose(2, 0, 1) 259 | # data[m_num] = data[m_num] + noise_data.astype(np.float32) 260 | # data[m_num] = data[m_num] 261 | elif self.condition_name == "Gaussian": 262 | 263 | noise_add = np.random.normal(0, self.Condition_G_Variance, noise_data.shape) 264 | # noise_add = np.random.random(noise_data.shape) * self.Condition_G_Variance 265 | noise_data = noise_data + noise_add 266 | noise_data = np.clip(noise_data, 0.0, 1.0) 267 | 268 | else: 269 | # noise_add = np.random.random(noise_data.shape) * self.Condition_G_Variance 270 | noise_add = np.random.normal(0, self.Condition_G_Variance, noise_data.shape) 271 | noise_data = noise_data + noise_add 272 | noise_data = np.clip(noise_data, 0.0, 1.0) 273 | noise_data = add_salt_peper(noise_data, self.Condition_SP_Variance) # c, 274 | # data[m_num] = noise_data 275 | ## plt.figure(6) 276 | ## plt.imshow(noise_data.transpose(1, 2, 0)) 277 | ## plt.axis('off') 278 | ## plt.show() 279 | data[m_num] = noise_data.astype(np.float32) 280 | ## Noise end 281 | 282 | else: 283 | kk = np.load(self.X[m_num][file_num]).astype(np.float32) 284 | kk = kk / 255.0 285 | # kk = np.load(self.X[m_num][file_num]) 286 | noise_kk = kk.copy() 287 | # plt.figure(1) 288 | # plt.imshow(kk[0, :, :], cmap="gray") 289 | # plt.axis('off') 290 | # plt.show() 291 | # plt.figure(2) 292 | # plt.imshow(kk[127, :, :], cmap="gray") 293 | # plt.axis('off') 294 | # plt.show() 295 | 296 | # Noise begin 297 | # if self.condition == 'noise': 298 | # ## plt.figure(1) 299 | # ## PIL_kk = Image.fromarray(kk[60,:,:]) 300 | # ## plt.imshow(PIL_kk) 301 | # ## plt.show() 302 | # if self.condition_name == "SaltPepper": 303 | # for i in range(kk.shape[0]): 304 | # noise_kk[i,:,:] = add_salt_peper_3D(kk[i,:,:], self.Condition_SP_Variance) # c, 305 | # 306 | # elif self.condition_name == "Gaussian": 307 | # # noise_data = np.random.random(kk.shape) * self.Condition_G_Variance 308 | # noise_add = np.random.normal(0, self.Condition_G_Variance, kk.shape) 309 | # # noise_kk = kk + noise_data.astype(np.float32) 310 | # noise_kk = noise_kk + noise_add 311 | # # if noise_kk.min() < 0: 312 | # # low_clip = -1. 313 | # # else: 314 | # # low_clip = 0. 315 | # noise_kk = np.clip(noise_kk, 0.0, 1.0) 316 | # else: 317 | # noise_add = np.random.normal(0, self.Condition_G_Variance, kk.shape) 318 | # # noise_kk = kk + noise_data.astype(np.float32) 319 | # noise_kk = noise_kk + noise_add 320 | # for i in range(kk.shape[0]): 321 | # noise_kk[i,:,:] = add_salt_peper_3D(kk[i,:,:], self.Condition_SP_Variance) # c, 322 | # Noise End 323 | 324 | # plt.figure(3) 325 | # plt.imshow(noise_kk[0, :, :], cmap="gray") 326 | # plt.axis('off') 327 | # plt.show() 328 | # plt.figure(4) 329 | # plt.imshow(noise_kk[127, :, :], cmap="gray") 330 | # plt.axis('off') 331 | # plt.show() 332 | data[m_num] = np.expand_dims(noise_kk.astype(np.float32), axis=0) 333 | # data[m_num] = np.expand_dims(kk, axis=0) 334 | # data[m_num] = self.__itensity_normalize_one_volume__(data[m_num]) 335 | # data[m_num] = data[m_num] / 255.0 336 | # plt.figure(2) 337 | # PIL_noise = Image.fromarray(data[m_num][0, 60,:,:]) 338 | # plt.imshow(PIL_noise) 339 | # plt.show() 340 | 341 | # plt.figure(1) 342 | # plt.imshow(data[0].transpose(1,2,0)) 343 | # plt.show() 344 | # plt.figure(2) 345 | # plt.imshow(data[1][60,:,:].transpose(1,2,0)) 346 | # plt.show() 347 | target_y = int(self.y[file_num]) 348 | target_y = np.array(target_y) 349 | target = torch.from_numpy(target_y) 350 | return data, target 351 | 352 | def __itensity_normalize_one_volume__(self, volume): 353 | """ 354 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 355 | inputs: 356 | volume: the input nd volume 357 | outputs: 358 | out: the normalized nd volume 359 | """ 360 | 361 | pixels = volume[volume > 0] 362 | mean = pixels.mean() 363 | std = pixels.std() 364 | out = (volume - mean) / std 365 | out_random = np.random.normal(0, 1, size=volume.shape) 366 | out[volume == 0] = out_random[volume == 0] 367 | return out 368 | 369 | def __len__(self): 370 | return len(self.X[0]) 371 | 372 | class Multi_modal_data(Dataset): 373 | """ 374 | load multi-view data 375 | """ 376 | 377 | def __init__(self, root, modal_number,modalties,mode,condition,args, folder='folder0'): 378 | """ 379 | :param root: data name and path 380 | :param train: load training set or test set 381 | """ 382 | super(Multi_modal_data, self).__init__() 383 | self.root = root 384 | self.mode = mode 385 | self.data_path = self.root + folder + "/" 386 | self.modalties = modalties 387 | self.condition = condition 388 | self.condition_name = args.condition_name 389 | self.seed_idx = args.seed_idx 390 | self.Condition_SP_Variance = args.Condition_SP_Variance 391 | self.Condition_G_Variance = args.Condition_G_Variance 392 | y_files = [] 393 | 394 | self.X = dict() 395 | for m_num in range(modal_number): 396 | x_files = [] 397 | c_m = modalties[m_num] 398 | with open(join(self.data_path, self.mode +"_" + c_m + '.txt'), 399 | 'r',encoding="gb18030",errors="ignore") as fx: 400 | files = fx.readlines() 401 | for file in files: 402 | file = file.replace('\n', '') 403 | x_files.append(file) 404 | self.X[m_num] =x_files 405 | with open(join(self.data_path, self.mode + '_GT.txt'), 406 | 'r') as fy: 407 | yfiles = fy.readlines() 408 | for yfile in yfiles: 409 | yfile = yfile.replace('\n', '') 410 | y_files.append(yfile) 411 | self.y = y_files 412 | 413 | def __getitem__(self, file_num): 414 | data = dict() 415 | np.random.seed(self.seed_idx) 416 | for m_num in range(len(self.X)): 417 | if self.modalties[m_num] == "FUN": 418 | data[m_num] = np.load(self.X[m_num][file_num]).astype(np.float32) 419 | # plt.figure(4) 420 | # plt.imshow(data[m_num].transpose(1,2,0).astype(np.uint8)) 421 | # plt.axis('off') 422 | # plt.show() 423 | data[m_num] = data[m_num]/255.0 424 | # resize to 256*256 425 | # np_km = np.load(self.X[m_num][file_num]).astype(np.float32) 426 | # Image_km = Image.fromarray(np.uint8(np_km.transpose(1, 2, 0))) 427 | # resize_km = Image_km.resize((256,256)) 428 | # data[m_num] = np.array(resize_km).transpose(2, 1, 0).astype(np.float32) 429 | # plt.figure(5) 430 | # plt.imshow(data[m_num].transpose(1,2,0)) 431 | # plt.axis('off') 432 | # plt.show() 433 | noise_data = data[m_num].copy() 434 | if self.condition == 'noise': 435 | if self.condition_name == "SaltPepper": 436 | # data[m_num] = addsalt_pepper(data[m_num], self.Condition_SP_Variance) # c, 437 | noise_data = add_salt_peper(noise_data.transpose(1, 2, 0), self.Condition_SP_Variance) # c, 438 | noise_data = noise_data.transpose(2, 0, 1) 439 | # data[m_num] = data[m_num] + noise_data.astype(np.float32) 440 | # data[m_num] = data[m_num] 441 | elif self.condition_name == "Gaussian": 442 | 443 | noise_add = np.random.normal(0, self.Condition_G_Variance, noise_data.shape) 444 | # noise_add = np.random.random(noise_data.shape) * self.Condition_G_Variance 445 | noise_data = noise_data + noise_add 446 | noise_data = np.clip(noise_data, 0.0, 1.0) 447 | 448 | else: 449 | # noise_add = np.random.random(noise_data.shape) * self.Condition_G_Variance 450 | noise_add = np.random.normal(0, self.Condition_G_Variance, noise_data.shape) 451 | noise_data = noise_data + noise_add 452 | noise_data = np.clip(noise_data, 0.0, 1.0) 453 | noise_data = add_salt_peper(noise_data, self.Condition_SP_Variance) # c, 454 | # data[m_num] = noise_data 455 | # plt.figure(6) 456 | # plt.imshow(noise_data.transpose(1, 2, 0)) 457 | # plt.axis('off') 458 | # plt.show() 459 | data[m_num] = noise_data.astype(np.float32) 460 | 461 | else: 462 | kk = np.load(self.X[m_num][file_num]).astype(np.float32) 463 | # plt.figure(1) 464 | # plt.imshow(kk[0, :, :], cmap="gray") 465 | # plt.axis('off') 466 | # plt.show() 467 | # plt.figure(2) 468 | # plt.imshow(kk[127, :, :], cmap="gray") 469 | # plt.axis('off') 470 | # plt.show() 471 | kk = kk / 255.0 472 | # kk = np.load(self.X[m_num][file_num]) 473 | noise_kk = kk.copy() 474 | # plt.figure(3) 475 | # plt.imshow(kk[0, :, :], cmap="gray") 476 | # plt.axis('off') 477 | # plt.show() 478 | # plt.figure(4) 479 | # plt.imshow(kk[127, :, :], cmap="gray") 480 | # plt.axis('off') 481 | # plt.show() 482 | # if self.condition == 'noise': 483 | # ## plt.figure(1) 484 | # ## PIL_kk = Image.fromarray(kk[60,:,:]) 485 | # ## plt.imshow(PIL_kk) 486 | # ## plt.show() 487 | # if self.condition_name == "SaltPepper": 488 | # for i in range(kk.shape[0]): 489 | # noise_kk[i,:,:] = add_salt_peper_3D(kk[i,:,:], self.Condition_SP_Variance) # c, 490 | # 491 | # elif self.condition_name == "Gaussian": 492 | # # noise_data = np.random.random(kk.shape) * self.Condition_G_Variance 493 | # noise_add = np.random.normal(0, self.Condition_G_Variance, kk.shape) 494 | # # noise_kk = kk + noise_data.astype(np.float32) 495 | # noise_kk = noise_kk + noise_add 496 | # # if noise_kk.min() < 0: 497 | # # low_clip = -1. 498 | # # else: 499 | # # low_clip = 0. 500 | # noise_kk = np.clip(noise_kk, 0.0, 1.0) 501 | # else: 502 | # noise_add = np.random.normal(0, self.Condition_G_Variance, kk.shape) 503 | # # noise_kk = kk + noise_data.astype(np.float32) 504 | # noise_kk = noise_kk + noise_add 505 | # for i in range(kk.shape[0]): 506 | # noise_kk[i,:,:] = add_salt_peper_3D(kk[i,:,:], self.Condition_SP_Variance) # c, 507 | # plt.figure(1) 508 | # plt.imshow(noise_kk[0, :, :], cmap="gray") 509 | # plt.axis('off') 510 | # plt.show() 511 | # plt.figure(2) 512 | # plt.imshow(noise_kk[127, :, :], cmap="gray") 513 | # plt.axis('off') 514 | # plt.show() 515 | data[m_num] = np.expand_dims(noise_kk.astype(np.float32), axis=0) 516 | # data[m_num] = np.expand_dims(kk, axis=0) 517 | # data[m_num] = self.__itensity_normalize_one_volume__(data[m_num]) 518 | # data[m_num] = data[m_num] / 255.0 519 | # plt.figure(2) 520 | # PIL_noise = Image.fromarray(data[m_num][0, 60,:,:]) 521 | # plt.imshow(PIL_noise) 522 | # plt.show() 523 | 524 | # plt.figure(1) 525 | # plt.imshow(data[0].transpose(1,2,0)) 526 | # plt.show() 527 | # plt.figure(2) 528 | # plt.imshow(data[1][60,:,:].transpose(1,2,0)) 529 | # plt.show() 530 | target_y = int(self.y[file_num]) 531 | target_y = np.array(target_y) 532 | target = torch.from_numpy(target_y) 533 | return data, target 534 | 535 | def __itensity_normalize_one_volume__(self, volume): 536 | """ 537 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 538 | inputs: 539 | volume: the input nd volume 540 | outputs: 541 | out: the normalized nd volume 542 | """ 543 | 544 | pixels = volume[volume > 0] 545 | mean = pixels.mean() 546 | std = pixels.std() 547 | out = (volume - mean) / std 548 | out_random = np.random.normal(0, 1, size=volume.shape) 549 | out[volume == 0] = out_random[volume == 0] 550 | return out 551 | 552 | def __len__(self): 553 | return len(self.X[0]) 554 | 555 | 556 | class GAMMA_dataset(Dataset): 557 | def __init__(self, 558 | args, 559 | dataset_root, 560 | oct_img_size, 561 | fundus_img_size, 562 | mode='train', 563 | label_file='', 564 | filelists=None, 565 | ): 566 | self.condition = args.condition 567 | self.condition_name = args.condition_name 568 | self.Condition_SP_Variance = args.Condition_SP_Variance 569 | self.Condition_G_Variance = args.Condition_G_Variance 570 | self.seed_idx = args.seed_idx 571 | 572 | self.dataset_root = dataset_root 573 | self.input_D = oct_img_size[0][0] 574 | self.input_H = oct_img_size[0][1] 575 | self.input_W = oct_img_size[0][2] 576 | # mean = (0.3163843, 0.86174834, 0.3641431) 577 | # std = (0.24608557, 0.11123227, 0.26710403) 578 | # normalize = transforms.Normalize(mean=mean, std=std) 579 | 580 | self.fundus_train_transforms = transforms.Compose([ 581 | transforms.ToTensor(), 582 | transforms.RandomApply([ 583 | transforms.ColorJitter(0.2, 0.2, 0.2, 0.1) 584 | ], p=0.8), 585 | transforms.RandomGrayscale(p=0.2), 586 | transforms.RandomHorizontalFlip(), 587 | # normalize, 588 | ]) 589 | 590 | self.oct_train_transforms = transforms.Compose([ 591 | transforms.ToTensor(), 592 | transforms.RandomHorizontalFlip(), 593 | ]) 594 | 595 | self.fundus_val_transforms = transforms.Compose([ 596 | transforms.ToTensor(), 597 | ]) 598 | 599 | self.oct_val_transforms = transforms.Compose([ 600 | transforms.ToTensor(), 601 | ]) 602 | 603 | self.mode = mode.lower() 604 | label = {row['data']: row[1:].values 605 | for _, row in pd.read_excel(label_file).iterrows()} 606 | # if train is all 607 | self.file_list = [] 608 | for f in filelists: 609 | self.file_list.append([f, label[int(f)]]) 610 | 611 | # if only for test 612 | # if self.mode == 'train': 613 | # label = {row['data']: row[1:].values 614 | # for _, row in pd.read_excel(label_file).iterrows()} 615 | # 616 | # self.file_list = [[f, label[int(f)]] for f in os.listdir(dataset_root)] 617 | # elif self.mode == "test" or self.mode == "val" : 618 | # self.file_list = [[f, None] for f in os.listdir(dataset_root)] 619 | 620 | # if filelists is not None: 621 | # self.file_list = [item for item in self.file_list if item[0] in filelists] 622 | def __getitem__(self, idx): 623 | data = dict() 624 | 625 | real_index, label = self.file_list[idx] 626 | 627 | # Fundus read 628 | fundus_img = np.load(self.dataset_root+"/"+real_index+"/"+real_index+".npy").astype(np.float32) 629 | fundus_img = fundus_img.transpose(1,2,0) / 255.0 630 | 631 | # OCT read 632 | # oct_series_list = sorted(os.listdir(os.path.join(self.dataset_root, real_index, real_index)), 633 | # key=lambda x: int(x.strip("_")[0])) 634 | oct_series_list = os.listdir(os.path.join(self.dataset_root, real_index, real_index)) 635 | oct_img = np.load(os.path.join(self.dataset_root+"/"+real_index+"/"+real_index+"/"+oct_series_list[0])) 636 | oct_img = self.__resize_oct_data__(oct_img) 637 | 638 | oct_img = oct_img / 255.0 639 | np.random.seed(self.seed_idx) 640 | 641 | # add noise on fundus & OCT 642 | if self.condition == 'noise': 643 | if self.condition_name == "SaltPepper": 644 | fundus_img = add_salt_peper(fundus_img.transpose(1, 2, 0), self.Condition_SP_Variance) # c, 645 | fundus_img = fundus_img.transpose(2, 0, 1) 646 | for i in range(oct_img.shape[0]): 647 | oct_img[i, :, :] = add_salt_peper_3D(oct_img[i, :, :], self.Condition_SP_Variance) # c, 648 | 649 | elif self.condition_name == "Gaussian": 650 | # noise_add = np.random.normal(0, self.Condition_G_Variance, fundus_img.shape) 651 | # ## noise_add = np.random.random(noise_data.shape) * self.Condition_G_Variance 652 | # fundus_img = fundus_img + noise_add 653 | # fundus_img = np.clip(fundus_img, 0.0, 1.0) 654 | 655 | noise_add = np.random.normal(0, self.Condition_G_Variance, oct_img.shape) 656 | oct_img = oct_img + noise_add 657 | oct_img = np.clip(oct_img, 0.0, 1.0) 658 | 659 | else: 660 | # noise_add = np.random.random(noise_data.shape) * self.Condition_G_Variance 661 | noise_add = np.random.normal(0, self.Condition_G_Variance, fundus_img.shape) 662 | fundus_img = fundus_img + noise_add 663 | fundus_img = np.clip(fundus_img, 0.0, 1.0) 664 | 665 | noise_add = np.random.normal(0, self.Condition_G_Variance, oct_img.shape) 666 | oct_img = oct_img + noise_add 667 | oct_img = np.clip(oct_img, 0.0, 1.0) 668 | 669 | fundus_img = add_salt_peper(fundus_img, self.Condition_SP_Variance) # c, 670 | 671 | for i in range(oct_img.shape[0]): 672 | oct_img[i, :, :] = add_salt_peper_3D(oct_img[i, :, :], self.Condition_SP_Variance) # c, 673 | 674 | if self.mode == "train": 675 | fundus_img = self.fundus_train_transforms(fundus_img.astype(np.float32)) 676 | oct_img = self.oct_train_transforms(oct_img.astype(np.float32)) 677 | else: 678 | fundus_img = self.fundus_val_transforms(fundus_img) 679 | oct_img = self.oct_val_transforms(oct_img) 680 | # data[0] = fundus_img.transpose(2, 0, 1) # H, W, C -> C, H, W 681 | # data[1] = oct_img.squeeze(-1) # D, H, W, 1 -> D, H, W 682 | data[0] = fundus_img 683 | data[1] = oct_img.unsqueeze(0) 684 | 685 | label = label.argmax() 686 | return data, label 687 | 688 | def __len__(self): 689 | return len(self.file_list) 690 | 691 | def __resize_oct_data__(self, data): 692 | """ 693 | Resize the data to the input size 694 | """ 695 | data = data.squeeze() 696 | [depth, height, width] = data.shape 697 | scale = [self.input_D*1.0/depth, self.input_H *1.0/height, self.input_W*1.0/width] 698 | data = ndimage.interpolation.zoom(data, scale, order=0) 699 | # data = data.unsqueeze() 700 | return data 701 | 702 | class Multi_modal_data_OOD(Dataset): 703 | """ 704 | load multi-view data 705 | """ 706 | 707 | def __init__(self, root, ood_datapath,oodfile_list,ood_dataclass,modal_number,modalties,mode,condition,args, folder='folder0'): 708 | """ 709 | :param root: data name and path 710 | :param train: load training set or test set 711 | """ 712 | super(Multi_modal_data_OOD, self).__init__() 713 | self.root = root 714 | self.mode = mode 715 | self.data_path = self.root + folder + "/" 716 | self.ood_data_path = ood_datapath 717 | self.ood_dataclass = ood_dataclass 718 | self.oodfile_list = oodfile_list 719 | self.modalties = modalties 720 | self.condition = condition 721 | self.seed_idx = args.seed_idx 722 | 723 | y_files = [] 724 | 725 | self.X = dict() 726 | for m_num in range(modal_number): 727 | x_files = [] 728 | c_m = modalties[m_num] 729 | if c_m == self.ood_dataclass: 730 | for real_index in self.oodfile_list: 731 | file = self.ood_data_path + "/" + real_index + "/" + real_index + ".npy" 732 | x_files.append(file) 733 | self.X[m_num] = x_files 734 | else: 735 | with open(join(self.data_path, self.mode + "_" + c_m + '.txt'), 736 | 'r', encoding="gb18030", errors="ignore") as fx: 737 | files = fx.readlines() 738 | for file in files: 739 | file = file.replace('\n', '') 740 | x_files.append(file) 741 | self.X[m_num] = x_files 742 | with open(join(self.data_path, self.mode + '_GT.txt'), 743 | 'r') as fy: 744 | yfiles = fy.readlines() 745 | for yfile in yfiles: 746 | yfile = yfile.replace('\n', '') 747 | y_files.append(yfile) 748 | self.y = y_files 749 | 750 | def __getitem__(self, file_num): 751 | data = dict() 752 | np.random.seed(self.seed_idx) 753 | for m_num in range(len(self.X)): 754 | if self.modalties[m_num] == "FUN": 755 | # # Fundus read 756 | # fundus_img = np.load(self.dataset_root + "/" + real_index + "/" + real_index + ".npy").astype( 757 | # np.float32) 758 | # fundus_img = fundus_img.transpose(1, 2, 0) / 255.0 759 | data[m_num] = np.load(self.X[m_num][file_num]).astype(np.float32) 760 | # plt.figure(4) 761 | # plt.imshow(data[m_num].transpose(1,2,0).astype(np.uint8)) 762 | # plt.axis('off') 763 | # plt.show() 764 | data[m_num] = data[m_num]/255.0 765 | # resize to 256*256 766 | # np_km = np.load(self.X[m_num][file_num]).astype(np.float32) 767 | # Image_km = Image.fromarray(np.uint8(np_km.transpose(1, 2, 0))) 768 | # resize_km = Image_km.resize((256,256)) 769 | # data[m_num] = np.array(resize_km).transpose(2, 1, 0).astype(np.float32) 770 | # plt.figure(5) 771 | # plt.imshow(data[m_num].transpose(1,2,0)) 772 | # plt.axis('off') 773 | # plt.show() 774 | 775 | else: 776 | kk = np.load(self.X[m_num][file_num]).astype(np.float32) 777 | # plt.figure(1) 778 | # plt.imshow(kk[0, :, :], cmap="gray") 779 | # plt.axis('off') 780 | # plt.show() 781 | # plt.figure(2) 782 | # plt.imshow(kk[127, :, :], cmap="gray") 783 | # plt.axis('off') 784 | # plt.show() 785 | kk = kk / 255.0 786 | # kk = np.load(self.X[m_num][file_num]) 787 | # data[m_num] = np.expand_dims(kk, axis=0) 788 | # data[m_num] = self.__itensity_normalize_one_volume__(data[m_num]) 789 | # data[m_num] = data[m_num] / 255.0 790 | # plt.figure(2) 791 | # PIL_noise = Image.fromarray(data[m_num][0, 60,:,:]) 792 | # plt.imshow(PIL_noise) 793 | # plt.show() 794 | data[m_num] = np.expand_dims(kk.astype(np.float32), axis=0) 795 | 796 | # plt.figure(1) 797 | # plt.imshow(data[0].transpose(1,2,0)) 798 | # plt.show() 799 | # plt.figure(2) 800 | # plt.imshow(data[1][60,:,:].transpose(1,2,0)) 801 | # plt.show() 802 | target_y = int(self.y[file_num]) 803 | target_y = np.array(target_y) 804 | target = torch.from_numpy(target_y) 805 | return data, target 806 | 807 | def __itensity_normalize_one_volume__(self, volume): 808 | """ 809 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 810 | inputs: 811 | volume: the input nd volume 812 | outputs: 813 | out: the normalized nd volume 814 | """ 815 | 816 | pixels = volume[volume > 0] 817 | mean = pixels.mean() 818 | std = pixels.std() 819 | out = (volume - mean) / std 820 | out_random = np.random.normal(0, 1, size=volume.shape) 821 | out[volume == 0] = out_random[volume == 0] 822 | return out 823 | 824 | def __len__(self): 825 | return len(self.X[1]) 826 | # def normalize(x, min=0): 827 | # if min == 0: 828 | # scaler = MinMaxScaler([0, 1]) 829 | # else: # min=-1 830 | # scaler = MinMaxScaler((-1, 1)) 831 | # norm_x = scaler.fit_transform(x) 832 | # return norm_x 833 | --------------------------------------------------------------------------------