├── datasets
├── a
├── unitopatho_test.csv
├── camelyon16_test.csv
├── unitopatho_train.csv
├── camelyon17_seen.csv
├── camelyon16_total.csv
└── camelyon17.csv
├── visualization
├── SAC.png
├── ROCs.png
├── AttriMIL.png
├── Visualization.png
├── OODPerformance.png
└── AttributeScoring.png
├── env.yml
├── coords_to_feature.py
├── models
├── AttriMIL.py
├── resnet_custom_dep.py
├── ABMIL.py
├── TransMIL.py
├── DSMIL.py
├── MIL.py
└── S4MIL.py
├── constraints.py
├── create_3coords.py
├── .gitignore
├── tester_transmil.py
├── tester_mil.py
├── tester_dsmil.py
├── README.md
├── tester_attrimil_abmil.py
├── utils.py
├── trainer_mil.py
├── trainer_transmil.py
├── LICENSE
├── trainer_dsmil.py
├── trainer_attrimil_abmil.py
└── dataloader.py
/datasets/a:
--------------------------------------------------------------------------------
1 | a
2 |
--------------------------------------------------------------------------------
/visualization/SAC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/SAC.png
--------------------------------------------------------------------------------
/visualization/ROCs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/ROCs.png
--------------------------------------------------------------------------------
/visualization/AttriMIL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/AttriMIL.png
--------------------------------------------------------------------------------
/visualization/Visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/Visualization.png
--------------------------------------------------------------------------------
/visualization/OODPerformance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/OODPerformance.png
--------------------------------------------------------------------------------
/visualization/AttributeScoring.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MedCAI/AttriMIL/HEAD/visualization/AttributeScoring.png
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: attrimil
2 | dependencies:
3 | - python==3.10
4 | - conda-forge::openslide
5 | - pip
6 | - pip:
7 | - timm==0.9.8
8 | - torch
9 | - torchvision
10 | - h5py
11 | - pandas
12 | - PyYAML
13 | - opencv-python
14 | - matplotlib
15 | - scikit-learn
16 | - scipy
17 | - tqdm
18 | - openslide-python
19 | - git+https://github.com/oval-group/smooth-topk.git
20 | - tensorboardX
21 |
--------------------------------------------------------------------------------
/coords_to_feature.py:
--------------------------------------------------------------------------------
1 | import openslide
2 | import matplotlib.pyplot as plt
3 | import os
4 | import numpy as np
5 | import h5py
6 |
7 | if __name__ == "__main__":
8 | # 需要改保存路径,和patch_size!!
9 | orgin_path = '/data2/clh/NSCLC/resnet18_simclr/h5_files/'
10 | coord_path = '/data2/clh/NSCLC/coords/'
11 | save_path = '/data2/clh/NSCLC/resnet18_simclr/h5_coords_files/'
12 | patch_size = (256, 256)
13 | start = 0
14 | for step, name in enumerate(os.listdir(orgin_path)):
15 | if step < start:
16 | continue
17 | if name in os.listdir(save_path):
18 | print("exist:", name)
19 | continue
20 |
21 | if name.endswith('h5'):
22 | # 读取文件
23 | print("Loading:", name)
24 | h5 = h5py.File(orgin_path + name)
25 | coords = np.array(h5['coords'])
26 | features = np.array(h5['features'])
27 | h5.close()
28 | h5 = h5py.File(coord_path + name)
29 | nearest = np.array(h5['nearest'])
30 | h5.close()
31 | h5 = h5py.File(save_path + name, 'w') #写入文件
32 | h5['coords'] = coords
33 | h5['features'] = features
34 | h5['nearest'] = nearest #名称为image
35 | h5.close() #关闭文件
36 | print("coords:{}, features:{}, nearest:{}".format(coords.shape, features.shape, nearest.shape))
37 |
--------------------------------------------------------------------------------
/models/AttriMIL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class Attn_Net_Gated(nn.Module):
7 | def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
8 | super(Attn_Net_Gated, self).__init__()
9 | self.attention_a = [
10 | nn.Linear(L, D),
11 | nn.Tanh()]
12 |
13 | self.attention_b = [nn.Linear(L, D),
14 | nn.Sigmoid()]
15 | if dropout:
16 | self.attention_a.append(nn.Dropout(0.25))
17 | self.attention_b.append(nn.Dropout(0.25))
18 |
19 | self.attention_a = nn.Sequential(*self.attention_a)
20 | self.attention_b = nn.Sequential(*self.attention_b)
21 |
22 | self.attention_c = nn.Linear(D, n_classes)
23 |
24 | def forward(self, x):
25 | a = self.attention_a(x)
26 | b = self.attention_b(x)
27 | A = a.mul(b)
28 | A = self.attention_c(A) # N x n_classes
29 | return A, x
30 |
31 |
32 | class AttriMIL(nn.Module):
33 | '''
34 | Multi-Branch ABMIL with constraints
35 | '''
36 | def __init__(self, n_classes=2, dim=512):
37 | super().__init__()
38 | self.adaptor = nn.Sequential(nn.Linear(dim, dim//2),
39 | nn.ReLU(),
40 | nn.Linear(dim // 2 , dim))
41 |
42 | attention = []
43 | classifer = [nn.Linear(dim, 1) for i in range(n_classes)]
44 | for i in range(n_classes):
45 | attention.append(Attn_Net_Gated(L = dim, D = dim // 2,))
46 | self.attention_nets = nn.ModuleList(attention)
47 | self.classifiers = nn.ModuleList(classifer)
48 | self.n_classes = n_classes
49 | self.bias = nn.Parameter(torch.zeros(n_classes), requires_grad=True)
50 |
51 | def forward(self, h):
52 | h = h + self.adaptor(h)
53 | A_raw = torch.empty(self.n_classes, h.size(0), ) # N x 1
54 | instance_score = torch.empty(1, self.n_classes, h.size(0)).float().to(h.device)
55 | for c in range(self.n_classes):
56 | A, h = self.attention_nets[c](h)
57 | A = torch.transpose(A, 1, 0) # 1 x N
58 | A_raw[c] = A
59 | instance_score[0, c] = self.classifiers[c](h)[:, 0]
60 | attribute_score = torch.empty(1, self.n_classes, h.size(0)).float().to(h.device)
61 | for c in range(self.n_classes):
62 | attribute_score[0, c] = instance_score[0, c] * torch.exp(A_raw[c])
63 |
64 | logits = torch.empty(1, self.n_classes).float().to(h.device)
65 | for c in range(self.n_classes):
66 | logits[0, c] = torch.sum(attribute_score[0, c], keepdim=True, dim=-1) / torch.sum(torch.exp(A_raw[c]), dim=-1) + self.bias[c]
67 |
68 | Y_hat = torch.topk(logits, 1, dim = 1)[1]
69 | Y_prob = F.softmax(logits, dim = 1)
70 | results_dict = {}
71 | return logits, Y_prob, Y_hat, attribute_score, results_dict
--------------------------------------------------------------------------------
/constraints.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from utils import *
4 | import os
5 | import queue
6 | from sklearn.preprocessing import label_binarize
7 | from sklearn.metrics import roc_auc_score, roc_curve
8 | from sklearn.metrics import auc as calc_auc
9 |
10 |
11 | def spatial_constraint(A, n_classes, nearest, ks=3):
12 | loss_spatial = torch.tensor(0.0).to(device)
13 | # N = A.shape[-1]
14 | for c in range(1, n_classes):
15 | score = A[:, c] # N
16 | nearest_score = score[nearest] # N ks^2-1
17 | abs_nearest = torch.abs(nearest_score)
18 | max_indices = torch.argmax(abs_nearest, dim=1)
19 | local_prototype = nearest_score.gather(1, max_indices.view(-1, 1)).squeeze()
20 | # print(local_prototype[:10])
21 | loss_spatial += torch.mean(torch.abs(torch.tanh(score - local_prototype)))
22 | return loss_spatial
23 |
24 |
25 | def rank_constraint(data, label, model, A, n_classes, label_positive_list, label_negative_list):
26 | loss_rank = torch.tensor(0.0).to(device)
27 | for c in range(n_classes):
28 | if label == c:
29 | value, indice = torch.topk(A[0, c], k=1)
30 | h = data[indice.item(): indice.item() + 1] # top feature
31 | if label_positive_list[c].full():
32 | _ = label_positive_list[c].get()
33 | label_positive_list[c].put(h)
34 | if label_negative_list[c].empty():
35 | loss_rank = loss_rank + torch.tensor(0.0).to(device)
36 | else:
37 | h = label_negative_list[c].get()
38 | label_negative_list[c].put(h)
39 | _, _, _, Ah, _ = model(h.detach())
40 | if c != 0:
41 | loss_rank = loss_rank + torch.clamp(torch.mean(Ah[0, c] - value), min=0.0) + torch.clamp(torch.mean(-value), min=0.0) + torch.clamp(torch.mean(Ah[0, c]), min=0.0)
42 | else:
43 | loss_rank = loss_rank + torch.clamp(torch.mean(-value), min=0.0) + torch.clamp(torch.mean(Ah[0, c]), min=0.0)
44 | else:
45 | value, indice = torch.topk(A[0, c], k=1)
46 | h = data[indice.item(): indice.item() + 1] # top feature
47 | if label_negative_list[c].full():
48 | _ = label_negative_list[c].get()
49 | label_negative_list[c].put(h)
50 | if label_positive_list[c].empty():
51 | loss_rank = loss_rank + torch.tensor(0.0).to(device)
52 | else:
53 | h = label_positive_list[c].get()
54 | label_positive_list[c].put(h)
55 | _, _, _, Ah, _ = model(h.detach())
56 | if c != 0:
57 | loss_rank = loss_rank + torch.clamp(torch.mean(value - Ah[0, c]), min=0.0) + torch.clamp(torch.mean(value), min=0.0)
58 | else:
59 | loss_rank = loss_rank + torch.clamp(torch.mean(value), min=0.0) + torch.clamp(torch.mean(-Ah[0, c]), min=0.0)
60 | loss_rank = loss_rank / n_classes
61 | return loss_rank, label_positive_list, label_negative_list
62 |
--------------------------------------------------------------------------------
/models/resnet_custom_dep.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import torch.nn.functional as F
5 | import torch
6 | import torchvision.models as models
7 |
8 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9 |
10 | class Network(torch.nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 | # model = timm.create_model('resnet50_gn', num_classes=1000, pretrained=False)
14 | # path = './weights/resnet50_gn_a1h2-8fe6c4d0.pth'
15 | # model.load_state_dict(torch.load(path))
16 | path = './weights/simclr_resnet50.safetensors'
17 | model = timm.create_model('resnet50', pretrained=False, pretrained_cfg_overlay=dict(file=path))
18 | self.resnet = model
19 | self.avgpool = nn.AdaptiveAvgPool2d(1)
20 |
21 | def forward(self, x):
22 | x = self.resnet.conv1(x)
23 | x = self.resnet.bn1(x)
24 | x = self.resnet.act1(x)
25 | x = self.resnet.maxpool(x)
26 | x = self.resnet.layer1(x)
27 | # x = rearrange(h, 'b c h w -> b h w c')
28 | # x = self.adapter(x)
29 | # x = rearrange(x, 'b h w c -> b c h w')
30 | # x = x + h
31 | x = self.resnet.layer2(x)
32 | # x = rearrange(h, 'b c h w -> b h w c')
33 | # x = self.adapter(x)
34 | # x = rearrange(x, 'b h w c -> b c h w')
35 | # x = x + h
36 | x = self.resnet.layer3(x)
37 | x = self.avgpool(x)
38 | x = x.view(x.shape[0], -1)
39 | return x
40 |
41 |
42 | class ResNet18(torch.nn.Module):
43 | def __init__(self):
44 | super().__init__()
45 | # model = timm.create_model('resnet50_gn', num_classes=1000, pretrained=False)
46 | # path = './weights/resnet50_gn_a1h2-8fe6c4d0.pth'
47 | # model.load_state_dict(torch.load(path))
48 | # path = './weights/resnet18_imagenet.safetensors'
49 | # model = timm.create_model('resnet18', pretrained=False, pretrained_cfg_overlay=dict(file=path))
50 | resnet18 = models.resnet18(pretrained=True)
51 | self.resnet = resnet18
52 | self.avgpool = nn.AdaptiveAvgPool2d(1)
53 |
54 | def forward(self, x):
55 | x = F.interpolate(x, (256, 256), mode='bilinear')
56 | x = self.resnet.conv1(x)
57 | x = self.resnet.bn1(x)
58 | x = self.resnet.relu(x)
59 | x = self.resnet.maxpool(x)
60 | x = self.resnet.layer1(x)
61 | # x = rearrange(h, 'b c h w -> b h w c')
62 | # x = self.adapter(x)
63 | # x = rearrange(x, 'b h w c -> b c h w')
64 | # x = x + h
65 | x = self.resnet.layer2(x)
66 | # x = rearrange(h, 'b c h w -> b h w c')
67 | # x = self.adapter(x)
68 | # x = rearrange(x, 'b h w c -> b c h w')
69 | # x = x + h
70 | x = self.resnet.layer3(x)
71 | # x = rearrange(h, 'b c h w -> b h w c')
72 | # x = self.adapter(x)
73 | # x = rearrange(x, 'b h w c -> b c h w')
74 | # x = x + h
75 | x = self.resnet.layer4(x)
76 | x = self.avgpool(x)
77 | x = x.view(x.shape[0], -1)
78 | return x
79 |
80 | class IClassifier(nn.Module):
81 | def __init__(self, feature_extractor,):
82 | super(IClassifier, self).__init__()
83 |
84 | self.feature_extractor = feature_extractor
85 |
86 | def forward(self, x):
87 | device = x.device
88 | feats = self.feature_extractor(x) # N x K
89 | return feats.view(feats.shape[0], -1)
90 |
91 |
--------------------------------------------------------------------------------
/datasets/unitopatho_test.csv:
--------------------------------------------------------------------------------
1 | case_id,slide_id,label
2 | patient_1,120-B3-NORM_1,NORM
3 | patient_2,193-B4-NORM_1,NORM
4 | patient_3,155-B4-NORM_1,NORM
5 | patient_4,211-B5-NORM_1,NORM
6 | patient_5,186-B4-NORM_1,NORM
7 | patient_6,188-B4-NORM_1,NORM
8 | patient_7,263-B5-NORM_1,NORM
9 | patient_8,187-B4-NORM_1,NORM
10 | patient_9,248-B5-TVALG_5,TVA.LG
11 | patient_10,134-B3-TVALG_5,TVA.LG
12 | patient_11,243-B5-TVALG_5,TVA.LG
13 | patient_12,180-B4-TVALG_5,TVA.LG
14 | patient_13,254-B5-TVALG_5,TVA.LG
15 | patient_14,199-B4-TVALG_5,TVA.LG
16 | patient_15,TVA.LG CASO 3 - 2018-12-04 13.22.40_5,TVA.LG
17 | patient_16,197-B4-TVALG_5,TVA.LG
18 | patient_17,TA.HG CASO 2 - 2018-12-04 13.31.38_2,TA.HG
19 | patient_18,TA.HG CASO 12 - 2019-03-04 18.16.39_2,TA.HG
20 | patient_19,129-B3-TAHG_2,TA.HG
21 | patient_20,204-B5-TAHG_2,TA.HG
22 | patient_21,85-B2-TAHG_2,TA.HG
23 | patient_22,106-B3-TAHG_2,TA.HG
24 | patient_23,TA.HG CASO 4 - 2019-03-04 10.26.19_2,TA.HG
25 | patient_24,56-B2-TAHG_2,TA.HG
26 | patient_25,119-B3-TAHG_2,TA.HG
27 | patient_26,147-B3-HP_0,HP
28 | patient_27,110-B3-HP_0,HP
29 | patient_28,242-B5-HP_0,HP
30 | patient_29,HP CASO 27 - 2019-03-04 17.03.07_0,HP
31 | patient_30,192-B4-HP_0,HP
32 | patient_31,246-B5-HP_0,HP
33 | patient_32,160-B4-HP_0,HP
34 | patient_33,HP CASO 25 - 2019-03-04 09.16.56_0,HP
35 | patient_34,164-B4-HP_0,HP
36 | patient_35,202-B5-HP_0,HP
37 | patient_36,241-B5-TALG_3,TA.LG
38 | patient_37,TA.LG CASO 96_3,TA.LG
39 | patient_38,53-B2-TALG_3,TA.LG
40 | patient_39,TA.LG CASO 44 - 2019-03-01 08.25.42_3,TA.LG
41 | patient_40,TA.LG CASO 94_3,TA.LG
42 | patient_41,67-B2-TALG_3,TA.LG
43 | patient_42,222-B5-TALG_3,TA.LG
44 | patient_43,TA.LG CASO 90_3,TA.LG
45 | patient_44,64-B2-TALG_3,TA.LG
46 | patient_45,122-B3-TALG_3,TA.LG
47 | patient_46,TA.LG CASO 61 - 2019-03-04 16.54.59_3,TA.LG
48 | patient_47,TA.LG CASO 87_3,TA.LG
49 | patient_48,121-B3-TALG_3,TA.LG
50 | patient_49,11-B1TALG_3,TA.LG
51 | patient_50,TA.LG CASO 43 - 2019-03-01 08.01.10_3,TA.LG
52 | patient_51,63-B2-TALG_3,TA.LG
53 | patient_52,TA.LG CASO 1 - 2018-12-04 12.52.01_3,TA.LG
54 | patient_53,TA.LG CASO 63 - 2019-03-04 17.23.02_3,TA.LG
55 | patient_54,TA.LG CASO 98_3,TA.LG
56 | patient_55,TA.LG CASO 92 A1_3,TA.LG
57 | patient_56,TA.LG CASO 101 D1_3,TA.LG
58 | patient_57,123-B3-TALG_3,TA.LG
59 | patient_58,99-B2-TALG_3,TA.LG
60 | patient_59,175-B4-TALG_3,TA.LG
61 | patient_60,TA.LG CASO 95_3,TA.LG
62 | patient_61,71-B2-TALG_3,TA.LG
63 | patient_62,88-B2-TALG_3,TA.LG
64 | patient_63,TA.LG CASO 54 - 2019-03-04 09.33.07_3,TA.LG
65 | patient_64,126-B3-TALG_3,TA.LG
66 | patient_65,TA.LG CASO 48 - 2019-03-01 09.02.35_3,TA.LG
67 | patient_66,141-B3-TALG_3,TA.LG
68 | patient_67,214-B5-TALG_3,TA.LG
69 | patient_68,179-B4-TALG_3,TA.LG
70 | patient_69,TA.LG CASO 12 - 2018-12-04 13.48.28_3,TA.LG
71 | patient_70,77-B2-TALG_3,TA.LG
72 | patient_71,TA.LG CASO 103_3,TA.LG
73 | patient_72,97-B2-TALG_3,TA.LG
74 | patient_73,183-B4-TALG_3,TA.LG
75 | patient_74,107-B3-TALG_3,TA.LG
76 | patient_75,195-B4-TALG_3,TA.LG
77 | patient_76,78-B2-TALG_3,TA.LG
78 | patient_77,TA.LG CASO 46 - 2019-03-01 08.37.03_3,TA.LG
79 | patient_78,128-B3-TALG_3,TA.LG
80 | patient_79,247-B5-TALG_3,TA.LG
81 | patient_80,253-B5-TALG_3,TA.LG
82 | patient_81,TA.LG CASO 83_3,TA.LG
83 | patient_82,92-B2-TVAHG_4,TVA.HG
84 | patient_83,240-B5-TVAHG_4,TVA.HG
85 | patient_84,59-B2-TVAHG_4,TVA.HG
86 | patient_85,TVA.HG CASO 11_4,TVA.HG
87 | patient_86,TVA.HG CASO 13_4,TVA.HG
88 | patient_87,238-B5-TVAHG_4,TVA.HG
89 | patient_88,TVA.HG CASO 9 - 2019-03-04 10.19.16_4,TVA.HG
90 |
--------------------------------------------------------------------------------
/models/ABMIL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class Attn_Net_Gated(nn.Module):
7 | def __init__(self, L = 1024, D = 256, dropout = False, n_classes = 1):
8 | super(Attn_Net_Gated, self).__init__()
9 | self.attention_a = [
10 | nn.Linear(L, D),
11 | nn.Tanh()]
12 |
13 | self.attention_b = [nn.Linear(L, D),
14 | nn.Sigmoid()]
15 | if dropout:
16 | self.attention_a.append(nn.Dropout(0.25))
17 | self.attention_b.append(nn.Dropout(0.25))
18 |
19 | self.attention_a = nn.Sequential(*self.attention_a)
20 | self.attention_b = nn.Sequential(*self.attention_b)
21 |
22 | self.attention_c = nn.Linear(D, n_classes)
23 |
24 | def forward(self, x):
25 | a = self.attention_a(x)
26 | b = self.attention_b(x)
27 | A = a.mul(b)
28 | A = self.attention_c(A) # N x n_classes
29 | return A, x
30 |
31 |
32 | class ABMIL(nn.Module):
33 | def __init__(self, n_classes=2, dim=512):
34 | super().__init__()
35 | self.adaptor = nn.Sequential(nn.Linear(dim, dim//2),
36 | nn.ReLU(),
37 | nn.Linear(dim // 2 , dim))
38 |
39 | self.attention_net = Attn_Net_Gated(L = dim, D = dim // 2, n_classes=1)
40 | self.classifier = nn.Linear(dim, n_classes)
41 | self.n_classes = n_classes
42 |
43 | def forward(self, h):
44 | h = h + self.adaptor(h)
45 | logits = torch.empty(1, self.n_classes).float().to(h.device)
46 | A, h = self.attention_net(h)
47 | A = torch.transpose(A, 1, 0) # 1 x N
48 | A_raw = A
49 | A = F.softmax(A, dim=1) # softmax over N
50 |
51 | M = torch.mm(A, h) # 1 x dim
52 | logits = self.classifier(M)
53 | Y_hat = torch.topk(logits, 1, dim = 1)[1]
54 | Y_prob = F.softmax(logits, dim = 1)
55 | results_dict = {}
56 | results_dict.update({'features': M})
57 | return logits, Y_prob, Y_hat, A_raw, results_dict
58 |
59 |
60 | class ABMIL_MB(nn.Module):
61 | def __init__(self, n_classes=2, dim=512):
62 | super().__init__()
63 | self.adaptor = nn.Sequential(nn.Linear(dim, dim//2),
64 | nn.ReLU(),
65 | nn.Linear(dim // 2 , dim))
66 |
67 | attention = []
68 | classifer = [nn.Linear(dim, 1) for i in range(n_classes)]
69 | for i in range(n_classes):
70 | attention.append(Attn_Net_Gated(L = dim, D = dim // 2,))
71 | self.attention_nets = nn.ModuleList(attention)
72 | self.classifiers = nn.ModuleList(classifer)
73 | self.n_classes = n_classes
74 |
75 | def forward(self, h):
76 | h = h + self.adaptor(h)
77 | logits = torch.empty(1, self.n_classes).float().to(h.device)
78 | A_raw = torch.empty(self.n_classes, h.size(0), ) # N x 1
79 | for c in range(self.n_classes):
80 | A, h = self.attention_nets[c](h)
81 | A = torch.transpose(A, 1, 0) # 1 x N
82 | A = F.softmax(A, dim=1) # softmax over N
83 | A_raw[c] = A
84 | M = torch.mm(A, h)
85 | logits[0, c] = self.classifiers[c](M)
86 | Y_hat = torch.topk(logits, 1, dim = 1)[1]
87 | Y_prob = F.softmax(logits, dim = 1)
88 | results_dict = {}
89 | results_dict.update({'features': M})
90 | return logits, Y_prob, Y_hat, A_raw, results_dict
91 |
--------------------------------------------------------------------------------
/create_3coords.py:
--------------------------------------------------------------------------------
1 | import openslide
2 | import matplotlib.pyplot as plt
3 | import os
4 | import numpy as np
5 | import h5py
6 | from joblib import Parallel, delayed
7 |
8 |
9 | def find_nearest(input_path,
10 | output_path,
11 | patch_size=(256, 256)):
12 | print("Loading:", name)
13 | h5 = h5py.File(input_path)
14 | coords = np.array(h5['coords'])
15 | # features = np.array(h5['features'])
16 | h5.close()
17 |
18 | nearest = []
19 | # left, right, up, down, left_up, left_down, right_up, right_down
20 | for step, p in enumerate(coords):
21 | exists = [np.array(step)]
22 | left = np.array([p[0], p[1] - patch_size[1]])
23 | loc = np.where(np.sum(coords == left, axis=1) == 2)[0]
24 | if len(loc) != 0:
25 | exists.append(loc[0])
26 | else:
27 | exists.append(np.array(step))
28 |
29 | right = np.array([p[0], p[1] + patch_size[1]])
30 | loc = np.where(np.sum(coords == right, axis=1) == 2)[0]
31 | if len(loc) != 0:
32 | exists.append(loc[0])
33 | else:
34 | exists.append(np.array(step))
35 |
36 | up = np.array([p[0] - patch_size[0], p[1]])
37 | loc = np.where(np.sum(coords == up, axis=1) == 2)[0]
38 | if len(loc) != 0:
39 | exists.append(loc[0])
40 | else:
41 | exists.append(np.array(step))
42 |
43 | down = np.array([p[0] + patch_size[0], p[1]])
44 | loc = np.where(np.sum(coords == down, axis=1) == 2)[0]
45 | if len(loc) != 0:
46 | exists.append(loc[0])
47 | else:
48 | exists.append(np.array(step))
49 |
50 | left_up = np.array([p[0] - patch_size[0], p[1] - patch_size[1]])
51 | loc = np.where(np.sum(coords == left_up, axis=1) == 2)[0]
52 | if len(loc) != 0:
53 | exists.append(loc[0])
54 | else:
55 | exists.append(np.array(step))
56 |
57 | left_down = np.array([p[0] + patch_size[0], p[1] - patch_size[1]])
58 | loc = np.where(np.sum(coords == left_down, axis=1) == 2)[0]
59 | if len(loc) != 0:
60 | exists.append(loc[0])
61 | else:
62 | exists.append(np.array(step))
63 |
64 | right_up = np.array([p[0] - patch_size[0], p[1] + patch_size[1]])
65 | loc = np.where(np.sum(coords == right_up, axis=1) == 2)[0]
66 | if len(loc) != 0:
67 | exists.append(loc[0])
68 | else:
69 | exists.append(np.array(step))
70 |
71 | right_down = np.array([p[0] + patch_size[0], p[1] + patch_size[1]])
72 | loc = np.where(np.sum(coords == right_down, axis=1) == 2)[0]
73 | if len(loc) != 0:
74 | exists.append(loc[0])
75 | else:
76 | exists.append(np.array(step))
77 | nearest.append(exists)
78 |
79 | h5 = h5py.File(output_path, 'w') # 写入文件
80 | h5['coords'] = coords
81 | # h5['features'] = features
82 | h5['nearest'] = nearest # 名称为 image
83 | h5.close() #关闭文件
84 | return
85 |
86 |
87 | if __name__ == "__main__":
88 | # 需要改保存路径,和patch_size!!
89 | orgin_path = '/data2/clh/NSCLC/LUAD/20X/patches/'
90 | save_path = '/data2/clh/NSCLC/coords/'
91 | patch_size = (256, 256)
92 | start = 0
93 | name_list = []
94 | for step, name in enumerate(os.listdir(orgin_path)):
95 | if step < start:
96 | continue
97 | if name.endswith('h5'):
98 | name_list.append(name)
99 | Parallel(n_jobs=32)(delayed(find_nearest)(os.path.join(orgin_path, slide), os.path.join(save_path, slide)) for slide in name_list)
--------------------------------------------------------------------------------
/models/TransMIL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from nystrom_attention import NystromAttention
6 |
7 |
8 | class TransLayer(nn.Module):
9 |
10 | def __init__(self, norm_layer=nn.LayerNorm, dim=512):
11 | super().__init__()
12 | self.norm = norm_layer(dim)
13 | self.attn = NystromAttention(
14 | dim = dim,
15 | dim_head = dim//8,
16 | heads = 8,
17 | num_landmarks = dim//2, # number of landmarks
18 | pinv_iterations = 6, # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
19 | residual = True, # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
20 | dropout=0.1
21 | )
22 |
23 | def forward(self, x):
24 | out = self.attn(self.norm(x))
25 | x = x + out
26 |
27 | return x
28 |
29 |
30 | class PPEG(nn.Module):
31 | def __init__(self, dim=512):
32 | super(PPEG, self).__init__()
33 | self.proj = nn.Conv2d(dim, dim, 7, 1, 7//2, groups=dim)
34 | self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim)
35 | self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)
36 |
37 | def forward(self, x, H, W):
38 | B, _, C = x.shape
39 | cls_token, feat_token = x[:, 0], x[:, 1:]
40 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
41 | x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
42 | x = x.flatten(2).transpose(1, 2)
43 | x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
44 | return x
45 |
46 |
47 | class TransMIL(nn.Module):
48 | def __init__(self, dim=512, n_classes=2):
49 | super(TransMIL, self).__init__()
50 | self.pos_layer = PPEG(dim=256)
51 | self._fc1 = nn.Sequential(nn.Linear(dim, 256), nn.ReLU())
52 | self.cls_token = nn.Parameter(torch.randn(1, 1, 256))
53 | self.n_classes = n_classes
54 | self.layer1 = TransLayer(dim=256)
55 | self.layer2 = TransLayer(dim=256)
56 | self.norm = nn.LayerNorm(256)
57 | self._fc2 = nn.Linear(256, self.n_classes)
58 |
59 |
60 | def forward(self, h):
61 | h = h.unsqueeze(0)
62 | h = self._fc1(h) #[B, n, 512]
63 | #---->pad
64 | H = h.shape[1]
65 | _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
66 | add_length = _H * _W - H
67 | h = torch.cat([h, h[:,:add_length,:]],dim = 1) #[B, N, 512]
68 |
69 | #---->cls_token
70 | B = h.shape[0]
71 | cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
72 | # print(h.shape, cls_tokens.shape)
73 | h = torch.cat((cls_tokens, h), dim=1)
74 |
75 | #---->Translayer x1
76 | h = self.layer1(h) #[B, N+1, 512]
77 |
78 | #---->PPEG
79 | h = self.pos_layer(h, _H, _W) #[B, N+1, 512]
80 |
81 | #---->Translayer x2
82 | h0 = self.layer2(h) #[B, N+1, 512]
83 |
84 | #---->cls_token
85 | h = self.norm(h0)[:,0]
86 |
87 | #---->predict
88 | logits = self._fc2(h) #[B, n_classes]
89 | Y_hat = torch.argmax(logits, dim=1)
90 | Y_prob = F.softmax(logits, dim = 1)
91 | # results_dict = {'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat}
92 | return logits, Y_prob, Y_hat, h0
93 |
94 | if __name__ == "__main__":
95 | data = torch.randn((1, 6000, 512)).cuda()
96 | model = TransMIL(n_classes=2).cuda()
97 | print(model.eval())
98 | results_dict = model(data)
99 | print(results_dict)
100 |
--------------------------------------------------------------------------------
/models/DSMIL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | class FCLayer(nn.Module):
7 | def __init__(self, in_size, out_size=1):
8 | super(FCLayer, self).__init__()
9 | self.fc = nn.Sequential(nn.Linear(in_size, out_size))
10 | def forward(self, feats):
11 | x = self.fc(feats)
12 | return feats, x
13 |
14 | class IClassifier(nn.Module):
15 | def __init__(self, feature_extractor, feature_size, output_class):
16 | super(IClassifier, self).__init__()
17 |
18 | self.feature_extractor = feature_extractor
19 | self.fc = nn.Linear(feature_size, output_class)
20 |
21 |
22 | def forward(self, x):
23 | device = x.device
24 | feats = self.feature_extractor(x) # N x K
25 | c = self.fc(feats.view(feats.shape[0], -1)) # N x C
26 | return feats.view(feats.shape[0], -1), c
27 |
28 | class BClassifier(nn.Module):
29 | def __init__(self, input_size, output_class, dropout_v=0.25, nonlinear=False, passing_v=False): # K, L, N
30 | super(BClassifier, self).__init__()
31 | if nonlinear:
32 | self.q = nn.Sequential(nn.Linear(input_size, 128), nn.ReLU(), nn.Linear(128, 128), nn.Tanh())
33 | else:
34 | self.q = nn.Linear(input_size, 128)
35 | if passing_v:
36 | self.v = nn.Sequential(
37 | nn.Dropout(dropout_v),
38 | nn.Linear(input_size, input_size),
39 | nn.ReLU()
40 | )
41 | else:
42 | self.v = nn.Identity()
43 |
44 | ### 1D convolutional layer that can handle multiple class (including binary)
45 | self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size)
46 |
47 | def forward(self, feats, c): # N x K, N x C
48 | device = feats.device
49 | V = self.v(feats) # N x V, unsorted
50 | Q = self.q(feats).view(feats.shape[0], -1) # N x Q, unsorted
51 |
52 | # handle multiple classes without for loop
53 | _, m_indices = torch.sort(c, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
54 | m_feats = torch.index_select(feats, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K
55 | q_max = self.q(m_feats) # compute queries of critical instances, q_max in shape C x Q
56 | A_raw = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
57 | A = F.softmax( A_raw / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,
58 | B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V
59 |
60 | B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
61 | C = self.fcc(B) # 1 x C x 1
62 | C = C.view(1, -1)
63 | return C, A_raw, B
64 |
65 | class MILNet(nn.Module):
66 | def __init__(self, feature_dim, n_classes):
67 | super(MILNet, self).__init__()
68 | self.adapter = nn.Sequential(nn.Linear(feature_dim, feature_dim//2),
69 | nn.ReLU(),
70 | nn.Linear(feature_dim//2, feature_dim))
71 | self.i_classifier = FCLayer(in_size=feature_dim, out_size=n_classes)
72 | self.b_classifier = BClassifier(input_size=feature_dim, output_class=n_classes)
73 |
74 | def forward(self, x):
75 | x = self.adapter(x) + x
76 | feats, classes = self.i_classifier(x)
77 | prediction_bag, A, B = self.b_classifier(feats, classes)
78 |
79 | return classes, prediction_bag, A, B
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/models/MIL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class MIL_MeanPooling(nn.Module):
6 | def __init__(self,
7 | n_classes = 2,
8 | top_k=1,
9 | embed_dim=512):
10 | super().__init__()
11 | fc = [nn.Linear(embed_dim, embed_dim // 2), nn.ReLU(), nn.Dropout(0.25)]
12 | self.fc = nn.Sequential(*fc)
13 | self.classifier= nn.Linear(embed_dim // 2, n_classes)
14 | self.top_k=top_k
15 |
16 | def forward(self, h, return_features=False):
17 | h = self.fc(h)
18 | h = torch.mean(h, dim=0, keepdim=True)
19 | logits = self.classifier(h)
20 |
21 | y_probs = F.softmax(logits, dim = 1)
22 | top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1,)
23 | top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
24 | Y_hat = torch.topk(top_instance, 1, dim = 1)[1]
25 | Y_prob = F.softmax(top_instance, dim = 1)
26 | results_dict = {}
27 |
28 | if return_features:
29 | top_features = torch.index_select(h, dim=0, index=top_instance_idx)
30 | results_dict.update({'features': top_features})
31 | return top_instance, Y_prob, Y_hat, y_probs, results_dict
32 |
33 | class MIL_MaxPooling(nn.Module):
34 | def __init__(self,
35 | n_classes = 2,
36 | top_k=1,
37 | embed_dim=512):
38 | super().__init__()
39 | fc = [nn.Linear(embed_dim, embed_dim//2), nn.ReLU(), nn.Dropout(0.25)]
40 | self.fc = nn.Sequential(*fc)
41 | self.classifiers = nn.Linear(embed_dim//2, n_classes)
42 | self.top_k=top_k
43 | self.n_classes = n_classes
44 | assert self.top_k == 1
45 |
46 | def forward(self, h, return_features=False):
47 | h = self.fc(h)
48 | logits = self.classifiers(h)
49 |
50 | y_probs = F.softmax(logits, dim = 1)
51 | m = y_probs.view(1, -1).argmax(1)
52 | top_indices = torch.cat(((m // self.n_classes).view(-1, 1), (m % self.n_classes).view(-1, 1)), dim=1).view(-1, 1)
53 | top_instance = logits[top_indices[0]]
54 |
55 | Y_hat = top_indices[1]
56 | Y_prob = y_probs[top_indices[0]]
57 |
58 | results_dict = {}
59 |
60 | if return_features:
61 | top_features = torch.index_select(h, dim=0, index=top_indices[0])
62 | results_dict.update({'features': top_features})
63 | return top_instance, Y_prob, Y_hat, y_probs, results_dict
64 |
65 |
66 | class MIL_RNN(nn.Module):
67 | def __init__(self, n_classes = 2, embed_dim = 512):
68 | super().__init__()
69 | self.dropout = nn.Dropout(0.25)
70 | self.rnn = nn.RNN(embed_dim, 128, 3, batch_first=True, bidirectional=True)
71 | self.classifier = nn.Linear(256, n_classes)
72 | self.top_k = 1
73 |
74 | def forward(self, h, return_features=False):
75 | h0 = torch.zeros(6, 1, 128).to(h.device)
76 | h = self.dropout(h)
77 | h = h.unsqueeze(0)
78 | if return_features:
79 | h, _ = self.rnn(h, h0)
80 | h = h.squeeze(1)
81 | logits = self.classifier(h)
82 | else:
83 | h, _ = self.rnn(h, h0)
84 | h = h.squeeze(0)
85 | logits = self.classifier(h)
86 | y_probs = F.softmax(logits, dim = 1)
87 | top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1,)
88 | top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
89 | Y_hat = torch.topk(top_instance, 1, dim = 1)[1]
90 | Y_prob = F.softmax(top_instance, dim = 1)
91 | results_dict = {}
92 |
93 | if return_features:
94 | top_features = torch.index_select(h, dim=0, index=top_instance_idx)
95 | results_dict.update({'features': top_features})
96 | return top_instance, Y_prob, Y_hat, y_probs, results_dict
--------------------------------------------------------------------------------
/datasets/camelyon16_test.csv:
--------------------------------------------------------------------------------
1 | case_id,slide_id,label
2 | patient_1,test_001,tumor_tissue
3 | patient_2,test_002,tumor_tissue
4 | patient_3,test_003,normal_tissue
5 | patient_4,test_004,tumor_tissue
6 | patient_5,test_005,normal_tissue
7 | patient_6,test_006,normal_tissue
8 | patient_7,test_007,normal_tissue
9 | patient_8,test_008,tumor_tissue
10 | patient_9,test_009,normal_tissue
11 | patient_10,test_010,tumor_tissue
12 | patient_11,test_011,tumor_tissue
13 | patient_12,test_012,normal_tissue
14 | patient_13,test_013,tumor_tissue
15 | patient_14,test_014,normal_tissue
16 | patient_15,test_015,normal_tissue
17 | patient_16,test_016,tumor_tissue
18 | patient_17,test_017,normal_tissue
19 | patient_18,test_018,normal_tissue
20 | patient_19,test_019,normal_tissue
21 | patient_20,test_020,normal_tissue
22 | patient_21,test_021,tumor_tissue
23 | patient_22,test_022,normal_tissue
24 | patient_23,test_023,normal_tissue
25 | patient_24,test_024,normal_tissue
26 | patient_25,test_025,normal_tissue
27 | patient_26,test_026,tumor_tissue
28 | patient_27,test_027,tumor_tissue
29 | patient_28,test_028,normal_tissue
30 | patient_29,test_029,tumor_tissue
31 | patient_30,test_030,tumor_tissue
32 | patient_31,test_031,normal_tissue
33 | patient_32,test_032,normal_tissue
34 | patient_33,test_033,tumor_tissue
35 | patient_34,test_034,normal_tissue
36 | patient_35,test_035,normal_tissue
37 | patient_36,test_036,normal_tissue
38 | patient_37,test_037,normal_tissue
39 | patient_38,test_038,tumor_tissue
40 | patient_39,test_039,normal_tissue
41 | patient_40,test_040,tumor_tissue
42 | patient_41,test_041,normal_tissue
43 | patient_42,test_042,normal_tissue
44 | patient_43,test_043,normal_tissue
45 | patient_44,test_044,normal_tissue
46 | patient_45,test_045,normal_tissue
47 | patient_46,test_046,tumor_tissue
48 | patient_47,test_047,normal_tissue
49 | patient_48,test_048,tumor_tissue
50 | patient_49,test_050,normal_tissue
51 | patient_50,test_051,tumor_tissue
52 | patient_51,test_052,tumor_tissue
53 | patient_52,test_053,normal_tissue
54 | patient_53,test_054,normal_tissue
55 | patient_54,test_055,normal_tissue
56 | patient_55,test_056,normal_tissue
57 | patient_56,test_057,normal_tissue
58 | patient_57,test_058,normal_tissue
59 | patient_58,test_059,normal_tissue
60 | patient_59,test_060,normal_tissue
61 | patient_60,test_061,tumor_tissue
62 | patient_61,test_062,normal_tissue
63 | patient_62,test_063,normal_tissue
64 | patient_63,test_064,tumor_tissue
65 | patient_64,test_065,tumor_tissue
66 | patient_65,test_066,tumor_tissue
67 | patient_66,test_067,normal_tissue
68 | patient_67,test_068,tumor_tissue
69 | patient_68,test_069,tumor_tissue
70 | patient_69,test_070,normal_tissue
71 | patient_70,test_071,tumor_tissue
72 | patient_71,test_072,normal_tissue
73 | patient_72,test_073,tumor_tissue
74 | patient_73,test_074,tumor_tissue
75 | patient_74,test_075,tumor_tissue
76 | patient_75,test_076,normal_tissue
77 | patient_76,test_077,normal_tissue
78 | patient_77,test_078,normal_tissue
79 | patient_78,test_079,tumor_tissue
80 | patient_79,test_080,normal_tissue
81 | patient_80,test_081,normal_tissue
82 | patient_81,test_082,tumor_tissue
83 | patient_82,test_083,normal_tissue
84 | patient_83,test_084,tumor_tissue
85 | patient_84,test_085,normal_tissue
86 | patient_85,test_086,normal_tissue
87 | patient_86,test_087,normal_tissue
88 | patient_87,test_088,normal_tissue
89 | patient_88,test_089,normal_tissue
90 | patient_89,test_090,tumor_tissue
91 | patient_90,test_091,normal_tissue
92 | patient_91,test_092,tumor_tissue
93 | patient_92,test_093,normal_tissue
94 | patient_93,test_094,tumor_tissue
95 | patient_94,test_095,normal_tissue
96 | patient_95,test_096,normal_tissue
97 | patient_96,test_097,tumor_tissue
98 | patient_97,test_098,normal_tissue
99 | patient_98,test_099,tumor_tissue
100 | patient_99,test_100,normal_tissue
101 | patient_100,test_101,normal_tissue
102 | patient_101,test_102,tumor_tissue
103 | patient_102,test_103,normal_tissue
104 | patient_103,test_104,tumor_tissue
105 | patient_104,test_105,tumor_tissue
106 | patient_105,test_106,normal_tissue
107 | patient_106,test_107,normal_tissue
108 | patient_107,test_108,tumor_tissue
109 | patient_108,test_109,normal_tissue
110 | patient_109,test_110,tumor_tissue
111 | patient_110,test_111,normal_tissue
112 | patient_111,test_112,normal_tissue
113 | patient_112,test_113,tumor_tissue
114 | patient_113,test_114,tumor_tissue
115 | patient_114,test_115,normal_tissue
116 | patient_115,test_116,tumor_tissue
117 | patient_116,test_117,tumor_tissue
118 | patient_117,test_118,normal_tissue
119 | patient_118,test_119,normal_tissue
120 | patient_119,test_120,normal_tissue
121 | patient_120,test_121,tumor_tissue
122 | patient_121,test_122,tumor_tissue
123 | patient_122,test_123,normal_tissue
124 | patient_123,test_124,normal_tissue
125 | patient_124,test_125,normal_tissue
126 | patient_125,test_126,normal_tissue
127 | patient_126,test_127,normal_tissue
128 | patient_127,test_128,normal_tissue
129 | patient_128,test_129,normal_tissue
130 | patient_129,test_130,normal_tissue
131 |
--------------------------------------------------------------------------------
/tester_transmil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.TransMIL import TransMIL
18 | from utils import *
19 |
20 |
21 | class Accuracy_Logger(object):
22 | """Accuracy logger"""
23 | def __init__(self, n_classes):
24 | super().__init__()
25 | self.n_classes = n_classes
26 | self.initialize()
27 |
28 | def initialize(self):
29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
30 |
31 | def log(self, Y_hat, Y):
32 | Y_hat = int(Y_hat)
33 | Y = int(Y)
34 | self.data[Y]["count"] += 1
35 | self.data[Y]["correct"] += (Y_hat == Y)
36 |
37 | def log_batch(self, Y_hat, Y):
38 | Y_hat = np.array(Y_hat).astype(int)
39 | Y = np.array(Y).astype(int)
40 | for label_class in np.unique(Y):
41 | cls_mask = Y == label_class
42 | self.data[label_class]["count"] += cls_mask.sum()
43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
44 |
45 | def get_summary(self, c):
46 | count = self.data[c]["count"]
47 | correct = self.data[c]["correct"]
48 |
49 | if count == 0:
50 | acc = None
51 | else:
52 | acc = float(correct) / count
53 |
54 | return acc, correct, count
55 |
56 | def summary(model, loader, n_classes):
57 | acc_logger = Accuracy_Logger(n_classes=n_classes)
58 | model.eval()
59 | test_loss = 0.
60 | test_error = 0.
61 |
62 | all_probs = np.zeros((len(loader), n_classes))
63 | all_labels = np.zeros(len(loader))
64 | all_preds = np.zeros(len(loader))
65 |
66 | slide_ids = loader.dataset.slide_data['slide_id']
67 | patient_results = {}
68 | for batch_idx, (data, label) in enumerate(loader):
69 | data, label = data.to(device), label.to(device)
70 | slide_id = slide_ids.iloc[batch_idx]
71 | with torch.inference_mode():
72 | logits, Y_prob, Y_hat, h0 = model(data)
73 |
74 | acc_logger.log(Y_hat, label)
75 |
76 | probs = Y_prob.cpu().numpy()
77 |
78 | all_probs[batch_idx] = probs
79 | all_labels[batch_idx] = label.item()
80 | all_preds[batch_idx] = Y_hat.item()
81 |
82 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
83 |
84 | error = calculate_error(Y_hat, label)
85 | test_error += error
86 |
87 | del data
88 | test_error /= len(loader)
89 |
90 | aucs = []
91 | if len(np.unique(all_labels)) == 1:
92 | auc_score = -1
93 |
94 | else:
95 | if n_classes == 2:
96 | auc_score = roc_auc_score(all_labels, all_probs[:, 1])
97 | else:
98 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
99 | for class_idx in range(n_classes):
100 | if class_idx in all_labels:
101 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
102 | aucs.append(calc_auc(fpr, tpr))
103 | else:
104 | aucs.append(float('nan'))
105 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
106 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel())
107 | auc_score = calc_auc(fpr, tpr)
108 |
109 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds}
110 | for c in range(n_classes):
111 | results_dict.update({'p_{}'.format(c): all_probs[:,c]})
112 | df = pd.DataFrame(results_dict)
113 | return patient_results, test_error, auc_score, df, acc_logger
114 |
115 | if __name__ == "__main__":
116 | save_dir = './results/camelyon16to17unseen_transmil_simclr_100/'
117 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon17_unseen.csv'
118 | data_dir = '/data2/clh/camelyon17/resnet18_simclr/'
119 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_transmil_simclr_100/'
120 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/'
121 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
122 | data_dir = data_dir,
123 | shuffle = False,
124 | print_info = True,
125 | label_dict = {'normal_tissue':0, 'tumor_tissue':1},
126 | patient_strat=False,
127 | ignore=[])
128 | os.makedirs(save_dir, exist_ok=True)
129 | model = TransMIL(dim=512, n_classes=2).cuda()
130 | folds = [0, 1, 2, 3, 4]
131 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds]
132 | all_results = []
133 | all_auc = []
134 | all_acc = []
135 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)]
136 | for ckpt_idx in range(len(ckpt_paths)):
137 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx])
138 | # loader = get_split_loader(test_dataset, testing = False)
139 | loader = get_simple_loader(dataset)
140 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx]))
141 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2)
142 | all_results.append(all_results)
143 | all_auc.append(auc)
144 | all_acc.append(1-test_error)
145 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False)
146 |
147 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc})
148 | save_name = 'summary.csv'
149 | final_df.to_csv(os.path.join(save_dir, save_name))
150 |
--------------------------------------------------------------------------------
/tester_mil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.MIL import *
18 | from utils import *
19 |
20 |
21 | class Accuracy_Logger(object):
22 | """Accuracy logger"""
23 | def __init__(self, n_classes):
24 | super().__init__()
25 | self.n_classes = n_classes
26 | self.initialize()
27 |
28 | def initialize(self):
29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
30 |
31 | def log(self, Y_hat, Y):
32 | Y_hat = int(Y_hat)
33 | Y = int(Y)
34 | self.data[Y]["count"] += 1
35 | self.data[Y]["correct"] += (Y_hat == Y)
36 |
37 | def log_batch(self, Y_hat, Y):
38 | Y_hat = np.array(Y_hat).astype(int)
39 | Y = np.array(Y).astype(int)
40 | for label_class in np.unique(Y):
41 | cls_mask = Y == label_class
42 | self.data[label_class]["count"] += cls_mask.sum()
43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
44 |
45 | def get_summary(self, c):
46 | count = self.data[c]["count"]
47 | correct = self.data[c]["correct"]
48 |
49 | if count == 0:
50 | acc = None
51 | else:
52 | acc = float(correct) / count
53 |
54 | return acc, correct, count
55 |
56 | def summary(model, loader, n_classes):
57 | acc_logger = Accuracy_Logger(n_classes=n_classes)
58 | model.eval()
59 | test_loss = 0.
60 | test_error = 0.
61 |
62 | all_probs = np.zeros((len(loader), n_classes))
63 | all_labels = np.zeros(len(loader))
64 | all_preds = np.zeros(len(loader))
65 |
66 | slide_ids = loader.dataset.slide_data['slide_id']
67 | patient_results = {}
68 | for batch_idx, (data, label) in enumerate(loader):
69 | data, label = data.to(device), label.to(device)
70 | slide_id = slide_ids.iloc[batch_idx]
71 | with torch.inference_mode():
72 | logits, Y_prob, Y_hat, y_probs, _ = model(data)
73 | acc_logger.log(Y_hat, label)
74 |
75 | probs = Y_prob.cpu().numpy()
76 |
77 | all_probs[batch_idx] = probs
78 | all_labels[batch_idx] = label.item()
79 | all_preds[batch_idx] = Y_hat.item()
80 |
81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
82 |
83 | error = calculate_error(Y_hat, label)
84 | test_error += error
85 |
86 | del data
87 | test_error /= len(loader)
88 |
89 | aucs = []
90 | if len(np.unique(all_labels)) == 1:
91 | auc_score = -1
92 |
93 | else:
94 | if n_classes == 2:
95 | auc_score = roc_auc_score(all_labels, all_probs[:, 1])
96 | else:
97 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
98 | for class_idx in range(n_classes):
99 | if class_idx in all_labels:
100 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
101 | aucs.append(calc_auc(fpr, tpr))
102 | else:
103 | aucs.append(float('nan'))
104 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
105 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel())
106 | auc_score = calc_auc(fpr, tpr)
107 |
108 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds}
109 | for c in range(n_classes):
110 | results_dict.update({'p_{}'.format(c): all_probs[:,c]})
111 | df = pd.DataFrame(results_dict)
112 | return patient_results, test_error, auc_score, df, acc_logger
113 |
114 | if __name__ == "__main__":
115 | save_dir = './results/camelyon16to17unseen_rnn_simclr_100/'
116 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon17_unseen.csv'
117 | data_dir = '/data2/clh/camelyon17/resnet18_simclr/'
118 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_rnn_simclr_100/'
119 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/'
120 | # {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5}
121 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
122 | data_dir = data_dir,
123 | shuffle = False,
124 | print_info = True,
125 | label_dict = {'normal_tissue':0, 'tumor_tissue':1},
126 | patient_strat=False,
127 | ignore=[])
128 | os.makedirs(save_dir, exist_ok=True)
129 | model = MIL_RNN(embed_dim=512, n_classes=2).cuda()
130 | folds = [0, 1, 2, 3, 4]
131 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds]
132 | all_results = []
133 | all_auc = []
134 | all_acc = []
135 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)]
136 | for ckpt_idx in range(len(ckpt_paths)):
137 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx])
138 | # loader = get_split_loader(test_dataset, testing = False)
139 | loader = get_simple_loader(dataset)
140 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx]))
141 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2)
142 | all_results.append(all_results)
143 | all_auc.append(auc)
144 | all_acc.append(1-test_error)
145 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False)
146 |
147 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc})
148 | save_name = 'summary.csv'
149 | final_df.to_csv(os.path.join(save_dir, save_name))
150 |
--------------------------------------------------------------------------------
/tester_dsmil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.DSMIL import MILNet
18 | from utils import *
19 |
20 |
21 | class Accuracy_Logger(object):
22 | """Accuracy logger"""
23 | def __init__(self, n_classes):
24 | super().__init__()
25 | self.n_classes = n_classes
26 | self.initialize()
27 |
28 | def initialize(self):
29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
30 |
31 | def log(self, Y_hat, Y):
32 | Y_hat = int(Y_hat)
33 | Y = int(Y)
34 | self.data[Y]["count"] += 1
35 | self.data[Y]["correct"] += (Y_hat == Y)
36 |
37 | def log_batch(self, Y_hat, Y):
38 | Y_hat = np.array(Y_hat).astype(int)
39 | Y = np.array(Y).astype(int)
40 | for label_class in np.unique(Y):
41 | cls_mask = Y == label_class
42 | self.data[label_class]["count"] += cls_mask.sum()
43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
44 |
45 | def get_summary(self, c):
46 | count = self.data[c]["count"]
47 | correct = self.data[c]["correct"]
48 |
49 | if count == 0:
50 | acc = None
51 | else:
52 | acc = float(correct) / count
53 |
54 | return acc, correct, count
55 |
56 | def summary(model, loader, n_classes):
57 | acc_logger = Accuracy_Logger(n_classes=n_classes)
58 | model.eval()
59 | test_loss = 0.
60 | test_error = 0.
61 |
62 | all_probs = np.zeros((len(loader), n_classes))
63 | all_labels = np.zeros(len(loader))
64 | all_preds = np.zeros(len(loader))
65 |
66 | slide_ids = loader.dataset.slide_data['slide_id']
67 | patient_results = {}
68 | for batch_idx, (data, label) in enumerate(loader):
69 | data, label = data.to(device), label.to(device)
70 | slide_id = slide_ids.iloc[batch_idx]
71 | with torch.inference_mode():
72 | ins_prediction, bag_prediction, _, _ = model(data)
73 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1]
74 | acc_logger.log(Y_hat, label)
75 |
76 | max_prediction, _ = torch.max(ins_prediction, 0)
77 | Y_prob = F.softmax(bag_prediction, dim=-1)
78 | probs = Y_prob.cpu().numpy()
79 |
80 | all_probs[batch_idx] = probs
81 | all_labels[batch_idx] = label.item()
82 | all_preds[batch_idx] = Y_hat.item()
83 |
84 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
85 |
86 | error = calculate_error(Y_hat, label)
87 | test_error += error
88 |
89 | del data
90 | test_error /= len(loader)
91 |
92 | aucs = []
93 | if len(np.unique(all_labels)) == 1:
94 | auc_score = -1
95 |
96 | else:
97 | if n_classes == 2:
98 | auc_score = roc_auc_score(all_labels, all_probs[:, 1])
99 | else:
100 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
101 | for class_idx in range(n_classes):
102 | if class_idx in all_labels:
103 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
104 | aucs.append(calc_auc(fpr, tpr))
105 | else:
106 | aucs.append(float('nan'))
107 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
108 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel())
109 | auc_score = calc_auc(fpr, tpr)
110 |
111 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds}
112 | for c in range(n_classes):
113 | results_dict.update({'p_{}'.format(c): all_probs[:,c]})
114 | df = pd.DataFrame(results_dict)
115 | return patient_results, test_error, auc_score, df, acc_logger
116 |
117 | if __name__ == "__main__":
118 | save_dir = './results/camelyon16to17unseen_dsmil_simclr_100/'
119 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon17_unseen.csv'
120 | data_dir = '/data2/clh/camelyon17/resnet18_simclr/'
121 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_dsmil_simclr/'
122 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/'
123 | # {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5}
124 | # {'LUAD':0, 'LUSC':1}
125 | # {'normal_tissue':0, 'tumor_tissue':1}
126 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
127 | data_dir = data_dir,
128 | shuffle = False,
129 | print_info = True,
130 | label_dict = {'normal_tissue':0, 'tumor_tissue':1},
131 | patient_strat=False,
132 | ignore=[])
133 | os.makedirs(save_dir, exist_ok=True)
134 | model = MILNet(feature_dim=512, n_classes=2).cuda()
135 | folds = [0, 1, 2, 3, 4]
136 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds]
137 | all_results = []
138 | all_auc = []
139 | all_acc = []
140 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)]
141 | for ckpt_idx in range(len(ckpt_paths)):
142 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx])
143 | # loader = get_split_loader(test_dataset, testing = False)
144 | loader = get_simple_loader(dataset)
145 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx]))
146 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2)
147 | all_results.append(all_results)
148 | all_auc.append(auc)
149 | all_acc.append(1-test_error)
150 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False)
151 |
152 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc})
153 | save_name = 'summary.csv'
154 | final_df.to_csv(os.path.join(save_dir, save_name))
155 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AttriMIL: Revisiting attention-based multiple instance learning for whole-slide pathological image classification from a perspective of instance attributes
2 | The official implementation of AttriMIL (published at _Medical Image Analysis 2025_).
3 |
4 | ## 1. Introduction
5 | ### 1.1 Background
6 | WSI classification typically requires the MIL framework to perform two key tasks: bag classification and instance discrimination, which correspond to clinical diagnosis and the localization of disease-positive regions, respectively. Among various MIL architectures, attention-based MIL frameworks address both tasks simultaneously under weak supervision and thus dominate pathology image analysis. However, attention-based MIL frameworks face two challenges:
7 |
8 | (i) The incorrect measure of pathological attributes based on attention, which may confuse diagnosis.
9 |
10 | (ii) The negligence of modeling intra-slide and inter-slide interaction, which is essential to obtain robust semantic representation of instances.
11 |
12 |
13 |
14 |
15 | Figure 1. Illustration of the workflow of attention-based MIL frameworks and the attribute scoring mechanism in AttriMIL.
16 |
17 |
18 |
19 | To overcome these issues, we propose a novel framework named attribute-aware multiple instance learning (AttriMIL) tailored for pathological image classification.
20 |
21 | (i) To identify the pathological attributes of instances, AttriMIL employs a multi-branch attribute scoring mechanism, where each branch integrates attention pooling with the classification head, deriving precise estimation of each instance's contribution to the bag prediction.
22 |
23 | (ii) Considering the intrinsic correlations between image patches in WSIs, we introduce two constraints to enhance the MIL framework's sensitivity to instance attributes.
24 |
25 | (iii) Inspired by parameter-efficient fine-tuning techniques, we design a pathology adaptive learning strategy for efficient pathological feature extraction. This optimized backbone empowers AttriMIL to model instance correlations across multiple feature levels.
26 |
27 | ### 1.2. Framework
28 | Figure 2 presents an overview of AttriMIL, which comprises three main components: (1) a pathology adaptive backbone for extracting optimized instance-level features, (2) multi-branch attribute scoring mechanism with attribute constraints, and (3) score aggregation and bag prediction. In this section, we first revisit multiple instance learning and attention-based frameworks, followed by a detailed description of AttriMIL.
29 |
30 |
31 |
32 |
33 | Figure 2. Overview of the proposed AttriMIL.
34 |
35 |
36 |
37 | ### 1.3 Performance
38 | AttriMIL achieves the state-of-the-art performance on four benchmarks, showcasing the superior bag classification performance, generalization ability, and instance localization capability. Additionally, AttriMIL is capable of identifying bag with a small proportion of target regions.
39 |
40 |
41 |
42 |
43 | Figure 3. Quantative comparison of the state-of-the-art WSI classification algorithms.
44 |
45 |
46 |
47 | ## 2. Quick Start
48 | ### 2.1 Installation
49 | AttriMIL is extended from [CLAM]([https://github.com/microsoft/CvT](https://github.com/mahmoodlab/CLAM)).Assuming that you have installed PyTorch and TorchVision, if not, please follow the [officiall instruction](https://pytorch.org/) to install them firstly.
50 | Intall the dependencies using cmd:
51 | ``` sh
52 | conda env create -f env.yml
53 | ```
54 | The code is developed and tested using pytorch 1.10.0. Other versions of pytorch are not fully tested.
55 |
56 | ### 2.2 Data preparation
57 | Data preparation based on CLAM, including tissue segmentation, patching, and feature extraction. In comparison to traditional process, we introduce a neiboorhood generation process and use pathology-adaptive learning for instance-level feature extraction.
58 | ``` sh
59 | python create_3coords.py # generate neighbor indices
60 | python coord_to_feature.py # incorporate the indices to feature (h5) files
61 | ```
62 | The final data following the structure:
63 | ```bash
64 | FEATURES_DIRECTORY/
65 | ├── h5_files
66 | ├── slide_1.h5
67 | ├── slide_2.h5
68 | └── ...
69 | └── pt_files
70 | ├── slide_1.pt
71 | ├── slide_2.pt
72 | └── ...
73 | ```
74 | where each .h5 file contains an array of extracted features along with their patch coordinates and neighbor indices.
75 |
76 | ### 2.3 Pretrained Weights
77 | Pretrained weights are based on ResNet18 ImageNet, ResNet18 SimCLR. We follow [DSMIL](https://github.com/binli123/dsmil-wsi/tree/master) for generating the SSL features. Surely, you can also use other pre-trained models.
78 |
79 | ### 2.4 Training and Testing
80 | Training your AttriMIL:
81 | ``` sh
82 | python trainer_attrimil_abmil.py
83 | ```
84 | Note that, the AttriMIL+DSMIL and AttriMIL+TransMIL will be released soon.
85 |
86 | ``` sh
87 | python tester_attrimil_abmil.py
88 | ```
89 | The visual results can directly capture the disease-positive regions, which is encouraging.
90 |
91 |
92 |
93 |
94 | Figure 4. Qualitative comparison of the state-of-the-art WSI classification algorithms.
95 |
96 |
97 |
98 | ## 3. Citation
99 | If you find this work or code is helpful in your research, please cite:
100 |
101 | ```
102 | @article{cai2025attrimil,
103 | title={AttriMIL: Revisiting attention-based multiple instance learning for whole-slide pathological image classification from a perspective of instance attributes},
104 | author={Cai, Linghan and Huang, Shenjin and Zhang, Ye and Lu, Jinpeng and Zhang, Yongbing},
105 | journal={Medical Image Analysis},
106 | pages={103631},
107 | year={2025},
108 | publisher={Elsevier}
109 | }
110 | ```
111 |
112 | ## 4. Contributing
113 | Thanks to the following work for improving our project:
114 | - CLAM: [https://github.com/mahmoodlab/CLAM](https://github.com/mahmoodlab/CLAM)
115 | - DSMIL: [https://github.com/binli123/dsmil-wsi/tree/master](https://github.com/binli123/dsmil-wsi/tree/master)
116 | - MambaMIL: [https://github.com/isyangshu/MambaMIL](https://github.com/isyangshu/MambaMIL)
117 |
118 | ## 5. License
119 | Distributed under the Apache 2.0 License. See LICENSE for more information.
120 |
--------------------------------------------------------------------------------
/tester_attrimil_abmil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.AttriMIL import AttriMIL
18 | from utils import *
19 |
20 |
21 | class Accuracy_Logger(object):
22 | """Accuracy logger"""
23 | def __init__(self, n_classes):
24 | super().__init__()
25 | self.n_classes = n_classes
26 | self.initialize()
27 |
28 | def initialize(self):
29 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
30 |
31 | def log(self, Y_hat, Y):
32 | Y_hat = int(Y_hat)
33 | Y = int(Y)
34 | self.data[Y]["count"] += 1
35 | self.data[Y]["correct"] += (Y_hat == Y)
36 |
37 | def log_batch(self, Y_hat, Y):
38 | Y_hat = np.array(Y_hat).astype(int)
39 | Y = np.array(Y).astype(int)
40 | for label_class in np.unique(Y):
41 | cls_mask = Y == label_class
42 | self.data[label_class]["count"] += cls_mask.sum()
43 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
44 |
45 | def get_summary(self, c):
46 | count = self.data[c]["count"]
47 | correct = self.data[c]["correct"]
48 |
49 | if count == 0:
50 | acc = None
51 | else:
52 | acc = float(correct) / count
53 |
54 | return acc, correct, count
55 |
56 | def summary(model, loader, n_classes):
57 | acc_logger = Accuracy_Logger(n_classes=n_classes)
58 | model.eval()
59 | test_loss = 0.
60 | test_error = 0.
61 |
62 | all_probs = np.zeros((len(loader), n_classes))
63 | all_labels = np.zeros(len(loader))
64 | all_preds = np.zeros(len(loader))
65 |
66 | slide_ids = loader.dataset.slide_data['slide_id']
67 | patient_results = {}
68 | for batch_idx, (data, label) in enumerate(loader):
69 | data, label = data.to(device), label.to(device)
70 | slide_id = slide_ids.iloc[batch_idx]
71 | with torch.inference_mode():
72 | logits, Y_prob, Y_hat, attribute_score, results_dict = model(data)
73 | acc_logger.log(Y_hat, label)
74 |
75 | probs = Y_prob.cpu().numpy()
76 |
77 | all_probs[batch_idx] = probs
78 | all_labels[batch_idx] = label.item()
79 | all_preds[batch_idx] = Y_hat.item()
80 |
81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
82 |
83 | error = calculate_error(Y_hat, label)
84 | test_error += error
85 |
86 | del data
87 | test_error /= len(loader)
88 |
89 | aucs = []
90 | if len(np.unique(all_labels)) == 1:
91 | auc_score = -1
92 |
93 | else:
94 | if n_classes == 2:
95 | auc_score = roc_auc_score(all_labels, all_probs[:, 1])
96 | else:
97 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
98 | for class_idx in range(n_classes):
99 | if class_idx in all_labels:
100 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
101 | aucs.append(calc_auc(fpr, tpr))
102 | else:
103 | aucs.append(float('nan'))
104 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
105 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel())
106 | auc_score = calc_auc(fpr, tpr)
107 |
108 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds}
109 | for c in range(n_classes):
110 | results_dict.update({'p_{}'.format(c): all_probs[:,c]})
111 | df = pd.DataFrame(results_dict)
112 | return patient_results, test_error, auc_score, df, acc_logger
113 |
114 | if __name__ == "__main__":
115 | import time
116 | start_time = time.time()
117 | save_dir = './results/timing_test/'
118 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon16_test.csv'
119 | data_dir = '/data2/clh/camelyon16/resnet18_simclr/'
120 | weight_dir = '/data1/ceiling/workspace/MIL/AttriMIL/save_weights/camelyon16_attrimil_simclr/'
121 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/camelyon16/'
122 | # {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5}
123 | # {'LUAD':0, 'LUSC':1}
124 | # {'normal_tissue':0, 'tumor_tissue':1}
125 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
126 | data_dir = data_dir,
127 | shuffle = False,
128 | print_info = True,
129 | label_dict = {'normal_tissue':0, 'tumor_tissue':1},
130 | patient_strat=False,
131 | ignore=[])
132 | os.makedirs(save_dir, exist_ok=True)
133 | model = AttriMIL(dim=512, n_classes=2).cuda()
134 | # folds = [0, 1, 2, 3, 4]
135 | folds = [0]
136 | ckpt_paths = [os.path.join(weight_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds]
137 | all_results = []
138 | all_auc = []
139 | all_acc = []
140 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(4)]
141 | for ckpt_idx in range(len(ckpt_paths)):
142 | # train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=csv_path[ckpt_idx])
143 | # loader = get_split_loader(test_dataset, testing = False)
144 | loader = get_simple_loader(dataset)
145 | model.load_state_dict(torch.load(ckpt_paths[ckpt_idx]))
146 | patient_results, test_error, auc, df, acc_logger = summary(model, loader, n_classes=2)
147 | all_results.append(all_results)
148 | all_auc.append(auc)
149 | all_acc.append(1-test_error)
150 | df.to_csv(os.path.join(save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False)
151 |
152 | final_df = pd.DataFrame({'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc})
153 | save_name = 'summary.csv'
154 | final_df.to_csv(os.path.join(save_dir, save_name))
155 | end_time = time.time()
156 | elapsed_time = end_time - start_time
157 | print(f"计时结果: 耗时{elapsed_time}秒")
158 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import pdb
6 |
7 | import torch
8 | import numpy as np
9 | import torch.nn as nn
10 | from torchvision import transforms
11 | from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
12 | import torch.optim as optim
13 | import pdb
14 | import torch.nn.functional as F
15 | import math
16 | from itertools import islice
17 | import collections
18 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 | class SubsetSequentialSampler(Sampler):
21 | """Samples elements sequentially from a given list of indices, without replacement.
22 |
23 | Arguments:
24 | indices (sequence): a sequence of indices
25 | """
26 | def __init__(self, indices):
27 | self.indices = indices
28 |
29 | def __iter__(self):
30 | return iter(self.indices)
31 |
32 | def __len__(self):
33 | return len(self.indices)
34 |
35 | def collate_MIL(batch):
36 | img = torch.cat([item[0] for item in batch], dim = 0)
37 | label = torch.LongTensor([item[1] for item in batch])
38 | return [img, label]
39 |
40 | def collate_MIL_coords(batch): # 训练时候,我需要有coords函数
41 | img = torch.cat([item[0] for item in batch], dim=0)
42 | label = torch.LongTensor([item[1] for item in batch])
43 | coords = torch.cat([torch.from_numpy(item[2]) for item in batch], dim=0)
44 | nearest = torch.cat([torch.from_numpy(item[3]) for item in batch], dim=0)
45 | return [img, label, coords, nearest]
46 |
47 | def collate_features(batch):
48 | img = torch.cat([item[0] for item in batch], dim = 0)
49 | coords = np.vstack([item[1] for item in batch])
50 | return [img, coords]
51 |
52 |
53 | def get_simple_loader(dataset, batch_size=1, num_workers=1):
54 | kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {}
55 | loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs)
56 | return loader
57 |
58 | def get_split_loader(split_dataset, training = False, testing = False, weighted = False):
59 | """
60 | return either the validation loader or training loader
61 | """
62 | kwargs = {'num_workers': 4} if device.type == "cuda" else {}
63 | if not testing:
64 | if training:
65 | if weighted:
66 | weights = make_weights_for_balanced_classes_split(split_dataset)
67 | loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL_coords, **kwargs)
68 | else:
69 | loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL_coords, **kwargs)
70 | else:
71 | loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL_coords, **kwargs)
72 |
73 | else:
74 | ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
75 | loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL, **kwargs )
76 |
77 | return loader
78 |
79 | def get_optim(model, args):
80 | if args.opt == "adam":
81 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
82 | elif args.opt == 'sgd':
83 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
84 | else:
85 | raise NotImplementedError
86 | return optimizer
87 |
88 | def print_network(net):
89 | num_params = 0
90 | num_params_train = 0
91 | print(net)
92 |
93 | for param in net.parameters():
94 | n = param.numel()
95 | num_params += n
96 | if param.requires_grad:
97 | num_params_train += n
98 |
99 | print('Total number of parameters: %d' % num_params)
100 | print('Total number of trainable parameters: %d' % num_params_train)
101 |
102 |
103 | def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
104 | seed = 7, label_frac = 1.0, custom_test_ids = None):
105 | indices = np.arange(samples).astype(int)
106 |
107 | if custom_test_ids is not None:
108 | indices = np.setdiff1d(indices, custom_test_ids)
109 |
110 | np.random.seed(seed)
111 | for i in range(n_splits):
112 | all_val_ids = []
113 | all_test_ids = []
114 | sampled_train_ids = []
115 |
116 | if custom_test_ids is not None: # pre-built test split, do not need to sample
117 | all_test_ids.extend(custom_test_ids)
118 |
119 | for c in range(len(val_num)):
120 | possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
121 | val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
122 |
123 | remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
124 | all_val_ids.extend(val_ids)
125 |
126 | if custom_test_ids is None: # sample test split
127 |
128 | test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
129 | remaining_ids = np.setdiff1d(remaining_ids, test_ids)
130 | all_test_ids.extend(test_ids)
131 |
132 | if label_frac == 1:
133 | sampled_train_ids.extend(remaining_ids)
134 |
135 | else:
136 | sample_num = math.ceil(len(remaining_ids) * label_frac)
137 | slice_ids = np.arange(sample_num)
138 | sampled_train_ids.extend(remaining_ids[slice_ids])
139 |
140 | yield sampled_train_ids, all_val_ids, all_test_ids
141 |
142 |
143 | def nth(iterator, n, default=None):
144 | if n is None:
145 | return collections.deque(iterator, maxlen=0)
146 | else:
147 | return next(islice(iterator,n, None), default)
148 |
149 | def calculate_error(Y_hat, Y):
150 | error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
151 |
152 | return error
153 |
154 | def make_weights_for_balanced_classes_split(dataset):
155 | N = float(len(dataset))
156 | weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]
157 | weight = [0] * int(N)
158 | for idx in range(len(dataset)):
159 | y = dataset.getlabel(idx)
160 | weight[idx] = weight_per_class[y]
161 |
162 | return torch.DoubleTensor(weight)
163 |
164 | def initialize_weights(module):
165 | for m in module.modules():
166 | if isinstance(m, nn.Linear):
167 | nn.init.xavier_normal_(m.weight)
168 | m.bias.data.zero_()
169 |
170 | elif isinstance(m, nn.BatchNorm1d):
171 | nn.init.constant_(m.weight, 1)
172 | nn.init.constant_(m.bias, 0)
173 |
174 |
--------------------------------------------------------------------------------
/models/S4MIL.py:
--------------------------------------------------------------------------------
1 | # This code is taken from the original S4 repository https://github.com/HazyResearch/state-spaces
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange, repeat
7 | import opt_einsum as oe
8 |
9 | _c2r = torch.view_as_real
10 | _r2c = torch.view_as_complex
11 |
12 | class DropoutNd(nn.Module):
13 | def __init__(self, p: float = 0.5, tie=True, transposed=True):
14 | """
15 | tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
16 | """
17 | super().__init__()
18 | if p < 0 or p >= 1:
19 | raise ValueError(
20 | "dropout probability has to be in [0, 1), " "but got {}".format(p))
21 | self.p = p
22 | self.tie = tie
23 | self.transposed = transposed
24 | self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
25 |
26 | def forward(self, X):
27 | """ X: (batch, dim, lengths...) """
28 | if self.training:
29 | if not self.transposed:
30 | X = rearrange(X, 'b d ... -> b ... d')
31 | # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow
32 | mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
33 | # mask = self.binomial.sample(mask_shape)
34 | mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
35 | X = X * mask * (1.0/(1-self.p))
36 | if not self.transposed:
37 | X = rearrange(X, 'b ... d -> b d ...')
38 | return X
39 | return X
40 |
41 |
42 | class S4DKernel(nn.Module):
43 | """Wrapper around SSKernelDiag that generates the diagonal SSM parameters
44 | """
45 |
46 | def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
47 | super().__init__()
48 | # Generate dt
49 | H = d_model
50 | log_dt = torch.rand(H) * (
51 | math.log(dt_max) - math.log(dt_min)
52 | ) + math.log(dt_min)
53 |
54 | C = torch.randn(H, N // 2, dtype=torch.cfloat)
55 | self.C = nn.Parameter(_c2r(C))
56 | self.register("log_dt", log_dt, lr)
57 |
58 | log_A_real = torch.log(0.5 * torch.ones(H, N//2))
59 | A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
60 | self.register("log_A_real", log_A_real, lr)
61 | self.register("A_imag", A_imag, lr)
62 |
63 | def forward(self, L):
64 | """
65 | returns: (..., c, L) where c is number of channels (default 1)
66 | """
67 |
68 | # Materialize parameters
69 | dt = torch.exp(self.log_dt) # (H)
70 | C = _r2c(self.C) # (H N)
71 | A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)
72 |
73 | # Vandermonde multiplication
74 | dtA = A * dt.unsqueeze(-1) # (H N)
75 | K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
76 | C = C * (torch.exp(dtA)-1.) / A
77 | K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real
78 |
79 | return K
80 |
81 | def register(self, name, tensor, lr=None):
82 | """Register a tensor with a configurable learning rate and 0 weight decay"""
83 |
84 | if lr == 0.0:
85 | self.register_buffer(name, tensor)
86 | else:
87 | self.register_parameter(name, nn.Parameter(tensor))
88 |
89 | optim = {"weight_decay": 0.0}
90 | if lr is not None:
91 | optim["lr"] = lr
92 | setattr(getattr(self, name), "_optim", optim)
93 |
94 |
95 | class S4D(nn.Module):
96 |
97 | def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args):
98 | super().__init__()
99 |
100 | self.h = d_model
101 | self.n = d_state
102 | self.d_output = self.h
103 | self.transposed = transposed
104 |
105 | self.D = nn.Parameter(torch.randn(self.h))
106 |
107 | # SSM Kernel
108 | self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)
109 |
110 | # Pointwise
111 | self.activation = nn.GELU()
112 | # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
113 | dropout_fn = DropoutNd
114 | self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
115 |
116 | # position-wise output transform to mix features
117 | self.output_linear = nn.Sequential(
118 | nn.Conv1d(self.h, 2*self.h, kernel_size=1),
119 | nn.GLU(dim=-2),
120 | )
121 |
122 | def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
123 | """ Input and output shape (B, H, L) """
124 | if not self.transposed:
125 | u = u.transpose(-1, -2)
126 | L = u.size(-1)
127 |
128 | # Compute SSM Kernel
129 | k = self.kernel(L=L) # (H L)
130 |
131 | # Convolution
132 | k_f = torch.fft.rfft(k, n=2*L) # (H L)
133 | u_f = torch.fft.rfft(u.to(torch.float32), n=2*L) # (B H L)
134 | y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)
135 |
136 | # Compute D term in state space equation - essentially a skip connection
137 | y = y + u * self.D.unsqueeze(-1)
138 |
139 | y = self.dropout(self.activation(y))
140 | y = self.output_linear(y)
141 | if not self.transposed:
142 | y = y.transpose(-1, -2)
143 | return y
144 |
145 |
146 | class S4Model(nn.Module):
147 | def __init__(self, in_dim, n_classes, dropout, act, survival = False):
148 | super(S4Model, self).__init__()
149 | self.n_classes = n_classes
150 | self._fc1 = [nn.Linear(in_dim, 512)]
151 | if act.lower() == 'relu':
152 | self._fc1 += [nn.ReLU()]
153 | elif act.lower() == 'gelu':
154 | self._fc1 += [nn.GELU()]
155 | if dropout:
156 | self._fc1 += [nn.Dropout(dropout)]
157 | print("dropout: ", dropout)
158 | self._fc1 = nn.Sequential(*self._fc1)
159 | self.s4_block = nn.Sequential(nn.LayerNorm(512),
160 | S4D(d_model=512, d_state=32, transposed=False))
161 |
162 | self.classifier = nn.Linear(512, self.n_classes)
163 | self.survival = survival
164 | def forward(self, x):
165 | x = x.unsqueeze(0)
166 | # print(x.shape)
167 | x = self._fc1(x)
168 | x = self.s4_block(x)
169 | x = torch.max(x, axis=1).values
170 | # print(x.shape)
171 | logits = self.classifier(x)
172 | Y_prob = F.softmax(logits, dim=1)
173 | Y_hat = torch.topk(logits, 1, dim=1)[1]
174 | A_raw = None
175 | results_dict = None
176 | if self.survival:
177 | Y_hat = torch.topk(logits, 1, dim = 1)[1]
178 | hazards = torch.sigmoid(logits)
179 | S = torch.cumprod(1 - hazards, dim=1)
180 | return hazards, S, Y_hat, None, None
181 | return logits, Y_prob, Y_hat, A_raw, results_dict
182 |
183 |
184 | def relocate(self):
185 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
186 | self._fc1 = self._fc1.to(device)
187 | self.s4_block = self.s4_block .to(device)
188 | self.classifier = self.classifier.to(device)
189 |
190 | if __name__ == "__main__":
191 | data = torch.randn((6000, 1024))
192 | data.to('cuda')
193 | model = S4Model(in_dim = 1024, n_classes = 4, act = 'gelu', dropout = 0.25)
194 | print(model)
195 | results_dict = model(data)
196 | print(results_dict)
--------------------------------------------------------------------------------
/datasets/unitopatho_train.csv:
--------------------------------------------------------------------------------
1 | case_id,slide_id,label
2 | patient_1,NORM CASO 8 - 2019-03-04 08.51.07_1,NORM
3 | patient_2,208-B5-NORM_1,NORM
4 | patient_3,271-B5-NORM_1,NORM
5 | patient_4,196-B4-NORM_1,NORM
6 | patient_5,57-B2-NORM_1,NORM
7 | patient_6,215-B5-NORM_1,NORM
8 | patient_7,265-B5-NORM_1,NORM
9 | patient_8,216-B5-NORM_1,NORM
10 | patient_9,131-B3-NORM_1,NORM
11 | patient_10,264-B5-NORM_1,NORM
12 | patient_11,184-B4-NORM_1,NORM
13 | patient_12,185-B4-NORM_1,NORM
14 | patient_13,NORM CASO 7 - 2019-03-01 09.12.01_1,NORM
15 | patient_14,109-B3-TVALG_5,TVA.LG
16 | patient_15,TVA.LG CASO 2 - 2018-12-04 13.19.16_5,TVA.LG
17 | patient_16,93-B2-TVALG_5,TVA.LG
18 | patient_17,TVA.LG CASO 9_5,TVA.LG
19 | patient_18,209-B5-TVALG_5,TVA.LG
20 | patient_19,113-B3-TVALG_5,TVA.LG
21 | patient_20,178-B4-TVALG_5,TVA.LG
22 | patient_21,172-B4-TVALG_5,TVA.LG
23 | patient_22,TVA.LG CASO 1 - 2018-12-04 13.17.55_5,TVA.LG
24 | patient_23,114-B3-TVALG_5,TVA.LG
25 | patient_24,176-B4-TVALG_5,TVA.LG
26 | patient_25,118-B3-TVALG_5,TVA.LG
27 | patient_26,96-B2-TVALG_5,TVA.LG
28 | patient_27,205-B5-TVALG_5,TVA.LG
29 | patient_28,207-B5-TVALG_5,TVA.LG
30 | patient_29,142-B3-TVALG_5,TVA.LG
31 | patient_30,174-B4-TVALG_5,TVA.LG
32 | patient_31,80-B2-TVALG_5,TVA.LG
33 | patient_32,100-B2-TVALG_5,TVA.LG
34 | patient_33,252-B5-TVALG_5,TVA.LG
35 | patient_34,TVA.LG CASO 10_5,TVA.LG
36 | patient_35,189-B4-TVALG_5,TVA.LG
37 | patient_36,TVA.LG CASO 11_5,TVA.LG
38 | patient_37,218-B5-TVALG_5,TVA.LG
39 | patient_38,154-B4-TVALG_5,TVA.LG
40 | patient_39,124-B3-TVALG_5,TVA.LG
41 | patient_40,70-B2-TVALG_5,TVA.LG
42 | patient_41,TVA.LG CASO 4 - 2018-12-04 13.25.02_5,TVA.LG
43 | patient_42,198-B4-TVALG_5,TVA.LG
44 | patient_43,191-B4-TVALG_5,TVA.LG
45 | patient_44,270-B5-TAHG_2,TA.HG
46 | patient_45,226-B5-TAHG_2,TA.HG
47 | patient_46,62-B2-TAHG_2,TA.HG
48 | patient_47,221-B5-TAHG_2,TA.HG
49 | patient_48,84-B2-TAHG_2,TA.HG
50 | patient_49,TA.HG CASO 10 - 2019-03-04 09.24.36_2,TA.HG
51 | patient_50,219-B5-TAHG_2,TA.HG
52 | patient_51,TA.HG CASO 14_2,TA.HG
53 | patient_52,TA.HG CASO 10_2,TA.HG
54 | patient_53,TA.HG CASO 13_2,TA.HG
55 | patient_54,54-B2-TAHG_2,TA.HG
56 | patient_55,TA.HG CASO 1 - 2018-12-04 12.58.12_2,TA.HG
57 | patient_56,223-B5-TAHG_2,TA.HG
58 | patient_57,TA.HG CASO 15_2,TA.HG
59 | patient_58,TA.HG CASO 16_2,TA.HG
60 | patient_59,TA.HG CASO 17_2,TA.HG
61 | patient_60,137-B3-TAHG_2,TA.HG
62 | patient_61,HP CASO 1 - 2018-12-04 13.30.08_0,HP
63 | patient_62,HP CASO 28 - 2019-03-04 17.34.19_0,HP
64 | patient_63,HP CASO 47_0,HP
65 | patient_64,HP CASO 43 B1_0,HP
66 | patient_65,HP CASO 26 - 2019-03-04 09.51.06_0,HP
67 | patient_66,HP CASO 36_0,HP
68 | patient_67,149-B3-HP_0,HP
69 | patient_68,170-B4-HP_0,HP
70 | patient_69,103-B3-HP_0,HP
71 | patient_70,HP CASO 44_0,HP
72 | patient_71,158-B4-HP_0,HP
73 | patient_72,168-B4-HP_0,HP
74 | patient_73,HP CASO 40 B1_0,HP
75 | patient_74,250-B5-HP_0,HP
76 | patient_75,151-B4-HP_0,HP
77 | patient_76,224-B5-HP_0,HP
78 | patient_77,HP CASO 38_0,HP
79 | patient_78,148-B3-HP_0,HP
80 | patient_79,HP CASO 2 - 2018-12-04 13.56.35_0,HP
81 | patient_80,HP CASO 39_0,HP
82 | patient_81,HP CASO 46_0,HP
83 | patient_82,HP CASO 42_0,HP
84 | patient_83,HP CASO 43 D1_0,HP
85 | patient_84,169-B4-HP_0,HP
86 | patient_85,HP CASO 45_0,HP
87 | patient_86,165-B4-HP_0,HP
88 | patient_87,159-B4-HP_0,HP
89 | patient_88,HP CASO 41 B1_0,HP
90 | patient_89,167-B4-HP_0,HP
91 | patient_90,194-B4-HP_0,HP
92 | patient_91,HP CASO 24 - 2019-03-04 08.59.32_0,HP
93 | patient_92,TA.LG CASO 57 - 2019-03-04 10.09.38_3,TA.LG
94 | patient_93,245-B5-TALG_3,TA.LG
95 | patient_94,TA.LG CASO 102_3,TA.LG
96 | patient_95,TA.LG CASO 79_3,TA.LG
97 | patient_96,86-B2-TALG_3,TA.LG
98 | patient_97,87-B2-TALG_3,TA.LG
99 | patient_98,239-B5-TALG_3,TA.LG
100 | patient_99,TA.LG CASO 50 - 2019-03-04 08.13.45_3,TA.LG
101 | patient_100,TA.LG CASO 9 - 2018-12-04 13.41.36_3,TA.LG
102 | patient_101,TA.LG CASO 76_3,TA.LG
103 | patient_102,TA.LG CASO 45 - 2019-03-01 08.32.26_3,TA.LG
104 | patient_103,213-B5-TALG_3,TA.LG
105 | patient_104,51-B2-TALG_3,TA.LG
106 | patient_105,TA.LG CASO 97_3,TA.LG
107 | patient_106,125-B3-TALG_3,TA.LG
108 | patient_107,TA.LG CASO 4 - 2018-12-04 13.29.01_3,TA.LG
109 | patient_108,TA.LG CASO 106_3,TA.LG
110 | patient_109,266-B5-TALG_3,TA.LG
111 | patient_110,73-B2-TALG_3,TA.LG
112 | patient_111,58-B2-TALG_3,TA.LG
113 | patient_112,76-B2-TALG_3,TA.LG
114 | patient_113,104-B3-TALG_3,TA.LG
115 | patient_114,TA.LG CASO 52 - 2019-03-04 08.34.05_3,TA.LG
116 | patient_115,68-B2-TALG_3,TA.LG
117 | patient_116,91-B2-TALG_3,TA.LG
118 | patient_117,225-B5-TALG_3,TA.LG
119 | patient_118,TA.LG CASO 80 C1_3,TA.LG
120 | patient_119,144-B3-TALG_3,TA.LG
121 | patient_120,TA.LG CASO 107_3,TA.LG
122 | patient_121,75-B2-TALG_3,TA.LG
123 | patient_122,52-B2-TALG_3,TA.LG
124 | patient_123,TA.LG CASO 93_3,TA.LG
125 | patient_124,267-B5-TALG_3,TA.LG
126 | patient_125,TA.LG CASO 89_3,TA.LG
127 | patient_126,81-B2-TALG_3,TA.LG
128 | patient_127,227-B5-TALG_3,TA.LG
129 | patient_128,72-B2-TALG_3,TA.LG
130 | patient_129,89-B2-TALG_3,TA.LG
131 | patient_130,TA.LG CASO 49 - 2019-03-04 07.38.56_3,TA.LG
132 | patient_131,132-B3-TALG_3,TA.LG
133 | patient_132,TA.LG CASO 56 - 2019-03-04 10.00.26_3,TA.LG
134 | patient_133,TA.LG CASO 77_3,TA.LG
135 | patient_134,101-B2-TALG_3,TA.LG
136 | patient_135,82-B2-TALG_3,TA.LG
137 | patient_136,TA.LG CASO 51 - 2019-03-04 08.23.50_3,TA.LG
138 | patient_137,133-B3-TALG_3,TA.LG
139 | patient_138,TA.LG CASO 88 A1_3,TA.LG
140 | patient_139,105-B3-TALG_3,TA.LG
141 | patient_140,212-B5-TALG_3,TA.LG
142 | patient_141,136-B3-TALG_3,TA.LG
143 | patient_142,TA.LG CASO 14 - 2018-12-04 13.51.01_3,TA.LG
144 | patient_143,TA.LG CASO 64 - 2019-03-04 17.42.27_3,TA.LG
145 | patient_144,TA.LG CASO 75_3,TA.LG
146 | patient_145,TA.LG CASO 105_3,TA.LG
147 | patient_146,TA.LG CASO 65 - 2019-03-04 18.07.00_3,TA.LG
148 | patient_147,TA.LG CASO 62 - 2019-03-04 17.11.05_3,TA.LG
149 | patient_148,TA.LG CASO 74 8424.19_3,TA.LG
150 | patient_149,TA.LG CASO 7 - 2018-12-04 13.37.32_3,TA.LG
151 | patient_150,TA.LG CASO 104 A1_3,TA.LG
152 | patient_151,TA.LG CASO 85_3,TA.LG
153 | patient_152,74-B2-TALG_3,TA.LG
154 | patient_153,203-B5-TALG_3,TA.LG
155 | patient_154,90-B2-TALG_3,TA.LG
156 | patient_155,TA.LG CASO 60 - 2019-03-04 16.40.51_3,TA.LG
157 | patient_156,69-B2-TALG_3,TA.LG
158 | patient_157,TA.LG CASO 58 - 2019-03-04 16.24.00_3,TA.LG
159 | patient_158,TA.LG CASO 2 - 2018-12-04 12.55.38_3,TA.LG
160 | patient_159,TA.LG CASO 5 - 2018-12-04 13.34.36_3,TA.LG
161 | patient_160,244-B5-TALG_3,TA.LG
162 | patient_161,237-B5-TALG_3,TA.LG
163 | patient_162,94-B2-TALG_3,TA.LG
164 | patient_163,TA.LG CASO 100_3,TA.LG
165 | patient_164,55-B2-TALG_3,TA.LG
166 | patient_165,TA.LG CASO 53 - 2019-03-04 08.42.25_3,TA.LG
167 | patient_166,171-B4-TALG_3,TA.LG
168 | patient_167,TA.LG CASO 99_3,TA.LG
169 | patient_168,173-B4-TALG_3,TA.LG
170 | patient_169,TA.LG CASO 3 - 2018-12-04 13.14.40_3,TA.LG
171 | patient_170,217-B5-TALG_3,TA.LG
172 | patient_171,182-B4-TALG_3,TA.LG
173 | patient_172,TA.LG CASO 84_3,TA.LG
174 | patient_173,236-B5-TALG_3,TA.LG
175 | patient_174,TA.LG CASO 86_3,TA.LG
176 | patient_175,TA.LG CASO 91_3,TA.LG
177 | patient_176,TA.LG CASO 66 - 2019-03-04 18.29.55_3,TA.LG
178 | patient_177,130-B3-TALG_3,TA.LG
179 | patient_178,79-B2-TALG_3,TA.LG
180 | patient_179,TA.LG CASO 88 B1_3,TA.LG
181 | patient_180,127-B3-TALG_3,TA.LG
182 | patient_181,TA.LG CASO 78_3,TA.LG
183 | patient_182,TA.LG CASO 80 A1_3,TA.LG
184 | patient_183,TA.LG CASO 101 B1_3,TA.LG
185 | patient_184,269-B5-TALG_3,TA.LG
186 | patient_185,TA.LG CASO 55 - 2019-03-04 09.43.03_3,TA.LG
187 | patient_186,TA.LG CASO 92 B1_3,TA.LG
188 | patient_187,TA.LG CASO 47 - 2019-03-01 08.51.33_3,TA.LG
189 | patient_188,TA.LG CASO 11 - 2018-12-04 13.46.00_3,TA.LG
190 | patient_189,210-B5-TALG_3,TA.LG
191 | patient_190,TA.LG CASO 81_3,TA.LG
192 | patient_191,61-B2-TALG_3,TA.LG
193 | patient_192,201-B5-TVAHG_4,TVA.HG
194 | patient_193,117-B3-TVAHG_4,TVA.HG
195 | patient_194,268-B5-TVAHG_4,TVA.HG
196 | patient_195,249-B5-TVAHG_4,TVA.HG
197 | patient_196,200-B4-TVAHG_4,TVA.HG
198 | patient_197,108-B3-TVAHG_4,TVA.HG
199 | patient_198,143-B3-TVAHG_4,TVA.HG
200 | patient_199,220-B5-TVAHG_4,TVA.HG
201 | patient_200,TVA.HG CASO 4 - 2018-12-04 13.15.43_4,TVA.HG
202 | patient_201,251-B5-TVAHG_4,TVA.HG
203 | patient_202,190-B4-TVAHG_4,TVA.HG
204 | patient_203,TVA.HG CASO 2 - 2018-12-04 12.53.37_4,TVA.HG
205 | patient_204,181-B4-TVAHG_4,TVA.HG
206 |
--------------------------------------------------------------------------------
/trainer_mil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.MIL import *
18 | from utils import *
19 |
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 |
23 | class Accuracy_Logger(object):
24 | """Accuracy logger"""
25 | def __init__(self, n_classes):
26 | super().__init__()
27 | self.n_classes = n_classes
28 | self.initialize()
29 |
30 | def initialize(self):
31 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
32 |
33 | def log(self, Y_hat, Y):
34 | Y_hat = int(Y_hat)
35 | Y = int(Y)
36 | self.data[Y]["count"] += 1
37 | self.data[Y]["correct"] += (Y_hat == Y)
38 |
39 | def log_batch(self, Y_hat, Y):
40 | Y_hat = np.array(Y_hat).astype(int)
41 | Y = np.array(Y).astype(int)
42 | for label_class in np.unique(Y):
43 | cls_mask = Y == label_class
44 | self.data[label_class]["count"] += cls_mask.sum()
45 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
46 |
47 | def get_summary(self, c):
48 | count = self.data[c]["count"]
49 | correct = self.data[c]["correct"]
50 |
51 | if count == 0:
52 | acc = None
53 | else:
54 | acc = float(correct) / count
55 |
56 | return acc, correct, count
57 |
58 | def summary(model, loader, n_classes):
59 | acc_logger = Accuracy_Logger(n_classes=n_classes)
60 | model.eval()
61 | test_loss = 0.
62 | test_error = 0.
63 |
64 | all_probs = np.zeros((len(loader), n_classes))
65 | all_labels = np.zeros(len(loader))
66 |
67 | slide_ids = loader.dataset.slide_data['slide_id']
68 | patient_results = {}
69 |
70 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
71 | data, label = data.to(device), label.to(device)
72 | slide_id = slide_ids.iloc[batch_idx]
73 | with torch.inference_mode():
74 | logits, Y_prob, Y_hat, y_probs, _ = model(data)
75 | acc_logger.log(Y_hat, label)
76 |
77 | probs = Y_prob.cpu().numpy()
78 | all_probs[batch_idx] = probs
79 | all_labels[batch_idx] = label.item()
80 |
81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
82 | error = calculate_error(Y_hat, label)
83 | test_error += error
84 |
85 | test_error /= len(loader)
86 |
87 | if n_classes == 2:
88 | auc = roc_auc_score(all_labels, all_probs[:, 1])
89 | aucs = []
90 | else:
91 | aucs = []
92 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
93 | for class_idx in range(n_classes):
94 | if class_idx in all_labels:
95 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
96 | aucs.append(calc_auc(fpr, tpr))
97 | else:
98 | aucs.append(float('nan'))
99 |
100 | auc = np.nanmean(np.array(aucs))
101 |
102 | return patient_results, test_error, auc, acc_logger
103 |
104 | def train_mil(datasets,
105 | save_path='./save_weights/camelyon16_transmil_imagenet/',
106 | feature_dim = 512,
107 | n_classes = 2,
108 | fold = 0,
109 | writer_flag = True,
110 | max_epoch = 200,
111 | early_stopping = True,
112 | ):
113 | writer_dir = os.path.join(save_path, str(fold))
114 | if not os.path.isdir(writer_dir):
115 | os.makedirs(writer_dir)
116 | if writer_flag:
117 | from tensorboardX import SummaryWriter
118 | writer = SummaryWriter(writer_dir, flush_secs=15)
119 | else:
120 | writer = None
121 |
122 | print("\nInit train/val/test splits...")
123 | train_split, val_split, test_split = datasets
124 | print("Training on {} samples".format(len(train_split)))
125 | print("Validating on {} samples".format(len(val_split)))
126 | print("Testing on {} samples".format(len(test_split)))
127 |
128 | print("\nInit loss function...")
129 | loss_fn = nn.CrossEntropyLoss()
130 | model = MIL_MeanPooling(embed_dim=feature_dim,
131 | n_classes=n_classes)
132 | _ = model.to(device)
133 |
134 | print("\nInit optimizer")
135 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5)
136 |
137 | print('\nInit Loaders...', end=' ')
138 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True)
139 | val_loader = get_split_loader(val_split, testing = False)
140 | test_loader = get_split_loader(test_split, testing = False)
141 | print('Done!')
142 |
143 | mini_loss = 10000
144 | retain = 0
145 | for epoch in range(max_epoch):
146 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn)
147 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn)
148 | if epoch % 20 == 0:
149 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch)))
150 | if loss < mini_loss:
151 | print("loss decrease from:{} to {}".format(mini_loss, loss))
152 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))
153 | mini_loss = loss
154 | retain = 0
155 | else:
156 | retain += 1
157 | print("Retain of early stopping: {} / {}".format(retain, 20))
158 | if early_stopping:
159 | if retain > 20 and epoch > 50:
160 | print("Early stopping")
161 | break
162 |
163 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))))
164 | summary(model, test_loader, n_classes)
165 |
166 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn):
167 | model.train()
168 | acc_logger = Accuracy_Logger(n_classes=n_classes)
169 | train_loss = 0.
170 | bag_loss = 0.
171 |
172 | print('\n')
173 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
174 | data, label = data.to(device), label.to(device)
175 |
176 | logits, Y_prob, Y_hat, y_probs, _ = model(data)
177 | acc_logger.log(Y_hat, label)
178 | loss_bag = loss_fn(logits, label)
179 | loss = loss_bag
180 |
181 | loss_bag_value = loss_bag.item()
182 | loss_value = loss.item()
183 |
184 | train_loss += loss_value
185 | bag_loss += loss_bag
186 |
187 | if (batch_idx + 1) % 20 == 0:
188 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, label.item(), data.size(0)))
189 |
190 | loss.backward()
191 | optimizer.step()
192 | optimizer.zero_grad()
193 |
194 | train_loss /= len(loader)
195 | bag_loss /= len(loader)
196 |
197 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}, '.format(epoch, train_loss, bag_loss))
198 |
199 | for i in range(n_classes):
200 | acc, correct, count = acc_logger.get_summary(i)
201 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
202 | if writer:
203 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch)
204 | if writer:
205 | writer.add_scalar('train/loss', train_loss, epoch)
206 | writer.add_scalar('train/loss_bag', loss_bag, epoch)
207 |
208 |
209 | def validate(epoch, model, loader, n_classes, writer, loss_fn):
210 | model.eval()
211 | acc_logger = Accuracy_Logger(n_classes=n_classes)
212 | val_loss = 0.
213 |
214 | prob = np.zeros((len(loader), n_classes))
215 | labels = np.zeros(len(loader))
216 |
217 | with torch.no_grad():
218 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
219 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True)
220 |
221 | logits, Y_prob, Y_hat, y_probs, _ = model(data)
222 | acc_logger.log(Y_hat, label)
223 |
224 | loss_bag = loss_fn(logits, label)
225 | loss = loss_bag
226 |
227 | prob[batch_idx] = Y_prob.cpu().numpy()
228 | labels[batch_idx] = label.item()
229 |
230 | val_loss += loss.item()
231 |
232 | val_loss /= len(loader)
233 | if n_classes == 2:
234 | auc = roc_auc_score(labels, prob[:, 1])
235 |
236 | else:
237 | auc = roc_auc_score(labels, prob, multi_class='ovr')
238 | if writer:
239 | writer.add_scalar('val/loss', val_loss, epoch)
240 | writer.add_scalar('val/auc', auc, epoch)
241 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc))
242 |
243 | for i in range(n_classes):
244 | acc, correct, count = acc_logger.get_summary(i)
245 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
246 |
247 | return val_loss
248 |
249 | if __name__ == "__main__":
250 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/unitopatho_train.csv'
251 | data_dir = '/data2/clh/unitopatho/resnet18_imagenet/'
252 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/'
253 | save_dir = './save_weights/unitopatho_mean_imagenet/'
254 |
255 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
256 | data_dir = data_dir,
257 | shuffle = False,
258 | seed = 1,
259 | print_info = True,
260 | label_dict = {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5},
261 | patient_strat=False,
262 | ignore=[])
263 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)]
264 | for step, name in enumerate(csv_path):
265 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name)
266 | train_mil((train_dataset, val_dataset, test_dataset),
267 | save_path=save_dir,
268 | feature_dim = 512,
269 | n_classes = 6,
270 | fold = step,
271 | writer_flag = True,
272 | max_epoch = 200,
273 | early_stopping = True,)
--------------------------------------------------------------------------------
/trainer_transmil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.TransMIL import TransMIL
18 | from utils import *
19 |
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 |
23 | class Accuracy_Logger(object):
24 | """Accuracy logger"""
25 | def __init__(self, n_classes):
26 | super().__init__()
27 | self.n_classes = n_classes
28 | self.initialize()
29 |
30 | def initialize(self):
31 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
32 |
33 | def log(self, Y_hat, Y):
34 | Y_hat = int(Y_hat)
35 | Y = int(Y)
36 | self.data[Y]["count"] += 1
37 | self.data[Y]["correct"] += (Y_hat == Y)
38 |
39 | def log_batch(self, Y_hat, Y):
40 | Y_hat = np.array(Y_hat).astype(int)
41 | Y = np.array(Y).astype(int)
42 | for label_class in np.unique(Y):
43 | cls_mask = Y == label_class
44 | self.data[label_class]["count"] += cls_mask.sum()
45 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
46 |
47 | def get_summary(self, c):
48 | count = self.data[c]["count"]
49 | correct = self.data[c]["correct"]
50 |
51 | if count == 0:
52 | acc = None
53 | else:
54 | acc = float(correct) / count
55 |
56 | return acc, correct, count
57 |
58 | def summary(model, loader, n_classes):
59 | acc_logger = Accuracy_Logger(n_classes=n_classes)
60 | model.eval()
61 | test_loss = 0.
62 | test_error = 0.
63 |
64 | all_probs = np.zeros((len(loader), n_classes))
65 | all_labels = np.zeros(len(loader))
66 |
67 | slide_ids = loader.dataset.slide_data['slide_id']
68 | patient_results = {}
69 |
70 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
71 | data, label = data.to(device), label.to(device)
72 | slide_id = slide_ids.iloc[batch_idx]
73 | with torch.inference_mode():
74 | logits, Y_prob, Y_hat, h0 = model(data)
75 | acc_logger.log(Y_hat, label)
76 |
77 | probs = Y_prob.cpu().numpy()
78 | all_probs[batch_idx] = probs
79 | all_labels[batch_idx] = label.item()
80 |
81 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
82 | error = calculate_error(Y_hat, label)
83 | test_error += error
84 |
85 | test_error /= len(loader)
86 |
87 | if n_classes == 2:
88 | auc = roc_auc_score(all_labels, all_probs[:, 1])
89 | aucs = []
90 | else:
91 | aucs = []
92 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
93 | for class_idx in range(n_classes):
94 | if class_idx in all_labels:
95 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
96 | aucs.append(calc_auc(fpr, tpr))
97 | else:
98 | aucs.append(float('nan'))
99 |
100 | auc = np.nanmean(np.array(aucs))
101 |
102 | return patient_results, test_error, auc, acc_logger
103 |
104 | def train_transmil(datasets,
105 | save_path='./save_weights/camelyon16_transmil_imagenet/',
106 | feature_dim = 512,
107 | n_classes = 2,
108 | fold = 0,
109 | writer_flag = True,
110 | max_epoch = 200,
111 | early_stopping = True,
112 | ):
113 | writer_dir = os.path.join(save_path, str(fold))
114 | if not os.path.isdir(writer_dir):
115 | os.makedirs(writer_dir)
116 | if writer_flag:
117 | from tensorboardX import SummaryWriter
118 | writer = SummaryWriter(writer_dir, flush_secs=15)
119 | else:
120 | writer = None
121 |
122 | print("\nInit train/val/test splits...")
123 | train_split, val_split, test_split = datasets
124 | print("Training on {} samples".format(len(train_split)))
125 | print("Validating on {} samples".format(len(val_split)))
126 | print("Testing on {} samples".format(len(test_split)))
127 |
128 | print("\nInit loss function...")
129 | loss_fn = nn.CrossEntropyLoss()
130 |
131 | model = TransMIL(dim=feature_dim,
132 | n_classes=n_classes)
133 | _ = model.to(device)
134 |
135 | print("\nInit optimizer")
136 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5)
137 |
138 | print('\nInit Loaders...', end=' ')
139 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True)
140 | val_loader = get_split_loader(val_split, testing = False)
141 | test_loader = get_split_loader(test_split, testing = False)
142 | print('Done!')
143 |
144 | mini_loss = 10000
145 | retain = 0
146 | for epoch in range(max_epoch):
147 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn)
148 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn)
149 | if epoch % 20 == 0:
150 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch)))
151 | if loss < mini_loss:
152 | print("loss decrease from:{} to {}".format(mini_loss, loss))
153 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))
154 | mini_loss = loss
155 | retain = 0
156 | else:
157 | retain += 1
158 | print("Retain of early stopping: {} / {}".format(retain, 20))
159 | if early_stopping:
160 | if retain > 20 and epoch > 50:
161 | print("Early stopping")
162 | break
163 |
164 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))))
165 | summary(model, test_loader, n_classes)
166 |
167 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn):
168 | model.train()
169 | acc_logger = Accuracy_Logger(n_classes=n_classes)
170 | train_loss = 0.
171 | bag_loss = 0.
172 |
173 | print('\n')
174 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
175 | data, label = data.to(device), label.to(device)
176 |
177 | logits, Y_prob, Y_hat, h0 = model(data)
178 | acc_logger.log(Y_hat, label)
179 | loss_bag = loss_fn(logits, label)
180 | loss = loss_bag
181 |
182 | loss_bag_value = loss_bag.item()
183 | loss_value = loss.item()
184 |
185 | train_loss += loss_value
186 | bag_loss += loss_bag
187 |
188 | if (batch_idx + 1) % 20 == 0:
189 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, label.item(), data.size(0)))
190 |
191 | loss.backward()
192 | optimizer.step()
193 | optimizer.zero_grad()
194 |
195 | train_loss /= len(loader)
196 | bag_loss /= len(loader)
197 |
198 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}, '.format(epoch, train_loss, bag_loss))
199 |
200 | for i in range(n_classes):
201 | acc, correct, count = acc_logger.get_summary(i)
202 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
203 | if writer:
204 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch)
205 | if writer:
206 | writer.add_scalar('train/loss', train_loss, epoch)
207 | writer.add_scalar('train/loss_bag', loss_bag, epoch)
208 |
209 |
210 | def validate(epoch, model, loader, n_classes, writer, loss_fn):
211 | model.eval()
212 | acc_logger = Accuracy_Logger(n_classes=n_classes)
213 | val_loss = 0.
214 |
215 | prob = np.zeros((len(loader), n_classes))
216 | labels = np.zeros(len(loader))
217 |
218 | with torch.no_grad():
219 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
220 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True)
221 |
222 | logits, Y_prob, Y_hat, h0 = model(data)
223 | acc_logger.log(Y_hat, label)
224 |
225 | loss_bag = loss_fn(logits, label)
226 | loss = loss_bag
227 |
228 | prob[batch_idx] = Y_prob.cpu().numpy()
229 | labels[batch_idx] = label.item()
230 |
231 | val_loss += loss.item()
232 |
233 | val_loss /= len(loader)
234 | if n_classes == 2:
235 | auc = roc_auc_score(labels, prob[:, 1])
236 |
237 | else:
238 | auc = roc_auc_score(labels, prob, multi_class='ovr')
239 | if writer:
240 | writer.add_scalar('val/loss', val_loss, epoch)
241 | writer.add_scalar('val/auc', auc, epoch)
242 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc))
243 |
244 | for i in range(n_classes):
245 | acc, correct, count = acc_logger.get_summary(i)
246 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
247 |
248 | return val_loss
249 |
250 | if __name__ == "__main__":
251 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/unitopatho_train.csv'
252 | data_dir = '/data2/clh/unitopatho/resnet18_imagenet/'
253 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/unitopatho/'
254 | save_dir = './save_weights/unitopatho_transmil_imagenet/'
255 |
256 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
257 | data_dir = data_dir,
258 | shuffle = False,
259 | seed = 1,
260 | print_info = True,
261 | label_dict = {'NORM':0, 'HP':1, 'TA.HG':2,'TA.LG':3, 'TVA.HG':4, 'TVA.LG':5},
262 | patient_strat=False,
263 | ignore=[])
264 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)]
265 | for step, name in enumerate(csv_path):
266 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name)
267 | train_transmil((train_dataset, val_dataset, test_dataset),
268 | save_path=save_dir,
269 | feature_dim = 512,
270 | n_classes = 6,
271 | fold = step,
272 | writer_flag = True,
273 | max_epoch = 200,
274 | early_stopping = True,)
--------------------------------------------------------------------------------
/datasets/camelyon17_seen.csv:
--------------------------------------------------------------------------------
1 | dir_name,case_id,slide_id,label
2 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_0,normal_tissue
3 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_1,normal_tissue
4 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_2,normal_tissue
5 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_3,tumor_tissue
6 | /data/ceiling/data/MIL/camelyon17/,patient_070,patient_070_node_4,normal_tissue
7 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_0,normal_tissue
8 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_1,normal_tissue
9 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_2,normal_tissue
10 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_3,normal_tissue
11 | /data/ceiling/data/MIL/camelyon17/,patient_071,patient_071_node_4,normal_tissue
12 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_1,tumor_tissue
13 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_2,tumor_tissue
14 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_3,normal_tissue
15 | /data/ceiling/data/MIL/camelyon17/,patient_072,patient_072_node_4,tumor_tissue
16 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_0,normal_tissue
17 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_1,tumor_tissue
18 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_2,normal_tissue
19 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_3,tumor_tissue
20 | /data/ceiling/data/MIL/camelyon17/,patient_073,patient_073_node_4,normal_tissue
21 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_0,normal_tissue
22 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_1,normal_tissue
23 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_2,normal_tissue
24 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_3,normal_tissue
25 | /data/ceiling/data/MIL/camelyon17/,patient_074,patient_074_node_4,tumor_tissue
26 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_0,normal_tissue
27 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_1,normal_tissue
28 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_2,normal_tissue
29 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_3,normal_tissue
30 | /data/ceiling/data/MIL/camelyon17/,patient_075,patient_075_node_4,tumor_tissue
31 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_0,normal_tissue
32 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_1,tumor_tissue
33 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_2,tumor_tissue
34 | /data/ceiling/data/MIL/camelyon17/,patient_076,patient_076_node_3,tumor_tissue
35 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_0,normal_tissue
36 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_1,normal_tissue
37 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_2,tumor_tissue
38 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_3,normal_tissue
39 | /data/ceiling/data/MIL/camelyon17/,patient_077,patient_077_node_4,tumor_tissue
40 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_0,normal_tissue
41 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_1,normal_tissue
42 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_2,normal_tissue
43 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_3,normal_tissue
44 | /data/ceiling/data/MIL/camelyon17/,patient_078,patient_078_node_4,normal_tissue
45 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_0,normal_tissue
46 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_1,normal_tissue
47 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_2,normal_tissue
48 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_3,normal_tissue
49 | /data/ceiling/data/MIL/camelyon17/,patient_079,patient_079_node_4,normal_tissue
50 | /data/ceiling/data/MIL/camelyon17/,patient_080,patient_080_node_2,tumor_tissue
51 | /data/ceiling/data/MIL/camelyon17/,patient_080,patient_080_node_3,tumor_tissue
52 | /data/ceiling/data/MIL/camelyon17/,patient_080,patient_080_node_4,tumor_tissue
53 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_0,normal_tissue
54 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_1,tumor_tissue
55 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_2,tumor_tissue
56 | /data/ceiling/data/MIL/camelyon17/,patient_081,patient_081_node_3,normal_tissue
57 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_0,normal_tissue
58 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_1,normal_tissue
59 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_2,normal_tissue
60 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_3,normal_tissue
61 | /data/ceiling/data/MIL/camelyon17/,patient_082,patient_082_node_4,normal_tissue
62 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_0,normal_tissue
63 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_1,normal_tissue
64 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_2,normal_tissue
65 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_3,normal_tissue
66 | /data/ceiling/data/MIL/camelyon17/,patient_083,patient_083_node_4,normal_tissue
67 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_0,normal_tissue
68 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_1,normal_tissue
69 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_2,tumor_tissue
70 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_3,tumor_tissue
71 | /data/ceiling/data/MIL/camelyon17/,patient_084,patient_084_node_4,normal_tissue
72 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_0,normal_tissue
73 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_1,normal_tissue
74 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_2,normal_tissue
75 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_3,normal_tissue
76 | /data/ceiling/data/MIL/camelyon17/,patient_085,patient_085_node_4,normal_tissue
77 | /data/ceiling/data/MIL/camelyon17/,patient_086,patient_086_node_1,normal_tissue
78 | /data/ceiling/data/MIL/camelyon17/,patient_086,patient_086_node_2,normal_tissue
79 | /data/ceiling/data/MIL/camelyon17/,patient_086,patient_086_node_3,normal_tissue
80 | /data/ceiling/data/MIL/camelyon17/,patient_087,patient_087_node_2,normal_tissue
81 | /data/ceiling/data/MIL/camelyon17/,patient_087,patient_087_node_3,normal_tissue
82 | /data/ceiling/data/MIL/camelyon17/,patient_087,patient_087_node_4,normal_tissue
83 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_0,normal_tissue
84 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_1,tumor_tissue
85 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_2,normal_tissue
86 | /data/ceiling/data/MIL/camelyon17/,patient_088,patient_088_node_3,normal_tissue
87 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_0,normal_tissue
88 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_1,normal_tissue
89 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_2,normal_tissue
90 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_3,tumor_tissue
91 | /data/ceiling/data/MIL/camelyon17/,patient_089,patient_089_node_4,normal_tissue
92 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_0,normal_tissue
93 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_1,normal_tissue
94 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_2,normal_tissue
95 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_3,normal_tissue
96 | /data/ceiling/data/MIL/camelyon17/,patient_090,patient_090_node_4,normal_tissue
97 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_0,normal_tissue
98 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_1,normal_tissue
99 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_2,tumor_tissue
100 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_3,tumor_tissue
101 | /data/ceiling/data/MIL/camelyon17/,patient_091,patient_091_node_4,tumor_tissue
102 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_0,tumor_tissue
103 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_1,tumor_tissue
104 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_2,normal_tissue
105 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_3,tumor_tissue
106 | /data/ceiling/data/MIL/camelyon17/,patient_092,patient_092_node_4,tumor_tissue
107 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_0,normal_tissue
108 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_1,normal_tissue
109 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_2,normal_tissue
110 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_3,normal_tissue
111 | /data/ceiling/data/MIL/camelyon17/,patient_093,patient_093_node_4,tumor_tissue
112 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_0,tumor_tissue
113 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_1,tumor_tissue
114 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_2,tumor_tissue
115 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_3,normal_tissue
116 | /data/ceiling/data/MIL/camelyon17/,patient_094,patient_094_node_4,tumor_tissue
117 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_0,tumor_tissue
118 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_1,normal_tissue
119 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_2,normal_tissue
120 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_3,normal_tissue
121 | /data/ceiling/data/MIL/camelyon17/,patient_095,patient_095_node_4,normal_tissue
122 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_0,tumor_tissue
123 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_1,normal_tissue
124 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_2,tumor_tissue
125 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_3,tumor_tissue
126 | /data/ceiling/data/MIL/camelyon17/,patient_096,patient_096_node_4,tumor_tissue
127 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_0,tumor_tissue
128 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_1,tumor_tissue
129 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_2,normal_tissue
130 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_3,tumor_tissue
131 | /data/ceiling/data/MIL/camelyon17/,patient_097,patient_097_node_4,tumor_tissue
132 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_0,normal_tissue
133 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_1,normal_tissue
134 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_2,normal_tissue
135 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_3,normal_tissue
136 | /data/ceiling/data/MIL/camelyon17/,patient_098,patient_098_node_4,normal_tissue
137 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_0,normal_tissue
138 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_1,normal_tissue
139 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_2,normal_tissue
140 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_3,normal_tissue
141 | /data/ceiling/data/MIL/camelyon17/,patient_099,patient_099_node_4,tumor_tissue
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/trainer_dsmil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.DSMIL import MILNet
18 | from utils import *
19 |
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 |
23 | class Accuracy_Logger(object):
24 | """Accuracy logger"""
25 | def __init__(self, n_classes):
26 | super().__init__()
27 | self.n_classes = n_classes
28 | self.initialize()
29 |
30 | def initialize(self):
31 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
32 |
33 | def log(self, Y_hat, Y):
34 | Y_hat = int(Y_hat)
35 | Y = int(Y)
36 | self.data[Y]["count"] += 1
37 | self.data[Y]["correct"] += (Y_hat == Y)
38 |
39 | def log_batch(self, Y_hat, Y):
40 | Y_hat = np.array(Y_hat).astype(int)
41 | Y = np.array(Y).astype(int)
42 | for label_class in np.unique(Y):
43 | cls_mask = Y == label_class
44 | self.data[label_class]["count"] += cls_mask.sum()
45 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
46 |
47 | def get_summary(self, c):
48 | count = self.data[c]["count"]
49 | correct = self.data[c]["correct"]
50 |
51 | if count == 0:
52 | acc = None
53 | else:
54 | acc = float(correct) / count
55 |
56 | return acc, correct, count
57 |
58 | def summary(model, loader, n_classes):
59 | acc_logger = Accuracy_Logger(n_classes=n_classes)
60 | model.eval()
61 | test_loss = 0.
62 | test_error = 0.
63 |
64 | all_probs = np.zeros((len(loader), n_classes))
65 | all_labels = np.zeros(len(loader))
66 |
67 | slide_ids = loader.dataset.slide_data['slide_id']
68 | patient_results = {}
69 |
70 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
71 | data, label = data.to(device), label.to(device)
72 | slide_id = slide_ids.iloc[batch_idx]
73 | with torch.inference_mode():
74 | ins_prediction, bag_prediction, _, _ = model(data)
75 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1]
76 | acc_logger.log(Y_hat, label)
77 |
78 | max_prediction, _ = torch.max(ins_prediction, 0)
79 | Y_prob = F.softmax(bag_prediction, dim=-1)
80 |
81 | probs = Y_prob.cpu().numpy()
82 | all_probs[batch_idx] = probs
83 | all_labels[batch_idx] = label.item()
84 |
85 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
86 | error = calculate_error(Y_hat, label)
87 | test_error += error
88 |
89 | test_error /= len(loader)
90 |
91 | if n_classes == 2:
92 | auc = roc_auc_score(all_labels, all_probs[:, 1])
93 | aucs = []
94 | else:
95 | aucs = []
96 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
97 | for class_idx in range(n_classes):
98 | if class_idx in all_labels:
99 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
100 | aucs.append(calc_auc(fpr, tpr))
101 | else:
102 | aucs.append(float('nan'))
103 |
104 | auc = np.nanmean(np.array(aucs))
105 |
106 | return patient_results, test_error, auc, acc_logger
107 |
108 | def train_dsmil(datasets,
109 | save_path='./save_weights/camelyon16_dsmil_imagenet/',
110 | feature_dim = 512,
111 | n_classes = 2,
112 | fold = 0,
113 | writer_flag = True,
114 | max_epoch = 200,
115 | early_stopping = True,
116 | ):
117 | writer_dir = os.path.join(save_path, str(fold))
118 | if not os.path.isdir(writer_dir):
119 | os.makedirs(writer_dir)
120 | if writer_flag:
121 | from tensorboardX import SummaryWriter
122 | writer = SummaryWriter(writer_dir, flush_secs=15)
123 | else:
124 | writer = None
125 |
126 | print("\nInit train/val/test splits...")
127 | train_split, val_split, test_split = datasets
128 | print("Training on {} samples".format(len(train_split)))
129 | print("Validating on {} samples".format(len(val_split)))
130 | print("Testing on {} samples".format(len(test_split)))
131 |
132 | print("\nInit loss function...")
133 | loss_fn = nn.CrossEntropyLoss()
134 |
135 | model = MILNet(feature_dim=feature_dim,
136 | n_classes=n_classes)
137 | _ = model.to(device)
138 |
139 | print("\nInit optimizer")
140 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5)
141 |
142 | print('\nInit Loaders...', end=' ')
143 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True)
144 | val_loader = get_split_loader(val_split, testing = False)
145 | test_loader = get_split_loader(test_split, testing = False)
146 | print('Done!')
147 |
148 | mini_loss = 10000
149 | retain = 0
150 | for epoch in range(max_epoch):
151 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn)
152 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn)
153 | if epoch % 10 == 0:
154 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch)))
155 | if loss < mini_loss:
156 | print("loss decrease from:{} to {}".format(mini_loss, loss))
157 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))
158 | mini_loss = loss
159 | retain = 0
160 | else:
161 | retain += 1
162 | print("Retain of early stopping: {} / {}".format(retain, 20))
163 | if early_stopping:
164 | if retain > 20 and epoch > 50:
165 | print("Early stopping")
166 | break
167 |
168 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))))
169 | summary(model, test_loader, n_classes)
170 |
171 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn):
172 | model.train()
173 | acc_logger = Accuracy_Logger(n_classes=n_classes)
174 | train_loss = 0.
175 | bag_loss = 0.
176 | ins_loss = 0.
177 |
178 | print('\n')
179 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
180 | data, label = data.to(device), label.to(device)
181 |
182 | ins_prediction, bag_prediction, _, _ = model(data)
183 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1]
184 | acc_logger.log(Y_hat, label)
185 | max_prediction, _ = torch.max(ins_prediction, 0)
186 | loss_bag = loss_fn(bag_prediction.view(1, -1), label)
187 | loss_ins = loss_fn(max_prediction.view(1, -1), label)
188 | loss = 0.5 * loss_bag + 0.5 * loss_ins
189 |
190 | loss_bag_value = loss_bag.item()
191 | loss_ins_value = loss_ins.item()
192 | loss_value = loss.item()
193 |
194 | train_loss += loss_value
195 | bag_loss += loss_bag
196 | ins_loss += loss_ins
197 |
198 | if (batch_idx + 1) % 20 == 0:
199 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, loss_ins: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, loss_ins_value, label.item(), data.size(0)))
200 |
201 | loss.backward()
202 | optimizer.step()
203 | optimizer.zero_grad()
204 |
205 | train_loss /= len(loader)
206 | bag_loss /= len(loader)
207 | ins_loss /= len(loader)
208 |
209 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}, ins_loss: {:.4f}'.format(epoch, train_loss, bag_loss, ins_loss))
210 |
211 | for i in range(n_classes):
212 | acc, correct, count = acc_logger.get_summary(i)
213 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
214 | if writer:
215 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch)
216 | if writer:
217 | writer.add_scalar('train/loss', train_loss, epoch)
218 | writer.add_scalar('train/loss_bag', loss_bag, epoch)
219 | writer.add_scalar('train/loss_ins', loss_ins, epoch)
220 |
221 |
222 | def validate(epoch, model, loader, n_classes, writer, loss_fn):
223 | model.eval()
224 | acc_logger = Accuracy_Logger(n_classes=n_classes)
225 | val_loss = 0.
226 |
227 | prob = np.zeros((len(loader), n_classes))
228 | labels = np.zeros(len(loader))
229 |
230 | with torch.no_grad():
231 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
232 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True)
233 |
234 | ins_prediction, bag_prediction, _, _ = model(data)
235 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1]
236 | acc_logger.log(Y_hat, label)
237 |
238 | max_prediction, _ = torch.max(ins_prediction, 0)
239 | loss_bag = loss_fn(bag_prediction.view(1, -1), label)
240 | loss_ins = loss_fn(max_prediction.view(1, -1), label)
241 | loss = 0.5 * loss_bag + 0.5 * loss_ins
242 | Y_prob = F.softmax(bag_prediction, dim=-1)
243 |
244 | prob[batch_idx] = Y_prob.cpu().numpy()
245 | labels[batch_idx] = label.item()
246 |
247 | val_loss += loss.item()
248 |
249 | val_loss /= len(loader)
250 | if n_classes == 2:
251 | auc = roc_auc_score(labels, prob[:, 1])
252 |
253 | else:
254 | auc = roc_auc_score(labels, prob, multi_class='ovr')
255 | if writer:
256 | writer.add_scalar('val/loss', val_loss, epoch)
257 | writer.add_scalar('val/auc', auc, epoch)
258 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc))
259 |
260 | for i in range(n_classes):
261 | acc, correct, count = acc_logger.get_summary(i)
262 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
263 |
264 | return val_loss
265 |
266 | if __name__ == "__main__":
267 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon16_total.csv'
268 | data_dir = '/data2/clh/camelyon16/resnet18_imagenet/'
269 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/camelyon16_100/'
270 | save_dir = './save_weights/camelyon16_dsmil_imagenet_100/'
271 | # {'normal_tissue':0, 'tumor_tissue':1}
272 |
273 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
274 | data_dir = data_dir,
275 | shuffle = False,
276 | seed = 1,
277 | print_info = True,
278 | label_dict = {'normal_tissue':0, 'tumor_tissue':1},
279 | patient_strat=False,
280 | ignore=[])
281 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)]
282 | for step, name in enumerate(csv_path):
283 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name)
284 | train_dsmil((train_dataset, val_dataset, test_dataset),
285 | save_path=save_dir,
286 | feature_dim = 512,
287 | n_classes = 2,
288 | fold = step,
289 | writer_flag = True,
290 | max_epoch = 200,
291 | early_stopping = False,)
--------------------------------------------------------------------------------
/trainer_attrimil_abmil.py:
--------------------------------------------------------------------------------
1 | from dataloader import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
2 |
3 | import os
4 | import torch
5 | from torch import nn
6 | import torch.optim as optim
7 | import pdb
8 | import torch.nn.functional as F
9 |
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from sklearn.preprocessing import label_binarize
14 | from sklearn.metrics import roc_auc_score, roc_curve
15 | from sklearn.metrics import auc as calc_auc
16 |
17 | from models.AttriMIL import AttriMIL
18 | from utils import *
19 |
20 | from constraints import spatial_constraint, rank_constraint
21 | import queue
22 |
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 |
25 |
26 | class Accuracy_Logger(object):
27 | """Accuracy logger"""
28 | def __init__(self, n_classes):
29 | super().__init__()
30 | self.n_classes = n_classes
31 | self.initialize()
32 |
33 | def initialize(self):
34 | self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]
35 |
36 | def log(self, Y_hat, Y):
37 | Y_hat = int(Y_hat)
38 | Y = int(Y)
39 | self.data[Y]["count"] += 1
40 | self.data[Y]["correct"] += (Y_hat == Y)
41 |
42 | def log_batch(self, Y_hat, Y):
43 | Y_hat = np.array(Y_hat).astype(int)
44 | Y = np.array(Y).astype(int)
45 | for label_class in np.unique(Y):
46 | cls_mask = Y == label_class
47 | self.data[label_class]["count"] += cls_mask.sum()
48 | self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()
49 |
50 | def get_summary(self, c):
51 | count = self.data[c]["count"]
52 | correct = self.data[c]["correct"]
53 |
54 | if count == 0:
55 | acc = None
56 | else:
57 | acc = float(correct) / count
58 |
59 | return acc, correct, count
60 |
61 | def summary(model, loader, n_classes):
62 | acc_logger = Accuracy_Logger(n_classes=n_classes)
63 | model.eval()
64 | test_loss = 0.
65 | test_error = 0.
66 |
67 | all_probs = np.zeros((len(loader), n_classes))
68 | all_labels = np.zeros(len(loader))
69 |
70 | slide_ids = loader.dataset.slide_data['slide_id']
71 | patient_results = {}
72 |
73 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
74 | data, label = data.to(device), label.to(device)
75 | slide_id = slide_ids.iloc[batch_idx]
76 | with torch.inference_mode():
77 | ins_prediction, bag_prediction, _, _ = model(data)
78 | Y_hat = torch.topk(bag_prediction.view(1, -1), 1, dim = 1)[1]
79 | acc_logger.log(Y_hat, label)
80 |
81 | max_prediction, _ = torch.max(ins_prediction, 0)
82 | Y_prob = F.softmax(bag_prediction, dim=-1)
83 |
84 | probs = Y_prob.cpu().numpy()
85 | all_probs[batch_idx] = probs
86 | all_labels[batch_idx] = label.item()
87 |
88 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}})
89 | error = calculate_error(Y_hat, label)
90 | test_error += error
91 |
92 | test_error /= len(loader)
93 |
94 | if n_classes == 2:
95 | auc = roc_auc_score(all_labels, all_probs[:, 1])
96 | aucs = []
97 | else:
98 | aucs = []
99 | binary_labels = label_binarize(all_labels, classes=[i for i in range(n_classes)])
100 | for class_idx in range(n_classes):
101 | if class_idx in all_labels:
102 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx])
103 | aucs.append(calc_auc(fpr, tpr))
104 | else:
105 | aucs.append(float('nan'))
106 |
107 | auc = np.nanmean(np.array(aucs))
108 |
109 | return patient_results, test_error, auc, acc_logger
110 |
111 | def train_abmil(datasets,
112 | save_path='./save_weights/camelyon16_abmil_imagenet/',
113 | feature_dim = 512,
114 | n_classes = 2,
115 | fold = 0,
116 | writer_flag = True,
117 | max_epoch = 200,
118 | early_stopping = True,):
119 |
120 | writer_dir = os.path.join(save_path, str(fold))
121 | if not os.path.isdir(writer_dir):
122 | os.makedirs(writer_dir)
123 | if writer_flag:
124 | from tensorboardX import SummaryWriter
125 | writer = SummaryWriter(writer_dir, flush_secs=15)
126 | else:
127 | writer = None
128 |
129 | print("\nInit train/val/test splits...")
130 | train_split, val_split, test_split = datasets
131 | print("Training on {} samples".format(len(train_split)))
132 | print("Validating on {} samples".format(len(val_split)))
133 | print("Testing on {} samples".format(len(test_split)))
134 |
135 | print("\nInit loss function...")
136 | loss_fn = nn.CrossEntropyLoss()
137 |
138 | model = AttriMIL(dim=feature_dim,
139 | n_classes=n_classes)
140 | _ = model.to(device)
141 |
142 | print("\nInit optimizer")
143 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, momentum=0.9, weight_decay=1e-5)
144 |
145 | print('\nInit Loaders...', end=' ')
146 | train_loader = get_split_loader(train_split, training=True, testing = False, weighted = True)
147 | val_loader = get_split_loader(val_split, testing = False)
148 | test_loader = get_split_loader(test_split, testing = False)
149 | print('Done!')
150 |
151 | mini_loss = 10000
152 | retain = 0
153 |
154 | for epoch in range(max_epoch):
155 | train_loop(epoch, model, train_loader, optimizer, n_classes, writer, loss_fn)
156 | loss = validate(epoch, model, val_loader, n_classes, writer, loss_fn)
157 | if epoch % 10 == 0:
158 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint_{}.pt'.format(fold, epoch)))
159 | if loss < mini_loss:
160 | print("loss decrease from:{} to {}".format(mini_loss, loss))
161 | torch.save(model.state_dict(), os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold)))
162 | mini_loss = loss
163 | retain = 0
164 | else:
165 | retain += 1
166 | print("Retain of early stopping: {} / {}".format(retain, 20))
167 | if early_stopping:
168 | if retain > 20 and epoch > 50:
169 | print("Early stopping")
170 | break
171 |
172 | model.load_state_dict(torch.load(os.path.join(save_path, 's_{}_checkpoint.pt'.format(fold))))
173 | summary(model, test_loader, n_classes)
174 |
175 | def train_loop(epoch, model, loader, optimizer, n_classes, writer, loss_fn):
176 | model.train()
177 | acc_logger = Accuracy_Logger(n_classes=n_classes)
178 | train_loss = 0.
179 | bag_loss = 0.
180 | ins_loss = 0.
181 |
182 | print('\n')
183 |
184 |
185 | label_positive_list = []
186 | label_negative_list = []
187 | for i in range(n_classes):
188 | label_positive_list.append(queue.Queue(maxsize=4))
189 | label_negative_list.append(queue.Queue(maxsize=4))
190 |
191 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
192 | data, label = data.to(device), label.to(device)
193 |
194 | logits, Y_prob, Y_hat, attribute_score, results_dict = model(data)
195 | acc_logger.log(Y_hat, label)
196 | loss_bag = loss_fn(logits, label)
197 | loss_spa = spatial_constraint(attribute_score, n_classes, nearest, ks=3)
198 | loss_rank, label_positive_list, label_negative_list = rank_constraint(data, label, model, attribute_score, n_classes, label_positive_list, label_negative_list)
199 |
200 | loss = loss_bag + 1.0 * loss_spa + 5.0 * loss_rank
201 |
202 | loss_bag_value = loss_bag.item()
203 | loss_ins_value = loss_ins.item()
204 | loss_spa_value = loss_spa.item()
205 | loss_rank_value = loss_rank.item()
206 | loss_value = loss.item()
207 |
208 | train_loss += loss_value
209 | bag_loss += loss_bag
210 |
211 | if (batch_idx + 1) % 20 == 0:
212 | print('batch {}, loss: {:.4f}, loss_bag: {:.4f}, loss_spa: {:.4f}, bag_size: {}'.format(batch_idx, loss_value, loss_bag_value, loss_spa_value, label.item(), data.size(0)))
213 |
214 | loss.backward()
215 | optimizer.step()
216 | optimizer.zero_grad()
217 |
218 | train_loss /= len(loader)
219 | bag_loss /= len(loader)
220 | ins_loss /= len(loader)
221 |
222 | print('Epoch: {}, train_loss: {:.4f}, bag_loss: {:.4f}'.format(epoch, train_loss, bag_loss))
223 |
224 | for i in range(n_classes):
225 | acc, correct, count = acc_logger.get_summary(i)
226 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
227 | if writer:
228 | writer.add_scalar('train/class_{}_acc'.format(i), acc, epoch)
229 | if writer:
230 | writer.add_scalar('train/loss', train_loss, epoch)
231 | writer.add_scalar('train/loss_bag', loss_bag, epoch)
232 |
233 |
234 | def validate(epoch, model, loader, n_classes, writer, loss_fn):
235 | model.eval()
236 | acc_logger = Accuracy_Logger(n_classes=n_classes)
237 | val_loss = 0.
238 |
239 | prob = np.zeros((len(loader), n_classes))
240 | labels = np.zeros(len(loader))
241 |
242 | with torch.no_grad():
243 | for batch_idx, (data, label, coords, nearest) in enumerate(loader):
244 | data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True)
245 |
246 | logits, Y_prob, Y_hat, attribute_score, results_dict = model(data)
247 | acc_logger.log(Y_hat, label)
248 | loss_bag = loss_fn(logits, label)
249 |
250 | loss = loss_bag
251 |
252 | prob[batch_idx] = Y_prob.cpu().numpy()
253 | labels[batch_idx] = label.item()
254 |
255 | val_loss += loss.item()
256 |
257 | val_loss /= len(loader)
258 | if n_classes == 2:
259 | auc = roc_auc_score(labels, prob[:, 1])
260 |
261 | else:
262 | auc = roc_auc_score(labels, prob, multi_class='ovr')
263 | if writer:
264 | writer.add_scalar('val/loss', val_loss, epoch)
265 | writer.add_scalar('val/auc', auc, epoch)
266 | print('\nVal Set, val_loss: {:.4f}, auc: {:.4f}'.format(val_loss, auc))
267 |
268 | for i in range(n_classes):
269 | acc, correct, count = acc_logger.get_summary(i)
270 | print('class {}: acc {}, correct {}/{}'.format(i, acc, correct, count))
271 |
272 | return val_loss
273 |
274 | if __name__ == "__main__":
275 | csv_path = '/data1/ceiling/workspace/AttriMIL_v2/dataset_csv/camelyon16_total.csv'
276 | data_dir = '/data2/clh/camelyon16/resnet18_imagenet/'
277 | split_path = '/data1/ceiling/workspace/AttriMIL_v2/splits/camelyon16_100/'
278 | save_dir = './save_weights/camelyon16_attrimil_imagenet_100/'
279 | # {'normal_tissue':0, 'tumor_tissue':1}
280 |
281 | dataset = Generic_MIL_Dataset(csv_path = csv_path,
282 | data_dir = data_dir,
283 | shuffle = False,
284 | seed = 1,
285 | print_info = True,
286 | label_dict = {'normal_tissue':0, 'tumor_tissue':1},
287 | patient_strat=False,
288 | ignore=[])
289 | csv_path = [split_path + 'splits_{}.csv'.format(i) for i in range(5)]
290 | for step, name in enumerate(csv_path):
291 | train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, csv_path=name)
292 | train_abmil((train_dataset, val_dataset, test_dataset),
293 | save_path=save_dir,
294 | feature_dim = 512,
295 | n_classes = 2,
296 | fold = step,
297 | writer_flag = True,
298 | max_epoch = 200,
299 | early_stopping = False,)
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 | import math
6 | import re
7 | import pdb
8 | import pickle
9 | from scipy import stats
10 |
11 | from torch.utils.data import Dataset
12 | import h5py
13 |
14 | def save_splits(split_datasets, column_keys, filename, boolean_style=False):
15 | splits = [split_datasets[i].slide_data['slide_id'] for i in range(len(split_datasets))]
16 | if not boolean_style:
17 | df = pd.concat(splits, ignore_index=True, axis=1)
18 | df.columns = column_keys
19 | else:
20 | df = pd.concat(splits, ignore_index = True, axis=0)
21 | index = df.values.tolist()
22 | one_hot = np.eye(len(split_datasets)).astype(bool)
23 | bool_array = np.repeat(one_hot, [len(dset) for dset in split_datasets], axis=0)
24 | df = pd.DataFrame(bool_array, index=index, columns = ['train', 'val', 'test'])
25 |
26 | df.to_csv(filename)
27 | print()
28 |
29 | class Generic_WSI_Classification_Dataset(Dataset):
30 | def __init__(self,
31 | csv_path = 'dataset_csv/ccrcc_clean.csv',
32 | shuffle = False,
33 | seed = 7,
34 | print_info = True,
35 | label_dict = {},
36 | filter_dict = {},
37 | ignore=[],
38 | patient_strat=False,
39 | label_col = None,
40 | patient_voting = 'max',
41 | ):
42 | """
43 | Args:
44 | csv_file (string): Path to the csv file with annotations.
45 | shuffle (boolean): Whether to shuffle
46 | seed (int): random seed for shuffling the data
47 | print_info (boolean): Whether to print a summary of the dataset
48 | label_dict (dict): Dictionary with key, value pairs for converting str labels to int
49 | ignore (list): List containing class labels to ignore
50 | """
51 | self.label_dict = label_dict
52 | self.num_classes = len(set(self.label_dict.values()))
53 | self.seed = seed
54 | self.print_info = print_info
55 | self.patient_strat = patient_strat
56 | self.train_ids, self.val_ids, self.test_ids = (None, None, None)
57 | self.data_dir = None
58 | if not label_col:
59 | label_col = 'label'
60 | self.label_col = label_col
61 |
62 | slide_data = pd.read_csv(csv_path)
63 | slide_data = self.filter_df(slide_data, filter_dict)
64 | slide_data = self.df_prep(slide_data, self.label_dict, ignore, self.label_col)
65 |
66 | ###shuffle data
67 | if shuffle:
68 | np.random.seed(seed)
69 | np.random.shuffle(slide_data)
70 |
71 | self.slide_data = slide_data
72 |
73 | self.patient_data_prep(patient_voting)
74 | self.cls_ids_prep()
75 |
76 | if print_info:
77 | self.summarize()
78 |
79 | def cls_ids_prep(self):
80 | # store ids corresponding each class at the patient or case level
81 | self.patient_cls_ids = [[] for i in range(self.num_classes)]
82 | for i in range(self.num_classes):
83 | self.patient_cls_ids[i] = np.where(self.patient_data['label'] == i)[0]
84 |
85 | # store ids corresponding each class at the slide level
86 | self.slide_cls_ids = [[] for i in range(self.num_classes)]
87 | for i in range(self.num_classes):
88 | self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
89 |
90 | def patient_data_prep(self, patient_voting='max'):
91 | patients = np.unique(np.array(self.slide_data['case_id'])) # get unique patients
92 | patient_labels = []
93 |
94 | for p in patients:
95 | locations = self.slide_data[self.slide_data['case_id'] == p].index.tolist()
96 | assert len(locations) > 0
97 | label = self.slide_data['label'][locations].values
98 | if patient_voting == 'max':
99 | label = label.max() # get patient label (MIL convention)
100 | elif patient_voting == 'maj':
101 | label = stats.mode(label)[0]
102 | else:
103 | raise NotImplementedError
104 | patient_labels.append(label)
105 |
106 | self.patient_data = {'case_id':patients, 'label':np.array(patient_labels)}
107 |
108 | @staticmethod
109 | def df_prep(data, label_dict, ignore, label_col):
110 | if label_col != 'label':
111 | data['label'] = data[label_col].copy()
112 |
113 | mask = data['label'].isin(ignore)
114 | data = data[~mask]
115 | data.reset_index(drop=True, inplace=True)
116 | for i in data.index:
117 | key = data.loc[i, 'label']
118 | data.at[i, 'label'] = label_dict[key]
119 |
120 | return data
121 |
122 | def filter_df(self, df, filter_dict={}):
123 | if len(filter_dict) > 0:
124 | filter_mask = np.full(len(df), True, bool)
125 | # assert 'label' not in filter_dict.keys()
126 | for key, val in filter_dict.items():
127 | mask = df[key].isin(val)
128 | filter_mask = np.logical_and(filter_mask, mask)
129 | df = df[filter_mask]
130 | return df
131 |
132 | def __len__(self):
133 | if self.patient_strat:
134 | return len(self.patient_data['case_id'])
135 |
136 | else:
137 | return len(self.slide_data)
138 |
139 | def summarize(self):
140 | print("label column: {}".format(self.label_col))
141 | print("label dictionary: {}".format(self.label_dict))
142 | print("number of classes: {}".format(self.num_classes))
143 | print("slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False))
144 | for i in range(self.num_classes):
145 | print('Patient-LVL; Number of samples registered in class %d: %d' % (i, self.patient_cls_ids[i].shape[0]))
146 | print('Slide-LVL; Number of samples registered in class %d: %d' % (i, self.slide_cls_ids[i].shape[0]))
147 |
148 | def get_split_from_df(self, all_splits, split_key='train'):
149 | split = all_splits[split_key]
150 | split = split.dropna().reset_index(drop=True)
151 |
152 | if len(split) > 0:
153 | mask = self.slide_data['slide_id'].isin(split.tolist())
154 | df_slice = self.slide_data[mask].reset_index(drop=True)
155 | split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes)
156 | else:
157 | split = None
158 |
159 | return split
160 |
161 | def get_merged_split_from_df(self, all_splits, split_keys=['train']):
162 | merged_split = []
163 | for split_key in split_keys:
164 | split = all_splits[split_key]
165 | split = split.dropna().reset_index(drop=True).tolist()
166 | merged_split.extend(split)
167 |
168 | if len(split) > 0:
169 | mask = self.slide_data['slide_id'].isin(merged_split)
170 | df_slice = self.slide_data[mask].reset_index(drop=True)
171 | split = Generic_Split(df_slice, data_dir=self.data_dir, num_classes=self.num_classes)
172 | else:
173 | split = None
174 |
175 | return split
176 |
177 |
178 | def return_splits(self, from_id=True, csv_path=None):
179 |
180 |
181 | if from_id:
182 | if len(self.train_ids) > 0:
183 | train_data = self.slide_data.loc[self.train_ids].reset_index(drop=True)
184 | train_split = Generic_Split(train_data, data_dir=self.data_dir, num_classes=self.num_classes)
185 |
186 | else:
187 | train_split = None
188 |
189 | if len(self.val_ids) > 0:
190 | val_data = self.slide_data.loc[self.val_ids].reset_index(drop=True)
191 | val_split = Generic_Split(val_data, data_dir=self.data_dir, num_classes=self.num_classes)
192 |
193 | else:
194 | val_split = None
195 |
196 | if len(self.test_ids) > 0:
197 | test_data = self.slide_data.loc[self.test_ids].reset_index(drop=True)
198 | test_split = Generic_Split(test_data, data_dir=self.data_dir, num_classes=self.num_classes)
199 |
200 | else:
201 | test_split = None
202 |
203 |
204 | else:
205 | assert csv_path
206 | all_splits = pd.read_csv(csv_path, dtype=self.slide_data['slide_id'].dtype) # Without "dtype=self.slide_data['slide_id'].dtype", read_csv() will convert all-number columns to a numerical type. Even if we convert numerical columns back to objects later, we may lose zero-padding in the process; the columns must be correctly read in from the get-go. When we compare the individual train/val/test columns to self.slide_data['slide_id'] in the get_split_from_df() method, we cannot compare objects (strings) to numbers or even to incorrectly zero-padded objects/strings. An example of this breaking is shown in https://github.com/andrew-weisman/clam_analysis/tree/main/datatype_comparison_bug-2021-12-01.
207 | train_split = self.get_split_from_df(all_splits, 'train')
208 | val_split = self.get_split_from_df(all_splits, 'val')
209 | test_split = self.get_split_from_df(all_splits, 'test')
210 |
211 | return train_split, val_split, test_split
212 |
213 | def get_list(self, ids):
214 | return self.slide_data['slide_id'][ids]
215 |
216 | def getlabel(self, ids):
217 | return self.slide_data['label'][ids]
218 |
219 | def __getitem__(self, idx):
220 | return None
221 |
222 | def test_split_gen(self, return_descriptor=False):
223 |
224 | if return_descriptor:
225 | index = [list(self.label_dict.keys())[list(self.label_dict.values()).index(i)] for i in range(self.num_classes)]
226 | columns = ['train', 'val', 'test']
227 | df = pd.DataFrame(np.full((len(index), len(columns)), 0, dtype=np.int32), index= index,
228 | columns= columns)
229 |
230 | count = len(self.train_ids)
231 | print('\nnumber of training samples: {}'.format(count))
232 | labels = self.getlabel(self.train_ids)
233 | unique, counts = np.unique(labels, return_counts=True)
234 | for u in range(len(unique)):
235 | print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
236 | if return_descriptor:
237 | df.loc[index[u], 'train'] = counts[u]
238 |
239 | count = len(self.val_ids)
240 | print('\nnumber of val samples: {}'.format(count))
241 | labels = self.getlabel(self.val_ids)
242 | unique, counts = np.unique(labels, return_counts=True)
243 | for u in range(len(unique)):
244 | print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
245 | if return_descriptor:
246 | df.loc[index[u], 'val'] = counts[u]
247 |
248 | count = len(self.test_ids)
249 | print('\nnumber of test samples: {}'.format(count))
250 | labels = self.getlabel(self.test_ids)
251 | unique, counts = np.unique(labels, return_counts=True)
252 | for u in range(len(unique)):
253 | print('number of samples in cls {}: {}'.format(unique[u], counts[u]))
254 | if return_descriptor:
255 | df.loc[index[u], 'test'] = counts[u]
256 |
257 | assert len(np.intersect1d(self.train_ids, self.test_ids)) == 0
258 | assert len(np.intersect1d(self.train_ids, self.val_ids)) == 0
259 | assert len(np.intersect1d(self.val_ids, self.test_ids)) == 0
260 |
261 | if return_descriptor:
262 | return df
263 |
264 | def save_split(self, filename):
265 | train_split = self.get_list(self.train_ids)
266 | val_split = self.get_list(self.val_ids)
267 | test_split = self.get_list(self.test_ids)
268 | df_tr = pd.DataFrame({'train': train_split})
269 | df_v = pd.DataFrame({'val': val_split})
270 | df_t = pd.DataFrame({'test': test_split})
271 | df = pd.concat([df_tr, df_v, df_t], axis=1)
272 | df.to_csv(filename, index = False)
273 |
274 |
275 | class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset):
276 | def __init__(self,
277 | data_dir,
278 | **kwargs):
279 |
280 | super(Generic_MIL_Dataset, self).__init__(**kwargs)
281 | self.data_dir = data_dir
282 | self.use_h5 = True
283 |
284 | def load_from_h5(self, toggle):
285 | self.use_h5 = toggle
286 |
287 | def __getitem__(self, idx):
288 | slide_id = self.slide_data['slide_id'][idx]
289 | label = self.slide_data['label'][idx]
290 | if type(self.data_dir) == dict:
291 | source = self.slide_data['source'][idx]
292 | data_dir = self.data_dir[source]
293 | else:
294 | data_dir = self.data_dir
295 | if not self.use_h5:
296 | if self.data_dir:
297 | full_path = os.path.join(data_dir, 'pt_files', '{}.pt'.format(slide_id))
298 | features = torch.load(full_path)
299 | return features, label
300 |
301 | else:
302 | return slide_id, label
303 | else:
304 | full_path = os.path.join(data_dir,'h5_coords_files','{}.h5'.format(slide_id))
305 | with h5py.File(full_path,'r') as hdf5_file:
306 | features = hdf5_file['features'][:]
307 | coords = hdf5_file['coords'][:]
308 | nearest = hdf5_file['nearest'][:]
309 | features = torch.from_numpy(features)
310 | return features, label, coords, nearest
311 |
312 |
313 | class Generic_Split(Generic_MIL_Dataset):
314 | def __init__(self, slide_data, data_dir=None, num_classes=2):
315 | self.use_h5 = True
316 | self.slide_data = slide_data
317 | self.data_dir = data_dir
318 | self.num_classes = num_classes
319 | self.slide_cls_ids = [[] for i in range(self.num_classes)]
320 | for i in range(self.num_classes):
321 | self.slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
322 |
323 | def __len__(self):
324 | return len(self.slide_data)
325 |
--------------------------------------------------------------------------------
/datasets/camelyon16_total.csv:
--------------------------------------------------------------------------------
1 | case_id,slide_id,label
2 | patient_1,normal_001,normal_tissue
3 | patient_2,normal_002,normal_tissue
4 | patient_3,normal_003,normal_tissue
5 | patient_4,normal_004,normal_tissue
6 | patient_5,normal_005,normal_tissue
7 | patient_6,normal_006,normal_tissue
8 | patient_7,normal_007,normal_tissue
9 | patient_8,normal_008,normal_tissue
10 | patient_9,normal_009,normal_tissue
11 | patient_10,normal_010,normal_tissue
12 | patient_11,normal_011,normal_tissue
13 | patient_12,normal_012,normal_tissue
14 | patient_13,normal_013,normal_tissue
15 | patient_14,normal_014,normal_tissue
16 | patient_15,normal_015,normal_tissue
17 | patient_16,normal_016,normal_tissue
18 | patient_17,normal_017,normal_tissue
19 | patient_18,normal_018,normal_tissue
20 | patient_19,normal_019,normal_tissue
21 | patient_20,normal_020,normal_tissue
22 | patient_21,normal_021,normal_tissue
23 | patient_22,normal_022,normal_tissue
24 | patient_23,normal_023,normal_tissue
25 | patient_24,normal_024,normal_tissue
26 | patient_25,normal_025,normal_tissue
27 | patient_26,normal_026,normal_tissue
28 | patient_27,normal_027,normal_tissue
29 | patient_28,normal_028,normal_tissue
30 | patient_29,normal_029,normal_tissue
31 | patient_30,normal_030,normal_tissue
32 | patient_31,normal_031,normal_tissue
33 | patient_32,normal_032,normal_tissue
34 | patient_33,normal_033,normal_tissue
35 | patient_34,normal_034,normal_tissue
36 | patient_35,normal_035,normal_tissue
37 | patient_36,normal_036,normal_tissue
38 | patient_37,normal_037,normal_tissue
39 | patient_38,normal_038,normal_tissue
40 | patient_39,normal_039,normal_tissue
41 | patient_40,normal_040,normal_tissue
42 | patient_41,normal_041,normal_tissue
43 | patient_42,normal_042,normal_tissue
44 | patient_43,normal_043,normal_tissue
45 | patient_44,normal_044,normal_tissue
46 | patient_45,normal_045,normal_tissue
47 | patient_46,normal_046,normal_tissue
48 | patient_47,normal_047,normal_tissue
49 | patient_48,normal_048,normal_tissue
50 | patient_49,normal_049,normal_tissue
51 | patient_50,normal_050,normal_tissue
52 | patient_51,normal_051,normal_tissue
53 | patient_52,normal_052,normal_tissue
54 | patient_53,normal_053,normal_tissue
55 | patient_54,normal_054,normal_tissue
56 | patient_55,normal_055,normal_tissue
57 | patient_56,normal_056,normal_tissue
58 | patient_57,normal_057,normal_tissue
59 | patient_58,normal_058,normal_tissue
60 | patient_59,normal_059,normal_tissue
61 | patient_60,normal_060,normal_tissue
62 | patient_61,normal_061,normal_tissue
63 | patient_62,normal_062,normal_tissue
64 | patient_63,normal_063,normal_tissue
65 | patient_64,normal_064,normal_tissue
66 | patient_65,normal_065,normal_tissue
67 | patient_66,normal_066,normal_tissue
68 | patient_67,normal_067,normal_tissue
69 | patient_68,normal_068,normal_tissue
70 | patient_69,normal_069,normal_tissue
71 | patient_70,normal_070,normal_tissue
72 | patient_71,normal_071,normal_tissue
73 | patient_72,normal_072,normal_tissue
74 | patient_73,normal_073,normal_tissue
75 | patient_74,normal_074,normal_tissue
76 | patient_75,normal_075,normal_tissue
77 | patient_76,normal_076,normal_tissue
78 | patient_77,normal_077,normal_tissue
79 | patient_78,normal_078,normal_tissue
80 | patient_79,normal_079,normal_tissue
81 | patient_80,normal_080,normal_tissue
82 | patient_81,normal_081,normal_tissue
83 | patient_82,normal_082,normal_tissue
84 | patient_83,normal_083,normal_tissue
85 | patient_84,normal_084,normal_tissue
86 | patient_85,normal_085,normal_tissue
87 | patient_86,normal_087,normal_tissue
88 | patient_87,normal_088,normal_tissue
89 | patient_88,normal_089,normal_tissue
90 | patient_89,normal_090,normal_tissue
91 | patient_90,normal_091,normal_tissue
92 | patient_91,normal_092,normal_tissue
93 | patient_92,normal_093,normal_tissue
94 | patient_93,normal_094,normal_tissue
95 | patient_94,normal_095,normal_tissue
96 | patient_95,normal_096,normal_tissue
97 | patient_96,normal_097,normal_tissue
98 | patient_97,normal_098,normal_tissue
99 | patient_98,normal_099,normal_tissue
100 | patient_99,normal_100,normal_tissue
101 | patient_100,normal_101,normal_tissue
102 | patient_101,normal_102,normal_tissue
103 | patient_102,normal_103,normal_tissue
104 | patient_103,normal_104,normal_tissue
105 | patient_104,normal_105,normal_tissue
106 | patient_105,normal_106,normal_tissue
107 | patient_106,normal_107,normal_tissue
108 | patient_107,normal_108,normal_tissue
109 | patient_108,normal_109,normal_tissue
110 | patient_109,normal_110,normal_tissue
111 | patient_110,normal_111,normal_tissue
112 | patient_111,normal_112,normal_tissue
113 | patient_112,normal_113,normal_tissue
114 | patient_113,normal_114,normal_tissue
115 | patient_114,normal_115,normal_tissue
116 | patient_115,normal_116,normal_tissue
117 | patient_116,normal_117,normal_tissue
118 | patient_117,normal_118,normal_tissue
119 | patient_118,normal_119,normal_tissue
120 | patient_119,normal_120,normal_tissue
121 | patient_120,normal_121,normal_tissue
122 | patient_121,normal_122,normal_tissue
123 | patient_122,normal_123,normal_tissue
124 | patient_123,normal_124,normal_tissue
125 | patient_124,normal_125,normal_tissue
126 | patient_125,normal_126,normal_tissue
127 | patient_126,normal_127,normal_tissue
128 | patient_127,normal_128,normal_tissue
129 | patient_128,normal_129,normal_tissue
130 | patient_129,normal_130,normal_tissue
131 | patient_130,normal_131,normal_tissue
132 | patient_131,normal_132,normal_tissue
133 | patient_132,normal_133,normal_tissue
134 | patient_133,normal_134,normal_tissue
135 | patient_134,normal_135,normal_tissue
136 | patient_135,normal_136,normal_tissue
137 | patient_136,normal_137,normal_tissue
138 | patient_137,normal_138,normal_tissue
139 | patient_138,normal_139,normal_tissue
140 | patient_139,normal_140,normal_tissue
141 | patient_140,normal_141,normal_tissue
142 | patient_141,normal_142,normal_tissue
143 | patient_142,normal_143,normal_tissue
144 | patient_143,normal_144,normal_tissue
145 | patient_144,normal_145,normal_tissue
146 | patient_145,normal_146,normal_tissue
147 | patient_146,normal_147,normal_tissue
148 | patient_147,normal_148,normal_tissue
149 | patient_148,normal_149,normal_tissue
150 | patient_149,normal_150,normal_tissue
151 | patient_150,normal_151,normal_tissue
152 | patient_151,normal_152,normal_tissue
153 | patient_152,normal_153,normal_tissue
154 | patient_153,normal_154,normal_tissue
155 | patient_154,normal_155,normal_tissue
156 | patient_155,normal_156,normal_tissue
157 | patient_156,normal_157,normal_tissue
158 | patient_157,normal_158,normal_tissue
159 | patient_158,normal_159,normal_tissue
160 | patient_159,normal_160,normal_tissue
161 | patient_160,tumor_001,tumor_tissue
162 | patient_161,tumor_002,tumor_tissue
163 | patient_162,tumor_003,tumor_tissue
164 | patient_163,tumor_004,tumor_tissue
165 | patient_164,tumor_005,tumor_tissue
166 | patient_165,tumor_006,tumor_tissue
167 | patient_166,tumor_007,tumor_tissue
168 | patient_167,tumor_008,tumor_tissue
169 | patient_168,tumor_009,tumor_tissue
170 | patient_169,tumor_010,tumor_tissue
171 | patient_170,tumor_011,tumor_tissue
172 | patient_171,tumor_012,tumor_tissue
173 | patient_172,tumor_013,tumor_tissue
174 | patient_173,tumor_014,tumor_tissue
175 | patient_174,tumor_015,tumor_tissue
176 | patient_175,tumor_016,tumor_tissue
177 | patient_176,tumor_017,tumor_tissue
178 | patient_177,tumor_018,tumor_tissue
179 | patient_178,tumor_019,tumor_tissue
180 | patient_179,tumor_020,tumor_tissue
181 | patient_180,tumor_021,tumor_tissue
182 | patient_181,tumor_022,tumor_tissue
183 | patient_182,tumor_023,tumor_tissue
184 | patient_183,tumor_024,tumor_tissue
185 | patient_184,tumor_025,tumor_tissue
186 | patient_185,tumor_026,tumor_tissue
187 | patient_186,tumor_027,tumor_tissue
188 | patient_187,tumor_028,tumor_tissue
189 | patient_188,tumor_029,tumor_tissue
190 | patient_189,tumor_030,tumor_tissue
191 | patient_190,tumor_031,tumor_tissue
192 | patient_191,tumor_032,tumor_tissue
193 | patient_192,tumor_033,tumor_tissue
194 | patient_193,tumor_034,tumor_tissue
195 | patient_194,tumor_035,tumor_tissue
196 | patient_195,tumor_036,tumor_tissue
197 | patient_196,tumor_037,tumor_tissue
198 | patient_197,tumor_038,tumor_tissue
199 | patient_198,tumor_039,tumor_tissue
200 | patient_199,tumor_040,tumor_tissue
201 | patient_200,tumor_041,tumor_tissue
202 | patient_201,tumor_042,tumor_tissue
203 | patient_202,tumor_043,tumor_tissue
204 | patient_203,tumor_044,tumor_tissue
205 | patient_204,tumor_045,tumor_tissue
206 | patient_205,tumor_046,tumor_tissue
207 | patient_206,tumor_047,tumor_tissue
208 | patient_207,tumor_048,tumor_tissue
209 | patient_208,tumor_049,tumor_tissue
210 | patient_209,tumor_050,tumor_tissue
211 | patient_210,tumor_051,tumor_tissue
212 | patient_211,tumor_052,tumor_tissue
213 | patient_212,tumor_053,tumor_tissue
214 | patient_213,tumor_054,tumor_tissue
215 | patient_214,tumor_055,tumor_tissue
216 | patient_215,tumor_056,tumor_tissue
217 | patient_216,tumor_057,tumor_tissue
218 | patient_217,tumor_058,tumor_tissue
219 | patient_218,tumor_059,tumor_tissue
220 | patient_219,tumor_060,tumor_tissue
221 | patient_220,tumor_061,tumor_tissue
222 | patient_221,tumor_062,tumor_tissue
223 | patient_222,tumor_063,tumor_tissue
224 | patient_223,tumor_064,tumor_tissue
225 | patient_224,tumor_065,tumor_tissue
226 | patient_225,tumor_066,tumor_tissue
227 | patient_226,tumor_067,tumor_tissue
228 | patient_227,tumor_068,tumor_tissue
229 | patient_228,tumor_069,tumor_tissue
230 | patient_229,tumor_070,tumor_tissue
231 | patient_230,tumor_071,tumor_tissue
232 | patient_231,tumor_072,tumor_tissue
233 | patient_232,tumor_073,tumor_tissue
234 | patient_233,tumor_074,tumor_tissue
235 | patient_234,tumor_075,tumor_tissue
236 | patient_235,tumor_076,tumor_tissue
237 | patient_236,tumor_077,tumor_tissue
238 | patient_237,tumor_078,tumor_tissue
239 | patient_238,tumor_079,tumor_tissue
240 | patient_239,tumor_080,tumor_tissue
241 | patient_240,tumor_081,tumor_tissue
242 | patient_241,tumor_082,tumor_tissue
243 | patient_242,tumor_083,tumor_tissue
244 | patient_243,tumor_084,tumor_tissue
245 | patient_244,tumor_085,tumor_tissue
246 | patient_245,tumor_086,tumor_tissue
247 | patient_246,tumor_087,tumor_tissue
248 | patient_247,tumor_088,tumor_tissue
249 | patient_248,tumor_089,tumor_tissue
250 | patient_249,tumor_090,tumor_tissue
251 | patient_250,tumor_091,tumor_tissue
252 | patient_251,tumor_092,tumor_tissue
253 | patient_252,tumor_093,tumor_tissue
254 | patient_253,tumor_094,tumor_tissue
255 | patient_254,tumor_095,tumor_tissue
256 | patient_255,tumor_096,tumor_tissue
257 | patient_256,tumor_097,tumor_tissue
258 | patient_257,tumor_098,tumor_tissue
259 | patient_258,tumor_099,tumor_tissue
260 | patient_259,tumor_100,tumor_tissue
261 | patient_260,tumor_101,tumor_tissue
262 | patient_261,tumor_102,tumor_tissue
263 | patient_262,tumor_103,tumor_tissue
264 | patient_263,tumor_104,tumor_tissue
265 | patient_264,tumor_105,tumor_tissue
266 | patient_265,tumor_106,tumor_tissue
267 | patient_266,tumor_107,tumor_tissue
268 | patient_267,tumor_108,tumor_tissue
269 | patient_268,tumor_109,tumor_tissue
270 | patient_269,tumor_110,tumor_tissue
271 | patient_270,tumor_111,tumor_tissue
272 | patient_271,test_001,tumor_tissue
273 | patient_272,test_002,tumor_tissue
274 | patient_273,test_003,normal_tissue
275 | patient_274,test_004,tumor_tissue
276 | patient_275,test_005,normal_tissue
277 | patient_276,test_006,normal_tissue
278 | patient_277,test_007,normal_tissue
279 | patient_278,test_008,tumor_tissue
280 | patient_279,test_009,normal_tissue
281 | patient_280,test_010,tumor_tissue
282 | patient_281,test_011,tumor_tissue
283 | patient_282,test_012,normal_tissue
284 | patient_283,test_013,tumor_tissue
285 | patient_284,test_014,normal_tissue
286 | patient_285,test_015,normal_tissue
287 | patient_286,test_016,tumor_tissue
288 | patient_287,test_017,normal_tissue
289 | patient_288,test_018,normal_tissue
290 | patient_289,test_019,normal_tissue
291 | patient_290,test_020,normal_tissue
292 | patient_291,test_021,tumor_tissue
293 | patient_292,test_022,normal_tissue
294 | patient_293,test_023,normal_tissue
295 | patient_294,test_024,normal_tissue
296 | patient_295,test_025,normal_tissue
297 | patient_296,test_026,tumor_tissue
298 | patient_297,test_027,tumor_tissue
299 | patient_298,test_028,normal_tissue
300 | patient_299,test_029,tumor_tissue
301 | patient_300,test_030,tumor_tissue
302 | patient_301,test_031,normal_tissue
303 | patient_302,test_032,normal_tissue
304 | patient_303,test_033,tumor_tissue
305 | patient_304,test_034,normal_tissue
306 | patient_305,test_035,normal_tissue
307 | patient_306,test_036,normal_tissue
308 | patient_307,test_037,normal_tissue
309 | patient_308,test_038,tumor_tissue
310 | patient_309,test_039,normal_tissue
311 | patient_310,test_040,tumor_tissue
312 | patient_311,test_041,normal_tissue
313 | patient_312,test_042,normal_tissue
314 | patient_313,test_043,normal_tissue
315 | patient_314,test_044,normal_tissue
316 | patient_315,test_045,normal_tissue
317 | patient_316,test_046,tumor_tissue
318 | patient_317,test_047,normal_tissue
319 | patient_318,test_048,tumor_tissue
320 | patient_319,test_050,normal_tissue
321 | patient_320,test_051,tumor_tissue
322 | patient_321,test_052,tumor_tissue
323 | patient_322,test_053,normal_tissue
324 | patient_323,test_054,normal_tissue
325 | patient_324,test_055,normal_tissue
326 | patient_325,test_056,normal_tissue
327 | patient_326,test_057,normal_tissue
328 | patient_327,test_058,normal_tissue
329 | patient_328,test_059,normal_tissue
330 | patient_329,test_060,normal_tissue
331 | patient_330,test_061,tumor_tissue
332 | patient_331,test_062,normal_tissue
333 | patient_332,test_063,normal_tissue
334 | patient_333,test_064,tumor_tissue
335 | patient_334,test_065,tumor_tissue
336 | patient_335,test_066,tumor_tissue
337 | patient_336,test_067,normal_tissue
338 | patient_337,test_068,tumor_tissue
339 | patient_338,test_069,tumor_tissue
340 | patient_339,test_070,normal_tissue
341 | patient_340,test_071,tumor_tissue
342 | patient_341,test_072,normal_tissue
343 | patient_342,test_073,tumor_tissue
344 | patient_343,test_074,tumor_tissue
345 | patient_344,test_075,tumor_tissue
346 | patient_345,test_076,normal_tissue
347 | patient_346,test_077,normal_tissue
348 | patient_347,test_078,normal_tissue
349 | patient_348,test_079,tumor_tissue
350 | patient_349,test_080,normal_tissue
351 | patient_350,test_081,normal_tissue
352 | patient_351,test_082,tumor_tissue
353 | patient_352,test_083,normal_tissue
354 | patient_353,test_084,tumor_tissue
355 | patient_354,test_085,normal_tissue
356 | patient_355,test_086,normal_tissue
357 | patient_356,test_087,normal_tissue
358 | patient_357,test_088,normal_tissue
359 | patient_358,test_089,normal_tissue
360 | patient_359,test_090,tumor_tissue
361 | patient_360,test_091,normal_tissue
362 | patient_361,test_092,tumor_tissue
363 | patient_362,test_093,normal_tissue
364 | patient_363,test_094,tumor_tissue
365 | patient_364,test_095,normal_tissue
366 | patient_365,test_096,normal_tissue
367 | patient_366,test_097,tumor_tissue
368 | patient_367,test_098,normal_tissue
369 | patient_368,test_099,tumor_tissue
370 | patient_369,test_100,normal_tissue
371 | patient_370,test_101,normal_tissue
372 | patient_371,test_102,tumor_tissue
373 | patient_372,test_103,normal_tissue
374 | patient_373,test_104,tumor_tissue
375 | patient_374,test_105,tumor_tissue
376 | patient_375,test_106,normal_tissue
377 | patient_376,test_107,normal_tissue
378 | patient_377,test_108,tumor_tissue
379 | patient_378,test_109,normal_tissue
380 | patient_379,test_110,tumor_tissue
381 | patient_380,test_111,normal_tissue
382 | patient_381,test_112,normal_tissue
383 | patient_382,test_113,tumor_tissue
384 | patient_383,test_114,tumor_tissue
385 | patient_384,test_115,normal_tissue
386 | patient_385,test_116,tumor_tissue
387 | patient_386,test_117,tumor_tissue
388 | patient_387,test_118,normal_tissue
389 | patient_388,test_119,normal_tissue
390 | patient_389,test_120,normal_tissue
391 | patient_390,test_121,tumor_tissue
392 | patient_391,test_122,tumor_tissue
393 | patient_392,test_123,normal_tissue
394 | patient_393,test_124,normal_tissue
395 | patient_394,test_125,normal_tissue
396 | patient_395,test_126,normal_tissue
397 | patient_396,test_127,normal_tissue
398 | patient_397,test_128,normal_tissue
399 | patient_398,test_129,normal_tissue
400 | patient_399,test_130,normal_tissue
401 |
--------------------------------------------------------------------------------
/datasets/camelyon17.csv:
--------------------------------------------------------------------------------
1 | case_id,slide_id,label
2 | patient_000,patient_000_node_0,normal_tissue
3 | patient_000,patient_000_node_1,normal_tissue
4 | patient_000,patient_000_node_2,normal_tissue
5 | patient_000,patient_000_node_3,normal_tissue
6 | patient_000,patient_000_node_4,normal_tissue
7 | patient_001,patient_001_node_0,normal_tissue
8 | patient_001,patient_001_node_1,normal_tissue
9 | patient_001,patient_001_node_2,normal_tissue
10 | patient_001,patient_001_node_3,normal_tissue
11 | patient_001,patient_001_node_4,normal_tissue
12 | patient_002,patient_002_node_0,normal_tissue
13 | patient_002,patient_002_node_1,normal_tissue
14 | patient_002,patient_002_node_2,normal_tissue
15 | patient_002,patient_002_node_3,normal_tissue
16 | patient_002,patient_002_node_4,normal_tissue
17 | patient_003,patient_003_node_0,normal_tissue
18 | patient_003,patient_003_node_1,normal_tissue
19 | patient_003,patient_003_node_2,normal_tissue
20 | patient_003,patient_003_node_3,normal_tissue
21 | patient_003,patient_003_node_4,normal_tissue
22 | patient_004,patient_004_node_0,normal_tissue
23 | patient_004,patient_004_node_1,normal_tissue
24 | patient_004,patient_004_node_2,normal_tissue
25 | patient_004,patient_004_node_3,normal_tissue
26 | patient_005,patient_005_node_0,normal_tissue
27 | patient_005,patient_005_node_1,normal_tissue
28 | patient_005,patient_005_node_2,normal_tissue
29 | patient_006,patient_006_node_0,normal_tissue
30 | patient_006,patient_006_node_1,normal_tissue
31 | patient_006,patient_006_node_3,normal_tissue
32 | patient_007,patient_007_node_0,normal_tissue
33 | patient_007,patient_007_node_1,normal_tissue
34 | patient_007,patient_007_node_2,normal_tissue
35 | patient_007,patient_007_node_3,normal_tissue
36 | patient_007,patient_007_node_4,tumor_tissue
37 | patient_008,patient_008_node_0,tumor_tissue
38 | patient_008,patient_008_node_1,normal_tissue
39 | patient_008,patient_008_node_2,normal_tissue
40 | patient_008,patient_008_node_3,normal_tissue
41 | patient_008,patient_008_node_4,normal_tissue
42 | patient_009,patient_009_node_1,tumor_tissue
43 | patient_009,patient_009_node_2,normal_tissue
44 | patient_009,patient_009_node_3,normal_tissue
45 | patient_009,patient_009_node_4,normal_tissue
46 | patient_010,patient_010_node_0,tumor_tissue
47 | patient_010,patient_010_node_1,normal_tissue
48 | patient_010,patient_010_node_2,normal_tissue
49 | patient_010,patient_010_node_3,normal_tissue
50 | patient_010,patient_010_node_4,tumor_tissue
51 | patient_011,patient_011_node_0,tumor_tissue
52 | patient_011,patient_011_node_1,normal_tissue
53 | patient_011,patient_011_node_2,normal_tissue
54 | patient_011,patient_011_node_3,normal_tissue
55 | patient_011,patient_011_node_4,normal_tissue
56 | patient_012,patient_012_node_0,tumor_tissue
57 | patient_012,patient_012_node_1,normal_tissue
58 | patient_012,patient_012_node_2,normal_tissue
59 | patient_012,patient_012_node_3,normal_tissue
60 | patient_012,patient_012_node_4,normal_tissue
61 | patient_013,patient_013_node_0,normal_tissue
62 | patient_013,patient_013_node_1,normal_tissue
63 | patient_013,patient_013_node_2,tumor_tissue
64 | patient_013,patient_013_node_3,tumor_tissue
65 | patient_013,patient_013_node_4,tumor_tissue
66 | patient_014,patient_014_node_0,normal_tissue
67 | patient_014,patient_014_node_1,normal_tissue
68 | patient_014,patient_014_node_2,tumor_tissue
69 | patient_014,patient_014_node_4,tumor_tissue
70 | patient_015,patient_015_node_0,normal_tissue
71 | patient_015,patient_015_node_1,tumor_tissue
72 | patient_015,patient_015_node_2,tumor_tissue
73 | patient_015,patient_015_node_3,normal_tissue
74 | patient_016,patient_016_node_0,tumor_tissue
75 | patient_016,patient_016_node_2,normal_tissue
76 | patient_016,patient_016_node_3,normal_tissue
77 | patient_016,patient_016_node_4,normal_tissue
78 | patient_017,patient_017_node_0,normal_tissue
79 | patient_017,patient_017_node_2,tumor_tissue
80 | patient_017,patient_017_node_3,tumor_tissue
81 | patient_017,patient_017_node_4,tumor_tissue
82 | patient_018,patient_018_node_0,normal_tissue
83 | patient_018,patient_018_node_1,tumor_tissue
84 | patient_018,patient_018_node_2,tumor_tissue
85 | patient_018,patient_018_node_3,tumor_tissue
86 | patient_019,patient_019_node_0,tumor_tissue
87 | patient_019,patient_019_node_1,tumor_tissue
88 | patient_019,patient_019_node_2,tumor_tissue
89 | patient_019,patient_019_node_3,tumor_tissue
90 | patient_019,patient_019_node_4,normal_tissue
91 | patient_020,patient_020_node_0,normal_tissue
92 | patient_020,patient_020_node_1,tumor_tissue
93 | patient_020,patient_020_node_2,tumor_tissue
94 | patient_020,patient_020_node_3,normal_tissue
95 | patient_020,patient_020_node_4,tumor_tissue
96 | patient_021,patient_021_node_0,tumor_tissue
97 | patient_021,patient_021_node_1,tumor_tissue
98 | patient_021,patient_021_node_2,normal_tissue
99 | patient_021,patient_021_node_3,tumor_tissue
100 | patient_021,patient_021_node_4,tumor_tissue
101 | patient_022,patient_022_node_0,tumor_tissue
102 | patient_022,patient_022_node_1,tumor_tissue
103 | patient_022,patient_022_node_2,normal_tissue
104 | patient_022,patient_022_node_3,normal_tissue
105 | patient_022,patient_022_node_4,tumor_tissue
106 | patient_023,patient_023_node_0,normal_tissue
107 | patient_023,patient_023_node_1,normal_tissue
108 | patient_023,patient_023_node_2,normal_tissue
109 | patient_023,patient_023_node_3,normal_tissue
110 | patient_023,patient_023_node_4,normal_tissue
111 | patient_024,patient_024_node_0,normal_tissue
112 | patient_024,patient_024_node_3,normal_tissue
113 | patient_025,patient_025_node_0,normal_tissue
114 | patient_025,patient_025_node_1,normal_tissue
115 | patient_025,patient_025_node_2,normal_tissue
116 | patient_025,patient_025_node_3,normal_tissue
117 | patient_025,patient_025_node_4,normal_tissue
118 | patient_026,patient_026_node_1,tumor_tissue
119 | patient_026,patient_026_node_2,tumor_tissue
120 | patient_026,patient_026_node_3,tumor_tissue
121 | patient_026,patient_026_node_4,tumor_tissue
122 | patient_027,patient_027_node_0,tumor_tissue
123 | patient_027,patient_027_node_1,normal_tissue
124 | patient_027,patient_027_node_2,tumor_tissue
125 | patient_027,patient_027_node_3,normal_tissue
126 | patient_027,patient_027_node_4,tumor_tissue
127 | patient_028,patient_028_node_0,tumor_tissue
128 | patient_028,patient_028_node_1,tumor_tissue
129 | patient_028,patient_028_node_2,tumor_tissue
130 | patient_028,patient_028_node_3,tumor_tissue
131 | patient_028,patient_028_node_4,normal_tissue
132 | patient_029,patient_029_node_0,tumor_tissue
133 | patient_029,patient_029_node_1,tumor_tissue
134 | patient_029,patient_029_node_3,tumor_tissue
135 | patient_029,patient_029_node_4,normal_tissue
136 | patient_030,patient_030_node_0,normal_tissue
137 | patient_030,patient_030_node_1,normal_tissue
138 | patient_030,patient_030_node_2,normal_tissue
139 | patient_030,patient_030_node_3,normal_tissue
140 | patient_030,patient_030_node_4,tumor_tissue
141 | patient_031,patient_031_node_0,normal_tissue
142 | patient_031,patient_031_node_1,normal_tissue
143 | patient_031,patient_031_node_2,normal_tissue
144 | patient_031,patient_031_node_3,normal_tissue
145 | patient_031,patient_031_node_4,normal_tissue
146 | patient_032,patient_032_node_0,normal_tissue
147 | patient_032,patient_032_node_1,tumor_tissue
148 | patient_032,patient_032_node_2,normal_tissue
149 | patient_032,patient_032_node_3,normal_tissue
150 | patient_032,patient_032_node_4,normal_tissue
151 | patient_033,patient_033_node_0,normal_tissue
152 | patient_033,patient_033_node_1,tumor_tissue
153 | patient_033,patient_033_node_2,tumor_tissue
154 | patient_033,patient_033_node_3,tumor_tissue
155 | patient_033,patient_033_node_4,normal_tissue
156 | patient_034,patient_034_node_0,normal_tissue
157 | patient_034,patient_034_node_1,normal_tissue
158 | patient_034,patient_034_node_2,normal_tissue
159 | patient_034,patient_034_node_3,tumor_tissue
160 | patient_034,patient_034_node_4,normal_tissue
161 | patient_035,patient_035_node_0,normal_tissue
162 | patient_035,patient_035_node_1,normal_tissue
163 | patient_035,patient_035_node_2,normal_tissue
164 | patient_035,patient_035_node_3,normal_tissue
165 | patient_035,patient_035_node_4,normal_tissue
166 | patient_036,patient_036_node_0,normal_tissue
167 | patient_036,patient_036_node_1,normal_tissue
168 | patient_036,patient_036_node_2,normal_tissue
169 | patient_036,patient_036_node_3,tumor_tissue
170 | patient_036,patient_036_node_4,normal_tissue
171 | patient_037,patient_037_node_0,normal_tissue
172 | patient_037,patient_037_node_2,normal_tissue
173 | patient_037,patient_037_node_3,normal_tissue
174 | patient_037,patient_037_node_4,normal_tissue
175 | patient_038,patient_038_node_0,tumor_tissue
176 | patient_038,patient_038_node_1,normal_tissue
177 | patient_038,patient_038_node_2,tumor_tissue
178 | patient_038,patient_038_node_3,normal_tissue
179 | patient_039,patient_039_node_0,tumor_tissue
180 | patient_039,patient_039_node_1,tumor_tissue
181 | patient_039,patient_039_node_2,normal_tissue
182 | patient_039,patient_039_node_3,normal_tissue
183 | patient_039,patient_039_node_4,normal_tissue
184 | patient_040,patient_040_node_0,normal_tissue
185 | patient_040,patient_040_node_1,normal_tissue
186 | patient_040,patient_040_node_3,normal_tissue
187 | patient_040,patient_040_node_4,normal_tissue
188 | patient_041,patient_041_node_1,normal_tissue
189 | patient_041,patient_041_node_2,normal_tissue
190 | patient_041,patient_041_node_3,normal_tissue
191 | patient_041,patient_041_node_4,normal_tissue
192 | patient_042,patient_042_node_0,normal_tissue
193 | patient_042,patient_042_node_1,normal_tissue
194 | patient_042,patient_042_node_2,normal_tissue
195 | patient_042,patient_042_node_3,tumor_tissue
196 | patient_042,patient_042_node_4,normal_tissue
197 | patient_043,patient_043_node_0,normal_tissue
198 | patient_043,patient_043_node_1,normal_tissue
199 | patient_043,patient_043_node_2,normal_tissue
200 | patient_043,patient_043_node_3,normal_tissue
201 | patient_043,patient_043_node_4,normal_tissue
202 | patient_044,patient_044_node_0,normal_tissue
203 | patient_044,patient_044_node_1,normal_tissue
204 | patient_044,patient_044_node_2,normal_tissue
205 | patient_044,patient_044_node_3,tumor_tissue
206 | patient_044,patient_044_node_4,tumor_tissue
207 | patient_045,patient_045_node_0,tumor_tissue
208 | patient_045,patient_045_node_1,tumor_tissue
209 | patient_045,patient_045_node_2,normal_tissue
210 | patient_045,patient_045_node_3,tumor_tissue
211 | patient_045,patient_045_node_4,tumor_tissue
212 | patient_046,patient_046_node_0,normal_tissue
213 | patient_046,patient_046_node_1,normal_tissue
214 | patient_046,patient_046_node_2,normal_tissue
215 | patient_046,patient_046_node_3,tumor_tissue
216 | patient_046,patient_046_node_4,tumor_tissue
217 | patient_047,patient_047_node_0,normal_tissue
218 | patient_047,patient_047_node_1,normal_tissue
219 | patient_047,patient_047_node_2,tumor_tissue
220 | patient_047,patient_047_node_3,normal_tissue
221 | patient_047,patient_047_node_4,normal_tissue
222 | patient_048,patient_048_node_0,tumor_tissue
223 | patient_048,patient_048_node_1,tumor_tissue
224 | patient_048,patient_048_node_2,normal_tissue
225 | patient_048,patient_048_node_3,normal_tissue
226 | patient_048,patient_048_node_4,normal_tissue
227 | patient_049,patient_049_node_0,normal_tissue
228 | patient_049,patient_049_node_1,tumor_tissue
229 | patient_049,patient_049_node_2,tumor_tissue
230 | patient_049,patient_049_node_3,normal_tissue
231 | patient_049,patient_049_node_4,normal_tissue
232 | patient_050,patient_050_node_0,normal_tissue
233 | patient_050,patient_050_node_1,normal_tissue
234 | patient_050,patient_050_node_2,normal_tissue
235 | patient_050,patient_050_node_3,normal_tissue
236 | patient_050,patient_050_node_4,tumor_tissue
237 | patient_051,patient_051_node_0,tumor_tissue
238 | patient_051,patient_051_node_1,tumor_tissue
239 | patient_051,patient_051_node_2,tumor_tissue
240 | patient_051,patient_051_node_3,normal_tissue
241 | patient_051,patient_051_node_4,tumor_tissue
242 | patient_052,patient_052_node_0,tumor_tissue
243 | patient_052,patient_052_node_1,tumor_tissue
244 | patient_052,patient_052_node_2,tumor_tissue
245 | patient_052,patient_052_node_3,tumor_tissue
246 | patient_052,patient_052_node_4,normal_tissue
247 | patient_053,patient_053_node_0,normal_tissue
248 | patient_053,patient_053_node_1,normal_tissue
249 | patient_053,patient_053_node_2,normal_tissue
250 | patient_053,patient_053_node_3,normal_tissue
251 | patient_053,patient_053_node_4,normal_tissue
252 | patient_054,patient_054_node_0,normal_tissue
253 | patient_054,patient_054_node_1,normal_tissue
254 | patient_054,patient_054_node_2,normal_tissue
255 | patient_054,patient_054_node_3,normal_tissue
256 | patient_054,patient_054_node_4,normal_tissue
257 | patient_055,patient_055_node_0,normal_tissue
258 | patient_055,patient_055_node_1,normal_tissue
259 | patient_055,patient_055_node_2,normal_tissue
260 | patient_055,patient_055_node_3,normal_tissue
261 | patient_055,patient_055_node_4,normal_tissue
262 | patient_056,patient_056_node_0,normal_tissue
263 | patient_056,patient_056_node_1,normal_tissue
264 | patient_056,patient_056_node_2,normal_tissue
265 | patient_056,patient_056_node_3,normal_tissue
266 | patient_056,patient_056_node_4,normal_tissue
267 | patient_057,patient_057_node_0,normal_tissue
268 | patient_057,patient_057_node_1,normal_tissue
269 | patient_057,patient_057_node_2,normal_tissue
270 | patient_057,patient_057_node_3,normal_tissue
271 | patient_057,patient_057_node_4,normal_tissue
272 | patient_058,patient_058_node_0,normal_tissue
273 | patient_058,patient_058_node_1,normal_tissue
274 | patient_058,patient_058_node_2,normal_tissue
275 | patient_058,patient_058_node_3,normal_tissue
276 | patient_058,patient_058_node_4,normal_tissue
277 | patient_059,patient_059_node_0,normal_tissue
278 | patient_059,patient_059_node_1,normal_tissue
279 | patient_059,patient_059_node_2,normal_tissue
280 | patient_059,patient_059_node_3,normal_tissue
281 | patient_059,patient_059_node_4,normal_tissue
282 | patient_060,patient_060_node_0,tumor_tissue
283 | patient_060,patient_060_node_1,normal_tissue
284 | patient_060,patient_060_node_2,normal_tissue
285 | patient_060,patient_060_node_4,normal_tissue
286 | patient_061,patient_061_node_0,normal_tissue
287 | patient_061,patient_061_node_1,normal_tissue
288 | patient_061,patient_061_node_2,normal_tissue
289 | patient_061,patient_061_node_3,normal_tissue
290 | patient_061,patient_061_node_4,tumor_tissue
291 | patient_062,patient_062_node_0,tumor_tissue
292 | patient_062,patient_062_node_1,tumor_tissue
293 | patient_062,patient_062_node_2,tumor_tissue
294 | patient_062,patient_062_node_3,tumor_tissue
295 | patient_062,patient_062_node_4,normal_tissue
296 | patient_063,patient_063_node_0,tumor_tissue
297 | patient_063,patient_063_node_1,normal_tissue
298 | patient_063,patient_063_node_2,tumor_tissue
299 | patient_063,patient_063_node_3,tumor_tissue
300 | patient_063,patient_063_node_4,tumor_tissue
301 | patient_064,patient_064_node_1,normal_tissue
302 | patient_064,patient_064_node_2,normal_tissue
303 | patient_064,patient_064_node_4,normal_tissue
304 | patient_065,patient_065_node_0,tumor_tissue
305 | patient_065,patient_065_node_1,normal_tissue
306 | patient_065,patient_065_node_2,normal_tissue
307 | patient_065,patient_065_node_3,normal_tissue
308 | patient_065,patient_065_node_4,tumor_tissue
309 | patient_066,patient_066_node_0,normal_tissue
310 | patient_066,patient_066_node_1,normal_tissue
311 | patient_066,patient_066_node_2,tumor_tissue
312 | patient_066,patient_066_node_3,tumor_tissue
313 | patient_066,patient_066_node_4,normal_tissue
314 | patient_067,patient_067_node_0,normal_tissue
315 | patient_067,patient_067_node_1,normal_tissue
316 | patient_067,patient_067_node_2,normal_tissue
317 | patient_067,patient_067_node_4,tumor_tissue
318 | patient_068,patient_068_node_0,normal_tissue
319 | patient_068,patient_068_node_2,normal_tissue
320 | patient_068,patient_068_node_3,normal_tissue
321 | patient_069,patient_069_node_0,tumor_tissue
322 | patient_069,patient_069_node_1,tumor_tissue
323 | patient_069,patient_069_node_2,tumor_tissue
324 | patient_069,patient_069_node_3,tumor_tissue
325 | patient_069,patient_069_node_4,normal_tissue
326 | patient_070,patient_070_node_0,normal_tissue
327 | patient_070,patient_070_node_1,normal_tissue
328 | patient_070,patient_070_node_2,normal_tissue
329 | patient_070,patient_070_node_3,tumor_tissue
330 | patient_070,patient_070_node_4,normal_tissue
331 | patient_071,patient_071_node_0,normal_tissue
332 | patient_071,patient_071_node_1,normal_tissue
333 | patient_071,patient_071_node_2,normal_tissue
334 | patient_071,patient_071_node_3,normal_tissue
335 | patient_071,patient_071_node_4,normal_tissue
336 | patient_072,patient_072_node_1,tumor_tissue
337 | patient_072,patient_072_node_2,tumor_tissue
338 | patient_072,patient_072_node_3,normal_tissue
339 | patient_072,patient_072_node_4,tumor_tissue
340 | patient_073,patient_073_node_0,normal_tissue
341 | patient_073,patient_073_node_1,tumor_tissue
342 | patient_073,patient_073_node_2,normal_tissue
343 | patient_073,patient_073_node_3,tumor_tissue
344 | patient_073,patient_073_node_4,normal_tissue
345 | patient_074,patient_074_node_0,normal_tissue
346 | patient_074,patient_074_node_1,normal_tissue
347 | patient_074,patient_074_node_2,normal_tissue
348 | patient_074,patient_074_node_3,normal_tissue
349 | patient_074,patient_074_node_4,tumor_tissue
350 | patient_075,patient_075_node_0,normal_tissue
351 | patient_075,patient_075_node_1,normal_tissue
352 | patient_075,patient_075_node_2,normal_tissue
353 | patient_075,patient_075_node_3,normal_tissue
354 | patient_075,patient_075_node_4,tumor_tissue
355 | patient_076,patient_076_node_0,normal_tissue
356 | patient_076,patient_076_node_1,tumor_tissue
357 | patient_076,patient_076_node_2,tumor_tissue
358 | patient_076,patient_076_node_3,tumor_tissue
359 | patient_077,patient_077_node_0,normal_tissue
360 | patient_077,patient_077_node_1,normal_tissue
361 | patient_077,patient_077_node_2,tumor_tissue
362 | patient_077,patient_077_node_3,normal_tissue
363 | patient_077,patient_077_node_4,tumor_tissue
364 | patient_078,patient_078_node_0,normal_tissue
365 | patient_078,patient_078_node_1,normal_tissue
366 | patient_078,patient_078_node_2,normal_tissue
367 | patient_078,patient_078_node_3,normal_tissue
368 | patient_078,patient_078_node_4,normal_tissue
369 | patient_079,patient_079_node_0,normal_tissue
370 | patient_079,patient_079_node_1,normal_tissue
371 | patient_079,patient_079_node_2,normal_tissue
372 | patient_079,patient_079_node_3,normal_tissue
373 | patient_079,patient_079_node_4,normal_tissue
374 | patient_080,patient_080_node_2,tumor_tissue
375 | patient_080,patient_080_node_3,tumor_tissue
376 | patient_080,patient_080_node_4,tumor_tissue
377 | patient_081,patient_081_node_0,normal_tissue
378 | patient_081,patient_081_node_1,tumor_tissue
379 | patient_081,patient_081_node_2,tumor_tissue
380 | patient_081,patient_081_node_3,normal_tissue
381 | patient_082,patient_082_node_0,normal_tissue
382 | patient_082,patient_082_node_1,normal_tissue
383 | patient_082,patient_082_node_2,normal_tissue
384 | patient_082,patient_082_node_3,normal_tissue
385 | patient_082,patient_082_node_4,normal_tissue
386 | patient_083,patient_083_node_0,normal_tissue
387 | patient_083,patient_083_node_1,normal_tissue
388 | patient_083,patient_083_node_2,normal_tissue
389 | patient_083,patient_083_node_3,normal_tissue
390 | patient_083,patient_083_node_4,normal_tissue
391 | patient_084,patient_084_node_0,normal_tissue
392 | patient_084,patient_084_node_1,normal_tissue
393 | patient_084,patient_084_node_2,tumor_tissue
394 | patient_084,patient_084_node_3,tumor_tissue
395 | patient_084,patient_084_node_4,normal_tissue
396 | patient_085,patient_085_node_0,normal_tissue
397 | patient_085,patient_085_node_1,normal_tissue
398 | patient_085,patient_085_node_2,normal_tissue
399 | patient_085,patient_085_node_3,normal_tissue
400 | patient_085,patient_085_node_4,normal_tissue
401 | patient_086,patient_086_node_1,normal_tissue
402 | patient_086,patient_086_node_2,normal_tissue
403 | patient_086,patient_086_node_3,normal_tissue
404 | patient_087,patient_087_node_2,normal_tissue
405 | patient_087,patient_087_node_3,normal_tissue
406 | patient_087,patient_087_node_4,normal_tissue
407 | patient_088,patient_088_node_0,normal_tissue
408 | patient_088,patient_088_node_1,tumor_tissue
409 | patient_088,patient_088_node_2,normal_tissue
410 | patient_088,patient_088_node_3,normal_tissue
411 | patient_089,patient_089_node_0,normal_tissue
412 | patient_089,patient_089_node_1,normal_tissue
413 | patient_089,patient_089_node_2,normal_tissue
414 | patient_089,patient_089_node_3,tumor_tissue
415 | patient_089,patient_089_node_4,normal_tissue
416 | patient_090,patient_090_node_0,normal_tissue
417 | patient_090,patient_090_node_1,normal_tissue
418 | patient_090,patient_090_node_2,normal_tissue
419 | patient_090,patient_090_node_3,normal_tissue
420 | patient_090,patient_090_node_4,normal_tissue
421 | patient_091,patient_091_node_0,normal_tissue
422 | patient_091,patient_091_node_1,normal_tissue
423 | patient_091,patient_091_node_2,tumor_tissue
424 | patient_091,patient_091_node_3,tumor_tissue
425 | patient_091,patient_091_node_4,tumor_tissue
426 | patient_092,patient_092_node_0,tumor_tissue
427 | patient_092,patient_092_node_1,tumor_tissue
428 | patient_092,patient_092_node_2,normal_tissue
429 | patient_092,patient_092_node_3,tumor_tissue
430 | patient_092,patient_092_node_4,tumor_tissue
431 | patient_093,patient_093_node_0,normal_tissue
432 | patient_093,patient_093_node_1,normal_tissue
433 | patient_093,patient_093_node_2,normal_tissue
434 | patient_093,patient_093_node_3,normal_tissue
435 | patient_093,patient_093_node_4,tumor_tissue
436 | patient_094,patient_094_node_0,tumor_tissue
437 | patient_094,patient_094_node_1,tumor_tissue
438 | patient_094,patient_094_node_2,tumor_tissue
439 | patient_094,patient_094_node_3,normal_tissue
440 | patient_094,patient_094_node_4,tumor_tissue
441 | patient_095,patient_095_node_0,tumor_tissue
442 | patient_095,patient_095_node_1,normal_tissue
443 | patient_095,patient_095_node_2,normal_tissue
444 | patient_095,patient_095_node_3,normal_tissue
445 | patient_095,patient_095_node_4,normal_tissue
446 | patient_096,patient_096_node_0,tumor_tissue
447 | patient_096,patient_096_node_1,normal_tissue
448 | patient_096,patient_096_node_2,tumor_tissue
449 | patient_096,patient_096_node_3,tumor_tissue
450 | patient_096,patient_096_node_4,tumor_tissue
451 | patient_097,patient_097_node_0,tumor_tissue
452 | patient_097,patient_097_node_1,tumor_tissue
453 | patient_097,patient_097_node_2,normal_tissue
454 | patient_097,patient_097_node_3,tumor_tissue
455 | patient_097,patient_097_node_4,tumor_tissue
456 | patient_098,patient_098_node_0,normal_tissue
457 | patient_098,patient_098_node_1,normal_tissue
458 | patient_098,patient_098_node_2,normal_tissue
459 | patient_098,patient_098_node_3,normal_tissue
460 | patient_098,patient_098_node_4,normal_tissue
461 | patient_099,patient_099_node_0,normal_tissue
462 | patient_099,patient_099_node_1,normal_tissue
463 | patient_099,patient_099_node_2,normal_tissue
464 | patient_099,patient_099_node_3,normal_tissue
465 | patient_099,patient_099_node_4,tumor_tissue
466 |
--------------------------------------------------------------------------------