├── .gitignore ├── LICENSE.txt ├── README.md ├── app.py ├── davis2017 ├── __init__.py ├── davis.py ├── evaluation.py ├── metrics.py ├── results.py └── utils.py ├── eval_miou.py ├── eval_video.py ├── examples ├── cat_00.jpg ├── cat_00.png ├── cat_01.jpg ├── cat_01.png ├── cat_02.jpg ├── cat_02.png ├── colorful_sneaker_00.jpg ├── colorful_sneaker_00.png ├── colorful_sneaker_01.jpg ├── colorful_sneaker_01.png ├── colorful_sneaker_02.jpg ├── colorful_sneaker_02.png ├── duck_toy_00.jpg ├── duck_toy_00.png ├── duck_toy_01.jpg ├── duck_toy_01.png ├── duck_toy_02.jpg └── duck_toy_02.png ├── figs ├── fig_db.png └── fig_persam.png ├── per_segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── tiny_vit_sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── persam.py ├── persam_f.py ├── persam_f_multi_obj.py ├── persam_video.py ├── persam_video_f.py ├── prepare_coco.py ├── requirements.txt ├── show.py └── weights └── mobile_sam.pt /.gitignore: -------------------------------------------------------------------------------- 1 | lation and distribution 2 | __pycache__ 3 | _ext 4 | *.pyc 5 | *.pyd 6 | *.so 7 | *.dll 8 | *.egg-info/ 9 | build/ 10 | dist/ 11 | wheels/ 12 | 13 | # pytorch/python/numpy formats 14 | *.pth 15 | *.pkl 16 | *.npy 17 | *.ts 18 | model_ts*.txt 19 | 20 | # onnx models 21 | *.onnx 22 | 23 | # ipython/jupyter notebooks 24 | **/.ipynb_checkpoints/ 25 | 26 | # Editor temporaries 27 | *.swn 28 | *.swo 29 | *.swp 30 | *~ 31 | 32 | # editor settings 33 | .idea 34 | .vscode 35 | _darcs 36 | 37 | # output 38 | data 39 | work_dirs 40 | 41 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Renrui Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Personalize Segment Anything with 1 Shot in 10 Seconds 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/personalize-segment-anything-model-with-one/personalized-segmentation-on-perseg)](https://paperswithcode.com/sota/personalized-segmentation-on-perseg?p=personalize-segment-anything-model-with-one) 4 | 5 | Official implementation of ['Personalize Segment Anything Model with One Shot'](https://arxiv.org/pdf/2305.03048.pdf). 6 | 7 | 💥 Try out the [web demo](https://huggingface.co/spaces/justin-zk/Personalize-SAM) 🤗 of PerSAM and PerSAM-F: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/justin-zk/Personalize-SAM) 8 | 9 | 10 | 🎉 Try out the [tutorial notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/PerSAM) in colab for your own dataset. Great thanks to [@NielsRogge](https://github.com/NielsRogge)! 11 | 12 | 🎆 Try out the online web demo of PerSAM in OpenXLab : 13 | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/RenRuiZhang/Personalize-SAM) 14 | 15 | 16 | ## News 17 | * Support [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) 🔥 with significant efficiency improvement. Thanks for their wonderful work! 18 | * **TODO**: Release the PerSAM-assisted [Dreambooth](https://arxiv.org/pdf/2208.12242.pdf) for better fine-tuning [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 📌. 19 | * We release the code of PerSAM and PerSAM-F 🔥. Check our [video](https://www.youtube.com/watch?v=QlunvXpYQXM) here! 20 | * We release a new dataset for personalized segmentation, [PerSeg](https://drive.google.com/file/d/18TbrwhZtAPY5dlaoEqkPa5h08G9Rjcio/view?usp=sharing) 🔥. 21 | 22 | ## Introduction 23 | *How to customize SAM to automatically segment your pet dog in a photo album?* 24 | 25 | In this project, we propose a training-free **Per**sonalization approach for [Segment Anything Model (SAM)](https://ai.facebook.com/research/publications/segment-anything/), termed as **PerSAM**. Given only a single image with a reference mask, PerSAM can segment specific visual concepts, e.g., your pet dog, within other images or videos without any training. 26 | For better performance, we further present an efficient one-shot fine-tuning variant, **PerSAM-F**. We freeze the entire SAM and introduce two learnable mask weights, which only trains **2 parameters** within **10 seconds**. 27 | 28 |
29 |
30 |
31 | 32 | Besides, our approach can be utilized to assist [DreamBooth](https://arxiv.org/pdf/2208.12242.pdf) in fine-tuning better [Stable Diffusion](https://github.com/CompVis/stable-diffusion) for personalized image synthesis. We adopt PerSAM to segment the target object in the user-provided few-shot images, which eliminates the **background disturbance** and benefits the target representation learning. 33 | 34 |
35 |
36 |
37 | 38 | ## Requirements 39 | ### Installation 40 | Clone the repo and create a conda environment: 41 | ```bash 42 | git clone https://github.com/ZrrSkywalker/Personalize-SAM.git 43 | cd Personalize-SAM 44 | 45 | conda create -n persam python=3.8 46 | conda activate persam 47 | 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | Similar to Segment Anything, our code requires `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. 52 | 53 | 54 | 55 | ### Preparation 56 | Please download our constructed dataset **PerSeg** for personalized segmentation from [Google Drive](https://drive.google.com/file/d/18TbrwhZtAPY5dlaoEqkPa5h08G9Rjcio/view?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1X-czD-FYW0ELlk2x90eTLg) (code `222k`), and the pre-trained weights of SAM from [here](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth). Then, unzip the dataset file and organize them as 57 | ``` 58 | data/ 59 | |–– Annotations/ 60 | |–– Images/ 61 | sam_vit_h_4b8939.pth 62 | ``` 63 | Please download 480p [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) split of DAVIS 2017. Then decompress the file to `DAVIS/2017` and organize them as 64 | ``` 65 | DAVIS/ 66 | |––2017/ 67 | |–– Annotations/ 68 | |–– ImageSets/ 69 | |–– JPEGImages/ 70 | ``` 71 | 72 | ## Getting Started 73 | 74 | ### Personalized Segmentation 75 | 76 | For the training-free 🧊 **PerSAM**, just run: 77 | ```bash 78 | python persam.py --outdir 79 | ``` 80 | 81 | For 10-second fine-tuning of 🚀 **PerSAM-F**, just run: 82 | ```bash 83 | python persam_f.py --outdir 84 | ``` 85 | 86 | For [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) with higher efficiency, just add `--sam_type vit_t`: 87 | ```bash 88 | python persam.py/persam_f.py --outdir --sam_type vit_t 89 | ``` 90 | 91 | 92 | For **Multi-Object** segmentation of the same category by PerSAM-F (Great thanks to [@mlzoo](https://github.com/mlzoo)), just run: 93 | ```bash 94 | python persam_f_multi_obj.py --sam_type --outdir 95 | ``` 96 | 97 | After running, the output masks and visualizations will be stored at `outputs/`. 98 | 99 | ### Evaluation 100 | Then, for mIoU evaluation, please run: 101 | ```bash 102 | python eval_miou.py --pred_path 103 | ``` 104 | 105 | ### Personalized Segmentation On Video 106 | 107 | For the training-free and evaluation of 🧊 **PerSAM** on video, just run: 108 | ```bash 109 | python persam_video.py --output_path 110 | ``` 111 | 112 | For 10-second fine-tuning and evaluation of 🚀 **PerSAM-F** on video, just run: 113 | ```bash 114 | python persam_video_f.py --output_path 115 | ``` 116 | 117 | ### Personalized Stable Diffusion 118 | Our approach can enhance DreamBooth to better personalize Stable Diffusion for text-to-image generation. 119 | 120 | Coming soon. 121 | 122 | ## Citation 123 | ```bash 124 | @article{zhang2023personalize, 125 | title={Personalize Segment Anything Model with One Shot}, 126 | author={Zhang, Renrui and Jiang, Zhengkai and Guo, Ziyu and Yan, Shilin and Pan, Junting and Dong, Hao and Gao, Peng and Li, Hongsheng}, 127 | journal={arXiv preprint arXiv:2305.03048}, 128 | year={2023} 129 | } 130 | ``` 131 | 132 | ## Acknowledgement 133 | This repo benefits from [Segment Anything](https://github.com/facebookresearch/segment-anything) and [DreamBooth](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion). Thanks for their wonderful works. 134 | 135 | ## Contact 136 | If you have any question about this project, please feel free to contact zhangrenrui@pjlab.org.cn. 137 | -------------------------------------------------------------------------------- /davis2017/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /davis2017/davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from collections import defaultdict 4 | import numpy as np 5 | from PIL import Image 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data import DataLoader 8 | from os import path 9 | 10 | def all_to_onehot(masks, labels): 11 | if len(masks.shape) == 3: 12 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 13 | else: 14 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) 15 | 16 | for k, l in enumerate(labels): 17 | Ms[k] = (masks == l).astype(np.uint8) 18 | 19 | return Ms 20 | 21 | class DAVISTestDataset(Dataset): 22 | def __init__(self, root, imset='2017/val.txt', resolution=480, single_object=False, target_name=None): 23 | self.root = root 24 | if resolution == 480: 25 | res_tag = '480p' 26 | else: 27 | res_tag = 'Full-Resolution' 28 | self.mask_dir = path.join(root, 'Annotations', res_tag) 29 | self.mask480_dir = path.join(root, 'Annotations', '480p') 30 | self.image_dir = path.join(root, 'JPEGImages', res_tag) 31 | self.resolution = resolution 32 | _imset_dir = path.join(root, 'ImageSets') 33 | _imset_f = path.join(_imset_dir, imset) 34 | 35 | self.videos = [] 36 | self.num_frames = {} 37 | self.num_objects = {} 38 | self.shape = {} 39 | self.size_480p = {} 40 | with open(path.join(_imset_f), "r") as lines: 41 | for line in lines: 42 | _video = line.rstrip('\n') 43 | if target_name is not None and target_name != _video: 44 | continue 45 | self.videos.append(_video) 46 | self.num_frames[_video] = len(os.listdir(path.join(self.image_dir, _video))) 47 | _mask = np.array(Image.open(path.join(self.mask_dir, _video, '00000.png')).convert("P")) 48 | self.num_objects[_video] = np.max(_mask) 49 | self.shape[_video] = np.shape(_mask) 50 | _mask480 = np.array(Image.open(path.join(self.mask480_dir, _video, '00000.png')).convert("P")) 51 | self.size_480p[_video] = np.shape(_mask480) 52 | self.single_object = single_object 53 | 54 | 55 | def __len__(self): 56 | return len(self.videos) 57 | 58 | def __getitem__(self, index): 59 | video = self.videos[index] 60 | info = {} 61 | info['name'] = video 62 | info['frames'] = [] 63 | info['num_frames'] = self.num_frames[video] 64 | info['size_480p'] = self.size_480p[video] 65 | 66 | images = [] 67 | masks = [] 68 | for f in range(self.num_frames[video]): 69 | img_file = path.join(self.image_dir, video, '{:05d}.jpg'.format(f)) 70 | img = Image.open(img_file).convert('RGB') 71 | img = np.array(img, dtype = 'uint8') 72 | images.append(img) 73 | info['frames'].append('{:05d}.jpg'.format(f)) 74 | 75 | mask_file = path.join(self.mask_dir, video, '{:05d}.png'.format(f)) 76 | if path.exists(mask_file): 77 | m = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8) #(480, 910) 78 | masks.append(m) #(480, 910), numpy 79 | else: 80 | masks.append(np.zeros_like(masks[0])) 81 | 82 | images = np.stack(images, 0) 83 | masks = np.stack(masks, 0) 84 | 85 | if self.single_object: 86 | labels = [1] 87 | masks = (masks > 0.5).astype(np.uint8) 88 | masks = all_to_onehot(masks, labels) 89 | else: 90 | labels = np.unique(masks[0]) 91 | labels = labels[labels!=0] 92 | masks = all_to_onehot(masks, labels) 93 | 94 | info['labels'] = labels 95 | 96 | data = { 97 | 'rgb': images, 98 | 'gt': masks, 99 | 'info': info, 100 | } 101 | 102 | return data 103 | 104 | 105 | class DAVIS(object): 106 | SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] 107 | TASKS = ['semi-supervised', 'unsupervised'] 108 | DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' 109 | VOID_LABEL = 255 110 | 111 | def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False): 112 | """ 113 | Class to read the DAVIS dataset 114 | :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 115 | :param task: Task to load the annotations, choose between semi-supervised or unsupervised. 116 | :param subset: Set to load the annotations 117 | :param sequences: Sequences to consider, 'all' to use all the sequences in a set. 118 | :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' 119 | """ 120 | if subset not in self.SUBSET_OPTIONS: 121 | raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') 122 | if task not in self.TASKS: 123 | raise ValueError(f'The only tasks that are supported are {self.TASKS}') 124 | 125 | self.task = task 126 | self.subset = subset 127 | self.root = root 128 | self.img_path = os.path.join(self.root, 'JPEGImages', resolution) 129 | annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' 130 | self.mask_path = os.path.join(self.root, annotations_folder, resolution) 131 | year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' 132 | self.imagesets_path = os.path.join(self.root, 'ImageSets', year) 133 | 134 | self._check_directories() 135 | 136 | if sequences == 'all': 137 | with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: 138 | tmp = f.readlines() 139 | sequences_names = [x.strip() for x in tmp] 140 | else: 141 | sequences_names = sequences if isinstance(sequences, list) else [sequences] 142 | self.sequences = defaultdict(dict) 143 | 144 | for seq in sequences_names: 145 | images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() 146 | if len(images) == 0 and not codalab: 147 | raise FileNotFoundError(f'Images for sequence {seq} not found.') 148 | self.sequences[seq]['images'] = images 149 | masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() 150 | masks.extend([-1] * (len(images) - len(masks))) 151 | self.sequences[seq]['masks'] = masks 152 | 153 | def _check_directories(self): 154 | if not os.path.exists(self.root): 155 | raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}') 156 | if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): 157 | raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset ' 158 | f'for the {self.task} task from {self.DATASET_WEB}') 159 | if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): 160 | raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}') 161 | 162 | def get_frames(self, sequence): 163 | for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): 164 | image = np.array(Image.open(img)) 165 | mask = None if msk is None else np.array(Image.open(msk)) 166 | yield image, mask 167 | 168 | def _get_all_elements(self, sequence, obj_type): 169 | obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) 170 | all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) 171 | obj_id = [] 172 | for i, obj in enumerate(self.sequences[sequence][obj_type]): 173 | all_objs[i, ...] = np.array(Image.open(obj)) 174 | obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) 175 | return all_objs, obj_id 176 | 177 | def get_all_images(self, sequence): 178 | return self._get_all_elements(sequence, 'images') 179 | 180 | def get_all_masks(self, sequence, separate_objects_masks=False): 181 | masks, masks_id = self._get_all_elements(sequence, 'masks') 182 | masks_void = np.zeros_like(masks) 183 | 184 | # Separate void and object masks 185 | for i in range(masks.shape[0]): 186 | masks_void[i, ...] = masks[i, ...] == 255 187 | masks[i, masks[i, ...] == 255] = 0 188 | 189 | if separate_objects_masks: 190 | num_objects = int(np.max(masks[0, ...])) 191 | tmp = np.ones((num_objects, *masks.shape)) 192 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 193 | masks = (tmp == masks[None, ...]) 194 | masks = masks > 0 195 | return masks, masks_void, masks_id 196 | 197 | def get_sequences(self): 198 | for seq in self.sequences: 199 | yield seq 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /davis2017/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from tqdm import tqdm 3 | import warnings 4 | warnings.filterwarnings("ignore", category=RuntimeWarning) 5 | 6 | import numpy as np 7 | from davis2017.davis import DAVIS 8 | from davis2017.metrics import db_eval_boundary, db_eval_iou 9 | from davis2017 import utils 10 | from davis2017.results import Results 11 | from scipy.optimize import linear_sum_assignment 12 | 13 | 14 | class DAVISEvaluation(object): 15 | def __init__(self, davis_root, task, gt_set, sequences='all', codalab=False): 16 | """ 17 | Class to evaluate DAVIS sequences from a certain set and for a certain task 18 | :param davis_root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 19 | :param task: Task to compute the evaluation, chose between semi-supervised or unsupervised. 20 | :param gt_set: Set to compute the evaluation 21 | :param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set. 22 | """ 23 | self.davis_root = davis_root 24 | self.task = task 25 | self.dataset = DAVIS(root=davis_root, task=task, subset=gt_set, sequences=sequences, codalab=codalab) 26 | 27 | @staticmethod 28 | def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric): 29 | if all_res_masks.shape[0] > all_gt_masks.shape[0]: 30 | sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!") 31 | sys.exit() 32 | elif all_res_masks.shape[0] < all_gt_masks.shape[0]: 33 | zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) 34 | all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) 35 | j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2]) 36 | for ii in range(all_gt_masks.shape[0]): 37 | if 'J' in metric: 38 | j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) 39 | if 'F' in metric: 40 | f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) 41 | return j_metrics_res, f_metrics_res 42 | 43 | @staticmethod 44 | def _evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric, max_n_proposals=20): 45 | if all_res_masks.shape[0] > max_n_proposals: 46 | sys.stdout.write(f"\nIn your PNG files there is an index higher than the maximum number ({max_n_proposals}) of proposals allowed!") 47 | sys.exit() 48 | elif all_res_masks.shape[0] < all_gt_masks.shape[0]: 49 | zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) 50 | all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) 51 | j_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) 52 | f_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) 53 | for ii in range(all_gt_masks.shape[0]): 54 | for jj in range(all_res_masks.shape[0]): 55 | if 'J' in metric: 56 | j_metrics_res[jj, ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) 57 | if 'F' in metric: 58 | f_metrics_res[jj, ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) 59 | if 'J' in metric and 'F' in metric: 60 | all_metrics = (np.mean(j_metrics_res, axis=2) + np.mean(f_metrics_res, axis=2)) / 2 61 | else: 62 | all_metrics = np.mean(j_metrics_res, axis=2) if 'J' in metric else np.mean(f_metrics_res, axis=2) 63 | row_ind, col_ind = linear_sum_assignment(-all_metrics) 64 | return j_metrics_res[row_ind, col_ind, :], f_metrics_res[row_ind, col_ind, :] 65 | 66 | def evaluate(self, res_path, metric=('J', 'F'), debug=False): 67 | metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric] 68 | if 'T' in metric: 69 | raise ValueError('Temporal metric not supported!') 70 | if 'J' not in metric and 'F' not in metric: 71 | raise ValueError('Metric possible values are J for IoU or F for Boundary') 72 | 73 | # Containers 74 | metrics_res = {} 75 | if 'J' in metric: 76 | metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 77 | if 'F' in metric: 78 | metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 79 | 80 | # Sweep all sequences 81 | results = Results(root_dir=res_path) 82 | L = os.listdir(res_path) 83 | for seq in tqdm(L): 84 | print("Calculating Class", seq) 85 | # if seq == "car-roundabout": 86 | # break 87 | all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True) 88 | if self.task == 'semi-supervised': 89 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 90 | all_res_masks = results.read_masks(seq, all_masks_id) 91 | if self.task == 'unsupervised': 92 | j_metrics_res, f_metrics_res = self._evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric) 93 | elif self.task == 'semi-supervised': 94 | j_metrics_res, f_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric) 95 | for ii in range(all_gt_masks.shape[0]): 96 | seq_name = f'{seq}_{ii+1}' 97 | if 'J' in metric: 98 | [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii]) 99 | metrics_res['J']["M"].append(JM) 100 | metrics_res['J']["R"].append(JR) 101 | metrics_res['J']["D"].append(JD) 102 | metrics_res['J']["M_per_object"][seq_name] = JM 103 | if 'F' in metric: 104 | [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii]) 105 | metrics_res['F']["M"].append(FM) 106 | metrics_res['F']["R"].append(FR) 107 | metrics_res['F']["D"].append(FD) 108 | metrics_res['F']["M_per_object"][seq_name] = FM 109 | # break 110 | 111 | # Show progress 112 | if debug: 113 | sys.stdout.write(seq + '\n') 114 | sys.stdout.flush() 115 | return metrics_res 116 | -------------------------------------------------------------------------------- /davis2017/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def db_eval_iou(annotation, segmentation, void_pixels=None): 7 | """ Compute region similarity as the Jaccard Index. 8 | Arguments: 9 | annotation (ndarray): binary annotation map. 10 | segmentation (ndarray): binary segmentation map. 11 | void_pixels (ndarray): optional mask with void pixels 12 | 13 | Return: 14 | jaccard (float): region similarity 15 | """ 16 | assert annotation.shape == segmentation.shape, \ 17 | f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.' 18 | annotation = annotation.astype(np.bool) 19 | segmentation = segmentation.astype(np.bool) 20 | 21 | if void_pixels is not None: 22 | assert annotation.shape == void_pixels.shape, \ 23 | f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.' 24 | void_pixels = void_pixels.astype(np.bool) 25 | else: 26 | void_pixels = np.zeros_like(segmentation) 27 | 28 | # Intersection between all sets 29 | inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 30 | union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 31 | 32 | j = inters / union 33 | if j.ndim == 0: 34 | j = 1 if np.isclose(union, 0) else j 35 | else: 36 | j[np.isclose(union, 0)] = 1 37 | return j 38 | 39 | 40 | def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008): 41 | assert annotation.shape == segmentation.shape 42 | if void_pixels is not None: 43 | assert annotation.shape == void_pixels.shape 44 | if annotation.ndim == 3: 45 | n_frames = annotation.shape[0] 46 | f_res = np.zeros(n_frames) 47 | for frame_id in range(n_frames): 48 | void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ] 49 | f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th) 50 | elif annotation.ndim == 2: 51 | f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th) 52 | else: 53 | raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions') 54 | return f_res 55 | 56 | 57 | def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008): 58 | """ 59 | Compute mean,recall and decay from per-frame evaluation. 60 | Calculates precision/recall for boundaries between foreground_mask and 61 | gt_mask using morphological operators to speed it up. 62 | 63 | Arguments: 64 | foreground_mask (ndarray): binary segmentation image. 65 | gt_mask (ndarray): binary annotated image. 66 | void_pixels (ndarray): optional mask with void pixels 67 | 68 | Returns: 69 | F (float): boundaries F-measure 70 | """ 71 | assert np.atleast_3d(foreground_mask).shape[2] == 1 72 | if void_pixels is not None: 73 | void_pixels = void_pixels.astype(np.bool) 74 | else: 75 | void_pixels = np.zeros_like(foreground_mask).astype(np.bool) 76 | 77 | bound_pix = bound_th if bound_th >= 1 else \ 78 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 79 | 80 | # Get the pixel boundaries of both masks 81 | fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) 82 | gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) 83 | 84 | from skimage.morphology import disk 85 | 86 | # fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 87 | fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 88 | # gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 89 | gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 90 | 91 | # Get the intersection 92 | gt_match = gt_boundary * fg_dil 93 | fg_match = fg_boundary * gt_dil 94 | 95 | # Area of the intersection 96 | n_fg = np.sum(fg_boundary) 97 | n_gt = np.sum(gt_boundary) 98 | 99 | # % Compute precision and recall 100 | if n_fg == 0 and n_gt > 0: 101 | precision = 1 102 | recall = 0 103 | elif n_fg > 0 and n_gt == 0: 104 | precision = 0 105 | recall = 1 106 | elif n_fg == 0 and n_gt == 0: 107 | precision = 1 108 | recall = 1 109 | else: 110 | precision = np.sum(fg_match) / float(n_fg) 111 | recall = np.sum(gt_match) / float(n_gt) 112 | 113 | # Compute F measure 114 | if precision + recall == 0: 115 | F = 0 116 | else: 117 | F = 2 * precision * recall / (precision + recall) 118 | 119 | return F 120 | 121 | 122 | def _seg2bmap(seg, width=None, height=None): 123 | """ 124 | From a segmentation, compute a binary boundary map with 1 pixel wide 125 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 126 | origin from the actual segment boundary. 127 | Arguments: 128 | seg : Segments labeled from 1..k. 129 | width : Width of desired bmap <= seg.shape[1] 130 | height : Height of desired bmap <= seg.shape[0] 131 | Returns: 132 | bmap (ndarray): Binary boundary map. 133 | David Martin 134 | January 2003 135 | """ 136 | 137 | seg = seg.astype(np.bool) 138 | seg[seg > 0] = 1 139 | 140 | assert np.atleast_3d(seg).shape[2] == 1 141 | 142 | width = seg.shape[1] if width is None else width 143 | height = seg.shape[0] if height is None else height 144 | 145 | h, w = seg.shape[:2] 146 | 147 | ar1 = float(width) / float(height) 148 | ar2 = float(w) / float(h) 149 | 150 | assert not ( 151 | width > w | height > h | abs(ar1 - ar2) > 0.01 152 | ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) 153 | 154 | e = np.zeros_like(seg) 155 | s = np.zeros_like(seg) 156 | se = np.zeros_like(seg) 157 | 158 | e[:, :-1] = seg[:, 1:] 159 | s[:-1, :] = seg[1:, :] 160 | se[:-1, :-1] = seg[1:, 1:] 161 | 162 | b = seg ^ e | seg ^ s | seg ^ se 163 | b[-1, :] = seg[-1, :] ^ e[-1, :] 164 | b[:, -1] = seg[:, -1] ^ s[:, -1] 165 | b[-1, -1] = 0 166 | 167 | if w == width and h == height: 168 | bmap = b 169 | else: 170 | bmap = np.zeros((height, width)) 171 | for x in range(w): 172 | for y in range(h): 173 | if b[y, x]: 174 | j = 1 + math.floor((y - 1) + height / h) 175 | i = 1 + math.floor((x - 1) + width / h) 176 | bmap[j, i] = 1 177 | 178 | return bmap 179 | 180 | 181 | if __name__ == '__main__': 182 | from davis2017.davis import DAVIS 183 | from davis2017.results import Results 184 | 185 | dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics') 186 | results = Results(root_dir='examples/osvos') 187 | # Test timing F measure 188 | for seq in dataset.get_sequences(): 189 | all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True) 190 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 191 | all_res_masks = results.read_masks(seq, all_masks_id) 192 | f_metrics_res = np.zeros(all_gt_masks.shape[:2]) 193 | for ii in range(all_gt_masks.shape[0]): 194 | f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...]) 195 | 196 | # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py 197 | # snakeviz f_measure.prof 198 | -------------------------------------------------------------------------------- /davis2017/results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import sys 5 | 6 | 7 | class Results(object): 8 | def __init__(self, root_dir): 9 | self.root_dir = root_dir 10 | 11 | def _read_mask(self, sequence, frame_id): 12 | try: 13 | mask_path = os.path.join(self.root_dir, sequence, f'{frame_id}.png') 14 | return np.array(Image.open(mask_path)) 15 | except IOError as err: 16 | sys.stdout.write(sequence + " frame %s not found!\n" % frame_id) 17 | sys.stdout.write("The frames have to be indexed PNG files placed inside the corespondent sequence " 18 | "folder.\nThe indexes have to match with the initial frame.\n") 19 | sys.stderr.write("IOError: " + err.strerror + "\n") 20 | sys.exit() 21 | 22 | def read_masks(self, sequence, masks_id): 23 | mask_0 = self._read_mask(sequence, masks_id[0]) 24 | masks = np.zeros((len(masks_id), *mask_0.shape)) 25 | for ii, m in enumerate(masks_id): 26 | masks[ii, ...] = self._read_mask(sequence, m) 27 | num_objects = int(np.max(masks)) 28 | tmp = np.ones((num_objects, *masks.shape)) 29 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 30 | masks = (tmp == masks[None, ...]) > 0 31 | return masks 32 | -------------------------------------------------------------------------------- /davis2017/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from PIL import Image 5 | import warnings 6 | from davis2017.davis import DAVIS 7 | 8 | 9 | def _pascal_color_map(N=256, normalized=False): 10 | """ 11 | Python implementation of the color map function for the PASCAL VOC data set. 12 | Official Matlab version can be found in the PASCAL VOC devkit 13 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 14 | """ 15 | 16 | def bitget(byteval, idx): 17 | return (byteval & (1 << idx)) != 0 18 | 19 | dtype = 'float32' if normalized else 'uint8' 20 | cmap = np.zeros((N, 3), dtype=dtype) 21 | for i in range(N): 22 | r = g = b = 0 23 | c = i 24 | for j in range(8): 25 | r = r | (bitget(c, 0) << 7 - j) 26 | g = g | (bitget(c, 1) << 7 - j) 27 | b = b | (bitget(c, 2) << 7 - j) 28 | c = c >> 3 29 | 30 | cmap[i] = np.array([r, g, b]) 31 | 32 | cmap = cmap / 255 if normalized else cmap 33 | return cmap 34 | 35 | 36 | def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 37 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 38 | if im.shape[:-1] != ann.shape: 39 | raise ValueError('First two dimensions of `im` and `ann` must match') 40 | if im.shape[-1] != 3: 41 | raise ValueError('im must have three channels at the 3 dimension') 42 | 43 | colors = colors or _pascal_color_map() 44 | colors = np.asarray(colors, dtype=np.uint8) 45 | 46 | mask = colors[ann] 47 | fg = im * alpha + (1 - alpha) * mask 48 | 49 | img = im.copy() 50 | img[ann > 0] = fg[ann > 0] 51 | 52 | if contour_thickness: # pragma: no cover 53 | import cv2 54 | for obj_id in np.unique(ann[ann > 0]): 55 | contours = cv2.findContours((ann == obj_id).astype( 56 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 57 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 58 | contour_thickness) 59 | return img 60 | 61 | 62 | def generate_obj_proposals(davis_root, subset, num_proposals, save_path): 63 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 64 | for seq in dataset.get_sequences(): 65 | save_dir = os.path.join(save_path, seq) 66 | if os.path.exists(save_dir): 67 | continue 68 | all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 69 | img_size = all_gt_masks.shape[2:] 70 | num_rows = int(np.ceil(np.sqrt(num_proposals))) 71 | proposals = np.zeros((num_proposals, len(all_masks_id), *img_size)) 72 | height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist() 73 | width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist() 74 | ii = 0 75 | prev_h, prev_w = 0, 0 76 | for h in height_slices[1:]: 77 | for w in width_slices[1:]: 78 | proposals[ii, :, prev_h:h, prev_w:w] = 1 79 | prev_w = w 80 | ii += 1 81 | if ii == num_proposals: 82 | break 83 | prev_h, prev_w = h, 0 84 | if ii == num_proposals: 85 | break 86 | 87 | os.makedirs(save_dir, exist_ok=True) 88 | for i, mask_id in enumerate(all_masks_id): 89 | mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0) 90 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 91 | 92 | 93 | def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path): 94 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 95 | for seq in dataset.get_sequences(): 96 | gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 97 | obj_swap = np.random.permutation(np.arange(gt_masks.shape[0])) 98 | gt_masks = gt_masks[obj_swap, ...] 99 | save_dir = os.path.join(save_path, seq) 100 | os.makedirs(save_dir, exist_ok=True) 101 | for i, mask_id in enumerate(all_masks_id): 102 | mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0) 103 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 104 | 105 | 106 | def color_map(N=256, normalized=False): 107 | def bitget(byteval, idx): 108 | return ((byteval & (1 << idx)) != 0) 109 | 110 | dtype = 'float32' if normalized else 'uint8' 111 | cmap = np.zeros((N, 3), dtype=dtype) 112 | for i in range(N): 113 | r = g = b = 0 114 | c = i 115 | for j in range(8): 116 | r = r | (bitget(c, 0) << 7-j) 117 | g = g | (bitget(c, 1) << 7-j) 118 | b = b | (bitget(c, 2) << 7-j) 119 | c = c >> 3 120 | 121 | cmap[i] = np.array([r, g, b]) 122 | 123 | cmap = cmap/255 if normalized else cmap 124 | return cmap 125 | 126 | 127 | def save_mask(mask, img_path): 128 | if np.max(mask) > 255: 129 | raise ValueError('Maximum id pixel value is 255') 130 | mask_img = Image.fromarray(mask.astype(np.uint8)) 131 | mask_img.putpalette(color_map().flatten().tolist()) 132 | mask_img.save(img_path) 133 | 134 | 135 | def db_statistics(per_frame_values): 136 | """ Compute mean,recall and decay from per-frame evaluation. 137 | Arguments: 138 | per_frame_values (ndarray): per-frame evaluation 139 | 140 | Returns: 141 | M,O,D (float,float,float): 142 | return evaluation statistics: mean,recall,decay. 143 | """ 144 | 145 | # strip off nan values 146 | with warnings.catch_warnings(): 147 | warnings.simplefilter("ignore", category=RuntimeWarning) 148 | M = np.nanmean(per_frame_values) 149 | O = np.nanmean(per_frame_values > 0.5) 150 | 151 | N_bins = 4 152 | ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 153 | ids = ids.astype(np.uint8) 154 | 155 | D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] 156 | 157 | with warnings.catch_warnings(): 158 | warnings.simplefilter("ignore", category=RuntimeWarning) 159 | D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) 160 | 161 | return M, O, D 162 | 163 | 164 | def list_files(dir, extension=".png"): 165 | return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)] 166 | 167 | 168 | def force_symlink(file1, file2): 169 | try: 170 | os.symlink(file1, file2) 171 | except OSError as e: 172 | if e.errno == errno.EEXIST: 173 | os.remove(file2) 174 | os.symlink(file1, file2) 175 | -------------------------------------------------------------------------------- /eval_miou.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from pathlib import Path 5 | import argparse 6 | 7 | 8 | 9 | def get_arguments(): 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--pred_path', type=str, default='persam') 14 | parser.add_argument('--gt_path', type=str, default='./data/Annotations') 15 | 16 | parser.add_argument('--ref_idx', type=str, default='00') 17 | 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def main(): 23 | 24 | args = get_arguments() 25 | print("Args:", args, "\n"), 26 | 27 | class_names = sorted(os.listdir(args.gt_path)) 28 | class_names = [class_name for class_name in class_names if ".DS" not in class_name] 29 | class_names.sort() 30 | 31 | mIoU, mAcc = 0, 0 32 | count = 0 33 | for class_name in class_names: 34 | count += 1 35 | gt_path_class = os.path.join(args.gt_path, class_name) 36 | pred_path_class = os.path.join("./outputs/" + args.pred_path, class_name) 37 | 38 | gt_images = [str(img_path) for img_path in sorted(Path(gt_path_class).rglob("*.png"))] 39 | pred_images = [str(img_path) for img_path in sorted(Path(pred_path_class).rglob("*.png"))] 40 | 41 | intersection_meter = AverageMeter() 42 | union_meter = AverageMeter() 43 | target_meter = AverageMeter() 44 | 45 | for i, (gt_img, pred_img) in enumerate(zip(gt_images, pred_images)): 46 | if args.ref_idx in gt_img: 47 | continue 48 | 49 | gt_img = cv2.imread(gt_img) 50 | gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY) > 0 51 | gt_img = np.uint8(gt_img) 52 | 53 | pred_img = cv2.imread(pred_img) 54 | pred_img = cv2.cvtColor(pred_img, cv2.COLOR_BGR2GRAY) > 0 55 | pred_img = np.uint8(pred_img) 56 | 57 | intersection, union, target = intersectionAndUnion(pred_img, gt_img) 58 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target) 59 | 60 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 61 | accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) 62 | 63 | print(class_name + ',', "IoU: %.2f," %(100 * iou_class), "Acc: %.2f\n" %(100 * accuracy_class)) 64 | 65 | mIoU += iou_class 66 | mAcc += accuracy_class 67 | 68 | print("\nmIoU: %.2f" %(100 * mIoU / count)) 69 | print("mAcc: %.2f\n" %(100 * mAcc / count)) 70 | 71 | 72 | class AverageMeter(object): 73 | """Computes and stores the average and current value""" 74 | 75 | def __init__(self): 76 | self.reset() 77 | 78 | def reset(self): 79 | self.val = 0 80 | self.avg = 0 81 | self.sum = 0 82 | self.count = 0 83 | 84 | def update(self, val, n=1): 85 | self.val = val 86 | self.sum += val * n 87 | self.count += n 88 | self.avg = self.sum / self.count 89 | 90 | 91 | def intersectionAndUnion(output, target): 92 | assert (output.ndim in [1, 2, 3]) 93 | assert output.shape == target.shape 94 | output = output.reshape(output.size).copy() 95 | target = target.reshape(target.size) 96 | 97 | area_intersection = np.logical_and(output, target).sum() 98 | area_union = np.logical_or(output, target).sum() 99 | area_target = target.sum() 100 | 101 | return area_intersection, area_union, area_target 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /eval_video.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | def eval_davis_result(results_path, davis_path): 3 | import os 4 | import sys 5 | from time import time 6 | import argparse 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from davis2017.evaluation import DAVISEvaluation 11 | 12 | time_start = time() 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--davis_path', type=str, help='Path to the DAVIS folder containing the JPEGImages, Annotations, ' 15 | 'ImageSets, Annotations_unsupervised folders') 16 | parser.add_argument('--set', type=str, help='Subset to evaluate the results', default='val') 17 | parser.add_argument('--task', type=str, help='Task to evaluate the results', default='semi-supervised', 18 | choices=['semi-supervised', 'unsupervised']) 19 | parser.add_argument('--results_path', type=str, help='Path to the folder containing the sequences folders', default='') 20 | args, _ = parser.parse_known_args() 21 | csv_name_global = f'global_results-{args.set}.csv' 22 | csv_name_per_sequence = f'per-sequence_results-{args.set}.csv' 23 | 24 | # Check if the method has been evaluated before, if so read the results, otherwise compute the results 25 | args.results_path = results_path 26 | args.davis_path = davis_path 27 | csv_path = args.results_path.replace("result", "result_csv") 28 | os.makedirs(csv_path, exist_ok=True) 29 | csv_name_global_path = os.path.join(csv_path, csv_name_global) 30 | csv_name_per_sequence_path = os.path.join(csv_path, csv_name_per_sequence) 31 | 32 | print(f'Evaluating sequences for the {args.task} task...') 33 | # Create dataset and evaluate 34 | dataset_eval = DAVISEvaluation(davis_root=args.davis_path, task=args.task, gt_set=args.set) 35 | metrics_res = dataset_eval.evaluate(args.results_path) 36 | J, F = metrics_res['J'], metrics_res['F'] 37 | 38 | # Generate dataframe for the general results 39 | g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay'] 40 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2. 41 | g_res = np.array([final_mean, np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(F["M"]), np.mean(F["R"]), 42 | np.mean(F["D"])]) 43 | g_res = np.reshape(g_res, [1, len(g_res)]) 44 | table_g = pd.DataFrame(data=g_res, columns=g_measures) 45 | with open(csv_name_global_path, 'w') as f: 46 | table_g.to_csv(f, index=False, float_format="%.3f") 47 | print(f'Global results saved in {csv_name_global_path}') 48 | 49 | # Generate a dataframe for the per sequence results 50 | seq_names = list(J['M_per_object'].keys()) 51 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean'] 52 | J_per_object = [J['M_per_object'][x] for x in seq_names] 53 | F_per_object = [F['M_per_object'][x] for x in seq_names] 54 | table_seq = pd.DataFrame(data=list(zip(seq_names, J_per_object, F_per_object)), columns=seq_measures) 55 | with open(csv_name_per_sequence_path, 'w') as f: 56 | table_seq.to_csv(f, index=False, float_format="%.3f") 57 | print(f'Per-sequence results saved in {csv_name_per_sequence_path}') 58 | 59 | # Print the results 60 | sys.stdout.write(f"--------------------------- Global results for {args.set} ---------------------------\n") 61 | print(table_g.to_string(index=False)) 62 | sys.stdout.write(f"\n---------- Per sequence results for {args.set} ----------\n") 63 | print(table_seq.to_string(index=False)) 64 | total_time = time() - time_start 65 | sys.stdout.write('\nTotal time:' + str(total_time)) 66 | 67 | -------------------------------------------------------------------------------- /examples/cat_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/cat_00.jpg -------------------------------------------------------------------------------- /examples/cat_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/cat_00.png -------------------------------------------------------------------------------- /examples/cat_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/cat_01.jpg -------------------------------------------------------------------------------- /examples/cat_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/cat_01.png -------------------------------------------------------------------------------- /examples/cat_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/cat_02.jpg -------------------------------------------------------------------------------- /examples/cat_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/cat_02.png -------------------------------------------------------------------------------- /examples/colorful_sneaker_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/colorful_sneaker_00.jpg -------------------------------------------------------------------------------- /examples/colorful_sneaker_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/colorful_sneaker_00.png -------------------------------------------------------------------------------- /examples/colorful_sneaker_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/colorful_sneaker_01.jpg -------------------------------------------------------------------------------- /examples/colorful_sneaker_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/colorful_sneaker_01.png -------------------------------------------------------------------------------- /examples/colorful_sneaker_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/colorful_sneaker_02.jpg -------------------------------------------------------------------------------- /examples/colorful_sneaker_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/colorful_sneaker_02.png -------------------------------------------------------------------------------- /examples/duck_toy_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/duck_toy_00.jpg -------------------------------------------------------------------------------- /examples/duck_toy_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/duck_toy_00.png -------------------------------------------------------------------------------- /examples/duck_toy_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/duck_toy_01.jpg -------------------------------------------------------------------------------- /examples/duck_toy_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/duck_toy_01.png -------------------------------------------------------------------------------- /examples/duck_toy_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/duck_toy_02.jpg -------------------------------------------------------------------------------- /examples/duck_toy_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/examples/duck_toy_02.png -------------------------------------------------------------------------------- /figs/fig_db.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/figs/fig_db.png -------------------------------------------------------------------------------- /figs/fig_persam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/figs/fig_persam.png -------------------------------------------------------------------------------- /per_segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /per_segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crop_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros_like(data["boxes"][:, 0]), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros_like(data["boxes"][:, 0]), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros_like(boxes[:, 0]), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /per_segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | def build_sam_vit_t(checkpoint=None): 47 | prompt_embed_dim = 256 48 | image_size = 1024 49 | vit_patch_size = 16 50 | image_embedding_size = image_size // vit_patch_size 51 | mobile_sam = Sam( 52 | image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, 53 | embed_dims=[64, 128, 160, 320], 54 | depths=[2, 2, 6, 2], 55 | num_heads=[2, 4, 5, 10], 56 | window_sizes=[7, 7, 14, 7], 57 | mlp_ratio=4., 58 | drop_rate=0., 59 | drop_path_rate=0.0, 60 | use_checkpoint=False, 61 | mbconv_expand_ratio=4.0, 62 | local_conv_size=3, 63 | layer_lr_decay=0.8 64 | ), 65 | prompt_encoder=PromptEncoder( 66 | embed_dim=prompt_embed_dim, 67 | image_embedding_size=(image_embedding_size, image_embedding_size), 68 | input_image_size=(image_size, image_size), 69 | mask_in_chans=16, 70 | ), 71 | mask_decoder=MaskDecoder( 72 | num_multimask_outputs=3, 73 | transformer=TwoWayTransformer( 74 | depth=2, 75 | embedding_dim=prompt_embed_dim, 76 | mlp_dim=2048, 77 | num_heads=8, 78 | ), 79 | transformer_dim=prompt_embed_dim, 80 | iou_head_depth=3, 81 | iou_head_hidden_dim=256, 82 | ), 83 | pixel_mean=[123.675, 116.28, 103.53], 84 | pixel_std=[58.395, 57.12, 57.375], 85 | ) 86 | 87 | mobile_sam.eval() 88 | if checkpoint is not None: 89 | with open(checkpoint, "rb") as f: 90 | state_dict = torch.load(f) 91 | mobile_sam.load_state_dict(state_dict) 92 | return mobile_sam 93 | 94 | sam_model_registry = { 95 | "default": build_sam_vit_h, 96 | "vit_h": build_sam_vit_h, 97 | "vit_l": build_sam_vit_l, 98 | "vit_b": build_sam_vit_b, 99 | "vit_t": build_sam_vit_t, 100 | } 101 | 102 | 103 | def _build_sam( 104 | encoder_embed_dim, 105 | encoder_depth, 106 | encoder_num_heads, 107 | encoder_global_attn_indexes, 108 | checkpoint=None, 109 | ): 110 | prompt_embed_dim = 256 111 | image_size = 1024 112 | vit_patch_size = 16 113 | image_embedding_size = image_size // vit_patch_size 114 | sam = Sam( 115 | image_encoder=ImageEncoderViT( 116 | depth=encoder_depth, 117 | embed_dim=encoder_embed_dim, 118 | img_size=image_size, 119 | mlp_ratio=4, 120 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 121 | num_heads=encoder_num_heads, 122 | patch_size=vit_patch_size, 123 | qkv_bias=True, 124 | use_rel_pos=True, 125 | global_attn_indexes=encoder_global_attn_indexes, 126 | window_size=14, 127 | out_chans=prompt_embed_dim, 128 | ), 129 | prompt_encoder=PromptEncoder( 130 | embed_dim=prompt_embed_dim, 131 | image_embedding_size=(image_embedding_size, image_embedding_size), 132 | input_image_size=(image_size, image_size), 133 | mask_in_chans=16, 134 | ), 135 | mask_decoder=MaskDecoder( 136 | num_multimask_outputs=3, 137 | transformer=TwoWayTransformer( 138 | depth=2, 139 | embedding_dim=prompt_embed_dim, 140 | mlp_dim=2048, 141 | num_heads=8, 142 | ), 143 | transformer_dim=prompt_embed_dim, 144 | iou_head_depth=3, 145 | iou_head_hidden_dim=256, 146 | ), 147 | pixel_mean=[123.675, 116.28, 103.53], 148 | pixel_std=[58.395, 57.12, 57.375], 149 | ) 150 | sam.eval() 151 | if checkpoint is not None: 152 | with open(checkpoint, "rb") as f: 153 | state_dict = torch.load(f) 154 | sam.load_state_dict(state_dict) 155 | return sam 156 | -------------------------------------------------------------------------------- /per_segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | from .tiny_vit_sam import TinyViT -------------------------------------------------------------------------------- /per_segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /per_segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 148 | positional parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 205 | positional parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /per_segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | attn_sim=None, 79 | target_embedding=None 80 | ) -> Tuple[torch.Tensor, torch.Tensor]: 81 | """ 82 | Predict masks given image and prompt embeddings. 83 | 84 | Arguments: 85 | image_embeddings (torch.Tensor): the embeddings from the image encoder 86 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 87 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 88 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 89 | multimask_output (bool): Whether to return multiple masks or a single 90 | mask. 91 | 92 | Returns: 93 | torch.Tensor: batched predicted masks 94 | torch.Tensor: batched predictions of mask quality 95 | """ 96 | masks, iou_pred = self.predict_masks( 97 | image_embeddings=image_embeddings, 98 | image_pe=image_pe, 99 | sparse_prompt_embeddings=sparse_prompt_embeddings, 100 | dense_prompt_embeddings=dense_prompt_embeddings, 101 | attn_sim=attn_sim, 102 | target_embedding=target_embedding 103 | ) 104 | 105 | # Select the correct mask or masks for output 106 | if multimask_output: 107 | mask_slice = slice(1, None) 108 | else: 109 | mask_slice = slice(0, 1) 110 | masks = masks[:, mask_slice, :, :] 111 | iou_pred = iou_pred[:, mask_slice] 112 | 113 | # Prepare output 114 | return masks, iou_pred 115 | 116 | def predict_masks( 117 | self, 118 | image_embeddings: torch.Tensor, 119 | image_pe: torch.Tensor, 120 | sparse_prompt_embeddings: torch.Tensor, 121 | dense_prompt_embeddings: torch.Tensor, 122 | attn_sim=None, 123 | target_embedding=None 124 | ) -> Tuple[torch.Tensor, torch.Tensor]: 125 | """Predicts masks. See 'forward' for more details.""" 126 | # Concatenate output tokens 127 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 128 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 129 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 130 | 131 | # Expand per-image data in batch direction to be per-mask 132 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 133 | src = src + dense_prompt_embeddings 134 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 135 | b, c, h, w = src.shape 136 | 137 | # Run the transformer 138 | hs, src = self.transformer(src, pos_src, tokens, attn_sim, target_embedding) 139 | iou_token_out = hs[:, 0, :] 140 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 141 | 142 | # Upscale mask embeddings and predict masks using the mask tokens 143 | src = src.transpose(1, 2).view(b, c, h, w) 144 | upscaled_embedding = self.output_upscaling(src) 145 | hyper_in_list: List[torch.Tensor] = [] 146 | for i in range(self.num_mask_tokens): 147 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 148 | hyper_in = torch.stack(hyper_in_list, dim=1) 149 | b, c, h, w = upscaled_embedding.shape 150 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 151 | 152 | # Generate mask quality predictions 153 | iou_pred = self.iou_prediction_head(iou_token_out) 154 | 155 | return masks, iou_pred 156 | 157 | 158 | # Lightly adapted from 159 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 160 | class MLP(nn.Module): 161 | def __init__( 162 | self, 163 | input_dim: int, 164 | hidden_dim: int, 165 | output_dim: int, 166 | num_layers: int, 167 | sigmoid_output: bool = False, 168 | ) -> None: 169 | super().__init__() 170 | self.num_layers = num_layers 171 | h = [hidden_dim] * (num_layers - 1) 172 | self.layers = nn.ModuleList( 173 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 174 | ) 175 | self.sigmoid_output = sigmoid_output 176 | 177 | def forward(self, x): 178 | for i, layer in enumerate(self.layers): 179 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 180 | if self.sigmoid_output: 181 | x = F.sigmoid(x) 182 | return x 183 | -------------------------------------------------------------------------------- /per_segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /per_segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple, Union 12 | from .tiny_vit_sam import TinyViT 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: Union[ImageEncoderViT, TinyViT], 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | 176 | def preprocess_mask(self, x: torch.Tensor) -> torch.Tensor: 177 | """Normalize pixel values and pad to a square input.""" 178 | # Pad 179 | h, w = x.shape[-2:] 180 | padh = self.image_encoder.img_size - h 181 | padw = self.image_encoder.img_size - w 182 | x = F.pad(x, (0, padw, 0, padh)) 183 | return x 184 | -------------------------------------------------------------------------------- /per_segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | attn_sim: Tensor, 68 | target_embedding=None 69 | ) -> Tuple[Tensor, Tensor]: 70 | """ 71 | Args: 72 | image_embedding (torch.Tensor): image to attend to. Should be shape 73 | B x embedding_dim x h x w for any h and w. 74 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 75 | have the same shape as image_embedding. 76 | point_embedding (torch.Tensor): the embedding to add to the query points. 77 | Must have shape B x N_points x embedding_dim for any N_points. 78 | 79 | Returns: 80 | torch.Tensor: the processed point_embedding 81 | torch.Tensor: the processed image_embedding 82 | """ 83 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 84 | bs, c, h, w = image_embedding.shape 85 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 86 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 87 | 88 | # Prepare queries 89 | queries = point_embedding 90 | keys = image_embedding 91 | 92 | # Apply transformer blocks and final layernorm 93 | for layer in self.layers: 94 | if target_embedding is not None: 95 | queries += target_embedding 96 | queries, keys = layer( 97 | queries=queries, 98 | keys=keys, 99 | query_pe=point_embedding, 100 | key_pe=image_pe, 101 | attn_sim=attn_sim, 102 | ) 103 | 104 | # Apply the final attention layer from the points to the image 105 | q = queries + point_embedding 106 | k = keys + image_pe 107 | 108 | if target_embedding is not None: 109 | q += target_embedding 110 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 111 | queries = queries + attn_out 112 | queries = self.norm_final_attn(queries) 113 | 114 | return queries, keys 115 | 116 | 117 | class TwoWayAttentionBlock(nn.Module): 118 | def __init__( 119 | self, 120 | embedding_dim: int, 121 | num_heads: int, 122 | mlp_dim: int = 2048, 123 | activation: Type[nn.Module] = nn.ReLU, 124 | attention_downsample_rate: int = 2, 125 | skip_first_layer_pe: bool = False, 126 | ) -> None: 127 | """ 128 | A transformer block with four layers: (1) self-attention of sparse 129 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 130 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 131 | inputs. 132 | 133 | Arguments: 134 | embedding_dim (int): the channel dimension of the embeddings 135 | num_heads (int): the number of heads in the attention layers 136 | mlp_dim (int): the hidden dimension of the mlp block 137 | activation (nn.Module): the activation of the mlp block 138 | skip_first_layer_pe (bool): skip the PE on the first layer 139 | """ 140 | super().__init__() 141 | self.self_attn = Attention(embedding_dim, num_heads) 142 | self.norm1 = nn.LayerNorm(embedding_dim) 143 | 144 | self.cross_attn_token_to_image = Attention( 145 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 146 | ) 147 | self.norm2 = nn.LayerNorm(embedding_dim) 148 | 149 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 150 | self.norm3 = nn.LayerNorm(embedding_dim) 151 | 152 | self.norm4 = nn.LayerNorm(embedding_dim) 153 | self.cross_attn_image_to_token = Attention( 154 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 155 | ) 156 | 157 | self.skip_first_layer_pe = skip_first_layer_pe 158 | 159 | def forward( 160 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor, attn_sim: Tensor 161 | ) -> Tuple[Tensor, Tensor]: 162 | # Self attention block 163 | if self.skip_first_layer_pe: 164 | queries = self.self_attn(q=queries, k=queries, v=queries) 165 | else: 166 | q = queries + query_pe 167 | attn_out = self.self_attn(q=q, k=q, v=queries) 168 | queries = queries + attn_out 169 | queries = self.norm1(queries) 170 | 171 | # Cross attention block, tokens attending to image embedding 172 | q = queries + query_pe 173 | k = keys + key_pe 174 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys, attn_sim=attn_sim) 175 | queries = queries + attn_out 176 | queries = self.norm2(queries) 177 | 178 | # MLP block 179 | mlp_out = self.mlp(queries) 180 | queries = queries + mlp_out 181 | queries = self.norm3(queries) 182 | 183 | # Cross attention block, image embedding attending to tokens 184 | q = queries + query_pe 185 | k = keys + key_pe 186 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 187 | keys = keys + attn_out 188 | keys = self.norm4(keys) 189 | 190 | return queries, keys 191 | 192 | 193 | class Attention(nn.Module): 194 | """ 195 | An attention layer that allows for downscaling the size of the embedding 196 | after projection to queries, keys, and values. 197 | """ 198 | 199 | def __init__( 200 | self, 201 | embedding_dim: int, 202 | num_heads: int, 203 | downsample_rate: int = 1, 204 | ) -> None: 205 | super().__init__() 206 | self.embedding_dim = embedding_dim 207 | self.internal_dim = embedding_dim // downsample_rate 208 | self.num_heads = num_heads 209 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 210 | 211 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 212 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 213 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 214 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 215 | 216 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 217 | b, n, c = x.shape 218 | x = x.reshape(b, n, num_heads, c // num_heads) 219 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 220 | 221 | def _recombine_heads(self, x: Tensor) -> Tensor: 222 | b, n_heads, n_tokens, c_per_head = x.shape 223 | x = x.transpose(1, 2) 224 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 225 | 226 | def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_sim: Tensor = None) -> Tensor: 227 | # Input projections 228 | q = self.q_proj(q) 229 | k = self.k_proj(k) 230 | v = self.v_proj(v) 231 | 232 | # Separate into heads 233 | q = self._separate_heads(q, self.num_heads) 234 | k = self._separate_heads(k, self.num_heads) 235 | v = self._separate_heads(v, self.num_heads) 236 | 237 | # Attention 238 | _, _, _, c_per_head = q.shape 239 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 240 | attn = attn / math.sqrt(c_per_head) 241 | attn = torch.softmax(attn, dim=-1) 242 | 243 | if attn_sim is not None: 244 | attn = attn + attn_sim 245 | attn = torch.softmax(attn, dim=-1) 246 | 247 | # Get output 248 | out = attn @ v 249 | out = self._recombine_heads(out) 250 | out = self.out_proj(out) 251 | 252 | return out 253 | -------------------------------------------------------------------------------- /per_segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from typing import Optional, Tuple 11 | 12 | from .utils.transforms import ResizeLongestSide 13 | 14 | 15 | class SamPredictor: 16 | def __init__( 17 | self, 18 | sam_model, 19 | ) -> None: 20 | """ 21 | Uses SAM to calculate the image embedding for an image, and then 22 | allow repeated, efficient mask prediction given prompts. 23 | 24 | Arguments: 25 | sam_model (Sam): The model to use for mask prediction. 26 | """ 27 | super().__init__() 28 | self.model = sam_model 29 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 30 | self.reset_image() 31 | 32 | def set_image( 33 | self, 34 | image: np.ndarray, 35 | mask: np.ndarray = None, 36 | image_format: str = "RGB", 37 | cal_image=True 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | # Transform the mask to the form expected by the model 61 | input_mask_torch = None 62 | if mask is not None: 63 | input_mask = self.transform.apply_image(mask) 64 | input_mask_torch = torch.as_tensor(input_mask, device=self.device) 65 | input_mask_torch = input_mask_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 66 | 67 | input_mask = self.set_torch_image(input_image_torch, image.shape[:2], transformed_mask=input_mask_torch) 68 | return input_mask 69 | 70 | 71 | @torch.no_grad() 72 | def set_torch_image( 73 | self, 74 | transformed_image: torch.Tensor, 75 | original_image_size: Tuple[int, ...], 76 | transformed_mask: torch.Tensor = None, 77 | cal_image=True 78 | ) -> None: 79 | """ 80 | Calculates the image embeddings for the provided image, allowing 81 | masks to be predicted with the 'predict' method. Expects the input 82 | image to be already transformed to the format expected by the model. 83 | 84 | Arguments: 85 | transformed_image (torch.Tensor): The input image, with shape 86 | 1x3xHxW, which has been transformed with ResizeLongestSide. 87 | original_image_size (tuple(int, int)): The size of the image 88 | before transformation, in (H, W) format. 89 | """ 90 | assert ( 91 | len(transformed_image.shape) == 4 92 | and transformed_image.shape[1] == 3 93 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 94 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 95 | 96 | if cal_image: 97 | self.reset_image() 98 | self.original_size = original_image_size 99 | self.input_size = tuple(transformed_image.shape[-2:]) 100 | input_image = self.model.preprocess(transformed_image) 101 | self.features = self.model.image_encoder(input_image) 102 | self.is_image_set = True 103 | 104 | if transformed_mask is not None: 105 | input_mask = self.model.preprocess(transformed_mask) # pad to 1024 106 | return input_mask 107 | 108 | def predict( 109 | self, 110 | point_coords: Optional[np.ndarray] = None, 111 | point_labels: Optional[np.ndarray] = None, 112 | box: Optional[np.ndarray] = None, 113 | mask_input: Optional[np.ndarray] = None, 114 | multimask_output: bool = True, 115 | return_logits: bool = False, 116 | attn_sim = None, 117 | target_embedding = None 118 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 119 | """ 120 | Predict masks for the given input prompts, using the currently set image. 121 | 122 | Arguments: 123 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 124 | model. Each point is in (X,Y) in pixels. 125 | point_labels (np.ndarray or None): A length N array of labels for the 126 | point prompts. 1 indicates a foreground point and 0 indicates a 127 | background point. 128 | box (np.ndarray or None): A length 4 array given a box prompt to the 129 | model, in XYXY format. 130 | mask_input (np.ndarray): A low resolution mask input to the model, typically 131 | coming from a previous prediction iteration. Has form 1xHxW, where 132 | for SAM, H=W=256. 133 | multimask_output (bool): If true, the model will return three masks. 134 | For ambiguous input prompts (such as a single click), this will often 135 | produce better masks than a single prediction. If only a single 136 | mask is needed, the model's predicted quality score can be used 137 | to select the best mask. For non-ambiguous prompts, such as multiple 138 | input prompts, multimask_output=False can give better results. 139 | return_logits (bool): If true, returns un-thresholded masks logits 140 | instead of a binary mask. 141 | 142 | Returns: 143 | (np.ndarray): The output masks in CxHxW format, where C is the 144 | number of masks, and (H, W) is the original image size. 145 | (np.ndarray): An array of length C containing the model's 146 | predictions for the quality of each mask. 147 | (np.ndarray): An array of shape CxHxW, where C is the number 148 | of masks and H=W=256. These low resolution logits can be passed to 149 | a subsequent iteration as mask input. 150 | """ 151 | if not self.is_image_set: 152 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 153 | 154 | # Transform input prompts 155 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 156 | if point_coords is not None: 157 | assert ( 158 | point_labels is not None 159 | ), "point_labels must be supplied if point_coords is supplied." 160 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 161 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 162 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 163 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 164 | if box is not None: 165 | box = self.transform.apply_boxes(box, self.original_size) 166 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 167 | box_torch = box_torch[None, :] 168 | if mask_input is not None: 169 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 170 | mask_input_torch = mask_input_torch[None, :, :, :] 171 | masks, iou_predictions, low_res_masks, high_res_masks = self.predict_torch( 172 | coords_torch, 173 | labels_torch, 174 | box_torch, 175 | mask_input_torch, 176 | multimask_output, 177 | return_logits=return_logits, 178 | attn_sim=attn_sim, 179 | target_embedding=target_embedding, 180 | ) 181 | 182 | masks = masks[0].detach().cpu().numpy() 183 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 184 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 185 | high_res_masks = high_res_masks[0] 186 | 187 | return masks, iou_predictions, low_res_masks, high_res_masks 188 | 189 | @torch.no_grad() 190 | def predict_torch( 191 | self, 192 | point_coords: Optional[torch.Tensor], 193 | point_labels: Optional[torch.Tensor], 194 | boxes: Optional[torch.Tensor] = None, 195 | mask_input: Optional[torch.Tensor] = None, 196 | multimask_output: bool = True, 197 | return_logits: bool = False, 198 | attn_sim = None, 199 | target_embedding = None 200 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 201 | """ 202 | Predict masks for the given input prompts, using the currently set image. 203 | Input prompts are batched torch tensors and are expected to already be 204 | transformed to the input frame using ResizeLongestSide. 205 | 206 | Arguments: 207 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 208 | model. Each point is in (X,Y) in pixels. 209 | point_labels (torch.Tensor or None): A BxN array of labels for the 210 | point prompts. 1 indicates a foreground point and 0 indicates a 211 | background point. 212 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 213 | model, in XYXY format. 214 | mask_input (np.ndarray): A low resolution mask input to the model, typically 215 | coming from a previous prediction iteration. Has form Bx1xHxW, where 216 | for SAM, H=W=256. Masks returned by a previous iteration of the 217 | predict method do not need further transformation. 218 | multimask_output (bool): If true, the model will return three masks. 219 | For ambiguous input prompts (such as a single click), this will often 220 | produce better masks than a single prediction. If only a single 221 | mask is needed, the model's predicted quality score can be used 222 | to select the best mask. For non-ambiguous prompts, such as multiple 223 | input prompts, multimask_output=False can give better results. 224 | return_logits (bool): If true, returns un-thresholded masks logits 225 | instead of a binary mask. 226 | 227 | Returns: 228 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 229 | number of masks, and (H, W) is the original image size. 230 | (torch.Tensor): An array of shape BxC containing the model's 231 | predictions for the quality of each mask. 232 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 233 | of masks and H=W=256. These low res logits can be passed to 234 | a subsequent iteration as mask input. 235 | """ 236 | if not self.is_image_set: 237 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 238 | 239 | if point_coords is not None: 240 | points = (point_coords, point_labels) 241 | else: 242 | points = None 243 | 244 | # Embed prompts 245 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 246 | points=points, 247 | boxes=boxes, 248 | masks=mask_input, 249 | ) 250 | 251 | # Predict masks 252 | low_res_masks, iou_predictions = self.model.mask_decoder( 253 | image_embeddings=self.features, 254 | image_pe=self.model.prompt_encoder.get_dense_pe(), 255 | sparse_prompt_embeddings=sparse_embeddings, 256 | dense_prompt_embeddings=dense_embeddings, 257 | multimask_output=multimask_output, 258 | attn_sim=attn_sim, 259 | target_embedding=target_embedding 260 | ) 261 | 262 | # Upscale the masks to the original image resolution 263 | high_res_masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 264 | 265 | if not return_logits: 266 | masks = high_res_masks > self.model.mask_threshold # 0.0 267 | return masks, iou_predictions, low_res_masks, high_res_masks 268 | else: 269 | return high_res_masks, iou_predictions, low_res_masks, high_res_masks 270 | 271 | 272 | def get_image_embedding(self) -> torch.Tensor: 273 | """ 274 | Returns the image embeddings for the currently set image, with 275 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 276 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 277 | """ 278 | if not self.is_image_set: 279 | raise RuntimeError( 280 | "An image must be set with .set_image(...) to generate an embedding." 281 | ) 282 | assert self.features is not None, "Features must exist if an image has been set." 283 | return self.features 284 | 285 | @property 286 | def device(self) -> torch.device: 287 | return self.model.device 288 | 289 | def reset_image(self) -> None: 290 | """Resets the currently set image.""" 291 | self.is_image_set = False 292 | self.features = None 293 | self.orig_h = None 294 | self.orig_w = None 295 | self.input_h = None 296 | self.input_w = None 297 | -------------------------------------------------------------------------------- /per_segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /per_segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /per_segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /per_segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /persam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import cv2 7 | from tqdm import tqdm 8 | import argparse 9 | import matplotlib.pyplot as plt 10 | import warnings 11 | warnings.filterwarnings('ignore') 12 | 13 | from show import * 14 | from per_segment_anything import sam_model_registry, SamPredictor 15 | 16 | 17 | 18 | def get_arguments(): 19 | 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument('--data', type=str, default='./data') 23 | parser.add_argument('--outdir', type=str, default='persam') 24 | parser.add_argument('--ckpt', type=str, default='sam_vit_h_4b8939.pth') 25 | parser.add_argument('--ref_idx', type=str, default='00') 26 | parser.add_argument('--sam_type', type=str, default='vit_h') 27 | 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def main(): 33 | 34 | args = get_arguments() 35 | print("Args:", args) 36 | 37 | images_path = args.data + '/Images/' 38 | masks_path = args.data + '/Annotations/' 39 | output_path = './outputs/' + args.outdir 40 | 41 | if not os.path.exists('./outputs/'): 42 | os.mkdir('./outputs/') 43 | 44 | for obj_name in os.listdir(images_path): 45 | if ".DS" not in obj_name: 46 | persam(args, obj_name, images_path, masks_path, output_path) 47 | 48 | 49 | def persam(args, obj_name, images_path, masks_path, output_path): 50 | 51 | print("\n------------> Segment " + obj_name) 52 | 53 | # Path preparation 54 | ref_image_path = os.path.join(images_path, obj_name, args.ref_idx + '.jpg') 55 | ref_mask_path = os.path.join(masks_path, obj_name, args.ref_idx + '.png') 56 | test_images_path = os.path.join(images_path, obj_name) 57 | 58 | output_path = os.path.join(output_path, obj_name) 59 | os.makedirs(output_path, exist_ok=True) 60 | 61 | # Load images and masks 62 | ref_image = cv2.imread(ref_image_path) 63 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 64 | 65 | ref_mask = cv2.imread(ref_mask_path) 66 | ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) 67 | 68 | 69 | print("======> Load SAM" ) 70 | if args.sam_type == 'vit_h': 71 | sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' 72 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() 73 | elif args.sam_type == 'vit_t': 74 | sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' 75 | device = "cuda" if torch.cuda.is_available() else "cpu" 76 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) 77 | sam.eval() 78 | 79 | predictor = SamPredictor(sam) 80 | 81 | print("======> Obtain Location Prior" ) 82 | # Image features encoding 83 | ref_mask = predictor.set_image(ref_image, ref_mask) 84 | ref_feat = predictor.features.squeeze().permute(1, 2, 0) 85 | 86 | ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear") 87 | ref_mask = ref_mask.squeeze()[0] 88 | 89 | # Target feature extraction 90 | target_feat = ref_feat[ref_mask > 0] 91 | target_embedding = target_feat.mean(0).unsqueeze(0) 92 | target_feat = target_embedding / target_embedding.norm(dim=-1, keepdim=True) 93 | target_embedding = target_embedding.unsqueeze(0) 94 | 95 | 96 | print('======> Start Testing') 97 | for test_idx in tqdm(range(len(os.listdir(test_images_path)))): 98 | 99 | # Load test image 100 | test_idx = '%02d' % test_idx 101 | test_image_path = test_images_path + '/' + test_idx + '.jpg' 102 | test_image = cv2.imread(test_image_path) 103 | test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB) 104 | 105 | # Image feature encoding 106 | predictor.set_image(test_image) 107 | test_feat = predictor.features.squeeze() 108 | 109 | # Cosine similarity 110 | C, h, w = test_feat.shape 111 | test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) 112 | test_feat = test_feat.reshape(C, h * w) 113 | sim = target_feat @ test_feat 114 | 115 | sim = sim.reshape(1, 1, h, w) 116 | sim = F.interpolate(sim, scale_factor=4, mode="bilinear") 117 | sim = predictor.model.postprocess_masks( 118 | sim, 119 | input_size=predictor.input_size, 120 | original_size=predictor.original_size).squeeze() 121 | 122 | # Positive-negative location prior 123 | topk_xy_i, topk_label_i, last_xy_i, last_label_i = point_selection(sim, topk=1) 124 | topk_xy = np.concatenate([topk_xy_i, last_xy_i], axis=0) 125 | topk_label = np.concatenate([topk_label_i, last_label_i], axis=0) 126 | 127 | # Obtain the target guidance for cross-attention layers 128 | sim = (sim - sim.mean()) / torch.std(sim) 129 | sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear") 130 | attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3) 131 | 132 | # First-step prediction 133 | masks, scores, logits, _ = predictor.predict( 134 | point_coords=topk_xy, 135 | point_labels=topk_label, 136 | multimask_output=False, 137 | attn_sim=attn_sim, # Target-guided Attention 138 | target_embedding=target_embedding # Target-semantic Prompting 139 | ) 140 | best_idx = 0 141 | 142 | # Cascaded Post-refinement-1 143 | masks, scores, logits, _ = predictor.predict( 144 | point_coords=topk_xy, 145 | point_labels=topk_label, 146 | mask_input=logits[best_idx: best_idx + 1, :, :], 147 | multimask_output=True) 148 | best_idx = np.argmax(scores) 149 | 150 | # Cascaded Post-refinement-2 151 | y, x = np.nonzero(masks[best_idx]) 152 | x_min = x.min() 153 | x_max = x.max() 154 | y_min = y.min() 155 | y_max = y.max() 156 | input_box = np.array([x_min, y_min, x_max, y_max]) 157 | masks, scores, logits, _ = predictor.predict( 158 | point_coords=topk_xy, 159 | point_labels=topk_label, 160 | box=input_box[None, :], 161 | mask_input=logits[best_idx: best_idx + 1, :, :], 162 | multimask_output=True) 163 | best_idx = np.argmax(scores) 164 | 165 | # Save masks 166 | plt.figure(figsize=(10, 10)) 167 | plt.imshow(test_image) 168 | show_mask(masks[best_idx], plt.gca()) 169 | show_points(topk_xy, topk_label, plt.gca()) 170 | plt.title(f"Mask {best_idx}", fontsize=18) 171 | plt.axis('off') 172 | vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}.jpg') 173 | with open(vis_mask_output_path, 'wb') as outfile: 174 | plt.savefig(outfile, format='jpg') 175 | 176 | final_mask = masks[best_idx] 177 | mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) 178 | mask_colors[final_mask, :] = np.array([[0, 0, 128]]) 179 | mask_output_path = os.path.join(output_path, test_idx + '.png') 180 | cv2.imwrite(mask_output_path, mask_colors) 181 | 182 | 183 | def point_selection(mask_sim, topk=1): 184 | # Top-1 point selection 185 | w, h = mask_sim.shape 186 | topk_xy = mask_sim.flatten(0).topk(topk)[1] 187 | topk_x = (topk_xy // h).unsqueeze(0) 188 | topk_y = (topk_xy - topk_x * h) 189 | topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) 190 | topk_label = np.array([1] * topk) 191 | topk_xy = topk_xy.cpu().numpy() 192 | 193 | # Top-last point selection 194 | last_xy = mask_sim.flatten(0).topk(topk, largest=False)[1] 195 | last_x = (last_xy // h).unsqueeze(0) 196 | last_y = (last_xy - last_x * h) 197 | last_xy = torch.cat((last_y, last_x), dim=0).permute(1, 0) 198 | last_label = np.array([0] * topk) 199 | last_xy = last_xy.cpu().numpy() 200 | 201 | return topk_xy, topk_label, last_xy, last_label 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /persam_f.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | import os 7 | import cv2 8 | from tqdm import tqdm 9 | import argparse 10 | import matplotlib.pyplot as plt 11 | import warnings 12 | warnings.filterwarnings('ignore') 13 | 14 | from show import * 15 | from per_segment_anything import sam_model_registry, SamPredictor 16 | 17 | 18 | 19 | def get_arguments(): 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('--data', type=str, default='./data') 24 | parser.add_argument('--outdir', type=str, default='persam_f') 25 | parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth') 26 | parser.add_argument('--sam_type', type=str, default='vit_h') 27 | 28 | parser.add_argument('--lr', type=float, default=1e-3) 29 | parser.add_argument('--train_epoch', type=int, default=1000) 30 | parser.add_argument('--log_epoch', type=int, default=200) 31 | parser.add_argument('--ref_idx', type=str, default='00') 32 | 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def main(): 38 | 39 | args = get_arguments() 40 | print("Args:", args) 41 | 42 | images_path = args.data + '/Images/' 43 | masks_path = args.data + '/Annotations/' 44 | output_path = './outputs/' + args.outdir 45 | 46 | if not os.path.exists('./outputs/'): 47 | os.mkdir('./outputs/') 48 | 49 | for obj_name in os.listdir(images_path): 50 | if ".DS" not in obj_name: 51 | persam_f(args, obj_name, images_path, masks_path, output_path) 52 | 53 | 54 | def persam_f(args, obj_name, images_path, masks_path, output_path): 55 | 56 | print("\n------------> Segment " + obj_name) 57 | 58 | # Path preparation 59 | ref_image_path = os.path.join(images_path, obj_name, args.ref_idx + '.jpg') 60 | ref_mask_path = os.path.join(masks_path, obj_name, args.ref_idx + '.png') 61 | test_images_path = os.path.join(images_path, obj_name) 62 | 63 | output_path = os.path.join(output_path, obj_name) 64 | os.makedirs(output_path, exist_ok=True) 65 | 66 | # Load images and masks 67 | ref_image = cv2.imread(ref_image_path) 68 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 69 | 70 | ref_mask = cv2.imread(ref_mask_path) 71 | ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) 72 | 73 | gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 74 | gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() 75 | 76 | 77 | print("======> Load SAM" ) 78 | if args.sam_type == 'vit_h': 79 | sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' 80 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() 81 | elif args.sam_type == 'vit_t': 82 | sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' 83 | device = "cuda" if torch.cuda.is_available() else "cpu" 84 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) 85 | sam.eval() 86 | 87 | 88 | for name, param in sam.named_parameters(): 89 | param.requires_grad = False 90 | predictor = SamPredictor(sam) 91 | 92 | 93 | print("======> Obtain Self Location Prior" ) 94 | # Image features encoding 95 | ref_mask = predictor.set_image(ref_image, ref_mask) 96 | ref_feat = predictor.features.squeeze().permute(1, 2, 0) 97 | 98 | ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear") 99 | ref_mask = ref_mask.squeeze()[0] 100 | 101 | # Target feature extraction 102 | target_feat = ref_feat[ref_mask > 0] 103 | target_feat_mean = target_feat.mean(0) 104 | target_feat_max = torch.max(target_feat, dim=0)[0] 105 | target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0) 106 | 107 | # Cosine similarity 108 | h, w, C = ref_feat.shape 109 | target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True) 110 | ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True) 111 | ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w) 112 | sim = target_feat @ ref_feat 113 | 114 | sim = sim.reshape(1, 1, h, w) 115 | sim = F.interpolate(sim, scale_factor=4, mode="bilinear") 116 | sim = predictor.model.postprocess_masks( 117 | sim, 118 | input_size=predictor.input_size, 119 | original_size=predictor.original_size).squeeze() 120 | 121 | # Positive location prior 122 | topk_xy, topk_label = point_selection(sim, topk=1) 123 | 124 | 125 | print('======> Start Training') 126 | # Learnable mask weights 127 | mask_weights = Mask_Weights().cuda() 128 | mask_weights.train() 129 | 130 | optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) 131 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch) 132 | 133 | for train_idx in range(args.train_epoch): 134 | 135 | # Run the decoder 136 | masks, scores, logits, logits_high = predictor.predict( 137 | point_coords=topk_xy, 138 | point_labels=topk_label, 139 | multimask_output=True) 140 | logits_high = logits_high.flatten(1) 141 | 142 | # Weighted sum three-scale masks 143 | weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) 144 | logits_high = logits_high * weights 145 | logits_high = logits_high.sum(0).unsqueeze(0) 146 | 147 | dice_loss = calculate_dice_loss(logits_high, gt_mask) 148 | focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask) 149 | loss = dice_loss + focal_loss 150 | 151 | optimizer.zero_grad() 152 | loss.backward() 153 | optimizer.step() 154 | scheduler.step() 155 | 156 | if train_idx % args.log_epoch == 0: 157 | print('Train Epoch: {:} / {:}'.format(train_idx, args.train_epoch)) 158 | current_lr = scheduler.get_last_lr()[0] 159 | print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item())) 160 | 161 | 162 | mask_weights.eval() 163 | weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) 164 | weights_np = weights.detach().cpu().numpy() 165 | print('======> Mask weights:\n', weights_np) 166 | 167 | print('======> Start Testing') 168 | for test_idx in tqdm(range(len(os.listdir(test_images_path)))): 169 | 170 | # Load test image 171 | test_idx = '%02d' % test_idx 172 | test_image_path = test_images_path + '/' + test_idx + '.jpg' 173 | test_image = cv2.imread(test_image_path) 174 | test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB) 175 | 176 | # Image feature encoding 177 | predictor.set_image(test_image) 178 | test_feat = predictor.features.squeeze() 179 | 180 | # Cosine similarity 181 | C, h, w = test_feat.shape 182 | test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) 183 | test_feat = test_feat.reshape(C, h * w) 184 | sim = target_feat @ test_feat 185 | 186 | sim = sim.reshape(1, 1, h, w) 187 | sim = F.interpolate(sim, scale_factor=4, mode="bilinear") 188 | sim = predictor.model.postprocess_masks( 189 | sim, 190 | input_size=predictor.input_size, 191 | original_size=predictor.original_size).squeeze() 192 | 193 | # Positive location prior 194 | topk_xy, topk_label = point_selection(sim, topk=1) 195 | 196 | # First-step prediction 197 | masks, scores, logits, logits_high = predictor.predict( 198 | point_coords=topk_xy, 199 | point_labels=topk_label, 200 | multimask_output=True) 201 | 202 | # Weighted sum three-scale masks 203 | logits_high = logits_high * weights.unsqueeze(-1) 204 | logit_high = logits_high.sum(0) 205 | mask = (logit_high > 0).detach().cpu().numpy() 206 | 207 | logits = logits * weights_np[..., None] 208 | logit = logits.sum(0) 209 | 210 | # Cascaded Post-refinement-1 211 | y, x = np.nonzero(mask) 212 | x_min = x.min() 213 | x_max = x.max() 214 | y_min = y.min() 215 | y_max = y.max() 216 | input_box = np.array([x_min, y_min, x_max, y_max]) 217 | masks, scores, logits, _ = predictor.predict( 218 | point_coords=topk_xy, 219 | point_labels=topk_label, 220 | box=input_box[None, :], 221 | mask_input=logit[None, :, :], 222 | multimask_output=True) 223 | best_idx = np.argmax(scores) 224 | 225 | # Cascaded Post-refinement-2 226 | y, x = np.nonzero(masks[best_idx]) 227 | x_min = x.min() 228 | x_max = x.max() 229 | y_min = y.min() 230 | y_max = y.max() 231 | input_box = np.array([x_min, y_min, x_max, y_max]) 232 | masks, scores, logits, _ = predictor.predict( 233 | point_coords=topk_xy, 234 | point_labels=topk_label, 235 | box=input_box[None, :], 236 | mask_input=logits[best_idx: best_idx + 1, :, :], 237 | multimask_output=True) 238 | best_idx = np.argmax(scores) 239 | 240 | # Save masks 241 | plt.figure(figsize=(10, 10)) 242 | plt.imshow(test_image) 243 | show_mask(masks[best_idx], plt.gca()) 244 | show_points(topk_xy, topk_label, plt.gca()) 245 | plt.title(f"Mask {best_idx}", fontsize=18) 246 | plt.axis('off') 247 | vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}.jpg') 248 | with open(vis_mask_output_path, 'wb') as outfile: 249 | plt.savefig(outfile, format='jpg') 250 | 251 | final_mask = masks[best_idx] 252 | mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) 253 | mask_colors[final_mask, :] = np.array([[0, 0, 128]]) 254 | mask_output_path = os.path.join(output_path, test_idx + '.png') 255 | cv2.imwrite(mask_output_path, mask_colors) 256 | 257 | 258 | class Mask_Weights(nn.Module): 259 | def __init__(self): 260 | super().__init__() 261 | self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3) 262 | 263 | 264 | def point_selection(mask_sim, topk=1): 265 | # Top-1 point selection 266 | w, h = mask_sim.shape 267 | topk_xy = mask_sim.flatten(0).topk(topk)[1] 268 | topk_x = (topk_xy // h).unsqueeze(0) 269 | topk_y = (topk_xy - topk_x * h) 270 | topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) 271 | topk_label = np.array([1] * topk) 272 | topk_xy = topk_xy.cpu().numpy() 273 | 274 | return topk_xy, topk_label 275 | 276 | 277 | def calculate_dice_loss(inputs, targets, num_masks = 1): 278 | """ 279 | Compute the DICE loss, similar to generalized IOU for masks 280 | Args: 281 | inputs: A float tensor of arbitrary shape. 282 | The predictions for each example. 283 | targets: A float tensor with the same shape as inputs. Stores the binary 284 | classification label for each element in inputs 285 | (0 for the negative class and 1 for the positive class). 286 | """ 287 | inputs = inputs.sigmoid() 288 | inputs = inputs.flatten(1) 289 | numerator = 2 * (inputs * targets).sum(-1) 290 | denominator = inputs.sum(-1) + targets.sum(-1) 291 | loss = 1 - (numerator + 1) / (denominator + 1) 292 | return loss.sum() / num_masks 293 | 294 | 295 | def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2): 296 | """ 297 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 298 | Args: 299 | inputs: A float tensor of arbitrary shape. 300 | The predictions for each example. 301 | targets: A float tensor with the same shape as inputs. Stores the binary 302 | classification label for each element in inputs 303 | (0 for the negative class and 1 for the positive class). 304 | alpha: (optional) Weighting factor in range (0,1) to balance 305 | positive vs negative examples. Default = -1 (no weighting). 306 | gamma: Exponent of the modulating factor (1 - p_t) to 307 | balance easy vs hard examples. 308 | Returns: 309 | Loss tensor 310 | """ 311 | prob = inputs.sigmoid() 312 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 313 | p_t = prob * targets + (1 - prob) * (1 - targets) 314 | loss = ce_loss * ((1 - p_t) ** gamma) 315 | 316 | if alpha >= 0: 317 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 318 | loss = alpha_t * loss 319 | 320 | return loss.mean(1).sum() / num_masks 321 | 322 | 323 | if __name__ == '__main__': 324 | main() 325 | -------------------------------------------------------------------------------- /persam_f_multi_obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | import os 7 | import cv2 8 | from tqdm import tqdm 9 | import argparse 10 | import matplotlib.pyplot as plt 11 | import warnings 12 | warnings.filterwarnings('ignore') 13 | 14 | from show import * 15 | from per_segment_anything import sam_model_registry, SamPredictor 16 | 17 | 18 | 19 | def get_arguments(): 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('--data', type=str, default='./data') 24 | parser.add_argument('--outdir', type=str, default='persam_f') 25 | parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth') 26 | parser.add_argument('--sam_type', type=str, default='vit_h') 27 | 28 | parser.add_argument('--lr', type=int, default=1e-3) 29 | parser.add_argument('--train_epoch_outside', type=int, default=1) 30 | parser.add_argument('--train_epoch_inside', type=int, default=200) 31 | parser.add_argument('--log_epoch', type=int, default=200) 32 | parser.add_argument('--training_percentage', type=float, default=0.5) 33 | 34 | parser.add_argument('--max_objects', type=int, default=10) 35 | parser.add_argument('--iou_threshold', type=float, default=0.8) 36 | 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def main(): 42 | 43 | args = get_arguments() 44 | print("Args:", args) 45 | 46 | images_path = args.data + '/Images/' 47 | masks_path = args.data + '/Annotations/' 48 | output_path = './outputs/' + args.outdir 49 | 50 | 51 | 52 | if not os.path.exists('./outputs/'): 53 | os.mkdir('./outputs/') 54 | 55 | for obj_name in os.listdir(images_path): 56 | persam_f(args, obj_name, images_path, masks_path, output_path) 57 | 58 | 59 | def persam_f(args, obj_name, images_path, masks_path, output_path): 60 | print("======> Load SAM" ) 61 | if args.sam_type == 'vit_h': 62 | sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' 63 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() 64 | elif args.sam_type == 'vit_t': 65 | sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' 66 | device = "cuda" if torch.cuda.is_available() else "cpu" 67 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) 68 | sam.eval() 69 | 70 | 71 | for name, param in sam.named_parameters(): 72 | param.requires_grad = False 73 | predictor = SamPredictor(sam) 74 | 75 | print("\n------------> Segment " + obj_name) 76 | for i in tqdm(range(args.train_epoch_outside)): 77 | output_path = os.path.join(output_path, obj_name) 78 | os.makedirs(output_path, exist_ok=True) 79 | training_size = int(len(os.listdir(os.path.join(images_path, obj_name))) * args.training_percentage) 80 | for ref_idx in range(training_size): 81 | # Path preparation 82 | ref_image_path = os.path.join(images_path, obj_name, '{:02}.jpg'.format(ref_idx)) 83 | ref_mask_path = os.path.join(masks_path, obj_name, '{:02}.png'.format(ref_idx)) 84 | test_images_path = os.path.join(images_path, obj_name) 85 | 86 | # Load images and masks 87 | ref_image = cv2.imread(ref_image_path) 88 | ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 89 | 90 | ref_mask = cv2.imread(ref_mask_path) 91 | ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) 92 | 93 | gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0 94 | gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda() 95 | 96 | # print("======> Obtain Self Location Prior" ) 97 | # Image features encoding 98 | ref_mask = predictor.set_image(ref_image, ref_mask) 99 | ref_feat = predictor.features.squeeze().permute(1, 2, 0) 100 | 101 | ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear") 102 | ref_mask = ref_mask.squeeze()[0] 103 | 104 | # Target feature extraction 105 | target_feat = ref_feat[ref_mask > 0] 106 | target_feat_mean = target_feat.mean(0) 107 | target_feat_max = torch.max(target_feat, dim=0)[0] 108 | target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0) 109 | 110 | # Cosine similarity 111 | h, w, C = ref_feat.shape 112 | target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True) 113 | ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True) 114 | ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w) 115 | sim = target_feat @ ref_feat 116 | 117 | sim = sim.reshape(1, 1, h, w) 118 | sim = F.interpolate(sim, scale_factor=4, mode="bilinear") 119 | sim = predictor.model.postprocess_masks( 120 | sim, 121 | input_size=predictor.input_size, 122 | original_size=predictor.original_size).squeeze() 123 | 124 | # Positive location prior 125 | topk_xy, topk_label = point_selection(sim, topk=1) 126 | 127 | 128 | # print('======> Start Training') 129 | # Learnable mask weights 130 | mask_weights = Mask_Weights().cuda() 131 | mask_weights.train() 132 | 133 | optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4) 134 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch_inside) 135 | 136 | for train_idx in range(args.train_epoch_inside): 137 | 138 | # Run the decoder 139 | masks, scores, logits, logits_high = predictor.predict( 140 | point_coords=topk_xy, 141 | point_labels=topk_label, 142 | multimask_output=True) 143 | logits_high = logits_high.flatten(1) 144 | 145 | # Weighted sum three-scale masks 146 | weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) 147 | logits_high = logits_high * weights 148 | logits_high = logits_high.sum(0).unsqueeze(0) 149 | 150 | dice_loss = calculate_dice_loss(logits_high, gt_mask) 151 | focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask) 152 | loss = dice_loss + focal_loss 153 | 154 | optimizer.zero_grad() 155 | loss.backward() 156 | optimizer.step() 157 | scheduler.step() 158 | # print('Train Epoch: {:} / {:}'.format(train_idx, args.train_epoch_inside)) 159 | current_lr = scheduler.get_last_lr()[0] 160 | 161 | 162 | mask_weights.eval() 163 | weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0) 164 | weights_np = weights.detach().cpu().numpy() 165 | # print('======> Mask weights:\n', weights_np) 166 | print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item())) 167 | 168 | print('======> Start Testing') 169 | for test_idx in tqdm(range(len(os.listdir(test_images_path)))): 170 | 171 | # Load test image 172 | test_idx = '%02d' % test_idx 173 | test_image_path = test_images_path + '/' + test_idx + '.jpg' 174 | test_image = cv2.imread(test_image_path) 175 | test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB) 176 | test_image_original = cv2.imread(test_image_path) 177 | test_image_original = cv2.cvtColor(test_image_original, cv2.COLOR_BGR2RGB) 178 | 179 | history_masks = [] 180 | plt.figure(figsize=(10, 10)) 181 | for i in tqdm(range(args.max_objects)): 182 | # Image feature encoding 183 | predictor.set_image(test_image) 184 | test_feat = predictor.features.squeeze() 185 | 186 | # Cosine similarity 187 | C, h, w = test_feat.shape 188 | test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) 189 | test_feat = test_feat.reshape(C, h * w) 190 | sim = target_feat @ test_feat 191 | 192 | sim = sim.reshape(1, 1, h, w) 193 | sim = F.interpolate(sim, scale_factor=4, mode="bilinear") 194 | sim = predictor.model.postprocess_masks( 195 | sim, 196 | input_size=predictor.input_size, 197 | original_size=predictor.original_size).squeeze() 198 | 199 | # Positive location prior 200 | topk_xy, topk_label = point_selection(sim, topk=1) 201 | 202 | # First-step prediction 203 | masks, scores, logits, logits_high = predictor.predict( 204 | point_coords=topk_xy, 205 | point_labels=topk_label, 206 | multimask_output=True) 207 | 208 | # Weighted sum three-scale masks 209 | logits_high = logits_high * weights.unsqueeze(-1) 210 | logit_high = logits_high.sum(0) 211 | mask = (logit_high > 0).detach().cpu().numpy() 212 | 213 | logits = logits * weights_np[..., None] 214 | logit = logits.sum(0) 215 | 216 | # Cascaded Post-refinement-1 217 | y, x = np.nonzero(mask) 218 | x_min = x.min() 219 | x_max = x.max() 220 | y_min = y.min() 221 | y_max = y.max() 222 | input_box = np.array([x_min, y_min, x_max, y_max]) 223 | masks, scores, logits, _ = predictor.predict( 224 | point_coords=topk_xy, 225 | point_labels=topk_label, 226 | box=input_box[None, :], 227 | mask_input=logit[None, :, :], 228 | multimask_output=True) 229 | best_idx = np.argmax(scores) 230 | 231 | # Cascaded Post-refinement-2 232 | y, x = np.nonzero(masks[best_idx]) 233 | x_min = x.min() 234 | x_max = x.max() 235 | y_min = y.min() 236 | y_max = y.max() 237 | input_box = np.array([x_min, y_min, x_max, y_max]) 238 | masks, scores, logits, _ = predictor.predict( 239 | point_coords=topk_xy, 240 | point_labels=topk_label, 241 | box=input_box[None, :], 242 | mask_input=logits[best_idx: best_idx + 1, :, :], 243 | multimask_output=True) 244 | best_idx = np.argmax(scores) 245 | 246 | 247 | final_mask = masks[best_idx] 248 | mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) 249 | mask_colors[final_mask, :] = np.array([[0, 0, 128]]) 250 | 251 | mask_bool = mask_colors.sum(axis=2) == 128 252 | test_image[mask_bool] = 0 253 | iou_over_threshold = False 254 | for h_mask in history_masks: 255 | if calculate_iou(h_mask, mask_colors) >= args.iou_threshold: 256 | iou_over_threshold = True 257 | break 258 | if iou_over_threshold: 259 | break 260 | show_mask(masks[best_idx], plt.gca()) 261 | show_points(topk_xy, topk_label, plt.gca()) 262 | history_masks.append(mask_colors) 263 | # Save masks 264 | 265 | plt.imshow(test_image_original) 266 | vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}_objects:{len(history_masks)}.jpg') 267 | with open(vis_mask_output_path, 'wb') as outfile: 268 | plt.savefig(outfile, format='jpg') 269 | 270 | mask_output_path = os.path.join(output_path, test_idx + '.png') 271 | cv2.imwrite(mask_output_path, mask_colors) 272 | 273 | 274 | 275 | class Mask_Weights(nn.Module): 276 | def __init__(self): 277 | super().__init__() 278 | self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3) 279 | 280 | 281 | def point_selection(mask_sim, topk=1): 282 | # Top-1 point selection 283 | w, h = mask_sim.shape 284 | topk_xy = mask_sim.flatten(0).topk(topk)[1] 285 | topk_x = (topk_xy // h).unsqueeze(0) 286 | topk_y = (topk_xy - topk_x * h) 287 | topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) 288 | topk_label = np.array([1] * topk) 289 | topk_xy = topk_xy.cpu().numpy() 290 | 291 | return topk_xy, topk_label 292 | 293 | 294 | def calculate_dice_loss(inputs, targets, num_masks = 1): 295 | """ 296 | Compute the DICE loss, similar to generalized IOU for masks 297 | Args: 298 | inputs: A float tensor of arbitrary shape. 299 | The predictions for each example. 300 | targets: A float tensor with the same shape as inputs. Stores the binary 301 | classification label for each element in inputs 302 | (0 for the negative class and 1 for the positive class). 303 | """ 304 | inputs = inputs.sigmoid() 305 | inputs = inputs.flatten(1) 306 | numerator = 2 * (inputs * targets).sum(-1) 307 | denominator = inputs.sum(-1) + targets.sum(-1) 308 | loss = 1 - (numerator + 1) / (denominator + 1) 309 | return loss.sum() / num_masks 310 | 311 | 312 | def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2): 313 | """ 314 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 315 | Args: 316 | inputs: A float tensor of arbitrary shape. 317 | The predictions for each example. 318 | targets: A float tensor with the same shape as inputs. Stores the binary 319 | classification label for each element in inputs 320 | (0 for the negative class and 1 for the positive class). 321 | alpha: (optional) Weighting factor in range (0,1) to balance 322 | positive vs negative examples. Default = -1 (no weighting). 323 | gamma: Exponent of the modulating factor (1 - p_t) to 324 | balance easy vs hard examples. 325 | Returns: 326 | Loss tensor 327 | """ 328 | prob = inputs.sigmoid() 329 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 330 | p_t = prob * targets + (1 - prob) * (1 - targets) 331 | loss = ce_loss * ((1 - p_t) ** gamma) 332 | 333 | if alpha >= 0: 334 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 335 | loss = alpha_t * loss 336 | 337 | return loss.mean(1).sum() / num_masks 338 | 339 | def calculate_iou(mask1, mask2): 340 | """ 341 | Calculate the Intersection over Union (IoU) score between two masks. 342 | 343 | Args: 344 | mask1: The first mask as a n*m*3 matrix with the mask parts being 128. 345 | mask2: The second mask as a n*m*3 matrix with the mask parts being 128. 346 | 347 | Returns: 348 | iou: The IoU score between the two masks. 349 | """ 350 | 351 | mask1 = mask1.sum(axis=2) 352 | mask2 = mask2.sum(axis=2) 353 | 354 | mask1 = np.where(mask1 == 128, 1, 0) 355 | mask2 = np.where(mask2 == 128, 1, 0) 356 | intersection = np.sum(np.logical_and(mask1, mask2)) 357 | union = np.sum(np.logical_or(mask1, mask2)) 358 | iou = intersection / union 359 | return iou 360 | 361 | if __name__ == '__main__': 362 | main() 363 | -------------------------------------------------------------------------------- /persam_video.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | from PIL import Image 3 | from os import path 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from per_segment_anything import SamPredictor, sam_model_registry 9 | from davis2017.davis import DAVISTestDataset, all_to_onehot 10 | from eval_video import eval_davis_result 11 | 12 | def main(args): 13 | if args.eval: 14 | eval_davis_result(args.output_path, args.davis_path) 15 | return 16 | 17 | # Dataset 18 | print("Running on DAVIS", args.dataset_set) 19 | test_dataset = DAVISTestDataset(args.davis_path, imset=args.dataset_set + '/val.txt') 20 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1) 21 | palette = Image.open(path.expanduser(os.path.join(args.davis_path, 'Annotations/480p/bike-packing/00000.png'))).getpalette() 22 | 23 | # Load SAM 24 | sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth' 25 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() 26 | predictor = SamPredictor(sam) 27 | 28 | # Start eval 29 | for iter, data in enumerate(test_loader): 30 | rgb = data['rgb'].cpu().numpy() 31 | msk = data['gt'][0].cpu().numpy() 32 | info = data['info'] 33 | name = info['name'][0] 34 | os.makedirs(args.output_path, exist_ok=True) 35 | L = os.listdir(args.output_path) 36 | print("Processing Video", name, "....") 37 | if name in L: 38 | print("File", name, "exists in", args.output_path, ", skip...") 39 | continue 40 | num_obj = len(info['labels'][0]) 41 | 42 | frame_num = rgb.shape[1] 43 | 44 | save_path = args.output_path + '/{}/'.format(name) 45 | os.makedirs(save_path, exist_ok=True) 46 | first_frame_image = rgb[0, 0] 47 | first_frame_mask = msk[:, 0] * args.exp 48 | 49 | fore_feat_list = [] 50 | # Foreground features 51 | input_boxes = [] 52 | for k in range(msk[:, 0].shape[0]): 53 | input_boxes.append(msk[:, 0][k]) 54 | for obj in range(num_obj): 55 | print("Processing Object", obj) 56 | frame_image = first_frame_image 57 | 58 | obj_mask = first_frame_mask[obj].reshape(first_frame_mask.shape[1], first_frame_mask.shape[2], 1) 59 | obj_mask = np.concatenate((obj_mask, np.zeros((obj_mask.shape[0], obj_mask.shape[1], 2), dtype=obj_mask.dtype)), axis=2) 60 | obj_mask = predictor.set_image(frame_image, obj_mask) 61 | if obj == 0: 62 | img_feat1 = predictor.features.squeeze().permute(1, 2, 0) 63 | obj_mask = F.interpolate(obj_mask, size=img_feat1.shape[0:2], mode="bilinear") 64 | obj_mask = obj_mask.squeeze()[0] 65 | 66 | fore_feat = img_feat1[obj_mask > 0] 67 | 68 | if fore_feat.shape[0] == 0: 69 | fore_feat_list.append(fore_feat.mean(0)) 70 | print("Find a small object in", name, "Object", obj) 71 | continue 72 | 73 | fore_feat_mean = fore_feat.mean(0) 74 | fore_feat_max = torch.max(fore_feat, dim=0)[0] 75 | fore_feat = (fore_feat_max / 2 + fore_feat_mean / 2).unsqueeze(0) 76 | fore_feat = fore_feat / fore_feat.norm(dim=-1, keepdim=True) 77 | fore_feat_list.append(fore_feat) 78 | 79 | for i in range (1, frame_num): 80 | current_img = rgb[0, i] 81 | predictor.set_image(current_img) 82 | 83 | # pred masks 84 | test_feat = predictor.features.squeeze() 85 | C, htest, wtest = test_feat.shape 86 | 87 | test_feat = test_feat / test_feat.norm(dim=0, keepdim=True) 88 | test_feat = test_feat.reshape(C, htest * wtest) 89 | 90 | concat_mask = np.zeros((1, first_frame_mask.shape[1], first_frame_mask.shape[2]), dtype=np.uint8) 91 | for j in range(min(len(fore_feat_list), len(input_boxes))): 92 | # Cosine similarity 93 | fore_feat = fore_feat_list[j] 94 | sim = fore_feat @ test_feat # 1, h*w 95 | sim = sim.reshape(1, 1, htest, wtest) 96 | sim = F.interpolate(sim, scale_factor=4, mode="bilinear") 97 | 98 | mask_sim = predictor.model.postprocess_masks( 99 | sim, 100 | input_size=predictor.input_size, 101 | original_size=predictor.original_size).squeeze() 102 | 103 | # Top-k point selection 104 | w, h = mask_sim.shape 105 | 106 | topk_xy_i, topk_label_i = point_selection(mask_sim, topk=args.topk) 107 | topk_xy = topk_xy_i 108 | topk_label = topk_label_i 109 | 110 | if args.center: 111 | topk_label = np.concatenate([topk_label, [1]], axis=0) 112 | 113 | if args.box_prompt: 114 | center, input_box_ = get_box_prompt(input_boxes[j], args.threshold) 115 | if args.center: 116 | topk_xy = np.concatenate((topk_xy, center), axis=0) 117 | 118 | masks, scores, logits, _ = predictor.predict( 119 | point_coords=topk_xy, 120 | point_labels=topk_label, 121 | box=input_box_[None, :], 122 | multimask_output=True) 123 | else: 124 | masks, scores, logits, _ = predictor.predict( 125 | point_coords=topk_xy, 126 | point_labels=topk_label, 127 | multimask_output=True) 128 | if args.large: 129 | masks_ = masks 130 | mask_num = np.array([np.sum(masks_[0]), np.sum(masks_[1]), np.sum(masks_[2])], dtype=np.uint8) 131 | ic_index = np.argmax(mask_num, axis=0).astype(np.uint8) 132 | else: 133 | ic_index = 0 134 | 135 | masks, scores, logits, _ = predictor.predict( 136 | point_coords=topk_xy, 137 | point_labels=topk_label, 138 | mask_input=logits[ic_index: ic_index + 1, :, :], 139 | multimask_output=True) 140 | ic_index = np.argmax(scores) 141 | 142 | # box refine 143 | y, x = np.nonzero(masks[ic_index]) 144 | x_min = x.min() 145 | x_max = x.max() 146 | y_min = y.min() 147 | y_max = y.max() 148 | input_box = np.array([x_min, y_min, x_max, y_max]) 149 | masks, scores, logits, _ = predictor.predict( 150 | point_coords=topk_xy, 151 | point_labels=topk_label, 152 | box=input_box[None, :], 153 | mask_input=logits[ic_index: ic_index + 1, :, :], 154 | multimask_output=True, 155 | return_logits=True) 156 | 157 | ic_index = np.argmax(scores) 158 | 159 | concat_mask = np.concatenate((concat_mask, masks[ic_index].reshape(1, masks.shape[1], masks.shape[2])), axis=0) 160 | 161 | current_mask_pred = np.argmax(concat_mask, axis=0).astype(np.uint8) 162 | output = Image.fromarray(current_mask_pred) 163 | output.putpalette(palette) 164 | output.save(save_path + '{:05d}.png'.format(i)) 165 | 166 | if args.box_prompt: 167 | cur_labels = np.unique(current_mask_pred) 168 | cur_labels = cur_labels[cur_labels!=0] 169 | input_boxes = all_to_onehot(current_mask_pred, cur_labels) 170 | 171 | print(f"Finish predict video: {name}") 172 | 173 | eval_davis_result(args.output_path, args.davis_path) 174 | 175 | def get_box_prompt(img, threshold): 176 | rows = np.any(img, axis=1) 177 | cols = np.any(img, axis=0) 178 | rmin, rmax = np.where(rows)[0][[0, -1]] 179 | cmin, cmax = np.where(cols)[0][[0, -1]] 180 | 181 | cmin = 0 if cmin - threshold <= 0 else cmin - threshold 182 | rmin = 0 if rmin - threshold <= 0 else rmin - threshold 183 | cmax = img.shape[1] if cmax + threshold >= img.shape[1] else cmax + threshold 184 | rmax = img.shape[0] if rmax + threshold >= img.shape[0] else rmax + threshold 185 | 186 | return np.array([[(cmin + cmax) // 2, (rmin + rmax) // 2]]), np.array([cmin,rmin,cmax,rmax]) # x1,y1,x2,y2 187 | 188 | def point_selection(mask_sim, topk=1): 189 | w, h = mask_sim.shape 190 | topk_xy = mask_sim.flatten(0).topk(topk)[1] 191 | topk_x = (topk_xy // h).unsqueeze(0) 192 | topk_y = (topk_xy - topk_x * h) 193 | topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0) 194 | topk_label = np.array([1] * topk) 195 | topk_xy = topk_xy.cpu().numpy() 196 | return topk_xy, topk_label 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--output_path", type=str, help="output path", required=True) 201 | parser.add_argument('--davis_path', default='./DAVIS/2017') 202 | parser.add_argument("--dataset_set", type=str, help="2017", default='2017') 203 | parser.add_argument("--topk", type=int, help="choose topk points", default=2) 204 | parser.add_argument("--exp", type=int, help="expand mask value to", default=215) 205 | parser.add_argument("--threshold", type=int, help="the threshold for bounding box expansion", default=10) 206 | parser.add_argument("--eval", action="store_true", help="eval only") 207 | parser.add_argument("--box_prompt", action="store_true", help="whether use box prompt") 208 | parser.add_argument("--large", action="store_true", help="whether choose largest mask for prompting after stage 1") 209 | parser.add_argument("--center", action="store_true", help="whether prompt with center") 210 | parser.set_defaults(box_prompt=True) 211 | parser.set_defaults(large=True) 212 | parser.set_defaults(center=True) 213 | args = parser.parse_args() 214 | print(args) 215 | main(args) -------------------------------------------------------------------------------- /prepare_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | import cv2 6 | import shutil 7 | import argparse 8 | 9 | def copy_file(src, dst): 10 | shutil.copy(src, dst) 11 | 12 | def coco2mask(coco_file_path, img_save_path): 13 | print('coco_file_path:', coco_file_path) 14 | print('img_save_path:', img_save_path) 15 | with open(coco_file_path, 'r') as f: 16 | coco_json = json.load(f) 17 | 18 | class_mapper = dict() 19 | for category in coco_json['categories']: 20 | class_mapper[category['id']] = category['name'] 21 | os.makedirs(os.path.join(img_save_path,category['name']), exist_ok=True) 22 | os.makedirs(os.path.join('data', 'Images', category['name']), exist_ok=True) 23 | for category in coco_json['categories']: 24 | print(category) 25 | idx = 0 26 | for image in coco_json['images']: 27 | height = image['height'] 28 | width = image['width'] 29 | mask = np.zeros((height, width), dtype=np.uint8) 30 | 31 | has_annots = False 32 | for annotation in coco_json['annotations']: 33 | if annotation['segmentation'] == []: 34 | continue 35 | if annotation['category_id'] == category['id'] and annotation['image_id'] == image['id']: 36 | has_annots = True 37 | seg = annotation['segmentation'] 38 | seg = np.array(seg).reshape((-1, 2)).astype(np.int32) 39 | mask = cv2.fillPoly(mask, [seg], 128) 40 | if has_annots: 41 | 42 | mask_img = np.zeros((height, width, 3), dtype=np.uint8) 43 | mask_img[:, :, 0] = mask 44 | mask_img[:, :, 1] = mask 45 | mask_img[:, :, 2] = mask 46 | img_save_name = os.path.join(img_save_path, class_mapper[category['id']], "{:02}".format(idx) + '.png') 47 | mask_img[:,:,1:] = 0 48 | mask_img = cv2.cvtColor(mask_img, cv2.COLOR_RGB2BGR) # 将参考图像从BGR颜色空间转为RGB颜色空间 49 | copy_file(os.path.join('auto-sam-data/',image['file_name']), 50 | os.path.join('data', 'Images',class_mapper[category['id']], "{:02}".format(idx) + '.jpg') 51 | ) 52 | 53 | cv2.imwrite(img_save_name, mask_img) 54 | idx += 1 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('coco_path', type=str, help='path to JSON file') 59 | parser.add_argument('output_dir', type=str, help='output directory') 60 | args = parser.parse_args() 61 | 62 | coco2mask(os.path.join(args.coco_path, 'result.json'), args.output_dir) 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | tqdm 3 | os 4 | numpy 5 | warnings 6 | argparse 7 | opencv-python 8 | -------------------------------------------------------------------------------- /show.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | 6 | 7 | 8 | def show_mask(mask, ax, random_color=False): 9 | if random_color: 10 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 11 | else: 12 | color = np.array([30/255, 144/255, 255/255, 0.4]) 13 | h, w = mask.shape[-2:] 14 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 15 | ax.imshow(mask_image) 16 | 17 | 18 | def show_points(coords, labels, ax, marker_size=375): 19 | pos_points = coords[labels==1] 20 | neg_points = coords[labels==0] 21 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 22 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 23 | 24 | 25 | def show_box(box, ax): 26 | x0, y0 = box[0], box[1] 27 | w, h = box[2] - box[0], box[3] - box[1] 28 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) -------------------------------------------------------------------------------- /weights/mobile_sam.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZrrSkywalker/Personalize-SAM/a7e87245156e0b8efcba10a15a745e699c679d9d/weights/mobile_sam.pt --------------------------------------------------------------------------------