├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MedSAM_Inference.py ├── README.md ├── assets ├── MedSAM_supp.pdf ├── architecture.png ├── box-task-illustration.png ├── img_demo.png ├── pathology.png ├── seg_demo.gif ├── seg_pathology.mp4 └── task-illustration.png ├── comparisons ├── DeepLabV3+ │ ├── README.md │ ├── infer_deeplabv3_res50_2D.py │ ├── infer_deeplabv3_res50_3D.py │ └── train_deeplabv3_res50.py ├── SAM │ ├── README.md │ ├── infer_SAM_2D_npz.py │ └── infer_SAM_3D_npz.py └── nnU-Net │ ├── README.md │ ├── infer_nnunet_2D.py │ └── infer_nnunet_3D.py ├── extensions ├── point_prompt │ ├── README.md │ ├── point_seg_demo.gif │ ├── train_point_prompt.py │ └── tutorial_point_prompt_seg.ipynb ├── seg_3dnii_sparse_marker │ ├── README.md │ ├── label_interpolate.py │ └── medsam_infer_3Dbox_adrenal.py └── text_prompt │ ├── README.md │ ├── text_seg_demo.gif │ ├── train_text_prompt.py │ └── tutorial_text_prompt_seg.ipynb ├── gui.py ├── pre_CT_MR.py ├── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── setup.py ├── train_multi_gpus.py ├── train_multi_gpus.sh ├── train_one_gpu.py ├── tutorial_quickstart.ipynb ├── utils ├── README.md ├── SurfaceDice.py ├── ckpt_convert.py ├── demo.py ├── format_convert.py ├── pre_CT_MR.py ├── pre_grey_rgb.py └── split.py └── work_dir └── MedSAM └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | # Project specific 3 | *.DS_Store 4 | *.pth 5 | *.nii.gz 6 | *.nii 7 | *.npy 8 | *.npz 9 | *.csv 10 | *.log 11 | *.png 12 | data/ 13 | work_dir/ 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | cover/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | .pybuilder/ 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | # For a library or package, you might want to ignore these files since the code is 101 | # intended to run in multiple environments; otherwise, check them in: 102 | # .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # poetry 112 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 113 | # This is especially recommended for binary packages to ensure reproducibility, and is more 114 | # commonly ignored for libraries. 115 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 116 | #poetry.lock 117 | 118 | # pdm 119 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 120 | #pdm.lock 121 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 122 | # in version control. 123 | # https://pdm.fming.dev/#use-with-ide 124 | .pdm.toml 125 | 126 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 127 | __pypackages__/ 128 | 129 | # Celery stuff 130 | celerybeat-schedule 131 | celerybeat.pid 132 | 133 | # SageMath parsed files 134 | *.sage.py 135 | 136 | # Environments 137 | .env 138 | .venv 139 | env/ 140 | venv/ 141 | ENV/ 142 | env.bak/ 143 | venv.bak/ 144 | 145 | # Spyder project settings 146 | .spyderproject 147 | .spyproject 148 | 149 | # Rope project settings 150 | .ropeproject 151 | 152 | # mkdocs documentation 153 | /site 154 | 155 | # mypy 156 | .mypy_cache/ 157 | .dmypy.json 158 | dmypy.json 159 | 160 | # Pyre type checker 161 | .pyre/ 162 | 163 | # pytype static type analyzer 164 | .pytype/ 165 | 166 | # Cython debug symbols 167 | cython_debug/ 168 | 169 | # PyCharm 170 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 171 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 172 | # and can be added to the global gitignore or merged into this file. For a more nuclear 173 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 174 | #.idea/ 175 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | # - id: no-commit-to-branch 6 | # args: [--pattern, ^v] 7 | # - id: check-added-large-files 8 | # args: [--maxkb=64] 9 | - id: check-case-conflict 10 | - id: check-yaml 11 | - id: check-xml 12 | - id: check-toml 13 | - id: check-merge-conflict 14 | - id: check-symlinks 15 | - id: destroyed-symlinks 16 | - id: mixed-line-ending 17 | args: [--fix=lf] 18 | - id: end-of-file-fixer 19 | - id: trailing-whitespace 20 | - id: check-json 21 | - id: pretty-format-json 22 | args: [--autofix, --indent=4, --no-ensure-ascii] 23 | - id: detect-private-key 24 | - id: fix-encoding-pragma 25 | 26 | - repo: https://github.com/psf/black 27 | rev: 23.3.0 28 | hooks: 29 | - id: black 30 | -------------------------------------------------------------------------------- /MedSAM_Inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | usage example: 5 | python MedSAM_Inference.py -i assets/img_demo.png -o ./ --box "[95,255,190,350]" 6 | 7 | """ 8 | 9 | # %% load environment 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import os 13 | 14 | join = os.path.join 15 | import torch 16 | from segment_anything import sam_model_registry 17 | from skimage import io, transform 18 | import torch.nn.functional as F 19 | import argparse 20 | 21 | 22 | # visualization functions 23 | # source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb 24 | # change color to avoid red and green 25 | def show_mask(mask, ax, random_color=False): 26 | if random_color: 27 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 28 | else: 29 | color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6]) 30 | h, w = mask.shape[-2:] 31 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 32 | ax.imshow(mask_image) 33 | 34 | 35 | def show_box(box, ax): 36 | x0, y0 = box[0], box[1] 37 | w, h = box[2] - box[0], box[3] - box[1] 38 | ax.add_patch( 39 | plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2) 40 | ) 41 | 42 | 43 | @torch.no_grad() 44 | def medsam_inference(medsam_model, img_embed, box_1024, H, W): 45 | box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) 46 | if len(box_torch.shape) == 2: 47 | box_torch = box_torch[:, None, :] # (B, 1, 4) 48 | 49 | sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( 50 | points=None, 51 | boxes=box_torch, 52 | masks=None, 53 | ) 54 | low_res_logits, _ = medsam_model.mask_decoder( 55 | image_embeddings=img_embed, # (B, 256, 64, 64) 56 | image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 57 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 58 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 59 | multimask_output=False, 60 | ) 61 | 62 | low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) 63 | 64 | low_res_pred = F.interpolate( 65 | low_res_pred, 66 | size=(H, W), 67 | mode="bilinear", 68 | align_corners=False, 69 | ) # (1, 1, gt.shape) 70 | low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) 71 | medsam_seg = (low_res_pred > 0.5).astype(np.uint8) 72 | return medsam_seg 73 | 74 | 75 | # %% load model and image 76 | parser = argparse.ArgumentParser( 77 | description="run inference on testing set based on MedSAM" 78 | ) 79 | parser.add_argument( 80 | "-i", 81 | "--data_path", 82 | type=str, 83 | default="assets/img_demo.png", 84 | help="path to the data folder", 85 | ) 86 | parser.add_argument( 87 | "-o", 88 | "--seg_path", 89 | type=str, 90 | default="assets/", 91 | help="path to the segmentation folder", 92 | ) 93 | parser.add_argument( 94 | "--box", 95 | type=str, 96 | default='[95, 255, 190, 350]', 97 | help="bounding box of the segmentation target", 98 | ) 99 | parser.add_argument("--device", type=str, default="cuda:0", help="device") 100 | parser.add_argument( 101 | "-chk", 102 | "--checkpoint", 103 | type=str, 104 | default="work_dir/MedSAM/medsam_vit_b.pth", 105 | help="path to the trained model", 106 | ) 107 | args = parser.parse_args() 108 | 109 | device = args.device 110 | medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint) 111 | medsam_model = medsam_model.to(device) 112 | medsam_model.eval() 113 | 114 | img_np = io.imread(args.data_path) 115 | if len(img_np.shape) == 2: 116 | img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) 117 | else: 118 | img_3c = img_np 119 | H, W, _ = img_3c.shape 120 | # %% image preprocessing 121 | img_1024 = transform.resize( 122 | img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True 123 | ).astype(np.uint8) 124 | img_1024 = (img_1024 - img_1024.min()) / np.clip( 125 | img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None 126 | ) # normalize to [0, 1], (H, W, 3) 127 | # convert the shape to (3, H, W) 128 | img_1024_tensor = ( 129 | torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) 130 | ) 131 | 132 | box_np = np.array([[int(x) for x in args.box[1:-1].split(',')]]) 133 | # transfer box_np t0 1024x1024 scale 134 | box_1024 = box_np / np.array([W, H, W, H]) * 1024 135 | with torch.no_grad(): 136 | image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64) 137 | 138 | medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W) 139 | io.imsave( 140 | join(args.seg_path, "seg_" + os.path.basename(args.data_path)), 141 | medsam_seg, 142 | check_contrast=False, 143 | ) 144 | 145 | # %% visualize results 146 | fig, ax = plt.subplots(1, 2, figsize=(10, 5)) 147 | ax[0].imshow(img_3c) 148 | show_box(box_np[0], ax[0]) 149 | ax[0].set_title("Input Image and Bounding Box") 150 | ax[1].imshow(img_3c) 151 | show_mask(medsam_seg, ax[1]) 152 | show_box(box_np[0], ax[1]) 153 | ax[1].set_title("MedSAM Segmentation") 154 | plt.show() 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedSAM 2 | This is the official repository for MedSAM: Segment Anything in Medical Images. 3 | 4 | Welcome to join our [mailing list](https://forms.gle/hk4Efp6uWnhjUHFP6) to get updates. 5 | 6 | 7 | ## News 8 | 9 | - 2025.04.07: Release [MedSAM2](https://github.com/bowang-lab/MedSAM2) for 3D and video segmentation. 10 | - 2025.02: Welcome to join CVPR 2025 Challenges: [Interactive](https://www.codabench.org/competitions/5263/) and [Text-guided](https://www.codabench.org/competitions/5651/) 3D Biomedical Image Segmentation 11 | - 2024.01.15: Welcome to join [CVPR 2024 Challenge: MedSAM on Laptop](https://www.codabench.org/competitions/1847/) 12 | - 2024.01.15: Release [LiteMedSAM](https://github.com/bowang-lab/MedSAM/blob/LiteMedSAM/README.md) and [3D Slicer Plugin](https://github.com/bowang-lab/MedSAMSlicer), 10x faster than MedSAM! 13 | 14 | 15 | ## Installation 16 | 1. Create a virtual environment `conda create -n medsam python=3.10 -y` and activate it `conda activate medsam` 17 | 2. Install [Pytorch 2.0](https://pytorch.org/get-started/locally/) 18 | 3. `git clone https://github.com/bowang-lab/MedSAM` 19 | 4. Enter the MedSAM folder `cd MedSAM` and run `pip install -e .` 20 | 21 | 22 | ## Get Started 23 | Download the [model checkpoint](https://drive.google.com/drive/folders/1ETWmi4AiniJeWOt6HAsYgTjYv_fkgzoN?usp=drive_link) and place it at e.g., `work_dir/MedSAM/medsam_vit_b` 24 | 25 | We provide three ways to quickly test the model on your images 26 | 27 | 1. Command line 28 | 29 | ```bash 30 | python MedSAM_Inference.py # segment the demo image 31 | ``` 32 | 33 | Segment other images with the following flags 34 | ```bash 35 | -i input_img 36 | -o output path 37 | --box bounding box of the segmentation target 38 | ``` 39 | 40 | 2. Jupyter-notebook 41 | 42 | We provide a step-by-step tutorial on [CoLab](https://colab.research.google.com/drive/19WNtRMbpsxeqimBlmJwtd1dzpaIvK2FZ?usp=sharing) 43 | 44 | You can also run it locally with `tutorial_quickstart.ipynb`. 45 | 46 | 3. GUI 47 | 48 | Install `PyQt5` with [pip](https://pypi.org/project/PyQt5/): `pip install PyQt5 ` or [conda](https://anaconda.org/anaconda/pyqt): `conda install -c anaconda pyqt` 49 | 50 | ```bash 51 | python gui.py 52 | ``` 53 | 54 | Load the image to the GUI and specify segmentation targets by drawing bounding boxes. 55 | 56 | 57 | 58 | https://github.com/bowang-lab/MedSAM/assets/19947331/a8d94b4d-0221-4d09-a43a-1251842487ee 59 | 60 | 61 | 62 | 63 | 64 | ## Model Training 65 | 66 | ### Data preprocessing 67 | 68 | Download [SAM checkpoint](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) and place it at `work_dir/SAM/sam_vit_b_01ec64.pth` . 69 | 70 | Download the demo [dataset](https://zenodo.org/record/7860267) and unzip it to `data/FLARE22Train/`. 71 | 72 | This dataset contains 50 abdomen CT scans and each scan contains an annotation mask with 13 organs. The names of the organ label are available at [MICCAI FLARE2022](https://flare22.grand-challenge.org/). 73 | 74 | Run pre-processing 75 | 76 | Install `cc3d`: `pip install connected-components-3d` 77 | 78 | ```bash 79 | python pre_CT_MR.py 80 | ``` 81 | 82 | - split dataset: 80% for training and 20% for testing 83 | - adjust CT scans to [soft tissue](https://radiopaedia.org/articles/windowing-ct) window level (40) and width (400) 84 | - max-min normalization 85 | - resample image size to `1024x1024` 86 | - save the pre-processed images and labels as `npy` files 87 | 88 | 89 | ### Training on multiple GPUs (Recommend) 90 | 91 | The model was trained on five A100 nodes and each node has four GPUs (80G) (20 A100 GPUs in total). Please use the slurm script to start the training process. 92 | 93 | ```bash 94 | sbatch train_multi_gpus.sh 95 | ``` 96 | 97 | When the training process is done, please convert the checkpoint to SAM's format for convenient inference. 98 | 99 | ```bash 100 | python utils/ckpt_convert.py # Please set the corresponding checkpoint path first 101 | ``` 102 | 103 | ### Training on one GPU 104 | 105 | ```bash 106 | python train_one_gpu.py 107 | ``` 108 | 109 | 110 | 111 | ## Acknowledgements 112 | - We highly appreciate all the challenge organizers and dataset owners for providing the public dataset to the community. 113 | - We thank Meta AI for making the source code of [segment anything](https://github.com/facebookresearch/segment-anything) publicly available. 114 | - We also thank Alexandre Bonnet for sharing this great [blog](https://encord.com/blog/learn-how-to-fine-tune-the-segment-anything-model-sam/) 115 | 116 | 117 | ## Reference 118 | 119 | ``` 120 | @article{MedSAM, 121 | title={Segment Anything in Medical Images}, 122 | author={Ma, Jun and He, Yuting and Li, Feifei and Han, Lin and You, Chenyu and Wang, Bo}, 123 | journal={Nature Communications}, 124 | volume={15}, 125 | pages={654}, 126 | year={2024} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /assets/MedSAM_supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/MedSAM_supp.pdf -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/architecture.png -------------------------------------------------------------------------------- /assets/box-task-illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/box-task-illustration.png -------------------------------------------------------------------------------- /assets/img_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/img_demo.png -------------------------------------------------------------------------------- /assets/pathology.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/pathology.png -------------------------------------------------------------------------------- /assets/seg_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/seg_demo.gif -------------------------------------------------------------------------------- /assets/seg_pathology.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/seg_pathology.mp4 -------------------------------------------------------------------------------- /assets/task-illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/assets/task-illustration.png -------------------------------------------------------------------------------- /comparisons/DeepLabV3+/README.md: -------------------------------------------------------------------------------- 1 | # DeepLabV3+ Training and Inference Scripts 2 | 3 | This folder contains the training and inference scripts of the DeepLabV3+ model for segmentation on medical image data in MedSAM's preprocessed `npz` format. For details regarding the data preprocessing pipeline, please refer to the [MedSAM](https://github.com/bowang-lab/MedSAM#data-preprocessing). 4 | 5 | ## Prerequisites 6 | 7 | This codebase uses the [Segmentation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch), which can be installed via pip: 8 | 9 | ``` 10 | pip install segmentation-models-pytorch 11 | ``` 12 | 13 | ## Training 14 | 15 | To train the DeepLabV3+ model, one can use the provided `train_deeplabv3_res50.py` script. In order to incorporate the bounding box prompts into the model, we converted the bounding box as a binary mask and concatenated it with the image as the model input. The bounding box was simulated based on ground truth. 16 | Below are the required parameters that need to be configured before training: 17 | 18 | * `-i /path/to/input`: Path to the input dataset (npy format). 19 | * `-o /path/to/output`: Path to save the trained model. 20 | 21 | Example command for training: 22 | ```sh 23 | python train_deeplabv3_res50.py \ 24 | -i /path/to/input \ 25 | -o /path/to/output \ 26 | -b ## batch size \ 27 | --num_workers 4 \ ## Number of workers for data loading 28 | --max_epochs 500 \ ## Maximum number of epochs to train 29 | --compile ## Whether to compile the model for acceleration 30 | ``` 31 | 32 | 33 | 34 | ## Inference 35 | The inference scripts assume that the data is in the `npz` format generated by MedSAM preprocess pipeline. To run inference, one can download the model [here](https://drive.google.com/drive/folders/1xYUgdjIsmBkobiBKXNb1uyqN-kGHW2p_?usp=sharing) and use the provided inference scripts. 36 | 37 | 38 | ### Inference for 2D images 39 | 40 | To perform inference on 2D images, one can use the `infer_deeplabv3_res50_2D.py` script. Below are the parameters need to be configured: 41 | 42 | * `-checkpoint`: Path to the trained model checkpoint. 43 | * `-data_root`: Path to the input images. 44 | * `-pred_save_dir`: Path to save the output segmented images. 45 | * `--save_overlay`: Save the overlay of the segmentation on the original image. (Optional) 46 | * `-png_save_dir`: Path to save the overlay images. (Required if `--save_overlay` is used) 47 | * `-num_workers`: Number of workers for multiprocessing during inference. 48 | * `--grey`: Save the overlay images in greyscale. (Optional) 49 | 50 | ```sh 51 | python infer_deeplabv3_res50_2D.py \ 52 | -checkpoint path/to/checkpoint/deeplabv3plus_best.pt \ 53 | -data_root /path/to/input \ 54 | -pred_save_dir /path/to/output \ 55 | --save_overlay \ 56 | -png_save_dir /path/to/saved/overlay \ 57 | -num_workers 2 \ 58 | --grey 59 | ``` 60 | 61 | ### Inference for 3D images 62 | 63 | To perform inference on 3D medical images, such as those of CT or MR modality, the `infer_deeplabv3_res50_3D.py` script can be used. Below are the parameters that one can configure: 64 | 65 | * `-checkpoint`: Path to the trained model checkpoint. 66 | * `-data_root`: Path to the input 3D images. 67 | * `-pred_save_dir`: Path to save the output segmented 3D images. 68 | * `-png_save_dir`: Path to save the overlay images. (Optional) 69 | * `-num_workers`: Number of workers for multiprocessing during inference. 70 | 71 | ```sh 72 | python infer_deeplabv3_res50_3D.py \ 73 | -checkpoint /path/to/checkpoint/deeplabv3plus_best.pt \ 74 | -data_root /path/to/input \ 75 | -pred_save_dir /path/to/output \ 76 | -png_save_dir /path/to/saved/overlay \ 77 | -num_workers 2 78 | ``` 79 | 80 | ## Acknowledgement 81 | This codebasse uses the [Segmentation Models Pytorch](https://github.com/qubvel/segmentation_models.pytorch) repository. We would like to thank the authors and the contributors for their great work and for making the code publicly available. -------------------------------------------------------------------------------- /comparisons/DeepLabV3+/infer_deeplabv3_res50_2D.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import glob 4 | import random 5 | from os import listdir, makedirs 6 | from os.path import join, isdir, basename, dirname, isfile 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch 10 | from torch._dynamo import OptimizedModule 11 | from torch import multiprocessing as mp 12 | import cv2 13 | import torch.nn.functional as F 14 | from matplotlib import pyplot as plt 15 | import segmentation_models_pytorch as smp 16 | import argparse 17 | 18 | torch.cuda.empty_cache() 19 | os.environ['PYTHONHASHSEED']=str(2023) 20 | random.seed(2023) 21 | np.random.seed(2023) 22 | torch.manual_seed(2023) 23 | torch.cuda.manual_seed(2023) 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | '-checkpoint', 28 | type=str, 29 | default='', 30 | help='Path to the model checkpoint', 31 | required=True 32 | ) 33 | parser.add_argument('-device', type=str, default='cuda:0') 34 | parser.add_argument( 35 | '-data_root', 36 | type=str, 37 | default='', 38 | help='Path to the validation data directory', 39 | required=True 40 | ) 41 | parser.add_argument( 42 | '-pred_save_dir', 43 | type=str, 44 | default='segs', 45 | help='Path to the directory where the segmentation results will be saved in npz format' 46 | ) 47 | parser.add_argument('--save_overlay', action='store_true', default=False, help="Whether to save segmentation overlay") 48 | parser.add_argument( 49 | '-png_save_dir', 50 | type=str, 51 | default='png', 52 | help='Path to the directory where the segmentation overlay will be saved in png format' 53 | ) 54 | parser.add_argument( 55 | '--grey', 56 | action='store_true', 57 | default=False, 58 | help="Whether to save segmentation overlay in grey scale" 59 | ) 60 | parser.add_argument('-num_workers', type=int, default=1, help='number of workers for dataloader') 61 | 62 | args = parser.parse_args() 63 | checkpoint = args.checkpoint 64 | device = args.device 65 | data_root = args.data_root 66 | pred_save_dir = args.pred_save_dir 67 | png_save_dir = args.png_save_dir 68 | makedirs(pred_save_dir, exist_ok=True) 69 | save_overlay = args.save_overlay 70 | if save_overlay: 71 | makedirs(png_save_dir, exist_ok=True) 72 | num_workers = args.num_workers 73 | data_root_files = listdir(data_root) 74 | has_task = isdir(join(data_root, data_root_files[0])) 75 | if has_task: 76 | gt_path_files = sorted(glob.glob(join(data_root, '**/*.npz'), recursive=True)) 77 | else: 78 | gt_path_files = sorted(glob.glob(join(data_root, '*.npz'), recursive=True)) 79 | image_size = 224 80 | bbox_shift = 5 81 | 82 | def show_mask(mask, ax, random_color=False): 83 | if random_color: 84 | color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0) 85 | else: 86 | color = np.array([251/255, 252/255, 30/255, 0.45]) 87 | h, w = mask.shape[-2:] 88 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 89 | ax.imshow(mask_image) 90 | 91 | def show_box(box, ax): 92 | x0, y0 = box[0], box[1] 93 | w, h = box[2] - box[0], box[3] - box[1] 94 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 95 | 96 | def dice_coefficient(preds, targets): 97 | smooth = 1.0 98 | assert preds.shape == targets.shape 99 | 100 | intersection = (preds * targets).sum() 101 | dice = (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth) 102 | return dice 103 | 104 | # %% 105 | model = smp.DeepLabV3Plus( 106 | encoder_name="resnet50", # encoder model type 107 | encoder_weights="imagenet", # use `imagenet` pretrained weights for encoder initialization 108 | in_channels=4, # Additional channel for bounding box prompt 109 | classes=1, # model output channels (number of classes in your dataset) 110 | activation=None # Output logits 111 | ) 112 | checkpoint = torch.load(checkpoint) 113 | model.load_state_dict(checkpoint['model']) 114 | model.to(device) 115 | model.eval() 116 | 117 | def preprocess_image(img_3c, gt_2D, image_size=224, bbox_shift=5): 118 | """ 119 | Append bounding box prompt channel to image 120 | """ 121 | resize_img_cv2 = cv2.resize( 122 | img_3c, 123 | (image_size, image_size), 124 | interpolation=cv2.INTER_AREA 125 | ) 126 | resize_img_cv2_01 = (resize_img_cv2 - resize_img_cv2.min()) / np.clip(resize_img_cv2.max() - resize_img_cv2.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3) 127 | # convert the shape to (3, H, W) 128 | resize_img = np.transpose(resize_img_cv2_01, (2, 0, 1)) 129 | assert np.max(resize_img)<=1.0 and np.min(resize_img)>=0.0, 'image should be normalized to [0, 1]' 130 | if gt_2D.shape[0] != image_size or gt_2D.shape[1] != image_size: 131 | gt_2D = cv2.resize( 132 | gt_2D, (image_size, image_size), 133 | interpolation=cv2.INTER_NEAREST 134 | ) 135 | gt_2D = np.uint8(gt_2D) 136 | else: 137 | gt_2D = gt_2D.astype(np.uint8) 138 | try: 139 | assert np.max(gt_2D) == 1 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 140 | except: 141 | assert np.max(gt_2D) == 0 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 142 | return None 143 | 144 | y_indices, x_indices = np.where(gt_2D > 0) 145 | x_min, x_max = np.min(x_indices), np.max(x_indices) 146 | y_min, y_max = np.min(y_indices), np.max(y_indices) 147 | H, W = gt_2D.shape 148 | x_min = max(0, x_min - bbox_shift) 149 | x_max = min(W, x_max + bbox_shift) 150 | y_min = max(0, y_min - bbox_shift) 151 | y_max = min(H, y_max + bbox_shift) 152 | bboxes = np.array([x_min, y_min, x_max, y_max]) 153 | 154 | ## Append bbox prompt channel 155 | resize_img_bbox = np.concatenate([resize_img, np.zeros((1, image_size, image_size))], axis=0) 156 | resize_img_bbox[-1, y_min:y_max, x_min:x_max] = 1.0 157 | resize_img_bbox = resize_img_bbox[None, ...] 158 | 159 | return torch.tensor(resize_img_bbox).float() 160 | 161 | 162 | def deeplabv3plus_infer_npz(gt_path_file): 163 | npz_name = basename(gt_path_file) 164 | if has_task: 165 | task_folder = gt_path_file.split('/')[-2] 166 | pred_save_dir_task = join(pred_save_dir, task_folder) 167 | png_save_dir_task = join(png_save_dir, task_folder) 168 | makedirs(pred_save_dir_task, exist_ok=True) 169 | makedirs(png_save_dir_task, exist_ok=True) 170 | else: 171 | pred_save_dir_task = pred_save_dir 172 | png_save_dir_task = png_save_dir 173 | if isfile(join(pred_save_dir_task, npz_name)): 174 | return 175 | npz = np.load(gt_path_file, 'r', allow_pickle=True) 176 | img_3c = npz['imgs'] # (Num, H, W) 177 | gts = npz['gts'] # (Num, 256, 256) 178 | segs = np.zeros_like(img_3c[..., 0], dtype=np.uint8) 179 | 180 | label_ids = np.unique(gts)[1:] 181 | 182 | for label_id in label_ids: 183 | gt_2D = np.uint8(gts == label_id) # only one label 184 | img_4c = preprocess_image( 185 | img_3c, 186 | gt_2D, 187 | image_size=image_size, 188 | bbox_shift=bbox_shift 189 | ) 190 | if img_4c == None: 191 | continue 192 | img_4c = img_4c.to(device) 193 | with torch.no_grad(): 194 | seg_logits = model(img_4c) 195 | seg_logits = F.interpolate( 196 | seg_logits, 197 | size=img_3c.shape[:2], 198 | mode='bilinear', 199 | align_corners=False 200 | ) 201 | seg_probs = torch.sigmoid(seg_logits) 202 | seg_probs_np = seg_probs.detach().cpu().numpy().squeeze() 203 | torch.cuda.empty_cache() 204 | seg_2D = np.uint8(seg_probs_np > 0.5) 205 | segs[seg_2D > 0] = label_id 206 | 207 | if gts.shape[0] != img_3c.shape[0] or gts.shape[1] != img_3c.shape[1]: 208 | gts = cv2.resize( 209 | gts, 210 | (img_3c.shape[1], img_3c.shape[0]), 211 | interpolation=cv2.INTER_NEAREST 212 | ) 213 | 214 | np.savez_compressed( 215 | join(pred_save_dir_task, npz_name), 216 | segs=segs, 217 | gts=gts 218 | ) 219 | 220 | if save_overlay: 221 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 222 | if args.grey: 223 | ax[0].imshow(img_3c, cmap='gray') 224 | else: 225 | ax[0].imshow(img_3c) 226 | ax[0].set_title("Image") 227 | ax[0].axis('off') 228 | if args.grey: 229 | ax[1].imshow(img_3c, cmap='gray') 230 | else: 231 | ax[1].imshow(img_3c) 232 | show_mask(gts, ax[1]) 233 | ax[1].axis('off') 234 | #show_box(boxes_np, ax[1]) 235 | ax[1].set_title("Ground Truth") 236 | if args.grey: 237 | ax[2].imshow(img_3c, cmap='gray') 238 | else: 239 | ax[2].imshow(img_3c) 240 | show_mask(segs, ax[2]) 241 | ax[2].set_title("Segmentation") 242 | ax[2].axis('off') 243 | plt.savefig( 244 | join(png_save_dir_task, npz_name.split(".")[0] + '.png'), 245 | dpi=300 246 | ) 247 | plt.close() 248 | 249 | if __name__ == '__main__': 250 | num_workers = num_workers 251 | try: 252 | mp.set_start_method('spawn', force=True) 253 | print("spawned") 254 | except RuntimeError: 255 | pass 256 | with mp.Pool(processes=num_workers) as pool: 257 | with tqdm(total=len(gt_path_files)) as pbar: 258 | for i, _ in tqdm(enumerate(pool.imap_unordered(deeplabv3plus_infer_npz, gt_path_files))): 259 | pbar.update() 260 | -------------------------------------------------------------------------------- /comparisons/DeepLabV3+/infer_deeplabv3_res50_3D.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import glob 4 | import random 5 | from os import listdir, makedirs 6 | from os.path import join, isfile, isdir, basename 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch 10 | from torch._dynamo import OptimizedModule 11 | from torch import multiprocessing as mp 12 | from datetime import datetime 13 | import cv2 14 | from skimage import morphology 15 | import torch.nn.functional as F 16 | from matplotlib import pyplot as plt 17 | import segmentation_models_pytorch as smp 18 | import argparse 19 | 20 | torch.cuda.empty_cache() 21 | os.environ['PYTHONHASHSEED']=str(2023) 22 | random.seed(2023) 23 | np.random.seed(2023) 24 | torch.manual_seed(2023) 25 | torch.cuda.manual_seed(2023) 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | '-checkpoint', 30 | type=str, 31 | default='', 32 | help='Path to the model checkpoint', 33 | required=True 34 | ) 35 | parser.add_argument('-device', type=str, default='cuda:0') 36 | parser.add_argument( 37 | '-data_root', 38 | type=str, 39 | default='', 40 | help='Path to the validation data directory', 41 | required=True 42 | ) 43 | parser.add_argument( 44 | '-pred_save_dir', 45 | type=str, 46 | default='segs', 47 | help='Path to the directory where the segmentation results will be saved in npz format' 48 | ) 49 | parser.add_argument('--save_overlay', action='store_true', default=False, help="Whether to save segmentation overlay") 50 | parser.add_argument( 51 | '-png_save_dir', 52 | type=str, 53 | default='png', 54 | help='Path to the directory where the segmentation overlay will be saved in png format' 55 | ) 56 | parser.add_argument('-num_workers', type=int, default=1) 57 | 58 | args = parser.parse_args() 59 | checkpoint = args.checkpoint 60 | device = args.device 61 | data_root = args.data_root 62 | pred_save_dir = args.pred_save_dir 63 | png_save_dir = args.png_save_dir 64 | makedirs(pred_save_dir, exist_ok=True) 65 | save_overlay = args.save_overlay 66 | if save_overlay: 67 | makedirs(png_save_dir, exist_ok=True) 68 | num_workers = args.num_workers 69 | data_root_files = listdir(data_root) 70 | 71 | ## Check whether the data root has subfolders 72 | has_task = isdir(join(data_root, data_root_files[0])) 73 | if has_task: 74 | gt_path_files = sorted(glob.glob(join(data_root, '**/*.npz'), recursive=True)) 75 | else: 76 | gt_path_files = sorted(glob.glob(join(data_root, '*.npz'), recursive=True)) 77 | image_size = 224 78 | bbox_shift = 5 79 | # %% 80 | def show_mask(mask, ax, random_color=False): 81 | if random_color: 82 | color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0) 83 | else: 84 | color = np.array([251/255, 252/255, 30/255, 0.45]) 85 | h, w = mask.shape[-2:] 86 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 87 | ax.imshow(mask_image) 88 | 89 | def show_box(box, ax): 90 | x0, y0 = box[0], box[1] 91 | w, h = box[2] - box[0], box[3] - box[1] 92 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 93 | 94 | def dice_coefficient(preds, targets): 95 | smooth = 1.0 96 | assert preds.shape == targets.shape 97 | 98 | intersection = (preds * targets).sum() 99 | dice = (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth) 100 | return dice 101 | 102 | def get_bbox(mask, bbox_shift=5): 103 | y_indices, x_indices = np.where(mask > 0) 104 | x_min, x_max = np.min(x_indices), np.max(x_indices) 105 | y_min, y_max = np.min(y_indices), np.max(y_indices) 106 | # add perturbation to bounding box coordinates 107 | H, W = mask.shape 108 | x_min = max(0, x_min - bbox_shift) 109 | x_max = min(W, x_max + bbox_shift) 110 | y_min = max(0, y_min - bbox_shift) 111 | y_max = min(H, y_max + bbox_shift) 112 | bboxes = np.array([x_min, y_min, x_max, y_max]) 113 | 114 | return bboxes 115 | 116 | 117 | def preprocess_slice(img_2D, gt_2D, image_size=224, bbox_shift=5): 118 | img_3c = np.repeat(img_2D[..., None], 3, axis=-1) # (H, W, 3) 119 | resize_img_cv2 = cv2.resize( 120 | img_3c, 121 | (image_size, image_size), 122 | interpolation=cv2.INTER_AREA 123 | ) 124 | resize_img_cv2_01 = (resize_img_cv2 - resize_img_cv2.min()) / np.clip(resize_img_cv2.max() - resize_img_cv2.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3) 125 | # convert the shape to (3, H, W) 126 | resize_img = np.transpose(resize_img_cv2_01, (2, 0, 1)) 127 | assert np.max(resize_img)<=1.0 and np.min(resize_img)>=0.0, 'image should be normalized to [0, 1]' 128 | if gt_2D.shape[0] != image_size or gt_2D.shape[1] != image_size: 129 | gt_2D = cv2.resize( 130 | gt_2D, (image_size, image_size), 131 | interpolation=cv2.INTER_NEAREST 132 | ) 133 | gt_2D = np.uint8(gt_2D) 134 | else: 135 | gt_2D = gt_2D.astype(np.uint8) 136 | try: 137 | assert np.max(gt_2D) == 1 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 138 | except: 139 | assert np.max(gt_2D) == 0 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 140 | return None 141 | 142 | y_indices, x_indices = np.where(gt_2D > 0) 143 | x_min, x_max = np.min(x_indices), np.max(x_indices) 144 | y_min, y_max = np.min(y_indices), np.max(y_indices) 145 | # add perturbation to bounding box coordinates 146 | H, W = gt_2D.shape 147 | x_min = max(0, x_min - bbox_shift) 148 | x_max = min(W, x_max + bbox_shift) 149 | y_min = max(0, y_min - bbox_shift) 150 | y_max = min(H, y_max + bbox_shift) 151 | 152 | ## Append bbox prompt channel 153 | resize_img_bbox = np.concatenate([resize_img, np.zeros((1, image_size, image_size))], axis=0) 154 | resize_img_bbox[-1, y_min:y_max, x_min:x_max] = 1.0 155 | resize_img_bbox = resize_img_bbox[None, ...] 156 | 157 | return torch.tensor(resize_img_bbox).float() 158 | 159 | model = smp.DeepLabV3Plus( 160 | encoder_name="resnet50", # encoder model type 161 | encoder_weights="imagenet", # use `imagenet` pretrained weights for encoder initialization 162 | in_channels=4, # Additional channel for bounding box prompt 163 | classes=1, # model output channels (number of classes in your dataset) 164 | activation=None # Output logits 165 | ) 166 | checkpoint = torch.load(checkpoint) 167 | weights = checkpoint['model'] 168 | model.load_state_dict(checkpoint['model']) 169 | model.to(device) 170 | model.eval() 171 | 172 | def deeplabv3plus_infer_npz(gt_path_file): 173 | npz_name = basename(gt_path_file) 174 | if has_task: 175 | task_folder = gt_path_file.split('/')[-2] 176 | pred_save_dir_task = join(pred_save_dir, task_folder) 177 | png_save_dir_task = join(png_save_dir, task_folder) 178 | makedirs(pred_save_dir_task, exist_ok=True) 179 | makedirs(png_save_dir_task, exist_ok=True) 180 | else: 181 | pred_save_dir_task = pred_save_dir 182 | png_save_dir_task = png_save_dir 183 | if isfile(join(pred_save_dir_task, npz_name)): 184 | return 185 | npz = np.load(gt_path_file, 'r', allow_pickle=True) 186 | img_3D = npz['imgs'] # (Num, H, W) 187 | gt_3D = npz['gts'] # (Num, 256, 256) 188 | spacing = npz['spacing'] 189 | seg_3D = np.zeros_like(gt_3D, dtype=np.uint8) # (Num, 256, 256) 190 | 191 | for i in range(img_3D.shape[0]): 192 | img_2D = img_3D[i,:,:] # (H, W) 193 | gt = gt_3D[i,:,:] # (H, W) 194 | label_ids = np.unique(gt)[1:] 195 | for label_id in label_ids: 196 | gt2D = np.uint8(gt == label_id) # only one label, (256, 256) 197 | img_4c = preprocess_slice( 198 | img_2D, 199 | gt2D, 200 | image_size=image_size, 201 | bbox_shift=bbox_shift 202 | ) 203 | if img_4c is None: 204 | continue 205 | img_4c = img_4c.to(device) 206 | with torch.no_grad(): 207 | seg_logits = model(img_4c) 208 | seg_logits = F.interpolate( 209 | seg_logits, 210 | size=(gt2D.shape[0], gt2D.shape[1]), 211 | mode='bilinear', 212 | align_corners=False 213 | ) 214 | seg_probs = torch.sigmoid(seg_logits) 215 | seg_probs_np = seg_probs.detach().cpu().numpy().squeeze() 216 | torch.cuda.empty_cache() 217 | seg_2D = np.uint8(seg_probs_np > 0.5) 218 | seg_3D[i, seg_2D>0] = label_id 219 | 220 | np.savez_compressed( 221 | join(pred_save_dir_task, npz_name), 222 | segs=seg_3D, 223 | gts=gt_3D, 224 | spacing=spacing 225 | ) 226 | 227 | if save_overlay: 228 | idx = int(seg_3D.shape[0] / 2) 229 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 230 | ax[0].imshow(img_3D[idx], cmap='gray') 231 | ax[0].set_title("Image") 232 | ax[0].axis('off') 233 | ax[1].imshow(img_3D[idx], cmap='gray') 234 | show_mask(gt_3D[idx], ax[1]) 235 | ax[1].axis('off') 236 | ax[1].set_title("Ground Truth") 237 | ax[2].imshow(img_3D[idx], cmap='gray') 238 | show_mask(seg_3D[idx], ax[2]) 239 | ax[2].set_title("Segmentation") 240 | ax[2].axis('off') 241 | plt.savefig( 242 | join(png_save_dir_task, npz_name.split(".")[0] + '.png'), 243 | dpi=300 244 | ) 245 | plt.close() 246 | 247 | print(f"Case {npz_name} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") 248 | 249 | # %% 250 | if __name__ == '__main__': 251 | num_workers = num_workers 252 | try: 253 | mp.set_start_method('spawn', force=True) 254 | print("spawned") 255 | except RuntimeError: 256 | pass 257 | with mp.Pool(processes=num_workers) as pool: 258 | with tqdm(total=len(gt_path_files)) as pbar: 259 | for i, _ in tqdm(enumerate(pool.imap_unordered(deeplabv3plus_infer_npz, gt_path_files))): 260 | pbar.update() 261 | -------------------------------------------------------------------------------- /comparisons/DeepLabV3+/train_deeplabv3_res50.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import glob 4 | import random 5 | import monai 6 | from os import makedirs 7 | from os.path import join 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | from time import time 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | from datetime import datetime 15 | 16 | import cv2 17 | import argparse 18 | from matplotlib import pyplot as plt 19 | import segmentation_models_pytorch as smp 20 | import monai 21 | # %% 22 | parser = argparse.ArgumentParser(description='Train DeepLabV3Plus') 23 | parser.add_argument('-i', '--data_root', type=str, default='', help='Two subfolders for training data: imgs and gts') 24 | parser.add_argument('-o', '--ckpt_dir', type=str, default='', help='Checkpoint save directory') 25 | parser.add_argument('-b', '--batch_size', type=int, default=600, help='batch size') 26 | parser.add_argument('--num_workers', type=int, default=30, help='number of workers for dataloader') 27 | parser.add_argument("--max_epochs", type=int, default=500, help="number of epochs") 28 | parser.add_argument('--compile', action='store_true', help='compile model') 29 | args = parser.parse_args() 30 | 31 | model_compile = args.compile 32 | num_epochs = args.max_epochs 33 | resume = None 34 | device = torch.device("cuda:0") 35 | data_root = args.data_root 36 | ckpt_dir = args.ckpt_dir 37 | batch_size = args.batch_size 38 | num_workers = args.num_workers 39 | makedirs(ckpt_dir, exist_ok=True) 40 | 41 | # %% 42 | torch.cuda.empty_cache() 43 | torch.set_float32_matmul_precision('high') 44 | 45 | def show_mask(mask, ax, random_color=False): 46 | if random_color: 47 | color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0) 48 | else: 49 | color = np.array([251/255, 252/255, 30/255, 0.45]) 50 | h, w = mask.shape[-2:] 51 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 52 | ax.imshow(mask_image) 53 | 54 | def show_box(box, ax): 55 | x0, y0 = box[0], box[1] 56 | w, h = box[2] - box[0], box[3] - box[1] 57 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 58 | 59 | 60 | # %% 61 | class NpyDataset(Dataset): 62 | def __init__(self, data_root, image_size=224, bbox_shift=5, data_aug=False): 63 | self.data_root = data_root 64 | self.gt_path = join(data_root, 'gts') 65 | self.img_path = join(data_root, 'imgs') 66 | self.gt_path_files = sorted(glob.glob(join(self.gt_path, '**/*.npy'), recursive=True)) 67 | self.gt_path_files = [file for file in self.gt_path_files if os.path.isfile(join(self.img_path, os.path.basename(file)))] 68 | self.image_size = image_size 69 | self.bbox_shift = bbox_shift 70 | self.data_aug = data_aug 71 | print(f'number of images: {len(self.gt_path_files)}') 72 | 73 | def __len__(self): 74 | return len(self.gt_path_files) 75 | 76 | def __getitem__(self, index): 77 | img_name = os.path.basename(self.gt_path_files[index]) 78 | assert img_name == os.path.basename(self.gt_path_files[index]), 'img gt name error' + self.gt_path_files[index] + self.npy_files[index] 79 | img_3c = np.load(join(self.img_path, img_name), 'r', allow_pickle=True) # (H, W, 3) 80 | resize_img_cv2 = cv2.resize( 81 | img_3c, 82 | (self.image_size, self.image_size), 83 | interpolation=cv2.INTER_AREA 84 | ) 85 | resize_img_cv2_01 = (resize_img_cv2 - resize_img_cv2.min()) / np.clip(resize_img_cv2.max() - resize_img_cv2.min(), a_min=1e-8, a_max=None) # normalize to [0, 1], (H, W, 3) 86 | # convert the shape to (3, H, W) 87 | resize_img = np.transpose(resize_img_cv2_01, (2, 0, 1)) 88 | assert np.max(resize_img)<=1.0 and np.min(resize_img)>=0.0, 'image should be normalized to [0, 1]' 89 | gt = np.load(self.gt_path_files[index], 'r', allow_pickle=True) # multiple labels [0, 1,4,5...] 90 | if gt.shape[0] != self.image_size or gt.shape[1] != self.image_size: 91 | gt_resize = cv2.resize( 92 | gt, (self.image_size, self.image_size), 93 | interpolation=cv2.INTER_NEAREST 94 | ) 95 | gt_resize = np.uint8(gt_resize) 96 | else: 97 | gt_resize = gt 98 | label_ids = np.unique(gt_resize)[1:] 99 | label_id = random.choice(label_ids.tolist()) 100 | gt2D = np.uint8(gt_resize == label_id) # only one label 101 | assert np.max(gt2D) == 1 and np.min(gt2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt2D)) 102 | # add data augmentation: random fliplr and random flipud 103 | if self.data_aug: 104 | if random.random() > 0.5: 105 | resize_img = np.ascontiguousarray(np.flip(resize_img, axis=-1)) 106 | gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-1)) 107 | # print('DA with flip left right') 108 | if random.random() > 0.5: 109 | resize_img = np.ascontiguousarray(np.flip(resize_img, axis=-2)) 110 | gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-2)) 111 | # print('DA with flip up down') 112 | y_indices, x_indices = np.where(gt2D > 0) 113 | x_min, x_max = np.min(x_indices), np.max(x_indices) 114 | y_min, y_max = np.min(y_indices), np.max(y_indices) 115 | # add perturbation to bounding box coordinates 116 | H, W = gt2D.shape 117 | x_min = max(0, x_min - random.randint(0, self.bbox_shift)) 118 | x_max = min(W, x_max + random.randint(0, self.bbox_shift)) 119 | y_min = max(0, y_min - random.randint(0, self.bbox_shift)) 120 | y_max = min(H, y_max + random.randint(0, self.bbox_shift)) 121 | bboxes = np.array([x_min, y_min, x_max, y_max]) 122 | 123 | ## Append bbox prompt channel 124 | resize_img_bbox = np.concatenate([resize_img, np.zeros((1, self.image_size, self.image_size))], axis=0) 125 | resize_img_bbox[-1, y_min:y_max, x_min:x_max] = 1.0 126 | # print(img_name, resize_img_bbox.shape, gt2D.shape) 127 | return torch.tensor(resize_img_bbox).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float(), img_name 128 | 129 | 130 | # %% 131 | model = smp.DeepLabV3Plus( 132 | encoder_name="resnet50", # encoder model type 133 | encoder_weights="imagenet", # use `imagenet` pretrained weights for encoder initialization 134 | in_channels=4, # Additional channel for bounding box prompt 135 | classes=1, # model output channels (number of classes in your dataset) 136 | activation=None # Output logits 137 | ) 138 | model.to(device) 139 | if model_compile: 140 | print("Compiling model...") 141 | model = torch.compile(model) 142 | 143 | # %% 144 | optimizer = torch.optim.Adam( 145 | model.parameters(), 146 | lr=0.001, 147 | weight_decay=4e-5, 148 | ) 149 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.94) 150 | # %% 151 | train_dataset = NpyDataset(data_root=data_root, data_aug=False, bbox_shift=5, image_size=224) 152 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 153 | 154 | # %% 155 | # loss function 156 | seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean', to_onehot_y=False) 157 | # %% 158 | # training 159 | if resume is not None: 160 | checkpoint = torch.load(resume) 161 | model._orig_mod.load_state_dict(checkpoint['model']) 162 | optimizer.load_state_dict(checkpoint['optimizer']) 163 | best_loss = checkpoint['best_loss'] 164 | start_epoch = checkpoint['epoch'] + 1 165 | print(f"Resuming training from epoch {start_epoch} with best loss {best_loss:.4f}") 166 | else: 167 | best_loss = 1e10 168 | best_epoch = 0 169 | start_epoch = 0 170 | 171 | for epoch in range(start_epoch, num_epochs): 172 | epoch_loss = [1e10 for _ in range(len(train_dataloader))] 173 | pbar = tqdm(train_dataloader) 174 | for step, (image, gt2D, boxes, img_names) in enumerate(pbar): 175 | optimizer.zero_grad() 176 | boxes_np = boxes.detach().cpu().numpy() 177 | image, gt2D = image.to(device), gt2D.to(device) 178 | pred = model(image) 179 | loss = seg_loss(pred, gt2D) 180 | epoch_loss[step] = loss.item() 181 | loss.backward() 182 | optimizer.step() 183 | optimizer.zero_grad() 184 | pbar.set_description(f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}") 185 | 186 | epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss) 187 | model_weights = model._orig_mod.state_dict() 188 | checkpoint = { 189 | "model": model_weights, 190 | "epoch": epoch, 191 | "optimizer": optimizer.state_dict(), 192 | "loss": epoch_loss_reduced, 193 | "best_loss": best_loss, 194 | } 195 | torch.save(checkpoint, join(ckpt_dir, "deeplabv3plus_latest.pt")) 196 | if epoch_loss_reduced < best_loss: 197 | print(f"New best loss: {best_loss:.4f} -> {epoch_loss_reduced:.4f}") 198 | best_loss = epoch_loss_reduced 199 | torch.save(checkpoint, join(ckpt_dir, "deeplabv3plus_best.pt")) 200 | 201 | epoch_loss_reduced = 1e10 202 | lr_scheduler.step() 203 | -------------------------------------------------------------------------------- /comparisons/SAM/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | The pre-trained [SAM](https://github.com/facebookresearch/segment-anything) model was used as a baseline in this study. SAM has three model types: ViT-H, ViT-L, and ViT-B. We used ViT-B model since it has a good trade off between segmentation accuracy and performance based on their ablation study (Fig.13). 4 | 5 | 6 | 1. Download the [checkpoint](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth). 7 | 8 | Note: We assuem that the preprocessing steps have been done. 9 | 10 | 2. For 2D images, run 11 | 12 | ```bash 13 | python infer_SAM_2D_npz.py -i input_path -o output_path -m model_path 14 | ``` 15 | 16 | 17 | 3. For 3D images, run 18 | 19 | ```bash 20 | python infer_SAM_3D_npz.py -i input_path -o output_path -m model_path 21 | ``` 22 | -------------------------------------------------------------------------------- /comparisons/SAM/infer_SAM_2D_npz.py: -------------------------------------------------------------------------------- 1 | #%% set environment 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from os import makedirs 5 | from os.path import join 6 | from tqdm import tqdm 7 | from skimage import transform 8 | import torch 9 | from segment_anything import sam_model_registry, SamPredictor 10 | import glob 11 | import os 12 | import argparse 13 | 14 | 15 | def show_mask(mask, ax, random_color=False): 16 | if random_color: 17 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 18 | else: 19 | color = np.array([251/255, 252/255, 30/255, 0.5]) 20 | h, w = mask.shape[-2:] 21 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 22 | ax.imshow(mask_image) 23 | 24 | def show_box(box, ax): 25 | x0, y0 = box[0], box[1] 26 | w, h = box[2] - box[0], box[3] - box[1] 27 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 28 | 29 | 30 | # %% 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('-i', '--img_path', type=str, default="") 33 | parser.add_argument('-o', '--seg_path', type=str, default="") 34 | parser.add_argument('-m', '--model_path', type=str, default="path to sam_vit_b_01ec64") 35 | args = parser.parse_args() 36 | img_path = args.img_path 37 | seg_path = args.seg_path 38 | makedirs(seg_path, exist_ok=True) 39 | 40 | SAM_MODEL_TYPE = "vit_b" 41 | SAM_CKPT_PATH = args.model_path 42 | device = torch.device("cuda:0") 43 | 44 | gt_path_files = sorted(glob.glob(join(img_path, '**/*.npz'), recursive=True)) 45 | print('find {} files'.format(len(gt_path_files))) 46 | image_size = 1024 47 | bbox_shift = 20 48 | 49 | # %% set up model 50 | sam_model = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CKPT_PATH) 51 | sam_model.to(device=device) 52 | predictor = SamPredictor(sam_model) 53 | 54 | #%% predict npz files and save results 55 | for gt_path_file in tqdm(gt_path_files): 56 | npz_name = os.path.basename(gt_path_file) 57 | task_folder = gt_path_file.split('/')[-2] 58 | os.makedirs(join(seg_path, task_folder), exist_ok=True) 59 | if not os.path.isfile(join(seg_path, task_folder, npz_name)): 60 | npz_data = np.load(gt_path_file, 'r', allow_pickle=True) # (H, W, 3) 61 | img_3c = npz_data['imgs'] # (Num, H, W) 62 | gt = npz_data['gts'] 63 | segs = np.zeros_like(gt, dtype=np.uint8) 64 | 65 | if len(img_3c.shape) == 2: 66 | img_3c = np.repeat(img_3c[:,:, None], 3, axis=-1) 67 | 68 | resize_img_1024 = cv2.resize(img_3c, (1024, 1024), interpolation=cv2.INTER_CUBIC) 69 | predictor.set_image(resize_img_1024.astype(np.uint8)) # conpute the image embedding only once 70 | 71 | gt_1024 = cv2.resize(gt, (1024, 1024), interpolation=cv2.INTER_NEAREST) # (1024, 1024) 72 | label_ids = np.unique(gt)[1:] 73 | for label_id in label_ids: 74 | gt_1024_label_id = np.uint8(gt_1024 == label_id) # only one label, (256, 256) 75 | y_indices, x_indices = np.where(gt_1024_label_id > 0) 76 | x_min, x_max = np.min(x_indices), np.max(x_indices) 77 | y_min, y_max = np.min(y_indices), np.max(y_indices) 78 | # add perturbation to bounding box coordinates 79 | H, W = gt_1024_label_id.shape 80 | x_min = max(0, x_min - bbox_shift) 81 | x_max = min(W, x_max + bbox_shift) 82 | y_min = max(0, y_min - bbox_shift) 83 | y_max = min(H, y_max + bbox_shift) 84 | bboxes1024 = np.array([x_min, y_min, x_max, y_max]) 85 | 86 | sam_mask, _, _ = predictor.predict(point_coords=None, point_labels=None, box=bboxes1024[None, :], multimask_output=False) #1024x1024, bool 87 | sam_mask = transform.resize(sam_mask[0].astype(np.uint8), (gt.shape[-2], gt.shape[-1]), order=0, preserve_range=True, mode='constant', anti_aliasing=False) # (256, 256) 88 | segs[sam_mask > 0] = label_id 89 | 90 | np.savez_compressed(join(seg_path, task_folder, npz_name), segs=segs, gts=gt) 91 | -------------------------------------------------------------------------------- /comparisons/SAM/infer_SAM_3D_npz.py: -------------------------------------------------------------------------------- 1 | #%% set environment 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from os import makedirs 5 | from os.path import join 6 | from tqdm import tqdm 7 | from skimage import transform 8 | import torch 9 | from segment_anything import sam_model_registry, SamPredictor 10 | import glob 11 | import os 12 | import argparse 13 | 14 | def show_mask(mask, ax, random_color=False): 15 | if random_color: 16 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 17 | else: 18 | color = np.array([251/255, 252/255, 30/255, 0.5]) 19 | h, w = mask.shape[-2:] 20 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 21 | ax.imshow(mask_image) 22 | 23 | def show_box(box, ax): 24 | x0, y0 = box[0], box[1] 25 | w, h = box[2] - box[0], box[3] - box[1] 26 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 27 | 28 | 29 | # %% 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('-i', '--img_path', type=str, default="") 32 | parser.add_argument('-o', '--seg_path', type=str, default="") 33 | parser.add_argument('-m', '--model_path', type=str, default="path to sam_vit_b_01ec64") 34 | args = parser.parse_args() 35 | img_path = args.img_path 36 | seg_path = args.seg_path 37 | makedirs(seg_path, exist_ok=True) 38 | 39 | SAM_MODEL_TYPE = "vit_b" 40 | SAM_CKPT_PATH = args.model_path 41 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | device = torch.device("cuda:0") 43 | 44 | gt_path_files = sorted(glob.glob(join(img_path, '**/*.npz'), recursive=True)) 45 | print('find {} files'.format(len(gt_path_files))) 46 | image_size = 1024 47 | bbox_shift = 20 48 | 49 | # %% set up model 50 | sam_model = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CKPT_PATH) 51 | sam_model.to(device=device) 52 | predictor = SamPredictor(sam_model) 53 | 54 | #%% predict npz files and save results 55 | for gt_path_file in tqdm(gt_path_files): 56 | npz_name = os.path.basename(gt_path_file) 57 | task_folder = gt_path_file.split('/')[-2] 58 | os.makedirs(join(seg_path, task_folder), exist_ok=True) 59 | npz_data = np.load(gt_path_file, 'r', allow_pickle=True) # (H, W, 3) 60 | img_3D = npz_data['imgs'] # (Num, H, W) 61 | gt_3D = npz_data['gts'] # (Num, 256, 256) 62 | spacing = npz_data['spacing'] 63 | seg_3D = np.zeros_like(gt_3D, dtype=np.uint8) # (Num, 256, 256) 64 | 65 | for i in range(img_3D.shape[0]): 66 | img_2d = img_3D[i,:,:] # (H, W, 3) 67 | img_3c = np.repeat(img_2d[:,:, None], 3, axis=-1) 68 | 69 | resize_img_1024 = cv2.resize(img_3c, (1024, 1024), interpolation=cv2.INTER_CUBIC) 70 | predictor.set_image(resize_img_1024.astype(np.uint8)) # conpute the image embedding only once 71 | 72 | gt = gt_3D[i,:,:] # (H, W) 73 | gt_1024 = cv2.resize(gt, (1024, 1024), interpolation=cv2.INTER_NEAREST) # (1024, 1024) 74 | label_ids = np.unique(gt)[1:] 75 | for label_id in label_ids: 76 | gt_1024_label_id = np.uint8(gt_1024 == label_id) # only one label, (256, 256) 77 | y_indices, x_indices = np.where(gt_1024_label_id > 0) 78 | x_min, x_max = np.min(x_indices), np.max(x_indices) 79 | y_min, y_max = np.min(y_indices), np.max(y_indices) 80 | # add perturbation to bounding box coordinates 81 | H, W = gt_1024_label_id.shape 82 | x_min = max(0, x_min - bbox_shift) 83 | x_max = min(W, x_max + bbox_shift) 84 | y_min = max(0, y_min - bbox_shift) 85 | y_max = min(H, y_max + bbox_shift) 86 | bboxes1024 = np.array([x_min, y_min, x_max, y_max]) 87 | 88 | sam_mask, _, _ = predictor.predict(point_coords=None, point_labels=None, box=bboxes1024[None, :], multimask_output=False) #1024x1024, bool 89 | sam_mask = transform.resize(sam_mask[0].astype(np.uint8), (gt.shape[-2], gt.shape[-1]), order=0, preserve_range=True, mode='constant', anti_aliasing=False) # (256, 256) 90 | seg_3D[i, sam_mask>0] = label_id 91 | np.savez_compressed(join(seg_path, task_folder, npz_name), segs=seg_3D, gts=gt_3D, spacing=spacing) # save spacing for metric computation 92 | -------------------------------------------------------------------------------- /comparisons/nnU-Net/README.md: -------------------------------------------------------------------------------- 1 | # nnU-Net 2 | 3 | This folder contains the scripts for training and inference of the nnUNet model on medical image data in MedSAM's preprocessed `npz` format. For details regarding the data preprocessing pipeline, please refer to the [MedSAM](https://github.com/bowang-lab/MedSAM/tree/main/utils). 4 | 5 | ## Prerequisites 6 | 7 | This codebase uses the [nnUNetv2](https://github.com/MIC-DKFZ/nnUNet/tree/master). One can choose to install the out-of-the-box version via pip as follows: 8 | 9 | ```sh 10 | pip install nnunetv2 11 | ``` 12 | 13 | For further details regarding configuring the nnUNetv2, please consult nnUNet's official [documentation](https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/installation_instructions.md). 14 | 15 | 16 | ## Training 17 | We converted the training npz/npy files to nii format which is widely used by nnU-Net. In order to incorporate the bounding box prompts into the model, we converted the bounding box as a binary mask and concatenated it with the image as the model input. The bounding box was simulated based on ground truth. 18 | 19 | 20 | ```bash 21 | nnUNetv2_train xxx 2d all 22 | ``` 23 | 24 | 25 | ## Inference 26 | 27 | The inference scripts assume that the data is in the `npz` format generated by MedSAM preprocess pipeline. To run inference, one can download the model [here](https://drive.google.com/drive/folders/1xYUgdjIsmBkobiBKXNb1uyqN-kGHW2p_?usp=sharing) and use the provided inference scripts. 28 | 29 | ### Inference for 2D images 30 | 31 | The `infer_nnunet_2D.py` script can be used for inference on 2D images. Below are the parameters that need to be configured: 32 | 33 | * `-checkpoint`: Path to the trained model checkpoint. 34 | * `-data_root`: Path to the test data. 35 | * `--grey`: Whether the input dataset is greyscale. 36 | * `-pred_save_dir`: Path to save the output segmented images. 37 | * `--save_overlay`: Save the overlay of the segmentation on the original image. (Optional) 38 | * `-png_save_dir`: Path to save the overlay images. (Required if `--save_overlay` is used) 39 | * `-num_workers`: Number of workers for multiprocessing during inference. (Optional) 40 | 41 | Note that for 2D images, the preprocessing step prior to inference is different for the RGB and greyscale. Hence it is necessary to specify the `--grey` flag when running inference on greyscale images. 42 | 43 | For RGB images: 44 | ```sh 45 | python infer_nnunet_2D.py \ 46 | -checkpoint nnUNet_results/Dataset001_Fundus \ 47 | -data_root path/to/test/data \ 48 | -pred_save_dir path/to/save/results \ 49 | --save_overlay \ 50 | -png_save_dir path/to/save/overlay \ 51 | -num_workers 2 \ 52 | ``` 53 | 54 | For greyscale images: 55 | ```sh 56 | python infer_nnunet_2D.py \ 57 | -checkpoint nnUNet_results/Dataset002_X-Ray \ 58 | -data_root path/to/test/data \ 59 | -pred_save_dir path/to/save/results \ 60 | --save_overlay \ 61 | -png_save_dir path/to/save/overlay \ 62 | -num_workers 2 \ 63 | --grey 64 | ``` 65 | 66 | ### Inference for 3D images 67 | 68 | The `infer_nnunet_3D.py` script can be used for inference on 3D images. Below are the parameters that need to be configured: 69 | 70 | * `-checkpoint`: Path to the trained model checkpoint. 71 | * `-data_root`: Path to the test data. 72 | * `-pred_save_dir`: Path to save the output segmented 3D images. 73 | * `--save_overlay`: Save the overlay of the segmentation on the original image. (Optional) 74 | * `-png_save_dir`: Path to save the overlay images. (Required if --save_overlay is used) 75 | * `-num_workers`: Number of workers for multiprocessing during inference. (Optional) 76 | 77 | Example command for 3D images inference: 78 | ```sh 79 | python infer_nnunet_3D.py \ 80 | -checkpoint nnUNet_results/Dataset003_CT \ 81 | -data_root path/to/test/data \ 82 | -pred_save_dir path/to/save/results \ 83 | --save_overlay \ 84 | -png_save_dir path/to/save/overlay \ 85 | -num_workers 2 \ 86 | ``` 87 | 88 | ## Acknowledgement 89 | We would like to thank the authors and the contributors of [nnUNet](https://github.com/MIC-DKFZ/nnUNet/tree/master) for their great work and for making the code publicly available. 90 | -------------------------------------------------------------------------------- /comparisons/nnU-Net/infer_nnunet_2D.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import glob 4 | import random 5 | from os import listdir, makedirs 6 | from os.path import join, exists, isfile, isdir, basename 7 | from tqdm import tqdm, trange 8 | from copy import deepcopy 9 | from time import time 10 | import numpy as np 11 | import torch 12 | from torch._dynamo import OptimizedModule 13 | from torch import multiprocessing as mp 14 | 15 | import cv2 16 | from skimage import morphology, color 17 | import torch.nn.functional as F 18 | 19 | from matplotlib import pyplot as plt 20 | 21 | from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor 22 | 23 | import argparse 24 | # %% 25 | torch.cuda.empty_cache() 26 | os.environ['PYTHONHASHSEED']=str(2023) 27 | random.seed(2023) 28 | np.random.seed(2023) 29 | torch.manual_seed(2023) 30 | torch.cuda.manual_seed(2023) 31 | 32 | # %% 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument( 35 | '-checkpoint', 36 | type=str, 37 | default='', 38 | help='Path to the model checkpoint directory in nnUNet_results', 39 | required=True 40 | ) 41 | parser.add_argument( 42 | '-data_root', 43 | type=str, 44 | default='', 45 | help='Path to the validation data directory', 46 | required=True 47 | ) 48 | parser.add_argument( 49 | '-pred_save_dir', 50 | type=str, 51 | default='', 52 | help='Path to the directory where the segmentation results will be saved in npz format' 53 | ) 54 | parser.add_argument('--save_overlay', action='store_true', default=False, help="Whether to save segmentation overlay") 55 | parser.add_argument( 56 | '-png_save_dir', 57 | type=str, 58 | default='', 59 | help='Path to the directory where the segmentation overlay will be saved in png format' 60 | ) 61 | parser.add_argument( 62 | '-bbox_shift', 63 | type=int, 64 | default=10, 65 | help='Perturbation shift of bounding box prompt' 66 | ) 67 | parser.add_argument( 68 | '--grey', 69 | action='store_true', 70 | default=False, 71 | help="Whether the input image is in grey scale" 72 | ) 73 | parser.add_argument( 74 | '-num_workers', type=int, default=1, 75 | help='number of workers for multiprocessing' 76 | ) 77 | 78 | # %% 79 | args = parser.parse_args() 80 | checkpoint = args.checkpoint 81 | data_root = args.data_root 82 | pred_save_dir = args.pred_save_dir 83 | png_save_dir = args.png_save_dir 84 | makedirs(pred_save_dir, exist_ok=True) 85 | save_overlay = args.save_overlay 86 | if save_overlay: 87 | makedirs(png_save_dir, exist_ok=True) 88 | num_workers = args.num_workers 89 | data_root_files = listdir(data_root) 90 | 91 | ## Check whether there exist subfolders 92 | has_task = isdir(join(data_root, data_root_files[0])) 93 | if has_task: 94 | gt_path_files = sorted(glob.glob(join(data_root, '**/*.npz'), recursive=True)) 95 | else: 96 | gt_path_files = sorted(glob.glob(join(data_root, '*.npz'), recursive=True)) 97 | bbox_shift = args.bbox_shift 98 | props = {'spacing': (999, 1, 1)} 99 | is_grey = args.grey 100 | 101 | # %% 102 | def show_mask(mask, ax, random_color=True, alpha=0.45): 103 | if random_color: 104 | color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0) 105 | else: 106 | color = np.array([251/255, 252/255, 30/255, alpha]) 107 | h, w = mask.shape[-2:] 108 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 109 | ax.imshow(mask_image) 110 | 111 | def show_box(box, ax): 112 | x0, y0 = box[0], box[1] 113 | w, h = box[2] - box[0], box[3] - box[1] 114 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 115 | 116 | def dice_coefficient(preds, targets): 117 | smooth = 1.0 118 | assert preds.shape == targets.shape 119 | 120 | intersection = (preds * targets).sum() 121 | dice = (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth) 122 | return dice 123 | 124 | # %% 125 | predictor = nnUNetPredictor( 126 | tile_step_size=0.5, 127 | use_gaussian=True, 128 | use_mirroring=False, ## disable tta 129 | perform_everything_on_gpu=True, 130 | device=torch.device('cuda', 0), 131 | verbose=False, 132 | verbose_preprocessing=False, 133 | allow_tqdm=False 134 | ) 135 | predictor.initialize_from_trained_model_folder( 136 | join(checkpoint, 'nnUNetTrainer__nnUNetPlans__2d'), 137 | use_folds='all', 138 | checkpoint_name='checkpoint_final.pth', 139 | ) 140 | 141 | # %% 142 | def preprocess_image_rgb(img_3c, gt_2D, bbox_shift=10): 143 | """ 144 | Append bounding box prompt channel to image 145 | """ 146 | img_3c = np.transpose(img_3c, (2, 0, 1)) ## (3, H, W) 147 | if gt_2D.shape[0] != img_3c.shape[1] or gt_2D.shape[1] != img_3c.shape[2]: 148 | gt_2D = cv2.resize( 149 | gt_2D, (img_3c.shape[2], img_3c.shape[1]), 150 | interpolation=cv2.INTER_NEAREST 151 | ) 152 | gt_2D = np.uint8(gt_2D) 153 | else: 154 | gt_2D = gt_2D.astype(np.uint8) 155 | try: 156 | assert np.max(gt_2D) == 1 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 157 | except: 158 | assert np.max(gt_2D) == 0 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 159 | return None 160 | 161 | y_indices, x_indices = np.where(gt_2D > 0) 162 | x_min, x_max = np.min(x_indices), np.max(x_indices) 163 | y_min, y_max = np.min(y_indices), np.max(y_indices) 164 | # add perturbation to bounding box coordinates 165 | H, W = gt_2D.shape 166 | x_min = max(0, x_min - bbox_shift) 167 | x_max = min(W, x_max + bbox_shift) 168 | y_min = max(0, y_min - bbox_shift) 169 | y_max = min(H, y_max + bbox_shift) 170 | 171 | ## Append bbox prompt channel as the last channel 172 | img_4c = np.concatenate([img_3c, np.zeros((1, img_3c.shape[1], img_3c.shape[2]))], axis=0) 173 | img_4c[-1, y_min:y_max, x_min:x_max] = 1.0 174 | img_4c = img_4c[:, None, ...] ## (4, 1, H, W) 175 | 176 | return torch.tensor(img_4c).float() 177 | 178 | def preprocess_image_grey(img_3c, gt_2D, bbox_shift=10): 179 | """ 180 | Append bounding box prompt channel to image 181 | """ 182 | if len(img_3c.shape) == 3: 183 | img_1c = np.uint8(color.rgb2gray(img_3c)*255.0) 184 | else: 185 | img_1c = img_3c 186 | img_1c = img_1c[None, ...] ## (1, H, W) 187 | if gt_2D.shape[0] != img_1c.shape[1] or gt_2D.shape[1] != img_1c.shape[2]: 188 | gt_2D = cv2.resize( 189 | gt_2D, (img_1c.shape[2], img_1c.shape[1]), 190 | interpolation=cv2.INTER_NEAREST 191 | ) 192 | gt_2D = np.uint8(gt_2D) 193 | else: 194 | gt_2D = gt_2D.astype(np.uint8) 195 | try: 196 | assert np.max(gt_2D) == 1 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 197 | except: 198 | assert np.max(gt_2D) == 0 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 199 | return None 200 | 201 | y_indices, x_indices = np.where(gt_2D > 0) 202 | x_min, x_max = np.min(x_indices), np.max(x_indices) 203 | y_min, y_max = np.min(y_indices), np.max(y_indices) 204 | 205 | H, W = gt_2D.shape 206 | x_min = max(0, x_min - bbox_shift) 207 | x_max = min(W, x_max + bbox_shift) 208 | y_min = max(0, y_min - bbox_shift) 209 | y_max = min(H, y_max + bbox_shift) 210 | 211 | ## Append bbox prompt channel as the last channel 212 | img_2c = np.concatenate([img_1c, np.zeros((1, img_1c.shape[1], img_1c.shape[2]))], axis=0) 213 | img_2c[-1, y_min:y_max, x_min:x_max] = 1.0 214 | img_2c = img_2c[:, None, ...] ## (2, 1, H, W) 215 | 216 | return torch.tensor(img_2c).float() 217 | # %% 218 | def nnunet_infer_npz(gt_path_file): 219 | npz_name = basename(gt_path_file) 220 | if has_task: 221 | task_folder = gt_path_file.split('/')[-2] 222 | pred_save_dir_task = join(pred_save_dir, task_folder) 223 | png_save_dir_task = join(png_save_dir, task_folder) 224 | makedirs(pred_save_dir_task, exist_ok=True) 225 | makedirs(png_save_dir_task, exist_ok=True) 226 | else: 227 | pred_save_dir_task = pred_save_dir 228 | png_save_dir_task = png_save_dir 229 | if isfile(join(pred_save_dir_task, npz_name)): 230 | return 231 | npz = np.load(gt_path_file, 'r', allow_pickle=True) 232 | imgs = npz['imgs'] # (H, W, C) 233 | gts = npz['gts'] # (H, W) 234 | segs = np.zeros_like(imgs[..., 0], dtype=np.uint8) 235 | 236 | label_ids = np.unique(gts)[1:] 237 | for label_id in label_ids: 238 | gt_2D = np.uint8(gts == label_id) # one label at a time 239 | if is_grey: 240 | img_bbox = preprocess_image_grey( 241 | imgs, 242 | gt_2D, 243 | bbox_shift=bbox_shift 244 | ) 245 | else: 246 | img_bbox = preprocess_image_rgb( 247 | imgs, 248 | gt_2D, 249 | bbox_shift=bbox_shift 250 | ) 251 | if img_bbox == None: 252 | continue ## No label available for the image 253 | seg_2D = predictor.predict_single_npy_array( 254 | input_image = img_bbox, 255 | image_properties = props, 256 | segmentation_previous_stage = None, 257 | output_file_truncated = None, 258 | save_or_return_probabilities = False 259 | ) 260 | seg_2D = seg_2D.squeeze() 261 | seg_2D = cv2.resize( 262 | seg_2D, (imgs.shape[1], imgs.shape[0]), 263 | interpolation=cv2.INTER_NEAREST 264 | ).astype(np.uint8) 265 | 266 | segs[seg_2D > 0] = label_id 267 | 268 | if gts.shape[0] != imgs.shape[0] or gts.shape[1] != imgs.shape[1]: 269 | gts = cv2.resize( 270 | gts, 271 | (imgs.shape[1], imgs.shape[0]), 272 | interpolation=cv2.INTER_NEAREST 273 | ) 274 | 275 | np.savez_compressed( 276 | join(pred_save_dir_task, npz_name), 277 | segs=segs, 278 | gts=gts 279 | ) 280 | 281 | if save_overlay: 282 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 283 | if args.grey: 284 | ax[0].imshow(imgs, cmap='gray') 285 | else: 286 | ax[0].imshow(imgs) 287 | ax[0].set_title("Image") 288 | ax[0].axis('off') 289 | if args.grey: 290 | ax[1].imshow(imgs, cmap='gray') 291 | else: 292 | ax[1].imshow(imgs) 293 | show_mask(gts, ax[1]) 294 | ax[1].axis('off') 295 | ax[1].set_title("Ground Truth") 296 | if args.grey: 297 | ax[2].imshow(imgs, cmap='gray') 298 | else: 299 | ax[2].imshow(imgs) 300 | show_mask(segs, ax[2]) 301 | ax[2].set_title("Segmentation") 302 | ax[2].axis('off') 303 | plt.savefig( 304 | join(png_save_dir_task, npz_name.split(".")[0] + '.png'), 305 | dpi=300 306 | ) 307 | plt.close() 308 | # %% 309 | if __name__ == '__main__': 310 | num_workers = num_workers 311 | mp.set_start_method('spawn') 312 | with mp.Pool(processes=num_workers) as pool: 313 | with tqdm(total=len(gt_path_files)) as pbar: 314 | for i, _ in tqdm(enumerate(pool.imap_unordered(nnunet_infer_npz, gt_path_files))): 315 | pbar.update() -------------------------------------------------------------------------------- /comparisons/nnU-Net/infer_nnunet_3D.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import glob 4 | import random 5 | from os import listdir, makedirs 6 | from os.path import join, exists, isfile, isdir, basename 7 | from tqdm import tqdm, trange 8 | from copy import deepcopy 9 | from time import time 10 | import numpy as np 11 | import torch 12 | from torch._dynamo import OptimizedModule 13 | from torch import multiprocessing as mp 14 | from datetime import datetime 15 | 16 | import cv2 17 | from skimage import morphology 18 | import torch.nn.functional as F 19 | 20 | from matplotlib import pyplot as plt 21 | 22 | from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor 23 | 24 | import argparse 25 | 26 | import timeit 27 | 28 | torch.cuda.empty_cache() 29 | os.environ['PYTHONHASHSEED']=str(2023) 30 | random.seed(2023) 31 | np.random.seed(2023) 32 | torch.manual_seed(2023) 33 | torch.cuda.manual_seed(2023) 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | '-checkpoint', 38 | type=str, 39 | default='', 40 | help='Path to the model checkpoint directory in nnUNet_results', 41 | required=True 42 | ) 43 | parser.add_argument( 44 | '-data_root', 45 | type=str, 46 | default='', 47 | help='Path to the validation data directory', 48 | required=True 49 | ) 50 | parser.add_argument( 51 | '-pred_save_dir', 52 | type=str, 53 | default='', 54 | help='Path to the directory where the segmentation results will be saved in npz format' 55 | ) 56 | parser.add_argument('--save_overlay', action='store_true', default=False, help="Whether to save segmentation overlay") 57 | parser.add_argument( 58 | '-png_save_dir', 59 | type=str, 60 | default='', 61 | help='Path to the directory where the segmentation overlay will be saved in png format' 62 | ) 63 | parser.add_argument( 64 | '-bbox_shift', 65 | type=int, 66 | default=10, 67 | help='Perturbation shift of bounding box prompt' 68 | ) 69 | parser.add_argument( 70 | '-num_workers', type=int, default=1, 71 | help='number of workers for multiprocessing' 72 | ) 73 | 74 | args = parser.parse_args() 75 | checkpoint = args.checkpoint 76 | data_root = args.data_root 77 | pred_save_dir = args.pred_save_dir 78 | png_save_dir = args.png_save_dir 79 | makedirs(pred_save_dir, exist_ok=True) 80 | save_overlay = args.save_overlay 81 | if save_overlay: 82 | makedirs(png_save_dir, exist_ok=True) 83 | num_workers = args.num_workers 84 | data_root_files = listdir(data_root) 85 | ## Check if there are subfolders 86 | has_task = isdir(join(data_root, data_root_files[0])) 87 | if has_task: 88 | gt_path_files = sorted(glob.glob(join(data_root, '**/*.npz'), recursive=True)) 89 | else: 90 | gt_path_files = sorted(glob.glob(join(data_root, '*.npz'), recursive=True)) 91 | bbox_shift = args.bbox_shift 92 | props = {'spacing': (999, 1, 1)} 93 | # %% 94 | def show_mask(mask, ax, random_color=False): 95 | if random_color: 96 | color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0) 97 | else: 98 | color = np.array([251/255, 252/255, 30/255, 0.45]) 99 | h, w = mask.shape[-2:] 100 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 101 | ax.imshow(mask_image) 102 | 103 | def show_box(box, ax): 104 | x0, y0 = box[0], box[1] 105 | w, h = box[2] - box[0], box[3] - box[1] 106 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 107 | 108 | def dice_coefficient(preds, targets): 109 | smooth = 1.0 110 | assert preds.shape == targets.shape 111 | 112 | intersection = (preds * targets).sum() 113 | dice = (2.0 * intersection + smooth) / (preds.sum() + targets.sum() + smooth) 114 | return dice 115 | 116 | predictor = nnUNetPredictor( 117 | tile_step_size=0.5, 118 | use_gaussian=True, 119 | use_mirroring=False, ## disable tta 120 | perform_everything_on_gpu=True, 121 | device=torch.device('cuda', 0), 122 | verbose=False, 123 | verbose_preprocessing=False, 124 | allow_tqdm=False 125 | ) 126 | predictor.initialize_from_trained_model_folder( 127 | join(checkpoint, 'nnUNetTrainer__nnUNetPlans__2d'), 128 | use_folds='all', 129 | checkpoint_name='checkpoint_final.pth', 130 | ) 131 | 132 | # %% 133 | def preprocess_slice(img_2D, gt_2D, bbox_shift=10): 134 | """ 135 | Append bounding box prompt channel to image 136 | """ 137 | img_1c = img_2D[None, ...] ## (1, H, W) 138 | if gt_2D.shape[0] != img_1c.shape[1] or gt_2D.shape[1] != img_1c.shape[2]: 139 | gt_2D = cv2.resize( 140 | gt_2D, (img_1c.shape[2], img_1c.shape[1]), 141 | interpolation=cv2.INTER_NEAREST 142 | ) 143 | gt_2D = np.uint8(gt_2D) 144 | else: 145 | gt_2D = gt_2D.astype(np.uint8) 146 | try: 147 | assert np.max(gt_2D) == 1 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 148 | except: 149 | assert np.max(gt_2D) == 0 and np.min(gt_2D) == 0, 'ground truth should be 0, 1, got: ' + str(np.unique(gt_2D)) 150 | return None 151 | 152 | y_indices, x_indices = np.where(gt_2D > 0) 153 | x_min, x_max = np.min(x_indices), np.max(x_indices) 154 | y_min, y_max = np.min(y_indices), np.max(y_indices) 155 | # add perturbation to bounding box coordinates 156 | H, W = gt_2D.shape 157 | x_min = max(0, x_min - bbox_shift) 158 | x_max = min(W, x_max + bbox_shift) 159 | y_min = max(0, y_min - bbox_shift) 160 | y_max = min(H, y_max + bbox_shift) 161 | 162 | ## Append bbox prompt channel as the last channel 163 | img_2c = np.concatenate([img_1c, np.zeros((1, img_1c.shape[1], img_1c.shape[2]))], axis=0) 164 | img_2c[-1, y_min:y_max, x_min:x_max] = 1.0 165 | img_2c = img_2c[:, None, ...] ## (2, 1, H, W) 166 | 167 | return torch.tensor(img_2c).float() 168 | 169 | # %% 170 | def nnunet_infer_npz(gt_path_file): 171 | npz_name = basename(gt_path_file) 172 | if has_task: 173 | task_folder = gt_path_file.split('/')[-2] 174 | pred_save_dir_task = join(pred_save_dir, task_folder) 175 | png_save_dir_task = join(png_save_dir, task_folder) 176 | makedirs(pred_save_dir_task, exist_ok=True) 177 | makedirs(png_save_dir_task, exist_ok=True) 178 | else: 179 | pred_save_dir_task = pred_save_dir 180 | png_save_dir_task = png_save_dir 181 | if isfile(join(pred_save_dir_task, npz_name)): 182 | return 183 | 184 | npz = np.load(gt_path_file, 'r', allow_pickle=True) 185 | img_3D = npz['imgs'] # (Num, H, W) 186 | gt_3D = npz['gts'] # (Num, H, W) 187 | spacing = npz['spacing'] 188 | seg_3D = np.zeros_like(gt_3D, dtype=np.uint8) 189 | 190 | for i in range(img_3D.shape[0]): 191 | img_2D = img_3D[i,:,:] # (H, W) 192 | gt = gt_3D[i,:,:] # (H, W) 193 | label_ids = np.unique(gt)[1:] 194 | for label_id in label_ids: 195 | gt2D = np.uint8(gt == label_id) # one label at a time 196 | img_2c = preprocess_slice( 197 | img_2D, 198 | gt2D, 199 | bbox_shift=bbox_shift 200 | ) 201 | if img_2c is None: 202 | continue ## no label available for this slice 203 | seg_2D = predictor.predict_single_npy_array( 204 | input_image = img_2c, 205 | image_properties = props, 206 | segmentation_previous_stage = None, 207 | output_file_truncated = None, 208 | save_or_return_probabilities = False 209 | ) 210 | seg_2D = seg_2D.squeeze() 211 | seg_2D = cv2.resize( 212 | seg_2D, (gt2D.shape[1], gt2D.shape[0]), 213 | interpolation=cv2.INTER_NEAREST 214 | ).astype(np.uint8) 215 | 216 | seg_3D[i, seg_2D>0] = label_id 217 | 218 | np.savez_compressed( 219 | join(pred_save_dir_task, npz_name), 220 | segs=seg_3D, 221 | gts=gt_3D, 222 | spacing=spacing 223 | ) 224 | 225 | # visualize image, mask and bounding box 226 | 227 | if save_overlay: 228 | idx = int(seg_3D.shape[0] / 2) 229 | fig, ax = plt.subplots(1, 3, figsize=(15, 5)) 230 | ax[0].imshow(img_3D[idx], cmap='gray') 231 | ax[0].set_title("Image") 232 | ax[0].axis('off') 233 | ax[1].imshow(img_3D[idx], cmap='gray') 234 | show_mask(gt_3D[idx], ax[1]) 235 | ax[1].axis('off') 236 | ax[1].set_title("Ground Truth") 237 | ax[2].imshow(img_3D[idx], cmap='gray') 238 | show_mask(seg_3D[idx], ax[2]) 239 | ax[2].set_title("Segmentation") 240 | ax[2].axis('off') 241 | plt.savefig( 242 | join(png_save_dir_task, npz_name.split(".")[0] + '.png'), 243 | dpi=300 244 | ) 245 | plt.close() 246 | 247 | print(f"Case {npz_name} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") 248 | 249 | 250 | if __name__ == "__main__": 251 | num_workers = num_workers 252 | mp.set_start_method('spawn') 253 | with mp.Pool(processes=num_workers) as pool: 254 | with tqdm(total=len(gt_path_files)) as pbar: 255 | pbar.set_description(f"[{basename(data_root)}]: ") 256 | for i, _ in tqdm(enumerate(pool.imap_unordered(nnunet_infer_npz, gt_path_files))): 257 | pbar.update() 258 | -------------------------------------------------------------------------------- /extensions/point_prompt/README.md: -------------------------------------------------------------------------------- 1 | # MedSAM with point prompts 2 | 3 | ## Inference 4 | 5 | Please try this out-of-the-box demo: [(colab)](https://colab.research.google.com/drive/1cCBw_IhdPiWE4sN7QwqKJPgAFlWsKgkm?usp=sharing) 6 | 7 | ## Training 8 | 9 | This training script shows how to train MedSAM with point prompts on the [MICCAI FLARE 2022](https://flare22.grand-challenge.org/) dataset, and assume that the dataset has been preprocessed into the format used by MedSAM as described [here](https://github.com/bowang-lab/MedSAM#data-preprocessing). 10 | 11 | The training script `train_point_prompt.py` takes the following arguments: 12 | * `-i`, `--tr_npy_path`: Path to the preprocessed npy data in MedSAM's format 13 | * `-medsam_checkpoint`: Path to the MedSAM checkpoint 14 | * `-work_dir`: Path to the directory where the model checkpoints will be saved 15 | * `-resume`: Path to the checkpoint to resume training from 16 | * `-batch_size`: Batch size 17 | 18 | For example, assume that the preprocessed data is stored in directory `npy_data`, the MedSAM model is stored in `MedSAM/work_dir/MedSAM/medsam_vit_b.pth`, and the model checkpoints should be saved in `train_point_prompt`. To train the model with a batch size of 16, run the following command: 19 | ``` 20 | python train_point_prompt.py \ 21 | -i npy_data \ 22 | -medsam_checkpoint MedSAM/work_dir/MedSAM/medsam_vit_b.pth \ 23 | -work_dir ./train_point_prompt 24 | ``` 25 | 26 | To resume an interrupted training, simply add the `-resume` argument: 27 | ``` 28 | python train_point_prompt.py \ 29 | -i npy_data \ 30 | -medsam_checkpoint MedSAM/work_dir/MedSAM/medsam_vit_b.pth \ 31 | -work_dir ./train_point_prompt \ 32 | -resume ./train_point_prompt/medsam_point_prompt_latest.pth 33 | ``` 34 | -------------------------------------------------------------------------------- /extensions/point_prompt/point_seg_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/extensions/point_prompt/point_seg_demo.gif -------------------------------------------------------------------------------- /extensions/point_prompt/train_point_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import monai 5 | from os import makedirs 6 | from os.path import join 7 | from tqdm import tqdm 8 | from time import time 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.utils.data import Dataset, DataLoader 14 | from datetime import datetime 15 | from segment_anything import sam_model_registry 16 | import cv2 17 | from matplotlib import pyplot as plt 18 | import argparse 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | '-i', 23 | '--tr_npy_path', 24 | type=str, 25 | help="Path to the data root directory.", 26 | required=True 27 | ) 28 | parser.add_argument( 29 | '-medsam_checkpoint', 30 | type=str, 31 | help="Path to the MedSAM checkpoint.", 32 | required=True 33 | ) 34 | parser.add_argument( 35 | '-work_dir', 36 | type=str, 37 | default="finetune_point_prompt", 38 | help="Path to where the checkpoints and logs are saved." 39 | ) 40 | parser.add_argument( 41 | '-max_epochs', 42 | type=int, 43 | default=1000, 44 | help="Maximum number of epochs." 45 | ) 46 | parser.add_argument( 47 | '-batch_size', 48 | type=int, 49 | default=16, 50 | help="Batch size." 51 | ) 52 | parser.add_argument( 53 | '-num_workers', 54 | type=int, 55 | default=8, 56 | help="Number of data loader workers." 57 | ) 58 | parser.add_argument( 59 | '-resume', 60 | type=str, 61 | default=None, 62 | help="Path to the checkpoint to resume from." 63 | ) 64 | parser.add_argument( 65 | '-lr', 66 | type=float, 67 | default=0.00005, 68 | help="learning rate (absolute lr)" 69 | ) 70 | parser.add_argument( 71 | '-weight_decay', 72 | type=float, 73 | default=0.01, 74 | help="Weight decay." 75 | ) 76 | parser.add_argument( 77 | '-seed', 78 | type=int, 79 | default=2023, 80 | help="Random seed for reproducibility." 81 | ) 82 | parser.add_argument( 83 | '--disable_aug', 84 | action='store_true', 85 | help="Disable data augmentation." 86 | ) 87 | args = parser.parse_args() 88 | 89 | data_root = args.tr_npy_path 90 | work_dir = args.work_dir 91 | num_epochs = args.max_epochs 92 | batch_size = args.batch_size 93 | num_workers = args.num_workers 94 | medsam_checkpoint = args.medsam_checkpoint 95 | data_aug = not args.disable_aug 96 | seed = args.seed 97 | device = "cuda:0" 98 | makedirs(work_dir, exist_ok=True) 99 | 100 | torch.cuda.empty_cache() 101 | os.environ['PYTHONHASHSEED']=str(seed) 102 | random.seed(seed) 103 | np.random.seed(seed) 104 | torch.manual_seed(seed) 105 | torch.cuda.manual_seed(seed) 106 | 107 | # Dataset class 108 | class NpyDataset(Dataset): 109 | def __init__(self, data_root, image_size=1024, data_aug=True): 110 | self.data_root = data_root 111 | self.gt_path = join(data_root, 'gts') 112 | self.img_path = join(data_root, 'imgs') 113 | self.gt_path_files = sorted(glob.glob(join(self.gt_path, '**/*.npy'), recursive=True)) 114 | self.gt_path_files = [file for file in self.gt_path_files if os.path.isfile(join(self.img_path, os.path.basename(file)))] 115 | self.image_size = image_size 116 | self.data_aug = data_aug 117 | 118 | def __len__(self): 119 | return len(self.gt_path_files) 120 | 121 | def __getitem__(self, index): 122 | img_name = os.path.basename(self.gt_path_files[index]) 123 | assert img_name == os.path.basename(self.gt_path_files[index]), 'img gt name error' + self.gt_path_files[index] + self.npy_files[index] 124 | img_1024 = np.load(join(self.img_path, img_name), 'r', allow_pickle=True) # (H, W, 3) 125 | # convert the shape to (3, H, W) 126 | img_1024 = np.transpose(img_1024, (2, 0, 1)) # (3, 256, 256) 127 | assert np.max(img_1024)<=1.0 and np.min(img_1024)>=0.0, 'image should be normalized to [0, 1]' 128 | gt = np.load(self.gt_path_files[index], 'r', allow_pickle=True) # multiple labels [0, 1,4,5...], (256,256) 129 | label_ids = np.unique(gt)[1:] 130 | try: 131 | gt2D = np.uint8(gt == random.choice(label_ids.tolist())) # only one label, (256, 256) 132 | except: 133 | print(img_name, 'label_ids.tolist()', label_ids.tolist()) 134 | gt2D = np.uint8(gt == np.max(gt)) # only one label, (256, 256) 135 | # add data augmentation: random fliplr and random flipud 136 | if self.data_aug: 137 | if random.random() > 0.5: 138 | img_1024 = np.ascontiguousarray(np.flip(img_1024, axis=-1)) 139 | gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-1)) 140 | if random.random() > 0.5: 141 | img_1024 = np.ascontiguousarray(np.flip(img_1024, axis=-2)) 142 | gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-2)) 143 | gt2D = np.uint8(gt2D > 0) 144 | y_indices, x_indices = np.where(gt2D > 0) 145 | x_point = np.random.choice(x_indices) 146 | y_point = np.random.choice(y_indices) 147 | coords = np.array([x_point, y_point]) 148 | 149 | ## Randomly sample a point from the gt at scale 1024 150 | gt2D_256 = cv2.resize( 151 | gt2D, 152 | (256, 256), 153 | interpolation=cv2.INTER_NEAREST 154 | ) 155 | return { 156 | "image": torch.tensor(img_1024).float(), 157 | "gt2D": torch.tensor(gt2D_256[None, :,:]).long(), 158 | "coords": torch.tensor(coords[None, ...]).float(), 159 | "image_name": img_name 160 | } 161 | 162 | class MedSAM(nn.Module): 163 | def __init__(self, 164 | image_encoder, 165 | mask_decoder, 166 | prompt_encoder, 167 | freeze_image_encoder=False, 168 | ): 169 | super().__init__() 170 | self.image_encoder = image_encoder 171 | self.mask_decoder = mask_decoder 172 | self.prompt_encoder = prompt_encoder 173 | 174 | # freeze prompt encoder 175 | for param in self.prompt_encoder.parameters(): 176 | param.requires_grad = False 177 | 178 | self.freeze_image_encoder = freeze_image_encoder 179 | if self.freeze_image_encoder: 180 | for param in self.image_encoder.parameters(): 181 | param.requires_grad = False 182 | 183 | def forward(self, image, point_prompt): 184 | 185 | # do not compute gradients for pretrained img encoder and prompt encoder 186 | with torch.no_grad(): 187 | image_embedding = self.image_encoder(image) # (B, 256, 64, 64) 188 | # not need to convert box to 1024x1024 grid 189 | # bbox is already in 1024x1024 190 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 191 | points=point_prompt, 192 | boxes=None, 193 | masks=None, 194 | ) 195 | low_res_masks, iou_predictions = self.mask_decoder( 196 | image_embeddings=image_embedding, # (B, 256, 64, 64) 197 | image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 198 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 199 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 200 | multimask_output=False, 201 | ) # (B, 1, 256, 256) 202 | 203 | return low_res_masks 204 | 205 | sam_model = sam_model_registry["vit_b"](checkpoint=medsam_checkpoint) 206 | medsam_model = MedSAM( 207 | image_encoder = sam_model.image_encoder, 208 | mask_decoder = sam_model.mask_decoder, 209 | prompt_encoder = sam_model.prompt_encoder, 210 | freeze_image_encoder = True 211 | ) 212 | medsam_model = medsam_model.to(device) 213 | medsam_model.train() 214 | print(f"MedSAM size: {sum(p.numel() for p in medsam_model.parameters())}") 215 | 216 | optimizer = optim.AdamW( 217 | medsam_model.mask_decoder.parameters(), 218 | lr=args.lr, 219 | betas=(0.9, 0.999), 220 | eps=1e-08, 221 | weight_decay=args.weight_decay 222 | ) 223 | 224 | seg_loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, reduction='mean') 225 | ce_loss = nn.BCEWithLogitsLoss(reduction="mean") 226 | 227 | train_dataset = NpyDataset(data_root=data_root, data_aug=data_aug) 228 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 229 | 230 | resume = args.resume 231 | if resume: 232 | checkpoint = torch.load(resume) 233 | medsam_model.load_state_dict(checkpoint["model"]) 234 | optimizer.load_state_dict(checkpoint["optimizer"]) 235 | start_epoch = checkpoint["epoch"] + 1 236 | best_loss = checkpoint["best_loss"] 237 | print(f"Loaded checkpoint from epoch {start_epoch}, best loss: {best_loss:.4f}") 238 | else: 239 | start_epoch = 0 240 | best_loss = 1e10 241 | torch.cuda.empty_cache() 242 | 243 | epoch_time = [] 244 | losses = [] 245 | for epoch in range(start_epoch, num_epochs): 246 | epoch_loss = [1e10 for _ in range(len(train_loader))] 247 | epoch_start_time = time() 248 | pbar = tqdm(train_loader) 249 | for step, batch in enumerate(pbar): 250 | image = batch["image"] 251 | gt2D = batch["gt2D"] 252 | coords_torch = batch["coords"] # (B, 2) 253 | optimizer.zero_grad() 254 | labels_torch = torch.ones(coords_torch.shape[0]).long() # (B,) 255 | labels_torch = labels_torch.unsqueeze(1) # (B, 1) 256 | image, gt2D = image.to(device), gt2D.to(device) 257 | coords_torch, labels_torch = coords_torch.to(device), labels_torch.to(device) 258 | point_prompt = (coords_torch, labels_torch) 259 | medsam_lite_pred = medsam_model(image, point_prompt) 260 | loss = seg_loss(medsam_lite_pred, gt2D) + ce_loss(medsam_lite_pred, gt2D.float()) 261 | epoch_loss[step] = loss.item() 262 | loss.backward() 263 | optimizer.step() 264 | optimizer.zero_grad() 265 | pbar.set_description(f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}") 266 | 267 | epoch_end_time = time() 268 | epoch_time.append(epoch_end_time - epoch_start_time) 269 | epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss) 270 | losses.append(epoch_loss_reduced) 271 | model_weights = medsam_model.state_dict() 272 | checkpoint = { 273 | "model": model_weights, 274 | "epoch": epoch, 275 | "optimizer": optimizer.state_dict(), 276 | "loss": epoch_loss_reduced, 277 | "best_loss": best_loss 278 | } 279 | if epoch_loss_reduced < best_loss: 280 | print(f"New best loss: {best_loss:.4f} -> {epoch_loss_reduced:.4f}") 281 | best_loss = epoch_loss_reduced 282 | checkpoint["best_loss"] = best_loss 283 | torch.save(checkpoint, join(work_dir, "medsam_point_prompt_best.pth")) 284 | 285 | torch.save(checkpoint, join(work_dir, "medsam_point_prompt_latest.pth")) 286 | 287 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10)) 288 | ax1.plot(losses) 289 | ax1.set_title("Dice + Cross Entropy Loss") 290 | ax1.set_xlabel("Epoch") 291 | ax1.set_ylabel("Loss") 292 | ax2.plot(epoch_time) 293 | ax2.set_title("Epoch Running Time") 294 | ax2.set_xlabel("Epoch") 295 | ax2.set_ylabel("Time (s)") 296 | fig.savefig(join(work_dir, "medsam_point_prompt_loss_time.png")) 297 | 298 | epoch_loss_reduced = 1e10 299 | -------------------------------------------------------------------------------- /extensions/seg_3dnii_sparse_marker/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | 4 | 5 | We conducted an annotation study to show how MedSAM can be used to assit annotation. Specifically, we used the recently released [adrenocortical carcinoma CT dataset](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=93257945), where the segmentation target (adrenal tumor) did not appear in the training set or existing validation sets. We randomly selected 10 cases with 733 tumor slices that need to be annotated. The data are available [here](https://drive.google.com/drive/folders/1QhD4vPDie-P2ddpur6lofRYWMHklYao3?usp=sharing). 6 | 7 | 8 | The annotation process contained three steps. 9 | 10 | 1. Initial marker 11 | Two radiologists independently draw the long and short tumor axes (initial marker), which is a common measure in clinical practice (e.g., RECIST). This process was conducted every 3-10 slices for the 3D tumor. 12 | 13 | 2. For each annotated slice, a rectangle binary mask was generated based on the linear label that can completely cover the linear label. For the unlabeled slices, the rectangle binary masks were simulated by interploating the surrounding labeled slices. 14 | 15 | ```bash 16 | python label_interpolate.py 17 | ``` 18 | 19 | 3. We converted the binary masks to bounding boxes followed by feeding them to medsam together with images and generating segmentation results. 20 | 21 | ```bash 22 | python medsam_infer_3Dbox_adrenal.py 23 | ``` 24 | -------------------------------------------------------------------------------- /extensions/seg_3dnii_sparse_marker/label_interpolate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | join = os.path.join 5 | import SimpleITK as sitk 6 | import numpy as np 7 | from collections import OrderedDict 8 | from scipy.ndimage import distance_transform_edt 9 | from scipy.interpolate import interp1d 10 | 11 | def interpolate_labels(label_volume): 12 | depth, height, width = label_volume.shape 13 | 14 | # Create an array to hold the interpolated labels 15 | interpolated_labels = np.zeros((depth, height, width), dtype=label_volume.dtype) 16 | 17 | # Loop through each unique label in the label volume 18 | for label in np.unique(label_volume): 19 | if label == 0: # Skip background 20 | continue 21 | 22 | # Create a binary mask for the current label 23 | binary_mask = (label_volume == label) 24 | 25 | # Extract slices with the current label 26 | labeled_slices = [i for i in range(depth) if np.any(binary_mask[i])] 27 | 28 | if len(labeled_slices) < 2: # At least two slices are needed for interpolation 29 | continue 30 | 31 | # Initialize array to hold distances for the slices 32 | distances = np.zeros((depth, height, width)) 33 | 34 | # Calculate distances from object border for each labeled slice 35 | for i in labeled_slices: 36 | distances[i] = distance_transform_edt(np.logical_not(binary_mask[i])) 37 | 38 | # Create an interpolating function 39 | f = interp1d(labeled_slices, [distances[i] for i in labeled_slices], 40 | axis=0, kind='linear', bounds_error=False, fill_value="extrapolate") 41 | 42 | # Interpolate 43 | for i, next_slice in zip(labeled_slices[:-1], labeled_slices[1:]): 44 | interpolated_slices = np.round(f(np.arange(i, next_slice + 1))).astype(int) 45 | 46 | # Update label 47 | interpolated_labels[i:next_slice + 1][interpolated_slices <= 0] = label 48 | 49 | return interpolated_labels 50 | 51 | def get_bbox(mask, bbox_shift=5): 52 | y_indices, x_indices = np.where(mask > 0) 53 | x_min, x_max = np.min(x_indices), np.max(x_indices) 54 | y_min, y_max = np.min(y_indices), np.max(y_indices) 55 | # add perturbation to bounding box coordinates 56 | H, W = mask.shape 57 | x_min = max(0, x_min - bbox_shift) 58 | x_max = min(W, x_max + bbox_shift) 59 | y_min = max(0, y_min - bbox_shift) 60 | y_max = min(H, y_max + bbox_shift) 61 | bboxes = np.array([x_min, y_min, x_max, y_max]) 62 | 63 | return bboxes 64 | 65 | 66 | # Check input directories. 67 | marker_dir = 'marker-expert1' 68 | save_dir = marker_dir + '_interpolated' 69 | os.makedirs(save_dir, exist_ok=True) 70 | names = sorted(os.listdir(marker_dir)) 71 | names = [name for name in names if name.endswith('.nii.gz')] 72 | 73 | for name in names: 74 | nii = sitk.ReadImage(join(marker_dir, name)) 75 | marker_data = np.uint8(sitk.GetArrayFromImage(nii)) 76 | # simulate bounding box based on marker 77 | box_data = np.zeros_like(marker_data, dtype=np.uint8) 78 | label_ids = np.unique(marker_data)[1:] 79 | print(f'label ids: {label_ids}') 80 | for label_id in label_ids: 81 | marker_data_id = (marker_data == label_id).astype(np.uint8) 82 | marker_zids, _, _ = np.where(marker_data_id > 0) 83 | marker_zids = np.sort(np.unique(marker_zids)) 84 | print(f'z indices: {marker_zids}') 85 | # bbox_dict = {} # key: z_index, value: bbox 86 | for z in marker_zids: 87 | # get bbox for each slice 88 | z_box = get_bbox(marker_data_id[z, :, :], bbox_shift=5) 89 | box_data[z, z_box[1]:z_box[3], z_box[0]:z_box[2]] = label_id 90 | # interpolate labels 91 | interpolated_labels = interpolate_labels(box_data) 92 | # save interpolated labels 93 | save_name = name# .replace('.nii.gz', '_interpolated.nii.gz') 94 | save_path = join(save_dir, save_name) 95 | save_sitk = sitk.GetImageFromArray(interpolated_labels) 96 | # add meta information 97 | save_sitk.CopyInformation(nii) 98 | sitk.WriteImage(save_sitk, save_path) -------------------------------------------------------------------------------- /extensions/seg_3dnii_sparse_marker/medsam_infer_3Dbox_adrenal.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import pandas as pd 3 | import numpy as np 4 | from os.path import join, exists 5 | from tqdm import tqdm 6 | from matplotlib import pyplot as plt 7 | import torch 8 | from torch.nn import functional as F 9 | from segment_anything import sam_model_registry 10 | import SimpleITK as sitk 11 | import random 12 | import os 13 | import cv2 14 | from skimage import io, measure 15 | from tqdm import tqdm 16 | from collections import OrderedDict 17 | import time 18 | 19 | def getLargestCC(segmentation): 20 | labels = measure.label(segmentation) 21 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 22 | return largestCC.astype(np.uint8) 23 | 24 | image_size = 1024 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | # %% 27 | # visualization functions 28 | # source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb 29 | # change color to avoid red and green 30 | def show_mask(mask, ax, random_color=False): 31 | if random_color: 32 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 33 | else: 34 | color = np.array([251/255, 252/255, 30/255, 0.6]) 35 | h, w = mask.shape[-2:] 36 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 37 | ax.imshow(mask_image) 38 | 39 | def show_box(box, ax): 40 | x0, y0 = box[0], box[1] 41 | w, h = box[2] - box[0], box[3] - box[1] 42 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 43 | 44 | def get_bbox(mask, bbox_shift=5): 45 | y_indices, x_indices = np.where(mask > 0) 46 | x_min, x_max = np.min(x_indices), np.max(x_indices) 47 | y_min, y_max = np.min(y_indices), np.max(y_indices) 48 | # add perturbation to bounding box coordinates 49 | H, W = mask.shape 50 | x_min = max(0, x_min - bbox_shift) 51 | x_max = min(W, x_max + bbox_shift) 52 | y_min = max(0, y_min - bbox_shift) 53 | y_max = min(H, y_max + bbox_shift) 54 | bboxes = np.array([x_min, y_min, x_max, y_max]) 55 | return bboxes 56 | 57 | @torch.no_grad() 58 | def medsam_inference(medsam_model, img_embed, box_1024, H, W): 59 | box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) 60 | if len(box_torch.shape) == 2: 61 | box_torch = box_torch[:, None, :] # (B, 1, 4) 62 | 63 | sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( 64 | points=None, 65 | boxes=box_torch, 66 | masks=None, 67 | ) 68 | low_res_logits, _ = medsam_model.mask_decoder( 69 | image_embeddings=img_embed, # (B, 256, 64, 64) 70 | image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 71 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 72 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 73 | multimask_output=False, 74 | ) 75 | 76 | low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) 77 | 78 | low_res_pred = F.interpolate( 79 | low_res_pred, 80 | size=(H, W), 81 | mode="bilinear", 82 | align_corners=False, 83 | ) # (1, 1, gt.shape) 84 | low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) 85 | medsam_seg = (low_res_pred > 0.5).astype(np.uint8) 86 | return medsam_seg 87 | 88 | 89 | # %% load model 90 | model_path = 'medsam_vit_b.pth' 91 | medsam_model = sam_model_registry['vit_b'](checkpoint=model_path) 92 | medsam_model = medsam_model.to(device) 93 | medsam_model.eval() 94 | 95 | seg_info = OrderedDict() 96 | seg_info['name'] = [] 97 | seg_info['running time'] = [] 98 | 99 | 100 | img_path = 'images' 101 | marker_path = 'marker-expert1_interpolated' 102 | seg_path = 'medsam_seg_expert1' 103 | os.makedirs(seg_path, exist_ok=True) 104 | 105 | # load data 106 | names = sorted(os.listdir(marker_path)) 107 | for name in tqdm(names): 108 | start_time = time.time() 109 | img_name = name.split('.nii.gz')[0] + '_0000.nii.gz' 110 | img_sitk = sitk.ReadImage(join(img_path, img_name)) 111 | image_data = sitk.GetArrayFromImage(img_sitk) 112 | # adjust window level and window width 113 | image_data_pre = image_data.astype(np.float32) # np.clip(image_data, -160.0, 240.0) 114 | image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 115 | image_data_pre = np.uint8(image_data_pre) 116 | seg_data = np.zeros_like(image_data_pre, dtype=np.uint8) 117 | marker_data = sitk.GetArrayFromImage(sitk.ReadImage(join(marker_path, name))) 118 | marker_data = np.uint8(marker_data) 119 | label_ids = np.unique(marker_data)[1:] 120 | print(f'label ids: {label_ids}') 121 | for label_id in label_ids: 122 | marker_data_id = (marker_data == label_id).astype(np.uint8) 123 | marker_zids, _, _ = np.where(marker_data_id > 0) 124 | marker_zids = np.sort(np.unique(marker_zids)) 125 | print(f'z indices: {marker_zids}') 126 | bbox_dict = {} # key: z_index, value: bbox 127 | for z in marker_zids: 128 | # get bbox for each slice 129 | z_box = get_bbox(marker_data_id[z, :, :], bbox_shift=5) 130 | bbox_dict[z] = z_box 131 | # find largest bbox in bbox_dict 132 | bbox_areas = [np.prod(bbox_dict[z][2:] - bbox_dict[z][:2]) for z in bbox_dict.keys()] 133 | z_middle = list(bbox_dict.keys())[np.argmax(bbox_areas)] # middle slice 134 | z_min = min(bbox_dict.keys()) 135 | z_max = max(bbox_dict.keys()) 136 | z_middle_bbox = bbox_dict[z_middle] 137 | # sanity check 138 | # img_roi = image_data_pre[z_middle, z_middle_bbox[1]:z_middle_bbox[3], z_middle_bbox[0]:z_middle_bbox[2]] 139 | # io.imsave(name.split('.nii.gz')[0] + '_roi.png', img_roi) 140 | 141 | # infer from middle slice to the z_max 142 | print('infer', name, 'from middle slice to the z_max') 143 | for z in tqdm(range(z_middle, z_max+1)): # include z_max 144 | img_2d = image_data_pre[z, :, :] 145 | if len(img_2d.shape) == 2: 146 | img_3c = np.repeat(img_2d[:, :, None], 3, axis=-1) 147 | else: 148 | img_3c = img_2d 149 | H, W, _ = img_3c.shape 150 | # img_1024 = transform.resize(img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8) 151 | img_1024 = cv2.resize(img_3c, (1024, 1024), interpolation=cv2.INTER_CUBIC) 152 | img_1024 = (img_1024 - img_1024.min()) / np.clip( 153 | img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None 154 | ) # normalize to [0, 1], (H, W, 3) 155 | # convert the shape to (3, H, W) 156 | img_1024_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) 157 | # get the image embedding 158 | with torch.no_grad(): 159 | image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64) 160 | if z in bbox_dict.keys(): 161 | box_1024 = bbox_dict[z] / np.array([W, H, W, H]) * 1024 162 | else: 163 | pre_seg = seg_data[z-1, :, :] # use the previous slice 164 | if np.max(pre_seg) > 0: 165 | pre_seg1024 = cv2.resize(pre_seg, (1024, 1024), interpolation=cv2.INTER_NEAREST) 166 | box_1024 = get_bbox(pre_seg1024) 167 | else: 168 | # find the closest z index in bbox_dict 169 | z_diff = [abs(z - z_id) for z_id in bbox_dict.keys()] 170 | z_closest = list(bbox_dict.keys())[np.argmin(z_diff)] 171 | box_1024 = bbox_dict[z_closest] / np.array([W, H, W, H]) * 1024 172 | bbox_dict[z] = box_1024 / 1024 * np.array([W, H, W, H]) 173 | img_2d_seg = medsam_inference(medsam_model, image_embedding, box_1024[None,:], H, W) 174 | seg_data[z, img_2d_seg>0] = 1 175 | 176 | # infer from middle slice to the z_max 177 | print('infer', name, 'from middle slice to the z_min') 178 | for z in tqdm(range(z_middle-1, z_min-1, -1)): 179 | img_2d = image_data_pre[z, :, :] 180 | if len(img_2d.shape) == 2: 181 | img_3c = np.repeat(img_2d[:, :, None], 3, axis=-1) 182 | else: 183 | img_3c = img_2d 184 | H, W, _ = img_3c.shape 185 | img_1024 = cv2.resize(img_3c, (1024, 1024), interpolation=cv2.INTER_CUBIC) 186 | img_1024 = (img_1024 - img_1024.min()) / np.clip( 187 | img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None 188 | ) # normalize to [0, 1], (H, W, 3) 189 | # convert the shape to (3, H, W) 190 | img_1024_tensor = torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) 191 | # get the image embedding 192 | with torch.no_grad(): 193 | image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64) 194 | 195 | if z in bbox_dict.keys(): 196 | box_1024 = bbox_dict[z] / np.array([W, H, W, H]) * 1024 197 | else: 198 | pre_seg = seg_data[z+1, :, :] 199 | if np.max(pre_seg) > 0: 200 | pre_seg1024 = cv2.resize(pre_seg, (1024, 1024), interpolation=cv2.INTER_NEAREST) 201 | box_1024 = get_bbox(pre_seg1024.astype(np.uint8)) 202 | else: 203 | # find the closest z index in bbox_dict 204 | z_diff = [abs(z - z_id) for z_id in bbox_dict.keys()] 205 | z_closest = list(bbox_dict.keys())[np.argmin(z_diff)] 206 | box_1024 = bbox_dict[z_closest] / np.array([W, H, W, H]) * 1024 207 | bbox_dict[z] = box_1024 / 1024 * np.array([W, H, W, H]) 208 | img_2d_seg = medsam_inference(medsam_model, image_embedding, box_1024[None,:], H, W) 209 | seg_data[z, img_2d_seg>0] = 1 210 | 211 | seg_sitk = sitk.GetImageFromArray(seg_data) 212 | seg_sitk.CopyInformation(img_sitk) 213 | sitk.WriteImage(seg_sitk, join(seg_path, name)) 214 | end_time = time.time() 215 | 216 | # save bounding box info 217 | seg_info['name'].append(name) 218 | seg_info['running time'].append(end_time - start_time) 219 | 220 | # save bbox info 221 | seg_df = pd.DataFrame(seg_info) 222 | seg_df.to_csv(join(seg_path, 'seg_info.csv'), index=False) 223 | 224 | -------------------------------------------------------------------------------- /extensions/text_prompt/README.md: -------------------------------------------------------------------------------- 1 | # MedSAM with text prompts 2 | 3 | 4 | ## Requirements 5 | The text prompt training uses the CLIP model from [Huggingface transformers](https://huggingface.co/docs/transformers/index). To install Huggingface transformers: 6 | ``` 7 | pip install transformers 8 | ``` 9 | 10 | ## Inference 11 | 12 | Please try this out-of-the-box demo: [colab](https://colab.research.google.com/drive/1wexPLewVMI-9EMiplfyoEtGGayYDH3tt?usp=sharing) 13 | 14 | ## Training 15 | 16 | This training script demonstrates how to train MedSAM with text prompts on the [MICCAI FLARE 2022](https://flare22.grand-challenge.org/) dataset, and assume that the dataset has been preprocessed into the format used by MedSAM as described [here](https://github.com/bowang-lab/MedSAM#data-preprocessing). 17 | 18 | The training script `train_text_prompt.py` takes the following arguments: 19 | * `-i`, `--tr_npy_path`: Path to the preprocessed npy data in MedSAM's format 20 | * `-medsam_checkpoint`: Path to the MedSAM checkpoint 21 | * `-work_dir`: Path to the directory where the model checkpoints will be saved 22 | * `-resume`: Path to the checkpoint to resume training from 23 | * `-batch_size`: Batch size 24 | 25 | For example, assume that the preprocessed data is stored in directory `npy_data`, the MedSAM model is stored in `MedSAM/work_dir/MedSAM/medsam_vit_b.pth`, and the model checkpoints should be saved in `train_text_prompt`. To train the model with a batch size of 16, run the following command: 26 | ``` 27 | python train_text_prompt.py \ 28 | -i npy_data \ 29 | -medsam_checkpoint MedSAM/work_dir/MedSAM/medsam_vit_b.pth \ 30 | -work_dir ./train_text_prompt 31 | ``` 32 | 33 | To resume an interrupted training, simply add the `-resume` argument: 34 | ``` 35 | python train_text_prompt.py \ 36 | -i npy_data \ 37 | -medsam_checkpoint MedSAM/work_dir/MedSAM/medsam_vit_b.pth \ 38 | -work_dir ./train_text_prompt \ 39 | -resume ./train_text_prompt/medsam_text_prompt_latest.pt 40 | ``` 41 | 42 | ## Train on your own dataset 43 | To train MedSAM with text prompts on your own dataset, you need to modify the `label_dict` in `NpyDataset` in the training script based on the label values and the corresponding text prompts. For example, if your dataset has two labels, `1` and `2`, and you want to use the text prompts `normal` and `abnormal`, then the `label_dict` should be: 44 | ``` 45 | class NpyDataset(Dataset): 46 | def __init__(self, data_root, image_size=1024, data_aug=True): 47 | ... 48 | self.label_dict = { 49 | 1: ['normal'], ## need to be a list 50 | 2: ['abnormal'] 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /extensions/text_prompt/text_seg_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/MedSAM/d71e8a1a99ad751840a22a7fa3ecfb4166fb1488/extensions/text_prompt/text_seg_demo.gif -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import time 4 | from PyQt5.QtGui import ( 5 | QBrush, 6 | QPainter, 7 | QPen, 8 | QPixmap, 9 | QKeySequence, 10 | QPen, 11 | QBrush, 12 | QColor, 13 | QImage, 14 | ) 15 | from PyQt5.QtWidgets import ( 16 | QFileDialog, 17 | QApplication, 18 | QGraphicsEllipseItem, 19 | QGraphicsItem, 20 | QGraphicsRectItem, 21 | QGraphicsScene, 22 | QGraphicsView, 23 | QGraphicsPixmapItem, 24 | QHBoxLayout, 25 | QPushButton, 26 | QSlider, 27 | QVBoxLayout, 28 | QWidget, 29 | QShortcut, 30 | ) 31 | 32 | import numpy as np 33 | from skimage import transform, io 34 | import torch 35 | import torch.nn as nn 36 | from torch.nn import functional as F 37 | from PIL import Image 38 | from segment_anything import sam_model_registry 39 | 40 | # freeze seeds 41 | torch.manual_seed(2023) 42 | torch.cuda.empty_cache() 43 | torch.cuda.manual_seed(2023) 44 | np.random.seed(2023) 45 | 46 | SAM_MODEL_TYPE = "vit_b" 47 | MedSAM_CKPT_PATH = "work_dir/MedSAM/medsam_vit_b.pth" 48 | MEDSAM_IMG_INPUT_SIZE = 1024 49 | 50 | if torch.backends.mps.is_available(): 51 | device = torch.device("mps") 52 | else: 53 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 54 | 55 | 56 | @torch.no_grad() 57 | def medsam_inference(medsam_model, img_embed, box_1024, height, width): 58 | box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) 59 | if len(box_torch.shape) == 2: 60 | box_torch = box_torch[:, None, :] # (B, 1, 4) 61 | 62 | sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( 63 | points=None, 64 | boxes=box_torch, 65 | masks=None, 66 | ) 67 | low_res_logits, _ = medsam_model.mask_decoder( 68 | image_embeddings=img_embed, # (B, 256, 64, 64) 69 | image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 70 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 71 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 72 | multimask_output=False, 73 | ) 74 | 75 | low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) 76 | 77 | low_res_pred = F.interpolate( 78 | low_res_pred, 79 | size=(height, width), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) # (1, 1, gt.shape) 83 | low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) 84 | medsam_seg = (low_res_pred > 0.5).astype(np.uint8) 85 | return medsam_seg 86 | 87 | 88 | print("Loading MedSAM model, a sec.") 89 | tic = time.perf_counter() 90 | 91 | # set up model 92 | medsam_model = sam_model_registry["vit_b"](checkpoint=MedSAM_CKPT_PATH).to(device) 93 | medsam_model.eval() 94 | 95 | print(f"Done, took {time.perf_counter() - tic}") 96 | 97 | 98 | def np2pixmap(np_img): 99 | height, width, channel = np_img.shape 100 | bytesPerLine = 3 * width 101 | qImg = QImage(np_img.data, width, height, bytesPerLine, QImage.Format_RGB888) 102 | return QPixmap.fromImage(qImg) 103 | 104 | 105 | colors = [ 106 | (255, 0, 0), 107 | (0, 255, 0), 108 | (0, 0, 255), 109 | (255, 255, 0), 110 | (255, 0, 255), 111 | (0, 255, 255), 112 | (128, 0, 0), 113 | (0, 128, 0), 114 | (0, 0, 128), 115 | (128, 128, 0), 116 | (128, 0, 128), 117 | (0, 128, 128), 118 | (255, 255, 255), 119 | (192, 192, 192), 120 | (64, 64, 64), 121 | (255, 0, 255), 122 | (0, 255, 255), 123 | (255, 255, 0), 124 | (0, 0, 127), 125 | (192, 0, 192), 126 | ] 127 | 128 | 129 | class Window(QWidget): 130 | def __init__(self): 131 | super().__init__() 132 | 133 | # configs 134 | self.half_point_size = 5 # radius of bbox starting and ending points 135 | 136 | # app stats 137 | self.image_path = None 138 | self.color_idx = 0 139 | self.bg_img = None 140 | self.is_mouse_down = False 141 | self.rect = None 142 | self.point_size = self.half_point_size * 2 143 | self.start_point = None 144 | self.end_point = None 145 | self.start_pos = (None, None) 146 | self.embedding = None 147 | self.prev_mask = None 148 | 149 | self.view = QGraphicsView() 150 | self.view.setRenderHint(QPainter.Antialiasing) 151 | 152 | pixmap = self.load_image() 153 | 154 | vbox = QVBoxLayout(self) 155 | vbox.addWidget(self.view) 156 | 157 | load_button = QPushButton("Load Image") 158 | save_button = QPushButton("Save Mask") 159 | 160 | hbox = QHBoxLayout(self) 161 | hbox.addWidget(load_button) 162 | hbox.addWidget(save_button) 163 | 164 | vbox.addLayout(hbox) 165 | 166 | self.setLayout(vbox) 167 | 168 | # keyboard shortcuts 169 | self.quit_shortcut = QShortcut(QKeySequence("Ctrl+Q"), self) 170 | self.quit_shortcut.activated.connect(lambda: quit()) 171 | 172 | self.undo_shortcut = QShortcut(QKeySequence("Ctrl+Z"), self) 173 | self.undo_shortcut.activated.connect(self.undo) 174 | 175 | load_button.clicked.connect(self.load_image) 176 | save_button.clicked.connect(self.save_mask) 177 | 178 | def undo(self): 179 | if self.prev_mask is None: 180 | print("No previous mask record") 181 | return 182 | 183 | self.color_idx -= 1 184 | 185 | bg = Image.fromarray(self.img_3c.astype("uint8"), "RGB") 186 | mask = Image.fromarray(self.prev_mask.astype("uint8"), "RGB") 187 | img = Image.blend(bg, mask, 0.2) 188 | 189 | self.scene.removeItem(self.bg_img) 190 | self.bg_img = self.scene.addPixmap(np2pixmap(np.array(img))) 191 | 192 | self.mask_c = self.prev_mask 193 | self.prev_mask = None 194 | 195 | def load_image(self): 196 | file_path, file_type = QFileDialog.getOpenFileName( 197 | self, "Choose Image to Segment", ".", "Image Files (*.png *.jpg *.bmp)" 198 | ) 199 | 200 | if file_path is None or len(file_path) == 0: 201 | print("No image path specified, plz select an image") 202 | exit() 203 | 204 | img_np = io.imread(file_path) 205 | if len(img_np.shape) == 2: 206 | img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) 207 | else: 208 | img_3c = img_np 209 | 210 | self.img_3c = img_3c 211 | self.image_path = file_path 212 | self.get_embeddings() 213 | pixmap = np2pixmap(self.img_3c) 214 | 215 | H, W, _ = self.img_3c.shape 216 | 217 | self.scene = QGraphicsScene(0, 0, W, H) 218 | self.end_point = None 219 | self.rect = None 220 | self.bg_img = self.scene.addPixmap(pixmap) 221 | self.bg_img.setPos(0, 0) 222 | self.mask_c = np.zeros((*self.img_3c.shape[:2], 3), dtype="uint8") 223 | self.view.setScene(self.scene) 224 | 225 | # events 226 | self.scene.mousePressEvent = self.mouse_press 227 | self.scene.mouseMoveEvent = self.mouse_move 228 | self.scene.mouseReleaseEvent = self.mouse_release 229 | 230 | def mouse_press(self, ev): 231 | x, y = ev.scenePos().x(), ev.scenePos().y() 232 | self.is_mouse_down = True 233 | self.start_pos = ev.scenePos().x(), ev.scenePos().y() 234 | self.start_point = self.scene.addEllipse( 235 | x - self.half_point_size, 236 | y - self.half_point_size, 237 | self.point_size, 238 | self.point_size, 239 | pen=QPen(QColor("red")), 240 | brush=QBrush(QColor("red")), 241 | ) 242 | 243 | def mouse_move(self, ev): 244 | if not self.is_mouse_down: 245 | return 246 | 247 | x, y = ev.scenePos().x(), ev.scenePos().y() 248 | 249 | if self.end_point is not None: 250 | self.scene.removeItem(self.end_point) 251 | self.end_point = self.scene.addEllipse( 252 | x - self.half_point_size, 253 | y - self.half_point_size, 254 | self.point_size, 255 | self.point_size, 256 | pen=QPen(QColor("red")), 257 | brush=QBrush(QColor("red")), 258 | ) 259 | 260 | if self.rect is not None: 261 | self.scene.removeItem(self.rect) 262 | sx, sy = self.start_pos 263 | xmin = min(x, sx) 264 | xmax = max(x, sx) 265 | ymin = min(y, sy) 266 | ymax = max(y, sy) 267 | self.rect = self.scene.addRect( 268 | xmin, ymin, xmax - xmin, ymax - ymin, pen=QPen(QColor("red")) 269 | ) 270 | 271 | def mouse_release(self, ev): 272 | x, y = ev.scenePos().x(), ev.scenePos().y() 273 | sx, sy = self.start_pos 274 | xmin = min(x, sx) 275 | xmax = max(x, sx) 276 | ymin = min(y, sy) 277 | ymax = max(y, sy) 278 | 279 | self.is_mouse_down = False 280 | 281 | H, W, _ = self.img_3c.shape 282 | box_np = np.array([[xmin, ymin, xmax, ymax]]) 283 | # print("bounding box:", box_np) 284 | box_1024 = box_np / np.array([W, H, W, H]) * 1024 285 | 286 | sam_mask = medsam_inference(medsam_model, self.embedding, box_1024, H, W) 287 | 288 | self.prev_mask = self.mask_c.copy() 289 | self.mask_c[sam_mask != 0] = colors[self.color_idx % len(colors)] 290 | self.color_idx += 1 291 | 292 | bg = Image.fromarray(self.img_3c.astype("uint8"), "RGB") 293 | mask = Image.fromarray(self.mask_c.astype("uint8"), "RGB") 294 | img = Image.blend(bg, mask, 0.2) 295 | 296 | self.scene.removeItem(self.bg_img) 297 | self.bg_img = self.scene.addPixmap(np2pixmap(np.array(img))) 298 | 299 | def save_mask(self): 300 | out_path = f"{self.image_path.split('.')[0]}_mask.png" 301 | io.imsave(out_path, self.mask_c) 302 | 303 | @torch.no_grad() 304 | def get_embeddings(self): 305 | print("Calculating embedding, gui may be unresponsive.") 306 | img_1024 = transform.resize( 307 | self.img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True 308 | ).astype(np.uint8) 309 | img_1024 = (img_1024 - img_1024.min()) / np.clip( 310 | img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None 311 | ) # normalize to [0, 1], (H, W, 3) 312 | # convert the shape to (3, H, W) 313 | img_1024_tensor = ( 314 | torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) 315 | ) 316 | 317 | # if self.embedding is None: 318 | with torch.no_grad(): 319 | self.embedding = medsam_model.image_encoder( 320 | img_1024_tensor 321 | ) # (1, 256, 64, 64) 322 | print("Done.") 323 | 324 | 325 | app = QApplication(sys.argv) 326 | 327 | w = Window() 328 | w.show() 329 | 330 | app.exec() 331 | -------------------------------------------------------------------------------- /pre_CT_MR.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% import packages 3 | # pip install connected-components-3d 4 | import numpy as np 5 | 6 | # import nibabel as nib 7 | import SimpleITK as sitk 8 | import os 9 | 10 | join = os.path.join 11 | from skimage import transform 12 | from tqdm import tqdm 13 | import cc3d 14 | 15 | # convert nii image to npz files, including original image and corresponding masks 16 | modality = "CT" 17 | anatomy = "Abd" # anantomy + dataset name 18 | img_name_suffix = "_0000.nii.gz" 19 | gt_name_suffix = ".nii.gz" 20 | prefix = modality + "_" + anatomy + "_" 21 | 22 | nii_path = "data/FLARE22Train/images" # path to the nii images 23 | gt_path = "data/FLARE22Train/labels" # path to the ground truth 24 | npy_path = "data/npy/" + prefix[:-1] 25 | os.makedirs(join(npy_path, "gts"), exist_ok=True) 26 | os.makedirs(join(npy_path, "imgs"), exist_ok=True) 27 | 28 | image_size = 1024 29 | voxel_num_thre2d = 100 30 | voxel_num_thre3d = 1000 31 | 32 | names = sorted(os.listdir(gt_path)) 33 | print(f"ori \# files {len(names)=}") 34 | names = [ 35 | name 36 | for name in names 37 | if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix)) 38 | ] 39 | print(f"after sanity check \# files {len(names)=}") 40 | 41 | # set label ids that are excluded 42 | remove_label_ids = [ 43 | 12 44 | ] # remove deodenum since it is scattered in the image, which is hard to specify with the bounding box 45 | tumor_id = None # only set this when there are multiple tumors; convert semantic masks to instance masks 46 | # set window level and width 47 | # https://radiopaedia.org/articles/windowing-ct 48 | WINDOW_LEVEL = 40 # only for CT images 49 | WINDOW_WIDTH = 400 # only for CT images 50 | 51 | # %% save preprocessed images and masks as npz files 52 | for name in tqdm(names[:40]): # use the remaining 10 cases for validation 53 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix 54 | gt_name = name 55 | gt_sitk = sitk.ReadImage(join(gt_path, gt_name)) 56 | gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk)) 57 | # remove label ids 58 | for remove_label_id in remove_label_ids: 59 | gt_data_ori[gt_data_ori == remove_label_id] = 0 60 | # label tumor masks as instances and remove from gt_data_ori 61 | if tumor_id is not None: 62 | tumor_bw = np.uint8(gt_data_ori == tumor_id) 63 | gt_data_ori[tumor_bw > 0] = 0 64 | # label tumor masks as instances 65 | tumor_inst, tumor_n = cc3d.connected_components( 66 | tumor_bw, connectivity=26, return_N=True 67 | ) 68 | # put the tumor instances back to gt_data_ori 69 | gt_data_ori[tumor_inst > 0] = ( 70 | tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1 71 | ) 72 | 73 | # exclude the objects with less than 1000 pixels in 3D 74 | gt_data_ori = cc3d.dust( 75 | gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True 76 | ) 77 | # remove small objects with less than 100 pixels in 2D slices 78 | 79 | for slice_i in range(gt_data_ori.shape[0]): 80 | gt_i = gt_data_ori[slice_i, :, :] 81 | # remove small objects with less than 100 pixels 82 | # reason: fro such small objects, the main challenge is detection rather than segmentation 83 | gt_data_ori[slice_i, :, :] = cc3d.dust( 84 | gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True 85 | ) 86 | # find non-zero slices 87 | z_index, _, _ = np.where(gt_data_ori > 0) 88 | z_index = np.unique(z_index) 89 | 90 | if len(z_index) > 0: 91 | # crop the ground truth with non-zero slices 92 | gt_roi = gt_data_ori[z_index, :, :] 93 | # load image and preprocess 94 | img_sitk = sitk.ReadImage(join(nii_path, image_name)) 95 | image_data = sitk.GetArrayFromImage(img_sitk) 96 | # nii preprocess start 97 | if modality == "CT": 98 | lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2 99 | upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2 100 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 101 | image_data_pre = ( 102 | (image_data_pre - np.min(image_data_pre)) 103 | / (np.max(image_data_pre) - np.min(image_data_pre)) 104 | * 255.0 105 | ) 106 | else: 107 | lower_bound, upper_bound = np.percentile( 108 | image_data[image_data > 0], 0.5 109 | ), np.percentile(image_data[image_data > 0], 99.5) 110 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 111 | image_data_pre = ( 112 | (image_data_pre - np.min(image_data_pre)) 113 | / (np.max(image_data_pre) - np.min(image_data_pre)) 114 | * 255.0 115 | ) 116 | image_data_pre[image_data == 0] = 0 117 | 118 | image_data_pre = np.uint8(image_data_pre) 119 | img_roi = image_data_pre[z_index, :, :] 120 | np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing()) 121 | # save the image and ground truth as nii files for sanity check; 122 | # they can be removed 123 | img_roi_sitk = sitk.GetImageFromArray(img_roi) 124 | gt_roi_sitk = sitk.GetImageFromArray(gt_roi) 125 | sitk.WriteImage( 126 | img_roi_sitk, 127 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"), 128 | ) 129 | sitk.WriteImage( 130 | gt_roi_sitk, 131 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"), 132 | ) 133 | # save the each CT image as npy file 134 | for i in range(img_roi.shape[0]): 135 | img_i = img_roi[i, :, :] 136 | img_3c = np.repeat(img_i[:, :, None], 3, axis=-1) 137 | resize_img_skimg = transform.resize( 138 | img_3c, 139 | (image_size, image_size), 140 | order=3, 141 | preserve_range=True, 142 | mode="constant", 143 | anti_aliasing=True, 144 | ) 145 | resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip( 146 | resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None 147 | ) # normalize to [0, 1], (H, W, 3) 148 | gt_i = gt_roi[i, :, :] 149 | resize_gt_skimg = transform.resize( 150 | gt_i, 151 | (image_size, image_size), 152 | order=0, 153 | preserve_range=True, 154 | mode="constant", 155 | anti_aliasing=False, 156 | ) 157 | resize_gt_skimg = np.uint8(resize_gt_skimg) 158 | assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape 159 | np.save( 160 | join( 161 | npy_path, 162 | "imgs", 163 | prefix 164 | + gt_name.split(gt_name_suffix)[0] 165 | + "-" 166 | + str(i).zfill(3) 167 | + ".npy", 168 | ), 169 | resize_img_skimg_01, 170 | ) 171 | np.save( 172 | join( 173 | npy_path, 174 | "gts", 175 | prefix 176 | + gt_name.split(gt_name_suffix)[0] 177 | + "-" 178 | + str(i).zfill(3) 179 | + ".npy", 180 | ), 181 | resize_gt_skimg, 182 | ) 183 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .build_sam import ( 9 | build_sam, 10 | build_sam_vit_h, 11 | build_sam_vit_l, 12 | build_sam_vit_b, 13 | sam_model_registry, 14 | ) 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | from functools import partial 8 | from pathlib import Path 9 | import urllib.request 10 | import torch 11 | 12 | from .modeling import ( 13 | ImageEncoderViT, 14 | MaskDecoder, 15 | PromptEncoder, 16 | Sam, 17 | TwoWayTransformer, 18 | ) 19 | 20 | 21 | def build_sam_vit_h(checkpoint=None): 22 | return _build_sam( 23 | encoder_embed_dim=1280, 24 | encoder_depth=32, 25 | encoder_num_heads=16, 26 | encoder_global_attn_indexes=[7, 15, 23, 31], 27 | checkpoint=checkpoint, 28 | ) 29 | 30 | 31 | build_sam = build_sam_vit_h 32 | 33 | 34 | def build_sam_vit_l(checkpoint=None): 35 | return _build_sam( 36 | encoder_embed_dim=1024, 37 | encoder_depth=24, 38 | encoder_num_heads=16, 39 | encoder_global_attn_indexes=[5, 11, 17, 23], 40 | checkpoint=checkpoint, 41 | ) 42 | 43 | 44 | def build_sam_vit_b(checkpoint=None): 45 | return _build_sam( 46 | encoder_embed_dim=768, 47 | encoder_depth=12, 48 | encoder_num_heads=12, 49 | encoder_global_attn_indexes=[2, 5, 8, 11], 50 | checkpoint=checkpoint, 51 | ) 52 | 53 | 54 | sam_model_registry = { 55 | "default": build_sam_vit_h, 56 | "vit_h": build_sam_vit_h, 57 | "vit_l": build_sam_vit_l, 58 | "vit_b": build_sam_vit_b, 59 | } 60 | 61 | 62 | def _build_sam( 63 | encoder_embed_dim, 64 | encoder_depth, 65 | encoder_num_heads, 66 | encoder_global_attn_indexes, 67 | checkpoint=None, 68 | ): 69 | prompt_embed_dim = 256 70 | image_size = 1024 71 | vit_patch_size = 16 72 | image_embedding_size = image_size // vit_patch_size 73 | sam = Sam( 74 | image_encoder=ImageEncoderViT( 75 | depth=encoder_depth, 76 | embed_dim=encoder_embed_dim, 77 | img_size=image_size, 78 | mlp_ratio=4, 79 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 80 | num_heads=encoder_num_heads, 81 | patch_size=vit_patch_size, 82 | qkv_bias=True, 83 | use_rel_pos=True, 84 | global_attn_indexes=encoder_global_attn_indexes, 85 | window_size=14, 86 | out_chans=prompt_embed_dim, 87 | ), 88 | prompt_encoder=PromptEncoder( 89 | embed_dim=prompt_embed_dim, 90 | image_embedding_size=(image_embedding_size, image_embedding_size), 91 | input_image_size=(image_size, image_size), 92 | mask_in_chans=16, 93 | ), 94 | mask_decoder=MaskDecoder( 95 | num_multimask_outputs=3, 96 | transformer=TwoWayTransformer( 97 | depth=2, 98 | embedding_dim=prompt_embed_dim, 99 | mlp_dim=2048, 100 | num_heads=8, 101 | ), 102 | transformer_dim=prompt_embed_dim, 103 | iou_head_depth=3, 104 | iou_head_hidden_dim=256, 105 | ), 106 | pixel_mean=[123.675, 116.28, 103.53], 107 | pixel_std=[58.395, 57.12, 57.375], 108 | ) 109 | sam.eval() 110 | checkpoint = Path(checkpoint) 111 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists(): 112 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ") 113 | if len(cmd) == 0 or cmd.lower() == "y": 114 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 115 | print("Downloading SAM ViT-B checkpoint...") 116 | urllib.request.urlretrieve( 117 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 118 | checkpoint, 119 | ) 120 | print(checkpoint.name, " is downloaded!") 121 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists(): 122 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ") 123 | if len(cmd) == 0 or cmd.lower() == "y": 124 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 125 | print("Downloading SAM ViT-H checkpoint...") 126 | urllib.request.urlretrieve( 127 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 128 | checkpoint, 129 | ) 130 | print(checkpoint.name, " is downloaded!") 131 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists(): 132 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ") 133 | if len(cmd) == 0 or cmd.lower() == "y": 134 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 135 | print("Downloading SAM ViT-L checkpoint...") 136 | urllib.request.urlretrieve( 137 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 138 | checkpoint, 139 | ) 140 | print(checkpoint.name, " is downloaded!") 141 | 142 | if checkpoint is not None: 143 | with open(checkpoint, "rb") as f: 144 | state_dict = torch.load(f, map_location=torch.device('cpu')) 145 | sam.load_state_dict(state_dict) 146 | return sam 147 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .sam import Sam 9 | from .image_encoder import ImageEncoderViT 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from typing import Type 12 | 13 | 14 | class MLPBlock(nn.Module): 15 | def __init__( 16 | self, 17 | embedding_dim: int, 18 | mlp_dim: int, 19 | act: Type[nn.Module] = nn.GELU, 20 | ) -> None: 21 | super().__init__() 22 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 23 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 24 | self.act = act() 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | return self.lin2(self.act(self.lin1(x))) 28 | 29 | 30 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 31 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 32 | class LayerNorm2d(nn.Module): 33 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 34 | super().__init__() 35 | self.weight = nn.Parameter(torch.ones(num_channels)) 36 | self.bias = nn.Parameter(torch.zeros(num_channels)) 37 | self.eps = eps 38 | 39 | def forward(self, x: torch.Tensor) -> torch.Tensor: 40 | u = x.mean(1, keepdim=True) 41 | s = (x - u).pow(2).mean(1, keepdim=True) 42 | x = (x - u) / torch.sqrt(s + self.eps) 43 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 44 | return x 45 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from typing import List, Tuple, Type 13 | 14 | from .common import LayerNorm2d 15 | 16 | 17 | class MaskDecoder(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | transformer_dim: int, 22 | transformer: nn.Module, 23 | num_multimask_outputs: int = 3, 24 | activation: Type[nn.Module] = nn.GELU, 25 | iou_head_depth: int = 3, 26 | iou_head_hidden_dim: int = 256, 27 | ) -> None: 28 | """ 29 | Predicts masks given an image and prompt embeddings, using a 30 | transformer architecture. 31 | 32 | Arguments: 33 | transformer_dim (int): the channel dimension of the transformer 34 | transformer (nn.Module): the transformer used to predict masks 35 | num_multimask_outputs (int): the number of masks to predict 36 | when disambiguating masks 37 | activation (nn.Module): the type of activation to use when 38 | upscaling masks 39 | iou_head_depth (int): the depth of the MLP used to predict 40 | mask quality 41 | iou_head_hidden_dim (int): the hidden dimension of the MLP 42 | used to predict mask quality 43 | """ 44 | super().__init__() 45 | self.transformer_dim = transformer_dim 46 | self.transformer = transformer 47 | 48 | self.num_multimask_outputs = num_multimask_outputs 49 | 50 | self.iou_token = nn.Embedding(1, transformer_dim) 51 | self.num_mask_tokens = num_multimask_outputs + 1 52 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 53 | 54 | self.output_upscaling = nn.Sequential( 55 | nn.ConvTranspose2d( 56 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 57 | ), 58 | LayerNorm2d(transformer_dim // 4), 59 | activation(), 60 | nn.ConvTranspose2d( 61 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 62 | ), 63 | activation(), 64 | ) 65 | self.output_hypernetworks_mlps = nn.ModuleList( 66 | [ 67 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 68 | for i in range(self.num_mask_tokens) 69 | ] 70 | ) 71 | 72 | self.iou_prediction_head = MLP( 73 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 74 | ) 75 | 76 | def forward( 77 | self, 78 | image_embeddings: torch.Tensor, 79 | image_pe: torch.Tensor, 80 | sparse_prompt_embeddings: torch.Tensor, 81 | dense_prompt_embeddings: torch.Tensor, 82 | multimask_output: bool, 83 | ) -> Tuple[torch.Tensor, torch.Tensor]: 84 | """ 85 | Predict masks given image and prompt embeddings. 86 | 87 | Arguments: 88 | image_embeddings (torch.Tensor): the embeddings from the image encoder 89 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 90 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 91 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 92 | multimask_output (bool): Whether to return multiple masks or a single 93 | mask. 94 | 95 | Returns: 96 | torch.Tensor: batched predicted masks 97 | torch.Tensor: batched predictions of mask quality 98 | """ 99 | masks, iou_pred = self.predict_masks( 100 | image_embeddings=image_embeddings, 101 | image_pe=image_pe, 102 | sparse_prompt_embeddings=sparse_prompt_embeddings, 103 | dense_prompt_embeddings=dense_prompt_embeddings, 104 | ) 105 | 106 | # Select the correct mask or masks for output 107 | if multimask_output: 108 | mask_slice = slice(1, None) 109 | else: 110 | mask_slice = slice(0, 1) 111 | masks = masks[:, mask_slice, :, :] 112 | iou_pred = iou_pred[:, mask_slice] 113 | 114 | # Prepare output 115 | return masks, iou_pred 116 | 117 | def predict_masks( 118 | self, 119 | image_embeddings: torch.Tensor, 120 | image_pe: torch.Tensor, 121 | sparse_prompt_embeddings: torch.Tensor, 122 | dense_prompt_embeddings: torch.Tensor, 123 | ) -> Tuple[torch.Tensor, torch.Tensor]: 124 | """Predicts masks. See 'forward' for more details.""" 125 | # Concatenate output tokens 126 | output_tokens = torch.cat( 127 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 128 | ) 129 | output_tokens = output_tokens.unsqueeze(0).expand( 130 | sparse_prompt_embeddings.size(0), -1, -1 131 | ) 132 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 133 | 134 | # Expand per-image data in batch direction to be per-mask 135 | if image_embeddings.shape[0] != tokens.shape[0]: 136 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 137 | else: 138 | src = image_embeddings 139 | src = src + dense_prompt_embeddings 140 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 141 | b, c, h, w = src.shape 142 | 143 | # Run the transformer 144 | hs, src = self.transformer(src, pos_src, tokens) 145 | iou_token_out = hs[:, 0, :] 146 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 147 | 148 | # Upscale mask embeddings and predict masks using the mask tokens 149 | src = src.transpose(1, 2).view(b, c, h, w) 150 | upscaled_embedding = self.output_upscaling(src) 151 | hyper_in_list: List[torch.Tensor] = [] 152 | for i in range(self.num_mask_tokens): 153 | hyper_in_list.append( 154 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 155 | ) 156 | hyper_in = torch.stack(hyper_in_list, dim=1) 157 | b, c, h, w = upscaled_embedding.shape 158 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 159 | 160 | # Generate mask quality predictions 161 | iou_pred = self.iou_prediction_head(iou_token_out) 162 | 163 | return masks, iou_pred 164 | 165 | 166 | # Lightly adapted from 167 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 168 | class MLP(nn.Module): 169 | def __init__( 170 | self, 171 | input_dim: int, 172 | hidden_dim: int, 173 | output_dim: int, 174 | num_layers: int, 175 | sigmoid_output: bool = False, 176 | ) -> None: 177 | super().__init__() 178 | self.num_layers = num_layers 179 | h = [hidden_dim] * (num_layers - 1) 180 | self.layers = nn.ModuleList( 181 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 182 | ) 183 | self.sigmoid_output = sigmoid_output 184 | 185 | def forward(self, x): 186 | for i, layer in enumerate(self.layers): 187 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 188 | if self.sigmoid_output: 189 | x = F.sigmoid(x) 190 | return x 191 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from typing import Any, Optional, Tuple, Type 13 | 14 | from .common import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | point_embedding[labels == -1] = 0.0 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | return point_embedding 100 | 101 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 102 | """Embeds box prompts.""" 103 | boxes = boxes + 0.5 # Shift to center of pixel 104 | coords = boxes.reshape(-1, 2, 2) 105 | corner_embedding = self.pe_layer.forward_with_coords( 106 | coords, self.input_image_size 107 | ) 108 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 109 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 110 | return corner_embedding 111 | 112 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 113 | """Embeds mask inputs.""" 114 | mask_embedding = self.mask_downscaling(masks) 115 | return mask_embedding 116 | 117 | def _get_batch_size( 118 | self, 119 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 120 | boxes: Optional[torch.Tensor], 121 | masks: Optional[torch.Tensor], 122 | ) -> int: 123 | """ 124 | Gets the batch size of the output given the batch size of the input prompts. 125 | """ 126 | if points is not None: 127 | return points[0].shape[0] 128 | elif boxes is not None: 129 | return boxes.shape[0] 130 | elif masks is not None: 131 | return masks.shape[0] 132 | else: 133 | return 1 134 | 135 | def _get_device(self) -> torch.device: 136 | return self.point_embeddings[0].weight.device 137 | 138 | def forward( 139 | self, 140 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 141 | boxes: Optional[torch.Tensor], 142 | masks: Optional[torch.Tensor], 143 | ) -> Tuple[torch.Tensor, torch.Tensor]: 144 | """ 145 | Embeds different types of prompts, returning both sparse and dense 146 | embeddings. 147 | 148 | Arguments: 149 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 150 | and labels to embed. 151 | boxes (torch.Tensor or none): boxes to embed 152 | masks (torch.Tensor or none): masks to embed 153 | 154 | Returns: 155 | torch.Tensor: sparse embeddings for the points and boxes, with shape 156 | BxNx(embed_dim), where N is determined by the number of input points 157 | and boxes. 158 | torch.Tensor: dense embeddings for the masks, in the shape 159 | Bx(embed_dim)x(embed_H)x(embed_W) 160 | """ 161 | bs = self._get_batch_size(points, boxes, masks) 162 | sparse_embeddings = torch.empty( 163 | (bs, 0, self.embed_dim), device=self._get_device() 164 | ) 165 | if points is not None: 166 | coords, labels = points 167 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 168 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 169 | if boxes is not None: 170 | box_embeddings = self._embed_boxes(boxes) 171 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 172 | 173 | if masks is not None: 174 | dense_embeddings = self._embed_masks(masks) 175 | else: 176 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 177 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 178 | ) 179 | 180 | return sparse_embeddings, dense_embeddings 181 | 182 | 183 | class PositionEmbeddingRandom(nn.Module): 184 | """ 185 | Positional encoding using random spatial frequencies. 186 | """ 187 | 188 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 189 | super().__init__() 190 | if scale is None or scale <= 0.0: 191 | scale = 1.0 192 | self.register_buffer( 193 | "positional_encoding_gaussian_matrix", 194 | scale * torch.randn((2, num_pos_feats)), 195 | ) 196 | 197 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 198 | """Positionally encode points that are normalized to [0,1].""" 199 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 200 | coords = 2 * coords - 1 201 | coords = coords @ self.positional_encoding_gaussian_matrix 202 | coords = 2 * np.pi * coords 203 | # outputs d_1 x ... x d_n x C shape 204 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 205 | 206 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 207 | """Generate positional encoding for a grid of the specified size.""" 208 | h, w = size 209 | device: Any = self.positional_encoding_gaussian_matrix.device 210 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 211 | y_embed = grid.cumsum(dim=0) - 0.5 212 | x_embed = grid.cumsum(dim=1) - 0.5 213 | y_embed = y_embed / h 214 | x_embed = x_embed / w 215 | 216 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 217 | return pe.permute(2, 0, 1) # C x H x W 218 | 219 | def forward_with_coords( 220 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 221 | ) -> torch.Tensor: 222 | """Positionally encode points that are not normalized to [0,1].""" 223 | coords = coords_input.clone() 224 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 225 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 226 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 227 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from typing import Any, Dict, List, Tuple 13 | 14 | from .image_encoder import ImageEncoderViT 15 | from .mask_decoder import MaskDecoder 16 | from .prompt_encoder import PromptEncoder 17 | 18 | 19 | class Sam(nn.Module): 20 | mask_threshold: float = 0.0 21 | image_format: str = "RGB" 22 | 23 | def __init__( 24 | self, 25 | image_encoder: ImageEncoderViT, 26 | prompt_encoder: PromptEncoder, 27 | mask_decoder: MaskDecoder, 28 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 29 | pixel_std: List[float] = [58.395, 57.12, 57.375], 30 | ) -> None: 31 | """ 32 | SAM predicts object masks from an image and input prompts. 33 | 34 | Arguments: 35 | image_encoder (ImageEncoderViT): The backbone used to encode the 36 | image into image embeddings that allow for efficient mask prediction. 37 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 38 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 39 | and encoded prompts. 40 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 41 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 42 | """ 43 | super().__init__() 44 | self.image_encoder = image_encoder 45 | self.prompt_encoder = prompt_encoder 46 | self.mask_decoder = mask_decoder 47 | self.register_buffer( 48 | "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False 49 | ) 50 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 51 | 52 | @property 53 | def device(self) -> Any: 54 | return self.pixel_mean.device 55 | 56 | @torch.no_grad() 57 | def forward( 58 | self, 59 | batched_input: List[Dict[str, Any]], 60 | multimask_output: bool, 61 | ) -> List[Dict[str, torch.Tensor]]: 62 | """ 63 | Predicts masks end-to-end from provided images and prompts. 64 | If prompts are not known in advance, using SamPredictor is 65 | recommended over calling the model directly. 66 | 67 | Arguments: 68 | batched_input (list(dict)): A list over input images, each a 69 | dictionary with the following keys. A prompt key can be 70 | excluded if it is not present. 71 | 'image': The image as a torch tensor in 3xHxW format, 72 | already transformed for input to the model. 73 | 'original_size': (tuple(int, int)) The original size of 74 | the image before transformation, as (H, W). 75 | 'point_coords': (torch.Tensor) Batched point prompts for 76 | this image, with shape BxNx2. Already transformed to the 77 | input frame of the model. 78 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 79 | with shape BxN. 80 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 81 | Already transformed to the input frame of the model. 82 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 83 | in the form Bx1xHxW. 84 | multimask_output (bool): Whether the model should predict multiple 85 | disambiguating masks, or return a single mask. 86 | 87 | Returns: 88 | (list(dict)): A list over input images, where each element is 89 | as dictionary with the following keys. 90 | 'masks': (torch.Tensor) Batched binary mask predictions, 91 | with shape BxCxHxW, where B is the number of input prompts, 92 | C is determined by multimask_output, and (H, W) is the 93 | original size of the image. 94 | 'iou_predictions': (torch.Tensor) The model's predictions 95 | of mask quality, in shape BxC. 96 | 'low_res_logits': (torch.Tensor) Low resolution logits with 97 | shape BxCxHxW, where H=W=256. Can be passed as mask input 98 | to subsequent iterations of prediction. 99 | """ 100 | input_images = torch.stack( 101 | [self.preprocess(x["image"]) for x in batched_input], dim=0 102 | ) 103 | image_embeddings = self.image_encoder(input_images) 104 | 105 | outputs = [] 106 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 107 | if "point_coords" in image_record: 108 | points = (image_record["point_coords"], image_record["point_labels"]) 109 | else: 110 | points = None 111 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 112 | points=points, 113 | boxes=image_record.get("boxes", None), 114 | masks=image_record.get("mask_inputs", None), 115 | ) 116 | low_res_masks, iou_predictions = self.mask_decoder( 117 | image_embeddings=curr_embedding.unsqueeze(0), 118 | image_pe=self.prompt_encoder.get_dense_pe(), 119 | sparse_prompt_embeddings=sparse_embeddings, 120 | dense_prompt_embeddings=dense_embeddings, 121 | multimask_output=multimask_output, 122 | ) 123 | masks = self.postprocess_masks( 124 | low_res_masks, 125 | input_size=image_record["image"].shape[-2:], 126 | original_size=image_record["original_size"], 127 | ) 128 | masks = masks > self.mask_threshold 129 | outputs.append( 130 | { 131 | "masks": masks, 132 | "iou_predictions": iou_predictions, 133 | "low_res_logits": low_res_masks, 134 | } 135 | ) 136 | return outputs 137 | 138 | def postprocess_masks( 139 | self, 140 | masks: torch.Tensor, 141 | input_size: Tuple[int, ...], 142 | original_size: Tuple[int, ...], 143 | ) -> torch.Tensor: 144 | """ 145 | Remove padding and upscale masks to the original image size. 146 | 147 | Arguments: 148 | masks (torch.Tensor): Batched masks from the mask_decoder, 149 | in BxCxHxW format. 150 | input_size (tuple(int, int)): The size of the image input to the 151 | model, in (H, W) format. Used to remove padding. 152 | original_size (tuple(int, int)): The original size of the image 153 | before resizing for input to the model, in (H, W) format. 154 | 155 | Returns: 156 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 157 | is given by original_size. 158 | """ 159 | masks = F.interpolate( 160 | masks, 161 | (self.image_encoder.img_size, self.image_encoder.img_size), 162 | mode="bilinear", 163 | align_corners=False, 164 | ) 165 | masks = masks[..., : input_size[0], : input_size[1]] 166 | masks = F.interpolate( 167 | masks, original_size, mode="bilinear", align_corners=False 168 | ) 169 | return masks 170 | 171 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 172 | """Normalize pixel values and pad to a square input.""" 173 | # Normalize colors 174 | x = (x - self.pixel_mean) / self.pixel_std 175 | 176 | # Pad 177 | h, w = x.shape[-2:] 178 | padh = self.image_encoder.img_size - h 179 | padw = self.image_encoder.img_size - w 180 | x = F.pad(x, (0, padw, 0, padh)) 181 | return x 182 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import Tensor, nn 10 | 11 | import math 12 | from typing import Tuple, Type 13 | 14 | from .common import MLPBlock 15 | 16 | 17 | class TwoWayTransformer(nn.Module): 18 | def __init__( 19 | self, 20 | depth: int, 21 | embedding_dim: int, 22 | num_heads: int, 23 | mlp_dim: int, 24 | activation: Type[nn.Module] = nn.ReLU, 25 | attention_downsample_rate: int = 2, 26 | ) -> None: 27 | """ 28 | A transformer decoder that attends to an input image using 29 | queries whose positional embedding is supplied. 30 | 31 | Args: 32 | depth (int): number of layers in the transformer 33 | embedding_dim (int): the channel dimension for the input embeddings 34 | num_heads (int): the number of heads for multihead attention. Must 35 | divide embedding_dim 36 | mlp_dim (int): the channel dimension internal to the MLP block 37 | activation (nn.Module): the activation to use in the MLP block 38 | """ 39 | super().__init__() 40 | self.depth = depth 41 | self.embedding_dim = embedding_dim 42 | self.num_heads = num_heads 43 | self.mlp_dim = mlp_dim 44 | self.layers = nn.ModuleList() 45 | 46 | for i in range(depth): 47 | self.layers.append( 48 | TwoWayAttentionBlock( 49 | embedding_dim=embedding_dim, 50 | num_heads=num_heads, 51 | mlp_dim=mlp_dim, 52 | activation=activation, 53 | attention_downsample_rate=attention_downsample_rate, 54 | skip_first_layer_pe=(i == 0), 55 | ) 56 | ) 57 | 58 | self.final_attn_token_to_image = Attention( 59 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 60 | ) 61 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 62 | 63 | def forward( 64 | self, 65 | image_embedding: Tensor, 66 | image_pe: Tensor, 67 | point_embedding: Tensor, 68 | ) -> Tuple[Tensor, Tensor]: 69 | """ 70 | Args: 71 | image_embedding (torch.Tensor): image to attend to. Should be shape 72 | B x embedding_dim x h x w for any h and w. 73 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 74 | have the same shape as image_embedding. 75 | point_embedding (torch.Tensor): the embedding to add to the query points. 76 | Must have shape B x N_points x embedding_dim for any N_points. 77 | 78 | Returns: 79 | torch.Tensor: the processed point_embedding 80 | torch.Tensor: the processed image_embedding 81 | """ 82 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 83 | bs, c, h, w = image_embedding.shape 84 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 85 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 86 | 87 | # Prepare queries 88 | queries = point_embedding 89 | keys = image_embedding 90 | 91 | # Apply transformer blocks and final layernorm 92 | for layer in self.layers: 93 | queries, keys = layer( 94 | queries=queries, 95 | keys=keys, 96 | query_pe=point_embedding, 97 | key_pe=image_pe, 98 | ) 99 | 100 | # Apply the final attention layer from the points to the image 101 | q = queries + point_embedding 102 | k = keys + image_pe 103 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 104 | queries = queries + attn_out 105 | queries = self.norm_final_attn(queries) 106 | 107 | return queries, keys 108 | 109 | 110 | class TwoWayAttentionBlock(nn.Module): 111 | def __init__( 112 | self, 113 | embedding_dim: int, 114 | num_heads: int, 115 | mlp_dim: int = 2048, 116 | activation: Type[nn.Module] = nn.ReLU, 117 | attention_downsample_rate: int = 2, 118 | skip_first_layer_pe: bool = False, 119 | ) -> None: 120 | """ 121 | A transformer block with four layers: (1) self-attention of sparse 122 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 123 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 124 | inputs. 125 | 126 | Arguments: 127 | embedding_dim (int): the channel dimension of the embeddings 128 | num_heads (int): the number of heads in the attention layers 129 | mlp_dim (int): the hidden dimension of the mlp block 130 | activation (nn.Module): the activation of the mlp block 131 | skip_first_layer_pe (bool): skip the PE on the first layer 132 | """ 133 | super().__init__() 134 | self.self_attn = Attention(embedding_dim, num_heads) 135 | self.norm1 = nn.LayerNorm(embedding_dim) 136 | 137 | self.cross_attn_token_to_image = Attention( 138 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 139 | ) 140 | self.norm2 = nn.LayerNorm(embedding_dim) 141 | 142 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 143 | self.norm3 = nn.LayerNorm(embedding_dim) 144 | 145 | self.norm4 = nn.LayerNorm(embedding_dim) 146 | self.cross_attn_image_to_token = Attention( 147 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 148 | ) 149 | 150 | self.skip_first_layer_pe = skip_first_layer_pe 151 | 152 | def forward( 153 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 154 | ) -> Tuple[Tensor, Tensor]: 155 | # Self attention block 156 | if self.skip_first_layer_pe: 157 | queries = self.self_attn(q=queries, k=queries, v=queries) 158 | else: 159 | q = queries + query_pe 160 | attn_out = self.self_attn(q=q, k=q, v=queries) 161 | queries = queries + attn_out 162 | queries = self.norm1(queries) 163 | 164 | # Cross attention block, tokens attending to image embedding 165 | q = queries + query_pe 166 | k = keys + key_pe 167 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 168 | queries = queries + attn_out 169 | queries = self.norm2(queries) 170 | 171 | # MLP block 172 | mlp_out = self.mlp(queries) 173 | queries = queries + mlp_out 174 | queries = self.norm3(queries) 175 | 176 | # Cross attention block, image embedding attending to tokens 177 | q = queries + query_pe 178 | k = keys + key_pe 179 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 180 | keys = keys + attn_out 181 | keys = self.norm4(keys) 182 | 183 | return queries, keys 184 | 185 | 186 | class Attention(nn.Module): 187 | """ 188 | An attention layer that allows for downscaling the size of the embedding 189 | after projection to queries, keys, and values. 190 | """ 191 | 192 | def __init__( 193 | self, 194 | embedding_dim: int, 195 | num_heads: int, 196 | downsample_rate: int = 1, 197 | ) -> None: 198 | super().__init__() 199 | self.embedding_dim = embedding_dim 200 | self.internal_dim = embedding_dim // downsample_rate 201 | self.num_heads = num_heads 202 | assert ( 203 | self.internal_dim % num_heads == 0 204 | ), "num_heads must divide embedding_dim." 205 | 206 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 207 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 208 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 209 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 210 | 211 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 212 | b, n, c = x.shape 213 | x = x.reshape(b, n, num_heads, c // num_heads) 214 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 215 | 216 | def _recombine_heads(self, x: Tensor) -> Tensor: 217 | b, n_heads, n_tokens, c_per_head = x.shape 218 | x = x.transpose(1, 2) 219 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 220 | 221 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 222 | # Input projections 223 | q = self.q_proj(q) 224 | k = self.k_proj(k) 225 | v = self.v_proj(v) 226 | 227 | # Separate into heads 228 | q = self._separate_heads(q, self.num_heads) 229 | k = self._separate_heads(k, self.num_heads) 230 | v = self._separate_heads(v, self.num_heads) 231 | 232 | # Attention 233 | _, _, _, c_per_head = q.shape 234 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 235 | attn = attn / math.sqrt(c_per_head) 236 | attn = torch.softmax(attn, dim=-1) 237 | 238 | # Get output 239 | out = attn @ v 240 | out = self._recombine_heads(out) 241 | out = self.out_proj(out) 242 | 243 | return out 244 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | 12 | from typing import Tuple 13 | 14 | from ..modeling import Sam 15 | from .amg import calculate_stability_score 16 | 17 | 18 | class SamOnnxModel(nn.Module): 19 | """ 20 | This model should not be called directly, but is used in ONNX export. 21 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 22 | with some functions modified to enable model tracing. Also supports extra 23 | options controlling what information. See the ONNX export script for details. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model: Sam, 29 | return_single_mask: bool, 30 | use_stability_score: bool = False, 31 | return_extra_metrics: bool = False, 32 | ) -> None: 33 | super().__init__() 34 | self.mask_decoder = model.mask_decoder 35 | self.model = model 36 | self.img_size = model.image_encoder.img_size 37 | self.return_single_mask = return_single_mask 38 | self.use_stability_score = use_stability_score 39 | self.stability_score_offset = 1.0 40 | self.return_extra_metrics = return_extra_metrics 41 | 42 | @staticmethod 43 | def resize_longest_image_size( 44 | input_image_size: torch.Tensor, longest_side: int 45 | ) -> torch.Tensor: 46 | input_image_size = input_image_size.to(torch.float32) 47 | scale = longest_side / torch.max(input_image_size) 48 | transformed_size = scale * input_image_size 49 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 50 | return transformed_size 51 | 52 | def _embed_points( 53 | self, point_coords: torch.Tensor, point_labels: torch.Tensor 54 | ) -> torch.Tensor: 55 | point_coords = point_coords + 0.5 56 | point_coords = point_coords / self.img_size 57 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 58 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 59 | 60 | point_embedding = point_embedding * (point_labels != -1) 61 | point_embedding = ( 62 | point_embedding 63 | + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) 64 | ) 65 | 66 | for i in range(self.model.prompt_encoder.num_point_embeddings): 67 | point_embedding = ( 68 | point_embedding 69 | + self.model.prompt_encoder.point_embeddings[i].weight 70 | * (point_labels == i) 71 | ) 72 | 73 | return point_embedding 74 | 75 | def _embed_masks( 76 | self, input_mask: torch.Tensor, has_mask_input: torch.Tensor 77 | ) -> torch.Tensor: 78 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( 79 | input_mask 80 | ) 81 | mask_embedding = mask_embedding + ( 82 | 1 - has_mask_input 83 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 84 | return mask_embedding 85 | 86 | def mask_postprocessing( 87 | self, masks: torch.Tensor, orig_im_size: torch.Tensor 88 | ) -> torch.Tensor: 89 | masks = F.interpolate( 90 | masks, 91 | size=(self.img_size, self.img_size), 92 | mode="bilinear", 93 | align_corners=False, 94 | ) 95 | 96 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( 97 | torch.int64 98 | ) 99 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 100 | 101 | orig_im_size = orig_im_size.to(torch.int64) 102 | h, w = orig_im_size[0], orig_im_size[1] 103 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 104 | return masks 105 | 106 | def select_masks( 107 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 108 | ) -> Tuple[torch.Tensor, torch.Tensor]: 109 | # Determine if we should return the multiclick mask or not from the number of points. 110 | # The reweighting is used to avoid control flow. 111 | score_reweight = torch.tensor( 112 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 113 | ).to(iou_preds.device) 114 | score = iou_preds + (num_points - 2.5) * score_reweight 115 | best_idx = torch.argmax(score, dim=1) 116 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 117 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 118 | 119 | return masks, iou_preds 120 | 121 | @torch.no_grad() 122 | def forward( 123 | self, 124 | image_embeddings: torch.Tensor, 125 | point_coords: torch.Tensor, 126 | point_labels: torch.Tensor, 127 | mask_input: torch.Tensor, 128 | has_mask_input: torch.Tensor, 129 | orig_im_size: torch.Tensor, 130 | ): 131 | sparse_embedding = self._embed_points(point_coords, point_labels) 132 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 133 | 134 | masks, scores = self.model.mask_decoder.predict_masks( 135 | image_embeddings=image_embeddings, 136 | image_pe=self.model.prompt_encoder.get_dense_pe(), 137 | sparse_prompt_embeddings=sparse_embedding, 138 | dense_prompt_embeddings=dense_embedding, 139 | ) 140 | 141 | if self.use_stability_score: 142 | scores = calculate_stability_score( 143 | masks, self.model.mask_threshold, self.stability_score_offset 144 | ) 145 | 146 | if self.return_single_mask: 147 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 148 | 149 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 150 | 151 | if self.return_extra_metrics: 152 | stability_scores = calculate_stability_score( 153 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 154 | ) 155 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 156 | return upscaled_masks, scores, stability_scores, areas, masks 157 | 158 | return upscaled_masks, scores, masks 159 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import numpy as np 9 | import torch 10 | from torch.nn import functional as F 11 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 12 | 13 | from copy import deepcopy 14 | from typing import Tuple 15 | 16 | 17 | class ResizeLongestSide: 18 | """ 19 | Resizes images to the longest side 'target_length', as well as provides 20 | methods for resizing coordinates and boxes. Provides methods for 21 | transforming both numpy array and batched torch tensors. 22 | """ 23 | 24 | def __init__(self, target_length: int) -> None: 25 | self.target_length = target_length 26 | 27 | def apply_image(self, image: np.ndarray) -> np.ndarray: 28 | """ 29 | Expects a numpy array with shape HxWxC in uint8 format. 30 | """ 31 | target_size = self.get_preprocess_shape( 32 | image.shape[0], image.shape[1], self.target_length 33 | ) 34 | return np.array(resize(to_pil_image(image), target_size)) 35 | 36 | def apply_coords( 37 | self, coords: np.ndarray, original_size: Tuple[int, ...] 38 | ) -> np.ndarray: 39 | """ 40 | Expects a numpy array of length 2 in the final dimension. Requires the 41 | original image size in (H, W) format. 42 | """ 43 | old_h, old_w = original_size 44 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length) 45 | new_coords = np.empty_like(coords) 46 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w) 47 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h) 48 | return new_coords 49 | 50 | def apply_boxes( 51 | self, boxes: np.ndarray, original_size: Tuple[int, ...] 52 | ) -> np.ndarray: 53 | """ 54 | Expects a numpy array shape Bx4. Requires the original image size 55 | in (H, W) format. 56 | """ 57 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 58 | return boxes.reshape(-1, 4) 59 | 60 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 61 | """ 62 | Expects batched images with shape BxCxHxW and float format. This 63 | transformation may not exactly match apply_image. apply_image is 64 | the transformation expected by the model. 65 | """ 66 | # Expects an image in BCHW format. May not exactly match apply_image. 67 | target_size = self.get_preprocess_shape( 68 | image.shape[2], image.shape[3], self.target_length 69 | ) 70 | return F.interpolate( 71 | image, target_size, mode="bilinear", align_corners=False, antialias=True 72 | ) 73 | 74 | def apply_coords_torch( 75 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 76 | ) -> torch.Tensor: 77 | """ 78 | Expects a torch tensor with length 2 in the last dimension. Requires the 79 | original image size in (H, W) format. 80 | """ 81 | old_h, old_w = original_size 82 | new_h, new_w = self.get_preprocess_shape( 83 | original_size[0], original_size[1], self.target_length 84 | ) 85 | coords = deepcopy(coords).to(torch.float) 86 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 87 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 88 | return coords 89 | 90 | def apply_boxes_torch( 91 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 92 | ) -> torch.Tensor: 93 | """ 94 | Expects a torch tensor with shape Bx4. Requires the original image 95 | size in (H, W) format. 96 | """ 97 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 98 | return boxes.reshape(-1, 4) 99 | 100 | @staticmethod 101 | def get_preprocess_shape( 102 | oldh: int, oldw: int, long_side_length: int 103 | ) -> Tuple[int, int]: 104 | """ 105 | Compute the output size given input size and target long side length. 106 | """ 107 | scale = long_side_length * 1.0 / max(oldh, oldw) 108 | newh, neww = oldh * scale, oldw * scale 109 | neww = int(neww + 0.5) 110 | newh = int(newh + 0.5) 111 | return (newh, neww) 112 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Adapted from https://github.com/facebookresearch/segment-anything 7 | 8 | from setuptools import find_packages, setup 9 | from setuptools.command.install import install 10 | from tempfile import NamedTemporaryFile 11 | from os import unlink 12 | import subprocess 13 | import sys 14 | 15 | class CustomInstall(install): 16 | def run(self): 17 | with NamedTemporaryFile(mode='w+', delete=False) as f: 18 | f.write('numpy>=2.0.0\n') 19 | f.write('nvidia-cublas-cu12>=12.4.5.8\n') 20 | f.write('nvidia-cuda-cupti-cu12>=12.4.127\n') 21 | f.write('nvidia-cuda-nvrtc-cu12>=12.4.127\n') 22 | f.write('nvidia-cuda-runtime-cu12>=12.4.127\n') 23 | f.write('nvidia-cudnn-cu12>=9.1.0.70\n') 24 | f.write('nvidia-cufft-cu12>=11.2.1.3\n') 25 | f.write('nvidia-curand-cu12>=10.3.5.147\n') 26 | constraints_file = f.name 27 | try: 28 | subprocess.check_call([ 29 | sys.executable, 30 | '-m', 31 | 'pip', 32 | 'install', 33 | 'monai', 34 | '--constraint', 35 | constraints_file 36 | ]) 37 | finally: 38 | unlink(constraints_file) 39 | install.run(self) 40 | 41 | setup( 42 | name="medsam", 43 | version="0.0.1", 44 | author="Jun Ma", 45 | python_requires=">=3.9", 46 | install_requires=["matplotlib", "scikit-image", "SimpleITK>=2.2.1", "nibabel", "tqdm", "scipy", "ipympl", "opencv-python", "jupyterlab", "ipywidgets"], 47 | packages=find_packages(exclude="notebooks"), 48 | extras_require={ 49 | "all": ["pycocotools", "opencv-python", "onnx", "onnxruntime"], 50 | "dev": ["flake8", "isort", "black", "mypy"], 51 | }, 52 | cmdclass={ 53 | 'install': CustomInstall, 54 | }, 55 | ) 56 | -------------------------------------------------------------------------------- /train_multi_gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=5 3 | #SBATCH --ntasks=5 4 | #SBATCH --cpus-per-task=24 5 | #SBATCH --job-name=n-5nodes 6 | #SBATCH --mem=200GB 7 | #SBATCH --gres=gpu:4 8 | #SBATCH --partition=a100 9 | #SBATCH --output=logs/mgpus_%x-%j.out 10 | #SBATCH --error=logs/mgpus_%x-%j.err 11 | #SBATCH --time=20-00:00:00 12 | #SBATCH --exclude=gpu101,gpu113 13 | 14 | set -x -e 15 | 16 | # log the sbatch environment 17 | echo "start time: $(date)" 18 | echo "SLURM_JOBID="$SLURM_JOBID 19 | echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST 20 | echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION 21 | echo "SLURM_NNODES"=$SLURM_NNODES 22 | echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE 23 | echo "SLURM_SUBMIT_DIR"=$SLURM_SUBMIT_DIR 24 | 25 | # Training setup 26 | GPUS_PER_NODE=$SLURM_GPUS_ON_NODE 27 | 28 | ## Master node setup 29 | MAIN_HOST=`hostname -s` 30 | export MASTER_ADDR=$MAIN_HOST 31 | 32 | # Get a free port using python 33 | export MASTER_PORT=$(python - <> ./logs/log_for_${SLURM_JOB_ID}_node_${i}.log 2>&1 & 78 | done 79 | wait ## Wait for the tasks on nodes to finish 80 | 81 | echo "END TIME: $(date)" 82 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This folder contains the code for data organization, splitting, preprocessing, and checkpoint converting. 3 | 4 | 5 | ## Data Organization 6 | Since the orginal data formats and folder structures vary greatly across different dataset, we need to organize them as unified structures, allowing to use the same functions for data preprocessing. 7 | The expected folder structures are as follows: 8 | 9 | 3D nii data 10 | ``` 11 | ----dataset_name 12 | --------images 13 | ------------xxxx_0000.nii.gz 14 | ------------xxxx_0000.nii.gz 15 | --------labels 16 | ------------xxxx.nii.gz 17 | ------------xxxx.nii.gz 18 | ``` 19 | Note: you can also use different suffix for images and labels. Please change them in the following preprocessing scripts as well. 20 | 21 | 2D data 22 | ``` 23 | ----dataset_name 24 | --------images 25 | ------------xxxx.jpg/png 26 | ------------xxxx.jpg/png 27 | --------labels 28 | ------------xxxx.png 29 | ------------xxxx.jpg/png 30 | ``` 31 | 32 | video data 33 | ``` 34 | ----dataset_name 35 | --------images 36 | ------------video1 37 | ----------------xxxx.png 38 | ------------video2 39 | ----------------xxxx.png 40 | --------labels 41 | ------------video1 42 | ----------------xxxx.png 43 | ------------video2 44 | ----------------xxxx.png 45 | ``` 46 | 47 | Unfortunately, it is impossible to have one script to finish all the data organization. We manually organized with commonly used data format converting functions, including `dcm2nii`, `mhd2nii`, `nii2nii`, `nrrd2nii`, `jpg2png`, `tif2png`, `rle_decode`. These functions are available at `format_convert.py` 48 | 49 | ## Data Splitting 50 | For common 2D images (e.g., skin cancer demoscopy, chest X-Ray), they can be directly separated into 80%/10%/10% for training, parameter tuning, and internal validation, respectively. For 3D images (e.g., all the MRI/CT scans) and video data, they should be split in the case/video level rather than 2D slice/frame level. For 2D whole-slide images, the splitting is in the whole-slide level. Since they cannot be directly sent to the model because of the high resolution, we divided them into patches with a fixed size of `1024x1024` after data splitting. 51 | 52 | After finishing the data organization, the data splitting can be easily done by running 53 | ```bash 54 | python split.py 55 | ``` 56 | Please set the proper data path in the script. The expected folder structures (e.g., 3D images) are 57 | 58 | ``` 59 | ----dataset_name 60 | --------images 61 | ------------xxxx_0000.nii.gz 62 | ------------xxxx_0000.nii.gz 63 | --------labels 64 | ------------xxxx.nii.gz 65 | ------------xxxx.nii.gz 66 | --------validation 67 | ------------images 68 | ----------------xxxx_0000.nii.gz 69 | ----------------xxxx_0000.nii.gz 70 | ------------labels 71 | ----------------xxxx.nii.gz 72 | ----------------xxxx.nii.gz 73 | --------testing 74 | ------------images 75 | ----------------xxxx_0000.nii.gz 76 | ----------------xxxx_0000.nii.gz 77 | ------------labels 78 | ----------------xxxx.nii.gz 79 | ----------------xxxx.nii.gz 80 | ``` 81 | 82 | ## Data Preprocessing and Ensembling 83 | 84 | All the images will be preprocessed as `npy` files. There are two main reasons for choosing this format. First, it allows fast data loading (main reason). We learned this point from [nnU-Net](https://github.com/MIC-DKFZ/nnUNet). Second, numpy file is a universal data interface to unify all the different data formats. For the convenience of debugging and inference, we also saved the original images and labels as `npz` files. Spacing information is also saved for CT and MR images. 85 | 86 | The following steps are applied to all images 87 | - max-min normalization 88 | - resample image size to 1024x2014 89 | - save the pre-processed images and labels as npy files 90 | 91 | Different modalities also have their own additional pre-process steps based on the data features. 92 | 93 | For CT images, we fist adjust the window level and width following the [common practice](https://radiopaedia.org/articles/windowing-ct). 94 | - Soft tissue window level (40) and width (400) 95 | - Chest window level (-600) and width (1500) 96 | - Brain window level (40) and width (80) 97 | 98 | For MR and ultrasound, mammography, and Optical Coherence Tomography (OCT) images (i.e., ultrasound), we apply intensity cut-off with 0.5 and 99.5 percentiles of the foreground voxels. Regarding RGB images (e.g., endoscopy, dermoscopy, fundus, and pathology images), if they are already within the expected intensity range of [0, 255], their intensities remained unchanged. However, if they fell outside this range, max-min normalization was applited to rescale the intensity values to [0, 255]. 99 | 100 | Preprocess for CT/MR images: 101 | ```bash 102 | python pre_CT_MR.py 103 | ``` 104 | 105 | Preprocess for grey and RGB images: 106 | ```bash 107 | python pre_grey_rgb.py 108 | ``` 109 | 110 | Note: Please set the corresponding folder path and molidaty information. We provided an example in the script. 111 | 112 | Data ensembling of different training datasets is very simple. Since all the training data are converted into `npy` files during preprocessing, you just need to merge them into one folder. 113 | 114 | 115 | ## Checkpoint Converting 116 | If the model is trained with multiple GPUs, please use the script `ckpt_convert.py` to convert the format since users only use one GPU for model inference in real practice. 117 | 118 | Set the path to `sam_ckpt_path`, `medsam_ckpt_path`, and `save_path` and run 119 | 120 | ```bash 121 | python ckpt_convert.py 122 | ``` 123 | 124 | -------------------------------------------------------------------------------- /utils/ckpt_convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | # %% convert medsam model checkpoint to sam checkpoint format for convenient inference 5 | sam_ckpt_path = "" 6 | medsam_ckpt_path = "" 7 | save_path = "" 8 | multi_gpu_ckpt = True # set as True if the model is trained with multi-gpu 9 | 10 | sam_ckpt = torch.load(sam_ckpt_path) 11 | medsam_ckpt = torch.load(medsam_ckpt_path) 12 | sam_keys = sam_ckpt.keys() 13 | for key in sam_keys: 14 | if not multi_gpu_ckpt: 15 | sam_ckpt[key] = medsam_ckpt["model"][key] 16 | else: 17 | sam_ckpt[key] = medsam_ckpt["model"]["module." + key] 18 | 19 | torch.save(sam_ckpt, save_path) 20 | -------------------------------------------------------------------------------- /utils/format_convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | join = os.path.join 4 | import random 5 | import numpy as np 6 | from skimage import io 7 | import SimpleITK as sitk 8 | 9 | 10 | def dcm2nii(dcm_path, nii_path): 11 | """ 12 | Convert dicom files to nii files 13 | """ 14 | reader = sitk.ImageSeriesReader() 15 | dicom_names = reader.GetGDCMSeriesFileNames(dcm_path) 16 | reader.SetFileNames(dicom_names) 17 | image = reader.Execute() 18 | sitk.WriteImage(image, nii_path) 19 | 20 | def mhd2nii(mhd_path, nii_path): 21 | """ 22 | Convert mhd files to nii files 23 | """ 24 | image = sitk.ReadImage(mhd_path) 25 | sitk.WriteImage(image, nii_path) 26 | 27 | def nii2nii(nii_path, nii_gz_path): 28 | """ 29 | Convert nii files to nii.gz files, which can reduce the file size 30 | """ 31 | image = sitk.ReadImage(nii_path) 32 | sitk.WriteImage(image, nii_gz_path) 33 | 34 | def nrrd2nii(nrrd_path, nii_path): 35 | """ 36 | Convert nrrd files to nii files 37 | """ 38 | image = sitk.ReadImage(nrrd_path) 39 | sitk.WriteImage(image, nii_path) 40 | 41 | def jpg2png(jpg_path, png_path): 42 | """ 43 | Convert jpg files to png files 44 | """ 45 | image = io.imread(jpg_path) 46 | io.imsave(png_path, image) 47 | 48 | def patchfy(img, mask, outpath, basename): 49 | """ 50 | Patchfy the image and mask into 1024x1024 patches 51 | """ 52 | image_patch_dir = join(outpath, "images") 53 | mask_patch_dir = join(outpath, "labels") 54 | os.makedirs(image_patch_dir, exist_ok=True) 55 | os.makedirs(mask_patch_dir, exist_ok=True) 56 | assert img.shape[:2] == mask.shape 57 | patch_height = 1024 58 | patch_width = 1024 59 | 60 | img_height, img_width = img.shape[:2] 61 | mask_height, mask_width = mask.shape 62 | 63 | if img_height % patch_height != 0: 64 | img = np.pad(img, ((0, patch_height - img_height % patch_height), (0, 0), (0, 0)), mode="constant") 65 | if img_width % patch_width != 0: 66 | img = np.pad(img, ((0, 0), (0, patch_width - img_width % patch_width), (0, 0)), mode="constant") 67 | if mask_height % patch_height != 0: 68 | mask = np.pad(mask, ((0, patch_height - mask_height % patch_height), (0, 0)), mode="constant") 69 | if mask_width % patch_width != 0: 70 | mask = np.pad(mask, ((0, 0), (0, patch_width - mask_width % patch_width)), mode="constant") 71 | 72 | assert img.shape[:2] == mask.shape 73 | assert img.shape[0] % patch_height == 0 74 | assert img.shape[1] % patch_width == 0 75 | assert mask.shape[0] % patch_height == 0 76 | assert mask.shape[1] % patch_width == 0 77 | 78 | height_steps = (img_height // patch_height) if img_height % patch_height == 0 else (img_height // patch_height + 1) 79 | width_steps = (img_width // patch_width) if img_width % patch_width == 0 else (img_width // patch_width + 1) 80 | 81 | for i in range(height_steps): 82 | for j in range(width_steps): 83 | img_patch = img[i * patch_height:(i + 1) * patch_height, j * patch_width:(j + 1) * patch_width, :] 84 | mask_patch = mask[i * patch_height:(i + 1) * patch_height, j * patch_width:(j + 1) * patch_width] 85 | assert img_patch.shape[:2] == mask_patch.shape 86 | assert img_patch.shape[0] == patch_height 87 | assert img_patch.shape[1] == patch_width 88 | print(f"img_patch.shape: {img_patch.shape}, mask_patch.shape: {mask_patch.shape}") 89 | img_patch_path = join(image_patch_dir, f"{basename}_{i}_{j}.png") 90 | mask_patch_path = join(mask_patch_dir, f"{basename}_{i}_{j}.png") 91 | io.imsave(img_patch_path, img_patch) 92 | io.imsave(mask_patch_path, mask_patch) 93 | 94 | 95 | def rle_decode(mask_rle, img_shape): 96 | """ 97 | #functions to convert encoding to mask and mask to encoding 98 | mask_rle: run-length as string formated (start length) 99 | shape: (height,width) of array to return 100 | Returns numpy array, 1 - mask, 0 - background 101 | """ 102 | seq = mask_rle.split() 103 | starts = np.array(list(map(int, seq[0::2]))) 104 | lengths = np.array(list(map(int, seq[1::2]))) 105 | assert len(starts) == len(lengths) 106 | ends = starts + lengths 107 | img = np.zeros((np.product(img_shape),), dtype=np.uint8) 108 | for begin, end in zip(starts, ends): 109 | img[begin:end] = 255 110 | # https://stackoverflow.com/a/46574906/4521646 111 | img.shape = img_shape 112 | return img.T -------------------------------------------------------------------------------- /utils/pre_CT_MR.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% import packages 3 | # pip install connected-components-3d 4 | import numpy as np 5 | 6 | # import nibabel as nib 7 | import SimpleITK as sitk 8 | import os 9 | 10 | join = os.path.join 11 | from skimage import transform 12 | from tqdm import tqdm 13 | import cc3d 14 | 15 | # convert nii image to npz files, including original image and corresponding masks 16 | modality = "CT" 17 | anatomy = "Abd" # anantomy + dataset name 18 | img_name_suffix = "_0000.nii.gz" 19 | gt_name_suffix = ".nii.gz" 20 | prefix = modality + "_" + anatomy + "_" 21 | 22 | nii_path = "data/FLARE22Train/images" # path to the nii images 23 | gt_path = "data/FLARE22Train/labels" # path to the ground truth 24 | npy_path = "data/npy/" + prefix[:-1] 25 | os.makedirs(join(npy_path, "gts"), exist_ok=True) 26 | os.makedirs(join(npy_path, "imgs"), exist_ok=True) 27 | 28 | image_size = 1024 29 | voxel_num_thre2d = 100 30 | voxel_num_thre3d = 1000 31 | 32 | names = sorted(os.listdir(gt_path)) 33 | print(f"ori \# files {len(names)=}") 34 | names = [ 35 | name 36 | for name in names 37 | if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix)) 38 | ] 39 | print(f"after sanity check \# files {len(names)=}") 40 | 41 | # set label ids that are excluded 42 | remove_label_ids = [ 43 | 12 44 | ] # remove deodenum since it is scattered in the image, which is hard to specify with the bounding box 45 | tumor_id = None # only set this when there are multiple tumors; convert semantic masks to instance masks 46 | # set window level and width 47 | # https://radiopaedia.org/articles/windowing-ct 48 | WINDOW_LEVEL = 40 # only for CT images 49 | WINDOW_WIDTH = 400 # only for CT images 50 | 51 | # %% save preprocessed images and masks as npz files 52 | for name in tqdm(names[:40]): # use the remaining 10 cases for validation 53 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix 54 | gt_name = name 55 | gt_sitk = sitk.ReadImage(join(gt_path, gt_name)) 56 | gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk)) 57 | # remove label ids 58 | for remove_label_id in remove_label_ids: 59 | gt_data_ori[gt_data_ori == remove_label_id] = 0 60 | # label tumor masks as instances and remove from gt_data_ori 61 | if tumor_id is not None: 62 | tumor_bw = np.uint8(gt_data_ori == tumor_id) 63 | gt_data_ori[tumor_bw > 0] = 0 64 | # label tumor masks as instances 65 | tumor_inst, tumor_n = cc3d.connected_components( 66 | tumor_bw, connectivity=26, return_N=True 67 | ) 68 | # put the tumor instances back to gt_data_ori 69 | gt_data_ori[tumor_inst > 0] = ( 70 | tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1 71 | ) 72 | 73 | # exclude the objects with less than 1000 pixels in 3D 74 | gt_data_ori = cc3d.dust( 75 | gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True 76 | ) 77 | # remove small objects with less than 100 pixels in 2D slices 78 | 79 | for slice_i in range(gt_data_ori.shape[0]): 80 | gt_i = gt_data_ori[slice_i, :, :] 81 | # remove small objects with less than 100 pixels 82 | # reason: fro such small objects, the main challenge is detection rather than segmentation 83 | gt_data_ori[slice_i, :, :] = cc3d.dust( 84 | gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True 85 | ) 86 | # find non-zero slices 87 | z_index, _, _ = np.where(gt_data_ori > 0) 88 | z_index = np.unique(z_index) 89 | 90 | if len(z_index) > 0: 91 | # crop the ground truth with non-zero slices 92 | gt_roi = gt_data_ori[z_index, :, :] 93 | # load image and preprocess 94 | img_sitk = sitk.ReadImage(join(nii_path, image_name)) 95 | image_data = sitk.GetArrayFromImage(img_sitk) 96 | # nii preprocess start 97 | if modality == "CT": 98 | lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2 99 | upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2 100 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 101 | image_data_pre = ( 102 | (image_data_pre - np.min(image_data_pre)) 103 | / (np.max(image_data_pre) - np.min(image_data_pre)) 104 | * 255.0 105 | ) 106 | else: 107 | lower_bound, upper_bound = np.percentile( 108 | image_data[image_data > 0], 0.5 109 | ), np.percentile(image_data[image_data > 0], 99.5) 110 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 111 | image_data_pre = ( 112 | (image_data_pre - np.min(image_data_pre)) 113 | / (np.max(image_data_pre) - np.min(image_data_pre)) 114 | * 255.0 115 | ) 116 | image_data_pre[image_data == 0] = 0 117 | 118 | image_data_pre = np.uint8(image_data_pre) 119 | img_roi = image_data_pre[z_index, :, :] 120 | np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing()) 121 | # save the image and ground truth as nii files for sanity check; 122 | # they can be removed 123 | img_roi_sitk = sitk.GetImageFromArray(img_roi) 124 | gt_roi_sitk = sitk.GetImageFromArray(gt_roi) 125 | sitk.WriteImage( 126 | img_roi_sitk, 127 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"), 128 | ) 129 | sitk.WriteImage( 130 | gt_roi_sitk, 131 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"), 132 | ) 133 | # save the each CT image as npy file 134 | for i in range(img_roi.shape[0]): 135 | img_i = img_roi[i, :, :] 136 | img_3c = np.repeat(img_i[:, :, None], 3, axis=-1) 137 | resize_img_skimg = transform.resize( 138 | img_3c, 139 | (image_size, image_size), 140 | order=3, 141 | preserve_range=True, 142 | mode="constant", 143 | anti_aliasing=True, 144 | ) 145 | resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip( 146 | resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None 147 | ) # normalize to [0, 1], (H, W, 3) 148 | gt_i = gt_roi[i, :, :] 149 | resize_gt_skimg = transform.resize( 150 | gt_i, 151 | (image_size, image_size), 152 | order=0, 153 | preserve_range=True, 154 | mode="constant", 155 | anti_aliasing=False, 156 | ) 157 | resize_gt_skimg = np.uint8(resize_gt_skimg) 158 | assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape 159 | np.save( 160 | join( 161 | npy_path, 162 | "imgs", 163 | prefix 164 | + gt_name.split(gt_name_suffix)[0] 165 | + "-" 166 | + str(i).zfill(3) 167 | + ".npy", 168 | ), 169 | resize_img_skimg_01, 170 | ) 171 | np.save( 172 | join( 173 | npy_path, 174 | "gts", 175 | prefix 176 | + gt_name.split(gt_name_suffix)[0] 177 | + "-" 178 | + str(i).zfill(3) 179 | + ".npy", 180 | ), 181 | resize_gt_skimg, 182 | ) 183 | -------------------------------------------------------------------------------- /utils/pre_grey_rgb.py: -------------------------------------------------------------------------------- 1 | #%% import packages 2 | import numpy as np 3 | import os 4 | join = os.path.join 5 | from skimage import io, transform 6 | from tqdm import tqdm 7 | 8 | # convert 2D data to npy files, including images and corresponding masks 9 | modality = 'dd' # e.g., 'Dermoscopy 10 | anatomy = 'dd' # e.g., 'SkinCancer' 11 | img_name_suffix = '.png' 12 | gt_name_suffix = '.png' 13 | prefix = modality + '_' + anatomy + '_' 14 | save_suffix = '.npy' 15 | image_size = 1024 16 | img_path = 'path to /images' # path to the images 17 | gt_path = 'path to/labels' # path to the corresponding annotations 18 | npy_path = 'path to/data/npy/' + prefix[:-1] # save npy path e.g., MedSAM/data/npy/; don't miss the `/` 19 | os.makedirs(join(npy_path, "gts"), exist_ok=True) 20 | os.makedirs(join(npy_path, "imgs"), exist_ok=True) 21 | names = sorted(os.listdir(gt_path)) 22 | print(f'ori \# files {len(names)=}') 23 | 24 | # set label ids that are excluded 25 | remove_label_ids = [] 26 | tumor_id = None # only set this when there are multiple tumors in one image; convert semantic masks to instance masks 27 | label_id_offset = 0 28 | do_intensity_cutoff = False # True for grey images 29 | #%% save preprocessed images and masks as npz files 30 | for name in tqdm(names): 31 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix 32 | gt_name = name 33 | npy_save_name = prefix + gt_name.split(gt_name_suffix)[0]+save_suffix 34 | gt_data_ori = np.uint8(io.imread(join(gt_path, gt_name))) 35 | # remove label ids 36 | for remove_label_id in remove_label_ids: 37 | gt_data_ori[gt_data_ori==remove_label_id] = 0 38 | # label tumor masks as instances and remove from gt_data_ori 39 | if tumor_id is not None: 40 | tumor_bw = np.uint8(gt_data_ori==tumor_id) 41 | gt_data_ori[tumor_bw>0] = 0 42 | # label tumor masks as instances 43 | tumor_inst, tumor_n = cc3d.connected_components(tumor_bw, connectivity=26, return_N=True) 44 | # put the tumor instances back to gt_data_ori 45 | gt_data_ori[tumor_inst>0] = tumor_inst[tumor_inst>0] + label_id_offset + 1 46 | 47 | # crop the ground truth with non-zero slices 48 | image_data = io.imread(join(img_path, image_name)) 49 | if np.max(image_data) > 255.0: 50 | image_data = np.uint8((image_data-image_data.min()) / (np.max(image_data)-np.min(image_data))*255.0) 51 | if len(image_data.shape) == 2: 52 | image_data = np.repeat(np.expand_dims(image_data, -1), 3, -1) 53 | assert len(image_data.shape) == 3, 'image data is not three channels: img shape:' + str(image_data.shape) + image_name 54 | # convert three channel to one channel 55 | if image_data.shape[-1] > 3: 56 | image_data = image_data[:,:,:3] 57 | # image preprocess start 58 | if do_intensity_cutoff: 59 | lower_bound, upper_bound = np.percentile(image_data[image_data>0], 0.5), np.percentile(image_data[image_data>0], 99.5) 60 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 61 | image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 62 | image_data_pre[image_data==0] = 0 63 | image_data_pre = np.uint8(image_data_pre) 64 | else: 65 | # print('no intensity cutoff') 66 | image_data_pre = image_data.copy() 67 | np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=image_data_pre, gts=gt_data_ori) 68 | resize_img = transform.resize(image_data_pre, (image_size, image_size), order=3, mode='constant', preserve_range=True, anti_aliasing=True) 69 | resize_img01 = resize_img/255.0 70 | resize_gt = transform.resize(gt_data_ori, (image_size, image_size), order=0, mode='constant', preserve_range=True, anti_aliasing=False) 71 | # save resize img and gt as npy 72 | np.save(join(npy_path, "imgs", npy_save_name), resize_img01) 73 | np.save(join(npy_path, "gts", npy_save_name), resize_gt.astype(np.uint8)) 74 | 75 | -------------------------------------------------------------------------------- /utils/split.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | join = os.path.join 4 | import random 5 | 6 | path_nii = '' # please complete path; two subfolders: images and labels 7 | path_video = None # or specify the path 8 | path_2d = None # or specify the path 9 | 10 | #%% split 3D nii data 11 | if path_nii is not None: 12 | img_path = join(path_nii, 'images') 13 | gt_path = join(path_nii, 'labels') 14 | gt_names = sorted(os.listdir(gt_path)) 15 | img_suffix = '_0000.nii.gz' 16 | gt_suffix = '.nii.gz' 17 | # split 20% data for validation and testing 18 | validation_path = join(path_nii, 'validation') 19 | os.makedirs(join(validation_path, 'images'), exist_ok=True) 20 | os.makedirs(join(validation_path, 'labels'), exist_ok=True) 21 | testing_path = join(path_nii, 'testing') 22 | os.makedirs(join(testing_path, 'images'), exist_ok=True) 23 | os.makedirs(join(testing_path, 'labels'), exist_ok=True) 24 | candidates = random.sample(gt_names, int(len(gt_names)*0.2)) 25 | # split half of test names for validation 26 | validation_names = random.sample(candidates, int(len(candidates)*0.5)) 27 | test_names = [name for name in candidates if name not in validation_names] 28 | # move validation and testing data to corresponding folders 29 | for name in validation_names: 30 | img_name = name.split(gt_suffix)[0] + img_suffix 31 | os.rename(join(img_path, img_name), join(validation_path, 'images', img_name)) 32 | os.rename(join(gt_path, name), join(validation_path, 'labels', name)) 33 | for name in test_names: 34 | img_name = name.split(gt_suffix)[0] + img_suffix 35 | os.rename(join(img_path, img_name), join(testing_path, 'images', img_name)) 36 | os.rename(join(gt_path, name), join(testing_path, 'labels', name)) 37 | 38 | 39 | ##% split 2D images 40 | if path_2d is not None: 41 | img_path = join(path_2d, 'images') 42 | gt_path = join(path_2d, 'labels') 43 | gt_names = sorted(os.listdir(gt_path)) 44 | img_suffix = '.png' 45 | gt_suffix = '.png' 46 | # split 20% data for validation and testing 47 | validation_path = join(path_2d, 'validation') 48 | os.makedirs(join(validation_path, 'images'), exist_ok=True) 49 | os.makedirs(join(validation_path, 'labels'), exist_ok=True) 50 | testing_path = join(path_2d, 'testing') 51 | os.makedirs(join(testing_path, 'images'), exist_ok=True) 52 | os.makedirs(join(testing_path, 'labels'), exist_ok=True) 53 | candidates = random.sample(gt_names, int(len(gt_names)*0.2)) 54 | # split half of test names for validation 55 | validation_names = random.sample(candidates, int(len(candidates)*0.5)) 56 | test_names = [name for name in candidates if name not in validation_names] 57 | # move validation and testing data to corresponding folders 58 | for name in validation_names: 59 | img_name = name.split(gt_suffix)[0] + img_suffix 60 | os.rename(join(img_path, img_name), join(validation_path, 'images', img_name)) 61 | os.rename(join(gt_path, name), join(validation_path, 'labels', name)) 62 | 63 | for name in test_names: 64 | img_name = name.split(gt_suffix)[0] + img_suffix 65 | os.rename(join(img_path, img_name), join(testing_path, 'images', img_name)) 66 | os.rename(join(gt_path, name), join(testing_path, 'labels', name)) 67 | 68 | #%% split video data 69 | if path_video is not None: 70 | img_path = join(path_video, 'images') 71 | gt_path = join(path_video, 'labels') 72 | gt_folders = sorted(os.listdir(gt_path)) 73 | # split 20% videos for validation and testing 74 | validation_path = join(path_video, 'validation') 75 | os.makedirs(join(validation_path, 'images'), exist_ok=True) 76 | os.makedirs(join(validation_path, 'labels'), exist_ok=True) 77 | testing_path = join(path_video, 'testing') 78 | os.makedirs(join(testing_path, 'images'), exist_ok=True) 79 | os.makedirs(join(testing_path, 'labels'), exist_ok=True) 80 | candidates = random.sample(gt_folders, int(len(gt_folders)*0.2)) 81 | # split half of test names for validation 82 | validation_names = random.sample(candidates, int(len(candidates)*0.5)) 83 | test_names = [name for name in candidates if name not in validation_names] 84 | # move validation and testing data to corresponding folders 85 | for name in validation_names: 86 | os.rename(join(img_path, name), join(validation_path, 'images', name)) 87 | os.rename(join(gt_path, name), join(validation_path, 'labels', name)) 88 | for name in test_names: 89 | os.rename(join(img_path, name), join(testing_path, 'images', name)) 90 | os.rename(join(gt_path, name), join(testing_path, 'labels', name)) 91 | -------------------------------------------------------------------------------- /work_dir/MedSAM/README.md: -------------------------------------------------------------------------------- 1 | Note: put the MedSAM model checkpoint in this folder. 2 | --------------------------------------------------------------------------------