├── .gitignore ├── Dockerfile ├── README.md ├── notebooks ├── infer_bboxes.ipynb └── point_prompt_inference.ipynb ├── num_points.sh ├── requirements.txt ├── src ├── dataset.py ├── form_samples.py ├── model.py ├── point_prompt_demo.py ├── preprocess_CT.py ├── test_bboxes.py ├── test_model.py └── train_point_prompt.py ├── test_random.sh ├── test_samples ├── gts │ ├── CT_Abd_word_0014-000.npy │ ├── CT_Abd_word_0014-050.npy │ ├── CT_Abd_word_0016-099.npy │ ├── CT_Abd_word_0017-020.npy │ ├── CT_Abd_word_0019-110.npy │ ├── CT_Abd_word_0019-129.npy │ ├── CT_Abd_word_0021-001.npy │ ├── CT_Abd_word_0021-021.npy │ └── CT_Abd_word_0024-100.npy └── imgs │ ├── CT_Abd_word_0014-000.npy │ ├── CT_Abd_word_0014-050.npy │ ├── CT_Abd_word_0016-099.npy │ ├── CT_Abd_word_0017-020.npy │ ├── CT_Abd_word_0019-110.npy │ ├── CT_Abd_word_0019-129.npy │ ├── CT_Abd_word_0021-001.npy │ ├── CT_Abd_word_0021-021.npy │ └── CT_Abd_word_0024-100.npy └── weights └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__/ 3 | weights/sam/ 4 | weights/medsam/ 5 | data/ 6 | lightning_logs/ 7 | logs/ 8 | MedSAM/ 9 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | ffmpeg libsm6 libxext6 \ 7 | git \ 8 | python3.9 python3.9-dev python3.9-venv python3-pip \ 9 | && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 10 | 11 | RUN apt-get install wget -y 12 | RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin 13 | RUN mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600 14 | RUN wget https://developer.download.nvidia.com/compute/cuda/12.3.0/local_installers/cuda-repo-ubuntu2004-12-3-local_12.3.0-545.23.06-1_amd64.deb 15 | RUN dpkg -i cuda-repo-ubuntu2004-12-3-local_12.3.0-545.23.06-1_amd64.deb 16 | RUN cp /var/cuda-repo-ubuntu2004-12-3-local/cuda-*-keyring.gpg /usr/share/keyrings/ 17 | RUN apt-get update 18 | RUN apt-get -y install cuda-toolkit-12-3 19 | 20 | COPY requirements.txt /repo/requirements.txt 21 | 22 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ 23 | bash ~/miniconda.sh -b -p $HOME/miniconda && eval "$(/root/miniconda/bin/conda shell.bash hook)" && \ 24 | conda init && conda config --set auto_activate_base true && \ 25 | pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 && \ 26 | git clone https://github.com/bowang-lab/MedSAM && pip install -e MedSAM/ 27 | 28 | WORKDIR /repo 29 | RUN eval "$(/root/miniconda/bin/conda shell.bash hook)" && \ 30 | pip install --no-cache-dir -r requirements.txt 31 | 32 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gaze-Assisted Medical Image Segmentation 2 | 3 | ## Overview 4 | 5 | In this study, we explore semi-supervised medical image segmentation using human gaze as interactive input for correcting segmentation. We fine-tuned the Segment Anything Model in medical images (MedSAM) with gaze data from abdominal images and validated it on the WORD dataset, consisting of 120 CT scans of 16 abdominal organs. Ours gaze-assisted MedSAM outperformed state-of-the-art models on WORD benchmark, achieving Dice coefficients of 85.8%, 86.7%, 81.7%, and **90.5%** for nnUNetV2, ResUNet, original MedSAM, and our gaze-assisted MedSAM (fine-tuned on 5% random part of WORD), respectively. The best approach, fine-tuned on the complete WORD dataset, demonstrated a Dice score of **92.5%**. 6 | 7 | ## Usage 8 | 9 | The fine-tuned model checkpoints are integrated into eye-tracking software for interactive segmentation of medical images. Below is a demonstration of our gaze-assisted model in action: 10 | 11 | ![visualization2](https://github.com/user-attachments/assets/993c259a-5d24-43f2-9589-99d2dde231f2) 12 | 13 | 14 | You can download the checkpoints of gaze-assisted MedSAM from [Google Drive](https://drive.google.com/file/d/1DR7fMNzBZzyJ8_gBQKNWL4_8bzAU74R9/view?usp=sharing). 15 | 16 | ## Getting started 17 | 18 | Follow these steps to set up the project: 19 | 20 | 0. Clone this repo and MedSAM repo inside: 21 | ``` 22 | git clone https://github.com/leiluk1/gaze-based-segmentation.git 23 | cd gaze-based-segmentation 24 | git clone https://github.com/bowang-lab/MedSAM 25 | ``` 26 | 27 | 1. Build docker container: 28 | ``` 29 | docker build -t medsam_ft:latest . 30 | ``` 31 | 32 | 2. Run docker container as daemon: 33 | ``` 34 | docker run \ 35 | -v .:/repo/ \ 36 | --gpus all \ 37 | -it -d --name medsam_ft medsam_ft 38 | ``` 39 | 40 | 41 | 3. Start bash inside the docker container: 42 | 43 | 0. In order to run scripts in the background, install and launch screen: 44 | ``` 45 | sudo apt install screen 46 | ``` 47 | 48 | 1. Start bash: 49 | 50 | ``` 51 | docker exec -it medsam_ft bash 52 | ``` 53 | 54 | 4. Download data and model checkpoints to `data` and `weights`, respectively: 55 | ``` 56 | pip install gdown 57 | gdown 19OWCXZGrimafREhXm8O8w2HBHZTfxEgU -O ./data/ # download WORD dataset 58 | apt-get install p7zip-full 59 | cd data 60 | 7z x WORD-V0.1.0.zip # unzip WORD dataset 61 | wget https://github.com/HiLab-git/WORD/raw/main/WORD_V0.1.0_labelsTs.zip # download WORD test annotations 62 | unzip WORD_V0.1.0_labelsTs.zip -d ./WORD-V0.1.0/ 63 | ``` 64 | 65 | ``` 66 | wget https://zenodo.org/records/5903037/files/Subtask1.zip?download=1 -O ./data/Subtask1.zip # download AbdomenCT-1K 67 | wget https://zenodo.org/records/5903037/files/Subtask2.zip?download=1 -O ./data/Subtask2.zip # download AbdomenCT-1K 68 | cd data 69 | unzip Subtask1.zip 70 | uzip Subtask2.zip 71 | cd Subtask2/TrainImage 72 | ls | xargs -I {} mv {} 2_{} 73 | cd ../TrainMask 74 | ls | xargs -I {} mv {} 2_{} 75 | cd .. 76 | mv TrainImage/* ../Subtask1/TrainImage/ 77 | mv TrainMask/* ../Subtask1/TrainMask/ 78 | cd .. 79 | rm -r Subtask2 80 | ``` 81 | 82 | ``` 83 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -O weights/sam/sam_vit_b_01ec64.pth # download SAM checkpoint 84 | ``` 85 | 86 | ``` 87 | gdown 1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_ -O ./weights/medsam/ # download MedSAM checkpoint 88 | ``` 89 | 90 | 5. Inside the container, run the following commands to double check that dependencies are installed: 91 | ``` 92 | pip install -r requirements.txt 93 | pip install -e MedSAM/ 94 | ``` 95 | 96 | 6. Initialize your clearml credentials via: 97 | ``` 98 | clearml-init 99 | ``` 100 | 101 | ## Training 102 | 103 | This training script demonstrates training of MedSAM with point prompts on the WORD dataset. 104 | 105 | The training script `src/train_point_prompt.py` takes the following arguments: 106 | * `--tr_npy_path`:` Path to the train data root directory; 107 | * `--val_npy_path`: Path to the validation data root directory; 108 | * `--test_npy_path`: Path to the test data root directory; 109 | * `--medsam_checkpoint`: Path to the MedSAM checkpoint; 110 | * `--max_epochs`: Maximum number of epochs; 111 | * `--batch_size`: Batch size; 112 | * `--num_workers`: Number of data loader workers; 113 | * `--lr`: Learning rate (absolute lr); 114 | * `--weight_decay`: Weight decay; 115 | * `--accumulate_grad_batches`: Accumulate grad batches; 116 | * `--seed`: Random seed for reproducibility; 117 | * `--disable_aug`: Disable data augmentation; 118 | * `--freeze_prompt_encoder`: Freeze prompt emcoder; 119 | * `--gt_in_ram`: Store gt in RAM during data processing; 120 | * `--num_points`: Number of points in the prompt; 121 | * `--mask_diff`: Approach based on the mask difference; 122 | * `--mask_prompt`: Whether mask prompt is incorporated; 123 | * `--base_medsam_checkpoint`: Path to the MedSAM base predictor checkpoint (used only with mask_diff approach; if not provided, base predictor is ours MedSAM model copy); 124 | * `--eval_per_organ`: Add performance comparison of different organs (evaluation per each class). 125 | 126 | 127 | For instance, assume that the preprocessed data is stored in directory `data`, the MedSAM model is placed in `weigths/medsam` folder, and the model checkpoints should be saved in `train_point_prompt`. Then, to train the model, run the following commands: 128 | 129 | 1. Data preprocessing (with 10% saved on a disk): 130 | 1. WORD Dataset: 131 | ``` 132 | python src/preprocess_CT.py \ 133 | --nii_path "./data/WORD-V0.1.0/imagesTr" \ 134 | --gt_path "./data/WORD-V0.1.0/labelsTr" \ 135 | --img_name_suffix ".nii.gz" \ 136 | --npy_path "./data/WORD/train_" \ 137 | --proportion 0.1; \ 138 | python src/preprocess_CT.py \ 139 | --nii_path "./data/WORD-V0.1.0/imagesVal" \ 140 | --gt_path "./data/WORD-V0.1.0/labelsVal" \ 141 | --img_name_suffix ".nii.gz" \ 142 | --npy_path "./data/WORD/val_" \ 143 | --proportion 0.1; \ 144 | python src/preprocess_CT.py \ 145 | --nii_path "./data/WORD-V0.1.0/imagesTs" \ 146 | --gt_path "./data/WORD-V0.1.0/labelsTs" \ 147 | --img_name_suffix ".nii.gz" \ 148 | --npy_path "./data/WORD/test_" \ 149 | --proportion 0.1 150 | ``` 151 | 152 | 2. AbdomenCT-1K Dataset: 153 | ``` 154 | python src/preprocess_CT.py \ 155 | --nii_path "./data/Subtask1/TrainImage" \ 156 | --gt_path "./data/Subtask1/TrainMask" \ 157 | --npy_path "./data/AbdomenCT/train_" \ 158 | --proportion 0.1 159 | ``` 160 | 161 | 2. Fine-tuning: 162 | 163 | One point prompt: 164 | 165 | ``` 166 | python src/train_point_prompt.py \ 167 | --tr_npy_path "data/WORD/train_CT_Abd/" \ 168 | --val_npy_path "data/WORD/val_CT_Abd/" \ 169 | --test_npy_path "data/WORD/test_CT_Abd/" \ 170 | --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \ 171 | --max_epochs 200 \ 172 | --batch_size 24 \ 173 | --num_workers 0 \ 174 | --no-gt_in_ram \ 175 | --eval_per_organ 176 | ``` 177 | 178 | An example of the prompt with 20 points: 179 | 180 | ``` 181 | python src/train_point_prompt.py \ 182 | --tr_npy_path "data/WORD/train_CT_Abd/" \ 183 | --val_npy_path "data/WORD/val_CT_Abd/" \ 184 | --test_npy_path "data/WORD/test_CT_Abd/" \ 185 | --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \ 186 | --max_epochs 200 \ 187 | --batch_size 24 \ 188 | --num_workers 0 \ 189 | --num_points 20 \ 190 | --no-gt_in_ram \ 191 | --eval_per_organ 192 | ``` 193 | 194 | An example of fine-tuning based on the mask difference with 20 points prompt: 195 | 196 | ``` 197 | python src/train_point_prompt.py \ 198 | --tr_npy_path "data/WORD/train_CT_Abd/" \ 199 | --val_npy_path "data/WORD/val_CT_Abd/" \ 200 | --test_npy_path "data/WORD/test_CT_Abd/" \ 201 | --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \ 202 | --max_epochs 200 \ 203 | --batch_size 24 \ 204 | --num_workers 0 \ 205 | --num_points 20 \ 206 | --no-gt_in_ram \ 207 | --mask_diff \ 208 | --eval_per_organ 209 | ``` 210 | 211 | 212 | ## Testing 213 | 214 | One point prompt: 215 | 216 | ``` 217 | python src/test_model.py \ 218 | --tr_npy_path "data/WORD/train_CT_Abd/" \ 219 | --val_npy_path "data/WORD/val_CT_Abd/" \ 220 | --test_npy_path "data/WORD/test_CT_Abd/" \ 221 | --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \ 222 | --checkpoint "exp_name=0-epoch=42-val_loss=0.00.ckpt" \ 223 | --batch_size 24 \ 224 | --num_workers 0 \ 225 | --num_points 1 \ 226 | --eval_per_organ 227 | ``` 228 | -------------------------------------------------------------------------------- /num_points.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the range of values for num_points 4 | for num in 1 3 5 10 15 50 5 | do 6 | python src/train_point_prompt.py \ 7 | --tr_npy_path "./data/WORD/train_CT_Abd/" \ 8 | --val_npy_path "./data/WORD/val_CT_Abd/" \ 9 | --test_npy_path "data/WORD/test_CT_Abd/" \ 10 | --medsam_checkpoint "./weights/medsam/medsam_vit_b.pth" \ 11 | --max_epochs 200 \ 12 | --batch_size 24 \ 13 | --num_workers 0 \ 14 | --num_points $num 15 | done 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | SimpleITK==2.3.1 2 | connected-components-3d==3.17.0 3 | scikit-image==0.24.0 4 | lightning==2.3.0 5 | opencv-python==4.7.0.68 6 | torch==2.3.1 7 | torchmetrics==1.4.0 8 | torchvision==0.18.1 9 | numpy==1.26.4 10 | scikit-learn==1.5.0 11 | pandas==2.2.2 12 | tqdm==4.66.4 13 | clearml==1.16.2 14 | tensorboard==2.17.0 15 | notebook==7.2.1 16 | imutils==0.5.4 -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import lightning as pl 5 | from os.path import join 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | import cv2 10 | 11 | 12 | # Dataset class 13 | class NpyDataset(Dataset): 14 | def __init__(self, data_root, image_size=1024, data_aug=True, gt_in_ram=True): 15 | self.data_root = data_root 16 | self.gt_path = join(data_root, 'gts') 17 | self.img_path = join(data_root, 'imgs') 18 | self.gt_path_files = sorted(glob.glob(join(self.gt_path, '**/*.npy'), recursive=True)) 19 | self.gt_path_files = [file for file in self.gt_path_files if os.path.isfile(join(self.img_path, os.path.basename(file)))] 20 | self.image_size = image_size 21 | self.data_aug = data_aug 22 | self.gt_in_ram = gt_in_ram 23 | self.data = self.read_data() 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def read_data(self): 29 | data = [] 30 | for gt_path in self.gt_path_files: 31 | img_name = os.path.basename(gt_path) 32 | img_path = join(self.img_path, img_name) 33 | gt = np.load(gt_path, 'r', allow_pickle=True) # multiple labels [0,1,4,5...], (256,256) 34 | label_ids = np.unique(gt)[1:] # [1,4,5...] 35 | for label_id in label_ids: 36 | gt2D = np.uint8(gt == label_id) 37 | gt2D = (gt2D * 255).astype(np.uint8) 38 | thresh = cv2.threshold( 39 | gt2D, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU 40 | )[1] 41 | cnts = cv2.findContours( 42 | thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 43 | ) 44 | cnts = cnts[0] if len(cnts) == 2 else cnts[1] 45 | 46 | for i in range(len(cnts)): 47 | if self.gt_in_ram: 48 | mask = np.zeros_like(gt2D) 49 | cv2.drawContours( 50 | mask, cnts, i, (255, 255, 255), thickness=cv2.FILLED 51 | ) 52 | gt_segm = np.uint8(mask == 255) 53 | data.append([img_path, gt_segm, label_id]) 54 | else: 55 | assert (self.image_size, self.image_size) == gt2D.shape, "GT size does not match image size" 56 | data.append([img_path, cnts, i, label_id]) 57 | return data 58 | 59 | def __getitem__(self, index): 60 | if self.gt_in_ram: 61 | img_path, gt2D, organ_class = self.data[index] 62 | else: 63 | img_path, cnts, i, organ_class = self.data[index] 64 | mask = np.zeros((self.image_size, self.image_size)) 65 | cv2.drawContours( 66 | mask, cnts, i, (255, 255, 255), thickness=cv2.FILLED 67 | ) 68 | gt2D = np.uint8(mask == 255) 69 | 70 | img_name = os.path.basename(img_path) 71 | img_1024 = np.load(img_path, 'r', allow_pickle=True) # (H, W, 3) 72 | # convert the shape to (3, H, W) 73 | img_1024 = np.transpose(img_1024, (2, 0, 1)) # (3, 256, 256) 74 | assert np.max(img_1024) <= 1.0 and np.min(img_1024) >= 0.0, 'image should be normalized to [0, 1]' 75 | 76 | # add data augmentation: random fliplr and random flipud 77 | if self.data_aug: 78 | if random.random() > 0.5: 79 | img_1024 = np.ascontiguousarray(np.flip(img_1024, axis=-1)) 80 | gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-1)) 81 | if random.random() > 0.5: 82 | img_1024 = np.ascontiguousarray(np.flip(img_1024, axis=-2)) 83 | gt2D = np.ascontiguousarray(np.flip(gt2D, axis=-2)) 84 | 85 | gt2D = np.uint8(gt2D > 0) 86 | gt2D_256 = cv2.resize( 87 | gt2D, 88 | (256, 256), 89 | interpolation=cv2.INTER_NEAREST 90 | ) 91 | return { 92 | "image": torch.tensor(img_1024).float(), 93 | "gt2D": torch.tensor(gt2D_256[None, :, :]).long(), 94 | "gt2D_orig": torch.tensor(gt2D).long(), 95 | "image_name": img_name, 96 | "organ_class": organ_class 97 | } 98 | 99 | 100 | class NpyDataModule(pl.LightningDataModule): 101 | def __init__( 102 | self, 103 | train_data_path, 104 | val_data_path, 105 | test_npy_path, 106 | batch_size=8, 107 | num_workers=0, 108 | data_aug=True, 109 | gt_in_ram=True, 110 | ): 111 | self.train_data_path = train_data_path 112 | self.val_data_path = val_data_path 113 | self.test_npy_path = test_npy_path 114 | self.batch_size = batch_size 115 | self.num_workers = num_workers 116 | self.data_aug = data_aug 117 | self.gt_in_ram = gt_in_ram 118 | 119 | def setup(self): 120 | self.train_dataset = NpyDataset( 121 | data_root=self.train_data_path, 122 | data_aug=self.data_aug, 123 | gt_in_ram=self.gt_in_ram, 124 | ) 125 | self.val_dataset = NpyDataset( 126 | data_root=self.val_data_path, 127 | data_aug=False, 128 | gt_in_ram=self.gt_in_ram, 129 | ) 130 | self.test_dataset = NpyDataset( 131 | data_root=self.test_npy_path, 132 | data_aug=False, 133 | gt_in_ram=self.gt_in_ram, 134 | ) 135 | 136 | print("train size:", len(self.train_dataset)) 137 | print("val size:", len(self.val_dataset)) 138 | print("test size:", len(self.test_dataset)) 139 | 140 | def train_dataloader(self): 141 | return DataLoader( 142 | self.train_dataset, 143 | batch_size=self.batch_size, 144 | shuffle=True, 145 | num_workers=self.num_workers, 146 | pin_memory=True 147 | ) 148 | 149 | def val_dataloader(self): 150 | return DataLoader( 151 | self.val_dataset, 152 | batch_size=self.batch_size, 153 | shuffle=False, 154 | num_workers=self.num_workers, 155 | pin_memory=True 156 | ) 157 | 158 | def test_dataloader(self): 159 | return DataLoader( 160 | self.test_dataset, 161 | batch_size=self.batch_size, 162 | shuffle=False, 163 | num_workers=self.num_workers, 164 | pin_memory=True 165 | ) 166 | -------------------------------------------------------------------------------- /src/form_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | 8 | def get_parser(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | '--prefix_name', 12 | type=str, 13 | default="CT_Abd_word_0021", 14 | help="Fixed prefix filename." 15 | ) 16 | parser.add_argument( 17 | '--data_directory', 18 | type=str, 19 | default="human_exp/doctors_exp/gts/1", 20 | help="Path to the validation data root directory." 21 | ) 22 | 23 | return parser 24 | 25 | 26 | def compute_count_labels(directory, prefix_filename): 27 | filename_labels = {} 28 | count_labels = {} 29 | 30 | for filename in os.listdir(directory): 31 | if filename.startswith(prefix_filename): 32 | gt_pth = os.path.join(directory, filename) 33 | 34 | filename_labels[filename] = [] 35 | gt = np.load(gt_pth, 'r', allow_pickle=True) # multiple labels [0,1,4,5...], (256,256) 36 | label_ids = np.unique(gt)[1:] # [1,4,5...] 37 | for label_id in label_ids: 38 | gt2D = np.uint8(gt == label_id) 39 | gt2D = (gt2D * 255).astype(np.uint8) 40 | thresh = cv2.threshold( 41 | gt2D, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU 42 | )[1] 43 | cnts = cv2.findContours( 44 | thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 45 | ) 46 | cnts = cnts[0] if len(cnts) == 2 else cnts[1] 47 | 48 | if label_id not in count_labels: 49 | count_labels[label_id] = 0 50 | 51 | for i in range(len(cnts)): 52 | mask = np.zeros_like(gt2D) 53 | cv2.drawContours( 54 | mask, cnts, i, (255, 255, 255), thickness=cv2.FILLED 55 | ) 56 | filename_labels[filename].append(label_id) 57 | count_labels[label_id] += 1 58 | 59 | return filename_labels, count_labels 60 | 61 | 62 | def save_metadata(counts, data_directory, filename): 63 | file_path = os.path.join(data_directory, "metadata.txt") 64 | with open(file_path, 'w') as file: 65 | file.write(f"Filename prefix: {filename}\n") 66 | for label, count in counts.items(): 67 | file.write(f"Label: {label}, Count: {count}\n") 68 | 69 | 70 | def main(data_directory, prefix_filename): 71 | filename_labels, count_labels = compute_count_labels( 72 | data_directory, 73 | prefix_filename 74 | ) 75 | 76 | sorted_filename_labels = dict(sorted(filename_labels.items())) 77 | 78 | sorted_count_labels = dict(sorted(count_labels.items())) 79 | 80 | print("Sorted filenames and labels:") 81 | for file, labels in sorted_filename_labels.items(): 82 | print(f"Filename: {file}, labels: {labels}") 83 | 84 | print("\nSorted counts of labels:") 85 | for label, count in sorted_count_labels.items(): 86 | print(f"Label: {label}, count: {count}") 87 | 88 | save_metadata(sorted_count_labels, data_directory, prefix_filename) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = get_parser() 93 | args = parser.parse_args() 94 | 95 | main(args.data_directory, args.prefix_name) 96 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lightning as pl 3 | import matplotlib.pyplot as plt 4 | import monai 5 | import numpy as np 6 | import torch 7 | from segment_anything import sam_model_registry 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torchmetrics import Dice, JaccardIndex 11 | import torchvision 12 | 13 | 14 | class MedSAM(pl.LightningModule): 15 | def __init__( 16 | self, 17 | backbone: str = "vit_b", 18 | medsam_checkpoint: str = None, 19 | freeze_image_encoder: bool = False, 20 | freeze_prompt_encoder: bool = False, 21 | lr: float = 0.00005, 22 | weight_decay: float = 0.01, 23 | num_points: int = 20, 24 | is_mask_diff: bool = False, 25 | is_mask_prompt: bool = False, 26 | base_medsam_checkpoint: str = None, 27 | eval_per_organ: bool = False, 28 | logger=None 29 | ): 30 | super().__init__() 31 | self.sam_model = sam_model_registry[backbone](checkpoint=medsam_checkpoint) 32 | 33 | self.freeze_prompt_encoder = freeze_prompt_encoder 34 | if self.freeze_prompt_encoder: 35 | # freeze prompt encoder 36 | for param in self.sam_model.prompt_encoder.parameters(): 37 | param.requires_grad = False 38 | print("Prompt encoder is frozen") 39 | 40 | self.freeze_image_encoder = freeze_image_encoder 41 | if self.freeze_image_encoder: 42 | for param in self.sam_model.image_encoder.parameters(): 43 | param.requires_grad = False 44 | print("Image encoder is frozen") 45 | 46 | self.lr = lr 47 | self.weight_decay = weight_decay 48 | 49 | self.jaccard = JaccardIndex(task="binary") 50 | self.dice_score = Dice(threshold=0) 51 | 52 | self.seg_loss = monai.losses.DiceLoss( 53 | sigmoid=True, 54 | squared_pred=True, 55 | reduction='mean' 56 | ) 57 | self.ce_loss = nn.BCEWithLogitsLoss(reduction="mean") 58 | 59 | self.num_points = num_points 60 | self.is_mask_diff = is_mask_diff 61 | self.is_mask_prompt = is_mask_prompt 62 | self.base_medsam_checkpoint = base_medsam_checkpoint 63 | 64 | if self.is_mask_diff and self.base_medsam_checkpoint is not None: 65 | # load base model 66 | self.base_sam = sam_model_registry[backbone](checkpoint=medsam_checkpoint) 67 | base_medsam_checkpoint = torch.load(self.base_medsam_checkpoint) 68 | self.base_sam.load_state_dict(base_medsam_checkpoint['state_dict'], strict=False) 69 | # freeze base model 70 | for param in self.base_sam.parameters(): 71 | param.requires_grad = False 72 | 73 | self.eval_per_organ = eval_per_organ 74 | 75 | self.clearml_logger = logger 76 | 77 | def forward(self, image, point_prompt, mask_prompt=None): 78 | image_embedding = self.sam_model.image_encoder(image) # (B, 256, 64, 64) 79 | # not need to convert box to 1024x1024 grid 80 | # bbox is already in 1024x1024 81 | sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 82 | points=point_prompt, 83 | boxes=None, 84 | masks=mask_prompt, 85 | ) 86 | low_res_masks, _ = self.sam_model.mask_decoder( 87 | image_embeddings=image_embedding, # (B, 256, 64, 64) 88 | image_pe=self.sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 89 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 90 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 91 | multimask_output=False, 92 | ) # (B, 1, 256, 256) 93 | 94 | return low_res_masks 95 | 96 | def base_pred(self, image, point_prompt): 97 | if self.base_medsam_checkpoint is not None: 98 | base_sam = self.base_sam 99 | else: 100 | base_sam = self.sam_model 101 | base_sam.eval() 102 | with torch.no_grad(): 103 | image_embedding = base_sam.image_encoder(image) # (B, 256, 64, 64) 104 | sparse_embeddings, dense_embeddings = base_sam.prompt_encoder( 105 | points=point_prompt, 106 | boxes=None, 107 | masks=None, 108 | ) 109 | low_res_masks, _ = base_sam.mask_decoder( 110 | image_embeddings=image_embedding, # (B, 256, 64, 64) 111 | image_pe=base_sam.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 112 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 113 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 114 | multimask_output=False, 115 | ) # (B, 1, 256, 256) 116 | 117 | base_sam.train() 118 | 119 | return low_res_masks 120 | 121 | def _shared_step(self, batch, batch_idx, 122 | phase: str, calculate_metrics: bool = True): 123 | image = batch["image"] 124 | gt2D = batch["gt2D"] # (B, 256, 256) 125 | gt2D_orig = batch["gt2D_orig"] # (B, 1024, 1024) 126 | 127 | low_base_pred_logits = None 128 | 129 | if self.is_mask_diff: 130 | point_prompt, low_base_pred_logits, base_pred_binary = self.generate_prompt_mask_diff(image, gt2D_orig) 131 | if not self.is_mask_prompt: 132 | low_base_pred_logits = None 133 | else: 134 | point_prompt = self.generate_point_prompt(gt2D_orig, phase) 135 | 136 | medsam_lite_pred = self(image, point_prompt, low_base_pred_logits) 137 | loss = self.seg_loss(medsam_lite_pred, gt2D) + self.ce_loss(medsam_lite_pred, gt2D.float()) 138 | 139 | logs = { 140 | f"loss_{phase}": loss, 141 | "step": self.current_epoch + 1 142 | } 143 | 144 | if calculate_metrics: 145 | logs.update(self._compute_metrics(medsam_lite_pred, gt2D, phase)) 146 | if self.eval_per_organ: 147 | classes = batch["organ_class"] 148 | logs.update(self._compute_metrics_per_organ(medsam_lite_pred, gt2D, classes, phase)) 149 | 150 | self.log_dict(logs, prog_bar=True, on_epoch=True, on_step=False) 151 | 152 | if phase == "test": 153 | if batch_idx < 10: 154 | pred_binary = medsam_lite_pred[0] > self.sam_model.mask_threshold 155 | img_name = batch["image_name"][0] 156 | orig_img = image[0].squeeze().detach().cpu().permute(1, 2, 0) 157 | pred_mask = torchvision.transforms.functional.resize( 158 | pred_binary, 159 | (1024, 1024), 160 | interpolation=2 161 | ) 162 | 163 | plt.figure(figsize=(10, 5)) 164 | 165 | plt.subplot(1, 2, 1) 166 | plt.imshow(orig_img, cmap='gray') 167 | plt.imshow(pred_mask.squeeze().detach().cpu(), alpha=0.5, cmap='viridis') 168 | plt.title("Predicted mask") 169 | 170 | plt.subplot(1, 2, 2) 171 | plt.imshow(orig_img, cmap='gray') 172 | plt.imshow(gt2D_orig[0].squeeze().detach().cpu(), alpha=0.5, cmap='viridis') 173 | plt.title("Ground Truth mask") 174 | 175 | self.clearml_logger.report_matplotlib_figure( 176 | title=f"Test Prediction: {img_name}", 177 | series="Mask pred visualization", 178 | iteration=batch_idx, 179 | figure=plt 180 | ) 181 | 182 | plt.close() 183 | 184 | if phase == "val" and self.is_mask_diff: 185 | if batch_idx % 5 == 0: 186 | orig_img = image[0].squeeze().detach().cpu().permute(1, 2, 0) 187 | img_name = batch["image_name"][0] 188 | img = np.load("./data/WORD/val_CT_Abd/imgs/" + img_name, 'r', allow_pickle=True) 189 | img = (img * 255).astype(np.uint8) 190 | 191 | coords = point_prompt[0][0].cpu().tolist() 192 | 193 | pred_binary = medsam_lite_pred[0] > self.sam_model.mask_threshold 194 | pred_mask = torchvision.transforms.functional.resize( 195 | pred_binary, 196 | (1024, 1024), 197 | interpolation=2 198 | ) 199 | 200 | fig, axs = plt.subplots(1, 4, figsize=(20, 5)) 201 | 202 | axs[0].imshow(orig_img, cmap='gray') 203 | axs[0].imshow(base_pred_binary[0].squeeze().detach().cpu(), alpha=0.5, cmap='viridis') 204 | axs[0].set_title("Base prediction mask") 205 | 206 | axs[1].imshow(orig_img, cmap='gray') 207 | axs[1].imshow(gt2D_orig[0].squeeze().detach().cpu(), alpha=0.5, cmap='viridis') 208 | axs[1].set_title("Ground Truth mask") 209 | 210 | for x, y in coords: 211 | cv2.circle(img, (int(x), int(y)), 5, (0, 0, 255), -1) 212 | axs[2].imshow(img) 213 | axs[2].set_title("Generated points from mask differences") 214 | 215 | axs[3].imshow(orig_img, cmap='gray') 216 | axs[3].imshow(pred_mask.squeeze().detach().cpu(), alpha=0.5, cmap='viridis') 217 | axs[3].set_title("Final predicted mask") 218 | 219 | self.clearml_logger.report_matplotlib_figure( 220 | title=f"Validation Prediction: {img_name}", 221 | series="Mask pred visualization", 222 | iteration=batch_idx, 223 | figure=plt 224 | ) 225 | 226 | plt.close() 227 | 228 | return loss, logs 229 | 230 | def _compute_metrics_per_organ(self, pred_logits, gt_mask, classes, phase): 231 | pred_binary = pred_logits > self.sam_model.mask_threshold 232 | num_classes = torch.max(classes).item() 233 | metrics = {} 234 | for organ in range(1, num_classes + 1): 235 | dice_arr = [] 236 | jaccard_arr = [] 237 | organ_mask = (classes == organ) 238 | for i in range(pred_logits.size(0)): 239 | if organ_mask[i].item(): 240 | dice = self.dice_score(pred_logits[i], gt_mask[i]).item() 241 | jaccard = self.jaccard(pred_binary[i], gt_mask[i]).item() 242 | dice_arr.append(dice) 243 | jaccard_arr.append(jaccard) 244 | 245 | if dice_arr and jaccard_arr: 246 | dice_mean = np.mean(dice_arr) 247 | dice_std = np.std(dice_arr) 248 | 249 | jaccard_mean = np.mean(jaccard_arr) 250 | jaccard_std = np.std(jaccard_arr) 251 | 252 | metrics[f"iou_mean/{phase}/{organ}"] = jaccard_mean 253 | metrics[f"iou_std/{phase}/{organ}"] = jaccard_std 254 | metrics[f"dice_mean/{phase}/{organ}"] = dice_mean 255 | metrics[f"dice_std/{phase}/{organ}"] = dice_std 256 | 257 | return metrics 258 | 259 | def _compute_metrics(self, pred_logits, gt_mask, phase): 260 | pred_binary = pred_logits > self.sam_model.mask_threshold 261 | dice_arr = [] 262 | jaccard_arr = [] 263 | for i in range(pred_logits.size(0)): 264 | dice = self.dice_score(pred_logits[i], gt_mask[i]).item() 265 | jaccard = self.jaccard(pred_binary[i], gt_mask[i]).item() 266 | dice_arr.append(dice) 267 | jaccard_arr.append(jaccard) 268 | dice_mean = np.mean(dice_arr) 269 | dice_std = np.std(dice_arr) 270 | 271 | jaccard_mean = np.mean(jaccard_arr) 272 | jaccard_std = np.std(jaccard_arr) 273 | 274 | metrics = { 275 | f"iou_mean/{phase}": jaccard_mean, 276 | f"iou_std/{phase}": jaccard_std, 277 | f"dice_mean/{phase}": dice_mean, 278 | f"dice_std/{phase}": dice_std, 279 | } 280 | 281 | return metrics 282 | 283 | def training_step(self, batch, batch_idx): 284 | return self._shared_step(batch, batch_idx, "train", False)[0] 285 | 286 | def validation_step(self, batch, batch_idx): 287 | self._shared_step(batch, batch_idx, "val", True) 288 | 289 | def test_step(self, batch, batch_idx): 290 | return self._shared_step(batch, batch_idx, "test", True)[1] 291 | 292 | def predict_step(self, batch, batch_idx): 293 | image = batch["image"] 294 | gt2D_orig = batch.get("gt2D_orig", None) # (B, 1024, 1024) 295 | 296 | low_base_pred_logits = batch.get("low_base_pred_logits", None) 297 | 298 | if gt2D_orig is not None: 299 | if self.is_mask_diff: 300 | if low_base_pred_logits is None: 301 | point_prompt, low_base_pred_logits, _ = self.generate_prompt_mask_diff(image, gt2D_orig) 302 | else: 303 | point_prompt, _, _ = self.generate_prompt_mask_diff(image, gt2D_orig) 304 | if not self.is_mask_prompt: 305 | low_base_pred_logits = None 306 | else: 307 | point_prompt = self.generate_point_prompt(gt2D_orig) 308 | coords = point_prompt[0] 309 | else: 310 | coords = batch["coords"] 311 | low_base_pred_logits = batch.get("low_base_pred_logits", None) 312 | coords_torch = torch.tensor(coords).float() 313 | coords_torch = torch.stack(coords_torch) 314 | labels_torch = torch.ones(coords_torch.shape[0], coords_torch.shape[1]).long() # (B, N) 315 | point_prompt = (coords_torch, labels_torch) 316 | 317 | medsam_lite_pred = self(image, point_prompt, low_base_pred_logits) 318 | 319 | return medsam_lite_pred, coords 320 | 321 | def configure_optimizers(self): 322 | 323 | optimizer = torch.optim.AdamW( 324 | self.sam_model.parameters(), 325 | lr=self.lr, 326 | betas=(0.9, 0.999), 327 | eps=1e-08, 328 | weight_decay=self.weight_decay 329 | ) 330 | 331 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 332 | optimizer, 333 | mode="min", 334 | factor=0.5, 335 | min_lr=1e-6, 336 | patience=5, 337 | verbose=True 338 | ) 339 | 340 | return { 341 | "optimizer": optimizer, 342 | "lr_scheduler": { 343 | "scheduler": scheduler, 344 | "monitor": "loss_val" 345 | } 346 | } 347 | 348 | def generate_point_prompt(self, gt2D_orig, phase=None): 349 | assert self.num_points > 0, "The number of points in the prompt cannot be less than 1" 350 | coords_torch = [] 351 | for i in range(gt2D_orig.shape[0]): # B 352 | gt2D = gt2D_orig[i].cpu().numpy() 353 | y_indices, x_indices = np.where(gt2D == 1) 354 | if self.num_points == 1: 355 | x_point = np.random.choice(x_indices) 356 | y_point = np.random.choice(y_indices) 357 | coords = np.array([x_point, y_point])[None, ...] 358 | else: 359 | # if phase == "train": 360 | # chosen_indices = np.random.choice(len(x_indices), self.num_points, replace=False) 361 | # x_points = x_indices[chosen_indices] 362 | # y_points = y_indices[chosen_indices] 363 | # coords = np.array([x_points, y_points]).T 364 | # else: 365 | 366 | y_indices_out, x_indices_out = np.where(gt2D == 0) 367 | 368 | num_points_in = int(self.num_points * 0.8) 369 | num_points_out = self.num_points - num_points_in 370 | 371 | chosen_indices_in = np.random.choice(len(x_indices), num_points_in, replace=False) 372 | chosen_indices_out = np.random.choice(len(x_indices_out), num_points_out, replace=False) 373 | x_points_in = x_indices[chosen_indices_in] 374 | y_points_in = y_indices[chosen_indices_in] 375 | 376 | x_points_out = x_indices_out[chosen_indices_out] 377 | y_points_out = y_indices_out[chosen_indices_out] 378 | 379 | coords_in = np.array([x_points_in, y_points_in]).T 380 | coords_out = np.array([x_points_out, y_points_out]).T 381 | coords = np.concatenate((coords_in, coords_out), axis=0) # (N, 2) 382 | 383 | # chosen_indices = np.random.choice(len(x_indices), self.num_points, replace=False) 384 | # x_points = x_indices[chosen_indices] 385 | # y_points = y_indices[chosen_indices] 386 | # coords = np.array([x_points, y_points]).T 387 | 388 | coords_torch.append(torch.tensor(coords).float()) 389 | 390 | coords_torch = torch.stack(coords_torch).to(gt2D_orig.device) # (B, N, 2) 391 | 392 | # Fixed label (1) 393 | labels_torch = torch.ones(coords_torch.shape[0], coords_torch.shape[1]).long() # (B, N) 394 | 395 | # Padding 396 | # num_padding = np.random.randint(0, self.num_points) 397 | # padding_indices = np.random.choice(coords_torch.shape[1], num_padding, replace=False) 398 | # coords_torch[:, padding_indices, :] = torch.tensor([0, 0], dtype=torch.float, device=coords_torch.device) 399 | # labels_torch[:, padding_indices] = -1 400 | 401 | # Assign ones as labels for coords_in and zeros for coords_out 402 | # num_points_in = int(coords_torch.shape[1] * 0.8) 403 | # num_points_out = coords_torch.shape[1] - num_points_in 404 | # labels_torch = torch.cat((torch.ones(coords_torch.shape[0], num_points_in), 405 | # torch.zeros(coords_torch.shape[0], num_points_out)), dim=1).long() 406 | 407 | # Random labels (0 or 1) 408 | # labels_torch = torch.randint(low=0, high=2, size=(coords_torch.shape[0], coords_torch.shape[1])).long() # (B, N) 409 | 410 | return (coords_torch, labels_torch) 411 | 412 | def generate_prompt_mask_diff(self, image, gt2D_orig): 413 | coords_torch_base = [] 414 | for i in range(gt2D_orig.shape[0]): # B 415 | gt2D = gt2D_orig[i].cpu().numpy() 416 | y_indices, x_indices = np.where(gt2D == 1) 417 | chosen_indices = np.random.choice(len(x_indices), self.num_points, replace=False) 418 | x_points = x_indices[chosen_indices] 419 | y_points = y_indices[chosen_indices] 420 | coords_base = np.array([x_points, y_points]).T # (N, 2) 421 | coords_torch_base.append(torch.tensor(coords_base).float()) 422 | coords_torch_base = torch.stack(coords_torch_base).to(gt2D_orig.device) # (B, N, 2) 423 | labels_torch_base = torch.ones(coords_torch_base.shape[0], coords_torch_base.shape[1]).long() 424 | point_prompt = (coords_torch_base, labels_torch_base) 425 | 426 | low_base_pred_logits = self.base_pred(image, point_prompt) 427 | base_pred_logits = F.interpolate( 428 | low_base_pred_logits, 429 | size=(image.shape[2], image.shape[3]), 430 | mode="bilinear", 431 | align_corners=False, 432 | ) 433 | 434 | base_pred_binary = (base_pred_logits > self.sam_model.mask_threshold).int() 435 | gt_mask = gt2D_orig.unsqueeze(1) 436 | delta = (base_pred_binary - (gt_mask > 0).int()).abs().squeeze(1) 437 | 438 | # num_points = np.random.randint(2, self.num_points) 439 | gt_mask_for_idx = gt_mask.squeeze(1) 440 | 441 | coords_list = point_prompt[0].tolist() 442 | labels_list = point_prompt[1].tolist() 443 | 444 | for num_sample in range(delta.size(0)): 445 | conditions = [ 446 | (delta[num_sample] == 1, 0.7), # from mask differences 447 | (gt_mask_for_idx[num_sample] == 1, 0.2), # from gt mask 448 | (gt_mask_for_idx[num_sample] == 0, 0.1) # from outside of gt mask 449 | ] 450 | 451 | for pos, ratio in conditions: 452 | y_idx, x_idx = torch.where(pos) 453 | rnd_idx = np.random.randint(0, len(y_idx), int(ratio*self.num_points)) 454 | coords_list[num_sample].extend([[x_idx[id].item(), y_idx[id].item()] for id in rnd_idx]) 455 | labels_list[num_sample].extend([1] * len(rnd_idx)) 456 | 457 | coords_torch = torch.tensor( 458 | coords_list, 459 | dtype=torch.float64, 460 | requires_grad=True, 461 | device=self.device 462 | ) 463 | labels_torch = torch.tensor( 464 | labels_list, 465 | dtype=torch.float64, 466 | requires_grad=True, 467 | device=self.device 468 | ) 469 | 470 | return (coords_torch, labels_torch), low_base_pred_logits, base_pred_binary 471 | -------------------------------------------------------------------------------- /src/point_prompt_demo.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class PointPromptDemo: 7 | def __init__(self, model): 8 | self.model = model 9 | self.model.eval() 10 | 11 | def show_mask(self, mask, ax, random_color=False, alpha=0.30): 12 | if random_color: 13 | color = np.concatenate([np.random.random(3), np.array([alpha])], axis=0) 14 | else: 15 | color = np.array([251/255, 52/255, 30/255, alpha]) 16 | h, w = mask.shape[-2:] 17 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 18 | ax.imshow(mask_image) 19 | 20 | @torch.no_grad() 21 | def infer(self, batch): 22 | medsam_lite_pred, coords = self.model.predict_step(batch, 1) 23 | return medsam_lite_pred, coords 24 | 25 | def show(self, image, seg, fig_size=5, alpha=0.7): 26 | fig, ax = plt.subplots(1, 1, figsize=(fig_size, fig_size)) 27 | plt.tight_layout() 28 | ax.imshow(image) 29 | ax.axis('off') 30 | self.show_mask(seg, ax, random_color=False, alpha=alpha) 31 | plt.show() 32 | -------------------------------------------------------------------------------- /src/preprocess_CT.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import SimpleITK as sitk 5 | 6 | join = os.path.join 7 | import cc3d 8 | import numpy as np 9 | from skimage import transform 10 | from tqdm import tqdm 11 | 12 | 13 | def get_parser(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--nii_path', type=str, help='Path to folder with nii images') 16 | parser.add_argument('--gt_path', type=str, help='Path to folder with nii ground truth masks (labels)') 17 | parser.add_argument('--img_name_suffix', type=str, default="_0000.nii.gz") 18 | parser.add_argument('--gt_name_suffix', type=str, default=".nii.gz") 19 | parser.add_argument('--npy_path', type=str, help='Path to save npy files (path to dataset)') 20 | parser.add_argument('--proportion', type=float, default=1, help='Proportion of slices to be sampled from each CT') 21 | return parser 22 | 23 | 24 | def main(): 25 | parser = get_parser() 26 | args = parser.parse_args() 27 | # convert nii image to npz files, including original image and corresponding masks 28 | modality = "CT" 29 | anatomy = "Abd" # anantomy + dataset name 30 | img_name_suffix = args.img_name_suffix 31 | gt_name_suffix = args.gt_name_suffix 32 | prefix = modality + "_" + anatomy + "_" 33 | 34 | nii_path = args.nii_path 35 | gt_path = args.gt_path 36 | npy_path = args.npy_path + prefix[:-1] 37 | os.makedirs(join(npy_path, "gts"), exist_ok=True) 38 | os.makedirs(join(npy_path, "imgs"), exist_ok=True) 39 | 40 | image_size = 1024 41 | voxel_num_thre2d = 100 42 | voxel_num_thre3d = 1000 43 | 44 | names = sorted(os.listdir(gt_path)) 45 | print(f"ori \# files {len(names)=}") 46 | names = [ 47 | name 48 | for name in names 49 | if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix)) 50 | ] 51 | print(f"after sanity check \# files {len(names)=}") 52 | tumor_id = None # only set this when there are multiple tumors; convert semantic masks to instance masks 53 | # set window level and width 54 | # https://radiopaedia.org/articles/windowing-ct 55 | WINDOW_LEVEL = 40 # only for CT images 56 | WINDOW_WIDTH = 400 # only for CT images 57 | 58 | # %% save preprocessed images and masks as npz files 59 | for name in tqdm(names): # use all cases 60 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix 61 | gt_name = name 62 | gt_sitk = sitk.ReadImage(join(gt_path, gt_name)) 63 | gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk)) 64 | # label tumor masks as instances and remove from gt_data_ori 65 | if tumor_id is not None: 66 | tumor_bw = np.uint8(gt_data_ori == tumor_id) 67 | gt_data_ori[tumor_bw > 0] = 0 68 | # label tumor masks as instances 69 | tumor_inst, tumor_n = cc3d.connected_components( 70 | tumor_bw, connectivity=26, return_N=True 71 | ) 72 | # put the tumor instances back to gt_data_ori 73 | gt_data_ori[tumor_inst > 0] = ( 74 | tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1 75 | ) 76 | 77 | # exclude the objects with less than 1000 pixels in 3D 78 | gt_data_ori = cc3d.dust( 79 | gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True 80 | ) 81 | # remove small objects with less than 100 pixels in 2D slices 82 | 83 | for slice_i in range(gt_data_ori.shape[0]): 84 | gt_i = gt_data_ori[slice_i, :, :] 85 | # remove small objects with less than 100 pixels 86 | # reason: fro such small objects, the main challenge is detection rather than segmentation 87 | gt_data_ori[slice_i, :, :] = cc3d.dust( 88 | gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True 89 | ) 90 | # find non-zero slices 91 | z_index, _, _ = np.where(gt_data_ori > 0) 92 | z_index = np.unique(z_index) 93 | 94 | if len(z_index) > 0: 95 | # crop the ground truth with non-zero slices 96 | gt_roi = gt_data_ori[z_index, :, :] 97 | # load image and preprocess 98 | img_sitk = sitk.ReadImage(join(nii_path, image_name)) 99 | image_data = sitk.GetArrayFromImage(img_sitk) 100 | # nii preprocess start 101 | if modality == "CT": 102 | lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2 103 | upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2 104 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 105 | image_data_pre = ( 106 | (image_data_pre - np.min(image_data_pre)) 107 | / (np.max(image_data_pre) - np.min(image_data_pre)) 108 | * 255.0 109 | ) 110 | else: 111 | lower_bound, upper_bound = np.percentile( 112 | image_data[image_data > 0], 0.5 113 | ), np.percentile(image_data[image_data > 0], 99.5) 114 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 115 | image_data_pre = ( 116 | (image_data_pre - np.min(image_data_pre)) 117 | / (np.max(image_data_pre) - np.min(image_data_pre)) 118 | * 255.0 119 | ) 120 | image_data_pre[image_data == 0] = 0 121 | 122 | image_data_pre = np.uint8(image_data_pre) 123 | img_roi = image_data_pre[z_index, :, :] 124 | # np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing()) 125 | # save the image and ground truth as nii files for sanity check; 126 | # they can be removed 127 | img_roi_sitk = sitk.GetImageFromArray(img_roi) 128 | gt_roi_sitk = sitk.GetImageFromArray(gt_roi) 129 | sitk.WriteImage( 130 | img_roi_sitk, 131 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"), 132 | ) 133 | sitk.WriteImage( 134 | gt_roi_sitk, 135 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"), 136 | ) 137 | # save the each CT image as npy file based on a proportion to be saved 138 | for i in range(img_roi.shape[0]): 139 | if np.random.uniform(0, 1) < args.proportion: 140 | img_i = img_roi[i, :, :] 141 | img_3c = np.repeat(img_i[:, :, None], 3, axis=-1) 142 | resize_img_skimg = transform.resize( 143 | img_3c, 144 | (image_size, image_size), 145 | order=3, 146 | preserve_range=True, 147 | mode="constant", 148 | anti_aliasing=True, 149 | ) 150 | resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip( 151 | resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None 152 | ) # normalize to [0, 1], (H, W, 3) 153 | gt_i = gt_roi[i, :, :] 154 | resize_gt_skimg = transform.resize( 155 | gt_i, 156 | (image_size, image_size), 157 | order=0, 158 | preserve_range=True, 159 | mode="constant", 160 | anti_aliasing=False, 161 | ) 162 | resize_gt_skimg = np.uint8(resize_gt_skimg) 163 | assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape 164 | np.save( 165 | join( 166 | npy_path, 167 | "imgs", 168 | prefix 169 | + gt_name.split(gt_name_suffix)[0] 170 | + "-" 171 | + str(i).zfill(3) 172 | + ".npy", 173 | ), 174 | resize_img_skimg_01, 175 | ) 176 | np.save( 177 | join( 178 | npy_path, 179 | "gts", 180 | prefix 181 | + gt_name.split(gt_name_suffix)[0] 182 | + "-" 183 | + str(i).zfill(3) 184 | + ".npy", 185 | ), 186 | resize_gt_skimg, 187 | ) 188 | 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /src/test_bboxes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | from matplotlib import pyplot as plt 8 | from segment_anything import sam_model_registry 9 | import torch 10 | import torchvision 11 | from tqdm import tqdm 12 | 13 | from dataset import NpyDataModule 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | '--tr_npy_path', 20 | type=str, 21 | help="Path to the train data root directory.", 22 | required=True 23 | ) 24 | parser.add_argument( 25 | '--val_npy_path', 26 | type=str, 27 | help="Path to the validation data root directory.", 28 | required=True 29 | ) 30 | parser.add_argument( 31 | '--test_npy_path', 32 | type=str, 33 | help="Path to the test data root directory.", 34 | required=True 35 | ) 36 | parser.add_argument( 37 | '--batch_size', 38 | type=int, 39 | default=4, 40 | help="Batch size." 41 | ) 42 | parser.add_argument( 43 | '--num_workers', 44 | type=int, 45 | default=0, 46 | help="Number of data loader workers." 47 | ) 48 | parser.add_argument( 49 | '--seed', 50 | type=int, 51 | default=2023, 52 | help="Random seed for reproducibility." 53 | ) 54 | parser.add_argument( 55 | '--disable_aug', 56 | action='store_true', 57 | help="Disable data augmentation." 58 | ) 59 | parser.add_argument( 60 | '--gt_in_ram', 61 | default=True, 62 | action=argparse.BooleanOptionalAction 63 | ) 64 | parser.add_argument( 65 | '--num_classes', 66 | type=int, 67 | default=16, 68 | help="Number of classes in the dataset." 69 | ) 70 | return parser 71 | 72 | 73 | def load_model(medsam_checkpoint="weights/medsam/medsam_vit_b.pth", device="cuda"): 74 | medsam_model = sam_model_registry['vit_b'](checkpoint=medsam_checkpoint) 75 | return medsam_model.to(device) 76 | 77 | 78 | def dice(pred, true, k=1): 79 | intersection = np.sum(pred[true == k]) * 2.0 80 | dice = intersection / (np.sum(pred) + np.sum(true)) 81 | return dice 82 | 83 | 84 | def compute_metrics_per_organ(pred_logits, gt_mask, classes): 85 | pred_binary = pred_logits > 0.0 86 | num_classes = torch.max(classes).item() 87 | metrics = {} 88 | for organ in range(1, num_classes + 1): 89 | dice_arr = [] 90 | for i in range(pred_logits.size(0)): 91 | organ_mask = (classes == organ) 92 | if organ_mask[i].item(): 93 | dice_score = dice( 94 | np.uint8(pred_binary[i].cpu().numpy()), 95 | gt_mask[i].cpu().numpy()) 96 | dice_arr.append(dice_score) 97 | if dice_arr: 98 | dice_mean = np.mean(dice_arr) 99 | dice_std = np.std(dice_arr) 100 | metrics[f"dice_mean/{organ}"] = dice_mean 101 | metrics[f"dice_std/{organ}"] = dice_std 102 | return metrics 103 | 104 | 105 | def generate_bboxes(batch): 106 | batch_gts = batch["gt2D_orig"] 107 | batch_bboxes = [] 108 | for i in range(batch_gts.shape[0]): 109 | gt_segm = batch_gts[i].squeeze().detach().cpu().numpy() 110 | gt_segm = cv2.normalize(gt_segm, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) 111 | thresh = cv2.threshold(gt_segm, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] 112 | cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 113 | cnts = cnts[0] if len(cnts) == 2 else cnts[1] 114 | bboxes = [] 115 | for c in cnts: 116 | x, y, w, h = cv2.boundingRect(c) 117 | bboxes.append((x, y, w + x, h + y)) 118 | batch_bboxes.append(bboxes) 119 | return batch_bboxes 120 | 121 | 122 | def infer_bboxes(batch, medsam_model, device): 123 | with torch.no_grad(): 124 | img_embed = medsam_model.image_encoder(batch["image"].to(device)) 125 | batch_bboxes = generate_bboxes(batch) 126 | box_torch = torch.as_tensor( 127 | batch_bboxes, 128 | dtype=torch.float, 129 | device=img_embed.device 130 | ) 131 | 132 | if len(box_torch.shape) == 2: 133 | box_torch = box_torch[:, None, :] # (B, 1, 4) 134 | 135 | sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( 136 | points=None, 137 | boxes=box_torch, 138 | masks=None, 139 | ) 140 | low_res_logits, _ = medsam_model.mask_decoder( 141 | image_embeddings=img_embed, 142 | image_pe=medsam_model.prompt_encoder.get_dense_pe(), 143 | sparse_prompt_embeddings=sparse_embeddings, 144 | dense_prompt_embeddings=dense_embeddings, 145 | multimask_output=False, 146 | ) 147 | pred_binary = low_res_logits > 0.0 148 | pred_mask = torchvision.transforms.functional.resize( 149 | pred_binary, 150 | (1024, 1024), 151 | interpolation=2 152 | ) 153 | medsam_seg = pred_mask.squeeze() 154 | return medsam_seg 155 | 156 | 157 | def show_mask(mask, ax, random_color=False): 158 | if random_color: 159 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 160 | else: 161 | color = np.array([251/255, 252/255, 30/255, 0.6]) 162 | h, w = mask.shape[-2:] 163 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 164 | ax.imshow(mask_image) 165 | 166 | 167 | def show_box(box, ax): 168 | x0, y0 = box[0], box[1] 169 | w, h = box[2] - box[0], box[3] - box[1] 170 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 171 | 172 | 173 | def main(): 174 | parser = get_parser() 175 | args = parser.parse_args() 176 | 177 | seed = args.seed 178 | torch.cuda.empty_cache() 179 | os.environ['PYTHONHASHSEED'] = str(seed) 180 | random.seed(seed) 181 | np.random.seed(seed) 182 | torch.manual_seed(seed) 183 | torch.cuda.manual_seed(seed) 184 | 185 | medsam_checkpoint = "weights/medsam/medsam_vit_b.pth" 186 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 187 | medsam_model = load_model(medsam_checkpoint, device) 188 | medsam_model.eval() 189 | print(f"MedSAM size: {sum(p.numel() for p in medsam_model.parameters())}") 190 | 191 | datamodule = NpyDataModule( 192 | train_data_path=args.tr_npy_path, 193 | val_data_path=args.val_npy_path, 194 | test_npy_path=args.test_npy_path, 195 | batch_size=args.batch_size, 196 | num_workers=args.num_workers, 197 | gt_in_ram=args.gt_in_ram, 198 | ) 199 | datamodule.setup() 200 | num_classes = args.num_classes 201 | 202 | metrics = {} 203 | 204 | for batch in tqdm(datamodule.test_dataloader()): 205 | medsam_seg = infer_bboxes(batch, medsam_model, device) 206 | metrics.update( 207 | compute_metrics_per_organ( 208 | medsam_seg, 209 | batch["gt2D_orig"], 210 | batch["organ_class"] 211 | ) 212 | ) 213 | 214 | mean_dice = np.mean([metrics[f"dice_mean/{organ}"] for organ in range(1, num_classes + 1)]) 215 | std_dice = np.std([metrics[f"dice_mean/{organ}"] for organ in range(1, num_classes + 1)]) 216 | 217 | print(f"Total mean dice: {mean_dice:.4f}") 218 | print(f"Total std of dice: {std_dice:.4f}") 219 | 220 | for organ in range(1, num_classes + 1): 221 | print(f"mean_dice/{organ}: {metrics.get(f'dice_mean/{organ}'):.4f}") 222 | 223 | for organ in range(1, num_classes + 1): 224 | print(f"mean_std/{organ}: {metrics.get(f'dice_std/{organ}'):.4f}") 225 | 226 | print("\nMetrics summary:") 227 | print(metrics) 228 | 229 | 230 | if __name__ == "__main__": 231 | main() 232 | -------------------------------------------------------------------------------- /src/test_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from datetime import datetime 5 | 6 | import lightning as pl 7 | import numpy as np 8 | import torch 9 | from clearml import Task 10 | 11 | from dataset import NpyDataModule 12 | from model import MedSAM 13 | 14 | 15 | def get_parser(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '--tr_npy_path', 19 | type=str, 20 | help="Path to the train data root directory.", 21 | required=True 22 | ) 23 | parser.add_argument( 24 | '--val_npy_path', 25 | type=str, 26 | help="Path to the validation data root directory.", 27 | required=True 28 | ) 29 | parser.add_argument( 30 | '--test_npy_path', 31 | type=str, 32 | help="Path to the test data root directory.", 33 | required=True 34 | ) 35 | parser.add_argument( 36 | '--medsam_checkpoint', 37 | type=str, 38 | help="Path to the MedSAM checkpoint.", 39 | required=True 40 | ) 41 | parser.add_argument( 42 | '--checkpoint', 43 | type=str, 44 | help="MedSAM fine-tuned checkpoint file name.", 45 | required=True 46 | ) 47 | parser.add_argument( 48 | '--batch_size', 49 | type=int, 50 | default=16, 51 | help="Batch size." 52 | ) 53 | parser.add_argument( 54 | '--num_workers', 55 | type=int, 56 | default=8, 57 | help="Number of data loader workers." 58 | ) 59 | parser.add_argument( 60 | '--seed', 61 | type=int, 62 | default=2023, 63 | help="Random seed for reproducibility." 64 | ) 65 | parser.add_argument( 66 | '--disable_aug', 67 | action='store_true', 68 | help="Disable data augmentation." 69 | ) 70 | parser.add_argument( 71 | '--gt_in_ram', 72 | default=True, 73 | action=argparse.BooleanOptionalAction 74 | ) 75 | parser.add_argument( 76 | '--num_points', 77 | type=int, 78 | default=1, 79 | help="Number of points in prompt to test on." 80 | ) 81 | parser.add_argument( 82 | '--mask_diff', 83 | default=False, 84 | action=argparse.BooleanOptionalAction 85 | ) 86 | parser.add_argument( 87 | '--mask_prompt', 88 | default=False, 89 | action=argparse.BooleanOptionalAction 90 | ) 91 | parser.add_argument( 92 | '--base_medsam_checkpoint', 93 | type=str, 94 | default=None, 95 | help="Path to the base predictor (MedSAM) checkpoint." 96 | ) 97 | parser.add_argument( 98 | '--eval_per_organ', 99 | default=False, 100 | action=argparse.BooleanOptionalAction 101 | ) 102 | 103 | return parser 104 | 105 | 106 | def test(exp_name, args): 107 | task = Task.init( 108 | project_name="medsam_point", 109 | tags=[ 110 | "testing", 111 | "1_point", 112 | ], 113 | task_name=exp_name, 114 | auto_connect_frameworks={"matplotlib": False}, 115 | ) 116 | 117 | medsam_model = MedSAM( 118 | medsam_checkpoint=args.medsam_checkpoint, 119 | freeze_image_encoder=True, 120 | num_points=args.num_points, 121 | eval_per_organ=args.eval_per_organ, 122 | is_mask_diff=args.mask_diff, 123 | is_mask_prompt=args.mask_prompt, 124 | base_medsam_checkpoint=args.base_medsam_checkpoint, 125 | logger=task.get_logger() 126 | ) 127 | checkpoint = torch.load("logs/" + args.checkpoint) 128 | medsam_model.load_state_dict(checkpoint['state_dict'], strict=False) 129 | 130 | datamodule = NpyDataModule( 131 | args.tr_npy_path, 132 | args.val_npy_path, 133 | args.test_npy_path, 134 | batch_size=args.batch_size, 135 | num_workers=args.num_workers, 136 | data_aug=not args.disable_aug, 137 | gt_in_ram=args.gt_in_ram, 138 | ) 139 | datamodule.setup() 140 | 141 | trainer = pl.Trainer() 142 | 143 | test_dice = trainer.test( 144 | medsam_model, 145 | datamodule.test_dataloader() 146 | )[0]["dice_mean/test"] 147 | 148 | return test_dice 149 | 150 | 151 | def main(): 152 | parser = get_parser() 153 | args = parser.parse_args() 154 | 155 | seed = args.seed 156 | torch.cuda.empty_cache() 157 | os.environ['PYTHONHASHSEED'] = str(seed) 158 | random.seed(seed) 159 | np.random.seed(seed) 160 | torch.manual_seed(seed) 161 | torch.cuda.manual_seed(seed) 162 | 163 | exp_name = datetime.now().strftime("%d-%m-%Y %H:%M:%S") 164 | test_dice = test(exp_name, args) 165 | print(test_dice) 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /src/train_point_prompt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from datetime import datetime 5 | 6 | import lightning as pl 7 | import numpy as np 8 | import torch 9 | from clearml import Task 10 | from lightning.pytorch.callbacks import ModelCheckpoint 11 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping 12 | 13 | from dataset import NpyDataModule 14 | from model import MedSAM 15 | 16 | 17 | def get_parser(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | '--tr_npy_path', 21 | type=str, 22 | help="Path to the train data root directory.", 23 | required=True 24 | ) 25 | parser.add_argument( 26 | '--val_npy_path', 27 | type=str, 28 | help="Path to the validation data root directory.", 29 | required=True 30 | ) 31 | parser.add_argument( 32 | '--test_npy_path', 33 | type=str, 34 | help="Path to the test data root directory.", 35 | required=True 36 | ) 37 | parser.add_argument( 38 | '--medsam_checkpoint', 39 | type=str, 40 | help="Path to the MedSAM checkpoint.", 41 | required=True 42 | ) 43 | parser.add_argument( 44 | '--max_epochs', 45 | type=int, 46 | default=1000, 47 | help="Maximum number of epochs." 48 | ) 49 | parser.add_argument( 50 | '--batch_size', 51 | type=int, 52 | default=16, 53 | help="Batch size." 54 | ) 55 | parser.add_argument( 56 | '--num_workers', 57 | type=int, 58 | default=8, 59 | help="Number of data loader workers." 60 | ) 61 | parser.add_argument( 62 | '--lr', 63 | type=float, 64 | default=0.00005, 65 | help="learning rate (absolute lr)" 66 | ) 67 | parser.add_argument( 68 | '--weight_decay', 69 | type=float, 70 | default=0.01, 71 | help="Weight decay." 72 | ) 73 | parser.add_argument( 74 | '--accumulate_grad_batches', 75 | type=int, 76 | default=4, 77 | help="Accumulate grad batches." 78 | ) 79 | parser.add_argument( 80 | '--seed', 81 | type=int, 82 | default=2023, 83 | help="Random seed for reproducibility." 84 | ) 85 | parser.add_argument( 86 | '--disable_aug', 87 | action='store_true', 88 | help="Disable data augmentation." 89 | ) 90 | parser.add_argument( 91 | '--freeze_prompt_encoder', 92 | default=True, 93 | action=argparse.BooleanOptionalAction 94 | ) 95 | parser.add_argument( 96 | '--gt_in_ram', 97 | default=True, 98 | action=argparse.BooleanOptionalAction 99 | ) 100 | parser.add_argument( 101 | '--num_points', 102 | type=int, 103 | default=1, 104 | help="Number of points in prompt." 105 | ) 106 | parser.add_argument( 107 | '--mask_diff', 108 | default=False, 109 | action=argparse.BooleanOptionalAction 110 | ) 111 | parser.add_argument( 112 | '--mask_prompt', 113 | default=False, 114 | action=argparse.BooleanOptionalAction 115 | ) 116 | parser.add_argument( 117 | '--base_medsam_checkpoint', 118 | type=str, 119 | default=None, 120 | help="Path to the base predictor (MedSAM) checkpoint." 121 | ) 122 | parser.add_argument( 123 | '--eval_per_organ', 124 | default=False, 125 | action=argparse.BooleanOptionalAction 126 | ) 127 | 128 | return parser 129 | 130 | 131 | def train(exp_name, args): 132 | task = Task.init( 133 | project_name="medsam_point", 134 | tags=[ 135 | "fine_tuning", 136 | "fixed_label_1", 137 | "mask_diff" 138 | # "fixed_label_1", 139 | # "remove_point_embedding", # add tags if neccessary 140 | ], 141 | task_name=exp_name, 142 | ) 143 | 144 | medsam_model = MedSAM( 145 | medsam_checkpoint=args.medsam_checkpoint, 146 | freeze_image_encoder=True, 147 | freeze_prompt_encoder=args.freeze_prompt_encoder, 148 | lr=args.lr, 149 | weight_decay=args.weight_decay, 150 | num_points=args.num_points, 151 | is_mask_diff=args.mask_diff, 152 | is_mask_prompt=args.mask_prompt, 153 | base_medsam_checkpoint=args.base_medsam_checkpoint, 154 | eval_per_organ=args.eval_per_organ, 155 | logger=task.get_logger() 156 | ) 157 | 158 | print(f"MedSAM size: {sum(p.numel() for p in medsam_model.parameters())}") 159 | 160 | datamodule = NpyDataModule( 161 | args.tr_npy_path, 162 | args.val_npy_path, 163 | args.test_npy_path, 164 | batch_size=args.batch_size, 165 | num_workers=args.num_workers, 166 | data_aug=not args.disable_aug, 167 | gt_in_ram=args.gt_in_ram, 168 | ) 169 | datamodule.setup() 170 | 171 | checkpoint_callback = ModelCheckpoint( 172 | dirpath="logs/", 173 | filename=f"{exp_name}-" + "{epoch}-{loss_val:.2f}", 174 | save_top_k=1, 175 | monitor="loss_val", 176 | mode="min", 177 | ) 178 | 179 | early_stop_callback = EarlyStopping( 180 | monitor="loss_val", 181 | min_delta=1e-4, 182 | patience=10, 183 | verbose=False, 184 | mode="min" 185 | ) 186 | 187 | trainer = pl.Trainer( 188 | max_epochs=args.max_epochs, 189 | accumulate_grad_batches=args.accumulate_grad_batches, 190 | callbacks=[checkpoint_callback, early_stop_callback], 191 | accelerator="gpu", 192 | devices=1 193 | ) 194 | trainer.fit( 195 | medsam_model, 196 | train_dataloaders=datamodule.train_dataloader(), 197 | val_dataloaders=datamodule.val_dataloader(), 198 | ) 199 | 200 | test_dice = trainer.test( 201 | medsam_model, 202 | datamodule.test_dataloader() 203 | )[0]["dice_mean/test"] 204 | 205 | return test_dice 206 | 207 | 208 | def main(): 209 | parser = get_parser() 210 | args = parser.parse_args() 211 | 212 | seed = args.seed 213 | torch.cuda.empty_cache() 214 | os.environ['PYTHONHASHSEED'] = str(seed) 215 | random.seed(seed) 216 | np.random.seed(seed) 217 | torch.manual_seed(seed) 218 | torch.cuda.manual_seed(seed) 219 | 220 | exp_name = datetime.now().strftime("%d-%m-%Y-%H:%M:%S") 221 | test_dice = train(exp_name, args) 222 | print(test_dice) 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /test_random.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for checkpoint in "07-07-2024-11:43:57-epoch=34-loss_val=0.26.ckpt" "07-07-2024-14:22:10-epoch=41-loss_val=0.20.ckpt" "07-07-2024-18:04:05-epoch=43-loss_val=0.16.ckpt" "07-07-2024-21:46:09-epoch=22-loss_val=0.14.ckpt" "08-07-2024-00:08:52-epoch=48-loss_val=0.12.ckpt" "04-07-2024-11:38:46-epoch=59-loss_val=0.12.ckpt" "08-07-2024-04:21:35-epoch=122-loss_val=0.10.ckpt" 4 | do 5 | for num in 1 3 5 10 15 20 50 90 100 6 | do 7 | python src/test_model.py \ 8 | --tr_npy_path "data/WORD/train_CT_Abd/" \ 9 | --val_npy_path "data/WORD/val_CT_Abd/" \ 10 | --test_npy_path "data/WORD/test_CT_Abd/" \ 11 | --medsam_checkpoint "weights/medsam/medsam_vit_b.pth" \ 12 | --checkpoint $checkpoint \ 13 | --batch_size 24 \ 14 | --num_workers 0 \ 15 | --num_points $num 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0014-000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0014-000.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0014-050.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0014-050.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0016-099.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0016-099.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0017-020.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0017-020.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0019-110.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0019-110.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0019-129.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0019-129.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0021-001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0021-001.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0021-021.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0021-021.npy -------------------------------------------------------------------------------- /test_samples/gts/CT_Abd_word_0024-100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/gts/CT_Abd_word_0024-100.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0014-000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0014-000.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0014-050.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0014-050.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0016-099.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0016-099.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0017-020.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0017-020.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0019-110.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0019-110.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0019-129.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0019-129.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0021-001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0021-001.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0021-021.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0021-021.npy -------------------------------------------------------------------------------- /test_samples/imgs/CT_Abd_word_0024-100.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leiluk1/gaze-based-segmentation/f7b97320b8fac3a78bab20bfeaa3d8502fdae079/test_samples/imgs/CT_Abd_word_0024-100.npy -------------------------------------------------------------------------------- /weights/README.md: -------------------------------------------------------------------------------- 1 | ## Model weights 2 | 3 | Download checkpoints [ViT-B SAM](https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints) and [ViT-B MedSAM](https://github.com/bowang-lab/MedSAM/tree/main?tab=readme-ov-file#get-started). Place them in `sam` and `medsam` folders accordingly. 4 | --------------------------------------------------------------------------------