├── figure2.png ├── figure3.jpg ├── models ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── alexnet.cpython-36.pyc │ ├── resnetv1.cpython-36.pyc │ └── resnetv2.cpython-36.pyc ├── __init__.py ├── resnetv2.py └── alexnet.py ├── README.md ├── Datasets_loader ├── dataset_CAMELYON16_BasedOnFeat.py ├── dataset_CAMELYON16.py ├── dataset_TCGA_LungCancer.py ├── dataset_MIL_CIFAR.py ├── dataset_CervicalCancer.py └── dataset_MIL_NCTCRCHE.py ├── util.py ├── utliz.py ├── train_TCGAFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py ├── train_CervicalFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py └── train_CAMELYONFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py /figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/miccaiif/WENO/HEAD/figure2.png -------------------------------------------------------------------------------- /figure3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/miccaiif/WENO/HEAD/figure3.jpg -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnetv1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/resnetv1.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnetv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/miccaiif/WENO/HEAD/models/__pycache__/resnetv2.cpython-36.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .alexnet import * 8 | from .resnetv2 import * 9 | from .resnetv1 import * 10 | 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :camel: WENO 2 | Official PyTorch implementation of our NeurIPS 2022 paper: **[Bi-directional Weakly Supervised Knowledge Distillation for Whole Slide Image Classification](https://arxiv.org/abs/2210.03664)**. We propose an end-to-end weakly supervised knowledge distillation framework (**WENO**) for WSI classification, which integrates a bag classifier and an instance classifier in a knowledge distillation framework to mutually improve the performance of both classifiers. WENO is a plug-and-play framework that can be easily applied to any existing attention-based bag classification methods. 3 | 4 |

5 | 6 |

7 | 8 | ### Frequently Asked Questions. 9 | 10 | * Regarding the preprocessing 11 | 12 | For specific preprocessing, as the settings of different MIL experiments vary (such as patch size, scale, etc.), patching needs to be conducted according to your own experimental settings. The [DSMIL](https://github.com/binli123/dsmil-wsi) paper provides a good example for reference (and is also referenced in this article). As uploading all these extracted feats files would require a lot of time and space, we have open-sourced the main and key code models. The training details in the paper and main codes can support the reproduction of this work. Thank you again for your attention! You are welcome to contact and cite us! Thank you! 13 | 14 | ### Citation 15 | If this work is helpful to you, please cite it as: 16 | ``` 17 | @article{qu2022bi, 18 | title={Bi-directional weakly supervised knowledge distillation for whole slide image classification}, 19 | author={Qu, Linhao and Wang, Manning and Song, Zhijian and others}, 20 | journal={Advances in Neural Information Processing Systems}, 21 | volume={35}, 22 | pages={15368--15381}, 23 | year={2022} 24 | } 25 | ``` 26 | 27 | ### Contact Information 28 | If you have any question, please email to me [lhqu20@fudan.edu.cn](lhqu20@fudan.edu.cn). 29 | -------------------------------------------------------------------------------- /models/resnetv2.py: -------------------------------------------------------------------------------- 1 | """ Pre-activation ResNet in PyTorch. 2 | also called ResNet v2. 3 | 4 | adapted from https://github.com/kuangliu/pytorch-cifar/edit/master/models/preact_resnet.py 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import os 13 | 14 | __all__ = ['resnetv2'] 15 | 16 | class PreActBottleneck(nn.Module): 17 | '''Pre-activation version of the original Bottleneck module.''' 18 | def __init__(self, in_planes, planes, stride=1,expansion=4): 19 | super(PreActBottleneck, self).__init__() 20 | self.expansion = expansion 21 | self.bn1 = nn.BatchNorm2d(in_planes) 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 25 | self.bn3 = nn.BatchNorm2d(planes) 26 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 27 | 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(x)) 35 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 36 | out = self.conv1(out) 37 | out = self.conv2(F.relu(self.bn2(out))) 38 | out = self.conv3(F.relu(self.bn3(out))) 39 | out += shortcut 40 | return out 41 | 42 | 43 | class PreActResNet(nn.Module): 44 | def __init__(self, block, num_blocks, num_classes=10,expansion=4): 45 | super(PreActResNet, self).__init__() 46 | self.in_planes = 16*expansion 47 | 48 | self.features = nn.Sequential(*[ 49 | nn.Conv2d(3, self.in_planes, kernel_size=7, stride=2, padding=3, bias=False), 50 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 51 | self._make_layer(block, 16*expansion, num_blocks[0], stride=1, expansion=4), 52 | self._make_layer(block, 2*16*expansion, num_blocks[1], stride=2, expansion=4), 53 | self._make_layer(block, 4*16*expansion, num_blocks[2], stride=2, expansion=4), 54 | self._make_layer(block, 8*16*expansion, num_blocks[3], stride=2, expansion=4), 55 | nn.AdaptiveAvgPool2d((1, 1)) 56 | ]) 57 | self.headcount = len(num_classes) 58 | if len(num_classes) == 1: 59 | self.top_layer = nn.Sequential(*[nn.Linear(512*expansion, num_classes[0])]) # for later compatib. 60 | else: 61 | for a,i in enumerate(num_classes): 62 | setattr(self, "top_layer%d" % a, nn.Linear(512*expansion, i)) 63 | self.top_layer = None # this way headcount can act as switch. 64 | 65 | def _make_layer(self, block, planes, num_blocks, stride,expansion): 66 | strides = [stride] + [1]*(num_blocks-1) 67 | layers = [] 68 | for stride in strides: 69 | layers.append(block(self.in_planes, planes, stride,expansion)) 70 | self.in_planes = planes * expansion 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.features(x) 75 | out = out.view(out.size(0), -1) 76 | if self.headcount == 1: 77 | if self.top_layer: 78 | out = self.top_layer(out) 79 | return out 80 | else: 81 | outp = [] 82 | for i in range(self.headcount): 83 | outp.append(getattr(self, "top_layer%d" % i)(out)) 84 | return outp 85 | 86 | 87 | 88 | def PreActResNet50(num_classes): 89 | return PreActResNet(PreActBottleneck, [3,4,6,3],num_classes) 90 | 91 | def resnetv2(nlayers=50, num_classes=[1000], expansion=1): 92 | if nlayers == 50: 93 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes, expansion=4*expansion) 94 | else: 95 | raise NotImplementedError 96 | 97 | 98 | if __name__ == '__main__': 99 | import torch 100 | model = resnetv2(num_classes=[500]*3) 101 | print([ k.shape for k in model(torch.randn(64,3,224,224))]) 102 | -------------------------------------------------------------------------------- /Datasets_loader/dataset_CAMELYON16_BasedOnFeat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.utils.data as data_utils 5 | from torchvision import datasets, transforms 6 | from PIL import Image 7 | import os 8 | import glob 9 | from skimage import io 10 | from tqdm import tqdm 11 | 12 | 13 | def statistics_slide(slide_path_list): 14 | num_pos_patch_allPosSlide = 0 15 | num_patch_allPosSlide = 0 16 | num_neg_patch_allNegSlide = 0 17 | num_all_slide = len(slide_path_list) 18 | 19 | for i in slide_path_list: 20 | if 'pos' in i.split('/')[-1]: # pos slide 21 | num_pos_patch = len(glob.glob(i + "/*_pos.jpg")) 22 | num_patch = len(glob.glob(i + "/*.jpg")) 23 | num_pos_patch_allPosSlide = num_pos_patch_allPosSlide + num_pos_patch 24 | num_patch_allPosSlide = num_patch_allPosSlide + num_patch 25 | else: # neg slide 26 | num_neg_patch = len(glob.glob(i + "/*.jpg")) 27 | num_neg_patch_allNegSlide = num_neg_patch_allNegSlide + num_neg_patch 28 | 29 | print("[DATA INFO] {} slides totally".format(num_all_slide)) 30 | print("[DATA INFO] pos_patch_ratio in pos slide: {:.4f}({}/{})".format( 31 | num_pos_patch_allPosSlide / num_patch_allPosSlide, num_pos_patch_allPosSlide, num_patch_allPosSlide)) 32 | print("[DATA INFO] num of patches: {} ({} from pos slide, {} from neg slide)".format( 33 | num_patch_allPosSlide+num_neg_patch_allNegSlide, num_patch_allPosSlide, num_neg_patch_allNegSlide)) 34 | return num_patch_allPosSlide+num_neg_patch_allNegSlide 35 | 36 | 37 | class CAMELYON_16_feat(torch.utils.data.Dataset): 38 | # @profile 39 | def __init__(self, root_dir='', 40 | train=True, transform=None, downsample=1.0, drop_threshold=0.0, preload=True, return_bag=False): 41 | self.root_dir = root_dir 42 | self.train = train 43 | self.transform = transform 44 | self.downsample = downsample 45 | self.drop_threshold = drop_threshold # drop the pos slide of which positive patch ratio less than the threshold 46 | self.preload = preload 47 | self.return_bag = return_bag 48 | if train: 49 | self.root_dir = os.path.join(self.root_dir, "training") 50 | else: 51 | self.root_dir = os.path.join(self.root_dir, "testing") 52 | 53 | all_slides = glob.glob(self.root_dir + "/*") 54 | # 1.filter the pos slides which have 0 pos patch 55 | all_pos_slides = glob.glob(self.root_dir + "/*_pos") 56 | 57 | for i in all_pos_slides: 58 | num_pos_patch = len(glob.glob(i + "/*_pos.jpg")) 59 | num_patch = len(glob.glob(i + "/*.jpg")) 60 | if num_pos_patch/num_patch <= self.drop_threshold: 61 | all_slides.remove(i) 62 | print("[DATA] {} of positive patch ratio {:.4f}({}/{}) is removed".format( 63 | i, num_pos_patch/num_patch, num_pos_patch, num_patch)) 64 | statistics_slide(all_slides) 65 | # 1.1 down sample the slides 66 | print("================ Down sample ================") 67 | np.random.shuffle(all_slides) 68 | all_slides = all_slides[:int(len(all_slides)*self.downsample)] 69 | self.num_slides = len(all_slides) 70 | self.num_patches = statistics_slide(all_slides) 71 | 72 | # 2. load all pre-trained patch features (by SimCLR in DSMIL) 73 | all_slides_name = [i.split('/')[-1] for i in all_slides] 74 | if train: 75 | all_slides_feat_file = glob.glob("") 76 | else: 77 | all_slides_feat_file = glob.glob("") 78 | 79 | self.slide_feat_all = np.zeros([self.num_patches, 512], dtype=np.float32) 80 | self.slide_patch_label_all = np.zeros([self.num_patches], dtype=np.long) 81 | self.patch_corresponding_slide_label = np.zeros([self.num_patches], dtype=np.long) 82 | self.patch_corresponding_slide_index = np.zeros([self.num_patches], dtype=np.long) 83 | self.patch_corresponding_slide_name = np.zeros([self.num_patches], dtype=' 3: 62 | all_patches_slide_i = sample(all_patches_slide_i, int(len(all_patches_slide_i)*self.patch_downsample)) 63 | for j in all_patches_slide_i: 64 | if self.preload: 65 | self.all_patches[cnt_patch, :, :, :] = io.imread(j) 66 | else: 67 | self.all_patches.append(j) 68 | self.patch_corresponding_slide_label.append(int('LUSC' in i.split('/')[-4])) 69 | self.patch_corresponding_slide_index.append(cnt_slide) 70 | self.patch_corresponding_slide_name.append(i.split('/')[-1]) 71 | cnt_patch = cnt_patch + 1 72 | cnt_slide = cnt_slide + 1 73 | if not self.preload: 74 | self.all_patches = np.array(self.all_patches) 75 | self.patch_corresponding_slide_label = np.array(self.patch_corresponding_slide_label) 76 | self.patch_corresponding_slide_index = np.array(self.patch_corresponding_slide_index) 77 | self.patch_corresponding_slide_name = np.array(self.patch_corresponding_slide_name) 78 | 79 | self.num_patches = len(self.all_patches) 80 | # 3.do some statistics 81 | print("[DATA INFO] num_slide is {}; num_patches is {}\n".format(self.num_slides, self.num_patches)) 82 | 83 | def __getitem__(self, index): 84 | if self.preload: 85 | patch_image = self.all_patches[index] 86 | else: 87 | patch_image = io.imread(self.all_patches[index]) 88 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index] 89 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index] 90 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index] 91 | 92 | patch_image = self.transform(Image.fromarray(np.uint8(patch_image), 'RGB')) 93 | patch_label = 0 # patch_label is not available in TCGA 94 | return patch_image, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index, 95 | patch_corresponding_slide_name], index 96 | 97 | def __len__(self): 98 | return self.num_patches 99 | 100 | 101 | class TCGA_LungCancer_Feat(torch.utils.data.Dataset): 102 | # @profile 103 | def __init__(self, train=True, downsample=1.0, return_bag=False): 104 | self.train = train 105 | self.return_bag = return_bag 106 | bags_csv = '' 107 | bags_path = pd.read_csv(bags_csv) 108 | train_path = bags_path.iloc[0:int(len(bags_path) * 0.8), :] 109 | test_path = bags_path.iloc[int(len(bags_path) * 0.8):, :] 110 | train_path = shuffle(train_path).reset_index(drop=True) 111 | test_path = shuffle(test_path).reset_index(drop=True) 112 | 113 | if downsample < 1.0: 114 | train_path = train_path.iloc[0:int(len(train_path) * downsample), :] 115 | test_path = test_path.iloc[0:int(len(test_path) * downsample), :] 116 | 117 | self.patch_feat_all = [] 118 | self.patch_corresponding_slide_label = [] 119 | self.patch_corresponding_slide_index = [] 120 | self.patch_corresponding_slide_name = [] 121 | if self.train: 122 | for i in tqdm(range(len(train_path)), desc='loading data'): 123 | label, feats = get_bag_feats(train_path.iloc[i]) 124 | self.patch_feat_all.append(feats) 125 | self.patch_corresponding_slide_label.append(np.ones(feats.shape[0]) * label) 126 | self.patch_corresponding_slide_index.append(np.ones(feats.shape[0]) * i) 127 | self.patch_corresponding_slide_name.append(np.ones(feats.shape[0]) * i) 128 | else: 129 | for i in tqdm(range(len(test_path)), desc='loading data'): 130 | label, feats = get_bag_feats(test_path.iloc[i]) 131 | self.patch_feat_all.append(feats) 132 | self.patch_corresponding_slide_label.append(np.ones(feats.shape[0]) * label) 133 | self.patch_corresponding_slide_index.append(np.ones(feats.shape[0]) * i) 134 | self.patch_corresponding_slide_name.append(np.ones(feats.shape[0]) * i) 135 | 136 | self.patch_feat_all = np.concatenate(self.patch_feat_all, axis=0).astype(np.float32) 137 | self.patch_corresponding_slide_label = np.concatenate(self.patch_corresponding_slide_label).astype(np.long) 138 | self.patch_corresponding_slide_index =np.concatenate(self.patch_corresponding_slide_index).astype(np.long) 139 | self.patch_corresponding_slide_name = np.concatenate(self.patch_corresponding_slide_name) 140 | 141 | self.num_patches = self.patch_feat_all.shape[0] 142 | self.patch_label_all = np.zeros([self.patch_feat_all.shape[0]], dtype=np.long) # Patch label is not available and set to 0 ! 143 | # 3.do some statistics 144 | print("[DATA INFO] num_slide is {}; num_patches is {}\n".format(len(train_path), self.num_patches)) 145 | 146 | def __getitem__(self, index): 147 | if self.return_bag: 148 | idx_patch_from_slide_i = np.where(self.patch_corresponding_slide_index == index)[0] 149 | bag = self.patch_feat_all[idx_patch_from_slide_i, :] 150 | patch_labels = self.patch_label_all[idx_patch_from_slide_i] # Patch label is not available and set to 0 ! 151 | slide_label = self.patch_corresponding_slide_label[idx_patch_from_slide_i][0] 152 | slide_index = self.patch_corresponding_slide_index[idx_patch_from_slide_i][0] 153 | slide_name = self.patch_corresponding_slide_name[idx_patch_from_slide_i][0] 154 | 155 | # check data 156 | if self.patch_corresponding_slide_label[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_label[idx_patch_from_slide_i].min(): 157 | raise 158 | if self.patch_corresponding_slide_index[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_index[idx_patch_from_slide_i].min(): 159 | raise 160 | return bag, [patch_labels, slide_label, slide_index, slide_name], index 161 | else: 162 | patch_image = self.patch_feat_all[index] 163 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index] 164 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index] 165 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index] 166 | 167 | patch_label = self.patch_label_all[index] # Patch label is not available and set to 0 ! 168 | return patch_image, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index, 169 | patch_corresponding_slide_name], index 170 | 171 | def __len__(self): 172 | if self.return_bag: 173 | return self.patch_corresponding_slide_index.max() + 1 174 | else: 175 | return self.num_patches 176 | 177 | 178 | def get_bag_feats(csv_file_df): 179 | feats_csv_path = '' + csv_file_df.iloc[0].split('/')[1] + '.csv' 180 | df = pd.read_csv(feats_csv_path) 181 | feats = shuffle(df).reset_index(drop=True) 182 | feats = feats.to_numpy() 183 | label = np.zeros(1) 184 | label[0] = csv_file_df.iloc[1] 185 | return label, feats 186 | 187 | 188 | if __name__ == '__main__': 189 | train_ds_feat = TCGA_LungCancer_Feat(train=True, downsample=1.0) 190 | test_ds_feat = TCGA_LungCancer_Feat(train=False, downsample=1.0) 191 | 192 | trans = transforms.Compose([ 193 | transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2), 194 | transforms.RandomHorizontalFlip(p=0.5), 195 | transforms.RandomVerticalFlip(p=0.5), 196 | transforms.RandomRotation(degrees=90), 197 | transforms.ToTensor()]) 198 | train_ds = TCGA_LungCancer(train=True, transform=None, downsample=0.1, drop_threshold=0, preload=False) 199 | val_ds = TCGA_LungCancer(train=False, transform=None, downsample=0.1, drop_threshold=0, preload=False) 200 | train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, 201 | shuffle=True, num_workers=0, drop_last=False, pin_memory=True) 202 | val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64, 203 | shuffle=False, num_workers=0, drop_last=False, pin_memory=True) 204 | for data in tqdm(train_loader, desc='loading'): 205 | patch_img = data[0] 206 | label_patch = data[1][0] 207 | label_bag = data[1][1] 208 | idx = data[-1] 209 | print("END") 210 | -------------------------------------------------------------------------------- /Datasets_loader/dataset_MIL_CIFAR.py: -------------------------------------------------------------------------------- 1 | """Pytorch Dataset object that loads perfectly balanced MNIST dataset in bag form.""" 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.utils.data as data_utils 7 | from torchvision import datasets, transforms 8 | from PIL import Image 9 | 10 | import numpy as np 11 | from six.moves import cPickle as pickle 12 | import os 13 | import platform 14 | classes = ('plane', 'car', 'bird', 'cat', 15 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 16 | 17 | img_rows, img_cols = 32, 32 18 | input_shape = (img_rows, img_cols, 3) 19 | 20 | 21 | def load_pickle(f): 22 | version = platform.python_version_tuple() 23 | if version[0] == '2': 24 | return pickle.load(f) 25 | elif version[0] == '3': 26 | return pickle.load(f, encoding='latin1') 27 | raise ValueError("invalid python version: {}".format(version)) 28 | 29 | 30 | def load_CIFAR_batch(filename): 31 | """ load single batch of cifar """ 32 | with open(filename, 'rb') as f: 33 | datadict = load_pickle(f) 34 | X = datadict['data'] 35 | Y = datadict['labels'] 36 | X = X.reshape(10000, 3072).reshape(10000, 3, 32, 32) 37 | Y = np.array(Y) 38 | return X, Y 39 | 40 | 41 | def load_CIFAR10(ROOT): 42 | """ load all of cifar """ 43 | xs = [] 44 | ys = [] 45 | for b in range(1,6): 46 | f = os.path.join(ROOT, 'data_batch_%d' % (b, )) 47 | X, Y = load_CIFAR_batch(f) 48 | xs.append(X) 49 | ys.append(Y) 50 | Xtr = np.concatenate(xs, axis=0) 51 | Ytr = np.concatenate(ys) 52 | del X, Y 53 | Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) 54 | return Xtr, Ytr, Xte, Yte 55 | 56 | 57 | def get_CIFAR10_data(): 58 | # Load the raw CIFAR-10 data 59 | cifar10_dir = '' 60 | x_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) 61 | x_train, y_train, X_test, y_test = torch.from_numpy(x_train), torch.from_numpy(y_train), \ 62 | torch.from_numpy(X_test), torch.from_numpy(y_test) 63 | return x_train, y_train, X_test, y_test 64 | 65 | 66 | def random_shuffle(input_tensor): 67 | length = input_tensor.shape[0] 68 | random_idx = torch.randperm(length) 69 | output_tensor = input_tensor[random_idx] 70 | return output_tensor 71 | 72 | 73 | class CIFAR_WholeSlide_challenge(torch.utils.data.Dataset): 74 | def __init__(self, train, positive_num=[8, 9], negative_num=[0, 1, 2, 3, 4, 5, 6, 7], 75 | bag_length=10, return_bag=False, num_img_per_slide=600, pos_patch_ratio=0.1, pos_slide_ratio=0.5, transform=None, accompanyPos=True): 76 | self.train = train 77 | self.positive_num = positive_num # transform the N-class into 2-class 78 | self.negative_num = negative_num # transform the N-class into 2-class 79 | self.bag_length = bag_length 80 | self.return_bag = return_bag # return patch ot bag 81 | self.transform = transform # transform the patch image 82 | self.num_img_per_slide = num_img_per_slide 83 | 84 | if train: 85 | self.ds_data, self.ds_label, _, _ = get_CIFAR10_data() 86 | try: 87 | self.ds_data_simCLR_feat = torch.from_numpy(np.load("./Datasets_loader/all_feats_CIFAR.npy")[:50000, :]).float() 88 | print("Pre-trained feat found") 89 | except: 90 | print("No pre-trained feat found") 91 | else: 92 | _, _ , self.ds_data, self.ds_label = get_CIFAR10_data() 93 | try: 94 | self.ds_data_simCLR_feat = torch.from_numpy(np.load("./Datasets_loader/all_feats_CIFAR.npy")[50000:, :]).float() 95 | print("Pre-trained feat found") 96 | except: 97 | print("No pre-trained feat found") 98 | 99 | self.build_whole_slides(num_img=num_img_per_slide, positive_nums=positive_num, negative_nums=negative_num, pos_patch_ratio=pos_patch_ratio, pos_slide_ratio=pos_slide_ratio) 100 | print("") 101 | 102 | def build_whole_slides(self, num_img, positive_nums, negative_nums, pos_patch_ratio=0.1, pos_slide_ratio=0.5): 103 | # num_img: num of images per slide 104 | # positive patch ratio in each slide 105 | 106 | num_pos_per_slide = int(num_img * pos_patch_ratio) 107 | num_neg_per_slide = num_img - num_pos_per_slide 108 | 109 | idx_pos = [] 110 | for num in positive_nums: 111 | idx_pos.append(torch.where(self.ds_label == num)[0]) 112 | idx_pos = torch.cat(idx_pos).unsqueeze(1) 113 | idx_neg = [] 114 | for num in negative_nums: 115 | idx_neg.append(torch.where(self.ds_label == num)[0]) 116 | idx_neg = torch.cat(idx_neg).unsqueeze(1) 117 | 118 | idx_pos = random_shuffle(idx_pos) 119 | idx_neg = random_shuffle(idx_neg) 120 | 121 | # build pos slides using calculated 122 | num_pos_2PosSlides = int(idx_neg.numel() // ((1 - pos_slide_ratio) / (pos_patch_ratio*pos_slide_ratio) + (1 - pos_patch_ratio) / pos_patch_ratio)) 123 | if num_pos_2PosSlides > idx_pos.shape[0]: 124 | num_pos_2PosSlides = idx_pos.shape[0] 125 | num_pos_2PosSlides = int(num_pos_2PosSlides // num_pos_per_slide * num_pos_per_slide) 126 | num_neg_2PosSlides = int(num_pos_2PosSlides * ((1-pos_patch_ratio)/pos_patch_ratio)) 127 | num_neg_2NegSlides = int(num_pos_2PosSlides * ((1-pos_slide_ratio)/(pos_patch_ratio*pos_slide_ratio))) 128 | 129 | num_neg_2PosSlides = int(num_neg_2PosSlides // num_neg_per_slide * num_neg_per_slide) 130 | num_neg_2NegSlides = int(num_neg_2NegSlides // num_img * num_img) 131 | 132 | if num_neg_2PosSlides // num_neg_per_slide != num_pos_2PosSlides // num_pos_per_slide : 133 | num_diff_slide = num_pos_2PosSlides // num_pos_per_slide - num_neg_2PosSlides // num_neg_per_slide 134 | num_pos_2PosSlides = num_pos_2PosSlides - num_pos_per_slide * num_diff_slide 135 | 136 | idx_pos = idx_pos[0:num_pos_2PosSlides] 137 | idx_neg = idx_neg[0:(num_neg_2PosSlides+num_neg_2NegSlides)] 138 | 139 | idx_pos_toPosSlide = idx_pos[:].reshape(-1, num_pos_per_slide) 140 | idx_neg_toPosSlide = idx_neg[0:num_neg_2PosSlides].reshape(-1, num_neg_per_slide) 141 | idx_neg_toNegSlide = idx_neg[num_neg_2PosSlides:].reshape(-1, num_img) 142 | 143 | idx_pos_slides = torch.cat([idx_pos_toPosSlide, idx_neg_toPosSlide], dim=1) 144 | # idx_pos_slides = idx_pos_slides[:, torch.randperm(idx_pos_slides.shape[1])] # shuffle pos and neg idx 145 | for i_ in range(idx_pos_slides.shape[0]): 146 | idx_pos_slides[i_, :] = idx_pos_slides[i_, torch.randperm(idx_pos_slides.shape[1])] 147 | idx_neg_slides = idx_neg_toNegSlide 148 | 149 | self.idx_all_slides = torch.cat([idx_pos_slides, idx_neg_slides], dim=0) 150 | self.label_all_slides = torch.cat([torch.ones(idx_pos_slides.shape[0]), torch.zeros(idx_neg_slides.shape[0])], dim=0) 151 | self.label_all_slides = self.label_all_slides.unsqueeze(1).repeat([1,self.idx_all_slides.shape[1]]).long() 152 | print("[Info] dataset: {}".format(self.idx_all_slides.shape)) 153 | #self.visualize(idx_pos_slides[0]) 154 | 155 | def __getitem__(self, index): 156 | if self.return_bag: 157 | bagPerSlide = self.idx_all_slides.shape[1] // self.bag_length 158 | idx_slide = index // bagPerSlide 159 | idx_bag_in_slide = index % bagPerSlide 160 | idx_images = self.idx_all_slides[idx_slide, (idx_bag_in_slide*self.bag_length):((idx_bag_in_slide+1)*self.bag_length)] 161 | bag = self.ds_data[idx_images] 162 | patch_labels_raw = self.ds_label[idx_images] 163 | patch_labels = torch.zeros_like(patch_labels_raw) 164 | for num in self.positive_num: 165 | patch_labels[patch_labels_raw == num] = 1 166 | patch_labels = patch_labels.long() 167 | slide_label = self.label_all_slides[idx_slide, 0] 168 | slide_name = str(idx_slide) 169 | return bag.float()/255, [patch_labels, slide_label, idx_slide, slide_name], index 170 | else: 171 | idx_image = self.idx_all_slides.flatten()[index] 172 | slide_label = self.label_all_slides.flatten()[index] 173 | idx_slide = index // self.num_img_per_slide 174 | slide_name = str(idx_slide) 175 | patch = self.ds_data[idx_image] 176 | patch_label = self.ds_label[idx_image] 177 | patch_label = int(patch_label in self.positive_num) 178 | return patch.float()/255, [patch_label, slide_label, idx_slide, slide_name], index 179 | 180 | def __len__(self): 181 | if self.return_bag: 182 | return self.idx_all_slides.shape[1] // self.bag_length * self.idx_all_slides.shape[0] 183 | else: 184 | return self.idx_all_slides.numel() 185 | 186 | def visualize(self, idx, number_row=20, number_col=30): 187 | # idx should be of shape num_img_per_slide 188 | slide = self.ds_data[idx].clone() # num_img_per_slide * 3 * 32 * 32 189 | patch_label = self.ds_label[idx].clone() 190 | idx_pos_patch = [] 191 | for num in self.positive_num: 192 | idx_pos_patch.append(torch.where(patch_label == num)[0]) 193 | idx_pos_patch = torch.cat(idx_pos_patch) 194 | slide[idx_pos_patch, 0, :2, :] = 255 195 | slide[idx_pos_patch, 0, -2:, :] = 255 196 | slide[idx_pos_patch, 0, :, :2] = 255 197 | slide[idx_pos_patch, 0, :, -2:] = 255 198 | 199 | slide[idx_pos_patch, 1, :2, :] = 0 200 | slide[idx_pos_patch, 1, -2:, :] = 0 201 | slide[idx_pos_patch, 1, :, :2] = 0 202 | slide[idx_pos_patch, 1, :, -2:] = 0 203 | 204 | slide[idx_pos_patch, 2, :2, :] = 0 205 | slide[idx_pos_patch, 2, -2:, :] = 0 206 | slide[idx_pos_patch, 2, :, :2] = 0 207 | slide[idx_pos_patch, 2, :, -2:] = 0 208 | 209 | slide = slide.unsqueeze(0).reshape(number_row, number_col, 3, 32, 32).permute(0, 3, 1, 4, 2).reshape(number_row*32, number_col*32, 3) 210 | import utliz 211 | utliz.show_img(slide) 212 | return 0 213 | 214 | 215 | if __name__ == "__main__": 216 | # Invoke the above function to get our data. 217 | x_train, y_train,x_test, y_test = get_CIFAR10_data() 218 | 219 | # for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.3, 0.5]: 220 | for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.5, 0.7]: 221 | print("=========== pos slide ratio: {} ===========".format(pos_slide_ratio)) 222 | train_ds = CIFAR_WholeSlide_challenge(train=True, positive_num=[9], negative_num=[0, 1, 2, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None) 223 | train_loader = data_utils.DataLoader(train_ds, batch_size=64, shuffle=True, drop_last=False) 224 | test_ds_part1 = CIFAR_WholeSlide_challenge(train=False, positive_num=[9], negative_num=[0, 1, 2, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None) 225 | test_loader_part1 = data_utils.DataLoader(test_ds_part1, batch_size=64, shuffle=True, drop_last=False) 226 | print("") 227 | print("") 228 | 229 | -------------------------------------------------------------------------------- /Datasets_loader/dataset_CervicalCancer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.utils.data as data_utils 5 | from torchvision import datasets, transforms 6 | from PIL import Image 7 | import os 8 | import glob 9 | from skimage import io 10 | from tqdm import tqdm 11 | 12 | 13 | def shuffle_downsample_myself(input_list, downsample): 14 | np.random.shuffle(input_list) 15 | if downsample == 1: 16 | output_list = input_list 17 | elif downsample < 1: 18 | output_list = input_list[:int(len(input_list) * downsample)] 19 | elif downsample > 1: 20 | downsample = int(downsample) 21 | if downsample > len(input_list): 22 | output_list = input_list 23 | else: 24 | output_list = input_list[:downsample] 25 | else: 26 | raise 27 | return output_list 28 | 29 | 30 | class CervicalCaner_16(torch.utils.data.Dataset): 31 | # @profile 32 | def __init__(self, root_dir="", 33 | train=True, transform=None, downsample=1.0, preload=True, return_bag=False): 34 | self.root_dir = root_dir 35 | self.train = train 36 | self.return_bag = return_bag 37 | self.transform = transform 38 | self.downsample = downsample 39 | self.preload = preload 40 | if self.transform is None: 41 | self.transform = transforms.Compose([transforms.ToTensor()]) 42 | if train: 43 | self.root_dir = os.path.join(self.root_dir, "train_datasets") 44 | else: 45 | self.root_dir = os.path.join(self.root_dir, "test_datasets") 46 | 47 | all_slides = glob.glob(self.root_dir + "/*/*") 48 | 49 | self.num_slides = len(all_slides) 50 | 51 | # 2.extract all available patches and build corresponding labels 52 | if self.preload: 53 | self.num_patches = 0 54 | for i in tqdm(all_slides, ascii=True, desc='scanning data'): 55 | self.num_patches = self.num_patches + len(shuffle_downsample_myself(os.listdir(i), self.downsample)) 56 | self.all_patches = np.zeros([self.num_patches, 224, 224, 3], dtype=np.uint8) 57 | else: 58 | self.all_patches = [] 59 | self.patch_label = [] 60 | self.patch_corresponding_slide_label = [] 61 | self.patch_corresponding_slide_index = [] 62 | self.patch_corresponding_slide_name = [] 63 | cnt_slide = 0 64 | cnt_patch = 0 65 | for i in tqdm(all_slides, ascii=True, desc='preload data'): 66 | patches_from_slide_i = os.listdir(i) 67 | patches_from_slide_i = shuffle_downsample_myself(patches_from_slide_i, self.downsample) 68 | for j in patches_from_slide_i: 69 | if self.preload: 70 | self.all_patches[cnt_patch, :, :, :] = io.imread(os.path.join(i, j)) 71 | else: 72 | self.all_patches.append(os.path.join(i, j)) 73 | self.patch_label.append(0) 74 | self.patch_corresponding_slide_label.append(int('P' == i.split('/')[-2])) 75 | self.patch_corresponding_slide_index.append(cnt_slide) 76 | self.patch_corresponding_slide_name.append(i.split('/')[-1]) 77 | cnt_patch = cnt_patch + 1 78 | cnt_slide = cnt_slide + 1 79 | if not self.preload: 80 | self.all_patches = np.array(self.all_patches) 81 | self.num_patches = self.all_patches.shape[0] 82 | self.all_patches = self.all_patches.transpose(0, 3, 1, 2) 83 | self.patch_label = np.array(self.patch_label, dtype=np.long) # [Attention] patch label is unavailable and set to 0 84 | self.patch_corresponding_slide_label = np.array(self.patch_corresponding_slide_label, dtype=np.long) 85 | self.patch_corresponding_slide_index = np.array(self.patch_corresponding_slide_index, dtype=np.long) 86 | self.patch_corresponding_slide_name = np.array(self.patch_corresponding_slide_name) 87 | 88 | # 3.do some statistics 89 | print("[DATA INFO] num_slide is {}; num_patches is {}\npos_patch_ratio is unknown".format( 90 | self.num_slides, self.num_patches)) 91 | print("") 92 | 93 | def __getitem__(self, index): 94 | if self.return_bag: 95 | idx_patch_from_slide_i = np.where(self.patch_corresponding_slide_index == index)[0] 96 | bag = self.all_patches[idx_patch_from_slide_i, :, :, :] 97 | bag = bag.astype(np.float32)/255 98 | patch_labels = self.patch_label[idx_patch_from_slide_i] # Patch label is not available and set to 0 ! 99 | slide_label = self.patch_corresponding_slide_label[idx_patch_from_slide_i][0] 100 | slide_index = self.patch_corresponding_slide_index[idx_patch_from_slide_i][0] 101 | slide_name = self.patch_corresponding_slide_name[idx_patch_from_slide_i][0] 102 | 103 | # check data 104 | if self.patch_corresponding_slide_label[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_label[idx_patch_from_slide_i].min(): 105 | raise 106 | if self.patch_corresponding_slide_index[idx_patch_from_slide_i].max() != self.patch_corresponding_slide_index[idx_patch_from_slide_i].min(): 107 | raise 108 | return bag, [patch_labels, slide_label, slide_index, slide_name], index 109 | else: 110 | if self.preload: 111 | patch_image = self.all_patches[index] 112 | else: 113 | patch_image = io.imread(self.all_patches[index]) 114 | patch_label = self.patch_label[index] # [Attention] patch label is unavailable and set to 0 115 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index] 116 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index] 117 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index] 118 | 119 | # patch_image = self.transform(Image.fromarray(np.uint8(patch_image), 'RGB')) 120 | patch_image = patch_image.astype(np.float32)/255 121 | return patch_image, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index, 122 | patch_corresponding_slide_name], index 123 | 124 | def __len__(self): 125 | if self.return_bag: 126 | return self.patch_corresponding_slide_index.max() + 1 127 | else: 128 | return self.num_patches 129 | 130 | 131 | class CervicalCaner_16_feat(torch.utils.data.Dataset): 132 | def __init__(self, root_dir="", 133 | train=True, return_bag=True): 134 | self.root_dir = root_dir 135 | self.train = train 136 | self.return_bag = return_bag 137 | 138 | # 1. load all featreus and slide label and index 139 | save_path = "" 140 | if train: 141 | self.all_patches = np.load(os.path.join(save_path, "train_feats.npy")) 142 | self.patch_corresponding_slide_label = np.load(os.path.join(save_path, "train_corresponding_slide_label.npy")) 143 | self.patch_corresponding_slide_index = np.load(os.path.join(save_path, "train_corresponding_slide_index.npy")) 144 | self.patch_label = np.zeros_like(self.patch_corresponding_slide_label) # [Attention] patch label is unavailable and set to 0 145 | self.patch_corresponding_slide_name = self.patch_corresponding_slide_index 146 | else: 147 | self.all_patches = np.load(os.path.join(save_path, "test_feats.npy")) 148 | self.patch_corresponding_slide_label = np.load(os.path.join(save_path, "test_corresponding_slide_label.npy")) 149 | self.patch_corresponding_slide_index = np.load(os.path.join(save_path, "test_corresponding_slide_index.npy")) 150 | self.patch_label = np.zeros_like(self.patch_corresponding_slide_label) # [Attention] patch label is unavailable and set to 0 151 | self.patch_corresponding_slide_name = self.patch_corresponding_slide_index 152 | 153 | self.num_patches = self.all_patches.shape[0] 154 | self.num_slides = self.patch_corresponding_slide_index.max() + 1 155 | print("[DATA INFO] num_slide is {}; num_patches is {}\npos_patch_ratio is unknown".format( 156 | self.num_slides, self.num_patches)) 157 | 158 | # 2. sort instances features into bag 159 | self.slide_feat_all = [] 160 | self.slide_label_all = [] 161 | self.slide_patch_label_all = [] 162 | for i in range(self.num_slides): 163 | idx_from_same_slide = self.patch_corresponding_slide_index == i 164 | idx_from_same_slide = np.nonzero(idx_from_same_slide)[0] 165 | 166 | self.slide_feat_all.append(self.all_patches[idx_from_same_slide]) 167 | if self.patch_corresponding_slide_label[idx_from_same_slide].max() != self.patch_corresponding_slide_label[idx_from_same_slide].min(): 168 | raise 169 | self.slide_label_all.append(self.patch_corresponding_slide_label[idx_from_same_slide].max()) 170 | self.slide_patch_label_all.append(np.zeros(idx_from_same_slide.shape[0]).astype(np.long)) 171 | print("") 172 | 173 | def __getitem__(self, index): 174 | if self.return_bag: 175 | slide_feat = self.slide_feat_all[index] 176 | slide_label = self.slide_label_all[index] 177 | slide_patch_label = self.slide_patch_label_all[index] 178 | return slide_feat, [slide_patch_label, slide_label], index 179 | else: 180 | patch_image_feat = self.all_patches[index] 181 | patch_label = self.patch_label[index] # [Attention] patch label is unavailable and set to 0 182 | patch_corresponding_slide_label = self.patch_corresponding_slide_label[index] 183 | patch_corresponding_slide_index = self.patch_corresponding_slide_index[index] 184 | patch_corresponding_slide_name = self.patch_corresponding_slide_name[index] 185 | 186 | return patch_image_feat, [patch_label, patch_corresponding_slide_label, patch_corresponding_slide_index, 187 | patch_corresponding_slide_name], index 188 | 189 | def __len__(self): 190 | if self.return_bag: 191 | return self.num_slides 192 | else: 193 | return self.num_patches 194 | 195 | 196 | if __name__ == '__main__': 197 | train_ds_return_bag = CervicalCaner_16_feat(train=True, return_bag=True) 198 | train_ds_return_instance = CervicalCaner_16_feat(train=True, return_bag=False) 199 | 200 | train_ds = CervicalCaner_16(root_dir=root_dir, train=True, transform=None, downsample=10, preload=True) 201 | val_ds = CervicalCaner_16(root_dir=root_dir, train=False, transform=None, downsample=10, preload=True) 202 | train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, 203 | shuffle=True, num_workers=0, drop_last=False, pin_memory=True) 204 | val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64, 205 | shuffle=False, num_workers=0, drop_last=False, pin_memory=True) 206 | for data in train_loader: 207 | patch_img = data[0] 208 | label_patch = data[1][0] 209 | label_bag = data[1][1] 210 | idx = data[-1] 211 | print("END") 212 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import math 4 | import numpy as np 5 | from scipy.special import logsumexp 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | 11 | from torch.nn import ModuleList 12 | from torchvision.utils import make_grid 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | def setup_runtime(seed=0, cuda_dev_id=[0]): 33 | """Initialize CUDA, CuDNN and the random seeds.""" 34 | # Setup CUDA 35 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 36 | if len(cuda_dev_id) == 1: 37 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0]) 38 | else: 39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_dev_id[0]) 40 | for i in cuda_dev_id[1:]: 41 | os.environ["CUDA_VISIBLE_DEVICES"] += "," + str(i) 42 | 43 | # global cuda_dev_id 44 | _cuda_device_id = cuda_dev_id 45 | if torch.cuda.is_available(): 46 | torch.backends.cudnn.enabled = True 47 | torch.backends.cudnn.benchmark = True 48 | torch.backends.cudnn.deterministic = False 49 | # Fix random seeds 50 | random.seed(seed) 51 | np.random.seed(seed) 52 | torch.manual_seed(seed) 53 | if torch.cuda.is_available(): 54 | torch.cuda.manual_seed_all(seed) 55 | 56 | class TotalAverage(): 57 | def __init__(self): 58 | self.reset() 59 | 60 | def reset(self): 61 | self.val = 0. 62 | self.mass = 0. 63 | self.sum = 0. 64 | self.avg = 0. 65 | 66 | def update(self, val, mass=1): 67 | self.val = val 68 | self.mass += mass 69 | self.sum += val * mass 70 | self.avg = self.sum / self.mass 71 | 72 | 73 | class MovingAverage(): 74 | def __init__(self, intertia=0.9): 75 | self.intertia = intertia 76 | self.reset() 77 | 78 | def reset(self): 79 | self.avg = 0. 80 | 81 | def update(self, val): 82 | self.avg = self.intertia * self.avg + (1 - self.intertia) * val 83 | 84 | 85 | def accuracy(output, target, topk=(1,)): 86 | """Computes the precision@k for the specified values of k.""" 87 | with torch.no_grad(): 88 | maxk = max(topk) 89 | batch_size = target.size(0) 90 | 91 | _, pred = output.topk(maxk, 1, True, True) 92 | pred = pred.t() 93 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 94 | 95 | res = [] 96 | for k in topk: 97 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 98 | res.append(correct_k.mul_(100.0 / batch_size)) 99 | return res 100 | 101 | def write_conv(writer, model, epoch, sobel=False): 102 | if not sobel: 103 | conv1_ = make_grid(list(ModuleList(list(model.children())[0].children())[0].parameters())[0], 104 | nrow=8, normalize=True, scale_each=True) 105 | writer.add_image('conv1', conv1_, epoch) 106 | else: 107 | conv1_sobel_w = list(ModuleList(list(model.children())[0].children())[0].parameters())[0] 108 | conv1_ = make_grid(conv1_sobel_w[:, 0:1, :, :], nrow=8, 109 | normalize=True, scale_each=True) 110 | writer.add_image('conv1_sobel_1', conv1_, epoch) 111 | conv2_ = make_grid(conv1_sobel_w[:, 1:2, :, :], nrow=8, 112 | normalize=True, scale_each=True) 113 | writer.add_image('conv1_sobel_2', conv2_, epoch) 114 | conv1_x = make_grid(torch.sum(conv1_sobel_w[:, :, :, :], 1, keepdim=True), nrow=8, 115 | normalize=True, scale_each=True) 116 | writer.add_image('conv1', conv1_x, epoch) 117 | 118 | 119 | ### LP stuff ### 120 | def absorb_bn(module, bn_module): 121 | w = module.weight.data 122 | if module.bias is None: 123 | if isinstance(module, nn.Linear): 124 | zeros = torch.Tensor(module.out_features).zero_().type(w.type()) 125 | else: 126 | zeros = torch.Tensor(module.out_channels).zero_().type(w.type()) 127 | module.bias = nn.Parameter(zeros) 128 | b = module.bias.data 129 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) 130 | if isinstance(module, nn.Conv2d): 131 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) 132 | else: 133 | w.mul_(invstd.unsqueeze(1).expand_as(w)) 134 | b.add_(-bn_module.running_mean).mul_(invstd) 135 | 136 | if bn_module.affine: 137 | if isinstance(module, nn.Conv2d): 138 | w.mul_(bn_module.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) 139 | else: 140 | w.mul_(bn_module.weight.data.unsqueeze(1).expand_as(w)) 141 | b.mul_(bn_module.weight.data).add_(bn_module.bias.data) 142 | 143 | bn_module.reset_parameters() 144 | bn_module.register_buffer('running_mean', None) 145 | bn_module.register_buffer('running_var', None) 146 | bn_module.affine = False 147 | bn_module.register_parameter('weight', None) 148 | bn_module.register_parameter('bias', None) 149 | 150 | 151 | def is_bn(m): 152 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 153 | 154 | 155 | def is_absorbing(m): 156 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) 157 | 158 | 159 | def search_absorb_bn(model): 160 | prev = None 161 | for m in model.children(): 162 | if is_bn(m) and is_absorbing(prev): 163 | print("absorbing",m) 164 | absorb_bn(prev, m) 165 | search_absorb_bn(m) 166 | prev = m 167 | 168 | 169 | class View(nn.Module): 170 | """A shape adaptation layer to patch certain networks.""" 171 | def __init__(self): 172 | super(View, self).__init__() 173 | 174 | def forward(self, x): 175 | return x.view(x.shape[0], -1) 176 | 177 | 178 | def sequential_skipping_bn_cut(model): 179 | mods = [] 180 | layers = list(model.features) + [View()] 181 | if 'sobel' in dict(model.named_children()).keys(): 182 | layers = list(model.sobel) + layers 183 | for m in nn.Sequential(*(layers)).children(): 184 | if not is_bn(m): 185 | mods.append(m) 186 | return nn.Sequential(*mods) 187 | 188 | 189 | def py_softmax(x, axis=None): 190 | """stable softmax""" 191 | return np.exp(x - logsumexp(x, axis=axis, keepdims=True)) 192 | 193 | def warmup_batchnorm(model, data_loader, device, batches=100): 194 | """ 195 | Run some batches through all parts of the model to warmup the running 196 | stats for batchnorm layers. 197 | """ 198 | model.train() 199 | for i, q in enumerate(data_loader): 200 | images = q[0] 201 | if i == batches: 202 | break 203 | images = images.to(device) 204 | _ = model(images) 205 | 206 | def init_pytorch_defaults(m, version='041'): 207 | ''' 208 | copied from AMDIM repo: https://github.com/Philip-Bachman/amdim-public/ 209 | note from me: haven't checked systematically if this improves results 210 | ''' 211 | if version == '041': 212 | # print('init.pt041: {0:s}'.format(str(m.weight.data.size()))) 213 | if isinstance(m, nn.Linear): 214 | stdv = 1. / math.sqrt(m.weight.size(1)) 215 | m.weight.data.uniform_(-stdv, stdv) 216 | if m.bias is not None: 217 | m.bias.data.uniform_(-stdv, stdv) 218 | elif isinstance(m, nn.Conv2d): 219 | n = m.in_channels 220 | for k in m.kernel_size: 221 | n *= k 222 | stdv = 1. / math.sqrt(n) 223 | m.weight.data.uniform_(-stdv, stdv) 224 | if m.bias is not None: 225 | m.bias.data.uniform_(-stdv, stdv) 226 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 227 | if m.affine: 228 | m.weight.data.uniform_() 229 | m.bias.data.zero_() 230 | else: 231 | assert False 232 | elif version == '100': 233 | # print('init.pt100: {0:s}'.format(str(m.weight.data.size()))) 234 | if isinstance(m, nn.Linear): 235 | init.kaiming_uniform_(m.weight, a=math.sqrt(5)) 236 | if m.bias is not None: 237 | fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight) 238 | bound = 1 / math.sqrt(fan_in) 239 | init.uniform_(m.bias, -bound, bound) 240 | elif isinstance(m, nn.Conv2d): 241 | n = m.in_channels 242 | init.kaiming_uniform_(m.weight, a=math.sqrt(5)) 243 | if m.bias is not None: 244 | fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight) 245 | bound = 1 / math.sqrt(fan_in) 246 | init.uniform_(m.bias, -bound, bound) 247 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 248 | if m.affine: 249 | m.weight.data.uniform_() 250 | m.bias.data.zero_() 251 | else: 252 | assert False 253 | elif version == 'custom': 254 | # print('init.custom: {0:s}'.format(str(m.weight.data.size()))) 255 | if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 256 | init.normal_(m.weight.data, mean=1, std=0.02) 257 | init.constant_(m.bias.data, 0) 258 | else: 259 | assert False 260 | else: 261 | assert False 262 | 263 | 264 | def weight_init(m): 265 | ''' 266 | Usage: 267 | model = Model() 268 | model.apply(weight_init) 269 | ''' 270 | if isinstance(m, nn.Linear): 271 | init_pytorch_defaults(m, version='041') 272 | elif isinstance(m, nn.Conv2d): 273 | init_pytorch_defaults(m, version='041') 274 | elif isinstance(m, nn.BatchNorm1d): 275 | init_pytorch_defaults(m, version='041') 276 | elif isinstance(m, nn.BatchNorm2d): 277 | init_pytorch_defaults(m, version='041') 278 | elif isinstance(m, nn.Conv1d): 279 | init.normal_(m.weight.data) 280 | if m.bias is not None: 281 | init.normal_(m.bias.data) 282 | elif isinstance(m, nn.ConvTranspose1d): 283 | init.normal_(m.weight.data) 284 | if m.bias is not None: 285 | init.normal_(m.bias.data) 286 | elif isinstance(m, nn.ConvTranspose2d): 287 | init.xavier_normal_(m.weight.data) 288 | if m.bias is not None: 289 | init.normal_(m.bias.data) 290 | 291 | 292 | def search_set_bn_eval(model,toeval): 293 | for m in model.children(): 294 | if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 295 | if toeval: 296 | m.eval() 297 | else: 298 | m.train() 299 | search_set_bn_eval(m, toeval) 300 | 301 | def prepmodel(model, modelpath): 302 | dat = torch.load(modelpath, map_location=lambda storage, loc: storage) # ['model'] 303 | from collections import OrderedDict 304 | new_state_dict = OrderedDict() 305 | for k, v in dat.items(): 306 | name = k.replace('module.', '') # remove `module.` 307 | new_state_dict[name] = v 308 | model.load_state_dict(new_state_dict) 309 | del dat 310 | for param in model.features.parameters(): 311 | param.requires_grad = False 312 | 313 | if model.headcount > 1: 314 | for i in range(model.headcount): 315 | setattr(model, "top_layer%d" % i, None) 316 | 317 | model.top_layer = nn.Sequential(nn.Linear(2048, 1000)) 318 | model.headcount = 1 319 | model.withfeature = False 320 | model.return_feature_only = False -------------------------------------------------------------------------------- /Datasets_loader/dataset_MIL_NCTCRCHE.py: -------------------------------------------------------------------------------- 1 | """Pytorch Dataset object that loads perfectly balanced MNIST dataset in bag form.""" 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.utils.data as data_utils 7 | from torchvision import datasets, transforms 8 | from PIL import Image 9 | from tqdm import tqdm 10 | import numpy as np 11 | from six.moves import cPickle as pickle 12 | import os 13 | from sklearn.model_selection import train_test_split 14 | import platform 15 | 16 | 17 | Patho_classes = ['NORM', 'MUC', 'TUM', 'STR', 'LYM', 'BACK', 'MUS', 'DEB', 'ADI'] 18 | 19 | 20 | def load_Pathdata(Root, downsample_ratio=1.0): 21 | X = [] 22 | Y = [] 23 | X_train = [] 24 | X_test = [] 25 | Y_train = [] 26 | Y_test = [] 27 | for index_class, folder in enumerate(Patho_classes): 28 | all_files_class_i = os.listdir(os.path.join(Root, folder)) 29 | all_files_class_i = all_files_class_i[:int(len(all_files_class_i)*downsample_ratio)] 30 | for file in tqdm(all_files_class_i): 31 | image = np.array(Image.open(os.path.join(Root, folder, file)).convert("RGB")).transpose(2, 0, 1) 32 | # image = np.array(cv2.imread(os.path.join(Root, folder, file)),cv2.IMREAD_UNCHANGED) 33 | label = index_class 34 | X.append(image) 35 | Y.append(label) 36 | train_X, test_X, train_Y, test_Y = train_test_split(np.array(X), np.array(Y), test_size=0.2, random_state=42) 37 | X_train.append(train_X) 38 | X_test.append(test_X) 39 | Y_train.append(train_Y) 40 | Y_test.append(test_Y) 41 | X=[] 42 | Y=[] 43 | X_train = [b for a in X_train for b in a] 44 | X_test = [b for a in X_test for b in a] 45 | Y_train = [b for a in Y_train for b in a] 46 | Y_test = [b for a in Y_test for b in a] 47 | 48 | return np.array(X_train), np.array(X_test), np.array(Y_train), np.array(Y_test) 49 | 50 | 51 | def get_Path10_data(Path10_dir='', downsample_ratio=1.0): 52 | # Load the raw Path-10 data 53 | X_train, X_test, Y_train, Y_test = load_Pathdata(Path10_dir, downsample_ratio) 54 | x_train, y_train, X_test, y_test = torch.from_numpy(X_train), torch.from_numpy(Y_train), \ 55 | torch.from_numpy(X_test), torch.from_numpy(Y_test) 56 | return x_train, y_train, X_test, y_test 57 | 58 | 59 | def random_shuffle(input_tensor): 60 | length = input_tensor.shape[0] 61 | random_idx = torch.randperm(length) 62 | output_tensor = input_tensor[random_idx] 63 | return output_tensor 64 | 65 | 66 | class NCT_WholeSlide_challenge(torch.utils.data.Dataset): 67 | def __init__(self, ds_data, ds_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8], 68 | bag_length=10, return_bag=False, num_img_per_slide=600, pos_patch_ratio=0.1, pos_slide_ratio=0.5, transform=None): 69 | 70 | self.positive_num = positive_num # transform the N-class into 2-class 71 | self.negative_num = negative_num # transform the N-class into 2-class 72 | self.bag_length = bag_length 73 | self.return_bag = return_bag # return patch ot bag 74 | self.transform = transform # transform the patch image 75 | self.num_img_per_slide = num_img_per_slide 76 | 77 | self.ds_data, self.ds_label = ds_data, ds_label 78 | self.build_whole_slides(num_img=num_img_per_slide, positive_nums=positive_num, negative_nums=negative_num, pos_patch_ratio=pos_patch_ratio, pos_slide_ratio=pos_slide_ratio) 79 | print("") 80 | 81 | def build_whole_slides(self, num_img, positive_nums, negative_nums, pos_patch_ratio=0.1, pos_slide_ratio=0.5): 82 | # num_img: num of images per slide 83 | # positive patch ratio in each slide 84 | 85 | num_pos_per_slide = int(num_img * pos_patch_ratio) 86 | num_neg_per_slide = num_img - num_pos_per_slide 87 | 88 | idx_pos = [] 89 | for num in positive_nums: 90 | idx_pos.append(torch.where(self.ds_label == num)[0]) 91 | idx_pos = torch.cat(idx_pos).unsqueeze(1) 92 | idx_neg = [] 93 | for num in negative_nums: 94 | idx_neg.append(torch.where(self.ds_label == num)[0]) 95 | idx_neg = torch.cat(idx_neg).unsqueeze(1) 96 | 97 | idx_pos = random_shuffle(idx_pos) 98 | idx_neg = random_shuffle(idx_neg) 99 | 100 | # build pos slides using calculated 101 | num_pos_2PosSlides = int(idx_neg.numel() // ((1 - pos_slide_ratio) / (pos_patch_ratio*pos_slide_ratio) + (1 - pos_patch_ratio) / pos_patch_ratio)) 102 | if num_pos_2PosSlides > idx_pos.shape[0]: 103 | num_pos_2PosSlides = idx_pos.shape[0] 104 | num_pos_2PosSlides = int(num_pos_2PosSlides // num_pos_per_slide * num_pos_per_slide) 105 | num_neg_2PosSlides = int(num_pos_2PosSlides * ((1-pos_patch_ratio)/pos_patch_ratio)) 106 | num_neg_2NegSlides = int(num_pos_2PosSlides * ((1-pos_slide_ratio)/(pos_patch_ratio*pos_slide_ratio))) 107 | 108 | num_neg_2PosSlides = int(num_neg_2PosSlides // num_neg_per_slide * num_neg_per_slide) 109 | num_neg_2NegSlides = int(num_neg_2NegSlides // num_img * num_img) 110 | 111 | if num_neg_2PosSlides // num_neg_per_slide != num_pos_2PosSlides // num_pos_per_slide : 112 | num_diff_slide = num_pos_2PosSlides // num_pos_per_slide - num_neg_2PosSlides // num_neg_per_slide 113 | num_pos_2PosSlides = num_pos_2PosSlides - num_pos_per_slide * num_diff_slide 114 | 115 | idx_pos = idx_pos[0:num_pos_2PosSlides] 116 | idx_neg = idx_neg[0:(num_neg_2PosSlides+num_neg_2NegSlides)] 117 | 118 | idx_pos_toPosSlide = idx_pos[:].reshape(-1, num_pos_per_slide) 119 | idx_neg_toPosSlide = idx_neg[0:num_neg_2PosSlides].reshape(-1, num_neg_per_slide) 120 | idx_neg_toNegSlide = idx_neg[num_neg_2PosSlides:].reshape(-1, num_img) 121 | 122 | idx_pos_slides = torch.cat([idx_pos_toPosSlide, idx_neg_toPosSlide], dim=1) 123 | # idx_pos_slides = idx_pos_slides[:, torch.randperm(idx_pos_slides.shape[1])] # shuffle pos and neg idx 124 | for i_ in range(idx_pos_slides.shape[0]): 125 | idx_pos_slides[i_, :] = idx_pos_slides[i_, torch.randperm(idx_pos_slides.shape[1])] 126 | idx_neg_slides = idx_neg_toNegSlide 127 | 128 | self.idx_all_slides = torch.cat([idx_pos_slides, idx_neg_slides], dim=0) 129 | self.label_all_slides = torch.cat([torch.ones(idx_pos_slides.shape[0]), torch.zeros(idx_neg_slides.shape[0])], dim=0) 130 | self.label_all_slides = self.label_all_slides.unsqueeze(1).repeat([1,self.idx_all_slides.shape[1]]).long() 131 | print("[Info] dataset: {}".format(self.idx_all_slides.shape)) 132 | #self.visualize(idx_pos_slides[0]) 133 | 134 | def __getitem__(self, index): 135 | if self.return_bag: 136 | bagPerSlide = self.idx_all_slides.shape[1] // self.bag_length 137 | idx_slide = index // bagPerSlide 138 | idx_bag_in_slide = index % bagPerSlide 139 | idx_images = self.idx_all_slides[idx_slide, (idx_bag_in_slide*self.bag_length):((idx_bag_in_slide+1)*self.bag_length)] 140 | bag = self.ds_data[idx_images] 141 | patch_labels_raw = self.ds_label[idx_images] 142 | patch_labels = torch.zeros_like(patch_labels_raw) 143 | for num in self.positive_num: 144 | patch_labels[patch_labels_raw == num] = 1 145 | patch_labels = patch_labels.long() 146 | slide_label = self.label_all_slides[idx_slide, 0] 147 | slide_name = str(idx_slide) 148 | return bag.float()/255, [patch_labels, slide_label, idx_slide, slide_name], index 149 | else: 150 | idx_image = self.idx_all_slides.flatten()[index] 151 | slide_label = self.label_all_slides.flatten()[index] 152 | idx_slide = index // self.num_img_per_slide 153 | slide_name = str(idx_slide) 154 | patch = self.ds_data[idx_image] 155 | patch_label = self.ds_label[idx_image] 156 | patch_label = int(patch_label in self.positive_num) 157 | return patch.float()/255, [patch_label, slide_label, idx_slide, slide_name], index 158 | 159 | def __len__(self): 160 | if self.return_bag: 161 | return self.idx_all_slides.shape[1] // self.bag_length * self.idx_all_slides.shape[0] 162 | else: 163 | return self.idx_all_slides.numel() 164 | 165 | def visualize(self, idx, number_row=10, number_col=10): 166 | # idx should be of shape num_img_per_slide 167 | slide = self.ds_data[idx].clone() # num_img_per_slide * 3 * 32 * 32 168 | patch_label = self.ds_label[idx].clone() 169 | idx_pos_patch = [] 170 | for num in self.positive_num: 171 | idx_pos_patch.append(torch.where(patch_label == num)[0]) 172 | idx_pos_patch = torch.cat(idx_pos_patch) 173 | slide[idx_pos_patch, 0, :10, :] = 255 174 | slide[idx_pos_patch, 0, -10:, :] = 255 175 | slide[idx_pos_patch, 0, :, :10] = 255 176 | slide[idx_pos_patch, 0, :, -10:] = 255 177 | 178 | slide[idx_pos_patch, 1, :10, :] = 0 179 | slide[idx_pos_patch, 1, -10:, :] = 0 180 | slide[idx_pos_patch, 1, :, :10] = 0 181 | slide[idx_pos_patch, 1, :, -10:] = 0 182 | 183 | slide[idx_pos_patch, 2, :10, :] = 0 184 | slide[idx_pos_patch, 2, -10:, :] = 0 185 | slide[idx_pos_patch, 2, :, :10] = 0 186 | slide[idx_pos_patch, 2, :, -10:] = 0 187 | 188 | slide = slide.unsqueeze(0).reshape(number_row, number_col, 3, 224, 224).permute(0, 3, 1, 4, 2).reshape(number_row*224, number_col*224, 3) 189 | import utliz 190 | # show_img_1(slide) 191 | return slide 192 | 193 | 194 | def show_img_1(img, save_file_name=''): 195 | if type(img) == torch.Tensor: 196 | img = img.cpu().detach().numpy() 197 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image 198 | if img.shape[0] == 3: 199 | img = img.transpose(1, 2, 0) 200 | fig = plt.figure() 201 | plt.imshow(img) 202 | if save_file_name != '': 203 | plt.savefig(save_file_name, format='svg') 204 | plt.colorbar() 205 | plt.show() 206 | 207 | 208 | def show_img(img, save_file_name='',format='svg', dpi=1200): 209 | if type(img) == torch.Tensor: 210 | img = img.cpu().detach().numpy() 211 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image 212 | if img.shape[0] == 3: 213 | img = img.transpose(1, 2, 0) 214 | fig = plt.figure() 215 | plt.imshow(img) 216 | plt.xticks([]) 217 | plt.yticks([]) 218 | plt.axis('off') 219 | if save_file_name != '': 220 | plt.savefig(save_file_name, format=format, dpi=dpi, pad_inches=0.0, bbox_inches='tight') 221 | plt.colorbar() 222 | plt.show() 223 | 224 | 225 | if __name__ == "__main__": 226 | # train_data, train_label, val_data, val_label = get_Path10_data(downsample_ratio=0.1) 227 | # for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.5, 0.7]: 228 | # print("=========== pos slide ratio: {} ===========".format(pos_slide_ratio)) 229 | # train_ds = NCT_WholeSlide_challenge(ds_data=train_data, ds_label=train_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None) 230 | # slide = train_ds.visualize(train_ds.idx_all_slides[0]) 231 | # show_img(slide, save_file_name='../figures/NCT_WSI_pos_PPR{}.png'.format(pos_slide_ratio), format='png', dpi=2400) 232 | # slide = train_ds.visualize(train_ds.idx_all_slides[-10]) 233 | # show_img(slide, save_file_name='../figures/NCT_WSI_neg_PPR{}.png'.format(pos_slide_ratio), format='png', dpi=2400) 234 | # # print("") 235 | # print("") 236 | 237 | train_data, train_label, val_data, val_label = get_Path10_data(downsample_ratio=1.0) 238 | for pos_slide_ratio in [0.01, 0.05, 0.1, 0.2, 0.5, 0.7]: 239 | print("=========== pos slide ratio: {} ===========".format(pos_slide_ratio)) 240 | train_ds = NCT_WholeSlide_challenge(ds_data=train_data, ds_label=train_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None) 241 | val_ds = NCT_WholeSlide_challenge(ds_data=val_data, ds_label=val_label, positive_num=[2], negative_num=[0, 1, 3, 4, 5, 6, 7, 8], bag_length=100, return_bag=False, num_img_per_slide=100, pos_patch_ratio=pos_slide_ratio, pos_slide_ratio=0.5, transform=None) 242 | print("") 243 | print("") 244 | 245 | -------------------------------------------------------------------------------- /utliz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.mplot3d.axes3d import Axes3D 5 | from sklearn import metrics 6 | 7 | 8 | def show_pointcloud(pc, size=10): 9 | if type(pc) == torch.Tensor: 10 | pc = pc.numpy() 11 | if pc.shape[0]==3: 12 | pc = pc.transpose() 13 | fig = plt.figure() 14 | ax = fig.add_subplot(1, 1, 1, projection='3d') 15 | ax.scatter(pc[:,0], pc[:,1], pc[:,2],s=size) 16 | 17 | 18 | def show_pointcloud_batch(pc, size=10): 19 | if type(pc) == torch.Tensor: 20 | pc = pc.numpy() 21 | if pc.shape[1]==3: 22 | pc = pc.transpose(0,2,1) 23 | B,N,C = pc.shape 24 | fig = plt.figure() 25 | for i in range(B): 26 | ax = fig.add_subplot(2, int(B/2), i+1, projection='3d') 27 | ax.scatter(pc[i, :, 0], pc[i, :, 1], pc[i, :, 2], s=size) 28 | 29 | 30 | def show_pointcloud_2pc(pc_1, pc_2, ax=None, c1='r', c2='b',s1=1, s2=1): 31 | if type(pc_1) == torch.Tensor: 32 | pc_1 = pc_1.cpu().detach().numpy() 33 | pc_2 = pc_2.cpu().detach().numpy() 34 | if pc_1.shape[0]==3: 35 | pc_1 = pc_1.transpose() 36 | pc_2 = pc_2.transpose() 37 | if ax is None: 38 | fig = plt.figure() 39 | ax = fig.add_subplot(1, 1, 1, projection='3d') 40 | ax.scatter(pc_1[:, 0], pc_1[:, 1], pc_1[:, 2], s=s1, c=c1, alpha=0.5) 41 | ax.scatter(pc_2[:, 0], pc_2[:, 1], pc_2[:, 2], s=s2, c=c2, alpha=0.5) 42 | 43 | 44 | def show_pointcloud_perpointcolor(pc, size=10,c='r'): 45 | # pc.shape = Nx3, c.shape = N 46 | if type(pc) == torch.Tensor: 47 | pc = pc.cpu().detach().numpy() 48 | if pc.shape[0]==3: 49 | pc = pc.transpose() 50 | if type(c) == torch.Tensor: 51 | c = c.cpu().detach().numpy() 52 | if type(c) == np.ndarray: 53 | if len(c.shape) == 2: 54 | c = np.squeeze(c) 55 | fig = plt.figure() 56 | ax = fig.add_subplot(1, 1, 1, projection='3d') 57 | ax0 = ax.scatter(pc[:,0], pc[:,1], pc[:,2],s=size, alpha=0.5,c=c) 58 | plt.colorbar(ax0, ax=ax) 59 | 60 | 61 | def cal_auc(label, pred, pos_label=1, return_fpr_tpr=False, save_fpr_tpr=False): 62 | if type(label) == torch.Tensor: 63 | label = label.detach().cpu().numpy() 64 | if type(pred) == torch.Tensor: 65 | pred = pred.detach().cpu().numpy() 66 | fpr, tpr, thresholds = metrics.roc_curve(label, pred, pos_label=pos_label, drop_intermediate=False) 67 | auc_score = metrics.auc(fpr, tpr) 68 | if save_fpr_tpr: 69 | if auc_score > 0.5: 70 | np.save("./ROC_reinter/{:.0f}".format(auc_score * 10000), 71 | np.concatenate([np.expand_dims(fpr, axis=1), np.expand_dims(tpr, axis=1)], axis=1)) 72 | if return_fpr_tpr: 73 | return fpr, tpr, auc_score 74 | return auc_score 75 | 76 | 77 | def cal_acc(label, pred, threshold=0.5): 78 | if type(label) == torch.Tensor: 79 | label = label.detach().cpu().numpy() 80 | if type(pred) == torch.Tensor: 81 | pred = pred.detach().cpu().numpy() 82 | pred_logit = pred>threshold 83 | pred_logit = pred_logit.astype(np.long) 84 | acc = np.sum(pred_logit == label)/label.shape[0] 85 | return acc 86 | 87 | 88 | def optimal_thresh(fpr, tpr, thresholds, p=0): 89 | loss = (fpr - tpr) - p * tpr / (fpr + tpr + 1) 90 | idx = np.argmin(loss, axis=0) 91 | return fpr[idx], tpr[idx], thresholds[idx] 92 | 93 | 94 | def cal_acc_optimThre(label, pred, pos_label=1): 95 | if type(label) == torch.Tensor: 96 | label = label.detach().cpu().numpy() 97 | if type(pred) == torch.Tensor: 98 | pred = pred.detach().cpu().numpy() 99 | fpr, tpr, thresholds = metrics.roc_curve(label, pred, pos_label=pos_label, drop_intermediate=False) 100 | fpr_optimal, tpr_optimal, threshold_optimal = optimal_thresh(fpr, tpr, thresholds) 101 | pred[pred>threshold_optimal] = 1 102 | pred[pred best_acc: 117 | best_acc = acc 118 | return best_acc 119 | 120 | 121 | def cal_TPR_TNR_FPR_FNR(label, pred): 122 | if type(pred) is not torch.Tensor: 123 | pred = torch.from_numpy(pred) 124 | else: 125 | pred = pred.detach().cpu() 126 | if type(label) is not torch.Tensor: 127 | label = torch.from_numpy(label) 128 | else: 129 | label = label.detach().cpu() 130 | 131 | pred_logit = pred.round() 132 | pseudo_label_TP = torch.sum(label * pred_logit) 133 | pseudo_label_TN = torch.sum((1 - label) * (1 - pred_logit)) 134 | pesudo_label_FP = torch.sum((1 - label) * pred_logit) 135 | pesudo_label_FN = torch.sum(label * (1 - pred_logit)) 136 | pseudo_label_TPR = 1.0 * pseudo_label_TP / (label.sum() + 1e-9) 137 | pseudo_label_TNR = 1.0 * pseudo_label_TN / (label.numel() - label.sum() + 1e-9) 138 | pseudo_label_FPR = 1.0 * pesudo_label_FP / (label.numel() - label.sum() + 1e-9) 139 | pseudo_label_FNR = 1.0 * pesudo_label_FN / (label.sum() + 1e-9) 140 | 141 | pseudo_label_precision = 1.0 * pseudo_label_TP / (pred_logit.sum() + 1e-9) 142 | pseudo_label_acc = 1.0 * torch.sum(label == pred_logit) / label.numel() 143 | pseudo_label_auc = cal_auc(label, pred) 144 | 145 | return [pseudo_label_TPR.item(), pseudo_label_TNR.item(), pseudo_label_FPR.item(), pseudo_label_FNR.item()],\ 146 | pseudo_label_acc.item(), pseudo_label_auc 147 | 148 | 149 | class AverageMeter(object): 150 | """Computes and stores the average and current value""" 151 | def __init__(self): 152 | self.reset() 153 | 154 | def reset(self): 155 | self.val = 0 156 | self.avg = 0 157 | self.sum = 0 158 | self.count = 0 159 | self.val_window = [] 160 | self.avg_window = 0 161 | 162 | def update(self, val, n=1): 163 | self.val = val 164 | self.sum += val * n 165 | self.count += n 166 | self.avg = self.sum / self.count 167 | if len(self.val_window)< 10: 168 | self.val_window.append(self.val) 169 | elif len(self.val_window) == 10: 170 | self.val_window.pop(0) 171 | self.val_window.append(self.val) 172 | else: 173 | print("windows avg ERROR") 174 | self.avg_window = np.array(self.val_window).mean() 175 | 176 | 177 | # class VisdomLinePlotter(object): 178 | # """Plots to Visdom""" 179 | # def __init__(self, env_name='main'): 180 | # self.viz = Visdom() 181 | # self.env = env_name 182 | # self.plots = {} 183 | # self.scatters = {} 184 | # def plot(self, var_name, split_name, title_name, x, y): 185 | # if var_name not in self.plots: 186 | # self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict( 187 | # legend=[split_name], 188 | # title=title_name, 189 | # xlabel='Epochs', 190 | # ylabel=var_name 191 | # )) 192 | # else: 193 | # self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append') 194 | # 195 | # # def scatter(self, var_name, split_name, title_name, x, size=10): 196 | # # if var_name not in self.scatters: 197 | # # self.scatters[var_name] = self.viz.scatter(X=x.cpu().detach().numpy(), env=self.env, opts=dict( 198 | # # legend=[split_name], 199 | # # title=title_name, 200 | # # markersize=size 201 | # # )) 202 | # # else: 203 | # # self.viz.scatter(X=x.cpu().detach().numpy(), env=self.env, win=self.scatters[var_name], name=split_name, update='replace') 204 | # 205 | # def scatter(self, var_name, split_name, title_name, x, size=10, color=0, symbol='dot'): 206 | # if var_name not in self.scatters: 207 | # if type(x) == torch.Tensor: 208 | # x = x.cpu().detach().numpy() 209 | # self.scatters[var_name] = self.viz.scatter(X=x, env=self.env, opts=dict( 210 | # legend=[split_name], 211 | # title=title_name, 212 | # markersize=size, 213 | # markercolor=color, 214 | # markerborderwidth=0, 215 | # # opacity=0.5 216 | # # markersymbol=symbol, 217 | # # linecolor='white', 218 | # )) 219 | # else: 220 | # if type(x) == torch.Tensor: 221 | # x = x.cpu().detach().numpy() 222 | # self.viz.scatter(X=x, env=self.env, win=self.scatters[var_name], name=split_name, update='replace') 223 | 224 | 225 | #################################### 226 | ########### plotly plot ############ 227 | def show_3D_imageSlice_plotly(volume): 228 | if type(volume) == torch.Tensor: 229 | volume = volume.detach().cpu().numpy() 230 | r, c = volume[0].shape 231 | # Define frames 232 | import plotly.graph_objects as go 233 | import plotly.io as pio 234 | pio.renderers.default = "browser" 235 | nb_frames = volume.shape[0] 236 | 237 | fig = go.Figure(frames=[go.Frame(data=go.Surface( 238 | z=((nb_frames-1)/10 - k * 0.1) * np.ones((r, c)), 239 | surfacecolor=np.flipud(volume[nb_frames-1 - k]), 240 | cmin=0, cmax=200 241 | ), 242 | name=str(k) # you need to name the frame for the animation to behave properly 243 | ) 244 | for k in range(nb_frames)]) 245 | 246 | # Add data to be displayed before animation starts 247 | fig.add_trace(go.Surface( 248 | z=(nb_frames-1)/10 * np.ones((r, c)), 249 | surfacecolor=np.flipud(volume[nb_frames-1]), 250 | colorscale='Gray', 251 | cmin=0, cmax=200, 252 | colorbar=dict(thickness=20, ticklen=4) 253 | )) 254 | 255 | 256 | def frame_args(duration): 257 | return { 258 | "frame": {"duration": duration}, 259 | "mode": "immediate", 260 | "fromcurrent": True, 261 | "transition": {"duration": duration, "easing": "linear"}, 262 | } 263 | 264 | sliders = [ 265 | { 266 | "pad": {"b": 10, "t": 60}, 267 | "len": 0.9, 268 | "x": 0.1, 269 | "y": 0, 270 | "steps": [ 271 | { 272 | "args": [[f.name], frame_args(0)], 273 | "label": str(k), 274 | "method": "animate", 275 | } 276 | for k, f in enumerate(fig.frames) 277 | ], 278 | } 279 | ] 280 | 281 | # Layout 282 | fig.update_layout( 283 | title='Slices in volumetric data', 284 | width=600, 285 | height=600, 286 | scene=dict( 287 | zaxis=dict(range=[-0.1, (nb_frames-1)/10], autorange=False), 288 | aspectratio=dict(x=1, y=1, z=1), 289 | ), 290 | updatemenus = [ 291 | { 292 | "buttons": [ 293 | { 294 | "args": [None, frame_args(50)], 295 | "label": "▶", # play symbol 296 | "method": "animate", 297 | }, 298 | { 299 | "args": [[None], frame_args(0)], 300 | "label": "◼", # pause symbol 301 | "method": "animate", 302 | }, 303 | ], 304 | "direction": "left", 305 | "pad": {"r": 10, "t": 70}, 306 | "type": "buttons", 307 | "x": 0.1, 308 | "y": 0, 309 | } 310 | ], 311 | sliders=sliders 312 | ) 313 | 314 | fig.show() 315 | 316 | 317 | def show_3D_volume_plotly(volume, surface_count=17): 318 | import plotly.graph_objects as go 319 | import numpy as np 320 | import plotly.io as pio 321 | pio.renderers.default = "browser" 322 | if type(volume) == torch.Tensor: 323 | volume = volume.detach().cpu().numpy() 324 | 325 | X, Y, Z = np.mgrid[0:volume.shape[0], 0:volume.shape[1], 0:volume.shape[2]] 326 | 327 | fig = go.Figure(data=go.Volume( 328 | x=X.flatten(), 329 | y=Y.flatten(), 330 | z=Z.flatten(), 331 | value=volume.flatten(), 332 | isomin=0.1, 333 | isomax=0.8, 334 | opacity=0.1, # needs to be small to see through all surfaces 335 | surface_count=surface_count, # needs to be a large number for good volume rendering 336 | )) 337 | fig.show() 338 | 339 | #################################### 340 | #################################### 341 | 342 | def get_lr(optimizer): 343 | for param_group in optimizer.param_groups: 344 | return param_group['lr'] 345 | 346 | 347 | #################################### 348 | #################################### 349 | class Network_Logger(object): 350 | def __init__(self, model): 351 | self.model = model 352 | self.model_grad_dict = {} 353 | self.model_weight_dict = {} 354 | self.model_weightSize_dict = {} 355 | for (i, j) in self.model.named_parameters(): 356 | if len(j.shape) > 1: 357 | self.model_grad_dict[i] = [] 358 | self.model_weight_dict[i] = [j.abs().mean().item()] 359 | self.model_weightSize_dict[i] = j.shape 360 | 361 | def log_grad(self): 362 | for (i, j) in self.model.named_parameters(): 363 | if len(j.shape) > 1: 364 | self.model_grad_dict[i].append(j.grad.abs().mean().item()) 365 | 366 | def log_weight(self): 367 | for (i, j) in self.model.named_parameters(): 368 | if len(j.shape) > 1: 369 | self.model_weight_dict[i].append(j.abs().mean().item()) 370 | 371 | def get_current_weight(self): 372 | current_weight = [] 373 | for key in self.model_weight_dict.keys(): 374 | current_weight.append(self.model_weight_dict[key][-1]) 375 | return current_weight 376 | 377 | def get_current_grad(self): 378 | current_grad = [] 379 | for key in self.model_grad_dict.keys(): 380 | current_grad.append(self.model_grad_dict[key][-1]) 381 | return current_grad 382 | 383 | def plot_grad(self, layer_idx=None): 384 | # example: layer_idx = [0,1,2] for only first 3 layers 385 | fig = plt.figure() 386 | ax = fig.add_subplot(1, 1, 1) 387 | if layer_idx is not None: 388 | for idx, key in enumerate(self.model_grad_dict.keys()): 389 | if idx in layer_idx: 390 | ax.plot(self.model_grad_dict[key], label=str(key)) 391 | ax.legend() 392 | else: 393 | for idx, key in enumerate(self.model_grad_dict.keys()): 394 | ax.plot(self.model_grad_dict[key], label=str(key)) 395 | ax.legend() 396 | 397 | def plot_weight(self, layer_idx=None): 398 | # example: layer_idx = [0,1,2] for only first 3 layers 399 | fig = plt.figure() 400 | ax = fig.add_subplot(1, 1, 1) 401 | if layer_idx is not None: 402 | for idx, key in enumerate(self.model_weight_dict.keys()): 403 | if idx in layer_idx: 404 | ax.plot(self.model_weight_dict[key], label=str(key)) 405 | ax.legend() 406 | else: 407 | for idx, key in enumerate(self.model_weight_dict.keys()): 408 | ax.plot(self.model_weight_dict[key], label=str(key)) 409 | ax.legend() 410 | 411 | 412 | #################################### 413 | #################################### 414 | def show_img(img, save_file_name=''): 415 | if type(img) == torch.Tensor: 416 | img = img.cpu().detach().numpy() 417 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image 418 | if img.shape[0] == 3: 419 | img = img.transpose(1, 2, 0) 420 | fig = plt.figure() 421 | plt.imshow(img) 422 | if save_file_name != '': 423 | plt.savefig(save_file_name, format='svg') 424 | plt.colorbar() 425 | plt.show() 426 | 427 | def show_img_multi(img_list, num_col, num_row): 428 | fig = plt.figure() 429 | 430 | for idx, img in enumerate(img_list): 431 | if type(img) == torch.Tensor: 432 | img = img.cpu().detach().numpy() 433 | if len(img.shape) == 3: # HxWx3 or 3xHxW, treat as RGB image 434 | if img.shape[0] == 3: 435 | img = img.transpose(1, 2, 0) 436 | ax = fig.add_subplot(num_col, num_row, idx+1) 437 | ax.imshow(img) 438 | plt.show() 439 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | # adapted from DeepCluster repo: https://github.com/facebookresearch/deepcluster 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | __all__ = [ 'AlexNet', 'alexnet_MNIST', 'alexnet', 'alexnet_STL10', 'alexnet_PCam', 'alexnet_CAMELYON', 8 | 'AlexNet_MNIST_projection_prototype', 'alexnet_MedMNIST', 'alexnet_CIFAR10'] 9 | 10 | # (number of filters, kernel size, stride, pad) 11 | CFG = { 12 | 'big': [(96, 11, 4, 2), 'M', (256, 5, 1, 2), 'M', (384, 3, 1, 1), (384, 3, 1, 1), (256, 3, 1, 1), 'M'], 13 | 'small': [(64, 11, 4, 2), 'M', (192, 5, 1, 2), 'M', (384, 3, 1, 1), (256, 3, 1, 1), (256, 3, 1, 1), 'M'], 14 | 'mnist': [(32, 6, 2, 2), (64, 3, 1, 1), 'M', (128, 3, 1, 1), (128, 3, 1, 1), 'M'], 15 | 'CAMELYON': [(96, 12, 4, 4), (256, 12, 4, 4), 'M_', (256, 5, 1, 2), 'M_', (512, 3, 1, 1), (512, 3, 1, 1), (256, 3, 1, 1), 'M_'], 16 | 'CIFAR10': [(96, 3, 1, 1), 'M', (192, 3, 1, 1), 'M', (384, 3, 1, 1), (384, 3, 1, 1), (192, 3, 1, 1), 'M'] 17 | } 18 | 19 | 20 | class AlexNet(nn.Module): 21 | def __init__(self, features, num_classes, init=True): 22 | super(AlexNet, self).__init__() 23 | self.features = features 24 | self.classifier = nn.Sequential(nn.Dropout(0.5), 25 | nn.Linear(256 * 2 * 2, 4096), 26 | nn.ReLU(inplace=True), 27 | nn.Dropout(0.5), 28 | nn.Linear(4096, 4096), 29 | nn.ReLU(inplace=True)) 30 | self.headcount = len(num_classes) 31 | self.return_features = False 32 | if len(num_classes) == 1: 33 | self.top_layer = nn.Linear(4096, num_classes[0]) 34 | else: 35 | for a,i in enumerate(num_classes): 36 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i)) 37 | self.top_layer = None # this way headcount can act as switch. 38 | if init: 39 | self._initialize_weights() 40 | 41 | def forward(self, x): 42 | x = self.features(x) 43 | x = x.view(x.size(0), 256 * 2 * 2) 44 | x = self.classifier(x) 45 | if self.return_features: # switch only used for CIFAR-experiments 46 | return x 47 | if self.headcount == 1: 48 | if self.top_layer: # this way headcount can act as switch. 49 | x = self.top_layer(x) 50 | return x 51 | else: 52 | outp = [] 53 | for i in range(self.headcount): 54 | outp.append(getattr(self, "top_layer%d" % i)(x)) 55 | return outp 56 | 57 | def _initialize_weights(self): 58 | for y, m in enumerate(self.modules()): 59 | if isinstance(m, nn.Conv2d): 60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 61 | for i in range(m.out_channels): 62 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 63 | if m.bias is not None: 64 | m.bias.data.zero_() 65 | elif isinstance(m, nn.BatchNorm2d): 66 | m.weight.data.fill_(1) 67 | m.bias.data.zero_() 68 | elif isinstance(m, nn.Linear): 69 | m.weight.data.normal_(0, 0.01) 70 | m.bias.data.zero_() 71 | 72 | 73 | class AlexNet_4x4(nn.Module): 74 | def __init__(self, features, num_classes, init=True): 75 | super(AlexNet_4x4, self).__init__() 76 | self.features = features 77 | self.classifier = nn.Sequential(nn.Dropout(0.5), 78 | nn.Linear(256 * 4 * 4, 4096), 79 | nn.ReLU(inplace=True), 80 | nn.Dropout(0.5), 81 | nn.Linear(4096, 4096), 82 | nn.ReLU(inplace=True)) 83 | self.headcount = len(num_classes) 84 | self.return_features = False 85 | if len(num_classes) == 1: 86 | self.top_layer = nn.Linear(4096, num_classes[0]) 87 | else: 88 | for a,i in enumerate(num_classes): 89 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i)) 90 | self.top_layer = None # this way headcount can act as switch. 91 | if init: 92 | self._initialize_weights() 93 | 94 | def forward(self, x): 95 | x = self.features(x) 96 | x = x.view(x.size(0), 256 * 4 * 4) 97 | x = self.classifier(x) 98 | if self.return_features: # switch only used for CIFAR-experiments 99 | return x 100 | if self.headcount == 1: 101 | if self.top_layer: # this way headcount can act as switch. 102 | x = self.top_layer(x) 103 | return x 104 | else: 105 | outp = [] 106 | for i in range(self.headcount): 107 | outp.append(getattr(self, "top_layer%d" % i)(x)) 108 | return outp 109 | 110 | def _initialize_weights(self): 111 | for y, m in enumerate(self.modules()): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | for i in range(m.out_channels): 115 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 116 | if m.bias is not None: 117 | m.bias.data.zero_() 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.Linear): 122 | m.weight.data.normal_(0, 0.01) 123 | m.bias.data.zero_() 124 | 125 | 126 | class AlexNet_MNIST(nn.Module): 127 | def __init__(self, features, num_classes, init=True): 128 | super(AlexNet_MNIST, self).__init__() 129 | self.features = features 130 | self.classifier = nn.Sequential(nn.Dropout(0.5), 131 | nn.Linear(128 * 2 * 2, 1024), 132 | nn.ReLU(inplace=True), 133 | nn.Dropout(0.5), 134 | nn.Linear(1024, 1024), 135 | nn.ReLU(inplace=True)) 136 | self.headcount = len(num_classes) 137 | self.return_features = False 138 | if len(num_classes) == 1: 139 | self.top_layer = nn.Linear(1024, num_classes[0]) 140 | else: 141 | for a,i in enumerate(num_classes): 142 | setattr(self, "top_layer%d" % a, nn.Linear(1024, i)) 143 | self.top_layer = None # this way headcount can act as switch. 144 | if init: 145 | self._initialize_weights() 146 | 147 | def forward(self, x): 148 | x = self.features(x) 149 | x = x.view(x.size(0), 128 * 2 * 2) 150 | x = self.classifier(x) 151 | if self.return_features: # switch only used for CIFAR-experiments 152 | return x 153 | if self.headcount == 1: 154 | if self.top_layer: # this way headcount can act as switch. 155 | x = self.top_layer(x) 156 | # x = nn.functional.tanh(x) # add by xiaoyuan 2021_4_22 to avoid nan in loss 157 | return x 158 | else: 159 | outp = [] 160 | for i in range(self.headcount): 161 | outp.append(getattr(self, "top_layer%d" % i)(x)) 162 | return outp 163 | 164 | def _initialize_weights(self): 165 | for y, m in enumerate(self.modules()): 166 | if isinstance(m, nn.Conv2d): 167 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 168 | for i in range(m.out_channels): 169 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | elif isinstance(m, nn.Linear): 176 | m.weight.data.normal_(0, 0.01) 177 | m.bias.data.zero_() 178 | 179 | 180 | class AlexNet_CIFAR10(nn.Module): 181 | def __init__(self, features, num_classes, init=True, input_feat_dim=192*3*3): 182 | super(AlexNet_CIFAR10, self).__init__() 183 | self.features = features 184 | self.input_feat_dim = input_feat_dim 185 | self.classifier = nn.Sequential( 186 | # nn.Dropout(0.5), 187 | nn.Linear(input_feat_dim, 4096), 188 | nn.ReLU(inplace=True), 189 | # nn.Dropout(0.5), 190 | nn.Linear(4096, 4096), 191 | nn.ReLU(inplace=True) 192 | ) 193 | self.headcount = len(num_classes) 194 | self.return_features = False 195 | if len(num_classes) == 1: 196 | self.top_layer = nn.Linear(4096, num_classes[0]) 197 | else: 198 | for a, i in enumerate(num_classes): 199 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i)) 200 | self.top_layer = None # this way headcount can act as switch. 201 | if init: 202 | self._initialize_weights() 203 | 204 | def forward(self, x): 205 | if self.features is not None: 206 | x = self.features(x) 207 | x = x.view(x.size(0), self.input_feat_dim) 208 | x = self.classifier(x) 209 | if self.return_features: # switch only used for CIFAR-experiments 210 | return x 211 | if self.headcount == 1: 212 | if self.top_layer: # this way headcount can act as switch. 213 | x = self.top_layer(x) 214 | # x = nn.functional.tanh(x) # add by xiaoyuan 2021_4_22 to avoid nan in loss 215 | return x 216 | else: 217 | outp = [] 218 | for i in range(self.headcount): 219 | outp.append(getattr(self, "top_layer%d" % i)(x)) 220 | return outp 221 | 222 | def _initialize_weights(self): 223 | for y, m in enumerate(self.modules()): 224 | if isinstance(m, nn.Conv2d): 225 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 226 | for i in range(m.out_channels): 227 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 228 | if m.bias is not None: 229 | m.bias.data.zero_() 230 | elif isinstance(m, nn.BatchNorm2d): 231 | m.weight.data.fill_(1) 232 | m.bias.data.zero_() 233 | elif isinstance(m, nn.Linear): 234 | m.weight.data.normal_(0, 0.01) 235 | m.bias.data.zero_() 236 | 237 | 238 | class AlexNet_MNIST_projection_prototype(nn.Module): 239 | def __init__(self, output_dim=0, hidden_mlp=0, nmb_prototypes=0, init=True, normalize=True, 240 | eval_mode=False, norm_layer=None): 241 | super(AlexNet_MNIST_projection_prototype, self).__init__() 242 | 243 | self.features = make_layers_features(CFG['mnist'], 1, bn=True) 244 | 245 | if norm_layer is None: 246 | norm_layer = nn.BatchNorm2d 247 | self._norm_layer = norm_layer 248 | 249 | self.eval_mode = eval_mode 250 | 251 | # normalize output features 252 | self.l2norm = normalize 253 | 254 | # projection head 255 | if output_dim == 0: 256 | self.projection_head = None 257 | elif hidden_mlp == 0: 258 | # self.projection_head = nn.Linear(128*2*2, output_dim) 259 | self.projection_head = nn.Linear(128, output_dim) 260 | else: 261 | self.projection_head = nn.Sequential( 262 | # nn.Linear(128*2*2, hidden_mlp), 263 | nn.Linear(128, hidden_mlp), 264 | nn.BatchNorm1d(hidden_mlp), 265 | nn.ReLU(inplace=True), 266 | nn.Linear(hidden_mlp, output_dim), 267 | ) 268 | 269 | # prototype layer 270 | self.prototypes = None 271 | if isinstance(nmb_prototypes, list): 272 | # self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) 273 | print("Multiple Prototypes is not supported now") 274 | elif nmb_prototypes > 0: 275 | self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) 276 | 277 | for m in self.modules(): 278 | if isinstance(m, nn.Conv2d): 279 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 280 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 281 | nn.init.constant_(m.weight, 1) 282 | nn.init.constant_(m.bias, 0) 283 | 284 | def forward_backbone(self, x): 285 | x = self.features(x) 286 | # x = x.view(x.size(0), 128 * 2 * 2) 287 | x = x.view(x.size(0), 128, 2 * 2) 288 | x = x.max(dim=-1)[0] 289 | return x 290 | 291 | def forward_head(self, x): 292 | if self.projection_head is not None: 293 | x = self.projection_head(x) 294 | 295 | if self.l2norm: 296 | x = nn.functional.normalize(x, dim=1, p=2) 297 | 298 | if self.prototypes is not None: 299 | return x, self.prototypes(x) 300 | return x 301 | 302 | def forward(self, inputs): 303 | if not isinstance(inputs, list): 304 | inputs = [inputs] 305 | idx_crops = torch.cumsum(torch.unique_consecutive( 306 | torch.tensor([inp.shape[-1] for inp in inputs]), 307 | return_counts=True, 308 | )[1], 0) 309 | start_idx = 0 310 | for end_idx in idx_crops: 311 | _out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)) 312 | if start_idx == 0: 313 | output = _out 314 | else: 315 | output = torch.cat((output, _out)) 316 | start_idx = end_idx 317 | return self.forward_head(output) 318 | 319 | def _initialize_weights(self): 320 | for y, m in enumerate(self.modules()): 321 | if isinstance(m, nn.Conv2d): 322 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 323 | for i in range(m.out_channels): 324 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 325 | if m.bias is not None: 326 | m.bias.data.zero_() 327 | elif isinstance(m, nn.BatchNorm2d): 328 | m.weight.data.fill_(1) 329 | m.bias.data.zero_() 330 | elif isinstance(m, nn.Linear): 331 | m.weight.data.normal_(0, 0.01) 332 | m.bias.data.zero_() 333 | 334 | 335 | def make_layers_features(cfg, input_dim, bn): 336 | layers = [] 337 | in_channels = input_dim 338 | for v in cfg: 339 | if v == 'M': 340 | layers += [nn.MaxPool2d(kernel_size=3, stride=2)] 341 | elif v == 'M_': 342 | layers += [nn.MaxPool2d(kernel_size=4, stride=2, padding=1)] 343 | else: 344 | conv2d = nn.Conv2d(in_channels, v[0], kernel_size=v[1], stride=v[2], padding=v[3])#,bias=False) 345 | if bn: 346 | layers += [conv2d, nn.BatchNorm2d(v[0]), nn.ReLU(inplace=True)] 347 | else: 348 | layers += [conv2d, nn.ReLU(inplace=True)] 349 | in_channels = v[0] 350 | return nn.Sequential(*layers) 351 | 352 | 353 | def alexnet(bn=True, num_classes=[1000], init=True, size='big'): 354 | dim = 1 355 | model = AlexNet(make_layers_features(CFG[size], dim, bn=bn), num_classes, init) 356 | return model 357 | 358 | 359 | def alexnet_MNIST(bn=True, num_classes=[2], init=True): 360 | dim = 1 361 | model = AlexNet_MNIST(make_layers_features(CFG['mnist'], dim, bn=bn), num_classes, init) 362 | return model 363 | 364 | 365 | def alexnet_MedMNIST(bn=True, num_classes=[2], init=True): 366 | dim = 3 367 | model = AlexNet_MNIST(make_layers_features(CFG['mnist'], dim, bn=bn), num_classes, init) 368 | return model 369 | 370 | 371 | def alexnet_STL10(num_classes): 372 | model = SmallAlexNet(num_classes) 373 | return model 374 | 375 | 376 | def alexnet_PCam(bn=True, num_classes=[2], init=True): 377 | dim = 3 378 | model = AlexNet(make_layers_features(CFG['big'], input_dim=dim ,bn=bn), num_classes=num_classes, init=init) 379 | return model 380 | 381 | 382 | def alexnet_CAMELYON(bn=True, num_classes=[2], init=True): 383 | dim = 3 384 | model = AlexNet_4x4(make_layers_features(CFG['CAMELYON'], input_dim=dim ,bn=bn), num_classes=num_classes, init=init) 385 | return model 386 | 387 | 388 | def alexnet_CIFAR10(bn=True, num_classes=[2], init=True): 389 | dim = 3 390 | model = AlexNet_CIFAR10(make_layers_features(CFG['CIFAR10'], dim, bn=bn), num_classes, init) 391 | return model 392 | 393 | 394 | class L2Norm(nn.Module): 395 | def forward(self, x): 396 | return x / x.norm(p=2, dim=1, keepdim=True) 397 | 398 | 399 | class SmallAlexNet(nn.Module): 400 | def __init__(self, in_channel=3, num_classes=[2]): 401 | super(SmallAlexNet, self).__init__() 402 | blocks = [] 403 | 404 | # conv_block_1 405 | blocks.append(nn.Sequential( 406 | nn.Conv2d(in_channel, 96, kernel_size=3, padding=1, bias=False), 407 | nn.BatchNorm2d(96), 408 | nn.ReLU(inplace=True), 409 | nn.MaxPool2d(3, 2), 410 | )) 411 | 412 | # conv_block_2 413 | blocks.append(nn.Sequential( 414 | nn.Conv2d(96, 192, kernel_size=3, padding=1, bias=False), 415 | nn.BatchNorm2d(192), 416 | nn.ReLU(inplace=True), 417 | nn.MaxPool2d(3, 2), 418 | )) 419 | 420 | # conv_block_3 421 | blocks.append(nn.Sequential( 422 | nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False), 423 | nn.BatchNorm2d(384), 424 | nn.ReLU(inplace=True), 425 | )) 426 | 427 | # conv_block_4 428 | blocks.append(nn.Sequential( 429 | nn.Conv2d(384, 384, kernel_size=3, padding=1, bias=False), 430 | nn.BatchNorm2d(384), 431 | nn.ReLU(inplace=True), 432 | )) 433 | 434 | # conv_block_5 435 | blocks.append(nn.Sequential( 436 | nn.Conv2d(384, 192, kernel_size=3, padding=1, bias=False), 437 | nn.BatchNorm2d(192), 438 | nn.ReLU(inplace=True), 439 | nn.MaxPool2d(3, 2), 440 | )) 441 | 442 | # fc6 443 | blocks.append(nn.Sequential( 444 | nn.Flatten(), 445 | nn.Linear(192 * 7 * 7, 4096, bias=False), # 256 * 6 * 6 if 224 * 224 446 | nn.BatchNorm1d(4096), 447 | nn.ReLU(inplace=True), 448 | )) 449 | 450 | # fc7 451 | blocks.append(nn.Sequential( 452 | nn.Linear(4096, 4096, bias=False), 453 | nn.BatchNorm1d(4096), 454 | nn.ReLU(inplace=True), 455 | )) 456 | 457 | # fc8 458 | blocks.append(nn.Sequential( 459 | nn.Linear(4096, num_classes[0]), 460 | L2Norm(), 461 | )) 462 | 463 | self.blocks = nn.ModuleList(blocks) 464 | self.init_weights_() 465 | 466 | def init_weights_(self): 467 | def init(m): 468 | if isinstance(m, (nn.Linear, nn.Conv2d)): 469 | nn.init.normal_(m.weight, 0, 0.02) 470 | if getattr(m, 'bias', None) is not None: 471 | nn.init.zeros_(m.bias) 472 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 473 | if getattr(m, 'weight', None) is not None: 474 | nn.init.ones_(m.weight) 475 | if getattr(m, 'bias', None) is not None: 476 | nn.init.zeros_(m.bias) 477 | 478 | self.apply(init) 479 | 480 | def forward(self, x, *, layer_index=-1): 481 | if layer_index < 0: 482 | layer_index += len(self.blocks) 483 | for layer in self.blocks[:(layer_index + 1)]: 484 | x = layer(x) 485 | return x 486 | 487 | 488 | class AlexNet_MNIST_attention(nn.Module): 489 | def __init__(self, features, num_classes, init=True, withoutAtten=False): 490 | super(AlexNet_MNIST_attention, self).__init__() 491 | self.withoutAtten=withoutAtten 492 | self.features = features 493 | self.classifier = nn.Sequential(nn.Dropout(0.5), 494 | nn.Linear(128 * 2 * 2, 1024), 495 | nn.ReLU(inplace=True), 496 | nn.Dropout(0.5), 497 | nn.Linear(1024, 1024), 498 | nn.ReLU(inplace=True)) 499 | self.L = 1024 500 | self.D = 512 501 | self.K = 1 502 | 503 | self.attention = nn.Sequential( 504 | nn.Linear(self.L, self.D), 505 | nn.Tanh(), 506 | nn.Linear(self.D, self.K) 507 | ) 508 | self.headcount = len(num_classes) 509 | self.return_features = False 510 | if len(num_classes) == 1: 511 | self.top_layer = nn.Linear(1024, num_classes[0]) 512 | else: 513 | for a,i in enumerate(num_classes): 514 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i)) 515 | self.top_layer = None # this way headcount can act as switch. 516 | if init: 517 | self._initialize_weights() 518 | 519 | def forward(self, x, returnBeforeSoftMaxA=False): 520 | x = x.squeeze(0) 521 | x = self.features(x) 522 | x = x.view(x.size(0), 128 * 2 * 2) 523 | x = self.classifier(x) 524 | 525 | # Attention module 526 | A_ = self.attention(x) # NxK 527 | A_ = torch.transpose(A_, 1, 0) # KxN 528 | A = F.softmax(A_, dim=1) # softmax over N 529 | 530 | if self.withoutAtten: 531 | x = torch.mean(x, dim=0, keepdim=True) 532 | else: 533 | x = torch.mm(A, x) # KxL 534 | 535 | if self.return_features: # switch only used for CIFAR-experiments 536 | return x 537 | 538 | x = self.top_layer(x) 539 | if returnBeforeSoftMaxA: 540 | return x, 0, A, A_ 541 | return x, 0, A 542 | 543 | def _initialize_weights(self): 544 | for y, m in enumerate(self.modules()): 545 | if isinstance(m, nn.Conv2d): 546 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 547 | for i in range(m.out_channels): 548 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 549 | if m.bias is not None: 550 | m.bias.data.zero_() 551 | elif isinstance(m, nn.BatchNorm2d): 552 | m.weight.data.fill_(1) 553 | m.bias.data.zero_() 554 | elif isinstance(m, nn.Linear): 555 | m.weight.data.normal_(0, 0.01) 556 | m.bias.data.zero_() 557 | 558 | 559 | class AlexNet_CIFAR10_attention(nn.Module): 560 | def __init__(self, features, num_classes, init=True, withoutAtten=False, input_feat_dim=192*3*3): 561 | super(AlexNet_CIFAR10_attention, self).__init__() 562 | self.input_feat_dim = input_feat_dim 563 | self.withoutAtten=withoutAtten 564 | self.features = features 565 | self.classifier = nn.Sequential(nn.Dropout(0.5), 566 | nn.Linear(input_feat_dim, 1024), 567 | nn.ReLU(inplace=True), 568 | nn.Dropout(0.5), 569 | nn.Linear(1024, 1024), 570 | nn.ReLU(inplace=True)) 571 | self.L = 1024 572 | self.D = 512 573 | self.K = 1 574 | 575 | self.attention = nn.Sequential( 576 | nn.Linear(self.L, self.D), 577 | nn.Tanh(), 578 | nn.Linear(self.D, self.K) 579 | ) 580 | self.headcount = len(num_classes) 581 | self.return_features = False 582 | if len(num_classes) == 1: 583 | self.top_layer = nn.Linear(1024, num_classes[0]) 584 | else: 585 | for a,i in enumerate(num_classes): 586 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i)) 587 | self.top_layer = None # this way headcount can act as switch. 588 | if init: 589 | self._initialize_weights() 590 | 591 | def forward(self, x, returnBeforeSoftMaxA=False, scores_replaceAS=None): 592 | if self.features is not None: 593 | x = x.squeeze(0) 594 | x = self.features(x) 595 | x = x.view(x.size(0), self.input_feat_dim) 596 | x = self.classifier(x) 597 | 598 | # Attention module 599 | A_ = self.attention(x) # NxK 600 | A_ = torch.transpose(A_, 1, 0) # KxN 601 | A = F.softmax(A_, dim=1) # softmax over N 602 | 603 | if scores_replaceAS is not None: 604 | A_ = scores_replaceAS 605 | A = F.softmax(A_, dim=1) # softmax over N 606 | 607 | if self.withoutAtten: 608 | x = torch.mean(x, dim=0, keepdim=True) 609 | else: 610 | x = torch.mm(A, x) # KxL 611 | 612 | if self.return_features: # switch only used for CIFAR-experiments 613 | return x 614 | 615 | x = self.top_layer(x) 616 | if returnBeforeSoftMaxA: 617 | return x, 0, A, A_ 618 | return x, 0, A 619 | 620 | def _initialize_weights(self): 621 | for y, m in enumerate(self.modules()): 622 | if isinstance(m, nn.Conv2d): 623 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 624 | for i in range(m.out_channels): 625 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 626 | if m.bias is not None: 627 | m.bias.data.zero_() 628 | elif isinstance(m, nn.BatchNorm2d): 629 | m.weight.data.fill_(1) 630 | m.bias.data.zero_() 631 | elif isinstance(m, nn.Linear): 632 | m.weight.data.normal_(0, 0.01) 633 | m.bias.data.zero_() 634 | 635 | 636 | class AlexNet_CIFAR10_dsmil(nn.Module): 637 | def __init__(self, features, num_classes, init=True, withoutAtten=False, input_feat_dim=192 * 3 * 3): 638 | super(AlexNet_CIFAR10_dsmil, self).__init__() 639 | self.withoutAtten=withoutAtten 640 | self.features = features 641 | self.classifier = nn.Sequential(nn.Dropout(0.5), 642 | nn.Linear(input_feat_dim, 1024), 643 | nn.ReLU(inplace=True), 644 | nn.Dropout(0.5), 645 | nn.Linear(1024, 1024), 646 | nn.ReLU(inplace=True)) 647 | # self.L = 1024 648 | # self.D = 512 649 | # self.K = 1 650 | # self.attention = nn.Sequential( 651 | # nn.Linear(self.L, self.D), 652 | # nn.Tanh(), 653 | # nn.Linear(self.D, self.K) 654 | # ) 655 | 656 | self.fc_dsmil = nn.Sequential(nn.Linear(1024, 2)) 657 | self.q_dsmil = nn.Linear(1024, 1024) 658 | self.v_dsmil = nn.Sequential( 659 | nn.Dropout(0.0), 660 | nn.Linear(1024, 1024) 661 | ) 662 | self.fcc_dsmil = nn.Conv1d(2, 2, kernel_size=1024) 663 | 664 | self.headcount = len(num_classes) 665 | self.return_features = False 666 | if len(num_classes) == 1: 667 | self.top_layer = nn.Linear(1024, num_classes[0]) 668 | else: 669 | for a,i in enumerate(num_classes): 670 | setattr(self, "top_layer%d" % a, nn.Linear(4096, i)) 671 | self.top_layer = None # this way headcount can act as switch. 672 | if init: 673 | self._initialize_weights() 674 | 675 | def forward(self, x): 676 | if self.features is not None: 677 | x = x.squeeze(0) 678 | x = self.features(x) 679 | x = x.view(x.size(0), -1) 680 | x = self.classifier(x) 681 | 682 | # # Attention module 683 | # A_ = self.attention(x) # NxK 684 | # A_ = torch.transpose(A_, 1, 0) # KxN 685 | # A = F.softmax(A_, dim=1) # softmax over N 686 | # 687 | # if self.withoutAtten: 688 | # x = torch.mean(x, dim=0, keepdim=True) 689 | # else: 690 | # x = torch.mm(A, x) # KxL 691 | # 692 | # if self.return_features: # switch only used for CIFAR-experiments 693 | # return x 694 | # x = self.top_layer(x) 695 | # if returnBeforeSoftMaxA: 696 | # return x, 0, A, A_ 697 | # return x, 0, A 698 | 699 | feat = x 700 | device = feat.device 701 | instance_pred = self.fc_dsmil(feat) 702 | V = self.v_dsmil(feat) 703 | Q = self.q_dsmil(feat).view(feat.shape[0], -1) 704 | _, m_indices = torch.sort(instance_pred, 0, descending=True) # sort class scores along the instance dimension, m_indices in shape N x C 705 | m_feats = torch.index_select(feat, dim=0, index=m_indices[0, :]) # select critical instances, m_feats in shape C x K 706 | q_max = self.q_dsmil(m_feats) # compute queries of critical instances, q_max in shape C x Q 707 | A = torch.mm(Q, q_max.transpose(0, 1)) # compute inner product of Q to each entry of q_max, A in shape N x C, each column contains unnormalized attention scores 708 | A = F.softmax( A / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device)), 0) # normalize attention scores, A in shape N x C, 709 | B = torch.mm(A.transpose(0, 1), V) # compute bag representation, B in shape C x V 710 | B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V 711 | C = self.fcc_dsmil(B) # 1 x C x 1 712 | C = C.view(1, -1) 713 | return instance_pred, C, A, B 714 | 715 | def _initialize_weights(self): 716 | for y, m in enumerate(self.modules()): 717 | if isinstance(m, nn.Conv2d): 718 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 719 | for i in range(m.out_channels): 720 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 721 | if m.bias is not None: 722 | m.bias.data.zero_() 723 | elif isinstance(m, nn.BatchNorm2d): 724 | m.weight.data.fill_(1) 725 | m.bias.data.zero_() 726 | elif isinstance(m, nn.Linear): 727 | m.weight.data.normal_(0, 0.01) 728 | m.bias.data.zero_() 729 | 730 | 731 | def alexnet_MNIST_Attention(bn=True, num_classes=[2], init=True): 732 | dim = 1 733 | model = AlexNet_MNIST_attention(make_layers_features(CFG['mnist'], dim, bn=bn), num_classes, init) 734 | return model 735 | 736 | 737 | def alexnet_CIFAR10_Attention(bn=True, num_classes=[2], init=True): 738 | dim = 3 739 | model = AlexNet_CIFAR10_attention(make_layers_features(CFG['CIFAR10'], dim, bn=bn), num_classes, init) 740 | return model 741 | 742 | 743 | ######################################## 744 | ## models for Shared Stu and Tea network 745 | def alexnet_CIFAR10_Encoder(): 746 | dim = 3 747 | model = make_layers_features(CFG['CIFAR10'], dim, bn=True) 748 | return model 749 | 750 | 751 | def teacher_Attention_head(bn=True, num_classes=[2], init=True, input_feat_dim=192*3*3): 752 | model = AlexNet_CIFAR10_attention(features=None, num_classes=num_classes, init=init, input_feat_dim=input_feat_dim) 753 | return model 754 | 755 | 756 | def teacher_DSMIL_head(bn=True, num_classes=[2], init=True, input_feat_dim=192*3*3): 757 | model = AlexNet_CIFAR10_dsmil(features=None, num_classes=num_classes, init=init, input_feat_dim=input_feat_dim) 758 | return model 759 | 760 | 761 | def student_head(num_classes=[2], init=True, input_feat_dim=192*3*3): 762 | model = AlexNet_CIFAR10(None, num_classes, init, input_feat_dim=input_feat_dim) 763 | return model 764 | 765 | 766 | class feat_projecter(nn.Module): 767 | def __init__(self, input_feat_dim=512, output_feat_dim=512): 768 | super(feat_projecter, self).__init__() 769 | # self.projecter = nn.Sequential( 770 | # nn.Linear(input_feat_dim, input_feat_dim*2), 771 | # nn.BatchNorm1d(input_feat_dim*2), 772 | # nn.ReLU(inplace=True), 773 | # nn.Linear(input_feat_dim*2, input_feat_dim * 2), 774 | # nn.BatchNorm1d(input_feat_dim*2), 775 | # nn.ReLU(inplace=True), 776 | # nn.Linear(input_feat_dim * 2, output_feat_dim), 777 | # nn.BatchNorm1d(output_feat_dim), 778 | # ) 779 | self.projecter = nn.Sequential( 780 | nn.Linear(input_feat_dim, output_feat_dim), 781 | nn.BatchNorm1d(output_feat_dim) 782 | ) 783 | def forward(self, x): 784 | x = self.projecter(x) 785 | return x 786 | 787 | 788 | def camelyon_feat_projecter(input_dim, output_dim): 789 | model = feat_projecter(input_dim, output_dim) 790 | return model 791 | ######################################## 792 | 793 | if __name__ == '__main__': 794 | import torch 795 | # model = alexnet(num_classes=[500]*3) 796 | # print([ k.shape for k in model(torch.randn(64,3,224,224))]) 797 | model = AlexNet_MNIST_projection_prototype(output_dim=128, hidden_mlp=2048, nmb_prototypes=300) 798 | print("END") 799 | 800 | -------------------------------------------------------------------------------- /train_TCGAFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | import os 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim 9 | import torch.nn as nn 10 | import torch.utils.data 11 | from tensorboardX import SummaryWriter 12 | # import models 13 | # from models.alexnet import alexnet_CIFAR10, alexnet_CIFAR10_Attention 14 | from models.alexnet import camelyon_feat_projecter, teacher_DSMIL_head, student_head 15 | # from dataset_toy import Dataset_toy 16 | # from Datasets_loader.dataset_MNIST_challenge import MNIST_WholeSlide_challenge 17 | # from Datasets_loader.dataset_MIL_CIFAR import CIFAR_WholeSlide_challenge 18 | from Datasets_loader.dataset_TCGA_LungCancer import TCGA_LungCancer_Feat 19 | import datetime 20 | import utliz 21 | import util 22 | import random 23 | from tqdm import tqdm 24 | import copy 25 | 26 | 27 | class Optimizer: 28 | def __init__(self, model_encoder, model_teacherHead, model_studentHead, 29 | optimizer_encoder, optimizer_teacherHead, optimizer_studentHead, 30 | train_bagloader, train_instanceloader, test_bagloader, test_instanceloader, 31 | writer=None, num_epoch=100, 32 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 33 | PLPostProcessMethod='NegGuide', StuFilterType='ReplaceAS', smoothE=100, 34 | stu_loss_weight_neg=0.1, stuOptPeriod=1): 35 | self.model_encoder = model_encoder 36 | self.model_teacherHead = model_teacherHead 37 | self.model_studentHead = model_studentHead 38 | self.optimizer_encoder = optimizer_encoder 39 | self.optimizer_teacherHead = optimizer_teacherHead 40 | self.optimizer_studentHead = optimizer_studentHead 41 | self.train_bagloader = train_bagloader 42 | self.train_instanceloader = train_instanceloader 43 | self.test_bagloader = test_bagloader 44 | self.test_instanceloader = test_instanceloader 45 | self.writer = writer 46 | self.num_epoch = num_epoch 47 | self.dev = dev 48 | self.log_period = 10 49 | self.PLPostProcessMethod = PLPostProcessMethod 50 | self.StuFilterType = StuFilterType 51 | self.smoothE = smoothE 52 | self.stu_loss_weight_neg = stu_loss_weight_neg 53 | self.stuOptPeriod = stuOptPeriod 54 | 55 | def optimize(self): 56 | self.Bank_all_Bags_label = None 57 | self.Bank_all_instances_pred_byTeacher = None 58 | self.Bank_all_instances_feat_byTeacher = None 59 | self.Bank_all_instances_pred_processed = None 60 | 61 | self.Bank_all_instances_pred_byStudent = None 62 | 63 | # Load pre-extracted SimCLR features 64 | # pre_trained_SimCLR_feat = self.train_instanceloader.dataset.ds_data_simCLR_feat[self.train_instanceloader.dataset.idx_all_slides].to(self.dev) 65 | for epoch in range(self.num_epoch): 66 | self.optimize_teacher(epoch) 67 | self.evaluate_teacher(epoch) 68 | if epoch % self.stuOptPeriod == 0: 69 | self.optimize_student(epoch) 70 | self.evaluate_student(epoch) 71 | 72 | return 0 73 | 74 | def optimize_teacher(self, epoch): 75 | self.model_encoder.train() 76 | self.model_teacherHead.train() 77 | self.model_studentHead.eval() 78 | criterion = torch.nn.CrossEntropyLoss() 79 | ## optimize teacher with bag-dataloader 80 | # 1. change loader to bag-loader 81 | loader = self.train_bagloader 82 | # 2. optimize 83 | patch_label_gt = [] 84 | patch_label_pred = [] 85 | bag_label_gt = [] 86 | bag_label_pred = [] 87 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher training')): 88 | for i, j in enumerate(label): 89 | if torch.is_tensor(j): 90 | label[i] = j.to(self.dev) 91 | selected = selected.squeeze(0) 92 | niter = epoch * len(loader) + iter 93 | 94 | data = data.to(self.dev) 95 | feat = self.model_encoder(data.squeeze(0)) 96 | if epoch > self.smoothE: 97 | if "FilterNegInstance" in self.StuFilterType: 98 | # using student prediction to remove negative instance feat in the positive bag 99 | if label[1] == 1: 100 | with torch.no_grad(): 101 | pred_byStudent = self.model_studentHead(feat) 102 | pred_byStudent = torch.softmax(pred_byStudent, dim=1)[:, 1] 103 | if '_Top' in self.StuFilterType: 104 | # strategy A: remove the topK most negative instance 105 | idx_to_keep = torch.topk(-pred_byStudent, k=int(self.StuFilterType.split('_Top')[-1]))[1] 106 | elif '_ThreProb' in self.StuFilterType: 107 | # strategy B: remove the negative instance above prob K 108 | idx_to_keep = torch.where(pred_byStudent <= int(self.StuFilterType.split('_ThreProb')[-1])/100.0)[0] 109 | if idx_to_keep.shape[0] == 0: # if all instance are dropped, keep the most positive one 110 | idx_to_keep = torch.topk(pred_byStudent, k=1)[1] 111 | feat_removedNeg = feat[idx_to_keep] 112 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat_removedNeg) 113 | instance_attn_score = torch.cat([instance_attn_score, instance_attn_score[:, 1].min()*torch.ones(feat.shape[0]-instance_attn_score.shape[0], 2).to(instance_attn_score.device)], dim=0) 114 | else: 115 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 116 | else: 117 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 118 | else: 119 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 120 | 121 | max_id = torch.argmax(instance_attn_score[:, 1]) 122 | bag_pred_byMax = instance_attn_score[max_id, :].squeeze(0) 123 | bag_loss = criterion(bag_prediction, label[1]) 124 | bag_loss_byMax = criterion(bag_pred_byMax.unsqueeze(0), label[1]) 125 | loss_teacher = 0.5 * bag_loss + 0.5 * bag_loss_byMax 126 | 127 | self.optimizer_encoder.zero_grad() 128 | self.optimizer_teacherHead.zero_grad() 129 | loss_teacher.backward() 130 | self.optimizer_encoder.step() 131 | self.optimizer_teacherHead.step() 132 | 133 | bag_prediction = 1.0 * torch.softmax(bag_prediction, dim=1) + \ 134 | 0.0 * torch.softmax(bag_pred_byMax.unsqueeze(0), dim=1) 135 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1) 136 | 137 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0)) 138 | patch_label_gt.append(label[0].squeeze(0)) 139 | bag_label_pred.append(bag_prediction.detach()[0, 1]) 140 | bag_label_gt.append(label[1]) 141 | if niter % self.log_period == 0: 142 | self.writer.add_scalar('train_loss_Teacher', loss_teacher.item(), niter) 143 | 144 | patch_label_pred = torch.cat(patch_label_pred) 145 | patch_label_gt = torch.cat(patch_label_gt) 146 | bag_label_pred = torch.tensor(bag_label_pred) 147 | bag_label_gt = torch.cat(bag_label_gt) 148 | 149 | self.estimated_AttnScore_norm_para_min = patch_label_pred.min() 150 | self.estimated_AttnScore_norm_para_max = patch_label_pred.max() 151 | patch_label_pred_normed = self.norm_AttnScore2Prob(patch_label_pred) 152 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1)) 153 | 154 | bag_auc_ByTeacher = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_pred.reshape(-1)) 155 | self.writer.add_scalar('train_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch) 156 | self.writer.add_scalar('train_bag_AUC_byTeacher', bag_auc_ByTeacher, epoch) 157 | # print("Epoch:{} train_bag_AUC_byTeacher:{}".format(epoch, bag_auc_ByTeacher)) 158 | return 0 159 | 160 | def norm_AttnScore2Prob(self, attn_score): 161 | prob = (attn_score - self.estimated_AttnScore_norm_para_min) / (self.estimated_AttnScore_norm_para_max - self.estimated_AttnScore_norm_para_min) 162 | return prob 163 | 164 | def post_process_pred_byTeacher(self, Bank_all_instances_feat, Bank_all_instances_pred, Bank_all_bags_label, method='NegGuide'): 165 | if method=='NegGuide': 166 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 167 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 168 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 169 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 170 | elif method=='NegGuide_TopK': 171 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 172 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 173 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0] 174 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 175 | K = 3 176 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1] 177 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1) 178 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 179 | elif method=='NegGuide_Similarity': 180 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 181 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 182 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0] 183 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 184 | K = 1 185 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1] 186 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1) 187 | Bank_all_Pos_instances_feat = Bank_all_instances_feat[idx_pos_bag] 188 | Bank_mostSalient_Pos_instances_feat = [] 189 | for i in range(Bank_all_Pos_instances_feat.shape[0]): 190 | Bank_mostSalient_Pos_instances_feat.append(Bank_all_Pos_instances_feat[i, idx_topK_inside_pos_bag[i, 0], :].unsqueeze(0).unsqueeze(0)) 191 | Bank_mostSalient_Pos_instances_feat = torch.cat(Bank_mostSalient_Pos_instances_feat, dim=0) 192 | 193 | distance_matrix = Bank_all_Pos_instances_feat - Bank_mostSalient_Pos_instances_feat 194 | distance_matrix = torch.norm(distance_matrix, dim=-1, p=2) 195 | Bank_all_instances_pred_processed[idx_pos_bag, :] = self.distanceMatrix2PL(distance_matrix) 196 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 197 | else: 198 | raise TypeError 199 | return Bank_all_instances_pred_processed 200 | 201 | def distanceMatrix2PL(self, distance_matrix, method='percentage'): 202 | # distance_matrix is of shape NxL (Num of Positive Bag * Bag Length) 203 | # represents the distance between each instance with their corresponding most salient instance 204 | # return Pseudo-labels of shape NxL (value should belong to [0,1]) 205 | 206 | if method == 'softMax': 207 | # 1. just use softMax to keep PLs value fall into [0,1] 208 | similarity_matrix = 1/(distance_matrix + 1e-5) 209 | pseudo_labels = torch.softmax(similarity_matrix, dim=1) 210 | elif method == 'percentage': 211 | # 2. use percentage to keep n% PL=1, 1-n% PL=0 212 | p = 0.1 # 10% is set 213 | threshold_v = distance_matrix.topk(k=int(100 * p), dim=1)[0][:, -1].unsqueeze(1).repeat([1, 100]) # of size Nx100 214 | pseudo_labels = torch.zeros_like(distance_matrix) 215 | pseudo_labels[distance_matrix >= threshold_v] = 0.0 216 | pseudo_labels[distance_matrix < threshold_v] = 1.0 217 | elif method == 'threshold': 218 | # 3. use threshold to set PLs of instance with distance above the threshold to 1 219 | raise TypeError 220 | else: 221 | raise TypeError 222 | 223 | ## visulaize the pseudo_labels distribution of inside each bag 224 | # import matplotlib.pyplot as plt 225 | # plt.figure() 226 | # plt.hist(pseudo_labels.cpu().numpy().reshape(-1)) 227 | 228 | return pseudo_labels 229 | 230 | def optimize_student(self, epoch): 231 | self.model_teacherHead.train() 232 | self.model_encoder.train() 233 | self.model_studentHead.train() 234 | ## optimize teacher with instance-dataloader 235 | # 1. change loader to instance-loader 236 | loader = self.train_instanceloader 237 | # 2. optimize 238 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 239 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 240 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 241 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 242 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')): 243 | for i, j in enumerate(label): 244 | if torch.is_tensor(j): 245 | label[i] = j.to(self.dev) 246 | selected = selected.squeeze(0) 247 | niter = epoch * len(loader) + iter 248 | 249 | data = data.to(self.dev) 250 | 251 | # get teacher output of instance 252 | feat = self.model_encoder(data) 253 | with torch.no_grad(): 254 | instance_attn_score, _, _, _ = self.model_teacherHead(feat) 255 | pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score[:, 1]).clamp(min=1e-5, max=1-1e-5).squeeze(0) 256 | # set true negative patch label to [1, 0] 257 | pseudo_instance_label[label[1] == 0] = 0 258 | # # DEBUG: Assign GT patch label 259 | # pseudo_instance_label = label[0] 260 | # get student output of instance 261 | patch_prediction = self.model_studentHead(feat) 262 | patch_prediction = torch.softmax(patch_prediction, dim=1) 263 | 264 | # cal loss 265 | loss_student = -1. * torch.mean(self.stu_loss_weight_neg * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) + 266 | (1-self.stu_loss_weight_neg) * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5)) 267 | self.optimizer_encoder.zero_grad() 268 | self.optimizer_studentHead.zero_grad() 269 | loss_student.backward() 270 | self.optimizer_encoder.step() 271 | self.optimizer_studentHead.step() 272 | 273 | patch_corresponding_slide_idx[selected, 0] = label[2] 274 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 275 | patch_label_gt[selected, 0] = label[0] 276 | bag_label_gt[selected, 0] = label[1] 277 | if niter % self.log_period == 0: 278 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter) 279 | 280 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 281 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 282 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 283 | 284 | # cal bag-level auc 285 | bag_label_gt_coarse = [] 286 | bag_label_prediction = [] 287 | available_bag_idx = patch_corresponding_slide_idx.unique() 288 | for bag_idx_i in available_bag_idx: 289 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i) 290 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max(): 291 | raise 292 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max()) 293 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max()) 294 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse) 295 | bag_label_prediction = torch.tensor(bag_label_prediction) 296 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1)) 297 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 298 | return 0 299 | 300 | def optimize_student_fromBank(self, epoch, Bank_all_instances_pred): 301 | self.model_teacherHead.train() 302 | self.model_encoder.train() 303 | self.model_studentHead.train() 304 | ## optimize teacher with instance-dataloader 305 | # 1. change loader to instance-loader 306 | loader = self.train_instanceloader 307 | # 2. optimize 308 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 309 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 310 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 311 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 312 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')): 313 | for i, j in enumerate(label): 314 | if torch.is_tensor(j): 315 | label[i] = j.to(self.dev) 316 | selected = selected.squeeze(0) 317 | niter = epoch * len(loader) + iter 318 | 319 | data = data.to(self.dev) 320 | 321 | # get teacher output of instance 322 | feat = self.model_encoder(data) 323 | # with torch.no_grad(): 324 | # _, _, _, instance_attn_score = self.model_teacherHead(feat, returnBeforeSoftMaxA=True) 325 | # pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score).clamp(min=1e-5, max=1-1e-5).squeeze(0) 326 | # # set true negative patch label to [1, 0] 327 | # pseudo_instance_label[label[1] == 0] = 0 328 | 329 | pseudo_instance_label = Bank_all_instances_pred[selected//100, selected%100] 330 | # # DEBUG: Assign GT patch label 331 | # pseudo_instance_label = label[0] 332 | # get student output of instance 333 | patch_prediction = self.model_studentHead(feat) 334 | patch_prediction = torch.softmax(patch_prediction, dim=1) 335 | 336 | # cal loss 337 | loss_student = -1. * torch.mean(0.1 * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) + 338 | 0.9 * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5)) 339 | self.optimizer_encoder.zero_grad() 340 | self.optimizer_studentHead.zero_grad() 341 | loss_student.backward() 342 | self.optimizer_encoder.step() 343 | self.optimizer_studentHead.step() 344 | 345 | patch_corresponding_slide_idx[selected, 0] = label[2] 346 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 347 | patch_label_gt[selected, 0] = label[0] 348 | bag_label_gt[selected, 0] = label[1] 349 | if niter % self.log_period == 0: 350 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter) 351 | 352 | self.Bank_all_instances_pred_byStudent = patch_label_pred 353 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 354 | bag_auc_ByStudent = 0 355 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 356 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 357 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 358 | return 0 359 | 360 | def evaluate(self, epoch, loader, log_name_prefix=''): 361 | return 0 362 | 363 | def evaluate_teacher(self, epoch): 364 | self.model_encoder.eval() 365 | self.model_teacherHead.eval() 366 | ## optimize teacher with bag-dataloader 367 | # 1. change loader to bag-loader 368 | loader = self.test_bagloader 369 | # 2. optimize 370 | patch_label_gt = [] 371 | patch_label_pred = [] 372 | bag_label_gt = [] 373 | bag_label_prediction_withAttnScore = [] 374 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher evaluating')): 375 | for i, j in enumerate(label): 376 | if torch.is_tensor(j): 377 | label[i] = j.to(self.dev) 378 | selected = selected.squeeze(0) 379 | niter = epoch * len(loader) + iter 380 | 381 | data = data.to(self.dev) 382 | with torch.no_grad(): 383 | feat = self.model_encoder(data.squeeze(0)) 384 | 385 | instance_attn_score, bag_prediction_withAttnScore, _, _ = self.model_teacherHead(feat) 386 | bag_prediction_withAttnScore = torch.softmax(bag_prediction_withAttnScore, 1) 387 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1) 388 | 389 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0)) 390 | patch_label_gt.append(label[0].squeeze(0)) 391 | bag_label_prediction_withAttnScore.append(bag_prediction_withAttnScore.detach()[0, 1]) 392 | bag_label_gt.append(label[1]) 393 | 394 | patch_label_pred = torch.cat(patch_label_pred) 395 | patch_label_gt = torch.cat(patch_label_gt) 396 | bag_label_prediction_withAttnScore = torch.tensor(bag_label_prediction_withAttnScore) 397 | bag_label_gt = torch.cat(bag_label_gt) 398 | 399 | patch_label_pred_normed = (patch_label_pred - patch_label_pred.min()) / (patch_label_pred.max() - patch_label_pred.min()) 400 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1)) 401 | bag_auc_ByTeacher_withAttnScore = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_prediction_withAttnScore.reshape(-1)) 402 | self.writer.add_scalar('test_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch) 403 | self.writer.add_scalar('test_bag_AUC_byTeacher', bag_auc_ByTeacher_withAttnScore, epoch) 404 | return 0 405 | 406 | def evaluate_student(self, epoch): 407 | self.model_encoder.eval() 408 | self.model_studentHead.eval() 409 | ## optimize teacher with instance-dataloader 410 | # 1. change loader to instance-loader 411 | loader = self.test_instanceloader 412 | # 2. optimize 413 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 414 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 415 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 416 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 417 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student evaluating')): 418 | for i, j in enumerate(label): 419 | if torch.is_tensor(j): 420 | label[i] = j.to(self.dev) 421 | selected = selected.squeeze(0) 422 | niter = epoch * len(loader) + iter 423 | 424 | data = data.to(self.dev) 425 | 426 | # get student output of instance 427 | with torch.no_grad(): 428 | feat = self.model_encoder(data) 429 | patch_prediction = self.model_studentHead(feat) 430 | patch_prediction = torch.softmax(patch_prediction, dim=1) 431 | 432 | patch_corresponding_slide_idx[selected, 0] = label[2] 433 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 434 | patch_label_gt[selected, 0] = label[0] 435 | bag_label_gt[selected, 0] = label[1] 436 | 437 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 438 | self.writer.add_scalar('test_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 439 | # print("Epoch:{} test_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 440 | 441 | # cal bag-level auc 442 | bag_label_gt_coarse = [] 443 | bag_label_prediction = [] 444 | available_bag_idx = patch_corresponding_slide_idx.unique() 445 | for bag_idx_i in available_bag_idx: 446 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i) 447 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max(): 448 | raise 449 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max()) 450 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max()) 451 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse) 452 | bag_label_prediction = torch.tensor(bag_label_prediction) 453 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1)) 454 | self.writer.add_scalar('test_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 455 | return 0 456 | 457 | 458 | def str2bool(v): 459 | """ 460 | Input: 461 | v - string 462 | output: 463 | True/False 464 | """ 465 | if isinstance(v, bool): 466 | return v 467 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 468 | return True 469 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 470 | return False 471 | else: 472 | raise argparse.ArgumentTypeError('Boolean value expected.') 473 | 474 | 475 | def get_parser(): 476 | parser = argparse.ArgumentParser(description='PyTorch Implementation of Self-Label') 477 | # optimizer 478 | parser.add_argument('--epochs', default=1500, type=int, help='number of epochs') 479 | parser.add_argument('--batch_size', default=512, type=int, help='batch size (default: 256)') 480 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate (default: 0.05)') 481 | parser.add_argument('--lrdrop', default=1500, type=int, help='multiply LR by 0.5 every (default: 150 epochs)') 482 | parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: (-5)') 483 | parser.add_argument('--dtype', default='f64', choices=['f64', 'f32'], type=str, help='SK-algo dtype (default: f64)') 484 | 485 | # SK algo 486 | parser.add_argument('--nopts', default=100, type=int, help='number of pseudo-opts (default: 100)') 487 | parser.add_argument('--augs', default=3, type=int, help='augmentation level (default: 3)') 488 | parser.add_argument('--lamb', default=25, type=int, help='for pseudoopt: lambda (default:25) ') 489 | 490 | # architecture 491 | # parser.add_argument('--arch', default='alexnet_MNIST', type=str, help='alexnet or resnet (default: alexnet)') 492 | 493 | # housekeeping 494 | parser.add_argument('--device', default='0', type=str, help='GPU devices to use for storage and model') 495 | parser.add_argument('--modeldevice', default='0', type=str, help='GPU numbers on which the CNN runs') 496 | parser.add_argument('--exp', default='self-label-default', help='path to experiment directory') 497 | parser.add_argument('--workers', default=0, type=int,help='number workers (default: 6)') 498 | parser.add_argument('--comment', default='DEBUG_BagDistillation_DSMIL', type=str, help='name for tensorboardX') 499 | parser.add_argument('--log-intv', default=1, type=int, help='save stuff every x epochs (default: 1)') 500 | parser.add_argument('--log_iter', default=200, type=int, help='log every x-th batch (default: 200)') 501 | parser.add_argument('--seed', default=10, type=int, help='random seed') 502 | 503 | parser.add_argument('--PLPostProcessMethod', default='NegGuide', type=str, 504 | help='Post-processing method of Attention Scores to build Pseudo Lables', 505 | choices=['NegGuide', 'NegGuide_TopK', 'NegGuide_Similarity']) 506 | parser.add_argument('--StuFilterType', default='FilterNegInstance__ThreProb50', type=str, 507 | help='Type of using Student Prediction to imporve Teacher ' 508 | '[ReplaceAS, FilterNegInstance_Top95, FilterNegInstance__ThreProb90]') 509 | parser.add_argument('--smoothE', default=9999, type=int, help='num of epoch to apply StuFilter') 510 | parser.add_argument('--stu_loss_weight_neg', default=0.1, type=float, help='weight of neg instances in stu training') 511 | parser.add_argument('--stuOptPeriod', default=1, type=int, help='period of stu optimization') 512 | return parser.parse_args() 513 | 514 | 515 | if __name__ == "__main__": 516 | args = get_parser() 517 | 518 | # torch.manual_seed(args.seed) 519 | # random.seed(args.seed) 520 | # np.random.seed(args.seed) 521 | 522 | name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")+"_%s" % args.comment.replace('/', '_') + \ 523 | "_Seed{}_Bs{}_lr{}_PLPostProcessBy{}_StuFilterType{}_smoothE{}_weightN{}_StuOptP{}".format( 524 | args.seed, args.batch_size, args.lr, 525 | args.PLPostProcessMethod, args.StuFilterType, args.smoothE, args.stu_loss_weight_neg, args.stuOptPeriod) 526 | try: 527 | args.device = [int(item) for item in args.device.split(',')] 528 | except AttributeError: 529 | args.device = [int(args.device)] 530 | args.modeldevice = args.device 531 | util.setup_runtime(seed=42, cuda_dev_id=list(np.unique(args.modeldevice + args.device))) 532 | 533 | print(name, flush=True) 534 | 535 | writer = SummaryWriter('./runs_TCGA/%s'%name) 536 | writer.add_text('args', " \n".join(['%s %s' % (arg, getattr(args, arg)) for arg in vars(args)])) 537 | 538 | # Setup model 539 | model_encoder = camelyon_feat_projecter(input_dim=512, output_dim=512).to('cuda:0') 540 | model_teacherHead = teacher_DSMIL_head(input_feat_dim=512).to('cuda:0') 541 | model_studentHead = student_head(input_feat_dim=512).to('cuda:0') 542 | 543 | optimizer_encoder = torch.optim.SGD(model_encoder.parameters(), lr=args.lr) 544 | optimizer_teacherHead = torch.optim.SGD(model_teacherHead.parameters(), lr=args.lr) 545 | optimizer_studentHead = torch.optim.SGD(model_studentHead.parameters(), lr=args.lr) 546 | 547 | # Setup loaders 548 | train_ds_return_instance = TCGA_LungCancer_Feat(train=True, return_bag=False) 549 | 550 | train_ds_return_bag = copy.deepcopy(train_ds_return_instance) 551 | train_ds_return_bag.return_bag = True 552 | val_ds_return_instance = TCGA_LungCancer_Feat(train=False, return_bag=False) 553 | val_ds_return_bag = TCGA_LungCancer_Feat(train=False, return_bag=True) 554 | 555 | train_loader_instance = torch.utils.data.DataLoader(train_ds_return_instance, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False) 556 | train_loader_bag = torch.utils.data.DataLoader(train_ds_return_bag, batch_size=1, shuffle=True, num_workers=args.workers, drop_last=False) 557 | val_loader_instance = torch.utils.data.DataLoader(val_ds_return_instance, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False) 558 | val_loader_bag = torch.utils.data.DataLoader(val_ds_return_bag, batch_size=1, shuffle=False, num_workers=args.workers, drop_last=False) 559 | 560 | print("[Data] {} training samples".format(len(train_loader_instance.dataset))) 561 | print("[Data] {} evaluating samples".format(len(val_loader_instance.dataset))) 562 | 563 | if torch.cuda.device_count() > 1: 564 | print("Let's use", len(args.modeldevice), "GPUs for the model") 565 | if len(args.modeldevice) == 1: 566 | print('single GPU model', flush=True) 567 | else: 568 | model_encoder = nn.DataParallel(model_encoder, device_ids=list(range(len(args.modeldevice)))) 569 | model_teacherHead = nn.DataParallel(model_teacherHead, device_ids=list(range(len(args.modeldevice)))) 570 | 571 | # Setup optimizer 572 | o = Optimizer(model_encoder=model_encoder, model_teacherHead=model_teacherHead, model_studentHead=model_studentHead, 573 | optimizer_encoder=optimizer_encoder, optimizer_teacherHead=optimizer_teacherHead, optimizer_studentHead=optimizer_studentHead, 574 | train_bagloader=train_loader_bag, train_instanceloader=train_loader_instance, 575 | test_bagloader=val_loader_bag, test_instanceloader=val_loader_instance, 576 | writer=writer, num_epoch=args.epochs, 577 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 578 | PLPostProcessMethod=args.PLPostProcessMethod, StuFilterType=args.StuFilterType, smoothE=args.smoothE, 579 | stu_loss_weight_neg=args.stu_loss_weight_neg, stuOptPeriod=args.stuOptPeriod) 580 | # Optimize 581 | o.optimize() -------------------------------------------------------------------------------- /train_CervicalFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | import os 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim 9 | import torch.nn as nn 10 | import torch.utils.data 11 | from tensorboardX import SummaryWriter 12 | # import models 13 | # from models.alexnet import alexnet_CIFAR10, alexnet_CIFAR10_Attention 14 | from models.alexnet import camelyon_feat_projecter, teacher_DSMIL_head, student_head 15 | # from dataset_toy import Dataset_toy 16 | # from Datasets_loader.dataset_MNIST_challenge import MNIST_WholeSlide_challenge 17 | # from Datasets_loader.dataset_MIL_CIFAR import CIFAR_WholeSlide_challenge 18 | from Datasets_loader.dataset_CervicalCancer import CervicalCaner_16_feat 19 | import datetime 20 | import utliz 21 | import util 22 | import random 23 | from tqdm import tqdm 24 | import copy 25 | 26 | 27 | class Optimizer: 28 | def __init__(self, model_encoder, model_teacherHead, model_studentHead, 29 | optimizer_encoder, optimizer_teacherHead, optimizer_studentHead, 30 | train_bagloader, train_instanceloader, test_bagloader, test_instanceloader, 31 | writer=None, num_epoch=100, 32 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 33 | PLPostProcessMethod='NegGuide', StuFilterType='ReplaceAS', smoothE=100, 34 | stu_loss_weight_neg=0.1, stuOptPeriod=1): 35 | self.model_encoder = model_encoder 36 | self.model_teacherHead = model_teacherHead 37 | self.model_studentHead = model_studentHead 38 | self.optimizer_encoder = optimizer_encoder 39 | self.optimizer_teacherHead = optimizer_teacherHead 40 | self.optimizer_studentHead = optimizer_studentHead 41 | self.train_bagloader = train_bagloader 42 | self.train_instanceloader = train_instanceloader 43 | self.test_bagloader = test_bagloader 44 | self.test_instanceloader = test_instanceloader 45 | self.writer = writer 46 | self.num_epoch = num_epoch 47 | self.dev = dev 48 | self.log_period = 10 49 | self.PLPostProcessMethod = PLPostProcessMethod 50 | self.StuFilterType = StuFilterType 51 | self.smoothE = smoothE 52 | self.stu_loss_weight_neg = stu_loss_weight_neg 53 | self.stuOptPeriod = stuOptPeriod 54 | 55 | def optimize(self): 56 | self.Bank_all_Bags_label = None 57 | self.Bank_all_instances_pred_byTeacher = None 58 | self.Bank_all_instances_feat_byTeacher = None 59 | self.Bank_all_instances_pred_processed = None 60 | 61 | self.Bank_all_instances_pred_byStudent = None 62 | 63 | # Load pre-extracted SimCLR features 64 | # pre_trained_SimCLR_feat = self.train_instanceloader.dataset.ds_data_simCLR_feat[self.train_instanceloader.dataset.idx_all_slides].to(self.dev) 65 | for epoch in range(self.num_epoch): 66 | self.optimize_teacher(epoch) 67 | self.evaluate_teacher(epoch) 68 | if epoch % self.stuOptPeriod == 0: 69 | self.optimize_student(epoch) 70 | self.evaluate_student(epoch) 71 | 72 | return 0 73 | 74 | def optimize_teacher(self, epoch): 75 | self.model_encoder.train() 76 | self.model_teacherHead.train() 77 | self.model_studentHead.eval() 78 | criterion = torch.nn.CrossEntropyLoss() 79 | ## optimize teacher with bag-dataloader 80 | # 1. change loader to bag-loader 81 | loader = self.train_bagloader 82 | # 2. optimize 83 | patch_label_gt = [] 84 | patch_label_pred = [] 85 | bag_label_gt = [] 86 | bag_label_pred = [] 87 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher training')): 88 | for i, j in enumerate(label): 89 | if torch.is_tensor(j): 90 | label[i] = j.to(self.dev) 91 | selected = selected.squeeze(0) 92 | niter = epoch * len(loader) + iter 93 | 94 | data = data.to(self.dev) 95 | feat = self.model_encoder(data.squeeze(0)) 96 | if epoch > self.smoothE: 97 | if "FilterNegInstance" in self.StuFilterType: 98 | # using student prediction to remove negative instance feat in the positive bag 99 | if label[1] == 1: 100 | with torch.no_grad(): 101 | pred_byStudent = self.model_studentHead(feat) 102 | pred_byStudent = torch.softmax(pred_byStudent, dim=1)[:, 1] 103 | if '_Top' in self.StuFilterType: 104 | # strategy A: remove the topK most negative instance 105 | idx_to_keep = torch.topk(-pred_byStudent, k=int(self.StuFilterType.split('_Top')[-1]))[1] 106 | elif '_ThreProb' in self.StuFilterType: 107 | # strategy B: remove the negative instance above prob K 108 | idx_to_keep = torch.where(pred_byStudent <= int(self.StuFilterType.split('_ThreProb')[-1])/100.0)[0] 109 | if idx_to_keep.shape[0] == 0: # if all instance are dropped, keep the most positive one 110 | idx_to_keep = torch.topk(pred_byStudent, k=1)[1] 111 | feat_removedNeg = feat[idx_to_keep] 112 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat_removedNeg) 113 | instance_attn_score = torch.cat([instance_attn_score, instance_attn_score[:, 1].min()*torch.ones(feat.shape[0]-instance_attn_score.shape[0], 2).to(instance_attn_score.device)], dim=0) 114 | else: 115 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 116 | else: 117 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 118 | else: 119 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 120 | 121 | max_id = torch.argmax(instance_attn_score[:, 1]) 122 | bag_pred_byMax = instance_attn_score[max_id, :].squeeze(0) 123 | bag_loss = criterion(bag_prediction, label[1]) 124 | bag_loss_byMax = criterion(bag_pred_byMax.unsqueeze(0), label[1]) 125 | loss_teacher = 0.5 * bag_loss + 0.5 * bag_loss_byMax 126 | 127 | self.optimizer_encoder.zero_grad() 128 | self.optimizer_teacherHead.zero_grad() 129 | loss_teacher.backward() 130 | self.optimizer_encoder.step() 131 | self.optimizer_teacherHead.step() 132 | 133 | bag_prediction = 1.0 * torch.softmax(bag_prediction, dim=1) + \ 134 | 0.0 * torch.softmax(bag_pred_byMax.unsqueeze(0), dim=1) 135 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1) 136 | 137 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0)) 138 | patch_label_gt.append(label[0].squeeze(0)) 139 | bag_label_pred.append(bag_prediction.detach()[0, 1]) 140 | bag_label_gt.append(label[1]) 141 | if niter % self.log_period == 0: 142 | self.writer.add_scalar('train_loss_Teacher', loss_teacher.item(), niter) 143 | 144 | patch_label_pred = torch.cat(patch_label_pred) 145 | patch_label_gt = torch.cat(patch_label_gt) 146 | bag_label_pred = torch.tensor(bag_label_pred) 147 | bag_label_gt = torch.cat(bag_label_gt) 148 | 149 | self.estimated_AttnScore_norm_para_min = patch_label_pred.min() 150 | self.estimated_AttnScore_norm_para_max = patch_label_pred.max() 151 | patch_label_pred_normed = self.norm_AttnScore2Prob(patch_label_pred) 152 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1)) 153 | 154 | bag_auc_ByTeacher = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_pred.reshape(-1)) 155 | self.writer.add_scalar('train_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch) 156 | self.writer.add_scalar('train_bag_AUC_byTeacher', bag_auc_ByTeacher, epoch) 157 | # print("Epoch:{} train_bag_AUC_byTeacher:{}".format(epoch, bag_auc_ByTeacher)) 158 | return 0 159 | 160 | def norm_AttnScore2Prob(self, attn_score): 161 | prob = (attn_score - self.estimated_AttnScore_norm_para_min) / (self.estimated_AttnScore_norm_para_max - self.estimated_AttnScore_norm_para_min) 162 | return prob 163 | 164 | def post_process_pred_byTeacher(self, Bank_all_instances_feat, Bank_all_instances_pred, Bank_all_bags_label, method='NegGuide'): 165 | if method=='NegGuide': 166 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 167 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 168 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 169 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 170 | elif method=='NegGuide_TopK': 171 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 172 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 173 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0] 174 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 175 | K = 3 176 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1] 177 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1) 178 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 179 | elif method=='NegGuide_Similarity': 180 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 181 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 182 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0] 183 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 184 | K = 1 185 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1] 186 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1) 187 | Bank_all_Pos_instances_feat = Bank_all_instances_feat[idx_pos_bag] 188 | Bank_mostSalient_Pos_instances_feat = [] 189 | for i in range(Bank_all_Pos_instances_feat.shape[0]): 190 | Bank_mostSalient_Pos_instances_feat.append(Bank_all_Pos_instances_feat[i, idx_topK_inside_pos_bag[i, 0], :].unsqueeze(0).unsqueeze(0)) 191 | Bank_mostSalient_Pos_instances_feat = torch.cat(Bank_mostSalient_Pos_instances_feat, dim=0) 192 | 193 | distance_matrix = Bank_all_Pos_instances_feat - Bank_mostSalient_Pos_instances_feat 194 | distance_matrix = torch.norm(distance_matrix, dim=-1, p=2) 195 | Bank_all_instances_pred_processed[idx_pos_bag, :] = self.distanceMatrix2PL(distance_matrix) 196 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 197 | else: 198 | raise TypeError 199 | return Bank_all_instances_pred_processed 200 | 201 | def distanceMatrix2PL(self, distance_matrix, method='percentage'): 202 | # distance_matrix is of shape NxL (Num of Positive Bag * Bag Length) 203 | # represents the distance between each instance with their corresponding most salient instance 204 | # return Pseudo-labels of shape NxL (value should belong to [0,1]) 205 | 206 | if method == 'softMax': 207 | # 1. just use softMax to keep PLs value fall into [0,1] 208 | similarity_matrix = 1/(distance_matrix + 1e-5) 209 | pseudo_labels = torch.softmax(similarity_matrix, dim=1) 210 | elif method == 'percentage': 211 | # 2. use percentage to keep n% PL=1, 1-n% PL=0 212 | p = 0.1 # 10% is set 213 | threshold_v = distance_matrix.topk(k=int(100 * p), dim=1)[0][:, -1].unsqueeze(1).repeat([1, 100]) # of size Nx100 214 | pseudo_labels = torch.zeros_like(distance_matrix) 215 | pseudo_labels[distance_matrix >= threshold_v] = 0.0 216 | pseudo_labels[distance_matrix < threshold_v] = 1.0 217 | elif method == 'threshold': 218 | # 3. use threshold to set PLs of instance with distance above the threshold to 1 219 | raise TypeError 220 | else: 221 | raise TypeError 222 | 223 | ## visulaize the pseudo_labels distribution of inside each bag 224 | # import matplotlib.pyplot as plt 225 | # plt.figure() 226 | # plt.hist(pseudo_labels.cpu().numpy().reshape(-1)) 227 | 228 | return pseudo_labels 229 | 230 | def optimize_student(self, epoch): 231 | self.model_teacherHead.train() 232 | self.model_encoder.train() 233 | self.model_studentHead.train() 234 | ## optimize teacher with instance-dataloader 235 | # 1. change loader to instance-loader 236 | loader = self.train_instanceloader 237 | # 2. optimize 238 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 239 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 240 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 241 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 242 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')): 243 | for i, j in enumerate(label): 244 | if torch.is_tensor(j): 245 | label[i] = j.to(self.dev) 246 | selected = selected.squeeze(0) 247 | niter = epoch * len(loader) + iter 248 | 249 | data = data.to(self.dev) 250 | 251 | # get teacher output of instance 252 | feat = self.model_encoder(data) 253 | with torch.no_grad(): 254 | instance_attn_score, _, _, _ = self.model_teacherHead(feat) 255 | pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score[:, 1]).clamp(min=1e-5, max=1-1e-5).squeeze(0) 256 | # set true negative patch label to [1, 0] 257 | pseudo_instance_label[label[1] == 0] = 0 258 | # # DEBUG: Assign GT patch label 259 | # pseudo_instance_label = label[0] 260 | # get student output of instance 261 | patch_prediction = self.model_studentHead(feat) 262 | patch_prediction = torch.softmax(patch_prediction, dim=1) 263 | 264 | # cal loss 265 | loss_student = -1. * torch.mean(self.stu_loss_weight_neg * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) + 266 | (1-self.stu_loss_weight_neg) * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5)) 267 | self.optimizer_encoder.zero_grad() 268 | self.optimizer_studentHead.zero_grad() 269 | loss_student.backward() 270 | self.optimizer_encoder.step() 271 | self.optimizer_studentHead.step() 272 | 273 | patch_corresponding_slide_idx[selected, 0] = label[2] 274 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 275 | patch_label_gt[selected, 0] = label[0] 276 | bag_label_gt[selected, 0] = label[1] 277 | if niter % self.log_period == 0: 278 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter) 279 | 280 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 281 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 282 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 283 | 284 | # cal bag-level auc 285 | bag_label_gt_coarse = [] 286 | bag_label_prediction = [] 287 | available_bag_idx = patch_corresponding_slide_idx.unique() 288 | for bag_idx_i in available_bag_idx: 289 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i) 290 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max(): 291 | raise 292 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max()) 293 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max()) 294 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse) 295 | bag_label_prediction = torch.tensor(bag_label_prediction) 296 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1)) 297 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 298 | return 0 299 | 300 | def optimize_student_fromBank(self, epoch, Bank_all_instances_pred): 301 | self.model_teacherHead.train() 302 | self.model_encoder.train() 303 | self.model_studentHead.train() 304 | ## optimize teacher with instance-dataloader 305 | # 1. change loader to instance-loader 306 | loader = self.train_instanceloader 307 | # 2. optimize 308 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 309 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 310 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 311 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 312 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')): 313 | for i, j in enumerate(label): 314 | if torch.is_tensor(j): 315 | label[i] = j.to(self.dev) 316 | selected = selected.squeeze(0) 317 | niter = epoch * len(loader) + iter 318 | 319 | data = data.to(self.dev) 320 | 321 | # get teacher output of instance 322 | feat = self.model_encoder(data) 323 | # with torch.no_grad(): 324 | # _, _, _, instance_attn_score = self.model_teacherHead(feat, returnBeforeSoftMaxA=True) 325 | # pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score).clamp(min=1e-5, max=1-1e-5).squeeze(0) 326 | # # set true negative patch label to [1, 0] 327 | # pseudo_instance_label[label[1] == 0] = 0 328 | 329 | pseudo_instance_label = Bank_all_instances_pred[selected//100, selected%100] 330 | # # DEBUG: Assign GT patch label 331 | # pseudo_instance_label = label[0] 332 | # get student output of instance 333 | patch_prediction = self.model_studentHead(feat) 334 | patch_prediction = torch.softmax(patch_prediction, dim=1) 335 | 336 | # cal loss 337 | loss_student = -1. * torch.mean(0.1 * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) + 338 | 0.9 * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5)) 339 | self.optimizer_encoder.zero_grad() 340 | self.optimizer_studentHead.zero_grad() 341 | loss_student.backward() 342 | self.optimizer_encoder.step() 343 | self.optimizer_studentHead.step() 344 | 345 | patch_corresponding_slide_idx[selected, 0] = label[2] 346 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 347 | patch_label_gt[selected, 0] = label[0] 348 | bag_label_gt[selected, 0] = label[1] 349 | if niter % self.log_period == 0: 350 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter) 351 | 352 | self.Bank_all_instances_pred_byStudent = patch_label_pred 353 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 354 | bag_auc_ByStudent = 0 355 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 356 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 357 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 358 | return 0 359 | 360 | def evaluate(self, epoch, loader, log_name_prefix=''): 361 | return 0 362 | 363 | def evaluate_teacher(self, epoch): 364 | self.model_encoder.eval() 365 | self.model_teacherHead.eval() 366 | ## optimize teacher with bag-dataloader 367 | # 1. change loader to bag-loader 368 | loader = self.test_bagloader 369 | # 2. optimize 370 | patch_label_gt = [] 371 | patch_label_pred = [] 372 | bag_label_gt = [] 373 | bag_label_prediction_withAttnScore = [] 374 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher evaluating')): 375 | for i, j in enumerate(label): 376 | if torch.is_tensor(j): 377 | label[i] = j.to(self.dev) 378 | selected = selected.squeeze(0) 379 | niter = epoch * len(loader) + iter 380 | 381 | data = data.to(self.dev) 382 | with torch.no_grad(): 383 | feat = self.model_encoder(data.squeeze(0)) 384 | 385 | instance_attn_score, bag_prediction_withAttnScore, _, _ = self.model_teacherHead(feat) 386 | bag_prediction_withAttnScore = torch.softmax(bag_prediction_withAttnScore, 1) 387 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1) 388 | 389 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0)) 390 | patch_label_gt.append(label[0].squeeze(0)) 391 | bag_label_prediction_withAttnScore.append(bag_prediction_withAttnScore.detach()[0, 1]) 392 | bag_label_gt.append(label[1]) 393 | 394 | patch_label_pred = torch.cat(patch_label_pred) 395 | patch_label_gt = torch.cat(patch_label_gt) 396 | bag_label_prediction_withAttnScore = torch.tensor(bag_label_prediction_withAttnScore) 397 | bag_label_gt = torch.cat(bag_label_gt) 398 | 399 | patch_label_pred_normed = (patch_label_pred - patch_label_pred.min()) / (patch_label_pred.max() - patch_label_pred.min()) 400 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1)) 401 | bag_auc_ByTeacher_withAttnScore = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_prediction_withAttnScore.reshape(-1)) 402 | self.writer.add_scalar('test_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch) 403 | self.writer.add_scalar('test_bag_AUC_byTeacher', bag_auc_ByTeacher_withAttnScore, epoch) 404 | return 0 405 | 406 | def evaluate_student(self, epoch): 407 | self.model_encoder.eval() 408 | self.model_studentHead.eval() 409 | ## optimize teacher with instance-dataloader 410 | # 1. change loader to instance-loader 411 | loader = self.test_instanceloader 412 | # 2. optimize 413 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 414 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 415 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 416 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 417 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student evaluating')): 418 | for i, j in enumerate(label): 419 | if torch.is_tensor(j): 420 | label[i] = j.to(self.dev) 421 | selected = selected.squeeze(0) 422 | niter = epoch * len(loader) + iter 423 | 424 | data = data.to(self.dev) 425 | 426 | # get student output of instance 427 | with torch.no_grad(): 428 | feat = self.model_encoder(data) 429 | patch_prediction = self.model_studentHead(feat) 430 | patch_prediction = torch.softmax(patch_prediction, dim=1) 431 | 432 | patch_corresponding_slide_idx[selected, 0] = label[2] 433 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 434 | patch_label_gt[selected, 0] = label[0] 435 | bag_label_gt[selected, 0] = label[1] 436 | 437 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 438 | self.writer.add_scalar('test_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 439 | # print("Epoch:{} test_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 440 | 441 | # cal bag-level auc 442 | bag_label_gt_coarse = [] 443 | bag_label_prediction = [] 444 | available_bag_idx = patch_corresponding_slide_idx.unique() 445 | for bag_idx_i in available_bag_idx: 446 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i) 447 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max(): 448 | raise 449 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max()) 450 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max()) 451 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse) 452 | bag_label_prediction = torch.tensor(bag_label_prediction) 453 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1)) 454 | self.writer.add_scalar('test_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 455 | return 0 456 | 457 | 458 | def str2bool(v): 459 | """ 460 | Input: 461 | v - string 462 | output: 463 | True/False 464 | """ 465 | if isinstance(v, bool): 466 | return v 467 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 468 | return True 469 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 470 | return False 471 | else: 472 | raise argparse.ArgumentTypeError('Boolean value expected.') 473 | 474 | 475 | def get_parser(): 476 | parser = argparse.ArgumentParser(description='PyTorch Implementation of Self-Label') 477 | # optimizer 478 | parser.add_argument('--epochs', default=1500, type=int, help='number of epochs') 479 | parser.add_argument('--batch_size', default=512, type=int, help='batch size (default: 256)') 480 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate (default: 0.05)') 481 | parser.add_argument('--lrdrop', default=1500, type=int, help='multiply LR by 0.5 every (default: 150 epochs)') 482 | parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: (-5)') 483 | parser.add_argument('--dtype', default='f64', choices=['f64', 'f32'], type=str, help='SK-algo dtype (default: f64)') 484 | 485 | # SK algo 486 | parser.add_argument('--nopts', default=100, type=int, help='number of pseudo-opts (default: 100)') 487 | parser.add_argument('--augs', default=3, type=int, help='augmentation level (default: 3)') 488 | parser.add_argument('--lamb', default=25, type=int, help='for pseudoopt: lambda (default:25) ') 489 | 490 | # architecture 491 | # parser.add_argument('--arch', default='alexnet_MNIST', type=str, help='alexnet or resnet (default: alexnet)') 492 | 493 | # housekeeping 494 | parser.add_argument('--device', default='0', type=str, help='GPU devices to use for storage and model') 495 | parser.add_argument('--modeldevice', default='0', type=str, help='GPU numbers on which the CNN runs') 496 | parser.add_argument('--exp', default='self-label-default', help='path to experiment directory') 497 | parser.add_argument('--workers', default=0, type=int,help='number workers (default: 6)') 498 | parser.add_argument('--comment', default='DEBUG_BagDistillation_DSMIL_BasedOnFeat', type=str, help='name for tensorboardX') 499 | parser.add_argument('--log-intv', default=1, type=int, help='save stuff every x epochs (default: 1)') 500 | parser.add_argument('--log_iter', default=200, type=int, help='log every x-th batch (default: 200)') 501 | parser.add_argument('--seed', default=10, type=int, help='random seed') 502 | 503 | parser.add_argument('--PLPostProcessMethod', default='NegGuide', type=str, 504 | help='Post-processing method of Attention Scores to build Pseudo Lables', 505 | choices=['NegGuide', 'NegGuide_TopK', 'NegGuide_Similarity']) 506 | parser.add_argument('--StuFilterType', default='FilterNegInstance__ThreProb50', type=str, 507 | help='Type of using Student Prediction to imporve Teacher ' 508 | '[ReplaceAS, FilterNegInstance_Top95, FilterNegInstance__ThreProb90]') 509 | parser.add_argument('--smoothE', default=9999, type=int, help='num of epoch to apply StuFilter') 510 | parser.add_argument('--stu_loss_weight_neg', default=0.1, type=float, help='weight of neg instances in stu training') 511 | parser.add_argument('--stuOptPeriod', default=1, type=int, help='period of stu optimization') 512 | return parser.parse_args() 513 | 514 | 515 | if __name__ == "__main__": 516 | args = get_parser() 517 | 518 | # torch.manual_seed(args.seed) 519 | # random.seed(args.seed) 520 | # np.random.seed(args.seed) 521 | 522 | name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")+"_%s" % args.comment.replace('/', '_') + \ 523 | "_Seed{}_Bs{}_lr{}_PLPostProcessBy{}_StuFilterType{}_smoothE{}_weightN{}_StuOptP{}".format( 524 | args.seed, args.batch_size, args.lr, 525 | args.PLPostProcessMethod, args.StuFilterType, args.smoothE, args.stu_loss_weight_neg, args.stuOptPeriod) 526 | try: 527 | args.device = [int(item) for item in args.device.split(',')] 528 | except AttributeError: 529 | args.device = [int(args.device)] 530 | args.modeldevice = args.device 531 | util.setup_runtime(seed=42, cuda_dev_id=list(np.unique(args.modeldevice + args.device))) 532 | 533 | print(name, flush=True) 534 | 535 | writer = SummaryWriter('./runs_Cervical/%s'%name) 536 | writer.add_text('args', " \n".join(['%s %s' % (arg, getattr(args, arg)) for arg in vars(args)])) 537 | 538 | # Setup model 539 | model_encoder = camelyon_feat_projecter(input_dim=512, output_dim=512).to('cuda:0') 540 | model_teacherHead = teacher_DSMIL_head(input_feat_dim=512).to('cuda:0') 541 | model_studentHead = student_head(input_feat_dim=512).to('cuda:0') 542 | 543 | optimizer_encoder = torch.optim.SGD(model_encoder.parameters(), lr=args.lr) 544 | optimizer_teacherHead = torch.optim.SGD(model_teacherHead.parameters(), lr=args.lr) 545 | optimizer_studentHead = torch.optim.SGD(model_studentHead.parameters(), lr=args.lr) 546 | 547 | # Setup loaders 548 | train_ds_return_instance = CervicalCaner_16_feat(train=True, return_bag=False) 549 | 550 | train_ds_return_bag = copy.deepcopy(train_ds_return_instance) 551 | train_ds_return_bag.return_bag = True 552 | val_ds_return_instance = CervicalCaner_16_feat(train=False, return_bag=False) 553 | val_ds_return_bag = CervicalCaner_16_feat(train=False, return_bag=True) 554 | 555 | train_loader_instance = torch.utils.data.DataLoader(train_ds_return_instance, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False) 556 | train_loader_bag = torch.utils.data.DataLoader(train_ds_return_bag, batch_size=1, shuffle=True, num_workers=args.workers, drop_last=False) 557 | val_loader_instance = torch.utils.data.DataLoader(val_ds_return_instance, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False) 558 | val_loader_bag = torch.utils.data.DataLoader(val_ds_return_bag, batch_size=1, shuffle=False, num_workers=args.workers, drop_last=False) 559 | 560 | print("[Data] {} training samples".format(len(train_loader_instance.dataset))) 561 | print("[Data] {} evaluating samples".format(len(val_loader_instance.dataset))) 562 | 563 | if torch.cuda.device_count() > 1: 564 | print("Let's use", len(args.modeldevice), "GPUs for the model") 565 | if len(args.modeldevice) == 1: 566 | print('single GPU model', flush=True) 567 | else: 568 | model_encoder = nn.DataParallel(model_encoder, device_ids=list(range(len(args.modeldevice)))) 569 | model_teacherHead = nn.DataParallel(model_teacherHead, device_ids=list(range(len(args.modeldevice)))) 570 | 571 | # Setup optimizer 572 | o = Optimizer(model_encoder=model_encoder, model_teacherHead=model_teacherHead, model_studentHead=model_studentHead, 573 | optimizer_encoder=optimizer_encoder, optimizer_teacherHead=optimizer_teacherHead, optimizer_studentHead=optimizer_studentHead, 574 | train_bagloader=train_loader_bag, train_instanceloader=train_loader_instance, 575 | test_bagloader=val_loader_bag, test_instanceloader=val_loader_instance, 576 | writer=writer, num_epoch=args.epochs, 577 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 578 | PLPostProcessMethod=args.PLPostProcessMethod, StuFilterType=args.StuFilterType, smoothE=args.smoothE, 579 | stu_loss_weight_neg=args.stu_loss_weight_neg, stuOptPeriod=args.stuOptPeriod) 580 | # Optimize 581 | o.optimize() -------------------------------------------------------------------------------- /train_CAMELYONFeat_BagDistillationDSMIL_SharedEnc_Similarity_StuFilterSmoothed_DropPos.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | import os 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim 9 | import torch.nn as nn 10 | import torch.utils.data 11 | from tensorboardX import SummaryWriter 12 | # import models 13 | # from models.alexnet import alexnet_CIFAR10, alexnet_CIFAR10_Attention 14 | from models.alexnet import camelyon_feat_projecter, teacher_DSMIL_head, student_head 15 | # from dataset_toy import Dataset_toy 16 | # from Datasets_loader.dataset_MNIST_challenge import MNIST_WholeSlide_challenge 17 | # from Datasets_loader.dataset_MIL_CIFAR import CIFAR_WholeSlide_challenge 18 | from Datasets_loader.dataset_CAMELYON16_BasedOnFeat import CAMELYON_16_feat 19 | import datetime 20 | import utliz 21 | import util 22 | import random 23 | from tqdm import tqdm 24 | import copy 25 | 26 | 27 | class Optimizer: 28 | def __init__(self, model_encoder, model_teacherHead, model_studentHead, 29 | optimizer_encoder, optimizer_teacherHead, optimizer_studentHead, 30 | train_bagloader, train_instanceloader, test_bagloader, test_instanceloader, 31 | writer=None, num_epoch=100, 32 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 33 | PLPostProcessMethod='NegGuide', StuFilterType='ReplaceAS', smoothE=100, 34 | stu_loss_weight_neg=0.1, stuOptPeriod=1): 35 | self.model_encoder = model_encoder 36 | self.model_teacherHead = model_teacherHead 37 | self.model_studentHead = model_studentHead 38 | self.optimizer_encoder = optimizer_encoder 39 | self.optimizer_teacherHead = optimizer_teacherHead 40 | self.optimizer_studentHead = optimizer_studentHead 41 | self.train_bagloader = train_bagloader 42 | self.train_instanceloader = train_instanceloader 43 | self.test_bagloader = test_bagloader 44 | self.test_instanceloader = test_instanceloader 45 | self.writer = writer 46 | self.num_epoch = num_epoch 47 | self.dev = dev 48 | self.log_period = 10 49 | self.PLPostProcessMethod = PLPostProcessMethod 50 | self.StuFilterType = StuFilterType 51 | self.smoothE = smoothE 52 | self.stu_loss_weight_neg = stu_loss_weight_neg 53 | self.stuOptPeriod = stuOptPeriod 54 | 55 | def optimize(self): 56 | self.Bank_all_Bags_label = None 57 | self.Bank_all_instances_pred_byTeacher = None 58 | self.Bank_all_instances_feat_byTeacher = None 59 | self.Bank_all_instances_pred_processed = None 60 | 61 | self.Bank_all_instances_pred_byStudent = None 62 | 63 | # Load pre-extracted SimCLR features 64 | # pre_trained_SimCLR_feat = self.train_instanceloader.dataset.ds_data_simCLR_feat[self.train_instanceloader.dataset.idx_all_slides].to(self.dev) 65 | for epoch in range(self.num_epoch): 66 | self.optimize_teacher(epoch) 67 | self.evaluate_teacher(epoch) 68 | if epoch % self.stuOptPeriod == 0: 69 | self.optimize_student(epoch) 70 | self.evaluate_student(epoch) 71 | 72 | return 0 73 | 74 | def optimize_teacher(self, epoch): 75 | self.model_encoder.train() 76 | self.model_teacherHead.train() 77 | self.model_studentHead.eval() 78 | criterion = torch.nn.CrossEntropyLoss() 79 | ## optimize teacher with bag-dataloader 80 | # 1. change loader to bag-loader 81 | loader = self.train_bagloader 82 | # 2. optimize 83 | patch_label_gt = [] 84 | patch_label_pred = [] 85 | bag_label_gt = [] 86 | bag_label_pred = [] 87 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher training')): 88 | for i, j in enumerate(label): 89 | if torch.is_tensor(j): 90 | label[i] = j.to(self.dev) 91 | selected = selected.squeeze(0) 92 | niter = epoch * len(loader) + iter 93 | 94 | data = data.to(self.dev) 95 | feat = self.model_encoder(data.squeeze(0)) 96 | if epoch > self.smoothE: 97 | if "FilterNegInstance" in self.StuFilterType: 98 | # using student prediction to remove negative instance feat in the positive bag 99 | if label[1] == 1: 100 | with torch.no_grad(): 101 | pred_byStudent = self.model_studentHead(feat) 102 | pred_byStudent = torch.softmax(pred_byStudent, dim=1)[:, 1] 103 | if '_Top' in self.StuFilterType: 104 | # strategy A: remove the topK most negative instance 105 | idx_to_keep = torch.topk(-pred_byStudent, k=int(self.StuFilterType.split('_Top')[-1]))[1] 106 | elif '_ThreProb' in self.StuFilterType: 107 | # strategy B: remove the negative instance above prob K 108 | idx_to_keep = torch.where(pred_byStudent <= int(self.StuFilterType.split('_ThreProb')[-1])/100.0)[0] 109 | if idx_to_keep.shape[0] == 0: # if all instance are dropped, keep the most positive one 110 | idx_to_keep = torch.topk(pred_byStudent, k=1)[1] 111 | feat_removedNeg = feat[idx_to_keep] 112 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat_removedNeg) 113 | instance_attn_score = torch.cat([instance_attn_score, instance_attn_score[:, 1].min()*torch.ones(feat.shape[0]-instance_attn_score.shape[0], 2).to(instance_attn_score.device)], dim=0) 114 | else: 115 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 116 | else: 117 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 118 | else: 119 | instance_attn_score, bag_prediction, _, _ = self.model_teacherHead(feat) 120 | 121 | max_id = torch.argmax(instance_attn_score[:, 1]) 122 | bag_pred_byMax = instance_attn_score[max_id, :].squeeze(0) 123 | bag_loss = criterion(bag_prediction, label[1]) 124 | bag_loss_byMax = criterion(bag_pred_byMax.unsqueeze(0), label[1]) 125 | loss_teacher = 0.5 * bag_loss + 0.5 * bag_loss_byMax 126 | 127 | self.optimizer_encoder.zero_grad() 128 | self.optimizer_teacherHead.zero_grad() 129 | loss_teacher.backward() 130 | self.optimizer_encoder.step() 131 | self.optimizer_teacherHead.step() 132 | 133 | bag_prediction = 1.0 * torch.softmax(bag_prediction, dim=1) + \ 134 | 0.0 * torch.softmax(bag_pred_byMax.unsqueeze(0), dim=1) 135 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1) 136 | 137 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0)) 138 | patch_label_gt.append(label[0].squeeze(0)) 139 | bag_label_pred.append(bag_prediction.detach()[0, 1]) 140 | bag_label_gt.append(label[1]) 141 | if niter % self.log_period == 0: 142 | self.writer.add_scalar('train_loss_Teacher', loss_teacher.item(), niter) 143 | 144 | patch_label_pred = torch.cat(patch_label_pred) 145 | patch_label_gt = torch.cat(patch_label_gt) 146 | bag_label_pred = torch.tensor(bag_label_pred) 147 | bag_label_gt = torch.cat(bag_label_gt) 148 | 149 | self.estimated_AttnScore_norm_para_min = patch_label_pred.min() 150 | self.estimated_AttnScore_norm_para_max = patch_label_pred.max() 151 | patch_label_pred_normed = self.norm_AttnScore2Prob(patch_label_pred) 152 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1)) 153 | 154 | bag_auc_ByTeacher = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_pred.reshape(-1)) 155 | self.writer.add_scalar('train_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch) 156 | self.writer.add_scalar('train_bag_AUC_byTeacher', bag_auc_ByTeacher, epoch) 157 | # print("Epoch:{} train_bag_AUC_byTeacher:{}".format(epoch, bag_auc_ByTeacher)) 158 | return 0 159 | 160 | def norm_AttnScore2Prob(self, attn_score): 161 | prob = (attn_score - self.estimated_AttnScore_norm_para_min) / (self.estimated_AttnScore_norm_para_max - self.estimated_AttnScore_norm_para_min) 162 | return prob 163 | 164 | def post_process_pred_byTeacher(self, Bank_all_instances_feat, Bank_all_instances_pred, Bank_all_bags_label, method='NegGuide'): 165 | if method=='NegGuide': 166 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 167 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 168 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 169 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 170 | elif method=='NegGuide_TopK': 171 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 172 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 173 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0] 174 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 175 | K = 3 176 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1] 177 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1) 178 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 179 | elif method=='NegGuide_Similarity': 180 | Bank_all_instances_pred_processed = Bank_all_instances_pred.clone() 181 | Bank_all_instances_pred_processed = self.norm_AttnScore2Prob(Bank_all_instances_pred_processed).clamp(min=1e-5, max=1 - 1e-5) 182 | idx_pos_bag = torch.where(Bank_all_bags_label[:, 0] == 1)[0] 183 | idx_neg_bag = torch.where(Bank_all_bags_label[:, 0] == 0)[0] 184 | K = 1 185 | idx_topK_inside_pos_bag = torch.topk(Bank_all_instances_pred_processed[idx_pos_bag, :], k=K, dim=-1, largest=True)[1] 186 | Bank_all_instances_pred_processed[idx_pos_bag].scatter_(index=idx_topK_inside_pos_bag, dim=1, value=1) 187 | Bank_all_Pos_instances_feat = Bank_all_instances_feat[idx_pos_bag] 188 | Bank_mostSalient_Pos_instances_feat = [] 189 | for i in range(Bank_all_Pos_instances_feat.shape[0]): 190 | Bank_mostSalient_Pos_instances_feat.append(Bank_all_Pos_instances_feat[i, idx_topK_inside_pos_bag[i, 0], :].unsqueeze(0).unsqueeze(0)) 191 | Bank_mostSalient_Pos_instances_feat = torch.cat(Bank_mostSalient_Pos_instances_feat, dim=0) 192 | 193 | distance_matrix = Bank_all_Pos_instances_feat - Bank_mostSalient_Pos_instances_feat 194 | distance_matrix = torch.norm(distance_matrix, dim=-1, p=2) 195 | Bank_all_instances_pred_processed[idx_pos_bag, :] = self.distanceMatrix2PL(distance_matrix) 196 | Bank_all_instances_pred_processed[idx_neg_bag, :] = 0 197 | else: 198 | raise TypeError 199 | return Bank_all_instances_pred_processed 200 | 201 | def distanceMatrix2PL(self, distance_matrix, method='percentage'): 202 | # distance_matrix is of shape NxL (Num of Positive Bag * Bag Length) 203 | # represents the distance between each instance with their corresponding most salient instance 204 | # return Pseudo-labels of shape NxL (value should belong to [0,1]) 205 | 206 | if method == 'softMax': 207 | # 1. just use softMax to keep PLs value fall into [0,1] 208 | similarity_matrix = 1/(distance_matrix + 1e-5) 209 | pseudo_labels = torch.softmax(similarity_matrix, dim=1) 210 | elif method == 'percentage': 211 | # 2. use percentage to keep n% PL=1, 1-n% PL=0 212 | p = 0.1 # 10% is set 213 | threshold_v = distance_matrix.topk(k=int(100 * p), dim=1)[0][:, -1].unsqueeze(1).repeat([1, 100]) # of size Nx100 214 | pseudo_labels = torch.zeros_like(distance_matrix) 215 | pseudo_labels[distance_matrix >= threshold_v] = 0.0 216 | pseudo_labels[distance_matrix < threshold_v] = 1.0 217 | elif method == 'threshold': 218 | # 3. use threshold to set PLs of instance with distance above the threshold to 1 219 | raise TypeError 220 | else: 221 | raise TypeError 222 | 223 | ## visulaize the pseudo_labels distribution of inside each bag 224 | # import matplotlib.pyplot as plt 225 | # plt.figure() 226 | # plt.hist(pseudo_labels.cpu().numpy().reshape(-1)) 227 | 228 | return pseudo_labels 229 | 230 | def optimize_student(self, epoch): 231 | self.model_teacherHead.train() 232 | self.model_encoder.train() 233 | self.model_studentHead.train() 234 | ## optimize teacher with instance-dataloader 235 | # 1. change loader to instance-loader 236 | loader = self.train_instanceloader 237 | # 2. optimize 238 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 239 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 240 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 241 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 242 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')): 243 | for i, j in enumerate(label): 244 | if torch.is_tensor(j): 245 | label[i] = j.to(self.dev) 246 | selected = selected.squeeze(0) 247 | niter = epoch * len(loader) + iter 248 | 249 | data = data.to(self.dev) 250 | 251 | # get teacher output of instance 252 | feat = self.model_encoder(data) 253 | with torch.no_grad(): 254 | instance_attn_score, _, _, _ = self.model_teacherHead(feat) 255 | pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score[:, 1]).clamp(min=1e-5, max=1-1e-5).squeeze(0) 256 | # set true negative patch label to [1, 0] 257 | pseudo_instance_label[label[1] == 0] = 0 258 | # # DEBUG: Assign GT patch label 259 | # pseudo_instance_label = label[0] 260 | # get student output of instance 261 | patch_prediction = self.model_studentHead(feat) 262 | patch_prediction = torch.softmax(patch_prediction, dim=1) 263 | 264 | # cal loss 265 | loss_student = -1. * torch.mean(self.stu_loss_weight_neg * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) + 266 | (1-self.stu_loss_weight_neg) * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5)) 267 | self.optimizer_encoder.zero_grad() 268 | self.optimizer_studentHead.zero_grad() 269 | loss_student.backward() 270 | self.optimizer_encoder.step() 271 | self.optimizer_studentHead.step() 272 | 273 | patch_corresponding_slide_idx[selected, 0] = label[2] 274 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 275 | patch_label_gt[selected, 0] = label[0] 276 | bag_label_gt[selected, 0] = label[1] 277 | if niter % self.log_period == 0: 278 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter) 279 | 280 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 281 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 282 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 283 | 284 | # cal bag-level auc 285 | bag_label_gt_coarse = [] 286 | bag_label_prediction = [] 287 | available_bag_idx = patch_corresponding_slide_idx.unique() 288 | for bag_idx_i in available_bag_idx: 289 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i) 290 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max(): 291 | raise 292 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max()) 293 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max()) 294 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse) 295 | bag_label_prediction = torch.tensor(bag_label_prediction) 296 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1)) 297 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 298 | return 0 299 | 300 | def optimize_student_fromBank(self, epoch, Bank_all_instances_pred): 301 | self.model_teacherHead.train() 302 | self.model_encoder.train() 303 | self.model_studentHead.train() 304 | ## optimize teacher with instance-dataloader 305 | # 1. change loader to instance-loader 306 | loader = self.train_instanceloader 307 | # 2. optimize 308 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 309 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 310 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 311 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 312 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student training')): 313 | for i, j in enumerate(label): 314 | if torch.is_tensor(j): 315 | label[i] = j.to(self.dev) 316 | selected = selected.squeeze(0) 317 | niter = epoch * len(loader) + iter 318 | 319 | data = data.to(self.dev) 320 | 321 | # get teacher output of instance 322 | feat = self.model_encoder(data) 323 | # with torch.no_grad(): 324 | # _, _, _, instance_attn_score = self.model_teacherHead(feat, returnBeforeSoftMaxA=True) 325 | # pseudo_instance_label = self.norm_AttnScore2Prob(instance_attn_score).clamp(min=1e-5, max=1-1e-5).squeeze(0) 326 | # # set true negative patch label to [1, 0] 327 | # pseudo_instance_label[label[1] == 0] = 0 328 | 329 | pseudo_instance_label = Bank_all_instances_pred[selected//100, selected%100] 330 | # # DEBUG: Assign GT patch label 331 | # pseudo_instance_label = label[0] 332 | # get student output of instance 333 | patch_prediction = self.model_studentHead(feat) 334 | patch_prediction = torch.softmax(patch_prediction, dim=1) 335 | 336 | # cal loss 337 | loss_student = -1. * torch.mean(0.1 * (1-pseudo_instance_label) * torch.log(patch_prediction[:, 0] + 1e-5) + 338 | 0.9 * pseudo_instance_label * torch.log(patch_prediction[:, 1] + 1e-5)) 339 | self.optimizer_encoder.zero_grad() 340 | self.optimizer_studentHead.zero_grad() 341 | loss_student.backward() 342 | self.optimizer_encoder.step() 343 | self.optimizer_studentHead.step() 344 | 345 | patch_corresponding_slide_idx[selected, 0] = label[2] 346 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 347 | patch_label_gt[selected, 0] = label[0] 348 | bag_label_gt[selected, 0] = label[1] 349 | if niter % self.log_period == 0: 350 | self.writer.add_scalar('train_loss_Student', loss_student.item(), niter) 351 | 352 | self.Bank_all_instances_pred_byStudent = patch_label_pred 353 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 354 | bag_auc_ByStudent = 0 355 | self.writer.add_scalar('train_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 356 | self.writer.add_scalar('train_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 357 | # print("Epoch:{} train_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 358 | return 0 359 | 360 | def evaluate(self, epoch, loader, log_name_prefix=''): 361 | return 0 362 | 363 | def evaluate_teacher(self, epoch): 364 | self.model_encoder.eval() 365 | self.model_teacherHead.eval() 366 | ## optimize teacher with bag-dataloader 367 | # 1. change loader to bag-loader 368 | loader = self.test_bagloader 369 | # 2. optimize 370 | patch_label_gt = [] 371 | patch_label_pred = [] 372 | bag_label_gt = [] 373 | bag_label_prediction_withAttnScore = [] 374 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Teacher evaluating')): 375 | for i, j in enumerate(label): 376 | if torch.is_tensor(j): 377 | label[i] = j.to(self.dev) 378 | selected = selected.squeeze(0) 379 | niter = epoch * len(loader) + iter 380 | 381 | data = data.to(self.dev) 382 | with torch.no_grad(): 383 | feat = self.model_encoder(data.squeeze(0)) 384 | 385 | instance_attn_score, bag_prediction_withAttnScore, _, _ = self.model_teacherHead(feat) 386 | bag_prediction_withAttnScore = torch.softmax(bag_prediction_withAttnScore, 1) 387 | # instance_attn_score = torch.softmax(instance_attn_score, dim=1) 388 | 389 | patch_label_pred.append(instance_attn_score[:, 1].detach().squeeze(0)) 390 | patch_label_gt.append(label[0].squeeze(0)) 391 | bag_label_prediction_withAttnScore.append(bag_prediction_withAttnScore.detach()[0, 1]) 392 | bag_label_gt.append(label[1]) 393 | 394 | patch_label_pred = torch.cat(patch_label_pred) 395 | patch_label_gt = torch.cat(patch_label_gt) 396 | bag_label_prediction_withAttnScore = torch.tensor(bag_label_prediction_withAttnScore) 397 | bag_label_gt = torch.cat(bag_label_gt) 398 | 399 | patch_label_pred_normed = (patch_label_pred - patch_label_pred.min()) / (patch_label_pred.max() - patch_label_pred.min()) 400 | instance_auc_ByTeacher = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred_normed.reshape(-1)) 401 | bag_auc_ByTeacher_withAttnScore = utliz.cal_auc(bag_label_gt.reshape(-1), bag_label_prediction_withAttnScore.reshape(-1)) 402 | self.writer.add_scalar('test_instance_AUC_byTeacher', instance_auc_ByTeacher, epoch) 403 | self.writer.add_scalar('test_bag_AUC_byTeacher', bag_auc_ByTeacher_withAttnScore, epoch) 404 | return 0 405 | 406 | def evaluate_student(self, epoch): 407 | self.model_encoder.eval() 408 | self.model_studentHead.eval() 409 | ## optimize teacher with instance-dataloader 410 | # 1. change loader to instance-loader 411 | loader = self.test_instanceloader 412 | # 2. optimize 413 | patch_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) # only for patch-label available dataset 414 | patch_label_pred = torch.zeros([loader.dataset.__len__(), 1]).float().to(self.dev) 415 | bag_label_gt = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 416 | patch_corresponding_slide_idx = torch.zeros([loader.dataset.__len__(), 1]).long().to(self.dev) 417 | for iter, (data, label, selected) in enumerate(tqdm(loader, desc='Student evaluating')): 418 | for i, j in enumerate(label): 419 | if torch.is_tensor(j): 420 | label[i] = j.to(self.dev) 421 | selected = selected.squeeze(0) 422 | niter = epoch * len(loader) + iter 423 | 424 | data = data.to(self.dev) 425 | 426 | # get student output of instance 427 | with torch.no_grad(): 428 | feat = self.model_encoder(data) 429 | patch_prediction = self.model_studentHead(feat) 430 | patch_prediction = torch.softmax(patch_prediction, dim=1) 431 | 432 | patch_corresponding_slide_idx[selected, 0] = label[2] 433 | patch_label_pred[selected, 0] = patch_prediction.detach()[:, 1] 434 | patch_label_gt[selected, 0] = label[0] 435 | bag_label_gt[selected, 0] = label[1] 436 | 437 | instance_auc_ByStudent = utliz.cal_auc(patch_label_gt.reshape(-1), patch_label_pred.reshape(-1)) 438 | self.writer.add_scalar('test_instance_AUC_byStudent', instance_auc_ByStudent, epoch) 439 | # print("Epoch:{} test_instance_AUC_byStudent:{}".format(epoch, instance_auc_ByStudent)) 440 | 441 | # cal bag-level auc 442 | bag_label_gt_coarse = [] 443 | bag_label_prediction = [] 444 | available_bag_idx = patch_corresponding_slide_idx.unique() 445 | for bag_idx_i in available_bag_idx: 446 | idx_same_bag_i = torch.where(patch_corresponding_slide_idx == bag_idx_i) 447 | if bag_label_gt[idx_same_bag_i].max() != bag_label_gt[idx_same_bag_i].max(): 448 | raise 449 | bag_label_gt_coarse.append(bag_label_gt[idx_same_bag_i].max()) 450 | bag_label_prediction.append(patch_label_pred[idx_same_bag_i].max()) 451 | bag_label_gt_coarse = torch.tensor(bag_label_gt_coarse) 452 | bag_label_prediction = torch.tensor(bag_label_prediction) 453 | bag_auc_ByStudent = utliz.cal_auc(bag_label_gt_coarse.reshape(-1), bag_label_prediction.reshape(-1)) 454 | self.writer.add_scalar('test_bag_AUC_byStudent', bag_auc_ByStudent, epoch) 455 | return 0 456 | 457 | 458 | def str2bool(v): 459 | """ 460 | Input: 461 | v - string 462 | output: 463 | True/False 464 | """ 465 | if isinstance(v, bool): 466 | return v 467 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 468 | return True 469 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 470 | return False 471 | else: 472 | raise argparse.ArgumentTypeError('Boolean value expected.') 473 | 474 | 475 | def get_parser(): 476 | parser = argparse.ArgumentParser(description='PyTorch Implementation of Self-Label') 477 | # optimizer 478 | parser.add_argument('--epochs', default=1500, type=int, help='number of epochs') 479 | parser.add_argument('--batch_size', default=512, type=int, help='batch size (default: 256)') 480 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate (default: 0.05)') 481 | parser.add_argument('--lrdrop', default=1500, type=int, help='multiply LR by 0.5 every (default: 150 epochs)') 482 | parser.add_argument('--wd', default=-5, type=float, help='weight decay pow (default: (-5)') 483 | parser.add_argument('--dtype', default='f64', choices=['f64', 'f32'], type=str, help='SK-algo dtype (default: f64)') 484 | 485 | # SK algo 486 | parser.add_argument('--nopts', default=100, type=int, help='number of pseudo-opts (default: 100)') 487 | parser.add_argument('--augs', default=3, type=int, help='augmentation level (default: 3)') 488 | parser.add_argument('--lamb', default=25, type=int, help='for pseudoopt: lambda (default:25) ') 489 | 490 | # architecture 491 | # parser.add_argument('--arch', default='alexnet_MNIST', type=str, help='alexnet or resnet (default: alexnet)') 492 | 493 | # housekeeping 494 | parser.add_argument('--device', default='0', type=str, help='GPU devices to use for storage and model') 495 | parser.add_argument('--modeldevice', default='0', type=str, help='GPU numbers on which the CNN runs') 496 | parser.add_argument('--exp', default='self-label-default', help='path to experiment directory') 497 | parser.add_argument('--workers', default=0, type=int,help='number workers (default: 6)') 498 | parser.add_argument('--comment', default='DEBUG_BagDistillation_DSMIL', type=str, help='name for tensorboardX') 499 | parser.add_argument('--log-intv', default=1, type=int, help='save stuff every x epochs (default: 1)') 500 | parser.add_argument('--log_iter', default=200, type=int, help='log every x-th batch (default: 200)') 501 | parser.add_argument('--seed', default=10, type=int, help='random seed') 502 | 503 | parser.add_argument('--PLPostProcessMethod', default='NegGuide', type=str, 504 | help='Post-processing method of Attention Scores to build Pseudo Lables', 505 | choices=['NegGuide', 'NegGuide_TopK', 'NegGuide_Similarity']) 506 | parser.add_argument('--StuFilterType', default='FilterNegInstance__ThreProb50', type=str, 507 | help='Type of using Student Prediction to imporve Teacher ' 508 | '[ReplaceAS, FilterNegInstance_Top95, FilterNegInstance__ThreProb90]') 509 | parser.add_argument('--smoothE', default=9999, type=int, help='num of epoch to apply StuFilter') 510 | parser.add_argument('--stu_loss_weight_neg', default=0.1, type=float, help='weight of neg instances in stu training') 511 | parser.add_argument('--stuOptPeriod', default=1, type=int, help='period of stu optimization') 512 | return parser.parse_args() 513 | 514 | 515 | if __name__ == "__main__": 516 | args = get_parser() 517 | 518 | # torch.manual_seed(args.seed) 519 | # random.seed(args.seed) 520 | # np.random.seed(args.seed) 521 | 522 | name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")+"_%s" % args.comment.replace('/', '_') + \ 523 | "_Seed{}_Bs{}_lr{}_PLPostProcessBy{}_StuFilterType{}_smoothE{}_weightN{}_StuOptP{}".format( 524 | args.seed, args.batch_size, args.lr, 525 | args.PLPostProcessMethod, args.StuFilterType, args.smoothE, args.stu_loss_weight_neg, args.stuOptPeriod) 526 | try: 527 | args.device = [int(item) for item in args.device.split(',')] 528 | except AttributeError: 529 | args.device = [int(args.device)] 530 | args.modeldevice = args.device 531 | util.setup_runtime(seed=42, cuda_dev_id=list(np.unique(args.modeldevice + args.device))) 532 | 533 | print(name, flush=True) 534 | 535 | writer = SummaryWriter('./runs_CAMELYON/%s'%name) 536 | writer.add_text('args', " \n".join(['%s %s' % (arg, getattr(args, arg)) for arg in vars(args)])) 537 | 538 | # Setup model 539 | model_encoder = camelyon_feat_projecter(input_dim=512, output_dim=512).to('cuda:0') 540 | model_teacherHead = teacher_DSMIL_head(input_feat_dim=512).to('cuda:0') 541 | model_studentHead = student_head(input_feat_dim=512).to('cuda:0') 542 | 543 | optimizer_encoder = torch.optim.SGD(model_encoder.parameters(), lr=args.lr) 544 | optimizer_teacherHead = torch.optim.SGD(model_teacherHead.parameters(), lr=args.lr) 545 | optimizer_studentHead = torch.optim.SGD(model_studentHead.parameters(), lr=args.lr) 546 | 547 | # Setup loaders 548 | train_ds_return_instance = CAMELYON_16_feat(train=True, transform=None, downsample=1.0, drop_threshold=0, preload=True, return_bag=False) 549 | 550 | train_ds_return_bag = copy.deepcopy(train_ds_return_instance) 551 | train_ds_return_bag.return_bag = True 552 | val_ds_return_instance = CAMELYON_16_feat(train=False, transform=None, downsample=1.0, drop_threshold=0, preload=True, return_bag=False) 553 | val_ds_return_bag = CAMELYON_16_feat(train=False, transform=None, downsample=1.0, drop_threshold=0, preload=True, return_bag=True) 554 | 555 | train_loader_instance = torch.utils.data.DataLoader(train_ds_return_instance, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=False) 556 | train_loader_bag = torch.utils.data.DataLoader(train_ds_return_bag, batch_size=1, shuffle=True, num_workers=args.workers, drop_last=False) 557 | val_loader_instance = torch.utils.data.DataLoader(val_ds_return_instance, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False) 558 | val_loader_bag = torch.utils.data.DataLoader(val_ds_return_bag, batch_size=1, shuffle=False, num_workers=args.workers, drop_last=False) 559 | 560 | print("[Data] {} training samples".format(len(train_loader_instance.dataset))) 561 | print("[Data] {} evaluating samples".format(len(val_loader_instance.dataset))) 562 | 563 | if torch.cuda.device_count() > 1: 564 | print("Let's use", len(args.modeldevice), "GPUs for the model") 565 | if len(args.modeldevice) == 1: 566 | print('single GPU model', flush=True) 567 | else: 568 | model_encoder = nn.DataParallel(model_encoder, device_ids=list(range(len(args.modeldevice)))) 569 | model_teacherHead = nn.DataParallel(model_teacherHead, device_ids=list(range(len(args.modeldevice)))) 570 | optimizer_studentHead = nn.DataParallel(optimizer_studentHead, device_ids=list(range(len(args.modeldevice)))) 571 | 572 | # Setup optimizer 573 | o = Optimizer(model_encoder=model_encoder, model_teacherHead=model_teacherHead, model_studentHead=model_studentHead, 574 | optimizer_encoder=optimizer_encoder, optimizer_teacherHead=optimizer_teacherHead, optimizer_studentHead=optimizer_studentHead, 575 | train_bagloader=train_loader_bag, train_instanceloader=train_loader_instance, 576 | test_bagloader=val_loader_bag, test_instanceloader=val_loader_instance, 577 | writer=writer, num_epoch=args.epochs, 578 | dev=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 579 | PLPostProcessMethod=args.PLPostProcessMethod, StuFilterType=args.StuFilterType, smoothE=args.smoothE, 580 | stu_loss_weight_neg=args.stu_loss_weight_neg, stuOptPeriod=args.stuOptPeriod) 581 | # Optimize 582 | o.optimize() --------------------------------------------------------------------------------