├── .gitignore ├── GroundingDINO ├── .asset │ ├── COCO.png │ ├── GD_GLIGEN.png │ ├── GD_SD.png │ ├── ODinW.png │ ├── arch.png │ ├── cats.png │ └── hero_figure.png ├── LICENSE ├── README.md ├── demo │ ├── gradio_app.py │ └── inference_on_a_image.py ├── groundingdino │ ├── __init__.py │ ├── config │ │ └── GroundingDINO_SwinT_OGC.py │ ├── datasets │ │ └── transforms.py │ ├── models │ │ ├── GroundingDINO │ │ │ ├── __init__.py │ │ │ ├── backbone │ │ │ │ ├── __init__.py │ │ │ │ ├── backbone.py │ │ │ │ ├── position_encoding.py │ │ │ │ └── swin_transformer.py │ │ │ ├── bertwarper.py │ │ │ ├── csrc │ │ │ │ ├── MsDeformAttn │ │ │ │ │ ├── ms_deform_attn.h │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ ├── ms_deform_attn_cpu.h │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ ├── cuda_version.cu │ │ │ │ └── vision.cpp │ │ │ ├── fuse_modules.py │ │ │ ├── groundingdino.py │ │ │ ├── ms_deform_attn.py │ │ │ ├── transformer.py │ │ │ ├── transformer_vanilla.py │ │ │ └── utils.py │ │ ├── __init__.py │ │ └── registry.py │ ├── util │ │ ├── __init__.py │ │ ├── box_ops.py │ │ ├── get_tokenlizer.py │ │ ├── inference.py │ │ ├── logger.py │ │ ├── misc.py │ │ ├── slconfig.py │ │ ├── slio.py │ │ ├── time_counter.py │ │ ├── utils.py │ │ ├── visualizer.py │ │ └── vl_utils.py │ └── version.py ├── requirements.txt └── setup.py ├── LICENSE ├── README.md ├── assets ├── Grounded-SAM_logo.png ├── automatic_label_output │ ├── demo1.jpg │ ├── demo2.jpg │ ├── demo4.jpg │ └── demo8.jpg ├── automatic_label_output_demo3.jpg ├── demo1.jpg ├── demo2.jpg ├── demo3.jpg ├── demo4.jpg ├── demo5.jpg ├── demo6.jpg ├── demo7.jpg ├── demo8.jpg ├── gradio_demo.png ├── grounded_sam.jpg ├── grounded_sam2.png ├── grounded_sam_demo3_demo4.png ├── grounded_sam_inpainting_demo.png ├── grounded_sam_output_demo1.jpg ├── grounding_dino_output_demo1.jpg └── inpaint_demo.jpg ├── automatic_label_demo.py ├── gradio_app.py ├── grounded_dino_sam_inpainting_demo.py ├── grounded_sam.ipynb ├── grounded_sam_demo.py ├── grounded_sam_inpainting_demo.py ├── grounding_dino_demo.py ├── gsa_api.py ├── requirements.txt └── segment_anything ├── .flake8 ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── masks1.png ├── masks2.jpg ├── model_diagram.png ├── notebook1.png └── notebook2.png ├── linter.sh ├── notebooks ├── automatic_mask_generator_example.ipynb ├── images │ ├── dog.jpg │ ├── groceries.jpg │ └── truck.jpg ├── onnx_model_example.ipynb └── predictor_example.ipynb ├── scripts ├── amg.py └── export_onnx_model.py ├── segment_anything ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # checkpoint 132 | *.pth -------------------------------------------------------------------------------- /GroundingDINO/.asset/COCO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/COCO.png -------------------------------------------------------------------------------- /GroundingDINO/.asset/GD_GLIGEN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/GD_GLIGEN.png -------------------------------------------------------------------------------- /GroundingDINO/.asset/GD_SD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/GD_SD.png -------------------------------------------------------------------------------- /GroundingDINO/.asset/ODinW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/ODinW.png -------------------------------------------------------------------------------- /GroundingDINO/.asset/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/arch.png -------------------------------------------------------------------------------- /GroundingDINO/.asset/cats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/cats.png -------------------------------------------------------------------------------- /GroundingDINO/.asset/hero_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/.asset/hero_figure.png -------------------------------------------------------------------------------- /GroundingDINO/README.md: -------------------------------------------------------------------------------- 1 | # Grounding DINO 2 | 3 | --- 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-2303.05499-b31b1b.svg)](https://arxiv.org/abs/2303.05499) 6 | [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/wxWDt5UiwY8) 7 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) 8 | [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/cMa77r3YrDk) 9 | [![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) 10 | 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \ 12 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \ 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \ 14 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=grounding-dino-marrying-dino-with-grounded) 15 | 16 | 17 | 18 | Official PyTorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.05499), a stronger open-set object detector. Code is available now! 19 | 20 | 21 | ## Highlight 22 | 23 | - **Open-Set Detection.** Detect **everything** with language! 24 | - **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**. 25 | - **Flexible.** Collaboration with Stable Diffusion for Image Editting. 26 | 27 | ## News 28 | [2023/03/28] A YouTube [video](https://youtu.be/cMa77r3YrDk) about Grounding DINO and basic object detection prompt engineering. [[SkalskiP](https://github.com/SkalskiP)] \ 29 | [2023/03/28] Add a [demo](https://huggingface.co/spaces/ShilongLiu/Grounding_DINO_demo) on Hugging Face Space! \ 30 | [2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\ 31 | [2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. [[SkalskiP](https://github.com/SkalskiP)] \ 32 | [2023/03/22] Code is available Now! 33 | 34 |
35 | 36 | Description 37 | 38 | ODinW 39 |
40 | 41 | 42 | 43 | ## TODO 44 | 45 | - [x] Release inference code and demo. 46 | - [x] Release checkpoints. 47 | - [ ] Grounding DINO with Stable Diffusion and GLIGEN demos. 48 | - [ ] Release training codes. 49 | 50 | ## Install 51 | 52 | If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available. 53 | 54 | ```bash 55 | pip install -e . 56 | ``` 57 | 58 | ## Demo 59 | 60 | ```bash 61 | CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \ 62 | -c /path/to/config \ 63 | -p /path/to/checkpoint \ 64 | -i .asset/cats.png \ 65 | -o "outputs/0" \ 66 | -t "cat ear." \ 67 | [--cpu-only] # open it for cpu mode 68 | ``` 69 | See the `demo/inference_on_a_image.py` for more details. 70 | 71 | **Web UI** 72 | 73 | We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See the file `demo/gradio_app.py` for more details. 74 | 75 | ## Checkpoints 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 |
namebackboneDatabox AP on COCOCheckpointConfig
1GroundingDINO-TSwin-TO365,GoldG,Cap4M48.4 (zero-shot) / 57.2 (fine-tune)Github link | HF linklink
102 | 103 | ## Results 104 | 105 |
106 | 107 | COCO Object Detection Results 108 | 109 | COCO 110 |
111 | 112 |
113 | 114 | ODinW Object Detection Results 115 | 116 | ODinW 117 |
118 | 119 |
120 | 121 | Marrying Grounding DINO with Stable Diffusion for Image Editing 122 | 123 | GD_SD 124 |
125 | 126 |
127 | 128 | Marrying Grounding DINO with GLIGEN for more Detailed Image Editing 129 | 130 | GD_GLIGEN 131 |
132 | 133 | ## Model 134 | 135 | Includes: a text backbone, an image backbone, a feature enhancer, a language-guided query selection, and a cross-modality decoder. 136 | 137 | ![arch](.asset/arch.png) 138 | 139 | 140 | ## Acknowledgement 141 | 142 | Our model is related to [DINO](https://github.com/IDEA-Research/DINO) and [GLIP](https://github.com/microsoft/GLIP). Thanks for their great work! 143 | 144 | We also thank great previous work including DETR, Deformable DETR, SMCA, Conditional DETR, Anchor DETR, Dynamic DETR, DAB-DETR, DN-DETR, etc. More related work are available at [Awesome Detection Transformer](https://github.com/IDEACVR/awesome-detection-transformer). A new toolbox [detrex](https://github.com/IDEA-Research/detrex) is available as well. 145 | 146 | Thanks [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) and [GLIGEN](https://github.com/gligen/GLIGEN) for their awesome models. 147 | 148 | 149 | ## Citation 150 | 151 | If you find our work helpful for your research, please consider citing the following BibTeX entry. 152 | 153 | ```bibtex 154 | @inproceedings{ShilongLiu2023GroundingDM, 155 | title={Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection}, 156 | author={Shilong Liu and Zhaoyang Zeng and Tianhe Ren and Feng Li and Hao Zhang and Jie Yang and Chunyuan Li and Jianwei Yang and Hang Su and Jun Zhu and Lei Zhang}, 157 | year={2023} 158 | } 159 | ``` 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /GroundingDINO/demo/gradio_app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | import cv2 4 | import requests 5 | import os 6 | from io import BytesIO 7 | from PIL import Image 8 | import numpy as np 9 | from pathlib import Path 10 | 11 | 12 | import warnings 13 | 14 | import torch 15 | 16 | # prepare the environment 17 | os.system("python setup.py build develop --user") 18 | os.system("pip install packaging==21.3") 19 | os.system("pip install gradio") 20 | 21 | 22 | warnings.filterwarnings("ignore") 23 | 24 | import gradio as gr 25 | 26 | from groundingdino.models import build_model 27 | from groundingdino.util.slconfig import SLConfig 28 | from groundingdino.util.utils import clean_state_dict 29 | from groundingdino.util.inference import annotate, load_image, predict 30 | import groundingdino.datasets.transforms as T 31 | 32 | from huggingface_hub import hf_hub_download 33 | 34 | 35 | 36 | # Use this command for evaluate the GLIP-T model 37 | config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py" 38 | ckpt_repo_id = "ShilongLiu/GroundingDINO" 39 | ckpt_filenmae = "groundingdino_swint_ogc.pth" 40 | 41 | 42 | def load_model_hf(model_config_path, repo_id, filename, device='cpu'): 43 | args = SLConfig.fromfile(model_config_path) 44 | model = build_model(args) 45 | args.device = device 46 | 47 | cache_file = hf_hub_download(repo_id=repo_id, filename=filename) 48 | checkpoint = torch.load(cache_file, map_location='cpu') 49 | log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) 50 | print("Model loaded from {} \n => {}".format(cache_file, log)) 51 | _ = model.eval() 52 | return model 53 | 54 | def image_transform_grounding(init_image): 55 | transform = T.Compose([ 56 | T.RandomResize([800], max_size=1333), 57 | T.ToTensor(), 58 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 59 | ]) 60 | image, _ = transform(init_image, None) # 3, h, w 61 | return init_image, image 62 | 63 | def image_transform_grounding_for_vis(init_image): 64 | transform = T.Compose([ 65 | T.RandomResize([800], max_size=1333), 66 | ]) 67 | image, _ = transform(init_image, None) # 3, h, w 68 | return image 69 | 70 | model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae) 71 | 72 | def run_grounding(input_image, grounding_caption, box_threshold, text_threshold): 73 | init_image = input_image.convert("RGB") 74 | original_size = init_image.size 75 | 76 | _, image_tensor = image_transform_grounding(init_image) 77 | image_pil: Image = image_transform_grounding_for_vis(init_image) 78 | 79 | # run grounidng 80 | boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu') 81 | annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases) 82 | image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)) 83 | 84 | 85 | return image_with_box 86 | 87 | if __name__ == "__main__": 88 | 89 | parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) 90 | parser.add_argument("--debug", action="store_true", help="using debug mode") 91 | parser.add_argument("--share", action="store_true", help="share the app") 92 | args = parser.parse_args() 93 | 94 | block = gr.Blocks().queue() 95 | with block: 96 | gr.Markdown("# [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO)") 97 | gr.Markdown("### Open-World Detection with Grounding DINO") 98 | 99 | with gr.Row(): 100 | with gr.Column(): 101 | input_image = gr.Image(source='upload', type="pil") 102 | grounding_caption = gr.Textbox(label="Detection Prompt") 103 | run_button = gr.Button(label="Run") 104 | with gr.Accordion("Advanced options", open=False): 105 | box_threshold = gr.Slider( 106 | label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 107 | ) 108 | text_threshold = gr.Slider( 109 | label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001 110 | ) 111 | 112 | with gr.Column(): 113 | gallery = gr.outputs.Image( 114 | type="pil", 115 | # label="grounding results" 116 | ).style(full_width=True, full_height=True) 117 | # gallery = gr.Gallery(label="Generated images", show_label=False).style( 118 | # grid=[1], height="auto", container=True, full_width=True, full_height=True) 119 | 120 | run_button.click(fn=run_grounding, inputs=[ 121 | input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery]) 122 | 123 | 124 | block.launch(server_name='0.0.0.0', server_port=7579, debug=args.debug, share=args.share) 125 | 126 | -------------------------------------------------------------------------------- /GroundingDINO/demo/inference_on_a_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | import groundingdino.datasets.transforms as T 10 | from groundingdino.models import build_model 11 | from groundingdino.util import box_ops 12 | from groundingdino.util.slconfig import SLConfig 13 | from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap 14 | 15 | 16 | def plot_boxes_to_image(image_pil, tgt): 17 | H, W = tgt["size"] 18 | boxes = tgt["boxes"] 19 | labels = tgt["labels"] 20 | assert len(boxes) == len(labels), "boxes and labels must have same length" 21 | 22 | draw = ImageDraw.Draw(image_pil) 23 | mask = Image.new("L", image_pil.size, 0) 24 | mask_draw = ImageDraw.Draw(mask) 25 | 26 | # draw boxes and masks 27 | for box, label in zip(boxes, labels): 28 | # from 0..1 to 0..W, 0..H 29 | box = box * torch.Tensor([W, H, W, H]) 30 | # from xywh to xyxy 31 | box[:2] -= box[2:] / 2 32 | box[2:] += box[:2] 33 | # random color 34 | color = tuple(np.random.randint(0, 255, size=3).tolist()) 35 | # draw 36 | x0, y0, x1, y1 = box 37 | x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) 38 | 39 | draw.rectangle([x0, y0, x1, y1], outline=color, width=6) 40 | # draw.text((x0, y0), str(label), fill=color) 41 | 42 | font = ImageFont.load_default() 43 | if hasattr(font, "getbbox"): 44 | bbox = draw.textbbox((x0, y0), str(label), font) 45 | else: 46 | w, h = draw.textsize(str(label), font) 47 | bbox = (x0, y0, w + x0, y0 + h) 48 | # bbox = draw.textbbox((x0, y0), str(label)) 49 | draw.rectangle(bbox, fill=color) 50 | draw.text((x0, y0), str(label), fill="white") 51 | 52 | mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) 53 | 54 | return image_pil, mask 55 | 56 | 57 | def load_image(image_path): 58 | # load image 59 | image_pil = Image.open(image_path).convert("RGB") # load image 60 | 61 | transform = T.Compose( 62 | [ 63 | T.RandomResize([800], max_size=1333), 64 | T.ToTensor(), 65 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 66 | ] 67 | ) 68 | image, _ = transform(image_pil, None) # 3, h, w 69 | return image_pil, image 70 | 71 | 72 | def load_model(model_config_path, model_checkpoint_path, cpu_only=False): 73 | args = SLConfig.fromfile(model_config_path) 74 | args.device = "cuda" if not cpu_only else "cpu" 75 | model = build_model(args) 76 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") 77 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 78 | print(load_res) 79 | _ = model.eval() 80 | return model 81 | 82 | 83 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False): 84 | caption = caption.lower() 85 | caption = caption.strip() 86 | if not caption.endswith("."): 87 | caption = caption + "." 88 | device = "cuda" if not cpu_only else "cpu" 89 | model = model.to(device) 90 | image = image.to(device) 91 | with torch.no_grad(): 92 | outputs = model(image[None], captions=[caption]) 93 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) 94 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) 95 | logits.shape[0] 96 | 97 | # filter output 98 | logits_filt = logits.clone() 99 | boxes_filt = boxes.clone() 100 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold 101 | logits_filt = logits_filt[filt_mask] # num_filt, 256 102 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4 103 | logits_filt.shape[0] 104 | 105 | # get phrase 106 | tokenlizer = model.tokenizer 107 | tokenized = tokenlizer(caption) 108 | # build pred 109 | pred_phrases = [] 110 | for logit, box in zip(logits_filt, boxes_filt): 111 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) 112 | if with_logits: 113 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") 114 | else: 115 | pred_phrases.append(pred_phrase) 116 | 117 | return boxes_filt, pred_phrases 118 | 119 | 120 | if __name__ == "__main__": 121 | 122 | parser = argparse.ArgumentParser("Grounding DINO example", add_help=True) 123 | parser.add_argument("--config_file", "-c", type=str, required=True, help="path to config file") 124 | parser.add_argument( 125 | "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file" 126 | ) 127 | parser.add_argument("--image_path", "-i", type=str, required=True, help="path to image file") 128 | parser.add_argument("--text_prompt", "-t", type=str, required=True, help="text prompt") 129 | parser.add_argument( 130 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" 131 | ) 132 | 133 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") 134 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") 135 | 136 | parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False") 137 | args = parser.parse_args() 138 | 139 | # cfg 140 | config_file = args.config_file # change the path of the model config file 141 | checkpoint_path = args.checkpoint_path # change the path of the model 142 | image_path = args.image_path 143 | text_prompt = args.text_prompt 144 | output_dir = args.output_dir 145 | box_threshold = args.box_threshold 146 | text_threshold = args.box_threshold 147 | 148 | # make dir 149 | os.makedirs(output_dir, exist_ok=True) 150 | # load image 151 | image_pil, image = load_image(image_path) 152 | # load model 153 | model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only) 154 | 155 | # visualize raw image 156 | image_pil.save(os.path.join(output_dir, "raw_image.jpg")) 157 | 158 | # run model 159 | boxes_filt, pred_phrases = get_grounding_output( 160 | model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only 161 | ) 162 | 163 | # visualize pred 164 | size = image_pil.size 165 | pred_dict = { 166 | "boxes": boxes_filt, 167 | "size": [size[1], size[0]], # H,W 168 | "labels": pred_phrases, 169 | } 170 | # import ipdb; ipdb.set_trace() 171 | image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0] 172 | image_with_box.save(os.path.join(output_dir, "pred.jpg")) 173 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/GroundingDINO/groundingdino/__init__.py -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py: -------------------------------------------------------------------------------- 1 | batch_size = 1 2 | modelname = "groundingdino" 3 | backbone = "swin_T_224_1k" 4 | position_embedding = "sine" 5 | pe_temperatureH = 20 6 | pe_temperatureW = 20 7 | return_interm_indices = [1, 2, 3] 8 | backbone_freeze_keywords = None 9 | enc_layers = 6 10 | dec_layers = 6 11 | pre_norm = False 12 | dim_feedforward = 2048 13 | hidden_dim = 256 14 | dropout = 0.0 15 | nheads = 8 16 | num_queries = 900 17 | query_dim = 4 18 | num_patterns = 0 19 | num_feature_levels = 4 20 | enc_n_points = 4 21 | dec_n_points = 4 22 | two_stage_type = "standard" 23 | two_stage_bbox_embed_share = False 24 | two_stage_class_embed_share = False 25 | transformer_activation = "relu" 26 | dec_pred_bbox_embed_share = True 27 | dn_box_noise_scale = 1.0 28 | dn_label_noise_ratio = 0.5 29 | dn_label_coef = 1.0 30 | dn_bbox_coef = 1.0 31 | embed_init_tgt = True 32 | dn_labelbook_size = 2000 33 | max_text_len = 256 34 | text_encoder_type = "bert-base-uncased" 35 | use_text_enhancer = True 36 | use_fusion_layer = True 37 | use_checkpoint = True 38 | use_transformer_ckpt = True 39 | use_text_cross_attention = True 40 | text_dropout = 0.0 41 | fusion_dropout = 0.0 42 | fusion_droppath = 0.1 43 | sub_sentence_present = True 44 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Transforms and data augmentation for both image + bbox. 4 | """ 5 | import os 6 | import random 7 | 8 | import PIL 9 | import torch 10 | import torchvision.transforms as T 11 | import torchvision.transforms.functional as F 12 | 13 | from groundingdino.util.box_ops import box_xyxy_to_cxcywh 14 | from groundingdino.util.misc import interpolate 15 | 16 | 17 | def crop(image, target, region): 18 | cropped_image = F.crop(image, *region) 19 | 20 | target = target.copy() 21 | i, j, h, w = region 22 | 23 | # should we do something wrt the original size? 24 | target["size"] = torch.tensor([h, w]) 25 | 26 | fields = ["labels", "area", "iscrowd", "positive_map"] 27 | 28 | if "boxes" in target: 29 | boxes = target["boxes"] 30 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 31 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 32 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 33 | cropped_boxes = cropped_boxes.clamp(min=0) 34 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 35 | target["boxes"] = cropped_boxes.reshape(-1, 4) 36 | target["area"] = area 37 | fields.append("boxes") 38 | 39 | if "masks" in target: 40 | # FIXME should we update the area here if there are no boxes? 41 | target["masks"] = target["masks"][:, i : i + h, j : j + w] 42 | fields.append("masks") 43 | 44 | # remove elements for which the boxes or masks that have zero area 45 | if "boxes" in target or "masks" in target: 46 | # favor boxes selection when defining which elements to keep 47 | # this is compatible with previous implementation 48 | if "boxes" in target: 49 | cropped_boxes = target["boxes"].reshape(-1, 2, 2) 50 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 51 | else: 52 | keep = target["masks"].flatten(1).any(1) 53 | 54 | for field in fields: 55 | if field in target: 56 | target[field] = target[field][keep] 57 | 58 | if os.environ.get("IPDB_SHILONG_DEBUG", None) == "INFO": 59 | # for debug and visualization only. 60 | if "strings_positive" in target: 61 | target["strings_positive"] = [ 62 | _i for _i, _j in zip(target["strings_positive"], keep) if _j 63 | ] 64 | 65 | return cropped_image, target 66 | 67 | 68 | def hflip(image, target): 69 | flipped_image = F.hflip(image) 70 | 71 | w, h = image.size 72 | 73 | target = target.copy() 74 | if "boxes" in target: 75 | boxes = target["boxes"] 76 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( 77 | [w, 0, w, 0] 78 | ) 79 | target["boxes"] = boxes 80 | 81 | if "masks" in target: 82 | target["masks"] = target["masks"].flip(-1) 83 | 84 | return flipped_image, target 85 | 86 | 87 | def resize(image, target, size, max_size=None): 88 | # size can be min_size (scalar) or (w, h) tuple 89 | 90 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 91 | w, h = image_size 92 | if max_size is not None: 93 | min_original_size = float(min((w, h))) 94 | max_original_size = float(max((w, h))) 95 | if max_original_size / min_original_size * size > max_size: 96 | size = int(round(max_size * min_original_size / max_original_size)) 97 | 98 | if (w <= h and w == size) or (h <= w and h == size): 99 | return (h, w) 100 | 101 | if w < h: 102 | ow = size 103 | oh = int(size * h / w) 104 | else: 105 | oh = size 106 | ow = int(size * w / h) 107 | 108 | return (oh, ow) 109 | 110 | def get_size(image_size, size, max_size=None): 111 | if isinstance(size, (list, tuple)): 112 | return size[::-1] 113 | else: 114 | return get_size_with_aspect_ratio(image_size, size, max_size) 115 | 116 | size = get_size(image.size, size, max_size) 117 | rescaled_image = F.resize(image, size) 118 | 119 | if target is None: 120 | return rescaled_image, None 121 | 122 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 123 | ratio_width, ratio_height = ratios 124 | 125 | target = target.copy() 126 | if "boxes" in target: 127 | boxes = target["boxes"] 128 | scaled_boxes = boxes * torch.as_tensor( 129 | [ratio_width, ratio_height, ratio_width, ratio_height] 130 | ) 131 | target["boxes"] = scaled_boxes 132 | 133 | if "area" in target: 134 | area = target["area"] 135 | scaled_area = area * (ratio_width * ratio_height) 136 | target["area"] = scaled_area 137 | 138 | h, w = size 139 | target["size"] = torch.tensor([h, w]) 140 | 141 | if "masks" in target: 142 | target["masks"] = ( 143 | interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 144 | ) 145 | 146 | return rescaled_image, target 147 | 148 | 149 | def pad(image, target, padding): 150 | # assumes that we only pad on the bottom right corners 151 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 152 | if target is None: 153 | return padded_image, None 154 | target = target.copy() 155 | # should we do something wrt the original size? 156 | target["size"] = torch.tensor(padded_image.size[::-1]) 157 | if "masks" in target: 158 | target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1])) 159 | return padded_image, target 160 | 161 | 162 | class ResizeDebug(object): 163 | def __init__(self, size): 164 | self.size = size 165 | 166 | def __call__(self, img, target): 167 | return resize(img, target, self.size) 168 | 169 | 170 | class RandomCrop(object): 171 | def __init__(self, size): 172 | self.size = size 173 | 174 | def __call__(self, img, target): 175 | region = T.RandomCrop.get_params(img, self.size) 176 | return crop(img, target, region) 177 | 178 | 179 | class RandomSizeCrop(object): 180 | def __init__(self, min_size: int, max_size: int, respect_boxes: bool = False): 181 | # respect_boxes: True to keep all boxes 182 | # False to tolerence box filter 183 | self.min_size = min_size 184 | self.max_size = max_size 185 | self.respect_boxes = respect_boxes 186 | 187 | def __call__(self, img: PIL.Image.Image, target: dict): 188 | init_boxes = len(target["boxes"]) 189 | max_patience = 10 190 | for i in range(max_patience): 191 | w = random.randint(self.min_size, min(img.width, self.max_size)) 192 | h = random.randint(self.min_size, min(img.height, self.max_size)) 193 | region = T.RandomCrop.get_params(img, [h, w]) 194 | result_img, result_target = crop(img, target, region) 195 | if ( 196 | not self.respect_boxes 197 | or len(result_target["boxes"]) == init_boxes 198 | or i == max_patience - 1 199 | ): 200 | return result_img, result_target 201 | return result_img, result_target 202 | 203 | 204 | class CenterCrop(object): 205 | def __init__(self, size): 206 | self.size = size 207 | 208 | def __call__(self, img, target): 209 | image_width, image_height = img.size 210 | crop_height, crop_width = self.size 211 | crop_top = int(round((image_height - crop_height) / 2.0)) 212 | crop_left = int(round((image_width - crop_width) / 2.0)) 213 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 214 | 215 | 216 | class RandomHorizontalFlip(object): 217 | def __init__(self, p=0.5): 218 | self.p = p 219 | 220 | def __call__(self, img, target): 221 | if random.random() < self.p: 222 | return hflip(img, target) 223 | return img, target 224 | 225 | 226 | class RandomResize(object): 227 | def __init__(self, sizes, max_size=None): 228 | assert isinstance(sizes, (list, tuple)) 229 | self.sizes = sizes 230 | self.max_size = max_size 231 | 232 | def __call__(self, img, target=None): 233 | size = random.choice(self.sizes) 234 | return resize(img, target, size, self.max_size) 235 | 236 | 237 | class RandomPad(object): 238 | def __init__(self, max_pad): 239 | self.max_pad = max_pad 240 | 241 | def __call__(self, img, target): 242 | pad_x = random.randint(0, self.max_pad) 243 | pad_y = random.randint(0, self.max_pad) 244 | return pad(img, target, (pad_x, pad_y)) 245 | 246 | 247 | class RandomSelect(object): 248 | """ 249 | Randomly selects between transforms1 and transforms2, 250 | with probability p for transforms1 and (1 - p) for transforms2 251 | """ 252 | 253 | def __init__(self, transforms1, transforms2, p=0.5): 254 | self.transforms1 = transforms1 255 | self.transforms2 = transforms2 256 | self.p = p 257 | 258 | def __call__(self, img, target): 259 | if random.random() < self.p: 260 | return self.transforms1(img, target) 261 | return self.transforms2(img, target) 262 | 263 | 264 | class ToTensor(object): 265 | def __call__(self, img, target): 266 | return F.to_tensor(img), target 267 | 268 | 269 | class RandomErasing(object): 270 | def __init__(self, *args, **kwargs): 271 | self.eraser = T.RandomErasing(*args, **kwargs) 272 | 273 | def __call__(self, img, target): 274 | return self.eraser(img), target 275 | 276 | 277 | class Normalize(object): 278 | def __init__(self, mean, std): 279 | self.mean = mean 280 | self.std = std 281 | 282 | def __call__(self, image, target=None): 283 | image = F.normalize(image, mean=self.mean, std=self.std) 284 | if target is None: 285 | return image, None 286 | target = target.copy() 287 | h, w = image.shape[-2:] 288 | if "boxes" in target: 289 | boxes = target["boxes"] 290 | boxes = box_xyxy_to_cxcywh(boxes) 291 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 292 | target["boxes"] = boxes 293 | return image, target 294 | 295 | 296 | class Compose(object): 297 | def __init__(self, transforms): 298 | self.transforms = transforms 299 | 300 | def __call__(self, image, target): 301 | for t in self.transforms: 302 | image, target = t(image, target) 303 | return image, target 304 | 305 | def __repr__(self): 306 | format_string = self.__class__.__name__ + "(" 307 | for t in self.transforms: 308 | format_string += "\n" 309 | format_string += " {0}".format(t) 310 | format_string += "\n)" 311 | return format_string 312 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Conditional DETR 8 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 10 | # ------------------------------------------------------------------------ 11 | # Copied from DETR (https://github.com/facebookresearch/detr) 12 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 13 | # ------------------------------------------------------------------------ 14 | 15 | from .groundingdino import build_groundingdino 16 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import build_backbone 2 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Conditional DETR 8 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 10 | # ------------------------------------------------------------------------ 11 | # Copied from DETR (https://github.com/facebookresearch/detr) 12 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 13 | # ------------------------------------------------------------------------ 14 | 15 | """ 16 | Backbone modules. 17 | """ 18 | 19 | from typing import Dict, List 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | import torchvision 24 | from torch import nn 25 | from torchvision.models._utils import IntermediateLayerGetter 26 | 27 | from groundingdino.util.misc import NestedTensor, clean_state_dict, is_main_process 28 | 29 | from .position_encoding import build_position_encoding 30 | from .swin_transformer import build_swin_transformer 31 | 32 | 33 | class FrozenBatchNorm2d(torch.nn.Module): 34 | """ 35 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 36 | 37 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 38 | without which any other models than torchvision.models.resnet[18,34,50,101] 39 | produce nans. 40 | """ 41 | 42 | def __init__(self, n): 43 | super(FrozenBatchNorm2d, self).__init__() 44 | self.register_buffer("weight", torch.ones(n)) 45 | self.register_buffer("bias", torch.zeros(n)) 46 | self.register_buffer("running_mean", torch.zeros(n)) 47 | self.register_buffer("running_var", torch.ones(n)) 48 | 49 | def _load_from_state_dict( 50 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 51 | ): 52 | num_batches_tracked_key = prefix + "num_batches_tracked" 53 | if num_batches_tracked_key in state_dict: 54 | del state_dict[num_batches_tracked_key] 55 | 56 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 57 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 58 | ) 59 | 60 | def forward(self, x): 61 | # move reshapes to the beginning 62 | # to make it fuser-friendly 63 | w = self.weight.reshape(1, -1, 1, 1) 64 | b = self.bias.reshape(1, -1, 1, 1) 65 | rv = self.running_var.reshape(1, -1, 1, 1) 66 | rm = self.running_mean.reshape(1, -1, 1, 1) 67 | eps = 1e-5 68 | scale = w * (rv + eps).rsqrt() 69 | bias = b - rm * scale 70 | return x * scale + bias 71 | 72 | 73 | class BackboneBase(nn.Module): 74 | def __init__( 75 | self, 76 | backbone: nn.Module, 77 | train_backbone: bool, 78 | num_channels: int, 79 | return_interm_indices: list, 80 | ): 81 | super().__init__() 82 | for name, parameter in backbone.named_parameters(): 83 | if ( 84 | not train_backbone 85 | or "layer2" not in name 86 | and "layer3" not in name 87 | and "layer4" not in name 88 | ): 89 | parameter.requires_grad_(False) 90 | 91 | return_layers = {} 92 | for idx, layer_index in enumerate(return_interm_indices): 93 | return_layers.update( 94 | {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)} 95 | ) 96 | 97 | # if len: 98 | # if use_stage1_feature: 99 | # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 100 | # else: 101 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 102 | # else: 103 | # return_layers = {'layer4': "0"} 104 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 105 | self.num_channels = num_channels 106 | 107 | def forward(self, tensor_list: NestedTensor): 108 | xs = self.body(tensor_list.tensors) 109 | out: Dict[str, NestedTensor] = {} 110 | for name, x in xs.items(): 111 | m = tensor_list.mask 112 | assert m is not None 113 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 114 | out[name] = NestedTensor(x, mask) 115 | # import ipdb; ipdb.set_trace() 116 | return out 117 | 118 | 119 | class Backbone(BackboneBase): 120 | """ResNet backbone with frozen BatchNorm.""" 121 | 122 | def __init__( 123 | self, 124 | name: str, 125 | train_backbone: bool, 126 | dilation: bool, 127 | return_interm_indices: list, 128 | batch_norm=FrozenBatchNorm2d, 129 | ): 130 | if name in ["resnet18", "resnet34", "resnet50", "resnet101"]: 131 | backbone = getattr(torchvision.models, name)( 132 | replace_stride_with_dilation=[False, False, dilation], 133 | pretrained=is_main_process(), 134 | norm_layer=batch_norm, 135 | ) 136 | else: 137 | raise NotImplementedError("Why you can get here with name {}".format(name)) 138 | # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 139 | assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available." 140 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] 141 | num_channels_all = [256, 512, 1024, 2048] 142 | num_channels = num_channels_all[4 - len(return_interm_indices) :] 143 | super().__init__(backbone, train_backbone, num_channels, return_interm_indices) 144 | 145 | 146 | class Joiner(nn.Sequential): 147 | def __init__(self, backbone, position_embedding): 148 | super().__init__(backbone, position_embedding) 149 | 150 | def forward(self, tensor_list: NestedTensor): 151 | xs = self[0](tensor_list) 152 | out: List[NestedTensor] = [] 153 | pos = [] 154 | for name, x in xs.items(): 155 | out.append(x) 156 | # position encoding 157 | pos.append(self[1](x).to(x.tensors.dtype)) 158 | 159 | return out, pos 160 | 161 | 162 | def build_backbone(args): 163 | """ 164 | Useful args: 165 | - backbone: backbone name 166 | - lr_backbone: 167 | - dilation 168 | - return_interm_indices: available: [0,1,2,3], [1,2,3], [3] 169 | - backbone_freeze_keywords: 170 | - use_checkpoint: for swin only for now 171 | 172 | """ 173 | position_embedding = build_position_encoding(args) 174 | train_backbone = True 175 | if not train_backbone: 176 | raise ValueError("Please set lr_backbone > 0") 177 | return_interm_indices = args.return_interm_indices 178 | assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]] 179 | args.backbone_freeze_keywords 180 | use_checkpoint = getattr(args, "use_checkpoint", False) 181 | 182 | if args.backbone in ["resnet50", "resnet101"]: 183 | backbone = Backbone( 184 | args.backbone, 185 | train_backbone, 186 | args.dilation, 187 | return_interm_indices, 188 | batch_norm=FrozenBatchNorm2d, 189 | ) 190 | bb_num_channels = backbone.num_channels 191 | elif args.backbone in [ 192 | "swin_T_224_1k", 193 | "swin_B_224_22k", 194 | "swin_B_384_22k", 195 | "swin_L_224_22k", 196 | "swin_L_384_22k", 197 | ]: 198 | pretrain_img_size = int(args.backbone.split("_")[-2]) 199 | backbone = build_swin_transformer( 200 | args.backbone, 201 | pretrain_img_size=pretrain_img_size, 202 | out_indices=tuple(return_interm_indices), 203 | dilation=False, 204 | use_checkpoint=use_checkpoint, 205 | ) 206 | 207 | bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :] 208 | else: 209 | raise NotImplementedError("Unknown backbone {}".format(args.backbone)) 210 | 211 | assert len(bb_num_channels) == len( 212 | return_interm_indices 213 | ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}" 214 | 215 | model = Joiner(backbone, position_embedding) 216 | model.num_channels = bb_num_channels 217 | assert isinstance( 218 | bb_num_channels, List 219 | ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels)) 220 | # import ipdb; ipdb.set_trace() 221 | return model 222 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/backbone/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # DINO 8 | # Copyright (c) 2022 IDEA. All Rights Reserved. 9 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 10 | # ------------------------------------------------------------------------ 11 | # Conditional DETR 12 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 13 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 14 | # ------------------------------------------------------------------------ 15 | # Copied from DETR (https://github.com/facebookresearch/detr) 16 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 17 | # ------------------------------------------------------------------------ 18 | 19 | """ 20 | Various positional encodings for the transformer. 21 | """ 22 | import math 23 | 24 | import torch 25 | from torch import nn 26 | 27 | from groundingdino.util.misc import NestedTensor 28 | 29 | 30 | class PositionEmbeddingSine(nn.Module): 31 | """ 32 | This is a more standard version of the position embedding, very similar to the one 33 | used by the Attention is all you need paper, generalized to work on images. 34 | """ 35 | 36 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 37 | super().__init__() 38 | self.num_pos_feats = num_pos_feats 39 | self.temperature = temperature 40 | self.normalize = normalize 41 | if scale is not None and normalize is False: 42 | raise ValueError("normalize should be True if scale is passed") 43 | if scale is None: 44 | scale = 2 * math.pi 45 | self.scale = scale 46 | 47 | def forward(self, tensor_list: NestedTensor): 48 | x = tensor_list.tensors 49 | mask = tensor_list.mask 50 | assert mask is not None 51 | not_mask = ~mask 52 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 53 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 54 | if self.normalize: 55 | eps = 1e-6 56 | # if os.environ.get("SHILONG_AMP", None) == '1': 57 | # eps = 1e-4 58 | # else: 59 | # eps = 1e-6 60 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 61 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 62 | 63 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 64 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 65 | 66 | pos_x = x_embed[:, :, :, None] / dim_t 67 | pos_y = y_embed[:, :, :, None] / dim_t 68 | pos_x = torch.stack( 69 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 70 | ).flatten(3) 71 | pos_y = torch.stack( 72 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 73 | ).flatten(3) 74 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 75 | return pos 76 | 77 | 78 | class PositionEmbeddingSineHW(nn.Module): 79 | """ 80 | This is a more standard version of the position embedding, very similar to the one 81 | used by the Attention is all you need paper, generalized to work on images. 82 | """ 83 | 84 | def __init__( 85 | self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None 86 | ): 87 | super().__init__() 88 | self.num_pos_feats = num_pos_feats 89 | self.temperatureH = temperatureH 90 | self.temperatureW = temperatureW 91 | self.normalize = normalize 92 | if scale is not None and normalize is False: 93 | raise ValueError("normalize should be True if scale is passed") 94 | if scale is None: 95 | scale = 2 * math.pi 96 | self.scale = scale 97 | 98 | def forward(self, tensor_list: NestedTensor): 99 | x = tensor_list.tensors 100 | mask = tensor_list.mask 101 | assert mask is not None 102 | not_mask = ~mask 103 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 104 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 105 | 106 | # import ipdb; ipdb.set_trace() 107 | 108 | if self.normalize: 109 | eps = 1e-6 110 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 111 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 112 | 113 | dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 114 | dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats) 115 | pos_x = x_embed[:, :, :, None] / dim_tx 116 | 117 | dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 118 | dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats) 119 | pos_y = y_embed[:, :, :, None] / dim_ty 120 | 121 | pos_x = torch.stack( 122 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 123 | ).flatten(3) 124 | pos_y = torch.stack( 125 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 126 | ).flatten(3) 127 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 128 | 129 | # import ipdb; ipdb.set_trace() 130 | 131 | return pos 132 | 133 | 134 | class PositionEmbeddingLearned(nn.Module): 135 | """ 136 | Absolute pos embedding, learned. 137 | """ 138 | 139 | def __init__(self, num_pos_feats=256): 140 | super().__init__() 141 | self.row_embed = nn.Embedding(50, num_pos_feats) 142 | self.col_embed = nn.Embedding(50, num_pos_feats) 143 | self.reset_parameters() 144 | 145 | def reset_parameters(self): 146 | nn.init.uniform_(self.row_embed.weight) 147 | nn.init.uniform_(self.col_embed.weight) 148 | 149 | def forward(self, tensor_list: NestedTensor): 150 | x = tensor_list.tensors 151 | h, w = x.shape[-2:] 152 | i = torch.arange(w, device=x.device) 153 | j = torch.arange(h, device=x.device) 154 | x_emb = self.col_embed(i) 155 | y_emb = self.row_embed(j) 156 | pos = ( 157 | torch.cat( 158 | [ 159 | x_emb.unsqueeze(0).repeat(h, 1, 1), 160 | y_emb.unsqueeze(1).repeat(1, w, 1), 161 | ], 162 | dim=-1, 163 | ) 164 | .permute(2, 0, 1) 165 | .unsqueeze(0) 166 | .repeat(x.shape[0], 1, 1, 1) 167 | ) 168 | return pos 169 | 170 | 171 | def build_position_encoding(args): 172 | N_steps = args.hidden_dim // 2 173 | if args.position_embedding in ("v2", "sine"): 174 | # TODO find a better way of exposing other arguments 175 | position_embedding = PositionEmbeddingSineHW( 176 | N_steps, 177 | temperatureH=args.pe_temperatureH, 178 | temperatureW=args.pe_temperatureW, 179 | normalize=True, 180 | ) 181 | elif args.position_embedding in ("v3", "learned"): 182 | position_embedding = PositionEmbeddingLearned(N_steps) 183 | else: 184 | raise ValueError(f"not supported {args.position_embedding}") 185 | 186 | return position_embedding 187 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | namespace groundingdino { 20 | 21 | at::Tensor 22 | ms_deform_attn_forward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const int im2col_step) 29 | { 30 | if (value.type().is_cuda()) 31 | { 32 | #ifdef WITH_CUDA 33 | return ms_deform_attn_cuda_forward( 34 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 35 | #else 36 | AT_ERROR("Not compiled with GPU support"); 37 | #endif 38 | } 39 | AT_ERROR("Not implemented on the CPU"); 40 | } 41 | 42 | std::vector 43 | ms_deform_attn_backward( 44 | const at::Tensor &value, 45 | const at::Tensor &spatial_shapes, 46 | const at::Tensor &level_start_index, 47 | const at::Tensor &sampling_loc, 48 | const at::Tensor &attn_weight, 49 | const at::Tensor &grad_output, 50 | const int im2col_step) 51 | { 52 | if (value.type().is_cuda()) 53 | { 54 | #ifdef WITH_CUDA 55 | return ms_deform_attn_cuda_backward( 56 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 57 | #else 58 | AT_ERROR("Not compiled with GPU support"); 59 | #endif 60 | } 61 | AT_ERROR("Not implemented on the CPU"); 62 | } 63 | 64 | } // namespace groundingdino -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | namespace groundingdino { 17 | 18 | at::Tensor 19 | ms_deform_attn_cpu_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step) 26 | { 27 | AT_ERROR("Not implement on cpu"); 28 | } 29 | 30 | std::vector 31 | ms_deform_attn_cpu_backward( 32 | const at::Tensor &value, 33 | const at::Tensor &spatial_shapes, 34 | const at::Tensor &level_start_index, 35 | const at::Tensor &sampling_loc, 36 | const at::Tensor &attn_weight, 37 | const at::Tensor &grad_output, 38 | const int im2col_step) 39 | { 40 | AT_ERROR("Not implement on cpu"); 41 | } 42 | 43 | } // namespace groundingdino 44 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | namespace groundingdino { 15 | 16 | at::Tensor 17 | ms_deform_attn_cpu_forward( 18 | const at::Tensor &value, 19 | const at::Tensor &spatial_shapes, 20 | const at::Tensor &level_start_index, 21 | const at::Tensor &sampling_loc, 22 | const at::Tensor &attn_weight, 23 | const int im2col_step); 24 | 25 | std::vector 26 | ms_deform_attn_cpu_backward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const at::Tensor &grad_output, 33 | const int im2col_step); 34 | 35 | } // namespace groundingdino 36 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace groundingdino { 20 | 21 | at::Tensor ms_deform_attn_cuda_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 30 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 31 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 32 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 33 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 34 | 35 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 36 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 37 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 38 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 39 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 40 | 41 | const int batch = value.size(0); 42 | const int spatial_size = value.size(1); 43 | const int num_heads = value.size(2); 44 | const int channels = value.size(3); 45 | 46 | const int num_levels = spatial_shapes.size(0); 47 | 48 | const int num_query = sampling_loc.size(1); 49 | const int num_point = sampling_loc.size(4); 50 | 51 | const int im2col_step_ = std::min(batch, im2col_step); 52 | 53 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 54 | 55 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 56 | 57 | const int batch_n = im2col_step_; 58 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 59 | auto per_value_size = spatial_size * num_heads * channels; 60 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 61 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 62 | for (int n = 0; n < batch/im2col_step_; ++n) 63 | { 64 | auto columns = output_n.select(0, n); 65 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 66 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 67 | value.data() + n * im2col_step_ * per_value_size, 68 | spatial_shapes.data(), 69 | level_start_index.data(), 70 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 71 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 72 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 73 | columns.data()); 74 | 75 | })); 76 | } 77 | 78 | output = output.view({batch, num_query, num_heads*channels}); 79 | 80 | return output; 81 | } 82 | 83 | 84 | std::vector ms_deform_attn_cuda_backward( 85 | const at::Tensor &value, 86 | const at::Tensor &spatial_shapes, 87 | const at::Tensor &level_start_index, 88 | const at::Tensor &sampling_loc, 89 | const at::Tensor &attn_weight, 90 | const at::Tensor &grad_output, 91 | const int im2col_step) 92 | { 93 | 94 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 95 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 96 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 97 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 98 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 99 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 100 | 101 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 102 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 103 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 104 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 105 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 106 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 107 | 108 | const int batch = value.size(0); 109 | const int spatial_size = value.size(1); 110 | const int num_heads = value.size(2); 111 | const int channels = value.size(3); 112 | 113 | const int num_levels = spatial_shapes.size(0); 114 | 115 | const int num_query = sampling_loc.size(1); 116 | const int num_point = sampling_loc.size(4); 117 | 118 | const int im2col_step_ = std::min(batch, im2col_step); 119 | 120 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 121 | 122 | auto grad_value = at::zeros_like(value); 123 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 124 | auto grad_attn_weight = at::zeros_like(attn_weight); 125 | 126 | const int batch_n = im2col_step_; 127 | auto per_value_size = spatial_size * num_heads * channels; 128 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 129 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 130 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 131 | 132 | for (int n = 0; n < batch/im2col_step_; ++n) 133 | { 134 | auto grad_output_g = grad_output_n.select(0, n); 135 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 136 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 137 | grad_output_g.data(), 138 | value.data() + n * im2col_step_ * per_value_size, 139 | spatial_shapes.data(), 140 | level_start_index.data(), 141 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 142 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 143 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 144 | grad_value.data() + n * im2col_step_ * per_value_size, 145 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 146 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 147 | 148 | })); 149 | } 150 | 151 | return { 152 | grad_value, grad_sampling_loc, grad_attn_weight 153 | }; 154 | } 155 | 156 | } // namespace groundingdino -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | namespace groundingdino { 15 | 16 | at::Tensor ms_deform_attn_cuda_forward( 17 | const at::Tensor &value, 18 | const at::Tensor &spatial_shapes, 19 | const at::Tensor &level_start_index, 20 | const at::Tensor &sampling_loc, 21 | const at::Tensor &attn_weight, 22 | const int im2col_step); 23 | 24 | std::vector ms_deform_attn_cuda_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | } // namespace groundingdino -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/csrc/cuda_version.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace groundingdino { 4 | int get_cudart_version() { 5 | return CUDART_VERSION; 6 | } 7 | } // namespace groundingdino 8 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #include "MsDeformAttn/ms_deform_attn.h" 4 | 5 | namespace groundingdino { 6 | 7 | #ifdef WITH_CUDA 8 | extern int get_cudart_version(); 9 | #endif 10 | 11 | std::string get_cuda_version() { 12 | #ifdef WITH_CUDA 13 | std::ostringstream oss; 14 | 15 | // copied from 16 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 17 | auto printCudaStyleVersion = [&](int v) { 18 | oss << (v / 1000) << "." << (v / 10 % 100); 19 | if (v % 10 != 0) { 20 | oss << "." << (v % 10); 21 | } 22 | }; 23 | printCudaStyleVersion(get_cudart_version()); 24 | return oss.str(); 25 | #else 26 | return std::string("not available"); 27 | #endif 28 | } 29 | 30 | // similar to 31 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp 32 | std::string get_compiler_version() { 33 | std::ostringstream ss; 34 | #if defined(__GNUC__) 35 | #ifndef __clang__ 36 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } 37 | #endif 38 | #endif 39 | 40 | #if defined(__clang_major__) 41 | { 42 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." 43 | << __clang_patchlevel__; 44 | } 45 | #endif 46 | 47 | #if defined(_MSC_VER) 48 | { ss << "MSVC " << _MSC_FULL_VER; } 49 | #endif 50 | return ss.str(); 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 55 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 56 | } 57 | 58 | } // namespace groundingdino -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/GroundingDINO/transformer_vanilla.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 8 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 9 | """ 10 | DETR Transformer class. 11 | 12 | Copy-paste from torch.nn.Transformer with modifications: 13 | * positional encodings are passed in MHattention 14 | * extra LN at the end of encoder is removed 15 | * decoder returns a stack of activations from all decoding layers 16 | """ 17 | from typing import Optional 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | from torch import Tensor, nn 22 | 23 | from .utils import ( 24 | MLP, 25 | _get_activation_fn, 26 | _get_clones, 27 | gen_encoder_output_proposals, 28 | gen_sineembed_for_position, 29 | sigmoid_focal_loss, 30 | ) 31 | 32 | 33 | class TextTransformer(nn.Module): 34 | def __init__(self, num_layers, d_model=256, nheads=8, dim_feedforward=2048, dropout=0.1): 35 | super().__init__() 36 | self.num_layers = num_layers 37 | self.d_model = d_model 38 | self.nheads = nheads 39 | self.dim_feedforward = dim_feedforward 40 | self.norm = None 41 | 42 | single_encoder_layer = TransformerEncoderLayer( 43 | d_model=d_model, nhead=nheads, dim_feedforward=dim_feedforward, dropout=dropout 44 | ) 45 | self.layers = _get_clones(single_encoder_layer, num_layers) 46 | 47 | def forward(self, memory_text: torch.Tensor, text_attention_mask: torch.Tensor): 48 | """ 49 | 50 | Args: 51 | text_attention_mask: bs, num_token 52 | memory_text: bs, num_token, d_model 53 | 54 | Raises: 55 | RuntimeError: _description_ 56 | 57 | Returns: 58 | output: bs, num_token, d_model 59 | """ 60 | 61 | output = memory_text.transpose(0, 1) 62 | 63 | for layer in self.layers: 64 | output = layer(output, src_key_padding_mask=text_attention_mask) 65 | 66 | if self.norm is not None: 67 | output = self.norm(output) 68 | 69 | return output.transpose(0, 1) 70 | 71 | 72 | class TransformerEncoderLayer(nn.Module): 73 | def __init__( 74 | self, 75 | d_model, 76 | nhead, 77 | dim_feedforward=2048, 78 | dropout=0.1, 79 | activation="relu", 80 | normalize_before=False, 81 | ): 82 | super().__init__() 83 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 84 | # Implementation of Feedforward model 85 | self.linear1 = nn.Linear(d_model, dim_feedforward) 86 | self.dropout = nn.Dropout(dropout) 87 | self.linear2 = nn.Linear(dim_feedforward, d_model) 88 | 89 | self.norm1 = nn.LayerNorm(d_model) 90 | self.norm2 = nn.LayerNorm(d_model) 91 | self.dropout1 = nn.Dropout(dropout) 92 | self.dropout2 = nn.Dropout(dropout) 93 | 94 | self.activation = _get_activation_fn(activation) 95 | self.normalize_before = normalize_before 96 | self.nhead = nhead 97 | 98 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 99 | return tensor if pos is None else tensor + pos 100 | 101 | def forward( 102 | self, 103 | src, 104 | src_mask: Optional[Tensor] = None, 105 | src_key_padding_mask: Optional[Tensor] = None, 106 | pos: Optional[Tensor] = None, 107 | ): 108 | # repeat attn mask 109 | if src_mask.dim() == 3 and src_mask.shape[0] == src.shape[1]: 110 | # bs, num_q, num_k 111 | src_mask = src_mask.repeat(self.nhead, 1, 1) 112 | 113 | q = k = self.with_pos_embed(src, pos) 114 | 115 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)[0] 116 | 117 | # src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 118 | src = src + self.dropout1(src2) 119 | src = self.norm1(src) 120 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 121 | src = src + self.dropout2(src2) 122 | src = self.norm2(src) 123 | return src 124 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | from .GroundingDINO import build_groundingdino 9 | 10 | 11 | def build_model(args): 12 | # we use register to maintain models from catdet6 on. 13 | from .registry import MODULE_BUILD_FUNCS 14 | 15 | assert args.modelname in MODULE_BUILD_FUNCS._module_dict 16 | build_func = MODULE_BUILD_FUNCS.get(args.modelname) 17 | model = build_func(args) 18 | return model 19 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/models/registry.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Grounding DINO 3 | # url: https://github.com/IDEA-Research/GroundingDINO 4 | # Copyright (c) 2023 IDEA. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------ 7 | # -*- coding: utf-8 -*- 8 | # @Author: Yihao Chen 9 | # @Date: 2021-08-16 16:03:17 10 | # @Last Modified by: Shilong Liu 11 | # @Last Modified time: 2022-01-23 15:26 12 | # modified from mmcv 13 | 14 | import inspect 15 | from functools import partial 16 | 17 | 18 | class Registry(object): 19 | def __init__(self, name): 20 | self._name = name 21 | self._module_dict = dict() 22 | 23 | def __repr__(self): 24 | format_str = self.__class__.__name__ + "(name={}, items={})".format( 25 | self._name, list(self._module_dict.keys()) 26 | ) 27 | return format_str 28 | 29 | def __len__(self): 30 | return len(self._module_dict) 31 | 32 | @property 33 | def name(self): 34 | return self._name 35 | 36 | @property 37 | def module_dict(self): 38 | return self._module_dict 39 | 40 | def get(self, key): 41 | return self._module_dict.get(key, None) 42 | 43 | def registe_with_name(self, module_name=None, force=False): 44 | return partial(self.register, module_name=module_name, force=force) 45 | 46 | def register(self, module_build_function, module_name=None, force=False): 47 | """Register a module build function. 48 | Args: 49 | module (:obj:`nn.Module`): Module to be registered. 50 | """ 51 | if not inspect.isfunction(module_build_function): 52 | raise TypeError( 53 | "module_build_function must be a function, but got {}".format( 54 | type(module_build_function) 55 | ) 56 | ) 57 | if module_name is None: 58 | module_name = module_build_function.__name__ 59 | if not force and module_name in self._module_dict: 60 | raise KeyError("{} is already registered in {}".format(module_name, self.name)) 61 | self._module_dict[module_name] = module_build_function 62 | 63 | return module_build_function 64 | 65 | 66 | MODULE_BUILD_FUNCS = Registry("model build functions") 67 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 12 | return torch.stack(b, dim=-1) 13 | 14 | 15 | def box_xyxy_to_cxcywh(x): 16 | x0, y0, x1, y1 = x.unbind(-1) 17 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] 18 | return torch.stack(b, dim=-1) 19 | 20 | 21 | # modified from torchvision to also return the union 22 | def box_iou(boxes1, boxes2): 23 | area1 = box_area(boxes1) 24 | area2 = box_area(boxes2) 25 | 26 | # import ipdb; ipdb.set_trace() 27 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 28 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 29 | 30 | wh = (rb - lt).clamp(min=0) # [N,M,2] 31 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 32 | 33 | union = area1[:, None] + area2 - inter 34 | 35 | iou = inter / (union + 1e-6) 36 | return iou, union 37 | 38 | 39 | def generalized_box_iou(boxes1, boxes2): 40 | """ 41 | Generalized IoU from https://giou.stanford.edu/ 42 | 43 | The boxes should be in [x0, y0, x1, y1] format 44 | 45 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 46 | and M = len(boxes2) 47 | """ 48 | # degenerate boxes gives inf / nan results 49 | # so do an early check 50 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 51 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 52 | # except: 53 | # import ipdb; ipdb.set_trace() 54 | iou, union = box_iou(boxes1, boxes2) 55 | 56 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 57 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 58 | 59 | wh = (rb - lt).clamp(min=0) # [N,M,2] 60 | area = wh[:, :, 0] * wh[:, :, 1] 61 | 62 | return iou - (area - union) / (area + 1e-6) 63 | 64 | 65 | # modified from torchvision to also return the union 66 | def box_iou_pairwise(boxes1, boxes2): 67 | area1 = box_area(boxes1) 68 | area2 = box_area(boxes2) 69 | 70 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2] 71 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2] 72 | 73 | wh = (rb - lt).clamp(min=0) # [N,2] 74 | inter = wh[:, 0] * wh[:, 1] # [N] 75 | 76 | union = area1 + area2 - inter 77 | 78 | iou = inter / union 79 | return iou, union 80 | 81 | 82 | def generalized_box_iou_pairwise(boxes1, boxes2): 83 | """ 84 | Generalized IoU from https://giou.stanford.edu/ 85 | 86 | Input: 87 | - boxes1, boxes2: N,4 88 | Output: 89 | - giou: N, 4 90 | """ 91 | # degenerate boxes gives inf / nan results 92 | # so do an early check 93 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 94 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 95 | assert boxes1.shape == boxes2.shape 96 | iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4 97 | 98 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 99 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 100 | 101 | wh = (rb - lt).clamp(min=0) # [N,2] 102 | area = wh[:, 0] * wh[:, 1] 103 | 104 | return iou - (area - union) / area 105 | 106 | 107 | def masks_to_boxes(masks): 108 | """Compute the bounding boxes around the provided masks 109 | 110 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 111 | 112 | Returns a [N, 4] tensors, with the boxes in xyxy format 113 | """ 114 | if masks.numel() == 0: 115 | return torch.zeros((0, 4), device=masks.device) 116 | 117 | h, w = masks.shape[-2:] 118 | 119 | y = torch.arange(0, h, dtype=torch.float) 120 | x = torch.arange(0, w, dtype=torch.float) 121 | y, x = torch.meshgrid(y, x) 122 | 123 | x_mask = masks * x.unsqueeze(0) 124 | x_max = x_mask.flatten(1).max(-1)[0] 125 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 126 | 127 | y_mask = masks * y.unsqueeze(0) 128 | y_max = y_mask.flatten(1).max(-1)[0] 129 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 130 | 131 | return torch.stack([x_min, y_min, x_max, y_max], 1) 132 | 133 | 134 | if __name__ == "__main__": 135 | x = torch.rand(5, 4) 136 | y = torch.rand(3, 4) 137 | iou, union = box_iou(x, y) 138 | import ipdb 139 | 140 | ipdb.set_trace() 141 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/get_tokenlizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast 2 | 3 | 4 | def get_tokenlizer(text_encoder_type): 5 | if not isinstance(text_encoder_type, str): 6 | # print("text_encoder_type is not a str") 7 | if hasattr(text_encoder_type, "text_encoder_type"): 8 | text_encoder_type = text_encoder_type.text_encoder_type 9 | elif text_encoder_type.get("text_encoder_type", False): 10 | text_encoder_type = text_encoder_type.get("text_encoder_type") 11 | else: 12 | raise ValueError( 13 | "Unknown type of text_encoder_type: {}".format(type(text_encoder_type)) 14 | ) 15 | print("final text_encoder_type: {}".format(text_encoder_type)) 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(text_encoder_type) 18 | return tokenizer 19 | 20 | 21 | def get_pretrained_language_model(text_encoder_type): 22 | if text_encoder_type == "bert-base-uncased": 23 | return BertModel.from_pretrained(text_encoder_type) 24 | if text_encoder_type == "roberta-base": 25 | return RobertaModel.from_pretrained(text_encoder_type) 26 | raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type)) 27 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/inference.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import cv2 4 | import numpy as np 5 | import supervision as sv 6 | import torch 7 | from PIL import Image 8 | from torchvision.ops import box_convert 9 | 10 | import groundingdino.datasets.transforms as T 11 | from groundingdino.models import build_model 12 | from groundingdino.util.misc import clean_state_dict 13 | from groundingdino.util.slconfig import SLConfig 14 | from groundingdino.util.utils import get_phrases_from_posmap 15 | 16 | 17 | def preprocess_caption(caption: str) -> str: 18 | result = caption.lower().strip() 19 | if result.endswith("."): 20 | return result 21 | return result + "." 22 | 23 | 24 | def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): 25 | args = SLConfig.fromfile(model_config_path) 26 | args.device = device 27 | model = build_model(args) 28 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") 29 | model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 30 | model.eval() 31 | return model 32 | 33 | 34 | def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: 35 | transform = T.Compose( 36 | [ 37 | T.RandomResize([800], max_size=1333), 38 | T.ToTensor(), 39 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 40 | ] 41 | ) 42 | image_source = Image.open(image_path).convert("RGB") 43 | image = np.asarray(image_source) 44 | image_transformed, _ = transform(image_source, None) 45 | return image, image_transformed 46 | 47 | 48 | def predict( 49 | model, 50 | image: torch.Tensor, 51 | caption: str, 52 | box_threshold: float, 53 | text_threshold: float, 54 | device: str = "cuda" 55 | ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: 56 | caption = preprocess_caption(caption=caption) 57 | 58 | model = model.to(device) 59 | image = image.to(device) 60 | 61 | with torch.no_grad(): 62 | outputs = model(image[None], captions=[caption]) 63 | 64 | prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) 65 | prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) 66 | 67 | mask = prediction_logits.max(dim=1)[0] > box_threshold 68 | logits = prediction_logits[mask] # logits.shape = (n, 256) 69 | boxes = prediction_boxes[mask] # boxes.shape = (n, 4) 70 | 71 | tokenizer = model.tokenizer 72 | tokenized = tokenizer(caption) 73 | 74 | phrases = [ 75 | get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') 76 | for logit 77 | in logits 78 | ] 79 | 80 | return boxes, logits.max(dim=1)[0], phrases 81 | 82 | 83 | def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray: 84 | h, w, _ = image_source.shape 85 | boxes = boxes * torch.Tensor([w, h, w, h]) 86 | xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() 87 | detections = sv.Detections(xyxy=xyxy) 88 | 89 | labels = [ 90 | f"{phrase} {logit:.2f}" 91 | for phrase, logit 92 | in zip(phrases, logits) 93 | ] 94 | 95 | box_annotator = sv.BoxAnnotator() 96 | annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) 97 | annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) 98 | return annotated_frame 99 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | 7 | from termcolor import colored 8 | 9 | 10 | class _ColorfulFormatter(logging.Formatter): 11 | def __init__(self, *args, **kwargs): 12 | self._root_name = kwargs.pop("root_name") + "." 13 | self._abbrev_name = kwargs.pop("abbrev_name", "") 14 | if len(self._abbrev_name): 15 | self._abbrev_name = self._abbrev_name + "." 16 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 17 | 18 | def formatMessage(self, record): 19 | record.name = record.name.replace(self._root_name, self._abbrev_name) 20 | log = super(_ColorfulFormatter, self).formatMessage(record) 21 | if record.levelno == logging.WARNING: 22 | prefix = colored("WARNING", "red", attrs=["blink"]) 23 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 24 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 25 | else: 26 | return log 27 | return prefix + " " + log 28 | 29 | 30 | # so that calling setup_logger multiple times won't add many handlers 31 | @functools.lru_cache() 32 | def setup_logger(output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None): 33 | """ 34 | Initialize the detectron2 logger and set its verbosity level to "INFO". 35 | 36 | Args: 37 | output (str): a file name or a directory to save log. If None, will not save log file. 38 | If ends with ".txt" or ".log", assumed to be a file name. 39 | Otherwise, logs will be saved to `output/log.txt`. 40 | name (str): the root module name of this logger 41 | 42 | Returns: 43 | logging.Logger: a logger 44 | """ 45 | logger = logging.getLogger(name) 46 | logger.setLevel(logging.DEBUG) 47 | logger.propagate = False 48 | 49 | if abbrev_name is None: 50 | abbrev_name = name 51 | 52 | plain_formatter = logging.Formatter( 53 | "[%(asctime)s.%(msecs)03d]: %(message)s", datefmt="%m/%d %H:%M:%S" 54 | ) 55 | # stdout logging: master only 56 | if distributed_rank == 0: 57 | ch = logging.StreamHandler(stream=sys.stdout) 58 | ch.setLevel(logging.DEBUG) 59 | if color: 60 | formatter = _ColorfulFormatter( 61 | colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s", 62 | datefmt="%m/%d %H:%M:%S", 63 | root_name=name, 64 | abbrev_name=str(abbrev_name), 65 | ) 66 | else: 67 | formatter = plain_formatter 68 | ch.setFormatter(formatter) 69 | logger.addHandler(ch) 70 | 71 | # file logging: all workers 72 | if output is not None: 73 | if output.endswith(".txt") or output.endswith(".log"): 74 | filename = output 75 | else: 76 | filename = os.path.join(output, "log.txt") 77 | if distributed_rank > 0: 78 | filename = filename + f".rank{distributed_rank}" 79 | os.makedirs(os.path.dirname(filename), exist_ok=True) 80 | 81 | fh = logging.StreamHandler(_cached_log_stream(filename)) 82 | fh.setLevel(logging.DEBUG) 83 | fh.setFormatter(plain_formatter) 84 | logger.addHandler(fh) 85 | 86 | return logger 87 | 88 | 89 | # cache the opened file object, so that different calls to `setup_logger` 90 | # with the same file name can safely write to the same file. 91 | @functools.lru_cache(maxsize=None) 92 | def _cached_log_stream(filename): 93 | return open(filename, "a") 94 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/slio.py: -------------------------------------------------------------------------------- 1 | # ========================================================== 2 | # Modified from mmcv 3 | # ========================================================== 4 | 5 | import json 6 | import pickle 7 | from abc import ABCMeta, abstractmethod 8 | from pathlib import Path 9 | 10 | import yaml 11 | 12 | try: 13 | from yaml import CLoader as Loader, CDumper as Dumper 14 | except ImportError: 15 | from yaml import Loader, Dumper 16 | 17 | 18 | # =========================== 19 | # Rigister handler 20 | # =========================== 21 | 22 | 23 | class BaseFileHandler(metaclass=ABCMeta): 24 | @abstractmethod 25 | def load_from_fileobj(self, file, **kwargs): 26 | pass 27 | 28 | @abstractmethod 29 | def dump_to_fileobj(self, obj, file, **kwargs): 30 | pass 31 | 32 | @abstractmethod 33 | def dump_to_str(self, obj, **kwargs): 34 | pass 35 | 36 | def load_from_path(self, filepath, mode="r", **kwargs): 37 | with open(filepath, mode) as f: 38 | return self.load_from_fileobj(f, **kwargs) 39 | 40 | def dump_to_path(self, obj, filepath, mode="w", **kwargs): 41 | with open(filepath, mode) as f: 42 | self.dump_to_fileobj(obj, f, **kwargs) 43 | 44 | 45 | class JsonHandler(BaseFileHandler): 46 | def load_from_fileobj(self, file): 47 | return json.load(file) 48 | 49 | def dump_to_fileobj(self, obj, file, **kwargs): 50 | json.dump(obj, file, **kwargs) 51 | 52 | def dump_to_str(self, obj, **kwargs): 53 | return json.dumps(obj, **kwargs) 54 | 55 | 56 | class PickleHandler(BaseFileHandler): 57 | def load_from_fileobj(self, file, **kwargs): 58 | return pickle.load(file, **kwargs) 59 | 60 | def load_from_path(self, filepath, **kwargs): 61 | return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs) 62 | 63 | def dump_to_str(self, obj, **kwargs): 64 | kwargs.setdefault("protocol", 2) 65 | return pickle.dumps(obj, **kwargs) 66 | 67 | def dump_to_fileobj(self, obj, file, **kwargs): 68 | kwargs.setdefault("protocol", 2) 69 | pickle.dump(obj, file, **kwargs) 70 | 71 | def dump_to_path(self, obj, filepath, **kwargs): 72 | super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs) 73 | 74 | 75 | class YamlHandler(BaseFileHandler): 76 | def load_from_fileobj(self, file, **kwargs): 77 | kwargs.setdefault("Loader", Loader) 78 | return yaml.load(file, **kwargs) 79 | 80 | def dump_to_fileobj(self, obj, file, **kwargs): 81 | kwargs.setdefault("Dumper", Dumper) 82 | yaml.dump(obj, file, **kwargs) 83 | 84 | def dump_to_str(self, obj, **kwargs): 85 | kwargs.setdefault("Dumper", Dumper) 86 | return yaml.dump(obj, **kwargs) 87 | 88 | 89 | file_handlers = { 90 | "json": JsonHandler(), 91 | "yaml": YamlHandler(), 92 | "yml": YamlHandler(), 93 | "pickle": PickleHandler(), 94 | "pkl": PickleHandler(), 95 | } 96 | 97 | # =========================== 98 | # load and dump 99 | # =========================== 100 | 101 | 102 | def is_str(x): 103 | """Whether the input is an string instance. 104 | 105 | Note: This method is deprecated since python 2 is no longer supported. 106 | """ 107 | return isinstance(x, str) 108 | 109 | 110 | def slload(file, file_format=None, **kwargs): 111 | """Load data from json/yaml/pickle files. 112 | 113 | This method provides a unified api for loading data from serialized files. 114 | 115 | Args: 116 | file (str or :obj:`Path` or file-like object): Filename or a file-like 117 | object. 118 | file_format (str, optional): If not specified, the file format will be 119 | inferred from the file extension, otherwise use the specified one. 120 | Currently supported formats include "json", "yaml/yml" and 121 | "pickle/pkl". 122 | 123 | Returns: 124 | The content from the file. 125 | """ 126 | if isinstance(file, Path): 127 | file = str(file) 128 | if file_format is None and is_str(file): 129 | file_format = file.split(".")[-1] 130 | if file_format not in file_handlers: 131 | raise TypeError(f"Unsupported format: {file_format}") 132 | 133 | handler = file_handlers[file_format] 134 | if is_str(file): 135 | obj = handler.load_from_path(file, **kwargs) 136 | elif hasattr(file, "read"): 137 | obj = handler.load_from_fileobj(file, **kwargs) 138 | else: 139 | raise TypeError('"file" must be a filepath str or a file-object') 140 | return obj 141 | 142 | 143 | def sldump(obj, file=None, file_format=None, **kwargs): 144 | """Dump data to json/yaml/pickle strings or files. 145 | 146 | This method provides a unified api for dumping data as strings or to files, 147 | and also supports custom arguments for each file format. 148 | 149 | Args: 150 | obj (any): The python object to be dumped. 151 | file (str or :obj:`Path` or file-like object, optional): If not 152 | specified, then the object is dump to a str, otherwise to a file 153 | specified by the filename or file-like object. 154 | file_format (str, optional): Same as :func:`load`. 155 | 156 | Returns: 157 | bool: True for success, False otherwise. 158 | """ 159 | if isinstance(file, Path): 160 | file = str(file) 161 | if file_format is None: 162 | if is_str(file): 163 | file_format = file.split(".")[-1] 164 | elif file is None: 165 | raise ValueError("file_format must be specified since file is None") 166 | if file_format not in file_handlers: 167 | raise TypeError(f"Unsupported format: {file_format}") 168 | 169 | handler = file_handlers[file_format] 170 | if file is None: 171 | return handler.dump_to_str(obj, **kwargs) 172 | elif is_str(file): 173 | handler.dump_to_path(obj, file, **kwargs) 174 | elif hasattr(file, "write"): 175 | handler.dump_to_fileobj(obj, file, **kwargs) 176 | else: 177 | raise TypeError('"file" must be a filename str or a file-object') 178 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/time_counter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | 5 | class TimeCounter: 6 | def __init__(self) -> None: 7 | pass 8 | 9 | def clear(self): 10 | self.timedict = {} 11 | self.basetime = time.perf_counter() 12 | 13 | def timeit(self, name): 14 | nowtime = time.perf_counter() - self.basetime 15 | self.timedict[name] = nowtime 16 | self.basetime = time.perf_counter() 17 | 18 | 19 | class TimeHolder: 20 | def __init__(self) -> None: 21 | self.timedict = {} 22 | 23 | def update(self, _timedict: dict): 24 | for k, v in _timedict.items(): 25 | if k not in self.timedict: 26 | self.timedict[k] = AverageMeter(name=k, val_only=True) 27 | self.timedict[k].update(val=v) 28 | 29 | def final_res(self): 30 | return {k: v.avg for k, v in self.timedict.items()} 31 | 32 | def __str__(self): 33 | return json.dumps(self.final_res(), indent=2) 34 | 35 | 36 | class AverageMeter(object): 37 | """Computes and stores the average and current value""" 38 | 39 | def __init__(self, name, fmt=":f", val_only=False): 40 | self.name = name 41 | self.fmt = fmt 42 | self.val_only = val_only 43 | self.reset() 44 | 45 | def reset(self): 46 | self.val = 0 47 | self.avg = 0 48 | self.sum = 0 49 | self.count = 0 50 | 51 | def update(self, val, n=1): 52 | self.val = val 53 | self.sum += val * n 54 | self.count += n 55 | self.avg = self.sum / self.count 56 | 57 | def __str__(self): 58 | if self.val_only: 59 | fmtstr = "{name} {val" + self.fmt + "}" 60 | else: 61 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 62 | return fmtstr.format(**self.__dict__) 63 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/util/vl_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import List 4 | 5 | import torch 6 | 7 | 8 | def create_positive_map_from_span(tokenized, token_span, max_text_len=256): 9 | """construct a map such that positive_map[i,j] = True iff box i is associated to token j 10 | Input: 11 | - tokenized: 12 | - input_ids: Tensor[1, ntokens] 13 | - attention_mask: Tensor[1, ntokens] 14 | - token_span: list with length num_boxes. 15 | - each item: [start_idx, end_idx] 16 | """ 17 | positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float) 18 | for j, tok_list in enumerate(token_span): 19 | for (beg, end) in tok_list: 20 | beg_pos = tokenized.char_to_token(beg) 21 | end_pos = tokenized.char_to_token(end - 1) 22 | if beg_pos is None: 23 | try: 24 | beg_pos = tokenized.char_to_token(beg + 1) 25 | if beg_pos is None: 26 | beg_pos = tokenized.char_to_token(beg + 2) 27 | except: 28 | beg_pos = None 29 | if end_pos is None: 30 | try: 31 | end_pos = tokenized.char_to_token(end - 2) 32 | if end_pos is None: 33 | end_pos = tokenized.char_to_token(end - 3) 34 | except: 35 | end_pos = None 36 | if beg_pos is None or end_pos is None: 37 | continue 38 | 39 | assert beg_pos is not None and end_pos is not None 40 | if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE": 41 | positive_map[j, beg_pos] = 1 42 | break 43 | else: 44 | positive_map[j, beg_pos : end_pos + 1].fill_(1) 45 | 46 | return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) 47 | 48 | 49 | def build_captions_and_token_span(cat_list, force_lowercase): 50 | """ 51 | Return: 52 | captions: str 53 | cat2tokenspan: dict 54 | { 55 | 'dog': [[0, 2]], 56 | ... 57 | } 58 | """ 59 | 60 | cat2tokenspan = {} 61 | captions = "" 62 | for catname in cat_list: 63 | class_name = catname 64 | if force_lowercase: 65 | class_name = class_name.lower() 66 | if "/" in class_name: 67 | class_name_list: List = class_name.strip().split("/") 68 | class_name_list.append(class_name) 69 | class_name: str = random.choice(class_name_list) 70 | 71 | tokens_positive_i = [] 72 | subnamelist = [i.strip() for i in class_name.strip().split(" ")] 73 | for subname in subnamelist: 74 | if len(subname) == 0: 75 | continue 76 | if len(captions) > 0: 77 | captions = captions + " " 78 | strat_idx = len(captions) 79 | end_idx = strat_idx + len(subname) 80 | tokens_positive_i.append([strat_idx, end_idx]) 81 | captions = captions + subname 82 | 83 | if len(tokens_positive_i) > 0: 84 | captions = captions + " ." 85 | cat2tokenspan[class_name] = tokens_positive_i 86 | 87 | return captions, cat2tokenspan 88 | 89 | 90 | def build_id2posspan_and_caption(category_dict: dict): 91 | """Build id2pos_span and caption from category_dict 92 | 93 | Args: 94 | category_dict (dict): category_dict 95 | """ 96 | cat_list = [item["name"].lower() for item in category_dict] 97 | id2catname = {item["id"]: item["name"].lower() for item in category_dict} 98 | caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True) 99 | id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()} 100 | return id2posspan, caption 101 | -------------------------------------------------------------------------------- /GroundingDINO/groundingdino/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | -------------------------------------------------------------------------------- /GroundingDINO/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | transformers 4 | addict 5 | yapf 6 | timm 7 | numpy 8 | opencv-python 9 | supervision==0.3.2 10 | pycocotools -------------------------------------------------------------------------------- /GroundingDINO/setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The IDEA Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------------------------------ 16 | # Modified from 17 | # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py 18 | # https://github.com/facebookresearch/detectron2/blob/main/setup.py 19 | # https://github.com/open-mmlab/mmdetection/blob/master/setup.py 20 | # https://github.com/Oneflow-Inc/libai/blob/main/setup.py 21 | # ------------------------------------------------------------------------------------------------ 22 | 23 | import glob 24 | import os 25 | import subprocess 26 | 27 | import torch 28 | from setuptools import find_packages, setup 29 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 30 | 31 | # groundingdino version info 32 | version = "0.1.0" 33 | package_name = "groundingdino" 34 | cwd = os.path.dirname(os.path.abspath(__file__)) 35 | 36 | 37 | sha = "Unknown" 38 | try: 39 | sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip() 40 | except Exception: 41 | pass 42 | 43 | 44 | def write_version_file(): 45 | version_path = os.path.join(cwd, "groundingdino", "version.py") 46 | with open(version_path, "w") as f: 47 | f.write(f"__version__ = '{version}'\n") 48 | # f.write(f"git_version = {repr(sha)}\n") 49 | 50 | 51 | requirements = ["torch", "torchvision"] 52 | 53 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 54 | 55 | 56 | def get_extensions(): 57 | this_dir = os.path.dirname(os.path.abspath(__file__)) 58 | extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc") 59 | 60 | main_source = os.path.join(extensions_dir, "vision.cpp") 61 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) 62 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( 63 | os.path.join(extensions_dir, "*.cu") 64 | ) 65 | 66 | sources = [main_source] + sources 67 | 68 | extension = CppExtension 69 | 70 | extra_compile_args = {"cxx": []} 71 | define_macros = [] 72 | 73 | if torch.cuda.is_available() and CUDA_HOME is not None: 74 | print("Compiling with CUDA") 75 | extension = CUDAExtension 76 | sources += source_cuda 77 | define_macros += [("WITH_CUDA", None)] 78 | extra_compile_args["nvcc"] = [ 79 | "-DCUDA_HAS_FP16=1", 80 | "-D__CUDA_NO_HALF_OPERATORS__", 81 | "-D__CUDA_NO_HALF_CONVERSIONS__", 82 | "-D__CUDA_NO_HALF2_OPERATORS__", 83 | ] 84 | else: 85 | print("Compiling without CUDA") 86 | define_macros += [("WITH_HIP", None)] 87 | extra_compile_args["nvcc"] = [] 88 | return None 89 | 90 | sources = [os.path.join(extensions_dir, s) for s in sources] 91 | include_dirs = [extensions_dir] 92 | 93 | ext_modules = [ 94 | extension( 95 | "groundingdino._C", 96 | sources, 97 | include_dirs=include_dirs, 98 | define_macros=define_macros, 99 | extra_compile_args=extra_compile_args, 100 | ) 101 | ] 102 | 103 | return ext_modules 104 | 105 | 106 | def parse_requirements(fname="requirements.txt", with_version=True): 107 | """Parse the package dependencies listed in a requirements file but strips 108 | specific versioning information. 109 | 110 | Args: 111 | fname (str): path to requirements file 112 | with_version (bool, default=False): if True include version specs 113 | 114 | Returns: 115 | List[str]: list of requirements items 116 | 117 | CommandLine: 118 | python -c "import setup; print(setup.parse_requirements())" 119 | """ 120 | import re 121 | import sys 122 | from os.path import exists 123 | 124 | require_fpath = fname 125 | 126 | def parse_line(line): 127 | """Parse information from a line in a requirements text file.""" 128 | if line.startswith("-r "): 129 | # Allow specifying requirements in other files 130 | target = line.split(" ")[1] 131 | for info in parse_require_file(target): 132 | yield info 133 | else: 134 | info = {"line": line} 135 | if line.startswith("-e "): 136 | info["package"] = line.split("#egg=")[1] 137 | elif "@git+" in line: 138 | info["package"] = line 139 | else: 140 | # Remove versioning from the package 141 | pat = "(" + "|".join([">=", "==", ">"]) + ")" 142 | parts = re.split(pat, line, maxsplit=1) 143 | parts = [p.strip() for p in parts] 144 | 145 | info["package"] = parts[0] 146 | if len(parts) > 1: 147 | op, rest = parts[1:] 148 | if ";" in rest: 149 | # Handle platform specific dependencies 150 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 151 | version, platform_deps = map(str.strip, rest.split(";")) 152 | info["platform_deps"] = platform_deps 153 | else: 154 | version = rest # NOQA 155 | info["version"] = (op, version) 156 | yield info 157 | 158 | def parse_require_file(fpath): 159 | with open(fpath, "r") as f: 160 | for line in f.readlines(): 161 | line = line.strip() 162 | if line and not line.startswith("#"): 163 | for info in parse_line(line): 164 | yield info 165 | 166 | def gen_packages_items(): 167 | if exists(require_fpath): 168 | for info in parse_require_file(require_fpath): 169 | parts = [info["package"]] 170 | if with_version and "version" in info: 171 | parts.extend(info["version"]) 172 | if not sys.version.startswith("3.4"): 173 | # apparently package_deps are broken in 3.4 174 | platform_deps = info.get("platform_deps") 175 | if platform_deps is not None: 176 | parts.append(";" + platform_deps) 177 | item = "".join(parts) 178 | yield item 179 | 180 | packages = list(gen_packages_items()) 181 | return packages 182 | 183 | 184 | if __name__ == "__main__": 185 | print(f"Building wheel {package_name}-{version}") 186 | 187 | with open("LICENSE", "r", encoding="utf-8") as f: 188 | license = f.read() 189 | 190 | write_version_file() 191 | 192 | setup( 193 | name="groundingdino", 194 | version="0.1.0", 195 | author="International Digital Economy Academy, Shilong Liu", 196 | url="https://github.com/IDEA-Research/GroundingDINO", 197 | description="open-set object detector", 198 | license=license, 199 | install_requires=parse_requirements("requirements.txt"), 200 | packages=find_packages( 201 | exclude=( 202 | "configs", 203 | "tests", 204 | ) 205 | ), 206 | ext_modules=get_extensions(), 207 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 208 | ) 209 | -------------------------------------------------------------------------------- /assets/Grounded-SAM_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/Grounded-SAM_logo.png -------------------------------------------------------------------------------- /assets/automatic_label_output/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo1.jpg -------------------------------------------------------------------------------- /assets/automatic_label_output/demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo2.jpg -------------------------------------------------------------------------------- /assets/automatic_label_output/demo4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo4.jpg -------------------------------------------------------------------------------- /assets/automatic_label_output/demo8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output/demo8.jpg -------------------------------------------------------------------------------- /assets/automatic_label_output_demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/automatic_label_output_demo3.jpg -------------------------------------------------------------------------------- /assets/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo1.jpg -------------------------------------------------------------------------------- /assets/demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo2.jpg -------------------------------------------------------------------------------- /assets/demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo3.jpg -------------------------------------------------------------------------------- /assets/demo4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo4.jpg -------------------------------------------------------------------------------- /assets/demo5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo5.jpg -------------------------------------------------------------------------------- /assets/demo6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo6.jpg -------------------------------------------------------------------------------- /assets/demo7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo7.jpg -------------------------------------------------------------------------------- /assets/demo8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/demo8.jpg -------------------------------------------------------------------------------- /assets/gradio_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/gradio_demo.png -------------------------------------------------------------------------------- /assets/grounded_sam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam.jpg -------------------------------------------------------------------------------- /assets/grounded_sam2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam2.png -------------------------------------------------------------------------------- /assets/grounded_sam_demo3_demo4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam_demo3_demo4.png -------------------------------------------------------------------------------- /assets/grounded_sam_inpainting_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam_inpainting_demo.png -------------------------------------------------------------------------------- /assets/grounded_sam_output_demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounded_sam_output_demo1.jpg -------------------------------------------------------------------------------- /assets/grounding_dino_output_demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/grounding_dino_output_demo1.jpg -------------------------------------------------------------------------------- /assets/inpaint_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/assets/inpaint_demo.jpg -------------------------------------------------------------------------------- /grounded_sam_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import copy 4 | 5 | import numpy as np 6 | import json 7 | import torch 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | # Grounding DINO 11 | import GroundingDINO.groundingdino.datasets.transforms as T 12 | from GroundingDINO.groundingdino.models import build_model 13 | from GroundingDINO.groundingdino.util import box_ops 14 | from GroundingDINO.groundingdino.util.slconfig import SLConfig 15 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap 16 | 17 | # segment anything 18 | from segment_anything import build_sam, SamPredictor 19 | import cv2 20 | import numpy as np 21 | import matplotlib.pyplot as plt 22 | 23 | 24 | def load_image(image_path): 25 | # load image 26 | image_pil = Image.open(image_path).convert("RGB") # load image 27 | 28 | transform = T.Compose( 29 | [ 30 | T.RandomResize([800], max_size=1333), 31 | T.ToTensor(), 32 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 33 | ] 34 | ) 35 | image, _ = transform(image_pil, None) # 3, h, w 36 | return image_pil, image 37 | 38 | 39 | def load_model(model_config_path, model_checkpoint_path, device): 40 | args = SLConfig.fromfile(model_config_path) 41 | args.device = device 42 | model = build_model(args) 43 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") 44 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 45 | print(load_res) 46 | _ = model.eval() 47 | return model 48 | 49 | 50 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): 51 | caption = caption.lower() 52 | caption = caption.strip() 53 | if not caption.endswith("."): 54 | caption = caption + "." 55 | model = model.to(device) 56 | image = image.to(device) 57 | with torch.no_grad(): 58 | outputs = model(image[None], captions=[caption]) 59 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) 60 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) 61 | logits.shape[0] 62 | 63 | # filter output 64 | logits_filt = logits.clone() 65 | boxes_filt = boxes.clone() 66 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold 67 | logits_filt = logits_filt[filt_mask] # num_filt, 256 68 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4 69 | logits_filt.shape[0] 70 | 71 | # get phrase 72 | tokenlizer = model.tokenizer 73 | tokenized = tokenlizer(caption) 74 | # build pred 75 | pred_phrases = [] 76 | for logit, box in zip(logits_filt, boxes_filt): 77 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) 78 | if with_logits: 79 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") 80 | else: 81 | pred_phrases.append(pred_phrase) 82 | 83 | return boxes_filt, pred_phrases 84 | 85 | def show_mask(mask, ax, random_color=False): 86 | if random_color: 87 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 88 | else: 89 | color = np.array([30/255, 144/255, 255/255, 0.6]) 90 | h, w = mask.shape[-2:] 91 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 92 | ax.imshow(mask_image) 93 | 94 | 95 | def show_box(box, ax, label): 96 | x0, y0 = box[0], box[1] 97 | w, h = box[2] - box[0], box[3] - box[1] 98 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 99 | ax.text(x0, y0, label) 100 | 101 | 102 | def save_mask_data(output_dir, mask_list, box_list, label_list): 103 | value = 0 # 0 for background 104 | 105 | mask_img = torch.zeros(mask_list.shape[-2:]) 106 | for idx, mask in enumerate(mask_list): 107 | mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1 108 | plt.figure(figsize=(10, 10)) 109 | plt.imshow(mask_img.numpy()) 110 | plt.axis('off') 111 | plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0) 112 | 113 | json_data = [{ 114 | 'value': value, 115 | 'label': 'background' 116 | }] 117 | for label, box in zip(label_list, box_list): 118 | value += 1 119 | name, logit = label.split('(') 120 | logit = logit[:-1] # the last is ')' 121 | json_data.append({ 122 | 'value': value, 123 | 'label': name, 124 | 'logit': float(logit), 125 | 'box': box.numpy().tolist(), 126 | }) 127 | with open(os.path.join(output_dir, 'mask.json'), 'w') as f: 128 | json.dump(json_data, f) 129 | 130 | 131 | if __name__ == "__main__": 132 | 133 | parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True) 134 | parser.add_argument("--config", type=str, required=True, help="path to config file") 135 | parser.add_argument( 136 | "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file" 137 | ) 138 | parser.add_argument( 139 | "--sam_checkpoint", type=str, required=True, help="path to checkpoint file" 140 | ) 141 | parser.add_argument("--input_image", type=str, required=True, help="path to image file") 142 | parser.add_argument("--text_prompt", type=str, required=True, help="text prompt") 143 | parser.add_argument( 144 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" 145 | ) 146 | 147 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") 148 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") 149 | 150 | parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False") 151 | args = parser.parse_args() 152 | 153 | # cfg 154 | config_file = args.config # change the path of the model config file 155 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model 156 | sam_checkpoint = args.sam_checkpoint 157 | image_path = args.input_image 158 | text_prompt = args.text_prompt 159 | output_dir = args.output_dir 160 | box_threshold = args.box_threshold 161 | text_threshold = args.box_threshold 162 | device = args.device 163 | 164 | # make dir 165 | os.makedirs(output_dir, exist_ok=True) 166 | # load image 167 | image_pil, image = load_image(image_path) 168 | # load model 169 | model = load_model(config_file, grounded_checkpoint, device=device) 170 | 171 | # visualize raw image 172 | image_pil.save(os.path.join(output_dir, "raw_image.jpg")) 173 | 174 | # run grounding dino model 175 | boxes_filt, pred_phrases = get_grounding_output( 176 | model, image, text_prompt, box_threshold, text_threshold, device=device 177 | ) 178 | 179 | # initialize SAM 180 | predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint)) 181 | image = cv2.imread(image_path) 182 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 183 | predictor.set_image(image) 184 | 185 | size = image_pil.size 186 | H, W = size[1], size[0] 187 | for i in range(boxes_filt.size(0)): 188 | boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) 189 | boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 190 | boxes_filt[i][2:] += boxes_filt[i][:2] 191 | 192 | boxes_filt = boxes_filt.cpu() 193 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]) 194 | 195 | masks, _, _ = predictor.predict_torch( 196 | point_coords = None, 197 | point_labels = None, 198 | boxes = transformed_boxes, 199 | multimask_output = False, 200 | ) 201 | 202 | # draw output image 203 | plt.figure(figsize=(10, 10)) 204 | plt.imshow(image) 205 | for mask in masks: 206 | show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) 207 | for box, label in zip(boxes_filt, pred_phrases): 208 | show_box(box.numpy(), plt.gca(), label) 209 | 210 | plt.axis('off') 211 | plt.savefig( 212 | os.path.join(output_dir, "grounded_sam_output.jpg"), 213 | bbox_inches="tight", dpi=300, pad_inches=0.0 214 | ) 215 | 216 | save_mask_data(output_dir, masks, boxes_filt, pred_phrases) 217 | 218 | -------------------------------------------------------------------------------- /grounded_sam_inpainting_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import copy 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | # Grounding DINO 10 | import GroundingDINO.groundingdino.datasets.transforms as T 11 | from GroundingDINO.groundingdino.models import build_model 12 | from GroundingDINO.groundingdino.util import box_ops 13 | from GroundingDINO.groundingdino.util.slconfig import SLConfig 14 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap 15 | 16 | # segment anything 17 | from segment_anything import build_sam, SamPredictor 18 | import cv2 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | 22 | 23 | # diffusers 24 | import PIL 25 | import requests 26 | import torch 27 | from io import BytesIO 28 | from diffusers import StableDiffusionInpaintPipeline 29 | 30 | 31 | def load_image(image_path): 32 | # load image 33 | image_pil = Image.open(image_path).convert("RGB") # load image 34 | 35 | transform = T.Compose( 36 | [ 37 | T.RandomResize([800], max_size=1333), 38 | T.ToTensor(), 39 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 40 | ] 41 | ) 42 | image, _ = transform(image_pil, None) # 3, h, w 43 | return image_pil, image 44 | 45 | 46 | def load_model(model_config_path, model_checkpoint_path, device): 47 | args = SLConfig.fromfile(model_config_path) 48 | args.device = device 49 | model = build_model(args) 50 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") 51 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 52 | print(load_res) 53 | _ = model.eval() 54 | return model 55 | 56 | 57 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): 58 | caption = caption.lower() 59 | caption = caption.strip() 60 | if not caption.endswith("."): 61 | caption = caption + "." 62 | model = model.to(device) 63 | image = image.to(device) 64 | with torch.no_grad(): 65 | outputs = model(image[None], captions=[caption]) 66 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) 67 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) 68 | logits.shape[0] 69 | 70 | # filter output 71 | logits_filt = logits.clone() 72 | boxes_filt = boxes.clone() 73 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold 74 | logits_filt = logits_filt[filt_mask] # num_filt, 256 75 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4 76 | logits_filt.shape[0] 77 | 78 | # get phrase 79 | tokenlizer = model.tokenizer 80 | tokenized = tokenlizer(caption) 81 | # build pred 82 | pred_phrases = [] 83 | for logit, box in zip(logits_filt, boxes_filt): 84 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) 85 | if with_logits: 86 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") 87 | else: 88 | pred_phrases.append(pred_phrase) 89 | 90 | return boxes_filt, pred_phrases 91 | 92 | def show_mask(mask, ax, random_color=False): 93 | if random_color: 94 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 95 | else: 96 | color = np.array([30/255, 144/255, 255/255, 0.6]) 97 | h, w = mask.shape[-2:] 98 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 99 | ax.imshow(mask_image) 100 | 101 | 102 | def show_box(box, ax, label): 103 | x0, y0 = box[0], box[1] 104 | w, h = box[2] - box[0], box[3] - box[1] 105 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 106 | ax.text(x0, y0, label) 107 | 108 | 109 | if __name__ == "__main__": 110 | 111 | parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True) 112 | parser.add_argument("--config", type=str, required=True, help="path to config file") 113 | parser.add_argument( 114 | "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file" 115 | ) 116 | parser.add_argument( 117 | "--sam_checkpoint", type=str, required=True, help="path to checkpoint file" 118 | ) 119 | parser.add_argument("--input_image", type=str, required=True, help="path to image file") 120 | parser.add_argument("--det_prompt", type=str, required=True, help="text prompt") 121 | parser.add_argument("--inpaint_prompt", type=str, required=True, help="inpaint prompt") 122 | parser.add_argument( 123 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" 124 | ) 125 | 126 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") 127 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") 128 | parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False") 129 | args = parser.parse_args() 130 | 131 | # cfg 132 | config_file = args.config # change the path of the model config file 133 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model 134 | sam_checkpoint = args.sam_checkpoint 135 | image_path = args.input_image 136 | det_prompt = args.det_prompt 137 | inpaint_prompt = args.inpaint_prompt 138 | output_dir = args.output_dir 139 | box_threshold = args.box_threshold 140 | text_threshold = args.box_threshold 141 | device = args.device 142 | 143 | # make dir 144 | os.makedirs(output_dir, exist_ok=True) 145 | # load image 146 | image_pil, image = load_image(image_path) 147 | # load model 148 | model = load_model(config_file, grounded_checkpoint, device=device) 149 | 150 | # visualize raw image 151 | image_pil.save(os.path.join(output_dir, "raw_image.jpg")) 152 | 153 | # run grounding dino model 154 | boxes_filt, pred_phrases = get_grounding_output( 155 | model, image, det_prompt, box_threshold, text_threshold, device=device 156 | ) 157 | 158 | # initialize SAM 159 | predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint)) 160 | image = cv2.imread(image_path) 161 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 162 | predictor.set_image(image) 163 | 164 | size = image_pil.size 165 | H, W = size[1], size[0] 166 | for i in range(boxes_filt.size(0)): 167 | boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) 168 | boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 169 | boxes_filt[i][2:] += boxes_filt[i][:2] 170 | 171 | boxes_filt = boxes_filt.cpu() 172 | transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]) 173 | 174 | masks, _, _ = predictor.predict_torch( 175 | point_coords = None, 176 | point_labels = None, 177 | boxes = transformed_boxes, 178 | multimask_output = False, 179 | ) 180 | 181 | # masks: [1, 1, 512, 512] 182 | 183 | # inpainting pipeline 184 | mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release 185 | mask_pil = Image.fromarray(mask) 186 | image_pil = Image.fromarray(image) 187 | 188 | pipe = StableDiffusionInpaintPipeline.from_pretrained( 189 | "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 190 | ) 191 | pipe = pipe.to("cuda") 192 | 193 | # prompt = "A sofa, high quality, detailed" 194 | image = pipe(prompt=inpaint_prompt, image=image_pil, mask_image=mask_pil).images[0] 195 | image.save(os.path.join(output_dir, "grounded_sam_inpainting_output.jpg")) 196 | 197 | # draw output image 198 | # plt.figure(figsize=(10, 10)) 199 | # plt.imshow(image) 200 | # for mask in masks: 201 | # show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) 202 | # for box, label in zip(boxes_filt, pred_phrases): 203 | # show_box(box.numpy(), plt.gca(), label) 204 | # plt.axis('off') 205 | # plt.savefig(os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight") 206 | 207 | -------------------------------------------------------------------------------- /grounding_dino_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image, ImageDraw, ImageFont 7 | 8 | import GroundingDINO.groundingdino.datasets.transforms as T 9 | from GroundingDINO.groundingdino.models import build_model 10 | from GroundingDINO.groundingdino.util import box_ops 11 | from GroundingDINO.groundingdino.util.slconfig import SLConfig 12 | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap 13 | 14 | 15 | def plot_boxes_to_image(image_pil, tgt): 16 | H, W = tgt["size"] 17 | boxes = tgt["boxes"] 18 | labels = tgt["labels"] 19 | assert len(boxes) == len(labels), "boxes and labels must have same length" 20 | 21 | draw = ImageDraw.Draw(image_pil) 22 | mask = Image.new("L", image_pil.size, 0) 23 | mask_draw = ImageDraw.Draw(mask) 24 | 25 | # draw boxes and masks 26 | for box, label in zip(boxes, labels): 27 | # from 0..1 to 0..W, 0..H 28 | box = box * torch.Tensor([W, H, W, H]) 29 | # from xywh to xyxy 30 | box[:2] -= box[2:] / 2 31 | box[2:] += box[:2] 32 | # random color 33 | color = tuple(np.random.randint(0, 255, size=3).tolist()) 34 | # draw 35 | x0, y0, x1, y1 = box 36 | x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) 37 | 38 | draw.rectangle([x0, y0, x1, y1], outline=color, width=6) 39 | # draw.text((x0, y0), str(label), fill=color) 40 | 41 | font = ImageFont.load_default() 42 | if hasattr(font, "getbbox"): 43 | bbox = draw.textbbox((x0, y0), str(label), font) 44 | else: 45 | w, h = draw.textsize(str(label), font) 46 | bbox = (x0, y0, w + x0, y0 + h) 47 | # bbox = draw.textbbox((x0, y0), str(label)) 48 | draw.rectangle(bbox, fill=color) 49 | draw.text((x0, y0), str(label), fill="white") 50 | 51 | mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) 52 | 53 | return image_pil, mask 54 | 55 | 56 | def load_image(image_path): 57 | # load image 58 | image_pil = Image.open(image_path).convert("RGB") # load image 59 | 60 | transform = T.Compose( 61 | [ 62 | T.RandomResize([800], max_size=1333), 63 | T.ToTensor(), 64 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 65 | ] 66 | ) 67 | image, _ = transform(image_pil, None) # 3, h, w 68 | return image_pil, image 69 | 70 | 71 | def load_model(model_config_path, model_checkpoint_path, device="cpu"): 72 | args = SLConfig.fromfile(model_config_path) 73 | args.device = device 74 | model = build_model(args) 75 | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") 76 | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 77 | print(load_res) 78 | _ = model.eval() 79 | return model 80 | 81 | 82 | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): 83 | caption = caption.lower() 84 | caption = caption.strip() 85 | if not caption.endswith("."): 86 | caption = caption + "." 87 | model = model.to(device) 88 | image = image.to(device) 89 | with torch.no_grad(): 90 | outputs = model(image[None], captions=[caption]) 91 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) 92 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) 93 | logits.shape[0] 94 | 95 | # filter output 96 | logits_filt = logits.clone() 97 | boxes_filt = boxes.clone() 98 | filt_mask = logits_filt.max(dim=1)[0] > box_threshold 99 | logits_filt = logits_filt[filt_mask] # num_filt, 256 100 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4 101 | logits_filt.shape[0] 102 | 103 | # get phrase 104 | tokenlizer = model.tokenizer 105 | tokenized = tokenlizer(caption) 106 | # build pred 107 | pred_phrases = [] 108 | for logit, box in zip(logits_filt, boxes_filt): 109 | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) 110 | if with_logits: 111 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") 112 | else: 113 | pred_phrases.append(pred_phrase) 114 | 115 | return boxes_filt, pred_phrases 116 | 117 | 118 | if __name__ == "__main__": 119 | 120 | parser = argparse.ArgumentParser("Grounding DINO example", add_help=True) 121 | parser.add_argument("--config", type=str, required=True, help="path to config file") 122 | parser.add_argument( 123 | "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file" 124 | ) 125 | parser.add_argument("--input_image", type=str, required=True, help="path to image file") 126 | parser.add_argument("--text_prompt", type=str, required=True, help="text prompt") 127 | parser.add_argument( 128 | "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory" 129 | ) 130 | 131 | parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") 132 | parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") 133 | 134 | parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False") 135 | args = parser.parse_args() 136 | 137 | # cfg 138 | config_file = args.config # change the path of the model config file 139 | grounded_checkpoint = args.grounded_checkpoint # change the path of the model 140 | image_path = args.input_image 141 | text_prompt = args.text_prompt 142 | output_dir = args.output_dir 143 | box_threshold = args.box_threshold 144 | text_threshold = args.box_threshold 145 | device = args.device 146 | 147 | # make dir 148 | os.makedirs(output_dir, exist_ok=True) 149 | # load image 150 | image_pil, image = load_image(image_path) 151 | # load model 152 | model = load_model(config_file, grounded_checkpoint, device=device) 153 | 154 | # visualize raw image 155 | # image_pil.save(os.path.join(output_dir, "raw_image.jpg")) 156 | 157 | # run model 158 | boxes_filt, pred_phrases = get_grounding_output( 159 | model, image, text_prompt, box_threshold, text_threshold, device=device 160 | ) 161 | 162 | # visualize pred 163 | size = image_pil.size 164 | pred_dict = { 165 | "boxes": boxes_filt, 166 | "size": [size[1], size[0]], # H,W 167 | "labels": pred_phrases, 168 | } 169 | # import ipdb; ipdb.set_trace() 170 | image_with_box = plot_boxes_to_image(image_pil, pred_dict)[0] 171 | image_with_box.save(os.path.join(output_dir, "grounding_dino_output.jpg")) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | diffusers 3 | gradio 4 | huggingface_hub 5 | matplotlib 6 | numpy 7 | onnxruntime 8 | opencv_python 9 | Pillow 10 | pycocotools 11 | PyYAML 12 | requests 13 | setuptools 14 | supervision 15 | termcolor 16 | timm 17 | torch 18 | torchvision 19 | transformers 20 | yapf 21 | -------------------------------------------------------------------------------- /segment_anything/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002 3 | max-line-length = 100 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | per-file-ignores = 7 | **/__init__.py:F401,F403,E402 8 | -------------------------------------------------------------------------------- /segment_anything/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /segment_anything/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /segment_anything/README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything 2 | 3 | **[Meta AI Research, FAIR](https://ai.facebook.com/research/)** 4 | 5 | [Alexander Kirillov](https://alexander-kirillov.github.io/), [Eric Mintun](https://ericmintun.github.io/), [Nikhila Ravi](https://nikhilaravi.com/), [Hanzi Mao](https://hanzimao.me/), Chloe Rolland, Laura Gustafson, [Tete Xiao](https://tetexiao.com), [Spencer Whitehead](https://www.spencerwhitehead.com/), Alex Berg, Wan-Yen Lo, [Piotr Dollar](https://pdollar.github.io/), [Ross Girshick](https://www.rossgirshick.info/) 6 | 7 | [[`Paper`](https://ai.facebook.com/research/publications/segment-anything/)] [[`Project`](https://segment-anything.com/)] [[`Demo`](https://segment-anything.com/demo)] [[`Dataset`](https://segment-anything.com/dataset/index.html)] [[`Blog`](https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/)] 8 | 9 | ![SAM design](assets/model_diagram.png?raw=true) 10 | 11 | The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a [dataset](https://segment-anything.com/dataset/index.html) of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. 12 | 13 |

14 | 15 | 16 |

17 | 18 | ## Installation 19 | 20 | The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. 21 | 22 | Install Segment Anything: 23 | 24 | ``` 25 | pip install git+https://github.com/facebookresearch/segment-anything.git 26 | ``` 27 | 28 | or clone the repository locally and install with 29 | 30 | ``` 31 | git clone git@github.com:facebookresearch/segment-anything.git 32 | cd segment-anything; pip install -e . 33 | ``` 34 | 35 | The following optional dependencies are necessary for mask post-processing, saving masks in COCO format, the example notebooks, and exporting the model in ONNX format. `jupyter` is also required to run the example notebooks. 36 | ``` 37 | pip install opencv-python pycocotools matplotlib onnxruntime onnx 38 | ``` 39 | 40 | 41 | ## Getting Started 42 | 43 | First download a [model checkpoint](#model-checkpoints). Then the model can be used in just a few lines to get masks from a given prompt: 44 | 45 | ``` 46 | from segment_anything import build_sam, SamPredictor 47 | predictor = SamPredictor(build_sam(checkpoint="")) 48 | predictor.set_image() 49 | masks, _, _ = predictor.predict() 50 | ``` 51 | 52 | or generate masks for an entire image: 53 | 54 | ``` 55 | from segment_anything import build_sam, SamAutomaticMaskGenerator 56 | mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="")) 57 | masks = mask_generator_generate() 58 | ``` 59 | 60 | Additionally, masks can be generated for images from the command line: 61 | 62 | ``` 63 | python scripts/amg.py --checkpoint --input --output 64 | ``` 65 | 66 | See the examples notebooks on [using SAM with prompts](/notebooks/predictor_example.ipynb) and [automatically generating masks](/notebooks/automatic_mask_generator_example.ipynb) for more details. 67 | 68 |

69 | 70 | 71 |

72 | 73 | ## ONNX Export 74 | 75 | SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with 76 | 77 | ``` 78 | python scripts/export_onnx_model.py --checkpoint --output 79 | ``` 80 | 81 | See the [example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) for details on how to combine image preprocessing via SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export. 82 | 83 | ## Model Checkpoints 84 | 85 | Three model versions of the model are available with different backbone sizes. These models can be instantiated by running 86 | ``` 87 | from segment_anything import sam_model_registry 88 | sam = sam_model_registry[""](checkpoint="") 89 | ``` 90 | Click the links below to download the checkpoint for the corresponding model name. The default model in bold can also be instantiated with `build_sam`, as in the examples in [Getting Started](#getting-started). 91 | 92 | * **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)** 93 | * `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) 94 | * `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) 95 | 96 | ## License 97 | The model is licensed under the [Apache 2.0 license](LICENSE). 98 | 99 | ## Contributing 100 | 101 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 102 | 103 | ## Contributors 104 | 105 | The Segment Anything project was made possible with the help of many contributors (alphabetical): 106 | 107 | Aaron Adcock, Vaibhav Aggarwal, Morteza Behrooz, Cheng-Yang Fu, Ashley Gabriel, Ahuva Goldstand, Allen Goodman, Sumanth Gurram, Jiabo Hu, Somya Jain, Devansh Kukreja, Robert Kuo, Joshua Lane, Yanghao Li, Lilian Luong, Jitendra Malik, Mallika Malhotra, William Ngan, Omkar Parkhi, Nikhil Raina, Dirk Rowe, Neil Sejoor, Vanessa Stark, Bala Varadarajan, Bram Wasti, Zachary Winstrom 108 | -------------------------------------------------------------------------------- /segment_anything/assets/masks1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/masks1.png -------------------------------------------------------------------------------- /segment_anything/assets/masks2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/masks2.jpg -------------------------------------------------------------------------------- /segment_anything/assets/model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/model_diagram.png -------------------------------------------------------------------------------- /segment_anything/assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/notebook1.png -------------------------------------------------------------------------------- /segment_anything/assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/assets/notebook2.png -------------------------------------------------------------------------------- /segment_anything/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /segment_anything/notebooks/images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/notebooks/images/dog.jpg -------------------------------------------------------------------------------- /segment_anything/notebooks/images/groceries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/notebooks/images/groceries.jpg -------------------------------------------------------------------------------- /segment_anything/notebooks/images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeanFly/Grounded-Segment-Anything-API/030e3cc5910546686e4749d75c905e2040a59b7e/segment_anything/notebooks/images/truck.jpg -------------------------------------------------------------------------------- /segment_anything/scripts/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import cv2 # type: ignore 8 | 9 | from segment_anything import SamAutomaticMaskGenerator, sam_model_registry 10 | 11 | import argparse 12 | import json 13 | import os 14 | from typing import Any, Dict, List 15 | 16 | parser = argparse.ArgumentParser( 17 | description=( 18 | "Runs automatic mask generation on an input image or directory of images, " 19 | "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " 20 | "as well as pycocotools if saving in RLE format." 21 | ) 22 | ) 23 | 24 | parser.add_argument( 25 | "--input", 26 | type=str, 27 | required=True, 28 | help="Path to either a single input image or folder of images.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--output", 33 | type=str, 34 | required=True, 35 | help=( 36 | "Path to the directory where masks will be output. Output will be either a folder " 37 | "of PNGs per image or a single json with COCO-style masks." 38 | ), 39 | ) 40 | 41 | parser.add_argument( 42 | "--model-type", 43 | type=str, 44 | default="default", 45 | help="The type of model to load, in ['default', 'vit_l', 'vit_b']", 46 | ) 47 | 48 | parser.add_argument( 49 | "--checkpoint", 50 | type=str, 51 | required=True, 52 | help="The path to the SAM checkpoint to use for mask generation.", 53 | ) 54 | 55 | parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") 56 | 57 | parser.add_argument( 58 | "--convert-to-rle", 59 | action="store_true", 60 | help=( 61 | "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " 62 | "Requires pycocotools." 63 | ), 64 | ) 65 | 66 | amg_settings = parser.add_argument_group("AMG Settings") 67 | 68 | amg_settings.add_argument( 69 | "--points-per-side", 70 | type=int, 71 | default=None, 72 | help="Generate masks by sampling a grid over the image with this many points to a side.", 73 | ) 74 | 75 | amg_settings.add_argument( 76 | "--points-per-batch", 77 | type=int, 78 | default=None, 79 | help="How many input points to process simultaneously in one batch.", 80 | ) 81 | 82 | amg_settings.add_argument( 83 | "--pred-iou-thresh", 84 | type=float, 85 | default=None, 86 | help="Exclude masks with a predicted score from the model that is lower than this threshold.", 87 | ) 88 | 89 | amg_settings.add_argument( 90 | "--stability-score-thresh", 91 | type=float, 92 | default=None, 93 | help="Exclude masks with a stability score lower than this threshold.", 94 | ) 95 | 96 | amg_settings.add_argument( 97 | "--stability-score-offset", 98 | type=float, 99 | default=None, 100 | help="Larger values perturb the mask more when measuring stability score.", 101 | ) 102 | 103 | amg_settings.add_argument( 104 | "--box-nms-thresh", 105 | type=float, 106 | default=None, 107 | help="The overlap threshold for excluding a duplicate mask.", 108 | ) 109 | 110 | amg_settings.add_argument( 111 | "--crop-n-layers", 112 | type=int, 113 | default=None, 114 | help=( 115 | "If >0, mask generation is run on smaller crops of the image to generate more masks. " 116 | "The value sets how many different scales to crop at." 117 | ), 118 | ) 119 | 120 | amg_settings.add_argument( 121 | "--crop-nms-thresh", 122 | type=float, 123 | default=None, 124 | help="The overlap threshold for excluding duplicate masks across different crops.", 125 | ) 126 | 127 | amg_settings.add_argument( 128 | "--crop-overlap-ratio", 129 | type=int, 130 | default=None, 131 | help="Larger numbers mean image crops will overlap more.", 132 | ) 133 | 134 | amg_settings.add_argument( 135 | "--crop-n-points-downscale-factor", 136 | type=int, 137 | default=None, 138 | help="The number of points-per-side in each layer of crop is reduced by this factor.", 139 | ) 140 | 141 | amg_settings.add_argument( 142 | "--min-mask-region-area", 143 | type=int, 144 | default=None, 145 | help=( 146 | "Disconnected mask regions or holes with area smaller than this value " 147 | "in pixels are removed by postprocessing." 148 | ), 149 | ) 150 | 151 | 152 | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: 153 | header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa 154 | metadata = [header] 155 | for i, mask_data in enumerate(masks): 156 | mask = mask_data["segmentation"] 157 | filename = f"{i}.png" 158 | cv2.imwrite(os.path.join(path, filename), mask * 255) 159 | mask_metadata = [ 160 | str(i), 161 | str(mask_data["area"]), 162 | *[str(x) for x in mask_data["bbox"]], 163 | *[str(x) for x in mask_data["point_coords"][0]], 164 | str(mask_data["predicted_iou"]), 165 | str(mask_data["stability_score"]), 166 | *[str(x) for x in mask_data["crop_box"]], 167 | ] 168 | row = ",".join(mask_metadata) 169 | metadata.append(row) 170 | metadata_path = os.path.join(path, "metadata.csv") 171 | with open(metadata_path, "w") as f: 172 | f.write("\n".join(metadata)) 173 | 174 | return 175 | 176 | 177 | def get_amg_kwargs(args): 178 | amg_kwargs = { 179 | "points_per_side": args.points_per_side, 180 | "points_per_batch": args.points_per_batch, 181 | "pred_iou_thresh": args.pred_iou_thresh, 182 | "stability_score_thresh": args.stability_score_thresh, 183 | "stability_score_offset": args.stability_score_offset, 184 | "box_nms_thresh": args.box_nms_thresh, 185 | "crop_n_layers": args.crop_n_layers, 186 | "crop_nms_thresh": args.crop_nms_thresh, 187 | "crop_overlap_ratio": args.crop_overlap_ratio, 188 | "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, 189 | "min_mask_region_area": args.min_mask_region_area, 190 | } 191 | amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} 192 | return amg_kwargs 193 | 194 | 195 | def main(args: argparse.Namespace) -> None: 196 | print("Loading model...") 197 | sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint) 198 | _ = sam.to(device=args.device) 199 | output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" 200 | amg_kwargs = get_amg_kwargs(args) 201 | generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) 202 | 203 | if not os.path.isdir(args.input): 204 | targets = [args.input] 205 | else: 206 | targets = [ 207 | f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f)) 208 | ] 209 | targets = [os.path.join(args.input, f) for f in targets] 210 | 211 | os.makedirs(args.output, exist_ok=True) 212 | 213 | for t in targets: 214 | print(f"Processing '{t}'...") 215 | image = cv2.imread(t) 216 | if image is None: 217 | print(f"Could not load '{t}' as an image, skipping...") 218 | continue 219 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 220 | 221 | masks = generator.generate(image) 222 | 223 | base = os.path.basename(t) 224 | base = os.path.splitext(base)[0] 225 | save_base = os.path.join(args.output, base) 226 | if output_mode == "binary_mask": 227 | os.makedirs(save_base, exist_ok=False) 228 | write_masks_to_folder(masks, save_base) 229 | else: 230 | save_file = save_base + ".json" 231 | with open(save_file, "w") as f: 232 | json.dump(masks, f) 233 | print("Done!") 234 | 235 | 236 | if __name__ == "__main__": 237 | args = parser.parse_args() 238 | main(args) 239 | -------------------------------------------------------------------------------- /segment_anything/scripts/export_onnx_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l 10 | from segment_anything.utils.onnx import SamOnnxModel 11 | 12 | import argparse 13 | import warnings 14 | 15 | try: 16 | import onnxruntime # type: ignore 17 | 18 | onnxruntime_exists = True 19 | except ImportError: 20 | onnxruntime_exists = False 21 | 22 | parser = argparse.ArgumentParser( 23 | description="Export the SAM prompt encoder and mask decoder to an ONNX model." 24 | ) 25 | 26 | parser.add_argument( 27 | "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." 28 | ) 29 | 30 | parser.add_argument( 31 | "--output", type=str, required=True, help="The filename to save the ONNX model to." 32 | ) 33 | 34 | parser.add_argument( 35 | "--model-type", 36 | type=str, 37 | default="default", 38 | help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.", 39 | ) 40 | 41 | parser.add_argument( 42 | "--return-single-mask", 43 | action="store_true", 44 | help=( 45 | "If true, the exported ONNX model will only return the best mask, " 46 | "instead of returning multiple masks. For high resolution images " 47 | "this can improve runtime when upscaling masks is expensive." 48 | ), 49 | ) 50 | 51 | parser.add_argument( 52 | "--opset", 53 | type=int, 54 | default=17, 55 | help="The ONNX opset version to use. Must be >=11", 56 | ) 57 | 58 | parser.add_argument( 59 | "--quantize-out", 60 | type=str, 61 | default=None, 62 | help=( 63 | "If set, will quantize the model and save it with this name. " 64 | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." 65 | ), 66 | ) 67 | 68 | parser.add_argument( 69 | "--gelu-approximate", 70 | action="store_true", 71 | help=( 72 | "Replace GELU operations with approximations using tanh. Useful " 73 | "for some runtimes that have slow or unimplemented erf ops, used in GELU." 74 | ), 75 | ) 76 | 77 | parser.add_argument( 78 | "--use-stability-score", 79 | action="store_true", 80 | help=( 81 | "Replaces the model's predicted mask quality score with the stability " 82 | "score calculated on the low resolution masks using an offset of 1.0. " 83 | ), 84 | ) 85 | 86 | parser.add_argument( 87 | "--return-extra-metrics", 88 | action="store_true", 89 | help=( 90 | "The model will return five results: (masks, scores, stability_scores, " 91 | "areas, low_res_logits) instead of the usual three. This can be " 92 | "significantly slower for high resolution outputs." 93 | ), 94 | ) 95 | 96 | 97 | def run_export( 98 | model_type: str, 99 | checkpoint: str, 100 | output: str, 101 | opset: int, 102 | return_single_mask: bool, 103 | gelu_approximate: bool = False, 104 | use_stability_score: bool = False, 105 | return_extra_metrics=False, 106 | ): 107 | print("Loading model...") 108 | if model_type == "vit_b": 109 | sam = build_sam_vit_b(checkpoint) 110 | elif model_type == "vit_l": 111 | sam = build_sam_vit_l(checkpoint) 112 | else: 113 | sam = build_sam(checkpoint) 114 | 115 | onnx_model = SamOnnxModel( 116 | model=sam, 117 | return_single_mask=return_single_mask, 118 | use_stability_score=use_stability_score, 119 | return_extra_metrics=return_extra_metrics, 120 | ) 121 | 122 | if gelu_approximate: 123 | for n, m in onnx_model.named_modules(): 124 | if isinstance(m, torch.nn.GELU): 125 | m.approximate = "tanh" 126 | 127 | dynamic_axes = { 128 | "point_coords": {1: "num_points"}, 129 | "point_labels": {1: "num_points"}, 130 | } 131 | 132 | embed_dim = sam.prompt_encoder.embed_dim 133 | embed_size = sam.prompt_encoder.image_embedding_size 134 | mask_input_size = [4 * x for x in embed_size] 135 | dummy_inputs = { 136 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 137 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 138 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 139 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 140 | "has_mask_input": torch.tensor([1], dtype=torch.float), 141 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), 142 | } 143 | 144 | _ = onnx_model(**dummy_inputs) 145 | 146 | output_names = ["masks", "iou_predictions", "low_res_masks"] 147 | 148 | with warnings.catch_warnings(): 149 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 150 | warnings.filterwarnings("ignore", category=UserWarning) 151 | with open(output, "wb") as f: 152 | print(f"Exporing onnx model to {output}...") 153 | torch.onnx.export( 154 | onnx_model, 155 | tuple(dummy_inputs.values()), 156 | f, 157 | export_params=True, 158 | verbose=False, 159 | opset_version=opset, 160 | do_constant_folding=True, 161 | input_names=list(dummy_inputs.keys()), 162 | output_names=output_names, 163 | dynamic_axes=dynamic_axes, 164 | ) 165 | 166 | if onnxruntime_exists: 167 | ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} 168 | ort_session = onnxruntime.InferenceSession(output) 169 | _ = ort_session.run(None, ort_inputs) 170 | print("Model has successfully been run with ONNXRuntime.") 171 | 172 | 173 | def to_numpy(tensor): 174 | return tensor.cpu().numpy() 175 | 176 | 177 | if __name__ == "__main__": 178 | args = parser.parse_args() 179 | run_export( 180 | model_type=args.model_type, 181 | checkpoint=args.checkpoint, 182 | output=args.output, 183 | opset=args.opset, 184 | return_single_mask=args.return_single_mask, 185 | gelu_approximate=args.gelu_approximate, 186 | use_stability_score=args.use_stability_score, 187 | return_extra_metrics=args.return_extra_metrics, 188 | ) 189 | 190 | if args.quantize_out is not None: 191 | assert onnxruntime_exists, "onnxruntime is required to quantize the model." 192 | from onnxruntime.quantization import QuantType # type: ignore 193 | from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore 194 | 195 | print(f"Quantizing model and writing to {args.quantize_out}...") 196 | quantize_dynamic( 197 | model_input=args.output, 198 | model_output=args.quantize_out, 199 | optimize_model=True, 200 | per_channel=False, 201 | reduce_range=False, 202 | weight_type=QuantType.QUInt8, 203 | ) 204 | print("Done!") 205 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam, 49 | "vit_h": build_sam, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for outptu 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input promts, 89 | C is determiend by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /segment_anything/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=segment_anything 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /segment_anything/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="segment_anything", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | --------------------------------------------------------------------------------