├── .gitignore ├── LICENSE ├── README.md ├── assets └── how-it-works.gif ├── environment.yaml ├── helpers ├── extract_embeddings.py └── generate_onnx.py ├── salt ├── dataset_explorer.py ├── display_utils.py ├── editor.py ├── interface.py ├── onnx_model.py └── utils.py └── segment_anything_annotator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Models 132 | .onnx 133 | 134 | # Dataset 135 | dataset/ 136 | models/ 137 | flyer_pages/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Anurag Ghosh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything Labelling Tool (SALT) 2 | 3 | Uses the Segment-Anything Model By Meta AI and adds a barebones interface to label images and saves the masks in the COCO format. 4 | 5 | Under active development, apologies for rough edges and bugs. Use at your own risk. 6 | 7 | ## Installation 8 | 9 | 1. Install [Segment Anything](https://github.com/facebookresearch/segment-anything) on any machine with a GPU. (Need not be the labelling machine.) 10 | 2. Create a conda environment using `conda env create -f environment.yaml` on the labelling machine (Need not have GPU). 11 | 3. (Optional) Install [coco-viewer](https://github.com/trsvchn/coco-viewer) to scroll through your annotations quickly. 12 | 13 | ## Usage 14 | 15 | 1. Setup your dataset in the following format `/images/*` and create empty folder `/embeddings`. 16 | - Annotations will be saved in `/annotations.json` by default. 17 | 2. Copy the `helpers` scripts to the base folder of your `segment-anything` folder. 18 | - Call `extract_embeddings.py` to extract embeddings for your images. 19 | - Call `generate_onnx.py` generate `*.onnx` files in models. 20 | 4. Copy the models in `models` folder. 21 | 5. Symlink your dataset in the SALT's root folder as ``. 22 | 6. Call `segment_anything_annotator.py` with argument `` and categories `cat1,cat2,cat3..`. 23 | - There are a few keybindings that make the annotation process fast. 24 | - Click on the object using left clicks and right click (to indicate outside object boundary). 25 | - `n` adds predicted mask into your annotations. (Add button) 26 | - `r` rejects the predicted mask. (Reject button) 27 | - `a` and `d` to cycle through images in your your set. (Next and Prev) 28 | - `l` and `k` to increase and decrease the transparency of the other annotations. 29 | - `Ctrl + S` to save progress to the COCO-style annotations file. 30 | 7. [coco-viewer](https://github.com/trsvchn/coco-viewer) to view your annotations. 31 | - `python cocoviewer.py -i -a /annotations.json` 32 | 33 | ## Demo 34 | 35 | ![How it Works Gif!](https://github.com/anuragxel/salt/raw/main/assets/how-it-works.gif) 36 | 37 | ## Contributing 38 | 39 | Follow these guidelines to ensure that your contributions can be reviewed and merged. Need a lot of help in making the UI better. 40 | 41 | If you have found a bug or have an idea for an improvement or new feature, please create an issue on GitHub. Before creating a new issue, please search existing issues to see if your issue has already been reported. 42 | 43 | When creating an issue, please include as much detail as possible, including steps to reproduce the issue if applicable. 44 | 45 | Create a pull request (PR) to the original repository. Please use `black` formatter when making code changes. 46 | 47 | ## License 48 | 49 | The code is licensed under the MIT License. By contributing to SALT, you agree to license your contributions under the same license as the project. See LICENSE for more information. 50 | -------------------------------------------------------------------------------- /assets/how-it-works.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anuragxel/salt/e431b6ee3c938a4fe49f3cd1a759e555a9f2a81f/assets/how-it-works.gif -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: seg-tool 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ca-certificates=2023.01.10 6 | - certifi=2022.12.7 7 | - libcxx=14.0.6 8 | - libffi=3.4.2 9 | - ncurses=6.4 10 | - openssl=1.1.1t 11 | - pip=23.0.1 12 | - python=3.8.16 13 | - readline=8.2 14 | - setuptools=65.6.3 15 | - sqlite=3.41.1 16 | - tk=8.6.12 17 | - wheel=0.38.4 18 | - xz=5.2.10 19 | - zlib=1.2.13 20 | - pip: 21 | - black==23.3.0 22 | - click==8.1.3 23 | - coloredlogs==15.0.1 24 | - contourpy==1.0.7 25 | - cycler==0.11.0 26 | - distinctipy==1.2.2 27 | - flatbuffers==23.3.3 28 | - fonttools==4.39.3 29 | - humanfriendly==10.0 30 | - imageio==2.27.0 31 | - importlib-resources==5.12.0 32 | - kiwisolver==1.4.4 33 | - lazy-loader==0.2 34 | - matplotlib==3.7.1 35 | - mpmath==1.3.0 36 | - mypy-extensions==1.0.0 37 | - networkx==3.1 38 | - numpy==1.24.2 39 | - onnxruntime==1.14.1 40 | - opencv-python==4.7.0.72 41 | - packaging==23.0 42 | - pathspec==0.11.1 43 | - pillow==9.5.0 44 | - platformdirs==3.2.0 45 | - protobuf==4.22.1 46 | - pycocotools==2.0.6 47 | - pyparsing==3.0.9 48 | - python-dateutil==2.8.2 49 | - pywavelets==1.4.1 50 | - scikit-image==0.20.0 51 | - scipy==1.9.1 52 | - simplification==0.6.7 53 | - six==1.16.0 54 | - sympy==1.11.1 55 | - tifffile==2023.3.21 56 | - tomli==2.0.1 57 | - typing-extensions==4.5.0 58 | - zipp==3.15.0 59 | prefix: /Users/athena/opt/anaconda3/envs/seg-tool 60 | -------------------------------------------------------------------------------- /helpers/extract_embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the Apache-2.0 license found in the LICENSE file in the root directory of segment_anything repository and source tree. 4 | # Adapted from onnx_model_example.ipynb in the segment_anything repository. 5 | # Please see the original notebook for more details and other examples and additional usage. 6 | import os 7 | import argparse 8 | import cv2 9 | from tqdm import tqdm 10 | import numpy as np 11 | from segment_anything import sam_model_registry, SamPredictor 12 | 13 | def main(checkpoint_path, model_type, device, images_folder, embeddings_folder): 14 | sam = sam_model_registry[model_type](checkpoint=checkpoint_path) 15 | sam.to(device=device) 16 | predictor = SamPredictor(sam) 17 | 18 | for image_name in tqdm(os.listdir(images_folder)): 19 | image_path = os.path.join(images_folder, image_name) 20 | image = cv2.imread(image_path) 21 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 22 | 23 | 24 | predictor.set_image(image) 25 | 26 | image_embedding = predictor.get_image_embedding().cpu().numpy() 27 | 28 | out_path = os.path.join(embeddings_folder, os.path.splitext(image_name)[0] + ".npy") 29 | np.save(out_path, image_embedding) 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--checkpoint-path", type=str, default="./sam_vit_h_4b8939.pth") 34 | parser.add_argument("--model_type", type=str, default="default") 35 | parser.add_argument("--device", type=str, default="cuda") 36 | parser.add_argument("--dataset-path", type=str, default="./example_dataset") 37 | args = parser.parse_args() 38 | 39 | checkpoint_path = args.checkpoint_path 40 | model_type = args.model_type 41 | device = args.device 42 | dataset_path = args.dataset_path 43 | 44 | images_folder = os.path.join(dataset_path, "images") 45 | embeddings_folder = os.path.join(dataset_path, "embeddings") 46 | if not os.path.exists(embeddings_folder): 47 | os.makedirs(embeddings_folder) 48 | 49 | main(checkpoint_path, model_type, device, images_folder, embeddings_folder) 50 | -------------------------------------------------------------------------------- /helpers/generate_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the Apache-2.0 license found in the LICENSE file in the root directory of segment_anything repository and source tree. 4 | # Adapted from onnx_model_example.ipynb in the segment_anything repository. 5 | # Please see the original notebook for more details and other examples and additional usage. 6 | import warnings 7 | import os, shutil 8 | import argparse 9 | 10 | from segment_anything import sam_model_registry, SamPredictor 11 | from segment_anything.utils.onnx import SamOnnxModel 12 | 13 | from onnxruntime.quantization import QuantType 14 | from onnxruntime.quantization.quantize import quantize_dynamic 15 | 16 | import cv2 17 | import torch 18 | 19 | def save_onnx_model(checkpoint, model_type, onnx_model_path, orig_im_size, opset_version, quantize = True): 20 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 21 | 22 | onnx_model = SamOnnxModel(sam, return_single_mask=True) 23 | 24 | dynamic_axes = { 25 | "point_coords": {1: "num_points"}, 26 | "point_labels": {1: "num_points"}, 27 | } 28 | 29 | embed_dim = sam.prompt_encoder.embed_dim 30 | embed_size = sam.prompt_encoder.image_embedding_size 31 | mask_input_size = [4 * x for x in embed_size] 32 | dummy_inputs = { 33 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 34 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 35 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 36 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 37 | "has_mask_input": torch.tensor([1], dtype=torch.float), 38 | "orig_im_size": torch.tensor(orig_im_size, dtype=torch.float), 39 | } 40 | output_names = ["masks", "iou_predictions", "low_res_masks"] 41 | 42 | with warnings.catch_warnings(): 43 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 44 | warnings.filterwarnings("ignore", category=UserWarning) 45 | with open(onnx_model_path, "wb") as f: 46 | torch.onnx.export( 47 | onnx_model, 48 | tuple(dummy_inputs.values()), 49 | f, 50 | export_params=True, 51 | verbose=False, 52 | opset_version=opset_version, 53 | do_constant_folding=True, 54 | input_names=list(dummy_inputs.keys()), 55 | output_names=output_names, 56 | dynamic_axes=dynamic_axes, 57 | ) 58 | 59 | if quantize: 60 | temp_model_path = os.path.join(os.path.split(onnx_model_path)[0], "temp.onnx") 61 | shutil.copy(onnx_model_path, temp_model_path) 62 | quantize_dynamic( 63 | model_input=temp_model_path, 64 | model_output=onnx_model_path, 65 | optimize_model=True, 66 | per_channel=False, 67 | reduce_range=False, 68 | weight_type=QuantType.QUInt8, 69 | ) 70 | os.remove(temp_model_path) 71 | 72 | def main(checkpoint_path, model_type, onnx_models_path, dataset_path, opset_version, quantize): 73 | if not os.path.exists(onnx_models_path): 74 | os.makedirs(onnx_models_path) 75 | 76 | images_path = os.path.join(dataset_path, "images") 77 | 78 | im_sizes = set() 79 | for image_path in os.listdir(images_path): 80 | if image_path.endswith(".jpg") or image_path.endswith(".png"): 81 | im_path = os.path.join(images_path, image_path) 82 | cv2_im = cv2.imread(im_path) 83 | im_sizes.add(cv2_im.shape[:2]) 84 | 85 | for orig_im_size in im_sizes: 86 | onnx_model_path = os.path.join(onnx_models_path, f"sam_onnx.{orig_im_size[0]}_{orig_im_size[1]}.onnx") 87 | save_onnx_model(checkpoint_path, model_type, onnx_model_path, orig_im_size, opset_version, quantize) 88 | 89 | if __name__ == "__main__": 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--checkpoint-path", type=str, default="./sam_vit_h_4b8939.pth") 93 | parser.add_argument("--model_type", type=str, default="default") 94 | parser.add_argument("--onnx-models-path", type=str, default="./models") 95 | parser.add_argument("--dataset-path", type=str, default="./dataset") 96 | parser.add_argument("--opset-version", type=int, default=15) 97 | parser.add_argument("--quantize", action="store_true") 98 | args = parser.parse_args() 99 | 100 | checkpoint_path = args.checkpoint_path 101 | model_type = args.model_type 102 | onnx_models_path = args.onnx_models_path 103 | dataset_path = args.dataset_path 104 | opset_version = args.opset_version 105 | quantize = args.quantize 106 | 107 | main(checkpoint_path, model_type, onnx_models_path, dataset_path, opset_version, quantize) -------------------------------------------------------------------------------- /salt/dataset_explorer.py: -------------------------------------------------------------------------------- 1 | from pycocotools import mask 2 | from skimage import measure 3 | import json 4 | import shutil 5 | import itertools 6 | import numpy as np 7 | from simplification.cutil import simplify_coords_vwp 8 | import os, cv2, copy 9 | from distinctipy import distinctipy 10 | 11 | 12 | def init_coco(dataset_folder, image_names, categories, coco_json_path): 13 | coco_json = { 14 | "info": { 15 | "description": "SAM Dataset", 16 | "url": "", 17 | "version": "1.0", 18 | "year": 2023, 19 | "contributor": "Sam", 20 | "date_created": "2021/07/01", 21 | }, 22 | "images": [], 23 | "annotations": [], 24 | "categories": [], 25 | } 26 | for i, category in enumerate(categories): 27 | coco_json["categories"].append( 28 | {"id": i, "name": category, "supercategory": category} 29 | ) 30 | for i, image_name in enumerate(image_names): 31 | im = cv2.imread(os.path.join(dataset_folder, image_name)) 32 | coco_json["images"].append( 33 | { 34 | "id": i, 35 | "file_name": image_name, 36 | "width": im.shape[1], 37 | "height": im.shape[0], 38 | } 39 | ) 40 | with open(coco_json_path, "w") as f: 41 | json.dump(coco_json, f) 42 | 43 | 44 | def bunch_coords(coords): 45 | coords_trans = [] 46 | for i in range(0, len(coords) // 2): 47 | coords_trans.append([coords[2 * i], coords[2 * i + 1]]) 48 | return coords_trans 49 | 50 | 51 | def unbunch_coords(coords): 52 | return list(itertools.chain(*coords)) 53 | 54 | 55 | def bounding_box_from_mask(mask): 56 | mask = mask.astype(np.uint8) 57 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 58 | all_contours = [] 59 | for contour in contours: 60 | all_contours.extend(contour) 61 | convex_hull = cv2.convexHull(np.array(all_contours)) 62 | x, y, w, h = cv2.boundingRect(convex_hull) 63 | return x, y, w, h 64 | 65 | 66 | def parse_mask_to_coco(image_id, anno_id, image_mask, category_id, poly=False): 67 | start_anno_id = anno_id 68 | x, y, width, height = bounding_box_from_mask(image_mask) 69 | if poly == False: 70 | fortran_binary_mask = np.asfortranarray(image_mask) 71 | encoded_mask = mask.encode(fortran_binary_mask) 72 | if poly == True: 73 | contours, _ = cv2.findContours(image_mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 74 | annotation = { 75 | "id": start_anno_id, 76 | "image_id": image_id, 77 | "category_id": category_id, 78 | "bbox": [float(x), float(y), float(width), float(height)], 79 | "area": float(width * height), 80 | "iscrowd": 0, 81 | "segmentation": [], 82 | } 83 | if poly == False: 84 | annotation["segmentation"] = encoded_mask 85 | annotation["segmentation"]["counts"] = str( 86 | annotation["segmentation"]["counts"], "utf-8" 87 | ) 88 | if poly == True: 89 | for contour in contours: 90 | sc = simplify_coords_vwp(contour[:,0,:], 2).ravel().tolist() 91 | annotation["segmentation"].append(sc) 92 | return annotation 93 | 94 | 95 | class DatasetExplorer: 96 | def __init__(self, dataset_folder, categories=None, coco_json_path=None): 97 | self.dataset_folder = dataset_folder 98 | self.image_names = os.listdir(os.path.join(self.dataset_folder, "images")) 99 | self.image_names = [ 100 | os.path.split(name)[1] 101 | for name in self.image_names 102 | if name.endswith(".jpg") or name.endswith(".png") 103 | ] 104 | self.coco_json_path = coco_json_path 105 | if not os.path.exists(coco_json_path): 106 | self.__init_coco_json(categories) 107 | with open(coco_json_path, "r") as f: 108 | self.coco_json = json.load(f) 109 | 110 | self.categories = [ 111 | category["name"] for category in self.coco_json["categories"] 112 | ] 113 | self.annotations_by_image_id = {} 114 | for annotation in self.coco_json["annotations"]: 115 | image_id = annotation["image_id"] 116 | if image_id not in self.annotations_by_image_id: 117 | self.annotations_by_image_id[image_id] = [] 118 | self.annotations_by_image_id[image_id].append(annotation) 119 | 120 | # self.global_annotation_id = len(self.coco_json["annotations"]) 121 | try: 122 | self.global_annotation_id = ( 123 | max(self.coco_json["annotations"], key=lambda x: x["id"])["id"] + 1 124 | ) 125 | except: 126 | self.global_annotation_id = 0 127 | self.category_colors = distinctipy.get_colors(len(self.categories)) 128 | self.category_colors = [ 129 | tuple([int(255 * c) for c in color]) for color in self.category_colors 130 | ] 131 | 132 | def __init_coco_json(self, categories): 133 | appended_image_names = [ 134 | os.path.join("images", name) for name in self.image_names 135 | ] 136 | init_coco( 137 | self.dataset_folder, appended_image_names, categories, self.coco_json_path 138 | ) 139 | 140 | def get_colors(self, category_id): 141 | return self.category_colors[category_id] 142 | 143 | def get_categories(self, get_colors=False): 144 | if get_colors: 145 | return self.categories, self.category_colors 146 | return self.categories 147 | 148 | def get_num_images(self): 149 | return len(self.image_names) 150 | 151 | def get_image_data(self, image_id): 152 | image_name = self.coco_json["images"][image_id]["file_name"] 153 | image_path = os.path.join(self.dataset_folder, image_name) 154 | embedding_path = os.path.join( 155 | self.dataset_folder, 156 | "embeddings", 157 | os.path.splitext(os.path.split(image_name)[1])[0] + ".npy", 158 | ) 159 | image = cv2.imread(image_path) 160 | image_bgr = copy.deepcopy(image) 161 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 162 | image_embedding = np.load(embedding_path) 163 | return image, image_bgr, image_embedding 164 | 165 | def __add_to_our_annotation_dict(self, annotation): 166 | image_id = annotation["image_id"] 167 | if image_id not in self.annotations_by_image_id: 168 | self.annotations_by_image_id[image_id] = [] 169 | self.annotations_by_image_id[image_id].append(annotation) 170 | 171 | def get_annotations(self, image_id, return_colors=False): 172 | if image_id not in self.annotations_by_image_id: 173 | return [], [] 174 | cats = [a["category_id"] for a in self.annotations_by_image_id[image_id]] 175 | colors = [self.category_colors[c] for c in cats] 176 | if return_colors: 177 | return self.annotations_by_image_id[image_id], colors 178 | return self.annotations_by_image_id[image_id] 179 | 180 | def delete_annotations(self, image_id, annotation_id): 181 | for annotation in self.coco_json["annotations"]: 182 | if ( 183 | annotation["image_id"] == image_id and annotation["id"] == annotation_id 184 | ): # and annotation["id"] in annotation_ids: 185 | self.coco_json["annotations"].remove(annotation) 186 | break 187 | for annotation in self.annotations_by_image_id[image_id]: 188 | if annotation["id"] == annotation_id: 189 | self.annotations_by_image_id[image_id].remove(annotation) 190 | break 191 | 192 | def add_annotation(self, image_id, category_id, mask, poly=True): 193 | if mask is None: 194 | return 195 | annotation = parse_mask_to_coco( 196 | image_id, self.global_annotation_id, mask, category_id, poly=poly 197 | ) 198 | self.__add_to_our_annotation_dict(annotation) 199 | self.coco_json["annotations"].append(annotation) 200 | self.global_annotation_id += 1 201 | 202 | def save_annotation(self): 203 | with open(self.coco_json_path, "w") as f: 204 | json.dump(self.coco_json, f) 205 | -------------------------------------------------------------------------------- /salt/display_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from pycocotools import mask as coco_mask 4 | 5 | 6 | class DisplayUtils: 7 | def __init__(self): 8 | self.transparency = 0 9 | self.box_width = 2 10 | 11 | def increase_transparency(self): 12 | self.transparency = min(1.0, self.transparency + 0.05) 13 | 14 | def decrease_transparency(self): 15 | self.transparency = max(0.0, self.transparency - 0.05) 16 | 17 | def overlay_mask_on_image(self, image, mask, color=(255, 0, 0)): 18 | gray_mask = mask.astype(np.uint8) * 255 19 | gray_mask = cv2.merge([gray_mask, gray_mask, gray_mask]) 20 | color_mask = cv2.bitwise_and(gray_mask, color) 21 | masked_image = cv2.bitwise_and(image.copy(), color_mask) 22 | overlay_on_masked_image = cv2.addWeighted( 23 | masked_image, self.transparency, color_mask, 1 - self.transparency, 0 24 | ) 25 | background = cv2.bitwise_and(image.copy(), cv2.bitwise_not(color_mask)) 26 | image = cv2.add(background, overlay_on_masked_image) 27 | return image 28 | 29 | def __convert_ann_to_mask(self, ann, height, width): 30 | mask = np.zeros((height, width), dtype=np.uint8) 31 | poly = ann["segmentation"] 32 | rles = coco_mask.frPyObjects(poly, height, width) 33 | rle = coco_mask.merge(rles) 34 | mask_instance = coco_mask.decode(rle) 35 | mask_instance = np.logical_not(mask_instance) 36 | mask = np.logical_or(mask, mask_instance) 37 | mask = np.logical_not(mask) 38 | return mask 39 | 40 | def draw_box_on_image(self, image, ann, color): 41 | x, y, w, h = ann["bbox"] 42 | x, y, w, h = int(x), int(y), int(w), int(h) 43 | if color == (0, 0, 0): 44 | image = cv2.rectangle(image, (x, y), (x + w, y + h), color, -1) 45 | else: 46 | image = cv2.rectangle(image, (x, y), (x + w, y + h), color, self.box_width) 47 | image = cv2.putText( 48 | image, 49 | "id: " + str(ann["id"]), 50 | (x, y - 10), 51 | cv2.FONT_HERSHEY_SIMPLEX, 52 | 0.9, 53 | (0, 0, 0), 54 | 4, 55 | ) 56 | return image 57 | 58 | def draw_annotations(self, image, annotations, colors): 59 | for ann, color in zip(annotations, colors): 60 | image = self.draw_box_on_image(image, ann, color) 61 | mask = self.__convert_ann_to_mask(ann, image.shape[0], image.shape[1]) 62 | image = self.overlay_mask_on_image(image, mask, color) 63 | return image 64 | 65 | def draw_points( 66 | self, image, points, labels, colors={1: (0, 255, 0), 0: (0, 0, 255)}, radius=5 67 | ): 68 | for i in range(points.shape[0]): 69 | point = points[i, :] 70 | label = labels[i] 71 | color = colors[label] 72 | image = cv2.circle(image, tuple(point), radius, color, -1) 73 | return image 74 | -------------------------------------------------------------------------------- /salt/editor.py: -------------------------------------------------------------------------------- 1 | import os, copy 2 | import numpy as np 3 | from salt.onnx_model import OnnxModels 4 | from salt.dataset_explorer import DatasetExplorer 5 | from salt.display_utils import DisplayUtils 6 | 7 | 8 | class CurrentCapturedInputs: 9 | def __init__(self): 10 | self.input_point = np.array([]) 11 | self.input_label = np.array([]) 12 | self.low_res_logits = None 13 | self.curr_mask = None 14 | 15 | def reset_inputs(self): 16 | self.input_point = np.array([]) 17 | self.input_label = np.array([]) 18 | self.low_res_logits = None 19 | self.curr_mask = None 20 | 21 | def set_mask(self, mask): 22 | self.curr_mask = mask 23 | 24 | def add_input_click(self, input_point, input_label): 25 | if len(self.input_point) == 0: 26 | self.input_point = np.array([input_point]) 27 | else: 28 | self.input_point = np.vstack([self.input_point, np.array([input_point])]) 29 | self.input_label = np.append(self.input_label, input_label) 30 | 31 | def set_low_res_logits(self, low_res_logits): 32 | self.low_res_logits = low_res_logits 33 | 34 | 35 | class Editor: 36 | def __init__( 37 | self, onnx_models_path, dataset_path, categories=None, coco_json_path=None 38 | ): 39 | self.dataset_path = dataset_path 40 | self.coco_json_path = coco_json_path 41 | if categories is None and not os.path.exists(coco_json_path): 42 | raise ValueError("categories must be provided if coco_json_path is None") 43 | if self.coco_json_path is None: 44 | self.coco_json_path = os.path.join(self.dataset_path, "annotations.json") 45 | self.dataset_explorer = DatasetExplorer( 46 | self.dataset_path, categories=categories, coco_json_path=self.coco_json_path 47 | ) 48 | self.curr_inputs = CurrentCapturedInputs() 49 | self.categories, self.category_colors = self.dataset_explorer.get_categories( 50 | get_colors=True 51 | ) 52 | self.image_id = 0 53 | self.category_id = 0 54 | self.show_other_anns = True 55 | ( 56 | self.image, 57 | self.image_bgr, 58 | self.image_embedding, 59 | ) = self.dataset_explorer.get_image_data(self.image_id) 60 | self.display = self.image_bgr.copy() 61 | self.onnx_helper = OnnxModels( 62 | onnx_models_path, 63 | image_width=self.image.shape[1], 64 | image_height=self.image.shape[0], 65 | ) 66 | self.du = DisplayUtils() 67 | self.reset() 68 | 69 | def list_annotations(self): 70 | anns, colors = self.dataset_explorer.get_annotations( 71 | self.image_id, return_colors=True 72 | ) 73 | return anns, colors 74 | 75 | def delete_annotations(self, annotation_id): 76 | self.dataset_explorer.delete_annotations(self.image_id, annotation_id) 77 | 78 | def __draw_known_annotations(self, selected_annotations=[]): 79 | anns, colors = self.dataset_explorer.get_annotations( 80 | self.image_id, return_colors=True 81 | ) 82 | for i, (ann, color) in enumerate(zip(anns, colors)): 83 | for selected_ann in selected_annotations: 84 | if ann["id"] == selected_ann: 85 | colors[i] = (0, 0, 0) 86 | # Use this to list the annotations 87 | self.display = self.du.draw_annotations(self.display, anns, colors) 88 | 89 | def __draw(self, selected_annotations=[]): 90 | self.display = self.image_bgr.copy() 91 | if self.curr_inputs.curr_mask is not None: 92 | self.display = self.du.draw_points( 93 | self.display, self.curr_inputs.input_point, self.curr_inputs.input_label 94 | ) 95 | self.display = self.du.overlay_mask_on_image( 96 | self.display, self.curr_inputs.curr_mask 97 | ) 98 | if self.show_other_anns: 99 | self.__draw_known_annotations(selected_annotations) 100 | 101 | def add_click(self, new_pt, new_label, selected_annotations=[]): 102 | self.curr_inputs.add_input_click(new_pt, new_label) 103 | masks, low_res_logits = self.onnx_helper.call( 104 | self.image, 105 | self.image_embedding, 106 | self.curr_inputs.input_point, 107 | self.curr_inputs.input_label, 108 | low_res_logits=self.curr_inputs.low_res_logits, 109 | ) 110 | self.curr_inputs.set_mask(masks[0, 0, :, :]) 111 | self.curr_inputs.set_low_res_logits(low_res_logits) 112 | self.__draw(selected_annotations) 113 | 114 | def remove_click(self, new_pt): 115 | print("ran remove click") 116 | 117 | def reset(self, hard=True, selected_annotations=[]): 118 | self.curr_inputs.reset_inputs() 119 | self.__draw(selected_annotations) 120 | 121 | def toggle(self, selected_annotations=[]): 122 | self.show_other_anns = not self.show_other_anns 123 | self.__draw(selected_annotations) 124 | 125 | def step_up_transparency(self, selected_annotations=[]): 126 | self.display = self.image_bgr.copy() 127 | self.du.increase_transparency() 128 | self.__draw(selected_annotations) 129 | 130 | def step_down_transparency(self, selected_annotations=[]): 131 | self.display = self.image_bgr.copy() 132 | self.du.decrease_transparency() 133 | self.__draw(selected_annotations) 134 | 135 | def draw_selected_annotations(self, selected_annotations=[]): 136 | self.__draw(selected_annotations) 137 | 138 | def save_ann(self): 139 | self.dataset_explorer.add_annotation( 140 | self.image_id, self.category_id, self.curr_inputs.curr_mask 141 | ) 142 | 143 | def save(self): 144 | self.dataset_explorer.save_annotation() 145 | 146 | def next_image(self): 147 | if self.image_id == self.dataset_explorer.get_num_images() - 1: 148 | return 149 | self.image_id += 1 150 | ( 151 | self.image, 152 | self.image_bgr, 153 | self.image_embedding, 154 | ) = self.dataset_explorer.get_image_data(self.image_id) 155 | self.display = self.image_bgr.copy() 156 | self.onnx_helper.set_image_resolution(self.image.shape[1], self.image.shape[0]) 157 | self.reset() 158 | 159 | def prev_image(self): 160 | if self.image_id == 0: 161 | return 162 | self.image_id -= 1 163 | ( 164 | self.image, 165 | self.image_bgr, 166 | self.image_embedding, 167 | ) = self.dataset_explorer.get_image_data(self.image_id) 168 | self.display = self.image_bgr.copy() 169 | self.onnx_helper.set_image_resolution(self.image.shape[1], self.image.shape[0]) 170 | self.reset() 171 | 172 | def next_category(self): 173 | if self.category_id == len(self.categories) - 1: 174 | self.category_id = 0 175 | return 176 | self.category_id += 1 177 | 178 | def prev_category(self): 179 | if self.category_id == 0: 180 | self.category_id = len(self.categories) - 1 181 | return 182 | self.category_id -= 1 183 | 184 | def get_categories(self, get_colors=False): 185 | if get_colors: 186 | return self.categories, self.category_colors 187 | return self.categories 188 | 189 | def select_category(self, category_name): 190 | category_id = self.categories.index(category_name) 191 | self.category_id = category_id 192 | -------------------------------------------------------------------------------- /salt/interface.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PyQt5.QtWidgets import ( 3 | QScrollArea, 4 | QWidget, 5 | QVBoxLayout, 6 | QLabel, 7 | QGraphicsView, 8 | QGraphicsScene, 9 | QApplication, 10 | QListWidget, 11 | QListWidgetItem, 12 | QAbstractItemView, 13 | ) 14 | from PyQt5.QtGui import QImage, QPixmap, QPainter, QWheelEvent, QMouseEvent 15 | from PyQt5.QtCore import Qt, QRectF 16 | from PyQt5.QtWidgets import ( 17 | QPushButton, 18 | QVBoxLayout, 19 | QHBoxLayout, 20 | QWidget, 21 | QLabel, 22 | QRadioButton, 23 | ) 24 | 25 | selected_annotations = [] 26 | 27 | 28 | class CustomGraphicsView(QGraphicsView): 29 | def __init__(self, editor): 30 | super(CustomGraphicsView, self).__init__() 31 | 32 | self.editor = editor 33 | self.setRenderHint(QPainter.Antialiasing) 34 | self.setRenderHint(QPainter.SmoothPixmapTransform) 35 | self.setRenderHint(QPainter.TextAntialiasing) 36 | 37 | self.setOptimizationFlag(QGraphicsView.DontAdjustForAntialiasing, True) 38 | self.setOptimizationFlag(QGraphicsView.DontSavePainterState, True) 39 | self.setViewportUpdateMode(QGraphicsView.FullViewportUpdate) 40 | 41 | self.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded) 42 | self.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) 43 | 44 | self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) 45 | self.setResizeAnchor(QGraphicsView.AnchorUnderMouse) 46 | self.setInteractive(True) 47 | 48 | self.scene = QGraphicsScene(self) 49 | self.setScene(self.scene) 50 | 51 | self.image_item = None 52 | 53 | def set_image(self, q_img): 54 | pixmap = QPixmap.fromImage(q_img) 55 | if self.image_item: 56 | self.image_item.setPixmap(pixmap) 57 | self.setSceneRect(QRectF(pixmap.rect())) 58 | else: 59 | self.image_item = self.scene.addPixmap(pixmap) 60 | self.setSceneRect(QRectF(pixmap.rect())) 61 | 62 | def wheelEvent(self, event: QWheelEvent): 63 | modifiers = QApplication.keyboardModifiers() 64 | if modifiers == Qt.ControlModifier: 65 | adj = (event.angleDelta().y() / 120) * 0.1 66 | self.scale(1 + adj, 1 + adj) 67 | else: 68 | delta_y = event.angleDelta().y() 69 | delta_x = event.angleDelta().x() 70 | x = self.horizontalScrollBar().value() 71 | self.horizontalScrollBar().setValue(x - delta_x) 72 | y = self.verticalScrollBar().value() 73 | self.verticalScrollBar().setValue(y - delta_y) 74 | 75 | def imshow(self, img): 76 | height, width, channel = img.shape 77 | bytes_per_line = 3 * width 78 | q_img = QImage( 79 | img.data, width, height, bytes_per_line, QImage.Format_RGB888 80 | ).rgbSwapped() 81 | self.set_image(q_img) 82 | 83 | def mousePressEvent(self, event: QMouseEvent) -> None: 84 | # FUTURE USE OF RIGHT CLICK EVENT IN THIS AREA 85 | modifiers = QApplication.keyboardModifiers() 86 | if modifiers == Qt.ControlModifier: 87 | print("Control/ Command key pressed during a mouse click") 88 | # self.editor.remove_click([int(x), int(y)]) 89 | else: 90 | pos = event.pos() 91 | pos_in_item = self.mapToScene(pos) - self.image_item.pos() 92 | x, y = pos_in_item.x(), pos_in_item.y() 93 | if event.button() == Qt.LeftButton: 94 | label = 1 95 | elif event.button() == Qt.RightButton: 96 | label = 0 97 | self.editor.add_click([int(x), int(y)], label, selected_annotations) 98 | self.imshow(self.editor.display) 99 | 100 | 101 | class ApplicationInterface(QWidget): 102 | def __init__(self, app, editor, panel_size=(1920, 1080)): 103 | super(ApplicationInterface, self).__init__() 104 | self.app = app 105 | self.editor = editor 106 | self.panel_size = panel_size 107 | 108 | self.layout = QVBoxLayout() 109 | 110 | self.top_bar = self.get_top_bar() 111 | self.layout.addWidget(self.top_bar) 112 | 113 | self.main_window = QHBoxLayout() 114 | 115 | self.graphics_view = CustomGraphicsView(editor) 116 | self.main_window.addWidget(self.graphics_view) 117 | 118 | self.panel = self.get_side_panel() 119 | self.panel_annotations = QListWidget() 120 | self.panel_annotations.setFixedWidth(200) 121 | self.panel_annotations.setSelectionMode(QAbstractItemView.MultiSelection) 122 | self.panel_annotations.itemClicked.connect(self.annotation_list_item_clicked) 123 | self.get_side_panel_annotations() 124 | self.main_window.addWidget(self.panel) 125 | self.main_window.addWidget(self.panel_annotations) 126 | 127 | self.layout.addLayout(self.main_window) 128 | 129 | self.setLayout(self.layout) 130 | 131 | self.graphics_view.imshow(self.editor.display) 132 | 133 | def reset(self): 134 | global selected_annotations 135 | self.editor.reset(selected_annotations) 136 | self.graphics_view.imshow(self.editor.display) 137 | 138 | def add(self): 139 | global selected_annotations 140 | self.editor.save_ann() 141 | self.editor.reset(selected_annotations) 142 | self.graphics_view.imshow(self.editor.display) 143 | 144 | def next_image(self): 145 | global selected_annotations 146 | self.editor.next_image() 147 | selected_annotations = [] 148 | self.graphics_view.imshow(self.editor.display) 149 | 150 | def prev_image(self): 151 | global selected_annotations 152 | self.editor.prev_image() 153 | selected_annotations = [] 154 | self.graphics_view.imshow(self.editor.display) 155 | 156 | def toggle(self): 157 | global selected_annotations 158 | self.editor.toggle(selected_annotations) 159 | self.graphics_view.imshow(self.editor.display) 160 | 161 | def transparency_up(self): 162 | global selected_annotations 163 | self.editor.step_up_transparency(selected_annotations) 164 | self.graphics_view.imshow(self.editor.display) 165 | 166 | def transparency_down(self): 167 | self.editor.step_down_transparency(selected_annotations) 168 | self.graphics_view.imshow(self.editor.display) 169 | 170 | def save_all(self): 171 | self.editor.save() 172 | 173 | def get_top_bar(self): 174 | top_bar = QWidget() 175 | button_layout = QHBoxLayout(top_bar) 176 | self.layout.addLayout(button_layout) 177 | buttons = [ 178 | ("Add", lambda: self.add()), 179 | ("Reset", lambda: self.reset()), 180 | ("Prev", lambda: self.prev_image()), 181 | ("Next", lambda: self.next_image()), 182 | ("Toggle", lambda: self.toggle()), 183 | ("Transparency Up", lambda: self.transparency_up()), 184 | ("Transparency Down", lambda: self.transparency_down()), 185 | ("Save", lambda: self.save_all()), 186 | ( 187 | "Remove Selected Annotations", 188 | lambda: self.delete_annotations(), 189 | ), 190 | ] 191 | for button, lmb in buttons: 192 | bt = QPushButton(button) 193 | bt.clicked.connect(lmb) 194 | button_layout.addWidget(bt) 195 | 196 | return top_bar 197 | 198 | def get_side_panel(self): 199 | panel = QWidget() 200 | panel_layout = QVBoxLayout(panel) 201 | categories, colors = self.editor.get_categories(get_colors=True) 202 | label_array = [] 203 | for i, _ in enumerate(categories): 204 | label_array.append(QRadioButton(categories[i])) 205 | label_array[i].clicked.connect( 206 | lambda state, x=categories[i]: self.editor.select_category(x) 207 | ) 208 | label_array[i].setStyleSheet( 209 | "background-color: rgba({},{},{},0.6)".format(*colors[i][::-1]) 210 | ) 211 | panel_layout.addWidget(label_array[i]) 212 | 213 | scroll = QScrollArea() 214 | scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded) 215 | scroll.setWidget(panel) 216 | scroll.setFixedWidth(200) 217 | return scroll 218 | 219 | def get_side_panel_annotations(self): 220 | anns, colors = self.editor.list_annotations() 221 | list_widget = self.panel_annotations 222 | list_widget.clear() 223 | # anns, colors = self.editor.get_annotations(self.editor.image_id) 224 | categories = self.editor.get_categories(get_colors=False) 225 | for i, ann in enumerate(anns): 226 | listWidgetItem = QListWidgetItem( 227 | str(ann["id"]) + " - " + (categories[ann["category_id"]]) 228 | ) 229 | list_widget.addItem(listWidgetItem) 230 | return list_widget 231 | 232 | def delete_annotations(self): 233 | global selected_annotations 234 | for annotation in selected_annotations: 235 | self.editor.delete_annotations(annotation) 236 | self.get_side_panel_annotations() 237 | selected_annotations = [] 238 | self.reset() 239 | 240 | def annotation_list_item_clicked(self, item): 241 | global selected_annotations 242 | if item.isSelected(): 243 | selected_annotations.append(int(item.text().split(" ")[0])) 244 | self.editor.draw_selected_annotations(selected_annotations) 245 | else: 246 | selected_annotations.remove(int(item.text().split(" ")[0])) 247 | self.editor.draw_selected_annotations(selected_annotations) 248 | self.graphics_view.imshow(self.editor.display) 249 | 250 | def keyPressEvent(self, event): 251 | if event.key() == Qt.Key_Escape: 252 | self.app.quit() 253 | if event.key() == Qt.Key_A: 254 | self.prev_image() 255 | self.get_side_panel_annotations() 256 | if event.key() == Qt.Key_D: 257 | self.next_image() 258 | self.get_side_panel_annotations() 259 | if event.key() == Qt.Key_K: 260 | self.transparency_down() 261 | if event.key() == Qt.Key_L: 262 | self.transparency_up() 263 | if event.key() == Qt.Key_N: 264 | self.add() 265 | self.get_side_panel_annotations() 266 | if event.key() == Qt.Key_R: 267 | self.reset() 268 | if event.key() == Qt.Key_T: 269 | self.toggle() 270 | if event.modifiers() == Qt.ControlModifier and event.key() == Qt.Key_S: 271 | self.save_all() 272 | elif event.key() == Qt.Key_Space: 273 | print("Space pressed") 274 | # self.clear_annotations(selected_annotations) 275 | # Do something if the space bar is pressed 276 | # pass 277 | -------------------------------------------------------------------------------- /salt/onnx_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import onnxruntime 4 | 5 | from salt.utils import apply_coords 6 | 7 | 8 | def get_model_path_from_resolution(onnx_models_path, width, height): 9 | onnx_model_path = os.path.join(onnx_models_path, f"sam_onnx.{height}_{width}.onnx") 10 | return onnx_model_path 11 | 12 | 13 | class OnnxModels: 14 | def __init__( 15 | self, onnx_models_path, threshold=0.5, image_width=1920, image_height=1080 16 | ): 17 | self.onnx_models_path = onnx_models_path 18 | print(self.onnx_models_path) 19 | self.threshold = threshold 20 | self.set_image_resolution(image_width, image_height) 21 | 22 | def __init_model(self, onnx_model_path): 23 | self.ort_session = onnxruntime.InferenceSession( 24 | onnx_model_path, providers=["CPUExecutionProvider"] 25 | ) 26 | 27 | def set_image_resolution(self, width, height): 28 | self.image_width = width 29 | self.image_height = height 30 | onnx_model_path = get_model_path_from_resolution( 31 | self.onnx_models_path, width, height 32 | ) 33 | self.__init_model(onnx_model_path) 34 | 35 | def __translate_input( 36 | self, 37 | image, 38 | image_embedding, 39 | input_point, 40 | input_label, 41 | input_box=None, 42 | onnx_mask_input=None, 43 | ): 44 | if input_box is None: 45 | onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[ 46 | None, :, : 47 | ] 48 | onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[ 49 | None, : 50 | ].astype(np.float32) 51 | else: 52 | onnx_box_coords = input_box.reshape(2, 2) 53 | onnx_box_labels = np.array([2, 3]) 54 | onnx_coord = np.concatenate([input_point, onnx_box_coords], axis=0)[ 55 | None, :, : 56 | ] 57 | onnx_label = np.concatenate([input_label, onnx_box_labels], axis=0)[ 58 | None, : 59 | ].astype(np.float32) 60 | 61 | onnx_coord = apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) 62 | if onnx_mask_input is None: 63 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) 64 | onnx_has_mask_input = np.zeros(1, dtype=np.float32) 65 | else: 66 | onnx_has_mask_input = np.ones(1, dtype=np.float32) 67 | ort_inputs = { 68 | "image_embeddings": image_embedding, 69 | "point_coords": onnx_coord, 70 | "point_labels": onnx_label, 71 | "mask_input": onnx_mask_input, 72 | "has_mask_input": onnx_has_mask_input, 73 | "orig_im_size": np.array(image.shape[:2], dtype=np.float32), 74 | } 75 | return ort_inputs 76 | 77 | def call( 78 | self, 79 | image, 80 | image_embedding, 81 | input_point, 82 | input_label, 83 | selected_box=None, 84 | low_res_logits=None, 85 | ): 86 | onnx_mask_input = None 87 | input_box = None 88 | if low_res_logits is not None: 89 | onnx_mask_input = low_res_logits 90 | if input_box is not None: 91 | input_box = selected_box 92 | ort_inputs = self.__translate_input( 93 | image, 94 | image_embedding, 95 | input_point, 96 | input_label, 97 | input_box=input_box, 98 | onnx_mask_input=onnx_mask_input, 99 | ) 100 | masks, _, low_res_logits = self.ort_session.run(None, ort_inputs) 101 | masks = masks > self.threshold 102 | return masks, low_res_logits 103 | -------------------------------------------------------------------------------- /salt/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from typing import Tuple 4 | 5 | def get_preprocess_shape( 6 | oldh: int, oldw: int, long_side_length: int 7 | ) -> Tuple[int, int]: 8 | """ 9 | Compute the output size given input size and target long side length. 10 | """ 11 | scale = long_side_length * 1.0 / max(oldh, oldw) 12 | newh, neww = oldh * scale, oldw * scale 13 | neww = int(neww + 0.5) 14 | newh = int(newh + 0.5) 15 | return (newh, neww) 16 | 17 | 18 | def apply_coords(coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 19 | """ 20 | Expects a numpy array of length 2 in the final dimension. Requires the 21 | original image size in (H, W) format. 22 | """ 23 | old_h, old_w = original_size 24 | new_h, new_w = get_preprocess_shape(original_size[0], original_size[1], 1024) 25 | coords = deepcopy(coords).astype(float) 26 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 27 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 28 | return coords 29 | -------------------------------------------------------------------------------- /segment_anything_annotator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | 5 | from PyQt5.QtWidgets import QApplication 6 | 7 | from salt.editor import Editor 8 | from salt.interface import ApplicationInterface 9 | 10 | if __name__ == "__main__": 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--onnx-models-path", type=str, default="./models") 14 | parser.add_argument("--dataset-path", type=str, default="./dataset") 15 | parser.add_argument("--categories", type=str) 16 | args = parser.parse_args() 17 | 18 | onnx_models_path = args.onnx_models_path 19 | dataset_path = args.dataset_path 20 | categories = None 21 | if args.categories is not None: 22 | categories = args.categories.split(",") 23 | 24 | coco_json_path = os.path.join(dataset_path,"annotations.json") 25 | 26 | editor = Editor( 27 | onnx_models_path, 28 | dataset_path, 29 | categories=categories, 30 | coco_json_path=coco_json_path 31 | ) 32 | 33 | app = QApplication(sys.argv) 34 | window = ApplicationInterface(app, editor) 35 | window.show() 36 | sys.exit(app.exec_()) --------------------------------------------------------------------------------