├── .gitignore ├── LICENSE ├── README.md ├── docs └── logos │ └── logo.png ├── environment.yaml ├── examples ├── README.md ├── annotator_2d.py ├── annotator_2d_wsi.py ├── finetuning │ ├── README.md │ ├── annotator_with_finetuned_model.py │ └── finetune_nuclick.py └── train_semantic.py ├── experiments ├── README.md ├── benchmarking │ ├── README.md │ ├── cellvit │ │ ├── cellvit_inference.py │ │ ├── eval_util.py │ │ └── pannuke_semantic.py │ ├── hovernet │ │ ├── eval_util.py │ │ ├── hovernet_inference.py │ │ └── semantic.py │ ├── hovernext │ │ ├── evaluate_ais_hover.py │ │ └── hovernext_inference.py │ ├── instanseg │ │ └── instanseg_inference.py │ ├── outdated │ │ └── cellvitplusplus │ │ │ ├── eval.py │ │ │ └── infer_cvtplus.py │ └── stardist │ │ └── stardist_inference.py ├── data │ ├── README.md │ ├── cvtplus_preproc.py │ ├── dataloaders.py │ ├── get_paths.py │ ├── interactive_data.py │ ├── load_original_data.py │ ├── load_padded_data.py │ ├── semantic │ │ └── get_semantic_paths.py │ └── util.py ├── patho-sam │ ├── README.md │ ├── evaluate_ais.py │ ├── evaluate_amg.py │ ├── evaluate_iterative_prompting.py │ ├── get_results.py │ ├── per_image_eval.py │ ├── push_inference.py │ ├── run_ais.py │ ├── run_amg.py │ ├── run_iter_boxes.py │ ├── run_iter_points.py │ └── util.py ├── semantic_segmentation │ ├── README.md │ ├── benchmark_methods │ │ ├── evaluate_other_methods.py │ │ └── run_biomedparse.py │ ├── generalists │ │ ├── evaluate_pannuke.py │ │ ├── submit_training.py │ │ ├── train_generalist.py │ │ └── train_pannuke.py │ └── results │ │ ├── get_bar_plots.py │ │ └── get_qualitative_plots.py └── training │ ├── README.md │ ├── generalists │ ├── get_generalist_datasets.py │ ├── run_all_finetuning.py │ └── train_generalist_histopathology.py │ └── specialists │ ├── finetune_neutrophils.py │ ├── finetune_specialist.py │ ├── get_specialist_dataset.py │ └── run_all_finetuning.py ├── patho_sam ├── __init__.py ├── __version__.py ├── automatic_segmentation.py ├── evaluation │ ├── class_weights.py │ └── evaluation.py ├── io │ ├── __init__.py │ └── util.py ├── semantic_segmentation.py ├── training │ ├── __init__.py │ ├── semantic_trainer.py │ └── util.py └── util.py ├── pyproject.toml ├── scripts ├── plotting │ └── get_dataset_figure.py └── test_pannuke.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Other stuff 165 | *.png 166 | gpu_jobs/ 167 | checkpoints/ 168 | logs/ 169 | *.out 170 | *.zarr 171 | *.sh 172 | configs/ 173 | inference_utils/ 174 | *.tif 175 | *.svg 176 | *.csv 177 | *.tiff 178 | data/ 179 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Computational Cell Analytics (Research Group of Prof. Dr. Constantin Pape) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything for Histopathology 2 | 3 | 4 | 5 | PathoSAM implements interactive annotation and (automatic) instance and semantic segmentation for histopathology images. It is built on top of [Segment Anything](https://segment-anything.com/) by Meta AI and our prior work [Segment Anything for Microscopy](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html). It specializes Segment Anything for nucleus segmentation in histopathology data. Its core components are: 6 | - The publicly available `patho_sam` models for interactive data annotation that were fine-tuned on openly available histopathology images. 7 | - The `patho_sam` library, which provides training functionality based on [Segment Anything for Microscopy](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html), and supports: 8 | - Application of Segment Anything to histopathology images, including whole-slide images, and fine-tuning on your data. 9 | - Semantic segmentation. 10 | 11 | Based on these components, `patho_sam` enables fast interactive and automatic annotation for histopathology images, see [Usage](#usage) for details. 12 | 13 | ## Installation 14 | 15 | How to install `patho_sam` python library from source? 16 | 17 | To create the environment and install `patho_sam` into it follow these steps: 18 | 19 | 1. Clone the repository: `git clone https://github.com/computational-cell-analytics/patho-sam` 20 | 2. Enter it: `cd patho-sam` 21 | 3. Create the environment with the necessary requirements: `conda env create -f environment.yaml` 22 | 4. Activate the environment: `conda activate patho-sam` 23 | 5. Install `patho_sam`: `pip install -e .` 24 | 25 | ## Usage 26 | 27 | ### Using example scripts: 28 | 29 | See the [examples](./examples/) folder for more details. 30 | 31 | ### Using CLI: 32 | 33 | - Download the example whole-slide image by running the following via terminal: `patho_sam.example_data` (see `patho_sam.example_data -h` for more details about the CLI). 34 | - Run automatic segmentation on your own WSI or the example data by running the following via terminal: 35 | ```bash 36 | patho_sam.automatic_segmentation -i /home/anwai/.cache/micro_sam/sample_data/whole-slide-histopathology-example-image.svs -o segmentation.tif 37 | ``` 38 | 39 | > NOTE 1: See `patho_sam.automatic_segmentation -h` for more details about the CLI. 40 | 41 | > NOTE 2: You can find your cache directory using: `python -c "from micro_sam.util import get_cache_directory; print(get_cache_directory())"`. 42 | 43 | 44 | ## Citation 45 | 46 | If you are using this repository in your research please cite: 47 | - [Our preprint](https://doi.org/10.48550/arXiv.2502.00408). 48 | - the [Segment Anything for Microscopy](https://www.nature.com/articles/s41592-024-02580-4) publication. 49 | - And the original [Segment Anything](https://arxiv.org/abs/2304.02643) publication. 50 | -------------------------------------------------------------------------------- /docs/logos/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-cell-analytics/patho-sam/482e2e336b9e5b75028c013452a70aba8107207a/docs/logos/logo.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: patho-sam 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - micro_sam 6 | # Note: installing the pytorch package from conda-forge will generally 7 | # give you the most optmized version for your system, if you have a modern 8 | # enough OS and CUDA version (CUDA >= 12). For older versions, you can 9 | # specify the CUDA version by pinning libtorch. 10 | # For example, add this line for a CUDA 11 version: 11 | # - libtorch=*=cuda11* 12 | # or, to enforce a CPU installation, change to 13 | # - "pytorch=*=cpu*" 14 | - pytorch >=2.5 15 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Example Scripts 2 | 3 | Examples for using the [`micro_sam`](https://github.com/computational-cell-analytics/micro-sam) annotation tools: 4 | 5 | - `annotator_2d_wsi.py`: Run the interactive 2d annotation tool, suitable for whole-slide images (WSIs). 6 | - `annotator_2d.py`: Run the interactive 2d annotation tool. 7 | - `train_pannuke_semantic.py`: Train a model for semantic segmentation of nuclei in PanNuke histopathology images. 8 | 9 | There are Jupyter Notebooks available for using automatic segmentation and finetuning on some example data in the [notebooks](https://github.com/computational-cell-analytics/micro-sam/tree/master/notebooks) folder located in the `micro-sam` repository. 10 | 11 | The folder `finetuning` contains example scripts that show how a Segment Anything model can be fine-tuned on custom data with the `micro_sam.train` library, and how the finetuned models can then be used within the `micro_sam` annotation tools. 12 | -------------------------------------------------------------------------------- /examples/annotator_2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from micro_sam.util import get_cache_directory 4 | from micro_sam.sam_annotator import annotator_2d 5 | from micro_sam.sample_data import fetch_wholeslide_histopathology_example_data 6 | 7 | 8 | from patho_sam.io import read_wsi 9 | 10 | 11 | DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") 12 | 13 | 14 | def whole_slide_image_annotator(use_finetuned_model): 15 | """Run the 2d annotator for an example histopathology image. 16 | 17 | See 'fetch_wholeslide_histopathology_example_data' for details on the data. 18 | """ 19 | example_data = fetch_wholeslide_histopathology_example_data(DATA_CACHE) 20 | 21 | # Read a small ROI for the 2d annotator. 22 | image = read_wsi(example_data, image_size=(10000, 15000, 512, 512)) # spatial shape: (512, 512) 23 | 24 | if use_finetuned_model: 25 | model_type = "vit_b_histopathology" 26 | else: 27 | model_type = "vit_b" 28 | 29 | # Store embeddings in a desired shape. 30 | save_path = f"./embedding_{model_type}.zarr" 31 | 32 | annotator_2d(image=image, embedding_path=save_path, model_type=model_type) 33 | 34 | 35 | def main(): 36 | # Whether to use the fine-tuned SAM model for WSIs. 37 | use_finetuned_model = True 38 | 39 | # 2d annotator for WSI data. 40 | whole_slide_image_annotator(use_finetuned_model) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /examples/annotator_2d_wsi.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from micro_sam.util import get_cache_directory 4 | from micro_sam.sam_annotator import annotator_2d 5 | from micro_sam.sample_data import fetch_wholeslide_histopathology_example_data 6 | 7 | from patho_sam.io import read_wsi 8 | 9 | 10 | DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") 11 | 12 | 13 | def whole_slide_image_annotator(use_finetuned_model): 14 | """Run the 2d annotator for an example whole-slide image for histopathology. 15 | 16 | See 'fetch_wholeslide_histopathology_example_data' for details on the data. 17 | """ 18 | example_data = fetch_wholeslide_histopathology_example_data(DATA_CACHE) 19 | 20 | # Load the WSI image. 21 | image = read_wsi(example_data) 22 | 23 | if use_finetuned_model: 24 | model_type = "vit_b_histopathology" 25 | else: 26 | model_type = "vit_b" 27 | 28 | # Store embeddings in a desired shape. 29 | save_path = f"./embedding_{model_type}.zarr" 30 | 31 | annotator_2d( 32 | image=image, 33 | embedding_path=save_path, 34 | model_type=model_type, 35 | tile_shape=(384, 384), 36 | halo=(64, 64), 37 | precompute_amg_state=True, 38 | ) 39 | 40 | 41 | def main(): 42 | # Whether to use the fine-tuned SAM model for WSIs. 43 | use_finetuned_model = True 44 | 45 | # 2d annotator for WSI data. 46 | whole_slide_image_annotator(use_finetuned_model) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /examples/finetuning/README.md: -------------------------------------------------------------------------------- 1 | # Example for model finetuning 2 | 3 | This folder contains example scripts that show how to finetune a SAM model on your own data and how the finetuned model can be used: 4 | 5 | - `finetune_nuclick.py`: Shows how to finetune the model on new data. Set `train_instance_segmentation` to `True` in order to also train a decoder for automatic instance segmentation. 6 | - `annotator_with_finetuned_model.py`: Use the finetuned model in the 2d annotator. 7 | -------------------------------------------------------------------------------- /examples/finetuning/annotator_with_finetuned_model.py: -------------------------------------------------------------------------------- 1 | import imageio.v3 as imageio 2 | 3 | from micro_sam.sam_annotator import annotator_2d 4 | 5 | 6 | def run_annotator_with_finetuned_model(): 7 | """Run the 2d annotator with a custom (finetuned) model. 8 | 9 | Here, we use the model that is produced by `finetune_nuclick.py` and apply it 10 | for an image from the validation set. 11 | """ 12 | # take the last frame, which is part of the val set, so the model was not directly trained on it 13 | im = imageio.imread("./data/IHC_nuclick/IHC/images/Validation/ROI_338_2.png") 14 | 15 | # set the checkpoint and the path for caching the embeddings 16 | checkpoint = "./finetuned_nuclick_model.pth" 17 | embedding_path = "./embeddings/embeddings-finetuned.zarr" 18 | 19 | # Adapt this if you finetune a different model type, e.g. vit_h. 20 | model_type = "vit_b" # We finetune a vit_b in the example script. 21 | 22 | # Run the 2d annotator with the custom model. 23 | annotator_2d(im, model_type=model_type, embedding_path=embedding_path, checkpoint=checkpoint) 24 | 25 | 26 | if __name__ == "__main__": 27 | run_annotator_with_finetuned_model() 28 | -------------------------------------------------------------------------------- /examples/finetuning/finetune_nuclick.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | import torch_em 6 | from torch_em.data import MinInstanceSampler 7 | from torch_em.data.datasets import get_nuclick_dataset 8 | from torch_em.transform.label import PerObjectDistanceTransform 9 | 10 | import micro_sam.training as sam_training 11 | from micro_sam.util import export_custom_sam_model 12 | 13 | from patho_sam.training import get_train_val_split, histopathology_identity 14 | 15 | 16 | DATA_FOLDER = "data" 17 | 18 | 19 | def get_dataloaders(batch_size, patch_shape, train_instance_segmentation): 20 | """This returns the NuClick dataloaders implemented in `torch-em`. 21 | https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/nuclick.py 22 | It will automatically download the NuClick data. 23 | 24 | NOTE: To replace this with another data loader, you need to return a torch data loader 25 | that returns `x, y` tensors, where `x` is the image data and `y` are corresponding labels. 26 | The labels have to be in a label mask instance segmentation format. 27 | i.e. a tensor of the same spatial shape as `x`, with instance labels for objects. 28 | Important: the ID 0 is reserved for background, and the IDS must be consecutive. 29 | 30 | See https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/finetuning/finetune_hela.py 31 | for more details on how to create your custom dataloaders. 32 | """ 33 | os.makedirs(DATA_FOLDER, exist_ok=True) 34 | 35 | # All relevant stuff for the dataset. 36 | raw_transform = histopathology_identity # Avoids normalizing the inputs, i.e. keeps the intensities b/w [0, 255]. 37 | sampler = MinInstanceSampler(min_num_instances=2, min_size=10) # Ensures at least 2 foreground objects per sample. 38 | label_dtype = torch.float32 # Converts labels to expected dtype. 39 | 40 | if train_instance_segmentation: 41 | # Computes the distance transform for objects to perform end-to-end automatic instance segmentation. 42 | label_transform = PerObjectDistanceTransform( 43 | distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, 44 | ) 45 | else: 46 | label_transform = torch_em.transform.label.connected_components 47 | 48 | dataset = get_nuclick_dataset( 49 | path=DATA_FOLDER, 50 | patch_shape=patch_shape, 51 | split="Train", 52 | sampler=sampler, 53 | label_dtype=label_dtype, 54 | label_transform=label_transform, 55 | raw_transform=raw_transform, 56 | download=True, # This will download the image and segmentation data for training. 57 | ) 58 | 59 | # Get the datasets. 60 | train_ds, val_ds = get_train_val_split(dataset) 61 | 62 | # Get the dataloaders. 63 | train_loader = torch_em.get_data_loader(train_ds, batch_size=batch_size, shuffle=True) 64 | val_loader = torch_em.get_data_loader(val_ds, batch_size=batch_size, shuffle=True) 65 | 66 | return train_loader, val_loader 67 | 68 | 69 | def run_training(checkpoint_name, model_type, train_instance_segmentation): 70 | """Run the actual model training.""" 71 | 72 | # All hyperparameters for training. 73 | batch_size = 1 # the training batch size 74 | patch_shape = (512, 512) # the size of patches for training 75 | n_objects_per_batch = 25 # the number of objects per batch that will be sampled 76 | device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training. 77 | 78 | # Get the dataloaders. 79 | train_loader, val_loader = get_dataloaders(batch_size, patch_shape, train_instance_segmentation) 80 | 81 | # Run training. 82 | sam_training.train_sam( 83 | name=checkpoint_name, 84 | model_type=model_type, 85 | train_loader=train_loader, 86 | val_loader=val_loader, 87 | n_epochs=100, 88 | n_objects_per_batch=n_objects_per_batch, 89 | with_segmentation_decoder=train_instance_segmentation, 90 | device=device, 91 | ) 92 | 93 | 94 | def export_model(checkpoint_name, model_type): 95 | """Export the trained model.""" 96 | export_path = "./finetuned_nuclick_model.pth" 97 | checkpoint_path = os.path.join("checkpoints", checkpoint_name, "best.pt") 98 | export_custom_sam_model(checkpoint_path=checkpoint_path, model_type=model_type, save_path=export_path) 99 | 100 | 101 | def main(): 102 | """Finetune a Segment Anything model. 103 | 104 | This example uses image data and segmentations from the NuClick dataset for lymphocyte segmentation, 105 | but can be easily adapted for other data (including data you have annotated with 'micro_sam' beforehand). 106 | """ 107 | # The 'model_type' determines which base model is used to initialize the weights that are finetuned. 108 | # We use 'vit_b' here becaise it can be trained faster. Note that 'vit_h' usually yields higher quality results. 109 | model_type = "vit_b" 110 | 111 | # The name of the checkpoint. The checkpoints will be stored in './checkpoints/' 112 | checkpoint_name = "sam_nuclick" 113 | 114 | # Train an additional convolutional decoer for end-to-end automatic instance segmentation. 115 | train_instance_segmentation = True 116 | 117 | run_training(checkpoint_name, model_type, train_instance_segmentation) 118 | export_model(checkpoint_name, model_type) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /examples/train_semantic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | import torch_em 6 | from torch_em.data import MinTwoInstanceSampler 7 | from torch_em.data.datasets import get_pannuke_dataset 8 | 9 | import micro_sam.training as sam_training 10 | from micro_sam.instance_segmentation import get_unetr 11 | 12 | from patho_sam.training import SemanticInstanceTrainer, get_train_val_split 13 | 14 | 15 | DATA_FOLDER = "data" 16 | 17 | 18 | def get_dataloaders(data_path): 19 | """This returns the PanNuke dataloaders implemented in `torch-em`. 20 | https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/pannuke.py 21 | It will automatically download the PanNuke data. 22 | 23 | NOTE: To replace this with another data loader, you need to return a torch data loader 24 | that returns `x, y` tensors, where `x` is the image data and `y` are corresponding labels. 25 | The labels have to be in a label mask semantic segmentation format. 26 | i.e. a tensor of the same spatial shape as `x`, with semantic labels for objects. 27 | Important: the ID 0 is reserved for background and ensure you have all semantic classes. 28 | """ 29 | # All relevant stuff for the dataset. 30 | raw_transform = sam_training.identity # Avoids normalizing the inputs, i.e. keeps the intensities b/w [0, 255]. 31 | sampler = MinTwoInstanceSampler() # Ensures that atleast one foreground class is obtained. 32 | label_dtype = torch.float32 # Converts labels to expected dtype. 33 | 34 | # Get the dataset 35 | dataset = get_pannuke_dataset( 36 | path=data_path, 37 | patch_shape=(1, 512, 512), 38 | ndim=2, 39 | folds=["fold_1", "fold_2"], 40 | custom_label_choice="semantic", 41 | sampler=sampler, 42 | label_dtype=label_dtype, 43 | raw_transform=raw_transform, 44 | download=True, 45 | ) 46 | 47 | # Create custom splits. 48 | train_dataset, val_dataset = get_train_val_split(dataset) 49 | 50 | # Get the dataloaders. 51 | train_loader = torch_em.get_data_loader(dataset=train_dataset, batch_size=1, shuffle=True) 52 | val_loader = torch_em.get_data_loader(dataset=val_dataset, batch_size=1, shuffle=True) 53 | 54 | return train_loader, val_loader 55 | 56 | 57 | def train_pannuke_semantic_segmentation(checkpoint_name, model_type): 58 | """Script for semantic segmentation for PanNuke data.""" 59 | 60 | # Parameters for training 61 | num_classes = 6 # available classes are [0, 1, 2, 3, 4, 5] 62 | device = "cuda" if torch.cuda.is_available() else "cpu" 63 | train_loader, val_loader = get_dataloaders(data_path=os.path.join(DATA_FOLDER, "pannuke")) 64 | 65 | # Get the trainable Segment Anything Model. 66 | model = sam_training.get_trainable_sam_model( 67 | model_type=model_type, 68 | device=device, 69 | checkpoint_path=None, # override to provide filepath for your trained SAM model. 70 | ) 71 | 72 | # Get the UNETR model for semantic segmentation pipeline 73 | unetr = get_unetr( 74 | image_encoder=model.sam.image_encoder, device=device, out_channels=num_classes, flexible_load_checkpoint=True, 75 | ) 76 | 77 | # All other stuff we need for training 78 | optimizer = torch.optim.AdamW(unetr.parameters(), lr=1e-4) 79 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=5) 80 | 81 | # Converts the per-batch inputs and corresponding labels to desired format required for training. 82 | convert_inputs = sam_training.util.ConvertToSemanticSamInputs() 83 | 84 | # Trainer for semantic segmentation (implemented using 'torch_em') 85 | trainer = SemanticInstanceTrainer( 86 | name=checkpoint_name, 87 | train_loader=train_loader, 88 | val_loader=val_loader, 89 | model=unetr, 90 | optimizer=optimizer, 91 | device=device, 92 | lr_scheduler=scheduler, 93 | log_image_interval=100, 94 | mixed_precision=True, 95 | compile_model=False, 96 | convert_inputs=convert_inputs, 97 | num_classes=num_classes, 98 | dice_weight=0, # override to use weighted dice-cross entropy loss. the trainer uses cross-entropy loss only. 99 | ) 100 | trainer.fit(epochs=100) 101 | 102 | 103 | def main(): 104 | """Finetune a Segment Anything model for semantic segmentation on the PanNuke dataset. 105 | 106 | This example uses image data and semantic segmentation labels for the PanNule dataset, 107 | but can easily be adapted for other data (including data you have annotated with patho_sam beforehand). 108 | NOTE: You must provide semantic class labels to train within this setup. 109 | """ 110 | # The model_type determines which base model is used to initialize the weights that are finetuned. 111 | # We use 'vit_b' here because it can be trained faster. Note that 'vit_h' yields higher quality results. 112 | model_type = "vit_b" 113 | 114 | # The name of checkpoint. The checkpoints will be stored in './checkpoints/'. 115 | checkpoint_name = "pannuke_semantic" 116 | 117 | train_pannuke_semantic_segmentation(checkpoint_name, model_type) 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything for Histopathology *Experiment Scripts* 2 | 3 | This directory contains all experiment scripts for building and benchmarking segment anything model for histopathology. 4 | 5 | Here is a brief description of relevant folders: 6 | - `benchmarking`: Contains scripts to run reference methods (eg. HoVerNet, HoVerNeXt, CellViT, InstanSeg, StarDist) 7 | - `data`: Contains scripts for preprocessing images and corresponding labels for all experiments. 8 | - `patho-sam`: Contains scripts for running SAM-models (eg. default SAM, `micro-sam` and `patho-sam` models for automatic and interactive instance segmentation). 9 | - `semantic_segmentation`: Contains scripts for training and evaluation semantic segmentation using `patho-sam` generalist models. 10 | - `training`: Contains scripts for training the specialist and generalist `patho-sam` models. 11 | 12 | > NOTE 1: There are scripts where the expected filepaths / directories might be hard-coded. Replace them with your respective paths where you would like to store / fetch files from. 13 | 14 | >NOTE 2: We provide [example scripts](../examples/) for convenience. Feel free to check them out for both interactive and automatic segmentation using `PathoSAM` models. 15 | -------------------------------------------------------------------------------- /experiments/benchmarking/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking Methods for Automatic Instance (and Semantic) Segmentation 2 | 3 | This folder contains methods that are used to compare automatic segmentation of `patho-sam` with the following: 4 | - `cellvit`: Contains scripts to run CellViT for automatic instance segmentation and classification with several model variants. 5 | - `hovernet`: Contains scripts to run HoVerNet for automatic instance segmentation and classifciation with several model variants. 6 | - `hovernext`: Contains scripts to run HoVerNeXt for automatic instance segmentation and classification with several model variants. 7 | - `instanseg`: Contains scripts to run InstanSeg for automatic instance segmentation. 8 | - `stardist`: Contains scripts to run StarDist for automatic instance segmentation. 9 | - Other miscellanous folders: 10 | - `outdated`: Contains other scripts either outdated or under development. 11 | -------------------------------------------------------------------------------- /experiments/benchmarking/cellvit/cellvit_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import argparse 5 | 6 | from eval_util import evaluate_cellvit 7 | 8 | 9 | def run_inference(input_dir, output_dir, dataset, checkpoint_path): 10 | if dataset not in ["pannuke", "nuclick", "srsanet", "lizard", "cpm15", "consep", "cpm17", "monuseg"]: 11 | data_dir = os.path.join(input_dir, "original_data", dataset, "eval_split") 12 | else: 13 | data_dir = os.path.join(input_dir, "vit_data", dataset, "eval_split") 14 | 15 | checkpoint = os.path.splitext(os.path.basename(checkpoint_path))[0].split("-", 1)[1] 16 | 17 | result_dir = os.path.join(output_dir, "results") 18 | output_path = os.path.join(output_dir, dataset, checkpoint) 19 | os.makedirs(output_path, exist_ok=True) 20 | args = [ 21 | "--model", f"{checkpoint_path}", 22 | "--outdir", f"{output_path}", 23 | "--magnification", "40", 24 | "--data", f"{data_dir}", 25 | ] 26 | 27 | command = [ 28 | "python", os.path.expanduser("~/CellViT/cell_segmentation/inference/inference_cellvit_experiment_monuseg.py"), 29 | ] + args 30 | 31 | print(f"Running inference with CellViT {checkpoint} model on {dataset} dataset...") 32 | subprocess.run(command) 33 | 34 | plot_dir = os.path.join(output_dir, dataset, checkpoint, dataset, "plots") 35 | if os.path.exists(plot_dir): 36 | shutil.rmtree(plot_dir) 37 | 38 | if os.path.exists(os.path.join(output_path, "inference_monuseg.log")): 39 | os.remove(os.path.join(output_path, "inference_monuseg.log")) 40 | 41 | evaluate_cellvit(output_path, checkpoint, dataset, data_dir, result_dir) 42 | print(f"Successfully ran inference with CellViT {checkpoint} model on {dataset} dataset") 43 | 44 | 45 | def get_cellvit_args(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data") 48 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on") 49 | parser.add_argument("-o", "--output_path", type=str, default=None, help="The path where the results are saved") 50 | parser.add_argument("-c", "--checkpoint_path", type=str, default=None, help="The checkpoints to use for inference") 51 | args = parser.parse_args() 52 | return args 53 | 54 | 55 | def main(): 56 | args = get_cellvit_args() 57 | run_inference( 58 | input_dir=args.input, 59 | output_dir=args.output_path, 60 | datasets=[args.dataset], 61 | checkpoints=[args.checkpoint_path], 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /experiments/benchmarking/cellvit/eval_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | from glob import glob 4 | from tqdm import tqdm 5 | from natsort import natsorted 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import imageio.v3 as imageio 10 | from skimage.measure import label 11 | 12 | from elf.evaluation import mean_segmentation_accuracy 13 | 14 | 15 | def zip_predictions(path, target_dir): 16 | print(f"Zipping {path}...") 17 | zip_name = os.path.basename(path) + ".zip" 18 | zip_path = os.path.join(target_dir, zip_name) 19 | with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: 20 | for root, _dirs, files in os.walk(path): 21 | for file in files: 22 | file_path = os.path.join(root, file) 23 | arcname = os.path.relpath(file_path, start=path) 24 | zipf.write(file_path, arcname) 25 | print("Successfully zipped results") 26 | 27 | 28 | def _run_evaluation(gt_paths, prediction_paths, verbose=True): 29 | print(len(gt_paths), len(prediction_paths)) 30 | assert len(gt_paths) == len(prediction_paths), \ 31 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}" 32 | 33 | msas, sa50s, sa75s = [], [], [] 34 | for gt_path, pred_path in tqdm( 35 | zip(gt_paths, prediction_paths), 36 | desc="Evaluate predictions", 37 | total=len(gt_paths), 38 | disable=not verbose, 39 | ): 40 | assert os.path.exists(gt_path), gt_path 41 | assert os.path.exists(pred_path), pred_path 42 | 43 | gt = imageio.imread(gt_path) 44 | gt = label(gt) 45 | pred = imageio.imread(pred_path) 46 | 47 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True) 48 | sa50, sa75 = scores[0], scores[5] 49 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75) 50 | 51 | return msas, sa50s, sa75s 52 | 53 | 54 | def evaluate_cellvit(prediction_dir, checkpoint, dataset, label_dir, result_dir): 55 | save_path = os.path.join(dataset, checkpoint, f'{dataset}_cellvit_{checkpoint}_ais_result.csv') 56 | if os.path.exists(save_path): 57 | print(f"Results for {dataset} evaluation already exist") 58 | return 59 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 60 | gt_paths = natsorted(glob(os.path.join(label_dir, "test_labels", "*"))) 61 | if len(prediction_paths) == 0: 62 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found") 63 | return 64 | 65 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths) 66 | results = pd.DataFrame.from_dict( 67 | { 68 | "mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)], 69 | } 70 | ) 71 | print(results.head(2)) 72 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True) 73 | results.to_csv(save_path, index=False) 74 | -------------------------------------------------------------------------------- /experiments/benchmarking/cellvit/pannuke_semantic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | 6 | 7 | CVT_CP = ["256-x20", "256-x40", "SAM-H-x20", "SAM-H-x40"] 8 | 9 | 10 | def run_inference(checkpoint_path, input_dir, output_dir, dataset): 11 | data_dir = os.path.join(input_dir, dataset, "eval_split") 12 | checkpoint = os.path.splitext((os.path.basename(checkpoint_path)[8:]))[0] 13 | output_path = os.path.join(output_dir, dataset, checkpoint) 14 | os.makedirs(output_path, exist_ok=True) 15 | args = [ 16 | "--model", f"{checkpoint_path}", 17 | "--outdir", f"{output_path}", 18 | "--magnification", "40", 19 | "--data", f"{data_dir}", 20 | ] 21 | 22 | command = [ 23 | "python3", 24 | os.path.expanduser("~/CellViT/cell_segmentation/inference/inference_cellvit_experiment_pannuke.py"), # noqa 25 | ] + args 26 | 27 | print(f"Running inference with CellViT {checkpoint} model on {dataset} dataset...") 28 | subprocess.run(command) 29 | plot_dir = os.path.join(output_path, "plots") 30 | if os.path.exists(plot_dir): 31 | shutil.rmtree(plot_dir) 32 | 33 | print(f"Successfully ran inference with CellViT {checkpoint} model on {dataset} dataset") 34 | 35 | 36 | def get_cellvit_args(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data") 39 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on") 40 | parser.add_argument("-o", "--output_path", type=str, default=None, help="The path where the results are saved") 41 | parser.add_argument("-c", "--checkpoint_path", type=str, default=None, help="The checkpoints to use for inference") 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def main(): 47 | args = get_cellvit_args() 48 | run_inference( 49 | input_dir=args.input, 50 | output_dir=args.output_path, 51 | datasets=args.dataset, 52 | checkpoints=args.checkpoint_path, 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /experiments/benchmarking/hovernet/eval_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | from natsort import natsorted 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import imageio.v3 as imageio 9 | from skimage.measure import label 10 | 11 | from elf.evaluation import mean_segmentation_accuracy 12 | 13 | 14 | DATASETS = [ 15 | "consep", 16 | "cpm15", 17 | "cpm17", 18 | "cryonuseg", 19 | "lizard", 20 | "lynsec_he", 21 | "lynsec_ihc", 22 | "monuseg", 23 | "nuclick", 24 | "nuinsseg", 25 | "pannuke", 26 | "puma", 27 | "srsanet", 28 | "tnbc", 29 | ] 30 | 31 | 32 | def _run_evaluation(gt_paths, prediction_paths, verbose=True): 33 | assert len(gt_paths) == len( 34 | prediction_paths 35 | ), f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}" 36 | msas, sa50s, sa75s = [], [], [] 37 | 38 | for gt_path, pred_path in tqdm( 39 | zip(gt_paths, prediction_paths), 40 | desc="Evaluate predictions", 41 | total=len(gt_paths), 42 | disable=not verbose, 43 | ): 44 | assert os.path.exists(gt_path), gt_path 45 | assert os.path.exists(pred_path), pred_path 46 | 47 | gt = imageio.imread(gt_path) 48 | gt = label(gt) 49 | pred = imageio.imread(pred_path) 50 | 51 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True) 52 | sa50, sa75 = scores[0], scores[5] 53 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75) 54 | 55 | return msas, sa50s, sa75s 56 | 57 | 58 | def evaluate_hovernet(prediction_dir, label_dir, result_dir, dataset, checkpoint): 59 | gt_paths = natsorted(glob(os.path.join(label_dir, "*"))) 60 | save_path = os.path.join(result_dir, dataset, checkpoint, f'{dataset}_hovernet_{checkpoint}_ais_result.csv') 61 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*.tiff"))) 62 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True) 63 | print(f"evaluation {dataset} dataset on checkpoint {checkpoint} ...") 64 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths) 65 | results = pd.DataFrame.from_dict( 66 | { 67 | "mSA": [np.mean(msas)], 68 | "SA50": [np.mean(sa50s)], 69 | "SA75": [np.mean(sa75s)], 70 | } 71 | ) 72 | print(results.head()) 73 | results.to_csv(save_path, index=False) 74 | -------------------------------------------------------------------------------- /experiments/benchmarking/hovernet/hovernet_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | from glob import glob 6 | from tqdm import tqdm 7 | from natsort import natsorted 8 | 9 | import imageio as imageio 10 | from scipy.io import loadmat 11 | 12 | from eval_util import evaluate_hovernet 13 | 14 | 15 | def mat_to_tiff(path): 16 | pred_mat_paths = [p for p in natsorted(glob(os.path.join(path, "mat", "*.mat")))] 17 | for mpath in tqdm(pred_mat_paths, desc="Preprocessing labels"): 18 | pred_path = os.path.join(path, os.path.basename(mpath.replace(".mat", ".tiff"))) 19 | pred = loadmat(mpath)["inst_map"] 20 | imageio.imwrite(pred_path, pred) 21 | 22 | 23 | def run_inference(input_dir, output_dir, dataset, checkpoint_path): 24 | checkpoint = os.path.basename(checkpoint_path).split("_")[2] 25 | output_path = os.path.join(output_dir, "inference", dataset, checkpoint) 26 | input_path = os.path.join(input_dir, dataset, "eval_split", "test_images") 27 | 28 | os.makedirs(output_path, exist_ok=True) 29 | if checkpoint in ["consep", "cpm17", "kumar"]: 30 | model_mode = "original" 31 | nr_types = 0 32 | type_info = "" 33 | else: 34 | model_mode = "fast" 35 | type_info = os.path.expanduser("~/hover_net/type_info.json") 36 | if checkpoint == "pannuke": 37 | nr_types = 6 38 | else: 39 | nr_types = 5 40 | 41 | args = [ 42 | "--nr_types", f"{nr_types}", 43 | "--type_info_path", f"{type_info}", 44 | "--model_mode", f"{model_mode}", 45 | "--model_path", f"{checkpoint_path}", 46 | "--nr_inference_workers", "2", 47 | "--nr_post_proc_worker", "0", 48 | "tile", 49 | "--input_dir", f"{input_path}", 50 | "--output_dir", f"{output_path}", 51 | "--save_raw_map", 52 | ] 53 | 54 | command = ["python3", os.path.expanduser("~/hover_net/run_infer.py")] + args 55 | print(f"Running inference with HoVerNet {checkpoint} model on {dataset} dataset...") 56 | 57 | subprocess.run(command) 58 | mat_to_tiff(os.path.join(output_path)) 59 | evaluate_hovernet( 60 | prediction_dir=output_path, 61 | label_dir=os.path.join(input_dir, dataset, "eval_split", "test_labels"), 62 | result_dir=os.path.join(output_dir, "results"), 63 | checkpoint=checkpoint, 64 | dataset=dataset, 65 | ) 66 | shutil.rmtree(os.path.join(output_path, "json")) 67 | shutil.rmtree(os.path.join(output_path, "mat")) 68 | shutil.rmtree(os.path.join(output_path, "overlay")) 69 | print(f"Inference on {dataset} dataset with the HoVerNet {checkpoint} model successfully completed") 70 | 71 | 72 | def get_hovernet_args(): 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data") 75 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on") 76 | parser.add_argument("-o", "--output_dir", type=str, default=None, help="The path where the results are saved") 77 | parser.add_argument( 78 | "-c", "--checkpoint_path", type=str, default=None, 79 | help="The path to the HoVerNet checkpoint to use for inference." 80 | ) 81 | args = parser.parse_args() 82 | return args 83 | 84 | 85 | def main(): 86 | args = get_hovernet_args() 87 | run_inference( 88 | input_dir=args.input, 89 | output_dir=args.output_dir, 90 | dataset=args.dataset, 91 | checkpoint=args.checkpoint_path, 92 | ) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /experiments/benchmarking/hovernet/semantic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from glob import glob 5 | from tqdm import tqdm 6 | from natsort import natsorted 7 | 8 | import cv2 9 | import json 10 | import numpy as np 11 | import imageio.v3 as imageio 12 | 13 | 14 | def json_to_tiff(path): 15 | label_json_paths = [p for p in natsorted(glob(os.path.join(path, "json", "*.json")))] 16 | img_shape = (256, 256) 17 | for mpath in tqdm(label_json_paths, desc="Postprocessing labels"): 18 | label_path = os.path.join(path, os.path.basename(mpath.replace(".json", ".tiff"))) 19 | with open(mpath, 'r') as file: 20 | data = json.load(file) 21 | pred_class_map = np.zeros(img_shape, dtype=np.int32) 22 | for id, cell_data in enumerate(data['nuc'].items(), start=1): 23 | cell_data = cell_data[1] 24 | contour = np.array(cell_data["contour"]) 25 | contour[:, 0] = np.clip(contour[:, 0], 0, img_shape[1]) 26 | contour[:, 1] = np.clip(contour[:, 1], 0, img_shape[0]) 27 | contour = contour.reshape((-1, 1, 2)) 28 | cell_type = cell_data["type"] 29 | contour = np.vstack((contour, [contour[0]])) 30 | contour = contour.astype(np.int32) 31 | cv2.fillPoly(pred_class_map, [contour], cell_type) 32 | 33 | imageio.imwrite(label_path, pred_class_map) 34 | 35 | 36 | def run_inference(model_dir, input_dir, output_dir, type_info_path): 37 | for dataset in ["pannuke"]: 38 | for checkpoint in ["pannuke"]: 39 | output_path = os.path.join(output_dir, "inference", dataset, checkpoint) 40 | input_path = os.path.join(input_dir, dataset, "semantic_split", "test_images") 41 | if os.path.exists(os.path.join(output_dir, "results", dataset, checkpoint, 'ais_result.csv')): 42 | print(f"Inference with HoVerNet model (type: {checkpoint}) on {dataset} dataset already done") 43 | continue 44 | 45 | os.makedirs(output_path, exist_ok=True) 46 | if checkpoint in ["consep", "cpm17", "kumar"]: 47 | model_mode = "original" 48 | model_path = os.path.join( 49 | model_dir, "checkpoints", f"hovernet_original_{checkpoint}_notype_tf2pytorch.tar" 50 | ) 51 | nr_types = 0 52 | else: 53 | model_mode = "fast" 54 | 55 | model_path = os.path.join(model_dir, "checkpoints", f"hovernet_fast_{checkpoint}_type_tf2pytorch.tar") 56 | if checkpoint == "pannuke": 57 | nr_types = 6 58 | else: 59 | nr_types = 5 60 | 61 | args = [ 62 | "--nr_types", f"{nr_types}", 63 | "--type_info_path", "/user/titus.griebel/u12649/hover_net/type_info.json", 64 | "--model_mode", f"{model_mode}", 65 | "--model_path", f"{model_path}", 66 | "--nr_inference_workers", "2", 67 | "--nr_post_proc_worker", "0", 68 | "tile", 69 | "--input_dir", f"{input_path}", 70 | "--output_dir", f"{output_path}", 71 | ] 72 | 73 | command = ["python3", "/user/titus.griebel/u12649/hover_net/run_infer.py"] + args 74 | print(f"Running inference with HoVerNet {checkpoint} model on {dataset} dataset...") 75 | 76 | subprocess.run(command) 77 | json_to_tiff(output_path) 78 | shutil.rmtree(os.path.join(output_path, "json")) 79 | shutil.rmtree(os.path.join(output_path, "mat")) 80 | shutil.rmtree(os.path.join(output_path, "overlay")) 81 | print(f"Inference on {dataset} dataset with the HoVerNet {checkpoint} model successfully completed") 82 | 83 | 84 | run_inference( 85 | model_dir="/mnt/lustre-grete/usr/u12649/models/hovernet", 86 | input_dir="/mnt/lustre-grete/usr/u12649/data/original_data", 87 | output_dir="/mnt/lustre-grete/usr/u12649/models/hovernet_types", 88 | type_info_path="/user/titus.griebel/u12649/hover_net/type_info.json", 89 | ) 90 | -------------------------------------------------------------------------------- /experiments/benchmarking/hovernext/evaluate_ais_hover.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | from natsort import natsorted 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import imageio.v3 as imageio 9 | from skimage.measure import label 10 | 11 | from elf.evaluation import mean_segmentation_accuracy 12 | 13 | 14 | CHECKPOINTS = [ 15 | "lizard_convnextv2_large", 16 | "lizard_convnextv2_base", 17 | "lizard_convnextv2_tiny", 18 | "pannuke_convnextv2_tiny_1", 19 | "pannuke_convnextv2_tiny_2", 20 | "pannuke_convnextv2_tiny_3", 21 | ] 22 | 23 | DATASETS = [ 24 | "consep", 25 | "cpm15", 26 | "cpm17", 27 | "cryonuseg", 28 | "lizard", 29 | "lynsec_he", 30 | "lynsec_ihc", 31 | "monusac", 32 | "monuseg", 33 | "nuclick", 34 | "nuinsseg", 35 | "pannuke", 36 | "puma", 37 | "srsanet", 38 | "tnbc", 39 | ] 40 | 41 | 42 | def _run_evaluation(gt_paths, prediction_paths, verbose=True): 43 | assert len(gt_paths) == len(prediction_paths), \ 44 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}" 45 | 46 | msas, sa50s, sa75s = [], [], [] 47 | for gt_path, pred_path in tqdm( 48 | zip(gt_paths, prediction_paths, strict=False), 49 | desc="Evaluate predictions", 50 | total=len(gt_paths), 51 | disable=not verbose, 52 | ): 53 | assert os.path.exists(gt_path), gt_path 54 | assert os.path.exists(pred_path), pred_path 55 | 56 | gt = imageio.imread(gt_path) 57 | gt = label(gt) 58 | pred = imageio.imread(pred_path) 59 | 60 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True) 61 | sa50, sa75 = scores[0], scores[5] 62 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75) 63 | 64 | return msas, sa50s, sa75s 65 | 66 | 67 | def evaluate_all_datasets_hovernet(prediction_dir, label_dir, result_dir): 68 | for dataset in DATASETS: 69 | gt_paths = natsorted(glob(os.path.join(label_dir, dataset, "eval_split", "test_labels", "*"))) 70 | for checkpoint in CHECKPOINTS: 71 | save_path = os.path.join( 72 | result_dir, dataset, checkpoint, f'{dataset}_hovernext_{checkpoint}_ais_result.csv' 73 | ) 74 | if os.path.exists(save_path): 75 | continue 76 | 77 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, dataset, checkpoint, "*"))) 78 | if len(prediction_paths) == 0: 79 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found") 80 | continue 81 | 82 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True) 83 | print(f"evaluation {dataset} dataset on checkpoint {checkpoint} ...") 84 | 85 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths) 86 | results = pd.DataFrame.from_dict( 87 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]} 88 | ) 89 | print(results.head()) 90 | results.to_csv(save_path, index=False) 91 | 92 | 93 | evaluate_all_datasets_hovernet( 94 | "/mnt/lustre-grete/usr/u12649/models/hovernext/inference", 95 | "/mnt/lustre-grete/usr/u12649/data/original_data", 96 | "/mnt/lustre-grete/usr/u12649/models/hovernext/results", 97 | ) 98 | -------------------------------------------------------------------------------- /experiments/benchmarking/hovernext/hovernext_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | # This script does not work right now; use the same script inside the hovernext clone 5 | 6 | 7 | def run_inference(input_dir, output_dir): 8 | for dataset in [ 9 | "cpm15", 10 | "cpm17", 11 | "cryonuseg", 12 | "janowczyk", 13 | "lizard", 14 | "lynsec", 15 | "monusac", 16 | "monuseg", 17 | "nuinsseg", 18 | "pannuke", 19 | "puma", 20 | "tnbc", 21 | ]: 22 | for model in ["pannuke_convnextv2_tiny_1", "pannuke_convnextv2_tiny_2", "pannuke_convnextv2_tiny_3"]: 23 | output_path = os.path.join(output_dir, dataset, model) 24 | input_path = os.path.join(input_dir, dataset, "loaded_dataset/complete_dataset/eval_split/test_images/*") 25 | if len(os.listdir(output_path)) > 0: 26 | continue 27 | os.makedirs(output_path, exist_ok=True) 28 | args = [ 29 | "--input", f"{input_path}", 30 | "--cp", f"{model}", 31 | "--output_root", f"{output_path}", 32 | "--tile_size", "512", 33 | ] 34 | 35 | command = ["python3", "/user/titus.griebel/u12649/hover_next_inference/main.py"] + args 36 | print(f"Running inference with HoVerNeXt {model} model on {dataset} dataset...") 37 | subprocess.run(command) 38 | print(f"Inference on {dataset} dataset with the HoVerNeXt model {model} successfully completed") 39 | 40 | 41 | run_inference( 42 | input_dir="/mnt/lustre-grete/usr/u12649/data/test", 43 | output_dir="/mnt/lustre-grete/usr/u12649/models/hovernext/inference", 44 | ) 45 | -------------------------------------------------------------------------------- /experiments/benchmarking/instanseg/instanseg_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | from tqdm import tqdm 5 | from natsort import natsorted 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import imageio.v3 as imageio 10 | from skimage.measure import label 11 | 12 | from tukra.io import read_image 13 | from tukra.inference import segment_using_instanseg 14 | 15 | from elf.evaluation import mean_segmentation_accuracy 16 | 17 | 18 | def _run_evaluation(gt_paths, prediction_paths, verbose=True): 19 | assert len(gt_paths) == len(prediction_paths), \ 20 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}" 21 | 22 | msas, sa50s, sa75s = [], [], [] 23 | for gt_path, pred_path in tqdm( 24 | zip(gt_paths, prediction_paths, strict=False), 25 | desc="Evaluate predictions", 26 | total=len(gt_paths), 27 | disable=not verbose, 28 | ): 29 | assert os.path.exists(gt_path), gt_path 30 | assert os.path.exists(pred_path), pred_path 31 | 32 | gt = imageio.imread(gt_path) 33 | gt = label(gt) 34 | pred = imageio.imread(pred_path) 35 | 36 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True) 37 | sa50, sa75 = scores[0], scores[5] 38 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75) 39 | 40 | return msas, sa50s, sa75s 41 | 42 | 43 | def evaluate_instanseg(prediction_dir, label_dir, result_dir, dataset): 44 | gt_paths = natsorted(glob(os.path.join(label_dir, "eval_split", "test_labels", "*"))) 45 | for checkpoint in ["instanseg"]: 46 | save_path = os.path.join(result_dir, dataset, checkpoint, f'{dataset}_instanseg_{checkpoint}_ais_result.csv') 47 | if os.path.exists(save_path): 48 | continue 49 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 50 | if len(prediction_paths) == 0: 51 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found") 52 | continue 53 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True) 54 | print(f"Evaluating {dataset} dataset on InstanSeg ...") 55 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths) 56 | results = pd.DataFrame.from_dict( 57 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]} 58 | ) 59 | 60 | results.to_csv(save_path, index=False) 61 | print(results.head(2)) 62 | 63 | 64 | def infer_instanseg(data_dir, output_path, dataset): 65 | image_paths = natsorted(glob(os.path.join(data_dir, "eval_split", "test_images", "*"))) 66 | os.makedirs(output_path, exist_ok=True) 67 | for image_path in tqdm(image_paths, desc=f"Performing inference on {dataset}"): 68 | image = read_image(image_path) 69 | segmentation = segment_using_instanseg( 70 | image=image, model_type="brightfield_nuclei", target="nuclei", scale="small" 71 | ) 72 | imageio.imwrite(os.path.join(output_path, os.path.basename(image_path).replace(".png", ".tiff")), segmentation) 73 | 74 | 75 | def run_inference(input_dir, output_dir, dataset): 76 | output_path = os.path.join(output_dir, "inference", dataset, "instanseg") 77 | input_path = os.path.join(input_dir, dataset) 78 | if os.path.exists( 79 | os.path.join(output_dir, "results", dataset, "instanseg", f'{dataset}_instanseg_instanseg_ais_result.csv') 80 | ): 81 | return 82 | 83 | os.makedirs(output_path, exist_ok=True) 84 | print(f"Running inference with InstanSeg model on {dataset} dataset... \n") 85 | infer_instanseg(input_path, output_path, dataset) 86 | print(f"Inference on {dataset} dataset with the InstanSeg model successfully completed \n") 87 | 88 | evaluate_instanseg( 89 | prediction_dir=output_path, 90 | label_dir=input_path, 91 | result_dir=os.path.join(output_dir, "results"), 92 | dataset=dataset, 93 | ) 94 | 95 | 96 | def get_instanseg_args(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data") 99 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on") 100 | parser.add_argument("-o", "--output_dir", type=str, default=None, help="The path where the results are saved") 101 | args = parser.parse_args() 102 | return args 103 | 104 | 105 | def main(): 106 | args = get_instanseg_args() 107 | run_inference(input_dir=args.input, output_dir=args.output_dir, dataset=args.dataset) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /experiments/benchmarking/outdated/cellvitplusplus/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | from natsort import natsorted 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import imageio.v3 as imageio 9 | from skimage.measure import label 10 | 11 | from elf.evaluation import mean_segmentation_accuracy 12 | 13 | 14 | DATASETS = [ 15 | "consep", 16 | "cpm15", 17 | "cpm17", 18 | "cryonuseg", 19 | "lizard", 20 | "lynsec_he", 21 | "lynsec_ihc", 22 | "monusac", 23 | "monuseg", 24 | "nuclick", 25 | "nuinsseg", 26 | "pannuke", 27 | "puma", 28 | "srsanet", 29 | "tnbc", 30 | ] 31 | 32 | CVTPP_CP = [ 33 | # 'Virchow-x40-AMP', 34 | "SAM-H-x40-AMP", 35 | # '256-x40-AMP' 36 | ] 37 | 38 | 39 | def _run_evaluation(gt_paths, prediction_paths, verbose=True): 40 | print(len(gt_paths), len(prediction_paths)) 41 | assert len(gt_paths) == len(prediction_paths), \ 42 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}" 43 | 44 | msas, sa50s, sa75s = [], [], [] 45 | for gt_path, pred_path in tqdm( 46 | zip(gt_paths, prediction_paths), desc="Evaluate predictions", total=len(gt_paths), disable=not verbose, 47 | ): 48 | assert os.path.exists(gt_path), gt_path 49 | assert os.path.exists(pred_path), pred_path 50 | 51 | gt = imageio.imread(gt_path) 52 | gt = label(gt) 53 | pred = imageio.imread(pred_path) 54 | 55 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True) 56 | sa50, sa75 = scores[0], scores[5] 57 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75) 58 | 59 | return msas, sa50s, sa75s 60 | 61 | 62 | def evaluate_cellvit(prediction_dir, checkpoint, dataset, result_dir, label_dir): 63 | save_path = os.path.join(result_dir, dataset, checkpoint, "ais_result.csv") 64 | if os.path.exists(save_path): 65 | print(f"Results for {dataset} evaluation already exist") 66 | return 67 | 68 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, dataset, checkpoint, "*.tiff"))) 69 | gt_paths = natsorted(glob(os.path.join(label_dir, "*.tiff"))) 70 | if len(prediction_paths) == 0: 71 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found") 72 | return 73 | 74 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths) 75 | results = pd.DataFrame.from_dict( 76 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]} 77 | ) 78 | print(results.head()) 79 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True) 80 | results.to_csv(save_path, index=False) 81 | 82 | 83 | def main(): 84 | input_dir = "/mnt/lustre-grete/usr/u12649/data/final_test" 85 | prediction_dir = "/mnt/lustre-grete/usr/u12649/models/cellvitpp/inference/" 86 | result_dir = "/mnt/lustre-grete/usr/u12649/models/cellvitpp/results" 87 | for dataset in DATASETS: 88 | label_dir = os.path.join(input_dir, dataset, "loaded_testset", "eval_split", "test_labels") 89 | for checkpoint in CVTPP_CP: 90 | evaluate_cellvit(prediction_dir, checkpoint, dataset, result_dir, label_dir) 91 | 92 | 93 | main() 94 | -------------------------------------------------------------------------------- /experiments/benchmarking/outdated/cellvitplusplus/infer_cvtplus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import pandas as pd 4 | from glob import glob 5 | from natsort import natsorted 6 | 7 | DATASETS = [ 8 | "consep", 9 | "cpm15", 10 | "cpm17", 11 | "cryonuseg", 12 | "lizard", 13 | "lynsec_he", 14 | "lynsec_ihc", 15 | "monusac", 16 | "monuseg", 17 | "nuclick", 18 | "nuinsseg", 19 | "pannuke", 20 | "puma", 21 | "srsanet", 22 | "tnbc", 23 | ] 24 | 25 | CVTPP_CP = [ 26 | # 'Virchow-x40-AMP', 27 | 'SAM-H-x40-AMP', 28 | # '256-x40-AMP' 29 | ] 30 | 31 | 32 | def run_inference(model_dir, input_dir, output_dir, result_dir, label_dir): 33 | for dataset in DATASETS: 34 | data_dir = os.path.join(input_dir, dataset) 35 | files = { 36 | "path": list(natsorted(glob(os.path.join(data_dir, '*.tiff')))), 37 | "slide_mpp": [0.25 for i in range(len(list(natsorted(glob(os.path.join(data_dir, '*.tiff'))))))], 38 | "magnification": [40 for i in range(len(list(natsorted(glob(os.path.join(data_dir, '*.tiff'))))))] 39 | } 40 | filelist_df = pd.DataFrame(files) 41 | os.makedirs(os.path.join(input_dir, "file_lists"), exist_ok=True) 42 | csv_filelist = os.path.join(input_dir, "file_lists", f"{dataset}_filelist.csv") 43 | filelist_df.to_csv(csv_filelist, index=False) 44 | 45 | for checkpoint in CVTPP_CP: 46 | checkpoint_path = os.path.join(model_dir, f"CellViT-{checkpoint}.pth") 47 | output_path = os.path.join(output_dir, dataset, checkpoint) 48 | if os.path.exists(output_path): 49 | if len(os.listdir(output_path)) > 1: 50 | continue 51 | os.makedirs(output_path, exist_ok=True) 52 | args = [ 53 | "--model", f"{checkpoint_path}", 54 | "--outdir", f"{output_path}", 55 | "process_dataset", 56 | "--filelist", f"{csv_filelist}" 57 | ] 58 | command = [ 59 | "python3", 60 | "/user/titus.griebel/u12649/CellViT-plus-plus/cellvit/detect_cells.py", 61 | ] + args 62 | print(f"Running inference with CellViT-plus-plus {checkpoint} model on {dataset} dataset...") 63 | subprocess.run(command) 64 | 65 | for file in glob(os.path.join(output_path, '*json')): 66 | os.remove(file) 67 | # evaluate_cellvit(output_path, checkpoint, dataset, result_dir, label_dir) 68 | 69 | print(f"Successfully ran inference with CellViT {checkpoint} model on {dataset} dataset") 70 | 71 | 72 | def main(): 73 | run_inference( 74 | "/mnt/lustre-grete/usr/u12649/models/cellvit_plusplus/checkpoints", 75 | "/mnt/lustre-grete/usr/u12649/data/cvtplus/preprocessed", 76 | "/mnt/lustre-grete/usr/u12649/models/cellvit_plusplus/inference/", 77 | "/mnt/lustre-grete/usr/u12649/models/cellvit_plusplus/results", 78 | "/mnt/lustre-grete/usr/u12649/data/final_test" 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /experiments/benchmarking/stardist/stardist_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from glob import glob 5 | from natsort import natsorted 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import imageio.v3 as imageio 10 | from skimage.measure import label 11 | 12 | from tukra.io import read_image 13 | from tukra.inference import segment_using_stardist 14 | 15 | from elf.evaluation import mean_segmentation_accuracy 16 | 17 | 18 | def _run_evaluation(gt_paths, prediction_paths, verbose=True): 19 | assert len(gt_paths) == len(prediction_paths), \ 20 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}" 21 | 22 | msas, sa50s, sa75s = [], [], [] 23 | for gt_path, pred_path in tqdm( 24 | zip(gt_paths, prediction_paths, strict=False), 25 | desc="Evaluate predictions", 26 | total=len(gt_paths), 27 | disable=not verbose, 28 | ): 29 | assert os.path.exists(gt_path), gt_path 30 | assert os.path.exists(pred_path), pred_path 31 | 32 | gt = imageio.imread(gt_path) 33 | gt = label(gt) 34 | pred = imageio.imread(pred_path) 35 | 36 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True) 37 | sa50, sa75 = scores[0], scores[5] 38 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75) 39 | 40 | return msas, sa50s, sa75s 41 | 42 | 43 | def evaluate_stardist(prediction_dir, label_dir, result_dir, dataset): 44 | gt_paths = natsorted(glob(os.path.join(label_dir, "eval_split", "test_labels", "*"))) 45 | for checkpoint in ["stardist"]: 46 | save_path = os.path.join(result_dir, dataset, checkpoint, f'{dataset}_stardist_stardist_ais_result.csv') 47 | if os.path.exists(save_path): 48 | continue 49 | 50 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 51 | if len(prediction_paths) == 0: 52 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found") 53 | continue 54 | 55 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True) 56 | print(f"Evaluating {dataset} dataset on Stardist ...") 57 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths) 58 | results = pd.DataFrame.from_dict( 59 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]} 60 | ) 61 | 62 | results.to_csv(save_path, index=False) 63 | print(results.head(2)) 64 | 65 | 66 | def infer_stardist(data_dir, output_path): 67 | image_paths = natsorted(glob(os.path.join(data_dir, "eval_split", "test_images", "*"))) 68 | os.makedirs(output_path, exist_ok=True) 69 | for image_path in image_paths: 70 | image = read_image(image_path) 71 | segmentation = segment_using_stardist(image=image, model_name="2D_versatile_he") 72 | imageio.imwrite(os.path.join(output_path, os.path.basename(image_path)), segmentation) 73 | 74 | 75 | def run_inference(input_dir, output_dir, dataset): 76 | output_path = os.path.join(output_dir, 'inference', dataset, "stardist") 77 | input_path = os.path.join(input_dir, dataset) 78 | if os.path.exists( 79 | os.path.join(output_dir, "results", dataset, "stardist", f'{dataset}_stardist_stardist_ais_result.csv') 80 | ): 81 | return 82 | 83 | os.makedirs(output_path, exist_ok=True) 84 | print(f"Running inference with StarDist model on {dataset} dataset... \n") 85 | infer_stardist(input_path, output_path) 86 | print(f"Inference on {dataset} dataset with the StarDist model successfully completed \n") 87 | 88 | evaluate_stardist( 89 | prediction_dir=output_path, 90 | label_dir=input_path, 91 | result_dir=os.path.join(output_dir, 'results'), 92 | dataset=dataset 93 | ) 94 | 95 | 96 | def get_stardist_args(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data") 99 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on") 100 | parser.add_argument("-o", "--output_dir", type=str, default=None, help="The path where the results are saved") 101 | args = parser.parse_args() 102 | return args 103 | 104 | 105 | def main(): 106 | args = get_stardist_args() 107 | run_inference(input_dir=args.input, output_dir=args.output_dir, dataset=args.dataset) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /experiments/data/README.md: -------------------------------------------------------------------------------- 1 | # Scripts to Preprocess Datasets for Evaluation 2 | 3 | This folder contains scripts for preprocessing datasets for all of our benchmarking (both automatic and interactive segmentation) 4 | 5 | > TLDR: The idea of this scripts are to use the input images corresponding to how they should be on a practical point of view. You can write your own scripts / provide path to your images in our [example scripts](../../examples/). 6 | -------------------------------------------------------------------------------- /experiments/data/cvtplus_preproc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from glob import glob 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | import imageio.v3 as imageio 8 | 9 | 10 | DATASETS = [ 11 | "consep", 12 | "cpm15", 13 | "cpm17", 14 | "cryonuseg", 15 | "lizard", 16 | "lynsec_he", 17 | "lynsec_ihc", 18 | "monusac", 19 | "monuseg", 20 | "nuclick", 21 | "nuinsseg", 22 | "pannuke", 23 | "puma", 24 | "srsanet", 25 | "tnbc", 26 | ] 27 | 28 | 29 | def preprocess_cvtplus(input_dir, output_dir): 30 | # create resized pyramid tiffs with vips 31 | import pyvips 32 | 33 | for dataset in DATASETS: 34 | data_dir = os.path.join(input_dir, dataset, "loaded_testset", "eval_split", "test_images") 35 | intermediate_folder = os.path.join(output_dir, "intermediate", dataset) 36 | output_folder = os.path.join(output_dir, "preprocessed", dataset) 37 | os.makedirs(intermediate_folder, exist_ok=True) 38 | os.makedirs(output_folder, exist_ok=True) 39 | for img_path in tqdm(glob(os.path.join(data_dir, "*.tiff"))): 40 | img = imageio.imread(img_path) 41 | img_uint = img.astype(np.uint8) 42 | intermediate_file = os.path.join(intermediate_folder, os.path.basename(img_path)) 43 | imageio.imwrite(intermediate_file, img_uint) 44 | output_file = os.path.join(output_folder, os.path.basename(img_path)) 45 | image = pyvips.Image.new_from_file(intermediate_file) 46 | image.tiffsave(output_file, tile=True, tile_width=512, tile_height=512, pyramid=True) 47 | 48 | shutil.rmtree(intermediate_folder) 49 | 50 | 51 | def main(): 52 | preprocess_cvtplus( 53 | input_dir="/mnt/lustre-grete/usr/u12649/data/final_test", output_dir="/mnt/lustre-grete/usr/u12649/data/cvtplus" 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /experiments/data/dataloaders.py: -------------------------------------------------------------------------------- 1 | import micro_sam.training as sam_training 2 | 3 | from torch_em.data import datasets 4 | from torch_em.data import MinInstanceSampler 5 | 6 | 7 | def get_dataloaders(patch_shape, data_path, dataset): 8 | raw_transform = sam_training.identity 9 | sampler = MinInstanceSampler(min_num_instances=3) 10 | 11 | if dataset == "consep": 12 | loader = datasets.get_consep_loader( 13 | path=data_path, 14 | patch_shape=patch_shape, 15 | batch_size=1, 16 | download=True, 17 | split="test", 18 | raw_transform=raw_transform, 19 | sampler=sampler, 20 | ) 21 | 22 | elif dataset == "cpm15": 23 | loader = datasets.get_cpm_loader( 24 | path=data_path, 25 | patch_shape=patch_shape, 26 | batch_size=1, 27 | download=False, 28 | split="test", 29 | raw_transform=raw_transform, 30 | sampler=sampler, 31 | data_choice="cpm15", 32 | ) 33 | 34 | elif dataset == "cpm17": 35 | loader = datasets.get_cpm_loader( 36 | path=data_path, 37 | patch_shape=patch_shape, 38 | batch_size=1, 39 | download=False, 40 | split="test", 41 | raw_transform=raw_transform, 42 | sampler=sampler, 43 | data_choice="cpm17", 44 | ) 45 | 46 | elif dataset == "cryonuseg": 47 | loader = datasets.get_cryonuseg_loader( 48 | path=data_path, 49 | patch_shape=(1,) + patch_shape, 50 | batch_size=1, 51 | rater="b1", 52 | split="test", 53 | download=True, 54 | raw_transform=raw_transform, 55 | sampler=sampler, 56 | ) 57 | 58 | elif dataset == "glas": 59 | loader = datasets.get_glas_loader( 60 | path=data_path, 61 | patch_shape=patch_shape, 62 | batch_size=1, 63 | download=True, 64 | split="test", 65 | raw_transform=raw_transform, 66 | sampler=MinInstanceSampler(min_num_instances=2), 67 | ) 68 | 69 | elif dataset == "lizard": 70 | loader = datasets.get_lizard_loader( 71 | path=data_path, 72 | patch_shape=patch_shape, 73 | batch_size=1, 74 | download=True, 75 | split="test", 76 | raw_transform=raw_transform, 77 | sampler=sampler, 78 | ) 79 | 80 | elif dataset == "lynsec_he": 81 | loader = datasets.get_lynsec_loader( 82 | path=data_path, 83 | patch_shape=patch_shape, 84 | batch_size=1, 85 | choice="h&e", 86 | download=True, 87 | raw_transform=raw_transform, 88 | sampler=sampler, 89 | ) 90 | 91 | elif dataset == "lynsec_ihc": 92 | loader = datasets.get_lynsec_loader( 93 | path=data_path, 94 | patch_shape=patch_shape, 95 | batch_size=1, 96 | choice="ihc", 97 | download=True, 98 | raw_transform=raw_transform, 99 | sampler=sampler, 100 | ) 101 | 102 | elif dataset == "monusac": 103 | loader = datasets.get_monusac_loader( 104 | path=data_path, 105 | patch_shape=patch_shape, 106 | batch_size=1, 107 | split="test", 108 | download=True, 109 | raw_transform=raw_transform, 110 | sampler=sampler, 111 | ) 112 | 113 | elif dataset == "monuseg": 114 | loader = datasets.get_monuseg_loader( 115 | path=data_path, 116 | patch_shape=patch_shape, 117 | batch_size=1, 118 | split="test", 119 | download=True, 120 | raw_transform=raw_transform, 121 | sampler=sampler, 122 | ) 123 | 124 | elif dataset == "nuinsseg": 125 | loader = datasets.get_nuinsseg_loader( 126 | path=data_path, 127 | patch_shape=patch_shape, 128 | batch_size=1, 129 | download=True, 130 | raw_transform=raw_transform, 131 | sampler=sampler, 132 | ) 133 | elif dataset == "nuclick": 134 | loader = datasets.get_nuclick_loader( 135 | path=data_path, 136 | patch_shape=patch_shape, 137 | batch_size=1, 138 | download=True, 139 | split="Validation", 140 | raw_transform=raw_transform, 141 | sampler=sampler, 142 | ) 143 | 144 | elif dataset == "pannuke": 145 | loader = datasets.get_pannuke_loader( 146 | path=data_path, 147 | patch_shape=(1,) + patch_shape, 148 | batch_size=1, 149 | folds=["fold_3"], 150 | ndim=2, 151 | download=True, 152 | raw_transform=raw_transform, 153 | sampler=sampler, 154 | ) 155 | 156 | elif dataset == "puma": 157 | loader = datasets.get_puma_loader( 158 | path=data_path, 159 | patch_shape=patch_shape, 160 | batch_size=1, 161 | annotations="nuclei", 162 | download=True, 163 | split="test", 164 | raw_transform=raw_transform, 165 | sampler=sampler, 166 | ) 167 | 168 | elif dataset == "srsanet": 169 | loader = datasets.get_srsanet_loader( 170 | path=data_path, 171 | patch_shape=patch_shape, 172 | batch_size=1, 173 | download=True, 174 | split="test", 175 | raw_transform=raw_transform, 176 | sampler=sampler, 177 | ) 178 | 179 | elif dataset == "tnbc": 180 | loader = datasets.get_tnbc_loader( 181 | path=data_path, 182 | patch_shape=patch_shape, 183 | batch_size=1, 184 | ndim=2, 185 | download=True, 186 | split="test", 187 | raw_transform=raw_transform, 188 | sampler=sampler, 189 | ) 190 | 191 | return loader 192 | -------------------------------------------------------------------------------- /experiments/data/get_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from natsort import natsorted 4 | 5 | import h5py 6 | import imageio.v3 as imageio 7 | 8 | from torch_em.data.datasets import histopathology 9 | 10 | 11 | def get_dataset_paths(data_path, dataset) -> list: 12 | 13 | if dataset == "consep": 14 | data_paths = histopathology.consep.get_consep_paths( 15 | path=data_path, download=True, split="test", 16 | ) 17 | label_key = 'labels' 18 | image_key = 'raw' 19 | 20 | elif dataset == "cpm15": 21 | image_paths, label_paths = histopathology.cpm.get_cpm_paths( 22 | path=data_path, download=False, split="test", data_choice="cpm15", 23 | ) 24 | 25 | elif dataset == "cpm17": 26 | image_paths, label_paths = histopathology.cpm.get_cpm_paths( 27 | path=data_path, download=False, split="test", data_choice="cpm17", 28 | ) 29 | 30 | elif dataset == "cryonuseg": 31 | image_paths, label_paths = histopathology.cryonuseg.get_cryonuseg_paths( 32 | path=data_path, rater_choice="b1", split="test", download=True, 33 | ) 34 | 35 | elif dataset == "glas": 36 | data_paths = histopathology.glas.get_glas_paths( 37 | path=data_path, download=True, split="test", 38 | ) 39 | label_key = 'labels' 40 | image_key = 'raw' 41 | 42 | elif dataset == "lizard": 43 | data_paths = histopathology.lizard.get_lizard_paths( 44 | path=data_path, download=True, split="test", 45 | ) 46 | label_key = 'labels/segmentation' 47 | image_key = 'image' 48 | 49 | elif dataset == "lynsec_he": 50 | image_paths, label_paths = histopathology.lynsec.get_lynsec_paths( 51 | path=data_path, choice="h&e", download=True, 52 | ) 53 | 54 | elif dataset == "lynsec_ihc": 55 | image_paths, label_paths = histopathology.lynsec.get_lynsec_paths( 56 | path=data_path, choice="ihc", download=True, 57 | ) 58 | 59 | elif dataset == "monuseg": 60 | image_paths, label_paths = histopathology.monuseg.get_monuseg_paths( 61 | path=data_path, split="test", download=True, 62 | ) 63 | 64 | elif dataset == "nuclick": 65 | image_paths, label_paths = histopathology.nuclick.get_nuclick_paths( 66 | path=data_path, download=True, split="Validation", 67 | ) 68 | 69 | elif dataset == "nuinsseg": 70 | image_paths, label_paths = histopathology.nuinsseg.get_nuinsseg_paths( 71 | path=data_path, download=True, 72 | ) 73 | 74 | elif dataset == "pannuke": 75 | data_paths = histopathology.pannuke.get_pannuke_paths( 76 | path=data_path, folds=["fold_3"], download=True, 77 | ) 78 | cached_images = os.path.join(data_path, "loaded_images") 79 | cached_labels = os.path.join(data_path, "loaded_labels") 80 | os.makedirs(cached_images, exist_ok=True) 81 | os.makedirs(cached_labels, exist_ok=True) 82 | 83 | for h5_path in data_paths: 84 | with h5py.File(h5_path, 'r') as file: 85 | images = file['images'] 86 | labels = file['labels/instances'] 87 | images = images[:] 88 | labels = labels[:] 89 | 90 | # PanNuke is provided in an array of shape (C, B, H, W) 91 | images = images.transpose(1, 2, 3, 0) # --> (B, H, W, C) 92 | 93 | counter = 1 94 | for image, label in zip(images, labels): 95 | image_path = os.path.join(cached_images, f"{counter:04}.tiff") 96 | label_path = os.path.join(cached_labels, f"{counter:04}.tiff") 97 | 98 | assert image.shape == (256, 256, 3) 99 | imageio.imwrite(image_path, image) 100 | imageio.imwrite(label_path, label) 101 | 102 | counter += 1 103 | 104 | image_paths = glob(os.path.join(cached_images, "*.tiff")) 105 | label_paths = glob(os.path.join(cached_labels, "*.tiff")) 106 | 107 | elif dataset == "puma": 108 | data_paths = histopathology.puma.get_puma_paths( 109 | path=data_path, annotations="nuclei", download=True, split="test", 110 | ) 111 | label_key = 'labels/nuclei' 112 | image_key = 'raw' 113 | 114 | elif dataset == "srsanet": 115 | image_paths, label_paths = histopathology.srsanet.get_srsanet_paths( 116 | path=data_path, 117 | download=True, 118 | split="test", 119 | ) 120 | 121 | elif dataset == "tnbc": 122 | data_paths = histopathology.tnbc.get_tnbc_paths( 123 | path=data_path, download=True, split="test", 124 | ) 125 | label_key = 'labels/instances' 126 | image_key = 'raw' 127 | label_key = 'labels/instances' 128 | image_key = 'raw' 129 | 130 | if dataset in ["consep", "lizard", "glas", "puma", "tnbc"]: 131 | cached_images = os.path.join(data_path, "loaded_images") 132 | cached_labels = os.path.join(data_path, "loaded_labels") 133 | os.makedirs(cached_images, exist_ok=True) 134 | os.makedirs(cached_labels, exist_ok=True) 135 | for h5_path in data_paths: 136 | with h5py.File(h5_path, 'r') as file: 137 | img = file[image_key] 138 | label = file[label_key] 139 | img = img[:] 140 | label = label[:] 141 | image = img.transpose(1, 2, 0) 142 | 143 | img_path = os.path.join( 144 | cached_images, os.path.basename(h5_path).replace(".h5", ".tiff")) 145 | label_path = os.path.join( 146 | cached_labels, os.path.basename(h5_path).replace(".h5", ".tiff")) 147 | assert image.shape[:2] == label.shape, f"{image.shape}, {label.shape}" 148 | 149 | imageio.imwrite(img_path, image) 150 | imageio.imwrite(label_path, label) 151 | image_paths = glob(os.path.join(cached_images, "*.tiff")) 152 | label_paths = glob(os.path.join(cached_labels, "*.tiff")) 153 | 154 | return natsorted(image_paths), natsorted(label_paths) 155 | -------------------------------------------------------------------------------- /experiments/data/interactive_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import imageio.v3 as imageio 5 | 6 | from dataloaders import get_dataloaders 7 | from util import DATASETS, create_val_split, remove_empty_labels 8 | 9 | 10 | def load_testsets(path, dsets=DATASETS, patch_shape=(512, 512)) -> None: 11 | for dataset in dsets: 12 | if os.path.exists(os.path.join(path, dataset, "loaded_testset", "images")): 13 | if len(os.listdir(os.path.join(path, dataset, "loaded_testset", "images"))) > 1: 14 | print(f"Dataset {dataset} is loaded already.") 15 | continue 16 | 17 | print(f"Loading {dataset} dataset...") 18 | dpath = os.path.join(path, dataset) 19 | os.makedirs(dpath, exist_ok=True) 20 | loader = get_dataloaders(patch_shape=patch_shape, data_path=dpath, dataset=dataset) 21 | 22 | image_output_path = os.path.join(path, dataset, "loaded_testset", "images") 23 | label_output_path = os.path.join(path, dataset, "loaded_testset", "labels") 24 | 25 | os.makedirs(image_output_path, exist_ok=True) 26 | os.makedirs(label_output_path, exist_ok=True) 27 | 28 | for idx, (image, label) in enumerate(loader, start=1): 29 | image = image.squeeze().numpy() 30 | label = label.squeeze().numpy() 31 | image = image.transpose(1, 2, 0) 32 | if image.shape[-1] == 4: # deletes alpha channel if one exists 33 | image = image[..., :-1] 34 | 35 | imageio.imwrite(os.path.join(image_output_path, f"{idx:04}.tiff"), image) 36 | imageio.imwrite(os.path.join(label_output_path, f"{idx:04}.tiff"), label) 37 | 38 | remove_empty_labels(dpath) 39 | create_val_split(os.path.join(dpath, "loaded_testset"), custom_name="eval_split", dataset=dataset) 40 | print(f"{dataset} testset has successfully been loaded.") 41 | 42 | 43 | def dataloading_args(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("-p", "--path", type=str, default=None) 46 | parser.add_argument("-d", "--datasets", type=str, default=None) 47 | parser.add_argument("--patch_shape", type=tuple, default=(512, 512)) 48 | 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | def main(): 54 | args = dataloading_args() 55 | if args.path is not None: 56 | data_path = args.path 57 | else: 58 | data_path = "/mnt/lustre-grete/usr/u12649/data/final_test/" 59 | 60 | if args.datasets is not None: 61 | load_testsets(data_path, [args.datasets], args.patch_shape) 62 | else: 63 | load_testsets(data_path, patch_shape=args.patch_shape) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /experiments/data/load_original_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import imageio.v3 as imageio 7 | 8 | from util import DATASETS 9 | from get_paths import get_dataset_paths 10 | 11 | 12 | def load_datasets(path, datasets=DATASETS): 13 | for dataset in datasets: 14 | dataset_path = os.path.join(path, dataset) 15 | image_outpath = os.path.join(dataset_path, "loaded_images") 16 | label_outpath = os.path.join(dataset_path, "loaded_labels") 17 | image_outpath = os.path.join(dataset_path, "loaded_images") 18 | label_outpath = os.path.join(dataset_path, "loaded_labels") 19 | os.makedirs(image_outpath, exist_ok=True) 20 | os.makedirs(label_outpath, exist_ok=True) 21 | if len(os.listdir(image_outpath)) > 1: 22 | continue 23 | 24 | print(f"Loading {dataset}...") 25 | if len(os.listdir(image_outpath)) > 1: 26 | continue 27 | 28 | print(f"Loading {dataset}...") 29 | image_paths, label_paths = get_dataset_paths(dataset_path, dataset) 30 | assert len(image_paths) == len(label_paths) 31 | 32 | count = 1 33 | for image_path, label_path in tqdm(zip(image_paths, label_paths), desc="Moving files to new directory..."): 34 | img_ext = os.path.splitext(image_path)[1] 35 | label_ext = os.path.splitext(label_path)[1] 36 | image_dest = os.path.join(image_outpath, f"{count:04}{img_ext}") 37 | label_dest = os.path.join(label_outpath, f"{count:04}{label_ext}") 38 | 39 | img = imageio.imread(image_path) 40 | if img.shape[2] == 4: # checks for and removes alpha channel 41 | img = img[:, :, :-1] 42 | imageio.imwrite(image_path, img) 43 | print("Alpha channel successfully removed.") 44 | 45 | shutil.move(image_path, image_dest) 46 | shutil.move(label_path, label_dest) 47 | 48 | count += 1 49 | 50 | 51 | def dataloading_args(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("-p", "--path", type=str, default=None) 54 | parser.add_argument("-d", "--datasets", type=str, default=None) 55 | 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def main(): 61 | args = dataloading_args() 62 | if args.path is not None: 63 | data_path = args.path 64 | else: 65 | data_path = "/mnt/lustre-grete/usr/u12649/data/original_data" 66 | 67 | if args.datasets is not None: 68 | load_datasets(data_path, [args.datasets]) 69 | else: 70 | load_datasets(data_path) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /experiments/data/load_padded_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import numpy as np 7 | import imageio.v3 as imageio 8 | 9 | from util import DATASETS, PADDING_DS 10 | from get_paths import get_dataset_paths 11 | 12 | 13 | def load_datasets(path, datasets=DATASETS): 14 | for dataset in datasets: 15 | dataset_path = os.path.join(path, dataset) 16 | image_outpath = os.path.join(dataset_path, "loaded_images") 17 | label_outpath = os.path.join(dataset_path, "loaded_labels") 18 | os.makedirs(image_outpath, exist_ok=True) 19 | os.makedirs(label_outpath, exist_ok=True) 20 | if os.path.exists(os.path.join(dataset_path, "eval_split", "test_images")): 21 | if len(os.listdir(os.path.join(dataset_path, "eval_split", "test_images"))) > 1: 22 | continue 23 | 24 | print(f"Loading {dataset}...") 25 | image_paths, label_paths = get_dataset_paths(dataset_path, dataset) 26 | assert len(image_paths) == len(label_paths) 27 | 28 | count = 1 29 | for image_path, label_path in tqdm(zip(image_paths, label_paths), desc="Moving files to new directory..."): 30 | img_ext = os.path.splitext(image_path)[1] 31 | label_ext = os.path.splitext(label_path)[1] 32 | image_dest = os.path.join(image_outpath, f"{count:04}{img_ext}") 33 | label_dest = os.path.join(label_outpath, f"{count:04}{label_ext}") 34 | 35 | img = imageio.imread(image_path) 36 | if img.shape[2] == 4: # checks for and removes alpha channel 37 | img = img[:, :, :-1] 38 | imageio.imwrite(image_path, img) 39 | print("Alpha channel successfully removed.") 40 | 41 | if dataset in PADDING_DS: 42 | padded_img = np.zeros((512, 512, 3), dtype=img.dtype) 43 | padded_img[:256, :256, :] = img 44 | assert padded_img.shape == (512, 512, 3), padded_img.shape 45 | imageio.imwrite(image_path, padded_img) 46 | 47 | label = imageio.imread(label_path) 48 | padded_label = np.zeros((512, 512), dtype=label.dtype) 49 | padded_label[:256, :256, :] = label 50 | assert padded_label.shape == (512, 512), padded_label.shape 51 | imageio.imwrite(label_path, padded_label) 52 | 53 | if dataset != 'lizard': 54 | if img.shape[0] != img.shape[1] or img.shape[0] % 16 != 0: 55 | shape = img.shape 56 | new_shape = max(shape[:2]) // 16 57 | new_dim = (new_shape + 1) * 16 58 | padded_img = np.zeros((new_dim, new_dim, 3), dtype=img.dtype) 59 | padded_img[:shape[0], :shape[1], :] = img 60 | 61 | label = imageio.imread(label_path) 62 | padded_label = np.zeros((new_dim, new_dim), dtype=label.dtype) 63 | padded_label[:shape[0], :shape[1]] = label 64 | imageio.imwrite(label_path, padded_label) 65 | assert padded_label.shape == padded_img.shape[:2] 66 | 67 | imageio.imwrite(image_path, padded_img) 68 | 69 | if dataset == 'lizard': 70 | shape = img.shape 71 | if img.shape[0] != img.shape[1] or img.shape[0] % 1024 != 0: 72 | new_shape = max(shape[:2]) // 1024 73 | new_dim = (new_shape + 1) * 1024 74 | padded_img = np.zeros((new_dim, new_dim, 3), dtype=img.dtype) 75 | padded_img[:shape[0], :shape[1], :] = img 76 | imageio.imwrite(image_path, padded_img) 77 | 78 | label = imageio.imread(label_path) 79 | padded_label = np.zeros((new_dim, new_dim), dtype=label.dtype) 80 | padded_label[:shape[0], :shape[1]] = label 81 | imageio.imwrite(label_path, padded_label) 82 | assert padded_label.shape == padded_img.shape[:2] 83 | 84 | shutil.move(image_path, image_dest) 85 | shutil.move(label_path, label_dest) 86 | 87 | count += 1 88 | 89 | 90 | def dataloading_args(): 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("-p", "--path", type=str, default=None) 93 | parser.add_argument("-d", "--datasets", type=str, default=None) 94 | 95 | args = parser.parse_args() 96 | return args 97 | 98 | 99 | def main(): 100 | args = dataloading_args() 101 | if args.path is not None: 102 | data_path = args.path 103 | else: 104 | data_path = "/mnt/lustre-grete/usr/u12649/data/vit_data" 105 | 106 | if args.datasets is not None: 107 | load_datasets(data_path, [args.datasets]) 108 | else: 109 | load_datasets(data_path) 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /experiments/data/semantic/get_semantic_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from natsort import natsorted 4 | 5 | import h5py 6 | import imageio.v3 as imageio 7 | 8 | from torch_em.data.datasets import histopathology 9 | 10 | from patho_sam.training.util import remap_labels 11 | 12 | 13 | def get_dataset_paths(data_path, dataset): 14 | cached_images = os.path.join(data_path, "loaded_images") 15 | cached_labels = os.path.join(data_path, "loaded_labels") 16 | os.makedirs(cached_images, exist_ok=True) 17 | os.makedirs(cached_labels, exist_ok=True) 18 | 19 | if dataset == "pannuke": 20 | data_paths = histopathology.pannuke.get_pannuke_paths( 21 | path=data_path, folds=["fold_3"], download=True, 22 | ) 23 | 24 | for h5_path in data_paths: 25 | with h5py.File(h5_path, 'r') as file: 26 | images = file['images'] 27 | labels = file['labels/semantic'] 28 | images = images[:] 29 | labels = labels[:] 30 | 31 | # PanNuke is provided in an array of shape (C, B, H, W) 32 | images = images.transpose(1, 2, 3, 0) # --> (B, H, W, C) 33 | 34 | counter = 1 35 | for image, label in zip(images, labels): 36 | image_path = os.path.join(cached_images, f"{counter:04}.tiff") 37 | label_path = os.path.join(cached_labels, f"{counter:04}.tiff") 38 | 39 | assert image.shape == (256, 256, 3) 40 | imageio.imwrite(image_path, image) 41 | imageio.imwrite(label_path, label) 42 | 43 | counter += 1 44 | 45 | elif dataset == "puma": 46 | data_paths = histopathology.puma.get_puma_paths( 47 | path=data_path, annotations="nuclei", download=True, split="test", 48 | ) 49 | 50 | for h5_path in data_paths: 51 | with h5py.File(h5_path, 'r') as file: 52 | img = file['raw'] 53 | label = file['labels/semantic/nuclei'] 54 | img = img[:] 55 | label = label[:] 56 | image = img.transpose(1, 2, 0) 57 | label = remap_labels(label, name=None) 58 | img_path = os.path.join( 59 | cached_images, os.path.basename(h5_path).replace(".h5", ".tiff")) 60 | label_path = os.path.join( 61 | cached_labels, os.path.basename(h5_path).replace(".h5", ".tiff")) 62 | assert image.shape[:2] == label.shape, f"{image.shape}, {label.shape}" 63 | 64 | imageio.imwrite(img_path, image) 65 | imageio.imwrite(label_path, label) 66 | 67 | else: 68 | raise ValueError 69 | 70 | image_paths = glob(os.path.join(cached_images, "*.tiff")) 71 | label_paths = glob(os.path.join(cached_labels, "*.tiff")) 72 | 73 | return natsorted(image_paths), natsorted(label_paths) 74 | -------------------------------------------------------------------------------- /experiments/patho-sam/README.md: -------------------------------------------------------------------------------- 1 | # Scripts for Evaluating Segment Anything Models for Histopathology 2 | 3 | This folder contains scripts for evaluating interactive and automatic instance segmentation of nucleus in histopathology images. 4 | 5 | > TLDR: The scripts expect path to input images and corresponding labels. You can write your own function to automate the processes in the scripts. The keywords mean: 1) AIS (Automatic Instance Segmentation), 2) AMG (Automatic Mask Generation), 3) Iterative Prompting (i.e. interactive segmentation, where prompts are derived from the ground-truth labels, simulating a user, and add additional points for 7 iterations). See our [preprint](https://doi.org/10.48550/arXiv.2502.00408) for more details on this. -------------------------------------------------------------------------------- /experiments/patho-sam/evaluate_ais.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from micro_sam.evaluation.evaluation import run_evaluation 4 | from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder 5 | 6 | from util import VANILLA_MODELS, get_default_arguments, get_pred_paths, get_test_paths, get_val_paths 7 | 8 | 9 | def run_instance_segmentation_with_decoder_inference( 10 | model_type, checkpoint, experiment_folder, input_path, tiling_window_params 11 | ): 12 | val_image_paths, val_gt_paths = get_val_paths(input_path) 13 | test_image_paths, _ = get_test_paths(input_path) 14 | prediction_folder = run_instance_segmentation_with_decoder( 15 | checkpoint, model_type, experiment_folder, val_image_paths, val_gt_paths, test_image_paths, 16 | tiling_window_params=tiling_window_params 17 | ) 18 | return prediction_folder 19 | 20 | 21 | def eval_instance_segmentation_with_decoder(prediction_folder, experiment_folder, input_path): 22 | print("Evaluating", prediction_folder) 23 | _, gt_paths = get_test_paths(input_path) 24 | pred_paths = get_pred_paths(prediction_folder) 25 | save_path = os.path.join(experiment_folder, "results", "instance_segmentation_with_decoder.csv") 26 | res = run_evaluation(gt_paths, pred_paths, save_path=save_path) 27 | print(res) 28 | 29 | 30 | def main(): 31 | args = get_default_arguments() 32 | 33 | if args.checkpoint is None: 34 | ckpt = VANILLA_MODELS[args.model] 35 | else: 36 | ckpt = args.checkpoint 37 | 38 | if args.tiling_window: 39 | tiling_window_params = {"tile_shape": [384, 384], "halo": [64, 64]} 40 | else: 41 | tiling_window_params = None 42 | 43 | prediction_folder = run_instance_segmentation_with_decoder_inference( 44 | args.model, ckpt, args.experiment_folder, args.input_path, tiling_window_params 45 | ) 46 | eval_instance_segmentation_with_decoder(prediction_folder, args.experiment_folder, args.input_path) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /experiments/patho-sam/evaluate_amg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from micro_sam.evaluation.evaluation import run_evaluation 4 | from micro_sam.evaluation.inference import run_amg 5 | 6 | from util import ( 7 | VANILLA_MODELS, get_default_arguments, get_pred_paths, get_test_paths, get_val_paths, 8 | ) 9 | 10 | 11 | def run_amg_inference(model_type, checkpoint, experiment_folder, input_path, tiling_window_params=None): 12 | val_image_paths, val_gt_paths = get_val_paths(input_path) 13 | test_image_paths, _ = get_test_paths(input_path) 14 | prediction_folder = run_amg( 15 | checkpoint, 16 | model_type, 17 | experiment_folder, 18 | val_image_paths, 19 | val_gt_paths, 20 | test_image_paths, 21 | tiling_window_params=tiling_window_params, 22 | ) 23 | return prediction_folder 24 | 25 | 26 | def eval_amg(prediction_folder, experiment_folder, input_path): 27 | print("Evaluating", prediction_folder) 28 | _, gt_paths = get_test_paths(input_path) 29 | pred_paths = get_pred_paths(prediction_folder) 30 | save_path = os.path.join(experiment_folder, "results", "amg.csv") 31 | res = run_evaluation(gt_paths, pred_paths, save_path=save_path) 32 | print(res) 33 | 34 | 35 | def main(): 36 | args = get_default_arguments() 37 | if args.checkpoint is None: 38 | ckpt = VANILLA_MODELS[args.model] 39 | else: 40 | ckpt = args.checkpoint 41 | 42 | if args.tiling_window: 43 | tiling_window_params = {"tile_shape": [384, 384], "halo": [64, 64]} 44 | else: 45 | tiling_window_params = None 46 | 47 | prediction_folder = run_amg_inference( 48 | args.model, ckpt, args.experiment_folder, args.input_path, tiling_window_params 49 | ) 50 | eval_amg(prediction_folder, args.experiment_folder, args.input_path) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /experiments/patho-sam/evaluate_iterative_prompting.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from micro_sam.evaluation import inference 4 | from micro_sam.evaluation.evaluation import run_evaluation_for_iterative_prompting 5 | 6 | from util import get_default_arguments, get_model, get_test_paths 7 | 8 | 9 | def _run_iterative_prompting(exp_folder, predictor, start_with_box_prompt, use_masks, input_path): 10 | prediction_root = os.path.join(exp_folder, "start_with_box" if start_with_box_prompt else "start_with_point") 11 | embedding_folder = os.path.join(exp_folder, "embeddings") 12 | image_paths, gt_paths = get_test_paths(input_path) 13 | inference.run_inference_with_iterative_prompting( 14 | predictor=predictor, 15 | image_paths=image_paths, 16 | gt_paths=gt_paths, 17 | embedding_dir=embedding_folder, 18 | prediction_dir=prediction_root, 19 | start_with_box_prompt=start_with_box_prompt, 20 | use_masks=use_masks, 21 | ) 22 | return prediction_root 23 | 24 | 25 | def _evaluate_iterative_prompting(prediction_root, start_with_box_prompt, exp_folder, input_path): 26 | _, gt_paths = get_test_paths(input_path) 27 | 28 | run_evaluation_for_iterative_prompting( 29 | gt_paths=gt_paths, 30 | prediction_root=prediction_root, 31 | experiment_folder=exp_folder, 32 | start_with_box_prompt=start_with_box_prompt, 33 | ) 34 | 35 | 36 | def main(): 37 | args = get_default_arguments() 38 | 39 | start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point 40 | 41 | # Get the predictor to perform inference 42 | predictor = get_model(model_type=args.model, ckpt=args.checkpoint) 43 | 44 | prediction_root = _run_iterative_prompting( 45 | args.experiment_folder, predictor, start_with_box_prompt, args.use_masks, args.input_path 46 | ) 47 | _evaluate_iterative_prompting(prediction_root, start_with_box_prompt, args.experiment_folder, args.input_path) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /experiments/patho-sam/get_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | 4 | import pandas as pd 5 | 6 | 7 | SAM_TYPES = ["vit_b", "vit_l", "vit_h", "vit_b_lm"] 8 | SAM_MODELS = ["generalist_sam", "lm_sam", "pannuke_sam", "vanilla_sam", "nuclick_sam", "glas_sam", "cryonuseg_sam"] 9 | MODEL_NAMES = ["hovernet", "cellvit", "hovernext", "stardist", "cellvitpp", "instanseg"] + SAM_MODELS 10 | 11 | DATASETS = [ 12 | "consep", 13 | "cpm15", 14 | "cpm17", 15 | "cryonuseg", 16 | "glas", 17 | "lizard", 18 | "lynsec_he", 19 | "lynsec_ihc", 20 | "monusac", 21 | "monuseg", 22 | "nuclick", 23 | "nuinsseg", 24 | "pannuke", 25 | "puma", 26 | "srsanet", 27 | "tnbc", 28 | ] 29 | 30 | HNXT_CP = [ 31 | "lizard_convnextv2_large", 32 | "lizard_convnextv2_base", 33 | "lizard_convnextv2_tiny", 34 | "pannuke_convnextv2_tiny_1", 35 | "pannuke_convnextv2_tiny_2", 36 | "pannuke_convnextv2_tiny_3", 37 | ] 38 | 39 | CVT_CP = [ 40 | "256-x20", 41 | "256-x40", 42 | "SAM-H-x20", 43 | "SAM-H-x40", 44 | ] 45 | 46 | HVNT_CP = [ 47 | "consep", 48 | "cpm17", 49 | "kumar", 50 | "monusac", 51 | "pannuke", 52 | ] 53 | 54 | CVTPP_CP = ["SAM-H-x40-AMP"] 55 | 56 | CHECKPOINTS = { 57 | "hovernet": HVNT_CP, 58 | "hovernext": HNXT_CP, 59 | "cellvit": CVT_CP, 60 | "cellvitpp": CVTPP_CP, 61 | "generalist_sam": SAM_TYPES, 62 | "pannuke_sam": SAM_TYPES, 63 | "old_generalist_sam": SAM_TYPES, 64 | "nuclick_sam": SAM_TYPES, 65 | "vanilla_sam": SAM_TYPES, 66 | "glas_sam": SAM_TYPES, 67 | "cryonuseg_sam": SAM_TYPES, 68 | "lm_sam": ["vit_b_lm"], 69 | "stardist": ["stardist"], 70 | "instanseg": ["instanseg"], 71 | } 72 | 73 | 74 | def get_results(path, overwrite=False): 75 | os.makedirs(os.path.join(path, "sum_results", "concatenated_results"), exist_ok=True) 76 | for mode in ['ais', 'amg', 'boxes', 'points']: 77 | csv_concat = os.path.join(path, "sum_results", "concatenated_results", f"concatenated_{mode}_results.csv") 78 | concat_df = pd.DataFrame() 79 | for model in MODEL_NAMES: 80 | if model not in SAM_MODELS and mode != 'ais': 81 | continue 82 | 83 | if model == "vanilla_sam" and mode == 'ais': 84 | continue 85 | 86 | for checkpoint in CHECKPOINTS[model]: 87 | if mode in ['ais', 'amg']: 88 | result_dict = {"dataset": [], "msa": [], "sa50": [], "sa75": []} 89 | else: 90 | result_dict = { 91 | "dataset": [], 92 | "msa_1st": [], 93 | "msa_8th": [], 94 | "sa50_1st": [], 95 | "sa50_8th": [], 96 | "sa75_1st": [], 97 | "sa75_8th": [], 98 | } 99 | 100 | os.makedirs(os.path.join(path, "sum_results", model), exist_ok=True) 101 | csv_out = os.path.join(path, "sum_results", model, f"{mode}_{model}_{checkpoint}_results.csv") 102 | if os.path.exists(csv_out) and not overwrite: 103 | print(f"{csv_out} already exists.") 104 | continue 105 | 106 | for dataset in natsorted(DATASETS): 107 | if model in SAM_MODELS: 108 | csv_path = os.path.join( 109 | path, model, "results", dataset, mode, f"{dataset}_{model}_{checkpoint}_{mode}.csv" 110 | ) 111 | else: 112 | csv_path = os.path.join( 113 | path, model, "results", dataset, checkpoint, 114 | f"{dataset}_{model}_{checkpoint}_{mode}_result.csv" 115 | ) 116 | 117 | if not os.path.exists(csv_path): 118 | continue 119 | 120 | df = pd.read_csv(csv_path) 121 | if mode in ['ais', 'amg']: 122 | result_dict["msa"].append(df.loc[0, "mSA"]) 123 | result_dict["sa50"].append(df.loc[0, "SA50"]) 124 | result_dict["sa75"].append(df.loc[0, "SA75"]) 125 | result_dict["dataset"].append(dataset) 126 | else: 127 | result_dict["msa_1st"].append(df.loc[0, "mSA"]) 128 | result_dict["sa50_1st"].append(df.loc[0, "SA50"]) 129 | result_dict["sa75_1st"].append(df.loc[0, "SA75"]) 130 | result_dict["msa_8th"].append(df.loc[7, "mSA"]) 131 | result_dict["sa50_8th"].append(df.loc[7, "SA50"]) 132 | result_dict["sa75_8th"].append(df.loc[7, "SA75"]) 133 | result_dict["dataset"].append(dataset) 134 | 135 | df = pd.DataFrame(result_dict) 136 | if df.empty: 137 | continue 138 | 139 | df.to_csv(csv_out, index=False) 140 | if mode in ["amg", "ais"]: 141 | df = df[["dataset", "msa"]] 142 | df.rename(columns={"msa": f"{model}_{checkpoint}_{mode}"}, inplace=True) 143 | else: 144 | df = df[["dataset", "msa_1st", "msa_8th"]] 145 | df.rename( 146 | columns={ 147 | "msa_1st": f"{model}_{checkpoint}_{mode}", 148 | "msa_8th": f"{model}_{checkpoint}_I{mode[0]}" 149 | }, 150 | inplace=True, 151 | ) 152 | 153 | if concat_df.empty: 154 | concat_df = df 155 | else: 156 | concat_df = pd.merge(concat_df, df, on="dataset", how="outer") 157 | print(f"{model} (checkpoint: {checkpoint}) added to concat results") 158 | 159 | if not concat_df.empty: 160 | concat_df.to_csv(csv_concat, index=False) 161 | 162 | 163 | def main(): 164 | get_results("/mnt/lustre-grete/usr/u12649/models/", overwrite=True) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /experiments/patho-sam/per_image_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | from natsort import natsorted 5 | from typing import List, Optional, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import imageio.v3 as imageio 10 | from skimage.measure import label 11 | 12 | from elf.evaluation import mean_segmentation_accuracy 13 | 14 | 15 | DATASETS = [ 16 | "consep", 17 | "cpm15", 18 | "cpm17", 19 | "cryonuseg", 20 | "lizard", 21 | "lynsec_he", 22 | "lynsec_ihc", 23 | "monusac", 24 | "monuseg", 25 | "nuclick", 26 | "nuinsseg", 27 | "pannuke", 28 | "puma", 29 | "srsanet", 30 | "tnbc", 31 | ] 32 | 33 | SAM_MODELS = ['generalist_sam', 'pannuke_sam'] 34 | 35 | MODELS = [ 36 | 'stardist', 37 | 'hovernet', 38 | 'hovernext', 39 | 'instanseg', 40 | 'cellvit', 41 | # 'cellvitpp', 42 | 'generalist_sam', 43 | 'pannuke_sam', 44 | # 'old_generalist_sam' 45 | ] 46 | 47 | MODEL_NAMES = { 48 | 'stardist': "StarDist", 49 | 'hovernet': "HoVerNet", 50 | 'hovernext': "HoVerNeXt", 51 | 'instanseg': "InstanSeg", 52 | 'cellvit': "CellViT", 53 | 'pannuke_sam': "Patho-SAM (Specialist)", 54 | 'generalist_sam': "Patho-SAM (Generalist)", 55 | } 56 | 57 | HNXT_CP = [ 58 | "lizard_convnextv2_large", 59 | "lizard_convnextv2_base", 60 | "lizard_convnextv2_tiny", 61 | "pannuke_convnextv2_tiny_1", 62 | "pannuke_convnextv2_tiny_2", 63 | "pannuke_convnextv2_tiny_3", 64 | ] 65 | CVT_CP = [ 66 | "256-x20", 67 | "256-x40", 68 | "SAM-H-x20", 69 | "SAM-H-x40", 70 | ] 71 | 72 | CVTPP_CP = ["SAM-H-x40-AMP"] 73 | 74 | SAM_TYPES = ["vit_b", "vit_l", "vit_h"] 75 | 76 | 77 | HVNT_CP = [ 78 | 'consep', 79 | 'cpm17', 80 | 'kumar', 81 | 'monusac', 82 | 'pannuke', 83 | ] 84 | 85 | CHECKPOINTS = { 86 | 'hovernet': HVNT_CP, 87 | 'hovernext': HNXT_CP, 88 | 'cellvit': CVT_CP, 89 | 'cellvitpp': CVTPP_CP, 90 | 'generalist_sam': SAM_TYPES, 91 | 'pannuke_sam': ['vit_b'], 92 | 'old_generalist_sam': ['vit_b'], 93 | 'stardist': ['stardist'], 94 | 'instanseg': ['instanseg'], 95 | } 96 | 97 | 98 | def _run_evaluation(gt_paths, prediction_paths, verbose=True): 99 | assert len(gt_paths) == len(prediction_paths), f"{len(gt_paths)}, {len(prediction_paths)}" 100 | msas, sa50s, sa75s = [], [], [] 101 | 102 | for gt_path, pred_path in tqdm( 103 | zip(gt_paths, prediction_paths), desc="Evaluate predictions", total=len(gt_paths), disable=not verbose 104 | ): 105 | assert os.path.exists(gt_path), gt_path 106 | assert os.path.exists(pred_path), pred_path 107 | 108 | gt = imageio.imread(gt_path) 109 | gt = label(gt) 110 | pred = imageio.imread(pred_path) 111 | 112 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True) 113 | sa50, sa75 = scores[0], scores[5] 114 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75) 115 | 116 | return msas, sa50s, sa75s 117 | 118 | 119 | def run_evaluation( 120 | gt_paths: List[Union[os.PathLike, str]], 121 | prediction_paths: List[Union[os.PathLike, str]], 122 | save_path: Optional[Union[os.PathLike, str]] = None, 123 | verbose: bool = True, 124 | ) -> pd.DataFrame: 125 | """Run evaluation for instance segmentation predictions. 126 | 127 | Args: 128 | gt_paths: The list of paths to ground-truth images. 129 | prediction_paths: The list of paths with the instance segmentations to evaluate. 130 | save_path: Optional path for saving the results. 131 | verbose: Whether to print the progress. 132 | 133 | Returns: 134 | A DataFrame that contains the evaluation results. 135 | """ 136 | assert len(gt_paths) == len(prediction_paths), f"{len(gt_paths)}, {len(prediction_paths)}" 137 | # if a save_path is given and it already exists then just load it instead of running the eval 138 | if save_path is not None and os.path.exists(save_path): 139 | return pd.read_csv(save_path) 140 | 141 | msas, sa50s, sa75s = _run_evaluation(gt_paths, prediction_paths, verbose=verbose) 142 | 143 | results = np.array(msas) 144 | np.save(save_path, results) 145 | 146 | 147 | def per_sample_eval(inf_path, data_path): 148 | for model in MODELS: 149 | for checkpoint in CHECKPOINTS[model]: 150 | for dataset in DATASETS: 151 | if model not in SAM_MODELS: 152 | inference_paths = natsorted( 153 | glob(os.path.join(inf_path, model, "inference", dataset, checkpoint, "*.tiff")) 154 | ) 155 | else: 156 | inference_paths = natsorted( 157 | glob( 158 | os.path.join( 159 | inf_path, model, "inference", dataset, checkpoint, "instance", 160 | "instance_segmentation_with_decoder", "inference", "*.tiff" 161 | ) 162 | ) 163 | ) 164 | 165 | label_paths = natsorted( 166 | glob(os.path.join(data_path, dataset, "loaded_testset", "eval_split", "test_labels", "*.tiff")) 167 | ) 168 | save_path = os.path.join( 169 | inf_path, model, "results", "per_sample", f"{model}_{dataset}_{checkpoint}_ps.npy" 170 | ) 171 | if os.path.exists(save_path): 172 | continue 173 | 174 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 175 | print(f"evaluating {model}, checkpoint {checkpoint}") 176 | run_evaluation(gt_paths=label_paths, prediction_paths=inference_paths, save_path=save_path) 177 | 178 | 179 | per_sample_eval( 180 | data_path="/mnt/lustre-grete/usr/u12649/data/final_test", 181 | inf_path="/mnt/lustre-grete/usr/u12649/models", 182 | ) 183 | -------------------------------------------------------------------------------- /experiments/patho-sam/push_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from datetime import datetime 4 | 5 | from util import DATASETS, SAM_TYPES, MODEL_NAMES 6 | 7 | 8 | CHECKPOINTS = { 9 | "vanilla_sam": SAM_TYPES, 10 | "generalist_sam": SAM_TYPES, 11 | "pannuke_sam": ["vit_b"], 12 | "lm_sam": ["vit_b_lm"], 13 | "glas_sam": ["vit_b"], 14 | "nuclick_sam": ["vit_b"], 15 | } 16 | 17 | 18 | def write_batch_script(out_path, _name, model_type, dataset, dry, model_name): 19 | "Writing scripts for patho-sam inference." 20 | batch_script = """#!/bin/bash 21 | #SBATCH -t 2-00:00:00 22 | #SBATCH --nodes=1 23 | #SBATCH --ntasks=1 24 | #SBATCH -p grete:shared 25 | #SBATCH -G A100:1 26 | #SBATCH -A nim00007 27 | #SBATCH -x ggpu114 28 | #SBATCH -c 16 29 | #SBATCH --job-name=patho-sam-inference 30 | 31 | source ~/.bashrc 32 | conda activate sam2 \n""" 33 | 34 | # python script 35 | python_script = f"python {_name}.py " 36 | python_script += f"-d {dataset} " # dataset to infer on 37 | python_script += f"-m {model_type} " # name of the model configuration 38 | python_script += f"-n {model_name} " # name of the model 39 | 40 | # let's add the python script to the bash script 41 | batch_script += python_script 42 | 43 | _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh" 44 | with open(_op, "w") as f: 45 | f.write(batch_script) 46 | 47 | cmd = ["sbatch", _op] 48 | if not dry: 49 | subprocess.run(cmd) 50 | 51 | 52 | def get_batch_script_names(tmp_folder): 53 | tmp_folder = os.path.expanduser(tmp_folder) 54 | os.makedirs(tmp_folder, exist_ok=True) 55 | 56 | script_name = "patho-sam-inference" 57 | 58 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") 59 | tmp_name = script_name + dt 60 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") 61 | 62 | return batch_script 63 | 64 | 65 | def submit_slurm(args): 66 | "Submit python script that needs gpus with given inputs on a slurm node." 67 | tmp_folder = "./gpu_jobs" 68 | model_path = "/mnt/lustre-grete/usr/u12649/models" 69 | script_names = ["run_ais", "run_amg"] 70 | for script_name in script_names: 71 | print(f"Running for {script_name}") 72 | for dataset in DATASETS: 73 | for model_name in MODEL_NAMES: 74 | for model in CHECKPOINTS[model_name]: 75 | if model_name == 'glas_sam' and dataset != 'glas': 76 | continue 77 | 78 | if model_name == 'nuclick_sam' and dataset != 'nuclick': 79 | continue 80 | 81 | if script_name != 'run_amg' and model_name == 'vanilla_sam': 82 | continue 83 | 84 | result = os.path.join( 85 | model_path, model_name, "results", dataset, 86 | f"{script_name[4:]}", f"{dataset}_{model_name}_{model}_{script_name[4:]}.csv" 87 | ) 88 | print(result) 89 | if os.path.exists(result): 90 | continue 91 | 92 | write_batch_script( 93 | out_path=get_batch_script_names(tmp_folder), 94 | _name=script_name, 95 | model_type=model, 96 | dataset=dataset, 97 | dry=args.dry, 98 | model_name=model_name, 99 | ) 100 | 101 | 102 | def main(args): 103 | submit_slurm(args) 104 | 105 | 106 | if __name__ == "__main__": 107 | import argparse 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument("-m", "--model_type", type=str, default=None) 110 | parser.add_argument("--dry", action="store_true") 111 | args = parser.parse_args() 112 | main(args) 113 | -------------------------------------------------------------------------------- /experiments/patho-sam/run_ais.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | 6 | from util import TILING_WINDOW_DS, PADDING_DS 7 | 8 | 9 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path): 10 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "instance") 11 | os.makedirs(output_path, exist_ok=True) 12 | if dataset not in PADDING_DS: 13 | input_path = os.path.join(input_dir, "original_data", dataset, "eval_split") 14 | else: 15 | input_path = os.path.join(input_dir, "vit_data", dataset, "eval_split") 16 | 17 | args = [ 18 | "-m", f"{model_type}", 19 | "-c", f"{checkpoint_path}", 20 | "--experiment_folder", f"{output_path}", 21 | "-i", f"{input_path}", 22 | ] 23 | 24 | if dataset in TILING_WINDOW_DS: 25 | args.append("--tiling_window") 26 | command = [ 27 | "python3", os.path.expanduser("~/patho-sam/experiments/patho-sam/evaluate_ais.py"), 28 | ] + args 29 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...") 30 | subprocess.run(command) 31 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'ais'), exist_ok=True) 32 | 33 | shutil.copy( 34 | os.path.join( 35 | output_dir, model_name, "inference", dataset, model_type, 'instance', 36 | 'results', 'instance_segmentation_with_decoder.csv' 37 | ), 38 | os.path.join(output_dir, model_name, 'results', dataset, 'ais', f'{dataset}_{model_name}_{model_type}_ais.csv') 39 | ) 40 | 41 | print(f"Successfully ran inference with {model_name} model (type: {model_type}) on {dataset} dataset") 42 | 43 | 44 | def get_ais_args(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument( 47 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on." 48 | ) 49 | parser.add_argument( 50 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}." 51 | ) 52 | parser.add_argument( 53 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored." 54 | ) 55 | parser.add_argument( 56 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located." 57 | ) 58 | parser.add_argument( 59 | "-n", "--name", type=str, default=None, 60 | help="Provide the name of the model to infer with {generalist_sam, vanilla_sam, ..}." 61 | ) 62 | parser.add_argument( 63 | "-c", "--checkpoint_path", type=str, default=None, 64 | help="(Optional) provide the path to the checkpoint to use for inference." 65 | ) 66 | 67 | 68 | def main(): 69 | args = get_ais_args() 70 | run_inference( 71 | output_dir=args.output_dir, 72 | input_dir=args.input_dir, 73 | model_type=args.model, 74 | dataset=args.dataset, 75 | model_name=args.name, 76 | checkpoint_path=args.checkpoint_path, 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /experiments/patho-sam/run_amg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | 6 | from util import TILING_WINDOW_DS, PADDING_DS 7 | 8 | 9 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path): 10 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "amg") 11 | os.makedirs(output_path, exist_ok=True) 12 | if os.path.exists( 13 | os.path.join(output_dir, model_name, 'results', dataset, 'amg', f'{dataset}_{model_name}_{model_type}_amg.csv') 14 | ): 15 | print(f"Inference with {model_name} model (type: {model_type}) on {dataset} dataset already done") 16 | return 17 | 18 | if dataset not in PADDING_DS: 19 | input_path = os.path.join(input_dir, "original_data", dataset, "eval_split") 20 | else: 21 | input_path = os.path.join(input_dir, "vit_data", dataset, "eval_split") 22 | 23 | args = [ 24 | "-m", f"{model_type}", 25 | "-c", f"{checkpoint_path}", 26 | "--experiment_folder", f"{output_path}", 27 | "-i", f"{input_path}", 28 | ] 29 | 30 | if dataset in TILING_WINDOW_DS: 31 | args.append("--tiling_window") 32 | 33 | command = [ 34 | "python3", os.path.expanduser("~/patho-sam/experiments/patho-sam/evaluate_amg.py"), 35 | ] + args 36 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...") 37 | subprocess.run(command) 38 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'amg'), exist_ok=True) 39 | 40 | shutil.copy( 41 | os.path.join(output_dir, model_name, "inference", dataset, model_type, 'amg', 'results', 'amg.csv'), 42 | os.path.join(output_dir, model_name, 'results', dataset, 'amg', f'{dataset}_{model_name}_{model_type}_amg.csv') 43 | ) 44 | print(f"Successfully ran amg inference with {model_name} model (type: {model_type}) on {dataset} dataset") 45 | 46 | 47 | def get_amg_args(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument( 50 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on." 51 | ) 52 | parser.add_argument( 53 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}." 54 | ) 55 | parser.add_argument( 56 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored." 57 | ) 58 | parser.add_argument( 59 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located." 60 | ) 61 | parser.add_argument( 62 | "-n", "--name", type=str, default=None, 63 | help="Provide the name of the model to infer with {generalist_sam, pannuke_sam, vanilla_sam, ..}." 64 | ) 65 | parser.add_argument( 66 | "-c", "--checkpoint_path", type=str, default=None, 67 | help="(Optional) provide the path to the checkpoint to use for inference." 68 | ) 69 | 70 | 71 | def main(): 72 | args = get_amg_args() 73 | run_inference( 74 | output_dir=args.output_dir, 75 | input_dir=args.input_dir, 76 | model_type=args.model, 77 | dataset=args.dataset, 78 | model_name=args.name, 79 | checkpoint_path=args.checkpoint_path, 80 | ) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /experiments/patho-sam/run_iter_boxes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | 6 | 7 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path, use_masks): 8 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "boxes") 9 | os.makedirs(output_path, exist_ok=True) 10 | if os.path.exists( 11 | os.path.join( 12 | output_dir, model_name, 'results', dataset, 'boxes', 13 | f'{dataset}_{model_name}_{model_type}_boxes.csv' 14 | ) 15 | ): 16 | print(f"Inference with {model_name} model on {dataset} dataset already done") 17 | return 18 | 19 | input_path = os.path.join(input_dir, dataset, "loaded_testset", "eval_split") 20 | args = [ 21 | "-m", f"{model_type}", 22 | "-c", f"{checkpoint_path}", 23 | "--experiment_folder", f"{output_path}", 24 | "-i", f"{input_path}", 25 | "--box", 26 | ] 27 | 28 | if use_masks: 29 | args.append("--use_masks") 30 | 31 | command = [ 32 | "python3", os.path.expanduser("~/patho-sam/experiments/patho-sam/evaluate_iterative_prompting.py"), 33 | ] + args 34 | 35 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...") 36 | subprocess.run(command) 37 | shutil.rmtree(os.path.join(output_path, "embeddings")) 38 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'boxes'), exist_ok=True) 39 | shutil.copy( 40 | os.path.join( 41 | output_dir, model_name, "inference", dataset, model_type, 'boxes', 'results', 42 | 'iterative_prompting_without_mask', 'iterative_prompts_start_box.csv' 43 | ), 44 | os.path.join( 45 | output_dir, model_name, 'results', dataset, 'boxes', f'{dataset}_{model_name}_{model_type}_boxes.csv' 46 | ) 47 | ) 48 | print(f"Successfully ran inference with {model_name} model (type: {model_type}) on {dataset} dataset") 49 | 50 | 51 | def get_iterative_boxes_args(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument( 54 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on." 55 | ) 56 | parser.add_argument( 57 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}." 58 | ) 59 | parser.add_argument( 60 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored." 61 | ) 62 | parser.add_argument( 63 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located." 64 | ) 65 | parser.add_argument( 66 | "-n", "--name", type=str, default=None, 67 | help="Provide the name of the model to infer with {generalist_sam, pannuke_sam, vanilla_sam, ..}." 68 | ) 69 | parser.add_argument( 70 | "-c", "--checkpoint_path", type=str, default=None, 71 | help="(Optional) provide the path to the checkpoint to use for inference." 72 | ) 73 | parser.add_argument( 74 | "--masks_off", action="store_false", help="To disable the usage of logit masks for iterative prompting." 75 | ) 76 | 77 | 78 | def main(): 79 | args = get_iterative_boxes_args() 80 | run_inference( 81 | output_dir=args.output_dir, 82 | input_dir=args.input_dir, 83 | model_type=args.model, 84 | dataset=args.dataset, 85 | model_name=args.name, 86 | checkpoint_path=args.checkpoint_path, 87 | use_masks=args.masks_off 88 | ) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /experiments/patho-sam/run_iter_points.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | 6 | 7 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path, use_masks): 8 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "points") 9 | os.makedirs(output_path, exist_ok=True) 10 | if os.path.exists( 11 | os.path.join( 12 | output_dir, model_name, 'results', dataset, 'points', 13 | f'{dataset}_{model_name}_{model_type}_points.csv' 14 | ) 15 | ): 16 | print(f"Inference with {model_name} model (type: {model_type}) on {dataset} dataset already done") 17 | return 18 | 19 | input_path = os.path.join(input_dir, dataset, "loaded_testset", "eval_split") 20 | args = [ 21 | "-m", f"{model_type}", 22 | "-c", f"{checkpoint_path}", 23 | "--experiment_folder", f"{output_path}", 24 | "-i", f"{input_path}", 25 | ] 26 | 27 | if use_masks: 28 | args.append("--use_masks") 29 | 30 | command = [ 31 | "python3", "/user/titus.griebel/u12649/patho-sam/experiments/patho-sam/evaluate_iterative_prompting.py", 32 | ] + args 33 | 34 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...") 35 | subprocess.run(command) 36 | embedding_path = os.path.join(output_path, "embeddings") 37 | shutil.rmtree(embedding_path) 38 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'points'), exist_ok=True) 39 | shutil.copy( 40 | os.path.join( 41 | output_dir, model_name, "inference", dataset, model_type, 'points', 'results', 42 | 'iterative_prompting_without_mask', 'iterative_prompts_start_point.csv' 43 | ), 44 | os.path.join( 45 | output_dir, model_name, 'results', dataset, 'points', 46 | f'{dataset}_{model_name}_{model_type}_points.csv' 47 | ) 48 | ) 49 | 50 | print( 51 | f"Successfully ran iterative points inference with {model_name} " 52 | "model (type: {model_type}) on {dataset} dataset" 53 | ) 54 | 55 | 56 | def get_iterative_points_args(): 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument( 59 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on." 60 | ) 61 | parser.add_argument( 62 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}." 63 | ) 64 | parser.add_argument( 65 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored." 66 | ) 67 | parser.add_argument( 68 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located." 69 | ) 70 | parser.add_argument( 71 | "-n", "--name", type=str, default=None, 72 | help="Provide the name of the model to infer with {generalist_sam, pannuke_sam, vanilla_sam, ..}." 73 | ) 74 | parser.add_argument( 75 | "-c", "--checkpoint_path", type=str, default=None, 76 | help="(Optional) provide the path to the checkpoint to use for inference." 77 | ) 78 | parser.add_argument( 79 | "--masks_off", action="store_false", help="To disable the usage of logit masks for iterative prompting." 80 | ) 81 | 82 | 83 | def main(): 84 | args = get_iterative_points_args() 85 | run_inference( 86 | output_dir=args.output_dir, 87 | input_dir=args.input_dir, 88 | model_type=args.model, 89 | dataset=args.dataset, 90 | model_name=args.name, 91 | checkpoint_path=args.checkpoint_path, 92 | use_masks=args.masks_off 93 | ) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /experiments/patho-sam/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | from natsort import natsorted 5 | 6 | from micro_sam.util import get_sam_model 7 | 8 | 9 | ROOT = "/scratch/projects/nim00007/sam/data/" 10 | EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/new_models" 11 | 12 | VANILLA_MODELS = { 13 | "vit_t": "/scratch-grete/projects/nim00007/sam/models/vanilla/vit_t_mobile_sam.pth", 14 | "vit_b": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_b_01ec64.pth", 15 | "vit_l": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_l_0b3195.pth", 16 | "vit_h": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_h_4b8939.pth", 17 | "vit_b_lm": None, 18 | "vit_b_histopathology": None, 19 | } 20 | 21 | 22 | SAM_TYPES = ["vit_b", "vit_l", "vit_h"] 23 | 24 | MODEL_NAMES = ["lm_sam", "vanilla_sam", "generalist_sam", "pannuke_sam", "nuclick_sam", "glas_sam"] 25 | 26 | DATASETS = [ 27 | "consep", 28 | "cpm15", 29 | "cpm17", 30 | "cryonuseg", 31 | "glas", 32 | "lizard", 33 | "lynsec_he", 34 | "lynsec_ihc", 35 | "monuseg", 36 | "nuclick", 37 | "nuinsseg", 38 | "pannuke", 39 | "puma", 40 | "srsanet", 41 | "tnbc", 42 | ] 43 | 44 | TILING_WINDOW_DS = [ 45 | "cpm15", 46 | "consep", 47 | "lizard", 48 | "monuseg", 49 | "puma", 50 | ] 51 | 52 | PADDING_DS = [ 53 | "pannuke", 54 | "srsanet", 55 | "nuclick", 56 | ] 57 | 58 | 59 | def get_dataset_paths(dataset_name, split_choice): 60 | file_search_specs = "*" 61 | is_explicit_split = True 62 | 63 | # if the datasets have different modalities/species, let's make use of it 64 | split_names = dataset_name.split("/") 65 | if len(split_names) > 1: 66 | assert len(split_names) <= 2 67 | dataset_name = [split_names[0], "slices", split_names[1]] 68 | else: 69 | dataset_name = [*split_names, "slices"] 70 | 71 | # if there is an explicit val/test split made, let's look at them 72 | if is_explicit_split: 73 | dataset_name.append(split_choice) 74 | 75 | raw_dir = os.path.join(ROOT, *dataset_name, "raw", file_search_specs) 76 | labels_dir = os.path.join(ROOT, *dataset_name, "labels", file_search_specs) 77 | 78 | return raw_dir, labels_dir 79 | 80 | 81 | def get_model(model_type, ckpt): 82 | if ckpt is None: 83 | ckpt = VANILLA_MODELS[model_type] 84 | predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt) 85 | return predictor 86 | 87 | 88 | def get_pred_paths(prediction_folder): 89 | pred_paths = sorted(glob(os.path.join(prediction_folder, "*"))) 90 | return pred_paths 91 | 92 | 93 | def get_default_arguments(): 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument( 96 | "-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor" 97 | ) 98 | parser.add_argument("-c", "--checkpoint", type=none_or_str, default=None) # expects best.pt 99 | parser.add_argument("-e", "--experiment_folder", type=str, required=True) # empty directory for saving the output 100 | parser.add_argument( 101 | "-i", "--input_path", type=str, required=True, default=None, 102 | help="Requires path to a directory containing 'test_images', 'test_labels', 'val_images' and 'val_labels' directories that contain the data", # noqa 103 | ) 104 | parser.add_argument( 105 | "--tiling_window", action="store_true", help="To use tiling window for inputs larger than 512 x 512" 106 | ) 107 | parser.add_argument( 108 | "--box", action="store_true", help="If passed, starts with first prompt as box" 109 | ) 110 | parser.add_argument( 111 | "--use_masks", action="store_true", help="To use logits masks for iterative prompting." 112 | ) 113 | args = parser.parse_args() 114 | return args 115 | 116 | 117 | def dataloading_args(): 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("-p", "--path", type=str, default=None) 120 | parser.add_argument("-d", "--datasets", type=str, default=None) 121 | parser.add_argument("--patch_shape", type=tuple, default=(512, 512)) 122 | 123 | args = parser.parse_args() 124 | return args 125 | 126 | 127 | def none_or_str(value): 128 | if value == "None": 129 | return None 130 | return value 131 | 132 | 133 | def get_val_paths(input_path): 134 | val_image_paths = natsorted(glob(os.path.join(input_path, "val_images/*"))) 135 | val_label_paths = natsorted(glob(os.path.join(input_path, "val_labels/*"))) 136 | return val_image_paths, val_label_paths 137 | 138 | 139 | def get_test_paths(input_path): 140 | test_image_paths = natsorted(glob(os.path.join(input_path, "test_images/*"))) 141 | test_label_paths = natsorted(glob(os.path.join(input_path, "test_labels/*"))) 142 | return test_image_paths, test_label_paths 143 | 144 | 145 | def get_inference_args(): 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument( 148 | "-d", "--dataset", type=str, default=None, 149 | help="The dataset to infer on. If None, all datasets will be chosen." 150 | ) 151 | parser.add_argument( 152 | "-m", "--model", type=str, default=None, 153 | help="Provide the model type to infer with {vit_b, vit_l, vit_h}." 154 | ) 155 | parser.add_argument( 156 | "--use_masks", action="store_true", help="To use logits masks for iterative prompting." 157 | ) 158 | parser.add_argument( 159 | "--tiling_window", action="store_true", 160 | help="To use tiling window for inputs larger than 512 x 512") 161 | parser.add_argument( 162 | "-n", "--name", type=str, default=None, 163 | help="Provide the name of the model to infer with {generalist_sam, vanilla_sam, ..}." 164 | ) 165 | args = parser.parse_args() 166 | return args 167 | -------------------------------------------------------------------------------- /experiments/semantic_segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Scripts for Training and Evaluating Semantic Segmentation using PathoSAM 2 | 3 | This folder contains scripts for training a semantic segmentation model, starting from `patho-sam` generalist models. Here is a brief mention of the folders: 4 | - `benchmarking`: Contains scripts to run BioMedParse for semantic segmentation in histopathology images and scripts to evaluate other benchmark methods for the task. 5 | - `generalists`: Contains scripts to run the training and evaluation of semantic segmentation (in all combinations) from `patho-sam` generalist models. 6 | - `results`: Contains scripts for plotting the results for semantic segmentation. 7 | -------------------------------------------------------------------------------- /experiments/semantic_segmentation/benchmark_methods/evaluate_other_methods.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | from natsort import natsorted 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from tukra.io import read_image 10 | 11 | from patho_sam.evaluation.evaluation import semantic_segmentation_quality, extract_class_weights_for_pannuke 12 | 13 | 14 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external" 15 | 16 | CLASS_IDS = [1, 2, 3, 4, 5] 17 | 18 | 19 | def evaluate_benchmark_methods(per_class_weights): 20 | # Get the original images first. 21 | image_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_images", "*.tiff"))) 22 | gt_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_labels", "*.tiff"))) 23 | 24 | assert image_paths and len(image_paths) == len(gt_paths) 25 | 26 | cellvit_256_20_scores, cellvit_256_40_scores, cellvit_sam_20_scores, cellvit_sam_40_scores = [], [], [], [] 27 | hovernet_scores = [] 28 | hovernext_1_scores, hovernext_2_scores = [], [] 29 | for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)): 30 | # Load the input image and corresponding labels. 31 | gt = read_image(gt_path) 32 | 33 | # If the inputs do not have any semantic labels, we do not evaluate them! 34 | if len(np.unique(gt)) == 1: 35 | continue 36 | 37 | # Get the filename 38 | fname = os.path.basename(image_path) 39 | 40 | # Get predictions and scores per experiment. 41 | # 1. cellvit results (for all models). 42 | cellvit_256_20 = read_image(os.path.join(ROOT, "cellvit", "256-x20", fname)) 43 | cellvit_256_40 = read_image(os.path.join(ROOT, "cellvit", "256-x40", fname)) 44 | cellvit_sam_20 = read_image(os.path.join(ROOT, "cellvit", "SAM-H-x20", fname)) 45 | cellvit_sam_40 = read_image(os.path.join(ROOT, "cellvit", "SAM-H-x40", fname)) 46 | 47 | cellvit_256_20_scores.append( 48 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_256_20, class_ids=CLASS_IDS) 49 | ) 50 | cellvit_256_40_scores.append( 51 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_256_40, class_ids=CLASS_IDS) 52 | ) 53 | cellvit_sam_20_scores.append( 54 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_sam_20, class_ids=CLASS_IDS) 55 | ) 56 | cellvit_sam_40_scores.append( 57 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_sam_40, class_ids=CLASS_IDS) 58 | ) 59 | 60 | # 2. hovernet results. 61 | hovernet = read_image(os.path.join(ROOT, "hovernet", "pannuke", fname)) 62 | 63 | hovernet_scores.append( 64 | semantic_segmentation_quality(ground_truth=gt, segmentation=hovernet, class_ids=CLASS_IDS) 65 | ) 66 | 67 | # 3. hovernext results. 68 | hovernext_1 = read_image(os.path.join(ROOT, "hovernext", "pannuke_convnextv2_tiny_1", fname)) 69 | hovernext_2 = read_image(os.path.join(ROOT, "hovernext", "pannuke_convnextv2_tiny_2", fname)) 70 | 71 | hovernext_1_scores.append( 72 | semantic_segmentation_quality(ground_truth=gt, segmentation=hovernext_1, class_ids=CLASS_IDS) 73 | ) 74 | hovernext_2_scores.append( 75 | semantic_segmentation_quality(ground_truth=gt, segmentation=hovernext_2, class_ids=CLASS_IDS) 76 | ) 77 | 78 | def _get_average_results(sq_per_image, fname): 79 | msq_neoplastic_cells = np.nanmean([sq[0] for sq in sq_per_image]) 80 | msq_inflammatory = np.nanmean([sq[1] for sq in sq_per_image]) 81 | msq_connective = np.nanmean([sq[2] for sq in sq_per_image]) 82 | msq_dead = np.nanmean([sq[3] for sq in sq_per_image]) 83 | msq_epithelial = np.nanmean([sq[4] for sq in sq_per_image]) 84 | 85 | all_msq = [msq_neoplastic_cells, msq_inflammatory, msq_connective, msq_dead, msq_epithelial] 86 | weighted_mean_msq = [msq * weight for msq, weight in zip(all_msq, per_class_weights)] 87 | 88 | results = { 89 | "neoplastic_cells": msq_neoplastic_cells, 90 | "inflammatory_cells": msq_inflammatory, 91 | "connective_cells": msq_connective, 92 | "dead_cells": msq_dead, 93 | "epithelial_cells": msq_epithelial, 94 | "weighted_mean": np.sum(weighted_mean_msq), 95 | "absolute_mean": np.mean(all_msq) 96 | } 97 | results = pd.DataFrame.from_dict([results]) 98 | results.to_csv(fname) 99 | print(results) 100 | 101 | # Get average results per method. 102 | _get_average_results(cellvit_256_20_scores, "cellvit_256_20_semantic.csv") 103 | _get_average_results(cellvit_256_40_scores, "cellvit_256_40_semantic.csv") 104 | _get_average_results(cellvit_sam_20_scores, "cellvit_sam_20_semantic.csv") 105 | _get_average_results(cellvit_sam_40_scores, "cellvit_sam_40_semantic.csv") 106 | _get_average_results(hovernet_scores, "hovernet_semantic.csv") 107 | _get_average_results(hovernext_1_scores, "hovernext_1_semantic.csv") 108 | _get_average_results(hovernext_2_scores, "hovernext_2_semantic.csv") 109 | 110 | 111 | def main(): 112 | # Get per class weights. 113 | fpath = os.path.join(*ROOT.rsplit("/")[:-2], "data", "pannuke", "pannuke_fold_3.h5") 114 | fpath = "/" + fpath 115 | per_class_weights = extract_class_weights_for_pannuke(fpath=fpath) 116 | 117 | # Run evaluation for external benchmark methods. 118 | evaluate_benchmark_methods(per_class_weights) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /experiments/semantic_segmentation/benchmark_methods/run_biomedparse.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | from natsort import natsorted 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import imageio.v3 as imageio 9 | 10 | from tukra.io import read_image 11 | from tukra.inference import get_biomedparse 12 | 13 | from patho_sam.evaluation.evaluation import semantic_segmentation_quality, extract_class_weights_for_pannuke 14 | 15 | 16 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external" 17 | 18 | MAPPING = { 19 | "neoplastic cells": 1, 20 | "inflammatory cells": 2, 21 | "connective tissue cells": 3, 22 | "dead cells": 4, # NOTE: dead cells are not a semantic class involved in biomedparse. 23 | "epithelial cells": 5, 24 | } 25 | 26 | 27 | def evaluate_biomedparse_for_histopathology(dataset): 28 | 29 | # Other stuff for biomedparse 30 | modality = "Pathology" # choice of modality to determine the semantic targets. 31 | model = get_biomedparse.get_biomedparse_model() # get the biomedparse model. 32 | 33 | # Get per class weights. 34 | fpath = os.path.join(*ROOT.rsplit("/")[:-2], "data", "pannuke", "pannuke_fold_3.h5") 35 | fpath = "/" + fpath 36 | per_class_weights = extract_class_weights_for_pannuke(fpath=fpath) 37 | 38 | # Get the inputs and corresponding labels. 39 | image_paths = natsorted(glob(os.path.join(ROOT, dataset, "test_images", "*.tiff"))) 40 | gt_paths = natsorted(glob(os.path.join(ROOT, dataset, "test_labels", "*.tiff"))) 41 | 42 | assert image_paths and len(image_paths) == len(gt_paths) 43 | 44 | # Get results directory 45 | result_dir = os.path.join(ROOT, "biomedparse", dataset) 46 | os.makedirs(result_dir, exist_ok=True) 47 | 48 | sq_per_image = [] 49 | for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)): 50 | # Get the input image and corresponding semantic labels. 51 | image = read_image(image_path).astype("float32") 52 | gt = read_image(gt_path) 53 | 54 | # If the inputs do not have any semantic labels, we do not evaluate them! 55 | if len(np.unique(gt)) == 1: 56 | continue 57 | 58 | # Run inference per image. 59 | prediction = get_biomedparse.run_biomedparse_automatic_inference( 60 | input_path=image, modality_type=modality, model=model, verbose=False, 61 | ) 62 | 63 | semantic_seg = np.zeros_like(gt, dtype="uint8") 64 | if prediction is not None: 65 | prompts = list(prediction.keys()) # Extracting detected classes. 66 | segmentations = list(prediction.values()) # Extracting the segmentations. 67 | 68 | # Map all predicted labels. 69 | for prompt, per_seg in zip(prompts, segmentations): 70 | semantic_seg[per_seg > 0] = MAPPING[prompt] 71 | 72 | # Evaluate scores. 73 | sq_score = semantic_segmentation_quality( 74 | ground_truth=gt, segmentation=semantic_seg, class_ids=list(MAPPING.values()), 75 | ) 76 | sq_per_image.append(sq_score) 77 | 78 | # Store the semantic segmentation results to avoid running them all the time. 79 | imageio.imwrite(os.path.join(result_dir, os.path.basename(image_path)), semantic_seg, compression="zlib") 80 | 81 | def _get_average_results(sq_per_image, fname): 82 | msq_neoplastic_cells = np.nanmean([sq[0] for sq in sq_per_image]) 83 | msq_inflammatory = np.nanmean([sq[1] for sq in sq_per_image]) 84 | msq_connective = np.nanmean([sq[2] for sq in sq_per_image]) 85 | msq_dead = np.nanmean([sq[3] for sq in sq_per_image]) # This class is absent. So this would obv. be zero. 86 | msq_epithelial = np.nanmean([sq[4] for sq in sq_per_image]) 87 | 88 | all_msq = [msq_neoplastic_cells, msq_inflammatory, msq_connective, msq_dead, msq_epithelial] 89 | weighted_mean_msq = [msq * weight for msq, weight in zip(all_msq, per_class_weights)] 90 | 91 | results = { 92 | "neoplastic_cells": msq_neoplastic_cells, 93 | "inflammatory_cells": msq_inflammatory, 94 | "connective_cells": msq_connective, 95 | "dead_cells": msq_dead, 96 | "epithelial_cells": msq_epithelial, 97 | "weighted_mean": np.sum(weighted_mean_msq), 98 | "absolute_mean": np.mean(all_msq) 99 | } 100 | results = pd.DataFrame.from_dict([results]) 101 | results.to_csv(fname) 102 | print(results) 103 | 104 | # Get average results for biomedparse. 105 | _get_average_results(sq_per_image, "biomedparse_semantic.csv") 106 | 107 | 108 | def main(): 109 | # Run automatic (semantic) segmentation inference using biomedparse 110 | evaluate_biomedparse_for_histopathology("pannuke") 111 | evaluate_biomedparse_for_histopathology("puma") 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /experiments/semantic_segmentation/generalists/evaluate_pannuke.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from tqdm import tqdm 4 | from natsort import natsorted 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import torch 10 | 11 | from tukra.io import read_image 12 | 13 | from micro_sam.util import get_sam_model 14 | from micro_sam.instance_segmentation import get_unetr 15 | 16 | from patho_sam.evaluation.evaluation import semantic_segmentation_quality, extract_class_weights_for_pannuke 17 | 18 | 19 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external" 20 | 21 | 22 | def evaluate_pannuke_semantic_segmentation(args): 23 | # Stuff needed for inference 24 | model_type = args.model_type 25 | checkpoint_path = args.checkpoint_path 26 | num_classes = 6 # available classes are [0, 1, 2, 3, 4, 5] 27 | device = "cuda" if torch.cuda.is_available() else "cpu" 28 | 29 | # Get per class weights. 30 | fpath = os.path.join(*ROOT.rsplit("/")[:-2], "data", "pannuke", "pannuke_fold_3.h5") 31 | fpath = "/" + fpath 32 | per_class_weights = extract_class_weights_for_pannuke(fpath=fpath) 33 | 34 | # Get the inputs and corresponding labels. 35 | image_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_images", "*.tiff"))) 36 | gt_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_labels", "*.tiff"))) 37 | 38 | assert len(image_paths) == len(gt_paths) and image_paths 39 | 40 | # Get the SAM model 41 | predictor = get_sam_model(model_type=model_type, device=device) 42 | 43 | # Get the UNETR model for semantic segmentation pipeline 44 | unetr = get_unetr( 45 | image_encoder=predictor.model.image_encoder, out_channels=num_classes, device=device, 46 | ) 47 | 48 | # Load the model weights 49 | model_state = torch.load(checkpoint_path, map_location="cpu")["model_state"] 50 | unetr.load_state_dict(model_state) 51 | unetr.to(device) 52 | unetr.eval() 53 | 54 | sq_per_image = [] 55 | with torch.no_grad(): 56 | for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)): 57 | # Read input image and corresponding labels. 58 | image = read_image(image_path) 59 | gt = read_image(gt_path) 60 | 61 | # If the inputs do not have any semantic labels, we do not evaluate them! 62 | if len(np.unique(gt)) == 1: 63 | continue 64 | 65 | # Pad the input image to fit the trained image shape. 66 | # NOTE: We pad it to top-left. 67 | image = np.pad(array=image, pad_width=((0, 256), (0, 256), (0, 0)), mode='constant') 68 | 69 | # Run inference 70 | tensor_image = image.transpose(2, 0, 1) 71 | tensor_image = torch.from_numpy(tensor_image[None]).to(device, torch.float32) 72 | outputs = unetr(tensor_image) 73 | 74 | # Perform argmax to get per class outputs. 75 | masks = torch.argmax(outputs, dim=1) 76 | masks = masks.detach().cpu().numpy().squeeze() 77 | 78 | # Unpad the images back to match the original shape. 79 | masks = masks[:256, :256] 80 | 81 | # Calcuate the score. 82 | sq_per_image.append( 83 | semantic_segmentation_quality(gt, masks, class_ids=[1, 2, 3, 4, 5]) 84 | ) 85 | 86 | def _get_average_results(sq_per_image, fname): 87 | msq_neoplastic_cells = np.nanmean([sq[0] for sq in sq_per_image]) 88 | msq_inflammatory = np.nanmean([sq[1] for sq in sq_per_image]) 89 | msq_connective = np.nanmean([sq[2] for sq in sq_per_image]) 90 | msq_dead = np.nanmean([sq[3] for sq in sq_per_image]) 91 | msq_epithelial = np.nanmean([sq[4] for sq in sq_per_image]) 92 | 93 | all_msq = [msq_neoplastic_cells, msq_inflammatory, msq_connective, msq_dead, msq_epithelial] 94 | weighted_mean_msq = [msq * weight for msq, weight in zip(all_msq, per_class_weights)] 95 | 96 | results = { 97 | "neoplastic_cells": msq_neoplastic_cells, 98 | "inflammatory_cells": msq_inflammatory, 99 | "connective_cells": msq_connective, 100 | "dead_cells": msq_dead, 101 | "epithelial_cells": msq_epithelial, 102 | "weighted_mean": np.sum(weighted_mean_msq), 103 | "absolute_mean": np.mean(all_msq) 104 | } 105 | results = pd.DataFrame.from_dict([results]) 106 | results.to_csv(fname) 107 | print(results) 108 | 109 | # Get average results per method. 110 | fname = checkpoint_path.rsplit("/")[-2] # Fetches the name of the style of training for semantic segmentation. 111 | _get_average_results(sq_per_image, f"pathosam_{fname}.csv") 112 | 113 | 114 | def main(args): 115 | evaluate_pannuke_semantic_segmentation(args) 116 | 117 | 118 | if __name__ == "__main__": 119 | import argparse 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("-i", "--input_path", default="/mnt/vast-nhr/projects/cidas/cca/test/data", type=str) 122 | parser.add_argument("-m", "--model_type", default="vit_b", type=str) 123 | parser.add_argument("-c", "--checkpoint_path", required=True) 124 | parser.add_argument("--view", action="store_true") 125 | args = parser.parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /experiments/semantic_segmentation/generalists/submit_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import itertools 5 | from datetime import datetime 6 | 7 | 8 | def submit_batch_script(script_name, decoder_only, decoder_from_pretrained, save_root, dry): 9 | batch_script = """#!/bin/bash 10 | #SBATCH -c 16 11 | #SBATCH --mem 64G 12 | #SBATCH -p grete-h100:shared 13 | #SBATCH -t 2-00:00:00 14 | #SBATCH -G H100:1 15 | #SBATCH -A gzz0001 16 | #SBATCH --job-name=patho-sam 17 | 18 | source ~/.bashrc 19 | micromamba activate super 20 | """ 21 | # Prepare the python scripts 22 | python_script = "python train_pannuke.py " 23 | python_script += f"-s {save_root} " 24 | 25 | if decoder_only: 26 | python_script += "--decoder_only " 27 | 28 | if decoder_from_pretrained: 29 | python_script += "--decoder_from_pretrained " 30 | 31 | # Add the python script to the bash script 32 | batch_script += python_script 33 | with open(script_name, "w") as f: 34 | f.write(batch_script) 35 | 36 | if not dry: 37 | cmd = ["sbatch", script_name] 38 | subprocess.run(cmd) 39 | 40 | 41 | def get_batch_script_names(tmp_folder): 42 | tmp_folder = os.path.expanduser(tmp_folder) 43 | os.makedirs(tmp_folder, exist_ok=True) 44 | 45 | script_name = "patho-sam-semantic-segmentation" 46 | 47 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") 48 | tmp_name = script_name + dt 49 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") 50 | 51 | return batch_script 52 | 53 | 54 | def main(args): 55 | tmp_folder = "./gpu_jobs" 56 | if os.path.exists(tmp_folder): 57 | shutil.rmtree(tmp_folder) 58 | 59 | for decoder_only, decoder_from_pretrained in itertools.product( 60 | [True, False], [True, False] 61 | ): 62 | submit_batch_script( 63 | script_name=get_batch_script_names(tmp_folder), 64 | decoder_only=decoder_only, 65 | decoder_from_pretrained=decoder_from_pretrained, 66 | save_root=args.save_root, 67 | dry=args.dry, 68 | ) 69 | 70 | 71 | if __name__ == "__main__": 72 | import argparse 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("--dry", action="store_true") 75 | parser.add_argument( 76 | "-s", "--save_root", type=str, default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/models" 77 | ) 78 | args = parser.parse_args() 79 | main(args) 80 | -------------------------------------------------------------------------------- /experiments/semantic_segmentation/results/get_qualitative_plots.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from skimage.segmentation import find_boundaries 6 | 7 | from tukra.io import read_image 8 | 9 | 10 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external" 11 | 12 | 13 | def get_qualitative_plots(): 14 | # Get the inputs and corresponding labels 15 | fnames = [ 16 | '1445.tiff', '1145.tiff', '1391.tiff', '2087.tiff', 17 | '2382.tiff', '0335.tiff', '0594.tiff', '2386.tiff', 18 | '2316.tiff', '2381.tiff', "1446.tiff", "1407.tiff" 19 | ] 20 | image_paths = [os.path.join(ROOT, "semantic_split", "test_images", fname) for fname in fnames] 21 | gt_paths = [os.path.join(ROOT, "semantic_split", "test_labels", fname) for fname in fnames] 22 | 23 | for image_path, gt_path in zip(image_paths, gt_paths): 24 | this_fname = os.path.basename(image_path) 25 | 26 | image = read_image(image_path) 27 | gt = read_image(gt_path) 28 | 29 | # Get an overlay for the input image 30 | mask_overlay = np.zeros_like(image) 31 | # Individual colors per class 32 | mask_overlay[find_boundaries(gt == 1)] = [255, 0, 0] 33 | mask_overlay[find_boundaries(gt == 2)] = [0, 255, 0] 34 | mask_overlay[find_boundaries(gt == 3)] = [0, 0, 255] 35 | mask_overlay[find_boundaries(gt == 4)] = [255, 255, 0] 36 | mask_overlay[find_boundaries(gt == 5)] = [0, 255, 255] 37 | 38 | # Map overlay over the image 39 | alpha = 0.5 40 | overlay = alpha * image + (1.0 - alpha) * mask_overlay 41 | overlay = overlay.astype("uint8") 42 | 43 | def _get_expected_labels(fpath): 44 | label = read_image(fpath) 45 | 46 | # Individual colors per class 47 | label_overlay = np.zeros_like(image) 48 | label_overlay[label == 1] = [255, 0, 0] 49 | label_overlay[label == 2] = [0, 255, 0] 50 | label_overlay[label == 3] = [0, 0, 255] 51 | label_overlay[label == 4] = [255, 255, 0] 52 | label_overlay[label == 5] = [0, 255, 255] 53 | 54 | return label_overlay.astype("uint8") 55 | 56 | import matplotlib.pyplot as plt 57 | fig, ax = plt.subplots(1, 6, figsize=(30, 20)) 58 | ax[0].imshow(overlay) 59 | ax[0].axis("off") 60 | 61 | ax[1].imshow(_get_expected_labels(os.path.join(ROOT, "hovernet", "pannuke", this_fname))) 62 | ax[1].axis("off") 63 | 64 | ax[2].imshow(_get_expected_labels(os.path.join(ROOT, "hovernext", "pannuke_convnextv2_tiny_1", this_fname))) 65 | ax[2].axis("off") 66 | 67 | ax[3].imshow(_get_expected_labels(os.path.join(ROOT, "cellvit", "SAM-H-x40", this_fname))) 68 | ax[3].axis("off") 69 | 70 | ax[4].imshow(_get_expected_labels(f"biomedparse_{this_fname}")) # I cache predictions on-the-fly atm. 71 | ax[4].axis("off") 72 | 73 | ax[5].imshow(_get_expected_labels(f"pathosam_{this_fname}")) # I cache predictions on-the-fly atm. 74 | ax[5].axis("off") 75 | 76 | plt.subplots_adjust(wspace=0.05, hspace=0) 77 | plt.savefig(("./" + Path(this_fname).stem + ".svg"), bbox_inches="tight") 78 | plt.close() 79 | 80 | 81 | def main(): 82 | get_qualitative_plots() 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /experiments/training/README.md: -------------------------------------------------------------------------------- 1 | # Finetuning Segment Anything Model for Histopathlogy 2 | 3 | This folder contains scripts for finetuning a generalist model for nuclei instance segmentation in histopathology H&E stained images. Here is a brief mention of the folders: 4 | - `generalists`: Contains scripts for training our generalist `patho-sam` models, on a large histopathology dataset for nuclei (automatic and interactive) instance segmentation, inspired by `micro-sam`. 5 | - `specialists`: Contains scripts for training our specialist `patho-sam` models, on PanNuke (for nuclei), GlaS (for glands), NuClick (for lymphocytes) and PUMA (for neutrophils) for task-/data-specific exploration. 6 | 7 | > NOTE: The scripts should run out-of-the-box and download most datasets automatically. For some datasets, automatic downloads are not supported. See the respective `get_generalist_datasets.py` or `get_specialist_dataset.py` for details about their downloads! 8 | -------------------------------------------------------------------------------- /experiments/training/generalists/get_generalist_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.utils.data as data_util 5 | 6 | import torch_em 7 | from torch_em.data import ConcatDataset, MinInstanceSampler, datasets 8 | from torch_em.transform.label import PerObjectDistanceTransform 9 | 10 | from patho_sam.training import histopathology_identity 11 | 12 | 13 | def _get_train_val_split(ds, val_fraction=0.2): 14 | generator = torch.Generator().manual_seed(42) 15 | train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator) 16 | return train_ds, val_ds 17 | 18 | 19 | def get_concat_hp_datasets(path, patch_shape, split_choice): 20 | # Important stuff for dataloaders. 21 | label_dtype = torch.float32 22 | sampler = MinInstanceSampler(min_num_instances=4, min_size=10) 23 | 24 | # Expected raw and label transforms. 25 | raw_transform = histopathology_identity 26 | label_transform = PerObjectDistanceTransform( 27 | distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=10, 28 | ) 29 | 30 | # Datasets used for training: CPM15, CPM17, Lizard, MoNuSeg, PanNuke, PUMA, TNBC 31 | cpm15_ds = datasets.get_cpm_dataset( 32 | path=os.path.join(path, "cpm15"), patch_shape=patch_shape, sampler=sampler, 33 | label_dtype=label_dtype, raw_transform=raw_transform, data_choice="cpm15", 34 | split=split_choice, label_transform=label_transform, n_samples=50, # NOTE: oversampling the data. 35 | ) 36 | 37 | lizard_ds = datasets.get_lizard_dataset( 38 | path=os.path.join(path, "lizard"), patch_shape=patch_shape, download=True, sampler=sampler, 39 | label_dtype=label_dtype, split=split_choice, label_transform=label_transform, raw_transform=raw_transform, 40 | ) 41 | 42 | puma_ds = datasets.get_puma_dataset( 43 | path=os.path.join(path, "puma"), patch_shape=patch_shape, download=True, sampler=sampler, split=split_choice, 44 | label_transform=label_transform, raw_transform=raw_transform, label_dtype=label_dtype, 45 | ) 46 | 47 | tnbc_ds = datasets.get_tnbc_dataset( 48 | path=os.path.join(path, "tnbc"), patch_shape=patch_shape, download=True, sampler=sampler, 49 | split=split_choice, label_transform=label_transform, label_dtype=label_dtype, 50 | ndim=2, raw_transform=raw_transform, n_samples=50, # NOTE: oversampling the data. 51 | ) 52 | 53 | def _get_cpm17_dataset(): 54 | cpm17_ds = datasets.get_cpm_dataset( 55 | path=os.path.join(path, "cpm17"), patch_shape=patch_shape, sampler=sampler, 56 | label_dtype=label_dtype, raw_transform=raw_transform, data_choice="cpm17", 57 | split="train", label_transform=label_transform, n_samples=50, # NOTE: oversampling the data. 58 | ) 59 | cpm17_train_ds, cpm17_val_ds = _get_train_val_split(ds=cpm17_ds) 60 | if split_choice == "train": 61 | return cpm17_train_ds 62 | else: 63 | return cpm17_val_ds 64 | 65 | def _get_monuseg_dataset(): 66 | monuseg_ds = datasets.get_monuseg_dataset( 67 | path=os.path.join(path, "monuseg"), patch_shape=patch_shape, download=True, split="train", 68 | sampler=sampler, label_transform=label_transform, label_dtype=label_dtype, 69 | ndim=2, raw_transform=raw_transform, n_samples=50, # NOTE: oversampling the data. 70 | ) 71 | monuseg_train_ds, monuseg_val_ds = _get_train_val_split(ds=monuseg_ds) 72 | if split_choice == "train": 73 | return monuseg_train_ds 74 | else: 75 | return monuseg_val_ds 76 | 77 | def _get_pannuke_dataset(): 78 | pannuke_ds = datasets.get_pannuke_dataset( 79 | path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), download=True, 80 | sampler=sampler, ndim=2, folds=["fold_1", "fold_2"], label_dtype=label_dtype, 81 | label_transform=label_transform, raw_transform=raw_transform, 82 | ) 83 | pannuke_train_ds, pannuke_val_ds = _get_train_val_split(ds=pannuke_ds) 84 | if split_choice == "train": 85 | return pannuke_train_ds 86 | else: 87 | return pannuke_val_ds 88 | 89 | _datasets = [ 90 | cpm15_ds, lizard_ds, puma_ds, tnbc_ds, _get_cpm17_dataset(), _get_monuseg_dataset(), _get_pannuke_dataset() 91 | ] 92 | 93 | return ConcatDataset(*_datasets) 94 | 95 | 96 | def get_generalist_hp_loaders(patch_shape, data_path): 97 | """This returns the concatenated histopathology datasets implemented in `torch_em`: 98 | https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets/histopathology 99 | It will automatically download all the datasets 100 | 101 | NOTE: To remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) 102 | in `get_concat_hp_datasets`. The labels have to be in a label mask instance segmentation format. 103 | i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. 104 | IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. 105 | """ 106 | # Get the datasets 107 | generalist_train_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape, split_choice="train") 108 | generalist_val_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape, split_choice="val") 109 | 110 | # Get the dataloaders 111 | train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) 112 | val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) 113 | 114 | return train_loader, val_loader 115 | -------------------------------------------------------------------------------- /experiments/training/generalists/run_all_finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from datetime import datetime 5 | 6 | 7 | N_OBJECTS = {"vit_b": 40, "vit_l": 30, "vit_h": 25} 8 | 9 | 10 | def write_batch_script(out_path, _name, model_type, save_root, dry): 11 | "Writing scripts for different patho-sam finetunings." 12 | batch_script = """#!/bin/bash 13 | #SBATCH -t 14-00:00:00 14 | #SBATCH --mem 128G 15 | #SBATCH --nodes=1 16 | #SBATCH --ntasks=1 17 | #SBATCH -p grete:shared 18 | #SBATCH -G A100:1 19 | #SBATCH -A gzz0001 20 | #SBATCH -c 32 21 | #SBATCH --qos=14d 22 | #SBATCH --constraint=80gb 23 | #SBATCH --job-name=patho-sam-generalist 24 | 25 | source ~/.bashrc 26 | micromamba activate sam \n""" 27 | 28 | # python script 29 | python_script = f"python {_name}.py " 30 | python_script += f"-s {save_root} " # save root folder 31 | python_script += f"-m {model_type} " # name of the model configuration 32 | python_script += f"--n_objects {N_OBJECTS[model_type]} " # choice of the number of objects 33 | 34 | # let's add the python script to the bash script 35 | batch_script += python_script 36 | 37 | _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh" 38 | with open(_op, "w") as f: 39 | f.write(batch_script) 40 | 41 | cmd = ["sbatch", _op] 42 | if not dry: 43 | subprocess.run(cmd) 44 | 45 | 46 | def get_batch_script_names(tmp_folder): 47 | tmp_folder = os.path.expanduser(tmp_folder) 48 | os.makedirs(tmp_folder, exist_ok=True) 49 | 50 | script_name = "patho-sam-finetuning" 51 | 52 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") 53 | tmp_name = script_name + dt 54 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") 55 | 56 | return batch_script 57 | 58 | 59 | def submit_slurm(args): 60 | "Submit python script that needs gpus with given inputs on a slurm node." 61 | tmp_folder = "./gpu_jobs" 62 | if os.path.exists(tmp_folder): 63 | shutil.rmtree(tmp_folder) 64 | 65 | if args.model_type is None: 66 | models = list(N_OBJECTS.keys()) 67 | else: 68 | models = [args.model_type] 69 | 70 | script_name = "train_generalist_histopathology" 71 | print(f"Running for {script_name}") 72 | for model_type in models: 73 | write_batch_script( 74 | out_path=get_batch_script_names(tmp_folder), 75 | _name=script_name, 76 | model_type=model_type, 77 | save_root=args.save_root, 78 | dry=args.dry, 79 | ) 80 | 81 | 82 | def main(args): 83 | submit_slurm(args) 84 | 85 | 86 | if __name__ == "__main__": 87 | import argparse 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument( 90 | "-s", "--save_root", type=str, default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/models" 91 | ) 92 | parser.add_argument("-m", "--model_type", type=str, default=None) 93 | parser.add_argument("--dry", action="store_true") 94 | args = parser.parse_args() 95 | main(args) 96 | -------------------------------------------------------------------------------- /experiments/training/generalists/train_generalist_histopathology.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | import micro_sam.training as sam_training 7 | from micro_sam.util import export_custom_sam_model 8 | 9 | from get_generalist_datasets import get_generalist_hp_loaders 10 | 11 | 12 | def finetune_generalist(args): 13 | """Example code for finetuning SAM on histopathology datasets.""" 14 | # override this (below) if you have some more complex set-up and need to specify the exact GPU. 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | # training settings: 18 | model_type = args.model_type 19 | checkpoint_path = None # override this to start training from a custom checkpoint. 20 | patch_shape = (512, 512) # the patch shape for training. 21 | n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled. 22 | checkpoint_name = f"{args.model_type}/patho_sam" 23 | 24 | # all the stuff we need for training 25 | train_loader, val_loader = get_generalist_hp_loaders(patch_shape=patch_shape, data_path=args.input_path) 26 | scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10} 27 | 28 | # Run training. 29 | sam_training.train_sam( 30 | name=checkpoint_name, 31 | model_type=model_type, 32 | train_loader=train_loader, 33 | val_loader=val_loader, 34 | early_stopping=None, 35 | n_objects_per_batch=n_objects_per_batch, 36 | checkpoint_path=checkpoint_path, 37 | with_segmentation_decoder=True, 38 | device=device, 39 | lr=1e-5, 40 | n_iterations=args.iterations, 41 | save_root=args.save_root, 42 | scheduler_kwargs=scheduler_kwargs, 43 | verify_n_labels_in_loader=None, # NOTE: Setting to 'None' verifies all labels in the loader(s). 44 | ) 45 | 46 | if args.export_path is not None: 47 | checkpoint_path = os.path.join( 48 | "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" 49 | ) 50 | export_custom_sam_model( 51 | checkpoint_path=checkpoint_path, 52 | model_type=model_type, 53 | save_path=args.export_path, 54 | ) 55 | 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.") 59 | parser.add_argument( 60 | "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data", 61 | help="The filepath to the datasets. If the data does not exist yet it will be downloaded.", 62 | ) 63 | parser.add_argument( 64 | "--model_type", "-m", default="vit_b", 65 | help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h.", 66 | ) 67 | parser.add_argument( 68 | "--save_root", "-s", default=None, 69 | help="Where to save the checkpoint and logs. By default they will be saved where this script is run.", 70 | ) 71 | parser.add_argument( 72 | "--iterations", type=int, default=int(25e4), 73 | help="For how many iterations should the model be trained? By default 250k.", 74 | ) 75 | parser.add_argument( 76 | "--export_path", "-e", 77 | help="Where to export the finetuned model to. The exported model can be used in the annotation tools.", 78 | ) 79 | parser.add_argument( 80 | "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." 81 | ) 82 | args = parser.parse_args() 83 | finetune_generalist(args) 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /experiments/training/specialists/finetune_neutrophils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from glob import glob 4 | from natsort import natsorted 5 | 6 | import torch 7 | 8 | import micro_sam.training as sam_training 9 | from micro_sam.util import export_custom_sam_model 10 | 11 | import torch_em 12 | from torch_em.data import MinInstanceSampler 13 | from torch_em.transform.label import PerObjectDistanceTransform 14 | 15 | from patho_sam.training import histopathology_identity 16 | 17 | 18 | def get_dataloaders(patch_shape, data_path): 19 | label_dtype = torch.float32 20 | sampler = MinInstanceSampler(min_num_instances=4, min_size=10) 21 | 22 | # Expected raw and label transforms. 23 | raw_transform = histopathology_identity 24 | label_transform = PerObjectDistanceTransform( 25 | distances=True, 26 | boundary_distances=True, 27 | directed_distances=False, 28 | foreground=True, 29 | instances=True, 30 | min_size=10, 31 | ) 32 | 33 | train_images = natsorted(glob(os.path.join(data_path, "train", "images", "*.tiff"))) 34 | train_labels = natsorted(glob(os.path.join(data_path, "train", "labels", "*.tiff"))) 35 | 36 | val_images = natsorted(glob(os.path.join(data_path, "val", "images", "*.tiff"))) 37 | val_labels = natsorted(glob(os.path.join(data_path, "val", "labels", "*.tiff"))) 38 | 39 | train_loader = torch_em.default_segmentation_loader( 40 | raw_paths=train_images, 41 | raw_key=None, 42 | label_paths=train_labels, 43 | label_key=None, 44 | patch_shape=patch_shape, 45 | ndim=2, 46 | with_channels=True, 47 | is_seg_dataset=False, 48 | sampler=sampler, 49 | label_dtype=label_dtype, 50 | label_transform=label_transform, 51 | raw_transform=raw_transform, 52 | batch_size=1, 53 | shuffle=True, 54 | num_workers=16, 55 | 56 | ) 57 | val_loader = torch_em.default_segmentation_dataset( 58 | raw_paths=val_images, 59 | raw_key=None, 60 | label_paths=val_labels, 61 | label_key=None, 62 | patch_shape=patch_shape, 63 | ndim=2, 64 | with_channels=True, 65 | is_seg_dataset=False, 66 | sampler=sampler, 67 | label_dtype=label_dtype, 68 | label_transform=label_transform, 69 | raw_transform=raw_transform, 70 | batch_size=1, 71 | shuffle=True, 72 | num_workers=16, 73 | ) 74 | 75 | return train_loader, val_loader 76 | 77 | 78 | def finetune_specialist(args): 79 | """Example code for finetuning SAM on histopathology datasets.""" 80 | # override this (below) if you have some more complex set-up and need to specify the exact GPU. 81 | device = "cuda" if torch.cuda.is_available() else "cpu" 82 | 83 | # training settings: 84 | model_type = args.model_type 85 | checkpoint_path = args.checkpoint # override this to start training from a custom checkpoint. 86 | patch_shape = (512, 512) # the patch shape for training. 87 | n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled. 88 | checkpoint_name = f"{args.model_type}/neutrophil_sam" 89 | 90 | # all the stuff we need for training 91 | train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) 92 | scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10} 93 | 94 | # Run training. 95 | sam_training.train_sam( 96 | name=checkpoint_name, 97 | model_type=model_type, 98 | train_loader=train_loader, 99 | val_loader=val_loader, 100 | early_stopping=None, 101 | n_objects_per_batch=n_objects_per_batch, 102 | checkpoint_path=checkpoint_path, 103 | with_segmentation_decoder=True, 104 | device=device, 105 | lr=1e-5, 106 | n_iterations=args.iterations, 107 | save_root=args.save_root, 108 | scheduler_kwargs=scheduler_kwargs, 109 | verify_n_labels_in_loader=None, # NOTE: Setting to 'None' verifies all labels in the loader(s). 110 | ) 111 | 112 | if args.export_path is not None: 113 | checkpoint_path = os.path.join( 114 | "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" 115 | ) 116 | export_custom_sam_model( 117 | checkpoint_path=checkpoint_path, 118 | model_type=model_type, 119 | save_path=args.export_path, 120 | ) 121 | 122 | 123 | def main(): 124 | parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.") 125 | parser.add_argument( 126 | "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data", 127 | help="The filepath to the datasets. If the data does not exist yet it will be downloaded.", 128 | ) 129 | parser.add_argument( 130 | "--model_type", "-m", default="vit_b_histopathology", 131 | help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h.", 132 | ) 133 | parser.add_argument( 134 | "--save_root", "-s", default=None, 135 | help="Where to save the checkpoint and logs. By default they will be saved where this script is run.", 136 | ) 137 | parser.add_argument( 138 | "--dataset", "-d", default=None, 139 | help="Which dataset to finetune the specialist model on. Choose from datasets implemented in torch-em.", 140 | ) 141 | parser.add_argument( 142 | "--checkpoint", "-c", default=None, 143 | help="Which custom checkpoint to finetune the specialist model on. If not provided, SAM will be finetuned.", 144 | ) 145 | parser.add_argument( 146 | "--iterations", type=int, default=int(1e4), 147 | help="For how many iterations should the model be trained? By default 100k.", 148 | ) 149 | parser.add_argument( 150 | "--export_path", "-e", 151 | help="Where to export the finetuned model to. The exported model can be used in the annotation tools.", 152 | ) 153 | parser.add_argument( 154 | "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." 155 | ) 156 | args = parser.parse_args() 157 | finetune_specialist(args) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /experiments/training/specialists/finetune_specialist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | import micro_sam.training as sam_training 7 | from micro_sam.util import export_custom_sam_model 8 | 9 | from get_specialist_dataset import get_specialist_loaders 10 | 11 | 12 | def finetune_specialist(args): 13 | """Example code for finetuning SAM on histopathology datasets.""" 14 | # override this (below) if you have some more complex set-up and need to specify the exact GPU. 15 | device = "cuda" if torch.cuda.is_available() else "cpu" 16 | 17 | # training settings: 18 | model_type = args.model_type 19 | checkpoint_path = args.checkpoint # override this to start training from a custom checkpoint. 20 | patch_shape = (512, 512) # the patch shape for training. 21 | n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled. 22 | checkpoint_name = f"{args.model_type}/patho_sam" 23 | 24 | # all the stuff we need for training 25 | train_loader, val_loader = get_specialist_loaders( 26 | patch_shape=patch_shape, data_path=args.input_path, dataset=args.dataset 27 | ) 28 | scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10} 29 | 30 | # Run training. 31 | sam_training.train_sam( 32 | name=checkpoint_name, 33 | model_type=model_type, 34 | train_loader=train_loader, 35 | val_loader=val_loader, 36 | early_stopping=None, 37 | n_objects_per_batch=n_objects_per_batch, 38 | checkpoint_path=checkpoint_path, 39 | with_segmentation_decoder=True, 40 | device=device, 41 | lr=1e-5, 42 | n_iterations=args.iterations, 43 | save_root=args.save_root, 44 | scheduler_kwargs=scheduler_kwargs, 45 | verify_n_labels_in_loader=None, # NOTE: Setting to 'None' verifies all labels in the loader(s). 46 | ) 47 | 48 | if args.export_path is not None: 49 | checkpoint_path = os.path.join( 50 | "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" 51 | ) 52 | export_custom_sam_model( 53 | checkpoint_path=checkpoint_path, 54 | model_type=model_type, 55 | save_path=args.export_path, 56 | ) 57 | 58 | 59 | def main(): 60 | parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.") 61 | parser.add_argument( 62 | "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data", 63 | help="The filepath to the datasets. If the data does not exist yet it will be downloaded.", 64 | ) 65 | parser.add_argument( 66 | "--model_type", "-m", default="vit_b", 67 | help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h.", 68 | ) 69 | parser.add_argument( 70 | "--save_root", "-s", default=None, 71 | help="Where to save the checkpoint and logs. By default they will be saved where this script is run.", 72 | ) 73 | parser.add_argument( 74 | "--dataset", "-d", default=None, 75 | help="Which dataset to finetune the specialist model on. Choose from datasets implemented in torch-em.", 76 | ) 77 | parser.add_argument( 78 | "--checkpoint", "-c", default=None, 79 | help="Which custom checkpoint to finetune the specialist model on. If not provided, SAM will be finetuned.", 80 | ) 81 | parser.add_argument( 82 | "--iterations", type=int, default=int(10e4), 83 | help="For how many iterations should the model be trained? By default 100k.", 84 | ) 85 | parser.add_argument( 86 | "--export_path", "-e", 87 | help="Where to export the finetuned model to. The exported model can be used in the annotation tools.", 88 | ) 89 | parser.add_argument( 90 | "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." 91 | ) 92 | args = parser.parse_args() 93 | finetune_specialist(args) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /experiments/training/specialists/get_specialist_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.utils.data as data_util 5 | 6 | import torch_em 7 | from torch_em.data import MinInstanceSampler, datasets 8 | from torch_em.transform.label import PerObjectDistanceTransform 9 | 10 | from patho_sam.training import histopathology_identity 11 | 12 | 13 | def _get_train_val_split(ds, val_fraction=0.2): 14 | generator = torch.Generator().manual_seed(42) 15 | train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator) 16 | return train_ds, val_ds 17 | 18 | 19 | def get_specialist_dataset(path, patch_shape, split_choice, dataset): 20 | # Important stuff for dataloaders. 21 | label_dtype = torch.float32 22 | sampler = MinInstanceSampler(min_num_instances=4, min_size=10) 23 | 24 | # Expected raw and label transforms. 25 | raw_transform = histopathology_identity 26 | label_transform = PerObjectDistanceTransform( 27 | distances=True, 28 | boundary_distances=True, 29 | directed_distances=False, 30 | foreground=True, 31 | instances=True, 32 | min_size=10, 33 | ) 34 | 35 | if dataset == "nuclick": 36 | ds = datasets.get_nuclick_dataset( 37 | path=os.path.join(path, dataset), 38 | patch_shape=patch_shape, 39 | download=True, 40 | sampler=MinInstanceSampler(min_num_instances=2, min_size=10), 41 | split="Train", 42 | label_dtype=label_dtype, 43 | label_transform=label_transform, 44 | raw_transform=raw_transform, 45 | ) 46 | nuclick_train_ds, nuclick_val_ds = _get_train_val_split(ds) 47 | if split_choice == "train": 48 | return nuclick_train_ds 49 | else: 50 | return nuclick_val_ds 51 | 52 | elif dataset == "cryonuseg": 53 | return datasets.get_cryonuseg_dataset( 54 | path=os.path.join(path, "cryonuseg"), 55 | patch_shape=(1, *patch_shape), 56 | download=True, 57 | sampler=sampler, 58 | split=split_choice, 59 | rater="b1", 60 | label_dtype=label_dtype, 61 | label_transform=label_transform, 62 | raw_transform=raw_transform, 63 | ) 64 | 65 | elif dataset == "pannuke": 66 | ds = datasets.get_pannuke_dataset( 67 | path=os.path.join(path, "pannuke"), 68 | patch_shape=(1, *patch_shape), 69 | download=True, 70 | sampler=sampler, 71 | ndim=2, 72 | folds=["fold_1", "fold_2"], 73 | label_dtype=label_dtype, 74 | label_transform=label_transform, 75 | raw_transform=raw_transform, 76 | ) 77 | pannuke_train_ds, pannuke_val_ds = _get_train_val_split(ds=ds) 78 | if split_choice == "train": 79 | return pannuke_train_ds 80 | else: 81 | return pannuke_val_ds 82 | 83 | elif dataset == "glas": 84 | ds = datasets.get_glas_dataset( 85 | path=os.path.join(path, "glas"), 86 | patch_shape=patch_shape, 87 | download=True, 88 | sampler=MinInstanceSampler(min_num_instances=2), 89 | split="train", 90 | label_dtype=label_dtype, 91 | label_transform=label_transform, 92 | raw_transform=raw_transform, 93 | ) 94 | glas_train_ds, glas_val_ds = _get_train_val_split(ds=ds) 95 | if split_choice == "train": 96 | return glas_train_ds 97 | else: 98 | return glas_val_ds 99 | 100 | else: 101 | raise NotImplementedError 102 | 103 | 104 | def get_specialist_loaders(patch_shape, data_path, dataset): 105 | """This returns a selected histopathology dataset implemented in `torch_em`: 106 | https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets/histopathology 107 | It will automatically download the dataset 108 | 109 | NOTE: To remove / replace the dataset with another dataset, you need to add the datasets (for train and val splits) 110 | in `get_specialist_datasets`. The labels have to be in a label mask instance segmentation format. 111 | i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. 112 | IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. 113 | """ 114 | # Get the datasets 115 | specialist_train_dataset = get_specialist_dataset( 116 | path=data_path, patch_shape=patch_shape, split_choice="train", dataset=dataset 117 | ) 118 | specialist_val_dataset = get_specialist_dataset( 119 | path=data_path, patch_shape=patch_shape, split_choice="val", dataset=dataset 120 | ) 121 | # Get the dataloaders 122 | train_loader = torch_em.get_data_loader(specialist_train_dataset, batch_size=2, shuffle=True, num_workers=16) 123 | val_loader = torch_em.get_data_loader(specialist_val_dataset, batch_size=1, shuffle=True, num_workers=16) 124 | 125 | return train_loader, val_loader 126 | -------------------------------------------------------------------------------- /experiments/training/specialists/run_all_finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from datetime import datetime 5 | 6 | 7 | N_OBJECTS = {"vit_b": 40, "vit_l": 30, "vit_h": 25} 8 | 9 | 10 | def write_batch_script(out_path, _name, model_type, save_root, dry, dataset, checkpoint): 11 | "Writing scripts for different patho-sam specialist finetunings." 12 | batch_script = """#!/bin/bash 13 | #SBATCH -t 2-00:00:00 14 | #SBATCH --mem 64G 15 | #SBATCH --nodes=1 16 | #SBATCH --ntasks=1 17 | #SBATCH -p grete:shared 18 | #SBATCH -G A100:1 19 | #SBATCH -A nim00007 20 | #SBATCH -c 16 21 | #SBATCH --constraint=80gb 22 | #SBATCH --job-name=patho-sam-specialist 23 | 24 | source ~/.bashrc 25 | micromamba activate sam \n""" 26 | 27 | # python script 28 | python_script = f"python {_name}.py " 29 | python_script += f"-s {save_root} " # save root folder 30 | python_script += f"-d {dataset} " # dataset to finetune on 31 | python_script += f"-c {checkpoint} " # checkpoint to train on 32 | python_script += f"-m {model_type} " # name of the model configuration 33 | python_script += f"--n_objects {N_OBJECTS[model_type]} " # choice of the number of objects 34 | 35 | # let's add the python script to the bash script 36 | batch_script += python_script 37 | 38 | _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh" 39 | with open(_op, "w") as f: 40 | f.write(batch_script) 41 | 42 | cmd = ["sbatch", _op] 43 | if not dry: 44 | subprocess.run(cmd) 45 | 46 | 47 | def get_batch_script_names(tmp_folder): 48 | tmp_folder = os.path.expanduser(tmp_folder) 49 | os.makedirs(tmp_folder, exist_ok=True) 50 | 51 | script_name = "patho-sam-finetuning" 52 | 53 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") 54 | tmp_name = script_name + dt 55 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") 56 | 57 | return batch_script 58 | 59 | 60 | def submit_slurm(args): 61 | "Submit python script that needs gpus with given inputs on a slurm node." 62 | tmp_folder = "./gpu_jobs" 63 | if os.path.exists(tmp_folder): 64 | shutil.rmtree(tmp_folder) 65 | 66 | if args.model_type is None: 67 | models = list(N_OBJECTS.keys()) 68 | else: 69 | models = [args.model_type] 70 | 71 | script_name = f"finetune_{args.dataset}_specialist" 72 | print(f"Running for {script_name}") 73 | for model_type in models: 74 | write_batch_script( 75 | out_path=get_batch_script_names(tmp_folder), 76 | _name=script_name, 77 | model_type=model_type, 78 | save_root=args.save_root, 79 | dry=args.dry, 80 | dataset=args.dataset, 81 | checkpoint=args.checkpoint 82 | ) 83 | 84 | 85 | def main(args): 86 | submit_slurm(args) 87 | 88 | 89 | if __name__ == "__main__": 90 | import argparse 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("-s", "--save_root", type=str, default=None) 94 | parser.add_argument("-d", "--dataset", type=str, default=None, required=True) 95 | parser.add_argument("-c", "--checkpoint", type=str, default=None) 96 | parser.add_argument("-m", "--model_type", type=str, default=None) 97 | parser.add_argument("--dry", action="store_true") 98 | args = parser.parse_args() 99 | main(args) 100 | -------------------------------------------------------------------------------- /patho_sam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/computational-cell-analytics/patho-sam/482e2e336b9e5b75028c013452a70aba8107207a/patho_sam/__init__.py -------------------------------------------------------------------------------- /patho_sam/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /patho_sam/evaluation/class_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Union, Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from torch_em.util.image import load_data 9 | from torch_em.data.datasets import histopathology 10 | 11 | from ..training.util import CLASS_DICT 12 | 13 | 14 | LABEL_KEYS = { 15 | 'pannuke': {'semantic': 'labels/semantic', 'instance': 'labels/instances'}, 16 | 'puma': {'semantic': 'labels/semantic/nuclei', 'instance': 'labels/instance/nuclei'}, 17 | } 18 | 19 | 20 | def extract_class_weights( 21 | path: Union[os.PathLike, str], dataset: Optional[str] = None, output_path: Optional[Union[os.PathLike, str]] = None, 22 | ) -> List: 23 | """Extract class weights per semantic class. 24 | 25 | Args: 26 | path: The filepath where the input data is either located or will be downloaded automatically. 27 | dataset: The choice of dataset to extract the class weights. 28 | output_path: The output directory where you would like to store the class weights. 29 | 30 | Return: 31 | List of class weights for chosen dataset. 32 | """ 33 | 34 | # Get the input filepaths 35 | _get_dataset = { 36 | "puma": lambda: histopathology.puma.get_puma_paths(path=path, split="train", download=True), 37 | "pannuke": lambda: histopathology.pannuke.get_pannuke_paths( 38 | path=path, folds=["fold_1", "fold_2"], download=True, 39 | ) 40 | } 41 | 42 | # Next, get the volume paths. 43 | volume_paths = _get_dataset[dataset]() 44 | 45 | # Load the entire instance and semantic stack. 46 | if isinstance(volume_paths, str): 47 | instances = load_data(path, key=LABEL_KEYS[dataset]['instance']) 48 | semantics = load_data(path, key=LABEL_KEYS[dataset]['semantic']) 49 | else: 50 | all_semantic, all_instances = [], [] 51 | for fpath in glob(os.path.join(path, "*.h5")): 52 | _instance = load_data(fpath, key=LABEL_KEYS[dataset]['instance']) 53 | _semantic = load_data(fpath, key=LABEL_KEYS[dataset]['semantic']) 54 | 55 | # Get all semantic and instance labels. 56 | all_instances.append(_instance) 57 | all_semantic.append(_semantic) 58 | 59 | instances = np.concatenate(all_instances, axis=0) 60 | semantics = np.concatenate(all_semantic, axis=0) 61 | 62 | # We need the following: 63 | # - Count the total number of instances. 64 | total_instance_counts = [ 65 | len(np.unique(ilabel)[1:]) for ilabel in instances if len(np.unique(ilabel)) > 1 66 | ] # Counting all valid foreground instances only. 67 | total_instance_counts = sum(total_instance_counts) 68 | 69 | # - Count per-semantic-class instances. 70 | class_ids = CLASS_DICT[dataset].values() 71 | total_per_class_instance_counts = [ 72 | [len(np.unique(np.where(slabel == cid, ilabel, 0))[1:]) for cid in class_ids] 73 | for ilabel, slabel in zip(instances, semantics) if len(np.unique(ilabel)) > 1 74 | ] 75 | assert total_instance_counts == sum([sum(t) for t in total_per_class_instance_counts]) 76 | 77 | # Calculate per class mean values. 78 | total_per_class_instance_counts = [sum(x) for x in zip(*total_per_class_instance_counts)] 79 | assert total_instance_counts == sum(total_per_class_instance_counts) 80 | 81 | # Finally, let's get the weight per class. Results are saved as .csv in the output folder and output as a list 82 | per_class_weights = [t / total_instance_counts for t in total_per_class_instance_counts] 83 | 84 | # Store the class weights locally. 85 | os.makedirs(output_path, exist_ok=True) 86 | result_dict = {nt: weight for nt, weight in zip(CLASS_DICT[dataset].keys(), per_class_weights)} 87 | result_df = pd.DataFrame.from_dict(result_dict, orient='index', columns=["Class Weight"]) 88 | result_df.to_csv(os.path.join(output_path, f"{dataset}_class_weights.csv"), index=True) 89 | 90 | return per_class_weights 91 | -------------------------------------------------------------------------------- /patho_sam/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | 4 | import h5py 5 | import numpy as np 6 | 7 | from elf.evaluation import dice_score 8 | 9 | CLASS_IDS = [1, 2, 3, 4, 5] 10 | 11 | 12 | def extract_class_weights_for_pannuke(fpath: Union[os.PathLike, str], class_ids: List = CLASS_IDS) -> List[float]: 13 | """Extract class weights per semantic class. 14 | 15 | Args: 16 | fpath: The filepath where the input stack for fold 3 stored for PanNuke dataset. 17 | Use `torch_em.data.datasets.histopathology.get_pannuke_paths` to get filepath for the stack. 18 | class_ids: The choice of all available class ids. 19 | 20 | Returns: 21 | List of class weights. 22 | """ 23 | # Load the entire instance and semantic stack. 24 | with h5py.File(fpath, "r") as f: 25 | instances = f['labels/instances'][:] 26 | semantic = f['labels/semantic'][:] 27 | 28 | # We need the following: 29 | # - Count the total number of instances. 30 | total_instance_counts = [ 31 | len(np.unique(ilabel)[1:]) for ilabel in instances if len(np.unique(ilabel)) > 1 32 | ] # Counting all valid foreground instances only. 33 | total_instance_counts = sum(total_instance_counts) 34 | 35 | # - Count per-semantic-class instances. 36 | total_per_class_instance_counts = [ 37 | [len(np.unique(np.where(slabel == cid, ilabel, 0))[1:]) for cid in class_ids] 38 | for ilabel, slabel in zip(instances, semantic) if len(np.unique(ilabel)) > 1 39 | ] 40 | assert total_instance_counts == sum([sum(t) for t in total_per_class_instance_counts]) 41 | 42 | # Calculate per class mean values. 43 | total_per_class_instance_counts = [sum(x) for x in zip(*total_per_class_instance_counts)] 44 | assert total_instance_counts == sum(total_per_class_instance_counts) 45 | 46 | # Finally, let's get the weight per class. 47 | per_class_weights = [t / total_instance_counts for t in total_per_class_instance_counts] 48 | 49 | return per_class_weights 50 | 51 | 52 | def semantic_segmentation_quality( 53 | ground_truth: np.ndarray, segmentation: np.ndarray, class_ids: List[int] 54 | ) -> List[float]: 55 | """Evaluation metric for the semantic segmentation task. 56 | 57 | Args: 58 | ground_truth: The ground truth with expected semantic labels. 59 | segmentation: The predicted masks with expected semantic labels. 60 | class_ids: The per-class id available for all tasks, to calculate per class semantic quality score. 61 | 62 | Returns: 63 | List of semantic quality score per class. 64 | """ 65 | # First, we iterate over all classes 66 | sq_per_class = [] 67 | for id in class_ids: 68 | # Get the per semantic class values. 69 | this_gt = (ground_truth == id).astype("uint32") 70 | this_seg = (segmentation == id).astype("uint32") 71 | 72 | # Check if the ground truth is empty for this semantic class. We skip calculation for this. 73 | if len(np.unique(this_gt)) == 1 and len(np.unique(this_seg) == 1): 74 | this_sq = np.nan 75 | else: 76 | this_sq = dice_score(this_seg, this_gt) 77 | 78 | sq_per_class.append(this_sq) 79 | 80 | return sq_per_class 81 | -------------------------------------------------------------------------------- /patho_sam/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import read_wsi 2 | -------------------------------------------------------------------------------- /patho_sam/io/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, Tuple, Optional 3 | 4 | import numpy as np 5 | 6 | try: 7 | import slideio 8 | except ImportError: 9 | slideio = None 10 | 11 | 12 | def read_wsi( 13 | input_path: Union[os.PathLike, str], 14 | scale: Optional[Tuple[int, int]] = None, 15 | image_size: Optional[Tuple[int, int, int, int]] = None, 16 | ) -> np.ndarray: 17 | """Function to read whole-slide images (WSIs) in histopathology. 18 | 19 | The file formats tested are '.svs', '.scn', '.czi', '.zvi', '.ndpi', 20 | '.vsi', '.qptiff' and other gdal formats. 21 | 22 | Args: 23 | input_path: The path to the WSI. 24 | scale: Relevant for WSIs, to get the image for a desired scale. Provide the desired (H, W) combination to scale. 25 | You can choose (H, 0) or (0, W) to scale along one dimension and keep the resolution intact. 26 | image_size: Relevant for WSIs, to get a ROI crop for a desired shape. 27 | 28 | Returns: 29 | The numpy array. 30 | """ 31 | if not os.path.exists(input_path): 32 | raise FileNotFoundError(input_path) 33 | 34 | assert slideio is not None, "Please install 'slideio': 'pip install slideio'." 35 | slide = slideio.open_slide(input_path) # Fetches the slide object. 36 | 37 | # Let's check with expected scale. 38 | if scale is None: 39 | scale = (0, 0) # Loads original resolution. 40 | else: 41 | if not isinstance(scale, Tuple) and len(scale) != 2: 42 | raise ValueError( 43 | "The scale parameter is expected to be a tuple of height and width dimensions, " 44 | "such that the new shape is (H', W')" 45 | ) 46 | 47 | # Let's check for the expected size of the desired ROI. 48 | # NOTE: Here, we expect all values for placing an ROI precisely: (x, y, W, H) 49 | if image_size is None: 50 | image_size = (0, 0, 0, 0) 51 | else: 52 | if not isinstance(image_size, Tuple): 53 | raise ValueError( 54 | "The image size parameter is expected to be a tuple of desired target ROI crop, " 55 | "such that the new crop shape is for this ROI." 56 | ) 57 | 58 | # If the user provides shapes in the usual 2d axes format, eg. (1024, 1024), 59 | # we provide them a top-left corner crop. 60 | if len(image_size) == 2: 61 | image_size = (0, 0, *image_size) 62 | 63 | assert len(scale) == 2 64 | assert len(image_size) == 4 65 | 66 | # NOTE: Each slide objects could contain one or multiple scenes, 67 | # which is coined as a continuous raster region (with the 2d image, other meta-data, etc) 68 | scene = slide.get_scene(0) 69 | input_array = scene.read_block(size=scale, rect=image_size) 70 | 71 | return input_array 72 | -------------------------------------------------------------------------------- /patho_sam/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import histopathology_identity, get_train_val_split 2 | from .semantic_trainer import SemanticInstanceTrainer 3 | -------------------------------------------------------------------------------- /patho_sam/training/semantic_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from micro_sam.training import SemanticSamTrainer 7 | from micro_sam.training.semantic_sam_trainer import CustomDiceLoss 8 | 9 | 10 | class SemanticInstanceTrainer(SemanticSamTrainer): 11 | """Modified trainer class for training the Segment Anything Model for semantic (instance) segmentation. 12 | """ 13 | def __init__( 14 | self, 15 | convert_inputs, 16 | num_classes: int, 17 | dice_weight: Optional[float] = None, 18 | class_weights: Optional[List[float]] = None, 19 | **kwargs 20 | ): 21 | assert num_classes > 1 22 | 23 | if "loss" not in kwargs: 24 | kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) 25 | 26 | if "metric" not in kwargs: 27 | kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) 28 | 29 | super().__init__(convert_inputs=convert_inputs, num_classes=num_classes, dice_weight=dice_weight, **kwargs) 30 | 31 | self.class_weights = class_weights 32 | if self.class_weights is None: 33 | self.compute_ce_loss = nn.CrossEntropyLoss() 34 | else: 35 | weight_vals = torch.tensor(self.class_weights, dtype=torch.float32, device=self.device) 36 | self.compute_ce_loss = nn.CrossEntropyLoss(weight=weight_vals) 37 | 38 | if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1): 39 | raise ValueError("The weight factor should lie between 0 and 1.") 40 | 41 | self._kwargs = kwargs 42 | 43 | def _get_model_outputs(self, batched_inputs): 44 | """Get the predictions from the model. 45 | """ 46 | inputs = torch.stack([bi["image"] for bi in batched_inputs], dim=0).to(self.device) 47 | outputs = self.model(inputs.to(self.device)) 48 | return outputs 49 | -------------------------------------------------------------------------------- /patho_sam/training/util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data as data_util 7 | 8 | from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb 9 | 10 | 11 | CLASS_MAP = { 12 | 'puma': { 13 | 2: 1, 14 | 3: 2, 4: 2, 5: 2, 6: 2, 7: 2, 15 | 1: 3, 8: 3, 16 | 10: 4, 17 | 9: 5, 18 | }, 19 | } 20 | 21 | CLASS_DICT = { 22 | 'puma': { 23 | "nuclei_stroma": 1, 24 | "nuclei_tumor": 2, 25 | "nuclei_plasma_cell": 3, 26 | "nuclei_histiocyte": 4, 27 | "nuclei_lymphocyte": 5, 28 | "nuclei_melanophage": 6, 29 | "nuclei_neutrophil": 7, 30 | "nuclei_endothelium": 8, 31 | "nuclei_epithelium": 9, 32 | "nuclei_apoptosis": 10 33 | }, 34 | 'pannuke': { 35 | "neoplastic": 1, 36 | "inflammatory": 2, 37 | "connective / soft tissue": 3, 38 | "dead cells": 4, 39 | "epithelial": 5, 40 | }, 41 | } 42 | 43 | 44 | def histopathology_identity(x, ensure_rgb=True): 45 | """Identity transform. 46 | Inspired from 'micro_sam/training/util.py' -> 'identity' function. 47 | 48 | This ensures to skip data normalization when finetuning SAM. 49 | Data normalization is performed within the model to SA-1B data statistics 50 | and should thus be skipped as a preprocessing step in training. 51 | """ 52 | if ensure_rgb: 53 | x = to_rgb(x) 54 | 55 | return x 56 | 57 | 58 | def get_train_val_split( 59 | ds: torch.utils.data.Dataset, val_fraction: float = 0.2, seed: int = 42, 60 | ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]: 61 | """Creates split for a dataset for a decided fraction. 62 | 63 | Args: 64 | dataset: The segmentation dataset. 65 | val_fraction: The fraction of split to decide for validation, and remanining for test. 66 | seed: Setting a seed for your storage device for reproducibility. 67 | 68 | Returns: 69 | Tuple of train and val datasets. 70 | """ 71 | generator = torch.Generator().manual_seed(seed) 72 | train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator) 73 | return train_ds, val_ds 74 | 75 | 76 | def remap_labels(y: np.ndarray, name: str) -> np.ndarray: 77 | """Maps the labels to overall meta classes, to match the 78 | semantic class structure of PanNuke dataset. 79 | 80 | Args: 81 | y: The original semantic label. 82 | name: The name of target dataset to remap original class ids to PanNuke class ids. 83 | 84 | Returns: 85 | The remapped labels. 86 | """ 87 | if name not in CLASS_MAP: 88 | raise ValueError(f"The chosen dataset '{name}' is not supported.") 89 | 90 | # Get the class id map. 91 | mapping = CLASS_MAP[name] 92 | 93 | # Remap the labels. 94 | # NOTE: We go with this remapping to make sure that each ids are mapped to the exact values. 95 | per_id_lookup_array = np.array([mapping.get(i, 0) for i in range(max(mapping) + 1)], dtype=np.int32) 96 | y_remapped = per_id_lookup_array[y] 97 | return y_remapped 98 | 99 | 100 | def calculate_class_weights_for_loss_weighting( 101 | foreground_class_weights: List[float] = [0.4702, 0.1797, 0.2229, 0.0159, 0.1113], 102 | ) -> List[float]: 103 | """Calculates the class weights for weighting the cross entropy loss. 104 | 105 | NOTE 1: The default weights originate from weighting both the PanNuke and PUMA labels. 106 | TODO: Scripts coming soon! 107 | 108 | NOTE 2: We weigh the classes using relative integers on a scale of 1 to 10, 109 | where 1 resembles the most frequent class and 10 the least frequent class. 110 | 111 | NOTE 3: Make sure that the order of weights match the class id order. 112 | 113 | Args: 114 | foreground_class_weight: The ratio / frequency of foreground class weights. 115 | 116 | Returns: 117 | The integer weighting for each class, including the background class. 118 | """ 119 | foreground_class_weights = np.array(foreground_class_weights) 120 | 121 | # Define the range for integer weighting. 122 | background_weight, max_weight = 1, 10 123 | 124 | # Normalize the class weights. 125 | min_val, max_val = np.min(foreground_class_weights), np.max(foreground_class_weights) 126 | 127 | # Invert the mapping (i.e. higher for rarer class, lower for common classes) 128 | mapped_weights = max_weight - ((foreground_class_weights - min_val) / (max_val - min_val)) * (max_weight - 1) 129 | 130 | # Make sure that the most common class has weight 1. 131 | mapped_weights[np.argmax(foreground_class_weights)] = background_weight 132 | 133 | # Round the weights and convert them to integer values. 134 | final_weights = np.round(mapped_weights).astype(int) 135 | 136 | # Add background weights in the beginning. 137 | final_weights_with_bg = [background_weight, *final_weights] 138 | 139 | return final_weights_with_bg 140 | -------------------------------------------------------------------------------- /patho_sam/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | from typing import Union, Optional, OrderedDict 4 | 5 | import pooch 6 | 7 | import torch 8 | 9 | from micro_sam.util import microsam_cachedir, get_cache_directory 10 | from micro_sam.sample_data import fetch_wholeslide_histopathology_example_data 11 | 12 | from .io import read_wsi 13 | 14 | 15 | DECODER_URL = "https://owncloud.gwdg.de/index.php/s/TFbI25UZoixd1hi/download" 16 | 17 | 18 | def export_semantic_segmentation_decoder( 19 | checkpoint_path: Union[str, os.PathLike], save_path: Union[str, os.PathLike], 20 | ): 21 | """Exports the weights of the trained convolutional decoder for semantic segemntation task. 22 | 23 | Args: 24 | checkpoint_path: Filepath to the trained semantic segmentation checkpoint. 25 | save_path: Filepath where the decoder weights will be stored. 26 | """ 27 | # Load the model state from finetuned checkpoint. 28 | model_state = torch.load(checkpoint_path, map_location="cpu")["model_state"] 29 | 30 | # Get the decoder state only. 31 | decoder_state = collections.OrderedDict( 32 | [(k, v) for k, v in model_state.items() if not k.startswith("encoder")] 33 | ) 34 | 35 | # Store the decoder state to a desired path. 36 | torch.save(decoder_state, save_path) 37 | 38 | 39 | def get_semantic_segmentation_decoder_weights(save_path: Optional[Union[str, os.PathLike]] = None) -> OrderedDict: 40 | """Get the semantic segmentation decoder weights for initializing the decoder-only. 41 | 42 | Args: 43 | save_path: Whether to save the model checkpoints to desired path. 44 | 45 | Returns: 46 | The pretrained decoder weights. 47 | """ 48 | # By default, we store decoder weights to `micro-sam` cache directory. 49 | save_directory = os.path.join(microsam_cachedir(), "models") if save_path is None else save_path 50 | 51 | # Download the model weights 52 | fname = "vit_b_histopathology_semantic_segmentation_decoder" 53 | pooch.retrieve( 54 | url=DECODER_URL, 55 | known_hash="bdd05a55c72c02abce72a7aa6885c6ec21df9c43fda9cf3c5d11ef5788de0ab0", 56 | fname=fname, 57 | path=save_directory, 58 | progressbar=True, 59 | ) 60 | 61 | # Get the checkpoint path. 62 | checkpoint_path = os.path.join(save_directory, fname) 63 | 64 | # Load the decoder state. 65 | state = torch.load(checkpoint_path, map_location="cpu") 66 | 67 | return state 68 | 69 | 70 | def get_example_wsi_data(): 71 | """@private""" 72 | import argparse 73 | parser = argparse.ArgumentParser(description="Download and visualize the example whole-slide image (WSI).") 74 | parser.add_argument( 75 | "-s", "--save_path", type=str, default=None, 76 | help=f"The folder to store the whole-slide image. By default, it is stored at '{get_cache_directory()}'." 77 | ) 78 | parser.add_argument( 79 | "--roi", nargs="+", type=int, default=None, 80 | help="The roi shape of the whole slide image for automatic segmentation. By default, predicts on entire WSI. " 81 | "You can provide the ROI shape as: '--roi X Y W H'.", 82 | ) 83 | parser.add_argument( 84 | "--view", action="store_true", help="Whether to view the WSI in napari." 85 | ) 86 | 87 | args = parser.parse_args() 88 | 89 | # Get the folder to store the WSI. By default, stores it at 'micro-sam' cache directory. 90 | save_dir = os.path.join(get_cache_directory(), "sample_data") if args.save_path is None else args.save_dir 91 | 92 | # Download the example WSI. 93 | example_data = fetch_wholeslide_histopathology_example_data(save_dir) 94 | 95 | # Check the ROI and convert it to tuple, if provided by user. 96 | roi = None if args.roi is None else tuple(args.roi) 97 | 98 | if args.view: 99 | # Load the WSI image. 100 | image = read_wsi(example_data, image_size=roi) 101 | 102 | # Get multi-scales for the input image. 103 | multiscale_images = [ 104 | image, 105 | read_wsi(example_data, image_size=roi, scale=(int(image.shape[0] / 2), 0)), 106 | read_wsi(example_data, image_size=roi, scale=(int(image.shape[0] / 4), 0)), 107 | read_wsi(example_data, image_size=roi, scale=(int(image.shape[0] / 8), 0)), 108 | ] 109 | 110 | import napari 111 | v = napari.Viewer() 112 | v.add_image(multiscale_images, name="Input Image") 113 | napari.run() 114 | 115 | print(f"The example WSI is stored at '{example_data}'.") 116 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /scripts/plotting/get_dataset_figure.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | from torch_em.data import datasets, MinTwoInstanceSampler 6 | 7 | from micro_sam.training import identity 8 | from micro_sam.evaluation.model_comparison import _overlay_outline 9 | 10 | 11 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data" 12 | 13 | 14 | def for_fig_1a(): 15 | sampler = MinTwoInstanceSampler() 16 | 17 | # Get a cummulative image for multiple datasets. 18 | get_loaders = { 19 | "consep": lambda: datasets.get_consep_loader( 20 | path=os.path.join(ROOT, "consep"), batch_size=1, shuffle=True, raw_transform=identity, 21 | patch_shape=(512, 512), split="test", download=True, sampler=sampler, 22 | ), 23 | # "cpm15": lambda: datasets.get_cpm_loader( 24 | # path=os.path.join(ROOT, "cpm15"), batch_size=1, shuffle=True, patch_shape=(512, 512), 25 | # data_choice="cpm15", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity, 26 | # ), 27 | "cpm17": lambda: datasets.get_cpm_loader( 28 | path=os.path.join(ROOT, "cpm17"), batch_size=1, shuffle=True, patch_shape=(512, 512), split="test", 29 | data_choice="cpm17", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity, 30 | ), 31 | "cryonuseg": lambda: datasets.get_cryonuseg_loader( 32 | path=os.path.join(ROOT, "cryonuseg"), batch_size=1, shuffle=True, patch_shape=(512, 512), 33 | split="test", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity, 34 | ), 35 | "lizard": lambda: datasets.get_lizard_loader( 36 | path=os.path.join(ROOT, "lizard"), batch_size=1, patch_shape=(512, 512), split="test", 37 | resize_inputs=True, download=True, shuffle=True, sampler=sampler, raw_transform=identity, 38 | ), 39 | "lynsec": lambda: datasets.get_lynsec_loader( 40 | path=os.path.join(ROOT, "lynsec"), batch_size=1, patch_shape=(512, 512), shuffle=True, 41 | choice="h&e", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity, 42 | ), 43 | "monuseg": lambda: datasets.get_monuseg_loader( 44 | path=os.path.join(ROOT, "monuseg"), batch_size=1, shuffle=True, patch_shape=(512, 512), 45 | resize_inputs=True, download=True, sampler=sampler, raw_transform=identity, split="test", 46 | ), 47 | "nuinsseg": lambda: datasets.get_nuinsseg_loader( 48 | path=os.path.join(ROOT, "nuinsseg"), batch_size=1, shuffle=True, patch_shape=(512, 512), 49 | download=True, sampler=sampler, raw_transform=identity, resize_inputs=True, 50 | ), 51 | # "pannuke": lambda: datasets.get_pannuke_loader( 52 | # path=os.path.join(ROOT, "pannuke"), batch_size=1, patch_shape=(512, 512), folds=["fold_3"], 53 | # download=True, shuffle=True, sampler=sampler, raw_transform=identity, 54 | # ), 55 | "puma": lambda: datasets.get_puma_loader( 56 | path=os.path.join(ROOT, "puma"), batch_size=1, patch_shape=(512, 512), split="test", 57 | download=True, sampler=sampler, raw_transform=identity, resize_inputs=True, 58 | ), 59 | "tnbc": lambda: datasets.get_tnbc_loader( 60 | path=os.path.join(ROOT, "tnbc"), batch_size=1, patch_shape=(512, 512), ndim=2, shuffle=True, 61 | split="train", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity, 62 | ) 63 | } 64 | 65 | fig, ax = plt.subplots(3, 3, figsize=(30, 30)) 66 | ax = ax.flatten() 67 | 68 | for i, dname in enumerate(get_loaders.keys()): 69 | loader = get_loaders[dname]() 70 | counter = 0 71 | for x, y in loader: 72 | if counter > 0: 73 | break 74 | counter += 1 75 | 76 | x, y = x.squeeze().numpy(), y.squeeze().numpy() 77 | 78 | # Make channels last for RGB images. 79 | if x.shape[0] == 3: 80 | x = x.transpose(1, 2, 0) 81 | 82 | # Normalize images. 83 | from torch_em.transform.raw import normalize 84 | x = normalize(x) * 255 85 | x = x.astype(int) 86 | 87 | # Finally, plot them into one place. 88 | image = _overlay_outline(x, y, outline_dilation=1) 89 | 90 | ax[i].imshow(image, cmap="gray") 91 | ax[i].axis("off") 92 | 93 | plt.subplots_adjust(hspace=0.01, wspace=0.01) 94 | plt.savefig("./fig_1a_histopathology_dataset_images.png", bbox_inches="tight") 95 | plt.savefig("./fig_1a_histopatholgoy_dataset_images.svg", bbox_inches="tight") 96 | 97 | 98 | def main(): 99 | for_fig_1a() 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /scripts/test_pannuke.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from glob import glob 4 | from natsort import natsorted 5 | 6 | import h5py 7 | import numpy as np 8 | import imageio.v3 as imageio 9 | from skimage.segmentation import relabel_sequential 10 | 11 | from torch_em.data.datasets.histopathology import pannuke, monuseg 12 | 13 | from micro_sam.util import get_sam_model 14 | from micro_sam.evaluation import inference, evaluation 15 | 16 | 17 | def _pad_image(image, target_shape=(512, 512), pad_value=0): 18 | pad = [ 19 | (max(t - s, 0) // 2, max(t - s, 0) - max(t - s, 0) // 2) for s, t in zip(image.shape[:2], target_shape) 20 | ] 21 | 22 | if image.ndim == 3: 23 | pad.append((0, 0)) 24 | 25 | return np.pad(image, pad, mode="constant", constant_values=pad_value) 26 | 27 | 28 | def get_data_paths(input_path, dataset_name): 29 | # Set specific data folders. 30 | input_path = os.path.join(input_path, dataset_name) 31 | 32 | if dataset_name == "pannuke": 33 | # Create clone of single images in input_path directory. 34 | data_dir = os.path.join(input_path, "benchmark_2d") 35 | 36 | if not os.path.exists(data_dir): 37 | # First, we get the fold 3. 38 | fold_path = pannuke.get_pannuke_paths(path=input_path, folds=["fold_3"], download=True) 39 | fold_path = fold_path[0] 40 | 41 | # Next, simply extract the images. 42 | with h5py.File(fold_path, "r") as f: 43 | image_stack = f["images"][:].transpose(1, 2, 3, 0) 44 | label_stack = f["labels/instances"][:] 45 | 46 | # Store them one-by-one locally in an experiment folder. 47 | os.makedirs(data_dir) 48 | 49 | for i, (image, label) in tqdm( 50 | enumerate(zip(image_stack, label_stack)), total=len(image_stack), desc="Extracting images", 51 | ): 52 | # There has to be some foreground in the image to be considered for interactive segmentation. 53 | if len(np.unique(label)) == 1: 54 | continue 55 | 56 | # NOTE: I am padding the image below to match the shape of inputs on which it is trained, 57 | # i.e. (512, 512), for proper reproducibility (otherwise the results are slightly worse) 58 | image = _pad_image(image) 59 | label = _pad_image(label) 60 | 61 | imageio.imwrite(os.path.join(data_dir, f"pannuke_fold_3_{i:05}_image.tif"), image, compression="zlib") 62 | imageio.imwrite(os.path.join(data_dir, f"pannuke_fold_3_{i:05}_label.tif"), label, compression="zlib") 63 | 64 | # Well, now we have our image and label paths. 65 | image_paths = natsorted(glob(os.path.join(data_dir, "*_image.tif"))) 66 | label_paths = natsorted(glob(os.path.join(data_dir, "*_label.tif"))) 67 | 68 | assert len(image_paths) == len(label_paths) 69 | 70 | # HACK: for debugging purpose, I will check on first 100 images. Remove the next line to run it on all images. 71 | image_paths, label_paths = image_paths[:100], label_paths[:100] 72 | 73 | elif dataset_name == "monuseg": 74 | # Create clone of cropped images in input_path directory. 75 | data_dir = os.path.join(input_path, "benchmark_2d") 76 | os.makedirs(data_dir, exist_ok=True) 77 | 78 | curr_image_paths, curr_label_paths = monuseg.get_monuseg_paths(path=input_path, split="test", download=True) 79 | 80 | # Let's do a simple cropping to test stuff. 81 | image_paths, label_paths = [], [] 82 | for curr_image_path, curr_label_path in zip(curr_image_paths, curr_label_paths): 83 | image = imageio.imread(curr_image_path) 84 | label = imageio.imread(curr_label_path).astype("uint32") 85 | 86 | # Do a simple cropping and relabel instances. 87 | image, label = image[:512, :512, :], label[:512, :512] 88 | label = relabel_sequential(label)[0] 89 | 90 | # And save the cropped image and corresponding label. 91 | image_path = os.path.join(data_dir, os.path.basename(curr_image_path)) 92 | image_paths.append(image_path) 93 | imageio.imwrite(image_path, image, compression="zlib") 94 | 95 | label_path = os.path.join(data_dir, os.path.basename(curr_label_path)) 96 | label_paths.append(label_path) 97 | imageio.imwrite(label_path, label, compression="zlib") 98 | 99 | else: 100 | raise ValueError 101 | 102 | return image_paths, label_paths 103 | 104 | 105 | def run_interactive_segmentation(input_path, experiment_folder, model_type, start_with_box_prompt=True): 106 | 107 | # Setup 1: PanNuke images (pad (256, 256) images up to (512, 512) to match the training patch shape) 108 | # image_paths, label_paths = get_data_paths(input_path, "pannuke") # NOTE: uncomment to run it on PanNuke 109 | 110 | # Setup 2: MoNuSeg images (since the images are larger than training patch shape, we crop them to shape (512, 512)) 111 | image_paths, label_paths = get_data_paths(input_path, "monuseg") # NOTE: comment this before running other setups. 112 | 113 | # Get the Segment Anything model. 114 | predictor = get_sam_model(model_type=model_type) 115 | 116 | # Then run interactive segmentation by simulating prompts from labels. 117 | prediction_root = os.path.join( 118 | experiment_folder, ("start_with_box" if start_with_box_prompt else "start_with_point") 119 | ) 120 | inference.run_inference_with_iterative_prompting( 121 | predictor=predictor, 122 | image_paths=image_paths, 123 | gt_paths=label_paths, 124 | embedding_dir=None, 125 | prediction_dir=prediction_root, 126 | start_with_box_prompt=start_with_box_prompt, 127 | ) 128 | 129 | # And evaluate the results. 130 | results = evaluation.run_evaluation_for_iterative_prompting( 131 | gt_paths=label_paths, 132 | prediction_root=prediction_root, 133 | experiment_folder=experiment_folder, 134 | start_with_box_prompt=start_with_box_prompt, 135 | ) 136 | 137 | print(results) 138 | 139 | 140 | def main(args): 141 | run_interactive_segmentation( 142 | input_path=args.input_path, 143 | model_type=args.model_type, 144 | experiment_folder=args.experiment_folder, 145 | start_with_box_prompt=False, 146 | ) 147 | 148 | 149 | if __name__ == "__main__": 150 | import argparse 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument("-i", "--input_path", default="/mnt/vast-nhr/projects/cidas/cca/data", type=str) 153 | parser.add_argument("-e", "--experiment_folder", default="./experiments", type=str) 154 | parser.add_argument("-m", "--model_type", default="vit_b_histopathology", type=str) 155 | args = parser.parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import runpy 4 | from distutils.core import setup 5 | 6 | 7 | __version__ = runpy.run_path("patho_sam/__version__.py")["__version__"] 8 | 9 | 10 | setup( 11 | name='patho_sam', 12 | description='Segment Anything for Histopathology.', 13 | version=__version__, 14 | author='Titus Griebel, Anwai Archit, Constantin Pape', 15 | author_email='titus.griebel@stud.uni-goettingen.de, anwai.archit@uni-goettingen.de, constantin.pape@informatik.uni-goettingen.de', # noqa 16 | url='https://github.com/computational-cell-analytics/patho-sam', 17 | packages=['patho_sam'], 18 | license="MIT", 19 | entry_points={ 20 | "console_scripts": [ 21 | "patho_sam.example_data = patho_sam.util:get_example_wsi_data", 22 | "patho_sam.automatic_segmentation = patho_sam.automatic_segmentation:main", 23 | ] 24 | } 25 | ) 26 | --------------------------------------------------------------------------------