├── models ├── __init__.py ├── caption.py ├── utils.py ├── position_encoding.py ├── backbone.py ├── transformer.py └── model.py ├── datasets ├── __init__.py ├── strip_list.pkl ├── thresholds.pkl ├── iu_xray_vocabulary.pkl ├── mimic_cxr_vocabulary.pkl ├── utils.py ├── tokenizers.py └── xray.py ├── docs └── EKAGen-framework.png ├── test_iu.sh ├── train_iu.sh ├── train_mimic.sh ├── test_mimic.sh ├── utils ├── stloss.py └── engine.py ├── ADM ├── generate_adm.py ├── model.py ├── adm_utils.py └── gradcam_utils.py ├── README.md ├── requirements.txt ├── main.py └── LICENSE /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/strip_list.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hnjzbss/EKAGen/HEAD/datasets/strip_list.pkl -------------------------------------------------------------------------------- /datasets/thresholds.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hnjzbss/EKAGen/HEAD/datasets/thresholds.pkl -------------------------------------------------------------------------------- /docs/EKAGen-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hnjzbss/EKAGen/HEAD/docs/EKAGen-framework.png -------------------------------------------------------------------------------- /datasets/iu_xray_vocabulary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hnjzbss/EKAGen/HEAD/datasets/iu_xray_vocabulary.pkl -------------------------------------------------------------------------------- /datasets/mimic_cxr_vocabulary.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hnjzbss/EKAGen/HEAD/datasets/mimic_cxr_vocabulary.pkl -------------------------------------------------------------------------------- /test_iu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=3 3 | python main.py --batch_size 16 --image_size 300 --vocab_size 760 --theta 0.4 --gamma 0.4 --beta 1.0 --delta 0.01 --dataset_name iu_xray --anno_path ../dataset/iu_xray/annotation.json --data_dir ../dataset/iu_xray/images --mode test --knowledge_prompt_path ./knowledge_path/knowledge_prompt_iu.pkl --test_path ./weight_path/iu_weight.pth -------------------------------------------------------------------------------- /train_iu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=2 3 | python main.py --epochs 50 --lr_backbone 1e-5 --lr 1e-4 --batch_size 8 --image_size 300 --vocab_size 760 --theta 0.4 --gamma 0.4 --beta 1.0 --delta 0.01 --dataset_name iu_xray --t_model_weight_path ./weight_path/iu_t_model.pth --anno_path ../dataset/iu_xray/annotation.json --data_dir ../dataset/iu_xray/images --mode train 4 | -------------------------------------------------------------------------------- /train_mimic.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python main.py --epochs 50 --lr_backbone 1e-5 --lr 1e-4 --batch_size 32 --image_size 300 --vocab_size 4253 --theta 0.4 --gamma 0.4 --beta 1.0 --delta 0.01 --dataset_name mimic_cxr --t_model_weight_path ./weight_path/mimic_t_model.pth --anno_path ../dataset/mimic_cxr/annotation.json --data_dir ../dataset/mimic_cxr/images300 --mode train 4 | -------------------------------------------------------------------------------- /test_mimic.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=1 3 | python main.py --batch_size 16 --image_size 300 --vocab_size 4253 --theta 0.4 --gamma 0.4 --beta 1.0 --delta 0.01 --dataset_name mimic_cxr --anno_path ../dataset/mimic_cxr/annotation.json --data_dir ../dataset/mimic_cxr/images300 --mode test --knowledge_prompt_path ./knowledge_path/knowledge_prompt_mimic.pkl --test_path ./weight_path/mimic_weight.pth -------------------------------------------------------------------------------- /utils/stloss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class SoftTarget(nn.Module): 6 | ''' 7 | Distilling the Knowledge in a Neural Network 8 | https://arxiv.org/pdf/1503.02531.pdf 9 | ''' 10 | 11 | def __init__(self, T): 12 | super(SoftTarget, self).__init__() 13 | self.T = T 14 | 15 | def forward(self, out_s, out_t): 16 | loss = F.kl_div(F.log_softmax(out_s / self.T, dim=2), 17 | F.softmax(out_t / self.T, dim=2), 18 | reduction='batchmean') * self.T * self.T 19 | 20 | return loss 21 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List 3 | from torch import Tensor 4 | import json 5 | 6 | MAX_DIM = 300 7 | 8 | 9 | def read_json(file_name): 10 | with open(file_name) as handle: 11 | out = json.load(handle) 12 | return out 13 | 14 | 15 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor], max_dim): 16 | if tensor_list[0].ndim == 3: 17 | max_size = [3, max_dim, max_dim] 18 | batch_shape = [len(tensor_list)] + max_size 19 | b, c, h, w = batch_shape 20 | dtype = tensor_list[0].dtype 21 | device = tensor_list[0].device 22 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 23 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 24 | for img, pad_img, m in zip(tensor_list, tensor, mask): 25 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 26 | m[: img.shape[1], :img.shape[2]] = False 27 | else: 28 | raise ValueError('not supported') 29 | return NestedTensor(tensor, mask) 30 | 31 | 32 | class NestedTensor(object): 33 | def __init__(self, tensors, mask: Optional[Tensor]): 34 | self.tensors = tensors 35 | self.mask = mask 36 | 37 | def to(self, device): 38 | cast_tensor = self.tensors.to(device) 39 | mask = self.mask 40 | if mask is not None: 41 | assert mask is not None 42 | cast_mask = mask.to(device) 43 | else: 44 | cast_mask = None 45 | return NestedTensor(cast_tensor, cast_mask) 46 | 47 | def decompose(self): 48 | return self.tensors, self.mask 49 | 50 | def __repr__(self): 51 | return str(self.tensors) 52 | -------------------------------------------------------------------------------- /models/caption.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .utils import NestedTensor, nested_tensor_from_tensor_list 5 | from .backbone import build_backbone 6 | from .transformer import build_transformer 7 | 8 | 9 | class Caption(nn.Module): 10 | def __init__(self, backbone, transformer, hidden_dim, vocab_size): 11 | super().__init__() 12 | self.backbone = backbone 13 | self.input_proj = nn.Conv2d( 14 | backbone.num_channels, hidden_dim, kernel_size=1) 15 | self.transformer = transformer 16 | self.mlp = MLP(hidden_dim, 512, vocab_size, 3) 17 | 18 | def forward(self, samples, target, target_mask, class_feature): 19 | if not isinstance(samples, NestedTensor): 20 | samples = nested_tensor_from_tensor_list(samples) 21 | 22 | features, pos = self.backbone(samples) 23 | src, mask = features[-1].decompose() 24 | 25 | assert mask is not None 26 | 27 | hs = self.transformer(self.input_proj(src), mask, 28 | pos[-1], target, target_mask, class_feature) 29 | out = self.mlp(hs.permute(1, 0, 2)) 30 | return out 31 | 32 | 33 | class MLP(nn.Module): 34 | """ Very simple multi-layer perceptron (also called FFN)""" 35 | 36 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 37 | super().__init__() 38 | self.num_layers = num_layers 39 | h = [hidden_dim] * (num_layers - 1) 40 | self.layers = nn.ModuleList(nn.Linear(n, k) 41 | for n, k in zip([input_dim] + h, h + [output_dim])) 42 | 43 | def forward(self, x): 44 | for i, layer in enumerate(self.layers): 45 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 46 | return x 47 | 48 | 49 | def build_model(config): 50 | backbone = build_backbone(config) 51 | transformer = build_transformer(config) 52 | 53 | model = Caption(backbone, transformer, config.hidden_dim, config.vocab_size) 54 | criterion = torch.nn.CrossEntropyLoss() 55 | 56 | return model, criterion 57 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch import Tensor 6 | import pickle 7 | 8 | 9 | def _max_by_axis(the_list): 10 | maxes = the_list[0] 11 | for sublist in the_list[1:]: 12 | for index, item in enumerate(sublist): 13 | maxes[index] = max(maxes[index], item) 14 | return maxes 15 | 16 | 17 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 18 | if tensor_list[0].ndim == 3: 19 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 20 | batch_shape = [len(tensor_list)] + max_size 21 | b, c, h, w = batch_shape 22 | dtype = tensor_list[0].dtype 23 | device = tensor_list[0].device 24 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 25 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 26 | for img, pad_img, m in zip(tensor_list, tensor, mask): 27 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 28 | m[: img.shape[1], :img.shape[2]] = False 29 | else: 30 | raise ValueError('not supported') 31 | return NestedTensor(tensor, mask) 32 | 33 | 34 | class NestedTensor(object): 35 | def __init__(self, tensors, mask: Optional[Tensor]): 36 | self.tensors = tensors 37 | self.mask = mask 38 | 39 | def to(self, device): 40 | cast_tensor = self.tensors.to(device) 41 | mask = self.mask 42 | if mask is not None: 43 | assert mask is not None 44 | cast_mask = mask.to(device) 45 | else: 46 | cast_mask = None 47 | return NestedTensor(cast_tensor, cast_mask) 48 | 49 | def decompose(self): 50 | return self.tensors, self.mask 51 | 52 | def __repr__(self): 53 | return str(self.tensors) 54 | 55 | 56 | def is_dist_avail_and_initialized(): 57 | if not dist.is_available(): 58 | return False 59 | if not dist.is_initialized(): 60 | return False 61 | return True 62 | 63 | 64 | def get_rank(): 65 | if not is_dist_avail_and_initialized(): 66 | return 0 67 | return dist.get_rank() 68 | 69 | 70 | def is_main_process(): 71 | return get_rank() == 0 72 | 73 | 74 | def get_knowledge(filename): 75 | with open(filename, 'rb') as f: 76 | knowledge_prompt = pickle.load(f) 77 | return knowledge_prompt 78 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | from .utils import NestedTensor 6 | 7 | 8 | class PositionEmbeddingSine(nn.Module): 9 | """ 10 | This is a more standard version of the position embedding, very similar to the one 11 | used by the Attention is all you need paper, generalized to work on images. 12 | """ 13 | 14 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 15 | super().__init__() 16 | self.num_pos_feats = num_pos_feats 17 | self.temperature = temperature 18 | self.normalize = normalize 19 | if scale is not None and normalize is False: 20 | raise ValueError("normalize should be True if scale is passed") 21 | if scale is None: 22 | scale = 2 * math.pi 23 | self.scale = scale 24 | 25 | def forward(self, tensor_list: NestedTensor): 26 | x = tensor_list.tensors 27 | mask = tensor_list.mask 28 | assert mask is not None 29 | not_mask = ~mask 30 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 31 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 32 | if self.normalize: 33 | eps = 1e-6 34 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 35 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 36 | 37 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 38 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 39 | 40 | pos_x = x_embed[:, :, :, None] / dim_t 41 | pos_y = y_embed[:, :, :, None] / dim_t 42 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 43 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 45 | return pos 46 | 47 | 48 | class PositionEmbeddingLearned(nn.Module): 49 | """ 50 | Absolute pos embedding, learned. 51 | """ 52 | 53 | def __init__(self, num_pos_feats=256): 54 | super().__init__() 55 | self.row_embed = nn.Embedding(50, num_pos_feats) 56 | self.col_embed = nn.Embedding(50, num_pos_feats) 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | nn.init.uniform_(self.row_embed.weight) 61 | nn.init.uniform_(self.col_embed.weight) 62 | 63 | def forward(self, tensor_list: NestedTensor): 64 | x = tensor_list.tensors 65 | h, w = x.shape[-2:] 66 | i = torch.arange(w, device=x.device) 67 | j = torch.arange(h, device=x.device) 68 | x_emb = self.col_embed(i) 69 | y_emb = self.row_embed(j) 70 | pos = torch.cat([ 71 | x_emb.unsqueeze(0).repeat(h, 1, 1), 72 | y_emb.unsqueeze(1).repeat(1, w, 1), 73 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 74 | return pos 75 | 76 | 77 | def build_position_encoding(config): 78 | N_steps = config.hidden_dim // 2 79 | if config.position_embedding in ('v2', 'sine'): 80 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 81 | elif config.position_embedding in ('v3', 'learned'): 82 | position_embedding = PositionEmbeddingLearned(N_steps) 83 | else: 84 | raise ValueError(f"not supported {config.position_embedding}") 85 | 86 | return position_embedding 87 | -------------------------------------------------------------------------------- /ADM/generate_adm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from torchvision import transforms 6 | from gradcam_utils import GradCAM, show_cam_on_image 7 | from model import resnet34 8 | import pickle 9 | import cv2 10 | import glob 11 | import tqdm 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 14 | 15 | if os.path.exists("datasets/thresholds.pkl"): 16 | with open("datasets/thresholds.pkl", "rb") as f: 17 | thresholds = pickle.load(f) 18 | 19 | 20 | def get_model(): 21 | model = resnet34(num_classes=14).cuda() 22 | model_weight_path = "./weights/MIMIC_best_weight.pth" 23 | model.load_state_dict(torch.load(model_weight_path, map_location="cpu")) 24 | return model.eval() 25 | 26 | 27 | def main(model, image_path, seg_path, mask_path, array_path): 28 | target_layers = [model.layer4] 29 | data_transform = transforms.Compose([transforms.Resize(300), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 32 | 33 | assert os.path.exists(image_path), "file: '{}' dose not exist.".format(image_path) 34 | img = Image.open(image_path).convert('RGB') 35 | img_np = np.array(img, dtype=np.uint8) 36 | img_tensor = data_transform(img) 37 | input_tensor = torch.unsqueeze(img_tensor, dim=0).cuda() 38 | logit = model(input_tensor) # [64, 1, 768] 39 | thresholded_predictions = 1 * (logit.detach().cpu().numpy() > thresholds) 40 | indices = np.where(thresholded_predictions[0] == 1)[0] 41 | 42 | cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) 43 | mask_arr_ass = np.zeros((300, 300, 3)) 44 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) 45 | 46 | for target_category in list(indices): 47 | grayscale_cam = cam(input_tensor=input_tensor, target_category=int(target_category)) 48 | 49 | grayscale_cam = grayscale_cam[0, :] 50 | _, heatmap = show_cam_on_image(img_np.astype(dtype=np.float32) / 255., 51 | grayscale_cam, 52 | use_rgb=True) 53 | threshold = 0.6 54 | mask = cv2.threshold(heatmap, threshold, 1, cv2.THRESH_BINARY)[1] 55 | mask = cv2.dilate(mask, kernel) 56 | mask = mask.astype(np.uint8) * 255 57 | mask_arr = np.asarray(mask) 58 | mask_arr_ass += mask_arr 59 | mask_arr_ass = np.any(mask_arr_ass, axis=2) 60 | np.save(array_path, mask_arr_ass) 61 | mask = Image.fromarray((mask_arr_ass * 255).astype(np.uint8)).convert('L') 62 | mask.save(mask_path) 63 | img = Image.open(image_path) 64 | img = mask * np.asarray(img) 65 | img = Image.fromarray(img) 66 | img.save(seg_path) 67 | 68 | 69 | if __name__ == '__main__': 70 | image_list = glob.glob("../dataset/mimic_cxr/images300/*/*/*/*.jpg") 71 | model = get_model() 72 | bar = tqdm.tqdm(image_list) 73 | for image_path in bar: 74 | seg_path = image_path.replace("images300", "resnet34_300/images300_seg") 75 | mask_path = image_path.replace("images300", "resnet34_300/images300_mask") 76 | array_path = image_path.replace("images300", "resnet34_300/images300_array").replace(".jpg", ".npy") 77 | 78 | if not os.path.exists(os.path.dirname(seg_path)): 79 | os.makedirs(os.path.dirname(seg_path)) 80 | os.makedirs(os.path.dirname(mask_path)) 81 | os.makedirs(os.path.dirname(array_path)) 82 | if not os.path.exists(seg_path): 83 | main(model, image_path, seg_path, mask_path, array_path) 84 | -------------------------------------------------------------------------------- /ADM/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class BasicBlock(nn.Module): 6 | expansion = 1 7 | 8 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs): 9 | super(BasicBlock, self).__init__() 10 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 11 | kernel_size=3, stride=stride, padding=1, 12 | bias=False) 13 | self.bn1 = nn.BatchNorm2d(out_channel) 14 | self.relu = nn.ReLU() 15 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 16 | kernel_size=3, stride=1, padding=1, 17 | bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_channel) 19 | self.downsample = downsample 20 | 21 | def forward(self, x): 22 | identity = x 23 | if self.downsample is not None: 24 | identity = self.downsample(x) 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | out += identity 34 | out = self.relu(out) 35 | 36 | return out 37 | 38 | 39 | class ResNet(nn.Module): 40 | 41 | def __init__(self, 42 | block, 43 | blocks_num, 44 | num_classes=1000, 45 | include_top=True, 46 | groups=1, 47 | width_per_group=64): 48 | super(ResNet, self).__init__() 49 | self.include_top = include_top 50 | self.in_channel = 64 51 | 52 | self.groups = groups 53 | self.width_per_group = width_per_group 54 | 55 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, 56 | padding=3, bias=False) 57 | 58 | self.bn1 = nn.BatchNorm2d(self.in_channel) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 61 | self.layer1 = self._make_layer(block, 64, blocks_num[0]) 62 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) 63 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) 64 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) 65 | if self.include_top: 66 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 67 | self.fc = nn.Linear(512 * block.expansion, num_classes) 68 | 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 72 | 73 | def _make_layer(self, block, channel, block_num, stride=1): 74 | downsample = None 75 | if stride != 1 or self.in_channel != channel * block.expansion: 76 | downsample = nn.Sequential( 77 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), 78 | nn.BatchNorm2d(channel * block.expansion)) 79 | 80 | layers = [] 81 | layers.append(block(self.in_channel, 82 | channel, 83 | downsample=downsample, 84 | stride=stride, 85 | groups=self.groups, 86 | width_per_group=self.width_per_group)) 87 | self.in_channel = channel * block.expansion 88 | 89 | for _ in range(1, block_num): 90 | layers.append(block(self.in_channel, 91 | channel, 92 | groups=self.groups, 93 | width_per_group=self.width_per_group)) 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | x = self.conv1(x) 99 | x = self.bn1(x) 100 | x = self.relu(x) 101 | x = self.maxpool(x) 102 | 103 | x = self.layer1(x) 104 | x = self.layer2(x) 105 | x = self.layer3(x) 106 | x = self.layer4(x) 107 | 108 | if self.include_top: 109 | x = self.avgpool(x) 110 | x = torch.flatten(x, 1) 111 | x = self.fc(x) 112 | x = torch.sigmoid(x) 113 | 114 | return x 115 | 116 | 117 | def resnet34(num_classes=1000, include_top=True): 118 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 119 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision 4 | from torch import nn 5 | from torchvision.models._utils import IntermediateLayerGetter 6 | from typing import Dict, List 7 | from .utils import NestedTensor, is_main_process 8 | from .position_encoding import build_position_encoding 9 | 10 | 11 | class FrozenBatchNorm2d(torch.nn.Module): 12 | """ 13 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 14 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 15 | without which any other models than torchvision.models.resnet[18,34,50,101] 16 | produce nans. 17 | """ 18 | 19 | def __init__(self, n): 20 | super(FrozenBatchNorm2d, self).__init__() 21 | self.register_buffer("weight", torch.ones(n)) 22 | self.register_buffer("bias", torch.zeros(n)) 23 | self.register_buffer("running_mean", torch.zeros(n)) 24 | self.register_buffer("running_var", torch.ones(n)) 25 | 26 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 27 | missing_keys, unexpected_keys, error_msgs): 28 | num_batches_tracked_key = prefix + 'num_batches_tracked' 29 | if num_batches_tracked_key in state_dict: 30 | del state_dict[num_batches_tracked_key] 31 | 32 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 33 | state_dict, prefix, local_metadata, strict, 34 | missing_keys, unexpected_keys, error_msgs) 35 | 36 | def forward(self, x): 37 | w = self.weight.reshape(1, -1, 1, 1) 38 | b = self.bias.reshape(1, -1, 1, 1) 39 | rv = self.running_var.reshape(1, -1, 1, 1) 40 | rm = self.running_mean.reshape(1, -1, 1, 1) 41 | eps = 1e-5 42 | scale = w * (rv + eps).rsqrt() 43 | bias = b - rm * scale 44 | return x * scale + bias 45 | 46 | 47 | class BackboneBase(nn.Module): 48 | 49 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 50 | super().__init__() 51 | for name, parameter in backbone.named_parameters(): 52 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 53 | parameter.requires_grad_(False) 54 | if return_interm_layers: 55 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 56 | else: 57 | return_layers = {'layer4': "0"} 58 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 59 | self.num_channels = num_channels 60 | 61 | def forward(self, tensor_list: NestedTensor): 62 | xs = self.body(tensor_list.tensors) 63 | out: Dict[str, NestedTensor] = {} 64 | for name, x in xs.items(): 65 | m = tensor_list.mask 66 | assert m is not None 67 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 68 | out[name] = NestedTensor(x, mask) 69 | return out 70 | 71 | 72 | class Backbone(BackboneBase): 73 | """ResNet backbone with frozen BatchNorm.""" 74 | 75 | def __init__(self, name: str, 76 | train_backbone: bool, 77 | return_interm_layers: bool, 78 | dilation: bool): 79 | backbone = getattr(torchvision.models, name)( 80 | replace_stride_with_dilation=[False, False, dilation], 81 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 82 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 83 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 84 | 85 | 86 | class Joiner(nn.Sequential): 87 | def __init__(self, backbone, position_embedding): 88 | super().__init__(backbone, position_embedding) 89 | 90 | def forward(self, tensor_list: NestedTensor): 91 | xs = self[0](tensor_list) 92 | out: List[NestedTensor] = [] 93 | pos = [] 94 | for name, x in xs.items(): 95 | out.append(x) 96 | # position encoding 97 | pos.append(self[1](x).to(x.tensors.dtype)) 98 | 99 | return out, pos 100 | 101 | 102 | def build_backbone(config): 103 | position_embedding = build_position_encoding(config) 104 | train_backbone = config.lr_backbone > 0 105 | return_interm_layers = False 106 | backbone = Backbone(config.backbone, train_backbone, return_interm_layers, config.dilation) 107 | model = Joiner(backbone, position_embedding) 108 | model.num_channels = backbone.num_channels 109 | return model 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EKAGen 2 | Code for CVPR2024 paper: "**[Instance-level Expert Knowledge and Aggregate Discriminative Attention for Radiology Report Generation](https://openaccess.thecvf.com/content/CVPR2024/papers/Bu_Instance-level_Expert_Knowledge_and_Aggregate_Discriminative_Attention_for_Radiology_Report_CVPR_2024_paper.pdf)**". Shenshen Bu, Taiji Li, Yuedong Yang, Zhiming Dai. [**[Video](https://www.youtube.com/watch?v=QbcNQ2zuS-8)**] 3 | 4 |

5 | EKAGen 框架示意图 6 |

7 | 8 | > **

Abstract:** Automatic radiology report generation can provide substantial advantages to clinical physicians by effectively reducing their workload and improving efficiency. Despite the promising potential of current methods, challenges persist in effectively extracting and preventing degradation of prominent features, as well as enhancing attention on pivotal regions. In this paper, we propose an Instance-level Expert Knowledge and Aggregate Discriminative Attention framework for radiology report generation. We convert expert reports into an embedding space and generate comprehensive representations for each disease, which serve as Preliminary Knowledge Support (PKS). To prevent feature disruption, we select the representations in the embedding space with the smallest distances to PKS as Rectified Knowledge Support (RKS). Then, EKAGen diagnoses the diseases and retrieves knowledge from RKS, creating Instance-level Expert Knowledge (IEK) for each query image, boosting generation. Additionally, we introduce Aggregate Discriminative Attention Map (ADM), which uses weak supervision to create maps of discriminative regions that highlight pivotal regions. For training, we propose a Global Information Self-Distillation (GID) strategy, using an iteratively optimized model to distill global knowledge into EKAGen. Extensive experiments and analyses on IU X-Ray and MIMIC-CXR datasets demonstrate that EKAGen outperforms previous state-of-the-art methods.

9 | 10 | ---------- 11 | 12 | # Get Started 13 | 14 | ## 1) Requirement 15 | 16 | - Python 3.8.13 17 | - Pytorch 1.9.0 18 | - Torchvision 0.10.0 19 | - CUDA 11.8 20 | - NVIDIA RTX 4090 21 | 22 | ## 2) Data Preparation 23 | ### MIMIC-CXR 24 | - You must be a credential user defined in [PhysioNet](https://physionet.org/settings/credentialing/) to access the data. 25 | - Download chest X-rays from [MIMIC-CXR-JPG](https://physionet.org/content/mimic-cxr-jpg/2.0.0/) and reports from [MIMIC-CXR](https://physionet.org/content/mimic-cxr/2.0.0/) Database. 26 | 27 | ### IU X-Ray 28 | - You can download the processed reports and images for IU X-Ray by [Chen *et al.*](https://aclanthology.org/2021.acl-long.459.pdf) from [R2GenCMN](https://github.com/cuhksz-nlp/R2GenCMN). 29 | 30 | ## 3) Download Model Weights and Knowledge Base 31 | * Download the following model weights: 32 | | Model | Publicly Available | 33 | | ----- | ------------------- | 34 | | DiagnosisBot | [diagnosisbot.pth](https://huggingface.co/ShenshenBu/EKAGen/blob/main/diagnosisbot.pth) | 35 | | Generate ADM Model Weight | [MIMIC_best_weight.pth](https://huggingface.co/ShenshenBu/EKAGen/blob/main/MIMIC_best_weight.pth) | 36 | | IU X-Ray Teacher Model | [iu_t_model.pth](https://huggingface.co/ShenshenBu/EKAGen/blob/main/iu_t_model.pth) | 37 | | MIMIC-CXR Teacher Model | [mimic_t_model.pth](https://huggingface.co/ShenshenBu/EKAGen/blob/main/mimic_t_model.pth) | 38 | 39 | * Download the following knowledge base and attention maps: 40 | | Item | Publicly Available | 41 | | ----- | ------------------- | 42 | | IU X-Ray Knowledge Base | [knowledge_prompt_iu.pkl](https://huggingface.co/ShenshenBu/EKAGen/blob/main/knowledge_prompt_iu.pkl) | 43 | | MIMIC-CXR Knowledge Base | [knowledge_prompt_mimic.pkl](https://huggingface.co/ShenshenBu/EKAGen/blob/main/knowledge_prompt_mimic.pkl) | 44 | | IU X-Ray ADM | [iu_mask.tar.gz](https://huggingface.co/ShenshenBu/EKAGen/blob/main/iu_mask.tar.gz) | 45 | | MIMIC-CXR ADM | [mimic_mask.tar.gz](https://huggingface.co/ShenshenBu/EKAGen/blob/main/mimic_mask.tar.gz) | 46 | 47 | ---------- 48 | 49 | ## 4) Training 50 | 51 | ### IU X-Ray 52 | ``` bash 53 | bash train_iu.sh 54 | ``` 55 | 56 | ### MIMIC-CXR 57 | ``` bash 58 | bash train_mimic.sh 59 | ``` 60 | 61 | ## 5) Inference 62 | 63 | You can download our trained models for inference from [IU X-Ray](https://huggingface.co/ShenshenBu/EKAGen/blob/main/iu_weight.pth) and [MIMIC-CXR](https://huggingface.co/ShenshenBu/EKAGen/blob/main/mimic_weight.pth). 64 | 65 | ### IU X-Ray 66 | ``` bash 67 | bash test_iu.sh 68 | ``` 69 | 70 | ### MIMIC-CXR 71 | ``` bash 72 | bash test_mimic.sh 73 | ``` 74 | 75 | ## Citation 76 | 77 | If you find this work useful in your research, please cite: 78 | ```tex 79 | @InProceedings{Bu_2024_CVPR, 80 | author = {Bu, Shenshen and Li, Taiji and Yang, Yuedong and Dai, Zhiming}, 81 | title = {Instance-level Expert Knowledge and Aggregate Discriminative Attention for Radiology Report Generation}, 82 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 83 | month = {June}, 84 | year = {2024}, 85 | pages = {14194-14204} 86 | } 87 | ``` 88 | 89 | ## Contact Information 90 | 91 | If you have any suggestions or questions, you can contact us by: bushsh@alumni.sysu.edu.cn. Thank you for your attention! 92 | -------------------------------------------------------------------------------- /datasets/tokenizers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | import re 4 | from collections import Counter 5 | import pickle 6 | 7 | with open('./datasets/strip_list.pkl', 'rb') as file: 8 | strip = pickle.load(file) 9 | 10 | 11 | class Tokenizer(object): 12 | def __init__(self, ann_path, threshold, dataset_name, max_length=128): 13 | self.ann_path = ann_path 14 | self.threshold = threshold 15 | self.dataset_name = dataset_name 16 | self.vocabulary_path = os.path.join("datasets", self.dataset_name + "_vocabulary.pkl") 17 | self.max_length = max_length 18 | 19 | if self.dataset_name == 'iu_xray': 20 | self.clean_report = self.clean_report_iu_xray 21 | else: 22 | self.clean_report = self.clean_report_mimic_cxr 23 | self.ann = json.loads(open(self.ann_path, 'r').read()) 24 | if os.path.exists(self.vocabulary_path): 25 | with open(self.vocabulary_path, "rb") as f: 26 | self.token2idx, self.idx2token = pickle.load(f) 27 | else: 28 | self.token2idx, self.idx2token = self.create_vocabulary() 29 | 30 | def create_vocabulary(self): 31 | total_tokens = [] 32 | 33 | for example in self.ann['train']: 34 | tokens = self.clean_report(example['report']).split() 35 | for token in tokens: 36 | total_tokens.append(token) 37 | 38 | total_tokens = [item for item in total_tokens if item not in strip] 39 | 40 | counter = Counter(total_tokens) 41 | vocab = [k for k, v in counter.items() if v >= self.threshold] + [''] 42 | 43 | vocab.sort() 44 | token2idx, idx2token = {}, {} 45 | for idx, token in enumerate(vocab): 46 | token2idx[token] = idx + 3 47 | idx2token[idx + 3] = token 48 | with open(self.vocabulary_path, "wb") as f: 49 | pickle.dump([token2idx, idx2token], f) 50 | 51 | return token2idx, idx2token 52 | 53 | def clean_report_iu_xray(self, report): 54 | report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \ 55 | .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \ 56 | .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 57 | .strip().lower().split('. ') 58 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', ''). 59 | replace('\\', '').replace("'", '').strip().lower()) 60 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 61 | report = ' . '.join(tokens) + ' .' 62 | return report 63 | 64 | def clean_report_mimic_cxr(self, report): 65 | report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \ 66 | .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \ 67 | .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \ 68 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \ 69 | .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \ 70 | .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \ 71 | .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \ 72 | .strip().lower().split('. ') 73 | sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '') 74 | .replace('\\', '').replace("'", '').strip().lower()) 75 | tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []] 76 | report = ' . '.join(tokens) + ' .' 77 | return report 78 | 79 | def get_token_by_id(self, id): 80 | return self.idx2token[id] 81 | 82 | def get_id_by_token(self, token): 83 | if token not in self.token2idx: 84 | return self.token2idx[''] 85 | return self.token2idx[token] 86 | 87 | def get_vocab_size(self): 88 | return len(self.token2idx) 89 | 90 | def __call__(self, report): 91 | tokens = self.clean_report(report).split() 92 | ids = [] 93 | for token in tokens: 94 | ids.append(self.get_id_by_token(token)) 95 | ids = [1] + ids + [2] 96 | return ids 97 | 98 | def decode(self, ids): 99 | txt = '' 100 | for i, idx in enumerate(ids): 101 | if idx > 0: 102 | if i >= 1: 103 | txt += ' ' 104 | try: 105 | txt += self.idx2token[idx] 106 | except: 107 | txt += self.idx2token[idx.cpu().item()] 108 | else: 109 | break 110 | return txt 111 | 112 | def decode_batch(self, ids_batch): 113 | out = [] 114 | for ids in ids_batch: 115 | out.append(self.decode(ids)) 116 | return out 117 | 118 | def encode(self, report): 119 | tokens = self.clean_report(report).split() 120 | ids = [] 121 | for token in tokens: 122 | ids.append(self.get_id_by_token(token)) 123 | ids = [1] + ids + [2] 124 | return ids 125 | 126 | def encode_batch(self, report_batch): 127 | out = [] 128 | for ids in report_batch: 129 | out.append(self.encode(ids)[:self.max_length]) 130 | return out 131 | -------------------------------------------------------------------------------- /utils/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import sys 4 | import tqdm 5 | from pycocoevalcap.bleu.bleu import Bleu 6 | from pycocoevalcap.meteor.meteor import Meteor as meteor 7 | from pycocoevalcap.rouge.rouge import Rouge as rouge 8 | from models import utils 9 | 10 | 11 | def train_one_epoch(model, tmodel, class_model, criterion, criterionKD, data_loader, 12 | optimizer, device, max_norm, thresholds, tokenizer, config): 13 | model.train() 14 | criterion.train() 15 | class_model.eval() 16 | tmodel.eval() 17 | 18 | epoch_loss = 0.0 19 | total = len(data_loader) 20 | 21 | with tqdm.tqdm(total=total) as pbar: 22 | for images, masks, com_images, com_masks, caps, cap_masks, image_class in data_loader: 23 | samples = utils.NestedTensor(images, masks).to(device) 24 | com_samples = utils.NestedTensor(com_images, com_masks).to(device) 25 | caps = caps.to(device) 26 | cap_masks = cap_masks.to(device) 27 | 28 | logit = class_model(image_class.to(device)) 29 | thresholded_predictions = 1 * (logit.cpu().numpy() > thresholds) 30 | t_outputs = tmodel(samples, caps[:, :-1], cap_masks[:, :-1], [thresholded_predictions, tokenizer]) 31 | outputs = model(com_samples, caps[:, :-1], cap_masks[:, :-1], [thresholded_predictions, tokenizer]) 32 | kd_loss = criterionKD(outputs, t_outputs.detach()) * config.delta 33 | 34 | loss = criterion(outputs.permute(0, 2, 1), caps[:, 1:]) + kd_loss 35 | loss_value = loss.item() 36 | epoch_loss += loss_value 37 | 38 | if not math.isfinite(loss_value): 39 | print(f'Loss is {loss_value}, stopping training') 40 | sys.exit(1) 41 | 42 | optimizer.zero_grad() 43 | loss.backward() 44 | if max_norm > 0: 45 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 46 | optimizer.step() 47 | 48 | pbar.update(1) 49 | 50 | return epoch_loss / total 51 | 52 | 53 | def create_caption_and_mask(start_token, max_length, batch_size): 54 | caption_template = torch.zeros((batch_size, max_length), dtype=torch.long) 55 | mask_template = torch.ones((batch_size, max_length), dtype=torch.bool) 56 | 57 | caption_template[:, 0] = start_token 58 | mask_template[:, 0] = False 59 | 60 | return caption_template, mask_template 61 | 62 | 63 | def compute_scores(gts, res): 64 | scorers = [ 65 | (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]), 66 | (meteor(), "METEOR"), 67 | (rouge(), "ROUGE_L") 68 | ] 69 | eval_res = {} 70 | for scorer, method in scorers: 71 | try: 72 | score, _ = scorer.compute_score(gts, res, verbose=0) 73 | except TypeError: 74 | score, _ = scorer.compute_score(gts, res) 75 | if type(method) == list: 76 | for sc, m in zip(score, method): 77 | eval_res[m] = sc 78 | else: 79 | eval_res[method] = score 80 | return eval_res 81 | 82 | 83 | @torch.no_grad() 84 | def evaluate(model, class_model, criterion, data_loader, device, config, thresholds, tokenizer): 85 | model.eval() 86 | criterion.eval() 87 | class_model.eval() 88 | total = len(data_loader) 89 | caption_list = [] 90 | caption_tokens_list = [] 91 | 92 | with tqdm.tqdm(total=total) as pbar: 93 | for images, masks, _, _, caps, _, image_class in data_loader: 94 | samples = utils.NestedTensor(images, masks).to(device) 95 | caption, cap_mask = create_caption_and_mask( 96 | config.start_token, config.max_position_embeddings, config.batch_size) 97 | try: 98 | for i in range(config.max_position_embeddings - 1): 99 | logit = class_model(image_class.to(device)) 100 | thresholded_predictions = 1 * (logit.cpu().numpy() > thresholds) 101 | predictions = model(samples.to(device), caption.to(device), cap_mask.to(device), 102 | [thresholded_predictions, tokenizer]) 103 | predictions = predictions[:, i, :] 104 | predicted_id = torch.argmax(predictions, axis=-1) 105 | if i == config.max_position_embeddings - 2: 106 | caption_list.extend(caption.cpu().numpy().tolist()) 107 | caption_tokens_list.extend(caps[:, 1:].cpu().numpy().tolist()) 108 | break 109 | caption[:, i + 1] = predicted_id 110 | cap_mask[:, i + 1] = False 111 | except: 112 | pass 113 | pbar.update(1) 114 | 115 | pred = caption_list 116 | report = caption_tokens_list 117 | preds_orign = [] 118 | preds = [] 119 | reports = [] 120 | for preds_sentence in pred: 121 | single_sentence = list() 122 | for item in preds_sentence: 123 | single_sentence.append(item) 124 | if item == 2: 125 | preds_orign.append(single_sentence) 126 | continue 127 | for preds_sentence in pred: 128 | preds.append([item for item in preds_sentence if item not in [config.start_token, config.end_token, 0]]) 129 | for reports_sentence in report: 130 | reports.append([item for item in reports_sentence if item not in [config.start_token, config.end_token, 0]]) 131 | ground_truth = [tokenizer.decode(item) for item in reports] 132 | pred_result = [tokenizer.decode(item) for item in preds] 133 | val_met = compute_scores({i: [gt] for i, gt in enumerate(ground_truth)}, 134 | {i: [re] for i, re in enumerate(pred_result)}) 135 | return val_met 136 | -------------------------------------------------------------------------------- /ADM/adm_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | class ActivationsAndGradients: 6 | def __init__(self, model, target_layers, reshape_transform): 7 | self.model = model 8 | self.gradients = [] 9 | self.activations = [] 10 | self.reshape_transform = reshape_transform 11 | self.handles = [] 12 | for target_layer in target_layers: 13 | self.handles.append( 14 | target_layer.register_forward_hook( 15 | self.save_activation)) 16 | if hasattr(target_layer, 'register_full_backward_hook'): 17 | self.handles.append( 18 | target_layer.register_full_backward_hook( 19 | self.save_gradient)) 20 | else: 21 | self.handles.append( 22 | target_layer.register_backward_hook( 23 | self.save_gradient)) 24 | 25 | def save_activation(self, module, input, output): 26 | activation = output 27 | if self.reshape_transform is not None: 28 | activation = self.reshape_transform(activation) 29 | self.activations.append(activation.cpu().detach()) 30 | 31 | def save_gradient(self, module, grad_input, grad_output): 32 | grad = grad_output[0] 33 | if self.reshape_transform is not None: 34 | grad = self.reshape_transform(grad) 35 | self.gradients = [grad.cpu().detach()] + self.gradients 36 | 37 | def __call__(self, x): 38 | self.gradients = [] 39 | self.activations = [] 40 | return self.model(x) 41 | 42 | def release(self): 43 | for handle in self.handles: 44 | handle.remove() 45 | 46 | 47 | class GradCAM: 48 | def __init__(self, 49 | model, 50 | target_layers, 51 | reshape_transform=None, 52 | use_cuda=False): 53 | self.model = model.eval() 54 | self.target_layers = target_layers 55 | self.reshape_transform = reshape_transform 56 | self.cuda = use_cuda 57 | if self.cuda: 58 | self.model = model.cuda() 59 | self.activations_and_grads = ActivationsAndGradients( 60 | self.model, target_layers, reshape_transform) 61 | 62 | @staticmethod 63 | def get_cam_weights(grads): 64 | return np.mean(grads, axis=(2, 3), keepdims=True) 65 | 66 | @staticmethod 67 | def get_loss(output, target_category): 68 | loss = 0 69 | for i in range(len(target_category)): 70 | loss = loss + output[i, target_category[i]] 71 | return loss 72 | 73 | def get_cam_image(self, activations, grads): 74 | weights = self.get_cam_weights(grads) 75 | weighted_activations = weights * activations 76 | cam = weighted_activations.sum(axis=1) 77 | 78 | return cam 79 | 80 | @staticmethod 81 | def get_target_width_height(input_tensor): 82 | width, height = input_tensor.size(-1), input_tensor.size(-2) 83 | return width, height 84 | 85 | def compute_cam_per_layer(self, input_tensor): 86 | activations_list = [a.cpu().data.numpy() 87 | for a in self.activations_and_grads.activations] 88 | grads_list = [g.cpu().data.numpy() 89 | for g in self.activations_and_grads.gradients] 90 | target_size = self.get_target_width_height(input_tensor) 91 | 92 | cam_per_target_layer = [] 93 | 94 | for layer_activations, layer_grads in zip(activations_list, grads_list): 95 | cam = self.get_cam_image(layer_activations, layer_grads) 96 | cam[cam < 0] = 0 97 | scaled = self.scale_cam_image(cam, target_size) 98 | cam_per_target_layer.append(scaled[:, None, :]) 99 | 100 | return cam_per_target_layer 101 | 102 | def aggregate_multi_layers(self, cam_per_target_layer): 103 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 104 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 105 | result = np.mean(cam_per_target_layer, axis=1) 106 | return self.scale_cam_image(result) 107 | 108 | @staticmethod 109 | def scale_cam_image(cam, target_size=None): 110 | result = [] 111 | for img in cam: 112 | img = img - np.min(img) 113 | img = img / (1e-7 + np.max(img)) 114 | if target_size is not None: 115 | img = cv2.resize(img, target_size) 116 | result.append(img) 117 | result = np.float32(result) 118 | 119 | return result 120 | 121 | def __call__(self, input_tensor, target_category=None): 122 | 123 | if self.cuda: 124 | input_tensor = input_tensor.cuda() 125 | 126 | output = self.activations_and_grads(input_tensor) 127 | if isinstance(target_category, int): 128 | target_category = [target_category] * input_tensor.size(0) 129 | 130 | if target_category is None: 131 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1) 132 | print(f"category id: {target_category}") 133 | else: 134 | assert (len(target_category) == input_tensor.size(0)) 135 | 136 | self.model.zero_grad() 137 | loss = self.get_loss(output, target_category) 138 | loss.backward(retain_graph=True) 139 | 140 | cam_per_layer = self.compute_cam_per_layer(input_tensor) 141 | return self.aggregate_multi_layers(cam_per_layer) 142 | 143 | def __del__(self): 144 | self.activations_and_grads.release() 145 | 146 | def __enter__(self): 147 | return self 148 | 149 | def __exit__(self, exc_type, exc_value, exc_tb): 150 | self.activations_and_grads.release() 151 | if isinstance(exc_value, IndexError): 152 | print( 153 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 154 | return True 155 | 156 | 157 | def show_cam_on_image(img: np.ndarray, 158 | mask: np.ndarray, 159 | use_rgb: bool = False, 160 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 161 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 162 | if use_rgb: 163 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 164 | heatmap = np.float32(heatmap) / 255 165 | 166 | if np.max(img) > 1: 167 | raise Exception( 168 | "The input image should np.float32 in the range [0, 1]") 169 | 170 | cam = heatmap + img 171 | cam = cam / np.max(cam) 172 | return np.uint8(255 * cam), heatmap 173 | 174 | 175 | def center_crop_img(img: np.ndarray, size: int): 176 | h, w, c = img.shape 177 | 178 | if w == h == size: 179 | return img 180 | 181 | if w < h: 182 | ratio = size / w 183 | new_w = size 184 | new_h = int(h * ratio) 185 | else: 186 | ratio = size / h 187 | new_h = size 188 | new_w = int(w * ratio) 189 | 190 | img = cv2.resize(img, dsize=(new_w, new_h)) 191 | 192 | if new_w == size: 193 | h = (new_h - size) // 2 194 | img = img[h: h + size] 195 | else: 196 | w = (new_w - size) // 2 197 | img = img[:, w: w + size] 198 | 199 | return img 200 | -------------------------------------------------------------------------------- /datasets/xray.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torchvision.transforms.functional as TF 3 | import torchvision as tv 4 | import os 5 | import torch 6 | import random 7 | import numpy as np 8 | from PIL import Image 9 | from .tokenizers import Tokenizer 10 | from .utils import nested_tensor_from_tensor_list, read_json 11 | 12 | 13 | class RandomRotation: 14 | def __init__(self, angles=[0, 90, 180, 270]): 15 | self.angles = angles 16 | 17 | def __call__(self, x): 18 | angle = random.choice(self.angles) 19 | return TF.rotate(x, angle, expand=True) 20 | 21 | 22 | def get_transform(MAX_DIM): 23 | def under_max(image): 24 | if image.mode != 'RGB': 25 | image = image.convert("RGB") 26 | 27 | shape = np.array(image.size, dtype=np.float) 28 | long_dim = max(shape) 29 | scale = MAX_DIM / long_dim 30 | 31 | new_shape = (shape * scale).astype(int) 32 | image = image.resize(new_shape) 33 | 34 | return image 35 | 36 | train_transform = tv.transforms.Compose([ 37 | RandomRotation(), 38 | tv.transforms.Lambda(under_max), 39 | tv.transforms.ColorJitter(brightness=[0.5, 1.3], contrast=[ 40 | 0.8, 1.5], saturation=[0.2, 1.5]), 41 | tv.transforms.RandomHorizontalFlip(), 42 | tv.transforms.ToTensor(), 43 | tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 44 | ]) 45 | 46 | val_transform = tv.transforms.Compose([ 47 | tv.transforms.Lambda(under_max), 48 | tv.transforms.ToTensor(), 49 | tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 50 | ]) 51 | return train_transform, val_transform 52 | 53 | 54 | transform_class = tv.transforms.Compose([ 55 | tv.transforms.Resize(224), 56 | tv.transforms.CenterCrop((224, 224)), 57 | tv.transforms.ToTensor(), 58 | tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 59 | ]) 60 | 61 | 62 | class XrayDataset(Dataset): 63 | def __init__(self, root, ann, max_length, limit, transform=None, transform_class=transform_class, 64 | mode='training', data_dir=None, dataset_name=None, image_size=None, 65 | theta=None, gamma=None, beta=None): 66 | super().__init__() 67 | 68 | self.root = root 69 | self.transform = transform 70 | self.transform_class = transform_class 71 | self.annot = ann 72 | 73 | self.data_dir = data_dir 74 | self.image_size = image_size 75 | 76 | self.theta = theta 77 | self.gamma = gamma 78 | self.beta = beta 79 | 80 | if mode == 'training': 81 | self.annot = self.annot[:] 82 | else: 83 | self.annot = self.annot[:] 84 | if dataset_name == "mimic_cxr": 85 | threshold = 10 86 | elif dataset_name == "iu_xray": 87 | threshold = 3 88 | self.data_name = dataset_name 89 | self.tokenizer = Tokenizer(ann_path=root, threshold=threshold, dataset_name=dataset_name) 90 | self.max_length = max_length + 1 91 | 92 | def _process(self, image_id): 93 | val = str(image_id).zfill(12) 94 | return val + '.jpg' 95 | 96 | def __len__(self): 97 | return len(self.annot) 98 | 99 | def __getitem__(self, idx): 100 | caption = self.annot[idx]["report"] 101 | image_path = self.annot[idx]['image_path'] 102 | image = Image.open(os.path.join(self.data_dir, image_path[0])).resize((300, 300)).convert('RGB') 103 | class_image = image 104 | com_image = image 105 | 106 | if self.data_name == "mimic_cxr": 107 | mask_arr = np.load(os.path.join(self.data_dir.strip("images300"), "images300_array", 108 | image_path[0].replace(".jpg", ".npy"))) 109 | else: 110 | mask_arr = np.load(os.path.join(self.data_dir.strip("images"), "images300_array", 111 | image_path[0].replace(".png", ".npy"))) 112 | 113 | if (np.sum(mask_arr) / 90000) > self.theta: 114 | image_arr = np.asarray(image) 115 | boost_arr = image_arr * np.expand_dims(mask_arr, 2) 116 | weak_arr = image_arr * np.expand_dims(1 - mask_arr, 2) 117 | image = Image.fromarray(boost_arr + (weak_arr * self.gamma).astype(np.uint8)) 118 | 119 | if self.transform: 120 | image = self.transform(image) 121 | com_image = self.transform(com_image) 122 | image = nested_tensor_from_tensor_list(image.unsqueeze(0), max_dim=self.image_size) 123 | com_image = nested_tensor_from_tensor_list(com_image.unsqueeze(0), max_dim=self.image_size) 124 | 125 | if self.transform_class: 126 | class_image = self.transform_class(class_image) 127 | 128 | caption = self.tokenizer(caption)[:self.max_length] 129 | cap_mask = [1] * len(caption) 130 | return image.tensors.squeeze(0), image.mask.squeeze(0), com_image.tensors.squeeze(0), com_image.mask.squeeze( 131 | 0), caption, cap_mask, class_image 132 | 133 | @staticmethod 134 | def collate_fn(data): 135 | max_length = 129 136 | image_batch, image_mask_batch, com_image_batch, com_image_mask_batch, report_ids_batch, report_masks_batch, class_image_batch = zip( 137 | *data) 138 | image_batch = torch.stack(image_batch, 0) 139 | image_mask_batch = torch.stack(image_mask_batch, 0) 140 | com_image_batch = torch.stack(com_image_batch, 0) 141 | com_image_mask_batch = torch.stack(com_image_mask_batch, 0) 142 | class_image_batch = torch.stack(class_image_batch, 0) 143 | target_batch = np.zeros((len(report_ids_batch), max_length), dtype=int) 144 | target_masks_batch = np.zeros((len(report_ids_batch), max_length), dtype=int) 145 | 146 | for i, report_ids in enumerate(report_ids_batch): 147 | target_batch[i, :len(report_ids)] = report_ids 148 | 149 | for i, report_masks in enumerate(report_masks_batch): 150 | target_masks_batch[i, :len(report_masks)] = report_masks 151 | target_masks_batch = 1 - target_masks_batch 152 | 153 | return image_batch, image_mask_batch, com_image_batch, com_image_mask_batch, torch.tensor( 154 | target_batch), torch.tensor(target_masks_batch, dtype=torch.bool), class_image_batch 155 | 156 | 157 | def build_dataset(config, mode='training', anno_path=None, data_dir=None, dataset_name=None, image_size=None, 158 | theta=None, gamma=None, beta=None): 159 | train_transform, val_transform = get_transform(MAX_DIM=image_size) 160 | if mode == 'training': 161 | train_file = anno_path 162 | data = XrayDataset(train_file, read_json( 163 | train_file)["train"], max_length=config.max_position_embeddings, limit=config.limit, 164 | transform=train_transform, 165 | mode='training', data_dir=data_dir, dataset_name=dataset_name, image_size=image_size, 166 | theta=theta, gamma=gamma, beta=beta) 167 | return data 168 | 169 | elif mode == 'validation': 170 | val_file = anno_path 171 | data = XrayDataset(val_file, read_json( 172 | val_file)["val"], max_length=config.max_position_embeddings, limit=config.limit, transform=val_transform, 173 | mode='validation', data_dir=data_dir, dataset_name=dataset_name, image_size=image_size, 174 | theta=theta, gamma=gamma, beta=beta) 175 | return data 176 | elif mode == 'test': 177 | test_file = anno_path 178 | data = XrayDataset(test_file, read_json( 179 | test_file)["test"], max_length=config.max_position_embeddings, limit=config.limit, transform=val_transform, 180 | mode='test', data_dir=data_dir, dataset_name=dataset_name, image_size=image_size, 181 | theta=theta, gamma=gamma, beta=beta) 182 | return data 183 | else: 184 | raise NotImplementedError(f"{mode} not supported") 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | absl-py=1.3.0=pypi_0 7 | aiohttp=3.8.4=pypi_0 8 | aiosignal=1.3.1=pypi_0 9 | albumentations=1.3.0=pypi_0 10 | allennlp=2.8.0=pypi_0 11 | anyio=3.6.2=pypi_0 12 | argon2-cffi=21.3.0=pypi_0 13 | argon2-cffi-bindings=21.2.0=pypi_0 14 | arrow=1.2.3=pypi_0 15 | asttokens=2.2.1=pypi_0 16 | async-timeout=4.0.2=pypi_0 17 | attrs=22.2.0=pypi_0 18 | autocommand=2.2.2=pypi_0 19 | av=10.0.0=pypi_0 20 | backcall=0.2.0=pypi_0 21 | backports-csv=1.0.7=pypi_0 22 | base58=2.1.1=pypi_0 23 | beautifulsoup4=4.11.1=pypi_0 24 | bleach=6.0.0=pypi_0 25 | blessed=1.20.0=pypi_0 26 | blis=0.7.9=pypi_0 27 | boto3=1.26.79=pypi_0 28 | botocore=1.29.79=pypi_0 29 | ca-certificates=2022.07.19=h06a4308_0 30 | cached-path=0.3.2=pypi_0 31 | cachetools=4.2.4=pypi_0 32 | catalogue=2.0.8=pypi_0 33 | certifi=2022.9.24=py38h06a4308_0 34 | cffi=1.15.1=pypi_0 35 | charset-normalizer=2.1.1=pypi_0 36 | checklist=0.0.11=pypi_0 37 | cheroot=9.0.0=pypi_0 38 | cherrypy=18.8.0=pypi_0 39 | click=8.1.3=pypi_0 40 | comm=0.1.3=pypi_0 41 | contourpy=1.0.7=pypi_0 42 | cryptography=40.0.1=pypi_0 43 | cycler=0.11.0=pypi_0 44 | cymem=2.0.7=pypi_0 45 | dassl=0.6.3=dev_0 46 | datasets=1.18.4=pypi_0 47 | debugpy=1.6.6=pypi_0 48 | decorator=5.1.1=pypi_0 49 | defusedxml=0.7.1=pypi_0 50 | dill=0.3.6=pypi_0 51 | docker-pycreds=0.4.0=pypi_0 52 | emoji=2.2.0=pypi_0 53 | entrypoints=0.3=pypi_0 54 | et-xmlfile=1.1.0=pypi_0 55 | exceptiongroup=1.1.1=pypi_0 56 | executing=1.2.0=pypi_0 57 | fairscale=0.4.0=pypi_0 58 | fastjsonschema=2.16.3=pypi_0 59 | feedparser=6.0.10=pypi_0 60 | filelock=3.3.2=pypi_0 61 | flake8=3.7.9=pypi_0 62 | fonttools=4.38.0=pypi_0 63 | fqdn=1.5.1=pypi_0 64 | frozenlist=1.3.3=pypi_0 65 | fsspec=2023.3.0=pypi_0 66 | ftfy=6.1.1=pypi_0 67 | future=0.18.2=pypi_0 68 | fvcore=0.1.5.post20221221=pypi_0 69 | gdown=4.5.1=pypi_0 70 | gitdb=4.0.10=pypi_0 71 | gitpython=3.1.31=pypi_0 72 | google-api-core=2.11.0=pypi_0 73 | google-auth=2.17.1=pypi_0 74 | google-auth-oauthlib=0.4.6=pypi_0 75 | google-cloud-core=2.3.2=pypi_0 76 | google-cloud-storage=1.44.0=pypi_0 77 | google-crc32c=1.5.0=pypi_0 78 | google-resumable-media=2.4.1=pypi_0 79 | googleapis-common-protos=1.59.0=pypi_0 80 | gpustat=1.1.1=pypi_0 81 | grpcio=1.49.1=pypi_0 82 | h5py=3.8.0=pypi_0 83 | huggingface-hub=0.1.2=pypi_0 84 | idna=3.4=pypi_0 85 | imageio=2.22.4=pypi_0 86 | importlib-metadata=5.0.0=pypi_0 87 | importlib-resources=5.12.0=pypi_0 88 | inflect=6.0.1=pypi_0 89 | iniconfig=2.0.0=pypi_0 90 | install=1.3.5=pypi_0 91 | iopath=0.1.10=pypi_0 92 | ipykernel=6.22.0=pypi_0 93 | ipython=8.11.0=pypi_0 94 | ipython-genutils=0.2.0=pypi_0 95 | ipywidgets=8.0.6=pypi_0 96 | iso-639=0.4.5=pypi_0 97 | isoduration=20.11.0=pypi_0 98 | isort=4.3.21=pypi_0 99 | jaraco-collections=4.0.0=pypi_0 100 | jaraco-context=4.3.0=pypi_0 101 | jaraco-functools=3.6.0=pypi_0 102 | jaraco-text=3.11.1=pypi_0 103 | jedi=0.18.2=pypi_0 104 | jinja2=3.1.2=pypi_0 105 | jmespath=1.0.1=pypi_0 106 | joblib=1.2.0=pypi_0 107 | jsonnet=0.19.1=pypi_0 108 | jsonpointer=2.3=pypi_0 109 | jsonschema=4.17.3=pypi_0 110 | jupyter=1.0.0=pypi_0 111 | jupyter-client=8.1.0=pypi_0 112 | jupyter-console=6.6.3=pypi_0 113 | jupyter-core=5.3.0=pypi_0 114 | jupyter-events=0.6.3=pypi_0 115 | jupyter-server=2.5.0=pypi_0 116 | jupyter-server-terminals=0.4.4=pypi_0 117 | jupyterlab-pygments=0.2.2=pypi_0 118 | jupyterlab-widgets=3.0.7=pypi_0 119 | kiwisolver=1.4.4=pypi_0 120 | ld_impl_linux-64=2.38=h1181459_1 121 | libffi=3.3=he6710b0_2 122 | libgcc-ng=11.2.0=h1234567_1 123 | libgomp=11.2.0=h1234567_1 124 | libstdcxx-ng=11.2.0=h1234567_1 125 | littleutils=0.2.2=pypi_0 126 | lmdb=1.3.0=pypi_0 127 | loguru=0.6.0=pypi_0 128 | lxml=4.9.2=pypi_0 129 | markdown=3.4.1=pypi_0 130 | markupsafe=2.1.1=pypi_0 131 | matplotlib=3.7.0=pypi_0 132 | matplotlib-inline=0.1.6=pypi_0 133 | maxflow=0.0.1=pypi_0 134 | mccabe=0.6.1=pypi_0 135 | mistune=2.0.5=pypi_0 136 | more-itertools=9.1.0=pypi_0 137 | multidict=6.0.4=pypi_0 138 | multiprocess=0.70.14=pypi_0 139 | munch=2.5.0=pypi_0 140 | murmurhash=1.0.9=pypi_0 141 | nbclassic=0.5.4=pypi_0 142 | nbclient=0.7.2=pypi_0 143 | nbconvert=7.2.10=pypi_0 144 | nbformat=5.8.0=pypi_0 145 | ncurses=6.3=h5eee18b_3 146 | nest-asyncio=1.5.6=pypi_0 147 | networkx=2.8.8=pypi_0 148 | nltk=3.8.1=pypi_0 149 | notebook=6.5.3=pypi_0 150 | notebook-shim=0.2.2=pypi_0 151 | numpy=1.23.4=pypi_0 152 | nvidia-ml-py=12.535.133=pypi_0 153 | oauthlib=3.2.1=pypi_0 154 | ogb=1.3.4=pypi_0 155 | opencv-python=4.2.0.34=pypi_0 156 | opencv-python-headless=4.6.0.66=pypi_0 157 | openpyxl=3.1.1=pypi_0 158 | openssl=1.1.1q=h7f8727e_0 159 | outdated=0.2.1=pypi_0 160 | overrides=3.1.0=pypi_0 161 | packaging=21.3=pypi_0 162 | pandas=1.5.0=pypi_0 163 | pandocfilters=1.5.0=pypi_0 164 | parameterized=0.8.1=pypi_0 165 | parso=0.8.3=pypi_0 166 | pathtools=0.1.2=pypi_0 167 | pathy=0.10.1=pypi_0 168 | patternfork-nosql=3.6=pypi_0 169 | pdfminer-six=20221105=pypi_0 170 | pexpect=4.8.0=pypi_0 171 | pickleshare=0.7.5=pypi_0 172 | pillow=9.2.0=pypi_0 173 | pip=23.0.1=pypi_0 174 | pkgutil-resolve-name=1.3.10=pypi_0 175 | platformdirs=3.2.0=pypi_0 176 | pluggy=1.0.0=pypi_0 177 | portalocker=2.6.0=pypi_0 178 | portend=3.1.0=pypi_0 179 | preshed=3.0.8=pypi_0 180 | prometheus-client=0.16.0=pypi_0 181 | promise=2.3=pypi_0 182 | prompt-toolkit=3.0.38=pypi_0 183 | protobuf=3.20.3=pypi_0 184 | psutil=5.9.4=pypi_0 185 | ptyprocess=0.7.0=pypi_0 186 | pure-eval=0.2.2=pypi_0 187 | pyarrow=11.0.0=pypi_0 188 | pyasn1=0.4.8=pypi_0 189 | pyasn1-modules=0.2.8=pypi_0 190 | pycocoevalcap=1.2=pypi_0 191 | pycocotools=2.0.7=pypi_0 192 | pycodestyle=2.5.0=pypi_0 193 | pycparser=2.21=pypi_0 194 | pydantic=1.8.2=pypi_0 195 | pyflakes=2.1.1=pypi_0 196 | pygments=2.14.0=pypi_0 197 | pymaxflow=1.3.0=pypi_0 198 | pyparsing=3.0.9=pypi_0 199 | pyrsistent=0.19.3=pypi_0 200 | pysocks=1.7.1=pypi_0 201 | pytest=7.2.2=pypi_0 202 | python=3.8.13=h12debd9_0 203 | python-dateutil=2.8.2=pypi_0 204 | python-docx=0.8.11=pypi_0 205 | python-json-logger=2.0.7=pypi_0 206 | pytorchvideo=0.1.5=pypi_0 207 | pytz=2022.4=pypi_0 208 | pywavelets=1.4.1=pypi_0 209 | pyyaml=6.0=pypi_0 210 | pyzmq=25.0.2=pypi_0 211 | qtconsole=5.4.1=pypi_0 212 | qtpy=2.3.1=pypi_0 213 | qudida=0.0.4=pypi_0 214 | readline=8.1.2=h7f8727e_1 215 | regex=2022.9.13=pypi_0 216 | requests=2.28.1=pypi_0 217 | requests-oauthlib=1.3.1=pypi_0 218 | responses=0.18.0=pypi_0 219 | rfc3339-validator=0.1.4=pypi_0 220 | rfc3986-validator=0.1.1=pypi_0 221 | rsa=4.9=pypi_0 222 | s3transfer=0.6.0=pypi_0 223 | sacremoses=0.0.53=pypi_0 224 | scikit-image=0.19.3=pypi_0 225 | scikit-learn=1.1.2=pypi_0 226 | scipy=1.9.2=pypi_0 227 | seaborn=0.12.2=pypi_0 228 | send2trash=1.8.0=pypi_0 229 | sentencepiece=0.1.97=pypi_0 230 | sentry-sdk=1.18.0=pypi_0 231 | setproctitle=1.3.2=pypi_0 232 | setuptools=67.6.0=pypi_0 233 | sgmllib3k=1.0.0=pypi_0 234 | shortuuid=1.0.11=pypi_0 235 | simplecrf=0.2.1.1=pypi_0 236 | six=1.16.0=pypi_0 237 | sk-video=1.1.10=pypi_0 238 | smart-open=6.3.0=pypi_0 239 | smmap=5.0.0=pypi_0 240 | sniffio=1.3.0=pypi_0 241 | soupsieve=2.3.2.post1=pypi_0 242 | spacy=3.1.7=pypi_0 243 | spacy-legacy=3.0.12=pypi_0 244 | sqlite=3.39.3=h5082296_0 245 | sqlitedict=2.1.0=pypi_0 246 | srsly=2.4.6=pypi_0 247 | stack-data=0.6.2=pypi_0 248 | stanza=1.5.0=pypi_0 249 | tabulate=0.9.0=pypi_0 250 | tb-nightly=2.11.0a20221013=pypi_0 251 | tempora=5.2.1=pypi_0 252 | tensorboard=2.4.1=pypi_0 253 | tensorboard-data-server=0.6.1=pypi_0 254 | tensorboard-logger=0.1.0=pypi_0 255 | tensorboard-plugin-wit=1.8.1=pypi_0 256 | tensorboardx=2.6=pypi_0 257 | termcolor=1.1.0=pypi_0 258 | terminado=0.17.1=pypi_0 259 | thinc=8.0.17=pypi_0 260 | threadpoolctl=3.1.0=pypi_0 261 | tifffile=2022.10.10=pypi_0 262 | timm=0.6.12=pypi_0 263 | tinycss2=1.2.1=pypi_0 264 | tk=8.6.12=h1ccaba5_0 265 | tokenizers=0.10.3=pypi_0 266 | tomli=2.0.1=pypi_0 267 | torch=1.9.0+cu111=pypi_0 268 | torchaudio=0.9.0=pypi_0 269 | torchtext=0.5.0=pypi_0 270 | torchvision=0.10.0+cu111=pypi_0 271 | tornado=6.2=pypi_0 272 | tqdm=4.62.3=pypi_0 273 | traitlets=5.9.0=pypi_0 274 | transformers=4.12.5=pypi_0 275 | typer=0.4.2=pypi_0 276 | typing-extensions=4.4.0=pypi_0 277 | uri-template=1.2.0=pypi_0 278 | urllib3=1.26.12=pypi_0 279 | wandb=0.12.21=pypi_0 280 | wasabi=0.10.1=pypi_0 281 | wcwidth=0.2.5=pypi_0 282 | webcolors=1.13=pypi_0 283 | webencodings=0.5.1=pypi_0 284 | websocket-client=1.5.1=pypi_0 285 | werkzeug=2.2.2=pypi_0 286 | wheel=0.40.0=pypi_0 287 | widgetsnbextension=4.0.7=pypi_0 288 | wilds=1.2.2=pypi_0 289 | xxhash=3.2.0=pypi_0 290 | xz=5.2.6=h5eee18b_0 291 | yacs=0.1.8=pypi_0 292 | yapf=0.29.0=pypi_0 293 | yarl=1.8.2=pypi_0 294 | zc-lockfile=3.0.post1=pypi_0 295 | zipp=3.9.0=pypi_0 296 | zlib=1.2.12=h5eee18b_3 297 | -------------------------------------------------------------------------------- /ADM/gradcam_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | class ActivationsAndGradients: 6 | """ Class for extracting activations and 7 | registering gradients from targeted intermediate layers """ 8 | 9 | def __init__(self, model, target_layers, reshape_transform): 10 | self.model = model 11 | self.gradients = [] 12 | self.activations = [] 13 | self.reshape_transform = reshape_transform 14 | self.handles = [] 15 | for target_layer in target_layers: 16 | self.handles.append( 17 | target_layer.register_forward_hook( 18 | self.save_activation)) 19 | # Backward compatibility with older pytorch versions: 20 | if hasattr(target_layer, 'register_full_backward_hook'): 21 | self.handles.append( 22 | target_layer.register_full_backward_hook( 23 | self.save_gradient)) 24 | else: 25 | self.handles.append( 26 | target_layer.register_backward_hook( 27 | self.save_gradient)) 28 | 29 | def save_activation(self, module, input, output): 30 | activation = output 31 | if self.reshape_transform is not None: 32 | activation = self.reshape_transform(activation) 33 | self.activations.append(activation.cpu().detach()) 34 | 35 | def save_gradient(self, module, grad_input, grad_output): 36 | # Gradients are computed in reverse order 37 | grad = grad_output[0] 38 | if self.reshape_transform is not None: 39 | grad = self.reshape_transform(grad) 40 | self.gradients = [grad.cpu().detach()] + self.gradients 41 | 42 | def __call__(self, x): 43 | self.gradients = [] 44 | self.activations = [] 45 | return self.model(x) 46 | 47 | def release(self): 48 | for handle in self.handles: 49 | handle.remove() 50 | 51 | 52 | class GradCAM: 53 | def __init__(self, 54 | model, 55 | target_layers, 56 | reshape_transform=None, 57 | use_cuda=False): 58 | self.model = model.eval() 59 | self.target_layers = target_layers 60 | self.reshape_transform = reshape_transform 61 | self.cuda = use_cuda 62 | if self.cuda: 63 | self.model = model.cuda() 64 | self.activations_and_grads = ActivationsAndGradients( 65 | self.model, target_layers, reshape_transform) 66 | 67 | """ Get a vector of weights for every channel in the target layer. 68 | Methods that return weights channels, 69 | will typically need to only implement this function. """ 70 | 71 | @staticmethod 72 | def get_cam_weights(grads): 73 | return np.mean(grads, axis=(2, 3), keepdims=True) 74 | 75 | @staticmethod 76 | def get_loss(output, target_category): 77 | loss = 0 78 | for i in range(len(target_category)): 79 | loss = loss + output[i, target_category[i]] 80 | return loss 81 | 82 | def get_cam_image(self, activations, grads): 83 | weights = self.get_cam_weights(grads) 84 | weighted_activations = weights * activations 85 | cam = weighted_activations.sum(axis=1) 86 | 87 | return cam 88 | 89 | @staticmethod 90 | def get_target_width_height(input_tensor): 91 | width, height = input_tensor.size(-1), input_tensor.size(-2) 92 | return width, height 93 | 94 | def compute_cam_per_layer(self, input_tensor): 95 | activations_list = [a.cpu().data.numpy() 96 | for a in self.activations_and_grads.activations] 97 | grads_list = [g.cpu().data.numpy() 98 | for g in self.activations_and_grads.gradients] 99 | target_size = self.get_target_width_height(input_tensor) 100 | 101 | cam_per_target_layer = [] 102 | # Loop over the saliency image from every layer 103 | 104 | for layer_activations, layer_grads in zip(activations_list, grads_list): 105 | cam = self.get_cam_image(layer_activations, layer_grads) 106 | cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image 107 | scaled = self.scale_cam_image(cam, target_size) 108 | cam_per_target_layer.append(scaled[:, None, :]) 109 | 110 | return cam_per_target_layer 111 | 112 | def aggregate_multi_layers(self, cam_per_target_layer): 113 | cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) 114 | cam_per_target_layer = np.maximum(cam_per_target_layer, 0) 115 | result = np.mean(cam_per_target_layer, axis=1) 116 | return self.scale_cam_image(result) 117 | 118 | @staticmethod 119 | def scale_cam_image(cam, target_size=None): 120 | result = [] 121 | for img in cam: 122 | img = img - np.min(img) 123 | img = img / (1e-7 + np.max(img)) 124 | if target_size is not None: 125 | img = cv2.resize(img, target_size) 126 | result.append(img) 127 | result = np.float32(result) 128 | 129 | return result 130 | 131 | def __call__(self, input_tensor, target_category=None): 132 | 133 | if self.cuda: 134 | input_tensor = input_tensor.cuda() 135 | 136 | output = self.activations_and_grads(input_tensor) 137 | if isinstance(target_category, int): 138 | target_category = [target_category] * input_tensor.size(0) 139 | 140 | if target_category is None: 141 | target_category = np.argmax(output.cpu().data.numpy(), axis=-1) 142 | print(f"category id: {target_category}") 143 | else: 144 | assert (len(target_category) == input_tensor.size(0)) 145 | 146 | self.model.zero_grad() 147 | loss = self.get_loss(output, target_category) 148 | loss.backward(retain_graph=True) 149 | 150 | # In most of the saliency attribution papers, the saliency is 151 | # computed with a single target layer. 152 | # Commonly it is the last convolutional layer. 153 | # Here we support passing a list with multiple target layers. 154 | # It will compute the saliency image for every image, 155 | # and then aggregate them (with a default mean aggregation). 156 | # This gives you more flexibility in case you just want to 157 | # use all conv layers for example, all Batchnorm layers, 158 | # or something else. 159 | cam_per_layer = self.compute_cam_per_layer(input_tensor) 160 | return self.aggregate_multi_layers(cam_per_layer) 161 | 162 | def __del__(self): 163 | self.activations_and_grads.release() 164 | 165 | def __enter__(self): 166 | return self 167 | 168 | def __exit__(self, exc_type, exc_value, exc_tb): 169 | self.activations_and_grads.release() 170 | if isinstance(exc_value, IndexError): 171 | # Handle IndexError here... 172 | print( 173 | f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") 174 | return True 175 | 176 | 177 | def show_cam_on_image(img: np.ndarray, 178 | mask: np.ndarray, 179 | use_rgb: bool = False, 180 | colormap: int = cv2.COLORMAP_JET) -> np.ndarray: 181 | """ This function overlays the cam mask on the image as an heatmap. 182 | By default the heatmap is in BGR format. 183 | 184 | :param img: The base image in RGB or BGR format. 185 | :param mask: The cam mask. 186 | :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. 187 | :param colormap: The OpenCV colormap to be used. 188 | :returns: The default image with the cam overlay. 189 | """ 190 | 191 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) 192 | if use_rgb: 193 | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) 194 | heatmap = np.float32(heatmap) / 255 195 | 196 | if np.max(img) > 1: 197 | raise Exception( 198 | "The input image should np.float32 in the range [0, 1]") 199 | 200 | cam = heatmap + img 201 | cam = cam / np.max(cam) 202 | return np.uint8(255 * cam), heatmap 203 | 204 | 205 | def center_crop_img(img: np.ndarray, size: int): 206 | h, w, c = img.shape 207 | 208 | if w == h == size: 209 | return img 210 | 211 | if w < h: 212 | ratio = size / w 213 | new_w = size 214 | new_h = int(h * ratio) 215 | else: 216 | ratio = size / h 217 | new_h = size 218 | new_w = int(w * ratio) 219 | 220 | img = cv2.resize(img, dsize=(new_w, new_h)) 221 | 222 | if new_w == size: 223 | h = (new_h - size) // 2 224 | img = img[h: h + size] 225 | else: 226 | w = (new_w - size) // 2 227 | img = img[:, w: w + size] 228 | 229 | return img 230 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import pickle 4 | import numpy as np 5 | import argparse 6 | import os 7 | from models import utils, caption 8 | from datasets import xray 9 | from utils.engine import train_one_epoch, evaluate 10 | from models.model import swin_tiny_patch4_window7_224 as create_model 11 | from utils.stloss import SoftTarget 12 | 13 | 14 | def build_diagnosisbot(num_classes, detector_weight_path): 15 | model = create_model(num_classes=num_classes) 16 | assert os.path.exists(detector_weight_path), "file: '{}' dose not exist.".format(detector_weight_path) 17 | model.load_state_dict(torch.load(detector_weight_path, map_location=torch.device('cpu')), strict=True) 18 | for k, v in model.named_parameters(): 19 | v.requires_grad = False 20 | return model 21 | 22 | 23 | def build_tmodel(config, device): 24 | tmodel, _ = caption.build_model(config) 25 | print("Loading teacher medel Checkpoint...") 26 | tcheckpoint = torch.load(config.t_model_weight_path, map_location='cpu') 27 | tmodel.load_state_dict(tcheckpoint['model']) 28 | tmodel.to(device) 29 | return tmodel 30 | 31 | 32 | def main(config): 33 | print(config) 34 | device = torch.device(config.device) 35 | print(f'Initializing Device: {device}') 36 | 37 | if os.path.exists(config.thresholds_path): 38 | with open(config.thresholds_path, "rb") as f: 39 | thresholds = pickle.load(f) 40 | 41 | seed = config.seed + utils.get_rank() 42 | torch.manual_seed(seed) 43 | np.random.seed(seed) 44 | 45 | detector = build_diagnosisbot(config.num_classes, config.detector_weight_path) 46 | detector.to(device) 47 | 48 | model, criterion = caption.build_model(config) 49 | criterionKD = SoftTarget(4.0) 50 | model.to(device) 51 | 52 | n_parameters = sum(p.numel() 53 | for p in model.parameters() if p.requires_grad) 54 | print(f"Number of params: {n_parameters}") 55 | 56 | param_dicts = [ 57 | {"params": [p for n, p in model.named_parameters( 58 | ) if "backbone" not in n and p.requires_grad]}, 59 | { 60 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 61 | "lr": config.lr_backbone, 62 | }, 63 | ] 64 | optimizer = torch.optim.AdamW( 65 | param_dicts, lr=config.lr, weight_decay=config.weight_decay) 66 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop) 67 | 68 | dataset_train = xray.build_dataset(config, mode='training', anno_path=config.anno_path, data_dir=config.data_dir, 69 | dataset_name=config.dataset_name, image_size=config.image_size, 70 | theta=config.theta, gamma=config.gamma, beta=config.beta) 71 | dataset_val = xray.build_dataset(config, mode='validation', anno_path=config.anno_path, data_dir=config.data_dir, 72 | dataset_name=config.dataset_name, image_size=config.image_size, 73 | theta=config.theta, gamma=config.gamma, beta=config.beta) 74 | dataset_test = xray.build_dataset(config, mode='test', anno_path=config.anno_path, data_dir=config.data_dir, 75 | dataset_name=config.dataset_name, image_size=config.image_size, 76 | theta=config.theta, gamma=config.gamma, beta=config.beta) 77 | print(f"Train: {len(dataset_train)}") 78 | print(f"Valid: {len(dataset_val)}") 79 | print(f"Test: {len(dataset_test)}") 80 | 81 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 82 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 83 | sampler_test = torch.utils.data.SequentialSampler(dataset_test) 84 | 85 | batch_sampler_train = torch.utils.data.BatchSampler( 86 | sampler_train, config.batch_size, drop_last=True 87 | ) 88 | 89 | data_loader_train = DataLoader( 90 | dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers, 91 | collate_fn=dataset_train.collate_fn) 92 | data_loader_val = DataLoader(dataset_val, config.batch_size, 93 | sampler=sampler_val, drop_last=False, 94 | collate_fn=dataset_val.collate_fn) 95 | 96 | data_loader_test = DataLoader(dataset_test, config.batch_size, 97 | sampler=sampler_test, drop_last=False, 98 | collate_fn=dataset_test.collate_fn) 99 | if config.mode == "train": 100 | tmodel = build_tmodel(config, device) 101 | print("Start Training..") 102 | for epoch in range(config.start_epoch, config.epochs): 103 | print(f"Epoch: {epoch}") 104 | epoch_loss = train_one_epoch( 105 | model, tmodel, detector, criterion, criterionKD, data_loader_train, optimizer, device, 106 | config.clip_max_norm, thresholds=thresholds, tokenizer=dataset_train.tokenizer, config=config) 107 | lr_scheduler.step() 108 | print(f"Training Loss: {epoch_loss}") 109 | 110 | torch.save({ 111 | 'model': model.state_dict(), 112 | 'optimizer': optimizer.state_dict(), 113 | 'lr_scheduler': lr_scheduler.state_dict(), 114 | 'epoch': epoch, 115 | }, config.dataset_name + "_weight_epoch" + str(epoch) + "_.pth") 116 | 117 | validate_result = evaluate(model, detector, criterion, data_loader_val, device, config, 118 | thresholds=thresholds, tokenizer=dataset_val.tokenizer) 119 | print(f"validate_result: {validate_result}") 120 | test_result = evaluate(model, detector, criterion, data_loader_test, device, config, 121 | thresholds=thresholds, tokenizer=dataset_test.tokenizer) 122 | print(f"test_result: {test_result}") 123 | if config.mode == "test": 124 | if os.path.exists(config.test_path): 125 | weights_dict = torch.load(config.test_path, map_location='cpu')['model'] 126 | model.load_state_dict(weights_dict, strict=False) 127 | 128 | print("Start Testing..") 129 | test_result = evaluate(model, detector, criterion, data_loader_test, device, config, 130 | thresholds=thresholds, tokenizer=dataset_test.tokenizer) 131 | print(f"test_result: {test_result}") 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | 137 | parser.add_argument('--epochs', type=int, default=50) 138 | parser.add_argument('--lr_drop', type=int, default=20) 139 | parser.add_argument('--start_epoch', type=int, default=0) 140 | parser.add_argument('--weight_decay', type=float, default=1e-4) 141 | 142 | # Backbone 143 | parser.add_argument('--backbone', type=str, default='resnet101') 144 | parser.add_argument('--position_embedding', type=str, default='sine') 145 | parser.add_argument('--dilation', type=bool, default=True) 146 | # Basic 147 | parser.add_argument('--lr_backbone', type=float, default=1e-5) 148 | parser.add_argument('--lr', type=float, default=1e-4) 149 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 150 | parser.add_argument('--seed', type=int, default=42) 151 | parser.add_argument('--batch_size', type=int, default=16) 152 | parser.add_argument('--num_workers', type=int, default=8) 153 | parser.add_argument('--clip_max_norm', type=float, default=0.1) 154 | 155 | # Transformer 156 | parser.add_argument('--hidden_dim', type=int, default=256) 157 | parser.add_argument('--pad_token_id', type=int, default=0) 158 | parser.add_argument('--max_position_embeddings', type=int, default=128) 159 | parser.add_argument('--layer_norm_eps', type=float, default=1e-12) 160 | parser.add_argument('--dropout', type=float, default=0.1) 161 | parser.add_argument('--vocab_size', type=int, default=4253) 162 | parser.add_argument('--start_token', type=int, default=1) 163 | parser.add_argument('--end_token', type=int, default=2) 164 | 165 | parser.add_argument('--enc_layers', type=int, default=6) 166 | parser.add_argument('--dec_layers', type=int, default=6) 167 | parser.add_argument('--dim_feedforward', type=int, default=2048) 168 | parser.add_argument('--nheads', type=int, default=8) 169 | parser.add_argument('--pre_norm', type=int, default=True) 170 | 171 | # diagnosisbot 172 | parser.add_argument('--num_classes', type=int, default=14) 173 | parser.add_argument('--thresholds_path', type=str, default="./datasets/thresholds.pkl") 174 | parser.add_argument('--detector_weight_path', type=str, default="./weight_path/diagnosisbot.pth") 175 | parser.add_argument('--t_model_weight_path', type=str, default="./weight_path/mimic_t_model.pth") 176 | parser.add_argument('--knowledge_prompt_path', type=str, default="./knowledge_path/knowledge_prompt_mimic.pkl") 177 | 178 | # ADA 179 | parser.add_argument('--theta', type=float, default=0.4) 180 | parser.add_argument('--gamma', type=float, default=0.4) 181 | parser.add_argument('--beta', type=float, default=1.0) 182 | 183 | # Delta 184 | parser.add_argument('--delta', type=float, default=0.01) 185 | 186 | # Dataset 187 | parser.add_argument('--image_size', type=int, default=300) 188 | parser.add_argument('--dataset_name', type=str, default='mimic_cxr') 189 | parser.add_argument('--anno_path', type=str, default='../dataset/mimic_cxr/annotation.json') 190 | parser.add_argument('--data_dir', type=str, default='../dataset/mimic_cxr/images300') 191 | parser.add_argument('--limit', type=int, default=-1) 192 | 193 | # mode 194 | parser.add_argument('--mode', type=str, default="train") 195 | parser.add_argument('--test_path', type=str, default="") 196 | 197 | config = parser.parse_args() 198 | main(config) 199 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | from models.utils import get_knowledge 7 | 8 | 9 | class Transformer(nn.Module): 10 | 11 | def __init__(self, config, d_model=512, nhead=8, num_encoder_layers=6, 12 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 13 | activation="relu", normalize_before=False, 14 | return_intermediate_dec=False): 15 | super().__init__() 16 | 17 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 18 | dropout, activation, normalize_before) 19 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 20 | self.encoder = TransformerEncoder( 21 | encoder_layer, num_encoder_layers, encoder_norm) 22 | 23 | self.embeddings = DecoderEmbeddings(config) 24 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 25 | dropout, activation, normalize_before) 26 | decoder_norm = nn.LayerNorm(d_model) 27 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, config, 28 | return_intermediate=return_intermediate_dec) 29 | 30 | self.knowledge_prompt = get_knowledge(config.knowledge_prompt_path) 31 | self._reset_parameters() 32 | 33 | self.d_model = d_model 34 | self.nhead = nhead 35 | 36 | def _reset_parameters(self): 37 | for p in self.parameters(): 38 | if p.dim() > 1: 39 | nn.init.xavier_uniform_(p) 40 | 41 | def forward(self, src, mask, pos_embed, tgt, tgt_mask, class_feature): 42 | bs, c, h, w = src.shape 43 | src = src.flatten(2).permute(2, 0, 1) 44 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 45 | mask = mask.flatten(1) 46 | class_feature = class_feature[0] 47 | batch_key_list = [tuple(item) for item in class_feature] 48 | class_feature = [self.knowledge_prompt[key] for key in batch_key_list] 49 | class_feature = torch.stack(class_feature).to(tgt.device) 50 | 51 | class_feature = class_feature.permute(1, 0, 2).detach() 52 | tgt = self.embeddings(tgt).permute(1, 0, 2) 53 | query_embed = self.embeddings.position_embeddings.weight.unsqueeze(1) 54 | query_embed = query_embed.repeat(1, bs, 1) 55 | 56 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 57 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, tgt_key_padding_mask=tgt_mask, 58 | pos=pos_embed, query_pos=query_embed, 59 | tgt_mask=generate_square_subsequent_mask(len(tgt)).to(tgt.device), 60 | class_feature=class_feature) 61 | return hs 62 | 63 | 64 | class TransformerEncoder(nn.Module): 65 | 66 | def __init__(self, encoder_layer, num_layers, norm=None): 67 | super().__init__() 68 | self.layers = _get_clones(encoder_layer, num_layers) 69 | self.num_layers = num_layers 70 | self.norm = norm 71 | 72 | def forward(self, src, 73 | mask: Optional[Tensor] = None, 74 | src_key_padding_mask: Optional[Tensor] = None, 75 | pos: Optional[Tensor] = None): 76 | output = src 77 | 78 | for layer in self.layers: 79 | output = layer(output, src_mask=mask, 80 | src_key_padding_mask=src_key_padding_mask, pos=pos) 81 | 82 | if self.norm is not None: 83 | output = self.norm(output) 84 | 85 | return output 86 | 87 | 88 | class TransformerDecoder(nn.Module): 89 | 90 | def __init__(self, decoder_layer, num_layers, norm=None, config=None, return_intermediate=False): 91 | super().__init__() 92 | self.layers = _get_clones(decoder_layer, num_layers) 93 | self.num_layers = num_layers 94 | self.norm = norm 95 | self.return_intermediate = return_intermediate 96 | self.fc1 = _get_clones(nn.Linear(config.hidden_dim, config.hidden_dim), config.dec_layers) 97 | self.fc2 = _get_clones(nn.Linear(config.hidden_dim, config.hidden_dim), config.dec_layers) 98 | self.fc3 = _get_clones(nn.Linear(config.hidden_dim * 2, config.hidden_dim), config.dec_layers) 99 | 100 | def forward(self, tgt, memory, 101 | tgt_mask: Optional[Tensor] = None, 102 | memory_mask: Optional[Tensor] = None, 103 | tgt_key_padding_mask: Optional[Tensor] = None, 104 | memory_key_padding_mask: Optional[Tensor] = None, 105 | pos: Optional[Tensor] = None, 106 | query_pos: Optional[Tensor] = None, 107 | class_feature=None): 108 | output = tgt 109 | intermediate = [] 110 | 111 | for i, layer in enumerate(self.layers): 112 | output = layer(output, memory, tgt_mask=tgt_mask, 113 | memory_mask=memory_mask, 114 | tgt_key_padding_mask=tgt_key_padding_mask, 115 | memory_key_padding_mask=memory_key_padding_mask, 116 | pos=pos, query_pos=query_pos) 117 | if class_feature is not None: 118 | # pass 119 | output = torch.cat((self.fc1[i](output), self.fc2[i](class_feature)), dim=2) 120 | output = self.fc3[i](output) 121 | if self.return_intermediate: 122 | intermediate.append(self.norm(output)) 123 | 124 | if self.norm is not None: 125 | output = self.norm(output) 126 | if self.return_intermediate: 127 | intermediate.pop() 128 | intermediate.append(output) 129 | 130 | if self.return_intermediate: 131 | return torch.stack(intermediate) 132 | 133 | return output 134 | 135 | 136 | class TransformerEncoderLayer(nn.Module): 137 | 138 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 139 | activation="relu", normalize_before=False): 140 | super().__init__() 141 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 142 | self.linear1 = nn.Linear(d_model, dim_feedforward) 143 | self.dropout = nn.Dropout(dropout) 144 | self.linear2 = nn.Linear(dim_feedforward, d_model) 145 | 146 | self.norm1 = nn.LayerNorm(d_model) 147 | self.norm2 = nn.LayerNorm(d_model) 148 | self.dropout1 = nn.Dropout(dropout) 149 | self.dropout2 = nn.Dropout(dropout) 150 | 151 | self.activation = _get_activation_fn(activation) 152 | self.normalize_before = normalize_before 153 | 154 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 155 | return tensor if pos is None else tensor + pos 156 | 157 | def forward_post(self, 158 | src, 159 | src_mask: Optional[Tensor] = None, 160 | src_key_padding_mask: Optional[Tensor] = None, 161 | pos: Optional[Tensor] = None): 162 | q = k = self.with_pos_embed(src, pos) 163 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 164 | key_padding_mask=src_key_padding_mask)[0] 165 | src = src + self.dropout1(src2) 166 | src = self.norm1(src) 167 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 168 | src = src + self.dropout2(src2) 169 | src = self.norm2(src) 170 | return src 171 | 172 | def forward_pre(self, src, 173 | src_mask: Optional[Tensor] = None, 174 | src_key_padding_mask: Optional[Tensor] = None, 175 | pos: Optional[Tensor] = None): 176 | src2 = self.norm1(src) 177 | q = k = self.with_pos_embed(src2, pos) 178 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 179 | key_padding_mask=src_key_padding_mask)[0] 180 | src = src + self.dropout1(src2) 181 | src2 = self.norm2(src) 182 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 183 | src = src + self.dropout2(src2) 184 | return src 185 | 186 | def forward(self, src, 187 | src_mask: Optional[Tensor] = None, 188 | src_key_padding_mask: Optional[Tensor] = None, 189 | pos: Optional[Tensor] = None): 190 | if self.normalize_before: 191 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 192 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 193 | 194 | 195 | class TransformerDecoderLayer(nn.Module): 196 | 197 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 198 | activation="relu", normalize_before=False): 199 | super().__init__() 200 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 201 | self.multihead_attn = nn.MultiheadAttention( 202 | d_model, nhead, dropout=dropout) 203 | self.linear1 = nn.Linear(d_model, dim_feedforward) 204 | self.dropout = nn.Dropout(dropout) 205 | self.linear2 = nn.Linear(dim_feedforward, d_model) 206 | 207 | self.norm1 = nn.LayerNorm(d_model) 208 | self.norm2 = nn.LayerNorm(d_model) 209 | self.norm3 = nn.LayerNorm(d_model) 210 | self.dropout1 = nn.Dropout(dropout) 211 | self.dropout2 = nn.Dropout(dropout) 212 | self.dropout3 = nn.Dropout(dropout) 213 | 214 | self.activation = _get_activation_fn(activation) 215 | self.normalize_before = normalize_before 216 | 217 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 218 | return tensor if pos is None else tensor + pos 219 | 220 | def forward_post(self, tgt, memory, 221 | tgt_mask: Optional[Tensor] = None, 222 | memory_mask: Optional[Tensor] = None, 223 | tgt_key_padding_mask: Optional[Tensor] = None, 224 | memory_key_padding_mask: Optional[Tensor] = None, 225 | pos: Optional[Tensor] = None, 226 | query_pos: Optional[Tensor] = None): 227 | q = k = self.with_pos_embed(tgt, query_pos) 228 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 229 | key_padding_mask=tgt_key_padding_mask)[0] 230 | tgt = tgt + self.dropout1(tgt2) 231 | tgt = self.norm1(tgt) 232 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 233 | key=self.with_pos_embed(memory, pos), 234 | value=memory, attn_mask=memory_mask, 235 | key_padding_mask=memory_key_padding_mask)[0] 236 | tgt = tgt + self.dropout2(tgt2) 237 | tgt = self.norm2(tgt) 238 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 239 | tgt = tgt + self.dropout3(tgt2) 240 | tgt = self.norm3(tgt) 241 | return tgt 242 | 243 | def forward_pre(self, tgt, memory, 244 | tgt_mask: Optional[Tensor] = None, 245 | memory_mask: Optional[Tensor] = None, 246 | tgt_key_padding_mask: Optional[Tensor] = None, 247 | memory_key_padding_mask: Optional[Tensor] = None, 248 | pos: Optional[Tensor] = None, 249 | query_pos: Optional[Tensor] = None): 250 | tgt2 = self.norm1(tgt) 251 | q = k = self.with_pos_embed(tgt2, query_pos) 252 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 253 | key_padding_mask=tgt_key_padding_mask)[0] 254 | tgt = tgt + self.dropout1(tgt2) 255 | tgt2 = self.norm2(tgt) 256 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 257 | key=self.with_pos_embed(memory, pos), 258 | value=memory, attn_mask=memory_mask, 259 | key_padding_mask=memory_key_padding_mask)[0] 260 | tgt = tgt + self.dropout2(tgt2) 261 | tgt2 = self.norm3(tgt) 262 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 263 | tgt = tgt + self.dropout3(tgt2) 264 | return tgt 265 | 266 | def forward(self, tgt, memory, 267 | tgt_mask: Optional[Tensor] = None, 268 | memory_mask: Optional[Tensor] = None, 269 | tgt_key_padding_mask: Optional[Tensor] = None, 270 | memory_key_padding_mask: Optional[Tensor] = None, 271 | pos: Optional[Tensor] = None, 272 | query_pos: Optional[Tensor] = None): 273 | if self.normalize_before: 274 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 275 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 276 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 277 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 278 | 279 | 280 | class DecoderEmbeddings(nn.Module): 281 | def __init__(self, config): 282 | super().__init__() 283 | self.word_embeddings = nn.Embedding( 284 | config.vocab_size, config.hidden_dim, padding_idx=config.pad_token_id) 285 | self.position_embeddings = nn.Embedding( 286 | config.max_position_embeddings, config.hidden_dim 287 | ) 288 | 289 | self.LayerNorm = torch.nn.LayerNorm( 290 | config.hidden_dim, eps=config.layer_norm_eps) 291 | self.dropout = nn.Dropout(config.dropout) 292 | 293 | def forward(self, x): 294 | input_shape = x.size() 295 | seq_length = input_shape[1] 296 | device = x.device 297 | 298 | position_ids = torch.arange( 299 | seq_length, dtype=torch.long, device=device) 300 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 301 | 302 | input_embeds = self.word_embeddings(x) 303 | position_embeds = self.position_embeddings(position_ids) 304 | 305 | embeddings = input_embeds + position_embeds 306 | embeddings = self.LayerNorm(embeddings) 307 | embeddings = self.dropout(embeddings) 308 | 309 | return embeddings 310 | 311 | 312 | def _get_clones(module, N): 313 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 314 | 315 | 316 | def _get_activation_fn(activation): 317 | """Return an activation function given a string""" 318 | if activation == "relu": 319 | return F.relu 320 | if activation == "gelu": 321 | return F.gelu 322 | if activation == "glu": 323 | return F.glu 324 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 325 | 326 | 327 | def generate_square_subsequent_mask(sz): 328 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 329 | Unmasked positions are filled with float(0.0). 330 | """ 331 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 332 | mask = mask.float().masked_fill(mask == 0, float( 333 | '-inf')).masked_fill(mask == 1, float(0.0)) 334 | return mask 335 | 336 | 337 | def build_transformer(config): 338 | return Transformer( 339 | config, 340 | d_model=config.hidden_dim, 341 | dropout=config.dropout, 342 | nhead=config.nheads, 343 | dim_feedforward=config.dim_feedforward, 344 | num_encoder_layers=config.enc_layers, 345 | num_decoder_layers=config.dec_layers, 346 | normalize_before=config.pre_norm, 347 | return_intermediate_dec=False, 348 | ) 349 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | """ Swin Transformer 2 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` 3 | - https://arxiv.org/pdf/2103.14030 4 | 5 | Code/weights from https://github.com/microsoft/Swin-Transformer 6 | 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.checkpoint as checkpoint 13 | import numpy as np 14 | from typing import Optional 15 | 16 | 17 | def drop_path_f(x, drop_prob: float = 0., training: bool = False): 18 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 19 | 20 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 21 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 22 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 23 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 24 | 'survival rate' as the argument. 25 | 26 | """ 27 | if drop_prob == 0. or not training: 28 | return x 29 | keep_prob = 1 - drop_prob 30 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 31 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 32 | random_tensor.floor_() # binarize 33 | output = x.div(keep_prob) * random_tensor 34 | return output 35 | 36 | 37 | class DropPath(nn.Module): 38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 39 | """ 40 | 41 | def __init__(self, drop_prob=None): 42 | super(DropPath, self).__init__() 43 | self.drop_prob = drop_prob 44 | 45 | def forward(self, x): 46 | return drop_path_f(x, self.drop_prob, self.training) 47 | 48 | 49 | def window_partition(x, window_size: int): 50 | B, H, W, C = x.shape 51 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 52 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 53 | return windows 54 | 55 | 56 | def window_reverse(windows, window_size: int, H: int, W: int): 57 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 58 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 59 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 60 | return x 61 | 62 | 63 | class PatchEmbed(nn.Module): 64 | """ 65 | 2D Image to Patch Embedding 66 | """ 67 | 68 | def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None): 69 | super().__init__() 70 | patch_size = (patch_size, patch_size) 71 | self.patch_size = patch_size 72 | self.in_chans = in_c 73 | self.embed_dim = embed_dim 74 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 75 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 76 | 77 | def forward(self, x): 78 | _, _, H, W = x.shape 79 | 80 | pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0) 81 | if pad_input: 82 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1], 83 | 0, self.patch_size[0] - H % self.patch_size[0], 84 | 0, 0)) 85 | 86 | x = self.proj(x) 87 | _, _, H, W = x.shape 88 | x = x.flatten(2).transpose(1, 2) 89 | x = self.norm(x) 90 | return x, H, W 91 | 92 | 93 | class PatchMerging(nn.Module): 94 | 95 | def __init__(self, dim, norm_layer=nn.LayerNorm): 96 | super().__init__() 97 | self.dim = dim 98 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 99 | self.norm = norm_layer(4 * dim) 100 | 101 | def forward(self, x, H, W): 102 | """ 103 | x: B, H*W, C 104 | """ 105 | B, L, C = x.shape 106 | assert L == H * W, "input feature has wrong size" 107 | 108 | x = x.view(B, H, W, C) 109 | 110 | pad_input = (H % 2 == 1) or (W % 2 == 1) 111 | if pad_input: 112 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 113 | 114 | x0 = x[:, 0::2, 0::2, :] 115 | x1 = x[:, 1::2, 0::2, :] 116 | x2 = x[:, 0::2, 1::2, :] 117 | x3 = x[:, 1::2, 1::2, :] 118 | x = torch.cat([x0, x1, x2, x3], -1) 119 | x = x.view(B, -1, 4 * C) 120 | 121 | x = self.norm(x) 122 | x = self.reduction(x) 123 | 124 | return x 125 | 126 | 127 | class Mlp(nn.Module): 128 | 129 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 130 | super().__init__() 131 | out_features = out_features or in_features 132 | hidden_features = hidden_features or in_features 133 | 134 | self.fc1 = nn.Linear(in_features, hidden_features) 135 | self.act = act_layer() 136 | self.drop1 = nn.Dropout(drop) 137 | self.fc2 = nn.Linear(hidden_features, out_features) 138 | self.drop2 = nn.Dropout(drop) 139 | 140 | def forward(self, x): 141 | x = self.fc1(x) 142 | x = self.act(x) 143 | x = self.drop1(x) 144 | x = self.fc2(x) 145 | x = self.drop2(x) 146 | return x 147 | 148 | 149 | class WindowAttention(nn.Module): 150 | 151 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): 152 | 153 | super().__init__() 154 | self.dim = dim 155 | self.window_size = window_size 156 | self.num_heads = num_heads 157 | head_dim = dim // num_heads 158 | self.scale = head_dim ** -0.5 159 | 160 | self.relative_position_bias_table = nn.Parameter( 161 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 162 | 163 | coords_h = torch.arange(self.window_size[0]) 164 | coords_w = torch.arange(self.window_size[1]) 165 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 166 | coords_flatten = torch.flatten(coords, 1) 167 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 168 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 169 | relative_coords[:, :, 0] += self.window_size[0] - 1 170 | relative_coords[:, :, 1] += self.window_size[1] - 1 171 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 172 | relative_position_index = relative_coords.sum(-1) 173 | self.register_buffer("relative_position_index", relative_position_index) 174 | 175 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 176 | self.attn_drop = nn.Dropout(attn_drop) 177 | self.proj = nn.Linear(dim, dim) 178 | self.proj_drop = nn.Dropout(proj_drop) 179 | 180 | nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) 181 | self.softmax = nn.Softmax(dim=-1) 182 | 183 | def forward(self, x, mask: Optional[torch.Tensor] = None): 184 | 185 | B_, N, C = x.shape 186 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 187 | q, k, v = qkv.unbind(0) 188 | 189 | q = q * self.scale 190 | attn = (q @ k.transpose(-2, -1)) 191 | 192 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 193 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 194 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 195 | attn = attn + relative_position_bias.unsqueeze(0) 196 | 197 | if mask is not None: 198 | nW = mask.shape[0] # num_windows 199 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 200 | attn = attn.view(-1, self.num_heads, N, N) 201 | attn = self.softmax(attn) 202 | else: 203 | attn = self.softmax(attn) 204 | 205 | attn = self.attn_drop(attn) 206 | 207 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 208 | x = self.proj(x) 209 | x = self.proj_drop(x) 210 | return x 211 | 212 | 213 | class SwinTransformerBlock(nn.Module): 214 | r""" Swin Transformer Block. 215 | 216 | Args: 217 | dim (int): Number of input channels. 218 | num_heads (int): Number of attention heads. 219 | window_size (int): Window size. 220 | shift_size (int): Shift size for SW-MSA. 221 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 222 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 223 | drop (float, optional): Dropout rate. Default: 0.0 224 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 225 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 226 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 227 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 228 | """ 229 | 230 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 231 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 232 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 233 | super().__init__() 234 | self.dim = dim 235 | self.num_heads = num_heads 236 | self.window_size = window_size 237 | self.shift_size = shift_size 238 | self.mlp_ratio = mlp_ratio 239 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 240 | 241 | self.norm1 = norm_layer(dim) 242 | self.attn = WindowAttention( 243 | dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, 244 | attn_drop=attn_drop, proj_drop=drop) 245 | 246 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 247 | self.norm2 = norm_layer(dim) 248 | mlp_hidden_dim = int(dim * mlp_ratio) 249 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 250 | 251 | def forward(self, x, attn_mask): 252 | H, W = self.H, self.W 253 | B, L, C = x.shape 254 | assert L == H * W, "input feature has wrong size" 255 | 256 | shortcut = x 257 | x = self.norm1(x) 258 | x = x.view(B, H, W, C) 259 | 260 | pad_l = pad_t = 0 261 | pad_r = (self.window_size - W % self.window_size) % self.window_size 262 | pad_b = (self.window_size - H % self.window_size) % self.window_size 263 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 264 | _, Hp, Wp, _ = x.shape 265 | 266 | if self.shift_size > 0: 267 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 268 | else: 269 | shifted_x = x 270 | attn_mask = None 271 | 272 | x_windows = window_partition(shifted_x, self.window_size) 273 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 274 | 275 | attn_windows = self.attn(x_windows, mask=attn_mask) 276 | 277 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 278 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) 279 | 280 | if self.shift_size > 0: 281 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 282 | else: 283 | x = shifted_x 284 | 285 | if pad_r > 0 or pad_b > 0: 286 | x = x[:, :H, :W, :].contiguous() 287 | 288 | x = x.view(B, H * W, C) 289 | 290 | x = shortcut + self.drop_path(x) 291 | x = x + self.drop_path(self.mlp(self.norm2(x))) 292 | 293 | return x 294 | 295 | 296 | class BasicLayer(nn.Module): 297 | """ 298 | A basic Swin Transformer layer for one stage. 299 | 300 | Args: 301 | dim (int): Number of input channels. 302 | depth (int): Number of blocks. 303 | num_heads (int): Number of attention heads. 304 | window_size (int): Local window size. 305 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 306 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 307 | drop (float, optional): Dropout rate. Default: 0.0 308 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 309 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 310 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 311 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 312 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 313 | """ 314 | 315 | def __init__(self, dim, depth, num_heads, window_size, 316 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 317 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 318 | super().__init__() 319 | self.dim = dim 320 | self.depth = depth 321 | self.window_size = window_size 322 | self.use_checkpoint = use_checkpoint 323 | self.shift_size = window_size // 2 324 | 325 | # build blocks 326 | self.blocks = nn.ModuleList([ 327 | SwinTransformerBlock( 328 | dim=dim, 329 | num_heads=num_heads, 330 | window_size=window_size, 331 | shift_size=0 if (i % 2 == 0) else self.shift_size, 332 | mlp_ratio=mlp_ratio, 333 | qkv_bias=qkv_bias, 334 | drop=drop, 335 | attn_drop=attn_drop, 336 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 337 | norm_layer=norm_layer) 338 | for i in range(depth)]) 339 | 340 | if downsample is not None: 341 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 342 | else: 343 | self.downsample = None 344 | 345 | def create_mask(self, x, H, W): 346 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 347 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 348 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) 349 | h_slices = (slice(0, -self.window_size), 350 | slice(-self.window_size, -self.shift_size), 351 | slice(-self.shift_size, None)) 352 | w_slices = (slice(0, -self.window_size), 353 | slice(-self.window_size, -self.shift_size), 354 | slice(-self.shift_size, None)) 355 | cnt = 0 356 | for h in h_slices: 357 | for w in w_slices: 358 | img_mask[:, h, w, :] = cnt 359 | cnt += 1 360 | 361 | mask_windows = window_partition(img_mask, self.window_size) 362 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 363 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 364 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 365 | return attn_mask 366 | 367 | def forward(self, x, H, W): 368 | attn_mask = self.create_mask(x, H, W) 369 | for blk in self.blocks: 370 | blk.H, blk.W = H, W 371 | if not torch.jit.is_scripting() and self.use_checkpoint: 372 | x = checkpoint.checkpoint(blk, x, attn_mask) 373 | else: 374 | x = blk(x, attn_mask) 375 | if self.downsample is not None: 376 | x = self.downsample(x, H, W) 377 | H, W = (H + 1) // 2, (W + 1) // 2 378 | 379 | return x, H, W 380 | 381 | 382 | class SwinTransformer(nn.Module): 383 | r""" Swin Transformer 384 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 385 | https://arxiv.org/pdf/2103.14030 386 | 387 | Args: 388 | patch_size (int | tuple(int)): Patch size. Default: 4 389 | in_chans (int): Number of input image channels. Default: 3 390 | num_classes (int): Number of classes for classification head. Default: 1000 391 | embed_dim (int): Patch embedding dimension. Default: 96 392 | depths (tuple(int)): Depth of each Swin Transformer layer. 393 | num_heads (tuple(int)): Number of attention heads in different layers. 394 | window_size (int): Window size. Default: 7 395 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 396 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 397 | drop_rate (float): Dropout rate. Default: 0 398 | attn_drop_rate (float): Attention dropout rate. Default: 0 399 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 400 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 401 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 402 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 403 | """ 404 | 405 | def __init__(self, patch_size=4, in_chans=3, num_classes=1000, 406 | embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), 407 | window_size=7, mlp_ratio=4., qkv_bias=True, 408 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 409 | norm_layer=nn.LayerNorm, patch_norm=True, 410 | use_checkpoint=False, **kwargs): 411 | super().__init__() 412 | 413 | self.num_classes = num_classes 414 | self.num_layers = len(depths) 415 | self.embed_dim = embed_dim 416 | self.patch_norm = patch_norm 417 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 418 | self.mlp_ratio = mlp_ratio 419 | 420 | self.patch_embed = PatchEmbed( 421 | patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim, 422 | norm_layer=norm_layer if self.patch_norm else None) 423 | self.pos_drop = nn.Dropout(p=drop_rate) 424 | 425 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 426 | 427 | self.layers = nn.ModuleList() 428 | for i_layer in range(self.num_layers): 429 | layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 430 | depth=depths[i_layer], 431 | num_heads=num_heads[i_layer], 432 | window_size=window_size, 433 | mlp_ratio=self.mlp_ratio, 434 | qkv_bias=qkv_bias, 435 | drop=drop_rate, 436 | attn_drop=attn_drop_rate, 437 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 438 | norm_layer=norm_layer, 439 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 440 | use_checkpoint=use_checkpoint) 441 | self.layers.append(layers) 442 | 443 | self.norm = norm_layer(self.num_features) 444 | self.avgpool = nn.AdaptiveAvgPool1d(1) 445 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 446 | 447 | self.apply(self._init_weights) 448 | 449 | def _init_weights(self, m): 450 | if isinstance(m, nn.Linear): 451 | nn.init.trunc_normal_(m.weight, std=.02) 452 | if isinstance(m, nn.Linear) and m.bias is not None: 453 | nn.init.constant_(m.bias, 0) 454 | elif isinstance(m, nn.LayerNorm): 455 | nn.init.constant_(m.bias, 0) 456 | nn.init.constant_(m.weight, 1.0) 457 | 458 | def forward(self, x): 459 | # x: [B, L, C] 460 | x, H, W = self.patch_embed(x) 461 | x = self.pos_drop(x) 462 | 463 | for layer in self.layers: 464 | x, H, W = layer(x, H, W) 465 | 466 | x = self.norm(x) 467 | x = self.avgpool(x.transpose(1, 2)) 468 | x = torch.flatten(x, 1) 469 | x = self.head(x) 470 | x = torch.sigmoid(x) 471 | 472 | return x 473 | 474 | 475 | def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs): 476 | # trained ImageNet-1K 477 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth 478 | model = SwinTransformer(in_chans=3, 479 | patch_size=4, 480 | window_size=7, 481 | embed_dim=96, 482 | depths=(2, 2, 6, 2), 483 | num_heads=(3, 6, 12, 24), 484 | num_classes=num_classes, 485 | **kwargs) 486 | return model 487 | 488 | 489 | def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs): 490 | # trained ImageNet-1K 491 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth 492 | model = SwinTransformer(in_chans=3, 493 | patch_size=4, 494 | window_size=7, 495 | embed_dim=96, 496 | depths=(2, 2, 18, 2), 497 | num_heads=(3, 6, 12, 24), 498 | num_classes=num_classes, 499 | **kwargs) 500 | return model 501 | 502 | 503 | def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs): 504 | # trained ImageNet-1K 505 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth 506 | model = SwinTransformer(in_chans=3, 507 | patch_size=4, 508 | window_size=7, 509 | embed_dim=128, 510 | depths=(2, 2, 18, 2), 511 | num_heads=(4, 8, 16, 32), 512 | num_classes=num_classes, 513 | **kwargs) 514 | return model 515 | 516 | 517 | def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs): 518 | # trained ImageNet-1K 519 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth 520 | model = SwinTransformer(in_chans=3, 521 | patch_size=4, 522 | window_size=12, 523 | embed_dim=128, 524 | depths=(2, 2, 18, 2), 525 | num_heads=(4, 8, 16, 32), 526 | num_classes=num_classes, 527 | **kwargs) 528 | return model 529 | 530 | 531 | def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): 532 | # trained ImageNet-22K 533 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 534 | model = SwinTransformer(in_chans=3, 535 | patch_size=4, 536 | window_size=7, 537 | embed_dim=128, 538 | depths=(2, 2, 18, 2), 539 | num_heads=(4, 8, 16, 32), 540 | num_classes=num_classes, 541 | **kwargs) 542 | return model 543 | 544 | 545 | def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs): 546 | # trained ImageNet-22K 547 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth 548 | model = SwinTransformer(in_chans=3, 549 | patch_size=4, 550 | window_size=12, 551 | embed_dim=128, 552 | depths=(2, 2, 18, 2), 553 | num_heads=(4, 8, 16, 32), 554 | num_classes=num_classes, 555 | **kwargs) 556 | return model 557 | 558 | 559 | def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs): 560 | # trained ImageNet-22K 561 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth 562 | model = SwinTransformer(in_chans=3, 563 | patch_size=4, 564 | window_size=7, 565 | embed_dim=192, 566 | depths=(2, 2, 18, 2), 567 | num_heads=(6, 12, 24, 48), 568 | num_classes=num_classes, 569 | **kwargs) 570 | return model 571 | 572 | 573 | def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs): 574 | # trained ImageNet-22K 575 | # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth 576 | model = SwinTransformer(in_chans=3, 577 | patch_size=4, 578 | window_size=12, 579 | embed_dim=192, 580 | depths=(2, 2, 18, 2), 581 | num_heads=(6, 12, 24, 48), 582 | num_classes=num_classes, 583 | **kwargs) 584 | return model 585 | --------------------------------------------------------------------------------