├── README.md
├── data
└── sample_TCGA_images
│ ├── Adrenocortical_carcinoma
│ └── 0.jpg
│ ├── Bladder_Urothelial_Carcinoma
│ └── 0.jpg
│ ├── Brain_Lower_Grade_Glioma
│ └── 0.jpg
│ ├── Breast_invasive_carcinoma
│ └── 0.jpg
│ ├── Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma
│ └── 0.jpg
│ ├── Cholangiocarcinoma
│ └── 0.jpg
│ ├── Colon_adenocarcinoma
│ └── 0.jpg
│ ├── Esophageal_carcinoma
│ └── 0.jpg
│ ├── Glioblastoma_multiforme
│ └── 0.jpg
│ ├── Head_and_Neck_squamous_cell_carcinoma
│ └── 0.jpg
│ ├── Kidney_Chromophobe
│ └── 0.jpg
│ ├── Kidney_renal_clear_cell_carcinoma
│ └── 0.jpg
│ ├── Kidney_renal_papillary_cell_carcinoma
│ └── 0.jpg
│ ├── Liver_hepatocellular_carcinoma
│ └── 0.jpg
│ ├── Lung_adenocarcinoma
│ └── 0.jpg
│ ├── Lung_squamous_cell_carcinoma
│ └── 0.jpg
│ ├── Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma
│ └── 0.jpg
│ ├── Mesothelioma
│ └── 0.jpg
│ ├── Ovarian_serous_cystadenocarcinoma
│ └── 0.jpg
│ ├── Pancreatic_adenocarcinoma
│ └── 0.jpg
│ ├── Pheochromocytoma_and_Paraganglioma
│ └── 0.jpg
│ ├── Prostate_adenocarcinoma
│ └── 0.jpg
│ ├── Rectum_adenocarcinoma
│ └── 0.jpg
│ ├── Sarcoma
│ └── 0.jpg
│ ├── Skin_Cutaneous_Melanoma
│ └── 0.jpg
│ ├── Stomach_adenocarcinoma
│ └── 0.jpg
│ ├── Testicular_Germ_Cell_Tumors
│ └── 0.jpg
│ ├── Thymoma
│ └── 0.jpg
│ ├── Thyroid_carcinoma
│ └── 0.jpg
│ ├── Uterine_Carcinosarcoma
│ └── 0.jpg
│ ├── Uterine_Corpus_Endometrial_Carcinoma
│ └── 0.jpg
│ └── Uveal_Melanoma
│ └── 0.jpg
├── model.py
├── pretrained_models
└── .gitignore
├── run_inference.py
├── sample_visual_results
└── .gitignore
└── validate_model.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Histopathological Image Classification with Cell Morphology Aware Deep Neural Networks
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | #### 1. Overview [[Paper]](https://openaccess.thecvf.com/content/CVPR2024W/CVMI/papers/Ignatov_Histopathological_Image_Classification_with_Cell_Morphology_Aware_Deep_Neural_Networks_CVPRW_2024_paper.pdf)
10 |
11 | This repository provides the implementation of the foundation **DeepCMorph CNN model** designed for histopathological image classification and analysis. Unlike the existing models, DeepCMorph explicitly **learns cell morphology**: its segmentation module is trained to identify different cell types and nuclei morphological features.
12 |
13 | Key DeepCMorph features:
14 |
15 | 1. Achieves the state-of-the-art results on the **TCGA**, **NCT-CRC-HE** and **Colorectal cancer (CRC)** datasets
16 | 2. Consists of two independent **nuclei segmentation / classification** and **tissue classification** modules
17 | 3. The segmentation module is pre-trained on a combination of **8 segmentation datasets**
18 | 4. The classification module is pre-trained on the **Pan-Cancer TCGA dataset** (8736 diagnostic slides / 7175 patients)
19 | 5. Can be applied to images of **arbitrary resolutions**
20 | 6. Can be trained or fine-tuned on **one GPU**
21 |
22 |
23 |
24 | #### 2. Prerequisites
25 |
26 | - Python: numpy and imageio packages
27 | - [PyTorch + TorchVision](https://pytorch.org/) libraries
28 | - [Optional] Nvidia GPU
29 |
30 |
31 |
32 | #### 3. Download Pre-Trained Models
33 |
34 | The segmentation module of all pre-trained models is trained on a combination of 8 publicly available nuclei segmentation / classification datasets: **Lizard, CryoNuSeg, MoNuSAC, BNS, TNBC, KUMAR, MICCAI** and **PanNuke** datasets.
35 |
36 | | Dataset | #Classes | Accuracy | Download Link |
37 | |-----------------------------|----------|----------|---------------|
38 | | Combined [[TCGA](https://zenodo.org/records/5889558) + [NCT_CRC_HE](https://zenodo.org/records/1214456)] | 41 | 81.59% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_Datasets_Combined_41_classes_acc_8159.pth) |
39 | | [TCGA](https://zenodo.org/records/5889558) [Extreme Augmentations] | 32 | 82.00% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_Pan_Cancer_Regularized_32_classes_acc_8200.pth) |
40 | | [TCGA](https://zenodo.org/records/5889558) [Moderate Augmentations] | 32 | 82.73% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_Pan_Cancer_32_classes_acc_8273.pth) |
41 | | [NCT_CRC_HE](https://zenodo.org/records/1214456) | 9 | 96.99% | [Link](https://data.vision.ee.ethz.ch/ihnatova/public/DeepCMorph/DeepCMorph_NCT_CRC_HE_Dataset_9_classes_acc_9699.pth) |
42 |
43 | Download the required models and copy them to the ``pretrained_models/`` directory.
44 |
45 |
46 |
47 | #### 4. Pre-Trained Model Usage
48 |
49 | Integrating the DeepCMorph model into your project is extremely simple. The code below shows how to define, initialize and run the model on sample histopathological images:
50 |
51 |
52 | ```python
53 | from model import DeepCMorph
54 |
55 | # Defining the model and specifying the number of target classes:
56 | # 41 for combined datasets, 32 for TCGA, 9 for CRC
57 | model = DeepCMorph(num_classes=41)
58 |
59 | # Loading model weights corresponding to the network trained on combined datasets
60 | # Possible 'dataset' values: TCGA, TCGA_REGULARIZED, CRC, COMBINED
61 | model.load_weights(dataset="COMBINED")
62 |
63 | # Get the predicted class for a sample input image
64 | predictions = model(sample_image)
65 | _, predicted_class = torch.max(predictions.data, 1)
66 |
67 | # Get feature vector of size 2560 for a sample input image
68 | features = model(sample_image, return_features=True)
69 |
70 | # Get predicted segmentation and classification maps for a sample input image
71 | nuclei_segmentation_map, nuclei_classification_maps = model(sample_image, return_segmentation_maps=True)
72 | ```
73 |
74 | A detailed model usage example is additionally provided in the script ``run_inference.py``. It applies the pre-trained DeepCMorph model to 32 images from the TCGA dataset to generate 1) sample **classification predictions**, 2) **feature maps of dimension 2560** that can be used for classification with the SVM or other stand-alone model, 3) **nuclei segmentation / classification maps** generation and visualization.
75 |
76 |
77 |
78 | #### 5. Fine-Tuning the Model
79 |
80 | The following codes are needed to initialize the model for further fine-tuning:
81 |
82 | ```python
83 | from model import DeepCMorph
84 |
85 | # Defining the model with frozen segmentation module (typical usage)
86 | # All weights of the classification module are trainable
87 | model = DeepCMorph(num_classes=...)
88 |
89 | # Defining the model with frozen segmentation and classificaton modules
90 | # Only last fully-connected layer would be trainable
91 | model = DeepCMorph(num_classes=..., freeze_classification_module=True)
92 |
93 | # Defining the model with all layers being trainable
94 | model = DeepCMorph(num_classes=..., freeze_segmentation_module=False)
95 | ```
96 |
97 |
98 |
99 | #### 6. Pre-Trained Model Evaluation
100 |
101 | File ``validate_model.py`` contains sample codes needed for model evaluation on the **NCT-CRC-HE-7K** dataset. To check the model accuracy:
102 |
103 | 1. Download the corresponding model weights
104 | 2. Download the [NCT-CRC-HE-7K](https://zenodo.org/records/1214456) dataset and extract it to the ``data`` directory.
105 | 3. Run the test script: ```python validate_model.py ```
106 |
107 | The provided script can be also easily modified for other datasets.
108 |
109 |
110 |
111 |
112 | #### 7. Folder structure
113 |
114 | >```data/sample_TCGA_images/``` - the folder with sample TCGA images
115 | >```pretrained_models/``` - the folder with the provided pre-trained DeepCMorph models
116 | >```sample_visual_results/``` - visualization of the nuclei segmentation and classification maps
117 |
118 | >```model.py``` - DeepCMorph implementation [PyTorch]
119 | >```train_model.py``` - the script showing model usage on sample histopathological images
120 | >```validate_model.py``` - the script for model validation on the NCT-CRC-HE-7K dataset
121 |
122 |
123 |
124 | #### 8. License
125 |
126 | Copyright (C) 2024 Andrey Ignatov. All rights reserved.
127 |
128 | Licensed under the [CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
129 |
130 | The code is released for academic research use only.
131 |
132 |
133 |
134 | #### 9. Citation
135 |
136 | ```
137 | @inproceedings{ignatov2024histopathological,
138 | title={Histopathological Image Classification with Cell Morphology Aware Deep Neural Networks},
139 | author={Ignatov, Andrey and Yates, Josephine and Boeva, Valentina},
140 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
141 | pages={6913--6925},
142 | year={2024}
143 | }
144 | ```
145 |
146 |
147 | #### 10. Any further questions?
148 |
149 | ```
150 | Please contact Andrey Ignatov (andrey@vision.ee.ethz.ch) for more information
151 | ```
152 |
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Adrenocortical_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Adrenocortical_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Bladder_Urothelial_Carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Bladder_Urothelial_Carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Brain_Lower_Grade_Glioma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Brain_Lower_Grade_Glioma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Breast_invasive_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Breast_invasive_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Cervical_squamous_cell_carcinoma_and_endocervical_adenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Cholangiocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Cholangiocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Colon_adenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Colon_adenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Esophageal_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Esophageal_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Glioblastoma_multiforme/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Glioblastoma_multiforme/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Head_and_Neck_squamous_cell_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Head_and_Neck_squamous_cell_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Kidney_Chromophobe/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Kidney_Chromophobe/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Kidney_renal_clear_cell_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Kidney_renal_clear_cell_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Kidney_renal_papillary_cell_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Kidney_renal_papillary_cell_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Liver_hepatocellular_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Liver_hepatocellular_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Lung_adenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Lung_adenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Lung_squamous_cell_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Lung_squamous_cell_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Lymphoid_Neoplasm_Diffuse_Large_B-cell_Lymphoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Mesothelioma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Mesothelioma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Ovarian_serous_cystadenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Ovarian_serous_cystadenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Pancreatic_adenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Pancreatic_adenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Pheochromocytoma_and_Paraganglioma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Pheochromocytoma_and_Paraganglioma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Prostate_adenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Prostate_adenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Rectum_adenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Rectum_adenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Sarcoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Sarcoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Skin_Cutaneous_Melanoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Skin_Cutaneous_Melanoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Stomach_adenocarcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Stomach_adenocarcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Testicular_Germ_Cell_Tumors/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Testicular_Germ_Cell_Tumors/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Thymoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Thymoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Thyroid_carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Thyroid_carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Uterine_Carcinosarcoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Uterine_Carcinosarcoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Uterine_Corpus_Endometrial_Carcinoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Uterine_Corpus_Endometrial_Carcinoma/0.jpg
--------------------------------------------------------------------------------
/data/sample_TCGA_images/Uveal_Melanoma/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DeepCMorph/95fb09eeac8f35d6dca1188bfa5e80a8d2331e05/data/sample_TCGA_images/Uveal_Melanoma/0.jpg
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 by Andrey Ignatov. All Rights Reserved.
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.models as models
6 | from torchvision.models.feature_extraction import create_feature_extractor
7 |
8 |
9 | class UpsampleConvLayer(torch.nn.Module):
10 |
11 | def __init__(self, in_channels, out_channels, kernel_size, stride=2, relu=False):
12 |
13 | super(UpsampleConvLayer, self).__init__()
14 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,)
15 | self.relu = nn.LeakyReLU(0.2)
16 |
17 | def forward(self, x):
18 |
19 | out = self.upconv(x)
20 | out = self.relu(out)
21 |
22 | return out
23 |
24 |
25 | class DoubleConv(nn.Module):
26 | def __init__(self, in_channels, out_channels):
27 | super(DoubleConv, self).__init__()
28 | self.conv = nn.Sequential(
29 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
30 | nn.BatchNorm2d(out_channels),
31 | nn.LeakyReLU(0.2),
32 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
33 | nn.BatchNorm2d(out_channels),
34 | nn.LeakyReLU(0.2),
35 | )
36 |
37 | def forward(self, x):
38 | return self.conv(x)
39 |
40 |
41 | class DeepCMorphSegmentationModule(nn.Module):
42 |
43 | def __init__(self, use_skips=False, num_classes=7):
44 |
45 | super(DeepCMorphSegmentationModule, self).__init__()
46 |
47 | net = models.efficientnet_b7(weights=None)
48 |
49 | self.return_nodes = {
50 | "features.2.0.block.0": "f1",
51 | "features.3.0.block.0": "f2",
52 | "features.4.0.block.0": "f3",
53 | "features.6.0.block.0": "f4",
54 | }
55 |
56 | self.encoder = create_feature_extractor(net, return_nodes=self.return_nodes)
57 |
58 | for p in self.encoder.parameters():
59 | p.requires_grad = True
60 |
61 | self.use_skips = use_skips
62 |
63 | self.upsample_1 = UpsampleConvLayer(1344, 512, 2)
64 | self.upsample_2 = UpsampleConvLayer(512, 256, 2)
65 | self.upsample_3 = UpsampleConvLayer(256, 128, 2)
66 | self.upsample_4 = UpsampleConvLayer(128, 64, 2)
67 |
68 | self.conv_1 = DoubleConv(992, 512)
69 | self.conv_2 = DoubleConv(544, 256)
70 | self.conv_3 = DoubleConv(320, 128)
71 | self.conv_4 = DoubleConv(64, 64)
72 |
73 | self.conv_segmentation = nn.Conv2d(64, 1, 3, 1, padding="same")
74 | self.conv_classification = nn.Conv2d(64, num_classes, 3, 1, padding="same")
75 | self.sigmoid = nn.Sigmoid()
76 |
77 | def forward(self, x):
78 |
79 | features = self.encoder(x)
80 |
81 | net = self.conv_1(torch.cat((self.upsample_1(features["f4"]), features["f3"]), dim=1))
82 | net = self.conv_2(torch.cat((self.upsample_2(net), features["f2"]), dim=1))
83 | net = self.conv_3(torch.cat((self.upsample_3(net), features["f1"]), dim=1))
84 | net = self.conv_4(self.upsample_4(net))
85 |
86 | predictions_segmentation = self.sigmoid(self.conv_segmentation(net))
87 | predictions_classification = self.sigmoid(self.conv_classification(net))
88 |
89 | return predictions_segmentation, predictions_classification
90 |
91 |
92 | class DeepCMorph(nn.Module):
93 |
94 | def __init__(self, num_classes=41, dropout_rate=0.0,
95 | freeze_classification_module=False, freeze_segmentation_module=True):
96 |
97 | super(DeepCMorph, self).__init__()
98 |
99 | self.num_classes = num_classes
100 | self.use_dropout = True if dropout_rate > 0 else False
101 | self.dropout = nn.Dropout(dropout_rate)
102 |
103 | # Defining nuclei segmentation and classification module
104 | self.model_preprocessing = DeepCMorphSegmentationModule()
105 |
106 | # Freezing the weights of the segmentation module
107 | for p in self.model_preprocessing.parameters():
108 | p.requires_grad = False if freeze_segmentation_module else True
109 |
110 | # Defining the DeepCMorph classification module
111 |
112 | # Using the standard Torchvision EfficientNetB7 implementation
113 | EfficientNetB7_backbone = models.efficientnet_b7(weights=None)
114 | self.return_nodes = {"flatten": "features"}
115 |
116 | self.encoder = create_feature_extractor(EfficientNetB7_backbone, return_nodes=self.return_nodes)
117 |
118 | # Changing the number of EfficientNet's input channels from 3 to 11:
119 | # 3 RGB + 1 nuclei segmentation + 7 nuclei classification feature maps
120 | self.encoder.features._modules['0'] = nn.Conv2d(11, 64, 3, stride=2, padding=1, bias=False)
121 |
122 | for p in self.encoder.parameters():
123 | p.requires_grad = False if freeze_classification_module else True
124 |
125 | # Defining the final fully-connected layer producing the predictions
126 | self.output = nn.Linear(2560, num_classes)
127 |
128 | self.output_41 = nn.Linear(2560, 41)
129 | self.output_32 = nn.Linear(2560, 32)
130 | self.output_9 = nn.Linear(2560, 9)
131 |
132 | def forward(self, x, return_features=False, return_segmentation_maps=False):
133 |
134 | nuclei_segmentation_map, nuclei_classification_maps = self.model_preprocessing(x)
135 |
136 | if return_segmentation_maps:
137 | return nuclei_segmentation_map, nuclei_classification_maps
138 |
139 | x = torch.cat((nuclei_segmentation_map, nuclei_classification_maps, x), dim=1)
140 |
141 | features = self.encoder(x)
142 | extracted_features = features["features"]
143 |
144 | if return_features:
145 | return extracted_features
146 |
147 | if self.use_dropout:
148 | extracted_features = self.dropout(extracted_features)
149 |
150 | if self.num_classes == 41:
151 | return self.output_41(extracted_features)
152 |
153 | if self.num_classes == 32:
154 | return self.output_32(extracted_features)
155 |
156 | if self.num_classes == 9:
157 | return self.output_9(extracted_features)
158 |
159 | return self.output(extracted_features)
160 |
161 | def load_weights(self, dataset=None, path_to_checkpoints=None):
162 |
163 | self = torch.nn.DataParallel(self)
164 |
165 | if dataset is None and path_to_checkpoints is None:
166 | raise Exception("Please provide either the dataset name or the path to a checkpoint!")
167 |
168 | if path_to_checkpoints is None:
169 |
170 | if dataset == "COMBINED":
171 | path_to_checkpoints = "pretrained_models/DeepCMorph_Datasets_Combined_41_classes_acc_8159.pth"
172 |
173 | if dataset == "TCGA":
174 | path_to_checkpoints = "pretrained_models/DeepCMorph_Pan_Cancer_32_classes_acc_8273.pth"
175 |
176 | if dataset == "TCGA_REGULARIZED":
177 | path_to_checkpoints = "pretrained_models/DeepCMorph_Pan_Cancer_Regularized_32_classes_acc_8200.pth"
178 |
179 | if dataset == "CRC":
180 | path_to_checkpoints = "pretrained_models/DeepCMorph_NCT_CRC_HE_Dataset_9_classes_acc_9699.pth"
181 |
182 | if path_to_checkpoints is None:
183 | raise Exception("Please provide a valid dataset name = {'COMBINED', 'TCGA', 'TCGA_REGULARIZED', 'CRC'}")
184 |
185 | missing_keys, unexpected_keys = self.load_state_dict(torch.load(path_to_checkpoints), strict=False)
186 |
187 | print("Model loaded, unexpected keys:", unexpected_keys)
188 |
189 |
--------------------------------------------------------------------------------
/pretrained_models/.gitignore:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 by Andrey Ignatov. All Rights Reserved.
2 |
3 | import torch
4 | from torchvision import datasets, transforms
5 | from torch.utils import data
6 | import imageio
7 | import numpy as np
8 |
9 | from model import DeepCMorph
10 |
11 | np.random.seed(42)
12 |
13 | # Modify the target number of classes and the path to the dataset
14 | NUM_CLASSES = 32
15 | PATH_TO_SAMPLE_FOLDER = "data/sample_TCGA_images/"
16 |
17 |
18 | if __name__ == '__main__':
19 |
20 | torch.backends.cudnn.deterministic = True
21 | device = torch.device("cuda")
22 |
23 | # Defining the model
24 | model = DeepCMorph(num_classes=NUM_CLASSES)
25 | # Loading model weights corresponding to the TCGA Pan Cancer dataset
26 | # Possible dataset values: TCGA, TCGA_REGULARIZED, CRC, COMBINED
27 | model.load_weights(dataset="TCGA")
28 |
29 | model.to(device)
30 | model.eval()
31 |
32 | # Loading test images
33 | test_transforms = transforms.Compose([transforms.ToTensor()])
34 | test_dataset = datasets.ImageFolder(PATH_TO_SAMPLE_FOLDER, transform=test_transforms)
35 | test_dataloader = data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=False, drop_last=False)
36 | TEST_SIZE = len(test_dataloader.dataset)
37 |
38 | print("1. Running sample inference")
39 |
40 | with torch.no_grad():
41 |
42 | image_id = 0
43 |
44 | test_iter = iter(test_dataloader)
45 | for j in range(len(test_dataloader)):
46 |
47 | image, labels = next(test_iter)
48 | image = image.to(device, non_blocking=True)
49 | labels = labels.to(device, non_blocking=True)
50 |
51 | # Get the predicted class for each input images
52 | predictions = model(image)
53 | _, predictions = torch.max(predictions.data, 1)
54 |
55 | predictions = predictions.detach().cpu().numpy()[0]
56 | targets = labels.detach().cpu().numpy()[0]
57 |
58 | print("Image %d: predicted class: %d, target class: %d" % (image_id, predictions, targets))
59 | image_id += 1
60 |
61 | print("2. Generating feature maps for sample input images")
62 |
63 | with torch.no_grad():
64 |
65 | feature_maps = np.zeros((TEST_SIZE, 2560))
66 |
67 | image_id = 0
68 |
69 | test_iter = iter(test_dataloader)
70 | for j in range(len(test_dataloader)):
71 |
72 | image, labels = next(test_iter)
73 | image = image.to(device, non_blocking=True)
74 | labels = labels.to(device, non_blocking=True)
75 |
76 | # Get feature vector of size 2560 for each input images
77 | image_features = model(image, return_features=True)
78 |
79 | image_features = image_features.detach().cpu().numpy()[0]
80 | feature_maps[image_id] = image_features
81 |
82 | print("Image " + str(image_id) + ", generated features:", image_features)
83 | image_id += 1
84 |
85 | print("Features generated, feature array shape:", feature_maps.shape)
86 |
87 | print("3. Generating segmentation and classification maps for sample images")
88 |
89 | with torch.no_grad():
90 |
91 | feature_maps = np.zeros((TEST_SIZE, 2560))
92 |
93 | image_id = 0
94 |
95 | test_iter = iter(test_dataloader)
96 | for j in range(len(test_dataloader)):
97 |
98 | image, labels = next(test_iter)
99 | image = image.to(device, non_blocking=True)
100 | labels = labels.to(device, non_blocking=True)
101 |
102 | # Get predicted segmentation and classification maps for each input images
103 | nuclei_segmentation_map, nuclei_classification_maps = model(image, return_segmentation_maps=True)
104 |
105 | # Visualizing the predicted segmentation map
106 | nuclei_segmentation_map = nuclei_segmentation_map.detach().cpu().numpy()[0].transpose(1,2,0) * 255
107 | nuclei_segmentation_map = np.dstack((nuclei_segmentation_map, nuclei_segmentation_map, nuclei_segmentation_map))
108 |
109 | # Visualizing the predicted nuclei classification map
110 | nuclei_classification_maps = nuclei_classification_maps.detach().cpu().numpy()[0].transpose(1, 2, 0)
111 | nuclei_classification_maps = np.argmax(nuclei_classification_maps, axis=2)
112 |
113 | nuclei_classification_maps_visualized = np.zeros((nuclei_classification_maps.shape[0], nuclei_classification_maps.shape[1], 3))
114 | nuclei_classification_maps_visualized[nuclei_classification_maps == 1] = [255, 0, 0]
115 | nuclei_classification_maps_visualized[nuclei_classification_maps == 2] = [0, 255, 0]
116 | nuclei_classification_maps_visualized[nuclei_classification_maps == 3] = [0, 0, 255]
117 | nuclei_classification_maps_visualized[nuclei_classification_maps == 4] = [255, 255, 0]
118 | nuclei_classification_maps_visualized[nuclei_classification_maps == 5] = [255, 0, 255]
119 | nuclei_classification_maps_visualized[nuclei_classification_maps == 6] = [0, 255, 255]
120 |
121 | image = image.detach().cpu().numpy()[0].transpose(1,2,0) * 255
122 |
123 | # Saving visual results
124 | combined_image = np.hstack((image, nuclei_segmentation_map, nuclei_classification_maps_visualized))
125 | imageio.imsave("sample_visual_results/" + str(image_id) + ".jpg", combined_image.astype(np.uint8))
126 | image_id += 1
127 |
128 | print("All visual results saved")
129 |
--------------------------------------------------------------------------------
/sample_visual_results/.gitignore:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/validate_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 by Andrey Ignatov. All Rights Reserved.
2 |
3 | import torch
4 | from torchvision import datasets, transforms
5 | from torch.utils import data
6 | import numpy as np
7 |
8 | np.random.seed(42)
9 |
10 | NUM_CLASSES = 9
11 | BATCH_SIZE = 64
12 |
13 | PATH_TO_TEST_DATASET = "data/CRC-VAL-HE-7K/"
14 |
15 | from model import DeepCMorph
16 |
17 | # Modify the model training parameters below:
18 |
19 | from sklearn.metrics import accuracy_score, balanced_accuracy_score
20 |
21 |
22 | if __name__ == '__main__':
23 |
24 | torch.backends.cudnn.deterministic = True
25 | device = torch.device("cuda")
26 |
27 | model = DeepCMorph(num_classes=NUM_CLASSES)
28 | model.load_weights(dataset="CRC")
29 |
30 | model.to(device)
31 | model.eval()
32 |
33 | test_transforms = transforms.Compose([transforms.ToTensor()])
34 | test_dataset = datasets.ImageFolder(PATH_TO_TEST_DATASET, transform=test_transforms)
35 | test_dataloader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=False, drop_last=False)
36 |
37 | TEST_SIZE = len(test_dataloader.dataset)
38 | print("Test size:", TEST_SIZE)
39 |
40 | print("Running Evaluation...")
41 |
42 | accuracy_total = 0.0
43 |
44 | targets_array = []
45 | predictions_array = []
46 |
47 | with torch.no_grad():
48 |
49 | test_iter = iter(test_dataloader)
50 | for j in range(len(test_dataloader)):
51 |
52 | image, labels = next(test_iter)
53 | image = image.to(device, non_blocking=True)
54 | labels = labels.to(device, non_blocking=True)
55 |
56 | predictions = model(image)
57 | _, predictions = torch.max(predictions.data, 1)
58 |
59 | predictions = predictions.detach().cpu().numpy()
60 | targets = labels.detach().cpu().numpy()
61 |
62 | for k in range(targets.shape[0]):
63 |
64 | target = targets[k]
65 | predicted = predictions[k]
66 |
67 | targets_array.append(target)
68 | predictions_array.append(predicted)
69 |
70 | print("Accuracy: " + str(accuracy_score(targets_array, predictions_array)))
71 | print("Balanced Accuracy: " + str(balanced_accuracy_score(targets_array, predictions_array)))
72 |
73 |
--------------------------------------------------------------------------------