├── .gitignore ├── LICENSE ├── README.md ├── trident.py ├── trident_gn.py └── trident_paper.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 dl19940602 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Trident-block-pytorch 2 | This is trident-block of pytorch type,which is designed follow the share weight. 3 | It is an implement of classifiaction,following are the diff backbone based on trident-block 4 | 5 | | models | Params | FLOPs | 6 | | ------ | ------ | ------ | 7 | | Resnet50 | 25.59M | 3.63G | 8 | | Resnet101 | 44.59M | 7.36G | 9 | | Resnet152 | 60.23M | 11.09G | 10 | 11 | trident-paper 12 | 13 | | models | Params | FLOPs | 14 | | ------ | ------ | ------ | 15 | | Resnet50 | 25.59M | 4.6G | 16 | | Resnet101 | 44.69M | 4.65G | 17 | | Resnet152 | 60.41M | 5.57G | 18 | -------------------------------------------------------------------------------- /trident.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | import math 11 | import torch.utils.model_zoo as model_zoo 12 | import pdb 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101','resnet_trident_101', 16 | 'resnet152'] 17 | 18 | 19 | model_urls = { 20 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 21 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 22 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 25 | } 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | #CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 67 | def __init__(self, inplanes, planes, stride=1, downsample=None):#inplanes输入channel,planes输出channel 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) # change 70 | self.bn1 = nn.BatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, # change 72 | padding=1, bias=False) 73 | self.bn2 = nn.BatchNorm2d(planes) 74 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = nn.BatchNorm2d(planes * 4) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | class trident_block(nn.Module): 103 | expansion = 4 104 | def __init__(self, inplanes, planes, stride=1, downsample=None, padding=[1, 2, 3], dilate=[1, 2, 3]): 105 | super(trident_block, self).__init__() 106 | self.stride = stride 107 | self.padding = padding 108 | self.dilate = dilate 109 | self.downsample = downsample 110 | self.share_weight4conv1 = nn.Parameter(torch.randn(planes, inplanes, 1, 1)) 111 | self.share_weight4conv2 = nn.Parameter(torch.randn(planes, planes, 3, 3)) 112 | self.share_weight4conv3 = nn.Parameter(torch.randn(planes * self.expansion, planes, 1, 1))#1*1/64, 3*3/64, 1*1/256 113 | 114 | self.bn11 = nn.BatchNorm2d(planes)#bn层 115 | self.bn12 = nn.BatchNorm2d(planes) 116 | self.bn13 = nn.BatchNorm2d(planes * self.expansion) 117 | 118 | self.bn21 = nn.BatchNorm2d(planes) 119 | self.bn22 = nn.BatchNorm2d(planes) 120 | self.bn23 = nn.BatchNorm2d(planes * self.expansion) 121 | 122 | self.bn31 = nn.BatchNorm2d(planes) 123 | self.bn32 = nn.BatchNorm2d(planes) 124 | self.bn33 = nn.BatchNorm2d(planes * self.expansion) 125 | 126 | self.relu1 = nn.ReLU(inplace=True)#relu层 127 | self.relu2 = nn.ReLU(inplace=True) 128 | self.relu3 = nn.ReLU(inplace=True) 129 | 130 | def forward_for_small(self, x): 131 | residual = x 132 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 133 | out = self.bn11(out) 134 | out = self.relu1(out) 135 | 136 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[0], dilation=self.dilate[0]) 137 | 138 | out = self.bn12(out) 139 | out = self.relu1(out) 140 | 141 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None) 142 | out = self.bn13(out) 143 | 144 | if self.downsample is not None: 145 | residual = self.downsample(x) 146 | 147 | out += residual 148 | out = self.relu1(out) 149 | 150 | return out 151 | 152 | def forward_for_middle(self, x): 153 | residual = x 154 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 155 | out = self.bn21(out) 156 | out = self.relu2(out) 157 | 158 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[1],dilation=self.dilate[1]) 159 | 160 | out = self.bn22(out) 161 | out = self.relu2(out) 162 | 163 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None) 164 | out = self.bn23(out) 165 | 166 | if self.downsample is not None: 167 | residual = self.downsample(x) 168 | print(out.shape) 169 | print(residual.shape) 170 | 171 | out += residual 172 | out = self.relu2(out) 173 | 174 | return out 175 | 176 | def forward_for_big(self, x): 177 | residual = x 178 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 179 | out = self.bn31(out) 180 | out = self.relu3(out) 181 | 182 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[2], dilation=self.dilate[2]) 183 | 184 | out = self.bn32(out) 185 | out = self.relu3(out) 186 | 187 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None)#对输入平面实施2D卷积 188 | out = self.bn33(out) 189 | 190 | if self.downsample is not None: 191 | residual = self.downsample(x) 192 | 193 | out += residual 194 | out = self.relu3(out) 195 | 196 | return out 197 | 198 | def forward(self, x): 199 | xm=x 200 | base_feat=[]#重新定义数组 201 | if self.downsample is not None:#衔接段需要downsample 202 | x1 = self.forward_for_small(x) 203 | base_feat.append(x1) 204 | x2 = self.forward_for_middle(x) 205 | base_feat.append(x2) 206 | x3 = self.forward_for_big(x) 207 | base_feat.append(x3) 208 | else: 209 | x1 = self.forward_for_small(xm[0]) 210 | base_feat.append(x1) 211 | x2 = self.forward_for_middle(xm[1]) 212 | base_feat.append(x2) 213 | x3 = self.forward_for_big(xm[2]) 214 | base_feat.append(x3) 215 | return base_feat #三个分支 216 | 217 | class ResNet(nn.Module): 218 | # def __init__(self, block, layers, num_classes=1000):#layers数组,units个数 219 | def __init__(self, block = Bottleneck, block1 = trident_block, layers = [3, 4, 6, 3], num_classes=1000):#layers数组,units个数 220 | self.inplanes = 64 221 | super(ResNet, self).__init__() 222 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,#通道输入3,输出64 223 | bias=False) 224 | self.bn1 = nn.BatchNorm2d(64) 225 | self.relu = nn.ReLU(inplace=True) 226 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # 3*3 maxpooling 227 | self.layer1 = self._make_layer(block, 64, layers[0]) 228 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 229 | # self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 230 | # self.layer3 = trident_block(1024,256)#加入14个trident-block,输出为1024维,3个分支,feature map大小一样 231 | 232 | self.layer3= self._make_layer(block, 256, layers[2], stride=2) 233 | self.layer4 = self._make_layer1(block1, 512, layers[3], stride=2)#需要修改 234 | # it is slightly better whereas slower to set stride = 1 235 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 236 | self.avgpool = nn.AvgPool2d(7)#resnet自身的 237 | self.fc = nn.Linear(512 * block.expansion, num_classes)#全连接分类 238 | 239 | for m in self.modules(): 240 | if isinstance(m, nn.Conv2d): 241 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 242 | m.weight.data.normal_(0, math.sqrt(2. / n)) 243 | elif isinstance(m, nn.BatchNorm2d): 244 | m.weight.data.fill_(1) 245 | m.bias.data.zero_() 246 | 247 | def _make_layer(self, block, planes, blocks, stride=1): 248 | downsample = None 249 | if stride != 1 or self.inplanes != planes * block.expansion: 250 | downsample = nn.Sequential( 251 | nn.Conv2d(self.inplanes, planes * block.expansion, 252 | kernel_size=1, stride=stride, bias=False), 253 | nn.BatchNorm2d(planes * block.expansion),#shortcut用1*1卷积 254 | ) 255 | 256 | layers = [] 257 | layers.append(block(self.inplanes, planes, stride, downsample))#衔接段会出现通道不匹配,需要借助downsample 258 | self.inplanes = planes * block.expansion#维度保持一致 259 | for i in range(1, blocks): 260 | layers.append(block(self.inplanes, planes))#堆叠的block 261 | 262 | return nn.Sequential(*layers)#一个resnet-unit卷积 263 | 264 | def _make_layer1(self, block1, planes, blocks, stride=1): 265 | downsample = None 266 | if stride != 1 or self.inplanes != planes * block1.expansion: 267 | downsample = nn.Sequential( 268 | nn.Conv2d(self.inplanes, planes * block1.expansion, 269 | kernel_size=1, stride=stride, bias=False), 270 | nn.BatchNorm2d(planes * block1.expansion),#shortcut用1*1卷积 271 | ) 272 | 273 | layers = [] 274 | layers.append(block1(self.inplanes, planes, stride, downsample))#衔接段会出现通道不匹配,需要借助downsample 275 | self.inplanes = planes * block1.expansion#维度保持一致 276 | for i in range(1, blocks): 277 | layers.append(block1(self.inplanes, planes))#堆叠的block 278 | 279 | return nn.Sequential(*layers)#一个trident-block卷积 280 | 281 | def forward(self, x): 282 | x = self.conv1(x) 283 | x = self.bn1(x) 284 | x = self.relu(x) 285 | x = self.maxpool(x) 286 | 287 | x = self.layer1(x) 288 | x = self.layer2(x) 289 | x = self.layer3(x) 290 | x = self.layer4(x)#三个分支输出(进入RPN)-base feat=1*3,feature map大小一样 291 | #在这需要分三个分支,参数共享 292 | result = np.array(x) 293 | print(result[0].shape) 294 | print(result[1].shape) 295 | print(result[2].shape) 296 | for i in range(3): 297 | x[i] = self.avgpool(x[i]) 298 | x[i] = x[i].view(x[i].size(0), -1) 299 | x[i] = self.fc(x[i]) 300 | 301 | return x# 1*3输出 302 | 303 | 304 | def resnet18(pretrained=False): 305 | 306 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 307 | if pretrained: 308 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 309 | return model 310 | 311 | 312 | def resnet34(pretrained=False): 313 | 314 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 315 | if pretrained: 316 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 317 | return model 318 | 319 | 320 | def resnet50(pretrained=False): 321 | 322 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 323 | if pretrained: 324 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 325 | return model 326 | 327 | 328 | def resnet101(pretrained=False): 329 | 330 | model = ResNet(Bottleneck, trident_block, [3, 4, 23, 3]) 331 | if pretrained: 332 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 333 | return model 334 | 335 | def trident_50(): 336 | 337 | # model = ResNet(Bottleneck, trident_block, [3, 4, 6, 3])#论文采用15个trident-block 338 | model = ResNet() 339 | return model 340 | 341 | def resnet152(pretrained=False): 342 | 343 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 344 | if pretrained: 345 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 346 | return model 347 | 348 | 349 | 350 | -------------------------------------------------------------------------------- /trident_gn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from ..registry import BACKBONES 6 | from torch.nn import init 7 | from mmcv.cnn import constant_init, kaiming_init 8 | from mmcv.runner import load_checkpoint 9 | import math 10 | import numpy as np 11 | from .shufflenet_block import * 12 | import logging 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn1 = nn.GroupNorm(8, planes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.GroupNorm(8, planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | #CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 54 | def __init__(self, inplanes, planes, stride=1, downsample=None):#inplanes输入channel,planes输出channel 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) # change 57 | self.bn1 = nn.GroupNorm(8, planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, # change 59 | padding=1, bias=False) 60 | self.bn2 = nn.GroupNorm(8, planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.GroupNorm(8, planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | class trident_block(nn.Module): 90 | expansion = 4 91 | def __init__(self, inplanes, planes, stride=1, downsample=None, padding=[1, 2, 3], dilate=[1, 2, 3]): 92 | super(trident_block, self).__init__() 93 | self.stride = stride 94 | self.padding = padding 95 | self.dilate = dilate 96 | self.downsample = downsample 97 | self.share_weight4conv1 = nn.Parameter(torch.randn(planes, inplanes, 1, 1)) 98 | self.share_weight4conv2 = nn.Parameter(torch.randn(planes, planes, 3, 3)) 99 | self.share_weight4conv3 = nn.Parameter(torch.randn(planes * self.expansion, planes, 1, 1))#1*1/64, 3*3/64, 1*1/256 100 | 101 | self.bn11 = nn.GroupNorm(8, planes)#bn层 102 | self.bn12 = nn.GroupNorm(8, planes) 103 | self.bn13 = nn.GroupNorm(8, planes * self.expansion) 104 | 105 | self.bn21 = nn.GroupNorm(8, planes) 106 | self.bn22 = nn.GroupNorm(8, planes) 107 | self.bn23 = nn.GroupNorm(8, planes * self.expansion) 108 | 109 | self.bn31 = nn.GroupNorm(8, planes) 110 | self.bn32 = nn.GroupNorm(8, planes) 111 | self.bn33 = nn.GroupNorm(8, planes * self.expansion) 112 | 113 | self.relu1 = nn.ReLU(inplace=True)#relu层 114 | self.relu2 = nn.ReLU(inplace=True) 115 | self.relu3 = nn.ReLU(inplace=True) 116 | 117 | def forward_for_small(self, x): 118 | residual = x 119 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 120 | out = self.bn11(out) 121 | out = self.relu1(out) 122 | 123 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[0], dilation=self.dilate[0]) 124 | 125 | out = self.bn12(out) 126 | out = self.relu1(out) 127 | 128 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None) 129 | out = self.bn13(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | 134 | out += residual 135 | out = self.relu1(out) 136 | 137 | return out 138 | 139 | def forward_for_middle(self, x): 140 | residual = x 141 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 142 | out = self.bn21(out) 143 | out = self.relu2(out) 144 | 145 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[1],dilation=self.dilate[1]) 146 | 147 | out = self.bn22(out) 148 | out = self.relu2(out) 149 | 150 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None) 151 | out = self.bn23(out) 152 | 153 | if self.downsample is not None: 154 | residual = self.downsample(x) 155 | # print(out.shape) 156 | # print(residual.shape) 157 | 158 | out += residual 159 | out = self.relu2(out) 160 | 161 | return out 162 | 163 | def forward_for_big(self, x): 164 | residual = x 165 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 166 | out = self.bn31(out) 167 | out = self.relu3(out) 168 | 169 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[2], dilation=self.dilate[2]) 170 | 171 | out = self.bn32(out) 172 | out = self.relu3(out) 173 | 174 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None)#对输入平面实施2D卷积 175 | out = self.bn33(out) 176 | 177 | if self.downsample is not None: 178 | residual = self.downsample(x) 179 | 180 | out += residual 181 | out = self.relu3(out) 182 | 183 | return out 184 | 185 | def forward(self, x): 186 | xm=x 187 | base_feat=[]#重新定义数组 188 | if self.downsample is not None:#衔接段需要downsample 189 | x1 = self.forward_for_small(x) 190 | base_feat.append(x1) 191 | x2 = self.forward_for_middle(x) 192 | base_feat.append(x2) 193 | x3 = self.forward_for_big(x) 194 | base_feat.append(x3) 195 | else: 196 | x1 = self.forward_for_small(xm[0]) 197 | base_feat.append(x1) 198 | x2 = self.forward_for_middle(xm[1]) 199 | base_feat.append(x2) 200 | x3 = self.forward_for_big(xm[2]) 201 | base_feat.append(x3) 202 | return base_feat #三个分支 203 | 204 | @BACKBONES.register_module 205 | class TridentNet(nn.Module): 206 | # def __init__(self, block, layers, num_classes=1000):#layers数组,units个数 207 | def __init__(self, block=Bottleneck, block1=trident_block, layers=[3,4,6,3], num_classes=1000, norm_eval=True):#layers数组,units个数 208 | self.inplanes = 64 209 | super(TridentNet, self).__init__() 210 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 211 | bias=False) 212 | self.bn1 = nn.GroupNorm(8, 64) 213 | self.relu = nn.ReLU(inplace=True) 214 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # 3*3 maxpooling 215 | self.layer1 = self._make_layer(block, 64, layers[0]) 216 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 217 | self.layer3 = self._make_layer1(block1, 256, layers[2], stride=2) 218 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 219 | self.avgpool = nn.AvgPool2d(7) 220 | self.fc = nn.Linear(512 * block.expansion, num_classes)#全连接分类 221 | self.norm_eval = norm_eval 222 | 223 | 224 | def _make_layer(self, block, planes, blocks, stride=1): 225 | downsample = None 226 | if stride != 1 or self.inplanes != planes * block.expansion: 227 | downsample = nn.Sequential( 228 | nn.Conv2d(self.inplanes, planes * block.expansion, 229 | kernel_size=1, stride=stride, bias=False), 230 | nn.GroupNorm(8, planes * block.expansion),#shortcut用1*1卷积 231 | ) 232 | 233 | layers = [] 234 | layers.append(block(self.inplanes, planes, stride, downsample))#衔接段会出现通道不匹配,需要借助downsample 235 | self.inplanes = planes * block.expansion#维度保持一致 236 | for i in range(1, blocks): 237 | layers.append(block(self.inplanes, planes))#堆叠的block 238 | 239 | return nn.Sequential(*layers)#一个resnet-unit卷积 240 | 241 | def _make_layer1(self, block1, planes, blocks, stride=1): 242 | downsample = None 243 | if stride != 1 or self.inplanes != planes * block1.expansion: 244 | downsample = nn.Sequential( 245 | nn.Conv2d(self.inplanes, planes * block1.expansion, 246 | kernel_size=1, stride=stride, bias=False), 247 | nn.GroupNorm(8, planes * block1.expansion),#shortcut用1*1卷积 248 | ) 249 | 250 | layers = [] 251 | layers.append(block1(self.inplanes, planes, stride, downsample))#衔接段会出现通道不匹配,需要借助downsample 252 | self.inplanes = planes * block1.expansion#维度保持一致 253 | for i in range(1, blocks): 254 | layers.append(block1(self.inplanes, planes))#堆叠的block 255 | 256 | return nn.Sequential(*layers)#一个trident-block卷积 257 | 258 | def init_weights(self, pretrained=None): 259 | if isinstance(pretrained, str): 260 | logger = logging.getLogger() 261 | load_checkpoint(self, pretrained, strict=False, logger=logger) 262 | elif pretrained is None: 263 | for m in self.modules(): 264 | if isinstance(m, nn.Conv2d): 265 | kaiming_init(m) 266 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 267 | constant_init(m, 1) 268 | 269 | 270 | def forward(self, x): 271 | x = self.conv1(x) 272 | x = self.bn1(x) 273 | x = self.relu(x) 274 | x = self.maxpool(x) 275 | 276 | x = self.layer1(x) 277 | x = self.layer2(x) 278 | x = self.layer3(x) 279 | x = x[2] 280 | return x 281 | 282 | def train(self, mode=True): 283 | super(TridentNet, self).train(mode) 284 | if mode and self.norm_eval: 285 | for m in self.modules(): 286 | # trick: eval have effect on BatchNorm only 287 | if isinstance(m, (nn.BatchNorm2d)): 288 | m.eval() 289 | 290 | 291 | 292 | -------------------------------------------------------------------------------- /trident_paper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | import math 11 | import torch.utils.model_zoo as model_zoo 12 | import pdb 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101','resnet_trident_101', 16 | 'resnet152'] 17 | 18 | 19 | model_urls = { 20 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 21 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 22 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 25 | } 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | #CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 67 | def __init__(self, inplanes, planes, stride=1, downsample=None):#inplanes输入channel,planes输出channel 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) # change 70 | self.bn1 = nn.BatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, # change 72 | padding=1, bias=False) 73 | self.bn2 = nn.BatchNorm2d(planes) 74 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 75 | self.bn3 = nn.BatchNorm2d(planes * 4) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | class trident_block(nn.Module): 103 | expansion = 4 104 | def __init__(self, inplanes, planes, stride=1, downsample=None, padding=[1, 2, 3], dilate=[1, 2, 3]): 105 | super(trident_block, self).__init__() 106 | self.stride = stride 107 | self.padding = padding 108 | self.dilate = dilate 109 | self.downsample = downsample 110 | self.share_weight4conv1 = nn.Parameter(torch.randn(planes, inplanes, 1, 1)) 111 | self.share_weight4conv2 = nn.Parameter(torch.randn(planes, planes, 3, 3)) 112 | self.share_weight4conv3 = nn.Parameter(torch.randn(planes * self.expansion, planes, 1, 1))#1*1/64, 3*3/64, 1*1/256 113 | 114 | self.bn11 = nn.BatchNorm2d(planes)#bn层 115 | self.bn12 = nn.BatchNorm2d(planes) 116 | self.bn13 = nn.BatchNorm2d(planes * self.expansion) 117 | 118 | self.bn21 = nn.BatchNorm2d(planes) 119 | self.bn22 = nn.BatchNorm2d(planes) 120 | self.bn23 = nn.BatchNorm2d(planes * self.expansion) 121 | 122 | self.bn31 = nn.BatchNorm2d(planes) 123 | self.bn32 = nn.BatchNorm2d(planes) 124 | self.bn33 = nn.BatchNorm2d(planes * self.expansion) 125 | 126 | self.relu1 = nn.ReLU(inplace=True)#relu层 127 | self.relu2 = nn.ReLU(inplace=True) 128 | self.relu3 = nn.ReLU(inplace=True) 129 | 130 | def forward_for_small(self, x): 131 | residual = x 132 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 133 | out = self.bn11(out) 134 | out = self.relu1(out) 135 | 136 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[0], dilation=self.dilate[0]) 137 | 138 | out = self.bn12(out) 139 | out = self.relu1(out) 140 | 141 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None) 142 | out = self.bn13(out) 143 | 144 | if self.downsample is not None: 145 | residual = self.downsample(x) 146 | 147 | out += residual 148 | out = self.relu1(out) 149 | 150 | return out 151 | 152 | def forward_for_middle(self, x): 153 | residual = x 154 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 155 | out = self.bn21(out) 156 | out = self.relu2(out) 157 | 158 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[1],dilation=self.dilate[1]) 159 | 160 | out = self.bn22(out) 161 | out = self.relu2(out) 162 | 163 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None) 164 | out = self.bn23(out) 165 | 166 | if self.downsample is not None: 167 | residual = self.downsample(x) 168 | print(out.shape) 169 | print(residual.shape) 170 | 171 | out += residual 172 | out = self.relu2(out) 173 | 174 | return out 175 | 176 | def forward_for_big(self, x): 177 | residual = x 178 | out = nn.functional.conv2d(x, self.share_weight4conv1, bias=None) 179 | out = self.bn31(out) 180 | out = self.relu3(out) 181 | 182 | out = nn.functional.conv2d(out, self.share_weight4conv2, bias=None, stride=self.stride, padding=self.padding[2], dilation=self.dilate[2]) 183 | 184 | out = self.bn32(out) 185 | out = self.relu3(out) 186 | 187 | out = nn.functional.conv2d(out, self.share_weight4conv3, bias=None)#对输入平面实施2D卷积 188 | out = self.bn33(out) 189 | 190 | if self.downsample is not None: 191 | residual = self.downsample(x) 192 | 193 | out += residual 194 | out = self.relu3(out) 195 | 196 | return out 197 | 198 | def forward(self, x): 199 | xm=x 200 | base_feat=[]#重新定义数组 201 | if self.downsample is not None:#衔接段需要downsample 202 | x1 = self.forward_for_small(x) 203 | base_feat.append(x1) 204 | x2 = self.forward_for_middle(x) 205 | base_feat.append(x2) 206 | x3 = self.forward_for_big(x) 207 | base_feat.append(x3) 208 | else: 209 | x1 = self.forward_for_small(xm[0]) 210 | base_feat.append(x1) 211 | x2 = self.forward_for_middle(xm[1]) 212 | base_feat.append(x2) 213 | x3 = self.forward_for_big(xm[2]) 214 | base_feat.append(x3) 215 | return base_feat #三个分支 216 | 217 | class ResNet(nn.Module): 218 | # def __init__(self, block, layers, num_classes=1000):#layers数组,units个数 219 | def __init__(self, block = Bottleneck, block1 = trident_block, layers = [3, 4, 6, 3], num_classes=1000):#layers数组,units个数 220 | self.inplanes = 64 221 | super(ResNet, self).__init__() 222 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,#通道输入3,输出64 223 | bias=False) 224 | self.bn1 = nn.BatchNorm2d(64) 225 | self.relu = nn.ReLU(inplace=True) 226 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # 3*3 maxpooling 227 | self.layer1 = self._make_layer(block, 64, layers[0]) 228 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 229 | # self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 230 | # self.layer3 = trident_block(1024,256)#加入14个trident-block,输出为1024维,3个分支,feature map大小一样 231 | 232 | self.layer3= self._make_layer1(block1, 256, layers[2], stride=2) 233 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)#需要修改 234 | # it is slightly better whereas slower to set stride = 1 235 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 236 | self.avgpool = nn.AvgPool2d(7)#resnet自身的 237 | self.fc = nn.Linear(512 * block.expansion, num_classes)#全连接分类 238 | 239 | for m in self.modules(): 240 | if isinstance(m, nn.Conv2d): 241 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 242 | m.weight.data.normal_(0, math.sqrt(2. / n)) 243 | elif isinstance(m, nn.BatchNorm2d): 244 | m.weight.data.fill_(1) 245 | m.bias.data.zero_() 246 | 247 | def _make_layer(self, block, planes, blocks, stride=1): 248 | downsample = None 249 | if stride != 1 or self.inplanes != planes * block.expansion: 250 | downsample = nn.Sequential( 251 | nn.Conv2d(self.inplanes, planes * block.expansion, 252 | kernel_size=1, stride=stride, bias=False), 253 | nn.BatchNorm2d(planes * block.expansion),#shortcut用1*1卷积 254 | ) 255 | 256 | layers = [] 257 | layers.append(block(self.inplanes, planes, stride, downsample))#衔接段会出现通道不匹配,需要借助downsample 258 | self.inplanes = planes * block.expansion#维度保持一致 259 | for i in range(1, blocks): 260 | layers.append(block(self.inplanes, planes))#堆叠的block 261 | 262 | return nn.Sequential(*layers)#一个resnet-unit卷积 263 | 264 | def _make_layer1(self, block1, planes, blocks, stride=1): 265 | downsample = None 266 | if stride != 1 or self.inplanes != planes * block1.expansion: 267 | downsample = nn.Sequential( 268 | nn.Conv2d(self.inplanes, planes * block1.expansion, 269 | kernel_size=1, stride=stride, bias=False), 270 | nn.BatchNorm2d(planes * block1.expansion),#shortcut用1*1卷积 271 | ) 272 | 273 | layers = [] 274 | layers.append(block1(self.inplanes, planes, stride, downsample))#衔接段会出现通道不匹配,需要借助downsample 275 | self.inplanes = planes * block1.expansion#维度保持一致 276 | for i in range(1, blocks): 277 | layers.append(block1(self.inplanes, planes))#堆叠的block 278 | 279 | return nn.Sequential(*layers)#一个trident-block卷积 280 | 281 | def forward(self, x): 282 | x = self.conv1(x) 283 | x = self.bn1(x) 284 | x = self.relu(x) 285 | x = self.maxpool(x) 286 | 287 | x = self.layer1(x) 288 | x = self.layer2(x) 289 | x = self.layer3(x)#三个分支输出(进入RPN)-base feat=1*3,feature map大小一样 290 | #在这需要分三个分支,参数共享 291 | result = np.array(x) 292 | print(result[0].shape) 293 | print(result[1].shape) 294 | print(result[2].shape) 295 | x = torch.cat((x[0],x[1],x[2]),0) 296 | # x = x[0]+x[1]+x[2] 297 | print(x.shape) 298 | x = self.layer4(x) 299 | x = self.avgpool(x) 300 | x = x.view(x.size(0), -1) 301 | print(x.shape) 302 | x = self.fc(x) 303 | return x# 1*3输出 304 | 305 | 306 | def resnet18(pretrained=False): 307 | 308 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 309 | if pretrained: 310 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 311 | return model 312 | 313 | 314 | def resnet34(pretrained=False): 315 | 316 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 317 | if pretrained: 318 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 319 | return model 320 | 321 | 322 | def resnet50(pretrained=False): 323 | 324 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 325 | if pretrained: 326 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 327 | return model 328 | 329 | 330 | def resnet101(pretrained=False): 331 | 332 | model = ResNet(Bottleneck, trident_block, [3, 4, 23, 3]) 333 | if pretrained: 334 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 335 | return model 336 | 337 | def trident_50(): 338 | 339 | # model = ResNet(Bottleneck, trident_block, [3, 4, 6, 3])#论文采用15个trident-block 340 | model = ResNet() 341 | return model 342 | 343 | def resnet152(pretrained=False): 344 | 345 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 346 | if pretrained: 347 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 348 | return model 349 | 350 | --------------------------------------------------------------------------------