├── README.md ├── data └── sample_TCGA_images │ ├── Adrenocortical_carcinoma │ └── 0.jpg │ ├── Bladder_Urothelial_Carcinoma │ └── 0.jpg │ ├── Brain_Lower_Grade_Glioma │ └── 0.jpg │ ├── Breast_invasive_carcinoma │ └── 0.jpg │ ├── Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma │ └── 0.jpg │ ├── Cholangiocarcinoma │ └── 0.jpg │ ├── Colon_adenocarcinoma │ └── 0.jpg │ ├── Esophageal_carcinoma │ └── 0.jpg │ ├── Glioblastoma_multiforme │ └── 0.jpg │ ├── Head_and_Neck_squamous_cell_carcinoma │ └── 0.jpg │ ├── Kidney_Chromophobe │ └── 0.jpg │ ├── Kidney_renal_clear_cell_carcinoma │ └── 0.jpg │ ├── Kidney_renal_papillary_cell_carcinoma │ └── 0.jpg │ ├── Liver_hepatocellular_carcinoma │ └── 0.jpg │ ├── Lung_adenocarcinoma │ └── 0.jpg │ ├── Lung_squamous_cell_carcinoma │ └── 0.jpg │ ├── Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma │ └── 0.jpg │ ├── Mesothelioma │ └── 0.jpg │ ├── Ovarian_serous_cystadenocarcinoma │ └── 0.jpg │ ├── Pancreatic_adenocarcinoma │ └── 0.jpg │ ├── Pheochromocytoma_and_Paraganglioma │ └── 0.jpg │ ├── Prostate_adenocarcinoma │ └── 0.jpg │ ├── Rectum_adenocarcinoma │ └── 0.jpg │ ├── Sarcoma │ └── 0.jpg │ ├── Skin_Cutaneous_Melanoma │ └── 0.jpg │ ├── Stomach_adenocarcinoma │ └── 0.jpg │ ├── Testicular_Germ_Cell_Tumors │ └── 0.jpg │ ├── Thymoma │ └── 0.jpg │ ├── Thyroid_carcinoma │ └── 0.jpg │ ├── Uterine_Carcinosarcoma │ └── 0.jpg │ ├── Uterine_Corpus_Endometrial_Carcinoma │ └── 0.jpg │ └── Uveal_Melanoma │ └── 0.jpg ├── model.py ├── pretrained_models └── .gitignore ├── run_inference.py ├── sample_visual_results └── .gitignore └── validate_model.py /README.md: -------------------------------------------------------------------------------- 1 | ## Histopathological Image Classification with Cell Morphology Aware Deep Neural Networks 2 | 3 |
4 | 5 | 6 | 7 |
8 | 9 | #### 1. Overview [[Paper]](https://openaccess.thecvf.com/content/CVPR2024W/CVMI/papers/Ignatov_Histopathological_Image_Classification_with_Cell_Morphology_Aware_Deep_Neural_Networks_CVPRW_2024_paper.pdf) 10 | 11 | This repository provides the implementation of the foundation **DeepCMorph CNN model** designed for histopathological image classification and analysis. Unlike the existing models, DeepCMorph explicitly **learns cell morphology**: its segmentation module is trained to identify different cell types and nuclei morphological features. 12 | 13 | Key DeepCMorph features: 14 | 15 | 1. Achieves the state-of-the-art results on the **TCGA**, **NCT-CRC-HE** and **Colorectal cancer (CRC)** datasets 16 | 2. Consists of two independent **nuclei segmentation / classification** and **tissue classification** modules 17 | 3. The segmentation module is pre-trained on a combination of **8 segmentation datasets** 18 | 4. The classification module is pre-trained on the **Pan-Cancer TCGA dataset** (8736 diagnostic slides / 7175 patients) 19 | 5. Can be applied to images of **arbitrary resolutions** 20 | 6. Can be trained or fine-tuned on **one GPU** 21 | 22 |
23 | 24 | #### 2. Prerequisites 25 | 26 | - Python: numpy and imageio packages 27 | - [PyTorch + TorchVision](https://pytorch.org/) libraries 28 | - [Optional] Nvidia GPU 29 | 30 |
31 | 32 | #### 3. Download Pre-Trained Models 33 | 34 | The segmentation module of all pre-trained models is trained on a combination of 8 publicly available nuclei segmentation / classification datasets: **Lizard, CryoNuSeg, MoNuSAC, BNS, TNBC, KUMAR, MICCAI** and **PanNuke** datasets. 35 | 36 | | Dataset | #Classes | Accuracy | Download Link | 37 | |-----------------------------|----------|----------|---------------| 38 | | Combined [[TCGA](https://zenodo.org/records/5889558) + [NCT_CRC_HE](https://zenodo.org/records/1214456)] | 41 | 81.59% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_Datasets_Combined_41_classes_acc_8159.pth) | 39 | | [TCGA](https://zenodo.org/records/5889558) [Extreme Augmentations] | 32 | 82.00% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_Pan_Cancer_Regularized_32_classes_acc_8200.pth) | 40 | | [TCGA](https://zenodo.org/records/5889558) [Moderate Augmentations] | 32 | 82.73% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_Pan_Cancer_32_classes_acc_8273.pth) | 41 | | [NCT_CRC_HE](https://zenodo.org/records/1214456) | 9 | 96.99% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_NCT_CRC_HE_Dataset_9_classes_acc_9699.pth) | 42 | 43 | Download the required models and copy them to the ``pretrained_models/`` directory. 44 | 45 |
46 | 47 | #### 4. Pre-Trained Model Usage 48 | 49 | Integrating the DeepCMorph model into your project is extremely simple. The code below shows how to define, initialize and run the model on sample histopathological images: 50 | 51 | 52 | ```python 53 | from model import DeepCMorph 54 | 55 | # Defining the model and specifying the number of target classes: 56 | # 41 for combined datasets, 32 for TCGA, 9 for CRC 57 | model = DeepCMorph(num_classes=41) 58 | 59 | # Loading model weights corresponding to the network trained on combined datasets 60 | # Possible 'dataset' values: TCGA, TCGA_REGULARIZED, CRC, COMBINED 61 | model.load_weights(dataset="COMBINED") 62 | 63 | # Get the predicted class for a sample input image 64 | predictions = model(sample_image) 65 | _, predicted_class = torch.max(predictions.data, 1) 66 | 67 | # Get feature vector of size 2560 for a sample input image 68 | features = model(sample_image, return_features=True) 69 | 70 | # Get predicted segmentation and classification maps for a sample input image 71 | nuclei_segmentation_map, nuclei_classification_maps = model(sample_image, return_segmentation_maps=True) 72 | ``` 73 | 74 | A detailed model usage example is additionally provided in the script ``run_inference.py``. It applies the pre-trained DeepCMorph model to 32 images from the TCGA dataset to generate 1) sample **classification predictions**, 2) **feature maps of dimension 2560** that can be used for classification with the SVM or other stand-alone model, 3) **nuclei segmentation / classification maps** generation and visualization. 75 | 76 |
77 | 78 | #### 5. Fine-Tuning the Model 79 | 80 | The following codes are needed to initialize the model for further fine-tuning: 81 | 82 | ```python 83 | from model import DeepCMorph 84 | 85 | # Defining the model with frozen segmentation module (typical usage) 86 | # All weights of the classification module are trainable 87 | model = DeepCMorph(num_classes=...) 88 | 89 | # Defining the model with frozen segmentation and classificaton modules 90 | # Only last fully-connected layer would be trainable 91 | model = DeepCMorph(num_classes=..., freeze_classification_module=True) 92 | 93 | # Defining the model with all layers being trainable 94 | model = DeepCMorph(num_classes=..., freeze_segmentation_module=False) 95 | ``` 96 | 97 |
98 | 99 | #### 6. Pre-Trained Model Evaluation 100 | 101 | File ``validate_model.py`` contains sample codes needed for model evaluation on the **NCT-CRC-HE-7K** dataset. To check the model accuracy: 102 | 103 | 1. Download the corresponding model weights 104 | 2. Download the [NCT-CRC-HE-7K](https://zenodo.org/records/1214456) dataset and extract it to the ``data`` directory. 105 | 3. Run the test script: ```python validate_model.py ``` 106 | 107 | The provided script can be also easily modified for other datasets. 108 | 109 |
110 | 111 | 112 | #### 7. Folder structure 113 | 114 | >```data/sample_TCGA_images/```   -   the folder with sample TCGA images
115 | >```pretrained_models/```   -   the folder with the provided pre-trained DeepCMorph models
116 | >```sample_visual_results/```   -   visualization of the nuclei segmentation and classification maps
117 | 118 | >```model.py```   -   DeepCMorph implementation [PyTorch]
119 | >```train_model.py```   -   the script showing model usage on sample histopathological images
120 | >```validate_model.py```   -   the script for model validation on the NCT-CRC-HE-7K dataset
121 | 122 |
123 | 124 | #### 8. License 125 | 126 | Copyright (C) 2024 Andrey Ignatov. All rights reserved. 127 | 128 | Licensed under the [CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 129 | 130 | The code is released for academic research use only. 131 | 132 |
133 | 134 | #### 9. Citation 135 | 136 | ``` 137 | @inproceedings{ignatov2024histopathological, 138 | title={Histopathological Image Classification with Cell Morphology Aware Deep Neural Networks}, 139 | author={Ignatov, Andrey and Yates, Josephine and Boeva, Valentina}, 140 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 141 | pages={6913--6925}, 142 | year={2024} 143 | } 144 | ``` 145 |
146 | 147 | #### 10. Any further questions? 148 | 149 | ``` 150 | Please contact Andrey Ignatov (andrey@vision.ee.ethz.ch) for more information 151 | ``` 152 | -------------------------------------------------------------------------------- /data/sample_TCGA_images/Adrenocortical_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Adrenocortical_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Bladder_Urothelial_Carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Bladder_Urothelial_Carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Brain_Lower_Grade_Glioma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Brain_Lower_Grade_Glioma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Breast_invasive_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Breast_invasive_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Cholangiocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Cholangiocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Colon_adenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Colon_adenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Esophageal_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Esophageal_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Glioblastoma_multiforme/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Glioblastoma_multiforme/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Head_and_Neck_squamous_cell_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Head_and_Neck_squamous_cell_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Kidney_Chromophobe/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Kidney_Chromophobe/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Kidney_renal_clear_cell_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Kidney_renal_clear_cell_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Kidney_renal_papillary_cell_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Kidney_renal_papillary_cell_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Liver_hepatocellular_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Liver_hepatocellular_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Lung_adenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Lung_adenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Lung_squamous_cell_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Lung_squamous_cell_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Mesothelioma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Mesothelioma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Ovarian_serous_cystadenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Ovarian_serous_cystadenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Pancreatic_adenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Pancreatic_adenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Pheochromocytoma_and_Paraganglioma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Pheochromocytoma_and_Paraganglioma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Prostate_adenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Prostate_adenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Rectum_adenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Rectum_adenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Sarcoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Sarcoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Skin_Cutaneous_Melanoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Skin_Cutaneous_Melanoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Stomach_adenocarcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Stomach_adenocarcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Testicular_Germ_Cell_Tumors/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Testicular_Germ_Cell_Tumors/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Thymoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Thymoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Thyroid_carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Thyroid_carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Uterine_Carcinosarcoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Uterine_Carcinosarcoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Uterine_Corpus_Endometrial_Carcinoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Uterine_Corpus_Endometrial_Carcinoma/0.jpg -------------------------------------------------------------------------------- /data/sample_TCGA_images/Uveal_Melanoma/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Uveal_Melanoma/0.jpg -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 by Andrey Ignatov. All Rights Reserved. 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | from torchvision.models.feature_extraction import create_feature_extractor 7 | 8 | 9 | class UpsampleConvLayer(torch.nn.Module): 10 | 11 | def __init__(self, in_channels, out_channels, kernel_size, stride=2, relu=False): 12 | 13 | super(UpsampleConvLayer, self).__init__() 14 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,) 15 | self.relu = nn.LeakyReLU(0.2) 16 | 17 | def forward(self, x): 18 | 19 | out = self.upconv(x) 20 | out = self.relu(out) 21 | 22 | return out 23 | 24 | 25 | class DoubleConv(nn.Module): 26 | def __init__(self, in_channels, out_channels): 27 | super(DoubleConv, self).__init__() 28 | self.conv = nn.Sequential( 29 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), 30 | nn.BatchNorm2d(out_channels), 31 | nn.LeakyReLU(0.2), 32 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), 33 | nn.BatchNorm2d(out_channels), 34 | nn.LeakyReLU(0.2), 35 | ) 36 | 37 | def forward(self, x): 38 | return self.conv(x) 39 | 40 | 41 | class DeepCMorphSegmentationModule(nn.Module): 42 | 43 | def __init__(self, use_skips=False, num_classes=7): 44 | 45 | super(DeepCMorphSegmentationModule, self).__init__() 46 | 47 | net = models.efficientnet_b7(weights=None) 48 | 49 | self.return_nodes = { 50 | "features.2.0.block.0": "f1", 51 | "features.3.0.block.0": "f2", 52 | "features.4.0.block.0": "f3", 53 | "features.6.0.block.0": "f4", 54 | } 55 | 56 | self.encoder = create_feature_extractor(net, return_nodes=self.return_nodes) 57 | 58 | for p in self.encoder.parameters(): 59 | p.requires_grad = True 60 | 61 | self.use_skips = use_skips 62 | 63 | self.upsample_1 = UpsampleConvLayer(1344, 512, 2) 64 | self.upsample_2 = UpsampleConvLayer(512, 256, 2) 65 | self.upsample_3 = UpsampleConvLayer(256, 128, 2) 66 | self.upsample_4 = UpsampleConvLayer(128, 64, 2) 67 | 68 | self.conv_1 = DoubleConv(992, 512) 69 | self.conv_2 = DoubleConv(544, 256) 70 | self.conv_3 = DoubleConv(320, 128) 71 | self.conv_4 = DoubleConv(64, 64) 72 | 73 | self.conv_segmentation = nn.Conv2d(64, 1, 3, 1, padding="same") 74 | self.conv_classification = nn.Conv2d(64, num_classes, 3, 1, padding="same") 75 | self.sigmoid = nn.Sigmoid() 76 | 77 | def forward(self, x): 78 | 79 | features = self.encoder(x) 80 | 81 | net = self.conv_1(torch.cat((self.upsample_1(features["f4"]), features["f3"]), dim=1)) 82 | net = self.conv_2(torch.cat((self.upsample_2(net), features["f2"]), dim=1)) 83 | net = self.conv_3(torch.cat((self.upsample_3(net), features["f1"]), dim=1)) 84 | net = self.conv_4(self.upsample_4(net)) 85 | 86 | predictions_segmentation = self.sigmoid(self.conv_segmentation(net)) 87 | predictions_classification = self.sigmoid(self.conv_classification(net)) 88 | 89 | return predictions_segmentation, predictions_classification 90 | 91 | 92 | class DeepCMorph(nn.Module): 93 | 94 | def __init__(self, num_classes=41, dropout_rate=0.0, 95 | freeze_classification_module=False, freeze_segmentation_module=True): 96 | 97 | super(DeepCMorph, self).__init__() 98 | 99 | self.num_classes = num_classes 100 | self.use_dropout = True if dropout_rate > 0 else False 101 | self.dropout = nn.Dropout(dropout_rate) 102 | 103 | # Defining nuclei segmentation and classification module 104 | self.model_preprocessing = DeepCMorphSegmentationModule() 105 | 106 | # Freezing the weights of the segmentation module 107 | for p in self.model_preprocessing.parameters(): 108 | p.requires_grad = False if freeze_segmentation_module else True 109 | 110 | # Defining the DeepCMorph classification module 111 | 112 | # Using the standard Torchvision EfficientNetB7 implementation 113 | EfficientNetB7_backbone = models.efficientnet_b7(weights=None) 114 | self.return_nodes = {"flatten": "features"} 115 | 116 | self.encoder = create_feature_extractor(EfficientNetB7_backbone, return_nodes=self.return_nodes) 117 | 118 | # Changing the number of EfficientNet's input channels from 3 to 11: 119 | # 3 RGB + 1 nuclei segmentation + 7 nuclei classification feature maps 120 | self.encoder.features._modules['0'] = nn.Conv2d(11, 64, 3, stride=2, padding=1, bias=False) 121 | 122 | for p in self.encoder.parameters(): 123 | p.requires_grad = False if freeze_classification_module else True 124 | 125 | # Defining the final fully-connected layer producing the predictions 126 | self.output = nn.Linear(2560, num_classes) 127 | 128 | self.output_41 = nn.Linear(2560, 41) 129 | self.output_32 = nn.Linear(2560, 32) 130 | self.output_9 = nn.Linear(2560, 9) 131 | 132 | def forward(self, x, return_features=False, return_segmentation_maps=False): 133 | 134 | nuclei_segmentation_map, nuclei_classification_maps = self.model_preprocessing(x) 135 | 136 | if return_segmentation_maps: 137 | return nuclei_segmentation_map, nuclei_classification_maps 138 | 139 | x = torch.cat((nuclei_segmentation_map, nuclei_classification_maps, x), dim=1) 140 | 141 | features = self.encoder(x) 142 | extracted_features = features["features"] 143 | 144 | if return_features: 145 | return extracted_features 146 | 147 | if self.use_dropout: 148 | extracted_features = self.dropout(extracted_features) 149 | 150 | if self.num_classes == 41: 151 | return self.output_41(extracted_features) 152 | 153 | if self.num_classes == 32: 154 | return self.output_32(extracted_features) 155 | 156 | if self.num_classes == 9: 157 | return self.output_9(extracted_features) 158 | 159 | return self.output(extracted_features) 160 | 161 | def load_weights(self, dataset=None, path_to_checkpoints=None): 162 | 163 | self = torch.nn.DataParallel(self) 164 | 165 | if dataset is None and path_to_checkpoints is None: 166 | raise Exception("Please provide either the dataset name or the path to a checkpoint!") 167 | 168 | if path_to_checkpoints is None: 169 | 170 | if dataset == "COMBINED": 171 | path_to_checkpoints = "pretrained_models/DeepCMorph_Datasets_Combined_41_classes_acc_8159.pth" 172 | 173 | if dataset == "TCGA": 174 | path_to_checkpoints = "pretrained_models/DeepCMorph_Pan_Cancer_32_classes_acc_8273.pth" 175 | 176 | if dataset == "TCGA_REGULARIZED": 177 | path_to_checkpoints = "pretrained_models/DeepCMorph_Pan_Cancer_Regularized_32_classes_acc_8200.pth" 178 | 179 | if dataset == "CRC": 180 | path_to_checkpoints = "pretrained_models/DeepCMorph_NCT_CRC_HE_Dataset_9_classes_acc_9699.pth" 181 | 182 | if path_to_checkpoints is None: 183 | raise Exception("Please provide a valid dataset name = {'COMBINED', 'TCGA', 'TCGA_REGULARIZED', 'CRC'}") 184 | 185 | missing_keys, unexpected_keys = self.load_state_dict(torch.load(path_to_checkpoints), strict=False) 186 | 187 | print("Model loaded, unexpected keys:", unexpected_keys) 188 | 189 | -------------------------------------------------------------------------------- /pretrained_models/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 by Andrey Ignatov. All Rights Reserved. 2 | 3 | import torch 4 | from torchvision import datasets, transforms 5 | from torch.utils import data 6 | import imageio 7 | import numpy as np 8 | 9 | from model import DeepCMorph 10 | 11 | np.random.seed(42) 12 | 13 | # Modify the target number of classes and the path to the dataset 14 | NUM_CLASSES = 32 15 | PATH_TO_SAMPLE_FOLDER = "data/sample_TCGA_images/" 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | torch.backends.cudnn.deterministic = True 21 | device = torch.device("cuda") 22 | 23 | # Defining the model 24 | model = DeepCMorph(num_classes=NUM_CLASSES) 25 | # Loading model weights corresponding to the TCGA Pan Cancer dataset 26 | # Possible dataset values: TCGA, TCGA_REGULARIZED, CRC, COMBINED 27 | model.load_weights(dataset="TCGA") 28 | 29 | model.to(device) 30 | model.eval() 31 | 32 | # Loading test images 33 | test_transforms = transforms.Compose([transforms.ToTensor()]) 34 | test_dataset = datasets.ImageFolder(PATH_TO_SAMPLE_FOLDER, transform=test_transforms) 35 | test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=False, drop_last=False) 36 | TEST_SIZE = len(test_dataloader.dataset) 37 | 38 | print("1. Running sample inference") 39 | 40 | with torch.no_grad(): 41 | 42 | image_id = 0 43 | 44 | test_iter = iter(test_dataloader) 45 | for j in range(len(test_dataloader)): 46 | 47 | image, labels = next(test_iter) 48 | image = image.to(device, non_blocking=True) 49 | labels = labels.to(device, non_blocking=True) 50 | 51 | # Get the predicted class for each input images 52 | predictions = model(image) 53 | _, predictions = torch.max(predictions.data, 1) 54 | 55 | predictions = predictions.detach().cpu().numpy()[0] 56 | targets = labels.detach().cpu().numpy()[0] 57 | 58 | print("Image %d: predicted class: %d, target class: %d" % (image_id, predictions, targets)) 59 | image_id += 1 60 | 61 | print("2. Generating feature maps for sample input images") 62 | 63 | with torch.no_grad(): 64 | 65 | feature_maps = np.zeros((TEST_SIZE, 2560)) 66 | 67 | image_id = 0 68 | 69 | test_iter = iter(test_dataloader) 70 | for j in range(len(test_dataloader)): 71 | 72 | image, labels = next(test_iter) 73 | image = image.to(device, non_blocking=True) 74 | labels = labels.to(device, non_blocking=True) 75 | 76 | # Get feature vector of size 2560 for each input images 77 | image_features = model(image, return_features=True) 78 | 79 | image_features = image_features.detach().cpu().numpy()[0] 80 | feature_maps[image_id] = image_features 81 | 82 | print("Image " + str(image_id) + ", generated features:", image_features) 83 | image_id += 1 84 | 85 | print("Features generated, feature array shape:", feature_maps.shape) 86 | 87 | print("3. Generating segmentation and classification maps for sample images") 88 | 89 | with torch.no_grad(): 90 | 91 | feature_maps = np.zeros((TEST_SIZE, 2560)) 92 | 93 | image_id = 0 94 | 95 | test_iter = iter(test_dataloader) 96 | for j in range(len(test_dataloader)): 97 | 98 | image, labels = next(test_iter) 99 | image = image.to(device, non_blocking=True) 100 | labels = labels.to(device, non_blocking=True) 101 | 102 | # Get predicted segmentation and classification maps for each input images 103 | nuclei_segmentation_map, nuclei_classification_maps = model(image, return_segmentation_maps=True) 104 | 105 | # Visualizing the predicted segmentation map 106 | nuclei_segmentation_map = nuclei_segmentation_map.detach().cpu().numpy()[0].transpose(1,2,0) * 255 107 | nuclei_segmentation_map = np.dstack((nuclei_segmentation_map, nuclei_segmentation_map, nuclei_segmentation_map)) 108 | 109 | # Visualizing the predicted nuclei classification map 110 | nuclei_classification_maps = nuclei_classification_maps.detach().cpu().numpy()[0].transpose(1, 2, 0) 111 | nuclei_classification_maps = np.argmax(nuclei_classification_maps, axis=2) 112 | 113 | nuclei_classification_maps_visualized = np.zeros((nuclei_classification_maps.shape[0], nuclei_classification_maps.shape[1], 3)) 114 | nuclei_classification_maps_visualized[nuclei_classification_maps == 1] = [255, 0, 0] 115 | nuclei_classification_maps_visualized[nuclei_classification_maps == 2] = [0, 255, 0] 116 | nuclei_classification_maps_visualized[nuclei_classification_maps == 3] = [0, 0, 255] 117 | nuclei_classification_maps_visualized[nuclei_classification_maps == 4] = [255, 255, 0] 118 | nuclei_classification_maps_visualized[nuclei_classification_maps == 5] = [255, 0, 255] 119 | nuclei_classification_maps_visualized[nuclei_classification_maps == 6] = [0, 255, 255] 120 | 121 | image = image.detach().cpu().numpy()[0].transpose(1,2,0) * 255 122 | 123 | # Saving visual results 124 | combined_image = np.hstack((image, nuclei_segmentation_map, nuclei_classification_maps_visualized)) 125 | imageio.imsave("sample_visual_results/" + str(image_id) + ".jpg", combined_image.astype(np.uint8)) 126 | image_id += 1 127 | 128 | print("All visual results saved") 129 | -------------------------------------------------------------------------------- /sample_visual_results/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /validate_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 by Andrey Ignatov. All Rights Reserved. 2 | 3 | import torch 4 | from torchvision import datasets, transforms 5 | from torch.utils import data 6 | import numpy as np 7 | 8 | np.random.seed(42) 9 | 10 | NUM_CLASSES = 9 11 | BATCH_SIZE = 64 12 | 13 | PATH_TO_TEST_DATASET = "data/CRC-VAL-HE-7K/" 14 | 15 | from model import DeepCMorph 16 | 17 | # Modify the model training parameters below: 18 | 19 | from sklearn.metrics import accuracy_score, balanced_accuracy_score 20 | 21 | 22 | if __name__ == '__main__': 23 | 24 | torch.backends.cudnn.deterministic = True 25 | device = torch.device("cuda") 26 | 27 | model = DeepCMorph(num_classes=NUM_CLASSES) 28 | model.load_weights(dataset="CRC") 29 | 30 | model.to(device) 31 | model.eval() 32 | 33 | test_transforms = transforms.Compose([transforms.ToTensor()]) 34 | test_dataset = datasets.ImageFolder(PATH_TO_TEST_DATASET, transform=test_transforms) 35 | test_dataloader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=False, drop_last=False) 36 | 37 | TEST_SIZE = len(test_dataloader.dataset) 38 | print("Test size:", TEST_SIZE) 39 | 40 | print("Running Evaluation...") 41 | 42 | accuracy_total = 0.0 43 | 44 | targets_array = [] 45 | predictions_array = [] 46 | 47 | with torch.no_grad(): 48 | 49 | test_iter = iter(test_dataloader) 50 | for j in range(len(test_dataloader)): 51 | 52 | image, labels = next(test_iter) 53 | image = image.to(device, non_blocking=True) 54 | labels = labels.to(device, non_blocking=True) 55 | 56 | predictions = model(image) 57 | _, predictions = torch.max(predictions.data, 1) 58 | 59 | predictions = predictions.detach().cpu().numpy() 60 | targets = labels.detach().cpu().numpy() 61 | 62 | for k in range(targets.shape[0]): 63 | 64 | target = targets[k] 65 | predicted = predictions[k] 66 | 67 | targets_array.append(target) 68 | predictions_array.append(predicted) 69 | 70 | print("Accuracy: " + str(accuracy_score(targets_array, predictions_array))) 71 | print("Balanced Accuracy: " + str(balanced_accuracy_score(targets_array, predictions_array))) 72 | 73 | --------------------------------------------------------------------------------