├── CBAMNet.py └── README.md /CBAMNet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | class CBAM_Module(nn.Module): 7 | 8 | def __init__(self, channels, reduction): 9 | super(CBAM_Module, self).__init__() 10 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | self.max_pool = nn.AdaptiveMaxPool2d(1) 12 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 13 | padding=0) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 16 | padding=0) 17 | self.sigmoid_channel = nn.Sigmoid() 18 | self.conv_after_concat = nn.Conv2d(2, 1, kernel_size = 3, stride=1, padding = 1) 19 | self.sigmoid_spatial = nn.Sigmoid() 20 | 21 | def forward(self, x): 22 | # Channel attention module 23 | module_input = x 24 | avg = self.avg_pool(x) 25 | mx = self.max_pool(x) 26 | avg = self.fc1(avg) 27 | mx = self.fc1(mx) 28 | avg = self.relu(avg) 29 | mx = self.relu(mx) 30 | avg = self.fc2(avg) 31 | mx = self.fc2(mx) 32 | x = avg + mx 33 | x = self.sigmoid_channel(x) 34 | # Spatial attention module 35 | x = module_input * x 36 | module_input = x 37 | avg = torch.mean(x, 1, True) 38 | mx, _ = torch.max(x, 1, True) 39 | x = torch.cat((avg, mx), 1) 40 | x = self.conv_after_concat(x) 41 | x = self.sigmoid_spatial(x) 42 | x = module_input * x 43 | return x 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | """ 48 | Base class for bottlenecks that implements `forward()` method. 49 | """ 50 | def forward(self, x): 51 | residual = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv3(out) 62 | out = self.bn3(out) 63 | 64 | if self.downsample is not None: 65 | residual = self.downsample(x) 66 | 67 | out = self.se_module(out) + residual 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | class CBAMResNetBottleneck(Bottleneck): 73 | """ 74 | ResNet bottleneck with a CBAM_Module. It follows Caffe 75 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 76 | (the latter is used in the torchvision implementation of ResNet). 77 | """ 78 | expansion = 4 79 | 80 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 81 | downsample=None): 82 | super(CBAMResNetBottleneck, self).__init__() 83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 84 | stride=stride) 85 | self.bn1 = nn.BatchNorm2d(planes) 86 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 87 | groups=groups, bias=False) 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 90 | self.bn3 = nn.BatchNorm2d(planes * 4) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.se_module = CBAM_Module(planes * 4, reduction=reduction) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | 97 | 98 | class CABMNet(nn.Module): 99 | 100 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 101 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 102 | downsample_padding=1, num_classes=1000): 103 | super(CABMNet, self).__init__() 104 | self.inplanes = inplanes 105 | if input_3x3: 106 | layer0_modules = [ 107 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 108 | bias=False)), 109 | ('bn1', nn.BatchNorm2d(64)), 110 | ('relu1', nn.ReLU(inplace=True)), 111 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 112 | bias=False)), 113 | ('bn2', nn.BatchNorm2d(64)), 114 | ('relu2', nn.ReLU(inplace=True)), 115 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 116 | bias=False)), 117 | ('bn3', nn.BatchNorm2d(inplanes)), 118 | ('relu3', nn.ReLU(inplace=True)), 119 | ] 120 | else: 121 | layer0_modules = [ 122 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 123 | padding=3, bias=False)), 124 | ('bn1', nn.BatchNorm2d(inplanes)), 125 | ('relu1', nn.ReLU(inplace=True)), 126 | ] 127 | # To preserve compatibility with Caffe weights `ceil_mode=True` 128 | # is used instead of `padding=1`. 129 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 130 | ceil_mode=True))) 131 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 132 | self.layer1 = self._make_layer( 133 | block, 134 | planes=64, 135 | blocks=layers[0], 136 | groups=groups, 137 | reduction=reduction, 138 | downsample_kernel_size=1, 139 | downsample_padding=0 140 | ) 141 | self.layer2 = self._make_layer( 142 | block, 143 | planes=128, 144 | blocks=layers[1], 145 | stride=2, 146 | groups=groups, 147 | reduction=reduction, 148 | downsample_kernel_size=downsample_kernel_size, 149 | downsample_padding=downsample_padding 150 | ) 151 | self.layer3 = self._make_layer( 152 | block, 153 | planes=256, 154 | blocks=layers[2], 155 | stride=2, 156 | groups=groups, 157 | reduction=reduction, 158 | downsample_kernel_size=downsample_kernel_size, 159 | downsample_padding=downsample_padding 160 | ) 161 | self.layer4 = self._make_layer( 162 | block, 163 | planes=512, 164 | blocks=layers[3], 165 | stride=2, 166 | groups=groups, 167 | reduction=reduction, 168 | downsample_kernel_size=downsample_kernel_size, 169 | downsample_padding=downsample_padding 170 | ) 171 | self.avg_pool = nn.AvgPool2d(7, stride=1) 172 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 173 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 174 | 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 178 | m.weight.data.normal_(0, math.sqrt(2. / n)) 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | 183 | # for m in self.modules(): 184 | # if isinstance(m, nn.Conv2d): 185 | # nn.init.kaiming_normal(m.weight.data) 186 | # elif isinstance(m, nn.BatchNorm2d): 187 | # m.weight.data.fill_(1) 188 | # m.bias.data.zero_() 189 | 190 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 191 | downsample_kernel_size=1, downsample_padding=0): 192 | downsample = None 193 | if stride != 1 or self.inplanes != planes * block.expansion: 194 | downsample = nn.Sequential( 195 | nn.Conv2d(self.inplanes, planes * block.expansion, 196 | kernel_size=downsample_kernel_size, stride=stride, 197 | padding=downsample_padding, bias=False), 198 | nn.BatchNorm2d(planes * block.expansion), 199 | ) 200 | 201 | layers = [] 202 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 203 | downsample)) 204 | self.inplanes = planes * block.expansion 205 | for i in range(1, blocks): 206 | layers.append(block(self.inplanes, planes, groups, reduction)) 207 | 208 | return nn.Sequential(*layers) 209 | 210 | def features(self, x): 211 | x = self.layer0(x) 212 | x = self.layer1(x) 213 | x = self.layer2(x) 214 | x = self.layer3(x) 215 | x = self.layer4(x) 216 | return x 217 | 218 | def logits(self, x): 219 | x = self.avg_pool(x) 220 | if self.dropout is not None: 221 | x = self.dropout(x) 222 | x = x.view(x.size(0), -1) 223 | x = self.last_linear(x) 224 | return x 225 | 226 | def forward(self, x): 227 | x = self.features(x) 228 | x = self.logits(x) 229 | return x 230 | 231 | 232 | 233 | 234 | def cbam_resnet50(num_classes=1000): 235 | model = CABMNet(CBAMResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, 236 | dropout_p=None, inplanes=64, input_3x3=False, 237 | downsample_kernel_size=1, downsample_padding=0, 238 | num_classes=num_classes) 239 | return model 240 | 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional-Block-Attention-Module 2 | - Unofficial implementation of [**Convolutional Block Attention Module (CBAM)**](https://arxiv.org/abs/1807.06521). 3 | --------------------------------------------------------------------------------