├── images └── truck.jpg ├── examples ├── box_f1.png ├── box_s1.png ├── point_f1.png └── point_s1.png ├── LICENSE ├── README.md ├── .gitignore ├── demo.ipynb └── mini_segment_anything.py /images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/images/truck.jpg -------------------------------------------------------------------------------- /examples/box_f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/box_f1.png -------------------------------------------------------------------------------- /examples/box_s1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/box_s1.png -------------------------------------------------------------------------------- /examples/point_f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/point_f1.png -------------------------------------------------------------------------------- /examples/point_s1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/point_s1.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jie Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | In this repo, we distill the powerful [segment anything](https://github.com/facebookresearch/segment-anything) models into lightweight [yoso](https://github.com/hujiecpp/YOSO) models for efficient image segmentation (yoso is a framework aiming to achieve real-time panoptic segmentation). 2 | Specifically, we replace the heavy ViT image encoder with yoso image encoder. 3 | This provides over 10 times speedup for extracting image features with a single RTX 3090 GPU. 4 | Currently, we release demo checkpoints here, and will update them when we obtain more accurate ones. 5 | 6 | 7 | ## Examples 8 | - Good segmentation results 9 | 10 | 11 | 12 | - Bad segmentation results 13 | 14 | 15 | 16 | ## Getting Started 17 | ### Installation 18 | We recommend to use [Anaconda](https://www.anaconda.com/) for installation. 19 | ```bash 20 | conda create --name sam_mini python=3.8 -y 21 | conda install pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch 22 | pip install git+https://github.com/facebookresearch/detectron2.git 23 | pip install git+https://github.com/facebookresearch/segment-anything.git 24 | pip install opencv-python pycocotools matplotlib 25 | ``` 26 | 27 | ### Checkpoints 28 | - 2023/04/27: [sam_yoso_r50_13a999.pth](https://github.com/hujiecpp/Mini-Segment-Anything/releases/download/checkpoint/sam_yoso_r50_13a999.pth) 29 | 30 | 31 | ### Demos 32 | After loading the lightweight model, the remaining operations keep the same as segment anything. 33 | - Predictor: 34 | ```python 35 | from mini_segment_anything import build_sam_yoso_r50 36 | from segment_anything import SamPredictor 37 | 38 | sam_checkpoint = './sam_yoso_r50_13a999.pth' 39 | sam_mini = build_sam_yoso_r50(checkpoint=sam_checkpoint).to('cuda') 40 | predictor_mini = SamPredictor(sam_mini) 41 | ``` 42 | 43 | - Automatic mask generator: 44 | ```python 45 | from mini_segment_anything import build_sam_yoso_r50 46 | from segment_anything import SamAutomaticMaskGenerator 47 | 48 | sam_checkpoint = './sam_yoso_r50_13a999.pth' 49 | sam_mini = build_sam_yoso_r50(checkpoint=sam_checkpoint).to('cuda') 50 | mask_generator_mini = SamAutomaticMaskGenerator(sam_mini) 51 | ``` 52 | More details can be found in './demo.ipynb'. 53 | 54 | ## Todo List 55 | - [x] Distill with image encoder. 56 | - [ ] Distill with mask decoder. 57 | 58 | ## Citation 59 | 60 | If you find this project helpful for your research, please consider citing the following BibTeX entry. 61 | 62 | ```BibTeX 63 | @article{kirillov2023segany, 64 | title={Segment Anything}, 65 | author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross}, 66 | journal={arXiv:2304.02643}, 67 | year={2023} 68 | } 69 | 70 | @article{hu2023yoso, 71 | title={You Only Segment Once: Towards Real-Time Panoptic Segmentation}, 72 | author={Hu, Jie and Huang, Linyan and Ren, Tianhe and Zhang, Shengchuan and Ji, Rongrong and Cao, Liujuan}, 73 | journal={arXiv preprint arXiv:2303.14651}, 74 | year={2023} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from mini_segment_anything import build_sam_yoso_r50\n", 10 | "from segment_anything import sam_model_registry, SamPredictor\n", 11 | "\n", 12 | "# sam-huge\n", 13 | "sam_checkpoint = './sam_vit_h_4b8939.pth'\n", 14 | "model_type = 'vit_h'\n", 15 | "sam_huge = sam_model_registry[model_type](checkpoint=sam_checkpoint).to('cuda')\n", 16 | "predictor_huge = SamPredictor(sam_huge)\n", 17 | "\n", 18 | "# sam-mini\n", 19 | "sam_checkpoint = './sam_yoso_r50_13a999.pth'\n", 20 | "sam_mini = build_sam_yoso_r50(checkpoint=sam_checkpoint).to('cuda')\n", 21 | "predictor_mini = SamPredictor(sam_mini)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import cv2\n", 31 | "import time\n", 32 | "import numpy as np\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "\n", 35 | "def show_mask(mask, ax, random_color=False):\n", 36 | " if random_color:\n", 37 | " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n", 38 | " else:\n", 39 | " color = np.array([30/255, 144/255, 255/255, 0.6])\n", 40 | " h, w = mask.shape[-2:]\n", 41 | " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n", 42 | " ax.imshow(mask_image)\n", 43 | " \n", 44 | "def show_points(coords, labels, ax, marker_size=375):\n", 45 | " pos_points = coords[labels==1]\n", 46 | " neg_points = coords[labels==0]\n", 47 | " ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n", 48 | " ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n", 49 | " \n", 50 | "def show_box(box, ax):\n", 51 | " x0, y0 = box[0], box[1]\n", 52 | " w, h = box[2] - box[0], box[3] - box[1]\n", 53 | " ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "image = cv2.imread('./images/truck.jpg')\n", 63 | "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 64 | "input_point = np.array([[600, 655]])\n", 65 | "input_label = np.array([1])\n", 66 | "input_box = np.array([400, 600, 700, 900])" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "start = time.time()\n", 76 | "predictor_huge.set_image(image)\n", 77 | "end = time.time()\n", 78 | "\n", 79 | "masks, scores, logits = predictor_huge.predict(\n", 80 | " point_coords=input_point,\n", 81 | " point_labels=input_label,\n", 82 | " # box=input_box[None, :],\n", 83 | " multimask_output=False,\n", 84 | ")\n", 85 | "\n", 86 | "plt.figure(figsize=(10,10))\n", 87 | "plt.imshow(image)\n", 88 | "show_mask(masks, plt.gca())\n", 89 | "show_points(input_point, input_label, plt.gca())\n", 90 | "# show_box(input_box, plt.gca())\n", 91 | "plt.axis('on')\n", 92 | "plt.title(\"SAM_huge time: \" + str(end - start) + \"s\")\n", 93 | "plt.show()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "start = time.time()\n", 103 | "predictor_mini.set_image(image)\n", 104 | "end = time.time()\n", 105 | "\n", 106 | "masks, scores, logits = predictor_mini.predict(\n", 107 | " point_coords=input_point,\n", 108 | " point_labels=input_label,\n", 109 | " # box=input_box[None, :],\n", 110 | " multimask_output=False,\n", 111 | ")\n", 112 | "\n", 113 | "plt.figure(figsize=(10,10))\n", 114 | "plt.imshow(image)\n", 115 | "show_mask(masks, plt.gca())\n", 116 | "show_points(input_point, input_label, plt.gca())\n", 117 | "# show_box(input_box, plt.gca())\n", 118 | "plt.title(\"SAM_mini time: \" + str(end - start) + \"s\")\n", 119 | "plt.axis('on')\n", 120 | "plt.show() " 121 | ] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "sam_mini", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.8.16" 141 | }, 142 | "orig_nbformat": 4 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 2 146 | } 147 | -------------------------------------------------------------------------------- /mini_segment_anything.py: -------------------------------------------------------------------------------- 1 | # ImageEncoderYOSO for SAM-mini 2 | # from https://github.com/hujiecpp/YOSO/blob/main/projects/YOSO/yoso/neck.py 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from detectron2.config import get_cfg 9 | from detectron2.config import CfgNode as CN 10 | import fvcore.nn.weight_init as weight_init 11 | from detectron2.modeling import build_backbone 12 | from detectron2.layers import DeformConv, ModulatedDeformConv 13 | 14 | class DeformLayer(nn.Module): 15 | def __init__(self, in_planes, out_planes, deconv_kernel=4, deconv_stride=2, deconv_pad=1, deconv_out_pad=0, modulate_deform=True, num_groups=1, deform_num_groups=1, dilation=1): 16 | super(DeformLayer, self).__init__() 17 | self.deform_modulated = modulate_deform 18 | if modulate_deform: 19 | deform_conv_op = ModulatedDeformConv 20 | offset_channels = 27 21 | else: 22 | deform_conv_op = DeformConv 23 | offset_channels = 18 24 | 25 | self.dcn_offset = nn.Conv2d(in_planes, offset_channels * deform_num_groups, kernel_size=3, stride=1, padding=1*dilation, dilation=dilation) 26 | self.dcn = deform_conv_op(in_planes, out_planes, kernel_size=3, stride=1, padding=1*dilation, bias=False, groups=num_groups, dilation=dilation, deformable_groups=deform_num_groups) 27 | for layer in [self.dcn]: 28 | weight_init.c2_msra_fill(layer) 29 | 30 | nn.init.constant_(self.dcn_offset.weight, 0) 31 | nn.init.constant_(self.dcn_offset.bias, 0) 32 | 33 | self.dcn_bn = nn.SyncBatchNorm(out_planes) 34 | self.up_sample = nn.ConvTranspose2d(in_channels=out_planes, out_channels=out_planes, kernel_size=deconv_kernel, stride=deconv_stride, padding=deconv_pad, output_padding=deconv_out_pad, bias=False) 35 | self._deconv_init() 36 | self.up_bn = nn.SyncBatchNorm(out_planes) 37 | self.relu = nn.ReLU() 38 | 39 | def forward(self, x): 40 | out = x 41 | if self.deform_modulated: 42 | offset_mask = self.dcn_offset(out) 43 | offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) 44 | offset = torch.cat((offset_x, offset_y), dim=1) 45 | mask = mask.sigmoid() 46 | out = self.dcn(out, offset, mask) 47 | else: 48 | offset = self.dcn_offset(out) 49 | out = self.dcn(out, offset) 50 | x = out 51 | 52 | x = self.dcn_bn(x) 53 | x = self.relu(x) 54 | x = self.up_sample(x) 55 | x = self.up_bn(x) 56 | x = self.relu(x) 57 | return x 58 | 59 | def _deconv_init(self): 60 | w = self.up_sample.weight.data 61 | f = math.ceil(w.size(2) / 2) 62 | c = (2 * f - 1 - f % 2) / (2. * f) 63 | for i in range(w.size(2)): 64 | for j in range(w.size(3)): 65 | w[0, 0, i, j] = \ 66 | (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 67 | for c in range(1, w.size(0)): 68 | w[c, 0, :, :] = w[0, 0, :, :] 69 | 70 | 71 | class LiteDeformConv(nn.Module): 72 | def __init__(self, cfg, backbone_shape): 73 | super(LiteDeformConv, self).__init__() 74 | in_features = cfg.MODEL.YOSO.IN_FEATURES 75 | in_channels = [] 76 | out_channels = [cfg.MODEL.YOSO.AGG_DIM] 77 | for feat in in_features: 78 | tmp = backbone_shape[feat].channels 79 | in_channels.append(tmp) 80 | out_channels.append(tmp//2) 81 | 82 | self.lateral_conv0 = nn.Conv2d(in_channels=in_channels[-1], out_channels=out_channels[-1], kernel_size=1, stride=1, padding=0) 83 | self.deform_conv1 = DeformLayer(in_planes=out_channels[-1], out_planes=out_channels[-2]) 84 | self.lateral_conv1 = nn.Conv2d(in_channels=in_channels[-2], out_channels=out_channels[-2], kernel_size=1, stride=1, padding=0) 85 | self.deform_conv2 = DeformLayer(in_planes=out_channels[-2], out_planes=out_channels[-3]) 86 | self.lateral_conv2 = nn.Conv2d(in_channels=in_channels[-3], out_channels=out_channels[-3], kernel_size=1, stride=1, padding=0) 87 | self.deform_conv3 = DeformLayer(in_planes=out_channels[-3], out_planes=out_channels[-4]) 88 | self.lateral_conv3 = nn.Conv2d(in_channels=in_channels[-4], out_channels=out_channels[-4], kernel_size=1, stride=1, padding=0) 89 | self.output_conv = nn.Conv2d(in_channels=out_channels[-5], out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1) 90 | self.bias = nn.Parameter(torch.FloatTensor(1,out_channels[-5],1,1), requires_grad=True) 91 | self.bias.data.fill_(0.0) 92 | 93 | self.conv_a5 = nn.Conv2d(in_channels=out_channels[-1], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) 94 | self.conv_a4 = nn.Conv2d(in_channels=out_channels[-2], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) 95 | self.conv_a3 = nn.Conv2d(in_channels=out_channels[-3], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) 96 | self.conv_a2 = nn.Conv2d(in_channels=out_channels[-4], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False) 97 | 98 | def forward(self, features_list): 99 | p5 = self.lateral_conv0(features_list[-1]) 100 | x5 = p5 101 | x = self.deform_conv1(x5) 102 | 103 | p4 = self.lateral_conv1(features_list[-2]) 104 | x4 = p4 + x 105 | x = self.deform_conv2(x4) 106 | 107 | p3 = self.lateral_conv2(features_list[-3]) 108 | x3 = p3 + x 109 | x = self.deform_conv3(x3) 110 | 111 | p2 = self.lateral_conv3(features_list[-4]) 112 | x2 = p2 + x 113 | 114 | # CFA 115 | x5 = F.interpolate(self.conv_a5(x5), scale_factor=8, align_corners=False, mode='bilinear') 116 | x4 = F.interpolate(self.conv_a4(x4), scale_factor=4, align_corners=False, mode='bilinear') 117 | x3 = F.interpolate(self.conv_a3(x3), scale_factor=2, align_corners=False, mode='bilinear') 118 | x2 = self.conv_a2(x2) 119 | x = x5 + x4 + x3 + x2 + self.bias 120 | 121 | x = self.output_conv(x) 122 | 123 | return x 124 | 125 | 126 | class YOSONeck(nn.Module): 127 | def __init__(self, cfg, backbone_shape): 128 | super().__init__() 129 | 130 | self.deconv = LiteDeformConv(cfg=cfg, backbone_shape=backbone_shape) 131 | self.loc_conv = nn.Conv2d(in_channels=128+2, out_channels=cfg.MODEL.YOSO.HIDDEN_DIM, kernel_size=1, stride=1) 132 | 133 | self.conv_1 = nn.Conv2d(in_channels=cfg.MODEL.YOSO.HIDDEN_DIM, out_channels=cfg.MODEL.YOSO.HIDDEN_DIM, kernel_size=4, stride=2, padding=1) 134 | self.relu = nn.ReLU() 135 | self.conv_2 = nn.Conv2d(in_channels=cfg.MODEL.YOSO.HIDDEN_DIM, out_channels=cfg.MODEL.YOSO.HIDDEN_DIM, kernel_size=4, stride=2, padding=1) 136 | 137 | def generate_coord(self, input_feat): 138 | x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device) 139 | y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device) 140 | y, x = torch.meshgrid(y_range, x_range) 141 | y = y.expand([input_feat.shape[0], 1, -1, -1]) 142 | x = x.expand([input_feat.shape[0], 1, -1, -1]) 143 | coord_feat = torch.cat([x, y], 1) 144 | return coord_feat 145 | 146 | def forward(self, features_list): 147 | features = self.deconv(features_list) 148 | coord_feat = self.generate_coord(features) 149 | features = torch.cat([features, coord_feat], 1) 150 | features = self.loc_conv(features) 151 | 152 | # features = F.interpolate(features, scale_factor=0.25, mode="bilinear", align_corners=False) 153 | features = self.relu(features) 154 | features = self.relu(self.conv_1(features)) 155 | features = self.conv_2(features) 156 | 157 | return features 158 | 159 | class ImageEncoderYOSO(nn.Module): 160 | def __init__(self, img_size: int = 1024): 161 | super().__init__() 162 | self.img_size = img_size 163 | 164 | cfg = get_cfg() 165 | cfg.MODEL.YOSO = CN() 166 | cfg.MODEL.YOSO.IN_FEATURES = ["res2", "res3", "res4", "res5"] 167 | cfg.MODEL.YOSO.HIDDEN_DIM = 256 168 | cfg.MODEL.YOSO.AGG_DIM = 128 169 | cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone" 170 | cfg.MODEL.BACKBONE.FREEZE_AT = 0 171 | cfg.MODEL.RESNETS.DEPTH = 50 172 | cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False 173 | cfg.MODEL.RESNETS.OUT_FEATURES = ["res2", "res3", "res4", "res5"] 174 | 175 | self.in_features = cfg.MODEL.YOSO.IN_FEATURES 176 | self.backbone = build_backbone(cfg) 177 | self.yoso_neck = YOSONeck(cfg=cfg, backbone_shape=self.backbone.output_shape()) 178 | # self.to('cuda') 179 | 180 | def forward(self, images): 181 | backbone_feats = self.backbone(images) 182 | # print(backbone_feats) 183 | features = list() 184 | for f in self.in_features: 185 | features.append(backbone_feats[f]) 186 | 187 | neck_feats = self.yoso_neck(features) 188 | return neck_feats 189 | 190 | # build SAM_mini 191 | from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 192 | 193 | def build_sam_yoso_r50(checkpoint=None): 194 | prompt_embed_dim = 256 195 | image_size = 1024 196 | vit_patch_size = 16 197 | image_embedding_size = image_size // vit_patch_size 198 | sam = Sam( 199 | image_encoder=ImageEncoderYOSO(), 200 | prompt_encoder=PromptEncoder( 201 | embed_dim=prompt_embed_dim, 202 | image_embedding_size=(image_embedding_size, image_embedding_size), 203 | input_image_size=(image_size, image_size), 204 | mask_in_chans=16, 205 | ), 206 | mask_decoder=MaskDecoder( 207 | num_multimask_outputs=3, 208 | transformer=TwoWayTransformer( 209 | depth=2, 210 | embedding_dim=prompt_embed_dim, 211 | mlp_dim=2048, 212 | num_heads=8, 213 | ), 214 | transformer_dim=prompt_embed_dim, 215 | iou_head_depth=3, 216 | iou_head_hidden_dim=256, 217 | ), 218 | pixel_mean=[123.675, 116.28, 103.53], 219 | pixel_std=[58.395, 57.12, 57.375], 220 | ) 221 | sam.eval() 222 | if checkpoint is not None: 223 | with open(checkpoint, "rb") as f: 224 | state_dict = torch.load(f) 225 | sam.load_state_dict(state_dict, strict=True) 226 | return sam --------------------------------------------------------------------------------