├── .gitignore
├── GroundingDINO
├── .asset
│ ├── COCO.png
│ ├── GD_GLIGEN.png
│ ├── GD_SD.png
│ ├── ODinW.png
│ ├── arch.png
│ ├── cats.png
│ └── hero_figure.png
├── LICENSE
├── README.md
├── demo
│ ├── gradio_app.py
│ └── inference_on_a_image.py
├── groundingdino
│ ├── __init__.py
│ ├── config
│ │ └── GroundingDINO_SwinT_OGC.py
│ ├── datasets
│ │ └── transforms.py
│ ├── models
│ │ ├── GroundingDINO
│ │ │ ├── __init__.py
│ │ │ ├── backbone
│ │ │ │ ├── __init__.py
│ │ │ │ ├── backbone.py
│ │ │ │ ├── position_encoding.py
│ │ │ │ └── swin_transformer.py
│ │ │ ├── bertwarper.py
│ │ │ ├── csrc
│ │ │ │ ├── MsDeformAttn
│ │ │ │ │ ├── ms_deform_attn.h
│ │ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ │ ├── ms_deform_attn_cpu.h
│ │ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ │ ├── cuda_version.cu
│ │ │ │ └── vision.cpp
│ │ │ ├── fuse_modules.py
│ │ │ ├── groundingdino.py
│ │ │ ├── ms_deform_attn.py
│ │ │ ├── transformer.py
│ │ │ ├── transformer_vanilla.py
│ │ │ └── utils.py
│ │ ├── __init__.py
│ │ └── registry.py
│ ├── util
│ │ ├── __init__.py
│ │ ├── box_ops.py
│ │ ├── get_tokenlizer.py
│ │ ├── inference.py
│ │ ├── logger.py
│ │ ├── misc.py
│ │ ├── slconfig.py
│ │ ├── slio.py
│ │ ├── time_counter.py
│ │ ├── utils.py
│ │ ├── visualizer.py
│ │ └── vl_utils.py
│ └── version.py
├── requirements.txt
└── setup.py
├── LICENSE
├── README.md
├── assets
├── Grounded-SAM_logo.png
├── automatic_label_output
│ ├── demo1.jpg
│ ├── demo2.jpg
│ ├── demo4.jpg
│ └── demo8.jpg
├── automatic_label_output_demo3.jpg
├── demo1.jpg
├── demo2.jpg
├── demo3.jpg
├── demo4.jpg
├── demo5.jpg
├── demo6.jpg
├── demo7.jpg
├── demo8.jpg
├── gradio_demo.png
├── grounded_sam.jpg
├── grounded_sam2.png
├── grounded_sam_demo3_demo4.png
├── grounded_sam_inpainting_demo.png
├── grounded_sam_output_demo1.jpg
├── grounding_dino_output_demo1.jpg
└── inpaint_demo.jpg
├── automatic_label_demo.py
├── gradio_app.py
├── grounded_dino_sam_inpainting_demo.py
├── grounded_sam.ipynb
├── grounded_sam_demo.py
├── grounded_sam_inpainting_demo.py
├── grounding_dino_demo.py
├── gsa_api.py
├── requirements.txt
└── segment_anything
├── .flake8
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
├── masks1.png
├── masks2.jpg
├── model_diagram.png
├── notebook1.png
└── notebook2.png
├── linter.sh
├── notebooks
├── automatic_mask_generator_example.ipynb
├── images
│ ├── dog.jpg
│ ├── groceries.jpg
│ └── truck.jpg
├── onnx_model_example.ipynb
└── predictor_example.ipynb
├── scripts
├── amg.py
└── export_onnx_model.py
├── segment_anything
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ └── transformer.py
├── predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── onnx.py
│ └── transforms.py
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # checkpoint
132 | *.pth
--------------------------------------------------------------------------------
/GroundingDINO/.asset/COCO.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/COCO.png
--------------------------------------------------------------------------------
/GroundingDINO/.asset/GD_GLIGEN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/GD_GLIGEN.png
--------------------------------------------------------------------------------
/GroundingDINO/.asset/GD_SD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/GD_SD.png
--------------------------------------------------------------------------------
/GroundingDINO/.asset/ODinW.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/ODinW.png
--------------------------------------------------------------------------------
/GroundingDINO/.asset/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/arch.png
--------------------------------------------------------------------------------
/GroundingDINO/.asset/cats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/cats.png
--------------------------------------------------------------------------------
/GroundingDINO/.asset/hero_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/hero_figure.png
--------------------------------------------------------------------------------
/GroundingDINO/README.md:
--------------------------------------------------------------------------------
1 | # Grounding DINO
2 |
3 | ---
4 |
5 | [](https://arxiv.org/abs/2303.05499)
6 | [](https://youtu.be/wxWDt5UiwY8)
7 | [](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
8 | [](https://youtu.be/cMa77r3YrDk)
9 | [](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo)
10 |
11 | [](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
12 | [](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
13 | [](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
14 | [](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded)
15 |
16 |
17 |
18 | Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now!
19 |
20 |
21 | ## Highlight
22 |
23 | - **Open-Set Detection.** Detect **everything** with language!
24 | - **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
25 | - **Flexible.** Collaboration with Stable Diffusion for Image Editting.
26 |
27 | ## News
28 | [2023/03/28] A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)] \
29 | [2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space! \
30 | [2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
31 | [2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)] \
32 | [2023/03/22] Code is available Now!
33 |
34 |
35 |
36 | Description
37 |
38 |
39 |
40 |
41 |
42 |
43 | ## TODO
44 |
45 | - [x] Release inference code and demo.
46 | - [x] Release checkpoints.
47 | - [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
48 | - [ ] Release training codes.
49 |
50 | ## Install
51 |
52 | If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
53 |
54 | ```bash
55 | pip install -e .
56 | ```
57 |
58 | ## Demo
59 |
60 | ```bash
61 | CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
62 | -c /path/to/config \
63 | -p /path/to/checkpoint \
64 | -i .asset/cats.png \
65 | -o "outputs/0" \
66 | -t "cat ear." \
67 | [--cpu-only] # open it for cpu mode
68 | ```
69 | See the `demo/inference_on_a_image.py` for more details.
70 |
71 | **Web UI**
72 |
73 | We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details.
74 |
75 | ## Checkpoints
76 |
77 |
78 |
79 |
80 |
81 | |
82 | name |
83 | backbone |
84 | Data |
85 | box AP on COCO |
86 | Checkpoint |
87 | Config |
88 |
89 |
90 |
91 |
92 | 1 |
93 | GroundingDINO-T |
94 | Swin-T |
95 | O365,GoldG,Cap4M |
96 | 48.4 (zero-shot) / 57.2 (fine-tune) |
97 | Github link | HF link |
98 | link |
99 |
100 |
101 |
102 |
103 | ## Results
104 |
105 |
106 |
107 | COCO Object Detection Results
108 |
109 |
110 |
111 |
112 |
113 |
114 | ODinW Object Detection Results
115 |
116 |
117 |
118 |
119 |
120 |
121 | Marrying Grounding DINO with Stable Diffusion for Image Editing
122 |
123 |
124 |
125 |
126 |
127 |
128 | Marrying Grounding DINO with GLIGEN for more Detailed Image Editing
129 |
130 |
131 |
132 |
133 | ## Model
134 |
135 | Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder.
136 |
137 | 
138 |
139 |
140 | ## Acknowledgement
141 |
142 | Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work!
143 |
144 | We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well.
145 |
146 | Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models.
147 |
148 |
149 | ## Citation
150 |
151 | If you find our work helpful for your research, please consider citing the following BibTeX entry.
152 |
153 | ```bibtex
154 | @inproceedings{ShilongLiu2023GroundingDM,
155 | title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection},
156 | author={Shilong Liu and Zhaoyang Zeng and Tianhe Ren and Feng Li and Hao Zhang and Jie Yang and Chunyuan Li and Jianwei Yang and Hang Su and Jun Zhu and Lei Zhang},
157 | year={2023}
158 | }
159 | ```
160 |
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/GroundingDINO/demo/gradio_app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from functools import partial
3 | import cv2
4 | import requests
5 | import os
6 | from io import BytesIO
7 | from PIL import Image
8 | import numpy as np
9 | from pathlib import Path
10 |
11 |
12 | import warnings
13 |
14 | import torch
15 |
16 | # prepare the environment
17 | os.system("python setup.py build develop --user")
18 | os.system("pip install packaging==21.3")
19 | os.system("pip install gradio")
20 |
21 |
22 | warnings.filterwarnings("ignore")
23 |
24 | import gradio as gr
25 |
26 | from groundingdino.models import build_model
27 | from groundingdino.util.slconfig import SLConfig
28 | from groundingdino.util.utils import clean_state_dict
29 | from groundingdino.util.inference import annotate, load_image, predict
30 | import groundingdino.datasets.transforms as T
31 |
32 | from huggingface_hub import hf_hub_download
33 |
34 |
35 |
36 | # Use this command for evaluate the GLIP-T model
37 | config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
38 | ckpt_repo_id = "ShilongLiu/GroundingDINO"
39 | ckpt_filenmae = "groundingdino_swint_ogc.pth"
40 |
41 |
42 | def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
43 | args = SLConfig.fromfile(model_config_path)
44 | model = build_model(args)
45 | args.device = device
46 |
47 | cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
48 | checkpoint = torch.load(cache_file, map_location='cpu')
49 | log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
50 | print("Model loaded from {} \n => {}".format(cache_file, log))
51 | _ = model.eval()
52 | return model
53 |
54 | def image_transform_grounding(init_image):
55 | transform = T.Compose([
56 | T.RandomResize([800], max_size=1333),
57 | T.ToTensor(),
58 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59 | ])
60 | image, _ = transform(init_image, None) # 3, h, w
61 | return init_image, image
62 |
63 | def image_transform_grounding_for_vis(init_image):
64 | transform = T.Compose([
65 | T.RandomResize([800], max_size=1333),
66 | ])
67 | image, _ = transform(init_image, None) # 3, h, w
68 | return image
69 |
70 | model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
71 |
72 | def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
73 | init_image = input_image.convert("RGB")
74 | original_size = init_image.size
75 |
76 | _, image_tensor = image_transform_grounding(init_image)
77 | image_pil: Image = image_transform_grounding_for_vis(init_image)
78 |
79 | # run grounidng
80 | boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
81 | annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
82 | image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
83 |
84 |
85 | return image_with_box
86 |
87 | if __name__ == "__main__":
88 |
89 | parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
90 | parser.add_argument("--debug", action="store_true", help="using debug mode")
91 | parser.add_argument("--share", action="store_true", help="share the app")
92 | args = parser.parse_args()
93 |
94 | block = gr.Blocks().queue()
95 | with block:
96 | gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)")
97 | gr.Markdown("### Open-World Detection with Grounding DINO")
98 |
99 | with gr.Row():
100 | with gr.Column():
101 | input_image = gr.Image(source='upload', type="pil")
102 | grounding_caption = gr.Textbox(label="Detection Prompt")
103 | run_button = gr.Button(label="Run")
104 | with gr.Accordion("Advanced options", open=False):
105 | box_threshold = gr.Slider(
106 | label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
107 | )
108 | text_threshold = gr.Slider(
109 | label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
110 | )
111 |
112 | with gr.Column():
113 | gallery = gr.outputs.Image(
114 | type="pil",
115 | # label="grounding results"
116 | ).style(full_width=True, full_height=True)
117 | # gallery = gr.Gallery(label="Generated images", show_label=False).style(
118 | # grid=[1], height="auto", container=True, full_width=True, full_height=True)
119 |
120 | run_button.click(fn=run_grounding, inputs=[
121 | input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
122 |
123 |
124 | block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share)
125 |
126 |
--------------------------------------------------------------------------------
/GroundingDINO/demo/inference_on_a_image.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image, ImageDraw, ImageFont
8 |
9 | import groundingdino.datasets.transforms as T
10 | from groundingdino.models import build_model
11 | from groundingdino.util import box_ops
12 | from groundingdino.util.slconfig import SLConfig
13 | from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
14 |
15 |
16 | def plot_boxes_to_image(image_pil, tgt):
17 | H, W = tgt["size"]
18 | boxes = tgt["boxes"]
19 | labels = tgt["labels"]
20 | assert len(boxes) == len(labels), "boxes and labels must have same length"
21 |
22 | draw = ImageDraw.Draw(image_pil)
23 | mask = Image.new("L", image_pil.size, 0)
24 | mask_draw = ImageDraw.Draw(mask)
25 |
26 | # draw boxes and masks
27 | for box, label in zip(boxes, labels):
28 | # from 0..1 to 0..W, 0..H
29 | box = box * torch.Tensor([W, H, W, H])
30 | # from xywh to xyxy
31 | box[:2] -= box[2:] / 2
32 | box[2:] += box[:2]
33 | # random color
34 | color = tuple(np.random.randint(0, 255, size=3).tolist())
35 | # draw
36 | x0, y0, x1, y1 = box
37 | x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
38 |
39 | draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
40 | # draw.text((x0, y0), str(label), fill=color)
41 |
42 | font = ImageFont.load_default()
43 | if hasattr(font, "getbbox"):
44 | bbox = draw.textbbox((x0, y0), str(label), font)
45 | else:
46 | w, h = draw.textsize(str(label), font)
47 | bbox = (x0, y0, w + x0, y0 + h)
48 | # bbox = draw.textbbox((x0, y0), str(label))
49 | draw.rectangle(bbox, fill=color)
50 | draw.text((x0, y0), str(label), fill="white")
51 |
52 | mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
53 |
54 | return image_pil, mask
55 |
56 |
57 | def load_image(image_path):
58 | # load image
59 | image_pil = Image.open(image_path).convert("RGB") # load image
60 |
61 | transform = T.Compose(
62 | [
63 | T.RandomResize([800], max_size=1333),
64 | T.ToTensor(),
65 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
66 | ]
67 | )
68 | image, _ = transform(image_pil, None) # 3, h, w
69 | return image_pil, image
70 |
71 |
72 | def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
73 | args = SLConfig.fromfile(model_config_path)
74 | args.device = "cuda" if not cpu_only else "cpu"
75 | model = build_model(args)
76 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
77 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
78 | print(load_res)
79 | _ = model.eval()
80 | return model
81 |
82 |
83 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
84 | caption = caption.lower()
85 | caption = caption.strip()
86 | if not caption.endswith("."):
87 | caption = caption + "."
88 | device = "cuda" if not cpu_only else "cpu"
89 | model = model.to(device)
90 | image = image.to(device)
91 | with torch.no_grad():
92 | outputs = model(image[None], captions=[caption])
93 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
94 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
95 | logits.shape[0]
96 |
97 | # filter output
98 | logits_filt = logits.clone()
99 | boxes_filt = boxes.clone()
100 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold
101 | logits_filt = logits_filt[filt_mask] # num_filt, 256
102 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4
103 | logits_filt.shape[0]
104 |
105 | # get phrase
106 | tokenlizer = model.tokenizer
107 | tokenized = tokenlizer(caption)
108 | # build pred
109 | pred_phrases = []
110 | for logit, box in zip(logits_filt, boxes_filt):
111 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
112 | if with_logits:
113 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
114 | else:
115 | pred_phrases.append(pred_phrase)
116 |
117 | return boxes_filt, pred_phrases
118 |
119 |
120 | if __name__ == "__main__":
121 |
122 | parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
123 | parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file")
124 | parser.add_argument(
125 | "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file"
126 | )
127 | parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file")
128 | parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt")
129 | parser.add_argument(
130 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
131 | )
132 |
133 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
134 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
135 |
136 | parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
137 | args = parser.parse_args()
138 |
139 | # cfg
140 | config_file = args.config_file # change the path of the model config file
141 | checkpoint_path = args.checkpoint_path # change the path of the model
142 | image_path = args.image_path
143 | text_prompt = args.text_prompt
144 | output_dir = args.output_dir
145 | box_threshold = args.box_threshold
146 | text_threshold = args.box_threshold
147 |
148 | # make dir
149 | os.makedirs(output_dir, exist_ok=True)
150 | # load image
151 | image_pil, image = load_image(image_path)
152 | # load model
153 | model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
154 |
155 | # visualize raw image
156 | image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
157 |
158 | # run model
159 | boxes_filt, pred_phrases = get_grounding_output(
160 | model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
161 | )
162 |
163 | # visualize pred
164 | size = image_pil.size
165 | pred_dict = {
166 | "boxes": boxes_filt,
167 | "size": [size[1], size[0]], # H,W
168 | "labels": pred_phrases,
169 | }
170 | # import ipdb; ipdb.set_trace()
171 | image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
172 | image_with_box.save(os.path.join(output_dir, "pred.jpg"))
173 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/groundingdino/__init__.py
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py:
--------------------------------------------------------------------------------
1 | batch_size = 1
2 | modelname = "groundingdino"
3 | backbone = "swin_T_224_1k"
4 | position_embedding = "sine"
5 | pe_temperatureH = 20
6 | pe_temperatureW = 20
7 | return_interm_indices = [1, 2, 3]
8 | backbone_freeze_keywords = None
9 | enc_layers = 6
10 | dec_layers = 6
11 | pre_norm = False
12 | dim_feedforward = 2048
13 | hidden_dim = 256
14 | dropout = 0.0
15 | nheads = 8
16 | num_queries = 900
17 | query_dim = 4
18 | num_patterns = 0
19 | num_feature_levels = 4
20 | enc_n_points = 4
21 | dec_n_points = 4
22 | two_stage_type = "standard"
23 | two_stage_bbox_embed_share = False
24 | two_stage_class_embed_share = False
25 | transformer_activation = "relu"
26 | dec_pred_bbox_embed_share = True
27 | dn_box_noise_scale = 1.0
28 | dn_label_noise_ratio = 0.5
29 | dn_label_coef = 1.0
30 | dn_bbox_coef = 1.0
31 | embed_init_tgt = True
32 | dn_labelbook_size = 2000
33 | max_text_len = 256
34 | text_encoder_type = "bert-base-uncased"
35 | use_text_enhancer = True
36 | use_fusion_layer = True
37 | use_checkpoint = True
38 | use_transformer_ckpt = True
39 | use_text_cross_attention = True
40 | text_dropout = 0.0
41 | fusion_dropout = 0.0
42 | fusion_droppath = 0.1
43 | sub_sentence_present = True
44 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Transforms and data augmentation for both image + bbox.
4 | """
5 | import os
6 | import random
7 |
8 | import PIL
9 | import torch
10 | import torchvision.transforms as T
11 | import torchvision.transforms.functional as F
12 |
13 | from groundingdino.util.box_ops import box_xyxy_to_cxcywh
14 | from groundingdino.util.misc import interpolate
15 |
16 |
17 | def crop(image, target, region):
18 | cropped_image = F.crop(image, *region)
19 |
20 | target = target.copy()
21 | i, j, h, w = region
22 |
23 | # should we do something wrt the original size?
24 | target["size"] = torch.tensor([h, w])
25 |
26 | fields = ["labels", "area", "iscrowd", "positive_map"]
27 |
28 | if "boxes" in target:
29 | boxes = target["boxes"]
30 | max_size = torch.as_tensor([w, h], dtype=torch.float32)
31 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
32 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
33 | cropped_boxes = cropped_boxes.clamp(min=0)
34 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
35 | target["boxes"] = cropped_boxes.reshape(-1, 4)
36 | target["area"] = area
37 | fields.append("boxes")
38 |
39 | if "masks" in target:
40 | # FIXME should we update the area here if there are no boxes?
41 | target["masks"] = target["masks"][:, i : i + h, j : j + w]
42 | fields.append("masks")
43 |
44 | # remove elements for which the boxes or masks that have zero area
45 | if "boxes" in target or "masks" in target:
46 | # favor boxes selection when defining which elements to keep
47 | # this is compatible with previous implementation
48 | if "boxes" in target:
49 | cropped_boxes = target["boxes"].reshape(-1, 2, 2)
50 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
51 | else:
52 | keep = target["masks"].flatten(1).any(1)
53 |
54 | for field in fields:
55 | if field in target:
56 | target[field] = target[field][keep]
57 |
58 | if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO":
59 | # for debug and visualization only.
60 | if "strings_positive" in target:
61 | target["strings_positive"] = [
62 | _i for _i, _j in zip(target["strings_positive"], keep) if _j
63 | ]
64 |
65 | return cropped_image, target
66 |
67 |
68 | def hflip(image, target):
69 | flipped_image = F.hflip(image)
70 |
71 | w, h = image.size
72 |
73 | target = target.copy()
74 | if "boxes" in target:
75 | boxes = target["boxes"]
76 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor(
77 | [w, 0, w, 0]
78 | )
79 | target["boxes"] = boxes
80 |
81 | if "masks" in target:
82 | target["masks"] = target["masks"].flip(-1)
83 |
84 | return flipped_image, target
85 |
86 |
87 | def resize(image, target, size, max_size=None):
88 | # size can be min_size (scalar) or (w, h) tuple
89 |
90 | def get_size_with_aspect_ratio(image_size, size, max_size=None):
91 | w, h = image_size
92 | if max_size is not None:
93 | min_original_size = float(min((w, h)))
94 | max_original_size = float(max((w, h)))
95 | if max_original_size / min_original_size * size > max_size:
96 | size = int(round(max_size * min_original_size / max_original_size))
97 |
98 | if (w <= h and w == size) or (h <= w and h == size):
99 | return (h, w)
100 |
101 | if w < h:
102 | ow = size
103 | oh = int(size * h / w)
104 | else:
105 | oh = size
106 | ow = int(size * w / h)
107 |
108 | return (oh, ow)
109 |
110 | def get_size(image_size, size, max_size=None):
111 | if isinstance(size, (list, tuple)):
112 | return size[::-1]
113 | else:
114 | return get_size_with_aspect_ratio(image_size, size, max_size)
115 |
116 | size = get_size(image.size, size, max_size)
117 | rescaled_image = F.resize(image, size)
118 |
119 | if target is None:
120 | return rescaled_image, None
121 |
122 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
123 | ratio_width, ratio_height = ratios
124 |
125 | target = target.copy()
126 | if "boxes" in target:
127 | boxes = target["boxes"]
128 | scaled_boxes = boxes * torch.as_tensor(
129 | [ratio_width, ratio_height, ratio_width, ratio_height]
130 | )
131 | target["boxes"] = scaled_boxes
132 |
133 | if "area" in target:
134 | area = target["area"]
135 | scaled_area = area * (ratio_width * ratio_height)
136 | target["area"] = scaled_area
137 |
138 | h, w = size
139 | target["size"] = torch.tensor([h, w])
140 |
141 | if "masks" in target:
142 | target["masks"] = (
143 | interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5
144 | )
145 |
146 | return rescaled_image, target
147 |
148 |
149 | def pad(image, target, padding):
150 | # assumes that we only pad on the bottom right corners
151 | padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
152 | if target is None:
153 | return padded_image, None
154 | target = target.copy()
155 | # should we do something wrt the original size?
156 | target["size"] = torch.tensor(padded_image.size[::-1])
157 | if "masks" in target:
158 | target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1]))
159 | return padded_image, target
160 |
161 |
162 | class ResizeDebug(object):
163 | def __init__(self, size):
164 | self.size = size
165 |
166 | def __call__(self, img, target):
167 | return resize(img, target, self.size)
168 |
169 |
170 | class RandomCrop(object):
171 | def __init__(self, size):
172 | self.size = size
173 |
174 | def __call__(self, img, target):
175 | region = T.RandomCrop.get_params(img, self.size)
176 | return crop(img, target, region)
177 |
178 |
179 | class RandomSizeCrop(object):
180 | def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False):
181 | # respect_boxes: True to keep all boxes
182 | # False to tolerence box filter
183 | self.min_size = min_size
184 | self.max_size = max_size
185 | self.respect_boxes = respect_boxes
186 |
187 | def __call__(self, img: PIL.Image.Image, target: dict):
188 | init_boxes = len(target["boxes"])
189 | max_patience = 10
190 | for i in range(max_patience):
191 | w = random.randint(self.min_size, min(img.width, self.max_size))
192 | h = random.randint(self.min_size, min(img.height, self.max_size))
193 | region = T.RandomCrop.get_params(img, [h, w])
194 | result_img, result_target = crop(img, target, region)
195 | if (
196 | not self.respect_boxes
197 | or len(result_target["boxes"]) == init_boxes
198 | or i == max_patience - 1
199 | ):
200 | return result_img, result_target
201 | return result_img, result_target
202 |
203 |
204 | class CenterCrop(object):
205 | def __init__(self, size):
206 | self.size = size
207 |
208 | def __call__(self, img, target):
209 | image_width, image_height = img.size
210 | crop_height, crop_width = self.size
211 | crop_top = int(round((image_height - crop_height) / 2.0))
212 | crop_left = int(round((image_width - crop_width) / 2.0))
213 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
214 |
215 |
216 | class RandomHorizontalFlip(object):
217 | def __init__(self, p=0.5):
218 | self.p = p
219 |
220 | def __call__(self, img, target):
221 | if random.random() < self.p:
222 | return hflip(img, target)
223 | return img, target
224 |
225 |
226 | class RandomResize(object):
227 | def __init__(self, sizes, max_size=None):
228 | assert isinstance(sizes, (list, tuple))
229 | self.sizes = sizes
230 | self.max_size = max_size
231 |
232 | def __call__(self, img, target=None):
233 | size = random.choice(self.sizes)
234 | return resize(img, target, size, self.max_size)
235 |
236 |
237 | class RandomPad(object):
238 | def __init__(self, max_pad):
239 | self.max_pad = max_pad
240 |
241 | def __call__(self, img, target):
242 | pad_x = random.randint(0, self.max_pad)
243 | pad_y = random.randint(0, self.max_pad)
244 | return pad(img, target, (pad_x, pad_y))
245 |
246 |
247 | class RandomSelect(object):
248 | """
249 | Randomly selects between transforms1 and transforms2,
250 | with probability p for transforms1 and (1 - p) for transforms2
251 | """
252 |
253 | def __init__(self, transforms1, transforms2, p=0.5):
254 | self.transforms1 = transforms1
255 | self.transforms2 = transforms2
256 | self.p = p
257 |
258 | def __call__(self, img, target):
259 | if random.random() < self.p:
260 | return self.transforms1(img, target)
261 | return self.transforms2(img, target)
262 |
263 |
264 | class ToTensor(object):
265 | def __call__(self, img, target):
266 | return F.to_tensor(img), target
267 |
268 |
269 | class RandomErasing(object):
270 | def __init__(self, *args, **kwargs):
271 | self.eraser = T.RandomErasing(*args, **kwargs)
272 |
273 | def __call__(self, img, target):
274 | return self.eraser(img), target
275 |
276 |
277 | class Normalize(object):
278 | def __init__(self, mean, std):
279 | self.mean = mean
280 | self.std = std
281 |
282 | def __call__(self, image, target=None):
283 | image = F.normalize(image, mean=self.mean, std=self.std)
284 | if target is None:
285 | return image, None
286 | target = target.copy()
287 | h, w = image.shape[-2:]
288 | if "boxes" in target:
289 | boxes = target["boxes"]
290 | boxes = box_xyxy_to_cxcywh(boxes)
291 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
292 | target["boxes"] = boxes
293 | return image, target
294 |
295 |
296 | class Compose(object):
297 | def __init__(self, transforms):
298 | self.transforms = transforms
299 |
300 | def __call__(self, image, target):
301 | for t in self.transforms:
302 | image, target = t(image, target)
303 | return image, target
304 |
305 | def __repr__(self):
306 | format_string = self.__class__.__name__ + "("
307 | for t in self.transforms:
308 | format_string += "\n"
309 | format_string += " {0}".format(t)
310 | format_string += "\n)"
311 | return format_string
312 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Grounding DINO
3 | # url: https://github.com/IDEA-Research/GroundingDINO
4 | # Copyright (c) 2023 IDEA. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------
7 | # Conditional DETR
8 | # Copyright (c) 2021 Microsoft. All Rights Reserved.
9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10 | # ------------------------------------------------------------------------
11 | # Copied from DETR (https://github.com/facebookresearch/detr)
12 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13 | # ------------------------------------------------------------------------
14 |
15 | from .groundingdino import build_groundingdino
16 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .backbone import build_backbone
2 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Grounding DINO
3 | # url: https://github.com/IDEA-Research/GroundingDINO
4 | # Copyright (c) 2023 IDEA. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------
7 | # Conditional DETR
8 | # Copyright (c) 2021 Microsoft. All Rights Reserved.
9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10 | # ------------------------------------------------------------------------
11 | # Copied from DETR (https://github.com/facebookresearch/detr)
12 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13 | # ------------------------------------------------------------------------
14 |
15 | """
16 | Backbone modules.
17 | """
18 |
19 | from typing import Dict, List
20 |
21 | import torch
22 | import torch.nn.functional as F
23 | import torchvision
24 | from torch import nn
25 | from torchvision.models._utils import IntermediateLayerGetter
26 |
27 | from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process
28 |
29 | from .position_encoding import build_position_encoding
30 | from .swin_transformer import build_swin_transformer
31 |
32 |
33 | class FrozenBatchNorm2d(torch.nn.Module):
34 | """
35 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
36 |
37 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
38 | without which any other models than torchvision.models.resnet[18,34,50,101]
39 | produce nans.
40 | """
41 |
42 | def __init__(self, n):
43 | super(FrozenBatchNorm2d, self).__init__()
44 | self.register_buffer("weight", torch.ones(n))
45 | self.register_buffer("bias", torch.zeros(n))
46 | self.register_buffer("running_mean", torch.zeros(n))
47 | self.register_buffer("running_var", torch.ones(n))
48 |
49 | def _load_from_state_dict(
50 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
51 | ):
52 | num_batches_tracked_key = prefix + "num_batches_tracked"
53 | if num_batches_tracked_key in state_dict:
54 | del state_dict[num_batches_tracked_key]
55 |
56 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
57 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
58 | )
59 |
60 | def forward(self, x):
61 | # move reshapes to the beginning
62 | # to make it fuser-friendly
63 | w = self.weight.reshape(1, -1, 1, 1)
64 | b = self.bias.reshape(1, -1, 1, 1)
65 | rv = self.running_var.reshape(1, -1, 1, 1)
66 | rm = self.running_mean.reshape(1, -1, 1, 1)
67 | eps = 1e-5
68 | scale = w * (rv + eps).rsqrt()
69 | bias = b - rm * scale
70 | return x * scale + bias
71 |
72 |
73 | class BackboneBase(nn.Module):
74 | def __init__(
75 | self,
76 | backbone: nn.Module,
77 | train_backbone: bool,
78 | num_channels: int,
79 | return_interm_indices: list,
80 | ):
81 | super().__init__()
82 | for name, parameter in backbone.named_parameters():
83 | if (
84 | not train_backbone
85 | or "layer2" not in name
86 | and "layer3" not in name
87 | and "layer4" not in name
88 | ):
89 | parameter.requires_grad_(False)
90 |
91 | return_layers = {}
92 | for idx, layer_index in enumerate(return_interm_indices):
93 | return_layers.update(
94 | {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
95 | )
96 |
97 | # if len:
98 | # if use_stage1_feature:
99 | # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
100 | # else:
101 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
102 | # else:
103 | # return_layers = {'layer4': "0"}
104 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
105 | self.num_channels = num_channels
106 |
107 | def forward(self, tensor_list: NestedTensor):
108 | xs = self.body(tensor_list.tensors)
109 | out: Dict[str, NestedTensor] = {}
110 | for name, x in xs.items():
111 | m = tensor_list.mask
112 | assert m is not None
113 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
114 | out[name] = NestedTensor(x, mask)
115 | # import ipdb; ipdb.set_trace()
116 | return out
117 |
118 |
119 | class Backbone(BackboneBase):
120 | """ResNet backbone with frozen BatchNorm."""
121 |
122 | def __init__(
123 | self,
124 | name: str,
125 | train_backbone: bool,
126 | dilation: bool,
127 | return_interm_indices: list,
128 | batch_norm=FrozenBatchNorm2d,
129 | ):
130 | if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
131 | backbone = getattr(torchvision.models, name)(
132 | replace_stride_with_dilation=[False, False, dilation],
133 | pretrained=is_main_process(),
134 | norm_layer=batch_norm,
135 | )
136 | else:
137 | raise NotImplementedError("Why you can get here with name {}".format(name))
138 | # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
139 | assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
140 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
141 | num_channels_all = [256, 512, 1024, 2048]
142 | num_channels = num_channels_all[4 - len(return_interm_indices) :]
143 | super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
144 |
145 |
146 | class Joiner(nn.Sequential):
147 | def __init__(self, backbone, position_embedding):
148 | super().__init__(backbone, position_embedding)
149 |
150 | def forward(self, tensor_list: NestedTensor):
151 | xs = self[0](tensor_list)
152 | out: List[NestedTensor] = []
153 | pos = []
154 | for name, x in xs.items():
155 | out.append(x)
156 | # position encoding
157 | pos.append(self[1](x).to(x.tensors.dtype))
158 |
159 | return out, pos
160 |
161 |
162 | def build_backbone(args):
163 | """
164 | Useful args:
165 | - backbone: backbone name
166 | - lr_backbone:
167 | - dilation
168 | - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
169 | - backbone_freeze_keywords:
170 | - use_checkpoint: for swin only for now
171 |
172 | """
173 | position_embedding = build_position_encoding(args)
174 | train_backbone = True
175 | if not train_backbone:
176 | raise ValueError("Please set lr_backbone > 0")
177 | return_interm_indices = args.return_interm_indices
178 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
179 | args.backbone_freeze_keywords
180 | use_checkpoint = getattr(args, "use_checkpoint", False)
181 |
182 | if args.backbone in ["resnet50", "resnet101"]:
183 | backbone = Backbone(
184 | args.backbone,
185 | train_backbone,
186 | args.dilation,
187 | return_interm_indices,
188 | batch_norm=FrozenBatchNorm2d,
189 | )
190 | bb_num_channels = backbone.num_channels
191 | elif args.backbone in [
192 | "swin_T_224_1k",
193 | "swin_B_224_22k",
194 | "swin_B_384_22k",
195 | "swin_L_224_22k",
196 | "swin_L_384_22k",
197 | ]:
198 | pretrain_img_size = int(args.backbone.split("_")[-2])
199 | backbone = build_swin_transformer(
200 | args.backbone,
201 | pretrain_img_size=pretrain_img_size,
202 | out_indices=tuple(return_interm_indices),
203 | dilation=False,
204 | use_checkpoint=use_checkpoint,
205 | )
206 |
207 | bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
208 | else:
209 | raise NotImplementedError("Unknown backbone {}".format(args.backbone))
210 |
211 | assert len(bb_num_channels) == len(
212 | return_interm_indices
213 | ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
214 |
215 | model = Joiner(backbone, position_embedding)
216 | model.num_channels = bb_num_channels
217 | assert isinstance(
218 | bb_num_channels, List
219 | ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
220 | # import ipdb; ipdb.set_trace()
221 | return model
222 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Grounding DINO
3 | # url: https://github.com/IDEA-Research/GroundingDINO
4 | # Copyright (c) 2023 IDEA. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------
7 | # DINO
8 | # Copyright (c) 2022 IDEA. All Rights Reserved.
9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10 | # ------------------------------------------------------------------------
11 | # Conditional DETR
12 | # Copyright (c) 2021 Microsoft. All Rights Reserved.
13 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14 | # ------------------------------------------------------------------------
15 | # Copied from DETR (https://github.com/facebookresearch/detr)
16 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17 | # ------------------------------------------------------------------------
18 |
19 | """
20 | Various positional encodings for the transformer.
21 | """
22 | import math
23 |
24 | import torch
25 | from torch import nn
26 |
27 | from groundingdino.util.misc import NestedTensor
28 |
29 |
30 | class PositionEmbeddingSine(nn.Module):
31 | """
32 | This is a more standard version of the position embedding, very similar to the one
33 | used by the Attention is all you need paper, generalized to work on images.
34 | """
35 |
36 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
37 | super().__init__()
38 | self.num_pos_feats = num_pos_feats
39 | self.temperature = temperature
40 | self.normalize = normalize
41 | if scale is not None and normalize is False:
42 | raise ValueError("normalize should be True if scale is passed")
43 | if scale is None:
44 | scale = 2 * math.pi
45 | self.scale = scale
46 |
47 | def forward(self, tensor_list: NestedTensor):
48 | x = tensor_list.tensors
49 | mask = tensor_list.mask
50 | assert mask is not None
51 | not_mask = ~mask
52 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
53 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
54 | if self.normalize:
55 | eps = 1e-6
56 | # if os.environ.get("SHILONG_AMP", None) == '1':
57 | # eps = 1e-4
58 | # else:
59 | # eps = 1e-6
60 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62 |
63 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65 |
66 | pos_x = x_embed[:, :, :, None] / dim_t
67 | pos_y = y_embed[:, :, :, None] / dim_t
68 | pos_x = torch.stack(
69 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70 | ).flatten(3)
71 | pos_y = torch.stack(
72 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73 | ).flatten(3)
74 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75 | return pos
76 |
77 |
78 | class PositionEmbeddingSineHW(nn.Module):
79 | """
80 | This is a more standard version of the position embedding, very similar to the one
81 | used by the Attention is all you need paper, generalized to work on images.
82 | """
83 |
84 | def __init__(
85 | self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None
86 | ):
87 | super().__init__()
88 | self.num_pos_feats = num_pos_feats
89 | self.temperatureH = temperatureH
90 | self.temperatureW = temperatureW
91 | self.normalize = normalize
92 | if scale is not None and normalize is False:
93 | raise ValueError("normalize should be True if scale is passed")
94 | if scale is None:
95 | scale = 2 * math.pi
96 | self.scale = scale
97 |
98 | def forward(self, tensor_list: NestedTensor):
99 | x = tensor_list.tensors
100 | mask = tensor_list.mask
101 | assert mask is not None
102 | not_mask = ~mask
103 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
104 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
105 |
106 | # import ipdb; ipdb.set_trace()
107 |
108 | if self.normalize:
109 | eps = 1e-6
110 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
111 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
112 |
113 | dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
114 | dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
115 | pos_x = x_embed[:, :, :, None] / dim_tx
116 |
117 | dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
118 | dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
119 | pos_y = y_embed[:, :, :, None] / dim_ty
120 |
121 | pos_x = torch.stack(
122 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
123 | ).flatten(3)
124 | pos_y = torch.stack(
125 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
126 | ).flatten(3)
127 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
128 |
129 | # import ipdb; ipdb.set_trace()
130 |
131 | return pos
132 |
133 |
134 | class PositionEmbeddingLearned(nn.Module):
135 | """
136 | Absolute pos embedding, learned.
137 | """
138 |
139 | def __init__(self, num_pos_feats=256):
140 | super().__init__()
141 | self.row_embed = nn.Embedding(50, num_pos_feats)
142 | self.col_embed = nn.Embedding(50, num_pos_feats)
143 | self.reset_parameters()
144 |
145 | def reset_parameters(self):
146 | nn.init.uniform_(self.row_embed.weight)
147 | nn.init.uniform_(self.col_embed.weight)
148 |
149 | def forward(self, tensor_list: NestedTensor):
150 | x = tensor_list.tensors
151 | h, w = x.shape[-2:]
152 | i = torch.arange(w, device=x.device)
153 | j = torch.arange(h, device=x.device)
154 | x_emb = self.col_embed(i)
155 | y_emb = self.row_embed(j)
156 | pos = (
157 | torch.cat(
158 | [
159 | x_emb.unsqueeze(0).repeat(h, 1, 1),
160 | y_emb.unsqueeze(1).repeat(1, w, 1),
161 | ],
162 | dim=-1,
163 | )
164 | .permute(2, 0, 1)
165 | .unsqueeze(0)
166 | .repeat(x.shape[0], 1, 1, 1)
167 | )
168 | return pos
169 |
170 |
171 | def build_position_encoding(args):
172 | N_steps = args.hidden_dim // 2
173 | if args.position_embedding in ("v2", "sine"):
174 | # TODO find a better way of exposing other arguments
175 | position_embedding = PositionEmbeddingSineHW(
176 | N_steps,
177 | temperatureH=args.pe_temperatureH,
178 | temperatureW=args.pe_temperatureW,
179 | normalize=True,
180 | )
181 | elif args.position_embedding in ("v3", "learned"):
182 | position_embedding = PositionEmbeddingLearned(N_steps)
183 | else:
184 | raise ValueError(f"not supported {args.position_embedding}")
185 |
186 | return position_embedding
187 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "ms_deform_attn_cuda.h"
17 | #endif
18 |
19 | namespace groundingdino {
20 |
21 | at::Tensor
22 | ms_deform_attn_forward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const int im2col_step)
29 | {
30 | if (value.type().is_cuda())
31 | {
32 | #ifdef WITH_CUDA
33 | return ms_deform_attn_cuda_forward(
34 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
35 | #else
36 | AT_ERROR("Not compiled with GPU support");
37 | #endif
38 | }
39 | AT_ERROR("Not implemented on the CPU");
40 | }
41 |
42 | std::vector
43 | ms_deform_attn_backward(
44 | const at::Tensor &value,
45 | const at::Tensor &spatial_shapes,
46 | const at::Tensor &level_start_index,
47 | const at::Tensor &sampling_loc,
48 | const at::Tensor &attn_weight,
49 | const at::Tensor &grad_output,
50 | const int im2col_step)
51 | {
52 | if (value.type().is_cuda())
53 | {
54 | #ifdef WITH_CUDA
55 | return ms_deform_attn_cuda_backward(
56 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
57 | #else
58 | AT_ERROR("Not compiled with GPU support");
59 | #endif
60 | }
61 | AT_ERROR("Not implemented on the CPU");
62 | }
63 |
64 | } // namespace groundingdino
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 | namespace groundingdino {
17 |
18 | at::Tensor
19 | ms_deform_attn_cpu_forward(
20 | const at::Tensor &value,
21 | const at::Tensor &spatial_shapes,
22 | const at::Tensor &level_start_index,
23 | const at::Tensor &sampling_loc,
24 | const at::Tensor &attn_weight,
25 | const int im2col_step)
26 | {
27 | AT_ERROR("Not implement on cpu");
28 | }
29 |
30 | std::vector
31 | ms_deform_attn_cpu_backward(
32 | const at::Tensor &value,
33 | const at::Tensor &spatial_shapes,
34 | const at::Tensor &level_start_index,
35 | const at::Tensor &sampling_loc,
36 | const at::Tensor &attn_weight,
37 | const at::Tensor &grad_output,
38 | const int im2col_step)
39 | {
40 | AT_ERROR("Not implement on cpu");
41 | }
42 |
43 | } // namespace groundingdino
44 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | namespace groundingdino {
15 |
16 | at::Tensor
17 | ms_deform_attn_cpu_forward(
18 | const at::Tensor &value,
19 | const at::Tensor &spatial_shapes,
20 | const at::Tensor &level_start_index,
21 | const at::Tensor &sampling_loc,
22 | const at::Tensor &attn_weight,
23 | const int im2col_step);
24 |
25 | std::vector
26 | ms_deform_attn_cpu_backward(
27 | const at::Tensor &value,
28 | const at::Tensor &spatial_shapes,
29 | const at::Tensor &level_start_index,
30 | const at::Tensor &sampling_loc,
31 | const at::Tensor &attn_weight,
32 | const at::Tensor &grad_output,
33 | const int im2col_step);
34 |
35 | } // namespace groundingdino
36 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 | #include "ms_deform_im2col_cuda.cuh"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 | namespace groundingdino {
20 |
21 | at::Tensor ms_deform_attn_cuda_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
30 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
31 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
32 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
33 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
34 |
35 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
36 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
37 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
38 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
39 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
40 |
41 | const int batch = value.size(0);
42 | const int spatial_size = value.size(1);
43 | const int num_heads = value.size(2);
44 | const int channels = value.size(3);
45 |
46 | const int num_levels = spatial_shapes.size(0);
47 |
48 | const int num_query = sampling_loc.size(1);
49 | const int num_point = sampling_loc.size(4);
50 |
51 | const int im2col_step_ = std::min(batch, im2col_step);
52 |
53 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
54 |
55 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
56 |
57 | const int batch_n = im2col_step_;
58 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
59 | auto per_value_size = spatial_size * num_heads * channels;
60 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
61 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
62 | for (int n = 0; n < batch/im2col_step_; ++n)
63 | {
64 | auto columns = output_n.select(0, n);
65 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
66 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
67 | value.data() + n * im2col_step_ * per_value_size,
68 | spatial_shapes.data(),
69 | level_start_index.data(),
70 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
71 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
72 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
73 | columns.data());
74 |
75 | }));
76 | }
77 |
78 | output = output.view({batch, num_query, num_heads*channels});
79 |
80 | return output;
81 | }
82 |
83 |
84 | std::vector ms_deform_attn_cuda_backward(
85 | const at::Tensor &value,
86 | const at::Tensor &spatial_shapes,
87 | const at::Tensor &level_start_index,
88 | const at::Tensor &sampling_loc,
89 | const at::Tensor &attn_weight,
90 | const at::Tensor &grad_output,
91 | const int im2col_step)
92 | {
93 |
94 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
95 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
96 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
97 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
98 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
99 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
100 |
101 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
102 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
103 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
104 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
105 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
106 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
107 |
108 | const int batch = value.size(0);
109 | const int spatial_size = value.size(1);
110 | const int num_heads = value.size(2);
111 | const int channels = value.size(3);
112 |
113 | const int num_levels = spatial_shapes.size(0);
114 |
115 | const int num_query = sampling_loc.size(1);
116 | const int num_point = sampling_loc.size(4);
117 |
118 | const int im2col_step_ = std::min(batch, im2col_step);
119 |
120 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
121 |
122 | auto grad_value = at::zeros_like(value);
123 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
124 | auto grad_attn_weight = at::zeros_like(attn_weight);
125 |
126 | const int batch_n = im2col_step_;
127 | auto per_value_size = spatial_size * num_heads * channels;
128 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
129 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
130 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
131 |
132 | for (int n = 0; n < batch/im2col_step_; ++n)
133 | {
134 | auto grad_output_g = grad_output_n.select(0, n);
135 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
136 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
137 | grad_output_g.data(),
138 | value.data() + n * im2col_step_ * per_value_size,
139 | spatial_shapes.data(),
140 | level_start_index.data(),
141 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
142 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
143 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
144 | grad_value.data() + n * im2col_step_ * per_value_size,
145 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
146 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
147 |
148 | }));
149 | }
150 |
151 | return {
152 | grad_value, grad_sampling_loc, grad_attn_weight
153 | };
154 | }
155 |
156 | } // namespace groundingdino
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | namespace groundingdino {
15 |
16 | at::Tensor ms_deform_attn_cuda_forward(
17 | const at::Tensor &value,
18 | const at::Tensor &spatial_shapes,
19 | const at::Tensor &level_start_index,
20 | const at::Tensor &sampling_loc,
21 | const at::Tensor &attn_weight,
22 | const int im2col_step);
23 |
24 | std::vector ms_deform_attn_cuda_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 | } // namespace groundingdino
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | namespace groundingdino {
4 | int get_cudart_version() {
5 | return CUDART_VERSION;
6 | }
7 | } // namespace groundingdino
8 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | #include "MsDeformAttn/ms_deform_attn.h"
4 |
5 | namespace groundingdino {
6 |
7 | #ifdef WITH_CUDA
8 | extern int get_cudart_version();
9 | #endif
10 |
11 | std::string get_cuda_version() {
12 | #ifdef WITH_CUDA
13 | std::ostringstream oss;
14 |
15 | // copied from
16 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
17 | auto printCudaStyleVersion = [&](int v) {
18 | oss << (v / 1000) << "." << (v / 10 % 100);
19 | if (v % 10 != 0) {
20 | oss << "." << (v % 10);
21 | }
22 | };
23 | printCudaStyleVersion(get_cudart_version());
24 | return oss.str();
25 | #else
26 | return std::string("not available");
27 | #endif
28 | }
29 |
30 | // similar to
31 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
32 | std::string get_compiler_version() {
33 | std::ostringstream ss;
34 | #if defined(__GNUC__)
35 | #ifndef __clang__
36 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
37 | #endif
38 | #endif
39 |
40 | #if defined(__clang_major__)
41 | {
42 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
43 | << __clang_patchlevel__;
44 | }
45 | #endif
46 |
47 | #if defined(_MSC_VER)
48 | { ss << "MSVC " << _MSC_FULL_VER; }
49 | #endif
50 | return ss.str();
51 | }
52 |
53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
54 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
55 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
56 | }
57 |
58 | } // namespace groundingdino
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/GroundingDINO/transformer_vanilla.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Grounding DINO
3 | # url: https://github.com/IDEA-Research/GroundingDINO
4 | # Copyright (c) 2023 IDEA. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------
7 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
9 | """
10 | DETR Transformer class.
11 |
12 | Copy-paste from torch.nn.Transformer with modifications:
13 | * positional encodings are passed in MHattention
14 | * extra LN at the end of encoder is removed
15 | * decoder returns a stack of activations from all decoding layers
16 | """
17 | from typing import Optional
18 |
19 | import torch
20 | import torch.nn.functional as F
21 | from torch import Tensor, nn
22 |
23 | from .utils import (
24 | MLP,
25 | _get_activation_fn,
26 | _get_clones,
27 | gen_encoder_output_proposals,
28 | gen_sineembed_for_position,
29 | sigmoid_focal_loss,
30 | )
31 |
32 |
33 | class TextTransformer(nn.Module):
34 | def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1):
35 | super().__init__()
36 | self.num_layers = num_layers
37 | self.d_model = d_model
38 | self.nheads = nheads
39 | self.dim_feedforward = dim_feedforward
40 | self.norm = None
41 |
42 | single_encoder_layer = TransformerEncoderLayer(
43 | d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout
44 | )
45 | self.layers = _get_clones(single_encoder_layer, num_layers)
46 |
47 | def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor):
48 | """
49 |
50 | Args:
51 | text_attention_mask: bs, num_token
52 | memory_text: bs, num_token, d_model
53 |
54 | Raises:
55 | RuntimeError: _description_
56 |
57 | Returns:
58 | output: bs, num_token, d_model
59 | """
60 |
61 | output = memory_text.transpose(0, 1)
62 |
63 | for layer in self.layers:
64 | output = layer(output, src_key_padding_mask=text_attention_mask)
65 |
66 | if self.norm is not None:
67 | output = self.norm(output)
68 |
69 | return output.transpose(0, 1)
70 |
71 |
72 | class TransformerEncoderLayer(nn.Module):
73 | def __init__(
74 | self,
75 | d_model,
76 | nhead,
77 | dim_feedforward=2048,
78 | dropout=0.1,
79 | activation="relu",
80 | normalize_before=False,
81 | ):
82 | super().__init__()
83 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
84 | # Implementation of Feedforward model
85 | self.linear1 = nn.Linear(d_model, dim_feedforward)
86 | self.dropout = nn.Dropout(dropout)
87 | self.linear2 = nn.Linear(dim_feedforward, d_model)
88 |
89 | self.norm1 = nn.LayerNorm(d_model)
90 | self.norm2 = nn.LayerNorm(d_model)
91 | self.dropout1 = nn.Dropout(dropout)
92 | self.dropout2 = nn.Dropout(dropout)
93 |
94 | self.activation = _get_activation_fn(activation)
95 | self.normalize_before = normalize_before
96 | self.nhead = nhead
97 |
98 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
99 | return tensor if pos is None else tensor + pos
100 |
101 | def forward(
102 | self,
103 | src,
104 | src_mask: Optional[Tensor] = None,
105 | src_key_padding_mask: Optional[Tensor] = None,
106 | pos: Optional[Tensor] = None,
107 | ):
108 | # repeat attn mask
109 | if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]:
110 | # bs, num_q, num_k
111 | src_mask = src_mask.repeat(self.nhead, 1, 1)
112 |
113 | q = k = self.with_pos_embed(src, pos)
114 |
115 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0]
116 |
117 | # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
118 | src = src + self.dropout1(src2)
119 | src = self.norm1(src)
120 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
121 | src = src + self.dropout2(src2)
122 | src = self.norm2(src)
123 | return src
124 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Grounding DINO
3 | # url: https://github.com/IDEA-Research/GroundingDINO
4 | # Copyright (c) 2023 IDEA. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------
7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8 | from .GroundingDINO import build_groundingdino
9 |
10 |
11 | def build_model(args):
12 | # we use register to maintain models from catdet6 on.
13 | from .registry import MODULE_BUILD_FUNCS
14 |
15 | assert args.modelname in MODULE_BUILD_FUNCS._module_dict
16 | build_func = MODULE_BUILD_FUNCS.get(args.modelname)
17 | model = build_func(args)
18 | return model
19 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/models/registry.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Grounding DINO
3 | # url: https://github.com/IDEA-Research/GroundingDINO
4 | # Copyright (c) 2023 IDEA. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------
7 | # -*- coding: utf-8 -*-
8 | # @Author: Yihao Chen
9 | # @Date: 2021-08-16 16:03:17
10 | # @Last Modified by: Shilong Liu
11 | # @Last Modified time: 2022-01-23 15:26
12 | # modified from mmcv
13 |
14 | import inspect
15 | from functools import partial
16 |
17 |
18 | class Registry(object):
19 | def __init__(self, name):
20 | self._name = name
21 | self._module_dict = dict()
22 |
23 | def __repr__(self):
24 | format_str = self.__class__.__name__ + "(name={}, items={})".format(
25 | self._name, list(self._module_dict.keys())
26 | )
27 | return format_str
28 |
29 | def __len__(self):
30 | return len(self._module_dict)
31 |
32 | @property
33 | def name(self):
34 | return self._name
35 |
36 | @property
37 | def module_dict(self):
38 | return self._module_dict
39 |
40 | def get(self, key):
41 | return self._module_dict.get(key, None)
42 |
43 | def registe_with_name(self, module_name=None, force=False):
44 | return partial(self.register, module_name=module_name, force=force)
45 |
46 | def register(self, module_build_function, module_name=None, force=False):
47 | """Register a module build function.
48 | Args:
49 | module (:obj:`nn.Module`): Module to be registered.
50 | """
51 | if not inspect.isfunction(module_build_function):
52 | raise TypeError(
53 | "module_build_function must be a function, but got {}".format(
54 | type(module_build_function)
55 | )
56 | )
57 | if module_name is None:
58 | module_name = module_build_function.__name__
59 | if not force and module_name in self._module_dict:
60 | raise KeyError("{} is already registered in {}".format(module_name, self.name))
61 | self._module_dict[module_name] = module_build_function
62 |
63 | return module_build_function
64 |
65 |
66 | MODULE_BUILD_FUNCS = Registry("model build functions")
67 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
12 | return torch.stack(b, dim=-1)
13 |
14 |
15 | def box_xyxy_to_cxcywh(x):
16 | x0, y0, x1, y1 = x.unbind(-1)
17 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
18 | return torch.stack(b, dim=-1)
19 |
20 |
21 | # modified from torchvision to also return the union
22 | def box_iou(boxes1, boxes2):
23 | area1 = box_area(boxes1)
24 | area2 = box_area(boxes2)
25 |
26 | # import ipdb; ipdb.set_trace()
27 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
28 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
29 |
30 | wh = (rb - lt).clamp(min=0) # [N,M,2]
31 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
32 |
33 | union = area1[:, None] + area2 - inter
34 |
35 | iou = inter / (union + 1e-6)
36 | return iou, union
37 |
38 |
39 | def generalized_box_iou(boxes1, boxes2):
40 | """
41 | Generalized IoU from https://giou.stanford.edu/
42 |
43 | The boxes should be in [x0, y0, x1, y1] format
44 |
45 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
46 | and M = len(boxes2)
47 | """
48 | # degenerate boxes gives inf / nan results
49 | # so do an early check
50 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
51 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
52 | # except:
53 | # import ipdb; ipdb.set_trace()
54 | iou, union = box_iou(boxes1, boxes2)
55 |
56 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
57 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
58 |
59 | wh = (rb - lt).clamp(min=0) # [N,M,2]
60 | area = wh[:, :, 0] * wh[:, :, 1]
61 |
62 | return iou - (area - union) / (area + 1e-6)
63 |
64 |
65 | # modified from torchvision to also return the union
66 | def box_iou_pairwise(boxes1, boxes2):
67 | area1 = box_area(boxes1)
68 | area2 = box_area(boxes2)
69 |
70 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
71 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
72 |
73 | wh = (rb - lt).clamp(min=0) # [N,2]
74 | inter = wh[:, 0] * wh[:, 1] # [N]
75 |
76 | union = area1 + area2 - inter
77 |
78 | iou = inter / union
79 | return iou, union
80 |
81 |
82 | def generalized_box_iou_pairwise(boxes1, boxes2):
83 | """
84 | Generalized IoU from https://giou.stanford.edu/
85 |
86 | Input:
87 | - boxes1, boxes2: N,4
88 | Output:
89 | - giou: N, 4
90 | """
91 | # degenerate boxes gives inf / nan results
92 | # so do an early check
93 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
94 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
95 | assert boxes1.shape == boxes2.shape
96 | iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
97 |
98 | lt = torch.min(boxes1[:, :2], boxes2[:, :2])
99 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
100 |
101 | wh = (rb - lt).clamp(min=0) # [N,2]
102 | area = wh[:, 0] * wh[:, 1]
103 |
104 | return iou - (area - union) / area
105 |
106 |
107 | def masks_to_boxes(masks):
108 | """Compute the bounding boxes around the provided masks
109 |
110 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
111 |
112 | Returns a [N, 4] tensors, with the boxes in xyxy format
113 | """
114 | if masks.numel() == 0:
115 | return torch.zeros((0, 4), device=masks.device)
116 |
117 | h, w = masks.shape[-2:]
118 |
119 | y = torch.arange(0, h, dtype=torch.float)
120 | x = torch.arange(0, w, dtype=torch.float)
121 | y, x = torch.meshgrid(y, x)
122 |
123 | x_mask = masks * x.unsqueeze(0)
124 | x_max = x_mask.flatten(1).max(-1)[0]
125 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
126 |
127 | y_mask = masks * y.unsqueeze(0)
128 | y_max = y_mask.flatten(1).max(-1)[0]
129 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
130 |
131 | return torch.stack([x_min, y_min, x_max, y_max], 1)
132 |
133 |
134 | if __name__ == "__main__":
135 | x = torch.rand(5, 4)
136 | y = torch.rand(3, 4)
137 | iou, union = box_iou(x, y)
138 | import ipdb
139 |
140 | ipdb.set_trace()
141 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/get_tokenlizer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
2 |
3 |
4 | def get_tokenlizer(text_encoder_type):
5 | if not isinstance(text_encoder_type, str):
6 | # print("text_encoder_type is not a str")
7 | if hasattr(text_encoder_type, "text_encoder_type"):
8 | text_encoder_type = text_encoder_type.text_encoder_type
9 | elif text_encoder_type.get("text_encoder_type", False):
10 | text_encoder_type = text_encoder_type.get("text_encoder_type")
11 | else:
12 | raise ValueError(
13 | "Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
14 | )
15 | print("final text_encoder_type: {}".format(text_encoder_type))
16 |
17 | tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
18 | return tokenizer
19 |
20 |
21 | def get_pretrained_language_model(text_encoder_type):
22 | if text_encoder_type == "bert-base-uncased":
23 | return BertModel.from_pretrained(text_encoder_type)
24 | if text_encoder_type == "roberta-base":
25 | return RobertaModel.from_pretrained(text_encoder_type)
26 | raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
27 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/inference.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 |
3 | import cv2
4 | import numpy as np
5 | import supervision as sv
6 | import torch
7 | from PIL import Image
8 | from torchvision.ops import box_convert
9 |
10 | import groundingdino.datasets.transforms as T
11 | from groundingdino.models import build_model
12 | from groundingdino.util.misc import clean_state_dict
13 | from groundingdino.util.slconfig import SLConfig
14 | from groundingdino.util.utils import get_phrases_from_posmap
15 |
16 |
17 | def preprocess_caption(caption: str) -> str:
18 | result = caption.lower().strip()
19 | if result.endswith("."):
20 | return result
21 | return result + "."
22 |
23 |
24 | def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"):
25 | args = SLConfig.fromfile(model_config_path)
26 | args.device = device
27 | model = build_model(args)
28 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
29 | model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
30 | model.eval()
31 | return model
32 |
33 |
34 | def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
35 | transform = T.Compose(
36 | [
37 | T.RandomResize([800], max_size=1333),
38 | T.ToTensor(),
39 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
40 | ]
41 | )
42 | image_source = Image.open(image_path).convert("RGB")
43 | image = np.asarray(image_source)
44 | image_transformed, _ = transform(image_source, None)
45 | return image, image_transformed
46 |
47 |
48 | def predict(
49 | model,
50 | image: torch.Tensor,
51 | caption: str,
52 | box_threshold: float,
53 | text_threshold: float,
54 | device: str = "cuda"
55 | ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
56 | caption = preprocess_caption(caption=caption)
57 |
58 | model = model.to(device)
59 | image = image.to(device)
60 |
61 | with torch.no_grad():
62 | outputs = model(image[None], captions=[caption])
63 |
64 | prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
65 | prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)
66 |
67 | mask = prediction_logits.max(dim=1)[0] > box_threshold
68 | logits = prediction_logits[mask] # logits.shape = (n, 256)
69 | boxes = prediction_boxes[mask] # boxes.shape = (n, 4)
70 |
71 | tokenizer = model.tokenizer
72 | tokenized = tokenizer(caption)
73 |
74 | phrases = [
75 | get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
76 | for logit
77 | in logits
78 | ]
79 |
80 | return boxes, logits.max(dim=1)[0], phrases
81 |
82 |
83 | def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray:
84 | h, w, _ = image_source.shape
85 | boxes = boxes * torch.Tensor([w, h, w, h])
86 | xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
87 | detections = sv.Detections(xyxy=xyxy)
88 |
89 | labels = [
90 | f"{phrase} {logit:.2f}"
91 | for phrase, logit
92 | in zip(phrases, logits)
93 | ]
94 |
95 | box_annotator = sv.BoxAnnotator()
96 | annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
97 | annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
98 | return annotated_frame
99 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import functools
3 | import logging
4 | import os
5 | import sys
6 |
7 | from termcolor import colored
8 |
9 |
10 | class _ColorfulFormatter(logging.Formatter):
11 | def __init__(self, *args, **kwargs):
12 | self._root_name = kwargs.pop("root_name") + "."
13 | self._abbrev_name = kwargs.pop("abbrev_name", "")
14 | if len(self._abbrev_name):
15 | self._abbrev_name = self._abbrev_name + "."
16 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
17 |
18 | def formatMessage(self, record):
19 | record.name = record.name.replace(self._root_name, self._abbrev_name)
20 | log = super(_ColorfulFormatter, self).formatMessage(record)
21 | if record.levelno == logging.WARNING:
22 | prefix = colored("WARNING", "red", attrs=["blink"])
23 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
24 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
25 | else:
26 | return log
27 | return prefix + " " + log
28 |
29 |
30 | # so that calling setup_logger multiple times won't add many handlers
31 | @functools.lru_cache()
32 | def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None):
33 | """
34 | Initialize the detectron2 logger and set its verbosity level to "INFO".
35 |
36 | Args:
37 | output (str): a file name or a directory to save log. If None, will not save log file.
38 | If ends with ".txt" or ".log", assumed to be a file name.
39 | Otherwise, logs will be saved to `output/log.txt`.
40 | name (str): the root module name of this logger
41 |
42 | Returns:
43 | logging.Logger: a logger
44 | """
45 | logger = logging.getLogger(name)
46 | logger.setLevel(logging.DEBUG)
47 | logger.propagate = False
48 |
49 | if abbrev_name is None:
50 | abbrev_name = name
51 |
52 | plain_formatter = logging.Formatter(
53 | "[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S"
54 | )
55 | # stdout logging: master only
56 | if distributed_rank == 0:
57 | ch = logging.StreamHandler(stream=sys.stdout)
58 | ch.setLevel(logging.DEBUG)
59 | if color:
60 | formatter = _ColorfulFormatter(
61 | colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
62 | datefmt="%m/%d %H:%M:%S",
63 | root_name=name,
64 | abbrev_name=str(abbrev_name),
65 | )
66 | else:
67 | formatter = plain_formatter
68 | ch.setFormatter(formatter)
69 | logger.addHandler(ch)
70 |
71 | # file logging: all workers
72 | if output is not None:
73 | if output.endswith(".txt") or output.endswith(".log"):
74 | filename = output
75 | else:
76 | filename = os.path.join(output, "log.txt")
77 | if distributed_rank > 0:
78 | filename = filename + f".rank{distributed_rank}"
79 | os.makedirs(os.path.dirname(filename), exist_ok=True)
80 |
81 | fh = logging.StreamHandler(_cached_log_stream(filename))
82 | fh.setLevel(logging.DEBUG)
83 | fh.setFormatter(plain_formatter)
84 | logger.addHandler(fh)
85 |
86 | return logger
87 |
88 |
89 | # cache the opened file object, so that different calls to `setup_logger`
90 | # with the same file name can safely write to the same file.
91 | @functools.lru_cache(maxsize=None)
92 | def _cached_log_stream(filename):
93 | return open(filename, "a")
94 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/slio.py:
--------------------------------------------------------------------------------
1 | # ==========================================================
2 | # Modified from mmcv
3 | # ==========================================================
4 |
5 | import json
6 | import pickle
7 | from abc import ABCMeta, abstractmethod
8 | from pathlib import Path
9 |
10 | import yaml
11 |
12 | try:
13 | from yaml import CLoader as Loader, CDumper as Dumper
14 | except ImportError:
15 | from yaml import Loader, Dumper
16 |
17 |
18 | # ===========================
19 | # Rigister handler
20 | # ===========================
21 |
22 |
23 | class BaseFileHandler(metaclass=ABCMeta):
24 | @abstractmethod
25 | def load_from_fileobj(self, file, **kwargs):
26 | pass
27 |
28 | @abstractmethod
29 | def dump_to_fileobj(self, obj, file, **kwargs):
30 | pass
31 |
32 | @abstractmethod
33 | def dump_to_str(self, obj, **kwargs):
34 | pass
35 |
36 | def load_from_path(self, filepath, mode="r", **kwargs):
37 | with open(filepath, mode) as f:
38 | return self.load_from_fileobj(f, **kwargs)
39 |
40 | def dump_to_path(self, obj, filepath, mode="w", **kwargs):
41 | with open(filepath, mode) as f:
42 | self.dump_to_fileobj(obj, f, **kwargs)
43 |
44 |
45 | class JsonHandler(BaseFileHandler):
46 | def load_from_fileobj(self, file):
47 | return json.load(file)
48 |
49 | def dump_to_fileobj(self, obj, file, **kwargs):
50 | json.dump(obj, file, **kwargs)
51 |
52 | def dump_to_str(self, obj, **kwargs):
53 | return json.dumps(obj, **kwargs)
54 |
55 |
56 | class PickleHandler(BaseFileHandler):
57 | def load_from_fileobj(self, file, **kwargs):
58 | return pickle.load(file, **kwargs)
59 |
60 | def load_from_path(self, filepath, **kwargs):
61 | return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
62 |
63 | def dump_to_str(self, obj, **kwargs):
64 | kwargs.setdefault("protocol", 2)
65 | return pickle.dumps(obj, **kwargs)
66 |
67 | def dump_to_fileobj(self, obj, file, **kwargs):
68 | kwargs.setdefault("protocol", 2)
69 | pickle.dump(obj, file, **kwargs)
70 |
71 | def dump_to_path(self, obj, filepath, **kwargs):
72 | super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
73 |
74 |
75 | class YamlHandler(BaseFileHandler):
76 | def load_from_fileobj(self, file, **kwargs):
77 | kwargs.setdefault("Loader", Loader)
78 | return yaml.load(file, **kwargs)
79 |
80 | def dump_to_fileobj(self, obj, file, **kwargs):
81 | kwargs.setdefault("Dumper", Dumper)
82 | yaml.dump(obj, file, **kwargs)
83 |
84 | def dump_to_str(self, obj, **kwargs):
85 | kwargs.setdefault("Dumper", Dumper)
86 | return yaml.dump(obj, **kwargs)
87 |
88 |
89 | file_handlers = {
90 | "json": JsonHandler(),
91 | "yaml": YamlHandler(),
92 | "yml": YamlHandler(),
93 | "pickle": PickleHandler(),
94 | "pkl": PickleHandler(),
95 | }
96 |
97 | # ===========================
98 | # load and dump
99 | # ===========================
100 |
101 |
102 | def is_str(x):
103 | """Whether the input is an string instance.
104 |
105 | Note: This method is deprecated since python 2 is no longer supported.
106 | """
107 | return isinstance(x, str)
108 |
109 |
110 | def slload(file, file_format=None, **kwargs):
111 | """Load data from json/yaml/pickle files.
112 |
113 | This method provides a unified api for loading data from serialized files.
114 |
115 | Args:
116 | file (str or :obj:`Path` or file-like object): Filename or a file-like
117 | object.
118 | file_format (str, optional): If not specified, the file format will be
119 | inferred from the file extension, otherwise use the specified one.
120 | Currently supported formats include "json", "yaml/yml" and
121 | "pickle/pkl".
122 |
123 | Returns:
124 | The content from the file.
125 | """
126 | if isinstance(file, Path):
127 | file = str(file)
128 | if file_format is None and is_str(file):
129 | file_format = file.split(".")[-1]
130 | if file_format not in file_handlers:
131 | raise TypeError(f"Unsupported format: {file_format}")
132 |
133 | handler = file_handlers[file_format]
134 | if is_str(file):
135 | obj = handler.load_from_path(file, **kwargs)
136 | elif hasattr(file, "read"):
137 | obj = handler.load_from_fileobj(file, **kwargs)
138 | else:
139 | raise TypeError('"file" must be a filepath str or a file-object')
140 | return obj
141 |
142 |
143 | def sldump(obj, file=None, file_format=None, **kwargs):
144 | """Dump data to json/yaml/pickle strings or files.
145 |
146 | This method provides a unified api for dumping data as strings or to files,
147 | and also supports custom arguments for each file format.
148 |
149 | Args:
150 | obj (any): The python object to be dumped.
151 | file (str or :obj:`Path` or file-like object, optional): If not
152 | specified, then the object is dump to a str, otherwise to a file
153 | specified by the filename or file-like object.
154 | file_format (str, optional): Same as :func:`load`.
155 |
156 | Returns:
157 | bool: True for success, False otherwise.
158 | """
159 | if isinstance(file, Path):
160 | file = str(file)
161 | if file_format is None:
162 | if is_str(file):
163 | file_format = file.split(".")[-1]
164 | elif file is None:
165 | raise ValueError("file_format must be specified since file is None")
166 | if file_format not in file_handlers:
167 | raise TypeError(f"Unsupported format: {file_format}")
168 |
169 | handler = file_handlers[file_format]
170 | if file is None:
171 | return handler.dump_to_str(obj, **kwargs)
172 | elif is_str(file):
173 | handler.dump_to_path(obj, file, **kwargs)
174 | elif hasattr(file, "write"):
175 | handler.dump_to_fileobj(obj, file, **kwargs)
176 | else:
177 | raise TypeError('"file" must be a filename str or a file-object')
178 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/time_counter.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 |
4 |
5 | class TimeCounter:
6 | def __init__(self) -> None:
7 | pass
8 |
9 | def clear(self):
10 | self.timedict = {}
11 | self.basetime = time.perf_counter()
12 |
13 | def timeit(self, name):
14 | nowtime = time.perf_counter() - self.basetime
15 | self.timedict[name] = nowtime
16 | self.basetime = time.perf_counter()
17 |
18 |
19 | class TimeHolder:
20 | def __init__(self) -> None:
21 | self.timedict = {}
22 |
23 | def update(self, _timedict: dict):
24 | for k, v in _timedict.items():
25 | if k not in self.timedict:
26 | self.timedict[k] = AverageMeter(name=k, val_only=True)
27 | self.timedict[k].update(val=v)
28 |
29 | def final_res(self):
30 | return {k: v.avg for k, v in self.timedict.items()}
31 |
32 | def __str__(self):
33 | return json.dumps(self.final_res(), indent=2)
34 |
35 |
36 | class AverageMeter(object):
37 | """Computes and stores the average and current value"""
38 |
39 | def __init__(self, name, fmt=":f", val_only=False):
40 | self.name = name
41 | self.fmt = fmt
42 | self.val_only = val_only
43 | self.reset()
44 |
45 | def reset(self):
46 | self.val = 0
47 | self.avg = 0
48 | self.sum = 0
49 | self.count = 0
50 |
51 | def update(self, val, n=1):
52 | self.val = val
53 | self.sum += val * n
54 | self.count += n
55 | self.avg = self.sum / self.count
56 |
57 | def __str__(self):
58 | if self.val_only:
59 | fmtstr = "{name} {val" + self.fmt + "}"
60 | else:
61 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
62 | return fmtstr.format(**self.__dict__)
63 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/util/vl_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from typing import List
4 |
5 | import torch
6 |
7 |
8 | def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
9 | """construct a map such that positive_map[i,j] = True iff box i is associated to token j
10 | Input:
11 | - tokenized:
12 | - input_ids: Tensor[1, ntokens]
13 | - attention_mask: Tensor[1, ntokens]
14 | - token_span: list with length num_boxes.
15 | - each item: [start_idx, end_idx]
16 | """
17 | positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
18 | for j, tok_list in enumerate(token_span):
19 | for (beg, end) in tok_list:
20 | beg_pos = tokenized.char_to_token(beg)
21 | end_pos = tokenized.char_to_token(end - 1)
22 | if beg_pos is None:
23 | try:
24 | beg_pos = tokenized.char_to_token(beg + 1)
25 | if beg_pos is None:
26 | beg_pos = tokenized.char_to_token(beg + 2)
27 | except:
28 | beg_pos = None
29 | if end_pos is None:
30 | try:
31 | end_pos = tokenized.char_to_token(end - 2)
32 | if end_pos is None:
33 | end_pos = tokenized.char_to_token(end - 3)
34 | except:
35 | end_pos = None
36 | if beg_pos is None or end_pos is None:
37 | continue
38 |
39 | assert beg_pos is not None and end_pos is not None
40 | if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
41 | positive_map[j, beg_pos] = 1
42 | break
43 | else:
44 | positive_map[j, beg_pos : end_pos + 1].fill_(1)
45 |
46 | return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
47 |
48 |
49 | def build_captions_and_token_span(cat_list, force_lowercase):
50 | """
51 | Return:
52 | captions: str
53 | cat2tokenspan: dict
54 | {
55 | 'dog': [[0, 2]],
56 | ...
57 | }
58 | """
59 |
60 | cat2tokenspan = {}
61 | captions = ""
62 | for catname in cat_list:
63 | class_name = catname
64 | if force_lowercase:
65 | class_name = class_name.lower()
66 | if "/" in class_name:
67 | class_name_list: List = class_name.strip().split("/")
68 | class_name_list.append(class_name)
69 | class_name: str = random.choice(class_name_list)
70 |
71 | tokens_positive_i = []
72 | subnamelist = [i.strip() for i in class_name.strip().split(" ")]
73 | for subname in subnamelist:
74 | if len(subname) == 0:
75 | continue
76 | if len(captions) > 0:
77 | captions = captions + " "
78 | strat_idx = len(captions)
79 | end_idx = strat_idx + len(subname)
80 | tokens_positive_i.append([strat_idx, end_idx])
81 | captions = captions + subname
82 |
83 | if len(tokens_positive_i) > 0:
84 | captions = captions + " ."
85 | cat2tokenspan[class_name] = tokens_positive_i
86 |
87 | return captions, cat2tokenspan
88 |
89 |
90 | def build_id2posspan_and_caption(category_dict: dict):
91 | """Build id2pos_span and caption from category_dict
92 |
93 | Args:
94 | category_dict (dict): category_dict
95 | """
96 | cat_list = [item["name"].lower() for item in category_dict]
97 | id2catname = {item["id"]: item["name"].lower() for item in category_dict}
98 | caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
99 | id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
100 | return id2posspan, caption
101 |
--------------------------------------------------------------------------------
/GroundingDINO/groundingdino/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.1.0'
2 |
--------------------------------------------------------------------------------
/GroundingDINO/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | transformers
4 | addict
5 | yapf
6 | timm
7 | numpy
8 | opencv-python
9 | supervision==0.3.2
10 | pycocotools
--------------------------------------------------------------------------------
/GroundingDINO/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The IDEA Authors. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ------------------------------------------------------------------------------------------------
16 | # Modified from
17 | # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py
18 | # https://github.com/facebookresearch/detectron2/blob/main/setup.py
19 | # https://github.com/open-mmlab/mmdetection/blob/master/setup.py
20 | # https://github.com/Oneflow-Inc/libai/blob/main/setup.py
21 | # ------------------------------------------------------------------------------------------------
22 |
23 | import glob
24 | import os
25 | import subprocess
26 |
27 | import torch
28 | from setuptools import find_packages, setup
29 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
30 |
31 | # groundingdino version info
32 | version = "0.1.0"
33 | package_name = "groundingdino"
34 | cwd = os.path.dirname(os.path.abspath(__file__))
35 |
36 |
37 | sha = "Unknown"
38 | try:
39 | sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
40 | except Exception:
41 | pass
42 |
43 |
44 | def write_version_file():
45 | version_path = os.path.join(cwd, "groundingdino", "version.py")
46 | with open(version_path, "w") as f:
47 | f.write(f"__version__ = '{version}'\n")
48 | # f.write(f"git_version = {repr(sha)}\n")
49 |
50 |
51 | requirements = ["torch", "torchvision"]
52 |
53 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]
54 |
55 |
56 | def get_extensions():
57 | this_dir = os.path.dirname(os.path.abspath(__file__))
58 | extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")
59 |
60 | main_source = os.path.join(extensions_dir, "vision.cpp")
61 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
62 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
63 | os.path.join(extensions_dir, "*.cu")
64 | )
65 |
66 | sources = [main_source] + sources
67 |
68 | extension = CppExtension
69 |
70 | extra_compile_args = {"cxx": []}
71 | define_macros = []
72 |
73 | if torch.cuda.is_available() and CUDA_HOME is not None:
74 | print("Compiling with CUDA")
75 | extension = CUDAExtension
76 | sources += source_cuda
77 | define_macros += [("WITH_CUDA", None)]
78 | extra_compile_args["nvcc"] = [
79 | "-DCUDA_HAS_FP16=1",
80 | "-D__CUDA_NO_HALF_OPERATORS__",
81 | "-D__CUDA_NO_HALF_CONVERSIONS__",
82 | "-D__CUDA_NO_HALF2_OPERATORS__",
83 | ]
84 | else:
85 | print("Compiling without CUDA")
86 | define_macros += [("WITH_HIP", None)]
87 | extra_compile_args["nvcc"] = []
88 | return None
89 |
90 | sources = [os.path.join(extensions_dir, s) for s in sources]
91 | include_dirs = [extensions_dir]
92 |
93 | ext_modules = [
94 | extension(
95 | "groundingdino._C",
96 | sources,
97 | include_dirs=include_dirs,
98 | define_macros=define_macros,
99 | extra_compile_args=extra_compile_args,
100 | )
101 | ]
102 |
103 | return ext_modules
104 |
105 |
106 | def parse_requirements(fname="requirements.txt", with_version=True):
107 | """Parse the package dependencies listed in a requirements file but strips
108 | specific versioning information.
109 |
110 | Args:
111 | fname (str): path to requirements file
112 | with_version (bool, default=False): if True include version specs
113 |
114 | Returns:
115 | List[str]: list of requirements items
116 |
117 | CommandLine:
118 | python -c "import setup; print(setup.parse_requirements())"
119 | """
120 | import re
121 | import sys
122 | from os.path import exists
123 |
124 | require_fpath = fname
125 |
126 | def parse_line(line):
127 | """Parse information from a line in a requirements text file."""
128 | if line.startswith("-r "):
129 | # Allow specifying requirements in other files
130 | target = line.split(" ")[1]
131 | for info in parse_require_file(target):
132 | yield info
133 | else:
134 | info = {"line": line}
135 | if line.startswith("-e "):
136 | info["package"] = line.split("#egg=")[1]
137 | elif "@git+" in line:
138 | info["package"] = line
139 | else:
140 | # Remove versioning from the package
141 | pat = "(" + "|".join([">=", "==", ">"]) + ")"
142 | parts = re.split(pat, line, maxsplit=1)
143 | parts = [p.strip() for p in parts]
144 |
145 | info["package"] = parts[0]
146 | if len(parts) > 1:
147 | op, rest = parts[1:]
148 | if ";" in rest:
149 | # Handle platform specific dependencies
150 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
151 | version, platform_deps = map(str.strip, rest.split(";"))
152 | info["platform_deps"] = platform_deps
153 | else:
154 | version = rest # NOQA
155 | info["version"] = (op, version)
156 | yield info
157 |
158 | def parse_require_file(fpath):
159 | with open(fpath, "r") as f:
160 | for line in f.readlines():
161 | line = line.strip()
162 | if line and not line.startswith("#"):
163 | for info in parse_line(line):
164 | yield info
165 |
166 | def gen_packages_items():
167 | if exists(require_fpath):
168 | for info in parse_require_file(require_fpath):
169 | parts = [info["package"]]
170 | if with_version and "version" in info:
171 | parts.extend(info["version"])
172 | if not sys.version.startswith("3.4"):
173 | # apparently package_deps are broken in 3.4
174 | platform_deps = info.get("platform_deps")
175 | if platform_deps is not None:
176 | parts.append(";" + platform_deps)
177 | item = "".join(parts)
178 | yield item
179 |
180 | packages = list(gen_packages_items())
181 | return packages
182 |
183 |
184 | if __name__ == "__main__":
185 | print(f"Building wheel {package_name}-{version}")
186 |
187 | with open("LICENSE", "r", encoding="utf-8") as f:
188 | license = f.read()
189 |
190 | write_version_file()
191 |
192 | setup(
193 | name="groundingdino",
194 | version="0.1.0",
195 | author="International Digital Economy Academy, Shilong Liu",
196 | url="https://github.com/IDEA-Research/GroundingDINO",
197 | description="open-set object detector",
198 | license=license,
199 | install_requires=parse_requirements("requirements.txt"),
200 | packages=find_packages(
201 | exclude=(
202 | "configs",
203 | "tests",
204 | )
205 | ),
206 | ext_modules=get_extensions(),
207 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
208 | )
209 |
--------------------------------------------------------------------------------
/assets/Grounded-SAM_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/Grounded-SAM_logo.png
--------------------------------------------------------------------------------
/assets/automatic_label_output/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo1.jpg
--------------------------------------------------------------------------------
/assets/automatic_label_output/demo2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo2.jpg
--------------------------------------------------------------------------------
/assets/automatic_label_output/demo4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo4.jpg
--------------------------------------------------------------------------------
/assets/automatic_label_output/demo8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo8.jpg
--------------------------------------------------------------------------------
/assets/automatic_label_output_demo3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output_demo3.jpg
--------------------------------------------------------------------------------
/assets/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo1.jpg
--------------------------------------------------------------------------------
/assets/demo2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo2.jpg
--------------------------------------------------------------------------------
/assets/demo3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo3.jpg
--------------------------------------------------------------------------------
/assets/demo4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo4.jpg
--------------------------------------------------------------------------------
/assets/demo5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo5.jpg
--------------------------------------------------------------------------------
/assets/demo6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo6.jpg
--------------------------------------------------------------------------------
/assets/demo7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo7.jpg
--------------------------------------------------------------------------------
/assets/demo8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo8.jpg
--------------------------------------------------------------------------------
/assets/gradio_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/gradio_demo.png
--------------------------------------------------------------------------------
/assets/grounded_sam.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam.jpg
--------------------------------------------------------------------------------
/assets/grounded_sam2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam2.png
--------------------------------------------------------------------------------
/assets/grounded_sam_demo3_demo4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam_demo3_demo4.png
--------------------------------------------------------------------------------
/assets/grounded_sam_inpainting_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam_inpainting_demo.png
--------------------------------------------------------------------------------
/assets/grounded_sam_output_demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam_output_demo1.jpg
--------------------------------------------------------------------------------
/assets/grounding_dino_output_demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounding_dino_output_demo1.jpg
--------------------------------------------------------------------------------
/assets/inpaint_demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/inpaint_demo.jpg
--------------------------------------------------------------------------------
/grounded_sam_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import copy
4 |
5 | import numpy as np
6 | import json
7 | import torch
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 | # Grounding DINO
11 | import GroundingDINO.groundingdino.datasets.transforms as T
12 | from GroundingDINO.groundingdino.models import build_model
13 | from GroundingDINO.groundingdino.util import box_ops
14 | from GroundingDINO.groundingdino.util.slconfig import SLConfig
15 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
16 |
17 | # segment anything
18 | from segment_anything import build_sam, SamPredictor
19 | import cv2
20 | import numpy as np
21 | import matplotlib.pyplot as plt
22 |
23 |
24 | def load_image(image_path):
25 | # load image
26 | image_pil = Image.open(image_path).convert("RGB") # load image
27 |
28 | transform = T.Compose(
29 | [
30 | T.RandomResize([800], max_size=1333),
31 | T.ToTensor(),
32 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33 | ]
34 | )
35 | image, _ = transform(image_pil, None) # 3, h, w
36 | return image_pil, image
37 |
38 |
39 | def load_model(model_config_path, model_checkpoint_path, device):
40 | args = SLConfig.fromfile(model_config_path)
41 | args.device = device
42 | model = build_model(args)
43 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
44 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
45 | print(load_res)
46 | _ = model.eval()
47 | return model
48 |
49 |
50 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
51 | caption = caption.lower()
52 | caption = caption.strip()
53 | if not caption.endswith("."):
54 | caption = caption + "."
55 | model = model.to(device)
56 | image = image.to(device)
57 | with torch.no_grad():
58 | outputs = model(image[None], captions=[caption])
59 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
60 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
61 | logits.shape[0]
62 |
63 | # filter output
64 | logits_filt = logits.clone()
65 | boxes_filt = boxes.clone()
66 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold
67 | logits_filt = logits_filt[filt_mask] # num_filt, 256
68 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4
69 | logits_filt.shape[0]
70 |
71 | # get phrase
72 | tokenlizer = model.tokenizer
73 | tokenized = tokenlizer(caption)
74 | # build pred
75 | pred_phrases = []
76 | for logit, box in zip(logits_filt, boxes_filt):
77 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
78 | if with_logits:
79 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
80 | else:
81 | pred_phrases.append(pred_phrase)
82 |
83 | return boxes_filt, pred_phrases
84 |
85 | def show_mask(mask, ax, random_color=False):
86 | if random_color:
87 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
88 | else:
89 | color = np.array([30/255, 144/255, 255/255, 0.6])
90 | h, w = mask.shape[-2:]
91 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
92 | ax.imshow(mask_image)
93 |
94 |
95 | def show_box(box, ax, label):
96 | x0, y0 = box[0], box[1]
97 | w, h = box[2] - box[0], box[3] - box[1]
98 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
99 | ax.text(x0, y0, label)
100 |
101 |
102 | def save_mask_data(output_dir, mask_list, box_list, label_list):
103 | value = 0 # 0 for background
104 |
105 | mask_img = torch.zeros(mask_list.shape[-2:])
106 | for idx, mask in enumerate(mask_list):
107 | mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
108 | plt.figure(figsize=(10, 10))
109 | plt.imshow(mask_img.numpy())
110 | plt.axis('off')
111 | plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
112 |
113 | json_data = [{
114 | 'value': value,
115 | 'label': 'background'
116 | }]
117 | for label, box in zip(label_list, box_list):
118 | value += 1
119 | name, logit = label.split('(')
120 | logit = logit[:-1] # the last is ')'
121 | json_data.append({
122 | 'value': value,
123 | 'label': name,
124 | 'logit': float(logit),
125 | 'box': box.numpy().tolist(),
126 | })
127 | with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
128 | json.dump(json_data, f)
129 |
130 |
131 | if __name__ == "__main__":
132 |
133 | parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
134 | parser.add_argument("--config", type=str, required=True, help="path to config file")
135 | parser.add_argument(
136 | "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
137 | )
138 | parser.add_argument(
139 | "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
140 | )
141 | parser.add_argument("--input_image", type=str, required=True, help="path to image file")
142 | parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
143 | parser.add_argument(
144 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
145 | )
146 |
147 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
148 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
149 |
150 | parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
151 | args = parser.parse_args()
152 |
153 | # cfg
154 | config_file = args.config # change the path of the model config file
155 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model
156 | sam_checkpoint = args.sam_checkpoint
157 | image_path = args.input_image
158 | text_prompt = args.text_prompt
159 | output_dir = args.output_dir
160 | box_threshold = args.box_threshold
161 | text_threshold = args.box_threshold
162 | device = args.device
163 |
164 | # make dir
165 | os.makedirs(output_dir, exist_ok=True)
166 | # load image
167 | image_pil, image = load_image(image_path)
168 | # load model
169 | model = load_model(config_file, grounded_checkpoint, device=device)
170 |
171 | # visualize raw image
172 | image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
173 |
174 | # run grounding dino model
175 | boxes_filt, pred_phrases = get_grounding_output(
176 | model, image, text_prompt, box_threshold, text_threshold, device=device
177 | )
178 |
179 | # initialize SAM
180 | predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
181 | image = cv2.imread(image_path)
182 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
183 | predictor.set_image(image)
184 |
185 | size = image_pil.size
186 | H, W = size[1], size[0]
187 | for i in range(boxes_filt.size(0)):
188 | boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
189 | boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
190 | boxes_filt[i][2:] += boxes_filt[i][:2]
191 |
192 | boxes_filt = boxes_filt.cpu()
193 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
194 |
195 | masks, _, _ = predictor.predict_torch(
196 | point_coords = None,
197 | point_labels = None,
198 | boxes = transformed_boxes,
199 | multimask_output = False,
200 | )
201 |
202 | # draw output image
203 | plt.figure(figsize=(10, 10))
204 | plt.imshow(image)
205 | for mask in masks:
206 | show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
207 | for box, label in zip(boxes_filt, pred_phrases):
208 | show_box(box.numpy(), plt.gca(), label)
209 |
210 | plt.axis('off')
211 | plt.savefig(
212 | os.path.join(output_dir, "grounded_sam_output.jpg"),
213 | bbox_inches="tight", dpi=300, pad_inches=0.0
214 | )
215 |
216 | save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
217 |
218 |
--------------------------------------------------------------------------------
/grounded_sam_inpainting_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import copy
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image, ImageDraw, ImageFont
8 |
9 | # Grounding DINO
10 | import GroundingDINO.groundingdino.datasets.transforms as T
11 | from GroundingDINO.groundingdino.models import build_model
12 | from GroundingDINO.groundingdino.util import box_ops
13 | from GroundingDINO.groundingdino.util.slconfig import SLConfig
14 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
15 |
16 | # segment anything
17 | from segment_anything import build_sam, SamPredictor
18 | import cv2
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 |
22 |
23 | # diffusers
24 | import PIL
25 | import requests
26 | import torch
27 | from io import BytesIO
28 | from diffusers import StableDiffusionInpaintPipeline
29 |
30 |
31 | def load_image(image_path):
32 | # load image
33 | image_pil = Image.open(image_path).convert("RGB") # load image
34 |
35 | transform = T.Compose(
36 | [
37 | T.RandomResize([800], max_size=1333),
38 | T.ToTensor(),
39 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
40 | ]
41 | )
42 | image, _ = transform(image_pil, None) # 3, h, w
43 | return image_pil, image
44 |
45 |
46 | def load_model(model_config_path, model_checkpoint_path, device):
47 | args = SLConfig.fromfile(model_config_path)
48 | args.device = device
49 | model = build_model(args)
50 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
51 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
52 | print(load_res)
53 | _ = model.eval()
54 | return model
55 |
56 |
57 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
58 | caption = caption.lower()
59 | caption = caption.strip()
60 | if not caption.endswith("."):
61 | caption = caption + "."
62 | model = model.to(device)
63 | image = image.to(device)
64 | with torch.no_grad():
65 | outputs = model(image[None], captions=[caption])
66 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
67 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
68 | logits.shape[0]
69 |
70 | # filter output
71 | logits_filt = logits.clone()
72 | boxes_filt = boxes.clone()
73 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold
74 | logits_filt = logits_filt[filt_mask] # num_filt, 256
75 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4
76 | logits_filt.shape[0]
77 |
78 | # get phrase
79 | tokenlizer = model.tokenizer
80 | tokenized = tokenlizer(caption)
81 | # build pred
82 | pred_phrases = []
83 | for logit, box in zip(logits_filt, boxes_filt):
84 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
85 | if with_logits:
86 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
87 | else:
88 | pred_phrases.append(pred_phrase)
89 |
90 | return boxes_filt, pred_phrases
91 |
92 | def show_mask(mask, ax, random_color=False):
93 | if random_color:
94 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
95 | else:
96 | color = np.array([30/255, 144/255, 255/255, 0.6])
97 | h, w = mask.shape[-2:]
98 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
99 | ax.imshow(mask_image)
100 |
101 |
102 | def show_box(box, ax, label):
103 | x0, y0 = box[0], box[1]
104 | w, h = box[2] - box[0], box[3] - box[1]
105 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
106 | ax.text(x0, y0, label)
107 |
108 |
109 | if __name__ == "__main__":
110 |
111 | parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
112 | parser.add_argument("--config", type=str, required=True, help="path to config file")
113 | parser.add_argument(
114 | "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
115 | )
116 | parser.add_argument(
117 | "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
118 | )
119 | parser.add_argument("--input_image", type=str, required=True, help="path to image file")
120 | parser.add_argument("--det_prompt", type=str, required=True, help="text prompt")
121 | parser.add_argument("--inpaint_prompt", type=str, required=True, help="inpaint prompt")
122 | parser.add_argument(
123 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
124 | )
125 |
126 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
127 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
128 | parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
129 | args = parser.parse_args()
130 |
131 | # cfg
132 | config_file = args.config # change the path of the model config file
133 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model
134 | sam_checkpoint = args.sam_checkpoint
135 | image_path = args.input_image
136 | det_prompt = args.det_prompt
137 | inpaint_prompt = args.inpaint_prompt
138 | output_dir = args.output_dir
139 | box_threshold = args.box_threshold
140 | text_threshold = args.box_threshold
141 | device = args.device
142 |
143 | # make dir
144 | os.makedirs(output_dir, exist_ok=True)
145 | # load image
146 | image_pil, image = load_image(image_path)
147 | # load model
148 | model = load_model(config_file, grounded_checkpoint, device=device)
149 |
150 | # visualize raw image
151 | image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
152 |
153 | # run grounding dino model
154 | boxes_filt, pred_phrases = get_grounding_output(
155 | model, image, det_prompt, box_threshold, text_threshold, device=device
156 | )
157 |
158 | # initialize SAM
159 | predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
160 | image = cv2.imread(image_path)
161 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
162 | predictor.set_image(image)
163 |
164 | size = image_pil.size
165 | H, W = size[1], size[0]
166 | for i in range(boxes_filt.size(0)):
167 | boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
168 | boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
169 | boxes_filt[i][2:] += boxes_filt[i][:2]
170 |
171 | boxes_filt = boxes_filt.cpu()
172 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
173 |
174 | masks, _, _ = predictor.predict_torch(
175 | point_coords = None,
176 | point_labels = None,
177 | boxes = transformed_boxes,
178 | multimask_output = False,
179 | )
180 |
181 | # masks: [1, 1, 512, 512]
182 |
183 | # inpainting pipeline
184 | mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
185 | mask_pil = Image.fromarray(mask)
186 | image_pil = Image.fromarray(image)
187 |
188 | pipe = StableDiffusionInpaintPipeline.from_pretrained(
189 | "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
190 | )
191 | pipe = pipe.to("cuda")
192 |
193 | # prompt = "A sofa, high quality, detailed"
194 | image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0]
195 | image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg"))
196 |
197 | # draw output image
198 | # plt.figure(figsize=(10, 10))
199 | # plt.imshow(image)
200 | # for mask in masks:
201 | # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
202 | # for box, label in zip(boxes_filt, pred_phrases):
203 | # show_box(box.numpy(), plt.gca(), label)
204 | # plt.axis('off')
205 | # plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight")
206 |
207 |
--------------------------------------------------------------------------------
/grounding_dino_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image, ImageDraw, ImageFont
7 |
8 | import GroundingDINO.groundingdino.datasets.transforms as T
9 | from GroundingDINO.groundingdino.models import build_model
10 | from GroundingDINO.groundingdino.util import box_ops
11 | from GroundingDINO.groundingdino.util.slconfig import SLConfig
12 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
13 |
14 |
15 | def plot_boxes_to_image(image_pil, tgt):
16 | H, W = tgt["size"]
17 | boxes = tgt["boxes"]
18 | labels = tgt["labels"]
19 | assert len(boxes) == len(labels), "boxes and labels must have same length"
20 |
21 | draw = ImageDraw.Draw(image_pil)
22 | mask = Image.new("L", image_pil.size, 0)
23 | mask_draw = ImageDraw.Draw(mask)
24 |
25 | # draw boxes and masks
26 | for box, label in zip(boxes, labels):
27 | # from 0..1 to 0..W, 0..H
28 | box = box * torch.Tensor([W, H, W, H])
29 | # from xywh to xyxy
30 | box[:2] -= box[2:] / 2
31 | box[2:] += box[:2]
32 | # random color
33 | color = tuple(np.random.randint(0, 255, size=3).tolist())
34 | # draw
35 | x0, y0, x1, y1 = box
36 | x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
37 |
38 | draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
39 | # draw.text((x0, y0), str(label), fill=color)
40 |
41 | font = ImageFont.load_default()
42 | if hasattr(font, "getbbox"):
43 | bbox = draw.textbbox((x0, y0), str(label), font)
44 | else:
45 | w, h = draw.textsize(str(label), font)
46 | bbox = (x0, y0, w + x0, y0 + h)
47 | # bbox = draw.textbbox((x0, y0), str(label))
48 | draw.rectangle(bbox, fill=color)
49 | draw.text((x0, y0), str(label), fill="white")
50 |
51 | mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
52 |
53 | return image_pil, mask
54 |
55 |
56 | def load_image(image_path):
57 | # load image
58 | image_pil = Image.open(image_path).convert("RGB") # load image
59 |
60 | transform = T.Compose(
61 | [
62 | T.RandomResize([800], max_size=1333),
63 | T.ToTensor(),
64 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
65 | ]
66 | )
67 | image, _ = transform(image_pil, None) # 3, h, w
68 | return image_pil, image
69 |
70 |
71 | def load_model(model_config_path, model_checkpoint_path, device="cpu"):
72 | args = SLConfig.fromfile(model_config_path)
73 | args.device = device
74 | model = build_model(args)
75 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
76 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
77 | print(load_res)
78 | _ = model.eval()
79 | return model
80 |
81 |
82 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
83 | caption = caption.lower()
84 | caption = caption.strip()
85 | if not caption.endswith("."):
86 | caption = caption + "."
87 | model = model.to(device)
88 | image = image.to(device)
89 | with torch.no_grad():
90 | outputs = model(image[None], captions=[caption])
91 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
92 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
93 | logits.shape[0]
94 |
95 | # filter output
96 | logits_filt = logits.clone()
97 | boxes_filt = boxes.clone()
98 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold
99 | logits_filt = logits_filt[filt_mask] # num_filt, 256
100 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4
101 | logits_filt.shape[0]
102 |
103 | # get phrase
104 | tokenlizer = model.tokenizer
105 | tokenized = tokenlizer(caption)
106 | # build pred
107 | pred_phrases = []
108 | for logit, box in zip(logits_filt, boxes_filt):
109 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
110 | if with_logits:
111 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
112 | else:
113 | pred_phrases.append(pred_phrase)
114 |
115 | return boxes_filt, pred_phrases
116 |
117 |
118 | if __name__ == "__main__":
119 |
120 | parser = argparse.ArgumentParser("Grounding DINO example", add_help=True)
121 | parser.add_argument("--config", type=str, required=True, help="path to config file")
122 | parser.add_argument(
123 | "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
124 | )
125 | parser.add_argument("--input_image", type=str, required=True, help="path to image file")
126 | parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
127 | parser.add_argument(
128 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
129 | )
130 |
131 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
132 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
133 |
134 | parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
135 | args = parser.parse_args()
136 |
137 | # cfg
138 | config_file = args.config # change the path of the model config file
139 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model
140 | image_path = args.input_image
141 | text_prompt = args.text_prompt
142 | output_dir = args.output_dir
143 | box_threshold = args.box_threshold
144 | text_threshold = args.box_threshold
145 | device = args.device
146 |
147 | # make dir
148 | os.makedirs(output_dir, exist_ok=True)
149 | # load image
150 | image_pil, image = load_image(image_path)
151 | # load model
152 | model = load_model(config_file, grounded_checkpoint, device=device)
153 |
154 | # visualize raw image
155 | # image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
156 |
157 | # run model
158 | boxes_filt, pred_phrases = get_grounding_output(
159 | model, image, text_prompt, box_threshold, text_threshold, device=device
160 | )
161 |
162 | # visualize pred
163 | size = image_pil.size
164 | pred_dict = {
165 | "boxes": boxes_filt,
166 | "size": [size[1], size[0]], # H,W
167 | "labels": pred_phrases,
168 | }
169 | # import ipdb; ipdb.set_trace()
170 | image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0]
171 | image_with_box.save(os.path.join(output_dir, "grounding_dino_output.jpg"))
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | addict
2 | diffusers
3 | gradio
4 | huggingface_hub
5 | matplotlib
6 | numpy
7 | onnxruntime
8 | opencv_python
9 | Pillow
10 | pycocotools
11 | PyYAML
12 | requests
13 | setuptools
14 | supervision
15 | termcolor
16 | timm
17 | torch
18 | torchvision
19 | transformers
20 | yapf
21 |
--------------------------------------------------------------------------------
/segment_anything/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002
3 | max-line-length = 100
4 | max-complexity = 18
5 | select = B,C,E,F,W,T4,B9
6 | per-file-ignores =
7 | **/__init__.py:F401,F403,E402
8 |
--------------------------------------------------------------------------------
/segment_anything/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/segment_anything/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to segment-anything
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 | ## Contributor License Agreement ("CLA")
16 | In order to accept your pull request, we need you to submit a CLA. You only need
17 | to do this once to work on any of Facebook's open source projects.
18 |
19 | Complete your CLA here:
20 |
21 | ## Issues
22 | We use GitHub issues to track public bugs. Please ensure your description is
23 | clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26 | disclosure of security bugs. In those cases, please go through the process
27 | outlined on that page and do not file a public issue.
28 |
29 | ## License
30 | By contributing to segment-anything, you agree that your contributions will be licensed
31 | under the LICENSE file in the root directory of this source tree.
32 |
--------------------------------------------------------------------------------
/segment_anything/README.md:
--------------------------------------------------------------------------------
1 | # Segment Anything
2 |
3 | **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
4 |
5 | [Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/)
6 |
7 | [[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)]
8 |
9 | 
10 |
11 | The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Installation
19 |
20 | The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
21 |
22 | Install Segment Anything:
23 |
24 | ```
25 | pip install git+https://github.com/facebookresearch/segment-anything.git
26 | ```
27 |
28 | or clone the repository locally and install with
29 |
30 | ```
31 | git clone git@github.com:facebookresearch/segment-anything.git
32 | cd segment-anything; pip install -e .
33 | ```
34 |
35 | The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks.
36 | ```
37 | pip install opencv-python pycocotools matplotlib onnxruntime onnx
38 | ```
39 |
40 |
41 | ## Getting Started
42 |
43 | First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt:
44 |
45 | ```
46 | from segment_anything import build_sam, SamPredictor
47 | predictor = SamPredictor(build_sam(checkpoint=""))
48 | predictor.set_image()
49 | masks, _, _ = predictor.predict()
50 | ```
51 |
52 | or generate masks for an entire image:
53 |
54 | ```
55 | from segment_anything import build_sam, SamAutomaticMaskGenerator
56 | mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint=""))
57 | masks = mask_generator_generate()
58 | ```
59 |
60 | Additionally, masks can be generated for images from the command line:
61 |
62 | ```
63 | python scripts/amg.py --checkpoint --input --output
64 | ```
65 |
66 | See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details.
67 |
68 |
69 |
70 |
71 |
72 |
73 | ## ONNX Export
74 |
75 | SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with
76 |
77 | ```
78 | python scripts/export_onnx_model.py --checkpoint --output
79 | ```
80 |
81 | See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export.
82 |
83 | ## Model Checkpoints
84 |
85 | Three model versions of the model are available with different backbone sizes. These models can be instantiated by running
86 | ```
87 | from segment_anything import sam_model_registry
88 | sam = sam_model_registry[""](checkpoint="")
89 | ```
90 | Click the links below to download the checkpoint for the corresponding model name. The default model in bold can also be instantiated with `build_sam`, as in the examples in [Getting Started](#getting-started).
91 |
92 | * **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)**
93 | * `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth)
94 | * `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
95 |
96 | ## License
97 | The model is licensed under the [Apache 2.0 license](LICENSE).
98 |
99 | ## Contributing
100 |
101 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
102 |
103 | ## Contributors
104 |
105 | The Segment Anything project was made possible with the help of many contributors (alphabetical):
106 |
107 | Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom
108 |
--------------------------------------------------------------------------------
/segment_anything/assets/masks1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/masks1.png
--------------------------------------------------------------------------------
/segment_anything/assets/masks2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/masks2.jpg
--------------------------------------------------------------------------------
/segment_anything/assets/model_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/model_diagram.png
--------------------------------------------------------------------------------
/segment_anything/assets/notebook1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/notebook1.png
--------------------------------------------------------------------------------
/segment_anything/assets/notebook2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/notebook2.png
--------------------------------------------------------------------------------
/segment_anything/linter.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | {
5 | black --version | grep -E "23\." > /dev/null
6 | } || {
7 | echo "Linter requires 'black==23.*' !"
8 | exit 1
9 | }
10 |
11 | ISORT_VERSION=$(isort --version-number)
12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then
13 | echo "Linter requires isort==5.12.0 !"
14 | exit 1
15 | fi
16 |
17 | echo "Running isort ..."
18 | isort . --atomic
19 |
20 | echo "Running black ..."
21 | black -l 100 .
22 |
23 | echo "Running flake8 ..."
24 | if [ -x "$(command -v flake8)" ]; then
25 | flake8 .
26 | else
27 | python3 -m flake8 .
28 | fi
29 |
30 | echo "Running mypy..."
31 |
32 | mypy --exclude 'setup.py|notebooks' .
33 |
--------------------------------------------------------------------------------
/segment_anything/notebooks/images/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/notebooks/images/dog.jpg
--------------------------------------------------------------------------------
/segment_anything/notebooks/images/groceries.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/notebooks/images/groceries.jpg
--------------------------------------------------------------------------------
/segment_anything/notebooks/images/truck.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/notebooks/images/truck.jpg
--------------------------------------------------------------------------------
/segment_anything/scripts/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import cv2 # type: ignore
8 |
9 | from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
10 |
11 | import argparse
12 | import json
13 | import os
14 | from typing import Any, Dict, List
15 |
16 | parser = argparse.ArgumentParser(
17 | description=(
18 | "Runs automatic mask generation on an input image or directory of images, "
19 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
20 | "as well as pycocotools if saving in RLE format."
21 | )
22 | )
23 |
24 | parser.add_argument(
25 | "--input",
26 | type=str,
27 | required=True,
28 | help="Path to either a single input image or folder of images.",
29 | )
30 |
31 | parser.add_argument(
32 | "--output",
33 | type=str,
34 | required=True,
35 | help=(
36 | "Path to the directory where masks will be output. Output will be either a folder "
37 | "of PNGs per image or a single json with COCO-style masks."
38 | ),
39 | )
40 |
41 | parser.add_argument(
42 | "--model-type",
43 | type=str,
44 | default="default",
45 | help="The type of model to load, in ['default', 'vit_l', 'vit_b']",
46 | )
47 |
48 | parser.add_argument(
49 | "--checkpoint",
50 | type=str,
51 | required=True,
52 | help="The path to the SAM checkpoint to use for mask generation.",
53 | )
54 |
55 | parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
56 |
57 | parser.add_argument(
58 | "--convert-to-rle",
59 | action="store_true",
60 | help=(
61 | "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
62 | "Requires pycocotools."
63 | ),
64 | )
65 |
66 | amg_settings = parser.add_argument_group("AMG Settings")
67 |
68 | amg_settings.add_argument(
69 | "--points-per-side",
70 | type=int,
71 | default=None,
72 | help="Generate masks by sampling a grid over the image with this many points to a side.",
73 | )
74 |
75 | amg_settings.add_argument(
76 | "--points-per-batch",
77 | type=int,
78 | default=None,
79 | help="How many input points to process simultaneously in one batch.",
80 | )
81 |
82 | amg_settings.add_argument(
83 | "--pred-iou-thresh",
84 | type=float,
85 | default=None,
86 | help="Exclude masks with a predicted score from the model that is lower than this threshold.",
87 | )
88 |
89 | amg_settings.add_argument(
90 | "--stability-score-thresh",
91 | type=float,
92 | default=None,
93 | help="Exclude masks with a stability score lower than this threshold.",
94 | )
95 |
96 | amg_settings.add_argument(
97 | "--stability-score-offset",
98 | type=float,
99 | default=None,
100 | help="Larger values perturb the mask more when measuring stability score.",
101 | )
102 |
103 | amg_settings.add_argument(
104 | "--box-nms-thresh",
105 | type=float,
106 | default=None,
107 | help="The overlap threshold for excluding a duplicate mask.",
108 | )
109 |
110 | amg_settings.add_argument(
111 | "--crop-n-layers",
112 | type=int,
113 | default=None,
114 | help=(
115 | "If >0, mask generation is run on smaller crops of the image to generate more masks. "
116 | "The value sets how many different scales to crop at."
117 | ),
118 | )
119 |
120 | amg_settings.add_argument(
121 | "--crop-nms-thresh",
122 | type=float,
123 | default=None,
124 | help="The overlap threshold for excluding duplicate masks across different crops.",
125 | )
126 |
127 | amg_settings.add_argument(
128 | "--crop-overlap-ratio",
129 | type=int,
130 | default=None,
131 | help="Larger numbers mean image crops will overlap more.",
132 | )
133 |
134 | amg_settings.add_argument(
135 | "--crop-n-points-downscale-factor",
136 | type=int,
137 | default=None,
138 | help="The number of points-per-side in each layer of crop is reduced by this factor.",
139 | )
140 |
141 | amg_settings.add_argument(
142 | "--min-mask-region-area",
143 | type=int,
144 | default=None,
145 | help=(
146 | "Disconnected mask regions or holes with area smaller than this value "
147 | "in pixels are removed by postprocessing."
148 | ),
149 | )
150 |
151 |
152 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
153 | header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
154 | metadata = [header]
155 | for i, mask_data in enumerate(masks):
156 | mask = mask_data["segmentation"]
157 | filename = f"{i}.png"
158 | cv2.imwrite(os.path.join(path, filename), mask * 255)
159 | mask_metadata = [
160 | str(i),
161 | str(mask_data["area"]),
162 | *[str(x) for x in mask_data["bbox"]],
163 | *[str(x) for x in mask_data["point_coords"][0]],
164 | str(mask_data["predicted_iou"]),
165 | str(mask_data["stability_score"]),
166 | *[str(x) for x in mask_data["crop_box"]],
167 | ]
168 | row = ",".join(mask_metadata)
169 | metadata.append(row)
170 | metadata_path = os.path.join(path, "metadata.csv")
171 | with open(metadata_path, "w") as f:
172 | f.write("\n".join(metadata))
173 |
174 | return
175 |
176 |
177 | def get_amg_kwargs(args):
178 | amg_kwargs = {
179 | "points_per_side": args.points_per_side,
180 | "points_per_batch": args.points_per_batch,
181 | "pred_iou_thresh": args.pred_iou_thresh,
182 | "stability_score_thresh": args.stability_score_thresh,
183 | "stability_score_offset": args.stability_score_offset,
184 | "box_nms_thresh": args.box_nms_thresh,
185 | "crop_n_layers": args.crop_n_layers,
186 | "crop_nms_thresh": args.crop_nms_thresh,
187 | "crop_overlap_ratio": args.crop_overlap_ratio,
188 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
189 | "min_mask_region_area": args.min_mask_region_area,
190 | }
191 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
192 | return amg_kwargs
193 |
194 |
195 | def main(args: argparse.Namespace) -> None:
196 | print("Loading model...")
197 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
198 | _ = sam.to(device=args.device)
199 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
200 | amg_kwargs = get_amg_kwargs(args)
201 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
202 |
203 | if not os.path.isdir(args.input):
204 | targets = [args.input]
205 | else:
206 | targets = [
207 | f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
208 | ]
209 | targets = [os.path.join(args.input, f) for f in targets]
210 |
211 | os.makedirs(args.output, exist_ok=True)
212 |
213 | for t in targets:
214 | print(f"Processing '{t}'...")
215 | image = cv2.imread(t)
216 | if image is None:
217 | print(f"Could not load '{t}' as an image, skipping...")
218 | continue
219 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
220 |
221 | masks = generator.generate(image)
222 |
223 | base = os.path.basename(t)
224 | base = os.path.splitext(base)[0]
225 | save_base = os.path.join(args.output, base)
226 | if output_mode == "binary_mask":
227 | os.makedirs(save_base, exist_ok=False)
228 | write_masks_to_folder(masks, save_base)
229 | else:
230 | save_file = save_base + ".json"
231 | with open(save_file, "w") as f:
232 | json.dump(masks, f)
233 | print("Done!")
234 |
235 |
236 | if __name__ == "__main__":
237 | args = parser.parse_args()
238 | main(args)
239 |
--------------------------------------------------------------------------------
/segment_anything/scripts/export_onnx_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
10 | from segment_anything.utils.onnx import SamOnnxModel
11 |
12 | import argparse
13 | import warnings
14 |
15 | try:
16 | import onnxruntime # type: ignore
17 |
18 | onnxruntime_exists = True
19 | except ImportError:
20 | onnxruntime_exists = False
21 |
22 | parser = argparse.ArgumentParser(
23 | description="Export the SAM prompt encoder and mask decoder to an ONNX model."
24 | )
25 |
26 | parser.add_argument(
27 | "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
28 | )
29 |
30 | parser.add_argument(
31 | "--output", type=str, required=True, help="The filename to save the ONNX model to."
32 | )
33 |
34 | parser.add_argument(
35 | "--model-type",
36 | type=str,
37 | default="default",
38 | help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
39 | )
40 |
41 | parser.add_argument(
42 | "--return-single-mask",
43 | action="store_true",
44 | help=(
45 | "If true, the exported ONNX model will only return the best mask, "
46 | "instead of returning multiple masks. For high resolution images "
47 | "this can improve runtime when upscaling masks is expensive."
48 | ),
49 | )
50 |
51 | parser.add_argument(
52 | "--opset",
53 | type=int,
54 | default=17,
55 | help="The ONNX opset version to use. Must be >=11",
56 | )
57 |
58 | parser.add_argument(
59 | "--quantize-out",
60 | type=str,
61 | default=None,
62 | help=(
63 | "If set, will quantize the model and save it with this name. "
64 | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
65 | ),
66 | )
67 |
68 | parser.add_argument(
69 | "--gelu-approximate",
70 | action="store_true",
71 | help=(
72 | "Replace GELU operations with approximations using tanh. Useful "
73 | "for some runtimes that have slow or unimplemented erf ops, used in GELU."
74 | ),
75 | )
76 |
77 | parser.add_argument(
78 | "--use-stability-score",
79 | action="store_true",
80 | help=(
81 | "Replaces the model's predicted mask quality score with the stability "
82 | "score calculated on the low resolution masks using an offset of 1.0. "
83 | ),
84 | )
85 |
86 | parser.add_argument(
87 | "--return-extra-metrics",
88 | action="store_true",
89 | help=(
90 | "The model will return five results: (masks, scores, stability_scores, "
91 | "areas, low_res_logits) instead of the usual three. This can be "
92 | "significantly slower for high resolution outputs."
93 | ),
94 | )
95 |
96 |
97 | def run_export(
98 | model_type: str,
99 | checkpoint: str,
100 | output: str,
101 | opset: int,
102 | return_single_mask: bool,
103 | gelu_approximate: bool = False,
104 | use_stability_score: bool = False,
105 | return_extra_metrics=False,
106 | ):
107 | print("Loading model...")
108 | if model_type == "vit_b":
109 | sam = build_sam_vit_b(checkpoint)
110 | elif model_type == "vit_l":
111 | sam = build_sam_vit_l(checkpoint)
112 | else:
113 | sam = build_sam(checkpoint)
114 |
115 | onnx_model = SamOnnxModel(
116 | model=sam,
117 | return_single_mask=return_single_mask,
118 | use_stability_score=use_stability_score,
119 | return_extra_metrics=return_extra_metrics,
120 | )
121 |
122 | if gelu_approximate:
123 | for n, m in onnx_model.named_modules():
124 | if isinstance(m, torch.nn.GELU):
125 | m.approximate = "tanh"
126 |
127 | dynamic_axes = {
128 | "point_coords": {1: "num_points"},
129 | "point_labels": {1: "num_points"},
130 | }
131 |
132 | embed_dim = sam.prompt_encoder.embed_dim
133 | embed_size = sam.prompt_encoder.image_embedding_size
134 | mask_input_size = [4 * x for x in embed_size]
135 | dummy_inputs = {
136 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
137 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
138 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
139 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
140 | "has_mask_input": torch.tensor([1], dtype=torch.float),
141 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
142 | }
143 |
144 | _ = onnx_model(**dummy_inputs)
145 |
146 | output_names = ["masks", "iou_predictions", "low_res_masks"]
147 |
148 | with warnings.catch_warnings():
149 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
150 | warnings.filterwarnings("ignore", category=UserWarning)
151 | with open(output, "wb") as f:
152 | print(f"Exporing onnx model to {output}...")
153 | torch.onnx.export(
154 | onnx_model,
155 | tuple(dummy_inputs.values()),
156 | f,
157 | export_params=True,
158 | verbose=False,
159 | opset_version=opset,
160 | do_constant_folding=True,
161 | input_names=list(dummy_inputs.keys()),
162 | output_names=output_names,
163 | dynamic_axes=dynamic_axes,
164 | )
165 |
166 | if onnxruntime_exists:
167 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
168 | ort_session = onnxruntime.InferenceSession(output)
169 | _ = ort_session.run(None, ort_inputs)
170 | print("Model has successfully been run with ONNXRuntime.")
171 |
172 |
173 | def to_numpy(tensor):
174 | return tensor.cpu().numpy()
175 |
176 |
177 | if __name__ == "__main__":
178 | args = parser.parse_args()
179 | run_export(
180 | model_type=args.model_type,
181 | checkpoint=args.checkpoint,
182 | output=args.output,
183 | opset=args.opset,
184 | return_single_mask=args.return_single_mask,
185 | gelu_approximate=args.gelu_approximate,
186 | use_stability_score=args.use_stability_score,
187 | return_extra_metrics=args.return_extra_metrics,
188 | )
189 |
190 | if args.quantize_out is not None:
191 | assert onnxruntime_exists, "onnxruntime is required to quantize the model."
192 | from onnxruntime.quantization import QuantType # type: ignore
193 | from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
194 |
195 | print(f"Quantizing model and writing to {args.quantize_out}...")
196 | quantize_dynamic(
197 | model_input=args.output,
198 | model_output=args.quantize_out,
199 | optimize_model=True,
200 | per_channel=False,
201 | reduce_range=False,
202 | weight_type=QuantType.QUInt8,
203 | )
204 | print("Done!")
205 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from functools import partial
10 |
11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam,
49 | "vit_h": build_sam,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | tranformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [
62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63 | for i in range(self.num_mask_tokens)
64 | ]
65 | )
66 |
67 | self.iou_prediction_head = MLP(
68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69 | )
70 |
71 | def forward(
72 | self,
73 | image_embeddings: torch.Tensor,
74 | image_pe: torch.Tensor,
75 | sparse_prompt_embeddings: torch.Tensor,
76 | dense_prompt_embeddings: torch.Tensor,
77 | multimask_output: bool,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """
80 | Predict masks given image and prompt embeddings.
81 |
82 | Arguments:
83 | image_embeddings (torch.Tensor): the embeddings from the image encoder
84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87 | multimask_output (bool): Whether to return multiple masks or a single
88 | mask.
89 |
90 | Returns:
91 | torch.Tensor: batched predicted masks
92 | torch.Tensor: batched predictions of mask quality
93 | """
94 | masks, iou_pred = self.predict_masks(
95 | image_embeddings=image_embeddings,
96 | image_pe=image_pe,
97 | sparse_prompt_embeddings=sparse_prompt_embeddings,
98 | dense_prompt_embeddings=dense_prompt_embeddings,
99 | )
100 |
101 | # Select the correct mask or masks for outptu
102 | if multimask_output:
103 | mask_slice = slice(1, None)
104 | else:
105 | mask_slice = slice(0, 1)
106 | masks = masks[:, mask_slice, :, :]
107 | iou_pred = iou_pred[:, mask_slice]
108 |
109 | # Prepare output
110 | return masks, iou_pred
111 |
112 | def predict_masks(
113 | self,
114 | image_embeddings: torch.Tensor,
115 | image_pe: torch.Tensor,
116 | sparse_prompt_embeddings: torch.Tensor,
117 | dense_prompt_embeddings: torch.Tensor,
118 | ) -> Tuple[torch.Tensor, torch.Tensor]:
119 | """Predicts masks. See 'forward' for more details."""
120 | # Concatenate output tokens
121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
124 |
125 | # Expand per-image data in batch direction to be per-mask
126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127 | src = src + dense_prompt_embeddings
128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
129 | b, c, h, w = src.shape
130 |
131 | # Run the transformer
132 | hs, src = self.transformer(src, pos_src, tokens)
133 | iou_token_out = hs[:, 0, :]
134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
135 |
136 | # Upscale mask embeddings and predict masks using the mask tokens
137 | src = src.transpose(1, 2).view(b, c, h, w)
138 | upscaled_embedding = self.output_upscaling(src)
139 | hyper_in_list: List[torch.Tensor] = []
140 | for i in range(self.num_mask_tokens):
141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
142 | hyper_in = torch.stack(hyper_in_list, dim=1)
143 | b, c, h, w = upscaled_embedding.shape
144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
145 |
146 | # Generate mask quality predictions
147 | iou_pred = self.iou_prediction_head(iou_token_out)
148 |
149 | return masks, iou_pred
150 |
151 |
152 | # Lightly adapted from
153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
154 | class MLP(nn.Module):
155 | def __init__(
156 | self,
157 | input_dim: int,
158 | hidden_dim: int,
159 | output_dim: int,
160 | num_layers: int,
161 | sigmoid_output: bool = False,
162 | ) -> None:
163 | super().__init__()
164 | self.num_layers = num_layers
165 | h = [hidden_dim] * (num_layers - 1)
166 | self.layers = nn.ModuleList(
167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
168 | )
169 | self.sigmoid_output = sigmoid_output
170 |
171 | def forward(self, x):
172 | for i, layer in enumerate(self.layers):
173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
174 | if self.sigmoid_output:
175 | x = F.sigmoid(x)
176 | return x
177 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from typing import Any, Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | points = torch.cat([points, padding_point], dim=1)
85 | labels = torch.cat([labels, padding_label], dim=1)
86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87 | point_embedding[labels == -1] = 0.0
88 | point_embedding[labels == -1] += self.not_a_point_embed.weight
89 | point_embedding[labels == 0] += self.point_embeddings[0].weight
90 | point_embedding[labels == 1] += self.point_embeddings[1].weight
91 | return point_embedding
92 |
93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94 | """Embeds box prompts."""
95 | boxes = boxes + 0.5 # Shift to center of pixel
96 | coords = boxes.reshape(-1, 2, 2)
97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100 | return corner_embedding
101 |
102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103 | """Embeds mask inputs."""
104 | mask_embedding = self.mask_downscaling(masks)
105 | return mask_embedding
106 |
107 | def _get_batch_size(
108 | self,
109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110 | boxes: Optional[torch.Tensor],
111 | masks: Optional[torch.Tensor],
112 | ) -> int:
113 | """
114 | Gets the batch size of the output given the batch size of the input prompts.
115 | """
116 | if points is not None:
117 | return points[0].shape[0]
118 | elif boxes is not None:
119 | return boxes.shape[0]
120 | elif masks is not None:
121 | return masks.shape[0]
122 | else:
123 | return 1
124 |
125 | def _get_device(self) -> torch.device:
126 | return self.point_embeddings[0].weight.device
127 |
128 | def forward(
129 | self,
130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131 | boxes: Optional[torch.Tensor],
132 | masks: Optional[torch.Tensor],
133 | ) -> Tuple[torch.Tensor, torch.Tensor]:
134 | """
135 | Embeds different types of prompts, returning both sparse and dense
136 | embeddings.
137 |
138 | Arguments:
139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140 | and labels to embed.
141 | boxes (torch.Tensor or none): boxes to embed
142 | masks (torch.Tensor or none): masks to embed
143 |
144 | Returns:
145 | torch.Tensor: sparse embeddings for the points and boxes, with shape
146 | BxNx(embed_dim), where N is determined by the number of input points
147 | and boxes.
148 | torch.Tensor: dense embeddings for the masks, in the shape
149 | Bx(embed_dim)x(embed_H)x(embed_W)
150 | """
151 | bs = self._get_batch_size(points, boxes, masks)
152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153 | if points is not None:
154 | coords, labels = points
155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157 | if boxes is not None:
158 | box_embeddings = self._embed_boxes(boxes)
159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160 |
161 | if masks is not None:
162 | dense_embeddings = self._embed_masks(masks)
163 | else:
164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166 | )
167 |
168 | return sparse_embeddings, dense_embeddings
169 |
170 |
171 | class PositionEmbeddingRandom(nn.Module):
172 | """
173 | Positional encoding using random spatial frequencies.
174 | """
175 |
176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177 | super().__init__()
178 | if scale is None or scale <= 0.0:
179 | scale = 1.0
180 | self.register_buffer(
181 | "positional_encoding_gaussian_matrix",
182 | scale * torch.randn((2, num_pos_feats)),
183 | )
184 |
185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186 | """Positionally encode points that are normalized to [0,1]."""
187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188 | coords = 2 * coords - 1
189 | coords = coords @ self.positional_encoding_gaussian_matrix
190 | coords = 2 * np.pi * coords
191 | # outputs d_1 x ... x d_n x C shape
192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193 |
194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195 | """Generate positional encoding for a grid of the specified size."""
196 | h, w = size
197 | device: Any = self.positional_encoding_gaussian_matrix.device
198 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
199 | y_embed = grid.cumsum(dim=0) - 0.5
200 | x_embed = grid.cumsum(dim=1) - 0.5
201 | y_embed = y_embed / h
202 | x_embed = x_embed / w
203 |
204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205 | return pe.permute(2, 0, 1) # C x H x W
206 |
207 | def forward_with_coords(
208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
209 | ) -> torch.Tensor:
210 | """Positionally encode points that are not normalized to [0,1]."""
211 | coords = coords_input.clone()
212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
215 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Any, Dict, List, Tuple
12 |
13 | from .image_encoder import ImageEncoderViT
14 | from .mask_decoder import MaskDecoder
15 | from .prompt_encoder import PromptEncoder
16 |
17 |
18 | class Sam(nn.Module):
19 | mask_threshold: float = 0.0
20 | image_format: str = "RGB"
21 |
22 | def __init__(
23 | self,
24 | image_encoder: ImageEncoderViT,
25 | prompt_encoder: PromptEncoder,
26 | mask_decoder: MaskDecoder,
27 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
28 | pixel_std: List[float] = [58.395, 57.12, 57.375],
29 | ) -> None:
30 | """
31 | SAM predicts object masks from an image and input prompts.
32 |
33 | Arguments:
34 | image_encoder (ImageEncoderViT): The backbone used to encode the
35 | image into image embeddings that allow for efficient mask prediction.
36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38 | and encoded prompts.
39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
41 | """
42 | super().__init__()
43 | self.image_encoder = image_encoder
44 | self.prompt_encoder = prompt_encoder
45 | self.mask_decoder = mask_decoder
46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48 |
49 | @property
50 | def device(self) -> Any:
51 | return self.pixel_mean.device
52 |
53 | @torch.no_grad()
54 | def forward(
55 | self,
56 | batched_input: List[Dict[str, Any]],
57 | multimask_output: bool,
58 | ) -> List[Dict[str, torch.Tensor]]:
59 | """
60 | Predicts masks end-to-end from provided images and prompts.
61 | If prompts are not known in advance, using SamPredictor is
62 | recommended over calling the model directly.
63 |
64 | Arguments:
65 | batched_input (list(dict)): A list over input images, each a
66 | dictionary with the following keys. A prompt key can be
67 | excluded if it is not present.
68 | 'image': The image as a torch tensor in 3xHxW format,
69 | already transformed for input to the model.
70 | 'original_size': (tuple(int, int)) The original size of
71 | the image before transformation, as (H, W).
72 | 'point_coords': (torch.Tensor) Batched point prompts for
73 | this image, with shape BxNx2. Already transformed to the
74 | input frame of the model.
75 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
76 | with shape BxN.
77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78 | Already transformed to the input frame of the model.
79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80 | in the form Bx1xHxW.
81 | multimask_output (bool): Whether the model should predict multiple
82 | disambiguating masks, or return a single mask.
83 |
84 | Returns:
85 | (list(dict)): A list over input images, where each element is
86 | as dictionary with the following keys.
87 | 'masks': (torch.Tensor) Batched binary mask predictions,
88 | with shape BxCxHxW, where B is the number of input promts,
89 | C is determiend by multimask_output, and (H, W) is the
90 | original size of the image.
91 | 'iou_predictions': (torch.Tensor) The model's predictions
92 | of mask quality, in shape BxC.
93 | 'low_res_logits': (torch.Tensor) Low resolution logits with
94 | shape BxCxHxW, where H=W=256. Can be passed as mask input
95 | to subsequent iterations of prediction.
96 | """
97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98 | image_embeddings = self.image_encoder(input_images)
99 |
100 | outputs = []
101 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
102 | if "point_coords" in image_record:
103 | points = (image_record["point_coords"], image_record["point_labels"])
104 | else:
105 | points = None
106 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
107 | points=points,
108 | boxes=image_record.get("boxes", None),
109 | masks=image_record.get("mask_inputs", None),
110 | )
111 | low_res_masks, iou_predictions = self.mask_decoder(
112 | image_embeddings=curr_embedding.unsqueeze(0),
113 | image_pe=self.prompt_encoder.get_dense_pe(),
114 | sparse_prompt_embeddings=sparse_embeddings,
115 | dense_prompt_embeddings=dense_embeddings,
116 | multimask_output=multimask_output,
117 | )
118 | masks = self.postprocess_masks(
119 | low_res_masks,
120 | input_size=image_record["image"].shape[-2:],
121 | original_size=image_record["original_size"],
122 | )
123 | masks = masks > self.mask_threshold
124 | outputs.append(
125 | {
126 | "masks": masks,
127 | "iou_predictions": iou_predictions,
128 | "low_res_logits": low_res_masks,
129 | }
130 | )
131 | return outputs
132 |
133 | def postprocess_masks(
134 | self,
135 | masks: torch.Tensor,
136 | input_size: Tuple[int, ...],
137 | original_size: Tuple[int, ...],
138 | ) -> torch.Tensor:
139 | """
140 | Remove padding and upscale masks to the original image size.
141 |
142 | Arguments:
143 | masks (torch.Tensor): Batched masks from the mask_decoder,
144 | in BxCxHxW format.
145 | input_size (tuple(int, int)): The size of the image input to the
146 | model, in (H, W) format. Used to remove padding.
147 | original_size (tuple(int, int)): The original size of the image
148 | before resizing for input to the model, in (H, W) format.
149 |
150 | Returns:
151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
152 | is given by original_size.
153 | """
154 | masks = F.interpolate(
155 | masks,
156 | (self.image_encoder.img_size, self.image_encoder.img_size),
157 | mode="bilinear",
158 | align_corners=False,
159 | )
160 | masks = masks[..., : input_size[0], : input_size[1]]
161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
162 | return masks
163 |
164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
165 | """Normalize pixel values and pad to a square input."""
166 | # Normalize colors
167 | x = (x - self.pixel_mean) / self.pixel_std
168 |
169 | # Pad
170 | h, w = x.shape[-2:]
171 | padh = self.image_encoder.img_size - h
172 | padw = self.image_encoder.img_size - w
173 | x = F.pad(x, (0, padw, 0, padh))
174 | return x
175 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | from .common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attenion layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202 |
203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207 |
208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209 | b, n, c = x.shape
210 | x = x.reshape(b, n, num_heads, c // num_heads)
211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212 |
213 | def _recombine_heads(self, x: Tensor) -> Tensor:
214 | b, n_heads, n_tokens, c_per_head = x.shape
215 | x = x.transpose(1, 2)
216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217 |
218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219 | # Input projections
220 | q = self.q_proj(q)
221 | k = self.k_proj(k)
222 | v = self.v_proj(v)
223 |
224 | # Separate into heads
225 | q = self._separate_heads(q, self.num_heads)
226 | k = self._separate_heads(k, self.num_heads)
227 | v = self._separate_heads(v, self.num_heads)
228 |
229 | # Attention
230 | _, _, _, c_per_head = q.shape
231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232 | attn = attn / math.sqrt(c_per_head)
233 | attn = torch.softmax(attn, dim=-1)
234 |
235 | # Get output
236 | out = attn @ v
237 | out = self._recombine_heads(out)
238 | out = self.out_proj(out)
239 |
240 | return out
241 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/segment_anything/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(
40 | original_size[0], original_size[1], self.target_length
41 | )
42 | coords = deepcopy(coords).astype(float)
43 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
44 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
45 | return coords
46 |
47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
48 | """
49 | Expects a numpy array shape Bx4. Requires the original image size
50 | in (H, W) format.
51 | """
52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
53 | return boxes.reshape(-1, 4)
54 |
55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
56 | """
57 | Expects batched images with shape BxCxHxW and float format. This
58 | transformation may not exactly match apply_image. apply_image is
59 | the transformation expected by the model.
60 | """
61 | # Expects an image in BCHW format. May not exactly match apply_image.
62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
63 | return F.interpolate(
64 | image, target_size, mode="bilinear", align_corners=False, antialias=True
65 | )
66 |
67 | def apply_coords_torch(
68 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
69 | ) -> torch.Tensor:
70 | """
71 | Expects a torch tensor with length 2 in the last dimension. Requires the
72 | original image size in (H, W) format.
73 | """
74 | old_h, old_w = original_size
75 | new_h, new_w = self.get_preprocess_shape(
76 | original_size[0], original_size[1], self.target_length
77 | )
78 | coords = deepcopy(coords).to(torch.float)
79 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
80 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
81 | return coords
82 |
83 | def apply_boxes_torch(
84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
85 | ) -> torch.Tensor:
86 | """
87 | Expects a torch tensor with shape Bx4. Requires the original image
88 | size in (H, W) format.
89 | """
90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
91 | return boxes.reshape(-1, 4)
92 |
93 | @staticmethod
94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
95 | """
96 | Compute the output size given input size and target long side length.
97 | """
98 | scale = long_side_length * 1.0 / max(oldh, oldw)
99 | newh, neww = oldh * scale, oldw * scale
100 | neww = int(neww + 0.5)
101 | newh = int(newh + 0.5)
102 | return (newh, neww)
103 |
--------------------------------------------------------------------------------
/segment_anything/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | line_length=100
3 | multi_line_output=3
4 | include_trailing_comma=True
5 | known_standard_library=numpy,setuptools
6 | skip_glob=*/__init__.py
7 | known_myself=segment_anything
8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort
9 | no_lines_before=STDLIB,THIRDPARTY
10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER
11 | default_section=FIRSTPARTY
12 |
--------------------------------------------------------------------------------
/segment_anything/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import find_packages, setup
8 |
9 | setup(
10 | name="segment_anything",
11 | version="1.0",
12 | install_requires=[],
13 | packages=find_packages(exclude="notebooks"),
14 | extras_require={
15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"],
16 | "dev": ["flake8", "isort", "black", "mypy"],
17 | },
18 | )
19 |
--------------------------------------------------------------------------------