├── README.md
├── configs.py
├── data.py
├── demo
└── data
│ └── url.txt
├── image
└── README
│ ├── Intro.jpg
│ ├── SOI.jpg
│ ├── biomarker.jpg
│ ├── cell_state.jpg
│ ├── cell_type.jpg
│ ├── cell_type1.jpg
│ ├── cell_type2.jpg
│ ├── deconvolution1.jpg
│ ├── deconvolution2.jpg
│ ├── logo.png
│ ├── segmentation.jpg
│ └── tissue_compartment.jpg
├── infer.py
├── model
├── __pycache__
│ └── arch.cpython-37.pyc
└── arch.py
├── requirements.txt
├── tcs
├── tissue_compartment_BRCA.json
├── tissue_compartment_CRC.json
├── tissue_compartment_HCC.json
├── tissue_compartment_HD.json
├── tissue_compartment_LUSC.json
├── tissue_compartment_Mix.json
├── tissue_compartment_OVC.json
├── tissue_compartment_PDAC.json
├── tissue_compartment_PRAD.json
├── tissue_compartment_RCC.json
├── tissue_compartment_STAD.json
├── tissue_compartment_Xenium.json
└── tissue_compartment_state.json
├── train.py
├── train_oneout_pannuke.py
├── train_state.py
├── tutorial
├── data
│ └── url.txt
└── tutorial.ipynb
└── utils
├── __pycache__
└── utils.cpython-37.pyc
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # HistoCell
2 |
3 | **HistoCell** is a **weakly-supervised deep learning framework** to elucidate the **hierarchical spatial cellular information** including **tissue compartments, single cell types and cell states** with **histopathology images only**. This tutorial implements HistoCell to predict super-resolution spatial cellular information and illustrates the representative applications. The link to the HistoCell method will be presented soon. \
4 | Our website: http://histocell.qhdyr.net/index/index/index.html
5 |
6 |
7 |
8 | ## Environments
9 | ```sh
10 | pip install -r requirements.txt
11 | ```
12 |
13 | ## Data Format and Preprocessing
14 |
15 | ### Data Preparation
16 |
17 | * For model pre-train, HistoCell takes the **scRNA-seq** and **spatial transcriptomics** data with paired **high-resolution histopathology images** as input.
18 | * For histopathology images, we cut the paired images according to the pixel coordinates from ST data. The preprocessing code can be found in **./tutorial/tutorial.ipynb**
19 | * For transcriptomics data, we apply deconvolution methods to get the cell composition as the supervision. Applicable methods are listed as [CARD](https://www.nature.com/articles/s41587-022-01273-7), [RCTD](https://www.nature.com/articles/s41587-021-00830-w), [Tangram](https://www.nature.com/articles/s41592-021-01264-7), [Cell2location](https://www.nature.com/articles/s41587-021-01139-4) etc.
20 | * For model inference, HistoCell only requires the histopathology images including tiles and WSIs.
21 | * WSIs should be formatted as .svs or .tif. In order to convert the WSIs into the tiles can be processed in a high-throughput manner, we apply the toolbox from [CLAM](https://github.com/mahmoodlab/CLAM) for image segmentation, stitching and patching. As for TCGA diagnostic images, patch of 256x256 pixels is recommended.
22 | * Tiles can be storaged in any format of image.
23 |
24 | ### Data Preprocessing
25 |
26 | Using GPUs is highly recommended for whole slide images (WSIs) processing.
27 |
28 | Before we begin the model inference process, cell segmentation is required. Here, we apply [HoVerNet](https://pdf.sciencedirectassets.com/272154/1-s2.0-S1361841519X00079/1-s2.0-S1361841519301045/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEO3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCIQDd2OMed8Quier79h6hhXShEPV1a3lwQv%2BMd%2B99MpGolQIgHbvYbhoiHEe1uT1QjfLGEOEMBSTSCAhJmoThBfbNll4quwUIpv%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAFGgwwNTkwMDM1NDY4NjUiDMf%2B5OgldA9yxfGLPyqPBY%2FB1K2p3P7IvTS1hGh8akwWej0tXrgiaex2nNOJT6jwHbFCml3oXhsoNvr5bQ%2BmhFL3hNeKRbOXowl1RfkkrM9Mo8W7VB6L6a3faDuy3R9FmYg9OniS%2F2l1pasqKf%2Bk3es0ZtkBZYJhIpRcxWIogcRB1WPWE9WGuBRfN2qp7xf7NNkq4ZmbNaU3ysqx%2FMFFZWGP1DoLhVUeP18olpZstHpJ5rrKvMEJ4bUhOnN4WkA4wflhpJAKy6dv10PJIbCGYWReuhcTFO%2FNoSqCRDUUnQZD5zRfaCfsjNO943WEECuHreEcGSfsGwH16ncrE6deBpvS9x5f7qFSzLkM01th0ZwonFL0zXSGN6qaPnZ0wBzO2Lbe0OtzPBeHG3BrPl3VxL9qYSKNDITFNW%2BVRAO3CckWm%2Bt%2FEQqBGRbX%2FArLKvT7NS12jx%2FhEhj%2B%2Ba3yYQQjFUJMDPfLpbXsLpl8IOsNKBbQAqsT09iN0an0zA8q7oh%2F9HgfL8KCZdRuPWp9HkLLjSRbdH01i7ctSMTbehrkjiVMnXz8f3B9%2BVHfZR%2B3xQYH2YOH67UE87JCofjKJkWhroXKSkS1c53ye%2FOCEyF9gp2ChrWnKG8o95jTccF%2BInoECXr0Ymc5QiotpLF2es6pUQcGi2mq5rcjY1P6vJ9x4i4DDR2e%2FM718BZaM8zQCUmYm4XOBu%2B22Wtf8GgAymn8pz7uTgKnh9jlZhmTZ4YAnzVaqW%2FjTovCSRCUucEzklibwjVJFHw8urGagWh77nG0Qv9wGnO4PQYYOTJqs%2F2WZo9raeKVwxMPkcoXOVRM0Pkphg4bDnVZG63R1xL89urfYK1PPKDGtH08o9UvREJM6ugGTjDHpRZPeS29NYXNZ7kwxtHOrQY6sQFFrXdmd2FpmYFJtlfg2DJxJa7SwFYBVdj3Db6HD%2FLftOy%2B%2FyGCiRWyB%2FC%2FDxPPU0WDWvUzAq72HoaEV87cVsLz6Q7446UGQ1HAGMKRnAhCALhSZIp7WfyC1gTuAthi5QVJvr78GXQAinqnBaAlrMaHnLTtbiSNPNykRQrrEvhBDIq21Ffy%2Fid%2F9pPW0IURXrvj1end6m7dZT7ZtEIYcA%2BWVa5%2FYz6%2FOu5Tvta6zaDkE0k%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240126T134143Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY5OQRHCPI%2F20240126%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=608d9d59c577a23f89bb7aaa69c7f8c2163e3911b0cfcdc16987fb1603f50428&hash=48169297d0c9fa366d3cb8bd120add89965ee3917df478d735f708fb300c8168&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S1361841519301045&tid=spdf-4ba996cc-271e-4fff-93e1-559385625235&sid=31ab0cf595ec25414989848882e6352d1712gxrqa&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=0f135a56045e56020451&rr=84b92cd0cc162287&cc=us)[[code](https://github.com/vqdang/hover_net)], a well-developed nuclei instance segmentation method, to process cell segmentation from image tiles. It is worth mentioned that other segmentation methods (e.g. [MaskRCNN](https://arxiv.org/abs/1703.06870), [Cerbebrus](https://pdf.sciencedirectassets.com/272154/1-s2.0-S1361841522X00078/1-s2.0-S1361841522003139/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEO3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCIQDWnVS2nFNswtWtc37yxAthvAru8F%2Bi9sXObOpOSp4fcQIgPLxoDxjXB2%2Bn5LDajvayiIU5Ev5%2FZSwXH65%2FSCvHpTQqvAUIpf%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAFGgwwNTkwMDM1NDY4NjUiDGqoAHUKPFM7LE85KCqQBRI1MjQTxc06bBJ2fe8tBSgL1PqcV3eXn3SDFwCsfPX48S61JVtwG9g91MxurBEUNhdwrjWdNBwI3pDUNKp7MQIH0AJVRXhwJ5mKktV96eZd5dizuGWGAicjTD%2Fy5zusW3nLiEhx0ka61bVUzKU7ZGdA%2Bs%2BT1egxuN3LK9fHKGVF63NVJt2sbDgNigMWEEiPdlXbWtvqvB0EylrEcVT7Rqo2a7ccsmPMCyDQlKyMzWJpNBOOu1KI2BK2KE0y0jLELSTeOvF9iHfwmBOf3Q1Azq7uBElKYA2jQPPFm6O%2Fi7b6UsJQtg5RLxB2HbOcRK1htkMmnOvGARy1y%2BbwQ26nYJJW5C85m%2F%2Bc3%2FXjBg%2BrsLeodZYEWQEpVgY%2BTUf%2Bdi4%2BCYmok0ra8kTrShb2CvgVK4pVrk8L0AO33WVGNokeOp%2F8xV0bz7O1DAsqN2ykAvvPk0wuprohZxBm2BGhA%2BLY7lox3f6v6Rv544wpvF8RHreBS%2B2jkVPae8%2FZER6WaRl9WuFuBzJlgY0xdW1V2iU8quvRxltyx%2BSS%2BVbdgMMBmX9H41LLVQgTYNnciL7hTcHRzS70pXQa2fb%2BfhFjQ0dXv3eH%2BCz3JeKgpU6wInT5Ax4GKKU9o3QqJMh9OsFmMvKjVHL8bzhpllg77aZEVPHbXj%2F2oUB6xN3BhSjl30wKiDKT0y1RGt2QuKWOqd%2FW5nEo0Nw8YA4kS%2FIkq6v7XQuyJnuPx6orceBGrHbBb0mLvHV5XubTTAmmvGmXojPqBT6cW%2FmCdjwHnGoAb5HG%2BAPuc%2FwCJ8pDR9ZxDgzry%2F4TNCPICSIP7yj%2B9tyfleO3i1J1mhnzCa%2BINadYcFBBJcVN4PW6abUmLq2zakWegZOr%2FjW2MNrHzq0GOrEBslcHzW1k%2BVNCvsQQ9R4j1edYgD16Ya5im3eeb4DyfkHwxF1eODT4NU1SsSI3uXG%2BnfmMEP0qPsq6khy%2FK9RXojWq0Qn0lsTw9ilEvPN97m192kvnNcD0dkbN32Jpjl%2Bfg08nRzORU9wzjJfwvSjwx2RIMdnBtJGENnzNNt%2FuEg%2BWdHon%2F6%2FG3%2FvmkWweeuQw9hnbO6X2HV0BGyMPXcMq9ouE8eIDqG38lbE3hm9kUFm1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240126T134736Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY72FFPE6G%2F20240126%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=f81b493864396df90b2dc5a4bc19a720f03415ca06d0e42870d9bbc34adb93ef&hash=e4bf9d21b009b09db822df8fbcdc06443d696a664e097e20f414559b75f39f7d&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S1361841522003139&tid=spdf-1b0d3161-dd23-45aa-bfe3-aa725b637c50&sid=31ab0cf595ec25414989848882e6352d1712gxrqa&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=0f135a56045e57545753&rr=84b935728ba1239d&cc=us)etc.) are also applicable here.
29 |
30 | With cell segmentation by HoVerNet, you will get json file as the segmentation results illustrated as below:
31 |
32 | ```json
33 | {"mag": null, "nuc": {"YOUR_CELL_ID": {"bbox": [[1, 2], [3, 4]], "centroid": [2, 3], "contour": [[0, 1]], "type_prob": 0.9, "type": 0}}}
34 | ```
35 |
36 | ### Input data format
37 | The tiled images, segmentation results(json file) and cell type/state proportion given by deconvolution are required during training process.
38 |
39 | ## Data Accession
40 | Demo data and tutorial data could be assessed on the Google drive with [Demo](https://drive.google.com/file/d/1DzwsFvD4wWKNN1wJCe3YsQZlk9xIp8DV/view?usp=sharing) and [Tutorial](https://drive.google.com/file/d/1iwXzxafdUF7SF6IAEyWluN728WbxAMEi/view?usp=sharing), respectively. To run the demo, you need to download the data to the directory **./demo/data**.
41 |
42 | ## Model pre-training with ST data
43 | Run **train.py** for model pre-training on cell type prediction. You need to change the parameters in **configs.py** including the model information, dataset directory and training details. Besides, choose the proper tissue compartment file in **./tcs**. To develop your own model, run the command below.
44 | ```python
45 | python train.py \
46 | --model YOUR_MODEL_DESC \
47 | --tissue TISSUE_TYPE \
48 | --deconv DECONVOLUTION_METHOD \
49 | --prefix SAMPLE_PREFIX \
50 | --k_class NUM_OF_CLASS \
51 | --tissue_compartment PATH_TO_TCS_FILE
52 | ```
53 | A simple demo on subject H1 from HER2ST dataset:
54 | ```python
55 | python train.py \
56 | --model Breast_Benchmark_H1 \
57 | --tissue BRCA \
58 | --deconv RCTD \
59 | --prefix H1 \
60 | --k_class 6 \
61 | --tissue_compartment ./tcs/tissue_compartment_BRCA.json
62 | ```
63 |
64 | For the benchmark analysis on the PanNuke dataset, run the leave-one-out cross-validation command:
65 | ```
66 | python train_oneout_pannuke.py \
67 | --model PanNuke_cross_validation \
68 | --tissue Breast \
69 | --reso 1 \
70 | --deconv Mix \
71 | --prefix fold1 fold2 fold3 \
72 | --k_class 5 \
73 | --tissue_compartment ./tcs/tissue_compartment_Mix.json
74 | ```
75 |
76 | As for cell state, you can run **train_state.py** to develop the cell state prediction model. For example, develop the cell state prediction on the breast cancer ST samples:
77 | ```python
78 | python train_state.py \
79 | --model brca_state \
80 | --tissue BRCA \
81 | --deconv RCTD \
82 | --prefix 10x_BRCA A_1938345_11 B_1938529_9 C_2000752_23 D_2000910_33 \
83 | --tissue_compartment ./tcs/tissue_compartment_state.json
84 | ```
85 |
86 | ## Quick inference with histopathology images only
87 | With the pretrained model, you can infer the cellular spatial profile with **infer.py**.
88 | ```python
89 | python infer.py \
90 | --model Breast_Benchmark_H1 \
91 | --epoch 30 \
92 | --tissue BRCA \
93 | --deconv RCTD \
94 | --prefix H1_hires \
95 | --k_class 6 \
96 | --tissue_compartment ./tcs/tissue_compartment_BRCA.json \
97 | --omit_gt
98 | ```
99 |
100 | ## Representative Results
101 | Here we only illustrate the demo and representative results corrsponding to the paper in the tutorial.ipynb. The predicted hierarchical spatial cellular information is storaged as a dict in a pickle file for each slide. For more results, you can directly jump to our [HistoCell website](http://histocell.qhdyr.net/index/index/index.html).
102 |
103 | ### Benchmark results
104 |
105 | * **Tissue Compartment**
106 |
107 |

108 |
109 | * **Single-cell Type**
110 |
111 |

112 |
113 |
114 |

115 |
116 |
117 | Red, blue and green scatters represent cancer epithelial cells, stromal cells and macrophage cells.
118 | * **Cell States**
119 |
120 |

121 |
122 |
123 | ### Representative Application: Tissue architecture annotations
124 |
125 | With a histopathology image, HistoCell could first infer pixel-level cell types and then cluster cells as tissue regions, which exhibit high accuracy and allow users to further identify the small foci within tissue regions at pixel-level resolution.
126 |
127 |

128 |
129 |
130 | ### Representative Application: Cell Type Deconvolution
131 |
132 | Since HistoCell Integrates spot-level cellular compositions deconvoluted from expression data and those based on histologic morphologic features, it could produce a more precise and robust deconvolution result.
133 |
134 |

135 |
136 |
137 |

138 |
139 |
140 | ### Representative Application: Spatial organization indicators identification
141 |
142 |

143 |
144 | The histopathology image is coverted to **spatial cellular map** with HistoCell and the cells are accumulated as clusters. Through the correlation analysis between clinical outcomes and cellular spatial clustering features, we identify spatial biomarkers for prognosis. Demo results can be found in Tutorial.ipynb. The representative spatial features for prognosis stratification is visualized as below.
145 |
146 |

147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | import json
2 | from yacs.config import CfgNode as CN
3 |
4 | def _get_config(tissue_type, deconv, subtype, k_class, tissue_dir):
5 | config = CN()
6 | config.train = CN()
7 | config.train.lr = 0.0005
8 | config.train.epoch = 41
9 | config.train.val_iter = 10
10 | config.train.val_min_iter = 9
11 |
12 | config.data = CN()
13 | config.data.deconv = deconv
14 | config.data.save_model = f'./train_log/{tissue_type}/models' # model saved
15 | config.data.ckpt = f'./train_log/{tissue_type}/ckpts' # eval results saved
16 | config.data.tile_dir = f'./demo/data/{tissue_type}/tiles' # path to tiles
17 | config.data.mask_dir = f'./demo/data/{tissue_type}/seg' # path to json segmentation file
18 | config.data.batch_size = 32
19 | config.data.tissue_dir = tissue_dir # tissue compartment directory
20 | config.data.max_cell_num = 256 # max cell number in a single tile for batch learning
21 | config.data.cell_dir = f'./demo/data/{tissue_type}/cell_proportion/type/{config.data.deconv}' # path to cell proportion label
22 |
23 | config.model = CN()
24 | config.model.tissue_class = 3
25 | config.model.pretrained = True
26 | config.model.channels = 3
27 | config.model.k_class = k_class
28 |
29 | return config
30 |
31 | def _get_cell_state_config(tissue_type, deconv, subtype, tissue_dir):
32 | config = CN()
33 | config.train = CN()
34 | config.train.lr = 0.0005
35 | config.train.epoch = 41
36 | config.train.val_iter = 5
37 | config.train.val_min_iter = 9
38 | config.train.state_epoch = 41
39 |
40 | config.data = CN()
41 | config.data.deconv = deconv
42 | config.data.save_model = f'./train_log/{tissue_type}/models'
43 | config.data.ckpt = f'./train_log/{tissue_type}/ckpts'
44 | config.data.tile_dir = f'./demo/data/{tissue_type}/tiles'
45 | config.data.mask_dir = f'./demo/data/{tissue_type}/seg'
46 | config.data.batch_size = 32
47 | config.data.tissue_dir = tissue_dir
48 | config.data.max_cell_num = 256
49 | with open(tissue_dir, 'r') as tissue_file:
50 | tc = json.load(tissue_file)
51 |
52 | config.data.cell_dir = f'./demo/data/{tissue_type}/cell_proportion/type/{config.data.deconv}'
53 | config.data.state_dir = f'./demo/data/{tissue_type}/cell_proportion/state/{config.data.deconv}' # path to cell state directory
54 |
55 | config.model = CN()
56 | config.model.tissue_class = len(tc['list'])
57 | config.model.pretrained = True
58 | config.model.channels = 3
59 | config.model.k_class = len(tc['dict'])
60 |
61 | k_state = 0
62 | for key, value in tc['state'].items():
63 | k_state += int(value)
64 | config.model.k_state = k_state
65 | print(k_state)
66 |
67 | return config
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | import json
7 | import random
8 | from glob import glob
9 | from PIL import Image
10 | from utils.utils import load_json, load_image
11 | from torch.utils.data import Dataset
12 | from torchvision import transforms
13 | from tqdm import tqdm
14 |
15 |
16 | class TileBatchDataset(Dataset):
17 | def __init__(self,
18 | tile_dir,
19 | mask_dir,
20 | tissue,
21 | cell_dir=None,
22 | deconv='POLARIS',
23 | prefix=['*'],
24 | channels=3,
25 | aug=True,
26 | val=False,
27 | panuke=None,
28 | sample_ratio=1,
29 | max_cell_num=128,
30 | ext='jpg',
31 | focus_tissue='Breast',
32 | reso=1
33 | ) -> None:
34 | ### Batch Size should be 1
35 | self.tile_list = []
36 |
37 | if aug is True:
38 | self.transforms = transforms.Compose(
39 | [
40 | transforms.ColorJitter(brightness=0.1, contrast=0.3, saturation=0.1, hue=0.2),
41 | transforms.RandomGrayscale(p=0.1),
42 | transforms.ToTensor(),
43 | transforms.Normalize(
44 | mean=[0.485, 0.456, 0.406],
45 | std=[0.229, 0.224, 0.225]
46 | )
47 | ]
48 | )
49 | else:
50 | self.transforms = transforms.Compose(
51 | [
52 | transforms.ToTensor(),
53 | transforms.Normalize(
54 | mean=[0.485, 0.456, 0.406],
55 | std=[0.229, 0.224, 0.225]
56 | )
57 | ]
58 | )
59 | self.channels = channels
60 |
61 | self.mask_dir = mask_dir
62 | self.cell_dir = cell_dir
63 | self.val = val
64 | self.panuke = panuke
65 | self.mcn = max_cell_num
66 | self.ext = ext
67 | self.reso = reso
68 |
69 | with open(tissue, 'r') as tissue_file:
70 | self.tc = json.load(tissue_file)
71 | print(f"Tissue Compartment: {self.tc['list']}")
72 | if cell_dir is not None:
73 | cell_labels = {}
74 |
75 | if deconv in ['Tangram', 'RCTD', 'stereoscope', 'CARD', 'POLARIS', 'Xenium', 'CARD_fix', 'RCTD_fix']:
76 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")):
77 | sample_name = cell_prop.split('/')[-1].split('.')[0]
78 | if sample_name not in prefix:
79 | continue
80 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0)
81 | for cell_index, row in tqdm(cell_df.iterrows()):
82 | cell_index = str(cell_index)
83 | x, y, barcode = cell_index.split('_')
84 | cell_num = np.array(row)
85 | cell_propotion = cell_num / np.sum(cell_num)
86 | abs_path = glob(os.path.join(tile_dir, sample_name, f'*_{x}x{y}.jpg'))
87 | if len(abs_path) <= 0:
88 | continue
89 | img_name = abs_path[0].split('/')[-1].split('.')[0]
90 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
91 | cell_labels.update(
92 | {
93 | sample_name + '_' + img_name: cell_propotion
94 | }
95 | )
96 | if not os.path.exists(json_file):
97 | continue
98 | self.tile_list.append(abs_path[0])
99 |
100 | elif deconv in ['Mix']:
101 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")):
102 | sample_name = cell_prop.split('/')[-1].replace('.tsv', '')
103 | if sample_name not in prefix:
104 | continue
105 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0)
106 | for cell_index, row in cell_df.iterrows():
107 | # cell_index = "%06d" % int(cell_index)
108 | cell_index = cell_index.split('-')[0]
109 | abs_path = glob(os.path.join(tile_dir, sample_name, f'{cell_index}-{focus_tissue}.{ext}'))
110 | if len(abs_path) == 0:
111 | continue
112 | tissue_index = abs_path[0].split('/')[-1].split('.')[0]
113 | cell_propotion = np.array(row)
114 | cell_propotion = cell_propotion / np.sum(cell_propotion)
115 | cell_labels.update(
116 | {
117 | sample_name + '_' + tissue_index: cell_propotion
118 | }
119 | )
120 | img_name = abs_path[0].split('/')[-1].replace(f'.{ext}', '')
121 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
122 | if not os.path.exists(json_file):
123 | continue
124 | self.tile_list.append(abs_path[0])
125 |
126 | else:
127 | raise NotImplementedError("The Deconvolution Method is not supported.")
128 |
129 | self.cell_labels = cell_labels
130 |
131 | if not val:
132 | print(self.tc['dict'])
133 | print(cell_df.keys())
134 | assert len(self.tc['dict'].keys()) == len(cell_df.keys())
135 | print(f"Cell Category: {list(cell_df.keys())}")
136 | else:
137 | for dir_prefix in prefix:
138 | path_list = []
139 | if panuke is None:
140 | panuke = '*'
141 | else:
142 | panuke = f'*-{panuke}'
143 | for abs_path in glob(os.path.join(tile_dir, f"{dir_prefix}/{panuke}.{ext}")):
144 | # for abs_path in glob(os.path.join(tile_dir, f"tumor_003/*.{ext}")):
145 | img_name = abs_path.split('/')[-1].replace(f'.{ext}', '')
146 | sample_name = abs_path.split('/')[-2]
147 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
148 | if os.path.exists(json_file):
149 | path_list.append(abs_path)
150 | self.tile_list.extend(path_list)
151 |
152 | self.tile_list = random.sample(self.tile_list, int(len(self.tile_list) * sample_ratio))
153 |
154 | def __len__(self) -> int:
155 | return len(self.tile_list)
156 |
157 | def __getitem__(self, index: int):
158 | # load image
159 | img_path = self.tile_list[index]
160 | img_name = img_path.split('/')[-1].split('.')[0]
161 | sample_name = img_path.split('/')[-2]
162 | with open(img_path, 'rb') as fp:
163 | pic = Image.open(fp).convert('RGB').resize((256 // self.reso, 256 // self.reso))
164 | pic = pic.resize((256, 256))
165 | image = self.transforms(pic)
166 |
167 | # load mask
168 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
169 | nucs = load_json(json_file)
170 | mask_list, size_list, pos_list, type_list = [], [], [], []
171 | for index, nuc in nucs.items():
172 | if len(image.shape) < 3:
173 | image = np.expand_dims(image, axis=0)
174 | # mask = np.zeros_like(image)
175 | center_x, center_y = nuc['centroid']
176 | bbox_xa, bbox_ya = nuc['bbox'][0]
177 | bbox_xb, bbox_yb = nuc['bbox'][1]
178 | mask = np.zeros([3, bbox_xb - bbox_xa, bbox_yb - bbox_ya])
179 | mask = image[:, bbox_xa: bbox_xb, bbox_ya: bbox_yb]
180 | try:
181 | mask = torch.tensor(cv2.resize(mask.permute(1, 2, 0).numpy(), (30, 30)))
182 | except:
183 | continue
184 |
185 | mask_list.append(mask.permute(2, 0, 1)), size_list.append(np.array([(bbox_xb - bbox_xa) / image.shape[-1], (bbox_yb - bbox_ya) / image.shape[-1]]))
186 | pos_list.append(np.array([center_x, center_y]))
187 | if nuc['type'] is None:
188 | type_list.append(-1)
189 | else:
190 | type_list.append(int(self.tc['HoVerNet'][int(nuc['type'])]))
191 |
192 | cell_types = np.array(type_list)
193 | image = torch.tensor(cv2.resize(image.permute(1, 2, 0).numpy(), (128, 128))).permute(2, 0, 1)
194 | cell_num = len(mask_list)
195 | if cell_num == 1:
196 | dist_mat = np.zeros([cell_num, cell_num])
197 | cell_coords = np.array([pos_list])
198 | elif cell_num == 0:
199 | dist_mat = np.zeros([cell_num, cell_num])
200 | cell_coords = np.zeros([0, 2])
201 | else:
202 | cell_coords = np.stack(pos_list, axis=0)
203 | dist_mat = np.zeros([cell_num, cell_num])
204 | for i in range(cell_num):
205 | for j in range(i + 1, cell_num):
206 | dist = np.linalg.norm((cell_coords[i] - cell_coords[j]), ord=2)
207 | if dist < 40:
208 | dist_mat[i, j] = 1
209 |
210 | dist_mat = dist_mat + dist_mat.T + np.identity(cell_num)
211 |
212 | adj_mat = np.zeros([self.mcn, self.mcn])
213 | cell_pixels = np.zeros([self.mcn, 2])
214 | CellType = -np.ones([self.mcn])
215 | valid_mask = len(mask_list) if len(mask_list) < self.mcn else self.mcn # max cell num
216 | if valid_mask >= self.mcn:
217 | mask_list = mask_list[:self.mcn]
218 | size_list = size_list[:self.mcn]
219 | adj_mat = dist_mat[:self.mcn, :self.mcn]
220 | cell_pixels = cell_coords[:self.mcn]
221 | CellType = cell_types[:self.mcn]
222 | else:
223 | for _ in range(self.mcn - valid_mask):
224 | mask_list.append(torch.zeros((3, 30, 30)))
225 | size_list.append(torch.zeros((2)))
226 |
227 | adj_mat = np.pad(dist_mat, ((0, self.mcn - valid_mask), (0, self.mcn - valid_mask)))
228 | cell_pixels[:valid_mask] = cell_coords
229 | CellType[:valid_mask] = cell_types
230 |
231 | cell_images = np.stack(mask_list, axis=0)
232 | cell_sizes = np.stack(size_list, axis=0)
233 |
234 | if self.cell_dir is None:
235 | return {
236 | 'name': f'{sample_name}_{img_name}',
237 | 'tissue': image, # B 3 256 256
238 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128
239 | 'mask': int(valid_mask), # B
240 | 'size': torch.tensor(cell_sizes, dtype=torch.float32),
241 | 'adj': torch.tensor(adj_mat, dtype=torch.long),
242 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32),
243 | 'cell_types': torch.tensor(CellType, dtype=torch.float32)
244 | }
245 | # Add tissue compartment
246 | if self.panuke is not None:
247 | sample_keys = sample_name + '_' + img_name.split('-')[0] + f'-{self.panuke}' # Only for Panuke
248 | cell_prop = self.cell_labels[sample_keys]
249 | tissue_cat = -1
250 |
251 | else:
252 | sample_keys = sample_name + '_' + img_name
253 | cell_prop = self.cell_labels[sample_keys]
254 | # tc = np.zeros(len(self.tc['list']))
255 | # for cata_i, prop in enumerate(cell_prop):
256 | # tc[self.tc['list'].index(self.tc['dict'][str(cata_i)])] += prop
257 | if np.max(cell_prop) > 0.75:
258 | tissue_cat = self.tc['list'].index(self.tc['dict'][str(np.argmax(cell_prop))])
259 | else:
260 | tissue_cat = len(self.tc['list'])
261 |
262 | return {
263 | 'name': f'{sample_name}_{img_name}',
264 | 'tissue': image, # B 3 128 128
265 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128
266 | 'mask': int(valid_mask), # B
267 | 'cells': torch.tensor(cell_prop, dtype=torch.float32), # B cell_type
268 | 'size': torch.tensor(cell_sizes, dtype=torch.float32),
269 | 'adj': torch.tensor(adj_mat, dtype=torch.long),
270 | 'tissue_cat': torch.tensor(tissue_cat, dtype=torch.long), # B 1
271 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32),
272 | 'cell_types': torch.tensor(CellType, dtype=torch.float32)
273 | }
274 |
275 |
276 | class TileBatchStateDataset(Dataset):
277 | def __init__(self, tile_dir, mask_dir, tissue, cell_dir=None, state_dir=None, deconv='POLARIS', prefix=['*'], channels=3, aug=True, val=False, panuke=False, sample_ratio=1, max_cell_num=128, ext='jpg') -> None:
278 | ### Batch Size should be 1
279 | self.tile_list = []
280 |
281 | if aug is True:
282 | self.transforms = transforms.Compose(
283 | [
284 | transforms.ColorJitter(brightness=0.1, contrast=0.3, saturation=0.1, hue=0.2),
285 | transforms.RandomGrayscale(p=0.1),
286 | transforms.ToTensor(),
287 | transforms.Normalize(
288 | mean=[0.485, 0.456, 0.406],
289 | std=[0.229, 0.224, 0.225]
290 | )
291 | ]
292 | )
293 | else:
294 | self.transforms = transforms.Compose(
295 | [
296 | transforms.ToTensor(),
297 | transforms.Normalize(
298 | mean=[0.485, 0.456, 0.406],
299 | std=[0.229, 0.224, 0.225]
300 | )
301 | ]
302 | )
303 | self.channels = channels
304 |
305 | self.mask_dir = mask_dir
306 | self.cell_dir = cell_dir
307 | self.state_dir = state_dir
308 | self.val = val
309 | self.panuke = panuke
310 | self.mcn = max_cell_num
311 | self.ext = ext
312 |
313 | with open(tissue, 'r') as tissue_file:
314 | self.tc = json.load(tissue_file)
315 | print(f"Tissue Compartment: {self.tc['list']}")
316 |
317 | if cell_dir is not None:
318 | cell_labels = {}
319 | state_labels = {}
320 |
321 | if deconv in ['Tangram', 'RCTD', 'stereoscope', 'CARD', 'POLARIS', 'Xenium']:
322 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")):
323 | sample_name = cell_prop.split('/')[-1].split('.')[0]
324 | if sample_name not in prefix:
325 | continue
326 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0)
327 | state_df = pd.read_csv(os.path.join(state_dir, f"{sample_name}.tsv"), sep='\t', index_col=0)
328 | for cell_index, state_row in state_df.iterrows():
329 | cell_index = str(cell_index)
330 | x, y, barcode = cell_index.split('_')
331 |
332 | # cell prop
333 | row = cell_df.loc[cell_index]
334 | cell_num = np.array(row)
335 | cell_propotion = cell_num / np.sum(cell_num)
336 | cell_labels.update(
337 | {
338 | sample_name + '_' + f'{x}x{y}': cell_propotion
339 | }
340 | )
341 |
342 | # state prop
343 | state_num = np.array(state_row)
344 | state_proportion = state_num / np.sum(state_num)
345 | state_labels.update(
346 | {
347 | sample_name + '_' + f'{x}x{y}': state_proportion
348 | }
349 | )
350 |
351 | abs_path = glob(os.path.join(tile_dir, sample_name, f'*_{x}x{y}.{ext}'))
352 | if len(abs_path) <= 0:
353 | continue
354 | img_name = abs_path[0].split('/')[-1].split('.')[0]
355 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
356 | if not os.path.exists(json_file):
357 | continue
358 | self.tile_list.append(abs_path[0])
359 | elif deconv in ['Mix']:
360 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")):
361 | sample_name = cell_prop.split('/')[-1].split('.')[0]
362 | if sample_name not in prefix:
363 | continue
364 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0)
365 | for cell_index, row in cell_df.iterrows():
366 | cell_index = "%06d" % int(cell_index)
367 | abs_path = glob(os.path.join(tile_dir, sample_name, f'{cell_index}-Prostate.{ext}'))
368 | if len(abs_path) == 0:
369 | continue
370 | cell_propotion = np.array(row)
371 | cell_labels.update(
372 | {
373 | sample_name + '_' + cell_index + '-Prostate': cell_propotion
374 | }
375 | )
376 | img_name = abs_path[0].split('/')[-1].split('.')[0]
377 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
378 | if not os.path.exists(json_file):
379 | continue
380 | self.tile_list.append(abs_path[0])
381 |
382 | else:
383 | raise NotImplementedError("The Deconvolution Method is not supported.")
384 |
385 | self.cell_labels = cell_labels
386 | self.state_labels = state_labels
387 |
388 | if not val:
389 | print(self.tc['dict'])
390 | print(cell_df.keys())
391 | assert len(self.tc['dict'].keys()) == len(cell_df.keys())
392 | print(f"Cell Category: {list(cell_df.keys())}")
393 | else:
394 | for dir_prefix in prefix:
395 | path_list = []
396 | print(os.path.join(tile_dir, f"{dir_prefix}/*.{ext}"))
397 | for abs_path in glob(os.path.join(tile_dir, f"{dir_prefix}/*.{ext}")):
398 | # img_name = abs_path.split('/')[-1].split('.')[0]
399 | img_name = abs_path.split('/')[-1].strip(f'.{ext}')
400 | sample_name = abs_path.split('/')[-2]
401 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
402 | # print(json_file)
403 | if os.path.exists(json_file):
404 | path_list.append(abs_path)
405 | self.tile_list.extend(path_list)
406 |
407 | self.tile_list = random.sample(self.tile_list, int(len(self.tile_list) * sample_ratio))
408 |
409 | def __len__(self) -> int:
410 | return len(self.tile_list)
411 |
412 | def __getitem__(self, index: int):
413 | # load image
414 | img_path = self.tile_list[index]
415 | img_name = img_path.split('/')[-1].strip(f'.{self.ext}')
416 | sample_name = img_path.split('/')[-2]
417 | with open(img_path, 'rb') as fp:
418 | pic = Image.open(fp)
419 | image = self.transforms(pic)
420 |
421 | # load mask
422 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json")
423 | nucs = load_json(json_file)
424 | mask_list, size_list, pos_list, type_list = [], [], [], []
425 | for index, nuc in nucs.items():
426 | if len(image.shape) < 3:
427 | image = np.expand_dims(image, axis=0)
428 | # mask = np.zeros_like(image)
429 | center_x, center_y = nuc['centroid']
430 | bbox_xa, bbox_ya = nuc['bbox'][0]
431 | bbox_xb, bbox_yb = nuc['bbox'][1]
432 | mask = np.zeros([3, bbox_xb - bbox_xa, bbox_yb - bbox_ya])
433 | mask = image[:, bbox_xa: bbox_xb, bbox_ya: bbox_yb]
434 | try:
435 | mask = torch.tensor(cv2.resize(mask.permute(1, 2, 0).numpy(), (30, 30)))
436 | except:
437 | continue
438 |
439 | mask_list.append(mask.permute(2, 0, 1)), size_list.append(np.array([(bbox_xb - bbox_xa) / image.shape[-1], (bbox_yb - bbox_ya) / image.shape[-1]]))
440 | pos_list.append(np.array([center_x, center_y]))
441 | type_list.append(int(self.tc['HoVerNet'][int(nuc['type'])]))
442 |
443 | cell_types = np.array(type_list)
444 | image = torch.tensor(cv2.resize(image.permute(1, 2, 0).numpy(), (256, 256))).permute(2, 0, 1)
445 | cell_num = len(mask_list)
446 | if cell_num <= 1:
447 | dist_mat = np.zeros([cell_num, cell_num])
448 | cell_coords = np.array(pos_list)
449 | else:
450 | cell_coords = np.stack(pos_list, axis=0)
451 | dist_mat = np.zeros([cell_num, cell_num])
452 | for i in range(cell_num):
453 | for j in range(i + 1, cell_num):
454 | dist = np.linalg.norm((cell_coords[i] - cell_coords[j]), ord=2)
455 | if dist < 40:
456 | dist_mat[i, j] = 1
457 |
458 | dist_mat = dist_mat + dist_mat.T + np.identity(cell_num)
459 |
460 | adj_mat = np.zeros([self.mcn, self.mcn])
461 | cell_pixels = np.zeros([self.mcn, 2])
462 | CellType = np.zeros([self.mcn])
463 | valid_mask = len(mask_list) if len(mask_list) < self.mcn else self.mcn # max cell num
464 | if valid_mask >= self.mcn:
465 | mask_list = mask_list[:self.mcn]
466 | size_list = size_list[:self.mcn]
467 | adj_mat = dist_mat[:self.mcn, :self.mcn]
468 | cell_pixels = cell_coords[:self.mcn, :self.mcn]
469 | CellType = cell_types[:self.mcn]
470 | else:
471 | for _ in range(self.mcn - valid_mask):
472 | mask_list.append(torch.zeros((3, 30, 30)))
473 | size_list.append(torch.zeros((2)))
474 |
475 | adj_mat = np.pad(dist_mat, ((0, self.mcn - valid_mask), (0, self.mcn - valid_mask)))
476 | for i in range(valid_mask):
477 | cell_pixels[i] = cell_coords[i]
478 | CellType[:valid_mask] = cell_types
479 |
480 | cell_images = np.stack(mask_list, axis=0)
481 | cell_sizes = np.stack(size_list, axis=0)
482 |
483 | if self.cell_dir is None:
484 | return {
485 | 'name': f'{sample_name}_{img_name}',
486 | 'tissue': image, # B 3 128 128
487 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128
488 | 'mask': int(valid_mask), # B
489 | 'size': torch.tensor(cell_sizes, dtype=torch.float32),
490 | 'adj': torch.tensor(adj_mat, dtype=torch.long),
491 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32),
492 | 'cell_types': torch.tensor(CellType, dtype=torch.float32)
493 | }
494 | # Add tissue compartment
495 | if self.panuke:
496 | sample_keys = sample_name + '_' + img_name.split('-')[0] + '-Prostate' # Only for Panuke
497 | cell_prop = self.cell_labels[sample_keys]
498 | tissue_cat = -1
499 |
500 | else:
501 | sample_keys = sample_name + '_' + img_name.split('_')[-1]
502 | cell_prop = self.cell_labels[sample_keys]
503 | # tc = np.zeros(len(self.tc['list']))
504 | # for cata_i, prop in enumerate(cell_prop):
505 | # tc[self.tc['list'].index(self.tc['dict'][str(cata_i)])] += prop
506 | if np.max(cell_prop) > 0.75:
507 | tissue_cat = self.tc['list'].index(self.tc['dict'][str(np.argmax(cell_prop))])
508 | else:
509 | tissue_cat = len(self.tc['list'])
510 |
511 | state_prop = self.state_labels[sample_keys]
512 |
513 | return {
514 | 'name': f'{sample_name}_{img_name}',
515 | 'tissue': image, # B 3 128 128
516 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128
517 | 'mask': int(valid_mask), # B
518 | 'cells': torch.tensor(cell_prop, dtype=torch.float32), # B cell_type
519 | 'states': torch.tensor(state_prop, dtype=torch.float32),
520 | 'size': torch.tensor(cell_sizes, dtype=torch.float32),
521 | 'adj': torch.tensor(adj_mat, dtype=torch.long),
522 | 'tissue_cat': torch.tensor(tissue_cat, dtype=torch.long), # B 1
523 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32),
524 | 'cell_types': torch.tensor(CellType, dtype=torch.float32)
525 | }
--------------------------------------------------------------------------------
/demo/data/url.txt:
--------------------------------------------------------------------------------
1 | Download from url:
2 | https://drive.google.com/file/d/1DzwsFvD4wWKNN1wJCe3YsQZlk9xIp8DV/view?usp=sharing
--------------------------------------------------------------------------------
/image/README/Intro.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/Intro.jpg
--------------------------------------------------------------------------------
/image/README/SOI.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/SOI.jpg
--------------------------------------------------------------------------------
/image/README/biomarker.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/biomarker.jpg
--------------------------------------------------------------------------------
/image/README/cell_state.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_state.jpg
--------------------------------------------------------------------------------
/image/README/cell_type.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_type.jpg
--------------------------------------------------------------------------------
/image/README/cell_type1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_type1.jpg
--------------------------------------------------------------------------------
/image/README/cell_type2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_type2.jpg
--------------------------------------------------------------------------------
/image/README/deconvolution1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/deconvolution1.jpg
--------------------------------------------------------------------------------
/image/README/deconvolution2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/deconvolution2.jpg
--------------------------------------------------------------------------------
/image/README/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/logo.png
--------------------------------------------------------------------------------
/image/README/segmentation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/segmentation.jpg
--------------------------------------------------------------------------------
/image/README/tissue_compartment.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/tissue_compartment.jpg
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | import pickle as pkl
5 | from model.arch import HistoCell
6 | from data import TileBatchDataset
7 | from utils.utils import *
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 | from yacs.config import CfgNode as CN
11 |
12 | def _get_config(tissue_type, deconv, subtype, k_class, tissue_dir):
13 | config = CN()
14 |
15 | config.data = CN()
16 | config.data.deconv = deconv
17 | config.data.save_model = f'./train_log/{tissue_type}/models'
18 | config.data.ckpt = f'./train_log/{tissue_type}/ckpts'
19 | config.data.tile_dir = f'./demo/data/{tissue_type}/tiles'
20 | config.data.mask_dir = f'./demo/data/{tissue_type}/seg'
21 | config.data.batch_size = 16
22 | config.data.tissue_dir = tissue_dir
23 | config.data.max_cell_num = 256
24 |
25 | config.model = CN()
26 | config.model.tissue_class = 3
27 | config.model.pretrained = True
28 | config.model.channels = 3
29 |
30 | config.data.cell_dir = ''
31 | config.model.k_class = k_class
32 |
33 | return config
34 |
35 | def main():
36 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
37 | parser = argparse.ArgumentParser(description="Prediction with Spots Images")
38 | parser.add_argument('--seed', default=47, type=int)
39 | parser.add_argument('--model', default='Baseline', type=str, help='model description')
40 | parser.add_argument('--epoch', default=-1, type=int)
41 | parser.add_argument('--tissue', default='BRCA', type=str)
42 | parser.add_argument('--deconv', default='CARD', type=str)
43 | parser.add_argument('--subtype', action='store_true')
44 | parser.add_argument('--prefix', required=False, nargs = '+')
45 | parser.add_argument('--k_class', default=8, type=int)
46 | parser.add_argument('--tissue_compartment', type=str, required=True)
47 | parser.add_argument('--omit_gt', action='store_true')
48 | parser.add_argument('--val_panuke', default=None, type=str)
49 |
50 | args = parser.parse_args()
51 | config = _get_config(args.tissue, args.deconv, args.subtype, args.k_class, args.tissue_compartment)
52 |
53 | print(f"Model details: {args.model}")
54 |
55 | print("Load Dataset...")
56 |
57 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix]
58 |
59 | print(data_prefix)
60 |
61 | val_data = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, cell_dir=None,
62 | deconv=config.data.deconv, prefix=data_prefix,
63 | aug=False, val=args.omit_gt, panuke=args.val_panuke, max_cell_num=config.data.max_cell_num)
64 | print(f"length of train data: {len(val_data)}")
65 | # val_data = TileMaskDataset(config.data.tile_dir, config.data.mask_dir, config.data.cell_dir)
66 | val_loader = DataLoader(val_data, batch_size=config.data.batch_size, num_workers=8, pin_memory=True)
67 | # val_loader = DataLoader(val_data, batch_size=1, collate_fn=collate_fn)
68 |
69 | print("Load Model...")
70 | model = HistoCell(config.model)
71 |
72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73 | if torch.cuda.is_available():
74 | print("Using 1 gpu")
75 | else:
76 | print("Using cpu")
77 |
78 |
79 | model_dir, ckpt_dir = \
80 | os.path.join(config.data.save_model, args.model), os.path.join(config.data.ckpt, args.model)
81 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True)
82 |
83 | ckpt_file = find_ckpt(model_dir)
84 |
85 | if args.epoch >= 0:
86 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, f"epoch_{args.epoch}.ckpt"))
87 | model.load_state_dict(state_dict)
88 |
89 | # ipdb.set_trace()
90 | elif ckpt_file is not None:
91 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file))
92 | model.load_state_dict(state_dict)
93 |
94 | else:
95 | raise FileNotFoundError("No trained model exits")
96 |
97 | model.to(device)
98 |
99 | print(f"Current Epoch: {current_epoch}")
100 |
101 | val_results = val_loop(current_epoch, model, val_loader, device)
102 | with open(os.path.join(ckpt_dir, f'epoch_{current_epoch}_{data_prefix[0]}_val.pkl'), 'wb') as file:
103 | pkl.dump(val_results, file)
104 |
105 | print("Val Finished!")
106 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}")
107 |
108 |
109 | def val_loop(epoch, model, val_loader, device):
110 | model.train()
111 | val_bar = tqdm(val_loader, desc="epoch " + str(epoch), total=len(val_loader),
112 | unit="batch", dynamic_ncols=True)
113 |
114 | all_results = {}
115 | for idx, data in enumerate(val_bar):
116 | tissue = data['tissue'].to(torch.float32).to(device)
117 | images = data['image'].to(torch.float32).to(device)
118 | cell_size = data['size'].to(torch.float32).to(device)
119 | adj_mat = data['adj'].to(torch.float32).to(device)
120 | valid_mask = data['mask'].to(torch.long).to(device)
121 | cell_coords = data['cell_coords'].to(torch.float32).to(device)
122 | ref_type = data['cell_types'].to(torch.long).to(device)
123 | if torch.sum(data['mask']) <= 0:
124 | continue
125 | batch, cells, channels, height, width = images.shape
126 | images = images.reshape(batch * cells, channels, height, width)
127 | prob_list, pred_proportion, tissue_cat, cell_embeddings = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells})
128 | name_list, valid_list, coord_list, type_list = [], [], [], []
129 | for name, valid_index, valid_coord, valid_type in zip(data['name'], valid_mask, cell_coords, ref_type):
130 | if valid_index <= 0:
131 | continue
132 | name_list.append(name)
133 | valid_list.append(valid_index)
134 | coord_list.append(valid_coord)
135 | type_list.append(valid_type)
136 |
137 | for name, pb, pp, vix, cc, ct, cell_features in zip(name_list, prob_list, pred_proportion, valid_list, coord_list, type_list, cell_embeddings):
138 | valid_num = int(vix.detach().cpu())
139 | all_results.update(
140 | {
141 | name: {
142 | 'prob': pb.detach().cpu().numpy(),
143 | 'pred_proportion': pp.detach().cpu().numpy(),
144 | 'prior_type': ct.cpu().numpy(),
145 | 'cell_num': valid_num,
146 | 'cell_coords': cc[:valid_num].detach().cpu().numpy(),
147 | # 'cell_embedding': cell_features.detach().cpu().numpy()
148 | }
149 | }
150 | )
151 |
152 | val_bar.set_description('epoch:{} iter:{}'.format(epoch, idx))
153 |
154 | return all_results
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
--------------------------------------------------------------------------------
/model/__pycache__/arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/model/__pycache__/arch.cpython-37.pyc
--------------------------------------------------------------------------------
/model/arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision.models import resnet18
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 |
7 |
8 | class GraphAttentionLayer(nn.Module):
9 | """
10 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
11 | """
12 | def __init__(self, in_features, out_features, dropout, alpha, concat=True):
13 | super(GraphAttentionLayer, self).__init__()
14 | self.dropout = dropout
15 | self.in_features = in_features
16 | self.out_features = out_features
17 | self.alpha = alpha
18 | self.concat = concat
19 |
20 | self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
21 | nn.init.xavier_uniform_(self.W.data, gain=1.414)
22 | self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
23 | nn.init.xavier_uniform_(self.a.data, gain=1.414)
24 |
25 | self.leakyrelu = nn.LeakyReLU(self.alpha)
26 |
27 | def forward(self, h, adj):
28 | Wh = torch.matmul(h, self.W) # h.shape: (B, N, in_features), Wh.shape: (N, out_features)
29 | e = self._prepare_attentional_mechanism_input(Wh)
30 |
31 | zero_vec = -9e15*torch.ones_like(e) # B N N
32 | attention = torch.where(adj > 0, e, zero_vec)
33 | attention = F.softmax(attention, dim=1)
34 | attention = F.dropout(attention, self.dropout, training=self.training)
35 | h_prime = torch.matmul(attention, Wh)
36 |
37 | if self.concat:
38 | return F.elu(h_prime)
39 | else:
40 | return h_prime
41 |
42 | def _prepare_attentional_mechanism_input(self, Wh):
43 | # Wh.shape (N, out_feature)
44 | # self.a.shape (2 * out_feature, 1)
45 | # Wh1&2.shape (N, 1)
46 | # e.shape (N, N)
47 | Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
48 | Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
49 | # broadcast add
50 | e = Wh1 + Wh2.transpose(-1, -2)
51 | return self.leakyrelu(e)
52 |
53 | def __repr__(self):
54 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
55 |
56 |
57 | class GAT(nn.Module):
58 | def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
59 | """Dense version of GAT."""
60 | super(GAT, self).__init__()
61 | self.dropout = dropout
62 |
63 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
64 | for i, attention in enumerate(self.attentions):
65 | self.add_module('attention_{}'.format(i), attention)
66 |
67 | self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
68 |
69 | def forward(self, x, adj):
70 | x = F.dropout(x, self.dropout, training=self.training)
71 | x = torch.cat([att(x, adj) for att in self.attentions], dim=-1)
72 | x = F.dropout(x, self.dropout, training=self.training)
73 | x = F.elu(self.out_att(x, adj))
74 | return F.log_softmax(x, dim=1)
75 |
76 |
77 | class HistoCell(nn.Module):
78 | def __init__(self, config) -> None:
79 | super(HistoCell, self).__init__()
80 | resnet = resnet18(pretrained=config.pretrained)
81 |
82 | if config.channels == 1:
83 | modules = [nn.Conv2d(config.channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)] + list(resnet.children())[1:-1]
84 | self.resnet = nn.Sequential(*modules)
85 |
86 | elif config.channels == 3:
87 | self.resnet = nn.Sequential(*list(resnet.children())[:-1])
88 |
89 | self.gat = GraphAttentionLayer(512, 512, dropout=0.5, alpha=0.2, concat=False)
90 |
91 | self.size_embed = nn.Sequential(
92 | nn.Linear(2, 16),
93 | nn.ReLU()
94 | )
95 | self.merge = nn.Linear(512 + 16, 512)
96 |
97 | self.out = nn.LSTM(input_size=512, hidden_size=512, num_layers=1, batch_first=True)
98 | self.predict = nn.Sequential(
99 | nn.Linear(512, config.k_class),
100 | nn.Softmax(dim=-1)
101 | )
102 | self.tc = nn.Linear(512, config.tissue_class + 1)
103 |
104 | def forward(self, tissue, bag, adj, cell_size, valid_mask, raw_size): # cell number * 3 * 224 * 224
105 | mask_feats = self.resnet(bag).squeeze().reshape(raw_size['batch'], raw_size['cells'], -1)
106 | size_feats = self.size_embed(cell_size) # B C 16
107 | mask_feats = self.merge(torch.concat([mask_feats, size_feats], dim=-1)) # B C 512
108 | # import ipdb
109 | # ipdb.set_trace()
110 | dmask_feats = nn.functional.dropout(mask_feats, p=0.5, training=True)
111 | graph_feats = self.gat(dmask_feats, adj) # B C 512
112 |
113 | global_feat = self.resnet(tissue).squeeze()
114 | if len(global_feat.shape) <= 1:
115 | global_feat = global_feat.unsqueeze(0)
116 |
117 | global_feats = torch.stack([global_feat for _ in range(graph_feats.shape[1])], dim=1)
118 | seq_feats = torch.stack([global_feats, graph_feats], dim=2) # (BxC) 3 512
119 | seq_feats = rearrange(seq_feats, 'B C L F-> (B C) L F')
120 | out_feats, _ = self.out(seq_feats)
121 | out_feats = rearrange(out_feats, '(B C) L F-> B C L F', B=raw_size['batch']) # B C 3 512
122 | dout_feats = nn.functional.dropout(out_feats, p=0.5, training=True)
123 |
124 | # tissue
125 | tissue_cat = self.tc(dout_feats[:, 0, 0])
126 | # Proportion
127 | probs = self.predict(dout_feats[:, :, 1]) # 16 64 cell_type
128 | prop_list, prob_list, cell_features = [], [], []
129 | for single_probs, valid_index, cell_embedding in zip(probs, valid_mask, dout_feats[:, :, 1]):
130 | if valid_index <= 0:
131 | continue
132 | prop_list.append(torch.mean(single_probs[:valid_index], dim=0))
133 | prob_list.append(single_probs[:valid_index])
134 | cell_features.append(cell_embedding[:valid_index])
135 |
136 | avg_probs = torch.stack(prop_list, dim=0)
137 |
138 | return prob_list, avg_probs, tissue_cat, cell_features
139 |
140 |
141 | class HistoState(nn.Module):
142 | def __init__(self, config) -> None:
143 | super(HistoState, self).__init__()
144 | resnet = resnet18(pretrained=config.pretrained)
145 |
146 | if config.channels == 1:
147 | modules = [nn.Conv2d(config.channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)] + list(resnet.children())[1:-1]
148 | self.resnet = nn.Sequential(*modules)
149 |
150 | elif config.channels == 3:
151 | self.resnet = nn.Sequential(*list(resnet.children())[:-1])
152 |
153 | self.gat = GraphAttentionLayer(512, 512, dropout=0.5, alpha=0.2, concat=False)
154 |
155 | self.size_embed = nn.Sequential(
156 | nn.Linear(2, 16),
157 | nn.ReLU()
158 | )
159 | self.merge = nn.Linear(512 + 16, 512)
160 |
161 | self.out = nn.LSTM(input_size=512, hidden_size=512, num_layers=1, batch_first=True)
162 | self.predict1 = nn.Sequential(
163 | nn.Linear(512, config.k_class),
164 | nn.Softmax(dim=-1)
165 | )
166 | self.predict2 = nn.Sequential(
167 | nn.Linear(512, config.k_state),
168 | nn.Softmax(dim=-1)
169 | )
170 | self.tc = nn.Linear(512, config.tissue_class + 1)
171 |
172 | def forward(self, tissue, bag, adj, cell_size, valid_mask, raw_size): # cell number * 3 * 224 * 224
173 | mask_feats = self.resnet(bag).squeeze().reshape(raw_size['batch'], raw_size['cells'], -1)
174 | size_feats = self.size_embed(cell_size) # B C 16
175 | mask_feats = self.merge(torch.concat([mask_feats, size_feats], dim=-1)) # B C 512
176 | graph_feats = self.gat(mask_feats, adj) # B C 512
177 |
178 | global_feat = self.resnet(tissue).squeeze()
179 | if len(global_feat.shape) <= 1:
180 | global_feat = global_feat.unsqueeze(0)
181 |
182 | global_feats = torch.stack([global_feat for _ in range(graph_feats.shape[1])], dim=1)
183 | seq_feats = torch.stack([global_feats, graph_feats, graph_feats], dim=2) # (BxC) 3 512
184 | seq_feats = rearrange(seq_feats, 'B C L F-> (B C) L F')
185 | out_feats, _ = self.out(seq_feats)
186 | out_feats = rearrange(out_feats, '(B C) L F-> B C L F', B=raw_size['batch']) # B C 3 512
187 |
188 | # tissue
189 | tissue_cat = self.tc(out_feats[:, 0, 0])
190 | # Cell Proportion
191 | type_probs = self.predict1(out_feats[:, :, 1]) # 16 64 cell_type
192 | type_prop_list, type_prob_list = [], []
193 | for single_probs, valid_index in zip(type_probs, valid_mask):
194 | if valid_index <= 0:
195 | continue
196 | type_prop_list.append(torch.mean(single_probs[:valid_index], dim=0))
197 | type_prob_list.append(single_probs[:valid_index])
198 |
199 | avg_type_probs = torch.stack(type_prop_list, dim=0)
200 |
201 | state_probs = self.predict2(out_feats[:, :, 2]) # 16 64 cell_type
202 | state_prop_list, state_prob_list = [], []
203 | for single_probs, valid_index in zip(state_probs, valid_mask):
204 | if valid_index <= 0:
205 | continue
206 | state_prop_list.append(torch.mean(single_probs[:valid_index], dim=0))
207 | state_prob_list.append(single_probs[:valid_index])
208 |
209 | avg_state_probs = torch.stack(state_prop_list, dim=0)
210 |
211 | return {
212 | 'tissue_compartment': tissue_cat,
213 | 'type_prob_list': type_prob_list,
214 | 'type_prop': avg_type_probs,
215 | 'state_prob_list': state_prob_list,
216 | 'state_prop': avg_state_probs
217 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | _openmp_mutex=5.1=1_gnu
6 | ca-certificates=2025.2.25=h06a4308_0
7 | certifi=2022.12.7=py37h06a4308_0
8 | charset-normalizer=3.4.1=pypi_0
9 | einops=0.6.1=pypi_0
10 | idna=3.10=pypi_0
11 | joblib=1.3.2=pypi_0
12 | ld_impl_linux-64=2.40=h12ee557_0
13 | libffi=3.4.4=h6a678d5_1
14 | libgcc-ng=11.2.0=h1234567_1
15 | libgomp=11.2.0=h1234567_1
16 | libstdcxx-ng=11.2.0=h1234567_1
17 | ncurses=6.4=h6a678d5_0
18 | numpy=1.21.6=pypi_0
19 | opencv-python=4.11.0.86=pypi_0
20 | openssl=1.1.1w=h7f8727e_0
21 | pandas=1.3.5=pypi_0
22 | pillow=9.5.0=pypi_0
23 | pip=22.3.1=py37h06a4308_0
24 | python=3.7.16=h7a1cb2a_0
25 | python-dateutil=2.9.0.post0=pypi_0
26 | pytz=2025.1=pypi_0
27 | pyyaml=6.0.1=pypi_0
28 | readline=8.2=h5eee18b_0
29 | requests=2.31.0=pypi_0
30 | scikit-learn=1.0.2=pypi_0
31 | scipy=1.7.3=pypi_0
32 | setuptools=65.6.3=py37h06a4308_0
33 | six=1.17.0=pypi_0
34 | sqlite=3.45.3=h5eee18b_0
35 | threadpoolctl=3.1.0=pypi_0
36 | tk=8.6.14=h39e8969_0
37 | torch=1.12.1+cu113=pypi_0
38 | torchaudio=0.12.1+cu113=pypi_0
39 | torchvision=0.13.1+cu113=pypi_0
40 | tqdm=4.67.1=pypi_0
41 | typing-extensions=4.7.1=pypi_0
42 | urllib3=2.0.7=pypi_0
43 | wheel=0.38.4=py37h06a4308_0
44 | xz=5.6.4=h5eee18b_1
45 | yacs=0.1.8=pypi_0
46 | zlib=1.2.13=h5eee18b_1
47 |
--------------------------------------------------------------------------------
/tcs/tissue_compartment_BRCA.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "Stromal", "1": "TME", "2": "TME", "3": "TME", "4": "TME", "5": "Epi"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_CRC.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "Epi", "1": "Stromal", "2": "TME", "3": "TME", "4": "TME", "5": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_HCC.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "TME", "1": "Epi", "2": "TME", "3": "Stromal", "4": "TME", "5": "Epi"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_HD.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "TME", "1": "TME", "2": "TME", "3": "Stromal", "4": "Epi", "5": "Epi"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_LUSC.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "TME", "1": "Epi", "2": "Epi", "3": "Stromal", "4": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_Mix.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "Epi", "1": "TME", "2": "Stromal", "3": "TME", "4": "Epi"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_OVC.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "Stromal", "1": "TME", "2": "Stromal", "3": "Epi", "4": "TME", "5": "TME", "6": "TME", "7": "Epi", "8": "TME", "9": "TME", "10": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_PDAC.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "TME", "1": "Epi", "2": "TME", "3": "Epi", "4": "Stromal", "5": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_PRAD.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "TME", "1": "Epi", "2": "Stromal", "3": "Stromal", "4": "TME", "5": "TME", "6": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_RCC.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "Epi", "1": "Epi", "2": "Stromal", "3": "TME", "4": "TME", "5": "TME", "6": "Stromal", "7": "TME", "8": "Epi", "9": "TME", "10": "TME", "11": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_STAD.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "Epi", "1": "Epi", "2": "Epi", "3": "TME", "4": "Epi", "5": "Epi", "6": "Epi", "7": "Epi", "8": "TME", "9": "Stromal", "10": "Stromal", "11": "TME", "12": "TME", "13": "Stromal", "14": "Epi", "15": "Stromal", "16": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 3]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_Xenium.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "TME", "1": "Epi", "2": "Stromal", "3": "TME", "4": "Epi", "5": "Stromal", "6": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "HoVerNet":
7 | [2, 0, 1, 2, 2, 0]
8 | }
--------------------------------------------------------------------------------
/tcs/tissue_compartment_state.json:
--------------------------------------------------------------------------------
1 | {
2 | "dict":
3 | {"0": "TME", "1": "Epi", "2": "TME", "3": "TME", "4": "Stromal", "5": "TME"},
4 | "list":
5 | ["Epi", "TME", "Stromal"],
6 | "state":
7 | {"0": "1", "1": "6", "2": "1", "3": "1", "4": "1", "5": "1"},
8 | "HoVerNet":
9 | [2, 0, 1, 2, 2, 0]
10 | }
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | from model.arch import HistoCell
5 | from data import TileBatchDataset
6 | from utils.utils import *
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 | from torch.optim import Adam
10 | from configs import _get_config
11 |
12 | def main():
13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
14 | parser = argparse.ArgumentParser(description="Prediction with Spots Images")
15 | parser.add_argument('--seed', default=47, type=int)
16 | parser.add_argument('--model', default='Baseline', type=str, help='model description')
17 | parser.add_argument('--tissue', default='BRCA', type=str)
18 | parser.add_argument('--deconv', default='CARD', type=str)
19 | parser.add_argument('--subtype', action='store_true')
20 | parser.add_argument('--prefix', required = False, nargs = '+')
21 | parser.add_argument('--k_class', default=8, type=int)
22 | parser.add_argument('--tissue_compartment', type=str, required=True)
23 | parser.add_argument('--ratio', type=float, required=False, default=1.0)
24 |
25 | args = parser.parse_args()
26 | config = _get_config(args.tissue, args.deconv, args.subtype, args.k_class, args.tissue_compartment)
27 |
28 | setup_seed(args.seed)
29 | print(f"Model details: {args.model}")
30 |
31 | print("Load Dataset...")
32 |
33 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix]
34 |
35 | print(data_prefix)
36 | train_data = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, deconv=config.data.deconv, prefix=data_prefix, sample_ratio=args.ratio, ext='png')
37 | print(f"length of train data: {len(train_data)}")
38 | train_loader = DataLoader(train_data, batch_size=config.data.batch_size, shuffle=True, num_workers=6, pin_memory=True)
39 |
40 | print("Load Model...")
41 | model = HistoCell(config.model)
42 | loss_func = torch.nn.KLDivLoss()
43 | aux_loss = torch.nn.CrossEntropyLoss()
44 | optimizer = Adam(model.parameters(), lr=config.train.lr)
45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 | if torch.cuda.is_available():
47 | print("Using 1 gpu")
48 | else:
49 | print("Using cpu")
50 |
51 |
52 | model_dir, ckpt_dir = \
53 | os.path.join(config.data.save_model, args.model), os.path.join(config.data.ckpt, args.model)
54 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True)
55 |
56 | ckpt_file = find_ckpt(model_dir)
57 | # ipdb.set_trace()
58 | if ckpt_file is not None:
59 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file))
60 | model.load_state_dict(state_dict)
61 | else:
62 | current_epoch = -1
63 |
64 | model.to(device)
65 |
66 | print(f"Current Epoch: {current_epoch}")
67 |
68 | print("Training...")
69 | for iter in range(current_epoch + 1, config.train.epoch):
70 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device)
71 | if iter % config.train.val_iter == 0 and iter > config.train.val_min_iter:
72 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt'))
73 |
74 | print("Training finished!")
75 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}")
76 |
77 |
78 | def train_loop(epoch, model, train_loader, loss_func, aux_loss, opt, device):
79 | model.train()
80 | train_bar = tqdm(train_loader, desc="epoch " + str(epoch), total=len(train_loader),
81 | unit="batch", dynamic_ncols=True)
82 | for idx, data in enumerate(train_bar):
83 | tissue = data['tissue'].to(torch.float32).to(device)
84 | images, cell_proportion = data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device)
85 | valid_mask = data['mask'].to(torch.long).to(device)
86 | cell_size = data['size'].to(torch.float32).to(device)
87 | adj_mat = data['adj'].to(torch.float32).to(device)
88 | gt_tissue_cat = data['tissue_cat'].to(device)
89 | if torch.sum(data['mask']) <= 0:
90 | continue
91 | batch, cells, channels, height, width = images.shape
92 | images = images.reshape(batch * cells, channels, height, width)
93 | probs, pred_proportion, tissue_cat, _ = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells})
94 | cell_prop_list = []
95 | for single_props, valid_index in zip(cell_proportion, valid_mask):
96 | if valid_index <= 0:
97 | continue
98 | cell_prop_list.append(single_props)
99 | cell_proportion = torch.stack(cell_prop_list, dim=0)
100 |
101 | loss_out = loss_func((cell_proportion + 1e-10).log(), pred_proportion + 1e-10) + \
102 | loss_func((pred_proportion + 1e-10).log(), cell_proportion + 1e-10) + \
103 | aux_loss(tissue_cat, gt_tissue_cat)
104 | loss_out.backward()
105 | opt.step()
106 | opt.zero_grad()
107 |
108 | loss_value = loss_out.item()
109 | train_bar.set_description('epoch:{} iter:{} loss:{}'.format(epoch, idx, loss_value))
110 |
111 | if __name__ == '__main__':
112 | main()
--------------------------------------------------------------------------------
/train_oneout_pannuke.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import argparse
5 | import random
6 | import pickle as pkl
7 | from model.arch import HistoCell
8 | from data import TileBatchDataset
9 | from utils.utils import *
10 | from torch.utils.data import DataLoader, dataset
11 | from tqdm import tqdm
12 | from yacs.config import CfgNode as CN
13 | from torch.optim import Adam
14 |
15 | def _get_config(tissue_type, deconv, subtype, k_class, tissue_dir):
16 | config = CN()
17 | config.train = CN()
18 | config.train.lr = 0.0005
19 | config.train.epoch = 61
20 | config.train.val_iter = 20
21 | config.train.val_min_iter = 1
22 |
23 | config.data = CN()
24 | config.data.deconv = deconv
25 | config.data.save_model = f'./train_log/pannuke/{tissue_type}/models'
26 | config.data.ckpt = f'./train_log/pannuke/{tissue_type}/ckpts'
27 | config.data.tile_dir = f'/data/gcf22/PaNnuke/raw_images'
28 | config.data.mask_dir = f'/data/gcf22/PaNnuke/cls_results_cv'
29 | config.data.batch_size = 64
30 | config.data.tissue_dir = tissue_dir
31 |
32 | config.model = CN()
33 | config.model.tissue_class = 3
34 | config.model.pretrained = True
35 | config.model.channels = 3
36 |
37 | config.data.cell_dir = f'/data/gcf22/PaNnuke/cell_props'
38 |
39 | config.model.k_class = k_class
40 |
41 | return config
42 |
43 |
44 | def train_loop(epoch, model, train_loader, loss_func, aux_loss, opt, device):
45 | model.train()
46 | train_bar = tqdm(train_loader, desc="epoch " + str(epoch), total=len(train_loader),
47 | unit="batch", dynamic_ncols=True)
48 | for idx, data in enumerate(train_bar):
49 | tissue = data['tissue'].to(torch.float32).to(device)
50 | images, cell_proportion = data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device)
51 | valid_mask = data['mask'].to(torch.long).to(device)
52 | cell_size = data['size'].to(torch.float32).to(device)
53 | adj_mat = data['adj'].to(torch.float32).to(device)
54 | gt_tissue_cat = data['tissue_cat'].to(device)
55 | if torch.sum(data['mask']) <= 0:
56 | continue
57 | batch, cells, channels, height, width = images.shape
58 | images = images.reshape(batch * cells, channels, height, width)
59 | probs, pred_proportion, tissue_cat, _ = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells})
60 | cell_prop_list = []
61 | for single_props, valid_index in zip(cell_proportion, valid_mask):
62 | if valid_index <= 0:
63 | continue
64 | cell_prop_list.append(single_props)
65 | cell_proportion = torch.stack(cell_prop_list, dim=0)
66 |
67 | loss_out = loss_func((cell_proportion + 1e-10).log(), pred_proportion + 1e-10) + \
68 | loss_func((pred_proportion + 1e-10).log(), cell_proportion + 1e-10)
69 | loss_out.backward()
70 | opt.step()
71 | opt.zero_grad()
72 |
73 | loss_value = loss_out.item()
74 | train_bar.set_description('epoch:{} iter:{} loss:{}'.format(epoch, idx, loss_value))
75 |
76 |
77 | def val_loop(epoch, model, val_loader, device):
78 | model.train()
79 | val_bar = tqdm(val_loader, desc="epoch " + str(epoch), total=len(val_loader),
80 | unit="batch", dynamic_ncols=True)
81 |
82 | all_results = {}
83 | for idx, data in enumerate(val_bar):
84 | tissue = data['tissue'].to(torch.float32).to(device)
85 | images, cell_proportion = data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device)
86 | cell_size = data['size'].to(torch.float32).to(device)
87 | adj_mat = data['adj'].to(torch.float32).to(device)
88 | valid_mask = data['mask'].to(torch.long).to(device)
89 | cell_coords = data['cell_coords'].to(torch.float32).to(device)
90 | ref_type = data['cell_types'].to(torch.long).to(device)
91 | if torch.sum(data['mask']) <= 0:
92 | continue
93 | batch, cells, channels, height, width = images.shape
94 | images = images.reshape(batch * cells, channels, height, width)
95 | prob_list, pred_proportion, _, cell_embeddings = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells})
96 | name_list, valid_list, coord_list, type_list, cell_prop_list = [], [], [], [], []
97 | for name, single_prop, valid_index, valid_coord, valid_type in zip(data['name'], cell_proportion, valid_mask, cell_coords, ref_type):
98 | if valid_index <= 0:
99 | continue
100 | name_list.append(name)
101 | valid_list.append(valid_index)
102 | coord_list.append(valid_coord)
103 | type_list.append(valid_type)
104 | cell_prop_list.append(single_prop)
105 |
106 | for name, pb, pp, cp, vix, cc, ct in zip(name_list, prob_list, pred_proportion, cell_prop_list, valid_list, coord_list, type_list):
107 | valid_num = int(vix.detach().cpu())
108 | all_results.update(
109 | {
110 | name: {
111 | 'pred_proportion': pp.detach().cpu().numpy(),
112 | 'cell_num': valid_num,
113 | 'cell_coords': cc[:valid_num].detach().cpu().numpy(),
114 | 'prob': pb.detach().cpu().numpy(),
115 | 'prior_type': ct[:valid_num].detach().cpu().numpy(),
116 | 'gt_proportion': cp.detach().cpu().numpy(),
117 | # 'cell_embedding': cell_features.detach().cpu().numpy()
118 | }
119 | }
120 | )
121 | val_bar.set_description('epoch:{} iter:{}'.format(epoch, idx))
122 | return all_results
123 |
124 |
125 | def main():
126 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
127 | parser = argparse.ArgumentParser(description="Prediction with Spots Images" )
128 | parser.add_argument('--seed', default=47, type=int)
129 | parser.add_argument('--model', default='Baseline', type=str, help='model description')
130 | parser.add_argument('--tissue', default='*', type=str)
131 | parser.add_argument('--deconv', default='CARD', type=str)
132 | parser.add_argument('--subtype', action='store_true')
133 | parser.add_argument('--prefix', required = False, nargs = '+')
134 | parser.add_argument('--k_class', default=8, type=int)
135 | parser.add_argument('--tissue_compartment', type=str, required=True)
136 | parser.add_argument('--sample_ratio', default = 1, type=float)
137 | parser.add_argument('--reso', default = 1, type=int)
138 |
139 | args = parser.parse_args()
140 | config = _get_config(args.tissue, args.deconv, args.subtype, args.k_class, args.tissue_compartment)
141 |
142 | setup_seed(args.seed)
143 | print(f"Model details: {args.model}")
144 |
145 | print("Load Dataset...")
146 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix]
147 | print(data_prefix)
148 |
149 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150 | if torch.cuda.is_available():
151 | print("Using 1 gpu")
152 | else:
153 | print("Using cpu")
154 |
155 | for idx, val_prefix in enumerate(data_prefix):
156 | raw_prefix = [item for item in data_prefix]
157 | raw_prefix.pop(idx)
158 | val_fold = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, deconv=config.data.deconv, prefix=[val_prefix], focus_tissue=args.tissue)
159 | train_fold = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, deconv=config.data.deconv, prefix=raw_prefix, focus_tissue=args.tissue, reso=args.reso)
160 | print(f"length of train data: {len(train_fold)} with sample {raw_prefix}")
161 | train_fold,_ = torch.utils.data.random_split(train_fold, [int(len(train_fold)*args.sample_ratio), len(train_fold) - int(len(train_fold)*args.sample_ratio)])
162 |
163 | print(f"length of sampled train data: {len(train_fold)} with sample {raw_prefix}")
164 | print(f"length of val data: {len(val_fold)} with sample {val_prefix}")
165 |
166 | train_loader = DataLoader(train_fold, batch_size=config.data.batch_size, num_workers=6, pin_memory=True)
167 | val_loader = DataLoader(val_fold, batch_size=config.data.batch_size // 2, num_workers=6, pin_memory=True)
168 |
169 | print("Load Model...")
170 | model = HistoCell(config.model)
171 | loss_func = torch.nn.KLDivLoss()
172 | aux_loss = torch.nn.CrossEntropyLoss()
173 | optimizer = Adam(model.parameters(), lr=config.train.lr)
174 |
175 | model_dir, ckpt_dir = \
176 | os.path.join(config.data.save_model, args.model, val_prefix), os.path.join(config.data.ckpt, args.model, val_prefix)
177 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True)
178 |
179 | ckpt_file = find_ckpt(model_dir)
180 | # ipdb.set_trace()
181 | if ckpt_file is not None:
182 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file))
183 | model.load_state_dict(state_dict)
184 | else:
185 | current_epoch = -1
186 |
187 | model.to(device)
188 |
189 | print(f"Current Epoch: {current_epoch}")
190 |
191 | print("Training...")
192 | for iter in range(current_epoch + 1, config.train.epoch):
193 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device)
194 | # if iter % 5 == 0:
195 | # save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt'))
196 | # TODO
197 | if iter > config.train.val_min_iter and iter % config.train.val_iter == 0:
198 | print("##### Validation #####")
199 | val_results = val_loop(iter, model, val_loader, device)
200 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.pt'))
201 | with open(os.path.join(ckpt_dir, f'epoch_{iter}_val.pkl'), 'wb') as file:
202 | pkl.dump(val_results, file)
203 |
204 | # save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt'))
205 |
206 | print("Training finished!")
207 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}")
208 |
209 | if __name__ == '__main__':
210 | main()
--------------------------------------------------------------------------------
/train_state.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import argparse
4 | from model.arch import HistoState
5 | from data import TileBatchStateDataset
6 | from utils.utils import *
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 | from torch.optim import Adam
10 | from configs import _get_cell_state_config
11 |
12 | def main():
13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
14 | parser = argparse.ArgumentParser(description="Prediction with Spots Images")
15 | parser.add_argument('--seed', default=47, type=int)
16 | parser.add_argument('--model', default='Baseline', type=str, help='model description')
17 | parser.add_argument('--tissue', default='BRCA', type=str)
18 | parser.add_argument('--deconv', default='CARD', type=str)
19 | parser.add_argument('--subtype', action='store_true')
20 | parser.add_argument('--prefix', required = False, nargs = '+')
21 | parser.add_argument('--tissue_compartment', type=str, required=True)
22 |
23 | args = parser.parse_args()
24 | config = _get_cell_state_config(args.tissue, args.deconv, args.subtype, args.tissue_compartment)
25 |
26 | setup_seed(args.seed)
27 | print(f"Model details: {args.model}")
28 |
29 | print("Load Dataset...")
30 |
31 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix]
32 |
33 | print(data_prefix)
34 | train_data = TileBatchStateDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, config.data.state_dir, deconv=config.data.deconv, prefix=data_prefix)
35 | print(f"length of train data: {len(train_data)}")
36 | train_loader = DataLoader(train_data, batch_size=config.data.batch_size, shuffle=True, num_workers=6, pin_memory=True)
37 |
38 | print("Load Model...")
39 | model = HistoState(config.model)
40 | loss_func = torch.nn.KLDivLoss()
41 | aux_loss = torch.nn.CrossEntropyLoss()
42 | optimizer = Adam(model.parameters(), lr=config.train.lr)
43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44 | if torch.cuda.is_available():
45 | print("Using 1 gpu")
46 | else:
47 | print("Using cpu")
48 |
49 |
50 | model_dir, ckpt_dir = \
51 | os.path.join(config.data.save_model, args.model), os.path.join(config.data.ckpt, args.model)
52 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True)
53 |
54 | ckpt_file = find_ckpt(model_dir)
55 | # ipdb.set_trace()
56 | if ckpt_file is not None:
57 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file))
58 | model.load_state_dict(state_dict)
59 | else:
60 | current_epoch = -1
61 |
62 | model.to(device)
63 |
64 | print(f"Current Epoch: {current_epoch}")
65 |
66 | print("Training with Cell Types...")
67 | if current_epoch + 1 <= config.train.epoch:
68 | for iter in range(current_epoch + 1, config.train.epoch):
69 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device)
70 |
71 | if iter % config.train.val_iter == 0 and iter > config.train.val_min_iter:
72 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt'))
73 |
74 | print("Training with Cell States...")
75 | for iter in range(max(config.train.epoch, current_epoch + 1), config.train.epoch + config.train.state_epoch):
76 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device, with_state=True)
77 |
78 | if iter % config.train.val_iter == 0 and iter > config.train.val_min_iter:
79 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt'))
80 |
81 | print("Training finished!")
82 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}")
83 |
84 |
85 | def train_loop(epoch, model, train_loader, loss_func, aux_loss, opt, device, with_state=False):
86 | model.train()
87 | train_bar = tqdm(train_loader, desc="epoch " + str(epoch), total=len(train_loader),
88 | unit="batch", dynamic_ncols=True)
89 | for idx, data in enumerate(train_bar):
90 | tissue = data['tissue'].to(torch.float32).to(device)
91 | images, cell_proportion, state_proportion = \
92 | data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device), data['states'].to(torch.float32).to(device)
93 | valid_mask = data['mask'].to(torch.long).to(device)
94 | cell_size = data['size'].to(torch.float32).to(device)
95 | adj_mat = data['adj'].to(torch.float32).to(device)
96 | gt_tissue_cat = data['tissue_cat'].to(device)
97 | if torch.sum(data['mask']) <= 0:
98 | continue
99 | batch, cells, channels, height, width = images.shape
100 | images = images.reshape(batch * cells, channels, height, width)
101 | pred_dict = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells})
102 | cell_prop_list = []
103 | for single_props, valid_index in zip(cell_proportion, valid_mask):
104 | if valid_index <= 0:
105 | continue
106 | cell_prop_list.append(single_props)
107 | cell_proportion = torch.stack(cell_prop_list, dim=0)
108 |
109 | loss_out = loss_func((cell_proportion + 1e-10).log(), pred_dict['type_prop'] + 1e-10) + \
110 | loss_func((pred_dict['type_prop'] + 1e-10).log(), cell_proportion + 1e-10) + \
111 | aux_loss(pred_dict['tissue_compartment'], gt_tissue_cat)
112 |
113 | if with_state:
114 | state_prop_list = []
115 | for name, single_states, valid_index in zip(data['name'], state_proportion, valid_mask):
116 | if valid_index <= 0:
117 | continue
118 | state_prop_list.append(single_states)
119 | state_proportion = torch.stack(state_prop_list, dim=0)
120 |
121 | state_loss = loss_func((state_proportion + 1e-10).log(), pred_dict['state_prop'] + 1e-10) + loss_func((pred_dict['state_prop'] + 1e-10).log(), state_proportion + 1e-10)
122 | # consist_loss = torch.nn.functional.l1_loss(pred_dict['state_prop'][:, :5], pred_dict['type_prop'][:, [0, 2, 3, 5, 6]])
123 | consist_loss = torch.nn.functional.l1_loss(pred_dict['state_prop'][:, :6], pred_dict['type_prop'][:, :6])
124 | loss_out = loss_out + state_loss + 0.5 * consist_loss
125 |
126 | loss_out.backward()
127 | opt.step()
128 | opt.zero_grad()
129 |
130 | loss_value = loss_out.item()
131 | train_bar.set_description('epoch:{} iter:{} loss:{}'.format(epoch, idx, loss_value))
132 |
133 | if __name__ == '__main__':
134 | main()
--------------------------------------------------------------------------------
/tutorial/data/url.txt:
--------------------------------------------------------------------------------
1 | Download from url:
2 | https://drive.google.com/file/d/1iwXzxafdUF7SF6IAEyWluN728WbxAMEi/view?usp=sharing
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | import torch
4 | import random
5 | import os
6 | import cv2
7 | import scipy
8 | from scipy.optimize import linear_sum_assignment
9 | from PIL import Image
10 |
11 | def reg_loss(props: torch.Tensor):
12 | # props: torch.Size([Batch Cell Prob])
13 | log_prop = torch.log(props)
14 | loss = props.mul(log_prop)
15 | return -torch.mean(loss)
16 |
17 | def index_filter(used_idx, all_idx):
18 | rest_index = []
19 | for item in all_idx:
20 | if item in used_idx:
21 | continue
22 | rest_index.append(item)
23 | return rest_index
24 |
25 | def load_image(img_path):
26 | fp = open(img_path, 'rb')
27 | pic = Image.open(fp)
28 | pic = np.array(pic)
29 | fp.close()
30 | return pic
31 |
32 | def save_image(src, img_path):
33 | img = Image.fromarray(src, 'RGB')
34 | img.save(img_path)
35 |
36 | def load_json(json_path):
37 | with open(json_path, 'r', encoding='utf-8') as file:
38 | nuc_info = json.load(file)
39 |
40 | return nuc_info['nuc']
41 |
42 | def dump_json(src: dict, json_path):
43 | with open(json_path, 'w', encoding='utf-8') as file:
44 | json.dump(src, file)
45 |
46 | def setup_seed(seed):
47 | torch.manual_seed(seed)
48 | torch.cuda.manual_seed_all(seed)
49 | np.random.seed(seed)
50 | random.seed(seed)
51 | torch.backends.cudnn.deterministic = True
52 |
53 | def collate_fn(batch):
54 | tensor_batch = {}
55 | for key, value in batch[0].items():
56 | if isinstance(value, str):
57 | tensor_batch.update({key: value})
58 | continue
59 | tensor_batch.update({key: torch.tensor(value, dtype=torch.float32)})
60 |
61 | return tensor_batch
62 |
63 | def collate_batch_fn(iter_batch):
64 | tensor_batch = {key: [] for key in iter_batch[0].keys()}
65 | for batch in iter_batch:
66 | for key, value in batch.items():
67 | tensor_batch[key].append(value)
68 |
69 | for key in tensor_batch.keys():
70 | if key in ['tissue', 'cells', 'image']:
71 | tensor_batch[key] = torch.stack(tensor_batch[key], dim=0).to(torch.float32)
72 | if key in ['mask']:
73 | tensor_batch[key] = torch.tensor(tensor_batch[key], dtype=torch.long)
74 |
75 | return tensor_batch
76 |
77 | def find_ckpt(file_dir):
78 | list=os.listdir(file_dir)
79 | list.sort(key=lambda fn: os.path.getmtime(os.path.join(file_dir, fn)) if not os.path.isdir(os.path.join(file_dir+fn)) else 0)
80 | if list != []:
81 | return list[-1]
82 | else:
83 | return None
84 |
85 |
86 | def load_ckpt(path):
87 | ckpt = torch.load(path, map_location="cpu")
88 | model = ckpt['state_dict']
89 | current_epoch = int(path.split('/')[-1].split('.')[0].split('_')[-1])
90 | return model, current_epoch
91 |
92 | def save_checkpoint(model, optimizer, save_dir):
93 | print(f"Saving checkpoint to {save_dir}")
94 | checkpoint = {
95 | 'state_dict': model.state_dict(),
96 | 'optimizer': optimizer.state_dict()
97 | }
98 | torch.save(checkpoint, save_dir)
99 |
100 | def get_bounding_box(img):
101 | """Get bounding box coordinate information."""
102 | rows = np.any(img, axis=1)
103 | cols = np.any(img, axis=0)
104 | rmin, rmax = np.where(rows)[0][[0, -1]]
105 | cmin, cmax = np.where(cols)[0][[0, -1]]
106 | # due to python indexing, need to add 1 to max
107 | # else accessing will be 1px in the box, not out
108 | rmax += 1
109 | cmax += 1
110 | return [rmin, rmax, cmin, cmax]
111 |
112 | def get_centroid(pred_inst, types = None, label_dict = None, id_offset = 0):
113 | inst_id_list = np.unique(pred_inst[pred_inst > 0]) # exlcude background 0411 is cancer v.s. normal
114 | inst_info_dict = {}
115 | for inst_id in inst_id_list:
116 | inst_map = pred_inst == inst_id
117 | # TODO: chane format of bbox output
118 | rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
119 | inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
120 | inst_map = inst_map[
121 | inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1]
122 | ]
123 | inst_map = inst_map.astype(np.uint8)
124 | inst_moment = cv2.moments(inst_map)
125 | inst_contour = cv2.findContours(
126 | inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
127 | )
128 | # * opencv protocol format may break
129 | inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))
130 | # < 3 points dont make a contour, so skip, likely artifact too
131 | # as the contours obtained via approximation => too small or sthg
132 | if inst_contour.shape[0] < 3:
133 | continue
134 | if len(inst_contour.shape) != 2:
135 | continue # ! check for trickery shape
136 | inst_centroid = [
137 | (inst_moment["m10"] / inst_moment["m00"]),
138 | (inst_moment["m01"] / inst_moment["m00"]),
139 | ]
140 | inst_centroid = np.array(inst_centroid)
141 | inst_contour[:, 0] += inst_bbox[0][1] # X
142 | inst_contour[:, 1] += inst_bbox[0][0] # Y
143 | inst_centroid[0] += inst_bbox[0][1] # X
144 | inst_centroid[1] += inst_bbox[0][0] # Y
145 | inst_info_dict[inst_id + id_offset] = { # inst_id should start at 1
146 | "bbox": inst_bbox,
147 | "centroid": inst_centroid,
148 | "contour": inst_contour,
149 | "type": types if label_dict is None else label_dict[types],
150 | }
151 |
152 | return inst_info_dict
153 |
154 |
155 | def pair_coordinates(setA, setB, radius):
156 | """Use the Munkres or Kuhn-Munkres algorithm to find the most optimal
157 | unique pairing (largest possible match) when pairing points in set B
158 | against points in set A, using distance as cost function.
159 |
160 | Args:
161 | setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate
162 | of N different points
163 | radius: valid area around a point in setA to consider
164 | a given coordinate in setB a candidate for match
165 | Return:
166 | pairing: pairing is an array of indices
167 | where point at index pairing[0] in set A paired with point
168 | in set B at index pairing[1]
169 | unparedA, unpairedB: remaining poitn in set A and set B unpaired
170 |
171 | """
172 | # * Euclidean distance as the cost matrix
173 | pair_distance = scipy.spatial.distance.cdist(setA, setB, metric='euclidean')
174 |
175 | # * Munkres pairing with scipy library
176 | # the algorithm return (row indices, matched column indices)
177 | # if there is multiple same cost in a row, index of first occurence
178 | # is return, thus the unique pairing is ensured
179 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance)
180 |
181 | # extract the paired cost and remove instances
182 | # outside of designated radius
183 | pair_cost = pair_distance[indicesA, paired_indicesB]
184 |
185 | pairedA = indicesA[pair_cost <= radius]
186 | pairedB = paired_indicesB[pair_cost <= radius]
187 |
188 | pairing = np.concatenate([pairedA[:,None], pairedB[:,None]], axis=-1)
189 | unpairedA = np.delete(np.arange(setA.shape[0]), pairedA)
190 | unpairedB = np.delete(np.arange(setB.shape[0]), pairedB)
191 | return pairing, unpairedA, unpairedB
192 |
193 | def remap_label(pred, by_size=False):
194 | """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3]
195 | not [0, 2, 4, 6]. The ordering of instances (which one comes first)
196 | is preserved unless by_size=True, then the instances will be reordered
197 | so that bigger nucler has smaller ID.
198 |
199 | Args:
200 | pred : the 2d array contain instances where each instances is marked
201 | by non-zero integer
202 | by_size : renaming with larger nuclei has smaller id (on-top)
203 |
204 | """
205 | pred_id = list(np.unique(pred))
206 | pred_id.remove(0)
207 | if len(pred_id) == 0:
208 | return pred # no label
209 | if by_size:
210 | pred_size = []
211 | for inst_id in pred_id:
212 | size = (pred == inst_id).sum()
213 | pred_size.append(size)
214 | # sort the id by size in descending order
215 | pair_list = zip(pred_id, pred_size)
216 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
217 | pred_id, pred_size = zip(*pair_list)
218 |
219 | new_pred = np.zeros(pred.shape, np.int32)
220 | for idx, inst_id in enumerate(pred_id):
221 | new_pred[pred == inst_id] = idx + 1
222 | return new_pred
--------------------------------------------------------------------------------