├── README.md ├── configs.py ├── data.py ├── demo └── data │ └── url.txt ├── image └── README │ ├── Intro.jpg │ ├── SOI.jpg │ ├── biomarker.jpg │ ├── cell_state.jpg │ ├── cell_type.jpg │ ├── cell_type1.jpg │ ├── cell_type2.jpg │ ├── deconvolution1.jpg │ ├── deconvolution2.jpg │ ├── logo.png │ ├── segmentation.jpg │ └── tissue_compartment.jpg ├── infer.py ├── model ├── __pycache__ │ └── arch.cpython-37.pyc └── arch.py ├── requirements.txt ├── tcs ├── tissue_compartment_BRCA.json ├── tissue_compartment_CRC.json ├── tissue_compartment_HCC.json ├── tissue_compartment_HD.json ├── tissue_compartment_LUSC.json ├── tissue_compartment_Mix.json ├── tissue_compartment_OVC.json ├── tissue_compartment_PDAC.json ├── tissue_compartment_PRAD.json ├── tissue_compartment_RCC.json ├── tissue_compartment_STAD.json ├── tissue_compartment_Xenium.json └── tissue_compartment_state.json ├── train.py ├── train_oneout_pannuke.py ├── train_state.py ├── tutorial ├── data │ └── url.txt └── tutorial.ipynb └── utils ├── __pycache__ └── utils.cpython-37.pyc └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # HistoCell 2 | 3 | **HistoCell** is a **weakly-supervised deep learning framework** to elucidate the **hierarchical spatial cellular information** including **tissue compartments, single cell types and cell states** with **histopathology images only**. This tutorial implements HistoCell to predict super-resolution spatial cellular information and illustrates the representative applications. The link to the HistoCell method will be presented soon. \ 4 | Our website: http://histocell.qhdyr.net/index/index/index.html 5 | 6 | Image 7 | 8 | ## Environments 9 | ```sh 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Data Format and Preprocessing 14 | 15 | ### Data Preparation 16 | 17 | * For model pre-train, HistoCell takes the **scRNA-seq** and **spatial transcriptomics** data with paired **high-resolution histopathology images** as input. 18 | * For histopathology images, we cut the paired images according to the pixel coordinates from ST data. The preprocessing code can be found in **./tutorial/tutorial.ipynb** 19 | * For transcriptomics data, we apply deconvolution methods to get the cell composition as the supervision. Applicable methods are listed as [CARD](https://www.nature.com/articles/s41587-022-01273-7), [RCTD](https://www.nature.com/articles/s41587-021-00830-w), [Tangram](https://www.nature.com/articles/s41592-021-01264-7), [Cell2location](https://www.nature.com/articles/s41587-021-01139-4) etc. 20 | * For model inference, HistoCell only requires the histopathology images including tiles and WSIs. 21 | * WSIs should be formatted as .svs or .tif. In order to convert the WSIs into the tiles can be processed in a high-throughput manner, we apply the toolbox from [CLAM](https://github.com/mahmoodlab/CLAM) for image segmentation, stitching and patching. As for TCGA diagnostic images, patch of 256x256 pixels is recommended. 22 | * Tiles can be storaged in any format of image. 23 | 24 | ### Data Preprocessing 25 | 26 | Using GPUs is highly recommended for whole slide images (WSIs) processing. 27 | 28 | Before we begin the model inference process, cell segmentation is required. Here, we apply [HoVerNet](https://pdf.sciencedirectassets.com/272154/1-s2.0-S1361841519X00079/1-s2.0-S1361841519301045/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEO3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCIQDd2OMed8Quier79h6hhXShEPV1a3lwQv%2BMd%2B99MpGolQIgHbvYbhoiHEe1uT1QjfLGEOEMBSTSCAhJmoThBfbNll4quwUIpv%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAFGgwwNTkwMDM1NDY4NjUiDMf%2B5OgldA9yxfGLPyqPBY%2FB1K2p3P7IvTS1hGh8akwWej0tXrgiaex2nNOJT6jwHbFCml3oXhsoNvr5bQ%2BmhFL3hNeKRbOXowl1RfkkrM9Mo8W7VB6L6a3faDuy3R9FmYg9OniS%2F2l1pasqKf%2Bk3es0ZtkBZYJhIpRcxWIogcRB1WPWE9WGuBRfN2qp7xf7NNkq4ZmbNaU3ysqx%2FMFFZWGP1DoLhVUeP18olpZstHpJ5rrKvMEJ4bUhOnN4WkA4wflhpJAKy6dv10PJIbCGYWReuhcTFO%2FNoSqCRDUUnQZD5zRfaCfsjNO943WEECuHreEcGSfsGwH16ncrE6deBpvS9x5f7qFSzLkM01th0ZwonFL0zXSGN6qaPnZ0wBzO2Lbe0OtzPBeHG3BrPl3VxL9qYSKNDITFNW%2BVRAO3CckWm%2Bt%2FEQqBGRbX%2FArLKvT7NS12jx%2FhEhj%2B%2Ba3yYQQjFUJMDPfLpbXsLpl8IOsNKBbQAqsT09iN0an0zA8q7oh%2F9HgfL8KCZdRuPWp9HkLLjSRbdH01i7ctSMTbehrkjiVMnXz8f3B9%2BVHfZR%2B3xQYH2YOH67UE87JCofjKJkWhroXKSkS1c53ye%2FOCEyF9gp2ChrWnKG8o95jTccF%2BInoECXr0Ymc5QiotpLF2es6pUQcGi2mq5rcjY1P6vJ9x4i4DDR2e%2FM718BZaM8zQCUmYm4XOBu%2B22Wtf8GgAymn8pz7uTgKnh9jlZhmTZ4YAnzVaqW%2FjTovCSRCUucEzklibwjVJFHw8urGagWh77nG0Qv9wGnO4PQYYOTJqs%2F2WZo9raeKVwxMPkcoXOVRM0Pkphg4bDnVZG63R1xL89urfYK1PPKDGtH08o9UvREJM6ugGTjDHpRZPeS29NYXNZ7kwxtHOrQY6sQFFrXdmd2FpmYFJtlfg2DJxJa7SwFYBVdj3Db6HD%2FLftOy%2B%2FyGCiRWyB%2FC%2FDxPPU0WDWvUzAq72HoaEV87cVsLz6Q7446UGQ1HAGMKRnAhCALhSZIp7WfyC1gTuAthi5QVJvr78GXQAinqnBaAlrMaHnLTtbiSNPNykRQrrEvhBDIq21Ffy%2Fid%2F9pPW0IURXrvj1end6m7dZT7ZtEIYcA%2BWVa5%2FYz6%2FOu5Tvta6zaDkE0k%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240126T134143Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY5OQRHCPI%2F20240126%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=608d9d59c577a23f89bb7aaa69c7f8c2163e3911b0cfcdc16987fb1603f50428&hash=48169297d0c9fa366d3cb8bd120add89965ee3917df478d735f708fb300c8168&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S1361841519301045&tid=spdf-4ba996cc-271e-4fff-93e1-559385625235&sid=31ab0cf595ec25414989848882e6352d1712gxrqa&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=0f135a56045e56020451&rr=84b92cd0cc162287&cc=us)[[code](https://github.com/vqdang/hover_net)], a well-developed nuclei instance segmentation method, to process cell segmentation from image tiles. It is worth mentioned that other segmentation methods (e.g. [MaskRCNN](https://arxiv.org/abs/1703.06870), [Cerbebrus](https://pdf.sciencedirectassets.com/272154/1-s2.0-S1361841522X00078/1-s2.0-S1361841522003139/main.pdf?X-Amz-Security-Token=IQoJb3JpZ2luX2VjEO3%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLWVhc3QtMSJHMEUCIQDWnVS2nFNswtWtc37yxAthvAru8F%2Bi9sXObOpOSp4fcQIgPLxoDxjXB2%2Bn5LDajvayiIU5Ev5%2FZSwXH65%2FSCvHpTQqvAUIpf%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FARAFGgwwNTkwMDM1NDY4NjUiDGqoAHUKPFM7LE85KCqQBRI1MjQTxc06bBJ2fe8tBSgL1PqcV3eXn3SDFwCsfPX48S61JVtwG9g91MxurBEUNhdwrjWdNBwI3pDUNKp7MQIH0AJVRXhwJ5mKktV96eZd5dizuGWGAicjTD%2Fy5zusW3nLiEhx0ka61bVUzKU7ZGdA%2Bs%2BT1egxuN3LK9fHKGVF63NVJt2sbDgNigMWEEiPdlXbWtvqvB0EylrEcVT7Rqo2a7ccsmPMCyDQlKyMzWJpNBOOu1KI2BK2KE0y0jLELSTeOvF9iHfwmBOf3Q1Azq7uBElKYA2jQPPFm6O%2Fi7b6UsJQtg5RLxB2HbOcRK1htkMmnOvGARy1y%2BbwQ26nYJJW5C85m%2F%2Bc3%2FXjBg%2BrsLeodZYEWQEpVgY%2BTUf%2Bdi4%2BCYmok0ra8kTrShb2CvgVK4pVrk8L0AO33WVGNokeOp%2F8xV0bz7O1DAsqN2ykAvvPk0wuprohZxBm2BGhA%2BLY7lox3f6v6Rv544wpvF8RHreBS%2B2jkVPae8%2FZER6WaRl9WuFuBzJlgY0xdW1V2iU8quvRxltyx%2BSS%2BVbdgMMBmX9H41LLVQgTYNnciL7hTcHRzS70pXQa2fb%2BfhFjQ0dXv3eH%2BCz3JeKgpU6wInT5Ax4GKKU9o3QqJMh9OsFmMvKjVHL8bzhpllg77aZEVPHbXj%2F2oUB6xN3BhSjl30wKiDKT0y1RGt2QuKWOqd%2FW5nEo0Nw8YA4kS%2FIkq6v7XQuyJnuPx6orceBGrHbBb0mLvHV5XubTTAmmvGmXojPqBT6cW%2FmCdjwHnGoAb5HG%2BAPuc%2FwCJ8pDR9ZxDgzry%2F4TNCPICSIP7yj%2B9tyfleO3i1J1mhnzCa%2BINadYcFBBJcVN4PW6abUmLq2zakWegZOr%2FjW2MNrHzq0GOrEBslcHzW1k%2BVNCvsQQ9R4j1edYgD16Ya5im3eeb4DyfkHwxF1eODT4NU1SsSI3uXG%2BnfmMEP0qPsq6khy%2FK9RXojWq0Qn0lsTw9ilEvPN97m192kvnNcD0dkbN32Jpjl%2Bfg08nRzORU9wzjJfwvSjwx2RIMdnBtJGENnzNNt%2FuEg%2BWdHon%2F6%2FG3%2FvmkWweeuQw9hnbO6X2HV0BGyMPXcMq9ouE8eIDqG38lbE3hm9kUFm1&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=20240126T134736Z&X-Amz-SignedHeaders=host&X-Amz-Expires=300&X-Amz-Credential=ASIAQ3PHCVTY72FFPE6G%2F20240126%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=f81b493864396df90b2dc5a4bc19a720f03415ca06d0e42870d9bbc34adb93ef&hash=e4bf9d21b009b09db822df8fbcdc06443d696a664e097e20f414559b75f39f7d&host=68042c943591013ac2b2430a89b270f6af2c76d8dfd086a07176afe7c76c2c61&pii=S1361841522003139&tid=spdf-1b0d3161-dd23-45aa-bfe3-aa725b637c50&sid=31ab0cf595ec25414989848882e6352d1712gxrqa&type=client&tsoh=d3d3LnNjaWVuY2VkaXJlY3QuY29t&ua=0f135a56045e57545753&rr=84b935728ba1239d&cc=us)etc.) are also applicable here. 29 | 30 | With cell segmentation by HoVerNet, you will get json file as the segmentation results illustrated as below: 31 | 32 | ```json 33 | {"mag": null, "nuc": {"YOUR_CELL_ID": {"bbox": [[1, 2], [3, 4]], "centroid": [2, 3], "contour": [[0, 1]], "type_prob": 0.9, "type": 0}}} 34 | ``` 35 | 36 | ### Input data format 37 | The tiled images, segmentation results(json file) and cell type/state proportion given by deconvolution are required during training process. 38 | 39 | ## Data Accession 40 | Demo data and tutorial data could be assessed on the Google drive with [Demo](https://drive.google.com/file/d/1DzwsFvD4wWKNN1wJCe3YsQZlk9xIp8DV/view?usp=sharing) and [Tutorial](https://drive.google.com/file/d/1iwXzxafdUF7SF6IAEyWluN728WbxAMEi/view?usp=sharing), respectively. To run the demo, you need to download the data to the directory **./demo/data**. 41 | 42 | ## Model pre-training with ST data 43 | Run **train.py** for model pre-training on cell type prediction. You need to change the parameters in **configs.py** including the model information, dataset directory and training details. Besides, choose the proper tissue compartment file in **./tcs**. To develop your own model, run the command below. 44 | ```python 45 | python train.py \ 46 | --model YOUR_MODEL_DESC \ 47 | --tissue TISSUE_TYPE \ 48 | --deconv DECONVOLUTION_METHOD \ 49 | --prefix SAMPLE_PREFIX \ 50 | --k_class NUM_OF_CLASS \ 51 | --tissue_compartment PATH_TO_TCS_FILE 52 | ``` 53 | A simple demo on subject H1 from HER2ST dataset: 54 | ```python 55 | python train.py \ 56 | --model Breast_Benchmark_H1 \ 57 | --tissue BRCA \ 58 | --deconv RCTD \ 59 | --prefix H1 \ 60 | --k_class 6 \ 61 | --tissue_compartment ./tcs/tissue_compartment_BRCA.json 62 | ``` 63 | 64 | For the benchmark analysis on the PanNuke dataset, run the leave-one-out cross-validation command: 65 | ``` 66 | python train_oneout_pannuke.py \ 67 | --model PanNuke_cross_validation \ 68 | --tissue Breast \ 69 | --reso 1 \ 70 | --deconv Mix \ 71 | --prefix fold1 fold2 fold3 \ 72 | --k_class 5 \ 73 | --tissue_compartment ./tcs/tissue_compartment_Mix.json 74 | ``` 75 | 76 | As for cell state, you can run **train_state.py** to develop the cell state prediction model. For example, develop the cell state prediction on the breast cancer ST samples: 77 | ```python 78 | python train_state.py \ 79 | --model brca_state \ 80 | --tissue BRCA \ 81 | --deconv RCTD \ 82 | --prefix 10x_BRCA A_1938345_11 B_1938529_9 C_2000752_23 D_2000910_33 \ 83 | --tissue_compartment ./tcs/tissue_compartment_state.json 84 | ``` 85 | 86 | ## Quick inference with histopathology images only 87 | With the pretrained model, you can infer the cellular spatial profile with **infer.py**. 88 | ```python 89 | python infer.py \ 90 | --model Breast_Benchmark_H1 \ 91 | --epoch 30 \ 92 | --tissue BRCA \ 93 | --deconv RCTD \ 94 | --prefix H1_hires \ 95 | --k_class 6 \ 96 | --tissue_compartment ./tcs/tissue_compartment_BRCA.json \ 97 | --omit_gt 98 | ``` 99 | 100 | ## Representative Results 101 | Here we only illustrate the demo and representative results corrsponding to the paper in the tutorial.ipynb. The predicted hierarchical spatial cellular information is storaged as a dict in a pickle file for each slide. For more results, you can directly jump to our [HistoCell website](http://histocell.qhdyr.net/index/index/index.html). 102 | 103 | ### Benchmark results 104 | 105 | * **Tissue Compartment** 106 |
107 | Image 108 |
109 | * **Single-cell Type** 110 |
111 | Image 112 |
113 |
114 | Image 115 |
116 | 117 | Red, blue and green scatters represent cancer epithelial cells, stromal cells and macrophage cells. 118 | * **Cell States** 119 |
120 | Image 121 |
122 | 123 | ### Representative Application: Tissue architecture annotations 124 | 125 | With a histopathology image, HistoCell could first infer pixel-level cell types and then cluster cells as tissue regions, which exhibit high accuracy and allow users to further identify the small foci within tissue regions at pixel-level resolution. 126 |
127 | Image 128 |
129 | 130 | ### Representative Application: Cell Type Deconvolution 131 | 132 | Since HistoCell Integrates spot-level cellular compositions deconvoluted from expression data and those based on histologic morphologic features, it could produce a more precise and robust deconvolution result. 133 |
134 | Image 135 |
136 |
137 | Image 138 |
139 | 140 | ### Representative Application: Spatial organization indicators identification 141 |
142 | Image 143 |
144 | The histopathology image is coverted to **spatial cellular map** with HistoCell and the cells are accumulated as clusters. Through the correlation analysis between clinical outcomes and cellular spatial clustering features, we identify spatial biomarkers for prognosis. Demo results can be found in Tutorial.ipynb. The representative spatial features for prognosis stratification is visualized as below. 145 |
146 | Image 147 |
148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from yacs.config import CfgNode as CN 3 | 4 | def _get_config(tissue_type, deconv, subtype, k_class, tissue_dir): 5 | config = CN() 6 | config.train = CN() 7 | config.train.lr = 0.0005 8 | config.train.epoch = 41 9 | config.train.val_iter = 10 10 | config.train.val_min_iter = 9 11 | 12 | config.data = CN() 13 | config.data.deconv = deconv 14 | config.data.save_model = f'./train_log/{tissue_type}/models' # model saved 15 | config.data.ckpt = f'./train_log/{tissue_type}/ckpts' # eval results saved 16 | config.data.tile_dir = f'./demo/data/{tissue_type}/tiles' # path to tiles 17 | config.data.mask_dir = f'./demo/data/{tissue_type}/seg' # path to json segmentation file 18 | config.data.batch_size = 32 19 | config.data.tissue_dir = tissue_dir # tissue compartment directory 20 | config.data.max_cell_num = 256 # max cell number in a single tile for batch learning 21 | config.data.cell_dir = f'./demo/data/{tissue_type}/cell_proportion/type/{config.data.deconv}' # path to cell proportion label 22 | 23 | config.model = CN() 24 | config.model.tissue_class = 3 25 | config.model.pretrained = True 26 | config.model.channels = 3 27 | config.model.k_class = k_class 28 | 29 | return config 30 | 31 | def _get_cell_state_config(tissue_type, deconv, subtype, tissue_dir): 32 | config = CN() 33 | config.train = CN() 34 | config.train.lr = 0.0005 35 | config.train.epoch = 41 36 | config.train.val_iter = 5 37 | config.train.val_min_iter = 9 38 | config.train.state_epoch = 41 39 | 40 | config.data = CN() 41 | config.data.deconv = deconv 42 | config.data.save_model = f'./train_log/{tissue_type}/models' 43 | config.data.ckpt = f'./train_log/{tissue_type}/ckpts' 44 | config.data.tile_dir = f'./demo/data/{tissue_type}/tiles' 45 | config.data.mask_dir = f'./demo/data/{tissue_type}/seg' 46 | config.data.batch_size = 32 47 | config.data.tissue_dir = tissue_dir 48 | config.data.max_cell_num = 256 49 | with open(tissue_dir, 'r') as tissue_file: 50 | tc = json.load(tissue_file) 51 | 52 | config.data.cell_dir = f'./demo/data/{tissue_type}/cell_proportion/type/{config.data.deconv}' 53 | config.data.state_dir = f'./demo/data/{tissue_type}/cell_proportion/state/{config.data.deconv}' # path to cell state directory 54 | 55 | config.model = CN() 56 | config.model.tissue_class = len(tc['list']) 57 | config.model.pretrained = True 58 | config.model.channels = 3 59 | config.model.k_class = len(tc['dict']) 60 | 61 | k_state = 0 62 | for key, value in tc['state'].items(): 63 | k_state += int(value) 64 | config.model.k_state = k_state 65 | print(k_state) 66 | 67 | return config -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import json 7 | import random 8 | from glob import glob 9 | from PIL import Image 10 | from utils.utils import load_json, load_image 11 | from torch.utils.data import Dataset 12 | from torchvision import transforms 13 | from tqdm import tqdm 14 | 15 | 16 | class TileBatchDataset(Dataset): 17 | def __init__(self, 18 | tile_dir, 19 | mask_dir, 20 | tissue, 21 | cell_dir=None, 22 | deconv='POLARIS', 23 | prefix=['*'], 24 | channels=3, 25 | aug=True, 26 | val=False, 27 | panuke=None, 28 | sample_ratio=1, 29 | max_cell_num=128, 30 | ext='jpg', 31 | focus_tissue='Breast', 32 | reso=1 33 | ) -> None: 34 | ### Batch Size should be 1 35 | self.tile_list = [] 36 | 37 | if aug is True: 38 | self.transforms = transforms.Compose( 39 | [ 40 | transforms.ColorJitter(brightness=0.1, contrast=0.3, saturation=0.1, hue=0.2), 41 | transforms.RandomGrayscale(p=0.1), 42 | transforms.ToTensor(), 43 | transforms.Normalize( 44 | mean=[0.485, 0.456, 0.406], 45 | std=[0.229, 0.224, 0.225] 46 | ) 47 | ] 48 | ) 49 | else: 50 | self.transforms = transforms.Compose( 51 | [ 52 | transforms.ToTensor(), 53 | transforms.Normalize( 54 | mean=[0.485, 0.456, 0.406], 55 | std=[0.229, 0.224, 0.225] 56 | ) 57 | ] 58 | ) 59 | self.channels = channels 60 | 61 | self.mask_dir = mask_dir 62 | self.cell_dir = cell_dir 63 | self.val = val 64 | self.panuke = panuke 65 | self.mcn = max_cell_num 66 | self.ext = ext 67 | self.reso = reso 68 | 69 | with open(tissue, 'r') as tissue_file: 70 | self.tc = json.load(tissue_file) 71 | print(f"Tissue Compartment: {self.tc['list']}") 72 | if cell_dir is not None: 73 | cell_labels = {} 74 | 75 | if deconv in ['Tangram', 'RCTD', 'stereoscope', 'CARD', 'POLARIS', 'Xenium', 'CARD_fix', 'RCTD_fix']: 76 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")): 77 | sample_name = cell_prop.split('/')[-1].split('.')[0] 78 | if sample_name not in prefix: 79 | continue 80 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0) 81 | for cell_index, row in tqdm(cell_df.iterrows()): 82 | cell_index = str(cell_index) 83 | x, y, barcode = cell_index.split('_') 84 | cell_num = np.array(row) 85 | cell_propotion = cell_num / np.sum(cell_num) 86 | abs_path = glob(os.path.join(tile_dir, sample_name, f'*_{x}x{y}.jpg')) 87 | if len(abs_path) <= 0: 88 | continue 89 | img_name = abs_path[0].split('/')[-1].split('.')[0] 90 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 91 | cell_labels.update( 92 | { 93 | sample_name + '_' + img_name: cell_propotion 94 | } 95 | ) 96 | if not os.path.exists(json_file): 97 | continue 98 | self.tile_list.append(abs_path[0]) 99 | 100 | elif deconv in ['Mix']: 101 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")): 102 | sample_name = cell_prop.split('/')[-1].replace('.tsv', '') 103 | if sample_name not in prefix: 104 | continue 105 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0) 106 | for cell_index, row in cell_df.iterrows(): 107 | # cell_index = "%06d" % int(cell_index) 108 | cell_index = cell_index.split('-')[0] 109 | abs_path = glob(os.path.join(tile_dir, sample_name, f'{cell_index}-{focus_tissue}.{ext}')) 110 | if len(abs_path) == 0: 111 | continue 112 | tissue_index = abs_path[0].split('/')[-1].split('.')[0] 113 | cell_propotion = np.array(row) 114 | cell_propotion = cell_propotion / np.sum(cell_propotion) 115 | cell_labels.update( 116 | { 117 | sample_name + '_' + tissue_index: cell_propotion 118 | } 119 | ) 120 | img_name = abs_path[0].split('/')[-1].replace(f'.{ext}', '') 121 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 122 | if not os.path.exists(json_file): 123 | continue 124 | self.tile_list.append(abs_path[0]) 125 | 126 | else: 127 | raise NotImplementedError("The Deconvolution Method is not supported.") 128 | 129 | self.cell_labels = cell_labels 130 | 131 | if not val: 132 | print(self.tc['dict']) 133 | print(cell_df.keys()) 134 | assert len(self.tc['dict'].keys()) == len(cell_df.keys()) 135 | print(f"Cell Category: {list(cell_df.keys())}") 136 | else: 137 | for dir_prefix in prefix: 138 | path_list = [] 139 | if panuke is None: 140 | panuke = '*' 141 | else: 142 | panuke = f'*-{panuke}' 143 | for abs_path in glob(os.path.join(tile_dir, f"{dir_prefix}/{panuke}.{ext}")): 144 | # for abs_path in glob(os.path.join(tile_dir, f"tumor_003/*.{ext}")): 145 | img_name = abs_path.split('/')[-1].replace(f'.{ext}', '') 146 | sample_name = abs_path.split('/')[-2] 147 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 148 | if os.path.exists(json_file): 149 | path_list.append(abs_path) 150 | self.tile_list.extend(path_list) 151 | 152 | self.tile_list = random.sample(self.tile_list, int(len(self.tile_list) * sample_ratio)) 153 | 154 | def __len__(self) -> int: 155 | return len(self.tile_list) 156 | 157 | def __getitem__(self, index: int): 158 | # load image 159 | img_path = self.tile_list[index] 160 | img_name = img_path.split('/')[-1].split('.')[0] 161 | sample_name = img_path.split('/')[-2] 162 | with open(img_path, 'rb') as fp: 163 | pic = Image.open(fp).convert('RGB').resize((256 // self.reso, 256 // self.reso)) 164 | pic = pic.resize((256, 256)) 165 | image = self.transforms(pic) 166 | 167 | # load mask 168 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 169 | nucs = load_json(json_file) 170 | mask_list, size_list, pos_list, type_list = [], [], [], [] 171 | for index, nuc in nucs.items(): 172 | if len(image.shape) < 3: 173 | image = np.expand_dims(image, axis=0) 174 | # mask = np.zeros_like(image) 175 | center_x, center_y = nuc['centroid'] 176 | bbox_xa, bbox_ya = nuc['bbox'][0] 177 | bbox_xb, bbox_yb = nuc['bbox'][1] 178 | mask = np.zeros([3, bbox_xb - bbox_xa, bbox_yb - bbox_ya]) 179 | mask = image[:, bbox_xa: bbox_xb, bbox_ya: bbox_yb] 180 | try: 181 | mask = torch.tensor(cv2.resize(mask.permute(1, 2, 0).numpy(), (30, 30))) 182 | except: 183 | continue 184 | 185 | mask_list.append(mask.permute(2, 0, 1)), size_list.append(np.array([(bbox_xb - bbox_xa) / image.shape[-1], (bbox_yb - bbox_ya) / image.shape[-1]])) 186 | pos_list.append(np.array([center_x, center_y])) 187 | if nuc['type'] is None: 188 | type_list.append(-1) 189 | else: 190 | type_list.append(int(self.tc['HoVerNet'][int(nuc['type'])])) 191 | 192 | cell_types = np.array(type_list) 193 | image = torch.tensor(cv2.resize(image.permute(1, 2, 0).numpy(), (128, 128))).permute(2, 0, 1) 194 | cell_num = len(mask_list) 195 | if cell_num == 1: 196 | dist_mat = np.zeros([cell_num, cell_num]) 197 | cell_coords = np.array([pos_list]) 198 | elif cell_num == 0: 199 | dist_mat = np.zeros([cell_num, cell_num]) 200 | cell_coords = np.zeros([0, 2]) 201 | else: 202 | cell_coords = np.stack(pos_list, axis=0) 203 | dist_mat = np.zeros([cell_num, cell_num]) 204 | for i in range(cell_num): 205 | for j in range(i + 1, cell_num): 206 | dist = np.linalg.norm((cell_coords[i] - cell_coords[j]), ord=2) 207 | if dist < 40: 208 | dist_mat[i, j] = 1 209 | 210 | dist_mat = dist_mat + dist_mat.T + np.identity(cell_num) 211 | 212 | adj_mat = np.zeros([self.mcn, self.mcn]) 213 | cell_pixels = np.zeros([self.mcn, 2]) 214 | CellType = -np.ones([self.mcn]) 215 | valid_mask = len(mask_list) if len(mask_list) < self.mcn else self.mcn # max cell num 216 | if valid_mask >= self.mcn: 217 | mask_list = mask_list[:self.mcn] 218 | size_list = size_list[:self.mcn] 219 | adj_mat = dist_mat[:self.mcn, :self.mcn] 220 | cell_pixels = cell_coords[:self.mcn] 221 | CellType = cell_types[:self.mcn] 222 | else: 223 | for _ in range(self.mcn - valid_mask): 224 | mask_list.append(torch.zeros((3, 30, 30))) 225 | size_list.append(torch.zeros((2))) 226 | 227 | adj_mat = np.pad(dist_mat, ((0, self.mcn - valid_mask), (0, self.mcn - valid_mask))) 228 | cell_pixels[:valid_mask] = cell_coords 229 | CellType[:valid_mask] = cell_types 230 | 231 | cell_images = np.stack(mask_list, axis=0) 232 | cell_sizes = np.stack(size_list, axis=0) 233 | 234 | if self.cell_dir is None: 235 | return { 236 | 'name': f'{sample_name}_{img_name}', 237 | 'tissue': image, # B 3 256 256 238 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128 239 | 'mask': int(valid_mask), # B 240 | 'size': torch.tensor(cell_sizes, dtype=torch.float32), 241 | 'adj': torch.tensor(adj_mat, dtype=torch.long), 242 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32), 243 | 'cell_types': torch.tensor(CellType, dtype=torch.float32) 244 | } 245 | # Add tissue compartment 246 | if self.panuke is not None: 247 | sample_keys = sample_name + '_' + img_name.split('-')[0] + f'-{self.panuke}' # Only for Panuke 248 | cell_prop = self.cell_labels[sample_keys] 249 | tissue_cat = -1 250 | 251 | else: 252 | sample_keys = sample_name + '_' + img_name 253 | cell_prop = self.cell_labels[sample_keys] 254 | # tc = np.zeros(len(self.tc['list'])) 255 | # for cata_i, prop in enumerate(cell_prop): 256 | # tc[self.tc['list'].index(self.tc['dict'][str(cata_i)])] += prop 257 | if np.max(cell_prop) > 0.75: 258 | tissue_cat = self.tc['list'].index(self.tc['dict'][str(np.argmax(cell_prop))]) 259 | else: 260 | tissue_cat = len(self.tc['list']) 261 | 262 | return { 263 | 'name': f'{sample_name}_{img_name}', 264 | 'tissue': image, # B 3 128 128 265 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128 266 | 'mask': int(valid_mask), # B 267 | 'cells': torch.tensor(cell_prop, dtype=torch.float32), # B cell_type 268 | 'size': torch.tensor(cell_sizes, dtype=torch.float32), 269 | 'adj': torch.tensor(adj_mat, dtype=torch.long), 270 | 'tissue_cat': torch.tensor(tissue_cat, dtype=torch.long), # B 1 271 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32), 272 | 'cell_types': torch.tensor(CellType, dtype=torch.float32) 273 | } 274 | 275 | 276 | class TileBatchStateDataset(Dataset): 277 | def __init__(self, tile_dir, mask_dir, tissue, cell_dir=None, state_dir=None, deconv='POLARIS', prefix=['*'], channels=3, aug=True, val=False, panuke=False, sample_ratio=1, max_cell_num=128, ext='jpg') -> None: 278 | ### Batch Size should be 1 279 | self.tile_list = [] 280 | 281 | if aug is True: 282 | self.transforms = transforms.Compose( 283 | [ 284 | transforms.ColorJitter(brightness=0.1, contrast=0.3, saturation=0.1, hue=0.2), 285 | transforms.RandomGrayscale(p=0.1), 286 | transforms.ToTensor(), 287 | transforms.Normalize( 288 | mean=[0.485, 0.456, 0.406], 289 | std=[0.229, 0.224, 0.225] 290 | ) 291 | ] 292 | ) 293 | else: 294 | self.transforms = transforms.Compose( 295 | [ 296 | transforms.ToTensor(), 297 | transforms.Normalize( 298 | mean=[0.485, 0.456, 0.406], 299 | std=[0.229, 0.224, 0.225] 300 | ) 301 | ] 302 | ) 303 | self.channels = channels 304 | 305 | self.mask_dir = mask_dir 306 | self.cell_dir = cell_dir 307 | self.state_dir = state_dir 308 | self.val = val 309 | self.panuke = panuke 310 | self.mcn = max_cell_num 311 | self.ext = ext 312 | 313 | with open(tissue, 'r') as tissue_file: 314 | self.tc = json.load(tissue_file) 315 | print(f"Tissue Compartment: {self.tc['list']}") 316 | 317 | if cell_dir is not None: 318 | cell_labels = {} 319 | state_labels = {} 320 | 321 | if deconv in ['Tangram', 'RCTD', 'stereoscope', 'CARD', 'POLARIS', 'Xenium']: 322 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")): 323 | sample_name = cell_prop.split('/')[-1].split('.')[0] 324 | if sample_name not in prefix: 325 | continue 326 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0) 327 | state_df = pd.read_csv(os.path.join(state_dir, f"{sample_name}.tsv"), sep='\t', index_col=0) 328 | for cell_index, state_row in state_df.iterrows(): 329 | cell_index = str(cell_index) 330 | x, y, barcode = cell_index.split('_') 331 | 332 | # cell prop 333 | row = cell_df.loc[cell_index] 334 | cell_num = np.array(row) 335 | cell_propotion = cell_num / np.sum(cell_num) 336 | cell_labels.update( 337 | { 338 | sample_name + '_' + f'{x}x{y}': cell_propotion 339 | } 340 | ) 341 | 342 | # state prop 343 | state_num = np.array(state_row) 344 | state_proportion = state_num / np.sum(state_num) 345 | state_labels.update( 346 | { 347 | sample_name + '_' + f'{x}x{y}': state_proportion 348 | } 349 | ) 350 | 351 | abs_path = glob(os.path.join(tile_dir, sample_name, f'*_{x}x{y}.{ext}')) 352 | if len(abs_path) <= 0: 353 | continue 354 | img_name = abs_path[0].split('/')[-1].split('.')[0] 355 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 356 | if not os.path.exists(json_file): 357 | continue 358 | self.tile_list.append(abs_path[0]) 359 | elif deconv in ['Mix']: 360 | for cell_prop in glob(os.path.join(cell_dir, "*.tsv")): 361 | sample_name = cell_prop.split('/')[-1].split('.')[0] 362 | if sample_name not in prefix: 363 | continue 364 | cell_df = pd.read_csv(cell_prop, sep='\t', index_col=0) 365 | for cell_index, row in cell_df.iterrows(): 366 | cell_index = "%06d" % int(cell_index) 367 | abs_path = glob(os.path.join(tile_dir, sample_name, f'{cell_index}-Prostate.{ext}')) 368 | if len(abs_path) == 0: 369 | continue 370 | cell_propotion = np.array(row) 371 | cell_labels.update( 372 | { 373 | sample_name + '_' + cell_index + '-Prostate': cell_propotion 374 | } 375 | ) 376 | img_name = abs_path[0].split('/')[-1].split('.')[0] 377 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 378 | if not os.path.exists(json_file): 379 | continue 380 | self.tile_list.append(abs_path[0]) 381 | 382 | else: 383 | raise NotImplementedError("The Deconvolution Method is not supported.") 384 | 385 | self.cell_labels = cell_labels 386 | self.state_labels = state_labels 387 | 388 | if not val: 389 | print(self.tc['dict']) 390 | print(cell_df.keys()) 391 | assert len(self.tc['dict'].keys()) == len(cell_df.keys()) 392 | print(f"Cell Category: {list(cell_df.keys())}") 393 | else: 394 | for dir_prefix in prefix: 395 | path_list = [] 396 | print(os.path.join(tile_dir, f"{dir_prefix}/*.{ext}")) 397 | for abs_path in glob(os.path.join(tile_dir, f"{dir_prefix}/*.{ext}")): 398 | # img_name = abs_path.split('/')[-1].split('.')[0] 399 | img_name = abs_path.split('/')[-1].strip(f'.{ext}') 400 | sample_name = abs_path.split('/')[-2] 401 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 402 | # print(json_file) 403 | if os.path.exists(json_file): 404 | path_list.append(abs_path) 405 | self.tile_list.extend(path_list) 406 | 407 | self.tile_list = random.sample(self.tile_list, int(len(self.tile_list) * sample_ratio)) 408 | 409 | def __len__(self) -> int: 410 | return len(self.tile_list) 411 | 412 | def __getitem__(self, index: int): 413 | # load image 414 | img_path = self.tile_list[index] 415 | img_name = img_path.split('/')[-1].strip(f'.{self.ext}') 416 | sample_name = img_path.split('/')[-2] 417 | with open(img_path, 'rb') as fp: 418 | pic = Image.open(fp) 419 | image = self.transforms(pic) 420 | 421 | # load mask 422 | json_file = os.path.join(self.mask_dir, sample_name, f"json/{img_name}.json") 423 | nucs = load_json(json_file) 424 | mask_list, size_list, pos_list, type_list = [], [], [], [] 425 | for index, nuc in nucs.items(): 426 | if len(image.shape) < 3: 427 | image = np.expand_dims(image, axis=0) 428 | # mask = np.zeros_like(image) 429 | center_x, center_y = nuc['centroid'] 430 | bbox_xa, bbox_ya = nuc['bbox'][0] 431 | bbox_xb, bbox_yb = nuc['bbox'][1] 432 | mask = np.zeros([3, bbox_xb - bbox_xa, bbox_yb - bbox_ya]) 433 | mask = image[:, bbox_xa: bbox_xb, bbox_ya: bbox_yb] 434 | try: 435 | mask = torch.tensor(cv2.resize(mask.permute(1, 2, 0).numpy(), (30, 30))) 436 | except: 437 | continue 438 | 439 | mask_list.append(mask.permute(2, 0, 1)), size_list.append(np.array([(bbox_xb - bbox_xa) / image.shape[-1], (bbox_yb - bbox_ya) / image.shape[-1]])) 440 | pos_list.append(np.array([center_x, center_y])) 441 | type_list.append(int(self.tc['HoVerNet'][int(nuc['type'])])) 442 | 443 | cell_types = np.array(type_list) 444 | image = torch.tensor(cv2.resize(image.permute(1, 2, 0).numpy(), (256, 256))).permute(2, 0, 1) 445 | cell_num = len(mask_list) 446 | if cell_num <= 1: 447 | dist_mat = np.zeros([cell_num, cell_num]) 448 | cell_coords = np.array(pos_list) 449 | else: 450 | cell_coords = np.stack(pos_list, axis=0) 451 | dist_mat = np.zeros([cell_num, cell_num]) 452 | for i in range(cell_num): 453 | for j in range(i + 1, cell_num): 454 | dist = np.linalg.norm((cell_coords[i] - cell_coords[j]), ord=2) 455 | if dist < 40: 456 | dist_mat[i, j] = 1 457 | 458 | dist_mat = dist_mat + dist_mat.T + np.identity(cell_num) 459 | 460 | adj_mat = np.zeros([self.mcn, self.mcn]) 461 | cell_pixels = np.zeros([self.mcn, 2]) 462 | CellType = np.zeros([self.mcn]) 463 | valid_mask = len(mask_list) if len(mask_list) < self.mcn else self.mcn # max cell num 464 | if valid_mask >= self.mcn: 465 | mask_list = mask_list[:self.mcn] 466 | size_list = size_list[:self.mcn] 467 | adj_mat = dist_mat[:self.mcn, :self.mcn] 468 | cell_pixels = cell_coords[:self.mcn, :self.mcn] 469 | CellType = cell_types[:self.mcn] 470 | else: 471 | for _ in range(self.mcn - valid_mask): 472 | mask_list.append(torch.zeros((3, 30, 30))) 473 | size_list.append(torch.zeros((2))) 474 | 475 | adj_mat = np.pad(dist_mat, ((0, self.mcn - valid_mask), (0, self.mcn - valid_mask))) 476 | for i in range(valid_mask): 477 | cell_pixels[i] = cell_coords[i] 478 | CellType[:valid_mask] = cell_types 479 | 480 | cell_images = np.stack(mask_list, axis=0) 481 | cell_sizes = np.stack(size_list, axis=0) 482 | 483 | if self.cell_dir is None: 484 | return { 485 | 'name': f'{sample_name}_{img_name}', 486 | 'tissue': image, # B 3 128 128 487 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128 488 | 'mask': int(valid_mask), # B 489 | 'size': torch.tensor(cell_sizes, dtype=torch.float32), 490 | 'adj': torch.tensor(adj_mat, dtype=torch.long), 491 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32), 492 | 'cell_types': torch.tensor(CellType, dtype=torch.float32) 493 | } 494 | # Add tissue compartment 495 | if self.panuke: 496 | sample_keys = sample_name + '_' + img_name.split('-')[0] + '-Prostate' # Only for Panuke 497 | cell_prop = self.cell_labels[sample_keys] 498 | tissue_cat = -1 499 | 500 | else: 501 | sample_keys = sample_name + '_' + img_name.split('_')[-1] 502 | cell_prop = self.cell_labels[sample_keys] 503 | # tc = np.zeros(len(self.tc['list'])) 504 | # for cata_i, prop in enumerate(cell_prop): 505 | # tc[self.tc['list'].index(self.tc['dict'][str(cata_i)])] += prop 506 | if np.max(cell_prop) > 0.75: 507 | tissue_cat = self.tc['list'].index(self.tc['dict'][str(np.argmax(cell_prop))]) 508 | else: 509 | tissue_cat = len(self.tc['list']) 510 | 511 | state_prop = self.state_labels[sample_keys] 512 | 513 | return { 514 | 'name': f'{sample_name}_{img_name}', 515 | 'tissue': image, # B 3 128 128 516 | 'image': torch.tensor(cell_images, dtype=torch.float32), # B 128 3 128 128 517 | 'mask': int(valid_mask), # B 518 | 'cells': torch.tensor(cell_prop, dtype=torch.float32), # B cell_type 519 | 'states': torch.tensor(state_prop, dtype=torch.float32), 520 | 'size': torch.tensor(cell_sizes, dtype=torch.float32), 521 | 'adj': torch.tensor(adj_mat, dtype=torch.long), 522 | 'tissue_cat': torch.tensor(tissue_cat, dtype=torch.long), # B 1 523 | 'cell_coords': torch.tensor(cell_pixels, dtype=torch.float32), 524 | 'cell_types': torch.tensor(CellType, dtype=torch.float32) 525 | } -------------------------------------------------------------------------------- /demo/data/url.txt: -------------------------------------------------------------------------------- 1 | Download from url: 2 | https://drive.google.com/file/d/1DzwsFvD4wWKNN1wJCe3YsQZlk9xIp8DV/view?usp=sharing -------------------------------------------------------------------------------- /image/README/Intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/Intro.jpg -------------------------------------------------------------------------------- /image/README/SOI.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/SOI.jpg -------------------------------------------------------------------------------- /image/README/biomarker.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/biomarker.jpg -------------------------------------------------------------------------------- /image/README/cell_state.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_state.jpg -------------------------------------------------------------------------------- /image/README/cell_type.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_type.jpg -------------------------------------------------------------------------------- /image/README/cell_type1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_type1.jpg -------------------------------------------------------------------------------- /image/README/cell_type2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/cell_type2.jpg -------------------------------------------------------------------------------- /image/README/deconvolution1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/deconvolution1.jpg -------------------------------------------------------------------------------- /image/README/deconvolution2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/deconvolution2.jpg -------------------------------------------------------------------------------- /image/README/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/logo.png -------------------------------------------------------------------------------- /image/README/segmentation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/segmentation.jpg -------------------------------------------------------------------------------- /image/README/tissue_compartment.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/image/README/tissue_compartment.jpg -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import pickle as pkl 5 | from model.arch import HistoCell 6 | from data import TileBatchDataset 7 | from utils.utils import * 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from yacs.config import CfgNode as CN 11 | 12 | def _get_config(tissue_type, deconv, subtype, k_class, tissue_dir): 13 | config = CN() 14 | 15 | config.data = CN() 16 | config.data.deconv = deconv 17 | config.data.save_model = f'./train_log/{tissue_type}/models' 18 | config.data.ckpt = f'./train_log/{tissue_type}/ckpts' 19 | config.data.tile_dir = f'./demo/data/{tissue_type}/tiles' 20 | config.data.mask_dir = f'./demo/data/{tissue_type}/seg' 21 | config.data.batch_size = 16 22 | config.data.tissue_dir = tissue_dir 23 | config.data.max_cell_num = 256 24 | 25 | config.model = CN() 26 | config.model.tissue_class = 3 27 | config.model.pretrained = True 28 | config.model.channels = 3 29 | 30 | config.data.cell_dir = '' 31 | config.model.k_class = k_class 32 | 33 | return config 34 | 35 | def main(): 36 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 37 | parser = argparse.ArgumentParser(description="Prediction with Spots Images") 38 | parser.add_argument('--seed', default=47, type=int) 39 | parser.add_argument('--model', default='Baseline', type=str, help='model description') 40 | parser.add_argument('--epoch', default=-1, type=int) 41 | parser.add_argument('--tissue', default='BRCA', type=str) 42 | parser.add_argument('--deconv', default='CARD', type=str) 43 | parser.add_argument('--subtype', action='store_true') 44 | parser.add_argument('--prefix', required=False, nargs = '+') 45 | parser.add_argument('--k_class', default=8, type=int) 46 | parser.add_argument('--tissue_compartment', type=str, required=True) 47 | parser.add_argument('--omit_gt', action='store_true') 48 | parser.add_argument('--val_panuke', default=None, type=str) 49 | 50 | args = parser.parse_args() 51 | config = _get_config(args.tissue, args.deconv, args.subtype, args.k_class, args.tissue_compartment) 52 | 53 | print(f"Model details: {args.model}") 54 | 55 | print("Load Dataset...") 56 | 57 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix] 58 | 59 | print(data_prefix) 60 | 61 | val_data = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, cell_dir=None, 62 | deconv=config.data.deconv, prefix=data_prefix, 63 | aug=False, val=args.omit_gt, panuke=args.val_panuke, max_cell_num=config.data.max_cell_num) 64 | print(f"length of train data: {len(val_data)}") 65 | # val_data = TileMaskDataset(config.data.tile_dir, config.data.mask_dir, config.data.cell_dir) 66 | val_loader = DataLoader(val_data, batch_size=config.data.batch_size, num_workers=8, pin_memory=True) 67 | # val_loader = DataLoader(val_data, batch_size=1, collate_fn=collate_fn) 68 | 69 | print("Load Model...") 70 | model = HistoCell(config.model) 71 | 72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | if torch.cuda.is_available(): 74 | print("Using 1 gpu") 75 | else: 76 | print("Using cpu") 77 | 78 | 79 | model_dir, ckpt_dir = \ 80 | os.path.join(config.data.save_model, args.model), os.path.join(config.data.ckpt, args.model) 81 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True) 82 | 83 | ckpt_file = find_ckpt(model_dir) 84 | 85 | if args.epoch >= 0: 86 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, f"epoch_{args.epoch}.ckpt")) 87 | model.load_state_dict(state_dict) 88 | 89 | # ipdb.set_trace() 90 | elif ckpt_file is not None: 91 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file)) 92 | model.load_state_dict(state_dict) 93 | 94 | else: 95 | raise FileNotFoundError("No trained model exits") 96 | 97 | model.to(device) 98 | 99 | print(f"Current Epoch: {current_epoch}") 100 | 101 | val_results = val_loop(current_epoch, model, val_loader, device) 102 | with open(os.path.join(ckpt_dir, f'epoch_{current_epoch}_{data_prefix[0]}_val.pkl'), 'wb') as file: 103 | pkl.dump(val_results, file) 104 | 105 | print("Val Finished!") 106 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}") 107 | 108 | 109 | def val_loop(epoch, model, val_loader, device): 110 | model.train() 111 | val_bar = tqdm(val_loader, desc="epoch " + str(epoch), total=len(val_loader), 112 | unit="batch", dynamic_ncols=True) 113 | 114 | all_results = {} 115 | for idx, data in enumerate(val_bar): 116 | tissue = data['tissue'].to(torch.float32).to(device) 117 | images = data['image'].to(torch.float32).to(device) 118 | cell_size = data['size'].to(torch.float32).to(device) 119 | adj_mat = data['adj'].to(torch.float32).to(device) 120 | valid_mask = data['mask'].to(torch.long).to(device) 121 | cell_coords = data['cell_coords'].to(torch.float32).to(device) 122 | ref_type = data['cell_types'].to(torch.long).to(device) 123 | if torch.sum(data['mask']) <= 0: 124 | continue 125 | batch, cells, channels, height, width = images.shape 126 | images = images.reshape(batch * cells, channels, height, width) 127 | prob_list, pred_proportion, tissue_cat, cell_embeddings = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells}) 128 | name_list, valid_list, coord_list, type_list = [], [], [], [] 129 | for name, valid_index, valid_coord, valid_type in zip(data['name'], valid_mask, cell_coords, ref_type): 130 | if valid_index <= 0: 131 | continue 132 | name_list.append(name) 133 | valid_list.append(valid_index) 134 | coord_list.append(valid_coord) 135 | type_list.append(valid_type) 136 | 137 | for name, pb, pp, vix, cc, ct, cell_features in zip(name_list, prob_list, pred_proportion, valid_list, coord_list, type_list, cell_embeddings): 138 | valid_num = int(vix.detach().cpu()) 139 | all_results.update( 140 | { 141 | name: { 142 | 'prob': pb.detach().cpu().numpy(), 143 | 'pred_proportion': pp.detach().cpu().numpy(), 144 | 'prior_type': ct.cpu().numpy(), 145 | 'cell_num': valid_num, 146 | 'cell_coords': cc[:valid_num].detach().cpu().numpy(), 147 | # 'cell_embedding': cell_features.detach().cpu().numpy() 148 | } 149 | } 150 | ) 151 | 152 | val_bar.set_description('epoch:{} iter:{}'.format(epoch, idx)) 153 | 154 | return all_results 155 | 156 | 157 | if __name__ == '__main__': 158 | main() -------------------------------------------------------------------------------- /model/__pycache__/arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/model/__pycache__/arch.cpython-37.pyc -------------------------------------------------------------------------------- /model/arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models import resnet18 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | 8 | class GraphAttentionLayer(nn.Module): 9 | """ 10 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 11 | """ 12 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 13 | super(GraphAttentionLayer, self).__init__() 14 | self.dropout = dropout 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.alpha = alpha 18 | self.concat = concat 19 | 20 | self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) 21 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 22 | self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) 23 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 24 | 25 | self.leakyrelu = nn.LeakyReLU(self.alpha) 26 | 27 | def forward(self, h, adj): 28 | Wh = torch.matmul(h, self.W) # h.shape: (B, N, in_features), Wh.shape: (N, out_features) 29 | e = self._prepare_attentional_mechanism_input(Wh) 30 | 31 | zero_vec = -9e15*torch.ones_like(e) # B N N 32 | attention = torch.where(adj > 0, e, zero_vec) 33 | attention = F.softmax(attention, dim=1) 34 | attention = F.dropout(attention, self.dropout, training=self.training) 35 | h_prime = torch.matmul(attention, Wh) 36 | 37 | if self.concat: 38 | return F.elu(h_prime) 39 | else: 40 | return h_prime 41 | 42 | def _prepare_attentional_mechanism_input(self, Wh): 43 | # Wh.shape (N, out_feature) 44 | # self.a.shape (2 * out_feature, 1) 45 | # Wh1&2.shape (N, 1) 46 | # e.shape (N, N) 47 | Wh1 = torch.matmul(Wh, self.a[:self.out_features, :]) 48 | Wh2 = torch.matmul(Wh, self.a[self.out_features:, :]) 49 | # broadcast add 50 | e = Wh1 + Wh2.transpose(-1, -2) 51 | return self.leakyrelu(e) 52 | 53 | def __repr__(self): 54 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 55 | 56 | 57 | class GAT(nn.Module): 58 | def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): 59 | """Dense version of GAT.""" 60 | super(GAT, self).__init__() 61 | self.dropout = dropout 62 | 63 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] 64 | for i, attention in enumerate(self.attentions): 65 | self.add_module('attention_{}'.format(i), attention) 66 | 67 | self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) 68 | 69 | def forward(self, x, adj): 70 | x = F.dropout(x, self.dropout, training=self.training) 71 | x = torch.cat([att(x, adj) for att in self.attentions], dim=-1) 72 | x = F.dropout(x, self.dropout, training=self.training) 73 | x = F.elu(self.out_att(x, adj)) 74 | return F.log_softmax(x, dim=1) 75 | 76 | 77 | class HistoCell(nn.Module): 78 | def __init__(self, config) -> None: 79 | super(HistoCell, self).__init__() 80 | resnet = resnet18(pretrained=config.pretrained) 81 | 82 | if config.channels == 1: 83 | modules = [nn.Conv2d(config.channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)] + list(resnet.children())[1:-1] 84 | self.resnet = nn.Sequential(*modules) 85 | 86 | elif config.channels == 3: 87 | self.resnet = nn.Sequential(*list(resnet.children())[:-1]) 88 | 89 | self.gat = GraphAttentionLayer(512, 512, dropout=0.5, alpha=0.2, concat=False) 90 | 91 | self.size_embed = nn.Sequential( 92 | nn.Linear(2, 16), 93 | nn.ReLU() 94 | ) 95 | self.merge = nn.Linear(512 + 16, 512) 96 | 97 | self.out = nn.LSTM(input_size=512, hidden_size=512, num_layers=1, batch_first=True) 98 | self.predict = nn.Sequential( 99 | nn.Linear(512, config.k_class), 100 | nn.Softmax(dim=-1) 101 | ) 102 | self.tc = nn.Linear(512, config.tissue_class + 1) 103 | 104 | def forward(self, tissue, bag, adj, cell_size, valid_mask, raw_size): # cell number * 3 * 224 * 224 105 | mask_feats = self.resnet(bag).squeeze().reshape(raw_size['batch'], raw_size['cells'], -1) 106 | size_feats = self.size_embed(cell_size) # B C 16 107 | mask_feats = self.merge(torch.concat([mask_feats, size_feats], dim=-1)) # B C 512 108 | # import ipdb 109 | # ipdb.set_trace() 110 | dmask_feats = nn.functional.dropout(mask_feats, p=0.5, training=True) 111 | graph_feats = self.gat(dmask_feats, adj) # B C 512 112 | 113 | global_feat = self.resnet(tissue).squeeze() 114 | if len(global_feat.shape) <= 1: 115 | global_feat = global_feat.unsqueeze(0) 116 | 117 | global_feats = torch.stack([global_feat for _ in range(graph_feats.shape[1])], dim=1) 118 | seq_feats = torch.stack([global_feats, graph_feats], dim=2) # (BxC) 3 512 119 | seq_feats = rearrange(seq_feats, 'B C L F-> (B C) L F') 120 | out_feats, _ = self.out(seq_feats) 121 | out_feats = rearrange(out_feats, '(B C) L F-> B C L F', B=raw_size['batch']) # B C 3 512 122 | dout_feats = nn.functional.dropout(out_feats, p=0.5, training=True) 123 | 124 | # tissue 125 | tissue_cat = self.tc(dout_feats[:, 0, 0]) 126 | # Proportion 127 | probs = self.predict(dout_feats[:, :, 1]) # 16 64 cell_type 128 | prop_list, prob_list, cell_features = [], [], [] 129 | for single_probs, valid_index, cell_embedding in zip(probs, valid_mask, dout_feats[:, :, 1]): 130 | if valid_index <= 0: 131 | continue 132 | prop_list.append(torch.mean(single_probs[:valid_index], dim=0)) 133 | prob_list.append(single_probs[:valid_index]) 134 | cell_features.append(cell_embedding[:valid_index]) 135 | 136 | avg_probs = torch.stack(prop_list, dim=0) 137 | 138 | return prob_list, avg_probs, tissue_cat, cell_features 139 | 140 | 141 | class HistoState(nn.Module): 142 | def __init__(self, config) -> None: 143 | super(HistoState, self).__init__() 144 | resnet = resnet18(pretrained=config.pretrained) 145 | 146 | if config.channels == 1: 147 | modules = [nn.Conv2d(config.channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)] + list(resnet.children())[1:-1] 148 | self.resnet = nn.Sequential(*modules) 149 | 150 | elif config.channels == 3: 151 | self.resnet = nn.Sequential(*list(resnet.children())[:-1]) 152 | 153 | self.gat = GraphAttentionLayer(512, 512, dropout=0.5, alpha=0.2, concat=False) 154 | 155 | self.size_embed = nn.Sequential( 156 | nn.Linear(2, 16), 157 | nn.ReLU() 158 | ) 159 | self.merge = nn.Linear(512 + 16, 512) 160 | 161 | self.out = nn.LSTM(input_size=512, hidden_size=512, num_layers=1, batch_first=True) 162 | self.predict1 = nn.Sequential( 163 | nn.Linear(512, config.k_class), 164 | nn.Softmax(dim=-1) 165 | ) 166 | self.predict2 = nn.Sequential( 167 | nn.Linear(512, config.k_state), 168 | nn.Softmax(dim=-1) 169 | ) 170 | self.tc = nn.Linear(512, config.tissue_class + 1) 171 | 172 | def forward(self, tissue, bag, adj, cell_size, valid_mask, raw_size): # cell number * 3 * 224 * 224 173 | mask_feats = self.resnet(bag).squeeze().reshape(raw_size['batch'], raw_size['cells'], -1) 174 | size_feats = self.size_embed(cell_size) # B C 16 175 | mask_feats = self.merge(torch.concat([mask_feats, size_feats], dim=-1)) # B C 512 176 | graph_feats = self.gat(mask_feats, adj) # B C 512 177 | 178 | global_feat = self.resnet(tissue).squeeze() 179 | if len(global_feat.shape) <= 1: 180 | global_feat = global_feat.unsqueeze(0) 181 | 182 | global_feats = torch.stack([global_feat for _ in range(graph_feats.shape[1])], dim=1) 183 | seq_feats = torch.stack([global_feats, graph_feats, graph_feats], dim=2) # (BxC) 3 512 184 | seq_feats = rearrange(seq_feats, 'B C L F-> (B C) L F') 185 | out_feats, _ = self.out(seq_feats) 186 | out_feats = rearrange(out_feats, '(B C) L F-> B C L F', B=raw_size['batch']) # B C 3 512 187 | 188 | # tissue 189 | tissue_cat = self.tc(out_feats[:, 0, 0]) 190 | # Cell Proportion 191 | type_probs = self.predict1(out_feats[:, :, 1]) # 16 64 cell_type 192 | type_prop_list, type_prob_list = [], [] 193 | for single_probs, valid_index in zip(type_probs, valid_mask): 194 | if valid_index <= 0: 195 | continue 196 | type_prop_list.append(torch.mean(single_probs[:valid_index], dim=0)) 197 | type_prob_list.append(single_probs[:valid_index]) 198 | 199 | avg_type_probs = torch.stack(type_prop_list, dim=0) 200 | 201 | state_probs = self.predict2(out_feats[:, :, 2]) # 16 64 cell_type 202 | state_prop_list, state_prob_list = [], [] 203 | for single_probs, valid_index in zip(state_probs, valid_mask): 204 | if valid_index <= 0: 205 | continue 206 | state_prop_list.append(torch.mean(single_probs[:valid_index], dim=0)) 207 | state_prob_list.append(single_probs[:valid_index]) 208 | 209 | avg_state_probs = torch.stack(state_prop_list, dim=0) 210 | 211 | return { 212 | 'tissue_compartment': tissue_cat, 213 | 'type_prob_list': type_prob_list, 214 | 'type_prop': avg_type_probs, 215 | 'state_prob_list': state_prob_list, 216 | 'state_prop': avg_state_probs 217 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | ca-certificates=2025.2.25=h06a4308_0 7 | certifi=2022.12.7=py37h06a4308_0 8 | charset-normalizer=3.4.1=pypi_0 9 | einops=0.6.1=pypi_0 10 | idna=3.10=pypi_0 11 | joblib=1.3.2=pypi_0 12 | ld_impl_linux-64=2.40=h12ee557_0 13 | libffi=3.4.4=h6a678d5_1 14 | libgcc-ng=11.2.0=h1234567_1 15 | libgomp=11.2.0=h1234567_1 16 | libstdcxx-ng=11.2.0=h1234567_1 17 | ncurses=6.4=h6a678d5_0 18 | numpy=1.21.6=pypi_0 19 | opencv-python=4.11.0.86=pypi_0 20 | openssl=1.1.1w=h7f8727e_0 21 | pandas=1.3.5=pypi_0 22 | pillow=9.5.0=pypi_0 23 | pip=22.3.1=py37h06a4308_0 24 | python=3.7.16=h7a1cb2a_0 25 | python-dateutil=2.9.0.post0=pypi_0 26 | pytz=2025.1=pypi_0 27 | pyyaml=6.0.1=pypi_0 28 | readline=8.2=h5eee18b_0 29 | requests=2.31.0=pypi_0 30 | scikit-learn=1.0.2=pypi_0 31 | scipy=1.7.3=pypi_0 32 | setuptools=65.6.3=py37h06a4308_0 33 | six=1.17.0=pypi_0 34 | sqlite=3.45.3=h5eee18b_0 35 | threadpoolctl=3.1.0=pypi_0 36 | tk=8.6.14=h39e8969_0 37 | torch=1.12.1+cu113=pypi_0 38 | torchaudio=0.12.1+cu113=pypi_0 39 | torchvision=0.13.1+cu113=pypi_0 40 | tqdm=4.67.1=pypi_0 41 | typing-extensions=4.7.1=pypi_0 42 | urllib3=2.0.7=pypi_0 43 | wheel=0.38.4=py37h06a4308_0 44 | xz=5.6.4=h5eee18b_1 45 | yacs=0.1.8=pypi_0 46 | zlib=1.2.13=h5eee18b_1 47 | -------------------------------------------------------------------------------- /tcs/tissue_compartment_BRCA.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "Stromal", "1": "TME", "2": "TME", "3": "TME", "4": "TME", "5": "Epi"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_CRC.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "Epi", "1": "Stromal", "2": "TME", "3": "TME", "4": "TME", "5": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_HCC.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "TME", "1": "Epi", "2": "TME", "3": "Stromal", "4": "TME", "5": "Epi"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_HD.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "TME", "1": "TME", "2": "TME", "3": "Stromal", "4": "Epi", "5": "Epi"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_LUSC.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "TME", "1": "Epi", "2": "Epi", "3": "Stromal", "4": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_Mix.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "Epi", "1": "TME", "2": "Stromal", "3": "TME", "4": "Epi"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_OVC.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "Stromal", "1": "TME", "2": "Stromal", "3": "Epi", "4": "TME", "5": "TME", "6": "TME", "7": "Epi", "8": "TME", "9": "TME", "10": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_PDAC.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "TME", "1": "Epi", "2": "TME", "3": "Epi", "4": "Stromal", "5": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_PRAD.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "TME", "1": "Epi", "2": "Stromal", "3": "Stromal", "4": "TME", "5": "TME", "6": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_RCC.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "Epi", "1": "Epi", "2": "Stromal", "3": "TME", "4": "TME", "5": "TME", "6": "Stromal", "7": "TME", "8": "Epi", "9": "TME", "10": "TME", "11": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_STAD.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "Epi", "1": "Epi", "2": "Epi", "3": "TME", "4": "Epi", "5": "Epi", "6": "Epi", "7": "Epi", "8": "TME", "9": "Stromal", "10": "Stromal", "11": "TME", "12": "TME", "13": "Stromal", "14": "Epi", "15": "Stromal", "16": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 3] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_Xenium.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "TME", "1": "Epi", "2": "Stromal", "3": "TME", "4": "Epi", "5": "Stromal", "6": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "HoVerNet": 7 | [2, 0, 1, 2, 2, 0] 8 | } -------------------------------------------------------------------------------- /tcs/tissue_compartment_state.json: -------------------------------------------------------------------------------- 1 | { 2 | "dict": 3 | {"0": "TME", "1": "Epi", "2": "TME", "3": "TME", "4": "Stromal", "5": "TME"}, 4 | "list": 5 | ["Epi", "TME", "Stromal"], 6 | "state": 7 | {"0": "1", "1": "6", "2": "1", "3": "1", "4": "1", "5": "1"}, 8 | "HoVerNet": 9 | [2, 0, 1, 2, 2, 0] 10 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from model.arch import HistoCell 5 | from data import TileBatchDataset 6 | from utils.utils import * 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from torch.optim import Adam 10 | from configs import _get_config 11 | 12 | def main(): 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 14 | parser = argparse.ArgumentParser(description="Prediction with Spots Images") 15 | parser.add_argument('--seed', default=47, type=int) 16 | parser.add_argument('--model', default='Baseline', type=str, help='model description') 17 | parser.add_argument('--tissue', default='BRCA', type=str) 18 | parser.add_argument('--deconv', default='CARD', type=str) 19 | parser.add_argument('--subtype', action='store_true') 20 | parser.add_argument('--prefix', required = False, nargs = '+') 21 | parser.add_argument('--k_class', default=8, type=int) 22 | parser.add_argument('--tissue_compartment', type=str, required=True) 23 | parser.add_argument('--ratio', type=float, required=False, default=1.0) 24 | 25 | args = parser.parse_args() 26 | config = _get_config(args.tissue, args.deconv, args.subtype, args.k_class, args.tissue_compartment) 27 | 28 | setup_seed(args.seed) 29 | print(f"Model details: {args.model}") 30 | 31 | print("Load Dataset...") 32 | 33 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix] 34 | 35 | print(data_prefix) 36 | train_data = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, deconv=config.data.deconv, prefix=data_prefix, sample_ratio=args.ratio, ext='png') 37 | print(f"length of train data: {len(train_data)}") 38 | train_loader = DataLoader(train_data, batch_size=config.data.batch_size, shuffle=True, num_workers=6, pin_memory=True) 39 | 40 | print("Load Model...") 41 | model = HistoCell(config.model) 42 | loss_func = torch.nn.KLDivLoss() 43 | aux_loss = torch.nn.CrossEntropyLoss() 44 | optimizer = Adam(model.parameters(), lr=config.train.lr) 45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | if torch.cuda.is_available(): 47 | print("Using 1 gpu") 48 | else: 49 | print("Using cpu") 50 | 51 | 52 | model_dir, ckpt_dir = \ 53 | os.path.join(config.data.save_model, args.model), os.path.join(config.data.ckpt, args.model) 54 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True) 55 | 56 | ckpt_file = find_ckpt(model_dir) 57 | # ipdb.set_trace() 58 | if ckpt_file is not None: 59 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file)) 60 | model.load_state_dict(state_dict) 61 | else: 62 | current_epoch = -1 63 | 64 | model.to(device) 65 | 66 | print(f"Current Epoch: {current_epoch}") 67 | 68 | print("Training...") 69 | for iter in range(current_epoch + 1, config.train.epoch): 70 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device) 71 | if iter % config.train.val_iter == 0 and iter > config.train.val_min_iter: 72 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt')) 73 | 74 | print("Training finished!") 75 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}") 76 | 77 | 78 | def train_loop(epoch, model, train_loader, loss_func, aux_loss, opt, device): 79 | model.train() 80 | train_bar = tqdm(train_loader, desc="epoch " + str(epoch), total=len(train_loader), 81 | unit="batch", dynamic_ncols=True) 82 | for idx, data in enumerate(train_bar): 83 | tissue = data['tissue'].to(torch.float32).to(device) 84 | images, cell_proportion = data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device) 85 | valid_mask = data['mask'].to(torch.long).to(device) 86 | cell_size = data['size'].to(torch.float32).to(device) 87 | adj_mat = data['adj'].to(torch.float32).to(device) 88 | gt_tissue_cat = data['tissue_cat'].to(device) 89 | if torch.sum(data['mask']) <= 0: 90 | continue 91 | batch, cells, channels, height, width = images.shape 92 | images = images.reshape(batch * cells, channels, height, width) 93 | probs, pred_proportion, tissue_cat, _ = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells}) 94 | cell_prop_list = [] 95 | for single_props, valid_index in zip(cell_proportion, valid_mask): 96 | if valid_index <= 0: 97 | continue 98 | cell_prop_list.append(single_props) 99 | cell_proportion = torch.stack(cell_prop_list, dim=0) 100 | 101 | loss_out = loss_func((cell_proportion + 1e-10).log(), pred_proportion + 1e-10) + \ 102 | loss_func((pred_proportion + 1e-10).log(), cell_proportion + 1e-10) + \ 103 | aux_loss(tissue_cat, gt_tissue_cat) 104 | loss_out.backward() 105 | opt.step() 106 | opt.zero_grad() 107 | 108 | loss_value = loss_out.item() 109 | train_bar.set_description('epoch:{} iter:{} loss:{}'.format(epoch, idx, loss_value)) 110 | 111 | if __name__ == '__main__': 112 | main() -------------------------------------------------------------------------------- /train_oneout_pannuke.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import argparse 5 | import random 6 | import pickle as pkl 7 | from model.arch import HistoCell 8 | from data import TileBatchDataset 9 | from utils.utils import * 10 | from torch.utils.data import DataLoader, dataset 11 | from tqdm import tqdm 12 | from yacs.config import CfgNode as CN 13 | from torch.optim import Adam 14 | 15 | def _get_config(tissue_type, deconv, subtype, k_class, tissue_dir): 16 | config = CN() 17 | config.train = CN() 18 | config.train.lr = 0.0005 19 | config.train.epoch = 61 20 | config.train.val_iter = 20 21 | config.train.val_min_iter = 1 22 | 23 | config.data = CN() 24 | config.data.deconv = deconv 25 | config.data.save_model = f'./train_log/pannuke/{tissue_type}/models' 26 | config.data.ckpt = f'./train_log/pannuke/{tissue_type}/ckpts' 27 | config.data.tile_dir = f'/data/gcf22/PaNnuke/raw_images' 28 | config.data.mask_dir = f'/data/gcf22/PaNnuke/cls_results_cv' 29 | config.data.batch_size = 64 30 | config.data.tissue_dir = tissue_dir 31 | 32 | config.model = CN() 33 | config.model.tissue_class = 3 34 | config.model.pretrained = True 35 | config.model.channels = 3 36 | 37 | config.data.cell_dir = f'/data/gcf22/PaNnuke/cell_props' 38 | 39 | config.model.k_class = k_class 40 | 41 | return config 42 | 43 | 44 | def train_loop(epoch, model, train_loader, loss_func, aux_loss, opt, device): 45 | model.train() 46 | train_bar = tqdm(train_loader, desc="epoch " + str(epoch), total=len(train_loader), 47 | unit="batch", dynamic_ncols=True) 48 | for idx, data in enumerate(train_bar): 49 | tissue = data['tissue'].to(torch.float32).to(device) 50 | images, cell_proportion = data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device) 51 | valid_mask = data['mask'].to(torch.long).to(device) 52 | cell_size = data['size'].to(torch.float32).to(device) 53 | adj_mat = data['adj'].to(torch.float32).to(device) 54 | gt_tissue_cat = data['tissue_cat'].to(device) 55 | if torch.sum(data['mask']) <= 0: 56 | continue 57 | batch, cells, channels, height, width = images.shape 58 | images = images.reshape(batch * cells, channels, height, width) 59 | probs, pred_proportion, tissue_cat, _ = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells}) 60 | cell_prop_list = [] 61 | for single_props, valid_index in zip(cell_proportion, valid_mask): 62 | if valid_index <= 0: 63 | continue 64 | cell_prop_list.append(single_props) 65 | cell_proportion = torch.stack(cell_prop_list, dim=0) 66 | 67 | loss_out = loss_func((cell_proportion + 1e-10).log(), pred_proportion + 1e-10) + \ 68 | loss_func((pred_proportion + 1e-10).log(), cell_proportion + 1e-10) 69 | loss_out.backward() 70 | opt.step() 71 | opt.zero_grad() 72 | 73 | loss_value = loss_out.item() 74 | train_bar.set_description('epoch:{} iter:{} loss:{}'.format(epoch, idx, loss_value)) 75 | 76 | 77 | def val_loop(epoch, model, val_loader, device): 78 | model.train() 79 | val_bar = tqdm(val_loader, desc="epoch " + str(epoch), total=len(val_loader), 80 | unit="batch", dynamic_ncols=True) 81 | 82 | all_results = {} 83 | for idx, data in enumerate(val_bar): 84 | tissue = data['tissue'].to(torch.float32).to(device) 85 | images, cell_proportion = data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device) 86 | cell_size = data['size'].to(torch.float32).to(device) 87 | adj_mat = data['adj'].to(torch.float32).to(device) 88 | valid_mask = data['mask'].to(torch.long).to(device) 89 | cell_coords = data['cell_coords'].to(torch.float32).to(device) 90 | ref_type = data['cell_types'].to(torch.long).to(device) 91 | if torch.sum(data['mask']) <= 0: 92 | continue 93 | batch, cells, channels, height, width = images.shape 94 | images = images.reshape(batch * cells, channels, height, width) 95 | prob_list, pred_proportion, _, cell_embeddings = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells}) 96 | name_list, valid_list, coord_list, type_list, cell_prop_list = [], [], [], [], [] 97 | for name, single_prop, valid_index, valid_coord, valid_type in zip(data['name'], cell_proportion, valid_mask, cell_coords, ref_type): 98 | if valid_index <= 0: 99 | continue 100 | name_list.append(name) 101 | valid_list.append(valid_index) 102 | coord_list.append(valid_coord) 103 | type_list.append(valid_type) 104 | cell_prop_list.append(single_prop) 105 | 106 | for name, pb, pp, cp, vix, cc, ct in zip(name_list, prob_list, pred_proportion, cell_prop_list, valid_list, coord_list, type_list): 107 | valid_num = int(vix.detach().cpu()) 108 | all_results.update( 109 | { 110 | name: { 111 | 'pred_proportion': pp.detach().cpu().numpy(), 112 | 'cell_num': valid_num, 113 | 'cell_coords': cc[:valid_num].detach().cpu().numpy(), 114 | 'prob': pb.detach().cpu().numpy(), 115 | 'prior_type': ct[:valid_num].detach().cpu().numpy(), 116 | 'gt_proportion': cp.detach().cpu().numpy(), 117 | # 'cell_embedding': cell_features.detach().cpu().numpy() 118 | } 119 | } 120 | ) 121 | val_bar.set_description('epoch:{} iter:{}'.format(epoch, idx)) 122 | return all_results 123 | 124 | 125 | def main(): 126 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 127 | parser = argparse.ArgumentParser(description="Prediction with Spots Images" ) 128 | parser.add_argument('--seed', default=47, type=int) 129 | parser.add_argument('--model', default='Baseline', type=str, help='model description') 130 | parser.add_argument('--tissue', default='*', type=str) 131 | parser.add_argument('--deconv', default='CARD', type=str) 132 | parser.add_argument('--subtype', action='store_true') 133 | parser.add_argument('--prefix', required = False, nargs = '+') 134 | parser.add_argument('--k_class', default=8, type=int) 135 | parser.add_argument('--tissue_compartment', type=str, required=True) 136 | parser.add_argument('--sample_ratio', default = 1, type=float) 137 | parser.add_argument('--reso', default = 1, type=int) 138 | 139 | args = parser.parse_args() 140 | config = _get_config(args.tissue, args.deconv, args.subtype, args.k_class, args.tissue_compartment) 141 | 142 | setup_seed(args.seed) 143 | print(f"Model details: {args.model}") 144 | 145 | print("Load Dataset...") 146 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix] 147 | print(data_prefix) 148 | 149 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 150 | if torch.cuda.is_available(): 151 | print("Using 1 gpu") 152 | else: 153 | print("Using cpu") 154 | 155 | for idx, val_prefix in enumerate(data_prefix): 156 | raw_prefix = [item for item in data_prefix] 157 | raw_prefix.pop(idx) 158 | val_fold = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, deconv=config.data.deconv, prefix=[val_prefix], focus_tissue=args.tissue) 159 | train_fold = TileBatchDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, deconv=config.data.deconv, prefix=raw_prefix, focus_tissue=args.tissue, reso=args.reso) 160 | print(f"length of train data: {len(train_fold)} with sample {raw_prefix}") 161 | train_fold,_ = torch.utils.data.random_split(train_fold, [int(len(train_fold)*args.sample_ratio), len(train_fold) - int(len(train_fold)*args.sample_ratio)]) 162 | 163 | print(f"length of sampled train data: {len(train_fold)} with sample {raw_prefix}") 164 | print(f"length of val data: {len(val_fold)} with sample {val_prefix}") 165 | 166 | train_loader = DataLoader(train_fold, batch_size=config.data.batch_size, num_workers=6, pin_memory=True) 167 | val_loader = DataLoader(val_fold, batch_size=config.data.batch_size // 2, num_workers=6, pin_memory=True) 168 | 169 | print("Load Model...") 170 | model = HistoCell(config.model) 171 | loss_func = torch.nn.KLDivLoss() 172 | aux_loss = torch.nn.CrossEntropyLoss() 173 | optimizer = Adam(model.parameters(), lr=config.train.lr) 174 | 175 | model_dir, ckpt_dir = \ 176 | os.path.join(config.data.save_model, args.model, val_prefix), os.path.join(config.data.ckpt, args.model, val_prefix) 177 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True) 178 | 179 | ckpt_file = find_ckpt(model_dir) 180 | # ipdb.set_trace() 181 | if ckpt_file is not None: 182 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file)) 183 | model.load_state_dict(state_dict) 184 | else: 185 | current_epoch = -1 186 | 187 | model.to(device) 188 | 189 | print(f"Current Epoch: {current_epoch}") 190 | 191 | print("Training...") 192 | for iter in range(current_epoch + 1, config.train.epoch): 193 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device) 194 | # if iter % 5 == 0: 195 | # save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt')) 196 | # TODO 197 | if iter > config.train.val_min_iter and iter % config.train.val_iter == 0: 198 | print("##### Validation #####") 199 | val_results = val_loop(iter, model, val_loader, device) 200 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.pt')) 201 | with open(os.path.join(ckpt_dir, f'epoch_{iter}_val.pkl'), 'wb') as file: 202 | pkl.dump(val_results, file) 203 | 204 | # save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt')) 205 | 206 | print("Training finished!") 207 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}") 208 | 209 | if __name__ == '__main__': 210 | main() -------------------------------------------------------------------------------- /train_state.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | from model.arch import HistoState 5 | from data import TileBatchStateDataset 6 | from utils.utils import * 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from torch.optim import Adam 10 | from configs import _get_cell_state_config 11 | 12 | def main(): 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 14 | parser = argparse.ArgumentParser(description="Prediction with Spots Images") 15 | parser.add_argument('--seed', default=47, type=int) 16 | parser.add_argument('--model', default='Baseline', type=str, help='model description') 17 | parser.add_argument('--tissue', default='BRCA', type=str) 18 | parser.add_argument('--deconv', default='CARD', type=str) 19 | parser.add_argument('--subtype', action='store_true') 20 | parser.add_argument('--prefix', required = False, nargs = '+') 21 | parser.add_argument('--tissue_compartment', type=str, required=True) 22 | 23 | args = parser.parse_args() 24 | config = _get_cell_state_config(args.tissue, args.deconv, args.subtype, args.tissue_compartment) 25 | 26 | setup_seed(args.seed) 27 | print(f"Model details: {args.model}") 28 | 29 | print("Load Dataset...") 30 | 31 | data_prefix = args.prefix if isinstance(args.prefix, list) else [args.prefix] 32 | 33 | print(data_prefix) 34 | train_data = TileBatchStateDataset(config.data.tile_dir, config.data.mask_dir, config.data.tissue_dir, config.data.cell_dir, config.data.state_dir, deconv=config.data.deconv, prefix=data_prefix) 35 | print(f"length of train data: {len(train_data)}") 36 | train_loader = DataLoader(train_data, batch_size=config.data.batch_size, shuffle=True, num_workers=6, pin_memory=True) 37 | 38 | print("Load Model...") 39 | model = HistoState(config.model) 40 | loss_func = torch.nn.KLDivLoss() 41 | aux_loss = torch.nn.CrossEntropyLoss() 42 | optimizer = Adam(model.parameters(), lr=config.train.lr) 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | if torch.cuda.is_available(): 45 | print("Using 1 gpu") 46 | else: 47 | print("Using cpu") 48 | 49 | 50 | model_dir, ckpt_dir = \ 51 | os.path.join(config.data.save_model, args.model), os.path.join(config.data.ckpt, args.model) 52 | os.makedirs(model_dir, exist_ok=True), os.makedirs(ckpt_dir, exist_ok=True) 53 | 54 | ckpt_file = find_ckpt(model_dir) 55 | # ipdb.set_trace() 56 | if ckpt_file is not None: 57 | state_dict, current_epoch = load_ckpt(os.path.join(model_dir, ckpt_file)) 58 | model.load_state_dict(state_dict) 59 | else: 60 | current_epoch = -1 61 | 62 | model.to(device) 63 | 64 | print(f"Current Epoch: {current_epoch}") 65 | 66 | print("Training with Cell Types...") 67 | if current_epoch + 1 <= config.train.epoch: 68 | for iter in range(current_epoch + 1, config.train.epoch): 69 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device) 70 | 71 | if iter % config.train.val_iter == 0 and iter > config.train.val_min_iter: 72 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt')) 73 | 74 | print("Training with Cell States...") 75 | for iter in range(max(config.train.epoch, current_epoch + 1), config.train.epoch + config.train.state_epoch): 76 | train_loop(iter, model, train_loader, loss_func, aux_loss, optimizer, device, with_state=True) 77 | 78 | if iter % config.train.val_iter == 0 and iter > config.train.val_min_iter: 79 | save_checkpoint(model, optimizer, os.path.join(model_dir, f'epoch_{iter}.ckpt')) 80 | 81 | print("Training finished!") 82 | print(f"Results saved in {os.path.join(config.data.ckpt, args.model)}") 83 | 84 | 85 | def train_loop(epoch, model, train_loader, loss_func, aux_loss, opt, device, with_state=False): 86 | model.train() 87 | train_bar = tqdm(train_loader, desc="epoch " + str(epoch), total=len(train_loader), 88 | unit="batch", dynamic_ncols=True) 89 | for idx, data in enumerate(train_bar): 90 | tissue = data['tissue'].to(torch.float32).to(device) 91 | images, cell_proportion, state_proportion = \ 92 | data['image'].to(torch.float32).to(device), data['cells'].to(torch.float32).to(device), data['states'].to(torch.float32).to(device) 93 | valid_mask = data['mask'].to(torch.long).to(device) 94 | cell_size = data['size'].to(torch.float32).to(device) 95 | adj_mat = data['adj'].to(torch.float32).to(device) 96 | gt_tissue_cat = data['tissue_cat'].to(device) 97 | if torch.sum(data['mask']) <= 0: 98 | continue 99 | batch, cells, channels, height, width = images.shape 100 | images = images.reshape(batch * cells, channels, height, width) 101 | pred_dict = model(tissue, images, adj_mat, cell_size, valid_mask, {'batch': batch, 'cells': cells}) 102 | cell_prop_list = [] 103 | for single_props, valid_index in zip(cell_proportion, valid_mask): 104 | if valid_index <= 0: 105 | continue 106 | cell_prop_list.append(single_props) 107 | cell_proportion = torch.stack(cell_prop_list, dim=0) 108 | 109 | loss_out = loss_func((cell_proportion + 1e-10).log(), pred_dict['type_prop'] + 1e-10) + \ 110 | loss_func((pred_dict['type_prop'] + 1e-10).log(), cell_proportion + 1e-10) + \ 111 | aux_loss(pred_dict['tissue_compartment'], gt_tissue_cat) 112 | 113 | if with_state: 114 | state_prop_list = [] 115 | for name, single_states, valid_index in zip(data['name'], state_proportion, valid_mask): 116 | if valid_index <= 0: 117 | continue 118 | state_prop_list.append(single_states) 119 | state_proportion = torch.stack(state_prop_list, dim=0) 120 | 121 | state_loss = loss_func((state_proportion + 1e-10).log(), pred_dict['state_prop'] + 1e-10) + loss_func((pred_dict['state_prop'] + 1e-10).log(), state_proportion + 1e-10) 122 | # consist_loss = torch.nn.functional.l1_loss(pred_dict['state_prop'][:, :5], pred_dict['type_prop'][:, [0, 2, 3, 5, 6]]) 123 | consist_loss = torch.nn.functional.l1_loss(pred_dict['state_prop'][:, :6], pred_dict['type_prop'][:, :6]) 124 | loss_out = loss_out + state_loss + 0.5 * consist_loss 125 | 126 | loss_out.backward() 127 | opt.step() 128 | opt.zero_grad() 129 | 130 | loss_value = loss_out.item() 131 | train_bar.set_description('epoch:{} iter:{} loss:{}'.format(epoch, idx, loss_value)) 132 | 133 | if __name__ == '__main__': 134 | main() -------------------------------------------------------------------------------- /tutorial/data/url.txt: -------------------------------------------------------------------------------- 1 | Download from url: 2 | https://drive.google.com/file/d/1iwXzxafdUF7SF6IAEyWluN728WbxAMEi/view?usp=sharing -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/recolyce/HistoCell/d6c7863a0334f2219777fd297c1083ee7b4675b8/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import torch 4 | import random 5 | import os 6 | import cv2 7 | import scipy 8 | from scipy.optimize import linear_sum_assignment 9 | from PIL import Image 10 | 11 | def reg_loss(props: torch.Tensor): 12 | # props: torch.Size([Batch Cell Prob]) 13 | log_prop = torch.log(props) 14 | loss = props.mul(log_prop) 15 | return -torch.mean(loss) 16 | 17 | def index_filter(used_idx, all_idx): 18 | rest_index = [] 19 | for item in all_idx: 20 | if item in used_idx: 21 | continue 22 | rest_index.append(item) 23 | return rest_index 24 | 25 | def load_image(img_path): 26 | fp = open(img_path, 'rb') 27 | pic = Image.open(fp) 28 | pic = np.array(pic) 29 | fp.close() 30 | return pic 31 | 32 | def save_image(src, img_path): 33 | img = Image.fromarray(src, 'RGB') 34 | img.save(img_path) 35 | 36 | def load_json(json_path): 37 | with open(json_path, 'r', encoding='utf-8') as file: 38 | nuc_info = json.load(file) 39 | 40 | return nuc_info['nuc'] 41 | 42 | def dump_json(src: dict, json_path): 43 | with open(json_path, 'w', encoding='utf-8') as file: 44 | json.dump(src, file) 45 | 46 | def setup_seed(seed): 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed_all(seed) 49 | np.random.seed(seed) 50 | random.seed(seed) 51 | torch.backends.cudnn.deterministic = True 52 | 53 | def collate_fn(batch): 54 | tensor_batch = {} 55 | for key, value in batch[0].items(): 56 | if isinstance(value, str): 57 | tensor_batch.update({key: value}) 58 | continue 59 | tensor_batch.update({key: torch.tensor(value, dtype=torch.float32)}) 60 | 61 | return tensor_batch 62 | 63 | def collate_batch_fn(iter_batch): 64 | tensor_batch = {key: [] for key in iter_batch[0].keys()} 65 | for batch in iter_batch: 66 | for key, value in batch.items(): 67 | tensor_batch[key].append(value) 68 | 69 | for key in tensor_batch.keys(): 70 | if key in ['tissue', 'cells', 'image']: 71 | tensor_batch[key] = torch.stack(tensor_batch[key], dim=0).to(torch.float32) 72 | if key in ['mask']: 73 | tensor_batch[key] = torch.tensor(tensor_batch[key], dtype=torch.long) 74 | 75 | return tensor_batch 76 | 77 | def find_ckpt(file_dir): 78 | list=os.listdir(file_dir) 79 | list.sort(key=lambda fn: os.path.getmtime(os.path.join(file_dir, fn)) if not os.path.isdir(os.path.join(file_dir+fn)) else 0) 80 | if list != []: 81 | return list[-1] 82 | else: 83 | return None 84 | 85 | 86 | def load_ckpt(path): 87 | ckpt = torch.load(path, map_location="cpu") 88 | model = ckpt['state_dict'] 89 | current_epoch = int(path.split('/')[-1].split('.')[0].split('_')[-1]) 90 | return model, current_epoch 91 | 92 | def save_checkpoint(model, optimizer, save_dir): 93 | print(f"Saving checkpoint to {save_dir}") 94 | checkpoint = { 95 | 'state_dict': model.state_dict(), 96 | 'optimizer': optimizer.state_dict() 97 | } 98 | torch.save(checkpoint, save_dir) 99 | 100 | def get_bounding_box(img): 101 | """Get bounding box coordinate information.""" 102 | rows = np.any(img, axis=1) 103 | cols = np.any(img, axis=0) 104 | rmin, rmax = np.where(rows)[0][[0, -1]] 105 | cmin, cmax = np.where(cols)[0][[0, -1]] 106 | # due to python indexing, need to add 1 to max 107 | # else accessing will be 1px in the box, not out 108 | rmax += 1 109 | cmax += 1 110 | return [rmin, rmax, cmin, cmax] 111 | 112 | def get_centroid(pred_inst, types = None, label_dict = None, id_offset = 0): 113 | inst_id_list = np.unique(pred_inst[pred_inst > 0]) # exlcude background 0411 is cancer v.s. normal 114 | inst_info_dict = {} 115 | for inst_id in inst_id_list: 116 | inst_map = pred_inst == inst_id 117 | # TODO: chane format of bbox output 118 | rmin, rmax, cmin, cmax = get_bounding_box(inst_map) 119 | inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) 120 | inst_map = inst_map[ 121 | inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] 122 | ] 123 | inst_map = inst_map.astype(np.uint8) 124 | inst_moment = cv2.moments(inst_map) 125 | inst_contour = cv2.findContours( 126 | inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 127 | ) 128 | # * opencv protocol format may break 129 | inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) 130 | # < 3 points dont make a contour, so skip, likely artifact too 131 | # as the contours obtained via approximation => too small or sthg 132 | if inst_contour.shape[0] < 3: 133 | continue 134 | if len(inst_contour.shape) != 2: 135 | continue # ! check for trickery shape 136 | inst_centroid = [ 137 | (inst_moment["m10"] / inst_moment["m00"]), 138 | (inst_moment["m01"] / inst_moment["m00"]), 139 | ] 140 | inst_centroid = np.array(inst_centroid) 141 | inst_contour[:, 0] += inst_bbox[0][1] # X 142 | inst_contour[:, 1] += inst_bbox[0][0] # Y 143 | inst_centroid[0] += inst_bbox[0][1] # X 144 | inst_centroid[1] += inst_bbox[0][0] # Y 145 | inst_info_dict[inst_id + id_offset] = { # inst_id should start at 1 146 | "bbox": inst_bbox, 147 | "centroid": inst_centroid, 148 | "contour": inst_contour, 149 | "type": types if label_dict is None else label_dict[types], 150 | } 151 | 152 | return inst_info_dict 153 | 154 | 155 | def pair_coordinates(setA, setB, radius): 156 | """Use the Munkres or Kuhn-Munkres algorithm to find the most optimal 157 | unique pairing (largest possible match) when pairing points in set B 158 | against points in set A, using distance as cost function. 159 | 160 | Args: 161 | setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate 162 | of N different points 163 | radius: valid area around a point in setA to consider 164 | a given coordinate in setB a candidate for match 165 | Return: 166 | pairing: pairing is an array of indices 167 | where point at index pairing[0] in set A paired with point 168 | in set B at index pairing[1] 169 | unparedA, unpairedB: remaining poitn in set A and set B unpaired 170 | 171 | """ 172 | # * Euclidean distance as the cost matrix 173 | pair_distance = scipy.spatial.distance.cdist(setA, setB, metric='euclidean') 174 | 175 | # * Munkres pairing with scipy library 176 | # the algorithm return (row indices, matched column indices) 177 | # if there is multiple same cost in a row, index of first occurence 178 | # is return, thus the unique pairing is ensured 179 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance) 180 | 181 | # extract the paired cost and remove instances 182 | # outside of designated radius 183 | pair_cost = pair_distance[indicesA, paired_indicesB] 184 | 185 | pairedA = indicesA[pair_cost <= radius] 186 | pairedB = paired_indicesB[pair_cost <= radius] 187 | 188 | pairing = np.concatenate([pairedA[:,None], pairedB[:,None]], axis=-1) 189 | unpairedA = np.delete(np.arange(setA.shape[0]), pairedA) 190 | unpairedB = np.delete(np.arange(setB.shape[0]), pairedB) 191 | return pairing, unpairedA, unpairedB 192 | 193 | def remap_label(pred, by_size=False): 194 | """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 195 | not [0, 2, 4, 6]. The ordering of instances (which one comes first) 196 | is preserved unless by_size=True, then the instances will be reordered 197 | so that bigger nucler has smaller ID. 198 | 199 | Args: 200 | pred : the 2d array contain instances where each instances is marked 201 | by non-zero integer 202 | by_size : renaming with larger nuclei has smaller id (on-top) 203 | 204 | """ 205 | pred_id = list(np.unique(pred)) 206 | pred_id.remove(0) 207 | if len(pred_id) == 0: 208 | return pred # no label 209 | if by_size: 210 | pred_size = [] 211 | for inst_id in pred_id: 212 | size = (pred == inst_id).sum() 213 | pred_size.append(size) 214 | # sort the id by size in descending order 215 | pair_list = zip(pred_id, pred_size) 216 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) 217 | pred_id, pred_size = zip(*pair_list) 218 | 219 | new_pred = np.zeros(pred.shape, np.int32) 220 | for idx, inst_id in enumerate(pred_id): 221 | new_pred[pred == inst_id] = idx + 1 222 | return new_pred --------------------------------------------------------------------------------