├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── evaluate.py ├── evaluate.sh ├── evaluate_bboxes.py ├── material ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── architecture.jpg ├── preview.png └── qualitative.png ├── models ├── DQE.py ├── __init__.py ├── backbone.py ├── common.py ├── geco.py ├── geco_infer.py ├── matcher.py ├── prompt_encoder.py ├── regression.py ├── sam_ViT.py └── transformer.py ├── pretrain.py ├── pretrain.sh ├── segment_anything ├── LICENSE(1) ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── train.py ├── train.sh └── utils ├── __init__.py ├── arg_parser.py ├── box_ops.py ├── data.py └── losses.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | *.jpg 8 | *.png 9 | *.jpeg 10 | *.pkl 11 | *.json 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | *.txt 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *,cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # IPython Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | ### VirtualEnv template 99 | # Virtualenv 100 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 101 | [Bb]in 102 | [Ii]nclude 103 | [Ll]ib 104 | [Ll]ib64 105 | [Ll]ocal 106 | [Ss]cripts 107 | pyvenv.cfg 108 | .venv 109 | pip-selfcheck.json 110 | 111 | ### JetBrains template 112 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 113 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 114 | 115 | # User-specific stuff 116 | .idea/**/workspace.xml 117 | .idea/**/tasks.xml 118 | .idea/**/usage.statistics.xml 119 | .idea/**/dictionaries 120 | .idea/**/shelf 121 | 122 | # AWS User-specific 123 | .idea/**/aws.xml 124 | 125 | # Generated files 126 | .idea/**/contentModel.xml 127 | 128 | # Sensitive or high-churn files 129 | .idea/**/dataSources/ 130 | .idea/**/dataSources.ids 131 | .idea/**/dataSources.local.xml 132 | .idea/**/sqlDataSources.xml 133 | .idea/**/dynamic.xml 134 | .idea/**/uiDesigner.xml 135 | .idea/**/dbnavigator.xml 136 | 137 | # Gradle 138 | .idea/**/gradle.xml 139 | .idea/**/libraries 140 | imgs 141 | # Gradle and Maven with auto-import 142 | # When using Gradle or Maven with auto-import, you should exclude module files, 143 | # since they will be recreated, and may cause churn. Uncomment if using 144 | # auto-import. 145 | # .idea/artifacts 146 | # .idea/compiler.xml 147 | # .idea/jarRepositories.xml 148 | # .idea/modules.xml 149 | # .idea/*.iml 150 | # .idea/modules 151 | # *.iml 152 | # *.ipr 153 | 154 | # CMake 155 | cmake-build-*/ 156 | 157 | # Mongo Explorer plugin 158 | .idea/**/mongoSettings.xml 159 | 160 | # File-based project format 161 | *.iws 162 | 163 | # IntelliJ 164 | out/ 165 | *.pth 166 | # mpeltonen/sbt-idea plugin 167 | .idea_modules/ 168 | .idea 169 | # JIRA plugin 170 | atlassian-ide-plugin.xml 171 | 172 | # Cursive Clojure plugin 173 | .idea/replstate.xml 174 | 175 | # SonarLint plugin 176 | .idea/sonarlint/ 177 | 178 | # Crashlytics plugin (for Android Studio and IntelliJ) 179 | com_crashlytics_export_strings.xml 180 | crashlytics.properties 181 | crashlytics-build.properties 182 | fabric.properties 183 | 184 | # Editor-based Rest Client 185 | .idea/httpRequests 186 | 187 | # Android studio 3.1+ serialized cache file 188 | .idea/caches/build_file_checksums.ser 189 | 190 | # idea folder, uncomment if you don't need it 191 | # .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 jerpelhan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeCo (A Novel Unified Architecture for Low-Shot Counting by Detection and Segmentation) 2 | 3 | 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-novel-unified-architecture-for-low-shot/few-shot-object-counting-and-detection-on)](https://paperswithcode.com/sota/few-shot-object-counting-and-detection-on?p=a-novel-unified-architecture-for-low-shot) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-novel-unified-architecture-for-low-shot/object-counting-on-fsc147)](https://paperswithcode.com/sota/object-counting-on-fsc147?p=a-novel-unified-architecture-for-low-shot) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-novel-unified-architecture-for-low-shot/exemplar-free-counting-on-fsc147)](https://paperswithcode.com/sota/exemplar-free-counting-on-fsc147?p=a-novel-unified-architecture-for-low-shot) 8 | 9 | 10 | This repository holds the official Pytorch implementation for the paper [A Novel Unified Architecture for Low-Shot Counting by Detection and Segmentation](https://arxiv.org/pdf/2409.18686) accepted at NeurIPS 2024. 11 | 12 | 13 | 14 | https://github.com/user-attachments/assets/cbdd1fbb-5b07-43c0-95e4-6bf0e5b2896a 15 | 16 | 17 | 18 | 19 | 20 | ## Abstract 21 | Low-shot object counters estimate the number of objects in an image using few or no annotated exemplars. Objects are localized by matching them to prototypes, which are constructed by unsupervised image-wide object appearance aggregation. Due to potentially diverse object appearances, the existing approaches often lead to overgeneralization and false positive detections. 22 | Furthermore, the best-performing methods train object localization by a surrogate loss, that predicts a unit Gaussian at each object center. This loss is sensitive to annotation error, hyperparameters and does not directly optimize the detection task, leading to suboptimal counts.We introduce GeCo, a novel low-shot counter that achieves accurate object detection, segmentation, and count estimation in a unified architecture. GeCo robustly generalizes the prototypes across objects appearances through a novel dense object query formulation. In addition, a novel counting loss is proposed, that directly optimizes the detection task and avoids the issues of the standard surrogate loss. GeCo surpasses the leading few-shot detection-based counters by ~25\% in the total count MAE, achieves superior detection accuracy and sets a new solid state-of-the-art result across all low-shot counting setups. 23 | ![](material/architecture.jpg) 24 | 25 | 26 | ## Quick demo 27 | 28 | To install the required dependencies, run the following command: 29 | 30 | ```bash 31 | conda create -n geco_test python=3.8 32 | conda activate geco_test 33 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 34 | pip install matplotlib 35 | ``` 36 | 37 | To run the demo, you need to download the [pretrained weights](https://drive.google.com/file/d/1wjOF9MWkrVJVo5uG3gVqZEW9pwRq_aIk/view?usp=sharing) and put them in the `MODEL_folder`. 38 | 39 | **Run the demo:** 40 | 41 | ```bash 42 | python demo.py --image_path ./material/4.jpg --output_masks 43 | ``` 44 | 45 | 46 | ## Evaluation on FSC147 47 | 48 | To evaluate GeCo on FSC147, install also: 49 | 50 | ```bash 51 | pip install tqdm 52 | pip install pycocotools 53 | pip install scipy 54 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' 55 | ``` 56 | 57 | download all the required data: 58 | 1. The original FSC147 dataset from [Learning to Count Everything](https://drive.google.com/file/d/1ymDYrGs9DSRicfZbSCDiOu0ikGDh5k6S/view?usp=sharing) (put in the `DATA_folder`), 59 | 60 | 2. Box annotations for validation and test split from [Counting-DETR](https://drive.google.com/drive/folders/1Jvr2Bu2cD_yn4W_DjKIW6YjdAiUsw_WA) (put in the `DATA_folder/annotations`), 61 | 62 | 3. [**Pretrained weights**](https://drive.google.com/file/d/1wjOF9MWkrVJVo5uG3gVqZEW9pwRq_aIk/view?usp=sharing) (put in the `MODEL_folder`). 63 | 64 | and compute density maps: 65 | ```bash 66 | python utils/data.py --data_path DATA_folder 67 | ``` 68 | (Need to compute density maps due to FSCD147 incompatibility with the original FSC147 annotations.) 69 | 70 | **Run inference on FSC147:** 71 | 72 | ```bash 73 | python evaluate.py --data_path DATA_folder --model_path MODEL_folder 74 | ``` 75 | **Run bbox evaluation on FSC147:** 76 | 77 | ```bash 78 | python evaluate_bboxes.py --data_path DATA_folder 79 | ``` 80 | 81 | ![](material/qualitative.png) 82 | 83 | ### Training 84 | 85 | To train the model, follow the steps for evaluation on FSC147, correct paths in `train.sh` and `pretrain.sh`, download box annotations for [train split](https://drive.google.com/file/d/15_qpEZ7f0ZBrcTmgFnxx71lCdxAGtuTz/view?usp=sharing), put them in `DATA_folder`, SAM-HQ pretrained [weights](https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing), put them in `MODEL_folder` and run the following commands: 86 | 87 | First, generate density maps: 88 | ```bash 89 | python utils/data.py 90 | ``` 91 | 92 | First run pretraining: 93 | ```bash 94 | sbatch pretrain.sh 95 | ``` 96 | 97 | then run the main training: 98 | ```bash 99 | sbatch train.sh 100 | ``` 101 | 102 | ## Citation 103 | ```bash 104 | @article{pelhan2024novel, 105 | title={A Novel Unified Architecture for Low-Shot Counting by Detection and Segmentation}, 106 | author={Pelhan, Jer and Lukezic, Alan and Zavrtanik, Vitjan and Kristan, Matej}, 107 | journal={Advances in Neural Information Processing Systems}, 108 | volume={37}, 109 | pages={66260--66282}, 110 | year={2024} 111 | } 112 | ``` 113 | 114 | ## Possible applications 115 | 116 | 117 | https://github.com/user-attachments/assets/e61c791d-389a-486e-a1bd-3713455df0a9 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from torch.nn import DataParallel 2 | from models.geco_infer import build_model 3 | from utils.arg_parser import get_argparser 4 | import argparse 5 | import torch 6 | from torchvision import transforms as T 7 | import matplotlib.patches as patches 8 | from PIL import Image 9 | from torchvision import ops 10 | from utils.data import resize_and_pad 11 | import matplotlib.pyplot as plt 12 | 13 | bounding_boxes = [] 14 | global clicked 15 | 16 | # Global variables to track drawing state 17 | rect = None 18 | start_x, start_y = None, None 19 | 20 | # Event handler for mouse press (start drawing) 21 | def on_press(event): 22 | global start_x, start_y, rect 23 | if event.inaxes: 24 | start_x, start_y = event.xdata, event.ydata # Store starting point 25 | # Create a rectangle (but do not draw yet) 26 | rect = patches.Rectangle((start_x, start_y), 0, 0, linewidth=2, edgecolor='r', facecolor='none') 27 | event.inaxes.add_patch(rect) 28 | plt.draw() # Update plot to show rectangle (even if not yet drawn) 29 | 30 | # Event handler for mouse motion (while drawing) 31 | def on_motion(event): 32 | global start_x, start_y, rect 33 | if rect is not None and event.inaxes: 34 | # Update the width and height of the rectangle based on mouse position 35 | width = event.xdata - start_x 36 | height = event.ydata - start_y 37 | rect.set_width(width) 38 | rect.set_height(height) 39 | plt.draw() # Redraw to update the rectangle while dragging 40 | 41 | # Event handler for mouse release (end drawing) 42 | def on_release(event): 43 | global rect 44 | # Once mouse is released, we finalize the bounding box 45 | if rect is not None: 46 | bounding_boxes.append([rect.get_x(), rect.get_y(), rect.get_x() + rect.get_width(), rect.get_y() + rect.get_height()]) 47 | rect = None # Reset rect after release 48 | 49 | 50 | @torch.no_grad() 51 | def demo(args): 52 | img_path = args.image_path 53 | global fig, ax 54 | 55 | gpu = 0 56 | torch.cuda.set_device(gpu) 57 | device = torch.device(gpu) 58 | 59 | model = DataParallel( 60 | build_model(args).to(device), 61 | device_ids=[gpu], 62 | output_device=gpu 63 | ) 64 | model.load_state_dict( 65 | torch.load('GeCo.pth', weights_only=True)['model'], strict=False, 66 | ) 67 | 68 | model.eval() 69 | 70 | image = T.ToTensor()(Image.open(img_path).convert("RGB")) 71 | 72 | # Create a figure and axis 73 | fig, ax = plt.subplots(1) 74 | ax.imshow(image.permute(1,2,0)) 75 | plt.axis('off') 76 | # Connect the click event 77 | fig.canvas.mpl_connect('button_press_event', on_press) 78 | fig.canvas.mpl_connect('motion_notify_event', on_motion) 79 | fig.canvas.mpl_connect('button_release_event', on_release) 80 | 81 | plt.title("Click and drag to draw bboxes, then close window") 82 | # Show the image 83 | plt.show() 84 | 85 | bboxes = torch.tensor(bounding_boxes, dtype=torch.float32) 86 | 87 | img, bboxes, scale = resize_and_pad(image, bboxes, full_stretch=False) 88 | img = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img).unsqueeze(0).to(device) 89 | bboxes = bboxes.unsqueeze(0).to(device) 90 | 91 | outputs, _, _, _, masks = model(img, bboxes) 92 | del _ 93 | idx = 0 94 | thr = 4 95 | keep = ops.nms(outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / thr], 96 | outputs[idx]['box_v'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / thr], 0.5) 97 | 98 | boxes = (outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / thr])[keep] 99 | 100 | bboxes = torch.clamp(boxes, 0, 1) 101 | 102 | plt.clf() 103 | plt.imshow(image.permute(1, 2, 0)) 104 | if args.output_masks: 105 | masks_ = masks[idx][(outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / thr)[0]] 106 | N_masks = masks_.shape[0] 107 | indices = torch.randint(1, N_masks + 1, (1, N_masks), device=masks_.device).view(-1, 1, 1) 108 | masks = (masks_ * indices).sum(dim=0) 109 | mask_display = ( 110 | T.Resize((int(img.shape[2] / scale), int(img.shape[3] / scale)), interpolation=T.InterpolationMode.NEAREST)( 111 | masks.cpu().unsqueeze(0))[0])[:image.shape[1], :image.shape[2]] 112 | cmap = plt.cm.tab20 # Use a colormap with distinct colors 113 | norm = plt.Normalize(vmin=0, vmax=N_masks) 114 | del masks 115 | del masks_ 116 | del outputs 117 | rgba_image = cmap(norm(mask_display)) 118 | rgba_image[mask_display == 0, -1] = 0 119 | plt.imshow(rgba_image, alpha=0.6) 120 | 121 | pred_boxes = bboxes.cpu() / torch.tensor([scale, scale, scale, scale]) * img.shape[-1] 122 | for i in range(len(pred_boxes)): 123 | box = pred_boxes[i] 124 | 125 | plt.plot([box[0], box[0], box[2], box[2], box[0]], [box[1], box[3], box[3], box[1], box[1]], linewidth=0.7, 126 | color='orange') 127 | 128 | pred_boxes = bounding_boxes 129 | for i in range(len(pred_boxes)): 130 | box = pred_boxes[i] 131 | plt.plot([box[0], box[0], box[2], box[2], box[0]], [box[1], box[3], box[3], box[1], box[1]], linewidth=2, 132 | color='red') 133 | plt.title("Number of selected objects:" + str(len(bboxes))) 134 | plt.axis('off') 135 | plt.show() 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser('GeCo', parents=[get_argparser()]) 139 | args = parser.parse_args() 140 | demo(args) 141 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GeCo 3 | #SBATCH --output=evaluation/test_GeCo%j.txt 4 | #SBATCH --error=evaluation/test_GeCo%j.txt 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=12 8 | #SBATCH --partition=gpu 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --time=0-02:00:00 11 | 12 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 13 | export MASTER_ADDR=$master_addr 14 | export MASTER_PORT=50197 15 | export NCCL_P2P_DISABLE=1 16 | export NCCL_IB_DISABLE=1 17 | export NCCL_BLOCKING_WAIT=1 18 | export TORCH_DISTRIBUTED_DEBUG=DETAIL 19 | 20 | module load Anaconda3 21 | source activate geco 22 | conda activate base 23 | conda activate geco 24 | 25 | srun --unbuffered python evaluate.py \ 26 | --model_name GeCo \ 27 | --data_path /d/hpc/projects/FRI/pelhanj/fsc147 \ 28 | --model_path /d/hpc/projects/FRI/pelhanj/fsc147/models/ -------------------------------------------------------------------------------- /material/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/material/1.jpg -------------------------------------------------------------------------------- /material/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/material/2.jpg -------------------------------------------------------------------------------- /material/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/material/3.jpg -------------------------------------------------------------------------------- /material/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/material/architecture.jpg -------------------------------------------------------------------------------- /material/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/material/preview.png -------------------------------------------------------------------------------- /material/qualitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/material/qualitative.png -------------------------------------------------------------------------------- /models/DQE.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from typing import Tuple 11 | 12 | from models.regression import UpsamplingLayer 13 | from models.transformer import SelfCrossAttentionBlock, PrototypeAttentionBlock, ImgToPrototypeAttentionBlock 14 | 15 | 16 | class DQE(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | num_prototype_attn_steps: int, 22 | num_image_attn_steps: int, 23 | zero_shot: bool = False 24 | 25 | ) -> None: 26 | """ 27 | 28 | Arguments: 29 | """ 30 | super().__init__() 31 | self.transformer_dim = transformer_dim 32 | self.prototype_attention = nn.ModuleList() 33 | self.image_attention = nn.ModuleList() 34 | self.zero_shot = zero_shot 35 | 36 | if self.zero_shot: 37 | self.image_to_prototype_attn =ImgToPrototypeAttentionBlock( 38 | embedding_dim=transformer_dim, 39 | num_heads=8, 40 | ) 41 | 42 | 43 | for _ in range(num_prototype_attn_steps): 44 | self.prototype_attention.append( 45 | PrototypeAttentionBlock( 46 | embedding_dim=transformer_dim, 47 | num_heads=8, 48 | ) 49 | ) 50 | 51 | for _ in range(num_image_attn_steps): 52 | self.image_attention.append(SelfCrossAttentionBlock( 53 | embedding_dim=transformer_dim, 54 | num_heads=8, 55 | )) 56 | 57 | self.upscale = nn.Sequential( 58 | UpsamplingLayer(transformer_dim, transformer_dim), 59 | UpsamplingLayer(transformer_dim, transformer_dim)) 60 | self.upscale_hq = UpsamplingLayer(transformer_dim + 32, transformer_dim) 61 | 62 | def init_weights(m): 63 | if isinstance(m, nn.Linear): 64 | torch.nn.init.xavier_uniform(m.weight) 65 | m.bias.data.fill_(0.01) 66 | 67 | def forward( 68 | self, 69 | image_embeddings: torch.Tensor, 70 | image_pe: torch.Tensor, 71 | prototype_embeddings: torch.Tensor, 72 | hq_features: torch.Tensor 73 | ) -> Tuple[torch.Tensor, torch.Tensor]: 74 | """ 75 | 76 | """ 77 | b, c, h, w = image_embeddings.shape 78 | image_pe = torch.repeat_interleave(image_pe, image_embeddings.shape[0], dim=0) 79 | if image_pe.shape[1:] != image_embeddings.shape[1:]: 80 | upsample_pos_emb = nn.UpsamplingBilinear2d(scale_factor=1.5) 81 | image_pe = upsample_pos_emb(image_pe) 82 | image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1) 83 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 84 | src = image_embeddings 85 | 86 | if self.zero_shot: 87 | prototype_embeddings = self.image_to_prototype_attn(image_f=src, 88 | prototypes=prototype_embeddings) 89 | 90 | for layer in self.prototype_attention: 91 | src = layer(image_f=src, 92 | prototypes=prototype_embeddings) 93 | 94 | for layer in self.image_attention: 95 | src = layer(image_f=src, 96 | adapted_image_f=image_embeddings, 97 | pos_enc=image_pe) 98 | src = src.transpose(1, 2).view(b, c, h, w) 99 | src = self.upscale(src) 100 | src = torch.cat([src, hq_features], dim=1) 101 | src = self.upscale_hq(src) 102 | 103 | return src 104 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/models/__init__.py -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torchvision import models 7 | from torchvision.ops.misc import FrozenBatchNorm2d 8 | 9 | from models.common import LayerNorm2d 10 | from models.sam_ViT import ImageEncoderViT 11 | from functools import partial 12 | 13 | 14 | class Backbone(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | requires_grad: bool, 19 | image_size: int, 20 | model_path: str = None, 21 | ): 22 | 23 | super(Backbone, self).__init__() 24 | 25 | 26 | vit_patch_size = 16 27 | image_embedding_size = image_size // vit_patch_size 28 | self.num_channels = image_emb_size = 256 29 | vit_dim = 1280 30 | transformer_dim = 256 31 | vit_encoder = ImageEncoderViT( 32 | depth=32, 33 | embed_dim=1280, 34 | img_size=image_size, 35 | mlp_ratio=4, 36 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 37 | num_heads=16, 38 | patch_size=vit_patch_size, 39 | qkv_bias=True, 40 | use_rel_pos=True, 41 | global_attn_indexes=[7, 15, 23, 31], 42 | window_size=14, 43 | out_chans=256, 44 | ) 45 | self.backbone = vit_encoder 46 | 47 | self.compress_vit_feat = nn.Sequential( 48 | nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), 49 | LayerNorm2d(transformer_dim), 50 | nn.GELU(), 51 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) 52 | 53 | self.embedding_encoder = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | nn.GELU(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | ) 59 | if model_path is not None: 60 | 61 | checkpoint = torch.load(os.path.join(model_path, "sam_hq_vit_h.pth"), map_location="cpu" 62 | ) 63 | state_dict = {k.replace("image_encoder.", ""): v for k, v in checkpoint.items() if "image_encoder" in k} 64 | self.backbone.load_state_dict(state_dict) 65 | 66 | state_dict = {k.replace("mask_decoder.compress_vit_feat.", ""): v for k, v in checkpoint.items() if 67 | "compress_vit_feat" in k} 68 | self.compress_vit_feat.load_state_dict(state_dict) 69 | 70 | state_dict = {k.replace("mask_decoder.embedding_encoder.", ""): v for k, v in checkpoint.items() if 71 | "embedding_encoder" in k} 72 | self.embedding_encoder.load_state_dict(state_dict) 73 | 74 | for n, param in self.named_parameters(): 75 | param.requires_grad_(requires_grad) 76 | 77 | def forward(self, x): 78 | x = self.backbone.patch_embed(x) 79 | if self.backbone.pos_embed is not None: 80 | if self.backbone.pos_embed.shape[1:] != x.shape[1:]: 81 | upsample_pos_emb = nn.UpsamplingBilinear2d(scale_factor=1.5) 82 | pos_emb = upsample_pos_emb(self.backbone.pos_embed.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 83 | else: 84 | pos_emb = self.backbone.pos_embed 85 | x = x + pos_emb 86 | 87 | interm_embeddings = [] 88 | for blk in self.backbone.blocks: 89 | x = blk(x) 90 | if blk.window_size == 0: 91 | interm_embeddings.append(x) 92 | image_embeddings = self.backbone.neck(x.permute(0, 3, 1, 2)) 93 | 94 | vit_features = interm_embeddings[0].permute(0, 3, 1, 2) 95 | hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) 96 | 97 | return image_embeddings, hq_features 98 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | from typing import Type 11 | 12 | class MLP(nn.Module): 13 | """ Very simple multi-layer perceptron (also called FFN)""" 14 | 15 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 16 | super().__init__() 17 | self.num_layers = num_layers 18 | h = [hidden_dim] * (num_layers - 1) 19 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 20 | 21 | def forward(self, x): 22 | for i, layer in enumerate(self.layers): 23 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 24 | return x 25 | 26 | class MLPBlock(nn.Module): 27 | def __init__( 28 | self, 29 | embedding_dim: int, 30 | mlp_dim: int, 31 | act: Type[nn.Module] = nn.GELU, 32 | ) -> None: 33 | super().__init__() 34 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 35 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 36 | self.act = act() 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | return self.lin2(self.act(self.lin1(x))) 40 | 41 | 42 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 43 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 44 | class LayerNorm2d(nn.Module): 45 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 46 | super().__init__() 47 | self.weight = nn.Parameter(torch.ones(num_channels)) 48 | self.bias = nn.Parameter(torch.zeros(num_channels)) 49 | self.eps = eps 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | u = x.mean(1, keepdim=True) 53 | s = (x - u).pow(2).mean(1, keepdim=True) 54 | x = (x - u) / torch.sqrt(s + self.eps) 55 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 56 | return x -------------------------------------------------------------------------------- /models/geco.py: -------------------------------------------------------------------------------- 1 | from torchvision.ops import roi_align 2 | 3 | from utils.box_ops import boxes_with_scores 4 | from .backbone import Backbone 5 | from .common import MLP 6 | from .DQE import DQE 7 | from .prompt_encoder import PromptEncoder_DQE 8 | import torch 9 | from torch import nn 10 | from torchvision.transforms import Resize 11 | 12 | 13 | class GeCo(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | image_size: int, 18 | num_objects: int, 19 | emb_dim: int, 20 | num_heads: int, 21 | kernel_dim: int, 22 | train_backbone: bool, 23 | reduction: int, 24 | zero_shot: bool, 25 | model_path: str 26 | ): 27 | super(GeCo, self).__init__() 28 | 29 | self.emb_dim = emb_dim 30 | self.num_objects = num_objects 31 | self.reduction = reduction 32 | self.kernel_dim = kernel_dim 33 | self.image_size = image_size 34 | self.zero_shot = zero_shot 35 | self.num_heads = num_heads 36 | self.num_classes = 1 37 | self.model_path = model_path 38 | self.backbone = Backbone(requires_grad=train_backbone, image_size=image_size, model_path=model_path) 39 | 40 | self.class_embed = nn.Sequential(nn.Linear(emb_dim, 1), nn.LeakyReLU()) 41 | self.bbox_embed = MLP(emb_dim, emb_dim, 4, 3) 42 | 43 | self.emb_dim = 256 44 | self.adapt_features = DQE( 45 | transformer_dim=self.emb_dim, 46 | num_prototype_attn_steps=3, 47 | num_image_attn_steps=2, 48 | ) 49 | 50 | self.prompt_encoder = PromptEncoder_DQE( 51 | embed_dim=256, 52 | image_embedding_size=(64, 64), 53 | input_image_size=(image_size, image_size), 54 | mask_in_chans=16, 55 | ) 56 | 57 | self.shape_or_objectness = nn.Sequential( 58 | nn.Linear(2, 64), 59 | nn.ReLU(), 60 | nn.Linear(64, emb_dim), 61 | nn.ReLU(), 62 | nn.Linear(emb_dim, 1 ** 2 * emb_dim) 63 | ) 64 | self.resize = Resize((512, 512)) 65 | 66 | 67 | def forward(self, x, bboxes): 68 | num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects 69 | 70 | # src = x 71 | src, src_hq = self.backbone(x) 72 | 73 | bs, c, h, w = src.size() 74 | 75 | bboxes_roi = torch.cat([ 76 | torch.arange( 77 | bs, requires_grad=False 78 | ).to(bboxes.device).repeat_interleave(self.num_objects).reshape(-1, 1), 79 | bboxes.flatten(0, 1), 80 | ], dim=1) 81 | 82 | # Roi align 83 | exemplars = roi_align( 84 | src, 85 | boxes=bboxes_roi, output_size=self.kernel_dim, 86 | spatial_scale=1.0 / self.reduction, aligned=True 87 | ).permute(0, 2, 3, 1).reshape(bs, self.num_objects * self.kernel_dim ** 2, self.emb_dim) 88 | 89 | box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device) 90 | box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0] 91 | box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1] 92 | 93 | # Encode shape 94 | shape = self.shape_or_objectness(box_hw).reshape( 95 | bs, -1, self.emb_dim 96 | ) 97 | prototype_embeddings = torch.cat([exemplars, shape], dim=1) 98 | 99 | # adapt image feature with prototypes 100 | adapted_f = self.adapt_features( 101 | image_embeddings=src, 102 | image_pe=self.prompt_encoder.get_dense_pe(), 103 | prototype_embeddings=prototype_embeddings, 104 | hq_features=src_hq 105 | ) 106 | 107 | # Predict class [fg, bg] and l,r,t,b 108 | bs, c, w, h = adapted_f.shape 109 | adapted_f = adapted_f.view(bs, self.emb_dim, -1).permute(0, 2, 1) 110 | centerness = self.class_embed(adapted_f).view(bs, w, h, 1).permute(0, 3, 1, 2) 111 | outputs_coord = self.bbox_embed(adapted_f).sigmoid().view(bs, w, h, 4).permute(0, 3, 1, 2) 112 | outputs, ref_points = boxes_with_scores(centerness, outputs_coord) 113 | 114 | return outputs, ref_points, centerness, outputs_coord 115 | 116 | 117 | 118 | 119 | def build_model(args): 120 | assert args.reduction in [4, 8, 16] 121 | 122 | return GeCo( 123 | image_size=args.image_size, 124 | num_objects=args.num_objects, 125 | zero_shot=args.zero_shot, 126 | emb_dim=args.emb_dim, 127 | num_heads=args.num_heads, 128 | kernel_dim=args.kernel_dim, 129 | train_backbone=args.backbone_lr > 0, 130 | reduction=args.reduction, 131 | model_path=args.model_path 132 | 133 | ) 134 | -------------------------------------------------------------------------------- /models/geco_infer.py: -------------------------------------------------------------------------------- 1 | from torchvision.ops import roi_align 2 | 3 | from utils.box_ops import boxes_with_scores 4 | from .backbone import Backbone 5 | from .common import MLP 6 | from .DQE import DQE 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torchvision.transforms import Resize 11 | 12 | 13 | class GeCo(nn.Module): 14 | 15 | def __init__( 16 | self, 17 | image_size: int, 18 | num_objects: int, 19 | emb_dim: int, 20 | num_heads: int, 21 | kernel_dim: int, 22 | train_backbone: bool, 23 | reduction: int, 24 | zero_shot: bool, 25 | model_path: str, 26 | return_masks: bool = False 27 | ): 28 | 29 | super(GeCo, self).__init__() 30 | 31 | self.emb_dim = emb_dim 32 | self.num_objects = num_objects 33 | self.reduction = reduction 34 | self.kernel_dim = kernel_dim 35 | self.image_size = image_size 36 | self.zero_shot = zero_shot 37 | self.num_heads = num_heads 38 | self.num_classes = 1 39 | self.model_path = model_path 40 | self.backbone = Backbone(requires_grad=train_backbone, image_size=image_size) 41 | self.class_embed = nn.Sequential(nn.Linear(emb_dim, 1), nn.LeakyReLU()) 42 | self.bbox_embed = MLP(emb_dim, emb_dim, 4, 3) 43 | self.return_masks = return_masks 44 | 45 | self.emb_dim = 256 46 | self.adapt_features = DQE( 47 | transformer_dim=self.emb_dim, 48 | num_prototype_attn_steps=3, 49 | num_image_attn_steps=2, 50 | zero_shot=zero_shot, 51 | ) 52 | from .prompt_encoder import PromptEncoder_DQE 53 | self.prompt_encoder = PromptEncoder_DQE( 54 | embed_dim=256, 55 | image_embedding_size=(64, 64), 56 | input_image_size=(image_size, image_size), 57 | mask_in_chans=16, 58 | ) 59 | 60 | from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder 61 | 62 | prompt_embed_dim = 256 63 | image_embedding_size = 64 64 | image_size = 1024 65 | self.prompt_encoder_sam = PromptEncoder( 66 | embed_dim=prompt_embed_dim, 67 | image_embedding_size=(image_embedding_size, image_embedding_size), 68 | input_image_size=(image_size, image_size), 69 | mask_in_chans=16, 70 | ) 71 | image_embedding_size = 96 72 | image_size = 1536 73 | self.prompt_encoder_sam_ = PromptEncoder( 74 | embed_dim=prompt_embed_dim, 75 | image_embedding_size=(image_embedding_size, image_embedding_size), 76 | input_image_size=(image_size, image_size), 77 | mask_in_chans=16, 78 | ) 79 | self.mask_decoder = MaskDecoder( 80 | num_multimask_outputs=3, 81 | transformer=TwoWayTransformer( 82 | depth=2, 83 | embedding_dim=prompt_embed_dim, 84 | mlp_dim=2048, 85 | num_heads=8, 86 | ), 87 | transformer_dim=prompt_embed_dim, 88 | iou_head_depth=3, 89 | iou_head_hidden_dim=256, 90 | ) 91 | 92 | checkpoint = torch.hub.load_state_dict_from_url( 93 | 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', 94 | map_location="cpu" 95 | ) 96 | state_dict = {k.replace("mask_decoder.", ""): v for k, v in checkpoint.items() if "mask_decoder" in k} 97 | self.mask_decoder.load_state_dict(state_dict) 98 | state_dict = {k.replace("prompt_encoder.", ""): v for k, v in checkpoint.items() if "prompt_encoder" in k} 99 | self.prompt_encoder_sam.load_state_dict(state_dict) 100 | self.prompt_encoder_sam_.load_state_dict(state_dict) 101 | 102 | if self.zero_shot: 103 | self.exemplars = nn.Parameter(torch.randn(1, emb_dim)) 104 | else: 105 | self.shape_or_objectness = nn.Sequential( 106 | nn.Linear(2, 64), 107 | nn.ReLU(), 108 | nn.Linear(64, emb_dim), 109 | nn.ReLU(), 110 | nn.Linear(emb_dim, 1 ** 2 * emb_dim) 111 | ) 112 | self.resize = Resize((512, 512)) 113 | 114 | def refine_bounding_boxes(self, features, outputs, return_masks=False): 115 | 116 | batch_masks = [] 117 | batch_iou = [] 118 | batch_bboxes = [] 119 | for i in range(len(outputs)): 120 | step = 50 121 | masks = [] 122 | iou_predictions = [] 123 | corrected_bboxes_ = [] 124 | for box_i in range(step, len(outputs[i]['pred_boxes'][0]) + step, step): 125 | box = outputs[i]['pred_boxes'][0][(box_i - step):box_i] * features.shape[-1] * 16 126 | if features.shape[-1] * 16 == 1024: 127 | sparse_embeddings, dense_embeddings = self.prompt_encoder_sam( 128 | points=None, 129 | boxes=box, 130 | masks=None, 131 | ) 132 | else: 133 | sparse_embeddings, dense_embeddings = self.prompt_encoder_sam_( 134 | points=None, 135 | boxes=box, 136 | masks=None, 137 | ) 138 | # # # Predict masks 139 | masks_, iou_predictions_ = self.mask_decoder( 140 | image_embeddings=features[i].unsqueeze(0), 141 | image_pe=self.prompt_encoder_sam.get_dense_pe(), 142 | sparse_prompt_embeddings=sparse_embeddings, 143 | dense_prompt_embeddings=dense_embeddings, 144 | multimask_output=False, 145 | ) 146 | 147 | masks_ = F.interpolate(masks_, (features.shape[-1] * 16, features.shape[-1] * 16), mode="bilinear", 148 | align_corners=False) 149 | masks_ = masks_ > 0 150 | if return_masks: 151 | masks_ = masks_[..., : 1024, : 1024] 152 | masks.append(masks_) 153 | iou_predictions.append(iou_predictions_) 154 | 155 | corrected_bboxes = torch.zeros((masks_.shape[0], 4), dtype=torch.float) 156 | masks_ = masks_[:, 0] 157 | for index, mask_i in enumerate(masks_): 158 | y, x = torch.where(mask_i != 0) 159 | if y.shape[0] > 0 and x.shape[0] > 0: 160 | corrected_bboxes[index, 0] = torch.min(x) 161 | corrected_bboxes[index, 1] = torch.min(y) 162 | corrected_bboxes[index, 2] = torch.max(x) 163 | corrected_bboxes[index, 3] = torch.max(y) 164 | corrected_bboxes_.append(corrected_bboxes) 165 | if len(corrected_bboxes_) > 0: 166 | if return_masks: 167 | batch_masks.append(torch.cat(masks, dim=0)[:, 0]) 168 | else: 169 | batch_masks.append([]) 170 | batch_bboxes.append(torch.cat(corrected_bboxes_)) 171 | batch_iou.append(torch.cat(iou_predictions).permute(1, 0)) 172 | else: 173 | batch_masks.append([]) 174 | batch_bboxes.append(torch.tensor([]).to(features.device)) 175 | batch_iou.append(torch.tensor([]).to(features.device)) 176 | return batch_masks, batch_iou, batch_bboxes 177 | 178 | def forward(self, img, bboxes): 179 | num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects 180 | 181 | src, src_hq = self.backbone(img) 182 | bs, c, h, w = src.size() 183 | 184 | if not self.zero_shot: 185 | prototype_embeddings = self.create_prototypes(src, bboxes) 186 | 187 | else: # zero shot 188 | prototype_embeddings = self.exemplars.expand(bs, -1, -1) 189 | adapted_f = self.adapt_features( 190 | image_embeddings=src, 191 | image_pe=self.prompt_encoder.get_dense_pe(), 192 | prototype_embeddings=prototype_embeddings, 193 | hq_features=src_hq 194 | ) 195 | 196 | # Predict class [fg, bg] and l,r,t,b 197 | bs, c, w, h = adapted_f.shape 198 | adapted_f = adapted_f.view(bs, self.emb_dim, -1).permute(0, 2, 1) 199 | centerness = self.class_embed(adapted_f).view(bs, w, h, 1).permute(0, 3, 1, 2) 200 | outputs_coord = self.bbox_embed(adapted_f).sigmoid().view(bs, w, h, 4).permute(0, 3, 1, 2) 201 | outputs, ref_points = boxes_with_scores(centerness, outputs_coord, batch_thresh=0.001) 202 | masks, ious, corrected_bboxes = self.refine_bounding_boxes(src, outputs, return_masks=self.return_masks) 203 | 204 | for i in range(len(outputs)): 205 | outputs[i]["scores"] = ious[i] 206 | outputs[i]["pred_boxes"] = corrected_bboxes[i].to(outputs[i]["pred_boxes"].device).unsqueeze(0) / img.shape[ 207 | -1] 208 | 209 | return outputs, ref_points, centerness, outputs_coord, masks 210 | 211 | def create_prototypes(self, src, bboxes): 212 | bs = src.size(0) 213 | self.num_objects = bboxes.size(1) 214 | 215 | bboxes_roi = torch.cat([ 216 | torch.arange( 217 | bs, requires_grad=False 218 | ).to(bboxes.device).repeat_interleave(self.num_objects).reshape(-1, 1), 219 | bboxes.flatten(0, 1), 220 | ], dim=1) 221 | self.kernel_dim = 1 222 | 223 | exemplars = roi_align( 224 | src, 225 | boxes=bboxes_roi, output_size=self.kernel_dim, 226 | spatial_scale=1.0 / self.reduction, aligned=True 227 | ).permute(0, 2, 3, 1).reshape(bs, self.num_objects * self.kernel_dim ** 2, self.emb_dim) 228 | 229 | box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device) 230 | box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0] 231 | box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1] 232 | 233 | shape = self.shape_or_objectness(box_hw).reshape( 234 | bs, -1, self.emb_dim 235 | ) 236 | prototype_embeddings = torch.cat([exemplars, shape], dim=1) 237 | return prototype_embeddings 238 | 239 | 240 | def build_model(args): 241 | assert args.reduction in [4, 8, 16] 242 | 243 | return GeCo( 244 | image_size=args.image_size, 245 | num_objects=args.num_objects, 246 | zero_shot=args.zero_shot, 247 | emb_dim=args.emb_dim, 248 | num_heads=args.num_heads, 249 | kernel_dim=args.kernel_dim, 250 | train_backbone=args.backbone_lr > 0, 251 | reduction=args.reduction, 252 | model_path=args.model_path, 253 | return_masks=args.output_masks 254 | ) 255 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.optimize import linear_sum_assignment 4 | from torch import nn 5 | from utils.box_ops import generalized_box_iou, box_iou 6 | 7 | 8 | class GeCoMatcher(nn.Module): 9 | 10 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 11 | super().__init__() 12 | self.cost_class = cost_class 13 | self.cost_bbox = cost_bbox 14 | self.cost_giou = cost_giou 15 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 16 | 17 | def forward(self, outputs, targets, ref_points=None): 18 | with torch.no_grad(): 19 | bs, num_queries = outputs["box_v"].shape[:2] 20 | 21 | # We flatten to compute the cost matrices in a batch 22 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 23 | 24 | # Also concat the target labels and boxes 25 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 26 | 27 | # Compute the L1 cost between boxes 28 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 29 | 30 | # Compute the giou cost betwen boxes 31 | iou, unions = box_iou(out_bbox, tgt_bbox) 32 | cost_giou = - generalized_box_iou(out_bbox, tgt_bbox) 33 | 34 | # Final cost matrix 35 | C = self.cost_bbox * cost_bbox + self.cost_giou * cost_giou 36 | C = C.view(bs, num_queries, -1).cpu() 37 | 38 | sizes = [len(v["boxes"]) for v in targets] 39 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 40 | 41 | non_mathced_gt_bbox_idx = \ 42 | np.nonzero(np.logical_not(np.in1d(np.array([i for i in range(tgt_bbox.shape[0])]), indices[0][1])))[0] 43 | non_mathced_gt_bbox_idx = np.concatenate( 44 | (non_mathced_gt_bbox_idx, torch.where(iou.max(dim=0)[0] == 0)[0].cpu().numpy())) 45 | non_mathced_gt_bbox_idx = [torch.tensor(non_mathced_gt_bbox_idx, dtype=torch.int64).unique()] 46 | remove_mask = np.logical_not(np.in1d(indices[0][1], non_mathced_gt_bbox_idx[ 47 | 0].cpu())) 48 | ind0 = indices[0][0][remove_mask] 49 | ind1 = indices[0][1][remove_mask] 50 | non_mathced_pred_bbox_idx = \ 51 | np.nonzero(np.logical_not(np.in1d(np.array([i for i in range(out_bbox.shape[0])]), indices[0][0])))[0] 52 | 53 | match_indexes = [(torch.as_tensor(ind0, dtype=torch.int64), torch.as_tensor(ind1, dtype=torch.int64))] 54 | return match_indexes, non_mathced_gt_bbox_idx, non_mathced_pred_bbox_idx 55 | 56 | 57 | def build_matcher(args): 58 | return GeCoMatcher(args.cost_class, args.cost_bbox, args.cost_giou) 59 | -------------------------------------------------------------------------------- /models/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | 14 | 15 | class PromptEncoder_DQE(nn.Module): 16 | def __init__( 17 | self, 18 | embed_dim: int, 19 | image_embedding_size: Tuple[int, int], 20 | input_image_size: Tuple[int, int], 21 | mask_in_chans: int, 22 | activation: Type[nn.Module] = nn.GELU, 23 | ) -> None: 24 | """ 25 | Encodes prompts for input to SAM's mask decoder. 26 | 27 | Arguments: 28 | embed_dim (int): The prompts' embedding dimension 29 | image_embedding_size (tuple(int, int)): The spatial size of the 30 | image embedding, as (H, W). 31 | input_image_size (int): The padded size of the image as input 32 | to the image encoder, as (H, W). 33 | mask_in_chans (int): The number of hidden channels used for 34 | encoding input masks. 35 | activation (nn.Module): The activation to use when encoding 36 | input masks. 37 | """ 38 | super().__init__() 39 | self.embed_dim = embed_dim 40 | self.input_image_size = input_image_size 41 | self.image_embedding_size = image_embedding_size 42 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 43 | 44 | def get_dense_pe(self) -> torch.Tensor: 45 | """ 46 | Returns the positional encoding used to encode point prompts, 47 | applied to a dense set of points the shape of the image encoding. 48 | 49 | Returns: 50 | torch.Tensor: Positional encoding with shape 51 | 1x(embed_dim)x(embedding_h)x(embedding_w) 52 | """ 53 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 54 | 55 | def _embed_points( 56 | self, 57 | points: torch.Tensor, 58 | labels: torch.Tensor, 59 | pad: bool, 60 | ) -> torch.Tensor: 61 | """Embeds point prompts.""" 62 | points = points + 0.5 # Shift to center of pixel 63 | if pad: 64 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 65 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 66 | points = torch.cat([points, padding_point], dim=1) 67 | labels = torch.cat([labels, padding_label], dim=1) 68 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 69 | point_embedding[labels == -1] = 0.0 70 | point_embedding[labels == -1] += self.not_a_point_embed.weight 71 | point_embedding[labels == 0] += self.point_embeddings[0].weight 72 | point_embedding[labels == 1] += self.point_embeddings[1].weight 73 | return point_embedding 74 | 75 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 76 | """Embeds box prompts.""" 77 | boxes = boxes + 0.5 # Shift to center of pixel 78 | coords = boxes.reshape(-1, 2, 2) 79 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 80 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 81 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 82 | return corner_embedding 83 | 84 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 85 | """Embeds mask inputs.""" 86 | mask_embedding = self.mask_downscaling(masks) 87 | return mask_embedding 88 | 89 | def _get_batch_size( 90 | self, 91 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 92 | boxes: Optional[torch.Tensor], 93 | masks: Optional[torch.Tensor], 94 | ) -> int: 95 | """ 96 | Gets the batch size of the output given the batch size of the input prompts. 97 | """ 98 | if points is not None: 99 | return points[0].shape[0] 100 | elif boxes is not None: 101 | return boxes.shape[0] 102 | elif masks is not None: 103 | return masks.shape[0] 104 | else: 105 | return 1 106 | 107 | def _get_device(self) -> torch.device: 108 | return self.point_embeddings[0].weight.device 109 | 110 | def forward( 111 | self, 112 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 113 | boxes: Optional[torch.Tensor], 114 | masks: Optional[torch.Tensor], 115 | ) -> Tuple[torch.Tensor, torch.Tensor]: 116 | """ 117 | Embeds different types of prompts, returning both sparse and dense 118 | embeddings. 119 | 120 | Arguments: 121 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 122 | and labels to embed. 123 | boxes (torch.Tensor or none): boxes to embed 124 | masks (torch.Tensor or none): masks to embed 125 | 126 | Returns: 127 | torch.Tensor: sparse embeddings for the points and boxes, with shape 128 | BxNx(embed_dim), where N is determined by the number of input points 129 | and boxes. 130 | torch.Tensor: dense embeddings for the masks, in the shape 131 | Bx(embed_dim)x(embed_H)x(embed_W) 132 | """ 133 | bs = self._get_batch_size(points, boxes, masks) 134 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 135 | if points is not None: 136 | coords, labels = points 137 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 138 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 139 | if boxes is not None: 140 | box_embeddings = self._embed_boxes(boxes) 141 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 142 | 143 | if masks is not None: 144 | dense_embeddings = self._embed_masks(masks) 145 | else: 146 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 147 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 148 | ) 149 | 150 | return sparse_embeddings, dense_embeddings 151 | 152 | 153 | class PositionEmbeddingRandom(nn.Module): 154 | """ 155 | Positional encoding using random spatial frequencies. 156 | """ 157 | 158 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 159 | super().__init__() 160 | if scale is None or scale <= 0.0: 161 | scale = 1.0 162 | self.register_buffer( 163 | "positional_encoding_gaussian_matrix", 164 | scale * torch.randn((2, num_pos_feats)), 165 | ) 166 | 167 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 168 | """Positionally encode points that are normalized to [0,1].""" 169 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 170 | coords = 2 * coords - 1 171 | coords = coords @ self.positional_encoding_gaussian_matrix 172 | coords = 2 * np.pi * coords 173 | # outputs d_1 x ... x d_n x C shape 174 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 175 | 176 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 177 | """Generate positional encoding for a grid of the specified size.""" 178 | h, w = size 179 | device: Any = self.positional_encoding_gaussian_matrix.device 180 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 181 | y_embed = grid.cumsum(dim=0) - 0.5 182 | x_embed = grid.cumsum(dim=1) - 0.5 183 | y_embed = y_embed / h 184 | x_embed = x_embed / w 185 | 186 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 187 | return pe.permute(2, 0, 1) # C x H x W 188 | 189 | def forward_with_coords( 190 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 191 | ) -> torch.Tensor: 192 | """Positionally encode points that are not normalized to [0,1].""" 193 | coords = coords_input.clone() 194 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 195 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 196 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 197 | -------------------------------------------------------------------------------- /models/regression.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class UpsamplingLayer(nn.Module): 5 | 6 | def __init__(self, in_channels, out_channels, leaky=True): 7 | 8 | super(UpsamplingLayer, self).__init__() 9 | 10 | self.layer = nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 12 | nn.LeakyReLU() if leaky else nn.ReLU(), 13 | nn.UpsamplingBilinear2d(scale_factor=2) 14 | ) 15 | 16 | def forward(self, x): 17 | return self.layer(x) 18 | 19 | -------------------------------------------------------------------------------- /models/sam_ViT.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 1280, 24 | depth: int = 32, 25 | num_heads: int = 16, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 148 | positional parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 205 | positional parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | from typing import Tuple 6 | from torch import nn 7 | 8 | 9 | class SelfCrossAttentionBlock(nn.Module): 10 | def __init__( 11 | self, 12 | embedding_dim: int, 13 | num_heads: int, 14 | ) -> None: 15 | """ 16 | """ 17 | super().__init__() 18 | self.self_attention = Attention(embedding_dim, num_heads) 19 | self.cross_attention = Attention(embedding_dim, num_heads) 20 | self.norm1 = nn.LayerNorm(embedding_dim) 21 | self.norm2 = nn.LayerNorm(embedding_dim) 22 | 23 | def forward( 24 | self, image_f: Tensor, adapted_image_f: Tensor, pos_enc: Tensor, 25 | ) -> Tuple[Tensor, Tensor]: 26 | adapted_image_f = adapted_image_f + self.self_attention(q=adapted_image_f + pos_enc, 27 | k=adapted_image_f + pos_enc, 28 | v=adapted_image_f + pos_enc) 29 | adapted_image_f = self.norm1(adapted_image_f) 30 | adapted_image_f = adapted_image_f + self.cross_attention(q=adapted_image_f + pos_enc, 31 | k=image_f + pos_enc, 32 | v=image_f + pos_enc) 33 | adapted_image_f = self.norm2(adapted_image_f) 34 | return adapted_image_f 35 | 36 | 37 | class PrototypeAttentionBlock(nn.Module): 38 | def __init__( 39 | self, 40 | embedding_dim: int, 41 | num_heads: int, 42 | ) -> None: 43 | """ 44 | """ 45 | super().__init__() 46 | self.cross_attention = Attention(embedding_dim, num_heads) 47 | self.norm = nn.LayerNorm(embedding_dim) 48 | 49 | def forward( 50 | self, image_f: Tensor, prototypes: Tensor, 51 | ) -> Tuple[Tensor, Tensor]: 52 | image_f = image_f + self.cross_attention(q=image_f, 53 | k=prototypes, 54 | v=prototypes) 55 | image_f = self.norm(image_f) 56 | return image_f 57 | 58 | class ImgToPrototypeAttentionBlock(nn.Module): 59 | def __init__( 60 | self, 61 | embedding_dim: int, 62 | num_heads: int, 63 | ) -> None: 64 | """ 65 | """ 66 | super().__init__() 67 | self.cross_attention = Attention(embedding_dim, num_heads) 68 | self.norm = nn.LayerNorm(embedding_dim) 69 | 70 | def forward( 71 | self, image_f: Tensor, prototypes: Tensor, 72 | ) -> Tuple[Tensor, Tensor]: 73 | 74 | prototypes = prototypes + self.cross_attention(q=prototypes, 75 | k=image_f, 76 | v=image_f) 77 | prototypes = self.norm(prototypes) 78 | return prototypes 79 | 80 | 81 | 82 | class Attention(nn.Module): 83 | """ 84 | An attention layer that allows for downscaling the size of the embedding 85 | after projection to queries, keys, and values. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | embedding_dim: int, 91 | num_heads: int, 92 | downsample_rate: int = 1, 93 | ) -> None: 94 | super().__init__() 95 | self.embedding_dim = embedding_dim 96 | self.internal_dim = embedding_dim // downsample_rate 97 | self.num_heads = num_heads 98 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 99 | 100 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 101 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 102 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 103 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 104 | 105 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 106 | b, n, c = x.shape 107 | x = x.reshape(b, n, num_heads, c // num_heads) 108 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 109 | 110 | def _recombine_heads(self, x: Tensor) -> Tensor: 111 | b, n_heads, n_tokens, c_per_head = x.shape 112 | x = x.transpose(1, 2) 113 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 114 | 115 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 116 | # Input projections 117 | q = self.q_proj(q) 118 | k = self.k_proj(k) 119 | v = self.v_proj(v) 120 | 121 | # Separate into heads 122 | q = self._separate_heads(q, self.num_heads) 123 | k = self._separate_heads(k, self.num_heads) 124 | v = self._separate_heads(v, self.num_heads) 125 | 126 | # Attention 127 | _, _, _, c_per_head = q.shape 128 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 129 | attn = attn / math.sqrt(c_per_head) 130 | attn = torch.softmax(attn, dim=-1) 131 | 132 | # Get output 133 | out = attn @ v 134 | out = self._recombine_heads(out) 135 | out = self.out_proj(out) 136 | 137 | return out 138 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | from models.geco import build_model 2 | from utils.box_ops import compute_location, BoxList 3 | from utils.data import FSC147Dataset 4 | from utils.arg_parser import get_argparser 5 | from utils.losses import ObjectNormalizedL2Loss, Detection_criterion 6 | from time import perf_counter 7 | import argparse 8 | import os 9 | 10 | import torch 11 | from torch import nn 12 | from torch.utils.data import DataLoader, DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel 14 | from torch import distributed as dist 15 | from utils.data import pad_collate 16 | import numpy as np 17 | import random 18 | 19 | 20 | torch.manual_seed(0) 21 | random.seed(0) 22 | np.random.seed(0) 23 | 24 | DATASETS = { 25 | 'fsc147': FSC147Dataset 26 | } 27 | 28 | 29 | def train(args): 30 | if 'SLURM_PROCID' in os.environ: 31 | world_size = int(os.environ['SLURM_NTASKS']) 32 | rank = int(os.environ['SLURM_PROCID']) 33 | gpu = rank % torch.cuda.device_count() 34 | print("Running on SLURM", world_size, rank, gpu) 35 | else: 36 | world_size = int(os.environ['WORLD_SIZE']) 37 | rank = int(os.environ['RANK']) 38 | gpu = int(os.environ['LOCAL_RANK']) 39 | 40 | torch.cuda.set_device(gpu) 41 | device = torch.device(gpu) 42 | 43 | dist.init_process_group( 44 | backend='nccl', init_method='env://', 45 | world_size=world_size, rank=rank 46 | ) 47 | 48 | model = DistributedDataParallel( 49 | build_model(args).to(device), 50 | device_ids=[gpu], 51 | output_device=gpu 52 | ) 53 | 54 | backbone_params = dict() 55 | non_backbone_params = dict() 56 | for n, p in model.named_parameters(): 57 | if 'backbone' in n: 58 | backbone_params[n] = p 59 | else: 60 | non_backbone_params[n] = p 61 | 62 | optimizer = torch.optim.AdamW( 63 | [ 64 | {'params': non_backbone_params.values()}, 65 | {'params': backbone_params.values(), 'lr': args.backbone_lr} 66 | ], 67 | lr=args.lr, 68 | weight_decay=args.weight_decay, 69 | ) 70 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop, gamma=0.25) 71 | if args.resume_training: 72 | checkpoint = torch.load(os.path.join(args.model_path, f'{args.model_name}.pth')) 73 | model.load_state_dict(checkpoint['model']) 74 | start_epoch = checkpoint['epoch'] 75 | best = checkpoint['best_val_ae'] 76 | optimizer.load_state_dict(checkpoint['optimizer']) 77 | scheduler.load_state_dict(checkpoint['scheduler']) 78 | else: 79 | start_epoch = 0 80 | best = 10000000000000 81 | 82 | criterion = ObjectNormalizedL2Loss() 83 | det_criterion = Detection_criterion( 84 | [[-1, 512]], # sizes, 85 | 'giou', # iou_loss_type, 86 | True, # center_sample, 87 | [1], # fpn_strides, 88 | 30, # pos_radius, 89 | ) 90 | 91 | train = DATASETS[args.dataset]( 92 | args.data_path, 93 | args.image_size, 94 | split='train', 95 | num_objects=args.num_objects, 96 | tiling_p=args.tiling_p, 97 | zero_shot=args.zero_shot 98 | ) 99 | val = DATASETS[args.dataset]( 100 | args.data_path, 101 | args.image_size, 102 | split='val', 103 | num_objects=args.num_objects, 104 | tiling_p=args.tiling_p 105 | ) 106 | train_loader = DataLoader( 107 | train, 108 | sampler=DistributedSampler(train), 109 | batch_size=args.batch_size, 110 | drop_last=True, 111 | num_workers=args.num_workers, 112 | collate_fn=pad_collate 113 | ) 114 | val_loader = DataLoader( 115 | val, 116 | sampler=DistributedSampler(val), 117 | batch_size=args.batch_size, 118 | drop_last=False, 119 | num_workers=args.num_workers, 120 | collate_fn=pad_collate 121 | ) 122 | 123 | print(rank) 124 | for epoch in range(start_epoch + 1, args.epochs + 1): 125 | if rank == 0: 126 | start = perf_counter() 127 | train_loss = torch.tensor(0.0).to(device) 128 | val_loss = torch.tensor(0.0).to(device) 129 | train_ae = torch.tensor(0.0).to(device) 130 | val_ae = torch.tensor(0.0).to(device) 131 | val_rmse = torch.tensor(0.0).to(device) 132 | 133 | train_loader.sampler.set_epoch(epoch) 134 | model.train() 135 | 136 | for img, bboxes, img_name, gt_bboxes, density_map in train_loader: 137 | img = img.to(device) 138 | bboxes = bboxes.to(device) 139 | density_map = density_map.to(device) 140 | 141 | optimizer.zero_grad() 142 | _,_,centerness, lrtb = model(img, bboxes) 143 | 144 | lrtb = lrtb * 512 145 | location = compute_location(lrtb) 146 | targets = BoxList(gt_bboxes, (args.image_size, args.image_size), mode='xyxy').to(device).resize( 147 | (512, 512)) 148 | 149 | # obtain the number of objects in batch 150 | with torch.no_grad(): 151 | num_objects = density_map.sum() 152 | dist.all_reduce(num_objects) 153 | det_loss = det_criterion(location, lrtb, targets) / num_objects 154 | main_loss = criterion(centerness, density_map, num_objects) 155 | 156 | loss = main_loss + det_loss 157 | loss.backward() 158 | if args.max_grad_norm > 0: 159 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 160 | optimizer.step() 161 | 162 | train_loss += main_loss * img.size(0) 163 | train_ae += torch.abs( 164 | density_map.flatten(1).sum(dim=1) - centerness.flatten(1).sum(dim=1) 165 | ).sum() 166 | 167 | model.eval() 168 | with torch.no_grad(): 169 | index_val = 0 170 | for img, bboxes, img_name, gt_bboxes, density_map in val_loader: 171 | img = img.to(device) 172 | bboxes = bboxes.to(device) 173 | density_map = density_map.to(device) 174 | 175 | optimizer.zero_grad() 176 | 177 | _,_,centerness, lrtb = model(img, bboxes) 178 | 179 | lrtb = lrtb * 512 180 | location = compute_location(lrtb) 181 | targets = BoxList(gt_bboxes, (args.image_size, args.image_size), mode='xyxy').to(device).resize( 182 | (512, 512)) 183 | 184 | # obtain the number of objects in batch 185 | with torch.no_grad(): 186 | num_objects = density_map.sum() 187 | dist.all_reduce(num_objects) 188 | det_loss = det_criterion(location, lrtb, targets) 189 | main_loss = criterion(centerness, density_map, num_objects) 190 | 191 | loss = main_loss + det_loss 192 | val_loss += loss 193 | val_ae += torch.abs( 194 | density_map.flatten(1).sum(dim=1) - centerness.flatten(1).sum(dim=1) 195 | ).sum() 196 | val_rmse += torch.pow( 197 | density_map.flatten(1).sum(dim=1) - centerness.flatten(1).sum(dim=1), 2 198 | ).sum() 199 | 200 | if args.max_grad_norm > 0: 201 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 202 | optimizer.step() 203 | 204 | dist.all_reduce(train_loss) 205 | dist.all_reduce(val_loss) 206 | dist.all_reduce(val_rmse) 207 | dist.all_reduce(train_ae) 208 | dist.all_reduce(val_ae) 209 | 210 | scheduler.step() 211 | 212 | if rank == 0: 213 | end = perf_counter() 214 | best_epoch = False 215 | if val_rmse.item() / len(val) < best: 216 | best = val_rmse.item() / len(val) 217 | checkpoint = { 218 | 'epoch': epoch, 219 | 'model': model.state_dict(), 220 | 'optimizer': optimizer.state_dict(), 221 | 'scheduler': scheduler.state_dict(), 222 | 'best_val_ae': val_ae.item() / len(val) 223 | } 224 | 225 | torch.save( 226 | checkpoint, 227 | os.path.join(args.model_path, f'{args.model_name}.pth') 228 | ) 229 | best_epoch = True 230 | torch.save( 231 | checkpoint, 232 | os.path.join(args.model_path, f'{args.model_name}_last.pth') 233 | ) 234 | 235 | print( 236 | f"Epoch: {epoch}", 237 | f"Train loss: {train_loss.item():.3f}", 238 | f"Val loss: {val_loss.item():.3f}", 239 | f"Train MAE: {train_ae.item() / len(train):.3f}", 240 | f"Val MAE: {val_ae.item() / len(val):.3f}", 241 | f"Val RMSE: {torch.sqrt(val_rmse / len(val)).item():.2f}", 242 | f"Epoch time: {end - start:.3f} seconds", 243 | 'best' if best_epoch else '' 244 | ) 245 | 246 | dist.destroy_process_group() 247 | 248 | 249 | if __name__ == '__main__': 250 | parser = argparse.ArgumentParser('GeCo', parents=[get_argparser()]) 251 | args = parser.parse_args() 252 | print(args) 253 | train(args) 254 | -------------------------------------------------------------------------------- /pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GeCo 3 | #SBATCH --output=train/GeCo_pretrain_%j.txt 4 | #SBATCH --error=train/GeCo_pretrain_%j.txt 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks-per-node=2 7 | #SBATCH --cpus-per-task=12 8 | #SBATCH --partition=gpu 9 | #SBATCH --gres=gpu:2 10 | #SBATCH --time=4-00:00:00 11 | 12 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 13 | export MASTER_ADDR=$master_addr 14 | export MASTER_PORT=50197 15 | export NCCL_P2P_DISABLE=1 16 | export NCCL_IB_DISABLE=1 17 | 18 | module load Anaconda3 19 | source activate cotr 20 | conda activate base 21 | conda activate cotr 22 | 23 | srun --unbuffered python pretrain.py \ 24 | --model_name GeCo_PRETRAIN \ 25 | --data_path /d/hpc/projects/FRI/pelhanj/fsc147 \ 26 | --epochs 150 \ 27 | --lr 1e-4 \ 28 | --backbone_lr 0 \ 29 | --lr_drop 150 \ 30 | --weight_decay 1e-4 \ 31 | --batch_size 4 \ 32 | --tiling_p 0.2 -------------------------------------------------------------------------------- /segment_anything/LICENSE(1): -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open('/home/jer/PycharmProjects/segment-anything/sam_vit_h_4b8939.pth', "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | #find max value in the tensor iou_pred, in the second dimension, and use masks with the same index 101 | ids = torch.argmax(iou_pred, dim=1) 102 | masks = masks[torch.arange(masks.size(0)), ids].unsqueeze(0).permute(1,0,2,3) 103 | iou_pred = iou_pred[torch.arange(iou_pred.size(0)), ids].unsqueeze(0).permute(1,0) 104 | 105 | # # Select the correct mask or masks for output 106 | # if multimask_output: 107 | # mask_slice = slice(1, None) 108 | # else: 109 | # mask_slice = slice(0, 1) 110 | # masks = masks[:, mask_slice, :, :] 111 | # iou_pred = iou_pred[:, mask_slice] 112 | 113 | # Prepare output 114 | return masks, iou_pred 115 | 116 | def predict_masks( 117 | self, 118 | image_embeddings: torch.Tensor, 119 | image_pe: torch.Tensor, 120 | sparse_prompt_embeddings: torch.Tensor, 121 | dense_prompt_embeddings: torch.Tensor, 122 | ) -> Tuple[torch.Tensor, torch.Tensor]: 123 | """Predicts masks. See 'forward' for more details.""" 124 | # Concatenate output tokens 125 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 126 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 127 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 128 | 129 | # Expand per-image data in batch direction to be per-mask 130 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 131 | if dense_prompt_embeddings.shape[1:] != image_embeddings.shape[1:]: 132 | upsample_pos_emb = nn.UpsamplingBilinear2d(scale_factor=1.5) 133 | dense_prompt_embeddings = upsample_pos_emb(dense_prompt_embeddings) 134 | src = src + dense_prompt_embeddings 135 | if image_pe.shape[1:] != image_embeddings.shape[1:]: 136 | upsample_pos_emb = nn.UpsamplingBilinear2d(scale_factor=1.5) 137 | image_pe = upsample_pos_emb(image_pe) 138 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 139 | 140 | 141 | b, c, h, w = src.shape 142 | 143 | # Run the transformer 144 | hs, src = self.transformer(src, pos_src, tokens) 145 | iou_token_out = hs[:, 0, :] 146 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 147 | 148 | # Upscale mask embeddings and predict masks using the mask tokens 149 | src = src.transpose(1, 2).view(b, c, h, w) 150 | upscaled_embedding = self.output_upscaling(src) 151 | hyper_in_list: List[torch.Tensor] = [] 152 | for i in range(self.num_mask_tokens): 153 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 154 | hyper_in = torch.stack(hyper_in_list, dim=1) 155 | b, c, h, w = upscaled_embedding.shape 156 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 157 | 158 | # Generate mask quality predictions 159 | iou_pred = self.iou_prediction_head(iou_token_out) 160 | 161 | return masks, iou_pred 162 | 163 | 164 | # Lightly adapted from 165 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 166 | class MLP(nn.Module): 167 | def __init__( 168 | self, 169 | input_dim: int, 170 | hidden_dim: int, 171 | output_dim: int, 172 | num_layers: int, 173 | sigmoid_output: bool = False, 174 | ) -> None: 175 | super().__init__() 176 | self.num_layers = num_layers 177 | h = [hidden_dim] * (num_layers - 1) 178 | self.layers = nn.ModuleList( 179 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 180 | ) 181 | self.sigmoid_output = sigmoid_output 182 | 183 | def forward(self, x): 184 | for i, layer in enumerate(self.layers): 185 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 186 | if self.sigmoid_output: 187 | x = F.sigmoid(x) 188 | return x 189 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | # and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from models.geco import build_model 2 | from models.matcher import build_matcher 3 | from torchvision import ops 4 | from utils.data import FSC147Dataset 5 | from utils.arg_parser import get_argparser 6 | from utils.losses import SetCriterion 7 | from time import perf_counter 8 | import argparse 9 | import os 10 | import torch 11 | from torch import nn 12 | from torch.utils.data import DataLoader, DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel 14 | from torch import distributed as dist 15 | from utils.data import pad_collate 16 | import numpy as np 17 | import random 18 | 19 | torch.manual_seed(0) 20 | random.seed(0) 21 | np.random.seed(0) 22 | 23 | DATASETS = { 24 | 'fsc147': FSC147Dataset 25 | } 26 | 27 | def train(args): 28 | if 'SLURM_PROCID' in os.environ: 29 | world_size = int(os.environ['SLURM_NTASKS']) 30 | rank = int(os.environ['SLURM_PROCID']) 31 | gpu = rank % torch.cuda.device_count() 32 | print("Running on SLURM", world_size, rank, gpu) 33 | else: 34 | world_size = int(os.environ['WORLD_SIZE']) 35 | rank = int(os.environ['RANK']) 36 | gpu = int(os.environ['LOCAL_RANK']) 37 | 38 | torch.cuda.set_device(gpu) 39 | device = torch.device(gpu) 40 | 41 | dist.init_process_group( 42 | backend='nccl', init_method='env://', 43 | world_size=world_size, rank=rank 44 | ) 45 | 46 | model = DistributedDataParallel( 47 | build_model(args).to(device), 48 | device_ids=[gpu], 49 | output_device=gpu 50 | ) 51 | 52 | backbone_params = dict() 53 | non_backbone_params = dict() 54 | for n, p in model.named_parameters(): 55 | if 'backbone' in n: 56 | backbone_params[n] = p 57 | else: 58 | non_backbone_params[n] = p 59 | 60 | optimizer = torch.optim.AdamW( 61 | [ 62 | {'params': non_backbone_params.values()}, 63 | {'params': backbone_params.values(), 'lr': args.backbone_lr} 64 | ], 65 | lr=args.lr, 66 | weight_decay=args.weight_decay, 67 | ) 68 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop, gamma=0.25) 69 | if args.resume_training: 70 | checkpoint = torch.load(os.path.join(args.model_path, f'{args.model_name}.pth')) 71 | model.load_state_dict(checkpoint['model'], strict=False) 72 | 73 | start_epoch = 0 74 | best = 10000000000000 75 | matcher = build_matcher(args) 76 | criterion = SetCriterion(0, matcher, {"loss_giou": args.giou_loss_coef}, ["bboxes", "ce"], 77 | focal_alpha=args.focal_alpha) 78 | criterion.to(device) 79 | 80 | train = DATASETS[args.dataset]( 81 | args.data_path, 82 | args.image_size, 83 | split='train', 84 | num_objects=args.num_objects, 85 | tiling_p=args.tiling_p, 86 | zero_shot=args.zero_shot 87 | ) 88 | val = DATASETS[args.dataset]( 89 | args.data_path, 90 | args.image_size, 91 | split='val', 92 | num_objects=args.num_objects, 93 | tiling_p=args.tiling_p 94 | ) 95 | train_loader = DataLoader( 96 | train, 97 | sampler=DistributedSampler(train), 98 | batch_size=args.batch_size, 99 | drop_last=True, 100 | num_workers=args.num_workers, 101 | collate_fn=pad_collate 102 | ) 103 | val_loader = DataLoader( 104 | val, 105 | sampler=DistributedSampler(val), 106 | batch_size=args.batch_size, 107 | drop_last=False, 108 | num_workers=args.num_workers, 109 | collate_fn=pad_collate 110 | ) 111 | 112 | 113 | print(rank) 114 | for epoch in range(start_epoch + 1, args.epochs + 1): 115 | if rank == 0: 116 | start = perf_counter() 117 | train_loss = torch.tensor(0.0).to(device) 118 | val_loss = torch.tensor(0.0).to(device) 119 | train_ae = torch.tensor(0.0).to(device) 120 | val_ae = torch.tensor(0.0).to(device) 121 | val_rmse = torch.tensor(0.0).to(device) 122 | 123 | train_loader.sampler.set_epoch(epoch) 124 | model.train() 125 | criterion.train() 126 | for img, bboxes, img_name, gt_bboxes, _ in train_loader: 127 | img = img.to(device) 128 | bboxes = bboxes.to(device) 129 | gt_bboxes = gt_bboxes.to(device) 130 | 131 | optimizer.zero_grad() 132 | outputs, ref_points, centerness, outputs_coord = model(img, bboxes) 133 | 134 | losses = [] 135 | num_objects_gt = [] 136 | num_objects_pred = [] 137 | 138 | nms_bboxes = [] 139 | for idx in range(img.shape[0]): 140 | target_bboxes = gt_bboxes[idx][torch.logical_not((gt_bboxes[idx] == 0).all(dim=1))] / 1024 141 | 142 | l = criterion(outputs[idx], 143 | [{"boxes": target_bboxes, "labels": torch.tensor([0] * target_bboxes.shape[0])}], 144 | centerness[idx], ref_points[idx]) 145 | keep = ops.nms(outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / 8], 146 | outputs[idx]['box_v'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / 8], 0.5) 147 | 148 | num_objects_gt.append(len(target_bboxes)) 149 | 150 | boxes = (outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / 8])[keep] 151 | nms_bboxes.append(boxes) 152 | num_objects_pred.append(len(boxes)) 153 | losses.append(l['loss_giou'] + l["loss_l2"] + + l["loss_bbox"]) 154 | loss = sum(losses) 155 | 156 | loss.backward() 157 | 158 | if args.max_grad_norm > 0: 159 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 160 | optimizer.step() 161 | train_loss += loss 162 | train_ae += torch.abs(torch.tensor(num_objects_gt) - torch.tensor(num_objects_pred)).sum() 163 | criterion.eval() 164 | model.eval() 165 | with torch.no_grad(): 166 | for img, bboxes, img_name, gt_bboxes, _ in val_loader: 167 | img = img.to(device) 168 | bboxes = bboxes.to(device) 169 | gt_bboxes = gt_bboxes.to(device) 170 | 171 | optimizer.zero_grad() 172 | outputs, ref_points, centerness, outputs_coord = model(img, bboxes) 173 | 174 | losses = [] 175 | num_objects_gt = [] 176 | num_objects_pred = [] 177 | nms_bboxes = [] 178 | 179 | for idx in range(img.shape[0]): 180 | # print(img_name[idx]) 181 | target_bboxes = gt_bboxes[idx][torch.logical_not((gt_bboxes[idx] == 0).all(dim=1))] / 1024 182 | 183 | l = criterion(outputs[idx], 184 | [{"boxes": target_bboxes, "labels": torch.tensor([0] * target_bboxes.shape[0])}], 185 | centerness[idx], ref_points[idx]) 186 | keep = ops.nms(outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / 8], 187 | outputs[idx]['box_v'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / 8], 0.5) 188 | 189 | num_objects_gt.append(len(target_bboxes)) 190 | 191 | boxes = (outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / 8])[keep] 192 | nms_bboxes.append(boxes) 193 | num_objects_pred.append(len(boxes)) 194 | losses.append(l['loss_giou'] + l["loss_l2"] + l["loss_bbox"]) 195 | loss = sum(losses) 196 | 197 | train_loss += loss 198 | num_objects_gt = torch.tensor(num_objects_gt) 199 | num_objects_pred = torch.tensor(num_objects_pred) 200 | 201 | val_loss += loss 202 | val_ae += torch.abs( 203 | num_objects_gt - num_objects_pred 204 | ).sum() 205 | val_rmse += torch.pow( 206 | num_objects_gt - num_objects_pred, 2 207 | ).sum() 208 | 209 | if args.max_grad_norm > 0: 210 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 211 | optimizer.step() 212 | 213 | dist.all_reduce(train_loss) 214 | dist.all_reduce(val_loss) 215 | dist.all_reduce(val_rmse) 216 | dist.all_reduce(train_ae) 217 | dist.all_reduce(val_ae) 218 | 219 | scheduler.step() 220 | 221 | if rank == 0: 222 | end = perf_counter() 223 | best_epoch = False 224 | if val_rmse.item() / len(val) < best: 225 | best = val_rmse.item() / len(val) 226 | checkpoint = { 227 | 'epoch': epoch, 228 | 'model': model.state_dict(), 229 | 'optimizer': optimizer.state_dict(), 230 | 'scheduler': scheduler.state_dict(), 231 | 'best_val_ae': val_rmse.item() / len(val) 232 | } 233 | 234 | torch.save( 235 | checkpoint, 236 | os.path.join(args.model_path, f'{args.model_name_resumed}.pth') 237 | ) 238 | 239 | best_epoch = True 240 | 241 | print( 242 | f"Epoch: {epoch}", 243 | f"Train loss: {train_loss.item():.3f}", 244 | f"Val loss: {val_loss.item():.3f}", 245 | f"Train MAE: {train_ae.item() / len(train):.3f}", 246 | f"Val MAE: {val_ae.item() / len(val):.3f}", 247 | f"Val RMSE: {torch.sqrt(val_rmse / len(val)).item():.2f}", 248 | f"Epoch time: {end - start:.3f} seconds", 249 | 'best' if best_epoch else '' 250 | ) 251 | 252 | dist.destroy_process_group() 253 | 254 | 255 | if __name__ == '__main__': 256 | parser = argparse.ArgumentParser('GeCo', parents=[get_argparser()]) 257 | args = parser.parse_args() 258 | print(args) 259 | train(args) 260 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GECO 3 | #SBATCH --output=train/GeCo_%j.txt 4 | #SBATCH --error=train/GeCo_%j.txt 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks-per-node=2 7 | #SBATCH --cpus-per-task=12 8 | #SBATCH --partition=gpu 9 | #SBATCH --gres=gpu:2 10 | #SBATCH --time=4-00:00:00 11 | 12 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 13 | export MASTER_ADDR=$master_addr 14 | export MASTER_PORT=50197 15 | export NCCL_P2P_DISABLE=1 16 | export NCCL_IB_DISABLE=1 17 | export NCCL_BLOCKING_WAIT=1 18 | export TORCH_DISTRIBUTED_DEBUG=DETAIL 19 | 20 | module load Anaconda3 21 | source activate geco 22 | conda activate base 23 | conda activate geco 24 | 25 | srun --unbuffered python train.py \ 26 | --resume_training \ 27 | --model_name GeCo_PRETRAIN \ 28 | --model_name_resumed GeCo \ 29 | --data_path /d/hpc/projects/FRI/pelhanj/fsc147 \ 30 | --epochs 200 \ 31 | --lr 1e-4 \ 32 | --backbone_lr 0 \ 33 | --lr_drop 200 \ 34 | --weight_decay 1e-4 \ 35 | --batch_size 4 \ 36 | --tiling_p 0.5 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerpelhan/GeCo/44b71e572b11d41822da3a52a9b7c3be85b7194e/utils/__init__.py -------------------------------------------------------------------------------- /utils/arg_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_argparser(): 5 | 6 | parser = argparse.ArgumentParser("LOCA parser", add_help=False) 7 | 8 | parser.add_argument('--model_name', default='GeCo_updated', type=str) 9 | parser.add_argument('--model_name_resumed', default='', type=str) 10 | parser.add_argument( 11 | '--data_path', 12 | default='/storage/datasets/fsc147', 13 | type=str 14 | ) 15 | parser.add_argument( 16 | '--model_path', 17 | default='./', 18 | type=str 19 | ) 20 | parser.add_argument( 21 | '--image_path', 22 | default='./', 23 | type=str 24 | ) 25 | parser.add_argument('--dataset', default='fsc147', type=str) 26 | parser.add_argument('--reduction', default=16, type=int) 27 | parser.add_argument('--image_size', default=1024, type=int) 28 | parser.add_argument('--emb_dim', default=256, type=int) 29 | parser.add_argument('--num_heads', default=8, type=int) 30 | parser.add_argument('--kernel_dim', default=1, type=int) 31 | parser.add_argument('--num_objects', default=3, type=int) 32 | parser.add_argument('--epochs', default=200, type=int) 33 | parser.add_argument('--resume_training', action='store_true') 34 | parser.add_argument('--lr', default=1e-4, type=float) 35 | parser.add_argument('--backbone_lr', default=0, type=float) 36 | parser.add_argument('--lr_drop', default=200, type=int) 37 | parser.add_argument('--weight_decay', default=1e-4, type=float) 38 | parser.add_argument('--batch_size', default=1, type=int) 39 | parser.add_argument('--num_workers', default=8, type=int) 40 | parser.add_argument('--max_grad_norm', default=0.1, type=float) 41 | parser.add_argument('--tiling_p', default=0.5, type=float) 42 | parser.add_argument('--zero_shot', action='store_true') 43 | parser.add_argument("--giou_loss_coef", default=2, type=float) 44 | parser.add_argument("--cost_class", default=2, type=float, help="Class coefficient in the matching cost") 45 | parser.add_argument("--cost_bbox", default=1, type=float, help="L1 box coefficient in the matching cost") 46 | parser.add_argument("--cost_giou", default=2, type=float, help="giou box coefficient in the matching cost") 47 | parser.add_argument("--focal_alpha", default=0.25, type=float) 48 | parser.add_argument('--output_masks', action='store_true') 49 | 50 | return parser 51 | -------------------------------------------------------------------------------- /utils/box_ops.py: -------------------------------------------------------------------------------- 1 | from torchvision import ops 2 | import torch 3 | from torchvision.ops.boxes import box_area 4 | from torch.nn import functional as F 5 | 6 | 7 | def boxes_with_scores(density_map, tlrb, sort=False, batch_thresh=None): 8 | B, C, _, _ = density_map.shape # B, 1, H, W 9 | 10 | pooled = F.max_pool2d(density_map, 3, 1, 1) 11 | if batch_thresh is None: 12 | batch_thresh = torch.median(density_map.reshape(B, -1), dim=-1).values.view(B, C, 1, 1) 13 | 14 | mask = (pooled == density_map) & (density_map > batch_thresh) 15 | 16 | out_batch = [] 17 | ref_points_batch = [] 18 | for i in range(B): 19 | # select the masked density maps and box offsets 20 | bbox_scores = density_map[i, mask[i]] 21 | ref_points = mask[i].nonzero()[:, -2:] 22 | 23 | # normalize center locations 24 | bbox_centers = ref_points / torch.tensor(mask.shape[2:], device=mask.device) 25 | 26 | # select masked box offsets, permute to keep channels last 27 | tlrb_ = tlrb[i].permute(1, 2, 0) 28 | bbox_offsets = tlrb_[mask[i].permute(1, 2, 0).expand_as(tlrb_)].reshape(-1, 4) 29 | 30 | sign = torch.tensor([-1, -1, 1, 1], device=mask.device) 31 | bbox_xyxy = bbox_centers.flip(-1).repeat(1, 2) + sign * bbox_offsets 32 | 33 | # sort by bbox scores 34 | if sort: 35 | perm = torch.argsort(bbox_scores, descending=True) 36 | bbox_scores = bbox_scores[perm] 37 | bbox_xyxy = bbox_xyxy[perm] 38 | ref_points = ref_points[perm] 39 | 40 | out_batch.append({ 41 | "pred_boxes": bbox_xyxy.unsqueeze(0), 42 | "box_v": bbox_scores.unsqueeze(0) 43 | }) 44 | ref_points_batch.append(ref_points.T) 45 | 46 | return out_batch, ref_points_batch 47 | 48 | # modified from torchvision to also return the union 49 | def box_iou(boxes1, boxes2): 50 | area1 = box_area(boxes1) 51 | area2 = box_area(boxes2) 52 | 53 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 54 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 55 | 56 | wh = (rb - lt).clamp(min=0) # [N,M,2] 57 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 58 | 59 | union = area1[:, None] + area2 - inter + 1e-16 # [N,M] 60 | 61 | iou = inter / union 62 | return iou, union 63 | 64 | 65 | def generalized_box_iou(boxes1, boxes2): 66 | """ 67 | Generalized IoU from https://giou.stanford.edu/ 68 | 69 | The boxes should be in [x0, y0, x1, y1] format 70 | 71 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 72 | and M = len(boxes2) 73 | """ 74 | # degenerate boxes gives inf / nan results 75 | # so do an early check 76 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 77 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 78 | iou, union = box_iou(boxes1, boxes2) 79 | 80 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 81 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 82 | 83 | wh = (rb - lt).clamp(min=0) # [N,M,2] 84 | area = wh[:, :, 0] * wh[:, :, 1] + 1e-16 # [N,M] 85 | 86 | return iou - (area - union) / area 87 | 88 | 89 | 90 | 91 | class BoxList: 92 | def __init__(self, box, image_size, mode='xyxy'): 93 | device = box.device if hasattr(box, 'device') else 'cpu' 94 | if torch.is_tensor(box): 95 | box = torch.as_tensor(box, dtype=torch.float32, device=device) 96 | else: 97 | box = torch.as_tensor(np.array(box), dtype=torch.float32, device=device) 98 | 99 | self.box = box 100 | self.size = image_size 101 | self.mode = mode 102 | 103 | self.fields = {} 104 | 105 | def convert(self, mode): 106 | if mode == self.mode: 107 | return self 108 | 109 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 110 | 111 | if mode == 'xyxy': 112 | box = torch.cat([x_min, y_min, x_max, y_max], -1) 113 | box = BoxList(box, self.size, mode=mode) 114 | 115 | elif mode == 'xywh': 116 | remove = 1 117 | box = torch.cat( 118 | [x_min, y_min, x_max - x_min + remove, y_max - y_min + remove], -1 119 | ) 120 | box = BoxList(box, self.size, mode=mode) 121 | 122 | box.copy_field(self) 123 | 124 | return box 125 | 126 | def copy_field(self, box): 127 | for k, v in box.fields.items(): 128 | self.fields[k] = v 129 | 130 | def area(self): 131 | box = self.box 132 | 133 | if self.mode == 'xyxy': 134 | remove = 1 135 | 136 | area = (box[:, 2] - box[:, 0] + remove) * (box[:, 3] - box[:, 1] + remove) 137 | 138 | elif self.mode == 'xywh': 139 | area = box[:, 2] * box[:, 3] 140 | 141 | return area 142 | 143 | def split_to_xyxy(self): 144 | if self.mode == 'xyxy': 145 | x_min, y_min, x_max, y_max = self.box.split(1, dim=-1) 146 | 147 | return x_min, y_min, x_max, y_max 148 | 149 | elif self.mode == 'xywh': 150 | remove = 1 151 | x_min, y_min, w, h = self.box.split(1, dim=-1) 152 | 153 | return ( 154 | x_min, 155 | y_min, 156 | x_min + (w - remove).clamp(min=0), 157 | y_min + (h - remove).clamp(min=0), 158 | ) 159 | 160 | def __len__(self): 161 | return self.box.shape[0] 162 | 163 | def __getitem__(self, index): 164 | box = BoxList(self.box[index], self.size, self.mode) 165 | 166 | return box 167 | 168 | def resize(self, size, *args, **kwargs): 169 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) 170 | 171 | if ratios[0] == ratios[1]: 172 | ratio = ratios[0] 173 | scaled = self.box * ratio 174 | box = BoxList(scaled, size, mode=self.mode) 175 | 176 | for k, v in self.fields.items(): 177 | if not isinstance(v, torch.Tensor): 178 | v = v.resize(size, *args, **kwargs) 179 | 180 | box.fields[k] = v 181 | 182 | return box 183 | 184 | ratio_w, ratio_h = ratios 185 | x_min, y_min, x_max, y_max = self.split_to_xyxy() 186 | scaled_x_min = x_min * ratio_w 187 | scaled_x_max = x_max * ratio_w 188 | scaled_y_min = y_min * ratio_h 189 | scaled_y_max = y_max * ratio_h 190 | scaled = torch.cat([scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max], -1) 191 | box = BoxList(scaled, size, mode='xyxy') 192 | 193 | for k, v in self.fields.items(): 194 | if not isinstance(v, torch.Tensor): 195 | v = v.resize(size, *args, **kwargs) 196 | 197 | box.fields[k] = v 198 | 199 | return box.convert(self.mode) 200 | 201 | def clip(self, remove_empty=True): 202 | remove = 1 203 | 204 | max_width = self.size[0] - remove 205 | max_height = self.size[1] - remove 206 | 207 | self.box[:, 0].clamp_(min=0, max=max_width) 208 | self.box[:, 1].clamp_(min=0, max=max_height) 209 | self.box[:, 2].clamp_(min=0, max=max_width) 210 | self.box[:, 3].clamp_(min=0, max=max_height) 211 | 212 | if remove_empty: 213 | box = self.box 214 | keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) 215 | 216 | return self[keep] 217 | 218 | else: 219 | return self 220 | 221 | def to(self, device): 222 | box = BoxList(self.box.to(device), self.size, self.mode) 223 | 224 | for k, v in self.fields.items(): 225 | if hasattr(v, 'to'): 226 | v = v.to(device) 227 | 228 | box.fields[k] = v 229 | 230 | return box 231 | 232 | 233 | 234 | def compute_location(features): 235 | locations = [] 236 | _, _, height, width = features.shape 237 | location_per_level = compute_location_per_level( 238 | height, width, 1, features.device 239 | ) 240 | locations.append(location_per_level) 241 | 242 | return locations 243 | 244 | def compute_location_per_level(height, width, stride, device): 245 | shift_x = torch.arange( 246 | 0, width * stride, step=stride, dtype=torch.float32, device=device 247 | ) 248 | shift_y = torch.arange( 249 | 0, height * stride, step=stride, dtype=torch.float32, device=device 250 | ) 251 | shift_y, shift_x = torch.meshgrid(shift_y, shift_x) 252 | shift_x = shift_x.reshape(-1) 253 | shift_y = shift_y.reshape(-1) 254 | location = torch.stack((shift_x, shift_y), 1) + stride // 2 255 | 256 | return location -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from PIL import Image 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms as T 10 | from torchvision.transforms import functional as TVF 11 | 12 | 13 | from torch.nn.utils.rnn import pad_sequence 14 | 15 | 16 | def pad_collate(batch): 17 | (img, bboxes, image_names, gt_bboxes, dmap) = zip(*batch) 18 | if None in gt_bboxes: 19 | return None, None, None, torch.stack(image_names), None, None 20 | gt_bboxes_pad = pad_sequence(gt_bboxes, batch_first=True, padding_value=0) 21 | img = torch.stack(img) 22 | bboxes = torch.stack(bboxes) 23 | image_names = torch.stack(image_names) 24 | dmaps = torch.stack(dmap) 25 | gt_bboxes = gt_bboxes_pad 26 | return img, bboxes, image_names, gt_bboxes, dmaps 27 | 28 | 29 | def xywh_to_x1y1x2y2(xywh): 30 | x, y, w, h = xywh 31 | x1 = x 32 | y1 = y 33 | x2 = x + w 34 | y2 = y + h 35 | return [x1, y1, x2, y2] 36 | 37 | 38 | def resize_and_pad(img, bboxes, density_map=None, size=1024.0, gt_bboxes=None, full_stretch=True, downscale_factor=1): 39 | channels, original_height, original_width = img.shape 40 | longer_dimension = max(original_height, original_width) 41 | scaling_factor = size / longer_dimension 42 | 43 | if not full_stretch: 44 | scaled_bboxes = bboxes * scaling_factor 45 | 46 | a_dim = ((scaled_bboxes[:, 2] - scaled_bboxes[:, 0]).mean() + ( 47 | scaled_bboxes[:, 3] - scaled_bboxes[:, 1]).mean()) / 2 48 | scaling_factor = min(1.0, 80 / a_dim.item()) * scaling_factor 49 | 50 | if downscale_factor != 1: 51 | scaling_factor = scaling_factor * downscale_factor 52 | 53 | resized_img = torch.nn.functional.interpolate(img.unsqueeze(0), scale_factor=scaling_factor, mode='bilinear', 54 | align_corners=False) 55 | 56 | if max(resized_img.shape) <= 1024: 57 | size = 1024 58 | size = int(size) 59 | pad_height = max(0, size - resized_img.shape[2]) 60 | pad_width = max(0, size - resized_img.shape[3]) 61 | 62 | padded_img = torch.nn.functional.pad(resized_img, (0, pad_width, 0, pad_height), mode='constant', value=0)[0] 63 | 64 | if density_map is not None: 65 | original_sum = density_map.sum() 66 | _, w0, h0 = density_map.shape 67 | _, W, H = img.shape 68 | resized_density_map = torch.nn.functional.interpolate(density_map.unsqueeze(0), size=(W, H), mode='bilinear', 69 | align_corners=False) 70 | resized_density_map = torch.nn.functional.interpolate(resized_density_map, scale_factor=scaling_factor, 71 | mode='bilinear', 72 | align_corners=False) 73 | padded_density_map = \ 74 | torch.nn.functional.pad(resized_density_map, (0, pad_width, 0, pad_height), mode='constant', value=0)[0] 75 | padded_density_map = T.Resize((512, 512), antialias=True)(padded_density_map) 76 | padded_density_map = padded_density_map / padded_density_map.sum() * original_sum 77 | 78 | bboxes = bboxes * torch.tensor([scaling_factor, scaling_factor, scaling_factor, scaling_factor]) 79 | 80 | if gt_bboxes is not None and density_map is not None: 81 | gt_bboxes = gt_bboxes * torch.tensor([scaling_factor, scaling_factor, scaling_factor, scaling_factor]) 82 | return padded_img, bboxes, padded_density_map, gt_bboxes, scaling_factor, (pad_width, pad_height) 83 | if gt_bboxes is not None: 84 | return padded_img, bboxes, gt_bboxes, scaling_factor, (pad_width, pad_height) 85 | if density_map is None and gt_bboxes is None: 86 | return padded_img, bboxes, scaling_factor 87 | 88 | return padded_img, bboxes, padded_density_map 89 | 90 | 91 | def tiling_augmentation(img, bboxes, resize, jitter, tile_size, hflip_p, gt_bboxes=None, density_map=None): 92 | def apply_hflip(tensor, apply): 93 | return TVF.hflip(tensor) if apply else tensor 94 | 95 | def make_tile(x, num_tiles, hflip, hflip_p, jitter=None): 96 | result = list() 97 | for j in range(num_tiles): 98 | row = list() 99 | for k in range(num_tiles): 100 | t = jitter(x) if jitter is not None else x 101 | # if hflip[j, k] < hflip_p: 102 | # t = TVF.hflip(t) 103 | row.append(t) 104 | result.append(torch.cat(row, dim=-1)) 105 | return torch.cat(result, dim=-2) 106 | 107 | x_tile, y_tile = tile_size 108 | y_target, x_target = resize.size 109 | num_tiles = max(int(x_tile.ceil()), int(y_tile.ceil())) 110 | # whether to horizontally flip each tile 111 | hflip = torch.rand(num_tiles, num_tiles) 112 | 113 | img = make_tile(img, num_tiles, hflip, hflip_p, jitter) 114 | c, h, w = img.shape 115 | img = resize(img[..., :int(y_tile * y_target), :int(x_tile * x_target)]) 116 | if density_map is not None: 117 | density_map = make_tile(density_map, num_tiles, hflip, hflip_p) 118 | density_map = density_map[..., :int(y_tile * y_target), :int(x_tile * x_target)] 119 | original_sum = density_map.sum() 120 | density_map = T.Resize((512, 512), antialias=True)(density_map) 121 | density_map = density_map / density_map.sum() * original_sum 122 | 123 | bboxes = bboxes / torch.tensor([w, h, w, h]) * resize.size[0] 124 | if gt_bboxes is not None: 125 | gt_bboxes_ = gt_bboxes / torch.tensor([w, h, w, h]) * resize.size[0] 126 | gt_bboxes_tiled = torch.cat([gt_bboxes_, 127 | gt_bboxes_ + torch.tensor([0, 512, 0, 512]), 128 | gt_bboxes_ + torch.tensor([512, 0, 512, 0]), 129 | gt_bboxes_ + torch.tensor([512, 512, 512, 512])]) 130 | if density_map is None: 131 | return img, bboxes, gt_bboxes_tiled 132 | else: 133 | return img, bboxes, density_map, gt_bboxes_tiled 134 | 135 | return img, bboxes, density_map 136 | 137 | 138 | class FSC147Dataset(Dataset): 139 | 140 | def __init__( 141 | self, data_path, img_size, split='train', num_objects=3, 142 | tiling_p=0.5, zero_shot=False, return_ids=False, evaluation=False 143 | ): 144 | from pycocotools.coco import COCO 145 | self.split = split 146 | self.data_path = data_path 147 | self.horizontal_flip_p = 0.5 148 | self.tiling_p = tiling_p 149 | self.img_size = img_size 150 | self.resize = T.Resize((img_size, img_size), antialias=True) 151 | self.resize512 = T.Resize((512, 512), antialias=True) 152 | self.jitter = T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8) 153 | self.num_objects = num_objects 154 | self.zero_shot = zero_shot 155 | self.return_ids = return_ids 156 | self.evaluation = evaluation 157 | 158 | with open( 159 | os.path.join(self.data_path, 'annotations', 'Train_Test_Val_FSC_147.json'), 'rb' 160 | ) as file: 161 | splits = json.load(file) 162 | self.image_names = splits[split] 163 | with open( 164 | os.path.join(self.data_path, 'annotations', 'annotation_FSC147_384.json'), 'rb' 165 | ) as file: 166 | self.annotations = json.load(file) 167 | self.labels = COCO(os.path.join(self.data_path, 'annotations', 'instances_' + split + '.json')) 168 | self.img_name_to_ori_id = self.map_img_name_to_ori_id() 169 | 170 | def get_gt_bboxes(self, idx): 171 | coco_im_id = self.img_name_to_ori_id[self.image_names[idx]] 172 | anno_ids = self.labels.getAnnIds([coco_im_id]) 173 | annotations = self.labels.loadAnns(anno_ids) 174 | bboxes = [] 175 | for a in annotations: 176 | bboxes.append(xywh_to_x1y1x2y2(a['bbox'])) 177 | return bboxes 178 | 179 | def __getitem__(self, idx: int): 180 | img = Image.open(os.path.join( 181 | self.data_path, 182 | 'images_384_VarV2', 183 | self.image_names[idx] 184 | )).convert("RGB") 185 | 186 | gt_bboxes = torch.tensor(self.get_gt_bboxes(idx)) 187 | 188 | img = T.Compose([ 189 | T.ToTensor(), 190 | ])(img) 191 | 192 | bboxes = torch.tensor( 193 | self.annotations[self.image_names[idx]]['box_examples_coordinates'], 194 | dtype=torch.float32 195 | )[:3, [0, 2], :].reshape(-1, 4)[:self.num_objects, ...] 196 | 197 | density_map = torch.from_numpy(np.load(os.path.join( 198 | self.data_path, 199 | 'gt_density_map_adaptive_1024_1024_SAME', 200 | os.path.splitext(self.image_names[idx])[0] + '.npy', 201 | ))).unsqueeze(0) 202 | 203 | tiled = False 204 | 205 | # data augmentation 206 | if self.split == 'train' and torch.rand(1) < self.tiling_p: 207 | tiled = True 208 | tile_size = (torch.rand(1) + 1, torch.rand(1) + 1) 209 | img, bboxes, density_map, gt_bboxes = tiling_augmentation( 210 | img, bboxes, self.resize, 211 | self.jitter, tile_size, self.horizontal_flip_p, gt_bboxes=gt_bboxes, density_map=density_map 212 | ) 213 | 214 | elif self.split == 'train': 215 | img, bboxes, density_map, gt_bboxes, scaling_factor, padwh = resize_and_pad(img, bboxes, density_map, 216 | full_stretch=True, 217 | gt_bboxes=gt_bboxes) 218 | elif not self.evaluation: 219 | img, bboxes, density_map, gt_bboxes, scaling_factor, padwh = resize_and_pad(img, bboxes, density_map, 220 | gt_bboxes=gt_bboxes, 221 | full_stretch=False, 222 | size=1024.0) 223 | else: 224 | img_, bboxes_, density_map_, gt_bboxes_, scaling_factor_, padwh_ = resize_and_pad(img, bboxes, 225 | density_map, 226 | gt_bboxes=gt_bboxes, 227 | full_stretch=False if not self.zero_shot else True, 228 | size=1024.0) 229 | if (bboxes_[:, 2] - bboxes_[:, 0]).min() < 25 and ( 230 | bboxes_[:, 3] - bboxes_[:, 1]).min() < 25 and not self.zero_shot: 231 | img, bboxes, density_map, gt_bboxes, scaling_factor, padwh = resize_and_pad(img, bboxes, 232 | density_map, 233 | gt_bboxes=gt_bboxes, 234 | full_stretch=False, 235 | size=1536.0) 236 | else: 237 | img, bboxes, density_map, gt_bboxes, scaling_factor, padwh = img_, bboxes_, density_map_, gt_bboxes_, scaling_factor_, padwh_ 238 | 239 | if self.split == 'train': 240 | if not tiled: 241 | img = self.jitter(img) 242 | img = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) 243 | 244 | if self.split == 'train' and not tiled and torch.rand(1) < self.horizontal_flip_p: 245 | img = TVF.hflip(img) 246 | density_map = TVF.hflip(density_map) 247 | bboxes[:, [0, 2]] = self.img_size - bboxes[:, [2, 0]] 248 | gt_bboxes[:, [0, 2]] = self.img_size - gt_bboxes[:, [2, 0]] 249 | 250 | gt_bboxes = torch.clamp(gt_bboxes, min=0, max=1024) 251 | 252 | if self.evaluation: 253 | return img, bboxes, density_map, torch.tensor(idx), gt_bboxes, scaling_factor, padwh 254 | 255 | else: 256 | return img, bboxes, torch.tensor(idx), gt_bboxes, density_map 257 | 258 | def __len__(self): 259 | return len(self.image_names) 260 | 261 | def map_img_name_to_ori_id(self, ): 262 | all_coco_imgs = self.labels.imgs 263 | map_name_2_id = dict() 264 | for k, v in all_coco_imgs.items(): 265 | img_id = v["id"] 266 | img_name = v["file_name"] 267 | map_name_2_id[img_name] = img_id 268 | return map_name_2_id 269 | 270 | 271 | def generate_density_maps(data_path, target_size=(1024, 1024)): 272 | from tqdm import tqdm 273 | from scipy.ndimage import gaussian_filter 274 | with open( 275 | os.path.join(data_path, 'annotations/annotation_FSC147_384.json'), 'rb' 276 | ) as file: 277 | annotations = json.load(file) 278 | 279 | if not os.path.exists(os.path.join(data_path, 'gt_density_map_adaptive_1024_1024_SAME')): 280 | os.makedirs(os.path.join(data_path, 'gt_density_map_adaptive_1024_1024_SAME')) 281 | 282 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 283 | 284 | for i, (image_name, ann) in enumerate(tqdm(annotations.items())): 285 | _, h, w = T.ToTensor()(Image.open(os.path.join( 286 | data_path, 287 | 'images_384_VarV2', 288 | image_name 289 | ))).size() 290 | h_ratio, w_ratio = target_size[0] / h, target_size[1] / w 291 | 292 | points = ( 293 | torch.tensor(ann['points'], device=device) * 294 | torch.tensor([w_ratio, h_ratio], device=device) 295 | ).long() 296 | points[:, 0] = points[:, 0].clip(0, target_size[1] - 1) 297 | points[:, 1] = points[:, 1].clip(0, target_size[0] - 1) 298 | 299 | sigmas = np.array([2, 2]) 300 | 301 | dmap = torch.zeros(*target_size) 302 | for p in range(points.size(0)): 303 | dmap[points[p, 1], points[p, 0]] += 1 304 | dmap = gaussian_filter(dmap.cpu().numpy(), sigmas) 305 | 306 | np.save(os.path.join( 307 | data_path, 308 | 'gt_density_map_adaptive_1024_1024_SAME', 309 | os.path.splitext(image_name)[0] + '.npy', 310 | ), dmap) 311 | 312 | 313 | if __name__ == '__main__': 314 | parser = argparse.ArgumentParser("Density map generator", add_help=False) 315 | parser.add_argument( 316 | '--data_path', 317 | default='/storage/datasets/fsc147/', 318 | type=str 319 | ) 320 | parser.add_argument('--image_size', default=1024, type=int) 321 | args = parser.parse_args() 322 | generate_density_maps(args.data_path, (args.image_size, args.image_size)) 323 | --------------------------------------------------------------------------------