├── LICENSE ├── README.md ├── ROAM ├── compute_metric.py ├── configs │ ├── int_glioma_tumor_subtyping.ini │ └── int_glioma_tumor_subtyping_test.ini ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── patchdataset.cpython-37.pyc │ │ ├── patchdataset.cpython-38.pyc │ │ ├── roidataset.cpython-37.pyc │ │ ├── roidataset.cpython-38.pyc │ │ ├── roidataset.cpython-39.pyc │ │ ├── vis_dataset.cpython-38.pyc │ │ └── wsi_dataset.cpython-38.pyc │ ├── roidataset.py │ └── vis_dataset.py ├── gen_visheatmaps_roi_batch.py ├── gen_visheatmaps_slide_batch.py ├── main.py ├── models │ ├── ROAM.py │ └── __init__.py ├── models_embed │ ├── ResNet.py │ ├── ccl.py │ ├── ctran.py │ ├── extractor.py │ └── simclr_ciga.py ├── parse_config.py ├── position_embedding.py ├── predict_cascade.py ├── prediction_results │ └── s1 │ │ ├── cascade_int_oli_grade_split.npy │ │ └── int_glioma_tumor_subtyping │ │ ├── predictions.json │ │ └── results.json ├── results │ └── int_glioma_tumor_subtyping │ │ └── int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False │ │ └── s1 │ │ ├── results.json │ │ └── visual_res │ │ ├── cm_mean.png │ │ ├── metrics.json │ │ └── normal_cm_mean.png ├── scripts │ ├── cascade_pred_int_glioma_tumor_subtyping.sh │ ├── int_glioma_tumor_subtyping.sh │ └── int_glioma_tumor_subtyping_test.sh ├── utils │ ├── __pycache__ │ │ ├── core_utils.cpython-37.pyc │ │ ├── file_utils.cpython-37.pyc │ │ ├── file_utils.cpython-38.pyc │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-38.pyc │ ├── eval_utils.py │ ├── file_utils.py │ └── utils.py ├── vahadane.py ├── vis_utils │ ├── __pycache__ │ │ ├── heatmap_utils.cpython-38.pyc │ │ ├── vit_grad_rollout.cpython-38.pyc │ │ └── vit_rollout.cpython-38.pyc │ ├── heatmap_utils.py │ ├── vit_explain.py │ ├── vit_grad_rollout.py │ └── vit_rollout.py ├── visheatmaps │ ├── roi_vis │ │ ├── configs │ │ │ └── int_glioma_tumor_subtyping_vis_roi.ini │ │ └── high │ │ │ └── int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False │ │ │ └── grad_rollout │ │ │ └── oligodendroglioma │ │ │ └── d0ab09865c3b467 │ │ │ └── top1_seeds1_0_d2_l2_r0.0_avg.png │ ├── slide_vis │ │ ├── configs │ │ │ ├── config_int_glioma_tumor_subtyping.yaml │ │ │ └── config_int_glioma_tumor_subtyping_roi.yaml │ │ ├── process_list │ │ │ ├── int_glioma_tumor_subtyping.csv │ │ │ └── int_glioma_tumor_subtyping_roi.csv │ │ └── results │ │ │ ├── heatmap_production_results │ │ │ └── int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False │ │ │ │ └── oligodendroglioma │ │ │ │ ├── d0ab09865c3b467_0.95_roi_1_blur_1_rs_1_bc_0_a_0.4_l_4_bi_0_-1.0.jpg │ │ │ │ ├── d0ab09865c3b467_0.9_roi_0_blur_1_rs_1_bc_0_a_0.4_l_5_bi_0_-1.0.jpg │ │ │ │ ├── d0ab09865c3b467_orig_4.jpg │ │ │ │ └── d0ab09865c3b467_orig_5.jpg │ │ │ └── int_glioma_tumor_subtyping_roi.csv │ └── target_roi_6e3.jpg └── wsi_core │ ├── WholeSlideImage.py │ ├── batch_process_utils.py │ ├── util_classes.py │ └── wsi_utils.py ├── data_prepare ├── create_patches_fp.py ├── create_splits.ipynb ├── data_csv │ └── example_xiangya_data_info_pro.csv ├── data_split │ └── xiangya_split_subtype │ │ └── example_test_split.npy ├── extract_feature_patch.py ├── models │ ├── ResNet.py │ ├── ccl.py │ ├── ctran.py │ ├── extractor.py │ └── simclr_ciga.py ├── patchdataset.py ├── target_images │ ├── target_image_6e3_1024.jpg │ ├── target_image_6e3_256.jpg │ ├── target_image_6e3_512.jpg │ └── target_roi_6e3.jpg ├── utils │ ├── __pycache__ │ │ ├── file_utils.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── file_utils.py │ └── utils.py ├── vahadane.py └── wsi_core │ ├── WholeSlideImage.py │ ├── __pycache__ │ ├── WholeSlideImage.cpython-37.pyc │ ├── WholeSlideImage.cpython-38.pyc │ ├── batch_process_utils.cpython-37.pyc │ ├── batch_process_utils.cpython-38.pyc │ ├── util_classes.cpython-37.pyc │ ├── util_classes.cpython-38.pyc │ ├── wsi_utils.cpython-37.pyc │ └── wsi_utils.cpython-38.pyc │ ├── batch_process_utils.py │ ├── util_classes.py │ └── wsi_utils.py └── docs ├── ROAM.png ├── cascade_diagnosis.jpg ├── environment.yaml └── visualization_examples.png /ROAM/compute_metric.py: -------------------------------------------------------------------------------- 1 | from unicodedata import name 2 | import numpy as np 3 | import pandas as pd 4 | import pickle 5 | import json 6 | import os 7 | import sys 8 | from sklearn import metrics 9 | from sklearn.metrics import confusion_matrix, balanced_accuracy_score 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | 13 | ## class names for each task 14 | cls_name_dict = { 15 | 'int_glioma_tumor_subtyping':['astrocytoma','oligodendroglioma','ependymoma'], 16 | 'ext_glioma_tumor_subtyping3':['astrocytoma','oligodendroglioma'], 17 | 'int_glioma_cls':['normal','gliosis','tumor'] 18 | } 19 | 20 | def getmetric(cm,num_cls): 21 | tr= np.trace(cm) 22 | precs = [] 23 | recs = [] 24 | f1scores = [] 25 | if num_cls == 2: 26 | cid = 1 27 | TP = cm[cid,cid] 28 | #TN = tr-cm[cid,cid] 29 | gt = cm[cid] #tp+fn 30 | pred = cm[:,cid] #tp+fp 31 | 32 | prec = TP/np.sum(pred) 33 | rec = TP/np.sum(gt) 34 | f1score = 2 * prec*rec/(prec+rec) 35 | return prec,rec,f1score 36 | for cid in range(num_cls): 37 | TP = cm[cid,cid] 38 | #TN = tr-cm[cid,cid] 39 | gt = cm[cid] #tp+fn 40 | pred = cm[:,cid] #tp+fp 41 | 42 | prec = TP/np.sum(pred) 43 | rec = TP/np.sum(gt) 44 | f1score = 2 * prec*rec/(prec+rec) 45 | 46 | precs.append(prec) 47 | recs.append(rec) 48 | f1scores.append(f1score) 49 | return np.mean(precs),np.mean(recs),np.mean(f1scores) 50 | 51 | 52 | def compute_metric_results(args,task): 53 | metric = {} 54 | 55 | savepath = os.path.join(args.results_dir,'visual_res') 56 | respath = os.path.join(args.results_dir,'results.json') 57 | 58 | os.makedirs(savepath,exist_ok=True) 59 | 60 | with open(respath,'r') as f: 61 | res = json.load(f) 62 | 63 | test_acc = res['test']['acc'] 64 | 65 | clsnames = cls_name_dict[task] 66 | 67 | 68 | cls_num = len(clsnames) 69 | gt = res['test']['trues'] 70 | pred = res['test']['preds'] 71 | cm = confusion_matrix(gt,pred) 72 | acc_b = balanced_accuracy_score(gt,pred) 73 | print(cm) 74 | 75 | prec,rec,f1score = getmetric(cm,cls_num) 76 | print(f'prec:{prec},recall:{rec},f1_score:{f1score}') 77 | metric['acc'] = test_acc 78 | metric['precision'] = prec 79 | metric['recall'] = rec 80 | metric['f1_score'] = f1score 81 | metric['balanced_accuracy'] = acc_b 82 | 83 | cm_normal = cm/cm.sum(axis=1)[:,np.newaxis] 84 | # confusion matrix 85 | plt.figure(figsize=(10,10)) 86 | sns.heatmap(cm,annot=True,cmap='Blues') 87 | plt.title(f'confusion matrix mean') 88 | 89 | plt.xlabel('predicted labels') 90 | plt.ylabel('ground truth labels') 91 | xlocations = np.array(range(len(clsnames))) + 0.5 92 | plt.xticks(xlocations,clsnames) 93 | plt.yticks(xlocations,clsnames,rotation = 90) 94 | 95 | plt.savefig(os.path.join(savepath,'cm_mean.png')) 96 | 97 | # normalized confusion matrix 98 | plt.figure(figsize=(10,10)) 99 | sns.heatmap(cm_normal,annot=True,cmap='Blues') 100 | plt.title(f'confusion matrix mean') 101 | 102 | plt.xlabel('predicted labels') 103 | plt.ylabel('ground truth labels') 104 | xlocations = np.array(range(len(clsnames))) + 0.5 105 | plt.xticks(xlocations,clsnames) 106 | plt.yticks(xlocations,clsnames,rotation = 90) 107 | 108 | plt.savefig(os.path.join(savepath,'normal_cm_mean.png')) 109 | 110 | 111 | with open(os.path.join(savepath,'metrics.json'),'w') as f: 112 | json.dump(metric,f) 113 | 114 | 115 | # compute metics for specific seed 116 | def compute_metric_results_seed(exp_code,task,seed): 117 | metric = {} 118 | 119 | results_dir = f'results/{task}/{exp_code}/{seed}' 120 | savepath = os.path.join(results_dir,'visual_res') 121 | respath = os.path.join(results_dir,'results.json') 122 | 123 | os.makedirs(savepath,exist_ok=True) 124 | 125 | with open(respath,'r') as f: 126 | res = json.load(f) 127 | 128 | test_acc = res['test']['acc'] 129 | 130 | clsnames = cls_name_dict[task] 131 | 132 | 133 | cls_num = len(clsnames) 134 | gt = res['test']['trues'] 135 | pred = res['test']['preds'] 136 | cm = confusion_matrix(gt,pred) 137 | acc_b = balanced_accuracy_score(gt,pred) 138 | print(cm) 139 | 140 | prec,rec,f1score = getmetric(cm,cls_num) 141 | print(f'prec:{prec},recall:{rec},f1_score:{f1score}') 142 | metric['acc'] = test_acc 143 | metric['precision'] = prec 144 | metric['recall'] = rec 145 | metric['f1_score'] = f1score 146 | metric['balanced_accuracy'] = acc_b 147 | 148 | cm_normal = cm/cm.sum(axis=1)[:,np.newaxis] 149 | # confusion matrix 150 | plt.figure(figsize=(10,10)) 151 | sns.heatmap(cm,annot=True,cmap='Blues') 152 | plt.title(f'confusion matrix mean') 153 | 154 | plt.xlabel('predicted labels') 155 | plt.ylabel('ground truth labels') 156 | xlocations = np.array(range(len(clsnames))) + 0.5 157 | plt.xticks(xlocations,clsnames) 158 | plt.yticks(xlocations,clsnames,rotation = 90) 159 | 160 | plt.savefig(os.path.join(savepath,'cm_mean.png')) 161 | 162 | # normalized confusion matrix 163 | plt.figure(figsize=(10,10)) 164 | sns.heatmap(cm_normal,annot=True,cmap='Blues') 165 | plt.title(f'confusion matrix mean') 166 | 167 | plt.xlabel('predicted labels') 168 | plt.ylabel('ground truth labels') 169 | xlocations = np.array(range(len(clsnames))) + 0.5 170 | plt.xticks(xlocations,clsnames) 171 | plt.yticks(xlocations,clsnames,rotation = 90) 172 | 173 | plt.savefig(os.path.join(savepath,'normal_cm_mean.png')) 174 | 175 | 176 | with open(os.path.join(savepath,'metrics.json'),'w') as f: 177 | json.dump(metric,f) 178 | 179 | 180 | if __name__ == '__main__': 181 | exp_name = sys.argv[1] 182 | task = sys.argv[2] 183 | seed = sys.argv[3] 184 | 185 | compute_metric_results_seed(exp_name,task,seed) 186 | -------------------------------------------------------------------------------- /ROAM/configs/int_glioma_tumor_subtyping.ini: -------------------------------------------------------------------------------- 1 | [int_glioma_tumor_subtyping] 2 | 3 | seed = 1 4 | stage = train 5 | embed_type = ImageNet 6 | sample_size = 100 7 | not_stainnorm = False 8 | test_dataset = xiangya 9 | data_root_dir = ../data_prepare/example 10 | results_dir = results 11 | 12 | max_epochs = 200 13 | batch_size = 4 14 | lr = 2e-4 15 | optimizer = adamw 16 | weight_decay = 1e-5 17 | scheduler = none 18 | stop_epochs = 20 19 | weighted_sample = True 20 | emb_dropout = 0 21 | attn_dropout = 0.25 22 | dropout = 0.2 23 | 24 | model_type = ROAM 25 | roi_dropout = True 26 | roi_supervise = True 27 | roi_weight = 1 28 | topk = 4 29 | roi_level = 0 30 | single_level = 0 31 | scale_type = ms 32 | embed_weightx5 = 0.3333 33 | embed_weightx10 = 0.3333 34 | embed_weightx20 = 0.3333 35 | not_interscale = False 36 | 37 | dim = 256 38 | depths = [2,2,2,2,2] 39 | heads = 8 40 | mlp_dim = 512 41 | dim_head = 64 42 | pool = cls 43 | ape = True 44 | attn_type = rel_sa 45 | shared_pe = True 46 | -------------------------------------------------------------------------------- /ROAM/configs/int_glioma_tumor_subtyping_test.ini: -------------------------------------------------------------------------------- 1 | [int_glioma_tumor_subtyping] 2 | 3 | seed = 1 4 | stage = test 5 | embed_type = ImageNet 6 | sample_size = 100 7 | not_stainnorm = False 8 | test_dataset = xiangya 9 | data_root_dir = ../data_prepare/example 10 | results_dir = results 11 | 12 | max_epochs = 200 13 | batch_size = 4 14 | lr = 2e-4 15 | optimizer = adamw 16 | weight_decay = 1e-5 17 | scheduler = none 18 | stop_epochs = 20 19 | weighted_sample = True 20 | emb_dropout = 0 21 | attn_dropout = 0.25 22 | dropout = 0.2 23 | 24 | model_type = ROAM 25 | roi_dropout = True 26 | roi_supervise = True 27 | roi_weight = 1 28 | topk = 4 29 | roi_level = 0 30 | single_level = 0 31 | scale_type = ms 32 | embed_weightx5 = 0.3333 33 | embed_weightx10 = 0.3333 34 | embed_weightx20 = 0.3333 35 | not_interscale = False 36 | 37 | dim = 256 38 | depths = [2,2,2,2,2] 39 | heads = 8 40 | mlp_dim = 512 41 | dim_head = 64 42 | pool = cls 43 | ape = True 44 | attn_type = rel_sa 45 | shared_pe = True 46 | -------------------------------------------------------------------------------- /ROAM/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__init__.py -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/patchdataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/patchdataset.cpython-37.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/patchdataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/patchdataset.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/roidataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/roidataset.cpython-37.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/roidataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/roidataset.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/roidataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/roidataset.cpython-39.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/vis_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/vis_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/dataset/__pycache__/wsi_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/dataset/__pycache__/wsi_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/dataset/roidataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import h5py 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | # generate dataset with batchsize of 1 11 | class Wsi_Dataset_sb(Dataset): 12 | def __init__(self,slide_ids,label_ids,csv_path,data_dir,label_dict): 13 | ''' 14 | Args: 15 | slide_ids (list): Ids of all WSIs in the dataset 16 | label_ids (list): Labels of all WSIs in the dataset 17 | csv_path (string): Path to the csv file with complete data information of all available WSIs 18 | data_dir (string): Root directory of all WSI data 19 | label_dict (dict): Dictionary with key, value pairs for converting label to int that can be used fot the current task 20 | ''' 21 | super(Wsi_Dataset_sb,self).__init__() 22 | self.data_csv = pd.read_csv(csv_path) 23 | self.slide_ids_avl = self.data_csv['slide_id'].values 24 | self.slide_cls_ids = [[] for i in range(len(label_dict))] 25 | 26 | self.data_dir = data_dir 27 | self.label_dict = label_dict 28 | 29 | self.slide_data = [] 30 | self.slide_label = [] 31 | 32 | for i in range(len(label_ids)): 33 | if slide_ids[i] in self.slide_ids_avl: 34 | self.slide_data.append(slide_ids[i]) 35 | self.slide_label.append(self.label_dict[label_ids[i]]) 36 | self.slide_cls_ids[self.label_dict[label_ids[i]]].append(slide_ids[i]) 37 | assert len(self.slide_data)==len(self.slide_label) 38 | 39 | def __len__(self): 40 | return len(self.slide_data) 41 | 42 | def get_label(self,idx): 43 | label = self.slide_label[idx] 44 | return label 45 | 46 | 47 | def __getitem__(self, idx): 48 | slide_id = self.slide_data[idx] 49 | label = self.slide_label[idx] 50 | 51 | feat_path = os.path.join(self.data_dir,f'{slide_id}.h5') 52 | with h5py.File(feat_path,'r') as hdf5_file: 53 | features = hdf5_file['features'][:] # num_patches,84,1024 54 | coords = hdf5_file['coords'][:] # num_patches,2 55 | 56 | features = torch.from_numpy(features) 57 | 58 | 59 | return features,coords,label 60 | 61 | # generate dataset with batchsize exceeding 1 62 | class Wsi_Dataset_mb(Dataset): 63 | def __init__(self,slide_ids,label_ids,csv_path,data_dir,label_dict): 64 | super(Wsi_Dataset_mb,self).__init__() 65 | self.data_csv = pd.read_csv(csv_path) 66 | self.slide_ids_avl = self.data_csv['slide_id'].values 67 | self.slide_cls_ids = [[] for i in range(len(label_dict))] 68 | 69 | self.data_dir = data_dir 70 | self.label_dict = label_dict 71 | 72 | self.slide_data = [] 73 | self.slide_label = [] 74 | 75 | for i in range(len(label_ids)): 76 | if slide_ids[i] in self.slide_ids_avl: 77 | self.slide_data.append(slide_ids[i]) 78 | self.slide_label.append(self.label_dict[label_ids[i]]) 79 | self.slide_cls_ids[self.label_dict[label_ids[i]]].append(slide_ids[i]) 80 | assert len(self.slide_data)==len(self.slide_label) 81 | 82 | 83 | def __len__(self): 84 | return len(self.slide_data) 85 | 86 | def get_label(self,idx): 87 | label = self.slide_label[idx] 88 | return label 89 | 90 | def __getitem__(self, idx): 91 | slide_id = self.slide_data[idx] 92 | label = self.slide_label[idx] 93 | 94 | feat_path = os.path.join(self.data_dir,f'{slide_id}.h5') 95 | 96 | return feat_path,label # return path instead of specific data 97 | 98 | 99 | ## for cascade predict, no labels 100 | class Wsi_Dataset_pred(Dataset): 101 | def __init__(self,slide_ids,csv_path,data_dir): 102 | super(Wsi_Dataset_pred,self).__init__() 103 | self.data_csv = pd.read_csv(csv_path) 104 | self.slide_ids_avl = self.data_csv['slide_id'].values 105 | 106 | self.data_dir = data_dir 107 | 108 | #self.slide_data = self.slide_data[self.slide_data['slide_id'].isin(sample_ids)].reset_index(drop=True) 109 | self.slide_data = [] 110 | 111 | for i in range(len(slide_ids)): 112 | if slide_ids[i] in self.slide_ids_avl: 113 | self.slide_data.append(slide_ids[i]) 114 | 115 | def __len__(self): 116 | return len(self.slide_data) 117 | 118 | 119 | def __getitem__(self, idx): 120 | slide_id = self.slide_data[idx] 121 | 122 | feat_path = os.path.join(self.data_dir,f'{slide_id}.h5') 123 | with h5py.File(feat_path,'r') as hdf5_file: 124 | features = hdf5_file['features'][:] # num_patches,84,1024 125 | coords = hdf5_file['coords'][:] # num_patches,2 126 | 127 | features = torch.from_numpy(features) 128 | 129 | return features,coords -------------------------------------------------------------------------------- /ROAM/dataset/vis_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import pandas as pd 3 | import numpy as np 4 | import time 5 | import pdb 6 | import PIL.Image as Image 7 | import h5py 8 | import openslide 9 | from torch.utils.data import Dataset 10 | import torch 11 | from wsi_core.util_classes import Contour_Checking_fn, isInContourV1, isInContourV2, isInContourV3_Easy, isInContourV3_Hard 12 | import vahadane 13 | 14 | TARGET_IMAGE_DIR = 'visheatmaps/target_image_6e3_256.jpg' 15 | TARGET_IMAGE_DIR2 = 'visheatmaps/target_roi_6e3.jpg' 16 | 17 | mean = (0.485, 0.456, 0.406) 18 | std = (0.229, 0.224, 0.225) 19 | transform_patch = transforms.Compose( 20 | [# may be other transform 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean = mean, std = std) 23 | ] 24 | ) 25 | 26 | def default_transforms(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 27 | t = transforms.Compose( 28 | [transforms.ToTensor(), 29 | transforms.Normalize(mean = mean, std = std)]) 30 | return t 31 | 32 | def get_contour_check_fn(contour_fn='four_pt_hard', cont=None, ref_patch_size=None, center_shift=None): 33 | if contour_fn == 'four_pt_hard': 34 | cont_check_fn = isInContourV3_Hard(contour=cont, patch_size=ref_patch_size, center_shift=center_shift) 35 | elif contour_fn == 'four_pt_easy': 36 | cont_check_fn = isInContourV3_Easy(contour=cont, patch_size=ref_patch_size, center_shift=0.5) 37 | elif contour_fn == 'center': 38 | cont_check_fn = isInContourV2(contour=cont, patch_size=ref_patch_size) 39 | elif contour_fn == 'basic': 40 | cont_check_fn = isInContourV1(contour=cont) 41 | else: 42 | raise NotImplementedError 43 | return cont_check_fn 44 | 45 | 46 | 47 | class Wsi_Region(Dataset): 48 | ''' 49 | args: 50 | wsi_object: instance of WholeSlideImage wrapper over a WSI 51 | top_left: tuple of coordinates representing the top left corner of WSI region (Default: None) 52 | bot_right tuple of coordinates representing the bot right corner of WSI region (Default: None) 53 | level: downsample level at which to prcess the WSI region 54 | patch_size: tuple of width, height representing the patch size 55 | step_size: tuple of w_step, h_step representing the step size 56 | contour_fn (str): 57 | contour checking fn to use 58 | choice of ['four_pt_hard', 'four_pt_easy', 'center', 'basic'] (Default: 'four_pt_hard') 59 | t: custom torchvision transformation to apply 60 | custom_downsample (int): additional downscale factor to apply 61 | use_center_shift: for 'four_pt_hard' contour check, how far out to shift the 4 points 62 | ''' 63 | def __init__(self, wsi_object, slide_path, top_left=None, bot_right=None, level=0, 64 | patch_size = (4096, 4096), step_size=(512, 512), 65 | target_roi_size = (2048,2048),target_patch_size = (256,256),contour_fn='four_pt_easy', 66 | t=None, custom_downsample=1, use_center_shift=False, 67 | is_stain_norm = True, target_image_dir = None): 68 | 69 | self.custom_downsample = custom_downsample 70 | self.roi_level = level 71 | self.slide_path = slide_path 72 | #print('cont_fn:',contour_fn) 73 | 74 | # downscale factor in reference to level 0 75 | self.ref_downsample = wsi_object.level_downsamples[level] 76 | # patch size in reference to level 0 77 | self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 78 | 79 | if self.custom_downsample > 1: 80 | self.target_patch_size = patch_size 81 | patch_size = tuple((np.array(patch_size) * np.array(self.ref_downsample) * custom_downsample).astype(int)) 82 | step_size = tuple((np.array(step_size) * custom_downsample).astype(int)) 83 | self.ref_size = patch_size 84 | else: 85 | step_size = tuple((np.array(step_size)).astype(int)) 86 | self.ref_size = tuple((np.array(patch_size) * np.array(self.ref_downsample)).astype(int)) 87 | 88 | self.wsi = wsi_object.wsi 89 | self.level = level 90 | self.patch_size = target_patch_size 91 | self.target_roi_size = target_roi_size 92 | self.roi_size = patch_size 93 | self.levels = [0,1,2] 94 | #self.patch_nums = 84 95 | 96 | if not use_center_shift: 97 | center_shift = 0. 98 | else: 99 | overlap = 1 - float(step_size[0] / patch_size[0]) 100 | if overlap < 0.25: 101 | center_shift = 0.375 102 | elif overlap >= 0.25 and overlap < 0.75: 103 | center_shift = 0.5 104 | elif overlap >=0.75 and overlap < 0.95: 105 | center_shift = 0.625 106 | else: 107 | center_shift = 1.0 108 | #center_shift = 0.375 # 25% overlap 109 | #center_shift = 0.625 #50%, 75% overlap 110 | #center_shift = 1.0 #95% overlap 111 | 112 | # print(f'=========step_size:{step_size[0]}') 113 | # print(f'=========patch_size:{patch_size[0]}') 114 | filtered_coords = [] 115 | #iterate through tissue contours for valid patch coordinates 116 | for cont_idx, contour in enumerate(wsi_object.contours_tissue): 117 | print('processing {}/{} contours'.format(cont_idx, len(wsi_object.contours_tissue))) 118 | cont_check_fn = get_contour_check_fn(contour_fn, contour, self.ref_size[0], center_shift) 119 | #print(wsi_object.holes_tissue) 120 | #print(wsi_object.holes_tissue[cont_idx]) 121 | coord_results, _ = wsi_object.process_contour(contour, wsi_object.holes_tissue[cont_idx], level, '', 122 | patch_size = patch_size[0], step_size = step_size[0], contour_fn=cont_check_fn, 123 | use_padding=True, top_left = top_left, bot_right = bot_right) 124 | if len(coord_results) > 0: 125 | filtered_coords.append(coord_results['coords']) 126 | 127 | #print(filtered_coords) 128 | #print(len(filtered_coords)) 129 | coords=np.vstack(filtered_coords) 130 | 131 | self.coords = coords 132 | print('filtered a total of {} coordinates'.format(len(self.coords))) 133 | 134 | #target_image_dir = 'visheatmaps/target_roi_6e3.jpg' 135 | #print(target_image_dir) 136 | if is_stain_norm: 137 | self.target_img = np.array(Image.open(target_image_dir)) 138 | #print(self.target_img.shape) 139 | ## may raise muliti-process problem 140 | self.vhd = vahadane.vahadane(LAMBDA1=0.01,LAMBDA2=0.01,fast_mode=0,ITER=100) 141 | self.Wt,self.Ht = self.vhd.stain_separate(self.target_img) 142 | #self.vhd.fast_mode = 1 #fast separate 143 | self.is_stain_norm = is_stain_norm 144 | 145 | def __len__(self): 146 | return len(self.coords) 147 | 148 | def stain_norm(self,src_img): 149 | #print(src_img.shape) 150 | #Image.fromarray(src_img).save('test.jpg') 151 | #vhd = vahadane.vahadane(LAMBDA1=0.01,LAMBDA2=0.01,fast_mode=0,ITER=100) 152 | 153 | Ws,Hs = self.vhd.stain_separate(src_img) 154 | #print(src_img.shape) 155 | #print(Ws,Hs,self.Wt,self.Ht) 156 | img = self.vhd.SPCN(src_img,Ws,Hs,self.Wt,self.Ht) 157 | return img 158 | 159 | def __getitem__(self, idx): 160 | coord = self.coords[idx] 161 | 162 | try: 163 | img = self.wsi.read_region(coord, self.roi_level, (self.roi_size[0], self.roi_size[1])).convert('RGB') 164 | except: 165 | # or subsequent normal patches will also raise errors 166 | self.wsi = openslide.open_slide(self.slide_path) 167 | available = False 168 | #img = np.zeros((self.roi_size,self.patch_size,3)) 169 | #return img, coord, torch.tensor([False]) 170 | else: 171 | img = self.wsi.read_region(coord, self.roi_level, (self.roi_size[0], self.roi_size[1])).convert('RGB') 172 | available = True 173 | 174 | #patch_num_all = np.sum(self.patch_nums) 175 | patch_num_all = 84 176 | if not available: 177 | img_batch = torch.zeros((patch_num_all,3,self.patch_size[0],self.patch_size[0])) 178 | #print('????????') 179 | else: 180 | img_batch = [] 181 | img_roi = img.resize((self.target_roi_size[0],self.target_roi_size[1])) 182 | #img_roi.save('test_roi.jpg') 183 | if self.is_stain_norm: 184 | img_roi = self.stain_norm(np.array(img_roi)) 185 | #Image.fromarray(img_roi).save('test_roi_pro.jpg') 186 | for level in self.levels: 187 | roi_size_cur = int(self.target_roi_size[0]/(2**level)) 188 | img_roi = np.array(img_roi) 189 | img_cur = Image.fromarray(img_roi).resize((roi_size_cur,roi_size_cur)) 190 | 191 | imgarray = np.array(img_cur) 192 | for i in range(0,roi_size_cur,self.patch_size[0]): 193 | for j in range(0,roi_size_cur,self.patch_size[1]): 194 | img_patch = imgarray[i:i+self.patch_size[0],j:j+self.patch_size[1],:] 195 | img_patch = transform_patch(img_patch) 196 | img_batch.append(img_patch) 197 | 198 | #img_batch = torch.stack(img_batch).unsqueeze(0) #(1,84,3,256,256) 199 | if available: 200 | img_batch = torch.stack(img_batch) 201 | #print('img_batch_shape:',img_batch.shape) 202 | 203 | 204 | return img_batch, coord, torch.tensor([available]) 205 | -------------------------------------------------------------------------------- /ROAM/gen_visheatmaps_roi_batch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from secrets import choice 3 | import sys 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from PIL import Image 8 | from torchvision import transforms 9 | import numpy as np 10 | import pandas as pd 11 | import cv2 12 | import vahadane 13 | import matplotlib.pyplot as plt 14 | 15 | from vis_utils.vit_rollout import VITAttentionRollout 16 | from vis_utils.vit_grad_rollout import VITAttentionGradRollout 17 | from models.ROAM import ROAM_VIS 18 | from models_embed.extractor import resnet50 19 | import models_embed.ResNet as ResNet 20 | from models_embed.ccl import CCL 21 | from models_embed.ctran import ctranspath 22 | from models_embed.simclr_ciga import simclr_ciga_model 23 | from parse_config import parse_args_heatmap_roi 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 26 | 27 | 28 | cls_name_dict = { 29 | 'int_glioma_tumor_subtyping':['astrocytoma','oligodendroglioma','ependymoma'], 30 | 'ext_glioma_tumor_subtyping3':['astrocytoma','oligodendroglioma'], 31 | 'int_glioma_cls':['normal','gliosis','tumor'] 32 | } 33 | 34 | 35 | ## for stain normalization 36 | mean = (0.485, 0.456, 0.406) 37 | std = (0.229, 0.224, 0.225) 38 | transform_patch = transforms.Compose( 39 | [# may be other transform 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean = mean, std = std) 42 | ] 43 | ) 44 | target_img = np.array(Image.open('visheatmaps/target_roi_6e3.jpg')) 45 | vhd = vahadane.vahadane(LAMBDA1=0.01,LAMBDA2=0.01,fast_mode=0,ITER=100) 46 | Wt,Ht = vhd.stain_separate(target_img) 47 | 48 | 49 | 50 | def preprocess_image(src_img): 51 | target_roi_size = 2048 52 | patch_size = 256 53 | img_batch = [] 54 | img_roi = src_img.resize((target_roi_size,target_roi_size)) 55 | # stain normalization 56 | Ws,Hs = vhd.stain_separate(np.array(img_roi)) 57 | img = vhd.SPCN(np.array(img_roi),Ws,Hs,Wt,Ht) 58 | 59 | # segment roi to 84 patches with size of 256 60 | for size in [2048,1024,512]: 61 | img_cur = Image.fromarray(img).resize((size,size)) 62 | img_array = np.array(img_cur) 63 | for i in range(0,size,patch_size): 64 | for j in range(0,size,patch_size): 65 | img_patch = img_array[i:i+patch_size,j:j+patch_size,:] 66 | img_patch = transform_patch(img_patch) 67 | img_batch.append(img_patch) 68 | 69 | img_batch = torch.stack(img_batch) #84,3,256,256 70 | return img_batch 71 | 72 | def masks_preprocess(masks,img,w): 73 | ''' 74 | Sum the attention masks of three magnification levels, weighted by the model's 'embed_weights' 75 | parameters, to obtain teh final attention heatmap of the ROI. 76 | 77 | Args: 78 | masks (List of array): attention masks of 3 magnifications (20x:(8,8), 10x:(4,4),5x(2,2)) 79 | img (PIL.Image): ROI image to be visualized. 80 | w (List of float): attention weights of 3 magnifications, same with parameters 'embed_weights' during training. 81 | ''' 82 | np_img = np.array(img)[:, :, ::-1] #RGB->BGR 83 | mask_weighted = 0 84 | # multi-scale 85 | for i in range(3): 86 | mask = masks[i] 87 | mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0])) 88 | mask_weighted += w[i]*mask 89 | 90 | masked_img = show_mask_on_image_weighted(np_img, mask_weighted) 91 | 92 | return masked_img,mask_weighted 93 | 94 | 95 | 96 | 97 | 98 | def show_mask_on_image(img, mask): 99 | img = np.float32(img) / 255 100 | 101 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 102 | heatmap = np.float32(heatmap) / 255 103 | #print(heatmap.shape) 104 | cam = heatmap + np.float32(img) 105 | 106 | cam = cam / np.max(cam) 107 | return np.uint8(255 * cam) 108 | 109 | def show_mask_on_image_weighted(img, mask, alpha=0.3): 110 | ''' 111 | draw heatmaps according to attention scores 112 | ''' 113 | cmap = plt.get_cmap('jet') 114 | 115 | mask_block = (cmap(mask) * 255)[:,:,:3].astype(np.uint8)[:,:,::-1] 116 | cam = cv2.addWeighted(mask_block,alpha,img,1-alpha,0) 117 | 118 | return cam 119 | 120 | 121 | if __name__ == '__main__': 122 | assert len(sys.argv)==3, 'please give configuration file and split seed!' 123 | 124 | 125 | args, task_info = parse_args_heatmap_roi(sys.argv[1], sys.argv[2]) 126 | if args.level == -1: 127 | args.level = args.depths[-1] 128 | ## embed_model 129 | print('load embedding model') 130 | if args.embed_type == 'ImageNet': 131 | embed_model = resnet50(pretrained=True).cuda() 132 | patch_dim = 1024 133 | elif args.embed_type == 'RetCCL': 134 | backbone = ResNet.resnet50 135 | embed_model = CCL(backbone, 128, 65536, mlp=True, two_branch=True, normlinear=True).cuda() 136 | ckpt_path = f'models_embed/RetCCL_ckpt.pth' 137 | embed_model.load_state_dict(torch.load(ckpt_path),strict=True) 138 | embed_model.encoder_q.fc = nn.Identity() 139 | embed_model.encoder_q.instDis = nn.Identity() 140 | embed_model.encoder_q.groupDis = nn.Identity() 141 | patch_dim = 2048 142 | elif args.embed_type == 'ctranspath': 143 | embed_model = ctranspath().cuda() 144 | embed_model.head = nn.Identity() 145 | td = torch.load(r'models_embed/ctranspath.pth') 146 | embed_model.load_state_dict(td['model'], strict=True) 147 | patch_dim = 768 148 | else: 149 | embed_model = simclr_ciga_model().cuda() 150 | patch_dim = 1024 151 | 152 | embed_model.eval() 153 | 154 | 155 | if args.embed_weightx5==None and args.embed_weightx10==None and args.embed_weightx20==None: 156 | embed_weights = None 157 | print('use learnabel weights') 158 | else: 159 | embed_weights = [args.embed_weightx5,args.embed_weightx10,args.embed_weightx20] 160 | print('set weights:', embed_weights) 161 | ## main model ROAM 162 | print('load main model ROAM') 163 | 164 | model = ROAM_VIS(choose_num = args.topk, 165 | num_patches = 84, 166 | patch_dim=patch_dim, 167 | num_classes=task_info[args.task]['n_classes'], 168 | roi_level = args.roi_level, 169 | scale_type = args.scale_type, 170 | embed_weights=embed_weights, 171 | dim=args.dim, 172 | depths=args.depths, 173 | heads=args.heads, 174 | mlp_dim=args.mlp_dim, 175 | dim_head=args.dim_head, 176 | dropout=args.dropout, 177 | emb_dropout=args.emb_dropout, 178 | attn_dropout=args.attn_dropout, 179 | pool=args.pool, 180 | ape=args.ape, 181 | attn_type=args.attn_type, 182 | shared_pe=args.shared_pe) 183 | model = model.cuda() 184 | 185 | 186 | print('exp_code:',args.exp_code) 187 | 188 | 189 | # read topk roi list 190 | img_root = f'visheatmaps/slide_vis/results/heatmap_production_results/{args.exp_code}/sampled_patches' 191 | slide_list = pd.read_csv(args.process_list) 192 | slide_ids = slide_list['slide_id'].values 193 | labels = slide_list['label'].values 194 | cls_name = cls_name_dict[args.task] 195 | 196 | for i in range(len(slide_ids)): 197 | sid = slide_ids[i] 198 | label = labels[i] 199 | catname = cls_name[label] 200 | print(f'===process topk roi of slide: {sid}') 201 | 202 | imagelist = os.listdir(f'{img_root}/pred_{label}_label_{label}/topk_high_attention') 203 | 204 | topidx = 0 205 | 206 | for imgpath in sorted(imagelist): 207 | img_name = str(topidx) + '_' + sid 208 | if img_name not in imgpath: continue 209 | topidx += 1 210 | print(f'topkidx:{topidx}') 211 | if topidx > args.topk_num: break 212 | mask_all = [] 213 | print(f'===process topk roi {topidx}/{args.topk_num}') 214 | img_save_path = f'visheatmaps/roi_vis/{args.sample}/{args.exp_code}/{args.vis_type}/{catname}/{sid}' 215 | if not os.path.exists(img_save_path): 216 | os.makedirs(img_save_path) 217 | image_path = os.path.join(f'{img_root}/pred_{label}_label_{label}/topk_high_attention',imgpath) 218 | 219 | ## read origin image 220 | img = Image.open(image_path) 221 | img_batch = preprocess_image(img) 222 | 223 | ## extract features 224 | input = img_batch.cuda() 225 | for k in range(5): 226 | print(f'split{k}') 227 | model_path = os.path.join(f'results/{args.task}/{args.exp_code}/{args.split_seed}',f'ROAM_split{k}.pth') 228 | #print('model_path:',model_path) 229 | model.load_state_dict(torch.load(model_path)) 230 | model.eval() 231 | print('done!') 232 | 233 | features = embed_model(input) #84,2048 234 | input_tensor = features.unsqueeze(0) #1,84,2048 235 | #print(input_tensor.device) 236 | 237 | ## get attention grad 238 | 239 | if args.category_index is None: 240 | print("Doing Attention Rollout") 241 | attention_rollout = VITAttentionRollout(model, head_fusion=args.head_fusion, 242 | discard_ratio=args.discard_ratio) 243 | masks = attention_rollout(input_tensor) 244 | name = "attention_rollout_{:.3f}_{}.png".format(args.discard_ratio, args.head_fusion) 245 | else: 246 | print("Doing Gradient Attention Rollout") 247 | if embed_weights: 248 | grad_rollout = VITAttentionGradRollout(model, args.level, discard_ratio=args.discard_ratio,vis_type=args.vis_type,vis_scale=args.vis_scale) 249 | w = embed_weights 250 | else: 251 | grad_rollout,w = VITAttentionGradRollout(model, args.level, discard_ratio=args.discard_ratio,vis_type=args.vis_type,vis_scale=args.vis_scale,learnable_weights=True) 252 | masks = grad_rollout(input_tensor, args.category_index) 253 | name = "grad_rollout_{}_{:.3f}_{}.png".format(args.category_index, 254 | args.discard_ratio, args.head_fusion) 255 | 256 | 257 | print('embed_weights:',w) 258 | masked_img,mask = masks_preprocess(masks,img,w) 259 | 260 | if args.vis_type == 'grad_rollout': 261 | Image.fromarray(masked_img[:,:,::-1]).convert('RGB').resize((1024,1024)).save(os.path.join(img_save_path,f'top{topidx}_seed{args.split_seed}_{args.category_index}_d{args.depths[-1]}_l{args.level}_r{args.discard_ratio}_fold{k}.png')) 262 | else: 263 | Image.fromarray(masked_img[:,:,::-1]).convert('RGB').resize((1024,1024)).save(os.path.join(img_save_path,f'top{topidx}_seed{args.split_seed}_{args.category_index}_d{args.depths[-1]}_fold{k}.png')) 264 | 265 | ## exclude mask with nan value 266 | if not np.isnan(mask).any(): 267 | mask_all.append(mask) 268 | 269 | ## average attention mask for 5 splits in each seed 270 | mask_avg = np.mean(mask_all,0) 271 | 272 | mask_avg = mask_avg/np.max(mask_avg) #normalization 273 | 274 | 275 | np_img = np.array(img)[:, :, ::-1] #RGB->BGR 276 | masked_img_avg = show_mask_on_image_weighted(np_img, mask_avg) 277 | 278 | if args.vis_type == 'grad_rollout': 279 | Image.fromarray(masked_img_avg[:,:,::-1]).convert('RGB').resize((1024,1024)).save(os.path.join(img_save_path,f'top{topidx}_seed{args.split_seed}_{args.category_index}_d{args.depths[-1]}_l{args.level}_r{args.discard_ratio}_avg.png')) 280 | else: 281 | Image.fromarray(masked_img_avg[:,:,::-1]).convert('RGB').resize((1024,1024)).save(os.path.join(img_save_path,f'top{topidx}_seed{args.split_seed}_{args.category_index}_d{args.depths[-1]}_avg.png')) 282 | img.resize((1024,1024)).save(os.path.join(img_save_path,f'top{topidx}_seed{args.split_seed}_ori.png')) -------------------------------------------------------------------------------- /ROAM/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/models/__init__.py -------------------------------------------------------------------------------- /ROAM/models_embed/ccl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | from PIL import Image 6 | import os 7 | 8 | 9 | class CCL(nn.Module): 10 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, two_branch=False, normlinear=False, normalize=False): 11 | super(CCL, self).__init__() 12 | 13 | self.K = K 14 | self.m = m 15 | self.T = T 16 | self.two_branch = two_branch 17 | self.normalize = normalize 18 | 19 | # create the encoders 20 | # num_classes is the output fc dimension 21 | self.encoder_q = base_encoder(num_classes=dim, two_branch=two_branch, mlp=mlp, normlinear=normlinear) 22 | self.encoder_k = base_encoder(num_classes=dim, two_branch=two_branch, mlp=mlp, normlinear=normlinear) 23 | 24 | if mlp and not two_branch: # hack: brute-force replacement 25 | dim_mlp = self.encoder_q.fc.weight.shape[1] 26 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 27 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 28 | 29 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 30 | param_k.data.copy_(param_q.data) # initialize 31 | param_k.requires_grad = False # not update by gradient 32 | 33 | def forward(self, im_q): 34 | # compute query features 35 | q = self.encoder_q(im_q) # queries: NxC 36 | if self.two_branch: 37 | eq1 = nn.functional.normalize(q[1], dim=1) # branch 2 38 | q = q[0] # branch 1 39 | if self.normalize: 40 | print(1) 41 | q = nn.functional.normalize(q, dim=1) 42 | return q -------------------------------------------------------------------------------- /ROAM/models_embed/ctran.py: -------------------------------------------------------------------------------- 1 | from timm.models.layers.helpers import to_2tuple 2 | import timm 3 | import torch.nn as nn 4 | 5 | 6 | class ConvStem(nn.Module): 7 | 8 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 9 | super().__init__() 10 | 11 | assert patch_size == 4 12 | assert embed_dim % 8 == 0 13 | 14 | img_size = to_2tuple(img_size) 15 | patch_size = to_2tuple(patch_size) 16 | self.img_size = img_size 17 | self.patch_size = patch_size 18 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 19 | self.num_patches = self.grid_size[0] * self.grid_size[1] 20 | self.flatten = flatten 21 | 22 | 23 | stem = [] 24 | input_dim, output_dim = 3, embed_dim // 8 25 | for l in range(2): 26 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 27 | stem.append(nn.BatchNorm2d(output_dim)) 28 | stem.append(nn.ReLU(inplace=True)) 29 | input_dim = output_dim 30 | output_dim *= 2 31 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 32 | self.proj = nn.Sequential(*stem) 33 | 34 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 35 | 36 | def forward(self, x): 37 | B, C, H, W = x.shape 38 | assert H == self.img_size[0] and W == self.img_size[1], \ 39 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 40 | x = self.proj(x) 41 | if self.flatten: 42 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 43 | x = self.norm(x) 44 | return x 45 | 46 | def ctranspath(): 47 | model = timm.create_model('swin_tiny_patch4_window7_224', embed_layer=ConvStem, pretrained=False) 48 | return model -------------------------------------------------------------------------------- /ROAM/models_embed/extractor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch 4 | from torchsummary import summary 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | class Bottleneck_Baseline(nn.Module): 19 | expansion = 4 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(Bottleneck_Baseline, self).__init__() 23 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 29 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | class ResNet_Baseline(nn.Module): 57 | 58 | def __init__(self, block, layers): 59 | self.inplanes = 64 60 | super(ResNet_Baseline, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | self.layer1 = self._make_layer(block, 64, layers[0]) 67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 69 | self.avgpool = nn.AdaptiveAvgPool2d(1) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | elif isinstance(m, nn.BatchNorm2d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def _make_layer(self, block, planes, blocks, stride=1): 79 | downsample = None 80 | if stride != 1 or self.inplanes != planes * block.expansion: 81 | downsample = nn.Sequential( 82 | nn.Conv2d(self.inplanes, planes * block.expansion, 83 | kernel_size=1, stride=stride, bias=False), 84 | nn.BatchNorm2d(planes * block.expansion), 85 | ) 86 | 87 | layers = [] 88 | layers.append(block(self.inplanes, planes, stride, downsample)) 89 | self.inplanes = planes * block.expansion 90 | for i in range(1, blocks): 91 | layers.append(block(self.inplanes, planes)) 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.conv1(x) 97 | x = self.bn1(x) 98 | x = self.relu(x) 99 | x = self.maxpool(x) 100 | 101 | x = self.layer1(x) 102 | x = self.layer2(x) 103 | x = self.layer3(x) 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | return x 109 | 110 | def resnet50(pretrained=False,pretrained_weights=None): 111 | """Constructs a Modified ResNet-50 model. 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | pretrained_weights: not None for pathological image 115 | """ 116 | model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3]) 117 | if pretrained: 118 | if pretrained_weights: 119 | model.load_state_dict(pretrained_weights) 120 | else: 121 | model = load_pretrained_weights(model, 'resnet50') 122 | return model 123 | 124 | def load_pretrained_weights(model, name): 125 | pretrained_dict = model_zoo.load_url(model_urls[name]) 126 | model.load_state_dict(pretrained_dict, strict=False) 127 | return model 128 | 129 | -------------------------------------------------------------------------------- /ROAM/models_embed/simclr_ciga.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | 4 | 5 | MODEL_PATH = 'models/Simclr_ciga.ckpt' 6 | RETURN_PREACTIVATION = True # return features from the model, if false return classification logits 7 | NUM_CLASSES = 4 # only used if RETURN_PREACTIVATION = False 8 | 9 | 10 | def load_model_weights(model, weights): 11 | 12 | model_dict = model.state_dict() 13 | weights = {k: v for k, v in weights.items() if k in model_dict} 14 | if weights == {}: 15 | print('No weight could be loaded..') 16 | model_dict.update(weights) 17 | model.load_state_dict(model_dict) 18 | 19 | return model 20 | 21 | def simclr_ciga_model(): 22 | model = torchvision.models.__dict__['resnet18'](pretrained=False) 23 | 24 | state = torch.load(MODEL_PATH, map_location='cuda:0') 25 | 26 | state_dict = state['state_dict'] 27 | for key in list(state_dict.keys()): 28 | state_dict[key.replace('model.', '').replace('resnet.', '')] = state_dict.pop(key) 29 | 30 | model = load_model_weights(model, state_dict) 31 | 32 | if RETURN_PREACTIVATION: 33 | model.fc = torch.nn.Sequential() 34 | else: 35 | model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES) 36 | 37 | return model 38 | -------------------------------------------------------------------------------- /ROAM/parse_config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import time 3 | 4 | class Args: 5 | def __init__(self): 6 | pass 7 | 8 | def parse_args(config_dir, split_seed): 9 | args = Args 10 | cf = configparser.ConfigParser() 11 | cf.read(config_dir) 12 | task = cf.sections()[0] 13 | print(f'Current task: {task}, Configuration file: {config_dir}') 14 | 15 | args.task = task 16 | 17 | #### 18 | # random seed of the model for training 19 | args.seed = cf.getint(task, 'seed') 20 | # random seed for spliting training data (s0~s4) 21 | args.split_seed = split_seed 22 | # stage of the process, train or test 23 | args.stage = cf.get(task, 'stage') 24 | # pre-trained model for extracting patch features: ImageNet(default), RetCCL, ctraspath, simclr-ciga 25 | args.embed_type = cf.get(task, 'embed_type') 26 | # percentage of data used for training: 20%,40%,60%,100%(default) 27 | args.sample_size = cf.getint(task, 'sample_size') 28 | # whether to use stain normalization: True(default) or False 29 | args.not_stainnorm = cf.getboolean(task, 'not_stainnorm') 30 | # type of test dataset: xiangya(in-house validation), TCGA(external validation) 31 | args.test_dataset = cf.get(task, 'test_dataset') 32 | # root directory of feature data 33 | args.data_root_dir = cf.get(task,'data_root_dir') 34 | # root directory to save all results 35 | args.results_dir = cf.get(task, 'results_dir') 36 | 37 | #### training 38 | # the maxmium epochs of training epochs allowed 39 | args.max_epochs = cf.getint(task, 'max_epochs') 40 | # size of a training batch 41 | args.batch_size = cf.getint(task, 'batch_size') 42 | args.lr = cf.getfloat(task, 'lr') 43 | args.optimizer = cf.get(task, 'optimizer') 44 | args.weight_decay = cf.getfloat(task, 'weight_decay') 45 | args.scheduler = cf.get(task, 'scheduler') 46 | # early stop when metrics have not improved for certain epochs 47 | args.stop_epochs = cf.getint(task, 'stop_epochs') 48 | # sampling weights for each class in the dataloader, for class imbalance 49 | args.weighted_sample = cf.getboolean(task, 'weighted_sample') 50 | args.emb_dropout = cf.getfloat(task, 'emb_dropout') 51 | args.attn_dropout = cf.getfloat(task, 'attn_dropout') 52 | args.dropout = cf.getfloat(task, 'dropout') 53 | 54 | #### ROAM specific options 55 | # name of the model 56 | args.model_type = cf.get(task, 'model_type') 57 | args.roi_dropout = cf.getboolean(task, 'roi_dropout') 58 | args.roi_supervise = cf.getboolean(task, 'roi_supervise') 59 | args.roi_weight = cf.getfloat(task, 'roi_weight') 60 | # the number of instances used in instance-level supervision 61 | args.topk = cf.getint(task, 'topk') 62 | # size of ROI at 20x, (0:2048,1:1024,2:512) 63 | args.roi_level = cf.getint(task, 'roi_level') 64 | # multi-scale ('ms') or single scale ('ss') 65 | args.scale_type = cf.get(task, 'scale_type') 66 | # magnification scale of input ROI for single-scale model. (0:20x,1:10x,2:5x) 67 | args.single_level = cf.getint(task, 'single_level') 68 | # weight coefficient of instance embedding at each magnificant level 69 | args.embed_weightx5 = eval(cf.get(task, 'embed_weightx5')) 70 | args.embed_weightx10 = eval(cf.get(task, 'embed_weightx10')) 71 | args.embed_weightx20 = eval(cf.get(task, 'embed_weightx20')) 72 | # whether to use inter-scale self-attention module. False (with inter-scale SA), True (without inter-scale SA) 73 | args.not_interscale = cf.getboolean(task, 'not_interscale') 74 | # model config 75 | args.dim = cf.getint(task, 'dim') 76 | args.depths = eval(cf.get(task, 'depths')) 77 | args.heads = cf.getint(task, 'heads') 78 | args.mlp_dim = cf.getint(task, 'mlp_dim') 79 | args.dim_head = cf.getint(task, 'dim_head') 80 | args.pool = cf.get(task, 'pool') 81 | args.ape = cf.getboolean(task, 'ape') 82 | args.attn_type = cf.get(task, 'attn_type') 83 | args.shared_pe = cf.getboolean(task, 'shared_pe') 84 | # name of the experiment 85 | args.exp_code = '_'.join(map(str, [args.task, args.depths, 86 | args.embed_type, 87 | args.batch_size, args.roi_dropout, 88 | args.roi_supervise, 89 | args.roi_weight, args.topk, 90 | args.roi_level, 91 | args.scale_type, args.single_level, 92 | args.not_interscale])) 93 | 94 | print('exp_code: {}'.format(args.exp_code)) 95 | # information of all tasks 96 | # example tasks 97 | task_info = { 98 | ## in-house validataion with xiangya test dataset 99 | 'int_glioma_detection':{'csv_path':'../data_prepare/data_csv/xiangya_data_info_pro.csv', 100 | 'label_dict':{0:0,1:1,2:2,3:2,4:2}, 101 | 'n_classes': 3, 102 | 'split_dir': f'../data_prepare/data_split/xiangya_split_detection/xiangya_split_detection_{split_seed}.npy', 103 | 'test_split_dir': '../data_prepare/data_split/xiangya_split_detection/test_split_label_detection.npy', 104 | 'cls_weights':[50,24,514]}, 105 | 106 | 'int_glioma_tumor_subtyping':{'csv_path': '../data_prepare/data_csv/example_xiangya_data_info_pro.csv', 107 | 'label_dict': {i+2:i for i in range(3)}, 108 | 'n_classes': 3, 109 | 'split_dir': f'../data_prepare/data_split/xiangya_split_subtype/xiangya_split_subtype_size{args.sample_size}_{split_seed}.npy', 110 | 'test_split_dir': '../data_prepare/data_split/xiangya_split_subtype/example_test_split.npy', 111 | 'cls_weights':[281,119,111]}, 112 | # external validation with TCGA test dataset 113 | 'ext_glioma_tumor_subtyping3':{'csv_path':'../data_prepare/data_csv/tcga_data_info_pro.csv', 114 | 'label_dict':{3:0,4:0,5:0,7:1,8:1}, 115 | 'n_classes': 2, 116 | 'split_dir': f'../data_prepare/data_split/xiangya_tcga_split_subtype2/xiangya_split_subtype2_{split_seed}.npy', 117 | 'label_dict_ext': {'astrocytoma_G2':0,'astrocytoma_G3':0,'glioblastoma_G4':0,'oligodendroglioma_G2':1,'oligodendroglioma_G3':1}, 118 | 'test_split_dir_ext': '../data_prepare/data_split/xiangya_tcga_split_subtype2/TCGA_test_split_label_subtype2.npy', 119 | 'cls_weights':[261,179]}, 120 | 121 | } 122 | 123 | 124 | return args, task_info 125 | 126 | 127 | def parse_args_heatmap_roi(config_dir, split_seed): 128 | args = Args 129 | cf = configparser.ConfigParser() 130 | cf.read(config_dir) 131 | task = cf.sections()[0] 132 | print(f'Current task: {task}, Configuration file: {config_dir}') 133 | 134 | args.task = task 135 | 136 | #### 137 | args.seed = cf.getint(task, 'seed') 138 | args.split_seed = split_seed 139 | args.embed_type = cf.get(task, 'embed_type') 140 | args.not_stainnorm = cf.getboolean(task, 'not_stainnorm') 141 | args.emb_dropout = cf.getfloat(task, 'emb_dropout') 142 | args.attn_dropout = cf.getfloat(task, 'attn_dropout') 143 | args.dropout = cf.getfloat(task, 'dropout') 144 | 145 | args.batch_size = cf.getint(task, 'batch_size') 146 | 147 | #### ROAM specific options 148 | args.model_type = cf.get(task, 'model_type') 149 | args.roi_dropout = cf.getboolean(task, 'roi_dropout') 150 | args.roi_supervise = cf.getboolean(task, 'roi_supervise') 151 | args.roi_weight = cf.getfloat(task, 'roi_weight') 152 | args.topk = cf.getint(task, 'topk') 153 | args.roi_level = cf.getint(task, 'roi_level') 154 | args.scale_type = cf.get(task, 'scale_type') 155 | args.single_level = cf.getint(task, 'single_level') 156 | args.embed_weightx5 = eval(cf.get(task, 'embed_weightx5')) 157 | args.embed_weightx10 = eval(cf.get(task, 'embed_weightx10')) 158 | args.embed_weightx20 = eval(cf.get(task, 'embed_weightx20')) 159 | args.not_interscale = cf.getboolean(task, 'not_interscale') 160 | 161 | args.dim = cf.getint(task, 'dim') 162 | args.depths = eval(cf.get(task, 'depths')) 163 | args.heads = cf.getint(task, 'heads') 164 | args.mlp_dim = cf.getint(task, 'mlp_dim') 165 | args.dim_head = cf.getint(task, 'dim_head') 166 | args.pool = cf.get(task, 'pool') 167 | args.ape = cf.getboolean(task, 'ape') 168 | args.attn_type = cf.get(task, 'attn_type') 169 | args.shared_pe = cf.getboolean(task, 'shared_pe') 170 | 171 | # roi vis parameters 172 | #args.image_path = cf.get(task,'image_path') 173 | args.process_list = cf.get(task,'process_list') 174 | args.topk_num = cf.getint(task,'topk_num') 175 | args.vis_type = cf.get(task,'vis_type') 176 | args.vis_scale = cf.get(task,'vis_scale') 177 | args.sample = cf.get(task,'sample') 178 | args.level = cf.getint(task, 'level') 179 | args.head_fusion = cf.get(task, 'head_fusion') 180 | args.discard_ratio = cf.getfloat(task, 'discard_ratio') 181 | args.category_index = cf.getint(task,'category_index') 182 | 183 | args.exp_code = '_'.join(map(str, [args.task, args.depths, 184 | args.embed_type, 185 | args.batch_size, args.roi_dropout, 186 | args.roi_supervise, 187 | args.roi_weight, args.topk, 188 | args.roi_level, 189 | args.scale_type,args.single_level, 190 | args.not_interscale])) 191 | 192 | task_info = { 193 | ## in-house validataion with xiangya test dataset 194 | 'int_glioma_detection':{'csv_path':'../data_prepare/data_csv/xiangya_data_info_pro.csv', 195 | 'label_dict':{0:0,1:1,2:2,3:2,4:2}, 196 | 'n_classes': 3, 197 | 'split_dir': f'../data_prepare/data_split/xiangya_split_detection/xiangya_split_detection_{split_seed}.npy', 198 | 'test_split_dir': '../data_prepare/data_split/xiangya_split_detection/test_split_label_detection.npy', 199 | 'cls_weights':[50,24,514]}, 200 | 201 | 'int_glioma_tumor_subtyping':{'csv_path': '../data_prepare/data_csv/example_xiangya_data_info_pro.csv', 202 | 'label_dict': {i+2:i for i in range(3)}, 203 | 'n_classes': 3, 204 | 'split_dir': f'../data_prepare/data_split/xiangya_split_subtype/xiangya_split_subtype_size100_{split_seed}.npy', 205 | 'test_split_dir': '../data_prepare/data_split/xiangya_split_subtype/example_test_split.npy', 206 | 'cls_weights':[281,119,111]}, 207 | # external validation with TCGA test dataset 208 | 'ext_glioma_tumor_subtyping3':{'csv_path':'../data_prepare/data_csv/tcga_data_info_pro.csv', 209 | 'label_dict':{3:0,4:0,5:0,7:1,8:1}, 210 | 'n_classes': 2, 211 | 'split_dir': f'../data_prepare/data_split/xiangya_tcga_split_subtype2/xiangya_split_subtype2_{split_seed}.npy', 212 | 'label_dict_ext': {'astrocytoma_G2':0,'astrocytoma_G3':0,'glioblastoma_G4':0,'oligodendroglioma_G2':1,'oligodendroglioma_G3':1}, 213 | 'test_split_dir_ext': '../data_prepare/data_split/xiangya_tcga_split_subtype2/TCGA_test_split_label_subtype2.npy', 214 | 'cls_weights':[261,179]}, 215 | 216 | 217 | } 218 | return args, task_info 219 | 220 | 221 | def read_taskinfo(split_seed): 222 | 223 | task_info = { 224 | ## in-house validataion with xiangya test dataset 225 | 'int_glioma_detection':{'csv_path':'../data_prepare/data_csv/xiangya_data_info_pro.csv', 226 | 'label_dict':{0:0,1:1,2:2,3:2,4:2}, 227 | 'n_classes': 3, 228 | 'split_dir': f'../data_prepare/data_split/xiangya_split_detection/xiangya_split_detection_{split_seed}.npy', 229 | 'test_split_dir': '../data_prepare/data_split/xiangya_split_detection/test_split_label_detection.npy', 230 | 'cls_weights':[50,24,514]}, 231 | 232 | 'int_glioma_tumor_subtyping':{'csv_path': '../data_prepare/data_csv/xiangya_data_info_pro.csv', 233 | 'label_dict': {i+2:i for i in range(3)}, 234 | 'n_classes': 3, 235 | 'split_dir': f'../data_prepare/data_split/xiangya_split_subtype/xiangya_split_subtype_size100_{split_seed}.npy', 236 | 'test_split_dir': '../data_prepare/data_split/xiangya_split_subtype/test_split_label_subtype.npy', 237 | 'cls_weights':[281,119,111]}, 238 | # external validation with TCGA test dataset 239 | 'ext_glioma_tumor_subtyping3':{'csv_path':'../data_prepare/data_csv/tcga_data_info_pro.csv', 240 | 'label_dict':{3:0,4:0,5:0,7:1,8:1}, 241 | 'n_classes': 2, 242 | 'split_dir': f'../data_prepare/data_split/xiangya_tcga_split_subtype2/xiangya_split_subtype2_{split_seed}.npy', 243 | 'label_dict_ext': {'astrocytoma_G2':0,'astrocytoma_G3':0,'glioblastoma_G4':0,'oligodendroglioma_G2':1,'oligodendroglioma_G3':1}, 244 | 'test_split_dir_ext': '../data_prepare/data_split/xiangya_tcga_split_subtype2/TCGA_test_split_label_subtype2.npy', 245 | 'cls_weights':[261,179]}, 246 | 247 | } 248 | return task_info 249 | 250 | 251 | -------------------------------------------------------------------------------- /ROAM/position_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def positionalencoding1d(d_model, length, ratio=1): 6 | """ 7 | :param d_model: dimension of the model 8 | :param length: length of positions 9 | :return: (length+1)*d_model position matrix 10 | """ 11 | if d_model % 2 != 0: 12 | raise ValueError("Cannot use sin/cos positional encoding with " 13 | "odd dim (got dim={:d})".format(d_model)) 14 | pe = torch.zeros(length+1, d_model) 15 | position = torch.arange(0, length+1).unsqueeze(1) 16 | div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * 17 | -(math.log(10000.0) / d_model))*ratio 18 | pe[:, 0::2] = torch.sin(position.float() * div_term) 19 | pe[:, 1::2] = torch.cos(position.float() * div_term) 20 | 21 | return pe 22 | 23 | 24 | def positionalencoding2d(d_model, height, width, ratio=1): 25 | """ 26 | :param d_model: dimension of the model 27 | :param height: height of the positions 28 | :param width: width of the positions 29 | :return: d_model*height*width position matrix 30 | """ 31 | if d_model % 4 != 0: 32 | raise ValueError("Cannot use sin/cos positional encoding with " 33 | "odd dimension (got dim={:d})".format(d_model)) 34 | pe = torch.zeros(height*width+1, d_model) 35 | # Each dimension use half of d_model 36 | d_model = int(d_model / 2) 37 | 38 | height_pe = positionalencoding1d(d_model, height, ratio) 39 | width_pe = positionalencoding1d(d_model, width, ratio) 40 | 41 | #print(height_pe.shape, width_pe.shape) 42 | 43 | pe[0, :d_model] = height_pe[0] 44 | pe[0, d_model:] = width_pe[0] 45 | 46 | for i in range(height): 47 | for j in range(width): 48 | pe[i*width+j+1, :d_model] = height_pe[i+1] 49 | pe[i*width+j+1, d_model:] = width_pe[j+1] 50 | 51 | return pe 52 | 53 | 54 | if __name__ == '__main__': 55 | x20 = positionalencoding2d(512, 8, 8) 56 | x10 = positionalencoding2d(512, 4, 4, 2) 57 | x5 = positionalencoding2d(512, 2, 2, 4) 58 | print(x20, x10, x5) 59 | cos = torch.nn.CosineSimilarity() 60 | print(cos(x10[1:2], x20)[1:].reshape((8,8))) 61 | 62 | 63 | -------------------------------------------------------------------------------- /ROAM/predict_cascade.py: -------------------------------------------------------------------------------- 1 | 2 | from cProfile import label 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.nn.modules import loss 7 | from torch.utils.data import DataLoader,WeightedRandomSampler 8 | import numpy as np 9 | import h5py 10 | import os 11 | import shutil 12 | import sys 13 | import json 14 | import random 15 | 16 | from tqdm import tqdm 17 | 18 | from dataset.roidataset import Wsi_Dataset_pred 19 | from models.ROAM import ROAM 20 | from parse_config import parse_args 21 | 22 | os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' 23 | 24 | cls_name_dict = { 25 | 'int_glioma_tumor_subtyping':['astrocytoma','oligodendroglioma','ependymoma'] 26 | } 27 | 28 | 29 | def seed_torch(seed=7): 30 | random.seed(seed) 31 | os.environ['PYTHONHASHSEED'] = str(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | 35 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 36 | if device.type == 'cuda': 37 | torch.cuda.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 39 | torch.backends.cudnn.benchmark = False 40 | torch.backends.cudnn.deterministic = True 41 | 42 | def read_features(feat_path): 43 | with h5py.File(feat_path,'r') as hdf5_file: 44 | features = hdf5_file['features'][:] # num_patches,84,1024 45 | return torch.from_numpy(features) 46 | 47 | def weights_for_balanced_class(train_dataset,weight_cls): 48 | n = float(len(train_dataset)) 49 | weight = [0]*int(n) 50 | for idx in range(len(train_dataset)): 51 | label = train_dataset.get_label(idx) 52 | weight[idx] = weight_cls[label] 53 | return weight 54 | 55 | 56 | 57 | def val_epoch(args,model,loader,loss_fn,epoch): 58 | val_loss = [] 59 | val_acc = [] 60 | preds = [] 61 | trues = [] 62 | probs = [] 63 | total_loss = 0.0 64 | model.eval() 65 | 66 | with torch.no_grad(): 67 | progressbar = tqdm(loader) 68 | for i,(feature,_,label) in enumerate(progressbar): 69 | feature, label = feature.cuda(),label.cuda() 70 | logits,loss_instance = model(feature,label,inst_level=False) # only slide 71 | pred = logits.argmax(1).cpu() 72 | 73 | loss_bag = loss_fn(logits,label) 74 | 75 | #print(loss_instance) 76 | loss = loss_bag 77 | 78 | val_loss.append(loss.item()) 79 | 80 | if pred == label.cpu(): 81 | val_acc.append(1) 82 | else: 83 | val_acc.append(0) 84 | 85 | preds.append(float(pred)) 86 | trues.append(int(label.cpu())) 87 | probs.append(torch.nn.functional.softmax(logits, dim=1).cpu()) 88 | 89 | progressbar.set_description(f'epoch: {epoch}, val_acc: {np.mean(val_acc):.4f}, val_loss: {np.mean(val_loss):.4f}, current_inst_loss: {float(loss_instance):.4f}, current_bag_loss: {float(loss_bag):.4f}') 90 | 91 | del loss 92 | 93 | probs = torch.cat(probs) 94 | 95 | return np.mean(val_acc),np.mean(val_loss),preds,trues,probs 96 | 97 | def pred_epoch(args,model,loader,epoch): 98 | preds = [] 99 | probs = [] 100 | model.eval() 101 | 102 | with torch.no_grad(): 103 | progressbar = tqdm(loader) 104 | for i,(feature,_) in enumerate(progressbar): 105 | feature = feature.cuda() 106 | logits,_ = model(feature,inst_level=False) # only slide 107 | pred = logits.argmax(1).cpu() 108 | 109 | preds.append(float(pred)) 110 | probs.append(torch.nn.functional.softmax(logits, dim=1).cpu()) 111 | 112 | progressbar.set_description(f'epoch: {epoch}, pred:{float(pred)}') 113 | 114 | 115 | probs = torch.cat(probs) 116 | 117 | return preds,probs 118 | 119 | if __name__ == "__main__": 120 | assert len(sys.argv) in [3,4], 'please give configuration file and split seed!' 121 | ### split seed [s1, s2, s3, s4, s5] 122 | 123 | args, task_info = parse_args(sys.argv[1], sys.argv[2]) 124 | if len(sys.argv) == 4: 125 | args.exp_code = sys.argv[3] 126 | print(f'exp_code: {args.exp_code}') 127 | print(f'split seed: {args.split_seed}') 128 | 129 | 130 | args.n_classes = task_info[args.task]['n_classes'] 131 | if args.task not in task_info: 132 | raise NotImplementedError 133 | 134 | ## ==> save_dir 135 | if not os.path.exists(args.results_dir): 136 | os.mkdir(args.results_dir) 137 | 138 | args.results_dir = os.path.join(args.results_dir,args.task) 139 | if not os.path.exists(args.results_dir): 140 | os.mkdir(args.results_dir) 141 | 142 | args.results_dir = os.path.join(args.results_dir,f'{args.exp_code}') 143 | if not os.path.exists(args.results_dir): 144 | os.mkdir(args.results_dir) 145 | 146 | shutil.copy(sys.argv[1], args.results_dir) 147 | with open(sys.argv[1], 'r') as f: 148 | for l in f.readlines(): 149 | print(l.strip()) 150 | 151 | args.results_dir = os.path.join(args.results_dir,f'{args.split_seed}') 152 | if not os.path.exists(args.results_dir): 153 | os.mkdir(args.results_dir) 154 | 155 | ## pred results save dir 156 | pred_result_dir = f'prediction_results/{args.split_seed}/{args.task}' 157 | if not os.path.exists(pred_result_dir): 158 | os.makedirs(pred_result_dir) 159 | 160 | 161 | 162 | if args.embed_type == 'ImageNet': 163 | patch_dim = 1024 164 | if args.embed_type == 'RetCCL': 165 | patch_dim = 2048 166 | if args.embed_type == 'ctranspath': 167 | patch_dim = 768 168 | if args.embed_type == 'simclr-ciga': 169 | patch_dim = 512 170 | 171 | 172 | 173 | ### ==> test dataset 174 | # confirm test slide list 175 | default_test_tasks = ['int_glioma_cls','int_idh_cls','int_mgmt_cls'] 176 | 177 | ## read test slide list file 178 | test_split_dir = task_info[args.task]['test_split_dir'] 179 | if args.task not in default_test_tasks: 180 | cascade_test_dir = os.path.join(f'prediction_results/{args.split_seed}',f'cascade_{args.task}_split.npy') 181 | if os.path.exists(cascade_test_dir): 182 | test_split_dir = cascade_test_dir 183 | 184 | # if args.task in default_test_tasks: 185 | # test_split_dir = task_info[args.task]['test_split_dir'] #xiangya 186 | # else: 187 | # # use the results from the previous layer of the cascade system for prediction 188 | # test_split_dir = os.path.join(f'prediction_results/{args.split_seed}',f'cascade_{args.task}_split.npy') 189 | 190 | ## original test slide list file has label, ignore 191 | test_info = np.load(test_split_dir) 192 | test_ids = test_info[0] if len(test_info)==2 else test_info 193 | 194 | 195 | data_dir = f'{args.data_root_dir}/feats_{args.embed_type}_norm' 196 | if args.test_dataset == 'xiangya': 197 | test_dataset = Wsi_Dataset_pred(slide_ids=test_ids, 198 | csv_path=task_info[args.task]['csv_path'], 199 | data_dir=data_dir) 200 | test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=4) 201 | if args.test_dataset == 'TCGA': 202 | test_ids,test_labels = np.load(task_info[args.task]['test_split_dir_ext']) 203 | 204 | test_dataset = Wsi_Dataset_pred(slide_ids = test_ids, 205 | csv_path = task_info[args.task]['csv_path'], 206 | data_dir = data_dir, 207 | ) 208 | test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=4) 209 | 210 | ### ==> training 211 | 212 | num_folds = 5 213 | results = {k: {} for k in range(num_folds)} 214 | probs_all = [] 215 | 216 | ## whether setting fixed weights of embedding for each level 217 | if args.embed_weightx5==None and args.embed_weightx10==None and args.embed_weightx20==None: 218 | embed_weights = None 219 | print('use learnabel weights') 220 | else: 221 | embed_weights = [args.embed_weightx5,args.embed_weightx10,args.embed_weightx20] 222 | print('set weights:', embed_weights) 223 | 224 | for k in range(num_folds): 225 | print(f'split: {k}') 226 | 227 | ## set seed 228 | seed_torch(args.seed) 229 | 230 | # model 231 | model = ROAM(choose_num = args.topk, 232 | num_patches = 84, 233 | patch_dim=patch_dim, 234 | num_classes=task_info[args.task]['n_classes'], 235 | roi_level = args.roi_level, 236 | scale_type = args.scale_type, 237 | single_level = args.single_level, 238 | embed_weights=embed_weights, 239 | dim=args.dim, 240 | depths=args.depths, 241 | heads=args.heads, 242 | mlp_dim=args.mlp_dim, 243 | not_interscale = args.not_interscale, 244 | dim_head=args.dim_head, 245 | dropout=args.dropout, 246 | emb_dropout=args.emb_dropout, 247 | attn_dropout=args.attn_dropout, 248 | pool=args.pool, 249 | ape=args.ape, 250 | attn_type=args.attn_type, 251 | shared_pe=args.shared_pe) 252 | 253 | model = model.cuda() 254 | 255 | loss_fn = nn.CrossEntropyLoss() 256 | 257 | model_path = os.path.join(args.results_dir,f'{args.model_type}_split{str(k)}.pth') 258 | assert os.path.exists(model_path), 'No trained model checkpoint!' 259 | model.load_state_dict(torch.load(model_path)) 260 | 261 | preds,probs = pred_epoch(args,model,test_loader,1) 262 | 263 | 264 | results[k]['preds'] = preds 265 | 266 | probs_all.append(probs) 267 | 268 | print(f'end of prediction of split_{k}') 269 | 270 | 271 | 272 | probs_all = torch.stack(probs_all) 273 | 274 | probs_mean = probs_all.mean(0) #b,n_classes 275 | preds_mean = probs_mean.argmax(1) 276 | 277 | results['test'] = {} 278 | results['test']['preds'] = preds_mean.tolist() 279 | results['test']['probs'] = probs_mean.tolist() 280 | 281 | final_preds = {} 282 | cls_names = cls_name_dict[args.task] 283 | print('predict results:') 284 | for idx,sid in enumerate(test_ids): 285 | final_preds[sid] = cls_names[results['test']['preds'][idx]] 286 | print(f'{sid}:{final_preds[sid]}') 287 | 288 | pred_res_dir = os.path.join(pred_result_dir,'predictions.json') 289 | print(f'save the predictions to {pred_res_dir}') 290 | with open(os.path.join(pred_result_dir,'predictions.json'),'w') as f: 291 | json.dump(final_preds,f) 292 | 293 | 294 | 295 | with open(os.path.join(pred_result_dir,'results.json'),'w') as f: 296 | json.dump(results,f) 297 | 298 | ## results split: prepare for next level prediction 299 | preds_res = np.array(results['test']['preds']) 300 | #print('preds_res:',preds_res) 301 | if args.task == 'int_glioma_cls': 302 | slide_ids_c = test_ids[preds_res == 2] 303 | np.save(os.path.join(f'prediction_results/{args.split_seed}','cascade_int_glioma_tumor_subtyping_split.npy'),slide_ids_c) 304 | elif args.task == 'int_glioma_tumor_subtyping': 305 | cat_names = ['int_ast_grade1','int_oli_grade','int_epe_grade'] 306 | for c in range(task_info[args.task]['n_classes']): 307 | slide_ids_c = test_ids[preds_res == c] 308 | if len(slide_ids_c)>0: 309 | np.save(os.path.join(f'prediction_results/{args.split_seed}',f'cascade_{cat_names[c]}_split.npy'),slide_ids_c) 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | -------------------------------------------------------------------------------- /ROAM/prediction_results/s1/cascade_int_oli_grade_split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/prediction_results/s1/cascade_int_oli_grade_split.npy -------------------------------------------------------------------------------- /ROAM/prediction_results/s1/int_glioma_tumor_subtyping/predictions.json: -------------------------------------------------------------------------------- 1 | {"d0ab09865c3b467": "oligodendroglioma"} -------------------------------------------------------------------------------- /ROAM/prediction_results/s1/int_glioma_tumor_subtyping/results.json: -------------------------------------------------------------------------------- 1 | {"0": {"preds": [1.0]}, "1": {"preds": [1.0]}, "2": {"preds": [1.0]}, "3": {"preds": [1.0]}, "4": {"preds": [1.0]}, "test": {"preds": [1], "probs": [[0.0024064790923148394, 0.9973974227905273, 0.00019609039009083062]]}} -------------------------------------------------------------------------------- /ROAM/results/int_glioma_tumor_subtyping/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/s1/visual_res/cm_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/results/int_glioma_tumor_subtyping/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/s1/visual_res/cm_mean.png -------------------------------------------------------------------------------- /ROAM/results/int_glioma_tumor_subtyping/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/s1/visual_res/metrics.json: -------------------------------------------------------------------------------- 1 | {"acc": 0.8518518805503845, "precision": 0.8523505217716442, "recall": 0.8325258919308611, "f1_score": 0.8416857387445623, "balanced_accuracy": 0.8325258919308611} -------------------------------------------------------------------------------- /ROAM/results/int_glioma_tumor_subtyping/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/s1/visual_res/normal_cm_mean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/results/int_glioma_tumor_subtyping/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/s1/visual_res/normal_cm_mean.png -------------------------------------------------------------------------------- /ROAM/scripts/cascade_pred_int_glioma_tumor_subtyping.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | python predict_cascade.py configs/int_glioma_tumor_subtyping.ini s1 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /ROAM/scripts/int_glioma_tumor_subtyping.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | python train.py configs/int_glioma_tumor_subtyping.ini s1 exp_code 3 | -------------------------------------------------------------------------------- /ROAM/scripts/int_glioma_tumor_subtyping_test.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | python train.py configs/int_glioma_tumor_subtyping_test.ini s1 3 | -------------------------------------------------------------------------------- /ROAM/utils/__pycache__/core_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/utils/__pycache__/core_utils.cpython-37.pyc -------------------------------------------------------------------------------- /ROAM/utils/__pycache__/file_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/utils/__pycache__/file_utils.cpython-37.pyc -------------------------------------------------------------------------------- /ROAM/utils/__pycache__/file_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/utils/__pycache__/file_utils.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /ROAM/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from models.model_mil import MIL_fc, MIL_fc_mc 7 | from models.model_clam import CLAM_SB, CLAM_MB 8 | import pdb 9 | import os 10 | import pandas as pd 11 | from utils.utils import * 12 | from utils.core_utils import Accuracy_Logger 13 | from sklearn.metrics import roc_auc_score, roc_curve, auc 14 | from sklearn.preprocessing import label_binarize 15 | import matplotlib.pyplot as plt 16 | 17 | def initiate_model(args, ckpt_path): 18 | print('Init Model') 19 | model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes} 20 | 21 | if args.model_size is not None and args.model_type in ['clam_sb', 'clam_mb']: 22 | model_dict.update({"size_arg": args.model_size}) 23 | 24 | if args.model_type =='clam_sb': 25 | model = CLAM_SB(**model_dict) 26 | elif args.model_type =='clam_mb': 27 | model = CLAM_MB(**model_dict) 28 | else: # args.model_type == 'mil' 29 | if args.n_classes > 2: 30 | model = MIL_fc_mc(**model_dict) 31 | else: 32 | model = MIL_fc(**model_dict) 33 | 34 | print_network(model) 35 | 36 | ckpt = torch.load(ckpt_path) 37 | ckpt_clean = {} 38 | for key in ckpt.keys(): 39 | if 'instance_loss_fn' in key: 40 | continue 41 | ckpt_clean.update({key.replace('.module', ''):ckpt[key]}) 42 | model.load_state_dict(ckpt_clean, strict=True) 43 | 44 | model.relocate() 45 | model.eval() 46 | return model 47 | 48 | def eval(dataset, args, ckpt_path): 49 | model = initiate_model(args, ckpt_path) 50 | 51 | print('Init Loaders') 52 | loader = get_simple_loader(dataset) 53 | patient_results, test_error, auc, df, _ = summary(model, loader, args) 54 | print('test_error: ', test_error) 55 | print('auc: ', auc) 56 | return model, patient_results, test_error, auc, df 57 | 58 | def summary(model, loader, args): 59 | acc_logger = Accuracy_Logger(n_classes=args.n_classes) 60 | model.eval() 61 | test_loss = 0. 62 | test_error = 0. 63 | 64 | all_probs = np.zeros((len(loader), args.n_classes)) 65 | all_labels = np.zeros(len(loader)) 66 | all_preds = np.zeros(len(loader)) 67 | 68 | slide_ids = loader.dataset.slide_data['slide_id'] 69 | patient_results = {} 70 | for batch_idx, (data, label) in enumerate(loader): 71 | data, label = data.to(device), label.to(device) 72 | slide_id = slide_ids.iloc[batch_idx] 73 | with torch.no_grad(): 74 | logits, Y_prob, Y_hat, _, results_dict = model(data) 75 | 76 | acc_logger.log(Y_hat, label) 77 | 78 | probs = Y_prob.cpu().numpy() 79 | 80 | all_probs[batch_idx] = probs 81 | all_labels[batch_idx] = label.item() 82 | all_preds[batch_idx] = Y_hat.item() 83 | 84 | patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'prob': probs, 'label': label.item()}}) 85 | 86 | error = calculate_error(Y_hat, label) 87 | test_error += error 88 | 89 | del data 90 | test_error /= len(loader) 91 | 92 | aucs = [] 93 | if len(np.unique(all_labels)) == 1: 94 | auc_score = -1 95 | 96 | else: 97 | if args.n_classes == 2: 98 | auc_score = roc_auc_score(all_labels, all_probs[:, 1]) 99 | else: 100 | binary_labels = label_binarize(all_labels, classes=[i for i in range(args.n_classes)]) 101 | for class_idx in range(args.n_classes): 102 | if class_idx in all_labels: 103 | fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs[:, class_idx]) 104 | aucs.append(auc(fpr, tpr)) 105 | else: 106 | aucs.append(float('nan')) 107 | if args.micro_average: 108 | binary_labels = label_binarize(all_labels, classes=[i for i in range(args.n_classes)]) 109 | fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs.ravel()) 110 | auc_score = auc(fpr, tpr) 111 | else: 112 | auc_score = np.nanmean(np.array(aucs)) 113 | 114 | results_dict = {'slide_id': slide_ids, 'Y': all_labels, 'Y_hat': all_preds} 115 | for c in range(args.n_classes): 116 | results_dict.update({'p_{}'.format(c): all_probs[:,c]}) 117 | df = pd.DataFrame(results_dict) 118 | return patient_results, test_error, auc_score, df, acc_logger 119 | -------------------------------------------------------------------------------- /ROAM/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import h5py 3 | 4 | def save_pkl(filename, save_object): 5 | writer = open(filename,'wb') 6 | pickle.dump(save_object, writer) 7 | writer.close() 8 | 9 | def load_pkl(filename): 10 | loader = open(filename,'rb') 11 | file = pickle.load(loader) 12 | loader.close() 13 | return file 14 | 15 | 16 | def save_hdf5(output_path, asset_dict, attr_dict= None, mode='a'): 17 | file = h5py.File(output_path, mode) 18 | for key, val in asset_dict.items(): 19 | data_shape = val.shape 20 | if key not in file: 21 | data_type = val.dtype 22 | chunk_shape = (1, ) + data_shape[1:] 23 | maxshape = (None, ) + data_shape[1:] 24 | dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type) 25 | dset[:] = val 26 | if attr_dict is not None: 27 | if key in attr_dict.keys(): 28 | for attr_key, attr_val in attr_dict[key].items(): 29 | dset.attrs[attr_key] = attr_val 30 | else: 31 | dset = file[key] 32 | dset.resize(len(dset) + data_shape[0], axis=0) 33 | dset[-data_shape[0]:] = val 34 | file.close() 35 | return output_path -------------------------------------------------------------------------------- /ROAM/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import pdb 6 | 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler 12 | import torch.optim as optim 13 | import pdb 14 | import torch.nn.functional as F 15 | import math 16 | from itertools import islice 17 | import collections 18 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | class SubsetSequentialSampler(Sampler): 21 | """Samples elements sequentially from a given list of indices, without replacement. 22 | 23 | Arguments: 24 | indices (sequence): a sequence of indices 25 | """ 26 | def __init__(self, indices): 27 | self.indices = indices 28 | 29 | def __iter__(self): 30 | return iter(self.indices) 31 | 32 | def __len__(self): 33 | return len(self.indices) 34 | 35 | def collate_MIL(batch): 36 | img = torch.cat([item[0] for item in batch], dim = 0) 37 | label = torch.LongTensor([item[1] for item in batch]) 38 | return [img, label] 39 | 40 | def collate_features(batch): 41 | img = torch.cat([item[0] for item in batch], dim = 0) 42 | coords = np.vstack([item[1] for item in batch]) 43 | available = torch.cat([item[2] for item in batch], dim = 0) 44 | return [img, coords, available] 45 | 46 | 47 | def get_simple_loader(dataset, batch_size=1, num_workers=1): 48 | kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {} 49 | loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs) 50 | return loader 51 | 52 | def get_split_loader(split_dataset, training = False, testing = False, weighted = False): 53 | """ 54 | return either the validation loader or training loader 55 | """ 56 | kwargs = {'num_workers': 4} if device.type == "cuda" else {} 57 | if not testing: 58 | if training: 59 | if weighted: 60 | weights = make_weights_for_balanced_classes_split(split_dataset) 61 | loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL, **kwargs) 62 | else: 63 | loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 64 | else: 65 | loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 66 | 67 | else: 68 | ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False) 69 | loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL, **kwargs ) 70 | 71 | return loader 72 | 73 | def get_optim(model, args): 74 | if args.opt == "adam": 75 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg) 76 | elif args.opt == 'sgd': 77 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg) 78 | else: 79 | raise NotImplementedError 80 | return optimizer 81 | 82 | def print_network(net): 83 | num_params = 0 84 | num_params_train = 0 85 | print(net) 86 | 87 | for param in net.parameters(): 88 | n = param.numel() 89 | num_params += n 90 | if param.requires_grad: 91 | num_params_train += n 92 | 93 | print('Total number of parameters: %d' % num_params) 94 | print('Total number of trainable parameters: %d' % num_params_train) 95 | 96 | 97 | def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5, 98 | seed = 7, label_frac = 1.0, custom_test_ids = None): 99 | indices = np.arange(samples).astype(int) 100 | 101 | if custom_test_ids is not None: 102 | indices = np.setdiff1d(indices, custom_test_ids) 103 | 104 | np.random.seed(seed) 105 | for i in range(n_splits): 106 | all_val_ids = [] 107 | all_test_ids = [] 108 | sampled_train_ids = [] 109 | 110 | if custom_test_ids is not None: # pre-built test split, do not need to sample 111 | all_test_ids.extend(custom_test_ids) 112 | 113 | for c in range(len(val_num)): 114 | possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class 115 | val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids 116 | 117 | remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation 118 | all_val_ids.extend(val_ids) 119 | 120 | if custom_test_ids is None: # sample test split 121 | 122 | test_ids = np.random.choice(remaining_ids, test_num[c], replace = False) 123 | remaining_ids = np.setdiff1d(remaining_ids, test_ids) 124 | all_test_ids.extend(test_ids) 125 | 126 | if label_frac == 1: 127 | sampled_train_ids.extend(remaining_ids) 128 | 129 | else: 130 | #print(len(remaining_ids)) 131 | sample_num = math.ceil(len(remaining_ids) * label_frac) 132 | slice_ids = np.arange(sample_num) 133 | sampled_train_ids.extend(remaining_ids[slice_ids]) 134 | 135 | yield sampled_train_ids, all_val_ids, all_test_ids 136 | 137 | 138 | def nth(iterator, n, default=None): 139 | if n is None: 140 | return collections.deque(iterator, maxlen=0) 141 | else: 142 | return next(islice(iterator,n, None), default) 143 | 144 | def calculate_error(Y_hat, Y): 145 | error = 1. - Y_hat.float().eq(Y.float()).float().mean().item() 146 | 147 | return error 148 | 149 | def make_weights_for_balanced_classes_split(dataset): 150 | N = float(len(dataset)) 151 | weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))] 152 | weight = [0] * int(N) 153 | for idx in range(len(dataset)): 154 | y = dataset.getlabel(idx) 155 | weight[idx] = weight_per_class[y] 156 | 157 | return torch.DoubleTensor(weight) 158 | 159 | def initialize_weights(module): 160 | for m in module.modules(): 161 | if isinstance(m, nn.Linear): 162 | nn.init.xavier_normal_(m.weight) 163 | m.bias.data.zero_() 164 | 165 | elif isinstance(m, nn.BatchNorm1d): 166 | nn.init.constant_(m.weight, 1) 167 | nn.init.constant_(m.bias, 0) 168 | 169 | -------------------------------------------------------------------------------- /ROAM/vahadane.py: -------------------------------------------------------------------------------- 1 | import spams 2 | import numpy as np 3 | import cv2 4 | import time 5 | 6 | 7 | class vahadane(object): 8 | 9 | def __init__(self, STAIN_NUM=2, THRESH=0.9, LAMBDA1=0.01, LAMBDA2=0.01, ITER=100, fast_mode=0, getH_mode=0): 10 | self.STAIN_NUM = STAIN_NUM 11 | self.THRESH = THRESH 12 | self.LAMBDA1 = LAMBDA1 13 | self.LAMBDA2 = LAMBDA2 14 | self.ITER = ITER 15 | self.fast_mode = fast_mode # 0: normal; 1: fast 16 | self.getH_mode = getH_mode # 0: spams.lasso; 1: pinv; 17 | 18 | 19 | def show_config(self): 20 | print('STAIN_NUM =', self.STAIN_NUM) 21 | print('THRESH =', self.THRESH) 22 | print('LAMBDA1 =', self.LAMBDA1) 23 | print('LAMBDA2 =', self.LAMBDA2) 24 | print('ITER =', self.ITER) 25 | print('fast_mode =', self.fast_mode) 26 | print('getH_mode =', self.getH_mode) 27 | 28 | 29 | def getV(self, img): 30 | 31 | I0 = img.reshape((-1,3)).T 32 | I0[I0==0] = 1 33 | V0 = np.log(255 / I0) 34 | 35 | img_LAB = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) 36 | mask = img_LAB[:, :, 0] / 255 < self.THRESH 37 | I = img[mask].reshape((-1, 3)).T 38 | I[I == 0] = 1 39 | V = np.log(255 / I) 40 | 41 | return V0, V 42 | 43 | 44 | def getW(self, V): 45 | W = spams.trainDL(np.asfortranarray(V), numThreads=1, K=self.STAIN_NUM, lambda1=self.LAMBDA1, iter=self.ITER, mode=2, modeD=0, posAlpha=True, posD=True, verbose=False) 46 | W = W / np.linalg.norm(W, axis=0)[None, :] 47 | if (W[0,0] < W[0,1]): 48 | W = W[:, [1,0]] 49 | return W 50 | 51 | 52 | def getH(self, V, W): 53 | if (self.getH_mode == 0): 54 | H = spams.lasso(np.asfortranarray(V), np.asfortranarray(W), numThreads=1, mode=2, lambda1=self.LAMBDA2, pos=True, verbose=False).toarray() 55 | elif (self.getH_mode == 1): 56 | H = np.linalg.pinv(W).dot(V); 57 | H[H<0] = 0 58 | else: 59 | H = 0 60 | return H 61 | 62 | 63 | def stain_separate(self, img): 64 | start = time.time() 65 | if (self.fast_mode == 0): 66 | V0, V = self.getV(img) 67 | W = self.getW(V) 68 | H = self.getH(V0, W) 69 | elif (self.fast_mode == 1): 70 | m = img.shape[0] 71 | n = img.shape[1] 72 | grid_size_m = int(m / 5) 73 | lenm = int(m / 20) 74 | grid_size_n = int(n / 5) 75 | lenn = int(n / 20) 76 | W = np.zeros((81, 3, self.STAIN_NUM)).astype(np.float64) 77 | for i in range(0, 4): 78 | for j in range(0, 4): 79 | px = (i + 1) * grid_size_m 80 | py = (j + 1) * grid_size_n 81 | patch = img[px - lenm : px + lenm, py - lenn: py + lenn, :] 82 | V0, V = self.getV(patch) 83 | W[i*9+j] = self.getW(V) 84 | W = np.mean(W, axis=0) 85 | V0, V = self.getV(img) 86 | H = self.getH(V0, W) 87 | #print('stain separation time:', time.time()-start, 's') 88 | return W, H 89 | 90 | 91 | def SPCN(self, img, Ws, Hs, Wt, Ht): 92 | Hs_RM = np.percentile(Hs, 99) 93 | Ht_RM = np.percentile(Ht, 99) 94 | Hs_norm = Hs * Ht_RM / Hs_RM 95 | Vs_norm = np.dot(Wt, Hs_norm) 96 | Is_norm = 255 * np.exp(-1 * Vs_norm) 97 | I = Is_norm.T.reshape(img.shape).astype(np.uint8) 98 | return I 99 | -------------------------------------------------------------------------------- /ROAM/vis_utils/__pycache__/heatmap_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/vis_utils/__pycache__/heatmap_utils.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/vis_utils/__pycache__/vit_grad_rollout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/vis_utils/__pycache__/vit_grad_rollout.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/vis_utils/__pycache__/vit_rollout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/vis_utils/__pycache__/vit_rollout.cpython-38.pyc -------------------------------------------------------------------------------- /ROAM/vis_utils/heatmap_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import pdb 6 | import os 7 | import pandas as pd 8 | from utils.utils import * 9 | from PIL import Image 10 | from math import floor 11 | import matplotlib.pyplot as plt 12 | from dataset.vis_dataset import Wsi_Region 13 | import h5py 14 | from wsi_core.WholeSlideImage import WholeSlideImage 15 | from scipy.stats import percentileofscore 16 | import math 17 | from utils.file_utils import save_hdf5 18 | from scipy.stats import percentileofscore 19 | 20 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | def score2percentile(score, ref): 23 | percentile = percentileofscore(ref, score) 24 | return percentile 25 | 26 | def drawHeatmap(scores, coords, slide_path=None, wsi_object=None, vis_level = -1, **kwargs): 27 | if wsi_object is None: 28 | wsi_object = WholeSlideImage(slide_path) 29 | print(wsi_object.name) 30 | 31 | wsi = wsi_object.getOpenSlide() 32 | if vis_level < 0: 33 | vis_level = wsi.get_best_level_for_downsample(32) 34 | 35 | heatmap = wsi_object.visHeatmap(scores=scores, coords=coords, vis_level=vis_level, **kwargs) 36 | return heatmap 37 | 38 | def initialize_wsi(slide_id, wsi_path, seg_mask_path=None, seg_params=None, filter_params=None): 39 | wsi_object = WholeSlideImage(wsi_path,slide_id) 40 | if seg_params['seg_level'] < 0: 41 | best_level = wsi_object.wsi.get_best_level_for_downsample(32) 42 | seg_params['seg_level'] = best_level 43 | 44 | wsi_object.segmentTissue(**seg_params, filter_params=filter_params) 45 | wsi_object.saveSegmentation(seg_mask_path) 46 | return wsi_object 47 | 48 | def compute_from_patches(wsi_object, clam_pred=None, model=None, feature_extractor=None, batch_size=512, 49 | attn_save_path=None, ref_scores=None, feat_save_path=None, **wsi_kwargs): 50 | top_left = wsi_kwargs['top_left'] 51 | bot_right = wsi_kwargs['bot_right'] 52 | patch_size = wsi_kwargs['patch_size'] 53 | 54 | roi_dataset = Wsi_Region(wsi_object, **wsi_kwargs) 55 | roi_loader = get_simple_loader(roi_dataset, batch_size=batch_size, num_workers=8) 56 | print('total number of patches to process: ', len(roi_dataset)) 57 | num_batches = len(roi_loader) 58 | print('number of batches: ', len(roi_loader)) 59 | mode = "w" 60 | for idx, (roi, coords) in enumerate(roi_loader): 61 | roi = roi.to(device) 62 | coords = coords.numpy() 63 | 64 | with torch.no_grad(): 65 | features = feature_extractor(roi) 66 | 67 | if attn_save_path is not None: 68 | A = model(features, attention_only=True) 69 | 70 | if A.size(0) > 1: #CLAM multi-branch attention 71 | A = A[clam_pred] 72 | 73 | A = A.view(-1, 1).cpu().numpy() 74 | 75 | if ref_scores is not None: 76 | for score_idx in range(len(A)): 77 | A[score_idx] = score2percentile(A[score_idx], ref_scores) 78 | 79 | asset_dict = {'attention_scores': A, 'coords': coords} 80 | save_path = save_hdf5(attn_save_path, asset_dict, mode=mode) 81 | 82 | if idx % math.ceil(num_batches * 0.05) == 0: 83 | print('procssed {} / {}'.format(idx, num_batches)) 84 | 85 | if feat_save_path is not None: 86 | asset_dict = {'features': features.cpu().numpy(), 'coords': coords} 87 | save_hdf5(feat_save_path, asset_dict, mode=mode) 88 | 89 | mode = "a" 90 | return attn_save_path, feat_save_path, wsi_object -------------------------------------------------------------------------------- /ROAM/vis_utils/vit_explain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Explainability visualization of Transformer and ViT is referenced from: 3 | 1. Transformer Interpretability Beyond Attention Visualization (Chefer etc., CVPR 2021) 4 | paper: http://arxiv.org/abs/2012.09838 5 | code: https://github.com/hila-chefer/Transformer-Explainability 6 | 2. Expolring Explainability for vision transformers 7 | blog: https://jacobgil.github.io/deeplearning/vision-transformer-explainability 8 | code: https://github.com/jacobgil/vit-explain 9 | ''' 10 | import argparse 11 | import sys 12 | import torch 13 | from PIL import Image 14 | from torchvision import transforms 15 | import numpy as np 16 | import cv2 17 | 18 | from vit_rollout import VITAttentionRollout 19 | from vit_grad_rollout import VITAttentionGradRollout 20 | 21 | def get_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--use_cuda', action='store_true', default=False, 24 | help='Use NVIDIA GPU acceleration') 25 | parser.add_argument('--image_path', type=str, default='./examples/both.png', 26 | help='Input image path') 27 | parser.add_argument('--head_fusion', type=str, default='max', 28 | help='How to fuse the attention heads for attention rollout. \ 29 | Can be mean/max/min') 30 | parser.add_argument('--discard_ratio', type=float, default=0.9, 31 | help='How many of the lowest 14x14 attention paths should we discard') 32 | parser.add_argument('--category_index', type=int, default=None, 33 | help='The category index for gradient rollout') 34 | args = parser.parse_args() 35 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 36 | if args.use_cuda: 37 | print("Using GPU") 38 | else: 39 | print("Using CPU") 40 | 41 | return args 42 | 43 | def show_mask_on_image(img, mask): 44 | img = np.float32(img) / 255 45 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 46 | heatmap = np.float32(heatmap) / 255 47 | cam = heatmap + np.float32(img) 48 | cam = cam / np.max(cam) 49 | return np.uint8(255 * cam) 50 | 51 | if __name__ == '__main__': 52 | args = get_args() 53 | model = torch.hub.load('facebookresearch/deit:main', 54 | 'deit_tiny_patch16_224', pretrained=True) 55 | model.eval() 56 | 57 | if args.use_cuda: 58 | model = model.cuda() 59 | 60 | transform = transforms.Compose([ 61 | transforms.Resize((224, 224)), 62 | transforms.ToTensor(), 63 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 64 | ]) 65 | img = Image.open(args.image_path) 66 | img = img.resize((224, 224)) 67 | input_tensor = transform(img).unsqueeze(0) 68 | if args.use_cuda: 69 | input_tensor = input_tensor.cuda() 70 | 71 | if args.category_index is None: 72 | print("Doing Attention Rollout") 73 | attention_rollout = VITAttentionRollout(model, head_fusion=args.head_fusion, 74 | discard_ratio=args.discard_ratio) 75 | mask = attention_rollout(input_tensor) 76 | name = "attention_rollout_{:.3f}_{}.png".format(args.discard_ratio, args.head_fusion) 77 | else: 78 | print("Doing Gradient Attention Rollout") 79 | grad_rollout = VITAttentionGradRollout(model, discard_ratio=args.discard_ratio) 80 | mask = grad_rollout(input_tensor, args.category_index) 81 | name = "grad_rollout_{}_{:.3f}_{}.png".format(args.category_index, 82 | args.discard_ratio, args.head_fusion) 83 | 84 | 85 | np_img = np.array(img)[:, :, ::-1] 86 | mask = cv2.resize(mask, (np_img.shape[1], np_img.shape[0])) 87 | mask = show_mask_on_image(np_img, mask) 88 | cv2.imshow("Input Image", np_img) 89 | cv2.imshow(name, mask) 90 | cv2.imwrite("input.png", np_img) 91 | cv2.imwrite(name, mask) 92 | cv2.waitKey(-1) -------------------------------------------------------------------------------- /ROAM/vis_utils/vit_grad_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from PIL import Image 4 | import numpy 5 | import sys 6 | from torchvision import transforms 7 | import numpy as np 8 | import cv2 9 | 10 | def avg_heads(cam, grad=None): 11 | cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]) 12 | if grad != None: 13 | grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) 14 | cam = grad * cam 15 | cam = cam.clamp(min=0).mean(dim=0) # filter negative 16 | 17 | return cam 18 | 19 | 20 | def grad_rollout(attentions, gradients, discard_ratio,vis_scale='ss',level=3, learnable_weights=False): 21 | if vis_scale == 'ss': 22 | 23 | result = torch.eye(attentions[0].size(-1)) 24 | # The order of obtaining gradients and attention scores is reversed 25 | gradients = gradients[::-1] 26 | with torch.no_grad(): 27 | for attention, grad in zip(attentions, gradients): 28 | weights = grad 29 | attention_heads_fused = (attention*weights).clamp(min=0).mean(axis=1) 30 | 31 | # Drop the lowest attentions, but 32 | # don't drop the class token 33 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 34 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 35 | flat[0, indices] = 0 36 | 37 | I = torch.eye(attention_heads_fused.size(-1)) 38 | a = (attention_heads_fused + 1.0*I)/2 39 | #a = (attention_heads_fused)/2 40 | a = a / a.sum(dim=-1) 41 | result = torch.matmul(a, result) 42 | 43 | # Look at the total attention between the class token, 44 | # and the image patches 45 | mask = result[0,0,1:] 46 | # In case of 224x224 image, this brings us from 196 to 14 47 | width = int(mask.size(-1)**0.5) 48 | mask = mask.reshape(width, width).numpy() 49 | mask = mask / np.max(mask) 50 | print(mask) 51 | return mask 52 | else: 53 | mask_all = [] 54 | ''' 55 | attentions: [transformer_20x.layer_0,transformer_20x.layer_1,transformer_20x.layer_2, 56 | transformer_10x.layer_0,transformer_10x.layer_1,transformer_10x.layer_2, 57 | transformer_5x.layer_0,transformer_5x.layer_1,transformer_5x.layer_2] 58 | ''' 59 | if learnable_weights: 60 | w = attentions[-1] 61 | w = torch.softmax(w,dim=1) 62 | w = w.detach().numpy() 63 | #print(w) 64 | attns = attentions[:-1] 65 | grads = gradients[1:] 66 | else: 67 | attns = attentions 68 | grads = gradients 69 | grads = grads[::-1] 70 | with torch.no_grad(): 71 | for i in range(3): 72 | attns_curl = attns[level*i:level*(i+1)] 73 | grads_curl = grads[level*i:level*(i+1)] 74 | result = torch.eye(attns_curl[0].size(-1)) 75 | for attn,grad in zip(attns_curl,grads_curl): 76 | weights = grad 77 | 78 | attention_heads_fused = (attn*grad).clamp(min=0).mean(axis=1) 79 | 80 | 81 | flat = attention_heads_fused.view(attention_heads_fused.size(0),-1) 82 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 83 | flat[0, indices] = 0 84 | I = torch.eye(attention_heads_fused.size(-1)) 85 | a = (attention_heads_fused + 1.0*I)/2 86 | #a = (attention_heads_fused)/2 87 | a = a / a.sum(dim=-1) 88 | result = torch.matmul(a, result) 89 | mask = result[0,0,1:] 90 | width = int(mask.size(-1)**0.5) 91 | mask = mask.reshape(width, width).numpy() 92 | mask = mask / np.max(mask) 93 | 94 | mask_all.append(mask) 95 | 96 | if learnable_weights: return mask_all, w 97 | return mask_all 98 | 99 | def grad_cam(attentions, gradients, vis_scale, level, learnable_weights=False): 100 | if vis_scale == 'ss': 101 | print(attentions[-1].shape) 102 | gradients = gradients[::-1] 103 | #print(gradients[0].shape) 104 | with torch.no_grad(): 105 | attn = attentions[-1] # h,s,s 106 | grad = gradients[-1] # h,s,s 107 | 108 | print(attn.shape) 109 | 110 | 111 | attn = attn[0,:,0,1:].reshape((-1,int((attn.shape[-1]-1)**0.5),int((attn.shape[-1]-1)**0.5))) 112 | grad = grad[0,:,0,1:].reshape((-1,int((gradients[-1].shape[-1]-1)**0.5),int((gradients[-1].shape[-1]-1)**0.5))) 113 | 114 | # h,n,n 115 | cam_grad = (grad*attn).mean(0).clamp(min=0) #n,n 116 | cam_grad = (cam_grad-cam_grad.min())/(cam_grad.max()-cam_grad.min()) 117 | print(cam_grad) 118 | 119 | 120 | return cam_grad.numpy() 121 | else: 122 | ## multi-scale 123 | cam_grad_all = [] 124 | if learnable_weights: 125 | w = attentions[-1] 126 | w = torch.softmax(w,dim=1) 127 | w = w.detach().numpy() 128 | #print(w) 129 | attns = attentions[:-1] 130 | grads = gradients[1:] 131 | else: 132 | attns = attentions 133 | grads = gradients 134 | grads = grads[::-1] 135 | with torch.no_grad(): 136 | for i in range(3): 137 | #print(f'attns_len:{len(attns)}') 138 | attn_mag = attns[level*(i+1)-1] 139 | grad_mag = grads[level*(i+1)-1] 140 | 141 | print(attn_mag.shape) 142 | attn = attn_mag[0,:,0,1:].reshape((-1,int((attn_mag.shape[-1]-1)**0.5),int((attn_mag.shape[-1]-1)**0.5))) 143 | grad = grad_mag[0,:,0,1:].reshape((-1,int((grad_mag.shape[-1]-1)**0.5),int((grad_mag.shape[-1]-1)**0.5))) 144 | 145 | # h,n,n 146 | cam_grad = (grad*attn).mean(0).clamp(min=0) #n,n 147 | cam_grad = (cam_grad-cam_grad.min())/(cam_grad.max()-cam_grad.min()) 148 | 149 | print(cam_grad) 150 | print(cam_grad.shape) 151 | cam_grad_all.append(cam_grad.numpy()) 152 | 153 | if learnable_weights: 154 | return cam_grad_all, w 155 | return cam_grad_all 156 | 157 | 158 | class VITAttentionGradRollout: 159 | def __init__(self, model, level, 160 | attention_layer_name='attend', 161 | discard_ratio=0.9, 162 | vis_type = 'grad_rollout', 163 | vis_scale='ms', 164 | learnable_weights=False): 165 | ''' 166 | ROI-level visualization. generate attention heatmap with self-attention matrix of Transformer 167 | 168 | args: 169 | model: ROAM model for drawing visualization heatmap 170 | level: depth of Transformer block 171 | discard_ratio: proportion of discarded low attention scores. focus only on the top attentions 172 | vis_type: type of visualization method. 'grad_rollout' or 'grad_cam' 173 | grad_cam: only focus on the last layer of Transformer at each magnification level 174 | grad_rollou: consider all self-attention layers 175 | vis_scale" single scale (ss) or multi-scale (ms) 176 | 'ss': only compute heatmap at 20x magnification scale 177 | learnable_weight: whether weight coefficients of each scale in the model are learnable 178 | 'True': obtain the final weights from the model's state dict 179 | 'False': fixed weight coefficients can be obtained according to initial config 180 | ''' 181 | self.model = model 182 | self.discard_ratio = discard_ratio 183 | self.vis_type = vis_type 184 | self.vis_scale = vis_scale 185 | self.level = level 186 | self.learnable_weights = learnable_weights 187 | 188 | if self.vis_scale == 'ms': 189 | att_layer_name = [f'transformer_{s}.layers.{l}.0.fn.attend' for s in [20,10,5] for l in range(level)] 190 | if learnable_weights: 191 | att_layer_name += [f'ms_attn.{level}'] 192 | 193 | cur_l = 0 194 | for name, module in self.model.named_modules(): 195 | if att_layer_name[cur_l] in name: 196 | module.register_forward_hook(self.get_attention) 197 | module.register_backward_hook(self.get_attention_gradient) 198 | 199 | cur_l += 1 200 | if cur_l >= len(att_layer_name): 201 | break 202 | 203 | else: 204 | # the attention scores of transformer20 are only ones needed 205 | att_layer_name = [f'transformer_{s}.layers.{l}' for s in [20,10,5] for l in range(level)] 206 | 207 | cur_l = 0 208 | for name, module in self.model.named_modules(): 209 | if attention_layer_name in name and att_layer_name[cur_l] in name: 210 | module.register_forward_hook(self.get_attention) 211 | 212 | module.register_backward_hook(self.get_attention_gradient) 213 | print(name,'is attention') 214 | cur_l += 1 215 | if cur_l >= level: 216 | break 217 | #print(name) 218 | self.attentions = [] 219 | self.attention_gradients = [] 220 | 221 | def get_attention(self, module, input, output): 222 | self.attentions.append(output.cpu()) 223 | 224 | def get_attention_gradient(self, module, grad_input, grad_output): 225 | self.attention_gradients.append(grad_input[0].cpu()) 226 | 227 | # def save_attn_gradients(self, attn_gradients): 228 | # print('grad_hook') 229 | # print(attn_gradients[0,0,0,1:10]) 230 | # self.attn_gradients = attn_gradients 231 | 232 | def __call__(self, input_tensor, category_index): 233 | self.model.zero_grad() 234 | _,output = self.model(input_tensor.unsqueeze(0),vis_mode=3) 235 | #print(output.shape) 236 | loss_fn = nn.CrossEntropyLoss() 237 | 238 | category_mask = torch.zeros(output.size()).cuda() 239 | category_mask[:, category_index] = 1 240 | 241 | 242 | loss = (output*category_mask).sum() 243 | 244 | loss.backward() 245 | 246 | #print(self.vis_type) 247 | 248 | if self.vis_type == 'grad_rollout': 249 | return grad_rollout(self.attentions, self.attention_gradients, 250 | self.discard_ratio, self.vis_scale,self.level,self.learnable_weights) 251 | else: 252 | ## grad_cam 253 | return grad_cam(self.attentions, self.attention_gradients, self.vis_scale, self.level, self.learnable_weights) 254 | -------------------------------------------------------------------------------- /ROAM/vis_utils/vit_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy 4 | import sys 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | 9 | def rollout(attentions, discard_ratio, head_fusion): 10 | result = torch.eye(attentions[0].size(-1)) 11 | with torch.no_grad(): 12 | for attention in attentions: 13 | if head_fusion == "mean": 14 | attention_heads_fused = attention.mean(axis=1) 15 | elif head_fusion == "max": 16 | attention_heads_fused = attention.max(axis=1)[0] 17 | elif head_fusion == "min": 18 | attention_heads_fused = attention.min(axis=1)[0] 19 | else: 20 | raise "Attention head fusion type Not supported" 21 | 22 | # Drop the lowest attentions, but 23 | # don't drop the class token 24 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 25 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 26 | indices = indices[indices != 0] 27 | flat[0, indices] = 0 28 | 29 | I = torch.eye(attention_heads_fused.size(-1)) 30 | a = (attention_heads_fused + 1.0*I)/2 31 | a = a / a.sum(dim=-1) 32 | 33 | result = torch.matmul(a, result) 34 | 35 | # Look at the total attention between the class token, 36 | # and the image patches 37 | mask = result[0, 0 , 1 :] 38 | # In case of 224x224 image, this brings us from 196 to 14 39 | width = int(mask.size(-1)**0.5) 40 | mask = mask.reshape(width, width).numpy() 41 | mask = mask / np.max(mask) 42 | return mask 43 | 44 | class VITAttentionRollout: 45 | def __init__(self, model, attention_layer_name='attn_drop', head_fusion="mean", 46 | discard_ratio=0.9): 47 | self.model = model 48 | self.head_fusion = head_fusion 49 | self.discard_ratio = discard_ratio 50 | for name, module in self.model.named_modules(): 51 | if attention_layer_name in name: 52 | module.register_forward_hook(self.get_attention) 53 | 54 | self.attentions = [] 55 | 56 | def get_attention(self, module, input, output): 57 | self.attentions.append(output.cpu()) 58 | 59 | def __call__(self, input_tensor): 60 | self.attentions = [] 61 | with torch.no_grad(): 62 | output = self.model(input_tensor) 63 | 64 | return rollout(self.attentions, self.discard_ratio, self.head_fusion) -------------------------------------------------------------------------------- /ROAM/visheatmaps/roi_vis/configs/int_glioma_tumor_subtyping_vis_roi.ini: -------------------------------------------------------------------------------- 1 | [int_glioma_tumor_subtyping] 2 | 3 | seed = 1 4 | embed_type = ImageNet 5 | not_stainnorm = False 6 | 7 | batch_size = 4 8 | emb_dropout = 0 9 | attn_dropout = 0.25 10 | dropout = 0.2 11 | 12 | model_type = ROAM 13 | roi_dropout = True 14 | roi_supervise = True 15 | roi_weight = 1.0 16 | topk = 4 17 | roi_level = 0 18 | scale_type = ms 19 | single_level = 0 20 | embed_weightx5 = 0.3333 21 | embed_weightx10 = 0.3333 22 | embed_weightx20 = 0.3333 23 | not_interscale = False 24 | 25 | dim = 256 26 | depths = [2,2,2,2,2] 27 | heads = 8 28 | mlp_dim = 512 29 | dim_head = 64 30 | pool = cls 31 | ape = True 32 | attn_type = rel_sa 33 | shared_pe = True 34 | 35 | process_list = visheatmaps/slide_vis/results/int_glioma_tumor_subtyping.csv 36 | topk_num = 5 37 | vis_type = grad_rollout 38 | vis_scale = ms 39 | sample = high 40 | level = -1 41 | head_fusion = max 42 | discard_ratio = 0 43 | category_index = 0 44 | 45 | -------------------------------------------------------------------------------- /ROAM/visheatmaps/roi_vis/high/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/grad_rollout/oligodendroglioma/d0ab09865c3b467/top1_seeds1_0_d2_l2_r0.0_avg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/visheatmaps/roi_vis/high/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/grad_rollout/oligodendroglioma/d0ab09865c3b467/top1_seeds1_0_d2_l2_r0.0_avg.png -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/configs/config_int_glioma_tumor_subtyping.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0,1 python create_heatmaps.py --config config_template.yaml 2 | --- 3 | exp_arguments: 4 | # number of classes, depends on the classification task 5 | n_classes: 2 6 | # where to save raw asset files 7 | raw_save_dir: visheatmaps/slide_vis/results/heatmap_raw_results 8 | # where to save final heatmaps 9 | production_save_dir: visheatmaps/slide_vis/results/heatmap_production_results 10 | data_arguments: 11 | # where is data stored; can be a single str path or a dictionary of key, data_dir mapping 12 | data_dir: heatmaps/demo/slides/ 13 | # csv list containing slide_ids (can additionally have seg/patch paramters, class labels, etc.) 14 | process_list: int_glioma_tumor_subtyping.csv 15 | 16 | # arguments of ROAM model 17 | model_arguments: 18 | task: int_glioma_tumor_subtyping 19 | seed: 1 20 | batch_size: 4 21 | exp_code: int_idh_cls_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False 22 | embed_type: ImageNet 23 | sample_size: 100 24 | not_stainnorm: False 25 | test_dataset: xiangya 26 | results_dir: results 27 | 28 | emb_dropout: 0 29 | attn_dropout: 0.25 30 | dropout: 0.2 31 | 32 | model_type: ROAM 33 | roi_dropout: True 34 | roi_supervise: True 35 | roi_weight: 1.0 36 | topk: 4 37 | roi_level: 0 38 | scale_type: ms 39 | single_level: 0 40 | embed_weightx5: 0.3333 41 | embed_weightx10: 0.3333 42 | embed_weightx20: 0.3333 43 | not_interscale: False 44 | 45 | dim: 256 46 | depths: [2,2,2,2,2] 47 | heads: 8 48 | mlp_dim: 512 49 | dim_head: 64 50 | pool: cls 51 | ape: True 52 | attn_type: rel_sa 53 | shared_pe: True 54 | 55 | 56 | patching_arguments: 57 | # arguments for patching 58 | patch_size: 4096 59 | overlap: 0.5 60 | patch_level: 0 61 | custom_downsample: 1 62 | # for stain normalization 63 | target_image_dir: visheatmaps/target_roi_6e3.jpg 64 | 65 | heatmap_arguments: 66 | # downsample at which to visualize heatmap (-1 refers to downsample closest to 32x downsample) 67 | vis_level: 4 68 | # transparency for overlaying heatmap on background (0: background only, 1: foreground only) 69 | alpha: 0.4 70 | # whether to use a blank canvas instead of original slide 71 | blank_canvas: false 72 | # whether to also save the original H&E image 73 | save_orig: true 74 | # file extension for saving heatmap/original image 75 | save_ext: jpg 76 | # whether to calculate percentile scores in reference to the set of non-overlapping patches 77 | use_ref_scores: True 78 | # whether to use gaussian blur for further smoothing 79 | blur: True 80 | # whether to shift the 4 default corner points for checking if a patch is inside a foreground contour 81 | use_center_shift: true 82 | # whether to only compute heatmap for ROI specified by x1, x2, y1, y2 83 | use_roi: false 84 | # whether to calculate heatmap with specified overlap (by default, coarse heatmap without overlap is always calculated) 85 | calc_heatmap: true 86 | # whether to binarize attention scores 87 | binarize: false 88 | # binarization threshold: (0, 1) 89 | binary_thresh: -1 90 | # factor for downscaling the heatmap before final dispaly 91 | custom_downsample: 1 92 | cmap: jet 93 | sample_arguments: 94 | samples: 95 | - name: "topk_high_attention" 96 | sample: true 97 | seed: 1 98 | k: 10 # save top-k patches 99 | mode: topk 100 | - name: "topk_low_attention" 101 | sample: true 102 | seed: 1 103 | k: 5 # save top-k patches 104 | mode: reverse_topk 105 | # - name: "random_attention" 106 | # sample: true 107 | # seed: 1 108 | # k: 5 # save top-k patches 109 | # mode: range_sample 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/configs/config_int_glioma_tumor_subtyping_roi.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0,1 python create_heatmaps.py --config config_template.yaml 2 | --- 3 | exp_arguments: 4 | # number of classes, depends on the classification task 5 | n_classes: 2 6 | # where to save raw asset files 7 | raw_save_dir: visheatmaps/slide_vis/results/heatmap_raw_results 8 | # where to save final heatmaps 9 | production_save_dir: visheatmaps/slide_vis/results/heatmap_production_results 10 | data_arguments: 11 | # where is data stored; can be a single str path or a dictionary of key, data_dir mapping 12 | data_dir: heatmaps/demo/slides/ 13 | # csv list containing slide_ids (can additionally have seg/patch paramters, class labels, etc.) 14 | process_list: int_glioma_tumor_subtyping_roi.csv 15 | 16 | # arguments of ROAM model 17 | model_arguments: 18 | task: int_glioma_tumor_subtyping 19 | seed: 1 20 | batch_size: 4 21 | exp_code: int_idh_cls_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False 22 | embed_type: ImageNet 23 | sample_size: 100 24 | not_stainnorm: False 25 | test_dataset: xiangya 26 | results_dir: results 27 | 28 | emb_dropout: 0 29 | attn_dropout: 0.25 30 | dropout: 0.2 31 | 32 | model_type: ROAM 33 | roi_dropout: True 34 | roi_supervise: True 35 | roi_weight: 1.0 36 | topk: 4 37 | roi_level: 0 38 | scale_type: ms 39 | single_level: 0 40 | embed_weightx5: 0.3333 41 | embed_weightx10: 0.3333 42 | embed_weightx20: 0.3333 43 | not_interscale: False 44 | 45 | dim: 256 46 | depths: [2,2,2,2,2] 47 | heads: 8 48 | mlp_dim: 512 49 | dim_head: 64 50 | pool: cls 51 | ape: True 52 | attn_type: rel_sa 53 | shared_pe: True 54 | 55 | 56 | patching_arguments: 57 | # arguments for patching 58 | patch_size: 4096 59 | overlap: 0.95 60 | patch_level: 0 61 | custom_downsample: 1 62 | # for stain normalization 63 | target_image_dir: visheatmaps/target_roi_6e3.jpg 64 | 65 | heatmap_arguments: 66 | # downsample at which to visualize heatmap (-1 refers to downsample closest to 32x downsample) 67 | vis_level: 2 68 | # transparency for overlaying heatmap on background (0: background only, 1: foreground only) 69 | alpha: 0.4 70 | # whether to use a blank canvas instead of original slide 71 | blank_canvas: false 72 | # whether to also save the original H&E image 73 | save_orig: true 74 | # file extension for saving heatmap/original image 75 | save_ext: jpg 76 | # whether to calculate percentile scores in reference to the set of non-overlapping patches 77 | use_ref_scores: True 78 | # whether to use gaussian blur for further smoothing 79 | blur: True 80 | # whether to shift the 4 default corner points for checking if a patch is inside a foreground contour 81 | use_center_shift: true 82 | # whether to only compute heatmap for ROI specified by x1, x2, y1, y2 83 | use_roi: true 84 | # whether to calculate heatmap with specified overlap (by default, coarse heatmap without overlap is always calculated) 85 | calc_heatmap: true 86 | # whether to binarize attention scores 87 | binarize: false 88 | # binarization threshold: (0, 1) 89 | binary_thresh: -1 90 | # factor for downscaling the heatmap before final dispaly 91 | custom_downsample: 1 92 | cmap: jet 93 | sample_arguments: 94 | samples: 95 | - name: "topk_high_attention_roi" 96 | sample: true 97 | seed: 1 98 | k: 1 # save top-k patches 99 | mode: topk 100 | # - name: "topk_low_attention" 101 | # sample: true 102 | # seed: 1 103 | # k: 5 # save top-k patches 104 | # mode: reverse_topk 105 | # - name: "random_attention" 106 | # sample: true 107 | # seed: 1 108 | # k: 5 # save top-k patches 109 | # mode: range_sample 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/process_list/int_glioma_tumor_subtyping.csv: -------------------------------------------------------------------------------- 1 | ,slide_id,path,preset_vis_level,label 2 | 0,d0ab09865c3b467,/images/202102/d0ab09865c3b467.tiff,5,1 3 | -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/process_list/int_glioma_tumor_subtyping_roi.csv: -------------------------------------------------------------------------------- 1 | ,slide_id,path,preset_vis_level,label,x1,y1,x2,y2 2 | 1,d0ab09865c3b467,/images/202102/d0ab09865c3b467.tiff,3,1,12000,20000,20000,28000 3 | 4 | -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_0.95_roi_1_blur_1_rs_1_bc_0_a_0.4_l_4_bi_0_-1.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_0.95_roi_1_blur_1_rs_1_bc_0_a_0.4_l_4_bi_0_-1.0.jpg -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_0.9_roi_0_blur_1_rs_1_bc_0_a_0.4_l_5_bi_0_-1.0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_0.9_roi_0_blur_1_rs_1_bc_0_a_0.4_l_5_bi_0_-1.0.jpg -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_orig_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_orig_4.jpg -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_orig_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/visheatmaps/slide_vis/results/heatmap_production_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467_orig_5.jpg -------------------------------------------------------------------------------- /ROAM/visheatmaps/slide_vis/results/int_glioma_tumor_subtyping_roi.csv: -------------------------------------------------------------------------------- 1 | Unnamed: 0,slide_id,path,preset_vis_level,label,x1,y1,x2,y2,slide_path,process,status,seg_level,sthresh,mthresh,close,use_otsu,keep_ids,exclude_ids,a_t,a_h,max_n_holes,vis_level,line_thickness,use_padding,contour_fn,feat_dir,p_0,p_1,p_2,pred 2 | 1,d0ab09865c3b467,/images/202102/d0ab09865c3b467.tiff,3,1,12000,20000,20000,28000,int_glioma_tumor_subtyping_roi.csv,1,tbp,4,8,7,4,False,none,none,25.0,16.0,8,4,250,True,four_pt,"visheatmaps/slide_vis/results/heatmap_raw_results/int_glioma_tumor_subtyping_[2, 2, 2, 2, 2]_ImageNet_4_True_True_1.0_4_0_ms_0_False/oligodendroglioma/d0ab09865c3b467/d0ab09865c3b467_0.95.h5",0.002407548949122429,0.9973962903022766,0.00019619803060777485,1.0 3 | -------------------------------------------------------------------------------- /ROAM/visheatmaps/target_roi_6e3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/ROAM/visheatmaps/target_roi_6e3.jpg -------------------------------------------------------------------------------- /ROAM/wsi_core/batch_process_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pdb 4 | 5 | ''' 6 | initiate a pandas df describing a list of slides to process 7 | args: 8 | slides (df or array-like): 9 | array-like structure containing list of slide ids, if df, these ids assumed to be 10 | stored under the 'slide_id' column 11 | seg_params (dict): segmentation paramters 12 | filter_params (dict): filter parameters 13 | vis_params (dict): visualization paramters 14 | patch_params (dict): patching paramters 15 | use_heatmap_args (bool): whether to include heatmap arguments such as ROI coordinates 16 | ''' 17 | def initialize_df(slides, slides_path, seg_params, filter_params, vis_params, patch_params, 18 | use_heatmap_args=False, save_patches=False): 19 | 20 | total = len(slides) 21 | if isinstance(slides, pd.DataFrame): 22 | slide_ids = slides.slide_id.values 23 | else: 24 | slide_ids = slides 25 | default_df_dict = {'slide_id': slide_ids, 'slide_path':slides_path, 'process': np.full((total), 1, dtype=np.uint8)} 26 | 27 | # initiate empty labels in case not provided 28 | if use_heatmap_args: 29 | default_df_dict.update({'label': np.full((total), -1)}) 30 | 31 | default_df_dict.update({ 32 | 'status': np.full((total), 'tbp'), 33 | # seg params 34 | 'seg_level': np.full((total), int(seg_params['seg_level']), dtype=np.int8), 35 | 'sthresh': np.full((total), int(seg_params['sthresh']), dtype=np.uint8), 36 | 'mthresh': np.full((total), int(seg_params['mthresh']), dtype=np.uint8), 37 | 'close': np.full((total), int(seg_params['close']), dtype=np.uint32), 38 | 'use_otsu': np.full((total), bool(seg_params['use_otsu']), dtype=bool), 39 | 'keep_ids': np.full((total), seg_params['keep_ids']), 40 | 'exclude_ids': np.full((total), seg_params['exclude_ids']), 41 | 42 | # filter params 43 | 'a_t': np.full((total), int(filter_params['a_t']), dtype=np.float32), 44 | 'a_h': np.full((total), int(filter_params['a_h']), dtype=np.float32), 45 | 'max_n_holes': np.full((total), int(filter_params['max_n_holes']), dtype=np.uint32), 46 | 47 | # vis params 48 | 'vis_level': np.full((total), int(vis_params['vis_level']), dtype=np.int8), 49 | 'line_thickness': np.full((total), int(vis_params['line_thickness']), dtype=np.uint32), 50 | 51 | # patching params 52 | 'use_padding': np.full((total), bool(patch_params['use_padding']), dtype=bool), 53 | 'contour_fn': np.full((total), patch_params['contour_fn']) 54 | }) 55 | 56 | if save_patches: 57 | default_df_dict.update({ 58 | 'white_thresh': np.full((total), int(patch_params['white_thresh']), dtype=np.uint8), 59 | 'black_thresh': np.full((total), int(patch_params['black_thresh']), dtype=np.uint8)}) 60 | 61 | if use_heatmap_args: 62 | # initiate empty x,y coordinates in case not provided 63 | default_df_dict.update({'x1': np.empty((total)).fill(np.NaN), 64 | 'x2': np.empty((total)).fill(np.NaN), 65 | 'y1': np.empty((total)).fill(np.NaN), 66 | 'y2': np.empty((total)).fill(np.NaN)}) 67 | 68 | 69 | if isinstance(slides, pd.DataFrame): 70 | temp_copy = pd.DataFrame(default_df_dict) # temporary dataframe w/ default params 71 | # find key in provided df 72 | # if exist, fill empty fields w/ default values, else, insert the default values as a new column 73 | for key in default_df_dict.keys(): 74 | if key in slides.columns: 75 | mask = slides[key].isna() 76 | slides.loc[mask, key] = temp_copy.loc[mask, key] 77 | else: 78 | slides.insert(len(slides.columns), key, default_df_dict[key]) 79 | else: 80 | slides = pd.DataFrame(default_df_dict) 81 | 82 | return slides -------------------------------------------------------------------------------- /ROAM/wsi_core/util_classes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import pdb 5 | import cv2 6 | class Mosaic_Canvas(object): 7 | def __init__(self,patch_size=256, n=100, downscale=4, n_per_row=10, bg_color=(0,0,0), alpha=-1): 8 | self.patch_size = patch_size 9 | self.downscaled_patch_size = int(np.ceil(patch_size/downscale)) 10 | self.n_rows = int(np.ceil(n / n_per_row)) 11 | self.n_cols = n_per_row 12 | w = self.n_cols * self.downscaled_patch_size 13 | h = self.n_rows * self.downscaled_patch_size 14 | if alpha < 0: 15 | canvas = Image.new(size=(w,h), mode="RGB", color=bg_color) 16 | else: 17 | canvas = Image.new(size=(w,h), mode="RGBA", color=bg_color + (int(255 * alpha),)) 18 | 19 | self.canvas = canvas 20 | self.dimensions = np.array([w, h]) 21 | self.reset_coord() 22 | 23 | def reset_coord(self): 24 | self.coord = np.array([0, 0]) 25 | 26 | def increment_coord(self): 27 | #print('current coord: {} x {} / {} x {}'.format(self.coord[0], self.coord[1], self.dimensions[0], self.dimensions[1])) 28 | assert np.all(self.coord<=self.dimensions) 29 | if self.coord[0] + self.downscaled_patch_size <=self.dimensions[0] - self.downscaled_patch_size: 30 | self.coord[0]+=self.downscaled_patch_size 31 | else: 32 | self.coord[0] = 0 33 | self.coord[1]+=self.downscaled_patch_size 34 | 35 | 36 | def save(self, save_path, **kwargs): 37 | self.canvas.save(save_path, **kwargs) 38 | 39 | def paste_patch(self, patch): 40 | assert patch.size[0] == self.patch_size 41 | assert patch.size[1] == self.patch_size 42 | self.canvas.paste(patch.resize(tuple([self.downscaled_patch_size, self.downscaled_patch_size])), tuple(self.coord)) 43 | self.increment_coord() 44 | 45 | def get_painting(self): 46 | return self.canvas 47 | 48 | class Contour_Checking_fn(object): 49 | # Defining __call__ method 50 | def __call__(self, pt): 51 | raise NotImplementedError 52 | 53 | class isInContourV1(Contour_Checking_fn): 54 | def __init__(self, contour): 55 | self.cont = contour 56 | 57 | def __call__(self, pt): 58 | return 1 if cv2.pointPolygonTest(self.cont, pt, False) >= 0 else 0 59 | 60 | class isInContourV2(Contour_Checking_fn): 61 | def __init__(self, contour, patch_size): 62 | self.cont = contour 63 | self.patch_size = patch_size 64 | 65 | def __call__(self, pt): 66 | return 1 if cv2.pointPolygonTest(self.cont, (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2), False) >= 0 else 0 67 | 68 | # Easy version of 4pt contour checking function - 1 of 4 points need to be in the contour for test to pass 69 | class isInContourV3_Easy(Contour_Checking_fn): 70 | def __init__(self, contour, patch_size, center_shift=0.5): 71 | self.cont = contour 72 | self.patch_size = patch_size 73 | self.shift = int(patch_size//2*center_shift) 74 | def __call__(self, pt): 75 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 76 | if self.shift > 0: 77 | all_points = [(center[0]-self.shift, center[1]-self.shift), 78 | (center[0]+self.shift, center[1]+self.shift), 79 | (center[0]+self.shift, center[1]-self.shift), 80 | (center[0]-self.shift, center[1]+self.shift) 81 | ] 82 | else: 83 | all_points = [center] 84 | #print(all_points) 85 | for points in all_points: 86 | # need to convert 'numpy.int64' to 'int' 87 | if cv2.pointPolygonTest(self.cont, (int(points[0]),int(points[1])), False) >= 0: 88 | return 1 89 | return 0 90 | 91 | # Hard version of 4pt contour checking function - all 4 points need to be in the contour for test to pass 92 | class isInContourV3_Hard(Contour_Checking_fn): 93 | def __init__(self, contour, patch_size, center_shift=0.5): 94 | self.cont = contour 95 | self.patch_size = patch_size 96 | self.shift = int(patch_size//2*center_shift) 97 | def __call__(self, pt): 98 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 99 | if self.shift > 0: 100 | all_points = [(center[0]-self.shift, center[1]-self.shift), 101 | (center[0]+self.shift, center[1]+self.shift), 102 | (center[0]+self.shift, center[1]-self.shift), 103 | (center[0]-self.shift, center[1]+self.shift) 104 | ] 105 | else: 106 | all_points = [center] 107 | 108 | for points in all_points: 109 | if cv2.pointPolygonTest(self.cont, (int(points[0]),int(points[1])), False) < 0: 110 | return 0 111 | return 1 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /ROAM/wsi_core/wsi_utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | import pdb 5 | from wsi_core.util_classes import Mosaic_Canvas 6 | from PIL import Image 7 | import math 8 | import cv2 9 | 10 | def isWhitePatch(patch, satThresh=5): 11 | patch_hsv = cv2.cvtColor(patch, cv2.COLOR_RGB2HSV) 12 | return True if np.mean(patch_hsv[:,:,1]) < satThresh else False 13 | 14 | def isBlackPatch(patch, rgbThresh=40): 15 | return True if np.all(np.mean(patch, axis = (0,1)) < rgbThresh) else False 16 | 17 | def isBlackPatch_S(patch, rgbThresh=20, percentage=0.05): 18 | num_pixels = patch.size[0] * patch.size[1] 19 | return True if np.all(np.array(patch) < rgbThresh, axis=(2)).sum() > num_pixels * percentage else False 20 | 21 | def isWhitePatch_S(patch, rgbThresh=220, percentage=0.2): 22 | num_pixels = patch.size[0] * patch.size[1] 23 | return True if np.all(np.array(patch) > rgbThresh, axis=(2)).sum() > num_pixels * percentage else False 24 | 25 | def coord_generator(x_start, x_end, x_step, y_start, y_end, y_step, args_dict=None): 26 | for x in range(x_start, x_end, x_step): 27 | for y in range(y_start, y_end, y_step): 28 | if args_dict is not None: 29 | process_dict = args_dict.copy() 30 | process_dict.update({'pt':(x,y)}) 31 | yield process_dict 32 | else: 33 | yield (x,y) 34 | 35 | def savePatchIter_bag_hdf5(patch): 36 | x, y, cont_idx, patch_level, downsample, downsampled_level_dim, level_dim, img_patch, name, save_path= tuple(patch.values()) 37 | img_patch = np.array(img_patch)[np.newaxis,...] 38 | img_shape = img_patch.shape 39 | 40 | file_path = os.path.join(save_path, name)+'.h5' 41 | file = h5py.File(file_path, "a") 42 | 43 | dset = file['imgs'] 44 | dset.resize(len(dset) + img_shape[0], axis=0) 45 | dset[-img_shape[0]:] = img_patch 46 | 47 | if 'coords' in file: 48 | coord_dset = file['coords'] 49 | coord_dset.resize(len(coord_dset) + img_shape[0], axis=0) 50 | coord_dset[-img_shape[0]:] = (x,y) 51 | 52 | file.close() 53 | 54 | def save_hdf5(output_path, asset_dict, attr_dict= None, mode='a'): 55 | file = h5py.File(output_path, mode) 56 | for key, val in asset_dict.items(): 57 | data_shape = val.shape 58 | if key not in file: 59 | data_type = val.dtype 60 | chunk_shape = (1, ) + data_shape[1:] 61 | maxshape = (None, ) + data_shape[1:] 62 | dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type) 63 | dset[:] = val 64 | if attr_dict is not None: 65 | if key in attr_dict.keys(): 66 | for attr_key, attr_val in attr_dict[key].items(): 67 | dset.attrs[attr_key] = attr_val 68 | else: 69 | dset = file[key] 70 | dset.resize(len(dset) + data_shape[0], axis=0) 71 | dset[-data_shape[0]:] = val 72 | file.close() 73 | return output_path 74 | 75 | def initialize_hdf5_bag(first_patch, save_coord=False): 76 | x, y, cont_idx, patch_level, downsample, downsampled_level_dim, level_dim, img_patch, name, save_path = tuple(first_patch.values()) 77 | file_path = os.path.join(save_path, name)+'.h5' 78 | file = h5py.File(file_path, "w") 79 | img_patch = np.array(img_patch)[np.newaxis,...] 80 | dtype = img_patch.dtype 81 | 82 | # Initialize a resizable dataset to hold the output 83 | img_shape = img_patch.shape 84 | maxshape = (None,) + img_shape[1:] #maximum dimensions up to which dataset maybe resized (None means unlimited) 85 | dset = file.create_dataset('imgs', 86 | shape=img_shape, maxshape=maxshape, chunks=img_shape, dtype=dtype) 87 | 88 | dset[:] = img_patch 89 | dset.attrs['patch_level'] = patch_level 90 | dset.attrs['wsi_name'] = name 91 | dset.attrs['downsample'] = downsample 92 | dset.attrs['level_dim'] = level_dim 93 | dset.attrs['downsampled_level_dim'] = downsampled_level_dim 94 | 95 | if save_coord: 96 | coord_dset = file.create_dataset('coords', shape=(1, 2), maxshape=(None, 2), chunks=(1, 2), dtype=np.int32) 97 | coord_dset[:] = (x,y) 98 | 99 | file.close() 100 | return file_path 101 | 102 | def sample_indices(scores, k, start=0.48, end=0.52, convert_to_percentile=False, seed=1): 103 | np.random.seed(seed) 104 | if convert_to_percentile: 105 | end_value = np.quantile(scores, end) 106 | start_value = np.quantile(scores, start) 107 | else: 108 | end_value = end 109 | start_value = start 110 | score_window = np.logical_and(scores >= start_value, scores <= end_value) 111 | indices = np.where(score_window)[0] 112 | if len(indices) < 1: 113 | return -1 114 | else: 115 | return np.random.choice(indices, min(k, len(indices)), replace=False) 116 | 117 | def top_k(scores, k, invert=False): 118 | if invert: 119 | top_k_ids=scores.argsort()[:k] 120 | else: 121 | top_k_ids=scores.argsort()[::-1][:k] 122 | return top_k_ids 123 | 124 | def to_percentiles(scores): 125 | from scipy.stats import rankdata 126 | scores = rankdata(scores, 'average')/len(scores) * 100 127 | return scores 128 | 129 | def screen_coords(scores, coords, top_left, bot_right): 130 | bot_right = np.array(bot_right) 131 | top_left = np.array(top_left) 132 | mask = np.logical_and(np.all(coords >= top_left, axis=1), np.all(coords <= bot_right, axis=1)) 133 | scores = scores[mask] 134 | coords = coords[mask] 135 | return scores, coords 136 | 137 | def sample_rois(scores, coords, k=5, mode='range_sample', seed=1, score_start=0.45, score_end=0.55, top_left=None, bot_right=None): 138 | 139 | if len(scores.shape) == 2: 140 | scores = scores.flatten() 141 | 142 | scores = to_percentiles(scores) 143 | if top_left is not None and bot_right is not None: 144 | scores, coords = screen_coords(scores, coords, top_left, bot_right) 145 | 146 | if mode == 'range_sample': 147 | sampled_ids = sample_indices(scores, start=score_start, end=score_end, k=k, convert_to_percentile=False, seed=seed) 148 | elif mode == 'topk': 149 | sampled_ids = top_k(scores, k, invert=False) 150 | elif mode == 'reverse_topk': 151 | sampled_ids = top_k(scores, k, invert=True) 152 | else: 153 | raise NotImplementedError 154 | coords = coords[sampled_ids] 155 | scores = scores[sampled_ids] 156 | 157 | asset = {'sampled_coords': coords, 'sampled_scores': scores} 158 | return asset 159 | 160 | def DrawGrid(img, coord, shape, thickness=2, color=(0,0,0,255)): 161 | cv2.rectangle(img, tuple(np.maximum([0, 0], coord-thickness//2)), tuple(coord - thickness//2 + np.array(shape)), (0, 0, 0, 255), thickness=thickness) 162 | return img 163 | 164 | def DrawMap(canvas, patch_dset, coords, patch_size, indices=None, verbose=1, draw_grid=True): 165 | if indices is None: 166 | indices = np.arange(len(coords)) 167 | total = len(indices) 168 | if verbose > 0: 169 | ten_percent_chunk = math.ceil(total * 0.1) 170 | print('start stitching {}'.format(patch_dset.attrs['wsi_name'])) 171 | 172 | for idx in range(total): 173 | if verbose > 0: 174 | if idx % ten_percent_chunk == 0: 175 | print('progress: {}/{} stitched'.format(idx, total)) 176 | 177 | patch_id = indices[idx] 178 | patch = patch_dset[patch_id] 179 | patch = cv2.resize(patch, patch_size) 180 | coord = coords[patch_id] 181 | canvas_crop_shape = canvas[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0], :3].shape[:2] 182 | canvas[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0], :3] = patch[:canvas_crop_shape[0], :canvas_crop_shape[1], :] 183 | if draw_grid: 184 | DrawGrid(canvas, coord, patch_size) 185 | 186 | return Image.fromarray(canvas) 187 | 188 | def DrawMapFromCoords(canvas, wsi_object, coords, patch_size, vis_level, indices=None, verbose=1, draw_grid=True): 189 | downsamples = wsi_object.wsi.level_downsamples[vis_level] 190 | if indices is None: 191 | indices = np.arange(len(coords)) 192 | total = len(indices) 193 | if verbose > 0: 194 | ten_percent_chunk = math.ceil(total * 0.1) 195 | 196 | patch_size = tuple(np.ceil((np.array(patch_size)/np.array(downsamples))).astype(np.int32)) 197 | print('downscaled patch size: {}x{}'.format(patch_size[0], patch_size[1])) 198 | 199 | for idx in range(total): 200 | if verbose > 0: 201 | if idx % ten_percent_chunk == 0: 202 | print('progress: {}/{} stitched'.format(idx, total)) 203 | 204 | patch_id = indices[idx] 205 | coord = coords[patch_id] 206 | patch = np.array(wsi_object.wsi.read_region(tuple(coord), vis_level, patch_size).convert("RGB")) 207 | coord = np.ceil(coord / downsamples).astype(np.int32) 208 | canvas_crop_shape = canvas[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0], :3].shape[:2] 209 | canvas[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0], :3] = patch[:canvas_crop_shape[0], :canvas_crop_shape[1], :] 210 | if draw_grid: 211 | DrawGrid(canvas, coord, patch_size) 212 | 213 | return Image.fromarray(canvas) 214 | 215 | def StitchPatches(hdf5_file_path, downscale=16, draw_grid=False, bg_color=(0,0,0), alpha=-1): 216 | file = h5py.File(hdf5_file_path, 'r') 217 | dset = file['imgs'] 218 | coords = file['coords'][:] 219 | if 'downsampled_level_dim' in dset.attrs.keys(): 220 | w, h = dset.attrs['downsampled_level_dim'] 221 | else: 222 | w, h = dset.attrs['level_dim'] 223 | print('original size: {} x {}'.format(w, h)) 224 | w = w // downscale 225 | h = h //downscale 226 | coords = (coords / downscale).astype(np.int32) 227 | print('downscaled size for stiching: {} x {}'.format(w, h)) 228 | print('number of patches: {}'.format(len(dset))) 229 | img_shape = dset[0].shape 230 | print('patch shape: {}'.format(img_shape)) 231 | downscaled_shape = (img_shape[1] // downscale, img_shape[0] // downscale) 232 | 233 | if w*h > Image.MAX_IMAGE_PIXELS: 234 | raise Image.DecompressionBombError("Visualization Downscale %d is too large" % downscale) 235 | 236 | if alpha < 0 or alpha == -1: 237 | heatmap = Image.new(size=(w,h), mode="RGB", color=bg_color) 238 | else: 239 | heatmap = Image.new(size=(w,h), mode="RGBA", color=bg_color + (int(255 * alpha),)) 240 | 241 | heatmap = np.array(heatmap) 242 | heatmap = DrawMap(heatmap, dset, coords, downscaled_shape, indices=None, draw_grid=draw_grid) 243 | 244 | file.close() 245 | return heatmap 246 | 247 | def StitchCoords(hdf5_file_path, wsi_object, downscale=16, draw_grid=False, bg_color=(0,0,0), alpha=-1): 248 | wsi = wsi_object.getOpenSlide() 249 | vis_level = wsi.get_best_level_for_downsample(downscale) 250 | file = h5py.File(hdf5_file_path, 'r') 251 | dset = file['coords'] 252 | coords = dset[:] 253 | w, h = wsi.level_dimensions[0] 254 | 255 | print('start stitching {}'.format(dset.attrs['name'])) 256 | print('original size: {} x {}'.format(w, h)) 257 | 258 | w, h = wsi.level_dimensions[vis_level] 259 | 260 | print('downscaled size for stiching: {} x {}'.format(w, h)) 261 | print('number of patches: {}'.format(len(coords))) 262 | 263 | patch_size = dset.attrs['patch_size'] 264 | patch_level = dset.attrs['patch_level'] 265 | print('patch size: {}x{} patch level: {}'.format(patch_size, patch_size, patch_level)) 266 | patch_size = tuple((np.array((patch_size, patch_size)) * wsi.level_downsamples[patch_level]).astype(np.int32)) 267 | print('ref patch size: {}x{}'.format(patch_size, patch_size)) 268 | 269 | if w*h > Image.MAX_IMAGE_PIXELS: 270 | raise Image.DecompressionBombError("Visualization Downscale %d is too large" % downscale) 271 | 272 | if alpha < 0 or alpha == -1: 273 | heatmap = Image.new(size=(w,h), mode="RGB", color=bg_color) 274 | else: 275 | heatmap = Image.new(size=(w,h), mode="RGBA", color=bg_color + (int(255 * alpha),)) 276 | 277 | heatmap = np.array(heatmap) 278 | heatmap = DrawMapFromCoords(heatmap, wsi_object, coords, patch_size, vis_level, indices=None, draw_grid=draw_grid) 279 | 280 | file.close() 281 | return heatmap 282 | 283 | def SamplePatches(coords_file_path, save_file_path, wsi_object, 284 | patch_level=0, custom_downsample=1, patch_size=256, sample_num=100, seed=1, stitch=True, verbose=1, mode='w'): 285 | file = h5py.File(coords_file_path, 'r') 286 | dset = file['coords'] 287 | coords = dset[:] 288 | 289 | h5_patch_size = dset.attrs['patch_size'] 290 | h5_patch_level = dset.attrs['patch_level'] 291 | 292 | if verbose>0: 293 | print('in .h5 file: total number of patches: {}'.format(len(coords))) 294 | print('in .h5 file: patch size: {}x{} patch level: {}'.format(h5_patch_size, h5_patch_size, h5_patch_level)) 295 | 296 | if patch_level < 0: 297 | patch_level = h5_patch_level 298 | 299 | if patch_size < 0: 300 | patch_size = h5_patch_size 301 | 302 | np.random.seed(seed) 303 | indices = np.random.choice(np.arange(len(coords)), min(len(coords), sample_num), replace=False) 304 | 305 | target_patch_size = np.array([patch_size, patch_size]) 306 | 307 | if custom_downsample > 1: 308 | target_patch_size = (np.array([patch_size, patch_size]) / custom_downsample).astype(np.int32) 309 | 310 | if stitch: 311 | canvas = Mosaic_Canvas(patch_size=target_patch_size[0], n=sample_num, downscale=4, n_per_row=10, bg_color=(0,0,0), alpha=-1) 312 | else: 313 | canvas = None 314 | 315 | for idx in indices: 316 | coord = coords[idx] 317 | patch = wsi_object.wsi.read_region(coord, patch_level, tuple([patch_size, patch_size])).convert('RGB') 318 | if custom_downsample > 1: 319 | patch = patch.resize(tuple(target_patch_size)) 320 | 321 | # if isBlackPatch_S(patch, rgbThresh=20, percentage=0.05) or isWhitePatch_S(patch, rgbThresh=220, percentage=0.25): 322 | # continue 323 | 324 | if stitch: 325 | canvas.paste_patch(patch) 326 | 327 | asset_dict = {'imgs': np.array(patch)[np.newaxis,...], 'coords': coord} 328 | save_hdf5(save_file_path, asset_dict, mode=mode) 329 | mode='a' 330 | 331 | return canvas, len(coords), len(indices) -------------------------------------------------------------------------------- /data_prepare/data_csv/example_xiangya_data_info_pro.csv: -------------------------------------------------------------------------------- 1 | ,slide_id,bingli_id,path,subtype,grade,vis_level 2 | 0,d0ab09865c3b467,少Ⅲ1345304-2,/images/202102/d0ab09865c3b467.tiff,3,8,5 3 | -------------------------------------------------------------------------------- /data_prepare/data_split/xiangya_split_subtype/example_test_split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/data_split/xiangya_split_subtype/example_test_split.npy -------------------------------------------------------------------------------- /data_prepare/extract_feature_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import openslide 5 | import h5py 6 | import os 7 | 8 | from torch.utils.data import DataLoader 9 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 10 | import torch 11 | import pandas as pd 12 | from PIL import Image 13 | from tqdm import tqdm 14 | import matplotlib.pyplot as plt 15 | 16 | from patchdataset import Roi_Seg_Dataset, Patch_Seg_Dataset 17 | from models.extractor import resnet50 18 | import models.ResNet as ResNet 19 | from models.ccl import CCL 20 | from models.ctran import ctranspath 21 | from models.simclr_ciga import simclr_ciga_model 22 | 23 | from utils.file_utils import save_hdf5 24 | from utils.utils import collate_features 25 | from PIL import Image 26 | from tqdm import tqdm 27 | import h5py 28 | import openslide 29 | import argparse 30 | 31 | 32 | 33 | 34 | def extract_feats(args,h5_file_path,wsi,slide_path,model,output_path,target_roi_size=2048,patch_size=256,levels=[0,1,2],batch_size=1,is_stain_norm=False): 35 | ''' 36 | extract_feats: 37 | Extract features of patches within ROIs through pre-trained models. 38 | For ROI with size of 2048*2048 at 20x magnification, 8*8=64 patches with size of 256 39 | will be put into models to extract features. 40 | args: 41 | h5_file_path: directory of bag (.h5 file), the coordinates of all patches segmented from the slide(bag) 42 | wsi: object of WSI ('OpenSlide' object) 43 | slide_path: directory of WSI 44 | model: pre-trained model 45 | output_path: directory to save computed featrues (.h5 file) 46 | target_roi_size: size of roi at 20x magnification 47 | patch_size: size of patch used for feature extraction 48 | levels: at which magnification levels the patch features are extracted (List. 0:20x,1:10x,2:5x) 49 | batch_size: batch size of patches for feature extraction 50 | is_stain_norm: whether to perform stain normalization 51 | resize : whether to resize patches to 224*224 52 | ''' 53 | # for ctranspath pre-trained model, input size of image should be reset to 224*224 instead 256*256 54 | if args.pretrained_model == 'ctranspath': 55 | roi_dataset = Roi_Seg_Dataset(args.pretrained_model,h5_file_path,slide_path,wsi,levels,target_roi_size,patch_size,is_stain_norm,resize=True) 56 | else: 57 | roi_dataset = Roi_Seg_Dataset(args.pretrained_model,h5_file_path,slide_path,wsi,levels,target_roi_size,patch_size,is_stain_norm) 58 | roi_dataloader = DataLoader(roi_dataset,batch_size=batch_size,num_workers=4) 59 | # first w 60 | mode = 'w' 61 | 62 | for batch,coords,available in tqdm(roi_dataloader): 63 | with torch.no_grad(): 64 | 65 | for b in range(batch_size): 66 | if not available[b]: 67 | continue 68 | 69 | img_batch = batch[b].cuda() 70 | 71 | features = model(img_batch) # 84,d 72 | features = features.unsqueeze(0) #1,84,d 73 | features = features.cpu().numpy() 74 | 75 | coord = coords[b].unsqueeze(0) 76 | coord = coord.numpy() 77 | 78 | asset_dict = {'features':features,'coords':coord} 79 | save_hdf5(output_path,asset_dict,attr_dict=None,mode=mode) 80 | mode = 'a' 81 | 82 | 83 | def extract_feats_patch(h5_file_path,wsi,slide_path,model,output_path,patch_size=256,batch_size=1,is_stain_norm=False): 84 | ''' 85 | extract_feats_patch: 86 | normal function for extracting patch features. 87 | extract features directly with input patch instead of cutting 88 | it into smaller patches for indidual feature extraction. 89 | args: 90 | h5_file_path: directory of bag (.h5 file), the coordinates of all patches segmented from the slide(bag) 91 | wsi: object of WSI ('OpenSlide' object) 92 | slide_path: directory of WSI 93 | model: pre-trained model 94 | output_path: directory to save computed featrues (.h5 file) 95 | patch_size: size of patch used for feature extraction 96 | batch_size: batch size of patches for feature extraction 97 | is_stain_norm: whether to perform stain normalization 98 | ''' 99 | roi_dataset = Patch_Seg_Dataset(h5_file_path,slide_path,wsi,patch_size,is_stain_norm) 100 | roi_dataloader = DataLoader(roi_dataset,batch_size=batch_size,num_workers=4,pin_memory=True,collate_fn=collate_features) 101 | # first w 102 | mode = 'w' 103 | 104 | for batch,coords,available in tqdm(roi_dataloader): 105 | with torch.no_grad(): 106 | batch = batch.cuda() 107 | 108 | features = model(batch) 109 | features = features.cpu().numpy() 110 | 111 | if features.shape[0] < 2: 112 | continue 113 | 114 | 115 | features_normal = features[available] 116 | coords_normal = coords[available] 117 | 118 | if features_normal.shape[0] > 0: 119 | asset_dict = {'features':features_normal,'coords':coords_normal} 120 | save_hdf5(output_path,asset_dict,attr_dict=None,mode=mode) 121 | mode = 'a' 122 | 123 | 124 | parser = argparse.ArgumentParser(description='Feature Extraction') 125 | parser.add_argument('--data_h5_dir', type=str, default='/data/glioma_data/datapro') 126 | parser.add_argument('--data_slide_dir', type=str, default='/data/glioma_data/iapsfile') 127 | parser.add_argument('--csv_path', type=str, default='./xiangya_data_info') 128 | parser.add_argument('--dataset', type=str, default='xiangya') 129 | parser.add_argument('--data_format', type=str, default='roi',choices=['roi','patch']) 130 | parser.add_argument('--feat_dir', type=str, default = '/data/glioma_data/datapro') 131 | parser.add_argument('--batch_size', type=int, default=1) 132 | parser.add_argument('--no_auto_skip', default=False, action='store_true') 133 | parser.add_argument('--target_patch_size', type=int, default=256) 134 | parser.add_argument('--target_roi_size', type=int, default=2048) 135 | parser.add_argument('--level',default=0,type=int,choices=[0,1,2]) 136 | parser.add_argument('--is_stain_norm',action='store_true',default=False,help='whether stain normlization') 137 | parser.add_argument('--pretrained_model',type=str,default='ImageNet',choices=['ImageNet','RetCCL','simclr-ciga','ctranspath'],help='model weights for extracting features') 138 | args = parser.parse_args() 139 | 140 | if __name__ == '__main__': 141 | 142 | # create directory to save generated features 143 | os.makedirs(args.feat_dir, exist_ok=True) 144 | if args.is_stain_norm: 145 | args.feat_dir = os.path.join(args.feat_dir, f'feats_{args.pretrained_model}_norm') 146 | #os.makedirs(os.path.join(args.feat_dir, f'feats_{args.pretrained_model}_norm')) 147 | else: 148 | args.feat_dir = os.path.join(args.feat_dir, f'feats_{args.pretrained_model}') 149 | os.makedirs(args.feat_dir,exist_ok=True) 150 | dest_files = os.listdir(args.feat_dir) 151 | 152 | # read slide info csv data 153 | data_csv = pd.read_csv(args.csv_path) 154 | slide_id = data_csv['slide_id'].values 155 | slide_path = data_csv['path'].values 156 | 157 | # calculate magnifications 158 | roi_size_list = [2048,1024,512] 159 | 160 | levels = [i for i in range(4-args.level)] 161 | target_roi_size = roi_size_list[args.level] 162 | 163 | # select pre-trained model for feature extraction 164 | if args.pretrained_model == 'ImageNet': 165 | model = resnet50(pretrained=True).cuda() 166 | elif args.pretrained_model == 'RetCCL': 167 | backbone = ResNet.resnet50 168 | model = CCL(backbone, 128, 65536, mlp=True, two_branch=True, normlinear=True).cuda() 169 | ckpt_path = f'models/{args.pretrained_model}_ckpt.pth' 170 | model.load_state_dict(torch.load(ckpt_path),strict=True) 171 | model.encoder_q.fc = nn.Identity() 172 | model.encoder_q.instDis = nn.Identity() 173 | model.encoder_q.groupDis = nn.Identity() 174 | elif args.pretrained_model == 'ctranspath': 175 | model = ctranspath() 176 | model.head = nn.Identity() 177 | td = torch.load(r'models/ctranspath.pth') 178 | model.load_state_dict(td['model'], strict=True) 179 | model = model.cuda() 180 | else: 181 | model = simclr_ciga_model().cuda() 182 | 183 | #model = nn.DataParallel(model) 184 | 185 | # feature extraction 186 | model.eval() 187 | 188 | for i in range(len(slide_id)): 189 | print(f'extract features from {slide_id[i]},{i}/{len(slide_id)}') 190 | 191 | bag_name = slide_id[i]+'.h5' 192 | h5_file_path = os.path.join(args.data_h5_dir, 'patches', bag_name) 193 | 194 | if args.dataset == 'xiangya': 195 | slide_file_path = args.data_slide_dir + slide_path[i] 196 | else: 197 | slide_file_path = slide_path[i] 198 | 199 | if not args.no_auto_skip and slide_id[i]+'.h5' in dest_files: 200 | print(f'skipped {slide_id[i]}') 201 | continue 202 | 203 | output_path = os.path.join(args.feat_dir, bag_name) 204 | wsi = openslide.open_slide(slide_file_path) 205 | ''' 206 | data_format: form of feature extraction 207 | roi: segment the roi into patches with size of 256*256 and extract features of these patches 208 | patch: directly extract features of input patches 209 | ''' 210 | if args.data_format == 'roi': 211 | extract_feats(args,h5_file_path,wsi,slide_file_path,model,output_path,target_roi_size=target_roi_size,levels = levels,is_stain_norm=args.is_stain_norm) 212 | else: 213 | extract_feats_patch(h5_file_path,wsi,slide_file_path,model,output_path,batch_size = args.batch_size,is_stain_norm=args.is_stain_norm) 214 | 215 | 216 | -------------------------------------------------------------------------------- /data_prepare/models/ccl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | from PIL import Image 6 | import os 7 | 8 | 9 | class CCL(nn.Module): 10 | def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False, two_branch=False, normlinear=False, normalize=False): 11 | super(CCL, self).__init__() 12 | 13 | self.K = K 14 | self.m = m 15 | self.T = T 16 | self.two_branch = two_branch 17 | self.normalize = normalize 18 | 19 | # create the encoders 20 | # num_classes is the output fc dimension 21 | self.encoder_q = base_encoder(num_classes=dim, two_branch=two_branch, mlp=mlp, normlinear=normlinear) 22 | self.encoder_k = base_encoder(num_classes=dim, two_branch=two_branch, mlp=mlp, normlinear=normlinear) 23 | 24 | if mlp and not two_branch: # hack: brute-force replacement 25 | dim_mlp = self.encoder_q.fc.weight.shape[1] 26 | self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) 27 | self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) 28 | 29 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 30 | param_k.data.copy_(param_q.data) # initialize 31 | param_k.requires_grad = False # not update by gradient 32 | 33 | def forward(self, im_q): 34 | # compute query features 35 | q = self.encoder_q(im_q) # queries: NxC 36 | if self.two_branch: 37 | eq1 = nn.functional.normalize(q[1], dim=1) # branch 2 38 | q = q[0] # branch 1 39 | if self.normalize: 40 | print(1) 41 | q = nn.functional.normalize(q, dim=1) 42 | return q -------------------------------------------------------------------------------- /data_prepare/models/ctran.py: -------------------------------------------------------------------------------- 1 | from timm.models.layers.helpers import to_2tuple 2 | import timm 3 | import torch.nn as nn 4 | 5 | 6 | class ConvStem(nn.Module): 7 | 8 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 9 | super().__init__() 10 | 11 | assert patch_size == 4 12 | assert embed_dim % 8 == 0 13 | 14 | img_size = to_2tuple(img_size) 15 | patch_size = to_2tuple(patch_size) 16 | self.img_size = img_size 17 | self.patch_size = patch_size 18 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 19 | self.num_patches = self.grid_size[0] * self.grid_size[1] 20 | self.flatten = flatten 21 | 22 | 23 | stem = [] 24 | input_dim, output_dim = 3, embed_dim // 8 25 | for l in range(2): 26 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 27 | stem.append(nn.BatchNorm2d(output_dim)) 28 | stem.append(nn.ReLU(inplace=True)) 29 | input_dim = output_dim 30 | output_dim *= 2 31 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 32 | self.proj = nn.Sequential(*stem) 33 | 34 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 35 | 36 | def forward(self, x): 37 | B, C, H, W = x.shape 38 | assert H == self.img_size[0] and W == self.img_size[1], \ 39 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 40 | x = self.proj(x) 41 | if self.flatten: 42 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 43 | x = self.norm(x) 44 | return x 45 | 46 | def ctranspath(): 47 | model = timm.create_model('swin_tiny_patch4_window7_224', embed_layer=ConvStem, pretrained=False) 48 | return model -------------------------------------------------------------------------------- /data_prepare/models/extractor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch 4 | from torchsummary import summary 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | class Bottleneck_Baseline(nn.Module): 19 | expansion = 4 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(Bottleneck_Baseline, self).__init__() 23 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 29 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | class ResNet_Baseline(nn.Module): 57 | 58 | def __init__(self, block, layers): 59 | self.inplanes = 64 60 | super(ResNet_Baseline, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | self.layer1 = self._make_layer(block, 64, layers[0]) 67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 69 | self.avgpool = nn.AdaptiveAvgPool2d(1) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | elif isinstance(m, nn.BatchNorm2d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def _make_layer(self, block, planes, blocks, stride=1): 79 | downsample = None 80 | if stride != 1 or self.inplanes != planes * block.expansion: 81 | downsample = nn.Sequential( 82 | nn.Conv2d(self.inplanes, planes * block.expansion, 83 | kernel_size=1, stride=stride, bias=False), 84 | nn.BatchNorm2d(planes * block.expansion), 85 | ) 86 | 87 | layers = [] 88 | layers.append(block(self.inplanes, planes, stride, downsample)) 89 | self.inplanes = planes * block.expansion 90 | for i in range(1, blocks): 91 | layers.append(block(self.inplanes, planes)) 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.conv1(x) 97 | x = self.bn1(x) 98 | x = self.relu(x) 99 | x = self.maxpool(x) 100 | 101 | x = self.layer1(x) 102 | x = self.layer2(x) 103 | x = self.layer3(x) 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | return x 109 | 110 | def resnet50(pretrained=False,pretrained_weights=None): 111 | """Constructs a Modified ResNet-50 model. 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | pretrained_weights: not None for pathological image 115 | """ 116 | model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3]) 117 | if pretrained: 118 | if pretrained_weights: 119 | model.load_state_dict(pretrained_weights) 120 | else: 121 | model = load_pretrained_weights(model, 'resnet50') 122 | return model 123 | 124 | def load_pretrained_weights(model, name): 125 | pretrained_dict = model_zoo.load_url(model_urls[name]) 126 | model.load_state_dict(pretrained_dict, strict=False) 127 | return model 128 | 129 | -------------------------------------------------------------------------------- /data_prepare/models/simclr_ciga.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | 4 | 5 | MODEL_PATH = 'models/Simclr_ciga.ckpt' 6 | RETURN_PREACTIVATION = True # return features from the model, if false return classification logits 7 | NUM_CLASSES = 4 # only used if RETURN_PREACTIVATION = False 8 | 9 | 10 | def load_model_weights(model, weights): 11 | 12 | model_dict = model.state_dict() 13 | weights = {k: v for k, v in weights.items() if k in model_dict} 14 | if weights == {}: 15 | print('No weight could be loaded..') 16 | model_dict.update(weights) 17 | model.load_state_dict(model_dict) 18 | 19 | return model 20 | 21 | def simclr_ciga_model(): 22 | model = torchvision.models.__dict__['resnet18'](pretrained=False) 23 | 24 | state = torch.load(MODEL_PATH, map_location='cuda:0') 25 | 26 | state_dict = state['state_dict'] 27 | for key in list(state_dict.keys()): 28 | state_dict[key.replace('model.', '').replace('resnet.', '')] = state_dict.pop(key) 29 | 30 | model = load_model_weights(model, state_dict) 31 | 32 | if RETURN_PREACTIVATION: 33 | model.fc = torch.nn.Sequential() 34 | else: 35 | model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES) 36 | 37 | return model 38 | -------------------------------------------------------------------------------- /data_prepare/patchdataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import openslide 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | from PIL import Image 9 | import h5py 10 | import vahadane 11 | 12 | 13 | mean = (0.485, 0.456, 0.406) 14 | std = (0.229, 0.224, 0.225) 15 | transform_patch = transforms.Compose( 16 | [# may be other transform 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean = mean, std = std) 19 | ] 20 | ) 21 | 22 | 23 | class Roi_Seg_Dataset(Dataset): 24 | def __init__(self, 25 | embed_type='ImageNet', 26 | file_path='', 27 | slide_path='', 28 | wsi=None, 29 | levels=[0,1,2], 30 | target_roi_size = 2048, 31 | patch_size=256, 32 | is_stain_norm=False, 33 | resize=False): 34 | ''' 35 | args: 36 | embed_type (string): type of pre-trained model. select from ImageNet, RetCCL, ctranspath, simclr-ciga 37 | file_path (string): directory of bag (.h5 file), the coordinates of all patches segmented from the slide(bag) 38 | slide_path (string): directory of WSI 39 | wsi ('OpenSlide' object): object of WSI, for reading regions from WSI according to coordinates. 40 | levels (Lsit): at which magnification levels the patch features are extracted (0:20x,1:10x,2:5x) 41 | target_roi_size (int): size of roi at 20x magnification 42 | patch_size (int): size of patch used for feature extraction 43 | is_stain_norm (bool): whether to perform stain normalization 44 | resize (bool): whether to resize patches to 224*224 45 | ''' 46 | self.file_path = file_path 47 | self.wsi = wsi 48 | self.levels = levels #[0,1] or [0,1,2] 49 | self.patch_size = patch_size 50 | self.slide_path = slide_path 51 | self.target_roi_size = target_roi_size 52 | self.downscale = 0 53 | self.resize = resize 54 | 55 | ''' 56 | mean and std for normalization of simclr-ciga are different 57 | ''' 58 | mean = (0.485, 0.456, 0.406) 59 | std = (0.229, 0.224, 0.225) 60 | if embed_type == 'simclr-ciga': 61 | mean = (0.5, 0.5, 0.5) 62 | std = (0.5, 0.5, 0.5) 63 | 64 | self.transform_patch = transforms.Compose( 65 | [# may be other transform 66 | transforms.ToTensor(), 67 | transforms.Normalize(mean = mean, std = std) 68 | ] 69 | ) 70 | 71 | with h5py.File(self.file_path,'r') as f: 72 | dset = f['coords'] 73 | self.roi_level = f['coords'].attrs['patch_level'] 74 | self.roi_size = f['coords'].attrs['patch_size'] 75 | self.downscale = int(self.roi_size/self.target_roi_size) 76 | self.length = len(dset) 77 | patch_num_0 = (self.target_roi_size/self.patch_size)**2 78 | self.patch_nums = [int(patch_num_0/(2**level)) for level in self.levels] 79 | 80 | 81 | # select target image for stain normalization 82 | if target_roi_size == 512: 83 | target_image_dir = 'target_images/target_image_6e3_512.jpg' 84 | if target_roi_size == 1024: 85 | target_image_dir = 'target_images/target_image_6e3_1024.jpg' 86 | if target_roi_size == 2048: 87 | target_image_dir = 'target_images/target_roi_6e3.jpg' 88 | 89 | if is_stain_norm: 90 | self.target_img = np.array(Image.open(target_image_dir)) 91 | self.vhd = vahadane.vahadane(LAMBDA1=0.01,LAMBDA2=0.01,fast_mode=0,ITER=100) 92 | self.Wt,self.Ht = self.vhd.stain_separate(self.target_img) 93 | self.is_stain_norm = is_stain_norm 94 | 95 | 96 | 97 | def __len__(self): 98 | return self.length 99 | 100 | def stain_norm(self,src_img): 101 | ''' 102 | perform stain normalization for source img 103 | input (numpy array, shape: (3,h,w)): source image 104 | output (numpy array, shape: (3,h,w)): normalized image 105 | ''' 106 | std = np.std(src_img[:,:,0].reshape(-1)) 107 | # exclude images with large backgrounds 108 | if std < 5: 109 | return src_img,False 110 | else: 111 | Ws,Hs = self.vhd.stain_separate(src_img) 112 | img = self.vhd.SPCN(src_img,Ws,Hs,self.Wt,self.Ht) 113 | return img,True 114 | 115 | def __getitem__(self,idx): 116 | with h5py.File(self.file_path,'r') as hdf5_file: 117 | coord = hdf5_file['coords'][idx] 118 | 119 | try: 120 | img = self.wsi.read_region(coord, self.roi_level, (self.roi_size, self.roi_size)).convert('RGB') 121 | except: 122 | # or subsequent normal patches will also raise errors 123 | self.wsi = openslide.open_slide(self.slide_path) 124 | available = False 125 | 126 | else: 127 | img = self.wsi.read_region(coord, self.roi_level, (self.roi_size, self.roi_size)).convert('RGB') 128 | available = True 129 | 130 | patch_num_all = np.sum(self.patch_nums) 131 | if not available: 132 | if self.resize: 133 | img_batch = torch.zeros((patch_num_all,3,224,224)) 134 | else: 135 | img_batch = torch.zeros((patch_num_all,3,self.patch_size,self.patch_size)) 136 | print(f'not available: {img_batch.shape}') 137 | else: 138 | img_batch = [] 139 | img_roi = img.resize((self.target_roi_size,self.target_roi_size)) 140 | 141 | if self.is_stain_norm: 142 | img_roi,flag = self.stain_norm(np.array(img_roi)) 143 | 144 | if not flag: 145 | img_roi = torch.zeros((patch_num_all,3,self.patch_size,self.patch_size)) 146 | available = False 147 | else: 148 | for level in self.levels: 149 | roi_size_cur = int(self.target_roi_size/(2**level)) 150 | img_roi = np.array(img_roi) 151 | img_cur = Image.fromarray(img_roi).resize((roi_size_cur,roi_size_cur)) 152 | 153 | imgarray = np.array(img_cur) 154 | for i in range(0,roi_size_cur,self.patch_size): 155 | for j in range(0,roi_size_cur,self.patch_size): 156 | img_patch = imgarray[i:i+self.patch_size,j:j+self.patch_size,:] 157 | if self.resize: 158 | img_patch = Image.fromarray(img_patch).resize((224,224)) 159 | img_patch = np.array(img_patch) 160 | img_patch = self.transform_patch(img_patch) 161 | img_batch.append(img_patch) 162 | 163 | 164 | if available: 165 | img_batch = torch.stack(img_batch) 166 | 167 | return img_batch, coord, torch.tensor([available]) 168 | 169 | 170 | ### for features extraction of single patch 171 | class Patch_Seg_Dataset(Dataset): 172 | def __init__(self, 173 | file_path, 174 | slide_path, 175 | wsi, 176 | patch_size=256, 177 | is_stain_norm=False): 178 | 179 | ''' 180 | args: 181 | file_path (string): directory of bag (.h5 file), the coordinates of all patches segmented from the slide(bag) 182 | slide_path (string): directory of WSI 183 | wsi ('OpenSlide' object): object of WSI, for reading regions from WSI according to coordinates. 184 | patch_size (int): size of patch used for feature extraction 185 | is_stain_norm (bool): whether to perform stain normalization 186 | ''' 187 | 188 | self.file_path = file_path 189 | self.wsi = wsi 190 | self.target_patch_size = patch_size 191 | self.slide_path = slide_path 192 | self.downscale = 0 193 | 194 | with h5py.File(self.file_path,'r') as f: 195 | dset = f['coords'] 196 | self.patch_level = f['coords'].attrs['patch_level'] 197 | self.patch_size = f['coords'].attrs['patch_size'] 198 | self.downscale = int(self.patch_size/self.target_patch_size) 199 | self.length = len(dset) 200 | 201 | print(self.patch_size) 202 | 203 | if self.patch_size == 256: 204 | target_image_dir = 'target_images/target_image_6e3_256.jpg' 205 | if self.patch_size == 512: 206 | target_image_dir = 'target_images/target_image_6e3_512.jpg' 207 | if self.patch_size == 1024: 208 | target_image_dir = 'target_images/target_image_6e3_1024.jpg' 209 | if is_stain_norm: 210 | self.target_img = np.array(Image.open(target_image_dir)) 211 | 212 | self.vhd = vahadane.vahadane(LAMBDA1=0.01,LAMBDA2=0.01,fast_mode=0,ITER=100) 213 | self.Wt,self.Ht = self.vhd.stain_separate(self.target_img) 214 | #self.vhd.fast_mode = 1 #fast separate 215 | self.is_stain_norm = is_stain_norm 216 | 217 | 218 | 219 | def __len__(self): 220 | return self.length 221 | 222 | def stain_norm(self,src_img): 223 | 224 | std = np.std(src_img[:,:,0].reshape(-1)) 225 | if std < 10: 226 | return src_img,False 227 | else: 228 | Ws,Hs = self.vhd.stain_separate(src_img) 229 | img = self.vhd.SPCN(src_img,Ws,Hs,self.Wt,self.Ht) 230 | return img,True 231 | 232 | def __getitem__(self,idx): 233 | with h5py.File(self.file_path,'r') as hdf5_file: 234 | coord = hdf5_file['coords'][idx] 235 | #print(coord) 236 | try: 237 | img = self.wsi.read_region(coord, self.patch_level, (self.patch_size, self.patch_size)).convert('RGB') 238 | except: 239 | # or subsequent normal patches will also raise errors 240 | self.wsi = openslide.open_slide(self.slide_path) 241 | available = False 242 | 243 | else: 244 | img = self.wsi.read_region(coord, self.patch_level, (self.patch_size, self.patch_size)).convert('RGB') 245 | available = True 246 | 247 | if not available: 248 | img_patch = torch.ones((1,3,self.target_patch_size,self.target_patch_size)) 249 | 250 | else: 251 | img_patch = img.resize((self.target_patch_size,self.target_patch_size)) 252 | 253 | flag = True 254 | if self.is_stain_norm: 255 | img_patch,flag = self.stain_norm(np.array(img_patch)) 256 | 257 | if not flag: 258 | img_patch = torch.ones((1,3,self.target_patch_size,self.target_patch_size)) 259 | available = False 260 | else: 261 | 262 | img_patch = transform_patch(img_patch) 263 | img_patch = img_patch.unsqueeze(0) 264 | 265 | return img_patch, coord, torch.tensor([available]) 266 | 267 | -------------------------------------------------------------------------------- /data_prepare/target_images/target_image_6e3_1024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/target_images/target_image_6e3_1024.jpg -------------------------------------------------------------------------------- /data_prepare/target_images/target_image_6e3_256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/target_images/target_image_6e3_256.jpg -------------------------------------------------------------------------------- /data_prepare/target_images/target_image_6e3_512.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/target_images/target_image_6e3_512.jpg -------------------------------------------------------------------------------- /data_prepare/target_images/target_roi_6e3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/target_images/target_roi_6e3.jpg -------------------------------------------------------------------------------- /data_prepare/utils/__pycache__/file_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/utils/__pycache__/file_utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_prepare/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_prepare/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import h5py 3 | 4 | def save_pkl(filename, save_object): 5 | writer = open(filename,'wb') 6 | pickle.dump(save_object, writer) 7 | writer.close() 8 | 9 | def load_pkl(filename): 10 | loader = open(filename,'rb') 11 | file = pickle.load(loader) 12 | loader.close() 13 | return file 14 | 15 | 16 | def save_hdf5(output_path, asset_dict, attr_dict= None, mode='a'): 17 | file = h5py.File(output_path, mode) 18 | for key, val in asset_dict.items(): 19 | data_shape = val.shape 20 | if key not in file: 21 | data_type = val.dtype 22 | chunk_shape = (1, ) + data_shape[1:] 23 | maxshape = (None, ) + data_shape[1:] 24 | dset = file.create_dataset(key, shape=data_shape, maxshape=maxshape, chunks=chunk_shape, dtype=data_type) 25 | dset[:] = val 26 | if attr_dict is not None: 27 | if key in attr_dict.keys(): 28 | for attr_key, attr_val in attr_dict[key].items(): 29 | dset.attrs[attr_key] = attr_val 30 | else: 31 | dset = file[key] 32 | dset.resize(len(dset) + data_shape[0], axis=0) 33 | dset[-data_shape[0]:] = val 34 | file.close() 35 | return output_path -------------------------------------------------------------------------------- /data_prepare/utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import pdb 6 | 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler 12 | import torch.optim as optim 13 | import pdb 14 | import torch.nn.functional as F 15 | import math 16 | from itertools import islice 17 | import collections 18 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | class SubsetSequentialSampler(Sampler): 21 | """Samples elements sequentially from a given list of indices, without replacement. 22 | 23 | Arguments: 24 | indices (sequence): a sequence of indices 25 | """ 26 | def __init__(self, indices): 27 | self.indices = indices 28 | 29 | def __iter__(self): 30 | return iter(self.indices) 31 | 32 | def __len__(self): 33 | return len(self.indices) 34 | 35 | def collate_MIL(batch): 36 | img = torch.cat([item[0] for item in batch], dim = 0) 37 | label = torch.LongTensor([item[1] for item in batch]) 38 | return [img, label] 39 | 40 | def collate_features(batch): 41 | img = torch.cat([item[0] for item in batch], dim = 0) 42 | coords = np.vstack([item[1] for item in batch]) 43 | available = torch.cat([item[2] for item in batch], dim = 0) 44 | #print(img.shape) 45 | #print(coords.shape) 46 | #print(available.shape) 47 | return [img, coords, available] 48 | 49 | 50 | def get_simple_loader(dataset, batch_size=1, num_workers=1): 51 | kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {} 52 | loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs) 53 | return loader 54 | 55 | def get_split_loader(split_dataset, training = False, testing = False, weighted = False): 56 | """ 57 | return either the validation loader or training loader 58 | """ 59 | kwargs = {'num_workers': 4} if device.type == "cuda" else {} 60 | if not testing: 61 | if training: 62 | if weighted: 63 | weights = make_weights_for_balanced_classes_split(split_dataset) 64 | loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL, **kwargs) 65 | else: 66 | loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 67 | else: 68 | loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL, **kwargs) 69 | 70 | else: 71 | ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False) 72 | loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL, **kwargs ) 73 | 74 | return loader 75 | 76 | def get_optim(model, args): 77 | if args.opt == "adam": 78 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg) 79 | elif args.opt == 'sgd': 80 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg) 81 | else: 82 | raise NotImplementedError 83 | return optimizer 84 | 85 | def print_network(net): 86 | num_params = 0 87 | num_params_train = 0 88 | print(net) 89 | 90 | for param in net.parameters(): 91 | n = param.numel() 92 | num_params += n 93 | if param.requires_grad: 94 | num_params_train += n 95 | 96 | print('Total number of parameters: %d' % num_params) 97 | print('Total number of trainable parameters: %d' % num_params_train) 98 | 99 | 100 | def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5, 101 | seed = 7, label_frac = 1.0, custom_test_ids = None): 102 | indices = np.arange(samples).astype(int) 103 | 104 | if custom_test_ids is not None: 105 | indices = np.setdiff1d(indices, custom_test_ids) 106 | 107 | np.random.seed(seed) 108 | for i in range(n_splits): 109 | all_val_ids = [] 110 | all_test_ids = [] 111 | sampled_train_ids = [] 112 | 113 | if custom_test_ids is not None: # pre-built test split, do not need to sample 114 | all_test_ids.extend(custom_test_ids) 115 | 116 | for c in range(len(val_num)): 117 | possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class 118 | val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids 119 | 120 | remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation 121 | all_val_ids.extend(val_ids) 122 | 123 | if custom_test_ids is None: # sample test split 124 | 125 | test_ids = np.random.choice(remaining_ids, test_num[c], replace = False) 126 | remaining_ids = np.setdiff1d(remaining_ids, test_ids) 127 | all_test_ids.extend(test_ids) 128 | 129 | if label_frac == 1: 130 | sampled_train_ids.extend(remaining_ids) 131 | 132 | else: 133 | #print(len(remaining_ids)) 134 | sample_num = math.ceil(len(remaining_ids) * label_frac) 135 | slice_ids = np.arange(sample_num) 136 | sampled_train_ids.extend(remaining_ids[slice_ids]) 137 | 138 | yield sampled_train_ids, all_val_ids, all_test_ids 139 | 140 | 141 | def nth(iterator, n, default=None): 142 | if n is None: 143 | return collections.deque(iterator, maxlen=0) 144 | else: 145 | return next(islice(iterator,n, None), default) 146 | 147 | def calculate_error(Y_hat, Y): 148 | error = 1. - Y_hat.float().eq(Y.float()).float().mean().item() 149 | 150 | return error 151 | 152 | def make_weights_for_balanced_classes_split(dataset): 153 | N = float(len(dataset)) 154 | weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))] 155 | weight = [0] * int(N) 156 | for idx in range(len(dataset)): 157 | y = dataset.getlabel(idx) 158 | weight[idx] = weight_per_class[y] 159 | 160 | return torch.DoubleTensor(weight) 161 | 162 | def initialize_weights(module): 163 | for m in module.modules(): 164 | if isinstance(m, nn.Linear): 165 | nn.init.xavier_normal_(m.weight) 166 | m.bias.data.zero_() 167 | 168 | elif isinstance(m, nn.BatchNorm1d): 169 | nn.init.constant_(m.weight, 1) 170 | nn.init.constant_(m.bias, 0) 171 | 172 | -------------------------------------------------------------------------------- /data_prepare/vahadane.py: -------------------------------------------------------------------------------- 1 | import spams 2 | import numpy as np 3 | import cv2 4 | import time 5 | 6 | 7 | class vahadane(object): 8 | 9 | def __init__(self, STAIN_NUM=2, THRESH=0.9, LAMBDA1=0.01, LAMBDA2=0.01, ITER=100, fast_mode=0, getH_mode=0): 10 | self.STAIN_NUM = STAIN_NUM 11 | self.THRESH = THRESH 12 | self.LAMBDA1 = LAMBDA1 13 | self.LAMBDA2 = LAMBDA2 14 | self.ITER = ITER 15 | self.fast_mode = fast_mode # 0: normal; 1: fast 16 | self.getH_mode = getH_mode # 0: spams.lasso; 1: pinv; 17 | 18 | 19 | def show_config(self): 20 | print('STAIN_NUM =', self.STAIN_NUM) 21 | print('THRESH =', self.THRESH) 22 | print('LAMBDA1 =', self.LAMBDA1) 23 | print('LAMBDA2 =', self.LAMBDA2) 24 | print('ITER =', self.ITER) 25 | print('fast_mode =', self.fast_mode) 26 | print('getH_mode =', self.getH_mode) 27 | 28 | 29 | def getV(self, img): 30 | 31 | I0 = img.reshape((-1,3)).T 32 | I0[I0==0] = 1 33 | V0 = np.log(255 / I0) 34 | 35 | img_LAB = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) 36 | mask = img_LAB[:, :, 0] / 255 < self.THRESH 37 | I = img[mask].reshape((-1, 3)).T 38 | I[I == 0] = 1 39 | V = np.log(255 / I) 40 | 41 | return V0, V 42 | 43 | 44 | def getW(self, V): 45 | W = spams.trainDL(np.asfortranarray(V), numThreads=1, K=self.STAIN_NUM, lambda1=self.LAMBDA1, iter=self.ITER, mode=2, modeD=0, posAlpha=True, posD=True, verbose=False) 46 | W = W / np.linalg.norm(W, axis=0)[None, :] 47 | if (W[0,0] < W[0,1]): 48 | W = W[:, [1,0]] 49 | return W 50 | 51 | 52 | def getH(self, V, W): 53 | if (self.getH_mode == 0): 54 | H = spams.lasso(np.asfortranarray(V), np.asfortranarray(W), numThreads=1, mode=2, lambda1=self.LAMBDA2, pos=True, verbose=False).toarray() 55 | elif (self.getH_mode == 1): 56 | H = np.linalg.pinv(W).dot(V); 57 | H[H<0] = 0 58 | else: 59 | H = 0 60 | return H 61 | 62 | 63 | def stain_separate(self, img): 64 | start = time.time() 65 | if (self.fast_mode == 0): 66 | V0, V = self.getV(img) 67 | W = self.getW(V) 68 | H = self.getH(V0, W) 69 | elif (self.fast_mode == 1): 70 | m = img.shape[0] 71 | n = img.shape[1] 72 | grid_size_m = int(m / 5) 73 | lenm = int(m / 20) 74 | grid_size_n = int(n / 5) 75 | lenn = int(n / 20) 76 | W = np.zeros((81, 3, self.STAIN_NUM)).astype(np.float64) 77 | for i in range(0, 4): 78 | for j in range(0, 4): 79 | px = (i + 1) * grid_size_m 80 | py = (j + 1) * grid_size_n 81 | patch = img[px - lenm : px + lenm, py - lenn: py + lenn, :] 82 | V0, V = self.getV(patch) 83 | W[i*9+j] = self.getW(V) 84 | W = np.mean(W, axis=0) 85 | V0, V = self.getV(img) 86 | H = self.getH(V0, W) 87 | #print('stain separation time:', time.time()-start, 's') 88 | return W, H 89 | 90 | 91 | def SPCN(self, img, Ws, Hs, Wt, Ht): 92 | Hs_RM = np.percentile(Hs, 99) 93 | Ht_RM = np.percentile(Ht, 99) 94 | Hs_norm = Hs * Ht_RM / Hs_RM 95 | Vs_norm = np.dot(Wt, Hs_norm) 96 | Is_norm = 255 * np.exp(-1 * Vs_norm) 97 | I = Is_norm.T.reshape(img.shape).astype(np.uint8) 98 | return I 99 | -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/WholeSlideImage.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/WholeSlideImage.cpython-37.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/WholeSlideImage.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/WholeSlideImage.cpython-38.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/batch_process_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/batch_process_utils.cpython-37.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/batch_process_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/batch_process_utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/util_classes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/util_classes.cpython-37.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/util_classes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/util_classes.cpython-38.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/wsi_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/wsi_utils.cpython-37.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/__pycache__/wsi_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/data_prepare/wsi_core/__pycache__/wsi_utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_prepare/wsi_core/batch_process_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pdb 4 | 5 | ''' 6 | initiate a pandas df describing a list of slides to process 7 | args: 8 | slides (df or array-like): 9 | array-like structure containing list of slide ids, if df, these ids assumed to be 10 | stored under the 'slide_id' column 11 | seg_params (dict): segmentation paramters 12 | filter_params (dict): filter parameters 13 | vis_params (dict): visualization paramters 14 | patch_params (dict): patching paramters 15 | use_heatmap_args (bool): whether to include heatmap arguments such as ROI coordinates 16 | ''' 17 | def initialize_df(slides, slides_path, seg_params, filter_params, vis_params, patch_params, 18 | use_heatmap_args=False, save_patches=False): 19 | 20 | total = len(slides) 21 | if isinstance(slides, pd.DataFrame): 22 | slide_ids = slides.slide_id.values 23 | else: 24 | slide_ids = slides 25 | default_df_dict = {'slide_id': slide_ids, 'slide_path':slides_path, 'process': np.full((total), 1, dtype=np.uint8)} 26 | 27 | # initiate empty labels in case not provided 28 | if use_heatmap_args: 29 | default_df_dict.update({'label': np.full((total), -1)}) 30 | 31 | default_df_dict.update({ 32 | 'status': np.full((total), 'tbp'), 33 | # seg params 34 | 'seg_level': np.full((total), int(seg_params['seg_level']), dtype=np.int8), 35 | 'sthresh': np.full((total), int(seg_params['sthresh']), dtype=np.uint8), 36 | 'mthresh': np.full((total), int(seg_params['mthresh']), dtype=np.uint8), 37 | 'close': np.full((total), int(seg_params['close']), dtype=np.uint32), 38 | 'use_otsu': np.full((total), bool(seg_params['use_otsu']), dtype=bool), 39 | 'keep_ids': np.full((total), seg_params['keep_ids']), 40 | 'exclude_ids': np.full((total), seg_params['exclude_ids']), 41 | 42 | # filter params 43 | 'a_t': np.full((total), int(filter_params['a_t']), dtype=np.float32), 44 | 'a_h': np.full((total), int(filter_params['a_h']), dtype=np.float32), 45 | 'max_n_holes': np.full((total), int(filter_params['max_n_holes']), dtype=np.uint32), 46 | 47 | # vis params 48 | 'vis_level': np.full((total), int(vis_params['vis_level']), dtype=np.int8), 49 | 'line_thickness': np.full((total), int(vis_params['line_thickness']), dtype=np.uint32), 50 | 51 | # patching params 52 | 'use_padding': np.full((total), bool(patch_params['use_padding']), dtype=bool), 53 | 'contour_fn': np.full((total), patch_params['contour_fn']) 54 | }) 55 | 56 | if save_patches: 57 | default_df_dict.update({ 58 | 'white_thresh': np.full((total), int(patch_params['white_thresh']), dtype=np.uint8), 59 | 'black_thresh': np.full((total), int(patch_params['black_thresh']), dtype=np.uint8)}) 60 | 61 | if use_heatmap_args: 62 | # initiate empty x,y coordinates in case not provided 63 | default_df_dict.update({'x1': np.empty((total)).fill(np.NaN), 64 | 'x2': np.empty((total)).fill(np.NaN), 65 | 'y1': np.empty((total)).fill(np.NaN), 66 | 'y2': np.empty((total)).fill(np.NaN)}) 67 | 68 | 69 | if isinstance(slides, pd.DataFrame): 70 | temp_copy = pd.DataFrame(default_df_dict) # temporary dataframe w/ default params 71 | # find key in provided df 72 | # if exist, fill empty fields w/ default values, else, insert the default values as a new column 73 | for key in default_df_dict.keys(): 74 | if key in slides.columns: 75 | mask = slides[key].isna() 76 | slides.loc[mask, key] = temp_copy.loc[mask, key] 77 | else: 78 | slides.insert(len(slides.columns), key, default_df_dict[key]) 79 | else: 80 | slides = pd.DataFrame(default_df_dict) 81 | 82 | return slides -------------------------------------------------------------------------------- /data_prepare/wsi_core/util_classes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import pdb 5 | import cv2 6 | class Mosaic_Canvas(object): 7 | def __init__(self,patch_size=256, n=100, downscale=4, n_per_row=10, bg_color=(0,0,0), alpha=-1): 8 | self.patch_size = patch_size 9 | self.downscaled_patch_size = int(np.ceil(patch_size/downscale)) 10 | self.n_rows = int(np.ceil(n / n_per_row)) 11 | self.n_cols = n_per_row 12 | w = self.n_cols * self.downscaled_patch_size 13 | h = self.n_rows * self.downscaled_patch_size 14 | if alpha < 0: 15 | canvas = Image.new(size=(w,h), mode="RGB", color=bg_color) 16 | else: 17 | canvas = Image.new(size=(w,h), mode="RGBA", color=bg_color + (int(255 * alpha),)) 18 | 19 | self.canvas = canvas 20 | self.dimensions = np.array([w, h]) 21 | self.reset_coord() 22 | 23 | def reset_coord(self): 24 | self.coord = np.array([0, 0]) 25 | 26 | def increment_coord(self): 27 | #print('current coord: {} x {} / {} x {}'.format(self.coord[0], self.coord[1], self.dimensions[0], self.dimensions[1])) 28 | assert np.all(self.coord<=self.dimensions) 29 | if self.coord[0] + self.downscaled_patch_size <=self.dimensions[0] - self.downscaled_patch_size: 30 | self.coord[0]+=self.downscaled_patch_size 31 | else: 32 | self.coord[0] = 0 33 | self.coord[1]+=self.downscaled_patch_size 34 | 35 | 36 | def save(self, save_path, **kwargs): 37 | self.canvas.save(save_path, **kwargs) 38 | 39 | def paste_patch(self, patch): 40 | assert patch.size[0] == self.patch_size 41 | assert patch.size[1] == self.patch_size 42 | self.canvas.paste(patch.resize(tuple([self.downscaled_patch_size, self.downscaled_patch_size])), tuple(self.coord)) 43 | self.increment_coord() 44 | 45 | def get_painting(self): 46 | return self.canvas 47 | 48 | class Contour_Checking_fn(object): 49 | # Defining __call__ method 50 | def __call__(self, pt): 51 | raise NotImplementedError 52 | 53 | class isInContourV1(Contour_Checking_fn): 54 | def __init__(self, contour): 55 | self.cont = contour 56 | 57 | def __call__(self, pt): 58 | return 1 if cv2.pointPolygonTest(self.cont, pt, False) >= 0 else 0 59 | 60 | class isInContourV2(Contour_Checking_fn): 61 | def __init__(self, contour, patch_size): 62 | self.cont = contour 63 | self.patch_size = patch_size 64 | 65 | def __call__(self, pt): 66 | return 1 if cv2.pointPolygonTest(self.cont, (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2), False) >= 0 else 0 67 | 68 | # Easy version of 4pt contour checking function - 1 of 4 points need to be in the contour for test to pass 69 | class isInContourV3_Easy(Contour_Checking_fn): 70 | def __init__(self, contour, patch_size, center_shift=0.5): 71 | self.cont = contour 72 | self.patch_size = patch_size 73 | self.shift = int(patch_size//2*center_shift) 74 | def __call__(self, pt): 75 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 76 | if self.shift > 0: 77 | all_points = [(center[0]-self.shift, center[1]-self.shift), 78 | (center[0]+self.shift, center[1]+self.shift), 79 | (center[0]+self.shift, center[1]-self.shift), 80 | (center[0]-self.shift, center[1]+self.shift) 81 | ] 82 | else: 83 | all_points = [center] 84 | #print(all_points) 85 | for points in all_points: 86 | # need to convert 'numpy.int64' to 'int' 87 | if cv2.pointPolygonTest(self.cont, (int(points[0]),int(points[1])), False) >= 0: 88 | return 1 89 | return 0 90 | 91 | # Hard version of 4pt contour checking function - all 4 points need to be in the contour for test to pass 92 | class isInContourV3_Hard(Contour_Checking_fn): 93 | def __init__(self, contour, patch_size, center_shift=0.5): 94 | self.cont = contour 95 | self.patch_size = patch_size 96 | self.shift = int(patch_size//2*center_shift) 97 | def __call__(self, pt): 98 | center = (pt[0]+self.patch_size//2, pt[1]+self.patch_size//2) 99 | if self.shift > 0: 100 | all_points = [(center[0]-self.shift, center[1]-self.shift), 101 | (center[0]+self.shift, center[1]+self.shift), 102 | (center[0]+self.shift, center[1]-self.shift), 103 | (center[0]-self.shift, center[1]+self.shift) 104 | ] 105 | else: 106 | all_points = [center] 107 | 108 | for points in all_points: 109 | if cv2.pointPolygonTest(self.cont, points, False) < 0: 110 | return 0 111 | return 1 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /docs/ROAM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/docs/ROAM.png -------------------------------------------------------------------------------- /docs/cascade_diagnosis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/docs/cascade_diagnosis.jpg -------------------------------------------------------------------------------- /docs/environment.yaml: -------------------------------------------------------------------------------- 1 | name: py39 2 | channels: 3 | - https://mirror.tuna.tsinghua.edu.cn/anaconda/pkgs/main 4 | - https://mirror.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64 5 | - https://mirror.tuna.tsinghua.edu.cn/anaconda/pkgs/r 6 | - https://mirror.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 7 | - https://mirror.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 8 | - https://mirror.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_kmp_llvm 12 | - blas=1.0=mkl 13 | - blessings=1.7=py39h06a4308_1002 14 | - brotlipy=0.7.0=py39h27cfd23_1003 15 | - bzip2=1.0.8=h7b6447c_0 16 | - c-ares=1.18.1=h7f8727e_0 17 | - ca-certificates=2023.05.30=h06a4308_0 18 | - cached-property=1.5.2=py_0 19 | - cairo=1.16.0=hf32fb01_1 20 | - certifi=2023.5.7=py39h06a4308_0 21 | - cffi=1.15.1=py39h74dc2b5_0 22 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 23 | - cryptography=37.0.1=py39h9ce1e76_0 24 | - cudatoolkit=11.3.1=h2bc3f7f_2 25 | - dbus=1.13.18=hb2f20db_0 26 | - expat=2.4.4=h295c915_0 27 | - ffmpeg=4.3.2=h37c90e5_3 28 | - fontconfig=2.13.1=h6c09931_0 29 | - freeglut=3.0.0=hf484d3e_5 30 | - freetype=2.11.0=h70c0345_0 31 | - fsspec=2023.4.0=py39h06a4308_0 32 | - gdk-pixbuf=2.42.6=h04a7f16_0 33 | - gettext=0.21.0=hf68c758_0 34 | - giflib=5.2.1=h7b6447c_0 35 | - glib=2.68.4=h9c3ff4c_0 36 | - glib-tools=2.68.4=h9c3ff4c_0 37 | - gmp=6.2.1=h295c915_3 38 | - gnutls=3.6.15=he1e5248_0 39 | - gpustat=0.6.0=pyhd3eb1b0_1 40 | - graphite2=1.3.14=h295c915_1 41 | - gst-plugins-base=1.14.5=h0935bb2_2 42 | - gstreamer=1.18.5=h76c114f_0 43 | - h5py=3.7.0=nompi_py39h63b1161_100 44 | - harfbuzz=3.0.0=h83ec7ef_1 45 | - hdf5=1.12.1=h70be1eb_2 46 | - icu=68.1=h2531618_0 47 | - idna=3.3=pyhd3eb1b0_0 48 | - intel-openmp=2021.4.0=h06a4308_3561 49 | - jasper=2.0.14=hd8c5072_2 50 | - jbig=2.1=hdba287a_0 51 | - jpeg=9e=h7f8727e_0 52 | - krb5=1.19.2=hac12032_0 53 | - lame=3.100=h7b6447c_0 54 | - lcms2=2.12=h3be6417_0 55 | - ld_impl_linux-64=2.38=h1181459_1 56 | - lerc=3.0=h295c915_0 57 | - libblas=3.9.0=12_linux64_mkl 58 | - libcblas=3.9.0=12_linux64_mkl 59 | - libclang=11.1.0=default_ha53f305_1 60 | - libcurl=7.84.0=h91b91d3_0 61 | - libdeflate=1.8=h7f8727e_5 62 | - libedit=3.1.20210910=h7f8727e_0 63 | - libev=4.33=h7f8727e_1 64 | - libevent=2.1.10=h9b69904_4 65 | - libffi=3.3=he6710b0_2 66 | - libgcc-ng=12.1.0=h8d9b700_16 67 | - libgfortran-ng=11.2.0=h00389a5_1 68 | - libgfortran5=11.2.0=h1234567_1 69 | - libglib=2.68.4=h3e27bee_0 70 | - libglu=9.0.0=hf484d3e_1 71 | - libiconv=1.16=h7f8727e_2 72 | - libidn2=2.3.2=h7f8727e_0 73 | - liblapack=3.9.0=12_linux64_mkl 74 | - liblapacke=3.9.0=12_linux64_mkl 75 | - libllvm11=11.1.0=h3826bc1_1 76 | - libnghttp2=1.46.0=hce63b2e_0 77 | - libopencv=4.5.3=py39ha7f30a5_5 78 | - libopus=1.3.1=h7b6447c_0 79 | - libpng=1.6.37=hbc83047_0 80 | - libpq=12.9=h16c4e8d_3 81 | - libprotobuf=3.18.1=h780b84a_0 82 | - libssh2=1.10.0=h8f2d780_0 83 | - libstdcxx-ng=11.2.0=h1234567_1 84 | - libtasn1=4.16.0=h27cfd23_0 85 | - libtiff=4.3.0=h6f004c6_2 86 | - libunistring=0.9.10=h27cfd23_0 87 | - libuuid=1.0.3=h7f8727e_2 88 | - libvpx=1.7.0=h439df22_0 89 | - libwebp=1.2.2=h55f646e_0 90 | - libwebp-base=1.2.2=h7f8727e_0 91 | - libxcb=1.15=h7f8727e_0 92 | - libxkbcommon=1.0.3=he3ba5ed_0 93 | - libxml2=2.9.12=h72842e0_0 94 | - libzlib=1.2.12=h166bdaf_2 95 | - lightning-utilities=0.7.1=py39h06a4308_1 96 | - llvm-openmp=14.0.4=he0ac6c6_0 97 | - lz4-c=1.9.3=h295c915_1 98 | - mkl=2021.4.0=h06a4308_640 99 | - mkl-service=2.4.0=py39h7f8727e_0 100 | - mkl_fft=1.3.1=py39hd3c417c_0 101 | - mkl_random=1.2.2=py39h51133e4_0 102 | - mysql-common=8.0.29=haf5c9bc_1 103 | - mysql-libs=8.0.29=h28c427c_1 104 | - ncurses=6.3=h5eee18b_3 105 | - nettle=3.7.3=hbbd107a_1 106 | - nspr=4.33=h295c915_0 107 | - nss=3.74=h0370c37_0 108 | - numpy=1.23.1=py39h6c91a56_0 109 | - numpy-base=1.23.1=py39ha15fc14_0 110 | - nvidia-ml=7.352.0=pyhd3eb1b0_0 111 | - opencv=4.5.3=py39hf3d152e_5 112 | - openh264=2.1.1=h4ff587b_0 113 | - openjpeg=2.4.0=h3ad879b_0 114 | - openslide=3.4.1=h8137273_1 115 | - openslide-python=1.2.0=py39hb9d737c_0 116 | - openssl=1.1.1u=h7f8727e_0 117 | - pcre=8.45=h295c915_0 118 | - pillow=9.2.0=py39hace64e9_1 119 | - pip=22.1.2=py39h06a4308_0 120 | - pixman=0.40.0=h7f8727e_1 121 | - psutil=5.9.0=py39h5eee18b_0 122 | - py-opencv=4.5.3=py39hef51801_5 123 | - pycparser=2.21=pyhd3eb1b0_0 124 | - pyopenssl=22.0.0=pyhd3eb1b0_0 125 | - pysocks=1.7.1=py39h06a4308_0 126 | - python=3.9.0=hdb3f193_2 127 | - python_abi=3.9=2_cp39 128 | - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 129 | - pytorch-mutex=1.0=cuda 130 | - pyyaml=6.0=py39h5eee18b_1 131 | - qt=5.12.9=h9d6b050_2 132 | - readline=8.1.2=h7f8727e_1 133 | - requests=2.28.1=py39h06a4308_0 134 | - scipy=1.8.1=py39he49c0e8_0 135 | - setuptools=61.2.0=py39h06a4308_0 136 | - six=1.16.0=pyhd3eb1b0_1 137 | - sqlite=3.39.2=h5082296_0 138 | - timm=0.4.12=pyhd8ed1ab_0 139 | - tk=8.6.12=h1ccaba5_0 140 | - torchaudio=0.12.1=py39_cu113 141 | - torchmetrics=0.11.4=py39h2f386ee_1 142 | - torchvision=0.13.1=py39_cu113 143 | - tqdm=4.64.0=py39h06a4308_0 144 | - typing_extensions=4.3.0=py39h06a4308_0 145 | - tzdata=2022a=hda174b7_0 146 | - urllib3=1.26.11=py39h06a4308_0 147 | - wheel=0.37.1=pyhd3eb1b0_0 148 | - x264=1!161.3030=h7f98852_1 149 | - xz=5.2.5=h7f8727e_1 150 | - yaml=0.2.5=h7b6447c_0 151 | - zlib=1.2.12=h7f8727e_2 152 | - zstd=1.5.2=ha4553b6_0 153 | - pip: 154 | - absl-py==1.4.0 155 | - addict==2.4.0 156 | - aiohttp==3.8.4 157 | - aiosignal==1.3.1 158 | - async-timeout==4.0.2 159 | - attrs==23.1.0 160 | - cachetools==5.3.1 161 | - cycler==0.11.0 162 | - einops==0.4.1 163 | - fonttools==4.34.4 164 | - frozenlist==1.3.3 165 | - future==0.18.2 166 | - google-auth==2.21.0 167 | - google-auth-oauthlib==1.0.0 168 | - grpcio==1.56.0 169 | - importlib-metadata==6.7.0 170 | - joblib==1.1.0 171 | - kiwisolver==1.4.4 172 | - markdown==3.4.3 173 | - markupsafe==2.1.3 174 | - matplotlib==3.5.3 175 | - multidict==6.0.4 176 | - nystrom-attention==0.0.11 177 | - oauthlib==3.2.2 178 | - opencv-python==4.8.0.74 179 | - packaging==21.3 180 | - pandas==1.4.3 181 | - protobuf==4.23.3 182 | - pyasn1==0.5.0 183 | - pyasn1-modules==0.3.0 184 | - pydeprecate==0.3.2 185 | - pyparsing==3.0.9 186 | - python-dateutil==2.8.2 187 | - pytorch-lightning==1.6.3 188 | - pytorch-toolbelt==0.6.3 189 | - pytorch-warmup==0.1.1 190 | - pytz==2022.2.1 191 | - requests-oauthlib==1.3.1 192 | - rsa==4.9 193 | - scikit-learn==1.1.2 194 | - seaborn==0.11.2 195 | - spams==2.6.5.4 196 | - tensorboard==2.13.0 197 | - tensorboard-data-server==0.7.1 198 | - tensorboardx==2.6.1 199 | - threadpoolctl==3.1.0 200 | - torchsummary==1.5.1 201 | - werkzeug==2.3.6 202 | - yarl==1.9.2 203 | - zipp==3.15.0 204 | prefix: /home/yinxiaoxu/anaconda3/envs/py39 205 | 206 | -------------------------------------------------------------------------------- /docs/visualization_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteyunjie/ROAM/2c8414c2aa2d43d293bf6d45be37382fcc90530b/docs/visualization_examples.png --------------------------------------------------------------------------------