├── .gitignore ├── LICENSE ├── README.md ├── configs └── FERNet300_coco.py └── mmdet ├── backbones └── compositebackbone.py └── dense_heads └── prs_head.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 EdSong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dual Refinement Underwater Object Detection Network 2 | 3 | [This is the unofficial implementation of Dual Refinement Underwater Object Detection Network](https://link.springer.com/chapter/10.1007/978-3-030-58565-5_17) (ECCV2020). 4 | 5 | ### 2022/04/15 Update 6 | 7 | 1. Compatible with the current version of MMDetection. 8 | 2. Borrow code from Cascade-RPN (Adaptive Conv) to implement the refinement stage. In detail, we use one share Adaptive Conv and two separate cls conv and reg conv to obtain the final class prediction and bbox. 9 | 3. The performance of FERNet300 can surpass the SSD300. 10 | 11 | ### SETUP 12 | 13 | This implementation is based on [mmdetection](https://github.com/open-mmlab/mmdetection). 14 | 15 | Follow the installation of [mmdetection](https://github.com/open-mmlab/mmdetection), and copy the code of this repository to the [mmdetection](https://github.com/open-mmlab/mmdetection) file. 16 | 17 | ### ENVIRONMENT 18 | 19 | Python == 3.7.6 20 | 21 | Pytorch == 1.5.1 22 | 23 | torchvision == 0.6.1 24 | 25 | numpy == 1.18.1 26 | 27 | pillow == 7.0.0 28 | 29 | ### DATASET 30 | 31 | Since the authors of FERNet haven't open source the UWD dataset, we use dataset UTDAC2020, the download link of which is shown as follows. 32 | 33 | https://drive.google.com/file/d/1avyB-ht3VxNERHpAwNTuBRFOxiXDMczI/view?usp=sharing 34 | 35 | It is recommended to symlink the dataset file to the root. 36 | 37 | ``` 38 | FERNet 39 | ├── data 40 | │ ├── UTDAC2020 41 | │ │ ├── train2017 42 | │ │ ├── val2017 43 | │ │ ├── annotations 44 | ``` 45 | 46 | ### TRAIN 47 | 48 | ```python 49 | python tools/train.py configs/FERNet300_coco.py 50 | ``` 51 | 52 | ### TEST 53 | 54 | ``` 55 | python tools/test.py configs/FERNet300_coco.py path/to/checkpoints 56 | ``` 57 | 58 | ### DETAILS OF THE IMPLEMENTATION 59 | 60 | Since the official paper omits lots of details of the implementation, I decide these details to reimplement the code. 61 | 62 | - In CCB, the paper selects the $150 \times 150$, $75 \times 75$ and $38 \times 38$ characteristic layers on the lead backbone, but haven't clear point out which layer for feature fusion (for features of many layer have the same resolutions), we use the features when the channel sizes just reach 64, 256 and 512 respectively in VGG16 for feature fusion. For ResNet-50, the feature after the first conv (before the first maxpooling), features of Stage1 and Stage2 for feature fusion. 63 | - It seems that the original SSD does the downsampling in the extra layers. If RFAM also does the downsampling, the prediction feature maps' sizes won't fit any more. So I cancel the stride in the RFAM block. 64 | - In PRS, the paper mentions to do the dilation in DCN, but does not mention the dilation rate. I set it as 3. 65 | - The paper says that pre-processing phase is doing binary classification and refinement phase is fine-tuned. I maintain the pre-processing phase as multi-classification, and sum the softmax logit except background. 66 | - In PRS, for each anchors of certain pixel has its own offset ($\Delta x, \Delta y$), we use group DCN and set $group == deform\_group == num_anchors$. 67 | - Instead of directly input offset ($\Delta x, \Delta y$) to DCN, I use a FC layer to process the offset, which seems to obtain higher performance. 68 | - Since mmdetection have no implementation of randomly warming up, I still follow the pre-defined schedule_2x setting in mmdetection. 69 | 70 | ### SOMETHING TO SAY ABOUT THE EXPERIMENTS 71 | 72 | - Although I only train the model for 24 epochs, it seems that neither of PRS or CCB improves performance. Thus, I need your help to point out the mistake of my implementation and assist me to improve the this implementation. If the author of the paper can see this repository, please open source the code as soon as possible. 73 | - In PRS, if you use the feature of after the pre-processing phase ($X_{end}$) for DCN, the performance is much lower than just using $X_{out}$. 74 | - You can realize fine-tune refinement phase by inputting feat.detach() and offset.detach() to DCN. In this way, DCN won't affect the training of upstream parameters. 75 | - In RFAM, the first branch and the input do not have any spatial downsampling while the other two branches do downsampling, which will lead to mismatch of the size. Besides, the authors do not mention the padding method for the dilation, so I use zero padding. 76 | -------------------------------------------------------------------------------- /configs/FERNet300_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/ssd300.py', '../_base_/datasets/utdac_detection_coco.py', 3 | '../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' 4 | ] 5 | # model settings 6 | input_size = 300 7 | model = dict( 8 | type='SingleStageDetector', 9 | pretrained='open-mmlab://vgg16_caffe', 10 | backbone=dict( 11 | type='CCB', 12 | input_size=input_size, 13 | depth=16, 14 | with_last_pool=False, 15 | ceil_mode=True, 16 | out_indices=(3, 4), 17 | out_feature_indices=(22, 34), 18 | l2_norm_scale=20, 19 | init_cfg=dict( 20 | type='Pretrained', checkpoint='open-mmlab://vgg16_caffe') 21 | ), 22 | neck=None, 23 | bbox_head=dict( 24 | type='PRSHead', 25 | in_channels=(512, 1024, 512, 256, 256, 256), 26 | num_classes=4, 27 | anchor_generator=dict( 28 | type='SSDAnchorGenerator', 29 | scale_major=False, 30 | input_size=input_size, 31 | basesize_ratio_range=(0.15, 0.9), 32 | strides=[8, 16, 32, 64, 100, 300], 33 | ratios=[[2, 3], [2, 3], [2, 3], [2, 3], [2], [2]]), 34 | bbox_coder=dict( 35 | type='DeltaXYWHBBoxCoder', 36 | target_means=[.0, .0, .0, .0], 37 | target_stds=[0.1, 0.1, 0.2, 0.2]), 38 | ), 39 | train_cfg = dict( 40 | assigner=dict( 41 | type='MaxIoUAssigner', 42 | pos_iou_thr=0.5, 43 | neg_iou_thr=0.5, 44 | min_pos_iou=0., 45 | ignore_iof_thr=-1, 46 | gt_max_assign_all=False), 47 | smoothl1_beta=1., 48 | allowed_border=-1, 49 | pos_weight=-1, 50 | neg_pos_ratio=3, 51 | debug=False), 52 | test_cfg = dict( 53 | nms_pre=1000, 54 | nms=dict(type='nms', iou_threshold=0.45), 55 | min_bbox_size=0, 56 | score_thr=0.02, 57 | max_per_img=200)) 58 | cudnn_benchmark = True 59 | 60 | # dataset settings 61 | dataset_type = 'UTDACDataset' 62 | data_root = 'data/UTDAC2020/' 63 | classes = ('echinus','starfish','holothurian','scallop') 64 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True) 65 | train_pipeline = [ 66 | dict(type='LoadImageFromFile', to_float32=True), 67 | dict(type='LoadAnnotations', with_bbox=True), 68 | dict( 69 | type='PhotoMetricDistortion', 70 | brightness_delta=32, 71 | contrast_range=(0.5, 1.5), 72 | saturation_range=(0.5, 1.5), 73 | hue_delta=18), 74 | dict( 75 | type='Expand', 76 | mean=img_norm_cfg['mean'], 77 | to_rgb=img_norm_cfg['to_rgb'], 78 | ratio_range=(1, 4)), 79 | dict( 80 | type='MinIoURandomCrop', 81 | min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), 82 | min_crop_size=0.3), 83 | dict(type='Resize', img_scale=(300, 300), keep_ratio=False), 84 | dict(type='Normalize', **img_norm_cfg), 85 | dict(type='RandomFlip', flip_ratio=0.5), 86 | dict(type='DefaultFormatBundle'), 87 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 88 | ] 89 | test_pipeline = [ 90 | dict(type='LoadImageFromFile'), 91 | dict( 92 | type='MultiScaleFlipAug', 93 | img_scale=(300, 300), 94 | flip=False, 95 | transforms=[ 96 | dict(type='Resize', keep_ratio=False), 97 | dict(type='Normalize', **img_norm_cfg), 98 | dict(type='ImageToTensor', keys=['img']), 99 | dict(type='Collect', keys=['img']), 100 | ]) 101 | ] 102 | data = dict( 103 | samples_per_gpu=12, 104 | workers_per_gpu=6, 105 | train=dict( 106 | _delete_=True, 107 | type='RepeatDataset', 108 | times=5, 109 | dataset=dict( 110 | type=dataset_type, 111 | ann_file=data_root + 'annotations/instances_train2017.json', 112 | img_prefix=data_root + 'train2017/', 113 | pipeline=train_pipeline)), 114 | val=dict(pipeline=test_pipeline), 115 | test=dict(pipeline=test_pipeline)) 116 | # optimizer 117 | optimizer = dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=5e-4) 118 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) -------------------------------------------------------------------------------- /mmdet/backbones/compositebackbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init 5 | from mmcv.runner import load_checkpoint 6 | 7 | from mmdet.utils import get_root_logger 8 | from ..builder import BACKBONES 9 | from .resnet import ResNet 10 | from mmcv.runner import BaseModule 11 | import warnings 12 | 13 | @BACKBONES.register_module() 14 | class CCB(VGG, BaseModule): 15 | """VGG Backbone network for single-shot-detection. 16 | 17 | Args: 18 | input_size (int): width and height of input, from {300, 512}. 19 | depth (int): Depth of vgg, from {11, 13, 16, 19}. 20 | out_indices (Sequence[int]): Output from which stages. 21 | 22 | Example: 23 | >>> self = CCB(input_size=300, depth=11) 24 | >>> self.eval() 25 | >>> inputs = torch.rand(1, 3, 300, 300) 26 | >>> level_outputs = self.forward(inputs) 27 | >>> for level_out in level_outputs: 28 | ... print(tuple(level_out.shape)) 29 | (1, 1024, 19, 19) 30 | (1, 512, 10, 10) 31 | (1, 256, 5, 5) 32 | (1, 256, 3, 3) 33 | (1, 256, 1, 1) 34 | """ 35 | extra_setting = { 36 | 300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256), 37 | 512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128), 38 | } 39 | 40 | def __init__(self, 41 | input_size, 42 | depth, 43 | with_last_pool=False, 44 | ceil_mode=True, 45 | out_indices=(3, 4), 46 | out_feature_indices=(22, 34), 47 | pretrained=None, 48 | init_cfg=None, 49 | l2_norm_scale=20.): 50 | # TODO: in_channels for mmcv.VGG 51 | super(CCB, self).__init__( 52 | depth, 53 | with_last_pool=with_last_pool, 54 | ceil_mode=ceil_mode, 55 | out_indices=out_indices) 56 | assert input_size in (300, 512) 57 | self.input_size = input_size 58 | 59 | self.features.add_module( 60 | str(len(self.features)), 61 | RFAM(512, 0.1) ) 62 | self.features.add_module( 63 | str(len(self.features)), 64 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)) 65 | self.features.add_module( 66 | str(len(self.features)), nn.ReLU(inplace=True)) 67 | self.features.add_module( 68 | str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1)) 69 | self.features.add_module( 70 | str(len(self.features)), nn.ReLU(inplace=True)) 71 | self.out_feature_indices = out_feature_indices 72 | 73 | self.inplanes = 1024 74 | self.extra = self._make_extra_layers(self.extra_setting[input_size]) 75 | self.l2_norm = L2Norm( 76 | self.features[out_feature_indices[0] - 1].out_channels, 77 | l2_norm_scale) 78 | 79 | 80 | self.resnet50 = ResNet(depth=50, init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')) 81 | 82 | if init_cfg is not None: 83 | self.init_cfg = init_cfg 84 | elif isinstance(pretrained, str): 85 | warnings.warn('DeprecationWarning: pretrained is deprecated, ' 86 | 'please use "init_cfg" instead') 87 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 88 | elif pretrained is None: 89 | self.init_cfg = [ 90 | dict(type='Kaiming', layer='Conv2d'), 91 | dict(type='Constant', val=1, layer='BatchNorm2d'), 92 | dict(type='Normal', std=0.01, layer='Linear'), 93 | ] 94 | else: 95 | raise TypeError('pretrained must be a str or None') 96 | 97 | self.chaAdj = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1), nn.Conv2d(384, 128, kernel_size=1), nn.Conv2d(768, 256, kernel_size=1)) 98 | self.RFAM_PRO = RFAM_PRO(512, 0.1) 99 | self.RFAMs = nn.Sequential( RFAM(512, 0.1), RFAM(1024, 0.1), RFAM(512, 0.1)) 100 | 101 | 102 | def init_weights(self, pretrained=None): 103 | """Initialize the weights in backbone. 104 | 105 | Args: 106 | pretrained (str, optional): Path to pre-trained weights. 107 | Defaults to None. 108 | """ 109 | if isinstance(pretrained, str): 110 | logger = get_root_logger() 111 | load_checkpoint(self, pretrained, strict=False, logger=logger) 112 | 113 | elif pretrained is None: 114 | for m in self.features.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | kaiming_init(m) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | constant_init(m, 1) 119 | elif isinstance(m, nn.Linear): 120 | normal_init(m, std=0.01) 121 | else: 122 | raise TypeError('pretrained must be a str or None') 123 | 124 | for m in self.extra.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | xavier_init(m, distribution='uniform') 127 | constant_init(self.l2_norm, self.l2_norm.scale) 128 | # self.resnet50.init_weights('torchvision://resnet50') 129 | 130 | 131 | def forward(self, x): 132 | """Forward function.""" 133 | # self.resnet5 134 | outs = [] 135 | resnet_feat = [] 136 | feat = self.resnet50.conv1(x) 137 | feat = self.resnet50.norm1(feat) 138 | feat = self.resnet50.relu(feat) 139 | resnet_feat.append(feat) 140 | feat = self.resnet50.maxpool(feat) 141 | feat = self.resnet50.layer1(feat) 142 | resnet_feat.append(feat) 143 | feat = self.resnet50.layer2(feat) 144 | resnet_feat.append(feat) 145 | count = 0 146 | rfam_count = 0 147 | for i, layer in enumerate(self.features): 148 | x = layer(x) 149 | if type(layer) == nn.MaxPool2d and count < 3: 150 | x = torch.cat((x,resnet_feat[count]), dim=1) 151 | x = self.chaAdj[count](x) 152 | count += 1 153 | if i in self.out_feature_indices: 154 | if i == 22: 155 | outs.append(self.RFAM_PRO(x)) 156 | else: 157 | outs.append(x) 158 | x = self.RFAMs[rfam_count](x) 159 | rfam_count += 1 160 | for i, layer in enumerate(self.extra): 161 | x = F.relu(layer(x), inplace=True) 162 | if i % 2 == 1: 163 | outs.append(x) 164 | if rfam_count < 3: 165 | x = self.RFAMs[rfam_count](x) 166 | rfam_count += 1 167 | outs[0] = self.l2_norm(outs[0]) 168 | if len(outs) == 1: 169 | return outs[0] 170 | else: 171 | return tuple(outs) 172 | 173 | def _make_extra_layers(self, outplanes): 174 | layers = [] 175 | kernel_sizes = (1, 3) 176 | num_layers = 0 177 | outplane = None 178 | for i in range(len(outplanes)): 179 | if self.inplanes == 'S': 180 | self.inplanes = outplane 181 | continue 182 | k = kernel_sizes[num_layers % 2] 183 | if outplanes[i] == 'S': 184 | outplane = outplanes[i + 1] 185 | conv = nn.Conv2d( 186 | self.inplanes, outplane, k, stride=2, padding=1) 187 | else: 188 | outplane = outplanes[i] 189 | conv = nn.Conv2d( 190 | self.inplanes, outplane, k, stride=1, padding=0) 191 | layers.append(conv) 192 | self.inplanes = outplanes[i] 193 | num_layers += 1 194 | if self.input_size == 512: 195 | layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | 200 | 201 | class L2Norm(nn.Module): 202 | 203 | def __init__(self, n_dims, scale=20., eps=1e-10): 204 | """L2 normalization layer. 205 | 206 | Args: 207 | n_dims (int): Number of dimensions to be normalized 208 | scale (float, optional): Defaults to 20.. 209 | eps (float, optional): Used to avoid division by zero. 210 | Defaults to 1e-10. 211 | """ 212 | super(L2Norm, self).__init__() 213 | self.n_dims = n_dims 214 | self.weight = nn.Parameter(torch.Tensor(self.n_dims)) 215 | self.eps = eps 216 | self.scale = scale 217 | 218 | def forward(self, x): 219 | """Forward function.""" 220 | # normalization layer convert to FP32 in FP16 training 221 | x_float = x.float() 222 | norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps 223 | return (self.weight[None, :, None, None].float().expand_as(x_float) * 224 | x_float / norm).type_as(x) 225 | 226 | # class RFAM(nn.Module): 227 | # def __init__(self, indim, scale): 228 | # super(RFAM, self).__init__() 229 | # embdim = indim//4 230 | # self.branch1 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), nn.Conv2d(embdim, embdim, kernel_size=3, padding=1), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 231 | # self.branch2 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), nn.Conv2d(embdim, embdim, kernel_size=3, stride=2, padding=1), nn.Conv2d(embdim, embdim, kernel_size=5, dilation=3, padding=6)) 232 | # self.branch3 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), 233 | # nn.Conv2d(embdim, embdim, kernel_size=5, stride=2, padding=2), 234 | # nn.Conv2d(embdim, embdim, kernel_size=3, dilation=5, padding=5)) 235 | # self.conv = nn.Conv2d(3*embdim, indim, kernel_size = 1) 236 | # self.relu = nn.ReLU() 237 | # self.scale = scale 238 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 239 | # 240 | # def forward(self, x): 241 | # x1 = self.branch1(x) 242 | # x2 = self.branch2(x) 243 | # x3 = self.branch3(x) 244 | # x = self.maxpool(x) 245 | # res = torch.cat((x1,x2,x3), dim = 1) 246 | # res = self.conv(res) 247 | # x = x + (res * self.scale) 248 | # x = self.relu(x) 249 | # return x 250 | 251 | class RFAM(nn.Module): 252 | def __init__(self, indim, scale): 253 | super(RFAM, self).__init__() 254 | embdim = indim//4 255 | self.branch1 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), nn.Conv2d(embdim, embdim, kernel_size=3, padding=1)) 256 | self.branch2 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), nn.Conv2d(embdim, embdim, kernel_size=3, padding=1), nn.Conv2d(embdim, embdim, kernel_size=5, dilation=3, padding=6)) 257 | self.branch3 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), 258 | nn.Conv2d(embdim, embdim, kernel_size=5, padding=3), 259 | nn.Conv2d(embdim, embdim, kernel_size=3, dilation=5, padding=4)) 260 | self.conv = nn.Conv2d(3*embdim, indim, kernel_size = 1) 261 | self.relu = nn.ReLU() 262 | self.scale = scale 263 | 264 | def forward(self, x): 265 | x1 = self.branch1(x) 266 | x2 = self.branch2(x) 267 | x3 = self.branch3(x) 268 | res = torch.cat((x1,x2,x3), dim = 1) 269 | res = self.conv(res) 270 | x = x + (res * self.scale) 271 | x = self.relu(x) 272 | return x 273 | 274 | class RFAM_PRO(nn.Module): 275 | def __init__(self, indim, scale): 276 | super(RFAM_PRO, self).__init__() 277 | embdim = indim // 4 278 | self.branch1 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), 279 | nn.Conv2d(embdim, embdim, kernel_size=3, padding=1)) 280 | self.branch2 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), 281 | nn.Conv2d(embdim, embdim, kernel_size=(3,1), padding=(1,0)), 282 | nn.Conv2d(embdim, embdim, kernel_size=3, dilation=3, padding=3)) 283 | self.branch3 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), 284 | nn.Conv2d(embdim, embdim, kernel_size=(1,3), padding=(0,1)), 285 | nn.Conv2d(embdim, embdim, kernel_size=3, dilation=3, padding=3)) 286 | self.branch4 = nn.Sequential(nn.Conv2d(indim, embdim, kernel_size=1), 287 | nn.Conv2d(embdim, embdim, kernel_size=(1,3), padding=(0,1)), 288 | nn.Conv2d(embdim, embdim, kernel_size=(3, 1), padding=(1, 0)), 289 | nn.Conv2d(embdim, embdim, kernel_size=3, dilation=5, padding=5)) 290 | self.conv = nn.Conv2d(4 * embdim, indim, kernel_size=1) 291 | self.relu = nn.ReLU() 292 | self.scale = scale 293 | 294 | def forward(self, x): 295 | x1 = self.branch1(x) 296 | x2 = self.branch2(x) 297 | x3 = self.branch3(x) 298 | x4 = self.branch4(x) 299 | res = torch.cat((x1, x2, x3, x4), dim=1) 300 | res = self.conv(res) 301 | x = x + (res * self.scale) 302 | x = self.relu(x) 303 | return x 304 | 305 | if __name__ == '__main__': 306 | x = torch.rand(8,512,32,32) 307 | conv = RFAM(512,0.1) 308 | x = conv(x) -------------------------------------------------------------------------------- /mmdet/dense_heads/prs_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import xavier_init 5 | from mmcv.runner import force_fp32 6 | 7 | from mmdet.core import (build_anchor_generator, build_assigner, 8 | build_bbox_coder, build_sampler, multi_apply) 9 | from ..builder import HEADS, build_loss 10 | from ..losses import smooth_l1_loss 11 | from .anchor_head import AnchorHead 12 | from mmcv.ops import DeformConv2d 13 | from mmcv.runner import BaseModule 14 | 15 | def images_to_levels(target, num_levels): 16 | """Convert targets by image to targets by feature level. 17 | 18 | [target_img0, target_img1] -> [target_level0, target_level1, ...] 19 | """ 20 | target = torch.stack(target, 0) 21 | level_targets = [] 22 | start = 0 23 | for n in num_levels: 24 | end = start + n 25 | # level_targets.append(target[:, start:end].squeeze(0)) 26 | level_targets.append(target[:, start:end]) 27 | start = end 28 | return level_targets 29 | 30 | class AdaptiveConv(BaseModule): 31 | def __init__(self, 32 | in_channels, 33 | out_channels, 34 | kernel_size=3, 35 | stride=1, 36 | padding=1, 37 | dilation=3, 38 | groups=1, 39 | deform_groups=1, 40 | bias=False, 41 | type='offset', 42 | init_cfg=dict( 43 | type='Normal', std=0.01, override=dict(name='conv'))): 44 | super(AdaptiveConv, self).__init__(init_cfg) 45 | assert type in ['offset', 'dilation'] 46 | self.adapt_type = type 47 | 48 | assert kernel_size == 3, 'Adaptive conv only supports kernels 3' 49 | if self.adapt_type == 'offset': 50 | assert stride == 1 and padding == 1, \ 51 | 'Adaptive conv offset mode only supports padding: {1}, ' \ 52 | f'stride: {1}, groups: {1}' 53 | self.conv = DeformConv2d( 54 | in_channels, 55 | out_channels, 56 | kernel_size, 57 | padding=padding, 58 | stride=stride, 59 | groups=groups, 60 | deform_groups=deform_groups, 61 | bias=bias) 62 | else: 63 | self.conv = nn.Conv2d( 64 | in_channels, 65 | out_channels, 66 | kernel_size, 67 | padding=dilation, 68 | dilation=dilation) 69 | 70 | def forward(self, x, offset, num_anchors): 71 | """Forward function.""" 72 | if self.adapt_type == 'offset': 73 | N, _, H, W = x.shape 74 | assert offset is not None 75 | # reshape [N, NA, 18] to (N, 18, H, W) 76 | offset = offset.reshape(N, H, W, -1) 77 | offset = offset.permute(0, 3, 1, 2) 78 | offset = offset.contiguous() 79 | x = self.conv(x, offset) 80 | else: 81 | assert offset is None 82 | x = self.conv(x) 83 | return x 84 | 85 | # TODO: add loss evaluator for SSD 86 | @HEADS.register_module() 87 | class PRSHead(AnchorHead): 88 | 89 | def __init__(self, 90 | num_classes=80, 91 | in_channels=(512, 1024, 512, 256, 256, 256), 92 | anchor_generator=dict( 93 | type='SSDAnchorGenerator', 94 | scale_major=False, 95 | input_size=300, 96 | strides=[8, 16, 32, 64, 100, 300], 97 | ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), 98 | basesize_ratio_range=(0.1, 0.9)), 99 | background_label=None, 100 | bbox_coder=dict( 101 | type='DeltaXYWHBBoxCoder', 102 | target_means=[.0, .0, .0, .0], 103 | target_stds=[1.0, 1.0, 1.0, 1.0], 104 | ), 105 | reg_decoded_bbox=False, 106 | train_cfg=None, 107 | test_cfg=None, 108 | init_cfg=dict( 109 | type='Normal', 110 | layer='Conv2d', 111 | std=0.01, 112 | override=dict( 113 | type='Normal', 114 | name='cls_convs_refine', 115 | std=0.01, 116 | bias_prob=0.01)), 117 | **kwargs): 118 | super(AnchorHead, self).__init__(init_cfg, **kwargs) 119 | self.num_classes = num_classes 120 | self.in_channels = in_channels 121 | self.cls_out_channels = num_classes + 1 # add background class 122 | self.anchor_generator = build_anchor_generator(anchor_generator) 123 | num_anchors = self.anchor_generator.num_base_anchors 124 | self.num_anchors = num_anchors 125 | self.anchor_strides = anchor_generator['strides'] 126 | reg_convs = [] 127 | cls_convs = [] 128 | for i in range(len(in_channels)): 129 | reg_convs.append( 130 | nn.Conv2d( 131 | in_channels[i], 132 | num_anchors[i] * 4, 133 | kernel_size=3, 134 | padding=1)) 135 | cls_convs.append( 136 | nn.Conv2d( 137 | in_channels[i], 138 | num_anchors[i], 139 | kernel_size=3, 140 | padding=1)) 141 | self.reg_convs = nn.ModuleList(reg_convs) 142 | self.cls_convs = nn.ModuleList(cls_convs) 143 | 144 | self.background_label = ( 145 | num_classes if background_label is None else background_label) 146 | # background_label should be either 0 or num_classes 147 | assert (self.background_label == 0 148 | or self.background_label == num_classes) 149 | 150 | self.bbox_coder = build_bbox_coder(bbox_coder) 151 | self.reg_decoded_bbox = reg_decoded_bbox 152 | self.use_sigmoid_cls = False 153 | self.cls_focal_loss = False 154 | self.train_cfg = train_cfg 155 | self.test_cfg = test_cfg 156 | # set sampling=False for archor_target 157 | self.sampling = False 158 | if self.train_cfg: 159 | self.assigner = build_assigner(self.train_cfg.assigner) 160 | # SSD sampling=False so use PseudoSampler 161 | sampler_cfg = dict(type='PseudoSampler') 162 | self.sampler = build_sampler(sampler_cfg, context=self) 163 | self.fp16_enabled = False 164 | 165 | # dcn 166 | dcn = [] 167 | reg_convs_refine = [] 168 | cls_convs_refine = [] 169 | for i in range(len(in_channels)): 170 | dcn.append(AdaptiveConv(num_anchors[i]*in_channels[i], in_channels[i],deform_groups=self.num_anchors[i])) 171 | reg_convs_refine.append( 172 | nn.Conv2d( 173 | in_channels[i], 174 | num_anchors[i] * 4, 175 | kernel_size=3, 176 | padding=1)) 177 | cls_convs_refine.append( 178 | nn.Conv2d( 179 | in_channels[i], 180 | num_anchors[i] * self.cls_out_channels, 181 | kernel_size=3, 182 | padding=1)) 183 | self.dcn = nn.ModuleList(dcn) 184 | self.reg_convs_refine = nn.ModuleList(reg_convs_refine) 185 | self.cls_convs_refine = nn.ModuleList(cls_convs_refine) 186 | self.relu = nn.ReLU(inplace=True) 187 | # self.BCE = build_loss(loss_cls_pre) 188 | 189 | def init_weights(self): 190 | """Initialize weights of the head.""" 191 | for m in self.modules(): 192 | if isinstance(m, nn.Conv2d): 193 | xavier_init(m, distribution='uniform', bias=0) 194 | 195 | def forward_train(self, 196 | x, 197 | img_metas, 198 | gt_bboxes, 199 | gt_labels=None, 200 | gt_bboxes_ignore=None, 201 | proposal_cfg=None, 202 | **kwargs): 203 | fg_scores, bbox_preds = self(x) 204 | 205 | featmap_sizes = [featmap.size()[-2:] for featmap in fg_scores] 206 | assert len(featmap_sizes) == self.anchor_generator.num_levels 207 | 208 | device = fg_scores[0].device 209 | anchor_list, valid_flag_list = self.get_anchors( 210 | featmap_sizes, img_metas, device=device) 211 | 212 | if gt_labels is None: 213 | loss_inputs = (anchor_list, valid_flag_list, fg_scores, bbox_preds, gt_bboxes, img_metas) 214 | else: 215 | loss_inputs = (anchor_list, valid_flag_list, fg_scores, bbox_preds, gt_bboxes, gt_labels, img_metas) 216 | losses = self.loss_pre(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) 217 | 218 | new_feats = [] 219 | 220 | # pre-processing 221 | for i in range(len(fg_scores)): 222 | score = fg_scores[i] 223 | s,_ = torch.max(score, dim = 1) 224 | s = s.unsqueeze(1) 225 | s = F.sigmoid(s) 226 | new_feats.append(s*x[i]+x[i]) 227 | # new_feats = x 228 | 229 | anchor_list_refine = self.refine_bboxes(anchor_list, bbox_preds, img_metas) 230 | offset_list = self.anchor_offset(anchor_list_refine, self.anchor_strides, featmap_sizes) 231 | 232 | cls_scores, bbox_preds_refine = self.forward_post(new_feats, offset_list) 233 | if gt_labels is None: 234 | loss_inputs = (anchor_list_refine, valid_flag_list, cls_scores, bbox_preds_refine, gt_bboxes, img_metas) 235 | else: 236 | loss_inputs = (anchor_list_refine, valid_flag_list, cls_scores, bbox_preds_refine, gt_bboxes, gt_labels, img_metas) 237 | losses_post = self.loss_post(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) 238 | 239 | losses.update(losses_post) 240 | if proposal_cfg is None: 241 | return losses 242 | else: 243 | proposal_list = self.get_bboxes(anchor_list_refine, cls_scores, bbox_preds_refine, img_metas, cfg=proposal_cfg) 244 | return losses, proposal_list 245 | 246 | def forward(self, feats): 247 | cls_scores = [] 248 | bbox_preds = [] 249 | 250 | for feat, reg_conv, cls_conv in zip(feats, self.reg_convs, 251 | self.cls_convs): 252 | cls_scores.append(cls_conv(feat)) 253 | bbox_preds.append(reg_conv(feat)) 254 | 255 | return cls_scores, bbox_preds 256 | 257 | def forward_post(self, feats, offset_list): 258 | cls_scores = [] 259 | bbox_preds = [] 260 | for i in range(len(feats)): 261 | x = feats[i] 262 | shape = list(x.shape) 263 | x = x.unsqueeze(1).expand((shape[0],self.num_anchors[i],shape[1],shape[2],shape[3])) 264 | shape[1] = shape[1] * self.num_anchors[i] 265 | x = x.reshape(shape) 266 | offset = offset_list[i] 267 | feat = self.relu(self.dcn[i](x, offset, self.num_anchors[i])) 268 | cls_score = self.cls_convs_refine[i](feat) 269 | bbox_pred = self.reg_convs_refine[i](feat) 270 | cls_scores.append(cls_score) 271 | bbox_preds.append(bbox_pred) 272 | return cls_scores, bbox_preds 273 | 274 | def loss_single_post(self, cls_score, bbox_pred, anchor, labels, label_weights, 275 | bbox_targets, bbox_weights, num_total_samples): 276 | loss_cls_all = F.cross_entropy( 277 | cls_score, labels, reduction='none') * label_weights 278 | # FG cat_id: [0, num_classes -1], BG cat_id: num_classes 279 | pos_inds = ((labels >= 0) & 280 | (labels < self.background_label)).nonzero().reshape(-1) 281 | neg_inds = (labels == self.background_label).nonzero().view(-1) 282 | 283 | num_pos_samples = pos_inds.size(0) 284 | num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples 285 | if num_neg_samples > neg_inds.size(0): 286 | num_neg_samples = neg_inds.size(0) 287 | topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) 288 | loss_cls_pos = loss_cls_all[pos_inds].sum() 289 | loss_cls_neg = topk_loss_cls_neg.sum() 290 | loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples 291 | 292 | if self.reg_decoded_bbox: 293 | bbox_pred = self.bbox_coder.decode(anchor, bbox_pred) 294 | 295 | loss_bbox = smooth_l1_loss( 296 | bbox_pred, 297 | bbox_targets, 298 | bbox_weights*0.5, 299 | beta=self.train_cfg.smoothl1_beta, 300 | avg_factor=num_total_samples) 301 | 302 | return loss_cls[None], loss_bbox 303 | 304 | def loss_single_pre(self, fg_scores, bbox_pred, anchor, fg_labels, label_weights, 305 | bbox_targets, bbox_weights, num_total_samples): 306 | fg_loss = (F.binary_cross_entropy( 307 | F.sigmoid(fg_scores), fg_labels, reduction='none') * label_weights).mean() 308 | 309 | 310 | if self.reg_decoded_bbox: 311 | bbox_pred = self.bbox_coder.decode(anchor, bbox_pred) 312 | 313 | loss_bbox = smooth_l1_loss( 314 | bbox_pred, 315 | bbox_targets, 316 | bbox_weights*0.5, 317 | beta=self.train_cfg.smoothl1_beta, 318 | avg_factor=num_total_samples) 319 | 320 | 321 | return fg_loss, loss_bbox 322 | 323 | def loss_pre(self, 324 | anchor_list, 325 | valid_flag_list, 326 | fg_scores, 327 | bbox_preds, 328 | gt_bboxes, 329 | gt_labels, 330 | img_metas, 331 | gt_bboxes_ignore=None): 332 | 333 | featmap_sizes = [featmap.size()[-2:] for featmap in fg_scores] 334 | assert len(featmap_sizes) == self.anchor_generator.num_levels 335 | 336 | # device = fg_scores[0].device 337 | 338 | cls_reg_targets = self.get_targets( 339 | anchor_list, 340 | valid_flag_list, 341 | gt_bboxes, 342 | img_metas, 343 | gt_bboxes_ignore_list=gt_bboxes_ignore, 344 | gt_labels_list=gt_labels, 345 | label_channels=1, 346 | unmap_outputs=False) 347 | if cls_reg_targets is None: 348 | return None 349 | (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, 350 | num_total_pos, num_total_neg) = cls_reg_targets 351 | 352 | num_images = len(img_metas) 353 | all_labels = torch.cat(labels_list, -1).view(num_images, -1) 354 | all_label_weights = torch.cat(label_weights_list, 355 | -1).view(num_images, -1) 356 | all_bbox_preds = torch.cat([ 357 | b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) 358 | for b in bbox_preds 359 | ], -2) 360 | all_bbox_targets = torch.cat(bbox_targets_list, 361 | -2).view(num_images, -1, 4) 362 | all_bbox_weights = torch.cat(bbox_weights_list, 363 | -2).view(num_images, -1, 4) 364 | 365 | # concat all level anchors to a single tensor 366 | all_anchors = [] 367 | for i in range(num_images): 368 | all_anchors.append(torch.cat(anchor_list[i])) 369 | 370 | # check NaN and Inf 371 | assert torch.isfinite(all_bbox_preds).all().item(), \ 372 | 'bbox predications become infinite or NaN!' 373 | 374 | fg_labels = all_labels.clone().float() 375 | fg_labels[fg_labels != self.num_classes] = 1 376 | fg_labels[fg_labels == self.num_classes] = 0 377 | 378 | all_fg_scores = torch.cat([ 379 | f.permute(0, 2, 3, 1).reshape( 380 | num_images, -1) for f in fg_scores 381 | ], 1) 382 | 383 | 384 | fg_losses, losses_bbox = multi_apply( 385 | self.loss_single_pre, 386 | all_fg_scores, 387 | all_bbox_preds, 388 | all_anchors, 389 | fg_labels, 390 | all_label_weights, 391 | all_bbox_targets, 392 | all_bbox_weights, 393 | num_total_samples=num_total_pos) 394 | return dict(losses_fg=fg_losses, loss_bbox=losses_bbox) 395 | 396 | def loss_post(self, 397 | anchor_list, 398 | valid_flag_list, 399 | cls_scores, 400 | bbox_preds, 401 | gt_bboxes, 402 | gt_labels, 403 | img_metas, 404 | gt_bboxes_ignore=None): 405 | 406 | featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] 407 | assert len(featmap_sizes) == self.anchor_generator.num_levels 408 | 409 | # device = cls_scores[0].device 410 | 411 | cls_reg_targets = self.get_targets( 412 | anchor_list, 413 | valid_flag_list, 414 | gt_bboxes, 415 | img_metas, 416 | gt_bboxes_ignore_list=gt_bboxes_ignore, 417 | gt_labels_list=gt_labels, 418 | label_channels=1, 419 | unmap_outputs=False) 420 | if cls_reg_targets is None: 421 | return None 422 | (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, 423 | num_total_pos, num_total_neg) = cls_reg_targets 424 | 425 | num_images = len(img_metas) 426 | all_cls_scores = torch.cat([ 427 | s.permute(0, 2, 3, 1).reshape( 428 | num_images, -1, self.cls_out_channels) for s in cls_scores 429 | ], 1) 430 | all_labels = torch.cat(labels_list, -1).view(num_images, -1) 431 | all_label_weights = torch.cat(label_weights_list, 432 | -1).view(num_images, -1) 433 | all_bbox_preds = torch.cat([ 434 | b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) 435 | for b in bbox_preds 436 | ], -2) 437 | all_bbox_targets = torch.cat(bbox_targets_list, 438 | -2).view(num_images, -1, 4) 439 | all_bbox_weights = torch.cat(bbox_weights_list, 440 | -2).view(num_images, -1, 4) 441 | 442 | # concat all level anchors to a single tensor 443 | all_anchors = [] 444 | for i in range(num_images): 445 | all_anchors.append(torch.cat(anchor_list[i])) 446 | 447 | # check NaN and Inf 448 | assert torch.isfinite(all_cls_scores).all().item(), \ 449 | 'classification scores become infinite or NaN!' 450 | assert torch.isfinite(all_bbox_preds).all().item(), \ 451 | 'bbox predications become infinite or NaN!' 452 | 453 | fg_labels = all_labels.clone().float() 454 | fg_labels[fg_labels != self.num_classes] = 1 455 | fg_labels[fg_labels == self.num_classes] = 0 456 | 457 | 458 | losses_cls, losses_bbox_ref = multi_apply( 459 | self.loss_single_post, 460 | all_cls_scores, 461 | all_bbox_preds, 462 | all_anchors, 463 | all_labels, 464 | all_label_weights, 465 | all_bbox_targets, 466 | all_bbox_weights, 467 | num_total_samples=num_total_pos) 468 | return dict(loss_cls=losses_cls, losses_bbox_ref=losses_bbox_ref) 469 | 470 | def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes): 471 | def _shape_offset(anchors, stride, ks=3, dilation=1): 472 | # currently support kernel_size=3 and dilation=1 473 | assert ks == 3 and dilation == 1 474 | pad = (ks - 1) // 2 475 | idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device) 476 | yy, xx = torch.meshgrid(idx, idx) # return order matters 477 | xx = xx.reshape(-1) 478 | yy = yy.reshape(-1) 479 | w = (anchors[:, 2] - anchors[:, 0]) / stride 480 | h = (anchors[:, 3] - anchors[:, 1]) / stride 481 | w = w / (ks - 1) - dilation 482 | h = h / (ks - 1) - dilation 483 | offset_x = w[:, None] * xx # (NA, ks**2) 484 | offset_y = h[:, None] * yy # (NA, ks**2) 485 | return offset_x, offset_y 486 | 487 | def _ctr_offset(anchors, stride, featmap_size, num_anchors): 488 | feat_h, feat_w = featmap_size 489 | 490 | x = (anchors[:, 0] + anchors[:, 2]) * 0.5 491 | y = (anchors[:, 1] + anchors[:, 3]) * 0.5 492 | # compute centers on feature map 493 | x = x / stride 494 | y = y / stride 495 | # compute predefine centers 496 | xx = torch.arange(0, feat_w, device=anchors.device) 497 | yy = torch.arange(0, feat_h, device=anchors.device) 498 | yy, xx = torch.meshgrid(yy, xx) 499 | xx = xx.reshape(-1).type_as(x) 500 | yy = yy.reshape(-1).type_as(y) 501 | 502 | xx = xx.unsqueeze(1).expand(xx.shape+(num_anchors,)).reshape(-1) 503 | yy = yy.unsqueeze(1).expand(yy.shape+(num_anchors,)).reshape(-1) 504 | 505 | offset_x = x - xx # (NA, ) 506 | offset_y = y - yy # (NA, ) 507 | return offset_x, offset_y 508 | 509 | num_imgs = len(anchor_list) 510 | num_lvls = len(anchor_list[0]) 511 | dtype = anchor_list[0][0].dtype 512 | device = anchor_list[0][0].device 513 | num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] 514 | 515 | offset_list = [] 516 | for i in range(num_imgs): 517 | mlvl_offset = [] 518 | for lvl in range(num_lvls): 519 | c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl], 520 | anchor_strides[lvl], 521 | featmap_sizes[lvl], 522 | self.num_anchors[lvl]) 523 | s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl], 524 | anchor_strides[lvl]) 525 | 526 | # offset = ctr_offset + shape_offset 527 | offset_x = s_offset_x + c_offset_x[:, None] 528 | offset_y = s_offset_y + c_offset_y[:, None] 529 | 530 | # offset order (y0, x0, y1, x2, .., y8, x8, y9, x9) 531 | offset = torch.stack([offset_y, offset_x], dim=-1) 532 | offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2] 533 | mlvl_offset.append(offset) 534 | offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2] 535 | offset_list = images_to_levels(offset_list, num_level_anchors) 536 | return offset_list 537 | 538 | def refine_bboxes(self, anchor_list, bbox_preds, img_metas): 539 | """Refine bboxes through stages.""" 540 | num_levels = len(bbox_preds) 541 | new_anchor_list = [] 542 | for img_id in range(len(img_metas)): 543 | mlvl_anchors = [] 544 | for i in range(num_levels): 545 | bbox_pred = bbox_preds[i][img_id].detach() 546 | bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) 547 | img_shape = img_metas[img_id]['img_shape'] 548 | bboxes = self.bbox_coder.decode(anchor_list[img_id][i], 549 | bbox_pred, img_shape) 550 | mlvl_anchors.append(bboxes) 551 | new_anchor_list.append(mlvl_anchors) 552 | return new_anchor_list 553 | 554 | def simple_test_bboxes(self, feats, img_metas, rescale=False): 555 | fg_scores, bbox_preds = self(feats) 556 | 557 | featmap_sizes = [featmap.size()[-2:] for featmap in fg_scores] 558 | assert len(featmap_sizes) == self.anchor_generator.num_levels 559 | 560 | device = fg_scores[0].device 561 | anchor_list, valid_flag_list = self.get_anchors( 562 | featmap_sizes, img_metas, device=device) 563 | 564 | new_feats = [] 565 | 566 | # pre-processing 567 | # for i in range(len(cls_scores)): 568 | # score = fg_scores[i] 569 | # s,_ = torch.max(score, dim = 1) 570 | # s = s.unsqueeze(1) 571 | # s = F.sigmoid(s) 572 | # new_feats.append(s*feats[i]+feats[i]) 573 | new_feats = feats 574 | 575 | anchor_list_refine = self.refine_bboxes(anchor_list, bbox_preds, img_metas) 576 | offset_list = self.anchor_offset(anchor_list_refine, self.anchor_strides, featmap_sizes) 577 | cls_scores, bbox_preds_refine = self.forward_post(new_feats, offset_list) 578 | results_list = self.get_bboxes(anchor_list_refine[0], cls_scores, bbox_preds_refine, img_metas, rescale=rescale) 579 | return results_list 580 | 581 | @force_fp32(apply_to=('cls_scores', 'bbox_preds')) 582 | def get_bboxes(self, 583 | anchor_list, 584 | cls_scores, 585 | bbox_preds, 586 | img_metas, 587 | cfg=None, 588 | rescale=False, 589 | with_nms=True): 590 | 591 | assert len(cls_scores) == len(bbox_preds) 592 | num_levels = len(cls_scores) 593 | 594 | # device = cls_scores[0].device 595 | # featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] 596 | # mlvl_anchors2 = self.anchor_generator.grid_anchors( 597 | # featmap_sizes, device=device) 598 | # anchor_list = images_to_levels(anchor_list, num_levels) 599 | mlvl_anchors = [anchor_list[i].detach() for i in range(num_levels)] 600 | mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] 601 | mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] 602 | 603 | if torch.onnx.is_in_onnx_export(): 604 | assert len( 605 | img_metas 606 | ) == 1, 'Only support one input image while in exporting to ONNX' 607 | img_shapes = img_metas[0]['img_shape_for_onnx'] 608 | else: 609 | img_shapes = [ 610 | img_metas[i]['img_shape'] 611 | for i in range(cls_scores[0].shape[0]) 612 | ] 613 | scale_factors = [ 614 | img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0]) 615 | ] 616 | 617 | if with_nms: 618 | # some heads don't support with_nms argument 619 | result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds, 620 | mlvl_anchors, img_shapes, 621 | scale_factors, cfg, rescale) 622 | else: 623 | result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds, 624 | mlvl_anchors, img_shapes, 625 | scale_factors, cfg, rescale, 626 | with_nms) 627 | return result_list --------------------------------------------------------------------------------