├── README.md └── dbt_pytorch.py /README.md: -------------------------------------------------------------------------------- 1 | # DBT_Net 2 | Pytorch version of model in NeurIPS'19 paper "Learning Deep Bilinear Transformation for Fine-grained Image Representation" 3 | the code is transformed from the ‘https://github.com/researchmm/DBTNet 4 | -------------------------------------------------------------------------------- /dbt_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy 5 | import math 6 | import os 7 | 8 | class GroupConv(nn.Module): 9 | def __init__(self, in_channels, out_channels, width, num_group): 10 | super(GroupConv, self).__init__() 11 | self.num_group = num_group 12 | self.in_channels = in_channels 13 | self.out_channels = out_channels 14 | self.matrix_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 15 | self.bn = nn.BatchNorm2d(out_channels) 16 | self.relu = nn.ReLU(True) 17 | 18 | nn.init.constant_(self.matrix_conv.weight, 1.0) 19 | nn.init.constant_(self.matrix_conv.bias, 0.1) 20 | self.loss = 0 21 | def forward(self, x): 22 | channels = self.out_channels 23 | # matrix_act = super(GroupConv, self).forward(x) # 分组映射矩阵,核尺寸为1的卷积层 24 | matrix_act = self.matrix_conv(x) 25 | matrix_act = self.bn(matrix_act) 26 | matrix_act = self.relu(matrix_act) 27 | 28 | tmp = matrix_act + 0.001 29 | b, c, w, h = tmp.shape 30 | width = w 31 | tmp = tmp.view(int((b*c*w*h)/(width*width)), width*width) 32 | tmp = F.normalize(tmp, p=2) 33 | tmp = tmp.view(b, channels, width*width) 34 | tmp = tmp.permute(1, 0, 2) 35 | tmp = tmp.reshape(channels, b*w*h) 36 | 37 | tmp_T = tmp.transpose(1,0) 38 | co = tmp.mm(tmp_T) 39 | co = co.view(1, channels*channels) 40 | co = co / 128 41 | 42 | gt = torch.ones((self.num_group)) 43 | gt = gt.diag() 44 | gt = gt.reshape((1, 1, self.num_group, self.num_group)) 45 | gt = gt.repeat((1, int((channels/self.num_group)*(channels/self.num_group)), 1, 1)) 46 | gt = F.pixel_shuffle(gt, upscale_factor=int(channels/self.num_group)) 47 | gt = gt.reshape((1, channels*channels)) 48 | 49 | loss_single = torch.sum((co-gt)*(co-gt)*0.001, dim=1) 50 | loss = loss_single.repeat(b) 51 | loss = loss / ((channels/512.0)*(channels/512.0)) 52 | 53 | self.loss = loss 54 | return matrix_act 55 | 56 | class GroupBillinear(nn.Module): 57 | def __init__(self, num_group, width, channels): 58 | super(GroupBillinear, self).__init__() 59 | self.num_group = num_group 60 | self.num_per_group = int(channels/num_group) 61 | self.channels = channels 62 | self.fc = nn.Linear(channels, channels, bias=True) 63 | self.bn = nn.BatchNorm2d(channels) 64 | # self.BL = nn.Bilinear(self.num_group, self.num_group, channels) 65 | def forward(self, x): 66 | b, c, w, h = x.shape 67 | width = w 68 | num_dim = b*c*w*h 69 | tmp = x.permute(0, 2, 3, 1) 70 | 71 | tmp = tmp.reshape(num_dim//self.channels, self.channels) 72 | my_tmp = self.fc(tmp) 73 | tmp = tmp + my_tmp 74 | 75 | tmp = tmp.reshape(((num_dim//self.channels), self.num_group, self.num_per_group)) 76 | tmp_T = tmp.permute((0,2,1)) 77 | 78 | 79 | # tmp = self.BL(tmp_T, tmp_T) 80 | # tmp = tmp.reshape((b, self.width, self.width, c)) 81 | # tmp = tmp.permute((0,3,1,2)) 82 | 83 | 84 | tmp = torch.tanh(torch.bmm(tmp_T, tmp)/32) 85 | tmp = tmp.reshape((b, width, width, self.num_per_group*self.num_per_group)) 86 | # tmp = F.upsample_bilinear(tmp, (width, c)) 87 | tmp = F.interpolate(tmp, (width, c)) 88 | tmp = tmp.permute((0,3,1,2)) 89 | 90 | 91 | out = x + self.bn(tmp) 92 | return out 93 | 94 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 95 | model_urls = { 96 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 97 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 98 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 99 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 100 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 101 | } 102 | 103 | 104 | def conv3x3(in_planes, out_planes, stride=1): 105 | """3x3 convolution with padding""" 106 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 107 | padding=1, bias=False) 108 | 109 | 110 | def conv1x1(in_planes, out_planes, stride=1): 111 | """1x1 convolution""" 112 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 113 | 114 | 115 | class BasicBlock(nn.Module): 116 | expansion = 1 117 | 118 | def __init__(self, inplanes, planes, stride=1, downsample=None): 119 | super(BasicBlock, self).__init__() 120 | self.conv1 = conv3x3(inplanes, planes, stride) 121 | self.bn1 = nn.BatchNorm2d(planes) 122 | self.relu = nn.ReLU(inplace=True) 123 | # self.act = nn.PReLU() 124 | self.conv2 = conv3x3(planes, planes) 125 | self.bn2 = nn.BatchNorm2d(planes) 126 | self.downsample = downsample 127 | self.stride = stride 128 | self.inplaces = inplanes 129 | self.planes = planes 130 | self.conv_ch = conv1x1(inplanes, planes, stride=1) 131 | 132 | def forward(self, x): 133 | if self.inplaces != self.planes: 134 | identity = self.conv_ch(x) 135 | identity = self.bn1(identity) 136 | identity = self.relu(identity) 137 | else: 138 | identity = x 139 | 140 | out = self.conv1(x) 141 | out = self.bn1(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | 147 | if self.downsample is not None: 148 | identity = self.downsample(x) 149 | 150 | out += identity 151 | out = self.relu(out) 152 | 153 | return out 154 | 155 | class Bottleneck(nn.Module): 156 | expansion = 4 157 | 158 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_SG_GB=False, featuremap_size=0): 159 | super(Bottleneck, self).__init__() 160 | 161 | self.bn1 = nn.BatchNorm2d(planes) 162 | self.conv2 = conv3x3(planes, planes, stride) 163 | self.bn2 = nn.BatchNorm2d(planes) 164 | self.conv3 = conv1x1(planes, planes * self.expansion) 165 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 166 | self.relu = nn.ReLU(inplace=True) 167 | # self.act = nn.PReLU() 168 | self.downsample = downsample 169 | self.stride = stride 170 | self.use_SG_GB = use_SG_GB 171 | if self.use_SG_GB: 172 | self.SG = GroupConv(inplanes, planes, featuremap_size, 16) 173 | self.GB = GroupBillinear(16, featuremap_size, planes) 174 | self.conv1 = conv3x3(planes, planes) 175 | else: 176 | self.conv1 = conv1x1(inplanes, planes) 177 | 178 | def forward(self, x): 179 | identity = x 180 | 181 | if self.use_SG_GB: 182 | out = self.SG(x) 183 | out = self.GB(out) 184 | out = self.conv1(out) 185 | else: 186 | out = self.conv1(x) 187 | out = self.bn1(out) 188 | out = self.relu(out) 189 | 190 | out = self.conv2(out) 191 | out = self.bn2(out) 192 | out = self.relu(out) 193 | 194 | out = self.conv3(out) 195 | out = self.bn3(out) 196 | 197 | if self.downsample is not None: 198 | identity = self.downsample(x) 199 | 200 | out += identity 201 | out = self.relu(out) 202 | 203 | return out 204 | 205 | class ResNet(nn.Module): 206 | 207 | def __init__(self, block, layers, num_classes=4, zero_init_residual=True): 208 | super(ResNet, self).__init__() 209 | self.inplanes = 64 210 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 211 | self.conv1_sim = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 212 | self.bn1 = nn.BatchNorm2d(64) 213 | self.relu = nn.ReLU(inplace=True) 214 | # self.act = nn.PReLU() 215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 216 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 217 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 218 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 219 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 220 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 221 | 222 | self.out = nn.Sequential( 223 | nn.Linear(512 * block.expansion, num_classes), 224 | nn.Sigmoid(), 225 | ) 226 | 227 | def _make_layer(self, block, planes, blocks, stride=1): 228 | downsample = None 229 | if stride != 1 or self.inplanes != planes * block.expansion: 230 | downsample = nn.Sequential( 231 | conv1x1(self.inplanes, planes * block.expansion, stride), 232 | nn.BatchNorm2d(planes * block.expansion), 233 | ) 234 | 235 | layers = [] 236 | layers.append(block(self.inplanes, planes, stride, downsample)) 237 | self.inplanes = planes * block.expansion 238 | for _ in range(1, blocks): 239 | layers.append(block(self.inplanes, planes)) 240 | 241 | return nn.Sequential(*layers) 242 | 243 | def forward(self, x): 244 | x = self.conv1(x) 245 | x = self.bn1(x) 246 | x = self.relu(x) 247 | x = self.maxpool(x) 248 | 249 | x = self.layer1(x) 250 | x = self.layer2(x) 251 | x = self.layer3(x) 252 | x = self.layer4(x) 253 | 254 | x = self.avgpool(x) 255 | # x = F.dropout2d(x, p=0.25, training=self.training) 256 | x = x.view(x.size(0), -1) 257 | x = self.out(x) 258 | 259 | return x 260 | 261 | class ResNet_SG_GB(nn.Module): 262 | 263 | def __init__(self, block, layers, num_classes=4, zero_init_residual=True, down_1=False): 264 | super(ResNet_SG_GB, self).__init__() 265 | self.inplanes = 64 266 | self.featuremap_size = 224 267 | self.down_1 = down_1 268 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 269 | self.featuremap_size = int(self.featuremap_size * 0.5) 270 | # self.conv1_sim = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 271 | self.bn1 = nn.BatchNorm2d(64) 272 | self.relu = nn.ReLU(inplace=True) 273 | # self.act = nn.PReLU() 274 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 275 | self.featuremap_size = int(self.featuremap_size * 0.5) 276 | self.all_gconvs = [] 277 | 278 | if down_1: 279 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 280 | else: 281 | self.layer1 = self._make_layer(block, 64, layers[0]) 282 | self.featuremap_size = int(self.featuremap_size * 0.5) 283 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 284 | self.featuremap_size = int(self.featuremap_size * 0.5) 285 | self.layer3 = self._make_layer_SG_GB(block, 256, layers[2], stride=2) 286 | self.featuremap_size = int(self.featuremap_size * 0.5) 287 | self.layer4 = self._make_layer_SG_GB(block, 512, layers[3], stride=2) 288 | self.featuremap_size = int(self.featuremap_size * 0.5) 289 | 290 | self.SG_end = GroupConv(512 * block.expansion, 512 * block.expansion, self.featuremap_size, 32) 291 | self.all_gconvs.append(self.SG_end) 292 | self.GB_end = GroupBillinear(32, self.featuremap_size, 512 * block.expansion) 293 | self.bn_end = nn.BatchNorm2d(512*block.expansion) 294 | 295 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 296 | self.out = nn.Sequential( 297 | nn.Linear(512 * block.expansion, num_classes), 298 | nn.Sigmoid(), 299 | ) 300 | 301 | 302 | def _make_layer(self, block, planes, blocks, stride=1): 303 | downsample = None 304 | if stride != 1 or self.inplanes != planes * block.expansion: 305 | downsample = nn.Sequential( 306 | conv1x1(self.inplanes, planes * block.expansion, stride), 307 | nn.BatchNorm2d(planes * block.expansion), 308 | ) 309 | 310 | layers = [] 311 | layers.append(block(self.inplanes, planes, stride, downsample)) 312 | self.inplanes = planes * block.expansion 313 | for _ in range(1, blocks): 314 | layers.append(block(self.inplanes, planes)) 315 | 316 | return nn.Sequential(*layers) 317 | def _make_layer_SG_GB(self, block, planes, blocks, stride=1): 318 | downsample = None 319 | if stride != 1 or self.inplanes != planes * block.expansion: 320 | downsample = nn.Sequential( 321 | conv1x1(self.inplanes, planes * block.expansion, stride), 322 | nn.BatchNorm2d(planes * block.expansion), 323 | ) 324 | 325 | layers = [] 326 | my_block = block(self.inplanes, planes, stride, downsample, True, self.featuremap_size) 327 | layers.append(my_block) 328 | self.all_gconvs.append(my_block.SG) 329 | self.inplanes = planes * block.expansion 330 | for _ in range(1, blocks): 331 | my_block = block(self.inplanes, planes, 1, None, True, self.featuremap_size) 332 | layers.append(my_block) 333 | self.all_gconvs.append(my_block.SG) 334 | return nn.Sequential(*layers) 335 | 336 | 337 | def forward(self, x): 338 | x = self.conv1(x) 339 | x = self.bn1(x) 340 | x = self.relu(x) 341 | x = self.maxpool(x) 342 | 343 | x = self.layer1(x) 344 | 345 | x = self.layer2(x) 346 | 347 | x = self.layer3(x) 348 | 349 | x = self.layer4(x) 350 | 351 | x = self.SG_end(x) 352 | x = self.GB_end(x) 353 | x = self.bn_end(x) 354 | 355 | x = self.avgpool(x) 356 | # x = F.dropout2d(x, p=0.25, training=self.training) 357 | x = x.view(x.size(0), -1) 358 | x = self.out(x) 359 | 360 | cnt = 0 361 | for sg in self.all_gconvs: 362 | if cnt == 0: 363 | loss = sg.loss 364 | else: 365 | loss = loss + sg.loss 366 | cnt = cnt + 1 367 | loss_sg = loss/cnt 368 | 369 | return x, loss_sg 370 | 371 | 372 | # temp_model = ResNet_SG_GB(Bottleneck, (3,4,6,3), 5) 373 | 374 | # print(temp_model.parameters) 375 | 376 | # temp_input = torch.randn(5, 3, 224, 224) 377 | # temp_label = torch.empty(5, dtype=torch.long).random_(3) 378 | 379 | # loss_fun = nn.CrossEntropyLoss() 380 | # loss_fun_2 = nn.MSELoss(reduction=False) 381 | 382 | # temp_model.zero_grad() 383 | # temp_out, temp_loss = temp_model(temp_input) 384 | 385 | 386 | # loss = loss_fun(temp_out, temp_label) 387 | # temp_loss_ = temp_loss*0 388 | # loss_matrix = loss_fun_2(temp_loss, temp_loss_)*1e-4 389 | 390 | # loss_all = loss + loss_matrix 391 | # loss_all.backward() 392 | 393 | 394 | # temp_model_SG = GroupConv(in_channels=64, out_channels=64, width=128, num_group=8) 395 | # temp_model_GB = GroupBillinear(num_group=8, width=128, channels=64) 396 | # temp_input = torch.randn(2, 64, 128, 128) 397 | # temp_out = temp_model_SG(temp_input) 398 | # out = temp_model_GB(temp_out) 399 | 400 | # print('ok') 401 | 402 | --------------------------------------------------------------------------------