├── LICENSE ├── README.md ├── automated_external_benchmarking ├── README.md ├── README.md~ ├── inference.py └── inference.py~ └── code ├── feature_extraction ├── README.md ├── README.md~ ├── ctran.py ├── encoders.py ├── extract_tissue.py ├── extract_tissue.py~ ├── make_tensors.py ├── optimus.py ├── phikon.py ├── resnet_custom.py ├── uni.py ├── virchow.py └── vision_transformer.py └── gma_training ├── README.md ├── datasets.py ├── modules.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Computational Pathology @Mount Sinai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Clinical Benchmark of Public Self-Supervised Pathology Foundation Models 2 | 3 | Repository of training recipes for the manuscript: "A Clinical Benchmark of Public Self-Supervised Pathology Foundation Models". 4 | Manuscript link: [arxiv](https://arxiv.org/abs/2407.06508) 5 | 6 | ## Abstract 7 | The use of self-supervised learning (SSL) to train pathology foundation models has increased substantially in the past few years. Notably, several models trained on large quantities of clinical data have been made publicly available in recent months. This will significantly enhance scientific research in computational pathology and help bridge the gap between research and clinical deployment. With the increase in availability of public foundation models of different sizes, trained using different algorithms on different datasets, it becomes important to establish a benchmark to compare the performance of such models on a variety of clinically relevant tasks spanning multiple organs and diseases. In this work, we present a collection of pathology datasets comprising clinical slides associated with clinically relevant endpoints including cancer diagnoses and a variety of biomarkers generated during standard hospital operation from three medical centers. We leverage these datasets to systematically assess the performance of public pathology foundation models and provide insights into best practices for training new foundation models and selecting appropriate pretrained models. To enable the community to evaluate their models on our clinical datasets, we make available an automated benchmarking pipeline for external use. 8 | 9 | ## Leaderboard 10 | 11 | Average AUC (standard deviation) across 20 MCCV splits. Models are ordered in increasing rank left to right. Updated 11/26/2024. 12 | 13 | ### Detection Tasks 14 | 15 | | Task | H-optimus-0 | Prov-GigaPath | SP85M | UNI | Virchow2 | SP22M | Phikon-v2 | Virchow | Phikon | CTransPath | tRes50 | 16 | |:----------------|:--------------|:----------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------| 17 | | MSHS Bladder | 0.963 (0.017) | 0.958 (0.016) | 0.961 (0.015) | 0.954 (0.018) | 0.964 (0.016) | 0.954 (0.016) | 0.960 (0.015) | 0.949 (0.022) | 0.950 (0.017) | 0.950 (0.023) | 0.938 (0.027) | 18 | | MSHS Breast | 0.981 (0.007) | 0.979 (0.009) | 0.981 (0.007) | 0.981 (0.007) | 0.978 (0.008) | 0.980 (0.009) | 0.978 (0.008) | 0.978 (0.006) | 0.977 (0.007) | 0.969 (0.011) | 0.932 (0.013) | 19 | | MSHS Colorectal | 0.974 (0.029) | 0.974 (0.026) | 0.973 (0.024) | 0.970 (0.028) | 0.970 (0.027) | 0.972 (0.023) | 0.969 (0.027) | 0.970 (0.030) | 0.966 (0.028) | 0.963 (0.027) | 0.950 (0.022) | 20 | | MSHS DCIS | 0.992 (0.006) | 0.983 (0.017) | 0.985 (0.018) | 0.992 (0.010) | 0.989 (0.014) | 0.985 (0.012) | 0.989 (0.010) | 0.984 (0.015) | 0.988 (0.012) | 0.974 (0.025) | 0.951 (0.034) | 21 | | MSHS IBD | 0.980 (0.008) | 0.980 (0.006) | 0.975 (0.009) | 0.973 (0.008) | 0.973 (0.010) | 0.975 (0.007) | 0.961 (0.011) | 0.965 (0.010) | 0.967 (0.010) | 0.956 (0.009) | 0.939 (0.020) | 22 | | MSHS Kidney | 0.973 (0.009) | 0.970 (0.009) | 0.972 (0.009) | 0.971 (0.009) | 0.971 (0.010) | 0.970 (0.010) | 0.967 (0.008) | 0.965 (0.011) | 0.965 (0.008) | 0.962 (0.009) | 0.952 (0.012) | 23 | | MSHS Oral | 0.991 (0.018) | 0.993 (0.013) | 0.987 (0.018) | 0.992 (0.016) | 0.994 (0.013) | 0.989 (0.013) | 0.990 (0.013) | 0.992 (0.014) | 0.992 (0.012) | 0.987 (0.012) | 0.966 (0.024) | 24 | | MSHS Prostate | 0.991 (0.005) | 0.991 (0.006) | 0.992 (0.006) | 0.992 (0.005) | 0.990 (0.007) | 0.990 (0.006) | 0.988 (0.006) | 0.992 (0.005) | 0.989 (0.008) | 0.983 (0.011) | 0.985 (0.007) | 25 | | MSHS Thyroid | 0.975 (0.013) | 0.974 (0.011) | 0.975 (0.014) | 0.977 (0.010) | 0.968 (0.015) | 0.973 (0.012) | 0.970 (0.010) | 0.975 (0.012) | 0.972 (0.011) | 0.971 (0.013) | 0.968 (0.013) | 26 | | Overall | 0.980 | 0.978 | 0.978 | 0.978 | 0.978 | 0.976 | 0.975 | 0.975 | 0.974 | 0.968 | 0.954 | 27 | 28 | ### Biomarker Tasks 29 | 30 | | Task | H-optimus-0 | Prov-GigaPath | UNI | Phikon | Virchow2 | Phikon-v2 | SP85M | SP22M | Virchow | CTransPath | tRes50 | 31 | |:------------------|:--------------|:----------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------|:--------------| 32 | | MSHS BCa ER | 0.973 (0.008) | 0.972 (0.008) | 0.970 (0.009) | 0.966 (0.009) | 0.970 (0.007) | 0.964 (0.008) | 0.965 (0.010) | 0.967 (0.009) | 0.968 (0.008) | 0.949 (0.009) | 0.912 (0.022) | 33 | | MSHS BCa HER2 | 0.860 (0.026) | 0.831 (0.027) | 0.830 (0.024) | 0.812 (0.033) | 0.836 (0.023) | 0.807 (0.029) | 0.802 (0.023) | 0.814 (0.021) | 0.825 (0.026) | 0.803 (0.023) | 0.772 (0.039) | 34 | | MSHS BCa PR | 0.937 (0.010) | 0.925 (0.011) | 0.929 (0.012) | 0.925 (0.013) | 0.912 (0.011) | 0.913 (0.014) | 0.926 (0.013) | 0.926 (0.011) | 0.918 (0.012) | 0.896 (0.013) | 0.838 (0.027) | 35 | | MSHS BioMe HRD | 0.695 (0.152) | 0.741 (0.097) | 0.727 (0.129) | 0.711 (0.117) | 0.703 (0.092) | 0.813 (0.104) | 0.718 (0.099) | 0.685 (0.113) | 0.702 (0.099) | 0.672 (0.140) | 0.545 (0.138) | 36 | | MSHS LUAD EGFR | 0.829 (0.047) | 0.821 (0.039) | 0.797 (0.041) | 0.763 (0.053) | 0.792 (0.054) | 0.754 (0.052) | 0.745 (0.043) | 0.725 (0.059) | 0.767 (0.052) | 0.718 (0.051) | 0.594 (0.057) | 37 | | MSKCC LUAD ALK | 0.821 (0.044) | 0.813 (0.042) | 0.789 (0.045) | 0.782 (0.046) | 0.778 (0.053) | 0.761 (0.040) | 0.736 (0.053) | 0.742 (0.049) | 0.747 (0.041) | 0.732 (0.049) | 0.625 (0.062) | 38 | | MSKCC LUAD EGFR | 0.823 (0.030) | 0.812 (0.028) | 0.795 (0.039) | 0.760 (0.039) | 0.766 (0.038) | 0.765 (0.036) | 0.755 (0.034) | 0.753 (0.030) | 0.755 (0.042) | 0.739 (0.037) | 0.649 (0.031) | 39 | | MSKCC LUAD KRAS | 0.711 (0.031) | 0.730 (0.030) | 0.716 (0.031) | 0.683 (0.029) | 0.694 (0.037) | 0.671 (0.018) | 0.677 (0.027) | 0.667 (0.020) | 0.658 (0.020) | 0.633 (0.032) | 0.546 (0.062) | 40 | | MSKCC LUAD STK11 | 0.874 (0.030) | 0.868 (0.036) | 0.843 (0.034) | 0.824 (0.032) | 0.808 (0.048) | 0.803 (0.038) | 0.813 (0.054) | 0.811 (0.049) | 0.785 (0.059) | 0.799 (0.042) | 0.657 (0.061) | 41 | | MSKCC LUAD TP53 | 0.732 (0.031) | 0.757 (0.028) | 0.746 (0.032) | 0.743 (0.031) | 0.706 (0.027) | 0.728 (0.030) | 0.705 (0.032) | 0.703 (0.028) | 0.702 (0.024) | 0.708 (0.029) | 0.658 (0.043) | 42 | | MSKCC NSCLC IO | 0.596 (0.069) | 0.562 (0.060) | 0.607 (0.053) | 0.573 (0.047) | 0.578 (0.055) | 0.541 (0.070) | 0.503 (0.039) | 0.525 (0.058) | 0.527 (0.060) | 0.556 (0.044) | 0.537 (0.077) | 43 | | SUH Melanoma BRAF | 0.705 (0.053) | 0.651 (0.052) | 0.698 (0.051) | 0.668 (0.057) | 0.671 (0.063) | 0.671 (0.053) | 0.653 (0.052) | 0.650 (0.043) | 0.607 (0.052) | 0.655 (0.090) | 0.567 (0.060) | 44 | | SUH Melanoma NRAS | 0.650 (0.049) | 0.625 (0.068) | 0.608 (0.060) | 0.635 (0.070) | 0.611 (0.053) | 0.608 (0.056) | 0.633 (0.069) | 0.609 (0.062) | 0.590 (0.061) | 0.575 (0.055) | 0.530 (0.058) | 45 | | Overall | 0.785 | 0.778 | 0.773 | 0.757 | 0.756 | 0.754 | 0.741 | 0.737 | 0.735 | 0.726 | 0.648 | 46 | 47 | 48 | ## Methods 49 | 50 | ### Clinically Relevant Downstream Tasks 51 | 52 | #### Detection Tasks 53 | 54 | | Task | Origin | Disease | Slides | Scanner | 55 | | --------- | ------ | ------------- | -----: | ----------------- | 56 | | Detection | MSHS | Breast Cancer | 1,998 | Philips Ultrafast | 57 | | Detection | MSHS | Oral Cancer | 279 | Philips Ultrafast | 58 | | Detection | MSHS | Bladder Cancer | 448 | Philips Ultrafast | 59 | | Detection | MSHS | Kidney Cancer | 1,000 | Philips Ultrafast | 60 | | Detection | MSHS | Thyroid Cancer | 710 | Philips Ultrafast | 61 | | Detection | MSHS | DCIS | 233 | Philips Ultrafast | 62 | | Detection | MSHS | Prostate Cancer | 1,000 | Philips Ultrafast | 63 | | Detection | MSHS | Colorectal Cancer | 413 | Philips Ultrafast | 64 | | Detection | MSHS | IBD | 1,448 | Philips Ultrafast | 65 | 66 | #### Biomarker Tasks 67 | 68 | | Task | Origin | Biomarker | Specimen | Slides | Scanner | 69 | | --------- | ------ | ------------- | ------------- | -----: | ----------------- | 70 | | Biomarker | MSHS | IHC ER | Breast Cancer | 2,000 | Philips Ultrafast | 71 | | Biomarker | MSHS | IHC PR | Breast Cancer | 1,986 | Philips Ultrafast | 72 | | Biomarker | MSHS | IHC/FISH HER2 | Breast Cancer | 2,018 | Philips Ultrafast | 73 | | Biomarker | MSHS | BioMe HRD | Breast | 563 | Philips Ultrafast | 74 | | Biomarker | SUH | NGS BRAF | Melanoma | 283 | Nanozoomer S210 | 75 | | Biomarker | SUH | NGS NRAS | Melanoma | 283 | Nanozoomer S210 | 76 | | Biomarker | MSHS | NGS EGFR | LUAD | 294 | Philips Ultrafast | 77 | | Biomarker | MSKCC | NGS EGFR | LUAD | 1,000 | Aperio AT2 | 78 | | Biomarker | MSKCC | NGS ALK | LUAD | 999 | Aperio AT2 | 79 | | Biomarker | MSKCC | NGS STK11 | LUAD | 998 | Aperio AT2 | 80 | | Biomarker | MSKCC | NGS KRAS | LUAD | 998 | Aperio AT2 | 81 | | Biomarker | MSKCC | NGS TP53 | LUAD | 998 | Aperio AT2 | 82 | | Outcome | MSKCC | ICI Response | NSCLC | 454 | Aperio AT2 | 83 | 84 | MSHS: Mount Sinai Health System; 85 | DCIS: Ductal Carcinoma In Situ; 86 | IBD: Inflammatory Bowel Disease; 87 | ER: Estrogen Receptor; 88 | PR: Progesterone Receptor; 89 | IHC: Immunohistochemistry; 90 | FISH: Fluorescence In Situ Hybridization; 91 | SUH: Sahlgrenska University Hospital; 92 | MSKCC: Memorial Sloan Kettering Cancer Center; 93 | LUAD: Lung Adenocarcinoma; 94 | ICI: Immene Checkpoint Inhibitors; 95 | NSCLC: Non-Small Cell Lung Cancer 96 | 97 | ### Public Pathology Foundation Models 98 | 99 | | Model | Param. (M) | Algorithm | Training Data | Tiles (M) | Slides (K) | 100 | | ------------------------------------------------------------------- | ---------: | --------- | ------------- | --------: | ---------: | 101 | | [CTransPath](https://github.com/Xiyue-Wang/TransPath) | 28 | SRCL | TCGA, PAIP | 16 | 32 | 102 | | [Phikon]() | 86 | iBOT | TCGA | 43 | 6 | 103 | | [UNI](https://huggingface.co/MahmoodLab/UNI) | 303 | DINOv2 | MGB | 100 | 100 | 104 | | [Virchow](https://huggingface.co/paige-ai/Virchow) | 631 | DINOv2 | MSKCC | 2,000 | 1,488 | 105 | | [SP22M](https://huggingface.co/MountSinaiCompPath/SP22M) | 22 | DINO | MSHS | 1,600 | 423 | 106 | | [SP85M](https://huggingface.co/MountSinaiCompPath/SP85M) | 86 | DINO | MSHS | 1,600 | 423 | 107 | | [Prov-GigaPath](https://huggingface.co/prov-gigapath/prov-gigapath) | 1,135 | DINOv2 | PHS | 1,300 | 171 | 108 | | [Virchow2](https://huggingface.co/paige-ai/Virchow2) | 631 | DINOv2 | MSKCC | 1,700 | 3,100 | 109 | | [H-optimus-0](https://huggingface.co/bioptimus/H-optimus-0) | 1,135 | DINOv2 | Proprietary | >100 | >500 | 110 | | [Phikon-v2](https://huggingface.co/owkin/phikon-v2) | 307 | DINOv2 | Multicenter | 456 | 58 | 111 | 112 | MGB: Mass General Brigham; 113 | MSKCC: Memorial Sloan Kettering Cancer Center; 114 | MSHS: Mount Sinai Health System; 115 | PHS: Providence Health and Services 116 | 117 | 118 | ## Automated External Benchmarking 119 | We provide a workflow to benchmark user submitted models. To submit a request follow the instructions below: 120 | 1. Submit [this form](https://forms.office.com/Pages/ResponsePage.aspx?id=YZ3odw9XsEO55GNPRi40uCeVZRc28JFPi1Agm1twtOFUMjVTRThNQVpRN1RNMldBMTNCUVZHVFFQSi4u) with the user's name and a valid email address. Optionally, a user can allow to record the results on our leaderboard by checking the relative checkbox and providing a model name. 121 | 2. The user will receive an email with a link to a secure OneDrive folder. 122 | 3. The user should upload to the provided OneDrive folder the following files: 123 | - Docker container: a Docker (or singularity) containerized environment including the model's weights. Note: currently there is a 250GB limit per file. 124 | - `inference.py` script: as script which can run in the container provided. It should accept as input a csv file listing the slides to run inference over. It should output a torch tensor of features per slide. We provide a sample [script](https://github.com/fuchs-lab-public/OPAL/blob/main/SSL_benchmarks/automated_external_benchmarking/inference.py) which can be modified accordingly. Additional instructions can be found [here](https://github.com/fuchs-lab-public/OPAL/tree/main/SSL_benchmarks/automated_external_benchmarking). 125 | 4. Within a 2 week timeframe, the user will receive via the provided email the results of the benchmarks as a csv file with the following columns: 126 | - Task 127 | - Task Type: Detection, Biomarker 128 | - Mean AUC 129 | - AUC Standard Deviation 130 | 5. After analysis, all data will be purged. If the user opted to save the results, they will be posted in the leaderboard. -------------------------------------------------------------------------------- /automated_external_benchmarking/README.md: -------------------------------------------------------------------------------- 1 | # Automated External Benchmarking 2 | 3 | ## Container 4 | 5 | Docker or Singularity containers are acceptable. Our code will run on the local cluster using Singularity. The official [nvidia-pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) or [MONAI](https://hub.docker.com/r/projectmonai/monai/tags) containers are a good place to start. 6 | 7 | ## `inference.py` Script 8 | 9 | This script should run in the container provided by the user. 10 | The script accepts three main arguments: 11 | - `slide_data` path to input csv file which lists the slides to encode. It contains four columns: `slide` a unique slide id which will also be used to name the tensors and matches with the `tile_data` csv files; `slide_path` the path to the slide file; `mult` a rescaling factor necessary if the right magnification is not available in the slide; `level` the level to extract pixel data from the slide. 12 | - `tile_data` path to input csv file with tile information for all tiles listed in the `slide_data` file. 13 | - `output`: path to output directory where .pth files will be saved. 14 | For each slide in the `slide_data` csv, a `.pth` binary tensor file will be generated in the `output` directory. 15 | We provide an [example script](https://github.com/fuchs-lab-public/OPAL/blob/main/SSL_benchmarks/automated_external_benchmarking/inference.py) which the users can modify to their needs. 16 | 17 | ## Testing 18 | 19 | Below we provide minimal examples of tile and slide csv files based on public data from openslide. 20 | Download the slide: 21 | ```bash 22 | wget https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/CMU-1.svs 23 | ``` 24 | 25 | ### Slide Data 26 | ```text 27 | slide,slide_path,mult,level 28 | CMU-1.svs,CMU-1.svs,1.,0 29 | ``` 30 | 31 | ### Tile Data 32 | ```text 33 | slide,x,y 34 | CMU-1.svs,0,0 35 | CMU-1.svs,224,224 36 | ``` 37 | 38 | ### Run 39 | ```bash 40 | python inference.py --slide_data slide_data.csv --tile_data tile_data.csv 41 | ``` 42 | -------------------------------------------------------------------------------- /automated_external_benchmarking/README.md~: -------------------------------------------------------------------------------- 1 | # Automated External Benchmarking 2 | 3 | ## Container 4 | 5 | Docker or Singularity containers are acceptable. Our code will run on the local cluster using Singularity. The official [nvidia-pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) or [MONAI](https://hub.docker.com/r/projectmonai/monai/tags) containers are a good place to start. The container should also come with the model weights. 6 | 7 | ## `inference.py` Script 8 | 9 | This script should run within the container provided by the user. 10 | The script should accept two arguments: 11 | - `input` path to input csv file which lists the slides to encode. The input csv files should contain two columns: `slide` a unique slide id which will also be used to name the tensors; `slide_path` the path to the slide file. 12 | - `output`: path to output directory. 13 | For each slide in the input csv, a `.pth` binary tensor file should be generated in the `output` directory. 14 | We provide an [example script](https://github.com/fuchs-lab-public/OPAL/blob/main/SSL_benchmarks/automated_external_benchmarking/inference.py) which the users can modify to their needs. 15 | 16 | -------------------------------------------------------------------------------- /automated_external_benchmarking/inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Sample inference script to generate feature representations of tiles from a foundation model. 3 | ''' 4 | import os 5 | import torch 6 | import torch.utils.data as data 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import openslide 10 | import PIL.Image as Image 11 | import argparse 12 | 13 | class slide_dataset(data.Dataset): 14 | ''' 15 | This dataset class should be included without modifications 16 | Arguments: 17 | - slide: openslide object. 18 | - df: dataframe with coordinates. It has columns x, y. Coordinates are always for level 0. 19 | - mult: rescaling factor necessary if the right magnification is not available in the slide. 20 | - level: the level to extract pixel data from the slide. 21 | - trans: a PIL image transform. 22 | - tilesize: 224. 23 | ''' 24 | def __init__(self, slide, df, mult=1., level=0, tilesize=224, transform=None): 25 | self.slide = slide 26 | self.df = df 27 | self.mult = mult 28 | self.level = level 29 | self.size = int(np.round(tilesize * mult)) 30 | self.tilesize = tilesize 31 | self.transform = transform 32 | def __getitem__(self, index): 33 | row = self.df.iloc[index] 34 | img = self.slide.read_region((int(row.x), int(row.y)), int(self.level), (self.size, self.size)).convert('RGB') 35 | if self.mult != 1: 36 | img = img.resize((self.tilesize, self.tilesize), Image.LANCZOS) 37 | img = self.transform(img) 38 | return img 39 | def __len__(self): 40 | return len(self.df) 41 | 42 | ##### Model Definition ##### 43 | def get_model(): 44 | # Define the model 45 | # Load model weights 46 | # Here we use a ResNet50 as an example 47 | model = torchvision.models.resnet50() 48 | model.fc = torch.nn.Identity() 49 | return model, 2048 50 | ############################ 51 | 52 | ##### Transform Definition ##### 53 | def get_trasform(): 54 | # Define image transform 55 | return transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 58 | ]) 59 | ################################ 60 | 61 | def main(args): 62 | 63 | # Set up device 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | 66 | # Set up model 67 | model, ndim = get_model() 68 | model.eval() 69 | model.to(device) 70 | transform = get_transform() 71 | 72 | # Set up data 73 | slide_data = pd.read_csv(args.slide_data)# Slide dataframe contains columns: slide id* (slide), slide path (slide_path), rescaling factor (mult), slide level (level) 74 | tile_data = pd.read_csv(args.tile_data)# Tile dataframe contains columns: slide id* (slide), tile coordinate x (x), tile coordinate y (y) 75 | 76 | # Iterate through slides 77 | for i, row in slide_data.iterrows(): 78 | 79 | # Output file name 80 | tensor_name = os.path.join(args.output, f'{row.slide}.pth') 81 | 82 | # Set up data 83 | slide = openslide.OpenSlide(row.slide_path) 84 | dataset = slide_dataset(slide, tile_data[tile_data.slide==row.slide], mult=row.mult, level=row.level, transform=transform, tilesize=args.tilesize) 85 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 86 | 87 | # Extract features and save 88 | with torch.no_grad(): 89 | tensor = torch.zeros((tile_data.slide==row.slide).sum(), ndim).float() 90 | for j, img in enumerate(loader): 91 | out = model(img.to(device)) 92 | tensor[j*args.batch_size:j*args.batch_size+img.size(0),:] = out.detach().clone() 93 | 94 | torch.save(tensor, tensor_name) 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--slide_data', type=str, default='', help='path to slide data csv file. It should include the following columns: ') 99 | parser.add_argument('--tile_data', type=str, default='', help='path to tile datacsv file. It should include the following columns: ') 100 | parser.add_argument('--output', type=str, default='', help='path to the output directory where .pth files will be saved.') 101 | parser.add_argument('--batch_size', type=int, default=128, help='batchs size') 102 | parser.add_argument('--workers', type=int, default=10, help='workers') 103 | parser.add_argument('--tilesize', type=int, default=224, help='tilesize') 104 | args = parser.parse_args() 105 | main(args) 106 | -------------------------------------------------------------------------------- /automated_external_benchmarking/inference.py~: -------------------------------------------------------------------------------- 1 | ''' 2 | Sample inference script to generate feature representations of tiles from a foundation model. 3 | ''' 4 | import os 5 | import torch 6 | 7 | def initialize_model(): 8 | # Define the model 9 | # Load model weights 10 | # Here we use a ResNet50 as an example 11 | model = torchvision.models.resnet50() 12 | model.fc = torch.nn.Identity() 13 | return model 14 | 15 | def main(): 16 | 17 | # Set up device 18 | device = 19 | 20 | # Set up model 21 | model = initialize_model() 22 | model.eval() 23 | model.to(device) 24 | 25 | 26 | -------------------------------------------------------------------------------- /code/feature_extraction/README.md: -------------------------------------------------------------------------------- 1 | # Feature Extraction 2 | 3 | Code to extract features from tiles and save them to tensors. For each task and foundation model, it generates for each slide a 2D tensor of shape number of tiles by number of features. 4 | Usage: 5 | ```python 6 | python make_tensors.py \ 7 | --slide_data path/to/slide/data/file/for/this/task.csv \ 8 | --tile_data path/to/tile/data/file/for/this/task.csv \ 9 | --encoder a_foundation_model \ 10 | --bsize 1024 11 | --workers 10 12 | ``` 13 | 14 | ## `slide_data` 15 | Task specific slide level data. Must contain the following columns: 16 | - `slide_path`: full path to slide 17 | - `slide`: unique slide identifier 18 | - `tensor_root`: full path to root for that data. Need to add the encoder type 19 | - `tensor_name`: name of tensor file without path 20 | 21 | ## `tile_data` 22 | Task specific tile level data. Must contain the following columns: 23 | - `slide`: unique slide identifier, same as in the `slide_data` file 24 | - `x`: x coordinate 25 | - `y`: y coordinate 26 | - `level`: pyramid level at which to extract data 27 | - `mult`: factor for tile resize 28 | 29 | ## Tissue Tile Generation 30 | There are many options to extract tissue tiles and any of them could be used here. In our work we use the strategy from [Campanella et al.](https://www.nature.com/articles/s41591-019-0508-1). 31 | This is a fast method that works well for H&E slides. We provide the code in `extract_tissue.py`. To use: 32 | ```python 33 | import extract_tissue 34 | import openslide 35 | help(extract_tissue.make_sample_grid) 36 | slide = openslide.OpenSlide(path/to/a/slide) 37 | # Generate coordinates 38 | base_mpp = extract_tissue.slide_base_mpp(slide) 39 | coord_list = extract_tissue.make_sample_grid(slide, patch_size=224, mpp=0.5, mult=4, base_mpp=base_mpp) 40 | # Plot extraction 41 | extract_tissue.plot_extraction(slide, patch_size=224, mpp=0.5, mult=4, base_mpp=base_mpp) 42 | ``` -------------------------------------------------------------------------------- /code/feature_extraction/README.md~: -------------------------------------------------------------------------------- 1 | # Feature Extraction 2 | 3 | Code to extract features from tiles and save them to tensors. For each task and foundation model, it generates for each slide a 2D tensor of shape number of tiles by number of features. 4 | Usage: 5 | ``` 6 | python make_tensors.py \ 7 | --slide_data path/to/slide/data/file/for/this/task.csv \ 8 | --tile_data path/to/tile/data/file/for/this/task.csv \ 9 | --encoder a_foundation_model \ 10 | --bsize 1024 11 | --workers 10 12 | ``` 13 | 14 | ## `slide_data` 15 | Task specific slide level data. Must contain the following columns: 16 | - `slide_path`: full path to slide 17 | - `slide`: unique slide identifier 18 | - `tensor_root`: full path to root for that data. Need to add the encoder type 19 | - `tensor_name`: name of tensor file without path 20 | 21 | ## `tile_data` 22 | Task specific tile level data. Must contain the following columns: 23 | - `slide`: unique slide identifier, same as in the `slide_data` file 24 | - `x`: x coordinate 25 | - `y`: y coordinate 26 | - `level`: pyramid level at which to extract data 27 | - `mult`: factor for tile resize 28 | -------------------------------------------------------------------------------- /code/feature_extraction/ctran.py: -------------------------------------------------------------------------------- 1 | from timm.models.layers.helpers import to_2tuple 2 | import timm 3 | import torch.nn as nn 4 | 5 | 6 | class ConvStem(nn.Module): 7 | 8 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 9 | super().__init__() 10 | 11 | assert patch_size == 4 12 | assert embed_dim % 8 == 0 13 | 14 | img_size = to_2tuple(img_size) 15 | patch_size = to_2tuple(patch_size) 16 | self.img_size = img_size 17 | self.patch_size = patch_size 18 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 19 | self.num_patches = self.grid_size[0] * self.grid_size[1] 20 | self.flatten = flatten 21 | 22 | 23 | stem = [] 24 | input_dim, output_dim = 3, embed_dim // 8 25 | for l in range(2): 26 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 27 | stem.append(nn.BatchNorm2d(output_dim)) 28 | stem.append(nn.ReLU(inplace=True)) 29 | input_dim = output_dim 30 | output_dim *= 2 31 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 32 | self.proj = nn.Sequential(*stem) 33 | 34 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 35 | 36 | def forward(self, x): 37 | B, C, H, W = x.shape 38 | assert H == self.img_size[0] and W == self.img_size[1], \ 39 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 40 | x = self.proj(x) 41 | if self.flatten: 42 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 43 | x = self.norm(x) 44 | return x 45 | 46 | def ctranspath(): 47 | model = timm.create_model('swin_tiny_patch4_window7_224', embed_layer=ConvStem, pretrained=False) 48 | return model 49 | -------------------------------------------------------------------------------- /code/feature_extraction/encoders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import resnet_custom 3 | import vision_transformer as vits 4 | import uni 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import torchvision.models as models 9 | 10 | class tres50(nn.Module): 11 | def __init__(self): 12 | super(tres50, self).__init__() 13 | resnet = resnet_custom.resnet50_baseline(True) 14 | self.features = nn.Sequential(*list(resnet.children())[0:-1]) 15 | def forward(self, x): 16 | x = self.features(x).view(x.size(0), -1) 17 | return x 18 | 19 | def load_pretrained_weights_dino(model, pretrained_weights, checkpoint_key): 20 | if os.path.isfile(pretrained_weights): 21 | state_dict = torch.load(pretrained_weights, map_location="cpu") 22 | if checkpoint_key is not None and checkpoint_key in state_dict: 23 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 24 | state_dict = state_dict[checkpoint_key] 25 | # remove `module.` prefix 26 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 27 | # remove `backbone.` prefix induced by multicrop wrapper 28 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 29 | msg = model.load_state_dict(state_dict, strict=False) 30 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 31 | 32 | def get_encoder(encoder): 33 | 34 | if encoder == 'tres50_imagenet': 35 | model = tres50() 36 | ndim = 1024 37 | elif encoder == 'res50_imagenet': 38 | model = models.__dict__['resnet50'](weights='DEFAULT') 39 | model.fc = nn.Identity() 40 | ndim = 2048 41 | elif encoder == 'ctranspath': 42 | from ctran import ctranspath 43 | model = ctranspath() 44 | model.head = nn.Identity() 45 | td = torch.load(r'path/to/checkpoint.pth') 46 | model.load_state_dict(td['model'], strict=True) 47 | ndim = 768 48 | elif encoder == 'phikon': 49 | import phikon 50 | model = phikon.get_model() 51 | ndim = 768 52 | elif encoder == 'uni': 53 | model = uni.get_model() 54 | ndim = 1024 55 | elif encoder == 'virchow': 56 | import virchow 57 | model = virchow.virchow() 58 | ndim = 2560 59 | elif encoder == 'h-optimus-0': 60 | import optimus 61 | model = optimus.get_model(encoder) 62 | ndim = 1536 63 | elif encoder == 'dinosmall': 64 | path = 'path/to/checkpoint.pth' 65 | model = vits.vit_small(num_classes=0) 66 | load_pretrained_weights_dino(model, path, 'teacher') 67 | ndim = 384 68 | elif encoder == 'dinobase': 69 | path = 'path/to/checkpoint.pth' 70 | model = vits.vit_base(num_classes=0) 71 | load_pretrained_weights_dino(model, path, 'teacher') 72 | ndim = 768 73 | elif encoder == 'gigapath': 74 | import timm 75 | model = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True) 76 | ndim = 1536 77 | else: 78 | raise Exception('Wrong encoder name') 79 | 80 | model.eval() 81 | if encoder == 'virchow': 82 | from timm.data import resolve_data_config 83 | from timm.data.transforms_factory import create_transform 84 | transform = create_transform(**resolve_data_config(model.virchow.pretrained_cfg, model=model.virchow)) 85 | elif encoder == 'h-optimus-0': 86 | transform = transforms.Compose([ 87 | transforms.ToTensor(), 88 | transforms.Normalize( 89 | mean=(0.707223, 0.578729, 0.703617), 90 | std=(0.211883, 0.230117, 0.177517) 91 | ), 92 | ]) 93 | else: 94 | transform = transforms.Compose([transforms.ToTensor(), 95 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 96 | ]) 97 | total_params = sum(p.numel() for p in model.parameters()) 98 | print(f'Model: {encoder} - {total_params} parameters') 99 | return model, transform, ndim 100 | -------------------------------------------------------------------------------- /code/feature_extraction/extract_tissue.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import openslide 4 | from skimage.morphology import binary_erosion, binary_dilation, label, dilation, square, skeletonize 5 | from skimage.filters import threshold_otsu 6 | import PIL.Image as Image 7 | import random 8 | MAX_PIXEL_DIFFERENCE = 0.2 # difference must be within 20% of image size 9 | 10 | def slide_base_mpp(slide): 11 | return float(slide.properties[openslide.PROPERTY_NAME_MPP_X]) 12 | 13 | def find_level(slide, mpp, patchsize=224, base_mpp=None): 14 | downsample = mpp / base_mpp 15 | for i in range(slide.level_count)[::-1]: 16 | if abs(downsample / slide.level_downsamples[i] * patchsize - patchsize) < MAX_PIXEL_DIFFERENCE * patchsize or downsample > slide.level_downsamples[i]: 17 | level = i 18 | mult = downsample / slide.level_downsamples[level] 19 | break 20 | else: 21 | raise Exception('Requested resolution ({} mpp) is too high'.format(mpp)) 22 | #move mult to closest pixel 23 | mult = np.round(mult*patchsize)/patchsize 24 | if abs(mult*patchsize - patchsize) < MAX_PIXEL_DIFFERENCE * patchsize: 25 | mult = 1. 26 | return level, mult 27 | 28 | def image2array(img): 29 | if img.__class__.__name__=='Image': 30 | if img.mode=='RGB': 31 | img=np.array(img) 32 | r,g,b = np.rollaxis(img, axis=-1) 33 | img=np.stack([r,g,b],axis=-1) 34 | elif img.mode=='RGBA': 35 | img=np.array(img) 36 | r,g,b,a = np.rollaxis(img, axis=-1) 37 | img=np.stack([r,g,b],axis=-1) 38 | else: 39 | sys.exit('Error: image is not RGB slide') 40 | img=np.uint8(img) 41 | return img 42 | 43 | def is_sample(img,threshold=0.9,ratioCenter=0.1,wholeAreaCutoff=0.5,centerAreaCutoff=0.9): 44 | nrows,ncols=img.shape 45 | timg=cv2.threshold(img, 255*threshold, 1, cv2.THRESH_BINARY_INV) 46 | kernel=np.ones((5,5),np.uint8) 47 | cimg=cv2.morphologyEx(timg[1], cv2.MORPH_CLOSE, kernel) 48 | crow=np.rint(nrows/2).astype(int) 49 | ccol=np.rint(ncols/2).astype(int) 50 | drow=np.rint(nrows*ratioCenter/2).astype(int) 51 | dcol=np.rint(ncols*ratioCenter/2).astype(int) 52 | centerw=cimg[crow-drow:crow+drow,ccol-dcol:ccol+dcol] 53 | if (np.count_nonzero(cimg)0)|(img_g==255)) 76 | t = threshold_otsu(masked.compressed()) 77 | img_g = cv2.threshold(img_g, t, 255, cv2.THRESH_BINARY)[1] 78 | 79 | else: 80 | masked = np.ma.masked_array(img_g, img_g==255) 81 | t = threshold_otsu(masked.compressed()) 82 | img_g = cv2.threshold(img_g, t, 255, cv2.THRESH_BINARY)[1] 83 | 84 | # Exclude marker 85 | if marker is not None: 86 | img_g = cv2.subtract(~img_g, marker) 87 | else: 88 | img_g = 255 - img_g 89 | 90 | # Remove grays 91 | img_g[std<5] = 0 92 | 93 | # Rescale 94 | if mult>1: 95 | img_g = img_g.reshape(h//mult, mult, w//mult, mult).max(axis=(1, 3)) 96 | 97 | return img_g, t 98 | 99 | def remove_black_ink(img_g, th=50, delta=50): 100 | ''' 101 | image in gray scale 102 | returns mask where ink is positive 103 | th=50 and delta=50 was chosen based on some slides 104 | ''' 105 | dist = np.clip(img_g - float(th), 0, None) 106 | mask = dist < delta 107 | if mask.sum() > 0: 108 | mask_s = skeletonize(mask) 109 | d = int(np.round(0.1 * mask.sum() / mask_s.sum())) 110 | mask = dilation(mask, square(2*d+1)) 111 | return mask 112 | else: 113 | return None 114 | 115 | def filter_regions(img,min_size): 116 | l, n = label(img, return_num=True) 117 | for i in range(1,n+1): 118 | #filter small regions 119 | if l[l==i].size < min_size: 120 | l[l==i] = 0 121 | return l 122 | 123 | def add(overlap): 124 | return np.linspace(0,1,overlap+1)[1:-1] 125 | 126 | def add2offset(img, slide, patch_size, mpp, maxmpp): 127 | size_x = img.shape[1] 128 | size_y = img.shape[0] 129 | offset_x = np.floor((slide.dimensions[0]*1./(patch_size*mpp/maxmpp)-size_x)*(patch_size*mpp/maxmpp)) 130 | offset_y = np.floor((slide.dimensions[1]*1./(patch_size*mpp/maxmpp)-size_y)*(patch_size*mpp/maxmpp)) 131 | add_x = np.linspace(0,offset_x,size_x).astype(int) 132 | add_y = np.linspace(0,offset_y,size_y).astype(int) 133 | return add_x, add_y 134 | 135 | def addoverlap(w, grid, overlap, patch_size, mpp, maxmpp, img, offset=0): 136 | o = (add(overlap)*(patch_size*mpp/maxmpp)).astype(int) 137 | ox,oy = np.meshgrid(o,o) 138 | connx = np.zeros(img.shape).astype(bool) 139 | conny = np.zeros(img.shape).astype(bool) 140 | connd = np.zeros(img.shape).astype(bool) 141 | connu = np.zeros(img.shape).astype(bool) 142 | connx[:,:-1] = img[:,1:] 143 | conny[:-1,:] = img[1:,:] 144 | connd[:-1,:-1] = img[1:,1:] 145 | connu[1:,:-1] = img[:-1,1:] & ( ~img[1:,1:] | ~img[:-1,:-1] ) 146 | connx = connx[w] 147 | conny = conny[w] 148 | connd = connd[w] 149 | connu = connu[w] 150 | extra = [] 151 | for i,(x,y) in enumerate(grid): 152 | if connx[i]: extra.extend(zip(o+x-offset,np.repeat(y,overlap-1)-offset)) 153 | if conny[i]: extra.extend(zip(np.repeat(x,overlap-1)-offset,o+y-offset)) 154 | if connd[i]: extra.extend(zip(ox.flatten()+x-offset,oy.flatten()+y-offset)) 155 | if connu[i]: extra.extend(zip(x+ox.flatten()-offset,y-oy.flatten()-offset)) 156 | return extra 157 | 158 | def make_sample_grid(slide, patch_size=224, mpp=0.5, min_cc_size=10, max_ratio_size=10, dilate=False, erode=False, prune=False, overlap=1, maxn=None, bmp=None, oversample=False, mult=1, centerpixel=False, base_mpp=None, thumbnail_coords=False): 159 | ''' 160 | Script that given an openslide object return a list of tuples 161 | in the form of (x,y) coordinates for patch extraction of sample patches. 162 | It has an erode option to make sure to get patches that are full of tissue. 163 | It has a prune option to check if patches are sample. It is slow. 164 | If bmp is given, it samples from within areas of the bmp that are nonzero. 165 | If oversample is True, it will downsample for full resolution regardless of what resolution is requested. 166 | mult is used to increase the resolution of the thumbnail to get finer tissue extraction 167 | ''' 168 | if oversample: 169 | img, th = threshold(slide, patch_size, base_mpp, base_mpp, mult) 170 | else: 171 | img, th = threshold(slide, patch_size, mpp, base_mpp, mult) 172 | 173 | if bmp: 174 | bmplab = Image.open(bmp) 175 | thumbx, thumby = img.shape 176 | bmplab = bmplab.resize((thumby, thumbx), Image.ANTIALIAS) 177 | bmplab = np.array(bmplab) 178 | bmplab[bmplab>0] = 1 179 | img = np.logical_and(img, bmplab) 180 | 181 | img = filter_regions(img,min_cc_size) 182 | img[img>0]=1 183 | if erode: 184 | img = binary_erosion(img) 185 | if dilate: 186 | img = binary_dilation(img) 187 | 188 | if oversample: 189 | add_x, add_y = add2offset(img, slide, patch_size, base_mpp, base_mpp) 190 | else: 191 | add_x, add_y = add2offset(img, slide, patch_size, mpp, base_mpp) 192 | 193 | #list of sample pixels 194 | w = np.where(img>0) 195 | if thumbnail_coords: 196 | return list(zip(w[1],w[0])), img.shape[1], img.shape[0] 197 | 198 | #grid=zip(w[1]*patch_size,w[0]*patch_size) 199 | if oversample: 200 | offset = int(0.5 * patch_size * ((mpp/base_mpp) - 1)) 201 | grid = list(zip((w[1]*(patch_size)+add_x[w[1]]-offset).astype(int),(w[0]*(patch_size)+add_y[w[0]]-offset).astype(int))) 202 | else: 203 | grid = list(zip((w[1]*(patch_size*mpp/base_mpp)+add_x[w[1]]).astype(int),(w[0]*(patch_size*mpp/base_mpp)+add_y[w[0]]).astype(int))) 204 | 205 | #connectivity 206 | if overlap > 1: 207 | if oversample: 208 | extra = addoverlap(w, grid, overlap, patch_size, base_mpp, base_mpp, img, offset=offset) 209 | grid.extend(extra) 210 | else: 211 | extra = addoverlap(w, grid, overlap, patch_size, mpp, base_mpp, img) 212 | grid.extend(extra) 213 | 214 | # center pixel offset 215 | if centerpixel: 216 | offset = int(mpp / base_mpp * patch_size // 2) 217 | grid = [(x[0] + offset, x[1] + offset) for x in grid] 218 | 219 | #prune squares 220 | if prune: 221 | level, mult = find_level(slide,mpp,base_mpp) 222 | psize = int(patch_size*mult) 223 | truegrid = [] 224 | for tup in grid: 225 | reg = slide.read_region(tup,level,(psize,psize)) 226 | if mult != 1: 227 | reg = reg.resize((224,224),Image.BILINEAR) 228 | reg = image2array(reg) 229 | if is_sample(reg,th/255,0.2,0.4,0.5): 230 | truegrid.append(tup) 231 | else: 232 | truegrid = grid 233 | 234 | #sample if maxn 235 | if maxn: 236 | truegrid = random.sample(truegrid, min(maxn, len(truegrid))) 237 | 238 | return truegrid 239 | 240 | def make_hires_map(slide, pred, grid, patch_size, mpp, maxmpp, overlap): 241 | ''' 242 | Given the list of predictions and the known overlap it gives the hires probability map 243 | ''' 244 | W = slide.dimensions[0] 245 | H = slide.dimensions[1] 246 | w = int(np.round(W*1./(patch_size*mpp/maxmpp))) 247 | h = int(np.round(H*1./(patch_size*mpp/maxmpp))) 248 | 249 | newimg = np.zeros((h*overlap,w*overlap))-1 250 | offset_x = np.floor((W*1./(patch_size*mpp/maxmpp)-w)*(patch_size*mpp/maxmpp)) 251 | offset_y = np.floor((H*1./(patch_size*mpp/maxmpp)-h)*(patch_size*mpp/maxmpp)) 252 | add_x = np.linspace(0,offset_x,w).astype(int) 253 | add_y = np.linspace(0,offset_y,h).astype(int) 254 | for i,(xgrid,ygrid) in enumerate(grid): 255 | yindx = int(ygrid/(patch_size*mpp/maxmpp)) 256 | xindx = int(xgrid/(patch_size*mpp/maxmpp)) 257 | y = np.round((ygrid-add_y[yindx])*overlap/(patch_size*mpp/maxmpp)).astype(int) 258 | x = np.round((xgrid-add_x[xindx])*overlap/(patch_size*mpp/maxmpp)).astype(int) 259 | newimg[y,x] = pred[i] 260 | return newimg 261 | 262 | def make_hires_map_stride(slide, pred, grid, stride): 263 | ''' 264 | Given the list of predictions and the stride it gives the hires probability map 265 | Grid ndarray specify the center pixel of a tile 266 | ''' 267 | W = slide.dimensions[0] 268 | H = slide.dimensions[1] 269 | w = int(round(W*1./stride)) 270 | h = int(round(H*1./stride)) 271 | 272 | # Scale grid to pixels 273 | ngrid = np.floor(grid.astype(float) / stride).astype(int) 274 | 275 | # Make image 276 | newimg = np.zeros((h,w))-2 277 | 278 | # Add tissue 279 | tissue = threshold_stride(slide, stride) 280 | newimg[tissue>0] = -1 281 | 282 | # paint predictions 283 | for i in range(len(ngrid)): 284 | x, y = ngrid[i] 285 | newimg[y,x] = pred[i] 286 | 287 | return newimg 288 | 289 | def threshold_stride(slide, stride): 290 | W = slide.dimensions[0] 291 | H = slide.dimensions[1] 292 | w = int(np.ceil(W*1./stride)) 293 | h = int(np.ceil(H*1./stride)) 294 | thumbnail = slide.get_thumbnail((w,h)) 295 | thumbnail = thumbnail.resize((w,h)) 296 | img = image2array(thumbnail) 297 | #calc std on color image 298 | std = np.std(img,axis=-1) 299 | #image to bw 300 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 301 | ## remove black dots ## 302 | _,tmp = cv2.threshold(img,20,255,cv2.THRESH_BINARY_INV) 303 | kernel = np.ones((5,5),np.uint8) 304 | tmp = cv2.dilate(tmp,kernel,iterations = 1) 305 | img[tmp==255] = 255 306 | img = cv2.GaussianBlur(img,(5,5),0) 307 | t,img = cv2.threshold(img,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) 308 | img = 255-img 309 | img[std<5] = 0 310 | return img 311 | 312 | def plot_extraction(slide, patch_size=224, mpp=0.5, min_cc_size=10, max_ratio_size=10, dilate=False, erode=False, prune=False, overlap=1, maxn=None, bmp=None, oversample=False, mult=1, base_mpp=None, save=''): 313 | '''Script that shows the result of applying the detector in case you get weird results''' 314 | import matplotlib.pyplot as plt 315 | import matplotlib.patches as patches 316 | downsample = 50. 317 | dpi = 100 318 | 319 | if save: 320 | plt.switch_backend('agg') 321 | 322 | grid = make_sample_grid(slide, patch_size, mpp=mpp, min_cc_size=min_cc_size, max_ratio_size=max_ratio_size, dilate=dilate, erode=erode, prune=prune, overlap=overlap, maxn=maxn, bmp=bmp, oversample=oversample, mult=mult, base_mpp=base_mpp) 323 | thumb = slide.get_thumbnail((np.round(slide.dimensions[0]/downsample),np.round(slide.dimensions[1]/downsample))) 324 | width, height = thumb.size 325 | 326 | ps = [] 327 | for tup in grid: 328 | ps.append(patches.Rectangle( 329 | (tup[0]/downsample, tup[1]/downsample), patch_size/downsample*(mpp/base_mpp), patch_size/downsample*(mpp/base_mpp), fill=False, 330 | edgecolor="red" 331 | )) 332 | 333 | fig = plt.figure(figsize=(int(width/dpi), int(height/dpi)), dpi=dpi) 334 | ax = fig.add_subplot(111, aspect='equal') 335 | ax.imshow(thumb) 336 | for p in ps: 337 | ax.add_patch(p) 338 | if save: 339 | plt.savefig(save) 340 | else: 341 | plt.show() 342 | 343 | def detect_marker(thumb, mult): 344 | ksize = int(max(1, mult)) 345 | #ksize = 1 346 | img = cv2.GaussianBlur(thumb, (5,5), 0) 347 | hsv_origimg = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 348 | # Extract marker 349 | black_marker = cv2.inRange(hsv_origimg, np.array([0, 0, 0]), np.array([180, 255, 125])) # black marker 350 | blue_marker = cv2.inRange(hsv_origimg, np.array([90, 30, 30]), np.array([130, 255, 255])) # blue marker 351 | green_marker = cv2.inRange(hsv_origimg, np.array([40, 30, 30]), np.array([90, 255, 255])) # green marker 352 | mask_hsv = cv2.bitwise_or(cv2.bitwise_or(black_marker, blue_marker), green_marker) 353 | mask_hsv = cv2.erode(mask_hsv, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize,ksize))) 354 | mask_hsv = cv2.dilate(mask_hsv, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize*3,ksize*3))) 355 | if np.count_nonzero(mask_hsv) > 0: 356 | return mask_hsv 357 | else: 358 | return None 359 | -------------------------------------------------------------------------------- /code/feature_extraction/extract_tissue.py~: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import openslide 4 | from skimage.morphology import binary_erosion, binary_dilation, label, dilation, square, skeletonize 5 | from skimage.filters import threshold_otsu 6 | import PIL.Image as Image 7 | import random 8 | import pdb 9 | 10 | isyntax_mpp = 0.25 11 | MAX_PIXEL_DIFFERENCE = 0.2 # difference must be within 20% of image size 12 | 13 | def slide_base_mpp(slide): 14 | return float(slide.properties[openslide.PROPERTY_NAME_MPP_X]) 15 | 16 | def find_level(slide, mpp, patchsize=224, base_mpp=None): 17 | if base_mpp is None: 18 | base_mpp = isyntax_mpp 19 | downsample = mpp / base_mpp 20 | for i in range(slide.level_count)[::-1]: 21 | if abs(downsample / slide.level_downsamples[i] * patchsize - patchsize) < MAX_PIXEL_DIFFERENCE * patchsize or downsample > slide.level_downsamples[i]: 22 | level = i 23 | mult = downsample / slide.level_downsamples[level] 24 | break 25 | else: 26 | raise Exception('Requested resolution ({} mpp) is too high'.format(mpp)) 27 | #move mult to closest pixel 28 | mult = np.round(mult*patchsize)/patchsize 29 | if abs(mult*patchsize - patchsize) < MAX_PIXEL_DIFFERENCE * patchsize: 30 | mult = 1. 31 | return level, mult 32 | 33 | def image2array(img): 34 | if img.__class__.__name__=='Image': 35 | if img.mode=='RGB': 36 | img=np.array(img) 37 | r,g,b = np.rollaxis(img, axis=-1) 38 | img=np.stack([r,g,b],axis=-1) 39 | elif img.mode=='RGBA': 40 | img=np.array(img) 41 | r,g,b,a = np.rollaxis(img, axis=-1) 42 | img=np.stack([r,g,b],axis=-1) 43 | else: 44 | sys.exit('Error: image is not RGB slide') 45 | img=np.uint8(img) 46 | return img#cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 47 | 48 | def is_sample(img,threshold=0.9,ratioCenter=0.1,wholeAreaCutoff=0.5,centerAreaCutoff=0.9): 49 | nrows,ncols=img.shape 50 | timg=cv2.threshold(img, 255*threshold, 1, cv2.THRESH_BINARY_INV) 51 | kernel=np.ones((5,5),np.uint8) 52 | cimg=cv2.morphologyEx(timg[1], cv2.MORPH_CLOSE, kernel) 53 | crow=np.rint(nrows/2).astype(int) 54 | ccol=np.rint(ncols/2).astype(int) 55 | drow=np.rint(nrows*ratioCenter/2).astype(int) 56 | dcol=np.rint(ncols*ratioCenter/2).astype(int) 57 | centerw=cimg[crow-drow:crow+drow,ccol-dcol:ccol+dcol] 58 | if (np.count_nonzero(cimg)0)|(std<5)) 84 | masked = np.ma.masked_array(img_g, (marker>0)|(img_g==255)) 85 | t = threshold_otsu(masked.compressed()) 86 | img_g = cv2.threshold(img_g, t, 255, cv2.THRESH_BINARY)[1] 87 | 88 | else: 89 | #masked = np.ma.masked_array(img_g, std<5) 90 | masked = np.ma.masked_array(img_g, img_g==255) 91 | t = threshold_otsu(masked.compressed()) 92 | img_g = cv2.threshold(img_g, t, 255, cv2.THRESH_BINARY)[1] 93 | #t, img_g = cv2.threshold(img_g, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) 94 | 95 | # Exclude marker 96 | if marker is not None: 97 | img_g = cv2.subtract(~img_g, marker) 98 | #img_g = cv2.erode(img_g, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15,15))) 99 | #img_g = cv2.dilate(img_g, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15,15))) 100 | else: 101 | img_g = 255 - img_g 102 | 103 | # Remove grays 104 | img_g[std<5] = 0 105 | 106 | # Rescale 107 | if mult>1: 108 | img_g = img_g.reshape(h//mult, mult, w//mult, mult).max(axis=(1, 3)) 109 | 110 | return img_g, t 111 | 112 | def remove_black_ink(img_g, th=50, delta=50): 113 | ''' 114 | image in gray scale 115 | returns mask where ink is positive 116 | th=50 and delta=50 was chosen based on some slides 117 | ''' 118 | dist = np.clip(img_g - float(th), 0, None) 119 | mask = dist < delta 120 | if mask.sum() > 0: 121 | mask_s = skeletonize(mask) 122 | d = int(np.round(0.1 * mask.sum() / mask_s.sum())) 123 | mask = dilation(mask, square(2*d+1)) 124 | return mask 125 | else: 126 | return None 127 | 128 | def filter_regions(img,min_size): 129 | l, n = label(img, return_num=True) 130 | for i in range(1,n+1): 131 | #filter small regions 132 | if l[l==i].size < min_size: 133 | l[l==i] = 0 134 | return l 135 | 136 | def add(overlap): 137 | return np.linspace(0,1,overlap+1)[1:-1] 138 | 139 | def add2offset(img, slide, patch_size, mpp, maxmpp): 140 | size_x = img.shape[1] 141 | size_y = img.shape[0] 142 | offset_x = np.floor((slide.dimensions[0]*1./(patch_size*mpp/maxmpp)-size_x)*(patch_size*mpp/maxmpp)) 143 | offset_y = np.floor((slide.dimensions[1]*1./(patch_size*mpp/maxmpp)-size_y)*(patch_size*mpp/maxmpp)) 144 | add_x = np.linspace(0,offset_x,size_x).astype(int) 145 | add_y = np.linspace(0,offset_y,size_y).astype(int) 146 | return add_x, add_y 147 | 148 | def addoverlap(w, grid, overlap, patch_size, mpp, maxmpp, img, offset=0): 149 | o = (add(overlap)*(patch_size*mpp/maxmpp)).astype(int) 150 | ox,oy = np.meshgrid(o,o) 151 | connx = np.zeros(img.shape).astype(bool) 152 | conny = np.zeros(img.shape).astype(bool) 153 | connd = np.zeros(img.shape).astype(bool) 154 | connu = np.zeros(img.shape).astype(bool) 155 | connx[:,:-1] = img[:,1:] 156 | conny[:-1,:] = img[1:,:] 157 | connd[:-1,:-1] = img[1:,1:] 158 | connu[1:,:-1] = img[:-1,1:] & ( ~img[1:,1:] | ~img[:-1,:-1] ) 159 | connx = connx[w] 160 | conny = conny[w] 161 | connd = connd[w] 162 | connu = connu[w] 163 | extra = [] 164 | for i,(x,y) in enumerate(grid): 165 | if connx[i]: extra.extend(zip(o+x-offset,np.repeat(y,overlap-1)-offset)) 166 | if conny[i]: extra.extend(zip(np.repeat(x,overlap-1)-offset,o+y-offset)) 167 | if connd[i]: extra.extend(zip(ox.flatten()+x-offset,oy.flatten()+y-offset)) 168 | if connu[i]: extra.extend(zip(x+ox.flatten()-offset,y-oy.flatten()-offset)) 169 | return extra 170 | 171 | def make_sample_grid(slide, patch_size=224, mpp=0.5, min_cc_size=10, max_ratio_size=10, dilate=False, erode=False, prune=False, overlap=1, maxn=None, bmp=None, oversample=False, mult=1, centerpixel=False, base_mpp=None, thumbnail_coords=False): 172 | ''' 173 | Script that given an openslide object return a list of tuples 174 | in the form of (x,y) coordinates for patch extraction of sample patches. 175 | It has an erode option to make sure to get patches that are full of tissue. 176 | It has a prune option to check if patches are sample. It is slow. 177 | If bmp is given, it samples from within areas of the bmp that are nonzero. 178 | If oversample is True, it will downsample for full resolution regardless of what resolution is requested. 179 | mult is used to increase the resolution of the thumbnail to get finer tissue extraction 180 | ''' 181 | if base_mpp is None: 182 | base_mpp = isyntax_mpp 183 | 184 | if oversample: 185 | img, th = threshold(slide, patch_size, base_mpp, base_mpp, mult) 186 | else: 187 | img, th = threshold(slide, patch_size, mpp, base_mpp, mult) 188 | 189 | if bmp: 190 | bmplab = Image.open(bmp) 191 | thumbx, thumby = img.shape 192 | bmplab = bmplab.resize((thumby, thumbx), Image.ANTIALIAS) 193 | bmplab = np.array(bmplab) 194 | bmplab[bmplab>0] = 1 195 | img = np.logical_and(img, bmplab) 196 | 197 | img = filter_regions(img,min_cc_size) 198 | img[img>0]=1 199 | if erode: 200 | img = binary_erosion(img) 201 | if dilate: 202 | img = binary_dilation(img) 203 | 204 | if oversample: 205 | add_x, add_y = add2offset(img, slide, patch_size, base_mpp, base_mpp) 206 | else: 207 | add_x, add_y = add2offset(img, slide, patch_size, mpp, base_mpp) 208 | 209 | #list of sample pixels 210 | w = np.where(img>0) 211 | if thumbnail_coords: 212 | return list(zip(w[1],w[0])), img.shape[1], img.shape[0] 213 | 214 | #grid=zip(w[1]*patch_size,w[0]*patch_size) 215 | if oversample: 216 | offset = int(0.5 * patch_size * ((mpp/base_mpp) - 1)) 217 | grid = list(zip((w[1]*(patch_size)+add_x[w[1]]-offset).astype(int),(w[0]*(patch_size)+add_y[w[0]]-offset).astype(int))) 218 | else: 219 | grid = list(zip((w[1]*(patch_size*mpp/base_mpp)+add_x[w[1]]).astype(int),(w[0]*(patch_size*mpp/base_mpp)+add_y[w[0]]).astype(int))) 220 | 221 | #connectivity 222 | if overlap > 1: 223 | if oversample: 224 | extra = addoverlap(w, grid, overlap, patch_size, base_mpp, base_mpp, img, offset=offset) 225 | grid.extend(extra) 226 | else: 227 | extra = addoverlap(w, grid, overlap, patch_size, mpp, base_mpp, img) 228 | grid.extend(extra) 229 | 230 | # center pixel offset 231 | if centerpixel: 232 | offset = int(mpp / base_mpp * patch_size // 2) 233 | grid = [(x[0] + offset, x[1] + offset) for x in grid] 234 | 235 | #prune squares 236 | if prune: 237 | level, mult = find_level(slide,mpp,base_mpp) 238 | psize = int(patch_size*mult) 239 | truegrid = [] 240 | for tup in grid: 241 | reg = slide.read_region(tup,level,(psize,psize)) 242 | if mult != 1: 243 | reg = reg.resize((224,224),Image.BILINEAR) 244 | reg = image2array(reg) 245 | if is_sample(reg,th/255,0.2,0.4,0.5): 246 | truegrid.append(tup) 247 | else: 248 | truegrid = grid 249 | 250 | #sample if maxn 251 | if maxn: 252 | truegrid = random.sample(truegrid, min(maxn, len(truegrid))) 253 | 254 | return truegrid 255 | 256 | def make_hires_map(slide, pred, grid, patch_size, mpp, maxmpp, overlap): 257 | ''' 258 | Given the list of predictions and the known overlap it gives the hires probability map 259 | ''' 260 | W = slide.dimensions[0] 261 | H = slide.dimensions[1] 262 | w = int(np.round(W*1./(patch_size*mpp/maxmpp))) 263 | h = int(np.round(H*1./(patch_size*mpp/maxmpp))) 264 | 265 | newimg = np.zeros((h*overlap,w*overlap))-1 266 | offset_x = np.floor((W*1./(patch_size*mpp/maxmpp)-w)*(patch_size*mpp/maxmpp)) 267 | offset_y = np.floor((H*1./(patch_size*mpp/maxmpp)-h)*(patch_size*mpp/maxmpp)) 268 | add_x = np.linspace(0,offset_x,w).astype(int) 269 | add_y = np.linspace(0,offset_y,h).astype(int) 270 | for i,(xgrid,ygrid) in enumerate(grid): 271 | yindx = int(ygrid/(patch_size*mpp/maxmpp)) 272 | xindx = int(xgrid/(patch_size*mpp/maxmpp)) 273 | y = np.round((ygrid-add_y[yindx])*overlap/(patch_size*mpp/maxmpp)).astype(int) 274 | x = np.round((xgrid-add_x[xindx])*overlap/(patch_size*mpp/maxmpp)).astype(int) 275 | newimg[y,x] = pred[i] 276 | return newimg 277 | 278 | def make_hires_map_stride(slide, pred, grid, stride): 279 | ''' 280 | Given the list of predictions and the stride it gives the hires probability map 281 | Grid ndarray specify the center pixel of a tile 282 | ''' 283 | W = slide.dimensions[0] 284 | H = slide.dimensions[1] 285 | w = int(round(W*1./stride)) 286 | h = int(round(H*1./stride)) 287 | 288 | # Scale grid to pixels 289 | ngrid = np.floor(grid.astype(float) / stride).astype(int) 290 | 291 | # Make image 292 | newimg = np.zeros((h,w))-2 293 | 294 | # Add tissue 295 | tissue = threshold_stride(slide, stride) 296 | newimg[tissue>0] = -1 297 | 298 | # paint predictions 299 | for i in range(len(ngrid)): 300 | x, y = ngrid[i] 301 | newimg[y,x] = pred[i] 302 | 303 | return newimg 304 | 305 | def threshold_stride(slide, stride): 306 | W = slide.dimensions[0] 307 | H = slide.dimensions[1] 308 | w = int(np.ceil(W*1./stride)) 309 | h = int(np.ceil(H*1./stride)) 310 | thumbnail = slide.get_thumbnail((w,h)) 311 | thumbnail = thumbnail.resize((w,h)) 312 | img = image2array(thumbnail) 313 | #calc std on color image 314 | std = np.std(img,axis=-1) 315 | #image to bw 316 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 317 | ## remove black dots ## 318 | _,tmp = cv2.threshold(img,20,255,cv2.THRESH_BINARY_INV) 319 | kernel = np.ones((5,5),np.uint8) 320 | tmp = cv2.dilate(tmp,kernel,iterations = 1) 321 | img[tmp==255] = 255 322 | img = cv2.GaussianBlur(img,(5,5),0) 323 | t,img = cv2.threshold(img,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) 324 | img = 255-img 325 | img[std<5] = 0 326 | return img 327 | 328 | def plot_extraction(slide, patch_size=224, mpp=0.5, min_cc_size=10, max_ratio_size=10, dilate=False, erode=False, prune=False, overlap=1, maxn=None, bmp=None, oversample=False, mult=1, base_mpp=None, save=''): 329 | '''Script that shows the result of applying the detector in case you get weird results''' 330 | import matplotlib.pyplot as plt 331 | import matplotlib.patches as patches 332 | downsample = 50. 333 | dpi = 100 334 | if base_mpp is None: 335 | base_mpp = isyntax_mpp 336 | 337 | if save: 338 | plt.switch_backend('agg') 339 | 340 | grid = make_sample_grid(slide, patch_size, mpp=mpp, min_cc_size=min_cc_size, max_ratio_size=max_ratio_size, dilate=dilate, erode=erode, prune=prune, overlap=overlap, maxn=maxn, bmp=bmp, oversample=oversample, mult=mult, base_mpp=base_mpp) 341 | thumb = slide.get_thumbnail((np.round(slide.dimensions[0]/downsample),np.round(slide.dimensions[1]/downsample))) 342 | width, height = thumb.size 343 | 344 | ps = [] 345 | for tup in grid: 346 | ps.append(patches.Rectangle( 347 | (tup[0]/downsample, tup[1]/downsample), patch_size/downsample*(mpp/base_mpp), patch_size/downsample*(mpp/base_mpp), fill=False, 348 | edgecolor="red" 349 | )) 350 | 351 | fig = plt.figure(figsize=(int(width/dpi), int(height/dpi)), dpi=dpi) 352 | ax = fig.add_subplot(111, aspect='equal') 353 | ax.imshow(thumb) 354 | for p in ps: 355 | ax.add_patch(p) 356 | if save: 357 | plt.savefig(save) 358 | else: 359 | plt.show() 360 | 361 | def detect_marker(thumb, mult): 362 | ksize = int(max(1, mult)) 363 | #ksize = 1 364 | img = cv2.GaussianBlur(thumb, (5,5), 0) 365 | hsv_origimg = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 366 | # Extract marker 367 | black_marker = cv2.inRange(hsv_origimg, np.array([0, 0, 0]), np.array([180, 255, 125])) # black marker 368 | blue_marker = cv2.inRange(hsv_origimg, np.array([90, 30, 30]), np.array([130, 255, 255])) # blue marker 369 | green_marker = cv2.inRange(hsv_origimg, np.array([40, 30, 30]), np.array([90, 255, 255])) # green marker 370 | mask_hsv = cv2.bitwise_or(cv2.bitwise_or(black_marker, blue_marker), green_marker) 371 | mask_hsv = cv2.erode(mask_hsv, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize,ksize))) 372 | mask_hsv = cv2.dilate(mask_hsv, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize*3,ksize*3))) 373 | if np.count_nonzero(mask_hsv) > 0: 374 | return mask_hsv 375 | else: 376 | return None 377 | 378 | #kernel=np.ones((5,5),np.uint8) 379 | #cimg=cv2.morphologyEx(th2, cv2.MORPH_OPEN, kernel) 380 | 381 | #out=[] 382 | #for tup in grid: 383 | # out.append(sum([1 if x==tup else 0 for x in grid])) 384 | #np.unique(np.array(out)) 385 | #TESTS 386 | #from SlideTileExtractor import extract_tissue 387 | #import openslide 388 | #slide = openslide.OpenSlide('/lila/data/fuchs/projects/lung/impacted/1142892.svs') 389 | #extract_tissue.plot_extraction(slide, patch_size=224, mpp=0.5, power=None, min_cc_size=10, max_ratio_size=10, dilate=False, erode=False, prune=False, overlap=1, maxn=None, bmp=None, oversample=False, mult=1, save='') 390 | -------------------------------------------------------------------------------- /code/feature_extraction/make_tensors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import openslide 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | import PIL.Image as Image 10 | import resnet_custom 11 | import pdb 12 | import encoders 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--slide_data', type=str, default='', help='path to slide data') 17 | parser.add_argument('--tile_data', type=str, default='', help='path to tile data') 18 | parser.add_argument('--encoder', type=str, default='', choices=[ 19 | 'tres50_imagenet', 20 | 'ctranspath', 21 | 'phikon', 22 | 'uni', 23 | 'gigapath', 24 | 'virchow', 25 | 'h-optimus-0', 26 | 'dinosmall', 27 | 'dinobase', 28 | ], help='choice of encoder') 29 | parser.add_argument('--tilesize', type=int, default=224, help='tile size') 30 | parser.add_argument('--bsize', type=int, default=128, help='batchs size') 31 | parser.add_argument('--workers', type=int, default=10, help='workers') 32 | 33 | class slide_dataset(data.Dataset): 34 | def __init__(self, slide, df, trans, tilesize): 35 | self.slide = slide 36 | self.df = df 37 | self.tilesize = tilesize 38 | self.trans = trans 39 | def __getitem__(self, index): 40 | row = self.df.iloc[index] 41 | size = int(np.round(self.tilesize * row.mult)) 42 | img = self.slide.read_region((int(row.x), int(row.y)), int(row.level), (size, size)).convert('RGB') 43 | if row.mult != 1: 44 | img = img.resize((self.tilesize, self.tilesize), Image.LANCZOS) 45 | img = self.trans(img) 46 | return img 47 | def __len__(self): 48 | return len(self.df) 49 | 50 | def main(): 51 | ''' 52 | slide_data has columns: 53 | - slide_path: full path to slide 54 | - slide: unique slide identifier 55 | - tensor_root: full path to root for that data. Need to add the encoder type 56 | - tensor_name: name of tensor file without path 57 | tile_data has columns: 58 | - slide: unique slide identifier 59 | - x: x coord 60 | - y: y coord 61 | - level: pyramid level at which to extract data 62 | - mult: factor for tile resize 63 | ''' 64 | global args 65 | args = parser.parse_args() 66 | 67 | # Set up encoder 68 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 69 | model, transform, ndim = encoders.get_encoder(args.encoder) 70 | model.to(device) 71 | 72 | # Set up data 73 | master = pd.read_csv(args.slide_data) 74 | tiles = pd.read_csv(args.tile_data) 75 | 76 | # Output directory 77 | row = master.iloc[0] 78 | if not os.path.exists(row.tensor_root): 79 | os.mkdir(row.tensor_root) 80 | 81 | if not os.path.exists(os.path.join(row.tensor_root, args.encoder)): 82 | os.mkdir(os.path.join(row.tensor_root, args.encoder)) 83 | 84 | # Iterate dataset 85 | with torch.no_grad(): 86 | for i, row in master.iterrows(): 87 | print(f'[{i+1}]/[{len(master)}]', end='\r') 88 | 89 | tensor_name = os.path.join(row.tensor_root, args.encoder, row.tensor_name) 90 | 91 | if not os.path.exists(tensor_name): 92 | # Set up slide 93 | slide = openslide.OpenSlide(row.slide_path) 94 | 95 | # Get coords 96 | grid = tiles[tiles.slide == row.slide].reset_index(drop=True) 97 | 98 | # Set up dataset and loader 99 | dset = slide_dataset(slide, grid, transform, args.tilesize) 100 | loader = torch.utils.data.DataLoader(dset, batch_size=args.bsize, shuffle=False, num_workers=args.workers) 101 | 102 | # Save tensor 103 | tensor = torch.zeros(len(grid), ndim).float() 104 | for j, img in enumerate(loader): 105 | out = model(img.cuda()) 106 | tensor[j*args.bsize:j*args.bsize+img.size(0),:] = out.detach().clone() 107 | 108 | torch.save(tensor, tensor_name) 109 | 110 | print('') 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /code/feature_extraction/optimus.py: -------------------------------------------------------------------------------- 1 | '''https://github.com/bioptimus/releases/tree/main/models/h-optimus/v0?utm_source=owkin&utm_medium=referral&utm_campaign=h-bioptimus-o 2 | import functools 3 | 4 | import timm 5 | import torch 6 | from torchvision import transforms 7 | 8 | 9 | PATH_TO_CHECKPOINT = "" # Path to the downloaded checkpoint. 10 | 11 | params = { 12 | 'patch_size': 14, 13 | 'embed_dim': 1536, 14 | 'depth': 40, 15 | 'num_heads': 24, 16 | 'init_values': 1e-05, 17 | 'mlp_ratio': 5.33334, 18 | 'mlp_layer': functools.partial( 19 | timm.layers.mlp.GluMlp, act_layer=torch.nn.modules.activation.SiLU, gate_last=False 20 | ), 21 | 'act_layer': torch.nn.modules.activation.SiLU, 22 | 'reg_tokens': 4, 23 | 'no_embed_class': True, 24 | 'img_size': 224, 25 | 'num_classes': 0, 26 | 'in_chans': 3 27 | } 28 | 29 | model = timm.models.VisionTransformer(**params) 30 | model.load_state_dict(torch.load(PATH_TO_CHECKPOINT, map_location="cpu")) 31 | model.eval() 32 | model.to("cuda") 33 | 34 | transform = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize( 37 | mean=(0.707223, 0.578729, 0.703617), 38 | std=(0.211883, 0.230117, 0.177517) 39 | ), 40 | ]) 41 | 42 | input = torch.rand(3, 224, 224) 43 | input = transforms.ToPILImage()(input) 44 | 45 | # We recommend using mixed precision for faster inference. 46 | with torch.autocast(device_type="cuda", dtype=torch.float16): 47 | with torch.inference_mode(): 48 | features = model(transform(input).unsqueeze(0).to("cuda")) 49 | 50 | assert features.shape == (1, 1536) 51 | ''' 52 | import functools 53 | import timm 54 | import torch 55 | 56 | def get_model(arch): 57 | if arch == 'h-optimus-0': 58 | params = { 59 | 'patch_size': 14, 60 | 'embed_dim': 1536, 61 | 'depth': 40, 62 | 'num_heads': 24, 63 | 'init_values': 1e-05, 64 | 'mlp_ratio': 5.33334, 65 | 'mlp_layer': functools.partial( 66 | timm.layers.mlp.GluMlp, act_layer=torch.nn.modules.activation.SiLU, gate_last=False 67 | ), 68 | 'act_layer': torch.nn.modules.activation.SiLU, 69 | 'reg_tokens': 4, 70 | 'no_embed_class': True, 71 | 'img_size': 224, 72 | 'num_classes': 0, 73 | 'in_chans': 3 74 | } 75 | model = timm.models.VisionTransformer(**params) 76 | model.load_state_dict(torch.load('path/to/checkpoint.pth', map_location="cpu")) 77 | 78 | return model 79 | -------------------------------------------------------------------------------- /code/feature_extraction/phikon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoImageProcessor, ViTModel 4 | 5 | class phikon_encoder(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.vit = ViTModel.from_pretrained("owkin/phikon", add_pooling_layer=False) 9 | 10 | def forward(self, x): 11 | outputs = self.vit(x) 12 | features = outputs.last_hidden_state[:, 0, :] 13 | return features 14 | 15 | def get_model(): 16 | return phikon_encoder() 17 | -------------------------------------------------------------------------------- /code/feature_extraction/resnet_custom.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | class Bottleneck_Baseline(nn.Module): 18 | expansion = 4 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(Bottleneck_Baseline, self).__init__() 22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv3(out) 45 | out = self.bn3(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | class ResNet_Baseline(nn.Module): 56 | 57 | def __init__(self, block, layers, nclass=2): 58 | self.inplanes = 64 59 | super(ResNet_Baseline, self).__init__() 60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 61 | bias=False) 62 | self.bn1 = nn.BatchNorm2d(64) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = self._make_layer(block, 64, layers[0]) 66 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 67 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 68 | self.avgpool = nn.AdaptiveAvgPool2d(1) 69 | self.classifier = nn.Linear(1024, nclass) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | elif isinstance(m, nn.BatchNorm2d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def _make_layer(self, block, planes, blocks, stride=1): 79 | downsample = None 80 | if stride != 1 or self.inplanes != planes * block.expansion: 81 | downsample = nn.Sequential( 82 | nn.Conv2d(self.inplanes, planes * block.expansion, 83 | kernel_size=1, stride=stride, bias=False), 84 | nn.BatchNorm2d(planes * block.expansion), 85 | ) 86 | 87 | layers = [] 88 | layers.append(block(self.inplanes, planes, stride, downsample)) 89 | self.inplanes = planes * block.expansion 90 | for i in range(1, blocks): 91 | layers.append(block(self.inplanes, planes)) 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.conv1(x) 97 | x = self.bn1(x) 98 | x = self.relu(x) 99 | x = self.maxpool(x) 100 | 101 | x = self.layer1(x) 102 | x = self.layer2(x) 103 | x = self.layer3(x) 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | #x = self.classifier(x) Commented out to get the feature embedding 109 | 110 | return x 111 | 112 | def resnet50_baseline(pretrained=False): 113 | """Constructs a Modified ResNet-50 model. 114 | Args: 115 | pretrained (bool): If True, returns a model pre-trained on ImageNet 116 | """ 117 | model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3]) 118 | if pretrained: 119 | model = load_pretrained_weights(model, 'resnet50') 120 | return model 121 | 122 | def load_pretrained_weights(model, name): 123 | pretrained_dict = model_zoo.load_url(model_urls[name]) 124 | model.load_state_dict(pretrained_dict, strict=False) 125 | return model 126 | -------------------------------------------------------------------------------- /code/feature_extraction/uni.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import timm 4 | 5 | def get_model(): 6 | model = timm.create_model( 7 | "vit_large_patch16_224", img_size=224, patch_size=16, init_values=1e-5, num_classes=0, dynamic_img_size=True 8 | ) 9 | model.load_state_dict(torch.load('path/to/checkpoint.pth', map_location="cpu"), strict=True) 10 | return model 11 | -------------------------------------------------------------------------------- /code/feature_extraction/virchow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timm 3 | import torch 4 | import torch.nn as nn 5 | from timm.data import resolve_data_config 6 | from timm.data.transforms_factory import create_transform 7 | from timm.layers import SwiGLUPacked 8 | from PIL import Image 9 | from huggingface_hub import login, hf_hub_download 10 | 11 | class virchow(nn.Module): 12 | def __init__(self): 13 | super(virchow, self).__init__() 14 | self.virchow = timm.create_model("hf-hub:paige-ai/Virchow", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU) 15 | 16 | def forward(self, x): 17 | output = self.virchow(x) # size: 1 x 257 x 1280 18 | class_token = output[:, 0] # size: 1 x 1280 19 | patch_tokens = output[:, 1:] # size: 1 x 256 x 1280 20 | # concatenate class token and average pool of patch tokens 21 | embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) # size: 1 x 2560 22 | return embedding 23 | -------------------------------------------------------------------------------- /code/feature_extraction/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 26 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 27 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 28 | def norm_cdf(x): 29 | # Computes standard normal cumulative distribution function 30 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 31 | 32 | if (mean < a - 2 * std) or (mean > b + 2 * std): 33 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 34 | "The distribution of values may be incorrect.", 35 | stacklevel=2) 36 | 37 | with torch.no_grad(): 38 | # Values are generated by using a truncated uniform distribution and 39 | # then using the inverse CDF for the normal distribution. 40 | # Get upper and lower cdf values 41 | l = norm_cdf((a - mean) / std) 42 | u = norm_cdf((b - mean) / std) 43 | 44 | # Uniformly fill tensor with values from [l, u], then translate to 45 | # [2l-1, 2u-1]. 46 | tensor.uniform_(2 * l - 1, 2 * u - 1) 47 | 48 | # Use inverse cdf transform for normal distribution to get truncated 49 | # standard normal 50 | tensor.erfinv_() 51 | 52 | # Transform to proper mean, std 53 | tensor.mul_(std * math.sqrt(2.)) 54 | tensor.add_(mean) 55 | 56 | # Clamp to ensure it's in the proper range 57 | tensor.clamp_(min=a, max=b) 58 | return tensor 59 | 60 | 61 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 62 | # type: (Tensor, float, float, float, float) -> Tensor 63 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 64 | 65 | def drop_path(x, drop_prob: float = 0., training: bool = False): 66 | if drop_prob == 0. or not training: 67 | return x 68 | keep_prob = 1 - drop_prob 69 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 70 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 71 | random_tensor.floor_() # binarize 72 | output = x.div(keep_prob) * random_tensor 73 | return output 74 | 75 | 76 | class DropPath(nn.Module): 77 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 78 | """ 79 | def __init__(self, drop_prob=None): 80 | super(DropPath, self).__init__() 81 | self.drop_prob = drop_prob 82 | 83 | def forward(self, x): 84 | return drop_path(x, self.drop_prob, self.training) 85 | 86 | 87 | class Mlp(nn.Module): 88 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 89 | super().__init__() 90 | out_features = out_features or in_features 91 | hidden_features = hidden_features or in_features 92 | self.fc1 = nn.Linear(in_features, hidden_features) 93 | self.act = act_layer() 94 | self.fc2 = nn.Linear(hidden_features, out_features) 95 | self.drop = nn.Dropout(drop) 96 | 97 | def forward(self, x): 98 | x = self.fc1(x) 99 | x = self.act(x) 100 | x = self.drop(x) 101 | x = self.fc2(x) 102 | x = self.drop(x) 103 | return x 104 | 105 | 106 | class Attention(nn.Module): 107 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 108 | super().__init__() 109 | self.num_heads = num_heads 110 | head_dim = dim // num_heads 111 | self.scale = qk_scale or head_dim ** -0.5 112 | 113 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 114 | self.attn_drop = nn.Dropout(attn_drop) 115 | self.proj = nn.Linear(dim, dim) 116 | self.proj_drop = nn.Dropout(proj_drop) 117 | 118 | def forward(self, x): 119 | B, N, C = x.shape 120 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | q, k, v = qkv[0], qkv[1], qkv[2] 122 | 123 | attn = (q @ k.transpose(-2, -1)) * self.scale 124 | attn = attn.softmax(dim=-1) 125 | attn = self.attn_drop(attn) 126 | 127 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 128 | x = self.proj(x) 129 | x = self.proj_drop(x) 130 | return x, attn 131 | 132 | 133 | class Block(nn.Module): 134 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 135 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 136 | super().__init__() 137 | self.norm1 = norm_layer(dim) 138 | self.attn = Attention( 139 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 140 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 141 | self.norm2 = norm_layer(dim) 142 | mlp_hidden_dim = int(dim * mlp_ratio) 143 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 144 | 145 | def forward(self, x, return_attention=False): 146 | y, attn = self.attn(self.norm1(x)) 147 | if return_attention: 148 | return attn 149 | x = x + self.drop_path(y) 150 | x = x + self.drop_path(self.mlp(self.norm2(x))) 151 | return x 152 | 153 | 154 | class PatchEmbed(nn.Module): 155 | """ Image to Patch Embedding 156 | """ 157 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 158 | super().__init__() 159 | num_patches = (img_size // patch_size) * (img_size // patch_size) 160 | self.img_size = img_size 161 | self.patch_size = patch_size 162 | self.num_patches = num_patches 163 | 164 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 165 | 166 | def forward(self, x): 167 | B, C, H, W = x.shape 168 | x = self.proj(x).flatten(2).transpose(1, 2) 169 | return x 170 | 171 | 172 | class VisionTransformer(nn.Module): 173 | """ Vision Transformer """ 174 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 175 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 176 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 177 | super().__init__() 178 | self.num_features = self.embed_dim = embed_dim 179 | 180 | self.patch_embed = PatchEmbed( 181 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 182 | num_patches = self.patch_embed.num_patches 183 | 184 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 185 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 186 | self.pos_drop = nn.Dropout(p=drop_rate) 187 | 188 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 189 | self.blocks = nn.ModuleList([ 190 | Block( 191 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 192 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 193 | for i in range(depth)]) 194 | self.norm = norm_layer(embed_dim) 195 | 196 | # Classifier head 197 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 198 | 199 | trunc_normal_(self.pos_embed, std=.02) 200 | trunc_normal_(self.cls_token, std=.02) 201 | self.apply(self._init_weights) 202 | 203 | def _init_weights(self, m): 204 | if isinstance(m, nn.Linear): 205 | trunc_normal_(m.weight, std=.02) 206 | if isinstance(m, nn.Linear) and m.bias is not None: 207 | nn.init.constant_(m.bias, 0) 208 | elif isinstance(m, nn.LayerNorm): 209 | nn.init.constant_(m.bias, 0) 210 | nn.init.constant_(m.weight, 1.0) 211 | 212 | def interpolate_pos_encoding(self, x, w, h): 213 | npatch = x.shape[1] - 1 214 | N = self.pos_embed.shape[1] - 1 215 | if npatch == N and w == h: 216 | return self.pos_embed 217 | class_pos_embed = self.pos_embed[:, 0] 218 | patch_pos_embed = self.pos_embed[:, 1:] 219 | dim = x.shape[-1] 220 | w0 = w // self.patch_embed.patch_size 221 | h0 = h // self.patch_embed.patch_size 222 | # we add a small number to avoid floating point error in the interpolation 223 | # see discussion at https://github.com/facebookresearch/dino/issues/8 224 | w0, h0 = w0 + 0.1, h0 + 0.1 225 | patch_pos_embed = nn.functional.interpolate( 226 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 227 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 228 | mode='bicubic', 229 | ) 230 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 231 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 232 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 233 | 234 | def prepare_tokens(self, x): 235 | B, nc, w, h = x.shape 236 | x = self.patch_embed(x) # patch linear embedding 237 | 238 | # add the [CLS] token to the embed patch tokens 239 | cls_tokens = self.cls_token.expand(B, -1, -1) 240 | x = torch.cat((cls_tokens, x), dim=1) 241 | 242 | # add positional encoding to each token 243 | x = x + self.interpolate_pos_encoding(x, w, h) 244 | 245 | return self.pos_drop(x) 246 | 247 | def forward(self, x): 248 | x = self.prepare_tokens(x) 249 | for blk in self.blocks: 250 | x = blk(x) 251 | x = self.norm(x) 252 | return x[:, 0] 253 | 254 | def get_last_selfattention(self, x): 255 | x = self.prepare_tokens(x) 256 | for i, blk in enumerate(self.blocks): 257 | if i < len(self.blocks) - 1: 258 | x = blk(x) 259 | else: 260 | # return attention of the last block 261 | return blk(x, return_attention=True) 262 | 263 | def get_intermediate_layers(self, x, n=1): 264 | x = self.prepare_tokens(x) 265 | # we return the output tokens from the `n` last blocks 266 | output = [] 267 | for i, blk in enumerate(self.blocks): 268 | x = blk(x) 269 | if len(self.blocks) - i <= n: 270 | output.append(self.norm(x)) 271 | return output 272 | 273 | 274 | def vit_tiny(patch_size=16, **kwargs): 275 | model = VisionTransformer( 276 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 277 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 278 | return model 279 | 280 | 281 | def vit_small(patch_size=16, **kwargs): 282 | model = VisionTransformer( 283 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 284 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 285 | return model 286 | 287 | 288 | def vit_base(patch_size=16, **kwargs): 289 | model = VisionTransformer( 290 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 291 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 292 | return model 293 | 294 | 295 | class DINOHead(nn.Module): 296 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 297 | super().__init__() 298 | nlayers = max(nlayers, 1) 299 | if nlayers == 1: 300 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 301 | else: 302 | layers = [nn.Linear(in_dim, hidden_dim)] 303 | if use_bn: 304 | layers.append(nn.BatchNorm1d(hidden_dim)) 305 | layers.append(nn.GELU()) 306 | for _ in range(nlayers - 2): 307 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 308 | if use_bn: 309 | layers.append(nn.BatchNorm1d(hidden_dim)) 310 | layers.append(nn.GELU()) 311 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 312 | self.mlp = nn.Sequential(*layers) 313 | self.apply(self._init_weights) 314 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 315 | self.last_layer.weight_g.data.fill_(1) 316 | if norm_last_layer: 317 | self.last_layer.weight_g.requires_grad = False 318 | 319 | def _init_weights(self, m): 320 | if isinstance(m, nn.Linear): 321 | trunc_normal_(m.weight, std=.02) 322 | if isinstance(m, nn.Linear) and m.bias is not None: 323 | nn.init.constant_(m.bias, 0) 324 | 325 | def forward(self, x): 326 | x = self.mlp(x) 327 | x = nn.functional.normalize(x, dim=-1, p=2) 328 | x = self.last_layer(x) 329 | return x 330 | -------------------------------------------------------------------------------- /code/gma_training/README.md: -------------------------------------------------------------------------------- 1 | # Training Script 2 | 3 | Code to train a Gated-MIL Attention (GMA) aggregation model for all benchmarking tasks. 4 | Usage: 5 | ``` 6 | python train.py \ 7 | --output output/directory \ 8 | --data benchmark_task # one of the 20 tasks included\ 9 | --encoder foundation_model \ 10 | --mccv 1 # 1-20 monte carlo cross validation runs 11 | --lr 0.0001 12 | ``` 13 | The script will produce a log file named `convergence.csv` with training loss and validation AUC. 14 | -------------------------------------------------------------------------------- /code/gma_training/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | def get_datasets(mccv=0, data='', encoder=''): 7 | # Load slide data 8 | df = pd.read_csv(os.path.join('root/data/directory', data, 'slide_data.csv')) 9 | df['tensor_path'] = [os.path.join(x.tensor_root, encoder, x.tensor_name) for _, x in df.iterrows()] 10 | # Select mccv and clean 11 | df = df.rename(columns={'mccv{}'.format(mccv):'mccvsplit'})[['slide','target','mccvsplit','tensor_path']] 12 | # Split into train and val 13 | df_train = df[df.mccvsplit=='train'].reset_index(drop=True).drop(columns=['mccvsplit']) 14 | df_val = df[df.mccvsplit=='val'].reset_index(drop=True).drop(columns=['mccvsplit']) 15 | # Create my loader objects 16 | dset_train = slide_dataset_classification(df_train) 17 | dset_val = slide_dataset_classification(df_val) 18 | return dset_train, dset_val 19 | 20 | class slide_dataset_classification(object): 21 | ''' 22 | Slide level dataset which returns for each slide the feature matrix (h) and the target 23 | ''' 24 | def __init__(self, df): 25 | self.df = df 26 | 27 | def __len__(self): 28 | # number of slides 29 | return len(self.df) 30 | 31 | def __getitem__(self, index): 32 | row = self.df.iloc[index] 33 | # get the feature matrix for that slide 34 | h = torch.load(row.tensor_path) 35 | # get the target 36 | return h, row.target 37 | -------------------------------------------------------------------------------- /code/gma_training/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Attn_Net_Gated(nn.Module): 7 | 8 | def __init__(self, L = 1024, D = 256, dropout = False, n_tasks = 1): 9 | super(Attn_Net_Gated, self).__init__() 10 | self.attention_a = [ 11 | nn.Linear(L, D), 12 | nn.Tanh()] 13 | 14 | self.attention_b = [nn.Linear(L, D), 15 | nn.Sigmoid()] 16 | if dropout: 17 | self.attention_a.append(nn.Dropout(0.25)) 18 | self.attention_b.append(nn.Dropout(0.25)) 19 | 20 | self.attention_a = nn.Sequential(*self.attention_a) 21 | self.attention_b = nn.Sequential(*self.attention_b) 22 | 23 | self.attention_c = nn.Linear(D, n_tasks) 24 | 25 | def forward(self, x): 26 | a = self.attention_a(x) 27 | b = self.attention_b(x) 28 | A = a.mul(b) 29 | A = self.attention_c(A) # N x n_classes 30 | return A, x 31 | 32 | class GMA(nn.Module): 33 | def __init__(self, ndim=1024, gate = True, size_arg = "big", dropout = False, n_classes = 1, n_tasks=1): 34 | super(GMA, self).__init__() 35 | self.size_dict = {"small": [ndim, 512, 256], "big": [ndim, 512, 384]} 36 | size = self.size_dict[size_arg] 37 | 38 | fc = [nn.Linear(size[0], size[1]), nn.ReLU()] 39 | if dropout: 40 | fc.append(nn.Dropout(0.25)) 41 | fc.extend([nn.Linear(size[1], size[1]), nn.ReLU()]) 42 | if dropout: 43 | fc.append(nn.Dropout(0.25)) 44 | attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_tasks = 1) 45 | 46 | fc.append(attention_net) 47 | self.attention_net = nn.Sequential(*fc) 48 | self.classifier = nn.Linear(size[1], n_classes) 49 | 50 | initialize_weights(self) 51 | 52 | def get_sign(self, h): 53 | A, h = self.attention_net(h)# h: Bx512 54 | w = self.classifier.weight.detach() 55 | sign = torch.mm(h, w.t()) 56 | return sign 57 | 58 | def forward(self, h, attention_only=False): 59 | A, h = self.attention_net(h) 60 | A = torch.transpose(A, 1, 0) 61 | if attention_only: 62 | return A[0] 63 | 64 | A_raw = A.detach().cpu().numpy()[0] 65 | w = self.classifier.weight.detach() 66 | sign = torch.mm(h.detach(), w.t()).cpu().numpy() 67 | 68 | A = F.softmax(A, dim=1) 69 | M = torch.mm(A, h) 70 | 71 | logits = self.classifier(M) 72 | return A_raw, sign, logits 73 | 74 | def initialize_weights(module): 75 | for m in module.modules(): 76 | if isinstance(m, nn.Linear): 77 | nn.init.xavier_normal_(m.weight) 78 | m.bias.data.zero_() 79 | -------------------------------------------------------------------------------- /code/gma_training/train.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import modules 3 | import os 4 | import argparse 5 | import torch.backends.cudnn as cudnn 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import time 12 | import pdb 13 | from sklearn.metrics import roc_auc_score 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | #I/O PARAMS 18 | parser.add_argument('--output', type=str, default='.', help='name of output directory') 19 | parser.add_argument('--log', type=str, default='convergence.csv', help='name of log file') 20 | parser.add_argument('--data', type=str, default='', choices=[ 21 | 'msk_lung_egfr', 22 | 'msk_lung_tp53', 23 | 'msk_lung_kras', 24 | 'msk_lung_stk11', 25 | 'msk_lung_alk', 26 | 'msk_lung_io', 27 | 'sinai_breast_cancer', 28 | 'sinai_breast_er', 29 | 'sinai_breast_pr', 30 | 'sinai_breast_her2', 31 | 'sinai_lung_egfr', 32 | 'sinai_ibd_detection', 33 | 'sinai_bladder_cancer', 34 | 'sinai_colorectal_cancer', 35 | 'sinai_dcis_cancer', 36 | 'sinai_kidney_cancer', 37 | 'sinai_oral_cancer', 38 | 'sinai_prostate_cancer', 39 | 'sinai_thyroid_cancer', 40 | 'biome_breast_hrd' 41 | ], help='which data to use') 42 | parser.add_argument('--encoder', type=str, default='', choices=[ 43 | 'tres50_imagenet', 44 | 'ctranspath', 45 | 'phikon', 46 | 'uni', 47 | 'virchow', 48 | 'gigapath', 49 | 'dinosmall', 50 | 'dinobase' 51 | ], help='which encoder to use') 52 | parser.add_argument('--mccv', default=1, type=int, choices=list(range(1,21)), help='which seed (default: 1/20)') 53 | #OPTIMIZATION PARAMS 54 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 55 | parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of linear warmup (highest LR used during training).""") 56 | parser.add_argument('--lr_end', type=float, default=1e-6, help="""Target LR at the end of optimization. We use a cosine LR schedule with linear warmup.""") 57 | parser.add_argument("--warmup_epochs", default=10, type=int, help="Number of epochs for the linear learning-rate warm up.") 58 | parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the weight decay. With ViT, a smaller value at the beginning of training works well.""") 59 | parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the weight decay. We use a cosine schedule for WD and using a larger decay by the end of training improves performance for ViTs.""") 60 | parser.add_argument('--nepochs', type=int, default=40, help='number of epochs (default: 40)') 61 | parser.add_argument('--workers', default=10, type=int, help='number of data loading workers (default: 10)') 62 | 63 | def main(): 64 | 65 | # Get user input 66 | global args 67 | args = parser.parse_args() 68 | 69 | # Set datasets 70 | train_dset, val_dset = datasets.get_datasets(mccv=args.mccv, data=args.data, encoder=args.encoder) 71 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=1, shuffle=True, num_workers=args.workers) 72 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=1, shuffle=False, num_workers=args.workers) 73 | 74 | # Dim of features 75 | if args.encoder.startswith('tres50'): 76 | args.ndim = 1024 77 | elif args.encoder == 'ctranspath': 78 | args.ndim = 768 79 | elif args.encoder == 'phikon': 80 | args.ndim = 768 81 | elif args.encoder == 'uni': 82 | args.ndim = 1024 83 | elif args.encoder == 'virchow': 84 | args.ndim = 2560 85 | elif args.encoder == 'gigapath': 86 | args.ndim = 1536 87 | elif args.encoder.startswith('dinosmall'): 88 | args.ndim = 384 89 | elif args.encoder.startswith('dinobase'): 90 | args.ndim = 768 91 | 92 | # Get model 93 | gma = modules.GMA(dropout=True, n_classes=2, ndim=args.ndim) 94 | gma.cuda() 95 | 96 | # Set loss 97 | criterion = nn.CrossEntropyLoss().cuda() 98 | 99 | # Set optimizer 100 | params_groups = get_params_groups(gma) 101 | optimizer = optim.AdamW(params_groups) 102 | 103 | # Set schedulers 104 | lr_schedule = cosine_scheduler( 105 | args.lr, 106 | args.lr_end, 107 | args.nepochs, 108 | len(train_loader), 109 | warmup_epochs=args.warmup_epochs, 110 | ) 111 | wd_schedule = cosine_scheduler( 112 | args.weight_decay, 113 | args.weight_decay_end, 114 | args.nepochs, 115 | len(train_loader), 116 | ) 117 | cudnn.benchmark = True 118 | 119 | # Set output files 120 | with open(os.path.join(args.output,args.log), 'w') as fconv: 121 | fconv.write('epoch,metric,value\n') 122 | 123 | # Main training loop 124 | for epoch in range(args.nepochs+1): 125 | 126 | ## Training logic 127 | if epoch > 0: 128 | loss = train(epoch, train_loader, gma, criterion, optimizer, lr_schedule, wd_schedule) 129 | print('Training\tEpoch: [{}/{}]\tLoss: {}'.format(epoch, args.nepochs, loss)) 130 | with open(os.path.join(args.output,args.log), 'a') as fconv: 131 | fconv.write('{},loss,{}\n'.format(epoch, loss)) 132 | 133 | ## Validation logic 134 | probs = test(epoch, val_loader, gma) 135 | auc = roc_auc_score(val_loader.dataset.df.target, probs) 136 | ### Printing stats 137 | print('Validation\tEpoch: [{}/{}]\tAUC: {}'.format(epoch, args.nepochs, auc)) 138 | with open(os.path.join(args.output,args.log), 'a') as fconv: 139 | fconv.write('{},auc,{}\n'.format(epoch, auc)) 140 | 141 | ### Model saving logic 142 | obj = { 143 | 'epoch': epoch, 144 | 'state_dict': gma.state_dict(), 145 | 'auc': auc, 146 | 'optimizer' : optimizer.state_dict() 147 | } 148 | torch.save(obj, os.path.join(args.output,'checkpoint.pth')) 149 | 150 | def test(run, loader, model): 151 | # Set model in test mode 152 | model.eval() 153 | # Initialize probability vector 154 | probs = torch.FloatTensor(len(loader)).cuda() 155 | # Loop through batches 156 | with torch.no_grad(): 157 | for i, (input, _) in enumerate(loader): 158 | print('Inference\tEpoch: [{}/{}]\tBatch: [{}/{}]'.format(run, args.nepochs, i+1, len(loader))) 159 | ## Copy batch to GPU 160 | input = input.squeeze(0).cuda() 161 | ## Forward pass 162 | _, _, output = model(input) 163 | output = F.softmax(output, dim=1) 164 | ## Clone output to output vector 165 | probs[i] = output.detach()[:,1].item() 166 | return probs.cpu().numpy() 167 | 168 | def train(run, loader, model, criterion, optimizer, lr_schedule, wd_schedule): 169 | # Set model in training mode 170 | model.train() 171 | # Initialize loss 172 | running_loss = 0. 173 | # Loop through batches 174 | for i, (input, target) in enumerate(loader): 175 | ## Update weight decay and learning rate according to their schedule 176 | it = len(loader) * (run-1) + i # global training iteration 177 | for j, param_group in enumerate(optimizer.param_groups): 178 | param_group["lr"] = lr_schedule[it] 179 | if j == 0: # only the first group is regularized 180 | param_group["weight_decay"] = wd_schedule[it] 181 | 182 | ## Copy to GPU 183 | input = input.squeeze(0).cuda() 184 | target = target.long().cuda() 185 | ## Forward pass 186 | _, _, output = model(input) 187 | ## Calculate loss 188 | loss = criterion(output, target) 189 | ## Optimization step 190 | optimizer.zero_grad() 191 | loss.backward() 192 | optimizer.step() 193 | ## Store loss 194 | running_loss += loss.item() 195 | print('Training\tEpoch: [{}/{}]\tBatch: [{}/{}]\tLoss: {}'.format(run, args.nepochs, i+1, len(loader), loss.item())) 196 | return running_loss / len(loader) 197 | 198 | def get_params_groups(model): 199 | regularized = [] 200 | not_regularized = [] 201 | for name, param in model.named_parameters(): 202 | if not param.requires_grad: 203 | continue 204 | # we do not regularize biases nor Norm parameters 205 | if name.endswith(".bias") or len(param.shape) == 1: 206 | not_regularized.append(param) 207 | else: 208 | regularized.append(param) 209 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 210 | 211 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 212 | warmup_schedule = np.array([]) 213 | warmup_iters = warmup_epochs * niter_per_ep 214 | if warmup_epochs > 0: 215 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 216 | 217 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 218 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 219 | 220 | schedule = np.concatenate((warmup_schedule, schedule)) 221 | assert len(schedule) == epochs * niter_per_ep 222 | return schedule 223 | 224 | if __name__ == '__main__': 225 | main() 226 | --------------------------------------------------------------------------------