├── 10453.png ├── README.md ├── README_zh.md ├── dataset.py ├── datasets └── OCTA-500 │ └── OCTA_3M │ ├── GT_Artery │ └── 10301.bmp │ ├── GT_Capillary │ └── 10301.bmp │ ├── GT_FAZ │ └── 10301.bmp │ ├── GT_LargeVessel │ └── 10301.bmp │ ├── GT_Vein │ └── 10301.bmp │ └── ProjectionMaps │ ├── OCTA(FULL) │ └── 10301.bmp │ ├── OCTA(ILM_OPL) │ └── 10301.bmp │ └── OCTA(OPL_BM) │ └── 10301.bmp ├── display.py ├── figures ├── pred_rv_global.png ├── pred_rv_local.png ├── pred_rv_local2.png ├── sample_3ch.png ├── sample_3ch_prompt.png ├── sample_3ch_prompt2.png ├── sample_FAZ.gif ├── sample_RV.gif ├── sample_artery.gif └── sample_capillary.gif ├── loss_functions.py ├── metrics.py ├── options.py ├── predict.py ├── prompt_points.py ├── sam_lora_image_encoder.py ├── sam_weights └── keep_dir.txt ├── segment_anything ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-37.pyc │ ├── automatic_mask_generator.cpython-310.pyc │ ├── build_sam.cpython-310.pyc │ ├── build_sam.cpython-37.pyc │ └── predictor.cpython-310.pyc ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── common.cpython-310.pyc │ │ ├── image_encoder.cpython-310.pyc │ │ ├── mask_decoder.cpython-310.pyc │ │ ├── prompt_encoder.cpython-310.pyc │ │ ├── sam.cpython-310.pyc │ │ └── transformer.cpython-310.pyc │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── amg.cpython-310.pyc │ └── transforms.cpython-310.pyc │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── test_sam_octa.py └── train_sam_octa.py /10453.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/10453.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAM-OCTA 2 | 3 | 中文版README: [README_zh](./README_zh.md) 4 | 5 | ## 1.Quick Start 6 | 7 | This project involves fine-tuning SAM using LoRA and performing segmentation tasks on OCTA images, built with **PyTorch**. 8 | 9 | First, you should put a pertrained weight file in the **sam_weights** folder. The download link for pre-trained weights is as follows: 10 | 11 | vit_h (default): https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 12 | 13 | vit_l: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 14 | 15 | vit_b: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 16 | 17 | After testing, the required RAM for the three models is as follows: **36,248 MB, 26,154 MB, 13,467 MB**. The "vit_h" is the default option. If you need to use other smaller models, please download the corresponding weights and modify the configuration in **options.py**. 18 | 19 | ... 20 | parser.add_argument("-model_type", type=str, default="vit_h") 21 | ... 22 | 23 | Use **train_sam_octa.py** to start fine-tuning. The warning informations will tell you which packages you should install. These packages are commonly used Python libraries without additional configuration. 24 | 25 | python train_sam_octa.py 26 | 27 | The dataset should be formed as **OCTA-500**, like this: 28 | 29 | /datasets 30 | /OCTA-500 31 | /OCTA_3M 32 | /GT_Artery 33 | 10301.bmp 34 | 10302.bmp 35 | ... 36 | /GT_Capillary 37 | 10301.bmp 38 | 10302.bmp 39 | ... 40 | /GT_FAZ 41 | ... 42 | /ProjectionMaps 43 | /OCTA(FULL) 44 | 10301.bmp 45 | 10302.bmp 46 | ... 47 | /OCTA(ILM_OPL) 48 | 10301.bmp 49 | 10302.bmp 50 | ... 51 | /OCTA(OPL_BM) 52 | 10301.bmp 53 | 10302.bmp 54 | ... 55 | /OCTA_6M 56 | ... 57 | 58 | Here, I used the sample with ID 10301 from the **OCTA_500** dataset of 3M FoV (Field of View) as an example. If you need the complete dataset, please contact the author of the **OCTA_500** dataset. 59 | 60 | **OCTA-500**'s related paper: https://arxiv.org/abs/2012.07261 61 | 62 | The results and metrics will recorded in the **results** folder (If it doesn't exist, it will be created). 63 | 64 | If you need to visualize the prediction samples of results, please use the **display.py** file. Since the result folders are generated based on time, you may need to replace this line of code. The generated images are in the **sample_display** folder. 65 | 66 | .. 67 | test_dir = "results/2024-01-01-08-17-09/3M_LargeVessel_100_True/0/0000" # Your result dir 68 | ... 69 | 70 | Here are some segmentation samples with prompt points, respectively the input image, the ground-truth and the prediction from left to right. 71 | 72 | **Local Model** 73 | 74 | *Artery* 75 | 76 | ![Sample](./figures/sample_artery.gif) 77 | 78 | *FAZ* 79 | 80 | ![Sample](./figures/sample_FAZ.gif) 81 | 82 | **Global Model** 83 | 84 | *RV* 85 | 86 | ![Sample](./figures/sample_RV.gif) 87 | 88 | *Capillary* 89 | 90 | ![Sample](./figures/sample_capillary.gif) 91 | 92 | ## 2.Configuration 93 | 94 | The project can support multiple segmentation tasks and it has two modes: **global** and **local**. In fact, the performance in the global mode is comparable to other segmentation models, while the local mode is unique to SAM-OCTA. In the **options.py** file, you can configure it, and below are explanations for each option: 95 | 96 | * -device: Specifies the IDs of available GPUs. It can support multiple GPUs, but due to the SAM code implementation by Meta, the batch_size should be equal to the number of GPUs used. However, the dataloader need to align the batch size, it is preferable to train with **batch_size=1** to avoid the prompt points mistake of different length. 97 | * -epochs: Specifies the number of training epochs. 98 | * -lr: The maximum learning rate, considering the warm-up strategy. 99 | * -check_interval: Specifies how often to save results (including weights) after a certain number of training epochs. 100 | * -k_fold: Specifies the number of folds for k-fold cross-validation. 101 | * -prompt_positive_num: Number of positive prompt points, -1 for random. 102 | * -prompt_total_num: Total number of prompt points, -1 for random. 103 | * -model_type: Selects the SAM model for fine-tuning: "vit_h", "vit_l", and "vit_b". 104 | * -is_local: Specifies whether it is in local mode. 105 | * -remark: Some remarks you need to fill in, which will be added to the generated result folder name. 106 | 107 | The following are some configurations specific to the OCTA-500 dataset: 108 | 109 | * -fov: Selects the sub-dataset corresponding to the field of view. 110 | * -label_type: Selects the annotation type (segmentation task type): "LargeVessel", "FAZ", "Capillary", "Artery", and "Vein". 111 | * -metrics: Selects the metrics to be computed (can select multiple): "Dice", "Jaccard", "Hausdorff". 112 | 113 | ## 3.Others 114 | 115 | If you find the information useful, please cite the relevant paper: https://arxiv.org/abs/2309.11758 116 | 117 | **Pretrained Weights (Baidu Cloud Storage)**: 118 | 119 | Link:https://pan.baidu.com/s/1S43QadZlhT8dL8TPbA0N6g?pwd=sifh 120 | 121 | Password:sifh 122 | 123 | ## 4.Instance Prediction (Supplement) 124 | Here, I provide additional code for vessel prediction, along with explanations through text and images. 125 | 126 | 1. Prepare an Image for Prediction. Start by preparing an image that you want to predict. In the provided code, I process the image by stacking its three channels and then duplicating it side-by-side. The duplicated version is used for manual annotation of prompt points (pure green for positive points, pure red for negative points). It looks something like this: 127 | 128 | ![Sample](./figures/sample_3ch.png) 129 | 130 | 131 | 2. Load the Pretrained Weights. Since this is just an example, I use a fine-tuned ViT-L model, which requires less memory and computation time. The provided weights combine both global and local prediction modes. You can download the weights from the following link: 132 | 133 | https://pan.baidu.com/s/1iCVmPaLOWVk36YbgcQ4AOg?pwd=i54c password: i54c 134 | 135 | Then, run the script predict.py, and the results will be saved in an automatically generated folder named prediction. 136 | 137 | 3. In global mode, no prompt points are needed. I automatically added a fixed negative point at [-100, -100] in the code. Let's take a look at the segmentation results: 138 | 139 | ![Sample](./figures/pred_rv_global.png) 140 | 141 | 4. In local mode, prompt points are provided on the vessels, for example: 142 | 143 | ![Sample](./figures/sample_3ch_prompt.png) 144 | ![Sample](./figures/sample_3ch_prompt2.png) 145 | 146 | The results for the provided prompt points are as follows: 147 | 148 | ![Sample](./figures/pred_rv_local.png) 149 | ![Sample](./figures/pred_rv_local2.png) 150 | 151 | 152 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # SAM-OCTA 2 | 3 | 4 | ## 1.快速上手 5 | 6 | 这是一个使用 LoRA 对 SAM 进行微调,并在 OCTA 图像上执行分割任务的项目, 使用 **PyTorch** 构建。 7 | 8 | 9 | 首先,您应该将一个预训练的权重文件放入 **sam_weights** 文件夹中。预训练权重的下载链接如下: 10 | 11 | vit_h (default): https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 12 | 13 | vit_l: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 14 | 15 | vit_b: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 16 | 17 | 经过测试,三种模型所需要的显存分别为:**36,248、 26,154、 13,467** MB。其中vit_h是默认使用的,如果您需要使用其他更小的模型,请下载对应权重,并修改对应配置项。 18 | 19 | ... 20 | parser.add_argument("-model_type", type=str, default="vit_h") 21 | ... 22 | 23 | 使用 **train_sam_octa.py** 来开始进行微调。警告信息将指导您应该去安装哪些包。这些包都是常用的python库,不需要额外的配置。 24 | 25 | python train_sam_octa.py 26 | 27 | 数据集应该按照 OCTA-500 的形式组织,就像这样: 28 | 29 | /datasets 30 | /OCTA-500 31 | /OCTA_3M 32 | /GT_Artery 33 | 10301.bmp 34 | 10302.bmp 35 | ... 36 | /GT_Capillary 37 | 10301.bmp 38 | 10302.bmp 39 | ... 40 | /GT_FAZ 41 | ... 42 | /ProjectionMaps 43 | /OCTA(FULL) 44 | 10301.bmp 45 | 10302.bmp 46 | ... 47 | /OCTA(ILM_OPL) 48 | 10301.bmp 49 | 10302.bmp 50 | ... 51 | /OCTA(OPL_BM) 52 | 10301.bmp 53 | 10302.bmp 54 | ... 55 | /OCTA_6M 56 | ... 57 | 58 | 这里我使用了 **OCTA_500** 中 FoV (视场角) 为 3M 的id为10301样本作为一个示例,如果需要完整的数据集,需要联系 **OCTA_500** 数据集的作者。 59 | 60 | **OCTA-500**'s related paper: https://arxiv.org/abs/2012.07261 61 | 62 | 示例结果和分割指标将被记录在 **results** 文件夹中(如果不存在,则这个文件夹将被自动创建)。 63 | 64 | 如果您需要对预测结果进行可视化,请使用 **display.py** 文件。由于结果文件夹是按时间生成的,需要对这一行代码进行替换。生成的图像存放在 **sample_display** 文件夹中。 65 | 66 | ... 67 | test_dir = "results/2024-01-01-08-17-09/3M_LargeVessel_100_True/0/0000" # Your result dir 68 | ... 69 | 70 | 这是一些带有提示点的分割的示例,从左到右分别是输入图像、标注以及预测结果。 71 | 72 | **局部模式** 73 | 74 | *动脉* 75 | 76 | ![Sample](./figures/sample_artery.gif) 77 | 78 | *中心无血管区* 79 | 80 | ![Sample](./figures/sample_FAZ.gif) 81 | 82 | **全局模式** 83 | 84 | *视网膜血管* 85 | 86 | ![Sample](./figures/sample_RV.gif) 87 | 88 | *毛细血管* 89 | 90 | ![Sample](./figures/sample_capillary.gif) 91 | 92 | 93 | ## 2.相关配置 94 | 95 | 该项目能够支持多个分割任务,并且分为**全局**和**局部**两种模式。事实上,全局模式下的性能与其他分割模型相差无几,局部模式则是SAM-OCTA所独有。在 **options.py** 文件中,可以对其进行配置,以下是各个选项的说明: 96 | 97 | * -device:指定可用显卡的id,可以支持多张显卡,但是由于Meta代码实现的原因,batch_size应该和所使用的显卡数量一致。然而又因为dataloader需要把不同样本对齐,所以最好是 **batch_size=1** 地进行训练,以避免提示点的长度对不齐所造成的报错。 98 | * -epochs: 训练多少轮。 99 | * -lr: 由于采用了warm-up策略,这里指的是最大学习率。 100 | * -check_interval: 间隔多少轮训练后保存一次结果(包括权重)。 101 | * -k_fold:k折交叉验证。 102 | * -prompt_positive_num:正提示点数量,-1为随机。 103 | * -prompt_total_num: 总提示点数量,-1为随机。 104 | * -model_type: 选择SAM的训练模型:"vit_h", "vit_l" 以及 "vit_b"。 105 | * -is_local: 是否为局部模式。 106 | * -remark: 一些你需要填写的备注信息,会添加到生成的结果文件夹名中。 107 | 108 | 以下是针对 **OCTA-500** 数据集的一些配置: 109 | * -fov: 选择视场角对应的子数据集。 110 | * -label_type: 选择标注类型(分割任务类型):"LargeVessel", "FAZ", "Capillary", "Artery" 以及 "Vein"。 111 | * -metrics: 选择需要统计的指标(可多选): "Dice", "Jaccard", "Hausdorff" 112 | 113 | 114 | ## 3.其他 115 | 116 | 如果觉得有用请引用相关论文: https://arxiv.org/abs/2309.11758 117 | 118 | **预训练权重(百度网盘)**: 119 | 120 | 链接:https://pan.baidu.com/s/1S43QadZlhT8dL8TPbA0N6g?pwd=sifh 121 | 122 | 提取码:sifh 123 | 124 | ## 4.实例预测(补充) 125 | 126 | 这里我添加一份关于血管的预测代码,并结合图文给出相关说明。 127 | 128 | 1. 首先准备一张你要预测的图片,我在代码中的处理是将三个通道叠加后,再并排复制一份。复制的一份是用来方便手动标记提示点的(纯绿色正点,纯红色负点),总之就像这样。 129 | 130 | ![Sample](./figures/sample_3ch.png) 131 | 132 | 2. 然后加载这个权重,由于只是示例,所以我直接使用了vit-l来微调,花费时间短且内存占用少些。并且这个权重结合了全局和局部两种模式。权重下载链接为: 133 | 134 | 135 | https://pan.baidu.com/s/1iCVmPaLOWVk36YbgcQ4AOg?pwd=i54c 提取码: i54c 136 | 137 | 138 | 然后运行 __predict.py__,结果保存在一个自动生成的 __predition__ 的文件夹中。 139 | 140 | 3. 全局模式下,无需提供提示点,我在代码中自动加了一个[-100, -100]的固定负点,让我们看看分割效果: 141 | 142 | ![Sample](./figures/pred_rv_global.png) 143 | 144 | 4. 局部模式下,在血管上给出提示点,例如: 145 | 146 | ![Sample](./figures/sample_3ch_prompt.png) 147 | ![Sample](./figures/sample_3ch_prompt2.png) 148 | 149 | 对应效果如图: 150 | 151 | ![Sample](./figures/pred_rv_local.png) 152 | ![Sample](./figures/pred_rv_local2.png) 153 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from prompt_points import label_to_point_prompt_local, label_to_point_prompt_global 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import cv2 5 | import os 6 | import random 7 | import numpy as np 8 | from collections import * 9 | 10 | from display import show_result_sample_figure 11 | import albumentations as alb 12 | from tqdm import tqdm 13 | 14 | 15 | class octa500_2d_dataset(Dataset): 16 | def __init__(self, 17 | fov="3M", 18 | label_type="LargeVessel", 19 | prompt_positive_num=-1, 20 | prompt_negative_num=-1, 21 | is_local=True, 22 | is_training=True): 23 | 24 | self.prompt_positive_num = prompt_positive_num 25 | self.prompt_negative_num = prompt_negative_num 26 | self.is_local = is_local 27 | self.is_training = is_training 28 | 29 | layers = ["OPL_BM", "ILM_OPL", "FULL"] 30 | data_dir = "datasets/OCTA-500" 31 | modal = "OCTA" 32 | label_dir = "{}/OCTA_{}/GT_{}".format(data_dir, fov, label_type) 33 | self.sample_ids = [x[:-4] for x in sorted(os.listdir(label_dir))] 34 | images = [] 35 | for sample_id in self.sample_ids: 36 | image_channels = [] 37 | for layer in layers: 38 | image_path = "{}/OCTA_{}/ProjectionMaps/{}({})/{}.bmp".format(data_dir, fov, modal, layer, sample_id) 39 | image_channels.append(cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)) 40 | images.append(np.array(image_channels)) 41 | self.images = images 42 | 43 | load_label = lambda sample_id: cv2.imread("{}/{}.bmp".format(label_dir, sample_id), cv2.IMREAD_GRAYSCALE) / 255 44 | self.labels = [load_label(x) for x in self.sample_ids] 45 | 46 | prob = 0.3 47 | self.transform = alb.Compose([ 48 | alb.RandomBrightnessContrast(p=prob), 49 | alb.CLAHE(p=prob), 50 | # alb.SafeRotate(limit=15, p=prob), 51 | alb.VerticalFlip(p=prob), 52 | alb.HorizontalFlip(p=prob), 53 | alb.AdvancedBlur(p=prob), 54 | ]) 55 | 56 | def __len__(self): 57 | return len(self.images) 58 | 59 | def __getitem__(self, index): 60 | image, prompt_points, prompt_type, selected_component = self.get_sam_item(self.images[index], self.labels[index]) 61 | return image, prompt_points, prompt_type, selected_component, self.sample_ids[index] 62 | 63 | def get_sam_item(self, image, label): 64 | if self.is_training: 65 | transformed = self.transform(**{"image": image.transpose((1,2,0)), "mask": label[np.newaxis,:].transpose((1,2,0))}) 66 | image, label = transformed["image"].transpose((2,0,1)), transformed["mask"].transpose((2,0,1))[0] 67 | ppn, pnn = self.prompt_positive_num, self.prompt_negative_num 68 | if self.is_local: 69 | random_max = 4 70 | if ppn == -1: ppn = random.randint(0, random_max) 71 | if pnn == -1: pnn = random.randint(int(ppn == 0), random_max) 72 | selected_component, prompt_points_pos, prompt_points_neg = label_to_point_prompt_local(label, ppn, pnn) 73 | else: 74 | selected_component, prompt_points_pos, prompt_points_neg = label_to_point_prompt_global(label, ppn, pnn) 75 | 76 | prompt_type = np.array([1] * len(prompt_points_pos) + [0] * len(prompt_points_neg)) 77 | prompt_points = np.array(prompt_points_pos + prompt_points_neg) 78 | 79 | return image, prompt_points, prompt_type, selected_component 80 | 81 | # if __name__=="__main__": 82 | # dataset = octa500_2d_dataset(is_local=True, prompt_positive_num=1, is_training=True) 83 | # for image, prompt_points, prompt_type, selected_component, sample_id in dataset: 84 | # print(np.max(selected_component)) -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/GT_Artery/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/GT_Artery/10301.bmp -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/GT_Capillary/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/GT_Capillary/10301.bmp -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/GT_FAZ/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/GT_FAZ/10301.bmp -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/GT_LargeVessel/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/GT_LargeVessel/10301.bmp -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/GT_Vein/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/GT_Vein/10301.bmp -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/ProjectionMaps/OCTA(FULL)/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/ProjectionMaps/OCTA(FULL)/10301.bmp -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/ProjectionMaps/OCTA(ILM_OPL)/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/ProjectionMaps/OCTA(ILM_OPL)/10301.bmp -------------------------------------------------------------------------------- /datasets/OCTA-500/OCTA_3M/ProjectionMaps/OCTA(OPL_BM)/10301.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/datasets/OCTA-500/OCTA_3M/ProjectionMaps/OCTA(OPL_BM)/10301.bmp -------------------------------------------------------------------------------- /display.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import cv2 4 | import os 5 | from tqdm import tqdm 6 | 7 | alpha = 0.5 8 | 9 | overlay = lambda x, y: cv2.addWeighted(x, alpha, y, 1-alpha, 0) 10 | 11 | # 灰度图像->单通道图像, Grayscale image -> single-channel image 12 | to_blue = lambda x: np.array([x, np.zeros_like(x), np.zeros_like(x)]).transpose((1,2,0)).astype(dtype=np.uint8) 13 | to_red = lambda x: np.array([np.zeros_like(x), np.zeros_like(x), x]).transpose((1,2,0)).astype(dtype=np.uint8) 14 | to_green = lambda x: np.array([np.zeros_like(x), x, np.zeros_like(x)]).transpose((1,2,0)).astype(dtype=np.uint8) 15 | to_light_green = lambda x: np.array([np.zeros_like(x), x / 2, np.zeros_like(x)]).transpose((1,2,0)).astype(dtype=np.uint8) 16 | to_yellow = lambda x: np.array([np.zeros_like(x), x, x]).transpose((1,2,0)).astype(dtype=np.uint8) 17 | 18 | to_3ch = lambda x: np.array([x,x,x]).transpose((1,2,0)).astype(dtype=np.uint8) 19 | 20 | def show_result_sample_figure(image, label, pred, prompt_points): 21 | sz = image.shape[-1] // 100 22 | cvt_img = lambda x: x.astype(np.uint8) 23 | image, label, pred = map(cvt_img, (image, label, pred)) 24 | if len(image.shape) == 2: image = to_3ch(image) 25 | else: image = image.transpose((1, 2, 0)) 26 | label, pred = cv2.resize(label, image.shape[:2]), cv2.resize(pred, image.shape[:2]) 27 | label_img = overlay(image, to_light_green(label)) 28 | pred_img = overlay(image, to_yellow(pred)) 29 | def draw_points(img): 30 | for x, y, type in prompt_points: 31 | cv2.circle(img, (x, y), int(1.5 * sz), (255, 0, 0), -1) 32 | if type: cv2.circle(img, (x, y), sz, (0, 255, 0), -1) 33 | else: cv2.circle(img, (x, y), sz, (0, 0, 255), -1) 34 | draw_points(label_img) 35 | draw_points(pred_img) 36 | return np.concatenate((image, label_img, pred_img), axis=1) 37 | 38 | def show_prompt_points_image(image, positive_region, negative_region, positive_points, negative_points, save_file=None): 39 | overlay_img = overlay(to_red(negative_region), to_yellow(positive_region)) 40 | overlay_img = overlay(to_3ch(image), overlay_img) 41 | 42 | for x, y in positive_points: cv2.circle(overlay_img, (x, y), 4, (0, 255, 0), -1) 43 | for x, y in negative_points: cv2.circle(overlay_img, (x, y), 4, (0, 0, 255), -1) 44 | 45 | if save_file: cv2.imwrite(save_file, overlay_img) 46 | 47 | return overlay_img 48 | 49 | def view_result_samples(result_dir): 50 | 51 | save_dir = "sample_display/{}".format(result_dir[len("results/"):]) 52 | if not os.path.exists(save_dir): os.makedirs(save_dir) 53 | file_names = [x[-9:-4] for x in os.listdir(result_dir) if "label" in x] 54 | data_name = [x[:-16] for x in os.listdir(result_dir) if "label" in x][0] 55 | for file_name in tqdm(file_names): 56 | label = np.load("{}/{}_label_{}.npy".format(result_dir, data_name, file_name)) 57 | pred = np.load("{}/{}_pred_{}.npy".format(result_dir, data_name, file_name)) 58 | prompt_info = np.load("{}/{}_prompt_info_{}.npy".format(result_dir, data_name, file_name)) 59 | image = np.load("{}/{}_sample_{}.npy".format(result_dir, data_name, file_name)) 60 | 61 | result = show_result_sample_figure(image* 255, label * 255, pred * 255, prompt_info) 62 | cv2.imwrite("{}/{}.png".format(save_dir, file_name), result) 63 | 64 | 65 | if __name__=="__main__": 66 | result_dir = "results/2024-01-01-10-46-22/3M_LargeVessel_100_True/0/0020" 67 | view_result_samples(result_dir) -------------------------------------------------------------------------------- /figures/pred_rv_global.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/pred_rv_global.png -------------------------------------------------------------------------------- /figures/pred_rv_local.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/pred_rv_local.png -------------------------------------------------------------------------------- /figures/pred_rv_local2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/pred_rv_local2.png -------------------------------------------------------------------------------- /figures/sample_3ch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/sample_3ch.png -------------------------------------------------------------------------------- /figures/sample_3ch_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/sample_3ch_prompt.png -------------------------------------------------------------------------------- /figures/sample_3ch_prompt2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/sample_3ch_prompt2.png -------------------------------------------------------------------------------- /figures/sample_FAZ.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/sample_FAZ.gif -------------------------------------------------------------------------------- /figures/sample_RV.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/sample_RV.gif -------------------------------------------------------------------------------- /figures/sample_artery.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/sample_artery.gif -------------------------------------------------------------------------------- /figures/sample_capillary.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/figures/sample_capillary.gif -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class DiceLoss(torch.nn.Module): 5 | def __init__(self, smooth=1.): 6 | super(DiceLoss, self).__init__() 7 | self.smooth = smooth 8 | 9 | def forward(self, pred, target): 10 | intersection = (pred * target).sum() 11 | denominator = pred.sum() + target.sum() 12 | dice_score = (2. * intersection + self.smooth) / (denominator + self.smooth) 13 | dice_loss = 1. - dice_score 14 | return dice_loss 15 | 16 | class clDiceLoss(torch.nn.Module): 17 | def __init__(self, smooth=1.): 18 | super(clDiceLoss, self).__init__() 19 | self.smooth = smooth 20 | 21 | def soft_cldice_loss(self, pred, target, target_skeleton=None): 22 | ''' 23 | inputs shape (batch, channel, height, width). 24 | calculate clDice loss 25 | Because pred and target at moment of loss calculation will be a torch tensors 26 | it is preferable to calculate target_skeleton on the step of batch forming, 27 | when it will be in numpy array format by means of opencv 28 | ''' 29 | cl_pred = self.soft_skeletonize(pred) 30 | if target_skeleton is None: 31 | target_skeleton = self.soft_skeletonize(target) 32 | iflat = self.norm_intersection(cl_pred, target) 33 | tflat = self.norm_intersection(target_skeleton, pred) 34 | intersection = (iflat * tflat).sum() 35 | return 1. - (2. * intersection) / (iflat + tflat).sum() 36 | 37 | def dice_loss(self, pred, target): 38 | ''' 39 | inputs shape (batch, channel, height, width). 40 | calculate dice loss per batch and channel of sample. 41 | E.g. if batch shape is [64, 1, 128, 128] -> [64, 1] 42 | ''' 43 | intersection = (pred * target).sum() 44 | denominator = pred.sum() + target.sum() 45 | dice_score = (2. * intersection + self.smooth) / (denominator + self.smooth) 46 | dice_loss = 1. - dice_score 47 | return dice_loss 48 | 49 | def soft_skeletonize(self, x, thresh_width=10): 50 | ''' 51 | Differenciable aproximation of morphological skelitonization operaton 52 | thresh_width - maximal expected width of vessel 53 | ''' 54 | for i in range(thresh_width): 55 | min_pool_x = torch.nn.functional.max_pool2d(x * -1, (3, 3), 1, 1) * -1 56 | contour = torch.nn.functional.relu(torch.nn.functional.max_pool2d(min_pool_x, (3, 3), 1, 1) - min_pool_x) 57 | x = torch.nn.functional.relu(x - contour) 58 | return x 59 | 60 | def norm_intersection(self, center_line, vessel): 61 | ''' 62 | inputs shape (batch, channel, height, width) 63 | intersection formalized by first ares 64 | x - suppose to be centerline of vessel (pred or gt) and y - is vessel (pred or gt) 65 | ''' 66 | smooth = 1. 67 | clf = center_line.view(*center_line.shape[:2], -1) 68 | vf = vessel.view(*vessel.shape[:2], -1) 69 | intersection = (clf * vf).sum(-1) 70 | return (intersection + smooth) / (clf.sum(-1) + smooth) 71 | 72 | def forward(self, pred, target): 73 | return 0.8 * self.dice_loss(pred, target) + 0.2 * self.soft_cldice_loss(pred, target) 74 | 75 | class FocalLoss(torch.nn.Module): 76 | def __init__(self, alpha=1, gamma=2, reduction='mean'): 77 | super(FocalLoss, self).__init__() 78 | self.alpha = alpha 79 | self.gamma = gamma 80 | self.reduction = reduction 81 | 82 | def forward(self, inputs, targets): 83 | ce_loss = F.cross_entropy(inputs, targets, reduction='none') 84 | pt = torch.exp(-ce_loss) 85 | focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss 86 | 87 | if self.reduction == 'mean': 88 | return focal_loss.mean() 89 | elif self.reduction == 'sum': 90 | return focal_loss.sum() 91 | else: 92 | return focal_loss 93 | 94 | class MCELoss(torch.nn.Module): 95 | def __init__(self, reduction='mean'): 96 | super(MCELoss, self).__init__() 97 | self.reduction = reduction 98 | 99 | def forward(self, inputs, targets): 100 | predicted_classes = inputs.argmax(dim=1) 101 | incorrect_predictions = predicted_classes != targets 102 | mce_loss = incorrect_predictions.float().mean() 103 | 104 | if self.reduction == 'mean': 105 | return mce_loss 106 | elif self.reduction == 'sum': 107 | return mce_loss * inputs.size(0) 108 | else: 109 | return mce_loss -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.distance import directed_hausdorff 2 | from collections import * 3 | import pandas as pd 4 | from statistics import mean 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | class MetricsStatistics: 8 | def __init__(self, save_dir="./results/"): 9 | self.epsilon = 1e-6 10 | self.func_dct = { 11 | "Precision": self.cal_precision, 12 | "Recall": self.cal_recall, 13 | "Specificity": self.cal_specificity, 14 | "Jaccard": self.cal_jaccard_index, 15 | "Dice": self.cal_dice, 16 | "Hausdorff": self.cal_hausdorff 17 | } 18 | self.save_dir = save_dir 19 | self.metric_values = defaultdict(list) # check epoch 临时用 20 | self.metric_epochs = defaultdict(list) # 保存了指定epoch的各样本平均值 21 | self.summary_writer = SummaryWriter(log_dir=save_dir) 22 | 23 | def cal_epoch_metric(self, metrics, label_type, label, pred): # 计算并保存样本指标 24 | for x in metrics:self.metric_values["{}-{}".format(x, label_type)].append(self.func_dct[x](label, pred)) 25 | 26 | def record_result(self, epoch): 27 | self.metric_epochs["epoch"].append(epoch) 28 | for k, v in self.metric_values.items(): 29 | self.summary_writer.add_scalar(k, mean(v), epoch) 30 | self.metric_epochs[k].append(mean(v)) 31 | pd.DataFrame(self.metric_epochs).to_excel("{}/metrics_statistics.xlsx".format(self.save_dir), index=False) 32 | self.metric_values.clear() 33 | 34 | def close(self): 35 | self.summary_writer.close() 36 | 37 | def cal_confusion_matrix(self, pred, label): 38 | TP = ((pred == 1) & (label == 1)).sum().item() 39 | FP = ((pred == 0) & (label == 1)).sum().item() 40 | FN = ((pred == 1) & (label == 0)).sum().item() 41 | TN = ((pred == 0) & (label == 0)).sum().item() 42 | return TP, FP, FN, TN 43 | 44 | def cal_precision(self, pred, label): 45 | TP, FP, FN, TN = self.cal_confusion_matrix(pred, label) 46 | return TP / (TP + FP + self.epsilon) 47 | 48 | def cal_recall(self, pred, label): 49 | TP, FP, FN, TN = self.cal_confusion_matrix(pred, label) 50 | return TP / (TP + FN + self.epsilon) 51 | 52 | def cal_specificity(self, pred, label): 53 | TP, FP, FN, TN = self.cal_confusion_matrix(pred, label) 54 | return TN / (TN + FP + self.epsilon) 55 | 56 | def cal_jaccard_index(self, pred, label): 57 | intersection = (pred & label).sum().item() 58 | union = (pred | label).sum().item() 59 | jaccard_index = intersection / (union + self.epsilon) 60 | return jaccard_index 61 | 62 | def cal_dice(self, pred, label): 63 | intersection = (pred & label).sum().item() 64 | union = pred.sum().item() + label.sum().item() 65 | dice = 2 * intersection / (union + self.epsilon) 66 | return dice 67 | 68 | def cal_hausdorff(self, pred, label): 69 | array1 = pred.cpu().numpy() 70 | array2 = label.cpu().numpy() 71 | dist1 = directed_hausdorff(array1, array2)[0] 72 | dist2 = directed_hausdorff(array2, array1)[0] 73 | hausdorff_dist = max(dist1, dist2) 74 | return hausdorff_dist -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='training argument values') 4 | 5 | def add_training_parser(parser): 6 | parser.add_argument("-device", type=str, default="0") 7 | parser.add_argument("-epochs", type=int, default=100) 8 | parser.add_argument("-lr", type=float, default=1e-4) 9 | parser.add_argument("-check_interval", type=int, default=10) 10 | parser.add_argument("-k_fold", type=str, default=10) 11 | parser.add_argument("-prompt_positive_num", type=int, default=1) 12 | parser.add_argument("-prompt_negative_num", type=int, default=1) 13 | parser.add_argument("-model_type", type=str, default="vit_h") 14 | parser.add_argument("-is_local", type=bool, default=False) 15 | parser.add_argument("-remark", type=str, default="OCTA-500") 16 | 17 | def add_cell_parser(parser): 18 | parser.add_argument("-metrics", type=list, default=["Dice", "Jaccard", "Hausdorff"]) 19 | 20 | def add_octa500_2d_parser(parser): 21 | parser.add_argument("-fov", type=str, default="3M") 22 | parser.add_argument("-label_type", type=str, default="LargeVessel") #"LargeVessel", "FAZ", "Capillary", "Artery", "Vein" 23 | parser.add_argument("-metrics", type=list, default=["Dice", "Jaccard", "Hausdorff"]) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from scipy.ndimage import label, center_of_mass 4 | 5 | # system 6 | import os, random, time, GPUtil 7 | from tqdm import tqdm 8 | from collections import * 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 11 | 12 | # torch 13 | import torch 14 | import torch.optim as optim 15 | from torch.nn import DataParallel 16 | 17 | # SAM 18 | from segment_anything import * 19 | from sam_lora_image_encoder import LoRA_Sam 20 | from segment_anything.utils.transforms import ResizeLongestSide 21 | 22 | class DisplayManager: 23 | def __init__(self, save_dir="display"): 24 | alpha = 0.5 25 | self.overlay = lambda x, y: cv2.addWeighted(x, alpha, y, 1-alpha, 0) 26 | self.to_3ch = lambda x: np.array([x,x,x]).transpose((1,2,0)).astype(dtype=np.uint8) 27 | self.to_color = lambda x, color: (self.to_3ch(x) * color).astype(dtype=np.uint8) 28 | self.to_visible = lambda x : (x * 255 if x.max() <= 1 else x).astype(np.uint8) 29 | 30 | self.point_color_dct = {True:(0, 255, 0), False:(0, 0, 255)} 31 | self.point_size = 10 32 | self.save_dir = save_dir 33 | os.makedirs(self.save_dir, exist_ok=True) 34 | 35 | def display_prompt(self, image, mask, prompts, save_name="temp"): 36 | image = self.overlay(image, self.to_color(mask, (0, 1, 1))) 37 | 38 | for x, y, z in prompts: 39 | cv2.circle(image, (int(x), int(y)), self.point_size, self.point_color_dct[z], -1) 40 | cv2.imwrite("{}/{}.png".format(self.save_dir, save_name), image) 41 | 42 | def display_predict(self, image, label, pred, prompt_points, prompt_type, save_name="temp"): 43 | to_numpy = lambda x : x.numpy() 44 | image, label, pred, prompt_points, prompt_type = map(to_numpy, [image, label, pred, prompt_points, prompt_type]) 45 | image, label, pred = map(self.to_visible, [image, label, pred]) 46 | 47 | image = self.to_3ch(image) 48 | 49 | image_pred = self.overlay(image, self.to_color(pred, (0, 1, 1))) 50 | image_pred_prompt = image_pred.copy() 51 | 52 | for (x, y), z in zip(prompt_points, prompt_type): 53 | cv2.circle(image_pred_prompt, (int(x), int(y)), self.point_size, self.point_color_dct[z], -1) 54 | 55 | merged_image = np.concatenate([image, image_pred_prompt], axis=1) 56 | 57 | cv2.imwrite("{}/{}.png".format(self.save_dir, save_name), merged_image) 58 | 59 | 60 | class PredictManager_OCTA: 61 | def __init__(self, weight_path, save_dir="predition", model_type="vit_l"): 62 | self.device_ids = "0" 63 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | self.save_dir = save_dir 65 | self.to_cuda = lambda x: x.to(torch.float).to(self.device) 66 | 67 | if model_type == "vit_h": sam = sam_model_registry["vit_h"](checkpoint="sam_weights/sam_vit_h_4b8939.pth") 68 | elif model_type == "vit_l": sam = sam_model_registry["vit_l"](checkpoint="sam_weights/sam_vit_l_0b3195.pth") 69 | else: sam = sam_model_registry["vit_b"](checkpoint="sam_weights/sam_vit_b_01ec64.pth") 70 | 71 | self.sam_transform = ResizeLongestSide(224) if model_type == "vit_b" else ResizeLongestSide(1024) 72 | 73 | rank = 4 74 | lora_sam = LoRA_Sam(sam, rank).cuda() 75 | 76 | 77 | 78 | self.model = DataParallel(lora_sam).to(self.device) 79 | 80 | self.model.load_state_dict(torch.load(weight_path)) 81 | 82 | 83 | 84 | def predict(self, image, save_name): 85 | dm = DisplayManager(save_dir=self.save_dir) 86 | 87 | # process image: 88 | w = image.shape[1] 89 | image, prompt_image = image[:, :w//2], image[:, w//2:] 90 | 91 | ppn, pnn = self.get_red_and_green_points(prompt_image) 92 | 93 | prompt_points = torch.tensor(np.array([ppn + pnn])) 94 | prompt_type = torch.tensor(np.array([[1] * len(ppn) + [0] * len(pnn)])) 95 | 96 | images = torch.tensor(np.array([image.transpose((2,0,1))])) 97 | 98 | with torch.no_grad(): 99 | images, prompt_type = map(self.to_cuda, (images, prompt_type)) 100 | images, original_size, prompt_points = self.make_prompts(images, prompt_points) 101 | 102 | preds = self.model(images, original_size, prompt_points, prompt_type) 103 | 104 | preds = torch.gt(preds, 0.8).int() 105 | 106 | image = images[0][1].cpu().detach() 107 | pred = preds[0][0].cpu().detach() 108 | 109 | prompt_points, prompt_types = prompt_points[0].cpu().detach().int(), prompt_type[0].cpu().detach().int() 110 | 111 | dm.display_predict(image, torch.zeros_like(pred), pred, prompt_points, prompt_types, save_name) 112 | 113 | def get_red_and_green_points(self, image): 114 | 115 | diff_map = np.sum(np.abs(image - (0,255,0)), axis=2) 116 | green_mask = np.where(diff_map < 10, 1, 0).astype(np.uint8) 117 | 118 | diff_map = np.sum(np.abs(image - (0,0,255)), axis=2) 119 | red_mask = np.where(diff_map < 10, 1, 0).astype(np.uint8) 120 | 121 | red_coords, green_coords = [], [] 122 | 123 | labeled_array, num_features = label(green_mask, structure=np.ones((3,3))) 124 | 125 | for i in range(1, num_features + 1): 126 | center = center_of_mass(green_mask, labeled_array, index=i) 127 | green_coords.append([int(center[1]), int(center[0])]) 128 | 129 | labeled_array, num_features = label(red_mask, structure=np.ones((3,3))) 130 | 131 | for i in range(1, num_features + 1): 132 | center = center_of_mass(red_mask, labeled_array, index=i) 133 | red_coords.append([int(center[1]), int(center[0])]) 134 | 135 | if len(green_coords) + len(red_coords) == 0: red_coords = [[-100, -100]] 136 | 137 | return green_coords, red_coords 138 | 139 | 140 | def make_prompts(self, images, prompt_points): 141 | original_size = tuple(images.shape[-2:]) 142 | images = self.sam_transform.apply_image_torch(images) 143 | prompt_points = self.sam_transform.apply_coords_torch(prompt_points, original_size) 144 | 145 | return images, original_size, prompt_points 146 | 147 | if __name__=="__main__": 148 | weight_path = "vit_l_rv.pth" 149 | sample_id = 10453 150 | image_path = "{}.png".format(sample_id) 151 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 152 | pm = PredictManager_OCTA(weight_path) 153 | pm.predict(image, str(sample_id)) 154 | -------------------------------------------------------------------------------- /prompt_points.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import multivariate_normal 3 | from scipy import ndimage 4 | import cv2 5 | from collections import * 6 | import random 7 | from itertools import * 8 | from functools import * 9 | from display import show_prompt_points_image 10 | from tqdm import tqdm 11 | 12 | random_seed = 0 13 | 14 | if random_seed: 15 | random.seed(random_seed) 16 | np.random.seed(random_seed) 17 | 18 | # 将二维的坐标点转换为高斯热图, Converting 2D coordinate points to Gaussian heat maps 19 | def points_to_gaussian_heatmap(centers, height, width, scale): 20 | gaussians = [] 21 | for y, x in centers: 22 | s = np.eye(2) * scale 23 | g = multivariate_normal(mean=(x, y), cov=s) 24 | gaussians.append(g) 25 | x, y = np.arange(0, width), np.arange(0, height) 26 | xx, yy = np.meshgrid(x, y) 27 | xxyy = np.stack([xx.ravel(), yy.ravel()]).T 28 | zz = sum(g.pdf(xxyy) for g in gaussians) 29 | img = zz.reshape((height, width)) 30 | 31 | return img / np.max(img) 32 | 33 | def get_labelmap(label): 34 | structure = ndimage.generate_binary_structure(2, 2) 35 | labelmaps, connected_num = ndimage.label(label, structure=structure) 36 | # 像素->联通分量,0为背景, Pixel->connected component, 0 is the background 37 | pixel2connetedId = {(x, y): val for (x, y), val in np.ndenumerate(labelmaps)} 38 | return labelmaps, connected_num, pixel2connetedId 39 | 40 | def get_negative_region(labelmap, neg_range=8): 41 | kernel = np.ones((neg_range, neg_range), np.uint8) 42 | negative_region = cv2.dilate(labelmap, kernel, iterations=1) - labelmap 43 | return negative_region 44 | 45 | def label_to_point_prompt_global(label, positive_num=2, negative_num=-1): 46 | labelmaps, connected_num, _ = get_labelmap(label) 47 | positive_points, negative_points = [], [] 48 | connected_points_pos, connected_points_neg = defaultdict(list), defaultdict(list) 49 | negative_region = get_negative_region(labelmaps.astype(np.uint8)) 50 | 51 | for (x, y), val in np.ndenumerate(labelmaps): connected_points_pos[val].append((y, x)) 52 | for (x, y), val in np.ndenumerate(negative_region): connected_points_neg[val].append((y, x)) 53 | 54 | # time consuming loop 55 | for connected_id in range(1, connected_num+1): 56 | if positive_num <= len(connected_points_pos[connected_id]): 57 | positive_points += random.sample(connected_points_pos[connected_id], max(0, positive_num)) 58 | if 0 < negative_num <= len(connected_points_neg[connected_id]): 59 | negative_points += random.sample(connected_points_neg[connected_id], max(0, negative_num)) 60 | 61 | if negative_num == -1: 62 | total_num = 30 * positive_num 63 | negative_num = total_num - connected_num * positive_num 64 | negative_region = get_negative_region(label) 65 | negative_points = [(y, x) for (x, y), val in np.ndenumerate(negative_region) if val] 66 | negative_points = random.sample(negative_points, max(0, negative_num)) 67 | 68 | return np.array([label], dtype=float), positive_points, negative_points 69 | 70 | def label_to_point_prompt_local(label, positive_num=2, negative_num=2): 71 | labelmaps, _, pixel2connetedId = get_labelmap(label) 72 | labelmap_points = [(x, y) for (x, y), val in np.ndenumerate(labelmaps) if val] 73 | 74 | min_area = positive_num + negative_num 75 | 76 | def get_selected_points(): 77 | selected_pixel = random.randint(0, len(labelmap_points)-1) 78 | selected_id = pixel2connetedId[labelmap_points[selected_pixel]] 79 | return [(x, y) for (x, y), val in np.ndenumerate(labelmaps) if val == selected_id] 80 | 81 | selected_points = get_selected_points() 82 | while len(selected_points) < min_area: selected_points = get_selected_points() 83 | 84 | selected_labelmap = np.zeros_like(labelmaps, dtype=np.uint8) 85 | for (x, y) in selected_points: selected_labelmap[(x, y)] = 1 86 | 87 | negative_region = get_negative_region(selected_labelmap) 88 | 89 | positive_points = [(y, x) for (x, y), val in np.ndenumerate(selected_labelmap) if val] 90 | negative_points = [(y, x) for (x, y), val in np.ndenumerate(negative_region) if val] 91 | 92 | positive_points = random.sample(positive_points, max(0, positive_num)) 93 | negative_points = random.sample(negative_points, max(0, negative_num)) 94 | 95 | # no prompt points, no segmentation 96 | if not positive_points + negative_points: selected_labelmap = np.zeros_like(labelmaps, dtype=np.uint8) 97 | 98 | return np.array([selected_labelmap], dtype=float), positive_points, negative_points 99 | 100 | # if __name__=="__main__": 101 | 102 | # for sample_id in range(10301, 10501): 103 | # image_path = "datasets/OCTA-500/OCTA_3M/ProjectionMaps/OCTA(OPL_BM)/{}.bmp".format(sample_id) 104 | # label_path = "datasets/OCTA-500/OCTA_3M/GT_LargeVessel/{}.bmp".format(sample_id) 105 | # image, label = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE), cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 106 | 107 | # _, _, positive_points, negative_points = label_to_point_prompt_global(label, 1, -1) 108 | # print(len(positive_points + negative_points)) -------------------------------------------------------------------------------- /sam_lora_image_encoder.py: -------------------------------------------------------------------------------- 1 | from segment_anything import build_sam, SamPredictor 2 | from segment_anything import sam_model_registry 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.nn.parameter import Parameter 10 | from segment_anything.modeling import Sam 11 | from safetensors import safe_open 12 | from safetensors.torch import save_file 13 | 14 | from icecream import ic 15 | 16 | 17 | class _LoRA_qkv(nn.Module): 18 | """In Sam it is implemented as 19 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 20 | B, N, C = x.shape 21 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 22 | q, k, v = qkv.unbind(0) 23 | """ 24 | 25 | def __init__( 26 | self, 27 | qkv: nn.Module, 28 | linear_a_q: nn.Module, 29 | linear_b_q: nn.Module, 30 | linear_a_v: nn.Module, 31 | linear_b_v: nn.Module, 32 | ): 33 | super().__init__() 34 | self.qkv = qkv 35 | self.linear_a_q = linear_a_q 36 | self.linear_b_q = linear_b_q 37 | self.linear_a_v = linear_a_v 38 | self.linear_b_v = linear_b_v 39 | self.dim = qkv.in_features 40 | self.w_identity = torch.eye(qkv.in_features) 41 | 42 | def forward(self, x): 43 | qkv = self.qkv(x) # B,N,N,3*org_C 44 | new_q = self.linear_b_q(self.linear_a_q(x)) 45 | new_v = self.linear_b_v(self.linear_a_v(x)) 46 | qkv[:, :, :, : self.dim] += new_q 47 | qkv[:, :, :, -self.dim:] += new_v 48 | return qkv 49 | 50 | 51 | class LoRA_Sam(nn.Module): 52 | """Applies low-rank adaptation to a Sam model's image encoder. 53 | 54 | Args: 55 | sam_model: a vision transformer model, see base_vit.py 56 | r: rank of LoRA 57 | num_classes: how many classes the model output, default to the vit model 58 | lora_layer: which layer we apply LoRA. 59 | 60 | Examples:: 61 | >>> model = ViT('B_16_imagenet1k') 62 | >>> lora_model = LoRA_ViT(model, r=4) 63 | >>> preds = lora_model(img) 64 | >>> print(preds.shape) 65 | torch.Size([1, 1000]) 66 | """ 67 | 68 | def __init__(self, sam_model: Sam, r: int, lora_layer=None): 69 | super(LoRA_Sam, self).__init__() 70 | 71 | assert r > 0 72 | # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels 73 | # dim = base_vit_dim 74 | if lora_layer: self.lora_layer = lora_layer 75 | else: self.lora_layer = list(range(len(sam_model.image_encoder.blocks))) # Only apply lora to the image encoder by default 76 | # create for storage, then we can init them or load weights 77 | self.w_As, self.w_Bs = [], [] # These are linear layers 78 | 79 | # lets freeze first 80 | for param in sam_model.image_encoder.parameters(): param.requires_grad = False 81 | 82 | for param in sam_model.prompt_encoder.parameters(): param.requires_grad = False 83 | for param in sam_model.mask_decoder.parameters(): param.requires_grad = False 84 | 85 | # Here, we do the surgery 86 | for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks): 87 | # If we only want few lora layer instead of all 88 | # if t_layer_i not in self.lora_layer: continue 89 | w_qkv_linear = blk.attn.qkv 90 | self.dim = w_qkv_linear.in_features 91 | w_a_linear_q = nn.Linear(self.dim, r, bias=False) 92 | w_b_linear_q = nn.Linear(r, self.dim, bias=False) 93 | w_a_linear_v = nn.Linear(self.dim, r, bias=False) 94 | w_b_linear_v = nn.Linear(r, self.dim, bias=False) 95 | self.w_As.append(w_a_linear_q) 96 | self.w_Bs.append(w_b_linear_q) 97 | self.w_As.append(w_a_linear_v) 98 | self.w_Bs.append(w_b_linear_v) 99 | blk.attn.qkv = _LoRA_qkv( 100 | w_qkv_linear, 101 | w_a_linear_q, 102 | w_b_linear_q, 103 | w_a_linear_v, 104 | w_b_linear_v, 105 | ) 106 | self.reset_parameters() 107 | self.sam = sam_model 108 | 109 | def save_lora_parameters(self, filename: str) -> None: 110 | r"""Only safetensors is supported now. 111 | 112 | pip install safetensor if you do not have one installed yet. 113 | 114 | save both lora and fc parameters. 115 | """ 116 | 117 | assert filename.endswith(".pt") or filename.endswith('.pth') 118 | 119 | num_layer = len(self.w_As) # actually, it is half 120 | a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} 121 | b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} 122 | prompt_encoder_tensors = {} 123 | mask_decoder_tensors = {} 124 | 125 | # save prompt encoder, only `state_dict`, the `named_parameter` is not permitted 126 | if isinstance(self.sam, torch.nn.DataParallel) or isinstance(self.sam, torch.nn.parallel.DistributedDataParallel): 127 | state_dict = self.sam.module.state_dict() 128 | else: 129 | state_dict = self.sam.state_dict() 130 | for key, value in state_dict.items(): 131 | if 'prompt_encoder' in key: prompt_encoder_tensors[key] = value 132 | if 'mask_decoder' in key: mask_decoder_tensors[key] = value 133 | 134 | merged_dict = {**a_tensors, **b_tensors, **prompt_encoder_tensors, **mask_decoder_tensors} 135 | torch.save(merged_dict, filename) 136 | 137 | def load_lora_parameters(self, filename: str) -> None: 138 | r"""Only safetensors is supported now. 139 | 140 | pip install safetensor if you do not have one installed yet.\ 141 | 142 | load both lora and fc parameters. 143 | """ 144 | 145 | assert filename.endswith(".pt") or filename.endswith('.pth') 146 | 147 | state_dict = torch.load(filename) 148 | 149 | for i, w_A_linear in enumerate(self.w_As): 150 | saved_key = f"w_a_{i:03d}" 151 | saved_tensor = state_dict[saved_key] 152 | w_A_linear.weight = Parameter(saved_tensor) 153 | 154 | for i, w_B_linear in enumerate(self.w_Bs): 155 | saved_key = f"w_b_{i:03d}" 156 | saved_tensor = state_dict[saved_key] 157 | w_B_linear.weight = Parameter(saved_tensor) 158 | 159 | sam_dict = self.sam.state_dict() 160 | sam_keys = sam_dict.keys() 161 | 162 | # load prompt encoder 163 | prompt_encoder_keys = [k for k in sam_keys if 'prompt_encoder' in k] 164 | prompt_encoder_values = [state_dict[k] for k in prompt_encoder_keys] 165 | prompt_encoder_new_state_dict = {k: v for k, v in zip(prompt_encoder_keys, prompt_encoder_values)} 166 | sam_dict.update(prompt_encoder_new_state_dict) 167 | 168 | # load mask decoder 169 | mask_decoder_keys = [k for k in sam_keys if 'mask_decoder' in k] 170 | mask_decoder_values = [state_dict[k] for k in mask_decoder_keys] 171 | mask_decoder_new_state_dict = {k: v for k, v in zip(mask_decoder_keys, mask_decoder_values)} 172 | sam_dict.update(mask_decoder_new_state_dict) 173 | self.sam.load_state_dict(sam_dict) 174 | 175 | def reset_parameters(self) -> None: 176 | for w_A in self.w_As: nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) 177 | for w_B in self.w_Bs: nn.init.zeros_(w_B.weight) 178 | 179 | def forward(self, images, original_size, point_coords, point_labels, multimask_output=False): 180 | input_size = tuple(images.shape[-2:]) 181 | images = self.sam.preprocess(images) 182 | image_features = self.sam.image_encoder(images) # shape: 1, 256, 64, 64, range: 0-1 183 | 184 | sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( 185 | points=(point_coords, point_labels), 186 | boxes=None, 187 | masks=None, 188 | ) 189 | 190 | # Predict masks 191 | low_res_masks, iou_predictions = self.sam.mask_decoder( 192 | image_embeddings=image_features, 193 | image_pe=self.sam.prompt_encoder.get_dense_pe(), 194 | sparse_prompt_embeddings=sparse_embeddings, 195 | dense_prompt_embeddings=dense_embeddings, 196 | multimask_output=multimask_output, 197 | ) 198 | 199 | # Upscale the masks to the original image resolution 200 | masks = self.sam.postprocess_masks(low_res_masks, input_size, original_size) 201 | return nn.Sigmoid()(masks) 202 | 203 | 204 | 205 | 206 | # if __name__ == "__main__": 207 | # sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") 208 | # lora_sam = LoRA_Sam(sam, 4) 209 | # lora_sam.sam.image_encoder(torch.rand(size=(1, 3, 1024, 1024))) -------------------------------------------------------------------------------- /sam_weights/keep_dir.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/sam_weights/keep_dir.txt -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/build_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/__pycache__/build_sam.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/build_sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/__pycache__/build_sam.cpython-37.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/predictor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/__pycache__/predictor.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crop_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros_like(data["boxes"][:, 0]), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros_like(data["boxes"][:, 0]), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros_like(boxes[:, 0]), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/modeling/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/modeling/__pycache__/sam.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 148 | positional parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 205 | positional parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | 85 | points = torch.cat([points, padding_point], dim=1) 86 | labels = torch.cat([labels, padding_label], dim=1) 87 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 88 | point_embedding[labels == -1] = 0.0 89 | point_embedding[labels == -1] += self.not_a_point_embed.weight 90 | point_embedding[labels == 0] += self.point_embeddings[0].weight 91 | point_embedding[labels == 1] += self.point_embeddings[1].weight 92 | return point_embedding 93 | 94 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 95 | """Embeds box prompts.""" 96 | boxes = boxes + 0.5 # Shift to center of pixel 97 | coords = boxes.reshape(-1, 2, 2) 98 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 99 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 100 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 101 | return corner_embedding 102 | 103 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 104 | """Embeds mask inputs.""" 105 | mask_embedding = self.mask_downscaling(masks) 106 | return mask_embedding 107 | 108 | def _get_batch_size( 109 | self, 110 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 111 | boxes: Optional[torch.Tensor], 112 | masks: Optional[torch.Tensor], 113 | ) -> int: 114 | """ 115 | Gets the batch size of the output given the batch size of the input prompts. 116 | """ 117 | if points is not None: 118 | return points[0].shape[0] 119 | elif boxes is not None: 120 | return boxes.shape[0] 121 | elif masks is not None: 122 | return masks.shape[0] 123 | else: 124 | return 1 125 | 126 | def _get_device(self) -> torch.device: 127 | return self.point_embeddings[0].weight.device 128 | 129 | def forward( 130 | self, 131 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 132 | boxes: Optional[torch.Tensor], 133 | masks: Optional[torch.Tensor], 134 | ) -> Tuple[torch.Tensor, torch.Tensor]: 135 | """ 136 | Embeds different types of prompts, returning both sparse and dense 137 | embeddings. 138 | 139 | Arguments: 140 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 141 | and labels to embed. 142 | boxes (torch.Tensor or none): boxes to embed 143 | masks (torch.Tensor or none): masks to embed 144 | 145 | Returns: 146 | torch.Tensor: sparse embeddings for the points and boxes, with shape 147 | BxNx(embed_dim), where N is determined by the number of input points 148 | and boxes. 149 | torch.Tensor: dense embeddings for the masks, in the shape 150 | Bx(embed_dim)x(embed_H)x(embed_W) 151 | """ 152 | bs = self._get_batch_size(points, boxes, masks) 153 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 154 | if points is not None: 155 | coords, labels = points 156 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 157 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 158 | if boxes is not None: 159 | box_embeddings = self._embed_boxes(boxes) 160 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 161 | 162 | if masks is not None: 163 | dense_embeddings = self._embed_masks(masks) 164 | else: 165 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 166 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 167 | ) 168 | 169 | return sparse_embeddings, dense_embeddings 170 | 171 | 172 | class PositionEmbeddingRandom(nn.Module): 173 | """ 174 | Positional encoding using random spatial frequencies. 175 | """ 176 | 177 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 178 | super().__init__() 179 | if scale is None or scale <= 0.0: 180 | scale = 1.0 181 | self.register_buffer( 182 | "positional_encoding_gaussian_matrix", 183 | scale * torch.randn((2, num_pos_feats)), 184 | ) 185 | 186 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 187 | """Positionally encode points that are normalized to [0,1].""" 188 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 189 | coords = 2 * coords - 1 190 | coords = coords @ self.positional_encoding_gaussian_matrix 191 | coords = 2 * np.pi * coords 192 | # outputs d_1 x ... x d_n x C shape 193 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 194 | 195 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 196 | """Generate positional encoding for a grid of the specified size.""" 197 | h, w = size 198 | device: Any = self.positional_encoding_gaussian_matrix.device 199 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 200 | y_embed = grid.cumsum(dim=0) - 0.5 201 | x_embed = grid.cumsum(dim=1) - 0.5 202 | y_embed = y_embed / h 203 | x_embed = x_embed / w 204 | 205 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 206 | return pe.permute(2, 0, 1) # C x H x W 207 | 208 | def forward_with_coords( 209 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 210 | ) -> torch.Tensor: 211 | """Positionally encode points that are not normalized to [0,1].""" 212 | coords = coords_input.clone() 213 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 214 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 215 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 216 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | # @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | 57 | input_image = self.transform.apply_image(image) # shape: (1024, 1024, 3) 58 | input_image_torch = torch.as_tensor(input_image, device=self.device) 59 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 60 | 61 | self.set_torch_image(input_image_torch, image.shape[:2]) 62 | 63 | @torch.no_grad() 64 | def set_torch_image( 65 | self, 66 | transformed_image: torch.Tensor, 67 | original_image_size: Tuple[int, ...], 68 | ) -> None: 69 | """ 70 | Calculates the image embeddings for the provided image, allowing 71 | masks to be predicted with the 'predict' method. Expects the input 72 | image to be already transformed to the format expected by the model. 73 | 74 | Arguments: 75 | transformed_image (torch.Tensor): The input image, with shape 76 | 1x3xHxW, which has been transformed with ResizeLongestSide. 77 | original_image_size (tuple(int, int)): The size of the image 78 | before transformation, in (H, W) format. 79 | """ 80 | assert ( 81 | len(transformed_image.shape) == 4 82 | and transformed_image.shape[1] == 3 83 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 84 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 85 | self.reset_image() 86 | # print("transformed_image.shape :", transformed_image.shape) # : (1, 3, 1024, 1024) 87 | self.original_size = original_image_size 88 | self.input_size = tuple(transformed_image.shape[-2:]) 89 | input_image = self.model.preprocess(transformed_image) 90 | self.features = self.model.image_encoder(input_image) 91 | self.is_image_set = True 92 | 93 | def predict( 94 | self, 95 | point_coords: Optional[np.ndarray] = None, 96 | point_labels: Optional[np.ndarray] = None, 97 | box: Optional[np.ndarray] = None, 98 | mask_input: Optional[np.ndarray] = None, 99 | multimask_output: bool = True, 100 | return_logits: bool = False, 101 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 102 | """ 103 | Predict masks for the given input prompts, using the currently set image. 104 | 105 | Arguments: 106 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 107 | model. Each point is in (X,Y) in pixels. 108 | point_labels (np.ndarray or None): A length N array of labels for the 109 | point prompts. 1 indicates a foreground point and 0 indicates a 110 | background point. 111 | box (np.ndarray or None): A length 4 array given a box prompt to the 112 | model, in XYXY format. 113 | mask_input (np.ndarray): A low resolution mask input to the model, typically 114 | coming from a previous prediction iteration. Has form 1xHxW, where 115 | for SAM, H=W=256. 116 | multimask_output (bool): If true, the model will return three masks. 117 | For ambiguous input prompts (such as a single click), this will often 118 | produce better masks than a single prediction. If only a single 119 | mask is needed, the model's predicted quality score can be used 120 | to select the best mask. For non-ambiguous prompts, such as multiple 121 | input prompts, multimask_output=False can give better results. 122 | return_logits (bool): If true, returns un-thresholded masks logits 123 | instead of a binary mask. 124 | 125 | Returns: 126 | (np.ndarray): The output masks in CxHxW format, where C is the 127 | number of masks, and (H, W) is the original image size. 128 | (np.ndarray): An array of length C containing the model's 129 | predictions for the quality of each mask. 130 | (np.ndarray): An array of shape CxHxW, where C is the number 131 | of masks and H=W=256. These low resolution logits can be passed to 132 | a subsequent iteration as mask input. 133 | """ 134 | if not self.is_image_set: 135 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 136 | 137 | # Transform input prompts 138 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 139 | if point_coords is not None: 140 | assert ( 141 | point_labels is not None 142 | ), "point_labels must be supplied if point_coords is supplied." 143 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 144 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 145 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 146 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 147 | if box is not None: 148 | box = self.transform.apply_boxes(box, self.original_size) 149 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 150 | box_torch = box_torch[None, :] 151 | if mask_input is not None: 152 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 153 | mask_input_torch = mask_input_torch[None, :, :, :] 154 | 155 | 156 | print("coords_torch.shape :", coords_torch.shape, coords_torch.max()) 157 | print("labels_torch.shape :", labels_torch.shape, labels_torch.max()) 158 | 159 | masks, iou_predictions, low_res_masks = self.predict_torch( 160 | coords_torch, 161 | labels_torch, 162 | box_torch, 163 | mask_input_torch, 164 | multimask_output, 165 | return_logits=return_logits, 166 | ) 167 | 168 | masks_np = masks[0].detach().cpu().numpy() 169 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 170 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 171 | return masks_np, iou_predictions_np, low_res_masks_np 172 | 173 | @torch.no_grad() 174 | def predict_torch( 175 | self, 176 | point_coords: Optional[torch.Tensor], 177 | point_labels: Optional[torch.Tensor], 178 | boxes: Optional[torch.Tensor] = None, 179 | mask_input: Optional[torch.Tensor] = None, 180 | multimask_output: bool = True, 181 | return_logits: bool = False, 182 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 183 | """ 184 | Predict masks for the given input prompts, using the currently set image. 185 | Input prompts are batched torch tensors and are expected to already be 186 | transformed to the input frame using ResizeLongestSide. 187 | 188 | Arguments: 189 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 190 | model. Each point is in (X,Y) in pixels. 191 | point_labels (torch.Tensor or None): A BxN array of labels for the 192 | point prompts. 1 indicates a foreground point and 0 indicates a 193 | background point. 194 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 195 | model, in XYXY format. 196 | mask_input (np.ndarray): A low resolution mask input to the model, typically 197 | coming from a previous prediction iteration. Has form Bx1xHxW, where 198 | for SAM, H=W=256. Masks returned by a previous iteration of the 199 | predict method do not need further transformation. 200 | multimask_output (bool): If true, the model will return three masks. 201 | For ambiguous input prompts (such as a single click), this will often 202 | produce better masks than a single prediction. If only a single 203 | mask is needed, the model's predicted quality score can be used 204 | to select the best mask. For non-ambiguous prompts, such as multiple 205 | input prompts, multimask_output=False can give better results. 206 | return_logits (bool): If true, returns un-thresholded masks logits 207 | instead of a binary mask. 208 | 209 | Returns: 210 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 211 | number of masks, and (H, W) is the original image size. 212 | (torch.Tensor): An array of shape BxC containing the model's 213 | predictions for the quality of each mask. 214 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 215 | of masks and H=W=256. These low res logits can be passed to 216 | a subsequent iteration as mask input. 217 | """ 218 | if not self.is_image_set: 219 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 220 | 221 | if point_coords is not None: 222 | points = (point_coords, point_labels) 223 | else: 224 | points = None 225 | 226 | # Embed prompts 227 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 228 | points=points, 229 | boxes=boxes, 230 | masks=mask_input, 231 | ) 232 | 233 | # Predict masks 234 | low_res_masks, iou_predictions = self.model.mask_decoder( 235 | image_embeddings=self.features, 236 | image_pe=self.model.prompt_encoder.get_dense_pe(), 237 | sparse_prompt_embeddings=sparse_embeddings, 238 | dense_prompt_embeddings=dense_embeddings, 239 | multimask_output=multimask_output, 240 | ) 241 | 242 | # Upscale the masks to the original image resolution 243 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 244 | 245 | if not return_logits: 246 | masks = masks > self.model.mask_threshold 247 | 248 | return masks, iou_predictions, low_res_masks 249 | 250 | def get_image_embedding(self) -> torch.Tensor: 251 | """ 252 | Returns the image embeddings for the currently set image, with 253 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 254 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 255 | """ 256 | if not self.is_image_set: 257 | raise RuntimeError( 258 | "An image must be set with .set_image(...) to generate an embedding." 259 | ) 260 | assert self.features is not None, "Features must exist if an image has been set." 261 | return self.features 262 | 263 | @property 264 | def device(self) -> torch.device: 265 | return self.model.device 266 | 267 | def reset_image(self) -> None: 268 | """Resets the currently set image.""" 269 | self.is_image_set = False 270 | self.features = None 271 | self.orig_h = None 272 | self.orig_w = None 273 | self.input_h = None 274 | self.input_w = None -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/amg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/utils/__pycache__/amg.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShellRedia/SAM-OCTA/9edc56247a3e30c96e5c6f1e8fe01250b8338527/segment_anything/utils/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | 104 | # if __name__=="__main__": 105 | # test_image = np.ones((304, 304, 3), dtype=np.uint8) 106 | # original_size = tuple(test_image.shape[:2]) 107 | # sam_transform = ResizeLongestSide(224) 108 | # image = sam_transform.apply_image(test_image) 109 | # print(image.shape) 110 | # prompt_points = np.array([[-100, -100]]) 111 | # prompt_points = sam_transform.apply_coords(prompt_points, original_size) 112 | # print(prompt_points.shape) 113 | # print(prompt_points) -------------------------------------------------------------------------------- /test_sam_octa.py: -------------------------------------------------------------------------------- 1 | from sam_lora_image_encoder import LoRA_Sam 2 | from segment_anything import * 3 | import torch 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from torch.nn import DataParallel 7 | from torch.utils.data import SubsetRandomSampler 8 | from dataset import octa500_2d_dataset 9 | from tqdm import tqdm 10 | import numpy as np 11 | from options import * 12 | import itertools 13 | from statistics import * 14 | from loss_functions import * 15 | import os 16 | import random 17 | import time 18 | from display import * 19 | from metrics import MetricsStatistics 20 | from collections import * 21 | from segment_anything.utils.transforms import ResizeLongestSide 22 | 23 | parser = argparse.ArgumentParser() 24 | add_training_parser(parser) 25 | add_octa500_2d_parser(parser) 26 | args = parser.parse_args() 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | num_gpus = torch.cuda.device_count() 32 | 33 | for i in range(num_gpus): 34 | gpu_name = torch.cuda.get_device_name(i) 35 | print(f"GPU {i}: {gpu_name}") 36 | 37 | time_str = "-".join(["{:0>2}".format(x) for x in time.localtime(time.time())][:-3]) 38 | print(time_str) 39 | 40 | test_weight_path = '3M_LargeVessel_Global.pth' 41 | 42 | to_cuda = lambda x: x.to(torch.float).to(device) 43 | 44 | ppn, pnn = args.prompt_positive_num, args.prompt_negative_num 45 | dataset_params = [args.fov, args.label_type, ppn, pnn, args.is_local, False] 46 | dataset_test = octa500_2d_dataset(*dataset_params) 47 | 48 | parameters = [args.fov, args.label_type, args.epochs, args.is_local, args.model_type, args.remark] 49 | 50 | save_dir = "test/{}/{}".format(time_str, "_".join(map(str, parameters))) 51 | 52 | sample_n = len(dataset_test) 53 | 54 | if args.model_type == "vit_h": 55 | sam = sam_model_registry["vit_h"](checkpoint="sam_weights/sam_vit_h_4b8939.pth") 56 | elif args.model_type == "vit_l": 57 | sam = sam_model_registry["vit_l"](checkpoint="sam_weights/sam_vit_l_0b3195.pth") 58 | else: 59 | sam = sam_model_registry["vit_b"](checkpoint="sam_weights/sam_vit_b_01ec64.pth") 60 | 61 | sam_transform = ResizeLongestSide(224) if args.model_type == "vit_b" else ResizeLongestSide(1024) 62 | 63 | model = LoRA_Sam(sam, 4).cuda() 64 | model = torch.nn.DataParallel(model) 65 | model.load_state_dict(torch.load(test_weight_path)) 66 | model = torch.nn.DataParallel(model).to(device) 67 | model.eval() 68 | 69 | val_loader = DataLoader(dataset_test, batch_size=1) 70 | 71 | metrics_statistics = MetricsStatistics(save_dir=save_dir) 72 | 73 | def make_prompts(images, prompt_points): 74 | original_size = tuple(images.shape[-2:]) 75 | images = sam_transform.apply_image_torch(images) 76 | prompt_points = sam_transform.apply_coords_torch(prompt_points, original_size) 77 | 78 | return images, original_size, prompt_points 79 | 80 | with torch.no_grad(): 81 | for images, prompt_points, prompt_type, selected_components, sample_ids in tqdm(val_loader): 82 | images, labels, prompt_type = map(to_cuda, (images, selected_components, prompt_type)) 83 | images, original_size, prompt_points = make_prompts(images, prompt_points) 84 | preds = model(images, original_size, prompt_points, prompt_type) 85 | 86 | preds = torch.gt(preds, 0.8).int() 87 | sample_id = str(sample_ids[0]) 88 | 89 | image, label, pred = map(lambda x:x[0][0].cpu().detach(), (images, labels, preds)) 90 | prompt_points, prompt_type = prompt_points[0].cpu().detach(), prompt_type[0].cpu().detach() 91 | prompt_info = np.concatenate((prompt_points, prompt_type[:,np.newaxis]), axis=1).astype(int) 92 | metrics_statistics.cal_epoch_metric(args.metrics, args.label_type, label.int(), pred.int()) 93 | 94 | if not os.path.exists(save_dir): os.makedirs(save_dir) 95 | save_sample_func = lambda x, y: np.save("/".join([save_dir,\ 96 | "{}_{}_{}.npy".format(args.label_type, x, sample_id)]), y) 97 | save_items = {"sample":image / 255, "label":label, "prompt_info":prompt_info, "pred":pred} 98 | for x, y in save_items.items(): save_sample_func(x, y) 99 | 100 | 101 | metrics_statistics.record_result(-1) 102 | metrics_statistics.close() -------------------------------------------------------------------------------- /train_sam_octa.py: -------------------------------------------------------------------------------- 1 | from sam_lora_image_encoder import LoRA_Sam 2 | from segment_anything import * 3 | import torch 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from torch.nn import DataParallel 7 | from torch.utils.data import SubsetRandomSampler 8 | from dataset import octa500_2d_dataset 9 | from tqdm import tqdm 10 | import numpy as np 11 | from options import * 12 | import itertools 13 | from statistics import * 14 | from loss_functions import * 15 | import os 16 | import random 17 | import time 18 | from display import * 19 | from metrics import MetricsStatistics 20 | from collections import * 21 | from segment_anything.utils.transforms import ResizeLongestSide 22 | 23 | parser = argparse.ArgumentParser(description='training arguments') 24 | add_training_parser(parser) 25 | add_octa500_2d_parser(parser) 26 | args = parser.parse_args() 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | num_gpus = torch.cuda.device_count() 32 | 33 | for i in range(num_gpus): 34 | gpu_name = torch.cuda.get_device_name(i) 35 | print(f"GPU {i}: {gpu_name}") 36 | 37 | time_str = "-".join(["{:0>2}".format(x) for x in time.localtime(time.time())][:-3]) 38 | print(time_str) 39 | 40 | to_cuda = lambda x: x.to(torch.float).to(device) 41 | 42 | class TrainManager_OCTA: 43 | def __init__(self, dataset_train, dataset_val): 44 | self.dataset_train, self.dataset_val = dataset_train, dataset_val 45 | parameters = [args.fov, args.label_type, args.epochs, args.is_local, args.model_type, args.remark] 46 | self.record_dir = "results/{}/{}".format(time_str, "_".join(map(str, parameters))) 47 | self.cpt_dir = "{}/checkpoints".format(self.record_dir) 48 | 49 | if not os.path.exists(self.cpt_dir): os.makedirs(self.cpt_dir) 50 | 51 | sample_n = len(self.dataset_train) 52 | self.indices = list(range(sample_n)) 53 | random.shuffle(self.indices) 54 | self.split = sample_n // args.k_fold 55 | 56 | if args.model_type == "vit_h": 57 | sam = sam_model_registry["vit_h"](checkpoint="sam_weights/sam_vit_h_4b8939.pth") 58 | elif args.model_type == "vit_l": 59 | sam = sam_model_registry["vit_l"](checkpoint="sam_weights/sam_vit_l_0b3195.pth") 60 | else: 61 | sam = sam_model_registry["vit_b"](checkpoint="sam_weights/sam_vit_b_01ec64.pth") 62 | 63 | self.sam_transform = ResizeLongestSide(224) if args.model_type == "vit_b" else ResizeLongestSide(1024) 64 | 65 | lora_sam = LoRA_Sam(sam, 4).cuda() 66 | self.model = DataParallel(lora_sam).to(device) 67 | torch.save(self.model.state_dict(), '{}/init.pth'.format(self.cpt_dir)) 68 | 69 | self.loss_func = DiceLoss() 70 | if args.label_type in ["Artery", "Vein", "LargeVessel"]: 71 | self.loss_func = lambda x, y: 0.8 * DiceLoss()(x, y) + 0.2 * clDiceLoss()(x, y) 72 | 73 | def get_dataloader(self, fold_i): 74 | train_indices = self.indices[:fold_i * self.split] + self.indices[(fold_i + 1) * self.split:] 75 | val_indices = self.indices[fold_i * self.split:(fold_i + 1) * self.split] 76 | train_sampler, val_sampler = [SubsetRandomSampler(x) for x in (train_indices, val_indices)] 77 | batch_size = len(args.device.split(",")) 78 | train_loader = DataLoader(self.dataset_train, batch_size=batch_size, sampler=train_sampler) 79 | val_loader = DataLoader(self.dataset_val, batch_size=1, sampler=val_sampler) 80 | 81 | return train_loader, val_loader 82 | 83 | def reset(self): 84 | self.model.load_state_dict(torch.load('{}/init.pth'.format(self.cpt_dir))) 85 | pg = [p for p in self.model.parameters() if p.requires_grad] # lora parameters 86 | self.optimizer = optim.AdamW(pg, lr=1, weight_decay=1e-4) 87 | epoch_p = args.epochs // 5 88 | lr_lambda = lambda x: max(1e-5, args.lr * x / epoch_p if x <= epoch_p else args.lr * 0.98 ** (x - epoch_p)) 89 | self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) 90 | 91 | def record_performance(self, train_loader, val_loader, fold_i, epoch, metrics_statistics): 92 | save_dir = "{}/{}/{:0>4}".format(self.record_dir, fold_i, epoch) 93 | torch.save(self.model.state_dict(), '{}/fold-{}_{:0>4}.pth'.format(self.cpt_dir, fold_i, epoch)) 94 | 95 | metrics_statistics.metric_values["learning rate"].append(self.optimizer.param_groups[0]['lr']) 96 | 97 | def record_dataloader(dataloader, loader_type="val", is_complete=True): 98 | for images, prompt_points, prompt_type, selected_components, sample_ids in dataloader: 99 | images, labels, prompt_type = map(to_cuda, (images, selected_components, prompt_type)) 100 | images, original_size, prompt_points = self.make_prompts(images, prompt_points) 101 | preds = self.model(images, original_size, prompt_points, prompt_type) 102 | metrics_statistics.metric_values["loss_"+loader_type].append(self.loss_func(preds, labels).cpu().item()) 103 | 104 | if is_complete: 105 | preds = torch.gt(preds, 0.8).int() 106 | sample_id = str(sample_ids[0]) 107 | 108 | image, label, pred = map(lambda x:x[0][0].cpu().detach(), (images, labels, preds)) 109 | prompt_points, prompt_type = prompt_points[0].cpu().detach(), prompt_type[0].cpu().detach() 110 | prompt_info = np.concatenate((prompt_points, prompt_type[:,np.newaxis]), axis=1).astype(int) 111 | metrics_statistics.cal_epoch_metric( 112 | args.metrics, "{}-{}".format(args.label_type,loader_type), label.int(), pred.int()) 113 | 114 | if not os.path.exists(save_dir): os.makedirs(save_dir) 115 | save_sample_func = lambda x, y: np.save("/".join([save_dir,\ 116 | "{}_{}_{}.npy".format(args.label_type, x, sample_id)]), y) 117 | save_items = {"sample":image / 255, "label":label, "prompt_info":prompt_info, "pred":pred} 118 | for x, y in save_items.items(): save_sample_func(x, y) 119 | 120 | record_dataloader(train_loader, "train", False) 121 | record_dataloader(val_loader, "val", True) 122 | 123 | metrics_statistics.record_result(epoch) 124 | 125 | def train(self): 126 | for fold_i in range(args.k_fold): 127 | train_loader, val_loader = self.get_dataloader(fold_i) 128 | self.reset() 129 | metrics_statistics = MetricsStatistics(save_dir="{}/{}".format(self.record_dir, fold_i)) 130 | self.record_performance(train_loader, val_loader, fold_i, 0, metrics_statistics) 131 | for epoch in tqdm(range(1, args.epochs+1), desc="training"): 132 | for images, prompt_points, prompt_type, selected_components, sample_ids in train_loader: 133 | images, labels, prompt_type = map(to_cuda, (images, selected_components, prompt_type)) 134 | images, original_size, prompt_points = self.make_prompts(images, prompt_points) 135 | self.optimizer.zero_grad() 136 | preds = self.model(images, original_size, prompt_points, prompt_type) 137 | self.loss_func(preds, labels).backward() 138 | self.optimizer.step() 139 | self.scheduler.step() 140 | if epoch % args.check_interval == 0: 141 | self.record_performance(train_loader, val_loader, fold_i, epoch, metrics_statistics) 142 | metrics_statistics.close() 143 | 144 | def make_prompts(self, images, prompt_points): 145 | original_size = tuple(images.shape[-2:]) 146 | images = self.sam_transform.apply_image_torch(images) 147 | prompt_points = self.sam_transform.apply_coords_torch(prompt_points, original_size) 148 | 149 | return images, original_size, prompt_points 150 | 151 | if __name__=="__main__": 152 | ppn, pnn = args.prompt_positive_num, args.prompt_negative_num 153 | dataset_params = [args.fov, args.label_type, ppn, pnn, args.is_local, True] 154 | dataset_train = octa500_2d_dataset(*dataset_params) 155 | dataset_params[-1] = False 156 | dataset_val = octa500_2d_dataset(*dataset_params) 157 | train_manager = TrainManager_OCTA(dataset_train, dataset_val) 158 | train_manager.train() --------------------------------------------------------------------------------