├── LICENSE ├── README.md ├── attentionnet3D.py └── attentionnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 CHANG Lufan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Residual-Attention-Network-for-Image-Classification-2D-3D-pytorch-implement 2 | 2D and 3D(volume) verison of Residual Attention Network 3 | 4 | Residual Attention Network for Image Classification (CVPR-2017 Spotlight) By Fei Wang, Mengqing Jiang, Chen Qian, Shuo Yang, Chen Li, Honggang Zhang, Xiaogang Wang, Xiaoou Tang* (https://openaccess.thecvf.com/content_cvpr_2017/papers/Wang_Residual_Attention_Network_CVPR_2017_paper.pdf) 5 | 6 | ## Use Examples: 7 | ### 2D: 8 | - `from attentionnet import attention56, attention92` 9 | - `attention56(num_classes=1000)` 10 | - `attention92(num_classes=1000)` 11 | ### 3D: 12 | - `from attentionnet3D import attention3d56, attention3d92` 13 | - `attention3d56(num_classes=1000)` 14 | - `attention3d92(num_classes=1000)` 15 | 16 | For the usage in a library, please refer to my fork on pretorched (https://github.com/moyiliyi/pretorched-x) 17 | 18 | Only the network architectures implemented here. You need to write your own train/test scripts. 19 | 20 | ## Reference 21 | This code is based on the following repos: 22 | - https://github.com/weiaicunzai/pytorch-cifar100 23 | - https://github.com/pytorch/vision/tree/master/torchvision 24 | - https://github.com/Tencent/MedicalNet 25 | -------------------------------------------------------------------------------- /attentionnet3D.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | 7 | __all__ = [ 8 | 'Attention3D', 'attention3d56', 'attention3d92' 9 | ] 10 | 11 | 12 | ''' Residual Bottleneck from Tencent/MedicalNet''' 13 | class ResidualBlock(nn.Module): 14 | expansion = 4 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(ResidualBlock, self).__init__() 18 | # Added for consistency 19 | planes = int(planes/4) 20 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 21 | self.bn1 = nn.BatchNorm3d(planes) 22 | self.conv2 = nn.Conv3d( 23 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm3d(planes) 25 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 26 | self.bn3 = nn.BatchNorm3d(planes * 4) 27 | self.relu = nn.ReLU(inplace=True) 28 | #self.downsample = downsample 29 | # Added: auto downsample 30 | self.downsample = nn.Sequential(nn.Conv3d(inplanes, planes *4 , kernel_size=1, stride=stride, bias = False), nn.BatchNorm3d(planes*4)) 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv3(out) 45 | out = self.bn3(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | class AttentionModule1(nn.Module): 56 | 57 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 58 | super().__init__() 59 | #"""The hyperparameter p denotes the number of preprocessing Residual 60 | #Units before splitting into trunk branch and mask branch. t denotes 61 | #the number of Residual Units in trunk branch. r denotes the number of 62 | #Residual Units between adjacent pooling layer in the mask branch.""" 63 | assert in_channels == out_channels 64 | 65 | self.pre = self._make_residual(in_channels, out_channels, p) 66 | self.trunk = self._make_residual(in_channels, out_channels, t) 67 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 68 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 69 | self.soft_resdown3 = self._make_residual(in_channels, out_channels, r) 70 | self.soft_resdown4 = self._make_residual(in_channels, out_channels, r) 71 | 72 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 73 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 74 | self.soft_resup3 = self._make_residual(in_channels, out_channels, r) 75 | self.soft_resup4 = self._make_residual(in_channels, out_channels, r) 76 | 77 | self.shortcut_short = ResidualBlock(in_channels, out_channels, 1) 78 | self.shortcut_long = ResidualBlock(in_channels, out_channels, 1) 79 | 80 | self.sigmoid = nn.Sequential( 81 | nn.BatchNorm3d(out_channels), 82 | nn.ReLU(inplace=True), 83 | nn.Conv3d(out_channels, out_channels, kernel_size=1), 84 | nn.BatchNorm3d(out_channels), 85 | nn.ReLU(inplace=True), 86 | nn.Conv3d(out_channels, out_channels, kernel_size=1), 87 | nn.Sigmoid() 88 | ) 89 | 90 | self.last = self._make_residual(in_channels, out_channels, p) 91 | 92 | def forward(self, x): 93 | ###We make the size of the smallest output map in each mask branch 7*7 to be consistent 94 | #with the smallest trunk output map size. 95 | ###Thus 3,2,1 max-pooling layers are used in mask branch with input size 56 * 56, 28 * 28, 14 * 14 respectively. 96 | x = self.pre(x) 97 | input_size = (x.size(2), x.size(3), x.size(4)) 98 | 99 | x_t = self.trunk(x) 100 | 101 | #first downsample out 28 102 | x_s = F.max_pool3d(x, kernel_size=3, stride=2, padding=1) 103 | x_s = self.soft_resdown1(x_s) 104 | 105 | #28 shortcut 106 | shape1 = (x_s.size(2), x_s.size(3), x_s.size(4)) 107 | shortcut_long = self.shortcut_long(x_s) 108 | 109 | #seccond downsample out 14 110 | x_s = F.max_pool3d(x, kernel_size=3, stride=2, padding=1) 111 | x_s = self.soft_resdown2(x_s) 112 | 113 | #14 shortcut 114 | shape2 = (x_s.size(2), x_s.size(3), x_s.size(4)) 115 | shortcut_short = self.soft_resdown3(x_s) 116 | 117 | #third downsample out 7 118 | x_s = F.max_pool3d(x, kernel_size=3, stride=2, padding=1) 119 | x_s = self.soft_resdown3(x_s) 120 | 121 | #mid ??? NOT IN THE PAPER 122 | x_s = self.soft_resdown4(x_s) 123 | x_s = self.soft_resup1(x_s) 124 | 125 | #first upsample out 14 126 | x_s = self.soft_resup2(x_s) 127 | x_s = F.interpolate(x_s, size=shape2) 128 | x_s += shortcut_short 129 | 130 | #second upsample out 28 131 | x_s = self.soft_resup3(x_s) 132 | x_s = F.interpolate(x_s, size=shape1) 133 | x_s += shortcut_long 134 | 135 | #thrid upsample out 54 136 | x_s = self.soft_resup4(x_s) 137 | x_s = F.interpolate(x_s, size=input_size) 138 | 139 | x_s = self.sigmoid(x_s) 140 | x = (1 + x_s) * x_t 141 | x = self.last(x) 142 | 143 | return x 144 | 145 | def _make_residual(self, in_channels, out_channels, p): 146 | 147 | layers = [] 148 | for _ in range(p): 149 | layers.append(ResidualBlock(in_channels, out_channels, 1)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | class AttentionModule2(nn.Module): 154 | 155 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 156 | super().__init__() 157 | #"""The hyperparameter p denotes the number of preprocessing Residual 158 | #Units before splitting into trunk branch and mask branch. t denotes 159 | #the number of Residual Units in trunk branch. r denotes the number of 160 | #Residual Units between adjacent pooling layer in the mask branch.""" 161 | assert in_channels == out_channels 162 | 163 | self.pre = self._make_residual(in_channels, out_channels, p) 164 | self.trunk = self._make_residual(in_channels, out_channels, t) 165 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 166 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 167 | self.soft_resdown3 = self._make_residual(in_channels, out_channels, r) 168 | 169 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 170 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 171 | self.soft_resup3 = self._make_residual(in_channels, out_channels, r) 172 | 173 | self.shortcut = ResidualBlock(in_channels, out_channels, 1) 174 | 175 | self.sigmoid = nn.Sequential( 176 | nn.BatchNorm3d(out_channels), 177 | nn.ReLU(inplace=True), 178 | nn.Conv3d(out_channels, out_channels, kernel_size=1), 179 | nn.BatchNorm3d(out_channels), 180 | nn.ReLU(inplace=True), 181 | nn.Conv3d(out_channels, out_channels, kernel_size=1), 182 | nn.Sigmoid() 183 | ) 184 | 185 | self.last = self._make_residual(in_channels, out_channels, p) 186 | 187 | def forward(self, x): 188 | x = self.pre(x) 189 | input_size = (x.size(2), x.size(3), x.size(4)) 190 | 191 | x_t = self.trunk(x) 192 | 193 | #first downsample out 14 194 | x_s = F.max_pool3d(x, kernel_size=3, stride=2, padding=1) 195 | x_s = self.soft_resdown1(x_s) 196 | 197 | #14 shortcut 198 | shape1 = (x_s.size(2), x_s.size(3), x_s.size(4)) 199 | shortcut = self.shortcut(x_s) 200 | 201 | #seccond downsample out 7 202 | x_s = F.max_pool3d(x, kernel_size=3, stride=2, padding=1) 203 | x_s = self.soft_resdown2(x_s) 204 | 205 | #mid 206 | x_s = self.soft_resdown3(x_s) 207 | x_s = self.soft_resup1(x_s) 208 | 209 | #first upsample out 14 210 | x_s = self.soft_resup2(x_s) 211 | x_s = F.interpolate(x_s, size=shape1) 212 | x_s += shortcut 213 | 214 | #second upsample out 28 215 | x_s = self.soft_resup3(x_s) 216 | x_s = F.interpolate(x_s, size=input_size) 217 | 218 | x_s = self.sigmoid(x_s) 219 | x = (1 + x_s) * x_t 220 | x = self.last(x) 221 | 222 | return x 223 | 224 | def _make_residual(self, in_channels, out_channels, p): 225 | 226 | layers = [] 227 | for _ in range(p): 228 | layers.append(ResidualBlock(in_channels, out_channels, 1)) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | class AttentionModule3(nn.Module): 233 | 234 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 235 | super().__init__() 236 | 237 | assert in_channels == out_channels 238 | 239 | self.pre = self._make_residual(in_channels, out_channels, p) 240 | self.trunk = self._make_residual(in_channels, out_channels, t) 241 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 242 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 243 | 244 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 245 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 246 | 247 | self.shortcut = ResidualBlock(in_channels, out_channels, 1) 248 | 249 | self.sigmoid = nn.Sequential( 250 | nn.BatchNorm3d(out_channels), 251 | nn.ReLU(inplace=True), 252 | nn.Conv3d(out_channels, out_channels, kernel_size=1), 253 | nn.BatchNorm3d(out_channels), 254 | nn.ReLU(inplace=True), 255 | nn.Conv3d(out_channels, out_channels, kernel_size=1), 256 | nn.Sigmoid() 257 | ) 258 | 259 | self.last = self._make_residual(in_channels, out_channels, p) 260 | 261 | def forward(self, x): 262 | x = self.pre(x) 263 | input_size = (x.size(2), x.size(3), x.size(4)) 264 | 265 | x_t = self.trunk(x) 266 | 267 | #first downsample out 14 268 | x_s = F.max_pool3d(x, kernel_size=3, stride=2, padding=1) 269 | x_s = self.soft_resdown1(x_s) 270 | 271 | #mid 272 | x_s = self.soft_resdown2(x_s) 273 | x_s = self.soft_resup1(x_s) 274 | 275 | #first upsample out 14 276 | x_s = self.soft_resup2(x_s) 277 | x_s = F.interpolate(x_s, size=input_size) 278 | 279 | x_s = self.sigmoid(x_s) 280 | x = (1 + x_s) * x_t 281 | x = self.last(x) 282 | 283 | return x 284 | 285 | def _make_residual(self, in_channels, out_channels, p): 286 | 287 | layers = [] 288 | for _ in range(p): 289 | layers.append(ResidualBlock(in_channels, out_channels, 1)) 290 | 291 | return nn.Sequential(*layers) 292 | 293 | class Attention3D(nn.Module): 294 | """residual attention netowrk 295 | Args: 296 | block_num: attention module number for each stage 297 | """ 298 | 299 | def __init__(self, block_num, num_classes=100,pretrained=None): 300 | 301 | super().__init__() 302 | self.pre_conv = nn.Sequential( 303 | nn.Conv3d(3, 64, kernel_size=3, stride=1, padding=1), 304 | nn.BatchNorm3d(64), 305 | nn.ReLU(inplace=True) 306 | ) 307 | 308 | self.stage1 = self._make_stage(64, 256, block_num[0], AttentionModule1) 309 | self.stage2 = self._make_stage(256, 512, block_num[1], AttentionModule2) 310 | self.stage3 = self._make_stage(512, 1024, block_num[2], AttentionModule3) 311 | self.stage4 = nn.Sequential( 312 | ResidualBlock(1024, 2048, stride=2), 313 | ResidualBlock(2048, 2048, stride=1), 314 | ResidualBlock(2048, 2048, stride=1) 315 | ) 316 | self.avg = nn.AdaptiveAvgPool3d(1) 317 | self.classifier = nn.Linear(2048, num_classes) 318 | 319 | for m in self.modules(): 320 | if isinstance(m, nn.Conv3d): 321 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 322 | elif isinstance(m, nn.BatchNorm3d): 323 | m.weight.data.fill_(1) 324 | m.bias.data.zero_() 325 | 326 | def forward(self, x): 327 | x = self.pre_conv(x) 328 | x = self.stage1(x) 329 | x = self.stage2(x) 330 | x = self.stage3(x) 331 | x = self.stage4(x) 332 | x = self.avg(x) 333 | x = x.view(x.size(0), -1) 334 | x = self.classifier(x) 335 | 336 | return x 337 | 338 | def _make_stage(self, in_channels, out_channels, num, block): 339 | 340 | layers = [] 341 | layers.append(ResidualBlock(in_channels, out_channels, stride=2)) 342 | 343 | for _ in range(num): 344 | layers.append(block(out_channels, out_channels)) 345 | 346 | return nn.Sequential(*layers) 347 | 348 | def attention3d56(**kwargs): 349 | return Attention3D([1, 1, 1], **kwargs) 350 | 351 | def attention3d92(**kwargs): 352 | return Attention3D([1, 2, 3], **kwargs) 353 | 354 | -------------------------------------------------------------------------------- /attentionnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | from typing import Type, Any, Callable, Union, List, Optional 7 | 8 | __all__ = [ 9 | 'Attention', 'attention56', 'attention92' 10 | ] 11 | 12 | ''' This code is based on https://github.com/weiaicunzai/pytorch-cifar100.git, with modification on base network''' 13 | 14 | ''' Residual Bottleneck from Torch Vision''' 15 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | class ResidualBlock(nn.Module): 26 | expansion = 4 27 | 28 | def __init__( 29 | self, 30 | inplanes: int, 31 | planes: int, 32 | stride: int = 1, 33 | downsample: Optional[nn.Module] = None, 34 | groups: int = 1, 35 | base_width: int = 64, 36 | dilation: int = 1, 37 | norm_layer: Optional[Callable[..., nn.Module]] = None 38 | ) -> None: 39 | super(ResidualBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | #width = int(planes * (base_width / 64.)) * groups 43 | planes = int(planes/4) 44 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 45 | self.conv1 = conv1x1(inplanes, planes) 46 | self.bn1 = norm_layer(planes) 47 | self.conv2 = conv3x3(planes, planes, stride, groups, dilation) 48 | self.bn2 = norm_layer(planes) 49 | self.conv3 = conv1x1(planes, planes * self.expansion) 50 | self.bn3 = norm_layer(planes * self.expansion) 51 | self.relu = nn.ReLU(inplace=True) 52 | #self.downsample = downsample 53 | self.stride = stride 54 | # Added: auto downsample 55 | self.downsample = nn.Sequential( 56 | conv1x1(inplanes, planes * self.expansion, stride), 57 | norm_layer(planes * self.expansion), 58 | ) 59 | 60 | def forward(self, x): 61 | residual = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv3(out) 72 | out = self.bn3(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(x) 76 | 77 | out += residual 78 | out = self.relu(out) 79 | 80 | return out 81 | 82 | class AttentionModule1(nn.Module): 83 | 84 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 85 | super().__init__() 86 | #"""The hyperparameter p denotes the number of preprocessing Residual 87 | #Units before splitting into trunk branch and mask branch. t denotes 88 | #the number of Residual Units in trunk branch. r denotes the number of 89 | #Residual Units between adjacent pooling layer in the mask branch.""" 90 | assert in_channels == out_channels 91 | 92 | self.pre = self._make_residual(in_channels, out_channels, p) 93 | self.trunk = self._make_residual(in_channels, out_channels, t) 94 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 95 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 96 | self.soft_resdown3 = self._make_residual(in_channels, out_channels, r) 97 | self.soft_resdown4 = self._make_residual(in_channels, out_channels, r) 98 | 99 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 100 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 101 | self.soft_resup3 = self._make_residual(in_channels, out_channels, r) 102 | self.soft_resup4 = self._make_residual(in_channels, out_channels, r) 103 | 104 | self.shortcut_short = ResidualBlock(in_channels, out_channels, 1) 105 | self.shortcut_long = ResidualBlock(in_channels, out_channels, 1) 106 | 107 | self.sigmoid = nn.Sequential( 108 | nn.BatchNorm2d(out_channels), 109 | nn.ReLU(inplace=True), 110 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 111 | nn.BatchNorm2d(out_channels), 112 | nn.ReLU(inplace=True), 113 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 114 | nn.Sigmoid() 115 | ) 116 | 117 | self.last = self._make_residual(in_channels, out_channels, p) 118 | 119 | def forward(self, x): 120 | ###We make the size of the smallest output map in each mask branch 7*7 to be consistent 121 | #with the smallest trunk output map size. 122 | ###Thus 3,2,1 max-pooling layers are used in mask branch with input size 56 * 56, 28 * 28, 14 * 14 respectively. 123 | x = self.pre(x) 124 | input_size = (x.size(2), x.size(3)) 125 | 126 | x_t = self.trunk(x) 127 | 128 | #first downsample out 28 129 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 130 | x_s = self.soft_resdown1(x_s) 131 | 132 | #28 shortcut 133 | shape1 = (x_s.size(2), x_s.size(3)) 134 | shortcut_long = self.shortcut_long(x_s) 135 | 136 | #seccond downsample out 14 137 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 138 | x_s = self.soft_resdown2(x_s) 139 | 140 | #14 shortcut 141 | shape2 = (x_s.size(2), x_s.size(3)) 142 | shortcut_short = self.soft_resdown3(x_s) 143 | 144 | #third downsample out 7 145 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 146 | x_s = self.soft_resdown3(x_s) 147 | 148 | #mid 149 | x_s = self.soft_resdown4(x_s) 150 | x_s = self.soft_resup1(x_s) 151 | 152 | #first upsample out 14 153 | x_s = self.soft_resup2(x_s) 154 | x_s = F.interpolate(x_s, size=shape2) 155 | x_s += shortcut_short 156 | 157 | #second upsample out 28 158 | x_s = self.soft_resup3(x_s) 159 | x_s = F.interpolate(x_s, size=shape1) 160 | x_s += shortcut_long 161 | 162 | #thrid upsample out 54 163 | x_s = self.soft_resup4(x_s) 164 | x_s = F.interpolate(x_s, size=input_size) 165 | 166 | x_s = self.sigmoid(x_s) 167 | x = (1 + x_s) * x_t 168 | x = self.last(x) 169 | 170 | return x 171 | 172 | def _make_residual(self, in_channels, out_channels, p): 173 | 174 | layers = [] 175 | for _ in range(p): 176 | layers.append(ResidualBlock(in_channels, out_channels, 1)) 177 | 178 | return nn.Sequential(*layers) 179 | 180 | class AttentionModule2(nn.Module): 181 | 182 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 183 | super().__init__() 184 | #"""The hyperparameter p denotes the number of preprocessing Residual 185 | #Units before splitting into trunk branch and mask branch. t denotes 186 | #the number of Residual Units in trunk branch. r denotes the number of 187 | #Residual Units between adjacent pooling layer in the mask branch.""" 188 | assert in_channels == out_channels 189 | 190 | self.pre = self._make_residual(in_channels, out_channels, p) 191 | self.trunk = self._make_residual(in_channels, out_channels, t) 192 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 193 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 194 | self.soft_resdown3 = self._make_residual(in_channels, out_channels, r) 195 | 196 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 197 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 198 | self.soft_resup3 = self._make_residual(in_channels, out_channels, r) 199 | 200 | self.shortcut = ResidualBlock(in_channels, out_channels, 1) 201 | 202 | self.sigmoid = nn.Sequential( 203 | nn.BatchNorm2d(out_channels), 204 | nn.ReLU(inplace=True), 205 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 206 | nn.BatchNorm2d(out_channels), 207 | nn.ReLU(inplace=True), 208 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 209 | nn.Sigmoid() 210 | ) 211 | 212 | self.last = self._make_residual(in_channels, out_channels, p) 213 | 214 | def forward(self, x): 215 | x = self.pre(x) 216 | input_size = (x.size(2), x.size(3)) 217 | 218 | x_t = self.trunk(x) 219 | 220 | #first downsample out 14 221 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 222 | x_s = self.soft_resdown1(x_s) 223 | 224 | #14 shortcut 225 | shape1 = (x_s.size(2), x_s.size(3)) 226 | shortcut = self.shortcut(x_s) 227 | 228 | #seccond downsample out 7 229 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 230 | x_s = self.soft_resdown2(x_s) 231 | 232 | #mid 233 | x_s = self.soft_resdown3(x_s) 234 | x_s = self.soft_resup1(x_s) 235 | 236 | #first upsample out 14 237 | x_s = self.soft_resup2(x_s) 238 | x_s = F.interpolate(x_s, size=shape1) 239 | x_s += shortcut 240 | 241 | #second upsample out 28 242 | x_s = self.soft_resup3(x_s) 243 | x_s = F.interpolate(x_s, size=input_size) 244 | 245 | x_s = self.sigmoid(x_s) 246 | x = (1 + x_s) * x_t 247 | x = self.last(x) 248 | 249 | return x 250 | 251 | def _make_residual(self, in_channels, out_channels, p): 252 | 253 | layers = [] 254 | for _ in range(p): 255 | layers.append(ResidualBlock(in_channels, out_channels, 1)) 256 | 257 | return nn.Sequential(*layers) 258 | 259 | class AttentionModule3(nn.Module): 260 | 261 | def __init__(self, in_channels, out_channels, p=1, t=2, r=1): 262 | super().__init__() 263 | 264 | assert in_channels == out_channels 265 | 266 | self.pre = self._make_residual(in_channels, out_channels, p) 267 | self.trunk = self._make_residual(in_channels, out_channels, t) 268 | self.soft_resdown1 = self._make_residual(in_channels, out_channels, r) 269 | self.soft_resdown2 = self._make_residual(in_channels, out_channels, r) 270 | 271 | self.soft_resup1 = self._make_residual(in_channels, out_channels, r) 272 | self.soft_resup2 = self._make_residual(in_channels, out_channels, r) 273 | 274 | self.shortcut = ResidualBlock(in_channels, out_channels, 1) 275 | 276 | self.sigmoid = nn.Sequential( 277 | nn.BatchNorm2d(out_channels), 278 | nn.ReLU(inplace=True), 279 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 280 | nn.BatchNorm2d(out_channels), 281 | nn.ReLU(inplace=True), 282 | nn.Conv2d(out_channels, out_channels, kernel_size=1), 283 | nn.Sigmoid() 284 | ) 285 | 286 | self.last = self._make_residual(in_channels, out_channels, p) 287 | 288 | def forward(self, x): 289 | x = self.pre(x) 290 | input_size = (x.size(2), x.size(3)) 291 | 292 | x_t = self.trunk(x) 293 | 294 | #first downsample out 14 295 | x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 296 | x_s = self.soft_resdown1(x_s) 297 | 298 | #mid 299 | x_s = self.soft_resdown2(x_s) 300 | x_s = self.soft_resup1(x_s) 301 | 302 | #first upsample out 14 303 | x_s = self.soft_resup2(x_s) 304 | x_s = F.interpolate(x_s, size=input_size) 305 | 306 | x_s = self.sigmoid(x_s) 307 | x = (1 + x_s) * x_t 308 | x = self.last(x) 309 | 310 | return x 311 | 312 | def _make_residual(self, in_channels, out_channels, p): 313 | 314 | layers = [] 315 | for _ in range(p): 316 | layers.append(ResidualBlock(in_channels, out_channels, 1)) 317 | 318 | return nn.Sequential(*layers) 319 | 320 | class Attention(nn.Module): 321 | """residual attention netowrk 322 | Args: 323 | block_num: attention module number for each stage 324 | """ 325 | 326 | def __init__(self, block_num, num_classes=100,pretrained=None): 327 | 328 | super().__init__() 329 | self.pre_conv = nn.Sequential( 330 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), 331 | nn.BatchNorm2d(64), 332 | nn.ReLU(inplace=True) 333 | ) 334 | 335 | self.stage1 = self._make_stage(64, 256, block_num[0], AttentionModule1) 336 | self.stage2 = self._make_stage(256, 512, block_num[1], AttentionModule2) 337 | self.stage3 = self._make_stage(512, 1024, block_num[2], AttentionModule3) 338 | self.stage4 = nn.Sequential( 339 | ResidualBlock(1024, 2048, 2), 340 | ResidualBlock(2048, 2048, 1), 341 | ResidualBlock(2048, 2048, 1) 342 | ) 343 | self.avg = nn.AdaptiveAvgPool2d(1) 344 | self.classifier = nn.Linear(2048, num_classes) 345 | 346 | for m in self.modules(): 347 | if isinstance(m, nn.Conv2d): 348 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 349 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 350 | nn.init.constant_(m.weight, 1) 351 | nn.init.constant_(m.bias, 0) 352 | 353 | def forward(self, x): 354 | x = self.pre_conv(x) 355 | x = self.stage1(x) 356 | x = self.stage2(x) 357 | x = self.stage3(x) 358 | x = self.stage4(x) 359 | x = self.avg(x) 360 | x = x.view(x.size(0), -1) 361 | x = self.classifier(x) 362 | 363 | return x 364 | 365 | def _make_stage(self, in_channels, out_channels, num, block): 366 | 367 | layers = [] 368 | layers.append(ResidualBlock(in_channels, out_channels, 2)) 369 | 370 | for _ in range(num): 371 | layers.append(block(out_channels, out_channels)) 372 | 373 | return nn.Sequential(*layers) 374 | 375 | def attention56(**kwargs): 376 | return Attention([1, 1, 1], **kwargs) 377 | 378 | def attention92(**kwargs): 379 | return Attention([1, 2, 3], **kwargs) 380 | 381 | --------------------------------------------------------------------------------