├── .github
└── workflows
│ ├── macos.yml
│ ├── pypi.yml
│ ├── ubuntu.yml
│ └── windows.yml
├── .gitignore
├── Inference.py
├── LICENSE
├── MANIFEST.in
├── MORE_USAGES.md
├── README.md
├── app_gradio.py
├── assets
├── Overview.png
├── anomaly.png
├── building.png
├── dog_clip.png
├── eightpic.pdf
├── eightpic.png
├── head_fig.png
├── hf_everything_mode.png
├── hf_points_mode.png
├── logo.png
├── more_usages
│ ├── box_prompt.png
│ ├── draw_edge.png
│ ├── everything_mode.png
│ ├── everything_mode_without_retina.png
│ ├── more_points.png
│ └── text_prompt_cat.png
├── replicate-1.png
├── replicate-2.png
├── replicate-3.png
└── salient.png
├── cog.yaml
├── examples
├── dogs.jpg
├── sa_10039.jpg
├── sa_11025.jpg
├── sa_1309.jpg
├── sa_192.jpg
├── sa_414.jpg
├── sa_561.jpg
├── sa_862.jpg
└── sa_8776.jpg
├── fastsam
├── __init__.py
├── decoder.py
├── model.py
├── predict.py
├── prompt.py
└── utils.py
├── images
├── cat.jpg
└── dogs.jpg
├── output
├── cat.jpg
└── dogs.jpg
├── predict.py
├── requirements.txt
├── segpredict.py
├── setup.py
└── utils
├── __init__.py
├── tools.py
└── tools_gradio.py
/.github/workflows/macos.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - main
5 | pull_request:
6 | branches:
7 | - main
8 |
9 | name: macOS build
10 | jobs:
11 | test-macOS:
12 | runs-on: ${{ matrix.config.os }}
13 | name: ${{ matrix.config.os }} (${{ matrix.config.py }})
14 | strategy:
15 | fail-fast: false
16 | matrix:
17 | config:
18 | - { os: macOS-latest, py: "3.10" }
19 | env:
20 | SDKROOT: /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk
21 | steps:
22 | - name: CHECKOUT CODE
23 | uses: actions/checkout@v3
24 | - name: SETUP PYTHON
25 | uses: actions/setup-python@v4
26 | with:
27 | python-version: ${{ matrix.config.py }}
28 | # - name: Install GDAL
29 | # run: |
30 | # python -m pip install --upgrade pip
31 | # pip install --no-cache-dir Cython
32 | # pip install --find-links=https://girder.github.io/large_image_wheels --no-cache GDAL
33 | # - name: Test GDAL installation
34 | # run: |
35 | # python -c "from osgeo import gdal"
36 | # gdalinfo --version
37 | - name: Install dependencies
38 | run: |
39 | python -m pip install --upgrade pip
40 | pip install --no-cache-dir Cython
41 | pip install codespell -r requirements.txt
42 | pip install .
43 |
--------------------------------------------------------------------------------
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: pypi
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v3
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: "3.x"
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install setuptools wheel twine
24 | - name: Build and publish
25 | env:
26 | TWINE_USERNAME: ${{ secrets.PYPI_USERS }}
27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
28 | run: |
29 | python setup.py sdist bdist_wheel
30 | twine upload dist/*
31 |
--------------------------------------------------------------------------------
/.github/workflows/ubuntu.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - main
5 | pull_request:
6 | branches:
7 | - main
8 |
9 | name: Linux build
10 | jobs:
11 | py-check:
12 | runs-on: ${{ matrix.config.os }}
13 | name: ${{ matrix.config.os }} (${{ matrix.config.py }})
14 | strategy:
15 | fail-fast: false
16 | matrix:
17 | config:
18 | - { os: ubuntu-latest, py: "3.8" }
19 | - { os: ubuntu-latest, py: "3.9" }
20 | - { os: ubuntu-latest, py: "3.10" }
21 | - { os: ubuntu-latest, py: "3.11" }
22 |
23 | env:
24 | SDKROOT: /Library/Developer/CommandLineTools/SDKs/MacOSX.sdk
25 | steps:
26 | - name: CHECKOUT CODE
27 | uses: actions/checkout@v3
28 | - name: SETUP PYTHON
29 | uses: actions/setup-python@v4
30 | with:
31 | python-version: ${{ matrix.config.py }}
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install --user --no-cache-dir Cython
36 | pip install --user -r requirements.txt
37 | pip install --user .
38 |
--------------------------------------------------------------------------------
/.github/workflows/windows.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - main
5 | pull_request:
6 | branches:
7 | - main
8 |
9 | name: Windows build
10 | jobs:
11 | test-windows:
12 | runs-on: windows-latest
13 | steps:
14 | - uses: actions/checkout@v3
15 | - name: Install miniconda
16 | uses: conda-incubator/setup-miniconda@v2
17 | with:
18 | auto-activate-base: true
19 | python-version: "3.10"
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install --no-cache-dir Cython
24 | pip install -r requirements.txt
25 | pip install .
26 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.pyo
3 | *.pyd
4 | .DS_Store
5 | .idea
6 | weights
7 | build/
8 | *.egg-info/
9 | gradio_cached_examples
10 | dist/
--------------------------------------------------------------------------------
/Inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from fastsam import FastSAM, FastSAMPrompt
3 | import ast
4 | import torch
5 | from PIL import Image
6 | from utils.tools import convert_box_xywh_to_xyxy
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument(
12 | "--model_path", type=str, default="./weights/FastSAM.pt", help="model"
13 | )
14 | parser.add_argument(
15 | "--img_path", type=str, default="./images/dogs.jpg", help="path to image file"
16 | )
17 | parser.add_argument("--imgsz", type=int, default=1024, help="image size")
18 | parser.add_argument(
19 | "--iou",
20 | type=float,
21 | default=0.9,
22 | help="iou threshold for filtering the annotations",
23 | )
24 | parser.add_argument(
25 | "--text_prompt", type=str, default=None, help='use text prompt eg: "a dog"'
26 | )
27 | parser.add_argument(
28 | "--conf", type=float, default=0.4, help="object confidence threshold"
29 | )
30 | parser.add_argument(
31 | "--output", type=str, default="./output/", help="image save path"
32 | )
33 | parser.add_argument(
34 | "--randomcolor", type=bool, default=True, help="mask random color"
35 | )
36 | parser.add_argument(
37 | "--point_prompt", type=str, default="[[0,0]]", help="[[x1,y1],[x2,y2]]"
38 | )
39 | parser.add_argument(
40 | "--point_label",
41 | type=str,
42 | default="[0]",
43 | help="[1,0] 0:background, 1:foreground",
44 | )
45 | parser.add_argument("--box_prompt", type=str, default="[[0,0,0,0]]", help="[[x,y,w,h],[x2,y2,w2,h2]] support multiple boxes")
46 | parser.add_argument(
47 | "--better_quality",
48 | type=str,
49 | default=False,
50 | help="better quality using morphologyEx",
51 | )
52 | device = torch.device(
53 | "cuda"
54 | if torch.cuda.is_available()
55 | else "mps"
56 | if torch.backends.mps.is_available()
57 | else "cpu"
58 | )
59 | parser.add_argument(
60 | "--device", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
61 | )
62 | parser.add_argument(
63 | "--retina",
64 | type=bool,
65 | default=True,
66 | help="draw high-resolution segmentation masks",
67 | )
68 | parser.add_argument(
69 | "--withContours", type=bool, default=False, help="draw the edges of the masks"
70 | )
71 | return parser.parse_args()
72 |
73 |
74 | def main(args):
75 | # load model
76 | model = FastSAM(args.model_path)
77 | args.point_prompt = ast.literal_eval(args.point_prompt)
78 | args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt))
79 | args.point_label = ast.literal_eval(args.point_label)
80 | input = Image.open(args.img_path)
81 | input = input.convert("RGB")
82 | everything_results = model(
83 | input,
84 | device=args.device,
85 | retina_masks=args.retina,
86 | imgsz=args.imgsz,
87 | conf=args.conf,
88 | iou=args.iou
89 | )
90 | bboxes = None
91 | points = None
92 | point_label = None
93 | prompt_process = FastSAMPrompt(input, everything_results, device=args.device)
94 | if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0:
95 | ann = prompt_process.box_prompt(bboxes=args.box_prompt)
96 | bboxes = args.box_prompt
97 | elif args.text_prompt != None:
98 | ann = prompt_process.text_prompt(text=args.text_prompt)
99 | elif args.point_prompt[0] != [0, 0]:
100 | ann = prompt_process.point_prompt(
101 | points=args.point_prompt, pointlabel=args.point_label
102 | )
103 | points = args.point_prompt
104 | point_label = args.point_label
105 | else:
106 | ann = prompt_process.everything_prompt()
107 | prompt_process.plot(
108 | annotations=ann,
109 | output_path=args.output+args.img_path.split("/")[-1],
110 | bboxes = bboxes,
111 | points = points,
112 | point_label = point_label,
113 | withContours=args.withContours,
114 | better_quality=args.better_quality,
115 | )
116 |
117 |
118 |
119 |
120 | if __name__ == "__main__":
121 | args = parse_args()
122 | main(args)
123 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 | CASIA-IVA-Lab
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README.md
3 | include requirements.txt
4 |
5 | recursive-exclude * __pycache__
6 | recursive-exclude * *.py[co]
7 |
8 |
--------------------------------------------------------------------------------
/MORE_USAGES.md:
--------------------------------------------------------------------------------
1 | # MORE_USAGES
2 |
3 |
4 |
5 | ### Everything mode
6 | Use --imgsz to change different input sizes.
7 |
8 | ```shell
9 | python Inference.py --model_path ./weights/FastSAM.pt \
10 | --img_path ./images/dogs.jpg \
11 | --imgsz 720 \
12 | ```
13 | 
14 |
15 |
16 |
17 | ### Use more points
18 | p
19 | ```shell
20 | python Inference.py --model_path ./weights/FastSAM.pt \
21 | --img_path ./images/dogs.jpg \
22 | --point_prompt "[[520,360],[620,300],[520,300],[620,360]]" \
23 | --point_label "[1,0,1,0]"
24 | ```
25 | 
26 | ### draw mask edge
27 | Use `--withContours True` to draw the edge of the mask.
28 |
29 | When `--better_quality True` is set, the edge will be more smooth.
30 |
31 | ```shell
32 | python Inference.py --model_path ./weights/FastSAM.pt \
33 | --img_path ./images/dogs.jpg \
34 | --point_prompt "[[620,360]]" \
35 | --point_label "[1]" \
36 | --withContours True \
37 | --better_quality True
38 | ```
39 |
40 | 
41 | ### use box prompt
42 | Use `--box_prompt [x,y,w,h]` to specify the bounding box of the foreground object
43 | ```shell
44 | python Inference.py --model_path ./weights/FastSAM.pt \
45 | --img_path ./images/dogs.jpg \
46 | --box_prompt "[[570,200,230,400]]"
47 | ```
48 | 
49 |
50 | ### use text prompt
51 | Use `--text_prompt "text"` to specify the text prompt
52 | ```shell
53 | python Inference.py --model_path ./weights/FastSAM.pt \
54 | --img_path ./images/cat.jpg \
55 | --text_prompt "cat" \
56 | --better_quality True \
57 | --withContours True
58 | ```
59 | 
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # Fast Segment Anything
4 |
5 | [](https://pypi.python.org/pypi/segment-anything-fast)
6 |
7 | [[`📕Paper`](https://arxiv.org/pdf/2306.12156.pdf)] [[`🤗HuggingFace Demo`](https://huggingface.co/spaces/An-619/FastSAM)] [[`Colab demo`](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)] [[`Replicate demo & API`](https://replicate.com/casia-iva-lab/fastsam)] [[`Model Zoo`](#model-checkpoints)] [[`BibTeX`](#citing-fastsam)]
8 |
9 | 
10 |
11 | The **Fast Segment Anything Model(FastSAM)** is a CNN Segment Anything Model trained using only 2% of the SA-1B dataset published by SAM authors. FastSAM achieves comparable performance with
12 | the SAM method at **50× higher run-time speed**.
13 |
14 | 
15 |
16 | **🍇 Updates**
17 |
18 | - **`2023/07/06`** Added to [Ultralytics (YOLOv8) Model Hub](https://docs.ultralytics.com/models/fast-sam/). Thanks to [Ultralytics](https://github.com/ultralytics/ultralytics) for help 🌹.
19 | - **`2023/06/29`** Support [text mode](https://huggingface.co/spaces/An-619/FastSAM) in HuggingFace Space. Thanks a lot to [gaoxinge](https://github.com/gaoxinge) for help 🌹.
20 | - **`2023/06/29`** Release [FastSAM_Awesome_TensorRT](https://github.com/ChuRuaNh0/FastSam_Awsome_TensorRT). Thanks a lot to [ChuRuaNh0](https://github.com/ChuRuaNh0) for providing the TensorRT model of FastSAM 🌹.
21 | - **`2023/06/26`** Release [FastSAM Replicate Online Demo](https://replicate.com/casia-iva-lab/fastsam). Thanks a lot to [Chenxi](https://chenxwh.github.io/) for providing this nice demo 🌹.
22 | - **`2023/06/26`** Support [points mode](https://huggingface.co/spaces/An-619/FastSAM) in HuggingFace Space. Better and faster interaction will come soon!
23 | - **`2023/06/24`** Thanks a lot to [Grounding-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) for Combining Grounding-DINO with FastSAM in [Grounded-FastSAM](https://github.com/IDEA-Research/Grounded-Segment-Anything/tree/main/EfficientSAM) 🌹.
24 |
25 | ## Installation
26 |
27 | Clone the repository locally:
28 |
29 | ```shell
30 | pip install segment-anything-fast
31 | ```
32 |
33 | ## Getting Started
34 |
35 | First download a [model checkpoint](#model-checkpoints).
36 |
37 | Then, you can run the scripts to try the everything mode and three prompt modes.
38 |
39 | ```shell
40 | # Everything mode
41 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg
42 | ```
43 |
44 | ```shell
45 | # Text prompt
46 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --text_prompt "the yellow dog"
47 | ```
48 |
49 | ```shell
50 | # Box prompt (xywh)
51 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --box_prompt "[[570,200,230,400]]"
52 | ```
53 |
54 | ```shell
55 | # Points prompt
56 | python Inference.py --model_path ./weights/FastSAM.pt --img_path ./images/dogs.jpg --point_prompt "[[520,360],[620,300]]" --point_label "[1,0]"
57 | ```
58 |
59 | You can use the following code to generate all masks, make mask selection based on prompts, and visualize the results.
60 |
61 | ```shell
62 | from fastsam import FastSAM, FastSAMPrompt
63 |
64 | model = FastSAM('./weights/FastSAM.pt')
65 | IMAGE_PATH = './images/dogs.jpg'
66 | DEVICE = 'cpu'
67 | everything_results = model(IMAGE_PATH, device=DEVICE, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
68 | prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
69 |
70 | # everything prompt
71 | ann = prompt_process.everything_prompt()
72 |
73 | # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
74 | ann = prompt_process.box_prompt(bbox=[[200, 200, 300, 300]])
75 |
76 | # text prompt
77 | ann = prompt_process.text_prompt(text='a photo of a dog')
78 |
79 | # point prompt
80 | # points default [[0,0]] [[x1,y1],[x2,y2]]
81 | # point_label default [0] [1,0] 0:background, 1:foreground
82 | ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
83 |
84 | prompt_process.plot(annotations=ann,output_path='./output/dog.jpg',)
85 | ```
86 |
87 | You are also welcomed to try our Colab demo: [FastSAM_example.ipynb](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing).
88 |
89 | ## Different Inference Options
90 |
91 | We provide various options for different purposes, details are in [MORE_USAGES.md](MORE_USAGES.md).
92 |
93 | ## Web demo
94 |
95 | ### Gradio demo
96 |
97 | - We also provide a UI for testing our method that is built with gradio. You can upload a custom image, select the mode and set the parameters, click the segment button, and get a satisfactory segmentation result. Currently, the UI supports interaction with the 'Everything mode' and 'points mode'. We plan to add support for additional modes in the future. Running the following command in a terminal will launch the demo:
98 |
99 | ```
100 | # Download the pre-trained model in "./weights/FastSAM.pt"
101 | python app_gradio.py
102 | ```
103 |
104 | - This demo is also hosted on [HuggingFace Space](https://huggingface.co/spaces/An-619/FastSAM).
105 |
106 |  
107 |
108 | ### Replicate demo
109 |
110 | - [Replicate demo](https://replicate.com/casia-iva-lab/fastsam) has supported all modes, you can experience points/box/text mode.
111 |
112 |   
113 |
114 | ## Model Checkpoints
115 |
116 | Two model versions of the model are available with different sizes. Click the links below to download the checkpoint for the corresponding model type.
117 |
118 | - **`default` or `FastSAM`: [YOLOv8x based Segment Anything Model](https://drive.google.com/file/d/1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv/view?usp=sharing) | [Baidu Cloud (pwd: 0000).](https://pan.baidu.com/s/18KzBmOTENjByoWWR17zdiQ?pwd=0000)**
119 | - `FastSAM-s`: [YOLOv8s based Segment Anything Model.](https://drive.google.com/file/d/10XmSj6mmpmRb8NhXbtiuO9cTTBwR_9SV/view?usp=sharing)
120 |
121 | ## Results
122 |
123 | All result were tested on a single NVIDIA GeForce RTX 3090.
124 |
125 | ### 1. Inference time
126 |
127 | Running Speed under Different Point Prompt Numbers(ms).
128 | | method | params | 1 | 10 | 100 | E(16x16) | E(32x32\*) | E(64x64) |
129 | |:------------------:|:--------:|:-----:|:-----:|:-----:|:----------:|:-----------:|:----------:|
130 | | SAM-H | 0.6G | 446 | 464 | 627 | 852 | 2099 | 6972 |
131 | | SAM-B | 136M | 110 | 125 | 230 | 432 | 1383 | 5417 |
132 | | FastSAM | 68M | 40 |40 | 40 | 40 | 40 | 40 |
133 |
134 | ### 2. Memory usage
135 |
136 | | Dataset | Method | GPU Memory (MB) |
137 | | :-------: | :-----: | :-------------: |
138 | | COCO 2017 | FastSAM | 2608 |
139 | | COCO 2017 | SAM-H | 7060 |
140 | | COCO 2017 | SAM-B | 4670 |
141 |
142 | ### 3. Zero-shot Transfer Experiments
143 |
144 | #### Edge Detection
145 |
146 | Test on the BSDB500 dataset.
147 | |method | year| ODS | OIS | AP | R50 |
148 | |:----------:|:-------:|:--------:|:--------:|:------:|:-----:|
149 | | HED | 2015| .788 | .808 | .840 | .923 |
150 | | SAM | 2023| .768 | .786 | .794 | .928 |
151 | | FastSAM | 2023| .750 | .790 | .793 | .903 |
152 |
153 | #### Object Proposals
154 |
155 | ##### COCO
156 |
157 | | method | AR10 | AR100 | AR1000 | AUC |
158 | | :-------: | :--: | :---: | :----: | :--: |
159 | | SAM-H E64 | 15.5 | 45.6 | 67.7 | 32.1 |
160 | | SAM-H E32 | 18.5 | 49.5 | 62.5 | 33.7 |
161 | | SAM-B E32 | 11.4 | 39.6 | 59.1 | 27.3 |
162 | | FastSAM | 15.7 | 47.3 | 63.7 | 32.2 |
163 |
164 | ##### LVIS
165 |
166 | bbox AR@1000
167 | | method | all | small | med. | large |
168 | |:---------------:|:-----:|:------:|:-----:|:------:|
169 | | ViTDet-H | 65.0 | 53.2 | 83.3 | 91.2 |
170 | zero-shot transfer methods
171 | | SAM-H E64 | 52.1 | 36.6 | 75.1 | 88.2 |
172 | | SAM-H E32 | 50.3 | 33.1 | 76.2 | 89.8 |
173 | | SAM-B E32 | 45.0 | 29.3 | 68.7 | 80.6 |
174 | | FastSAM | 57.1 | 44.3 | 77.1 | 85.3 |
175 |
176 | #### Instance Segmentation On COCO 2017
177 |
178 | | method | AP | APS | APM | APL |
179 | | :------: | :--: | :--: | :--: | :--: |
180 | | ViTDet-H | .510 | .320 | .543 | .689 |
181 | | SAM | .465 | .308 | .510 | .617 |
182 | | FastSAM | .379 | .239 | .434 | .500 |
183 |
184 | ### 4. Performance Visualization
185 |
186 | Several segmentation results:
187 |
188 | #### Natural Images
189 |
190 | 
191 |
192 | #### Text to Mask
193 |
194 | 
195 |
196 | ### 5.Downstream tasks
197 |
198 | The results of several downstream tasks to show the effectiveness.
199 |
200 | #### Anomaly Detection
201 |
202 | 
203 |
204 | #### Salient Object Detection
205 |
206 | 
207 |
208 | #### Building Extracting
209 |
210 | 
211 |
212 | ## License
213 |
214 | The model is licensed under the [Apache 2.0 license](LICENSE).
215 |
216 | ## Acknowledgement
217 |
218 | - [Segment Anything](https://segment-anything.com/) provides the SA-1B dataset and the base codes.
219 | - [YOLOv8](https://github.com/ultralytics/ultralytics) provides codes and pre-trained models.
220 | - [YOLACT](https://arxiv.org/abs/2112.10003) provides powerful instance segmentation method.
221 | - [Grounded-Segment-Anything](https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything) provides a useful web demo template.
222 |
223 | ## Contributors
224 |
225 | Our project wouldn't be possible without the contributions of these amazing people! Thank you all for making this project better.
226 |
227 |
228 |
229 |
230 |
231 | ## Citing FastSAM
232 |
233 | If you find this project useful for your research, please consider citing the following BibTeX entry.
234 |
235 | ```
236 | @misc{zhao2023fast,
237 | title={Fast Segment Anything},
238 | author={Xu Zhao and Wenchao Ding and Yongqi An and Yinglong Du and Tao Yu and Min Li and Ming Tang and Jinqiao Wang},
239 | year={2023},
240 | eprint={2306.12156},
241 | archivePrefix={arXiv},
242 | primaryClass={cs.CV}
243 | }
244 | ```
245 |
246 | [](https://star-history.com/#CASIA-IVA-Lab/FastSAM&Date)
247 |
--------------------------------------------------------------------------------
/app_gradio.py:
--------------------------------------------------------------------------------
1 | from ultralytics import YOLO
2 | import gradio as gr
3 | import torch
4 | from utils.tools_gradio import fast_process
5 | from utils.tools import format_results, box_prompt, point_prompt, text_prompt
6 | from PIL import ImageDraw
7 | import numpy as np
8 |
9 | # Load the pre-trained model
10 | model = YOLO('./weights/FastSAM.pt')
11 |
12 | device = torch.device(
13 | "cuda"
14 | if torch.cuda.is_available()
15 | else "mps"
16 | if torch.backends.mps.is_available()
17 | else "cpu"
18 | )
19 |
20 | # Description
21 | title = "
🏃 Fast Segment Anything 🤗"
22 |
23 | news = """ # 📖 News
24 | 🔥 2023/07/14: Add a "wider result" button in text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/95)).
25 |
26 | 🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
27 |
28 | 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
29 |
30 | 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
31 | """
32 |
33 | description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
34 |
35 | 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
36 |
37 | ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
38 |
39 | 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
40 |
41 | 📣 You can also obtain the segmentation results of any Image through this Colab: [](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
42 |
43 | 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
44 |
45 | 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
46 |
47 | """
48 |
49 | description_p = """ # 🎯 Instructions for points mode
50 | This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
51 |
52 | 1. Upload an image or choose an example.
53 |
54 | 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
55 |
56 | 3. Add points one by one on the image.
57 |
58 | 4. Click the 'Segment with points prompt' button to get the segmentation results.
59 |
60 | **5. If you get Error, click the 'Clear points' button and try again may help.**
61 |
62 | """
63 |
64 | examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
65 | ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
66 |
67 | default_example = examples[0]
68 |
69 | css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
70 |
71 |
72 | def segment_everything(
73 | input,
74 | input_size=1024,
75 | iou_threshold=0.7,
76 | conf_threshold=0.25,
77 | better_quality=False,
78 | withContours=True,
79 | use_retina=True,
80 | text="",
81 | wider=False,
82 | mask_random_color=True,
83 | ):
84 | input_size = int(input_size) # 确保 imgsz 是整数
85 | # Thanks for the suggestion by hysts in HuggingFace.
86 | w, h = input.size
87 | scale = input_size / max(w, h)
88 | new_w = int(w * scale)
89 | new_h = int(h * scale)
90 | input = input.resize((new_w, new_h))
91 |
92 | results = model(input,
93 | device=device,
94 | retina_masks=True,
95 | iou=iou_threshold,
96 | conf=conf_threshold,
97 | imgsz=input_size,)
98 |
99 | if len(text) > 0:
100 | results = format_results(results[0], 0)
101 | annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
102 | annotations = np.array([annotations])
103 | else:
104 | annotations = results[0].masks.data
105 |
106 | fig = fast_process(annotations=annotations,
107 | image=input,
108 | device=device,
109 | scale=(1024 // input_size),
110 | better_quality=better_quality,
111 | mask_random_color=mask_random_color,
112 | bbox=None,
113 | use_retina=use_retina,
114 | withContours=withContours,)
115 | return fig
116 |
117 |
118 | def segment_with_points(
119 | input,
120 | input_size=1024,
121 | iou_threshold=0.7,
122 | conf_threshold=0.25,
123 | better_quality=False,
124 | withContours=True,
125 | use_retina=True,
126 | mask_random_color=True,
127 | ):
128 | global global_points
129 | global global_point_label
130 |
131 | input_size = int(input_size) # 确保 imgsz 是整数
132 | # Thanks for the suggestion by hysts in HuggingFace.
133 | w, h = input.size
134 | scale = input_size / max(w, h)
135 | new_w = int(w * scale)
136 | new_h = int(h * scale)
137 | input = input.resize((new_w, new_h))
138 |
139 | scaled_points = [[int(x * scale) for x in point] for point in global_points]
140 |
141 | results = model(input,
142 | device=device,
143 | retina_masks=True,
144 | iou=iou_threshold,
145 | conf=conf_threshold,
146 | imgsz=input_size,)
147 |
148 | results = format_results(results[0], 0)
149 | annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
150 | annotations = np.array([annotations])
151 |
152 | fig = fast_process(annotations=annotations,
153 | image=input,
154 | device=device,
155 | scale=(1024 // input_size),
156 | better_quality=better_quality,
157 | mask_random_color=mask_random_color,
158 | bbox=None,
159 | use_retina=use_retina,
160 | withContours=withContours,)
161 |
162 | global_points = []
163 | global_point_label = []
164 | return fig, None
165 |
166 |
167 | def get_points_with_draw(image, label, evt: gr.SelectData):
168 | global global_points
169 | global global_point_label
170 |
171 | x, y = evt.index[0], evt.index[1]
172 | point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
173 | global_points.append([x, y])
174 | global_point_label.append(1 if label == 'Add Mask' else 0)
175 |
176 | print(x, y, label == 'Add Mask')
177 |
178 | # 创建一个可以在图像上绘图的对象
179 | draw = ImageDraw.Draw(image)
180 | draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
181 | return image
182 |
183 |
184 | cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
185 | cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
186 | cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
187 |
188 | segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
189 | segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
190 | segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
191 |
192 | global_points = []
193 | global_point_label = []
194 |
195 | input_size_slider = gr.components.Slider(minimum=512,
196 | maximum=1024,
197 | value=1024,
198 | step=64,
199 | label='Input_size',
200 | info='Our model was trained on a size of 1024')
201 |
202 | with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
203 | with gr.Row():
204 | with gr.Column(scale=1):
205 | # Title
206 | gr.Markdown(title)
207 |
208 | with gr.Column(scale=1):
209 | # News
210 | gr.Markdown(news)
211 |
212 | with gr.Tab("Everything mode"):
213 | # Images
214 | with gr.Row(variant="panel"):
215 | with gr.Column(scale=1):
216 | cond_img_e.render()
217 |
218 | with gr.Column(scale=1):
219 | segm_img_e.render()
220 |
221 | # Submit & Clear
222 | with gr.Row():
223 | with gr.Column():
224 | input_size_slider.render()
225 |
226 | with gr.Row():
227 | contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
228 |
229 | with gr.Column():
230 | segment_btn_e = gr.Button("Segment Everything", variant='primary')
231 | clear_btn_e = gr.Button("Clear", variant="secondary")
232 |
233 | gr.Markdown("Try some of the examples below ⬇️")
234 | gr.Examples(examples=examples,
235 | inputs=[cond_img_e],
236 | outputs=segm_img_e,
237 | fn=segment_everything,
238 | cache_examples=True,
239 | examples_per_page=4)
240 |
241 | with gr.Column():
242 | with gr.Accordion("Advanced options", open=False):
243 | iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
244 | conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
245 | with gr.Row():
246 | mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
247 | with gr.Column():
248 | retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
249 |
250 | # Description
251 | gr.Markdown(description_e)
252 |
253 | segment_btn_e.click(segment_everything,
254 | inputs=[
255 | cond_img_e,
256 | input_size_slider,
257 | iou_threshold,
258 | conf_threshold,
259 | mor_check,
260 | contour_check,
261 | retina_check,
262 | ],
263 | outputs=segm_img_e)
264 |
265 | with gr.Tab("Points mode"):
266 | # Images
267 | with gr.Row(variant="panel"):
268 | with gr.Column(scale=1):
269 | cond_img_p.render()
270 |
271 | with gr.Column(scale=1):
272 | segm_img_p.render()
273 |
274 | # Submit & Clear
275 | with gr.Row():
276 | with gr.Column():
277 | with gr.Row():
278 | add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
279 |
280 | with gr.Column():
281 | segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
282 | clear_btn_p = gr.Button("Clear points", variant='secondary')
283 |
284 | gr.Markdown("Try some of the examples below ⬇️")
285 | gr.Examples(examples=examples,
286 | inputs=[cond_img_p],
287 | # outputs=segm_img_p,
288 | # fn=segment_with_points,
289 | # cache_examples=True,
290 | examples_per_page=4)
291 |
292 | with gr.Column():
293 | # Description
294 | gr.Markdown(description_p)
295 |
296 | cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
297 |
298 | segment_btn_p.click(segment_with_points,
299 | inputs=[cond_img_p],
300 | outputs=[segm_img_p, cond_img_p])
301 |
302 | with gr.Tab("Text mode"):
303 | # Images
304 | with gr.Row(variant="panel"):
305 | with gr.Column(scale=1):
306 | cond_img_t.render()
307 |
308 | with gr.Column(scale=1):
309 | segm_img_t.render()
310 |
311 | # Submit & Clear
312 | with gr.Row():
313 | with gr.Column():
314 | input_size_slider_t = gr.components.Slider(minimum=512,
315 | maximum=1024,
316 | value=1024,
317 | step=64,
318 | label='Input_size',
319 | info='Our model was trained on a size of 1024')
320 | with gr.Row():
321 | with gr.Column():
322 | contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
323 | text_box = gr.Textbox(label="text prompt", value="a black dog")
324 |
325 | with gr.Column():
326 | segment_btn_t = gr.Button("Segment with text", variant='primary')
327 | clear_btn_t = gr.Button("Clear", variant="secondary")
328 |
329 | gr.Markdown("Try some of the examples below ⬇️")
330 | gr.Examples(examples=[["examples/dogs.jpg"]] + examples,
331 | inputs=[cond_img_e],
332 | # outputs=segm_img_e,
333 | # fn=segment_everything,
334 | # cache_examples=True,
335 | examples_per_page=4)
336 |
337 | with gr.Column():
338 | with gr.Accordion("Advanced options", open=False):
339 | iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
340 | conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
341 | with gr.Row():
342 | mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
343 | retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
344 | wider_check = gr.Checkbox(value=False, label='wider', info='wider result')
345 |
346 | # Description
347 | gr.Markdown(description_e)
348 |
349 | segment_btn_t.click(segment_everything,
350 | inputs=[
351 | cond_img_t,
352 | input_size_slider_t,
353 | iou_threshold,
354 | conf_threshold,
355 | mor_check,
356 | contour_check,
357 | retina_check,
358 | text_box,
359 | wider_check,
360 | ],
361 | outputs=segm_img_t)
362 |
363 | def clear():
364 | return None, None
365 |
366 | def clear_text():
367 | return None, None, None
368 |
369 | clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
370 | clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
371 | clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
372 |
373 | demo.queue()
374 | demo.launch()
375 |
--------------------------------------------------------------------------------
/assets/Overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/Overview.png
--------------------------------------------------------------------------------
/assets/anomaly.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/anomaly.png
--------------------------------------------------------------------------------
/assets/building.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/building.png
--------------------------------------------------------------------------------
/assets/dog_clip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/dog_clip.png
--------------------------------------------------------------------------------
/assets/eightpic.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/eightpic.pdf
--------------------------------------------------------------------------------
/assets/eightpic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/eightpic.png
--------------------------------------------------------------------------------
/assets/head_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/head_fig.png
--------------------------------------------------------------------------------
/assets/hf_everything_mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/hf_everything_mode.png
--------------------------------------------------------------------------------
/assets/hf_points_mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/hf_points_mode.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/logo.png
--------------------------------------------------------------------------------
/assets/more_usages/box_prompt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/box_prompt.png
--------------------------------------------------------------------------------
/assets/more_usages/draw_edge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/draw_edge.png
--------------------------------------------------------------------------------
/assets/more_usages/everything_mode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/everything_mode.png
--------------------------------------------------------------------------------
/assets/more_usages/everything_mode_without_retina.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/everything_mode_without_retina.png
--------------------------------------------------------------------------------
/assets/more_usages/more_points.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/more_points.png
--------------------------------------------------------------------------------
/assets/more_usages/text_prompt_cat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/more_usages/text_prompt_cat.png
--------------------------------------------------------------------------------
/assets/replicate-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/replicate-1.png
--------------------------------------------------------------------------------
/assets/replicate-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/replicate-2.png
--------------------------------------------------------------------------------
/assets/replicate-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/replicate-3.png
--------------------------------------------------------------------------------
/assets/salient.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/assets/salient.png
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 | # Thanks for chenxwh.
4 |
5 | build:
6 | # set to true if your model requires a GPU
7 | gpu: true
8 | cuda: "11.7"
9 | system_packages:
10 | - "libgl1-mesa-glx"
11 | - "libglib2.0-0"
12 | python_version: "3.8"
13 | python_packages:
14 | - "matplotlib==3.7.1"
15 | - "opencv-python==4.7.0.72"
16 | - "Pillow==9.5.0"
17 | - "PyYAML==6.0"
18 | - "requests==2.31.0"
19 | - "scipy==1.10.1"
20 | - "torch==2.0.1"
21 | - "torchvision==0.15.2"
22 | - "tqdm==4.65.0"
23 | - "pandas==2.0.2"
24 | - "seaborn==0.12.0"
25 | - "ultralytics==8.0.121"
26 | - git+https://github.com/openai/CLIP.git
27 | predict: "predict.py:Predictor"
28 |
--------------------------------------------------------------------------------
/examples/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/dogs.jpg
--------------------------------------------------------------------------------
/examples/sa_10039.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_10039.jpg
--------------------------------------------------------------------------------
/examples/sa_11025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_11025.jpg
--------------------------------------------------------------------------------
/examples/sa_1309.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_1309.jpg
--------------------------------------------------------------------------------
/examples/sa_192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_192.jpg
--------------------------------------------------------------------------------
/examples/sa_414.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_414.jpg
--------------------------------------------------------------------------------
/examples/sa_561.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_561.jpg
--------------------------------------------------------------------------------
/examples/sa_862.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_862.jpg
--------------------------------------------------------------------------------
/examples/sa_8776.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/examples/sa_8776.jpg
--------------------------------------------------------------------------------
/fastsam/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 |
3 | from .model import FastSAM
4 | from .predict import FastSAMPredictor
5 | from .prompt import FastSAMPrompt
6 | # from .val import FastSAMValidator
7 | from .decoder import FastSAMDecoder
8 |
9 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'
10 |
--------------------------------------------------------------------------------
/fastsam/decoder.py:
--------------------------------------------------------------------------------
1 | from .model import FastSAM
2 | import numpy as np
3 | from PIL import Image
4 | import clip
5 | from typing import Optional, List, Tuple, Union
6 |
7 |
8 | class FastSAMDecoder:
9 | def __init__(
10 | self,
11 | model: FastSAM,
12 | device: str='cpu',
13 | conf: float=0.4,
14 | iou: float=0.9,
15 | imgsz: int=1024,
16 | retina_masks: bool=True,
17 | ):
18 | self.model = model
19 | self.device = device
20 | self.retina_masks = retina_masks
21 | self.imgsz = imgsz
22 | self.conf = conf
23 | self.iou = iou
24 | self.image = None
25 | self.image_embedding = None
26 |
27 | def run_encoder(self, image):
28 | if isinstance(image,str):
29 | image = np.array(Image.open(image))
30 | self.image = image
31 | image_embedding = self.model(
32 | self.image,
33 | device=self.device,
34 | retina_masks=self.retina_masks,
35 | imgsz=self.imgsz,
36 | conf=self.conf,
37 | iou=self.iou
38 | )
39 | return image_embedding[0].numpy()
40 |
41 | def run_decoder(
42 | self,
43 | image_embedding,
44 | point_prompt: Optional[np.ndarray]=None,
45 | point_label: Optional[np.ndarray]=None,
46 | box_prompt: Optional[np.ndarray]=None,
47 | text_prompt: Optional[str]=None,
48 | )->np.ndarray:
49 | self.image_embedding = image_embedding
50 | if point_prompt is not None:
51 | ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
52 | return ann
53 | elif box_prompt is not None:
54 | ann = self.box_prompt(bbox=box_prompt)
55 | return ann
56 | elif text_prompt is not None:
57 | ann = self.text_prompt(text=text_prompt)
58 | return ann
59 | else:
60 | return None
61 |
62 | def box_prompt(self, bbox):
63 | assert (bbox[2] != 0 and bbox[3] != 0)
64 | masks = self.image_embedding.masks.data
65 | target_height = self.image.shape[0]
66 | target_width = self.image.shape[1]
67 | h = masks.shape[1]
68 | w = masks.shape[2]
69 | if h != target_height or w != target_width:
70 | bbox = [
71 | int(bbox[0] * w / target_width),
72 | int(bbox[1] * h / target_height),
73 | int(bbox[2] * w / target_width),
74 | int(bbox[3] * h / target_height), ]
75 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
76 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
77 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
78 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
79 |
80 | # IoUs = torch.zeros(len(masks), dtype=torch.float32)
81 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
82 |
83 | masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
84 | orig_masks_area = np.sum(masks, axis=(1, 2))
85 |
86 | union = bbox_area + orig_masks_area - masks_area
87 | IoUs = masks_area / union
88 | max_iou_index = np.argmax(IoUs)
89 |
90 | return np.array([masks[max_iou_index].cpu().numpy()])
91 |
92 | def point_prompt(self, points, pointlabel): # numpy
93 |
94 | masks = self._format_results(self.image_embedding[0], 0)
95 | target_height = self.image.shape[0]
96 | target_width = self.image.shape[1]
97 | h = masks[0]['segmentation'].shape[0]
98 | w = masks[0]['segmentation'].shape[1]
99 | if h != target_height or w != target_width:
100 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
101 | onemask = np.zeros((h, w))
102 | masks = sorted(masks, key=lambda x: x['area'], reverse=True)
103 | for i, annotation in enumerate(masks):
104 | if type(annotation) == dict:
105 | mask = annotation['segmentation']
106 | else:
107 | mask = annotation
108 | for i, point in enumerate(points):
109 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
110 | onemask[mask] = 1
111 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
112 | onemask[mask] = 0
113 | onemask = onemask >= 1
114 | return np.array([onemask])
115 |
116 | def _format_results(self, result, filter=0):
117 | annotations = []
118 | n = len(result.masks.data)
119 | for i in range(n):
120 | annotation = {}
121 | mask = result.masks.data[i] == 1.0
122 |
123 | if np.sum(mask) < filter:
124 | continue
125 | annotation['id'] = i
126 | annotation['segmentation'] = mask
127 | annotation['bbox'] = result.boxes.data[i]
128 | annotation['score'] = result.boxes.conf[i]
129 | annotation['area'] = annotation['segmentation'].sum()
130 | annotations.append(annotation)
131 | return annotations
132 |
--------------------------------------------------------------------------------
/fastsam/model.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, AGPL-3.0 license
2 | """
3 | FastSAM model interface.
4 |
5 | Usage - Predict:
6 | from ultralytics import FastSAM
7 |
8 | model = FastSAM('last.pt')
9 | results = model.predict('ultralytics/assets/bus.jpg')
10 | """
11 |
12 | from ultralytics.yolo.cfg import get_cfg
13 | from ultralytics.yolo.engine.exporter import Exporter
14 | from ultralytics.yolo.engine.model import YOLO
15 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
16 | from ultralytics.yolo.utils.checks import check_imgsz
17 |
18 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
19 | from .predict import FastSAMPredictor
20 |
21 |
22 | class FastSAM(YOLO):
23 |
24 | @smart_inference_mode()
25 | def predict(self, source=None, stream=False, **kwargs):
26 | """
27 | Perform prediction using the YOLO model.
28 |
29 | Args:
30 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
31 | Accepts all source types accepted by the YOLO model.
32 | stream (bool): Whether to stream the predictions or not. Defaults to False.
33 | **kwargs : Additional keyword arguments passed to the predictor.
34 | Check the 'configuration' section in the documentation for all available options.
35 |
36 | Returns:
37 | (List[ultralytics.yolo.engine.results.Results]): The prediction results.
38 | """
39 | if source is None:
40 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
41 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
42 | overrides = self.overrides.copy()
43 | overrides['conf'] = 0.25
44 | overrides.update(kwargs) # prefer kwargs
45 | overrides['mode'] = kwargs.get('mode', 'predict')
46 | assert overrides['mode'] in ['track', 'predict']
47 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
48 | self.predictor = FastSAMPredictor(overrides=overrides)
49 | self.predictor.setup_model(model=self.model, verbose=False)
50 | try:
51 | return self.predictor(source, stream=stream)
52 | except Exception as e:
53 | return None
54 |
55 | def train(self, **kwargs):
56 | """Function trains models but raises an error as FastSAM models do not support training."""
57 | raise NotImplementedError("Currently, the training codes are on the way.")
58 |
59 | def val(self, **kwargs):
60 | """Run validation given dataset."""
61 | overrides = dict(task='segment', mode='val')
62 | overrides.update(kwargs) # prefer kwargs
63 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
64 | args.imgsz = check_imgsz(args.imgsz, max_dim=1)
65 | validator = FastSAM(args=args)
66 | validator(model=self.model)
67 | self.metrics = validator.metrics
68 | return validator.metrics
69 |
70 | @smart_inference_mode()
71 | def export(self, **kwargs):
72 | """
73 | Export model.
74 |
75 | Args:
76 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
77 | """
78 | overrides = dict(task='detect')
79 | overrides.update(kwargs)
80 | overrides['mode'] = 'export'
81 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
82 | args.task = self.task
83 | if args.imgsz == DEFAULT_CFG.imgsz:
84 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
85 | if args.batch == DEFAULT_CFG.batch:
86 | args.batch = 1 # default to 1 if not modified
87 | return Exporter(overrides=args)(model=self.model)
88 |
89 | def info(self, detailed=False, verbose=True):
90 | """
91 | Logs model info.
92 |
93 | Args:
94 | detailed (bool): Show detailed information about model.
95 | verbose (bool): Controls verbosity.
96 | """
97 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
98 |
99 | def __call__(self, source=None, stream=False, **kwargs):
100 | """Calls the 'predict' function with given arguments to perform object detection."""
101 | return self.predict(source, stream, **kwargs)
102 |
103 | def __getattr__(self, attr):
104 | """Raises error if object has no requested attribute."""
105 | name = self.__class__.__name__
106 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
107 |
--------------------------------------------------------------------------------
/fastsam/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ultralytics.yolo.engine.results import Results
4 | from ultralytics.yolo.utils import DEFAULT_CFG, ops
5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor
6 | from .utils import bbox_iou
7 |
8 | class FastSAMPredictor(DetectionPredictor):
9 |
10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
11 | super().__init__(cfg, overrides, _callbacks)
12 | self.args.task = 'segment'
13 |
14 | def postprocess(self, preds, img, orig_imgs):
15 | """TODO: filter by classes."""
16 | p = ops.non_max_suppression(preds[0],
17 | self.args.conf,
18 | self.args.iou,
19 | agnostic=self.args.agnostic_nms,
20 | max_det=self.args.max_det,
21 | nc=len(self.model.names),
22 | classes=self.args.classes)
23 |
24 | results = []
25 | if len(p) == 0 or len(p[0]) == 0:
26 | print("No object detected.")
27 | return results
28 |
29 | full_box = torch.zeros_like(p[0][0])
30 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
31 | full_box = full_box.view(1, -1)
32 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
33 | if critical_iou_index.numel() != 0:
34 | full_box[0][4] = p[0][critical_iou_index][:,4]
35 | full_box[0][6:] = p[0][critical_iou_index][:,6:]
36 | p[0][critical_iou_index] = full_box
37 |
38 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
39 | for i, pred in enumerate(p):
40 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
41 | path = self.batch[0]
42 | img_path = path[i] if isinstance(path, list) else path
43 | if not len(pred): # save empty boxes
44 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
45 | continue
46 | if self.args.retina_masks:
47 | if not isinstance(orig_imgs, torch.Tensor):
48 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
49 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
50 | else:
51 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
52 | if not isinstance(orig_imgs, torch.Tensor):
53 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
54 | results.append(
55 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
56 | return results
57 |
--------------------------------------------------------------------------------
/fastsam/prompt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 | from .utils import image_to_np_ndarray
8 | from PIL import Image
9 |
10 | try:
11 | import clip # for linear_assignment
12 |
13 | except (ImportError, AssertionError, AttributeError):
14 | from ultralytics.yolo.utils.checks import check_requirements
15 |
16 | check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
17 | import clip
18 |
19 |
20 | class FastSAMPrompt:
21 |
22 | def __init__(self, image, results, device='cuda'):
23 | if isinstance(image, str) or isinstance(image, Image.Image):
24 | image = image_to_np_ndarray(image)
25 | self.device = device
26 | self.results = results
27 | self.img = image
28 |
29 | def _segment_image(self, image, bbox):
30 | if isinstance(image, Image.Image):
31 | image_array = np.array(image)
32 | else:
33 | image_array = image
34 | segmented_image_array = np.zeros_like(image_array)
35 | x1, y1, x2, y2 = bbox
36 | segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
37 | segmented_image = Image.fromarray(segmented_image_array)
38 | black_image = Image.new('RGB', image.size, (255, 255, 255))
39 | # transparency_mask = np.zeros_like((), dtype=np.uint8)
40 | transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
41 | transparency_mask[y1:y2, x1:x2] = 255
42 | transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
43 | black_image.paste(segmented_image, mask=transparency_mask_image)
44 | return black_image
45 |
46 | def _format_results(self, result, filter=0):
47 | annotations = []
48 | n = len(result.masks.data)
49 | for i in range(n):
50 | annotation = {}
51 | mask = result.masks.data[i] == 1.0
52 |
53 | if torch.sum(mask) < filter:
54 | continue
55 | annotation['id'] = i
56 | annotation['segmentation'] = mask.cpu().numpy()
57 | annotation['bbox'] = result.boxes.data[i]
58 | annotation['score'] = result.boxes.conf[i]
59 | annotation['area'] = annotation['segmentation'].sum()
60 | annotations.append(annotation)
61 | return annotations
62 |
63 | def filter_masks(annotations): # filte the overlap mask
64 | annotations.sort(key=lambda x: x['area'], reverse=True)
65 | to_remove = set()
66 | for i in range(0, len(annotations)):
67 | a = annotations[i]
68 | for j in range(i + 1, len(annotations)):
69 | b = annotations[j]
70 | if i != j and j not in to_remove:
71 | # check if
72 | if b['area'] < a['area']:
73 | if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
74 | to_remove.add(j)
75 |
76 | return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
77 |
78 | def _get_bbox_from_mask(self, mask):
79 | mask = mask.astype(np.uint8)
80 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
81 | x1, y1, w, h = cv2.boundingRect(contours[0])
82 | x2, y2 = x1 + w, y1 + h
83 | if len(contours) > 1:
84 | for b in contours:
85 | x_t, y_t, w_t, h_t = cv2.boundingRect(b)
86 | # Merge multiple bounding boxes into one.
87 | x1 = min(x1, x_t)
88 | y1 = min(y1, y_t)
89 | x2 = max(x2, x_t + w_t)
90 | y2 = max(y2, y_t + h_t)
91 | h = y2 - y1
92 | w = x2 - x1
93 | return [x1, y1, x2, y2]
94 |
95 | def plot_to_result(self,
96 | annotations,
97 | bboxes=None,
98 | points=None,
99 | point_label=None,
100 | mask_random_color=True,
101 | better_quality=True,
102 | retina=False,
103 | withContours=True) -> np.ndarray:
104 | if isinstance(annotations[0], dict):
105 | annotations = [annotation['segmentation'] for annotation in annotations]
106 | image = self.img
107 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
108 | original_h = image.shape[0]
109 | original_w = image.shape[1]
110 | if sys.platform == "darwin":
111 | plt.switch_backend("TkAgg")
112 | plt.figure(figsize=(original_w / 100, original_h / 100))
113 | # Add subplot with no margin.
114 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
115 | plt.margins(0, 0)
116 | plt.gca().xaxis.set_major_locator(plt.NullLocator())
117 | plt.gca().yaxis.set_major_locator(plt.NullLocator())
118 |
119 | plt.imshow(image)
120 | if better_quality:
121 | if isinstance(annotations[0], torch.Tensor):
122 | annotations = np.array(annotations.cpu())
123 | for i, mask in enumerate(annotations):
124 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
125 | annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
126 | if self.device == 'cpu':
127 | annotations = np.array(annotations)
128 | self.fast_show_mask(
129 | annotations,
130 | plt.gca(),
131 | random_color=mask_random_color,
132 | bboxes=bboxes,
133 | points=points,
134 | pointlabel=point_label,
135 | retinamask=retina,
136 | target_height=original_h,
137 | target_width=original_w,
138 | )
139 | else:
140 | if isinstance(annotations[0], np.ndarray):
141 | annotations = torch.from_numpy(annotations)
142 | self.fast_show_mask_gpu(
143 | annotations,
144 | plt.gca(),
145 | random_color=mask_random_color,
146 | bboxes=bboxes,
147 | points=points,
148 | pointlabel=point_label,
149 | retinamask=retina,
150 | target_height=original_h,
151 | target_width=original_w,
152 | )
153 | if isinstance(annotations, torch.Tensor):
154 | annotations = annotations.cpu().numpy()
155 | if withContours:
156 | contour_all = []
157 | temp = np.zeros((original_h, original_w, 1))
158 | for i, mask in enumerate(annotations):
159 | if type(mask) == dict:
160 | mask = mask['segmentation']
161 | annotation = mask.astype(np.uint8)
162 | if not retina:
163 | annotation = cv2.resize(
164 | annotation,
165 | (original_w, original_h),
166 | interpolation=cv2.INTER_NEAREST,
167 | )
168 | contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
169 | for contour in contours:
170 | contour_all.append(contour)
171 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
172 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
173 | contour_mask = temp / 255 * color.reshape(1, 1, -1)
174 | plt.imshow(contour_mask)
175 |
176 | plt.axis('off')
177 | fig = plt.gcf()
178 | plt.draw()
179 |
180 | try:
181 | buf = fig.canvas.tostring_rgb()
182 | except AttributeError:
183 | fig.canvas.draw()
184 | buf = fig.canvas.tostring_rgb()
185 | cols, rows = fig.canvas.get_width_height()
186 | img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
187 | result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
188 | plt.close()
189 | return result
190 |
191 | # Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control.
192 | def plot(self,
193 | annotations,
194 | output_path,
195 | bboxes=None,
196 | points=None,
197 | point_label=None,
198 | mask_random_color=True,
199 | better_quality=True,
200 | retina=False,
201 | withContours=True):
202 | if len(annotations) == 0:
203 | return None
204 | result = self.plot_to_result(
205 | annotations,
206 | bboxes,
207 | points,
208 | point_label,
209 | mask_random_color,
210 | better_quality,
211 | retina,
212 | withContours,
213 | )
214 |
215 | path = os.path.dirname(os.path.abspath(output_path))
216 | if not os.path.exists(path):
217 | os.makedirs(path)
218 | result = result[:, :, ::-1]
219 | cv2.imwrite(output_path, result)
220 |
221 | # CPU post process
222 | def fast_show_mask(
223 | self,
224 | annotation,
225 | ax,
226 | random_color=False,
227 | bboxes=None,
228 | points=None,
229 | pointlabel=None,
230 | retinamask=True,
231 | target_height=960,
232 | target_width=960,
233 | ):
234 | msak_sum = annotation.shape[0]
235 | height = annotation.shape[1]
236 | weight = annotation.shape[2]
237 | #Sort annotations based on area.
238 | areas = np.sum(annotation, axis=(1, 2))
239 | sorted_indices = np.argsort(areas)
240 | annotation = annotation[sorted_indices]
241 |
242 | index = (annotation != 0).argmax(axis=0)
243 | if random_color:
244 | color = np.random.random((msak_sum, 1, 1, 3))
245 | else:
246 | color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
247 | transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
248 | visual = np.concatenate([color, transparency], axis=-1)
249 | mask_image = np.expand_dims(annotation, -1) * visual
250 |
251 | show = np.zeros((height, weight, 4))
252 | h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
253 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
254 | # Use vectorized indexing to update the values of 'show'.
255 | show[h_indices, w_indices, :] = mask_image[indices]
256 | if bboxes is not None:
257 | for bbox in bboxes:
258 | x1, y1, x2, y2 = bbox
259 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
260 | # draw point
261 | if points is not None:
262 | plt.scatter(
263 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
264 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
265 | s=20,
266 | c='y',
267 | )
268 | plt.scatter(
269 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
270 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
271 | s=20,
272 | c='m',
273 | )
274 |
275 | if not retinamask:
276 | show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
277 | ax.imshow(show)
278 |
279 | def fast_show_mask_gpu(
280 | self,
281 | annotation,
282 | ax,
283 | random_color=False,
284 | bboxes=None,
285 | points=None,
286 | pointlabel=None,
287 | retinamask=True,
288 | target_height=960,
289 | target_width=960,
290 | ):
291 | msak_sum = annotation.shape[0]
292 | height = annotation.shape[1]
293 | weight = annotation.shape[2]
294 | areas = torch.sum(annotation, dim=(1, 2))
295 | sorted_indices = torch.argsort(areas, descending=False)
296 | annotation = annotation[sorted_indices]
297 | # Find the index of the first non-zero value at each position.
298 | index = (annotation != 0).to(torch.long).argmax(dim=0)
299 | if random_color:
300 | color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
301 | else:
302 | color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
303 | 30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
304 | transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
305 | visual = torch.cat([color, transparency], dim=-1)
306 | mask_image = torch.unsqueeze(annotation, -1) * visual
307 | # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form.
308 | show = torch.zeros((height, weight, 4)).to(annotation.device)
309 | try:
310 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
311 | except:
312 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
313 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
314 | # Use vectorized indexing to update the values of 'show'.
315 | show[h_indices, w_indices, :] = mask_image[indices]
316 | show_cpu = show.cpu().numpy()
317 | if bboxes is not None:
318 | for bbox in bboxes:
319 | x1, y1, x2, y2 = bbox
320 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
321 | # draw point
322 | if points is not None:
323 | plt.scatter(
324 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
325 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
326 | s=20,
327 | c='y',
328 | )
329 | plt.scatter(
330 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
331 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
332 | s=20,
333 | c='m',
334 | )
335 | if not retinamask:
336 | show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
337 | ax.imshow(show_cpu)
338 |
339 | # clip
340 | @torch.no_grad()
341 | def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
342 | preprocessed_images = [preprocess(image).to(device) for image in elements]
343 | tokenized_text = clip.tokenize([search_text]).to(device)
344 | stacked_images = torch.stack(preprocessed_images)
345 | image_features = model.encode_image(stacked_images)
346 | text_features = model.encode_text(tokenized_text)
347 | image_features /= image_features.norm(dim=-1, keepdim=True)
348 | text_features /= text_features.norm(dim=-1, keepdim=True)
349 | probs = 100.0 * image_features @ text_features.T
350 | return probs[:, 0].softmax(dim=0)
351 |
352 | def _crop_image(self, format_results):
353 |
354 | image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
355 | ori_w, ori_h = image.size
356 | annotations = format_results
357 | mask_h, mask_w = annotations[0]['segmentation'].shape
358 | if ori_w != mask_w or ori_h != mask_h:
359 | image = image.resize((mask_w, mask_h))
360 | cropped_boxes = []
361 | cropped_images = []
362 | not_crop = []
363 | filter_id = []
364 | # annotations, _ = filter_masks(annotations)
365 | # filter_id = list(_)
366 | for _, mask in enumerate(annotations):
367 | if np.sum(mask['segmentation']) <= 100:
368 | filter_id.append(_)
369 | continue
370 | bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
371 | cropped_boxes.append(self._segment_image(image, bbox))
372 | # cropped_boxes.append(segment_image(image,mask["segmentation"]))
373 | cropped_images.append(bbox) # Save the bounding box of the cropped image.
374 |
375 | return cropped_boxes, cropped_images, not_crop, filter_id, annotations
376 |
377 | def box_prompt(self, bbox=None, bboxes=None):
378 | if self.results == None:
379 | return []
380 | assert bbox or bboxes
381 | if bboxes is None:
382 | bboxes = [bbox]
383 | max_iou_index = []
384 | for bbox in bboxes:
385 | assert (bbox[2] != 0 and bbox[3] != 0)
386 | masks = self.results[0].masks.data
387 | target_height = self.img.shape[0]
388 | target_width = self.img.shape[1]
389 | h = masks.shape[1]
390 | w = masks.shape[2]
391 | if h != target_height or w != target_width:
392 | bbox = [
393 | int(bbox[0] * w / target_width),
394 | int(bbox[1] * h / target_height),
395 | int(bbox[2] * w / target_width),
396 | int(bbox[3] * h / target_height), ]
397 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
398 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
399 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
400 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
401 |
402 | # IoUs = torch.zeros(len(masks), dtype=torch.float32)
403 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
404 |
405 | masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
406 | orig_masks_area = torch.sum(masks, dim=(1, 2))
407 |
408 | union = bbox_area + orig_masks_area - masks_area
409 | IoUs = masks_area / union
410 | max_iou_index.append(int(torch.argmax(IoUs)))
411 | max_iou_index = list(set(max_iou_index))
412 | return np.array(masks[max_iou_index].cpu().numpy())
413 |
414 | def point_prompt(self, points, pointlabel): # numpy
415 | if self.results == None:
416 | return []
417 | masks = self._format_results(self.results[0], 0)
418 | target_height = self.img.shape[0]
419 | target_width = self.img.shape[1]
420 | h = masks[0]['segmentation'].shape[0]
421 | w = masks[0]['segmentation'].shape[1]
422 | if h != target_height or w != target_width:
423 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
424 | onemask = np.zeros((h, w))
425 | masks = sorted(masks, key=lambda x: x['area'], reverse=True)
426 | for i, annotation in enumerate(masks):
427 | if type(annotation) == dict:
428 | mask = annotation['segmentation']
429 | else:
430 | mask = annotation
431 | for i, point in enumerate(points):
432 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
433 | onemask[mask] = 1
434 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
435 | onemask[mask] = 0
436 | onemask = onemask >= 1
437 | return np.array([onemask])
438 |
439 | def text_prompt(self, text):
440 | if self.results == None:
441 | return []
442 | format_results = self._format_results(self.results[0], 0)
443 | cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
444 | clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
445 | scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
446 | max_idx = scores.argsort()
447 | max_idx = max_idx[-1]
448 | max_idx += sum(np.array(filter_id) <= int(max_idx))
449 | return np.array([annotations[max_idx]['segmentation']])
450 |
451 | def everything_prompt(self):
452 | if self.results == None:
453 | return []
454 | return self.results[0].masks.data
455 |
456 |
--------------------------------------------------------------------------------
/fastsam/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 |
5 |
6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
7 | '''Adjust bounding boxes to stick to image border if they are within a certain threshold.
8 | Args:
9 | boxes: (n, 4)
10 | image_shape: (height, width)
11 | threshold: pixel threshold
12 | Returns:
13 | adjusted_boxes: adjusted bounding boxes
14 | '''
15 |
16 | # Image dimensions
17 | h, w = image_shape
18 |
19 | # Adjust boxes
20 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(
21 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1
22 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(
23 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1
24 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(
25 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2
26 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(
27 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2
28 |
29 | return boxes
30 |
31 |
32 |
33 | def convert_box_xywh_to_xyxy(box):
34 | x1 = box[0]
35 | y1 = box[1]
36 | x2 = box[0] + box[2]
37 | y2 = box[1] + box[3]
38 | return [x1, y1, x2, y2]
39 |
40 |
41 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
42 | '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
43 | Args:
44 | box1: (4, )
45 | boxes: (n, 4)
46 | Returns:
47 | high_iou_indices: Indices of boxes with IoU > thres
48 | '''
49 | boxes = adjust_bboxes_to_image_border(boxes, image_shape)
50 | # obtain coordinates for intersections
51 | x1 = torch.max(box1[0], boxes[:, 0])
52 | y1 = torch.max(box1[1], boxes[:, 1])
53 | x2 = torch.min(box1[2], boxes[:, 2])
54 | y2 = torch.min(box1[3], boxes[:, 3])
55 |
56 | # compute the area of intersection
57 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
58 |
59 | # compute the area of both individual boxes
60 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
61 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
62 |
63 | # compute the area of union
64 | union = box1_area + box2_area - intersection
65 |
66 | # compute the IoU
67 | iou = intersection / union # Should be shape (n, )
68 | if raw_output:
69 | if iou.numel() == 0:
70 | return 0
71 | return iou
72 |
73 | # get indices of boxes with IoU > thres
74 | high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
75 |
76 | return high_iou_indices
77 |
78 |
79 | def image_to_np_ndarray(image):
80 | if type(image) is str:
81 | return np.array(Image.open(image))
82 | elif issubclass(type(image), Image.Image):
83 | return np.array(image)
84 | elif type(image) is np.ndarray:
85 | return image
86 | return None
87 |
--------------------------------------------------------------------------------
/images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/images/cat.jpg
--------------------------------------------------------------------------------
/images/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/images/dogs.jpg
--------------------------------------------------------------------------------
/output/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/output/cat.jpg
--------------------------------------------------------------------------------
/output/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/output/dogs.jpg
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # https://github.com/replicate/cog/blob/main/docs/python.md
3 | # Thanks for chenxwh.
4 |
5 | import argparse
6 | import cv2
7 | import shutil
8 | import ast
9 | from cog import BasePredictor, Input, Path
10 | from ultralytics import YOLO
11 | from utils.tools import *
12 |
13 |
14 | class Predictor(BasePredictor):
15 | def setup(self):
16 | """Load the model into memory to make running multiple predictions efficient"""
17 | self.models = {k: YOLO(f"{k}.pt") for k in ["FastSAM-s", "FastSAM-x"]}
18 |
19 | def predict(
20 | self,
21 | input_image: Path = Input(description="Input image"),
22 | model_name: str = Input(
23 | description="choose a model",
24 | choices=["FastSAM-x", "FastSAM-s"],
25 | default="FastSAM-x",
26 | ),
27 | iou: float = Input(
28 | description="iou threshold for filtering the annotations", default=0.7
29 | ),
30 | text_prompt: str = Input(
31 | description='use text prompt eg: "a black dog"', default=None
32 | ),
33 | conf: float = Input(description="object confidence threshold", default=0.25),
34 | retina: bool = Input(
35 | description="draw high-resolution segmentation masks", default=True
36 | ),
37 | box_prompt: str = Input(default="[0,0,0,0]", description="[x,y,w,h]"),
38 | point_prompt: str = Input(default="[[0,0]]", description="[[x1,y1],[x2,y2]]"),
39 | point_label: str = Input(default="[0]", description="[1,0] 0:background, 1:foreground"),
40 | withContours: bool = Input(
41 | description="draw the edges of the masks", default=False
42 | ),
43 | better_quality: bool = Input(
44 | description="better quality using morphologyEx", default=False
45 | ),
46 | ) -> Path:
47 | """Run a single prediction on the model"""
48 |
49 | # default params
50 |
51 | out_path = "output"
52 | if os.path.exists(out_path):
53 | shutil.rmtree(out_path)
54 | os.makedirs(out_path, exist_ok=True)
55 |
56 | device = torch.device(
57 | "cuda"
58 | if torch.cuda.is_available()
59 | else "mps"
60 | if torch.backends.mps.is_available()
61 | else "cpu"
62 | )
63 |
64 | args = argparse.Namespace(
65 | better_quality=better_quality,
66 | box_prompt=box_prompt,
67 | conf=conf,
68 | device=device,
69 | img_path=str(input_image),
70 | imgsz=1024,
71 | iou=iou,
72 | model_path="FastSAM-x.pt",
73 | output=out_path,
74 | point_label=point_label,
75 | point_prompt=point_prompt,
76 | randomcolor=True,
77 | retina=retina,
78 | text_prompt=text_prompt,
79 | withContours=withContours,
80 | )
81 | args.point_prompt = ast.literal_eval(args.point_prompt)
82 | args.box_prompt = ast.literal_eval(args.box_prompt)
83 | args.point_label = ast.literal_eval(args.point_label)
84 |
85 | model = self.models[model_name]
86 |
87 | results = model(
88 | str(input_image),
89 | imgsz=args.imgsz,
90 | device=args.device,
91 | retina_masks=args.retina,
92 | iou=args.iou,
93 | conf=args.conf,
94 | max_det=100,
95 | )
96 |
97 | if args.box_prompt[2] != 0 and args.box_prompt[3] != 0:
98 | annotations = prompt(results, args, box=True)
99 | annotations = np.array([annotations])
100 | fast_process(
101 | annotations=annotations,
102 | args=args,
103 | mask_random_color=args.randomcolor,
104 | bbox=convert_box_xywh_to_xyxy(args.box_prompt),
105 | )
106 |
107 | elif args.text_prompt != None:
108 | results = format_results(results[0], 0)
109 | annotations = prompt(results, args, text=True)
110 | annotations = np.array([annotations])
111 | fast_process(
112 | annotations=annotations, args=args, mask_random_color=args.randomcolor
113 | )
114 |
115 | elif args.point_prompt[0] != [0, 0]:
116 | results = format_results(results[0], 0)
117 | annotations = prompt(results, args, point=True)
118 | # list to numpy
119 | annotations = np.array([annotations])
120 | fast_process(
121 | annotations=annotations,
122 | args=args,
123 | mask_random_color=args.randomcolor,
124 | points=args.point_prompt,
125 | )
126 |
127 | else:
128 | fast_process(
129 | annotations=results[0].masks.data,
130 | args=args,
131 | mask_random_color=args.randomcolor,
132 | )
133 |
134 | out = "/tmp.out.png"
135 | shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out)
136 |
137 | return Path(out)
138 |
139 |
140 | def prompt(results, args, box=None, point=None, text=None):
141 | ori_img = cv2.imread(args.img_path)
142 | ori_h = ori_img.shape[0]
143 | ori_w = ori_img.shape[1]
144 | if box:
145 | mask, idx = box_prompt(
146 | results[0].masks.data,
147 | convert_box_xywh_to_xyxy(args.box_prompt),
148 | ori_h,
149 | ori_w,
150 | )
151 | elif point:
152 | mask, idx = point_prompt(
153 | results, args.point_prompt, args.point_label, ori_h, ori_w
154 | )
155 | elif text:
156 | mask, idx = text_prompt(results, args.text_prompt, args.img_path, args.device)
157 | else:
158 | return None
159 | return mask
160 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Base-----------------------------------
2 | matplotlib>=3.2.2
3 | opencv-python>=4.6.0
4 | Pillow>=7.1.2
5 | PyYAML>=5.3.1
6 | requests>=2.23.0
7 | scipy>=1.4.1
8 | torch>=1.7.0
9 | torchvision>=0.8.1
10 | tqdm>=4.64.0
11 |
12 | pandas>=1.1.4
13 | seaborn>=0.11.0
14 |
15 | gradio<=3.35.2
16 |
17 | # Ultralytics-----------------------------------
18 | ultralytics<=8.0.122
19 |
20 | openai-clip
--------------------------------------------------------------------------------
/segpredict.py:
--------------------------------------------------------------------------------
1 | from fastsam import FastSAM, FastSAMPrompt
2 | import torch
3 |
4 | model = FastSAM('FastSAM.pt')
5 | IMAGE_PATH = './images/dogs.jpg'
6 | DEVICE = torch.device(
7 | "cuda"
8 | if torch.cuda.is_available()
9 | else "mps"
10 | if torch.backends.mps.is_available()
11 | else "cpu"
12 | )
13 | everything_results = model(
14 | IMAGE_PATH,
15 | device=DEVICE,
16 | retina_masks=True,
17 | imgsz=1024,
18 | conf=0.4,
19 | iou=0.9,
20 | )
21 | prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)
22 |
23 | # # everything prompt
24 | ann = prompt_process.everything_prompt()
25 |
26 | # # bbox prompt
27 | # # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
28 | # bboxes default shape [[0,0,0,0]] -> [[x1,y1,x2,y2]]
29 | # ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
30 | # ann = prompt_process.box_prompt(bboxes=[[200, 200, 300, 300], [500, 500, 600, 600]])
31 |
32 | # # text prompt
33 | # ann = prompt_process.text_prompt(text='a photo of a dog')
34 |
35 | # # point prompt
36 | # # points default [[0,0]] [[x1,y1],[x2,y2]]
37 | # # point_label default [0] [1,0] 0:background, 1:foreground
38 | # ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
39 |
40 | # point prompt
41 | # points default [[0,0]] [[x1,y1],[x2,y2]]
42 | # point_label default [0] [1,0] 0:background, 1:foreground
43 | ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])
44 |
45 | prompt_process.plot(
46 | annotations=ann,
47 | output='./output/',
48 | mask_random_color=True,
49 | better_quality=True,
50 | retina=False,
51 | withContours=True,
52 | )
53 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """The setup script."""
4 |
5 | import io
6 | from os import path as op
7 | from setuptools import setup, find_packages
8 |
9 | with open("README.md", encoding="utf-8") as readme_file:
10 | readme = readme_file.read()
11 |
12 | here = op.abspath(op.dirname(__file__))
13 |
14 | # get the dependencies and installs
15 | with io.open(op.join(here, "requirements.txt"), encoding="utf-8") as f:
16 | all_reqs = f.read().split("\n")
17 |
18 | install_requires = [x.strip() for x in all_reqs if "git+" not in x]
19 | dependency_links = [x.strip().replace("git+", "") for x in all_reqs if "git+" not in x]
20 |
21 | extras_requires = {}
22 |
23 |
24 | requirements = []
25 |
26 | setup_requirements = []
27 |
28 | test_requirements = []
29 |
30 | setup(
31 | author="Qiusheng Wu",
32 | author_email="giswqs@gmail.com",
33 | python_requires=">=3.8",
34 | classifiers=[
35 | "Intended Audience :: Developers",
36 | "License :: OSI Approved :: Apache Software License",
37 | "Natural Language :: English",
38 | "Programming Language :: Python :: 3",
39 | "Programming Language :: Python :: 3.8",
40 | "Programming Language :: Python :: 3.9",
41 | "Programming Language :: Python :: 3.10",
42 | "Programming Language :: Python :: 3.11",
43 | ],
44 | description="Fast Segment Anything",
45 | install_requires=install_requires,
46 | extras_require=extras_requires,
47 | dependency_links=dependency_links,
48 | license="Apache Software License",
49 | long_description=readme,
50 | long_description_content_type="text/markdown",
51 | include_package_data=True,
52 | keywords="segment-anything",
53 | name="segment-anything-fast",
54 | version="0.1.2",
55 | packages=find_packages(include=["fastsam", "fastsam.*"]),
56 | setup_requires=setup_requirements,
57 | url="https://github.com/opengeos/FastSAM",
58 | zip_safe=False,
59 | )
60 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opengeos/FastSAM/1c7245d73e84debc8a7f8d4124f6d6c4d0e86700/utils/__init__.py
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import matplotlib.pyplot as plt
4 | import cv2
5 | import torch
6 | import os
7 | import sys
8 | import clip
9 |
10 |
11 | def convert_box_xywh_to_xyxy(box):
12 | if len(box) == 4:
13 | return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
14 | else:
15 | result = []
16 | for b in box:
17 | b = convert_box_xywh_to_xyxy(b)
18 | result.append(b)
19 | return result
20 |
21 |
22 | def segment_image(image, bbox):
23 | image_array = np.array(image)
24 | segmented_image_array = np.zeros_like(image_array)
25 | x1, y1, x2, y2 = bbox
26 | segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
27 | segmented_image = Image.fromarray(segmented_image_array)
28 | black_image = Image.new("RGB", image.size, (255, 255, 255))
29 | # transparency_mask = np.zeros_like((), dtype=np.uint8)
30 | transparency_mask = np.zeros(
31 | (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
32 | )
33 | transparency_mask[y1:y2, x1:x2] = 255
34 | transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
35 | black_image.paste(segmented_image, mask=transparency_mask_image)
36 | return black_image
37 |
38 |
39 | def format_results(result, filter=0):
40 | annotations = []
41 | n = len(result.masks.data)
42 | for i in range(n):
43 | annotation = {}
44 | mask = result.masks.data[i] == 1.0
45 |
46 | if torch.sum(mask) < filter:
47 | continue
48 | annotation["id"] = i
49 | annotation["segmentation"] = mask.cpu().numpy()
50 | annotation["bbox"] = result.boxes.data[i]
51 | annotation["score"] = result.boxes.conf[i]
52 | annotation["area"] = annotation["segmentation"].sum()
53 | annotations.append(annotation)
54 | return annotations
55 |
56 |
57 | def filter_masks(annotations): # filter the overlap mask
58 | annotations.sort(key=lambda x: x["area"], reverse=True)
59 | to_remove = set()
60 | for i in range(0, len(annotations)):
61 | a = annotations[i]
62 | for j in range(i + 1, len(annotations)):
63 | b = annotations[j]
64 | if i != j and j not in to_remove:
65 | # check if
66 | if b["area"] < a["area"]:
67 | if (a["segmentation"] & b["segmentation"]).sum() / b[
68 | "segmentation"
69 | ].sum() > 0.8:
70 | to_remove.add(j)
71 |
72 | return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
73 |
74 |
75 | def get_bbox_from_mask(mask):
76 | mask = mask.astype(np.uint8)
77 | contours, hierarchy = cv2.findContours(
78 | mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
79 | )
80 | x1, y1, w, h = cv2.boundingRect(contours[0])
81 | x2, y2 = x1 + w, y1 + h
82 | if len(contours) > 1:
83 | for b in contours:
84 | x_t, y_t, w_t, h_t = cv2.boundingRect(b)
85 | # 将多个bbox合并成一个
86 | x1 = min(x1, x_t)
87 | y1 = min(y1, y_t)
88 | x2 = max(x2, x_t + w_t)
89 | y2 = max(y2, y_t + h_t)
90 | h = y2 - y1
91 | w = x2 - x1
92 | return [x1, y1, x2, y2]
93 |
94 |
95 | def fast_process(
96 | annotations, args, mask_random_color, bbox=None, points=None, edges=False
97 | ):
98 | if isinstance(annotations[0], dict):
99 | annotations = [annotation["segmentation"] for annotation in annotations]
100 | result_name = os.path.basename(args.img_path)
101 | image = cv2.imread(args.img_path)
102 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103 | original_h = image.shape[0]
104 | original_w = image.shape[1]
105 | if sys.platform == "darwin":
106 | plt.switch_backend("TkAgg")
107 | plt.figure(figsize=(original_w/100, original_h/100))
108 | # Add subplot with no margin.
109 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
110 | plt.margins(0, 0)
111 | plt.gca().xaxis.set_major_locator(plt.NullLocator())
112 | plt.gca().yaxis.set_major_locator(plt.NullLocator())
113 | plt.imshow(image)
114 | if args.better_quality == True:
115 | if isinstance(annotations[0], torch.Tensor):
116 | annotations = np.array(annotations.cpu())
117 | for i, mask in enumerate(annotations):
118 | mask = cv2.morphologyEx(
119 | mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
120 | )
121 | annotations[i] = cv2.morphologyEx(
122 | mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
123 | )
124 | if args.device == "cpu":
125 | annotations = np.array(annotations)
126 | fast_show_mask(
127 | annotations,
128 | plt.gca(),
129 | random_color=mask_random_color,
130 | bbox=bbox,
131 | points=points,
132 | point_label=args.point_label,
133 | retinamask=args.retina,
134 | target_height=original_h,
135 | target_width=original_w,
136 | )
137 | else:
138 | if isinstance(annotations[0], np.ndarray):
139 | annotations = torch.from_numpy(annotations)
140 | fast_show_mask_gpu(
141 | annotations,
142 | plt.gca(),
143 | random_color=args.randomcolor,
144 | bbox=bbox,
145 | points=points,
146 | point_label=args.point_label,
147 | retinamask=args.retina,
148 | target_height=original_h,
149 | target_width=original_w,
150 | )
151 | if isinstance(annotations, torch.Tensor):
152 | annotations = annotations.cpu().numpy()
153 | if args.withContours == True:
154 | contour_all = []
155 | temp = np.zeros((original_h, original_w, 1))
156 | for i, mask in enumerate(annotations):
157 | if type(mask) == dict:
158 | mask = mask["segmentation"]
159 | annotation = mask.astype(np.uint8)
160 | if args.retina == False:
161 | annotation = cv2.resize(
162 | annotation,
163 | (original_w, original_h),
164 | interpolation=cv2.INTER_NEAREST,
165 | )
166 | contours, hierarchy = cv2.findContours(
167 | annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
168 | )
169 | for contour in contours:
170 | contour_all.append(contour)
171 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
172 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
173 | contour_mask = temp / 255 * color.reshape(1, 1, -1)
174 | plt.imshow(contour_mask)
175 |
176 | save_path = args.output
177 | if not os.path.exists(save_path):
178 | os.makedirs(save_path)
179 | plt.axis("off")
180 | fig = plt.gcf()
181 | plt.draw()
182 |
183 | try:
184 | buf = fig.canvas.tostring_rgb()
185 | except AttributeError:
186 | fig.canvas.draw()
187 | buf = fig.canvas.tostring_rgb()
188 |
189 | cols, rows = fig.canvas.get_width_height()
190 | img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
191 | cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
192 |
193 |
194 | # CPU post process
195 | def fast_show_mask(
196 | annotation,
197 | ax,
198 | random_color=False,
199 | bbox=None,
200 | points=None,
201 | point_label=None,
202 | retinamask=True,
203 | target_height=960,
204 | target_width=960,
205 | ):
206 | msak_sum = annotation.shape[0]
207 | height = annotation.shape[1]
208 | weight = annotation.shape[2]
209 | # 将annotation 按照面积 排序
210 | areas = np.sum(annotation, axis=(1, 2))
211 | sorted_indices = np.argsort(areas)
212 | annotation = annotation[sorted_indices]
213 |
214 | index = (annotation != 0).argmax(axis=0)
215 | if random_color == True:
216 | color = np.random.random((msak_sum, 1, 1, 3))
217 | else:
218 | color = np.ones((msak_sum, 1, 1, 3)) * np.array(
219 | [30 / 255, 144 / 255, 255 / 255]
220 | )
221 | transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
222 | visual = np.concatenate([color, transparency], axis=-1)
223 | mask_image = np.expand_dims(annotation, -1) * visual
224 |
225 | show = np.zeros((height, weight, 4))
226 | h_indices, w_indices = np.meshgrid(
227 | np.arange(height), np.arange(weight), indexing="ij"
228 | )
229 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
230 | # 使用向量化索引更新show的值
231 | show[h_indices, w_indices, :] = mask_image[indices]
232 | if bbox is not None:
233 | x1, y1, x2, y2 = bbox
234 | ax.add_patch(
235 | plt.Rectangle(
236 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
237 | )
238 | )
239 | # draw point
240 | if points is not None:
241 | plt.scatter(
242 | [point[0] for i, point in enumerate(points) if point_label[i] == 1],
243 | [point[1] for i, point in enumerate(points) if point_label[i] == 1],
244 | s=20,
245 | c="y",
246 | )
247 | plt.scatter(
248 | [point[0] for i, point in enumerate(points) if point_label[i] == 0],
249 | [point[1] for i, point in enumerate(points) if point_label[i] == 0],
250 | s=20,
251 | c="m",
252 | )
253 |
254 | if retinamask == False:
255 | show = cv2.resize(
256 | show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
257 | )
258 | ax.imshow(show)
259 |
260 |
261 | def fast_show_mask_gpu(
262 | annotation,
263 | ax,
264 | random_color=False,
265 | bbox=None,
266 | points=None,
267 | point_label=None,
268 | retinamask=True,
269 | target_height=960,
270 | target_width=960,
271 | ):
272 | msak_sum = annotation.shape[0]
273 | height = annotation.shape[1]
274 | weight = annotation.shape[2]
275 | areas = torch.sum(annotation, dim=(1, 2))
276 | sorted_indices = torch.argsort(areas, descending=False)
277 | annotation = annotation[sorted_indices]
278 | # 找每个位置第一个非零值下标
279 | index = (annotation != 0).to(torch.long).argmax(dim=0)
280 | if random_color == True:
281 | color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
282 | else:
283 | color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
284 | [30 / 255, 144 / 255, 255 / 255]
285 | ).to(annotation.device)
286 | transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
287 | visual = torch.cat([color, transparency], dim=-1)
288 | mask_image = torch.unsqueeze(annotation, -1) * visual
289 | # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
290 | show = torch.zeros((height, weight, 4)).to(annotation.device)
291 | h_indices, w_indices = torch.meshgrid(
292 | torch.arange(height), torch.arange(weight), indexing="ij"
293 | )
294 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
295 | # 使用向量化索引更新show的值
296 | show[h_indices, w_indices, :] = mask_image[indices]
297 | show_cpu = show.cpu().numpy()
298 | if bbox is not None:
299 | x1, y1, x2, y2 = bbox
300 | ax.add_patch(
301 | plt.Rectangle(
302 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
303 | )
304 | )
305 | # draw point
306 | if points is not None:
307 | plt.scatter(
308 | [point[0] for i, point in enumerate(points) if point_label[i] == 1],
309 | [point[1] for i, point in enumerate(points) if point_label[i] == 1],
310 | s=20,
311 | c="y",
312 | )
313 | plt.scatter(
314 | [point[0] for i, point in enumerate(points) if point_label[i] == 0],
315 | [point[1] for i, point in enumerate(points) if point_label[i] == 0],
316 | s=20,
317 | c="m",
318 | )
319 | if retinamask == False:
320 | show_cpu = cv2.resize(
321 | show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
322 | )
323 | ax.imshow(show_cpu)
324 |
325 |
326 | # clip
327 | @torch.no_grad()
328 | def retriev(
329 | model, preprocess, elements: [Image.Image], search_text: str, device
330 | ):
331 | preprocessed_images = [preprocess(image).to(device) for image in elements]
332 | tokenized_text = clip.tokenize([search_text]).to(device)
333 | stacked_images = torch.stack(preprocessed_images)
334 | image_features = model.encode_image(stacked_images)
335 | text_features = model.encode_text(tokenized_text)
336 | image_features /= image_features.norm(dim=-1, keepdim=True)
337 | text_features /= text_features.norm(dim=-1, keepdim=True)
338 | probs = 100.0 * image_features @ text_features.T
339 | return probs[:, 0].softmax(dim=0)
340 |
341 |
342 | def crop_image(annotations, image_like):
343 | if isinstance(image_like, str):
344 | image = Image.open(image_like)
345 | else:
346 | image = image_like
347 | ori_w, ori_h = image.size
348 | mask_h, mask_w = annotations[0]["segmentation"].shape
349 | if ori_w != mask_w or ori_h != mask_h:
350 | image = image.resize((mask_w, mask_h))
351 | cropped_boxes = []
352 | cropped_images = []
353 | not_crop = []
354 | origin_id = []
355 | for _, mask in enumerate(annotations):
356 | if np.sum(mask["segmentation"]) <= 100:
357 | continue
358 | origin_id.append(_)
359 | bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
360 | cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
361 | # cropped_boxes.append(segment_image(image,mask["segmentation"]))
362 | cropped_images.append(bbox) # 保存裁剪的图片的bbox
363 | return cropped_boxes, cropped_images, not_crop, origin_id, annotations
364 |
365 |
366 | def box_prompt(masks, bbox, target_height, target_width):
367 | h = masks.shape[1]
368 | w = masks.shape[2]
369 | if h != target_height or w != target_width:
370 | bbox = [
371 | int(bbox[0] * w / target_width),
372 | int(bbox[1] * h / target_height),
373 | int(bbox[2] * w / target_width),
374 | int(bbox[3] * h / target_height),
375 | ]
376 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
377 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
378 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
379 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
380 |
381 | # IoUs = torch.zeros(len(masks), dtype=torch.float32)
382 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
383 |
384 | masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
385 | orig_masks_area = torch.sum(masks, dim=(1, 2))
386 |
387 | union = bbox_area + orig_masks_area - masks_area
388 | IoUs = masks_area / union
389 | max_iou_index = torch.argmax(IoUs)
390 |
391 | return masks[max_iou_index].cpu().numpy(), max_iou_index
392 |
393 |
394 | def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
395 | h = masks[0]["segmentation"].shape[0]
396 | w = masks[0]["segmentation"].shape[1]
397 | if h != target_height or w != target_width:
398 | points = [
399 | [int(point[0] * w / target_width), int(point[1] * h / target_height)]
400 | for point in points
401 | ]
402 | onemask = np.zeros((h, w))
403 | masks = sorted(masks, key=lambda x: x['area'], reverse=True)
404 | for i, annotation in enumerate(masks):
405 | if type(annotation) == dict:
406 | mask = annotation['segmentation']
407 | else:
408 | mask = annotation
409 | for i, point in enumerate(points):
410 | if mask[point[1], point[0]] == 1 and point_label[i] == 1:
411 | onemask[mask] = 1
412 | if mask[point[1], point[0]] == 1 and point_label[i] == 0:
413 | onemask[mask] = 0
414 | onemask = onemask >= 1
415 | return onemask, 0
416 |
417 |
418 | def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
419 | cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
420 | annotations, img_path
421 | )
422 | clip_model, preprocess = clip.load("ViT-B/32", device=device)
423 | scores = retriev(
424 | clip_model, preprocess, cropped_boxes, text, device=device
425 | )
426 | max_idx = scores.argsort()
427 | max_idx = max_idx[-1]
428 | max_idx = origin_id[int(max_idx)]
429 |
430 | # find the biggest mask which contains the mask with max score
431 | if wider:
432 | mask0 = annotations_[max_idx]["segmentation"]
433 | area0 = np.sum(mask0)
434 | areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
435 | areas = sorted(areas, key=lambda area: area[1], reverse=True)
436 | indices = [area[0] for area in areas]
437 | for index in indices:
438 | if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
439 | max_idx = index
440 | break
441 |
442 | return annotations_[max_idx]["segmentation"], max_idx
443 |
--------------------------------------------------------------------------------
/utils/tools_gradio.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import matplotlib.pyplot as plt
4 | import cv2
5 | import torch
6 |
7 |
8 | def fast_process(
9 | annotations,
10 | image,
11 | device,
12 | scale,
13 | better_quality=False,
14 | mask_random_color=True,
15 | bbox=None,
16 | use_retina=True,
17 | withContours=True,
18 | ):
19 | if isinstance(annotations[0], dict):
20 | annotations = [annotation['segmentation'] for annotation in annotations]
21 |
22 | original_h = image.height
23 | original_w = image.width
24 | if better_quality:
25 | if isinstance(annotations[0], torch.Tensor):
26 | annotations = np.array(annotations.cpu())
27 | for i, mask in enumerate(annotations):
28 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
29 | annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
30 | if device == 'cpu':
31 | annotations = np.array(annotations)
32 | inner_mask = fast_show_mask(
33 | annotations,
34 | plt.gca(),
35 | random_color=mask_random_color,
36 | bbox=bbox,
37 | retinamask=use_retina,
38 | target_height=original_h,
39 | target_width=original_w,
40 | )
41 | else:
42 | if isinstance(annotations[0], np.ndarray):
43 | annotations = torch.from_numpy(annotations)
44 | inner_mask = fast_show_mask_gpu(
45 | annotations,
46 | plt.gca(),
47 | random_color=mask_random_color,
48 | bbox=bbox,
49 | retinamask=use_retina,
50 | target_height=original_h,
51 | target_width=original_w,
52 | )
53 | if isinstance(annotations, torch.Tensor):
54 | annotations = annotations.cpu().numpy()
55 |
56 | if withContours:
57 | contour_all = []
58 | temp = np.zeros((original_h, original_w, 1))
59 | for i, mask in enumerate(annotations):
60 | if type(mask) == dict:
61 | mask = mask['segmentation']
62 | annotation = mask.astype(np.uint8)
63 | if use_retina == False:
64 | annotation = cv2.resize(
65 | annotation,
66 | (original_w, original_h),
67 | interpolation=cv2.INTER_NEAREST,
68 | )
69 | contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
70 | for contour in contours:
71 | contour_all.append(contour)
72 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
73 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
74 | contour_mask = temp / 255 * color.reshape(1, 1, -1)
75 |
76 | image = image.convert('RGBA')
77 | overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
78 | image.paste(overlay_inner, (0, 0), overlay_inner)
79 |
80 | if withContours:
81 | overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
82 | image.paste(overlay_contour, (0, 0), overlay_contour)
83 |
84 | return image
85 |
86 |
87 | # CPU post process
88 | def fast_show_mask(
89 | annotation,
90 | ax,
91 | random_color=False,
92 | bbox=None,
93 | retinamask=True,
94 | target_height=960,
95 | target_width=960,
96 | ):
97 | mask_sum = annotation.shape[0]
98 | height = annotation.shape[1]
99 | weight = annotation.shape[2]
100 | # 将annotation 按照面积 排序
101 | areas = np.sum(annotation, axis=(1, 2))
102 | sorted_indices = np.argsort(areas)[::1]
103 | annotation = annotation[sorted_indices]
104 |
105 | index = (annotation != 0).argmax(axis=0)
106 | if random_color:
107 | color = np.random.random((mask_sum, 1, 1, 3))
108 | else:
109 | color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
110 | transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
111 | visual = np.concatenate([color, transparency], axis=-1)
112 | mask_image = np.expand_dims(annotation, -1) * visual
113 |
114 | mask = np.zeros((height, weight, 4))
115 |
116 | h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
117 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
118 |
119 | mask[h_indices, w_indices, :] = mask_image[indices]
120 | if bbox is not None:
121 | x1, y1, x2, y2 = bbox
122 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123 |
124 | if not retinamask:
125 | mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126 |
127 | return mask
128 |
129 |
130 | def fast_show_mask_gpu(
131 | annotation,
132 | ax,
133 | random_color=False,
134 | bbox=None,
135 | retinamask=True,
136 | target_height=960,
137 | target_width=960,
138 | ):
139 | device = annotation.device
140 | mask_sum = annotation.shape[0]
141 | height = annotation.shape[1]
142 | weight = annotation.shape[2]
143 | areas = torch.sum(annotation, dim=(1, 2))
144 | sorted_indices = torch.argsort(areas, descending=False)
145 | annotation = annotation[sorted_indices]
146 | # 找每个位置第一个非零值下标
147 | index = (annotation != 0).to(torch.long).argmax(dim=0)
148 | if random_color:
149 | color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150 | else:
151 | color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
152 | [30 / 255, 144 / 255, 255 / 255]
153 | ).to(device)
154 | transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
155 | visual = torch.cat([color, transparency], dim=-1)
156 | mask_image = torch.unsqueeze(annotation, -1) * visual
157 | # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
158 | mask = torch.zeros((height, weight, 4)).to(device)
159 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
160 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
161 | # 使用向量化索引更新show的值
162 | mask[h_indices, w_indices, :] = mask_image[indices]
163 | mask_cpu = mask.cpu().numpy()
164 | if bbox is not None:
165 | x1, y1, x2, y2 = bbox
166 | ax.add_patch(
167 | plt.Rectangle(
168 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169 | )
170 | )
171 | if not retinamask:
172 | mask_cpu = cv2.resize(
173 | mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174 | )
175 | return mask_cpu
176 |
--------------------------------------------------------------------------------