├── .github
└── workflows
│ ├── commit-lint.yml
│ ├── install_test_ci.yml
│ └── release_to_pypi.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── examples
├── image_cv_segautomask.py
├── image_sahi_slice_predict.py
├── image_segautomask.py
├── image_segmanualmask.py
├── video_segautomask.py
└── video_segmanualmask.py
├── metaseg
├── __init__.py
├── falai_predictor.py
├── generator
│ ├── __init__.py
│ ├── automatic_mask_generator.py
│ ├── build_sam.py
│ └── predictor.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ └── transformer.py
├── sahi_predictor.py
├── sam_predictor.py
├── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── data_utils.py
│ ├── model_file_downloader.py
│ ├── onnx.py
│ └── transforms.py
└── webapp
│ ├── __init__.py
│ ├── app.py
│ └── demo.py
├── poetry.lock
├── pyproject.toml
├── requirements-dev.txt
├── requirements.txt
└── scripts
├── amg.py
└── export_onnx_model.py
/.github/workflows/commit-lint.yml:
--------------------------------------------------------------------------------
1 | name: Conventional Commits
2 | on: [pull_request]
3 |
4 | jobs:
5 | lint-commits:
6 | name: Lint Commits
7 | runs-on: ubuntu-latest
8 |
9 | steps:
10 | - name: Checkout 🛎️
11 | uses: actions/checkout@v3.5.3
12 | with:
13 | fetch-depth: 0
14 |
15 | - name: Install Commit Linting Tool 🔧
16 | run: npm install --save-dev @commitlint/{cli,config-conventional}
17 |
18 | - name: Set Linting Config to Conventional Commits spec 🔧
19 | run: |
20 | echo "module.exports = { extends: ['@commitlint/config-conventional'] };" > commitlint.config.js
21 |
22 | - name: Lint 🚨
23 | run: npx commitlint --from HEAD~1 --to HEAD --verbose
24 |
--------------------------------------------------------------------------------
/.github/workflows/install_test_ci.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Install Test Python
3 | on: [push, pull_request, workflow_dispatch]
4 |
5 | jobs:
6 | Install_Test_Python:
7 | strategy:
8 | fail-fast: false
9 | matrix:
10 | os: [ubuntu-latest, windows-latest, macos-latest]
11 | python-version: ["3.8", "3.9", "3.10","3.11"]
12 | runs-on: ${{ matrix.os }}
13 | steps:
14 | - uses: actions/checkout@v3
15 |
16 | - name: Set up Python ${{ matrix.python-version }}
17 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest'
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 |
22 | - name: Set up a virtual environment for Python ${{ matrix.python-version }}
23 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest'
24 | run: |
25 | python -m pip install --upgrade virtualenv
26 | virtualenv venv
27 | source venv/bin/activate
28 |
29 | - name: Install the base dependencies
30 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest'
31 | run: |
32 | source venv/bin/activate
33 | python -m pip install --upgrade poetry
34 |
35 | - name: Check the correctness of the project config
36 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest'
37 | run: |
38 | source venv/bin/activate
39 | poetry check
40 |
41 | - name: Install the package
42 | if: matrix.os == 'macos-latest' || matrix.os == 'ubuntu-latest'
43 | run: |
44 | source venv/bin/activate
45 | poetry install
46 |
47 |
48 | - name: Set up a virtual environment
49 | if: matrix.os == 'windows-latest'
50 | shell: pwsh
51 | run: |
52 | python -m pip install --upgrade virtualenv
53 | python -m virtualenv venv
54 | .\venv\Scripts\Activate.ps1
55 |
56 | - name: Install the base dependencies
57 | shell: pwsh
58 | if: matrix.os == 'windows-latest'
59 | run: |
60 | .\venv\Scripts\Activate.ps1
61 | python -m pip install --upgrade poetry
62 |
63 | - name: Check the correctness of the project config
64 | shell: pwsh
65 | if: matrix.os == 'windows-latest'
66 | run: |
67 | .\venv\Scripts\Activate.ps1
68 | poetry check
69 |
70 | - name: Install the package
71 | shell: pwsh
72 | if: matrix.os == 'windows-latest'
73 | run: |
74 | .\venv\Scripts\Activate.ps1
75 | poetry install
76 |
--------------------------------------------------------------------------------
/.github/workflows/release_to_pypi.yml:
--------------------------------------------------------------------------------
1 | name: MetaSeg Release to PyPi
2 | on:
3 | push:
4 | tags:
5 | - 'v[0-9]+.[0-9].+[0-9]+'
6 |
7 | # Allows you to run this workflow manually from the Actions tab
8 | workflow_dispatch:
9 |
10 | jobs:
11 | build-n-publish:
12 | name: Build and publish to PyPI
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - name: Checkout source
17 | uses: actions/checkout@v3.5.3
18 |
19 | - name: Set up Python
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: "3.9"
23 |
24 | - name: Build source and wheel distributions
25 | run: |
26 | python -m pip install --upgrade build twine
27 | python -m build
28 | twine check --strict dist/*
29 | - name: Publish distribution to PyPI
30 | uses: pypa/gh-action-pypi-publish@release/v1
31 | with:
32 | user: __token__
33 | password: ${{ secrets.PYPI_METASEG_API_KEY }}
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # .pth/pt files (pytorch)
132 | *.pth
133 | *.pt
134 |
135 | # Ignore OSX folder attributes
136 | .DS_Store
137 |
138 | # Idea project files
139 | .idea/
140 | .idea_modules/
141 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: end-of-file-fixer
8 | - id: trailing-whitespace
9 | - id: check-yaml
10 | - id: check-docstring-first
11 | - id: check-executables-have-shebangs
12 | - id: check-toml
13 | - id: check-case-conflict
14 | - id: check-added-large-files
15 | args: ['--maxkb=2048']
16 | exclude: ^logo/
17 | - id: detect-private-key
18 | - id: forbid-new-submodules
19 | - id: pretty-format-json
20 | args: ['--autofix', '--no-sort-keys', '--indent=4']
21 | - id: end-of-file-fixer
22 | - id: mixed-line-ending
23 | - repo: https://github.com/asottile/pyupgrade
24 | rev: v3.19.1
25 | hooks:
26 | - id: pyupgrade
27 | args:
28 | - --py3-plus
29 | - --keep-runtime-typing
30 | - repo: https://github.com/astral-sh/ruff-pre-commit
31 | rev: v0.9.1
32 | hooks:
33 | - id: ruff
34 | args: [--fix, --exit-non-zero-on-fix]
35 | - repo: https://github.com/pycqa/isort
36 | rev: 5.13.2
37 | hooks:
38 | - id: isort
39 | name: isort (python)
40 | - id: isort
41 | name: isort (cython)
42 | types: [cython]
43 | - id: isort
44 | name: isort (pyi)
45 | types: [pyi]
46 | - repo: https://github.com/psf/black
47 | rev: 24.10.0
48 | hooks:
49 | - id: black
50 | - repo: https://github.com/PyCQA/bandit
51 | rev: '1.8.2'
52 | hooks:
53 | - id: bandit
54 | args: ["-c", "pyproject.toml"]
55 | additional_dependencies: ["bandit[toml]"]
56 | - repo: https://github.com/PyCQA/autoflake
57 | rev: v2.3.1
58 | hooks:
59 | - id: autoflake
60 |
61 | ci:
62 | autofix_commit_msg: "fix(pre_commit): 🎨 auto format pre-commit hooks"
63 | autoupdate_commit_msg: "fix(pre_commit): ⬆ pre_commit autoupdate"
64 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 |
2 | # Contributor Covenant Code of Conduct
3 |
4 | ## Our Pledge
5 |
6 | We as members, contributors, and leaders pledge to make participation in our
7 | community a harassment-free experience for everyone, regardless of age, body
8 | size, visible or invisible disability, ethnicity, sex characteristics, gender
9 | identity and expression, level of experience, education, socio-economic status,
10 | nationality, personal appearance, race, caste, color, religion, or sexual
11 | identity and orientation.
12 |
13 | We pledge to act and interact in ways that contribute to an open, welcoming,
14 | diverse, inclusive, and healthy community.
15 |
16 | ## Our Standards
17 |
18 | Examples of behavior that contributes to a positive environment for our
19 | community include:
20 |
21 | * Demonstrating empathy and kindness toward other people
22 | * Being respectful of differing opinions, viewpoints, and experiences
23 | * Giving and gracefully accepting constructive feedback
24 | * Accepting responsibility and apologizing to those affected by our mistakes,
25 | and learning from the experience
26 | * Focusing on what is best not just for us as individuals, but for the overall
27 | community
28 |
29 | Examples of unacceptable behavior include:
30 |
31 | * The use of sexualized language or imagery, and sexual attention or advances of
32 | any kind
33 | * Trolling, insulting or derogatory comments, and personal or political attacks
34 | * Public or private harassment
35 | * Publishing others' private information, such as a physical or email address,
36 | without their explicit permission
37 | * Other conduct which could reasonably be considered inappropriate in a
38 | professional setting
39 |
40 | ## Enforcement Responsibilities
41 |
42 | Community leaders are responsible for clarifying and enforcing our standards of
43 | acceptable behavior and will take appropriate and fair corrective action in
44 | response to any behavior that they deem inappropriate, threatening, offensive,
45 | or harmful.
46 |
47 | Community leaders have the right and responsibility to remove, edit, or reject
48 | comments, commits, code, wiki edits, issues, and other contributions that are
49 | not aligned to this Code of Conduct, and will communicate reasons for moderation
50 | decisions when appropriate.
51 |
52 | ## Scope
53 |
54 | This Code of Conduct applies within all community spaces, and also applies when
55 | an individual is officially representing the community in public spaces.
56 | Examples of representing our community include using an official e-mail address,
57 | posting via an official social media account, or acting as an appointed
58 | representative at an online or offline event.
59 |
60 | ## Enforcement
61 |
62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
63 | reported to the community leaders responsible for enforcement at
64 | [INSERT CONTACT METHOD].
65 | All complaints will be reviewed and investigated promptly and fairly.
66 |
67 | All community leaders are obligated to respect the privacy and security of the
68 | reporter of any incident.
69 |
70 | ## Enforcement Guidelines
71 |
72 | Community leaders will follow these Community Impact Guidelines in determining
73 | the consequences for any action they deem in violation of this Code of Conduct:
74 |
75 | ### 1. Correction
76 |
77 | **Community Impact**: Use of inappropriate language or other behavior deemed
78 | unprofessional or unwelcome in the community.
79 |
80 | **Consequence**: A private, written warning from community leaders, providing
81 | clarity around the nature of the violation and an explanation of why the
82 | behavior was inappropriate. A public apology may be requested.
83 |
84 | ### 2. Warning
85 |
86 | **Community Impact**: A violation through a single incident or series of
87 | actions.
88 |
89 | **Consequence**: A warning with consequences for continued behavior. No
90 | interaction with the people involved, including unsolicited interaction with
91 | those enforcing the Code of Conduct, for a specified period of time. This
92 | includes avoiding interactions in community spaces as well as external channels
93 | like social media. Violating these terms may lead to a temporary or permanent
94 | ban.
95 |
96 | ### 3. Temporary Ban
97 |
98 | **Community Impact**: A serious violation of community standards, including
99 | sustained inappropriate behavior.
100 |
101 | **Consequence**: A temporary ban from any sort of interaction or public
102 | communication with the community for a specified period of time. No public or
103 | private interaction with the people involved, including unsolicited interaction
104 | with those enforcing the Code of Conduct, is allowed during this period.
105 | Violating these terms may lead to a permanent ban.
106 |
107 | ### 4. Permanent Ban
108 |
109 | **Community Impact**: Demonstrating a pattern of violation of community
110 | standards, including sustained inappropriate behavior, harassment of an
111 | individual, or aggression toward or disparagement of classes of individuals.
112 |
113 | **Consequence**: A permanent ban from any sort of public interaction within the
114 | community.
115 |
116 | ## Attribution
117 |
118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
119 | version 2.1, available at
120 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
121 |
122 | Community Impact Guidelines were inspired by
123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
124 |
125 | For answers to common questions about this code of conduct, see the FAQ at
126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
127 | [https://www.contributor-covenant.org/translations][translations].
128 |
129 | [homepage]: https://www.contributor-covenant.org
130 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
131 | [Mozilla CoC]: https://github.com/mozilla/diversity
132 | [FAQ]: https://www.contributor-covenant.org/faq
133 | [translations]: https://www.contributor-covenant.org/translations
134 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | MetaSeg: Packaged version of the Segment Anything repository
4 |
5 |
6 |

7 |
8 |

9 |

10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | This repo is a packaged version of the [segment-anything](https://github.com/facebookresearch/segment-anything) model.
33 |
34 | ### Installation
35 | ```bash
36 | pip install metaseg
37 | ```
38 |
39 | ### Usage
40 | ```python
41 | from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor
42 |
43 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch.
44 |
45 | # For image
46 | results = SegAutoMaskPredictor().image_predict(
47 | source="image.jpg",
48 | model_type="vit_l", # vit_l, vit_h, vit_b
49 | points_per_side=16,
50 | points_per_batch=64,
51 | min_area=0,
52 | output_path="output.jpg",
53 | show=True,
54 | save=False,
55 | )
56 |
57 | # For video
58 | results = SegAutoMaskPredictor().video_predict(
59 | source="video.mp4",
60 | model_type="vit_l", # vit_l, vit_h, vit_b
61 | points_per_side=16,
62 | points_per_batch=64,
63 | min_area=1000,
64 | output_path="output.mp4",
65 | )
66 |
67 | # For manuel box and point selection
68 |
69 | # For image
70 | results = SegManualMaskPredictor().image_predict(
71 | source="image.jpg",
72 | model_type="vit_l", # vit_l, vit_h, vit_b
73 | input_point=[[100, 100], [200, 200]],
74 | input_label=[0, 1],
75 | input_box=[100, 100, 200, 200], # or [[100, 100, 200, 200], [100, 100, 200, 200]]
76 | multimask_output=False,
77 | random_color=False,
78 | show=True,
79 | save=False,
80 | )
81 |
82 | # For video
83 |
84 | results = SegManualMaskPredictor().video_predict(
85 | source="video.mp4",
86 | model_type="vit_l", # vit_l, vit_h, vit_b
87 | input_point=[0, 0, 100, 100],
88 | input_label=[0, 1],
89 | input_box=None,
90 | multimask_output=False,
91 | random_color=False,
92 | output_path="output.mp4",
93 | )
94 | ```
95 |
96 | ### [SAHI](https://github.com/obss/sahi) + Segment Anything
97 |
98 | ```bash
99 | pip install sahi metaseg
100 | ```
101 |
102 | ```python
103 | from metaseg.sahi_predict import SahiAutoSegmentation, sahi_sliced_predict
104 |
105 | image_path = "image.jpg"
106 | boxes = sahi_sliced_predict(
107 | image_path=image_path,
108 | detection_model_type="yolov5", # yolov8, detectron2, mmdetection, torchvision
109 | detection_model_path="yolov5l6.pt",
110 | conf_th=0.25,
111 | image_size=1280,
112 | slice_height=256,
113 | slice_width=256,
114 | overlap_height_ratio=0.2,
115 | overlap_width_ratio=0.2,
116 | )
117 |
118 | SahiAutoSegmentation().image_predict(
119 | source=image_path,
120 | model_type="vit_b",
121 | input_box=boxes,
122 | multimask_output=False,
123 | random_color=False,
124 | show=True,
125 | save=False,
126 | )
127 | ```
128 |
129 |
130 | ### [FalAI(Cloud GPU)](https://docs.fal.ai/fal-serverless/quickstart) + Segment Anything
131 | ```bash
132 | pip install metaseg fal_serverless
133 | fal-serverless auth login
134 | ```
135 |
136 | ```python
137 | # For Auto Mask
138 | from metaseg import falai_automask_image
139 |
140 | image = falai_automask_image(
141 | image_path="image.jpg",
142 | model_type="vit_b",
143 | points_per_side=16,
144 | points_per_batch=32,
145 | min_area=0,
146 | )
147 | image.show() # Show image
148 | image.save("output.jpg") # Save image
149 |
150 | # For Manual Mask
151 | from metaseg import falai_manuelmask_image
152 |
153 | image = falai_manualmask_image(
154 | image_path="image.jpg",
155 | model_type="vit_b",
156 | input_point=[[100, 100], [200, 200]],
157 | input_label=[0, 1],
158 | input_box=[100, 100, 200, 200], # or [[100, 100, 200, 200], [100, 100, 200, 200]],
159 | multimask_output=False,
160 | random_color=False,
161 | )
162 | image.show() # Show image
163 | image.save("output.jpg") # Save image
164 | ```
165 | # Extra Features
166 |
167 | - [x] Support for Yolov5/8, Detectron2, Mmdetection, Torchvision models
168 | - [x] Support for video and web application(Huggingface Spaces)
169 | - [x] Support for manual single multi box and point selection
170 | - [x] Support for pip installation
171 | - [x] Support for SAHI library
172 | - [x] Support for FalAI
173 |
--------------------------------------------------------------------------------
/examples/image_cv_segautomask.py:
--------------------------------------------------------------------------------
1 | from cv2 import COLOR_BGR2RGB, Mat
2 | from cv2 import cvtColor as cv_cvtColor
3 | from cv2 import imread as cv_imread
4 |
5 | from metaseg import SegAutoMaskPredictor
6 |
7 |
8 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch.
9 | def main(src: Mat) -> None:
10 | SegAutoMaskPredictor().image_predict(
11 | source=src,
12 | model_type="vit_l", # vit_l, vit_h, vit_b
13 | points_per_side=16,
14 | points_per_batch=64,
15 | min_area=0,
16 | output_path="output.jpg",
17 | show=True,
18 | save=False,
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 | image = cv_imread("image.png")
24 | image = cv_cvtColor(image, COLOR_BGR2RGB)
25 | main(image)
26 |
--------------------------------------------------------------------------------
/examples/image_sahi_slice_predict.py:
--------------------------------------------------------------------------------
1 | from metaseg import SahiAutoSegmentation, sahi_sliced_predict
2 |
3 |
4 | def main(src: str = "image.png") -> None:
5 | img_path = src
6 | boxes = sahi_sliced_predict(
7 | image_path=img_path,
8 | detection_model_type="yolov5", # yolov8, detectron2, mmdetection, torchvision
9 | detection_model_path="yolov5l6.pt",
10 | conf_th=0.25,
11 | image_size=1280,
12 | slice_height=256,
13 | slice_width=256,
14 | overlap_height_ratio=0.2,
15 | overlap_width_ratio=0.2,
16 | )
17 |
18 | SahiAutoSegmentation().image_predict(
19 | source=img_path,
20 | model_type="vit_b",
21 | input_box=boxes,
22 | multimask_output=False,
23 | random_color=False,
24 | show=True,
25 | save=False,
26 | )
27 |
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------
/examples/image_segautomask.py:
--------------------------------------------------------------------------------
1 | from metaseg import SegAutoMaskPredictor
2 |
3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch.
4 |
5 |
6 | def main(src: str = "image.png") -> None:
7 | SegAutoMaskPredictor().image_predict(
8 | source=src,
9 | model_type="vit_l", # vit_l, vit_h, vit_b
10 | points_per_side=16,
11 | points_per_batch=64,
12 | min_area=0,
13 | output_path="output.jpg",
14 | show=True,
15 | save=False,
16 | )
17 |
18 |
19 | if __name__ == "__main__":
20 | main()
21 |
--------------------------------------------------------------------------------
/examples/image_segmanualmask.py:
--------------------------------------------------------------------------------
1 | from metaseg import SegManualMaskPredictor
2 |
3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch.
4 |
5 |
6 | def main(src: str = "image.png") -> None:
7 | SegManualMaskPredictor().image_predict(
8 | source=src,
9 | model_type="vit_l", # vit_l, vit_h, vit_b
10 | input_point=[[100, 100], [200, 200]],
11 | input_label=[0, 1],
12 | input_box=[
13 | 100,
14 | 100,
15 | 200,
16 | 200,
17 | ], # or [[100, 100, 200, 200], [100, 100, 200, 200]]
18 | multimask_output=False,
19 | random_color=False,
20 | show=True,
21 | save=False,
22 | )
23 |
24 |
25 | if __name__ == "__main__":
26 | main()
27 |
--------------------------------------------------------------------------------
/examples/video_segautomask.py:
--------------------------------------------------------------------------------
1 | from metaseg import SegAutoMaskPredictor
2 |
3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch.
4 |
5 |
6 | # For video
7 | def main(src: str = "video.mp4") -> None:
8 | SegAutoMaskPredictor().video_predict(
9 | source=src,
10 | model_type="vit_l", # vit_l, vit_h, vit_b
11 | points_per_side=16,
12 | points_per_batch=64,
13 | min_area=1000,
14 | output_path="output.mp4",
15 | )
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/examples/video_segmanualmask.py:
--------------------------------------------------------------------------------
1 | from metaseg import SegManualMaskPredictor
2 |
3 | # If gpu memory is not enough, reduce the points_per_side and points_per_batch.
4 |
5 |
6 | def main(src: str = "video.mp4") -> None:
7 | SegManualMaskPredictor().video_predict(
8 | source=src,
9 | model_type="vit_l", # vit_l, vit_h, vit_b
10 | input_point=[0, 0, 100, 100],
11 | input_label=[0, 1],
12 | input_box=None,
13 | multimask_output=False,
14 | random_color=False,
15 | output_path="output.mp4",
16 | )
17 |
18 |
19 | if __name__ == "__main__":
20 | main()
21 |
--------------------------------------------------------------------------------
/metaseg/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import importlib.metadata as importlib_metadata
8 |
9 | from .falai_predictor import automask_image as automask_image
10 | from .falai_predictor import falai_automask_image as falai_automask_image
11 | from .falai_predictor import falai_manuelmask_image as falai_manuelmask_image
12 | from .falai_predictor import manuelmask_image as manuelmask_image
13 | from .sahi_predictor import SahiAutoSegmentation as SahiAutoSegmentation
14 | from .sahi_predictor import sahi_sliced_predict as sahi_sliced_predict
15 | from .sam_predictor import SegAutoMaskPredictor as SegAutoMaskPredictor
16 | from .sam_predictor import SegManualMaskPredictor as SegManualMaskPredictor
17 |
18 | try:
19 | # This will read version from pyproject.toml
20 | __version__ = importlib_metadata.version(__package__ or __name__)
21 | except importlib_metadata.PackageNotFoundError:
22 | __version__ = "development"
23 |
--------------------------------------------------------------------------------
/metaseg/falai_predictor.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | from PIL import Image
4 |
5 | from .sam_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
6 | from .utils.data_utils import load_server_image
7 |
8 | try:
9 | from fal_serverless import isolated
10 | except ImportError:
11 | raise ImportError(
12 | "Please install FalAI library using 'pip install fal_serverless'."
13 | )
14 |
15 |
16 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4")
17 | def automask_image(
18 | data, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0
19 | ):
20 | image_path, output_path = load_server_image(data)
21 | SegAutoMaskPredictor().image_predict(
22 | source=image_path,
23 | model_type=model_type,
24 | points_per_side=points_per_side,
25 | points_per_batch=points_per_batch,
26 | min_area=min_area,
27 | output_path=output_path,
28 | show=False,
29 | save=True,
30 | )
31 | with open(output_path, "rb") as f:
32 | result = f.read()
33 |
34 | return result
35 |
36 |
37 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4")
38 | def manuelmask_image(
39 | data,
40 | model_type="vit_b",
41 | input_point=[[100, 100], [200, 200]],
42 | input_label=[0, 1],
43 | input_box=[100, 100, 200, 200],
44 | multimask_output=False,
45 | random_color=False,
46 | min_area=0,
47 | ):
48 | image_path, output_path = load_server_image(data)
49 | SegManualMaskPredictor().image_predict(
50 | source=image_path,
51 | model_type=model_type,
52 | input_point=input_point,
53 | input_label=input_label,
54 | input_box=input_box, #
55 | multimask_output=multimask_output,
56 | random_color=random_color,
57 | min_area=min_area, #
58 | output_path=output_path,
59 | show=False,
60 | save=True,
61 | )
62 | with open(output_path, "rb") as f:
63 | result = f.read()
64 |
65 | return result
66 |
67 |
68 | def falai_automask_image(
69 | image_path, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0
70 | ):
71 | with open(image_path, "rb") as f:
72 | data = f.read()
73 |
74 | image = automask_image(
75 | data=data,
76 | model_type=model_type,
77 | points_per_side=points_per_side,
78 | points_per_batch=points_per_batch,
79 | min_area=min_area,
80 | )
81 | image = Image.open(BytesIO(image))
82 | return image
83 |
84 |
85 | def falai_manuelmask_image(
86 | image_path,
87 | model_type="vit_b",
88 | input_point=[[100, 100], [200, 200]],
89 | input_label=[0, 1],
90 | input_box=[100, 100, 200, 200],
91 | multimask_output=False,
92 | random_color=False,
93 | min_area=0,
94 | ):
95 | with open(image_path, "rb") as f:
96 | data = f.read()
97 |
98 | image = manuelmask_image(
99 | data=data,
100 | model_type=model_type,
101 | input_point=input_point,
102 | input_label=input_label,
103 | input_box=input_box,
104 | multimask_output=multimask_output,
105 | random_color=random_color,
106 | min_area=min_area,
107 | )
108 | image = Image.open(BytesIO(image))
109 | return image
110 |
--------------------------------------------------------------------------------
/metaseg/generator/__init__.py:
--------------------------------------------------------------------------------
1 | from .automatic_mask_generator import (
2 | SamAutomaticMaskGenerator as SamAutomaticMaskGenerator,
3 | )
4 | from .build_sam import build_sam as build_sam
5 | from .build_sam import build_sam_vit_b as build_sam_vit_b
6 | from .build_sam import build_sam_vit_h as build_sam_vit_h
7 | from .build_sam import build_sam_vit_l as build_sam_vit_l
8 | from .build_sam import sam_model_registry as sam_model_registry
9 | from .predictor import SamPredictor as SamPredictor
10 |
--------------------------------------------------------------------------------
/metaseg/generator/automatic_mask_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Dict, List, Optional, Tuple
8 |
9 | import numpy as np
10 | import torch
11 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore
12 |
13 | from metaseg.generator.predictor import SamPredictor
14 | from metaseg.modeling import Sam
15 | from metaseg.utils.amg import (
16 | MaskData,
17 | area_from_rle,
18 | batch_iterator,
19 | batched_mask_to_box,
20 | box_xyxy_to_xywh,
21 | build_all_layer_point_grids,
22 | calculate_stability_score,
23 | coco_encode_rle,
24 | generate_crop_boxes,
25 | is_box_near_crop_edge,
26 | mask_to_rle_pytorch,
27 | remove_small_regions,
28 | rle_to_mask,
29 | uncrop_boxes_xyxy,
30 | uncrop_masks,
31 | uncrop_points,
32 | )
33 |
34 |
35 | class SamAutomaticMaskGenerator:
36 | def __init__(
37 | self,
38 | model: Sam,
39 | points_per_side: Optional[int] = 32,
40 | points_per_batch: int = 64,
41 | pred_iou_thresh: float = 0.88,
42 | stability_score_thresh: float = 0.95,
43 | stability_score_offset: float = 1.0,
44 | box_nms_thresh: float = 0.7,
45 | crop_n_layers: int = 0,
46 | crop_nms_thresh: float = 0.7,
47 | crop_overlap_ratio: float = 512 / 1500,
48 | crop_n_points_downscale_factor: int = 1,
49 | point_grids: Optional[List[np.ndarray]] = None,
50 | min_mask_region_area: int = 0,
51 | output_mode: str = "binary_mask",
52 | ) -> None:
53 | """
54 | Using a SAM model, generates masks for the entire image.
55 | Generates a grid of point prompts over the image, then filters
56 | low quality and duplicate masks. The default settings are chosen
57 | for SAM with a ViT-H backbone.
58 |
59 | Arguments:
60 | model (Sam): The SAM model to use for mask prediction.
61 | points_per_side (int or None): The number of points to be sampled
62 | along one side of the image. The total number of points is
63 | points_per_side**2. If None, 'point_grids' must provide explicit
64 | point sampling.
65 | points_per_batch (int): Sets the number of points run simultaneously
66 | by the model. Higher numbers may be faster but use more GPU memory.
67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the
68 | model's predicted mask quality.
69 | stability_score_thresh (float): A filtering threshold in [0,1], using
70 | the stability of the mask under changes to the cutoff used to binarize
71 | the model's mask predictions.
72 | stability_score_offset (float): The amount to shift the cutoff when
73 | calculated the stability score.
74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal
75 | suppression to filter duplicate masks.
76 | crops_n_layers (int): If >0, mask prediction will be run again on
77 | crops of the image. Sets the number of layers to run, where each
78 | layer has 2**i_layer number of image crops.
79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal
80 | suppression to filter duplicate masks between different crops.
81 | crop_overlap_ratio (float): Sets the degree to which crops overlap.
82 | In the first crop layer, crops will overlap by this fraction of
83 | the image length. Later layers with more crops scale down this overlap.
84 | crop_n_points_downscale_factor (int): The number of points-per-side
85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86 | point_grids (list(np.ndarray) or None): A list over explicit grids
87 | of points used for sampling, normalized to [0,1]. The nth grid in the
88 | list is used in the nth crop layer. Exclusive with points_per_side.
89 | min_mask_region_area (int): If >0, postprocessing will be applied
90 | to remove disconnected regions and holes in masks with area smaller
91 | than min_mask_region_area. Requires opencv.
92 | output_mode (str): The form masks are returned in. Can be 'binary_mask',
93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94 | For large resolutions, 'binary_mask' may consume large amounts of
95 | memory.
96 | """
97 |
98 | assert (points_per_side is None) != (
99 | point_grids is None
100 | ), "Exactly one of points_per_side or point_grid must be provided."
101 | if points_per_side is not None:
102 | self.point_grids = build_all_layer_point_grids(
103 | points_per_side,
104 | crop_n_layers,
105 | crop_n_points_downscale_factor,
106 | )
107 | elif point_grids is not None:
108 | self.point_grids = point_grids
109 | else:
110 | raise ValueError("Can't have both points_per_side and point_grid be None.")
111 |
112 | assert output_mode in [
113 | "binary_mask",
114 | "uncompressed_rle",
115 | "coco_rle",
116 | ], f"Unknown output_mode {output_mode}."
117 | if output_mode == "coco_rle":
118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119 |
120 | if min_mask_region_area > 0:
121 | import cv2 # type: ignore # noqa: F401
122 |
123 | self.predictor = SamPredictor(model)
124 | self.points_per_batch = points_per_batch
125 | self.pred_iou_thresh = pred_iou_thresh
126 | self.stability_score_thresh = stability_score_thresh
127 | self.stability_score_offset = stability_score_offset
128 | self.box_nms_thresh = box_nms_thresh
129 | self.crop_n_layers = crop_n_layers
130 | self.crop_nms_thresh = crop_nms_thresh
131 | self.crop_overlap_ratio = crop_overlap_ratio
132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133 | self.min_mask_region_area = min_mask_region_area
134 | self.output_mode = output_mode
135 |
136 | @torch.no_grad()
137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138 | """
139 | Generates masks for the given image.
140 |
141 | Arguments:
142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143 |
144 | Returns:
145 | list(dict(str, any)): A list over records for masks. Each record is
146 | a dict containing the following keys:
147 | segmentation (dict(str, any) or np.ndarray): The mask. If
148 | output_mode='binary_mask', is an array of shape HW. Otherwise,
149 | is a dictionary containing the RLE.
150 | bbox (list(float)): The box around the mask, in XYWH format.
151 | area (int): The area in pixels of the mask.
152 | predicted_iou (float): The model's own prediction of the mask's
153 | quality. This is filtered by the pred_iou_thresh parameter.
154 | point_coords (list(list(float))): The point coordinates input
155 | to the model to generate this mask.
156 | stability_score (float): A measure of the mask's quality. This
157 | is filtered on using the stability_score_thresh parameter.
158 | crop_box (list(float)): The crop of the image used to generate
159 | the mask, given in XYWH format.
160 | """
161 |
162 | # Generate masks
163 | mask_data = self._generate_masks(image)
164 |
165 | # Filter small disconnected regions and holes in masks
166 | if self.min_mask_region_area > 0:
167 | mask_data = self.postprocess_small_regions(
168 | mask_data,
169 | self.min_mask_region_area,
170 | max(self.box_nms_thresh, self.crop_nms_thresh),
171 | )
172 |
173 | # Encode masks
174 | if self.output_mode == "coco_rle":
175 | mask_data["segmentations"] = [
176 | coco_encode_rle(rle) for rle in mask_data["rles"]
177 | ]
178 | elif self.output_mode == "binary_mask":
179 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
180 | else:
181 | mask_data["segmentations"] = mask_data["rles"]
182 |
183 | # Write mask records
184 | curr_anns = []
185 | for idx in range(len(mask_data["segmentations"])):
186 | ann = {
187 | "segmentation": mask_data["segmentations"][idx],
188 | "area": area_from_rle(mask_data["rles"][idx]),
189 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
190 | "predicted_iou": mask_data["iou_preds"][idx].item(),
191 | "point_coords": [mask_data["points"][idx].tolist()],
192 | "stability_score": mask_data["stability_score"][idx].item(),
193 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
194 | }
195 | curr_anns.append(ann)
196 |
197 | return curr_anns
198 |
199 | def _generate_masks(self, image: np.ndarray) -> MaskData:
200 | orig_size = image.shape[:2]
201 | crop_boxes, layer_idxs = generate_crop_boxes(
202 | orig_size, self.crop_n_layers, self.crop_overlap_ratio
203 | )
204 |
205 | # Iterate over image crops
206 | data = MaskData()
207 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
208 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
209 | data.cat(crop_data)
210 |
211 | # Remove duplicate masks between crops
212 | if len(crop_boxes) > 1:
213 | # Prefer masks from smaller crops
214 | scores = 1 / box_area(data["crop_boxes"])
215 | scores = scores.to(data["boxes"].device)
216 | keep_by_nms = batched_nms(
217 | data["boxes"].float(),
218 | scores,
219 | torch.zeros(len(data["boxes"])), # categories
220 | iou_threshold=self.crop_nms_thresh,
221 | )
222 | data.filter(keep_by_nms)
223 |
224 | data.to_numpy()
225 | return data
226 |
227 | def _process_crop(
228 | self,
229 | image: np.ndarray,
230 | crop_box: List[int],
231 | crop_layer_idx: int,
232 | orig_size: Tuple[int, ...],
233 | ) -> MaskData:
234 | # Crop the image and calculate embeddings
235 | x0, y0, x1, y1 = crop_box
236 | cropped_im = image[y0:y1, x0:x1, :]
237 | cropped_im_size = cropped_im.shape[:2]
238 | self.predictor.set_image(cropped_im)
239 |
240 | # Get points for this crop
241 | points_scale = np.array(cropped_im_size)[None, ::-1]
242 | points_for_image = self.point_grids[crop_layer_idx] * points_scale
243 |
244 | # Generate masks for this crop in batches
245 | data = MaskData()
246 | for (points,) in batch_iterator(self.points_per_batch, points_for_image):
247 | batch_data = self._process_batch(
248 | points, cropped_im_size, crop_box, orig_size
249 | )
250 | data.cat(batch_data)
251 | del batch_data
252 | self.predictor.reset_image()
253 |
254 | # Remove duplicates within this crop.
255 | keep_by_nms = batched_nms(
256 | data["boxes"].float(),
257 | data["iou_preds"],
258 | torch.zeros(len(data["boxes"])), # categories
259 | iou_threshold=self.box_nms_thresh,
260 | )
261 | data.filter(keep_by_nms)
262 |
263 | # Return to the original image frame
264 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
265 | data["points"] = uncrop_points(data["points"], crop_box)
266 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
267 |
268 | return data
269 |
270 | def _process_batch(
271 | self,
272 | points: np.ndarray,
273 | im_size: Tuple[int, ...],
274 | crop_box: List[int],
275 | orig_size: Tuple[int, ...],
276 | ) -> MaskData:
277 | orig_h, orig_w = orig_size
278 |
279 | # Run model on this batch
280 | transformed_points = self.predictor.transform.apply_coords(points, im_size)
281 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
282 | in_labels = torch.ones(
283 | in_points.shape[0], dtype=torch.int, device=in_points.device
284 | )
285 | masks, iou_preds, _ = self.predictor.predict_torch(
286 | in_points[:, None, :],
287 | in_labels[:, None],
288 | multimask_output=True,
289 | return_logits=True,
290 | )
291 |
292 | # Serialize predictions and store in MaskData
293 | data = MaskData(
294 | masks=masks.flatten(0, 1),
295 | iou_preds=iou_preds.flatten(0, 1),
296 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
297 | )
298 | del masks
299 |
300 | # Filter by predicted IoU
301 | if self.pred_iou_thresh > 0.0:
302 | keep_mask = data["iou_preds"] > self.pred_iou_thresh
303 | data.filter(keep_mask)
304 |
305 | # Calculate stability score
306 | data["stability_score"] = calculate_stability_score(
307 | data["masks"],
308 | self.predictor.model.mask_threshold,
309 | self.stability_score_offset,
310 | )
311 | if self.stability_score_thresh > 0.0:
312 | keep_mask = data["stability_score"] >= self.stability_score_thresh
313 | data.filter(keep_mask)
314 |
315 | # Threshold masks and calculate boxes
316 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold
317 | data["boxes"] = batched_mask_to_box(data["masks"])
318 |
319 | # Filter boxes that touch crop boundaries
320 | keep_mask = ~is_box_near_crop_edge(
321 | data["boxes"], crop_box, [0, 0, orig_w, orig_h]
322 | )
323 | if not torch.all(keep_mask):
324 | data.filter(keep_mask)
325 |
326 | # Compress to RLE
327 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
328 | data["rles"] = mask_to_rle_pytorch(data["masks"])
329 | del data["masks"]
330 |
331 | return data
332 |
333 | @staticmethod
334 | def postprocess_small_regions(
335 | mask_data: MaskData, min_area: int, nms_thresh: float
336 | ) -> MaskData:
337 | """
338 | Removes small disconnected regions and holes in masks, then reruns
339 | box NMS to remove any new duplicates.
340 |
341 | Edits mask_data in place.
342 |
343 | Requires open-cv as a dependency.
344 | """
345 | if len(mask_data["rles"]) == 0:
346 | return mask_data
347 |
348 | # Filter small disconnected regions and holes
349 | new_masks = []
350 | scores = []
351 | for rle in mask_data["rles"]:
352 | mask = rle_to_mask(rle)
353 |
354 | mask, changed = remove_small_regions(mask, min_area, mode="holes")
355 | unchanged = not changed
356 | mask, changed = remove_small_regions(mask, min_area, mode="islands")
357 | unchanged = unchanged and not changed
358 |
359 | new_masks.append(torch.as_tensor(mask).unsqueeze(0))
360 | # Give score=0 to changed masks and score=1 to unchanged masks
361 | # so NMS will prefer ones that didn't need postprocessing
362 | scores.append(float(unchanged))
363 |
364 | # Recalculate boxes and remove any new duplicates
365 | masks = torch.cat(new_masks, dim=0)
366 | boxes = batched_mask_to_box(masks)
367 | keep_by_nms = batched_nms(
368 | boxes.float(),
369 | torch.as_tensor(scores),
370 | torch.zeros(len(boxes)), # categories
371 | iou_threshold=nms_thresh,
372 | )
373 |
374 | # Only recalculate RLEs for masks that have changed
375 | for i_mask in keep_by_nms:
376 | if scores[i_mask] == 0.0:
377 | mask_torch = masks[i_mask].unsqueeze(0)
378 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
379 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
380 | mask_data.filter(keep_by_nms)
381 |
382 | return mask_data
383 |
--------------------------------------------------------------------------------
/metaseg/generator/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from functools import partial
8 |
9 | import torch
10 |
11 | from metaseg.modeling import (
12 | ImageEncoderViT,
13 | MaskDecoder,
14 | PromptEncoder,
15 | Sam,
16 | TwoWayTransformer,
17 | )
18 |
19 |
20 | def build_sam_vit_h(checkpoint=None):
21 | return _build_sam(
22 | encoder_embed_dim=1280,
23 | encoder_depth=32,
24 | encoder_num_heads=16,
25 | encoder_global_attn_indexes=[7, 15, 23, 31],
26 | checkpoint=checkpoint,
27 | )
28 |
29 |
30 | build_sam = build_sam_vit_h
31 |
32 |
33 | def build_sam_vit_l(checkpoint=None):
34 | return _build_sam(
35 | encoder_embed_dim=1024,
36 | encoder_depth=24,
37 | encoder_num_heads=16,
38 | encoder_global_attn_indexes=[5, 11, 17, 23],
39 | checkpoint=checkpoint,
40 | )
41 |
42 |
43 | def build_sam_vit_b(checkpoint=None):
44 | return _build_sam(
45 | encoder_embed_dim=768,
46 | encoder_depth=12,
47 | encoder_num_heads=12,
48 | encoder_global_attn_indexes=[2, 5, 8, 11],
49 | checkpoint=checkpoint,
50 | )
51 |
52 |
53 | build_sam_vit_h = {
54 | "default": build_sam,
55 | "vit_h": build_sam,
56 | "vit_l": build_sam_vit_l,
57 | "vit_b": build_sam_vit_b,
58 | }
59 |
60 | sam_model_registry = {
61 | "default": build_sam,
62 | "vit_h": build_sam,
63 | "vit_l": build_sam_vit_l,
64 | "vit_b": build_sam_vit_b,
65 | }
66 |
67 |
68 | def _build_sam(
69 | encoder_embed_dim,
70 | encoder_depth,
71 | encoder_num_heads,
72 | encoder_global_attn_indexes,
73 | checkpoint=None,
74 | ):
75 | prompt_embed_dim = 256
76 | image_size = 1024
77 | vit_patch_size = 16
78 | image_embedding_size = image_size // vit_patch_size
79 | sam = Sam(
80 | image_encoder=ImageEncoderViT(
81 | depth=encoder_depth,
82 | embed_dim=encoder_embed_dim,
83 | img_size=image_size,
84 | mlp_ratio=4,
85 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
86 | num_heads=encoder_num_heads,
87 | patch_size=vit_patch_size,
88 | qkv_bias=True,
89 | use_rel_pos=True,
90 | global_attn_indexes=encoder_global_attn_indexes,
91 | window_size=14,
92 | out_chans=prompt_embed_dim,
93 | ),
94 | prompt_encoder=PromptEncoder(
95 | embed_dim=prompt_embed_dim,
96 | image_embedding_size=(image_embedding_size, image_embedding_size),
97 | input_image_size=(image_size, image_size),
98 | mask_in_chans=16,
99 | ),
100 | mask_decoder=MaskDecoder(
101 | num_multimask_outputs=3,
102 | transformer=TwoWayTransformer(
103 | depth=2,
104 | embedding_dim=prompt_embed_dim,
105 | mlp_dim=2048,
106 | num_heads=8,
107 | ),
108 | transformer_dim=prompt_embed_dim,
109 | iou_head_depth=3,
110 | iou_head_hidden_dim=256,
111 | ),
112 | pixel_mean=[123.675, 116.28, 103.53],
113 | pixel_std=[58.395, 57.12, 57.375],
114 | )
115 | sam.eval()
116 | if checkpoint is not None:
117 | with open(checkpoint, "rb") as f:
118 | state_dict = torch.load(f)
119 | sam.load_state_dict(state_dict)
120 | return sam
121 |
--------------------------------------------------------------------------------
/metaseg/generator/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from metaseg.modeling import Sam
13 | from metaseg.utils.transforms import ResizeLongestSide
14 |
15 |
16 | class SamPredictor:
17 | def __init__(
18 | self,
19 | sam_model: Sam,
20 | ) -> None:
21 | """
22 | Uses SAM to calculate the image embedding for an image, and then
23 | allow repeated, efficient mask prediction given prompts.
24 |
25 | Arguments:
26 | sam_model (Sam): The model to use for mask prediction.
27 | """
28 | super().__init__()
29 | self.model = sam_model
30 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
31 | self.reset_image()
32 |
33 | def set_image(
34 | self,
35 | image: np.ndarray,
36 | image_format: str = "RGB",
37 | ) -> None:
38 | """
39 | Calculates the image embeddings for the provided image, allowing
40 | masks to be predicted with the 'predict' method.
41 |
42 | Arguments:
43 | image (np.ndarray): The image for calculating masks. Expects an
44 | image in HWC uint8 format, with pixel values in [0, 255].
45 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
46 | """
47 | assert image_format in [
48 | "RGB",
49 | "BGR",
50 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
51 | if image_format != self.model.image_format:
52 | image = image[..., ::-1]
53 |
54 | # Transform the image to the form expected by the model
55 | input_image = self.transform.apply_image(image)
56 | input_image_torch = torch.as_tensor(input_image, device=self.device)
57 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
58 | None, :, :, :
59 | ]
60 |
61 | self.set_torch_image(input_image_torch, image.shape[:2])
62 |
63 | @torch.no_grad()
64 | def set_torch_image(
65 | self,
66 | transformed_image: torch.Tensor,
67 | original_image_size: Tuple[int, ...],
68 | ) -> None:
69 | """
70 | Calculates the image embeddings for the provided image, allowing
71 | masks to be predicted with the 'predict' method. Expects the input
72 | image to be already transformed to the format expected by the model.
73 |
74 | Arguments:
75 | transformed_image (torch.Tensor): The input image, with shape
76 | 1x3xHxW, which has been transformed with ResizeLongestSide.
77 | original_image_size (tuple(int, int)): The size of the image
78 | before transformation, in (H, W) format.
79 | """
80 | assert (
81 | len(transformed_image.shape) == 4
82 | and transformed_image.shape[1] == 3
83 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
84 | ), (
85 | f"set_torch_image input must be BCHW with long side "
86 | f"{self.model.image_encoder.img_size}."
87 | )
88 | self.reset_image()
89 |
90 | self.original_size = original_image_size
91 | self.input_size = tuple(transformed_image.shape[-2:])
92 | input_image = self.model.preprocess(transformed_image)
93 | self.features = self.model.image_encoder(input_image)
94 | self.is_image_set = True
95 |
96 | def predict(
97 | self,
98 | point_coords: Optional[np.ndarray] = None,
99 | point_labels: Optional[np.ndarray] = None,
100 | box: Optional[np.ndarray] = None,
101 | mask_input: Optional[np.ndarray] = None,
102 | multimask_output: bool = True,
103 | return_logits: bool = False,
104 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
105 | """
106 | Predict masks for the given input prompts, using the currently set image.
107 |
108 | Arguments:
109 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
110 | model. Each point is in (X,Y) in pixels.
111 | point_labels (np.ndarray or None): A length N array of labels for the
112 | point prompts. 1 indicates a foreground point and 0 indicates a
113 | background point.
114 | box (np.ndarray or None): A length 4 array given a box prompt to the
115 | model, in XYXY format.
116 | mask_input (np.ndarray): A low resolution mask input to the model, typically
117 | coming from a previous prediction iteration. Has form 1xHxW, where
118 | for SAM, H=W=256.
119 | multimask_output (bool): If true, the model will return three masks.
120 | For ambiguous input prompts (such as a single click), this will often
121 | produce better masks than a single prediction. If only a single
122 | mask is needed, the model's predicted quality score can be used
123 | to select the best mask. For non-ambiguous prompts, such as multiple
124 | input prompts, multimask_output=False can give better results.
125 | return_logits (bool): If true, returns un-thresholded masks logits
126 | instead of a binary mask.
127 |
128 | Returns:
129 | (np.ndarray): The output masks in CxHxW format, where C is the
130 | number of masks, and (H, W) is the original image size.
131 | (np.ndarray): An array of length C containing the model's
132 | predictions for the quality of each mask.
133 | (np.ndarray): An array of shape CxHxW, where C is the number
134 | of masks and H=W=256. These low resolution logits can be passed to
135 | a subsequent iteration as mask input.
136 | """
137 | if not self.is_image_set:
138 | raise RuntimeError(
139 | "An image must be set with .set_image(...) before mask prediction."
140 | )
141 |
142 | # Transform input prompts
143 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
144 | if point_coords is not None:
145 | assert (
146 | point_labels is not None
147 | ), "point_labels must be supplied if point_coords is supplied."
148 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
149 | coords_torch = torch.as_tensor(
150 | point_coords, dtype=torch.float, device=self.device
151 | )
152 | labels_torch = torch.as_tensor(
153 | point_labels, dtype=torch.int, device=self.device
154 | )
155 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
156 | if box is not None:
157 | box = self.transform.apply_boxes(box, self.original_size)
158 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
159 | box_torch = box_torch[None, :]
160 | if mask_input is not None:
161 | mask_input_torch = torch.as_tensor(
162 | mask_input, dtype=torch.float, device=self.device
163 | )
164 | mask_input_torch = mask_input_torch[None, :, :, :]
165 |
166 | masks, iou_predictions, low_res_masks = self.predict_torch(
167 | coords_torch,
168 | labels_torch,
169 | box_torch,
170 | mask_input_torch,
171 | multimask_output,
172 | return_logits=return_logits,
173 | )
174 |
175 | masks = masks[0].detach().cpu().numpy()
176 | iou_predictions = iou_predictions[0].detach().cpu().numpy()
177 | low_res_masks = low_res_masks[0].detach().cpu().numpy()
178 | return masks, iou_predictions, low_res_masks
179 |
180 | @torch.no_grad()
181 | def predict_torch(
182 | self,
183 | point_coords: Optional[torch.Tensor],
184 | point_labels: Optional[torch.Tensor],
185 | boxes: Optional[torch.Tensor] = None,
186 | mask_input: Optional[torch.Tensor] = None,
187 | multimask_output: bool = True,
188 | return_logits: bool = False,
189 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
190 | """
191 | Predict masks for the given input prompts, using the currently set image.
192 | Input prompts are batched torch tensors and are expected to already be
193 | transformed to the input frame using ResizeLongestSide.
194 |
195 | Arguments:
196 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
197 | model. Each point is in (X,Y) in pixels.
198 | point_labels (torch.Tensor or None): A BxN array of labels for the
199 | point prompts. 1 indicates a foreground point and 0 indicates a
200 | background point.
201 | box (np.ndarray or None): A Bx4 array given a box prompt to the
202 | model, in XYXY format.
203 | mask_input (np.ndarray): A low resolution mask input to the model, typically
204 | coming from a previous prediction iteration. Has form Bx1xHxW, where
205 | for SAM, H=W=256. Masks returned by a previous iteration of the
206 | predict method do not need further transformation.
207 | multimask_output (bool): If true, the model will return three masks.
208 | For ambiguous input prompts (such as a single click), this will often
209 | produce better masks than a single prediction. If only a single
210 | mask is needed, the model's predicted quality score can be used
211 | to select the best mask. For non-ambiguous prompts, such as multiple
212 | input prompts, multimask_output=False can give better results.
213 | return_logits (bool): If true, returns un-thresholded masks logits
214 | instead of a binary mask.
215 |
216 | Returns:
217 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
218 | number of masks, and (H, W) is the original image size.
219 | (torch.Tensor): An array of shape BxC containing the model's
220 | predictions for the quality of each mask.
221 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
222 | of masks and H=W=256. These low res logits can be passed to
223 | a subsequent iteration as mask input.
224 | """
225 | if not self.is_image_set:
226 | raise RuntimeError(
227 | "An image must be set with .set_image(...) before mask prediction."
228 | )
229 |
230 | if point_coords is not None:
231 | points = (point_coords, point_labels)
232 | else:
233 | points = None
234 |
235 | # Embed prompts
236 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
237 | points=points,
238 | boxes=boxes,
239 | masks=mask_input,
240 | )
241 |
242 | # Predict masks
243 | low_res_masks, iou_predictions = self.model.mask_decoder(
244 | image_embeddings=self.features,
245 | image_pe=self.model.prompt_encoder.get_dense_pe(),
246 | sparse_prompt_embeddings=sparse_embeddings,
247 | dense_prompt_embeddings=dense_embeddings,
248 | multimask_output=multimask_output,
249 | )
250 |
251 | # Upscale the masks to the original image resolution
252 | masks = self.model.postprocess_masks(
253 | low_res_masks, self.input_size, self.original_size
254 | )
255 |
256 | if not return_logits:
257 | masks = masks > self.model.mask_threshold
258 |
259 | return masks, iou_predictions, low_res_masks
260 |
261 | def get_image_embedding(self) -> torch.Tensor:
262 | """
263 | Returns the image embeddings for the currently set image, with
264 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
265 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
266 | """
267 | if not self.is_image_set:
268 | raise RuntimeError(
269 | "An image must be set with .set_image(...) to generate an embedding."
270 | )
271 | assert (
272 | self.features is not None
273 | ), "Features must exist if an image has been set."
274 | return self.features
275 |
276 | @property
277 | def device(self) -> torch.device:
278 | return self.model.device
279 |
280 | def reset_image(self) -> None:
281 | """Resets the currently set image."""
282 | self.is_image_set = False
283 | self.features = None
284 | self.orig_h = None
285 | self.orig_w = None
286 | self.input_h = None
287 | self.input_w = None
288 |
--------------------------------------------------------------------------------
/metaseg/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .common import LayerNorm2d as LayerNorm2d
8 | from .common import MLPBlock as MLPBlock
9 | from .image_encoder import ImageEncoderViT as ImageEncoderViT
10 | from .mask_decoder import MaskDecoder as MaskDecoder
11 | from .prompt_encoder import PositionEmbeddingRandom as PositionEmbeddingRandom
12 | from .prompt_encoder import PromptEncoder as PromptEncoder
13 | from .sam import Sam as Sam
14 | from .transformer import Attention as Attention
15 | from .transformer import TwoWayAttentionBlock as TwoWayAttentionBlock
16 | from .transformer import TwoWayTransformer as TwoWayTransformer
17 |
--------------------------------------------------------------------------------
/metaseg/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Type
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/metaseg/modeling/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple, Type
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from metaseg.modeling.common import LayerNorm2d, MLPBlock
14 |
15 |
16 | # This class and its supporting functions below lightly adapted from the
17 | # ViTDet backbone available at:
18 | # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
19 | class ImageEncoderViT(nn.Module):
20 | def __init__(
21 | self,
22 | img_size: int = 1024,
23 | patch_size: int = 16,
24 | in_chans: int = 3,
25 | embed_dim: int = 768,
26 | depth: int = 12,
27 | num_heads: int = 12,
28 | mlp_ratio: float = 4.0,
29 | out_chans: int = 256,
30 | qkv_bias: bool = True,
31 | norm_layer: Type[nn.Module] = nn.LayerNorm,
32 | act_layer: Type[nn.Module] = nn.GELU,
33 | use_abs_pos: bool = True,
34 | use_rel_pos: bool = False,
35 | rel_pos_zero_init: bool = True,
36 | window_size: int = 0,
37 | global_attn_indexes: Tuple[int, ...] = (),
38 | ) -> None:
39 | """
40 | Args:
41 | img_size (int): Input image size.
42 | patch_size (int): Patch size.
43 | in_chans (int): Number of input image channels.
44 | embed_dim (int): Patch embedding dimension.
45 | depth (int): Depth of ViT.
46 | num_heads (int): Number of attention heads in each ViT block.
47 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
48 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
49 | norm_layer (nn.Module): Normalization layer.
50 | act_layer (nn.Module): Activation layer.
51 | use_abs_pos (bool): If True, use absolute
52 | positional embeddings.
53 | use_rel_pos (bool): If True, add relative
54 | positional embeddings to the attention map.
55 | rel_pos_zero_init (bool): If True, zero initialize
56 | relative positional parameters.
57 | window_size (int): Window size for window
58 | attention blocks.
59 | global_attn_indexes (list): Indexes for
60 | blocks using global attention.
61 | """
62 | super().__init__()
63 | self.img_size = img_size
64 |
65 | self.patch_embed = PatchEmbed(
66 | kernel_size=(patch_size, patch_size),
67 | stride=(patch_size, patch_size),
68 | in_chans=in_chans,
69 | embed_dim=embed_dim,
70 | )
71 |
72 | self.pos_embed: Optional[nn.Parameter] = None
73 | if use_abs_pos:
74 | # Initialize absolute positional embedding with pretrain image size.
75 | self.pos_embed = nn.Parameter(
76 | torch.zeros(
77 | 1, img_size // patch_size, img_size // patch_size, embed_dim
78 | )
79 | )
80 |
81 | self.blocks = nn.ModuleList()
82 | for i in range(depth):
83 | block = Block(
84 | dim=embed_dim,
85 | num_heads=num_heads,
86 | mlp_ratio=mlp_ratio,
87 | qkv_bias=qkv_bias,
88 | norm_layer=norm_layer,
89 | act_layer=act_layer,
90 | use_rel_pos=use_rel_pos,
91 | rel_pos_zero_init=rel_pos_zero_init,
92 | window_size=window_size if i not in global_attn_indexes else 0,
93 | input_size=(img_size // patch_size, img_size // patch_size),
94 | )
95 | self.blocks.append(block)
96 |
97 | self.neck = nn.Sequential(
98 | nn.Conv2d(
99 | embed_dim,
100 | out_chans,
101 | kernel_size=1,
102 | bias=False,
103 | ),
104 | LayerNorm2d(out_chans),
105 | nn.Conv2d(
106 | out_chans,
107 | out_chans,
108 | kernel_size=3,
109 | padding=1,
110 | bias=False,
111 | ),
112 | LayerNorm2d(out_chans),
113 | )
114 |
115 | def forward(self, x: torch.Tensor) -> torch.Tensor:
116 | x = self.patch_embed(x)
117 | if self.pos_embed is not None:
118 | x = x + self.pos_embed
119 |
120 | for blk in self.blocks:
121 | x = blk(x)
122 |
123 | x = self.neck(x.permute(0, 3, 1, 2))
124 |
125 | return x
126 |
127 |
128 | class Block(nn.Module):
129 | """Transformer blocks with support of window
130 | attention and residual propagation blocks"""
131 |
132 | def __init__(
133 | self,
134 | dim: int,
135 | num_heads: int,
136 | mlp_ratio: float = 4.0,
137 | qkv_bias: bool = True,
138 | norm_layer: Type[nn.Module] = nn.LayerNorm,
139 | act_layer: Type[nn.Module] = nn.GELU,
140 | use_rel_pos: bool = False,
141 | rel_pos_zero_init: bool = True,
142 | window_size: int = 0,
143 | input_size: Optional[Tuple[int, int]] = None,
144 | ) -> None:
145 | """
146 | Args:
147 | dim (int): Number of input channels.
148 | num_heads (int): Number of attention heads in each ViT block.
149 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
150 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
151 | norm_layer (nn.Module): Normalization layer.
152 | act_layer (nn.Module): Activation layer.
153 | use_rel_pos (bool): If True, add relative
154 | positional embeddings to the attention map.
155 | rel_pos_zero_init (bool): If True, zero
156 | initialize relative positional parameters.
157 | window_size (int): Window size for
158 | window attention blocks. If it equals 0, then
159 | use global attention.
160 | input_size (int or None): Input resolution for
161 | calculating the relative positional parameter size.
162 | """
163 | super().__init__()
164 | self.norm1 = norm_layer(dim)
165 | self.attn = Attention(
166 | dim,
167 | num_heads=num_heads,
168 | qkv_bias=qkv_bias,
169 | use_rel_pos=use_rel_pos,
170 | rel_pos_zero_init=rel_pos_zero_init,
171 | input_size=input_size if window_size == 0 else (window_size, window_size),
172 | )
173 |
174 | self.norm2 = norm_layer(dim)
175 | self.mlp = MLPBlock(
176 | embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
177 | )
178 |
179 | self.window_size = window_size
180 |
181 | def forward(self, x: torch.Tensor) -> torch.Tensor:
182 | shortcut = x
183 | x = self.norm1(x)
184 | # Window partition
185 | if self.window_size > 0:
186 | H, W = x.shape[1], x.shape[2]
187 | x, pad_hw = window_partition(x, self.window_size)
188 |
189 | x = self.attn(x)
190 | # Reverse window partition
191 | if self.window_size > 0:
192 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
193 |
194 | x = shortcut + x
195 | x = x + self.mlp(self.norm2(x))
196 |
197 | return x
198 |
199 |
200 | class Attention(nn.Module):
201 | """Multi-head Attention block with relative position embeddings."""
202 |
203 | def __init__(
204 | self,
205 | dim: int,
206 | num_heads: int = 8,
207 | qkv_bias: bool = True,
208 | use_rel_pos: bool = False,
209 | rel_pos_zero_init: bool = True,
210 | input_size: Optional[Tuple[int, int]] = None,
211 | ) -> None:
212 | """
213 | Args:
214 | dim(int): Number of input channels.
215 | num_heads(int): Number of attention heads.
216 | qkv_bias(bool): If True, add a learnable
217 | bias to query, key, value.
218 | rel_pos(bool): If True, add relative
219 | positional embeddings to the attention map.
220 | rel_pos_zero_init(bool): If True, zero initialize
221 | relative positional parameters.
222 | input_size(int or None): Input resolution for
223 | calculating the relative positional
224 | parameter size.
225 | """
226 | super().__init__()
227 | self.num_heads = num_heads
228 | head_dim = dim // num_heads
229 | self.scale = head_dim**-0.5
230 |
231 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
232 | self.proj = nn.Linear(dim, dim)
233 |
234 | self.use_rel_pos = use_rel_pos
235 | if self.use_rel_pos:
236 | assert (
237 | input_size is not None
238 | ), "Input size must be provided if using relative positional encoding."
239 | # initialize relative positional embeddings
240 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
241 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
242 |
243 | def forward(self, x: torch.Tensor) -> torch.Tensor:
244 | B, H, W, _ = x.shape
245 | # qkv with shape (3, B, nHead, H * W, C)
246 | qkv = (
247 | self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
248 | )
249 | # q, k, v with shape (B * nHead, H * W, C)
250 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
251 |
252 | attn = (q * self.scale) @ k.transpose(-2, -1)
253 |
254 | if self.use_rel_pos:
255 | attn = add_decomposed_rel_pos(
256 | attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
257 | )
258 |
259 | attn = attn.softmax(dim=-1)
260 | x = (
261 | (attn @ v)
262 | .view(B, self.num_heads, H, W, -1)
263 | .permute(0, 2, 3, 1, 4)
264 | .reshape(B, H, W, -1)
265 | )
266 | x = self.proj(x)
267 |
268 | return x
269 |
270 |
271 | def window_partition(
272 | x: torch.Tensor, window_size: int
273 | ) -> Tuple[torch.Tensor, Tuple[int, int]]:
274 | """
275 | Partition into non-overlapping windows with padding if needed.
276 | Args:
277 | x (tensor): input tokens with [B, H, W, C].
278 | window_size (int): window size.
279 |
280 | Returns:
281 | windows: windows after partition with
282 | [B * num_windows, window_size, window_size, C].
283 | (Hp, Wp): padded height and width before partition
284 | """
285 | B, H, W, C = x.shape
286 |
287 | pad_h = (window_size - H % window_size) % window_size
288 | pad_w = (window_size - W % window_size) % window_size
289 | if pad_h > 0 or pad_w > 0:
290 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
291 | Hp, Wp = H + pad_h, W + pad_w
292 |
293 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
294 | windows = (
295 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
296 | )
297 | return windows, (Hp, Wp)
298 |
299 |
300 | def window_unpartition(
301 | windows: torch.Tensor,
302 | window_size: int,
303 | pad_hw: Tuple[int, int],
304 | hw: Tuple[int, int],
305 | ) -> torch.Tensor:
306 | """
307 | Window unpartition into original sequences and removing padding.
308 | Args:
309 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
310 | window_size (int): window size.
311 | pad_hw (Tuple): padded height and width (Hp, Wp).
312 | hw (Tuple): original height and width (H, W) before padding.
313 |
314 | Returns:
315 | x: unpartitioned sequences with [B, H, W, C].
316 | """
317 | Hp, Wp = pad_hw
318 | H, W = hw
319 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
320 | x = windows.view(
321 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1
322 | )
323 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
324 |
325 | if Hp > H or Wp > W:
326 | x = x[:, :H, :W, :].contiguous()
327 | return x
328 |
329 |
330 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
331 | """
332 | Get relative positional embeddings according to the relative positions of
333 | query and key sizes.
334 | Args:
335 | q_size (int): size of query q.
336 | k_size (int): size of key k.
337 | rel_pos (Tensor): relative position embeddings (L, C).
338 |
339 | Returns:
340 | Extracted positional embeddings according to relative positions.
341 | """
342 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
343 | # Interpolate rel pos if needed.
344 | if rel_pos.shape[0] != max_rel_dist:
345 | # Interpolate rel pos.
346 | rel_pos_resized = F.interpolate(
347 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
348 | size=max_rel_dist,
349 | mode="linear",
350 | )
351 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
352 | else:
353 | rel_pos_resized = rel_pos
354 |
355 | # Scale the coords with short length if shapes for q and k are different.
356 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
357 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
358 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
359 |
360 | return rel_pos_resized[relative_coords.long()]
361 |
362 |
363 | def add_decomposed_rel_pos(
364 | attn: torch.Tensor,
365 | q: torch.Tensor,
366 | rel_pos_h: torch.Tensor,
367 | rel_pos_w: torch.Tensor,
368 | q_size: Tuple[int, int],
369 | k_size: Tuple[int, int],
370 | ) -> torch.Tensor:
371 | """
372 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
373 | commit_sha: 19786631e330df9f3622e5402b4a419a263a2c80
374 | https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
375 | Args:
376 | attn (Tensor): attention map.
377 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
378 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
379 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
380 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
381 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
382 |
383 | Returns:
384 | attn (Tensor): attention map with added relative positional embeddings.
385 | """
386 | q_h, q_w = q_size
387 | k_h, k_w = k_size
388 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
389 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
390 |
391 | B, _, dim = q.shape
392 | r_q = q.reshape(B, q_h, q_w, dim)
393 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
394 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
395 |
396 | attn = (
397 | attn.view(B, q_h, q_w, k_h, k_w)
398 | + rel_h[:, :, :, :, None]
399 | + rel_w[:, :, :, None, :]
400 | ).view(B, q_h * q_w, k_h * k_w)
401 |
402 | return attn
403 |
404 |
405 | class PatchEmbed(nn.Module):
406 | """
407 | Image to Patch Embedding.
408 | """
409 |
410 | def __init__(
411 | self,
412 | kernel_size: Tuple[int, int] = (16, 16),
413 | stride: Tuple[int, int] = (16, 16),
414 | padding: Tuple[int, int] = (0, 0),
415 | in_chans: int = 3,
416 | embed_dim: int = 768,
417 | ) -> None:
418 | """
419 | Args:
420 | kernel_size (Tuple): kernel size of the projection layer.
421 | stride (Tuple): stride of the projection layer.
422 | padding (Tuple): padding size of the projection layer.
423 | in_chans (int): Number of input image channels.
424 | embed_dim (int): embed_dim (int): Patch embedding dimension.
425 | """
426 | super().__init__()
427 |
428 | self.proj = nn.Conv2d(
429 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
430 | )
431 |
432 | def forward(self, x: torch.Tensor) -> torch.Tensor:
433 | x = self.proj(x)
434 | # B C H W -> B H W C
435 | x = x.permute(0, 2, 3, 1)
436 | return x
437 |
--------------------------------------------------------------------------------
/metaseg/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from metaseg.modeling.common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | tranformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(
55 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
56 | ),
57 | LayerNorm2d(transformer_dim // 4),
58 | activation(),
59 | nn.ConvTranspose2d(
60 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
61 | ),
62 | activation(),
63 | )
64 | self.output_hypernetworks_mlps = nn.ModuleList(
65 | [
66 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
67 | for i in range(self.num_mask_tokens)
68 | ]
69 | )
70 |
71 | self.iou_prediction_head = MLP(
72 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
73 | )
74 |
75 | def forward(
76 | self,
77 | image_embeddings: torch.Tensor,
78 | image_pe: torch.Tensor,
79 | sparse_prompt_embeddings: torch.Tensor,
80 | dense_prompt_embeddings: torch.Tensor,
81 | multimask_output: bool,
82 | ) -> Tuple[torch.Tensor, torch.Tensor]:
83 | """
84 | Predict masks given image and prompt embeddings.
85 |
86 | Arguments:
87 | image_embeddings (torch.Tensor): the embeddings from the image encoder
88 | image_pe (torch.Tensor): positional encoding
89 | with the shape of image_embeddings
90 | sparse_prompt_embeddings (torch.Tensor): the embeddings
91 | of the points and boxes
92 | dense_prompt_embeddings (torch.Tensor): the embeddings
93 | of the mask inputs
94 | multimask_output (bool): Whether to return
95 | multiple masks or a single mask.
96 |
97 | Returns:
98 | torch.Tensor: batched predicted masks
99 | torch.Tensor: batched predictions of mask quality
100 | """
101 | masks, iou_pred = self.predict_masks(
102 | image_embeddings=image_embeddings,
103 | image_pe=image_pe,
104 | sparse_prompt_embeddings=sparse_prompt_embeddings,
105 | dense_prompt_embeddings=dense_prompt_embeddings,
106 | )
107 |
108 | # Select the correct mask or masks for outptu
109 | if multimask_output:
110 | mask_slice = slice(1, None)
111 | else:
112 | mask_slice = slice(0, 1)
113 | masks = masks[:, mask_slice, :, :]
114 | iou_pred = iou_pred[:, mask_slice]
115 |
116 | # Prepare output
117 | return masks, iou_pred
118 |
119 | def predict_masks(
120 | self,
121 | image_embeddings: torch.Tensor,
122 | image_pe: torch.Tensor,
123 | sparse_prompt_embeddings: torch.Tensor,
124 | dense_prompt_embeddings: torch.Tensor,
125 | ) -> Tuple[torch.Tensor, torch.Tensor]:
126 | """Predicts masks. See 'forward' for more details."""
127 | # Concatenate output tokens
128 | output_tokens = torch.cat(
129 | [self.iou_token.weight, self.mask_tokens.weight], dim=0
130 | )
131 | output_tokens = output_tokens.unsqueeze(0).expand(
132 | sparse_prompt_embeddings.size(0), -1, -1
133 | )
134 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
135 |
136 | # Expand per-image data in batch direction to be per-mask
137 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
138 | src = src + dense_prompt_embeddings
139 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
140 | b, c, h, w = src.shape
141 |
142 | # Run the transformer
143 | hs, src = self.transformer(src, pos_src, tokens)
144 | iou_token_out = hs[:, 0, :]
145 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
146 |
147 | # Upscale mask embeddings and predict masks using the mask tokens
148 | src = src.transpose(1, 2).view(b, c, h, w)
149 | upscaled_embedding = self.output_upscaling(src)
150 | hyper_in_list: List[torch.Tensor] = []
151 | for i in range(self.num_mask_tokens):
152 | hyper_in_list.append(
153 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
154 | )
155 | hyper_in = torch.stack(hyper_in_list, dim=1)
156 | b, c, h, w = upscaled_embedding.shape
157 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
158 |
159 | # Generate mask quality predictions
160 | iou_pred = self.iou_prediction_head(iou_token_out)
161 |
162 | return masks, iou_pred
163 |
164 |
165 | # Lightly adapted from
166 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
167 | class MLP(nn.Module):
168 | def __init__(
169 | self,
170 | input_dim: int,
171 | hidden_dim: int,
172 | output_dim: int,
173 | num_layers: int,
174 | sigmoid_output: bool = False,
175 | ) -> None:
176 | super().__init__()
177 | self.num_layers = num_layers
178 | h = [hidden_dim] * (num_layers - 1)
179 | self.layers = nn.ModuleList(
180 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
181 | )
182 | self.sigmoid_output = sigmoid_output
183 |
184 | def forward(self, x):
185 | for i, layer in enumerate(self.layers):
186 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
187 | if self.sigmoid_output:
188 | x = F.sigmoid(x)
189 | return x
190 |
--------------------------------------------------------------------------------
/metaseg/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Optional, Tuple, Type
8 |
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 |
13 | from metaseg.modeling.common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [
47 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
48 | ]
49 | self.point_embeddings = nn.ModuleList(point_embeddings)
50 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
51 |
52 | self.mask_input_size = (
53 | 4 * image_embedding_size[0],
54 | 4 * image_embedding_size[1],
55 | )
56 | self.mask_downscaling = nn.Sequential(
57 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
58 | LayerNorm2d(mask_in_chans // 4),
59 | activation(),
60 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
61 | LayerNorm2d(mask_in_chans),
62 | activation(),
63 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
64 | )
65 | self.no_mask_embed = nn.Embedding(1, embed_dim)
66 |
67 | def get_dense_pe(self) -> torch.Tensor:
68 | """
69 | Returns the positional encoding used to encode point prompts,
70 | applied to a dense set of points the shape of the image encoding.
71 |
72 | Returns:
73 | torch.Tensor: Positional encoding with shape
74 | 1x(embed_dim)x(embedding_h)x(embedding_w)
75 | """
76 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
77 |
78 | def _embed_points(
79 | self,
80 | points: torch.Tensor,
81 | labels: torch.Tensor,
82 | pad: bool,
83 | ) -> torch.Tensor:
84 | """Embeds point prompts."""
85 | points = points + 0.5 # Shift to center of pixel
86 | if pad:
87 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
88 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
89 | points = torch.cat([points, padding_point], dim=1)
90 | labels = torch.cat([labels, padding_label], dim=1)
91 | point_embedding = self.pe_layer.forward_with_coords(
92 | points, self.input_image_size
93 | )
94 | point_embedding[labels == -1] = 0.0
95 | point_embedding[labels == -1] += self.not_a_point_embed.weight
96 | point_embedding[labels == 0] += self.point_embeddings[0].weight
97 | point_embedding[labels == 1] += self.point_embeddings[1].weight
98 | return point_embedding
99 |
100 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
101 | """Embeds box prompts."""
102 | boxes = boxes + 0.5 # Shift to center of pixel
103 | coords = boxes.reshape(-1, 2, 2)
104 | corner_embedding = self.pe_layer.forward_with_coords(
105 | coords, self.input_image_size
106 | )
107 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
108 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
109 | return corner_embedding
110 |
111 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
112 | """Embeds mask inputs."""
113 | mask_embedding = self.mask_downscaling(masks)
114 | return mask_embedding
115 |
116 | def _get_batch_size(
117 | self,
118 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
119 | boxes: Optional[torch.Tensor],
120 | masks: Optional[torch.Tensor],
121 | ) -> int:
122 | """
123 | Gets the batch size of the output given the batch size of the input prompts.
124 | """
125 | if points is not None:
126 | return points[0].shape[0]
127 | elif boxes is not None:
128 | return boxes.shape[0]
129 | elif masks is not None:
130 | return masks.shape[0]
131 | else:
132 | return 1
133 |
134 | def _get_device(self) -> torch.device:
135 | return self.point_embeddings[0].weight.device
136 |
137 | def forward(
138 | self,
139 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
140 | boxes: Optional[torch.Tensor],
141 | masks: Optional[torch.Tensor],
142 | ) -> Tuple[torch.Tensor, torch.Tensor]:
143 | """
144 | Embeds different types of prompts, returning both sparse and dense
145 | embeddings.
146 |
147 | Arguments:
148 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
149 | and labels to embed.
150 | boxes (torch.Tensor or none): boxes to embed
151 | masks (torch.Tensor or none): masks to embed
152 |
153 | Returns:
154 | torch.Tensor: sparse embeddings for the points and boxes, with shape
155 | BxNx(embed_dim), where N is determined by the number of input points
156 | and boxes.
157 | torch.Tensor: dense embeddings for the masks, in the shape
158 | Bx(embed_dim)x(embed_H)x(embed_W)
159 | """
160 | bs = self._get_batch_size(points, boxes, masks)
161 | sparse_embeddings = torch.empty(
162 | (bs, 0, self.embed_dim), device=self._get_device()
163 | )
164 | if points is not None:
165 | coords, labels = points
166 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
167 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
168 | if boxes is not None:
169 | box_embeddings = self._embed_boxes(boxes)
170 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
171 |
172 | if masks is not None:
173 | dense_embeddings = self._embed_masks(masks)
174 | else:
175 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
176 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
177 | )
178 |
179 | return sparse_embeddings, dense_embeddings
180 |
181 |
182 | class PositionEmbeddingRandom(nn.Module):
183 | """
184 | Positional encoding using random spatial frequencies.
185 | """
186 |
187 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
188 | super().__init__()
189 | if scale is None or scale <= 0.0:
190 | scale = 1.0
191 | self.register_buffer(
192 | "positional_encoding_gaussian_matrix",
193 | scale * torch.randn((2, num_pos_feats)),
194 | )
195 |
196 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
197 | """Positionally encode points that are normalized to [0,1]."""
198 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
199 | coords = 2 * coords - 1
200 | coords = coords @ self.positional_encoding_gaussian_matrix
201 | coords = 2 * np.pi * coords
202 | # outputs d_1 x ... x d_n x C shape
203 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
204 |
205 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
206 | """Generate positional encoding for a grid of the specified size."""
207 | h, w = size
208 | device: Any = self.positional_encoding_gaussian_matrix.device
209 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
210 | y_embed = grid.cumsum(dim=0) - 0.5
211 | x_embed = grid.cumsum(dim=1) - 0.5
212 | y_embed = y_embed / h
213 | x_embed = x_embed / w
214 |
215 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
216 | return pe.permute(2, 0, 1) # C x H x W
217 |
218 | def forward_with_coords(
219 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
220 | ) -> torch.Tensor:
221 | """Positionally encode points that are not normalized to [0,1]."""
222 | coords = coords_input.clone()
223 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
224 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
225 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
226 |
--------------------------------------------------------------------------------
/metaseg/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Dict, List, Tuple
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from metaseg.modeling.image_encoder import ImageEncoderViT
14 | from metaseg.modeling.mask_decoder import MaskDecoder
15 | from metaseg.modeling.prompt_encoder import PromptEncoder
16 |
17 |
18 | class Sam(nn.Module):
19 | mask_threshold: float = 0.0
20 | image_format: str = "RGB"
21 |
22 | def __init__(
23 | self,
24 | image_encoder: ImageEncoderViT,
25 | prompt_encoder: PromptEncoder,
26 | mask_decoder: MaskDecoder,
27 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
28 | pixel_std: List[float] = [58.395, 57.12, 57.375],
29 | ) -> None:
30 | """
31 | SAM predicts object masks from an image and input prompts.
32 |
33 | Arguments:
34 | image_encoder (ImageEncoderViT): The backbone used to encode the
35 | image into image embeddings that allow for efficient mask prediction.
36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38 | and encoded prompts.
39 | pixel_mean (list(float)): Mean values for normalizing
40 | pixels in the input image.
41 | pixel_std (list(float)): Std values for normalizing
42 | pixels in the input image.
43 | """
44 | super().__init__()
45 | self.image_encoder = image_encoder
46 | self.prompt_encoder = prompt_encoder
47 | self.mask_decoder = mask_decoder
48 | self.register_buffer(
49 | "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
50 | )
51 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
52 |
53 | @property
54 | def device(self) -> Any:
55 | return self.pixel_mean.device
56 |
57 | @torch.no_grad()
58 | def forward(
59 | self,
60 | batched_input: List[Dict[str, Any]],
61 | multimask_output: bool,
62 | ) -> List[Dict[str, torch.Tensor]]:
63 | """
64 | Predicts masks end-to-end from provided images and prompts.
65 | If prompts are not known in advance, using SamPredictor is
66 | recommended over calling the model directly.
67 |
68 | Arguments:
69 | batched_input (list(dict)): A list over input images, each a
70 | dictionary with the following keys. A prompt key can be
71 | excluded if it is not present.
72 | 'image': The image as a torch tensor in 3xHxW format,
73 | already transformed for input to the model.
74 | 'original_size': (tuple(int, int)) The original size of
75 | the image before transformation, as (H, W).
76 | 'point_coords': (torch.Tensor) Batched point prompts for
77 | this image, with shape BxNx2. Already transformed to the
78 | input frame of the model.
79 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
80 | with shape BxN.
81 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
82 | Already transformed to the input frame of the model.
83 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
84 | in the form Bx1xHxW.
85 | multimask_output (bool): Whether the model should predict multiple
86 | disambiguating masks, or return a single mask.
87 |
88 | Returns:
89 | (list(dict)): A list over input images, where each element is
90 | as dictionary with the following keys.
91 | 'masks': (torch.Tensor) Batched binary mask predictions,
92 | with shape BxCxHxW, where B is the number of input promts,
93 | C is determiend by multimask_output, and (H, W) is the
94 | original size of the image.
95 | 'iou_predictions': (torch.Tensor) The model's predictions
96 | of mask quality, in shape BxC.
97 | 'low_res_logits': (torch.Tensor) Low resolution logits with
98 | shape BxCxHxW, where H=W=256. Can be passed as mask input
99 | to subsequent iterations of prediction.
100 | """
101 | input_images = torch.stack(
102 | [self.preprocess(x["image"]) for x in batched_input], dim=0
103 | )
104 | image_embeddings = self.image_encoder(input_images)
105 |
106 | outputs = []
107 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
108 | if "point_coords" in image_record:
109 | points = (image_record["point_coords"], image_record["point_labels"])
110 | else:
111 | points = None
112 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
113 | points=points,
114 | boxes=image_record.get("boxes", None),
115 | masks=image_record.get("mask_inputs", None),
116 | )
117 | low_res_masks, iou_predictions = self.mask_decoder(
118 | image_embeddings=curr_embedding.unsqueeze(0),
119 | image_pe=self.prompt_encoder.get_dense_pe(),
120 | sparse_prompt_embeddings=sparse_embeddings,
121 | dense_prompt_embeddings=dense_embeddings,
122 | multimask_output=multimask_output,
123 | )
124 | masks = self.postprocess_masks(
125 | low_res_masks,
126 | input_size=image_record["image"].shape[-2:],
127 | original_size=image_record["original_size"],
128 | )
129 | masks = masks > self.mask_threshold
130 | outputs.append(
131 | {
132 | "masks": masks,
133 | "iou_predictions": iou_predictions,
134 | "low_res_logits": low_res_masks,
135 | }
136 | )
137 | return outputs
138 |
139 | def postprocess_masks(
140 | self,
141 | masks: torch.Tensor,
142 | input_size: Tuple[int, ...],
143 | original_size: Tuple[int, ...],
144 | ) -> torch.Tensor:
145 | """
146 | Remove padding and upscale masks to the original image size.
147 |
148 | Arguments:
149 | masks (torch.Tensor): Batched masks from the mask_decoder,
150 | in BxCxHxW format.
151 | input_size (tuple(int, int)): The size of the image input to the
152 | model, in (H, W) format. Used to remove padding.
153 | original_size (tuple(int, int)): The original size of the image
154 | before resizing for input to the model, in (H, W) format.
155 |
156 | Returns:
157 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
158 | is given by original_size.
159 | """
160 | masks = F.interpolate(
161 | masks,
162 | (self.image_encoder.img_size, self.image_encoder.img_size),
163 | mode="bilinear",
164 | align_corners=False,
165 | )
166 | masks = masks[..., : input_size[0], : input_size[1]]
167 | masks = F.interpolate(
168 | masks, original_size, mode="bilinear", align_corners=False
169 | )
170 | return masks
171 |
172 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
173 | """Normalize pixel values and pad to a square input."""
174 | # Normalize colors
175 | x = (x - self.pixel_mean) / self.pixel_std
176 |
177 | # Pad
178 | h, w = x.shape[-2:]
179 | padh = self.image_encoder.img_size - h
180 | padw = self.image_encoder.img_size - w
181 | x = F.pad(x, (0, padw, 0, padh))
182 | return x
183 |
--------------------------------------------------------------------------------
/metaseg/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Tuple, Type
9 |
10 | import torch
11 | from torch import Tensor, nn
12 |
13 | from metaseg.modeling.common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attenion layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert (
202 | self.internal_dim % num_heads == 0
203 | ), "num_heads must divide embedding_dim."
204 |
205 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
207 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
208 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
209 |
210 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
211 | b, n, c = x.shape
212 | x = x.reshape(b, n, num_heads, c // num_heads)
213 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
214 |
215 | def _recombine_heads(self, x: Tensor) -> Tensor:
216 | b, n_heads, n_tokens, c_per_head = x.shape
217 | x = x.transpose(1, 2)
218 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
219 |
220 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
221 | # Input projections
222 | q = self.q_proj(q)
223 | k = self.k_proj(k)
224 | v = self.v_proj(v)
225 |
226 | # Separate into heads
227 | q = self._separate_heads(q, self.num_heads)
228 | k = self._separate_heads(k, self.num_heads)
229 | v = self._separate_heads(v, self.num_heads)
230 |
231 | # Attention
232 | _, _, _, c_per_head = q.shape
233 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
234 | attn = attn / math.sqrt(c_per_head)
235 | attn = torch.softmax(attn, dim=-1)
236 |
237 | # Get output
238 | out = attn @ v
239 | out = self._recombine_heads(out)
240 | out = self.out_proj(out)
241 |
242 | return out
243 |
--------------------------------------------------------------------------------
/metaseg/sahi_predictor.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import cv2
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 | from cv2 import Mat
8 | from PIL import Image
9 |
10 | from metaseg.generator import SamPredictor, sam_model_registry
11 | from metaseg.utils import (
12 | download_model,
13 | load_image,
14 | multi_boxes,
15 | plt_load_box,
16 | plt_load_mask,
17 | )
18 |
19 |
20 | def sahi_sliced_predict(
21 | image_path,
22 | detection_model_type,
23 | detection_model_path,
24 | conf_th,
25 | image_size,
26 | slice_height,
27 | slice_width,
28 | overlap_height_ratio,
29 | overlap_width_ratio,
30 | ):
31 | try:
32 | from sahi import AutoDetectionModel
33 | from sahi.predict import get_prediction, get_sliced_prediction
34 | except ImportError:
35 | raise ImportError("Please install SAHI library using 'pip install sahi'.")
36 |
37 | device = "cuda" if torch.cuda.is_available() else "cpu"
38 |
39 | detection_model = AutoDetectionModel.from_pretrained(
40 | image_size=image_size,
41 | model_type=detection_model_type,
42 | model_path=detection_model_path,
43 | confidence_threshold=conf_th,
44 | device=device,
45 | )
46 | result = get_sliced_prediction(
47 | image_path,
48 | detection_model,
49 | slice_height=slice_height,
50 | slice_width=slice_width,
51 | overlap_height_ratio=overlap_height_ratio,
52 | overlap_width_ratio=overlap_width_ratio,
53 | )
54 |
55 | result = get_prediction(image_path, detection_model)
56 | output = result.object_prediction_list
57 | boxes = []
58 | for i in output:
59 | boxes.append(i.bbox.to_xyxy())
60 |
61 | return boxes
62 |
63 |
64 | class SahiAutoSegmentation:
65 | def __init__(self):
66 | self.model = None
67 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
68 |
69 | def load_model(self, model_type):
70 | if self.model is None:
71 | self.model_path = download_model(model_type)
72 | self.model = sam_model_registry[model_type](checkpoint=self.model_path)
73 | self.model.to(device=self.device)
74 |
75 | return self.model
76 |
77 | def image_predict(
78 | self,
79 | source: Union[str, Mat],
80 | model_type,
81 | input_box=None,
82 | input_point=None,
83 | input_label=None,
84 | multimask_output=False,
85 | random_color=False,
86 | show=False,
87 | save=False,
88 | ):
89 | read_image = load_image(source)
90 | model = self.load_model(model_type)
91 | predictor = SamPredictor(model)
92 | predictor.set_image(read_image)
93 |
94 | if type(input_box[0]) == list:
95 | input_boxes, new_boxes = multi_boxes(input_box, predictor, read_image)
96 |
97 | masks, iou_predictions, low_res_masks = predictor.predict_torch(
98 | point_coords=None,
99 | point_labels=None,
100 | boxes=new_boxes,
101 | multimask_output=False,
102 | )
103 |
104 | elif type(input_box[0]) == int:
105 | input_boxes = np.array(input_box)[None, :]
106 |
107 | masks, iou_predictions, low_res_masks = predictor.predict(
108 | point_coords=input_point,
109 | point_labels=input_label,
110 | box=input_boxes,
111 | multimask_output=multimask_output,
112 | )
113 |
114 | plt.figure(figsize=(10, 10))
115 | plt.imshow(read_image)
116 | for mask in masks:
117 | plt_load_mask(mask.cpu().numpy(), plt.gca(), random_color=random_color)
118 | for box in input_boxes:
119 | plt_load_box(box.cpu().numpy(), plt.gca())
120 | plt.axis("off")
121 | if save:
122 | plt.savefig("output.png", bbox_inches="tight")
123 | output_image = cv2.imread("output.png")
124 | output_image = Image.fromarray(output_image)
125 | return output_image
126 | if show:
127 | plt.show()
128 |
129 | return masks, iou_predictions, low_res_masks
130 |
--------------------------------------------------------------------------------
/metaseg/sam_predictor.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from cv2 import Mat
7 | from tqdm import tqdm
8 |
9 | from metaseg.generator.automatic_mask_generator import SamAutomaticMaskGenerator
10 | from metaseg.generator.build_sam import sam_model_registry
11 | from metaseg.generator.predictor import SamPredictor
12 | from metaseg.utils import (
13 | download_model,
14 | load_box,
15 | load_image,
16 | load_mask,
17 | load_video,
18 | multi_boxes,
19 | show_image,
20 | )
21 |
22 |
23 | class SegAutoMaskPredictor:
24 | def __init__(self):
25 | self.model = None
26 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
27 |
28 | def load_model(self, model_type):
29 | if self.model is None:
30 | self.model_path = download_model(model_type)
31 | self.model = sam_model_registry[model_type](checkpoint=self.model_path)
32 | self.model.to(device=self.device)
33 |
34 | return self.model
35 |
36 | def image_predict(
37 | self,
38 | source: Union[str, Mat],
39 | model_type,
40 | points_per_side,
41 | points_per_batch,
42 | min_area,
43 | output_path="output.png",
44 | show=False,
45 | save=False,
46 | ):
47 | read_image = load_image(source)
48 | model = self.load_model(model_type)
49 | mask_generator = SamAutomaticMaskGenerator(
50 | model,
51 | points_per_side=points_per_side,
52 | points_per_batch=points_per_batch,
53 | min_mask_region_area=min_area,
54 | )
55 |
56 | masks = mask_generator.generate(read_image)
57 |
58 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
59 | mask_image = np.zeros(
60 | (masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3),
61 | dtype=np.uint8,
62 | )
63 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)
64 | for i, ann in enumerate(sorted_anns):
65 | m = ann["segmentation"]
66 | img = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8)
67 | color = colors[i % 256]
68 | for i in range(3):
69 | img[:, :, 0] = color[0]
70 | img[:, :, 1] = color[1]
71 | img[:, :, 2] = color[2]
72 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8))
73 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0)
74 | mask_image = cv2.add(mask_image, img)
75 |
76 | combined_mask = cv2.add(read_image, mask_image)
77 | self.combined_mask = combined_mask
78 | if show:
79 | show_image(combined_mask)
80 |
81 | if save:
82 | cv2.imwrite(output_path, combined_mask)
83 |
84 | return masks
85 |
86 | def video_predict(
87 | self,
88 | source,
89 | model_type,
90 | points_per_side,
91 | points_per_batch,
92 | min_area,
93 | output_path="output.mp4",
94 | ):
95 | cap, out = load_video(source, output_path)
96 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
97 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)
98 |
99 | for _ in tqdm(range(length)):
100 | ret, frame = cap.read()
101 | if not ret:
102 | break
103 |
104 | model = self.load_model(model_type)
105 | mask_generator = SamAutomaticMaskGenerator(
106 | model,
107 | points_per_side=points_per_side,
108 | points_per_batch=points_per_batch,
109 | min_mask_region_area=min_area,
110 | )
111 | masks = mask_generator.generate(frame)
112 |
113 | if len(masks) == 0:
114 | continue
115 |
116 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
117 | mask_image = np.zeros(
118 | (
119 | masks[0]["segmentation"].shape[0],
120 | masks[0]["segmentation"].shape[1],
121 | 3,
122 | ),
123 | dtype=np.uint8,
124 | )
125 |
126 | for i, ann in enumerate(sorted_anns):
127 | m = ann["segmentation"]
128 | color = colors[i % 256]
129 | img = np.zeros((m.shape[0], m.shape[1], 3), dtype=np.uint8)
130 | img[:, :, 0] = color[0]
131 | img[:, :, 1] = color[1]
132 | img[:, :, 2] = color[2]
133 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8))
134 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0)
135 | mask_image = cv2.add(mask_image, img)
136 |
137 | combined_mask = cv2.add(frame, mask_image)
138 | out.write(combined_mask)
139 |
140 | out.release()
141 | cap.release()
142 | cv2.destroyAllWindows()
143 |
144 | return output_path
145 |
146 |
147 | class SegManualMaskPredictor:
148 | def __init__(self):
149 | self.model = None
150 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
151 |
152 | def load_model(self, model_type):
153 | if self.model is None:
154 | self.model_path = download_model(model_type)
155 | self.model = sam_model_registry[model_type](checkpoint=self.model_path)
156 | self.model.to(device=self.device)
157 |
158 | return self.model
159 |
160 | def image_predict(
161 | self,
162 | source: Union[str, Mat],
163 | model_type,
164 | input_box=None,
165 | input_point=None,
166 | input_label=None,
167 | multimask_output=False,
168 | output_path="output.png",
169 | random_color=False,
170 | show=False,
171 | save=False,
172 | ):
173 | image = load_image(source)
174 | model = self.load_model(model_type)
175 | predictor = SamPredictor(model)
176 | predictor.set_image(image)
177 |
178 | if type(input_box[0]) == list:
179 | input_boxes, new_boxes = multi_boxes(input_box, predictor, image)
180 |
181 | masks, _, _ = predictor.predict_torch(
182 | point_coords=None,
183 | point_labels=None,
184 | boxes=new_boxes,
185 | multimask_output=False,
186 | )
187 | for mask in masks:
188 | mask_image = load_mask(mask.cpu().numpy(), random_color)
189 |
190 | for box in input_boxes:
191 | image = load_box(box.cpu().numpy(), image)
192 |
193 | elif type(input_box[0]) == int:
194 | input_boxes = np.array(input_box)[None, :]
195 |
196 | masks, _, _ = predictor.predict(
197 | point_coords=input_point,
198 | point_labels=input_label,
199 | box=input_boxes,
200 | multimask_output=multimask_output,
201 | )
202 | mask_image = load_mask(masks, random_color)
203 | image = load_box(input_box, image)
204 |
205 | combined_mask = cv2.add(image, mask_image)
206 | if save:
207 | cv2.imwrite(output_path, combined_mask)
208 |
209 | if show:
210 | show_image(combined_mask)
211 |
212 | return masks
213 |
214 | def video_predict(
215 | self,
216 | source,
217 | model_type,
218 | input_box=None,
219 | input_point=None,
220 | input_label=None,
221 | multimask_output=False,
222 | output_path="output.mp4",
223 | random_color=False,
224 | ):
225 | cap, out = load_video(source, output_path)
226 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
227 |
228 | for _ in tqdm(range(length)):
229 | ret, frame = cap.read()
230 | if not ret:
231 | break
232 |
233 | model = self.load_model(model_type)
234 | predictor = SamPredictor(model)
235 | predictor.set_image(frame)
236 |
237 | if type(input_box[0]) == list:
238 | input_boxes, new_boxes = multi_boxes(input_box, predictor, frame)
239 |
240 | masks, _, _ = predictor.predict_torch(
241 | point_coords=None,
242 | point_labels=None,
243 | boxes=new_boxes,
244 | multimask_output=False,
245 | )
246 | for mask in masks:
247 | mask_image = load_mask(mask.cpu().numpy(), random_color)
248 |
249 | for box in input_boxes:
250 | frame = load_box(box.cpu().numpy(), frame)
251 |
252 | elif type(input_box[0]) == int:
253 | input_boxes = np.array(input_box)[None, :]
254 |
255 | masks, _, _ = predictor.predict(
256 | point_coords=input_point,
257 | point_labels=input_label,
258 | box=input_boxes,
259 | multimask_output=multimask_output,
260 | )
261 | mask_image = load_mask(masks, random_color)
262 | frame = load_box(input_box, frame)
263 |
264 | combined_mask = cv2.add(frame, mask_image)
265 | out.write(combined_mask)
266 |
267 | out.release()
268 | cap.release()
269 | cv2.destroyAllWindows()
270 | return output_path
271 |
--------------------------------------------------------------------------------
/metaseg/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .data_utils import load_box as load_box
8 | from .data_utils import load_image as load_image
9 | from .data_utils import load_mask as load_mask
10 | from .data_utils import load_server_image as load_server_image
11 | from .data_utils import load_video as load_video
12 | from .data_utils import multi_boxes as multi_boxes
13 | from .data_utils import plt_load_box as plt_load_box
14 | from .data_utils import plt_load_mask as plt_load_mask
15 | from .data_utils import show_image as show_image
16 | from .model_file_downloader import download_model as download_model
17 | from .onnx import SamOnnxModel as SamOnnxModel
18 | from .transforms import ResizeLongestSide as ResizeLongestSide
19 |
--------------------------------------------------------------------------------
/metaseg/utils/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from copy import deepcopy
9 | from itertools import product
10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple
11 |
12 | import cv2
13 | import numpy as np
14 | import torch
15 | from pycocotools import mask as mask_utils
16 |
17 |
18 | class MaskData:
19 | """
20 | A structure for storing masks and their related data in batched format.
21 | Implements basic filtering and concatenation.
22 | """
23 |
24 | def __init__(self, **kwargs) -> None:
25 | for v in kwargs.values():
26 | assert isinstance(
27 | v, (list, np.ndarray, torch.Tensor)
28 | ), "MaskData only supports list, numpy arrays, and torch tensors."
29 | self._stats = dict(**kwargs)
30 |
31 | def __setitem__(self, key: str, item: Any) -> None:
32 | assert isinstance(
33 | item, (list, np.ndarray, torch.Tensor)
34 | ), "MaskData only supports list, numpy arrays, and torch tensors."
35 | self._stats[key] = item
36 |
37 | def __delitem__(self, key: str) -> None:
38 | del self._stats[key]
39 |
40 | def __getitem__(self, key: str) -> Any:
41 | return self._stats[key]
42 |
43 | def items(self) -> ItemsView[str, Any]:
44 | return self._stats.items()
45 |
46 | def filter(self, keep: torch.Tensor) -> None:
47 | for k, v in self._stats.items():
48 | if v is None:
49 | self._stats[k] = None
50 | elif isinstance(v, torch.Tensor):
51 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
52 | elif isinstance(v, np.ndarray):
53 | self._stats[k] = v[keep.detach().cpu().numpy()]
54 | elif isinstance(v, list) and keep.dtype == torch.bool:
55 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
56 | elif isinstance(v, list):
57 | self._stats[k] = [v[i] for i in keep]
58 | else:
59 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
60 |
61 | def cat(self, new_stats: "MaskData") -> None:
62 | for k, v in new_stats.items():
63 | if k not in self._stats or self._stats[k] is None:
64 | self._stats[k] = deepcopy(v)
65 | elif isinstance(v, torch.Tensor):
66 | self._stats[k] = torch.cat([self._stats[k], v], dim=0)
67 | elif isinstance(v, np.ndarray):
68 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
69 | elif isinstance(v, list):
70 | self._stats[k] = self._stats[k] + deepcopy(v)
71 | else:
72 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
73 |
74 | def to_numpy(self) -> None:
75 | for k, v in self._stats.items():
76 | if isinstance(v, torch.Tensor):
77 | self._stats[k] = v.detach().cpu().numpy()
78 |
79 |
80 | def is_box_near_crop_edge(
81 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
82 | ) -> torch.Tensor:
83 | """Filter masks at the edge of a crop, but not at the edge of the original image."""
84 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
85 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
86 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
87 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
88 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
89 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
90 | return torch.any(near_crop_edge, dim=1)
91 |
92 |
93 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
94 | box_xywh = deepcopy(box_xyxy)
95 | box_xywh[2] = box_xywh[2] - box_xywh[0]
96 | box_xywh[3] = box_xywh[3] - box_xywh[1]
97 | return box_xywh
98 |
99 |
100 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
101 | assert len(args) > 0 and all(
102 | len(a) == len(args[0]) for a in args
103 | ), "Batched iteration must have inputs of all the same size."
104 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
105 | for b in range(n_batches):
106 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
107 |
108 |
109 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
110 | """
111 | Encodes masks to an uncompressed RLE, in the format expected by
112 | pycoco tools.
113 | """
114 | # Put in fortran order and flatten h,w
115 | b, h, w = tensor.shape
116 | tensor = tensor.permute(0, 2, 1).flatten(1)
117 |
118 | # Compute change indices
119 | diff = tensor[:, 1:] ^ tensor[:, :-1]
120 | change_indices = diff.nonzero()
121 |
122 | # Encode run length
123 | out = []
124 | for i in range(b):
125 | cur_idxs = change_indices[change_indices[:, 0] == i, 1]
126 | cur_idxs = torch.cat(
127 | [
128 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
129 | cur_idxs + 1,
130 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
131 | ]
132 | )
133 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
134 | counts = [] if tensor[i, 0] == 0 else [0]
135 | counts.extend(btw_idxs.detach().cpu().tolist())
136 | out.append({"size": [h, w], "counts": counts})
137 | return out
138 |
139 |
140 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
141 | """Compute a binary mask from an uncompressed RLE."""
142 | h, w = rle["size"]
143 | mask = np.empty(h * w, dtype=bool)
144 | idx = 0
145 | parity = False
146 | for count in rle["counts"]:
147 | mask[idx : idx + count] = parity
148 | idx += count
149 | parity ^= True
150 | mask = mask.reshape(w, h)
151 | return mask.transpose() # Put in C order
152 |
153 |
154 | def area_from_rle(rle: Dict[str, Any]) -> int:
155 | return sum(rle["counts"][1::2])
156 |
157 |
158 | def calculate_stability_score(
159 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float
160 | ) -> torch.Tensor:
161 | """
162 | Computes the stability score for a batch of masks. The stability
163 | score is the IoU between the binary masks obtained by thresholding
164 | the predicted mask logits at high and low values.
165 | """
166 | # One mask is always contained inside the other.
167 | # Save memory by preventing unnecesary cast to torch.int64
168 | intersections = (
169 | (masks > (mask_threshold + threshold_offset))
170 | .sum(-1, dtype=torch.int16)
171 | .sum(-1, dtype=torch.int32)
172 | )
173 | unions = (
174 | (masks > (mask_threshold - threshold_offset))
175 | .sum(-1, dtype=torch.int16)
176 | .sum(-1, dtype=torch.int32)
177 | )
178 | return intersections / unions
179 |
180 |
181 | def build_point_grid(n_per_side: int) -> np.ndarray:
182 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
183 | offset = 1 / (2 * n_per_side)
184 | points_one_side = np.linspace(offset, 1 - offset, n_per_side)
185 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
186 | points_y = np.tile(points_one_side[:, None], (1, n_per_side))
187 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
188 | return points
189 |
190 |
191 | def build_all_layer_point_grids(
192 | n_per_side: int, n_layers: int, scale_per_layer: int
193 | ) -> List[np.ndarray]:
194 | """Generates point grids for all crop layers."""
195 | points_by_layer = []
196 | for i in range(n_layers + 1):
197 | n_points = int(n_per_side / (scale_per_layer**i))
198 | points_by_layer.append(build_point_grid(n_points))
199 | return points_by_layer
200 |
201 |
202 | def generate_crop_boxes(
203 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
204 | ) -> Tuple[List[List[int]], List[int]]:
205 | """
206 | Generates a list of crop boxes of different sizes. Each layer
207 | has (2**i)**2 boxes for the ith layer.
208 | """
209 | crop_boxes, layer_idxs = [], []
210 | im_h, im_w = im_size
211 | short_side = min(im_h, im_w)
212 |
213 | # Original image
214 | crop_boxes.append([0, 0, im_w, im_h])
215 | layer_idxs.append(0)
216 |
217 | def crop_len(orig_len, n_crops, overlap):
218 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
219 |
220 | for i_layer in range(n_layers):
221 | n_crops_per_side = 2 ** (i_layer + 1)
222 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
223 |
224 | crop_w = crop_len(im_w, n_crops_per_side, overlap)
225 | crop_h = crop_len(im_h, n_crops_per_side, overlap)
226 |
227 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
228 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
229 |
230 | # Crops in XYWH format
231 | for x0, y0 in product(crop_box_x0, crop_box_y0):
232 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
233 | crop_boxes.append(box)
234 | layer_idxs.append(i_layer + 1)
235 |
236 | return crop_boxes, layer_idxs
237 |
238 |
239 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
240 | x0, y0, _, _ = crop_box
241 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
242 | # Check if boxes has a channel dimension
243 | if len(boxes.shape) == 3:
244 | offset = offset.unsqueeze(1)
245 | return boxes + offset
246 |
247 |
248 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
249 | x0, y0, _, _ = crop_box
250 | offset = torch.tensor([[x0, y0]], device=points.device)
251 | # Check if points has a channel dimension
252 | if len(points.shape) == 3:
253 | offset = offset.unsqueeze(1)
254 | return points + offset
255 |
256 |
257 | def uncrop_masks(
258 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
259 | ) -> torch.Tensor:
260 | x0, y0, x1, y1 = crop_box
261 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
262 | return masks
263 | # Coordinate transform masks
264 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
265 | pad = (x0, pad_x - x0, y0, pad_y - y0)
266 | return torch.nn.functional.pad(masks, pad, value=0)
267 |
268 |
269 | def remove_small_regions(
270 | mask: np.ndarray, area_thresh: float, mode: str
271 | ) -> Tuple[np.ndarray, bool]:
272 | """
273 | Removes small disconnected regions and holes in a mask. Returns the
274 | mask and an indicator of if the mask has been modified.
275 | """
276 |
277 | assert mode in ["holes", "islands"]
278 | correct_holes = mode == "holes"
279 | working_mask = (correct_holes ^ mask).astype(np.uint8)
280 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
281 | sizes = stats[:, -1][1:] # Row 0 is background label
282 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
283 | if len(small_regions) == 0:
284 | return mask, False
285 | fill_labels = [0] + small_regions
286 | if not correct_holes:
287 | fill_labels = [i for i in range(n_labels) if i not in fill_labels]
288 | # If every region is below threshold, keep largest
289 | if len(fill_labels) == 0:
290 | fill_labels = [int(np.argmax(sizes)) + 1]
291 | mask = np.isin(regions, fill_labels)
292 | return mask, True
293 |
294 |
295 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
296 | h, w = uncompressed_rle["size"]
297 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
298 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
299 | return rle
300 |
301 |
302 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
303 | """
304 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
305 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
306 | """
307 | # torch.max below raises an error on empty inputs, just skip in this case
308 | if torch.numel(masks) == 0:
309 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
310 |
311 | # Normalize shape to CxHxW
312 | shape = masks.shape
313 | h, w = shape[-2:]
314 | if len(shape) > 2:
315 | masks = masks.flatten(0, -3)
316 | else:
317 | masks = masks.unsqueeze(0)
318 |
319 | # Get top and bottom edges
320 | in_height, _ = torch.max(masks, dim=-1)
321 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
322 | bottom_edges, _ = torch.max(in_height_coords, dim=-1)
323 | in_height_coords = in_height_coords + h * (~in_height)
324 | top_edges, _ = torch.min(in_height_coords, dim=-1)
325 |
326 | # Get left and right edges
327 | in_width, _ = torch.max(masks, dim=-2)
328 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
329 | right_edges, _ = torch.max(in_width_coords, dim=-1)
330 | in_width_coords = in_width_coords + w * (~in_width)
331 | left_edges, _ = torch.min(in_width_coords, dim=-1)
332 |
333 | # If the mask is empty the right edge will be to the left of the left edge.
334 | # Replace these boxes with [0, 0, 0, 0]
335 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
336 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
337 | out = out * (~empty_filter).unsqueeze(-1)
338 |
339 | # Return to original shape
340 | if len(shape) > 2:
341 | out = out.reshape(*shape[:-2], 4)
342 | else:
343 | out = out[0]
344 |
345 | return out
346 |
--------------------------------------------------------------------------------
/metaseg/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from os import system
3 | from os.path import isfile as isfile
4 | from typing import Union
5 | from uuid import uuid4
6 |
7 | import cv2
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | from cv2 import Mat
11 | from PIL import Image
12 | from torch import tensor
13 |
14 |
15 | def load_image(image: Union[str, Mat]) -> Mat:
16 | """
17 | Load image from path
18 | :param image_path: path to image file or image as Mat or np.ndarray
19 | :return: image as Mat
20 | """
21 | if isfile(image):
22 | image = cv2.imread(image)
23 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
24 | return image
25 | elif isinstance(image, Mat) or isinstance(image, np.ndarray):
26 | return image
27 | else:
28 | raise ValueError("image must be a path or cv2.Mat")
29 |
30 |
31 | def load_server_image(image_path):
32 | imagedir = str(uuid4())
33 | system(f"mkdir -p {imagedir}")
34 | image = Image.open(BytesIO(image_path))
35 | if image.mode != "RGB":
36 | image = image.convert("RGB")
37 |
38 | image_path = f"{imagedir}/base_image_v0.png"
39 | output_path = f"{imagedir}/output_v0.png"
40 | image.save(image_path, format="PNG")
41 | return image_path, output_path
42 |
43 |
44 | def load_video(video_path, output_path="output.mp4"):
45 | cap = cv2.VideoCapture(video_path)
46 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
47 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
48 | fourcc = cv2.VideoWriter_fourcc(*"XVID")
49 | fps = int(cap.get(cv2.CAP_PROP_FPS))
50 | out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
51 | return cap, out
52 |
53 |
54 | def load_mask(mask, random_color):
55 | if random_color:
56 | color = np.random.rand(3) * 255
57 | else:
58 | color = np.array([100, 50, 0])
59 |
60 | h, w = mask.shape[-2:]
61 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
62 | mask_image = mask_image.astype(np.uint8)
63 | return mask_image
64 |
65 |
66 | def load_box(box, image):
67 | x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
68 | cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2)
69 | return image
70 |
71 |
72 | def plt_load_mask(mask, ax, random_color=False):
73 | if random_color:
74 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
75 | else:
76 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
77 | h, w = mask.shape[-2:]
78 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
79 | ax.imshow(mask_image)
80 |
81 |
82 | def plt_load_box(box, ax):
83 | x0, y0 = box[0], box[1]
84 | w, h = box[2] - box[0], box[3] - box[1]
85 | ax.add_patch(
86 | plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
87 | )
88 |
89 |
90 | def multi_boxes(boxes, predictor, image):
91 | input_boxes = tensor(boxes, device=predictor.device)
92 | transformed_boxes = predictor.transform.apply_boxes_torch(
93 | input_boxes, image.shape[:2]
94 | )
95 | return input_boxes, transformed_boxes
96 |
97 |
98 | def show_image(output_image):
99 | cv2.imshow("output", output_image)
100 | cv2.waitKey(0)
101 | cv2.destroyAllWindows()
102 |
--------------------------------------------------------------------------------
/metaseg/utils/model_file_downloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import partial
3 | from hashlib import md5
4 | from pathlib import Path
5 | from shutil import copyfileobj
6 |
7 | from requests import Response, get
8 | from tqdm.auto import tqdm
9 |
10 | # A dictionary containing model types as keys and their respective URLs as values
11 | MODEL_URLS: dict[str : tuple[str]] = {
12 | "vit_h": (
13 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
14 | "green",
15 | "01ec64d29a2fca3f0661936605ae66f8",
16 | ),
17 | "vit_l": (
18 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
19 | "red",
20 | "0b3195507c641ddb6910d2bb5adee89c",
21 | ),
22 | "vit_b": (
23 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
24 | "cyan",
25 | "4b8939a88964f0f4ff5f5b2642c598a6",
26 | ),
27 | }
28 |
29 |
30 | # md5 check function
31 | def _check_md5(filename: str, orig_md5: str) -> bool:
32 | """
33 | filename: str, A string representing the path to the file.
34 | orig_md5: str, A string representing the original md5 hash.
35 | """
36 | if not os.path.exists(filename):
37 | return False
38 | with open(filename, "rb") as file_to_check:
39 | # read contents of the file
40 | data = file_to_check.read()
41 | # pipe contents of the file through
42 | md5_returned = md5(data).hexdigest()
43 | # Return True if the computed hash matches the original one
44 | if md5_returned == orig_md5:
45 | return True
46 | return False
47 |
48 |
49 | def download_model(model_type):
50 | """
51 | model_type: str, A string representing the model type.
52 | It can be 'vit_h', 'vit_l', or 'vit_b'.
53 | """
54 |
55 | # Check if the model file already exists and model_type is in MODEL_URLS
56 | filename = f"{model_type}.pth"
57 | if not os.path.exists(filename) and model_type in MODEL_URLS:
58 | print(f"Downloading {filename} model \n")
59 | res: Response = get(
60 | MODEL_URLS[model_type][0], stream=True, allow_redirects=True
61 | )
62 | if res.status_code != 200:
63 | res.raise_for_status()
64 | raise RuntimeError(
65 | f"Request to {MODEL_URLS[model_type][0]} "
66 | f"returned status code {res.status_code}"
67 | )
68 |
69 | file_size: int = int(res.headers.get("Content-Length", 0))
70 | folder_path: Path = Path(filename).expanduser().resolve()
71 | folder_path.parent.mkdir(parents=True, exist_ok=True)
72 |
73 | desc = "(Unknown total file size)" if file_size == 0 else ""
74 | res.raw.read = partial(
75 | res.raw.read, decode_content=True
76 | ) # Decompress if needed
77 | with tqdm.wrapattr(
78 | res.raw,
79 | "read",
80 | total=file_size,
81 | desc=desc,
82 | colour=MODEL_URLS[model_type][1],
83 | ) as r_raw:
84 | with folder_path.open("wb") as f:
85 | copyfileobj(r_raw, f)
86 |
87 | elif os.path.exists(filename):
88 | if not _check_md5(filename, MODEL_URLS[model_type][2]):
89 | print("File corrupted. Re-downloading... \n")
90 | os.remove(filename)
91 | download_model(model_type)
92 |
93 | print(f"{filename} model download complete. \n")
94 | else:
95 | raise ValueError(
96 | "Invalid model type. It should be 'vit_h', 'vit_l', or 'vit_b'."
97 | )
98 |
99 | return filename
100 |
--------------------------------------------------------------------------------
/metaseg/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Tuple
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import functional as F
12 |
13 | from metaseg.modeling import Sam
14 | from metaseg.utils.amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(
52 | self, point_coords: torch.Tensor, point_labels: torch.Tensor
53 | ) -> torch.Tensor:
54 | point_coords = point_coords + 0.5
55 | point_coords = point_coords / self.img_size
56 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
57 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
58 |
59 | point_embedding = point_embedding * (point_labels != -1)
60 | point_embedding = (
61 | point_embedding
62 | + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
63 | )
64 |
65 | for i in range(self.model.prompt_encoder.num_point_embeddings):
66 | point_embedding = (
67 | point_embedding
68 | + self.model.prompt_encoder.point_embeddings[i].weight
69 | * (point_labels == i)
70 | )
71 |
72 | return point_embedding
73 |
74 | def _embed_masks(
75 | self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
76 | ) -> torch.Tensor:
77 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
78 | input_mask
79 | )
80 | mask_embedding = mask_embedding + (
81 | 1 - has_mask_input
82 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
83 | return mask_embedding
84 |
85 | def mask_postprocessing(
86 | self, masks: torch.Tensor, orig_im_size: torch.Tensor
87 | ) -> torch.Tensor:
88 | masks = F.interpolate(
89 | masks,
90 | size=(self.img_size, self.img_size),
91 | mode="bilinear",
92 | align_corners=False,
93 | )
94 |
95 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
96 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
97 |
98 | orig_im_size = orig_im_size.to(torch.int64)
99 | h, w = orig_im_size[0], orig_im_size[1]
100 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
101 | return masks
102 |
103 | def select_masks(
104 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
105 | ) -> Tuple[torch.Tensor, torch.Tensor]:
106 | # Determine if we should return the multi click
107 | # mask or not from the number of points.
108 | # The reweighting is used to avoid control flow.
109 | score_reweight = torch.tensor(
110 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
111 | ).to(iou_preds.device)
112 | score = iou_preds + (num_points - 2.5) * score_reweight
113 | best_idx = torch.argmax(score, dim=1)
114 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
115 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
116 |
117 | return masks, iou_preds
118 |
119 | @torch.no_grad()
120 | def forward(
121 | self,
122 | image_embeddings: torch.Tensor,
123 | point_coords: torch.Tensor,
124 | point_labels: torch.Tensor,
125 | mask_input: torch.Tensor,
126 | has_mask_input: torch.Tensor,
127 | orig_im_size: torch.Tensor,
128 | ):
129 | sparse_embedding = self._embed_points(point_coords, point_labels)
130 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
131 |
132 | masks, scores = self.model.mask_decoder.predict_masks(
133 | image_embeddings=image_embeddings,
134 | image_pe=self.model.prompt_encoder.get_dense_pe(),
135 | sparse_prompt_embeddings=sparse_embedding,
136 | dense_prompt_embeddings=dense_embedding,
137 | )
138 |
139 | if self.use_stability_score:
140 | scores = calculate_stability_score(
141 | masks, self.model.mask_threshold, self.stability_score_offset
142 | )
143 |
144 | if self.return_single_mask:
145 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
146 |
147 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
148 |
149 | if self.return_extra_metrics:
150 | stability_scores = calculate_stability_score(
151 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
152 | )
153 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
154 | return upscaled_masks, scores, stability_scores, areas, masks
155 |
156 | return upscaled_masks, scores, masks
157 |
--------------------------------------------------------------------------------
/metaseg/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from copy import deepcopy
8 | from typing import Tuple
9 |
10 | import numpy as np
11 | import torch
12 | from torch.nn import functional as F
13 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(
31 | image.shape[0], image.shape[1], self.target_length
32 | )
33 | return np.array(resize(to_pil_image(image), target_size))
34 |
35 | def apply_coords(
36 | self, coords: np.ndarray, original_size: Tuple[int, ...]
37 | ) -> np.ndarray:
38 | """
39 | Expects a numpy array of length 2 in the final dimension. Requires the
40 | original image size in (H, W) format.
41 | """
42 | old_h, old_w = original_size
43 | new_h, new_w = self.get_preprocess_shape(
44 | original_size[0], original_size[1], self.target_length
45 | )
46 | coords = deepcopy(coords).astype(float)
47 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
48 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
49 | return coords
50 |
51 | def apply_boxes(
52 | self, boxes: np.ndarray, original_size: Tuple[int, ...]
53 | ) -> np.ndarray:
54 | """
55 | Expects a numpy array shape Bx4. Requires the original image size
56 | in (H, W) format.
57 | """
58 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
59 | return boxes.reshape(-1, 4)
60 |
61 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
62 | """
63 | Expects batched images with shape BxCxHxW and float format. This
64 | transformation may not exactly match apply_image. apply_image is
65 | the transformation expected by the model.
66 | """
67 | # Expects an image in BCHW format. May not exactly match apply_image.
68 | target_size = self.get_preprocess_shape(
69 | image.shape[0], image.shape[1], self.target_length
70 | )
71 | return F.interpolate(
72 | image, target_size, mode="bilinear", align_corners=False, antialias=True
73 | )
74 |
75 | def apply_coords_torch(
76 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
77 | ) -> torch.Tensor:
78 | """
79 | Expects a torch tensor with length 2 in the last dimension. Requires the
80 | original image size in (H, W) format.
81 | """
82 | old_h, old_w = original_size
83 | new_h, new_w = self.get_preprocess_shape(
84 | original_size[0], original_size[1], self.target_length
85 | )
86 | coords = deepcopy(coords).to(torch.float)
87 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
88 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
89 | return coords
90 |
91 | def apply_boxes_torch(
92 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
93 | ) -> torch.Tensor:
94 | """
95 | Expects a torch tensor with shape Bx4. Requires the original image
96 | size in (H, W) format.
97 | """
98 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
99 | return boxes.reshape(-1, 4)
100 |
101 | @staticmethod
102 | def get_preprocess_shape(
103 | oldh: int, oldw: int, long_side_length: int
104 | ) -> Tuple[int, int]:
105 | """
106 | Compute the output size given input size and target long side length.
107 | """
108 | scale = long_side_length * 1.0 / max(oldh, oldw)
109 | newh, neww = oldh * scale, oldw * scale
110 | neww = int(neww + 0.5)
111 | newh = int(newh + 0.5)
112 | return (newh, neww)
113 |
--------------------------------------------------------------------------------
/metaseg/webapp/__init__.py:
--------------------------------------------------------------------------------
1 | from .app import metaseg_app as metaseg_app
2 |
--------------------------------------------------------------------------------
/metaseg/webapp/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from demo import automask_image_app, automask_video_app, sahi_autoseg_app
3 |
4 |
5 | def image_app():
6 | with gr.Blocks():
7 | with gr.Row():
8 | with gr.Column():
9 | seg_automask_image_file = gr.Image(type="filepath").style(height=260)
10 | with gr.Row():
11 | with gr.Column():
12 | seg_automask_image_model_type = gr.Dropdown(
13 | choices=[
14 | "vit_h",
15 | "vit_l",
16 | "vit_b",
17 | ],
18 | value="vit_l",
19 | label="Model Type",
20 | )
21 |
22 | seg_automask_image_min_area = gr.Number(
23 | value=0,
24 | label="Min Area",
25 | )
26 | with gr.Row():
27 | with gr.Column():
28 | seg_automask_image_points_per_side = gr.Slider(
29 | minimum=0,
30 | maximum=32,
31 | step=2,
32 | value=16,
33 | label="Points per Side",
34 | )
35 |
36 | seg_automask_image_points_per_batch = gr.Slider(
37 | minimum=0,
38 | maximum=64,
39 | step=2,
40 | value=64,
41 | label="Points per Batch",
42 | )
43 |
44 | seg_automask_image_predict = gr.Button(value="Generator")
45 |
46 | with gr.Column():
47 | output_image = gr.Image()
48 |
49 | seg_automask_image_predict.click(
50 | fn=automask_image_app,
51 | inputs=[
52 | seg_automask_image_file,
53 | seg_automask_image_model_type,
54 | seg_automask_image_points_per_side,
55 | seg_automask_image_points_per_batch,
56 | seg_automask_image_min_area,
57 | ],
58 | outputs=[output_image],
59 | )
60 |
61 |
62 | def video_app():
63 | with gr.Blocks():
64 | with gr.Row():
65 | with gr.Column():
66 | seg_automask_video_file = gr.Video().style(height=260)
67 | with gr.Row():
68 | with gr.Column():
69 | seg_automask_video_model_type = gr.Dropdown(
70 | choices=[
71 | "vit_h",
72 | "vit_l",
73 | "vit_b",
74 | ],
75 | value="vit_l",
76 | label="Model Type",
77 | )
78 | seg_automask_video_min_area = gr.Number(
79 | value=1000,
80 | label="Min Area",
81 | )
82 |
83 | with gr.Row():
84 | with gr.Column():
85 | seg_automask_video_points_per_side = gr.Slider(
86 | minimum=0,
87 | maximum=32,
88 | step=2,
89 | value=16,
90 | label="Points per Side",
91 | )
92 |
93 | seg_automask_video_points_per_batch = gr.Slider(
94 | minimum=0,
95 | maximum=64,
96 | step=2,
97 | value=64,
98 | label="Points per Batch",
99 | )
100 |
101 | seg_automask_video_predict = gr.Button(value="Generator")
102 | with gr.Column():
103 | output_video = gr.Video()
104 |
105 | seg_automask_video_predict.click(
106 | fn=automask_video_app,
107 | inputs=[
108 | seg_automask_video_file,
109 | seg_automask_video_model_type,
110 | seg_automask_video_points_per_side,
111 | seg_automask_video_points_per_batch,
112 | seg_automask_video_min_area,
113 | ],
114 | outputs=[output_video],
115 | )
116 |
117 |
118 | def sahi_app():
119 | with gr.Blocks():
120 | with gr.Row():
121 | with gr.Column():
122 | sahi_image_file = gr.Image(type="filepath").style(height=260)
123 | sahi_autoseg_model_type = gr.Dropdown(
124 | choices=[
125 | "vit_h",
126 | "vit_l",
127 | "vit_b",
128 | ],
129 | value="vit_l",
130 | label="Sam Model Type",
131 | )
132 |
133 | with gr.Row():
134 | with gr.Column():
135 | sahi_model_type = gr.Dropdown(
136 | choices=[
137 | "yolov5",
138 | "yolov8",
139 | ],
140 | value="yolov5",
141 | label="Detector Model Type",
142 | )
143 | sahi_image_size = gr.Slider(
144 | minimum=0,
145 | maximum=1600,
146 | step=32,
147 | value=640,
148 | label="Image Size",
149 | )
150 |
151 | sahi_overlap_width = gr.Slider(
152 | minimum=0,
153 | maximum=1,
154 | step=0.1,
155 | value=0.2,
156 | label="Overlap Width",
157 | )
158 |
159 | sahi_slice_width = gr.Slider(
160 | minimum=0,
161 | maximum=640,
162 | step=32,
163 | value=256,
164 | label="Slice Width",
165 | )
166 |
167 | with gr.Row():
168 | with gr.Column():
169 | sahi_model_path = gr.Dropdown(
170 | choices=[
171 | "yolov5l.pt",
172 | "yolov5l6.pt",
173 | "yolov8l.pt",
174 | "yolov8x.pt",
175 | ],
176 | value="yolov5l6.pt",
177 | label="Detector Model Path",
178 | )
179 |
180 | sahi_conf_th = gr.Slider(
181 | minimum=0,
182 | maximum=1,
183 | step=0.1,
184 | value=0.2,
185 | label="Confidence Threshold",
186 | )
187 | sahi_overlap_height = gr.Slider(
188 | minimum=0,
189 | maximum=1,
190 | step=0.1,
191 | value=0.2,
192 | label="Overlap Height",
193 | )
194 | sahi_slice_height = gr.Slider(
195 | minimum=0,
196 | maximum=640,
197 | step=32,
198 | value=256,
199 | label="Slice Height",
200 | )
201 | sahi_image_predict = gr.Button(value="Generator")
202 |
203 | with gr.Column():
204 | output_image = gr.Image()
205 |
206 | sahi_image_predict.click(
207 | fn=sahi_autoseg_app,
208 | inputs=[
209 | sahi_image_file,
210 | sahi_autoseg_model_type,
211 | sahi_model_type,
212 | sahi_model_path,
213 | sahi_conf_th,
214 | sahi_image_size,
215 | sahi_slice_height,
216 | sahi_slice_width,
217 | sahi_overlap_height,
218 | sahi_overlap_width,
219 | ],
220 | outputs=[output_image],
221 | )
222 |
223 |
224 | def metaseg_app():
225 | app = gr.Blocks()
226 | with app:
227 | with gr.Row():
228 | with gr.Column():
229 | with gr.Tab("Image"):
230 | image_app()
231 | with gr.Tab("Video"):
232 | video_app()
233 | with gr.Tab("SAHI"):
234 | sahi_app()
235 |
236 | app.queue(concurrency_count=1)
237 | app.launch(debug=True, enable_queue=True)
238 |
239 |
240 | if __name__ == "__main__":
241 | metaseg_app()
242 |
--------------------------------------------------------------------------------
/metaseg/webapp/demo.py:
--------------------------------------------------------------------------------
1 | from metaseg import (
2 | SahiAutoSegmentation,
3 | SegAutoMaskPredictor,
4 | SegManualMaskPredictor,
5 | sahi_sliced_predict,
6 | )
7 |
8 | # For image
9 |
10 |
11 | def automask_image_app(
12 | image_path, model_type, points_per_side, points_per_batch, min_area
13 | ):
14 | SegAutoMaskPredictor().image_predict(
15 | source=image_path,
16 | model_type=model_type, # vit_l, vit_h, vit_b
17 | points_per_side=points_per_side,
18 | points_per_batch=points_per_batch,
19 | min_area=min_area,
20 | output_path="output.png",
21 | show=False,
22 | save=True,
23 | )
24 | return "output.png"
25 |
26 |
27 | # For video
28 |
29 |
30 | def automask_video_app(
31 | video_path, model_type, points_per_side, points_per_batch, min_area
32 | ):
33 | SegAutoMaskPredictor().video_predict(
34 | source=video_path,
35 | model_type=model_type, # vit_l, vit_h, vit_b
36 | points_per_side=points_per_side,
37 | points_per_batch=points_per_batch,
38 | min_area=min_area,
39 | output_path="output.mp4",
40 | )
41 | return "output.mp4"
42 |
43 |
44 | # For manuel box and point selection
45 |
46 |
47 | def manual_app(
48 | image_path,
49 | model_type,
50 | input_point,
51 | input_label,
52 | input_box,
53 | multimask_output,
54 | random_color,
55 | ):
56 | SegManualMaskPredictor().image_predict(
57 | source=image_path,
58 | model_type=model_type, # vit_l, vit_h, vit_b
59 | input_point=input_point,
60 | input_label=input_label,
61 | input_box=input_box,
62 | multimask_output=multimask_output,
63 | random_color=random_color,
64 | output_path="output.png",
65 | show=False,
66 | save=True,
67 | )
68 | return "output.png"
69 |
70 |
71 | # For sahi sliced prediction
72 |
73 |
74 | def sahi_autoseg_app(
75 | image_path,
76 | sam_model_type,
77 | detection_model_type,
78 | detection_model_path,
79 | conf_th,
80 | image_size,
81 | slice_height,
82 | slice_width,
83 | overlap_height_ratio,
84 | overlap_width_ratio,
85 | ):
86 | boxes = sahi_sliced_predict(
87 | image_path=image_path,
88 | # yolov8, detectron2, mmdetection, torchvision
89 | detection_model_type=detection_model_type,
90 | detection_model_path=detection_model_path,
91 | conf_th=conf_th,
92 | image_size=image_size,
93 | slice_height=slice_height,
94 | slice_width=slice_width,
95 | overlap_height_ratio=overlap_height_ratio,
96 | overlap_width_ratio=overlap_width_ratio,
97 | )
98 |
99 | SahiAutoSegmentation().image_predict(
100 | source=image_path,
101 | model_type=sam_model_type,
102 | input_box=boxes,
103 | multimask_output=False,
104 | random_color=False,
105 | show=False,
106 | save=True,
107 | )
108 |
109 | return "output.png"
110 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "metaseg"
3 | version = "0.7.8"
4 | description = "MetaSeg: Packaged version of the Segment Anything repository"
5 | authors = ["Kadir Nar "]
6 | maintainers = ["Kadir Nar "]
7 | readme = "README.md"
8 | packages = [{include = "metaseg"}]
9 | homepage = "https://github.com/kadirnar/segment-anything-video"
10 | repository = "https://github.com/kadirnar/segment-anything-video"
11 | documentation = "https://github.com/kadirnar/segment-anything-video/blob/main/README.md"
12 | keywords = ["pytorch","segment-anything-video","metaseg"]
13 | license = "Apache-2.0"
14 | classifiers = [
15 | "Development Status :: 5 - Production/Stable",
16 | "License :: OSI Approved :: Apache Software License",
17 | "Natural Language :: English",
18 | "Programming Language :: Python",
19 | "Programming Language :: Python :: 3",
20 | "Programming Language :: Python :: 3 :: Only",
21 | "Programming Language :: Python :: 3.8",
22 | "Programming Language :: Python :: 3.9",
23 | "Programming Language :: Python :: 3.10",
24 | "Programming Language :: Python :: 3.11",
25 | "Operating System :: OS Independent",
26 | "Topic :: Software Development :: Libraries :: Application Frameworks",
27 | "Topic :: Software Development :: Libraries :: Python Modules",
28 | "Topic :: Software Development :: Libraries",
29 | "Topic :: Software Development",
30 | "Topic :: Scientific/Engineering",
31 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
32 | "Topic :: Scientific/Engineering :: Mathematics",
33 | "Topic :: Scientific/Engineering :: Image Recognition",
34 | "Topic :: Scientific/Engineering :: Image Processing",
35 | "Intended Audience :: Developers",
36 | "Intended Audience :: Science/Research",
37 | "Intended Audience :: Education",
38 | ]
39 |
40 |
41 | [tool.poetry.dependencies]
42 | python = ">=3.8.1,<3.12.0"
43 | torch = "^2.0.1"
44 | torchvision = "^0.15.2"
45 | opencv-python = "^4.7.0.72"
46 | tqdm = "^4.65.0"
47 | matplotlib = "^3.7.1"
48 | pillow = "^9.5.0"
49 | pycocotools = "^2.0.6"
50 | fal-serverless = "^0.6.35"
51 | sahi = "^0.11.14"
52 | onnx = { version = "^1.14.0", optional = true }
53 | onnxruntime = { version ="^1.15.1", optional = true }
54 | ultralytics = { version = "^8.0.123", optional = true }
55 | yolov5 = { version ="^7.0.12", optional = true }
56 | requests = "^2.31.0"
57 |
58 |
59 | [tool.poetry.extras]
60 | full = ["onnxruntime","onnx","yolov5","ultralytics"]
61 | yolov5 = ["yolov5"]
62 | yolov8 = ["ultralytics"]
63 |
64 |
65 | [tool.poetry.group.dev.dependencies]
66 | black = "^23.1.0"
67 | mypy = "^1.0.1"
68 | bandit = "^1.7.4"
69 | debugpy = "^1.6.6"
70 | rope = "^1.7.0"
71 | wheel = "^0.38.4"
72 | setuptools = "^67.4.0"
73 | coverage = "^7.2.1"
74 | pre-commit = "^3.1.1"
75 | pyupgrade = "^3.3.1"
76 | ruff = "^0.0.244"
77 | pytest = "^7.2.1"
78 | toml = "^0.10.2"
79 | flake8 = "^6.0.0"
80 | isort = "^5.12.0"
81 | parameterized = "^0.9.0"
82 |
83 |
84 |
85 | [tool.isort]
86 | line_length = 88
87 | profile = "black"
88 |
89 | [tool.bandit]
90 | target = ["tests", "metaseg"]
91 | tests = ["B201", "B301"]
92 |
93 | [tool.autoflake]
94 | check = true
95 | imports = ["cv2", "requests", "metaseg"]
96 |
97 |
98 | [tool.black]
99 | line-length = 88
100 | include = '\.pyi?$'
101 | exclude = '''
102 | /(
103 | \.git
104 | | \.hg
105 | | \.mypy_cache
106 | | \.tox
107 | | \.venv
108 | | _build
109 | | buck-out
110 | | build
111 | | dist
112 | )/
113 | '''
114 |
115 | [tool.ruff]
116 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
117 | select = ["E", "F"]
118 | ignore = []
119 |
120 | # Allow autofix for all enabled rules (when `--fix`) is provided.
121 | fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
122 | unfixable = []
123 |
124 | # Exclude a variety of commonly ignored directories.
125 | exclude = [
126 | ".bzr",
127 | ".direnv",
128 | ".eggs",
129 | ".git",
130 | ".git-rewrite",
131 | ".hg",
132 | ".mypy_cache",
133 | ".nox",
134 | ".pants.d",
135 | ".pytype",
136 | ".ruff_cache",
137 | ".svn",
138 | ".tox",
139 | ".venv",
140 | "__pypackages__",
141 | "_build",
142 | "buck-out",
143 | "build",
144 | "dist",
145 | "node_modules",
146 | "venv",
147 | ]
148 |
149 | # Same as Black.
150 | line-length = 88
151 |
152 | # Allow unused variables when underscore-prefixed.
153 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
154 |
155 | [build-system]
156 | requires = ["poetry-core"]
157 | build-backend = "poetry.core.masonry.api"
158 |
--------------------------------------------------------------------------------
/scripts/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import json
9 | import os
10 | from typing import Any, Dict, List
11 |
12 | import cv2
13 |
14 | from metaseg.generator import SamAutomaticMaskGenerator, sam_model_registry
15 |
16 | parser = argparse.ArgumentParser(
17 | description=(
18 | "Runs automatic mask generation on an input image or directory of images, "
19 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
20 | "as well as pycocotools if saving in RLE format."
21 | )
22 | )
23 |
24 | parser.add_argument(
25 | "--input",
26 | type=str,
27 | required=True,
28 | help="Path to either a single input image or folder of images.",
29 | )
30 |
31 | parser.add_argument(
32 | "--output",
33 | type=str,
34 | required=True,
35 | help=(
36 | "Path to the directory where masks will be output. "
37 | "Output will be either a folder "
38 | "of PNGs per image or a single json with COCO-style masks."
39 | ),
40 | )
41 |
42 | parser.add_argument(
43 | "--model-type",
44 | type=str,
45 | default="default",
46 | help="The type of model to load, in ['default', 'vit_l', 'vit_b']",
47 | )
48 |
49 | parser.add_argument(
50 | "--checkpoint",
51 | type=str,
52 | required=True,
53 | help="The path to the SAM checkpoint to use for mask generation.",
54 | )
55 |
56 | parser.add_argument(
57 | "--device", type=str, default="cuda", help="The device to run generation on."
58 | )
59 |
60 | parser.add_argument(
61 | "--convert-to-rle",
62 | action="store_true",
63 | help=(
64 | "Save masks as COCO RLEs in a single json "
65 | "instead of as a folder of PNGs. "
66 | "Requires pycocotools."
67 | ),
68 | )
69 |
70 | amg_settings = parser.add_argument_group("AMG Settings")
71 |
72 | amg_settings.add_argument(
73 | "--points-per-side",
74 | type=int,
75 | default=None,
76 | help="Generate masks by sampling a grid over "
77 | "the image with this many points to a side.",
78 | )
79 |
80 | amg_settings.add_argument(
81 | "--points-per-batch",
82 | type=int,
83 | default=None,
84 | help="How many input points to process " "simultaneously in one batch.",
85 | )
86 |
87 | amg_settings.add_argument(
88 | "--pred-iou-thresh",
89 | type=float,
90 | default=None,
91 | help="Exclude masks with a predicted score from "
92 | "the model that is lower than this threshold.",
93 | )
94 |
95 | amg_settings.add_argument(
96 | "--stability-score-thresh",
97 | type=float,
98 | default=None,
99 | help="Exclude masks with a stability " "score lower than this threshold.",
100 | )
101 |
102 | amg_settings.add_argument(
103 | "--stability-score-offset",
104 | type=float,
105 | default=None,
106 | help="Larger values perturb the mask " "more when measuring stability score.",
107 | )
108 |
109 | amg_settings.add_argument(
110 | "--box-nms-thresh",
111 | type=float,
112 | default=None,
113 | help="The overlap threshold for excluding a duplicate mask.",
114 | )
115 |
116 | amg_settings.add_argument(
117 | "--crop-n-layers",
118 | type=int,
119 | default=None,
120 | help=(
121 | "If >0, mask generation is run on smaller "
122 | "crops of the image to generate more masks. "
123 | "The value sets how many different scales to crop at."
124 | ),
125 | )
126 |
127 | amg_settings.add_argument(
128 | "--crop-nms-thresh",
129 | type=float,
130 | default=None,
131 | help="The overlap threshold for excluding "
132 | "duplicate masks across different crops.",
133 | )
134 |
135 | amg_settings.add_argument(
136 | "--crop-overlap-ratio",
137 | type=int,
138 | default=None,
139 | help="Larger numbers mean image crops will overlap more.",
140 | )
141 |
142 | amg_settings.add_argument(
143 | "--crop-n-points-downscale-factor",
144 | type=int,
145 | default=None,
146 | help="The number of points-per-side in each "
147 | "layer of crop is reduced by this factor.",
148 | )
149 |
150 | amg_settings.add_argument(
151 | "--min-mask-region-area",
152 | type=int,
153 | default=None,
154 | help=(
155 | "Disconnected mask regions or holes with "
156 | "area smaller than this value "
157 | "in pixels are removed by postprocessing."
158 | ),
159 | )
160 |
161 |
162 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
163 | header = (
164 | "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,"
165 | "point_input_x,point_input_y,predicted_iou,"
166 | "stability_score,crop_box_x0,crop_box_y0,"
167 | "crop_box_w,crop_box_h"
168 | )
169 | metadata = [header]
170 | for i, mask_data in enumerate(masks):
171 | mask = mask_data["segmentation"]
172 | filename = f"{i}.png"
173 | cv2.imwrite(os.path.join(path, filename), mask * 255)
174 | mask_metadata = [
175 | str(i),
176 | str(mask_data["area"]),
177 | *[str(x) for x in mask_data["bbox"]],
178 | *[str(x) for x in mask_data["point_coords"][0]],
179 | str(mask_data["predicted_iou"]),
180 | str(mask_data["stability_score"]),
181 | *[str(x) for x in mask_data["crop_box"]],
182 | ]
183 | row = ",".join(mask_metadata)
184 | metadata.append(row)
185 | metadata_path = os.path.join(path, "metadata.csv")
186 | with open(metadata_path, "w") as f:
187 | f.write("\n".join(metadata))
188 |
189 | return
190 |
191 |
192 | def get_amg_kwargs(args):
193 | amg_kwargs = {
194 | "points_per_side": args.points_per_side,
195 | "points_per_batch": args.points_per_batch,
196 | "pred_iou_thresh": args.pred_iou_thresh,
197 | "stability_score_thresh": args.stability_score_thresh,
198 | "stability_score_offset": args.stability_score_offset,
199 | "box_nms_thresh": args.box_nms_thresh,
200 | "crop_n_layers": args.crop_n_layers,
201 | "crop_nms_thresh": args.crop_nms_thresh,
202 | "crop_overlap_ratio": args.crop_overlap_ratio,
203 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
204 | "min_mask_region_area": args.min_mask_region_area,
205 | }
206 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
207 | return amg_kwargs
208 |
209 |
210 | def main(args: argparse.Namespace) -> None:
211 | print("Loading model...")
212 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
213 | _ = sam.to(device=args.device)
214 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
215 | amg_kwargs = get_amg_kwargs(args)
216 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
217 |
218 | if not os.path.isdir(args.input):
219 | targets = [args.input]
220 | else:
221 | targets = [
222 | f
223 | for f in os.listdir(args.input)
224 | if not os.path.isdir(os.path.join(args.input, f))
225 | ]
226 | targets = [os.path.join(args.input, f) for f in targets]
227 |
228 | os.makedirs(args.output, exist_ok=True)
229 |
230 | for t in targets:
231 | print(f"Processing '{t}'...")
232 | image = cv2.imread(t)
233 | if image is None:
234 | print(f"Could not load '{t}' as an image, skipping...")
235 | continue
236 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
237 |
238 | masks = generator.generate(image)
239 |
240 | base = os.path.basename(t)
241 | base = os.path.splitext(base)[0]
242 | save_base = os.path.join(args.output, base)
243 | if output_mode == "binary_mask":
244 | os.makedirs(save_base, exist_ok=False)
245 | write_masks_to_folder(masks, save_base)
246 | else:
247 | save_file = save_base + ".json"
248 | with open(save_file, "w") as f:
249 | json.dump(masks, f)
250 | print("Done!")
251 |
252 |
253 | if __name__ == "__main__":
254 | args = parser.parse_args()
255 | main(args)
256 |
--------------------------------------------------------------------------------
/scripts/export_onnx_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import warnings
9 |
10 | import torch
11 | from onnxruntime import InferenceSession
12 | from onnxruntime.quantization import QuantType
13 | from onnxruntime.quantization.quantize import quantize_dynamic
14 |
15 | from metaseg.generator import build_sam, build_sam_vit_b, build_sam_vit_l
16 | from metaseg.utils.onnx import SamOnnxModel
17 |
18 | parser = argparse.ArgumentParser(
19 | description="Export the SAM prompt encoder and mask decoder to an ONNX model."
20 | )
21 |
22 | parser.add_argument(
23 | "--checkpoint",
24 | type=str,
25 | required=True,
26 | help="The path to the SAM model checkpoint.",
27 | )
28 |
29 | parser.add_argument(
30 | "--output", type=str, required=True, help="The filename to save the ONNX model to."
31 | )
32 |
33 | parser.add_argument(
34 | "--model-type",
35 | type=str,
36 | default="default",
37 | help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
38 | )
39 |
40 | parser.add_argument(
41 | "--return-single-mask",
42 | action="store_true",
43 | help=(
44 | "If true, the exported ONNX model will only return the best mask, "
45 | "instead of returning multiple masks. For high resolution images "
46 | "this can improve runtime when upscaling masks is expensive."
47 | ),
48 | )
49 |
50 | parser.add_argument(
51 | "--opset",
52 | type=int,
53 | default=17,
54 | help="The ONNX opset version to use. Must be >=11",
55 | )
56 |
57 | parser.add_argument(
58 | "--quantize-out",
59 | type=str,
60 | default=None,
61 | help=(
62 | "If set, will quantize the model and save it with this name. "
63 | "Quantization is performed with quantize_dynamic "
64 | "from onnxruntime.quantization.quantize."
65 | ),
66 | )
67 |
68 | parser.add_argument(
69 | "--gelu-approximate",
70 | action="store_true",
71 | help=(
72 | "Replace GELU operations with approximations using tanh. Useful "
73 | "for some runtimes that have slow or unimplemented erf ops, used in GELU."
74 | ),
75 | )
76 |
77 | parser.add_argument(
78 | "--use-stability-score",
79 | action="store_true",
80 | help=(
81 | "Replaces the model's predicted mask quality score with the stability "
82 | "score calculated on the low resolution masks using an offset of 1.0. "
83 | ),
84 | )
85 |
86 | parser.add_argument(
87 | "--return-extra-metrics",
88 | action="store_true",
89 | help=(
90 | "The model will return five results: (masks, scores, stability_scores, "
91 | "areas, low_res_logits) instead of the usual three. This can be "
92 | "significantly slower for high resolution outputs."
93 | ),
94 | )
95 |
96 |
97 | def run_export(
98 | model_type: str,
99 | checkpoint: str,
100 | output: str,
101 | opset: int,
102 | return_single_mask: bool,
103 | gelu_approximate: bool = False,
104 | use_stability_score: bool = False,
105 | return_extra_metrics=False,
106 | ):
107 | print("Loading model...")
108 | if model_type == "vit_b":
109 | sam = build_sam_vit_b(checkpoint)
110 | elif model_type == "vit_l":
111 | sam = build_sam_vit_l(checkpoint)
112 | else:
113 | sam = build_sam(checkpoint)
114 |
115 | onnx_model = SamOnnxModel(
116 | model=sam,
117 | return_single_mask=return_single_mask,
118 | use_stability_score=use_stability_score,
119 | return_extra_metrics=return_extra_metrics,
120 | )
121 |
122 | if gelu_approximate:
123 | for n, m in onnx_model.named_modules():
124 | if isinstance(m, torch.nn.GELU):
125 | m.approximate = "tanh"
126 |
127 | dynamic_axes = {
128 | "point_coords": {1: "num_points"},
129 | "point_labels": {1: "num_points"},
130 | }
131 |
132 | embed_dim = sam.prompt_encoder.embed_dim
133 | embed_size = sam.prompt_encoder.image_embedding_size
134 | mask_input_size = [4 * x for x in embed_size]
135 | dummy_inputs = {
136 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
137 | "point_coords": torch.randint(
138 | low=0, high=1024, size=(1, 5, 2), dtype=torch.float
139 | ),
140 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
141 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
142 | "has_mask_input": torch.tensor([1], dtype=torch.float),
143 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
144 | }
145 |
146 | _ = onnx_model(**dummy_inputs)
147 |
148 | output_names = ["masks", "iou_predictions", "low_res_masks"]
149 |
150 | with warnings.catch_warnings():
151 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
152 | warnings.filterwarnings("ignore", category=UserWarning)
153 | with open(output, "wb") as f:
154 | print(f"Exporing onnx model to {output}...")
155 | torch.onnx.export(
156 | onnx_model,
157 | tuple(dummy_inputs.values()),
158 | f,
159 | export_params=True,
160 | verbose=False,
161 | opset_version=opset,
162 | do_constant_folding=True,
163 | input_names=list(dummy_inputs.keys()),
164 | output_names=output_names,
165 | dynamic_axes=dynamic_axes,
166 | )
167 |
168 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
169 | ort_session = InferenceSession(output)
170 | _ = ort_session.run(None, ort_inputs)
171 | print("Model has successfully been run with ONNXRuntime.")
172 |
173 |
174 | def to_numpy(tensor):
175 | return tensor.cpu().numpy()
176 |
177 |
178 | if __name__ == "__main__":
179 | args = parser.parse_args()
180 | run_export(
181 | model_type=args.model_type,
182 | checkpoint=args.checkpoint,
183 | output=args.output,
184 | opset=args.opset,
185 | return_single_mask=args.return_single_mask,
186 | gelu_approximate=args.gelu_approximate,
187 | use_stability_score=args.use_stability_score,
188 | return_extra_metrics=args.return_extra_metrics,
189 | )
190 |
191 | print(f"Quantizing model and writing to {args.quantize_out}...")
192 | quantize_dynamic(
193 | model_input=args.output,
194 | model_output=args.quantize_out,
195 | optimize_model=True,
196 | per_channel=False,
197 | reduce_range=False,
198 | weight_type=QuantType.QUInt8,
199 | )
200 | print("Done!")
201 |
--------------------------------------------------------------------------------