├── .gitignore ├── README.md ├── concept_dataset.py ├── dataset ├── __init__.py └── isic_dataset.py ├── dataset_utils ├── build_training_label.py └── build_training_npy.py ├── fig └── framework.png ├── model ├── __init__.py ├── explicd.py └── utils.py ├── requirements.txt ├── train_blackbox.py ├── train_explicd.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *_bkp.py 4 | flops.py 5 | checkpoint/ 6 | log*/ 7 | exp/ 8 | initmodel/ 9 | result/ 10 | *.swp 11 | test* 12 | 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Explicd: Explainable Language-Informed Criteria-based Diagnosis 2 | 3 | This repository contains the official implementation of "[Aligning Human Knowledge with Visual Concepts Towards Explainable Medical Image Classification (MICCAI 2024 Early Accept)](https://arxiv.org/abs/2406.05596)" 4 | 5 | ## Abstract 6 | 7 | Although explainability is essential in the clinical diagnosis, most deep learning models still function as black boxes without elucidating their decision-making process. In this study, we investigate the explainable model development that can mimic the decision-making process of human experts by fusing the domain knowledge of explicit diagnostic criteria. We introduce a simple yet effective framework, Explicd, towards Explainable language-informed criteria-based diagnosis. Explicd initiates its process by querying domain knowledge from either large language models (LLMs) or human experts to establish diagnostic criteria across various concept axes (e.g., color, shape, texture, or specific patterns of diseases). By leveraging a pretrained vision-language model, Explicd injects these criteria into the embedding space as knowledge anchors, thereby facilitating the learning of corresponding visual concepts within medical images. The final diagnostic outcome is determined based on the similarity scores between the encoded visual concepts and the textual criteria embeddings. Through extensive evaluation of five medical image classification benchmarks, Explicd has demonstrated its inherent explainability and extends to improve classification performance compared to traditional black-box models. 8 | 9 | ![framework](./fig/framework.png) 10 | 11 | ## Key Features 12 | 13 | - Queries domain knowledge from LLMs or human experts to establish diagnostic criteria 14 | - Utilizes a vision-language model to inject criteria into the embedding space 15 | - Introduces a visual concept learning module with a criteria anchor contrastive loss 16 | - Provides explainable diagnoses based on similarity scores between encoded visual concepts and textual criteria embeddings 17 | 18 | 19 | 20 | ## Usage 21 | Python version >= 3.9 22 | 23 | ```bash 24 | git clone https://github.com/yhygao/Explicd.git 25 | cd Explicd 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | Download the dataset from ISIC2018 website and process with scripts under `dataset_utils` or use the our processed files from [Google Drive](https://drive.google.com/drive/folders/1vf6X44zALelFXQNCAmg0_VizT4yxRkse?usp=drive_link) 30 | 31 | To train blackbox baseline model like ResNet or ViT, use 32 | ```bash 33 | python train_blackbox.py --model resnet50.a1_in1k --data-path path_to_the_dataset --gpu 0 34 | ``` 35 | To train Explicd, use 36 | ```bash 37 | python train_explicd.py --data-path path_to_the_dataset --gpu 0 38 | ``` 39 | 40 | The queried diagnostic criteria and concepts from GPT are in `concept_dataset.py` 41 | 42 | ## Citation 43 | 44 | If you find this work useful in your research, please consider citing: 45 | 46 | ```bibtex 47 | @article{gao2024aligning, 48 | title={Aligning Human Knowledge with Visual Concepts Towards Explainable Medical Image Classification}, 49 | author={Gao, Yunhe and Gu, Difei and Zhou, Mu and Metaxas, Dimitris}, 50 | journal={arXiv preprint arXiv:2406.05596}, 51 | year={2024} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /concept_dataset.py: -------------------------------------------------------------------------------- 1 | explicid_isic_dict = { 2 | 'color': ['highly variable, often with multiple colors (black, brown, red, white, blue)', 'uniformly tan, brown, or black', 'translucent, pearly white, sometimes with blue, brown, or black areas', 'red, pink, or brown, often with a scale', 'light brown to black', 'pink brown or red', 'red, purple, or blue'], 3 | 'shape': ['irregular', 'round', 'round to irregular', 'variable'], 4 | 'border': ['often blurry and irregular', 'sharp and well-defined', 'rolled edges, often indistinct'], 5 | 'dermoscopic patterns': ['atypical pigment network, irregular streaks, blue-whitish veil, irregular', 'regular pigment network, symmetric dots and globules', 'arborizing vessels, leaf-like areas, blue-gray avoid nests', 'strawberry pattern, glomerular vessels, scale', 'cerebriform pattern, milia-like cysts, comedo-like openings', 'central white patch, peripheral pigment network', 'depends on type (e.g., cherry angiomas have red lacunae; spider angiomas have a central red dot with radiating legs'], 6 | 'texture': ['a raised or ulcerated surface', 'smooth', 'smooth, possibly with telangiectasias', 'rough, scaly', 'warty or greasy surface', 'firm, may dimple when pinched'], 7 | 'symmetry': ['asymmetrical', 'symmetrical', 'can be symmetrical or asymmetrical depending on type'], 8 | 'elevation': ['flat to raised', 'raised with possible central ulceration', 'slightly raised', 'slightly raised maybe thick'] 9 | 10 | } 11 | 12 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhygao/Explicd/5702c8bb7d70a3a7b78a5c89ab8b318e0bb53db2/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/isic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.utils.data import Dataset, DataLoader 7 | from torchvision import transforms 8 | import math 9 | import pdb 10 | from utils import GaussianLayer 11 | 12 | 13 | class SkinDataset(Dataset): 14 | 15 | def __init__(self, dataset_dir, mode='train', transforms=None, flag=0, debug=False, config=None, return_concept_label=False): 16 | self.mode = mode 17 | if debug: 18 | dataset_dir = dataset_dir + 'debug' 19 | print(dataset_dir) 20 | self.dataset_dir = dataset_dir 21 | self.transforms = transforms 22 | self.flag = flag 23 | self.args = config 24 | self.return_concept_label = return_concept_label 25 | 26 | 27 | print('start loading %s dataset'%mode) 28 | 29 | data = np.load(dataset_dir+'dataList.npy', mmap_mode='r', allow_pickle=False) 30 | label = np.load(dataset_dir+'labelList.npy', mmap_mode='r', allow_pickle=False) 31 | 32 | rng = np.random.default_rng(29) 33 | shuffled_indices = rng.permutation(data.shape[0]) 34 | 35 | self.dataList = data 36 | self.labelList = label 37 | 38 | shuffled_label = label[shuffled_indices] 39 | 40 | self.origin_size = [data.shape[1], data.shape[2]] 41 | 42 | index_list = np.zeros((0), dtype=np.int64) 43 | for i in range(7): 44 | print('load class %d'%i) 45 | num = (shuffled_label==i).sum() 46 | num = num // 5 47 | 48 | index = np.where(shuffled_label == i)[0] 49 | test_index = index[num*flag:num*(flag+1)] 50 | train_index = np.array(list(set(index) - set(test_index))) 51 | 52 | num_val = len(test_index) // 2 53 | if self.mode == 'train': 54 | index_list = np.concatenate((index_list, shuffled_indices[train_index]), axis=0) 55 | 56 | 57 | elif self.mode == 'val': 58 | index_list = np.concatenate((index_list, shuffled_indices[test_index[:num_val]]), axis=0) 59 | else: 60 | index_list = np.concatenate((index_list, shuffled_indices[test_index[num_val:]]), axis=0) 61 | 62 | self.index_list = index_list 63 | 64 | 65 | self.concept_label_map = [ 66 | [0, 0, 0, 0, 0, 0, 0], # MEL 67 | [1, 1, 1, 1, 1, 1, 0], # NV 68 | [2, 0, 2, 2, 2, 0, 1], # BCC 69 | [3, 0, 0, 3, 3, 0, 2], # AKIEC 70 | [4, 2, 1, 4, 4, 1, 3], # BKL 71 | [5, 1, 1, 5, 5, 1, 0], # DF 72 | [6, 3, 1, 6, 1, 2, 0], # VASC 73 | ] 74 | 75 | print('load done') 76 | 77 | def __len__(self): 78 | return self.index_list.shape[0] 79 | 80 | def __getitem__(self, idx): 81 | real_idx = self.index_list[idx] 82 | img = self.dataList[real_idx] 83 | label = int(self.labelList[real_idx]) 84 | 85 | if self.transforms is not None: 86 | img = self.transforms(img) 87 | 88 | if self.return_concept_label: 89 | return img, label, np.array(self.concept_label_map[label]) 90 | else: 91 | return img, label 92 | 93 | def random_deform(self, img): 94 | 95 | img = img.unsqueeze(0) 96 | 97 | dx = (torch.randn_like(img[0, 0, :, :])*2-1) * self.args.deform_scale 98 | dy = (torch.randn_like(img[0, 0, :, :])*2-1) * self.args.deform_scale 99 | 100 | dx = self.smooth(dx.unsqueeze(0).unsqueeze(0))[0, 0, :, :].detach() 101 | dy = self.smooth(dy.unsqueeze(0).unsqueeze(0))[0, 0, :, :].detach() 102 | 103 | 104 | x_line = torch.linspace(-1, 1, steps=img.shape[2]).unsqueeze(0) 105 | x_line = x_line.repeat(img.shape[2], 1) 106 | y_line = torch.linspace(-1, 1, steps=img.shape[3]).unsqueeze(1) 107 | y_line = y_line.repeat(1, img.shape[3]) 108 | 109 | x_line += dx 110 | y_line += dy 111 | 112 | grid = torch.stack((x_line, y_line), dim=2) 113 | grid = grid.unsqueeze(0) 114 | 115 | img = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 116 | 117 | 118 | img = img.squeeze(0) 119 | 120 | return img 121 | 122 | 123 | def random_affine(self, img): 124 | 125 | img = img.unsqueeze(0) 126 | 127 | scale_x = np.random.random() * (2*self.args.scale) + 1 - self.args.scale 128 | scale_y = np.random.random() * (2*self.args.scale) + 1 - self.args.scale 129 | 130 | shear_x = np.random.random() * self.args.scale 131 | shear_y = np.random.random() * self.args.scale 132 | #shear_x = 0 133 | #shear_y = 0 134 | 135 | 136 | angle = np.random.randint(-self.args.angle, self.args.angle) 137 | angle = (angle / 180.) * math.pi 138 | 139 | theta_scale = torch.tensor([[scale_x, shear_x, 0], 140 | [shear_y, scale_y, 0], 141 | [0, 0, 1]]).float() 142 | theta_rotate = torch.tensor([[math.cos(angle), -math.sin(angle), 0], 143 | [math.sin(angle), math.cos(angle), 0], 144 | [0, 0, 1]]).float() 145 | 146 | theta = torch.mm(theta_scale, theta_rotate)[0:2, :] 147 | grid = F.affine_grid(theta.unsqueeze(0), img.size(), align_corners=True) 148 | 149 | img = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 150 | img = img.squeeze(0) 151 | 152 | return img 153 | 154 | -------------------------------------------------------------------------------- /dataset_utils/build_training_label.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import csv 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import pdb 8 | 9 | 10 | 11 | reader = pd.read_csv('./train_label.csv') 12 | cls_list = ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC'] 13 | 14 | cls_num = np.zeros(7) 15 | 16 | with open('training_onehot_label.csv', 'w') as csvfile: 17 | 18 | writer = csv.writer(csvfile) 19 | 20 | for i in range(len(reader)): 21 | name = reader['image'][i] 22 | 23 | oh_label = np.zeros((7)) 24 | for j in range(7): 25 | oh_label[j] = reader[cls_list[j]][i].astype(np.uint8) 26 | 27 | 28 | label = np.argmax(oh_label) 29 | 30 | img = Image.open('./train/%s.jpg'%name) 31 | npimg = np.array(img) 32 | print(name, npimg.shape, label) 33 | 34 | cls_num[label] += 1 35 | 36 | 37 | writer.writerow([name, label]) 38 | 39 | 40 | print(cls_num) 41 | -------------------------------------------------------------------------------- /dataset_utils/build_training_npy.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import csv 4 | 5 | import numpy as np 6 | import pdb 7 | 8 | 9 | 10 | 11 | with open('training_onehot_label.csv', 'r') as csv_file: 12 | 13 | 14 | csv_reader = list(csv.reader(csv_file)) 15 | length = len(csv_reader) 16 | 17 | 18 | 19 | dataList = [] 20 | labelList = [] 21 | 22 | for i in range(length): 23 | 24 | row = csv_reader[i] 25 | name = row[0] 26 | label = row[1] 27 | 28 | 29 | img = Image.open('./train/%s.jpg'%name) 30 | npimg = np.array(img) 31 | 32 | dataList.append(npimg) 33 | labelList.append(int(label)) 34 | 35 | print(name) 36 | 37 | 38 | dataList = np.array(dataList, dtype=np.uint8) 39 | labelList = np.array(labelList, dtype=np.uint8) 40 | 41 | 42 | np.save('./dataList.npy', dataList) 43 | np.save('./labelList.npy', labelList) 44 | 45 | 46 | print('done') 47 | -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhygao/Explicd/5702c8bb7d70a3a7b78a5c89ab8b318e0bb53db2/fig/framework.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .explicd import ExpLICD 2 | 3 | 4 | -------------------------------------------------------------------------------- /model/explicd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from open_clip import create_model_from_pretrained, get_tokenizer 6 | from torchvision import transforms 7 | from .utils import FFN 8 | 9 | import pdb 10 | 11 | 12 | 13 | class ExpLICD(nn.Module): 14 | def __init__(self, concept_list, model_name='biomedclip', config=None): 15 | super().__init__() 16 | 17 | self.concept_list = concept_list 18 | self.model_name = model_name 19 | self.config = config 20 | 21 | if self.model_name in ['biomedclip', 'openclip']: 22 | if self.model_name == 'biomedclip': 23 | self.model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') 24 | self.tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') 25 | elif self.model_name == 'openclip': 26 | self.model, preprocess = create_model_from_pretrained('hf-hub:laion/CLIP-ViT-L-14-laion2B-s32B-b82K') 27 | self.tokenizer = get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-laion2B-s32B-b82K') 28 | 29 | #self.model, preprocess = create_model_from_pretrained('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K') 30 | #self.tokenizer = get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K') 31 | 32 | config.preprocess = preprocess 33 | 34 | self.model.cuda() 35 | 36 | concept_keys = list(concept_list.keys()) 37 | 38 | self.concept_token_dict = {} 39 | for key in concept_keys: 40 | if config.dataset == 'isic2018': 41 | prefix = f"this is a dermoscopic image, the {key} of the lesion is " 42 | attr_concept_list = concept_list[key] 43 | prefix_attr_concept_list = [prefix + concept for concept in attr_concept_list] 44 | tmp_concept_text = self.tokenizer(prefix_attr_concept_list).cuda() 45 | _, tmp_concept_feats, logit_scale = self.model(None, tmp_concept_text) 46 | self.concept_token_dict[key] = tmp_concept_feats.detach() 47 | 48 | 49 | self.logit_scale = logit_scale.detach() 50 | 51 | self.visual_features = [] 52 | 53 | self.hook_list = [] 54 | def hook_fn(module, input, output): 55 | self.visual_features.append(output) # detach to aboid saving computation graph 56 | # might need to remove if finetune the full model 57 | layers = [self.model.visual.trunk.blocks[11]] 58 | for layer in layers: 59 | self.hook_list.append(layer.register_forward_hook(hook_fn)) 60 | 61 | 62 | self.visual_tokens = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(7, 768))) 63 | 64 | self.cross_attn = nn.MultiheadAttention(embed_dim=768, num_heads=12, batch_first=True) 65 | self.ffn = FFN(768, 768*4) 66 | self.norm = nn.LayerNorm(768) 67 | self.proj = nn.Linear(in_features=768, out_features=512, bias=False) 68 | 69 | self.cls_head = nn.Linear(in_features=34, out_features=config.num_class) 70 | 71 | for param in self.model.text.parameters(): 72 | param.requires_grad = False 73 | for param in self.model.visual.trunk.parameters(): 74 | param.requires_grad = True 75 | 76 | self.visual_tokens.requires_grad = True 77 | 78 | def get_backbone_params(self): 79 | return self.model.visual.trunk.parameters() 80 | def get_bridge_params(self): 81 | param_list = [] 82 | 83 | param_list.append(self.visual_tokens) 84 | for param in self.cross_attn.parameters(): 85 | param_list.append(param) 86 | for param in self.ffn.parameters(): 87 | param_list.append(param) 88 | for param in self.norm.parameters(): 89 | param_list.append(param) 90 | for param in self.proj.parameters(): 91 | param_list.append(param) 92 | for param in self.cls_head.parameters(): 93 | param_list.append(param) 94 | 95 | 96 | #param_list.append(self.ffn.parameters()) 97 | #param_list.append(self.norm.parameters()) 98 | #param_list.append(self.proj.parameters()) 99 | #param_list.append(self.cls_head.parameters()) 100 | 101 | return param_list 102 | 103 | 104 | def forward(self, imgs): 105 | 106 | self.visual_features.clear() 107 | #with torch.no_grad(): 108 | # img_feats, _, _ = self.model(imgs, None) 109 | img_feats, _, _ = self.model(imgs, None) 110 | img_feat_map = self.visual_features[0][:, 1:, :] 111 | 112 | B, _, _ = img_feat_map.shape 113 | visual_tokens = self.visual_tokens.repeat(B, 1, 1) 114 | 115 | agg_visual_tokens, _ = self.cross_attn(visual_tokens, img_feat_map, img_feat_map) 116 | agg_visual_tokens = self.proj(self.norm(self.ffn(agg_visual_tokens))) 117 | 118 | agg_visual_tokens = F.normalize(agg_visual_tokens, dim=-1) 119 | 120 | image_logits_dict = {} 121 | idx = 0 122 | for key in self.concept_token_dict.keys(): 123 | image_logits_dict[key] = (self.logit_scale * agg_visual_tokens[:, idx:idx+1, :] @ self.concept_token_dict[key].repeat(B, 1, 1).permute(0, 2, 1)).squeeze(1) 124 | idx += 1 125 | 126 | 127 | image_logits_list = [] 128 | for key in image_logits_dict.keys(): 129 | image_logits_list.append(image_logits_dict[key]) 130 | 131 | image_logits = torch.cat(image_logits_list, dim=-1) 132 | cls_logits = self.cls_head(image_logits) 133 | 134 | return cls_logits, image_logits_dict 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import collections 4 | from functools import partial 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | 11 | class FFN(nn.Module): 12 | def __init__(self, input_dim, ff_dim): 13 | super().__init__() 14 | 15 | self.linear1 = nn.Linear(input_dim, ff_dim) 16 | self.linear2 = nn.Linear(ff_dim, input_dim) 17 | self.relu = nn.ReLU() 18 | 19 | def forward(self, x): 20 | x = self.relu(self.linear1(x)) 21 | x = self.linear2(x) 22 | 23 | return x 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | matplotlib 3 | open-clip-torch 4 | PyYAML 5 | scikit-image 6 | scikit-learn 7 | scipy 8 | tensorboard 9 | timm 10 | tokenizers 11 | torch 12 | torchaudio 13 | torchvision 14 | tqdm 15 | transformers -------------------------------------------------------------------------------- /train_blackbox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import optim 5 | import numpy as np 6 | 7 | import timm 8 | from dataset.isic_dataset import SkinDataset 9 | 10 | from torchvision import transforms, models 11 | from sklearn.metrics import balanced_accuracy_score 12 | 13 | import copy 14 | from torch.utils.data import DataLoader 15 | from optparse import OptionParser 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torchvision.utils import make_grid, save_image 18 | import utils 19 | import matplotlib.pyplot as plt 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import pdb 25 | 26 | DEBUG = False 27 | 28 | 29 | 30 | dataset_dict = { 31 | 'isic2018': SkinDataset, 32 | } 33 | 34 | def train_net(model, config): 35 | 36 | print(config.unique_name) 37 | 38 | data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) 39 | transform = timm.data.create_transform(**data_cfg) 40 | 41 | transform_list = [transforms.ToPILImage()] 42 | transform_list.append(transforms.RandomResizedCrop(size=data_cfg['input_size'][-1], scale=(0.75, 1.0), ratio=(0.75, 1.33), interpolation=utils.get_interpolation_mode(data_cfg['interpolation']))) 43 | transform_list.append(transforms.RandomHorizontalFlip()) 44 | transform_list.append(transforms.RandomVerticalFlip()) 45 | transform_list.append(transforms.ToTensor()) 46 | 47 | if config.dataset == 'isic2018': 48 | transform_list.append(utils.gray_world()) 49 | transform_list.append(transforms.Normalize(mean=data_cfg['mean'], std=data_cfg['std'])) 50 | 51 | train_transforms = transforms.Compose(transform_list) 52 | val_transforms = transforms.Compose([ 53 | transforms.ToPILImage(), 54 | transforms.Resize(size=int(data_cfg['input_size'][-1]/data_cfg['crop_pct']), interpolation=utils.get_interpolation_mode(data_cfg['interpolation'])), 55 | transforms.CenterCrop(size=data_cfg['input_size'][-1]), 56 | transforms.ToTensor(), 57 | utils.gray_world() if config.dataset=='isic2018' else utils.identity(), 58 | transforms.Normalize(mean=data_cfg['mean'], 59 | std=data_cfg['std'])] 60 | ) 61 | 62 | 63 | trainset = dataset_dict[config.dataset](config.data_path, mode='train', transforms=train_transforms, flag=config.flag, debug=DEBUG, config=config) 64 | trainLoader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=8, drop_last=True) 65 | 66 | valset = dataset_dict[config.dataset](config.data_path, mode='val', transforms=val_transforms, flag=config.flag, debug=DEBUG, config=config) 67 | valLoader = DataLoader(valset, batch_size=config.batch_size, shuffle=False, num_workers=2, drop_last=False) 68 | 69 | testset = dataset_dict[config.dataset](config.data_path, mode='test', transforms=val_transforms, flag=config.flag, debug=DEBUG, config=config) 70 | testLoader = DataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=2, drop_last=False) 71 | 72 | 73 | writer = SummaryWriter(config.log_path+config.unique_name) 74 | 75 | if config.cls_weight == None: 76 | criterion = nn.CrossEntropyLoss().cuda() 77 | else: 78 | lesion_weight = torch.FloatTensor(config.cls_weight).cuda() 79 | criterion = nn.CrossEntropyLoss(weight=lesion_weight).cuda() 80 | 81 | if config.optimizer == 'sgd': 82 | optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=0.0005) 83 | elif config.optimizer == 'adam': 84 | optimizer = optim.Adam(model.parameters(), lr=config.lr) 85 | elif config.optimizer == 'adamw': 86 | optimizer = optim.AdamW(model.parameters(), lr=config.lr) 87 | 88 | scaler = torch.cuda.amp.GradScaler() if config.amp else None 89 | 90 | BMAC, acc, _ = validation(model, valLoader, criterion) 91 | print('BMAC: %.5f, Acc: %.5f'%(BMAC, acc)) 92 | 93 | best_acc = 0 94 | for epoch in range(config.epochs): 95 | print('Starting epoch {}/{}'.format(epoch+1, config.epochs)) 96 | batch_time = 0 97 | epoch_loss = 0 98 | 99 | 100 | model.train() 101 | 102 | end = time.time() 103 | 104 | exp_scheduler = utils.exp_lr_scheduler_with_warmup(optimizer, init_lr=config.lr, epoch=epoch, warmup_epoch=config.warmup_epoch, max_epoch=config.epochs) 105 | 106 | for i, (data, label) in enumerate(trainLoader, 0): 107 | x1, target1 = data.float().cuda(), label.long().cuda() 108 | 109 | optimizer.zero_grad() 110 | 111 | if config.amp: 112 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 113 | output = model(x1) 114 | 115 | loss = criterion(output, target1) 116 | 117 | scaler.scale(loss).backward() 118 | scaler.step(optimizer) 119 | scaler.update() 120 | 121 | else: 122 | output = model(x1) 123 | 124 | loss = criterion(output, target1) 125 | loss.backward() 126 | optimizer.step() 127 | 128 | 129 | epoch_loss += loss.item() 130 | 131 | batch_time = time.time() - end 132 | 133 | end = time.time() 134 | 135 | 136 | print(i, 'loss: %.5f, batch_time: %.5f' % (loss.item(), batch_time)) 137 | 138 | print('[epoch %d] epoch loss: %.5f' % (epoch+1, epoch_loss/(i+1) )) 139 | 140 | writer.add_scalar('Train/Loss', epoch_loss/(i+1), epoch+1) 141 | 142 | 143 | 144 | if not os.path.isdir('%s%s/'%(config.cp_path, config.unique_name)): 145 | os.makedirs('%s%s/'%(config.cp_path, config.unique_name)) 146 | 147 | if (epoch+1) % 50 == 0: 148 | torch.save(model.state_dict(), '%s%s/CP%d.pth'%(config.cp_path, config.unique_name, epoch+1)) 149 | 150 | val_BMAC, val_acc, val_loss = validation(model, valLoader, criterion) 151 | writer.add_scalar('Val/BMAC', val_BMAC, epoch+1) 152 | writer.add_scalar('Val/Acc', val_acc, epoch+1) 153 | writer.add_scalar('Val/val_loss', val_loss, epoch+1) 154 | 155 | test_BMAC, test_acc, test_loss = validation(model, testLoader, criterion) 156 | writer.add_scalar('Test/BMAC', test_BMAC, epoch+1) 157 | writer.add_scalar('Test/Acc', test_acc, epoch+1) 158 | writer.add_scalar('Test/test_loss', test_loss, epoch+1) 159 | 160 | lr = optimizer.param_groups[0]['lr'] 161 | writer.add_scalar('LR/lr', lr, epoch+1) 162 | 163 | 164 | if val_BMAC >= best_acc: 165 | best_acc = val_BMAC 166 | if not os.path.exists(config.cp_path): 167 | os.makedirs(config.cp_path) 168 | torch.save(model.state_dict(), '%s%s/best.pth'%(config.cp_path, config.unique_name)) 169 | 170 | 171 | print('save done') 172 | print('BMAC: %.5f/best BMAC: %.5f, Acc: %.5f'%(val_BMAC, best_acc, val_acc)) 173 | 174 | 175 | 176 | def validation(model, dataloader, criterion): 177 | 178 | net = model 179 | 180 | net.eval() 181 | 182 | losses = 0 183 | 184 | pred_list = np.zeros((0), dtype=np.uint8) 185 | gt_list = np.zeros((0), dtype=np.uint8) 186 | 187 | with torch.no_grad(): 188 | for i, (data, label) in enumerate(dataloader): 189 | data, label = data.float(), label.long() 190 | 191 | inputs, labels = data.cuda(), label.cuda() 192 | pred = net(inputs) 193 | 194 | loss = criterion(pred, labels) 195 | losses += loss.item() 196 | 197 | _, label_pred = torch.max(pred, dim=1) 198 | 199 | 200 | pred_list = np.concatenate((pred_list, label_pred.cpu().numpy().astype(np.uint8)), axis=0) 201 | gt_list = np.concatenate((gt_list, label.cpu().numpy().astype(np.uint8)), axis=0) 202 | 203 | BMAC = balanced_accuracy_score(gt_list, pred_list) 204 | correct = np.sum(gt_list == pred_list) 205 | acc = 100 * correct / len(pred_list) 206 | 207 | return BMAC, acc, losses/(i+1) 208 | 209 | 210 | 211 | 212 | if __name__ == '__main__': 213 | parser = OptionParser() 214 | parser.add_option('-e', '--epochs', dest='epochs', default=150, type='int', 215 | help='number of epochs') 216 | parser.add_option('-b', '--batch_size', dest='batch_size', default=128, 217 | type='int', help='batch size') 218 | parser.add_option('--warmup_epoch', dest='warmup_epoch', default=5, type='int') 219 | parser.add_option('--optimizer', dest='optimizer', default='sgd', type='str') 220 | parser.add_option('-l', '--lr', dest='lr', default=0.01, 221 | type='float', help='learning rate') 222 | parser.add_option('-c', '--resume', type='str', dest='load', default=False, 223 | help='load pretrained model') 224 | parser.add_option('-p', '--checkpoint-path', type='str', dest='cp_path', 225 | default='./checkpoint/', help='checkpoint path') 226 | parser.add_option('-o', '--log-path', type='str', dest='log_path', 227 | default='./log/', help='log path') 228 | parser.add_option('-m', '--model', type='str', dest='model', 229 | default='resnet50.a1_in1k', help='use which model in [vit_base_patch16_224.orig_in21k, resnet50.a1_in1k]') # We find vit.orig_in21k is better than CLIP weights 230 | parser.add_option('--linear-probe', dest='linear_probe', action='store_true', help='if use linear probe finetuning') 231 | parser.add_option('-d', '--dataset', type='str', dest='dataset', 232 | default='isic2018', help='name of datasets') 233 | parser.add_option('--data-path', type='str', dest='data_path', 234 | default='/data/local/yg397/dataset/isic2018/', help='the path of the dataset') 235 | parser.add_option('-u', '--unique_name', type='str', dest='unique_name', 236 | default='test', help='name prefix') 237 | parser.add_option('--flag', type='int', dest='flag', default=2, help='fold for cross-validation') 238 | parser.add_option('--gpu', type='str', dest='gpu', default='0') 239 | parser.add_option('--amp', action='store_true', help='if use mixed precision training') 240 | 241 | (config, args) = parser.parse_args() 242 | 243 | os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu 244 | 245 | config.log_path = config.log_path + config.dataset + '/' 246 | config.cp_path = config.cp_path + config.dataset + '/' 247 | 248 | print('use model:', config.model) 249 | 250 | num_class_dict = { 251 | 'isic2018': 7, 252 | } 253 | cls_weight_dict = { 254 | 'isic2018': [1, 0.5, 1.2, 1.3, 1, 2, 2], 255 | } 256 | 257 | config.cls_weight = cls_weight_dict[config.dataset] 258 | config.num_class = num_class_dict[config.dataset] 259 | 260 | 261 | net = timm.create_model(config.model, pretrained=True, num_classes=config.num_class) 262 | if config.linear_probe: 263 | for name, param in net.named_parameters(): 264 | if 'fc' in name and 'resnet' in config.model: 265 | param.requires_grad = True 266 | elif 'head' in name and 'vit' in config.model: 267 | param.requires_grad = True 268 | else: 269 | param.requires_grad = False 270 | 271 | print('num of params', sum(p.numel() for p in net.parameters() if p.requires_grad)) 272 | 273 | 274 | if config.load: 275 | net.load_state_dict(torch.load(config.load)) 276 | print('Model loaded from {}'.format(config.load)) 277 | 278 | net.cuda() 279 | 280 | 281 | train_net(net, config) 282 | 283 | print('done') 284 | 285 | 286 | -------------------------------------------------------------------------------- /train_explicd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import optim 5 | import numpy as np 6 | import timm 7 | from dataset.isic_dataset import SkinDataset 8 | from model import ExpLICD 9 | 10 | from torchvision import transforms, models 11 | from sklearn.metrics import balanced_accuracy_score 12 | 13 | import copy 14 | from torch.utils.data import DataLoader 15 | from optparse import OptionParser 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torchvision.utils import make_grid, save_image 18 | import utils 19 | import matplotlib.pyplot as plt 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import pdb 25 | DEBUG = False 26 | 27 | 28 | 29 | dataset_dict = { 30 | 'isic2018': SkinDataset, 31 | } 32 | 33 | def train_net(model, config): 34 | 35 | print(config.unique_name) 36 | 37 | train_transforms = copy.deepcopy(config.preprocess) 38 | train_transforms.transforms.pop(0) 39 | if model.model_name != 'clip': 40 | train_transforms.transforms.pop(0) 41 | train_transforms.transforms.insert(0, transforms.RandomVerticalFlip()) 42 | train_transforms.transforms.insert(0, transforms.RandomHorizontalFlip()) 43 | train_transforms.transforms.insert(0, transforms.RandomResizedCrop(size=(224,224), scale=(0.75, 1.0), ratio=(0.75, 1.33), interpolation=utils.get_interpolation_mode('bicubic'))) 44 | train_transforms.transforms.insert(0, transforms.ToPILImage()) 45 | #if config.dataset == 'isic2018': 46 | # train_transforms.transforms.insert(-1, utils.gray_world()) 47 | 48 | 49 | val_transforms = copy.deepcopy(config.preprocess) 50 | val_transforms.transforms.insert(0, transforms.ToPILImage()) 51 | #if config.dataset == 'isic2018': 52 | # val_transforms.transforms.insert(-1, utils.gray_world()) 53 | 54 | 55 | trainset = dataset_dict[config.dataset](config.data_path, mode='train', transforms=train_transforms, flag=config.flag, debug=DEBUG, config=config, return_concept_label=True) 56 | trainLoader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=8, drop_last=True) 57 | 58 | valset = dataset_dict[config.dataset](config.data_path, mode='val', transforms=val_transforms, flag=config.flag, debug=DEBUG, config=config, return_concept_label=True) 59 | valLoader = DataLoader(valset, batch_size=config.batch_size, shuffle=False, num_workers=2, drop_last=False) 60 | 61 | testset = dataset_dict[config.dataset](config.data_path, mode='test', transforms=val_transforms, flag=config.flag, debug=DEBUG, config=config, return_concept_label=True) 62 | testLoader = DataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=2, drop_last=False) 63 | 64 | 65 | 66 | writer = SummaryWriter(config.log_path+config.unique_name) 67 | 68 | if config.cls_weight == None: 69 | criterion = nn.CrossEntropyLoss().cuda() 70 | else: 71 | lesion_weight = torch.FloatTensor(config.cls_weight).cuda() 72 | criterion = nn.CrossEntropyLoss(weight=lesion_weight).cuda() 73 | 74 | if config.optimizer == 'sgd': 75 | optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=0.0005) 76 | elif config.optimizer == 'adam': 77 | optimizer = optim.Adam(model.parameters(), lr=config.lr) 78 | elif config.optimizer == 'adamw': 79 | optimizer = optim.AdamW([ 80 | {'params': model.get_backbone_params(), 'lr': config.lr * 0.1}, 81 | {'params': model.get_bridge_params(), 'lr': config.lr}, 82 | ]) 83 | 84 | scaler = torch.cuda.amp.GradScaler() if config.amp else None 85 | 86 | BMAC, acc, _, _ = validation(model, valLoader, criterion) 87 | print('BMAC: %.5f, Acc: %.5f'%(BMAC, acc)) 88 | 89 | best_acc = 0 90 | for epoch in range(config.epochs): 91 | print('Starting epoch {}/{}'.format(epoch+1, config.epochs)) 92 | batch_time = 0 93 | epoch_loss_cls = 0 94 | epoch_loss_concept = 0 95 | 96 | 97 | model.train() 98 | 99 | end = time.time() 100 | 101 | exp_scheduler = utils.exp_lr_scheduler_with_warmup(optimizer, init_lr=config.lr, epoch=epoch, warmup_epoch=config.warmup_epoch, max_epoch=config.epochs) 102 | 103 | for i, (data, label, concept_label) in enumerate(trainLoader, 0): 104 | x, target = data.float().cuda(), label.long().cuda() 105 | concept_label = concept_label.long().cuda() 106 | 107 | optimizer.zero_grad() 108 | 109 | if config.amp: 110 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 111 | cls_logits, image_logits_dict = model(x) 112 | 113 | loss_cls = criterion(cls_logits, target) 114 | 115 | loss_concepts = 0 116 | idx = 0 117 | for key in net.concept_token_dict.keys(): 118 | image_concept_loss = F.cross_entropy(image_logits_dict[key], concept_label[:, idx]) 119 | loss_concepts += image_concept_loss 120 | idx += 1 121 | 122 | loss = loss_cls + loss_concepts / idx 123 | 124 | scaler.scale(loss).backward() 125 | scaler.step(optimizer) 126 | scaler.update() 127 | 128 | else: 129 | cls_logits, image_logits_dict = model(x) 130 | 131 | loss_cls = criterion(cls_logits, target) 132 | 133 | loss_concepts = 0 134 | idx = 0 135 | for key in net.concept_token_dict.keys(): 136 | image_concept_loss = F.cross_entropy(image_logits_dict[key], concept_label[:, idx]) 137 | loss_concepts += image_concept_loss 138 | idx += 1 139 | 140 | loss = loss_cls + loss_concepts / idx 141 | 142 | loss.backward() 143 | optimizer.step() 144 | 145 | 146 | 147 | epoch_loss_cls += loss_cls.item() 148 | epoch_loss_concept += loss_concepts.item() 149 | 150 | batch_time = time.time() - end 151 | 152 | end = time.time() 153 | 154 | 155 | print(i, 'loss_cls: %.5f, loss_concept: %.5f, batch_time: %.5f' % (loss.item(), loss_concepts.item(), batch_time)) 156 | 157 | print('[epoch %d] epoch loss_cls: %.5f, epoch_loss_concept: %.5f' % (epoch+1, epoch_loss_cls/(i+1), epoch_loss_concept/(i+1) )) 158 | 159 | writer.add_scalar('Train/Loss_cls', epoch_loss_cls/(i+1), epoch+1) 160 | writer.add_scalar('Train/Loss_concept', epoch_loss_concept/(i+1), epoch+1) 161 | 162 | 163 | if not os.path.isdir('%s%s/'%(config.cp_path, config.unique_name)): 164 | os.makedirs('%s%s/'%(config.cp_path, config.unique_name)) 165 | 166 | if (epoch+1) % 50 == 0: 167 | torch.save(model.state_dict(), '%s%s/CP%d.pth'%(config.cp_path, config.unique_name, epoch+1)) 168 | 169 | val_BMAC, val_acc, val_loss_cls, val_loss_concept = validation(model, valLoader, criterion) 170 | writer.add_scalar('Val/BMAC', val_BMAC, epoch+1) 171 | writer.add_scalar('Val/Acc', val_acc, epoch+1) 172 | writer.add_scalar('Val/val_loss_cls', val_loss_cls, epoch+1) 173 | writer.add_scalar('Val/val_loss_concept', val_loss_concept, epoch+1) 174 | 175 | test_BMAC, test_acc, test_loss_cls, test_loss_concept = validation(model, testLoader, criterion) 176 | writer.add_scalar('Test/BMAC', test_BMAC, epoch+1) 177 | writer.add_scalar('Test/Acc', test_acc, epoch+1) 178 | writer.add_scalar('Test/test_loss_cls', test_loss_cls, epoch+1) 179 | writer.add_scalar('Test/test_loss_concept', test_loss_concept, epoch+1) 180 | 181 | lr = optimizer.param_groups[0]['lr'] 182 | writer.add_scalar('LR/lr', lr, epoch+1) 183 | 184 | 185 | if val_BMAC >= best_acc: 186 | best_acc = val_BMAC 187 | if not os.path.exists(config.cp_path): 188 | os.makedirs(config.cp_path) 189 | torch.save(model.state_dict(), '%s%s/best.pth'%(config.cp_path, config.unique_name)) 190 | 191 | 192 | print('save done') 193 | print('BMAC: %.5f/best BMAC: %.5f, Acc: %.5f'%(val_BMAC, best_acc, val_acc)) 194 | 195 | 196 | 197 | def validation(model, dataloader, criterion): 198 | 199 | net = model 200 | 201 | net.eval() 202 | 203 | losses_cls = 0 204 | losses_concepts = 0 205 | 206 | pred_list = np.zeros((0), dtype=np.uint8) 207 | gt_list = np.zeros((0), dtype=np.uint8) 208 | 209 | with torch.no_grad(): 210 | for i, (data, label, concept_label) in enumerate(dataloader): 211 | 212 | data, label = data.cuda(), label.long().cuda() 213 | concept_label = concept_label.long().cuda() 214 | cls_logits, image_logits_dict = net(data) 215 | 216 | loss_cls = criterion(cls_logits, label) 217 | losses_cls += loss_cls.item() 218 | 219 | tmp_loss_concepts = 0 220 | idx = 0 221 | for key in net.concept_token_dict.keys(): 222 | image_concept_loss = F.cross_entropy(image_logits_dict[key], concept_label[:, idx]) 223 | tmp_loss_concepts += image_concept_loss.item() 224 | idx += 1 225 | 226 | losses_concepts += tmp_loss_concepts / len(list(net.concept_token_dict.keys())) 227 | 228 | _, label_pred = torch.max(cls_logits, dim=1) 229 | 230 | 231 | pred_list = np.concatenate((pred_list, label_pred.cpu().numpy().astype(np.uint8)), axis=0) 232 | gt_list = np.concatenate((gt_list, label.cpu().numpy().astype(np.uint8)), axis=0) 233 | 234 | BMAC = balanced_accuracy_score(gt_list, pred_list) 235 | correct = np.sum(gt_list == pred_list) 236 | acc = 100 * correct / len(pred_list) 237 | 238 | return BMAC, acc, losses_cls/(i+1), losses_concepts/(i+1) 239 | 240 | 241 | 242 | 243 | if __name__ == '__main__': 244 | parser = OptionParser() 245 | parser.add_option('-e', '--epochs', dest='epochs', default=150, type='int', 246 | help='number of epochs') 247 | parser.add_option('-b', '--batch_size', dest='batch_size', default=128, 248 | type='int', help='batch size') 249 | parser.add_option('--warmup_epoch', dest='warmup_epoch', default=5, type='int') 250 | parser.add_option('--optimizer', dest='optimizer', default='adamw', type='str') 251 | parser.add_option('-l', '--lr', dest='lr', default=0.0001, 252 | type='float', help='learning rate') 253 | parser.add_option('-c', '--resume', type='str', dest='load', default=False, 254 | help='load pretrained model') 255 | parser.add_option('-p', '--checkpoint-path', type='str', dest='cp_path', 256 | #default='/data/yunhe/Liver/auto-aug/checkpoint/', help='checkpoint path') 257 | default='./checkpoint/', help='checkpoint path') 258 | parser.add_option('-o', '--log-path', type='str', dest='log_path', 259 | default='./log/', help='log path') 260 | parser.add_option('-m', '--model', type='str', dest='model', 261 | default='explicd', help='use which model') 262 | parser.add_option('--linear-probe', dest='linear_probe', action='store_true', help='if use linear probe finetuning') 263 | parser.add_option('-d', '--dataset', type='str', dest='dataset', 264 | default='isic2018', help='name of dataset') 265 | parser.add_option('--data-path', type='str', dest='data_path', 266 | default='/data/local/yg397/dataset/isic2018/', help='the path of the dataset') 267 | parser.add_option('-u', '--unique_name', type='str', dest='unique_name', 268 | default='test', help='name prefix') 269 | 270 | 271 | parser.add_option('--flag', type='int', dest='flag', default=2) 272 | 273 | parser.add_option('--gpu', type='str', dest='gpu', 274 | default='0') 275 | parser.add_option('--amp', action='store_true', help='if use mixed precision training') 276 | 277 | (config, args) = parser.parse_args() 278 | 279 | os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu 280 | 281 | config.log_path = config.log_path + config.dataset + '/' 282 | config.cp_path = config.cp_path + config.dataset + '/' 283 | 284 | print('use model:', config.model) 285 | 286 | num_class_dict = { 287 | 'isic2018': 7, 288 | } 289 | 290 | cls_weight_dict = { 291 | 'isic2018': [1, 0.5, 1.2, 1.3, 1, 2, 2], 292 | } 293 | 294 | config.cls_weight = cls_weight_dict[config.dataset] 295 | config.num_class = num_class_dict[config.dataset] 296 | 297 | 298 | from concept_dataset import explicid_isic_dict 299 | concept_list = explicid_isic_dict 300 | net = ExpLICD(concept_list=concept_list, model_name='biomedclip', config=config) 301 | 302 | # We find using orig_in21k vit weights works better than biomedclip vit weights 303 | # Delete the following if want to use biomedclip weights 304 | vit = timm.create_model('vit_base_patch16_224.orig_in21k', pretrained=True, num_classes=config.num_class) 305 | vit.head = nn.Identity() 306 | net.model.visual.trunk.load_state_dict(vit.state_dict()) 307 | 308 | 309 | 310 | if config.load: 311 | net.load_state_dict(torch.load(config.load)) 312 | print('Model loaded from {}'.format(config.load)) 313 | 314 | net.cuda() 315 | 316 | 317 | train_net(net, config) 318 | 319 | print('done') 320 | 321 | 322 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from torchvision.utils import make_grid, save_image 7 | import pdb 8 | import os 9 | import scipy.ndimage 10 | import types as Types 11 | import torchvision.transforms as transforms 12 | 13 | class gray_world(object): 14 | 15 | def __call__(self, img): 16 | mu_g = img[1].mean() 17 | img[0] = img[0] * (mu_g / img[0].mean()) 18 | img[2] = img[2] * (mu_g / img[2].mean()) 19 | 20 | img = torch.clamp(img, 0, 1) 21 | 22 | return img 23 | 24 | def __repr__(self): 25 | return self.__class__.__name__ + '()' 26 | 27 | class identity(object): 28 | 29 | def __call__(self, img): 30 | 31 | return img 32 | 33 | def __repr__(self): 34 | return self.__class__.__name__ + '()' 35 | 36 | 37 | class my_resize(object): 38 | 39 | def __init__(self, size): 40 | self.size = size 41 | 42 | def __call__(self, img): 43 | 44 | img = img.unsqueeze(0) 45 | img = F.interpolate(img, size=self.size, mode='bilinear', align_corners=True) 46 | 47 | img = img.squeeze(0) 48 | 49 | 50 | return img 51 | 52 | def __repr__(self): 53 | return self.__class__.__name__ + '()' 54 | 55 | def get_interpolation_mode(interpolation_str): 56 | interpolation_mapping = { 57 | 'nearest': transforms.InterpolationMode.NEAREST, 58 | 'lanczos': transforms.InterpolationMode.LANCZOS, 59 | 'bilinear': transforms.InterpolationMode.BILINEAR, 60 | 'bicubic': transforms.InterpolationMode.BICUBIC, 61 | 'box': transforms.InterpolationMode.BOX, 62 | 'hamming': transforms.InterpolationMode.HAMMING 63 | } 64 | return interpolation_mapping.get(interpolation_str) 65 | 66 | 67 | class GaussianLayer(nn.Module): 68 | def __init__(self, sigma=8): 69 | super(GaussianLayer, self).__init__() 70 | self.seq = nn.Sequential( 71 | nn.ReflectionPad2d(10), 72 | nn.Conv2d(1, 1, 21, stride=1, padding=0, bias=None) 73 | ) 74 | self.weights_init(sigma) 75 | #self.seq = self.seq.cuda() 76 | 77 | 78 | def forward(self, x): 79 | return self.seq(x) 80 | 81 | def weights_init(self, sigma): 82 | n = np.zeros((21, 21)) 83 | n[10, 10] = 1 84 | k = scipy.ndimage.gaussian_filter(n, sigma=sigma) 85 | for name, f in self.named_parameters(): 86 | f.data.copy_(torch.from_numpy(k)) 87 | f.requires_grad = False 88 | 89 | 90 | 91 | def save_vis_imgs_3(model, imgs_vis, imgs_vis_label, writer, epoch, vis_dir, config): 92 | 93 | noise = torch.randn(imgs_vis.size(0), config.noise_dim).cuda() 94 | imgs_vis_aug = model.aug_net(noise, imgs_vis.cuda(), imgs_vis_label.cuda()) 95 | imgs_vis_aug = imgs_vis_aug.cpu() 96 | 97 | grid = make_grid(imgs_vis_aug[:imgs_vis.size(0)], nrow=int(math.sqrt(imgs_vis.size(0))), normalize=True, padding=1, pad_value=1) 98 | 99 | 100 | save_image(grid, os.path.join(vis_dir, 'aug_imgs_%d.png'%(epoch+1))) 101 | 102 | def save_vis_imgs_4(model, imgs_vis, imgs_vis_label, writer, epoch, vis_dir, config): 103 | 104 | noise = torch.randn(imgs_vis.size(0), config.noise_dim).cuda() 105 | imgs_vis_aug, imgs_vis_label_aug = model.aug_net(noise, imgs_vis.cuda(), imgs_vis_label.cuda()) 106 | imgs_vis_aug = imgs_vis_aug.cpu() 107 | imgs_vis_label_aug = imgs_vis_label_aug.cpu() 108 | 109 | grid1 = make_grid(imgs_vis_aug[:imgs_vis.size(0)], nrow=int(math.sqrt(imgs_vis.size(0))), normalize=True, padding=1, pad_value=1) 110 | grid2 = make_grid(imgs_vis_aug[imgs_vis.size(0):], nrow=int(math.sqrt(imgs_vis.size(0))), normalize=True, padding=1, pad_value=1) 111 | 112 | grid = torch.cat([grid1, grid2], dim=2) 113 | 114 | 115 | save_image(grid, os.path.join(vis_dir, 'aug_imgs_%d.png'%(epoch+1))) 116 | 117 | def save_vis_imgs_5(model, imgs_vis, imgs_vis_label, writer, epoch, vis_dir, config): 118 | 119 | noise = torch.randn(imgs_vis.size(0), config.noise_dim).cuda() 120 | imgs_vis_aug, _= model.aug_net(noise, imgs_vis.cuda()) 121 | imgs_vis_aug = imgs_vis_aug.cpu() 122 | 123 | grid = make_grid(imgs_vis_aug[:imgs_vis.size(0)], nrow=int(math.sqrt(imgs_vis.size(0))), normalize=True, padding=1, pad_value=1) 124 | 125 | 126 | save_image(grid, os.path.join(vis_dir, 'aug_imgs_%d.png'%(epoch+1))) 127 | 128 | 129 | 130 | def multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, lr_decay_epoch, max_epoch, gamma=0.1): 131 | 132 | if epoch >= 0 and epoch <= warmup_epoch: 133 | lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.)) 134 | if epoch == warmup_epoch: 135 | lr = init_lr 136 | for param_group in optimizer.param_groups: 137 | param_group['lr'] = lr 138 | 139 | return lr 140 | 141 | flag = False 142 | for i in range(len(lr_decay_epoch)): 143 | if epoch == lr_decay_epoch[i]: 144 | flag = True 145 | break 146 | 147 | if flag == True: 148 | lr = init_lr * gamma**(i+1) 149 | for param_group in optimizer.param_groups: 150 | param_group['lr'] = lr 151 | 152 | else: 153 | return optimizer.param_groups[0]['lr'] 154 | 155 | return lr 156 | 157 | def exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, max_epoch): 158 | 159 | if epoch >= 0 and epoch <= warmup_epoch: 160 | lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.)) 161 | if epoch == warmup_epoch: 162 | lr = init_lr 163 | for param_group in optimizer.param_groups: 164 | param_group['lr'] = lr 165 | 166 | return lr 167 | 168 | else: 169 | lr = init_lr * (1 - epoch / max_epoch)**0.9 170 | for param_group in optimizer.param_groups: 171 | param_group['lr'] = lr 172 | 173 | return lr 174 | 175 | class Exp_LR_Scheduler_with_Warmup(): 176 | def __init__(self, optimizer, init_lr, warmup_epoch, max_epoch): 177 | self.optimizer = optimizer 178 | self.init_lr = init_lr 179 | self.max_epoch = max_epoch 180 | self.warmup_epoch = warmup_epoch 181 | self.current_epoch = 0 182 | 183 | lr = self.init_lr * 2.718 ** (10*(float(self.current_epoch) / float(warmup_epoch) - 1.)) 184 | self.set_lr(lr) 185 | 186 | def set_lr(self, lr): 187 | for param_group in self.optimizer.param_groups: 188 | param_group['lr'] = lr 189 | 190 | def step(self): 191 | 192 | self.current_epoch += 1 193 | 194 | if self.current_epoch >= 0 and self.current_epoch <= self.warmup_epoch: 195 | lr = self.init_lr * 2.718 ** (10*(float(self.current_epoch) / float(self.warmup_epoch) - 1.)) 196 | if self.current_epoch == self.warmup_epoch: 197 | lr = self.init_lr 198 | else: 199 | lr = self.init_lr * (1 - self.current_epoch / self.max_epoch) ** 0.9 200 | 201 | self.set_lr(lr) 202 | 203 | class MultiBatchNorm(nn.Module): 204 | def __init__(self, num_features, momentum=None, eps=1e-05, dim='2d', types=['base', 'aug']) : 205 | assert isinstance(types, list) and len(types) > 1 206 | assert 'base' in types 207 | assert dim in ('1d', '2d') 208 | super(MultiBatchNorm, self).__init__() 209 | self.types = types 210 | 211 | if dim == '1d': 212 | if momentum is not None: 213 | self.bns = nn.ModuleDict([[t, nn.BatchNorm1d(num_features, momentum=momentum, eps=eps)] for t in types]) 214 | else: 215 | self.bns = nn.ModuleDict([[t, nn.BatchNorm1d(num_features, eps=eps)] for t in types]) 216 | elif dim == '2d': 217 | if momentum is not None: 218 | self.bns = nn.ModuleDict([[t, nn.BatchNorm2d(num_features, momentum=momentum, eps=eps)] for t in types]) 219 | else: 220 | self.bns = nn.ModuleDict([[t, nn.BatchNorm2d(num_features, eps=eps)] for t in types]) 221 | 222 | self.t = 'base' 223 | def forward(self, x): 224 | # print('bn type: {}'.format(self.t)) 225 | assert self.t in self.types 226 | out = self.bns[self.t](x) 227 | self.t = 'base' 228 | return out 229 | def replace_bn_with_multibn(model, types=['base', 'aug']): 230 | def convert(model): 231 | conversion_count = 0 232 | for name, module in reversed(model._modules.items()): 233 | if len(list(module.children())) > 0: 234 | 235 | model._modules[name], num_converted = convert(module) 236 | conversion_count += num_converted 237 | 238 | if type(module) == nn.BatchNorm2d: 239 | 240 | layer_old = module 241 | num_features = module.num_features 242 | eps = module.eps 243 | momentum = module.momentum 244 | layer_new = MultiBatchNorm(num_features=num_features, eps=eps, momentum=momentum, types=types) 245 | 246 | state_dict = module.state_dict() 247 | for t in types: 248 | layer_new.bns[t].load_state_dict(state_dict) 249 | 250 | model._modules[name] = layer_new 251 | conversion_count += 1 252 | return model, conversion_count 253 | 254 | 255 | def set_bn_type(self, t): 256 | for m in self.modules(): 257 | if isinstance(m, MultiBatchNorm): 258 | m.t = t 259 | 260 | model, _ = convert(model) 261 | model.set_bn_type = Types.MethodType(set_bn_type, model) 262 | 263 | return model 264 | 265 | def replace_bn_with_layer_norm(model, types='layer'): 266 | def convert_bn(model): 267 | for name, module in reversed(model._modules.items()): 268 | if len(list(module.children())) > 0: 269 | 270 | model._modules[name] = convert_bn(module) 271 | if type(module) == nn.BatchNorm2d: 272 | 273 | layer_old = module 274 | num_features = module.num_features 275 | eps = module.eps 276 | momentum = module.momentum 277 | 278 | if types == 'no': 279 | layer_new = nn.Sequential(nn.Identity()) 280 | model._modules[name] = layer_new 281 | return model 282 | model = convert_bn(model) 283 | 284 | return model 285 | 286 | 287 | def cal_dice(pred, target, C): 288 | N = pred.shape[0] 289 | target_mask = target.data.new(N, C).fill_(0) 290 | target_mask.scatter_(1, target, 1.) 291 | 292 | pred_mask = pred.data.new(N, C).fill_(0) 293 | pred_mask.scatter_(1, pred, 1.) 294 | 295 | intersection= pred_mask * target_mask 296 | summ = pred_mask + target_mask 297 | 298 | intersection = intersection.sum(0).type(torch.float32) 299 | summ = summ.sum(0).type(torch.float32) 300 | 301 | eps = torch.rand(C, dtype=torch.float32) 302 | eps = eps.fill_(1e-7) 303 | 304 | summ += eps.cuda() 305 | dice = 2 * intersection / summ 306 | 307 | return dice, intersection, summ 308 | 309 | --------------------------------------------------------------------------------