├── .github └── workflows │ ├── commit-lint.yml │ ├── install_test_ci.yml │ └── release_to_pypi.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── examples ├── image_cv_segautomask.py ├── image_sahi_slice_predict.py ├── image_segautomask.py ├── image_segmanualmask.py ├── video_segautomask.py └── video_segmanualmask.py ├── metaseg ├── __init__.py ├── falai_predictor.py ├── generator │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ └── predictor.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── sahi_predictor.py ├── sam_predictor.py ├── utils │ ├── __init__.py │ ├── amg.py │ ├── data_utils.py │ ├── model_file_downloader.py │ ├── onnx.py │ └── transforms.py └── webapp │ ├── __init__.py │ ├── app.py │ └── demo.py ├── poetry.lock ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt └── scripts ├── amg.py └── export_onnx_model.py /.github/workflows/commit-lint.yml: -------------------------------------------------------------------------------- 1 | name: Conventional Commits 2 | on: [pull_request] 3 | 4 | jobs: 5 | lint-commits: 6 | name: Lint Commits 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout 🛎️ 11 | uses: actions/checkout@v3.5.3 12 | with: 13 | fetch-depth: 0 14 | 15 | - name: Install Commit Linting Tool 🔧 16 | run: npm install --save-dev @commitlint/{cli,config-conventional} 17 | 18 | - name: Set Linting Config to Conventional Commits spec 🔧 19 | run: | 20 | echo "module.exports = { extends: ['@commitlint/config-conventional'] };" > commitlint.config.js 21 | 22 | - name: Lint 🚨 23 | run: npx commitlint --from HEAD~1 --to HEAD --verbose 24 | -------------------------------------------------------------------------------- /.github/workflows/install_test_ci.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Install Test Python 3 | on: [push, pull_request, workflow_dispatch] 4 | 5 | jobs: 6 | Install_Test_Python: 7 | strategy: 8 | fail-fast: false 9 | matrix: 10 | os: [ubuntu-latest, windows-latest, macos-latest] 11 | python-version: ["3.8", "3.9", "3.10","3.11"] 12 | runs-on: ${{ matrix.os }} 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Python ${{ matrix.python-version }} 17 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest' 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | 22 | - name: Set up a virtual environment for Python ${{ matrix.python-version }} 23 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest' 24 | run: | 25 | python -m pip install --upgrade virtualenv 26 | virtualenv venv 27 | source venv/bin/activate 28 | 29 | - name: Install the base dependencies 30 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest' 31 | run: | 32 | source venv/bin/activate 33 | python -m pip install --upgrade poetry 34 | 35 | - name: Check the correctness of the project config 36 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest' 37 | run: | 38 | source venv/bin/activate 39 | poetry check 40 | 41 | - name: Install the package 42 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest' 43 | run: | 44 | source venv/bin/activate 45 | poetry install 46 | 47 | 48 | - name: Set up a virtual environment 49 | if: matrix.os == 'windows-latest' 50 | shell: pwsh 51 | run: | 52 | python -m pip install --upgrade virtualenv 53 | python -m virtualenv venv 54 | .\venv\Scripts\Activate.ps1 55 | 56 | - name: Install the base dependencies 57 | shell: pwsh 58 | if: matrix.os == 'windows-latest' 59 | run: | 60 | .\venv\Scripts\Activate.ps1 61 | python -m pip install --upgrade poetry 62 | 63 | - name: Check the correctness of the project config 64 | shell: pwsh 65 | if: matrix.os == 'windows-latest' 66 | run: | 67 | .\venv\Scripts\Activate.ps1 68 | poetry check 69 | 70 | - name: Install the package 71 | shell: pwsh 72 | if: matrix.os == 'windows-latest' 73 | run: | 74 | .\venv\Scripts\Activate.ps1 75 | poetry install 76 | -------------------------------------------------------------------------------- /.github/workflows/release_to_pypi.yml: -------------------------------------------------------------------------------- 1 | name: MetaSeg Release to PyPi 2 | on: 3 | push: 4 | tags: 5 | - 'v[0-9]+.[0-9].+[0-9]+' 6 | 7 | # Allows you to run this workflow manually from the Actions tab 8 | workflow_dispatch: 9 | 10 | jobs: 11 | build-n-publish: 12 | name: Build and publish to PyPI 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout source 17 | uses: actions/checkout@v3.5.3 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.9" 23 | 24 | - name: Build source and wheel distributions 25 | run: | 26 | python -m pip install --upgrade build twine 27 | python -m build 28 | twine check --strict dist/* 29 | - name: Publish distribution to PyPI 30 | uses: pypa/gh-action-pypi-publish@release/v1 31 | with: 32 | user: __token__ 33 | password: ${{ secrets.PYPI_METASEG_API_KEY }} 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # .pth/pt files (pytorch) 132 | *.pth 133 | *.pt 134 | 135 | # Ignore OSX folder attributes 136 | .DS_Store 137 | 138 | # Idea project files 139 | .idea/ 140 | .idea_modules/ 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | - id: check-yaml 10 | - id: check-docstring-first 11 | - id: check-executables-have-shebangs 12 | - id: check-toml 13 | - id: check-case-conflict 14 | - id: check-added-large-files 15 | args: ['--maxkb=2048'] 16 | exclude: ^logo/ 17 | - id: detect-private-key 18 | - id: forbid-new-submodules 19 | - id: pretty-format-json 20 | args: ['--autofix', '--no-sort-keys', '--indent=4'] 21 | - id: end-of-file-fixer 22 | - id: mixed-line-ending 23 | - repo: https://github.com/asottile/pyupgrade 24 | rev: v3.19.1 25 | hooks: 26 | - id: pyupgrade 27 | args: 28 | - --py3-plus 29 | - --keep-runtime-typing 30 | - repo: https://github.com/astral-sh/ruff-pre-commit 31 | rev: v0.9.1 32 | hooks: 33 | - id: ruff 34 | args: [--fix, --exit-non-zero-on-fix] 35 | - repo: https://github.com/pycqa/isort 36 | rev: 5.13.2 37 | hooks: 38 | - id: isort 39 | name: isort (python) 40 | - id: isort 41 | name: isort (cython) 42 | types: [cython] 43 | - id: isort 44 | name: isort (pyi) 45 | types: [pyi] 46 | - repo: https://github.com/psf/black 47 | rev: 24.10.0 48 | hooks: 49 | - id: black 50 | - repo: https://github.com/PyCQA/bandit 51 | rev: '1.8.2' 52 | hooks: 53 | - id: bandit 54 | args: ["-c", "pyproject.toml"] 55 | additional_dependencies: ["bandit[toml]"] 56 | - repo: https://github.com/PyCQA/autoflake 57 | rev: v2.3.1 58 | hooks: 59 | - id: autoflake 60 | 61 | ci: 62 | autofix_commit_msg: "fix(pre_commit): 🎨 auto format pre-commit hooks" 63 | autoupdate_commit_msg: "fix(pre_commit): ⬆ pre_commit autoupdate" 64 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at 64 | [INSERT CONTACT METHOD]. 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series of 87 | actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or permanent 94 | ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within the 114 | community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.1, available at 120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 127 | [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [homepage]: https://www.contributor-covenant.org 130 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 131 | [Mozilla CoC]: https://github.com/mozilla/diversity 132 | [FAQ]: https://www.contributor-covenant.org/faq 133 | [translations]: https://www.contributor-covenant.org/translations 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | MetaSeg: Packaged version of the Segment Anything repository 4 |

5 |
6 | teaser 7 |
8 | downloads 9 | HuggingFace Spaces 10 |
11 | 12 | 13 |

14 | 15 | Package version 16 | 17 | 18 | Download Count 19 | 20 | 21 | Supported Python versions 22 | 23 | 24 | Project Status 25 | 26 | 27 | pre-commit.ci 28 | 29 |

30 | 31 | 32 | This repo is a packaged version of the [segment-anything](https://github.com/facebookresearch/segment-anything) model. 33 | 34 | ### Installation 35 | ```bash 36 | pip install metaseg 37 | ``` 38 | 39 | ### Usage 40 | ```python 41 | from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor 42 | 43 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch. 44 | 45 | # For image 46 | results = SegAutoMaskPredictor().image_predict( 47 | source="image.jpg", 48 | model_type="vit_l", # vit_l, vit_h, vit_b 49 | points_per_side=16, 50 | points_per_batch=64, 51 | min_area=0, 52 | output_path="output.jpg", 53 | show=True, 54 | save=False, 55 | ) 56 | 57 | # For video 58 | results = SegAutoMaskPredictor().video_predict( 59 | source="video.mp4", 60 | model_type="vit_l", # vit_l, vit_h, vit_b 61 | points_per_side=16, 62 | points_per_batch=64, 63 | min_area=1000, 64 | output_path="output.mp4", 65 | ) 66 | 67 | # For manuel box and point selection 68 | 69 | # For image 70 | results = SegManualMaskPredictor().image_predict( 71 | source="image.jpg", 72 | model_type="vit_l", # vit_l, vit_h, vit_b 73 | input_point=[[100, 100], [200, 200]], 74 | input_label=[0, 1], 75 | input_box=[100, 100, 200, 200], # or [[100, 100, 200, 200], [100, 100, 200, 200]] 76 | multimask_output=False, 77 | random_color=False, 78 | show=True, 79 | save=False, 80 | ) 81 | 82 | # For video 83 | 84 | results = SegManualMaskPredictor().video_predict( 85 | source="video.mp4", 86 | model_type="vit_l", # vit_l, vit_h, vit_b 87 | input_point=[0, 0, 100, 100], 88 | input_label=[0, 1], 89 | input_box=None, 90 | multimask_output=False, 91 | random_color=False, 92 | output_path="output.mp4", 93 | ) 94 | ``` 95 | 96 | ### [SAHI](https://github.com/obss/sahi) + Segment Anything 97 | 98 | ```bash 99 | pip install sahi metaseg 100 | ``` 101 | 102 | ```python 103 | from metaseg.sahi_predict import SahiAutoSegmentation, sahi_sliced_predict 104 | 105 | image_path = "image.jpg" 106 | boxes = sahi_sliced_predict( 107 | image_path=image_path, 108 | detection_model_type="yolov5", # yolov8, detectron2, mmdetection, torchvision 109 | detection_model_path="yolov5l6.pt", 110 | conf_th=0.25, 111 | image_size=1280, 112 | slice_height=256, 113 | slice_width=256, 114 | overlap_height_ratio=0.2, 115 | overlap_width_ratio=0.2, 116 | ) 117 | 118 | SahiAutoSegmentation().image_predict( 119 | source=image_path, 120 | model_type="vit_b", 121 | input_box=boxes, 122 | multimask_output=False, 123 | random_color=False, 124 | show=True, 125 | save=False, 126 | ) 127 | ``` 128 | teaser 129 | 130 | ### [FalAI(Cloud GPU)](https://docs.fal.ai/fal-serverless/quickstart) + Segment Anything 131 | ```bash 132 | pip install metaseg fal_serverless 133 | fal-serverless auth login 134 | ``` 135 | 136 | ```python 137 | # For Auto Mask 138 | from metaseg import falai_automask_image 139 | 140 | image = falai_automask_image( 141 | image_path="image.jpg", 142 | model_type="vit_b", 143 | points_per_side=16, 144 | points_per_batch=32, 145 | min_area=0, 146 | ) 147 | image.show() # Show image 148 | image.save("output.jpg") # Save image 149 | 150 | # For Manual Mask 151 | from metaseg import falai_manuelmask_image 152 | 153 | image = falai_manualmask_image( 154 | image_path="image.jpg", 155 | model_type="vit_b", 156 | input_point=[[100, 100], [200, 200]], 157 | input_label=[0, 1], 158 | input_box=[100, 100, 200, 200], # or [[100, 100, 200, 200], [100, 100, 200, 200]], 159 | multimask_output=False, 160 | random_color=False, 161 | ) 162 | image.show() # Show image 163 | image.save("output.jpg") # Save image 164 | ``` 165 | # Extra Features 166 | 167 | - [x] Support for Yolov5/8, Detectron2, Mmdetection, Torchvision models 168 | - [x] Support for video and web application(Huggingface Spaces) 169 | - [x] Support for manual single multi box and point selection 170 | - [x] Support for pip installation 171 | - [x] Support for SAHI library 172 | - [x] Support for FalAI 173 | -------------------------------------------------------------------------------- /examples/image_cv_segautomask.py: -------------------------------------------------------------------------------- 1 | from cv2 import COLOR_BGR2RGB, Mat 2 | from cv2 import cvtColor as cv_cvtColor 3 | from cv2 import imread as cv_imread 4 | 5 | from metaseg import SegAutoMaskPredictor 6 | 7 | 8 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch. 9 | def main(src: Mat) -> None: 10 | SegAutoMaskPredictor().image_predict( 11 | source=src, 12 | model_type="vit_l", # vit_l, vit_h, vit_b 13 | points_per_side=16, 14 | points_per_batch=64, 15 | min_area=0, 16 | output_path="output.jpg", 17 | show=True, 18 | save=False, 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | image = cv_imread("image.png") 24 | image = cv_cvtColor(image, COLOR_BGR2RGB) 25 | main(image) 26 | -------------------------------------------------------------------------------- /examples/image_sahi_slice_predict.py: -------------------------------------------------------------------------------- 1 | from metaseg import SahiAutoSegmentation, sahi_sliced_predict 2 | 3 | 4 | def main(src: str = "image.png") -> None: 5 | img_path = src 6 | boxes = sahi_sliced_predict( 7 | image_path=img_path, 8 | detection_model_type="yolov5", # yolov8, detectron2, mmdetection, torchvision 9 | detection_model_path="yolov5l6.pt", 10 | conf_th=0.25, 11 | image_size=1280, 12 | slice_height=256, 13 | slice_width=256, 14 | overlap_height_ratio=0.2, 15 | overlap_width_ratio=0.2, 16 | ) 17 | 18 | SahiAutoSegmentation().image_predict( 19 | source=img_path, 20 | model_type="vit_b", 21 | input_box=boxes, 22 | multimask_output=False, 23 | random_color=False, 24 | show=True, 25 | save=False, 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /examples/image_segautomask.py: -------------------------------------------------------------------------------- 1 | from metaseg import SegAutoMaskPredictor 2 | 3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch. 4 | 5 | 6 | def main(src: str = "image.png") -> None: 7 | SegAutoMaskPredictor().image_predict( 8 | source=src, 9 | model_type="vit_l", # vit_l, vit_h, vit_b 10 | points_per_side=16, 11 | points_per_batch=64, 12 | min_area=0, 13 | output_path="output.jpg", 14 | show=True, 15 | save=False, 16 | ) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /examples/image_segmanualmask.py: -------------------------------------------------------------------------------- 1 | from metaseg import SegManualMaskPredictor 2 | 3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch. 4 | 5 | 6 | def main(src: str = "image.png") -> None: 7 | SegManualMaskPredictor().image_predict( 8 | source=src, 9 | model_type="vit_l", # vit_l, vit_h, vit_b 10 | input_point=[[100, 100], [200, 200]], 11 | input_label=[0, 1], 12 | input_box=[ 13 | 100, 14 | 100, 15 | 200, 16 | 200, 17 | ], # or [[100, 100, 200, 200], [100, 100, 200, 200]] 18 | multimask_output=False, 19 | random_color=False, 20 | show=True, 21 | save=False, 22 | ) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /examples/video_segautomask.py: -------------------------------------------------------------------------------- 1 | from metaseg import SegAutoMaskPredictor 2 | 3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch. 4 | 5 | 6 | # For video 7 | def main(src: str = "video.mp4") -> None: 8 | SegAutoMaskPredictor().video_predict( 9 | source=src, 10 | model_type="vit_l", # vit_l, vit_h, vit_b 11 | points_per_side=16, 12 | points_per_batch=64, 13 | min_area=1000, 14 | output_path="output.mp4", 15 | ) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /examples/video_segmanualmask.py: -------------------------------------------------------------------------------- 1 | from metaseg import SegManualMaskPredictor 2 | 3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch. 4 | 5 | 6 | def main(src: str = "video.mp4") -> None: 7 | SegManualMaskPredictor().video_predict( 8 | source=src, 9 | model_type="vit_l", # vit_l, vit_h, vit_b 10 | input_point=[0, 0, 100, 100], 11 | input_label=[0, 1], 12 | input_box=None, 13 | multimask_output=False, 14 | random_color=False, 15 | output_path="output.mp4", 16 | ) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /metaseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib.metadata as importlib_metadata 8 | 9 | from .falai_predictor import automask_image as automask_image 10 | from .falai_predictor import falai_automask_image as falai_automask_image 11 | from .falai_predictor import falai_manuelmask_image as falai_manuelmask_image 12 | from .falai_predictor import manuelmask_image as manuelmask_image 13 | from .sahi_predictor import SahiAutoSegmentation as SahiAutoSegmentation 14 | from .sahi_predictor import sahi_sliced_predict as sahi_sliced_predict 15 | from .sam_predictor import SegAutoMaskPredictor as SegAutoMaskPredictor 16 | from .sam_predictor import SegManualMaskPredictor as SegManualMaskPredictor 17 | 18 | try: 19 | # This will read version from pyproject.toml 20 | __version__ = importlib_metadata.version(__package__ or __name__) 21 | except importlib_metadata.PackageNotFoundError: 22 | __version__ = "development" 23 | -------------------------------------------------------------------------------- /metaseg/falai_predictor.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | from PIL import Image 4 | 5 | from .sam_predictor import SegAutoMaskPredictor, SegManualMaskPredictor 6 | from .utils.data_utils import load_server_image 7 | 8 | try: 9 | from fal_serverless import isolated 10 | except ImportError: 11 | raise ImportError( 12 | "Please install FalAI library using 'pip install fal_serverless'." 13 | ) 14 | 15 | 16 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4") 17 | def automask_image( 18 | data, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0 19 | ): 20 | image_path, output_path = load_server_image(data) 21 | SegAutoMaskPredictor().image_predict( 22 | source=image_path, 23 | model_type=model_type, 24 | points_per_side=points_per_side, 25 | points_per_batch=points_per_batch, 26 | min_area=min_area, 27 | output_path=output_path, 28 | show=False, 29 | save=True, 30 | ) 31 | with open(output_path, "rb") as f: 32 | result = f.read() 33 | 34 | return result 35 | 36 | 37 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4") 38 | def manuelmask_image( 39 | data, 40 | model_type="vit_b", 41 | input_point=[[100, 100], [200, 200]], 42 | input_label=[0, 1], 43 | input_box=[100, 100, 200, 200], 44 | multimask_output=False, 45 | random_color=False, 46 | min_area=0, 47 | ): 48 | image_path, output_path = load_server_image(data) 49 | SegManualMaskPredictor().image_predict( 50 | source=image_path, 51 | model_type=model_type, 52 | input_point=input_point, 53 | input_label=input_label, 54 | input_box=input_box, # 55 | multimask_output=multimask_output, 56 | random_color=random_color, 57 | min_area=min_area, # 58 | output_path=output_path, 59 | show=False, 60 | save=True, 61 | ) 62 | with open(output_path, "rb") as f: 63 | result = f.read() 64 | 65 | return result 66 | 67 | 68 | def falai_automask_image( 69 | image_path, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0 70 | ): 71 | with open(image_path, "rb") as f: 72 | data = f.read() 73 | 74 | image = automask_image( 75 | data=data, 76 | model_type=model_type, 77 | points_per_side=points_per_side, 78 | points_per_batch=points_per_batch, 79 | min_area=min_area, 80 | ) 81 | image = Image.open(BytesIO(image)) 82 | return image 83 | 84 | 85 | def falai_manuelmask_image( 86 | image_path, 87 | model_type="vit_b", 88 | input_point=[[100, 100], [200, 200]], 89 | input_label=[0, 1], 90 | input_box=[100, 100, 200, 200], 91 | multimask_output=False, 92 | random_color=False, 93 | min_area=0, 94 | ): 95 | with open(image_path, "rb") as f: 96 | data = f.read() 97 | 98 | image = manuelmask_image( 99 | data=data, 100 | model_type=model_type, 101 | input_point=input_point, 102 | input_label=input_label, 103 | input_box=input_box, 104 | multimask_output=multimask_output, 105 | random_color=random_color, 106 | min_area=min_area, 107 | ) 108 | image = Image.open(BytesIO(image)) 109 | return image 110 | -------------------------------------------------------------------------------- /metaseg/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .automatic_mask_generator import ( 2 | SamAutomaticMaskGenerator as SamAutomaticMaskGenerator, 3 | ) 4 | from .build_sam import build_sam as build_sam 5 | from .build_sam import build_sam_vit_b as build_sam_vit_b 6 | from .build_sam import build_sam_vit_h as build_sam_vit_h 7 | from .build_sam import build_sam_vit_l as build_sam_vit_l 8 | from .build_sam import sam_model_registry as sam_model_registry 9 | from .predictor import SamPredictor as SamPredictor 10 | -------------------------------------------------------------------------------- /metaseg/generator/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | import numpy as np 10 | import torch 11 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 12 | 13 | from metaseg.generator.predictor import SamPredictor 14 | from metaseg.modeling import Sam 15 | from metaseg.utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crops_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [ 176 | coco_encode_rle(rle) for rle in mask_data["rles"] 177 | ] 178 | elif self.output_mode == "binary_mask": 179 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 180 | else: 181 | mask_data["segmentations"] = mask_data["rles"] 182 | 183 | # Write mask records 184 | curr_anns = [] 185 | for idx in range(len(mask_data["segmentations"])): 186 | ann = { 187 | "segmentation": mask_data["segmentations"][idx], 188 | "area": area_from_rle(mask_data["rles"][idx]), 189 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 190 | "predicted_iou": mask_data["iou_preds"][idx].item(), 191 | "point_coords": [mask_data["points"][idx].tolist()], 192 | "stability_score": mask_data["stability_score"][idx].item(), 193 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 194 | } 195 | curr_anns.append(ann) 196 | 197 | return curr_anns 198 | 199 | def _generate_masks(self, image: np.ndarray) -> MaskData: 200 | orig_size = image.shape[:2] 201 | crop_boxes, layer_idxs = generate_crop_boxes( 202 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 203 | ) 204 | 205 | # Iterate over image crops 206 | data = MaskData() 207 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 208 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 209 | data.cat(crop_data) 210 | 211 | # Remove duplicate masks between crops 212 | if len(crop_boxes) > 1: 213 | # Prefer masks from smaller crops 214 | scores = 1 / box_area(data["crop_boxes"]) 215 | scores = scores.to(data["boxes"].device) 216 | keep_by_nms = batched_nms( 217 | data["boxes"].float(), 218 | scores, 219 | torch.zeros(len(data["boxes"])), # categories 220 | iou_threshold=self.crop_nms_thresh, 221 | ) 222 | data.filter(keep_by_nms) 223 | 224 | data.to_numpy() 225 | return data 226 | 227 | def _process_crop( 228 | self, 229 | image: np.ndarray, 230 | crop_box: List[int], 231 | crop_layer_idx: int, 232 | orig_size: Tuple[int, ...], 233 | ) -> MaskData: 234 | # Crop the image and calculate embeddings 235 | x0, y0, x1, y1 = crop_box 236 | cropped_im = image[y0:y1, x0:x1, :] 237 | cropped_im_size = cropped_im.shape[:2] 238 | self.predictor.set_image(cropped_im) 239 | 240 | # Get points for this crop 241 | points_scale = np.array(cropped_im_size)[None, ::-1] 242 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 243 | 244 | # Generate masks for this crop in batches 245 | data = MaskData() 246 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 247 | batch_data = self._process_batch( 248 | points, cropped_im_size, crop_box, orig_size 249 | ) 250 | data.cat(batch_data) 251 | del batch_data 252 | self.predictor.reset_image() 253 | 254 | # Remove duplicates within this crop. 255 | keep_by_nms = batched_nms( 256 | data["boxes"].float(), 257 | data["iou_preds"], 258 | torch.zeros(len(data["boxes"])), # categories 259 | iou_threshold=self.box_nms_thresh, 260 | ) 261 | data.filter(keep_by_nms) 262 | 263 | # Return to the original image frame 264 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 265 | data["points"] = uncrop_points(data["points"], crop_box) 266 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 267 | 268 | return data 269 | 270 | def _process_batch( 271 | self, 272 | points: np.ndarray, 273 | im_size: Tuple[int, ...], 274 | crop_box: List[int], 275 | orig_size: Tuple[int, ...], 276 | ) -> MaskData: 277 | orig_h, orig_w = orig_size 278 | 279 | # Run model on this batch 280 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 281 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 282 | in_labels = torch.ones( 283 | in_points.shape[0], dtype=torch.int, device=in_points.device 284 | ) 285 | masks, iou_preds, _ = self.predictor.predict_torch( 286 | in_points[:, None, :], 287 | in_labels[:, None], 288 | multimask_output=True, 289 | return_logits=True, 290 | ) 291 | 292 | # Serialize predictions and store in MaskData 293 | data = MaskData( 294 | masks=masks.flatten(0, 1), 295 | iou_preds=iou_preds.flatten(0, 1), 296 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 297 | ) 298 | del masks 299 | 300 | # Filter by predicted IoU 301 | if self.pred_iou_thresh > 0.0: 302 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 303 | data.filter(keep_mask) 304 | 305 | # Calculate stability score 306 | data["stability_score"] = calculate_stability_score( 307 | data["masks"], 308 | self.predictor.model.mask_threshold, 309 | self.stability_score_offset, 310 | ) 311 | if self.stability_score_thresh > 0.0: 312 | keep_mask = data["stability_score"] >= self.stability_score_thresh 313 | data.filter(keep_mask) 314 | 315 | # Threshold masks and calculate boxes 316 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 317 | data["boxes"] = batched_mask_to_box(data["masks"]) 318 | 319 | # Filter boxes that touch crop boundaries 320 | keep_mask = ~is_box_near_crop_edge( 321 | data["boxes"], crop_box, [0, 0, orig_w, orig_h] 322 | ) 323 | if not torch.all(keep_mask): 324 | data.filter(keep_mask) 325 | 326 | # Compress to RLE 327 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 328 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 329 | del data["masks"] 330 | 331 | return data 332 | 333 | @staticmethod 334 | def postprocess_small_regions( 335 | mask_data: MaskData, min_area: int, nms_thresh: float 336 | ) -> MaskData: 337 | """ 338 | Removes small disconnected regions and holes in masks, then reruns 339 | box NMS to remove any new duplicates. 340 | 341 | Edits mask_data in place. 342 | 343 | Requires open-cv as a dependency. 344 | """ 345 | if len(mask_data["rles"]) == 0: 346 | return mask_data 347 | 348 | # Filter small disconnected regions and holes 349 | new_masks = [] 350 | scores = [] 351 | for rle in mask_data["rles"]: 352 | mask = rle_to_mask(rle) 353 | 354 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 355 | unchanged = not changed 356 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 357 | unchanged = unchanged and not changed 358 | 359 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 360 | # Give score=0 to changed masks and score=1 to unchanged masks 361 | # so NMS will prefer ones that didn't need postprocessing 362 | scores.append(float(unchanged)) 363 | 364 | # Recalculate boxes and remove any new duplicates 365 | masks = torch.cat(new_masks, dim=0) 366 | boxes = batched_mask_to_box(masks) 367 | keep_by_nms = batched_nms( 368 | boxes.float(), 369 | torch.as_tensor(scores), 370 | torch.zeros(len(boxes)), # categories 371 | iou_threshold=nms_thresh, 372 | ) 373 | 374 | # Only recalculate RLEs for masks that have changed 375 | for i_mask in keep_by_nms: 376 | if scores[i_mask] == 0.0: 377 | mask_torch = masks[i_mask].unsqueeze(0) 378 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 379 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 380 | mask_data.filter(keep_by_nms) 381 | 382 | return mask_data 383 | -------------------------------------------------------------------------------- /metaseg/generator/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | 9 | import torch 10 | 11 | from metaseg.modeling import ( 12 | ImageEncoderViT, 13 | MaskDecoder, 14 | PromptEncoder, 15 | Sam, 16 | TwoWayTransformer, 17 | ) 18 | 19 | 20 | def build_sam_vit_h(checkpoint=None): 21 | return _build_sam( 22 | encoder_embed_dim=1280, 23 | encoder_depth=32, 24 | encoder_num_heads=16, 25 | encoder_global_attn_indexes=[7, 15, 23, 31], 26 | checkpoint=checkpoint, 27 | ) 28 | 29 | 30 | build_sam = build_sam_vit_h 31 | 32 | 33 | def build_sam_vit_l(checkpoint=None): 34 | return _build_sam( 35 | encoder_embed_dim=1024, 36 | encoder_depth=24, 37 | encoder_num_heads=16, 38 | encoder_global_attn_indexes=[5, 11, 17, 23], 39 | checkpoint=checkpoint, 40 | ) 41 | 42 | 43 | def build_sam_vit_b(checkpoint=None): 44 | return _build_sam( 45 | encoder_embed_dim=768, 46 | encoder_depth=12, 47 | encoder_num_heads=12, 48 | encoder_global_attn_indexes=[2, 5, 8, 11], 49 | checkpoint=checkpoint, 50 | ) 51 | 52 | 53 | build_sam_vit_h = { 54 | "default": build_sam, 55 | "vit_h": build_sam, 56 | "vit_l": build_sam_vit_l, 57 | "vit_b": build_sam_vit_b, 58 | } 59 | 60 | sam_model_registry = { 61 | "default": build_sam, 62 | "vit_h": build_sam, 63 | "vit_l": build_sam_vit_l, 64 | "vit_b": build_sam_vit_b, 65 | } 66 | 67 | 68 | def _build_sam( 69 | encoder_embed_dim, 70 | encoder_depth, 71 | encoder_num_heads, 72 | encoder_global_attn_indexes, 73 | checkpoint=None, 74 | ): 75 | prompt_embed_dim = 256 76 | image_size = 1024 77 | vit_patch_size = 16 78 | image_embedding_size = image_size // vit_patch_size 79 | sam = Sam( 80 | image_encoder=ImageEncoderViT( 81 | depth=encoder_depth, 82 | embed_dim=encoder_embed_dim, 83 | img_size=image_size, 84 | mlp_ratio=4, 85 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 86 | num_heads=encoder_num_heads, 87 | patch_size=vit_patch_size, 88 | qkv_bias=True, 89 | use_rel_pos=True, 90 | global_attn_indexes=encoder_global_attn_indexes, 91 | window_size=14, 92 | out_chans=prompt_embed_dim, 93 | ), 94 | prompt_encoder=PromptEncoder( 95 | embed_dim=prompt_embed_dim, 96 | image_embedding_size=(image_embedding_size, image_embedding_size), 97 | input_image_size=(image_size, image_size), 98 | mask_in_chans=16, 99 | ), 100 | mask_decoder=MaskDecoder( 101 | num_multimask_outputs=3, 102 | transformer=TwoWayTransformer( 103 | depth=2, 104 | embedding_dim=prompt_embed_dim, 105 | mlp_dim=2048, 106 | num_heads=8, 107 | ), 108 | transformer_dim=prompt_embed_dim, 109 | iou_head_depth=3, 110 | iou_head_hidden_dim=256, 111 | ), 112 | pixel_mean=[123.675, 116.28, 103.53], 113 | pixel_std=[58.395, 57.12, 57.375], 114 | ) 115 | sam.eval() 116 | if checkpoint is not None: 117 | with open(checkpoint, "rb") as f: 118 | state_dict = torch.load(f) 119 | sam.load_state_dict(state_dict) 120 | return sam 121 | -------------------------------------------------------------------------------- /metaseg/generator/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from metaseg.modeling import Sam 13 | from metaseg.utils.transforms import ResizeLongestSide 14 | 15 | 16 | class SamPredictor: 17 | def __init__( 18 | self, 19 | sam_model: Sam, 20 | ) -> None: 21 | """ 22 | Uses SAM to calculate the image embedding for an image, and then 23 | allow repeated, efficient mask prediction given prompts. 24 | 25 | Arguments: 26 | sam_model (Sam): The model to use for mask prediction. 27 | """ 28 | super().__init__() 29 | self.model = sam_model 30 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 31 | self.reset_image() 32 | 33 | def set_image( 34 | self, 35 | image: np.ndarray, 36 | image_format: str = "RGB", 37 | ) -> None: 38 | """ 39 | Calculates the image embeddings for the provided image, allowing 40 | masks to be predicted with the 'predict' method. 41 | 42 | Arguments: 43 | image (np.ndarray): The image for calculating masks. Expects an 44 | image in HWC uint8 format, with pixel values in [0, 255]. 45 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 46 | """ 47 | assert image_format in [ 48 | "RGB", 49 | "BGR", 50 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 51 | if image_format != self.model.image_format: 52 | image = image[..., ::-1] 53 | 54 | # Transform the image to the form expected by the model 55 | input_image = self.transform.apply_image(image) 56 | input_image_torch = torch.as_tensor(input_image, device=self.device) 57 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[ 58 | None, :, :, : 59 | ] 60 | 61 | self.set_torch_image(input_image_torch, image.shape[:2]) 62 | 63 | @torch.no_grad() 64 | def set_torch_image( 65 | self, 66 | transformed_image: torch.Tensor, 67 | original_image_size: Tuple[int, ...], 68 | ) -> None: 69 | """ 70 | Calculates the image embeddings for the provided image, allowing 71 | masks to be predicted with the 'predict' method. Expects the input 72 | image to be already transformed to the format expected by the model. 73 | 74 | Arguments: 75 | transformed_image (torch.Tensor): The input image, with shape 76 | 1x3xHxW, which has been transformed with ResizeLongestSide. 77 | original_image_size (tuple(int, int)): The size of the image 78 | before transformation, in (H, W) format. 79 | """ 80 | assert ( 81 | len(transformed_image.shape) == 4 82 | and transformed_image.shape[1] == 3 83 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 84 | ), ( 85 | f"set_torch_image input must be BCHW with long side " 86 | f"{self.model.image_encoder.img_size}." 87 | ) 88 | self.reset_image() 89 | 90 | self.original_size = original_image_size 91 | self.input_size = tuple(transformed_image.shape[-2:]) 92 | input_image = self.model.preprocess(transformed_image) 93 | self.features = self.model.image_encoder(input_image) 94 | self.is_image_set = True 95 | 96 | def predict( 97 | self, 98 | point_coords: Optional[np.ndarray] = None, 99 | point_labels: Optional[np.ndarray] = None, 100 | box: Optional[np.ndarray] = None, 101 | mask_input: Optional[np.ndarray] = None, 102 | multimask_output: bool = True, 103 | return_logits: bool = False, 104 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 105 | """ 106 | Predict masks for the given input prompts, using the currently set image. 107 | 108 | Arguments: 109 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 110 | model. Each point is in (X,Y) in pixels. 111 | point_labels (np.ndarray or None): A length N array of labels for the 112 | point prompts. 1 indicates a foreground point and 0 indicates a 113 | background point. 114 | box (np.ndarray or None): A length 4 array given a box prompt to the 115 | model, in XYXY format. 116 | mask_input (np.ndarray): A low resolution mask input to the model, typically 117 | coming from a previous prediction iteration. Has form 1xHxW, where 118 | for SAM, H=W=256. 119 | multimask_output (bool): If true, the model will return three masks. 120 | For ambiguous input prompts (such as a single click), this will often 121 | produce better masks than a single prediction. If only a single 122 | mask is needed, the model's predicted quality score can be used 123 | to select the best mask. For non-ambiguous prompts, such as multiple 124 | input prompts, multimask_output=False can give better results. 125 | return_logits (bool): If true, returns un-thresholded masks logits 126 | instead of a binary mask. 127 | 128 | Returns: 129 | (np.ndarray): The output masks in CxHxW format, where C is the 130 | number of masks, and (H, W) is the original image size. 131 | (np.ndarray): An array of length C containing the model's 132 | predictions for the quality of each mask. 133 | (np.ndarray): An array of shape CxHxW, where C is the number 134 | of masks and H=W=256. These low resolution logits can be passed to 135 | a subsequent iteration as mask input. 136 | """ 137 | if not self.is_image_set: 138 | raise RuntimeError( 139 | "An image must be set with .set_image(...) before mask prediction." 140 | ) 141 | 142 | # Transform input prompts 143 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 144 | if point_coords is not None: 145 | assert ( 146 | point_labels is not None 147 | ), "point_labels must be supplied if point_coords is supplied." 148 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 149 | coords_torch = torch.as_tensor( 150 | point_coords, dtype=torch.float, device=self.device 151 | ) 152 | labels_torch = torch.as_tensor( 153 | point_labels, dtype=torch.int, device=self.device 154 | ) 155 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 156 | if box is not None: 157 | box = self.transform.apply_boxes(box, self.original_size) 158 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 159 | box_torch = box_torch[None, :] 160 | if mask_input is not None: 161 | mask_input_torch = torch.as_tensor( 162 | mask_input, dtype=torch.float, device=self.device 163 | ) 164 | mask_input_torch = mask_input_torch[None, :, :, :] 165 | 166 | masks, iou_predictions, low_res_masks = self.predict_torch( 167 | coords_torch, 168 | labels_torch, 169 | box_torch, 170 | mask_input_torch, 171 | multimask_output, 172 | return_logits=return_logits, 173 | ) 174 | 175 | masks = masks[0].detach().cpu().numpy() 176 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 177 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 178 | return masks, iou_predictions, low_res_masks 179 | 180 | @torch.no_grad() 181 | def predict_torch( 182 | self, 183 | point_coords: Optional[torch.Tensor], 184 | point_labels: Optional[torch.Tensor], 185 | boxes: Optional[torch.Tensor] = None, 186 | mask_input: Optional[torch.Tensor] = None, 187 | multimask_output: bool = True, 188 | return_logits: bool = False, 189 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 190 | """ 191 | Predict masks for the given input prompts, using the currently set image. 192 | Input prompts are batched torch tensors and are expected to already be 193 | transformed to the input frame using ResizeLongestSide. 194 | 195 | Arguments: 196 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 197 | model. Each point is in (X,Y) in pixels. 198 | point_labels (torch.Tensor or None): A BxN array of labels for the 199 | point prompts. 1 indicates a foreground point and 0 indicates a 200 | background point. 201 | box (np.ndarray or None): A Bx4 array given a box prompt to the 202 | model, in XYXY format. 203 | mask_input (np.ndarray): A low resolution mask input to the model, typically 204 | coming from a previous prediction iteration. Has form Bx1xHxW, where 205 | for SAM, H=W=256. Masks returned by a previous iteration of the 206 | predict method do not need further transformation. 207 | multimask_output (bool): If true, the model will return three masks. 208 | For ambiguous input prompts (such as a single click), this will often 209 | produce better masks than a single prediction. If only a single 210 | mask is needed, the model's predicted quality score can be used 211 | to select the best mask. For non-ambiguous prompts, such as multiple 212 | input prompts, multimask_output=False can give better results. 213 | return_logits (bool): If true, returns un-thresholded masks logits 214 | instead of a binary mask. 215 | 216 | Returns: 217 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 218 | number of masks, and (H, W) is the original image size. 219 | (torch.Tensor): An array of shape BxC containing the model's 220 | predictions for the quality of each mask. 221 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 222 | of masks and H=W=256. These low res logits can be passed to 223 | a subsequent iteration as mask input. 224 | """ 225 | if not self.is_image_set: 226 | raise RuntimeError( 227 | "An image must be set with .set_image(...) before mask prediction." 228 | ) 229 | 230 | if point_coords is not None: 231 | points = (point_coords, point_labels) 232 | else: 233 | points = None 234 | 235 | # Embed prompts 236 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 237 | points=points, 238 | boxes=boxes, 239 | masks=mask_input, 240 | ) 241 | 242 | # Predict masks 243 | low_res_masks, iou_predictions = self.model.mask_decoder( 244 | image_embeddings=self.features, 245 | image_pe=self.model.prompt_encoder.get_dense_pe(), 246 | sparse_prompt_embeddings=sparse_embeddings, 247 | dense_prompt_embeddings=dense_embeddings, 248 | multimask_output=multimask_output, 249 | ) 250 | 251 | # Upscale the masks to the original image resolution 252 | masks = self.model.postprocess_masks( 253 | low_res_masks, self.input_size, self.original_size 254 | ) 255 | 256 | if not return_logits: 257 | masks = masks > self.model.mask_threshold 258 | 259 | return masks, iou_predictions, low_res_masks 260 | 261 | def get_image_embedding(self) -> torch.Tensor: 262 | """ 263 | Returns the image embeddings for the currently set image, with 264 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 265 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 266 | """ 267 | if not self.is_image_set: 268 | raise RuntimeError( 269 | "An image must be set with .set_image(...) to generate an embedding." 270 | ) 271 | assert ( 272 | self.features is not None 273 | ), "Features must exist if an image has been set." 274 | return self.features 275 | 276 | @property 277 | def device(self) -> torch.device: 278 | return self.model.device 279 | 280 | def reset_image(self) -> None: 281 | """Resets the currently set image.""" 282 | self.is_image_set = False 283 | self.features = None 284 | self.orig_h = None 285 | self.orig_w = None 286 | self.input_h = None 287 | self.input_w = None 288 | -------------------------------------------------------------------------------- /metaseg/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .common import LayerNorm2d as LayerNorm2d 8 | from .common import MLPBlock as MLPBlock 9 | from .image_encoder import ImageEncoderViT as ImageEncoderViT 10 | from .mask_decoder import MaskDecoder as MaskDecoder 11 | from .prompt_encoder import PositionEmbeddingRandom as PositionEmbeddingRandom 12 | from .prompt_encoder import PromptEncoder as PromptEncoder 13 | from .sam import Sam as Sam 14 | from .transformer import Attention as Attention 15 | from .transformer import TwoWayAttentionBlock as TwoWayAttentionBlock 16 | from .transformer import TwoWayTransformer as TwoWayTransformer 17 | -------------------------------------------------------------------------------- /metaseg/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /metaseg/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from metaseg.modeling.common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the 17 | # ViTDet backbone available at: 18 | # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py 19 | class ImageEncoderViT(nn.Module): 20 | def __init__( 21 | self, 22 | img_size: int = 1024, 23 | patch_size: int = 16, 24 | in_chans: int = 3, 25 | embed_dim: int = 768, 26 | depth: int = 12, 27 | num_heads: int = 12, 28 | mlp_ratio: float = 4.0, 29 | out_chans: int = 256, 30 | qkv_bias: bool = True, 31 | norm_layer: Type[nn.Module] = nn.LayerNorm, 32 | act_layer: Type[nn.Module] = nn.GELU, 33 | use_abs_pos: bool = True, 34 | use_rel_pos: bool = False, 35 | rel_pos_zero_init: bool = True, 36 | window_size: int = 0, 37 | global_attn_indexes: Tuple[int, ...] = (), 38 | ) -> None: 39 | """ 40 | Args: 41 | img_size (int): Input image size. 42 | patch_size (int): Patch size. 43 | in_chans (int): Number of input image channels. 44 | embed_dim (int): Patch embedding dimension. 45 | depth (int): Depth of ViT. 46 | num_heads (int): Number of attention heads in each ViT block. 47 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 48 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 49 | norm_layer (nn.Module): Normalization layer. 50 | act_layer (nn.Module): Activation layer. 51 | use_abs_pos (bool): If True, use absolute 52 | positional embeddings. 53 | use_rel_pos (bool): If True, add relative 54 | positional embeddings to the attention map. 55 | rel_pos_zero_init (bool): If True, zero initialize 56 | relative positional parameters. 57 | window_size (int): Window size for window 58 | attention blocks. 59 | global_attn_indexes (list): Indexes for 60 | blocks using global attention. 61 | """ 62 | super().__init__() 63 | self.img_size = img_size 64 | 65 | self.patch_embed = PatchEmbed( 66 | kernel_size=(patch_size, patch_size), 67 | stride=(patch_size, patch_size), 68 | in_chans=in_chans, 69 | embed_dim=embed_dim, 70 | ) 71 | 72 | self.pos_embed: Optional[nn.Parameter] = None 73 | if use_abs_pos: 74 | # Initialize absolute positional embedding with pretrain image size. 75 | self.pos_embed = nn.Parameter( 76 | torch.zeros( 77 | 1, img_size // patch_size, img_size // patch_size, embed_dim 78 | ) 79 | ) 80 | 81 | self.blocks = nn.ModuleList() 82 | for i in range(depth): 83 | block = Block( 84 | dim=embed_dim, 85 | num_heads=num_heads, 86 | mlp_ratio=mlp_ratio, 87 | qkv_bias=qkv_bias, 88 | norm_layer=norm_layer, 89 | act_layer=act_layer, 90 | use_rel_pos=use_rel_pos, 91 | rel_pos_zero_init=rel_pos_zero_init, 92 | window_size=window_size if i not in global_attn_indexes else 0, 93 | input_size=(img_size // patch_size, img_size // patch_size), 94 | ) 95 | self.blocks.append(block) 96 | 97 | self.neck = nn.Sequential( 98 | nn.Conv2d( 99 | embed_dim, 100 | out_chans, 101 | kernel_size=1, 102 | bias=False, 103 | ), 104 | LayerNorm2d(out_chans), 105 | nn.Conv2d( 106 | out_chans, 107 | out_chans, 108 | kernel_size=3, 109 | padding=1, 110 | bias=False, 111 | ), 112 | LayerNorm2d(out_chans), 113 | ) 114 | 115 | def forward(self, x: torch.Tensor) -> torch.Tensor: 116 | x = self.patch_embed(x) 117 | if self.pos_embed is not None: 118 | x = x + self.pos_embed 119 | 120 | for blk in self.blocks: 121 | x = blk(x) 122 | 123 | x = self.neck(x.permute(0, 3, 1, 2)) 124 | 125 | return x 126 | 127 | 128 | class Block(nn.Module): 129 | """Transformer blocks with support of window 130 | attention and residual propagation blocks""" 131 | 132 | def __init__( 133 | self, 134 | dim: int, 135 | num_heads: int, 136 | mlp_ratio: float = 4.0, 137 | qkv_bias: bool = True, 138 | norm_layer: Type[nn.Module] = nn.LayerNorm, 139 | act_layer: Type[nn.Module] = nn.GELU, 140 | use_rel_pos: bool = False, 141 | rel_pos_zero_init: bool = True, 142 | window_size: int = 0, 143 | input_size: Optional[Tuple[int, int]] = None, 144 | ) -> None: 145 | """ 146 | Args: 147 | dim (int): Number of input channels. 148 | num_heads (int): Number of attention heads in each ViT block. 149 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 150 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 151 | norm_layer (nn.Module): Normalization layer. 152 | act_layer (nn.Module): Activation layer. 153 | use_rel_pos (bool): If True, add relative 154 | positional embeddings to the attention map. 155 | rel_pos_zero_init (bool): If True, zero 156 | initialize relative positional parameters. 157 | window_size (int): Window size for 158 | window attention blocks. If it equals 0, then 159 | use global attention. 160 | input_size (int or None): Input resolution for 161 | calculating the relative positional parameter size. 162 | """ 163 | super().__init__() 164 | self.norm1 = norm_layer(dim) 165 | self.attn = Attention( 166 | dim, 167 | num_heads=num_heads, 168 | qkv_bias=qkv_bias, 169 | use_rel_pos=use_rel_pos, 170 | rel_pos_zero_init=rel_pos_zero_init, 171 | input_size=input_size if window_size == 0 else (window_size, window_size), 172 | ) 173 | 174 | self.norm2 = norm_layer(dim) 175 | self.mlp = MLPBlock( 176 | embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer 177 | ) 178 | 179 | self.window_size = window_size 180 | 181 | def forward(self, x: torch.Tensor) -> torch.Tensor: 182 | shortcut = x 183 | x = self.norm1(x) 184 | # Window partition 185 | if self.window_size > 0: 186 | H, W = x.shape[1], x.shape[2] 187 | x, pad_hw = window_partition(x, self.window_size) 188 | 189 | x = self.attn(x) 190 | # Reverse window partition 191 | if self.window_size > 0: 192 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 193 | 194 | x = shortcut + x 195 | x = x + self.mlp(self.norm2(x)) 196 | 197 | return x 198 | 199 | 200 | class Attention(nn.Module): 201 | """Multi-head Attention block with relative position embeddings.""" 202 | 203 | def __init__( 204 | self, 205 | dim: int, 206 | num_heads: int = 8, 207 | qkv_bias: bool = True, 208 | use_rel_pos: bool = False, 209 | rel_pos_zero_init: bool = True, 210 | input_size: Optional[Tuple[int, int]] = None, 211 | ) -> None: 212 | """ 213 | Args: 214 | dim(int): Number of input channels. 215 | num_heads(int): Number of attention heads. 216 | qkv_bias(bool): If True, add a learnable 217 | bias to query, key, value. 218 | rel_pos(bool): If True, add relative 219 | positional embeddings to the attention map. 220 | rel_pos_zero_init(bool): If True, zero initialize 221 | relative positional parameters. 222 | input_size(int or None): Input resolution for 223 | calculating the relative positional 224 | parameter size. 225 | """ 226 | super().__init__() 227 | self.num_heads = num_heads 228 | head_dim = dim // num_heads 229 | self.scale = head_dim**-0.5 230 | 231 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 232 | self.proj = nn.Linear(dim, dim) 233 | 234 | self.use_rel_pos = use_rel_pos 235 | if self.use_rel_pos: 236 | assert ( 237 | input_size is not None 238 | ), "Input size must be provided if using relative positional encoding." 239 | # initialize relative positional embeddings 240 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 241 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 242 | 243 | def forward(self, x: torch.Tensor) -> torch.Tensor: 244 | B, H, W, _ = x.shape 245 | # qkv with shape (3, B, nHead, H * W, C) 246 | qkv = ( 247 | self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 248 | ) 249 | # q, k, v with shape (B * nHead, H * W, C) 250 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 251 | 252 | attn = (q * self.scale) @ k.transpose(-2, -1) 253 | 254 | if self.use_rel_pos: 255 | attn = add_decomposed_rel_pos( 256 | attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) 257 | ) 258 | 259 | attn = attn.softmax(dim=-1) 260 | x = ( 261 | (attn @ v) 262 | .view(B, self.num_heads, H, W, -1) 263 | .permute(0, 2, 3, 1, 4) 264 | .reshape(B, H, W, -1) 265 | ) 266 | x = self.proj(x) 267 | 268 | return x 269 | 270 | 271 | def window_partition( 272 | x: torch.Tensor, window_size: int 273 | ) -> Tuple[torch.Tensor, Tuple[int, int]]: 274 | """ 275 | Partition into non-overlapping windows with padding if needed. 276 | Args: 277 | x (tensor): input tokens with [B, H, W, C]. 278 | window_size (int): window size. 279 | 280 | Returns: 281 | windows: windows after partition with 282 | [B * num_windows, window_size, window_size, C]. 283 | (Hp, Wp): padded height and width before partition 284 | """ 285 | B, H, W, C = x.shape 286 | 287 | pad_h = (window_size - H % window_size) % window_size 288 | pad_w = (window_size - W % window_size) % window_size 289 | if pad_h > 0 or pad_w > 0: 290 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 291 | Hp, Wp = H + pad_h, W + pad_w 292 | 293 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 294 | windows = ( 295 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 296 | ) 297 | return windows, (Hp, Wp) 298 | 299 | 300 | def window_unpartition( 301 | windows: torch.Tensor, 302 | window_size: int, 303 | pad_hw: Tuple[int, int], 304 | hw: Tuple[int, int], 305 | ) -> torch.Tensor: 306 | """ 307 | Window unpartition into original sequences and removing padding. 308 | Args: 309 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 310 | window_size (int): window size. 311 | pad_hw (Tuple): padded height and width (Hp, Wp). 312 | hw (Tuple): original height and width (H, W) before padding. 313 | 314 | Returns: 315 | x: unpartitioned sequences with [B, H, W, C]. 316 | """ 317 | Hp, Wp = pad_hw 318 | H, W = hw 319 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 320 | x = windows.view( 321 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 322 | ) 323 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 324 | 325 | if Hp > H or Wp > W: 326 | x = x[:, :H, :W, :].contiguous() 327 | return x 328 | 329 | 330 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 331 | """ 332 | Get relative positional embeddings according to the relative positions of 333 | query and key sizes. 334 | Args: 335 | q_size (int): size of query q. 336 | k_size (int): size of key k. 337 | rel_pos (Tensor): relative position embeddings (L, C). 338 | 339 | Returns: 340 | Extracted positional embeddings according to relative positions. 341 | """ 342 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 343 | # Interpolate rel pos if needed. 344 | if rel_pos.shape[0] != max_rel_dist: 345 | # Interpolate rel pos. 346 | rel_pos_resized = F.interpolate( 347 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 348 | size=max_rel_dist, 349 | mode="linear", 350 | ) 351 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 352 | else: 353 | rel_pos_resized = rel_pos 354 | 355 | # Scale the coords with short length if shapes for q and k are different. 356 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 357 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 358 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 359 | 360 | return rel_pos_resized[relative_coords.long()] 361 | 362 | 363 | def add_decomposed_rel_pos( 364 | attn: torch.Tensor, 365 | q: torch.Tensor, 366 | rel_pos_h: torch.Tensor, 367 | rel_pos_w: torch.Tensor, 368 | q_size: Tuple[int, int], 369 | k_size: Tuple[int, int], 370 | ) -> torch.Tensor: 371 | """ 372 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 373 | commit_sha: 19786631e330df9f3622e5402b4a419a263a2c80 374 | https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py 375 | Args: 376 | attn (Tensor): attention map. 377 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 378 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 379 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 380 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 381 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 382 | 383 | Returns: 384 | attn (Tensor): attention map with added relative positional embeddings. 385 | """ 386 | q_h, q_w = q_size 387 | k_h, k_w = k_size 388 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 389 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 390 | 391 | B, _, dim = q.shape 392 | r_q = q.reshape(B, q_h, q_w, dim) 393 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 394 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 395 | 396 | attn = ( 397 | attn.view(B, q_h, q_w, k_h, k_w) 398 | + rel_h[:, :, :, :, None] 399 | + rel_w[:, :, :, None, :] 400 | ).view(B, q_h * q_w, k_h * k_w) 401 | 402 | return attn 403 | 404 | 405 | class PatchEmbed(nn.Module): 406 | """ 407 | Image to Patch Embedding. 408 | """ 409 | 410 | def __init__( 411 | self, 412 | kernel_size: Tuple[int, int] = (16, 16), 413 | stride: Tuple[int, int] = (16, 16), 414 | padding: Tuple[int, int] = (0, 0), 415 | in_chans: int = 3, 416 | embed_dim: int = 768, 417 | ) -> None: 418 | """ 419 | Args: 420 | kernel_size (Tuple): kernel size of the projection layer. 421 | stride (Tuple): stride of the projection layer. 422 | padding (Tuple): padding size of the projection layer. 423 | in_chans (int): Number of input image channels. 424 | embed_dim (int): embed_dim (int): Patch embedding dimension. 425 | """ 426 | super().__init__() 427 | 428 | self.proj = nn.Conv2d( 429 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 430 | ) 431 | 432 | def forward(self, x: torch.Tensor) -> torch.Tensor: 433 | x = self.proj(x) 434 | # B C H W -> B H W C 435 | x = x.permute(0, 2, 3, 1) 436 | return x 437 | -------------------------------------------------------------------------------- /metaseg/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from metaseg.modeling.common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d( 55 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 56 | ), 57 | LayerNorm2d(transformer_dim // 4), 58 | activation(), 59 | nn.ConvTranspose2d( 60 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 61 | ), 62 | activation(), 63 | ) 64 | self.output_hypernetworks_mlps = nn.ModuleList( 65 | [ 66 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 67 | for i in range(self.num_mask_tokens) 68 | ] 69 | ) 70 | 71 | self.iou_prediction_head = MLP( 72 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 73 | ) 74 | 75 | def forward( 76 | self, 77 | image_embeddings: torch.Tensor, 78 | image_pe: torch.Tensor, 79 | sparse_prompt_embeddings: torch.Tensor, 80 | dense_prompt_embeddings: torch.Tensor, 81 | multimask_output: bool, 82 | ) -> Tuple[torch.Tensor, torch.Tensor]: 83 | """ 84 | Predict masks given image and prompt embeddings. 85 | 86 | Arguments: 87 | image_embeddings (torch.Tensor): the embeddings from the image encoder 88 | image_pe (torch.Tensor): positional encoding 89 | with the shape of image_embeddings 90 | sparse_prompt_embeddings (torch.Tensor): the embeddings 91 | of the points and boxes 92 | dense_prompt_embeddings (torch.Tensor): the embeddings 93 | of the mask inputs 94 | multimask_output (bool): Whether to return 95 | multiple masks or a single mask. 96 | 97 | Returns: 98 | torch.Tensor: batched predicted masks 99 | torch.Tensor: batched predictions of mask quality 100 | """ 101 | masks, iou_pred = self.predict_masks( 102 | image_embeddings=image_embeddings, 103 | image_pe=image_pe, 104 | sparse_prompt_embeddings=sparse_prompt_embeddings, 105 | dense_prompt_embeddings=dense_prompt_embeddings, 106 | ) 107 | 108 | # Select the correct mask or masks for outptu 109 | if multimask_output: 110 | mask_slice = slice(1, None) 111 | else: 112 | mask_slice = slice(0, 1) 113 | masks = masks[:, mask_slice, :, :] 114 | iou_pred = iou_pred[:, mask_slice] 115 | 116 | # Prepare output 117 | return masks, iou_pred 118 | 119 | def predict_masks( 120 | self, 121 | image_embeddings: torch.Tensor, 122 | image_pe: torch.Tensor, 123 | sparse_prompt_embeddings: torch.Tensor, 124 | dense_prompt_embeddings: torch.Tensor, 125 | ) -> Tuple[torch.Tensor, torch.Tensor]: 126 | """Predicts masks. See 'forward' for more details.""" 127 | # Concatenate output tokens 128 | output_tokens = torch.cat( 129 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 130 | ) 131 | output_tokens = output_tokens.unsqueeze(0).expand( 132 | sparse_prompt_embeddings.size(0), -1, -1 133 | ) 134 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 135 | 136 | # Expand per-image data in batch direction to be per-mask 137 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 138 | src = src + dense_prompt_embeddings 139 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 140 | b, c, h, w = src.shape 141 | 142 | # Run the transformer 143 | hs, src = self.transformer(src, pos_src, tokens) 144 | iou_token_out = hs[:, 0, :] 145 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 146 | 147 | # Upscale mask embeddings and predict masks using the mask tokens 148 | src = src.transpose(1, 2).view(b, c, h, w) 149 | upscaled_embedding = self.output_upscaling(src) 150 | hyper_in_list: List[torch.Tensor] = [] 151 | for i in range(self.num_mask_tokens): 152 | hyper_in_list.append( 153 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 154 | ) 155 | hyper_in = torch.stack(hyper_in_list, dim=1) 156 | b, c, h, w = upscaled_embedding.shape 157 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 158 | 159 | # Generate mask quality predictions 160 | iou_pred = self.iou_prediction_head(iou_token_out) 161 | 162 | return masks, iou_pred 163 | 164 | 165 | # Lightly adapted from 166 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 167 | class MLP(nn.Module): 168 | def __init__( 169 | self, 170 | input_dim: int, 171 | hidden_dim: int, 172 | output_dim: int, 173 | num_layers: int, 174 | sigmoid_output: bool = False, 175 | ) -> None: 176 | super().__init__() 177 | self.num_layers = num_layers 178 | h = [hidden_dim] * (num_layers - 1) 179 | self.layers = nn.ModuleList( 180 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 181 | ) 182 | self.sigmoid_output = sigmoid_output 183 | 184 | def forward(self, x): 185 | for i, layer in enumerate(self.layers): 186 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 187 | if self.sigmoid_output: 188 | x = F.sigmoid(x) 189 | return x 190 | -------------------------------------------------------------------------------- /metaseg/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Optional, Tuple, Type 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | from metaseg.modeling.common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [ 47 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 48 | ] 49 | self.point_embeddings = nn.ModuleList(point_embeddings) 50 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 51 | 52 | self.mask_input_size = ( 53 | 4 * image_embedding_size[0], 54 | 4 * image_embedding_size[1], 55 | ) 56 | self.mask_downscaling = nn.Sequential( 57 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 58 | LayerNorm2d(mask_in_chans // 4), 59 | activation(), 60 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 61 | LayerNorm2d(mask_in_chans), 62 | activation(), 63 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 64 | ) 65 | self.no_mask_embed = nn.Embedding(1, embed_dim) 66 | 67 | def get_dense_pe(self) -> torch.Tensor: 68 | """ 69 | Returns the positional encoding used to encode point prompts, 70 | applied to a dense set of points the shape of the image encoding. 71 | 72 | Returns: 73 | torch.Tensor: Positional encoding with shape 74 | 1x(embed_dim)x(embedding_h)x(embedding_w) 75 | """ 76 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 77 | 78 | def _embed_points( 79 | self, 80 | points: torch.Tensor, 81 | labels: torch.Tensor, 82 | pad: bool, 83 | ) -> torch.Tensor: 84 | """Embeds point prompts.""" 85 | points = points + 0.5 # Shift to center of pixel 86 | if pad: 87 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 88 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 89 | points = torch.cat([points, padding_point], dim=1) 90 | labels = torch.cat([labels, padding_label], dim=1) 91 | point_embedding = self.pe_layer.forward_with_coords( 92 | points, self.input_image_size 93 | ) 94 | point_embedding[labels == -1] = 0.0 95 | point_embedding[labels == -1] += self.not_a_point_embed.weight 96 | point_embedding[labels == 0] += self.point_embeddings[0].weight 97 | point_embedding[labels == 1] += self.point_embeddings[1].weight 98 | return point_embedding 99 | 100 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 101 | """Embeds box prompts.""" 102 | boxes = boxes + 0.5 # Shift to center of pixel 103 | coords = boxes.reshape(-1, 2, 2) 104 | corner_embedding = self.pe_layer.forward_with_coords( 105 | coords, self.input_image_size 106 | ) 107 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 108 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 109 | return corner_embedding 110 | 111 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 112 | """Embeds mask inputs.""" 113 | mask_embedding = self.mask_downscaling(masks) 114 | return mask_embedding 115 | 116 | def _get_batch_size( 117 | self, 118 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 119 | boxes: Optional[torch.Tensor], 120 | masks: Optional[torch.Tensor], 121 | ) -> int: 122 | """ 123 | Gets the batch size of the output given the batch size of the input prompts. 124 | """ 125 | if points is not None: 126 | return points[0].shape[0] 127 | elif boxes is not None: 128 | return boxes.shape[0] 129 | elif masks is not None: 130 | return masks.shape[0] 131 | else: 132 | return 1 133 | 134 | def _get_device(self) -> torch.device: 135 | return self.point_embeddings[0].weight.device 136 | 137 | def forward( 138 | self, 139 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 140 | boxes: Optional[torch.Tensor], 141 | masks: Optional[torch.Tensor], 142 | ) -> Tuple[torch.Tensor, torch.Tensor]: 143 | """ 144 | Embeds different types of prompts, returning both sparse and dense 145 | embeddings. 146 | 147 | Arguments: 148 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 149 | and labels to embed. 150 | boxes (torch.Tensor or none): boxes to embed 151 | masks (torch.Tensor or none): masks to embed 152 | 153 | Returns: 154 | torch.Tensor: sparse embeddings for the points and boxes, with shape 155 | BxNx(embed_dim), where N is determined by the number of input points 156 | and boxes. 157 | torch.Tensor: dense embeddings for the masks, in the shape 158 | Bx(embed_dim)x(embed_H)x(embed_W) 159 | """ 160 | bs = self._get_batch_size(points, boxes, masks) 161 | sparse_embeddings = torch.empty( 162 | (bs, 0, self.embed_dim), device=self._get_device() 163 | ) 164 | if points is not None: 165 | coords, labels = points 166 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 167 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 168 | if boxes is not None: 169 | box_embeddings = self._embed_boxes(boxes) 170 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 171 | 172 | if masks is not None: 173 | dense_embeddings = self._embed_masks(masks) 174 | else: 175 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 176 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 177 | ) 178 | 179 | return sparse_embeddings, dense_embeddings 180 | 181 | 182 | class PositionEmbeddingRandom(nn.Module): 183 | """ 184 | Positional encoding using random spatial frequencies. 185 | """ 186 | 187 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 188 | super().__init__() 189 | if scale is None or scale <= 0.0: 190 | scale = 1.0 191 | self.register_buffer( 192 | "positional_encoding_gaussian_matrix", 193 | scale * torch.randn((2, num_pos_feats)), 194 | ) 195 | 196 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 197 | """Positionally encode points that are normalized to [0,1].""" 198 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 199 | coords = 2 * coords - 1 200 | coords = coords @ self.positional_encoding_gaussian_matrix 201 | coords = 2 * np.pi * coords 202 | # outputs d_1 x ... x d_n x C shape 203 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 204 | 205 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 206 | """Generate positional encoding for a grid of the specified size.""" 207 | h, w = size 208 | device: Any = self.positional_encoding_gaussian_matrix.device 209 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 210 | y_embed = grid.cumsum(dim=0) - 0.5 211 | x_embed = grid.cumsum(dim=1) - 0.5 212 | y_embed = y_embed / h 213 | x_embed = x_embed / w 214 | 215 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 216 | return pe.permute(2, 0, 1) # C x H x W 217 | 218 | def forward_with_coords( 219 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 220 | ) -> torch.Tensor: 221 | """Positionally encode points that are not normalized to [0,1].""" 222 | coords = coords_input.clone() 223 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 224 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 225 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 226 | -------------------------------------------------------------------------------- /metaseg/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Tuple 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from metaseg.modeling.image_encoder import ImageEncoderViT 14 | from metaseg.modeling.mask_decoder import MaskDecoder 15 | from metaseg.modeling.prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing 40 | pixels in the input image. 41 | pixel_std (list(float)): Std values for normalizing 42 | pixels in the input image. 43 | """ 44 | super().__init__() 45 | self.image_encoder = image_encoder 46 | self.prompt_encoder = prompt_encoder 47 | self.mask_decoder = mask_decoder 48 | self.register_buffer( 49 | "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False 50 | ) 51 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 52 | 53 | @property 54 | def device(self) -> Any: 55 | return self.pixel_mean.device 56 | 57 | @torch.no_grad() 58 | def forward( 59 | self, 60 | batched_input: List[Dict[str, Any]], 61 | multimask_output: bool, 62 | ) -> List[Dict[str, torch.Tensor]]: 63 | """ 64 | Predicts masks end-to-end from provided images and prompts. 65 | If prompts are not known in advance, using SamPredictor is 66 | recommended over calling the model directly. 67 | 68 | Arguments: 69 | batched_input (list(dict)): A list over input images, each a 70 | dictionary with the following keys. A prompt key can be 71 | excluded if it is not present. 72 | 'image': The image as a torch tensor in 3xHxW format, 73 | already transformed for input to the model. 74 | 'original_size': (tuple(int, int)) The original size of 75 | the image before transformation, as (H, W). 76 | 'point_coords': (torch.Tensor) Batched point prompts for 77 | this image, with shape BxNx2. Already transformed to the 78 | input frame of the model. 79 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 80 | with shape BxN. 81 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 82 | Already transformed to the input frame of the model. 83 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 84 | in the form Bx1xHxW. 85 | multimask_output (bool): Whether the model should predict multiple 86 | disambiguating masks, or return a single mask. 87 | 88 | Returns: 89 | (list(dict)): A list over input images, where each element is 90 | as dictionary with the following keys. 91 | 'masks': (torch.Tensor) Batched binary mask predictions, 92 | with shape BxCxHxW, where B is the number of input promts, 93 | C is determiend by multimask_output, and (H, W) is the 94 | original size of the image. 95 | 'iou_predictions': (torch.Tensor) The model's predictions 96 | of mask quality, in shape BxC. 97 | 'low_res_logits': (torch.Tensor) Low resolution logits with 98 | shape BxCxHxW, where H=W=256. Can be passed as mask input 99 | to subsequent iterations of prediction. 100 | """ 101 | input_images = torch.stack( 102 | [self.preprocess(x["image"]) for x in batched_input], dim=0 103 | ) 104 | image_embeddings = self.image_encoder(input_images) 105 | 106 | outputs = [] 107 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 108 | if "point_coords" in image_record: 109 | points = (image_record["point_coords"], image_record["point_labels"]) 110 | else: 111 | points = None 112 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 113 | points=points, 114 | boxes=image_record.get("boxes", None), 115 | masks=image_record.get("mask_inputs", None), 116 | ) 117 | low_res_masks, iou_predictions = self.mask_decoder( 118 | image_embeddings=curr_embedding.unsqueeze(0), 119 | image_pe=self.prompt_encoder.get_dense_pe(), 120 | sparse_prompt_embeddings=sparse_embeddings, 121 | dense_prompt_embeddings=dense_embeddings, 122 | multimask_output=multimask_output, 123 | ) 124 | masks = self.postprocess_masks( 125 | low_res_masks, 126 | input_size=image_record["image"].shape[-2:], 127 | original_size=image_record["original_size"], 128 | ) 129 | masks = masks > self.mask_threshold 130 | outputs.append( 131 | { 132 | "masks": masks, 133 | "iou_predictions": iou_predictions, 134 | "low_res_logits": low_res_masks, 135 | } 136 | ) 137 | return outputs 138 | 139 | def postprocess_masks( 140 | self, 141 | masks: torch.Tensor, 142 | input_size: Tuple[int, ...], 143 | original_size: Tuple[int, ...], 144 | ) -> torch.Tensor: 145 | """ 146 | Remove padding and upscale masks to the original image size. 147 | 148 | Arguments: 149 | masks (torch.Tensor): Batched masks from the mask_decoder, 150 | in BxCxHxW format. 151 | input_size (tuple(int, int)): The size of the image input to the 152 | model, in (H, W) format. Used to remove padding. 153 | original_size (tuple(int, int)): The original size of the image 154 | before resizing for input to the model, in (H, W) format. 155 | 156 | Returns: 157 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 158 | is given by original_size. 159 | """ 160 | masks = F.interpolate( 161 | masks, 162 | (self.image_encoder.img_size, self.image_encoder.img_size), 163 | mode="bilinear", 164 | align_corners=False, 165 | ) 166 | masks = masks[..., : input_size[0], : input_size[1]] 167 | masks = F.interpolate( 168 | masks, original_size, mode="bilinear", align_corners=False 169 | ) 170 | return masks 171 | 172 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 173 | """Normalize pixel values and pad to a square input.""" 174 | # Normalize colors 175 | x = (x - self.pixel_mean) / self.pixel_std 176 | 177 | # Pad 178 | h, w = x.shape[-2:] 179 | padh = self.image_encoder.img_size - h 180 | padw = self.image_encoder.img_size - w 181 | x = F.pad(x, (0, padw, 0, padh)) 182 | return x 183 | -------------------------------------------------------------------------------- /metaseg/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple, Type 9 | 10 | import torch 11 | from torch import Tensor, nn 12 | 13 | from metaseg.modeling.common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert ( 202 | self.internal_dim % num_heads == 0 203 | ), "num_heads must divide embedding_dim." 204 | 205 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 207 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 208 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 209 | 210 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 211 | b, n, c = x.shape 212 | x = x.reshape(b, n, num_heads, c // num_heads) 213 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 214 | 215 | def _recombine_heads(self, x: Tensor) -> Tensor: 216 | b, n_heads, n_tokens, c_per_head = x.shape 217 | x = x.transpose(1, 2) 218 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 219 | 220 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 221 | # Input projections 222 | q = self.q_proj(q) 223 | k = self.k_proj(k) 224 | v = self.v_proj(v) 225 | 226 | # Separate into heads 227 | q = self._separate_heads(q, self.num_heads) 228 | k = self._separate_heads(k, self.num_heads) 229 | v = self._separate_heads(v, self.num_heads) 230 | 231 | # Attention 232 | _, _, _, c_per_head = q.shape 233 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 234 | attn = attn / math.sqrt(c_per_head) 235 | attn = torch.softmax(attn, dim=-1) 236 | 237 | # Get output 238 | out = attn @ v 239 | out = self._recombine_heads(out) 240 | out = self.out_proj(out) 241 | 242 | return out 243 | -------------------------------------------------------------------------------- /metaseg/sahi_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from cv2 import Mat 8 | from PIL import Image 9 | 10 | from metaseg.generator import SamPredictor, sam_model_registry 11 | from metaseg.utils import ( 12 | download_model, 13 | load_image, 14 | multi_boxes, 15 | plt_load_box, 16 | plt_load_mask, 17 | ) 18 | 19 | 20 | def sahi_sliced_predict( 21 | image_path, 22 | detection_model_type, 23 | detection_model_path, 24 | conf_th, 25 | image_size, 26 | slice_height, 27 | slice_width, 28 | overlap_height_ratio, 29 | overlap_width_ratio, 30 | ): 31 | try: 32 | from sahi import AutoDetectionModel 33 | from sahi.predict import get_prediction, get_sliced_prediction 34 | except ImportError: 35 | raise ImportError("Please install SAHI library using 'pip install sahi'.") 36 | 37 | device = "cuda" if torch.cuda.is_available() else "cpu" 38 | 39 | detection_model = AutoDetectionModel.from_pretrained( 40 | image_size=image_size, 41 | model_type=detection_model_type, 42 | model_path=detection_model_path, 43 | confidence_threshold=conf_th, 44 | device=device, 45 | ) 46 | result = get_sliced_prediction( 47 | image_path, 48 | detection_model, 49 | slice_height=slice_height, 50 | slice_width=slice_width, 51 | overlap_height_ratio=overlap_height_ratio, 52 | overlap_width_ratio=overlap_width_ratio, 53 | ) 54 | 55 | result = get_prediction(image_path, detection_model) 56 | output = result.object_prediction_list 57 | boxes = [] 58 | for i in output: 59 | boxes.append(i.bbox.to_xyxy()) 60 | 61 | return boxes 62 | 63 | 64 | class SahiAutoSegmentation: 65 | def __init__(self): 66 | self.model = None 67 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 68 | 69 | def load_model(self, model_type): 70 | if self.model is None: 71 | self.model_path = download_model(model_type) 72 | self.model = sam_model_registry[model_type](checkpoint=self.model_path) 73 | self.model.to(device=self.device) 74 | 75 | return self.model 76 | 77 | def image_predict( 78 | self, 79 | source: Union[str, Mat], 80 | model_type, 81 | input_box=None, 82 | input_point=None, 83 | input_label=None, 84 | multimask_output=False, 85 | random_color=False, 86 | show=False, 87 | save=False, 88 | ): 89 | read_image = load_image(source) 90 | model = self.load_model(model_type) 91 | predictor = SamPredictor(model) 92 | predictor.set_image(read_image) 93 | 94 | if type(input_box[0]) == list: 95 | input_boxes, new_boxes = multi_boxes(input_box, predictor, read_image) 96 | 97 | masks, iou_predictions, low_res_masks = predictor.predict_torch( 98 | point_coords=None, 99 | point_labels=None, 100 | boxes=new_boxes, 101 | multimask_output=False, 102 | ) 103 | 104 | elif type(input_box[0]) == int: 105 | input_boxes = np.array(input_box)[None, :] 106 | 107 | masks, iou_predictions, low_res_masks = predictor.predict( 108 | point_coords=input_point, 109 | point_labels=input_label, 110 | box=input_boxes, 111 | multimask_output=multimask_output, 112 | ) 113 | 114 | plt.figure(figsize=(10, 10)) 115 | plt.imshow(read_image) 116 | for mask in masks: 117 | plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=random_color) 118 | for box in input_boxes: 119 | plt_load_box(box.cpu().numpy(), plt.gca()) 120 | plt.axis("off") 121 | if save: 122 | plt.savefig("output.png", bbox_inches="tight") 123 | output_image = cv2.imread("output.png") 124 | output_image = Image.fromarray(output_image) 125 | return output_image 126 | if show: 127 | plt.show() 128 | 129 | return masks, iou_predictions, low_res_masks 130 | -------------------------------------------------------------------------------- /metaseg/sam_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from cv2 import Mat 7 | from tqdm import tqdm 8 | 9 | from metaseg.generator.automatic_mask_generator import SamAutomaticMaskGenerator 10 | from metaseg.generator.build_sam import sam_model_registry 11 | from metaseg.generator.predictor import SamPredictor 12 | from metaseg.utils import ( 13 | download_model, 14 | load_box, 15 | load_image, 16 | load_mask, 17 | load_video, 18 | multi_boxes, 19 | show_image, 20 | ) 21 | 22 | 23 | class SegAutoMaskPredictor: 24 | def __init__(self): 25 | self.model = None 26 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 27 | 28 | def load_model(self, model_type): 29 | if self.model is None: 30 | self.model_path = download_model(model_type) 31 | self.model = sam_model_registry[model_type](checkpoint=self.model_path) 32 | self.model.to(device=self.device) 33 | 34 | return self.model 35 | 36 | def image_predict( 37 | self, 38 | source: Union[str, Mat], 39 | model_type, 40 | points_per_side, 41 | points_per_batch, 42 | min_area, 43 | output_path="output.png", 44 | show=False, 45 | save=False, 46 | ): 47 | read_image = load_image(source) 48 | model = self.load_model(model_type) 49 | mask_generator = SamAutomaticMaskGenerator( 50 | model, 51 | points_per_side=points_per_side, 52 | points_per_batch=points_per_batch, 53 | min_mask_region_area=min_area, 54 | ) 55 | 56 | masks = mask_generator.generate(read_image) 57 | 58 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True) 59 | mask_image = np.zeros( 60 | (masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3), 61 | dtype=np.uint8, 62 | ) 63 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8) 64 | for i, ann in enumerate(sorted_anns): 65 | m = ann["segmentation"] 66 | img = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8) 67 | color = colors[i % 256] 68 | for i in range(3): 69 | img[:, :, 0] = color[0] 70 | img[:, :, 1] = color[1] 71 | img[:, :, 2] = color[2] 72 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8)) 73 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0) 74 | mask_image = cv2.add(mask_image, img) 75 | 76 | combined_mask = cv2.add(read_image, mask_image) 77 | self.combined_mask = combined_mask 78 | if show: 79 | show_image(combined_mask) 80 | 81 | if save: 82 | cv2.imwrite(output_path, combined_mask) 83 | 84 | return masks 85 | 86 | def video_predict( 87 | self, 88 | source, 89 | model_type, 90 | points_per_side, 91 | points_per_batch, 92 | min_area, 93 | output_path="output.mp4", 94 | ): 95 | cap, out = load_video(source, output_path) 96 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 97 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8) 98 | 99 | for _ in tqdm(range(length)): 100 | ret, frame = cap.read() 101 | if not ret: 102 | break 103 | 104 | model = self.load_model(model_type) 105 | mask_generator = SamAutomaticMaskGenerator( 106 | model, 107 | points_per_side=points_per_side, 108 | points_per_batch=points_per_batch, 109 | min_mask_region_area=min_area, 110 | ) 111 | masks = mask_generator.generate(frame) 112 | 113 | if len(masks) == 0: 114 | continue 115 | 116 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True) 117 | mask_image = np.zeros( 118 | ( 119 | masks[0]["segmentation"].shape[0], 120 | masks[0]["segmentation"].shape[1], 121 | 3, 122 | ), 123 | dtype=np.uint8, 124 | ) 125 | 126 | for i, ann in enumerate(sorted_anns): 127 | m = ann["segmentation"] 128 | color = colors[i % 256] 129 | img = np.zeros((m.shape[0], m.shape[1], 3), dtype=np.uint8) 130 | img[:, :, 0] = color[0] 131 | img[:, :, 1] = color[1] 132 | img[:, :, 2] = color[2] 133 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8)) 134 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0) 135 | mask_image = cv2.add(mask_image, img) 136 | 137 | combined_mask = cv2.add(frame, mask_image) 138 | out.write(combined_mask) 139 | 140 | out.release() 141 | cap.release() 142 | cv2.destroyAllWindows() 143 | 144 | return output_path 145 | 146 | 147 | class SegManualMaskPredictor: 148 | def __init__(self): 149 | self.model = None 150 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 151 | 152 | def load_model(self, model_type): 153 | if self.model is None: 154 | self.model_path = download_model(model_type) 155 | self.model = sam_model_registry[model_type](checkpoint=self.model_path) 156 | self.model.to(device=self.device) 157 | 158 | return self.model 159 | 160 | def image_predict( 161 | self, 162 | source: Union[str, Mat], 163 | model_type, 164 | input_box=None, 165 | input_point=None, 166 | input_label=None, 167 | multimask_output=False, 168 | output_path="output.png", 169 | random_color=False, 170 | show=False, 171 | save=False, 172 | ): 173 | image = load_image(source) 174 | model = self.load_model(model_type) 175 | predictor = SamPredictor(model) 176 | predictor.set_image(image) 177 | 178 | if type(input_box[0]) == list: 179 | input_boxes, new_boxes = multi_boxes(input_box, predictor, image) 180 | 181 | masks, _, _ = predictor.predict_torch( 182 | point_coords=None, 183 | point_labels=None, 184 | boxes=new_boxes, 185 | multimask_output=False, 186 | ) 187 | for mask in masks: 188 | mask_image = load_mask(mask.cpu().numpy(), random_color) 189 | 190 | for box in input_boxes: 191 | image = load_box(box.cpu().numpy(), image) 192 | 193 | elif type(input_box[0]) == int: 194 | input_boxes = np.array(input_box)[None, :] 195 | 196 | masks, _, _ = predictor.predict( 197 | point_coords=input_point, 198 | point_labels=input_label, 199 | box=input_boxes, 200 | multimask_output=multimask_output, 201 | ) 202 | mask_image = load_mask(masks, random_color) 203 | image = load_box(input_box, image) 204 | 205 | combined_mask = cv2.add(image, mask_image) 206 | if save: 207 | cv2.imwrite(output_path, combined_mask) 208 | 209 | if show: 210 | show_image(combined_mask) 211 | 212 | return masks 213 | 214 | def video_predict( 215 | self, 216 | source, 217 | model_type, 218 | input_box=None, 219 | input_point=None, 220 | input_label=None, 221 | multimask_output=False, 222 | output_path="output.mp4", 223 | random_color=False, 224 | ): 225 | cap, out = load_video(source, output_path) 226 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 227 | 228 | for _ in tqdm(range(length)): 229 | ret, frame = cap.read() 230 | if not ret: 231 | break 232 | 233 | model = self.load_model(model_type) 234 | predictor = SamPredictor(model) 235 | predictor.set_image(frame) 236 | 237 | if type(input_box[0]) == list: 238 | input_boxes, new_boxes = multi_boxes(input_box, predictor, frame) 239 | 240 | masks, _, _ = predictor.predict_torch( 241 | point_coords=None, 242 | point_labels=None, 243 | boxes=new_boxes, 244 | multimask_output=False, 245 | ) 246 | for mask in masks: 247 | mask_image = load_mask(mask.cpu().numpy(), random_color) 248 | 249 | for box in input_boxes: 250 | frame = load_box(box.cpu().numpy(), frame) 251 | 252 | elif type(input_box[0]) == int: 253 | input_boxes = np.array(input_box)[None, :] 254 | 255 | masks, _, _ = predictor.predict( 256 | point_coords=input_point, 257 | point_labels=input_label, 258 | box=input_boxes, 259 | multimask_output=multimask_output, 260 | ) 261 | mask_image = load_mask(masks, random_color) 262 | frame = load_box(input_box, frame) 263 | 264 | combined_mask = cv2.add(frame, mask_image) 265 | out.write(combined_mask) 266 | 267 | out.release() 268 | cap.release() 269 | cv2.destroyAllWindows() 270 | return output_path 271 | -------------------------------------------------------------------------------- /metaseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .data_utils import load_box as load_box 8 | from .data_utils import load_image as load_image 9 | from .data_utils import load_mask as load_mask 10 | from .data_utils import load_server_image as load_server_image 11 | from .data_utils import load_video as load_video 12 | from .data_utils import multi_boxes as multi_boxes 13 | from .data_utils import plt_load_box as plt_load_box 14 | from .data_utils import plt_load_mask as plt_load_mask 15 | from .data_utils import show_image as show_image 16 | from .model_file_downloader import download_model as download_model 17 | from .onnx import SamOnnxModel as SamOnnxModel 18 | from .transforms import ResizeLongestSide as ResizeLongestSide 19 | -------------------------------------------------------------------------------- /metaseg/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from copy import deepcopy 9 | from itertools import product 10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 11 | 12 | import cv2 13 | import numpy as np 14 | import torch 15 | from pycocotools import mask as mask_utils 16 | 17 | 18 | class MaskData: 19 | """ 20 | A structure for storing masks and their related data in batched format. 21 | Implements basic filtering and concatenation. 22 | """ 23 | 24 | def __init__(self, **kwargs) -> None: 25 | for v in kwargs.values(): 26 | assert isinstance( 27 | v, (list, np.ndarray, torch.Tensor) 28 | ), "MaskData only supports list, numpy arrays, and torch tensors." 29 | self._stats = dict(**kwargs) 30 | 31 | def __setitem__(self, key: str, item: Any) -> None: 32 | assert isinstance( 33 | item, (list, np.ndarray, torch.Tensor) 34 | ), "MaskData only supports list, numpy arrays, and torch tensors." 35 | self._stats[key] = item 36 | 37 | def __delitem__(self, key: str) -> None: 38 | del self._stats[key] 39 | 40 | def __getitem__(self, key: str) -> Any: 41 | return self._stats[key] 42 | 43 | def items(self) -> ItemsView[str, Any]: 44 | return self._stats.items() 45 | 46 | def filter(self, keep: torch.Tensor) -> None: 47 | for k, v in self._stats.items(): 48 | if v is None: 49 | self._stats[k] = None 50 | elif isinstance(v, torch.Tensor): 51 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 52 | elif isinstance(v, np.ndarray): 53 | self._stats[k] = v[keep.detach().cpu().numpy()] 54 | elif isinstance(v, list) and keep.dtype == torch.bool: 55 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 56 | elif isinstance(v, list): 57 | self._stats[k] = [v[i] for i in keep] 58 | else: 59 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 60 | 61 | def cat(self, new_stats: "MaskData") -> None: 62 | for k, v in new_stats.items(): 63 | if k not in self._stats or self._stats[k] is None: 64 | self._stats[k] = deepcopy(v) 65 | elif isinstance(v, torch.Tensor): 66 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 67 | elif isinstance(v, np.ndarray): 68 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 69 | elif isinstance(v, list): 70 | self._stats[k] = self._stats[k] + deepcopy(v) 71 | else: 72 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 73 | 74 | def to_numpy(self) -> None: 75 | for k, v in self._stats.items(): 76 | if isinstance(v, torch.Tensor): 77 | self._stats[k] = v.detach().cpu().numpy() 78 | 79 | 80 | def is_box_near_crop_edge( 81 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 82 | ) -> torch.Tensor: 83 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 84 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 85 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 86 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 87 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 88 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 89 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 90 | return torch.any(near_crop_edge, dim=1) 91 | 92 | 93 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 94 | box_xywh = deepcopy(box_xyxy) 95 | box_xywh[2] = box_xywh[2] - box_xywh[0] 96 | box_xywh[3] = box_xywh[3] - box_xywh[1] 97 | return box_xywh 98 | 99 | 100 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 101 | assert len(args) > 0 and all( 102 | len(a) == len(args[0]) for a in args 103 | ), "Batched iteration must have inputs of all the same size." 104 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 105 | for b in range(n_batches): 106 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 107 | 108 | 109 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 110 | """ 111 | Encodes masks to an uncompressed RLE, in the format expected by 112 | pycoco tools. 113 | """ 114 | # Put in fortran order and flatten h,w 115 | b, h, w = tensor.shape 116 | tensor = tensor.permute(0, 2, 1).flatten(1) 117 | 118 | # Compute change indices 119 | diff = tensor[:, 1:] ^ tensor[:, :-1] 120 | change_indices = diff.nonzero() 121 | 122 | # Encode run length 123 | out = [] 124 | for i in range(b): 125 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 126 | cur_idxs = torch.cat( 127 | [ 128 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | cur_idxs + 1, 130 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 131 | ] 132 | ) 133 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 134 | counts = [] if tensor[i, 0] == 0 else [0] 135 | counts.extend(btw_idxs.detach().cpu().tolist()) 136 | out.append({"size": [h, w], "counts": counts}) 137 | return out 138 | 139 | 140 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 141 | """Compute a binary mask from an uncompressed RLE.""" 142 | h, w = rle["size"] 143 | mask = np.empty(h * w, dtype=bool) 144 | idx = 0 145 | parity = False 146 | for count in rle["counts"]: 147 | mask[idx : idx + count] = parity 148 | idx += count 149 | parity ^= True 150 | mask = mask.reshape(w, h) 151 | return mask.transpose() # Put in C order 152 | 153 | 154 | def area_from_rle(rle: Dict[str, Any]) -> int: 155 | return sum(rle["counts"][1::2]) 156 | 157 | 158 | def calculate_stability_score( 159 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 160 | ) -> torch.Tensor: 161 | """ 162 | Computes the stability score for a batch of masks. The stability 163 | score is the IoU between the binary masks obtained by thresholding 164 | the predicted mask logits at high and low values. 165 | """ 166 | # One mask is always contained inside the other. 167 | # Save memory by preventing unnecesary cast to torch.int64 168 | intersections = ( 169 | (masks > (mask_threshold + threshold_offset)) 170 | .sum(-1, dtype=torch.int16) 171 | .sum(-1, dtype=torch.int32) 172 | ) 173 | unions = ( 174 | (masks > (mask_threshold - threshold_offset)) 175 | .sum(-1, dtype=torch.int16) 176 | .sum(-1, dtype=torch.int32) 177 | ) 178 | return intersections / unions 179 | 180 | 181 | def build_point_grid(n_per_side: int) -> np.ndarray: 182 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 183 | offset = 1 / (2 * n_per_side) 184 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 185 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 186 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 187 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 188 | return points 189 | 190 | 191 | def build_all_layer_point_grids( 192 | n_per_side: int, n_layers: int, scale_per_layer: int 193 | ) -> List[np.ndarray]: 194 | """Generates point grids for all crop layers.""" 195 | points_by_layer = [] 196 | for i in range(n_layers + 1): 197 | n_points = int(n_per_side / (scale_per_layer**i)) 198 | points_by_layer.append(build_point_grid(n_points)) 199 | return points_by_layer 200 | 201 | 202 | def generate_crop_boxes( 203 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 204 | ) -> Tuple[List[List[int]], List[int]]: 205 | """ 206 | Generates a list of crop boxes of different sizes. Each layer 207 | has (2**i)**2 boxes for the ith layer. 208 | """ 209 | crop_boxes, layer_idxs = [], [] 210 | im_h, im_w = im_size 211 | short_side = min(im_h, im_w) 212 | 213 | # Original image 214 | crop_boxes.append([0, 0, im_w, im_h]) 215 | layer_idxs.append(0) 216 | 217 | def crop_len(orig_len, n_crops, overlap): 218 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 219 | 220 | for i_layer in range(n_layers): 221 | n_crops_per_side = 2 ** (i_layer + 1) 222 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 223 | 224 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 225 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 226 | 227 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 228 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 229 | 230 | # Crops in XYWH format 231 | for x0, y0 in product(crop_box_x0, crop_box_y0): 232 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 233 | crop_boxes.append(box) 234 | layer_idxs.append(i_layer + 1) 235 | 236 | return crop_boxes, layer_idxs 237 | 238 | 239 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 240 | x0, y0, _, _ = crop_box 241 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 242 | # Check if boxes has a channel dimension 243 | if len(boxes.shape) == 3: 244 | offset = offset.unsqueeze(1) 245 | return boxes + offset 246 | 247 | 248 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 249 | x0, y0, _, _ = crop_box 250 | offset = torch.tensor([[x0, y0]], device=points.device) 251 | # Check if points has a channel dimension 252 | if len(points.shape) == 3: 253 | offset = offset.unsqueeze(1) 254 | return points + offset 255 | 256 | 257 | def uncrop_masks( 258 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 259 | ) -> torch.Tensor: 260 | x0, y0, x1, y1 = crop_box 261 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 262 | return masks 263 | # Coordinate transform masks 264 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 265 | pad = (x0, pad_x - x0, y0, pad_y - y0) 266 | return torch.nn.functional.pad(masks, pad, value=0) 267 | 268 | 269 | def remove_small_regions( 270 | mask: np.ndarray, area_thresh: float, mode: str 271 | ) -> Tuple[np.ndarray, bool]: 272 | """ 273 | Removes small disconnected regions and holes in a mask. Returns the 274 | mask and an indicator of if the mask has been modified. 275 | """ 276 | 277 | assert mode in ["holes", "islands"] 278 | correct_holes = mode == "holes" 279 | working_mask = (correct_holes ^ mask).astype(np.uint8) 280 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 281 | sizes = stats[:, -1][1:] # Row 0 is background label 282 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 283 | if len(small_regions) == 0: 284 | return mask, False 285 | fill_labels = [0] + small_regions 286 | if not correct_holes: 287 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 288 | # If every region is below threshold, keep largest 289 | if len(fill_labels) == 0: 290 | fill_labels = [int(np.argmax(sizes)) + 1] 291 | mask = np.isin(regions, fill_labels) 292 | return mask, True 293 | 294 | 295 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 296 | h, w = uncompressed_rle["size"] 297 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 298 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 299 | return rle 300 | 301 | 302 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 303 | """ 304 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 305 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 306 | """ 307 | # torch.max below raises an error on empty inputs, just skip in this case 308 | if torch.numel(masks) == 0: 309 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 310 | 311 | # Normalize shape to CxHxW 312 | shape = masks.shape 313 | h, w = shape[-2:] 314 | if len(shape) > 2: 315 | masks = masks.flatten(0, -3) 316 | else: 317 | masks = masks.unsqueeze(0) 318 | 319 | # Get top and bottom edges 320 | in_height, _ = torch.max(masks, dim=-1) 321 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 322 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 323 | in_height_coords = in_height_coords + h * (~in_height) 324 | top_edges, _ = torch.min(in_height_coords, dim=-1) 325 | 326 | # Get left and right edges 327 | in_width, _ = torch.max(masks, dim=-2) 328 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 329 | right_edges, _ = torch.max(in_width_coords, dim=-1) 330 | in_width_coords = in_width_coords + w * (~in_width) 331 | left_edges, _ = torch.min(in_width_coords, dim=-1) 332 | 333 | # If the mask is empty the right edge will be to the left of the left edge. 334 | # Replace these boxes with [0, 0, 0, 0] 335 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 336 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 337 | out = out * (~empty_filter).unsqueeze(-1) 338 | 339 | # Return to original shape 340 | if len(shape) > 2: 341 | out = out.reshape(*shape[:-2], 4) 342 | else: 343 | out = out[0] 344 | 345 | return out 346 | -------------------------------------------------------------------------------- /metaseg/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from os import system 3 | from os.path import isfile as isfile 4 | from typing import Union 5 | from uuid import uuid4 6 | 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from cv2 import Mat 11 | from PIL import Image 12 | from torch import tensor 13 | 14 | 15 | def load_image(image: Union[str, Mat]) -> Mat: 16 | """ 17 | Load image from path 18 | :param image_path: path to image file or image as Mat or np.ndarray 19 | :return: image as Mat 20 | """ 21 | if isfile(image): 22 | image = cv2.imread(image) 23 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 24 | return image 25 | elif isinstance(image, Mat) or isinstance(image, np.ndarray): 26 | return image 27 | else: 28 | raise ValueError("image must be a path or cv2.Mat") 29 | 30 | 31 | def load_server_image(image_path): 32 | imagedir = str(uuid4()) 33 | system(f"mkdir -p {imagedir}") 34 | image = Image.open(BytesIO(image_path)) 35 | if image.mode != "RGB": 36 | image = image.convert("RGB") 37 | 38 | image_path = f"{imagedir}/base_image_v0.png" 39 | output_path = f"{imagedir}/output_v0.png" 40 | image.save(image_path, format="PNG") 41 | return image_path, output_path 42 | 43 | 44 | def load_video(video_path, output_path="output.mp4"): 45 | cap = cv2.VideoCapture(video_path) 46 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 47 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 48 | fourcc = cv2.VideoWriter_fourcc(*"XVID") 49 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 50 | out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) 51 | return cap, out 52 | 53 | 54 | def load_mask(mask, random_color): 55 | if random_color: 56 | color = np.random.rand(3) * 255 57 | else: 58 | color = np.array([100, 50, 0]) 59 | 60 | h, w = mask.shape[-2:] 61 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 62 | mask_image = mask_image.astype(np.uint8) 63 | return mask_image 64 | 65 | 66 | def load_box(box, image): 67 | x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3]) 68 | cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2) 69 | return image 70 | 71 | 72 | def plt_load_mask(mask, ax, random_color=False): 73 | if random_color: 74 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 75 | else: 76 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) 77 | h, w = mask.shape[-2:] 78 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 79 | ax.imshow(mask_image) 80 | 81 | 82 | def plt_load_box(box, ax): 83 | x0, y0 = box[0], box[1] 84 | w, h = box[2] - box[0], box[3] - box[1] 85 | ax.add_patch( 86 | plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2) 87 | ) 88 | 89 | 90 | def multi_boxes(boxes, predictor, image): 91 | input_boxes = tensor(boxes, device=predictor.device) 92 | transformed_boxes = predictor.transform.apply_boxes_torch( 93 | input_boxes, image.shape[:2] 94 | ) 95 | return input_boxes, transformed_boxes 96 | 97 | 98 | def show_image(output_image): 99 | cv2.imshow("output", output_image) 100 | cv2.waitKey(0) 101 | cv2.destroyAllWindows() 102 | -------------------------------------------------------------------------------- /metaseg/utils/model_file_downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from hashlib import md5 4 | from pathlib import Path 5 | from shutil import copyfileobj 6 | 7 | from requests import Response, get 8 | from tqdm.auto import tqdm 9 | 10 | # A dictionary containing model types as keys and their respective URLs as values 11 | MODEL_URLS: dict[str : tuple[str]] = { 12 | "vit_h": ( 13 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 14 | "green", 15 | "01ec64d29a2fca3f0661936605ae66f8", 16 | ), 17 | "vit_l": ( 18 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 19 | "red", 20 | "0b3195507c641ddb6910d2bb5adee89c", 21 | ), 22 | "vit_b": ( 23 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 24 | "cyan", 25 | "4b8939a88964f0f4ff5f5b2642c598a6", 26 | ), 27 | } 28 | 29 | 30 | # md5 check function 31 | def _check_md5(filename: str, orig_md5: str) -> bool: 32 | """ 33 | filename: str, A string representing the path to the file. 34 | orig_md5: str, A string representing the original md5 hash. 35 | """ 36 | if not os.path.exists(filename): 37 | return False 38 | with open(filename, "rb") as file_to_check: 39 | # read contents of the file 40 | data = file_to_check.read() 41 | # pipe contents of the file through 42 | md5_returned = md5(data).hexdigest() 43 | # Return True if the computed hash matches the original one 44 | if md5_returned == orig_md5: 45 | return True 46 | return False 47 | 48 | 49 | def download_model(model_type): 50 | """ 51 | model_type: str, A string representing the model type. 52 | It can be 'vit_h', 'vit_l', or 'vit_b'. 53 | """ 54 | 55 | # Check if the model file already exists and model_type is in MODEL_URLS 56 | filename = f"{model_type}.pth" 57 | if not os.path.exists(filename) and model_type in MODEL_URLS: 58 | print(f"Downloading {filename} model \n") 59 | res: Response = get( 60 | MODEL_URLS[model_type][0], stream=True, allow_redirects=True 61 | ) 62 | if res.status_code != 200: 63 | res.raise_for_status() 64 | raise RuntimeError( 65 | f"Request to {MODEL_URLS[model_type][0]} " 66 | f"returned status code {res.status_code}" 67 | ) 68 | 69 | file_size: int = int(res.headers.get("Content-Length", 0)) 70 | folder_path: Path = Path(filename).expanduser().resolve() 71 | folder_path.parent.mkdir(parents=True, exist_ok=True) 72 | 73 | desc = "(Unknown total file size)" if file_size == 0 else "" 74 | res.raw.read = partial( 75 | res.raw.read, decode_content=True 76 | ) # Decompress if needed 77 | with tqdm.wrapattr( 78 | res.raw, 79 | "read", 80 | total=file_size, 81 | desc=desc, 82 | colour=MODEL_URLS[model_type][1], 83 | ) as r_raw: 84 | with folder_path.open("wb") as f: 85 | copyfileobj(r_raw, f) 86 | 87 | elif os.path.exists(filename): 88 | if not _check_md5(filename, MODEL_URLS[model_type][2]): 89 | print("File corrupted. Re-downloading... \n") 90 | os.remove(filename) 91 | download_model(model_type) 92 | 93 | print(f"{filename} model download complete. \n") 94 | else: 95 | raise ValueError( 96 | "Invalid model type. It should be 'vit_h', 'vit_l', or 'vit_b'." 97 | ) 98 | 99 | return filename 100 | -------------------------------------------------------------------------------- /metaseg/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | 13 | from metaseg.modeling import Sam 14 | from metaseg.utils.amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points( 52 | self, point_coords: torch.Tensor, point_labels: torch.Tensor 53 | ) -> torch.Tensor: 54 | point_coords = point_coords + 0.5 55 | point_coords = point_coords / self.img_size 56 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 57 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 58 | 59 | point_embedding = point_embedding * (point_labels != -1) 60 | point_embedding = ( 61 | point_embedding 62 | + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) 63 | ) 64 | 65 | for i in range(self.model.prompt_encoder.num_point_embeddings): 66 | point_embedding = ( 67 | point_embedding 68 | + self.model.prompt_encoder.point_embeddings[i].weight 69 | * (point_labels == i) 70 | ) 71 | 72 | return point_embedding 73 | 74 | def _embed_masks( 75 | self, input_mask: torch.Tensor, has_mask_input: torch.Tensor 76 | ) -> torch.Tensor: 77 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( 78 | input_mask 79 | ) 80 | mask_embedding = mask_embedding + ( 81 | 1 - has_mask_input 82 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 83 | return mask_embedding 84 | 85 | def mask_postprocessing( 86 | self, masks: torch.Tensor, orig_im_size: torch.Tensor 87 | ) -> torch.Tensor: 88 | masks = F.interpolate( 89 | masks, 90 | size=(self.img_size, self.img_size), 91 | mode="bilinear", 92 | align_corners=False, 93 | ) 94 | 95 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 96 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 97 | 98 | orig_im_size = orig_im_size.to(torch.int64) 99 | h, w = orig_im_size[0], orig_im_size[1] 100 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 101 | return masks 102 | 103 | def select_masks( 104 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 105 | ) -> Tuple[torch.Tensor, torch.Tensor]: 106 | # Determine if we should return the multi click 107 | # mask or not from the number of points. 108 | # The reweighting is used to avoid control flow. 109 | score_reweight = torch.tensor( 110 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 111 | ).to(iou_preds.device) 112 | score = iou_preds + (num_points - 2.5) * score_reweight 113 | best_idx = torch.argmax(score, dim=1) 114 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 115 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 116 | 117 | return masks, iou_preds 118 | 119 | @torch.no_grad() 120 | def forward( 121 | self, 122 | image_embeddings: torch.Tensor, 123 | point_coords: torch.Tensor, 124 | point_labels: torch.Tensor, 125 | mask_input: torch.Tensor, 126 | has_mask_input: torch.Tensor, 127 | orig_im_size: torch.Tensor, 128 | ): 129 | sparse_embedding = self._embed_points(point_coords, point_labels) 130 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 131 | 132 | masks, scores = self.model.mask_decoder.predict_masks( 133 | image_embeddings=image_embeddings, 134 | image_pe=self.model.prompt_encoder.get_dense_pe(), 135 | sparse_prompt_embeddings=sparse_embedding, 136 | dense_prompt_embeddings=dense_embedding, 137 | ) 138 | 139 | if self.use_stability_score: 140 | scores = calculate_stability_score( 141 | masks, self.model.mask_threshold, self.stability_score_offset 142 | ) 143 | 144 | if self.return_single_mask: 145 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 146 | 147 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 148 | 149 | if self.return_extra_metrics: 150 | stability_scores = calculate_stability_score( 151 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 152 | ) 153 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 154 | return upscaled_masks, scores, stability_scores, areas, masks 155 | 156 | return upscaled_masks, scores, masks 157 | -------------------------------------------------------------------------------- /metaseg/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from copy import deepcopy 8 | from typing import Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn import functional as F 13 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape( 31 | image.shape[0], image.shape[1], self.target_length 32 | ) 33 | return np.array(resize(to_pil_image(image), target_size)) 34 | 35 | def apply_coords( 36 | self, coords: np.ndarray, original_size: Tuple[int, ...] 37 | ) -> np.ndarray: 38 | """ 39 | Expects a numpy array of length 2 in the final dimension. Requires the 40 | original image size in (H, W) format. 41 | """ 42 | old_h, old_w = original_size 43 | new_h, new_w = self.get_preprocess_shape( 44 | original_size[0], original_size[1], self.target_length 45 | ) 46 | coords = deepcopy(coords).astype(float) 47 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 48 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 49 | return coords 50 | 51 | def apply_boxes( 52 | self, boxes: np.ndarray, original_size: Tuple[int, ...] 53 | ) -> np.ndarray: 54 | """ 55 | Expects a numpy array shape Bx4. Requires the original image size 56 | in (H, W) format. 57 | """ 58 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 59 | return boxes.reshape(-1, 4) 60 | 61 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 62 | """ 63 | Expects batched images with shape BxCxHxW and float format. This 64 | transformation may not exactly match apply_image. apply_image is 65 | the transformation expected by the model. 66 | """ 67 | # Expects an image in BCHW format. May not exactly match apply_image. 68 | target_size = self.get_preprocess_shape( 69 | image.shape[0], image.shape[1], self.target_length 70 | ) 71 | return F.interpolate( 72 | image, target_size, mode="bilinear", align_corners=False, antialias=True 73 | ) 74 | 75 | def apply_coords_torch( 76 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 77 | ) -> torch.Tensor: 78 | """ 79 | Expects a torch tensor with length 2 in the last dimension. Requires the 80 | original image size in (H, W) format. 81 | """ 82 | old_h, old_w = original_size 83 | new_h, new_w = self.get_preprocess_shape( 84 | original_size[0], original_size[1], self.target_length 85 | ) 86 | coords = deepcopy(coords).to(torch.float) 87 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 88 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 89 | return coords 90 | 91 | def apply_boxes_torch( 92 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 93 | ) -> torch.Tensor: 94 | """ 95 | Expects a torch tensor with shape Bx4. Requires the original image 96 | size in (H, W) format. 97 | """ 98 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 99 | return boxes.reshape(-1, 4) 100 | 101 | @staticmethod 102 | def get_preprocess_shape( 103 | oldh: int, oldw: int, long_side_length: int 104 | ) -> Tuple[int, int]: 105 | """ 106 | Compute the output size given input size and target long side length. 107 | """ 108 | scale = long_side_length * 1.0 / max(oldh, oldw) 109 | newh, neww = oldh * scale, oldw * scale 110 | neww = int(neww + 0.5) 111 | newh = int(newh + 0.5) 112 | return (newh, neww) 113 | -------------------------------------------------------------------------------- /metaseg/webapp/__init__.py: -------------------------------------------------------------------------------- 1 | from .app import metaseg_app as metaseg_app 2 | -------------------------------------------------------------------------------- /metaseg/webapp/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from demo import automask_image_app, automask_video_app, sahi_autoseg_app 3 | 4 | 5 | def image_app(): 6 | with gr.Blocks(): 7 | with gr.Row(): 8 | with gr.Column(): 9 | seg_automask_image_file = gr.Image(type="filepath").style(height=260) 10 | with gr.Row(): 11 | with gr.Column(): 12 | seg_automask_image_model_type = gr.Dropdown( 13 | choices=[ 14 | "vit_h", 15 | "vit_l", 16 | "vit_b", 17 | ], 18 | value="vit_l", 19 | label="Model Type", 20 | ) 21 | 22 | seg_automask_image_min_area = gr.Number( 23 | value=0, 24 | label="Min Area", 25 | ) 26 | with gr.Row(): 27 | with gr.Column(): 28 | seg_automask_image_points_per_side = gr.Slider( 29 | minimum=0, 30 | maximum=32, 31 | step=2, 32 | value=16, 33 | label="Points per Side", 34 | ) 35 | 36 | seg_automask_image_points_per_batch = gr.Slider( 37 | minimum=0, 38 | maximum=64, 39 | step=2, 40 | value=64, 41 | label="Points per Batch", 42 | ) 43 | 44 | seg_automask_image_predict = gr.Button(value="Generator") 45 | 46 | with gr.Column(): 47 | output_image = gr.Image() 48 | 49 | seg_automask_image_predict.click( 50 | fn=automask_image_app, 51 | inputs=[ 52 | seg_automask_image_file, 53 | seg_automask_image_model_type, 54 | seg_automask_image_points_per_side, 55 | seg_automask_image_points_per_batch, 56 | seg_automask_image_min_area, 57 | ], 58 | outputs=[output_image], 59 | ) 60 | 61 | 62 | def video_app(): 63 | with gr.Blocks(): 64 | with gr.Row(): 65 | with gr.Column(): 66 | seg_automask_video_file = gr.Video().style(height=260) 67 | with gr.Row(): 68 | with gr.Column(): 69 | seg_automask_video_model_type = gr.Dropdown( 70 | choices=[ 71 | "vit_h", 72 | "vit_l", 73 | "vit_b", 74 | ], 75 | value="vit_l", 76 | label="Model Type", 77 | ) 78 | seg_automask_video_min_area = gr.Number( 79 | value=1000, 80 | label="Min Area", 81 | ) 82 | 83 | with gr.Row(): 84 | with gr.Column(): 85 | seg_automask_video_points_per_side = gr.Slider( 86 | minimum=0, 87 | maximum=32, 88 | step=2, 89 | value=16, 90 | label="Points per Side", 91 | ) 92 | 93 | seg_automask_video_points_per_batch = gr.Slider( 94 | minimum=0, 95 | maximum=64, 96 | step=2, 97 | value=64, 98 | label="Points per Batch", 99 | ) 100 | 101 | seg_automask_video_predict = gr.Button(value="Generator") 102 | with gr.Column(): 103 | output_video = gr.Video() 104 | 105 | seg_automask_video_predict.click( 106 | fn=automask_video_app, 107 | inputs=[ 108 | seg_automask_video_file, 109 | seg_automask_video_model_type, 110 | seg_automask_video_points_per_side, 111 | seg_automask_video_points_per_batch, 112 | seg_automask_video_min_area, 113 | ], 114 | outputs=[output_video], 115 | ) 116 | 117 | 118 | def sahi_app(): 119 | with gr.Blocks(): 120 | with gr.Row(): 121 | with gr.Column(): 122 | sahi_image_file = gr.Image(type="filepath").style(height=260) 123 | sahi_autoseg_model_type = gr.Dropdown( 124 | choices=[ 125 | "vit_h", 126 | "vit_l", 127 | "vit_b", 128 | ], 129 | value="vit_l", 130 | label="Sam Model Type", 131 | ) 132 | 133 | with gr.Row(): 134 | with gr.Column(): 135 | sahi_model_type = gr.Dropdown( 136 | choices=[ 137 | "yolov5", 138 | "yolov8", 139 | ], 140 | value="yolov5", 141 | label="Detector Model Type", 142 | ) 143 | sahi_image_size = gr.Slider( 144 | minimum=0, 145 | maximum=1600, 146 | step=32, 147 | value=640, 148 | label="Image Size", 149 | ) 150 | 151 | sahi_overlap_width = gr.Slider( 152 | minimum=0, 153 | maximum=1, 154 | step=0.1, 155 | value=0.2, 156 | label="Overlap Width", 157 | ) 158 | 159 | sahi_slice_width = gr.Slider( 160 | minimum=0, 161 | maximum=640, 162 | step=32, 163 | value=256, 164 | label="Slice Width", 165 | ) 166 | 167 | with gr.Row(): 168 | with gr.Column(): 169 | sahi_model_path = gr.Dropdown( 170 | choices=[ 171 | "yolov5l.pt", 172 | "yolov5l6.pt", 173 | "yolov8l.pt", 174 | "yolov8x.pt", 175 | ], 176 | value="yolov5l6.pt", 177 | label="Detector Model Path", 178 | ) 179 | 180 | sahi_conf_th = gr.Slider( 181 | minimum=0, 182 | maximum=1, 183 | step=0.1, 184 | value=0.2, 185 | label="Confidence Threshold", 186 | ) 187 | sahi_overlap_height = gr.Slider( 188 | minimum=0, 189 | maximum=1, 190 | step=0.1, 191 | value=0.2, 192 | label="Overlap Height", 193 | ) 194 | sahi_slice_height = gr.Slider( 195 | minimum=0, 196 | maximum=640, 197 | step=32, 198 | value=256, 199 | label="Slice Height", 200 | ) 201 | sahi_image_predict = gr.Button(value="Generator") 202 | 203 | with gr.Column(): 204 | output_image = gr.Image() 205 | 206 | sahi_image_predict.click( 207 | fn=sahi_autoseg_app, 208 | inputs=[ 209 | sahi_image_file, 210 | sahi_autoseg_model_type, 211 | sahi_model_type, 212 | sahi_model_path, 213 | sahi_conf_th, 214 | sahi_image_size, 215 | sahi_slice_height, 216 | sahi_slice_width, 217 | sahi_overlap_height, 218 | sahi_overlap_width, 219 | ], 220 | outputs=[output_image], 221 | ) 222 | 223 | 224 | def metaseg_app(): 225 | app = gr.Blocks() 226 | with app: 227 | with gr.Row(): 228 | with gr.Column(): 229 | with gr.Tab("Image"): 230 | image_app() 231 | with gr.Tab("Video"): 232 | video_app() 233 | with gr.Tab("SAHI"): 234 | sahi_app() 235 | 236 | app.queue(concurrency_count=1) 237 | app.launch(debug=True, enable_queue=True) 238 | 239 | 240 | if __name__ == "__main__": 241 | metaseg_app() 242 | -------------------------------------------------------------------------------- /metaseg/webapp/demo.py: -------------------------------------------------------------------------------- 1 | from metaseg import ( 2 | SahiAutoSegmentation, 3 | SegAutoMaskPredictor, 4 | SegManualMaskPredictor, 5 | sahi_sliced_predict, 6 | ) 7 | 8 | # For image 9 | 10 | 11 | def automask_image_app( 12 | image_path, model_type, points_per_side, points_per_batch, min_area 13 | ): 14 | SegAutoMaskPredictor().image_predict( 15 | source=image_path, 16 | model_type=model_type, # vit_l, vit_h, vit_b 17 | points_per_side=points_per_side, 18 | points_per_batch=points_per_batch, 19 | min_area=min_area, 20 | output_path="output.png", 21 | show=False, 22 | save=True, 23 | ) 24 | return "output.png" 25 | 26 | 27 | # For video 28 | 29 | 30 | def automask_video_app( 31 | video_path, model_type, points_per_side, points_per_batch, min_area 32 | ): 33 | SegAutoMaskPredictor().video_predict( 34 | source=video_path, 35 | model_type=model_type, # vit_l, vit_h, vit_b 36 | points_per_side=points_per_side, 37 | points_per_batch=points_per_batch, 38 | min_area=min_area, 39 | output_path="output.mp4", 40 | ) 41 | return "output.mp4" 42 | 43 | 44 | # For manuel box and point selection 45 | 46 | 47 | def manual_app( 48 | image_path, 49 | model_type, 50 | input_point, 51 | input_label, 52 | input_box, 53 | multimask_output, 54 | random_color, 55 | ): 56 | SegManualMaskPredictor().image_predict( 57 | source=image_path, 58 | model_type=model_type, # vit_l, vit_h, vit_b 59 | input_point=input_point, 60 | input_label=input_label, 61 | input_box=input_box, 62 | multimask_output=multimask_output, 63 | random_color=random_color, 64 | output_path="output.png", 65 | show=False, 66 | save=True, 67 | ) 68 | return "output.png" 69 | 70 | 71 | # For sahi sliced prediction 72 | 73 | 74 | def sahi_autoseg_app( 75 | image_path, 76 | sam_model_type, 77 | detection_model_type, 78 | detection_model_path, 79 | conf_th, 80 | image_size, 81 | slice_height, 82 | slice_width, 83 | overlap_height_ratio, 84 | overlap_width_ratio, 85 | ): 86 | boxes = sahi_sliced_predict( 87 | image_path=image_path, 88 | # yolov8, detectron2, mmdetection, torchvision 89 | detection_model_type=detection_model_type, 90 | detection_model_path=detection_model_path, 91 | conf_th=conf_th, 92 | image_size=image_size, 93 | slice_height=slice_height, 94 | slice_width=slice_width, 95 | overlap_height_ratio=overlap_height_ratio, 96 | overlap_width_ratio=overlap_width_ratio, 97 | ) 98 | 99 | SahiAutoSegmentation().image_predict( 100 | source=image_path, 101 | model_type=sam_model_type, 102 | input_box=boxes, 103 | multimask_output=False, 104 | random_color=False, 105 | show=False, 106 | save=True, 107 | ) 108 | 109 | return "output.png" 110 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "metaseg" 3 | version = "0.7.8" 4 | description = "MetaSeg: Packaged version of the Segment Anything repository" 5 | authors = ["Kadir Nar "] 6 | maintainers = ["Kadir Nar "] 7 | readme = "README.md" 8 | packages = [{include = "metaseg"}] 9 | homepage = "https://github.com/kadirnar/segment-anything-video" 10 | repository = "https://github.com/kadirnar/segment-anything-video" 11 | documentation = "https://github.com/kadirnar/segment-anything-video/blob/main/README.md" 12 | keywords = ["pytorch","segment-anything-video","metaseg"] 13 | license = "Apache-2.0" 14 | classifiers = [ 15 | "Development Status :: 5 - Production/Stable", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Natural Language :: English", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3 :: Only", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | "Operating System :: OS Independent", 26 | "Topic :: Software Development :: Libraries :: Application Frameworks", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | "Topic :: Software Development :: Libraries", 29 | "Topic :: Software Development", 30 | "Topic :: Scientific/Engineering", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | "Topic :: Scientific/Engineering :: Mathematics", 33 | "Topic :: Scientific/Engineering :: Image Recognition", 34 | "Topic :: Scientific/Engineering :: Image Processing", 35 | "Intended Audience :: Developers", 36 | "Intended Audience :: Science/Research", 37 | "Intended Audience :: Education", 38 | ] 39 | 40 | 41 | [tool.poetry.dependencies] 42 | python = ">=3.8.1,<3.12.0" 43 | torch = "^2.0.1" 44 | torchvision = "^0.15.2" 45 | opencv-python = "^4.7.0.72" 46 | tqdm = "^4.65.0" 47 | matplotlib = "^3.7.1" 48 | pillow = "^9.5.0" 49 | pycocotools = "^2.0.6" 50 | fal-serverless = "^0.6.35" 51 | sahi = "^0.11.14" 52 | onnx = { version = "^1.14.0", optional = true } 53 | onnxruntime = { version ="^1.15.1", optional = true } 54 | ultralytics = { version = "^8.0.123", optional = true } 55 | yolov5 = { version ="^7.0.12", optional = true } 56 | requests = "^2.31.0" 57 | 58 | 59 | [tool.poetry.extras] 60 | full = ["onnxruntime","onnx","yolov5","ultralytics"] 61 | yolov5 = ["yolov5"] 62 | yolov8 = ["ultralytics"] 63 | 64 | 65 | [tool.poetry.group.dev.dependencies] 66 | black = "^23.1.0" 67 | mypy = "^1.0.1" 68 | bandit = "^1.7.4" 69 | debugpy = "^1.6.6" 70 | rope = "^1.7.0" 71 | wheel = "^0.38.4" 72 | setuptools = "^67.4.0" 73 | coverage = "^7.2.1" 74 | pre-commit = "^3.1.1" 75 | pyupgrade = "^3.3.1" 76 | ruff = "^0.0.244" 77 | pytest = "^7.2.1" 78 | toml = "^0.10.2" 79 | flake8 = "^6.0.0" 80 | isort = "^5.12.0" 81 | parameterized = "^0.9.0" 82 | 83 | 84 | 85 | [tool.isort] 86 | line_length = 88 87 | profile = "black" 88 | 89 | [tool.bandit] 90 | target = ["tests", "metaseg"] 91 | tests = ["B201", "B301"] 92 | 93 | [tool.autoflake] 94 | check = true 95 | imports = ["cv2", "requests", "metaseg"] 96 | 97 | 98 | [tool.black] 99 | line-length = 88 100 | include = '\.pyi?$' 101 | exclude = ''' 102 | /( 103 | \.git 104 | | \.hg 105 | | \.mypy_cache 106 | | \.tox 107 | | \.venv 108 | | _build 109 | | buck-out 110 | | build 111 | | dist 112 | )/ 113 | ''' 114 | 115 | [tool.ruff] 116 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 117 | select = ["E", "F"] 118 | ignore = [] 119 | 120 | # Allow autofix for all enabled rules (when `--fix`) is provided. 121 | fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] 122 | unfixable = [] 123 | 124 | # Exclude a variety of commonly ignored directories. 125 | exclude = [ 126 | ".bzr", 127 | ".direnv", 128 | ".eggs", 129 | ".git", 130 | ".git-rewrite", 131 | ".hg", 132 | ".mypy_cache", 133 | ".nox", 134 | ".pants.d", 135 | ".pytype", 136 | ".ruff_cache", 137 | ".svn", 138 | ".tox", 139 | ".venv", 140 | "__pypackages__", 141 | "_build", 142 | "buck-out", 143 | "build", 144 | "dist", 145 | "node_modules", 146 | "venv", 147 | ] 148 | 149 | # Same as Black. 150 | line-length = 88 151 | 152 | # Allow unused variables when underscore-prefixed. 153 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 154 | 155 | [build-system] 156 | requires = ["poetry-core"] 157 | build-backend = "poetry.core.masonry.api" 158 | -------------------------------------------------------------------------------- /scripts/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import json 9 | import os 10 | from typing import Any, Dict, List 11 | 12 | import cv2 13 | 14 | from metaseg.generator import SamAutomaticMaskGenerator, sam_model_registry 15 | 16 | parser = argparse.ArgumentParser( 17 | description=( 18 | "Runs automatic mask generation on an input image or directory of images, " 19 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " 20 | "as well as pycocotools if saving in RLE format." 21 | ) 22 | ) 23 | 24 | parser.add_argument( 25 | "--input", 26 | type=str, 27 | required=True, 28 | help="Path to either a single input image or folder of images.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--output", 33 | type=str, 34 | required=True, 35 | help=( 36 | "Path to the directory where masks will be output. " 37 | "Output will be either a folder " 38 | "of PNGs per image or a single json with COCO-style masks." 39 | ), 40 | ) 41 | 42 | parser.add_argument( 43 | "--model-type", 44 | type=str, 45 | default="default", 46 | help="The type of model to load, in ['default', 'vit_l', 'vit_b']", 47 | ) 48 | 49 | parser.add_argument( 50 | "--checkpoint", 51 | type=str, 52 | required=True, 53 | help="The path to the SAM checkpoint to use for mask generation.", 54 | ) 55 | 56 | parser.add_argument( 57 | "--device", type=str, default="cuda", help="The device to run generation on." 58 | ) 59 | 60 | parser.add_argument( 61 | "--convert-to-rle", 62 | action="store_true", 63 | help=( 64 | "Save masks as COCO RLEs in a single json " 65 | "instead of as a folder of PNGs. " 66 | "Requires pycocotools." 67 | ), 68 | ) 69 | 70 | amg_settings = parser.add_argument_group("AMG Settings") 71 | 72 | amg_settings.add_argument( 73 | "--points-per-side", 74 | type=int, 75 | default=None, 76 | help="Generate masks by sampling a grid over " 77 | "the image with this many points to a side.", 78 | ) 79 | 80 | amg_settings.add_argument( 81 | "--points-per-batch", 82 | type=int, 83 | default=None, 84 | help="How many input points to process " "simultaneously in one batch.", 85 | ) 86 | 87 | amg_settings.add_argument( 88 | "--pred-iou-thresh", 89 | type=float, 90 | default=None, 91 | help="Exclude masks with a predicted score from " 92 | "the model that is lower than this threshold.", 93 | ) 94 | 95 | amg_settings.add_argument( 96 | "--stability-score-thresh", 97 | type=float, 98 | default=None, 99 | help="Exclude masks with a stability " "score lower than this threshold.", 100 | ) 101 | 102 | amg_settings.add_argument( 103 | "--stability-score-offset", 104 | type=float, 105 | default=None, 106 | help="Larger values perturb the mask " "more when measuring stability score.", 107 | ) 108 | 109 | amg_settings.add_argument( 110 | "--box-nms-thresh", 111 | type=float, 112 | default=None, 113 | help="The overlap threshold for excluding a duplicate mask.", 114 | ) 115 | 116 | amg_settings.add_argument( 117 | "--crop-n-layers", 118 | type=int, 119 | default=None, 120 | help=( 121 | "If >0, mask generation is run on smaller " 122 | "crops of the image to generate more masks. " 123 | "The value sets how many different scales to crop at." 124 | ), 125 | ) 126 | 127 | amg_settings.add_argument( 128 | "--crop-nms-thresh", 129 | type=float, 130 | default=None, 131 | help="The overlap threshold for excluding " 132 | "duplicate masks across different crops.", 133 | ) 134 | 135 | amg_settings.add_argument( 136 | "--crop-overlap-ratio", 137 | type=int, 138 | default=None, 139 | help="Larger numbers mean image crops will overlap more.", 140 | ) 141 | 142 | amg_settings.add_argument( 143 | "--crop-n-points-downscale-factor", 144 | type=int, 145 | default=None, 146 | help="The number of points-per-side in each " 147 | "layer of crop is reduced by this factor.", 148 | ) 149 | 150 | amg_settings.add_argument( 151 | "--min-mask-region-area", 152 | type=int, 153 | default=None, 154 | help=( 155 | "Disconnected mask regions or holes with " 156 | "area smaller than this value " 157 | "in pixels are removed by postprocessing." 158 | ), 159 | ) 160 | 161 | 162 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: 163 | header = ( 164 | "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h," 165 | "point_input_x,point_input_y,predicted_iou," 166 | "stability_score,crop_box_x0,crop_box_y0," 167 | "crop_box_w,crop_box_h" 168 | ) 169 | metadata = [header] 170 | for i, mask_data in enumerate(masks): 171 | mask = mask_data["segmentation"] 172 | filename = f"{i}.png" 173 | cv2.imwrite(os.path.join(path, filename), mask * 255) 174 | mask_metadata = [ 175 | str(i), 176 | str(mask_data["area"]), 177 | *[str(x) for x in mask_data["bbox"]], 178 | *[str(x) for x in mask_data["point_coords"][0]], 179 | str(mask_data["predicted_iou"]), 180 | str(mask_data["stability_score"]), 181 | *[str(x) for x in mask_data["crop_box"]], 182 | ] 183 | row = ",".join(mask_metadata) 184 | metadata.append(row) 185 | metadata_path = os.path.join(path, "metadata.csv") 186 | with open(metadata_path, "w") as f: 187 | f.write("\n".join(metadata)) 188 | 189 | return 190 | 191 | 192 | def get_amg_kwargs(args): 193 | amg_kwargs = { 194 | "points_per_side": args.points_per_side, 195 | "points_per_batch": args.points_per_batch, 196 | "pred_iou_thresh": args.pred_iou_thresh, 197 | "stability_score_thresh": args.stability_score_thresh, 198 | "stability_score_offset": args.stability_score_offset, 199 | "box_nms_thresh": args.box_nms_thresh, 200 | "crop_n_layers": args.crop_n_layers, 201 | "crop_nms_thresh": args.crop_nms_thresh, 202 | "crop_overlap_ratio": args.crop_overlap_ratio, 203 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, 204 | "min_mask_region_area": args.min_mask_region_area, 205 | } 206 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} 207 | return amg_kwargs 208 | 209 | 210 | def main(args: argparse.Namespace) -> None: 211 | print("Loading model...") 212 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) 213 | _ = sam.to(device=args.device) 214 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" 215 | amg_kwargs = get_amg_kwargs(args) 216 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) 217 | 218 | if not os.path.isdir(args.input): 219 | targets = [args.input] 220 | else: 221 | targets = [ 222 | f 223 | for f in os.listdir(args.input) 224 | if not os.path.isdir(os.path.join(args.input, f)) 225 | ] 226 | targets = [os.path.join(args.input, f) for f in targets] 227 | 228 | os.makedirs(args.output, exist_ok=True) 229 | 230 | for t in targets: 231 | print(f"Processing '{t}'...") 232 | image = cv2.imread(t) 233 | if image is None: 234 | print(f"Could not load '{t}' as an image, skipping...") 235 | continue 236 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 237 | 238 | masks = generator.generate(image) 239 | 240 | base = os.path.basename(t) 241 | base = os.path.splitext(base)[0] 242 | save_base = os.path.join(args.output, base) 243 | if output_mode == "binary_mask": 244 | os.makedirs(save_base, exist_ok=False) 245 | write_masks_to_folder(masks, save_base) 246 | else: 247 | save_file = save_base + ".json" 248 | with open(save_file, "w") as f: 249 | json.dump(masks, f) 250 | print("Done!") 251 | 252 | 253 | if __name__ == "__main__": 254 | args = parser.parse_args() 255 | main(args) 256 | -------------------------------------------------------------------------------- /scripts/export_onnx_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import warnings 9 | 10 | import torch 11 | from onnxruntime import InferenceSession 12 | from onnxruntime.quantization import QuantType 13 | from onnxruntime.quantization.quantize import quantize_dynamic 14 | 15 | from metaseg.generator import build_sam, build_sam_vit_b, build_sam_vit_l 16 | from metaseg.utils.onnx import SamOnnxModel 17 | 18 | parser = argparse.ArgumentParser( 19 | description="Export the SAM prompt encoder and mask decoder to an ONNX model." 20 | ) 21 | 22 | parser.add_argument( 23 | "--checkpoint", 24 | type=str, 25 | required=True, 26 | help="The path to the SAM model checkpoint.", 27 | ) 28 | 29 | parser.add_argument( 30 | "--output", type=str, required=True, help="The filename to save the ONNX model to." 31 | ) 32 | 33 | parser.add_argument( 34 | "--model-type", 35 | type=str, 36 | default="default", 37 | help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.", 38 | ) 39 | 40 | parser.add_argument( 41 | "--return-single-mask", 42 | action="store_true", 43 | help=( 44 | "If true, the exported ONNX model will only return the best mask, " 45 | "instead of returning multiple masks. For high resolution images " 46 | "this can improve runtime when upscaling masks is expensive." 47 | ), 48 | ) 49 | 50 | parser.add_argument( 51 | "--opset", 52 | type=int, 53 | default=17, 54 | help="The ONNX opset version to use. Must be >=11", 55 | ) 56 | 57 | parser.add_argument( 58 | "--quantize-out", 59 | type=str, 60 | default=None, 61 | help=( 62 | "If set, will quantize the model and save it with this name. " 63 | "Quantization is performed with quantize_dynamic " 64 | "from onnxruntime.quantization.quantize." 65 | ), 66 | ) 67 | 68 | parser.add_argument( 69 | "--gelu-approximate", 70 | action="store_true", 71 | help=( 72 | "Replace GELU operations with approximations using tanh. Useful " 73 | "for some runtimes that have slow or unimplemented erf ops, used in GELU." 74 | ), 75 | ) 76 | 77 | parser.add_argument( 78 | "--use-stability-score", 79 | action="store_true", 80 | help=( 81 | "Replaces the model's predicted mask quality score with the stability " 82 | "score calculated on the low resolution masks using an offset of 1.0. " 83 | ), 84 | ) 85 | 86 | parser.add_argument( 87 | "--return-extra-metrics", 88 | action="store_true", 89 | help=( 90 | "The model will return five results: (masks, scores, stability_scores, " 91 | "areas, low_res_logits) instead of the usual three. This can be " 92 | "significantly slower for high resolution outputs." 93 | ), 94 | ) 95 | 96 | 97 | def run_export( 98 | model_type: str, 99 | checkpoint: str, 100 | output: str, 101 | opset: int, 102 | return_single_mask: bool, 103 | gelu_approximate: bool = False, 104 | use_stability_score: bool = False, 105 | return_extra_metrics=False, 106 | ): 107 | print("Loading model...") 108 | if model_type == "vit_b": 109 | sam = build_sam_vit_b(checkpoint) 110 | elif model_type == "vit_l": 111 | sam = build_sam_vit_l(checkpoint) 112 | else: 113 | sam = build_sam(checkpoint) 114 | 115 | onnx_model = SamOnnxModel( 116 | model=sam, 117 | return_single_mask=return_single_mask, 118 | use_stability_score=use_stability_score, 119 | return_extra_metrics=return_extra_metrics, 120 | ) 121 | 122 | if gelu_approximate: 123 | for n, m in onnx_model.named_modules(): 124 | if isinstance(m, torch.nn.GELU): 125 | m.approximate = "tanh" 126 | 127 | dynamic_axes = { 128 | "point_coords": {1: "num_points"}, 129 | "point_labels": {1: "num_points"}, 130 | } 131 | 132 | embed_dim = sam.prompt_encoder.embed_dim 133 | embed_size = sam.prompt_encoder.image_embedding_size 134 | mask_input_size = [4 * x for x in embed_size] 135 | dummy_inputs = { 136 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 137 | "point_coords": torch.randint( 138 | low=0, high=1024, size=(1, 5, 2), dtype=torch.float 139 | ), 140 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 141 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 142 | "has_mask_input": torch.tensor([1], dtype=torch.float), 143 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), 144 | } 145 | 146 | _ = onnx_model(**dummy_inputs) 147 | 148 | output_names = ["masks", "iou_predictions", "low_res_masks"] 149 | 150 | with warnings.catch_warnings(): 151 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 152 | warnings.filterwarnings("ignore", category=UserWarning) 153 | with open(output, "wb") as f: 154 | print(f"Exporing onnx model to {output}...") 155 | torch.onnx.export( 156 | onnx_model, 157 | tuple(dummy_inputs.values()), 158 | f, 159 | export_params=True, 160 | verbose=False, 161 | opset_version=opset, 162 | do_constant_folding=True, 163 | input_names=list(dummy_inputs.keys()), 164 | output_names=output_names, 165 | dynamic_axes=dynamic_axes, 166 | ) 167 | 168 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} 169 | ort_session = InferenceSession(output) 170 | _ = ort_session.run(None, ort_inputs) 171 | print("Model has successfully been run with ONNXRuntime.") 172 | 173 | 174 | def to_numpy(tensor): 175 | return tensor.cpu().numpy() 176 | 177 | 178 | if __name__ == "__main__": 179 | args = parser.parse_args() 180 | run_export( 181 | model_type=args.model_type, 182 | checkpoint=args.checkpoint, 183 | output=args.output, 184 | opset=args.opset, 185 | return_single_mask=args.return_single_mask, 186 | gelu_approximate=args.gelu_approximate, 187 | use_stability_score=args.use_stability_score, 188 | return_extra_metrics=args.return_extra_metrics, 189 | ) 190 | 191 | print(f"Quantizing model and writing to {args.quantize_out}...") 192 | quantize_dynamic( 193 | model_input=args.output, 194 | model_output=args.quantize_out, 195 | optimize_model=True, 196 | per_channel=False, 197 | reduce_range=False, 198 | weight_type=QuantType.QUInt8, 199 | ) 200 | print("Done!") 201 | --------------------------------------------------------------------------------