├── .github └── workflows │ ├── macos.yml │ ├── pypi.yml │ ├── ubuntu.yml │ └── windows.yml ├── .gitignore ├── Inference.py ├── LICENSE ├── MANIFEST.in ├── MORE_USAGES.md ├── README.md ├── app_gradio.py ├── assets ├── Overview.png ├── anomaly.png ├── building.png ├── dog_clip.png ├── eightpic.pdf ├── eightpic.png ├── head_fig.png ├── hf_everything_mode.png ├── hf_points_mode.png ├── logo.png ├── more_usages │ ├── box_prompt.png │ ├── draw_edge.png │ ├── everything_mode.png │ ├── everything_mode_without_retina.png │ ├── more_points.png │ └── text_prompt_cat.png ├── replicate-1.png ├── replicate-2.png ├── replicate-3.png └── salient.png ├── cog.yaml ├── examples ├── dogs.jpg ├── sa_10039.jpg ├── sa_11025.jpg ├── sa_1309.jpg ├── sa_192.jpg ├── sa_414.jpg ├── sa_561.jpg ├── sa_862.jpg └── sa_8776.jpg ├── fastsam ├── __init__.py ├── decoder.py ├── model.py ├── predict.py ├── prompt.py └── utils.py ├── images ├── cat.jpg └── dogs.jpg ├── output ├── cat.jpg └── dogs.jpg ├── predict.py ├── requirements.txt ├── segpredict.py ├── setup.py └── utils ├── __init__.py ├── tools.py └── tools_gradio.py /.github/workflows/macos.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | pull_request: 6 | branches: 7 | - main 8 | 9 | name: macOS build 10 | jobs: 11 | test-macOS: 12 | runs-on: ${{ matrix.config.os }} 13 | name: ${{ matrix.config.os }} (${{ matrix.config.py }}) 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | config: 18 | - { os: macOS-latest, py: "3.10" } 19 | env: 20 | SDKROOT: /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk 21 | steps: 22 | - name: CHECKOUT CODE 23 | uses: actions/checkout@v3 24 | - name: SETUP PYTHON 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: ${{ matrix.config.py }} 28 | # - name: Install GDAL 29 | # run: | 30 | # python -m pip install --upgrade pip 31 | # pip install --no-cache-dir Cython 32 | # pip install --find-links=https://girder.github.io/large_image_wheels --no-cache GDAL 33 | # - name: Test GDAL installation 34 | # run: | 35 | # python -c "from osgeo import gdal" 36 | # gdalinfo --version 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install --no-cache-dir Cython 41 | pip install codespell -r requirements.txt 42 | pip install . 43 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: pypi 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.x" 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install setuptools wheel twine 24 | - name: Build and publish 25 | env: 26 | TWINE_USERNAME: ${{ secrets.PYPI_USERS }} 27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 28 | run: | 29 | python setup.py sdist bdist_wheel 30 | twine upload dist/* 31 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | pull_request: 6 | branches: 7 | - main 8 | 9 | name: Linux build 10 | jobs: 11 | py-check: 12 | runs-on: ${{ matrix.config.os }} 13 | name: ${{ matrix.config.os }} (${{ matrix.config.py }}) 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | config: 18 | - { os: ubuntu-latest, py: "3.8" } 19 | - { os: ubuntu-latest, py: "3.9" } 20 | - { os: ubuntu-latest, py: "3.10" } 21 | - { os: ubuntu-latest, py: "3.11" } 22 | 23 | env: 24 | SDKROOT: /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk 25 | steps: 26 | - name: CHECKOUT CODE 27 | uses: actions/checkout@v3 28 | - name: SETUP PYTHON 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: ${{ matrix.config.py }} 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install --user --no-cache-dir Cython 36 | pip install --user -r requirements.txt 37 | pip install --user . 38 | -------------------------------------------------------------------------------- /.github/workflows/windows.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | pull_request: 6 | branches: 7 | - main 8 | 9 | name: Windows build 10 | jobs: 11 | test-windows: 12 | runs-on: windows-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Install miniconda 16 | uses: conda-incubator/setup-miniconda@v2 17 | with: 18 | auto-activate-base: true 19 | python-version: "3.10" 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install --no-cache-dir Cython 24 | pip install -r requirements.txt 25 | pip install . 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | *.pyd 4 | .DS_Store 5 | .idea 6 | weights 7 | build/ 8 | *.egg-info/ 9 | gradio_cached_examples 10 | dist/ -------------------------------------------------------------------------------- /Inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from fastsam import FastSAM, FastSAMPrompt 3 | import ast 4 | import torch 5 | from PIL import Image 6 | from utils.tools import convert_box_xywh_to_xyxy 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--model_path", type=str, default="./weights/FastSAM.pt", help="model" 13 | ) 14 | parser.add_argument( 15 | "--img_path", type=str, default="./images/dogs.jpg", help="path to image file" 16 | ) 17 | parser.add_argument("--imgsz", type=int, default=1024, help="image size") 18 | parser.add_argument( 19 | "--iou", 20 | type=float, 21 | default=0.9, 22 | help="iou threshold for filtering the annotations", 23 | ) 24 | parser.add_argument( 25 | "--text_prompt", type=str, default=None, help='use text prompt eg: "a dog"' 26 | ) 27 | parser.add_argument( 28 | "--conf", type=float, default=0.4, help="object confidence threshold" 29 | ) 30 | parser.add_argument( 31 | "--output", type=str, default="./output/", help="image save path" 32 | ) 33 | parser.add_argument( 34 | "--randomcolor", type=bool, default=True, help="mask random color" 35 | ) 36 | parser.add_argument( 37 | "--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]" 38 | ) 39 | parser.add_argument( 40 | "--point_label", 41 | type=str, 42 | default="[0]", 43 | help="[1,0] 0:background, 1:foreground", 44 | ) 45 | parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes") 46 | parser.add_argument( 47 | "--better_quality", 48 | type=str, 49 | default=False, 50 | help="better quality using morphologyEx", 51 | ) 52 | device = torch.device( 53 | "cuda" 54 | if torch.cuda.is_available() 55 | else "mps" 56 | if torch.backends.mps.is_available() 57 | else "cpu" 58 | ) 59 | parser.add_argument( 60 | "--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu" 61 | ) 62 | parser.add_argument( 63 | "--retina", 64 | type=bool, 65 | default=True, 66 | help="draw high-resolution segmentation masks", 67 | ) 68 | parser.add_argument( 69 | "--withContours", type=bool, default=False, help="draw the edges of the masks" 70 | ) 71 | return parser.parse_args() 72 | 73 | 74 | def main(args): 75 | # load model 76 | model = FastSAM(args.model_path) 77 | args.point_prompt = ast.literal_eval(args.point_prompt) 78 | args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt)) 79 | args.point_label = ast.literal_eval(args.point_label) 80 | input = Image.open(args.img_path) 81 | input = input.convert("RGB") 82 | everything_results = model( 83 | input, 84 | device=args.device, 85 | retina_masks=args.retina, 86 | imgsz=args.imgsz, 87 | conf=args.conf, 88 | iou=args.iou 89 | ) 90 | bboxes = None 91 | points = None 92 | point_label = None 93 | prompt_process = FastSAMPrompt(input, everything_results, device=args.device) 94 | if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0: 95 | ann = prompt_process.box_prompt(bboxes=args.box_prompt) 96 | bboxes = args.box_prompt 97 | elif args.text_prompt != None: 98 | ann = prompt_process.text_prompt(text=args.text_prompt) 99 | elif args.point_prompt[0] != [0, 0]: 100 | ann = prompt_process.point_prompt( 101 | points=args.point_prompt, pointlabel=args.point_label 102 | ) 103 | points = args.point_prompt 104 | point_label = args.point_label 105 | else: 106 | ann = prompt_process.everything_prompt() 107 | prompt_process.plot( 108 | annotations=ann, 109 | output_path=args.output+args.img_path.split("/")[-1], 110 | bboxes = bboxes, 111 | points = points, 112 | point_label = point_label, 113 | withContours=args.withContours, 114 | better_quality=args.better_quality, 115 | ) 116 | 117 | 118 | 119 | 120 | if __name__ == "__main__": 121 | args = parse_args() 122 | main(args) 123 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 | CASIA-IVA-Lab 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | 5 | recursive-exclude * __pycache__ 6 | recursive-exclude * *.py[co] 7 | 8 | -------------------------------------------------------------------------------- /MORE_USAGES.md: -------------------------------------------------------------------------------- 1 | # MORE_USAGES 2 | 3 | 4 | 5 | ### Everything mode 6 | Use --imgsz to change different input sizes. 7 | 8 | ```shell 9 | python Inference.py --model_path ./weights/FastSAM.pt \ 10 | --img_path ./images/dogs.jpg \ 11 | --imgsz 720 \ 12 | ``` 13 | ![everything mode](assets/more_usages/everything_mode.png) 14 | 15 | 16 | 17 | ### Use more points 18 | p 19 | ```shell 20 | python Inference.py --model_path ./weights/FastSAM.pt \ 21 | --img_path ./images/dogs.jpg \ 22 | --point_prompt "[[520,360],[620,300],[520,300],[620,360]]" \ 23 | --point_label "[1,0,1,0]" 24 | ``` 25 | ![points prompt](assets/more_usages/more_points.png) 26 | ### draw mask edge 27 | Use `--withContours True` to draw the edge of the mask. 28 | 29 | When `--better_quality True` is set, the edge will be more smooth. 30 | 31 | ```shell 32 | python Inference.py --model_path ./weights/FastSAM.pt \ 33 | --img_path ./images/dogs.jpg \ 34 | --point_prompt "[[620,360]]" \ 35 | --point_label "[1]" \ 36 | --withContours True \ 37 | --better_quality True 38 | ``` 39 | 40 | ![Draw Edge](assets/more_usages/draw_edge.png) 41 | ### use box prompt 42 | Use `--box_prompt [x,y,w,h]` to specify the bounding box of the foreground object 43 | ```shell 44 | python Inference.py --model_path ./weights/FastSAM.pt \ 45 | --img_path ./images/dogs.jpg \ 46 | --box_prompt "[[570,200,230,400]]" 47 | ``` 48 | ![box prompt](assets/more_usages/box_prompt.png) 49 | 50 | ### use text prompt 51 | Use `--text_prompt "text"` to specify the text prompt 52 | ```shell 53 | python Inference.py --model_path ./weights/FastSAM.pt \ 54 | --img_path ./images/cat.jpg \ 55 | --text_prompt "cat" \ 56 | --better_quality True \ 57 | --withContours True 58 | ``` 59 | ![text prompt](assets/more_usages/text_prompt_cat.png) 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](assets/logo.png) 2 | 3 | # Fast Segment Anything 4 | 5 | [![image](https://img.shields.io/pypi/v/segment-anything-fast.svg)](https://pypi.python.org/pypi/segment-anything-fast) 6 | 7 | [[`📕Paper`](https://arxiv.org/pdf/2306.12156.pdf)] [[`🤗HuggingFace Demo`](https://huggingface.co/spaces/An-619/FastSAM)] [[`Colab demo`](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)] [[`Replicate demo & API`](https://replicate.com/casia-iva-lab/fastsam)] [[`Model Zoo`](#model-checkpoints)] [[`BibTeX`](#citing-fastsam)] 8 | 9 | ![FastSAM Speed](assets/head_fig.png) 10 | 11 | The **Fast Segment Anything Model(FastSAM)** is a CNN Segment Anything Model trained using only 2% of the SA-1B dataset published by SAM authors. FastSAM achieves comparable performance with 12 | the SAM method at **50× higher run-time speed**. 13 | 14 | ![FastSAM design](assets/Overview.png) 15 | 16 | **🍇 Updates** 17 | 18 | - **`2023/07/06`** Added to [Ultralytics (YOLOv8) Model Hub](https://docs.ultralytics.com/models/fast-sam/). Thanks to [Ultralytics](https://github.com/ultralytics/ultralytics) for help 🌹. 19 | - **`2023/06/29`** Support [text mode](https://huggingface.co/spaces/An-619/FastSAM) in HuggingFace Space. Thanks a lot to [gaoxinge](https://github.com/gaoxinge) for help 🌹. 20 | - **`2023/06/29`** Release [FastSAM_Awesome_TensorRT](https://github.com/ChuRuaNh0/FastSam_Awsome_TensorRT). Thanks a lot to [ChuRuaNh0](https://github.com/ChuRuaNh0) for providing the TensorRT model of FastSAM 🌹. 21 | - **`2023/06/26`** Release [FastSAM Replicate Online Demo](https://replicate.com/casia-iva-lab/fastsam). Thanks a lot to [Chenxi](https://chenxwh.github.io/) for providing this nice demo 🌹. 22 | - **`2023/06/26`** Support [points mode](https://huggingface.co/spaces/An-619/FastSAM) in HuggingFace Space. Better and faster interaction will come soon! 23 | - **`2023/06/24`** Thanks a lot to [Grounding-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) for Combining Grounding-DINO with FastSAM in [Grounded-FastSAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM) 🌹. 24 | 25 | ## Installation 26 | 27 | Clone the repository locally: 28 | 29 | ```shell 30 | pip install segment-anything-fast 31 | ``` 32 | 33 | ## Getting Started 34 | 35 | First download a [model checkpoint](#model-checkpoints). 36 | 37 | Then, you can run the scripts to try the everything mode and three prompt modes. 38 | 39 | ```shell 40 | # Everything mode 41 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg 42 | ``` 43 | 44 | ```shell 45 | # Text prompt 46 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --text_prompt "the yellow dog" 47 | ``` 48 | 49 | ```shell 50 | # Box prompt (xywh) 51 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --box_prompt "[[570,200,230,400]]" 52 | ``` 53 | 54 | ```shell 55 | # Points prompt 56 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --point_prompt "[[520,360],[620,300]]" --point_label "[1,0]" 57 | ``` 58 | 59 | You can use the following code to generate all masks, make mask selection based on prompts, and visualize the results. 60 | 61 | ```shell 62 | from fastsam import FastSAM, FastSAMPrompt 63 | 64 | model = FastSAM('./weights/FastSAM.pt') 65 | IMAGE_PATH = './images/dogs.jpg' 66 | DEVICE = 'cpu' 67 | everything_results = model(IMAGE_PATH, device=DEVICE, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,) 68 | prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE) 69 | 70 | # everything prompt 71 | ann = prompt_process.everything_prompt() 72 | 73 | # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2] 74 | ann = prompt_process.box_prompt(bbox=[[200, 200, 300, 300]]) 75 | 76 | # text prompt 77 | ann = prompt_process.text_prompt(text='a photo of a dog') 78 | 79 | # point prompt 80 | # points default [[0,0]] [[x1,y1],[x2,y2]] 81 | # point_label default [0] [1,0] 0:background, 1:foreground 82 | ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1]) 83 | 84 | prompt_process.plot(annotations=ann,output_path='./output/dog.jpg',) 85 | ``` 86 | 87 | You are also welcomed to try our Colab demo: [FastSAM_example.ipynb](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing). 88 | 89 | ## Different Inference Options 90 | 91 | We provide various options for different purposes, details are in [MORE_USAGES.md](MORE_USAGES.md). 92 | 93 | ## Web demo 94 | 95 | ### Gradio demo 96 | 97 | - We also provide a UI for testing our method that is built with gradio. You can upload a custom image, select the mode and set the parameters, click the segment button, and get a satisfactory segmentation result. Currently, the UI supports interaction with the 'Everything mode' and 'points mode'. We plan to add support for additional modes in the future. Running the following command in a terminal will launch the demo: 98 | 99 | ``` 100 | # Download the pre-trained model in "./weights/FastSAM.pt" 101 | python app_gradio.py 102 | ``` 103 | 104 | - This demo is also hosted on [HuggingFace Space](https://huggingface.co/spaces/An-619/FastSAM). 105 | 106 | ![HF_Everyhting](assets/hf_everything_mode.png) ![HF_Points](assets/hf_points_mode.png) 107 | 108 | ### Replicate demo 109 | 110 | - [Replicate demo](https://replicate.com/casia-iva-lab/fastsam) has supported all modes, you can experience points/box/text mode. 111 | 112 | ![Replicate-1](assets/replicate-1.png) ![Replicate-2](assets/replicate-2.png) ![Replicate-3](assets/replicate-3.png) 113 | 114 | ## Model Checkpoints 115 | 116 | Two model versions of the model are available with different sizes. Click the links below to download the checkpoint for the corresponding model type. 117 | 118 | - **`default` or `FastSAM`: [YOLOv8x based Segment Anything Model](https://drive.google.com/file/d/1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv/view?usp=sharing) | [Baidu Cloud (pwd: 0000).](https://pan.baidu.com/s/18KzBmOTENjByoWWR17zdiQ?pwd=0000)** 119 | - `FastSAM-s`: [YOLOv8s based Segment Anything Model.](https://drive.google.com/file/d/10XmSj6mmpmRb8NhXbtiuO9cTTBwR_9SV/view?usp=sharing) 120 | 121 | ## Results 122 | 123 | All result were tested on a single NVIDIA GeForce RTX 3090. 124 | 125 | ### 1. Inference time 126 | 127 | Running Speed under Different Point Prompt Numbers(ms). 128 | | method | params | 1 | 10 | 100 | E(16x16) | E(32x32\*) | E(64x64) | 129 | |:------------------:|:--------:|:-----:|:-----:|:-----:|:----------:|:-----------:|:----------:| 130 | | SAM-H | 0.6G | 446 | 464 | 627 | 852 | 2099 | 6972 | 131 | | SAM-B | 136M | 110 | 125 | 230 | 432 | 1383 | 5417 | 132 | | FastSAM | 68M | 40 |40 | 40 | 40 | 40 | 40 | 133 | 134 | ### 2. Memory usage 135 | 136 | | Dataset | Method | GPU Memory (MB) | 137 | | :-------: | :-----: | :-------------: | 138 | | COCO 2017 | FastSAM | 2608 | 139 | | COCO 2017 | SAM-H | 7060 | 140 | | COCO 2017 | SAM-B | 4670 | 141 | 142 | ### 3. Zero-shot Transfer Experiments 143 | 144 | #### Edge Detection 145 | 146 | Test on the BSDB500 dataset. 147 | |method | year| ODS | OIS | AP | R50 | 148 | |:----------:|:-------:|:--------:|:--------:|:------:|:-----:| 149 | | HED | 2015| .788 | .808 | .840 | .923 | 150 | | SAM | 2023| .768 | .786 | .794 | .928 | 151 | | FastSAM | 2023| .750 | .790 | .793 | .903 | 152 | 153 | #### Object Proposals 154 | 155 | ##### COCO 156 | 157 | | method | AR10 | AR100 | AR1000 | AUC | 158 | | :-------: | :--: | :---: | :----: | :--: | 159 | | SAM-H E64 | 15.5 | 45.6 | 67.7 | 32.1 | 160 | | SAM-H E32 | 18.5 | 49.5 | 62.5 | 33.7 | 161 | | SAM-B E32 | 11.4 | 39.6 | 59.1 | 27.3 | 162 | | FastSAM | 15.7 | 47.3 | 63.7 | 32.2 | 163 | 164 | ##### LVIS 165 | 166 | bbox AR@1000 167 | | method | all | small | med. | large | 168 | |:---------------:|:-----:|:------:|:-----:|:------:| 169 | | ViTDet-H | 65.0 | 53.2 | 83.3 | 91.2 | 170 | zero-shot transfer methods 171 | | SAM-H E64 | 52.1 | 36.6 | 75.1 | 88.2 | 172 | | SAM-H E32 | 50.3 | 33.1 | 76.2 | 89.8 | 173 | | SAM-B E32 | 45.0 | 29.3 | 68.7 | 80.6 | 174 | | FastSAM | 57.1 | 44.3 | 77.1 | 85.3 | 175 | 176 | #### Instance Segmentation On COCO 2017 177 | 178 | | method | AP | APS | APM | APL | 179 | | :------: | :--: | :--: | :--: | :--: | 180 | | ViTDet-H | .510 | .320 | .543 | .689 | 181 | | SAM | .465 | .308 | .510 | .617 | 182 | | FastSAM | .379 | .239 | .434 | .500 | 183 | 184 | ### 4. Performance Visualization 185 | 186 | Several segmentation results: 187 | 188 | #### Natural Images 189 | 190 | ![Natural Images](assets/eightpic.png) 191 | 192 | #### Text to Mask 193 | 194 | ![Text to Mask](assets/dog_clip.png) 195 | 196 | ### 5.Downstream tasks 197 | 198 | The results of several downstream tasks to show the effectiveness. 199 | 200 | #### Anomaly Detection 201 | 202 | ![Anomaly Detection](assets/anomaly.png) 203 | 204 | #### Salient Object Detection 205 | 206 | ![Salient Object Detection](assets/salient.png) 207 | 208 | #### Building Extracting 209 | 210 | ![Building Detection](assets/building.png) 211 | 212 | ## License 213 | 214 | The model is licensed under the [Apache 2.0 license](LICENSE). 215 | 216 | ## Acknowledgement 217 | 218 | - [Segment Anything](https://segment-anything.com/) provides the SA-1B dataset and the base codes. 219 | - [YOLOv8](https://github.com/ultralytics/ultralytics) provides codes and pre-trained models. 220 | - [YOLACT](https://arxiv.org/abs/2112.10003) provides powerful instance segmentation method. 221 | - [Grounded-Segment-Anything](https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything) provides a useful web demo template. 222 | 223 | ## Contributors 224 | 225 | Our project wouldn't be possible without the contributions of these amazing people! Thank you all for making this project better. 226 | 227 | 228 | 229 | 230 | 231 | ## Citing FastSAM 232 | 233 | If you find this project useful for your research, please consider citing the following BibTeX entry. 234 | 235 | ``` 236 | @misc{zhao2023fast, 237 | title={Fast Segment Anything}, 238 | author={Xu Zhao and Wenchao Ding and Yongqi An and Yinglong Du and Tao Yu and Min Li and Ming Tang and Jinqiao Wang}, 239 | year={2023}, 240 | eprint={2306.12156}, 241 | archivePrefix={arXiv}, 242 | primaryClass={cs.CV} 243 | } 244 | ``` 245 | 246 | [![Star History Chart](https://api.star-history.com/svg?repos=CASIA-IVA-Lab/FastSAM&type=Date)](https://star-history.com/#CASIA-IVA-Lab/FastSAM&Date) 247 | -------------------------------------------------------------------------------- /app_gradio.py: -------------------------------------------------------------------------------- 1 | from ultralytics import YOLO 2 | import gradio as gr 3 | import torch 4 | from utils.tools_gradio import fast_process 5 | from utils.tools import format_results, box_prompt, point_prompt, text_prompt 6 | from PIL import ImageDraw 7 | import numpy as np 8 | 9 | # Load the pre-trained model 10 | model = YOLO('./weights/FastSAM.pt') 11 | 12 | device = torch.device( 13 | "cuda" 14 | if torch.cuda.is_available() 15 | else "mps" 16 | if torch.backends.mps.is_available() 17 | else "cpu" 18 | ) 19 | 20 | # Description 21 | title = "
🏃 Fast Segment Anything 🤗
" 22 | 23 | news = """ # 📖 News 24 | 🔥 2023/07/14: Add a "wider result" button in text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/95)). 25 | 26 | 🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)). 27 | 28 | 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!) 29 | 30 | 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment. 31 | """ 32 | 33 | description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it. 34 | 35 | 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon. 36 | 37 | ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded. 38 | 39 | 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked. 40 | 41 | 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing) 42 | 43 | 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant. 44 | 45 | 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM) 46 | 47 | """ 48 | 49 | description_p = """ # 🎯 Instructions for points mode 50 | This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it. 51 | 52 | 1. Upload an image or choose an example. 53 | 54 | 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented). 55 | 56 | 3. Add points one by one on the image. 57 | 58 | 4. Click the 'Segment with points prompt' button to get the segmentation results. 59 | 60 | **5. If you get Error, click the 'Clear points' button and try again may help.** 61 | 62 | """ 63 | 64 | examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"], 65 | ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]] 66 | 67 | default_example = examples[0] 68 | 69 | css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" 70 | 71 | 72 | def segment_everything( 73 | input, 74 | input_size=1024, 75 | iou_threshold=0.7, 76 | conf_threshold=0.25, 77 | better_quality=False, 78 | withContours=True, 79 | use_retina=True, 80 | text="", 81 | wider=False, 82 | mask_random_color=True, 83 | ): 84 | input_size = int(input_size) # 确保 imgsz 是整数 85 | # Thanks for the suggestion by hysts in HuggingFace. 86 | w, h = input.size 87 | scale = input_size / max(w, h) 88 | new_w = int(w * scale) 89 | new_h = int(h * scale) 90 | input = input.resize((new_w, new_h)) 91 | 92 | results = model(input, 93 | device=device, 94 | retina_masks=True, 95 | iou=iou_threshold, 96 | conf=conf_threshold, 97 | imgsz=input_size,) 98 | 99 | if len(text) > 0: 100 | results = format_results(results[0], 0) 101 | annotations, _ = text_prompt(results, text, input, device=device, wider=wider) 102 | annotations = np.array([annotations]) 103 | else: 104 | annotations = results[0].masks.data 105 | 106 | fig = fast_process(annotations=annotations, 107 | image=input, 108 | device=device, 109 | scale=(1024 // input_size), 110 | better_quality=better_quality, 111 | mask_random_color=mask_random_color, 112 | bbox=None, 113 | use_retina=use_retina, 114 | withContours=withContours,) 115 | return fig 116 | 117 | 118 | def segment_with_points( 119 | input, 120 | input_size=1024, 121 | iou_threshold=0.7, 122 | conf_threshold=0.25, 123 | better_quality=False, 124 | withContours=True, 125 | use_retina=True, 126 | mask_random_color=True, 127 | ): 128 | global global_points 129 | global global_point_label 130 | 131 | input_size = int(input_size) # 确保 imgsz 是整数 132 | # Thanks for the suggestion by hysts in HuggingFace. 133 | w, h = input.size 134 | scale = input_size / max(w, h) 135 | new_w = int(w * scale) 136 | new_h = int(h * scale) 137 | input = input.resize((new_w, new_h)) 138 | 139 | scaled_points = [[int(x * scale) for x in point] for point in global_points] 140 | 141 | results = model(input, 142 | device=device, 143 | retina_masks=True, 144 | iou=iou_threshold, 145 | conf=conf_threshold, 146 | imgsz=input_size,) 147 | 148 | results = format_results(results[0], 0) 149 | annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w) 150 | annotations = np.array([annotations]) 151 | 152 | fig = fast_process(annotations=annotations, 153 | image=input, 154 | device=device, 155 | scale=(1024 // input_size), 156 | better_quality=better_quality, 157 | mask_random_color=mask_random_color, 158 | bbox=None, 159 | use_retina=use_retina, 160 | withContours=withContours,) 161 | 162 | global_points = [] 163 | global_point_label = [] 164 | return fig, None 165 | 166 | 167 | def get_points_with_draw(image, label, evt: gr.SelectData): 168 | global global_points 169 | global global_point_label 170 | 171 | x, y = evt.index[0], evt.index[1] 172 | point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255) 173 | global_points.append([x, y]) 174 | global_point_label.append(1 if label == 'Add Mask' else 0) 175 | 176 | print(x, y, label == 'Add Mask') 177 | 178 | # 创建一个可以在图像上绘图的对象 179 | draw = ImageDraw.Draw(image) 180 | draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color) 181 | return image 182 | 183 | 184 | cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil') 185 | cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil') 186 | cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil') 187 | 188 | segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil') 189 | segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil') 190 | segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil') 191 | 192 | global_points = [] 193 | global_point_label = [] 194 | 195 | input_size_slider = gr.components.Slider(minimum=512, 196 | maximum=1024, 197 | value=1024, 198 | step=64, 199 | label='Input_size', 200 | info='Our model was trained on a size of 1024') 201 | 202 | with gr.Blocks(css=css, title='Fast Segment Anything') as demo: 203 | with gr.Row(): 204 | with gr.Column(scale=1): 205 | # Title 206 | gr.Markdown(title) 207 | 208 | with gr.Column(scale=1): 209 | # News 210 | gr.Markdown(news) 211 | 212 | with gr.Tab("Everything mode"): 213 | # Images 214 | with gr.Row(variant="panel"): 215 | with gr.Column(scale=1): 216 | cond_img_e.render() 217 | 218 | with gr.Column(scale=1): 219 | segm_img_e.render() 220 | 221 | # Submit & Clear 222 | with gr.Row(): 223 | with gr.Column(): 224 | input_size_slider.render() 225 | 226 | with gr.Row(): 227 | contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') 228 | 229 | with gr.Column(): 230 | segment_btn_e = gr.Button("Segment Everything", variant='primary') 231 | clear_btn_e = gr.Button("Clear", variant="secondary") 232 | 233 | gr.Markdown("Try some of the examples below ⬇️") 234 | gr.Examples(examples=examples, 235 | inputs=[cond_img_e], 236 | outputs=segm_img_e, 237 | fn=segment_everything, 238 | cache_examples=True, 239 | examples_per_page=4) 240 | 241 | with gr.Column(): 242 | with gr.Accordion("Advanced options", open=False): 243 | iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations') 244 | conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold') 245 | with gr.Row(): 246 | mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx') 247 | with gr.Column(): 248 | retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks') 249 | 250 | # Description 251 | gr.Markdown(description_e) 252 | 253 | segment_btn_e.click(segment_everything, 254 | inputs=[ 255 | cond_img_e, 256 | input_size_slider, 257 | iou_threshold, 258 | conf_threshold, 259 | mor_check, 260 | contour_check, 261 | retina_check, 262 | ], 263 | outputs=segm_img_e) 264 | 265 | with gr.Tab("Points mode"): 266 | # Images 267 | with gr.Row(variant="panel"): 268 | with gr.Column(scale=1): 269 | cond_img_p.render() 270 | 271 | with gr.Column(scale=1): 272 | segm_img_p.render() 273 | 274 | # Submit & Clear 275 | with gr.Row(): 276 | with gr.Column(): 277 | with gr.Row(): 278 | add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)") 279 | 280 | with gr.Column(): 281 | segment_btn_p = gr.Button("Segment with points prompt", variant='primary') 282 | clear_btn_p = gr.Button("Clear points", variant='secondary') 283 | 284 | gr.Markdown("Try some of the examples below ⬇️") 285 | gr.Examples(examples=examples, 286 | inputs=[cond_img_p], 287 | # outputs=segm_img_p, 288 | # fn=segment_with_points, 289 | # cache_examples=True, 290 | examples_per_page=4) 291 | 292 | with gr.Column(): 293 | # Description 294 | gr.Markdown(description_p) 295 | 296 | cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p) 297 | 298 | segment_btn_p.click(segment_with_points, 299 | inputs=[cond_img_p], 300 | outputs=[segm_img_p, cond_img_p]) 301 | 302 | with gr.Tab("Text mode"): 303 | # Images 304 | with gr.Row(variant="panel"): 305 | with gr.Column(scale=1): 306 | cond_img_t.render() 307 | 308 | with gr.Column(scale=1): 309 | segm_img_t.render() 310 | 311 | # Submit & Clear 312 | with gr.Row(): 313 | with gr.Column(): 314 | input_size_slider_t = gr.components.Slider(minimum=512, 315 | maximum=1024, 316 | value=1024, 317 | step=64, 318 | label='Input_size', 319 | info='Our model was trained on a size of 1024') 320 | with gr.Row(): 321 | with gr.Column(): 322 | contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') 323 | text_box = gr.Textbox(label="text prompt", value="a black dog") 324 | 325 | with gr.Column(): 326 | segment_btn_t = gr.Button("Segment with text", variant='primary') 327 | clear_btn_t = gr.Button("Clear", variant="secondary") 328 | 329 | gr.Markdown("Try some of the examples below ⬇️") 330 | gr.Examples(examples=[["examples/dogs.jpg"]] + examples, 331 | inputs=[cond_img_e], 332 | # outputs=segm_img_e, 333 | # fn=segment_everything, 334 | # cache_examples=True, 335 | examples_per_page=4) 336 | 337 | with gr.Column(): 338 | with gr.Accordion("Advanced options", open=False): 339 | iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations') 340 | conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold') 341 | with gr.Row(): 342 | mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx') 343 | retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks') 344 | wider_check = gr.Checkbox(value=False, label='wider', info='wider result') 345 | 346 | # Description 347 | gr.Markdown(description_e) 348 | 349 | segment_btn_t.click(segment_everything, 350 | inputs=[ 351 | cond_img_t, 352 | input_size_slider_t, 353 | iou_threshold, 354 | conf_threshold, 355 | mor_check, 356 | contour_check, 357 | retina_check, 358 | text_box, 359 | wider_check, 360 | ], 361 | outputs=segm_img_t) 362 | 363 | def clear(): 364 | return None, None 365 | 366 | def clear_text(): 367 | return None, None, None 368 | 369 | clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e]) 370 | clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p]) 371 | clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box]) 372 | 373 | demo.queue() 374 | demo.launch() 375 | -------------------------------------------------------------------------------- /assets/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/Overview.png -------------------------------------------------------------------------------- /assets/anomaly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/anomaly.png -------------------------------------------------------------------------------- /assets/building.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/building.png -------------------------------------------------------------------------------- /assets/dog_clip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/dog_clip.png -------------------------------------------------------------------------------- /assets/eightpic.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/eightpic.pdf -------------------------------------------------------------------------------- /assets/eightpic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/eightpic.png -------------------------------------------------------------------------------- /assets/head_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/head_fig.png -------------------------------------------------------------------------------- /assets/hf_everything_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/hf_everything_mode.png -------------------------------------------------------------------------------- /assets/hf_points_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/hf_points_mode.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/logo.png -------------------------------------------------------------------------------- /assets/more_usages/box_prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/box_prompt.png -------------------------------------------------------------------------------- /assets/more_usages/draw_edge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/draw_edge.png -------------------------------------------------------------------------------- /assets/more_usages/everything_mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/everything_mode.png -------------------------------------------------------------------------------- /assets/more_usages/everything_mode_without_retina.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/everything_mode_without_retina.png -------------------------------------------------------------------------------- /assets/more_usages/more_points.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/more_points.png -------------------------------------------------------------------------------- /assets/more_usages/text_prompt_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/text_prompt_cat.png -------------------------------------------------------------------------------- /assets/replicate-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/replicate-1.png -------------------------------------------------------------------------------- /assets/replicate-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/replicate-2.png -------------------------------------------------------------------------------- /assets/replicate-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/replicate-3.png -------------------------------------------------------------------------------- /assets/salient.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/salient.png -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | # Thanks for chenxwh. 4 | 5 | build: 6 | # set to true if your model requires a GPU 7 | gpu: true 8 | cuda: "11.7" 9 | system_packages: 10 | - "libgl1-mesa-glx" 11 | - "libglib2.0-0" 12 | python_version: "3.8" 13 | python_packages: 14 | - "matplotlib==3.7.1" 15 | - "opencv-python==4.7.0.72" 16 | - "Pillow==9.5.0" 17 | - "PyYAML==6.0" 18 | - "requests==2.31.0" 19 | - "scipy==1.10.1" 20 | - "torch==2.0.1" 21 | - "torchvision==0.15.2" 22 | - "tqdm==4.65.0" 23 | - "pandas==2.0.2" 24 | - "seaborn==0.12.0" 25 | - "ultralytics==8.0.121" 26 | - git+https://github.com/openai/CLIP.git 27 | predict: "predict.py:Predictor" 28 | -------------------------------------------------------------------------------- /examples/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/dogs.jpg -------------------------------------------------------------------------------- /examples/sa_10039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_10039.jpg -------------------------------------------------------------------------------- /examples/sa_11025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_11025.jpg -------------------------------------------------------------------------------- /examples/sa_1309.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_1309.jpg -------------------------------------------------------------------------------- /examples/sa_192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_192.jpg -------------------------------------------------------------------------------- /examples/sa_414.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_414.jpg -------------------------------------------------------------------------------- /examples/sa_561.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_561.jpg -------------------------------------------------------------------------------- /examples/sa_862.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_862.jpg -------------------------------------------------------------------------------- /examples/sa_8776.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_8776.jpg -------------------------------------------------------------------------------- /fastsam/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import FastSAM 4 | from .predict import FastSAMPredictor 5 | from .prompt import FastSAMPrompt 6 | # from .val import FastSAMValidator 7 | from .decoder import FastSAMDecoder 8 | 9 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder' 10 | -------------------------------------------------------------------------------- /fastsam/decoder.py: -------------------------------------------------------------------------------- 1 | from .model import FastSAM 2 | import numpy as np 3 | from PIL import Image 4 | import clip 5 | from typing import Optional, List, Tuple, Union 6 | 7 | 8 | class FastSAMDecoder: 9 | def __init__( 10 | self, 11 | model: FastSAM, 12 | device: str='cpu', 13 | conf: float=0.4, 14 | iou: float=0.9, 15 | imgsz: int=1024, 16 | retina_masks: bool=True, 17 | ): 18 | self.model = model 19 | self.device = device 20 | self.retina_masks = retina_masks 21 | self.imgsz = imgsz 22 | self.conf = conf 23 | self.iou = iou 24 | self.image = None 25 | self.image_embedding = None 26 | 27 | def run_encoder(self, image): 28 | if isinstance(image,str): 29 | image = np.array(Image.open(image)) 30 | self.image = image 31 | image_embedding = self.model( 32 | self.image, 33 | device=self.device, 34 | retina_masks=self.retina_masks, 35 | imgsz=self.imgsz, 36 | conf=self.conf, 37 | iou=self.iou 38 | ) 39 | return image_embedding[0].numpy() 40 | 41 | def run_decoder( 42 | self, 43 | image_embedding, 44 | point_prompt: Optional[np.ndarray]=None, 45 | point_label: Optional[np.ndarray]=None, 46 | box_prompt: Optional[np.ndarray]=None, 47 | text_prompt: Optional[str]=None, 48 | )->np.ndarray: 49 | self.image_embedding = image_embedding 50 | if point_prompt is not None: 51 | ann = self.point_prompt(points=point_prompt, pointlabel=point_label) 52 | return ann 53 | elif box_prompt is not None: 54 | ann = self.box_prompt(bbox=box_prompt) 55 | return ann 56 | elif text_prompt is not None: 57 | ann = self.text_prompt(text=text_prompt) 58 | return ann 59 | else: 60 | return None 61 | 62 | def box_prompt(self, bbox): 63 | assert (bbox[2] != 0 and bbox[3] != 0) 64 | masks = self.image_embedding.masks.data 65 | target_height = self.image.shape[0] 66 | target_width = self.image.shape[1] 67 | h = masks.shape[1] 68 | w = masks.shape[2] 69 | if h != target_height or w != target_width: 70 | bbox = [ 71 | int(bbox[0] * w / target_width), 72 | int(bbox[1] * h / target_height), 73 | int(bbox[2] * w / target_width), 74 | int(bbox[3] * h / target_height), ] 75 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 76 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 77 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 78 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 79 | 80 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 81 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 82 | 83 | masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2)) 84 | orig_masks_area = np.sum(masks, axis=(1, 2)) 85 | 86 | union = bbox_area + orig_masks_area - masks_area 87 | IoUs = masks_area / union 88 | max_iou_index = np.argmax(IoUs) 89 | 90 | return np.array([masks[max_iou_index].cpu().numpy()]) 91 | 92 | def point_prompt(self, points, pointlabel): # numpy 93 | 94 | masks = self._format_results(self.image_embedding[0], 0) 95 | target_height = self.image.shape[0] 96 | target_width = self.image.shape[1] 97 | h = masks[0]['segmentation'].shape[0] 98 | w = masks[0]['segmentation'].shape[1] 99 | if h != target_height or w != target_width: 100 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] 101 | onemask = np.zeros((h, w)) 102 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 103 | for i, annotation in enumerate(masks): 104 | if type(annotation) == dict: 105 | mask = annotation['segmentation'] 106 | else: 107 | mask = annotation 108 | for i, point in enumerate(points): 109 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: 110 | onemask[mask] = 1 111 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: 112 | onemask[mask] = 0 113 | onemask = onemask >= 1 114 | return np.array([onemask]) 115 | 116 | def _format_results(self, result, filter=0): 117 | annotations = [] 118 | n = len(result.masks.data) 119 | for i in range(n): 120 | annotation = {} 121 | mask = result.masks.data[i] == 1.0 122 | 123 | if np.sum(mask) < filter: 124 | continue 125 | annotation['id'] = i 126 | annotation['segmentation'] = mask 127 | annotation['bbox'] = result.boxes.data[i] 128 | annotation['score'] = result.boxes.conf[i] 129 | annotation['area'] = annotation['segmentation'].sum() 130 | annotations.append(annotation) 131 | return annotations 132 | -------------------------------------------------------------------------------- /fastsam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | FastSAM model interface. 4 | 5 | Usage - Predict: 6 | from ultralytics import FastSAM 7 | 8 | model = FastSAM('last.pt') 9 | results = model.predict('ultralytics/assets/bus.jpg') 10 | """ 11 | 12 | from ultralytics.yolo.cfg import get_cfg 13 | from ultralytics.yolo.engine.exporter import Exporter 14 | from ultralytics.yolo.engine.model import YOLO 15 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir 16 | from ultralytics.yolo.utils.checks import check_imgsz 17 | 18 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode 19 | from .predict import FastSAMPredictor 20 | 21 | 22 | class FastSAM(YOLO): 23 | 24 | @smart_inference_mode() 25 | def predict(self, source=None, stream=False, **kwargs): 26 | """ 27 | Perform prediction using the YOLO model. 28 | 29 | Args: 30 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on. 31 | Accepts all source types accepted by the YOLO model. 32 | stream (bool): Whether to stream the predictions or not. Defaults to False. 33 | **kwargs : Additional keyword arguments passed to the predictor. 34 | Check the 'configuration' section in the documentation for all available options. 35 | 36 | Returns: 37 | (List[ultralytics.yolo.engine.results.Results]): The prediction results. 38 | """ 39 | if source is None: 40 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' 41 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") 42 | overrides = self.overrides.copy() 43 | overrides['conf'] = 0.25 44 | overrides.update(kwargs) # prefer kwargs 45 | overrides['mode'] = kwargs.get('mode', 'predict') 46 | assert overrides['mode'] in ['track', 'predict'] 47 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python 48 | self.predictor = FastSAMPredictor(overrides=overrides) 49 | self.predictor.setup_model(model=self.model, verbose=False) 50 | try: 51 | return self.predictor(source, stream=stream) 52 | except Exception as e: 53 | return None 54 | 55 | def train(self, **kwargs): 56 | """Function trains models but raises an error as FastSAM models do not support training.""" 57 | raise NotImplementedError("Currently, the training codes are on the way.") 58 | 59 | def val(self, **kwargs): 60 | """Run validation given dataset.""" 61 | overrides = dict(task='segment', mode='val') 62 | overrides.update(kwargs) # prefer kwargs 63 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 64 | args.imgsz = check_imgsz(args.imgsz, max_dim=1) 65 | validator = FastSAM(args=args) 66 | validator(model=self.model) 67 | self.metrics = validator.metrics 68 | return validator.metrics 69 | 70 | @smart_inference_mode() 71 | def export(self, **kwargs): 72 | """ 73 | Export model. 74 | 75 | Args: 76 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs 77 | """ 78 | overrides = dict(task='detect') 79 | overrides.update(kwargs) 80 | overrides['mode'] = 'export' 81 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 82 | args.task = self.task 83 | if args.imgsz == DEFAULT_CFG.imgsz: 84 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed 85 | if args.batch == DEFAULT_CFG.batch: 86 | args.batch = 1 # default to 1 if not modified 87 | return Exporter(overrides=args)(model=self.model) 88 | 89 | def info(self, detailed=False, verbose=True): 90 | """ 91 | Logs model info. 92 | 93 | Args: 94 | detailed (bool): Show detailed information about model. 95 | verbose (bool): Controls verbosity. 96 | """ 97 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) 98 | 99 | def __call__(self, source=None, stream=False, **kwargs): 100 | """Calls the 'predict' function with given arguments to perform object detection.""" 101 | return self.predict(source, stream, **kwargs) 102 | 103 | def __getattr__(self, attr): 104 | """Raises error if object has no requested attribute.""" 105 | name = self.__class__.__name__ 106 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 107 | -------------------------------------------------------------------------------- /fastsam/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ultralytics.yolo.engine.results import Results 4 | from ultralytics.yolo.utils import DEFAULT_CFG, ops 5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 6 | from .utils import bbox_iou 7 | 8 | class FastSAMPredictor(DetectionPredictor): 9 | 10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 11 | super().__init__(cfg, overrides, _callbacks) 12 | self.args.task = 'segment' 13 | 14 | def postprocess(self, preds, img, orig_imgs): 15 | """TODO: filter by classes.""" 16 | p = ops.non_max_suppression(preds[0], 17 | self.args.conf, 18 | self.args.iou, 19 | agnostic=self.args.agnostic_nms, 20 | max_det=self.args.max_det, 21 | nc=len(self.model.names), 22 | classes=self.args.classes) 23 | 24 | results = [] 25 | if len(p) == 0 or len(p[0]) == 0: 26 | print("No object detected.") 27 | return results 28 | 29 | full_box = torch.zeros_like(p[0][0]) 30 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 31 | full_box = full_box.view(1, -1) 32 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) 33 | if critical_iou_index.numel() != 0: 34 | full_box[0][4] = p[0][critical_iou_index][:,4] 35 | full_box[0][6:] = p[0][critical_iou_index][:,6:] 36 | p[0][critical_iou_index] = full_box 37 | 38 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 39 | for i, pred in enumerate(p): 40 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 41 | path = self.batch[0] 42 | img_path = path[i] if isinstance(path, list) else path 43 | if not len(pred): # save empty boxes 44 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 45 | continue 46 | if self.args.retina_masks: 47 | if not isinstance(orig_imgs, torch.Tensor): 48 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 49 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 50 | else: 51 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 52 | if not isinstance(orig_imgs, torch.Tensor): 53 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 54 | results.append( 55 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 56 | return results 57 | -------------------------------------------------------------------------------- /fastsam/prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from .utils import image_to_np_ndarray 8 | from PIL import Image 9 | 10 | try: 11 | import clip # for linear_assignment 12 | 13 | except (ImportError, AssertionError, AttributeError): 14 | from ultralytics.yolo.utils.checks import check_requirements 15 | 16 | check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source 17 | import clip 18 | 19 | 20 | class FastSAMPrompt: 21 | 22 | def __init__(self, image, results, device='cuda'): 23 | if isinstance(image, str) or isinstance(image, Image.Image): 24 | image = image_to_np_ndarray(image) 25 | self.device = device 26 | self.results = results 27 | self.img = image 28 | 29 | def _segment_image(self, image, bbox): 30 | if isinstance(image, Image.Image): 31 | image_array = np.array(image) 32 | else: 33 | image_array = image 34 | segmented_image_array = np.zeros_like(image_array) 35 | x1, y1, x2, y2 = bbox 36 | segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] 37 | segmented_image = Image.fromarray(segmented_image_array) 38 | black_image = Image.new('RGB', image.size, (255, 255, 255)) 39 | # transparency_mask = np.zeros_like((), dtype=np.uint8) 40 | transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) 41 | transparency_mask[y1:y2, x1:x2] = 255 42 | transparency_mask_image = Image.fromarray(transparency_mask, mode='L') 43 | black_image.paste(segmented_image, mask=transparency_mask_image) 44 | return black_image 45 | 46 | def _format_results(self, result, filter=0): 47 | annotations = [] 48 | n = len(result.masks.data) 49 | for i in range(n): 50 | annotation = {} 51 | mask = result.masks.data[i] == 1.0 52 | 53 | if torch.sum(mask) < filter: 54 | continue 55 | annotation['id'] = i 56 | annotation['segmentation'] = mask.cpu().numpy() 57 | annotation['bbox'] = result.boxes.data[i] 58 | annotation['score'] = result.boxes.conf[i] 59 | annotation['area'] = annotation['segmentation'].sum() 60 | annotations.append(annotation) 61 | return annotations 62 | 63 | def filter_masks(annotations): # filte the overlap mask 64 | annotations.sort(key=lambda x: x['area'], reverse=True) 65 | to_remove = set() 66 | for i in range(0, len(annotations)): 67 | a = annotations[i] 68 | for j in range(i + 1, len(annotations)): 69 | b = annotations[j] 70 | if i != j and j not in to_remove: 71 | # check if 72 | if b['area'] < a['area']: 73 | if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8: 74 | to_remove.add(j) 75 | 76 | return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove 77 | 78 | def _get_bbox_from_mask(self, mask): 79 | mask = mask.astype(np.uint8) 80 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 81 | x1, y1, w, h = cv2.boundingRect(contours[0]) 82 | x2, y2 = x1 + w, y1 + h 83 | if len(contours) > 1: 84 | for b in contours: 85 | x_t, y_t, w_t, h_t = cv2.boundingRect(b) 86 | # Merge multiple bounding boxes into one. 87 | x1 = min(x1, x_t) 88 | y1 = min(y1, y_t) 89 | x2 = max(x2, x_t + w_t) 90 | y2 = max(y2, y_t + h_t) 91 | h = y2 - y1 92 | w = x2 - x1 93 | return [x1, y1, x2, y2] 94 | 95 | def plot_to_result(self, 96 | annotations, 97 | bboxes=None, 98 | points=None, 99 | point_label=None, 100 | mask_random_color=True, 101 | better_quality=True, 102 | retina=False, 103 | withContours=True) -> np.ndarray: 104 | if isinstance(annotations[0], dict): 105 | annotations = [annotation['segmentation'] for annotation in annotations] 106 | image = self.img 107 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 108 | original_h = image.shape[0] 109 | original_w = image.shape[1] 110 | if sys.platform == "darwin": 111 | plt.switch_backend("TkAgg") 112 | plt.figure(figsize=(original_w / 100, original_h / 100)) 113 | # Add subplot with no margin. 114 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 115 | plt.margins(0, 0) 116 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 117 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 118 | 119 | plt.imshow(image) 120 | if better_quality: 121 | if isinstance(annotations[0], torch.Tensor): 122 | annotations = np.array(annotations.cpu()) 123 | for i, mask in enumerate(annotations): 124 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) 125 | annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) 126 | if self.device == 'cpu': 127 | annotations = np.array(annotations) 128 | self.fast_show_mask( 129 | annotations, 130 | plt.gca(), 131 | random_color=mask_random_color, 132 | bboxes=bboxes, 133 | points=points, 134 | pointlabel=point_label, 135 | retinamask=retina, 136 | target_height=original_h, 137 | target_width=original_w, 138 | ) 139 | else: 140 | if isinstance(annotations[0], np.ndarray): 141 | annotations = torch.from_numpy(annotations) 142 | self.fast_show_mask_gpu( 143 | annotations, 144 | plt.gca(), 145 | random_color=mask_random_color, 146 | bboxes=bboxes, 147 | points=points, 148 | pointlabel=point_label, 149 | retinamask=retina, 150 | target_height=original_h, 151 | target_width=original_w, 152 | ) 153 | if isinstance(annotations, torch.Tensor): 154 | annotations = annotations.cpu().numpy() 155 | if withContours: 156 | contour_all = [] 157 | temp = np.zeros((original_h, original_w, 1)) 158 | for i, mask in enumerate(annotations): 159 | if type(mask) == dict: 160 | mask = mask['segmentation'] 161 | annotation = mask.astype(np.uint8) 162 | if not retina: 163 | annotation = cv2.resize( 164 | annotation, 165 | (original_w, original_h), 166 | interpolation=cv2.INTER_NEAREST, 167 | ) 168 | contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 169 | for contour in contours: 170 | contour_all.append(contour) 171 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) 172 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) 173 | contour_mask = temp / 255 * color.reshape(1, 1, -1) 174 | plt.imshow(contour_mask) 175 | 176 | plt.axis('off') 177 | fig = plt.gcf() 178 | plt.draw() 179 | 180 | try: 181 | buf = fig.canvas.tostring_rgb() 182 | except AttributeError: 183 | fig.canvas.draw() 184 | buf = fig.canvas.tostring_rgb() 185 | cols, rows = fig.canvas.get_width_height() 186 | img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) 187 | result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) 188 | plt.close() 189 | return result 190 | 191 | # Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control. 192 | def plot(self, 193 | annotations, 194 | output_path, 195 | bboxes=None, 196 | points=None, 197 | point_label=None, 198 | mask_random_color=True, 199 | better_quality=True, 200 | retina=False, 201 | withContours=True): 202 | if len(annotations) == 0: 203 | return None 204 | result = self.plot_to_result( 205 | annotations, 206 | bboxes, 207 | points, 208 | point_label, 209 | mask_random_color, 210 | better_quality, 211 | retina, 212 | withContours, 213 | ) 214 | 215 | path = os.path.dirname(os.path.abspath(output_path)) 216 | if not os.path.exists(path): 217 | os.makedirs(path) 218 | result = result[:, :, ::-1] 219 | cv2.imwrite(output_path, result) 220 | 221 | # CPU post process 222 | def fast_show_mask( 223 | self, 224 | annotation, 225 | ax, 226 | random_color=False, 227 | bboxes=None, 228 | points=None, 229 | pointlabel=None, 230 | retinamask=True, 231 | target_height=960, 232 | target_width=960, 233 | ): 234 | msak_sum = annotation.shape[0] 235 | height = annotation.shape[1] 236 | weight = annotation.shape[2] 237 | #Sort annotations based on area. 238 | areas = np.sum(annotation, axis=(1, 2)) 239 | sorted_indices = np.argsort(areas) 240 | annotation = annotation[sorted_indices] 241 | 242 | index = (annotation != 0).argmax(axis=0) 243 | if random_color: 244 | color = np.random.random((msak_sum, 1, 1, 3)) 245 | else: 246 | color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) 247 | transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 248 | visual = np.concatenate([color, transparency], axis=-1) 249 | mask_image = np.expand_dims(annotation, -1) * visual 250 | 251 | show = np.zeros((height, weight, 4)) 252 | h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') 253 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 254 | # Use vectorized indexing to update the values of 'show'. 255 | show[h_indices, w_indices, :] = mask_image[indices] 256 | if bboxes is not None: 257 | for bbox in bboxes: 258 | x1, y1, x2, y2 = bbox 259 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) 260 | # draw point 261 | if points is not None: 262 | plt.scatter( 263 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], 264 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], 265 | s=20, 266 | c='y', 267 | ) 268 | plt.scatter( 269 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], 270 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], 271 | s=20, 272 | c='m', 273 | ) 274 | 275 | if not retinamask: 276 | show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST) 277 | ax.imshow(show) 278 | 279 | def fast_show_mask_gpu( 280 | self, 281 | annotation, 282 | ax, 283 | random_color=False, 284 | bboxes=None, 285 | points=None, 286 | pointlabel=None, 287 | retinamask=True, 288 | target_height=960, 289 | target_width=960, 290 | ): 291 | msak_sum = annotation.shape[0] 292 | height = annotation.shape[1] 293 | weight = annotation.shape[2] 294 | areas = torch.sum(annotation, dim=(1, 2)) 295 | sorted_indices = torch.argsort(areas, descending=False) 296 | annotation = annotation[sorted_indices] 297 | # Find the index of the first non-zero value at each position. 298 | index = (annotation != 0).to(torch.long).argmax(dim=0) 299 | if random_color: 300 | color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) 301 | else: 302 | color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ 303 | 30 / 255, 144 / 255, 255 / 255]).to(annotation.device) 304 | transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 305 | visual = torch.cat([color, transparency], dim=-1) 306 | mask_image = torch.unsqueeze(annotation, -1) * visual 307 | # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. 308 | show = torch.zeros((height, weight, 4)).to(annotation.device) 309 | try: 310 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') 311 | except: 312 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) 313 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 314 | # Use vectorized indexing to update the values of 'show'. 315 | show[h_indices, w_indices, :] = mask_image[indices] 316 | show_cpu = show.cpu().numpy() 317 | if bboxes is not None: 318 | for bbox in bboxes: 319 | x1, y1, x2, y2 = bbox 320 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) 321 | # draw point 322 | if points is not None: 323 | plt.scatter( 324 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], 325 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], 326 | s=20, 327 | c='y', 328 | ) 329 | plt.scatter( 330 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], 331 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], 332 | s=20, 333 | c='m', 334 | ) 335 | if not retinamask: 336 | show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST) 337 | ax.imshow(show_cpu) 338 | 339 | # clip 340 | @torch.no_grad() 341 | def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: 342 | preprocessed_images = [preprocess(image).to(device) for image in elements] 343 | tokenized_text = clip.tokenize([search_text]).to(device) 344 | stacked_images = torch.stack(preprocessed_images) 345 | image_features = model.encode_image(stacked_images) 346 | text_features = model.encode_text(tokenized_text) 347 | image_features /= image_features.norm(dim=-1, keepdim=True) 348 | text_features /= text_features.norm(dim=-1, keepdim=True) 349 | probs = 100.0 * image_features @ text_features.T 350 | return probs[:, 0].softmax(dim=0) 351 | 352 | def _crop_image(self, format_results): 353 | 354 | image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)) 355 | ori_w, ori_h = image.size 356 | annotations = format_results 357 | mask_h, mask_w = annotations[0]['segmentation'].shape 358 | if ori_w != mask_w or ori_h != mask_h: 359 | image = image.resize((mask_w, mask_h)) 360 | cropped_boxes = [] 361 | cropped_images = [] 362 | not_crop = [] 363 | filter_id = [] 364 | # annotations, _ = filter_masks(annotations) 365 | # filter_id = list(_) 366 | for _, mask in enumerate(annotations): 367 | if np.sum(mask['segmentation']) <= 100: 368 | filter_id.append(_) 369 | continue 370 | bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox 371 | cropped_boxes.append(self._segment_image(image, bbox)) 372 | # cropped_boxes.append(segment_image(image,mask["segmentation"])) 373 | cropped_images.append(bbox) # Save the bounding box of the cropped image. 374 | 375 | return cropped_boxes, cropped_images, not_crop, filter_id, annotations 376 | 377 | def box_prompt(self, bbox=None, bboxes=None): 378 | if self.results == None: 379 | return [] 380 | assert bbox or bboxes 381 | if bboxes is None: 382 | bboxes = [bbox] 383 | max_iou_index = [] 384 | for bbox in bboxes: 385 | assert (bbox[2] != 0 and bbox[3] != 0) 386 | masks = self.results[0].masks.data 387 | target_height = self.img.shape[0] 388 | target_width = self.img.shape[1] 389 | h = masks.shape[1] 390 | w = masks.shape[2] 391 | if h != target_height or w != target_width: 392 | bbox = [ 393 | int(bbox[0] * w / target_width), 394 | int(bbox[1] * h / target_height), 395 | int(bbox[2] * w / target_width), 396 | int(bbox[3] * h / target_height), ] 397 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 398 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 399 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 400 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 401 | 402 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 403 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 404 | 405 | masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) 406 | orig_masks_area = torch.sum(masks, dim=(1, 2)) 407 | 408 | union = bbox_area + orig_masks_area - masks_area 409 | IoUs = masks_area / union 410 | max_iou_index.append(int(torch.argmax(IoUs))) 411 | max_iou_index = list(set(max_iou_index)) 412 | return np.array(masks[max_iou_index].cpu().numpy()) 413 | 414 | def point_prompt(self, points, pointlabel): # numpy 415 | if self.results == None: 416 | return [] 417 | masks = self._format_results(self.results[0], 0) 418 | target_height = self.img.shape[0] 419 | target_width = self.img.shape[1] 420 | h = masks[0]['segmentation'].shape[0] 421 | w = masks[0]['segmentation'].shape[1] 422 | if h != target_height or w != target_width: 423 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] 424 | onemask = np.zeros((h, w)) 425 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 426 | for i, annotation in enumerate(masks): 427 | if type(annotation) == dict: 428 | mask = annotation['segmentation'] 429 | else: 430 | mask = annotation 431 | for i, point in enumerate(points): 432 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: 433 | onemask[mask] = 1 434 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: 435 | onemask[mask] = 0 436 | onemask = onemask >= 1 437 | return np.array([onemask]) 438 | 439 | def text_prompt(self, text): 440 | if self.results == None: 441 | return [] 442 | format_results = self._format_results(self.results[0], 0) 443 | cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) 444 | clip_model, preprocess = clip.load('ViT-B/32', device=self.device) 445 | scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) 446 | max_idx = scores.argsort() 447 | max_idx = max_idx[-1] 448 | max_idx += sum(np.array(filter_id) <= int(max_idx)) 449 | return np.array([annotations[max_idx]['segmentation']]) 450 | 451 | def everything_prompt(self): 452 | if self.results == None: 453 | return [] 454 | return self.results[0].masks.data 455 | 456 | -------------------------------------------------------------------------------- /fastsam/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): 7 | '''Adjust bounding boxes to stick to image border if they are within a certain threshold. 8 | Args: 9 | boxes: (n, 4) 10 | image_shape: (height, width) 11 | threshold: pixel threshold 12 | Returns: 13 | adjusted_boxes: adjusted bounding boxes 14 | ''' 15 | 16 | # Image dimensions 17 | h, w = image_shape 18 | 19 | # Adjust boxes 20 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor( 21 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1 22 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor( 23 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1 24 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor( 25 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2 26 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor( 27 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2 28 | 29 | return boxes 30 | 31 | 32 | 33 | def convert_box_xywh_to_xyxy(box): 34 | x1 = box[0] 35 | y1 = box[1] 36 | x2 = box[0] + box[2] 37 | y2 = box[1] + box[3] 38 | return [x1, y1, x2, y2] 39 | 40 | 41 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): 42 | '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. 43 | Args: 44 | box1: (4, ) 45 | boxes: (n, 4) 46 | Returns: 47 | high_iou_indices: Indices of boxes with IoU > thres 48 | ''' 49 | boxes = adjust_bboxes_to_image_border(boxes, image_shape) 50 | # obtain coordinates for intersections 51 | x1 = torch.max(box1[0], boxes[:, 0]) 52 | y1 = torch.max(box1[1], boxes[:, 1]) 53 | x2 = torch.min(box1[2], boxes[:, 2]) 54 | y2 = torch.min(box1[3], boxes[:, 3]) 55 | 56 | # compute the area of intersection 57 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) 58 | 59 | # compute the area of both individual boxes 60 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) 61 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 62 | 63 | # compute the area of union 64 | union = box1_area + box2_area - intersection 65 | 66 | # compute the IoU 67 | iou = intersection / union # Should be shape (n, ) 68 | if raw_output: 69 | if iou.numel() == 0: 70 | return 0 71 | return iou 72 | 73 | # get indices of boxes with IoU > thres 74 | high_iou_indices = torch.nonzero(iou > iou_thres).flatten() 75 | 76 | return high_iou_indices 77 | 78 | 79 | def image_to_np_ndarray(image): 80 | if type(image) is str: 81 | return np.array(Image.open(image)) 82 | elif issubclass(type(image), Image.Image): 83 | return np.array(image) 84 | elif type(image) is np.ndarray: 85 | return image 86 | return None 87 | -------------------------------------------------------------------------------- /images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/images/cat.jpg -------------------------------------------------------------------------------- /images/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/images/dogs.jpg -------------------------------------------------------------------------------- /output/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/output/cat.jpg -------------------------------------------------------------------------------- /output/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/output/dogs.jpg -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | # Thanks for chenxwh. 4 | 5 | import argparse 6 | import cv2 7 | import shutil 8 | import ast 9 | from cog import BasePredictor, Input, Path 10 | from ultralytics import YOLO 11 | from utils.tools import * 12 | 13 | 14 | class Predictor(BasePredictor): 15 | def setup(self): 16 | """Load the model into memory to make running multiple predictions efficient""" 17 | self.models = {k: YOLO(f"{k}.pt") for k in ["FastSAM-s", "FastSAM-x"]} 18 | 19 | def predict( 20 | self, 21 | input_image: Path = Input(description="Input image"), 22 | model_name: str = Input( 23 | description="choose a model", 24 | choices=["FastSAM-x", "FastSAM-s"], 25 | default="FastSAM-x", 26 | ), 27 | iou: float = Input( 28 | description="iou threshold for filtering the annotations", default=0.7 29 | ), 30 | text_prompt: str = Input( 31 | description='use text prompt eg: "a black dog"', default=None 32 | ), 33 | conf: float = Input(description="object confidence threshold", default=0.25), 34 | retina: bool = Input( 35 | description="draw high-resolution segmentation masks", default=True 36 | ), 37 | box_prompt: str = Input(default="[0,0,0,0]", description="[x,y,w,h]"), 38 | point_prompt: str = Input(default="[[0,0]]", description="[[x1,y1],[x2,y2]]"), 39 | point_label: str = Input(default="[0]", description="[1,0] 0:background, 1:foreground"), 40 | withContours: bool = Input( 41 | description="draw the edges of the masks", default=False 42 | ), 43 | better_quality: bool = Input( 44 | description="better quality using morphologyEx", default=False 45 | ), 46 | ) -> Path: 47 | """Run a single prediction on the model""" 48 | 49 | # default params 50 | 51 | out_path = "output" 52 | if os.path.exists(out_path): 53 | shutil.rmtree(out_path) 54 | os.makedirs(out_path, exist_ok=True) 55 | 56 | device = torch.device( 57 | "cuda" 58 | if torch.cuda.is_available() 59 | else "mps" 60 | if torch.backends.mps.is_available() 61 | else "cpu" 62 | ) 63 | 64 | args = argparse.Namespace( 65 | better_quality=better_quality, 66 | box_prompt=box_prompt, 67 | conf=conf, 68 | device=device, 69 | img_path=str(input_image), 70 | imgsz=1024, 71 | iou=iou, 72 | model_path="FastSAM-x.pt", 73 | output=out_path, 74 | point_label=point_label, 75 | point_prompt=point_prompt, 76 | randomcolor=True, 77 | retina=retina, 78 | text_prompt=text_prompt, 79 | withContours=withContours, 80 | ) 81 | args.point_prompt = ast.literal_eval(args.point_prompt) 82 | args.box_prompt = ast.literal_eval(args.box_prompt) 83 | args.point_label = ast.literal_eval(args.point_label) 84 | 85 | model = self.models[model_name] 86 | 87 | results = model( 88 | str(input_image), 89 | imgsz=args.imgsz, 90 | device=args.device, 91 | retina_masks=args.retina, 92 | iou=args.iou, 93 | conf=args.conf, 94 | max_det=100, 95 | ) 96 | 97 | if args.box_prompt[2] != 0 and args.box_prompt[3] != 0: 98 | annotations = prompt(results, args, box=True) 99 | annotations = np.array([annotations]) 100 | fast_process( 101 | annotations=annotations, 102 | args=args, 103 | mask_random_color=args.randomcolor, 104 | bbox=convert_box_xywh_to_xyxy(args.box_prompt), 105 | ) 106 | 107 | elif args.text_prompt != None: 108 | results = format_results(results[0], 0) 109 | annotations = prompt(results, args, text=True) 110 | annotations = np.array([annotations]) 111 | fast_process( 112 | annotations=annotations, args=args, mask_random_color=args.randomcolor 113 | ) 114 | 115 | elif args.point_prompt[0] != [0, 0]: 116 | results = format_results(results[0], 0) 117 | annotations = prompt(results, args, point=True) 118 | # list to numpy 119 | annotations = np.array([annotations]) 120 | fast_process( 121 | annotations=annotations, 122 | args=args, 123 | mask_random_color=args.randomcolor, 124 | points=args.point_prompt, 125 | ) 126 | 127 | else: 128 | fast_process( 129 | annotations=results[0].masks.data, 130 | args=args, 131 | mask_random_color=args.randomcolor, 132 | ) 133 | 134 | out = "/tmp.out.png" 135 | shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out) 136 | 137 | return Path(out) 138 | 139 | 140 | def prompt(results, args, box=None, point=None, text=None): 141 | ori_img = cv2.imread(args.img_path) 142 | ori_h = ori_img.shape[0] 143 | ori_w = ori_img.shape[1] 144 | if box: 145 | mask, idx = box_prompt( 146 | results[0].masks.data, 147 | convert_box_xywh_to_xyxy(args.box_prompt), 148 | ori_h, 149 | ori_w, 150 | ) 151 | elif point: 152 | mask, idx = point_prompt( 153 | results, args.point_prompt, args.point_label, ori_h, ori_w 154 | ) 155 | elif text: 156 | mask, idx = text_prompt(results, args.text_prompt, args.img_path, args.device) 157 | else: 158 | return None 159 | return mask 160 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Base----------------------------------- 2 | matplotlib>=3.2.2 3 | opencv-python>=4.6.0 4 | Pillow>=7.1.2 5 | PyYAML>=5.3.1 6 | requests>=2.23.0 7 | scipy>=1.4.1 8 | torch>=1.7.0 9 | torchvision>=0.8.1 10 | tqdm>=4.64.0 11 | 12 | pandas>=1.1.4 13 | seaborn>=0.11.0 14 | 15 | gradio<=3.35.2 16 | 17 | # Ultralytics----------------------------------- 18 | ultralytics<=8.0.122 19 | 20 | openai-clip -------------------------------------------------------------------------------- /segpredict.py: -------------------------------------------------------------------------------- 1 | from fastsam import FastSAM, FastSAMPrompt 2 | import torch 3 | 4 | model = FastSAM('FastSAM.pt') 5 | IMAGE_PATH = './images/dogs.jpg' 6 | DEVICE = torch.device( 7 | "cuda" 8 | if torch.cuda.is_available() 9 | else "mps" 10 | if torch.backends.mps.is_available() 11 | else "cpu" 12 | ) 13 | everything_results = model( 14 | IMAGE_PATH, 15 | device=DEVICE, 16 | retina_masks=True, 17 | imgsz=1024, 18 | conf=0.4, 19 | iou=0.9, 20 | ) 21 | prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE) 22 | 23 | # # everything prompt 24 | ann = prompt_process.everything_prompt() 25 | 26 | # # bbox prompt 27 | # # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2] 28 | # bboxes default shape [[0,0,0,0]] -> [[x1,y1,x2,y2]] 29 | # ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300]) 30 | # ann = prompt_process.box_prompt(bboxes=[[200, 200, 300, 300], [500, 500, 600, 600]]) 31 | 32 | # # text prompt 33 | # ann = prompt_process.text_prompt(text='a photo of a dog') 34 | 35 | # # point prompt 36 | # # points default [[0,0]] [[x1,y1],[x2,y2]] 37 | # # point_label default [0] [1,0] 0:background, 1:foreground 38 | # ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1]) 39 | 40 | # point prompt 41 | # points default [[0,0]] [[x1,y1],[x2,y2]] 42 | # point_label default [0] [1,0] 0:background, 1:foreground 43 | ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1]) 44 | 45 | prompt_process.plot( 46 | annotations=ann, 47 | output='./output/', 48 | mask_random_color=True, 49 | better_quality=True, 50 | retina=False, 51 | withContours=True, 52 | ) 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """The setup script.""" 4 | 5 | import io 6 | from os import path as op 7 | from setuptools import setup, find_packages 8 | 9 | with open("README.md", encoding="utf-8") as readme_file: 10 | readme = readme_file.read() 11 | 12 | here = op.abspath(op.dirname(__file__)) 13 | 14 | # get the dependencies and installs 15 | with io.open(op.join(here, "requirements.txt"), encoding="utf-8") as f: 16 | all_reqs = f.read().split("\n") 17 | 18 | install_requires = [x.strip() for x in all_reqs if "git+" not in x] 19 | dependency_links = [x.strip().replace("git+", "") for x in all_reqs if "git+" not in x] 20 | 21 | extras_requires = {} 22 | 23 | 24 | requirements = [] 25 | 26 | setup_requirements = [] 27 | 28 | test_requirements = [] 29 | 30 | setup( 31 | author="Qiusheng Wu", 32 | author_email="giswqs@gmail.com", 33 | python_requires=">=3.8", 34 | classifiers=[ 35 | "Intended Audience :: Developers", 36 | "License :: OSI Approved :: Apache Software License", 37 | "Natural Language :: English", 38 | "Programming Language :: Python :: 3", 39 | "Programming Language :: Python :: 3.8", 40 | "Programming Language :: Python :: 3.9", 41 | "Programming Language :: Python :: 3.10", 42 | "Programming Language :: Python :: 3.11", 43 | ], 44 | description="Fast Segment Anything", 45 | install_requires=install_requires, 46 | extras_require=extras_requires, 47 | dependency_links=dependency_links, 48 | license="Apache Software License", 49 | long_description=readme, 50 | long_description_content_type="text/markdown", 51 | include_package_data=True, 52 | keywords="segment-anything", 53 | name="segment-anything-fast", 54 | version="0.1.2", 55 | packages=find_packages(include=["fastsam", "fastsam.*"]), 56 | setup_requires=setup_requirements, 57 | url="https://github.com/opengeos/FastSAM", 58 | zip_safe=False, 59 | ) 60 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/utils/__init__.py -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import torch 6 | import os 7 | import sys 8 | import clip 9 | 10 | 11 | def convert_box_xywh_to_xyxy(box): 12 | if len(box) == 4: 13 | return [box[0], box[1], box[0] + box[2], box[1] + box[3]] 14 | else: 15 | result = [] 16 | for b in box: 17 | b = convert_box_xywh_to_xyxy(b) 18 | result.append(b) 19 | return result 20 | 21 | 22 | def segment_image(image, bbox): 23 | image_array = np.array(image) 24 | segmented_image_array = np.zeros_like(image_array) 25 | x1, y1, x2, y2 = bbox 26 | segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] 27 | segmented_image = Image.fromarray(segmented_image_array) 28 | black_image = Image.new("RGB", image.size, (255, 255, 255)) 29 | # transparency_mask = np.zeros_like((), dtype=np.uint8) 30 | transparency_mask = np.zeros( 31 | (image_array.shape[0], image_array.shape[1]), dtype=np.uint8 32 | ) 33 | transparency_mask[y1:y2, x1:x2] = 255 34 | transparency_mask_image = Image.fromarray(transparency_mask, mode="L") 35 | black_image.paste(segmented_image, mask=transparency_mask_image) 36 | return black_image 37 | 38 | 39 | def format_results(result, filter=0): 40 | annotations = [] 41 | n = len(result.masks.data) 42 | for i in range(n): 43 | annotation = {} 44 | mask = result.masks.data[i] == 1.0 45 | 46 | if torch.sum(mask) < filter: 47 | continue 48 | annotation["id"] = i 49 | annotation["segmentation"] = mask.cpu().numpy() 50 | annotation["bbox"] = result.boxes.data[i] 51 | annotation["score"] = result.boxes.conf[i] 52 | annotation["area"] = annotation["segmentation"].sum() 53 | annotations.append(annotation) 54 | return annotations 55 | 56 | 57 | def filter_masks(annotations): # filter the overlap mask 58 | annotations.sort(key=lambda x: x["area"], reverse=True) 59 | to_remove = set() 60 | for i in range(0, len(annotations)): 61 | a = annotations[i] 62 | for j in range(i + 1, len(annotations)): 63 | b = annotations[j] 64 | if i != j and j not in to_remove: 65 | # check if 66 | if b["area"] < a["area"]: 67 | if (a["segmentation"] & b["segmentation"]).sum() / b[ 68 | "segmentation" 69 | ].sum() > 0.8: 70 | to_remove.add(j) 71 | 72 | return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove 73 | 74 | 75 | def get_bbox_from_mask(mask): 76 | mask = mask.astype(np.uint8) 77 | contours, hierarchy = cv2.findContours( 78 | mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 79 | ) 80 | x1, y1, w, h = cv2.boundingRect(contours[0]) 81 | x2, y2 = x1 + w, y1 + h 82 | if len(contours) > 1: 83 | for b in contours: 84 | x_t, y_t, w_t, h_t = cv2.boundingRect(b) 85 | # 将多个bbox合并成一个 86 | x1 = min(x1, x_t) 87 | y1 = min(y1, y_t) 88 | x2 = max(x2, x_t + w_t) 89 | y2 = max(y2, y_t + h_t) 90 | h = y2 - y1 91 | w = x2 - x1 92 | return [x1, y1, x2, y2] 93 | 94 | 95 | def fast_process( 96 | annotations, args, mask_random_color, bbox=None, points=None, edges=False 97 | ): 98 | if isinstance(annotations[0], dict): 99 | annotations = [annotation["segmentation"] for annotation in annotations] 100 | result_name = os.path.basename(args.img_path) 101 | image = cv2.imread(args.img_path) 102 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 103 | original_h = image.shape[0] 104 | original_w = image.shape[1] 105 | if sys.platform == "darwin": 106 | plt.switch_backend("TkAgg") 107 | plt.figure(figsize=(original_w/100, original_h/100)) 108 | # Add subplot with no margin. 109 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 110 | plt.margins(0, 0) 111 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 112 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 113 | plt.imshow(image) 114 | if args.better_quality == True: 115 | if isinstance(annotations[0], torch.Tensor): 116 | annotations = np.array(annotations.cpu()) 117 | for i, mask in enumerate(annotations): 118 | mask = cv2.morphologyEx( 119 | mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8) 120 | ) 121 | annotations[i] = cv2.morphologyEx( 122 | mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8) 123 | ) 124 | if args.device == "cpu": 125 | annotations = np.array(annotations) 126 | fast_show_mask( 127 | annotations, 128 | plt.gca(), 129 | random_color=mask_random_color, 130 | bbox=bbox, 131 | points=points, 132 | point_label=args.point_label, 133 | retinamask=args.retina, 134 | target_height=original_h, 135 | target_width=original_w, 136 | ) 137 | else: 138 | if isinstance(annotations[0], np.ndarray): 139 | annotations = torch.from_numpy(annotations) 140 | fast_show_mask_gpu( 141 | annotations, 142 | plt.gca(), 143 | random_color=args.randomcolor, 144 | bbox=bbox, 145 | points=points, 146 | point_label=args.point_label, 147 | retinamask=args.retina, 148 | target_height=original_h, 149 | target_width=original_w, 150 | ) 151 | if isinstance(annotations, torch.Tensor): 152 | annotations = annotations.cpu().numpy() 153 | if args.withContours == True: 154 | contour_all = [] 155 | temp = np.zeros((original_h, original_w, 1)) 156 | for i, mask in enumerate(annotations): 157 | if type(mask) == dict: 158 | mask = mask["segmentation"] 159 | annotation = mask.astype(np.uint8) 160 | if args.retina == False: 161 | annotation = cv2.resize( 162 | annotation, 163 | (original_w, original_h), 164 | interpolation=cv2.INTER_NEAREST, 165 | ) 166 | contours, hierarchy = cv2.findContours( 167 | annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 168 | ) 169 | for contour in contours: 170 | contour_all.append(contour) 171 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) 172 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) 173 | contour_mask = temp / 255 * color.reshape(1, 1, -1) 174 | plt.imshow(contour_mask) 175 | 176 | save_path = args.output 177 | if not os.path.exists(save_path): 178 | os.makedirs(save_path) 179 | plt.axis("off") 180 | fig = plt.gcf() 181 | plt.draw() 182 | 183 | try: 184 | buf = fig.canvas.tostring_rgb() 185 | except AttributeError: 186 | fig.canvas.draw() 187 | buf = fig.canvas.tostring_rgb() 188 | 189 | cols, rows = fig.canvas.get_width_height() 190 | img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3) 191 | cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) 192 | 193 | 194 | # CPU post process 195 | def fast_show_mask( 196 | annotation, 197 | ax, 198 | random_color=False, 199 | bbox=None, 200 | points=None, 201 | point_label=None, 202 | retinamask=True, 203 | target_height=960, 204 | target_width=960, 205 | ): 206 | msak_sum = annotation.shape[0] 207 | height = annotation.shape[1] 208 | weight = annotation.shape[2] 209 | # 将annotation 按照面积 排序 210 | areas = np.sum(annotation, axis=(1, 2)) 211 | sorted_indices = np.argsort(areas) 212 | annotation = annotation[sorted_indices] 213 | 214 | index = (annotation != 0).argmax(axis=0) 215 | if random_color == True: 216 | color = np.random.random((msak_sum, 1, 1, 3)) 217 | else: 218 | color = np.ones((msak_sum, 1, 1, 3)) * np.array( 219 | [30 / 255, 144 / 255, 255 / 255] 220 | ) 221 | transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 222 | visual = np.concatenate([color, transparency], axis=-1) 223 | mask_image = np.expand_dims(annotation, -1) * visual 224 | 225 | show = np.zeros((height, weight, 4)) 226 | h_indices, w_indices = np.meshgrid( 227 | np.arange(height), np.arange(weight), indexing="ij" 228 | ) 229 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 230 | # 使用向量化索引更新show的值 231 | show[h_indices, w_indices, :] = mask_image[indices] 232 | if bbox is not None: 233 | x1, y1, x2, y2 = bbox 234 | ax.add_patch( 235 | plt.Rectangle( 236 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 237 | ) 238 | ) 239 | # draw point 240 | if points is not None: 241 | plt.scatter( 242 | [point[0] for i, point in enumerate(points) if point_label[i] == 1], 243 | [point[1] for i, point in enumerate(points) if point_label[i] == 1], 244 | s=20, 245 | c="y", 246 | ) 247 | plt.scatter( 248 | [point[0] for i, point in enumerate(points) if point_label[i] == 0], 249 | [point[1] for i, point in enumerate(points) if point_label[i] == 0], 250 | s=20, 251 | c="m", 252 | ) 253 | 254 | if retinamask == False: 255 | show = cv2.resize( 256 | show, (target_width, target_height), interpolation=cv2.INTER_NEAREST 257 | ) 258 | ax.imshow(show) 259 | 260 | 261 | def fast_show_mask_gpu( 262 | annotation, 263 | ax, 264 | random_color=False, 265 | bbox=None, 266 | points=None, 267 | point_label=None, 268 | retinamask=True, 269 | target_height=960, 270 | target_width=960, 271 | ): 272 | msak_sum = annotation.shape[0] 273 | height = annotation.shape[1] 274 | weight = annotation.shape[2] 275 | areas = torch.sum(annotation, dim=(1, 2)) 276 | sorted_indices = torch.argsort(areas, descending=False) 277 | annotation = annotation[sorted_indices] 278 | # 找每个位置第一个非零值下标 279 | index = (annotation != 0).to(torch.long).argmax(dim=0) 280 | if random_color == True: 281 | color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) 282 | else: 283 | color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor( 284 | [30 / 255, 144 / 255, 255 / 255] 285 | ).to(annotation.device) 286 | transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 287 | visual = torch.cat([color, transparency], dim=-1) 288 | mask_image = torch.unsqueeze(annotation, -1) * visual 289 | # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 290 | show = torch.zeros((height, weight, 4)).to(annotation.device) 291 | h_indices, w_indices = torch.meshgrid( 292 | torch.arange(height), torch.arange(weight), indexing="ij" 293 | ) 294 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 295 | # 使用向量化索引更新show的值 296 | show[h_indices, w_indices, :] = mask_image[indices] 297 | show_cpu = show.cpu().numpy() 298 | if bbox is not None: 299 | x1, y1, x2, y2 = bbox 300 | ax.add_patch( 301 | plt.Rectangle( 302 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 303 | ) 304 | ) 305 | # draw point 306 | if points is not None: 307 | plt.scatter( 308 | [point[0] for i, point in enumerate(points) if point_label[i] == 1], 309 | [point[1] for i, point in enumerate(points) if point_label[i] == 1], 310 | s=20, 311 | c="y", 312 | ) 313 | plt.scatter( 314 | [point[0] for i, point in enumerate(points) if point_label[i] == 0], 315 | [point[1] for i, point in enumerate(points) if point_label[i] == 0], 316 | s=20, 317 | c="m", 318 | ) 319 | if retinamask == False: 320 | show_cpu = cv2.resize( 321 | show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST 322 | ) 323 | ax.imshow(show_cpu) 324 | 325 | 326 | # clip 327 | @torch.no_grad() 328 | def retriev( 329 | model, preprocess, elements: [Image.Image], search_text: str, device 330 | ): 331 | preprocessed_images = [preprocess(image).to(device) for image in elements] 332 | tokenized_text = clip.tokenize([search_text]).to(device) 333 | stacked_images = torch.stack(preprocessed_images) 334 | image_features = model.encode_image(stacked_images) 335 | text_features = model.encode_text(tokenized_text) 336 | image_features /= image_features.norm(dim=-1, keepdim=True) 337 | text_features /= text_features.norm(dim=-1, keepdim=True) 338 | probs = 100.0 * image_features @ text_features.T 339 | return probs[:, 0].softmax(dim=0) 340 | 341 | 342 | def crop_image(annotations, image_like): 343 | if isinstance(image_like, str): 344 | image = Image.open(image_like) 345 | else: 346 | image = image_like 347 | ori_w, ori_h = image.size 348 | mask_h, mask_w = annotations[0]["segmentation"].shape 349 | if ori_w != mask_w or ori_h != mask_h: 350 | image = image.resize((mask_w, mask_h)) 351 | cropped_boxes = [] 352 | cropped_images = [] 353 | not_crop = [] 354 | origin_id = [] 355 | for _, mask in enumerate(annotations): 356 | if np.sum(mask["segmentation"]) <= 100: 357 | continue 358 | origin_id.append(_) 359 | bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox 360 | cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片 361 | # cropped_boxes.append(segment_image(image,mask["segmentation"])) 362 | cropped_images.append(bbox) # 保存裁剪的图片的bbox 363 | return cropped_boxes, cropped_images, not_crop, origin_id, annotations 364 | 365 | 366 | def box_prompt(masks, bbox, target_height, target_width): 367 | h = masks.shape[1] 368 | w = masks.shape[2] 369 | if h != target_height or w != target_width: 370 | bbox = [ 371 | int(bbox[0] * w / target_width), 372 | int(bbox[1] * h / target_height), 373 | int(bbox[2] * w / target_width), 374 | int(bbox[3] * h / target_height), 375 | ] 376 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 377 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 378 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 379 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 380 | 381 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 382 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 383 | 384 | masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2)) 385 | orig_masks_area = torch.sum(masks, dim=(1, 2)) 386 | 387 | union = bbox_area + orig_masks_area - masks_area 388 | IoUs = masks_area / union 389 | max_iou_index = torch.argmax(IoUs) 390 | 391 | return masks[max_iou_index].cpu().numpy(), max_iou_index 392 | 393 | 394 | def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理 395 | h = masks[0]["segmentation"].shape[0] 396 | w = masks[0]["segmentation"].shape[1] 397 | if h != target_height or w != target_width: 398 | points = [ 399 | [int(point[0] * w / target_width), int(point[1] * h / target_height)] 400 | for point in points 401 | ] 402 | onemask = np.zeros((h, w)) 403 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 404 | for i, annotation in enumerate(masks): 405 | if type(annotation) == dict: 406 | mask = annotation['segmentation'] 407 | else: 408 | mask = annotation 409 | for i, point in enumerate(points): 410 | if mask[point[1], point[0]] == 1 and point_label[i] == 1: 411 | onemask[mask] = 1 412 | if mask[point[1], point[0]] == 1 and point_label[i] == 0: 413 | onemask[mask] = 0 414 | onemask = onemask >= 1 415 | return onemask, 0 416 | 417 | 418 | def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9): 419 | cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image( 420 | annotations, img_path 421 | ) 422 | clip_model, preprocess = clip.load("ViT-B/32", device=device) 423 | scores = retriev( 424 | clip_model, preprocess, cropped_boxes, text, device=device 425 | ) 426 | max_idx = scores.argsort() 427 | max_idx = max_idx[-1] 428 | max_idx = origin_id[int(max_idx)] 429 | 430 | # find the biggest mask which contains the mask with max score 431 | if wider: 432 | mask0 = annotations_[max_idx]["segmentation"] 433 | area0 = np.sum(mask0) 434 | areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id] 435 | areas = sorted(areas, key=lambda area: area[1], reverse=True) 436 | indices = [area[0] for area in areas] 437 | for index in indices: 438 | if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold: 439 | max_idx = index 440 | break 441 | 442 | return annotations_[max_idx]["segmentation"], max_idx 443 | -------------------------------------------------------------------------------- /utils/tools_gradio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import torch 6 | 7 | 8 | def fast_process( 9 | annotations, 10 | image, 11 | device, 12 | scale, 13 | better_quality=False, 14 | mask_random_color=True, 15 | bbox=None, 16 | use_retina=True, 17 | withContours=True, 18 | ): 19 | if isinstance(annotations[0], dict): 20 | annotations = [annotation['segmentation'] for annotation in annotations] 21 | 22 | original_h = image.height 23 | original_w = image.width 24 | if better_quality: 25 | if isinstance(annotations[0], torch.Tensor): 26 | annotations = np.array(annotations.cpu()) 27 | for i, mask in enumerate(annotations): 28 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) 29 | annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) 30 | if device == 'cpu': 31 | annotations = np.array(annotations) 32 | inner_mask = fast_show_mask( 33 | annotations, 34 | plt.gca(), 35 | random_color=mask_random_color, 36 | bbox=bbox, 37 | retinamask=use_retina, 38 | target_height=original_h, 39 | target_width=original_w, 40 | ) 41 | else: 42 | if isinstance(annotations[0], np.ndarray): 43 | annotations = torch.from_numpy(annotations) 44 | inner_mask = fast_show_mask_gpu( 45 | annotations, 46 | plt.gca(), 47 | random_color=mask_random_color, 48 | bbox=bbox, 49 | retinamask=use_retina, 50 | target_height=original_h, 51 | target_width=original_w, 52 | ) 53 | if isinstance(annotations, torch.Tensor): 54 | annotations = annotations.cpu().numpy() 55 | 56 | if withContours: 57 | contour_all = [] 58 | temp = np.zeros((original_h, original_w, 1)) 59 | for i, mask in enumerate(annotations): 60 | if type(mask) == dict: 61 | mask = mask['segmentation'] 62 | annotation = mask.astype(np.uint8) 63 | if use_retina == False: 64 | annotation = cv2.resize( 65 | annotation, 66 | (original_w, original_h), 67 | interpolation=cv2.INTER_NEAREST, 68 | ) 69 | contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 70 | for contour in contours: 71 | contour_all.append(contour) 72 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale) 73 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9]) 74 | contour_mask = temp / 255 * color.reshape(1, 1, -1) 75 | 76 | image = image.convert('RGBA') 77 | overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA') 78 | image.paste(overlay_inner, (0, 0), overlay_inner) 79 | 80 | if withContours: 81 | overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA') 82 | image.paste(overlay_contour, (0, 0), overlay_contour) 83 | 84 | return image 85 | 86 | 87 | # CPU post process 88 | def fast_show_mask( 89 | annotation, 90 | ax, 91 | random_color=False, 92 | bbox=None, 93 | retinamask=True, 94 | target_height=960, 95 | target_width=960, 96 | ): 97 | mask_sum = annotation.shape[0] 98 | height = annotation.shape[1] 99 | weight = annotation.shape[2] 100 | # 将annotation 按照面积 排序 101 | areas = np.sum(annotation, axis=(1, 2)) 102 | sorted_indices = np.argsort(areas)[::1] 103 | annotation = annotation[sorted_indices] 104 | 105 | index = (annotation != 0).argmax(axis=0) 106 | if random_color: 107 | color = np.random.random((mask_sum, 1, 1, 3)) 108 | else: 109 | color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) 110 | transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6 111 | visual = np.concatenate([color, transparency], axis=-1) 112 | mask_image = np.expand_dims(annotation, -1) * visual 113 | 114 | mask = np.zeros((height, weight, 4)) 115 | 116 | h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') 117 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 118 | 119 | mask[h_indices, w_indices, :] = mask_image[indices] 120 | if bbox is not None: 121 | x1, y1, x2, y2 = bbox 122 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) 123 | 124 | if not retinamask: 125 | mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST) 126 | 127 | return mask 128 | 129 | 130 | def fast_show_mask_gpu( 131 | annotation, 132 | ax, 133 | random_color=False, 134 | bbox=None, 135 | retinamask=True, 136 | target_height=960, 137 | target_width=960, 138 | ): 139 | device = annotation.device 140 | mask_sum = annotation.shape[0] 141 | height = annotation.shape[1] 142 | weight = annotation.shape[2] 143 | areas = torch.sum(annotation, dim=(1, 2)) 144 | sorted_indices = torch.argsort(areas, descending=False) 145 | annotation = annotation[sorted_indices] 146 | # 找每个位置第一个非零值下标 147 | index = (annotation != 0).to(torch.long).argmax(dim=0) 148 | if random_color: 149 | color = torch.rand((mask_sum, 1, 1, 3)).to(device) 150 | else: 151 | color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor( 152 | [30 / 255, 144 / 255, 255 / 255] 153 | ).to(device) 154 | transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6 155 | visual = torch.cat([color, transparency], dim=-1) 156 | mask_image = torch.unsqueeze(annotation, -1) * visual 157 | # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 158 | mask = torch.zeros((height, weight, 4)).to(device) 159 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) 160 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 161 | # 使用向量化索引更新show的值 162 | mask[h_indices, w_indices, :] = mask_image[indices] 163 | mask_cpu = mask.cpu().numpy() 164 | if bbox is not None: 165 | x1, y1, x2, y2 = bbox 166 | ax.add_patch( 167 | plt.Rectangle( 168 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 169 | ) 170 | ) 171 | if not retinamask: 172 | mask_cpu = cv2.resize( 173 | mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST 174 | ) 175 | return mask_cpu 176 | --------------------------------------------------------------------------------