├── datasets ├── a ├── unitopatho_test.csv ├── camelyon16_test.csv ├── unitopatho_train.csv ├── camelyon17_seen.csv ├── camelyon16_total.csv └── camelyon17.csv ├── visualization ├── SAC.png ├── ROCs.png ├── AttriMIL.png ├── Visualization.png ├── OODPerformance.png └── AttributeScoring.png ├── env.yml ├── coords_to_feature.py ├── models ├── AttriMIL.py ├── resnet_custom_dep.py ├── ABMIL.py ├── TransMIL.py ├── DSMIL.py ├── MIL.py └── S4MIL.py ├── constraints.py ├── create_3coords.py ├── .gitignore ├── tester_transmil.py ├── tester_mil.py ├── tester_dsmil.py ├── README.md ├── tester_attrimil_abmil.py ├── utils.py ├── trainer_mil.py ├── trainer_transmil.py ├── LICENSE ├── trainer_dsmil.py ├── trainer_attrimil_abmil.py └── dataloader.py /datasets/a: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /visualization/SAC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/SAC.png -------------------------------------------------------------------------------- /visualization/ROCs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/ROCs.png -------------------------------------------------------------------------------- /visualization/AttriMIL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/AttriMIL.png -------------------------------------------------------------------------------- /visualization/Visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/Visualization.png -------------------------------------------------------------------------------- /visualization/OODPerformance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/OODPerformance.png -------------------------------------------------------------------------------- /visualization/AttributeScoring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/AttributeScoring.png -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: attrimil 2 | dependencies: 3 | - python==3.10 4 | - conda-forge::openslide 5 | - pip 6 | - pip: 7 | - timm==0.9.8 8 | - torch 9 | - torchvision 10 | - h5py 11 | - pandas 12 | - PyYAML 13 | - opencv-python 14 | - matplotlib 15 | - scikit-learn 16 | - scipy 17 | - tqdm 18 | - openslide-python 19 | - git+https://github.com/oval-group/smooth-topk.git 20 | - tensorboardX 21 | -------------------------------------------------------------------------------- /coords_to_feature.py: -------------------------------------------------------------------------------- 1 | import openslide 2 | import matplotlib.pyplot as plt 3 | import os 4 | import numpy as np 5 | import h5py 6 | 7 | if __name__ == "__main__": 8 | # 需要改保存路径,和patch_size!! 9 | orgin_path = '/data2/clh/NSCLC/resnet18_simclr/h5_files/' 10 | coord_path = '/data2/clh/NSCLC/coords/' 11 | save_path = '/data2/clh/NSCLC/resnet18_simclr/h5_coords_files/' 12 | patch_size = (256, 256) 13 | start = 0 14 | for step, name in enumerate(os.listdir(orgin_path)): 15 | if step < start: 16 | continue 17 | if name in os.listdir(save_path): 18 | print("exist:", name) 19 | continue 20 | 21 | if name.endswith('h5'): 22 | # 读取文件 23 | print("Loading:", name) 24 | h5 = h5py.File(orgin_path + name) 25 | coords = np.array(h5['coords']) 26 | features = np.array(h5['features']) 27 | h5.close() 28 | h5 = h5py.File(coord_path + name) 29 | nearest = np.array(h5['nearest']) 30 | h5.close() 31 | h5 = h5py.File(save_path + name, 'w') #写入文件 32 | h5['coords'] = coords 33 | h5['features'] = features 34 | h5['nearest'] = nearest #名称为image 35 | h5.close() #关闭文件 36 | print("coords:{}, features:{}, nearest:{}".format(coords.shape, features.shape, nearest.shape)) 37 | -------------------------------------------------------------------------------- /models/AttriMIL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Attn_Net_Gated(nn.Module): 7 | def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1): 8 | super(Attn_Net_Gated, self).__init__() 9 | self.attention_a = [ 10 | nn.Linear(L, D), 11 | nn.Tanh()] 12 | 13 | self.attention_b = [nn.Linear(L, D), 14 | nn.Sigmoid()] 15 | if dropout: 16 | self.attention_a.append(nn.Dropout(0.25)) 17 | self.attention_b.append(nn.Dropout(0.25)) 18 | 19 | self.attention_a = nn.Sequential(*self.attention_a) 20 | self.attention_b = nn.Sequential(*self.attention_b) 21 | 22 | self.attention_c = nn.Linear(D, n_classes) 23 | 24 | def forward(self, x): 25 | a = self.attention_a(x) 26 | b = self.attention_b(x) 27 | A = a.mul(b) 28 | A = self.attention_c(A) # N x n_classes 29 | return A, x 30 | 31 | 32 | class AttriMIL(nn.Module): 33 | ''' 34 | Multi-Branch ABMIL with constraints 35 | ''' 36 | def __init__(self, n_classes=2, dim=512): 37 | super().__init__() 38 | self.adaptor = nn.Sequential(nn.Linear(dim, dim//2), 39 | nn.ReLU(), 40 | nn.Linear(dim // 2 , dim)) 41 | 42 | attention = [] 43 | classifer = [nn.Linear(dim, 1) for i in range(n_classes)] 44 | for i in range(n_classes): 45 | attention.append(Attn_Net_Gated(L = dim, D = dim // 2,)) 46 | self.attention_nets = nn.ModuleList(attention) 47 | self.classifiers = nn.ModuleList(classifer) 48 | self.n_classes = n_classes 49 | self.bias = nn.Parameter(torch.zeros(n_classes), requires_grad=True) 50 | 51 | def forward(self, h): 52 | h = h + self.adaptor(h) 53 | A_raw = torch.empty(self.n_classes, h.size(0), ) # N x 1 54 | instance_score = torch.empty(1, self.n_classes, h.size(0)).float().to(h.device) 55 | for c in range(self.n_classes): 56 | A, h = self.attention_nets[c](h) 57 | A = torch.transpose(A, 1, 0) # 1 x N 58 | A_raw[c] = A 59 | instance_score[0, c] = self.classifiers[c](h)[:, 0] 60 | attribute_score = torch.empty(1, self.n_classes, h.size(0)).float().to(h.device) 61 | for c in range(self.n_classes): 62 | attribute_score[0, c] = instance_score[0, c] * torch.exp(A_raw[c]) 63 | 64 | logits = torch.empty(1, self.n_classes).float().to(h.device) 65 | for c in range(self.n_classes): 66 | logits[0, c] = torch.sum(attribute_score[0, c], keepdim=True, dim=-1) / torch.sum(torch.exp(A_raw[c]), dim=-1) + self.bias[c] 67 | 68 | Y_hat = torch.topk(logits, 1, dim = 1)[1] 69 | Y_prob = F.softmax(logits, dim = 1) 70 | results_dict = {} 71 | return logits, Y_prob, Y_hat, attribute_score, results_dict -------------------------------------------------------------------------------- /constraints.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils import * 4 | import os 5 | import queue 6 | from sklearn.preprocessing import label_binarize 7 | from sklearn.metrics import roc_auc_score, roc_curve 8 | from sklearn.metrics import auc as calc_auc 9 | 10 | 11 | def spatial_constraint(A, n_classes, nearest, ks=3): 12 | loss_spatial = torch.tensor(0.0).to(device) 13 | # N = A.shape[-1] 14 | for c in range(1, n_classes): 15 | score = A[:, c] # N 16 | nearest_score = score[nearest] # N ks^2-1 17 | abs_nearest = torch.abs(nearest_score) 18 | max_indices = torch.argmax(abs_nearest, dim=1) 19 | local_prototype = nearest_score.gather(1, max_indices.view(-1, 1)).squeeze() 20 | # print(local_prototype[:10]) 21 | loss_spatial += torch.mean(torch.abs(torch.tanh(score - local_prototype))) 22 | return loss_spatial 23 | 24 | 25 | def rank_constraint(data, label, model, A, n_classes, label_positive_list, label_negative_list): 26 | loss_rank = torch.tensor(0.0).to(device) 27 | for c in range(n_classes): 28 | if label == c: 29 | value, indice = torch.topk(A[0, c], k=1) 30 | h = data[indice.item(): indice.item() + 1] # top feature 31 | if label_positive_list[c].full(): 32 | _ = label_positive_list[c].get() 33 | label_positive_list[c].put(h) 34 | if label_negative_list[c].empty(): 35 | loss_rank = loss_rank + torch.tensor(0.0).to(device) 36 | else: 37 | h = label_negative_list[c].get() 38 | label_negative_list[c].put(h) 39 | _, _, _, Ah, _ = model(h.detach()) 40 | if c != 0: 41 | loss_rank = loss_rank + torch.clamp(torch.mean(Ah[0, c] - value), min=0.0) + torch.clamp(torch.mean(-value), min=0.0) + torch.clamp(torch.mean(Ah[0, c]), min=0.0) 42 | else: 43 | loss_rank = loss_rank + torch.clamp(torch.mean(-value), min=0.0) + torch.clamp(torch.mean(Ah[0, c]), min=0.0) 44 | else: 45 | value, indice = torch.topk(A[0, c], k=1) 46 | h = data[indice.item(): indice.item() + 1] # top feature 47 | if label_negative_list[c].full(): 48 | _ = label_negative_list[c].get() 49 | label_negative_list[c].put(h) 50 | if label_positive_list[c].empty(): 51 | loss_rank = loss_rank + torch.tensor(0.0).to(device) 52 | else: 53 | h = label_positive_list[c].get() 54 | label_positive_list[c].put(h) 55 | _, _, _, Ah, _ = model(h.detach()) 56 | if c != 0: 57 | loss_rank = loss_rank + torch.clamp(torch.mean(value - Ah[0, c]), min=0.0) + torch.clamp(torch.mean(value), min=0.0) 58 | else: 59 | loss_rank = loss_rank + torch.clamp(torch.mean(value), min=0.0) + torch.clamp(torch.mean(-Ah[0, c]), min=0.0) 60 | loss_rank = loss_rank / n_classes 61 | return loss_rank, label_positive_list, label_negative_list 62 | -------------------------------------------------------------------------------- /models/resnet_custom_dep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import torch.nn.functional as F 5 | import torch 6 | import torchvision.models as models 7 | 8 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 9 | 10 | class Network(torch.nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | # model = timm.create_model('resnet50_gn', num_classes=1000, pretrained=False) 14 | # path = './weights/resnet50_gn_a1h2-8fe6c4d0.pth' 15 | # model.load_state_dict(torch.load(path)) 16 | path = './weights/simclr_resnet50.safetensors' 17 | model = timm.create_model('resnet50', pretrained=False, pretrained_cfg_overlay=dict(file=path)) 18 | self.resnet = model 19 | self.avgpool = nn.AdaptiveAvgPool2d(1) 20 | 21 | def forward(self, x): 22 | x = self.resnet.conv1(x) 23 | x = self.resnet.bn1(x) 24 | x = self.resnet.act1(x) 25 | x = self.resnet.maxpool(x) 26 | x = self.resnet.layer1(x) 27 | # x = rearrange(h, 'b c h w -> b h w c') 28 | # x = self.adapter(x) 29 | # x = rearrange(x, 'b h w c -> b c h w') 30 | # x = x + h 31 | x = self.resnet.layer2(x) 32 | # x = rearrange(h, 'b c h w -> b h w c') 33 | # x = self.adapter(x) 34 | # x = rearrange(x, 'b h w c -> b c h w') 35 | # x = x + h 36 | x = self.resnet.layer3(x) 37 | x = self.avgpool(x) 38 | x = x.view(x.shape[0], -1) 39 | return x 40 | 41 | 42 | class ResNet18(torch.nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | # model = timm.create_model('resnet50_gn', num_classes=1000, pretrained=False) 46 | # path = './weights/resnet50_gn_a1h2-8fe6c4d0.pth' 47 | # model.load_state_dict(torch.load(path)) 48 | # path = './weights/resnet18_imagenet.safetensors' 49 | # model = timm.create_model('resnet18', pretrained=False, pretrained_cfg_overlay=dict(file=path)) 50 | resnet18 = models.resnet18(pretrained=True) 51 | self.resnet = resnet18 52 | self.avgpool = nn.AdaptiveAvgPool2d(1) 53 | 54 | def forward(self, x): 55 | x = F.interpolate(x, (256, 256), mode='bilinear') 56 | x = self.resnet.conv1(x) 57 | x = self.resnet.bn1(x) 58 | x = self.resnet.relu(x) 59 | x = self.resnet.maxpool(x) 60 | x = self.resnet.layer1(x) 61 | # x = rearrange(h, 'b c h w -> b h w c') 62 | # x = self.adapter(x) 63 | # x = rearrange(x, 'b h w c -> b c h w') 64 | # x = x + h 65 | x = self.resnet.layer2(x) 66 | # x = rearrange(h, 'b c h w -> b h w c') 67 | # x = self.adapter(x) 68 | # x = rearrange(x, 'b h w c -> b c h w') 69 | # x = x + h 70 | x = self.resnet.layer3(x) 71 | # x = rearrange(h, 'b c h w -> b h w c') 72 | # x = self.adapter(x) 73 | # x = rearrange(x, 'b h w c -> b c h w') 74 | # x = x + h 75 | x = self.resnet.layer4(x) 76 | x = self.avgpool(x) 77 | x = x.view(x.shape[0], -1) 78 | return x 79 | 80 | class IClassifier(nn.Module): 81 | def __init__(self, feature_extractor,): 82 | super(IClassifier, self).__init__() 83 | 84 | self.feature_extractor = feature_extractor 85 | 86 | def forward(self, x): 87 | device = x.device 88 | feats = self.feature_extractor(x) # N x K 89 | return feats.view(feats.shape[0], -1) 90 | 91 | -------------------------------------------------------------------------------- /datasets/unitopatho_test.csv: -------------------------------------------------------------------------------- 1 | case_id,slide_id,label 2 | patient_1,120-B3-NORM_1,NORM 3 | patient_2,193-B4-NORM_1,NORM 4 | patient_3,155-B4-NORM_1,NORM 5 | patient_4,211-B5-NORM_1,NORM 6 | patient_5,186-B4-NORM_1,NORM 7 | patient_6,188-B4-NORM_1,NORM 8 | patient_7,263-B5-NORM_1,NORM 9 | patient_8,187-B4-NORM_1,NORM 10 | patient_9,248-B5-TVALG_5,TVA.LG 11 | patient_10,134-B3-TVALG_5,TVA.LG 12 | patient_11,243-B5-TVALG_5,TVA.LG 13 | patient_12,180-B4-TVALG_5,TVA.LG 14 | patient_13,254-B5-TVALG_5,TVA.LG 15 | patient_14,199-B4-TVALG_5,TVA.LG 16 | patient_15,TVA.LG CASO 3 - 2018-12-04 13.22.40_5,TVA.LG 17 | patient_16,197-B4-TVALG_5,TVA.LG 18 | patient_17,TA.HG CASO 2 - 2018-12-04 13.31.38_2,TA.HG 19 | patient_18,TA.HG CASO 12 - 2019-03-04 18.16.39_2,TA.HG 20 | patient_19,129-B3-TAHG_2,TA.HG 21 | patient_20,204-B5-TAHG_2,TA.HG 22 | patient_21,85-B2-TAHG_2,TA.HG 23 | patient_22,106-B3-TAHG_2,TA.HG 24 | patient_23,TA.HG CASO 4 - 2019-03-04 10.26.19_2,TA.HG 25 | patient_24,56-B2-TAHG_2,TA.HG 26 | patient_25,119-B3-TAHG_2,TA.HG 27 | patient_26,147-B3-HP_0,HP 28 | patient_27,110-B3-HP_0,HP 29 | patient_28,242-B5-HP_0,HP 30 | patient_29,HP CASO 27 - 2019-03-04 17.03.07_0,HP 31 | patient_30,192-B4-HP_0,HP 32 | patient_31,246-B5-HP_0,HP 33 | patient_32,160-B4-HP_0,HP 34 | patient_33,HP CASO 25 - 2019-03-04 09.16.56_0,HP 35 | patient_34,164-B4-HP_0,HP 36 | patient_35,202-B5-HP_0,HP 37 | patient_36,241-B5-TALG_3,TA.LG 38 | patient_37,TA.LG CASO 96_3,TA.LG 39 | patient_38,53-B2-TALG_3,TA.LG 40 | patient_39,TA.LG CASO 44 - 2019-03-01 08.25.42_3,TA.LG 41 | patient_40,TA.LG CASO 94_3,TA.LG 42 | patient_41,67-B2-TALG_3,TA.LG 43 | patient_42,222-B5-TALG_3,TA.LG 44 | patient_43,TA.LG CASO 90_3,TA.LG 45 | patient_44,64-B2-TALG_3,TA.LG 46 | patient_45,122-B3-TALG_3,TA.LG 47 | patient_46,TA.LG CASO 61 - 2019-03-04 16.54.59_3,TA.LG 48 | patient_47,TA.LG CASO 87_3,TA.LG 49 | patient_48,121-B3-TALG_3,TA.LG 50 | patient_49,11-B1TALG_3,TA.LG 51 | patient_50,TA.LG CASO 43 - 2019-03-01 08.01.10_3,TA.LG 52 | patient_51,63-B2-TALG_3,TA.LG 53 | patient_52,TA.LG CASO 1 - 2018-12-04 12.52.01_3,TA.LG 54 | patient_53,TA.LG CASO 63 - 2019-03-04 17.23.02_3,TA.LG 55 | patient_54,TA.LG CASO 98_3,TA.LG 56 | patient_55,TA.LG CASO 92 A1_3,TA.LG 57 | patient_56,TA.LG CASO 101 D1_3,TA.LG 58 | patient_57,123-B3-TALG_3,TA.LG 59 | patient_58,99-B2-TALG_3,TA.LG 60 | patient_59,175-B4-TALG_3,TA.LG 61 | patient_60,TA.LG CASO 95_3,TA.LG 62 | patient_61,71-B2-TALG_3,TA.LG 63 | patient_62,88-B2-TALG_3,TA.LG 64 | patient_63,TA.LG CASO 54 - 2019-03-04 09.33.07_3,TA.LG 65 | patient_64,126-B3-TALG_3,TA.LG 66 | patient_65,TA.LG CASO 48 - 2019-03-01 09.02.35_3,TA.LG 67 | patient_66,141-B3-TALG_3,TA.LG 68 | patient_67,214-B5-TALG_3,TA.LG 69 | patient_68,179-B4-TALG_3,TA.LG 70 | patient_69,TA.LG CASO 12 - 2018-12-04 13.48.28_3,TA.LG 71 | patient_70,77-B2-TALG_3,TA.LG 72 | patient_71,TA.LG CASO 103_3,TA.LG 73 | patient_72,97-B2-TALG_3,TA.LG 74 | patient_73,183-B4-TALG_3,TA.LG 75 | patient_74,107-B3-TALG_3,TA.LG 76 | patient_75,195-B4-TALG_3,TA.LG 77 | patient_76,78-B2-TALG_3,TA.LG 78 | patient_77,TA.LG CASO 46 - 2019-03-01 08.37.03_3,TA.LG 79 | patient_78,128-B3-TALG_3,TA.LG 80 | patient_79,247-B5-TALG_3,TA.LG 81 | patient_80,253-B5-TALG_3,TA.LG 82 | patient_81,TA.LG CASO 83_3,TA.LG 83 | patient_82,92-B2-TVAHG_4,TVA.HG 84 | patient_83,240-B5-TVAHG_4,TVA.HG 85 | patient_84,59-B2-TVAHG_4,TVA.HG 86 | patient_85,TVA.HG CASO 11_4,TVA.HG 87 | patient_86,TVA.HG CASO 13_4,TVA.HG 88 | patient_87,238-B5-TVAHG_4,TVA.HG 89 | patient_88,TVA.HG CASO 9 - 2019-03-04 10.19.16_4,TVA.HG 90 | -------------------------------------------------------------------------------- /models/ABMIL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Attn_Net_Gated(nn.Module): 7 | def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1): 8 | super(Attn_Net_Gated, self).__init__() 9 | self.attention_a = [ 10 | nn.Linear(L, D), 11 | nn.Tanh()] 12 | 13 | self.attention_b = [nn.Linear(L, D), 14 | nn.Sigmoid()] 15 | if dropout: 16 | self.attention_a.append(nn.Dropout(0.25)) 17 | self.attention_b.append(nn.Dropout(0.25)) 18 | 19 | self.attention_a = nn.Sequential(*self.attention_a) 20 | self.attention_b = nn.Sequential(*self.attention_b) 21 | 22 | self.attention_c = nn.Linear(D, n_classes) 23 | 24 | def forward(self, x): 25 | a = self.attention_a(x) 26 | b = self.attention_b(x) 27 | A = a.mul(b) 28 | A = self.attention_c(A) # N x n_classes 29 | return A, x 30 | 31 | 32 | class ABMIL(nn.Module): 33 | def __init__(self, n_classes=2, dim=512): 34 | super().__init__() 35 | self.adaptor = nn.Sequential(nn.Linear(dim, dim//2), 36 | nn.ReLU(), 37 | nn.Linear(dim // 2 , dim)) 38 | 39 | self.attention_net = Attn_Net_Gated(L = dim, D = dim // 2, n_classes=1) 40 | self.classifier = nn.Linear(dim, n_classes) 41 | self.n_classes = n_classes 42 | 43 | def forward(self, h): 44 | h = h + self.adaptor(h) 45 | logits = torch.empty(1, self.n_classes).float().to(h.device) 46 | A, h = self.attention_net(h) 47 | A = torch.transpose(A, 1, 0) # 1 x N 48 | A_raw = A 49 | A = F.softmax(A, dim=1) # softmax over N 50 | 51 | M = torch.mm(A, h) # 1 x dim 52 | logits = self.classifier(M) 53 | Y_hat = torch.topk(logits, 1, dim = 1)[1] 54 | Y_prob = F.softmax(logits, dim = 1) 55 | results_dict = {} 56 | results_dict.update({'features': M}) 57 | return logits, Y_prob, Y_hat, A_raw, results_dict 58 | 59 | 60 | class ABMIL_MB(nn.Module): 61 | def __init__(self, n_classes=2, dim=512): 62 | super().__init__() 63 | self.adaptor = nn.Sequential(nn.Linear(dim, dim//2), 64 | nn.ReLU(), 65 | nn.Linear(dim // 2 , dim)) 66 | 67 | attention = [] 68 | classifer = [nn.Linear(dim, 1) for i in range(n_classes)] 69 | for i in range(n_classes): 70 | attention.append(Attn_Net_Gated(L = dim, D = dim // 2,)) 71 | self.attention_nets = nn.ModuleList(attention) 72 | self.classifiers = nn.ModuleList(classifer) 73 | self.n_classes = n_classes 74 | 75 | def forward(self, h): 76 | h = h + self.adaptor(h) 77 | logits = torch.empty(1, self.n_classes).float().to(h.device) 78 | A_raw = torch.empty(self.n_classes, h.size(0), ) # N x 1 79 | for c in range(self.n_classes): 80 | A, h = self.attention_nets[c](h) 81 | A = torch.transpose(A, 1, 0) # 1 x N 82 | A = F.softmax(A, dim=1) # softmax over N 83 | A_raw[c] = A 84 | M = torch.mm(A, h) 85 | logits[0, c] = self.classifiers[c](M) 86 | Y_hat = torch.topk(logits, 1, dim = 1)[1] 87 | Y_prob = F.softmax(logits, dim = 1) 88 | results_dict = {} 89 | results_dict.update({'features': M}) 90 | return logits, Y_prob, Y_hat, A_raw, results_dict 91 | -------------------------------------------------------------------------------- /create_3coords.py: -------------------------------------------------------------------------------- 1 | import openslide 2 | import matplotlib.pyplot as plt 3 | import os 4 | import numpy as np 5 | import h5py 6 | from joblib import Parallel, delayed 7 | 8 | 9 | def find_nearest(input_path, 10 | output_path, 11 | patch_size=(256, 256)): 12 | print("Loading:", name) 13 | h5 = h5py.File(input_path) 14 | coords = np.array(h5['coords']) 15 | # features = np.array(h5['features']) 16 | h5.close() 17 | 18 | nearest = [] 19 | # left, right, up, down, left_up, left_down, right_up, right_down 20 | for step, p in enumerate(coords): 21 | exists = [np.array(step)] 22 | left = np.array([p[0], p[1] - patch_size[1]]) 23 | loc = np.where(np.sum(coords == left, axis=1) == 2)[0] 24 | if len(loc) != 0: 25 | exists.append(loc[0]) 26 | else: 27 | exists.append(np.array(step)) 28 | 29 | right = np.array([p[0], p[1] + patch_size[1]]) 30 | loc = np.where(np.sum(coords == right, axis=1) == 2)[0] 31 | if len(loc) != 0: 32 | exists.append(loc[0]) 33 | else: 34 | exists.append(np.array(step)) 35 | 36 | up = np.array([p[0] - patch_size[0], p[1]]) 37 | loc = np.where(np.sum(coords == up, axis=1) == 2)[0] 38 | if len(loc) != 0: 39 | exists.append(loc[0]) 40 | else: 41 | exists.append(np.array(step)) 42 | 43 | down = np.array([p[0] + patch_size[0], p[1]]) 44 | loc = np.where(np.sum(coords == down, axis=1) == 2)[0] 45 | if len(loc) != 0: 46 | exists.append(loc[0]) 47 | else: 48 | exists.append(np.array(step)) 49 | 50 | left_up = np.array([p[0] - patch_size[0], p[1] - patch_size[1]]) 51 | loc = np.where(np.sum(coords == left_up, axis=1) == 2)[0] 52 | if len(loc) != 0: 53 | exists.append(loc[0]) 54 | else: 55 | exists.append(np.array(step)) 56 | 57 | left_down = np.array([p[0] + patch_size[0], p[1] - patch_size[1]]) 58 | loc = np.where(np.sum(coords == left_down, axis=1) == 2)[0] 59 | if len(loc) != 0: 60 | exists.append(loc[0]) 61 | else: 62 | exists.append(np.array(step)) 63 | 64 | right_up = np.array([p[0] - patch_size[0], p[1] + patch_size[1]]) 65 | loc = np.where(np.sum(coords == right_up, axis=1) == 2)[0] 66 | if len(loc) != 0: 67 | exists.append(loc[0]) 68 | else: 69 | exists.append(np.array(step)) 70 | 71 | right_down = np.array([p[0] + patch_size[0], p[1] + patch_size[1]]) 72 | loc = np.where(np.sum(coords == right_down, axis=1) == 2)[0] 73 | if len(loc) != 0: 74 | exists.append(loc[0]) 75 | else: 76 | exists.append(np.array(step)) 77 | nearest.append(exists) 78 | 79 | h5 = h5py.File(output_path, 'w') # 写入文件 80 | h5['coords'] = coords 81 | # h5['features'] = features 82 | h5['nearest'] = nearest # 名称为 image 83 | h5.close() #关闭文件 84 | return 85 | 86 | 87 | if __name__ == "__main__": 88 | # 需要改保存路径,和patch_size!! 89 | orgin_path = '/data2/clh/NSCLC/LUAD/20X/patches/' 90 | save_path = '/data2/clh/NSCLC/coords/' 91 | patch_size = (256, 256) 92 | start = 0 93 | name_list = [] 94 | for step, name in enumerate(os.listdir(orgin_path)): 95 | if step < start: 96 | continue 97 | if name.endswith('h5'): 98 | name_list.append(name) 99 | Parallel(n_jobs=32)(delayed(find_nearest)(os.path.join(orgin_path, slide), os.path.join(save_path, slide)) for slide in name_list) -------------------------------------------------------------------------------- /models/TransMIL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from nystrom_attention import NystromAttention 6 | 7 | 8 | class TransLayer(nn.Module): 9 | 10 | def __init__(self, norm_layer=nn.LayerNorm, dim=512): 11 | super().__init__() 12 | self.norm = norm_layer(dim) 13 | self.attn = NystromAttention( 14 | dim = dim, 15 | dim_head = dim//8, 16 | heads = 8, 17 | num_landmarks = dim//2, # number of landmarks 18 | pinv_iterations = 6, # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper 19 | residual = True, # whether to do an extra residual with the value or not. supposedly faster convergence if turned on 20 | dropout=0.1 21 | ) 22 | 23 | def forward(self, x): 24 | out = self.attn(self.norm(x)) 25 | x = x + out 26 | 27 | return x 28 | 29 | 30 | class PPEG(nn.Module): 31 | def __init__(self, dim=512): 32 | super(PPEG, self).__init__() 33 | self.proj = nn.Conv2d(dim, dim, 7, 1, 7//2, groups=dim) 34 | self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim) 35 | self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim) 36 | 37 | def forward(self, x, H, W): 38 | B, _, C = x.shape 39 | cls_token, feat_token = x[:, 0], x[:, 1:] 40 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) 41 | x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat) 42 | x = x.flatten(2).transpose(1, 2) 43 | x = torch.cat((cls_token.unsqueeze(1), x), dim=1) 44 | return x 45 | 46 | 47 | class TransMIL(nn.Module): 48 | def __init__(self, dim=512, n_classes=2): 49 | super(TransMIL, self).__init__() 50 | self.pos_layer = PPEG(dim=256) 51 | self._fc1 = nn.Sequential(nn.Linear(dim, 256), nn.ReLU()) 52 | self.cls_token = nn.Parameter(torch.randn(1, 1, 256)) 53 | self.n_classes = n_classes 54 | self.layer1 = TransLayer(dim=256) 55 | self.layer2 = TransLayer(dim=256) 56 | self.norm = nn.LayerNorm(256) 57 | self._fc2 = nn.Linear(256, self.n_classes) 58 | 59 | 60 | def forward(self, h): 61 | h = h.unsqueeze(0) 62 | h = self._fc1(h) #[B, n, 512] 63 | #---->pad 64 | H = h.shape[1] 65 | _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H))) 66 | add_length = _H * _W - H 67 | h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 512] 68 | 69 | #---->cls_token 70 | B = h.shape[0] 71 | cls_tokens = self.cls_token.expand(B, -1, -1).cuda() 72 | # print(h.shape, cls_tokens.shape) 73 | h = torch.cat((cls_tokens, h), dim=1) 74 | 75 | #---->Translayer x1 76 | h = self.layer1(h) #[B, N+1, 512] 77 | 78 | #---->PPEG 79 | h = self.pos_layer(h, _H, _W) #[B, N+1, 512] 80 | 81 | #---->Translayer x2 82 | h0 = self.layer2(h) #[B, N+1, 512] 83 | 84 | #---->cls_token 85 | h = self.norm(h0)[:,0] 86 | 87 | #---->predict 88 | logits = self._fc2(h) #[B, n_classes] 89 | Y_hat = torch.argmax(logits, dim=1) 90 | Y_prob = F.softmax(logits, dim = 1) 91 | # results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat} 92 | return logits, Y_prob, Y_hat, h0 93 | 94 | if __name__ == "__main__": 95 | data = torch.randn((1, 6000, 512)).cuda() 96 | model = TransMIL(n_classes=2).cuda() 97 | print(model.eval()) 98 | results_dict = model(data) 99 | print(results_dict) 100 | -------------------------------------------------------------------------------- /models/DSMIL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class FCLayer(nn.Module): 7 | def __init__(self, in_size, out_size=1): 8 | super(FCLayer, self).__init__() 9 | self.fc = nn.Sequential(nn.Linear(in_size, out_size)) 10 | def forward(self, feats): 11 | x = self.fc(feats) 12 | return feats, x 13 | 14 | class IClassifier(nn.Module): 15 | def __init__(self, feature_extractor, feature_size, output_class): 16 | super(IClassifier, self).__init__() 17 | 18 | self.feature_extractor = feature_extractor 19 | self.fc = nn.Linear(feature_size, output_class) 20 | 21 | 22 | def forward(self, x): 23 | device = x.device 24 | feats = self.feature_extractor(x) # N x K 25 | c = self.fc(feats.view(feats.shape[0], -1)) # N x C 26 | return feats.view(feats.shape[0], -1), c 27 | 28 | class BClassifier(nn.Module): 29 | def __init__(self, input_size, output_class, dropout_v=0.25, nonlinear=False, passing_v=False): # K, L, N 30 | super(BClassifier, self).__init__() 31 | if nonlinear: 32 | self.q = nn.Sequential(nn.Linear(input_size, 128), nn.ReLU(), nn.Linear(128, 128), nn.Tanh()) 33 | else: 34 | self.q = nn.Linear(input_size, 128) 35 | if passing_v: 36 | self.v = nn.Sequential( 37 | nn.Dropout(dropout_v), 38 | nn.Linear(input_size, input_size), 39 | nn.ReLU() 40 | ) 41 | else: 42 | self.v = nn.Identity() 43 | 44 | ### 1D convolutional layer that can handle multiple class (including binary) 45 | self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size) 46 | 47 | def forward(self, feats, c): # N x K, N x C 48 | device = feats.device 49 | V = self.v(feats) # N x V, unsorted 50 | Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted 51 | 52 | # handle multiple classes without for loop 53 | _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C 54 | m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K 55 | q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q 56 | A_raw = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores 57 | A = F.softmax( A_raw / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C, 58 | B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V 59 | 60 | B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V 61 | C = self.fcc(B) # 1 x C x 1 62 | C = C.view(1, -1) 63 | return C, A_raw, B 64 | 65 | class MILNet(nn.Module): 66 | def __init__(self, feature_dim, n_classes): 67 | super(MILNet, self).__init__() 68 | self.adapter = nn.Sequential(nn.Linear(feature_dim, feature_dim//2), 69 | nn.ReLU(), 70 | nn.Linear(feature_dim//2, feature_dim)) 71 | self.i_classifier = FCLayer(in_size=feature_dim, out_size=n_classes) 72 | self.b_classifier = BClassifier(input_size=feature_dim, output_class=n_classes) 73 | 74 | def forward(self, x): 75 | x = self.adapter(x) + x 76 | feats, classes = self.i_classifier(x) 77 | prediction_bag, A, B = self.b_classifier(feats, classes) 78 | 79 | return classes, prediction_bag, A, B -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /models/MIL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MIL_MeanPooling(nn.Module): 6 | def __init__(self, 7 | n_classes = 2, 8 | top_k=1, 9 | embed_dim=512): 10 | super().__init__() 11 | fc = [nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(0.25)] 12 | self.fc = nn.Sequential(*fc) 13 | self.classifier= nn.Linear(embed_dim // 2, n_classes) 14 | self.top_k=top_k 15 | 16 | def forward(self, h, return_features=False): 17 | h = self.fc(h) 18 | h = torch.mean(h, dim=0, keepdim=True) 19 | logits = self.classifier(h) 20 | 21 | y_probs = F.softmax(logits, dim = 1) 22 | top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1,) 23 | top_instance = torch.index_select(logits, dim=0, index=top_instance_idx) 24 | Y_hat = torch.topk(top_instance, 1, dim = 1)[1] 25 | Y_prob = F.softmax(top_instance, dim = 1) 26 | results_dict = {} 27 | 28 | if return_features: 29 | top_features = torch.index_select(h, dim=0, index=top_instance_idx) 30 | results_dict.update({'features': top_features}) 31 | return top_instance, Y_prob, Y_hat, y_probs, results_dict 32 | 33 | class MIL_MaxPooling(nn.Module): 34 | def __init__(self, 35 | n_classes = 2, 36 | top_k=1, 37 | embed_dim=512): 38 | super().__init__() 39 | fc = [nn.Linear(embed_dim, embed_dim//2), nn.ReLU(), nn.Dropout(0.25)] 40 | self.fc = nn.Sequential(*fc) 41 | self.classifiers = nn.Linear(embed_dim//2, n_classes) 42 | self.top_k=top_k 43 | self.n_classes = n_classes 44 | assert self.top_k == 1 45 | 46 | def forward(self, h, return_features=False): 47 | h = self.fc(h) 48 | logits = self.classifiers(h) 49 | 50 | y_probs = F.softmax(logits, dim = 1) 51 | m = y_probs.view(1, -1).argmax(1) 52 | top_indices = torch.cat(((m // self.n_classes).view(-1, 1), (m % self.n_classes).view(-1, 1)), dim=1).view(-1, 1) 53 | top_instance = logits[top_indices[0]] 54 | 55 | Y_hat = top_indices[1] 56 | Y_prob = y_probs[top_indices[0]] 57 | 58 | results_dict = {} 59 | 60 | if return_features: 61 | top_features = torch.index_select(h, dim=0, index=top_indices[0]) 62 | results_dict.update({'features': top_features}) 63 | return top_instance, Y_prob, Y_hat, y_probs, results_dict 64 | 65 | 66 | class MIL_RNN(nn.Module): 67 | def __init__(self, n_classes = 2, embed_dim = 512): 68 | super().__init__() 69 | self.dropout = nn.Dropout(0.25) 70 | self.rnn = nn.RNN(embed_dim, 128, 3, batch_first=True, bidirectional=True) 71 | self.classifier = nn.Linear(256, n_classes) 72 | self.top_k = 1 73 | 74 | def forward(self, h, return_features=False): 75 | h0 = torch.zeros(6, 1, 128).to(h.device) 76 | h = self.dropout(h) 77 | h = h.unsqueeze(0) 78 | if return_features: 79 | h, _ = self.rnn(h, h0) 80 | h = h.squeeze(1) 81 | logits = self.classifier(h) 82 | else: 83 | h, _ = self.rnn(h, h0) 84 | h = h.squeeze(0) 85 | logits = self.classifier(h) 86 | y_probs = F.softmax(logits, dim = 1) 87 | top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1,) 88 | top_instance = torch.index_select(logits, dim=0, index=top_instance_idx) 89 | Y_hat = torch.topk(top_instance, 1, dim = 1)[1] 90 | Y_prob = F.softmax(top_instance, dim = 1) 91 | results_dict = {} 92 | 93 | if return_features: 94 | top_features = torch.index_select(h, dim=0, index=top_instance_idx) 95 | results_dict.update({'features': top_features}) 96 | return top_instance, Y_prob, Y_hat, y_probs, results_dict -------------------------------------------------------------------------------- /datasets/camelyon16_test.csv: -------------------------------------------------------------------------------- 1 | case_id,slide_id,label 2 | patient_1,test_001,tumor_tissue 3 | patient_2,test_002,tumor_tissue 4 | patient_3,test_003,normal_tissue 5 | patient_4,test_004,tumor_tissue 6 | patient_5,test_005,normal_tissue 7 | patient_6,test_006,normal_tissue 8 | patient_7,test_007,normal_tissue 9 | patient_8,test_008,tumor_tissue 10 | patient_9,test_009,normal_tissue 11 | patient_10,test_010,tumor_tissue 12 | patient_11,test_011,tumor_tissue 13 | patient_12,test_012,normal_tissue 14 | patient_13,test_013,tumor_tissue 15 | patient_14,test_014,normal_tissue 16 | patient_15,test_015,normal_tissue 17 | patient_16,test_016,tumor_tissue 18 | patient_17,test_017,normal_tissue 19 | patient_18,test_018,normal_tissue 20 | patient_19,test_019,normal_tissue 21 | patient_20,test_020,normal_tissue 22 | patient_21,test_021,tumor_tissue 23 | patient_22,test_022,normal_tissue 24 | patient_23,test_023,normal_tissue 25 | patient_24,test_024,normal_tissue 26 | patient_25,test_025,normal_tissue 27 | patient_26,test_026,tumor_tissue 28 | patient_27,test_027,tumor_tissue 29 | patient_28,test_028,normal_tissue 30 | patient_29,test_029,tumor_tissue 31 | patient_30,test_030,tumor_tissue 32 | patient_31,test_031,normal_tissue 33 | patient_32,test_032,normal_tissue 34 | patient_33,test_033,tumor_tissue 35 | patient_34,test_034,normal_tissue 36 | patient_35,test_035,normal_tissue 37 | patient_36,test_036,normal_tissue 38 | patient_37,test_037,normal_tissue 39 | patient_38,test_038,tumor_tissue 40 | patient_39,test_039,normal_tissue 41 | patient_40,test_040,tumor_tissue 42 | patient_41,test_041,normal_tissue 43 | patient_42,test_042,normal_tissue 44 | patient_43,test_043,normal_tissue 45 | patient_44,test_044,normal_tissue 46 | patient_45,test_045,normal_tissue 47 | patient_46,test_046,tumor_tissue 48 | patient_47,test_047,normal_tissue 49 | patient_48,test_048,tumor_tissue 50 | patient_49,test_050,normal_tissue 51 | patient_50,test_051,tumor_tissue 52 | patient_51,test_052,tumor_tissue 53 | patient_52,test_053,normal_tissue 54 | patient_53,test_054,normal_tissue 55 | patient_54,test_055,normal_tissue 56 | patient_55,test_056,normal_tissue 57 | patient_56,test_057,normal_tissue 58 | patient_57,test_058,normal_tissue 59 | patient_58,test_059,normal_tissue 60 | patient_59,test_060,normal_tissue 61 | patient_60,test_061,tumor_tissue 62 | patient_61,test_062,normal_tissue 63 | patient_62,test_063,normal_tissue 64 | patient_63,test_064,tumor_tissue 65 | patient_64,test_065,tumor_tissue 66 | patient_65,test_066,tumor_tissue 67 | patient_66,test_067,normal_tissue 68 | patient_67,test_068,tumor_tissue 69 | patient_68,test_069,tumor_tissue 70 | patient_69,test_070,normal_tissue 71 | patient_70,test_071,tumor_tissue 72 | patient_71,test_072,normal_tissue 73 | patient_72,test_073,tumor_tissue 74 | patient_73,test_074,tumor_tissue 75 | patient_74,test_075,tumor_tissue 76 | patient_75,test_076,normal_tissue 77 | patient_76,test_077,normal_tissue 78 | patient_77,test_078,normal_tissue 79 | patient_78,test_079,tumor_tissue 80 | patient_79,test_080,normal_tissue 81 | patient_80,test_081,normal_tissue 82 | patient_81,test_082,tumor_tissue 83 | patient_82,test_083,normal_tissue 84 | patient_83,test_084,tumor_tissue 85 | patient_84,test_085,normal_tissue 86 | patient_85,test_086,normal_tissue 87 | patient_86,test_087,normal_tissue 88 | patient_87,test_088,normal_tissue 89 | patient_88,test_089,normal_tissue 90 | patient_89,test_090,tumor_tissue 91 | patient_90,test_091,normal_tissue 92 | patient_91,test_092,tumor_tissue 93 | patient_92,test_093,normal_tissue 94 | patient_93,test_094,tumor_tissue 95 | patient_94,test_095,normal_tissue 96 | patient_95,test_096,normal_tissue 97 | patient_96,test_097,tumor_tissue 98 | patient_97,test_098,normal_tissue 99 | patient_98,test_099,tumor_tissue 100 | patient_99,test_100,normal_tissue 101 | patient_100,test_101,normal_tissue 102 | patient_101,test_102,tumor_tissue 103 | patient_102,test_103,normal_tissue 104 | patient_103,test_104,tumor_tissue 105 | patient_104,test_105,tumor_tissue 106 | patient_105,test_106,normal_tissue 107 | patient_106,test_107,normal_tissue 108 | patient_107,test_108,tumor_tissue 109 | patient_108,test_109,normal_tissue 110 | patient_109,test_110,tumor_tissue 111 | patient_110,test_111,normal_tissue 112 | patient_111,test_112,normal_tissue 113 | patient_112,test_113,tumor_tissue 114 | patient_113,test_114,tumor_tissue 115 | patient_114,test_115,normal_tissue 116 | patient_115,test_116,tumor_tissue 117 | patient_116,test_117,tumor_tissue 118 | patient_117,test_118,normal_tissue 119 | patient_118,test_119,normal_tissue 120 | patient_119,test_120,normal_tissue 121 | patient_120,test_121,tumor_tissue 122 | patient_121,test_122,tumor_tissue 123 | patient_122,test_123,normal_tissue 124 | patient_123,test_124,normal_tissue 125 | patient_124,test_125,normal_tissue 126 | patient_125,test_126,normal_tissue 127 | patient_126,test_127,normal_tissue 128 | patient_127,test_128,normal_tissue 129 | patient_128,test_129,normal_tissue 130 | patient_129,test_130,normal_tissue 131 | -------------------------------------------------------------------------------- /tester_transmil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.TransMIL import TransMIL 18 | from utils import * 19 | 20 | 21 | class Accuracy_Logger(object): 22 | """Accuracy logger""" 23 | def __init__(self, n_classes): 24 | super().__init__() 25 | self.n_classes = n_classes 26 | self.initialize() 27 | 28 | def initialize(self): 29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 30 | 31 | def log(self, Y_hat, Y): 32 | Y_hat = int(Y_hat) 33 | Y = int(Y) 34 | self.data[Y]["count"] += 1 35 | self.data[Y]["correct"] += (Y_hat == Y) 36 | 37 | def log_batch(self, Y_hat, Y): 38 | Y_hat = np.array(Y_hat).astype(int) 39 | Y = np.array(Y).astype(int) 40 | for label_class in np.unique(Y): 41 | cls_mask = Y == label_class 42 | self.data[label_class]["count"] += cls_mask.sum() 43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 44 | 45 | def get_summary(self, c): 46 | count = self.data[c]["count"] 47 | correct = self.data[c]["correct"] 48 | 49 | if count == 0: 50 | acc = None 51 | else: 52 | acc = float(correct) / count 53 | 54 | return acc, correct, count 55 | 56 | def summary(model, loader, n_classes): 57 | acc_logger = Accuracy_Logger(n_classes=n_classes) 58 | model.eval() 59 | test_loss = 0. 60 | test_error = 0. 61 | 62 | all_probs = np.zeros((len(loader), n_classes)) 63 | all_labels = np.zeros(len(loader)) 64 | all_preds = np.zeros(len(loader)) 65 | 66 | slide_ids = loader.dataset.slide_data['slide_id'] 67 | patient_results = {} 68 | for batch_idx, (data, label) in enumerate(loader): 69 | data, label = data.to(device), label.to(device) 70 | slide_id = slide_ids.iloc[batch_idx] 71 | with torch.inference_mode(): 72 | logits, Y_prob, Y_hat, h0 = model(data) 73 | 74 | acc_logger.log(Y_hat, label) 75 | 76 | probs = Y_prob.cpu().numpy() 77 | 78 | all_probs[batch_idx] = probs 79 | all_labels[batch_idx] = label.item() 80 | all_preds[batch_idx] = Y_hat.item() 81 | 82 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 83 | 84 | error = calculate_error(Y_hat, label) 85 | test_error += error 86 | 87 | del data 88 | test_error /= len(loader) 89 | 90 | aucs = [] 91 | if len(np.unique(all_labels)) == 1: 92 | auc_score = -1 93 | 94 | else: 95 | if n_classes == 2: 96 | auc_score = roc_auc_score(all_labels, all_probs[:, 1]) 97 | else: 98 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 99 | for class_idx in range(n_classes): 100 | if class_idx in all_labels: 101 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 102 | aucs.append(calc_auc(fpr, tpr)) 103 | else: 104 | aucs.append(float('nan')) 105 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 106 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel()) 107 | auc_score = calc_auc(fpr, tpr) 108 | 109 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds} 110 | for c in range(n_classes): 111 | results_dict.update({'p_{}'.format(c): all_probs[:,c]}) 112 | df = pd.DataFrame(results_dict) 113 | return patient_results, test_error, auc_score, df, acc_logger 114 | 115 | if __name__ == "__main__": 116 | save_dir = './results/camelyon16to17unseen_transmil_simclr_100/' 117 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon17_unseen.csv' 118 | data_dir = '/data2/clh/camelyon17/resnet18_simclr/' 119 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_transmil_simclr_100/' 120 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/' 121 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 122 | data_dir = data_dir, 123 | shuffle = False, 124 | print_info = True, 125 | label_dict = {'normal_tissue':0, 'tumor_tissue':1}, 126 | patient_strat=False, 127 | ignore=[]) 128 | os.makedirs(save_dir, exist_ok=True) 129 | model = TransMIL(dim=512, n_classes=2).cuda() 130 | folds = [0, 1, 2, 3, 4] 131 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds] 132 | all_results = [] 133 | all_auc = [] 134 | all_acc = [] 135 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)] 136 | for ckpt_idx in range(len(ckpt_paths)): 137 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx]) 138 | # loader = get_split_loader(test_dataset, testing = False) 139 | loader = get_simple_loader(dataset) 140 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx])) 141 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2) 142 | all_results.append(all_results) 143 | all_auc.append(auc) 144 | all_acc.append(1-test_error) 145 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False) 146 | 147 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc}) 148 | save_name = 'summary.csv' 149 | final_df.to_csv(os.path.join(save_dir, save_name)) 150 | -------------------------------------------------------------------------------- /tester_mil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.MIL import * 18 | from utils import * 19 | 20 | 21 | class Accuracy_Logger(object): 22 | """Accuracy logger""" 23 | def __init__(self, n_classes): 24 | super().__init__() 25 | self.n_classes = n_classes 26 | self.initialize() 27 | 28 | def initialize(self): 29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 30 | 31 | def log(self, Y_hat, Y): 32 | Y_hat = int(Y_hat) 33 | Y = int(Y) 34 | self.data[Y]["count"] += 1 35 | self.data[Y]["correct"] += (Y_hat == Y) 36 | 37 | def log_batch(self, Y_hat, Y): 38 | Y_hat = np.array(Y_hat).astype(int) 39 | Y = np.array(Y).astype(int) 40 | for label_class in np.unique(Y): 41 | cls_mask = Y == label_class 42 | self.data[label_class]["count"] += cls_mask.sum() 43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 44 | 45 | def get_summary(self, c): 46 | count = self.data[c]["count"] 47 | correct = self.data[c]["correct"] 48 | 49 | if count == 0: 50 | acc = None 51 | else: 52 | acc = float(correct) / count 53 | 54 | return acc, correct, count 55 | 56 | def summary(model, loader, n_classes): 57 | acc_logger = Accuracy_Logger(n_classes=n_classes) 58 | model.eval() 59 | test_loss = 0. 60 | test_error = 0. 61 | 62 | all_probs = np.zeros((len(loader), n_classes)) 63 | all_labels = np.zeros(len(loader)) 64 | all_preds = np.zeros(len(loader)) 65 | 66 | slide_ids = loader.dataset.slide_data['slide_id'] 67 | patient_results = {} 68 | for batch_idx, (data, label) in enumerate(loader): 69 | data, label = data.to(device), label.to(device) 70 | slide_id = slide_ids.iloc[batch_idx] 71 | with torch.inference_mode(): 72 | logits, Y_prob, Y_hat, y_probs, _ = model(data) 73 | acc_logger.log(Y_hat, label) 74 | 75 | probs = Y_prob.cpu().numpy() 76 | 77 | all_probs[batch_idx] = probs 78 | all_labels[batch_idx] = label.item() 79 | all_preds[batch_idx] = Y_hat.item() 80 | 81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 82 | 83 | error = calculate_error(Y_hat, label) 84 | test_error += error 85 | 86 | del data 87 | test_error /= len(loader) 88 | 89 | aucs = [] 90 | if len(np.unique(all_labels)) == 1: 91 | auc_score = -1 92 | 93 | else: 94 | if n_classes == 2: 95 | auc_score = roc_auc_score(all_labels, all_probs[:, 1]) 96 | else: 97 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 98 | for class_idx in range(n_classes): 99 | if class_idx in all_labels: 100 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 101 | aucs.append(calc_auc(fpr, tpr)) 102 | else: 103 | aucs.append(float('nan')) 104 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 105 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel()) 106 | auc_score = calc_auc(fpr, tpr) 107 | 108 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds} 109 | for c in range(n_classes): 110 | results_dict.update({'p_{}'.format(c): all_probs[:,c]}) 111 | df = pd.DataFrame(results_dict) 112 | return patient_results, test_error, auc_score, df, acc_logger 113 | 114 | if __name__ == "__main__": 115 | save_dir = './results/camelyon16to17unseen_rnn_simclr_100/' 116 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon17_unseen.csv' 117 | data_dir = '/data2/clh/camelyon17/resnet18_simclr/' 118 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_rnn_simclr_100/' 119 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/' 120 | # {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5} 121 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 122 | data_dir = data_dir, 123 | shuffle = False, 124 | print_info = True, 125 | label_dict = {'normal_tissue':0, 'tumor_tissue':1}, 126 | patient_strat=False, 127 | ignore=[]) 128 | os.makedirs(save_dir, exist_ok=True) 129 | model = MIL_RNN(embed_dim=512, n_classes=2).cuda() 130 | folds = [0, 1, 2, 3, 4] 131 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds] 132 | all_results = [] 133 | all_auc = [] 134 | all_acc = [] 135 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)] 136 | for ckpt_idx in range(len(ckpt_paths)): 137 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx]) 138 | # loader = get_split_loader(test_dataset, testing = False) 139 | loader = get_simple_loader(dataset) 140 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx])) 141 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2) 142 | all_results.append(all_results) 143 | all_auc.append(auc) 144 | all_acc.append(1-test_error) 145 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False) 146 | 147 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc}) 148 | save_name = 'summary.csv' 149 | final_df.to_csv(os.path.join(save_dir, save_name)) 150 | -------------------------------------------------------------------------------- /tester_dsmil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.DSMIL import MILNet 18 | from utils import * 19 | 20 | 21 | class Accuracy_Logger(object): 22 | """Accuracy logger""" 23 | def __init__(self, n_classes): 24 | super().__init__() 25 | self.n_classes = n_classes 26 | self.initialize() 27 | 28 | def initialize(self): 29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 30 | 31 | def log(self, Y_hat, Y): 32 | Y_hat = int(Y_hat) 33 | Y = int(Y) 34 | self.data[Y]["count"] += 1 35 | self.data[Y]["correct"] += (Y_hat == Y) 36 | 37 | def log_batch(self, Y_hat, Y): 38 | Y_hat = np.array(Y_hat).astype(int) 39 | Y = np.array(Y).astype(int) 40 | for label_class in np.unique(Y): 41 | cls_mask = Y == label_class 42 | self.data[label_class]["count"] += cls_mask.sum() 43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 44 | 45 | def get_summary(self, c): 46 | count = self.data[c]["count"] 47 | correct = self.data[c]["correct"] 48 | 49 | if count == 0: 50 | acc = None 51 | else: 52 | acc = float(correct) / count 53 | 54 | return acc, correct, count 55 | 56 | def summary(model, loader, n_classes): 57 | acc_logger = Accuracy_Logger(n_classes=n_classes) 58 | model.eval() 59 | test_loss = 0. 60 | test_error = 0. 61 | 62 | all_probs = np.zeros((len(loader), n_classes)) 63 | all_labels = np.zeros(len(loader)) 64 | all_preds = np.zeros(len(loader)) 65 | 66 | slide_ids = loader.dataset.slide_data['slide_id'] 67 | patient_results = {} 68 | for batch_idx, (data, label) in enumerate(loader): 69 | data, label = data.to(device), label.to(device) 70 | slide_id = slide_ids.iloc[batch_idx] 71 | with torch.inference_mode(): 72 | ins_prediction, bag_prediction, _, _ = model(data) 73 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1] 74 | acc_logger.log(Y_hat, label) 75 | 76 | max_prediction, _ = torch.max(ins_prediction, 0) 77 | Y_prob = F.softmax(bag_prediction, dim=-1) 78 | probs = Y_prob.cpu().numpy() 79 | 80 | all_probs[batch_idx] = probs 81 | all_labels[batch_idx] = label.item() 82 | all_preds[batch_idx] = Y_hat.item() 83 | 84 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 85 | 86 | error = calculate_error(Y_hat, label) 87 | test_error += error 88 | 89 | del data 90 | test_error /= len(loader) 91 | 92 | aucs = [] 93 | if len(np.unique(all_labels)) == 1: 94 | auc_score = -1 95 | 96 | else: 97 | if n_classes == 2: 98 | auc_score = roc_auc_score(all_labels, all_probs[:, 1]) 99 | else: 100 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 101 | for class_idx in range(n_classes): 102 | if class_idx in all_labels: 103 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 104 | aucs.append(calc_auc(fpr, tpr)) 105 | else: 106 | aucs.append(float('nan')) 107 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 108 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel()) 109 | auc_score = calc_auc(fpr, tpr) 110 | 111 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds} 112 | for c in range(n_classes): 113 | results_dict.update({'p_{}'.format(c): all_probs[:,c]}) 114 | df = pd.DataFrame(results_dict) 115 | return patient_results, test_error, auc_score, df, acc_logger 116 | 117 | if __name__ == "__main__": 118 | save_dir = './results/camelyon16to17unseen_dsmil_simclr_100/' 119 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon17_unseen.csv' 120 | data_dir = '/data2/clh/camelyon17/resnet18_simclr/' 121 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_dsmil_simclr/' 122 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/' 123 | # {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5} 124 | # {'LUAD':0, 'LUSC':1} 125 | # {'normal_tissue':0, 'tumor_tissue':1} 126 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 127 | data_dir = data_dir, 128 | shuffle = False, 129 | print_info = True, 130 | label_dict = {'normal_tissue':0, 'tumor_tissue':1}, 131 | patient_strat=False, 132 | ignore=[]) 133 | os.makedirs(save_dir, exist_ok=True) 134 | model = MILNet(feature_dim=512, n_classes=2).cuda() 135 | folds = [0, 1, 2, 3, 4] 136 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds] 137 | all_results = [] 138 | all_auc = [] 139 | all_acc = [] 140 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)] 141 | for ckpt_idx in range(len(ckpt_paths)): 142 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx]) 143 | # loader = get_split_loader(test_dataset, testing = False) 144 | loader = get_simple_loader(dataset) 145 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx])) 146 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2) 147 | all_results.append(all_results) 148 | all_auc.append(auc) 149 | all_acc.append(1-test_error) 150 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False) 151 | 152 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc}) 153 | save_name = 'summary.csv' 154 | final_df.to_csv(os.path.join(save_dir, save_name)) 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AttriMIL: Revisiting attention-based multiple instance learning for whole-slide pathological image classification from a perspective of instance attributes 2 | The official implementation of AttriMIL (published at _Medical Image Analysis 2025_). 3 | 4 | ## 1. Introduction 5 | ### 1.1 Background 6 | WSI classification typically requires the MIL framework to perform two key tasks: bag classification and instance discrimination, which correspond to clinical diagnosis and the localization of disease-positive regions, respectively. Among various MIL architectures, attention-based MIL frameworks address both tasks simultaneously under weak supervision and thus dominate pathology image analysis. However, attention-based MIL frameworks face two challenges: 7 | 8 | (i) The incorrect measure of pathological attributes based on attention, which may confuse diagnosis. 9 | 10 | (ii) The negligence of modeling intra-slide and inter-slide interaction, which is essential to obtain robust semantic representation of instances. 11 | 12 |

13 |
14 | 15 | Figure 1. Illustration of the workflow of attention-based MIL frameworks and the attribute scoring mechanism in AttriMIL. 16 | 17 |

18 | 19 | To overcome these issues, we propose a novel framework named attribute-aware multiple instance learning (AttriMIL) tailored for pathological image classification. 20 | 21 | (i) To identify the pathological attributes of instances, AttriMIL employs a multi-branch attribute scoring mechanism, where each branch integrates attention pooling with the classification head, deriving precise estimation of each instance's contribution to the bag prediction. 22 | 23 | (ii) Considering the intrinsic correlations between image patches in WSIs, we introduce two constraints to enhance the MIL framework's sensitivity to instance attributes. 24 | 25 | (iii) Inspired by parameter-efficient fine-tuning techniques, we design a pathology adaptive learning strategy for efficient pathological feature extraction. This optimized backbone empowers AttriMIL to model instance correlations across multiple feature levels. 26 | 27 | ### 1.2. Framework 28 | Figure 2 presents an overview of AttriMIL, which comprises three main components: (1) a pathology adaptive backbone for extracting optimized instance-level features, (2) multi-branch attribute scoring mechanism with attribute constraints, and (3) score aggregation and bag prediction. In this section, we first revisit multiple instance learning and attention-based frameworks, followed by a detailed description of AttriMIL. 29 | 30 |

31 |
32 | 33 | Figure 2. Overview of the proposed AttriMIL. 34 | 35 |

36 | 37 | ### 1.3 Performance 38 | AttriMIL achieves the state-of-the-art performance on four benchmarks, showcasing the superior bag classification performance, generalization ability, and instance localization capability. Additionally, AttriMIL is capable of identifying bag with a small proportion of target regions. 39 | 40 |

41 |
42 | 43 | Figure 3. Quantative comparison of the state-of-the-art WSI classification algorithms. 44 | 45 |

46 | 47 | ## 2. Quick Start 48 | ### 2.1 Installation 49 | AttriMIL is extended from [CLAM]([https://github.com/microsoft/CvT](https://github.com/mahmoodlab/CLAM)).Assuming that you have installed PyTorch and TorchVision, if not, please follow the [officiall instruction](https://pytorch.org/) to install them firstly. 50 | Intall the dependencies using cmd: 51 | ``` sh 52 | conda env create -f env.yml 53 | ``` 54 | The code is developed and tested using pytorch 1.10.0. Other versions of pytorch are not fully tested. 55 | 56 | ### 2.2 Data preparation 57 | Data preparation based on CLAM, including tissue segmentation, patching, and feature extraction. In comparison to traditional process, we introduce a neiboorhood generation process and use pathology-adaptive learning for instance-level feature extraction. 58 | ``` sh 59 | python create_3coords.py # generate neighbor indices 60 | python coord_to_feature.py # incorporate the indices to feature (h5) files 61 | ``` 62 | The final data following the structure: 63 | ```bash 64 | FEATURES_DIRECTORY/ 65 | ├── h5_files 66 | ├── slide_1.h5 67 | ├── slide_2.h5 68 | └── ... 69 | └── pt_files 70 | ├── slide_1.pt 71 | ├── slide_2.pt 72 | └── ... 73 | ``` 74 | where each .h5 file contains an array of extracted features along with their patch coordinates and neighbor indices. 75 | 76 | ### 2.3 Pretrained Weights 77 | Pretrained weights are based on ResNet18 ImageNet, ResNet18 SimCLR. We follow [DSMIL](https://github.com/binli123/dsmil-wsi/tree/master) for generating the SSL features. Surely, you can also use other pre-trained models. 78 | 79 | ### 2.4 Training and Testing 80 | Training your AttriMIL: 81 | ``` sh 82 | python trainer_attrimil_abmil.py 83 | ``` 84 | Note that, the AttriMIL+DSMIL and AttriMIL+TransMIL will be released soon. 85 | 86 | ``` sh 87 | python tester_attrimil_abmil.py 88 | ``` 89 | The visual results can directly capture the disease-positive regions, which is encouraging. 90 | 91 |

92 |
93 | 94 | Figure 4. Qualitative comparison of the state-of-the-art WSI classification algorithms. 95 | 96 |

97 | 98 | ## 3. Citation 99 | If you find this work or code is helpful in your research, please cite: 100 | 101 | ``` 102 | @article{cai2025attrimil, 103 | title={AttriMIL: Revisiting attention-based multiple instance learning for whole-slide pathological image classification from a perspective of instance attributes}, 104 | author={Cai, Linghan and Huang, Shenjin and Zhang, Ye and Lu, Jinpeng and Zhang, Yongbing}, 105 | journal={Medical Image Analysis}, 106 | pages={103631}, 107 | year={2025}, 108 | publisher={Elsevier} 109 | } 110 | ``` 111 | 112 | ## 4. Contributing 113 | Thanks to the following work for improving our project: 114 | - CLAM: [https://github.com/mahmoodlab/CLAM](https://github.com/mahmoodlab/CLAM) 115 | - DSMIL: [https://github.com/binli123/dsmil-wsi/tree/master](https://github.com/binli123/dsmil-wsi/tree/master) 116 | - MambaMIL: [https://github.com/isyangshu/MambaMIL](https://github.com/isyangshu/MambaMIL) 117 | 118 | ## 5. License 119 | Distributed under the Apache 2.0 License. See LICENSE for more information. 120 | -------------------------------------------------------------------------------- /tester_attrimil_abmil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.AttriMIL import AttriMIL 18 | from utils import * 19 | 20 | 21 | class Accuracy_Logger(object): 22 | """Accuracy logger""" 23 | def __init__(self, n_classes): 24 | super().__init__() 25 | self.n_classes = n_classes 26 | self.initialize() 27 | 28 | def initialize(self): 29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 30 | 31 | def log(self, Y_hat, Y): 32 | Y_hat = int(Y_hat) 33 | Y = int(Y) 34 | self.data[Y]["count"] += 1 35 | self.data[Y]["correct"] += (Y_hat == Y) 36 | 37 | def log_batch(self, Y_hat, Y): 38 | Y_hat = np.array(Y_hat).astype(int) 39 | Y = np.array(Y).astype(int) 40 | for label_class in np.unique(Y): 41 | cls_mask = Y == label_class 42 | self.data[label_class]["count"] += cls_mask.sum() 43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 44 | 45 | def get_summary(self, c): 46 | count = self.data[c]["count"] 47 | correct = self.data[c]["correct"] 48 | 49 | if count == 0: 50 | acc = None 51 | else: 52 | acc = float(correct) / count 53 | 54 | return acc, correct, count 55 | 56 | def summary(model, loader, n_classes): 57 | acc_logger = Accuracy_Logger(n_classes=n_classes) 58 | model.eval() 59 | test_loss = 0. 60 | test_error = 0. 61 | 62 | all_probs = np.zeros((len(loader), n_classes)) 63 | all_labels = np.zeros(len(loader)) 64 | all_preds = np.zeros(len(loader)) 65 | 66 | slide_ids = loader.dataset.slide_data['slide_id'] 67 | patient_results = {} 68 | for batch_idx, (data, label) in enumerate(loader): 69 | data, label = data.to(device), label.to(device) 70 | slide_id = slide_ids.iloc[batch_idx] 71 | with torch.inference_mode(): 72 | logits, Y_prob, Y_hat, attribute_score, results_dict = model(data) 73 | acc_logger.log(Y_hat, label) 74 | 75 | probs = Y_prob.cpu().numpy() 76 | 77 | all_probs[batch_idx] = probs 78 | all_labels[batch_idx] = label.item() 79 | all_preds[batch_idx] = Y_hat.item() 80 | 81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 82 | 83 | error = calculate_error(Y_hat, label) 84 | test_error += error 85 | 86 | del data 87 | test_error /= len(loader) 88 | 89 | aucs = [] 90 | if len(np.unique(all_labels)) == 1: 91 | auc_score = -1 92 | 93 | else: 94 | if n_classes == 2: 95 | auc_score = roc_auc_score(all_labels, all_probs[:, 1]) 96 | else: 97 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 98 | for class_idx in range(n_classes): 99 | if class_idx in all_labels: 100 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 101 | aucs.append(calc_auc(fpr, tpr)) 102 | else: 103 | aucs.append(float('nan')) 104 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 105 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel()) 106 | auc_score = calc_auc(fpr, tpr) 107 | 108 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds} 109 | for c in range(n_classes): 110 | results_dict.update({'p_{}'.format(c): all_probs[:,c]}) 111 | df = pd.DataFrame(results_dict) 112 | return patient_results, test_error, auc_score, df, acc_logger 113 | 114 | if __name__ == "__main__": 115 | import time 116 | start_time = time.time() 117 | save_dir = './results/timing_test/' 118 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon16_test.csv' 119 | data_dir = '/data2/clh/camelyon16/resnet18_simclr/' 120 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_attrimil_simclr/' 121 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/camelyon16/' 122 | # {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5} 123 | # {'LUAD':0, 'LUSC':1} 124 | # {'normal_tissue':0, 'tumor_tissue':1} 125 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 126 | data_dir = data_dir, 127 | shuffle = False, 128 | print_info = True, 129 | label_dict = {'normal_tissue':0, 'tumor_tissue':1}, 130 | patient_strat=False, 131 | ignore=[]) 132 | os.makedirs(save_dir, exist_ok=True) 133 | model = AttriMIL(dim=512, n_classes=2).cuda() 134 | # folds = [0, 1, 2, 3, 4] 135 | folds = [0] 136 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds] 137 | all_results = [] 138 | all_auc = [] 139 | all_acc = [] 140 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)] 141 | for ckpt_idx in range(len(ckpt_paths)): 142 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx]) 143 | # loader = get_split_loader(test_dataset, testing = False) 144 | loader = get_simple_loader(dataset) 145 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx])) 146 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2) 147 | all_results.append(all_results) 148 | all_auc.append(auc) 149 | all_acc.append(1-test_error) 150 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False) 151 | 152 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc}) 153 | save_name = 'summary.csv' 154 | final_df.to_csv(os.path.join(save_dir, save_name)) 155 | end_time = time.time() 156 | elapsed_time = end_time - start_time 157 | print(f"计时结果: 耗时{elapsed_time}秒") 158 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import pdb 6 | 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler 12 | import torch.optim as optim 13 | import pdb 14 | import torch.nn.functional as F 15 | import math 16 | from itertools import islice 17 | import collections 18 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | class SubsetSequentialSampler(Sampler): 21 | """Samples elements sequentially from a given list of indices, without replacement. 22 | 23 | Arguments: 24 | indices (sequence): a sequence of indices 25 | """ 26 | def __init__(self, indices): 27 | self.indices = indices 28 | 29 | def __iter__(self): 30 | return iter(self.indices) 31 | 32 | def __len__(self): 33 | return len(self.indices) 34 | 35 | def collate_MIL(batch): 36 | img = torch.cat([item[0] for item in batch], dim = 0) 37 | label = torch.LongTensor([item[1] for item in batch]) 38 | return [img, label] 39 | 40 | def collate_MIL_coords(batch): # 训练时候,我需要有coords函数 41 | img = torch.cat([item[0] for item in batch], dim=0) 42 | label = torch.LongTensor([item[1] for item in batch]) 43 | coords = torch.cat([torch.from_numpy(item[2]) for item in batch], dim=0) 44 | nearest = torch.cat([torch.from_numpy(item[3]) for item in batch], dim=0) 45 | return [img, label, coords, nearest] 46 | 47 | def collate_features(batch): 48 | img = torch.cat([item[0] for item in batch], dim = 0) 49 | coords = np.vstack([item[1] for item in batch]) 50 | return [img, coords] 51 | 52 | 53 | def get_simple_loader(dataset, batch_size=1, num_workers=1): 54 | kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {} 55 | loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs) 56 | return loader 57 | 58 | def get_split_loader(split_dataset, training = False, testing = False, weighted = False): 59 | """ 60 | return either the validation loader or training loader 61 | """ 62 | kwargs = {'num_workers': 4} if device.type == "cuda" else {} 63 | if not testing: 64 | if training: 65 | if weighted: 66 | weights = make_weights_for_balanced_classes_split(split_dataset) 67 | loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL_coords, **kwargs) 68 | else: 69 | loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL_coords, **kwargs) 70 | else: 71 | loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL_coords, **kwargs) 72 | 73 | else: 74 | ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False) 75 | loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL, **kwargs ) 76 | 77 | return loader 78 | 79 | def get_optim(model, args): 80 | if args.opt == "adam": 81 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg) 82 | elif args.opt == 'sgd': 83 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg) 84 | else: 85 | raise NotImplementedError 86 | return optimizer 87 | 88 | def print_network(net): 89 | num_params = 0 90 | num_params_train = 0 91 | print(net) 92 | 93 | for param in net.parameters(): 94 | n = param.numel() 95 | num_params += n 96 | if param.requires_grad: 97 | num_params_train += n 98 | 99 | print('Total number of parameters: %d' % num_params) 100 | print('Total number of trainable parameters: %d' % num_params_train) 101 | 102 | 103 | def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5, 104 | seed = 7, label_frac = 1.0, custom_test_ids = None): 105 | indices = np.arange(samples).astype(int) 106 | 107 | if custom_test_ids is not None: 108 | indices = np.setdiff1d(indices, custom_test_ids) 109 | 110 | np.random.seed(seed) 111 | for i in range(n_splits): 112 | all_val_ids = [] 113 | all_test_ids = [] 114 | sampled_train_ids = [] 115 | 116 | if custom_test_ids is not None: # pre-built test split, do not need to sample 117 | all_test_ids.extend(custom_test_ids) 118 | 119 | for c in range(len(val_num)): 120 | possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class 121 | val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids 122 | 123 | remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation 124 | all_val_ids.extend(val_ids) 125 | 126 | if custom_test_ids is None: # sample test split 127 | 128 | test_ids = np.random.choice(remaining_ids, test_num[c], replace = False) 129 | remaining_ids = np.setdiff1d(remaining_ids, test_ids) 130 | all_test_ids.extend(test_ids) 131 | 132 | if label_frac == 1: 133 | sampled_train_ids.extend(remaining_ids) 134 | 135 | else: 136 | sample_num = math.ceil(len(remaining_ids) * label_frac) 137 | slice_ids = np.arange(sample_num) 138 | sampled_train_ids.extend(remaining_ids[slice_ids]) 139 | 140 | yield sampled_train_ids, all_val_ids, all_test_ids 141 | 142 | 143 | def nth(iterator, n, default=None): 144 | if n is None: 145 | return collections.deque(iterator, maxlen=0) 146 | else: 147 | return next(islice(iterator,n, None), default) 148 | 149 | def calculate_error(Y_hat, Y): 150 | error = 1. - Y_hat.float().eq(Y.float()).float().mean().item() 151 | 152 | return error 153 | 154 | def make_weights_for_balanced_classes_split(dataset): 155 | N = float(len(dataset)) 156 | weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))] 157 | weight = [0] * int(N) 158 | for idx in range(len(dataset)): 159 | y = dataset.getlabel(idx) 160 | weight[idx] = weight_per_class[y] 161 | 162 | return torch.DoubleTensor(weight) 163 | 164 | def initialize_weights(module): 165 | for m in module.modules(): 166 | if isinstance(m, nn.Linear): 167 | nn.init.xavier_normal_(m.weight) 168 | m.bias.data.zero_() 169 | 170 | elif isinstance(m, nn.BatchNorm1d): 171 | nn.init.constant_(m.weight, 1) 172 | nn.init.constant_(m.bias, 0) 173 | 174 | -------------------------------------------------------------------------------- /models/S4MIL.py: -------------------------------------------------------------------------------- 1 | # This code is taken from the original S4 repository https://github.com/HazyResearch/state-spaces 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | import opt_einsum as oe 8 | 9 | _c2r = torch.view_as_real 10 | _r2c = torch.view_as_complex 11 | 12 | class DropoutNd(nn.Module): 13 | def __init__(self, p: float = 0.5, tie=True, transposed=True): 14 | """ 15 | tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) 16 | """ 17 | super().__init__() 18 | if p < 0 or p >= 1: 19 | raise ValueError( 20 | "dropout probability has to be in [0, 1), " "but got {}".format(p)) 21 | self.p = p 22 | self.tie = tie 23 | self.transposed = transposed 24 | self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) 25 | 26 | def forward(self, X): 27 | """ X: (batch, dim, lengths...) """ 28 | if self.training: 29 | if not self.transposed: 30 | X = rearrange(X, 'b d ... -> b ... d') 31 | # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow 32 | mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape 33 | # mask = self.binomial.sample(mask_shape) 34 | mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p 35 | X = X * mask * (1.0/(1-self.p)) 36 | if not self.transposed: 37 | X = rearrange(X, 'b ... d -> b d ...') 38 | return X 39 | return X 40 | 41 | 42 | class S4DKernel(nn.Module): 43 | """Wrapper around SSKernelDiag that generates the diagonal SSM parameters 44 | """ 45 | 46 | def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None): 47 | super().__init__() 48 | # Generate dt 49 | H = d_model 50 | log_dt = torch.rand(H) * ( 51 | math.log(dt_max) - math.log(dt_min) 52 | ) + math.log(dt_min) 53 | 54 | C = torch.randn(H, N // 2, dtype=torch.cfloat) 55 | self.C = nn.Parameter(_c2r(C)) 56 | self.register("log_dt", log_dt, lr) 57 | 58 | log_A_real = torch.log(0.5 * torch.ones(H, N//2)) 59 | A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H) 60 | self.register("log_A_real", log_A_real, lr) 61 | self.register("A_imag", A_imag, lr) 62 | 63 | def forward(self, L): 64 | """ 65 | returns: (..., c, L) where c is number of channels (default 1) 66 | """ 67 | 68 | # Materialize parameters 69 | dt = torch.exp(self.log_dt) # (H) 70 | C = _r2c(self.C) # (H N) 71 | A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N) 72 | 73 | # Vandermonde multiplication 74 | dtA = A * dt.unsqueeze(-1) # (H N) 75 | K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L) 76 | C = C * (torch.exp(dtA)-1.) / A 77 | K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real 78 | 79 | return K 80 | 81 | def register(self, name, tensor, lr=None): 82 | """Register a tensor with a configurable learning rate and 0 weight decay""" 83 | 84 | if lr == 0.0: 85 | self.register_buffer(name, tensor) 86 | else: 87 | self.register_parameter(name, nn.Parameter(tensor)) 88 | 89 | optim = {"weight_decay": 0.0} 90 | if lr is not None: 91 | optim["lr"] = lr 92 | setattr(getattr(self, name), "_optim", optim) 93 | 94 | 95 | class S4D(nn.Module): 96 | 97 | def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args): 98 | super().__init__() 99 | 100 | self.h = d_model 101 | self.n = d_state 102 | self.d_output = self.h 103 | self.transposed = transposed 104 | 105 | self.D = nn.Parameter(torch.randn(self.h)) 106 | 107 | # SSM Kernel 108 | self.kernel = S4DKernel(self.h, N=self.n, **kernel_args) 109 | 110 | # Pointwise 111 | self.activation = nn.GELU() 112 | # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11 113 | dropout_fn = DropoutNd 114 | self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() 115 | 116 | # position-wise output transform to mix features 117 | self.output_linear = nn.Sequential( 118 | nn.Conv1d(self.h, 2*self.h, kernel_size=1), 119 | nn.GLU(dim=-2), 120 | ) 121 | 122 | def forward(self, u, **kwargs): # absorbs return_output and transformer src mask 123 | """ Input and output shape (B, H, L) """ 124 | if not self.transposed: 125 | u = u.transpose(-1, -2) 126 | L = u.size(-1) 127 | 128 | # Compute SSM Kernel 129 | k = self.kernel(L=L) # (H L) 130 | 131 | # Convolution 132 | k_f = torch.fft.rfft(k, n=2*L) # (H L) 133 | u_f = torch.fft.rfft(u.to(torch.float32), n=2*L) # (B H L) 134 | y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L) 135 | 136 | # Compute D term in state space equation - essentially a skip connection 137 | y = y + u * self.D.unsqueeze(-1) 138 | 139 | y = self.dropout(self.activation(y)) 140 | y = self.output_linear(y) 141 | if not self.transposed: 142 | y = y.transpose(-1, -2) 143 | return y 144 | 145 | 146 | class S4Model(nn.Module): 147 | def __init__(self, in_dim, n_classes, dropout, act, survival = False): 148 | super(S4Model, self).__init__() 149 | self.n_classes = n_classes 150 | self._fc1 = [nn.Linear(in_dim, 512)] 151 | if act.lower() == 'relu': 152 | self._fc1 += [nn.ReLU()] 153 | elif act.lower() == 'gelu': 154 | self._fc1 += [nn.GELU()] 155 | if dropout: 156 | self._fc1 += [nn.Dropout(dropout)] 157 | print("dropout: ", dropout) 158 | self._fc1 = nn.Sequential(*self._fc1) 159 | self.s4_block = nn.Sequential(nn.LayerNorm(512), 160 | S4D(d_model=512, d_state=32, transposed=False)) 161 | 162 | self.classifier = nn.Linear(512, self.n_classes) 163 | self.survival = survival 164 | def forward(self, x): 165 | x = x.unsqueeze(0) 166 | # print(x.shape) 167 | x = self._fc1(x) 168 | x = self.s4_block(x) 169 | x = torch.max(x, axis=1).values 170 | # print(x.shape) 171 | logits = self.classifier(x) 172 | Y_prob = F.softmax(logits, dim=1) 173 | Y_hat = torch.topk(logits, 1, dim=1)[1] 174 | A_raw = None 175 | results_dict = None 176 | if self.survival: 177 | Y_hat = torch.topk(logits, 1, dim = 1)[1] 178 | hazards = torch.sigmoid(logits) 179 | S = torch.cumprod(1 - hazards, dim=1) 180 | return hazards, S, Y_hat, None, None 181 | return logits, Y_prob, Y_hat, A_raw, results_dict 182 | 183 | 184 | def relocate(self): 185 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 186 | self._fc1 = self._fc1.to(device) 187 | self.s4_block = self.s4_block .to(device) 188 | self.classifier = self.classifier.to(device) 189 | 190 | if __name__ == "__main__": 191 | data = torch.randn((6000, 1024)) 192 | data.to('cuda') 193 | model = S4Model(in_dim = 1024, n_classes = 4, act = 'gelu', dropout = 0.25) 194 | print(model) 195 | results_dict = model(data) 196 | print(results_dict) -------------------------------------------------------------------------------- /datasets/unitopatho_train.csv: -------------------------------------------------------------------------------- 1 | case_id,slide_id,label 2 | patient_1,NORM CASO 8 - 2019-03-04 08.51.07_1,NORM 3 | patient_2,208-B5-NORM_1,NORM 4 | patient_3,271-B5-NORM_1,NORM 5 | patient_4,196-B4-NORM_1,NORM 6 | patient_5,57-B2-NORM_1,NORM 7 | patient_6,215-B5-NORM_1,NORM 8 | patient_7,265-B5-NORM_1,NORM 9 | patient_8,216-B5-NORM_1,NORM 10 | patient_9,131-B3-NORM_1,NORM 11 | patient_10,264-B5-NORM_1,NORM 12 | patient_11,184-B4-NORM_1,NORM 13 | patient_12,185-B4-NORM_1,NORM 14 | patient_13,NORM CASO 7 - 2019-03-01 09.12.01_1,NORM 15 | patient_14,109-B3-TVALG_5,TVA.LG 16 | patient_15,TVA.LG CASO 2 - 2018-12-04 13.19.16_5,TVA.LG 17 | patient_16,93-B2-TVALG_5,TVA.LG 18 | patient_17,TVA.LG CASO 9_5,TVA.LG 19 | patient_18,209-B5-TVALG_5,TVA.LG 20 | patient_19,113-B3-TVALG_5,TVA.LG 21 | patient_20,178-B4-TVALG_5,TVA.LG 22 | patient_21,172-B4-TVALG_5,TVA.LG 23 | patient_22,TVA.LG CASO 1 - 2018-12-04 13.17.55_5,TVA.LG 24 | patient_23,114-B3-TVALG_5,TVA.LG 25 | patient_24,176-B4-TVALG_5,TVA.LG 26 | patient_25,118-B3-TVALG_5,TVA.LG 27 | patient_26,96-B2-TVALG_5,TVA.LG 28 | patient_27,205-B5-TVALG_5,TVA.LG 29 | patient_28,207-B5-TVALG_5,TVA.LG 30 | patient_29,142-B3-TVALG_5,TVA.LG 31 | patient_30,174-B4-TVALG_5,TVA.LG 32 | patient_31,80-B2-TVALG_5,TVA.LG 33 | patient_32,100-B2-TVALG_5,TVA.LG 34 | patient_33,252-B5-TVALG_5,TVA.LG 35 | patient_34,TVA.LG CASO 10_5,TVA.LG 36 | patient_35,189-B4-TVALG_5,TVA.LG 37 | patient_36,TVA.LG CASO 11_5,TVA.LG 38 | patient_37,218-B5-TVALG_5,TVA.LG 39 | patient_38,154-B4-TVALG_5,TVA.LG 40 | patient_39,124-B3-TVALG_5,TVA.LG 41 | patient_40,70-B2-TVALG_5,TVA.LG 42 | patient_41,TVA.LG CASO 4 - 2018-12-04 13.25.02_5,TVA.LG 43 | patient_42,198-B4-TVALG_5,TVA.LG 44 | patient_43,191-B4-TVALG_5,TVA.LG 45 | patient_44,270-B5-TAHG_2,TA.HG 46 | patient_45,226-B5-TAHG_2,TA.HG 47 | patient_46,62-B2-TAHG_2,TA.HG 48 | patient_47,221-B5-TAHG_2,TA.HG 49 | patient_48,84-B2-TAHG_2,TA.HG 50 | patient_49,TA.HG CASO 10 - 2019-03-04 09.24.36_2,TA.HG 51 | patient_50,219-B5-TAHG_2,TA.HG 52 | patient_51,TA.HG CASO 14_2,TA.HG 53 | patient_52,TA.HG CASO 10_2,TA.HG 54 | patient_53,TA.HG CASO 13_2,TA.HG 55 | patient_54,54-B2-TAHG_2,TA.HG 56 | patient_55,TA.HG CASO 1 - 2018-12-04 12.58.12_2,TA.HG 57 | patient_56,223-B5-TAHG_2,TA.HG 58 | patient_57,TA.HG CASO 15_2,TA.HG 59 | patient_58,TA.HG CASO 16_2,TA.HG 60 | patient_59,TA.HG CASO 17_2,TA.HG 61 | patient_60,137-B3-TAHG_2,TA.HG 62 | patient_61,HP CASO 1 - 2018-12-04 13.30.08_0,HP 63 | patient_62,HP CASO 28 - 2019-03-04 17.34.19_0,HP 64 | patient_63,HP CASO 47_0,HP 65 | patient_64,HP CASO 43 B1_0,HP 66 | patient_65,HP CASO 26 - 2019-03-04 09.51.06_0,HP 67 | patient_66,HP CASO 36_0,HP 68 | patient_67,149-B3-HP_0,HP 69 | patient_68,170-B4-HP_0,HP 70 | patient_69,103-B3-HP_0,HP 71 | patient_70,HP CASO 44_0,HP 72 | patient_71,158-B4-HP_0,HP 73 | patient_72,168-B4-HP_0,HP 74 | patient_73,HP CASO 40 B1_0,HP 75 | patient_74,250-B5-HP_0,HP 76 | patient_75,151-B4-HP_0,HP 77 | patient_76,224-B5-HP_0,HP 78 | patient_77,HP CASO 38_0,HP 79 | patient_78,148-B3-HP_0,HP 80 | patient_79,HP CASO 2 - 2018-12-04 13.56.35_0,HP 81 | patient_80,HP CASO 39_0,HP 82 | patient_81,HP CASO 46_0,HP 83 | patient_82,HP CASO 42_0,HP 84 | patient_83,HP CASO 43 D1_0,HP 85 | patient_84,169-B4-HP_0,HP 86 | patient_85,HP CASO 45_0,HP 87 | patient_86,165-B4-HP_0,HP 88 | patient_87,159-B4-HP_0,HP 89 | patient_88,HP CASO 41 B1_0,HP 90 | patient_89,167-B4-HP_0,HP 91 | patient_90,194-B4-HP_0,HP 92 | patient_91,HP CASO 24 - 2019-03-04 08.59.32_0,HP 93 | patient_92,TA.LG CASO 57 - 2019-03-04 10.09.38_3,TA.LG 94 | patient_93,245-B5-TALG_3,TA.LG 95 | patient_94,TA.LG CASO 102_3,TA.LG 96 | patient_95,TA.LG CASO 79_3,TA.LG 97 | patient_96,86-B2-TALG_3,TA.LG 98 | patient_97,87-B2-TALG_3,TA.LG 99 | patient_98,239-B5-TALG_3,TA.LG 100 | patient_99,TA.LG CASO 50 - 2019-03-04 08.13.45_3,TA.LG 101 | patient_100,TA.LG CASO 9 - 2018-12-04 13.41.36_3,TA.LG 102 | patient_101,TA.LG CASO 76_3,TA.LG 103 | patient_102,TA.LG CASO 45 - 2019-03-01 08.32.26_3,TA.LG 104 | patient_103,213-B5-TALG_3,TA.LG 105 | patient_104,51-B2-TALG_3,TA.LG 106 | patient_105,TA.LG CASO 97_3,TA.LG 107 | patient_106,125-B3-TALG_3,TA.LG 108 | patient_107,TA.LG CASO 4 - 2018-12-04 13.29.01_3,TA.LG 109 | patient_108,TA.LG CASO 106_3,TA.LG 110 | patient_109,266-B5-TALG_3,TA.LG 111 | patient_110,73-B2-TALG_3,TA.LG 112 | patient_111,58-B2-TALG_3,TA.LG 113 | patient_112,76-B2-TALG_3,TA.LG 114 | patient_113,104-B3-TALG_3,TA.LG 115 | patient_114,TA.LG CASO 52 - 2019-03-04 08.34.05_3,TA.LG 116 | patient_115,68-B2-TALG_3,TA.LG 117 | patient_116,91-B2-TALG_3,TA.LG 118 | patient_117,225-B5-TALG_3,TA.LG 119 | patient_118,TA.LG CASO 80 C1_3,TA.LG 120 | patient_119,144-B3-TALG_3,TA.LG 121 | patient_120,TA.LG CASO 107_3,TA.LG 122 | patient_121,75-B2-TALG_3,TA.LG 123 | patient_122,52-B2-TALG_3,TA.LG 124 | patient_123,TA.LG CASO 93_3,TA.LG 125 | patient_124,267-B5-TALG_3,TA.LG 126 | patient_125,TA.LG CASO 89_3,TA.LG 127 | patient_126,81-B2-TALG_3,TA.LG 128 | patient_127,227-B5-TALG_3,TA.LG 129 | patient_128,72-B2-TALG_3,TA.LG 130 | patient_129,89-B2-TALG_3,TA.LG 131 | patient_130,TA.LG CASO 49 - 2019-03-04 07.38.56_3,TA.LG 132 | patient_131,132-B3-TALG_3,TA.LG 133 | patient_132,TA.LG CASO 56 - 2019-03-04 10.00.26_3,TA.LG 134 | patient_133,TA.LG CASO 77_3,TA.LG 135 | patient_134,101-B2-TALG_3,TA.LG 136 | patient_135,82-B2-TALG_3,TA.LG 137 | patient_136,TA.LG CASO 51 - 2019-03-04 08.23.50_3,TA.LG 138 | patient_137,133-B3-TALG_3,TA.LG 139 | patient_138,TA.LG CASO 88 A1_3,TA.LG 140 | patient_139,105-B3-TALG_3,TA.LG 141 | patient_140,212-B5-TALG_3,TA.LG 142 | patient_141,136-B3-TALG_3,TA.LG 143 | patient_142,TA.LG CASO 14 - 2018-12-04 13.51.01_3,TA.LG 144 | patient_143,TA.LG CASO 64 - 2019-03-04 17.42.27_3,TA.LG 145 | patient_144,TA.LG CASO 75_3,TA.LG 146 | patient_145,TA.LG CASO 105_3,TA.LG 147 | patient_146,TA.LG CASO 65 - 2019-03-04 18.07.00_3,TA.LG 148 | patient_147,TA.LG CASO 62 - 2019-03-04 17.11.05_3,TA.LG 149 | patient_148,TA.LG CASO 74 8424.19_3,TA.LG 150 | patient_149,TA.LG CASO 7 - 2018-12-04 13.37.32_3,TA.LG 151 | patient_150,TA.LG CASO 104 A1_3,TA.LG 152 | patient_151,TA.LG CASO 85_3,TA.LG 153 | patient_152,74-B2-TALG_3,TA.LG 154 | patient_153,203-B5-TALG_3,TA.LG 155 | patient_154,90-B2-TALG_3,TA.LG 156 | patient_155,TA.LG CASO 60 - 2019-03-04 16.40.51_3,TA.LG 157 | patient_156,69-B2-TALG_3,TA.LG 158 | patient_157,TA.LG CASO 58 - 2019-03-04 16.24.00_3,TA.LG 159 | patient_158,TA.LG CASO 2 - 2018-12-04 12.55.38_3,TA.LG 160 | patient_159,TA.LG CASO 5 - 2018-12-04 13.34.36_3,TA.LG 161 | patient_160,244-B5-TALG_3,TA.LG 162 | patient_161,237-B5-TALG_3,TA.LG 163 | patient_162,94-B2-TALG_3,TA.LG 164 | patient_163,TA.LG CASO 100_3,TA.LG 165 | patient_164,55-B2-TALG_3,TA.LG 166 | patient_165,TA.LG CASO 53 - 2019-03-04 08.42.25_3,TA.LG 167 | patient_166,171-B4-TALG_3,TA.LG 168 | patient_167,TA.LG CASO 99_3,TA.LG 169 | patient_168,173-B4-TALG_3,TA.LG 170 | patient_169,TA.LG CASO 3 - 2018-12-04 13.14.40_3,TA.LG 171 | patient_170,217-B5-TALG_3,TA.LG 172 | patient_171,182-B4-TALG_3,TA.LG 173 | patient_172,TA.LG CASO 84_3,TA.LG 174 | patient_173,236-B5-TALG_3,TA.LG 175 | patient_174,TA.LG CASO 86_3,TA.LG 176 | patient_175,TA.LG CASO 91_3,TA.LG 177 | patient_176,TA.LG CASO 66 - 2019-03-04 18.29.55_3,TA.LG 178 | patient_177,130-B3-TALG_3,TA.LG 179 | patient_178,79-B2-TALG_3,TA.LG 180 | patient_179,TA.LG CASO 88 B1_3,TA.LG 181 | patient_180,127-B3-TALG_3,TA.LG 182 | patient_181,TA.LG CASO 78_3,TA.LG 183 | patient_182,TA.LG CASO 80 A1_3,TA.LG 184 | patient_183,TA.LG CASO 101 B1_3,TA.LG 185 | patient_184,269-B5-TALG_3,TA.LG 186 | patient_185,TA.LG CASO 55 - 2019-03-04 09.43.03_3,TA.LG 187 | patient_186,TA.LG CASO 92 B1_3,TA.LG 188 | patient_187,TA.LG CASO 47 - 2019-03-01 08.51.33_3,TA.LG 189 | patient_188,TA.LG CASO 11 - 2018-12-04 13.46.00_3,TA.LG 190 | patient_189,210-B5-TALG_3,TA.LG 191 | patient_190,TA.LG CASO 81_3,TA.LG 192 | patient_191,61-B2-TALG_3,TA.LG 193 | patient_192,201-B5-TVAHG_4,TVA.HG 194 | patient_193,117-B3-TVAHG_4,TVA.HG 195 | patient_194,268-B5-TVAHG_4,TVA.HG 196 | patient_195,249-B5-TVAHG_4,TVA.HG 197 | patient_196,200-B4-TVAHG_4,TVA.HG 198 | patient_197,108-B3-TVAHG_4,TVA.HG 199 | patient_198,143-B3-TVAHG_4,TVA.HG 200 | patient_199,220-B5-TVAHG_4,TVA.HG 201 | patient_200,TVA.HG CASO 4 - 2018-12-04 13.15.43_4,TVA.HG 202 | patient_201,251-B5-TVAHG_4,TVA.HG 203 | patient_202,190-B4-TVAHG_4,TVA.HG 204 | patient_203,TVA.HG CASO 2 - 2018-12-04 12.53.37_4,TVA.HG 205 | patient_204,181-B4-TVAHG_4,TVA.HG 206 | -------------------------------------------------------------------------------- /trainer_mil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.MIL import * 18 | from utils import * 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | class Accuracy_Logger(object): 24 | """Accuracy logger""" 25 | def __init__(self, n_classes): 26 | super().__init__() 27 | self.n_classes = n_classes 28 | self.initialize() 29 | 30 | def initialize(self): 31 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 32 | 33 | def log(self, Y_hat, Y): 34 | Y_hat = int(Y_hat) 35 | Y = int(Y) 36 | self.data[Y]["count"] += 1 37 | self.data[Y]["correct"] += (Y_hat == Y) 38 | 39 | def log_batch(self, Y_hat, Y): 40 | Y_hat = np.array(Y_hat).astype(int) 41 | Y = np.array(Y).astype(int) 42 | for label_class in np.unique(Y): 43 | cls_mask = Y == label_class 44 | self.data[label_class]["count"] += cls_mask.sum() 45 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 46 | 47 | def get_summary(self, c): 48 | count = self.data[c]["count"] 49 | correct = self.data[c]["correct"] 50 | 51 | if count == 0: 52 | acc = None 53 | else: 54 | acc = float(correct) / count 55 | 56 | return acc, correct, count 57 | 58 | def summary(model, loader, n_classes): 59 | acc_logger = Accuracy_Logger(n_classes=n_classes) 60 | model.eval() 61 | test_loss = 0. 62 | test_error = 0. 63 | 64 | all_probs = np.zeros((len(loader), n_classes)) 65 | all_labels = np.zeros(len(loader)) 66 | 67 | slide_ids = loader.dataset.slide_data['slide_id'] 68 | patient_results = {} 69 | 70 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 71 | data, label = data.to(device), label.to(device) 72 | slide_id = slide_ids.iloc[batch_idx] 73 | with torch.inference_mode(): 74 | logits, Y_prob, Y_hat, y_probs, _ = model(data) 75 | acc_logger.log(Y_hat, label) 76 | 77 | probs = Y_prob.cpu().numpy() 78 | all_probs[batch_idx] = probs 79 | all_labels[batch_idx] = label.item() 80 | 81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 82 | error = calculate_error(Y_hat, label) 83 | test_error += error 84 | 85 | test_error /= len(loader) 86 | 87 | if n_classes == 2: 88 | auc = roc_auc_score(all_labels, all_probs[:, 1]) 89 | aucs = [] 90 | else: 91 | aucs = [] 92 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 93 | for class_idx in range(n_classes): 94 | if class_idx in all_labels: 95 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 96 | aucs.append(calc_auc(fpr, tpr)) 97 | else: 98 | aucs.append(float('nan')) 99 | 100 | auc = np.nanmean(np.array(aucs)) 101 | 102 | return patient_results, test_error, auc, acc_logger 103 | 104 | def train_mil(datasets, 105 | save_path='./save_weights/camelyon16_transmil_imagenet/', 106 | feature_dim = 512, 107 | n_classes = 2, 108 | fold = 0, 109 | writer_flag = True, 110 | max_epoch = 200, 111 | early_stopping = True, 112 | ): 113 | writer_dir = os.path.join(save_path, str(fold)) 114 | if not os.path.isdir(writer_dir): 115 | os.makedirs(writer_dir) 116 | if writer_flag: 117 | from tensorboardX import SummaryWriter 118 | writer = SummaryWriter(writer_dir, flush_secs=15) 119 | else: 120 | writer = None 121 | 122 | print("\nInit train/val/test splits...") 123 | train_split, val_split, test_split = datasets 124 | print("Training on {} samples".format(len(train_split))) 125 | print("Validating on {} samples".format(len(val_split))) 126 | print("Testing on {} samples".format(len(test_split))) 127 | 128 | print("\nInit loss function...") 129 | loss_fn = nn.CrossEntropyLoss() 130 | model = MIL_MeanPooling(embed_dim=feature_dim, 131 | n_classes=n_classes) 132 | _ = model.to(device) 133 | 134 | print("\nInit optimizer") 135 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5) 136 | 137 | print('\nInit Loaders...', end=' ') 138 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True) 139 | val_loader = get_split_loader(val_split, testing = False) 140 | test_loader = get_split_loader(test_split, testing = False) 141 | print('Done!') 142 | 143 | mini_loss = 10000 144 | retain = 0 145 | for epoch in range(max_epoch): 146 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn) 147 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn) 148 | if epoch % 20 == 0: 149 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch))) 150 | if loss < mini_loss: 151 | print("loss decrease from:{} to {}".format(mini_loss, loss)) 152 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))) 153 | mini_loss = loss 154 | retain = 0 155 | else: 156 | retain += 1 157 | print("Retain of early stopping: {} / {}".format(retain, 20)) 158 | if early_stopping: 159 | if retain > 20 and epoch > 50: 160 | print("Early stopping") 161 | break 162 | 163 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))) 164 | summary(model, test_loader, n_classes) 165 | 166 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn): 167 | model.train() 168 | acc_logger = Accuracy_Logger(n_classes=n_classes) 169 | train_loss = 0. 170 | bag_loss = 0. 171 | 172 | print('\n') 173 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 174 | data, label = data.to(device), label.to(device) 175 | 176 | logits, Y_prob, Y_hat, y_probs, _ = model(data) 177 | acc_logger.log(Y_hat, label) 178 | loss_bag = loss_fn(logits, label) 179 | loss = loss_bag 180 | 181 | loss_bag_value = loss_bag.item() 182 | loss_value = loss.item() 183 | 184 | train_loss += loss_value 185 | bag_loss += loss_bag 186 | 187 | if (batch_idx + 1) % 20 == 0: 188 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, label.item(), data.size(0))) 189 | 190 | loss.backward() 191 | optimizer.step() 192 | optimizer.zero_grad() 193 | 194 | train_loss /= len(loader) 195 | bag_loss /= len(loader) 196 | 197 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}, '.format(epoch, train_loss, bag_loss)) 198 | 199 | for i in range(n_classes): 200 | acc, correct, count = acc_logger.get_summary(i) 201 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 202 | if writer: 203 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch) 204 | if writer: 205 | writer.add_scalar('train/loss', train_loss, epoch) 206 | writer.add_scalar('train/loss_bag', loss_bag, epoch) 207 | 208 | 209 | def validate(epoch, model, loader, n_classes, writer, loss_fn): 210 | model.eval() 211 | acc_logger = Accuracy_Logger(n_classes=n_classes) 212 | val_loss = 0. 213 | 214 | prob = np.zeros((len(loader), n_classes)) 215 | labels = np.zeros(len(loader)) 216 | 217 | with torch.no_grad(): 218 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 219 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True) 220 | 221 | logits, Y_prob, Y_hat, y_probs, _ = model(data) 222 | acc_logger.log(Y_hat, label) 223 | 224 | loss_bag = loss_fn(logits, label) 225 | loss = loss_bag 226 | 227 | prob[batch_idx] = Y_prob.cpu().numpy() 228 | labels[batch_idx] = label.item() 229 | 230 | val_loss += loss.item() 231 | 232 | val_loss /= len(loader) 233 | if n_classes == 2: 234 | auc = roc_auc_score(labels, prob[:, 1]) 235 | 236 | else: 237 | auc = roc_auc_score(labels, prob, multi_class='ovr') 238 | if writer: 239 | writer.add_scalar('val/loss', val_loss, epoch) 240 | writer.add_scalar('val/auc', auc, epoch) 241 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc)) 242 | 243 | for i in range(n_classes): 244 | acc, correct, count = acc_logger.get_summary(i) 245 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 246 | 247 | return val_loss 248 | 249 | if __name__ == "__main__": 250 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/unitopatho_train.csv' 251 | data_dir = '/data2/clh/unitopatho/resnet18_imagenet/' 252 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/' 253 | save_dir = './save_weights/unitopatho_mean_imagenet/' 254 | 255 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 256 | data_dir = data_dir, 257 | shuffle = False, 258 | seed = 1, 259 | print_info = True, 260 | label_dict = {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5}, 261 | patient_strat=False, 262 | ignore=[]) 263 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)] 264 | for step, name in enumerate(csv_path): 265 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name) 266 | train_mil((train_dataset, val_dataset, test_dataset), 267 | save_path=save_dir, 268 | feature_dim = 512, 269 | n_classes = 6, 270 | fold = step, 271 | writer_flag = True, 272 | max_epoch = 200, 273 | early_stopping = True,) -------------------------------------------------------------------------------- /trainer_transmil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.TransMIL import TransMIL 18 | from utils import * 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | class Accuracy_Logger(object): 24 | """Accuracy logger""" 25 | def __init__(self, n_classes): 26 | super().__init__() 27 | self.n_classes = n_classes 28 | self.initialize() 29 | 30 | def initialize(self): 31 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 32 | 33 | def log(self, Y_hat, Y): 34 | Y_hat = int(Y_hat) 35 | Y = int(Y) 36 | self.data[Y]["count"] += 1 37 | self.data[Y]["correct"] += (Y_hat == Y) 38 | 39 | def log_batch(self, Y_hat, Y): 40 | Y_hat = np.array(Y_hat).astype(int) 41 | Y = np.array(Y).astype(int) 42 | for label_class in np.unique(Y): 43 | cls_mask = Y == label_class 44 | self.data[label_class]["count"] += cls_mask.sum() 45 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 46 | 47 | def get_summary(self, c): 48 | count = self.data[c]["count"] 49 | correct = self.data[c]["correct"] 50 | 51 | if count == 0: 52 | acc = None 53 | else: 54 | acc = float(correct) / count 55 | 56 | return acc, correct, count 57 | 58 | def summary(model, loader, n_classes): 59 | acc_logger = Accuracy_Logger(n_classes=n_classes) 60 | model.eval() 61 | test_loss = 0. 62 | test_error = 0. 63 | 64 | all_probs = np.zeros((len(loader), n_classes)) 65 | all_labels = np.zeros(len(loader)) 66 | 67 | slide_ids = loader.dataset.slide_data['slide_id'] 68 | patient_results = {} 69 | 70 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 71 | data, label = data.to(device), label.to(device) 72 | slide_id = slide_ids.iloc[batch_idx] 73 | with torch.inference_mode(): 74 | logits, Y_prob, Y_hat, h0 = model(data) 75 | acc_logger.log(Y_hat, label) 76 | 77 | probs = Y_prob.cpu().numpy() 78 | all_probs[batch_idx] = probs 79 | all_labels[batch_idx] = label.item() 80 | 81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 82 | error = calculate_error(Y_hat, label) 83 | test_error += error 84 | 85 | test_error /= len(loader) 86 | 87 | if n_classes == 2: 88 | auc = roc_auc_score(all_labels, all_probs[:, 1]) 89 | aucs = [] 90 | else: 91 | aucs = [] 92 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 93 | for class_idx in range(n_classes): 94 | if class_idx in all_labels: 95 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 96 | aucs.append(calc_auc(fpr, tpr)) 97 | else: 98 | aucs.append(float('nan')) 99 | 100 | auc = np.nanmean(np.array(aucs)) 101 | 102 | return patient_results, test_error, auc, acc_logger 103 | 104 | def train_transmil(datasets, 105 | save_path='./save_weights/camelyon16_transmil_imagenet/', 106 | feature_dim = 512, 107 | n_classes = 2, 108 | fold = 0, 109 | writer_flag = True, 110 | max_epoch = 200, 111 | early_stopping = True, 112 | ): 113 | writer_dir = os.path.join(save_path, str(fold)) 114 | if not os.path.isdir(writer_dir): 115 | os.makedirs(writer_dir) 116 | if writer_flag: 117 | from tensorboardX import SummaryWriter 118 | writer = SummaryWriter(writer_dir, flush_secs=15) 119 | else: 120 | writer = None 121 | 122 | print("\nInit train/val/test splits...") 123 | train_split, val_split, test_split = datasets 124 | print("Training on {} samples".format(len(train_split))) 125 | print("Validating on {} samples".format(len(val_split))) 126 | print("Testing on {} samples".format(len(test_split))) 127 | 128 | print("\nInit loss function...") 129 | loss_fn = nn.CrossEntropyLoss() 130 | 131 | model = TransMIL(dim=feature_dim, 132 | n_classes=n_classes) 133 | _ = model.to(device) 134 | 135 | print("\nInit optimizer") 136 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5) 137 | 138 | print('\nInit Loaders...', end=' ') 139 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True) 140 | val_loader = get_split_loader(val_split, testing = False) 141 | test_loader = get_split_loader(test_split, testing = False) 142 | print('Done!') 143 | 144 | mini_loss = 10000 145 | retain = 0 146 | for epoch in range(max_epoch): 147 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn) 148 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn) 149 | if epoch % 20 == 0: 150 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch))) 151 | if loss < mini_loss: 152 | print("loss decrease from:{} to {}".format(mini_loss, loss)) 153 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))) 154 | mini_loss = loss 155 | retain = 0 156 | else: 157 | retain += 1 158 | print("Retain of early stopping: {} / {}".format(retain, 20)) 159 | if early_stopping: 160 | if retain > 20 and epoch > 50: 161 | print("Early stopping") 162 | break 163 | 164 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))) 165 | summary(model, test_loader, n_classes) 166 | 167 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn): 168 | model.train() 169 | acc_logger = Accuracy_Logger(n_classes=n_classes) 170 | train_loss = 0. 171 | bag_loss = 0. 172 | 173 | print('\n') 174 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 175 | data, label = data.to(device), label.to(device) 176 | 177 | logits, Y_prob, Y_hat, h0 = model(data) 178 | acc_logger.log(Y_hat, label) 179 | loss_bag = loss_fn(logits, label) 180 | loss = loss_bag 181 | 182 | loss_bag_value = loss_bag.item() 183 | loss_value = loss.item() 184 | 185 | train_loss += loss_value 186 | bag_loss += loss_bag 187 | 188 | if (batch_idx + 1) % 20 == 0: 189 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, label.item(), data.size(0))) 190 | 191 | loss.backward() 192 | optimizer.step() 193 | optimizer.zero_grad() 194 | 195 | train_loss /= len(loader) 196 | bag_loss /= len(loader) 197 | 198 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}, '.format(epoch, train_loss, bag_loss)) 199 | 200 | for i in range(n_classes): 201 | acc, correct, count = acc_logger.get_summary(i) 202 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 203 | if writer: 204 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch) 205 | if writer: 206 | writer.add_scalar('train/loss', train_loss, epoch) 207 | writer.add_scalar('train/loss_bag', loss_bag, epoch) 208 | 209 | 210 | def validate(epoch, model, loader, n_classes, writer, loss_fn): 211 | model.eval() 212 | acc_logger = Accuracy_Logger(n_classes=n_classes) 213 | val_loss = 0. 214 | 215 | prob = np.zeros((len(loader), n_classes)) 216 | labels = np.zeros(len(loader)) 217 | 218 | with torch.no_grad(): 219 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 220 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True) 221 | 222 | logits, Y_prob, Y_hat, h0 = model(data) 223 | acc_logger.log(Y_hat, label) 224 | 225 | loss_bag = loss_fn(logits, label) 226 | loss = loss_bag 227 | 228 | prob[batch_idx] = Y_prob.cpu().numpy() 229 | labels[batch_idx] = label.item() 230 | 231 | val_loss += loss.item() 232 | 233 | val_loss /= len(loader) 234 | if n_classes == 2: 235 | auc = roc_auc_score(labels, prob[:, 1]) 236 | 237 | else: 238 | auc = roc_auc_score(labels, prob, multi_class='ovr') 239 | if writer: 240 | writer.add_scalar('val/loss', val_loss, epoch) 241 | writer.add_scalar('val/auc', auc, epoch) 242 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc)) 243 | 244 | for i in range(n_classes): 245 | acc, correct, count = acc_logger.get_summary(i) 246 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 247 | 248 | return val_loss 249 | 250 | if __name__ == "__main__": 251 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/unitopatho_train.csv' 252 | data_dir = '/data2/clh/unitopatho/resnet18_imagenet/' 253 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/' 254 | save_dir = './save_weights/unitopatho_transmil_imagenet/' 255 | 256 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 257 | data_dir = data_dir, 258 | shuffle = False, 259 | seed = 1, 260 | print_info = True, 261 | label_dict = {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5}, 262 | patient_strat=False, 263 | ignore=[]) 264 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)] 265 | for step, name in enumerate(csv_path): 266 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name) 267 | train_transmil((train_dataset, val_dataset, test_dataset), 268 | save_path=save_dir, 269 | feature_dim = 512, 270 | n_classes = 6, 271 | fold = step, 272 | writer_flag = True, 273 | max_epoch = 200, 274 | early_stopping = True,) -------------------------------------------------------------------------------- /datasets/camelyon17_seen.csv: -------------------------------------------------------------------------------- 1 | dir_name,case_id,slide_id,label 2 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_0,normal_tissue 3 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_1,normal_tissue 4 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_2,normal_tissue 5 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_3,tumor_tissue 6 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_4,normal_tissue 7 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_0,normal_tissue 8 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_1,normal_tissue 9 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_2,normal_tissue 10 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_3,normal_tissue 11 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_4,normal_tissue 12 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_1,tumor_tissue 13 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_2,tumor_tissue 14 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_3,normal_tissue 15 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_4,tumor_tissue 16 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_0,normal_tissue 17 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_1,tumor_tissue 18 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_2,normal_tissue 19 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_3,tumor_tissue 20 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_4,normal_tissue 21 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_0,normal_tissue 22 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_1,normal_tissue 23 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_2,normal_tissue 24 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_3,normal_tissue 25 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_4,tumor_tissue 26 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_0,normal_tissue 27 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_1,normal_tissue 28 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_2,normal_tissue 29 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_3,normal_tissue 30 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_4,tumor_tissue 31 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_0,normal_tissue 32 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_1,tumor_tissue 33 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_2,tumor_tissue 34 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_3,tumor_tissue 35 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_0,normal_tissue 36 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_1,normal_tissue 37 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_2,tumor_tissue 38 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_3,normal_tissue 39 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_4,tumor_tissue 40 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_0,normal_tissue 41 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_1,normal_tissue 42 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_2,normal_tissue 43 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_3,normal_tissue 44 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_4,normal_tissue 45 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_0,normal_tissue 46 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_1,normal_tissue 47 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_2,normal_tissue 48 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_3,normal_tissue 49 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_4,normal_tissue 50 | /data/ceiling/data/MIL/camelyon17/,patient_080,patient_080_node_2,tumor_tissue 51 | /data/ceiling/data/MIL/camelyon17/,patient_080,patient_080_node_3,tumor_tissue 52 | /data/ceiling/data/MIL/camelyon17/,patient_080,patient_080_node_4,tumor_tissue 53 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_0,normal_tissue 54 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_1,tumor_tissue 55 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_2,tumor_tissue 56 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_3,normal_tissue 57 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_0,normal_tissue 58 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_1,normal_tissue 59 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_2,normal_tissue 60 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_3,normal_tissue 61 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_4,normal_tissue 62 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_0,normal_tissue 63 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_1,normal_tissue 64 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_2,normal_tissue 65 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_3,normal_tissue 66 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_4,normal_tissue 67 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_0,normal_tissue 68 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_1,normal_tissue 69 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_2,tumor_tissue 70 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_3,tumor_tissue 71 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_4,normal_tissue 72 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_0,normal_tissue 73 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_1,normal_tissue 74 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_2,normal_tissue 75 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_3,normal_tissue 76 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_4,normal_tissue 77 | /data/ceiling/data/MIL/camelyon17/,patient_086,patient_086_node_1,normal_tissue 78 | /data/ceiling/data/MIL/camelyon17/,patient_086,patient_086_node_2,normal_tissue 79 | /data/ceiling/data/MIL/camelyon17/,patient_086,patient_086_node_3,normal_tissue 80 | /data/ceiling/data/MIL/camelyon17/,patient_087,patient_087_node_2,normal_tissue 81 | /data/ceiling/data/MIL/camelyon17/,patient_087,patient_087_node_3,normal_tissue 82 | /data/ceiling/data/MIL/camelyon17/,patient_087,patient_087_node_4,normal_tissue 83 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_0,normal_tissue 84 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_1,tumor_tissue 85 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_2,normal_tissue 86 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_3,normal_tissue 87 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_0,normal_tissue 88 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_1,normal_tissue 89 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_2,normal_tissue 90 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_3,tumor_tissue 91 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_4,normal_tissue 92 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_0,normal_tissue 93 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_1,normal_tissue 94 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_2,normal_tissue 95 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_3,normal_tissue 96 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_4,normal_tissue 97 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_0,normal_tissue 98 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_1,normal_tissue 99 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_2,tumor_tissue 100 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_3,tumor_tissue 101 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_4,tumor_tissue 102 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_0,tumor_tissue 103 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_1,tumor_tissue 104 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_2,normal_tissue 105 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_3,tumor_tissue 106 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_4,tumor_tissue 107 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_0,normal_tissue 108 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_1,normal_tissue 109 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_2,normal_tissue 110 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_3,normal_tissue 111 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_4,tumor_tissue 112 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_0,tumor_tissue 113 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_1,tumor_tissue 114 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_2,tumor_tissue 115 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_3,normal_tissue 116 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_4,tumor_tissue 117 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_0,tumor_tissue 118 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_1,normal_tissue 119 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_2,normal_tissue 120 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_3,normal_tissue 121 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_4,normal_tissue 122 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_0,tumor_tissue 123 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_1,normal_tissue 124 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_2,tumor_tissue 125 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_3,tumor_tissue 126 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_4,tumor_tissue 127 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_0,tumor_tissue 128 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_1,tumor_tissue 129 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_2,normal_tissue 130 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_3,tumor_tissue 131 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_4,tumor_tissue 132 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_0,normal_tissue 133 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_1,normal_tissue 134 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_2,normal_tissue 135 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_3,normal_tissue 136 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_4,normal_tissue 137 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_0,normal_tissue 138 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_1,normal_tissue 139 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_2,normal_tissue 140 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_3,normal_tissue 141 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_4,tumor_tissue 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /trainer_dsmil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.DSMIL import MILNet 18 | from utils import * 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | class Accuracy_Logger(object): 24 | """Accuracy logger""" 25 | def __init__(self, n_classes): 26 | super().__init__() 27 | self.n_classes = n_classes 28 | self.initialize() 29 | 30 | def initialize(self): 31 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 32 | 33 | def log(self, Y_hat, Y): 34 | Y_hat = int(Y_hat) 35 | Y = int(Y) 36 | self.data[Y]["count"] += 1 37 | self.data[Y]["correct"] += (Y_hat == Y) 38 | 39 | def log_batch(self, Y_hat, Y): 40 | Y_hat = np.array(Y_hat).astype(int) 41 | Y = np.array(Y).astype(int) 42 | for label_class in np.unique(Y): 43 | cls_mask = Y == label_class 44 | self.data[label_class]["count"] += cls_mask.sum() 45 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 46 | 47 | def get_summary(self, c): 48 | count = self.data[c]["count"] 49 | correct = self.data[c]["correct"] 50 | 51 | if count == 0: 52 | acc = None 53 | else: 54 | acc = float(correct) / count 55 | 56 | return acc, correct, count 57 | 58 | def summary(model, loader, n_classes): 59 | acc_logger = Accuracy_Logger(n_classes=n_classes) 60 | model.eval() 61 | test_loss = 0. 62 | test_error = 0. 63 | 64 | all_probs = np.zeros((len(loader), n_classes)) 65 | all_labels = np.zeros(len(loader)) 66 | 67 | slide_ids = loader.dataset.slide_data['slide_id'] 68 | patient_results = {} 69 | 70 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 71 | data, label = data.to(device), label.to(device) 72 | slide_id = slide_ids.iloc[batch_idx] 73 | with torch.inference_mode(): 74 | ins_prediction, bag_prediction, _, _ = model(data) 75 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1] 76 | acc_logger.log(Y_hat, label) 77 | 78 | max_prediction, _ = torch.max(ins_prediction, 0) 79 | Y_prob = F.softmax(bag_prediction, dim=-1) 80 | 81 | probs = Y_prob.cpu().numpy() 82 | all_probs[batch_idx] = probs 83 | all_labels[batch_idx] = label.item() 84 | 85 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 86 | error = calculate_error(Y_hat, label) 87 | test_error += error 88 | 89 | test_error /= len(loader) 90 | 91 | if n_classes == 2: 92 | auc = roc_auc_score(all_labels, all_probs[:, 1]) 93 | aucs = [] 94 | else: 95 | aucs = [] 96 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 97 | for class_idx in range(n_classes): 98 | if class_idx in all_labels: 99 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 100 | aucs.append(calc_auc(fpr, tpr)) 101 | else: 102 | aucs.append(float('nan')) 103 | 104 | auc = np.nanmean(np.array(aucs)) 105 | 106 | return patient_results, test_error, auc, acc_logger 107 | 108 | def train_dsmil(datasets, 109 | save_path='./save_weights/camelyon16_dsmil_imagenet/', 110 | feature_dim = 512, 111 | n_classes = 2, 112 | fold = 0, 113 | writer_flag = True, 114 | max_epoch = 200, 115 | early_stopping = True, 116 | ): 117 | writer_dir = os.path.join(save_path, str(fold)) 118 | if not os.path.isdir(writer_dir): 119 | os.makedirs(writer_dir) 120 | if writer_flag: 121 | from tensorboardX import SummaryWriter 122 | writer = SummaryWriter(writer_dir, flush_secs=15) 123 | else: 124 | writer = None 125 | 126 | print("\nInit train/val/test splits...") 127 | train_split, val_split, test_split = datasets 128 | print("Training on {} samples".format(len(train_split))) 129 | print("Validating on {} samples".format(len(val_split))) 130 | print("Testing on {} samples".format(len(test_split))) 131 | 132 | print("\nInit loss function...") 133 | loss_fn = nn.CrossEntropyLoss() 134 | 135 | model = MILNet(feature_dim=feature_dim, 136 | n_classes=n_classes) 137 | _ = model.to(device) 138 | 139 | print("\nInit optimizer") 140 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5) 141 | 142 | print('\nInit Loaders...', end=' ') 143 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True) 144 | val_loader = get_split_loader(val_split, testing = False) 145 | test_loader = get_split_loader(test_split, testing = False) 146 | print('Done!') 147 | 148 | mini_loss = 10000 149 | retain = 0 150 | for epoch in range(max_epoch): 151 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn) 152 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn) 153 | if epoch % 10 == 0: 154 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch))) 155 | if loss < mini_loss: 156 | print("loss decrease from:{} to {}".format(mini_loss, loss)) 157 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))) 158 | mini_loss = loss 159 | retain = 0 160 | else: 161 | retain += 1 162 | print("Retain of early stopping: {} / {}".format(retain, 20)) 163 | if early_stopping: 164 | if retain > 20 and epoch > 50: 165 | print("Early stopping") 166 | break 167 | 168 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))) 169 | summary(model, test_loader, n_classes) 170 | 171 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn): 172 | model.train() 173 | acc_logger = Accuracy_Logger(n_classes=n_classes) 174 | train_loss = 0. 175 | bag_loss = 0. 176 | ins_loss = 0. 177 | 178 | print('\n') 179 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 180 | data, label = data.to(device), label.to(device) 181 | 182 | ins_prediction, bag_prediction, _, _ = model(data) 183 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1] 184 | acc_logger.log(Y_hat, label) 185 | max_prediction, _ = torch.max(ins_prediction, 0) 186 | loss_bag = loss_fn(bag_prediction.view(1, -1), label) 187 | loss_ins = loss_fn(max_prediction.view(1, -1), label) 188 | loss = 0.5 * loss_bag + 0.5 * loss_ins 189 | 190 | loss_bag_value = loss_bag.item() 191 | loss_ins_value = loss_ins.item() 192 | loss_value = loss.item() 193 | 194 | train_loss += loss_value 195 | bag_loss += loss_bag 196 | ins_loss += loss_ins 197 | 198 | if (batch_idx + 1) % 20 == 0: 199 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, loss_ins: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, loss_ins_value, label.item(), data.size(0))) 200 | 201 | loss.backward() 202 | optimizer.step() 203 | optimizer.zero_grad() 204 | 205 | train_loss /= len(loader) 206 | bag_loss /= len(loader) 207 | ins_loss /= len(loader) 208 | 209 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}, ins_loss: {:.4f}'.format(epoch, train_loss, bag_loss, ins_loss)) 210 | 211 | for i in range(n_classes): 212 | acc, correct, count = acc_logger.get_summary(i) 213 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 214 | if writer: 215 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch) 216 | if writer: 217 | writer.add_scalar('train/loss', train_loss, epoch) 218 | writer.add_scalar('train/loss_bag', loss_bag, epoch) 219 | writer.add_scalar('train/loss_ins', loss_ins, epoch) 220 | 221 | 222 | def validate(epoch, model, loader, n_classes, writer, loss_fn): 223 | model.eval() 224 | acc_logger = Accuracy_Logger(n_classes=n_classes) 225 | val_loss = 0. 226 | 227 | prob = np.zeros((len(loader), n_classes)) 228 | labels = np.zeros(len(loader)) 229 | 230 | with torch.no_grad(): 231 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 232 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True) 233 | 234 | ins_prediction, bag_prediction, _, _ = model(data) 235 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1] 236 | acc_logger.log(Y_hat, label) 237 | 238 | max_prediction, _ = torch.max(ins_prediction, 0) 239 | loss_bag = loss_fn(bag_prediction.view(1, -1), label) 240 | loss_ins = loss_fn(max_prediction.view(1, -1), label) 241 | loss = 0.5 * loss_bag + 0.5 * loss_ins 242 | Y_prob = F.softmax(bag_prediction, dim=-1) 243 | 244 | prob[batch_idx] = Y_prob.cpu().numpy() 245 | labels[batch_idx] = label.item() 246 | 247 | val_loss += loss.item() 248 | 249 | val_loss /= len(loader) 250 | if n_classes == 2: 251 | auc = roc_auc_score(labels, prob[:, 1]) 252 | 253 | else: 254 | auc = roc_auc_score(labels, prob, multi_class='ovr') 255 | if writer: 256 | writer.add_scalar('val/loss', val_loss, epoch) 257 | writer.add_scalar('val/auc', auc, epoch) 258 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc)) 259 | 260 | for i in range(n_classes): 261 | acc, correct, count = acc_logger.get_summary(i) 262 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 263 | 264 | return val_loss 265 | 266 | if __name__ == "__main__": 267 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon16_total.csv' 268 | data_dir = '/data2/clh/camelyon16/resnet18_imagenet/' 269 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/camelyon16_100/' 270 | save_dir = './save_weights/camelyon16_dsmil_imagenet_100/' 271 | # {'normal_tissue':0, 'tumor_tissue':1} 272 | 273 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 274 | data_dir = data_dir, 275 | shuffle = False, 276 | seed = 1, 277 | print_info = True, 278 | label_dict = {'normal_tissue':0, 'tumor_tissue':1}, 279 | patient_strat=False, 280 | ignore=[]) 281 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)] 282 | for step, name in enumerate(csv_path): 283 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name) 284 | train_dsmil((train_dataset, val_dataset, test_dataset), 285 | save_path=save_dir, 286 | feature_dim = 512, 287 | n_classes = 2, 288 | fold = step, 289 | writer_flag = True, 290 | max_epoch = 200, 291 | early_stopping = False,) -------------------------------------------------------------------------------- /trainer_attrimil_abmil.py: -------------------------------------------------------------------------------- 1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | import torch.optim as optim 7 | import pdb 8 | import torch.nn.functional as F 9 | 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from sklearn.preprocessing import label_binarize 14 | from sklearn.metrics import roc_auc_score, roc_curve 15 | from sklearn.metrics import auc as calc_auc 16 | 17 | from models.AttriMIL import AttriMIL 18 | from utils import * 19 | 20 | from constraints import spatial_constraint, rank_constraint 21 | import queue 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | 26 | class Accuracy_Logger(object): 27 | """Accuracy logger""" 28 | def __init__(self, n_classes): 29 | super().__init__() 30 | self.n_classes = n_classes 31 | self.initialize() 32 | 33 | def initialize(self): 34 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)] 35 | 36 | def log(self, Y_hat, Y): 37 | Y_hat = int(Y_hat) 38 | Y = int(Y) 39 | self.data[Y]["count"] += 1 40 | self.data[Y]["correct"] += (Y_hat == Y) 41 | 42 | def log_batch(self, Y_hat, Y): 43 | Y_hat = np.array(Y_hat).astype(int) 44 | Y = np.array(Y).astype(int) 45 | for label_class in np.unique(Y): 46 | cls_mask = Y == label_class 47 | self.data[label_class]["count"] += cls_mask.sum() 48 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum() 49 | 50 | def get_summary(self, c): 51 | count = self.data[c]["count"] 52 | correct = self.data[c]["correct"] 53 | 54 | if count == 0: 55 | acc = None 56 | else: 57 | acc = float(correct) / count 58 | 59 | return acc, correct, count 60 | 61 | def summary(model, loader, n_classes): 62 | acc_logger = Accuracy_Logger(n_classes=n_classes) 63 | model.eval() 64 | test_loss = 0. 65 | test_error = 0. 66 | 67 | all_probs = np.zeros((len(loader), n_classes)) 68 | all_labels = np.zeros(len(loader)) 69 | 70 | slide_ids = loader.dataset.slide_data['slide_id'] 71 | patient_results = {} 72 | 73 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 74 | data, label = data.to(device), label.to(device) 75 | slide_id = slide_ids.iloc[batch_idx] 76 | with torch.inference_mode(): 77 | ins_prediction, bag_prediction, _, _ = model(data) 78 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1] 79 | acc_logger.log(Y_hat, label) 80 | 81 | max_prediction, _ = torch.max(ins_prediction, 0) 82 | Y_prob = F.softmax(bag_prediction, dim=-1) 83 | 84 | probs = Y_prob.cpu().numpy() 85 | all_probs[batch_idx] = probs 86 | all_labels[batch_idx] = label.item() 87 | 88 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 89 | error = calculate_error(Y_hat, label) 90 | test_error += error 91 | 92 | test_error /= len(loader) 93 | 94 | if n_classes == 2: 95 | auc = roc_auc_score(all_labels, all_probs[:, 1]) 96 | aucs = [] 97 | else: 98 | aucs = [] 99 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)]) 100 | for class_idx in range(n_classes): 101 | if class_idx in all_labels: 102 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 103 | aucs.append(calc_auc(fpr, tpr)) 104 | else: 105 | aucs.append(float('nan')) 106 | 107 | auc = np.nanmean(np.array(aucs)) 108 | 109 | return patient_results, test_error, auc, acc_logger 110 | 111 | def train_abmil(datasets, 112 | save_path='./save_weights/camelyon16_abmil_imagenet/', 113 | feature_dim = 512, 114 | n_classes = 2, 115 | fold = 0, 116 | writer_flag = True, 117 | max_epoch = 200, 118 | early_stopping = True,): 119 | 120 | writer_dir = os.path.join(save_path, str(fold)) 121 | if not os.path.isdir(writer_dir): 122 | os.makedirs(writer_dir) 123 | if writer_flag: 124 | from tensorboardX import SummaryWriter 125 | writer = SummaryWriter(writer_dir, flush_secs=15) 126 | else: 127 | writer = None 128 | 129 | print("\nInit train/val/test splits...") 130 | train_split, val_split, test_split = datasets 131 | print("Training on {} samples".format(len(train_split))) 132 | print("Validating on {} samples".format(len(val_split))) 133 | print("Testing on {} samples".format(len(test_split))) 134 | 135 | print("\nInit loss function...") 136 | loss_fn = nn.CrossEntropyLoss() 137 | 138 | model = AttriMIL(dim=feature_dim, 139 | n_classes=n_classes) 140 | _ = model.to(device) 141 | 142 | print("\nInit optimizer") 143 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5) 144 | 145 | print('\nInit Loaders...', end=' ') 146 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True) 147 | val_loader = get_split_loader(val_split, testing = False) 148 | test_loader = get_split_loader(test_split, testing = False) 149 | print('Done!') 150 | 151 | mini_loss = 10000 152 | retain = 0 153 | 154 | for epoch in range(max_epoch): 155 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn) 156 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn) 157 | if epoch % 10 == 0: 158 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch))) 159 | if loss < mini_loss: 160 | print("loss decrease from:{} to {}".format(mini_loss, loss)) 161 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))) 162 | mini_loss = loss 163 | retain = 0 164 | else: 165 | retain += 1 166 | print("Retain of early stopping: {} / {}".format(retain, 20)) 167 | if early_stopping: 168 | if retain > 20 and epoch > 50: 169 | print("Early stopping") 170 | break 171 | 172 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))) 173 | summary(model, test_loader, n_classes) 174 | 175 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn): 176 | model.train() 177 | acc_logger = Accuracy_Logger(n_classes=n_classes) 178 | train_loss = 0. 179 | bag_loss = 0. 180 | ins_loss = 0. 181 | 182 | print('\n') 183 | 184 | 185 | label_positive_list = [] 186 | label_negative_list = [] 187 | for i in range(n_classes): 188 | label_positive_list.append(queue.Queue(maxsize=4)) 189 | label_negative_list.append(queue.Queue(maxsize=4)) 190 | 191 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 192 | data, label = data.to(device), label.to(device) 193 | 194 | logits, Y_prob, Y_hat, attribute_score, results_dict = model(data) 195 | acc_logger.log(Y_hat, label) 196 | loss_bag = loss_fn(logits, label) 197 | loss_spa = spatial_constraint(attribute_score, n_classes, nearest, ks=3) 198 | loss_rank, label_positive_list, label_negative_list = rank_constraint(data, label, model, attribute_score, n_classes, label_positive_list, label_negative_list) 199 | 200 | loss = loss_bag + 1.0 * loss_spa + 5.0 * loss_rank 201 | 202 | loss_bag_value = loss_bag.item() 203 | loss_ins_value = loss_ins.item() 204 | loss_spa_value = loss_spa.item() 205 | loss_rank_value = loss_rank.item() 206 | loss_value = loss.item() 207 | 208 | train_loss += loss_value 209 | bag_loss += loss_bag 210 | 211 | if (batch_idx + 1) % 20 == 0: 212 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, loss_spa: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, loss_spa_value, label.item(), data.size(0))) 213 | 214 | loss.backward() 215 | optimizer.step() 216 | optimizer.zero_grad() 217 | 218 | train_loss /= len(loader) 219 | bag_loss /= len(loader) 220 | ins_loss /= len(loader) 221 | 222 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}'.format(epoch, train_loss, bag_loss)) 223 | 224 | for i in range(n_classes): 225 | acc, correct, count = acc_logger.get_summary(i) 226 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 227 | if writer: 228 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch) 229 | if writer: 230 | writer.add_scalar('train/loss', train_loss, epoch) 231 | writer.add_scalar('train/loss_bag', loss_bag, epoch) 232 | 233 | 234 | def validate(epoch, model, loader, n_classes, writer, loss_fn): 235 | model.eval() 236 | acc_logger = Accuracy_Logger(n_classes=n_classes) 237 | val_loss = 0. 238 | 239 | prob = np.zeros((len(loader), n_classes)) 240 | labels = np.zeros(len(loader)) 241 | 242 | with torch.no_grad(): 243 | for batch_idx, (data, label, coords, nearest) in enumerate(loader): 244 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True) 245 | 246 | logits, Y_prob, Y_hat, attribute_score, results_dict = model(data) 247 | acc_logger.log(Y_hat, label) 248 | loss_bag = loss_fn(logits, label) 249 | 250 | loss = loss_bag 251 | 252 | prob[batch_idx] = Y_prob.cpu().numpy() 253 | labels[batch_idx] = label.item() 254 | 255 | val_loss += loss.item() 256 | 257 | val_loss /= len(loader) 258 | if n_classes == 2: 259 | auc = roc_auc_score(labels, prob[:, 1]) 260 | 261 | else: 262 | auc = roc_auc_score(labels, prob, multi_class='ovr') 263 | if writer: 264 | writer.add_scalar('val/loss', val_loss, epoch) 265 | writer.add_scalar('val/auc', auc, epoch) 266 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc)) 267 | 268 | for i in range(n_classes): 269 | acc, correct, count = acc_logger.get_summary(i) 270 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count)) 271 | 272 | return val_loss 273 | 274 | if __name__ == "__main__": 275 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon16_total.csv' 276 | data_dir = '/data2/clh/camelyon16/resnet18_imagenet/' 277 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/camelyon16_100/' 278 | save_dir = './save_weights/camelyon16_attrimil_imagenet_100/' 279 | # {'normal_tissue':0, 'tumor_tissue':1} 280 | 281 | dataset = Generic_MIL_Dataset(csv_path = csv_path, 282 | data_dir = data_dir, 283 | shuffle = False, 284 | seed = 1, 285 | print_info = True, 286 | label_dict = {'normal_tissue':0, 'tumor_tissue':1}, 287 | patient_strat=False, 288 | ignore=[]) 289 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)] 290 | for step, name in enumerate(csv_path): 291 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name) 292 | train_abmil((train_dataset, val_dataset, test_dataset), 293 | save_path=save_dir, 294 | feature_dim = 512, 295 | n_classes = 2, 296 | fold = step, 297 | writer_flag = True, 298 | max_epoch = 200, 299 | early_stopping = False,) -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import math 6 | import re 7 | import pdb 8 | import pickle 9 | from scipy import stats 10 | 11 | from torch.utils.data import Dataset 12 | import h5py 13 | 14 | def save_splits(split_datasets, column_keys, filename, boolean_style=False): 15 | splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))] 16 | if not boolean_style: 17 | df = pd.concat(splits, ignore_index=True, axis=1) 18 | df.columns = column_keys 19 | else: 20 | df = pd.concat(splits, ignore_index = True, axis=0) 21 | index = df.values.tolist() 22 | one_hot = np.eye(len(split_datasets)).astype(bool) 23 | bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0) 24 | df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test']) 25 | 26 | df.to_csv(filename) 27 | print() 28 | 29 | class Generic_WSI_Classification_Dataset(Dataset): 30 | def __init__(self, 31 | csv_path = 'dataset_csv/ccrcc_clean.csv', 32 | shuffle = False, 33 | seed = 7, 34 | print_info = True, 35 | label_dict = {}, 36 | filter_dict = {}, 37 | ignore=[], 38 | patient_strat=False, 39 | label_col = None, 40 | patient_voting = 'max', 41 | ): 42 | """ 43 | Args: 44 | csv_file (string): Path to the csv file with annotations. 45 | shuffle (boolean): Whether to shuffle 46 | seed (int): random seed for shuffling the data 47 | print_info (boolean): Whether to print a summary of the dataset 48 | label_dict (dict): Dictionary with key, value pairs for converting str labels to int 49 | ignore (list): List containing class labels to ignore 50 | """ 51 | self.label_dict = label_dict 52 | self.num_classes = len(set(self.label_dict.values())) 53 | self.seed = seed 54 | self.print_info = print_info 55 | self.patient_strat = patient_strat 56 | self.train_ids, self.val_ids, self.test_ids = (None, None, None) 57 | self.data_dir = None 58 | if not label_col: 59 | label_col = 'label' 60 | self.label_col = label_col 61 | 62 | slide_data = pd.read_csv(csv_path) 63 | slide_data = self.filter_df(slide_data, filter_dict) 64 | slide_data = self.df_prep(slide_data, self.label_dict, ignore, self.label_col) 65 | 66 | ###shuffle data 67 | if shuffle: 68 | np.random.seed(seed) 69 | np.random.shuffle(slide_data) 70 | 71 | self.slide_data = slide_data 72 | 73 | self.patient_data_prep(patient_voting) 74 | self.cls_ids_prep() 75 | 76 | if print_info: 77 | self.summarize() 78 | 79 | def cls_ids_prep(self): 80 | # store ids corresponding each class at the patient or case level 81 | self.patient_cls_ids = [[] for i in range(self.num_classes)] 82 | for i in range(self.num_classes): 83 | self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0] 84 | 85 | # store ids corresponding each class at the slide level 86 | self.slide_cls_ids = [[] for i in range(self.num_classes)] 87 | for i in range(self.num_classes): 88 | self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] 89 | 90 | def patient_data_prep(self, patient_voting='max'): 91 | patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients 92 | patient_labels = [] 93 | 94 | for p in patients: 95 | locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist() 96 | assert len(locations) > 0 97 | label = self.slide_data['label'][locations].values 98 | if patient_voting == 'max': 99 | label = label.max() # get patient label (MIL convention) 100 | elif patient_voting == 'maj': 101 | label = stats.mode(label)[0] 102 | else: 103 | raise NotImplementedError 104 | patient_labels.append(label) 105 | 106 | self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)} 107 | 108 | @staticmethod 109 | def df_prep(data, label_dict, ignore, label_col): 110 | if label_col != 'label': 111 | data['label'] = data[label_col].copy() 112 | 113 | mask = data['label'].isin(ignore) 114 | data = data[~mask] 115 | data.reset_index(drop=True, inplace=True) 116 | for i in data.index: 117 | key = data.loc[i, 'label'] 118 | data.at[i, 'label'] = label_dict[key] 119 | 120 | return data 121 | 122 | def filter_df(self, df, filter_dict={}): 123 | if len(filter_dict) > 0: 124 | filter_mask = np.full(len(df), True, bool) 125 | # assert 'label' not in filter_dict.keys() 126 | for key, val in filter_dict.items(): 127 | mask = df[key].isin(val) 128 | filter_mask = np.logical_and(filter_mask, mask) 129 | df = df[filter_mask] 130 | return df 131 | 132 | def __len__(self): 133 | if self.patient_strat: 134 | return len(self.patient_data['case_id']) 135 | 136 | else: 137 | return len(self.slide_data) 138 | 139 | def summarize(self): 140 | print("label column: {}".format(self.label_col)) 141 | print("label dictionary: {}".format(self.label_dict)) 142 | print("number of classes: {}".format(self.num_classes)) 143 | print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False)) 144 | for i in range(self.num_classes): 145 | print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0])) 146 | print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0])) 147 | 148 | def get_split_from_df(self, all_splits, split_key='train'): 149 | split = all_splits[split_key] 150 | split = split.dropna().reset_index(drop=True) 151 | 152 | if len(split) > 0: 153 | mask = self.slide_data['slide_id'].isin(split.tolist()) 154 | df_slice = self.slide_data[mask].reset_index(drop=True) 155 | split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes) 156 | else: 157 | split = None 158 | 159 | return split 160 | 161 | def get_merged_split_from_df(self, all_splits, split_keys=['train']): 162 | merged_split = [] 163 | for split_key in split_keys: 164 | split = all_splits[split_key] 165 | split = split.dropna().reset_index(drop=True).tolist() 166 | merged_split.extend(split) 167 | 168 | if len(split) > 0: 169 | mask = self.slide_data['slide_id'].isin(merged_split) 170 | df_slice = self.slide_data[mask].reset_index(drop=True) 171 | split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes) 172 | else: 173 | split = None 174 | 175 | return split 176 | 177 | 178 | def return_splits(self, from_id=True, csv_path=None): 179 | 180 | 181 | if from_id: 182 | if len(self.train_ids) > 0: 183 | train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True) 184 | train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes) 185 | 186 | else: 187 | train_split = None 188 | 189 | if len(self.val_ids) > 0: 190 | val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True) 191 | val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes) 192 | 193 | else: 194 | val_split = None 195 | 196 | if len(self.test_ids) > 0: 197 | test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True) 198 | test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes) 199 | 200 | else: 201 | test_split = None 202 | 203 | 204 | else: 205 | assert csv_path 206 | all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01. 207 | train_split = self.get_split_from_df(all_splits, 'train') 208 | val_split = self.get_split_from_df(all_splits, 'val') 209 | test_split = self.get_split_from_df(all_splits, 'test') 210 | 211 | return train_split, val_split, test_split 212 | 213 | def get_list(self, ids): 214 | return self.slide_data['slide_id'][ids] 215 | 216 | def getlabel(self, ids): 217 | return self.slide_data['label'][ids] 218 | 219 | def __getitem__(self, idx): 220 | return None 221 | 222 | def test_split_gen(self, return_descriptor=False): 223 | 224 | if return_descriptor: 225 | index = [list(self.label_dict.keys())[list(self.label_dict.values()).index(i)] for i in range(self.num_classes)] 226 | columns = ['train', 'val', 'test'] 227 | df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index, 228 | columns= columns) 229 | 230 | count = len(self.train_ids) 231 | print('\nnumber of training samples: {}'.format(count)) 232 | labels = self.getlabel(self.train_ids) 233 | unique, counts = np.unique(labels, return_counts=True) 234 | for u in range(len(unique)): 235 | print('number of samples in cls {}: {}'.format(unique[u], counts[u])) 236 | if return_descriptor: 237 | df.loc[index[u], 'train'] = counts[u] 238 | 239 | count = len(self.val_ids) 240 | print('\nnumber of val samples: {}'.format(count)) 241 | labels = self.getlabel(self.val_ids) 242 | unique, counts = np.unique(labels, return_counts=True) 243 | for u in range(len(unique)): 244 | print('number of samples in cls {}: {}'.format(unique[u], counts[u])) 245 | if return_descriptor: 246 | df.loc[index[u], 'val'] = counts[u] 247 | 248 | count = len(self.test_ids) 249 | print('\nnumber of test samples: {}'.format(count)) 250 | labels = self.getlabel(self.test_ids) 251 | unique, counts = np.unique(labels, return_counts=True) 252 | for u in range(len(unique)): 253 | print('number of samples in cls {}: {}'.format(unique[u], counts[u])) 254 | if return_descriptor: 255 | df.loc[index[u], 'test'] = counts[u] 256 | 257 | assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0 258 | assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0 259 | assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0 260 | 261 | if return_descriptor: 262 | return df 263 | 264 | def save_split(self, filename): 265 | train_split = self.get_list(self.train_ids) 266 | val_split = self.get_list(self.val_ids) 267 | test_split = self.get_list(self.test_ids) 268 | df_tr = pd.DataFrame({'train': train_split}) 269 | df_v = pd.DataFrame({'val': val_split}) 270 | df_t = pd.DataFrame({'test': test_split}) 271 | df = pd.concat([df_tr, df_v, df_t], axis=1) 272 | df.to_csv(filename, index = False) 273 | 274 | 275 | class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset): 276 | def __init__(self, 277 | data_dir, 278 | **kwargs): 279 | 280 | super(Generic_MIL_Dataset, self).__init__(**kwargs) 281 | self.data_dir = data_dir 282 | self.use_h5 = True 283 | 284 | def load_from_h5(self, toggle): 285 | self.use_h5 = toggle 286 | 287 | def __getitem__(self, idx): 288 | slide_id = self.slide_data['slide_id'][idx] 289 | label = self.slide_data['label'][idx] 290 | if type(self.data_dir) == dict: 291 | source = self.slide_data['source'][idx] 292 | data_dir = self.data_dir[source] 293 | else: 294 | data_dir = self.data_dir 295 | if not self.use_h5: 296 | if self.data_dir: 297 | full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id)) 298 | features = torch.load(full_path) 299 | return features, label 300 | 301 | else: 302 | return slide_id, label 303 | else: 304 | full_path = os.path.join(data_dir,'h5_coords_files','{}.h5'.format(slide_id)) 305 | with h5py.File(full_path,'r') as hdf5_file: 306 | features = hdf5_file['features'][:] 307 | coords = hdf5_file['coords'][:] 308 | nearest = hdf5_file['nearest'][:] 309 | features = torch.from_numpy(features) 310 | return features, label, coords, nearest 311 | 312 | 313 | class Generic_Split(Generic_MIL_Dataset): 314 | def __init__(self, slide_data, data_dir=None, num_classes=2): 315 | self.use_h5 = True 316 | self.slide_data = slide_data 317 | self.data_dir = data_dir 318 | self.num_classes = num_classes 319 | self.slide_cls_ids = [[] for i in range(self.num_classes)] 320 | for i in range(self.num_classes): 321 | self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0] 322 | 323 | def __len__(self): 324 | return len(self.slide_data) 325 | -------------------------------------------------------------------------------- /datasets/camelyon16_total.csv: -------------------------------------------------------------------------------- 1 | case_id,slide_id,label 2 | patient_1,normal_001,normal_tissue 3 | patient_2,normal_002,normal_tissue 4 | patient_3,normal_003,normal_tissue 5 | patient_4,normal_004,normal_tissue 6 | patient_5,normal_005,normal_tissue 7 | patient_6,normal_006,normal_tissue 8 | patient_7,normal_007,normal_tissue 9 | patient_8,normal_008,normal_tissue 10 | patient_9,normal_009,normal_tissue 11 | patient_10,normal_010,normal_tissue 12 | patient_11,normal_011,normal_tissue 13 | patient_12,normal_012,normal_tissue 14 | patient_13,normal_013,normal_tissue 15 | patient_14,normal_014,normal_tissue 16 | patient_15,normal_015,normal_tissue 17 | patient_16,normal_016,normal_tissue 18 | patient_17,normal_017,normal_tissue 19 | patient_18,normal_018,normal_tissue 20 | patient_19,normal_019,normal_tissue 21 | patient_20,normal_020,normal_tissue 22 | patient_21,normal_021,normal_tissue 23 | patient_22,normal_022,normal_tissue 24 | patient_23,normal_023,normal_tissue 25 | patient_24,normal_024,normal_tissue 26 | patient_25,normal_025,normal_tissue 27 | patient_26,normal_026,normal_tissue 28 | patient_27,normal_027,normal_tissue 29 | patient_28,normal_028,normal_tissue 30 | patient_29,normal_029,normal_tissue 31 | patient_30,normal_030,normal_tissue 32 | patient_31,normal_031,normal_tissue 33 | patient_32,normal_032,normal_tissue 34 | patient_33,normal_033,normal_tissue 35 | patient_34,normal_034,normal_tissue 36 | patient_35,normal_035,normal_tissue 37 | patient_36,normal_036,normal_tissue 38 | patient_37,normal_037,normal_tissue 39 | patient_38,normal_038,normal_tissue 40 | patient_39,normal_039,normal_tissue 41 | patient_40,normal_040,normal_tissue 42 | patient_41,normal_041,normal_tissue 43 | patient_42,normal_042,normal_tissue 44 | patient_43,normal_043,normal_tissue 45 | patient_44,normal_044,normal_tissue 46 | patient_45,normal_045,normal_tissue 47 | patient_46,normal_046,normal_tissue 48 | patient_47,normal_047,normal_tissue 49 | patient_48,normal_048,normal_tissue 50 | patient_49,normal_049,normal_tissue 51 | patient_50,normal_050,normal_tissue 52 | patient_51,normal_051,normal_tissue 53 | patient_52,normal_052,normal_tissue 54 | patient_53,normal_053,normal_tissue 55 | patient_54,normal_054,normal_tissue 56 | patient_55,normal_055,normal_tissue 57 | patient_56,normal_056,normal_tissue 58 | patient_57,normal_057,normal_tissue 59 | patient_58,normal_058,normal_tissue 60 | patient_59,normal_059,normal_tissue 61 | patient_60,normal_060,normal_tissue 62 | patient_61,normal_061,normal_tissue 63 | patient_62,normal_062,normal_tissue 64 | patient_63,normal_063,normal_tissue 65 | patient_64,normal_064,normal_tissue 66 | patient_65,normal_065,normal_tissue 67 | patient_66,normal_066,normal_tissue 68 | patient_67,normal_067,normal_tissue 69 | patient_68,normal_068,normal_tissue 70 | patient_69,normal_069,normal_tissue 71 | patient_70,normal_070,normal_tissue 72 | patient_71,normal_071,normal_tissue 73 | patient_72,normal_072,normal_tissue 74 | patient_73,normal_073,normal_tissue 75 | patient_74,normal_074,normal_tissue 76 | patient_75,normal_075,normal_tissue 77 | patient_76,normal_076,normal_tissue 78 | patient_77,normal_077,normal_tissue 79 | patient_78,normal_078,normal_tissue 80 | patient_79,normal_079,normal_tissue 81 | patient_80,normal_080,normal_tissue 82 | patient_81,normal_081,normal_tissue 83 | patient_82,normal_082,normal_tissue 84 | patient_83,normal_083,normal_tissue 85 | patient_84,normal_084,normal_tissue 86 | patient_85,normal_085,normal_tissue 87 | patient_86,normal_087,normal_tissue 88 | patient_87,normal_088,normal_tissue 89 | patient_88,normal_089,normal_tissue 90 | patient_89,normal_090,normal_tissue 91 | patient_90,normal_091,normal_tissue 92 | patient_91,normal_092,normal_tissue 93 | patient_92,normal_093,normal_tissue 94 | patient_93,normal_094,normal_tissue 95 | patient_94,normal_095,normal_tissue 96 | patient_95,normal_096,normal_tissue 97 | patient_96,normal_097,normal_tissue 98 | patient_97,normal_098,normal_tissue 99 | patient_98,normal_099,normal_tissue 100 | patient_99,normal_100,normal_tissue 101 | patient_100,normal_101,normal_tissue 102 | patient_101,normal_102,normal_tissue 103 | patient_102,normal_103,normal_tissue 104 | patient_103,normal_104,normal_tissue 105 | patient_104,normal_105,normal_tissue 106 | patient_105,normal_106,normal_tissue 107 | patient_106,normal_107,normal_tissue 108 | patient_107,normal_108,normal_tissue 109 | patient_108,normal_109,normal_tissue 110 | patient_109,normal_110,normal_tissue 111 | patient_110,normal_111,normal_tissue 112 | patient_111,normal_112,normal_tissue 113 | patient_112,normal_113,normal_tissue 114 | patient_113,normal_114,normal_tissue 115 | patient_114,normal_115,normal_tissue 116 | patient_115,normal_116,normal_tissue 117 | patient_116,normal_117,normal_tissue 118 | patient_117,normal_118,normal_tissue 119 | patient_118,normal_119,normal_tissue 120 | patient_119,normal_120,normal_tissue 121 | patient_120,normal_121,normal_tissue 122 | patient_121,normal_122,normal_tissue 123 | patient_122,normal_123,normal_tissue 124 | patient_123,normal_124,normal_tissue 125 | patient_124,normal_125,normal_tissue 126 | patient_125,normal_126,normal_tissue 127 | patient_126,normal_127,normal_tissue 128 | patient_127,normal_128,normal_tissue 129 | patient_128,normal_129,normal_tissue 130 | patient_129,normal_130,normal_tissue 131 | patient_130,normal_131,normal_tissue 132 | patient_131,normal_132,normal_tissue 133 | patient_132,normal_133,normal_tissue 134 | patient_133,normal_134,normal_tissue 135 | patient_134,normal_135,normal_tissue 136 | patient_135,normal_136,normal_tissue 137 | patient_136,normal_137,normal_tissue 138 | patient_137,normal_138,normal_tissue 139 | patient_138,normal_139,normal_tissue 140 | patient_139,normal_140,normal_tissue 141 | patient_140,normal_141,normal_tissue 142 | patient_141,normal_142,normal_tissue 143 | patient_142,normal_143,normal_tissue 144 | patient_143,normal_144,normal_tissue 145 | patient_144,normal_145,normal_tissue 146 | patient_145,normal_146,normal_tissue 147 | patient_146,normal_147,normal_tissue 148 | patient_147,normal_148,normal_tissue 149 | patient_148,normal_149,normal_tissue 150 | patient_149,normal_150,normal_tissue 151 | patient_150,normal_151,normal_tissue 152 | patient_151,normal_152,normal_tissue 153 | patient_152,normal_153,normal_tissue 154 | patient_153,normal_154,normal_tissue 155 | patient_154,normal_155,normal_tissue 156 | patient_155,normal_156,normal_tissue 157 | patient_156,normal_157,normal_tissue 158 | patient_157,normal_158,normal_tissue 159 | patient_158,normal_159,normal_tissue 160 | patient_159,normal_160,normal_tissue 161 | patient_160,tumor_001,tumor_tissue 162 | patient_161,tumor_002,tumor_tissue 163 | patient_162,tumor_003,tumor_tissue 164 | patient_163,tumor_004,tumor_tissue 165 | patient_164,tumor_005,tumor_tissue 166 | patient_165,tumor_006,tumor_tissue 167 | patient_166,tumor_007,tumor_tissue 168 | patient_167,tumor_008,tumor_tissue 169 | patient_168,tumor_009,tumor_tissue 170 | patient_169,tumor_010,tumor_tissue 171 | patient_170,tumor_011,tumor_tissue 172 | patient_171,tumor_012,tumor_tissue 173 | patient_172,tumor_013,tumor_tissue 174 | patient_173,tumor_014,tumor_tissue 175 | patient_174,tumor_015,tumor_tissue 176 | patient_175,tumor_016,tumor_tissue 177 | patient_176,tumor_017,tumor_tissue 178 | patient_177,tumor_018,tumor_tissue 179 | patient_178,tumor_019,tumor_tissue 180 | patient_179,tumor_020,tumor_tissue 181 | patient_180,tumor_021,tumor_tissue 182 | patient_181,tumor_022,tumor_tissue 183 | patient_182,tumor_023,tumor_tissue 184 | patient_183,tumor_024,tumor_tissue 185 | patient_184,tumor_025,tumor_tissue 186 | patient_185,tumor_026,tumor_tissue 187 | patient_186,tumor_027,tumor_tissue 188 | patient_187,tumor_028,tumor_tissue 189 | patient_188,tumor_029,tumor_tissue 190 | patient_189,tumor_030,tumor_tissue 191 | patient_190,tumor_031,tumor_tissue 192 | patient_191,tumor_032,tumor_tissue 193 | patient_192,tumor_033,tumor_tissue 194 | patient_193,tumor_034,tumor_tissue 195 | patient_194,tumor_035,tumor_tissue 196 | patient_195,tumor_036,tumor_tissue 197 | patient_196,tumor_037,tumor_tissue 198 | patient_197,tumor_038,tumor_tissue 199 | patient_198,tumor_039,tumor_tissue 200 | patient_199,tumor_040,tumor_tissue 201 | patient_200,tumor_041,tumor_tissue 202 | patient_201,tumor_042,tumor_tissue 203 | patient_202,tumor_043,tumor_tissue 204 | patient_203,tumor_044,tumor_tissue 205 | patient_204,tumor_045,tumor_tissue 206 | patient_205,tumor_046,tumor_tissue 207 | patient_206,tumor_047,tumor_tissue 208 | patient_207,tumor_048,tumor_tissue 209 | patient_208,tumor_049,tumor_tissue 210 | patient_209,tumor_050,tumor_tissue 211 | patient_210,tumor_051,tumor_tissue 212 | patient_211,tumor_052,tumor_tissue 213 | patient_212,tumor_053,tumor_tissue 214 | patient_213,tumor_054,tumor_tissue 215 | patient_214,tumor_055,tumor_tissue 216 | patient_215,tumor_056,tumor_tissue 217 | patient_216,tumor_057,tumor_tissue 218 | patient_217,tumor_058,tumor_tissue 219 | patient_218,tumor_059,tumor_tissue 220 | patient_219,tumor_060,tumor_tissue 221 | patient_220,tumor_061,tumor_tissue 222 | patient_221,tumor_062,tumor_tissue 223 | patient_222,tumor_063,tumor_tissue 224 | patient_223,tumor_064,tumor_tissue 225 | patient_224,tumor_065,tumor_tissue 226 | patient_225,tumor_066,tumor_tissue 227 | patient_226,tumor_067,tumor_tissue 228 | patient_227,tumor_068,tumor_tissue 229 | patient_228,tumor_069,tumor_tissue 230 | patient_229,tumor_070,tumor_tissue 231 | patient_230,tumor_071,tumor_tissue 232 | patient_231,tumor_072,tumor_tissue 233 | patient_232,tumor_073,tumor_tissue 234 | patient_233,tumor_074,tumor_tissue 235 | patient_234,tumor_075,tumor_tissue 236 | patient_235,tumor_076,tumor_tissue 237 | patient_236,tumor_077,tumor_tissue 238 | patient_237,tumor_078,tumor_tissue 239 | patient_238,tumor_079,tumor_tissue 240 | patient_239,tumor_080,tumor_tissue 241 | patient_240,tumor_081,tumor_tissue 242 | patient_241,tumor_082,tumor_tissue 243 | patient_242,tumor_083,tumor_tissue 244 | patient_243,tumor_084,tumor_tissue 245 | patient_244,tumor_085,tumor_tissue 246 | patient_245,tumor_086,tumor_tissue 247 | patient_246,tumor_087,tumor_tissue 248 | patient_247,tumor_088,tumor_tissue 249 | patient_248,tumor_089,tumor_tissue 250 | patient_249,tumor_090,tumor_tissue 251 | patient_250,tumor_091,tumor_tissue 252 | patient_251,tumor_092,tumor_tissue 253 | patient_252,tumor_093,tumor_tissue 254 | patient_253,tumor_094,tumor_tissue 255 | patient_254,tumor_095,tumor_tissue 256 | patient_255,tumor_096,tumor_tissue 257 | patient_256,tumor_097,tumor_tissue 258 | patient_257,tumor_098,tumor_tissue 259 | patient_258,tumor_099,tumor_tissue 260 | patient_259,tumor_100,tumor_tissue 261 | patient_260,tumor_101,tumor_tissue 262 | patient_261,tumor_102,tumor_tissue 263 | patient_262,tumor_103,tumor_tissue 264 | patient_263,tumor_104,tumor_tissue 265 | patient_264,tumor_105,tumor_tissue 266 | patient_265,tumor_106,tumor_tissue 267 | patient_266,tumor_107,tumor_tissue 268 | patient_267,tumor_108,tumor_tissue 269 | patient_268,tumor_109,tumor_tissue 270 | patient_269,tumor_110,tumor_tissue 271 | patient_270,tumor_111,tumor_tissue 272 | patient_271,test_001,tumor_tissue 273 | patient_272,test_002,tumor_tissue 274 | patient_273,test_003,normal_tissue 275 | patient_274,test_004,tumor_tissue 276 | patient_275,test_005,normal_tissue 277 | patient_276,test_006,normal_tissue 278 | patient_277,test_007,normal_tissue 279 | patient_278,test_008,tumor_tissue 280 | patient_279,test_009,normal_tissue 281 | patient_280,test_010,tumor_tissue 282 | patient_281,test_011,tumor_tissue 283 | patient_282,test_012,normal_tissue 284 | patient_283,test_013,tumor_tissue 285 | patient_284,test_014,normal_tissue 286 | patient_285,test_015,normal_tissue 287 | patient_286,test_016,tumor_tissue 288 | patient_287,test_017,normal_tissue 289 | patient_288,test_018,normal_tissue 290 | patient_289,test_019,normal_tissue 291 | patient_290,test_020,normal_tissue 292 | patient_291,test_021,tumor_tissue 293 | patient_292,test_022,normal_tissue 294 | patient_293,test_023,normal_tissue 295 | patient_294,test_024,normal_tissue 296 | patient_295,test_025,normal_tissue 297 | patient_296,test_026,tumor_tissue 298 | patient_297,test_027,tumor_tissue 299 | patient_298,test_028,normal_tissue 300 | patient_299,test_029,tumor_tissue 301 | patient_300,test_030,tumor_tissue 302 | patient_301,test_031,normal_tissue 303 | patient_302,test_032,normal_tissue 304 | patient_303,test_033,tumor_tissue 305 | patient_304,test_034,normal_tissue 306 | patient_305,test_035,normal_tissue 307 | patient_306,test_036,normal_tissue 308 | patient_307,test_037,normal_tissue 309 | patient_308,test_038,tumor_tissue 310 | patient_309,test_039,normal_tissue 311 | patient_310,test_040,tumor_tissue 312 | patient_311,test_041,normal_tissue 313 | patient_312,test_042,normal_tissue 314 | patient_313,test_043,normal_tissue 315 | patient_314,test_044,normal_tissue 316 | patient_315,test_045,normal_tissue 317 | patient_316,test_046,tumor_tissue 318 | patient_317,test_047,normal_tissue 319 | patient_318,test_048,tumor_tissue 320 | patient_319,test_050,normal_tissue 321 | patient_320,test_051,tumor_tissue 322 | patient_321,test_052,tumor_tissue 323 | patient_322,test_053,normal_tissue 324 | patient_323,test_054,normal_tissue 325 | patient_324,test_055,normal_tissue 326 | patient_325,test_056,normal_tissue 327 | patient_326,test_057,normal_tissue 328 | patient_327,test_058,normal_tissue 329 | patient_328,test_059,normal_tissue 330 | patient_329,test_060,normal_tissue 331 | patient_330,test_061,tumor_tissue 332 | patient_331,test_062,normal_tissue 333 | patient_332,test_063,normal_tissue 334 | patient_333,test_064,tumor_tissue 335 | patient_334,test_065,tumor_tissue 336 | patient_335,test_066,tumor_tissue 337 | patient_336,test_067,normal_tissue 338 | patient_337,test_068,tumor_tissue 339 | patient_338,test_069,tumor_tissue 340 | patient_339,test_070,normal_tissue 341 | patient_340,test_071,tumor_tissue 342 | patient_341,test_072,normal_tissue 343 | patient_342,test_073,tumor_tissue 344 | patient_343,test_074,tumor_tissue 345 | patient_344,test_075,tumor_tissue 346 | patient_345,test_076,normal_tissue 347 | patient_346,test_077,normal_tissue 348 | patient_347,test_078,normal_tissue 349 | patient_348,test_079,tumor_tissue 350 | patient_349,test_080,normal_tissue 351 | patient_350,test_081,normal_tissue 352 | patient_351,test_082,tumor_tissue 353 | patient_352,test_083,normal_tissue 354 | patient_353,test_084,tumor_tissue 355 | patient_354,test_085,normal_tissue 356 | patient_355,test_086,normal_tissue 357 | patient_356,test_087,normal_tissue 358 | patient_357,test_088,normal_tissue 359 | patient_358,test_089,normal_tissue 360 | patient_359,test_090,tumor_tissue 361 | patient_360,test_091,normal_tissue 362 | patient_361,test_092,tumor_tissue 363 | patient_362,test_093,normal_tissue 364 | patient_363,test_094,tumor_tissue 365 | patient_364,test_095,normal_tissue 366 | patient_365,test_096,normal_tissue 367 | patient_366,test_097,tumor_tissue 368 | patient_367,test_098,normal_tissue 369 | patient_368,test_099,tumor_tissue 370 | patient_369,test_100,normal_tissue 371 | patient_370,test_101,normal_tissue 372 | patient_371,test_102,tumor_tissue 373 | patient_372,test_103,normal_tissue 374 | patient_373,test_104,tumor_tissue 375 | patient_374,test_105,tumor_tissue 376 | patient_375,test_106,normal_tissue 377 | patient_376,test_107,normal_tissue 378 | patient_377,test_108,tumor_tissue 379 | patient_378,test_109,normal_tissue 380 | patient_379,test_110,tumor_tissue 381 | patient_380,test_111,normal_tissue 382 | patient_381,test_112,normal_tissue 383 | patient_382,test_113,tumor_tissue 384 | patient_383,test_114,tumor_tissue 385 | patient_384,test_115,normal_tissue 386 | patient_385,test_116,tumor_tissue 387 | patient_386,test_117,tumor_tissue 388 | patient_387,test_118,normal_tissue 389 | patient_388,test_119,normal_tissue 390 | patient_389,test_120,normal_tissue 391 | patient_390,test_121,tumor_tissue 392 | patient_391,test_122,tumor_tissue 393 | patient_392,test_123,normal_tissue 394 | patient_393,test_124,normal_tissue 395 | patient_394,test_125,normal_tissue 396 | patient_395,test_126,normal_tissue 397 | patient_396,test_127,normal_tissue 398 | patient_397,test_128,normal_tissue 399 | patient_398,test_129,normal_tissue 400 | patient_399,test_130,normal_tissue 401 | -------------------------------------------------------------------------------- /datasets/camelyon17.csv: -------------------------------------------------------------------------------- 1 | case_id,slide_id,label 2 | patient_000,patient_000_node_0,normal_tissue 3 | patient_000,patient_000_node_1,normal_tissue 4 | patient_000,patient_000_node_2,normal_tissue 5 | patient_000,patient_000_node_3,normal_tissue 6 | patient_000,patient_000_node_4,normal_tissue 7 | patient_001,patient_001_node_0,normal_tissue 8 | patient_001,patient_001_node_1,normal_tissue 9 | patient_001,patient_001_node_2,normal_tissue 10 | patient_001,patient_001_node_3,normal_tissue 11 | patient_001,patient_001_node_4,normal_tissue 12 | patient_002,patient_002_node_0,normal_tissue 13 | patient_002,patient_002_node_1,normal_tissue 14 | patient_002,patient_002_node_2,normal_tissue 15 | patient_002,patient_002_node_3,normal_tissue 16 | patient_002,patient_002_node_4,normal_tissue 17 | patient_003,patient_003_node_0,normal_tissue 18 | patient_003,patient_003_node_1,normal_tissue 19 | patient_003,patient_003_node_2,normal_tissue 20 | patient_003,patient_003_node_3,normal_tissue 21 | patient_003,patient_003_node_4,normal_tissue 22 | patient_004,patient_004_node_0,normal_tissue 23 | patient_004,patient_004_node_1,normal_tissue 24 | patient_004,patient_004_node_2,normal_tissue 25 | patient_004,patient_004_node_3,normal_tissue 26 | patient_005,patient_005_node_0,normal_tissue 27 | patient_005,patient_005_node_1,normal_tissue 28 | patient_005,patient_005_node_2,normal_tissue 29 | patient_006,patient_006_node_0,normal_tissue 30 | patient_006,patient_006_node_1,normal_tissue 31 | patient_006,patient_006_node_3,normal_tissue 32 | patient_007,patient_007_node_0,normal_tissue 33 | patient_007,patient_007_node_1,normal_tissue 34 | patient_007,patient_007_node_2,normal_tissue 35 | patient_007,patient_007_node_3,normal_tissue 36 | patient_007,patient_007_node_4,tumor_tissue 37 | patient_008,patient_008_node_0,tumor_tissue 38 | patient_008,patient_008_node_1,normal_tissue 39 | patient_008,patient_008_node_2,normal_tissue 40 | patient_008,patient_008_node_3,normal_tissue 41 | patient_008,patient_008_node_4,normal_tissue 42 | patient_009,patient_009_node_1,tumor_tissue 43 | patient_009,patient_009_node_2,normal_tissue 44 | patient_009,patient_009_node_3,normal_tissue 45 | patient_009,patient_009_node_4,normal_tissue 46 | patient_010,patient_010_node_0,tumor_tissue 47 | patient_010,patient_010_node_1,normal_tissue 48 | patient_010,patient_010_node_2,normal_tissue 49 | patient_010,patient_010_node_3,normal_tissue 50 | patient_010,patient_010_node_4,tumor_tissue 51 | patient_011,patient_011_node_0,tumor_tissue 52 | patient_011,patient_011_node_1,normal_tissue 53 | patient_011,patient_011_node_2,normal_tissue 54 | patient_011,patient_011_node_3,normal_tissue 55 | patient_011,patient_011_node_4,normal_tissue 56 | patient_012,patient_012_node_0,tumor_tissue 57 | patient_012,patient_012_node_1,normal_tissue 58 | patient_012,patient_012_node_2,normal_tissue 59 | patient_012,patient_012_node_3,normal_tissue 60 | patient_012,patient_012_node_4,normal_tissue 61 | patient_013,patient_013_node_0,normal_tissue 62 | patient_013,patient_013_node_1,normal_tissue 63 | patient_013,patient_013_node_2,tumor_tissue 64 | patient_013,patient_013_node_3,tumor_tissue 65 | patient_013,patient_013_node_4,tumor_tissue 66 | patient_014,patient_014_node_0,normal_tissue 67 | patient_014,patient_014_node_1,normal_tissue 68 | patient_014,patient_014_node_2,tumor_tissue 69 | patient_014,patient_014_node_4,tumor_tissue 70 | patient_015,patient_015_node_0,normal_tissue 71 | patient_015,patient_015_node_1,tumor_tissue 72 | patient_015,patient_015_node_2,tumor_tissue 73 | patient_015,patient_015_node_3,normal_tissue 74 | patient_016,patient_016_node_0,tumor_tissue 75 | patient_016,patient_016_node_2,normal_tissue 76 | patient_016,patient_016_node_3,normal_tissue 77 | patient_016,patient_016_node_4,normal_tissue 78 | patient_017,patient_017_node_0,normal_tissue 79 | patient_017,patient_017_node_2,tumor_tissue 80 | patient_017,patient_017_node_3,tumor_tissue 81 | patient_017,patient_017_node_4,tumor_tissue 82 | patient_018,patient_018_node_0,normal_tissue 83 | patient_018,patient_018_node_1,tumor_tissue 84 | patient_018,patient_018_node_2,tumor_tissue 85 | patient_018,patient_018_node_3,tumor_tissue 86 | patient_019,patient_019_node_0,tumor_tissue 87 | patient_019,patient_019_node_1,tumor_tissue 88 | patient_019,patient_019_node_2,tumor_tissue 89 | patient_019,patient_019_node_3,tumor_tissue 90 | patient_019,patient_019_node_4,normal_tissue 91 | patient_020,patient_020_node_0,normal_tissue 92 | patient_020,patient_020_node_1,tumor_tissue 93 | patient_020,patient_020_node_2,tumor_tissue 94 | patient_020,patient_020_node_3,normal_tissue 95 | patient_020,patient_020_node_4,tumor_tissue 96 | patient_021,patient_021_node_0,tumor_tissue 97 | patient_021,patient_021_node_1,tumor_tissue 98 | patient_021,patient_021_node_2,normal_tissue 99 | patient_021,patient_021_node_3,tumor_tissue 100 | patient_021,patient_021_node_4,tumor_tissue 101 | patient_022,patient_022_node_0,tumor_tissue 102 | patient_022,patient_022_node_1,tumor_tissue 103 | patient_022,patient_022_node_2,normal_tissue 104 | patient_022,patient_022_node_3,normal_tissue 105 | patient_022,patient_022_node_4,tumor_tissue 106 | patient_023,patient_023_node_0,normal_tissue 107 | patient_023,patient_023_node_1,normal_tissue 108 | patient_023,patient_023_node_2,normal_tissue 109 | patient_023,patient_023_node_3,normal_tissue 110 | patient_023,patient_023_node_4,normal_tissue 111 | patient_024,patient_024_node_0,normal_tissue 112 | patient_024,patient_024_node_3,normal_tissue 113 | patient_025,patient_025_node_0,normal_tissue 114 | patient_025,patient_025_node_1,normal_tissue 115 | patient_025,patient_025_node_2,normal_tissue 116 | patient_025,patient_025_node_3,normal_tissue 117 | patient_025,patient_025_node_4,normal_tissue 118 | patient_026,patient_026_node_1,tumor_tissue 119 | patient_026,patient_026_node_2,tumor_tissue 120 | patient_026,patient_026_node_3,tumor_tissue 121 | patient_026,patient_026_node_4,tumor_tissue 122 | patient_027,patient_027_node_0,tumor_tissue 123 | patient_027,patient_027_node_1,normal_tissue 124 | patient_027,patient_027_node_2,tumor_tissue 125 | patient_027,patient_027_node_3,normal_tissue 126 | patient_027,patient_027_node_4,tumor_tissue 127 | patient_028,patient_028_node_0,tumor_tissue 128 | patient_028,patient_028_node_1,tumor_tissue 129 | patient_028,patient_028_node_2,tumor_tissue 130 | patient_028,patient_028_node_3,tumor_tissue 131 | patient_028,patient_028_node_4,normal_tissue 132 | patient_029,patient_029_node_0,tumor_tissue 133 | patient_029,patient_029_node_1,tumor_tissue 134 | patient_029,patient_029_node_3,tumor_tissue 135 | patient_029,patient_029_node_4,normal_tissue 136 | patient_030,patient_030_node_0,normal_tissue 137 | patient_030,patient_030_node_1,normal_tissue 138 | patient_030,patient_030_node_2,normal_tissue 139 | patient_030,patient_030_node_3,normal_tissue 140 | patient_030,patient_030_node_4,tumor_tissue 141 | patient_031,patient_031_node_0,normal_tissue 142 | patient_031,patient_031_node_1,normal_tissue 143 | patient_031,patient_031_node_2,normal_tissue 144 | patient_031,patient_031_node_3,normal_tissue 145 | patient_031,patient_031_node_4,normal_tissue 146 | patient_032,patient_032_node_0,normal_tissue 147 | patient_032,patient_032_node_1,tumor_tissue 148 | patient_032,patient_032_node_2,normal_tissue 149 | patient_032,patient_032_node_3,normal_tissue 150 | patient_032,patient_032_node_4,normal_tissue 151 | patient_033,patient_033_node_0,normal_tissue 152 | patient_033,patient_033_node_1,tumor_tissue 153 | patient_033,patient_033_node_2,tumor_tissue 154 | patient_033,patient_033_node_3,tumor_tissue 155 | patient_033,patient_033_node_4,normal_tissue 156 | patient_034,patient_034_node_0,normal_tissue 157 | patient_034,patient_034_node_1,normal_tissue 158 | patient_034,patient_034_node_2,normal_tissue 159 | patient_034,patient_034_node_3,tumor_tissue 160 | patient_034,patient_034_node_4,normal_tissue 161 | patient_035,patient_035_node_0,normal_tissue 162 | patient_035,patient_035_node_1,normal_tissue 163 | patient_035,patient_035_node_2,normal_tissue 164 | patient_035,patient_035_node_3,normal_tissue 165 | patient_035,patient_035_node_4,normal_tissue 166 | patient_036,patient_036_node_0,normal_tissue 167 | patient_036,patient_036_node_1,normal_tissue 168 | patient_036,patient_036_node_2,normal_tissue 169 | patient_036,patient_036_node_3,tumor_tissue 170 | patient_036,patient_036_node_4,normal_tissue 171 | patient_037,patient_037_node_0,normal_tissue 172 | patient_037,patient_037_node_2,normal_tissue 173 | patient_037,patient_037_node_3,normal_tissue 174 | patient_037,patient_037_node_4,normal_tissue 175 | patient_038,patient_038_node_0,tumor_tissue 176 | patient_038,patient_038_node_1,normal_tissue 177 | patient_038,patient_038_node_2,tumor_tissue 178 | patient_038,patient_038_node_3,normal_tissue 179 | patient_039,patient_039_node_0,tumor_tissue 180 | patient_039,patient_039_node_1,tumor_tissue 181 | patient_039,patient_039_node_2,normal_tissue 182 | patient_039,patient_039_node_3,normal_tissue 183 | patient_039,patient_039_node_4,normal_tissue 184 | patient_040,patient_040_node_0,normal_tissue 185 | patient_040,patient_040_node_1,normal_tissue 186 | patient_040,patient_040_node_3,normal_tissue 187 | patient_040,patient_040_node_4,normal_tissue 188 | patient_041,patient_041_node_1,normal_tissue 189 | patient_041,patient_041_node_2,normal_tissue 190 | patient_041,patient_041_node_3,normal_tissue 191 | patient_041,patient_041_node_4,normal_tissue 192 | patient_042,patient_042_node_0,normal_tissue 193 | patient_042,patient_042_node_1,normal_tissue 194 | patient_042,patient_042_node_2,normal_tissue 195 | patient_042,patient_042_node_3,tumor_tissue 196 | patient_042,patient_042_node_4,normal_tissue 197 | patient_043,patient_043_node_0,normal_tissue 198 | patient_043,patient_043_node_1,normal_tissue 199 | patient_043,patient_043_node_2,normal_tissue 200 | patient_043,patient_043_node_3,normal_tissue 201 | patient_043,patient_043_node_4,normal_tissue 202 | patient_044,patient_044_node_0,normal_tissue 203 | patient_044,patient_044_node_1,normal_tissue 204 | patient_044,patient_044_node_2,normal_tissue 205 | patient_044,patient_044_node_3,tumor_tissue 206 | patient_044,patient_044_node_4,tumor_tissue 207 | patient_045,patient_045_node_0,tumor_tissue 208 | patient_045,patient_045_node_1,tumor_tissue 209 | patient_045,patient_045_node_2,normal_tissue 210 | patient_045,patient_045_node_3,tumor_tissue 211 | patient_045,patient_045_node_4,tumor_tissue 212 | patient_046,patient_046_node_0,normal_tissue 213 | patient_046,patient_046_node_1,normal_tissue 214 | patient_046,patient_046_node_2,normal_tissue 215 | patient_046,patient_046_node_3,tumor_tissue 216 | patient_046,patient_046_node_4,tumor_tissue 217 | patient_047,patient_047_node_0,normal_tissue 218 | patient_047,patient_047_node_1,normal_tissue 219 | patient_047,patient_047_node_2,tumor_tissue 220 | patient_047,patient_047_node_3,normal_tissue 221 | patient_047,patient_047_node_4,normal_tissue 222 | patient_048,patient_048_node_0,tumor_tissue 223 | patient_048,patient_048_node_1,tumor_tissue 224 | patient_048,patient_048_node_2,normal_tissue 225 | patient_048,patient_048_node_3,normal_tissue 226 | patient_048,patient_048_node_4,normal_tissue 227 | patient_049,patient_049_node_0,normal_tissue 228 | patient_049,patient_049_node_1,tumor_tissue 229 | patient_049,patient_049_node_2,tumor_tissue 230 | patient_049,patient_049_node_3,normal_tissue 231 | patient_049,patient_049_node_4,normal_tissue 232 | patient_050,patient_050_node_0,normal_tissue 233 | patient_050,patient_050_node_1,normal_tissue 234 | patient_050,patient_050_node_2,normal_tissue 235 | patient_050,patient_050_node_3,normal_tissue 236 | patient_050,patient_050_node_4,tumor_tissue 237 | patient_051,patient_051_node_0,tumor_tissue 238 | patient_051,patient_051_node_1,tumor_tissue 239 | patient_051,patient_051_node_2,tumor_tissue 240 | patient_051,patient_051_node_3,normal_tissue 241 | patient_051,patient_051_node_4,tumor_tissue 242 | patient_052,patient_052_node_0,tumor_tissue 243 | patient_052,patient_052_node_1,tumor_tissue 244 | patient_052,patient_052_node_2,tumor_tissue 245 | patient_052,patient_052_node_3,tumor_tissue 246 | patient_052,patient_052_node_4,normal_tissue 247 | patient_053,patient_053_node_0,normal_tissue 248 | patient_053,patient_053_node_1,normal_tissue 249 | patient_053,patient_053_node_2,normal_tissue 250 | patient_053,patient_053_node_3,normal_tissue 251 | patient_053,patient_053_node_4,normal_tissue 252 | patient_054,patient_054_node_0,normal_tissue 253 | patient_054,patient_054_node_1,normal_tissue 254 | patient_054,patient_054_node_2,normal_tissue 255 | patient_054,patient_054_node_3,normal_tissue 256 | patient_054,patient_054_node_4,normal_tissue 257 | patient_055,patient_055_node_0,normal_tissue 258 | patient_055,patient_055_node_1,normal_tissue 259 | patient_055,patient_055_node_2,normal_tissue 260 | patient_055,patient_055_node_3,normal_tissue 261 | patient_055,patient_055_node_4,normal_tissue 262 | patient_056,patient_056_node_0,normal_tissue 263 | patient_056,patient_056_node_1,normal_tissue 264 | patient_056,patient_056_node_2,normal_tissue 265 | patient_056,patient_056_node_3,normal_tissue 266 | patient_056,patient_056_node_4,normal_tissue 267 | patient_057,patient_057_node_0,normal_tissue 268 | patient_057,patient_057_node_1,normal_tissue 269 | patient_057,patient_057_node_2,normal_tissue 270 | patient_057,patient_057_node_3,normal_tissue 271 | patient_057,patient_057_node_4,normal_tissue 272 | patient_058,patient_058_node_0,normal_tissue 273 | patient_058,patient_058_node_1,normal_tissue 274 | patient_058,patient_058_node_2,normal_tissue 275 | patient_058,patient_058_node_3,normal_tissue 276 | patient_058,patient_058_node_4,normal_tissue 277 | patient_059,patient_059_node_0,normal_tissue 278 | patient_059,patient_059_node_1,normal_tissue 279 | patient_059,patient_059_node_2,normal_tissue 280 | patient_059,patient_059_node_3,normal_tissue 281 | patient_059,patient_059_node_4,normal_tissue 282 | patient_060,patient_060_node_0,tumor_tissue 283 | patient_060,patient_060_node_1,normal_tissue 284 | patient_060,patient_060_node_2,normal_tissue 285 | patient_060,patient_060_node_4,normal_tissue 286 | patient_061,patient_061_node_0,normal_tissue 287 | patient_061,patient_061_node_1,normal_tissue 288 | patient_061,patient_061_node_2,normal_tissue 289 | patient_061,patient_061_node_3,normal_tissue 290 | patient_061,patient_061_node_4,tumor_tissue 291 | patient_062,patient_062_node_0,tumor_tissue 292 | patient_062,patient_062_node_1,tumor_tissue 293 | patient_062,patient_062_node_2,tumor_tissue 294 | patient_062,patient_062_node_3,tumor_tissue 295 | patient_062,patient_062_node_4,normal_tissue 296 | patient_063,patient_063_node_0,tumor_tissue 297 | patient_063,patient_063_node_1,normal_tissue 298 | patient_063,patient_063_node_2,tumor_tissue 299 | patient_063,patient_063_node_3,tumor_tissue 300 | patient_063,patient_063_node_4,tumor_tissue 301 | patient_064,patient_064_node_1,normal_tissue 302 | patient_064,patient_064_node_2,normal_tissue 303 | patient_064,patient_064_node_4,normal_tissue 304 | patient_065,patient_065_node_0,tumor_tissue 305 | patient_065,patient_065_node_1,normal_tissue 306 | patient_065,patient_065_node_2,normal_tissue 307 | patient_065,patient_065_node_3,normal_tissue 308 | patient_065,patient_065_node_4,tumor_tissue 309 | patient_066,patient_066_node_0,normal_tissue 310 | patient_066,patient_066_node_1,normal_tissue 311 | patient_066,patient_066_node_2,tumor_tissue 312 | patient_066,patient_066_node_3,tumor_tissue 313 | patient_066,patient_066_node_4,normal_tissue 314 | patient_067,patient_067_node_0,normal_tissue 315 | patient_067,patient_067_node_1,normal_tissue 316 | patient_067,patient_067_node_2,normal_tissue 317 | patient_067,patient_067_node_4,tumor_tissue 318 | patient_068,patient_068_node_0,normal_tissue 319 | patient_068,patient_068_node_2,normal_tissue 320 | patient_068,patient_068_node_3,normal_tissue 321 | patient_069,patient_069_node_0,tumor_tissue 322 | patient_069,patient_069_node_1,tumor_tissue 323 | patient_069,patient_069_node_2,tumor_tissue 324 | patient_069,patient_069_node_3,tumor_tissue 325 | patient_069,patient_069_node_4,normal_tissue 326 | patient_070,patient_070_node_0,normal_tissue 327 | patient_070,patient_070_node_1,normal_tissue 328 | patient_070,patient_070_node_2,normal_tissue 329 | patient_070,patient_070_node_3,tumor_tissue 330 | patient_070,patient_070_node_4,normal_tissue 331 | patient_071,patient_071_node_0,normal_tissue 332 | patient_071,patient_071_node_1,normal_tissue 333 | patient_071,patient_071_node_2,normal_tissue 334 | patient_071,patient_071_node_3,normal_tissue 335 | patient_071,patient_071_node_4,normal_tissue 336 | patient_072,patient_072_node_1,tumor_tissue 337 | patient_072,patient_072_node_2,tumor_tissue 338 | patient_072,patient_072_node_3,normal_tissue 339 | patient_072,patient_072_node_4,tumor_tissue 340 | patient_073,patient_073_node_0,normal_tissue 341 | patient_073,patient_073_node_1,tumor_tissue 342 | patient_073,patient_073_node_2,normal_tissue 343 | patient_073,patient_073_node_3,tumor_tissue 344 | patient_073,patient_073_node_4,normal_tissue 345 | patient_074,patient_074_node_0,normal_tissue 346 | patient_074,patient_074_node_1,normal_tissue 347 | patient_074,patient_074_node_2,normal_tissue 348 | patient_074,patient_074_node_3,normal_tissue 349 | patient_074,patient_074_node_4,tumor_tissue 350 | patient_075,patient_075_node_0,normal_tissue 351 | patient_075,patient_075_node_1,normal_tissue 352 | patient_075,patient_075_node_2,normal_tissue 353 | patient_075,patient_075_node_3,normal_tissue 354 | patient_075,patient_075_node_4,tumor_tissue 355 | patient_076,patient_076_node_0,normal_tissue 356 | patient_076,patient_076_node_1,tumor_tissue 357 | patient_076,patient_076_node_2,tumor_tissue 358 | patient_076,patient_076_node_3,tumor_tissue 359 | patient_077,patient_077_node_0,normal_tissue 360 | patient_077,patient_077_node_1,normal_tissue 361 | patient_077,patient_077_node_2,tumor_tissue 362 | patient_077,patient_077_node_3,normal_tissue 363 | patient_077,patient_077_node_4,tumor_tissue 364 | patient_078,patient_078_node_0,normal_tissue 365 | patient_078,patient_078_node_1,normal_tissue 366 | patient_078,patient_078_node_2,normal_tissue 367 | patient_078,patient_078_node_3,normal_tissue 368 | patient_078,patient_078_node_4,normal_tissue 369 | patient_079,patient_079_node_0,normal_tissue 370 | patient_079,patient_079_node_1,normal_tissue 371 | patient_079,patient_079_node_2,normal_tissue 372 | patient_079,patient_079_node_3,normal_tissue 373 | patient_079,patient_079_node_4,normal_tissue 374 | patient_080,patient_080_node_2,tumor_tissue 375 | patient_080,patient_080_node_3,tumor_tissue 376 | patient_080,patient_080_node_4,tumor_tissue 377 | patient_081,patient_081_node_0,normal_tissue 378 | patient_081,patient_081_node_1,tumor_tissue 379 | patient_081,patient_081_node_2,tumor_tissue 380 | patient_081,patient_081_node_3,normal_tissue 381 | patient_082,patient_082_node_0,normal_tissue 382 | patient_082,patient_082_node_1,normal_tissue 383 | patient_082,patient_082_node_2,normal_tissue 384 | patient_082,patient_082_node_3,normal_tissue 385 | patient_082,patient_082_node_4,normal_tissue 386 | patient_083,patient_083_node_0,normal_tissue 387 | patient_083,patient_083_node_1,normal_tissue 388 | patient_083,patient_083_node_2,normal_tissue 389 | patient_083,patient_083_node_3,normal_tissue 390 | patient_083,patient_083_node_4,normal_tissue 391 | patient_084,patient_084_node_0,normal_tissue 392 | patient_084,patient_084_node_1,normal_tissue 393 | patient_084,patient_084_node_2,tumor_tissue 394 | patient_084,patient_084_node_3,tumor_tissue 395 | patient_084,patient_084_node_4,normal_tissue 396 | patient_085,patient_085_node_0,normal_tissue 397 | patient_085,patient_085_node_1,normal_tissue 398 | patient_085,patient_085_node_2,normal_tissue 399 | patient_085,patient_085_node_3,normal_tissue 400 | patient_085,patient_085_node_4,normal_tissue 401 | patient_086,patient_086_node_1,normal_tissue 402 | patient_086,patient_086_node_2,normal_tissue 403 | patient_086,patient_086_node_3,normal_tissue 404 | patient_087,patient_087_node_2,normal_tissue 405 | patient_087,patient_087_node_3,normal_tissue 406 | patient_087,patient_087_node_4,normal_tissue 407 | patient_088,patient_088_node_0,normal_tissue 408 | patient_088,patient_088_node_1,tumor_tissue 409 | patient_088,patient_088_node_2,normal_tissue 410 | patient_088,patient_088_node_3,normal_tissue 411 | patient_089,patient_089_node_0,normal_tissue 412 | patient_089,patient_089_node_1,normal_tissue 413 | patient_089,patient_089_node_2,normal_tissue 414 | patient_089,patient_089_node_3,tumor_tissue 415 | patient_089,patient_089_node_4,normal_tissue 416 | patient_090,patient_090_node_0,normal_tissue 417 | patient_090,patient_090_node_1,normal_tissue 418 | patient_090,patient_090_node_2,normal_tissue 419 | patient_090,patient_090_node_3,normal_tissue 420 | patient_090,patient_090_node_4,normal_tissue 421 | patient_091,patient_091_node_0,normal_tissue 422 | patient_091,patient_091_node_1,normal_tissue 423 | patient_091,patient_091_node_2,tumor_tissue 424 | patient_091,patient_091_node_3,tumor_tissue 425 | patient_091,patient_091_node_4,tumor_tissue 426 | patient_092,patient_092_node_0,tumor_tissue 427 | patient_092,patient_092_node_1,tumor_tissue 428 | patient_092,patient_092_node_2,normal_tissue 429 | patient_092,patient_092_node_3,tumor_tissue 430 | patient_092,patient_092_node_4,tumor_tissue 431 | patient_093,patient_093_node_0,normal_tissue 432 | patient_093,patient_093_node_1,normal_tissue 433 | patient_093,patient_093_node_2,normal_tissue 434 | patient_093,patient_093_node_3,normal_tissue 435 | patient_093,patient_093_node_4,tumor_tissue 436 | patient_094,patient_094_node_0,tumor_tissue 437 | patient_094,patient_094_node_1,tumor_tissue 438 | patient_094,patient_094_node_2,tumor_tissue 439 | patient_094,patient_094_node_3,normal_tissue 440 | patient_094,patient_094_node_4,tumor_tissue 441 | patient_095,patient_095_node_0,tumor_tissue 442 | patient_095,patient_095_node_1,normal_tissue 443 | patient_095,patient_095_node_2,normal_tissue 444 | patient_095,patient_095_node_3,normal_tissue 445 | patient_095,patient_095_node_4,normal_tissue 446 | patient_096,patient_096_node_0,tumor_tissue 447 | patient_096,patient_096_node_1,normal_tissue 448 | patient_096,patient_096_node_2,tumor_tissue 449 | patient_096,patient_096_node_3,tumor_tissue 450 | patient_096,patient_096_node_4,tumor_tissue 451 | patient_097,patient_097_node_0,tumor_tissue 452 | patient_097,patient_097_node_1,tumor_tissue 453 | patient_097,patient_097_node_2,normal_tissue 454 | patient_097,patient_097_node_3,tumor_tissue 455 | patient_097,patient_097_node_4,tumor_tissue 456 | patient_098,patient_098_node_0,normal_tissue 457 | patient_098,patient_098_node_1,normal_tissue 458 | patient_098,patient_098_node_2,normal_tissue 459 | patient_098,patient_098_node_3,normal_tissue 460 | patient_098,patient_098_node_4,normal_tissue 461 | patient_099,patient_099_node_0,normal_tissue 462 | patient_099,patient_099_node_1,normal_tissue 463 | patient_099,patient_099_node_2,normal_tissue 464 | patient_099,patient_099_node_3,normal_tissue 465 | patient_099,patient_099_node_4,tumor_tissue 466 | --------------------------------------------------------------------------------