├── figure2.png
├── figure3.jpg
├── models
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── alexnet.cpython-36.pyc
│ ├── resnetv1.cpython-36.pyc
│ └── resnetv2.cpython-36.pyc
├── __init__.py
├── resnetv2.py
└── alexnet.py
├── README.md
├── Datasets_loader
├── dataset_CAMELYON16_BasedOnFeat.py
├── dataset_CAMELYON16.py
├── dataset_TCGA_LungCancer.py
├── dataset_MIL_CIFAR.py
├── dataset_CervicalCancer.py
└── dataset_MIL_NCTCRCHE.py
├── util.py
├── utliz.py
├── train_TCGAFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py
├── train_CervicalFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py
└── train_CAMELYONFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py
/figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/miccaiif/WENO/HEAD/figure2.png
--------------------------------------------------------------------------------
/figure3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/miccaiif/WENO/HEAD/figure3.jpg
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/alexnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/alexnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnetv1.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/resnetv1.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnetv2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/resnetv2.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | from .alexnet import *
8 | from .resnetv2 import *
9 | from .resnetv1 import *
10 |
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # :camel: WENO
2 | Official PyTorch implementation of our NeurIPS 2022 paper: **[Bi-directional Weakly Supervised Knowledge Distillation for Whole Slide Image Classification](https://arxiv.org/abs/2210.03664)**. We propose an end-to-end weakly supervised knowledge distillation framework (**WENO**) for WSI classification, which integrates a bag classifier and an instance classifier in a knowledge distillation framework to mutually improve the performance of both classifiers. WENO is a plug-and-play framework that can be easily applied to any existing attention-based bag classification methods.
3 |
4 |
5 |
6 |
7 |
8 | ### Frequently Asked Questions.
9 |
10 | * Regarding the preprocessing
11 |
12 | For specific preprocessing, as the settings of different MIL experiments vary (such as patch size, scale, etc.), patching needs to be conducted according to your own experimental settings. The [DSMIL](https://github.com/binli123/dsmil-wsi) paper provides a good example for reference (and is also referenced in this article). As uploading all these extracted feats files would require a lot of time and space, we have open-sourced the main and key code models. The training details in the paper and main codes can support the reproduction of this work. Thank you again for your attention! You are welcome to contact and cite us! Thank you!
13 |
14 | ### Citation
15 | If this work is helpful to you, please cite it as:
16 | ```
17 | @article{qu2022bi,
18 | title={Bi-directional weakly supervised knowledge distillation for whole slide image classification},
19 | author={Qu, Linhao and Wang, Manning and Song, Zhijian and others},
20 | journal={Advances in Neural Information Processing Systems},
21 | volume={35},
22 | pages={15368--15381},
23 | year={2022}
24 | }
25 | ```
26 |
27 | ### Contact Information
28 | If you have any question, please email to me [lhqu20@fudan.edu.cn](lhqu20@fudan.edu.cn).
29 |
--------------------------------------------------------------------------------
/models/resnetv2.py:
--------------------------------------------------------------------------------
1 | """ Pre-activation ResNet in PyTorch.
2 | also called ResNet v2.
3 |
4 | adapted from https://github.com/kuangliu/pytorch-cifar/edit/master/models/preact_resnet.py
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
8 | """
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import os
13 |
14 | __all__ = ['resnetv2']
15 |
16 | class PreActBottleneck(nn.Module):
17 | '''Pre-activation version of the original Bottleneck module.'''
18 | def __init__(self, in_planes, planes, stride=1,expansion=4):
19 | super(PreActBottleneck, self).__init__()
20 | self.expansion = expansion
21 | self.bn1 = nn.BatchNorm2d(in_planes)
22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
25 | self.bn3 = nn.BatchNorm2d(planes)
26 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
27 |
28 | if stride != 1 or in_planes != self.expansion*planes:
29 | self.shortcut = nn.Sequential(
30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
31 | )
32 |
33 | def forward(self, x):
34 | out = F.relu(self.bn1(x))
35 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
36 | out = self.conv1(out)
37 | out = self.conv2(F.relu(self.bn2(out)))
38 | out = self.conv3(F.relu(self.bn3(out)))
39 | out += shortcut
40 | return out
41 |
42 |
43 | class PreActResNet(nn.Module):
44 | def __init__(self, block, num_blocks, num_classes=10,expansion=4):
45 | super(PreActResNet, self).__init__()
46 | self.in_planes = 16*expansion
47 |
48 | self.features = nn.Sequential(*[
49 | nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3, bias=False),
50 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
51 | self._make_layer(block, 16*expansion, num_blocks[0], stride=1, expansion=4),
52 | self._make_layer(block, 2*16*expansion, num_blocks[1], stride=2, expansion=4),
53 | self._make_layer(block, 4*16*expansion, num_blocks[2], stride=2, expansion=4),
54 | self._make_layer(block, 8*16*expansion, num_blocks[3], stride=2, expansion=4),
55 | nn.AdaptiveAvgPool2d((1, 1))
56 | ])
57 | self.headcount = len(num_classes)
58 | if len(num_classes) == 1:
59 | self.top_layer = nn.Sequential(*[nn.Linear(512*expansion, num_classes[0])]) # for later compatib.
60 | else:
61 | for a,i in enumerate(num_classes):
62 | setattr(self, "top_layer%d" % a, nn.Linear(512*expansion, i))
63 | self.top_layer = None # this way headcount can act as switch.
64 |
65 | def _make_layer(self, block, planes, num_blocks, stride,expansion):
66 | strides = [stride] + [1]*(num_blocks-1)
67 | layers = []
68 | for stride in strides:
69 | layers.append(block(self.in_planes, planes, stride,expansion))
70 | self.in_planes = planes * expansion
71 | return nn.Sequential(*layers)
72 |
73 | def forward(self, x):
74 | out = self.features(x)
75 | out = out.view(out.size(0), -1)
76 | if self.headcount == 1:
77 | if self.top_layer:
78 | out = self.top_layer(out)
79 | return out
80 | else:
81 | outp = []
82 | for i in range(self.headcount):
83 | outp.append(getattr(self, "top_layer%d" % i)(out))
84 | return outp
85 |
86 |
87 |
88 | def PreActResNet50(num_classes):
89 | return PreActResNet(PreActBottleneck, [3,4,6,3],num_classes)
90 |
91 | def resnetv2(nlayers=50, num_classes=[1000], expansion=1):
92 | if nlayers == 50:
93 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes, expansion=4*expansion)
94 | else:
95 | raise NotImplementedError
96 |
97 |
98 | if __name__ == '__main__':
99 | import torch
100 | model = resnetv2(num_classes=[500]*3)
101 | print([ k.shape for k in model(torch.randn(64,3,224,224))])
102 |
--------------------------------------------------------------------------------
/Datasets_loader/dataset_CAMELYON16_BasedOnFeat.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.utils.data as data_utils
5 | from torchvision import datasets, transforms
6 | from PIL import Image
7 | import os
8 | import glob
9 | from skimage import io
10 | from tqdm import tqdm
11 |
12 |
13 | def statistics_slide(slide_path_list):
14 | num_pos_patch_allPosSlide = 0
15 | num_patch_allPosSlide = 0
16 | num_neg_patch_allNegSlide = 0
17 | num_all_slide = len(slide_path_list)
18 |
19 | for i in slide_path_list:
20 | if 'pos' in i.split('/')[-1]: # pos slide
21 | num_pos_patch = len(glob.glob(i + "/*_pos.jpg"))
22 | num_patch = len(glob.glob(i + "/*.jpg"))
23 | num_pos_patch_allPosSlide = num_pos_patch_allPosSlide + num_pos_patch
24 | num_patch_allPosSlide = num_patch_allPosSlide + num_patch
25 | else: # neg slide
26 | num_neg_patch = len(glob.glob(i + "/*.jpg"))
27 | num_neg_patch_allNegSlide = num_neg_patch_allNegSlide + num_neg_patch
28 |
29 | print("[DATA INFO] {} slides totally".format(num_all_slide))
30 | print("[DATA INFO] pos_patch_ratio in pos slide: {:.4f}({}/{})".format(
31 | num_pos_patch_allPosSlide / num_patch_allPosSlide, num_pos_patch_allPosSlide, num_patch_allPosSlide))
32 | print("[DATA INFO] num of patches: {} ({} from pos slide, {} from neg slide)".format(
33 | num_patch_allPosSlide+num_neg_patch_allNegSlide, num_patch_allPosSlide, num_neg_patch_allNegSlide))
34 | return num_patch_allPosSlide+num_neg_patch_allNegSlide
35 |
36 |
37 | class CAMELYON_16_feat(torch.utils.data.Dataset):
38 | # @profile
39 | def __init__(self, root_dir='',
40 | train=True, transform=None, downsample=1.0, drop_threshold=0.0, preload=True, return_bag=False):
41 | self.root_dir = root_dir
42 | self.train = train
43 | self.transform = transform
44 | self.downsample = downsample
45 | self.drop_threshold = drop_threshold # drop the pos slide of which positive patch ratio less than the threshold
46 | self.preload = preload
47 | self.return_bag = return_bag
48 | if train:
49 | self.root_dir = os.path.join(self.root_dir, "training")
50 | else:
51 | self.root_dir = os.path.join(self.root_dir, "testing")
52 |
53 | all_slides = glob.glob(self.root_dir + "/*")
54 | # 1.filter the pos slides which have 0 pos patch
55 | all_pos_slides = glob.glob(self.root_dir + "/*_pos")
56 |
57 | for i in all_pos_slides:
58 | num_pos_patch = len(glob.glob(i + "/*_pos.jpg"))
59 | num_patch = len(glob.glob(i + "/*.jpg"))
60 | if num_pos_patch/num_patch <= self.drop_threshold:
61 | all_slides.remove(i)
62 | print("[DATA] {} of positive patch ratio {:.4f}({}/{}) is removed".format(
63 | i, num_pos_patch/num_patch, num_pos_patch, num_patch))
64 | statistics_slide(all_slides)
65 | # 1.1 down sample the slides
66 | print("================ Down sample ================")
67 | np.random.shuffle(all_slides)
68 | all_slides = all_slides[:int(len(all_slides)*self.downsample)]
69 | self.num_slides = len(all_slides)
70 | self.num_patches = statistics_slide(all_slides)
71 |
72 | # 2. load all pre-trained patch features (by SimCLR in DSMIL)
73 | all_slides_name = [i.split('/')[-1] for i in all_slides]
74 | if train:
75 | all_slides_feat_file = glob.glob("")
76 | else:
77 | all_slides_feat_file = glob.glob("")
78 |
79 | self.slide_feat_all = np.zeros([self.num_patches, 512], dtype=np.float32)
80 | self.slide_patch_label_all = np.zeros([self.num_patches], dtype=np.long)
81 | self.patch_corresponding_slide_label = np.zeros([self.num_patches], dtype=np.long)
82 | self.patch_corresponding_slide_index = np.zeros([self.num_patches], dtype=np.long)
83 | self.patch_corresponding_slide_name = np.zeros([self.num_patches], dtype=' 3:
62 | all_patches_slide_i = sample(all_patches_slide_i, int(len(all_patches_slide_i)*self.patch_downsample))
63 | for j in all_patches_slide_i:
64 | if self.preload:
65 | self.all_patches[cnt_patch, :, :, :] = io.imread(j)
66 | else:
67 | self.all_patches.append(j)
68 | self.patch_corresponding_slide_label.append(int('LUSC' in i.split('/')[-4]))
69 | self.patch_corresponding_slide_index.append(cnt_slide)
70 | self.patch_corresponding_slide_name.append(i.split('/')[-1])
71 | cnt_patch = cnt_patch + 1
72 | cnt_slide = cnt_slide + 1
73 | if not self.preload:
74 | self.all_patches = np.array(self.all_patches)
75 | self.patch_corresponding_slide_label = np.array(self.patch_corresponding_slide_label)
76 | self.patch_corresponding_slide_index = np.array(self.patch_corresponding_slide_index)
77 | self.patch_corresponding_slide_name = np.array(self.patch_corresponding_slide_name)
78 |
79 | self.num_patches = len(self.all_patches)
80 | # 3.do some statistics
81 | print("[DATA INFO] num_slide is {}; num_patches is {}\n".format(self.num_slides, self.num_patches))
82 |
83 | def __getitem__(self, index):
84 | if self.preload:
85 | patch_image = self.all_patches[index]
86 | else:
87 | patch_image = io.imread(self.all_patches[index])
88 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index]
89 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index]
90 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index]
91 |
92 | patch_image = self.transform(Image.fromarray(np.uint8(patch_image), 'RGB'))
93 | patch_label = 0 # patch_label is not available in TCGA
94 | return patch_image, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index,
95 | patch_corresponding_slide_name], index
96 |
97 | def __len__(self):
98 | return self.num_patches
99 |
100 |
101 | class TCGA_LungCancer_Feat(torch.utils.data.Dataset):
102 | # @profile
103 | def __init__(self, train=True, downsample=1.0, return_bag=False):
104 | self.train = train
105 | self.return_bag = return_bag
106 | bags_csv = ''
107 | bags_path = pd.read_csv(bags_csv)
108 | train_path = bags_path.iloc[0:int(len(bags_path) * 0.8), :]
109 | test_path = bags_path.iloc[int(len(bags_path) * 0.8):, :]
110 | train_path = shuffle(train_path).reset_index(drop=True)
111 | test_path = shuffle(test_path).reset_index(drop=True)
112 |
113 | if downsample < 1.0:
114 | train_path = train_path.iloc[0:int(len(train_path) * downsample), :]
115 | test_path = test_path.iloc[0:int(len(test_path) * downsample), :]
116 |
117 | self.patch_feat_all = []
118 | self.patch_corresponding_slide_label = []
119 | self.patch_corresponding_slide_index = []
120 | self.patch_corresponding_slide_name = []
121 | if self.train:
122 | for i in tqdm(range(len(train_path)), desc='loading data'):
123 | label, feats = get_bag_feats(train_path.iloc[i])
124 | self.patch_feat_all.append(feats)
125 | self.patch_corresponding_slide_label.append(np.ones(feats.shape[0]) * label)
126 | self.patch_corresponding_slide_index.append(np.ones(feats.shape[0]) * i)
127 | self.patch_corresponding_slide_name.append(np.ones(feats.shape[0]) * i)
128 | else:
129 | for i in tqdm(range(len(test_path)), desc='loading data'):
130 | label, feats = get_bag_feats(test_path.iloc[i])
131 | self.patch_feat_all.append(feats)
132 | self.patch_corresponding_slide_label.append(np.ones(feats.shape[0]) * label)
133 | self.patch_corresponding_slide_index.append(np.ones(feats.shape[0]) * i)
134 | self.patch_corresponding_slide_name.append(np.ones(feats.shape[0]) * i)
135 |
136 | self.patch_feat_all = np.concatenate(self.patch_feat_all, axis=0).astype(np.float32)
137 | self.patch_corresponding_slide_label = np.concatenate(self.patch_corresponding_slide_label).astype(np.long)
138 | self.patch_corresponding_slide_index =np.concatenate(self.patch_corresponding_slide_index).astype(np.long)
139 | self.patch_corresponding_slide_name = np.concatenate(self.patch_corresponding_slide_name)
140 |
141 | self.num_patches = self.patch_feat_all.shape[0]
142 | self.patch_label_all = np.zeros([self.patch_feat_all.shape[0]], dtype=np.long) # Patch label is not available and set to 0 !
143 | # 3.do some statistics
144 | print("[DATA INFO] num_slide is {}; num_patches is {}\n".format(len(train_path), self.num_patches))
145 |
146 | def __getitem__(self, index):
147 | if self.return_bag:
148 | idx_patch_from_slide_i = np.where(self.patch_corresponding_slide_index == index)[0]
149 | bag = self.patch_feat_all[idx_patch_from_slide_i, :]
150 | patch_labels = self.patch_label_all[idx_patch_from_slide_i] # Patch label is not available and set to 0 !
151 | slide_label = self.patch_corresponding_slide_label[idx_patch_from_slide_i][0]
152 | slide_index = self.patch_corresponding_slide_index[idx_patch_from_slide_i][0]
153 | slide_name = self.patch_corresponding_slide_name[idx_patch_from_slide_i][0]
154 |
155 | # check data
156 | if self.patch_corresponding_slide_label[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_label[idx_patch_from_slide_i].min():
157 | raise
158 | if self.patch_corresponding_slide_index[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_index[idx_patch_from_slide_i].min():
159 | raise
160 | return bag, [patch_labels, slide_label, slide_index, slide_name], index
161 | else:
162 | patch_image = self.patch_feat_all[index]
163 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index]
164 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index]
165 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index]
166 |
167 | patch_label = self.patch_label_all[index] # Patch label is not available and set to 0 !
168 | return patch_image, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index,
169 | patch_corresponding_slide_name], index
170 |
171 | def __len__(self):
172 | if self.return_bag:
173 | return self.patch_corresponding_slide_index.max() + 1
174 | else:
175 | return self.num_patches
176 |
177 |
178 | def get_bag_feats(csv_file_df):
179 | feats_csv_path = '' + csv_file_df.iloc[0].split('/')[1] + '.csv'
180 | df = pd.read_csv(feats_csv_path)
181 | feats = shuffle(df).reset_index(drop=True)
182 | feats = feats.to_numpy()
183 | label = np.zeros(1)
184 | label[0] = csv_file_df.iloc[1]
185 | return label, feats
186 |
187 |
188 | if __name__ == '__main__':
189 | train_ds_feat = TCGA_LungCancer_Feat(train=True, downsample=1.0)
190 | test_ds_feat = TCGA_LungCancer_Feat(train=False, downsample=1.0)
191 |
192 | trans = transforms.Compose([
193 | transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
194 | transforms.RandomHorizontalFlip(p=0.5),
195 | transforms.RandomVerticalFlip(p=0.5),
196 | transforms.RandomRotation(degrees=90),
197 | transforms.ToTensor()])
198 | train_ds = TCGA_LungCancer(train=True, transform=None, downsample=0.1, drop_threshold=0, preload=False)
199 | val_ds = TCGA_LungCancer(train=False, transform=None, downsample=0.1, drop_threshold=0, preload=False)
200 | train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64,
201 | shuffle=True, num_workers=0, drop_last=False, pin_memory=True)
202 | val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64,
203 | shuffle=False, num_workers=0, drop_last=False, pin_memory=True)
204 | for data in tqdm(train_loader, desc='loading'):
205 | patch_img = data[0]
206 | label_patch = data[1][0]
207 | label_bag = data[1][1]
208 | idx = data[-1]
209 | print("END")
210 |
--------------------------------------------------------------------------------
/Datasets_loader/dataset_MIL_CIFAR.py:
--------------------------------------------------------------------------------
1 | """Pytorch Dataset object that loads perfectly balanced MNIST dataset in bag form."""
2 |
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import torch
6 | import torch.utils.data as data_utils
7 | from torchvision import datasets, transforms
8 | from PIL import Image
9 |
10 | import numpy as np
11 | from six.moves import cPickle as pickle
12 | import os
13 | import platform
14 | classes = ('plane', 'car', 'bird', 'cat',
15 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
16 |
17 | img_rows, img_cols = 32, 32
18 | input_shape = (img_rows, img_cols, 3)
19 |
20 |
21 | def load_pickle(f):
22 | version = platform.python_version_tuple()
23 | if version[0] == '2':
24 | return pickle.load(f)
25 | elif version[0] == '3':
26 | return pickle.load(f, encoding='latin1')
27 | raise ValueError("invalid python version: {}".format(version))
28 |
29 |
30 | def load_CIFAR_batch(filename):
31 | """ load single batch of cifar """
32 | with open(filename, 'rb') as f:
33 | datadict = load_pickle(f)
34 | X = datadict['data']
35 | Y = datadict['labels']
36 | X = X.reshape(10000, 3072).reshape(10000, 3, 32, 32)
37 | Y = np.array(Y)
38 | return X, Y
39 |
40 |
41 | def load_CIFAR10(ROOT):
42 | """ load all of cifar """
43 | xs = []
44 | ys = []
45 | for b in range(1,6):
46 | f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
47 | X, Y = load_CIFAR_batch(f)
48 | xs.append(X)
49 | ys.append(Y)
50 | Xtr = np.concatenate(xs, axis=0)
51 | Ytr = np.concatenate(ys)
52 | del X, Y
53 | Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
54 | return Xtr, Ytr, Xte, Yte
55 |
56 |
57 | def get_CIFAR10_data():
58 | # Load the raw CIFAR-10 data
59 | cifar10_dir = ''
60 | x_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
61 | x_train, y_train, X_test, y_test = torch.from_numpy(x_train), torch.from_numpy(y_train), \
62 | torch.from_numpy(X_test), torch.from_numpy(y_test)
63 | return x_train, y_train, X_test, y_test
64 |
65 |
66 | def random_shuffle(input_tensor):
67 | length = input_tensor.shape[0]
68 | random_idx = torch.randperm(length)
69 | output_tensor = input_tensor[random_idx]
70 | return output_tensor
71 |
72 |
73 | class CIFAR_WholeSlide_challenge(torch.utils.data.Dataset):
74 | def __init__(self, train, positive_num=[8, 9], negative_num=[0, 1, 2, 3, 4, 5, 6, 7],
75 | bag_length=10, return_bag=False, num_img_per_slide=600, pos_patch_ratio=0.1, pos_slide_ratio=0.5, transform=None, accompanyPos=True):
76 | self.train = train
77 | self.positive_num = positive_num # transform the N-class into 2-class
78 | self.negative_num = negative_num # transform the N-class into 2-class
79 | self.bag_length = bag_length
80 | self.return_bag = return_bag # return patch ot bag
81 | self.transform = transform # transform the patch image
82 | self.num_img_per_slide = num_img_per_slide
83 |
84 | if train:
85 | self.ds_data, self.ds_label, _, _ = get_CIFAR10_data()
86 | try:
87 | self.ds_data_simCLR_feat = torch.from_numpy(np.load("./Datasets_loader/all_feats_CIFAR.npy")[:50000, :]).float()
88 | print("Pre-trained feat found")
89 | except:
90 | print("No pre-trained feat found")
91 | else:
92 | _, _ , self.ds_data, self.ds_label = get_CIFAR10_data()
93 | try:
94 | self.ds_data_simCLR_feat = torch.from_numpy(np.load("./Datasets_loader/all_feats_CIFAR.npy")[50000:, :]).float()
95 | print("Pre-trained feat found")
96 | except:
97 | print("No pre-trained feat found")
98 |
99 | self.build_whole_slides(num_img=num_img_per_slide, positive_nums=positive_num, negative_nums=negative_num, pos_patch_ratio=pos_patch_ratio, pos_slide_ratio=pos_slide_ratio)
100 | print("")
101 |
102 | def build_whole_slides(self, num_img, positive_nums, negative_nums, pos_patch_ratio=0.1, pos_slide_ratio=0.5):
103 | # num_img: num of images per slide
104 | # positive patch ratio in each slide
105 |
106 | num_pos_per_slide = int(num_img * pos_patch_ratio)
107 | num_neg_per_slide = num_img - num_pos_per_slide
108 |
109 | idx_pos = []
110 | for num in positive_nums:
111 | idx_pos.append(torch.where(self.ds_label == num)[0])
112 | idx_pos = torch.cat(idx_pos).unsqueeze(1)
113 | idx_neg = []
114 | for num in negative_nums:
115 | idx_neg.append(torch.where(self.ds_label == num)[0])
116 | idx_neg = torch.cat(idx_neg).unsqueeze(1)
117 |
118 | idx_pos = random_shuffle(idx_pos)
119 | idx_neg = random_shuffle(idx_neg)
120 |
121 | # build pos slides using calculated
122 | num_pos_2PosSlides = int(idx_neg.numel() // ((1 - pos_slide_ratio) / (pos_patch_ratio*pos_slide_ratio) + (1 - pos_patch_ratio) / pos_patch_ratio))
123 | if num_pos_2PosSlides > idx_pos.shape[0]:
124 | num_pos_2PosSlides = idx_pos.shape[0]
125 | num_pos_2PosSlides = int(num_pos_2PosSlides // num_pos_per_slide * num_pos_per_slide)
126 | num_neg_2PosSlides = int(num_pos_2PosSlides * ((1-pos_patch_ratio)/pos_patch_ratio))
127 | num_neg_2NegSlides = int(num_pos_2PosSlides * ((1-pos_slide_ratio)/(pos_patch_ratio*pos_slide_ratio)))
128 |
129 | num_neg_2PosSlides = int(num_neg_2PosSlides // num_neg_per_slide * num_neg_per_slide)
130 | num_neg_2NegSlides = int(num_neg_2NegSlides // num_img * num_img)
131 |
132 | if num_neg_2PosSlides // num_neg_per_slide != num_pos_2PosSlides // num_pos_per_slide :
133 | num_diff_slide = num_pos_2PosSlides // num_pos_per_slide - num_neg_2PosSlides // num_neg_per_slide
134 | num_pos_2PosSlides = num_pos_2PosSlides - num_pos_per_slide * num_diff_slide
135 |
136 | idx_pos = idx_pos[0:num_pos_2PosSlides]
137 | idx_neg = idx_neg[0:(num_neg_2PosSlides+num_neg_2NegSlides)]
138 |
139 | idx_pos_toPosSlide = idx_pos[:].reshape(-1, num_pos_per_slide)
140 | idx_neg_toPosSlide = idx_neg[0:num_neg_2PosSlides].reshape(-1, num_neg_per_slide)
141 | idx_neg_toNegSlide = idx_neg[num_neg_2PosSlides:].reshape(-1, num_img)
142 |
143 | idx_pos_slides = torch.cat([idx_pos_toPosSlide, idx_neg_toPosSlide], dim=1)
144 | # idx_pos_slides = idx_pos_slides[:, torch.randperm(idx_pos_slides.shape[1])] # shuffle pos and neg idx
145 | for i_ in range(idx_pos_slides.shape[0]):
146 | idx_pos_slides[i_, :] = idx_pos_slides[i_, torch.randperm(idx_pos_slides.shape[1])]
147 | idx_neg_slides = idx_neg_toNegSlide
148 |
149 | self.idx_all_slides = torch.cat([idx_pos_slides, idx_neg_slides], dim=0)
150 | self.label_all_slides = torch.cat([torch.ones(idx_pos_slides.shape[0]), torch.zeros(idx_neg_slides.shape[0])], dim=0)
151 | self.label_all_slides = self.label_all_slides.unsqueeze(1).repeat([1,self.idx_all_slides.shape[1]]).long()
152 | print("[Info] dataset: {}".format(self.idx_all_slides.shape))
153 | #self.visualize(idx_pos_slides[0])
154 |
155 | def __getitem__(self, index):
156 | if self.return_bag:
157 | bagPerSlide = self.idx_all_slides.shape[1] // self.bag_length
158 | idx_slide = index // bagPerSlide
159 | idx_bag_in_slide = index % bagPerSlide
160 | idx_images = self.idx_all_slides[idx_slide, (idx_bag_in_slide*self.bag_length):((idx_bag_in_slide+1)*self.bag_length)]
161 | bag = self.ds_data[idx_images]
162 | patch_labels_raw = self.ds_label[idx_images]
163 | patch_labels = torch.zeros_like(patch_labels_raw)
164 | for num in self.positive_num:
165 | patch_labels[patch_labels_raw == num] = 1
166 | patch_labels = patch_labels.long()
167 | slide_label = self.label_all_slides[idx_slide, 0]
168 | slide_name = str(idx_slide)
169 | return bag.float()/255, [patch_labels, slide_label, idx_slide, slide_name], index
170 | else:
171 | idx_image = self.idx_all_slides.flatten()[index]
172 | slide_label = self.label_all_slides.flatten()[index]
173 | idx_slide = index // self.num_img_per_slide
174 | slide_name = str(idx_slide)
175 | patch = self.ds_data[idx_image]
176 | patch_label = self.ds_label[idx_image]
177 | patch_label = int(patch_label in self.positive_num)
178 | return patch.float()/255, [patch_label, slide_label, idx_slide, slide_name], index
179 |
180 | def __len__(self):
181 | if self.return_bag:
182 | return self.idx_all_slides.shape[1] // self.bag_length * self.idx_all_slides.shape[0]
183 | else:
184 | return self.idx_all_slides.numel()
185 |
186 | def visualize(self, idx, number_row=20, number_col=30):
187 | # idx should be of shape num_img_per_slide
188 | slide = self.ds_data[idx].clone() # num_img_per_slide * 3 * 32 * 32
189 | patch_label = self.ds_label[idx].clone()
190 | idx_pos_patch = []
191 | for num in self.positive_num:
192 | idx_pos_patch.append(torch.where(patch_label == num)[0])
193 | idx_pos_patch = torch.cat(idx_pos_patch)
194 | slide[idx_pos_patch, 0, :2, :] = 255
195 | slide[idx_pos_patch, 0, -2:, :] = 255
196 | slide[idx_pos_patch, 0, :, :2] = 255
197 | slide[idx_pos_patch, 0, :, -2:] = 255
198 |
199 | slide[idx_pos_patch, 1, :2, :] = 0
200 | slide[idx_pos_patch, 1, -2:, :] = 0
201 | slide[idx_pos_patch, 1, :, :2] = 0
202 | slide[idx_pos_patch, 1, :, -2:] = 0
203 |
204 | slide[idx_pos_patch, 2, :2, :] = 0
205 | slide[idx_pos_patch, 2, -2:, :] = 0
206 | slide[idx_pos_patch, 2, :, :2] = 0
207 | slide[idx_pos_patch, 2, :, -2:] = 0
208 |
209 | slide = slide.unsqueeze(0).reshape(number_row, number_col, 3, 32, 32).permute(0, 3, 1, 4, 2).reshape(number_row*32, number_col*32, 3)
210 | import utliz
211 | utliz.show_img(slide)
212 | return 0
213 |
214 |
215 | if __name__ == "__main__":
216 | # Invoke the above function to get our data.
217 | x_train, y_train,x_test, y_test = get_CIFAR10_data()
218 |
219 | # for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.3, 0.5]:
220 | for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.5, 0.7]:
221 | print("=========== pos slide ratio: {} ===========".format(pos_slide_ratio))
222 | train_ds = CIFAR_WholeSlide_challenge(train=True, positive_num=[9], negative_num=[0, 1, 2, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None)
223 | train_loader = data_utils.DataLoader(train_ds, batch_size=64, shuffle=True, drop_last=False)
224 | test_ds_part1 = CIFAR_WholeSlide_challenge(train=False, positive_num=[9], negative_num=[0, 1, 2, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None)
225 | test_loader_part1 = data_utils.DataLoader(test_ds_part1, batch_size=64, shuffle=True, drop_last=False)
226 | print("")
227 | print("")
228 |
229 |
--------------------------------------------------------------------------------
/Datasets_loader/dataset_CervicalCancer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.utils.data as data_utils
5 | from torchvision import datasets, transforms
6 | from PIL import Image
7 | import os
8 | import glob
9 | from skimage import io
10 | from tqdm import tqdm
11 |
12 |
13 | def shuffle_downsample_myself(input_list, downsample):
14 | np.random.shuffle(input_list)
15 | if downsample == 1:
16 | output_list = input_list
17 | elif downsample < 1:
18 | output_list = input_list[:int(len(input_list) * downsample)]
19 | elif downsample > 1:
20 | downsample = int(downsample)
21 | if downsample > len(input_list):
22 | output_list = input_list
23 | else:
24 | output_list = input_list[:downsample]
25 | else:
26 | raise
27 | return output_list
28 |
29 |
30 | class CervicalCaner_16(torch.utils.data.Dataset):
31 | # @profile
32 | def __init__(self, root_dir="",
33 | train=True, transform=None, downsample=1.0, preload=True, return_bag=False):
34 | self.root_dir = root_dir
35 | self.train = train
36 | self.return_bag = return_bag
37 | self.transform = transform
38 | self.downsample = downsample
39 | self.preload = preload
40 | if self.transform is None:
41 | self.transform = transforms.Compose([transforms.ToTensor()])
42 | if train:
43 | self.root_dir = os.path.join(self.root_dir, "train_datasets")
44 | else:
45 | self.root_dir = os.path.join(self.root_dir, "test_datasets")
46 |
47 | all_slides = glob.glob(self.root_dir + "/*/*")
48 |
49 | self.num_slides = len(all_slides)
50 |
51 | # 2.extract all available patches and build corresponding labels
52 | if self.preload:
53 | self.num_patches = 0
54 | for i in tqdm(all_slides, ascii=True, desc='scanning data'):
55 | self.num_patches = self.num_patches + len(shuffle_downsample_myself(os.listdir(i), self.downsample))
56 | self.all_patches = np.zeros([self.num_patches, 224, 224, 3], dtype=np.uint8)
57 | else:
58 | self.all_patches = []
59 | self.patch_label = []
60 | self.patch_corresponding_slide_label = []
61 | self.patch_corresponding_slide_index = []
62 | self.patch_corresponding_slide_name = []
63 | cnt_slide = 0
64 | cnt_patch = 0
65 | for i in tqdm(all_slides, ascii=True, desc='preload data'):
66 | patches_from_slide_i = os.listdir(i)
67 | patches_from_slide_i = shuffle_downsample_myself(patches_from_slide_i, self.downsample)
68 | for j in patches_from_slide_i:
69 | if self.preload:
70 | self.all_patches[cnt_patch, :, :, :] = io.imread(os.path.join(i, j))
71 | else:
72 | self.all_patches.append(os.path.join(i, j))
73 | self.patch_label.append(0)
74 | self.patch_corresponding_slide_label.append(int('P' == i.split('/')[-2]))
75 | self.patch_corresponding_slide_index.append(cnt_slide)
76 | self.patch_corresponding_slide_name.append(i.split('/')[-1])
77 | cnt_patch = cnt_patch + 1
78 | cnt_slide = cnt_slide + 1
79 | if not self.preload:
80 | self.all_patches = np.array(self.all_patches)
81 | self.num_patches = self.all_patches.shape[0]
82 | self.all_patches = self.all_patches.transpose(0, 3, 1, 2)
83 | self.patch_label = np.array(self.patch_label, dtype=np.long) # [Attention] patch label is unavailable and set to 0
84 | self.patch_corresponding_slide_label = np.array(self.patch_corresponding_slide_label, dtype=np.long)
85 | self.patch_corresponding_slide_index = np.array(self.patch_corresponding_slide_index, dtype=np.long)
86 | self.patch_corresponding_slide_name = np.array(self.patch_corresponding_slide_name)
87 |
88 | # 3.do some statistics
89 | print("[DATA INFO] num_slide is {}; num_patches is {}\npos_patch_ratio is unknown".format(
90 | self.num_slides, self.num_patches))
91 | print("")
92 |
93 | def __getitem__(self, index):
94 | if self.return_bag:
95 | idx_patch_from_slide_i = np.where(self.patch_corresponding_slide_index == index)[0]
96 | bag = self.all_patches[idx_patch_from_slide_i, :, :, :]
97 | bag = bag.astype(np.float32)/255
98 | patch_labels = self.patch_label[idx_patch_from_slide_i] # Patch label is not available and set to 0 !
99 | slide_label = self.patch_corresponding_slide_label[idx_patch_from_slide_i][0]
100 | slide_index = self.patch_corresponding_slide_index[idx_patch_from_slide_i][0]
101 | slide_name = self.patch_corresponding_slide_name[idx_patch_from_slide_i][0]
102 |
103 | # check data
104 | if self.patch_corresponding_slide_label[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_label[idx_patch_from_slide_i].min():
105 | raise
106 | if self.patch_corresponding_slide_index[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_index[idx_patch_from_slide_i].min():
107 | raise
108 | return bag, [patch_labels, slide_label, slide_index, slide_name], index
109 | else:
110 | if self.preload:
111 | patch_image = self.all_patches[index]
112 | else:
113 | patch_image = io.imread(self.all_patches[index])
114 | patch_label = self.patch_label[index] # [Attention] patch label is unavailable and set to 0
115 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index]
116 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index]
117 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index]
118 |
119 | # patch_image = self.transform(Image.fromarray(np.uint8(patch_image), 'RGB'))
120 | patch_image = patch_image.astype(np.float32)/255
121 | return patch_image, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index,
122 | patch_corresponding_slide_name], index
123 |
124 | def __len__(self):
125 | if self.return_bag:
126 | return self.patch_corresponding_slide_index.max() + 1
127 | else:
128 | return self.num_patches
129 |
130 |
131 | class CervicalCaner_16_feat(torch.utils.data.Dataset):
132 | def __init__(self, root_dir="",
133 | train=True, return_bag=True):
134 | self.root_dir = root_dir
135 | self.train = train
136 | self.return_bag = return_bag
137 |
138 | # 1. load all featreus and slide label and index
139 | save_path = ""
140 | if train:
141 | self.all_patches = np.load(os.path.join(save_path, "train_feats.npy"))
142 | self.patch_corresponding_slide_label = np.load(os.path.join(save_path, "train_corresponding_slide_label.npy"))
143 | self.patch_corresponding_slide_index = np.load(os.path.join(save_path, "train_corresponding_slide_index.npy"))
144 | self.patch_label = np.zeros_like(self.patch_corresponding_slide_label) # [Attention] patch label is unavailable and set to 0
145 | self.patch_corresponding_slide_name = self.patch_corresponding_slide_index
146 | else:
147 | self.all_patches = np.load(os.path.join(save_path, "test_feats.npy"))
148 | self.patch_corresponding_slide_label = np.load(os.path.join(save_path, "test_corresponding_slide_label.npy"))
149 | self.patch_corresponding_slide_index = np.load(os.path.join(save_path, "test_corresponding_slide_index.npy"))
150 | self.patch_label = np.zeros_like(self.patch_corresponding_slide_label) # [Attention] patch label is unavailable and set to 0
151 | self.patch_corresponding_slide_name = self.patch_corresponding_slide_index
152 |
153 | self.num_patches = self.all_patches.shape[0]
154 | self.num_slides = self.patch_corresponding_slide_index.max() + 1
155 | print("[DATA INFO] num_slide is {}; num_patches is {}\npos_patch_ratio is unknown".format(
156 | self.num_slides, self.num_patches))
157 |
158 | # 2. sort instances features into bag
159 | self.slide_feat_all = []
160 | self.slide_label_all = []
161 | self.slide_patch_label_all = []
162 | for i in range(self.num_slides):
163 | idx_from_same_slide = self.patch_corresponding_slide_index == i
164 | idx_from_same_slide = np.nonzero(idx_from_same_slide)[0]
165 |
166 | self.slide_feat_all.append(self.all_patches[idx_from_same_slide])
167 | if self.patch_corresponding_slide_label[idx_from_same_slide].max() != self.patch_corresponding_slide_label[idx_from_same_slide].min():
168 | raise
169 | self.slide_label_all.append(self.patch_corresponding_slide_label[idx_from_same_slide].max())
170 | self.slide_patch_label_all.append(np.zeros(idx_from_same_slide.shape[0]).astype(np.long))
171 | print("")
172 |
173 | def __getitem__(self, index):
174 | if self.return_bag:
175 | slide_feat = self.slide_feat_all[index]
176 | slide_label = self.slide_label_all[index]
177 | slide_patch_label = self.slide_patch_label_all[index]
178 | return slide_feat, [slide_patch_label, slide_label], index
179 | else:
180 | patch_image_feat = self.all_patches[index]
181 | patch_label = self.patch_label[index] # [Attention] patch label is unavailable and set to 0
182 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index]
183 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index]
184 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index]
185 |
186 | return patch_image_feat, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index,
187 | patch_corresponding_slide_name], index
188 |
189 | def __len__(self):
190 | if self.return_bag:
191 | return self.num_slides
192 | else:
193 | return self.num_patches
194 |
195 |
196 | if __name__ == '__main__':
197 | train_ds_return_bag = CervicalCaner_16_feat(train=True, return_bag=True)
198 | train_ds_return_instance = CervicalCaner_16_feat(train=True, return_bag=False)
199 |
200 | train_ds = CervicalCaner_16(root_dir=root_dir, train=True, transform=None, downsample=10, preload=True)
201 | val_ds = CervicalCaner_16(root_dir=root_dir, train=False, transform=None, downsample=10, preload=True)
202 | train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64,
203 | shuffle=True, num_workers=0, drop_last=False, pin_memory=True)
204 | val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64,
205 | shuffle=False, num_workers=0, drop_last=False, pin_memory=True)
206 | for data in train_loader:
207 | patch_img = data[0]
208 | label_patch = data[1][0]
209 | label_bag = data[1][1]
210 | idx = data[-1]
211 | print("END")
212 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import math
4 | import numpy as np
5 | from scipy.special import logsumexp
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.init as init
10 |
11 | from torch.nn import ModuleList
12 | from torchvision.utils import make_grid
13 |
14 | class AverageMeter(object):
15 | """Computes and stores the average and current value"""
16 | def __init__(self):
17 | self.reset()
18 |
19 | def reset(self):
20 | self.val = 0
21 | self.avg = 0
22 | self.sum = 0
23 | self.count = 0
24 |
25 | def update(self, val, n=1):
26 | self.val = val
27 | self.sum += val * n
28 | self.count += n
29 | self.avg = self.sum / self.count
30 |
31 |
32 | def setup_runtime(seed=0, cuda_dev_id=[0]):
33 | """Initialize CUDA, CuDNN and the random seeds."""
34 | # Setup CUDA
35 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
36 | if len(cuda_dev_id) == 1:
37 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0])
38 | else:
39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0])
40 | for i in cuda_dev_id[1:]:
41 | os.environ["CUDA_VISIBLE_DEVICES"] += "," + str(i)
42 |
43 | # global cuda_dev_id
44 | _cuda_device_id = cuda_dev_id
45 | if torch.cuda.is_available():
46 | torch.backends.cudnn.enabled = True
47 | torch.backends.cudnn.benchmark = True
48 | torch.backends.cudnn.deterministic = False
49 | # Fix random seeds
50 | random.seed(seed)
51 | np.random.seed(seed)
52 | torch.manual_seed(seed)
53 | if torch.cuda.is_available():
54 | torch.cuda.manual_seed_all(seed)
55 |
56 | class TotalAverage():
57 | def __init__(self):
58 | self.reset()
59 |
60 | def reset(self):
61 | self.val = 0.
62 | self.mass = 0.
63 | self.sum = 0.
64 | self.avg = 0.
65 |
66 | def update(self, val, mass=1):
67 | self.val = val
68 | self.mass += mass
69 | self.sum += val * mass
70 | self.avg = self.sum / self.mass
71 |
72 |
73 | class MovingAverage():
74 | def __init__(self, intertia=0.9):
75 | self.intertia = intertia
76 | self.reset()
77 |
78 | def reset(self):
79 | self.avg = 0.
80 |
81 | def update(self, val):
82 | self.avg = self.intertia * self.avg + (1 - self.intertia) * val
83 |
84 |
85 | def accuracy(output, target, topk=(1,)):
86 | """Computes the precision@k for the specified values of k."""
87 | with torch.no_grad():
88 | maxk = max(topk)
89 | batch_size = target.size(0)
90 |
91 | _, pred = output.topk(maxk, 1, True, True)
92 | pred = pred.t()
93 | correct = pred.eq(target.view(1, -1).expand_as(pred))
94 |
95 | res = []
96 | for k in topk:
97 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
98 | res.append(correct_k.mul_(100.0 / batch_size))
99 | return res
100 |
101 | def write_conv(writer, model, epoch, sobel=False):
102 | if not sobel:
103 | conv1_ = make_grid(list(ModuleList(list(model.children())[0].children())[0].parameters())[0],
104 | nrow=8, normalize=True, scale_each=True)
105 | writer.add_image('conv1', conv1_, epoch)
106 | else:
107 | conv1_sobel_w = list(ModuleList(list(model.children())[0].children())[0].parameters())[0]
108 | conv1_ = make_grid(conv1_sobel_w[:, 0:1, :, :], nrow=8,
109 | normalize=True, scale_each=True)
110 | writer.add_image('conv1_sobel_1', conv1_, epoch)
111 | conv2_ = make_grid(conv1_sobel_w[:, 1:2, :, :], nrow=8,
112 | normalize=True, scale_each=True)
113 | writer.add_image('conv1_sobel_2', conv2_, epoch)
114 | conv1_x = make_grid(torch.sum(conv1_sobel_w[:, :, :, :], 1, keepdim=True), nrow=8,
115 | normalize=True, scale_each=True)
116 | writer.add_image('conv1', conv1_x, epoch)
117 |
118 |
119 | ### LP stuff ###
120 | def absorb_bn(module, bn_module):
121 | w = module.weight.data
122 | if module.bias is None:
123 | if isinstance(module, nn.Linear):
124 | zeros = torch.Tensor(module.out_features).zero_().type(w.type())
125 | else:
126 | zeros = torch.Tensor(module.out_channels).zero_().type(w.type())
127 | module.bias = nn.Parameter(zeros)
128 | b = module.bias.data
129 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5)
130 | if isinstance(module, nn.Conv2d):
131 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w))
132 | else:
133 | w.mul_(invstd.unsqueeze(1).expand_as(w))
134 | b.add_(-bn_module.running_mean).mul_(invstd)
135 |
136 | if bn_module.affine:
137 | if isinstance(module, nn.Conv2d):
138 | w.mul_(bn_module.weight.data.view(w.size(0), 1, 1, 1).expand_as(w))
139 | else:
140 | w.mul_(bn_module.weight.data.unsqueeze(1).expand_as(w))
141 | b.mul_(bn_module.weight.data).add_(bn_module.bias.data)
142 |
143 | bn_module.reset_parameters()
144 | bn_module.register_buffer('running_mean', None)
145 | bn_module.register_buffer('running_var', None)
146 | bn_module.affine = False
147 | bn_module.register_parameter('weight', None)
148 | bn_module.register_parameter('bias', None)
149 |
150 |
151 | def is_bn(m):
152 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)
153 |
154 |
155 | def is_absorbing(m):
156 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)
157 |
158 |
159 | def search_absorb_bn(model):
160 | prev = None
161 | for m in model.children():
162 | if is_bn(m) and is_absorbing(prev):
163 | print("absorbing",m)
164 | absorb_bn(prev, m)
165 | search_absorb_bn(m)
166 | prev = m
167 |
168 |
169 | class View(nn.Module):
170 | """A shape adaptation layer to patch certain networks."""
171 | def __init__(self):
172 | super(View, self).__init__()
173 |
174 | def forward(self, x):
175 | return x.view(x.shape[0], -1)
176 |
177 |
178 | def sequential_skipping_bn_cut(model):
179 | mods = []
180 | layers = list(model.features) + [View()]
181 | if 'sobel' in dict(model.named_children()).keys():
182 | layers = list(model.sobel) + layers
183 | for m in nn.Sequential(*(layers)).children():
184 | if not is_bn(m):
185 | mods.append(m)
186 | return nn.Sequential(*mods)
187 |
188 |
189 | def py_softmax(x, axis=None):
190 | """stable softmax"""
191 | return np.exp(x - logsumexp(x, axis=axis, keepdims=True))
192 |
193 | def warmup_batchnorm(model, data_loader, device, batches=100):
194 | """
195 | Run some batches through all parts of the model to warmup the running
196 | stats for batchnorm layers.
197 | """
198 | model.train()
199 | for i, q in enumerate(data_loader):
200 | images = q[0]
201 | if i == batches:
202 | break
203 | images = images.to(device)
204 | _ = model(images)
205 |
206 | def init_pytorch_defaults(m, version='041'):
207 | '''
208 | copied from AMDIM repo: https://github.com/Philip-Bachman/amdim-public/
209 | note from me: haven't checked systematically if this improves results
210 | '''
211 | if version == '041':
212 | # print('init.pt041: {0:s}'.format(str(m.weight.data.size())))
213 | if isinstance(m, nn.Linear):
214 | stdv = 1. / math.sqrt(m.weight.size(1))
215 | m.weight.data.uniform_(-stdv, stdv)
216 | if m.bias is not None:
217 | m.bias.data.uniform_(-stdv, stdv)
218 | elif isinstance(m, nn.Conv2d):
219 | n = m.in_channels
220 | for k in m.kernel_size:
221 | n *= k
222 | stdv = 1. / math.sqrt(n)
223 | m.weight.data.uniform_(-stdv, stdv)
224 | if m.bias is not None:
225 | m.bias.data.uniform_(-stdv, stdv)
226 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
227 | if m.affine:
228 | m.weight.data.uniform_()
229 | m.bias.data.zero_()
230 | else:
231 | assert False
232 | elif version == '100':
233 | # print('init.pt100: {0:s}'.format(str(m.weight.data.size())))
234 | if isinstance(m, nn.Linear):
235 | init.kaiming_uniform_(m.weight, a=math.sqrt(5))
236 | if m.bias is not None:
237 | fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight)
238 | bound = 1 / math.sqrt(fan_in)
239 | init.uniform_(m.bias, -bound, bound)
240 | elif isinstance(m, nn.Conv2d):
241 | n = m.in_channels
242 | init.kaiming_uniform_(m.weight, a=math.sqrt(5))
243 | if m.bias is not None:
244 | fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight)
245 | bound = 1 / math.sqrt(fan_in)
246 | init.uniform_(m.bias, -bound, bound)
247 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
248 | if m.affine:
249 | m.weight.data.uniform_()
250 | m.bias.data.zero_()
251 | else:
252 | assert False
253 | elif version == 'custom':
254 | # print('init.custom: {0:s}'.format(str(m.weight.data.size())))
255 | if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
256 | init.normal_(m.weight.data, mean=1, std=0.02)
257 | init.constant_(m.bias.data, 0)
258 | else:
259 | assert False
260 | else:
261 | assert False
262 |
263 |
264 | def weight_init(m):
265 | '''
266 | Usage:
267 | model = Model()
268 | model.apply(weight_init)
269 | '''
270 | if isinstance(m, nn.Linear):
271 | init_pytorch_defaults(m, version='041')
272 | elif isinstance(m, nn.Conv2d):
273 | init_pytorch_defaults(m, version='041')
274 | elif isinstance(m, nn.BatchNorm1d):
275 | init_pytorch_defaults(m, version='041')
276 | elif isinstance(m, nn.BatchNorm2d):
277 | init_pytorch_defaults(m, version='041')
278 | elif isinstance(m, nn.Conv1d):
279 | init.normal_(m.weight.data)
280 | if m.bias is not None:
281 | init.normal_(m.bias.data)
282 | elif isinstance(m, nn.ConvTranspose1d):
283 | init.normal_(m.weight.data)
284 | if m.bias is not None:
285 | init.normal_(m.bias.data)
286 | elif isinstance(m, nn.ConvTranspose2d):
287 | init.xavier_normal_(m.weight.data)
288 | if m.bias is not None:
289 | init.normal_(m.bias.data)
290 |
291 |
292 | def search_set_bn_eval(model,toeval):
293 | for m in model.children():
294 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
295 | if toeval:
296 | m.eval()
297 | else:
298 | m.train()
299 | search_set_bn_eval(m, toeval)
300 |
301 | def prepmodel(model, modelpath):
302 | dat = torch.load(modelpath, map_location=lambda storage, loc: storage) # ['model']
303 | from collections import OrderedDict
304 | new_state_dict = OrderedDict()
305 | for k, v in dat.items():
306 | name = k.replace('module.', '') # remove `module.`
307 | new_state_dict[name] = v
308 | model.load_state_dict(new_state_dict)
309 | del dat
310 | for param in model.features.parameters():
311 | param.requires_grad = False
312 |
313 | if model.headcount > 1:
314 | for i in range(model.headcount):
315 | setattr(model, "top_layer%d" % i, None)
316 |
317 | model.top_layer = nn.Sequential(nn.Linear(2048, 1000))
318 | model.headcount = 1
319 | model.withfeature = False
320 | model.return_feature_only = False
--------------------------------------------------------------------------------
/Datasets_loader/dataset_MIL_NCTCRCHE.py:
--------------------------------------------------------------------------------
1 | """Pytorch Dataset object that loads perfectly balanced MNIST dataset in bag form."""
2 |
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import torch
6 | import torch.utils.data as data_utils
7 | from torchvision import datasets, transforms
8 | from PIL import Image
9 | from tqdm import tqdm
10 | import numpy as np
11 | from six.moves import cPickle as pickle
12 | import os
13 | from sklearn.model_selection import train_test_split
14 | import platform
15 |
16 |
17 | Patho_classes = ['NORM', 'MUC', 'TUM', 'STR', 'LYM', 'BACK', 'MUS', 'DEB', 'ADI']
18 |
19 |
20 | def load_Pathdata(Root, downsample_ratio=1.0):
21 | X = []
22 | Y = []
23 | X_train = []
24 | X_test = []
25 | Y_train = []
26 | Y_test = []
27 | for index_class, folder in enumerate(Patho_classes):
28 | all_files_class_i = os.listdir(os.path.join(Root, folder))
29 | all_files_class_i = all_files_class_i[:int(len(all_files_class_i)*downsample_ratio)]
30 | for file in tqdm(all_files_class_i):
31 | image = np.array(Image.open(os.path.join(Root, folder, file)).convert("RGB")).transpose(2, 0, 1)
32 | # image = np.array(cv2.imread(os.path.join(Root, folder, file)),cv2.IMREAD_UNCHANGED)
33 | label = index_class
34 | X.append(image)
35 | Y.append(label)
36 | train_X, test_X, train_Y, test_Y = train_test_split(np.array(X), np.array(Y), test_size=0.2, random_state=42)
37 | X_train.append(train_X)
38 | X_test.append(test_X)
39 | Y_train.append(train_Y)
40 | Y_test.append(test_Y)
41 | X=[]
42 | Y=[]
43 | X_train = [b for a in X_train for b in a]
44 | X_test = [b for a in X_test for b in a]
45 | Y_train = [b for a in Y_train for b in a]
46 | Y_test = [b for a in Y_test for b in a]
47 |
48 | return np.array(X_train), np.array(X_test), np.array(Y_train), np.array(Y_test)
49 |
50 |
51 | def get_Path10_data(Path10_dir='', downsample_ratio=1.0):
52 | # Load the raw Path-10 data
53 | X_train, X_test, Y_train, Y_test = load_Pathdata(Path10_dir, downsample_ratio)
54 | x_train, y_train, X_test, y_test = torch.from_numpy(X_train), torch.from_numpy(Y_train), \
55 | torch.from_numpy(X_test), torch.from_numpy(Y_test)
56 | return x_train, y_train, X_test, y_test
57 |
58 |
59 | def random_shuffle(input_tensor):
60 | length = input_tensor.shape[0]
61 | random_idx = torch.randperm(length)
62 | output_tensor = input_tensor[random_idx]
63 | return output_tensor
64 |
65 |
66 | class NCT_WholeSlide_challenge(torch.utils.data.Dataset):
67 | def __init__(self, ds_data, ds_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8],
68 | bag_length=10, return_bag=False, num_img_per_slide=600, pos_patch_ratio=0.1, pos_slide_ratio=0.5, transform=None):
69 |
70 | self.positive_num = positive_num # transform the N-class into 2-class
71 | self.negative_num = negative_num # transform the N-class into 2-class
72 | self.bag_length = bag_length
73 | self.return_bag = return_bag # return patch ot bag
74 | self.transform = transform # transform the patch image
75 | self.num_img_per_slide = num_img_per_slide
76 |
77 | self.ds_data, self.ds_label = ds_data, ds_label
78 | self.build_whole_slides(num_img=num_img_per_slide, positive_nums=positive_num, negative_nums=negative_num, pos_patch_ratio=pos_patch_ratio, pos_slide_ratio=pos_slide_ratio)
79 | print("")
80 |
81 | def build_whole_slides(self, num_img, positive_nums, negative_nums, pos_patch_ratio=0.1, pos_slide_ratio=0.5):
82 | # num_img: num of images per slide
83 | # positive patch ratio in each slide
84 |
85 | num_pos_per_slide = int(num_img * pos_patch_ratio)
86 | num_neg_per_slide = num_img - num_pos_per_slide
87 |
88 | idx_pos = []
89 | for num in positive_nums:
90 | idx_pos.append(torch.where(self.ds_label == num)[0])
91 | idx_pos = torch.cat(idx_pos).unsqueeze(1)
92 | idx_neg = []
93 | for num in negative_nums:
94 | idx_neg.append(torch.where(self.ds_label == num)[0])
95 | idx_neg = torch.cat(idx_neg).unsqueeze(1)
96 |
97 | idx_pos = random_shuffle(idx_pos)
98 | idx_neg = random_shuffle(idx_neg)
99 |
100 | # build pos slides using calculated
101 | num_pos_2PosSlides = int(idx_neg.numel() // ((1 - pos_slide_ratio) / (pos_patch_ratio*pos_slide_ratio) + (1 - pos_patch_ratio) / pos_patch_ratio))
102 | if num_pos_2PosSlides > idx_pos.shape[0]:
103 | num_pos_2PosSlides = idx_pos.shape[0]
104 | num_pos_2PosSlides = int(num_pos_2PosSlides // num_pos_per_slide * num_pos_per_slide)
105 | num_neg_2PosSlides = int(num_pos_2PosSlides * ((1-pos_patch_ratio)/pos_patch_ratio))
106 | num_neg_2NegSlides = int(num_pos_2PosSlides * ((1-pos_slide_ratio)/(pos_patch_ratio*pos_slide_ratio)))
107 |
108 | num_neg_2PosSlides = int(num_neg_2PosSlides // num_neg_per_slide * num_neg_per_slide)
109 | num_neg_2NegSlides = int(num_neg_2NegSlides // num_img * num_img)
110 |
111 | if num_neg_2PosSlides // num_neg_per_slide != num_pos_2PosSlides // num_pos_per_slide :
112 | num_diff_slide = num_pos_2PosSlides // num_pos_per_slide - num_neg_2PosSlides // num_neg_per_slide
113 | num_pos_2PosSlides = num_pos_2PosSlides - num_pos_per_slide * num_diff_slide
114 |
115 | idx_pos = idx_pos[0:num_pos_2PosSlides]
116 | idx_neg = idx_neg[0:(num_neg_2PosSlides+num_neg_2NegSlides)]
117 |
118 | idx_pos_toPosSlide = idx_pos[:].reshape(-1, num_pos_per_slide)
119 | idx_neg_toPosSlide = idx_neg[0:num_neg_2PosSlides].reshape(-1, num_neg_per_slide)
120 | idx_neg_toNegSlide = idx_neg[num_neg_2PosSlides:].reshape(-1, num_img)
121 |
122 | idx_pos_slides = torch.cat([idx_pos_toPosSlide, idx_neg_toPosSlide], dim=1)
123 | # idx_pos_slides = idx_pos_slides[:, torch.randperm(idx_pos_slides.shape[1])] # shuffle pos and neg idx
124 | for i_ in range(idx_pos_slides.shape[0]):
125 | idx_pos_slides[i_, :] = idx_pos_slides[i_, torch.randperm(idx_pos_slides.shape[1])]
126 | idx_neg_slides = idx_neg_toNegSlide
127 |
128 | self.idx_all_slides = torch.cat([idx_pos_slides, idx_neg_slides], dim=0)
129 | self.label_all_slides = torch.cat([torch.ones(idx_pos_slides.shape[0]), torch.zeros(idx_neg_slides.shape[0])], dim=0)
130 | self.label_all_slides = self.label_all_slides.unsqueeze(1).repeat([1,self.idx_all_slides.shape[1]]).long()
131 | print("[Info] dataset: {}".format(self.idx_all_slides.shape))
132 | #self.visualize(idx_pos_slides[0])
133 |
134 | def __getitem__(self, index):
135 | if self.return_bag:
136 | bagPerSlide = self.idx_all_slides.shape[1] // self.bag_length
137 | idx_slide = index // bagPerSlide
138 | idx_bag_in_slide = index % bagPerSlide
139 | idx_images = self.idx_all_slides[idx_slide, (idx_bag_in_slide*self.bag_length):((idx_bag_in_slide+1)*self.bag_length)]
140 | bag = self.ds_data[idx_images]
141 | patch_labels_raw = self.ds_label[idx_images]
142 | patch_labels = torch.zeros_like(patch_labels_raw)
143 | for num in self.positive_num:
144 | patch_labels[patch_labels_raw == num] = 1
145 | patch_labels = patch_labels.long()
146 | slide_label = self.label_all_slides[idx_slide, 0]
147 | slide_name = str(idx_slide)
148 | return bag.float()/255, [patch_labels, slide_label, idx_slide, slide_name], index
149 | else:
150 | idx_image = self.idx_all_slides.flatten()[index]
151 | slide_label = self.label_all_slides.flatten()[index]
152 | idx_slide = index // self.num_img_per_slide
153 | slide_name = str(idx_slide)
154 | patch = self.ds_data[idx_image]
155 | patch_label = self.ds_label[idx_image]
156 | patch_label = int(patch_label in self.positive_num)
157 | return patch.float()/255, [patch_label, slide_label, idx_slide, slide_name], index
158 |
159 | def __len__(self):
160 | if self.return_bag:
161 | return self.idx_all_slides.shape[1] // self.bag_length * self.idx_all_slides.shape[0]
162 | else:
163 | return self.idx_all_slides.numel()
164 |
165 | def visualize(self, idx, number_row=10, number_col=10):
166 | # idx should be of shape num_img_per_slide
167 | slide = self.ds_data[idx].clone() # num_img_per_slide * 3 * 32 * 32
168 | patch_label = self.ds_label[idx].clone()
169 | idx_pos_patch = []
170 | for num in self.positive_num:
171 | idx_pos_patch.append(torch.where(patch_label == num)[0])
172 | idx_pos_patch = torch.cat(idx_pos_patch)
173 | slide[idx_pos_patch, 0, :10, :] = 255
174 | slide[idx_pos_patch, 0, -10:, :] = 255
175 | slide[idx_pos_patch, 0, :, :10] = 255
176 | slide[idx_pos_patch, 0, :, -10:] = 255
177 |
178 | slide[idx_pos_patch, 1, :10, :] = 0
179 | slide[idx_pos_patch, 1, -10:, :] = 0
180 | slide[idx_pos_patch, 1, :, :10] = 0
181 | slide[idx_pos_patch, 1, :, -10:] = 0
182 |
183 | slide[idx_pos_patch, 2, :10, :] = 0
184 | slide[idx_pos_patch, 2, -10:, :] = 0
185 | slide[idx_pos_patch, 2, :, :10] = 0
186 | slide[idx_pos_patch, 2, :, -10:] = 0
187 |
188 | slide = slide.unsqueeze(0).reshape(number_row, number_col, 3, 224, 224).permute(0, 3, 1, 4, 2).reshape(number_row*224, number_col*224, 3)
189 | import utliz
190 | # show_img_1(slide)
191 | return slide
192 |
193 |
194 | def show_img_1(img, save_file_name=''):
195 | if type(img) == torch.Tensor:
196 | img = img.cpu().detach().numpy()
197 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image
198 | if img.shape[0] == 3:
199 | img = img.transpose(1, 2, 0)
200 | fig = plt.figure()
201 | plt.imshow(img)
202 | if save_file_name != '':
203 | plt.savefig(save_file_name, format='svg')
204 | plt.colorbar()
205 | plt.show()
206 |
207 |
208 | def show_img(img, save_file_name='',format='svg', dpi=1200):
209 | if type(img) == torch.Tensor:
210 | img = img.cpu().detach().numpy()
211 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image
212 | if img.shape[0] == 3:
213 | img = img.transpose(1, 2, 0)
214 | fig = plt.figure()
215 | plt.imshow(img)
216 | plt.xticks([])
217 | plt.yticks([])
218 | plt.axis('off')
219 | if save_file_name != '':
220 | plt.savefig(save_file_name, format=format, dpi=dpi, pad_inches=0.0, bbox_inches='tight')
221 | plt.colorbar()
222 | plt.show()
223 |
224 |
225 | if __name__ == "__main__":
226 | # train_data, train_label, val_data, val_label = get_Path10_data(downsample_ratio=0.1)
227 | # for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.5, 0.7]:
228 | # print("=========== pos slide ratio: {} ===========".format(pos_slide_ratio))
229 | # train_ds = NCT_WholeSlide_challenge(ds_data=train_data, ds_label=train_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None)
230 | # slide = train_ds.visualize(train_ds.idx_all_slides[0])
231 | # show_img(slide, save_file_name='../figures/NCT_WSI_pos_PPR{}.png'.format(pos_slide_ratio), format='png', dpi=2400)
232 | # slide = train_ds.visualize(train_ds.idx_all_slides[-10])
233 | # show_img(slide, save_file_name='../figures/NCT_WSI_neg_PPR{}.png'.format(pos_slide_ratio), format='png', dpi=2400)
234 | # # print("")
235 | # print("")
236 |
237 | train_data, train_label, val_data, val_label = get_Path10_data(downsample_ratio=1.0)
238 | for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.5, 0.7]:
239 | print("=========== pos slide ratio: {} ===========".format(pos_slide_ratio))
240 | train_ds = NCT_WholeSlide_challenge(ds_data=train_data, ds_label=train_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None)
241 | val_ds = NCT_WholeSlide_challenge(ds_data=val_data, ds_label=val_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None)
242 | print("")
243 | print("")
244 |
245 |
--------------------------------------------------------------------------------
/utliz.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | from mpl_toolkits.mplot3d.axes3d import Axes3D
5 | from sklearn import metrics
6 |
7 |
8 | def show_pointcloud(pc, size=10):
9 | if type(pc) == torch.Tensor:
10 | pc = pc.numpy()
11 | if pc.shape[0]==3:
12 | pc = pc.transpose()
13 | fig = plt.figure()
14 | ax = fig.add_subplot(1, 1, 1, projection='3d')
15 | ax.scatter(pc[:,0], pc[:,1], pc[:,2],s=size)
16 |
17 |
18 | def show_pointcloud_batch(pc, size=10):
19 | if type(pc) == torch.Tensor:
20 | pc = pc.numpy()
21 | if pc.shape[1]==3:
22 | pc = pc.transpose(0,2,1)
23 | B,N,C = pc.shape
24 | fig = plt.figure()
25 | for i in range(B):
26 | ax = fig.add_subplot(2, int(B/2), i+1, projection='3d')
27 | ax.scatter(pc[i, :, 0], pc[i, :, 1], pc[i, :, 2], s=size)
28 |
29 |
30 | def show_pointcloud_2pc(pc_1, pc_2, ax=None, c1='r', c2='b',s1=1, s2=1):
31 | if type(pc_1) == torch.Tensor:
32 | pc_1 = pc_1.cpu().detach().numpy()
33 | pc_2 = pc_2.cpu().detach().numpy()
34 | if pc_1.shape[0]==3:
35 | pc_1 = pc_1.transpose()
36 | pc_2 = pc_2.transpose()
37 | if ax is None:
38 | fig = plt.figure()
39 | ax = fig.add_subplot(1, 1, 1, projection='3d')
40 | ax.scatter(pc_1[:, 0], pc_1[:, 1], pc_1[:, 2], s=s1, c=c1, alpha=0.5)
41 | ax.scatter(pc_2[:, 0], pc_2[:, 1], pc_2[:, 2], s=s2, c=c2, alpha=0.5)
42 |
43 |
44 | def show_pointcloud_perpointcolor(pc, size=10,c='r'):
45 | # pc.shape = Nx3, c.shape = N
46 | if type(pc) == torch.Tensor:
47 | pc = pc.cpu().detach().numpy()
48 | if pc.shape[0]==3:
49 | pc = pc.transpose()
50 | if type(c) == torch.Tensor:
51 | c = c.cpu().detach().numpy()
52 | if type(c) == np.ndarray:
53 | if len(c.shape) == 2:
54 | c = np.squeeze(c)
55 | fig = plt.figure()
56 | ax = fig.add_subplot(1, 1, 1, projection='3d')
57 | ax0 = ax.scatter(pc[:,0], pc[:,1], pc[:,2],s=size, alpha=0.5,c=c)
58 | plt.colorbar(ax0, ax=ax)
59 |
60 |
61 | def cal_auc(label, pred, pos_label=1, return_fpr_tpr=False, save_fpr_tpr=False):
62 | if type(label) == torch.Tensor:
63 | label = label.detach().cpu().numpy()
64 | if type(pred) == torch.Tensor:
65 | pred = pred.detach().cpu().numpy()
66 | fpr, tpr, thresholds = metrics.roc_curve(label, pred, pos_label=pos_label, drop_intermediate=False)
67 | auc_score = metrics.auc(fpr, tpr)
68 | if save_fpr_tpr:
69 | if auc_score > 0.5:
70 | np.save("./ROC_reinter/{:.0f}".format(auc_score * 10000),
71 | np.concatenate([np.expand_dims(fpr, axis=1), np.expand_dims(tpr, axis=1)], axis=1))
72 | if return_fpr_tpr:
73 | return fpr, tpr, auc_score
74 | return auc_score
75 |
76 |
77 | def cal_acc(label, pred, threshold=0.5):
78 | if type(label) == torch.Tensor:
79 | label = label.detach().cpu().numpy()
80 | if type(pred) == torch.Tensor:
81 | pred = pred.detach().cpu().numpy()
82 | pred_logit = pred>threshold
83 | pred_logit = pred_logit.astype(np.long)
84 | acc = np.sum(pred_logit == label)/label.shape[0]
85 | return acc
86 |
87 |
88 | def optimal_thresh(fpr, tpr, thresholds, p=0):
89 | loss = (fpr - tpr) - p * tpr / (fpr + tpr + 1)
90 | idx = np.argmin(loss, axis=0)
91 | return fpr[idx], tpr[idx], thresholds[idx]
92 |
93 |
94 | def cal_acc_optimThre(label, pred, pos_label=1):
95 | if type(label) == torch.Tensor:
96 | label = label.detach().cpu().numpy()
97 | if type(pred) == torch.Tensor:
98 | pred = pred.detach().cpu().numpy()
99 | fpr, tpr, thresholds = metrics.roc_curve(label, pred, pos_label=pos_label, drop_intermediate=False)
100 | fpr_optimal, tpr_optimal, threshold_optimal = optimal_thresh(fpr, tpr, thresholds)
101 | pred[pred>threshold_optimal] = 1
102 | pred[pred best_acc:
117 | best_acc = acc
118 | return best_acc
119 |
120 |
121 | def cal_TPR_TNR_FPR_FNR(label, pred):
122 | if type(pred) is not torch.Tensor:
123 | pred = torch.from_numpy(pred)
124 | else:
125 | pred = pred.detach().cpu()
126 | if type(label) is not torch.Tensor:
127 | label = torch.from_numpy(label)
128 | else:
129 | label = label.detach().cpu()
130 |
131 | pred_logit = pred.round()
132 | pseudo_label_TP = torch.sum(label * pred_logit)
133 | pseudo_label_TN = torch.sum((1 - label) * (1 - pred_logit))
134 | pesudo_label_FP = torch.sum((1 - label) * pred_logit)
135 | pesudo_label_FN = torch.sum(label * (1 - pred_logit))
136 | pseudo_label_TPR = 1.0 * pseudo_label_TP / (label.sum() + 1e-9)
137 | pseudo_label_TNR = 1.0 * pseudo_label_TN / (label.numel() - label.sum() + 1e-9)
138 | pseudo_label_FPR = 1.0 * pesudo_label_FP / (label.numel() - label.sum() + 1e-9)
139 | pseudo_label_FNR = 1.0 * pesudo_label_FN / (label.sum() + 1e-9)
140 |
141 | pseudo_label_precision = 1.0 * pseudo_label_TP / (pred_logit.sum() + 1e-9)
142 | pseudo_label_acc = 1.0 * torch.sum(label == pred_logit) / label.numel()
143 | pseudo_label_auc = cal_auc(label, pred)
144 |
145 | return [pseudo_label_TPR.item(), pseudo_label_TNR.item(), pseudo_label_FPR.item(), pseudo_label_FNR.item()],\
146 | pseudo_label_acc.item(), pseudo_label_auc
147 |
148 |
149 | class AverageMeter(object):
150 | """Computes and stores the average and current value"""
151 | def __init__(self):
152 | self.reset()
153 |
154 | def reset(self):
155 | self.val = 0
156 | self.avg = 0
157 | self.sum = 0
158 | self.count = 0
159 | self.val_window = []
160 | self.avg_window = 0
161 |
162 | def update(self, val, n=1):
163 | self.val = val
164 | self.sum += val * n
165 | self.count += n
166 | self.avg = self.sum / self.count
167 | if len(self.val_window)< 10:
168 | self.val_window.append(self.val)
169 | elif len(self.val_window) == 10:
170 | self.val_window.pop(0)
171 | self.val_window.append(self.val)
172 | else:
173 | print("windows avg ERROR")
174 | self.avg_window = np.array(self.val_window).mean()
175 |
176 |
177 | # class VisdomLinePlotter(object):
178 | # """Plots to Visdom"""
179 | # def __init__(self, env_name='main'):
180 | # self.viz = Visdom()
181 | # self.env = env_name
182 | # self.plots = {}
183 | # self.scatters = {}
184 | # def plot(self, var_name, split_name, title_name, x, y):
185 | # if var_name not in self.plots:
186 | # self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict(
187 | # legend=[split_name],
188 | # title=title_name,
189 | # xlabel='Epochs',
190 | # ylabel=var_name
191 | # ))
192 | # else:
193 | # self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append')
194 | #
195 | # # def scatter(self, var_name, split_name, title_name, x, size=10):
196 | # # if var_name not in self.scatters:
197 | # # self.scatters[var_name] = self.viz.scatter(X=x.cpu().detach().numpy(), env=self.env, opts=dict(
198 | # # legend=[split_name],
199 | # # title=title_name,
200 | # # markersize=size
201 | # # ))
202 | # # else:
203 | # # self.viz.scatter(X=x.cpu().detach().numpy(), env=self.env, win=self.scatters[var_name], name=split_name, update='replace')
204 | #
205 | # def scatter(self, var_name, split_name, title_name, x, size=10, color=0, symbol='dot'):
206 | # if var_name not in self.scatters:
207 | # if type(x) == torch.Tensor:
208 | # x = x.cpu().detach().numpy()
209 | # self.scatters[var_name] = self.viz.scatter(X=x, env=self.env, opts=dict(
210 | # legend=[split_name],
211 | # title=title_name,
212 | # markersize=size,
213 | # markercolor=color,
214 | # markerborderwidth=0,
215 | # # opacity=0.5
216 | # # markersymbol=symbol,
217 | # # linecolor='white',
218 | # ))
219 | # else:
220 | # if type(x) == torch.Tensor:
221 | # x = x.cpu().detach().numpy()
222 | # self.viz.scatter(X=x, env=self.env, win=self.scatters[var_name], name=split_name, update='replace')
223 |
224 |
225 | ####################################
226 | ########### plotly plot ############
227 | def show_3D_imageSlice_plotly(volume):
228 | if type(volume) == torch.Tensor:
229 | volume = volume.detach().cpu().numpy()
230 | r, c = volume[0].shape
231 | # Define frames
232 | import plotly.graph_objects as go
233 | import plotly.io as pio
234 | pio.renderers.default = "browser"
235 | nb_frames = volume.shape[0]
236 |
237 | fig = go.Figure(frames=[go.Frame(data=go.Surface(
238 | z=((nb_frames-1)/10 - k * 0.1) * np.ones((r, c)),
239 | surfacecolor=np.flipud(volume[nb_frames-1 - k]),
240 | cmin=0, cmax=200
241 | ),
242 | name=str(k) # you need to name the frame for the animation to behave properly
243 | )
244 | for k in range(nb_frames)])
245 |
246 | # Add data to be displayed before animation starts
247 | fig.add_trace(go.Surface(
248 | z=(nb_frames-1)/10 * np.ones((r, c)),
249 | surfacecolor=np.flipud(volume[nb_frames-1]),
250 | colorscale='Gray',
251 | cmin=0, cmax=200,
252 | colorbar=dict(thickness=20, ticklen=4)
253 | ))
254 |
255 |
256 | def frame_args(duration):
257 | return {
258 | "frame": {"duration": duration},
259 | "mode": "immediate",
260 | "fromcurrent": True,
261 | "transition": {"duration": duration, "easing": "linear"},
262 | }
263 |
264 | sliders = [
265 | {
266 | "pad": {"b": 10, "t": 60},
267 | "len": 0.9,
268 | "x": 0.1,
269 | "y": 0,
270 | "steps": [
271 | {
272 | "args": [[f.name], frame_args(0)],
273 | "label": str(k),
274 | "method": "animate",
275 | }
276 | for k, f in enumerate(fig.frames)
277 | ],
278 | }
279 | ]
280 |
281 | # Layout
282 | fig.update_layout(
283 | title='Slices in volumetric data',
284 | width=600,
285 | height=600,
286 | scene=dict(
287 | zaxis=dict(range=[-0.1, (nb_frames-1)/10], autorange=False),
288 | aspectratio=dict(x=1, y=1, z=1),
289 | ),
290 | updatemenus = [
291 | {
292 | "buttons": [
293 | {
294 | "args": [None, frame_args(50)],
295 | "label": "▶", # play symbol
296 | "method": "animate",
297 | },
298 | {
299 | "args": [[None], frame_args(0)],
300 | "label": "◼", # pause symbol
301 | "method": "animate",
302 | },
303 | ],
304 | "direction": "left",
305 | "pad": {"r": 10, "t": 70},
306 | "type": "buttons",
307 | "x": 0.1,
308 | "y": 0,
309 | }
310 | ],
311 | sliders=sliders
312 | )
313 |
314 | fig.show()
315 |
316 |
317 | def show_3D_volume_plotly(volume, surface_count=17):
318 | import plotly.graph_objects as go
319 | import numpy as np
320 | import plotly.io as pio
321 | pio.renderers.default = "browser"
322 | if type(volume) == torch.Tensor:
323 | volume = volume.detach().cpu().numpy()
324 |
325 | X, Y, Z = np.mgrid[0:volume.shape[0], 0:volume.shape[1], 0:volume.shape[2]]
326 |
327 | fig = go.Figure(data=go.Volume(
328 | x=X.flatten(),
329 | y=Y.flatten(),
330 | z=Z.flatten(),
331 | value=volume.flatten(),
332 | isomin=0.1,
333 | isomax=0.8,
334 | opacity=0.1, # needs to be small to see through all surfaces
335 | surface_count=surface_count, # needs to be a large number for good volume rendering
336 | ))
337 | fig.show()
338 |
339 | ####################################
340 | ####################################
341 |
342 | def get_lr(optimizer):
343 | for param_group in optimizer.param_groups:
344 | return param_group['lr']
345 |
346 |
347 | ####################################
348 | ####################################
349 | class Network_Logger(object):
350 | def __init__(self, model):
351 | self.model = model
352 | self.model_grad_dict = {}
353 | self.model_weight_dict = {}
354 | self.model_weightSize_dict = {}
355 | for (i, j) in self.model.named_parameters():
356 | if len(j.shape) > 1:
357 | self.model_grad_dict[i] = []
358 | self.model_weight_dict[i] = [j.abs().mean().item()]
359 | self.model_weightSize_dict[i] = j.shape
360 |
361 | def log_grad(self):
362 | for (i, j) in self.model.named_parameters():
363 | if len(j.shape) > 1:
364 | self.model_grad_dict[i].append(j.grad.abs().mean().item())
365 |
366 | def log_weight(self):
367 | for (i, j) in self.model.named_parameters():
368 | if len(j.shape) > 1:
369 | self.model_weight_dict[i].append(j.abs().mean().item())
370 |
371 | def get_current_weight(self):
372 | current_weight = []
373 | for key in self.model_weight_dict.keys():
374 | current_weight.append(self.model_weight_dict[key][-1])
375 | return current_weight
376 |
377 | def get_current_grad(self):
378 | current_grad = []
379 | for key in self.model_grad_dict.keys():
380 | current_grad.append(self.model_grad_dict[key][-1])
381 | return current_grad
382 |
383 | def plot_grad(self, layer_idx=None):
384 | # example: layer_idx = [0,1,2] for only first 3 layers
385 | fig = plt.figure()
386 | ax = fig.add_subplot(1, 1, 1)
387 | if layer_idx is not None:
388 | for idx, key in enumerate(self.model_grad_dict.keys()):
389 | if idx in layer_idx:
390 | ax.plot(self.model_grad_dict[key], label=str(key))
391 | ax.legend()
392 | else:
393 | for idx, key in enumerate(self.model_grad_dict.keys()):
394 | ax.plot(self.model_grad_dict[key], label=str(key))
395 | ax.legend()
396 |
397 | def plot_weight(self, layer_idx=None):
398 | # example: layer_idx = [0,1,2] for only first 3 layers
399 | fig = plt.figure()
400 | ax = fig.add_subplot(1, 1, 1)
401 | if layer_idx is not None:
402 | for idx, key in enumerate(self.model_weight_dict.keys()):
403 | if idx in layer_idx:
404 | ax.plot(self.model_weight_dict[key], label=str(key))
405 | ax.legend()
406 | else:
407 | for idx, key in enumerate(self.model_weight_dict.keys()):
408 | ax.plot(self.model_weight_dict[key], label=str(key))
409 | ax.legend()
410 |
411 |
412 | ####################################
413 | ####################################
414 | def show_img(img, save_file_name=''):
415 | if type(img) == torch.Tensor:
416 | img = img.cpu().detach().numpy()
417 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image
418 | if img.shape[0] == 3:
419 | img = img.transpose(1, 2, 0)
420 | fig = plt.figure()
421 | plt.imshow(img)
422 | if save_file_name != '':
423 | plt.savefig(save_file_name, format='svg')
424 | plt.colorbar()
425 | plt.show()
426 |
427 | def show_img_multi(img_list, num_col, num_row):
428 | fig = plt.figure()
429 |
430 | for idx, img in enumerate(img_list):
431 | if type(img) == torch.Tensor:
432 | img = img.cpu().detach().numpy()
433 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image
434 | if img.shape[0] == 3:
435 | img = img.transpose(1, 2, 0)
436 | ax = fig.add_subplot(num_col, num_row, idx+1)
437 | ax.imshow(img)
438 | plt.show()
439 |
--------------------------------------------------------------------------------
/models/alexnet.py:
--------------------------------------------------------------------------------
1 | # adapted from DeepCluster repo: https://github.com/facebookresearch/deepcluster
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 |
7 | __all__ = [ 'AlexNet', 'alexnet_MNIST', 'alexnet', 'alexnet_STL10', 'alexnet_PCam', 'alexnet_CAMELYON',
8 | 'AlexNet_MNIST_projection_prototype', 'alexnet_MedMNIST', 'alexnet_CIFAR10']
9 |
10 | # (number of filters, kernel size, stride, pad)
11 | CFG = {
12 | 'big': [(96, 11, 4, 2), 'M', (256, 5, 1, 2), 'M', (384, 3, 1, 1), (384, 3, 1, 1), (256, 3, 1, 1), 'M'],
13 | 'small': [(64, 11, 4, 2), 'M', (192, 5, 1, 2), 'M', (384, 3, 1, 1), (256, 3, 1, 1), (256, 3, 1, 1), 'M'],
14 | 'mnist': [(32, 6, 2, 2), (64, 3, 1, 1), 'M', (128, 3, 1, 1), (128, 3, 1, 1), 'M'],
15 | 'CAMELYON': [(96, 12, 4, 4), (256, 12, 4, 4), 'M_', (256, 5, 1, 2), 'M_', (512, 3, 1, 1), (512, 3, 1, 1), (256, 3, 1, 1), 'M_'],
16 | 'CIFAR10': [(96, 3, 1, 1), 'M', (192, 3, 1, 1), 'M', (384, 3, 1, 1), (384, 3, 1, 1), (192, 3, 1, 1), 'M']
17 | }
18 |
19 |
20 | class AlexNet(nn.Module):
21 | def __init__(self, features, num_classes, init=True):
22 | super(AlexNet, self).__init__()
23 | self.features = features
24 | self.classifier = nn.Sequential(nn.Dropout(0.5),
25 | nn.Linear(256 * 2 * 2, 4096),
26 | nn.ReLU(inplace=True),
27 | nn.Dropout(0.5),
28 | nn.Linear(4096, 4096),
29 | nn.ReLU(inplace=True))
30 | self.headcount = len(num_classes)
31 | self.return_features = False
32 | if len(num_classes) == 1:
33 | self.top_layer = nn.Linear(4096, num_classes[0])
34 | else:
35 | for a,i in enumerate(num_classes):
36 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
37 | self.top_layer = None # this way headcount can act as switch.
38 | if init:
39 | self._initialize_weights()
40 |
41 | def forward(self, x):
42 | x = self.features(x)
43 | x = x.view(x.size(0), 256 * 2 * 2)
44 | x = self.classifier(x)
45 | if self.return_features: # switch only used for CIFAR-experiments
46 | return x
47 | if self.headcount == 1:
48 | if self.top_layer: # this way headcount can act as switch.
49 | x = self.top_layer(x)
50 | return x
51 | else:
52 | outp = []
53 | for i in range(self.headcount):
54 | outp.append(getattr(self, "top_layer%d" % i)(x))
55 | return outp
56 |
57 | def _initialize_weights(self):
58 | for y, m in enumerate(self.modules()):
59 | if isinstance(m, nn.Conv2d):
60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
61 | for i in range(m.out_channels):
62 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
63 | if m.bias is not None:
64 | m.bias.data.zero_()
65 | elif isinstance(m, nn.BatchNorm2d):
66 | m.weight.data.fill_(1)
67 | m.bias.data.zero_()
68 | elif isinstance(m, nn.Linear):
69 | m.weight.data.normal_(0, 0.01)
70 | m.bias.data.zero_()
71 |
72 |
73 | class AlexNet_4x4(nn.Module):
74 | def __init__(self, features, num_classes, init=True):
75 | super(AlexNet_4x4, self).__init__()
76 | self.features = features
77 | self.classifier = nn.Sequential(nn.Dropout(0.5),
78 | nn.Linear(256 * 4 * 4, 4096),
79 | nn.ReLU(inplace=True),
80 | nn.Dropout(0.5),
81 | nn.Linear(4096, 4096),
82 | nn.ReLU(inplace=True))
83 | self.headcount = len(num_classes)
84 | self.return_features = False
85 | if len(num_classes) == 1:
86 | self.top_layer = nn.Linear(4096, num_classes[0])
87 | else:
88 | for a,i in enumerate(num_classes):
89 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
90 | self.top_layer = None # this way headcount can act as switch.
91 | if init:
92 | self._initialize_weights()
93 |
94 | def forward(self, x):
95 | x = self.features(x)
96 | x = x.view(x.size(0), 256 * 4 * 4)
97 | x = self.classifier(x)
98 | if self.return_features: # switch only used for CIFAR-experiments
99 | return x
100 | if self.headcount == 1:
101 | if self.top_layer: # this way headcount can act as switch.
102 | x = self.top_layer(x)
103 | return x
104 | else:
105 | outp = []
106 | for i in range(self.headcount):
107 | outp.append(getattr(self, "top_layer%d" % i)(x))
108 | return outp
109 |
110 | def _initialize_weights(self):
111 | for y, m in enumerate(self.modules()):
112 | if isinstance(m, nn.Conv2d):
113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
114 | for i in range(m.out_channels):
115 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
116 | if m.bias is not None:
117 | m.bias.data.zero_()
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 | elif isinstance(m, nn.Linear):
122 | m.weight.data.normal_(0, 0.01)
123 | m.bias.data.zero_()
124 |
125 |
126 | class AlexNet_MNIST(nn.Module):
127 | def __init__(self, features, num_classes, init=True):
128 | super(AlexNet_MNIST, self).__init__()
129 | self.features = features
130 | self.classifier = nn.Sequential(nn.Dropout(0.5),
131 | nn.Linear(128 * 2 * 2, 1024),
132 | nn.ReLU(inplace=True),
133 | nn.Dropout(0.5),
134 | nn.Linear(1024, 1024),
135 | nn.ReLU(inplace=True))
136 | self.headcount = len(num_classes)
137 | self.return_features = False
138 | if len(num_classes) == 1:
139 | self.top_layer = nn.Linear(1024, num_classes[0])
140 | else:
141 | for a,i in enumerate(num_classes):
142 | setattr(self, "top_layer%d" % a, nn.Linear(1024, i))
143 | self.top_layer = None # this way headcount can act as switch.
144 | if init:
145 | self._initialize_weights()
146 |
147 | def forward(self, x):
148 | x = self.features(x)
149 | x = x.view(x.size(0), 128 * 2 * 2)
150 | x = self.classifier(x)
151 | if self.return_features: # switch only used for CIFAR-experiments
152 | return x
153 | if self.headcount == 1:
154 | if self.top_layer: # this way headcount can act as switch.
155 | x = self.top_layer(x)
156 | # x = nn.functional.tanh(x) # add by xiaoyuan 2021_4_22 to avoid nan in loss
157 | return x
158 | else:
159 | outp = []
160 | for i in range(self.headcount):
161 | outp.append(getattr(self, "top_layer%d" % i)(x))
162 | return outp
163 |
164 | def _initialize_weights(self):
165 | for y, m in enumerate(self.modules()):
166 | if isinstance(m, nn.Conv2d):
167 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
168 | for i in range(m.out_channels):
169 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
170 | if m.bias is not None:
171 | m.bias.data.zero_()
172 | elif isinstance(m, nn.BatchNorm2d):
173 | m.weight.data.fill_(1)
174 | m.bias.data.zero_()
175 | elif isinstance(m, nn.Linear):
176 | m.weight.data.normal_(0, 0.01)
177 | m.bias.data.zero_()
178 |
179 |
180 | class AlexNet_CIFAR10(nn.Module):
181 | def __init__(self, features, num_classes, init=True, input_feat_dim=192*3*3):
182 | super(AlexNet_CIFAR10, self).__init__()
183 | self.features = features
184 | self.input_feat_dim = input_feat_dim
185 | self.classifier = nn.Sequential(
186 | # nn.Dropout(0.5),
187 | nn.Linear(input_feat_dim, 4096),
188 | nn.ReLU(inplace=True),
189 | # nn.Dropout(0.5),
190 | nn.Linear(4096, 4096),
191 | nn.ReLU(inplace=True)
192 | )
193 | self.headcount = len(num_classes)
194 | self.return_features = False
195 | if len(num_classes) == 1:
196 | self.top_layer = nn.Linear(4096, num_classes[0])
197 | else:
198 | for a, i in enumerate(num_classes):
199 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
200 | self.top_layer = None # this way headcount can act as switch.
201 | if init:
202 | self._initialize_weights()
203 |
204 | def forward(self, x):
205 | if self.features is not None:
206 | x = self.features(x)
207 | x = x.view(x.size(0), self.input_feat_dim)
208 | x = self.classifier(x)
209 | if self.return_features: # switch only used for CIFAR-experiments
210 | return x
211 | if self.headcount == 1:
212 | if self.top_layer: # this way headcount can act as switch.
213 | x = self.top_layer(x)
214 | # x = nn.functional.tanh(x) # add by xiaoyuan 2021_4_22 to avoid nan in loss
215 | return x
216 | else:
217 | outp = []
218 | for i in range(self.headcount):
219 | outp.append(getattr(self, "top_layer%d" % i)(x))
220 | return outp
221 |
222 | def _initialize_weights(self):
223 | for y, m in enumerate(self.modules()):
224 | if isinstance(m, nn.Conv2d):
225 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
226 | for i in range(m.out_channels):
227 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
228 | if m.bias is not None:
229 | m.bias.data.zero_()
230 | elif isinstance(m, nn.BatchNorm2d):
231 | m.weight.data.fill_(1)
232 | m.bias.data.zero_()
233 | elif isinstance(m, nn.Linear):
234 | m.weight.data.normal_(0, 0.01)
235 | m.bias.data.zero_()
236 |
237 |
238 | class AlexNet_MNIST_projection_prototype(nn.Module):
239 | def __init__(self, output_dim=0, hidden_mlp=0, nmb_prototypes=0, init=True, normalize=True,
240 | eval_mode=False, norm_layer=None):
241 | super(AlexNet_MNIST_projection_prototype, self).__init__()
242 |
243 | self.features = make_layers_features(CFG['mnist'], 1, bn=True)
244 |
245 | if norm_layer is None:
246 | norm_layer = nn.BatchNorm2d
247 | self._norm_layer = norm_layer
248 |
249 | self.eval_mode = eval_mode
250 |
251 | # normalize output features
252 | self.l2norm = normalize
253 |
254 | # projection head
255 | if output_dim == 0:
256 | self.projection_head = None
257 | elif hidden_mlp == 0:
258 | # self.projection_head = nn.Linear(128*2*2, output_dim)
259 | self.projection_head = nn.Linear(128, output_dim)
260 | else:
261 | self.projection_head = nn.Sequential(
262 | # nn.Linear(128*2*2, hidden_mlp),
263 | nn.Linear(128, hidden_mlp),
264 | nn.BatchNorm1d(hidden_mlp),
265 | nn.ReLU(inplace=True),
266 | nn.Linear(hidden_mlp, output_dim),
267 | )
268 |
269 | # prototype layer
270 | self.prototypes = None
271 | if isinstance(nmb_prototypes, list):
272 | # self.prototypes = MultiPrototypes(output_dim, nmb_prototypes)
273 | print("Multiple Prototypes is not supported now")
274 | elif nmb_prototypes > 0:
275 | self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False)
276 |
277 | for m in self.modules():
278 | if isinstance(m, nn.Conv2d):
279 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
280 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
281 | nn.init.constant_(m.weight, 1)
282 | nn.init.constant_(m.bias, 0)
283 |
284 | def forward_backbone(self, x):
285 | x = self.features(x)
286 | # x = x.view(x.size(0), 128 * 2 * 2)
287 | x = x.view(x.size(0), 128, 2 * 2)
288 | x = x.max(dim=-1)[0]
289 | return x
290 |
291 | def forward_head(self, x):
292 | if self.projection_head is not None:
293 | x = self.projection_head(x)
294 |
295 | if self.l2norm:
296 | x = nn.functional.normalize(x, dim=1, p=2)
297 |
298 | if self.prototypes is not None:
299 | return x, self.prototypes(x)
300 | return x
301 |
302 | def forward(self, inputs):
303 | if not isinstance(inputs, list):
304 | inputs = [inputs]
305 | idx_crops = torch.cumsum(torch.unique_consecutive(
306 | torch.tensor([inp.shape[-1] for inp in inputs]),
307 | return_counts=True,
308 | )[1], 0)
309 | start_idx = 0
310 | for end_idx in idx_crops:
311 | _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True))
312 | if start_idx == 0:
313 | output = _out
314 | else:
315 | output = torch.cat((output, _out))
316 | start_idx = end_idx
317 | return self.forward_head(output)
318 |
319 | def _initialize_weights(self):
320 | for y, m in enumerate(self.modules()):
321 | if isinstance(m, nn.Conv2d):
322 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
323 | for i in range(m.out_channels):
324 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
325 | if m.bias is not None:
326 | m.bias.data.zero_()
327 | elif isinstance(m, nn.BatchNorm2d):
328 | m.weight.data.fill_(1)
329 | m.bias.data.zero_()
330 | elif isinstance(m, nn.Linear):
331 | m.weight.data.normal_(0, 0.01)
332 | m.bias.data.zero_()
333 |
334 |
335 | def make_layers_features(cfg, input_dim, bn):
336 | layers = []
337 | in_channels = input_dim
338 | for v in cfg:
339 | if v == 'M':
340 | layers += [nn.MaxPool2d(kernel_size=3, stride=2)]
341 | elif v == 'M_':
342 | layers += [nn.MaxPool2d(kernel_size=4, stride=2, padding=1)]
343 | else:
344 | conv2d = nn.Conv2d(in_channels, v[0], kernel_size=v[1], stride=v[2], padding=v[3])#,bias=False)
345 | if bn:
346 | layers += [conv2d, nn.BatchNorm2d(v[0]), nn.ReLU(inplace=True)]
347 | else:
348 | layers += [conv2d, nn.ReLU(inplace=True)]
349 | in_channels = v[0]
350 | return nn.Sequential(*layers)
351 |
352 |
353 | def alexnet(bn=True, num_classes=[1000], init=True, size='big'):
354 | dim = 1
355 | model = AlexNet(make_layers_features(CFG[size], dim, bn=bn), num_classes, init)
356 | return model
357 |
358 |
359 | def alexnet_MNIST(bn=True, num_classes=[2], init=True):
360 | dim = 1
361 | model = AlexNet_MNIST(make_layers_features(CFG['mnist'], dim, bn=bn), num_classes, init)
362 | return model
363 |
364 |
365 | def alexnet_MedMNIST(bn=True, num_classes=[2], init=True):
366 | dim = 3
367 | model = AlexNet_MNIST(make_layers_features(CFG['mnist'], dim, bn=bn), num_classes, init)
368 | return model
369 |
370 |
371 | def alexnet_STL10(num_classes):
372 | model = SmallAlexNet(num_classes)
373 | return model
374 |
375 |
376 | def alexnet_PCam(bn=True, num_classes=[2], init=True):
377 | dim = 3
378 | model = AlexNet(make_layers_features(CFG['big'], input_dim=dim ,bn=bn), num_classes=num_classes, init=init)
379 | return model
380 |
381 |
382 | def alexnet_CAMELYON(bn=True, num_classes=[2], init=True):
383 | dim = 3
384 | model = AlexNet_4x4(make_layers_features(CFG['CAMELYON'], input_dim=dim ,bn=bn), num_classes=num_classes, init=init)
385 | return model
386 |
387 |
388 | def alexnet_CIFAR10(bn=True, num_classes=[2], init=True):
389 | dim = 3
390 | model = AlexNet_CIFAR10(make_layers_features(CFG['CIFAR10'], dim, bn=bn), num_classes, init)
391 | return model
392 |
393 |
394 | class L2Norm(nn.Module):
395 | def forward(self, x):
396 | return x / x.norm(p=2, dim=1, keepdim=True)
397 |
398 |
399 | class SmallAlexNet(nn.Module):
400 | def __init__(self, in_channel=3, num_classes=[2]):
401 | super(SmallAlexNet, self).__init__()
402 | blocks = []
403 |
404 | # conv_block_1
405 | blocks.append(nn.Sequential(
406 | nn.Conv2d(in_channel, 96, kernel_size=3, padding=1, bias=False),
407 | nn.BatchNorm2d(96),
408 | nn.ReLU(inplace=True),
409 | nn.MaxPool2d(3, 2),
410 | ))
411 |
412 | # conv_block_2
413 | blocks.append(nn.Sequential(
414 | nn.Conv2d(96, 192, kernel_size=3, padding=1, bias=False),
415 | nn.BatchNorm2d(192),
416 | nn.ReLU(inplace=True),
417 | nn.MaxPool2d(3, 2),
418 | ))
419 |
420 | # conv_block_3
421 | blocks.append(nn.Sequential(
422 | nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False),
423 | nn.BatchNorm2d(384),
424 | nn.ReLU(inplace=True),
425 | ))
426 |
427 | # conv_block_4
428 | blocks.append(nn.Sequential(
429 | nn.Conv2d(384, 384, kernel_size=3, padding=1, bias=False),
430 | nn.BatchNorm2d(384),
431 | nn.ReLU(inplace=True),
432 | ))
433 |
434 | # conv_block_5
435 | blocks.append(nn.Sequential(
436 | nn.Conv2d(384, 192, kernel_size=3, padding=1, bias=False),
437 | nn.BatchNorm2d(192),
438 | nn.ReLU(inplace=True),
439 | nn.MaxPool2d(3, 2),
440 | ))
441 |
442 | # fc6
443 | blocks.append(nn.Sequential(
444 | nn.Flatten(),
445 | nn.Linear(192 * 7 * 7, 4096, bias=False), # 256 * 6 * 6 if 224 * 224
446 | nn.BatchNorm1d(4096),
447 | nn.ReLU(inplace=True),
448 | ))
449 |
450 | # fc7
451 | blocks.append(nn.Sequential(
452 | nn.Linear(4096, 4096, bias=False),
453 | nn.BatchNorm1d(4096),
454 | nn.ReLU(inplace=True),
455 | ))
456 |
457 | # fc8
458 | blocks.append(nn.Sequential(
459 | nn.Linear(4096, num_classes[0]),
460 | L2Norm(),
461 | ))
462 |
463 | self.blocks = nn.ModuleList(blocks)
464 | self.init_weights_()
465 |
466 | def init_weights_(self):
467 | def init(m):
468 | if isinstance(m, (nn.Linear, nn.Conv2d)):
469 | nn.init.normal_(m.weight, 0, 0.02)
470 | if getattr(m, 'bias', None) is not None:
471 | nn.init.zeros_(m.bias)
472 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
473 | if getattr(m, 'weight', None) is not None:
474 | nn.init.ones_(m.weight)
475 | if getattr(m, 'bias', None) is not None:
476 | nn.init.zeros_(m.bias)
477 |
478 | self.apply(init)
479 |
480 | def forward(self, x, *, layer_index=-1):
481 | if layer_index < 0:
482 | layer_index += len(self.blocks)
483 | for layer in self.blocks[:(layer_index + 1)]:
484 | x = layer(x)
485 | return x
486 |
487 |
488 | class AlexNet_MNIST_attention(nn.Module):
489 | def __init__(self, features, num_classes, init=True, withoutAtten=False):
490 | super(AlexNet_MNIST_attention, self).__init__()
491 | self.withoutAtten=withoutAtten
492 | self.features = features
493 | self.classifier = nn.Sequential(nn.Dropout(0.5),
494 | nn.Linear(128 * 2 * 2, 1024),
495 | nn.ReLU(inplace=True),
496 | nn.Dropout(0.5),
497 | nn.Linear(1024, 1024),
498 | nn.ReLU(inplace=True))
499 | self.L = 1024
500 | self.D = 512
501 | self.K = 1
502 |
503 | self.attention = nn.Sequential(
504 | nn.Linear(self.L, self.D),
505 | nn.Tanh(),
506 | nn.Linear(self.D, self.K)
507 | )
508 | self.headcount = len(num_classes)
509 | self.return_features = False
510 | if len(num_classes) == 1:
511 | self.top_layer = nn.Linear(1024, num_classes[0])
512 | else:
513 | for a,i in enumerate(num_classes):
514 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
515 | self.top_layer = None # this way headcount can act as switch.
516 | if init:
517 | self._initialize_weights()
518 |
519 | def forward(self, x, returnBeforeSoftMaxA=False):
520 | x = x.squeeze(0)
521 | x = self.features(x)
522 | x = x.view(x.size(0), 128 * 2 * 2)
523 | x = self.classifier(x)
524 |
525 | # Attention module
526 | A_ = self.attention(x) # NxK
527 | A_ = torch.transpose(A_, 1, 0) # KxN
528 | A = F.softmax(A_, dim=1) # softmax over N
529 |
530 | if self.withoutAtten:
531 | x = torch.mean(x, dim=0, keepdim=True)
532 | else:
533 | x = torch.mm(A, x) # KxL
534 |
535 | if self.return_features: # switch only used for CIFAR-experiments
536 | return x
537 |
538 | x = self.top_layer(x)
539 | if returnBeforeSoftMaxA:
540 | return x, 0, A, A_
541 | return x, 0, A
542 |
543 | def _initialize_weights(self):
544 | for y, m in enumerate(self.modules()):
545 | if isinstance(m, nn.Conv2d):
546 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
547 | for i in range(m.out_channels):
548 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
549 | if m.bias is not None:
550 | m.bias.data.zero_()
551 | elif isinstance(m, nn.BatchNorm2d):
552 | m.weight.data.fill_(1)
553 | m.bias.data.zero_()
554 | elif isinstance(m, nn.Linear):
555 | m.weight.data.normal_(0, 0.01)
556 | m.bias.data.zero_()
557 |
558 |
559 | class AlexNet_CIFAR10_attention(nn.Module):
560 | def __init__(self, features, num_classes, init=True, withoutAtten=False, input_feat_dim=192*3*3):
561 | super(AlexNet_CIFAR10_attention, self).__init__()
562 | self.input_feat_dim = input_feat_dim
563 | self.withoutAtten=withoutAtten
564 | self.features = features
565 | self.classifier = nn.Sequential(nn.Dropout(0.5),
566 | nn.Linear(input_feat_dim, 1024),
567 | nn.ReLU(inplace=True),
568 | nn.Dropout(0.5),
569 | nn.Linear(1024, 1024),
570 | nn.ReLU(inplace=True))
571 | self.L = 1024
572 | self.D = 512
573 | self.K = 1
574 |
575 | self.attention = nn.Sequential(
576 | nn.Linear(self.L, self.D),
577 | nn.Tanh(),
578 | nn.Linear(self.D, self.K)
579 | )
580 | self.headcount = len(num_classes)
581 | self.return_features = False
582 | if len(num_classes) == 1:
583 | self.top_layer = nn.Linear(1024, num_classes[0])
584 | else:
585 | for a,i in enumerate(num_classes):
586 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
587 | self.top_layer = None # this way headcount can act as switch.
588 | if init:
589 | self._initialize_weights()
590 |
591 | def forward(self, x, returnBeforeSoftMaxA=False, scores_replaceAS=None):
592 | if self.features is not None:
593 | x = x.squeeze(0)
594 | x = self.features(x)
595 | x = x.view(x.size(0), self.input_feat_dim)
596 | x = self.classifier(x)
597 |
598 | # Attention module
599 | A_ = self.attention(x) # NxK
600 | A_ = torch.transpose(A_, 1, 0) # KxN
601 | A = F.softmax(A_, dim=1) # softmax over N
602 |
603 | if scores_replaceAS is not None:
604 | A_ = scores_replaceAS
605 | A = F.softmax(A_, dim=1) # softmax over N
606 |
607 | if self.withoutAtten:
608 | x = torch.mean(x, dim=0, keepdim=True)
609 | else:
610 | x = torch.mm(A, x) # KxL
611 |
612 | if self.return_features: # switch only used for CIFAR-experiments
613 | return x
614 |
615 | x = self.top_layer(x)
616 | if returnBeforeSoftMaxA:
617 | return x, 0, A, A_
618 | return x, 0, A
619 |
620 | def _initialize_weights(self):
621 | for y, m in enumerate(self.modules()):
622 | if isinstance(m, nn.Conv2d):
623 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
624 | for i in range(m.out_channels):
625 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
626 | if m.bias is not None:
627 | m.bias.data.zero_()
628 | elif isinstance(m, nn.BatchNorm2d):
629 | m.weight.data.fill_(1)
630 | m.bias.data.zero_()
631 | elif isinstance(m, nn.Linear):
632 | m.weight.data.normal_(0, 0.01)
633 | m.bias.data.zero_()
634 |
635 |
636 | class AlexNet_CIFAR10_dsmil(nn.Module):
637 | def __init__(self, features, num_classes, init=True, withoutAtten=False, input_feat_dim=192 * 3 * 3):
638 | super(AlexNet_CIFAR10_dsmil, self).__init__()
639 | self.withoutAtten=withoutAtten
640 | self.features = features
641 | self.classifier = nn.Sequential(nn.Dropout(0.5),
642 | nn.Linear(input_feat_dim, 1024),
643 | nn.ReLU(inplace=True),
644 | nn.Dropout(0.5),
645 | nn.Linear(1024, 1024),
646 | nn.ReLU(inplace=True))
647 | # self.L = 1024
648 | # self.D = 512
649 | # self.K = 1
650 | # self.attention = nn.Sequential(
651 | # nn.Linear(self.L, self.D),
652 | # nn.Tanh(),
653 | # nn.Linear(self.D, self.K)
654 | # )
655 |
656 | self.fc_dsmil = nn.Sequential(nn.Linear(1024, 2))
657 | self.q_dsmil = nn.Linear(1024, 1024)
658 | self.v_dsmil = nn.Sequential(
659 | nn.Dropout(0.0),
660 | nn.Linear(1024, 1024)
661 | )
662 | self.fcc_dsmil = nn.Conv1d(2, 2, kernel_size=1024)
663 |
664 | self.headcount = len(num_classes)
665 | self.return_features = False
666 | if len(num_classes) == 1:
667 | self.top_layer = nn.Linear(1024, num_classes[0])
668 | else:
669 | for a,i in enumerate(num_classes):
670 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i))
671 | self.top_layer = None # this way headcount can act as switch.
672 | if init:
673 | self._initialize_weights()
674 |
675 | def forward(self, x):
676 | if self.features is not None:
677 | x = x.squeeze(0)
678 | x = self.features(x)
679 | x = x.view(x.size(0), -1)
680 | x = self.classifier(x)
681 |
682 | # # Attention module
683 | # A_ = self.attention(x) # NxK
684 | # A_ = torch.transpose(A_, 1, 0) # KxN
685 | # A = F.softmax(A_, dim=1) # softmax over N
686 | #
687 | # if self.withoutAtten:
688 | # x = torch.mean(x, dim=0, keepdim=True)
689 | # else:
690 | # x = torch.mm(A, x) # KxL
691 | #
692 | # if self.return_features: # switch only used for CIFAR-experiments
693 | # return x
694 | # x = self.top_layer(x)
695 | # if returnBeforeSoftMaxA:
696 | # return x, 0, A, A_
697 | # return x, 0, A
698 |
699 | feat = x
700 | device = feat.device
701 | instance_pred = self.fc_dsmil(feat)
702 | V = self.v_dsmil(feat)
703 | Q = self.q_dsmil(feat).view(feat.shape[0], -1)
704 | _, m_indices = torch.sort(instance_pred, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C
705 | m_feats = torch.index_select(feat, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K
706 | q_max = self.q_dsmil(m_feats) # compute queries of critical instances, q_max in shape C x Q
707 | A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores
708 | A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C,
709 | B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V
710 | B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
711 | C = self.fcc_dsmil(B) # 1 x C x 1
712 | C = C.view(1, -1)
713 | return instance_pred, C, A, B
714 |
715 | def _initialize_weights(self):
716 | for y, m in enumerate(self.modules()):
717 | if isinstance(m, nn.Conv2d):
718 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
719 | for i in range(m.out_channels):
720 | m.weight.data[i].normal_(0, math.sqrt(2. / n))
721 | if m.bias is not None:
722 | m.bias.data.zero_()
723 | elif isinstance(m, nn.BatchNorm2d):
724 | m.weight.data.fill_(1)
725 | m.bias.data.zero_()
726 | elif isinstance(m, nn.Linear):
727 | m.weight.data.normal_(0, 0.01)
728 | m.bias.data.zero_()
729 |
730 |
731 | def alexnet_MNIST_Attention(bn=True, num_classes=[2], init=True):
732 | dim = 1
733 | model = AlexNet_MNIST_attention(make_layers_features(CFG['mnist'], dim, bn=bn), num_classes, init)
734 | return model
735 |
736 |
737 | def alexnet_CIFAR10_Attention(bn=True, num_classes=[2], init=True):
738 | dim = 3
739 | model = AlexNet_CIFAR10_attention(make_layers_features(CFG['CIFAR10'], dim, bn=bn), num_classes, init)
740 | return model
741 |
742 |
743 | ########################################
744 | ## models for Shared Stu and Tea network
745 | def alexnet_CIFAR10_Encoder():
746 | dim = 3
747 | model = make_layers_features(CFG['CIFAR10'], dim, bn=True)
748 | return model
749 |
750 |
751 | def teacher_Attention_head(bn=True, num_classes=[2], init=True, input_feat_dim=192*3*3):
752 | model = AlexNet_CIFAR10_attention(features=None, num_classes=num_classes, init=init, input_feat_dim=input_feat_dim)
753 | return model
754 |
755 |
756 | def teacher_DSMIL_head(bn=True, num_classes=[2], init=True, input_feat_dim=192*3*3):
757 | model = AlexNet_CIFAR10_dsmil(features=None, num_classes=num_classes, init=init, input_feat_dim=input_feat_dim)
758 | return model
759 |
760 |
761 | def student_head(num_classes=[2], init=True, input_feat_dim=192*3*3):
762 | model = AlexNet_CIFAR10(None, num_classes, init, input_feat_dim=input_feat_dim)
763 | return model
764 |
765 |
766 | class feat_projecter(nn.Module):
767 | def __init__(self, input_feat_dim=512, output_feat_dim=512):
768 | super(feat_projecter, self).__init__()
769 | # self.projecter = nn.Sequential(
770 | # nn.Linear(input_feat_dim, input_feat_dim*2),
771 | # nn.BatchNorm1d(input_feat_dim*2),
772 | # nn.ReLU(inplace=True),
773 | # nn.Linear(input_feat_dim*2, input_feat_dim * 2),
774 | # nn.BatchNorm1d(input_feat_dim*2),
775 | # nn.ReLU(inplace=True),
776 | # nn.Linear(input_feat_dim * 2, output_feat_dim),
777 | # nn.BatchNorm1d(output_feat_dim),
778 | # )
779 | self.projecter = nn.Sequential(
780 | nn.Linear(input_feat_dim, output_feat_dim),
781 | nn.BatchNorm1d(output_feat_dim)
782 | )
783 | def forward(self, x):
784 | x = self.projecter(x)
785 | return x
786 |
787 |
788 | def camelyon_feat_projecter(input_dim, output_dim):
789 | model = feat_projecter(input_dim, output_dim)
790 | return model
791 | ########################################
792 |
793 | if __name__ == '__main__':
794 | import torch
795 | # model = alexnet(num_classes=[500]*3)
796 | # print([ k.shape for k in model(torch.randn(64,3,224,224))])
797 | model = AlexNet_MNIST_projection_prototype(output_dim=128, hidden_mlp=2048, nmb_prototypes=300)
798 | print("END")
799 |
800 |
--------------------------------------------------------------------------------
/train_TCGAFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import warnings
3 | import os
4 | import time
5 | import numpy as np
6 |
7 | import torch
8 | import torch.optim
9 | import torch.nn as nn
10 | import torch.utils.data
11 | from tensorboardX import SummaryWriter
12 | # import models
13 | # from models.alexnet import alexnet_CIFAR10, alexnet_CIFAR10_Attention
14 | from models.alexnet import camelyon_feat_projecter, teacher_DSMIL_head, student_head
15 | # from dataset_toy import Dataset_toy
16 | # from Datasets_loader.dataset_MNIST_challenge import MNIST_WholeSlide_challenge
17 | # from Datasets_loader.dataset_MIL_CIFAR import CIFAR_WholeSlide_challenge
18 | from Datasets_loader.dataset_TCGA_LungCancer import TCGA_LungCancer_Feat
19 | import datetime
20 | import utliz
21 | import util
22 | import random
23 | from tqdm import tqdm
24 | import copy
25 |
26 |
27 | class Optimizer:
28 | def __init__(self, model_encoder, model_teacherHead, model_studentHead,
29 | optimizer_encoder, optimizer_teacherHead, optimizer_studentHead,
30 | train_bagloader, train_instanceloader, test_bagloader, test_instanceloader,
31 | writer=None, num_epoch=100,
32 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
33 | PLPostProcessMethod='NegGuide', StuFilterType='ReplaceAS', smoothE=100,
34 | stu_loss_weight_neg=0.1, stuOptPeriod=1):
35 | self.model_encoder = model_encoder
36 | self.model_teacherHead = model_teacherHead
37 | self.model_studentHead = model_studentHead
38 | self.optimizer_encoder = optimizer_encoder
39 | self.optimizer_teacherHead = optimizer_teacherHead
40 | self.optimizer_studentHead = optimizer_studentHead
41 | self.train_bagloader = train_bagloader
42 | self.train_instanceloader = train_instanceloader
43 | self.test_bagloader = test_bagloader
44 | self.test_instanceloader = test_instanceloader
45 | self.writer = writer
46 | self.num_epoch = num_epoch
47 | self.dev = dev
48 | self.log_period = 10
49 | self.PLPostProcessMethod = PLPostProcessMethod
50 | self.StuFilterType = StuFilterType
51 | self.smoothE = smoothE
52 | self.stu_loss_weight_neg = stu_loss_weight_neg
53 | self.stuOptPeriod = stuOptPeriod
54 |
55 | def optimize(self):
56 | self.Bank_all_Bags_label = None
57 | self.Bank_all_instances_pred_byTeacher = None
58 | self.Bank_all_instances_feat_byTeacher = None
59 | self.Bank_all_instances_pred_processed = None
60 |
61 | self.Bank_all_instances_pred_byStudent = None
62 |
63 | # Load pre-extracted SimCLR features
64 | # pre_trained_SimCLR_feat = self.train_instanceloader.dataset.ds_data_simCLR_feat[self.train_instanceloader.dataset.idx_all_slides].to(self.dev)
65 | for epoch in range(self.num_epoch):
66 | self.optimize_teacher(epoch)
67 | self.evaluate_teacher(epoch)
68 | if epoch % self.stuOptPeriod == 0:
69 | self.optimize_student(epoch)
70 | self.evaluate_student(epoch)
71 |
72 | return 0
73 |
74 | def optimize_teacher(self, epoch):
75 | self.model_encoder.train()
76 | self.model_teacherHead.train()
77 | self.model_studentHead.eval()
78 | criterion = torch.nn.CrossEntropyLoss()
79 | ## optimize teacher with bag-dataloader
80 | # 1. change loader to bag-loader
81 | loader = self.train_bagloader
82 | # 2. optimize
83 | patch_label_gt = []
84 | patch_label_pred = []
85 | bag_label_gt = []
86 | bag_label_pred = []
87 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher training')):
88 | for i, j in enumerate(label):
89 | if torch.is_tensor(j):
90 | label[i] = j.to(self.dev)
91 | selected = selected.squeeze(0)
92 | niter = epoch * len(loader) + iter
93 |
94 | data = data.to(self.dev)
95 | feat = self.model_encoder(data.squeeze(0))
96 | if epoch > self.smoothE:
97 | if "FilterNegInstance" in self.StuFilterType:
98 | # using student prediction to remove negative instance feat in the positive bag
99 | if label[1] == 1:
100 | with torch.no_grad():
101 | pred_byStudent = self.model_studentHead(feat)
102 | pred_byStudent = torch.softmax(pred_byStudent, dim=1)[:, 1]
103 | if '_Top' in self.StuFilterType:
104 | # strategy A: remove the topK most negative instance
105 | idx_to_keep = torch.topk(-pred_byStudent, k=int(self.StuFilterType.split('_Top')[-1]))[1]
106 | elif '_ThreProb' in self.StuFilterType:
107 | # strategy B: remove the negative instance above prob K
108 | idx_to_keep = torch.where(pred_byStudent <= int(self.StuFilterType.split('_ThreProb')[-1])/100.0)[0]
109 | if idx_to_keep.shape[0] == 0: # if all instance are dropped, keep the most positive one
110 | idx_to_keep = torch.topk(pred_byStudent, k=1)[1]
111 | feat_removedNeg = feat[idx_to_keep]
112 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat_removedNeg)
113 | instance_attn_score = torch.cat([instance_attn_score, instance_attn_score[:, 1].min()*torch.ones(feat.shape[0]-instance_attn_score.shape[0], 2).to(instance_attn_score.device)], dim=0)
114 | else:
115 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
116 | else:
117 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
118 | else:
119 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
120 |
121 | max_id = torch.argmax(instance_attn_score[:, 1])
122 | bag_pred_byMax = instance_attn_score[max_id, :].squeeze(0)
123 | bag_loss = criterion(bag_prediction, label[1])
124 | bag_loss_byMax = criterion(bag_pred_byMax.unsqueeze(0), label[1])
125 | loss_teacher = 0.5 * bag_loss + 0.5 * bag_loss_byMax
126 |
127 | self.optimizer_encoder.zero_grad()
128 | self.optimizer_teacherHead.zero_grad()
129 | loss_teacher.backward()
130 | self.optimizer_encoder.step()
131 | self.optimizer_teacherHead.step()
132 |
133 | bag_prediction = 1.0 * torch.softmax(bag_prediction, dim=1) + \
134 | 0.0 * torch.softmax(bag_pred_byMax.unsqueeze(0), dim=1)
135 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1)
136 |
137 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0))
138 | patch_label_gt.append(label[0].squeeze(0))
139 | bag_label_pred.append(bag_prediction.detach()[0, 1])
140 | bag_label_gt.append(label[1])
141 | if niter % self.log_period == 0:
142 | self.writer.add_scalar('train_loss_Teacher', loss_teacher.item(), niter)
143 |
144 | patch_label_pred = torch.cat(patch_label_pred)
145 | patch_label_gt = torch.cat(patch_label_gt)
146 | bag_label_pred = torch.tensor(bag_label_pred)
147 | bag_label_gt = torch.cat(bag_label_gt)
148 |
149 | self.estimated_AttnScore_norm_para_min = patch_label_pred.min()
150 | self.estimated_AttnScore_norm_para_max = patch_label_pred.max()
151 | patch_label_pred_normed = self.norm_AttnScore2Prob(patch_label_pred)
152 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1))
153 |
154 | bag_auc_ByTeacher = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_pred.reshape(-1))
155 | self.writer.add_scalar('train_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch)
156 | self.writer.add_scalar('train_bag_AUC_byTeacher', bag_auc_ByTeacher, epoch)
157 | # print("Epoch:{} train_bag_AUC_byTeacher:{}".format(epoch, bag_auc_ByTeacher))
158 | return 0
159 |
160 | def norm_AttnScore2Prob(self, attn_score):
161 | prob = (attn_score - self.estimated_AttnScore_norm_para_min) / (self.estimated_AttnScore_norm_para_max - self.estimated_AttnScore_norm_para_min)
162 | return prob
163 |
164 | def post_process_pred_byTeacher(self, Bank_all_instances_feat, Bank_all_instances_pred, Bank_all_bags_label, method='NegGuide'):
165 | if method=='NegGuide':
166 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
167 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
168 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
169 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
170 | elif method=='NegGuide_TopK':
171 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
172 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
173 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0]
174 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
175 | K = 3
176 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1]
177 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1)
178 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
179 | elif method=='NegGuide_Similarity':
180 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
181 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
182 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0]
183 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
184 | K = 1
185 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1]
186 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1)
187 | Bank_all_Pos_instances_feat = Bank_all_instances_feat[idx_pos_bag]
188 | Bank_mostSalient_Pos_instances_feat = []
189 | for i in range(Bank_all_Pos_instances_feat.shape[0]):
190 | Bank_mostSalient_Pos_instances_feat.append(Bank_all_Pos_instances_feat[i, idx_topK_inside_pos_bag[i, 0], :].unsqueeze(0).unsqueeze(0))
191 | Bank_mostSalient_Pos_instances_feat = torch.cat(Bank_mostSalient_Pos_instances_feat, dim=0)
192 |
193 | distance_matrix = Bank_all_Pos_instances_feat - Bank_mostSalient_Pos_instances_feat
194 | distance_matrix = torch.norm(distance_matrix, dim=-1, p=2)
195 | Bank_all_instances_pred_processed[idx_pos_bag, :] = self.distanceMatrix2PL(distance_matrix)
196 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
197 | else:
198 | raise TypeError
199 | return Bank_all_instances_pred_processed
200 |
201 | def distanceMatrix2PL(self, distance_matrix, method='percentage'):
202 | # distance_matrix is of shape NxL (Num of Positive Bag * Bag Length)
203 | # represents the distance between each instance with their corresponding most salient instance
204 | # return Pseudo-labels of shape NxL (value should belong to [0,1])
205 |
206 | if method == 'softMax':
207 | # 1. just use softMax to keep PLs value fall into [0,1]
208 | similarity_matrix = 1/(distance_matrix + 1e-5)
209 | pseudo_labels = torch.softmax(similarity_matrix, dim=1)
210 | elif method == 'percentage':
211 | # 2. use percentage to keep n% PL=1, 1-n% PL=0
212 | p = 0.1 # 10% is set
213 | threshold_v = distance_matrix.topk(k=int(100 * p), dim=1)[0][:, -1].unsqueeze(1).repeat([1, 100]) # of size Nx100
214 | pseudo_labels = torch.zeros_like(distance_matrix)
215 | pseudo_labels[distance_matrix >= threshold_v] = 0.0
216 | pseudo_labels[distance_matrix < threshold_v] = 1.0
217 | elif method == 'threshold':
218 | # 3. use threshold to set PLs of instance with distance above the threshold to 1
219 | raise TypeError
220 | else:
221 | raise TypeError
222 |
223 | ## visulaize the pseudo_labels distribution of inside each bag
224 | # import matplotlib.pyplot as plt
225 | # plt.figure()
226 | # plt.hist(pseudo_labels.cpu().numpy().reshape(-1))
227 |
228 | return pseudo_labels
229 |
230 | def optimize_student(self, epoch):
231 | self.model_teacherHead.train()
232 | self.model_encoder.train()
233 | self.model_studentHead.train()
234 | ## optimize teacher with instance-dataloader
235 | # 1. change loader to instance-loader
236 | loader = self.train_instanceloader
237 | # 2. optimize
238 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
239 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
240 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
241 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
242 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')):
243 | for i, j in enumerate(label):
244 | if torch.is_tensor(j):
245 | label[i] = j.to(self.dev)
246 | selected = selected.squeeze(0)
247 | niter = epoch * len(loader) + iter
248 |
249 | data = data.to(self.dev)
250 |
251 | # get teacher output of instance
252 | feat = self.model_encoder(data)
253 | with torch.no_grad():
254 | instance_attn_score, _, _, _ = self.model_teacherHead(feat)
255 | pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score[:, 1]).clamp(min=1e-5, max=1-1e-5).squeeze(0)
256 | # set true negative patch label to [1, 0]
257 | pseudo_instance_label[label[1] == 0] = 0
258 | # # DEBUG: Assign GT patch label
259 | # pseudo_instance_label = label[0]
260 | # get student output of instance
261 | patch_prediction = self.model_studentHead(feat)
262 | patch_prediction = torch.softmax(patch_prediction, dim=1)
263 |
264 | # cal loss
265 | loss_student = -1. * torch.mean(self.stu_loss_weight_neg * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) +
266 | (1-self.stu_loss_weight_neg) * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5))
267 | self.optimizer_encoder.zero_grad()
268 | self.optimizer_studentHead.zero_grad()
269 | loss_student.backward()
270 | self.optimizer_encoder.step()
271 | self.optimizer_studentHead.step()
272 |
273 | patch_corresponding_slide_idx[selected, 0] = label[2]
274 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
275 | patch_label_gt[selected, 0] = label[0]
276 | bag_label_gt[selected, 0] = label[1]
277 | if niter % self.log_period == 0:
278 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter)
279 |
280 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
281 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
282 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
283 |
284 | # cal bag-level auc
285 | bag_label_gt_coarse = []
286 | bag_label_prediction = []
287 | available_bag_idx = patch_corresponding_slide_idx.unique()
288 | for bag_idx_i in available_bag_idx:
289 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i)
290 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max():
291 | raise
292 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max())
293 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max())
294 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse)
295 | bag_label_prediction = torch.tensor(bag_label_prediction)
296 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1))
297 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
298 | return 0
299 |
300 | def optimize_student_fromBank(self, epoch, Bank_all_instances_pred):
301 | self.model_teacherHead.train()
302 | self.model_encoder.train()
303 | self.model_studentHead.train()
304 | ## optimize teacher with instance-dataloader
305 | # 1. change loader to instance-loader
306 | loader = self.train_instanceloader
307 | # 2. optimize
308 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
309 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
310 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
311 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
312 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')):
313 | for i, j in enumerate(label):
314 | if torch.is_tensor(j):
315 | label[i] = j.to(self.dev)
316 | selected = selected.squeeze(0)
317 | niter = epoch * len(loader) + iter
318 |
319 | data = data.to(self.dev)
320 |
321 | # get teacher output of instance
322 | feat = self.model_encoder(data)
323 | # with torch.no_grad():
324 | # _, _, _, instance_attn_score = self.model_teacherHead(feat, returnBeforeSoftMaxA=True)
325 | # pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score).clamp(min=1e-5, max=1-1e-5).squeeze(0)
326 | # # set true negative patch label to [1, 0]
327 | # pseudo_instance_label[label[1] == 0] = 0
328 |
329 | pseudo_instance_label = Bank_all_instances_pred[selected//100, selected%100]
330 | # # DEBUG: Assign GT patch label
331 | # pseudo_instance_label = label[0]
332 | # get student output of instance
333 | patch_prediction = self.model_studentHead(feat)
334 | patch_prediction = torch.softmax(patch_prediction, dim=1)
335 |
336 | # cal loss
337 | loss_student = -1. * torch.mean(0.1 * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) +
338 | 0.9 * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5))
339 | self.optimizer_encoder.zero_grad()
340 | self.optimizer_studentHead.zero_grad()
341 | loss_student.backward()
342 | self.optimizer_encoder.step()
343 | self.optimizer_studentHead.step()
344 |
345 | patch_corresponding_slide_idx[selected, 0] = label[2]
346 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
347 | patch_label_gt[selected, 0] = label[0]
348 | bag_label_gt[selected, 0] = label[1]
349 | if niter % self.log_period == 0:
350 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter)
351 |
352 | self.Bank_all_instances_pred_byStudent = patch_label_pred
353 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
354 | bag_auc_ByStudent = 0
355 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
356 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
357 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
358 | return 0
359 |
360 | def evaluate(self, epoch, loader, log_name_prefix=''):
361 | return 0
362 |
363 | def evaluate_teacher(self, epoch):
364 | self.model_encoder.eval()
365 | self.model_teacherHead.eval()
366 | ## optimize teacher with bag-dataloader
367 | # 1. change loader to bag-loader
368 | loader = self.test_bagloader
369 | # 2. optimize
370 | patch_label_gt = []
371 | patch_label_pred = []
372 | bag_label_gt = []
373 | bag_label_prediction_withAttnScore = []
374 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher evaluating')):
375 | for i, j in enumerate(label):
376 | if torch.is_tensor(j):
377 | label[i] = j.to(self.dev)
378 | selected = selected.squeeze(0)
379 | niter = epoch * len(loader) + iter
380 |
381 | data = data.to(self.dev)
382 | with torch.no_grad():
383 | feat = self.model_encoder(data.squeeze(0))
384 |
385 | instance_attn_score, bag_prediction_withAttnScore, _, _ = self.model_teacherHead(feat)
386 | bag_prediction_withAttnScore = torch.softmax(bag_prediction_withAttnScore, 1)
387 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1)
388 |
389 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0))
390 | patch_label_gt.append(label[0].squeeze(0))
391 | bag_label_prediction_withAttnScore.append(bag_prediction_withAttnScore.detach()[0, 1])
392 | bag_label_gt.append(label[1])
393 |
394 | patch_label_pred = torch.cat(patch_label_pred)
395 | patch_label_gt = torch.cat(patch_label_gt)
396 | bag_label_prediction_withAttnScore = torch.tensor(bag_label_prediction_withAttnScore)
397 | bag_label_gt = torch.cat(bag_label_gt)
398 |
399 | patch_label_pred_normed = (patch_label_pred - patch_label_pred.min()) / (patch_label_pred.max() - patch_label_pred.min())
400 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1))
401 | bag_auc_ByTeacher_withAttnScore = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_prediction_withAttnScore.reshape(-1))
402 | self.writer.add_scalar('test_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch)
403 | self.writer.add_scalar('test_bag_AUC_byTeacher', bag_auc_ByTeacher_withAttnScore, epoch)
404 | return 0
405 |
406 | def evaluate_student(self, epoch):
407 | self.model_encoder.eval()
408 | self.model_studentHead.eval()
409 | ## optimize teacher with instance-dataloader
410 | # 1. change loader to instance-loader
411 | loader = self.test_instanceloader
412 | # 2. optimize
413 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
414 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
415 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
416 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
417 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student evaluating')):
418 | for i, j in enumerate(label):
419 | if torch.is_tensor(j):
420 | label[i] = j.to(self.dev)
421 | selected = selected.squeeze(0)
422 | niter = epoch * len(loader) + iter
423 |
424 | data = data.to(self.dev)
425 |
426 | # get student output of instance
427 | with torch.no_grad():
428 | feat = self.model_encoder(data)
429 | patch_prediction = self.model_studentHead(feat)
430 | patch_prediction = torch.softmax(patch_prediction, dim=1)
431 |
432 | patch_corresponding_slide_idx[selected, 0] = label[2]
433 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
434 | patch_label_gt[selected, 0] = label[0]
435 | bag_label_gt[selected, 0] = label[1]
436 |
437 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
438 | self.writer.add_scalar('test_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
439 | # print("Epoch:{} test_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
440 |
441 | # cal bag-level auc
442 | bag_label_gt_coarse = []
443 | bag_label_prediction = []
444 | available_bag_idx = patch_corresponding_slide_idx.unique()
445 | for bag_idx_i in available_bag_idx:
446 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i)
447 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max():
448 | raise
449 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max())
450 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max())
451 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse)
452 | bag_label_prediction = torch.tensor(bag_label_prediction)
453 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1))
454 | self.writer.add_scalar('test_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
455 | return 0
456 |
457 |
458 | def str2bool(v):
459 | """
460 | Input:
461 | v - string
462 | output:
463 | True/False
464 | """
465 | if isinstance(v, bool):
466 | return v
467 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
468 | return True
469 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
470 | return False
471 | else:
472 | raise argparse.ArgumentTypeError('Boolean value expected.')
473 |
474 |
475 | def get_parser():
476 | parser = argparse.ArgumentParser(description='PyTorch Implementation of Self-Label')
477 | # optimizer
478 | parser.add_argument('--epochs', default=1500, type=int, help='number of epochs')
479 | parser.add_argument('--batch_size', default=512, type=int, help='batch size (default: 256)')
480 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate (default: 0.05)')
481 | parser.add_argument('--lrdrop', default=1500, type=int, help='multiply LR by 0.5 every (default: 150 epochs)')
482 | parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: (-5)')
483 | parser.add_argument('--dtype', default='f64', choices=['f64', 'f32'], type=str, help='SK-algo dtype (default: f64)')
484 |
485 | # SK algo
486 | parser.add_argument('--nopts', default=100, type=int, help='number of pseudo-opts (default: 100)')
487 | parser.add_argument('--augs', default=3, type=int, help='augmentation level (default: 3)')
488 | parser.add_argument('--lamb', default=25, type=int, help='for pseudoopt: lambda (default:25) ')
489 |
490 | # architecture
491 | # parser.add_argument('--arch', default='alexnet_MNIST', type=str, help='alexnet or resnet (default: alexnet)')
492 |
493 | # housekeeping
494 | parser.add_argument('--device', default='0', type=str, help='GPU devices to use for storage and model')
495 | parser.add_argument('--modeldevice', default='0', type=str, help='GPU numbers on which the CNN runs')
496 | parser.add_argument('--exp', default='self-label-default', help='path to experiment directory')
497 | parser.add_argument('--workers', default=0, type=int,help='number workers (default: 6)')
498 | parser.add_argument('--comment', default='DEBUG_BagDistillation_DSMIL', type=str, help='name for tensorboardX')
499 | parser.add_argument('--log-intv', default=1, type=int, help='save stuff every x epochs (default: 1)')
500 | parser.add_argument('--log_iter', default=200, type=int, help='log every x-th batch (default: 200)')
501 | parser.add_argument('--seed', default=10, type=int, help='random seed')
502 |
503 | parser.add_argument('--PLPostProcessMethod', default='NegGuide', type=str,
504 | help='Post-processing method of Attention Scores to build Pseudo Lables',
505 | choices=['NegGuide', 'NegGuide_TopK', 'NegGuide_Similarity'])
506 | parser.add_argument('--StuFilterType', default='FilterNegInstance__ThreProb50', type=str,
507 | help='Type of using Student Prediction to imporve Teacher '
508 | '[ReplaceAS, FilterNegInstance_Top95, FilterNegInstance__ThreProb90]')
509 | parser.add_argument('--smoothE', default=9999, type=int, help='num of epoch to apply StuFilter')
510 | parser.add_argument('--stu_loss_weight_neg', default=0.1, type=float, help='weight of neg instances in stu training')
511 | parser.add_argument('--stuOptPeriod', default=1, type=int, help='period of stu optimization')
512 | return parser.parse_args()
513 |
514 |
515 | if __name__ == "__main__":
516 | args = get_parser()
517 |
518 | # torch.manual_seed(args.seed)
519 | # random.seed(args.seed)
520 | # np.random.seed(args.seed)
521 |
522 | name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")+"_%s" % args.comment.replace('/', '_') + \
523 | "_Seed{}_Bs{}_lr{}_PLPostProcessBy{}_StuFilterType{}_smoothE{}_weightN{}_StuOptP{}".format(
524 | args.seed, args.batch_size, args.lr,
525 | args.PLPostProcessMethod, args.StuFilterType, args.smoothE, args.stu_loss_weight_neg, args.stuOptPeriod)
526 | try:
527 | args.device = [int(item) for item in args.device.split(',')]
528 | except AttributeError:
529 | args.device = [int(args.device)]
530 | args.modeldevice = args.device
531 | util.setup_runtime(seed=42, cuda_dev_id=list(np.unique(args.modeldevice + args.device)))
532 |
533 | print(name, flush=True)
534 |
535 | writer = SummaryWriter('./runs_TCGA/%s'%name)
536 | writer.add_text('args', " \n".join(['%s %s' % (arg, getattr(args, arg)) for arg in vars(args)]))
537 |
538 | # Setup model
539 | model_encoder = camelyon_feat_projecter(input_dim=512, output_dim=512).to('cuda:0')
540 | model_teacherHead = teacher_DSMIL_head(input_feat_dim=512).to('cuda:0')
541 | model_studentHead = student_head(input_feat_dim=512).to('cuda:0')
542 |
543 | optimizer_encoder = torch.optim.SGD(model_encoder.parameters(), lr=args.lr)
544 | optimizer_teacherHead = torch.optim.SGD(model_teacherHead.parameters(), lr=args.lr)
545 | optimizer_studentHead = torch.optim.SGD(model_studentHead.parameters(), lr=args.lr)
546 |
547 | # Setup loaders
548 | train_ds_return_instance = TCGA_LungCancer_Feat(train=True, return_bag=False)
549 |
550 | train_ds_return_bag = copy.deepcopy(train_ds_return_instance)
551 | train_ds_return_bag.return_bag = True
552 | val_ds_return_instance = TCGA_LungCancer_Feat(train=False, return_bag=False)
553 | val_ds_return_bag = TCGA_LungCancer_Feat(train=False, return_bag=True)
554 |
555 | train_loader_instance = torch.utils.data.DataLoader(train_ds_return_instance, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False)
556 | train_loader_bag = torch.utils.data.DataLoader(train_ds_return_bag, batch_size=1, shuffle=True, num_workers=args.workers, drop_last=False)
557 | val_loader_instance = torch.utils.data.DataLoader(val_ds_return_instance, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False)
558 | val_loader_bag = torch.utils.data.DataLoader(val_ds_return_bag, batch_size=1, shuffle=False, num_workers=args.workers, drop_last=False)
559 |
560 | print("[Data] {} training samples".format(len(train_loader_instance.dataset)))
561 | print("[Data] {} evaluating samples".format(len(val_loader_instance.dataset)))
562 |
563 | if torch.cuda.device_count() > 1:
564 | print("Let's use", len(args.modeldevice), "GPUs for the model")
565 | if len(args.modeldevice) == 1:
566 | print('single GPU model', flush=True)
567 | else:
568 | model_encoder = nn.DataParallel(model_encoder, device_ids=list(range(len(args.modeldevice))))
569 | model_teacherHead = nn.DataParallel(model_teacherHead, device_ids=list(range(len(args.modeldevice))))
570 |
571 | # Setup optimizer
572 | o = Optimizer(model_encoder=model_encoder, model_teacherHead=model_teacherHead, model_studentHead=model_studentHead,
573 | optimizer_encoder=optimizer_encoder, optimizer_teacherHead=optimizer_teacherHead, optimizer_studentHead=optimizer_studentHead,
574 | train_bagloader=train_loader_bag, train_instanceloader=train_loader_instance,
575 | test_bagloader=val_loader_bag, test_instanceloader=val_loader_instance,
576 | writer=writer, num_epoch=args.epochs,
577 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
578 | PLPostProcessMethod=args.PLPostProcessMethod, StuFilterType=args.StuFilterType, smoothE=args.smoothE,
579 | stu_loss_weight_neg=args.stu_loss_weight_neg, stuOptPeriod=args.stuOptPeriod)
580 | # Optimize
581 | o.optimize()
--------------------------------------------------------------------------------
/train_CervicalFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import warnings
3 | import os
4 | import time
5 | import numpy as np
6 |
7 | import torch
8 | import torch.optim
9 | import torch.nn as nn
10 | import torch.utils.data
11 | from tensorboardX import SummaryWriter
12 | # import models
13 | # from models.alexnet import alexnet_CIFAR10, alexnet_CIFAR10_Attention
14 | from models.alexnet import camelyon_feat_projecter, teacher_DSMIL_head, student_head
15 | # from dataset_toy import Dataset_toy
16 | # from Datasets_loader.dataset_MNIST_challenge import MNIST_WholeSlide_challenge
17 | # from Datasets_loader.dataset_MIL_CIFAR import CIFAR_WholeSlide_challenge
18 | from Datasets_loader.dataset_CervicalCancer import CervicalCaner_16_feat
19 | import datetime
20 | import utliz
21 | import util
22 | import random
23 | from tqdm import tqdm
24 | import copy
25 |
26 |
27 | class Optimizer:
28 | def __init__(self, model_encoder, model_teacherHead, model_studentHead,
29 | optimizer_encoder, optimizer_teacherHead, optimizer_studentHead,
30 | train_bagloader, train_instanceloader, test_bagloader, test_instanceloader,
31 | writer=None, num_epoch=100,
32 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
33 | PLPostProcessMethod='NegGuide', StuFilterType='ReplaceAS', smoothE=100,
34 | stu_loss_weight_neg=0.1, stuOptPeriod=1):
35 | self.model_encoder = model_encoder
36 | self.model_teacherHead = model_teacherHead
37 | self.model_studentHead = model_studentHead
38 | self.optimizer_encoder = optimizer_encoder
39 | self.optimizer_teacherHead = optimizer_teacherHead
40 | self.optimizer_studentHead = optimizer_studentHead
41 | self.train_bagloader = train_bagloader
42 | self.train_instanceloader = train_instanceloader
43 | self.test_bagloader = test_bagloader
44 | self.test_instanceloader = test_instanceloader
45 | self.writer = writer
46 | self.num_epoch = num_epoch
47 | self.dev = dev
48 | self.log_period = 10
49 | self.PLPostProcessMethod = PLPostProcessMethod
50 | self.StuFilterType = StuFilterType
51 | self.smoothE = smoothE
52 | self.stu_loss_weight_neg = stu_loss_weight_neg
53 | self.stuOptPeriod = stuOptPeriod
54 |
55 | def optimize(self):
56 | self.Bank_all_Bags_label = None
57 | self.Bank_all_instances_pred_byTeacher = None
58 | self.Bank_all_instances_feat_byTeacher = None
59 | self.Bank_all_instances_pred_processed = None
60 |
61 | self.Bank_all_instances_pred_byStudent = None
62 |
63 | # Load pre-extracted SimCLR features
64 | # pre_trained_SimCLR_feat = self.train_instanceloader.dataset.ds_data_simCLR_feat[self.train_instanceloader.dataset.idx_all_slides].to(self.dev)
65 | for epoch in range(self.num_epoch):
66 | self.optimize_teacher(epoch)
67 | self.evaluate_teacher(epoch)
68 | if epoch % self.stuOptPeriod == 0:
69 | self.optimize_student(epoch)
70 | self.evaluate_student(epoch)
71 |
72 | return 0
73 |
74 | def optimize_teacher(self, epoch):
75 | self.model_encoder.train()
76 | self.model_teacherHead.train()
77 | self.model_studentHead.eval()
78 | criterion = torch.nn.CrossEntropyLoss()
79 | ## optimize teacher with bag-dataloader
80 | # 1. change loader to bag-loader
81 | loader = self.train_bagloader
82 | # 2. optimize
83 | patch_label_gt = []
84 | patch_label_pred = []
85 | bag_label_gt = []
86 | bag_label_pred = []
87 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher training')):
88 | for i, j in enumerate(label):
89 | if torch.is_tensor(j):
90 | label[i] = j.to(self.dev)
91 | selected = selected.squeeze(0)
92 | niter = epoch * len(loader) + iter
93 |
94 | data = data.to(self.dev)
95 | feat = self.model_encoder(data.squeeze(0))
96 | if epoch > self.smoothE:
97 | if "FilterNegInstance" in self.StuFilterType:
98 | # using student prediction to remove negative instance feat in the positive bag
99 | if label[1] == 1:
100 | with torch.no_grad():
101 | pred_byStudent = self.model_studentHead(feat)
102 | pred_byStudent = torch.softmax(pred_byStudent, dim=1)[:, 1]
103 | if '_Top' in self.StuFilterType:
104 | # strategy A: remove the topK most negative instance
105 | idx_to_keep = torch.topk(-pred_byStudent, k=int(self.StuFilterType.split('_Top')[-1]))[1]
106 | elif '_ThreProb' in self.StuFilterType:
107 | # strategy B: remove the negative instance above prob K
108 | idx_to_keep = torch.where(pred_byStudent <= int(self.StuFilterType.split('_ThreProb')[-1])/100.0)[0]
109 | if idx_to_keep.shape[0] == 0: # if all instance are dropped, keep the most positive one
110 | idx_to_keep = torch.topk(pred_byStudent, k=1)[1]
111 | feat_removedNeg = feat[idx_to_keep]
112 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat_removedNeg)
113 | instance_attn_score = torch.cat([instance_attn_score, instance_attn_score[:, 1].min()*torch.ones(feat.shape[0]-instance_attn_score.shape[0], 2).to(instance_attn_score.device)], dim=0)
114 | else:
115 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
116 | else:
117 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
118 | else:
119 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
120 |
121 | max_id = torch.argmax(instance_attn_score[:, 1])
122 | bag_pred_byMax = instance_attn_score[max_id, :].squeeze(0)
123 | bag_loss = criterion(bag_prediction, label[1])
124 | bag_loss_byMax = criterion(bag_pred_byMax.unsqueeze(0), label[1])
125 | loss_teacher = 0.5 * bag_loss + 0.5 * bag_loss_byMax
126 |
127 | self.optimizer_encoder.zero_grad()
128 | self.optimizer_teacherHead.zero_grad()
129 | loss_teacher.backward()
130 | self.optimizer_encoder.step()
131 | self.optimizer_teacherHead.step()
132 |
133 | bag_prediction = 1.0 * torch.softmax(bag_prediction, dim=1) + \
134 | 0.0 * torch.softmax(bag_pred_byMax.unsqueeze(0), dim=1)
135 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1)
136 |
137 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0))
138 | patch_label_gt.append(label[0].squeeze(0))
139 | bag_label_pred.append(bag_prediction.detach()[0, 1])
140 | bag_label_gt.append(label[1])
141 | if niter % self.log_period == 0:
142 | self.writer.add_scalar('train_loss_Teacher', loss_teacher.item(), niter)
143 |
144 | patch_label_pred = torch.cat(patch_label_pred)
145 | patch_label_gt = torch.cat(patch_label_gt)
146 | bag_label_pred = torch.tensor(bag_label_pred)
147 | bag_label_gt = torch.cat(bag_label_gt)
148 |
149 | self.estimated_AttnScore_norm_para_min = patch_label_pred.min()
150 | self.estimated_AttnScore_norm_para_max = patch_label_pred.max()
151 | patch_label_pred_normed = self.norm_AttnScore2Prob(patch_label_pred)
152 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1))
153 |
154 | bag_auc_ByTeacher = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_pred.reshape(-1))
155 | self.writer.add_scalar('train_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch)
156 | self.writer.add_scalar('train_bag_AUC_byTeacher', bag_auc_ByTeacher, epoch)
157 | # print("Epoch:{} train_bag_AUC_byTeacher:{}".format(epoch, bag_auc_ByTeacher))
158 | return 0
159 |
160 | def norm_AttnScore2Prob(self, attn_score):
161 | prob = (attn_score - self.estimated_AttnScore_norm_para_min) / (self.estimated_AttnScore_norm_para_max - self.estimated_AttnScore_norm_para_min)
162 | return prob
163 |
164 | def post_process_pred_byTeacher(self, Bank_all_instances_feat, Bank_all_instances_pred, Bank_all_bags_label, method='NegGuide'):
165 | if method=='NegGuide':
166 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
167 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
168 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
169 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
170 | elif method=='NegGuide_TopK':
171 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
172 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
173 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0]
174 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
175 | K = 3
176 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1]
177 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1)
178 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
179 | elif method=='NegGuide_Similarity':
180 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
181 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
182 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0]
183 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
184 | K = 1
185 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1]
186 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1)
187 | Bank_all_Pos_instances_feat = Bank_all_instances_feat[idx_pos_bag]
188 | Bank_mostSalient_Pos_instances_feat = []
189 | for i in range(Bank_all_Pos_instances_feat.shape[0]):
190 | Bank_mostSalient_Pos_instances_feat.append(Bank_all_Pos_instances_feat[i, idx_topK_inside_pos_bag[i, 0], :].unsqueeze(0).unsqueeze(0))
191 | Bank_mostSalient_Pos_instances_feat = torch.cat(Bank_mostSalient_Pos_instances_feat, dim=0)
192 |
193 | distance_matrix = Bank_all_Pos_instances_feat - Bank_mostSalient_Pos_instances_feat
194 | distance_matrix = torch.norm(distance_matrix, dim=-1, p=2)
195 | Bank_all_instances_pred_processed[idx_pos_bag, :] = self.distanceMatrix2PL(distance_matrix)
196 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
197 | else:
198 | raise TypeError
199 | return Bank_all_instances_pred_processed
200 |
201 | def distanceMatrix2PL(self, distance_matrix, method='percentage'):
202 | # distance_matrix is of shape NxL (Num of Positive Bag * Bag Length)
203 | # represents the distance between each instance with their corresponding most salient instance
204 | # return Pseudo-labels of shape NxL (value should belong to [0,1])
205 |
206 | if method == 'softMax':
207 | # 1. just use softMax to keep PLs value fall into [0,1]
208 | similarity_matrix = 1/(distance_matrix + 1e-5)
209 | pseudo_labels = torch.softmax(similarity_matrix, dim=1)
210 | elif method == 'percentage':
211 | # 2. use percentage to keep n% PL=1, 1-n% PL=0
212 | p = 0.1 # 10% is set
213 | threshold_v = distance_matrix.topk(k=int(100 * p), dim=1)[0][:, -1].unsqueeze(1).repeat([1, 100]) # of size Nx100
214 | pseudo_labels = torch.zeros_like(distance_matrix)
215 | pseudo_labels[distance_matrix >= threshold_v] = 0.0
216 | pseudo_labels[distance_matrix < threshold_v] = 1.0
217 | elif method == 'threshold':
218 | # 3. use threshold to set PLs of instance with distance above the threshold to 1
219 | raise TypeError
220 | else:
221 | raise TypeError
222 |
223 | ## visulaize the pseudo_labels distribution of inside each bag
224 | # import matplotlib.pyplot as plt
225 | # plt.figure()
226 | # plt.hist(pseudo_labels.cpu().numpy().reshape(-1))
227 |
228 | return pseudo_labels
229 |
230 | def optimize_student(self, epoch):
231 | self.model_teacherHead.train()
232 | self.model_encoder.train()
233 | self.model_studentHead.train()
234 | ## optimize teacher with instance-dataloader
235 | # 1. change loader to instance-loader
236 | loader = self.train_instanceloader
237 | # 2. optimize
238 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
239 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
240 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
241 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
242 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')):
243 | for i, j in enumerate(label):
244 | if torch.is_tensor(j):
245 | label[i] = j.to(self.dev)
246 | selected = selected.squeeze(0)
247 | niter = epoch * len(loader) + iter
248 |
249 | data = data.to(self.dev)
250 |
251 | # get teacher output of instance
252 | feat = self.model_encoder(data)
253 | with torch.no_grad():
254 | instance_attn_score, _, _, _ = self.model_teacherHead(feat)
255 | pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score[:, 1]).clamp(min=1e-5, max=1-1e-5).squeeze(0)
256 | # set true negative patch label to [1, 0]
257 | pseudo_instance_label[label[1] == 0] = 0
258 | # # DEBUG: Assign GT patch label
259 | # pseudo_instance_label = label[0]
260 | # get student output of instance
261 | patch_prediction = self.model_studentHead(feat)
262 | patch_prediction = torch.softmax(patch_prediction, dim=1)
263 |
264 | # cal loss
265 | loss_student = -1. * torch.mean(self.stu_loss_weight_neg * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) +
266 | (1-self.stu_loss_weight_neg) * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5))
267 | self.optimizer_encoder.zero_grad()
268 | self.optimizer_studentHead.zero_grad()
269 | loss_student.backward()
270 | self.optimizer_encoder.step()
271 | self.optimizer_studentHead.step()
272 |
273 | patch_corresponding_slide_idx[selected, 0] = label[2]
274 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
275 | patch_label_gt[selected, 0] = label[0]
276 | bag_label_gt[selected, 0] = label[1]
277 | if niter % self.log_period == 0:
278 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter)
279 |
280 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
281 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
282 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
283 |
284 | # cal bag-level auc
285 | bag_label_gt_coarse = []
286 | bag_label_prediction = []
287 | available_bag_idx = patch_corresponding_slide_idx.unique()
288 | for bag_idx_i in available_bag_idx:
289 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i)
290 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max():
291 | raise
292 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max())
293 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max())
294 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse)
295 | bag_label_prediction = torch.tensor(bag_label_prediction)
296 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1))
297 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
298 | return 0
299 |
300 | def optimize_student_fromBank(self, epoch, Bank_all_instances_pred):
301 | self.model_teacherHead.train()
302 | self.model_encoder.train()
303 | self.model_studentHead.train()
304 | ## optimize teacher with instance-dataloader
305 | # 1. change loader to instance-loader
306 | loader = self.train_instanceloader
307 | # 2. optimize
308 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
309 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
310 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
311 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
312 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')):
313 | for i, j in enumerate(label):
314 | if torch.is_tensor(j):
315 | label[i] = j.to(self.dev)
316 | selected = selected.squeeze(0)
317 | niter = epoch * len(loader) + iter
318 |
319 | data = data.to(self.dev)
320 |
321 | # get teacher output of instance
322 | feat = self.model_encoder(data)
323 | # with torch.no_grad():
324 | # _, _, _, instance_attn_score = self.model_teacherHead(feat, returnBeforeSoftMaxA=True)
325 | # pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score).clamp(min=1e-5, max=1-1e-5).squeeze(0)
326 | # # set true negative patch label to [1, 0]
327 | # pseudo_instance_label[label[1] == 0] = 0
328 |
329 | pseudo_instance_label = Bank_all_instances_pred[selected//100, selected%100]
330 | # # DEBUG: Assign GT patch label
331 | # pseudo_instance_label = label[0]
332 | # get student output of instance
333 | patch_prediction = self.model_studentHead(feat)
334 | patch_prediction = torch.softmax(patch_prediction, dim=1)
335 |
336 | # cal loss
337 | loss_student = -1. * torch.mean(0.1 * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) +
338 | 0.9 * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5))
339 | self.optimizer_encoder.zero_grad()
340 | self.optimizer_studentHead.zero_grad()
341 | loss_student.backward()
342 | self.optimizer_encoder.step()
343 | self.optimizer_studentHead.step()
344 |
345 | patch_corresponding_slide_idx[selected, 0] = label[2]
346 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
347 | patch_label_gt[selected, 0] = label[0]
348 | bag_label_gt[selected, 0] = label[1]
349 | if niter % self.log_period == 0:
350 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter)
351 |
352 | self.Bank_all_instances_pred_byStudent = patch_label_pred
353 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
354 | bag_auc_ByStudent = 0
355 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
356 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
357 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
358 | return 0
359 |
360 | def evaluate(self, epoch, loader, log_name_prefix=''):
361 | return 0
362 |
363 | def evaluate_teacher(self, epoch):
364 | self.model_encoder.eval()
365 | self.model_teacherHead.eval()
366 | ## optimize teacher with bag-dataloader
367 | # 1. change loader to bag-loader
368 | loader = self.test_bagloader
369 | # 2. optimize
370 | patch_label_gt = []
371 | patch_label_pred = []
372 | bag_label_gt = []
373 | bag_label_prediction_withAttnScore = []
374 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher evaluating')):
375 | for i, j in enumerate(label):
376 | if torch.is_tensor(j):
377 | label[i] = j.to(self.dev)
378 | selected = selected.squeeze(0)
379 | niter = epoch * len(loader) + iter
380 |
381 | data = data.to(self.dev)
382 | with torch.no_grad():
383 | feat = self.model_encoder(data.squeeze(0))
384 |
385 | instance_attn_score, bag_prediction_withAttnScore, _, _ = self.model_teacherHead(feat)
386 | bag_prediction_withAttnScore = torch.softmax(bag_prediction_withAttnScore, 1)
387 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1)
388 |
389 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0))
390 | patch_label_gt.append(label[0].squeeze(0))
391 | bag_label_prediction_withAttnScore.append(bag_prediction_withAttnScore.detach()[0, 1])
392 | bag_label_gt.append(label[1])
393 |
394 | patch_label_pred = torch.cat(patch_label_pred)
395 | patch_label_gt = torch.cat(patch_label_gt)
396 | bag_label_prediction_withAttnScore = torch.tensor(bag_label_prediction_withAttnScore)
397 | bag_label_gt = torch.cat(bag_label_gt)
398 |
399 | patch_label_pred_normed = (patch_label_pred - patch_label_pred.min()) / (patch_label_pred.max() - patch_label_pred.min())
400 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1))
401 | bag_auc_ByTeacher_withAttnScore = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_prediction_withAttnScore.reshape(-1))
402 | self.writer.add_scalar('test_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch)
403 | self.writer.add_scalar('test_bag_AUC_byTeacher', bag_auc_ByTeacher_withAttnScore, epoch)
404 | return 0
405 |
406 | def evaluate_student(self, epoch):
407 | self.model_encoder.eval()
408 | self.model_studentHead.eval()
409 | ## optimize teacher with instance-dataloader
410 | # 1. change loader to instance-loader
411 | loader = self.test_instanceloader
412 | # 2. optimize
413 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
414 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
415 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
416 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
417 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student evaluating')):
418 | for i, j in enumerate(label):
419 | if torch.is_tensor(j):
420 | label[i] = j.to(self.dev)
421 | selected = selected.squeeze(0)
422 | niter = epoch * len(loader) + iter
423 |
424 | data = data.to(self.dev)
425 |
426 | # get student output of instance
427 | with torch.no_grad():
428 | feat = self.model_encoder(data)
429 | patch_prediction = self.model_studentHead(feat)
430 | patch_prediction = torch.softmax(patch_prediction, dim=1)
431 |
432 | patch_corresponding_slide_idx[selected, 0] = label[2]
433 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
434 | patch_label_gt[selected, 0] = label[0]
435 | bag_label_gt[selected, 0] = label[1]
436 |
437 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
438 | self.writer.add_scalar('test_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
439 | # print("Epoch:{} test_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
440 |
441 | # cal bag-level auc
442 | bag_label_gt_coarse = []
443 | bag_label_prediction = []
444 | available_bag_idx = patch_corresponding_slide_idx.unique()
445 | for bag_idx_i in available_bag_idx:
446 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i)
447 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max():
448 | raise
449 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max())
450 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max())
451 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse)
452 | bag_label_prediction = torch.tensor(bag_label_prediction)
453 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1))
454 | self.writer.add_scalar('test_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
455 | return 0
456 |
457 |
458 | def str2bool(v):
459 | """
460 | Input:
461 | v - string
462 | output:
463 | True/False
464 | """
465 | if isinstance(v, bool):
466 | return v
467 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
468 | return True
469 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
470 | return False
471 | else:
472 | raise argparse.ArgumentTypeError('Boolean value expected.')
473 |
474 |
475 | def get_parser():
476 | parser = argparse.ArgumentParser(description='PyTorch Implementation of Self-Label')
477 | # optimizer
478 | parser.add_argument('--epochs', default=1500, type=int, help='number of epochs')
479 | parser.add_argument('--batch_size', default=512, type=int, help='batch size (default: 256)')
480 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate (default: 0.05)')
481 | parser.add_argument('--lrdrop', default=1500, type=int, help='multiply LR by 0.5 every (default: 150 epochs)')
482 | parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: (-5)')
483 | parser.add_argument('--dtype', default='f64', choices=['f64', 'f32'], type=str, help='SK-algo dtype (default: f64)')
484 |
485 | # SK algo
486 | parser.add_argument('--nopts', default=100, type=int, help='number of pseudo-opts (default: 100)')
487 | parser.add_argument('--augs', default=3, type=int, help='augmentation level (default: 3)')
488 | parser.add_argument('--lamb', default=25, type=int, help='for pseudoopt: lambda (default:25) ')
489 |
490 | # architecture
491 | # parser.add_argument('--arch', default='alexnet_MNIST', type=str, help='alexnet or resnet (default: alexnet)')
492 |
493 | # housekeeping
494 | parser.add_argument('--device', default='0', type=str, help='GPU devices to use for storage and model')
495 | parser.add_argument('--modeldevice', default='0', type=str, help='GPU numbers on which the CNN runs')
496 | parser.add_argument('--exp', default='self-label-default', help='path to experiment directory')
497 | parser.add_argument('--workers', default=0, type=int,help='number workers (default: 6)')
498 | parser.add_argument('--comment', default='DEBUG_BagDistillation_DSMIL_BasedOnFeat', type=str, help='name for tensorboardX')
499 | parser.add_argument('--log-intv', default=1, type=int, help='save stuff every x epochs (default: 1)')
500 | parser.add_argument('--log_iter', default=200, type=int, help='log every x-th batch (default: 200)')
501 | parser.add_argument('--seed', default=10, type=int, help='random seed')
502 |
503 | parser.add_argument('--PLPostProcessMethod', default='NegGuide', type=str,
504 | help='Post-processing method of Attention Scores to build Pseudo Lables',
505 | choices=['NegGuide', 'NegGuide_TopK', 'NegGuide_Similarity'])
506 | parser.add_argument('--StuFilterType', default='FilterNegInstance__ThreProb50', type=str,
507 | help='Type of using Student Prediction to imporve Teacher '
508 | '[ReplaceAS, FilterNegInstance_Top95, FilterNegInstance__ThreProb90]')
509 | parser.add_argument('--smoothE', default=9999, type=int, help='num of epoch to apply StuFilter')
510 | parser.add_argument('--stu_loss_weight_neg', default=0.1, type=float, help='weight of neg instances in stu training')
511 | parser.add_argument('--stuOptPeriod', default=1, type=int, help='period of stu optimization')
512 | return parser.parse_args()
513 |
514 |
515 | if __name__ == "__main__":
516 | args = get_parser()
517 |
518 | # torch.manual_seed(args.seed)
519 | # random.seed(args.seed)
520 | # np.random.seed(args.seed)
521 |
522 | name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")+"_%s" % args.comment.replace('/', '_') + \
523 | "_Seed{}_Bs{}_lr{}_PLPostProcessBy{}_StuFilterType{}_smoothE{}_weightN{}_StuOptP{}".format(
524 | args.seed, args.batch_size, args.lr,
525 | args.PLPostProcessMethod, args.StuFilterType, args.smoothE, args.stu_loss_weight_neg, args.stuOptPeriod)
526 | try:
527 | args.device = [int(item) for item in args.device.split(',')]
528 | except AttributeError:
529 | args.device = [int(args.device)]
530 | args.modeldevice = args.device
531 | util.setup_runtime(seed=42, cuda_dev_id=list(np.unique(args.modeldevice + args.device)))
532 |
533 | print(name, flush=True)
534 |
535 | writer = SummaryWriter('./runs_Cervical/%s'%name)
536 | writer.add_text('args', " \n".join(['%s %s' % (arg, getattr(args, arg)) for arg in vars(args)]))
537 |
538 | # Setup model
539 | model_encoder = camelyon_feat_projecter(input_dim=512, output_dim=512).to('cuda:0')
540 | model_teacherHead = teacher_DSMIL_head(input_feat_dim=512).to('cuda:0')
541 | model_studentHead = student_head(input_feat_dim=512).to('cuda:0')
542 |
543 | optimizer_encoder = torch.optim.SGD(model_encoder.parameters(), lr=args.lr)
544 | optimizer_teacherHead = torch.optim.SGD(model_teacherHead.parameters(), lr=args.lr)
545 | optimizer_studentHead = torch.optim.SGD(model_studentHead.parameters(), lr=args.lr)
546 |
547 | # Setup loaders
548 | train_ds_return_instance = CervicalCaner_16_feat(train=True, return_bag=False)
549 |
550 | train_ds_return_bag = copy.deepcopy(train_ds_return_instance)
551 | train_ds_return_bag.return_bag = True
552 | val_ds_return_instance = CervicalCaner_16_feat(train=False, return_bag=False)
553 | val_ds_return_bag = CervicalCaner_16_feat(train=False, return_bag=True)
554 |
555 | train_loader_instance = torch.utils.data.DataLoader(train_ds_return_instance, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False)
556 | train_loader_bag = torch.utils.data.DataLoader(train_ds_return_bag, batch_size=1, shuffle=True, num_workers=args.workers, drop_last=False)
557 | val_loader_instance = torch.utils.data.DataLoader(val_ds_return_instance, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False)
558 | val_loader_bag = torch.utils.data.DataLoader(val_ds_return_bag, batch_size=1, shuffle=False, num_workers=args.workers, drop_last=False)
559 |
560 | print("[Data] {} training samples".format(len(train_loader_instance.dataset)))
561 | print("[Data] {} evaluating samples".format(len(val_loader_instance.dataset)))
562 |
563 | if torch.cuda.device_count() > 1:
564 | print("Let's use", len(args.modeldevice), "GPUs for the model")
565 | if len(args.modeldevice) == 1:
566 | print('single GPU model', flush=True)
567 | else:
568 | model_encoder = nn.DataParallel(model_encoder, device_ids=list(range(len(args.modeldevice))))
569 | model_teacherHead = nn.DataParallel(model_teacherHead, device_ids=list(range(len(args.modeldevice))))
570 |
571 | # Setup optimizer
572 | o = Optimizer(model_encoder=model_encoder, model_teacherHead=model_teacherHead, model_studentHead=model_studentHead,
573 | optimizer_encoder=optimizer_encoder, optimizer_teacherHead=optimizer_teacherHead, optimizer_studentHead=optimizer_studentHead,
574 | train_bagloader=train_loader_bag, train_instanceloader=train_loader_instance,
575 | test_bagloader=val_loader_bag, test_instanceloader=val_loader_instance,
576 | writer=writer, num_epoch=args.epochs,
577 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
578 | PLPostProcessMethod=args.PLPostProcessMethod, StuFilterType=args.StuFilterType, smoothE=args.smoothE,
579 | stu_loss_weight_neg=args.stu_loss_weight_neg, stuOptPeriod=args.stuOptPeriod)
580 | # Optimize
581 | o.optimize()
--------------------------------------------------------------------------------
/train_CAMELYONFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import warnings
3 | import os
4 | import time
5 | import numpy as np
6 |
7 | import torch
8 | import torch.optim
9 | import torch.nn as nn
10 | import torch.utils.data
11 | from tensorboardX import SummaryWriter
12 | # import models
13 | # from models.alexnet import alexnet_CIFAR10, alexnet_CIFAR10_Attention
14 | from models.alexnet import camelyon_feat_projecter, teacher_DSMIL_head, student_head
15 | # from dataset_toy import Dataset_toy
16 | # from Datasets_loader.dataset_MNIST_challenge import MNIST_WholeSlide_challenge
17 | # from Datasets_loader.dataset_MIL_CIFAR import CIFAR_WholeSlide_challenge
18 | from Datasets_loader.dataset_CAMELYON16_BasedOnFeat import CAMELYON_16_feat
19 | import datetime
20 | import utliz
21 | import util
22 | import random
23 | from tqdm import tqdm
24 | import copy
25 |
26 |
27 | class Optimizer:
28 | def __init__(self, model_encoder, model_teacherHead, model_studentHead,
29 | optimizer_encoder, optimizer_teacherHead, optimizer_studentHead,
30 | train_bagloader, train_instanceloader, test_bagloader, test_instanceloader,
31 | writer=None, num_epoch=100,
32 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
33 | PLPostProcessMethod='NegGuide', StuFilterType='ReplaceAS', smoothE=100,
34 | stu_loss_weight_neg=0.1, stuOptPeriod=1):
35 | self.model_encoder = model_encoder
36 | self.model_teacherHead = model_teacherHead
37 | self.model_studentHead = model_studentHead
38 | self.optimizer_encoder = optimizer_encoder
39 | self.optimizer_teacherHead = optimizer_teacherHead
40 | self.optimizer_studentHead = optimizer_studentHead
41 | self.train_bagloader = train_bagloader
42 | self.train_instanceloader = train_instanceloader
43 | self.test_bagloader = test_bagloader
44 | self.test_instanceloader = test_instanceloader
45 | self.writer = writer
46 | self.num_epoch = num_epoch
47 | self.dev = dev
48 | self.log_period = 10
49 | self.PLPostProcessMethod = PLPostProcessMethod
50 | self.StuFilterType = StuFilterType
51 | self.smoothE = smoothE
52 | self.stu_loss_weight_neg = stu_loss_weight_neg
53 | self.stuOptPeriod = stuOptPeriod
54 |
55 | def optimize(self):
56 | self.Bank_all_Bags_label = None
57 | self.Bank_all_instances_pred_byTeacher = None
58 | self.Bank_all_instances_feat_byTeacher = None
59 | self.Bank_all_instances_pred_processed = None
60 |
61 | self.Bank_all_instances_pred_byStudent = None
62 |
63 | # Load pre-extracted SimCLR features
64 | # pre_trained_SimCLR_feat = self.train_instanceloader.dataset.ds_data_simCLR_feat[self.train_instanceloader.dataset.idx_all_slides].to(self.dev)
65 | for epoch in range(self.num_epoch):
66 | self.optimize_teacher(epoch)
67 | self.evaluate_teacher(epoch)
68 | if epoch % self.stuOptPeriod == 0:
69 | self.optimize_student(epoch)
70 | self.evaluate_student(epoch)
71 |
72 | return 0
73 |
74 | def optimize_teacher(self, epoch):
75 | self.model_encoder.train()
76 | self.model_teacherHead.train()
77 | self.model_studentHead.eval()
78 | criterion = torch.nn.CrossEntropyLoss()
79 | ## optimize teacher with bag-dataloader
80 | # 1. change loader to bag-loader
81 | loader = self.train_bagloader
82 | # 2. optimize
83 | patch_label_gt = []
84 | patch_label_pred = []
85 | bag_label_gt = []
86 | bag_label_pred = []
87 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher training')):
88 | for i, j in enumerate(label):
89 | if torch.is_tensor(j):
90 | label[i] = j.to(self.dev)
91 | selected = selected.squeeze(0)
92 | niter = epoch * len(loader) + iter
93 |
94 | data = data.to(self.dev)
95 | feat = self.model_encoder(data.squeeze(0))
96 | if epoch > self.smoothE:
97 | if "FilterNegInstance" in self.StuFilterType:
98 | # using student prediction to remove negative instance feat in the positive bag
99 | if label[1] == 1:
100 | with torch.no_grad():
101 | pred_byStudent = self.model_studentHead(feat)
102 | pred_byStudent = torch.softmax(pred_byStudent, dim=1)[:, 1]
103 | if '_Top' in self.StuFilterType:
104 | # strategy A: remove the topK most negative instance
105 | idx_to_keep = torch.topk(-pred_byStudent, k=int(self.StuFilterType.split('_Top')[-1]))[1]
106 | elif '_ThreProb' in self.StuFilterType:
107 | # strategy B: remove the negative instance above prob K
108 | idx_to_keep = torch.where(pred_byStudent <= int(self.StuFilterType.split('_ThreProb')[-1])/100.0)[0]
109 | if idx_to_keep.shape[0] == 0: # if all instance are dropped, keep the most positive one
110 | idx_to_keep = torch.topk(pred_byStudent, k=1)[1]
111 | feat_removedNeg = feat[idx_to_keep]
112 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat_removedNeg)
113 | instance_attn_score = torch.cat([instance_attn_score, instance_attn_score[:, 1].min()*torch.ones(feat.shape[0]-instance_attn_score.shape[0], 2).to(instance_attn_score.device)], dim=0)
114 | else:
115 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
116 | else:
117 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
118 | else:
119 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat)
120 |
121 | max_id = torch.argmax(instance_attn_score[:, 1])
122 | bag_pred_byMax = instance_attn_score[max_id, :].squeeze(0)
123 | bag_loss = criterion(bag_prediction, label[1])
124 | bag_loss_byMax = criterion(bag_pred_byMax.unsqueeze(0), label[1])
125 | loss_teacher = 0.5 * bag_loss + 0.5 * bag_loss_byMax
126 |
127 | self.optimizer_encoder.zero_grad()
128 | self.optimizer_teacherHead.zero_grad()
129 | loss_teacher.backward()
130 | self.optimizer_encoder.step()
131 | self.optimizer_teacherHead.step()
132 |
133 | bag_prediction = 1.0 * torch.softmax(bag_prediction, dim=1) + \
134 | 0.0 * torch.softmax(bag_pred_byMax.unsqueeze(0), dim=1)
135 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1)
136 |
137 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0))
138 | patch_label_gt.append(label[0].squeeze(0))
139 | bag_label_pred.append(bag_prediction.detach()[0, 1])
140 | bag_label_gt.append(label[1])
141 | if niter % self.log_period == 0:
142 | self.writer.add_scalar('train_loss_Teacher', loss_teacher.item(), niter)
143 |
144 | patch_label_pred = torch.cat(patch_label_pred)
145 | patch_label_gt = torch.cat(patch_label_gt)
146 | bag_label_pred = torch.tensor(bag_label_pred)
147 | bag_label_gt = torch.cat(bag_label_gt)
148 |
149 | self.estimated_AttnScore_norm_para_min = patch_label_pred.min()
150 | self.estimated_AttnScore_norm_para_max = patch_label_pred.max()
151 | patch_label_pred_normed = self.norm_AttnScore2Prob(patch_label_pred)
152 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1))
153 |
154 | bag_auc_ByTeacher = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_pred.reshape(-1))
155 | self.writer.add_scalar('train_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch)
156 | self.writer.add_scalar('train_bag_AUC_byTeacher', bag_auc_ByTeacher, epoch)
157 | # print("Epoch:{} train_bag_AUC_byTeacher:{}".format(epoch, bag_auc_ByTeacher))
158 | return 0
159 |
160 | def norm_AttnScore2Prob(self, attn_score):
161 | prob = (attn_score - self.estimated_AttnScore_norm_para_min) / (self.estimated_AttnScore_norm_para_max - self.estimated_AttnScore_norm_para_min)
162 | return prob
163 |
164 | def post_process_pred_byTeacher(self, Bank_all_instances_feat, Bank_all_instances_pred, Bank_all_bags_label, method='NegGuide'):
165 | if method=='NegGuide':
166 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
167 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
168 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
169 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
170 | elif method=='NegGuide_TopK':
171 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
172 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
173 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0]
174 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
175 | K = 3
176 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1]
177 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1)
178 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
179 | elif method=='NegGuide_Similarity':
180 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone()
181 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5)
182 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0]
183 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0]
184 | K = 1
185 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1]
186 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1)
187 | Bank_all_Pos_instances_feat = Bank_all_instances_feat[idx_pos_bag]
188 | Bank_mostSalient_Pos_instances_feat = []
189 | for i in range(Bank_all_Pos_instances_feat.shape[0]):
190 | Bank_mostSalient_Pos_instances_feat.append(Bank_all_Pos_instances_feat[i, idx_topK_inside_pos_bag[i, 0], :].unsqueeze(0).unsqueeze(0))
191 | Bank_mostSalient_Pos_instances_feat = torch.cat(Bank_mostSalient_Pos_instances_feat, dim=0)
192 |
193 | distance_matrix = Bank_all_Pos_instances_feat - Bank_mostSalient_Pos_instances_feat
194 | distance_matrix = torch.norm(distance_matrix, dim=-1, p=2)
195 | Bank_all_instances_pred_processed[idx_pos_bag, :] = self.distanceMatrix2PL(distance_matrix)
196 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0
197 | else:
198 | raise TypeError
199 | return Bank_all_instances_pred_processed
200 |
201 | def distanceMatrix2PL(self, distance_matrix, method='percentage'):
202 | # distance_matrix is of shape NxL (Num of Positive Bag * Bag Length)
203 | # represents the distance between each instance with their corresponding most salient instance
204 | # return Pseudo-labels of shape NxL (value should belong to [0,1])
205 |
206 | if method == 'softMax':
207 | # 1. just use softMax to keep PLs value fall into [0,1]
208 | similarity_matrix = 1/(distance_matrix + 1e-5)
209 | pseudo_labels = torch.softmax(similarity_matrix, dim=1)
210 | elif method == 'percentage':
211 | # 2. use percentage to keep n% PL=1, 1-n% PL=0
212 | p = 0.1 # 10% is set
213 | threshold_v = distance_matrix.topk(k=int(100 * p), dim=1)[0][:, -1].unsqueeze(1).repeat([1, 100]) # of size Nx100
214 | pseudo_labels = torch.zeros_like(distance_matrix)
215 | pseudo_labels[distance_matrix >= threshold_v] = 0.0
216 | pseudo_labels[distance_matrix < threshold_v] = 1.0
217 | elif method == 'threshold':
218 | # 3. use threshold to set PLs of instance with distance above the threshold to 1
219 | raise TypeError
220 | else:
221 | raise TypeError
222 |
223 | ## visulaize the pseudo_labels distribution of inside each bag
224 | # import matplotlib.pyplot as plt
225 | # plt.figure()
226 | # plt.hist(pseudo_labels.cpu().numpy().reshape(-1))
227 |
228 | return pseudo_labels
229 |
230 | def optimize_student(self, epoch):
231 | self.model_teacherHead.train()
232 | self.model_encoder.train()
233 | self.model_studentHead.train()
234 | ## optimize teacher with instance-dataloader
235 | # 1. change loader to instance-loader
236 | loader = self.train_instanceloader
237 | # 2. optimize
238 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
239 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
240 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
241 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
242 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')):
243 | for i, j in enumerate(label):
244 | if torch.is_tensor(j):
245 | label[i] = j.to(self.dev)
246 | selected = selected.squeeze(0)
247 | niter = epoch * len(loader) + iter
248 |
249 | data = data.to(self.dev)
250 |
251 | # get teacher output of instance
252 | feat = self.model_encoder(data)
253 | with torch.no_grad():
254 | instance_attn_score, _, _, _ = self.model_teacherHead(feat)
255 | pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score[:, 1]).clamp(min=1e-5, max=1-1e-5).squeeze(0)
256 | # set true negative patch label to [1, 0]
257 | pseudo_instance_label[label[1] == 0] = 0
258 | # # DEBUG: Assign GT patch label
259 | # pseudo_instance_label = label[0]
260 | # get student output of instance
261 | patch_prediction = self.model_studentHead(feat)
262 | patch_prediction = torch.softmax(patch_prediction, dim=1)
263 |
264 | # cal loss
265 | loss_student = -1. * torch.mean(self.stu_loss_weight_neg * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) +
266 | (1-self.stu_loss_weight_neg) * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5))
267 | self.optimizer_encoder.zero_grad()
268 | self.optimizer_studentHead.zero_grad()
269 | loss_student.backward()
270 | self.optimizer_encoder.step()
271 | self.optimizer_studentHead.step()
272 |
273 | patch_corresponding_slide_idx[selected, 0] = label[2]
274 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
275 | patch_label_gt[selected, 0] = label[0]
276 | bag_label_gt[selected, 0] = label[1]
277 | if niter % self.log_period == 0:
278 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter)
279 |
280 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
281 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
282 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
283 |
284 | # cal bag-level auc
285 | bag_label_gt_coarse = []
286 | bag_label_prediction = []
287 | available_bag_idx = patch_corresponding_slide_idx.unique()
288 | for bag_idx_i in available_bag_idx:
289 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i)
290 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max():
291 | raise
292 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max())
293 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max())
294 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse)
295 | bag_label_prediction = torch.tensor(bag_label_prediction)
296 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1))
297 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
298 | return 0
299 |
300 | def optimize_student_fromBank(self, epoch, Bank_all_instances_pred):
301 | self.model_teacherHead.train()
302 | self.model_encoder.train()
303 | self.model_studentHead.train()
304 | ## optimize teacher with instance-dataloader
305 | # 1. change loader to instance-loader
306 | loader = self.train_instanceloader
307 | # 2. optimize
308 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
309 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
310 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
311 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
312 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')):
313 | for i, j in enumerate(label):
314 | if torch.is_tensor(j):
315 | label[i] = j.to(self.dev)
316 | selected = selected.squeeze(0)
317 | niter = epoch * len(loader) + iter
318 |
319 | data = data.to(self.dev)
320 |
321 | # get teacher output of instance
322 | feat = self.model_encoder(data)
323 | # with torch.no_grad():
324 | # _, _, _, instance_attn_score = self.model_teacherHead(feat, returnBeforeSoftMaxA=True)
325 | # pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score).clamp(min=1e-5, max=1-1e-5).squeeze(0)
326 | # # set true negative patch label to [1, 0]
327 | # pseudo_instance_label[label[1] == 0] = 0
328 |
329 | pseudo_instance_label = Bank_all_instances_pred[selected//100, selected%100]
330 | # # DEBUG: Assign GT patch label
331 | # pseudo_instance_label = label[0]
332 | # get student output of instance
333 | patch_prediction = self.model_studentHead(feat)
334 | patch_prediction = torch.softmax(patch_prediction, dim=1)
335 |
336 | # cal loss
337 | loss_student = -1. * torch.mean(0.1 * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) +
338 | 0.9 * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5))
339 | self.optimizer_encoder.zero_grad()
340 | self.optimizer_studentHead.zero_grad()
341 | loss_student.backward()
342 | self.optimizer_encoder.step()
343 | self.optimizer_studentHead.step()
344 |
345 | patch_corresponding_slide_idx[selected, 0] = label[2]
346 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
347 | patch_label_gt[selected, 0] = label[0]
348 | bag_label_gt[selected, 0] = label[1]
349 | if niter % self.log_period == 0:
350 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter)
351 |
352 | self.Bank_all_instances_pred_byStudent = patch_label_pred
353 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
354 | bag_auc_ByStudent = 0
355 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
356 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
357 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
358 | return 0
359 |
360 | def evaluate(self, epoch, loader, log_name_prefix=''):
361 | return 0
362 |
363 | def evaluate_teacher(self, epoch):
364 | self.model_encoder.eval()
365 | self.model_teacherHead.eval()
366 | ## optimize teacher with bag-dataloader
367 | # 1. change loader to bag-loader
368 | loader = self.test_bagloader
369 | # 2. optimize
370 | patch_label_gt = []
371 | patch_label_pred = []
372 | bag_label_gt = []
373 | bag_label_prediction_withAttnScore = []
374 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher evaluating')):
375 | for i, j in enumerate(label):
376 | if torch.is_tensor(j):
377 | label[i] = j.to(self.dev)
378 | selected = selected.squeeze(0)
379 | niter = epoch * len(loader) + iter
380 |
381 | data = data.to(self.dev)
382 | with torch.no_grad():
383 | feat = self.model_encoder(data.squeeze(0))
384 |
385 | instance_attn_score, bag_prediction_withAttnScore, _, _ = self.model_teacherHead(feat)
386 | bag_prediction_withAttnScore = torch.softmax(bag_prediction_withAttnScore, 1)
387 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1)
388 |
389 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0))
390 | patch_label_gt.append(label[0].squeeze(0))
391 | bag_label_prediction_withAttnScore.append(bag_prediction_withAttnScore.detach()[0, 1])
392 | bag_label_gt.append(label[1])
393 |
394 | patch_label_pred = torch.cat(patch_label_pred)
395 | patch_label_gt = torch.cat(patch_label_gt)
396 | bag_label_prediction_withAttnScore = torch.tensor(bag_label_prediction_withAttnScore)
397 | bag_label_gt = torch.cat(bag_label_gt)
398 |
399 | patch_label_pred_normed = (patch_label_pred - patch_label_pred.min()) / (patch_label_pred.max() - patch_label_pred.min())
400 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1))
401 | bag_auc_ByTeacher_withAttnScore = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_prediction_withAttnScore.reshape(-1))
402 | self.writer.add_scalar('test_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch)
403 | self.writer.add_scalar('test_bag_AUC_byTeacher', bag_auc_ByTeacher_withAttnScore, epoch)
404 | return 0
405 |
406 | def evaluate_student(self, epoch):
407 | self.model_encoder.eval()
408 | self.model_studentHead.eval()
409 | ## optimize teacher with instance-dataloader
410 | # 1. change loader to instance-loader
411 | loader = self.test_instanceloader
412 | # 2. optimize
413 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset
414 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev)
415 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
416 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev)
417 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student evaluating')):
418 | for i, j in enumerate(label):
419 | if torch.is_tensor(j):
420 | label[i] = j.to(self.dev)
421 | selected = selected.squeeze(0)
422 | niter = epoch * len(loader) + iter
423 |
424 | data = data.to(self.dev)
425 |
426 | # get student output of instance
427 | with torch.no_grad():
428 | feat = self.model_encoder(data)
429 | patch_prediction = self.model_studentHead(feat)
430 | patch_prediction = torch.softmax(patch_prediction, dim=1)
431 |
432 | patch_corresponding_slide_idx[selected, 0] = label[2]
433 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1]
434 | patch_label_gt[selected, 0] = label[0]
435 | bag_label_gt[selected, 0] = label[1]
436 |
437 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1))
438 | self.writer.add_scalar('test_instance_AUC_byStudent', instance_auc_ByStudent, epoch)
439 | # print("Epoch:{} test_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent))
440 |
441 | # cal bag-level auc
442 | bag_label_gt_coarse = []
443 | bag_label_prediction = []
444 | available_bag_idx = patch_corresponding_slide_idx.unique()
445 | for bag_idx_i in available_bag_idx:
446 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i)
447 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max():
448 | raise
449 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max())
450 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max())
451 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse)
452 | bag_label_prediction = torch.tensor(bag_label_prediction)
453 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1))
454 | self.writer.add_scalar('test_bag_AUC_byStudent', bag_auc_ByStudent, epoch)
455 | return 0
456 |
457 |
458 | def str2bool(v):
459 | """
460 | Input:
461 | v - string
462 | output:
463 | True/False
464 | """
465 | if isinstance(v, bool):
466 | return v
467 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
468 | return True
469 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
470 | return False
471 | else:
472 | raise argparse.ArgumentTypeError('Boolean value expected.')
473 |
474 |
475 | def get_parser():
476 | parser = argparse.ArgumentParser(description='PyTorch Implementation of Self-Label')
477 | # optimizer
478 | parser.add_argument('--epochs', default=1500, type=int, help='number of epochs')
479 | parser.add_argument('--batch_size', default=512, type=int, help='batch size (default: 256)')
480 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate (default: 0.05)')
481 | parser.add_argument('--lrdrop', default=1500, type=int, help='multiply LR by 0.5 every (default: 150 epochs)')
482 | parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: (-5)')
483 | parser.add_argument('--dtype', default='f64', choices=['f64', 'f32'], type=str, help='SK-algo dtype (default: f64)')
484 |
485 | # SK algo
486 | parser.add_argument('--nopts', default=100, type=int, help='number of pseudo-opts (default: 100)')
487 | parser.add_argument('--augs', default=3, type=int, help='augmentation level (default: 3)')
488 | parser.add_argument('--lamb', default=25, type=int, help='for pseudoopt: lambda (default:25) ')
489 |
490 | # architecture
491 | # parser.add_argument('--arch', default='alexnet_MNIST', type=str, help='alexnet or resnet (default: alexnet)')
492 |
493 | # housekeeping
494 | parser.add_argument('--device', default='0', type=str, help='GPU devices to use for storage and model')
495 | parser.add_argument('--modeldevice', default='0', type=str, help='GPU numbers on which the CNN runs')
496 | parser.add_argument('--exp', default='self-label-default', help='path to experiment directory')
497 | parser.add_argument('--workers', default=0, type=int,help='number workers (default: 6)')
498 | parser.add_argument('--comment', default='DEBUG_BagDistillation_DSMIL', type=str, help='name for tensorboardX')
499 | parser.add_argument('--log-intv', default=1, type=int, help='save stuff every x epochs (default: 1)')
500 | parser.add_argument('--log_iter', default=200, type=int, help='log every x-th batch (default: 200)')
501 | parser.add_argument('--seed', default=10, type=int, help='random seed')
502 |
503 | parser.add_argument('--PLPostProcessMethod', default='NegGuide', type=str,
504 | help='Post-processing method of Attention Scores to build Pseudo Lables',
505 | choices=['NegGuide', 'NegGuide_TopK', 'NegGuide_Similarity'])
506 | parser.add_argument('--StuFilterType', default='FilterNegInstance__ThreProb50', type=str,
507 | help='Type of using Student Prediction to imporve Teacher '
508 | '[ReplaceAS, FilterNegInstance_Top95, FilterNegInstance__ThreProb90]')
509 | parser.add_argument('--smoothE', default=9999, type=int, help='num of epoch to apply StuFilter')
510 | parser.add_argument('--stu_loss_weight_neg', default=0.1, type=float, help='weight of neg instances in stu training')
511 | parser.add_argument('--stuOptPeriod', default=1, type=int, help='period of stu optimization')
512 | return parser.parse_args()
513 |
514 |
515 | if __name__ == "__main__":
516 | args = get_parser()
517 |
518 | # torch.manual_seed(args.seed)
519 | # random.seed(args.seed)
520 | # np.random.seed(args.seed)
521 |
522 | name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")+"_%s" % args.comment.replace('/', '_') + \
523 | "_Seed{}_Bs{}_lr{}_PLPostProcessBy{}_StuFilterType{}_smoothE{}_weightN{}_StuOptP{}".format(
524 | args.seed, args.batch_size, args.lr,
525 | args.PLPostProcessMethod, args.StuFilterType, args.smoothE, args.stu_loss_weight_neg, args.stuOptPeriod)
526 | try:
527 | args.device = [int(item) for item in args.device.split(',')]
528 | except AttributeError:
529 | args.device = [int(args.device)]
530 | args.modeldevice = args.device
531 | util.setup_runtime(seed=42, cuda_dev_id=list(np.unique(args.modeldevice + args.device)))
532 |
533 | print(name, flush=True)
534 |
535 | writer = SummaryWriter('./runs_CAMELYON/%s'%name)
536 | writer.add_text('args', " \n".join(['%s %s' % (arg, getattr(args, arg)) for arg in vars(args)]))
537 |
538 | # Setup model
539 | model_encoder = camelyon_feat_projecter(input_dim=512, output_dim=512).to('cuda:0')
540 | model_teacherHead = teacher_DSMIL_head(input_feat_dim=512).to('cuda:0')
541 | model_studentHead = student_head(input_feat_dim=512).to('cuda:0')
542 |
543 | optimizer_encoder = torch.optim.SGD(model_encoder.parameters(), lr=args.lr)
544 | optimizer_teacherHead = torch.optim.SGD(model_teacherHead.parameters(), lr=args.lr)
545 | optimizer_studentHead = torch.optim.SGD(model_studentHead.parameters(), lr=args.lr)
546 |
547 | # Setup loaders
548 | train_ds_return_instance = CAMELYON_16_feat(train=True, transform=None, downsample=1.0, drop_threshold=0, preload=True, return_bag=False)
549 |
550 | train_ds_return_bag = copy.deepcopy(train_ds_return_instance)
551 | train_ds_return_bag.return_bag = True
552 | val_ds_return_instance = CAMELYON_16_feat(train=False, transform=None, downsample=1.0, drop_threshold=0, preload=True, return_bag=False)
553 | val_ds_return_bag = CAMELYON_16_feat(train=False, transform=None, downsample=1.0, drop_threshold=0, preload=True, return_bag=True)
554 |
555 | train_loader_instance = torch.utils.data.DataLoader(train_ds_return_instance, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False)
556 | train_loader_bag = torch.utils.data.DataLoader(train_ds_return_bag, batch_size=1, shuffle=True, num_workers=args.workers, drop_last=False)
557 | val_loader_instance = torch.utils.data.DataLoader(val_ds_return_instance, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False)
558 | val_loader_bag = torch.utils.data.DataLoader(val_ds_return_bag, batch_size=1, shuffle=False, num_workers=args.workers, drop_last=False)
559 |
560 | print("[Data] {} training samples".format(len(train_loader_instance.dataset)))
561 | print("[Data] {} evaluating samples".format(len(val_loader_instance.dataset)))
562 |
563 | if torch.cuda.device_count() > 1:
564 | print("Let's use", len(args.modeldevice), "GPUs for the model")
565 | if len(args.modeldevice) == 1:
566 | print('single GPU model', flush=True)
567 | else:
568 | model_encoder = nn.DataParallel(model_encoder, device_ids=list(range(len(args.modeldevice))))
569 | model_teacherHead = nn.DataParallel(model_teacherHead, device_ids=list(range(len(args.modeldevice))))
570 | optimizer_studentHead = nn.DataParallel(optimizer_studentHead, device_ids=list(range(len(args.modeldevice))))
571 |
572 | # Setup optimizer
573 | o = Optimizer(model_encoder=model_encoder, model_teacherHead=model_teacherHead, model_studentHead=model_studentHead,
574 | optimizer_encoder=optimizer_encoder, optimizer_teacherHead=optimizer_teacherHead, optimizer_studentHead=optimizer_studentHead,
575 | train_bagloader=train_loader_bag, train_instanceloader=train_loader_instance,
576 | test_bagloader=val_loader_bag, test_instanceloader=val_loader_instance,
577 | writer=writer, num_epoch=args.epochs,
578 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
579 | PLPostProcessMethod=args.PLPostProcessMethod, StuFilterType=args.StuFilterType, smoothE=args.smoothE,
580 | stu_loss_weight_neg=args.stu_loss_weight_neg, stuOptPeriod=args.stuOptPeriod)
581 | # Optimize
582 | o.optimize()
--------------------------------------------------------------------------------