├── README.md ├── cluster.py ├── data_preprocess ├── generate_superpixel.py ├── graph_construction.ipynb └── superpixel_utils.py ├── data_splits ├── train_val_test_split_argo.pkl ├── train_val_test_split_kirc.pkl ├── train_val_test_split_lihc.pkl └── useless ├── images ├── miccai_framework.png └── try ├── label ├── argo_label.pt ├── kirc_label.pt ├── lihc_label.pt └── useless ├── requirements.txt └── train ├── block_utils.py ├── run.sh ├── superpixel_transformer_n.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | ## Multi-scope Analysis Driven Hierarchical Graph Transformer for Whole Slide Image based Cancer Survival Prediction 2 | 3 |
4 | 5 | ## Installation 6 | Clone the repo: 7 | ```bash 8 | git clone https://github.com/Baeksweety/HGTHGT && cd HGTHGT 9 | ``` 10 | Create a conda environment and activate it: 11 | ```bash 12 | conda create -n env python=3.8 13 | conda activate env 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Data Preprocess 18 | ***generate_superpixel.py*** shows how to generate merged superpixels of whole slide images and ***graph_construction.ipynb*** shows how to transform a histological image into the hierarchical graphs. After the data processing is completed, put all hierarchical graphs into a folder. The form is as follows: 19 | ```bash 20 | PYG_Data 21 | └── Dataset 22 | ├── pyg_data_1.pt 23 | ├── pyg_data_2.pt 24 | : 25 | └── pyg_data_n.pt 26 | ``` 27 | 28 | 29 | ## Cluster 30 | ***cluster.py*** shows how to generate the fixed number of clusters which woould be used in the train process. The form is as follows: 31 | ```bash 32 | Cluster_Info 33 | └── Dataset 34 | ├── cluster_info_1.pt 35 | ├── cluster_info_2.pt 36 | : 37 | └── cluster_info_n.pt 38 | ``` 39 | 40 | 41 | 42 | ## Training 43 | First, setting the data path, data splits and hyperparameters in the file ***train.py***. Then, experiments can be run using the following command-line: 44 | ```bash 45 | cd train 46 | python train.py 47 | or 48 | bash run.sh 49 | ``` 50 | 51 | ## Saved models 52 | We provide a 5-fold checkpoint for each dataset, which performing as: 53 | | Dataset | CI | 54 | | ----- |:--------:| 55 | | CRC | 0.607 | 56 | | TCGA_LIHC | 0.657 | 57 | | TCGA_KIRC | 0.646 | 58 | 59 | 60 | 61 | 62 | 63 | ## More Info 64 | - Our implementation refers the following publicly available codes. 65 | - [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric)--Fey M, Lenssen J E. Fast graph representation learning with PyTorch Geometric[J]. arXiv preprint arXiv:1903.02428, 2019. 66 | - [Histocartography](https://github.com/histocartography/histocartography)--Jaume G, Pati P, Anklin V, et al. HistoCartography: A toolkit for graph analytics in digital pathology[C]//MICCAI Workshop on Computational Pathology. PMLR, 2021: 117-128. 67 | - [ViT Pytorch](https://github.com/lukemelas/PyTorch-Pretrained-ViT)--Dosovitskiy A, Beyer L, Kolesnikov A, et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale[C]//International Conference on Learning Representations. 2020. 68 | - [NAGCN](https://github.com/YohnGuan/NAGCN)--Guan Y, Zhang J, Tian K, et al. Node-aligned graph convolutional network for whole-slide image representation and classification[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 18813-18823. 69 | -------------------------------------------------------------------------------- /cluster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import h5py 5 | from sklearn.cluster import MiniBatchKMeans, KMeans 6 | import os 7 | import random 8 | import argparse 9 | # 3rd party library 10 | from tqdm import tqdm 11 | import numpy as np 12 | import pandas as pds 13 | from sklearn.cluster import MiniBatchKMeans, KMeans 14 | import torch 15 | 16 | def get_cluster(datas, k=100, num=-1, seed=3): 17 | if seed is not None: 18 | random.seed(seed) 19 | if num > 0: 20 | data_sam = random.sample(datas, num) 21 | else: 22 | data_sam = datas 23 | if seed is not None: 24 | random.seed(seed) 25 | random.shuffle(data_sam) 26 | 27 | dim_N = 0 #record the number of all patches 28 | dim_D = data_sam[0]['shape'][1] 29 | for data in tqdm(data_sam): 30 | dim_N += data['shape'][0] 31 | con_data = np.zeros((dim_N, dim_D), dtype=np.float32) 32 | ind = 0 33 | for data in tqdm(data_sam): 34 | data_path, data_shape = data['slide'], data['shape'] 35 | cur_data = torch.load(data_path) 36 | con_data[ind:ind + data_shape[0], :] = cur_data.numpy() 37 | ind += data_shape[0] 38 | # clusterer = KMeans(n_clusters=k) 39 | clusterer = MiniBatchKMeans(n_clusters=k, batch_size=10000) 40 | clusterer.fit(con_data) 41 | print("cluster done") 42 | return clusterer 43 | 44 | for fold_num in range(5): 45 | cluster_data = torch.load('/data14/yanhe/miccai/codebook/data/argo/codebook_info_fold{}.pt'.format(fold_num)) 46 | # print(len(cluster_data)) 47 | train_data = [] 48 | for data in cluster_data: 49 | train_data.append(data) 50 | print(len(train_data)) 51 | clusterer = get_cluster(train_data, k=16, num=-1, seed=3) #k represents the number of clusters 52 | saved_path='/data14/yanhe/miccai/codebook/patch_cluster/argo/fold{}/'.format(fold_num) 53 | os.makedirs(saved_path,exist_ok=True) 54 | for slidename in os.listdir('/data12/ybj/survival/argo_selected/20x/slides_feat/h5_files'): 55 | if slidename[-3:] == '.h5': 56 | slide_path=os.path.join('/data12/ybj/survival/argo_selected/20x/slides_feat/h5_files',slidename) 57 | # print(slide_path) 58 | f = h5py.File(slide_path) 59 | name, _ = os.path.splitext(slidename) 60 | print(name) 61 | slide_cluster_path=saved_path+name+'.pth' 62 | # print(slide_cluster_path) 63 | length = f['coords'].shape[0] 64 | wsi_patch=[] 65 | for i in range(length): 66 | wsi_patch_info = {} 67 | patch_fea = f['features'][i].reshape(1,-1) 68 | cluster_class = clusterer.predict(patch_fea)[0] 69 | wsi_patch_info['patch_id']=i 70 | wsi_patch_info['cluster']=cluster_class 71 | wsi_patch.append(wsi_patch_info) 72 | torch.save(wsi_patch,slide_cluster_path) 73 | -------------------------------------------------------------------------------- /data_preprocess/generate_superpixel.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | import h5py 9 | from sklearn.cluster import MiniBatchKMeans, KMeans 10 | import random 11 | from PIL import Image 12 | from dgl.data.utils import save_graphs 13 | 14 | from histocartography.utils import download_example_data 15 | from histocartography.preprocessing import ( 16 | ColorMergedSuperpixelExtractor, 17 | DeepFeatureExtractor 18 | ) 19 | 20 | from histocartography.visualization import OverlayGraphVisualization 21 | from superpixel_utils import RAGGraphBuilder,MyColorMergedSuperpixelExtractor 22 | import argparse 23 | from skimage.measure import regionprops 24 | import joblib 25 | import cv2 26 | 27 | 28 | #remove background 29 | def get_node_centroids(instance_map: np.ndarray,raw_image:np.ndarray): 30 | # print(instance_map) 31 | regions = regionprops(instance_map) 32 | mask_value=1 33 | # centroids = np.empty((len(regions), 2)) 34 | for i, region in enumerate(regions): 35 | center_y, center_x = region.centroid # (y, x) 36 | # print(i) 37 | # print(region.coords) 38 | center_x = int(round(center_x)) 39 | center_y = int(round(center_y)) 40 | if sum(raw_image[center_y,center_x])== 0: 41 | for index,couple in enumerate(region.coords): 42 | y,x = couple 43 | # print(y,x) 44 | instance_map[y,x]=0 45 | else: 46 | for index,couple in enumerate(region.coords): 47 | y,x = couple 48 | # print(y,x) 49 | instance_map[y,x]=mask_value 50 | mask_value+=1 51 | 52 | # centroids[i, 0] = center_x 53 | # centroids[i, 1] = center_y 54 | # print(instance_map) 55 | return instance_map 56 | 57 | def generate_superpixel(image_path,downsampling_factor): 58 | """ 59 | Generate a tissue graph for all the images in image path dir. 60 | """ 61 | 62 | # 1. get image path 63 | # image_fnames = glob(os.path.join(image_path, '*.png')) 64 | # image_fnames = img_path 65 | # 2. define superpixel extractor. Here, we query 50 SLIC superpixels, 66 | # but a superpixel size (in #pixels) can be provided as well in the case 67 | # where image size vary from one sample to another. 68 | superpixel_detector = ColorMergedSuperpixelExtractor( 69 | nr_superpixels=args.nr_superpixels, 70 | compactness=10, 71 | blur_kernel_size=1, 72 | threshold=1, 73 | downsampling_factor=downsampling_factor, 74 | connectivity = 2, 75 | ) 76 | 77 | # 3. define feature extractor: extract patches of 144x144 pixels 78 | # resized to 224 to match resnet input size. If the superpixel is larger 79 | # than 144x144, several patches are extracted and patch embeddings are averaged. 80 | # Everything is handled internally. Please refer to the implementation for 81 | # details. 82 | feature_extractor = DeepFeatureExtractor( 83 | architecture='resnet34', 84 | patch_size=144, 85 | resize_size=224 86 | ) 87 | 88 | # 4. define graph builder 89 | tissue_graph_builder = RAGGraphBuilder(add_loc_feats=True) 90 | 91 | # 5. define graph visualizer 92 | visualizer = OverlayGraphVisualization() 93 | 94 | _, image_name = os.path.split(image_path) 95 | image = Image.open(image_path) 96 | image=np.array(image) 97 | fname = image_name[:-4]+'.png' 98 | print(image.shape) 99 | 100 | superpixels, _ = superpixel_detector.process(image) 101 | superpixels = get_node_centroids(superpixels,image) 102 | return superpixels 103 | 104 | 105 | def modify_superpixels(superpixel_patchnum,instance_map): 106 | regions = regionprops(instance_map) 107 | mask_value=1 108 | for i, region in enumerate(regions): 109 | sup_value = region.label 110 | if (superpixel_patchnum[sup_value]==0) or (superpixel_patchnum[sup_value]==1): 111 | print('true!') 112 | for index,couple in enumerate(region.coords): 113 | y,x = couple 114 | instance_map[y,x]=0 115 | else: 116 | for index,couple in enumerate(region.coords): 117 | y,x = couple 118 | instance_map[y,x]=mask_value 119 | mask_value+=1 120 | return instance_map 121 | 122 | 123 | 124 | def get_max_value(data_matrix): 125 | new_data=[] 126 | for i in range(len(data_matrix)): 127 | new_data.append(max(data_matrix[i])) 128 | return max(new_data) 129 | 130 | 131 | 132 | def patch_judge_40x(stitch_path,saved_path): 133 | _, image_name = os.path.split(stitch_path) 134 | name = image_name[:-4] 135 | superpixels=generate_superpixel(stitch_path,downsampling_factor=4) 136 | print(superpixels.shape) 137 | h,w=superpixels.shape 138 | slide_info_path = '/data12/ybj/survival/CRC/40x/slides_feat/h5_files/'+name+'.h5' 139 | print(slide_info_path) 140 | f = h5py.File(slide_info_path) 141 | new_slide_info=[] 142 | new_slide_info_path=saved_path+name+'.pth' 143 | length = f['coords'].shape[0] 144 | for i in range(length): 145 | patch_info={} 146 | patch_info['patch_id']=i 147 | patch_info['coords']=f['coords'][i] 148 | patch_info['features']=f['features'][i] 149 | patch_coords=f['coords'][i] 150 | # x,y=slide_info[i]['coords'] 151 | x,y=patch_coords 152 | x=x+512 153 | y=y+512 154 | x=int(x/16) 155 | y=int(y/16) 156 | if (y=h)&(x=w): 166 | x_n,y_n=f['coords'][i] 167 | x_n=x_n+(w*16-x_n)/2 168 | y_n=y_n+512 169 | x_n=int(x_n/16) 170 | y_n=int(y_n/16) 171 | patch_info['superpixel']=superpixels[y_n][x_n] 172 | else: 173 | x_n,y_n=f['coords'][i] 174 | x_n=x_n+(w*16-x_n)/2 175 | y_n=y_n+(h*16-y_n)/2 176 | x_n=int(x_n/16) 177 | y_n=int(y_n/16) 178 | patch_info['superpixel']=superpixels[y_n][x_n] 179 | new_slide_info.append(patch_info) 180 | superpixel_patchnum={} 181 | max_superpixel = get_max_value(superpixels) 182 | # print(max(superpixel)) 183 | print(max_superpixel) 184 | for sup_value in range(1,max_superpixel+1): 185 | # print(sup_value) 186 | coords_s=[] 187 | features_s=[] 188 | for patch_id in range(len(new_slide_info)): 189 | if new_slide_info[patch_id]['superpixel']==sup_value: 190 | coords_s.append(new_slide_info[patch_id]['coords']) 191 | # features_s.append(new_slide_info[patch_id]['features']) 192 | coords_np=np.array(coords_s) 193 | # features_np=np.array(features_s) 194 | # print(coords_np.shape,features_np.shape) 195 | patch_number=coords_np.shape[0] 196 | superpixel_patchnum[sup_value] = patch_number 197 | superpixels = modify_superpixels(superpixel_patchnum,superpixels) 198 | wsi_info = [] 199 | for i in range(length): 200 | patch_info={} 201 | patch_info['patch_id']=i 202 | patch_info['coords']=f['coords'][i] 203 | patch_info['features']=f['features'][i] 204 | patch_coords=f['coords'][i] 205 | # x,y=slide_info[i]['coords'] 206 | x,y=patch_coords 207 | x=x+512 208 | y=y+512 209 | x=int(x/16) 210 | y=int(y/16) 211 | if (y=h)&(x=w): 221 | x_n,y_n=f['coords'][i] 222 | x_n=x_n+(w*16-x_n)/2 223 | y_n=y_n+512 224 | x_n=int(x_n/16) 225 | y_n=int(y_n/16) 226 | patch_info['superpixel']=superpixels[y_n][x_n] 227 | else: 228 | x_n,y_n=f['coords'][i] 229 | x_n=x_n+(w*16-x_n)/2 230 | y_n=y_n+(h*16-y_n)/2 231 | x_n=int(x_n/16) 232 | y_n=int(y_n/16) 233 | patch_info['superpixel']=superpixels[y_n][x_n] 234 | wsi_info.append(patch_info) 235 | torch.save(wsi_info,new_slide_info_path) 236 | return superpixels 237 | 238 | 239 | def patch_judge_20x(stitch_path,saved_path): 240 | _, image_name = os.path.split(stitch_path) 241 | name = image_name[:-4] 242 | superpixels=generate_superpixel(stitch_path,downsampling_factor=2) 243 | print(superpixels.shape) 244 | h,w=superpixels.shape 245 | slide_info_path = '/data12/ybj/survival/CRC/20x/slides_feat/h5_files/'+name+'.h5' 246 | f = h5py.File(slide_info_path) 247 | new_slide_info=[] 248 | new_slide_info_path=saved_path+name+'.pth' 249 | length = f['coords'].shape[0] 250 | for i in range(length): 251 | patch_info={} 252 | patch_info['patch_id']=i 253 | patch_info['coords']=f['coords'][i] 254 | patch_info['features']=f['features'][i] 255 | patch_coords=f['coords'][i] 256 | x,y=patch_coords 257 | x=x+256 258 | y=y+256 259 | x=int(x/16) 260 | y=int(y/16) 261 | if (y=h)&(x=w): 271 | x_n,y_n=f['coords'][i] 272 | x_n=x_n+(w*16-x_n)/2 273 | y_n=y_n+256 274 | x_n=int(x_n/16) 275 | y_n=int(y_n/16) 276 | patch_info['superpixel']=superpixels[y_n][x_n] 277 | else: 278 | x_n,y_n=f['coords'][i] 279 | x_n=x_n+(w*16-x_n)/2 280 | y_n=y_n+(h*16-y_n)/2 281 | x_n=int(x_n/16) 282 | y_n=int(y_n/16) 283 | patch_info['superpixel']=superpixels[y_n][x_n] 284 | new_slide_info.append(patch_info) 285 | superpixel_patchnum={} 286 | max_superpixel = get_max_value(superpixels) 287 | # print(max(superpixel)) 288 | print(max_superpixel) 289 | for sup_value in range(1,max_superpixel+1): 290 | # print(sup_value) 291 | coords_s=[] 292 | features_s=[] 293 | for patch_id in range(len(new_slide_info)): 294 | if new_slide_info[patch_id]['superpixel']==sup_value: 295 | coords_s.append(new_slide_info[patch_id]['coords']) 296 | coords_np=np.array(coords_s) 297 | patch_number=coords_np.shape[0] 298 | superpixel_patchnum[sup_value] = patch_number 299 | superpixels = modify_superpixels(superpixel_patchnum,superpixels) 300 | wsi_info = [] 301 | for i in range(length): 302 | patch_info={} 303 | patch_info['patch_id']=i 304 | patch_info['coords']=f['coords'][i] 305 | patch_info['features']=f['features'][i] 306 | patch_coords=f['coords'][i] 307 | x,y=patch_coords 308 | x=x+256 309 | y=y+256 310 | x=int(x/16) 311 | y=int(y/16) 312 | if (y=h)&(x=w): 322 | x_n,y_n=f['coords'][i] 323 | x_n=x_n+(w*16-x_n)/2 324 | y_n=y_n+256 325 | x_n=int(x_n/16) 326 | y_n=int(y_n/16) 327 | patch_info['superpixel']=superpixels[y_n][x_n] 328 | else: 329 | x_n,y_n=f['coords'][i] 330 | x_n=x_n+(w*16-x_n)/2 331 | y_n=y_n+(h*16-y_n)/2 332 | x_n=int(x_n/16) 333 | y_n=int(y_n/16) 334 | patch_info['superpixel']=superpixels[y_n][x_n] 335 | wsi_info.append(patch_info) 336 | torch.save(wsi_info,new_slide_info_path) 337 | return superpixels 338 | 339 | def generate_tissue_graph(slide_list_path,image_path,saved_path,graph_file_saved_path,vis_saved_path): 340 | """ 341 | Generate a tissue graph for all the images in image path dir. 342 | """ 343 | feature_extractor = DeepFeatureExtractor( 344 | architecture='resnet34', 345 | patch_size=144, 346 | resize_size=224 347 | ) 348 | 349 | # 4. define graph builder 350 | tissue_graph_builder = RAGGraphBuilder(add_loc_feats=True) 351 | 352 | # 5. define graph visualizer 353 | visualizer = OverlayGraphVisualization() 354 | # print(fname) 355 | # b. extract superpixels 356 | slide_list = joblib.load(slide_list_path) 357 | # slide_x, _ = os.path.splitext(slide_list_path) 358 | _, slide_x = os.path.split(slide_list_path) 359 | # print(slide_list_path) 360 | # print(slide_x) 361 | slide_x = slide_x.split('_')[1] 362 | print(slide_x) 363 | # for slidename in os.listdir(image_path): 364 | for index in range(len(slide_list)): 365 | slidename = slide_list[index] 366 | name, _ = os.path.splitext(slidename) 367 | # print(name) 368 | stitch_name = name+'.jpg' 369 | stitch_path = os.path.join(image_path,stitch_name) 370 | print(stitch_path) 371 | if slide_x == '20x': 372 | superpixels = patch_judge_20x(stitch_path,saved_path) 373 | elif slide_x == '40x': 374 | superpixels = patch_judge_40x(stitch_path,saved_path) 375 | # torch.save(superpixels,os.path.join('/data13/yanhe/miccai/super_pixel/superpixel_array/tcga_lihc_200superpixel',image_name[:-4]+'.pt')) 376 | # print(image_name) 377 | print(superpixels.shape) 378 | # print(superpixels) 379 | # c. extract deep features 380 | image = Image.open(stitch_path) 381 | features = feature_extractor.process(image, superpixels) 382 | # print(features.shape) 383 | # d. build a Region Adjacency Graph (RAG) 384 | # graph = tissue_graph_builder.process(image, superpixels, features) 385 | graph = tissue_graph_builder.process(superpixels, features) 386 | # print(graph) 387 | # e. save the graph 388 | torch.save(graph,os.path.join(graph_file_saved_path,slidename[:-4]+'.pt')) 389 | # f. visualize and save the graph 390 | canvas = visualizer.process(image, graph, instance_map=superpixels) 391 | canvas.save(os.path.join(vis_saved_path,slidename[:-4]+'.png')) 392 | 393 | def main(args): 394 | saved_path = args.saved_path 395 | graph_file_saved_path = args.graph_file_saved_path 396 | vis_saved_path = args.vis_saved_path 397 | os.makedirs(saved_path,exist_ok=True) 398 | os.makedirs(graph_file_saved_path,exist_ok=True) 399 | os.makedirs(vis_saved_path,exist_ok=True) 400 | #20x 401 | slide_20x_path = args.slide_20x_path 402 | stitch_20x_path = args.stitch_20x_path 403 | generate_tissue_graph(slide_20x_path,stitch_20x_path,saved_path,graph_file_saved_path,vis_saved_path) 404 | 405 | #40x 406 | slide_40x_path = args.slide_40x_path 407 | stitch_40x_path = args.stitch_40x_path 408 | generate_tissue_graph(slide_40x_path,stitch_40x_path,saved_path,graph_file_saved_path,vis_saved_path) 409 | 410 | 411 | 412 | def get_params(): 413 | parser = argparse.ArgumentParser(description='superpixel_generate') 414 | 415 | parser.add_argument('--slide_40x_path', type=str, default='/data12/yanhe/miccai/data/tcga_crc/slide_40x_list.pkl') 416 | parser.add_argument('--slide_20x_path', type=str, default='/data12/yanhe/miccai/data/tcga_crc/slide_20x_list.pkl') 417 | parser.add_argument('--stitch_20x_path', type=str, default='/data12/ybj/survival/CRC/20x/stitches') 418 | parser.add_argument('--stitch_40x_path', type=str, default='/data12/ybj/survival/CRC/40x/stitches') 419 | parser.add_argument('--saved_path', type=str, default='/data11/yanhe/miccai/super_pixel/slide_superpixel/tcga_crc/superpixel_num_300/') 420 | parser.add_argument('--vis_saved_path', type=str, default='/data11/yanhe/miccai/super_pixel/vis/tcga_crc/superpixel_num_300') 421 | parser.add_argument('--graph_file_saved_path', type=str, default='/data11/yanhe/miccai/super_pixel/graph_file/tcga_crc/superpixel_num_300') 422 | parser.add_argument('--nr_superpixels', type=int, default=300) 423 | 424 | args, _ = parser.parse_known_args() 425 | return args 426 | 427 | 428 | if __name__ == '__main__': 429 | try: 430 | args=get_params() 431 | main(args) 432 | except Exception as exception: 433 | # logger.exception(exception) 434 | raise 435 | -------------------------------------------------------------------------------- /data_preprocess/graph_construction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import nmslib\n", 10 | "import networkx as nx\n", 11 | "import os\n", 12 | "import torch\n", 13 | "import numpy as np\n", 14 | "from tqdm import tqdm\n", 15 | "import h5py" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "class Hnsw:\n", 25 | " def __init__(self, space='cosinesimil', index_params=None,\n", 26 | " query_params=None, print_progress=True):\n", 27 | " self.space = space\n", 28 | " self.index_params = index_params\n", 29 | " self.query_params = query_params\n", 30 | " self.print_progress = print_progress\n", 31 | "\n", 32 | " def fit(self, X):\n", 33 | " index_params = self.index_params\n", 34 | " if index_params is None:\n", 35 | " index_params = {'M': 16, 'post': 0, 'efConstruction': 400}\n", 36 | "\n", 37 | " query_params = self.query_params\n", 38 | " if query_params is None:\n", 39 | " query_params = {'ef': 90}\n", 40 | "\n", 41 | " # this is the actual nmslib part, hopefully the syntax should\n", 42 | " # be pretty readable, the documentation also has a more verbiage\n", 43 | " # introduction: https://nmslib.github.io/nmslib/quickstart.html\n", 44 | " index = nmslib.init(space=self.space, method='hnsw')\n", 45 | " index.addDataPointBatch(X)\n", 46 | " index.createIndex(index_params, print_progress=self.print_progress)\n", 47 | " index.setQueryTimeParams(query_params)\n", 48 | "\n", 49 | " self.index_ = index\n", 50 | " self.index_params_ = index_params\n", 51 | " self.query_params_ = query_params\n", 52 | " return self\n", 53 | "\n", 54 | " def query(self, vector, topn):\n", 55 | " # the knnQuery returns indices and corresponding distance\n", 56 | " # we will throw the distance away for now\n", 57 | " indices, dist = self.index_.knnQuery(vector, k=topn)\n", 58 | " return indices" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "from torch_geometric.data import Data as geomData\n", 68 | "from itertools import chain\n", 69 | "\n", 70 | "def pt2graph(wsi_h5, slidename,radius=9):\n", 71 | " from torch_geometric.data import Data as geomData\n", 72 | " from itertools import chain\n", 73 | " coords, features = np.array(wsi_h5['coords']), np.array(wsi_h5['features'])\n", 74 | " assert coords.shape[0] == features.shape[0]\n", 75 | " num_patches = coords.shape[0]\n", 76 | "# print(num_patches)\n", 77 | " superpixel_info_path = os.path.join('/data14/yanhe/miccai/super_pixel/slide_superpixel/argo_new/superpixel_num_500',slidename+'.pth')\n", 78 | " superpixel_info = torch.load(superpixel_info_path)\n", 79 | " superpixel_attri=[]\n", 80 | "# print(len(superpixel_info))\n", 81 | " for index in range(len(superpixel_info)):\n", 82 | " superpixel = superpixel_info[index]['superpixel']\n", 83 | " superpixel_attri.append(superpixel)\n", 84 | " superpixel_attri = torch.LongTensor(superpixel_attri) \n", 85 | "# print(superpixel_attri)\n", 86 | " \n", 87 | " inter_graph_path = os.path.join('/data14/yanhe/miccai/super_pixel/graph_file/argo/superpixel_num_500',slidename+'.pt')\n", 88 | " g = torch.load(inter_graph_path)\n", 89 | " for index in range(g.ndata['centroid'].shape[0]):\n", 90 | " edge_index = torch.Tensor()\n", 91 | " edge_index = g.edges()[0].unsqueeze(0)\n", 92 | " edge_index = torch.cat((edge_index,g.edges()[1].unsqueeze(0)),dim=0)\n", 93 | " \n", 94 | " model = Hnsw(space='l2')\n", 95 | " model.fit(coords)\n", 96 | " a = np.repeat(range(num_patches), radius-1)\n", 97 | " b = np.fromiter(chain(*[model.query(coords[v_idx], topn=radius)[1:] for v_idx in range(num_patches)]),dtype=int)\n", 98 | " edge_spatial = torch.Tensor(np.stack([a,b])).type(torch.LongTensor)\n", 99 | " superpixel_edge = superpixel_attri[edge_spatial]\n", 100 | "\n", 101 | " edge_mask = (superpixel_edge[0,:] == superpixel_edge[1,:])\n", 102 | "\n", 103 | " remain_edge_index = edge_spatial[:,edge_mask]\n", 104 | " G = geomData(x = torch.Tensor(features),\n", 105 | " edge_patch = remain_edge_index,\n", 106 | " edge_superpixel = edge_index,\n", 107 | " superpixel_attri = superpixel_attri,\n", 108 | " centroid = torch.Tensor(coords))\n", 109 | " return G" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def createDir_h5toPyG(h5_path, save_path):\n", 119 | " pbar = tqdm(os.listdir(h5_path))\n", 120 | " for h5_fname in pbar:\n", 121 | " pbar.set_description('%s - Creating Graph' % (h5_fname[:-3]))\n", 122 | "\n", 123 | " try:\n", 124 | " wsi_h5 = h5py.File(os.path.join(h5_path, h5_fname), \"r\")\n", 125 | " slidename = h5_fname[:-3]\n", 126 | " if slidename != 'ZS6Y1A01554_HE520' and slidename != 'ZS6Y1A03883_HE208' and slidename != 'ZS6Y1A07240_HE400' and slidename != 'ZS6Y1A08318_HE155':\n", 127 | " G = pt2graph(wsi_h5,slidename)\n", 128 | " torch.save(G, os.path.join(save_path, h5_fname[:-3]+'.pt'))\n", 129 | " wsi_h5.close()\n", 130 | " except OSError:\n", 131 | " pbar.set_description('%s - Broken H5' % (h5_fname[:12]))\n", 132 | " print(h5_fname, 'Broken')" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "h5_path = '/data12/ybj/survival/argo_selected/20x/slides_feat/h5_files'\n", 142 | "save_path = '/data14/yanhe/miccai/graph_file/argo/superpixel_num_500/'\n", 143 | "os.makedirs(save_path, exist_ok=True)\n", 144 | "createDir_h5toPyG(h5_path, save_path)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "pylcx", 207 | "language": "python", 208 | "name": "pylcx" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.8.5" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 4 225 | } 226 | -------------------------------------------------------------------------------- /data_preprocess/superpixel_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | import h5py 9 | from sklearn.cluster import MiniBatchKMeans, KMeans 10 | import random 11 | from PIL import Image 12 | from dgl.data.utils import save_graphs 13 | 14 | from histocartography.utils import download_example_data 15 | from histocartography.preprocessing import ( 16 | ColorMergedSuperpixelExtractor, 17 | DeepFeatureExtractor 18 | ) 19 | from histocartography.visualization import OverlayGraphVisualization 20 | 21 | from skimage.measure import regionprops 22 | import joblib 23 | import cv2 24 | 25 | import logging 26 | import multiprocessing 27 | import os 28 | import sys 29 | from abc import ABC, abstractmethod 30 | from copy import deepcopy 31 | from functools import partial 32 | from pathlib import Path 33 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 34 | 35 | import h5py 36 | import pandas as pd 37 | from tqdm.auto import tqdm 38 | 39 | import logging 40 | import multiprocessing 41 | import os 42 | import sys 43 | from abc import ABC, abstractmethod 44 | from copy import deepcopy 45 | from functools import partial 46 | from pathlib import Path 47 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 48 | 49 | import h5py 50 | import pandas as pd 51 | from tqdm.auto import tqdm 52 | 53 | class PipelineStep(ABC): 54 | """Base pipelines step""" 55 | 56 | def __init__( 57 | self, 58 | save_path: Union[None, str, Path] = None, 59 | precompute: bool = True, 60 | link_path: Union[None, str, Path] = None, 61 | precompute_path: Union[None, str, Path] = None, 62 | ) -> None: 63 | """Abstract class that helps with saving and loading precomputed results 64 | Args: 65 | save_path (Union[None, str, Path], optional): Base path to save results to. 66 | When set to None, the results are not saved to disk. Defaults to None. 67 | precompute (bool, optional): Whether to perform the precomputation necessary 68 | for the step. Defaults to True. 69 | link_path (Union[None, str, Path], optional): Path to link the output directory 70 | to. When None, no link is created. Only supported when save_path is not None. 71 | Defaults to None. 72 | precompute_path (Union[None, str, Path], optional): Path to save the output of 73 | the precomputation to. If not specified it defaults to the output directory 74 | of the step when save_path is not None. Defaults to None. 75 | """ 76 | assert ( 77 | save_path is not None or link_path is None 78 | ), "link_path only supported when save_path is not None" 79 | 80 | name = self.__repr__() 81 | self.save_path = save_path 82 | if self.save_path is not None: 83 | self.output_dir = Path(self.save_path) / name 84 | self.output_key = "default_key" 85 | self._mkdir() 86 | if precompute_path is None: 87 | precompute_path = save_path 88 | 89 | if precompute: 90 | self.precompute( 91 | link_path=link_path, 92 | precompute_path=precompute_path) 93 | 94 | def __repr__(self) -> str: 95 | """Representation of a pipeline step. 96 | Returns: 97 | str: Representation of a pipeline step. 98 | """ 99 | variables = ",".join( 100 | [f"{k}={v}" for k, v in sorted(self.__dict__.items())]) 101 | return ( 102 | f"{self.__class__.__name__}({variables})".replace(" ", "") 103 | .replace('"', "") 104 | .replace("'", "") 105 | .replace("..", "") 106 | .replace("/", "_") 107 | ) 108 | 109 | def _mkdir(self) -> None: 110 | """Create path to output files""" 111 | assert ( 112 | self.save_path is not None 113 | ), "Can only create directory if base_path was not None when constructing the object" 114 | if not self.output_dir.exists(): 115 | self.output_dir.mkdir() 116 | 117 | def _link_to_path(self, link_directory: Union[None, str, Path]) -> None: 118 | """Links the output directory to the given directory. 119 | Args: 120 | link_directory (Union[None, str, Path]): Directory to link to 121 | """ 122 | if link_directory is None or Path( 123 | link_directory).parent.resolve() == Path(self.output_dir): 124 | logging.info("Link to self skipped") 125 | return 126 | assert ( 127 | self.save_path is not None 128 | ), f"Linking only supported when saving is enabled, i.e. when save_path is passed in the constructor." 129 | if os.path.islink(link_directory): 130 | if os.path.exists(link_directory): 131 | logging.info("Link already exists: overwriting...") 132 | os.remove(link_directory) 133 | else: 134 | logging.critical( 135 | "Link exists, but points nowhere. Ignoring...") 136 | return 137 | elif os.path.exists(link_directory): 138 | os.remove(link_directory) 139 | os.symlink(self.output_dir, link_directory, target_is_directory=True) 140 | 141 | def precompute( 142 | self, 143 | link_path: Union[None, str, Path] = None, 144 | precompute_path: Union[None, str, Path] = None, 145 | ) -> None: 146 | """Precompute all necessary information for this step 147 | Args: 148 | link_path (Union[None, str, Path], optional): Path to link the output to. Defaults to None. 149 | precompute_path (Union[None, str, Path], optional): Path to load/save the precomputation outputs. Defaults to None. 150 | """ 151 | pass 152 | 153 | def process( 154 | self, *args: Any, output_name: Optional[str] = None, **kwargs: Any 155 | ) -> Any: 156 | """Main process function of the step and outputs the result. Try to saves the output when output_name is passed. 157 | Args: 158 | output_name (Optional[str], optional): Unique identifier of the passed datapoint. Defaults to None. 159 | Returns: 160 | Any: Result of the pipeline step 161 | """ 162 | if output_name is not None and self.save_path is not None: 163 | return self._process_and_save( 164 | *args, output_name=output_name, **kwargs) 165 | else: 166 | return self._process(*args, **kwargs) 167 | 168 | @abstractmethod 169 | def _process(self, *args: Any, **kwargs: Any) -> Any: 170 | """Abstract method that performs the computation of the pipeline step 171 | Returns: 172 | Any: Result of the pipeline step 173 | """ 174 | 175 | def _get_outputs(self, input_file: h5py.File) -> Union[Any, Tuple]: 176 | """Extracts the step output from a given h5 file 177 | Args: 178 | input_file (h5py.File): File to load from 179 | Returns: 180 | Union[Any, Tuple]: Previously computed output of the step 181 | """ 182 | outputs = list() 183 | nr_outputs = len(input_file.keys()) 184 | 185 | # Legacy, remove at some point 186 | if nr_outputs == 1 and self.output_key in input_file.keys(): 187 | return tuple([input_file[self.output_key][()]]) 188 | 189 | for i in range(nr_outputs): 190 | outputs.append(input_file[f"{self.output_key}_{i}"][()]) 191 | if len(outputs) == 1: 192 | return outputs[0] 193 | else: 194 | return tuple(outputs) 195 | 196 | def _set_outputs(self, output_file: h5py.File, 197 | outputs: Union[Tuple, Any]) -> None: 198 | """Save the step output to a given h5 file 199 | Args: 200 | output_file (h5py.File): File to write to 201 | outputs (Union[Tuple, Any]): Computed step output 202 | """ 203 | if not isinstance(outputs, tuple): 204 | outputs = tuple([outputs]) 205 | for i, output in enumerate(outputs): 206 | output_file.create_dataset( 207 | f"{self.output_key}_{i}", 208 | data=output, 209 | compression="gzip", 210 | compression_opts=9, 211 | ) 212 | 213 | def _process_and_save( 214 | self, *args: Any, output_name: str, **kwargs: Any 215 | ) -> Any: 216 | """Process and save in the provided path as as .h5 file 217 | Args: 218 | output_name (str): Unique identifier of the the passed datapoint 219 | Raises: 220 | read_error (OSError): When the unable to read to self.output_dir/output_name.h5 221 | write_error (OSError): When the unable to write to self.output_dir/output_name.h5 222 | Returns: 223 | Any: Result of the pipeline step 224 | """ 225 | assert ( 226 | self.save_path is not None 227 | ), "Can only save intermediate output if base_path was not None when constructing the object" 228 | output_path = self.output_dir / f"{output_name}.h5" 229 | if output_path.exists(): 230 | logging.info( 231 | f"{self.__class__.__name__}: Output of {output_name} already exists, using it instead of recomputing" 232 | ) 233 | try: 234 | with h5py.File(output_path, "r") as input_file: 235 | output = self._get_outputs(input_file=input_file) 236 | except OSError as read_error: 237 | print(f"\n\nCould not read from {output_path}!\n\n") 238 | raise read_error 239 | else: 240 | output = self._process(*args, **kwargs) 241 | try: 242 | with h5py.File(output_path, "w") as output_file: 243 | self._set_outputs(output_file=output_file, outputs=output) 244 | except OSError as write_error: 245 | print(f"\n\nCould not write to {output_path}!\n\n") 246 | raise write_error 247 | return output 248 | 249 | def fast_histogram(input_array: np.ndarray, nr_values: int) -> np.ndarray: 250 | """Calculates a histogram of a matrix of the values from 0 up to (excluding) nr_values 251 | Args: 252 | x (np.array): Input tensor 253 | nr_values (int): Possible values. From 0 up to (exclusing) nr_values. 254 | Returns: 255 | np.array: Output tensor 256 | """ 257 | output_array = np.empty(nr_values, dtype=int) 258 | for i in range(nr_values): 259 | output_array[i] = (input_array == i).sum() 260 | return output_array 261 | 262 | 263 | def load_image(image_path: Path) -> np.ndarray: 264 | """Loads an image from a given path and returns it as a numpy array 265 | Args: 266 | image_path (Path): Path of the image 267 | Returns: 268 | np.ndarray: Array representation of the image 269 | """ 270 | assert image_path.exists() 271 | try: 272 | with Image.open(image_path) as img: 273 | image = np.array(img) 274 | except OSError as e: 275 | logging.critical("Could not open %s", image_path) 276 | raise OSError(e) 277 | return image 278 | 279 | """This module handles all the graph building""" 280 | 281 | import logging 282 | from abc import abstractmethod 283 | from pathlib import Path 284 | from typing import Any, Optional, Tuple, Union 285 | 286 | import cv2 287 | import dgl 288 | import networkx as nx 289 | import numpy as np 290 | import pandas as pd 291 | import torch 292 | from dgl.data.utils import load_graphs, save_graphs 293 | from skimage.measure import regionprops 294 | from sklearn.neighbors import kneighbors_graph 295 | 296 | # from ..pipeline import PipelineStep 297 | # from .utils import fast_histogram 298 | 299 | 300 | 301 | LABEL = "label" 302 | CENTROID = "centroid" 303 | FEATURES = "feat" 304 | 305 | 306 | def two_hop_neighborhood(graph: dgl.DGLGraph) -> dgl.DGLGraph: 307 | """Increases the connectivity of a given graph by an additional hop 308 | Args: 309 | graph (dgl.DGLGraph): Input graph 310 | Returns: 311 | dgl.DGLGraph: Output graph 312 | """ 313 | A = graph.adjacency_matrix().to_dense() 314 | A_tilde = (1.0 * ((A + A.matmul(A)) >= 1)) - torch.eye(A.shape[0]) 315 | ngraph = nx.convert_matrix.from_numpy_matrix(A_tilde.numpy()) 316 | new_graph = dgl.DGLGraph() 317 | new_graph.from_networkx(ngraph) 318 | for k, v in graph.ndata.items(): 319 | new_graph.ndata[k] = v 320 | for k, v in graph.edata.items(): 321 | new_graph.edata[k] = v 322 | return new_graph 323 | 324 | 325 | class BaseGraphBuilder(PipelineStep): 326 | """ 327 | Base interface class for graph building. 328 | """ 329 | 330 | def __init__( 331 | self, 332 | nr_annotation_classes: int = 5, 333 | annotation_background_class: Optional[int] = None, 334 | add_loc_feats: bool = False, 335 | **kwargs: Any 336 | ) -> None: 337 | """ 338 | Base Graph Builder constructor. 339 | Args: 340 | nr_annotation_classes (int): Number of classes in annotation. Used only if setting node labels. 341 | annotation_background_class (int): Background class label in annotation. Used only if setting node labels. 342 | add_loc_feats (bool): Flag to include location-based features (ie normalized centroids) 343 | in node feature representation. 344 | Defaults to False. 345 | """ 346 | self.nr_annotation_classes = nr_annotation_classes 347 | self.annotation_background_class = annotation_background_class 348 | self.add_loc_feats = add_loc_feats 349 | super().__init__(**kwargs) 350 | 351 | def _process( # type: ignore[override] 352 | self, 353 | instance_map: np.ndarray, 354 | features: torch.Tensor, 355 | annotation: Optional[np.ndarray] = None, 356 | ) -> dgl.DGLGraph: 357 | """Generates a graph from a given instance_map and features 358 | Args: 359 | instance_map (np.array): Instance map depicting tissue components 360 | features (torch.Tensor): Features of each node. Shape (nr_nodes, nr_features) 361 | annotation (Union[None, np.array], optional): Optional node level to include. 362 | Defaults to None. 363 | Returns: 364 | dgl.DGLGraph: The constructed graph 365 | """ 366 | # add nodes 367 | num_nodes = features.shape[0] 368 | graph = dgl.DGLGraph() 369 | graph.add_nodes(num_nodes) 370 | 371 | # add image size as graph data 372 | image_size = (instance_map.shape[1], instance_map.shape[0]) # (x, y) 373 | 374 | # get instance centroids 375 | centroids = self._get_node_centroids(instance_map) 376 | 377 | # add node content 378 | self._set_node_centroids(centroids, graph) 379 | self._set_node_features(features, image_size, graph) 380 | if annotation is not None: 381 | self._set_node_labels(instance_map, annotation, graph) 382 | 383 | # build edges 384 | self._build_topology(instance_map, centroids, graph) 385 | return graph 386 | 387 | def _process_and_save( # type: ignore[override] 388 | self, 389 | instance_map: np.ndarray, 390 | features: torch.Tensor, 391 | annotation: Optional[np.ndarray] = None, 392 | output_name: str = None, 393 | ) -> dgl.DGLGraph: 394 | """Process and save in provided directory 395 | Args: 396 | output_name (str): Name of output file 397 | instance_map (np.ndarray): Instance map depicting tissue components 398 | (eg nuclei, tissue superpixels) 399 | features (torch.Tensor): Features of each node. Shape (nr_nodes, nr_features) 400 | annotation (Optional[np.ndarray], optional): Optional node level to include. 401 | Defaults to None. 402 | Returns: 403 | dgl.DGLGraph: [description] 404 | """ 405 | assert ( 406 | self.save_path is not None 407 | ), "Can only save intermediate output if base_path was not None during construction" 408 | output_path = self.output_dir / f"{output_name}.bin" 409 | if output_path.exists(): 410 | logging.info( 411 | f"Output of {output_name} already exists, using it instead of recomputing" 412 | ) 413 | graphs, _ = load_graphs(str(output_path)) 414 | assert len(graphs) == 1 415 | graph = graphs[0] 416 | else: 417 | graph = self._process( 418 | instance_map=instance_map, 419 | features=features, 420 | annotation=annotation) 421 | save_graphs(str(output_path), [graph]) 422 | return graph 423 | 424 | def _get_node_centroids( 425 | self, instance_map: np.ndarray 426 | ) -> np.ndarray: 427 | """Get the centroids of the graphs 428 | Args: 429 | instance_map (np.ndarray): Instance map depicting tissue components 430 | Returns: 431 | centroids (np.ndarray): Node centroids 432 | """ 433 | regions = regionprops(instance_map) 434 | centroids = np.empty((len(regions), 2)) 435 | for i, region in enumerate(regions): 436 | center_y, center_x = region.centroid # (y, x) 437 | center_x = int(round(center_x)) 438 | center_y = int(round(center_y)) 439 | centroids[i, 0] = center_x 440 | centroids[i, 1] = center_y 441 | return centroids 442 | 443 | def _set_node_centroids( 444 | self, 445 | centroids: np.ndarray, 446 | graph: dgl.DGLGraph 447 | ) -> None: 448 | """Set the centroids of the graphs 449 | Args: 450 | centroids (np.ndarray): Node centroids 451 | graph (dgl.DGLGraph): Graph to add the centroids to 452 | """ 453 | graph.ndata[CENTROID] = torch.FloatTensor(centroids) 454 | 455 | def _set_node_features( 456 | self, 457 | features: torch.Tensor, 458 | image_size: Tuple[int, int], 459 | graph: dgl.DGLGraph 460 | ) -> None: 461 | """Set the provided node features 462 | Args: 463 | features (torch.Tensor): Node features 464 | image_size (Tuple[int,int]): Image dimension (x, y) 465 | graph (dgl.DGLGraph): Graph to add the features to 466 | """ 467 | if not torch.is_tensor(features): 468 | features = torch.FloatTensor(features) 469 | if not self.add_loc_feats: 470 | graph.ndata[FEATURES] = features 471 | elif ( 472 | self.add_loc_feats 473 | and image_size is not None 474 | ): 475 | # compute normalized centroid features 476 | centroids = graph.ndata[CENTROID] 477 | 478 | normalized_centroids = torch.empty_like(centroids) # (x, y) 479 | normalized_centroids[:, 0] = centroids[:, 0] / image_size[0] 480 | normalized_centroids[:, 1] = centroids[:, 1] / image_size[1] 481 | 482 | if features.ndim == 3: 483 | normalized_centroids = normalized_centroids \ 484 | .unsqueeze(dim=1) \ 485 | .repeat(1, features.shape[1], 1) 486 | concat_dim = 2 487 | elif features.ndim == 2: 488 | concat_dim = 1 489 | 490 | concat_features = torch.cat( 491 | ( 492 | features, 493 | normalized_centroids 494 | ), 495 | dim=concat_dim, 496 | ) 497 | graph.ndata[FEATURES] = concat_features 498 | else: 499 | raise ValueError( 500 | "Please provide image size to add the normalized centroid to the node features." 501 | ) 502 | 503 | @abstractmethod 504 | def _set_node_labels( 505 | self, 506 | instance_map: np.ndarray, 507 | annotation: np.ndarray, 508 | graph: dgl.DGLGraph 509 | ) -> None: 510 | """Set the node labels of the graphs 511 | Args: 512 | instance_map (np.ndarray): Instance map depicting tissue components 513 | annotation (np.ndarray): Annotations, eg node labels 514 | graph (dgl.DGLGraph): Graph to add the centroids to 515 | """ 516 | 517 | @abstractmethod 518 | def _build_topology( 519 | self, 520 | instance_map: np.ndarray, 521 | centroids: np.ndarray, 522 | graph: dgl.DGLGraph 523 | ) -> None: 524 | """Generate the graph topology from the provided instance_map 525 | Args: 526 | instance_map (np.array): Instance map depicting tissue components 527 | centroids (np.array): Node centroids 528 | graph (dgl.DGLGraph): Graph to add the edges 529 | """ 530 | 531 | def precompute( 532 | self, 533 | link_path: Union[None, str, Path] = None, 534 | precompute_path: Union[None, str, Path] = None, 535 | ) -> None: 536 | """Precompute all necessary information 537 | Args: 538 | link_path (Union[None, str, Path], optional): Path to link to. Defaults to None. 539 | precompute_path (Union[None, str, Path], optional): Path to save precomputation outputs. Defaults to None. 540 | """ 541 | if self.save_path is not None and link_path is not None: 542 | self._link_to_path(Path(link_path) / "graphs") 543 | 544 | 545 | class RAGGraphBuilder(BaseGraphBuilder): 546 | """ 547 | Super-pixel Graphs class for graph building. 548 | """ 549 | 550 | def __init__(self, kernel_size: int = 3, hops: int = 1, **kwargs) -> None: 551 | """Create a graph builder that uses a provided kernel size to detect connectivity 552 | Args: 553 | kernel_size (int, optional): Size of the kernel to detect connectivity. Defaults to 5. 554 | """ 555 | logging.debug("*** RAG Graph Builder ***") 556 | assert hops > 0 and isinstance( 557 | hops, int 558 | ), f"Invalid hops {hops} ({type(hops)}). Must be integer >= 0" 559 | self.kernel_size = kernel_size 560 | self.hops = hops 561 | super().__init__(**kwargs) 562 | 563 | def _set_node_labels( 564 | self, 565 | instance_map: np.ndarray, 566 | annotation: np.ndarray, 567 | graph: dgl.DGLGraph) -> None: 568 | """Set the node labels of the graphs using annotation map""" 569 | assert ( 570 | self.nr_annotation_classes < 256 571 | ), "Cannot handle that many classes with 8-bits" 572 | regions = regionprops(instance_map) 573 | labels = torch.empty(len(regions), dtype=torch.uint8) 574 | 575 | for region_label in np.arange(1, len(regions) + 1): 576 | histogram = fast_histogram( 577 | annotation[instance_map == region_label], 578 | nr_values=self.nr_annotation_classes 579 | ) 580 | mask = np.ones(len(histogram), np.bool) 581 | mask[self.annotation_background_class] = 0 582 | if histogram[mask].sum() == 0: 583 | assignment = self.annotation_background_class 584 | else: 585 | histogram[self.annotation_background_class] = 0 586 | assignment = np.argmax(histogram) 587 | labels[region_label - 1] = int(assignment) 588 | graph.ndata[LABEL] = labels 589 | 590 | def _build_topology( 591 | self, 592 | instance_map: np.ndarray, 593 | centroids: np.ndarray, 594 | graph: dgl.DGLGraph 595 | ) -> None: 596 | """Create the graph topology from the instance connectivty in the instance_map""" 597 | regions = regionprops(instance_map) 598 | instance_ids = torch.empty(len(regions), dtype=torch.uint8) 599 | 600 | kernel = np.ones((3, 3), np.uint8) 601 | adjacency = np.zeros(shape=(len(instance_ids), len(instance_ids))) 602 | 603 | for instance_id in np.arange(1, len(instance_ids) + 1): 604 | mask = (instance_map == instance_id).astype(np.uint8) 605 | # print("mask:{}".format(mask)) 606 | dilation = cv2.dilate(mask,kernel, iterations=1) 607 | # print("dilation:{}".format(dilation)) 608 | boundary = dilation - mask 609 | # print("boundary:{}".format(boundary)) 610 | # print(sum(sum(boundary))) 611 | idx = pd.unique(instance_map[boundary.astype(bool)]) 612 | # print("idx:{}".format(idx)) 613 | # print(len(idx)) 614 | instance_id -= 1 # because instance_map id starts from 1 615 | idx -= 1 # because instance_map id starts from 1 616 | # print("new idx:{}".format(idx)) 617 | # print(type(idx)) 618 | idx = idx.tolist() 619 | # print(type(idx)) 620 | if -1 in idx: 621 | idx.remove(-1) 622 | idx = np.array(idx) 623 | # print(type(idx)) 624 | # print("new new idx:{}".format(idx)) 625 | if idx.shape[0] != 0: 626 | adjacency[instance_id, idx] = 1 627 | # print(adjacency) 628 | 629 | edge_list = np.nonzero(adjacency) 630 | graph.add_edges(list(edge_list[0]), list(edge_list[1])) 631 | 632 | for _ in range(self.hops - 1): 633 | graph = two_hop_neighborhood(graph) 634 | 635 | """This module handles everything related to superpixels""" 636 | 637 | import logging 638 | import math 639 | import sys 640 | from abc import abstractmethod 641 | from pathlib import Path 642 | from typing import Any, Dict, Optional, Union 643 | 644 | import cv2 645 | import h5py 646 | import numpy as np 647 | from skimage.color.colorconv import rgb2hed 648 | from skimage.future import graph 649 | from skimage.segmentation import slic 650 | from skimage.future import graph 651 | 652 | 653 | class SuperpixelExtractor(PipelineStep): 654 | """Helper class to extract superpixels from images""" 655 | 656 | def __init__( 657 | self, 658 | nr_superpixels: int = None, 659 | superpixel_size: int = None, 660 | max_nr_superpixels: Optional[int] = None, 661 | blur_kernel_size: Optional[float] = 1, 662 | compactness: Optional[int] = 20, 663 | max_iterations: Optional[int] = 10, 664 | threshold: Optional[float] = 0.03, 665 | connectivity: Optional[int] = 2, 666 | color_space: Optional[str] = "rgb", 667 | downsampling_factor: Optional[int] = 1, 668 | **kwargs, 669 | ) -> None: 670 | """Abstract class that extracts superpixels from RGB Images 671 | Args: 672 | nr_superpixels (None, int): The number of super pixels before any merging. 673 | superpixel_size (None, int): The size of super pixels before any merging. 674 | max_nr_superpixels (int, optional): Upper bound for the number of super pixels. 675 | Useful when providing a superpixel size. 676 | blur_kernel_size (float, optional): Size of the blur kernel. Defaults to 0. 677 | compactness (int, optional): Compactness of the superpixels. Defaults to 30. 678 | max_iterations (int, optional): Number of iterations of the slic algorithm. Defaults to 10. 679 | threshold (float, optional): Connectivity threshold. Defaults to 0.03. 680 | connectivity (int, optional): Connectivity for merging graph. Defaults to 2. 681 | downsampling_factor (int, optional): Downsampling factor from the input image 682 | resolution. Defaults to 1. 683 | """ 684 | assert (nr_superpixels is None and superpixel_size is not None) or ( 685 | nr_superpixels is not None and superpixel_size is None 686 | ), "Provide value for either nr_superpixels or superpixel_size" 687 | self.nr_superpixels = nr_superpixels 688 | self.superpixel_size = superpixel_size 689 | self.max_nr_superpixels = max_nr_superpixels 690 | self.blur_kernel_size = blur_kernel_size 691 | self.compactness = compactness 692 | self.max_iterations = max_iterations 693 | self.threshold = threshold 694 | self.connectivity = connectivity 695 | self.color_space = color_space 696 | self.downsampling_factor = downsampling_factor 697 | super().__init__(**kwargs) 698 | 699 | def _process( # type: ignore[override] 700 | self, input_image: np.ndarray, tissue_mask: np.ndarray = None 701 | ) -> np.ndarray: 702 | """Return the superpixels of a given input image 703 | Args: 704 | input_image (np.array): Input image 705 | tissue_mask (None, np.array): Input tissue mask 706 | Returns: 707 | np.array: Extracted superpixels 708 | """ 709 | logging.debug("Input size: %s", input_image.shape) 710 | original_height, original_width, _ = input_image.shape 711 | if self.downsampling_factor != 1: 712 | input_image = self._downsample( 713 | input_image, self.downsampling_factor) 714 | if tissue_mask is not None: 715 | tissue_mask = self._downsample( 716 | tissue_mask, self.downsampling_factor) 717 | logging.debug("Downsampled to %s", input_image.shape) 718 | superpixels = self._extract_superpixels( 719 | image=input_image, tissue_mask=tissue_mask 720 | ) 721 | if self.downsampling_factor != 1: 722 | superpixels = self._upsample( 723 | superpixels, original_height, original_width) 724 | logging.debug("Upsampled to %s", superpixels.shape) 725 | return superpixels 726 | 727 | @abstractmethod 728 | def _extract_superpixels( 729 | self, image: np.ndarray, tissue_mask: np.ndarray = None 730 | ) -> np.ndarray: 731 | """Perform the superpixel extraction 732 | Args: 733 | image (np.array): Input tensor 734 | tissue_mask (np.array): Tissue mask tensor 735 | Returns: 736 | np.array: Output tensor 737 | """ 738 | 739 | @staticmethod 740 | def _downsample(image: np.ndarray, downsampling_factor: int) -> np.ndarray: 741 | """Downsample an input image with a given downsampling factor 742 | Args: 743 | image (np.array): Input tensor 744 | downsampling_factor (int): Factor to downsample 745 | Returns: 746 | np.array: Output tensor 747 | """ 748 | height, width = image.shape[0], image.shape[1] 749 | new_height = math.floor(height / downsampling_factor) 750 | new_width = math.floor(width / downsampling_factor) 751 | downsampled_image = cv2.resize( 752 | image, (new_width, new_height), interpolation=cv2.INTER_NEAREST 753 | ) 754 | return downsampled_image 755 | 756 | @staticmethod 757 | def _upsample( 758 | image: np.ndarray, 759 | new_height: int, 760 | new_width: int) -> np.ndarray: 761 | """Upsample an input image to a speficied new height and width 762 | Args: 763 | image (np.array): Input tensor 764 | new_height (int): Target height 765 | new_width (int): Target width 766 | Returns: 767 | np.array: Output tensor 768 | """ 769 | upsampled_image = cv2.resize( 770 | image, (new_width, new_height), interpolation=cv2.INTER_NEAREST 771 | ) 772 | return upsampled_image 773 | 774 | def precompute( 775 | self, 776 | link_path: Union[None, str, Path] = None, 777 | precompute_path: Union[None, str, Path] = None, 778 | ) -> None: 779 | """Precompute all necessary information 780 | Args: 781 | link_path (Union[None, str, Path], optional): Path to link to. Defaults to None. 782 | precompute_path (Union[None, str, Path], optional): Path to save precomputation outputs. Defaults to None. 783 | """ 784 | if self.save_path is not None and link_path is not None: 785 | self._link_to_path(Path(link_path) / "superpixels") 786 | 787 | 788 | class SLICSuperpixelExtractor(SuperpixelExtractor): 789 | """Use the SLIC algorithm to extract superpixels.""" 790 | 791 | def __init__(self, **kwargs) -> None: 792 | """Extract superpixels with the SLIC algorithm""" 793 | super().__init__(**kwargs) 794 | 795 | def _get_nr_superpixels(self, image: np.ndarray) -> int: 796 | """Compute the number of superpixels for initial segmentation 797 | Args: 798 | image (np.array): Input tensor 799 | Returns: 800 | int: Output number of superpixels 801 | """ 802 | if self.superpixel_size is not None: 803 | nr_superpixels = int( 804 | (image.shape[0] * image.shape[1] / self.superpixel_size) 805 | ) 806 | elif self.nr_superpixels is not None: 807 | nr_superpixels = self.nr_superpixels 808 | if self.max_nr_superpixels is not None: 809 | nr_superpixels = min(nr_superpixels, self.max_nr_superpixels) 810 | return nr_superpixels 811 | 812 | def _extract_superpixels( 813 | self, 814 | image: np.ndarray, 815 | *args, 816 | **kwargs) -> np.ndarray: 817 | """Perform the superpixel extraction 818 | Args: 819 | image (np.array): Input tensor 820 | Returns: 821 | np.array: Output tensor 822 | """ 823 | if self.color_space == "hed": 824 | image = rgb2hed(image) 825 | nr_superpixels = self._get_nr_superpixels(image) 826 | superpixels = slic( 827 | image, 828 | sigma=self.blur_kernel_size, 829 | n_segments=nr_superpixels, 830 | max_iter=self.max_iterations, 831 | compactness=self.compactness, 832 | start_label=1, 833 | ) 834 | return superpixels 835 | 836 | 837 | class MergedSuperpixelExtractor(SuperpixelExtractor): 838 | def __init__(self, **kwargs) -> None: 839 | """Extract superpixels with the SLIC algorithm""" 840 | super().__init__(**kwargs) 841 | 842 | def _get_nr_superpixels(self, image: np.ndarray) -> int: 843 | """Compute the number of superpixels for initial segmentation 844 | Args: 845 | image (np.array): Input tensor 846 | Returns: 847 | int: Output number of superpixels 848 | """ 849 | if self.superpixel_size is not None: 850 | nr_superpixels = int( 851 | (image.shape[0] * image.shape[1] / self.superpixel_size) 852 | ) 853 | elif self.nr_superpixels is not None: 854 | nr_superpixels = self.nr_superpixels 855 | if self.max_nr_superpixels is not None: 856 | nr_superpixels = min(nr_superpixels, self.max_nr_superpixels) 857 | return nr_superpixels 858 | 859 | def _extract_initial_superpixels(self, image: np.ndarray) -> np.ndarray: 860 | """Extract initial superpixels using SLIC 861 | Args: 862 | image (np.array): Input tensor 863 | Returns: 864 | np.array: Output tensor 865 | """ 866 | nr_superpixels = self._get_nr_superpixels(image) 867 | superpixels = slic( 868 | image, 869 | sigma=self.blur_kernel_size, 870 | n_segments=nr_superpixels, 871 | compactness=self.compactness, 872 | max_iter=self.max_iterations, 873 | start_label=1, 874 | ) 875 | return superpixels 876 | 877 | def _merge_superpixels( 878 | self, 879 | input_image: np.ndarray, 880 | initial_superpixels: np.ndarray, 881 | tissue_mask: np.ndarray = None, 882 | ) -> np.ndarray: 883 | """Merge the initial superpixels to return merged superpixels 884 | Args: 885 | image (np.array): Input image 886 | initial_superpixels (np.array): Initial superpixels 887 | tissue_mask (None, np.array): Tissue mask 888 | Returns: 889 | np.array: Output merged superpixel tensor 890 | """ 891 | if tissue_mask is not None: 892 | # Remove superpixels belonging to background or having < 10% tissue 893 | # content 894 | ids_initial = np.unique(initial_superpixels, return_counts=True) 895 | ids_masked = np.unique( 896 | tissue_mask * initial_superpixels, return_counts=True 897 | ) 898 | 899 | ctr = 1 900 | superpixels = np.zeros_like(initial_superpixels) 901 | for i in range(len(ids_initial[0])): 902 | id = ids_initial[0][i] 903 | if id in ids_masked[0]: 904 | idx = np.where(id == ids_masked[0])[0] 905 | ratio = ids_masked[1][idx] / ids_initial[1][i] 906 | if ratio >= 0.1: 907 | superpixels[initial_superpixels == id] = ctr 908 | ctr += 1 909 | 910 | initial_superpixels = superpixels 911 | 912 | # Merge superpixels within tissue region 913 | g = graph.rag_mean_color(input_image, initial_superpixels) 914 | merged_superpixels = graph.merge_hierarchical( 915 | initial_superpixels, 916 | g, 917 | thresh=self.threshold, 918 | rag_copy=False, 919 | in_place_merge=True, 920 | merge_func=self._merging_function, 921 | weight_func=self._weighting_function, 922 | ) 923 | merged_superpixels += 1 # Handle regionprops that ignores all values of 0 924 | # mask = np.zeros_like(initial_superpixels) 925 | # mask[initial_superpixels != 0] = 1 926 | # merged_superpixels = merged_superpixels * mask 927 | return merged_superpixels 928 | 929 | @abstractmethod 930 | def _weighting_function( 931 | self, graph: graph.RAG, src: int, dst: int, n: int 932 | ) -> Dict[str, Any]: 933 | """Handle merging of nodes of a region boundary region adjacency graph.""" 934 | 935 | @abstractmethod 936 | def _merging_function(self, graph: graph.RAG, src: int, dst: int) -> None: 937 | """Call back called before merging 2 nodes.""" 938 | 939 | def _extract_superpixels( 940 | self, image: np.ndarray, tissue_mask: np.ndarray = None 941 | ) -> np.ndarray: 942 | """Perform superpixel extraction 943 | Args: 944 | image (np.array): Input tensor 945 | tissue_mask (np.array, optional): Input tissue mask 946 | Returns: 947 | np.array: Extracted merged superpixels. 948 | np.array: Extracted init superpixels, ie before merging. 949 | """ 950 | initial_superpixels = self._extract_initial_superpixels(image) 951 | merged_superpixels = self._merge_superpixels( 952 | image, initial_superpixels, tissue_mask 953 | ) 954 | 955 | return merged_superpixels, initial_superpixels 956 | 957 | def _process( # type: ignore[override] 958 | self, input_image: np.ndarray, tissue_mask=None 959 | ) -> np.ndarray: 960 | """Return the superpixels of a given input image 961 | Args: 962 | input_image (np.array): Input image. 963 | tissue_mask (None, np.array): Tissue mask. 964 | Returns: 965 | np.array: Extracted merged superpixels. 966 | np.array: Extracted init superpixels, ie before merging. 967 | """ 968 | logging.debug("Input size: %s", input_image.shape) 969 | original_height, original_width, _ = input_image.shape 970 | if self.downsampling_factor is not None and self.downsampling_factor != 1: 971 | input_image = self._downsample( 972 | input_image, self.downsampling_factor) 973 | if tissue_mask is not None: 974 | tissue_mask = self._downsample( 975 | tissue_mask, self.downsampling_factor) 976 | logging.debug("Downsampled to %s", input_image.shape) 977 | merged_superpixels, initial_superpixels = self._extract_superpixels( 978 | input_image, tissue_mask 979 | ) 980 | if self.downsampling_factor != 1: 981 | merged_superpixels = self._upsample( 982 | merged_superpixels, original_height, original_width 983 | ) 984 | initial_superpixels = self._upsample( 985 | initial_superpixels, original_height, original_width 986 | ) 987 | logging.debug("Upsampled to %s", merged_superpixels.shape) 988 | return merged_superpixels, initial_superpixels 989 | 990 | def _process_and_save( 991 | self, 992 | *args: Any, 993 | output_name: str, 994 | **kwargs: Any) -> Any: 995 | """Process and save in the provided path as as .h5 file 996 | Args: 997 | output_name (str): Name of output file 998 | """ 999 | assert ( 1000 | self.save_path is not None 1001 | ), "Can only save intermediate output if base_path was not None when constructing the object" 1002 | superpixel_output_path = self.output_dir / f"{output_name}.h5" 1003 | if superpixel_output_path.exists(): 1004 | logging.info( 1005 | f"{self.__class__.__name__}: Output of {output_name} already exists, using it instead of recomputing" 1006 | ) 1007 | try: 1008 | with h5py.File(superpixel_output_path, "r") as input_file: 1009 | merged_superpixels, initial_superpixels = self._get_outputs( 1010 | input_file=input_file) 1011 | except OSError as e: 1012 | print( 1013 | f"\n\nCould not read from {superpixel_output_path}!\n\n", 1014 | file=sys.stderr, 1015 | flush=True, 1016 | ) 1017 | print( 1018 | f"\n\nCould not read from {superpixel_output_path}!\n\n", 1019 | flush=True) 1020 | raise e 1021 | else: 1022 | merged_superpixels, initial_superpixels = self._process( 1023 | *args, **kwargs) 1024 | try: 1025 | with h5py.File(superpixel_output_path, "w") as output_file: 1026 | self._set_outputs( 1027 | output_file=output_file, 1028 | outputs=(merged_superpixels, initial_superpixels), 1029 | ) 1030 | except OSError as e: 1031 | print( 1032 | f"\n\nCould not write to {superpixel_output_path}!\n\n", 1033 | flush=True) 1034 | raise e 1035 | return merged_superpixels, initial_superpixels 1036 | 1037 | 1038 | class MyColorMergedSuperpixelExtractor(MergedSuperpixelExtractor): 1039 | def __init__( 1040 | self, 1041 | w_hist: float = 0.5, 1042 | w_mean: float = 0.5, 1043 | **kwargs) -> None: 1044 | """Superpixel merger based on color attibutes taken from the HACT-Net Implementation 1045 | Args: 1046 | w_hist (float, optional): Weight of the histogram features for merging. Defaults to 0.5. 1047 | w_mean (float, optional): Weight of the mean features for merging. Defaults to 0.5. 1048 | """ 1049 | self.w_hist = w_hist 1050 | self.w_mean = w_mean 1051 | super().__init__(**kwargs) 1052 | 1053 | def _color_features_per_channel(self, img_ch: np.ndarray) -> np.ndarray: 1054 | """Extract color histograms from image channel 1055 | Args: 1056 | img_ch (np.ndarray): Image channel 1057 | Returns: 1058 | np.ndarray: Histogram of the image channel 1059 | """ 1060 | hist, _ = np.histogram(img_ch, bins=np.arange(0, 257, 64)) # 8 bins 1061 | return hist 1062 | 1063 | def _weighting_function( 1064 | self, graph: graph.RAG, src: int, dst: int, n: int 1065 | ) -> Dict[str, Any]: 1066 | diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color'] 1067 | diff = np.linalg.norm(diff) 1068 | return {'weight': diff} 1069 | 1070 | def _merging_function(self, graph: graph.RAG, src: int, dst: int) -> None: 1071 | graph.nodes[dst]['total color'] += graph.nodes[src]['total color'] 1072 | graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count'] 1073 | graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] / 1074 | graph.nodes[dst]['pixel count']) -------------------------------------------------------------------------------- /data_splits/train_val_test_split_argo.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/data_splits/train_val_test_split_argo.pkl -------------------------------------------------------------------------------- /data_splits/train_val_test_split_kirc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/data_splits/train_val_test_split_kirc.pkl -------------------------------------------------------------------------------- /data_splits/train_val_test_split_lihc.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/data_splits/train_val_test_split_lihc.pkl -------------------------------------------------------------------------------- /data_splits/useless: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/miccai_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/images/miccai_framework.png -------------------------------------------------------------------------------- /images/try: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /label/argo_label.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/label/argo_label.pt -------------------------------------------------------------------------------- /label/kirc_label.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/label/kirc_label.pt -------------------------------------------------------------------------------- /label/lihc_label.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/label/lihc_label.pt -------------------------------------------------------------------------------- /label/useless: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dgl==0.4.3 2 | einops==0.3.0 3 | histocartography==0.2.0 4 | joblib==1.0.0 5 | lifelines==0.25.7 6 | numpy==1.19.2 7 | opencv-python==4.5.1.48 8 | opencv-python-headless==4.1.2.30 9 | openpyxl==3.0.8 10 | openslide-python==1.1.2 11 | pandas==1.2.0 12 | Pillow==8.1.0 13 | scikit-image==0.18.1 14 | scikit-learn==1.1.2 15 | scipy==1.6.0 16 | sklearn-pandas==2.0.4 17 | timm==0.4.12 18 | torch==1.7.1 19 | torch-cluster==1.5.8 20 | torch-geometric==1.6.3 21 | torch-scatter==2.0.5 22 | torch-sparse==0.6.8 23 | torch-spline-conv==1.2.0 24 | torchaudio==0.7.0 25 | torchstain==1.1.0 26 | torchvision==0.8.2 27 | tqdm==4.56.0 28 | -------------------------------------------------------------------------------- /train/block_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import Mlp,DropPath 4 | 5 | class Attention(nn.Module): 6 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 7 | super().__init__() 8 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 9 | self.num_heads = num_heads 10 | head_dim = dim // num_heads 11 | self.scale = head_dim ** -0.5 12 | 13 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 14 | self.attn_drop = nn.Dropout(attn_drop) 15 | self.proj = nn.Linear(dim, dim) 16 | self.proj_drop = nn.Dropout(proj_drop) 17 | 18 | self.attention_weights: Optional[Tensor] = None 19 | 20 | def forward(self, x): 21 | B, N, C = x.shape 22 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 23 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 24 | 25 | attn = (q @ k.transpose(-2, -1)) * self.scale 26 | attn = attn.softmax(dim=-1) 27 | attn = self.attn_drop(attn) 28 | # print(attn.shape) 29 | self.attention_weights = attn 30 | 31 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 32 | x = self.proj(x) 33 | x = self.proj_drop(x) 34 | return x 35 | 36 | def get_attention_weights(self): 37 | return self.attention_weights 38 | 39 | 40 | class LayerScale(nn.Module): 41 | def __init__(self, dim, init_values=1e-5, inplace=False): 42 | super().__init__() 43 | self.inplace = inplace 44 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 45 | 46 | def forward(self, x): 47 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 48 | 49 | 50 | class Block(nn.Module): 51 | 52 | def __init__( 53 | self, 54 | dim, 55 | num_heads, 56 | mlp_ratio=4., 57 | qkv_bias=False, 58 | drop=0., 59 | attn_drop=0., 60 | init_values=None, 61 | drop_path=0., 62 | act_layer=nn.GELU, 63 | norm_layer=nn.LayerNorm 64 | ): 65 | super().__init__() 66 | self.norm1 = norm_layer(dim) 67 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 68 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 69 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 70 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 71 | 72 | self.norm2 = norm_layer(dim) 73 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 74 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 75 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 76 | 77 | def forward(self, x): 78 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 79 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 80 | return x 81 | 82 | def get_attention_weights(self): 83 | return self.attn.get_attention_weights() -------------------------------------------------------------------------------- /train/run.sh: -------------------------------------------------------------------------------- 1 | # #!/bin/bash 2 | 3 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 0 --label "tcga_argo_sage_max_1024featdim_fold0_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1" 4 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 1 --label "tcga_argo_sage_max_1024featdim_fold1_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1" 5 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 2 --label "tcga_argo_sage_max_1024featdim_fold2_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1" 6 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 3 --label "tcga_argo_sage_max_1024featdim_fold3_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1" 7 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 4 --label "tcga_argo_sage_max_1024featdim_fold4_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1" 8 | 9 | #according to the dataset and parameters, the value of label should be changed 10 | -------------------------------------------------------------------------------- /train/superpixel_transformer_n.py: -------------------------------------------------------------------------------- 1 | from timm.models.vision_transformer import VisionTransformer 2 | import timm.models.vision_transformer 3 | import skimage.io as io 4 | import argparse 5 | import joblib 6 | import copy 7 | import random 8 | import os 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 10 | 11 | import skimage.io as io 12 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_,PatchEmbed 13 | from timm.models.helpers import build_model_with_cfg, named_apply 14 | from torch_geometric.nn import global_mean_pool,global_max_pool,GlobalAttention,dense_diff_pool,global_add_pool,TopKPooling,ASAPooling,SAGPooling 15 | from torch_geometric.nn import GCNConv,ChebConv,SAGEConv,GraphConv,LEConv,LayerNorm,GATConv 16 | import torch 17 | from sklearn.metrics import accuracy_score,f1_score,roc_auc_score 18 | import torch.nn as nn 19 | torch.set_num_threads(8) 20 | import numpy as np 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | from functools import partial 23 | from block_utils import Block 24 | from torch_geometric.data import Data as geomData 25 | from timm.models.layers import trunc_normal_ 26 | from torch_scatter import scatter_add 27 | from torch_geometric.utils import softmax 28 | try: 29 | from torch import _assert 30 | except ImportError: 31 | def _assert(condition: bool, message: str): 32 | assert condition, message 33 | 34 | 35 | def reset(nn): 36 | def _reset(item): 37 | if hasattr(item, 'reset_parameters'): 38 | item.reset_parameters() 39 | 40 | if nn is not None: 41 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 42 | for item in nn.children(): 43 | _reset(item) 44 | else: 45 | _reset(nn) 46 | 47 | class my_GlobalAttention(torch.nn.Module): 48 | def __init__(self, gate_nn, nn=None): 49 | super(my_GlobalAttention, self).__init__() 50 | self.gate_nn = gate_nn 51 | self.nn = nn 52 | 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | reset(self.gate_nn) 57 | reset(self.nn) 58 | 59 | 60 | def forward(self, x, batch, size=None): 61 | """""" 62 | x = x.unsqueeze(-1) if x.dim() == 1 else x 63 | size = batch.max().item() + 1 if size is None else size #modified 64 | 65 | gate = self.gate_nn(x).view(-1, 1) 66 | x = self.nn(x) if self.nn is not None else x 67 | assert gate.dim() == x.dim() and gate.size(0) == x.size(0) 68 | 69 | gate = softmax(gate, batch, num_nodes=size) 70 | out = scatter_add(gate * x, batch, dim=0, dim_size=size) 71 | 72 | return out 73 | 74 | 75 | class Intra_GCN(nn.Module): 76 | def __init__(self,in_feats,n_hidden,out_feats,drop_out_ratio=0.2,mpool_method="global_mean_pool",gnn_method='sage'): 77 | super(Intra_GCN,self).__init__() 78 | if gnn_method == 'sage': 79 | self.conv1= SAGEConv(in_channels=in_feats,out_channels=out_feats) 80 | elif gnn_method == 'gcn': 81 | self.conv1= GCNConv(in_channels=in_feats,out_channels=out_feats) 82 | elif gnn_method == 'gat': 83 | self.conv1= GATConv(in_channels=in_feats,out_channels=out_feats) 84 | elif gnn_method == 'leconv': 85 | self.conv1= LEConv(in_channels=in_feats,out_channels=out_feats) 86 | elif gnn_method == 'graphconv': 87 | self.conv1= GraphConv(in_channels=in_feats,out_channels=out_feats) 88 | # self.conv2= SAGEConv(in_channels=n_hidden,out_channels=out_feats) 89 | 90 | self.relu = torch.nn.ReLU() 91 | self.sigmoid = torch.nn.Sigmoid() 92 | self.dropout=nn.Dropout(p=drop_out_ratio) 93 | self.softmax = nn.Softmax(dim=-1) 94 | 95 | if mpool_method == "global_mean_pool": 96 | self.mpool = global_mean_pool 97 | elif mpool_method == "global_max_pool": 98 | self.mpool = global_max_pool 99 | elif mpool_method == "global_att_pool": 100 | att_net=nn.Sequential(nn.Linear(out_feats, out_feats//2), nn.ReLU(), nn.Linear(out_feats//2, 1)) 101 | self.mpool = my_GlobalAttention(att_net) 102 | self.norm = LayerNorm(in_feats) 103 | self.norm2 = LayerNorm(out_feats) 104 | self.norm1 = LayerNorm(n_hidden) 105 | 106 | def forward(self,data): 107 | x=data.x 108 | edge_index = data.edge_patch 109 | 110 | x = self.norm(x) 111 | x = self.conv1(x,edge_index) 112 | x = self.relu(x) 113 | # x = self.sigmoid(x) 114 | # x = self.norm(x) 115 | x = self.norm1(x) 116 | x = self.dropout(x) 117 | 118 | # x = self.conv2(x,edge_index) 119 | # x = self.relu(x) 120 | # # x = self.sigmoid(x) 121 | # # x = self.norm(x) 122 | # x = self.norm2(x) 123 | # x = self.dropout(x) 124 | # print(x) 125 | 126 | # batch = torch.zeros(len(x),dtype=torch.long).to(device) 127 | batch = data.superpixel_attri.to(device) 128 | x = self.mpool(x,batch) 129 | # print('fea dim is {}'.format(x.shape)) 130 | # print(x) 131 | 132 | fea = x 133 | # print(fea.shape) 134 | 135 | return fea 136 | 137 | class Inter_GCN(nn.Module): 138 | def __init__(self,in_feats,n_hidden,out_feats,drop_out_ratio=0.2,mpool_method="global_mean_pool",gnn_method='sage'): 139 | super(Inter_GCN,self).__init__() 140 | # self.conv1= SAGEConv(in_channels=in_feats,out_channels=out_feats) 141 | # self.conv2= SAGEConv(in_channels=n_hidden,out_channels=out_feats) 142 | if gnn_method == 'sage': 143 | self.conv1= SAGEConv(in_channels=in_feats,out_channels=out_feats) 144 | elif gnn_method == 'gcn': 145 | self.conv1= GCNConv(in_channels=in_feats,out_channels=out_feats) 146 | elif gnn_method == 'gat': 147 | self.conv1= GATConv(in_channels=in_feats,out_channels=out_feats) 148 | elif gnn_method == 'leconv': 149 | self.conv1= LEConv(in_channels=in_feats,out_channels=out_feats) 150 | elif gnn_method == 'graphconv': 151 | self.conv1= GraphConv(in_channels=in_feats,out_channels=out_feats) 152 | 153 | self.relu = torch.nn.ReLU() 154 | self.sigmoid = torch.nn.Sigmoid() 155 | self.dropout=nn.Dropout(p=drop_out_ratio) 156 | self.softmax = nn.Softmax(dim=-1) 157 | 158 | if mpool_method == "global_mean_pool": 159 | self.mpool = global_mean_pool 160 | elif mpool_method == "global_max_pool": 161 | self.mpool = global_max_pool 162 | elif mpool_method == "global_att_pool": 163 | att_net=nn.Sequential(nn.Linear(out_feats, out_feats//2), nn.ReLU(), nn.Linear(out_feats//2, 1)) 164 | self.mpool = my_GlobalAttention(att_net) 165 | self.norm = LayerNorm(in_feats) 166 | self.norm2 = LayerNorm(out_feats) 167 | self.norm1 = LayerNorm(n_hidden) 168 | 169 | def forward(self,data,feature): 170 | x=feature 171 | edge_index = data.edge_superpixel 172 | # print(x.shape) 173 | x = self.norm(x) 174 | x = self.conv1(x,edge_index) 175 | x = self.relu(x) 176 | # x = self.sigmoid(x) 177 | x = self.norm1(x) 178 | x = self.dropout(x) 179 | 180 | # x = self.conv2(x,edge_index) 181 | # x = self.relu(x) 182 | # # x = self.sigmoid(x) 183 | # x = self.norm2(x) 184 | # x = self.dropout(x) 185 | 186 | # batch = torch.zeros(len(x),dtype=torch.long).to(device) 187 | # x = self.mpool(x,batch) 188 | 189 | fea = x 190 | # print(x.shape) 191 | # print(fea.shape) 192 | 193 | return fea 194 | 195 | 196 | 197 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 198 | """ Vision Transformer with support for global average pooling 199 | """ 200 | def __init__(self, num_patches=100,no_embed_class=False,class_token=True, depth=1,drop_path_rate=0.,mlp_ratio=4.,pre_norm=True,qkv_bias=True,init_values=None,drop_rate=0.,attn_drop_rate=0.,norm_layer=None,act_layer=None,weight_init='',global_pool='token', fc_norm=None,**kwargs): 201 | super(VisionTransformer, self).__init__(**kwargs) 202 | embed_dim = kwargs['embed_dim'] 203 | self.patch_embed = nn.Linear(embed_dim,embed_dim) 204 | num_patches = num_patches 205 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 206 | act_layer = act_layer or nn.GELU 207 | self.no_embed_class = no_embed_class 208 | self.global_pool = global_pool 209 | 210 | self.num_prefix_tokens = 1 if class_token else 0 211 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens 212 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) 213 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() 214 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm 215 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 216 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 217 | 218 | self.blocks = nn.Sequential(*[ 219 | Block( 220 | dim=embed_dim, 221 | num_heads=kwargs['num_heads'], 222 | mlp_ratio=mlp_ratio, 223 | qkv_bias=qkv_bias, 224 | init_values=init_values, 225 | drop=drop_rate, 226 | attn_drop=attn_drop_rate, 227 | drop_path=dpr[i], 228 | norm_layer=norm_layer, 229 | act_layer=act_layer 230 | ) 231 | for i in range(depth)]) 232 | self.init_weights(weight_init) 233 | 234 | def init_weights(self, mode=''): 235 | assert mode in ('jax', 'jax_nlhb', 'moco', '') 236 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 237 | trunc_normal_(self.pos_embed, std=.02) 238 | if self.cls_token is not None: 239 | nn.init.normal_(self.cls_token, std=1e-6) 240 | named_apply(get_init_weights_vit(mode, head_bias), self) 241 | 242 | def _pos_embed(self, x): 243 | if self.no_embed_class: 244 | # deit-3, updated JAX (big vision) 245 | # position embedding does not overlap with class token, add then concat 246 | x = x + self.pos_embed 247 | if self.cls_token is not None: 248 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 249 | else: 250 | # original timm, JAX, and deit vit impl 251 | # pos_embed has entry for class token, concat then add 252 | if self.cls_token is not None: 253 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 254 | x = x + self.pos_embed 255 | return self.pos_drop(x) 256 | 257 | def forward_features(self, x): 258 | x = self.patch_embed(x) 259 | x = self._pos_embed(x) 260 | x = self.norm_pre(x) 261 | # if self.grad_checkpointing and not torch.jit.is_scripting(): 262 | # x = checkpoint_seq(self.blocks, x) 263 | # else: 264 | x = self.blocks(x) 265 | x = self.norm(x) 266 | return x 267 | 268 | def forward_head(self, x, pre_logits: bool = False): 269 | if self.global_pool: 270 | x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] #token 271 | x = self.fc_norm(x) 272 | return x if pre_logits else self.head(x) 273 | 274 | def forward(self, x): 275 | x = self.forward_features(x) 276 | x = self.forward_head(x) 277 | return torch.sigmoid(x) 278 | 279 | def get_attention_weights(self): 280 | return [block.get_attention_weights() for block in self.blocks] 281 | 282 | def init_weights_vit_timm(module: nn.Module, name: str = ''): 283 | """ ViT weight initialization, original timm impl (for reproducibility) """ 284 | if isinstance(module, nn.Linear): 285 | trunc_normal_(module.weight, std=.02) 286 | if module.bias is not None: 287 | nn.init.zeros_(module.bias) 288 | elif hasattr(module, 'init_weights'): 289 | module.init_weights() 290 | 291 | 292 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): 293 | """ ViT weight initialization, matching JAX (Flax) impl """ 294 | if isinstance(module, nn.Linear): 295 | if name.startswith('head'): 296 | nn.init.zeros_(module.weight) 297 | nn.init.constant_(module.bias, head_bias) 298 | else: 299 | nn.init.xavier_uniform_(module.weight) 300 | if module.bias is not None: 301 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) 302 | elif isinstance(module, nn.Conv2d): 303 | lecun_normal_(module.weight) 304 | if module.bias is not None: 305 | nn.init.zeros_(module.bias) 306 | elif hasattr(module, 'init_weights'): 307 | module.init_weights() 308 | 309 | 310 | def init_weights_vit_moco(module: nn.Module, name: str = ''): 311 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ 312 | if isinstance(module, nn.Linear): 313 | if 'qkv' in name: 314 | # treat the weights of Q, K, V separately 315 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 316 | nn.init.uniform_(module.weight, -val, val) 317 | else: 318 | nn.init.xavier_uniform_(module.weight) 319 | if module.bias is not None: 320 | nn.init.zeros_(module.bias) 321 | elif hasattr(module, 'init_weights'): 322 | module.init_weights() 323 | 324 | 325 | def get_init_weights_vit(mode='jax', head_bias: float = 0.): 326 | if 'jax' in mode: 327 | return partial(init_weights_vit_jax, head_bias=head_bias) 328 | elif 'moco' in mode: 329 | return init_weights_vit_moco 330 | else: 331 | return init_weights_vit_timm 332 | 333 | 334 | class Superpixel_Vit(nn.Module): 335 | def __init__(self,in_feats_intra=1024,n_hidden_intra=1024,out_feats_intra=1024,in_feats_inter=1024,n_hidden_inter=1024,out_feats_inter=1024,vw_num=16,feat_dim=1024,num_classes=1,depth=1,num_heads = 16,final_fea_type = 'mean',mpool_intra='global_mean_pool',mpool_inter='global_mean_pool',gnn_intra='sage',gnn_inter='sage'): 336 | super(Superpixel_Vit, self).__init__() 337 | 338 | self.vw_num = vw_num 339 | self.feat_dim = feat_dim 340 | 341 | #intra-graph 342 | self.gcn1 = Intra_GCN(in_feats=in_feats_intra,n_hidden=n_hidden_intra,out_feats=out_feats_intra,mpool_method=mpool_intra,gnn_method=gnn_intra) 343 | 344 | #inter-graph 345 | self.gcn2 = Inter_GCN(in_feats=in_feats_inter,n_hidden=n_hidden_inter,out_feats=out_feats_inter,mpool_method=mpool_inter,gnn_method=gnn_inter) 346 | self.vit = VisionTransformer(num_patches = vw_num,num_classes = num_classes, embed_dim = feat_dim,depth = depth,num_heads = num_heads) 347 | 348 | self.final_fea_type = final_fea_type 349 | 350 | def superpixel_graph(self,data): 351 | superpixel_attri = data.superpixel_attri 352 | min_value = int(min(superpixel_attri)) 353 | #intra-graph 354 | superpixel_fea = self.gcn1(data) 355 | #intra-graph 356 | superpixel_feas = self.gcn2(data,superpixel_fea) 357 | if min_value == 0: 358 | print('min superpixel value is 0') 359 | superpixel_fea_all={} 360 | for index in range(1,superpixel_feas.shape[0]): 361 | fea = superpixel_feas[index].unsqueeze(0) 362 | superpixel_value = index 363 | superpixel_fea_all[superpixel_value] = fea 364 | else: 365 | print('min superpixel value is 1') 366 | superpixel_fea_all={} 367 | for index in range(superpixel_feas.shape[0]): 368 | fea = superpixel_feas[index].unsqueeze(0) 369 | superpixel_value = index+1 370 | superpixel_fea_all[superpixel_value] = fea 371 | return superpixel_fea_all 372 | 373 | 374 | def mean_feature(self,superpixel_features,cluster_info): 375 | mask=np.zeros((self.vw_num,self.feat_dim)) 376 | mask=torch.tensor(mask).to(device) 377 | # superpixel_cluster_path = os.path.join(cluster_info_path,slidename+'.pth') 378 | # cluster_info = torch.load(superpixel_cluster_path) 379 | for vw in range(self.vw_num): 380 | fea_all=torch.Tensor().to(device) 381 | for superpixel_value in cluster_info.keys(): 382 | if cluster_info[superpixel_value]['cluster']==vw: 383 | if fea_all.shape[0]==0: 384 | fea_all=superpixel_features[superpixel_value] 385 | else: 386 | fea_all=torch.cat((fea_all,superpixel_features[superpixel_value]),dim=0) 387 | if fea_all.shape[0]!=0: 388 | fea_avg=torch.mean(fea_all,axis=0) 389 | # print('fea_avg shape:{}'.format(fea_avg.shape)) 390 | mask[vw]=fea_avg 391 | return mask 392 | 393 | def max_feature(self,superpixel_features,cluster_info): 394 | mask=np.zeros((self.vw_num,self.feat_dim)) 395 | mask=torch.tensor(mask).to(device) 396 | # superpixel_cluster_path = os.path.join(cluster_info_path,slidename+'.pth') 397 | # cluster_info = torch.load(superpixel_cluster_path) 398 | for vw in range(self.vw_num): 399 | fea_all=torch.Tensor().to(device) 400 | for superpixel_value in cluster_info.keys(): 401 | if cluster_info[superpixel_value]['cluster']==vw: 402 | if fea_all.shape[0]==0: 403 | fea_all=superpixel_features[superpixel_value] 404 | else: 405 | fea_all=torch.cat((fea_all,superpixel_features[superpixel_value]),dim=0) 406 | if fea_all.shape[0]!=0: 407 | fea_max,_=torch.max(fea_all,dim=0) 408 | # print('fea_avg shape:{}'.format(fea_avg.shape)) 409 | mask[vw]=fea_max 410 | return mask 411 | 412 | def forward(self,data,cluster_info): 413 | 414 | superpixels_fea = self.superpixel_graph(data) 415 | #final-fea 416 | if self.final_fea_type == 'mean': 417 | fea = self.mean_feature(superpixels_fea,cluster_info) #[16,1024] 418 | elif self.final_fea_type == 'max': 419 | fea = self.max_feature(superpixels_fea,cluster_info) 420 | fea = fea.unsqueeze(0) #[1,16,1024] 421 | # print(fea.shape) 422 | # print(fea.shape) 423 | #vit 424 | fea = fea.float() 425 | out = self.vit(fea) 426 | return out 427 | 428 | def get_attention_weights(self): 429 | return self.vit.get_attention_weights() 430 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | from timm.models.vision_transformer import VisionTransformer 2 | import timm.models.vision_transformer 3 | import skimage.io as io 4 | import argparse 5 | import joblib 6 | import copy 7 | import random 8 | import os 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '7' 10 | 11 | import skimage.io as io 12 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_,PatchEmbed 13 | # from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq 14 | import torch 15 | from sklearn.metrics import accuracy_score,f1_score,roc_auc_score 16 | # from model_utils import Block,DropPath,get_sinusoid_encoding_table 17 | import torch.nn as nn 18 | torch.set_num_threads(8) 19 | # from lifelines.utils import concordance_index 20 | import numpy as np 21 | from lifelines.utils import concordance_index as ci 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | from functools import partial 24 | from block_utils import Block 25 | from superpixel_transformer_n import Superpixel_Vit 26 | 27 | try: 28 | from torch import _assert 29 | except ImportError: 30 | def _assert(condition: bool, message: str): 31 | assert condition, message 32 | 33 | 34 | def _neg_partial_log(prediction, T, E): 35 | 36 | current_batch_len = len(prediction) 37 | R_matrix_train = np.zeros([current_batch_len, current_batch_len], dtype=int) 38 | for i in range(current_batch_len): 39 | for j in range(current_batch_len): 40 | R_matrix_train[i, j] = T[j] >= T[i] 41 | 42 | train_R = torch.FloatTensor(R_matrix_train) 43 | train_R = train_R.cuda() 44 | 45 | # train_ystatus = torch.tensor(np.array(E),dtype=torch.float).to(device) 46 | train_ystatus = E 47 | 48 | theta = prediction.reshape(-1) 49 | 50 | exp_theta = torch.exp(theta) 51 | loss_nn = - torch.mean((theta - torch.log(torch.sum(exp_theta * train_R, dim=1))) * train_ystatus) 52 | 53 | return loss_nn 54 | 55 | def get_val_ci(pre_time,patient_and_time,patient_sur_type): 56 | ordered_time, ordered_pred_time, ordered_observed=[],[],[] 57 | for x in pre_time: 58 | ordered_time.append(patient_and_time[x]) 59 | ordered_pred_time.append(pre_time[x]*-1) 60 | ordered_observed.append(patient_sur_type[x]) 61 | # print(len(ordered_time), len(ordered_pred_time), len(ordered_observed)) 62 | return ci(ordered_time, ordered_pred_time, ordered_observed) 63 | 64 | 65 | def get_train_val(train_slide,val_slide,test_slide,label_path): 66 | train_slidename=[] 67 | train_label_censorship=[] 68 | train_label_survtime=[] 69 | val_slidename=[] 70 | val_label_censorship=[] 71 | val_label_survtime=[] 72 | test_slidename=[] 73 | test_label_censorship=[] 74 | test_label_survtime=[] 75 | label = torch.load(label_path) 76 | 77 | for index in train_slide: 78 | train_slidename.append(index) 79 | train_label_censorship.append(label[index]['censorship']) 80 | train_label_survtime.append(label[index]['surv_time']) 81 | 82 | for index1 in val_slide: 83 | val_slidename.append(index1) 84 | val_label_censorship.append(label[index1]['censorship']) 85 | val_label_survtime.append(label[index1]['surv_time']) 86 | 87 | for index2 in test_slide: 88 | test_slidename.append(index2) 89 | test_label_censorship.append(label[index2]['censorship']) 90 | test_label_survtime.append(label[index2]['surv_time']) 91 | 92 | return train_slidename,val_slidename,test_slidename,train_label_censorship,train_label_survtime,val_label_censorship,val_label_survtime,test_label_censorship,test_label_survtime 93 | 94 | 95 | 96 | def eval_model(model,test_slide,test_label_surv_type,test_label_time,batch_size,cuda=True): 97 | model = model.cuda() 98 | model = model.to(device) 99 | model = model.eval() 100 | 101 | #print('begin evaluation!') 102 | with torch.no_grad(): 103 | running_loss=0. 104 | test_loss = 0. 105 | t_out_pre = torch.Tensor().to(device) 106 | t_labelall_surv_type = torch.Tensor().to(device) 107 | t_labelall_time = torch.Tensor().to(device) 108 | 109 | atten_dict = {} 110 | total_loss_for_test = 0 111 | # batch_id = 0 112 | test_num = len(test_slide) 113 | # batch divide 114 | iter_num = test_num // batch_size + 1 115 | 116 | # batch 117 | for batch_iter in range(iter_num): 118 | outs = torch.Tensor().to(device) 119 | labels_surv_type = torch.Tensor().to(device) 120 | labels_time = torch.Tensor().to(device) 121 | 122 | if (batch_iter + 1) == iter_num: # last batch 123 | for sample_index in range(batch_iter * batch_size, test_num): 124 | # print("Sample Index: ", sample_index) 125 | 126 | slidename = test_slide[sample_index] 127 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device) 128 | print(pyg_data) 129 | label_surv_type = torch.tensor(test_label_surv_type[sample_index]).unsqueeze(0).to(device) 130 | label_time = torch.tensor(test_label_time[sample_index]).unsqueeze(0).to(device) 131 | 132 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth')) 133 | # model train 134 | output = model(pyg_data,cluster_info) 135 | atten = model.get_attention_weights() 136 | average_atten = torch.mean(atten[0],dim=1) #record attention score 137 | atten_dict[label_time[0]] = average_atten[0] 138 | # save output and label to calculate loss 139 | if outs.shape[0] == 0: 140 | outs = output 141 | else: 142 | outs = torch.cat((outs, output), dim=0) 143 | 144 | if labels_surv_type.shape[0] == 0: 145 | labels_surv_type = label_surv_type 146 | else: 147 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0) 148 | 149 | if labels_time.shape[0] == 0: 150 | labels_time = label_time 151 | else: 152 | labels_time = torch.cat((labels_time, label_time),dim=0) 153 | 154 | # save all samples results 155 | 156 | if t_out_pre.shape[0] == 0: 157 | t_out_pre = -1 * output 158 | else: 159 | t_out_pre = torch.cat((t_out_pre, -1 * output),dim=0) 160 | 161 | if t_labelall_surv_type.shape[0] == 0: 162 | t_labelall_surv_type = label_surv_type 163 | else: 164 | t_labelall_surv_type = torch.cat((t_labelall_surv_type, label_surv_type), dim=0) 165 | 166 | if t_labelall_time.shape[0] == 0: 167 | t_labelall_time = label_time 168 | else: 169 | t_labelall_time = torch.cat((t_labelall_time, label_time), dim=0) 170 | # calculate loss of the batch 171 | if torch.sum(labels_surv_type) > 0.0: 172 | print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape) 173 | loss = _neg_partial_log(outs,labels_time,labels_surv_type) 174 | loss = args.cox_loss * loss 175 | 176 | total_loss_for_test += loss.item() 177 | print("Batch Avg Loss: {:.4f}", loss.item()) 178 | 179 | else: 180 | # batch 181 | for batch_index in range(batch_size): 182 | index = batch_iter * batch_size + batch_index 183 | # print("Sample Index: ", index) 184 | 185 | # acquire dat and label 186 | slidename = test_slide[index] 187 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device) 188 | print(pyg_data) 189 | label_surv_type = torch.tensor(test_label_surv_type[index]).unsqueeze(0).to(device) 190 | label_time = torch.tensor(test_label_time[index]).unsqueeze(0).to(device) 191 | 192 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth')) 193 | # model train 194 | output = model(pyg_data,cluster_info) 195 | atten = model.get_attention_weights() 196 | average_atten = torch.mean(atten[0],dim=1) 197 | atten_dict[label_time[0]] = average_atten[0] 198 | # save the output and label 199 | if outs.shape[0] == 0: 200 | outs = output 201 | else: 202 | outs = torch.cat((outs, output), dim=0) 203 | 204 | if labels_surv_type.shape[0] == 0: 205 | labels_surv_type = label_surv_type 206 | else: 207 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0) 208 | 209 | if labels_time.shape[0] == 0: 210 | labels_time = label_time 211 | else: 212 | labels_time = torch.cat((labels_time, label_time),dim=0) 213 | 214 | # save all samples results 215 | 216 | if t_out_pre.shape[0] == 0: 217 | t_out_pre = -1 * output 218 | else: 219 | t_out_pre = torch.cat((t_out_pre, -1 * output),dim=0) 220 | 221 | if t_labelall_surv_type.shape[0] == 0: 222 | t_labelall_surv_type = label_surv_type 223 | else: 224 | t_labelall_surv_type = torch.cat((t_labelall_surv_type, label_surv_type), dim=0) 225 | 226 | if t_labelall_time.shape[0] == 0: 227 | t_labelall_time = label_time 228 | else: 229 | t_labelall_time = torch.cat((t_labelall_time, label_time), dim=0) 230 | #compute the loss 231 | if torch.sum(labels_surv_type) > 0.0: 232 | print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape) 233 | 234 | 235 | loss = _neg_partial_log(outs,labels_time,labels_surv_type) 236 | loss = args.cox_loss * loss 237 | 238 | total_loss_for_test += loss.item() 239 | print("Batch Avg Loss: {:.4f}", loss.item()) 240 | 241 | c_idx_epochs_avg = ci(t_labelall_time.data.cpu(),t_out_pre.data.cpu(),t_labelall_surv_type.data.cpu()) 242 | epoch_loss_test = total_loss_for_test / iter_num 243 | print("epoch_val_test_loss:{}".format(epoch_loss_test)) 244 | 245 | return c_idx_epochs_avg,epoch_loss_test,atten_dict 246 | 247 | 248 | def train_model(n_epochs,model,optimizer,scheduler,train_slide,val_slide,test_slide,train_label_censorship,train_label_survtime,val_label_censorship,val_label_survtime,test_label_censorship,test_label_survtime,fold_num,batch_size,cuda=True): 249 | if cuda: 250 | model = model.to(device) 251 | 252 | os.makedirs("/data14/yanhe/miccai/train/saved_model/tcga_kirc/interpreatable_transformer_depth_"+str(args.depth)+"/seed_"+str(args.seed)+"/",exist_ok=True) 253 | os.makedirs("/data14/yanhe/miccai/train/log_result/tcga_kirc/interpreatable_transformer_depth_"+str(args.depth)+"/seed_"+str(args.seed),exist_ok=True) 254 | os.makedirs("/data14/yanhe/miccai/train/attention_weights/tcga_kirc/interpreatable_transformer_depth_"+str(args.depth)+"/seed_"+str(args.seed)+'/atten_dict/',exist_ok=True) 255 | 256 | best_loss = 1e9 257 | best_ci = 0. 258 | best_acc = 0. 259 | n_out_features = 1 260 | n_classes = [1]*n_out_features 261 | for epoch in range(n_epochs): 262 | model.train() 263 | # pbar.set_description("RepeatNum:{}/{} Seed:{} Fold:{}/{}".format(repeat_num_temp + 1,repeat_num,seed,fold_num,all_fold_num)) 264 | total_loss_for_train = 0 265 | batch_id = 0 266 | out_pre = torch.Tensor().to(device) 267 | pre_for_batch = torch.Tensor().to(device) 268 | labelall_surv_type = torch.Tensor().to(device) 269 | label_surv_type_for_batch = torch.Tensor().to(device) 270 | label_time_for_batch = torch.Tensor().to(device) 271 | # print(labelall_surv_type.shape[0]) 272 | labelall_time = torch.Tensor().to(device) 273 | train_num = len(train_slide) 274 | print("train_slide Num: ",len(train_slide)) 275 | 276 | 277 | iter_num = train_num // batch_size + 1 278 | 279 | for batch_iter in range(iter_num): 280 | outs = torch.Tensor().to(device) 281 | labels_surv_type = torch.Tensor().to(device) 282 | labels_time = torch.Tensor().to(device) 283 | 284 | if (batch_iter + 1) == iter_num: 285 | for sample_index in range(batch_iter * batch_size, train_num): 286 | # print("Sample Index: ", sample_index) 287 | 288 | slidename = train_slide[sample_index] 289 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device) 290 | print(pyg_data) 291 | label_surv_type = torch.tensor(train_label_censorship[sample_index]).unsqueeze(0).to(device) 292 | label_time = torch.tensor(train_label_survtime[sample_index]).unsqueeze(0).to(device) 293 | 294 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth')) 295 | 296 | output = model(pyg_data,cluster_info) 297 | 298 | if outs.shape[0] == 0: 299 | outs = output 300 | else: 301 | outs = torch.cat((outs, output), dim=0) 302 | 303 | if labels_surv_type.shape[0] == 0: 304 | labels_surv_type = label_surv_type 305 | else: 306 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0) 307 | 308 | if labels_time.shape[0] == 0: 309 | labels_time = label_time 310 | else: 311 | labels_time = torch.cat((labels_time, label_time),dim=0) 312 | 313 | 314 | if out_pre.shape[0] == 0: 315 | out_pre = -1 * output 316 | else: 317 | out_pre = torch.cat((out_pre, -1 * output),dim=0) 318 | 319 | if labelall_surv_type.shape[0] == 0: 320 | labelall_surv_type = label_surv_type 321 | else: 322 | labelall_surv_type = torch.cat((labelall_surv_type, label_surv_type), dim=0) 323 | 324 | if labelall_time.shape[0] == 0: 325 | labelall_time = label_time 326 | else: 327 | labelall_time = torch.cat((labelall_time, label_time), dim=0) 328 | 329 | if torch.sum(labels_surv_type) > 0.0: 330 | # print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape) 331 | 332 | 333 | loss = _neg_partial_log(outs,labels_time,labels_surv_type) 334 | loss = args.cox_loss * loss 335 | optimizer.zero_grad() 336 | loss.backward() 337 | optimizer.step() 338 | total_loss_for_train += loss.item() 339 | # print("Batch Avg Loss: {:.4f}", loss.item()) 340 | 341 | else: 342 | 343 | for batch_index in range(batch_size): 344 | index = batch_iter * batch_size + batch_index 345 | # print("Sample Index: ", index) 346 | 347 | slidename = train_slide[index] 348 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device) 349 | print(pyg_data) 350 | label_surv_type = torch.tensor(train_label_censorship[index]).unsqueeze(0).to(device) 351 | label_time = torch.tensor(train_label_survtime[index]).unsqueeze(0).to(device) 352 | 353 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth')) 354 | 355 | output = model(pyg_data,cluster_info) 356 | 357 | if outs.shape[0] == 0: 358 | outs = output 359 | else: 360 | outs = torch.cat((outs, output), dim=0) 361 | 362 | if labels_surv_type.shape[0] == 0: 363 | labels_surv_type = label_surv_type 364 | else: 365 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0) 366 | 367 | if labels_time.shape[0] == 0: 368 | labels_time = label_time 369 | else: 370 | labels_time = torch.cat((labels_time, label_time),dim=0) 371 | 372 | 373 | if out_pre.shape[0] == 0: 374 | out_pre = -1 * output 375 | else: 376 | out_pre = torch.cat((out_pre, -1 * output),dim=0) 377 | 378 | if labelall_surv_type.shape[0] == 0: 379 | labelall_surv_type = label_surv_type 380 | else: 381 | labelall_surv_type = torch.cat((labelall_surv_type, label_surv_type), dim=0) 382 | 383 | if labelall_time.shape[0] == 0: 384 | labelall_time = label_time 385 | else: 386 | labelall_time = torch.cat((labelall_time, label_time), dim=0) 387 | 388 | if torch.sum(labels_surv_type) > 0.0: 389 | # print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape) 390 | 391 | 392 | loss = _neg_partial_log(outs,labels_time,labels_surv_type) 393 | loss = args.cox_loss * loss 394 | optimizer.zero_grad() 395 | loss.backward() 396 | optimizer.step() 397 | total_loss_for_train += loss.item() 398 | # print("Batch Avg Loss: {:.4f}", loss.item()) 399 | 400 | 401 | 402 | c_idx_for_train = ci(labelall_time.data.cpu(),out_pre.data.cpu(),labelall_surv_type.data.cpu()) 403 | 404 | epoch_loss = total_loss_for_train / iter_num 405 | 406 | print("Epoch [{}/{}], epoch_loss {:.4f}".format(epoch+1,n_epochs, epoch_loss)) 407 | 408 | #val 409 | val_c_idx,val_loss,_ = eval_model(model,val_slide,val_label_censorship,val_label_survtime,batch_size,cuda=cuda) 410 | 411 | #test 412 | test_c_idx,test_loss,_ =eval_model(model,test_slide,test_label_censorship,test_label_survtime,batch_size,cuda=cuda) 413 | # if scheduler is not None: 414 | # scheduler.step() 415 | 416 | with open('/data14/yanhe/miccai/train/log_result/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/'+args.label+'_random.log',"a") as f: 417 | f.write(f"EPOCH {epoch} : \n") 418 | f.write(f"train loss - {epoch_loss} train ci - {c_idx_for_train};\n") 419 | f.write(f"val loss - {val_loss} val ci -{val_c_idx};\n") 420 | f.write(f"test loss - {test_loss}test ci - {test_c_idx};\n") 421 | if val_c_idx >= best_ci: 422 | best_epoch = epoch 423 | best_ci = val_c_idx 424 | t_model = copy.deepcopy(model) 425 | # if val_loss < best_loss: 426 | # best_loss = val_loss 427 | # best_epoch = epoch 428 | # t_model = copy.deepcopy(model) 429 | # t_model = copy.deepcopy(model) 430 | save_path = '/data14/yanhe/miccai/train/saved_model/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/fold_num_{}.pth'.format(fold_num) 431 | torch.save({'model_state_dict': t_model.state_dict(), 432 | 'optimizer_state_dict': optimizer.state_dict(), 433 | }, save_path) 434 | print("Model saved: %s" % save_path) 435 | 436 | t_test_c_idx,t_test_loss,atten_dict = eval_model(t_model,test_slide,test_label_censorship,test_label_survtime,batch_size,cuda=cuda) 437 | 438 | atten_dict_saved_path = '/data14/yanhe/miccai/train/attention_weights/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/atten_dict/'+args.label+'_foldnum_{}.pth'.format(fold_num) 439 | torch.save(atten_dict,atten_dict_saved_path) 440 | with open('/data14/yanhe/miccai/train/log_result/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/'+args.label+'_random.log',"a") as f: 441 | f.write(f"best model test ci value {t_test_c_idx} occurs at EPOCH {best_epoch} ;\n") 442 | # f.write(f"best model test ci value {t_test_c_idx} ;\n") 443 | 444 | 445 | def setup_seed(seed): 446 | torch.manual_seed(seed) 447 | os.environ['PYTHONHASHSEED'] = str(seed) 448 | torch.cuda.manual_seed(seed) 449 | torch.cuda.manual_seed_all(seed) 450 | np.random.seed(seed) 451 | random.seed(seed) 452 | torch.backends.cudnn.benchmark = False 453 | torch.backends.cudnn.deterministic = True 454 | torch.backends.cudnn.enabled = True 455 | 456 | 457 | def main(args): 458 | fold_num = args.fold_num 459 | label_path = args.label_path 460 | n_epochs = args.epochs 461 | batch_size = args.batch_size 462 | label = torch.load(label_path) 463 | split_dict_path = args.split_dict_path 464 | seed = args.seed 465 | pyg_path = args.pyg_path 466 | print("seed:{}".format(seed)) 467 | #splits 468 | # split_dict=torch.load('/data14/yanhe/miccai/data/tcga_lihc/train_val_test_split.pkl') 469 | split_dict = joblib.load(split_dict_path) 470 | fold_num = args.fold_num 471 | fold_train = 'fold_'+str(fold_num)+'_train' 472 | fold_val = 'fold_'+str(fold_num)+'_val' 473 | fold_test = 'fold_'+str(fold_num)+'_test' 474 | train_slide = split_dict[fold_train] 475 | val_slide = split_dict[fold_val] 476 | test_slide = split_dict[fold_test] 477 | setup_seed(seed) 478 | 479 | model = Superpixel_Vit(in_feats_intra=args.in_feats_intra, 480 | n_hidden_intra=args.n_hidden_intra, 481 | out_feats_intra=args.out_feats_intra, 482 | in_feats_inter=args.in_feats_inter, 483 | n_hidden_inter=args.n_hidden_inter, 484 | out_feats_inter=args.out_feats_inter, 485 | vw_num=args.vw_num, 486 | feat_dim=args.feat_dim, 487 | num_classes=1, 488 | depth=args.depth, 489 | num_heads = args.num_heads, 490 | final_fea_type = args.final_fea_type, 491 | mpool_intra=args.mpool_intra, 492 | mpool_inter=args.mpool_inter, 493 | gnn_intra=args.gnn_intra, 494 | gnn_inter=args.gnn_inter 495 | ) 496 | 497 | optimizer = torch.optim.AdamW([dict(params=model.parameters(), lr=args.lr, betas=(0.9, 0.95),weight_decay = args.l2_reg_alpha),]) 498 | t_max = n_epochs 499 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=t_max, eta_min=0. ) 500 | 501 | 502 | Train_slide,Val_slide,Test_slide,Train_Labels_censorship,Train_Labels_survtime,Val_Labels_censorship,Val_Labels_survtime,Test_Labels_censorship,Test_Labels_survtime = get_train_val(train_slide,val_slide,test_slide,label_path) 503 | train_model(n_epochs, model, optimizer, scheduler, Train_slide,Val_slide,Test_slide,Train_Labels_censorship,Train_Labels_survtime,Val_Labels_censorship,Val_Labels_survtime,Test_Labels_censorship,Test_Labels_survtime,fold_num,batch_size,cuda=True) 504 | 505 | 506 | 507 | 508 | 509 | def get_params(): 510 | parser = argparse.ArgumentParser(description='model training') 511 | 512 | parser.add_argument('--label_path',type=str, default='/data12/yanhe/miccai/data/tcga_kirc/slide_label.pt') 513 | parser.add_argument('--split_dict_path',type=str, default='/data12/yanhe/miccai/data/tcga_kirc/train_val_test_split_random1.pkl') 514 | parser.add_argument('--pyg_path',type=str, default='/data14/yanhe/miccai/graph_file/tcga_kirc/superpixel_num_600') 515 | parser.add_argument('--cluster_info_path',type=str, default='/data14/yanhe/miccai/codebook/cluster_info/tcga_kirc/superpixel600_cluster16/all_fold' ) 516 | parser.add_argument('--vw_num',type=int, default=16) 517 | parser.add_argument('--feat_dim',type =int,default=1024) 518 | parser.add_argument('--depth',type=int,default=1) 519 | parser.add_argument('--num_heads',type=int,default=4) 520 | parser.add_argument('--mpool_intra',type=str,default='global_mean_pool') #‘global_mean_pool’,'global_max_pool','global_att_pool' 521 | parser.add_argument('--mpool_inter',type=str,default='global_mean_pool') 522 | parser.add_argument('--gnn_intra',type=str,default='sage') #'sage''gcn''gat''leconv''graphconv' 523 | parser.add_argument('--gnn_inter',type=str,default='sage') 524 | parser.add_argument('--in_feats_intra',type=int, default=1024) 525 | parser.add_argument('--n_hidden_intra',type=int, default=1024) 526 | parser.add_argument('--out_feats_intra',type=int,default=1024) 527 | parser.add_argument('--in_feats_inter',type=int,default=1024) #in_feats_inter=out_feats_intra 528 | parser.add_argument('--n_hidden_inter',type=int, default=1024) 529 | parser.add_argument('--out_feats_inter',type=int,default=1024) #out_feats_inter=feat_dim 530 | parser.add_argument('--final_fea_type',type=str,default='mean') 531 | parser.add_argument('--epochs', type=int, default=30) 532 | parser.add_argument('--batch_size',type=int,default=16) 533 | #parser.add_argument('--warmup_epochs', type=int, default= 40) 534 | parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate of model training") 535 | parser.add_argument("--l2_reg_alpha",type=float,default=0.001) 536 | parser.add_argument("--cox_loss",type=float,default=12) 537 | parser.add_argument("--seed",type=int, default=1) 538 | parser.add_argument("--fold_num",type=int, default=0) 539 | parser.add_argument("--label",type=str,default="tcga_lihc_fold_0_lr1e-5_30epoch") 540 | 541 | args, _ = parser.parse_known_args() 542 | return args 543 | 544 | 545 | if __name__ == '__main__': 546 | try: 547 | # tuner_params = nni.get_next_parameter() 548 | # logger.debug(tuner_params) 549 | # params = vars(merge_parameter(get_params(), tuner_params)) 550 | # main(params) 551 | args=get_params() 552 | main(args) 553 | except Exception as exception: 554 | # logger.exception(exception) 555 | raise 556 | 557 | --------------------------------------------------------------------------------