├── README.md
├── cluster.py
├── data_preprocess
├── generate_superpixel.py
├── graph_construction.ipynb
└── superpixel_utils.py
├── data_splits
├── train_val_test_split_argo.pkl
├── train_val_test_split_kirc.pkl
├── train_val_test_split_lihc.pkl
└── useless
├── images
├── miccai_framework.png
└── try
├── label
├── argo_label.pt
├── kirc_label.pt
├── lihc_label.pt
└── useless
├── requirements.txt
└── train
├── block_utils.py
├── run.sh
├── superpixel_transformer_n.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Multi-scope Analysis Driven Hierarchical Graph Transformer for Whole Slide Image based Cancer Survival Prediction
2 |
3 |

4 |
5 | ## Installation
6 | Clone the repo:
7 | ```bash
8 | git clone https://github.com/Baeksweety/HGTHGT && cd HGTHGT
9 | ```
10 | Create a conda environment and activate it:
11 | ```bash
12 | conda create -n env python=3.8
13 | conda activate env
14 | pip install -r requirements.txt
15 | ```
16 |
17 | ## Data Preprocess
18 | ***generate_superpixel.py*** shows how to generate merged superpixels of whole slide images and ***graph_construction.ipynb*** shows how to transform a histological image into the hierarchical graphs. After the data processing is completed, put all hierarchical graphs into a folder. The form is as follows:
19 | ```bash
20 | PYG_Data
21 | └── Dataset
22 | ├── pyg_data_1.pt
23 | ├── pyg_data_2.pt
24 | :
25 | └── pyg_data_n.pt
26 | ```
27 |
28 |
29 | ## Cluster
30 | ***cluster.py*** shows how to generate the fixed number of clusters which woould be used in the train process. The form is as follows:
31 | ```bash
32 | Cluster_Info
33 | └── Dataset
34 | ├── cluster_info_1.pt
35 | ├── cluster_info_2.pt
36 | :
37 | └── cluster_info_n.pt
38 | ```
39 |
40 |
41 |
42 | ## Training
43 | First, setting the data path, data splits and hyperparameters in the file ***train.py***. Then, experiments can be run using the following command-line:
44 | ```bash
45 | cd train
46 | python train.py
47 | or
48 | bash run.sh
49 | ```
50 |
51 | ## Saved models
52 | We provide a 5-fold checkpoint for each dataset, which performing as:
53 | | Dataset | CI |
54 | | ----- |:--------:|
55 | | CRC | 0.607 |
56 | | TCGA_LIHC | 0.657 |
57 | | TCGA_KIRC | 0.646 |
58 |
59 |
60 |
61 |
62 |
63 | ## More Info
64 | - Our implementation refers the following publicly available codes.
65 | - [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric)--Fey M, Lenssen J E. Fast graph representation learning with PyTorch Geometric[J]. arXiv preprint arXiv:1903.02428, 2019.
66 | - [Histocartography](https://github.com/histocartography/histocartography)--Jaume G, Pati P, Anklin V, et al. HistoCartography: A toolkit for graph analytics in digital pathology[C]//MICCAI Workshop on Computational Pathology. PMLR, 2021: 117-128.
67 | - [ViT Pytorch](https://github.com/lukemelas/PyTorch-Pretrained-ViT)--Dosovitskiy A, Beyer L, Kolesnikov A, et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale[C]//International Conference on Learning Representations. 2020.
68 | - [NAGCN](https://github.com/YohnGuan/NAGCN)--Guan Y, Zhang J, Tian K, et al. Node-aligned graph convolutional network for whole-slide image representation and classification[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 18813-18823.
69 |
--------------------------------------------------------------------------------
/cluster.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import h5py
5 | from sklearn.cluster import MiniBatchKMeans, KMeans
6 | import os
7 | import random
8 | import argparse
9 | # 3rd party library
10 | from tqdm import tqdm
11 | import numpy as np
12 | import pandas as pds
13 | from sklearn.cluster import MiniBatchKMeans, KMeans
14 | import torch
15 |
16 | def get_cluster(datas, k=100, num=-1, seed=3):
17 | if seed is not None:
18 | random.seed(seed)
19 | if num > 0:
20 | data_sam = random.sample(datas, num)
21 | else:
22 | data_sam = datas
23 | if seed is not None:
24 | random.seed(seed)
25 | random.shuffle(data_sam)
26 |
27 | dim_N = 0 #record the number of all patches
28 | dim_D = data_sam[0]['shape'][1]
29 | for data in tqdm(data_sam):
30 | dim_N += data['shape'][0]
31 | con_data = np.zeros((dim_N, dim_D), dtype=np.float32)
32 | ind = 0
33 | for data in tqdm(data_sam):
34 | data_path, data_shape = data['slide'], data['shape']
35 | cur_data = torch.load(data_path)
36 | con_data[ind:ind + data_shape[0], :] = cur_data.numpy()
37 | ind += data_shape[0]
38 | # clusterer = KMeans(n_clusters=k)
39 | clusterer = MiniBatchKMeans(n_clusters=k, batch_size=10000)
40 | clusterer.fit(con_data)
41 | print("cluster done")
42 | return clusterer
43 |
44 | for fold_num in range(5):
45 | cluster_data = torch.load('/data14/yanhe/miccai/codebook/data/argo/codebook_info_fold{}.pt'.format(fold_num))
46 | # print(len(cluster_data))
47 | train_data = []
48 | for data in cluster_data:
49 | train_data.append(data)
50 | print(len(train_data))
51 | clusterer = get_cluster(train_data, k=16, num=-1, seed=3) #k represents the number of clusters
52 | saved_path='/data14/yanhe/miccai/codebook/patch_cluster/argo/fold{}/'.format(fold_num)
53 | os.makedirs(saved_path,exist_ok=True)
54 | for slidename in os.listdir('/data12/ybj/survival/argo_selected/20x/slides_feat/h5_files'):
55 | if slidename[-3:] == '.h5':
56 | slide_path=os.path.join('/data12/ybj/survival/argo_selected/20x/slides_feat/h5_files',slidename)
57 | # print(slide_path)
58 | f = h5py.File(slide_path)
59 | name, _ = os.path.splitext(slidename)
60 | print(name)
61 | slide_cluster_path=saved_path+name+'.pth'
62 | # print(slide_cluster_path)
63 | length = f['coords'].shape[0]
64 | wsi_patch=[]
65 | for i in range(length):
66 | wsi_patch_info = {}
67 | patch_fea = f['features'][i].reshape(1,-1)
68 | cluster_class = clusterer.predict(patch_fea)[0]
69 | wsi_patch_info['patch_id']=i
70 | wsi_patch_info['cluster']=cluster_class
71 | wsi_patch.append(wsi_patch_info)
72 | torch.save(wsi_patch,slide_cluster_path)
73 |
--------------------------------------------------------------------------------
/data_preprocess/generate_superpixel.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | import numpy as np
7 | import pandas as pd
8 | import h5py
9 | from sklearn.cluster import MiniBatchKMeans, KMeans
10 | import random
11 | from PIL import Image
12 | from dgl.data.utils import save_graphs
13 |
14 | from histocartography.utils import download_example_data
15 | from histocartography.preprocessing import (
16 | ColorMergedSuperpixelExtractor,
17 | DeepFeatureExtractor
18 | )
19 |
20 | from histocartography.visualization import OverlayGraphVisualization
21 | from superpixel_utils import RAGGraphBuilder,MyColorMergedSuperpixelExtractor
22 | import argparse
23 | from skimage.measure import regionprops
24 | import joblib
25 | import cv2
26 |
27 |
28 | #remove background
29 | def get_node_centroids(instance_map: np.ndarray,raw_image:np.ndarray):
30 | # print(instance_map)
31 | regions = regionprops(instance_map)
32 | mask_value=1
33 | # centroids = np.empty((len(regions), 2))
34 | for i, region in enumerate(regions):
35 | center_y, center_x = region.centroid # (y, x)
36 | # print(i)
37 | # print(region.coords)
38 | center_x = int(round(center_x))
39 | center_y = int(round(center_y))
40 | if sum(raw_image[center_y,center_x])== 0:
41 | for index,couple in enumerate(region.coords):
42 | y,x = couple
43 | # print(y,x)
44 | instance_map[y,x]=0
45 | else:
46 | for index,couple in enumerate(region.coords):
47 | y,x = couple
48 | # print(y,x)
49 | instance_map[y,x]=mask_value
50 | mask_value+=1
51 |
52 | # centroids[i, 0] = center_x
53 | # centroids[i, 1] = center_y
54 | # print(instance_map)
55 | return instance_map
56 |
57 | def generate_superpixel(image_path,downsampling_factor):
58 | """
59 | Generate a tissue graph for all the images in image path dir.
60 | """
61 |
62 | # 1. get image path
63 | # image_fnames = glob(os.path.join(image_path, '*.png'))
64 | # image_fnames = img_path
65 | # 2. define superpixel extractor. Here, we query 50 SLIC superpixels,
66 | # but a superpixel size (in #pixels) can be provided as well in the case
67 | # where image size vary from one sample to another.
68 | superpixel_detector = ColorMergedSuperpixelExtractor(
69 | nr_superpixels=args.nr_superpixels,
70 | compactness=10,
71 | blur_kernel_size=1,
72 | threshold=1,
73 | downsampling_factor=downsampling_factor,
74 | connectivity = 2,
75 | )
76 |
77 | # 3. define feature extractor: extract patches of 144x144 pixels
78 | # resized to 224 to match resnet input size. If the superpixel is larger
79 | # than 144x144, several patches are extracted and patch embeddings are averaged.
80 | # Everything is handled internally. Please refer to the implementation for
81 | # details.
82 | feature_extractor = DeepFeatureExtractor(
83 | architecture='resnet34',
84 | patch_size=144,
85 | resize_size=224
86 | )
87 |
88 | # 4. define graph builder
89 | tissue_graph_builder = RAGGraphBuilder(add_loc_feats=True)
90 |
91 | # 5. define graph visualizer
92 | visualizer = OverlayGraphVisualization()
93 |
94 | _, image_name = os.path.split(image_path)
95 | image = Image.open(image_path)
96 | image=np.array(image)
97 | fname = image_name[:-4]+'.png'
98 | print(image.shape)
99 |
100 | superpixels, _ = superpixel_detector.process(image)
101 | superpixels = get_node_centroids(superpixels,image)
102 | return superpixels
103 |
104 |
105 | def modify_superpixels(superpixel_patchnum,instance_map):
106 | regions = regionprops(instance_map)
107 | mask_value=1
108 | for i, region in enumerate(regions):
109 | sup_value = region.label
110 | if (superpixel_patchnum[sup_value]==0) or (superpixel_patchnum[sup_value]==1):
111 | print('true!')
112 | for index,couple in enumerate(region.coords):
113 | y,x = couple
114 | instance_map[y,x]=0
115 | else:
116 | for index,couple in enumerate(region.coords):
117 | y,x = couple
118 | instance_map[y,x]=mask_value
119 | mask_value+=1
120 | return instance_map
121 |
122 |
123 |
124 | def get_max_value(data_matrix):
125 | new_data=[]
126 | for i in range(len(data_matrix)):
127 | new_data.append(max(data_matrix[i]))
128 | return max(new_data)
129 |
130 |
131 |
132 | def patch_judge_40x(stitch_path,saved_path):
133 | _, image_name = os.path.split(stitch_path)
134 | name = image_name[:-4]
135 | superpixels=generate_superpixel(stitch_path,downsampling_factor=4)
136 | print(superpixels.shape)
137 | h,w=superpixels.shape
138 | slide_info_path = '/data12/ybj/survival/CRC/40x/slides_feat/h5_files/'+name+'.h5'
139 | print(slide_info_path)
140 | f = h5py.File(slide_info_path)
141 | new_slide_info=[]
142 | new_slide_info_path=saved_path+name+'.pth'
143 | length = f['coords'].shape[0]
144 | for i in range(length):
145 | patch_info={}
146 | patch_info['patch_id']=i
147 | patch_info['coords']=f['coords'][i]
148 | patch_info['features']=f['features'][i]
149 | patch_coords=f['coords'][i]
150 | # x,y=slide_info[i]['coords']
151 | x,y=patch_coords
152 | x=x+512
153 | y=y+512
154 | x=int(x/16)
155 | y=int(y/16)
156 | if (y=h)&(x=w):
166 | x_n,y_n=f['coords'][i]
167 | x_n=x_n+(w*16-x_n)/2
168 | y_n=y_n+512
169 | x_n=int(x_n/16)
170 | y_n=int(y_n/16)
171 | patch_info['superpixel']=superpixels[y_n][x_n]
172 | else:
173 | x_n,y_n=f['coords'][i]
174 | x_n=x_n+(w*16-x_n)/2
175 | y_n=y_n+(h*16-y_n)/2
176 | x_n=int(x_n/16)
177 | y_n=int(y_n/16)
178 | patch_info['superpixel']=superpixels[y_n][x_n]
179 | new_slide_info.append(patch_info)
180 | superpixel_patchnum={}
181 | max_superpixel = get_max_value(superpixels)
182 | # print(max(superpixel))
183 | print(max_superpixel)
184 | for sup_value in range(1,max_superpixel+1):
185 | # print(sup_value)
186 | coords_s=[]
187 | features_s=[]
188 | for patch_id in range(len(new_slide_info)):
189 | if new_slide_info[patch_id]['superpixel']==sup_value:
190 | coords_s.append(new_slide_info[patch_id]['coords'])
191 | # features_s.append(new_slide_info[patch_id]['features'])
192 | coords_np=np.array(coords_s)
193 | # features_np=np.array(features_s)
194 | # print(coords_np.shape,features_np.shape)
195 | patch_number=coords_np.shape[0]
196 | superpixel_patchnum[sup_value] = patch_number
197 | superpixels = modify_superpixels(superpixel_patchnum,superpixels)
198 | wsi_info = []
199 | for i in range(length):
200 | patch_info={}
201 | patch_info['patch_id']=i
202 | patch_info['coords']=f['coords'][i]
203 | patch_info['features']=f['features'][i]
204 | patch_coords=f['coords'][i]
205 | # x,y=slide_info[i]['coords']
206 | x,y=patch_coords
207 | x=x+512
208 | y=y+512
209 | x=int(x/16)
210 | y=int(y/16)
211 | if (y=h)&(x=w):
221 | x_n,y_n=f['coords'][i]
222 | x_n=x_n+(w*16-x_n)/2
223 | y_n=y_n+512
224 | x_n=int(x_n/16)
225 | y_n=int(y_n/16)
226 | patch_info['superpixel']=superpixels[y_n][x_n]
227 | else:
228 | x_n,y_n=f['coords'][i]
229 | x_n=x_n+(w*16-x_n)/2
230 | y_n=y_n+(h*16-y_n)/2
231 | x_n=int(x_n/16)
232 | y_n=int(y_n/16)
233 | patch_info['superpixel']=superpixels[y_n][x_n]
234 | wsi_info.append(patch_info)
235 | torch.save(wsi_info,new_slide_info_path)
236 | return superpixels
237 |
238 |
239 | def patch_judge_20x(stitch_path,saved_path):
240 | _, image_name = os.path.split(stitch_path)
241 | name = image_name[:-4]
242 | superpixels=generate_superpixel(stitch_path,downsampling_factor=2)
243 | print(superpixels.shape)
244 | h,w=superpixels.shape
245 | slide_info_path = '/data12/ybj/survival/CRC/20x/slides_feat/h5_files/'+name+'.h5'
246 | f = h5py.File(slide_info_path)
247 | new_slide_info=[]
248 | new_slide_info_path=saved_path+name+'.pth'
249 | length = f['coords'].shape[0]
250 | for i in range(length):
251 | patch_info={}
252 | patch_info['patch_id']=i
253 | patch_info['coords']=f['coords'][i]
254 | patch_info['features']=f['features'][i]
255 | patch_coords=f['coords'][i]
256 | x,y=patch_coords
257 | x=x+256
258 | y=y+256
259 | x=int(x/16)
260 | y=int(y/16)
261 | if (y=h)&(x=w):
271 | x_n,y_n=f['coords'][i]
272 | x_n=x_n+(w*16-x_n)/2
273 | y_n=y_n+256
274 | x_n=int(x_n/16)
275 | y_n=int(y_n/16)
276 | patch_info['superpixel']=superpixels[y_n][x_n]
277 | else:
278 | x_n,y_n=f['coords'][i]
279 | x_n=x_n+(w*16-x_n)/2
280 | y_n=y_n+(h*16-y_n)/2
281 | x_n=int(x_n/16)
282 | y_n=int(y_n/16)
283 | patch_info['superpixel']=superpixels[y_n][x_n]
284 | new_slide_info.append(patch_info)
285 | superpixel_patchnum={}
286 | max_superpixel = get_max_value(superpixels)
287 | # print(max(superpixel))
288 | print(max_superpixel)
289 | for sup_value in range(1,max_superpixel+1):
290 | # print(sup_value)
291 | coords_s=[]
292 | features_s=[]
293 | for patch_id in range(len(new_slide_info)):
294 | if new_slide_info[patch_id]['superpixel']==sup_value:
295 | coords_s.append(new_slide_info[patch_id]['coords'])
296 | coords_np=np.array(coords_s)
297 | patch_number=coords_np.shape[0]
298 | superpixel_patchnum[sup_value] = patch_number
299 | superpixels = modify_superpixels(superpixel_patchnum,superpixels)
300 | wsi_info = []
301 | for i in range(length):
302 | patch_info={}
303 | patch_info['patch_id']=i
304 | patch_info['coords']=f['coords'][i]
305 | patch_info['features']=f['features'][i]
306 | patch_coords=f['coords'][i]
307 | x,y=patch_coords
308 | x=x+256
309 | y=y+256
310 | x=int(x/16)
311 | y=int(y/16)
312 | if (y=h)&(x=w):
322 | x_n,y_n=f['coords'][i]
323 | x_n=x_n+(w*16-x_n)/2
324 | y_n=y_n+256
325 | x_n=int(x_n/16)
326 | y_n=int(y_n/16)
327 | patch_info['superpixel']=superpixels[y_n][x_n]
328 | else:
329 | x_n,y_n=f['coords'][i]
330 | x_n=x_n+(w*16-x_n)/2
331 | y_n=y_n+(h*16-y_n)/2
332 | x_n=int(x_n/16)
333 | y_n=int(y_n/16)
334 | patch_info['superpixel']=superpixels[y_n][x_n]
335 | wsi_info.append(patch_info)
336 | torch.save(wsi_info,new_slide_info_path)
337 | return superpixels
338 |
339 | def generate_tissue_graph(slide_list_path,image_path,saved_path,graph_file_saved_path,vis_saved_path):
340 | """
341 | Generate a tissue graph for all the images in image path dir.
342 | """
343 | feature_extractor = DeepFeatureExtractor(
344 | architecture='resnet34',
345 | patch_size=144,
346 | resize_size=224
347 | )
348 |
349 | # 4. define graph builder
350 | tissue_graph_builder = RAGGraphBuilder(add_loc_feats=True)
351 |
352 | # 5. define graph visualizer
353 | visualizer = OverlayGraphVisualization()
354 | # print(fname)
355 | # b. extract superpixels
356 | slide_list = joblib.load(slide_list_path)
357 | # slide_x, _ = os.path.splitext(slide_list_path)
358 | _, slide_x = os.path.split(slide_list_path)
359 | # print(slide_list_path)
360 | # print(slide_x)
361 | slide_x = slide_x.split('_')[1]
362 | print(slide_x)
363 | # for slidename in os.listdir(image_path):
364 | for index in range(len(slide_list)):
365 | slidename = slide_list[index]
366 | name, _ = os.path.splitext(slidename)
367 | # print(name)
368 | stitch_name = name+'.jpg'
369 | stitch_path = os.path.join(image_path,stitch_name)
370 | print(stitch_path)
371 | if slide_x == '20x':
372 | superpixels = patch_judge_20x(stitch_path,saved_path)
373 | elif slide_x == '40x':
374 | superpixels = patch_judge_40x(stitch_path,saved_path)
375 | # torch.save(superpixels,os.path.join('/data13/yanhe/miccai/super_pixel/superpixel_array/tcga_lihc_200superpixel',image_name[:-4]+'.pt'))
376 | # print(image_name)
377 | print(superpixels.shape)
378 | # print(superpixels)
379 | # c. extract deep features
380 | image = Image.open(stitch_path)
381 | features = feature_extractor.process(image, superpixels)
382 | # print(features.shape)
383 | # d. build a Region Adjacency Graph (RAG)
384 | # graph = tissue_graph_builder.process(image, superpixels, features)
385 | graph = tissue_graph_builder.process(superpixels, features)
386 | # print(graph)
387 | # e. save the graph
388 | torch.save(graph,os.path.join(graph_file_saved_path,slidename[:-4]+'.pt'))
389 | # f. visualize and save the graph
390 | canvas = visualizer.process(image, graph, instance_map=superpixels)
391 | canvas.save(os.path.join(vis_saved_path,slidename[:-4]+'.png'))
392 |
393 | def main(args):
394 | saved_path = args.saved_path
395 | graph_file_saved_path = args.graph_file_saved_path
396 | vis_saved_path = args.vis_saved_path
397 | os.makedirs(saved_path,exist_ok=True)
398 | os.makedirs(graph_file_saved_path,exist_ok=True)
399 | os.makedirs(vis_saved_path,exist_ok=True)
400 | #20x
401 | slide_20x_path = args.slide_20x_path
402 | stitch_20x_path = args.stitch_20x_path
403 | generate_tissue_graph(slide_20x_path,stitch_20x_path,saved_path,graph_file_saved_path,vis_saved_path)
404 |
405 | #40x
406 | slide_40x_path = args.slide_40x_path
407 | stitch_40x_path = args.stitch_40x_path
408 | generate_tissue_graph(slide_40x_path,stitch_40x_path,saved_path,graph_file_saved_path,vis_saved_path)
409 |
410 |
411 |
412 | def get_params():
413 | parser = argparse.ArgumentParser(description='superpixel_generate')
414 |
415 | parser.add_argument('--slide_40x_path', type=str, default='/data12/yanhe/miccai/data/tcga_crc/slide_40x_list.pkl')
416 | parser.add_argument('--slide_20x_path', type=str, default='/data12/yanhe/miccai/data/tcga_crc/slide_20x_list.pkl')
417 | parser.add_argument('--stitch_20x_path', type=str, default='/data12/ybj/survival/CRC/20x/stitches')
418 | parser.add_argument('--stitch_40x_path', type=str, default='/data12/ybj/survival/CRC/40x/stitches')
419 | parser.add_argument('--saved_path', type=str, default='/data11/yanhe/miccai/super_pixel/slide_superpixel/tcga_crc/superpixel_num_300/')
420 | parser.add_argument('--vis_saved_path', type=str, default='/data11/yanhe/miccai/super_pixel/vis/tcga_crc/superpixel_num_300')
421 | parser.add_argument('--graph_file_saved_path', type=str, default='/data11/yanhe/miccai/super_pixel/graph_file/tcga_crc/superpixel_num_300')
422 | parser.add_argument('--nr_superpixels', type=int, default=300)
423 |
424 | args, _ = parser.parse_known_args()
425 | return args
426 |
427 |
428 | if __name__ == '__main__':
429 | try:
430 | args=get_params()
431 | main(args)
432 | except Exception as exception:
433 | # logger.exception(exception)
434 | raise
435 |
--------------------------------------------------------------------------------
/data_preprocess/graph_construction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import nmslib\n",
10 | "import networkx as nx\n",
11 | "import os\n",
12 | "import torch\n",
13 | "import numpy as np\n",
14 | "from tqdm import tqdm\n",
15 | "import h5py"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "class Hnsw:\n",
25 | " def __init__(self, space='cosinesimil', index_params=None,\n",
26 | " query_params=None, print_progress=True):\n",
27 | " self.space = space\n",
28 | " self.index_params = index_params\n",
29 | " self.query_params = query_params\n",
30 | " self.print_progress = print_progress\n",
31 | "\n",
32 | " def fit(self, X):\n",
33 | " index_params = self.index_params\n",
34 | " if index_params is None:\n",
35 | " index_params = {'M': 16, 'post': 0, 'efConstruction': 400}\n",
36 | "\n",
37 | " query_params = self.query_params\n",
38 | " if query_params is None:\n",
39 | " query_params = {'ef': 90}\n",
40 | "\n",
41 | " # this is the actual nmslib part, hopefully the syntax should\n",
42 | " # be pretty readable, the documentation also has a more verbiage\n",
43 | " # introduction: https://nmslib.github.io/nmslib/quickstart.html\n",
44 | " index = nmslib.init(space=self.space, method='hnsw')\n",
45 | " index.addDataPointBatch(X)\n",
46 | " index.createIndex(index_params, print_progress=self.print_progress)\n",
47 | " index.setQueryTimeParams(query_params)\n",
48 | "\n",
49 | " self.index_ = index\n",
50 | " self.index_params_ = index_params\n",
51 | " self.query_params_ = query_params\n",
52 | " return self\n",
53 | "\n",
54 | " def query(self, vector, topn):\n",
55 | " # the knnQuery returns indices and corresponding distance\n",
56 | " # we will throw the distance away for now\n",
57 | " indices, dist = self.index_.knnQuery(vector, k=topn)\n",
58 | " return indices"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "from torch_geometric.data import Data as geomData\n",
68 | "from itertools import chain\n",
69 | "\n",
70 | "def pt2graph(wsi_h5, slidename,radius=9):\n",
71 | " from torch_geometric.data import Data as geomData\n",
72 | " from itertools import chain\n",
73 | " coords, features = np.array(wsi_h5['coords']), np.array(wsi_h5['features'])\n",
74 | " assert coords.shape[0] == features.shape[0]\n",
75 | " num_patches = coords.shape[0]\n",
76 | "# print(num_patches)\n",
77 | " superpixel_info_path = os.path.join('/data14/yanhe/miccai/super_pixel/slide_superpixel/argo_new/superpixel_num_500',slidename+'.pth')\n",
78 | " superpixel_info = torch.load(superpixel_info_path)\n",
79 | " superpixel_attri=[]\n",
80 | "# print(len(superpixel_info))\n",
81 | " for index in range(len(superpixel_info)):\n",
82 | " superpixel = superpixel_info[index]['superpixel']\n",
83 | " superpixel_attri.append(superpixel)\n",
84 | " superpixel_attri = torch.LongTensor(superpixel_attri) \n",
85 | "# print(superpixel_attri)\n",
86 | " \n",
87 | " inter_graph_path = os.path.join('/data14/yanhe/miccai/super_pixel/graph_file/argo/superpixel_num_500',slidename+'.pt')\n",
88 | " g = torch.load(inter_graph_path)\n",
89 | " for index in range(g.ndata['centroid'].shape[0]):\n",
90 | " edge_index = torch.Tensor()\n",
91 | " edge_index = g.edges()[0].unsqueeze(0)\n",
92 | " edge_index = torch.cat((edge_index,g.edges()[1].unsqueeze(0)),dim=0)\n",
93 | " \n",
94 | " model = Hnsw(space='l2')\n",
95 | " model.fit(coords)\n",
96 | " a = np.repeat(range(num_patches), radius-1)\n",
97 | " b = np.fromiter(chain(*[model.query(coords[v_idx], topn=radius)[1:] for v_idx in range(num_patches)]),dtype=int)\n",
98 | " edge_spatial = torch.Tensor(np.stack([a,b])).type(torch.LongTensor)\n",
99 | " superpixel_edge = superpixel_attri[edge_spatial]\n",
100 | "\n",
101 | " edge_mask = (superpixel_edge[0,:] == superpixel_edge[1,:])\n",
102 | "\n",
103 | " remain_edge_index = edge_spatial[:,edge_mask]\n",
104 | " G = geomData(x = torch.Tensor(features),\n",
105 | " edge_patch = remain_edge_index,\n",
106 | " edge_superpixel = edge_index,\n",
107 | " superpixel_attri = superpixel_attri,\n",
108 | " centroid = torch.Tensor(coords))\n",
109 | " return G"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "def createDir_h5toPyG(h5_path, save_path):\n",
119 | " pbar = tqdm(os.listdir(h5_path))\n",
120 | " for h5_fname in pbar:\n",
121 | " pbar.set_description('%s - Creating Graph' % (h5_fname[:-3]))\n",
122 | "\n",
123 | " try:\n",
124 | " wsi_h5 = h5py.File(os.path.join(h5_path, h5_fname), \"r\")\n",
125 | " slidename = h5_fname[:-3]\n",
126 | " if slidename != 'ZS6Y1A01554_HE520' and slidename != 'ZS6Y1A03883_HE208' and slidename != 'ZS6Y1A07240_HE400' and slidename != 'ZS6Y1A08318_HE155':\n",
127 | " G = pt2graph(wsi_h5,slidename)\n",
128 | " torch.save(G, os.path.join(save_path, h5_fname[:-3]+'.pt'))\n",
129 | " wsi_h5.close()\n",
130 | " except OSError:\n",
131 | " pbar.set_description('%s - Broken H5' % (h5_fname[:12]))\n",
132 | " print(h5_fname, 'Broken')"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "h5_path = '/data12/ybj/survival/argo_selected/20x/slides_feat/h5_files'\n",
142 | "save_path = '/data14/yanhe/miccai/graph_file/argo/superpixel_num_500/'\n",
143 | "os.makedirs(save_path, exist_ok=True)\n",
144 | "createDir_h5toPyG(h5_path, save_path)"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": []
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {},
158 | "outputs": [],
159 | "source": []
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": []
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {},
172 | "outputs": [],
173 | "source": []
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {},
179 | "outputs": [],
180 | "source": []
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": []
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {},
193 | "outputs": [],
194 | "source": []
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "metadata": {},
200 | "outputs": [],
201 | "source": []
202 | }
203 | ],
204 | "metadata": {
205 | "kernelspec": {
206 | "display_name": "pylcx",
207 | "language": "python",
208 | "name": "pylcx"
209 | },
210 | "language_info": {
211 | "codemirror_mode": {
212 | "name": "ipython",
213 | "version": 3
214 | },
215 | "file_extension": ".py",
216 | "mimetype": "text/x-python",
217 | "name": "python",
218 | "nbconvert_exporter": "python",
219 | "pygments_lexer": "ipython3",
220 | "version": "3.8.5"
221 | }
222 | },
223 | "nbformat": 4,
224 | "nbformat_minor": 4
225 | }
226 |
--------------------------------------------------------------------------------
/data_preprocess/superpixel_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | import numpy as np
7 | import pandas as pd
8 | import h5py
9 | from sklearn.cluster import MiniBatchKMeans, KMeans
10 | import random
11 | from PIL import Image
12 | from dgl.data.utils import save_graphs
13 |
14 | from histocartography.utils import download_example_data
15 | from histocartography.preprocessing import (
16 | ColorMergedSuperpixelExtractor,
17 | DeepFeatureExtractor
18 | )
19 | from histocartography.visualization import OverlayGraphVisualization
20 |
21 | from skimage.measure import regionprops
22 | import joblib
23 | import cv2
24 |
25 | import logging
26 | import multiprocessing
27 | import os
28 | import sys
29 | from abc import ABC, abstractmethod
30 | from copy import deepcopy
31 | from functools import partial
32 | from pathlib import Path
33 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
34 |
35 | import h5py
36 | import pandas as pd
37 | from tqdm.auto import tqdm
38 |
39 | import logging
40 | import multiprocessing
41 | import os
42 | import sys
43 | from abc import ABC, abstractmethod
44 | from copy import deepcopy
45 | from functools import partial
46 | from pathlib import Path
47 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
48 |
49 | import h5py
50 | import pandas as pd
51 | from tqdm.auto import tqdm
52 |
53 | class PipelineStep(ABC):
54 | """Base pipelines step"""
55 |
56 | def __init__(
57 | self,
58 | save_path: Union[None, str, Path] = None,
59 | precompute: bool = True,
60 | link_path: Union[None, str, Path] = None,
61 | precompute_path: Union[None, str, Path] = None,
62 | ) -> None:
63 | """Abstract class that helps with saving and loading precomputed results
64 | Args:
65 | save_path (Union[None, str, Path], optional): Base path to save results to.
66 | When set to None, the results are not saved to disk. Defaults to None.
67 | precompute (bool, optional): Whether to perform the precomputation necessary
68 | for the step. Defaults to True.
69 | link_path (Union[None, str, Path], optional): Path to link the output directory
70 | to. When None, no link is created. Only supported when save_path is not None.
71 | Defaults to None.
72 | precompute_path (Union[None, str, Path], optional): Path to save the output of
73 | the precomputation to. If not specified it defaults to the output directory
74 | of the step when save_path is not None. Defaults to None.
75 | """
76 | assert (
77 | save_path is not None or link_path is None
78 | ), "link_path only supported when save_path is not None"
79 |
80 | name = self.__repr__()
81 | self.save_path = save_path
82 | if self.save_path is not None:
83 | self.output_dir = Path(self.save_path) / name
84 | self.output_key = "default_key"
85 | self._mkdir()
86 | if precompute_path is None:
87 | precompute_path = save_path
88 |
89 | if precompute:
90 | self.precompute(
91 | link_path=link_path,
92 | precompute_path=precompute_path)
93 |
94 | def __repr__(self) -> str:
95 | """Representation of a pipeline step.
96 | Returns:
97 | str: Representation of a pipeline step.
98 | """
99 | variables = ",".join(
100 | [f"{k}={v}" for k, v in sorted(self.__dict__.items())])
101 | return (
102 | f"{self.__class__.__name__}({variables})".replace(" ", "")
103 | .replace('"', "")
104 | .replace("'", "")
105 | .replace("..", "")
106 | .replace("/", "_")
107 | )
108 |
109 | def _mkdir(self) -> None:
110 | """Create path to output files"""
111 | assert (
112 | self.save_path is not None
113 | ), "Can only create directory if base_path was not None when constructing the object"
114 | if not self.output_dir.exists():
115 | self.output_dir.mkdir()
116 |
117 | def _link_to_path(self, link_directory: Union[None, str, Path]) -> None:
118 | """Links the output directory to the given directory.
119 | Args:
120 | link_directory (Union[None, str, Path]): Directory to link to
121 | """
122 | if link_directory is None or Path(
123 | link_directory).parent.resolve() == Path(self.output_dir):
124 | logging.info("Link to self skipped")
125 | return
126 | assert (
127 | self.save_path is not None
128 | ), f"Linking only supported when saving is enabled, i.e. when save_path is passed in the constructor."
129 | if os.path.islink(link_directory):
130 | if os.path.exists(link_directory):
131 | logging.info("Link already exists: overwriting...")
132 | os.remove(link_directory)
133 | else:
134 | logging.critical(
135 | "Link exists, but points nowhere. Ignoring...")
136 | return
137 | elif os.path.exists(link_directory):
138 | os.remove(link_directory)
139 | os.symlink(self.output_dir, link_directory, target_is_directory=True)
140 |
141 | def precompute(
142 | self,
143 | link_path: Union[None, str, Path] = None,
144 | precompute_path: Union[None, str, Path] = None,
145 | ) -> None:
146 | """Precompute all necessary information for this step
147 | Args:
148 | link_path (Union[None, str, Path], optional): Path to link the output to. Defaults to None.
149 | precompute_path (Union[None, str, Path], optional): Path to load/save the precomputation outputs. Defaults to None.
150 | """
151 | pass
152 |
153 | def process(
154 | self, *args: Any, output_name: Optional[str] = None, **kwargs: Any
155 | ) -> Any:
156 | """Main process function of the step and outputs the result. Try to saves the output when output_name is passed.
157 | Args:
158 | output_name (Optional[str], optional): Unique identifier of the passed datapoint. Defaults to None.
159 | Returns:
160 | Any: Result of the pipeline step
161 | """
162 | if output_name is not None and self.save_path is not None:
163 | return self._process_and_save(
164 | *args, output_name=output_name, **kwargs)
165 | else:
166 | return self._process(*args, **kwargs)
167 |
168 | @abstractmethod
169 | def _process(self, *args: Any, **kwargs: Any) -> Any:
170 | """Abstract method that performs the computation of the pipeline step
171 | Returns:
172 | Any: Result of the pipeline step
173 | """
174 |
175 | def _get_outputs(self, input_file: h5py.File) -> Union[Any, Tuple]:
176 | """Extracts the step output from a given h5 file
177 | Args:
178 | input_file (h5py.File): File to load from
179 | Returns:
180 | Union[Any, Tuple]: Previously computed output of the step
181 | """
182 | outputs = list()
183 | nr_outputs = len(input_file.keys())
184 |
185 | # Legacy, remove at some point
186 | if nr_outputs == 1 and self.output_key in input_file.keys():
187 | return tuple([input_file[self.output_key][()]])
188 |
189 | for i in range(nr_outputs):
190 | outputs.append(input_file[f"{self.output_key}_{i}"][()])
191 | if len(outputs) == 1:
192 | return outputs[0]
193 | else:
194 | return tuple(outputs)
195 |
196 | def _set_outputs(self, output_file: h5py.File,
197 | outputs: Union[Tuple, Any]) -> None:
198 | """Save the step output to a given h5 file
199 | Args:
200 | output_file (h5py.File): File to write to
201 | outputs (Union[Tuple, Any]): Computed step output
202 | """
203 | if not isinstance(outputs, tuple):
204 | outputs = tuple([outputs])
205 | for i, output in enumerate(outputs):
206 | output_file.create_dataset(
207 | f"{self.output_key}_{i}",
208 | data=output,
209 | compression="gzip",
210 | compression_opts=9,
211 | )
212 |
213 | def _process_and_save(
214 | self, *args: Any, output_name: str, **kwargs: Any
215 | ) -> Any:
216 | """Process and save in the provided path as as .h5 file
217 | Args:
218 | output_name (str): Unique identifier of the the passed datapoint
219 | Raises:
220 | read_error (OSError): When the unable to read to self.output_dir/output_name.h5
221 | write_error (OSError): When the unable to write to self.output_dir/output_name.h5
222 | Returns:
223 | Any: Result of the pipeline step
224 | """
225 | assert (
226 | self.save_path is not None
227 | ), "Can only save intermediate output if base_path was not None when constructing the object"
228 | output_path = self.output_dir / f"{output_name}.h5"
229 | if output_path.exists():
230 | logging.info(
231 | f"{self.__class__.__name__}: Output of {output_name} already exists, using it instead of recomputing"
232 | )
233 | try:
234 | with h5py.File(output_path, "r") as input_file:
235 | output = self._get_outputs(input_file=input_file)
236 | except OSError as read_error:
237 | print(f"\n\nCould not read from {output_path}!\n\n")
238 | raise read_error
239 | else:
240 | output = self._process(*args, **kwargs)
241 | try:
242 | with h5py.File(output_path, "w") as output_file:
243 | self._set_outputs(output_file=output_file, outputs=output)
244 | except OSError as write_error:
245 | print(f"\n\nCould not write to {output_path}!\n\n")
246 | raise write_error
247 | return output
248 |
249 | def fast_histogram(input_array: np.ndarray, nr_values: int) -> np.ndarray:
250 | """Calculates a histogram of a matrix of the values from 0 up to (excluding) nr_values
251 | Args:
252 | x (np.array): Input tensor
253 | nr_values (int): Possible values. From 0 up to (exclusing) nr_values.
254 | Returns:
255 | np.array: Output tensor
256 | """
257 | output_array = np.empty(nr_values, dtype=int)
258 | for i in range(nr_values):
259 | output_array[i] = (input_array == i).sum()
260 | return output_array
261 |
262 |
263 | def load_image(image_path: Path) -> np.ndarray:
264 | """Loads an image from a given path and returns it as a numpy array
265 | Args:
266 | image_path (Path): Path of the image
267 | Returns:
268 | np.ndarray: Array representation of the image
269 | """
270 | assert image_path.exists()
271 | try:
272 | with Image.open(image_path) as img:
273 | image = np.array(img)
274 | except OSError as e:
275 | logging.critical("Could not open %s", image_path)
276 | raise OSError(e)
277 | return image
278 |
279 | """This module handles all the graph building"""
280 |
281 | import logging
282 | from abc import abstractmethod
283 | from pathlib import Path
284 | from typing import Any, Optional, Tuple, Union
285 |
286 | import cv2
287 | import dgl
288 | import networkx as nx
289 | import numpy as np
290 | import pandas as pd
291 | import torch
292 | from dgl.data.utils import load_graphs, save_graphs
293 | from skimage.measure import regionprops
294 | from sklearn.neighbors import kneighbors_graph
295 |
296 | # from ..pipeline import PipelineStep
297 | # from .utils import fast_histogram
298 |
299 |
300 |
301 | LABEL = "label"
302 | CENTROID = "centroid"
303 | FEATURES = "feat"
304 |
305 |
306 | def two_hop_neighborhood(graph: dgl.DGLGraph) -> dgl.DGLGraph:
307 | """Increases the connectivity of a given graph by an additional hop
308 | Args:
309 | graph (dgl.DGLGraph): Input graph
310 | Returns:
311 | dgl.DGLGraph: Output graph
312 | """
313 | A = graph.adjacency_matrix().to_dense()
314 | A_tilde = (1.0 * ((A + A.matmul(A)) >= 1)) - torch.eye(A.shape[0])
315 | ngraph = nx.convert_matrix.from_numpy_matrix(A_tilde.numpy())
316 | new_graph = dgl.DGLGraph()
317 | new_graph.from_networkx(ngraph)
318 | for k, v in graph.ndata.items():
319 | new_graph.ndata[k] = v
320 | for k, v in graph.edata.items():
321 | new_graph.edata[k] = v
322 | return new_graph
323 |
324 |
325 | class BaseGraphBuilder(PipelineStep):
326 | """
327 | Base interface class for graph building.
328 | """
329 |
330 | def __init__(
331 | self,
332 | nr_annotation_classes: int = 5,
333 | annotation_background_class: Optional[int] = None,
334 | add_loc_feats: bool = False,
335 | **kwargs: Any
336 | ) -> None:
337 | """
338 | Base Graph Builder constructor.
339 | Args:
340 | nr_annotation_classes (int): Number of classes in annotation. Used only if setting node labels.
341 | annotation_background_class (int): Background class label in annotation. Used only if setting node labels.
342 | add_loc_feats (bool): Flag to include location-based features (ie normalized centroids)
343 | in node feature representation.
344 | Defaults to False.
345 | """
346 | self.nr_annotation_classes = nr_annotation_classes
347 | self.annotation_background_class = annotation_background_class
348 | self.add_loc_feats = add_loc_feats
349 | super().__init__(**kwargs)
350 |
351 | def _process( # type: ignore[override]
352 | self,
353 | instance_map: np.ndarray,
354 | features: torch.Tensor,
355 | annotation: Optional[np.ndarray] = None,
356 | ) -> dgl.DGLGraph:
357 | """Generates a graph from a given instance_map and features
358 | Args:
359 | instance_map (np.array): Instance map depicting tissue components
360 | features (torch.Tensor): Features of each node. Shape (nr_nodes, nr_features)
361 | annotation (Union[None, np.array], optional): Optional node level to include.
362 | Defaults to None.
363 | Returns:
364 | dgl.DGLGraph: The constructed graph
365 | """
366 | # add nodes
367 | num_nodes = features.shape[0]
368 | graph = dgl.DGLGraph()
369 | graph.add_nodes(num_nodes)
370 |
371 | # add image size as graph data
372 | image_size = (instance_map.shape[1], instance_map.shape[0]) # (x, y)
373 |
374 | # get instance centroids
375 | centroids = self._get_node_centroids(instance_map)
376 |
377 | # add node content
378 | self._set_node_centroids(centroids, graph)
379 | self._set_node_features(features, image_size, graph)
380 | if annotation is not None:
381 | self._set_node_labels(instance_map, annotation, graph)
382 |
383 | # build edges
384 | self._build_topology(instance_map, centroids, graph)
385 | return graph
386 |
387 | def _process_and_save( # type: ignore[override]
388 | self,
389 | instance_map: np.ndarray,
390 | features: torch.Tensor,
391 | annotation: Optional[np.ndarray] = None,
392 | output_name: str = None,
393 | ) -> dgl.DGLGraph:
394 | """Process and save in provided directory
395 | Args:
396 | output_name (str): Name of output file
397 | instance_map (np.ndarray): Instance map depicting tissue components
398 | (eg nuclei, tissue superpixels)
399 | features (torch.Tensor): Features of each node. Shape (nr_nodes, nr_features)
400 | annotation (Optional[np.ndarray], optional): Optional node level to include.
401 | Defaults to None.
402 | Returns:
403 | dgl.DGLGraph: [description]
404 | """
405 | assert (
406 | self.save_path is not None
407 | ), "Can only save intermediate output if base_path was not None during construction"
408 | output_path = self.output_dir / f"{output_name}.bin"
409 | if output_path.exists():
410 | logging.info(
411 | f"Output of {output_name} already exists, using it instead of recomputing"
412 | )
413 | graphs, _ = load_graphs(str(output_path))
414 | assert len(graphs) == 1
415 | graph = graphs[0]
416 | else:
417 | graph = self._process(
418 | instance_map=instance_map,
419 | features=features,
420 | annotation=annotation)
421 | save_graphs(str(output_path), [graph])
422 | return graph
423 |
424 | def _get_node_centroids(
425 | self, instance_map: np.ndarray
426 | ) -> np.ndarray:
427 | """Get the centroids of the graphs
428 | Args:
429 | instance_map (np.ndarray): Instance map depicting tissue components
430 | Returns:
431 | centroids (np.ndarray): Node centroids
432 | """
433 | regions = regionprops(instance_map)
434 | centroids = np.empty((len(regions), 2))
435 | for i, region in enumerate(regions):
436 | center_y, center_x = region.centroid # (y, x)
437 | center_x = int(round(center_x))
438 | center_y = int(round(center_y))
439 | centroids[i, 0] = center_x
440 | centroids[i, 1] = center_y
441 | return centroids
442 |
443 | def _set_node_centroids(
444 | self,
445 | centroids: np.ndarray,
446 | graph: dgl.DGLGraph
447 | ) -> None:
448 | """Set the centroids of the graphs
449 | Args:
450 | centroids (np.ndarray): Node centroids
451 | graph (dgl.DGLGraph): Graph to add the centroids to
452 | """
453 | graph.ndata[CENTROID] = torch.FloatTensor(centroids)
454 |
455 | def _set_node_features(
456 | self,
457 | features: torch.Tensor,
458 | image_size: Tuple[int, int],
459 | graph: dgl.DGLGraph
460 | ) -> None:
461 | """Set the provided node features
462 | Args:
463 | features (torch.Tensor): Node features
464 | image_size (Tuple[int,int]): Image dimension (x, y)
465 | graph (dgl.DGLGraph): Graph to add the features to
466 | """
467 | if not torch.is_tensor(features):
468 | features = torch.FloatTensor(features)
469 | if not self.add_loc_feats:
470 | graph.ndata[FEATURES] = features
471 | elif (
472 | self.add_loc_feats
473 | and image_size is not None
474 | ):
475 | # compute normalized centroid features
476 | centroids = graph.ndata[CENTROID]
477 |
478 | normalized_centroids = torch.empty_like(centroids) # (x, y)
479 | normalized_centroids[:, 0] = centroids[:, 0] / image_size[0]
480 | normalized_centroids[:, 1] = centroids[:, 1] / image_size[1]
481 |
482 | if features.ndim == 3:
483 | normalized_centroids = normalized_centroids \
484 | .unsqueeze(dim=1) \
485 | .repeat(1, features.shape[1], 1)
486 | concat_dim = 2
487 | elif features.ndim == 2:
488 | concat_dim = 1
489 |
490 | concat_features = torch.cat(
491 | (
492 | features,
493 | normalized_centroids
494 | ),
495 | dim=concat_dim,
496 | )
497 | graph.ndata[FEATURES] = concat_features
498 | else:
499 | raise ValueError(
500 | "Please provide image size to add the normalized centroid to the node features."
501 | )
502 |
503 | @abstractmethod
504 | def _set_node_labels(
505 | self,
506 | instance_map: np.ndarray,
507 | annotation: np.ndarray,
508 | graph: dgl.DGLGraph
509 | ) -> None:
510 | """Set the node labels of the graphs
511 | Args:
512 | instance_map (np.ndarray): Instance map depicting tissue components
513 | annotation (np.ndarray): Annotations, eg node labels
514 | graph (dgl.DGLGraph): Graph to add the centroids to
515 | """
516 |
517 | @abstractmethod
518 | def _build_topology(
519 | self,
520 | instance_map: np.ndarray,
521 | centroids: np.ndarray,
522 | graph: dgl.DGLGraph
523 | ) -> None:
524 | """Generate the graph topology from the provided instance_map
525 | Args:
526 | instance_map (np.array): Instance map depicting tissue components
527 | centroids (np.array): Node centroids
528 | graph (dgl.DGLGraph): Graph to add the edges
529 | """
530 |
531 | def precompute(
532 | self,
533 | link_path: Union[None, str, Path] = None,
534 | precompute_path: Union[None, str, Path] = None,
535 | ) -> None:
536 | """Precompute all necessary information
537 | Args:
538 | link_path (Union[None, str, Path], optional): Path to link to. Defaults to None.
539 | precompute_path (Union[None, str, Path], optional): Path to save precomputation outputs. Defaults to None.
540 | """
541 | if self.save_path is not None and link_path is not None:
542 | self._link_to_path(Path(link_path) / "graphs")
543 |
544 |
545 | class RAGGraphBuilder(BaseGraphBuilder):
546 | """
547 | Super-pixel Graphs class for graph building.
548 | """
549 |
550 | def __init__(self, kernel_size: int = 3, hops: int = 1, **kwargs) -> None:
551 | """Create a graph builder that uses a provided kernel size to detect connectivity
552 | Args:
553 | kernel_size (int, optional): Size of the kernel to detect connectivity. Defaults to 5.
554 | """
555 | logging.debug("*** RAG Graph Builder ***")
556 | assert hops > 0 and isinstance(
557 | hops, int
558 | ), f"Invalid hops {hops} ({type(hops)}). Must be integer >= 0"
559 | self.kernel_size = kernel_size
560 | self.hops = hops
561 | super().__init__(**kwargs)
562 |
563 | def _set_node_labels(
564 | self,
565 | instance_map: np.ndarray,
566 | annotation: np.ndarray,
567 | graph: dgl.DGLGraph) -> None:
568 | """Set the node labels of the graphs using annotation map"""
569 | assert (
570 | self.nr_annotation_classes < 256
571 | ), "Cannot handle that many classes with 8-bits"
572 | regions = regionprops(instance_map)
573 | labels = torch.empty(len(regions), dtype=torch.uint8)
574 |
575 | for region_label in np.arange(1, len(regions) + 1):
576 | histogram = fast_histogram(
577 | annotation[instance_map == region_label],
578 | nr_values=self.nr_annotation_classes
579 | )
580 | mask = np.ones(len(histogram), np.bool)
581 | mask[self.annotation_background_class] = 0
582 | if histogram[mask].sum() == 0:
583 | assignment = self.annotation_background_class
584 | else:
585 | histogram[self.annotation_background_class] = 0
586 | assignment = np.argmax(histogram)
587 | labels[region_label - 1] = int(assignment)
588 | graph.ndata[LABEL] = labels
589 |
590 | def _build_topology(
591 | self,
592 | instance_map: np.ndarray,
593 | centroids: np.ndarray,
594 | graph: dgl.DGLGraph
595 | ) -> None:
596 | """Create the graph topology from the instance connectivty in the instance_map"""
597 | regions = regionprops(instance_map)
598 | instance_ids = torch.empty(len(regions), dtype=torch.uint8)
599 |
600 | kernel = np.ones((3, 3), np.uint8)
601 | adjacency = np.zeros(shape=(len(instance_ids), len(instance_ids)))
602 |
603 | for instance_id in np.arange(1, len(instance_ids) + 1):
604 | mask = (instance_map == instance_id).astype(np.uint8)
605 | # print("mask:{}".format(mask))
606 | dilation = cv2.dilate(mask,kernel, iterations=1)
607 | # print("dilation:{}".format(dilation))
608 | boundary = dilation - mask
609 | # print("boundary:{}".format(boundary))
610 | # print(sum(sum(boundary)))
611 | idx = pd.unique(instance_map[boundary.astype(bool)])
612 | # print("idx:{}".format(idx))
613 | # print(len(idx))
614 | instance_id -= 1 # because instance_map id starts from 1
615 | idx -= 1 # because instance_map id starts from 1
616 | # print("new idx:{}".format(idx))
617 | # print(type(idx))
618 | idx = idx.tolist()
619 | # print(type(idx))
620 | if -1 in idx:
621 | idx.remove(-1)
622 | idx = np.array(idx)
623 | # print(type(idx))
624 | # print("new new idx:{}".format(idx))
625 | if idx.shape[0] != 0:
626 | adjacency[instance_id, idx] = 1
627 | # print(adjacency)
628 |
629 | edge_list = np.nonzero(adjacency)
630 | graph.add_edges(list(edge_list[0]), list(edge_list[1]))
631 |
632 | for _ in range(self.hops - 1):
633 | graph = two_hop_neighborhood(graph)
634 |
635 | """This module handles everything related to superpixels"""
636 |
637 | import logging
638 | import math
639 | import sys
640 | from abc import abstractmethod
641 | from pathlib import Path
642 | from typing import Any, Dict, Optional, Union
643 |
644 | import cv2
645 | import h5py
646 | import numpy as np
647 | from skimage.color.colorconv import rgb2hed
648 | from skimage.future import graph
649 | from skimage.segmentation import slic
650 | from skimage.future import graph
651 |
652 |
653 | class SuperpixelExtractor(PipelineStep):
654 | """Helper class to extract superpixels from images"""
655 |
656 | def __init__(
657 | self,
658 | nr_superpixels: int = None,
659 | superpixel_size: int = None,
660 | max_nr_superpixels: Optional[int] = None,
661 | blur_kernel_size: Optional[float] = 1,
662 | compactness: Optional[int] = 20,
663 | max_iterations: Optional[int] = 10,
664 | threshold: Optional[float] = 0.03,
665 | connectivity: Optional[int] = 2,
666 | color_space: Optional[str] = "rgb",
667 | downsampling_factor: Optional[int] = 1,
668 | **kwargs,
669 | ) -> None:
670 | """Abstract class that extracts superpixels from RGB Images
671 | Args:
672 | nr_superpixels (None, int): The number of super pixels before any merging.
673 | superpixel_size (None, int): The size of super pixels before any merging.
674 | max_nr_superpixels (int, optional): Upper bound for the number of super pixels.
675 | Useful when providing a superpixel size.
676 | blur_kernel_size (float, optional): Size of the blur kernel. Defaults to 0.
677 | compactness (int, optional): Compactness of the superpixels. Defaults to 30.
678 | max_iterations (int, optional): Number of iterations of the slic algorithm. Defaults to 10.
679 | threshold (float, optional): Connectivity threshold. Defaults to 0.03.
680 | connectivity (int, optional): Connectivity for merging graph. Defaults to 2.
681 | downsampling_factor (int, optional): Downsampling factor from the input image
682 | resolution. Defaults to 1.
683 | """
684 | assert (nr_superpixels is None and superpixel_size is not None) or (
685 | nr_superpixels is not None and superpixel_size is None
686 | ), "Provide value for either nr_superpixels or superpixel_size"
687 | self.nr_superpixels = nr_superpixels
688 | self.superpixel_size = superpixel_size
689 | self.max_nr_superpixels = max_nr_superpixels
690 | self.blur_kernel_size = blur_kernel_size
691 | self.compactness = compactness
692 | self.max_iterations = max_iterations
693 | self.threshold = threshold
694 | self.connectivity = connectivity
695 | self.color_space = color_space
696 | self.downsampling_factor = downsampling_factor
697 | super().__init__(**kwargs)
698 |
699 | def _process( # type: ignore[override]
700 | self, input_image: np.ndarray, tissue_mask: np.ndarray = None
701 | ) -> np.ndarray:
702 | """Return the superpixels of a given input image
703 | Args:
704 | input_image (np.array): Input image
705 | tissue_mask (None, np.array): Input tissue mask
706 | Returns:
707 | np.array: Extracted superpixels
708 | """
709 | logging.debug("Input size: %s", input_image.shape)
710 | original_height, original_width, _ = input_image.shape
711 | if self.downsampling_factor != 1:
712 | input_image = self._downsample(
713 | input_image, self.downsampling_factor)
714 | if tissue_mask is not None:
715 | tissue_mask = self._downsample(
716 | tissue_mask, self.downsampling_factor)
717 | logging.debug("Downsampled to %s", input_image.shape)
718 | superpixels = self._extract_superpixels(
719 | image=input_image, tissue_mask=tissue_mask
720 | )
721 | if self.downsampling_factor != 1:
722 | superpixels = self._upsample(
723 | superpixels, original_height, original_width)
724 | logging.debug("Upsampled to %s", superpixels.shape)
725 | return superpixels
726 |
727 | @abstractmethod
728 | def _extract_superpixels(
729 | self, image: np.ndarray, tissue_mask: np.ndarray = None
730 | ) -> np.ndarray:
731 | """Perform the superpixel extraction
732 | Args:
733 | image (np.array): Input tensor
734 | tissue_mask (np.array): Tissue mask tensor
735 | Returns:
736 | np.array: Output tensor
737 | """
738 |
739 | @staticmethod
740 | def _downsample(image: np.ndarray, downsampling_factor: int) -> np.ndarray:
741 | """Downsample an input image with a given downsampling factor
742 | Args:
743 | image (np.array): Input tensor
744 | downsampling_factor (int): Factor to downsample
745 | Returns:
746 | np.array: Output tensor
747 | """
748 | height, width = image.shape[0], image.shape[1]
749 | new_height = math.floor(height / downsampling_factor)
750 | new_width = math.floor(width / downsampling_factor)
751 | downsampled_image = cv2.resize(
752 | image, (new_width, new_height), interpolation=cv2.INTER_NEAREST
753 | )
754 | return downsampled_image
755 |
756 | @staticmethod
757 | def _upsample(
758 | image: np.ndarray,
759 | new_height: int,
760 | new_width: int) -> np.ndarray:
761 | """Upsample an input image to a speficied new height and width
762 | Args:
763 | image (np.array): Input tensor
764 | new_height (int): Target height
765 | new_width (int): Target width
766 | Returns:
767 | np.array: Output tensor
768 | """
769 | upsampled_image = cv2.resize(
770 | image, (new_width, new_height), interpolation=cv2.INTER_NEAREST
771 | )
772 | return upsampled_image
773 |
774 | def precompute(
775 | self,
776 | link_path: Union[None, str, Path] = None,
777 | precompute_path: Union[None, str, Path] = None,
778 | ) -> None:
779 | """Precompute all necessary information
780 | Args:
781 | link_path (Union[None, str, Path], optional): Path to link to. Defaults to None.
782 | precompute_path (Union[None, str, Path], optional): Path to save precomputation outputs. Defaults to None.
783 | """
784 | if self.save_path is not None and link_path is not None:
785 | self._link_to_path(Path(link_path) / "superpixels")
786 |
787 |
788 | class SLICSuperpixelExtractor(SuperpixelExtractor):
789 | """Use the SLIC algorithm to extract superpixels."""
790 |
791 | def __init__(self, **kwargs) -> None:
792 | """Extract superpixels with the SLIC algorithm"""
793 | super().__init__(**kwargs)
794 |
795 | def _get_nr_superpixels(self, image: np.ndarray) -> int:
796 | """Compute the number of superpixels for initial segmentation
797 | Args:
798 | image (np.array): Input tensor
799 | Returns:
800 | int: Output number of superpixels
801 | """
802 | if self.superpixel_size is not None:
803 | nr_superpixels = int(
804 | (image.shape[0] * image.shape[1] / self.superpixel_size)
805 | )
806 | elif self.nr_superpixels is not None:
807 | nr_superpixels = self.nr_superpixels
808 | if self.max_nr_superpixels is not None:
809 | nr_superpixels = min(nr_superpixels, self.max_nr_superpixels)
810 | return nr_superpixels
811 |
812 | def _extract_superpixels(
813 | self,
814 | image: np.ndarray,
815 | *args,
816 | **kwargs) -> np.ndarray:
817 | """Perform the superpixel extraction
818 | Args:
819 | image (np.array): Input tensor
820 | Returns:
821 | np.array: Output tensor
822 | """
823 | if self.color_space == "hed":
824 | image = rgb2hed(image)
825 | nr_superpixels = self._get_nr_superpixels(image)
826 | superpixels = slic(
827 | image,
828 | sigma=self.blur_kernel_size,
829 | n_segments=nr_superpixels,
830 | max_iter=self.max_iterations,
831 | compactness=self.compactness,
832 | start_label=1,
833 | )
834 | return superpixels
835 |
836 |
837 | class MergedSuperpixelExtractor(SuperpixelExtractor):
838 | def __init__(self, **kwargs) -> None:
839 | """Extract superpixels with the SLIC algorithm"""
840 | super().__init__(**kwargs)
841 |
842 | def _get_nr_superpixels(self, image: np.ndarray) -> int:
843 | """Compute the number of superpixels for initial segmentation
844 | Args:
845 | image (np.array): Input tensor
846 | Returns:
847 | int: Output number of superpixels
848 | """
849 | if self.superpixel_size is not None:
850 | nr_superpixels = int(
851 | (image.shape[0] * image.shape[1] / self.superpixel_size)
852 | )
853 | elif self.nr_superpixels is not None:
854 | nr_superpixels = self.nr_superpixels
855 | if self.max_nr_superpixels is not None:
856 | nr_superpixels = min(nr_superpixels, self.max_nr_superpixels)
857 | return nr_superpixels
858 |
859 | def _extract_initial_superpixels(self, image: np.ndarray) -> np.ndarray:
860 | """Extract initial superpixels using SLIC
861 | Args:
862 | image (np.array): Input tensor
863 | Returns:
864 | np.array: Output tensor
865 | """
866 | nr_superpixels = self._get_nr_superpixels(image)
867 | superpixels = slic(
868 | image,
869 | sigma=self.blur_kernel_size,
870 | n_segments=nr_superpixels,
871 | compactness=self.compactness,
872 | max_iter=self.max_iterations,
873 | start_label=1,
874 | )
875 | return superpixels
876 |
877 | def _merge_superpixels(
878 | self,
879 | input_image: np.ndarray,
880 | initial_superpixels: np.ndarray,
881 | tissue_mask: np.ndarray = None,
882 | ) -> np.ndarray:
883 | """Merge the initial superpixels to return merged superpixels
884 | Args:
885 | image (np.array): Input image
886 | initial_superpixels (np.array): Initial superpixels
887 | tissue_mask (None, np.array): Tissue mask
888 | Returns:
889 | np.array: Output merged superpixel tensor
890 | """
891 | if tissue_mask is not None:
892 | # Remove superpixels belonging to background or having < 10% tissue
893 | # content
894 | ids_initial = np.unique(initial_superpixels, return_counts=True)
895 | ids_masked = np.unique(
896 | tissue_mask * initial_superpixels, return_counts=True
897 | )
898 |
899 | ctr = 1
900 | superpixels = np.zeros_like(initial_superpixels)
901 | for i in range(len(ids_initial[0])):
902 | id = ids_initial[0][i]
903 | if id in ids_masked[0]:
904 | idx = np.where(id == ids_masked[0])[0]
905 | ratio = ids_masked[1][idx] / ids_initial[1][i]
906 | if ratio >= 0.1:
907 | superpixels[initial_superpixels == id] = ctr
908 | ctr += 1
909 |
910 | initial_superpixels = superpixels
911 |
912 | # Merge superpixels within tissue region
913 | g = graph.rag_mean_color(input_image, initial_superpixels)
914 | merged_superpixels = graph.merge_hierarchical(
915 | initial_superpixels,
916 | g,
917 | thresh=self.threshold,
918 | rag_copy=False,
919 | in_place_merge=True,
920 | merge_func=self._merging_function,
921 | weight_func=self._weighting_function,
922 | )
923 | merged_superpixels += 1 # Handle regionprops that ignores all values of 0
924 | # mask = np.zeros_like(initial_superpixels)
925 | # mask[initial_superpixels != 0] = 1
926 | # merged_superpixels = merged_superpixels * mask
927 | return merged_superpixels
928 |
929 | @abstractmethod
930 | def _weighting_function(
931 | self, graph: graph.RAG, src: int, dst: int, n: int
932 | ) -> Dict[str, Any]:
933 | """Handle merging of nodes of a region boundary region adjacency graph."""
934 |
935 | @abstractmethod
936 | def _merging_function(self, graph: graph.RAG, src: int, dst: int) -> None:
937 | """Call back called before merging 2 nodes."""
938 |
939 | def _extract_superpixels(
940 | self, image: np.ndarray, tissue_mask: np.ndarray = None
941 | ) -> np.ndarray:
942 | """Perform superpixel extraction
943 | Args:
944 | image (np.array): Input tensor
945 | tissue_mask (np.array, optional): Input tissue mask
946 | Returns:
947 | np.array: Extracted merged superpixels.
948 | np.array: Extracted init superpixels, ie before merging.
949 | """
950 | initial_superpixels = self._extract_initial_superpixels(image)
951 | merged_superpixels = self._merge_superpixels(
952 | image, initial_superpixels, tissue_mask
953 | )
954 |
955 | return merged_superpixels, initial_superpixels
956 |
957 | def _process( # type: ignore[override]
958 | self, input_image: np.ndarray, tissue_mask=None
959 | ) -> np.ndarray:
960 | """Return the superpixels of a given input image
961 | Args:
962 | input_image (np.array): Input image.
963 | tissue_mask (None, np.array): Tissue mask.
964 | Returns:
965 | np.array: Extracted merged superpixels.
966 | np.array: Extracted init superpixels, ie before merging.
967 | """
968 | logging.debug("Input size: %s", input_image.shape)
969 | original_height, original_width, _ = input_image.shape
970 | if self.downsampling_factor is not None and self.downsampling_factor != 1:
971 | input_image = self._downsample(
972 | input_image, self.downsampling_factor)
973 | if tissue_mask is not None:
974 | tissue_mask = self._downsample(
975 | tissue_mask, self.downsampling_factor)
976 | logging.debug("Downsampled to %s", input_image.shape)
977 | merged_superpixels, initial_superpixels = self._extract_superpixels(
978 | input_image, tissue_mask
979 | )
980 | if self.downsampling_factor != 1:
981 | merged_superpixels = self._upsample(
982 | merged_superpixels, original_height, original_width
983 | )
984 | initial_superpixels = self._upsample(
985 | initial_superpixels, original_height, original_width
986 | )
987 | logging.debug("Upsampled to %s", merged_superpixels.shape)
988 | return merged_superpixels, initial_superpixels
989 |
990 | def _process_and_save(
991 | self,
992 | *args: Any,
993 | output_name: str,
994 | **kwargs: Any) -> Any:
995 | """Process and save in the provided path as as .h5 file
996 | Args:
997 | output_name (str): Name of output file
998 | """
999 | assert (
1000 | self.save_path is not None
1001 | ), "Can only save intermediate output if base_path was not None when constructing the object"
1002 | superpixel_output_path = self.output_dir / f"{output_name}.h5"
1003 | if superpixel_output_path.exists():
1004 | logging.info(
1005 | f"{self.__class__.__name__}: Output of {output_name} already exists, using it instead of recomputing"
1006 | )
1007 | try:
1008 | with h5py.File(superpixel_output_path, "r") as input_file:
1009 | merged_superpixels, initial_superpixels = self._get_outputs(
1010 | input_file=input_file)
1011 | except OSError as e:
1012 | print(
1013 | f"\n\nCould not read from {superpixel_output_path}!\n\n",
1014 | file=sys.stderr,
1015 | flush=True,
1016 | )
1017 | print(
1018 | f"\n\nCould not read from {superpixel_output_path}!\n\n",
1019 | flush=True)
1020 | raise e
1021 | else:
1022 | merged_superpixels, initial_superpixels = self._process(
1023 | *args, **kwargs)
1024 | try:
1025 | with h5py.File(superpixel_output_path, "w") as output_file:
1026 | self._set_outputs(
1027 | output_file=output_file,
1028 | outputs=(merged_superpixels, initial_superpixels),
1029 | )
1030 | except OSError as e:
1031 | print(
1032 | f"\n\nCould not write to {superpixel_output_path}!\n\n",
1033 | flush=True)
1034 | raise e
1035 | return merged_superpixels, initial_superpixels
1036 |
1037 |
1038 | class MyColorMergedSuperpixelExtractor(MergedSuperpixelExtractor):
1039 | def __init__(
1040 | self,
1041 | w_hist: float = 0.5,
1042 | w_mean: float = 0.5,
1043 | **kwargs) -> None:
1044 | """Superpixel merger based on color attibutes taken from the HACT-Net Implementation
1045 | Args:
1046 | w_hist (float, optional): Weight of the histogram features for merging. Defaults to 0.5.
1047 | w_mean (float, optional): Weight of the mean features for merging. Defaults to 0.5.
1048 | """
1049 | self.w_hist = w_hist
1050 | self.w_mean = w_mean
1051 | super().__init__(**kwargs)
1052 |
1053 | def _color_features_per_channel(self, img_ch: np.ndarray) -> np.ndarray:
1054 | """Extract color histograms from image channel
1055 | Args:
1056 | img_ch (np.ndarray): Image channel
1057 | Returns:
1058 | np.ndarray: Histogram of the image channel
1059 | """
1060 | hist, _ = np.histogram(img_ch, bins=np.arange(0, 257, 64)) # 8 bins
1061 | return hist
1062 |
1063 | def _weighting_function(
1064 | self, graph: graph.RAG, src: int, dst: int, n: int
1065 | ) -> Dict[str, Any]:
1066 | diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
1067 | diff = np.linalg.norm(diff)
1068 | return {'weight': diff}
1069 |
1070 | def _merging_function(self, graph: graph.RAG, src: int, dst: int) -> None:
1071 | graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
1072 | graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
1073 | graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
1074 | graph.nodes[dst]['pixel count'])
--------------------------------------------------------------------------------
/data_splits/train_val_test_split_argo.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/data_splits/train_val_test_split_argo.pkl
--------------------------------------------------------------------------------
/data_splits/train_val_test_split_kirc.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/data_splits/train_val_test_split_kirc.pkl
--------------------------------------------------------------------------------
/data_splits/train_val_test_split_lihc.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/data_splits/train_val_test_split_lihc.pkl
--------------------------------------------------------------------------------
/data_splits/useless:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/miccai_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/images/miccai_framework.png
--------------------------------------------------------------------------------
/images/try:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/label/argo_label.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/label/argo_label.pt
--------------------------------------------------------------------------------
/label/kirc_label.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/label/kirc_label.pt
--------------------------------------------------------------------------------
/label/lihc_label.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baeksweety/superpixel_transformer/5c5f395cc9759a48f1ad7cdd01eeaaf7d445b67c/label/lihc_label.pt
--------------------------------------------------------------------------------
/label/useless:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dgl==0.4.3
2 | einops==0.3.0
3 | histocartography==0.2.0
4 | joblib==1.0.0
5 | lifelines==0.25.7
6 | numpy==1.19.2
7 | opencv-python==4.5.1.48
8 | opencv-python-headless==4.1.2.30
9 | openpyxl==3.0.8
10 | openslide-python==1.1.2
11 | pandas==1.2.0
12 | Pillow==8.1.0
13 | scikit-image==0.18.1
14 | scikit-learn==1.1.2
15 | scipy==1.6.0
16 | sklearn-pandas==2.0.4
17 | timm==0.4.12
18 | torch==1.7.1
19 | torch-cluster==1.5.8
20 | torch-geometric==1.6.3
21 | torch-scatter==2.0.5
22 | torch-sparse==0.6.8
23 | torch-spline-conv==1.2.0
24 | torchaudio==0.7.0
25 | torchstain==1.1.0
26 | torchvision==0.8.2
27 | tqdm==4.56.0
28 |
--------------------------------------------------------------------------------
/train/block_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import Mlp,DropPath
4 |
5 | class Attention(nn.Module):
6 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
7 | super().__init__()
8 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
9 | self.num_heads = num_heads
10 | head_dim = dim // num_heads
11 | self.scale = head_dim ** -0.5
12 |
13 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
14 | self.attn_drop = nn.Dropout(attn_drop)
15 | self.proj = nn.Linear(dim, dim)
16 | self.proj_drop = nn.Dropout(proj_drop)
17 |
18 | self.attention_weights: Optional[Tensor] = None
19 |
20 | def forward(self, x):
21 | B, N, C = x.shape
22 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
23 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
24 |
25 | attn = (q @ k.transpose(-2, -1)) * self.scale
26 | attn = attn.softmax(dim=-1)
27 | attn = self.attn_drop(attn)
28 | # print(attn.shape)
29 | self.attention_weights = attn
30 |
31 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
32 | x = self.proj(x)
33 | x = self.proj_drop(x)
34 | return x
35 |
36 | def get_attention_weights(self):
37 | return self.attention_weights
38 |
39 |
40 | class LayerScale(nn.Module):
41 | def __init__(self, dim, init_values=1e-5, inplace=False):
42 | super().__init__()
43 | self.inplace = inplace
44 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
45 |
46 | def forward(self, x):
47 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
48 |
49 |
50 | class Block(nn.Module):
51 |
52 | def __init__(
53 | self,
54 | dim,
55 | num_heads,
56 | mlp_ratio=4.,
57 | qkv_bias=False,
58 | drop=0.,
59 | attn_drop=0.,
60 | init_values=None,
61 | drop_path=0.,
62 | act_layer=nn.GELU,
63 | norm_layer=nn.LayerNorm
64 | ):
65 | super().__init__()
66 | self.norm1 = norm_layer(dim)
67 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
68 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
69 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
70 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
71 |
72 | self.norm2 = norm_layer(dim)
73 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
74 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
75 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
76 |
77 | def forward(self, x):
78 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
79 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
80 | return x
81 |
82 | def get_attention_weights(self):
83 | return self.attn.get_attention_weights()
--------------------------------------------------------------------------------
/train/run.sh:
--------------------------------------------------------------------------------
1 | # #!/bin/bash
2 |
3 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 0 --label "tcga_argo_sage_max_1024featdim_fold0_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1"
4 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 1 --label "tcga_argo_sage_max_1024featdim_fold1_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1"
5 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 2 --label "tcga_argo_sage_max_1024featdim_fold2_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1"
6 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 3 --label "tcga_argo_sage_max_1024featdim_fold3_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1"
7 | python train.py --mpool_intra 'global_max_pool' --seed 1 --fold_num 4 --label "tcga_argo_sage_max_1024featdim_fold4_superpixel600_cluster16_numhead8_lr1e-5_30epoch_l2regalpha0.001_batchsize_16_split0_depth1"
8 |
9 | #according to the dataset and parameters, the value of label should be changed
10 |
--------------------------------------------------------------------------------
/train/superpixel_transformer_n.py:
--------------------------------------------------------------------------------
1 | from timm.models.vision_transformer import VisionTransformer
2 | import timm.models.vision_transformer
3 | import skimage.io as io
4 | import argparse
5 | import joblib
6 | import copy
7 | import random
8 | import os
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | import skimage.io as io
12 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_,PatchEmbed
13 | from timm.models.helpers import build_model_with_cfg, named_apply
14 | from torch_geometric.nn import global_mean_pool,global_max_pool,GlobalAttention,dense_diff_pool,global_add_pool,TopKPooling,ASAPooling,SAGPooling
15 | from torch_geometric.nn import GCNConv,ChebConv,SAGEConv,GraphConv,LEConv,LayerNorm,GATConv
16 | import torch
17 | from sklearn.metrics import accuracy_score,f1_score,roc_auc_score
18 | import torch.nn as nn
19 | torch.set_num_threads(8)
20 | import numpy as np
21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 | from functools import partial
23 | from block_utils import Block
24 | from torch_geometric.data import Data as geomData
25 | from timm.models.layers import trunc_normal_
26 | from torch_scatter import scatter_add
27 | from torch_geometric.utils import softmax
28 | try:
29 | from torch import _assert
30 | except ImportError:
31 | def _assert(condition: bool, message: str):
32 | assert condition, message
33 |
34 |
35 | def reset(nn):
36 | def _reset(item):
37 | if hasattr(item, 'reset_parameters'):
38 | item.reset_parameters()
39 |
40 | if nn is not None:
41 | if hasattr(nn, 'children') and len(list(nn.children())) > 0:
42 | for item in nn.children():
43 | _reset(item)
44 | else:
45 | _reset(nn)
46 |
47 | class my_GlobalAttention(torch.nn.Module):
48 | def __init__(self, gate_nn, nn=None):
49 | super(my_GlobalAttention, self).__init__()
50 | self.gate_nn = gate_nn
51 | self.nn = nn
52 |
53 | self.reset_parameters()
54 |
55 | def reset_parameters(self):
56 | reset(self.gate_nn)
57 | reset(self.nn)
58 |
59 |
60 | def forward(self, x, batch, size=None):
61 | """"""
62 | x = x.unsqueeze(-1) if x.dim() == 1 else x
63 | size = batch.max().item() + 1 if size is None else size #modified
64 |
65 | gate = self.gate_nn(x).view(-1, 1)
66 | x = self.nn(x) if self.nn is not None else x
67 | assert gate.dim() == x.dim() and gate.size(0) == x.size(0)
68 |
69 | gate = softmax(gate, batch, num_nodes=size)
70 | out = scatter_add(gate * x, batch, dim=0, dim_size=size)
71 |
72 | return out
73 |
74 |
75 | class Intra_GCN(nn.Module):
76 | def __init__(self,in_feats,n_hidden,out_feats,drop_out_ratio=0.2,mpool_method="global_mean_pool",gnn_method='sage'):
77 | super(Intra_GCN,self).__init__()
78 | if gnn_method == 'sage':
79 | self.conv1= SAGEConv(in_channels=in_feats,out_channels=out_feats)
80 | elif gnn_method == 'gcn':
81 | self.conv1= GCNConv(in_channels=in_feats,out_channels=out_feats)
82 | elif gnn_method == 'gat':
83 | self.conv1= GATConv(in_channels=in_feats,out_channels=out_feats)
84 | elif gnn_method == 'leconv':
85 | self.conv1= LEConv(in_channels=in_feats,out_channels=out_feats)
86 | elif gnn_method == 'graphconv':
87 | self.conv1= GraphConv(in_channels=in_feats,out_channels=out_feats)
88 | # self.conv2= SAGEConv(in_channels=n_hidden,out_channels=out_feats)
89 |
90 | self.relu = torch.nn.ReLU()
91 | self.sigmoid = torch.nn.Sigmoid()
92 | self.dropout=nn.Dropout(p=drop_out_ratio)
93 | self.softmax = nn.Softmax(dim=-1)
94 |
95 | if mpool_method == "global_mean_pool":
96 | self.mpool = global_mean_pool
97 | elif mpool_method == "global_max_pool":
98 | self.mpool = global_max_pool
99 | elif mpool_method == "global_att_pool":
100 | att_net=nn.Sequential(nn.Linear(out_feats, out_feats//2), nn.ReLU(), nn.Linear(out_feats//2, 1))
101 | self.mpool = my_GlobalAttention(att_net)
102 | self.norm = LayerNorm(in_feats)
103 | self.norm2 = LayerNorm(out_feats)
104 | self.norm1 = LayerNorm(n_hidden)
105 |
106 | def forward(self,data):
107 | x=data.x
108 | edge_index = data.edge_patch
109 |
110 | x = self.norm(x)
111 | x = self.conv1(x,edge_index)
112 | x = self.relu(x)
113 | # x = self.sigmoid(x)
114 | # x = self.norm(x)
115 | x = self.norm1(x)
116 | x = self.dropout(x)
117 |
118 | # x = self.conv2(x,edge_index)
119 | # x = self.relu(x)
120 | # # x = self.sigmoid(x)
121 | # # x = self.norm(x)
122 | # x = self.norm2(x)
123 | # x = self.dropout(x)
124 | # print(x)
125 |
126 | # batch = torch.zeros(len(x),dtype=torch.long).to(device)
127 | batch = data.superpixel_attri.to(device)
128 | x = self.mpool(x,batch)
129 | # print('fea dim is {}'.format(x.shape))
130 | # print(x)
131 |
132 | fea = x
133 | # print(fea.shape)
134 |
135 | return fea
136 |
137 | class Inter_GCN(nn.Module):
138 | def __init__(self,in_feats,n_hidden,out_feats,drop_out_ratio=0.2,mpool_method="global_mean_pool",gnn_method='sage'):
139 | super(Inter_GCN,self).__init__()
140 | # self.conv1= SAGEConv(in_channels=in_feats,out_channels=out_feats)
141 | # self.conv2= SAGEConv(in_channels=n_hidden,out_channels=out_feats)
142 | if gnn_method == 'sage':
143 | self.conv1= SAGEConv(in_channels=in_feats,out_channels=out_feats)
144 | elif gnn_method == 'gcn':
145 | self.conv1= GCNConv(in_channels=in_feats,out_channels=out_feats)
146 | elif gnn_method == 'gat':
147 | self.conv1= GATConv(in_channels=in_feats,out_channels=out_feats)
148 | elif gnn_method == 'leconv':
149 | self.conv1= LEConv(in_channels=in_feats,out_channels=out_feats)
150 | elif gnn_method == 'graphconv':
151 | self.conv1= GraphConv(in_channels=in_feats,out_channels=out_feats)
152 |
153 | self.relu = torch.nn.ReLU()
154 | self.sigmoid = torch.nn.Sigmoid()
155 | self.dropout=nn.Dropout(p=drop_out_ratio)
156 | self.softmax = nn.Softmax(dim=-1)
157 |
158 | if mpool_method == "global_mean_pool":
159 | self.mpool = global_mean_pool
160 | elif mpool_method == "global_max_pool":
161 | self.mpool = global_max_pool
162 | elif mpool_method == "global_att_pool":
163 | att_net=nn.Sequential(nn.Linear(out_feats, out_feats//2), nn.ReLU(), nn.Linear(out_feats//2, 1))
164 | self.mpool = my_GlobalAttention(att_net)
165 | self.norm = LayerNorm(in_feats)
166 | self.norm2 = LayerNorm(out_feats)
167 | self.norm1 = LayerNorm(n_hidden)
168 |
169 | def forward(self,data,feature):
170 | x=feature
171 | edge_index = data.edge_superpixel
172 | # print(x.shape)
173 | x = self.norm(x)
174 | x = self.conv1(x,edge_index)
175 | x = self.relu(x)
176 | # x = self.sigmoid(x)
177 | x = self.norm1(x)
178 | x = self.dropout(x)
179 |
180 | # x = self.conv2(x,edge_index)
181 | # x = self.relu(x)
182 | # # x = self.sigmoid(x)
183 | # x = self.norm2(x)
184 | # x = self.dropout(x)
185 |
186 | # batch = torch.zeros(len(x),dtype=torch.long).to(device)
187 | # x = self.mpool(x,batch)
188 |
189 | fea = x
190 | # print(x.shape)
191 | # print(fea.shape)
192 |
193 | return fea
194 |
195 |
196 |
197 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
198 | """ Vision Transformer with support for global average pooling
199 | """
200 | def __init__(self, num_patches=100,no_embed_class=False,class_token=True, depth=1,drop_path_rate=0.,mlp_ratio=4.,pre_norm=True,qkv_bias=True,init_values=None,drop_rate=0.,attn_drop_rate=0.,norm_layer=None,act_layer=None,weight_init='',global_pool='token', fc_norm=None,**kwargs):
201 | super(VisionTransformer, self).__init__(**kwargs)
202 | embed_dim = kwargs['embed_dim']
203 | self.patch_embed = nn.Linear(embed_dim,embed_dim)
204 | num_patches = num_patches
205 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
206 | act_layer = act_layer or nn.GELU
207 | self.no_embed_class = no_embed_class
208 | self.global_pool = global_pool
209 |
210 | self.num_prefix_tokens = 1 if class_token else 0
211 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
212 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
213 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
214 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
215 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
216 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
217 |
218 | self.blocks = nn.Sequential(*[
219 | Block(
220 | dim=embed_dim,
221 | num_heads=kwargs['num_heads'],
222 | mlp_ratio=mlp_ratio,
223 | qkv_bias=qkv_bias,
224 | init_values=init_values,
225 | drop=drop_rate,
226 | attn_drop=attn_drop_rate,
227 | drop_path=dpr[i],
228 | norm_layer=norm_layer,
229 | act_layer=act_layer
230 | )
231 | for i in range(depth)])
232 | self.init_weights(weight_init)
233 |
234 | def init_weights(self, mode=''):
235 | assert mode in ('jax', 'jax_nlhb', 'moco', '')
236 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
237 | trunc_normal_(self.pos_embed, std=.02)
238 | if self.cls_token is not None:
239 | nn.init.normal_(self.cls_token, std=1e-6)
240 | named_apply(get_init_weights_vit(mode, head_bias), self)
241 |
242 | def _pos_embed(self, x):
243 | if self.no_embed_class:
244 | # deit-3, updated JAX (big vision)
245 | # position embedding does not overlap with class token, add then concat
246 | x = x + self.pos_embed
247 | if self.cls_token is not None:
248 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
249 | else:
250 | # original timm, JAX, and deit vit impl
251 | # pos_embed has entry for class token, concat then add
252 | if self.cls_token is not None:
253 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
254 | x = x + self.pos_embed
255 | return self.pos_drop(x)
256 |
257 | def forward_features(self, x):
258 | x = self.patch_embed(x)
259 | x = self._pos_embed(x)
260 | x = self.norm_pre(x)
261 | # if self.grad_checkpointing and not torch.jit.is_scripting():
262 | # x = checkpoint_seq(self.blocks, x)
263 | # else:
264 | x = self.blocks(x)
265 | x = self.norm(x)
266 | return x
267 |
268 | def forward_head(self, x, pre_logits: bool = False):
269 | if self.global_pool:
270 | x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] #token
271 | x = self.fc_norm(x)
272 | return x if pre_logits else self.head(x)
273 |
274 | def forward(self, x):
275 | x = self.forward_features(x)
276 | x = self.forward_head(x)
277 | return torch.sigmoid(x)
278 |
279 | def get_attention_weights(self):
280 | return [block.get_attention_weights() for block in self.blocks]
281 |
282 | def init_weights_vit_timm(module: nn.Module, name: str = ''):
283 | """ ViT weight initialization, original timm impl (for reproducibility) """
284 | if isinstance(module, nn.Linear):
285 | trunc_normal_(module.weight, std=.02)
286 | if module.bias is not None:
287 | nn.init.zeros_(module.bias)
288 | elif hasattr(module, 'init_weights'):
289 | module.init_weights()
290 |
291 |
292 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
293 | """ ViT weight initialization, matching JAX (Flax) impl """
294 | if isinstance(module, nn.Linear):
295 | if name.startswith('head'):
296 | nn.init.zeros_(module.weight)
297 | nn.init.constant_(module.bias, head_bias)
298 | else:
299 | nn.init.xavier_uniform_(module.weight)
300 | if module.bias is not None:
301 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
302 | elif isinstance(module, nn.Conv2d):
303 | lecun_normal_(module.weight)
304 | if module.bias is not None:
305 | nn.init.zeros_(module.bias)
306 | elif hasattr(module, 'init_weights'):
307 | module.init_weights()
308 |
309 |
310 | def init_weights_vit_moco(module: nn.Module, name: str = ''):
311 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
312 | if isinstance(module, nn.Linear):
313 | if 'qkv' in name:
314 | # treat the weights of Q, K, V separately
315 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
316 | nn.init.uniform_(module.weight, -val, val)
317 | else:
318 | nn.init.xavier_uniform_(module.weight)
319 | if module.bias is not None:
320 | nn.init.zeros_(module.bias)
321 | elif hasattr(module, 'init_weights'):
322 | module.init_weights()
323 |
324 |
325 | def get_init_weights_vit(mode='jax', head_bias: float = 0.):
326 | if 'jax' in mode:
327 | return partial(init_weights_vit_jax, head_bias=head_bias)
328 | elif 'moco' in mode:
329 | return init_weights_vit_moco
330 | else:
331 | return init_weights_vit_timm
332 |
333 |
334 | class Superpixel_Vit(nn.Module):
335 | def __init__(self,in_feats_intra=1024,n_hidden_intra=1024,out_feats_intra=1024,in_feats_inter=1024,n_hidden_inter=1024,out_feats_inter=1024,vw_num=16,feat_dim=1024,num_classes=1,depth=1,num_heads = 16,final_fea_type = 'mean',mpool_intra='global_mean_pool',mpool_inter='global_mean_pool',gnn_intra='sage',gnn_inter='sage'):
336 | super(Superpixel_Vit, self).__init__()
337 |
338 | self.vw_num = vw_num
339 | self.feat_dim = feat_dim
340 |
341 | #intra-graph
342 | self.gcn1 = Intra_GCN(in_feats=in_feats_intra,n_hidden=n_hidden_intra,out_feats=out_feats_intra,mpool_method=mpool_intra,gnn_method=gnn_intra)
343 |
344 | #inter-graph
345 | self.gcn2 = Inter_GCN(in_feats=in_feats_inter,n_hidden=n_hidden_inter,out_feats=out_feats_inter,mpool_method=mpool_inter,gnn_method=gnn_inter)
346 | self.vit = VisionTransformer(num_patches = vw_num,num_classes = num_classes, embed_dim = feat_dim,depth = depth,num_heads = num_heads)
347 |
348 | self.final_fea_type = final_fea_type
349 |
350 | def superpixel_graph(self,data):
351 | superpixel_attri = data.superpixel_attri
352 | min_value = int(min(superpixel_attri))
353 | #intra-graph
354 | superpixel_fea = self.gcn1(data)
355 | #intra-graph
356 | superpixel_feas = self.gcn2(data,superpixel_fea)
357 | if min_value == 0:
358 | print('min superpixel value is 0')
359 | superpixel_fea_all={}
360 | for index in range(1,superpixel_feas.shape[0]):
361 | fea = superpixel_feas[index].unsqueeze(0)
362 | superpixel_value = index
363 | superpixel_fea_all[superpixel_value] = fea
364 | else:
365 | print('min superpixel value is 1')
366 | superpixel_fea_all={}
367 | for index in range(superpixel_feas.shape[0]):
368 | fea = superpixel_feas[index].unsqueeze(0)
369 | superpixel_value = index+1
370 | superpixel_fea_all[superpixel_value] = fea
371 | return superpixel_fea_all
372 |
373 |
374 | def mean_feature(self,superpixel_features,cluster_info):
375 | mask=np.zeros((self.vw_num,self.feat_dim))
376 | mask=torch.tensor(mask).to(device)
377 | # superpixel_cluster_path = os.path.join(cluster_info_path,slidename+'.pth')
378 | # cluster_info = torch.load(superpixel_cluster_path)
379 | for vw in range(self.vw_num):
380 | fea_all=torch.Tensor().to(device)
381 | for superpixel_value in cluster_info.keys():
382 | if cluster_info[superpixel_value]['cluster']==vw:
383 | if fea_all.shape[0]==0:
384 | fea_all=superpixel_features[superpixel_value]
385 | else:
386 | fea_all=torch.cat((fea_all,superpixel_features[superpixel_value]),dim=0)
387 | if fea_all.shape[0]!=0:
388 | fea_avg=torch.mean(fea_all,axis=0)
389 | # print('fea_avg shape:{}'.format(fea_avg.shape))
390 | mask[vw]=fea_avg
391 | return mask
392 |
393 | def max_feature(self,superpixel_features,cluster_info):
394 | mask=np.zeros((self.vw_num,self.feat_dim))
395 | mask=torch.tensor(mask).to(device)
396 | # superpixel_cluster_path = os.path.join(cluster_info_path,slidename+'.pth')
397 | # cluster_info = torch.load(superpixel_cluster_path)
398 | for vw in range(self.vw_num):
399 | fea_all=torch.Tensor().to(device)
400 | for superpixel_value in cluster_info.keys():
401 | if cluster_info[superpixel_value]['cluster']==vw:
402 | if fea_all.shape[0]==0:
403 | fea_all=superpixel_features[superpixel_value]
404 | else:
405 | fea_all=torch.cat((fea_all,superpixel_features[superpixel_value]),dim=0)
406 | if fea_all.shape[0]!=0:
407 | fea_max,_=torch.max(fea_all,dim=0)
408 | # print('fea_avg shape:{}'.format(fea_avg.shape))
409 | mask[vw]=fea_max
410 | return mask
411 |
412 | def forward(self,data,cluster_info):
413 |
414 | superpixels_fea = self.superpixel_graph(data)
415 | #final-fea
416 | if self.final_fea_type == 'mean':
417 | fea = self.mean_feature(superpixels_fea,cluster_info) #[16,1024]
418 | elif self.final_fea_type == 'max':
419 | fea = self.max_feature(superpixels_fea,cluster_info)
420 | fea = fea.unsqueeze(0) #[1,16,1024]
421 | # print(fea.shape)
422 | # print(fea.shape)
423 | #vit
424 | fea = fea.float()
425 | out = self.vit(fea)
426 | return out
427 |
428 | def get_attention_weights(self):
429 | return self.vit.get_attention_weights()
430 |
--------------------------------------------------------------------------------
/train/train.py:
--------------------------------------------------------------------------------
1 | from timm.models.vision_transformer import VisionTransformer
2 | import timm.models.vision_transformer
3 | import skimage.io as io
4 | import argparse
5 | import joblib
6 | import copy
7 | import random
8 | import os
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '7'
10 |
11 | import skimage.io as io
12 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_,PatchEmbed
13 | # from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq
14 | import torch
15 | from sklearn.metrics import accuracy_score,f1_score,roc_auc_score
16 | # from model_utils import Block,DropPath,get_sinusoid_encoding_table
17 | import torch.nn as nn
18 | torch.set_num_threads(8)
19 | # from lifelines.utils import concordance_index
20 | import numpy as np
21 | from lifelines.utils import concordance_index as ci
22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 | from functools import partial
24 | from block_utils import Block
25 | from superpixel_transformer_n import Superpixel_Vit
26 |
27 | try:
28 | from torch import _assert
29 | except ImportError:
30 | def _assert(condition: bool, message: str):
31 | assert condition, message
32 |
33 |
34 | def _neg_partial_log(prediction, T, E):
35 |
36 | current_batch_len = len(prediction)
37 | R_matrix_train = np.zeros([current_batch_len, current_batch_len], dtype=int)
38 | for i in range(current_batch_len):
39 | for j in range(current_batch_len):
40 | R_matrix_train[i, j] = T[j] >= T[i]
41 |
42 | train_R = torch.FloatTensor(R_matrix_train)
43 | train_R = train_R.cuda()
44 |
45 | # train_ystatus = torch.tensor(np.array(E),dtype=torch.float).to(device)
46 | train_ystatus = E
47 |
48 | theta = prediction.reshape(-1)
49 |
50 | exp_theta = torch.exp(theta)
51 | loss_nn = - torch.mean((theta - torch.log(torch.sum(exp_theta * train_R, dim=1))) * train_ystatus)
52 |
53 | return loss_nn
54 |
55 | def get_val_ci(pre_time,patient_and_time,patient_sur_type):
56 | ordered_time, ordered_pred_time, ordered_observed=[],[],[]
57 | for x in pre_time:
58 | ordered_time.append(patient_and_time[x])
59 | ordered_pred_time.append(pre_time[x]*-1)
60 | ordered_observed.append(patient_sur_type[x])
61 | # print(len(ordered_time), len(ordered_pred_time), len(ordered_observed))
62 | return ci(ordered_time, ordered_pred_time, ordered_observed)
63 |
64 |
65 | def get_train_val(train_slide,val_slide,test_slide,label_path):
66 | train_slidename=[]
67 | train_label_censorship=[]
68 | train_label_survtime=[]
69 | val_slidename=[]
70 | val_label_censorship=[]
71 | val_label_survtime=[]
72 | test_slidename=[]
73 | test_label_censorship=[]
74 | test_label_survtime=[]
75 | label = torch.load(label_path)
76 |
77 | for index in train_slide:
78 | train_slidename.append(index)
79 | train_label_censorship.append(label[index]['censorship'])
80 | train_label_survtime.append(label[index]['surv_time'])
81 |
82 | for index1 in val_slide:
83 | val_slidename.append(index1)
84 | val_label_censorship.append(label[index1]['censorship'])
85 | val_label_survtime.append(label[index1]['surv_time'])
86 |
87 | for index2 in test_slide:
88 | test_slidename.append(index2)
89 | test_label_censorship.append(label[index2]['censorship'])
90 | test_label_survtime.append(label[index2]['surv_time'])
91 |
92 | return train_slidename,val_slidename,test_slidename,train_label_censorship,train_label_survtime,val_label_censorship,val_label_survtime,test_label_censorship,test_label_survtime
93 |
94 |
95 |
96 | def eval_model(model,test_slide,test_label_surv_type,test_label_time,batch_size,cuda=True):
97 | model = model.cuda()
98 | model = model.to(device)
99 | model = model.eval()
100 |
101 | #print('begin evaluation!')
102 | with torch.no_grad():
103 | running_loss=0.
104 | test_loss = 0.
105 | t_out_pre = torch.Tensor().to(device)
106 | t_labelall_surv_type = torch.Tensor().to(device)
107 | t_labelall_time = torch.Tensor().to(device)
108 |
109 | atten_dict = {}
110 | total_loss_for_test = 0
111 | # batch_id = 0
112 | test_num = len(test_slide)
113 | # batch divide
114 | iter_num = test_num // batch_size + 1
115 |
116 | # batch
117 | for batch_iter in range(iter_num):
118 | outs = torch.Tensor().to(device)
119 | labels_surv_type = torch.Tensor().to(device)
120 | labels_time = torch.Tensor().to(device)
121 |
122 | if (batch_iter + 1) == iter_num: # last batch
123 | for sample_index in range(batch_iter * batch_size, test_num):
124 | # print("Sample Index: ", sample_index)
125 |
126 | slidename = test_slide[sample_index]
127 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device)
128 | print(pyg_data)
129 | label_surv_type = torch.tensor(test_label_surv_type[sample_index]).unsqueeze(0).to(device)
130 | label_time = torch.tensor(test_label_time[sample_index]).unsqueeze(0).to(device)
131 |
132 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth'))
133 | # model train
134 | output = model(pyg_data,cluster_info)
135 | atten = model.get_attention_weights()
136 | average_atten = torch.mean(atten[0],dim=1) #record attention score
137 | atten_dict[label_time[0]] = average_atten[0]
138 | # save output and label to calculate loss
139 | if outs.shape[0] == 0:
140 | outs = output
141 | else:
142 | outs = torch.cat((outs, output), dim=0)
143 |
144 | if labels_surv_type.shape[0] == 0:
145 | labels_surv_type = label_surv_type
146 | else:
147 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0)
148 |
149 | if labels_time.shape[0] == 0:
150 | labels_time = label_time
151 | else:
152 | labels_time = torch.cat((labels_time, label_time),dim=0)
153 |
154 | # save all samples results
155 |
156 | if t_out_pre.shape[0] == 0:
157 | t_out_pre = -1 * output
158 | else:
159 | t_out_pre = torch.cat((t_out_pre, -1 * output),dim=0)
160 |
161 | if t_labelall_surv_type.shape[0] == 0:
162 | t_labelall_surv_type = label_surv_type
163 | else:
164 | t_labelall_surv_type = torch.cat((t_labelall_surv_type, label_surv_type), dim=0)
165 |
166 | if t_labelall_time.shape[0] == 0:
167 | t_labelall_time = label_time
168 | else:
169 | t_labelall_time = torch.cat((t_labelall_time, label_time), dim=0)
170 | # calculate loss of the batch
171 | if torch.sum(labels_surv_type) > 0.0:
172 | print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape)
173 | loss = _neg_partial_log(outs,labels_time,labels_surv_type)
174 | loss = args.cox_loss * loss
175 |
176 | total_loss_for_test += loss.item()
177 | print("Batch Avg Loss: {:.4f}", loss.item())
178 |
179 | else:
180 | # batch
181 | for batch_index in range(batch_size):
182 | index = batch_iter * batch_size + batch_index
183 | # print("Sample Index: ", index)
184 |
185 | # acquire dat and label
186 | slidename = test_slide[index]
187 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device)
188 | print(pyg_data)
189 | label_surv_type = torch.tensor(test_label_surv_type[index]).unsqueeze(0).to(device)
190 | label_time = torch.tensor(test_label_time[index]).unsqueeze(0).to(device)
191 |
192 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth'))
193 | # model train
194 | output = model(pyg_data,cluster_info)
195 | atten = model.get_attention_weights()
196 | average_atten = torch.mean(atten[0],dim=1)
197 | atten_dict[label_time[0]] = average_atten[0]
198 | # save the output and label
199 | if outs.shape[0] == 0:
200 | outs = output
201 | else:
202 | outs = torch.cat((outs, output), dim=0)
203 |
204 | if labels_surv_type.shape[0] == 0:
205 | labels_surv_type = label_surv_type
206 | else:
207 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0)
208 |
209 | if labels_time.shape[0] == 0:
210 | labels_time = label_time
211 | else:
212 | labels_time = torch.cat((labels_time, label_time),dim=0)
213 |
214 | # save all samples results
215 |
216 | if t_out_pre.shape[0] == 0:
217 | t_out_pre = -1 * output
218 | else:
219 | t_out_pre = torch.cat((t_out_pre, -1 * output),dim=0)
220 |
221 | if t_labelall_surv_type.shape[0] == 0:
222 | t_labelall_surv_type = label_surv_type
223 | else:
224 | t_labelall_surv_type = torch.cat((t_labelall_surv_type, label_surv_type), dim=0)
225 |
226 | if t_labelall_time.shape[0] == 0:
227 | t_labelall_time = label_time
228 | else:
229 | t_labelall_time = torch.cat((t_labelall_time, label_time), dim=0)
230 | #compute the loss
231 | if torch.sum(labels_surv_type) > 0.0:
232 | print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape)
233 |
234 |
235 | loss = _neg_partial_log(outs,labels_time,labels_surv_type)
236 | loss = args.cox_loss * loss
237 |
238 | total_loss_for_test += loss.item()
239 | print("Batch Avg Loss: {:.4f}", loss.item())
240 |
241 | c_idx_epochs_avg = ci(t_labelall_time.data.cpu(),t_out_pre.data.cpu(),t_labelall_surv_type.data.cpu())
242 | epoch_loss_test = total_loss_for_test / iter_num
243 | print("epoch_val_test_loss:{}".format(epoch_loss_test))
244 |
245 | return c_idx_epochs_avg,epoch_loss_test,atten_dict
246 |
247 |
248 | def train_model(n_epochs,model,optimizer,scheduler,train_slide,val_slide,test_slide,train_label_censorship,train_label_survtime,val_label_censorship,val_label_survtime,test_label_censorship,test_label_survtime,fold_num,batch_size,cuda=True):
249 | if cuda:
250 | model = model.to(device)
251 |
252 | os.makedirs("/data14/yanhe/miccai/train/saved_model/tcga_kirc/interpreatable_transformer_depth_"+str(args.depth)+"/seed_"+str(args.seed)+"/",exist_ok=True)
253 | os.makedirs("/data14/yanhe/miccai/train/log_result/tcga_kirc/interpreatable_transformer_depth_"+str(args.depth)+"/seed_"+str(args.seed),exist_ok=True)
254 | os.makedirs("/data14/yanhe/miccai/train/attention_weights/tcga_kirc/interpreatable_transformer_depth_"+str(args.depth)+"/seed_"+str(args.seed)+'/atten_dict/',exist_ok=True)
255 |
256 | best_loss = 1e9
257 | best_ci = 0.
258 | best_acc = 0.
259 | n_out_features = 1
260 | n_classes = [1]*n_out_features
261 | for epoch in range(n_epochs):
262 | model.train()
263 | # pbar.set_description("RepeatNum:{}/{} Seed:{} Fold:{}/{}".format(repeat_num_temp + 1,repeat_num,seed,fold_num,all_fold_num))
264 | total_loss_for_train = 0
265 | batch_id = 0
266 | out_pre = torch.Tensor().to(device)
267 | pre_for_batch = torch.Tensor().to(device)
268 | labelall_surv_type = torch.Tensor().to(device)
269 | label_surv_type_for_batch = torch.Tensor().to(device)
270 | label_time_for_batch = torch.Tensor().to(device)
271 | # print(labelall_surv_type.shape[0])
272 | labelall_time = torch.Tensor().to(device)
273 | train_num = len(train_slide)
274 | print("train_slide Num: ",len(train_slide))
275 |
276 |
277 | iter_num = train_num // batch_size + 1
278 |
279 | for batch_iter in range(iter_num):
280 | outs = torch.Tensor().to(device)
281 | labels_surv_type = torch.Tensor().to(device)
282 | labels_time = torch.Tensor().to(device)
283 |
284 | if (batch_iter + 1) == iter_num:
285 | for sample_index in range(batch_iter * batch_size, train_num):
286 | # print("Sample Index: ", sample_index)
287 |
288 | slidename = train_slide[sample_index]
289 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device)
290 | print(pyg_data)
291 | label_surv_type = torch.tensor(train_label_censorship[sample_index]).unsqueeze(0).to(device)
292 | label_time = torch.tensor(train_label_survtime[sample_index]).unsqueeze(0).to(device)
293 |
294 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth'))
295 |
296 | output = model(pyg_data,cluster_info)
297 |
298 | if outs.shape[0] == 0:
299 | outs = output
300 | else:
301 | outs = torch.cat((outs, output), dim=0)
302 |
303 | if labels_surv_type.shape[0] == 0:
304 | labels_surv_type = label_surv_type
305 | else:
306 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0)
307 |
308 | if labels_time.shape[0] == 0:
309 | labels_time = label_time
310 | else:
311 | labels_time = torch.cat((labels_time, label_time),dim=0)
312 |
313 |
314 | if out_pre.shape[0] == 0:
315 | out_pre = -1 * output
316 | else:
317 | out_pre = torch.cat((out_pre, -1 * output),dim=0)
318 |
319 | if labelall_surv_type.shape[0] == 0:
320 | labelall_surv_type = label_surv_type
321 | else:
322 | labelall_surv_type = torch.cat((labelall_surv_type, label_surv_type), dim=0)
323 |
324 | if labelall_time.shape[0] == 0:
325 | labelall_time = label_time
326 | else:
327 | labelall_time = torch.cat((labelall_time, label_time), dim=0)
328 |
329 | if torch.sum(labels_surv_type) > 0.0:
330 | # print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape)
331 |
332 |
333 | loss = _neg_partial_log(outs,labels_time,labels_surv_type)
334 | loss = args.cox_loss * loss
335 | optimizer.zero_grad()
336 | loss.backward()
337 | optimizer.step()
338 | total_loss_for_train += loss.item()
339 | # print("Batch Avg Loss: {:.4f}", loss.item())
340 |
341 | else:
342 |
343 | for batch_index in range(batch_size):
344 | index = batch_iter * batch_size + batch_index
345 | # print("Sample Index: ", index)
346 |
347 | slidename = train_slide[index]
348 | pyg_data = torch.load(os.path.join(args.pyg_path,slidename+'.pt')).to(device)
349 | print(pyg_data)
350 | label_surv_type = torch.tensor(train_label_censorship[index]).unsqueeze(0).to(device)
351 | label_time = torch.tensor(train_label_survtime[index]).unsqueeze(0).to(device)
352 |
353 | cluster_info = torch.load(os.path.join(args.cluster_info_path,slidename+'.pth'))
354 |
355 | output = model(pyg_data,cluster_info)
356 |
357 | if outs.shape[0] == 0:
358 | outs = output
359 | else:
360 | outs = torch.cat((outs, output), dim=0)
361 |
362 | if labels_surv_type.shape[0] == 0:
363 | labels_surv_type = label_surv_type
364 | else:
365 | labels_surv_type = torch.cat((labels_surv_type, label_surv_type), dim=0)
366 |
367 | if labels_time.shape[0] == 0:
368 | labels_time = label_time
369 | else:
370 | labels_time = torch.cat((labels_time, label_time),dim=0)
371 |
372 |
373 | if out_pre.shape[0] == 0:
374 | out_pre = -1 * output
375 | else:
376 | out_pre = torch.cat((out_pre, -1 * output),dim=0)
377 |
378 | if labelall_surv_type.shape[0] == 0:
379 | labelall_surv_type = label_surv_type
380 | else:
381 | labelall_surv_type = torch.cat((labelall_surv_type, label_surv_type), dim=0)
382 |
383 | if labelall_time.shape[0] == 0:
384 | labelall_time = label_time
385 | else:
386 | labelall_time = torch.cat((labelall_time, label_time), dim=0)
387 |
388 | if torch.sum(labels_surv_type) > 0.0:
389 | # print("outs.shape:",outs.shape,"labels_time.shape:",labels_time.shape)
390 |
391 |
392 | loss = _neg_partial_log(outs,labels_time,labels_surv_type)
393 | loss = args.cox_loss * loss
394 | optimizer.zero_grad()
395 | loss.backward()
396 | optimizer.step()
397 | total_loss_for_train += loss.item()
398 | # print("Batch Avg Loss: {:.4f}", loss.item())
399 |
400 |
401 |
402 | c_idx_for_train = ci(labelall_time.data.cpu(),out_pre.data.cpu(),labelall_surv_type.data.cpu())
403 |
404 | epoch_loss = total_loss_for_train / iter_num
405 |
406 | print("Epoch [{}/{}], epoch_loss {:.4f}".format(epoch+1,n_epochs, epoch_loss))
407 |
408 | #val
409 | val_c_idx,val_loss,_ = eval_model(model,val_slide,val_label_censorship,val_label_survtime,batch_size,cuda=cuda)
410 |
411 | #test
412 | test_c_idx,test_loss,_ =eval_model(model,test_slide,test_label_censorship,test_label_survtime,batch_size,cuda=cuda)
413 | # if scheduler is not None:
414 | # scheduler.step()
415 |
416 | with open('/data14/yanhe/miccai/train/log_result/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/'+args.label+'_random.log',"a") as f:
417 | f.write(f"EPOCH {epoch} : \n")
418 | f.write(f"train loss - {epoch_loss} train ci - {c_idx_for_train};\n")
419 | f.write(f"val loss - {val_loss} val ci -{val_c_idx};\n")
420 | f.write(f"test loss - {test_loss}test ci - {test_c_idx};\n")
421 | if val_c_idx >= best_ci:
422 | best_epoch = epoch
423 | best_ci = val_c_idx
424 | t_model = copy.deepcopy(model)
425 | # if val_loss < best_loss:
426 | # best_loss = val_loss
427 | # best_epoch = epoch
428 | # t_model = copy.deepcopy(model)
429 | # t_model = copy.deepcopy(model)
430 | save_path = '/data14/yanhe/miccai/train/saved_model/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/fold_num_{}.pth'.format(fold_num)
431 | torch.save({'model_state_dict': t_model.state_dict(),
432 | 'optimizer_state_dict': optimizer.state_dict(),
433 | }, save_path)
434 | print("Model saved: %s" % save_path)
435 |
436 | t_test_c_idx,t_test_loss,atten_dict = eval_model(t_model,test_slide,test_label_censorship,test_label_survtime,batch_size,cuda=cuda)
437 |
438 | atten_dict_saved_path = '/data14/yanhe/miccai/train/attention_weights/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/atten_dict/'+args.label+'_foldnum_{}.pth'.format(fold_num)
439 | torch.save(atten_dict,atten_dict_saved_path)
440 | with open('/data14/yanhe/miccai/train/log_result/tcga_kirc/interpreatable_transformer_depth_'+str(args.depth)+"/seed_"+str(args.seed)+'/'+args.label+'_random.log',"a") as f:
441 | f.write(f"best model test ci value {t_test_c_idx} occurs at EPOCH {best_epoch} ;\n")
442 | # f.write(f"best model test ci value {t_test_c_idx} ;\n")
443 |
444 |
445 | def setup_seed(seed):
446 | torch.manual_seed(seed)
447 | os.environ['PYTHONHASHSEED'] = str(seed)
448 | torch.cuda.manual_seed(seed)
449 | torch.cuda.manual_seed_all(seed)
450 | np.random.seed(seed)
451 | random.seed(seed)
452 | torch.backends.cudnn.benchmark = False
453 | torch.backends.cudnn.deterministic = True
454 | torch.backends.cudnn.enabled = True
455 |
456 |
457 | def main(args):
458 | fold_num = args.fold_num
459 | label_path = args.label_path
460 | n_epochs = args.epochs
461 | batch_size = args.batch_size
462 | label = torch.load(label_path)
463 | split_dict_path = args.split_dict_path
464 | seed = args.seed
465 | pyg_path = args.pyg_path
466 | print("seed:{}".format(seed))
467 | #splits
468 | # split_dict=torch.load('/data14/yanhe/miccai/data/tcga_lihc/train_val_test_split.pkl')
469 | split_dict = joblib.load(split_dict_path)
470 | fold_num = args.fold_num
471 | fold_train = 'fold_'+str(fold_num)+'_train'
472 | fold_val = 'fold_'+str(fold_num)+'_val'
473 | fold_test = 'fold_'+str(fold_num)+'_test'
474 | train_slide = split_dict[fold_train]
475 | val_slide = split_dict[fold_val]
476 | test_slide = split_dict[fold_test]
477 | setup_seed(seed)
478 |
479 | model = Superpixel_Vit(in_feats_intra=args.in_feats_intra,
480 | n_hidden_intra=args.n_hidden_intra,
481 | out_feats_intra=args.out_feats_intra,
482 | in_feats_inter=args.in_feats_inter,
483 | n_hidden_inter=args.n_hidden_inter,
484 | out_feats_inter=args.out_feats_inter,
485 | vw_num=args.vw_num,
486 | feat_dim=args.feat_dim,
487 | num_classes=1,
488 | depth=args.depth,
489 | num_heads = args.num_heads,
490 | final_fea_type = args.final_fea_type,
491 | mpool_intra=args.mpool_intra,
492 | mpool_inter=args.mpool_inter,
493 | gnn_intra=args.gnn_intra,
494 | gnn_inter=args.gnn_inter
495 | )
496 |
497 | optimizer = torch.optim.AdamW([dict(params=model.parameters(), lr=args.lr, betas=(0.9, 0.95),weight_decay = args.l2_reg_alpha),])
498 | t_max = n_epochs
499 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=t_max, eta_min=0. )
500 |
501 |
502 | Train_slide,Val_slide,Test_slide,Train_Labels_censorship,Train_Labels_survtime,Val_Labels_censorship,Val_Labels_survtime,Test_Labels_censorship,Test_Labels_survtime = get_train_val(train_slide,val_slide,test_slide,label_path)
503 | train_model(n_epochs, model, optimizer, scheduler, Train_slide,Val_slide,Test_slide,Train_Labels_censorship,Train_Labels_survtime,Val_Labels_censorship,Val_Labels_survtime,Test_Labels_censorship,Test_Labels_survtime,fold_num,batch_size,cuda=True)
504 |
505 |
506 |
507 |
508 |
509 | def get_params():
510 | parser = argparse.ArgumentParser(description='model training')
511 |
512 | parser.add_argument('--label_path',type=str, default='/data12/yanhe/miccai/data/tcga_kirc/slide_label.pt')
513 | parser.add_argument('--split_dict_path',type=str, default='/data12/yanhe/miccai/data/tcga_kirc/train_val_test_split_random1.pkl')
514 | parser.add_argument('--pyg_path',type=str, default='/data14/yanhe/miccai/graph_file/tcga_kirc/superpixel_num_600')
515 | parser.add_argument('--cluster_info_path',type=str, default='/data14/yanhe/miccai/codebook/cluster_info/tcga_kirc/superpixel600_cluster16/all_fold' )
516 | parser.add_argument('--vw_num',type=int, default=16)
517 | parser.add_argument('--feat_dim',type =int,default=1024)
518 | parser.add_argument('--depth',type=int,default=1)
519 | parser.add_argument('--num_heads',type=int,default=4)
520 | parser.add_argument('--mpool_intra',type=str,default='global_mean_pool') #‘global_mean_pool’,'global_max_pool','global_att_pool'
521 | parser.add_argument('--mpool_inter',type=str,default='global_mean_pool')
522 | parser.add_argument('--gnn_intra',type=str,default='sage') #'sage''gcn''gat''leconv''graphconv'
523 | parser.add_argument('--gnn_inter',type=str,default='sage')
524 | parser.add_argument('--in_feats_intra',type=int, default=1024)
525 | parser.add_argument('--n_hidden_intra',type=int, default=1024)
526 | parser.add_argument('--out_feats_intra',type=int,default=1024)
527 | parser.add_argument('--in_feats_inter',type=int,default=1024) #in_feats_inter=out_feats_intra
528 | parser.add_argument('--n_hidden_inter',type=int, default=1024)
529 | parser.add_argument('--out_feats_inter',type=int,default=1024) #out_feats_inter=feat_dim
530 | parser.add_argument('--final_fea_type',type=str,default='mean')
531 | parser.add_argument('--epochs', type=int, default=30)
532 | parser.add_argument('--batch_size',type=int,default=16)
533 | #parser.add_argument('--warmup_epochs', type=int, default= 40)
534 | parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate of model training")
535 | parser.add_argument("--l2_reg_alpha",type=float,default=0.001)
536 | parser.add_argument("--cox_loss",type=float,default=12)
537 | parser.add_argument("--seed",type=int, default=1)
538 | parser.add_argument("--fold_num",type=int, default=0)
539 | parser.add_argument("--label",type=str,default="tcga_lihc_fold_0_lr1e-5_30epoch")
540 |
541 | args, _ = parser.parse_known_args()
542 | return args
543 |
544 |
545 | if __name__ == '__main__':
546 | try:
547 | # tuner_params = nni.get_next_parameter()
548 | # logger.debug(tuner_params)
549 | # params = vars(merge_parameter(get_params(), tuner_params))
550 | # main(params)
551 | args=get_params()
552 | main(args)
553 | except Exception as exception:
554 | # logger.exception(exception)
555 | raise
556 |
557 |
--------------------------------------------------------------------------------