├── .gitignore ├── LICENSE ├── README.md ├── backend ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets │ ├── data │ │ └── dogs.jpg │ ├── masks1.png │ ├── masks2.jpg │ ├── model_diagram.png │ ├── notebook1.png │ └── notebook2.png ├── embedded.py ├── linter.sh ├── main.py ├── notebooks │ ├── automatic_mask_generator_example.ipynb │ ├── images │ │ ├── dog.jpg │ │ ├── groceries.jpg │ │ └── truck.jpg │ ├── onnx_model_example.ipynb │ └── predictor_example.ipynb ├── put your model sam_vit_b.pth ├── scripts │ ├── amg.py │ └── export_onnx_model.py ├── segment_anything.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── segment_anything │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── setup.cfg └── setup.py └── frontend ├── .gitignore ├── README.md ├── configs └── webpack │ ├── common.js │ ├── dev.js │ └── prod.js ├── model ├── meta_multi_onnx.onnx ├── meta_onnx.onnx └── onnx_example.onnx ├── package-lock.json ├── package.json ├── postcss.config.js ├── readme ├── demo.gif └── demo.mkv ├── src ├── App.tsx ├── assets │ ├── Github.svg │ ├── Meta.svg │ ├── arrow-icn.svg │ ├── chairs.png │ ├── circle-plus.svg │ ├── dataset.png │ ├── gallery │ │ ├── 1.jpg │ │ ├── 1.jpg.npy │ │ ├── 2.jpg │ │ ├── 2.jpg.npy │ │ ├── 3.jpg │ │ ├── 3.jpg.npy │ │ ├── 4.jpg │ │ ├── 4.jpg.npy │ │ ├── 5.jpg │ │ ├── 5.jpg.npy │ │ ├── 6.jpg │ │ ├── 7.jpg │ │ ├── 8.jpg │ │ └── 8.jpg.npy │ ├── hamburger.svg │ ├── horses.png │ ├── icn-image-gallery.svg │ ├── icn-nn.svg │ ├── index.html │ ├── scss │ │ └── App.scss │ ├── stack.svg │ └── upload_arrow.svg ├── components │ ├── Canvas.tsx │ ├── ErrorPage.tsx │ ├── FeatureSummary.tsx │ ├── FeedbackModal.tsx │ ├── Footer.tsx │ ├── ImagePicker.tsx │ ├── MobileOptionNavBar.tsx │ ├── MobileSegmentDrawer.tsx │ ├── NavBar.tsx │ ├── SegmentDrawer.tsx │ ├── SegmentOptions.tsx │ ├── Sparkle.tsx │ ├── Stage.tsx │ ├── SvgMask.tsx │ ├── ToolTip.tsx │ ├── helpers │ │ ├── CanvasHelper.tsx │ │ ├── Interfaces.tsx │ │ ├── colors.tsx │ │ ├── files.tsx │ │ ├── maskUtils.tsx │ │ ├── metaTheme.tsx │ │ ├── onnxModelAPI.tsx │ │ ├── photos.tsx │ │ ├── scaleHelper.tsx │ │ └── trace.tsx │ └── hooks │ │ ├── Animation.tsx │ │ ├── context.tsx │ │ └── createContext.tsx ├── enviroments.tsx └── index.tsx ├── tailwind.config.js └── tsconfig.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | lerna-debug.log* 8 | 9 | # Diagnostic reports (https://nodejs.org/api/report.html) 10 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 11 | 12 | # Runtime data 13 | pids 14 | *.pid 15 | *.seed 16 | *.pid.lock 17 | 18 | # Directory for instrumented libs generated by jscoverage/JSCover 19 | lib-cov 20 | 21 | # Coverage directory used by tools like istanbul 22 | coverage 23 | *.lcov 24 | 25 | # nyc test coverage 26 | .nyc_output 27 | 28 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 29 | .grunt 30 | 31 | # Bower dependency directory (https://bower.io/) 32 | bower_components 33 | 34 | # node-waf configuration 35 | .lock-wscript 36 | 37 | # Compiled binary addons (https://nodejs.org/api/addons.html) 38 | build/Release 39 | 40 | # Dependency directories 41 | node_modules/ 42 | jspm_packages/ 43 | 44 | # TypeScript v1 declaration files 45 | typings/ 46 | 47 | # TypeScript cache 48 | *.tsbuildinfo 49 | 50 | # Optional npm cache directory 51 | .npm 52 | 53 | # Optional eslint cache 54 | .eslintcache 55 | 56 | # Microbundle cache 57 | .rpt2_cache/ 58 | .rts2_cache_cjs/ 59 | .rts2_cache_es/ 60 | .rts2_cache_umd/ 61 | 62 | # Optional REPL history 63 | .node_repl_history 64 | 65 | # Output of 'npm pack' 66 | *.tgz 67 | 68 | # Yarn Integrity file 69 | .yarn-integrity 70 | 71 | # dotenv environment variables file 72 | .env 73 | .env.test 74 | 75 | # parcel-bundler cache (https://parceljs.org/) 76 | .cache 77 | 78 | # Next.js build output 79 | .next 80 | 81 | # Nuxt.js build / generate output 82 | .nuxt 83 | dist 84 | 85 | # Gatsby files 86 | .cache/ 87 | # Comment in the public line in if your project uses Gatsby and *not* Next.js 88 | # https://nextjs.org/blog/next-9-1#public-directory-support 89 | # public 90 | 91 | # vuepress build output 92 | .vuepress/dist 93 | 94 | # Serverless directories 95 | .serverless/ 96 | 97 | # FuseBox cache 98 | .fusebox/ 99 | 100 | # DynamoDB Local files 101 | .dynamodb/ 102 | 103 | # TernJS port file 104 | .tern-port 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # segment-anything-demo 2 | This repository is a sample implementation of frontend/backend using SAM code from meta. 3 | 4 | [![video](frontend/readme/demo.gif)](https://www.youtube.com/watch?v=e9Aj4llndvs) 5 | 6 | We have decided to use SAM from Meta for the development and improvement of our debugging tool. Please refer to the URL below for detailed information on the code. 7 | 8 | [Meta : segment-anything](https://github.com/facebookresearch/segment-anything) 9 | 10 | Please read the README.md files in the backend and frontend folders. 11 | 12 | ## version 13 | 14 | The following are the versions required to run the program. Please refer to them. 15 | 16 | ### front-end 17 | 18 | ```cmd 19 | node --version 20 | v18.16.0 21 | ``` 22 | 23 | ### back-end 24 | 25 | ```cmd 26 | python --version 27 | 3.10.10 28 | ``` 29 | 30 | Please leave any questions or inquiries as an issue. Thank you. 😊 31 | -------------------------------------------------------------------------------- /backend/.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | __pycache__ 3 | *.pyc 4 | .vscode/ 5 | venv/ 6 | build/ 7 | sam_vit_b.pth -------------------------------------------------------------------------------- /backend/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /backend/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /backend/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /backend/README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything 2 | 3 | **[Meta AI Research, FAIR](https://ai.facebook.com/research/)** 4 | 5 | [Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/) 6 | 7 | [[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] [[`BibTeX`](#citing-segment-anything)] 8 | 9 | ![SAM design](assets/model_diagram.png?raw=true) 10 | 11 | The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. 12 | 13 |

14 | 15 | 16 |

17 | 18 | ## Installation 19 | 20 | The code requires `python=3.8 ~ 3.10`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. 21 | 22 | 23 | 24 | Install Segment Anything: 25 | 26 | ``` 27 | pip install -e . 28 | ``` 29 | 30 | The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks. 31 | ``` 32 | pip install opencv-python pycocotools matplotlib onnxruntime onnx 33 | ``` 34 | 35 | To use fastapi, you first need to install the fastapi package and the uvicorn web server. To install them, run the following command in your terminal. 36 | ``` 37 | pip install fastapi uvicorn[standard] 38 | ``` 39 | 40 | To use artificial intelligence, please install an appropriate version of PyTorch for your personal PC's graphics card. 41 | ``` 42 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu113 43 | ``` 44 | 45 | Please read the following instructions and proceed with the installation for the artificial intelligence model for embedded systems. 46 | 47 | 48 | ## Getting Started 49 | 50 | First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt: 51 | 52 | ``` 53 | from segment_anything import SamPredictor, sam_model_registry 54 | sam = sam_model_registry[""](checkpoint="") 55 | predictor = SamPredictor(sam) 56 | predictor.set_image() 57 | masks, _, _ = predictor.predict() 58 | ``` 59 | 60 | or generate masks for an entire image: 61 | 62 | ``` 63 | from segment_anything import SamAutomaticMaskGenerator, sam_model_registry 64 | sam = sam_model_registry[""](checkpoint="") 65 | mask_generator = SamAutomaticMaskGenerator(sam) 66 | masks = mask_generator.generate() 67 | ``` 68 | 69 | Additionally, masks can be generated for images from the command line: 70 | 71 | ``` 72 | python scripts/amg.py --checkpoint --model-type --input --output 73 | ``` 74 | 75 | See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details. 76 | 77 |

78 | 79 | 80 |

81 | 82 | ## ONNX Export 83 | 84 | SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with 85 | 86 | ``` 87 | python scripts/export_onnx_model.py --checkpoint --model-type --output 88 | ``` 89 | 90 | See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export. 91 | 92 | ## Model Checkpoints 93 | 94 | Three model versions of the model are available with different backbone sizes. These models can be instantiated by running 95 | ``` 96 | from segment_anything import sam_model_registry 97 | sam = sam_model_registry[""](checkpoint="") 98 | ``` 99 | Click the links below to download the checkpoint for the corresponding model type. 100 | 101 | * **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)** 102 | * `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) 103 | * `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) 104 | 105 | ## Dataset 106 | See [here](https://ai.facebook.com/datasets/segment-anything/) for an overview of the datastet. The dataset can be downloaded [here](https://ai.facebook.com/datasets/segment-anything-downloads/). By downloading the datasets you agree that you have read and accepted the terms of the SA-1B Dataset Research License. 107 | 108 | We save masks per image as a json file. It can be loaded as a dictionary in python in the below format. 109 | 110 | 111 | ```python 112 | { 113 | "image" : image_info, 114 | "annotations" : [annotation], 115 | } 116 | 117 | image_info { 118 | "image_id" : int, # Image id 119 | "width" : int, # Image width 120 | "height" : int, # Image height 121 | "file_name" : str, # Image filename 122 | } 123 | 124 | annotation { 125 | "id" : int, # Annotation id 126 | "segmentation" : dict, # Mask saved in COCO RLE format. 127 | "bbox" : [x, y, w, h], # The box around the mask, in XYWH format 128 | "area" : int, # The area in pixels of the mask 129 | "predicted_iou" : float, # The model's own prediction of the mask's quality 130 | "stability_score" : float, # A measure of the mask's quality 131 | "crop_box" : [x, y, w, h], # The crop of the image used to generate the mask, in XYWH format 132 | "point_coords" : [[x, y]], # The point coordinates input to the model to generate the mask 133 | } 134 | ``` 135 | 136 | Image ids can be found in sa_images_ids.txt which can be downloaded using the above [link](https://ai.facebook.com/datasets/segment-anything-downloads/) as well. 137 | 138 | To decode a mask in COCO RLE format into binary: 139 | ``` 140 | from pycocotools import mask as mask_utils 141 | mask = mask_utils.decode(annotation["segmentation"]) 142 | ``` 143 | See [here](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py) for more instructions to manipulate masks stored in RLE format. 144 | 145 | 146 | ## License 147 | The model is licensed under the [Apache 2.0 license](LICENSE). 148 | 149 | ## Contributing 150 | 151 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 152 | 153 | ## Contributors 154 | 155 | The Segment Anything project was made possible with the help of many contributors (alphabetical): 156 | 157 | Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom 158 | 159 | ## Citing Segment Anything 160 | 161 | If you use SAM or SA-1B in your research, please use the following BibTeX entry. 162 | 163 | ``` 164 | @article{kirillov2023segany, 165 | title={Segment Anything}, 166 | author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross}, 167 | journal={arXiv:2304.02643}, 168 | year={2023} 169 | } 170 | ``` 171 | 172 | ## Fastapi run 173 | 174 | Please put the "sam_vit_b.pth" model in the same directory as the "main.py" file, and run the backend using FastAPI. 175 | 176 | ``` 177 | python main.py 178 | ``` -------------------------------------------------------------------------------- /backend/assets/data/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/assets/data/dogs.jpg -------------------------------------------------------------------------------- /backend/assets/masks1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/assets/masks1.png -------------------------------------------------------------------------------- /backend/assets/masks2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/assets/masks2.jpg -------------------------------------------------------------------------------- /backend/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/assets/model_diagram.png -------------------------------------------------------------------------------- /backend/assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/assets/notebook1.png -------------------------------------------------------------------------------- /backend/assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/assets/notebook2.png -------------------------------------------------------------------------------- /backend/embedded.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | 6 | import argparse 7 | import onnxruntime 8 | 9 | from onnxruntime.quantization import QuantType 10 | from onnxruntime.quantization.quantize import quantize_dynamic 11 | 12 | from segment_anything import sam_model_registry, SamPredictor 13 | from segment_anything.utils.onnx import SamOnnxModel 14 | 15 | def make_embedding(encode, root:str, checkpoint:str, model_type:str): 16 | encode = np.fromstring(encode, dtype = np.uint8) 17 | image = cv2.imdecode(encode, cv2.IMREAD_UNCHANGED) 18 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 19 | sam.to(device='cuda') 20 | predictor = SamPredictor(sam) 21 | predictor.set_image(image) 22 | image_embedding = predictor.get_image_embedding().cpu().numpy() 23 | np.save(root, image_embedding) 24 | 25 | -------------------------------------------------------------------------------- /backend/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /backend/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request, Response 2 | import uvicorn 3 | from starlette.middleware.cors import CORSMiddleware 4 | 5 | import os 6 | 7 | from embedded import make_embedding 8 | 9 | app = FastAPI() 10 | 11 | @app.post("/ai/embedded/{file_name}") 12 | async def embedded(request: Request, file_name:str): 13 | body = await request.body() 14 | 15 | root = f"{os.path.abspath(os.path.join(os.getcwd(), os.pardir))}/frontend/src/assets/gallery/{file_name}.npy" 16 | checkpoint = "sam_vit_b.pth" 17 | model_type = "vit_b" 18 | print(root) 19 | make_embedding(body, root, checkpoint, model_type) 20 | 21 | return {"npy": f"{file_name}.npy"} 22 | 23 | @app.post("/ai/embedded/all/{file_name}") 24 | async def embedded(request: Request, file_name:str): 25 | return Response(status_code=200) 26 | 27 | 28 | app.add_middleware( 29 | CORSMiddleware, 30 | allow_origins=["*"], 31 | allow_credentials=True, 32 | allow_methods=["*"], 33 | allow_headers=["*"], 34 | ) 35 | 36 | if __name__ == "__main__": 37 | uvicorn.run( 38 | app="main:app", 39 | host="localhost", 40 | port=8000, 41 | reload=True, 42 | ) 43 | -------------------------------------------------------------------------------- /backend/notebooks/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/notebooks/images/dog.jpg -------------------------------------------------------------------------------- /backend/notebooks/images/groceries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/notebooks/images/groceries.jpg -------------------------------------------------------------------------------- /backend/notebooks/images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/notebooks/images/truck.jpg -------------------------------------------------------------------------------- /backend/put your model sam_vit_b.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/backend/put your model sam_vit_b.pth -------------------------------------------------------------------------------- /backend/scripts/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import cv2 # type: ignore 8 | 9 | from segment_anything import SamAutomaticMaskGenerator, sam_model_registry 10 | 11 | import argparse 12 | import json 13 | import os 14 | from typing import Any, Dict, List 15 | 16 | parser = argparse.ArgumentParser( 17 | description=( 18 | "Runs automatic mask generation on an input image or directory of images, " 19 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " 20 | "as well as pycocotools if saving in RLE format." 21 | ) 22 | ) 23 | 24 | parser.add_argument( 25 | "--input", 26 | type=str, 27 | required=True, 28 | help="Path to either a single input image or folder of images.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--output", 33 | type=str, 34 | required=True, 35 | help=( 36 | "Path to the directory where masks will be output. Output will be either a folder " 37 | "of PNGs per image or a single json with COCO-style masks." 38 | ), 39 | ) 40 | 41 | parser.add_argument( 42 | "--model-type", 43 | type=str, 44 | required=True, 45 | help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']", 46 | ) 47 | 48 | parser.add_argument( 49 | "--checkpoint", 50 | type=str, 51 | required=True, 52 | help="The path to the SAM checkpoint to use for mask generation.", 53 | ) 54 | 55 | parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") 56 | 57 | parser.add_argument( 58 | "--convert-to-rle", 59 | action="store_true", 60 | help=( 61 | "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " 62 | "Requires pycocotools." 63 | ), 64 | ) 65 | 66 | amg_settings = parser.add_argument_group("AMG Settings") 67 | 68 | amg_settings.add_argument( 69 | "--points-per-side", 70 | type=int, 71 | default=None, 72 | help="Generate masks by sampling a grid over the image with this many points to a side.", 73 | ) 74 | 75 | amg_settings.add_argument( 76 | "--points-per-batch", 77 | type=int, 78 | default=None, 79 | help="How many input points to process simultaneously in one batch.", 80 | ) 81 | 82 | amg_settings.add_argument( 83 | "--pred-iou-thresh", 84 | type=float, 85 | default=None, 86 | help="Exclude masks with a predicted score from the model that is lower than this threshold.", 87 | ) 88 | 89 | amg_settings.add_argument( 90 | "--stability-score-thresh", 91 | type=float, 92 | default=None, 93 | help="Exclude masks with a stability score lower than this threshold.", 94 | ) 95 | 96 | amg_settings.add_argument( 97 | "--stability-score-offset", 98 | type=float, 99 | default=None, 100 | help="Larger values perturb the mask more when measuring stability score.", 101 | ) 102 | 103 | amg_settings.add_argument( 104 | "--box-nms-thresh", 105 | type=float, 106 | default=None, 107 | help="The overlap threshold for excluding a duplicate mask.", 108 | ) 109 | 110 | amg_settings.add_argument( 111 | "--crop-n-layers", 112 | type=int, 113 | default=None, 114 | help=( 115 | "If >0, mask generation is run on smaller crops of the image to generate more masks. " 116 | "The value sets how many different scales to crop at." 117 | ), 118 | ) 119 | 120 | amg_settings.add_argument( 121 | "--crop-nms-thresh", 122 | type=float, 123 | default=None, 124 | help="The overlap threshold for excluding duplicate masks across different crops.", 125 | ) 126 | 127 | amg_settings.add_argument( 128 | "--crop-overlap-ratio", 129 | type=int, 130 | default=None, 131 | help="Larger numbers mean image crops will overlap more.", 132 | ) 133 | 134 | amg_settings.add_argument( 135 | "--crop-n-points-downscale-factor", 136 | type=int, 137 | default=None, 138 | help="The number of points-per-side in each layer of crop is reduced by this factor.", 139 | ) 140 | 141 | amg_settings.add_argument( 142 | "--min-mask-region-area", 143 | type=int, 144 | default=None, 145 | help=( 146 | "Disconnected mask regions or holes with area smaller than this value " 147 | "in pixels are removed by postprocessing." 148 | ), 149 | ) 150 | 151 | 152 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: 153 | header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa 154 | metadata = [header] 155 | for i, mask_data in enumerate(masks): 156 | mask = mask_data["segmentation"] 157 | filename = f"{i}.png" 158 | cv2.imwrite(os.path.join(path, filename), mask * 255) 159 | mask_metadata = [ 160 | str(i), 161 | str(mask_data["area"]), 162 | *[str(x) for x in mask_data["bbox"]], 163 | *[str(x) for x in mask_data["point_coords"][0]], 164 | str(mask_data["predicted_iou"]), 165 | str(mask_data["stability_score"]), 166 | *[str(x) for x in mask_data["crop_box"]], 167 | ] 168 | row = ",".join(mask_metadata) 169 | metadata.append(row) 170 | metadata_path = os.path.join(path, "metadata.csv") 171 | with open(metadata_path, "w") as f: 172 | f.write("\n".join(metadata)) 173 | 174 | return 175 | 176 | 177 | def get_amg_kwargs(args): 178 | amg_kwargs = { 179 | "points_per_side": args.points_per_side, 180 | "points_per_batch": args.points_per_batch, 181 | "pred_iou_thresh": args.pred_iou_thresh, 182 | "stability_score_thresh": args.stability_score_thresh, 183 | "stability_score_offset": args.stability_score_offset, 184 | "box_nms_thresh": args.box_nms_thresh, 185 | "crop_n_layers": args.crop_n_layers, 186 | "crop_nms_thresh": args.crop_nms_thresh, 187 | "crop_overlap_ratio": args.crop_overlap_ratio, 188 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, 189 | "min_mask_region_area": args.min_mask_region_area, 190 | } 191 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} 192 | return amg_kwargs 193 | 194 | 195 | def main(args: argparse.Namespace) -> None: 196 | print("Loading model...") 197 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) 198 | _ = sam.to(device=args.device) 199 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" 200 | amg_kwargs = get_amg_kwargs(args) 201 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) 202 | 203 | if not os.path.isdir(args.input): 204 | targets = [args.input] 205 | else: 206 | targets = [ 207 | f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) 208 | ] 209 | targets = [os.path.join(args.input, f) for f in targets] 210 | 211 | os.makedirs(args.output, exist_ok=True) 212 | 213 | for t in targets: 214 | print(f"Processing '{t}'...") 215 | image = cv2.imread(t) 216 | if image is None: 217 | print(f"Could not load '{t}' as an image, skipping...") 218 | continue 219 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 220 | 221 | masks = generator.generate(image) 222 | 223 | base = os.path.basename(t) 224 | base = os.path.splitext(base)[0] 225 | save_base = os.path.join(args.output, base) 226 | if output_mode == "binary_mask": 227 | os.makedirs(save_base, exist_ok=False) 228 | write_masks_to_folder(masks, save_base) 229 | else: 230 | save_file = save_base + ".json" 231 | with open(save_file, "w") as f: 232 | json.dump(masks, f) 233 | print("Done!") 234 | 235 | 236 | if __name__ == "__main__": 237 | args = parser.parse_args() 238 | main(args) 239 | -------------------------------------------------------------------------------- /backend/scripts/export_onnx_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from segment_anything import sam_model_registry 10 | from segment_anything.utils.onnx import SamOnnxModel 11 | 12 | import argparse 13 | import warnings 14 | 15 | try: 16 | import onnxruntime # type: ignore 17 | 18 | onnxruntime_exists = True 19 | except ImportError: 20 | onnxruntime_exists = False 21 | 22 | parser = argparse.ArgumentParser( 23 | description="Export the SAM prompt encoder and mask decoder to an ONNX model." 24 | ) 25 | 26 | parser.add_argument( 27 | "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." 28 | ) 29 | 30 | parser.add_argument( 31 | "--output", type=str, required=True, help="The filename to save the ONNX model to." 32 | ) 33 | 34 | parser.add_argument( 35 | "--model-type", 36 | type=str, 37 | required=True, 38 | help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", 39 | ) 40 | 41 | parser.add_argument( 42 | "--return-single-mask", 43 | action="store_true", 44 | help=( 45 | "If true, the exported ONNX model will only return the best mask, " 46 | "instead of returning multiple masks. For high resolution images " 47 | "this can improve runtime when upscaling masks is expensive." 48 | ), 49 | ) 50 | 51 | parser.add_argument( 52 | "--opset", 53 | type=int, 54 | default=17, 55 | help="The ONNX opset version to use. Must be >=11", 56 | ) 57 | 58 | parser.add_argument( 59 | "--quantize-out", 60 | type=str, 61 | default=None, 62 | help=( 63 | "If set, will quantize the model and save it with this name. " 64 | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." 65 | ), 66 | ) 67 | 68 | parser.add_argument( 69 | "--gelu-approximate", 70 | action="store_true", 71 | help=( 72 | "Replace GELU operations with approximations using tanh. Useful " 73 | "for some runtimes that have slow or unimplemented erf ops, used in GELU." 74 | ), 75 | ) 76 | 77 | parser.add_argument( 78 | "--use-stability-score", 79 | action="store_true", 80 | help=( 81 | "Replaces the model's predicted mask quality score with the stability " 82 | "score calculated on the low resolution masks using an offset of 1.0. " 83 | ), 84 | ) 85 | 86 | parser.add_argument( 87 | "--return-extra-metrics", 88 | action="store_true", 89 | help=( 90 | "The model will return five results: (masks, scores, stability_scores, " 91 | "areas, low_res_logits) instead of the usual three. This can be " 92 | "significantly slower for high resolution outputs." 93 | ), 94 | ) 95 | 96 | 97 | def run_export( 98 | model_type: str, 99 | checkpoint: str, 100 | output: str, 101 | opset: int, 102 | return_single_mask: bool, 103 | gelu_approximate: bool = False, 104 | use_stability_score: bool = False, 105 | return_extra_metrics=False, 106 | ): 107 | print("Loading model...") 108 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 109 | 110 | onnx_model = SamOnnxModel( 111 | model=sam, 112 | return_single_mask=return_single_mask, 113 | use_stability_score=use_stability_score, 114 | return_extra_metrics=return_extra_metrics, 115 | ) 116 | 117 | if gelu_approximate: 118 | for n, m in onnx_model.named_modules(): 119 | if isinstance(m, torch.nn.GELU): 120 | m.approximate = "tanh" 121 | 122 | dynamic_axes = { 123 | "point_coords": {1: "num_points"}, 124 | "point_labels": {1: "num_points"}, 125 | } 126 | 127 | embed_dim = sam.prompt_encoder.embed_dim 128 | embed_size = sam.prompt_encoder.image_embedding_size 129 | mask_input_size = [4 * x for x in embed_size] 130 | dummy_inputs = { 131 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 132 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 133 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 134 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 135 | "has_mask_input": torch.tensor([1], dtype=torch.float), 136 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), 137 | } 138 | 139 | _ = onnx_model(**dummy_inputs) 140 | 141 | output_names = ["masks", "iou_predictions", "low_res_masks"] 142 | 143 | with warnings.catch_warnings(): 144 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 145 | warnings.filterwarnings("ignore", category=UserWarning) 146 | with open(output, "wb") as f: 147 | print(f"Exporting onnx model to {output}...") 148 | torch.onnx.export( 149 | onnx_model, 150 | tuple(dummy_inputs.values()), 151 | f, 152 | export_params=True, 153 | verbose=False, 154 | opset_version=opset, 155 | do_constant_folding=True, 156 | input_names=list(dummy_inputs.keys()), 157 | output_names=output_names, 158 | dynamic_axes=dynamic_axes, 159 | ) 160 | 161 | if onnxruntime_exists: 162 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} 163 | ort_session = onnxruntime.InferenceSession(output) 164 | _ = ort_session.run(None, ort_inputs) 165 | print("Model has successfully been run with ONNXRuntime.") 166 | 167 | 168 | def to_numpy(tensor): 169 | return tensor.cpu().numpy() 170 | 171 | 172 | if __name__ == "__main__": 173 | args = parser.parse_args() 174 | run_export( 175 | model_type=args.model_type, 176 | checkpoint=args.checkpoint, 177 | output=args.output, 178 | opset=args.opset, 179 | return_single_mask=args.return_single_mask, 180 | gelu_approximate=args.gelu_approximate, 181 | use_stability_score=args.use_stability_score, 182 | return_extra_metrics=args.return_extra_metrics, 183 | ) 184 | 185 | if args.quantize_out is not None: 186 | assert onnxruntime_exists, "onnxruntime is required to quantize the model." 187 | from onnxruntime.quantization import QuantType # type: ignore 188 | from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore 189 | 190 | print(f"Quantizing model and writing to {args.quantize_out}...") 191 | quantize_dynamic( 192 | model_input=args.output, 193 | model_output=args.quantize_out, 194 | optimize_model=True, 195 | per_channel=False, 196 | reduce_range=False, 197 | weight_type=QuantType.QUInt8, 198 | ) 199 | print("Done!") 200 | -------------------------------------------------------------------------------- /backend/segment_anything.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: segment-anything 3 | Version: 1.0 4 | Provides-Extra: all 5 | Provides-Extra: dev 6 | License-File: LICENSE 7 | -------------------------------------------------------------------------------- /backend/segment_anything.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.cfg 4 | setup.py 5 | segment_anything/__init__.py 6 | segment_anything/automatic_mask_generator.py 7 | segment_anything/build_sam.py 8 | segment_anything/predictor.py 9 | segment_anything.egg-info/PKG-INFO 10 | segment_anything.egg-info/SOURCES.txt 11 | segment_anything.egg-info/dependency_links.txt 12 | segment_anything.egg-info/requires.txt 13 | segment_anything.egg-info/top_level.txt 14 | segment_anything/modeling/__init__.py 15 | segment_anything/modeling/common.py 16 | segment_anything/modeling/image_encoder.py 17 | segment_anything/modeling/mask_decoder.py 18 | segment_anything/modeling/prompt_encoder.py 19 | segment_anything/modeling/sam.py 20 | segment_anything/modeling/transformer.py 21 | segment_anything/utils/__init__.py 22 | segment_anything/utils/amg.py 23 | segment_anything/utils/onnx.py 24 | segment_anything/utils/transforms.py -------------------------------------------------------------------------------- /backend/segment_anything.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /backend/segment_anything.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | 2 | [all] 3 | matplotlib 4 | pycocotools 5 | opencv-python 6 | onnx 7 | onnxruntime 8 | 9 | [dev] 10 | flake8 11 | isort 12 | black 13 | mypy 14 | -------------------------------------------------------------------------------- /backend/segment_anything.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | segment_anything 2 | -------------------------------------------------------------------------------- /backend/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /backend/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /backend/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /backend/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /backend/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for outptu 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /backend/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /backend/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /backend/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /backend/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /backend/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /backend/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /backend/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /backend/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | lerna-debug.log* 8 | 9 | # Diagnostic reports (https://nodejs.org/api/report.html) 10 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 11 | 12 | # Runtime data 13 | pids 14 | *.pid 15 | *.seed 16 | *.pid.lock 17 | 18 | # Directory for instrumented libs generated by jscoverage/JSCover 19 | lib-cov 20 | 21 | # Coverage directory used by tools like istanbul 22 | coverage 23 | *.lcov 24 | 25 | # nyc test coverage 26 | .nyc_output 27 | 28 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 29 | .grunt 30 | 31 | # Bower dependency directory (https://bower.io/) 32 | bower_components 33 | 34 | # node-waf configuration 35 | .lock-wscript 36 | 37 | # Compiled binary addons (https://nodejs.org/api/addons.html) 38 | build/Release 39 | 40 | # Dependency directories 41 | node_modules/ 42 | jspm_packages/ 43 | 44 | # TypeScript v1 declaration files 45 | typings/ 46 | 47 | # TypeScript cache 48 | *.tsbuildinfo 49 | 50 | # Optional npm cache directory 51 | .npm 52 | 53 | # Optional eslint cache 54 | .eslintcache 55 | 56 | # Microbundle cache 57 | .rpt2_cache/ 58 | .rts2_cache_cjs/ 59 | .rts2_cache_es/ 60 | .rts2_cache_umd/ 61 | 62 | # Optional REPL history 63 | .node_repl_history 64 | 65 | # Output of 'npm pack' 66 | *.tgz 67 | 68 | # Yarn Integrity file 69 | .yarn-integrity 70 | 71 | # dotenv environment variables file 72 | .env 73 | .env.test 74 | 75 | # parcel-bundler cache (https://parceljs.org/) 76 | .cache 77 | 78 | # Next.js build output 79 | .next 80 | 81 | # Nuxt.js build / generate output 82 | .nuxt 83 | dist 84 | 85 | # Gatsby files 86 | .cache/ 87 | # Comment in the public line in if your project uses Gatsby and *not* Next.js 88 | # https://nextjs.org/blog/next-9-1#public-directory-support 89 | # public 90 | 91 | # vuepress build output 92 | .vuepress/dist 93 | 94 | # Serverless directories 95 | .serverless/ 96 | 97 | # FuseBox cache 98 | .fusebox/ 99 | 100 | # DynamoDB Local files 101 | .dynamodb/ 102 | 103 | # TernJS port file 104 | .tern-port 105 | -------------------------------------------------------------------------------- /frontend/README.md: -------------------------------------------------------------------------------- 1 | ## Segment Anything Simple Web demo 2 | 3 | This **front-end only** demo shows how to load a fixed image and `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128. 4 | 5 | ## Demo 6 | 7 | 8 | [![asciicast](./readme/demo.gif)](https://github.com/MizzleAa/segment-anything-demo/tree/main/frontend/readme/demo.mkv) 9 | 10 | 11 | ## Run the app 12 | 13 | ``` 14 | yarn && yarn start 15 | ``` 16 | 17 | Navigate to [`http://localhost:8080/`](http://localhost:8080/) 18 | 19 | Move your cursor around to see the mask prediction update in real time. 20 | 21 | ## Export the image embedding 22 | 23 | In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding. 24 | 25 | Initialize the predictor 26 | 27 | ```python 28 | checkpoint = "sam_vit_h_4b8939.pth" 29 | model_type = "vit_h" 30 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 31 | sam.to(device='cuda') 32 | predictor = SamPredictor(sam) 33 | ``` 34 | 35 | Set the new image and export the embedding 36 | 37 | ```python 38 | image = cv2.imread('src/assets/dogs.jpg') 39 | predictor.set_image(image) 40 | image_embedding = predictor.get_image_embedding().cpu().numpy() 41 | np.save("dogs_embedding.npy", image_embedding) 42 | ``` 43 | 44 | Save the new image and embedding in `/assets/gallery`. 45 | 46 | ## Export the ONNX model 47 | 48 | You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb). 49 | 50 | Run the cell in the notebook which saves the `onnx_example.onnx` file, download it and copy it to the path `/model/onnx_example.onnx`. 51 | 52 | Here is a snippet of the export/quantization code: 53 | 54 | ```python 55 | onnx_model_path = "sam_onnx_example.onnx" 56 | onnx_model_quantized_path = "onnx_example.onnx" 57 | quantize_dynamic( 58 | model_input=onnx_model_path, 59 | model_output=onnx_model_quantized_path, 60 | optimize_model=True, 61 | per_channel=False, 62 | reduce_range=False, 63 | weight_type=QuantType.QUInt8, 64 | ) 65 | ``` 66 | 67 | **NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.** 68 | 69 | ## Update the image, embedding, model in the app 70 | 71 | Update the following file paths at the top of `enviroments.tsx`: 72 | 73 | The current version does not support the `MULTI_MASK_MODEL_DIR`. 74 | and `ERASE_API_ENDPOINT` is not supported. 75 | 76 | ```typescript 77 | export const MODEL_DIR = "/model/onnx_example.onnx"; 78 | export const MULTI_MASK_MODEL_DIR ="/model/meta_multi_onnx.onnx" 79 | export const API_ENDPOINT = "http://127.0.0.1:8000/ai/embedded"; 80 | export const ALL_MASK_API_ENDPOINT = "http://127.0.0.1:8000/ai/embedded/all"; 81 | export const ERASE_API_ENDPOINT = ""; 82 | ``` 83 | 84 | ## ONNX multithreading with SharedArrayBuffer 85 | 86 | To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details) 87 | 88 | The headers below are set in `configs/webpack/dev.js`: 89 | 90 | ```js 91 | headers: { 92 | "Cross-Origin-Opener-Policy": "same-origin", 93 | "Cross-Origin-Embedder-Policy": "credentialless", 94 | } 95 | ``` 96 | 97 | ## Structure of the app 98 | 99 | The example code provides a UI similar to the one provided in the demo version by Meta. 100 | 101 | **`App.tsx`** 102 | 103 | - Initializes ONNX model 104 | - Loads image embedding and image 105 | - Runs the ONNX model based on input prompts 106 | 107 | **`Stage.tsx`** 108 | 109 | - Handles mouse move interaction to update the ONNX model prompt 110 | 111 | **`Tool.tsx`** 112 | 113 | - Renders the image and the mask prediction 114 | 115 | **`helpers/maskUtils.tsx`** 116 | 117 | - Conversion of ONNX model output from array to an HTMLImageElement 118 | 119 | **`helpers/onnxModelAPI.tsx`** 120 | 121 | - Formats the inputs for the ONNX model 122 | 123 | **`helpers/scaleHelper.tsx`** 124 | 125 | - Handles image scaling logic for SAM (longest size 1024) 126 | 127 | **`hooks/`** 128 | 129 | - Handle shared state for the app 130 | -------------------------------------------------------------------------------- /frontend/configs/webpack/common.js: -------------------------------------------------------------------------------- 1 | const { resolve } = require("path"); 2 | const HtmlWebpackPlugin = require("html-webpack-plugin"); 3 | // const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin"); 4 | const CopyPlugin = require("copy-webpack-plugin"); 5 | const webpack = require("webpack"); 6 | 7 | module.exports = { 8 | entry: "./src/index.tsx", 9 | resolve: { 10 | extensions: [".js", ".jsx", ".ts", ".tsx"], 11 | }, 12 | output: { 13 | path: resolve(__dirname, "dist"), 14 | }, 15 | module: { 16 | rules: [ 17 | { 18 | test: /\.mjs$/, 19 | include: /node_modules/, 20 | type: "javascript/auto", 21 | resolve: { 22 | fullySpecified: false, 23 | }, 24 | }, 25 | { 26 | test: [/\.jsx?$/, /\.tsx?$/], 27 | use: ["ts-loader"], 28 | exclude: /node_modules/, 29 | }, 30 | { 31 | test: /\.css$/, 32 | use: ["style-loader", "css-loader"], 33 | }, 34 | { 35 | test: /\.(scss|sass)$/, 36 | use: ["style-loader", "css-loader", "postcss-loader"], 37 | }, 38 | { 39 | test: /\.(jpe?g|png|gif|svg)$/i, 40 | use: [ 41 | "file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]", 42 | "image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false", 43 | ], 44 | }, 45 | { 46 | test: /\.(woff|woff2|ttf)$/, 47 | use: { 48 | loader: "url-loader", 49 | }, 50 | }, 51 | ], 52 | }, 53 | plugins: [ 54 | new CopyPlugin({ 55 | patterns: [ 56 | { 57 | from: "node_modules/onnxruntime-web/dist/*.wasm", 58 | to: "[name][ext]", 59 | }, 60 | { 61 | from: "model", 62 | to: "model", 63 | }, 64 | { 65 | from: "src/assets", 66 | to: "assets", 67 | }, 68 | ], 69 | }), 70 | new HtmlWebpackPlugin({ 71 | template: "./src/assets/index.html", 72 | }), 73 | // new FriendlyErrorsWebpackPlugin(), 74 | new webpack.ProvidePlugin({ 75 | process: "process/browser", 76 | }), 77 | ], 78 | }; 79 | -------------------------------------------------------------------------------- /frontend/configs/webpack/dev.js: -------------------------------------------------------------------------------- 1 | // development config 2 | const { merge } = require("webpack-merge"); 3 | const commonConfig = require("./common"); 4 | 5 | module.exports = merge(commonConfig, { 6 | mode: "development", 7 | devServer: { 8 | hot: true, // enable HMR on the server 9 | open: true, 10 | // These headers enable the cross origin isolation state 11 | // needed to enable use of SharedArrayBuffer for ONNX 12 | // multithreading. 13 | headers: { 14 | "Cross-Origin-Opener-Policy": "same-origin", 15 | "Cross-Origin-Embedder-Policy": "credentialless", 16 | }, 17 | }, 18 | devtool: "cheap-module-source-map", 19 | }); 20 | -------------------------------------------------------------------------------- /frontend/configs/webpack/prod.js: -------------------------------------------------------------------------------- 1 | // production config 2 | const { merge } = require("webpack-merge"); 3 | const { resolve } = require("path"); 4 | const Dotenv = require("dotenv-webpack"); 5 | const commonConfig = require("./common"); 6 | 7 | module.exports = merge(commonConfig, { 8 | mode: "production", 9 | output: { 10 | filename: "js/bundle.[contenthash].min.js", 11 | path: resolve(__dirname, "../../dist"), 12 | publicPath: "/", 13 | }, 14 | devtool: "source-map", 15 | plugins: [new Dotenv()], 16 | }); 17 | -------------------------------------------------------------------------------- /frontend/model/meta_multi_onnx.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/model/meta_multi_onnx.onnx -------------------------------------------------------------------------------- /frontend/model/meta_onnx.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/model/meta_onnx.onnx -------------------------------------------------------------------------------- /frontend/model/onnx_example.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/model/onnx_example.onnx -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "segment-anything-mini-demo", 3 | "version": "0.1.0", 4 | "license": "MIT", 5 | "scripts": { 6 | "build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js", 7 | "clean-dist": "rimraf dist/*", 8 | "lint": "eslint './src/**/*.{js,ts,tsx}' --quiet", 9 | "start": "yarn run start-dev", 10 | "test": "yarn run start-model-test", 11 | "start-dev": "webpack serve --config=configs/webpack/dev.js" 12 | }, 13 | "devDependencies": { 14 | "@babel/core": "^7.21.4", 15 | "@babel/preset-env": "^7.21.4", 16 | "@babel/preset-react": "^7.18.6", 17 | "@babel/preset-typescript": "^7.21.4", 18 | "@pmmmwh/react-refresh-webpack-plugin": "^0.5.10", 19 | "@testing-library/react": "^14.0.0", 20 | "@types/node": "^18.15.11", 21 | "@types/react": "^18.0.34", 22 | "@types/react-dom": "^18.0.11", 23 | "@types/underscore": "^1.11.4", 24 | "@typescript-eslint/eslint-plugin": "^5.58.0", 25 | "@typescript-eslint/parser": "^5.58.0", 26 | "babel-loader": "^9.1.2", 27 | "copy-webpack-plugin": "^11.0.0", 28 | "css-loader": "^6.7.3", 29 | "dotenv": "^16.0.3", 30 | "dotenv-webpack": "^8.0.1", 31 | "eslint": "^8.38.0", 32 | "eslint-plugin-react": "^7.32.2", 33 | "file-loader": "^6.2.0", 34 | "fork-ts-checker-webpack-plugin": "^8.0.0", 35 | "html-webpack-plugin": "^5.5.0", 36 | "image-webpack-loader": "^8.1.0", 37 | "postcss-loader": "^7.2.4", 38 | "postcss-preset-env": "^8.3.1", 39 | "process": "^0.11.10", 40 | "rimraf": "^5.0.0", 41 | "sass": "^1.62.0", 42 | "sass-loader": "^13.2.2", 43 | "style-loader": "^3.3.2", 44 | "tailwindcss": "^3.3.1", 45 | "ts-loader": "^9.4.2", 46 | "typescript": "^5.0.4", 47 | "webpack": "^5.78.0", 48 | "webpack-cli": "^5.0.1", 49 | "webpack-dev-server": "^4.13.2", 50 | "webpack-dotenv-plugin": "^2.1.0", 51 | "webpack-merge": "^5.8.0" 52 | }, 53 | "dependencies": { 54 | "konva": "^8.4.3", 55 | "lz-string": "^1.5.0", 56 | "npyjs": "^0.4.0", 57 | "onnxruntime-web": "^1.14.0", 58 | "react": "^18.2.0", 59 | "react-cookie": "^4.1.1", 60 | "react-cookie-consent": "^8.0.1", 61 | "react-daisyui": "^3.1.2", 62 | "react-dom": "^18.2.0", 63 | "react-dropzone": "^14.2.3", 64 | "react-ga4": "^2.1.0", 65 | "react-icons": "^4.8.0", 66 | "react-konva": "^18.2.5", 67 | "react-photo-album": "^2.0.4", 68 | "react-refresh": "^0.14.0", 69 | "react-router-dom": "^6.10.0", 70 | "react-swipeable": "^7.0.0", 71 | "underscore": "^1.13.6" 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /frontend/postcss.config.js: -------------------------------------------------------------------------------- 1 | const tailwindcss = require("tailwindcss"); 2 | module.exports = { 3 | plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss], 4 | }; 5 | -------------------------------------------------------------------------------- /frontend/readme/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/readme/demo.gif -------------------------------------------------------------------------------- /frontend/readme/demo.mkv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/readme/demo.mkv -------------------------------------------------------------------------------- /frontend/src/assets/Github.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /frontend/src/assets/arrow-icn.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /frontend/src/assets/chairs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/chairs.png -------------------------------------------------------------------------------- /frontend/src/assets/circle-plus.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /frontend/src/assets/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/dataset.png -------------------------------------------------------------------------------- /frontend/src/assets/gallery/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/1.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/1.jpg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/1.jpg.npy -------------------------------------------------------------------------------- /frontend/src/assets/gallery/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/2.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/2.jpg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/2.jpg.npy -------------------------------------------------------------------------------- /frontend/src/assets/gallery/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/3.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/3.jpg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/3.jpg.npy -------------------------------------------------------------------------------- /frontend/src/assets/gallery/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/4.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/4.jpg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/4.jpg.npy -------------------------------------------------------------------------------- /frontend/src/assets/gallery/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/5.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/5.jpg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/5.jpg.npy -------------------------------------------------------------------------------- /frontend/src/assets/gallery/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/6.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/7.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/8.jpg -------------------------------------------------------------------------------- /frontend/src/assets/gallery/8.jpg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/gallery/8.jpg.npy -------------------------------------------------------------------------------- /frontend/src/assets/hamburger.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /frontend/src/assets/horses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MizzleAa/segment-anything-demo-react-fastapi/75f512410850a3507a85801fac0b7adafae4f6e6/frontend/src/assets/horses.png -------------------------------------------------------------------------------- /frontend/src/assets/icn-image-gallery.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /frontend/src/assets/icn-nn.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /frontend/src/assets/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 9 | Segment Anything Demo 10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | -------------------------------------------------------------------------------- /frontend/src/assets/stack.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /frontend/src/assets/upload_arrow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /frontend/src/components/ErrorPage.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { useRouteError } from "react-router-dom"; 3 | 4 | const ErrorPage = () => { 5 | const error: any = useRouteError(); 6 | console.error(error); 7 | 8 | return ( 9 |
10 |

Oops!

11 |

Sorry, an unexpected error has occurred.

12 |

13 | {error.statusText || error.message} 14 |

15 |
16 | ); 17 | }; 18 | 19 | export default ErrorPage; 20 | -------------------------------------------------------------------------------- /frontend/src/components/FeatureSummary.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { NavLink } from "react-router-dom"; 3 | import { getTextColors } from "./helpers/metaTheme"; 4 | 5 | interface Action { 6 | action: string; 7 | actionUrl: string; 8 | } 9 | export type FeatureSummaryProps = { 10 | label?: string; 11 | actions?: Action[]; 12 | darkMode?: boolean; 13 | centerAlign?: boolean; 14 | small?: boolean; 15 | useNavLink?: boolean; 16 | style?: React.CSSProperties; 17 | children: React.ReactNode; 18 | className?: string; 19 | justifyCenter?: boolean; 20 | }; 21 | 22 | export default function FeatureSummary({ 23 | label, 24 | actions, 25 | darkMode, 26 | centerAlign, 27 | small, 28 | style, 29 | children, 30 | className, 31 | useNavLink, 32 | justifyCenter, 33 | }: FeatureSummaryProps) { 34 | const { primary, secondary } = getTextColors(darkMode || false); 35 | const uiColor = darkMode ? "bg-gray-800" : "bg-white"; 36 | const linkBody = (l: Action) => ( 37 | <> 38 | 64 | 68 | {l.action} 69 | 70 | 71 | ); 72 | 73 | return ( 74 |
78 | {label && 79 | (small ? ( 80 |

{label}

81 | ) : ( 82 |
{label}
83 | ))} 84 | 85 |
{children}
86 | 87 |
91 | {/* This button should be refactored as its own component */} 92 | {actions && 93 | actions.map((l, key) => { 94 | return useNavLink ? ( 95 | 100 | {linkBody(l)} 101 | 102 | ) : ( 103 | 108 | {linkBody(l)} 109 | 110 | ); 111 | })} 112 |
113 |
114 | ); 115 | } 116 | -------------------------------------------------------------------------------- /frontend/src/components/FeedbackModal.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | const SPACE = "%20"; 4 | const NEW_LINE = "%0D%0A"; 5 | const EMAIL = "segment-anything@meta.com"; 6 | const SUBJECT = "Segment Anything Demo Feedback"; 7 | const BODY = `Hello Segment Anything team,${NEW_LINE}${NEW_LINE}I'd like to give you some feedback about your demo.`; 8 | 9 | const subject = SUBJECT.replaceAll(" ", SPACE); 10 | const body = BODY.replaceAll(" ", SPACE); 11 | 12 | const FeedbackModal = () => { 13 | return ( 14 |
15 |
16 |
17 |

Feedback

18 | 19 | 20 | Close 21 | 22 | 23 |
24 |

25 | Please email all feedback to
26 | 31 | {EMAIL} 32 | 33 |

34 |
35 |
36 | ); 37 | }; 38 | 39 | export default FeedbackModal; 40 | -------------------------------------------------------------------------------- /frontend/src/components/Footer.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { Link, NavLink, useLocation } from "react-router-dom"; 3 | import { AiFillGithub } from "react-icons/ai"; 4 | const Footer = () => { 5 | return ( 6 |
7 |
8 |
9 | 10 |

Sampling by @ MizzleAa

11 | 12 |
13 |
14 |
15 | 16 | ); 17 | }; 18 | 19 | export default Footer; 20 | -------------------------------------------------------------------------------- /frontend/src/components/ImagePicker.tsx: -------------------------------------------------------------------------------- 1 | import React, { useContext, useEffect, useLayoutEffect, useState } from "react"; 2 | import { Button, Checkbox } from "react-daisyui"; 3 | import { useDropzone } from "react-dropzone"; 4 | import PhotoAlbum from "react-photo-album"; 5 | import { NavLink } from "react-router-dom"; 6 | import photos from "./helpers/photos"; 7 | import AppContext from "./hooks/createContext"; 8 | import { useLocation } from 'react-router-dom'; 9 | 10 | export interface ImagePickerProps { 11 | handleSelectedImage: ( 12 | data: File | URL, 13 | options?: { shouldDownload?: boolean; shouldNotFetchAllModel?: boolean } 14 | ) => void; 15 | showGallery: [showGallery: boolean, setShowGallery: (e: boolean) => void]; 16 | } 17 | 18 | const ImagePicker = ({ 19 | handleSelectedImage, 20 | showGallery: [showGallery, setShowGallery], 21 | }: ImagePickerProps) => { 22 | const [error, setError] = useState(""); 23 | const [isLoadedCount, setIsLoadedCount] = useState(0); 24 | const [acceptedTerms, setAcceptedTerms] = useState(false); 25 | const { 26 | enableDemo: [enableDemo, setEnableDemo], 27 | } = useContext(AppContext)!; 28 | // const location = useLocation(); 29 | const path = `${window.location.origin.toString()}` 30 | const isMobile = window.innerWidth < 768; 31 | 32 | const downloadAllImageResponses = () => { 33 | photos.forEach((photo, i) => { 34 | setTimeout(() => { 35 | handleSelectedImage(new URL(photo.src, path), { 36 | shouldDownload: true, 37 | }); 38 | }, i * 30000); 39 | }); 40 | }; 41 | 42 | const handleAttemptContinue = () => { 43 | setAcceptedTerms(true); 44 | setTimeout(() => setEnableDemo(true), 500); 45 | }; 46 | 47 | const { getRootProps, getInputProps } = useDropzone({ 48 | accept: { 49 | "image/png": [".png"], 50 | "image/jpeg": [".jpeg", ".jpg"], 51 | }, 52 | onDrop: (acceptedFile) => { 53 | try { 54 | if (acceptedFile.length === 0) { 55 | setError("File not accepted! Try again."); 56 | return; 57 | } 58 | if (acceptedFile.length > 1) { 59 | setError("Too many files! Try again with 1 file."); 60 | return; 61 | } 62 | const reader = new FileReader(); 63 | reader.onloadend = () => { 64 | handleSelectedImage(acceptedFile[0]); 65 | }; 66 | reader.readAsDataURL(acceptedFile[0]); 67 | } catch (error) { 68 | console.log(error); 69 | } 70 | }, 71 | maxSize: 50_000_000, 72 | }); 73 | 74 | const image = ({ imageProps }: { imageProps: any }) => { 75 | const { src, key, style, onClick } = imageProps; 76 | 77 | return ( 78 | onClick!(e, { index: 0 })} 84 | onLoad={() => { 85 | setIsLoadedCount((prev) => prev + 1); 86 | }} 87 | > 88 | ); 89 | }; 90 | 91 | const onClickPhoto = (src: any) => { 92 | console.log(src) 93 | } 94 | 95 | // return ( 96 | //
97 | 98 | // 101 | 102 | //
107 | // handleSelectedImage(e.photo.src)} 112 | // renderPhoto={image} 113 | // /> 114 | //
115 | //
116 | // ); 117 | 118 | 119 | return ( 120 |
124 | handleSelectedImage(e.photo.src)} 129 | renderPhoto={image} 130 | /> 131 |
132 | ); 133 | }; 134 | 135 | export default ImagePicker; 136 | -------------------------------------------------------------------------------- /frontend/src/components/MobileOptionNavBar.tsx: -------------------------------------------------------------------------------- 1 | import React, { useContext, useState } from "react"; 2 | import AppContext from "./hooks/createContext"; 3 | 4 | interface MobileOptionNavBarProps { 5 | handleResetInteraction: () => void; 6 | handleUndoInteraction: () => void; 7 | handleRedoInteraction: () => void; 8 | handleResetState: () => void; 9 | handleImage: (img?: HTMLImageElement) => void; 10 | userNegClickBool: [ 11 | userNegClickBool: boolean, 12 | setUserNegClickBool: (e: boolean) => void 13 | ]; 14 | } 15 | 16 | const MobileOptionNavBar = ({ 17 | handleResetInteraction, 18 | handleRedoInteraction, 19 | handleResetState, 20 | handleUndoInteraction, 21 | handleImage, 22 | userNegClickBool: [userNegClickBool, setUserNegClickBool], 23 | }: MobileOptionNavBarProps) => { 24 | const { 25 | svg: [svg, setSVG], 26 | clicksHistory: [clicksHistory, setClicksHistory], 27 | segmentTypes: [segmentTypes, setSegmentTypes], 28 | isErased: [isErased, setIsErased], 29 | isLoading: [, setIsLoading], 30 | } = useContext(AppContext)!; 31 | const [hasTouchedUpload, setHasTouchedUpload] = useState(false); 32 | return ( 33 |
34 |
35 | 52 | 60 | 68 |
69 | 89 |
90 | ); 91 | }; 92 | 93 | export default MobileOptionNavBar; 94 | -------------------------------------------------------------------------------- /frontend/src/components/NavBar.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { NavLink, useLocation } from "react-router-dom"; 3 | 4 | interface NavBarProps { 5 | resetState: () => void; 6 | } 7 | 8 | const NavBar = ({ resetState }: NavBarProps) => { 9 | const [isMenuOpen, setIsMenuOpen] = React.useState(false); 10 | const location = useLocation(); 11 | const desktopClasses = "mr-10 font-medium text-base"; 12 | const mobileClasses = "mx-8 text-xl font-semibold"; 13 | 14 | return ( 15 |
16 | { 18 | setIsMenuOpen(true); 19 | }} 20 | className="absolute left-0 h-3 mx-6 md:hidden" 21 | src="/assets/hamburger.svg" 22 | alt="Mobile Menu" 23 | /> 24 | 30 |
Segment Anything
31 |
32 | Research by Meta AI 33 |
34 |
35 | 36 |
37 | ); 38 | }; 39 | 40 | export default NavBar; 41 | -------------------------------------------------------------------------------- /frontend/src/components/SegmentOptions.tsx: -------------------------------------------------------------------------------- 1 | import React, { useContext } from "react"; 2 | import AppContext from "./hooks/createContext"; 3 | 4 | interface SegmentOptionsProps { 5 | handleResetInteraction: () => void; 6 | handleUndoInteraction: () => void; 7 | handleRedoInteraction: () => void; 8 | handleCreateSticker: () => void; 9 | handleMagicErase: () => void; 10 | handleImage: (img?: HTMLImageElement) => void; 11 | hasClicked: boolean; 12 | isCutOut: [isCutOut: boolean, setIsCutOut: (e: boolean) => void]; 13 | handleMultiMaskMode: () => void; 14 | } 15 | 16 | const SegmentOptions = ({ 17 | handleResetInteraction, 18 | handleUndoInteraction, 19 | handleRedoInteraction, 20 | handleCreateSticker, 21 | handleMagicErase, 22 | handleImage, 23 | hasClicked, 24 | isCutOut: [isCutOut, setIsCutOut], 25 | handleMultiMaskMode, 26 | }: SegmentOptionsProps) => { 27 | const { 28 | isModelLoaded: [isModelLoaded, setIsModelLoaded], 29 | segmentTypes: [segmentTypes, setSegmentTypes], 30 | isLoading: [isLoading, setIsLoading], 31 | isErased: [isErased, setIsErased], 32 | svg: [svg, setSVG], 33 | clicksHistory: [clicksHistory, setClicksHistory], 34 | image: [image], 35 | isMultiMaskMode: [isMultiMaskMode, setIsMultiMaskMode], 36 | svgs: [svgs, setSVGs], 37 | clicks: [clicks, setClicks], 38 | showLoadingModal: [showLoadingModal, setShowLoadingModal], 39 | didShowAMGAnimation: [didShowAMGAnimation, setDidShowAMGAnimation], 40 | } = useContext(AppContext)!; 41 | return ( 42 | <> 43 |
48 | 65 | 73 | 81 |
82 |
88 | {/* {segmentTypes === "Click" && ( 89 | 143 | )} */} 144 | {/* */} 156 | 180 |
181 | 182 | ); 183 | }; 184 | 185 | export default SegmentOptions; 186 | -------------------------------------------------------------------------------- /frontend/src/components/Sparkle.tsx: -------------------------------------------------------------------------------- 1 | import React, { useContext, useEffect, useState } from "react"; 2 | import AppContext from "./hooks/createContext"; 3 | 4 | const Sparkle = ({ isActive }: { isActive: Boolean }) => { 5 | const { 6 | isModelLoaded: [isModelLoaded, setIsModelLoaded], 7 | segmentTypes: [segmentTypes, setSegmentTypes], 8 | } = useContext(AppContext)!; 9 | const FILL = segmentTypes === "All" ? "#2962D9" : "#000"; 10 | const [showStar1, setShowStar1] = useState(false); 11 | const [showStar2, setShowStar2] = useState(false); 12 | const [showStar3, setShowStar3] = useState(false); 13 | const [timers, setTimers] = useState(new Array(6).fill(null)); 14 | 15 | const animate = () => { 16 | setShowStar1(true); 17 | setTimers([ 18 | setTimeout(() => { 19 | setShowStar2(true); 20 | setTimers((prev) => [ 21 | ...prev, 22 | setTimeout(() => { 23 | setShowStar3(true); 24 | setTimers((prev) => [ 25 | ...prev, 26 | setTimeout(() => { 27 | setShowStar3(false); 28 | setTimers((prev) => [ 29 | ...prev, 30 | setTimeout(() => { 31 | setShowStar2(false); 32 | setTimers((prev) => [ 33 | ...prev, 34 | setTimeout(() => { 35 | setShowStar1(false); 36 | setTimers((prev) => [ 37 | ...prev, 38 | setTimeout(() => { 39 | animate(); 40 | }, 700), 41 | ]); 42 | }, 100), 43 | ]); 44 | }, 150), 45 | ]); 46 | }, 800), 47 | ]); 48 | }, 100), 49 | ]); 50 | }, 150), 51 | ]); 52 | }; 53 | 54 | const clearTimers = () => { 55 | for (const timer of timers) { 56 | clearTimeout(timer); 57 | } 58 | }; 59 | 60 | useEffect(() => { 61 | if (!isModelLoaded.allModel) { 62 | animate(); 63 | } else { 64 | clearTimers(); 65 | } 66 | return () => { 67 | clearTimers(); 68 | }; 69 | }, [isModelLoaded.allModel]); 70 | 71 | return ( 72 | <> 73 | 81 | 90 | 99 | 108 | 109 | 110 | ); 111 | }; 112 | 113 | export default Sparkle; 114 | -------------------------------------------------------------------------------- /frontend/src/components/SvgMask.tsx: -------------------------------------------------------------------------------- 1 | import React, { useContext, useEffect, useState, useRef } from "react"; 2 | import AppContext from "./hooks/createContext"; 3 | 4 | interface SvgMaskProps { 5 | xScale: number; 6 | yScale: number; 7 | svgStr: string; 8 | id?: string | undefined; 9 | className?: string | undefined; 10 | } 11 | 12 | const SvgMask = ({ 13 | xScale, 14 | yScale, 15 | svgStr, 16 | id = "", 17 | className = "", 18 | }: SvgMaskProps) => { 19 | const { 20 | click: [click, setClick], 21 | image: [image], 22 | isLoading: [isLoading, setIsLoading], 23 | canvasWidth: [, setCanvasWidth], 24 | canvasHeight: [, setCanvasHeight], 25 | isErasing: [isErasing, setIsErasing], 26 | svg: [svg], 27 | isMultiMaskMode: [isMultiMaskMode, setIsMultiMaskMode], 28 | } = useContext(AppContext)!; 29 | const [key, setKey] = useState(Math.random()); 30 | const [boundingBox, setBoundingBox] = useState( 31 | undefined 32 | ); 33 | const pathRef = useRef(null); 34 | const getBoundingBox = () => { 35 | if (!pathRef?.current) return; 36 | setBoundingBox(pathRef.current.getBBox()); 37 | }; 38 | useEffect(() => { 39 | if (!isLoading) { 40 | setKey(Math.random()); 41 | } 42 | getBoundingBox(); 43 | }, [svg]); 44 | const bbX = boundingBox?.x; 45 | const bbY = boundingBox?.y; 46 | const bbWidth = boundingBox?.width; 47 | const bbHeight = boundingBox?.height; 48 | const bbMiddleY = bbY && bbHeight && bbY + bbHeight / 2; 49 | const bbWidthRatio = bbWidth && bbWidth / xScale; 50 | return ( 51 | 57 | {!isMultiMaskMode && bbX && bbWidth && ( 58 | <> 59 | 67 | 68 | 69 | 70 | 71 | 72 | 83 | 84 | 85 | )} 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 100 | {!click && (!isLoading || isErasing) && ( 101 | <> 102 | {!isMultiMaskMode && bbWidthRatio && ( 103 | 0.5 && window.innerWidth < 768 ? "hidden" : "" 106 | }`} 107 | d={svgStr} 108 | strokeLinecap="round" 109 | strokeLinejoin="round" 110 | strokeOpacity="0" 111 | fillOpacity="1" 112 | fill={`url(#gradient${id})`} 113 | /> 114 | )} 115 | 128 | 129 | )} 130 | 131 | ); 132 | }; 133 | 134 | export default SvgMask; 135 | -------------------------------------------------------------------------------- /frontend/src/components/ToolTip.tsx: -------------------------------------------------------------------------------- 1 | import React, { useContext } from "react"; 2 | import { AnnotationProps } from "./helpers/Interfaces"; 3 | // import useTimeout from "./helpers/useTimeout"; 4 | import AppContext from "./hooks/createContext"; 5 | 6 | interface ToolTipProps { 7 | isHoverToolTip: [boolean, React.Dispatch>]; 8 | allText: [allText: any, setAllText: any]; 9 | hasClicked: boolean; 10 | annotations: Array; 11 | } 12 | 13 | const ToolTip = ({ 14 | hasClicked, 15 | annotations, 16 | isHoverToolTip: [isHoverToolTip, setIsHoverToolTip], 17 | allText: [allText, setAllText], 18 | }: ToolTipProps) => { 19 | const { 20 | segmentTypes: [segmentTypes, setSegmentTypes], 21 | clicks: [clicks, setClicks], 22 | eraserText: [eraserText, setEraserText], 23 | isErasing: [isErasing, setIsErasing], 24 | isMultiMaskMode: [isMultiMaskMode, setIsMultiMaskMode], 25 | } = useContext(AppContext)!; 26 | 27 | // useEffect(() => { 28 | // return () => { 29 | // clearTimeout(timerRefOne.current); 30 | // clearTimeout(timerRefTwo.current); 31 | // clearTimeout(timerRefThree.current); 32 | // }; 33 | // }, []); 34 | 35 | // const timerRefOne = useRef(null); 36 | // const timerRefTwo = useRef(null); 37 | // const timerRefThree = useRef(null); 38 | 39 | const isMobile = window.innerWidth < 768; 40 | const getText = () => { 41 | if (isErasing) return null; 42 | // if (eraserText.isErase) 43 | // return "Masks can be input into other open source models, like Erase."; 44 | // if (eraserText.isEmbedding) 45 | // return "Re-extracting an embedding on the erased image."; 46 | if (isMultiMaskMode) { 47 | if (clicks?.length) 48 | return "Move your cursor on or off the image to expand or collapse the layers."; 49 | return "SAM predicts multiple mask possibilities with a single click. Select an object to start."; 50 | } 51 | if (segmentTypes === "Click") { 52 | if (isMobile) { 53 | if (hasClicked && clicks?.length) 54 | return "Cut out the selected object using the Cut-out tool."; 55 | return "Select any object, SAM is running in the browser."; 56 | } 57 | if (hasClicked && clicks?.length) 58 | return "Cut out the selected object, or try multi-mask mode."; 59 | if (isHoverToolTip) 60 | return "When hovering over the image, SAM is running in the browser."; 61 | } 62 | if (segmentTypes === "Box") { 63 | if (annotations.length) return "Refine by adding or subtracting points."; 64 | return "Draw a box around an object."; 65 | } 66 | if (segmentTypes === "All") { 67 | return allText; 68 | } 69 | return null; 70 | }; 71 | return ( 72 | <> 73 |
78 |
79 | {getText()} 80 |
81 |
82 | 83 | ); 84 | }; 85 | 86 | export default ToolTip; 87 | -------------------------------------------------------------------------------- /frontend/src/components/helpers/CanvasHelper.tsx: -------------------------------------------------------------------------------- 1 | import { RefObject } from "react"; 2 | 3 | interface canvasScaleInitializerProps { 4 | width: number; 5 | height: number; 6 | containerRef: RefObject; 7 | shouldFitToWidth?: boolean; 8 | } 9 | 10 | interface canvasScaleResizerProps { 11 | width: number; 12 | height: number; 13 | containerWidth: number; 14 | containerHeight: number; 15 | shouldFitToWidth?: boolean; 16 | } 17 | 18 | const canvasScaleInitializer = ({ 19 | width, 20 | height, 21 | containerRef, 22 | shouldFitToWidth, 23 | }: canvasScaleInitializerProps) => { 24 | const containerWidth = containerRef.current?.offsetWidth || width; 25 | const containerHeight = containerRef.current?.offsetHeight || height; 26 | return canvasScaleResizer({ 27 | width, 28 | height, 29 | containerWidth, 30 | containerHeight, 31 | shouldFitToWidth, 32 | }); 33 | }; 34 | 35 | const canvasScaleResizer = ({ 36 | width, 37 | height, 38 | containerWidth, 39 | containerHeight, 40 | shouldFitToWidth, 41 | }: canvasScaleResizerProps) => { 42 | const isMobile = window.innerWidth < 768; 43 | let scale = 1; 44 | const xScale = containerWidth / width; 45 | const yScale = containerHeight / height; 46 | if (isMobile) { 47 | scale = Math.max(xScale, yScale); 48 | } else { 49 | if (shouldFitToWidth) { 50 | scale = xScale; 51 | } else { 52 | scale = Math.min(xScale, yScale); 53 | } 54 | } 55 | const scaledWidth = scale * width; 56 | const scaledHeight = scale * height; 57 | const scalingStyle = { 58 | transform: `scale(${scale})`, 59 | transformOrigin: "left top", 60 | }; 61 | const scaledDimensionsStyle = { 62 | width: scaledWidth, 63 | height: scaledHeight, 64 | }; 65 | return { 66 | scalingStyle, 67 | scaledDimensionsStyle, 68 | scaledWidth, 69 | scaledHeight, 70 | containerWidth, 71 | containerHeight, 72 | }; 73 | }; 74 | 75 | export { canvasScaleInitializer, canvasScaleResizer }; 76 | -------------------------------------------------------------------------------- /frontend/src/components/helpers/Interfaces.tsx: -------------------------------------------------------------------------------- 1 | import { InferenceSession, Tensor } from "onnxruntime-web"; 2 | import { convertSegmentsToSVG, generatePolygonSegments } from "./trace"; 3 | const ort = require("onnxruntime-web"); 4 | 5 | export interface AnnotationProps { 6 | x: number; 7 | y: number; 8 | width: number; 9 | height: number; 10 | clickType: number; 11 | } 12 | 13 | export interface modelInputProps { 14 | x: number; 15 | y: number; 16 | width: null | number; 17 | height: null | number; 18 | clickType: number; 19 | } 20 | 21 | export enum clickType { 22 | POSITIVE = 1.0, 23 | NEGATIVE = 0.0, 24 | UPPER_LEFT = 2.0, 25 | BOTTOM_RIGHT = 3.0, 26 | } 27 | 28 | export interface modelScaleProps { 29 | onnxScale: number; 30 | maskWidth: number; 31 | maskHeight: number; 32 | scale: number; 33 | uploadScale: number; 34 | width: number; 35 | height: number; 36 | } 37 | 38 | export interface setParmsandQueryModelProps { 39 | width: number; 40 | height: number; 41 | uploadScale: number; 42 | imgData: HTMLImageElement; 43 | handleSegModelResults: ({ tensor }: { tensor: Tensor }) => void; 44 | handleAllModelResults: ({ 45 | allJSON, 46 | image_height, 47 | }: { 48 | allJSON: { 49 | encodedMask: string; 50 | bbox: number[]; 51 | score: number; 52 | point_coord: number[]; 53 | uncertain_iou: number; 54 | area: number; 55 | }[]; 56 | image_height: number; 57 | }) => void; 58 | imgName: string; 59 | shouldDownload: boolean | undefined; 60 | shouldNotFetchAllModel: boolean | undefined; 61 | } 62 | 63 | export interface setParmsandQueryEraseModelProps { 64 | width: number; 65 | height: number; 66 | uploadScale: number; 67 | imgData: HTMLImageElement | null; 68 | mask: 69 | | string[] 70 | | Uint8Array 71 | | Float32Array 72 | | Int8Array 73 | | Uint16Array 74 | | Int16Array 75 | | Int32Array 76 | | BigInt64Array 77 | | Float64Array 78 | | Uint32Array 79 | | BigUint64Array; 80 | handlePredictedImage: (e: string) => void; 81 | } 82 | 83 | export interface queryEraseModelProps { 84 | image: Blob | string; 85 | mask: 86 | | string[] 87 | | Uint8Array 88 | | Float32Array 89 | | Int8Array 90 | | Uint16Array 91 | | Int16Array 92 | | Int32Array 93 | | BigInt64Array 94 | | Float64Array 95 | | Uint32Array 96 | | BigUint64Array; 97 | handlePredictedImage: (e: string) => void; 98 | } 99 | 100 | export interface queryModelReturnTensorsProps { 101 | blob: Blob; 102 | handleSegModelResults: ({ tensor }: { tensor: Tensor }) => void; 103 | handleAllModelResults: ({ 104 | allJSON, 105 | image_height, 106 | }: { 107 | allJSON: { 108 | encodedMask: string; 109 | bbox: number[]; 110 | score: number; 111 | point_coord: number[]; 112 | uncertain_iou: number; 113 | area: number; 114 | }[]; 115 | image_height: number; 116 | }) => void; 117 | image_height: number; 118 | imgName: string; 119 | shouldDownload: boolean | undefined; 120 | shouldNotFetchAllModel: boolean | undefined; 121 | } 122 | 123 | export interface modeDataProps { 124 | clicks?: Array; 125 | tensor: Tensor; 126 | modelScale: modelScaleProps; 127 | best_box?: number[]; 128 | point_coords?: Array; 129 | point_labels?: number[]; 130 | last_pred_mask: Tensor | null; 131 | } 132 | 133 | export interface StageProps { 134 | handleResetState: () => void; 135 | handleMagicErase: () => void; 136 | handleImage: (img?: HTMLImageElement) => void; 137 | scale: modelScaleProps | null; 138 | hasClicked: boolean; 139 | setHasClicked: (e: boolean) => void; 140 | handleSelectedImage: ( 141 | data: File | URL, 142 | options?: { shouldDownload?: boolean; shouldNotFetchAllModel?: boolean } 143 | ) => void; 144 | image: HTMLImageElement | null; 145 | isStandalone?: boolean; 146 | model?: InferenceSession | null; 147 | } 148 | 149 | /** 150 | * Converts RLE Array into SVG data as a single string. 151 | * @param {Float32Array} rleMask 152 | * @param {number} height 153 | * @returns {string} 154 | */ 155 | export const traceRleToSVG = ( 156 | rleMask: 157 | | Array 158 | | string[] 159 | | Uint8Array 160 | | Float32Array 161 | | Int8Array 162 | | Uint16Array 163 | | Int16Array 164 | | Int32Array 165 | | BigInt64Array 166 | | Float64Array 167 | | Uint32Array 168 | | BigUint64Array, 169 | height: number 170 | ) => { 171 | const polySegments = generatePolygonSegments(rleMask, height); 172 | const svgStr = convertSegmentsToSVG(polySegments); 173 | return svgStr; 174 | }; -------------------------------------------------------------------------------- /frontend/src/components/helpers/colors.tsx: -------------------------------------------------------------------------------- 1 | // Colormap options copied from detectron2 2 | 3 | const DETECTRON2_COLORS = [ 4 | [0.0, 0.447, 0.741], 5 | [0.85, 0.325, 0.098], 6 | [0.929, 0.694, 0.125], 7 | [0.494, 0.184, 0.556], 8 | [0.466, 0.674, 0.188], 9 | [0.301, 0.745, 0.933], 10 | [0.635, 0.078, 0.184], 11 | [0.3, 0.3, 0.3], 12 | [0.6, 0.6, 0.6], 13 | [1.0, 0.0, 0.0], 14 | [1.0, 0.5, 0.0], 15 | [0.749, 0.749, 0.0], 16 | [0.0, 1.0, 0.0], 17 | [0.0, 0.0, 1.0], 18 | [0.667, 0.0, 1.0], 19 | [0.333, 0.333, 0.0], 20 | [0.333, 0.667, 0.0], 21 | [0.333, 1.0, 0.0], 22 | [0.667, 0.333, 0.0], 23 | [0.667, 0.667, 0.0], 24 | [0.667, 1.0, 0.0], 25 | [1.0, 0.333, 0.0], 26 | [1.0, 0.667, 0.0], 27 | [1.0, 1.0, 0.0], 28 | [0.0, 0.333, 0.5], 29 | [0.0, 0.667, 0.5], 30 | [0.0, 1.0, 0.5], 31 | [0.333, 0.0, 0.5], 32 | [0.333, 0.333, 0.5], 33 | [0.333, 0.667, 0.5], 34 | [0.333, 1.0, 0.5], 35 | [0.667, 0.0, 0.5], 36 | [0.667, 0.333, 0.5], 37 | [0.667, 0.667, 0.5], 38 | [0.667, 1.0, 0.5], 39 | [1.0, 0.0, 0.5], 40 | [1.0, 0.333, 0.5], 41 | [1.0, 0.667, 0.5], 42 | [1.0, 1.0, 0.5], 43 | [0.0, 0.333, 1.0], 44 | [0.0, 0.667, 1.0], 45 | [0.0, 1.0, 1.0], 46 | [0.333, 0.0, 1.0], 47 | [0.333, 0.333, 1.0], 48 | [0.333, 0.667, 1.0], 49 | [0.333, 1.0, 1.0], 50 | [0.667, 0.0, 1.0], 51 | [0.667, 0.333, 1.0], 52 | [0.667, 0.667, 1.0], 53 | [0.667, 1.0, 1.0], 54 | [1.0, 0.0, 1.0], 55 | [1.0, 0.333, 1.0], 56 | [1.0, 0.667, 1.0], 57 | [0.333, 0.0, 0.0], 58 | [0.5, 0.0, 0.0], 59 | [0.667, 0.0, 0.0], 60 | [0.833, 0.0, 0.0], 61 | [1.0, 0.0, 0.0], 62 | [0.0, 0.167, 0.0], 63 | [0.0, 0.333, 0.0], 64 | [0.0, 0.5, 0.0], 65 | [0.0, 0.667, 0.0], 66 | [0.0, 0.833, 0.0], 67 | [0.0, 1.0, 0.0], 68 | [0.0, 0.0, 0.167], 69 | [0.0, 0.0, 0.333], 70 | [0.0, 0.0, 0.5], 71 | [0.0, 0.0, 0.667], 72 | [0.0, 0.0, 0.833], 73 | [0.0, 0.0, 1.0], 74 | [0.0, 0.0, 0.0], 75 | [0.143, 0.143, 0.143], 76 | [0.857, 0.857, 0.857], 77 | [1.0, 1.0, 1.0], 78 | ]; 79 | 80 | const colors = (function () { 81 | const RGBs: Array = []; 82 | DETECTRON2_COLORS.map((color) => { 83 | const [r, g, b] = color.map((n) => { 84 | return Math.round(n * 255); 85 | }); 86 | RGBs.push(`rgb(${r},${g},${b})`); 87 | }); 88 | return RGBs; 89 | })(); 90 | 91 | export default colors; 92 | -------------------------------------------------------------------------------- /frontend/src/components/helpers/files.tsx: -------------------------------------------------------------------------------- 1 | const getFile = async (data: URL) => { 2 | const response = await fetch(data); 3 | const blob = await response.blob(); 4 | return new File([blob], "image.jpeg"); 5 | }; 6 | 7 | export default getFile; 8 | -------------------------------------------------------------------------------- /frontend/src/components/helpers/maskUtils.tsx: -------------------------------------------------------------------------------- 1 | 2 | import { generatePolygonSegments, convertSegmentsToSVG } from "./trace"; 3 | import { Tensor } from "onnxruntime-web"; 4 | const ort = require("onnxruntime-web"); 5 | // Functions for handling mask output from the ONNX model 6 | 7 | // Convert the onnx model mask prediction to ImageData 8 | function arrayToImageData(input: any, width: number, height: number) { 9 | const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color 10 | const arr = new Uint8ClampedArray(4 * width * height).fill(0); 11 | for (let i = 0; i < input.length; i++) { 12 | 13 | // Threshold the onnx model mask prediction at 0.0 14 | // This is equivalent to thresholding the mask using predictor.model.mask_threshold 15 | // in python 16 | if (input[i] > 0.0) { 17 | arr[4 * i + 0] = r; 18 | arr[4 * i + 1] = g; 19 | arr[4 * i + 2] = b; 20 | arr[4 * i + 3] = a; 21 | } 22 | } 23 | return new ImageData(arr, height, width); 24 | } 25 | 26 | // Use a Canvas element to produce an image from ImageData 27 | function imageDataToImage(imageData: ImageData) { 28 | const canvas = imageDataToCanvas(imageData); 29 | const image = new Image(); 30 | image.src = canvas.toDataURL(); 31 | return image; 32 | } 33 | 34 | // Canvas elements can be created from ImageData 35 | function imageDataToCanvas(imageData: ImageData) { 36 | const canvas = document.createElement("canvas"); 37 | const ctx = canvas.getContext("2d"); 38 | canvas.width = imageData.width; 39 | canvas.height = imageData.height; 40 | ctx?.putImageData(imageData, 0, 0); 41 | return canvas; 42 | } 43 | 44 | // Convert the onnx model mask output to an HTMLImageElement 45 | export function onnxMaskToImage(input: any, width: number, height: number) { 46 | return imageDataToImage(arrayToImageData(input, width, height)); 47 | } 48 | 49 | 50 | /** 51 | * Functions for handling and tracing masks. 52 | */ 53 | 54 | // const { 55 | // generatePolygonSegments, 56 | // convertSegmentsToSVG, 57 | // } = require("./custom_tracer"); 58 | 59 | /** 60 | * Converts mask array into RLE array using the fortran array 61 | * format where rows and columns are transposed. This is the 62 | * format used by the COCO API and is expected by the mask tracer. 63 | * @param {Array} input 64 | * @param {number} nrows 65 | * @param {number} ncols 66 | * @returns array of integers 67 | */ 68 | export function maskDataToFortranArrayToRle( 69 | input: any, 70 | nrows: number, 71 | ncols: number 72 | ) { 73 | const result = []; 74 | let count = 0; 75 | let bit = false; 76 | for (let c = 0; c < ncols; c++) { 77 | for (let r = 0; r < nrows; r++) { 78 | var i = c + r * ncols; 79 | if (i < input.length) { 80 | const filled = input[i] > 0.0; 81 | if (filled !== bit) { 82 | result.push(count); 83 | bit = !bit; 84 | count = 1; 85 | } else count++; 86 | } 87 | } 88 | } 89 | if (count > 0) result.push(count); 90 | return result; 91 | } 92 | 93 | /** 94 | * Converts RLE Array into SVG data as a single string. 95 | * @param {Float32Array} rleMask 96 | * @param {number} height 97 | * @returns {string} 98 | */ 99 | export const traceRleToSVG = ( 100 | rleMask: 101 | | Array 102 | | string[] 103 | | Uint8Array 104 | | Float32Array 105 | | Int8Array 106 | | Uint16Array 107 | | Int16Array 108 | | Int32Array 109 | | BigInt64Array 110 | | Float64Array 111 | | Uint32Array 112 | | BigUint64Array, 113 | height: number 114 | ) => { 115 | const polySegments = generatePolygonSegments(rleMask, height); 116 | const svgStr = convertSegmentsToSVG(polySegments); 117 | return svgStr; 118 | }; 119 | 120 | export const getAllMasks = (maskData: any, height: number, width: number) => { 121 | let masks = []; 122 | for (let m = 0; m < 4; m++) { 123 | let nthMask = new Float32Array(height * width); 124 | const offset = m * width * height; 125 | for (let i = 0; i < height; i++) { 126 | for (let j = 0; j < width; j++) { 127 | var idx = i * width + j; 128 | if (idx < width * height) { 129 | nthMask[idx] = maskData[offset + idx]; 130 | } 131 | } 132 | } 133 | masks.push(nthMask); 134 | } 135 | return masks; 136 | }; 137 | 138 | export const getBestPredMask = ( 139 | maskData: any, 140 | height: number, 141 | width: number, 142 | index: number 143 | ) => { 144 | let nthMask = new Float32Array(height * width); 145 | const offset = index * width * height; 146 | for (let i = 0; i < height; i++) { 147 | for (let j = 0; j < width; j++) { 148 | var idx = i * width + j; 149 | if (idx < width * height) { 150 | nthMask[idx] = maskData[offset + idx]; 151 | } 152 | } 153 | } 154 | const bestMask = new Tensor("float32", nthMask, [1, 1, width, height]); 155 | return bestMask; 156 | }; 157 | 158 | function areaUnderLine(x0: number, y0: number, x1: number, y1: number) { 159 | // A vertical line has no area 160 | if (x0 === x1) return 0; 161 | // Square piece 162 | const ymin = Math.min(y0, y1); 163 | const squareArea = (x1 - x0) * ymin; 164 | // Triangle piece 165 | const ymax = Math.max(y0, y1); 166 | const triangleArea = Math.trunc((x1 - x0) * (ymax - ymin) / 2); 167 | return squareArea + triangleArea; 168 | } 169 | 170 | function svgCoordToInt(input: string) { 171 | if ((input.charAt(0) === "L") || (input.charAt(0) === "M")) { 172 | return parseInt(input.slice(1)); 173 | } 174 | return parseInt(input); 175 | } 176 | 177 | function areaOfSVGPolygon(input: string) { 178 | let coords = input.split(" "); 179 | if (coords.length < 4) return 0; 180 | if (coords.length % 2 != 0) return 0; 181 | let area = 0; 182 | // We need to close the polygon loop, so start with the last coords. 183 | let old_x = svgCoordToInt(coords[coords.length - 2]); 184 | let old_y = svgCoordToInt(coords[coords.length - 1]); 185 | for (let i = 0; i < coords.length; i = i + 2) { 186 | let new_x = svgCoordToInt(coords[i]); 187 | let new_y = svgCoordToInt(coords[i + 1]); 188 | area = area + areaUnderLine(old_x, old_y, new_x, new_y); 189 | old_x = new_x; 190 | old_y = new_y; 191 | } 192 | return area; 193 | } 194 | 195 | /** 196 | * Filters SVG edges that enclose an area smaller than maxRegionSize. 197 | * Expects a list over SVG strings, with each string in the format: 198 | * 'M L ... ' 199 | * The area calculation is not quite exact, truncating fractional pixels 200 | * instead of rounding. Both clockwise and counterclockwise SVG edges 201 | * are filtered, removing stray regions and small holes. Always keeps 202 | * at least one positive area region. 203 | */ 204 | export function filterSmallSVGRegions( 205 | input: string[], maxRegionSize: number = 100 206 | ) { 207 | const filtered_regions = input.filter( 208 | (region: string) => Math.abs(areaOfSVGPolygon(region)) > maxRegionSize 209 | ); 210 | if (filtered_regions.length === 0) { 211 | const areas = input.map((region: string) => areaOfSVGPolygon(region)); 212 | const bestIdx = areas.indexOf(Math.max(...areas)); 213 | return [input[bestIdx]]; 214 | } 215 | return filtered_regions; 216 | } 217 | 218 | /** 219 | * Converts onnx model output into SVG data as a single string 220 | * @param {Float32Array} maskData 221 | * @param {number} height 222 | * @param {number} width 223 | * @returns {string} 224 | */ 225 | export const traceOnnxMaskToSVG = ( 226 | maskData: 227 | | string[] 228 | | Uint8Array 229 | | Uint8ClampedArray 230 | | Float32Array 231 | | Int8Array 232 | | Uint16Array 233 | | Int16Array 234 | | Int32Array 235 | | BigInt64Array 236 | | Float64Array 237 | | Uint32Array 238 | | BigUint64Array, 239 | height: number, 240 | width: number 241 | ) => { 242 | const rleMask = maskDataToFortranArrayToRle(maskData, width, height); 243 | let svgStr = traceRleToSVG(rleMask, width); 244 | svgStr = filterSmallSVGRegions(svgStr); 245 | return svgStr; 246 | }; 247 | 248 | /** 249 | * Converts compressed RLE string into SVG 250 | * @param {string} maskString 251 | * @param {number} height 252 | * @returns {string} 253 | */ 254 | export const traceCompressedRLeStringToSVG = ( 255 | maskString: string | null, 256 | height: number 257 | ) => { 258 | const rleMask = rleFrString(maskString); 259 | let svgStr = traceRleToSVG(rleMask, height); 260 | svgStr = filterSmallSVGRegions(svgStr); 261 | return svgStr; 262 | }; 263 | 264 | /** 265 | * Parses RLE from compressed string 266 | * @param {Array} input 267 | * @returns array of integers 268 | */ 269 | export const rleFrString = (input: any) => { 270 | let result = []; 271 | let charIndex = 0; 272 | while (charIndex < input.length) { 273 | let value = 0, 274 | k = 0, 275 | more = 1; 276 | while (more) { 277 | let c = input.charCodeAt(charIndex) - 48; 278 | value |= (c & 0x1f) << (5 * k); 279 | more = c & 0x20; 280 | charIndex++; 281 | k++; 282 | if (!more && c & 0x10) value |= -1 << (5 * k); 283 | } 284 | if (result.length > 2) value += result[result.length - 2]; 285 | result.push(value); 286 | } 287 | return result; 288 | }; 289 | 290 | function toImageData(input: any, width: number, height: number) { 291 | const [r, g, b, a] = [0, 114, 189, 255]; 292 | const arr = new Uint8ClampedArray(4 * width * height).fill(0); 293 | for (let i = 0; i < input.length; i++) { 294 | if (input[i] > 0.0) { 295 | arr[4 * i + 0] = r; 296 | arr[4 * i + 1] = g; 297 | arr[4 * i + 2] = b; 298 | arr[4 * i + 3] = a; 299 | } 300 | } 301 | return new ImageData(arr, height, width); 302 | } 303 | 304 | export function rleToImage(input: any, width: number, height: number) { 305 | return imageDataToImage(toImageData(input, width, height)); 306 | } 307 | 308 | export function rleToCanvas(input: any, width: number, height: number) { 309 | return imageDataToCanvas(toImageData(input, width, height)); 310 | } 311 | 312 | 313 | // Returns a boolean array for which masks to keep in the multi-mask 314 | // display, given uncertain IoUs and overlap IoUs. 315 | export function keepArrayForMultiMask( 316 | uncertainIoUs: number[], 317 | overlapIoUs: number[], 318 | uncertainThresh: number = 0.8, 319 | overlapThresh: number = 0.9, 320 | ) { 321 | let keepArray = uncertainIoUs.map((iou: number) => iou > uncertainThresh); 322 | const duplicateArray = overlapIoUs.map((iou: number) => iou < overlapThresh); 323 | keepArray = keepArray.map((val: boolean, i: number) => val && duplicateArray[i]); 324 | // If all masks fail tests, keep just the best one 325 | if (keepArray.every(item => item === false)) { 326 | const bestIdx = uncertainIoUs.indexOf(Math.max(...uncertainIoUs)); 327 | keepArray[bestIdx] = true; 328 | } 329 | return keepArray; 330 | } 331 | -------------------------------------------------------------------------------- /frontend/src/components/helpers/metaTheme.tsx: -------------------------------------------------------------------------------- 1 | /** 2 | * Returns tailwind color classes for text contents, based on the darkMode boolean 3 | */ 4 | export function getTextColors(darkMode: boolean): { 5 | primary: string; 6 | secondary: string; 7 | } { 8 | const primary = darkMode ? "text-white" : "text-gray-800"; 9 | const secondary = darkMode ? "text-gray-300" : "text-gray-600"; 10 | return { primary, secondary }; 11 | } 12 | -------------------------------------------------------------------------------- /frontend/src/components/helpers/photos.tsx: -------------------------------------------------------------------------------- 1 | const photos = [ 2 | { 3 | src: "/assets/gallery/1.jpg", 4 | width: 1920, 5 | height: 1080, 6 | }, 7 | { 8 | src: "/assets/gallery/2.jpg", 9 | width: 1920, 10 | height: 1080, 11 | }, 12 | { 13 | src: "/assets/gallery/3.jpg", 14 | width: 1920, 15 | height: 1080, 16 | }, 17 | { 18 | src: "/assets/gallery/4.jpg", 19 | width: 1920, 20 | height: 1080, 21 | }, 22 | { 23 | src: "/assets/gallery/5.jpg", 24 | width: 1920, 25 | height: 1080, 26 | }, 27 | { 28 | src: "/assets/gallery/6.jpg", 29 | width: 1072, 30 | height: 608, 31 | }, 32 | { 33 | src: "/assets/gallery/7.jpg", 34 | width: 1080, 35 | height: 1440, 36 | }, 37 | { 38 | src: "/assets/gallery/8.jpg", 39 | width: 2048, 40 | height: 1365, 41 | }, 42 | 43 | ]; 44 | 45 | export default photos; 46 | -------------------------------------------------------------------------------- /frontend/src/components/helpers/scaleHelper.tsx: -------------------------------------------------------------------------------- 1 | const handleImageScale = (data: HTMLImageElement) => { 2 | const IMAGE_SIZE = 500; 3 | const UPLOAD_IMAGE_SIZE = 1024; 4 | let w = data.naturalWidth; 5 | let h = data.naturalHeight; 6 | let scale; 7 | let uploadScale; 8 | if (h < w) { 9 | scale = IMAGE_SIZE / h; 10 | if (h * scale > 1333) { 11 | scale = 1333 / h; 12 | } 13 | uploadScale = UPLOAD_IMAGE_SIZE / w; 14 | } else { 15 | scale = IMAGE_SIZE / w; 16 | if (w * scale > 1333) { 17 | scale = 1333 / w; 18 | } 19 | uploadScale = UPLOAD_IMAGE_SIZE / h; 20 | } 21 | return { height: h, width: w, scale, uploadScale }; 22 | }; 23 | 24 | export { handleImageScale }; 25 | -------------------------------------------------------------------------------- /frontend/src/components/hooks/Animation.tsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from "react"; 2 | 3 | function useDelayUnmount(isMounted: boolean, delayTime: number) { 4 | const [showDiv, setShowDiv] = useState(false); 5 | useEffect(() => { 6 | let timeoutId: string | number | NodeJS.Timeout | undefined; 7 | if (isMounted && !showDiv) { 8 | setShowDiv(true); 9 | } else if (!isMounted && showDiv) { 10 | timeoutId = setTimeout(() => setShowDiv(false), delayTime); //delay our unmount 11 | } 12 | return () => clearTimeout(timeoutId); // cleanup mechanism for effects , the use of setTimeout generate a sideEffect 13 | }, [isMounted, delayTime, showDiv]); 14 | return showDiv; 15 | } 16 | 17 | const Animate = ({ children, isMounted }: any) => { 18 | const showDiv = useDelayUnmount(isMounted, 450); 19 | const mountedStyle = { animation: "inAnimation 450ms ease-in" }; 20 | const unmountedStyle = { 21 | animation: "outAnimation 700ms ease-out", 22 | animationFillMode: "forwards", 23 | }; 24 | return ( 25 |
26 | {showDiv && ( 27 |
{children}
28 | )} 29 |
30 | ); 31 | }; 32 | 33 | export default Animate; 34 | 35 | // THE CSS: 36 | 37 | // @keyframes inAnimation { 38 | // 0% { 39 | // opacity: 0; 40 | // max-height: 0px; 41 | // } 42 | // 100% { 43 | // opacity: 1; 44 | // max-height: 600px; 45 | // } 46 | // } 47 | 48 | // @keyframes outAnimation { 49 | // 0% { 50 | // opacity: 1; 51 | // max-height: 600px; 52 | // } 53 | // 100% { 54 | // opacity: 0; 55 | // max-height: 0px; 56 | // } 57 | // } 58 | -------------------------------------------------------------------------------- /frontend/src/components/hooks/context.tsx: -------------------------------------------------------------------------------- 1 | import { Tensor } from "onnxruntime-web"; 2 | const ort = require("onnxruntime-web"); 3 | 4 | import React, { useState } from "react"; 5 | import { modelInputProps } from "../helpers/Interfaces"; 6 | import AppContext from "./createContext"; 7 | 8 | const AppContextProvider = (props: { 9 | children: React.ReactElement>; 10 | }) => { 11 | const [click, setClick] = useState(null); 12 | const [clicks, setClicks] = useState | null>(null); 13 | const [clicksHistory, setClicksHistory] = 14 | useState | null>(null); 15 | const [isLoading, setIsLoading] = useState(false); 16 | const [image, setImage] = useState(null); 17 | const [prevImage, setPrevImage] = useState(null); 18 | const [isErasing, setIsErasing] = useState(false); 19 | const [isErased, setIsErased] = useState(false); 20 | const [error, setError] = useState(false); 21 | const [svg, setSVG] = useState(null); 22 | const [svgs, setSVGs] = useState(null); 23 | const [allsvg, setAllsvg] = useState< 24 | { svg: string[]; point_coord: number[] }[] | null 25 | >(null); 26 | const [isModelLoaded, setIsModelLoaded] = useState<{ 27 | boxModel: boolean; 28 | allModel: boolean; 29 | }>({ boxModel: false, allModel: false }); 30 | const [stickers, setStickers] = useState([]); 31 | const [activeSticker, setActiveSticker] = useState(0); 32 | const [segmentTypes, setSegmentTypes] = useState<"Box" | "Click" | "All">( 33 | "Click" 34 | ); 35 | const [canvasWidth, setCanvasWidth] = useState(0); 36 | const [canvasHeight, setCanvasHeight] = useState(0); 37 | const [maskImg, setMaskImg] = useState(null); 38 | const [maskCanvas, setMaskCanvas] = useState(null); 39 | const [userNegClickBool, setUserNegClickBool] = useState(false); 40 | const [hasNegClicked, setHasNegClicked] = useState(false); 41 | const [stickerTabBool, setStickerTabBool] = useState(false); 42 | const [enableDemo, setEnableDemo] = useState(false); 43 | const [isMultiMaskMode, setIsMultiMaskMode] = useState(false); 44 | const [isHovering, setIsHovering] = useState(null); 45 | const [showLoadingModal, setShowLoadingModal] = useState(false); 46 | const [eraserText, setEraserText] = useState<{ 47 | isErase: boolean; 48 | isEmbedding: boolean; 49 | }>({ isErase: false, isEmbedding: false }); 50 | const [didShowAMGAnimation, setDidShowAMGAnimation] = 51 | useState(false); 52 | const [predMask, setPredMask] = useState(null); 53 | const [predMasks, setPredMasks] = useState(null); 54 | const [predMasksHistory, setPredMasksHistory] = useState( 55 | null 56 | ); 57 | const [isAllAnimationDone, setIsAllAnimationDone] = useState(false); 58 | const [isToolBarUpload, setIsToolBarUpload] = useState(false); 59 | 60 | return ( 61 | 99 | {props.children} 100 | 101 | ); 102 | }; 103 | 104 | export default AppContextProvider; 105 | -------------------------------------------------------------------------------- /frontend/src/components/hooks/createContext.tsx: -------------------------------------------------------------------------------- 1 | import { Tensor } from "onnxruntime-web"; 2 | import { createContext } from "react"; 3 | import { modelInputProps } from "../helpers/Interfaces"; 4 | 5 | interface contextProps { 6 | click: [ 7 | click: modelInputProps | null, 8 | setClick: (e: modelInputProps | null) => void 9 | ]; 10 | clicks: [ 11 | clicks: modelInputProps[] | null, 12 | setClicks: (e: modelInputProps[] | null) => void 13 | ]; 14 | clicksHistory: [ 15 | clicksHistory: modelInputProps[] | null, 16 | setClicksHistory: (e: modelInputProps[] | null) => void 17 | ]; 18 | image: [ 19 | image: HTMLImageElement | null, 20 | setImage: (e: HTMLImageElement | null) => void 21 | ]; 22 | prevImage: [ 23 | prevImage: HTMLImageElement | null, 24 | setPrevImage: (e: HTMLImageElement | null) => void 25 | ]; 26 | isLoading: [isLoading: boolean, setIsLoading: (e: boolean) => void]; 27 | isErasing: [isErasing: boolean, setIsErasing: (e: boolean) => void]; 28 | isErased: [isErased: boolean, setIsErased: (e: boolean) => void]; 29 | error: [error: boolean, setError: (e: boolean) => void]; 30 | svg: [svg: string[] | null, setSVG: (e: string[] | null) => void]; 31 | svgs: [svgs: string[][] | null, setSVGs: (e: string[][] | null) => void]; 32 | allsvg: [ 33 | allsvg: { svg: string[]; point_coord: number[] }[] | null, 34 | setAllsvg: (e: { svg: string[]; point_coord: number[] }[] | null) => void 35 | ]; 36 | stickers: [ 37 | stickers: HTMLCanvasElement[], 38 | setStickers: (e: HTMLCanvasElement[]) => void 39 | ]; 40 | activeSticker: [ 41 | activerSticker: number, 42 | setActiveSticker: (e: number) => void 43 | ]; 44 | isModelLoaded: [ 45 | isModelLoaded: { 46 | boxModel: boolean; 47 | allModel: boolean; 48 | }, 49 | setIsModelLoaded: React.Dispatch< 50 | React.SetStateAction<{ boxModel: boolean; allModel: boolean }> 51 | > 52 | ]; 53 | segmentTypes: [ 54 | segmentTypes: "Box" | "Click" | "All", 55 | setSegmentTypes: (e: "Box" | "Click" | "All") => void 56 | ]; 57 | canvasWidth: [canvasWidth: number, setCanvasWidth: (e: number) => void]; 58 | canvasHeight: [canvasHeight: number, setCanvasHeight: (e: number) => void]; 59 | maskImg: [ 60 | maskImg: HTMLImageElement | null, 61 | setMaskImg: (e: HTMLImageElement | null) => void 62 | ]; 63 | maskCanvas: [ 64 | maskCanvas: HTMLCanvasElement | null, 65 | setMaskCanvas: (e: HTMLCanvasElement | null) => void 66 | ]; 67 | userNegClickBool: [ 68 | userNegClickBool: boolean, 69 | setUserNegClickBool: (e: boolean) => void 70 | ]; 71 | hasNegClicked: [ 72 | hasNegClicked: boolean, 73 | setHasNegClicked: (e: boolean) => void 74 | ]; 75 | stickerTabBool: [ 76 | stickerTabBool: boolean, 77 | setStickerTabBool: React.Dispatch> 78 | ]; 79 | enableDemo: [ 80 | enableDemo: boolean, 81 | setEnableDemo: React.Dispatch> 82 | ]; 83 | isMultiMaskMode: [ 84 | isMultiMaskMode: boolean, 85 | setIsMultiMaskMode: React.Dispatch> 86 | ]; 87 | isHovering: [ 88 | isHovering: boolean | null, 89 | setIsHovering: React.Dispatch> 90 | ]; 91 | showLoadingModal: [ 92 | showLoadingModal: boolean, 93 | setShowLoadingModal: React.Dispatch> 94 | ]; 95 | eraserText: [ 96 | eraserText: { 97 | isErase: boolean; 98 | isEmbedding: boolean; 99 | }, 100 | setEraserText: React.Dispatch< 101 | React.SetStateAction<{ 102 | isErase: boolean; 103 | isEmbedding: boolean; 104 | }> 105 | > 106 | ]; 107 | didShowAMGAnimation: [ 108 | didShowAMGAnimation: boolean, 109 | setDidShowAMGAnimation: React.Dispatch> 110 | ]; 111 | predMask: [ 112 | predMask: Tensor | null, 113 | setPredMask: React.Dispatch> 114 | ]; 115 | predMasks: [ 116 | predMasks: Tensor[] | null, 117 | setPredMasks: React.Dispatch> 118 | ]; 119 | predMasksHistory: [ 120 | predMasksHistory: Tensor[] | null, 121 | setPredMasksHistory: React.Dispatch> 122 | ]; 123 | isAllAnimationDone: [ 124 | isAllAnimationDone: boolean, 125 | setIsAllAnimationDone: React.Dispatch> 126 | ]; 127 | isToolBarUpload: [ 128 | isToolBarUpload: boolean, 129 | setIsToolBarUpload: React.Dispatch> 130 | ]; 131 | } 132 | 133 | const AppContext = createContext(null); 134 | 135 | export default AppContext; 136 | -------------------------------------------------------------------------------- /frontend/src/enviroments.tsx: -------------------------------------------------------------------------------- 1 | // export const IMAGE_PATH = "/assets/data/dogs.jpg"; 2 | // export const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; 3 | export const MODEL_DIR = "/model/onnx_example.onnx"; 4 | export const MULTI_MASK_MODEL_DIR ="/model/meta_multi_onnx.onnx" 5 | export const API_ENDPOINT = "http://127.0.0.1:8000/ai/embedded"; 6 | export const ALL_MASK_API_ENDPOINT = "http://127.0.0.1:8000/ai/embedded/all"; 7 | export const ERASE_API_ENDPOINT = ""; 8 | 9 | //META set 10 | // const META_API_ENDPOINT = "https://model-zoo.metademolab.com/predictions/segment_everything_box_model" 11 | // const META_API_ALL_MASK_API_ENDPOINT = "https://model-zoo.metademolab.com/predictions/automatic_masks" 12 | // const META_MODEL_DIR = "../../model/interactive_module_quantized_592547_2023_03_19_sam6_long_uncertain.onnx" 13 | -------------------------------------------------------------------------------- /frontend/src/index.tsx: -------------------------------------------------------------------------------- 1 | import * as React from "react"; 2 | import { createRoot } from "react-dom/client"; 3 | import { createBrowserRouter, RouterProvider } from "react-router-dom"; 4 | import App from "./App"; 5 | import ErrorPage from "./components/ErrorPage"; 6 | import AppContextProvider from "./components/hooks/context"; 7 | 8 | const container = document.getElementById("root"); 9 | const root = createRoot(container!); 10 | 11 | const router = createBrowserRouter([ 12 | { 13 | path: "*", 14 | element: , 15 | errorElement: , 16 | }, 17 | ]); 18 | root.render( 19 | 20 | 21 | 22 | ); 23 | -------------------------------------------------------------------------------- /frontend/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: ["./src/**/*.{html,js,tsx}"], 4 | theme: {}, 5 | plugins: [], 6 | }; 7 | -------------------------------------------------------------------------------- /frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "lib": ["dom", "dom.iterable", "esnext"], 4 | "allowJs": true, 5 | "skipLibCheck": true, 6 | "strict": true, 7 | "forceConsistentCasingInFileNames": true, 8 | "noEmit": false, 9 | "esModuleInterop": true, 10 | "module": "esnext", 11 | "moduleResolution": "node", 12 | "resolveJsonModule": true, 13 | "isolatedModules": true, 14 | "jsx": "react", 15 | "incremental": true, 16 | "target": "ESNext", 17 | "useDefineForClassFields": true, 18 | "allowSyntheticDefaultImports": true, 19 | "outDir": "./dist/", 20 | "sourceMap": true 21 | }, 22 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", "src"], 23 | "exclude": ["node_modules"] 24 | } 25 | --------------------------------------------------------------------------------