├── images
└── truck.jpg
├── examples
├── box_f1.png
├── box_s1.png
├── point_f1.png
└── point_s1.png
├── LICENSE
├── README.md
├── .gitignore
├── demo.ipynb
└── mini_segment_anything.py
/images/truck.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/images/truck.jpg
--------------------------------------------------------------------------------
/examples/box_f1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/box_f1.png
--------------------------------------------------------------------------------
/examples/box_s1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/box_s1.png
--------------------------------------------------------------------------------
/examples/point_f1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/point_f1.png
--------------------------------------------------------------------------------
/examples/point_s1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hujiecpp/Mini-Segment-Anything/HEAD/examples/point_s1.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jie Hu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | In this repo, we distill the powerful [segment anything](https://github.com/facebookresearch/segment-anything) models into lightweight [yoso](https://github.com/hujiecpp/YOSO) models for efficient image segmentation (yoso is a framework aiming to achieve real-time panoptic segmentation).
2 | Specifically, we replace the heavy ViT image encoder with yoso image encoder.
3 | This provides over 10 times speedup for extracting image features with a single RTX 3090 GPU.
4 | Currently, we release demo checkpoints here, and will update them when we obtain more accurate ones.
5 |
6 |
7 | ## Examples
8 | - Good segmentation results
9 |
10 |
11 |
12 | - Bad segmentation results
13 |
14 |
15 |
16 | ## Getting Started
17 | ### Installation
18 | We recommend to use [Anaconda](https://www.anaconda.com/) for installation.
19 | ```bash
20 | conda create --name sam_mini python=3.8 -y
21 | conda install pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch
22 | pip install git+https://github.com/facebookresearch/detectron2.git
23 | pip install git+https://github.com/facebookresearch/segment-anything.git
24 | pip install opencv-python pycocotools matplotlib
25 | ```
26 |
27 | ### Checkpoints
28 | - 2023/04/27: [sam_yoso_r50_13a999.pth](https://github.com/hujiecpp/Mini-Segment-Anything/releases/download/checkpoint/sam_yoso_r50_13a999.pth)
29 |
30 |
31 | ### Demos
32 | After loading the lightweight model, the remaining operations keep the same as segment anything.
33 | - Predictor:
34 | ```python
35 | from mini_segment_anything import build_sam_yoso_r50
36 | from segment_anything import SamPredictor
37 |
38 | sam_checkpoint = './sam_yoso_r50_13a999.pth'
39 | sam_mini = build_sam_yoso_r50(checkpoint=sam_checkpoint).to('cuda')
40 | predictor_mini = SamPredictor(sam_mini)
41 | ```
42 |
43 | - Automatic mask generator:
44 | ```python
45 | from mini_segment_anything import build_sam_yoso_r50
46 | from segment_anything import SamAutomaticMaskGenerator
47 |
48 | sam_checkpoint = './sam_yoso_r50_13a999.pth'
49 | sam_mini = build_sam_yoso_r50(checkpoint=sam_checkpoint).to('cuda')
50 | mask_generator_mini = SamAutomaticMaskGenerator(sam_mini)
51 | ```
52 | More details can be found in './demo.ipynb'.
53 |
54 | ## Todo List
55 | - [x] Distill with image encoder.
56 | - [ ] Distill with mask decoder.
57 |
58 | ## Citation
59 |
60 | If you find this project helpful for your research, please consider citing the following BibTeX entry.
61 |
62 | ```BibTeX
63 | @article{kirillov2023segany,
64 | title={Segment Anything},
65 | author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
66 | journal={arXiv:2304.02643},
67 | year={2023}
68 | }
69 |
70 | @article{hu2023yoso,
71 | title={You Only Segment Once: Towards Real-Time Panoptic Segmentation},
72 | author={Hu, Jie and Huang, Linyan and Ren, Tianhe and Zhang, Shengchuan and Ji, Rongrong and Cao, Liujuan},
73 | journal={arXiv preprint arXiv:2303.14651},
74 | year={2023}
75 | }
76 | ```
77 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from mini_segment_anything import build_sam_yoso_r50\n",
10 | "from segment_anything import sam_model_registry, SamPredictor\n",
11 | "\n",
12 | "# sam-huge\n",
13 | "sam_checkpoint = './sam_vit_h_4b8939.pth'\n",
14 | "model_type = 'vit_h'\n",
15 | "sam_huge = sam_model_registry[model_type](checkpoint=sam_checkpoint).to('cuda')\n",
16 | "predictor_huge = SamPredictor(sam_huge)\n",
17 | "\n",
18 | "# sam-mini\n",
19 | "sam_checkpoint = './sam_yoso_r50_13a999.pth'\n",
20 | "sam_mini = build_sam_yoso_r50(checkpoint=sam_checkpoint).to('cuda')\n",
21 | "predictor_mini = SamPredictor(sam_mini)"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import cv2\n",
31 | "import time\n",
32 | "import numpy as np\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "\n",
35 | "def show_mask(mask, ax, random_color=False):\n",
36 | " if random_color:\n",
37 | " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
38 | " else:\n",
39 | " color = np.array([30/255, 144/255, 255/255, 0.6])\n",
40 | " h, w = mask.shape[-2:]\n",
41 | " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
42 | " ax.imshow(mask_image)\n",
43 | " \n",
44 | "def show_points(coords, labels, ax, marker_size=375):\n",
45 | " pos_points = coords[labels==1]\n",
46 | " neg_points = coords[labels==0]\n",
47 | " ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
48 | " ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
49 | " \n",
50 | "def show_box(box, ax):\n",
51 | " x0, y0 = box[0], box[1]\n",
52 | " w, h = box[2] - box[0], box[3] - box[1]\n",
53 | " ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "image = cv2.imread('./images/truck.jpg')\n",
63 | "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
64 | "input_point = np.array([[600, 655]])\n",
65 | "input_label = np.array([1])\n",
66 | "input_box = np.array([400, 600, 700, 900])"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "start = time.time()\n",
76 | "predictor_huge.set_image(image)\n",
77 | "end = time.time()\n",
78 | "\n",
79 | "masks, scores, logits = predictor_huge.predict(\n",
80 | " point_coords=input_point,\n",
81 | " point_labels=input_label,\n",
82 | " # box=input_box[None, :],\n",
83 | " multimask_output=False,\n",
84 | ")\n",
85 | "\n",
86 | "plt.figure(figsize=(10,10))\n",
87 | "plt.imshow(image)\n",
88 | "show_mask(masks, plt.gca())\n",
89 | "show_points(input_point, input_label, plt.gca())\n",
90 | "# show_box(input_box, plt.gca())\n",
91 | "plt.axis('on')\n",
92 | "plt.title(\"SAM_huge time: \" + str(end - start) + \"s\")\n",
93 | "plt.show()"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "start = time.time()\n",
103 | "predictor_mini.set_image(image)\n",
104 | "end = time.time()\n",
105 | "\n",
106 | "masks, scores, logits = predictor_mini.predict(\n",
107 | " point_coords=input_point,\n",
108 | " point_labels=input_label,\n",
109 | " # box=input_box[None, :],\n",
110 | " multimask_output=False,\n",
111 | ")\n",
112 | "\n",
113 | "plt.figure(figsize=(10,10))\n",
114 | "plt.imshow(image)\n",
115 | "show_mask(masks, plt.gca())\n",
116 | "show_points(input_point, input_label, plt.gca())\n",
117 | "# show_box(input_box, plt.gca())\n",
118 | "plt.title(\"SAM_mini time: \" + str(end - start) + \"s\")\n",
119 | "plt.axis('on')\n",
120 | "plt.show() "
121 | ]
122 | }
123 | ],
124 | "metadata": {
125 | "kernelspec": {
126 | "display_name": "sam_mini",
127 | "language": "python",
128 | "name": "python3"
129 | },
130 | "language_info": {
131 | "codemirror_mode": {
132 | "name": "ipython",
133 | "version": 3
134 | },
135 | "file_extension": ".py",
136 | "mimetype": "text/x-python",
137 | "name": "python",
138 | "nbconvert_exporter": "python",
139 | "pygments_lexer": "ipython3",
140 | "version": "3.8.16"
141 | },
142 | "orig_nbformat": 4
143 | },
144 | "nbformat": 4,
145 | "nbformat_minor": 2
146 | }
147 |
--------------------------------------------------------------------------------
/mini_segment_anything.py:
--------------------------------------------------------------------------------
1 | # ImageEncoderYOSO for SAM-mini
2 | # from https://github.com/hujiecpp/YOSO/blob/main/projects/YOSO/yoso/neck.py
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from detectron2.config import get_cfg
9 | from detectron2.config import CfgNode as CN
10 | import fvcore.nn.weight_init as weight_init
11 | from detectron2.modeling import build_backbone
12 | from detectron2.layers import DeformConv, ModulatedDeformConv
13 |
14 | class DeformLayer(nn.Module):
15 | def __init__(self, in_planes, out_planes, deconv_kernel=4, deconv_stride=2, deconv_pad=1, deconv_out_pad=0, modulate_deform=True, num_groups=1, deform_num_groups=1, dilation=1):
16 | super(DeformLayer, self).__init__()
17 | self.deform_modulated = modulate_deform
18 | if modulate_deform:
19 | deform_conv_op = ModulatedDeformConv
20 | offset_channels = 27
21 | else:
22 | deform_conv_op = DeformConv
23 | offset_channels = 18
24 |
25 | self.dcn_offset = nn.Conv2d(in_planes, offset_channels * deform_num_groups, kernel_size=3, stride=1, padding=1*dilation, dilation=dilation)
26 | self.dcn = deform_conv_op(in_planes, out_planes, kernel_size=3, stride=1, padding=1*dilation, bias=False, groups=num_groups, dilation=dilation, deformable_groups=deform_num_groups)
27 | for layer in [self.dcn]:
28 | weight_init.c2_msra_fill(layer)
29 |
30 | nn.init.constant_(self.dcn_offset.weight, 0)
31 | nn.init.constant_(self.dcn_offset.bias, 0)
32 |
33 | self.dcn_bn = nn.SyncBatchNorm(out_planes)
34 | self.up_sample = nn.ConvTranspose2d(in_channels=out_planes, out_channels=out_planes, kernel_size=deconv_kernel, stride=deconv_stride, padding=deconv_pad, output_padding=deconv_out_pad, bias=False)
35 | self._deconv_init()
36 | self.up_bn = nn.SyncBatchNorm(out_planes)
37 | self.relu = nn.ReLU()
38 |
39 | def forward(self, x):
40 | out = x
41 | if self.deform_modulated:
42 | offset_mask = self.dcn_offset(out)
43 | offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
44 | offset = torch.cat((offset_x, offset_y), dim=1)
45 | mask = mask.sigmoid()
46 | out = self.dcn(out, offset, mask)
47 | else:
48 | offset = self.dcn_offset(out)
49 | out = self.dcn(out, offset)
50 | x = out
51 |
52 | x = self.dcn_bn(x)
53 | x = self.relu(x)
54 | x = self.up_sample(x)
55 | x = self.up_bn(x)
56 | x = self.relu(x)
57 | return x
58 |
59 | def _deconv_init(self):
60 | w = self.up_sample.weight.data
61 | f = math.ceil(w.size(2) / 2)
62 | c = (2 * f - 1 - f % 2) / (2. * f)
63 | for i in range(w.size(2)):
64 | for j in range(w.size(3)):
65 | w[0, 0, i, j] = \
66 | (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
67 | for c in range(1, w.size(0)):
68 | w[c, 0, :, :] = w[0, 0, :, :]
69 |
70 |
71 | class LiteDeformConv(nn.Module):
72 | def __init__(self, cfg, backbone_shape):
73 | super(LiteDeformConv, self).__init__()
74 | in_features = cfg.MODEL.YOSO.IN_FEATURES
75 | in_channels = []
76 | out_channels = [cfg.MODEL.YOSO.AGG_DIM]
77 | for feat in in_features:
78 | tmp = backbone_shape[feat].channels
79 | in_channels.append(tmp)
80 | out_channels.append(tmp//2)
81 |
82 | self.lateral_conv0 = nn.Conv2d(in_channels=in_channels[-1], out_channels=out_channels[-1], kernel_size=1, stride=1, padding=0)
83 | self.deform_conv1 = DeformLayer(in_planes=out_channels[-1], out_planes=out_channels[-2])
84 | self.lateral_conv1 = nn.Conv2d(in_channels=in_channels[-2], out_channels=out_channels[-2], kernel_size=1, stride=1, padding=0)
85 | self.deform_conv2 = DeformLayer(in_planes=out_channels[-2], out_planes=out_channels[-3])
86 | self.lateral_conv2 = nn.Conv2d(in_channels=in_channels[-3], out_channels=out_channels[-3], kernel_size=1, stride=1, padding=0)
87 | self.deform_conv3 = DeformLayer(in_planes=out_channels[-3], out_planes=out_channels[-4])
88 | self.lateral_conv3 = nn.Conv2d(in_channels=in_channels[-4], out_channels=out_channels[-4], kernel_size=1, stride=1, padding=0)
89 | self.output_conv = nn.Conv2d(in_channels=out_channels[-5], out_channels=out_channels[-5], kernel_size=3, stride=1, padding=1)
90 | self.bias = nn.Parameter(torch.FloatTensor(1,out_channels[-5],1,1), requires_grad=True)
91 | self.bias.data.fill_(0.0)
92 |
93 | self.conv_a5 = nn.Conv2d(in_channels=out_channels[-1], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
94 | self.conv_a4 = nn.Conv2d(in_channels=out_channels[-2], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
95 | self.conv_a3 = nn.Conv2d(in_channels=out_channels[-3], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
96 | self.conv_a2 = nn.Conv2d(in_channels=out_channels[-4], out_channels=out_channels[-5], kernel_size=1, stride=1, padding=0, bias=False)
97 |
98 | def forward(self, features_list):
99 | p5 = self.lateral_conv0(features_list[-1])
100 | x5 = p5
101 | x = self.deform_conv1(x5)
102 |
103 | p4 = self.lateral_conv1(features_list[-2])
104 | x4 = p4 + x
105 | x = self.deform_conv2(x4)
106 |
107 | p3 = self.lateral_conv2(features_list[-3])
108 | x3 = p3 + x
109 | x = self.deform_conv3(x3)
110 |
111 | p2 = self.lateral_conv3(features_list[-4])
112 | x2 = p2 + x
113 |
114 | # CFA
115 | x5 = F.interpolate(self.conv_a5(x5), scale_factor=8, align_corners=False, mode='bilinear')
116 | x4 = F.interpolate(self.conv_a4(x4), scale_factor=4, align_corners=False, mode='bilinear')
117 | x3 = F.interpolate(self.conv_a3(x3), scale_factor=2, align_corners=False, mode='bilinear')
118 | x2 = self.conv_a2(x2)
119 | x = x5 + x4 + x3 + x2 + self.bias
120 |
121 | x = self.output_conv(x)
122 |
123 | return x
124 |
125 |
126 | class YOSONeck(nn.Module):
127 | def __init__(self, cfg, backbone_shape):
128 | super().__init__()
129 |
130 | self.deconv = LiteDeformConv(cfg=cfg, backbone_shape=backbone_shape)
131 | self.loc_conv = nn.Conv2d(in_channels=128+2, out_channels=cfg.MODEL.YOSO.HIDDEN_DIM, kernel_size=1, stride=1)
132 |
133 | self.conv_1 = nn.Conv2d(in_channels=cfg.MODEL.YOSO.HIDDEN_DIM, out_channels=cfg.MODEL.YOSO.HIDDEN_DIM, kernel_size=4, stride=2, padding=1)
134 | self.relu = nn.ReLU()
135 | self.conv_2 = nn.Conv2d(in_channels=cfg.MODEL.YOSO.HIDDEN_DIM, out_channels=cfg.MODEL.YOSO.HIDDEN_DIM, kernel_size=4, stride=2, padding=1)
136 |
137 | def generate_coord(self, input_feat):
138 | x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device)
139 | y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device)
140 | y, x = torch.meshgrid(y_range, x_range)
141 | y = y.expand([input_feat.shape[0], 1, -1, -1])
142 | x = x.expand([input_feat.shape[0], 1, -1, -1])
143 | coord_feat = torch.cat([x, y], 1)
144 | return coord_feat
145 |
146 | def forward(self, features_list):
147 | features = self.deconv(features_list)
148 | coord_feat = self.generate_coord(features)
149 | features = torch.cat([features, coord_feat], 1)
150 | features = self.loc_conv(features)
151 |
152 | # features = F.interpolate(features, scale_factor=0.25, mode="bilinear", align_corners=False)
153 | features = self.relu(features)
154 | features = self.relu(self.conv_1(features))
155 | features = self.conv_2(features)
156 |
157 | return features
158 |
159 | class ImageEncoderYOSO(nn.Module):
160 | def __init__(self, img_size: int = 1024):
161 | super().__init__()
162 | self.img_size = img_size
163 |
164 | cfg = get_cfg()
165 | cfg.MODEL.YOSO = CN()
166 | cfg.MODEL.YOSO.IN_FEATURES = ["res2", "res3", "res4", "res5"]
167 | cfg.MODEL.YOSO.HIDDEN_DIM = 256
168 | cfg.MODEL.YOSO.AGG_DIM = 128
169 | cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
170 | cfg.MODEL.BACKBONE.FREEZE_AT = 0
171 | cfg.MODEL.RESNETS.DEPTH = 50
172 | cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False
173 | cfg.MODEL.RESNETS.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
174 |
175 | self.in_features = cfg.MODEL.YOSO.IN_FEATURES
176 | self.backbone = build_backbone(cfg)
177 | self.yoso_neck = YOSONeck(cfg=cfg, backbone_shape=self.backbone.output_shape())
178 | # self.to('cuda')
179 |
180 | def forward(self, images):
181 | backbone_feats = self.backbone(images)
182 | # print(backbone_feats)
183 | features = list()
184 | for f in self.in_features:
185 | features.append(backbone_feats[f])
186 |
187 | neck_feats = self.yoso_neck(features)
188 | return neck_feats
189 |
190 | # build SAM_mini
191 | from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
192 |
193 | def build_sam_yoso_r50(checkpoint=None):
194 | prompt_embed_dim = 256
195 | image_size = 1024
196 | vit_patch_size = 16
197 | image_embedding_size = image_size // vit_patch_size
198 | sam = Sam(
199 | image_encoder=ImageEncoderYOSO(),
200 | prompt_encoder=PromptEncoder(
201 | embed_dim=prompt_embed_dim,
202 | image_embedding_size=(image_embedding_size, image_embedding_size),
203 | input_image_size=(image_size, image_size),
204 | mask_in_chans=16,
205 | ),
206 | mask_decoder=MaskDecoder(
207 | num_multimask_outputs=3,
208 | transformer=TwoWayTransformer(
209 | depth=2,
210 | embedding_dim=prompt_embed_dim,
211 | mlp_dim=2048,
212 | num_heads=8,
213 | ),
214 | transformer_dim=prompt_embed_dim,
215 | iou_head_depth=3,
216 | iou_head_hidden_dim=256,
217 | ),
218 | pixel_mean=[123.675, 116.28, 103.53],
219 | pixel_std=[58.395, 57.12, 57.375],
220 | )
221 | sam.eval()
222 | if checkpoint is not None:
223 | with open(checkpoint, "rb") as f:
224 | state_dict = torch.load(f)
225 | sam.load_state_dict(state_dict, strict=True)
226 | return sam
--------------------------------------------------------------------------------