├── .gitignore ├── LICENSE ├── README.md ├── Segment_Anything_Automatic_Mask_Generator_Demo.ipynb ├── Segment_Anything_Benchmarks.ipynb ├── Segment_Anything_multi_backend_Keras_Demo.ipynb ├── benchmark.png ├── requirements.txt ├── sam_keras ├── __init__.py ├── amg_utils.py ├── automatic_mask_generator.py ├── jax_nms.py ├── predictor.py └── prompter.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Tirth Patel 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything Model in Multi-Backend Keras 2 | 3 | This is an implementation of the Segment Anything predictor and automatic mask 4 | generator in Keras 3. 5 | 6 | The demos uses KerasCV's Segment Anything model: 7 | 8 | - [Predictor demo](Segment_Anything_multi_backend_Keras_Demo.ipynb) 9 | - [Atomatic Mask Generator demo](Segment_Anything_Automatic_Mask_Generator_Demo.ipynb) 10 | 11 | ## Install the package 12 | 13 | ```shell 14 | pip install git+https://github.com/tirthasheshpatel/segment_anything_keras.git 15 | ``` 16 | 17 | Install the required dependencies: 18 | 19 | ```shell 20 | pip install -U Pillow numpy keras keras-cv 21 | ``` 22 | 23 | Install TensorFlow, JAX, or PyTorch, whichever backend you'd like to use. 24 | 25 | To get all the dependencies and all the backends to run the demos, do: 26 | 27 | ```shell 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Getting the pretrained Segment Anything Model 32 | 33 | ```python 34 | # Use TensorFlow backend, choose any you want 35 | import os 36 | os.environ['KERAS_BACKEND'] = "tensorflow" 37 | 38 | from keras_cv.models import SegmentAnythingModel 39 | from sam_keras import SAMPredictor 40 | 41 | # Get the huge model trained on the SA-1B dataset. 42 | # Other available options are: 43 | # - "sam_base_sa1b" 44 | # - "sam_large_sa1b" 45 | model = SegmentAnythingModel.from_preset("sam_huge_sa1b") 46 | 47 | # Create the predictor 48 | predictor = SAMPredictor(model) 49 | 50 | # Now you can use the predictor just like the one on the original repo. 51 | # The only difference is list of input dicts isn't supported; instead 52 | # pass each input dict separately to the `predict` method. 53 | ``` 54 | 55 | ## Notes 56 | 57 | Right now JAX and TensorFlow have large compile-time overhead. Prompt encoder 58 | recompiles each time a different combination of prompts (points only, 59 | points + boxes, boxes only, etc) is passed. To avoid this, compile the model 60 | with `run_eagerly=True` and `jit_compile=False`. 61 | -------------------------------------------------------------------------------- /Segment_Anything_Benchmarks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "V100", 8 | "machine_shape": "hm" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "source": [ 23 | "# Segment Anything Keras 3.0 port Benchmarks\n", 24 | "\n", 25 | "This notebook benchmarks the segment anything model for TensorFlow, JAX, and PyTorch using Keras 3.0.\n", 26 | "\n", 27 | "There are three types of benchmarks:\n", 28 | "\n", 29 | "1. End-to-end model inference (`image_encoder + prompt_encoder + mask_decoder`)\n", 30 | "2. End-to-end model inference with pre and post-processing\n", 31 | "3. Prompt benchmarks (`prompt_encoder + mask_decoder` with image features set)" 32 | ], 33 | "metadata": { 34 | "id": "belFxJbxUSuK" 35 | } 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "source": [ 40 | "## Get all the dependencies and weight sets" 41 | ], 42 | "metadata": { 43 | "id": "3nS0Cb8hU3eV" 44 | } 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 1, 49 | "metadata": { 50 | "id": "wvTf8fxHSp5c", 51 | "colab": { 52 | "base_uri": "https://localhost:8080/" 53 | }, 54 | "outputId": "7c1c8dbe-dc5a-4e4b-fac4-6e09fdb2b569" 55 | }, 56 | "outputs": [ 57 | { 58 | "output_type": "stream", 59 | "name": "stdout", 60 | "text": [ 61 | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", 62 | "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.2 which is incompatible.\u001b[0m\u001b[31m\n", 63 | "\u001b[0m" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# Get the dependencies\n", 69 | "!pip install -Uq keras-cv >> /dev/null\n", 70 | "!pip install -Uq keras >> /dev/null\n", 71 | "!pip install -Uq git+https://github.com/tirthasheshpatel/segment_anything_keras.git >> /dev/null\n", 72 | "\n", 73 | "# Get the image for the demo\n", 74 | "!curl -sSL https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg -o truck.jpg\n", 75 | "!curl -sSL https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg -o groceries.jpg" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "source": [ 81 | "## Set the backend" 82 | ], 83 | "metadata": { 84 | "id": "wdSmBxslU8Xd" 85 | } 86 | }, 87 | { 88 | "cell_type": "code", 89 | "source": [ 90 | "import os\n", 91 | "os.environ['KERAS_BACKEND'] = \"torch\"" 92 | ], 93 | "metadata": { 94 | "id": "fXasdeUnS0pb" 95 | }, 96 | "execution_count": 2, 97 | "outputs": [] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "source": [ 102 | "## Choose the model" 103 | ], 104 | "metadata": { 105 | "id": "IfNiHXYLhsvQ" 106 | } 107 | }, 108 | { 109 | "cell_type": "code", 110 | "source": [ 111 | "model_type = \"huge\"" 112 | ], 113 | "metadata": { 114 | "id": "xykkspqNhuvS" 115 | }, 116 | "execution_count": 3, 117 | "outputs": [] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "source": [ 122 | "## Import Dependencies" 123 | ], 124 | "metadata": { 125 | "id": "vSxB0WbKVaGw" 126 | } 127 | }, 128 | { 129 | "cell_type": "code", 130 | "source": [ 131 | "import cv2\n", 132 | "import numpy as np\n", 133 | "import matplotlib.pyplot as plt\n", 134 | "import keras\n", 135 | "from keras import ops\n", 136 | "from keras_cv.models import SegmentAnythingModel\n", 137 | "from sam_keras import SAMPredictor" 138 | ], 139 | "metadata": { 140 | "id": "XHiX11_QS0rr" 141 | }, 142 | "execution_count": 4, 143 | "outputs": [] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "source": [ 148 | "## Define the model" 149 | ], 150 | "metadata": { 151 | "id": "YMKf6hiLVDqf" 152 | } 153 | }, 154 | { 155 | "cell_type": "code", 156 | "source": [ 157 | "sam = SegmentAnythingModel.from_preset(f\"sam_{model_type}_sa1b\")" 158 | ], 159 | "metadata": { 160 | "id": "38UUq9vzS0uY", 161 | "colab": { 162 | "base_uri": "https://localhost:8080/" 163 | }, 164 | "outputId": "eb01e05f-3272-4c2d-f4ba-1a4daa754365" 165 | }, 166 | "execution_count": 5, 167 | "outputs": [ 168 | { 169 | "output_type": "stream", 170 | "name": "stdout", 171 | "text": [ 172 | "Downloading data from https://storage.googleapis.com/keras-cv/models/segment_anything/sam_huge.h5\n", 173 | "\u001b[1m2564774344/2564774344\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m74s\u001b[0m 0us/step\n" 174 | ] 175 | } 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "source": [ 181 | "## End-to-End Model Inference with pre and post-processing" 182 | ], 183 | "metadata": { 184 | "id": "i1OZ0AaOVOul" 185 | } 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "source": [ 190 | "### Setup" 191 | ], 192 | "metadata": { 193 | "id": "tXn-pElWhQdw" 194 | } 195 | }, 196 | { 197 | "cell_type": "code", 198 | "source": [ 199 | "# Define predictor\n", 200 | "model = SAMPredictor(sam)\n", 201 | "transform = model.transform\n", 202 | "\n", 203 | "# Load the image\n", 204 | "image = cv2.imread('truck.jpg')\n", 205 | "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 206 | "\n", 207 | "# Define the inputs\n", 208 | "input_point = np.array([[500, 375]])\n", 209 | "input_label = np.array([1])\n", 210 | "\n", 211 | "image_record = {}\n", 212 | "\n", 213 | "image_record[\"image\"] = ops.convert_to_tensor(\n", 214 | " transform.apply_image(image)[np.newaxis, ...],\n", 215 | " dtype=\"float32\"\n", 216 | ")\n", 217 | "\n", 218 | "image_record[\"original_size\"] = (image.shape[0], image.shape[1])\n", 219 | "\n", 220 | "image_record[\"point_coords\"] = ops.reshape(\n", 221 | " ops.convert_to_tensor(\n", 222 | " input_point, dtype=\"float32\"\n", 223 | " ),\n", 224 | " (1, 1, 2)\n", 225 | ")\n", 226 | "image_record[\"point_coords\"] = transform.apply_coords(\n", 227 | " image_record[\"point_coords\"], image_record[\"original_size\"]\n", 228 | ")\n", 229 | "\n", 230 | "image_record[\"point_labels\"] = ops.convert_to_tensor(\n", 231 | " input_label[np.newaxis, ...],\n", 232 | " dtype=\"float32\"\n", 233 | ")" 234 | ], 235 | "metadata": { 236 | "id": "Ps93S2vVTZie" 237 | }, 238 | "execution_count": 6, 239 | "outputs": [] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "source": [ 244 | "### Benchmark" 245 | ], 246 | "metadata": { 247 | "id": "UDqGr2i7hT-7" 248 | } 249 | }, 250 | { 251 | "cell_type": "code", 252 | "source": [ 253 | "# Dry run to build the model\n", 254 | "out = model.predict(image_record)" 255 | ], 256 | "metadata": { 257 | "colab": { 258 | "base_uri": "https://localhost:8080/" 259 | }, 260 | "id": "N0CmG5a5TZk8", 261 | "outputId": "9fd3b98d-c5cd-4d24-8175-46c5d89f1b26" 262 | }, 263 | "execution_count": 7, 264 | "outputs": [ 265 | { 266 | "output_type": "stream", 267 | "name": "stdout", 268 | "text": [ 269 | "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 3s/step\n" 270 | ] 271 | } 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "source": [ 277 | "# Predict also reports a time. Let's consider that too.\n", 278 | "out = model.predict(image_record)" 279 | ], 280 | "metadata": { 281 | "colab": { 282 | "base_uri": "https://localhost:8080/" 283 | }, 284 | "id": "6N_xMBuhRhtu", 285 | "outputId": "e1218f55-fd33-4888-e373-729751890ef4" 286 | }, 287 | "execution_count": 8, 288 | "outputs": [ 289 | { 290 | "output_type": "stream", 291 | "name": "stdout", 292 | "text": [ 293 | "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 598ms/step\n" 294 | ] 295 | } 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "source": [ 301 | "# Benchmark the model\n", 302 | "%timeit out = model.predict(image_record, verbose=0)" 303 | ], 304 | "metadata": { 305 | "colab": { 306 | "base_uri": "https://localhost:8080/" 307 | }, 308 | "id": "We_RwF7dWY63", 309 | "outputId": "b16161c9-6c03-4db6-a939-fe4cbdcc5753" 310 | }, 311 | "execution_count": 9, 312 | "outputs": [ 313 | { 314 | "output_type": "stream", 315 | "name": "stdout", 316 | "text": [ 317 | "609 ms ± 3.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 318 | ] 319 | } 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "source": [ 325 | "## End-to-End Model Inference" 326 | ], 327 | "metadata": { 328 | "id": "AKE1PdWqYLYJ" 329 | } 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "source": [ 334 | "### Setup" 335 | ], 336 | "metadata": { 337 | "id": "Q4rm45Y7hXhF" 338 | } 339 | }, 340 | { 341 | "cell_type": "code", 342 | "source": [ 343 | "# Run the pre and post-processing steps here itself\n", 344 | "images = model.preprocess_images(image_record[\"image\"])\n", 345 | "points = ops.convert_to_tensor(\n", 346 | " image_record.get(\"point_coords\", ops.ones((1, 0, 2))),\n", 347 | " dtype=\"float32\"\n", 348 | ")\n", 349 | "labels = ops.convert_to_tensor(\n", 350 | " image_record.get(\"point_labels\", ops.ones((1, 0))),\n", 351 | " dtype=\"float32\"\n", 352 | ")\n", 353 | "box = ops.convert_to_tensor(\n", 354 | " image_record.get(\"boxes\", ops.ones((1, 0, 2, 2))),\n", 355 | " dtype=\"float32\"\n", 356 | ")\n", 357 | "mask = ops.convert_to_tensor(\n", 358 | " image_record.get(\"mask_inputs\", ops.ones((1, 0, 256, 256, 1))),\n", 359 | " dtype=\"float32\"\n", 360 | ")\n", 361 | "\n", 362 | "if ops.size(points) and not ops.size(box):\n", 363 | " pad_point = ops.zeros((points.shape[0], 1, 2), dtype=\"float32\")\n", 364 | " pad_label = -ops.ones((labels.shape[0], 1), dtype=\"float32\")\n", 365 | " points = ops.concatenate([points, pad_point], axis=1)\n", 366 | " labels = ops.concatenate([labels, pad_label], axis=1)\n", 367 | "\n", 368 | "B = max([\n", 369 | " images.shape[0],\n", 370 | " points.shape[0],\n", 371 | " labels.shape[0],\n", 372 | " box.shape[0],\n", 373 | " mask.shape[0],\n", 374 | "])\n", 375 | "\n", 376 | "images, points, labels, box, mask = model._broadcast_batch(\n", 377 | " B, images, points, labels, box, mask\n", 378 | ")\n", 379 | "\n", 380 | "model_input = {\n", 381 | " \"images\": images,\n", 382 | " \"points\": points,\n", 383 | " \"labels\": labels,\n", 384 | " \"boxes\": box,\n", 385 | " \"masks\": mask\n", 386 | "}" 387 | ], 388 | "metadata": { 389 | "id": "835jlxIyXkvv" 390 | }, 391 | "execution_count": 10, 392 | "outputs": [] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "source": [ 397 | "### Benchmark" 398 | ], 399 | "metadata": { 400 | "id": "DoovDjh5hbax" 401 | } 402 | }, 403 | { 404 | "cell_type": "code", 405 | "source": [ 406 | "model.model.predict(model_input);" 407 | ], 408 | "metadata": { 409 | "colab": { 410 | "base_uri": "https://localhost:8080/" 411 | }, 412 | "id": "jprz_7nqXJ2V", 413 | "outputId": "fd3be700-ac13-45b8-a3dc-4dc45d837ab3" 414 | }, 415 | "execution_count": 11, 416 | "outputs": [ 417 | { 418 | "output_type": "stream", 419 | "name": "stdout", 420 | "text": [ 421 | "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 600ms/step\n" 422 | ] 423 | } 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "source": [ 429 | "%timeit model.model.predict(model_input, verbose=0)" 430 | ], 431 | "metadata": { 432 | "colab": { 433 | "base_uri": "https://localhost:8080/" 434 | }, 435 | "id": "U27ga7D8YO12", 436 | "outputId": "5f70ff29-715f-4061-e159-46d6046ab19d" 437 | }, 438 | "execution_count": 12, 439 | "outputs": [ 440 | { 441 | "output_type": "stream", 442 | "name": "stdout", 443 | "text": [ 444 | "608 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 445 | ] 446 | } 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "source": [ 452 | "## Prompt Benchmarks" 453 | ], 454 | "metadata": { 455 | "id": "nvaEd-AlaCf2" 456 | } 457 | }, 458 | { 459 | "cell_type": "markdown", 460 | "source": [ 461 | "### Setup" 462 | ], 463 | "metadata": { 464 | "id": "h_mCbABphgGC" 465 | } 466 | }, 467 | { 468 | "cell_type": "code", 469 | "source": [ 470 | "# Set the features\n", 471 | "features = ops.convert_to_tensor(\n", 472 | " model.model.backbone.predict(model_input[\"images\"], verbose=0),\n", 473 | " dtype=\"float32\"\n", 474 | ")" 475 | ], 476 | "metadata": { 477 | "id": "eKgwlNw7YTem" 478 | }, 479 | "execution_count": 13, 480 | "outputs": [] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "source": [ 485 | "class SAMPrompter(keras.Model):\n", 486 | " def __init__(self, prompt_encoder, mask_decoder, feature_shape=(64, 64, 256), **kwargs):\n", 487 | " # Define the prompt encoder inputs -- Prompts\n", 488 | " prompt_inputs = {\n", 489 | " \"points\": keras.Input(shape=[None, 2], name=\"points\"),\n", 490 | " \"labels\": keras.Input(shape=[None], name=\"labels\"),\n", 491 | " \"boxes\": keras.Input(shape=[None, 2, 2], name=\"boxes\"),\n", 492 | " \"masks\": keras.Input(shape=[None, None, None, 1], name=\"masks\"),\n", 493 | " }\n", 494 | "\n", 495 | " # All Inputs -- Features + Prompts\n", 496 | " all_inputs = {\"features\": keras.Input(feature_shape, name=\"features\")}\n", 497 | " all_inputs.update(prompt_inputs)\n", 498 | "\n", 499 | " # Build the prompt encoder\n", 500 | " prompt_embeddings = prompt_encoder(prompt_inputs)\n", 501 | "\n", 502 | " # Define the mask decoder inputs\n", 503 | " mask_decoder_inputs = {\n", 504 | " \"image_embeddings\": all_inputs[\"features\"],\n", 505 | " \"image_pe\": prompt_embeddings[\"dense_positional_embeddings\"],\n", 506 | " \"sparse_prompt_embeddings\": prompt_embeddings[\"sparse_embeddings\"],\n", 507 | " \"dense_prompt_embeddings\": prompt_embeddings[\"dense_embeddings\"],\n", 508 | " }\n", 509 | "\n", 510 | " # Build the mask decoder\n", 511 | " outputs = mask_decoder(mask_decoder_inputs)\n", 512 | "\n", 513 | " super().__init__(inputs=all_inputs, outputs=outputs, **kwargs)\n", 514 | "\n", 515 | " self.prompt_encoder = prompt_encoder\n", 516 | " self.mask_decoder = mask_decoder" 517 | ], 518 | "metadata": { 519 | "id": "SZif2w3rZSJs" 520 | }, 521 | "execution_count": 14, 522 | "outputs": [] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "source": [ 527 | "prompter_model = SAMPrompter(model.model.prompt_encoder, model.model.mask_decoder, feature_shape=features.shape[1:])" 528 | ], 529 | "metadata": { 530 | "id": "9xEISA_IgKa9" 531 | }, 532 | "execution_count": 15, 533 | "outputs": [] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "source": [ 538 | "prompt_inputs = {\n", 539 | " \"features\": features,\n", 540 | " \"points\": model_input[\"points\"],\n", 541 | " \"labels\": model_input[\"labels\"],\n", 542 | " \"boxes\": model_input[\"boxes\"],\n", 543 | " \"masks\": model_input[\"masks\"]\n", 544 | "}" 545 | ], 546 | "metadata": { 547 | "id": "UtwBbwv4gUa5" 548 | }, 549 | "execution_count": 16, 550 | "outputs": [] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "source": [ 555 | "### Benchmark" 556 | ], 557 | "metadata": { 558 | "id": "nLRbgdeDhiyP" 559 | } 560 | }, 561 | { 562 | "cell_type": "code", 563 | "source": [ 564 | "# Dry run to build the model\n", 565 | "outs = prompter_model.predict(prompt_inputs)" 566 | ], 567 | "metadata": { 568 | "colab": { 569 | "base_uri": "https://localhost:8080/" 570 | }, 571 | "id": "ptDHJOtLgnuL", 572 | "outputId": "902c2cd4-62bd-4b42-e896-298c195ff601" 573 | }, 574 | "execution_count": 17, 575 | "outputs": [ 576 | { 577 | "output_type": "stream", 578 | "name": "stdout", 579 | "text": [ 580 | "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 38ms/step\n" 581 | ] 582 | } 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "source": [ 588 | "# Predict also reports a time. Let's also consider that.\n", 589 | "outs = prompter_model.predict(prompt_inputs)" 590 | ], 591 | "metadata": { 592 | "colab": { 593 | "base_uri": "https://localhost:8080/" 594 | }, 595 | "id": "DBkGykXKRykq", 596 | "outputId": "7bde8332-3572-4062-cbbf-49c942e37c2e" 597 | }, 598 | "execution_count": 18, 599 | "outputs": [ 600 | { 601 | "output_type": "stream", 602 | "name": "stdout", 603 | "text": [ 604 | "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 38ms/step\n" 605 | ] 606 | } 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "source": [ 612 | "%timeit outs = prompter_model.predict(prompt_inputs, verbose=0)" 613 | ], 614 | "metadata": { 615 | "colab": { 616 | "base_uri": "https://localhost:8080/" 617 | }, 618 | "id": "9AGFkbYQg1Op", 619 | "outputId": "b0fe8765-a429-4d2d-fa96-15c4dc3a0021" 620 | }, 621 | "execution_count": 19, 622 | "outputs": [ 623 | { 624 | "output_type": "stream", 625 | "name": "stdout", 626 | "text": [ 627 | "39.4 ms ± 834 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 628 | ] 629 | } 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "source": [], 635 | "metadata": { 636 | "id": "HbuSvHkQWGdc" 637 | }, 638 | "execution_count": 19, 639 | "outputs": [] 640 | } 641 | ] 642 | } -------------------------------------------------------------------------------- /benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tirthasheshpatel/segment_anything_keras/305eab30bc303e23e8890add837489598f6b67d0/benchmark.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | numpy 3 | tf-nightly 4 | torch 5 | torchvision 6 | torchaudio 7 | jaxlib 8 | jax 9 | "keras-cv==0.7.0.dev0" 10 | pycocotools 11 | opencv-python 12 | -------------------------------------------------------------------------------- /sam_keras/__init__.py: -------------------------------------------------------------------------------- 1 | # Author: Tirth Patel (tirthasheshpatel@gmail.com) 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from sam_keras.predictor import SAMPredictor, ResizeLongestSide 10 | from sam_keras.automatic_mask_generator import SAMAutomaticMaskGenerator 11 | from sam_keras.prompter import SAMPrompter 12 | -------------------------------------------------------------------------------- /sam_keras/amg_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Tirth Patel (tirthasheshpatel@gmail.com) 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import numpy as np 10 | 11 | from keras import ops 12 | 13 | import math 14 | from copy import deepcopy 15 | from itertools import product 16 | 17 | 18 | class MaskData: 19 | """ 20 | A structure for storing masks and their related data in batched format. 21 | Implements basic filtering and concatenation. 22 | """ 23 | 24 | def __init__(self, **kwargs): 25 | for v in kwargs.values(): 26 | if ( 27 | not isinstance(v, list) 28 | and not ops.is_tensor(v) 29 | and not isinstance(v, np.ndarray) 30 | ): 31 | raise ValueError( 32 | "`MaskData` only supports `list`, tensors, and numpy arrays." 33 | ) 34 | self._stats = dict(**kwargs) 35 | 36 | def __setitem__(self, key, item): 37 | if ( 38 | not isinstance(item, list) 39 | and not ops.is_tensor(item) 40 | and not isinstance(item, np.ndarray) 41 | ): 42 | raise ValueError( 43 | "`MaskData` only supports `list`, tensors, and numpy arrays." 44 | ) 45 | self._stats[key] = item 46 | 47 | def __delitem__(self, key): 48 | del self._stats[key] 49 | 50 | def __getitem__(self, key): 51 | return self._stats[key] 52 | 53 | def items(self): 54 | return self._stats.items() 55 | 56 | def filter(self, keep): 57 | for k, v in self._stats.items(): 58 | if v is None: 59 | self._stats[k] = None 60 | elif ops.is_tensor(v): 61 | if "bool" in str(keep.dtype): 62 | self._stats[k] = v[keep] 63 | else: 64 | self._stats[k] = ops.take(v, keep, 0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = v[ops.convert_to_numpy(keep)] 67 | elif isinstance(v, list) and "bool" in str(keep.dtype): 68 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 69 | elif isinstance(v, list): 70 | self._stats[k] = [v[i] for i in keep] 71 | else: 72 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 73 | 74 | def cat(self, new_stats): 75 | for k, v in new_stats.items(): 76 | if k not in self._stats or self._stats[k] is None: 77 | self._stats[k] = deepcopy(v) 78 | elif ops.is_tensor(v): 79 | self._stats[k] = ops.concatenate([self._stats[k], v], axis=0) 80 | elif isinstance(v, np.ndarray): 81 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 82 | elif isinstance(v, list): 83 | self._stats[k] = self._stats[k] + v 84 | else: 85 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 86 | 87 | def to_numpy(self): 88 | for k, v in self._stats.items(): 89 | if ops.is_tensor(v): 90 | self._stats[k] = ops.convert_to_numpy(v) 91 | 92 | 93 | def _isclose(x1, x2, atol, rtol): 94 | x1 = ops.convert_to_numpy(x1) 95 | x2 = ops.convert_to_numpy(x2) 96 | return ops.convert_to_tensor(np.isclose(x1, x2, rtol=rtol, atol=atol)) 97 | 98 | 99 | def is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0): 100 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 101 | crop_box_torch = ops.convert_to_tensor(crop_box, dtype="float32") 102 | orig_box_torch = ops.convert_to_tensor(orig_box, dtype="float32") 103 | boxes = ops.cast(uncrop_boxes_xyxy(boxes, crop_box), dtype="float32") 104 | near_crop_edge = _isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 105 | near_image_edge = _isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 106 | near_crop_edge = near_crop_edge & (~near_image_edge) 107 | return ops.any(near_crop_edge, axis=1) 108 | 109 | 110 | def box_xyxy_to_xywh(box_xyxy, axis=-1): 111 | box_xyxy = ops.moveaxis(box_xyxy, axis, 0) 112 | box_xywh = ops.stack( 113 | [ 114 | box_xyxy[0], 115 | box_xyxy[1], 116 | box_xyxy[2] - box_xyxy[0], 117 | box_xyxy[3] - box_xyxy[1], 118 | ], 119 | axis=0, 120 | ) 121 | return ops.moveaxis(box_xywh, 0, axis) 122 | 123 | 124 | def box_xyxy_to_yxyx(box_xyxy, axis=-1): 125 | box_xyxy = ops.moveaxis(box_xyxy, axis, 0) 126 | box_yxyx = ops.stack( 127 | [ 128 | box_xyxy[1], 129 | box_xyxy[0], 130 | box_xyxy[3], 131 | box_xyxy[2], 132 | ], 133 | axis=0, 134 | ) 135 | return ops.moveaxis(box_yxyx, 0, axis) 136 | 137 | 138 | def batch_iterator(batch_size: int, *args): 139 | if not len(args) > 0 or not all(len(a) == len(args[0]) for a in args): 140 | raise ValueError("Batched iteration must have inputs of all the same size.") 141 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 142 | for b in range(n_batches): 143 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 144 | 145 | 146 | def mask_to_rle_tensor(tensor): 147 | """ 148 | Encodes masks to an uncompressed RLE, in the format expected by 149 | pycoco tools. 150 | """ 151 | tensor = ops.convert_to_numpy(tensor) 152 | # Put in fortran order and flatten h,w 153 | b, h, w = tensor.shape 154 | tensor = np.reshape(np.transpose(tensor, axes=(0, 2, 1)), (b, w * h)) 155 | 156 | # Compute change indices 157 | diff = tensor[:, 1:] ^ tensor[:, :-1] 158 | change_indices = np.stack(np.nonzero(diff), 1) 159 | 160 | # Encode run length 161 | out = [] 162 | for i in range(b): 163 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 164 | cur_idxs = np.concatenate( 165 | [ 166 | np.array([0], dtype=cur_idxs.dtype), 167 | cur_idxs + 1, 168 | np.array([h * w], dtype=cur_idxs.dtype), 169 | ] 170 | ) 171 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 172 | counts = [] if tensor[i, 0] == False else [0] 173 | counts.extend(list(btw_idxs)) 174 | out.append({"size": [h, w], "counts": counts}) 175 | return out 176 | 177 | 178 | def rle_to_mask(rle): 179 | """Compute a binary mask from an uncompressed RLE.""" 180 | h, w = rle["size"] 181 | mask = np.empty(h * w, dtype=bool) 182 | idx = 0 183 | parity = False 184 | for count in rle["counts"]: 185 | mask[idx : idx + count] = parity 186 | idx += count 187 | parity ^= True 188 | mask = mask.reshape(w, h) 189 | return mask.transpose() 190 | 191 | 192 | def area_from_rle(rle): 193 | return sum(rle["counts"][1::2]) 194 | 195 | 196 | def calculate_stability_score(masks, mask_threshold, threshold_offset): 197 | """ 198 | Computes the stability score for a batch of masks. The stability 199 | score is the IoU between the binary masks obtained by thresholding 200 | the predicted mask logits at high and low values. 201 | """ 202 | # One mask is always contained inside the other. 203 | # Save memory by preventing unnecessary cast to torch.int64 204 | intersections = ops.sum( 205 | ops.sum( 206 | ops.cast(masks > (mask_threshold + threshold_offset), dtype="float32"), -1 207 | ), 208 | -1, 209 | ) 210 | unions = ops.sum( 211 | ops.sum( 212 | ops.cast(masks > (mask_threshold - threshold_offset), dtype="float32"), -1 213 | ), 214 | -1, 215 | ) 216 | return intersections / unions 217 | 218 | 219 | def build_point_grid(n_per_side): 220 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 221 | offset = 1 / (2 * n_per_side) 222 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 223 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 224 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 225 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 226 | return points 227 | 228 | 229 | def build_all_layer_point_grids(n_per_side, n_layers, scale_per_layer): 230 | """Generates point grids for all crop layers.""" 231 | points_by_layer = [] 232 | for i in range(n_layers + 1): 233 | n_points = int(n_per_side / (scale_per_layer**i)) 234 | points_by_layer.append(build_point_grid(n_points)) 235 | return points_by_layer 236 | 237 | 238 | def generate_crop_boxes(im_size, n_layers, overlap_ratio): 239 | """ 240 | Generates a list of crop boxes of different sizes. Each layer 241 | has (2**i)**2 boxes for the ith layer. 242 | """ 243 | crop_boxes, layer_idxs = [], [] 244 | im_h, im_w = im_size 245 | short_side = min(im_h, im_w) 246 | 247 | # Original image 248 | crop_boxes.append([0, 0, im_w, im_h]) 249 | layer_idxs.append(0) 250 | 251 | def crop_len(orig_len, n_crops, overlap): 252 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 253 | 254 | for i_layer in range(n_layers): 255 | n_crops_per_side = 2 ** (i_layer + 1) 256 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 257 | 258 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 259 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 260 | 261 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 262 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 263 | 264 | # Crops in XYWH format 265 | for x0, y0 in product(crop_box_x0, crop_box_y0): 266 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 267 | crop_boxes.append(box) 268 | layer_idxs.append(i_layer + 1) 269 | 270 | return crop_boxes, layer_idxs 271 | 272 | 273 | def uncrop_boxes_xyxy(boxes, crop_box): 274 | x0, y0, _, _ = crop_box 275 | boxes = ops.cast(boxes, "float32") 276 | offset = ops.convert_to_tensor([[x0, y0, x0, y0]], dtype="float32") 277 | # Check if boxes has a channel dimension 278 | if len(boxes.shape) == 3: 279 | offset = offset[:, None, ...] 280 | return boxes + offset 281 | 282 | 283 | def uncrop_points(points, crop_box): 284 | x0, y0, _, _ = crop_box 285 | points = ops.cast(points, "float32") 286 | offset = ops.convert_to_tensor([[x0, y0]], dtype="float32") 287 | # Check if points has a channel dimension 288 | if len(points.shape) == 3: 289 | offset = offset[:, None, ...] 290 | return points + offset 291 | 292 | 293 | def uncrop_masks(masks, crop_box, orig_h, orig_w): 294 | x0, y0, x1, y1 = crop_box 295 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 296 | return masks 297 | # Coordinate transform masks 298 | pad = [(0, 0)] * len(masks.shape[:-2]) 299 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 300 | pad = pad + [(y0, pad_y - y0), (x0, pad_x - x0)] 301 | return ops.pad(masks, pad) 302 | 303 | 304 | def remove_small_regions(mask, area_thresh, mode): 305 | """ 306 | Removes small disconnected regions and holes in a mask. Returns the 307 | mask and an indicator of if the mask has been modified. 308 | """ 309 | import cv2 310 | 311 | correct_holes = mode == "holes" 312 | working_mask = (correct_holes ^ mask).astype(np.uint8) 313 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 314 | sizes = stats[:, -1][1:] # Row 0 is background label 315 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 316 | if len(small_regions) == 0: 317 | return mask, False 318 | fill_labels = [0] + small_regions 319 | if not correct_holes: 320 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 321 | # If every region is below threshold, keep largest 322 | if len(fill_labels) == 0: 323 | fill_labels = [int(np.argmax(sizes)) + 1] 324 | mask = np.isin(regions, fill_labels) 325 | return mask, True 326 | 327 | 328 | def coco_encode_rle(uncompressed_rle): 329 | from pycocotools import mask as mask_utils 330 | 331 | h, w = uncompressed_rle["size"] 332 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 333 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 334 | return rle 335 | 336 | 337 | def batched_mask_to_box(masks): 338 | """ 339 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 340 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 341 | """ 342 | if ops.size(masks) == 0: 343 | return ops.zeros((*masks.shape[:-2], 4)) 344 | 345 | # Normalize shape to CxHxW 346 | shape = masks.shape 347 | h, w = shape[-2:] 348 | if len(shape) > 2: 349 | masks = ops.reshape(masks, (-1, *masks.shape[-2:])) 350 | else: 351 | masks = masks[None, ...] 352 | 353 | # Get top and bottom edges 354 | in_height = ops.max(masks, axis=-1) 355 | in_height_coords = ( 356 | ops.cast(in_height, "float32") * ops.arange(h, dtype="float32")[None, :] 357 | ) 358 | bottom_edges = ops.max(in_height_coords, axis=-1) 359 | in_height_coords = in_height_coords + h * ops.cast(~in_height, "float32") 360 | top_edges = ops.min(in_height_coords, axis=-1) 361 | 362 | # Get left and right edges 363 | in_width = ops.max(masks, axis=-2) 364 | in_width_coords = ( 365 | ops.cast(in_width, "float32") * ops.arange(w, dtype="float32")[None, :] 366 | ) 367 | right_edges = ops.max(in_width_coords, axis=-1) 368 | in_width_coords = in_width_coords + w * ops.cast(~in_width, "float32") 369 | left_edges = ops.min(in_width_coords, axis=-1) 370 | 371 | # If the mask is empty the right edge will be to the left of the left edge. 372 | # Replace these boxes with [0, 0, 0, 0] 373 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 374 | out = ops.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1) 375 | out = out * ops.cast((~empty_filter)[..., None], "float32") 376 | 377 | # Return to original shape 378 | if len(shape) > 2: 379 | out = ops.reshape(out, (*shape[:-2], 4)) 380 | else: 381 | out = out[0] 382 | 383 | return out 384 | -------------------------------------------------------------------------------- /sam_keras/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Author: Tirth Patel (tirthasheshpatel@gmail.com) 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import numpy as np 10 | import keras 11 | from keras import ops 12 | 13 | # TODO: KerasCV made the non_max_suppression layer internal since 0.7.0 release. 14 | # Instead of trying to access the internals, copy-paste the code for 15 | # non_max_suppression in this repo and use that instead. 16 | try: 17 | from keras_cv.layers.object_detection.non_max_suppression import non_max_suppression 18 | except ImportError: 19 | from keras_cv.src.layers.object_detection.non_max_suppression import ( 20 | non_max_suppression, 21 | ) 22 | 23 | from sam_keras.amg_utils import ( 24 | MaskData, 25 | area_from_rle, 26 | batch_iterator, 27 | batched_mask_to_box, 28 | box_xyxy_to_xywh, 29 | box_xyxy_to_yxyx, 30 | build_all_layer_point_grids, 31 | calculate_stability_score, 32 | coco_encode_rle, 33 | generate_crop_boxes, 34 | is_box_near_crop_edge, 35 | mask_to_rle_tensor, 36 | remove_small_regions, 37 | rle_to_mask, 38 | uncrop_boxes_xyxy, 39 | uncrop_masks, 40 | uncrop_points, 41 | ) 42 | 43 | 44 | __all__ = ["SAMAutomaticMaskGenerator"] 45 | 46 | 47 | def _box_area(boxes): 48 | return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1]) 49 | 50 | 51 | def _batched_nms(boxes, scores, iou_threshold, max_output_size): 52 | if keras.config.backend() == "torch": 53 | from torchvision.ops import batched_nms 54 | 55 | idx = batched_nms( 56 | boxes, 57 | scores, 58 | ops.zeros_like(boxes[:, 0]), # categories 59 | iou_threshold=iou_threshold, 60 | ) 61 | del batched_nms 62 | elif keras.config.backend() == "tensorflow": 63 | import tensorflow as tf 64 | 65 | idx = tf.image.non_max_suppression( 66 | boxes=box_xyxy_to_yxyx(boxes), 67 | scores=scores, 68 | max_output_size=max_output_size, 69 | iou_threshold=iou_threshold, 70 | ) 71 | del tf 72 | elif keras.config.backend() == "jax": 73 | from sam_keras import jax_nms 74 | 75 | idx = jax_nms.non_max_suppression_padded( 76 | boxes=box_xyxy_to_yxyx(boxes)[None, ...], 77 | scores=scores[None, ...], 78 | max_output_size=max_output_size, 79 | iou_threshold=iou_threshold, 80 | ) 81 | del jax_nms 82 | else: 83 | idx, num_valid = non_max_suppression( 84 | boxes=box_xyxy_to_yxyx(boxes), 85 | scores=scores, 86 | max_output_size=max_output_size, 87 | iou_threshold=iou_threshold, 88 | ) 89 | idx = idx[0][:num_valid] 90 | return idx 91 | 92 | 93 | class SAMAutomaticMaskGenerator: 94 | def __init__( 95 | self, 96 | predictor, 97 | points_per_side=32, 98 | points_per_batch=64, 99 | pred_iou_thresh=0.88, 100 | stability_score_thresh=0.95, 101 | stability_score_offset=1.0, 102 | box_nms_thresh=0.7, 103 | crop_n_layers=0, 104 | crop_nms_thresh=0.7, 105 | crop_overlap_ratio=512 / 1500, 106 | crop_n_points_downscale_factor=1, 107 | point_grids=None, 108 | min_mask_region_area=0, 109 | output_mode="binary_mask", 110 | max_output_masks=100, 111 | ) -> None: 112 | """ 113 | Using a SAM model, generates masks for the entire image. 114 | Generates a grid of point prompts over the image, then filters 115 | low quality and duplicate masks. The default settings are chosen 116 | for SAM with a ViT-H backbone. 117 | 118 | Arguments: 119 | predictor (Sam): The SAM model to use for mask prediction. 120 | points_per_side (int or None): The number of points to be sampled 121 | along one side of the image. The total number of points is 122 | points_per_side**2. If None, 'point_grids' must provide explicit 123 | point sampling. 124 | points_per_batch (int): Sets the number of points run simultaneously 125 | by the model. Higher numbers may be faster but use more GPU memory. 126 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 127 | model's predicted mask quality. 128 | stability_score_thresh (float): A filtering threshold in [0,1], using 129 | the stability of the mask under changes to the cutoff used to binarize 130 | the model's mask predictions. 131 | stability_score_offset (float): The amount to shift the cutoff when 132 | calculated the stability score. 133 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 134 | suppression to filter duplicate masks. 135 | crop_n_layers (int): If >0, mask prediction will be run again on 136 | crops of the image. Sets the number of layers to run, where each 137 | layer has 2**i_layer number of image crops. 138 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 139 | suppression to filter duplicate masks between different crops. 140 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 141 | In the first crop layer, crops will overlap by this fraction of 142 | the image length. Later layers with more crops scale down this overlap. 143 | crop_n_points_downscale_factor (int): The number of points-per-side 144 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 145 | point_grids (list(np.ndarray) or None): A list over explicit grids 146 | of points used for sampling, normalized to [0,1]. The nth grid in the 147 | list is used in the nth crop layer. Exclusive with points_per_side. 148 | min_mask_region_area (int): If >0, postprocessing will be applied 149 | to remove disconnected regions and holes in masks with area smaller 150 | than min_mask_region_area. Requires opencv. 151 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 152 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 153 | For large resolutions, 'binary_mask' may consume large amounts of 154 | memory. 155 | max_output_masks (int): Maximum number of masks to generate. 156 | """ 157 | 158 | if not ((points_per_side is None) ^ (point_grids is None)): 159 | raise ValueError( 160 | "Exactly one of points_per_side or point_grid must be provided." 161 | ) 162 | if points_per_side is not None: 163 | self.point_grids = build_all_layer_point_grids( 164 | points_per_side, 165 | crop_n_layers, 166 | crop_n_points_downscale_factor, 167 | ) 168 | elif point_grids is not None: 169 | self.point_grids = point_grids 170 | else: 171 | raise ValueError("Can't have both points_per_side and point_grid be None.") 172 | 173 | if output_mode not in [ 174 | "binary_mask", 175 | "uncompressed_rle", 176 | "coco_rle", 177 | ]: 178 | raise ValueError(f"Unknown output_mode {output_mode}.") 179 | if output_mode == "coco_rle": 180 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 181 | 182 | if min_mask_region_area > 0: 183 | import cv2 # type: ignore # noqa: F401 184 | 185 | self.predictor = predictor 186 | self.points_per_batch = points_per_batch 187 | self.pred_iou_thresh = pred_iou_thresh 188 | self.stability_score_thresh = stability_score_thresh 189 | self.stability_score_offset = stability_score_offset 190 | self.box_nms_thresh = box_nms_thresh 191 | self.crop_n_layers = crop_n_layers 192 | self.crop_nms_thresh = crop_nms_thresh 193 | self.crop_overlap_ratio = crop_overlap_ratio 194 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 195 | self.min_mask_region_area = min_mask_region_area 196 | self.output_mode = output_mode 197 | self.max_output_masks = max_output_masks 198 | 199 | def generate(self, image, **kwargs): 200 | """ 201 | Generates masks for the given image. 202 | 203 | Arguments: 204 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 205 | 206 | Returns: 207 | list(dict(str, any)): A list over records for masks. Each record is 208 | a dict containing the following keys: 209 | segmentation (dict(str, any) or np.ndarray): The mask. If 210 | output_mode='binary_mask', is an array of shape HW. Otherwise, 211 | is a dictionary containing the RLE. 212 | bbox (list(float)): The box around the mask, in XYWH format. 213 | area (int): The area in pixels of the mask. 214 | predicted_iou (float): The model's own prediction of the mask's 215 | quality. This is filtered by the pred_iou_thresh parameter. 216 | point_coords (list(list(float))): The point coordinates input 217 | to the model to generate this mask. 218 | stability_score (float): A measure of the mask's quality. This 219 | is filtered on using the stability_score_thresh parameter. 220 | crop_box (list(float)): The crop of the image used to generate 221 | the mask, given in XYWH format. 222 | """ 223 | 224 | # Generate masks 225 | mask_data = self._generate_masks(image, **kwargs) 226 | 227 | # Filter small disconnected regions and holes in masks 228 | if self.min_mask_region_area > 0: 229 | mask_data = self.postprocess_small_regions( 230 | mask_data, 231 | self.min_mask_region_area, 232 | max(self.box_nms_thresh, self.crop_nms_thresh), 233 | ) 234 | 235 | # Encode masks 236 | if self.output_mode == "coco_rle": 237 | mask_data["segmentations"] = [ 238 | coco_encode_rle(rle) for rle in mask_data["rles"] 239 | ] 240 | elif self.output_mode == "binary_mask": 241 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 242 | else: 243 | mask_data["segmentations"] = mask_data["rles"] 244 | 245 | # Write mask records 246 | curr_anns = [] 247 | for idx in range(len(mask_data["segmentations"])): 248 | ann = { 249 | "segmentation": mask_data["segmentations"][idx], 250 | "area": area_from_rle(mask_data["rles"][idx]), 251 | "bbox": ops.convert_to_numpy( 252 | box_xyxy_to_xywh(mask_data["boxes"][idx]) 253 | ).tolist(), 254 | "predicted_iou": mask_data["iou_preds"][idx].item(), 255 | "point_coords": [mask_data["points"][idx].tolist()], 256 | "stability_score": mask_data["stability_score"][idx].item(), 257 | "crop_box": ops.convert_to_numpy( 258 | box_xyxy_to_xywh(mask_data["crop_boxes"][idx]) 259 | ).tolist(), 260 | } 261 | curr_anns.append(ann) 262 | 263 | return curr_anns 264 | 265 | def _generate_masks(self, image, **kwargs): 266 | orig_size = image.shape[:2] 267 | crop_boxes, layer_idxs = generate_crop_boxes( 268 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 269 | ) 270 | 271 | # Iterate over image crops 272 | data = MaskData() 273 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 274 | crop_data = self._process_crop( 275 | image, crop_box, layer_idx, orig_size, **kwargs 276 | ) 277 | data.cat(crop_data) 278 | 279 | # Remove duplicate masks between crops 280 | if len(crop_boxes) > 1: 281 | # Prefer masks from smaller crops 282 | scores = 1 / _box_area(data["crop_boxes"]) 283 | boxes = ops.cast(data["boxes"], "float32") 284 | scores = ops.cast(scores, "float32") 285 | keep_by_nms = _batched_nms( 286 | boxes, 287 | scores, 288 | iou_threshold=self.crop_nms_thresh, 289 | max_output_size=self.max_output_masks, 290 | ) 291 | data.filter(keep_by_nms) 292 | 293 | data.to_numpy() 294 | 295 | return data 296 | 297 | def _process_crop(self, image, crop_box, crop_layer_idx, orig_size, **kwargs): 298 | # Crop the image and calculate embeddings 299 | x0, y0, x1, y1 = crop_box 300 | cropped_im = image[y0:y1, x0:x1, :] 301 | cropped_im_size = cropped_im.shape[:2] 302 | self.predictor.set_image(cropped_im, **kwargs) 303 | 304 | # Get points for this crop 305 | points_scale = np.array(cropped_im_size)[None, ::-1] 306 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 307 | 308 | # Generate masks for this crop in batches 309 | data = MaskData() 310 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 311 | batch_data = self._process_batch( 312 | points, cropped_im_size, crop_box, orig_size, **kwargs 313 | ) 314 | data.cat(batch_data) 315 | del batch_data 316 | self.predictor.reset_image() 317 | 318 | # Remove duplicates within this crop. 319 | keep_by_nms = _batched_nms( 320 | ops.cast(data["boxes"], "float32"), 321 | ops.cast(data["iou_preds"], "float32"), 322 | iou_threshold=self.box_nms_thresh, 323 | max_output_size=self.max_output_masks, 324 | ) 325 | data.filter(keep_by_nms) 326 | 327 | # Return to the original image frame 328 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 329 | data["points"] = uncrop_points(data["points"], crop_box) 330 | data["crop_boxes"] = ops.convert_to_tensor( 331 | [crop_box for _ in range(len(data["rles"]))] 332 | ) 333 | 334 | return data 335 | 336 | def _process_batch(self, points, im_size, crop_box, orig_size, **kwargs): 337 | orig_h, orig_w = orig_size 338 | 339 | # Run model on this batch 340 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 341 | in_points = ops.convert_to_tensor(transformed_points)[:, None, :] 342 | B = in_points.shape[0] 343 | in_labels = ops.ones(B, dtype="int32")[:, None] 344 | in_points = ops.concatenate( 345 | [in_points, ops.zeros((B, 1, 2), dtype=in_points.dtype)], axis=1 346 | ) 347 | in_labels = ops.concatenate( 348 | [in_labels, -ops.ones((B, 1), dtype=in_labels.dtype)], axis=1 349 | ) 350 | out = self.predictor.predict( 351 | dict(point_coords=in_points, point_labels=in_labels), 352 | multimask_output=True, 353 | **kwargs, 354 | ) 355 | masks, iou_preds = out["masks"], out["iou_predictions"] 356 | 357 | # Serialize predictions and store in MaskData 358 | masks, iou_preds, points = map( 359 | ops.convert_to_tensor, [masks, iou_preds, points] 360 | ) 361 | data = MaskData( 362 | masks=ops.reshape(masks, (-1, *masks.shape[2:])), 363 | iou_preds=ops.reshape(iou_preds, (-1, *iou_preds.shape[2:])), 364 | points=ops.convert_to_tensor(ops.repeat(points, masks.shape[1], axis=0)), 365 | ) 366 | del masks 367 | 368 | # Filter by predicted IoU 369 | if self.pred_iou_thresh > 0.0: 370 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 371 | data.filter(keep_mask) 372 | 373 | # Calculate stability score 374 | data["stability_score"] = calculate_stability_score( 375 | data["masks"], self.predictor.mask_threshold, self.stability_score_offset 376 | ) 377 | if self.stability_score_thresh > 0.0: 378 | keep_mask = data["stability_score"] >= self.stability_score_thresh 379 | data.filter(keep_mask) 380 | 381 | # Threshold masks and calculate boxes 382 | data["masks"] = data["masks"] > self.predictor.mask_threshold 383 | data["boxes"] = batched_mask_to_box(data["masks"]) 384 | 385 | # Filter boxes that touch crop boundaries 386 | keep_mask = ~is_box_near_crop_edge( 387 | data["boxes"], crop_box, [0, 0, orig_w, orig_h] 388 | ) 389 | if not ops.all(keep_mask): 390 | data.filter(keep_mask) 391 | 392 | # Compress to RLE 393 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 394 | data["rles"] = mask_to_rle_tensor(data["masks"]) 395 | del data["masks"] 396 | 397 | return data 398 | 399 | def postprocess_small_regions( 400 | self, mask_data: MaskData, min_area: int, nms_thresh: float 401 | ) -> MaskData: 402 | """ 403 | Removes small disconnected regions and holes in masks, then reruns 404 | box NMS to remove any new duplicates. 405 | 406 | Edits mask_data in place. 407 | 408 | Requires open-cv as a dependency. 409 | """ 410 | if len(mask_data["rles"]) == 0: 411 | return mask_data 412 | 413 | # Filter small disconnected regions and holes 414 | new_masks = [] 415 | scores = [] 416 | for rle in mask_data["rles"]: 417 | mask = rle_to_mask(rle) 418 | 419 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 420 | unchanged = not changed 421 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 422 | unchanged = unchanged and not changed 423 | 424 | new_masks.append(ops.convert_to_tensor(mask)[None, ...]) 425 | # Give score=0 to changed masks and score=1 to unchanged masks 426 | # so NMS will prefer ones that didn't need postprocessing 427 | scores.append(float(unchanged)) 428 | 429 | # Recalculate boxes and remove any new duplicates 430 | masks = ops.concatenate(new_masks, axis=0) 431 | scores = ops.convert_to_tensor(scores, "float32") 432 | boxes = batched_mask_to_box(masks) 433 | keep_by_nms = _batched_nms( 434 | ops.cast(boxes, "float32"), 435 | scores, 436 | iou_threshold=nms_thresh, 437 | max_output_size=self.max_output_masks, 438 | ) 439 | 440 | # We update the boxes directly in the loop below. 441 | # Copy the boxes data since, for the tensorflow backend, Keras 3 returns 442 | # readonly arrays which can't be mutated in-place. 443 | mask_data["boxes"] = mask_data["boxes"].copy() 444 | 445 | # Only recalculate RLEs for masks that have changed 446 | for i_mask in keep_by_nms: 447 | if scores[i_mask] == 0.0: 448 | mask_tensor = masks[i_mask][None, ...] 449 | mask_data["rles"][i_mask] = mask_to_rle_tensor(mask_tensor)[0] 450 | mask_data["boxes"][i_mask] = ops.convert_to_numpy( 451 | boxes[i_mask] 452 | ) # update res directly 453 | mask_data.filter(keep_by_nms) 454 | 455 | return mask_data 456 | -------------------------------------------------------------------------------- /sam_keras/jax_nms.py: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/mlperf/training_results_v0.7/blob/3dbb53064a6b79354c68a6832414b6536fee1a75/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py 2 | # See https://github.com/google/flax/discussions/1929#discussioncomment-2378312 3 | # 4 | # Apache License 5 | # Version 2.0, January 2004 6 | # http://www.apache.org/licenses/ 7 | # 8 | # TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 9 | # 10 | # 1. Definitions. 11 | # 12 | # "License" shall mean the terms and conditions for use, reproduction, 13 | # and distribution as defined by Sections 1 through 9 of this document. 14 | # 15 | # "Licensor" shall mean the copyright owner or entity authorized by 16 | # the copyright owner that is granting the License. 17 | # 18 | # "Legal Entity" shall mean the union of the acting entity and all 19 | # other entities that control, are controlled by, or are under common 20 | # control with that entity. For the purposes of this definition, 21 | # "control" means (i) the power, direct or indirect, to cause the 22 | # direction or management of such entity, whether by contract or 23 | # otherwise, or (ii) ownership of fifty percent (50%) or more of the 24 | # outstanding shares, or (iii) beneficial ownership of such entity. 25 | # 26 | # "You" (or "Your") shall mean an individual or Legal Entity 27 | # exercising permissions granted by this License. 28 | # 29 | # "Source" form shall mean the preferred form for making modifications, 30 | # including but not limited to software source code, documentation 31 | # source, and configuration files. 32 | # 33 | # "Object" form shall mean any form resulting from mechanical 34 | # transformation or translation of a Source form, including but 35 | # not limited to compiled object code, generated documentation, 36 | # and conversions to other media types. 37 | # 38 | # "Work" shall mean the work of authorship, whether in Source or 39 | # Object form, made available under the License, as indicated by a 40 | # copyright notice that is included in or attached to the work 41 | # (an example is provided in the Appendix below). 42 | # 43 | # "Derivative Works" shall mean any work, whether in Source or Object 44 | # form, that is based on (or derived from) the Work and for which the 45 | # editorial revisions, annotations, elaborations, or other modifications 46 | # represent, as a whole, an original work of authorship. For the purposes 47 | # of this License, Derivative Works shall not include works that remain 48 | # separable from, or merely link (or bind by name) to the interfaces of, 49 | # the Work and Derivative Works thereof. 50 | # 51 | # "Contribution" shall mean any work of authorship, including 52 | # the original version of the Work and any modifications or additions 53 | # to that Work or Derivative Works thereof, that is intentionally 54 | # submitted to Licensor for inclusion in the Work by the copyright owner 55 | # or by an individual or Legal Entity authorized to submit on behalf of 56 | # the copyright owner. For the purposes of this definition, "submitted" 57 | # means any form of electronic, verbal, or written communication sent 58 | # to the Licensor or its representatives, including but not limited to 59 | # communication on electronic mailing lists, source code control systems, 60 | # and issue tracking systems that are managed by, or on behalf of, the 61 | # Licensor for the purpose of discussing and improving the Work, but 62 | # excluding communication that is conspicuously marked or otherwise 63 | # designated in writing by the copyright owner as "Not a Contribution." 64 | # 65 | # "Contributor" shall mean Licensor and any individual or Legal Entity 66 | # on behalf of whom a Contribution has been received by Licensor and 67 | # subsequently incorporated within the Work. 68 | # 69 | # 2. Grant of Copyright License. Subject to the terms and conditions of 70 | # this License, each Contributor hereby grants to You a perpetual, 71 | # worldwide, non-exclusive, no-charge, royalty-free, irrevocable 72 | # copyright license to reproduce, prepare Derivative Works of, 73 | # publicly display, publicly perform, sublicense, and distribute the 74 | # Work and such Derivative Works in Source or Object form. 75 | # 76 | # 3. Grant of Patent License. Subject to the terms and conditions of 77 | # this License, each Contributor hereby grants to You a perpetual, 78 | # worldwide, non-exclusive, no-charge, royalty-free, irrevocable 79 | # (except as stated in this section) patent license to make, have made, 80 | # use, offer to sell, sell, import, and otherwise transfer the Work, 81 | # where such license applies only to those patent claims licensable 82 | # by such Contributor that are necessarily infringed by their 83 | # Contribution(s) alone or by combination of their Contribution(s) 84 | # with the Work to which such Contribution(s) was submitted. If You 85 | # institute patent litigation against any entity (including a 86 | # cross-claim or counterclaim in a lawsuit) alleging that the Work 87 | # or a Contribution incorporated within the Work constitutes direct 88 | # or contributory patent infringement, then any patent licenses 89 | # granted to You under this License for that Work shall terminate 90 | # as of the date such litigation is filed. 91 | # 92 | # 4. Redistribution. You may reproduce and distribute copies of the 93 | # Work or Derivative Works thereof in any medium, with or without 94 | # modifications, and in Source or Object form, provided that You 95 | # meet the following conditions: 96 | # 97 | # (a) You must give any other recipients of the Work or 98 | # Derivative Works a copy of this License; and 99 | # 100 | # (b) You must cause any modified files to carry prominent notices 101 | # stating that You changed the files; and 102 | # 103 | # (c) You must retain, in the Source form of any Derivative Works 104 | # that You distribute, all copyright, patent, trademark, and 105 | # attribution notices from the Source form of the Work, 106 | # excluding those notices that do not pertain to any part of 107 | # the Derivative Works; and 108 | # 109 | # (d) If the Work includes a "NOTICE" text file as part of its 110 | # distribution, then any Derivative Works that You distribute must 111 | # include a readable copy of the attribution notices contained 112 | # within such NOTICE file, excluding those notices that do not 113 | # pertain to any part of the Derivative Works, in at least one 114 | # of the following places: within a NOTICE text file distributed 115 | # as part of the Derivative Works; within the Source form or 116 | # documentation, if provided along with the Derivative Works; or, 117 | # within a display generated by the Derivative Works, if and 118 | # wherever such third-party notices normally appear. The contents 119 | # of the NOTICE file are for informational purposes only and 120 | # do not modify the License. You may add Your own attribution 121 | # notices within Derivative Works that You distribute, alongside 122 | # or as an addendum to the NOTICE text from the Work, provided 123 | # that such additional attribution notices cannot be construed 124 | # as modifying the License. 125 | # 126 | # You may add Your own copyright statement to Your modifications and 127 | # may provide additional or different license terms and conditions 128 | # for use, reproduction, or distribution of Your modifications, or 129 | # for any such Derivative Works as a whole, provided Your use, 130 | # reproduction, and distribution of the Work otherwise complies with 131 | # the conditions stated in this License. 132 | # 133 | # 5. Submission of Contributions. Unless You explicitly state otherwise, 134 | # any Contribution intentionally submitted for inclusion in the Work 135 | # by You to the Licensor shall be under the terms and conditions of 136 | # this License, without any additional terms or conditions. 137 | # Notwithstanding the above, nothing herein shall supersede or modify 138 | # the terms of any separate license agreement you may have executed 139 | # with Licensor regarding such Contributions. 140 | # 141 | # 6. Trademarks. This License does not grant permission to use the trade 142 | # names, trademarks, service marks, or product names of the Licensor, 143 | # except as required for reasonable and customary use in describing the 144 | # origin of the Work and reproducing the content of the NOTICE file. 145 | # 146 | # 7. Disclaimer of Warranty. Unless required by applicable law or 147 | # agreed to in writing, Licensor provides the Work (and each 148 | # Contributor provides its Contributions) on an "AS IS" BASIS, 149 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 150 | # implied, including, without limitation, any warranties or conditions 151 | # of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 152 | # PARTICULAR PURPOSE. You are solely responsible for determining the 153 | # appropriateness of using or redistributing the Work and assume any 154 | # risks associated with Your exercise of permissions under this License. 155 | # 156 | # 8. Limitation of Liability. In no event and under no legal theory, 157 | # whether in tort (including negligence), contract, or otherwise, 158 | # unless required by applicable law (such as deliberate and grossly 159 | # negligent acts) or agreed to in writing, shall any Contributor be 160 | # liable to You for damages, including any direct, indirect, special, 161 | # incidental, or consequential damages of any character arising as a 162 | # result of this License or out of the use or inability to use the 163 | # Work (including but not limited to damages for loss of goodwill, 164 | # work stoppage, computer failure or malfunction, or any and all 165 | # other commercial damages or losses), even if such Contributor 166 | # has been advised of the possibility of such damages. 167 | # 168 | # 9. Accepting Warranty or Additional Liability. While redistributing 169 | # the Work or Derivative Works thereof, You may choose to offer, 170 | # and charge a fee for, acceptance of support, warranty, indemnity, 171 | # or other liability obligations and/or rights consistent with this 172 | # License. However, in accepting such obligations, You may act only 173 | # on Your own behalf and on Your sole responsibility, not on behalf 174 | # of any other Contributor, and only if You agree to indemnify, 175 | # defend, and hold each Contributor harmless for any liability 176 | # incurred by, or claims asserted against, such Contributor by reason 177 | # of your accepting any such warranty or additional liability. 178 | # 179 | # END OF TERMS AND CONDITIONS 180 | # 181 | # APPENDIX: How to apply the Apache License to your work. 182 | # 183 | # To apply the Apache License to your work, attach the following 184 | # boilerplate notice, with the fields enclosed by brackets "[]" 185 | # replaced with your own identifying information. (Don't include 186 | # the brackets!) The text should be enclosed in the appropriate 187 | # comment syntax for the file format. We also recommend that a 188 | # file or class name and description of purpose be included on the 189 | # same "printed page" as the copyright notice for easier 190 | # identification within third-party archives. 191 | # 192 | # Copyright 2018 The MLPerf Authors 193 | # 194 | # Licensed under the Apache License, Version 2.0 (the "License"); 195 | # you may not use this file except in compliance with the License. 196 | # You may obtain a copy of the License at 197 | # 198 | # http://www.apache.org/licenses/LICENSE-2.0 199 | # 200 | # Unless required by applicable law or agreed to in writing, software 201 | # distributed under the License is distributed on an "AS IS" BASIS, 202 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 203 | # See the License for the specific language governing permissions and 204 | # limitations under the License. 205 | 206 | 207 | """Non-max Suppression example. 208 | 209 | This script does non-max suppression used in models like SSD 210 | """ 211 | 212 | from jax import lax 213 | import jax.numpy as jnp 214 | 215 | _NMS_TILE_SIZE = 256 216 | 217 | 218 | def _bbox_overlap(boxes, gt_boxes): 219 | """Find Bounding box overlap. 220 | 221 | Args: 222 | boxes: first set of bounding boxes 223 | gt_boxes: second set of boxes to compute IOU 224 | 225 | Returns: 226 | iou: Intersection over union matrix of all input bounding boxes 227 | """ 228 | bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split( 229 | ary=boxes, indices_or_sections=4, axis=2 230 | ) 231 | gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split( 232 | ary=gt_boxes, indices_or_sections=4, axis=2 233 | ) 234 | 235 | # Calculates the intersection area. 236 | i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1])) 237 | i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1])) 238 | i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1])) 239 | i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1])) 240 | i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0) 241 | 242 | # Calculates the union area. 243 | bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min) 244 | gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min) 245 | # Adds a small epsilon to avoid divide-by-zero. 246 | u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8 247 | 248 | # Calculates IoU. 249 | iou = i_area / u_area 250 | 251 | return iou 252 | 253 | 254 | def _self_suppression(in_args): 255 | iou, _, iou_sum = in_args 256 | batch_size = iou.shape[0] 257 | can_suppress_others = jnp.reshape( 258 | jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1] 259 | ).astype(iou.dtype) 260 | iou_suppressed = ( 261 | jnp.reshape( 262 | (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype), 263 | [batch_size, -1, 1], 264 | ) 265 | * iou 266 | ) 267 | iou_sum_new = jnp.sum(iou_suppressed, [1, 2]) 268 | return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new 269 | 270 | 271 | def _cross_suppression(in_args): 272 | boxes, box_slice, iou_threshold, inner_idx = in_args 273 | batch_size = boxes.shape[0] 274 | new_slice = lax.dynamic_slice( 275 | boxes, [0, inner_idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4] 276 | ) 277 | iou = _bbox_overlap(new_slice, box_slice) 278 | ret_slice = ( 279 | jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2) 280 | * box_slice 281 | ) 282 | return boxes, ret_slice, iou_threshold, inner_idx + 1 283 | 284 | 285 | def _suppression_loop_body(in_args): 286 | """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE). 287 | 288 | Args: 289 | in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx 290 | 291 | Returns: 292 | boxes: updated boxes. 293 | iou_threshold: pass down iou_threshold to the next iteration. 294 | output_size: the updated output_size. 295 | idx: the updated induction variable. 296 | """ 297 | boxes, iou_threshold, output_size, idx = in_args 298 | num_tiles = boxes.shape[1] // _NMS_TILE_SIZE 299 | batch_size = boxes.shape[0] 300 | 301 | # Iterates over tiles that can possibly suppress the current tile. 302 | box_slice = lax.dynamic_slice( 303 | boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4] 304 | ) 305 | 306 | def _loop_cond(in_args): 307 | _, _, _, inner_idx = in_args 308 | return inner_idx < idx 309 | 310 | _, box_slice, _, _ = lax.while_loop( 311 | _loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0) 312 | ) 313 | 314 | # Iterates over the current tile to compute self-suppression. 315 | iou = _bbox_overlap(box_slice, box_slice) 316 | mask = jnp.expand_dims( 317 | jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) 318 | > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 319 | 0, 320 | ) 321 | iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype) 322 | 323 | def _loop_cond2(in_args): 324 | _, loop_condition, _ = in_args 325 | return loop_condition 326 | 327 | suppressed_iou, _, _ = lax.while_loop( 328 | _loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2])) 329 | ) 330 | suppressed_box = jnp.sum(suppressed_iou, 1) > 0 331 | box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2) 332 | 333 | # Uses box_slice to update the input boxes. 334 | mask = jnp.reshape( 335 | (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), [1, -1, 1, 1] 336 | ) 337 | boxes = jnp.tile( 338 | jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] 339 | ) * mask + jnp.reshape(boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * ( 340 | 1 - mask 341 | ) 342 | boxes = jnp.reshape(boxes, [batch_size, -1, 4]) 343 | 344 | # Updates output_size. 345 | output_size += jnp.sum(jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1]) 346 | return boxes, iou_threshold, output_size, idx + 1 347 | 348 | 349 | def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): 350 | """A wrapper that handles non-maximum suppression. 351 | 352 | Assumption: 353 | * The boxes are sorted by scores unless the box is a dot (all coordinates 354 | are zero). 355 | * Boxes with higher scores can be used to suppress boxes with lower scores. 356 | 357 | The overal design of the algorithm is to handle boxes tile-by-tile: 358 | 359 | boxes = boxes.pad_to_multiply_of(tile_size) 360 | num_tiles = len(boxes) // tile_size 361 | output_boxes = [] 362 | for i in range(num_tiles): 363 | box_tile = boxes[i*tile_size : (i+1)*tile_size] 364 | for j in range(i - 1): 365 | suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] 366 | iou = _bbox_overlap(box_tile, suppressing_tile) 367 | # if the box is suppressed in iou, clear it to a dot 368 | box_tile *= _update_boxes(iou) 369 | # Iteratively handle the diagnal tile. 370 | iou = _box_overlap(box_tile, box_tile) 371 | iou_changed = True 372 | while iou_changed: 373 | # boxes that are not suppressed by anything else 374 | suppressing_boxes = _get_suppressing_boxes(iou) 375 | # boxes that are suppressed by suppressing_boxes 376 | suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) 377 | # clear iou to 0 for boxes that are suppressed, as they cannot be used 378 | # to suppress other boxes any more 379 | new_iou = _clear_iou(iou, suppressed_boxes) 380 | iou_changed = (new_iou != iou) 381 | iou = new_iou 382 | # remaining boxes that can still suppress others, are selected boxes. 383 | output_boxes.append(_get_suppressing_boxes(iou)) 384 | if len(output_boxes) >= max_output_size: 385 | break 386 | 387 | Args: 388 | scores: a tensor with a shape of [batch_size, anchors]. 389 | boxes: a tensor with a shape of [batch_size, anchors, 4]. 390 | max_output_size: a scalar integer `Tensor` representing the maximum number 391 | of boxes to be selected by non max suppression. 392 | iou_threshold: a float representing the threshold for deciding whether boxes 393 | overlap too much with respect to IOU. 394 | Returns: 395 | nms_scores: a tensor with a shape of [batch_size, anchors]. It has same 396 | dtype as input scores. 397 | nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has 398 | same dtype as input boxes. 399 | """ 400 | batch_size = boxes.shape[0] 401 | num_boxes = boxes.shape[1] 402 | pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes 403 | boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]]) 404 | scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]]) 405 | num_boxes += pad 406 | 407 | def _loop_cond(in_args): 408 | unused_boxes, unused_threshold, output_size, idx = in_args 409 | return jnp.logical_and( 410 | jnp.min(output_size) < max_output_size, idx < num_boxes // _NMS_TILE_SIZE 411 | ) 412 | 413 | selected_boxes, _, output_size, _ = lax.while_loop( 414 | _loop_cond, 415 | _suppression_loop_body, 416 | (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0), 417 | ) 418 | idx = num_boxes - lax.top_k( 419 | jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) 420 | * jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), 421 | max_output_size, 422 | )[0].astype(jnp.int32) 423 | idx = jnp.minimum(idx, num_boxes - 1) 424 | idx = jnp.reshape( 425 | idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1] 426 | ) 427 | return idx[idx < num_boxes - pad] 428 | -------------------------------------------------------------------------------- /sam_keras/predictor.py: -------------------------------------------------------------------------------- 1 | # Author: Tirth Patel (tirthasheshpatel@gmail.com) 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import numpy as np 10 | from keras import ops 11 | from PIL import Image 12 | from .prompter import SAMPrompter 13 | 14 | 15 | __all__ = ["ResizeLongestSide", "SAMPredictor"] 16 | 17 | 18 | class ResizeLongestSide: 19 | def __init__(self, target_length): 20 | self.target_length = int(target_length) 21 | 22 | def apply_image(self, image): 23 | image = np.asarray(image) 24 | if len(image.shape) != 3: 25 | raise ValueError("`image` must be of shape `(H, W, C)`.") 26 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1]) 27 | return np.asarray( 28 | Image.fromarray(image).resize( 29 | target_size[::-1], resample=Image.Resampling.BILINEAR 30 | ) 31 | ) 32 | 33 | def apply_coords(self, coords, original_size): 34 | coords = ops.convert_to_tensor(coords) 35 | if len(coords.shape) != 3 and coords.shape[-1] != 2: 36 | raise ValueError( 37 | f"`coords` must be of shape `(B, N, 2)` but got `{coords.shape}`" 38 | ) 39 | old_h, old_w = tuple(original_size) 40 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1]) 41 | coords = ops.cast(coords, "float32") 42 | coords_x = coords[..., 0] * (new_w / old_w) 43 | coords_y = coords[..., 1] * (new_h / old_h) 44 | return ops.stack([coords_x, coords_y], axis=-1) 45 | 46 | def apply_boxes(self, boxes, original_size): 47 | boxes = ops.convert_to_tensor(boxes) 48 | B, N = boxes.shape[0:2] 49 | if len(boxes.shape) != 3 and boxes.shape[-1] != 4: 50 | raise ValueError( 51 | f"`boxes` must of shape `(B, N, 4)` but got `{boxes.shape}`" 52 | ) 53 | boxes = self.apply_coords(ops.reshape(boxes, (B, N, 2, 2)), original_size) 54 | return boxes 55 | 56 | def get_preprocess_shape(self, old_h, old_w): 57 | scale = self.target_length * 1.0 / max(old_h, old_w) 58 | new_h = old_h * scale 59 | new_w = old_w * scale 60 | return int(new_h + 0.5), int(new_w + 0.5) 61 | 62 | 63 | class SAMPredictor: 64 | mask_threshold = 0.0 65 | 66 | def __init__( 67 | self, 68 | model, 69 | pixel_mean=[123.675, 116.28, 103.53], 70 | pixel_std=[58.395, 57.12, 57.375], 71 | ): 72 | self.model = model 73 | self.pixel_mean = ops.convert_to_tensor(pixel_mean, dtype="float32") 74 | self.pixel_std = ops.convert_to_tensor(pixel_std, dtype="float32") 75 | self.img_size = model.backbone.input.shape[1] 76 | self.transform = ResizeLongestSide(self.img_size) 77 | self.prompter = SAMPrompter(self.model.prompt_encoder, self.model.mask_decoder) 78 | self.reset_image() 79 | 80 | def set_image(self, image, **kwargs): 81 | input_image = self.transform.apply_image(image) 82 | input_image_tensor = ops.convert_to_tensor(input_image, dtype="float32") 83 | input_image_tensor = input_image_tensor[None, :, :, :] 84 | 85 | self.set_tensor_image(input_image_tensor, image.shape[:2], **kwargs) 86 | 87 | def set_tensor_image(self, transformed_image, original_image_size, **kwargs): 88 | """ 89 | Calculates the image embeddings for the provided image, allowing 90 | masks to be predicted with the 'predict' method. Expects the input 91 | image to be already transformed to the format expected by the model. 92 | 93 | Arguments: 94 | transformed_image (tensor): The input image, with shape 95 | 1x3xHxW, which has been transformed with ResizeLongestSide. 96 | original_image_size (tuple(int, int)): The size of the image 97 | before transformation, in (H, W) format. 98 | """ 99 | self.reset_image() 100 | 101 | self.original_size = tuple(original_image_size) 102 | self.input_size = tuple(transformed_image.shape[-2:]) 103 | self.unprocessed_image = transformed_image 104 | input_image = self.preprocess_images(transformed_image) 105 | self.features = self.model.backbone.predict(input_image, **kwargs) 106 | self.is_image_set = True 107 | 108 | def _broadcast_batch(self, B, *args): 109 | res = [] 110 | for arg in args: 111 | res.append( 112 | ops.broadcast_to(arg, (B,) + arg.shape[1:]) if arg is not None else arg 113 | ) 114 | return res 115 | 116 | def predict( 117 | self, batched_input, multimask_output=True, return_logits=True, **kwargs 118 | ): 119 | batched_input = batched_input.copy() 120 | 121 | if self.is_image_set: 122 | batched_input["image"] = self.unprocessed_image 123 | batched_input["original_size"] = self.original_size 124 | images = self.features 125 | else: 126 | images = self.preprocess_images(batched_input["image"]) 127 | 128 | points = batched_input.get("point_coords", None) 129 | labels = batched_input.get("point_labels", None) 130 | boxes = batched_input.get("boxes", None) 131 | masks = batched_input.get("mask_inputs", None) 132 | 133 | if points is not None and boxes is None: 134 | pad_point = ops.zeros((points.shape[0], 1, 2), dtype="float32") 135 | pad_label = -ops.ones((labels.shape[0], 1), dtype="float32") 136 | points = ops.concatenate([points, pad_point], axis=1) 137 | labels = ops.concatenate([labels, pad_label], axis=1) 138 | 139 | B = max( 140 | [ 141 | images.shape[0], 142 | points.shape[0] if points is not None else 0, 143 | labels.shape[0] if labels is not None else 0, 144 | boxes.shape[0] if boxes is not None else 0, 145 | masks.shape[0] if masks is not None else 0, 146 | ] 147 | ) 148 | 149 | images, points, labels, boxes, masks = self._broadcast_batch( 150 | B, images, points, labels, boxes, masks 151 | ) 152 | 153 | model_input = {"images": images} 154 | 155 | if points is not None: 156 | model_input["points"] = points 157 | model_input["labels"] = labels 158 | if boxes is not None: 159 | model_input["boxes"] = boxes 160 | if masks is not None: 161 | model_input["masks"] = masks 162 | 163 | if self.is_image_set: 164 | outs = self.prompter.predict(model_input, **kwargs) 165 | else: 166 | outs = self.model.predict(model_input, **kwargs) 167 | low_res_masks, iou_scores = outs["masks"], outs["iou_pred"] 168 | if multimask_output: 169 | low_res_masks = low_res_masks[:, 1:, :, :] 170 | iou_scores = iou_scores[:, 1:] 171 | else: 172 | low_res_masks = low_res_masks[:, :1, :, :] 173 | iou_scores = iou_scores[:, :1] 174 | masks = self.postprocess_masks( 175 | low_res_masks, 176 | input_size=batched_input["image"].shape[1:3], 177 | original_size=batched_input["original_size"], 178 | ) 179 | if not return_logits: 180 | masks = ops.cast(masks > self.mask_threshold, dtype="float32") 181 | batched_output = { 182 | "masks": masks, 183 | "iou_predictions": iou_scores, 184 | "low_res_masks": low_res_masks, 185 | } 186 | return batched_output 187 | 188 | def postprocess_masks(self, masks, input_size, original_size): 189 | masks = ops.image.resize( 190 | ops.transpose(masks, axes=(0, 2, 3, 1)), 191 | size=(self.img_size, self.img_size), 192 | interpolation="bilinear", 193 | ) 194 | masks = masks[..., : input_size[0], : input_size[1], :] 195 | masks = ops.image.resize(masks, size=original_size, interpolation="bilinear") 196 | return ops.transpose(masks, axes=(0, 3, 1, 2)) 197 | 198 | def preprocess_images(self, x): 199 | x = (x - self.pixel_mean) / self.pixel_std 200 | h, w = x.shape[1:3] 201 | pad_h = self.img_size - h 202 | pad_w = self.img_size - w 203 | x = ops.pad(x, [(0, 0), (0, pad_h), (0, pad_w), (0, 0)]) 204 | # KerasCV now rescales the images and normalizes them. 205 | # Just unnormalize such that when KerasCV normalizes them 206 | # again, the padded values map to 0. 207 | x = x * self.pixel_std + self.pixel_mean 208 | return x 209 | 210 | def get_image_embedding(self): 211 | if not self.is_image_set: 212 | raise RuntimeError( 213 | "An image must be set with .set_image(...) to generate an embedding." 214 | ) 215 | return self.features 216 | 217 | def reset_image(self): 218 | """Resets the currently set image.""" 219 | self.is_image_set = False 220 | self.unprocessed_image = None 221 | self.features = None 222 | self.orig_h = None 223 | self.orig_w = None 224 | self.input_h = None 225 | self.input_w = None 226 | -------------------------------------------------------------------------------- /sam_keras/prompter.py: -------------------------------------------------------------------------------- 1 | # Author: Tirth Patel (tirthasheshpatel@gmail.com) 2 | 3 | import numpy as np 4 | import keras 5 | from keras import ops 6 | 7 | 8 | __all__ = ["SAMPrompter"] 9 | 10 | 11 | class SAMPrompter(keras.Model): 12 | def __init__( 13 | self, prompt_encoder, mask_decoder, feature_shape=(64, 64, 256), **kwargs 14 | ): 15 | # Define the prompt encoder inputs -- Prompts 16 | prompt_inputs = { 17 | "points": keras.Input(shape=[None, 2], name="points"), 18 | "labels": keras.Input(shape=[None], name="labels"), 19 | "boxes": keras.Input(shape=[None, 2, 2], name="boxes"), 20 | "masks": keras.Input(shape=[None, None, None, 1], name="masks"), 21 | } 22 | 23 | # All Inputs -- Features + Prompts 24 | all_inputs = {"images": keras.Input(feature_shape, name="images")} 25 | all_inputs.update(prompt_inputs) 26 | 27 | # Build the prompt encoder 28 | prompt_embeddings = prompt_encoder(prompt_inputs) 29 | 30 | # Define the mask decoder inputs 31 | mask_decoder_inputs = { 32 | "image_embeddings": all_inputs["images"], 33 | "image_pe": prompt_embeddings["dense_positional_embeddings"], 34 | "sparse_prompt_embeddings": prompt_embeddings["sparse_embeddings"], 35 | "dense_prompt_embeddings": prompt_embeddings["dense_embeddings"], 36 | } 37 | 38 | # Build the mask decoder 39 | outputs = mask_decoder(mask_decoder_inputs) 40 | 41 | super().__init__(inputs=all_inputs, outputs=outputs, **kwargs) 42 | 43 | self.prompt_encoder = prompt_encoder 44 | self.mask_decoder = mask_decoder 45 | 46 | def predict_step(self, *args, **kwargs): 47 | if len(args) == 2: 48 | args = (args[0], _add_placeholder_prompts(args[-1])) 49 | else: 50 | args = (_add_placeholder_prompts(args[0]),) 51 | 52 | return super().predict_step(*args, **kwargs) 53 | 54 | 55 | def _add_placeholder_prompts(inputs): 56 | """Adds placeholder prompt inputs for a call to SAM. 57 | 58 | Because SAM is a functional subclass model, all inputs must be specified in 59 | calls to the model. However, prompt inputs are all optional, so we have to 60 | add placeholders when they're not specified by the user. 61 | """ 62 | inputs = inputs.copy() 63 | 64 | # Get the batch shape based on the image input 65 | B = ops.shape(inputs["images"])[0] 66 | 67 | # The type of the placeholders must match the existing inputs with respect 68 | # to whether or not they are tensors (as opposed to Numpy arrays). 69 | zeros = ops.zeros if ops.is_tensor(inputs["images"]) else np.zeros 70 | 71 | # Fill in missing inputs. 72 | if "points" not in inputs: 73 | inputs["points"] = zeros((B, 0, 2)) 74 | if "labels" not in inputs: 75 | inputs["labels"] = zeros((B, 0)) 76 | if "boxes" not in inputs: 77 | inputs["boxes"] = zeros((B, 0, 2, 2)) 78 | if "masks" not in inputs: 79 | inputs["masks"] = zeros((B, 0, 256, 256, 1)) 80 | 81 | return inputs 82 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Author: Tirth Patel (tirthasheshpatel@gmail.com) 2 | 3 | # Copyright (c) Tirth Patel. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from setuptools import find_packages, setup 10 | 11 | setup( 12 | name="sam_keras", 13 | version="0.0.1", 14 | # we also require keras-cv but since we depend on its source, 15 | # we expect the user to install it separately. 16 | install_requires=["numpy", "Pillow", "keras-cv"], 17 | packages=find_packages(), 18 | extras_require={ 19 | "all": ["matplotlib", "pycocotools", "opencv-python", "tensorflow", "torch", 20 | "torchvision", "torchaudio", "jax", "jaxlib"], 21 | }, 22 | ) 23 | --------------------------------------------------------------------------------