├── Densenet_attention.py ├── LICENSE ├── README.md ├── ckpt └── AADFNet │ └── ckptPlaceHolder ├── config.py ├── datasets.py ├── densenet ├── Densenet.py ├── Densenet.pyc ├── __init__.py ├── __init__.pyc ├── config.py └── config.pyc ├── infer.py ├── joint_transforms.py ├── misc.py └── train_densenet_attention.py /Densenet_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from densenet import Densenet 5 | 6 | 7 | class AADFNet(nn.Module): 8 | 9 | def __init__(self): 10 | super(AADFNet, self).__init__() 11 | densenet = Densenet() 12 | self.layer0 = densenet.layer0 13 | self.layer1 = densenet.layer1 14 | self.layer2 = densenet.layer2 15 | self.layer3 = densenet.layer3 16 | self.layer4 = densenet.layer4 17 | 18 | self.aspp_layer4 = _ASPP_attention(2208, 32) 19 | self.aspp_layer3 = _ASPP_attention(2112, 32) 20 | self.aspp_layer2 = _ASPP_attention(768, 32) 21 | self.aspp_layer1 = _ASPP_attention(384, 32) 22 | 23 | self.predict4 = nn.Sequential( 24 | nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 25 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 26 | nn.Conv2d(64, 1, kernel_size=1) 27 | ) 28 | self.predict41 = nn.Conv2d(32, 1, kernel_size=1) 29 | self.predict42 = nn.Conv2d(32, 1, kernel_size=1) 30 | self.predict43 = nn.Conv2d(32, 1, kernel_size=1) 31 | self.predict44 = nn.Conv2d(32, 1, kernel_size=1) 32 | 33 | 34 | self.predict3 = nn.Sequential( 35 | nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 36 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 37 | nn.Conv2d(64, 1, kernel_size=1) 38 | ) 39 | self.predict31 = nn.Conv2d(33, 1, kernel_size=1) 40 | self.predict32 = nn.Conv2d(33, 1, kernel_size=1) 41 | self.predict33 = nn.Conv2d(33, 1, kernel_size=1) 42 | self.predict34 = nn.Conv2d(33, 1, kernel_size=1) 43 | 44 | self.predict2 = nn.Sequential( 45 | nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 46 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 47 | nn.Conv2d(64, 1, kernel_size=1) 48 | ) 49 | self.predict21 = nn.Conv2d(33, 1, kernel_size=1) 50 | self.predict22 = nn.Conv2d(33, 1, kernel_size=1) 51 | self.predict23 = nn.Conv2d(33, 1, kernel_size=1) 52 | self.predict24 = nn.Conv2d(33, 1, kernel_size=1) 53 | 54 | self.predict1 = nn.Sequential( 55 | nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 56 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 57 | nn.Conv2d(64, 1, kernel_size=1) 58 | ) 59 | self.predict11 = nn.Conv2d(33, 1, kernel_size=1) 60 | self.predict12 = nn.Conv2d(33, 1, kernel_size=1) 61 | self.predict13 = nn.Conv2d(33, 1, kernel_size=1) 62 | self.predict14 = nn.Conv2d(33, 1, kernel_size=1) 63 | 64 | self.predict4_2 = nn.Sequential( 65 | nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 66 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 67 | nn.Conv2d(64, 1, kernel_size=1) 68 | ) 69 | 70 | self.predict3_2 = nn.Sequential( 71 | nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 72 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 73 | nn.Conv2d(64, 1, kernel_size=1) 74 | ) 75 | 76 | self.predict2_2 = nn.Sequential( 77 | nn.Conv2d(129, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 78 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU(), 79 | nn.Conv2d(64, 1, kernel_size=1) 80 | ) 81 | 82 | self.residual3 = nn.Sequential( 83 | nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 84 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 85 | nn.Conv2d(128, 128, kernel_size=1) 86 | ) 87 | self.residual2 = nn.Sequential( 88 | nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 89 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 90 | nn.Conv2d(128, 128, kernel_size=1) 91 | ) 92 | self.residual1 = nn.Sequential( 93 | nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 94 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 95 | nn.Conv2d(128, 128, kernel_size=1) 96 | ) 97 | self.residual2_2 = nn.Sequential( 98 | nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 99 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 100 | nn.Conv2d(128, 128, kernel_size=1) 101 | ) 102 | self.residual3_2 = nn.Sequential( 103 | nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 104 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 105 | nn.Conv2d(128, 128, kernel_size=1) 106 | ) 107 | self.residual4_2 = nn.Sequential( 108 | nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 109 | nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU(), 110 | nn.Conv2d(128, 128, kernel_size=1) 111 | ) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout): 115 | m.inplace = True 116 | 117 | def forward(self, x): 118 | layer0 = self.layer0(x) 119 | layer1 = self.layer1(layer0) 120 | layer2 = self.layer2(layer1) 121 | layer3 = self.layer3(layer2) 122 | # layer3 = self.nlb(layer3) 123 | layer4 = self.layer4(layer3) 124 | 125 | aspp_layer41, aspp_layer42, aspp_layer43, aspp_layer44, aspp_layer4 = self.aspp_layer4(layer4) 126 | aspp_layer31, aspp_layer32, aspp_layer33, aspp_layer34, aspp_layer3 = self.aspp_layer3(layer3) 127 | aspp_layer21, aspp_layer22, aspp_layer23, aspp_layer24, aspp_layer2 = self.aspp_layer2(layer2) 128 | aspp_layer11, aspp_layer12, aspp_layer13, aspp_layer14, aspp_layer1 = self.aspp_layer1(layer1) 129 | 130 | 131 | predict41 = self.predict41(aspp_layer41) 132 | predict42 = self.predict42(aspp_layer42) 133 | predict43 = self.predict43(aspp_layer43) 134 | predict44 = self.predict44(aspp_layer44) 135 | 136 | predict4 = self.predict4(aspp_layer4) 137 | predict4 = F.upsample(predict4, size=layer3.size()[2:], mode='bilinear') 138 | aspp_layer4 = F.upsample(aspp_layer4, size=layer3.size()[2:], mode='bilinear') 139 | 140 | predict31 = self.predict31(torch.cat((predict4, aspp_layer31), 1)) + predict4 141 | predict32 = self.predict32(torch.cat((predict4, aspp_layer32), 1)) + predict4 142 | predict33 = self.predict33(torch.cat((predict4, aspp_layer33), 1)) + predict4 143 | predict34 = self.predict34(torch.cat((predict4, aspp_layer34), 1)) + predict4 144 | 145 | fpn_layer3 = aspp_layer3 + self.residual3(torch.cat((aspp_layer4, aspp_layer3), 1)) 146 | predict3 = self.predict3(torch.cat((predict4, fpn_layer3), 1)) + predict4 147 | predict3 = F.upsample(predict3, size=layer2.size()[2:], mode='bilinear') 148 | fpn_layer3 = F.upsample(fpn_layer3, size=layer2.size()[2:], mode='bilinear') 149 | 150 | predict21 = self.predict21(torch.cat((predict3, aspp_layer21), 1)) + predict3 151 | predict22 = self.predict22(torch.cat((predict3, aspp_layer22), 1)) + predict3 152 | predict23 = self.predict23(torch.cat((predict3, aspp_layer23), 1)) + predict3 153 | predict24 = self.predict24(torch.cat((predict3, aspp_layer24), 1)) + predict3 154 | 155 | fpn_layer2 = aspp_layer2 + self.residual2(torch.cat((fpn_layer3, aspp_layer2), 1)) 156 | predict2 = self.predict2(torch.cat((predict3, fpn_layer2), 1)) + predict3 157 | predict2 = F.upsample(predict2, size=layer1.size()[2:], mode='bilinear') 158 | fpn_layer2 = F.upsample(fpn_layer2, size=layer1.size()[2:], mode='bilinear') 159 | 160 | predict11 = self.predict11(torch.cat((predict2, aspp_layer11), 1)) + predict2 161 | predict12 = self.predict12(torch.cat((predict2, aspp_layer12), 1)) + predict2 162 | predict13 = self.predict13(torch.cat((predict2, aspp_layer13), 1)) + predict2 163 | predict14 = self.predict14(torch.cat((predict2, aspp_layer14), 1)) + predict2 164 | 165 | fpn_layer1 = aspp_layer1 + self.residual1(torch.cat((fpn_layer2, aspp_layer1), 1)) 166 | predict1 = self.predict1(torch.cat((predict2, fpn_layer1), 1)) + predict2 167 | 168 | fpn_layer4 = F.upsample(aspp_layer4, size=layer1.size()[2:], mode='bilinear') 169 | fpn_layer3 = F.upsample(fpn_layer3, size=layer1.size()[2:], mode='bilinear') 170 | fpn_layer2 = F.upsample(fpn_layer2, size=layer1.size()[2:], mode='bilinear') 171 | 172 | fpn_layer2_2 = fpn_layer2 + self.residual2_2(torch.cat((fpn_layer2, fpn_layer1), 1)) 173 | predict2_2 = self.predict2_2(torch.cat((predict1, fpn_layer2_2), 1)) + predict1 174 | 175 | fpn_layer3_2 = fpn_layer3 + self.residual3_2(torch.cat((fpn_layer3, fpn_layer2), 1)) 176 | predict3_2 = self.predict3_2(torch.cat((predict2_2, fpn_layer3_2), 1)) + predict2_2 177 | 178 | fpn_layer4_2 = fpn_layer4 + self.residual4_2(torch.cat((fpn_layer4, fpn_layer3), 1)) 179 | predict4_2 = self.predict4_2(torch.cat((predict3_2, fpn_layer4_2), 1)) + predict3_2 180 | 181 | predict4 = F.upsample(predict4, size=x.size()[2:], mode='bilinear') 182 | predict3 = F.upsample(predict3, size=x.size()[2:], mode='bilinear') 183 | predict2 = F.upsample(predict2, size=x.size()[2:], mode='bilinear') 184 | predict1 = F.upsample(predict1, size=x.size()[2:], mode='bilinear') 185 | 186 | predict4_2 = F.upsample(predict4_2, size=x.size()[2:], mode='bilinear') 187 | predict3_2 = F.upsample(predict3_2, size=x.size()[2:], mode='bilinear') 188 | predict2_2 = F.upsample(predict2_2, size=x.size()[2:], mode='bilinear') 189 | 190 | predict44 = F.upsample(predict44, size=x.size()[2:], mode='bilinear') 191 | predict43 = F.upsample(predict43, size=x.size()[2:], mode='bilinear') 192 | predict42 = F.upsample(predict42, size=x.size()[2:], mode='bilinear') 193 | predict41 = F.upsample(predict41, size=x.size()[2:], mode='bilinear') 194 | 195 | predict34 = F.upsample(predict34, size=x.size()[2:], mode='bilinear') 196 | predict33 = F.upsample(predict33, size=x.size()[2:], mode='bilinear') 197 | predict32 = F.upsample(predict32, size=x.size()[2:], mode='bilinear') 198 | predict31 = F.upsample(predict31, size=x.size()[2:], mode='bilinear') 199 | 200 | predict24 = F.upsample(predict24, size=x.size()[2:], mode='bilinear') 201 | predict23 = F.upsample(predict23, size=x.size()[2:], mode='bilinear') 202 | predict22 = F.upsample(predict22, size=x.size()[2:], mode='bilinear') 203 | predict21 = F.upsample(predict21, size=x.size()[2:], mode='bilinear') 204 | 205 | predict14 = F.upsample(predict14, size=x.size()[2:], mode='bilinear') 206 | predict13 = F.upsample(predict13, size=x.size()[2:], mode='bilinear') 207 | predict12 = F.upsample(predict12, size=x.size()[2:], mode='bilinear') 208 | predict11 = F.upsample(predict11, size=x.size()[2:], mode='bilinear') 209 | 210 | if self.training: 211 | return predict4_2, predict3_2, predict2_2, predict1, predict2, predict3, predict4,\ 212 | predict41, predict42, predict43, predict44, \ 213 | predict31, predict32, predict33, predict34, \ 214 | predict21, predict22, predict23, predict24, \ 215 | predict11, predict12, predict13, predict14, 216 | return F.sigmoid(predict4_2) 217 | 218 | 219 | class _ASPP_attention(nn.Module): 220 | 221 | def __init__(self, in_dim, out_dim): 222 | super(_ASPP_attention, self).__init__() 223 | 224 | self.conv1 = nn.Sequential( 225 | nn.Conv2d(in_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU() 226 | ) 227 | 228 | self.conv2 = nn.Sequential( 229 | nn.Conv2d(in_dim + out_dim, out_dim, kernel_size=3, dilation=2, padding=2), nn.BatchNorm2d(out_dim), nn.PReLU() 230 | ) 231 | self.conv3 = nn.Sequential( 232 | nn.Conv2d(in_dim + out_dim * 2, out_dim, kernel_size=3, dilation=4, padding=4), nn.BatchNorm2d(out_dim), nn.PReLU() 233 | ) 234 | self.conv4 = nn.Sequential( 235 | nn.Conv2d(in_dim + out_dim * 3, out_dim, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(out_dim), nn.PReLU() 236 | ) 237 | 238 | self.fuse1 = nn.Sequential( 239 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 240 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 241 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU() 242 | ) 243 | self.fuse2 = nn.Sequential( 244 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 245 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 246 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU() 247 | ) 248 | self.fuse3 = nn.Sequential( 249 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 250 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 251 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU() 252 | ) 253 | self.fuse4 = nn.Sequential( 254 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 255 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 256 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU() 257 | ) 258 | 259 | self.attention4_local = nn.Sequential( 260 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 261 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 262 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 263 | ) 264 | self.attention3_local = nn.Sequential( 265 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 266 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 267 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 268 | ) 269 | self.attention2_local = nn.Sequential( 270 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 271 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 272 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 273 | ) 274 | self.attention1_local = nn.Sequential( 275 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 276 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 277 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 278 | ) 279 | 280 | self.attention4_global = nn.Sequential( 281 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 282 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 283 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 284 | ) 285 | self.attention3_global = nn.Sequential( 286 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 287 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 288 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 289 | ) 290 | self.attention2_global = nn.Sequential( 291 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 292 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 293 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 294 | ) 295 | self.attention1_global = nn.Sequential( 296 | nn.Conv2d(2 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 297 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 298 | nn.Conv2d(out_dim, out_dim, kernel_size=1), nn.Softmax2d() 299 | ) 300 | 301 | 302 | self.refine4 = nn.Sequential( 303 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 304 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 305 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU() 306 | ) 307 | self.refine3 = nn.Sequential( 308 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 309 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 310 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU() 311 | ) 312 | self.refine2 = nn.Sequential( 313 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 314 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 315 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU() 316 | ) 317 | self.refine1 = nn.Sequential( 318 | nn.Conv2d(3 * out_dim, out_dim, kernel_size=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 319 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU(), 320 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), nn.BatchNorm2d(out_dim), nn.PReLU() 321 | ) 322 | 323 | 324 | 325 | def forward(self, x): 326 | conv1 = self.conv1(x) 327 | conv2 = self.conv2(torch.cat((x, conv1), 1)) 328 | conv3 = self.conv3(torch.cat((x, conv1, conv2), 1)) 329 | conv4 = self.conv4(torch.cat((x, conv1, conv2, conv3), 1)) 330 | 331 | fusion4_123 = self.fuse4(torch.cat((conv1, conv2, conv3), 1)) 332 | fusion3_124 = self.fuse3(torch.cat((conv1, conv2, conv4), 1)) 333 | fusion2_134 = self.fuse2(torch.cat((conv1, conv3, conv4), 1)) 334 | fusion1_234 = self.fuse1(torch.cat((conv2, conv3, conv4), 1)) 335 | 336 | attention4_local = self.attention4_local(torch.cat((conv4, fusion4_123), 1)) 337 | attention3_local = self.attention3_local(torch.cat((conv3, fusion3_124), 1)) 338 | attention2_local = self.attention2_local(torch.cat((conv2, fusion2_134), 1)) 339 | attention1_local = self.attention1_local(torch.cat((conv1, fusion1_234), 1)) 340 | 341 | attention4_global = F.upsample(self.attention4_global(F.adaptive_avg_pool2d(torch.cat((conv4, fusion4_123), 1), 1)), 342 | size=x.size()[2:], mode='bilinear', align_corners=True) 343 | attention3_global = F.upsample(self.attention3_global(F.adaptive_avg_pool2d(torch.cat((conv3, fusion3_124), 1), 1)), 344 | size=x.size()[2:], mode='bilinear', align_corners=True) 345 | attention2_global = F.upsample(self.attention2_global(F.adaptive_avg_pool2d(torch.cat((conv2, fusion2_134), 1), 1)), 346 | size=x.size()[2:], mode='bilinear', align_corners=True) 347 | attention1_global = F.upsample(self.attention1_global(F.adaptive_avg_pool2d(torch.cat((conv1, fusion1_234), 1), 1)), 348 | size=x.size()[2:], mode='bilinear', align_corners=True) 349 | 350 | refine4 = self.refine4(torch.cat((fusion4_123 * attention4_local, fusion4_123 * attention4_global, conv4), 1)) 351 | refine3 = self.refine3(torch.cat((fusion3_124 * attention3_local, fusion3_124 * attention3_global, conv3), 1)) 352 | refine2 = self.refine2(torch.cat((fusion2_134 * attention2_local, fusion2_134 * attention2_global, conv2), 1)) 353 | refine1 = self.refine1(torch.cat((fusion1_234 * attention1_local, fusion1_234 * attention1_global, conv1), 1)) 354 | refine_fusion = torch.cat((refine1, refine2, refine3, refine4), 1) 355 | 356 | return refine1, refine2, refine3, refine4, refine_fusion 357 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jiaxing Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Aggregating Attentional Dilated Features for Salient Object Detection 2 | by Lei Zhu^, Jiaxing Chen^, Xiaowei Hu, Chi-Wing Fu, Xuemiao Xu, Jing Qin, and Pheng-Ann Heng (^ joint 1st authors)[[paper link](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8836095)] 3 | 4 | This implementation is written by Jiaxing Chen at the South China University of Technology. 5 | 6 | ## Citation 7 | 8 | @article{zhu2019aggregating, 9 |      title={Aggregating Attentional Dilated Features for Salient Object Detection}, 10 |      author={Zhu, Lei and Chen, Jiaxing and Hu, Xiaowei and Fu, Chi-Wing and Xu, Xuemiao and Qin, Jing and Heng, Pheng-Ann}, 11 |      journal={IEEE Transactions on Circuits and Systems for Video Technology}, 12 |      year = {2019}, 13 |      publisher={IEEE} 14 | } 15 | 16 | ## Saliency Map 17 | 18 | The results of salient object detection on seven datasets (ECSSD, HKU-IS, PASCAL-S, SOD, DUT-OMRON, DUTS-TE, SOC) can be found at [Google Drive](https://drive.google.com/open?id=1tv72yWNH0ANHoSU4qMOwD7g5r53wSZEe). 19 | 20 | ## Trained Model 21 | 22 | You can download the trained model which is reported in our paper at [Google Drive](https://drive.google.com/file/d/1AWFG6x2lLNTUttBIxzrCFH277wkRuFX-/view?usp=sharing). 23 | 24 | ## Requirement 25 | 26 | - Python 2.7 27 | - PyTorch 0.4.0 28 | - torchvision 29 | - numpy 30 | - Cython 31 | - pydensecrf ([here](https://github.com/Andrew-Qibin/dss_crf) to install) 32 | 33 | ## Training 34 | 35 | 1. Set the path of pretrained DenseNet model in densenet/config.py 36 | 2. Set the path of DUTS dataset in config.py 37 | 3. Run by `python train.py` 38 | 39 | *Hyper-parameters* of training were gathered at the beginning of *train.py* and you can conveniently change them as you need. 40 | 41 | ## Testing 42 | 43 | 1. Set the path of six benchmark datasets in config.py 44 | 2. Put the trained model in ckpt/AADFNet 45 | 3. Run by `python infer.py` 46 | 47 | *Settings* of testing were gathered at the beginning of *infer.py* and you can conveniently change them as you need. 48 | 49 | ## Dataset links 50 | 51 | - [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html), [HKU-IS](https://sites.google.com/site/ligb86/hkuis), [PASCAL-S](http://cbi.gatech.edu/salobj/), [SOD](http://elderlab.yorku.ca/SOD/), [DUT-OMRON](http://ice.dlut.edu.cn/lu/DUT-OMRON/Homepage.htm), [DUTS](http://saliencydetection.net/duts/), [SOC](http://dpfan.net/SOCBenchmark/), : the seven benchmark datasets 52 | -------------------------------------------------------------------------------- /ckpt/AADFNet/ckptPlaceHolder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubBingoChen/AADF-Net/f5b2149fba1cc88bbe65c1c7e407e46889ac5b09/ckpt/AADFNet/ckptPlaceHolder -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | datasets_root = '/home/b3-542/Documents/datasets/SaliencyDatasets' 5 | 6 | # For each dataset, I put images and masks together 7 | msra10k_path = os.path.join(datasets_root, 'msra10k') 8 | ecssd_path = os.path.join(datasets_root, 'ecssd') 9 | hkuis_path = os.path.join(datasets_root, 'hkuis') 10 | pascals_path = os.path.join(datasets_root, 'pascals') 11 | dutomron_path = os.path.join(datasets_root, 'dutomron') 12 | duts_path = os.path.join(datasets_root, 'duts') 13 | duts_train_path = os.path.join(datasets_root, 'duts_train') 14 | sod_path = os.path.join(datasets_root, 'sod') 15 | soc_path = os.path.join(datasets_root, 'soc') 16 | soc_val_path = os.path.join(datasets_root, 'soc_val') 17 | thur15k_path = os.path.join(datasets_root, 'thur15k') 18 | 19 | pytorch_pretrained_root = '/home/b3-542/Packages/Models/PyTorch Pretrained' 20 | pretrained_res50_path = os.path.join(pytorch_pretrained_root, 'ResNet', 'resnet50-19c8e357.pth') 21 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import torch.utils.data as data 5 | from PIL import Image 6 | 7 | 8 | def make_dataset(root): 9 | img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')] 10 | return [(os.path.join(root, img_name + '.jpg'), os.path.join(root, img_name + '.png')) for img_name in img_list] 11 | 12 | 13 | class ImageFolder(data.Dataset): 14 | # image and gt should be in the same folder and have same filename except extended name (jpg and png respectively) 15 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None): 16 | self.root = root 17 | self.imgs = make_dataset(root) 18 | self.joint_transform = joint_transform 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | 22 | def __getitem__(self, index): 23 | img_path, gt_path = self.imgs[index] 24 | img = Image.open(img_path).convert('RGB') 25 | target = Image.open(gt_path).convert('L') 26 | if self.joint_transform is not None: 27 | img, target = self.joint_transform(img, target) 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | if self.target_transform is not None: 31 | target = self.target_transform(target) 32 | 33 | return img, target 34 | 35 | def __len__(self): 36 | return len(self.imgs) 37 | -------------------------------------------------------------------------------- /densenet/Densenet.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch 3 | from torch import nn 4 | from config import pretrained_densenet_path 5 | import re 6 | 7 | 8 | class Densenet(nn.Module): 9 | def __init__(self): 10 | super(Densenet, self).__init__() 11 | densenet = models.densenet161() 12 | 13 | pattern = re.compile( 14 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 15 | state_dict = torch.load(pretrained_densenet_path) 16 | for key in list(state_dict.keys()): 17 | res = pattern.match(key) 18 | if res: 19 | new_key = res.group(1) + res.group(2) 20 | state_dict[new_key] = state_dict[key] 21 | del state_dict[key] 22 | 23 | densenet.load_state_dict(state_dict) 24 | 25 | self.layer0 = nn.Sequential(densenet.features.conv0, densenet.features.norm0, densenet.features.relu0,densenet.features.pool0) 26 | self.layer1 = nn.Sequential(densenet.features.denseblock1) 27 | self.layer2 = nn.Sequential(densenet.features.transition1, densenet.features.denseblock2) 28 | self.layer3 = nn.Sequential(densenet.features.transition2, densenet.features.denseblock3) 29 | self.layer4 = nn.Sequential(densenet.features.transition3, densenet.features.denseblock4) 30 | 31 | 32 | def forward(self, x): 33 | layer0 = self.layer0(x) 34 | 35 | layer1 = self.layer1(layer0) 36 | layer2 = self.layer2(layer1) 37 | layer3 = self.layer3(layer2) 38 | layer4 = self.layer4(layer3) 39 | return layer4 40 | 41 | 42 | -------------------------------------------------------------------------------- /densenet/Densenet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubBingoChen/AADF-Net/f5b2149fba1cc88bbe65c1c7e407e46889ac5b09/densenet/Densenet.pyc -------------------------------------------------------------------------------- /densenet/__init__.py: -------------------------------------------------------------------------------- 1 | from Densenet import Densenet -------------------------------------------------------------------------------- /densenet/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubBingoChen/AADF-Net/f5b2149fba1cc88bbe65c1c7e407e46889ac5b09/densenet/__init__.pyc -------------------------------------------------------------------------------- /densenet/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | pytorch_pretrained_root = '/home/b3-542/Packages/Models/PyTorch Pretrained' 3 | pretrained_densenet_path = os.path.join(pytorch_pretrained_root, 'DenseNet', 'densenet161-17b70270.pth') 4 | -------------------------------------------------------------------------------- /densenet/config.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/githubBingoChen/AADF-Net/f5b2149fba1cc88bbe65c1c7e407e46889ac5b09/densenet/config.pyc -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | 5 | import torch 6 | from PIL import Image 7 | from torch.autograd import Variable 8 | from torchvision import transforms 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | from config import ecssd_path, hkuis_path, pascals_path, sod_path, dutomron_path, duts_path, thur15k_path, soc_path 13 | from misc import check_mkdir, crf_refine, AvgMeter, cal_precision_recall_mae, cal_fmeasure, cal_fmeasure_both 14 | 15 | from Densenet_attention import AADFNet 16 | 17 | torch.manual_seed(2018) 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 20 | # set which gpu to use 21 | # torch.cuda.set_device(1) 22 | 23 | # the following two args specify the location of the file of trained model (pth extension) 24 | # you should have the pth file in the folder './$ckpt_path$/$exp_name$' 25 | args = { 26 | 'snapshot': '30000', # your snapshot filename (exclude extension name) 27 | 'crf_refine': True, # whether to use crf to refine results 28 | 'save_results': False # whether to save the resulting masks 29 | } 30 | 31 | ckpt_path = './ckpt' 32 | exp_name = 'AADFNet' 33 | exp_predict = args['snapshot'] + ' predict1' 34 | 35 | 36 | 37 | 38 | img_transform = transforms.Compose([ 39 | transforms.Resize((400, 400)), 40 | transforms.ToTensor(), 41 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 42 | ]) 43 | 44 | to_pil = transforms.ToPILImage() 45 | to_test = {'ecssd': ecssd_path, 'hkuis': hkuis_path, 'pascal': pascals_path, 'dutomron': dutomron_path, 'duts': duts_path, 'sod': sod_path,'soc': soc_path} 46 | 47 | def main(): 48 | net = AADFNet().cuda() 49 | net = nn.DataParallel(net, device_ids=[0]) 50 | 51 | print exp_name + 'crf: '+ str(args['crf_refine']) 52 | print 'load snapshot \'%s\' for testing' % args['snapshot'] 53 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 54 | net.eval() 55 | 56 | with torch.no_grad(): 57 | results = {} 58 | 59 | for name, root in to_test.iteritems(): 60 | 61 | precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)] 62 | mae_record = AvgMeter() 63 | time_record = AvgMeter() 64 | 65 | img_list = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')] 66 | 67 | for idx, img_name in enumerate(img_list): 68 | img_name = img_list[idx] 69 | print 'predicting for %s: %d / %d' % (name, idx + 1, len(img_list)) 70 | check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) 71 | 72 | start = time.time() 73 | img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') 74 | img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() 75 | prediction = net(img_var) 76 | W, H = img.size 77 | prediction = F.upsample_bilinear(prediction, size=(H, W)) 78 | prediction = np.array(to_pil(prediction.data.squeeze(0).cpu())) 79 | 80 | 81 | if args['crf_refine']: 82 | prediction = crf_refine(np.array(img), prediction) 83 | 84 | end = time.time() 85 | 86 | gt = np.array(Image.open(os.path.join(root, img_name + '.png')).convert('L')) 87 | precision, recall, mae = cal_precision_recall_mae(prediction, gt) 88 | for pidx, pdata in enumerate(zip(precision, recall)): 89 | p, r = pdata 90 | precision_record[pidx].update(p) 91 | recall_record[pidx].update(r) 92 | 93 | mae_record.update(mae) 94 | time_record.update(end-start) 95 | 96 | 97 | 98 | if args['save_results']: 99 | Image.fromarray(prediction).save(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % ( 100 | exp_name, name, args['snapshot']), img_name + '.png')) 101 | 102 | max_fmeasure, mean_fmeasure = cal_fmeasure_both([precord.avg for precord in precision_record], 103 | [rrecord.avg for rrecord in recall_record]) 104 | results[name] = {'max_fmeasure': max_fmeasure, 'mae': mae_record.avg, 'mean_fmeasure': mean_fmeasure} 105 | 106 | print 'test results:' 107 | print results 108 | 109 | with open('Result', 'a') as f: 110 | if args['crf_refine']: 111 | f.write('with CRF') 112 | 113 | f.write('Runing time %.6f \n' % time_record.avg) 114 | f.write('\n%s\n iter:%s\n' % (exp_name, args['snapshot'])) 115 | for name, value in results.iteritems(): 116 | f.write('%s: mean_fmeasure: %.10f, mae: %.10f, max_fmeasure: %.10f\n' % ( 117 | name, value['mean_fmeasure'], value['mae'], value['max_fmeasure'])) 118 | 119 | 120 | 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /joint_transforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | 4 | from PIL import Image, ImageOps 5 | 6 | 7 | class Compose(object): 8 | def __init__(self, transforms): 9 | self.transforms = transforms 10 | 11 | def __call__(self, img, mask): 12 | assert img.size == mask.size 13 | for t in self.transforms: 14 | img, mask = t(img, mask) 15 | return img, mask 16 | 17 | 18 | class RandomCrop(object): 19 | def __init__(self, size, padding=0): 20 | if isinstance(size, numbers.Number): 21 | self.size = (int(size), int(size)) 22 | else: 23 | self.size = size 24 | self.padding = padding 25 | 26 | def __call__(self, img, mask): 27 | if self.padding > 0: 28 | img = ImageOps.expand(img, border=self.padding, fill=0) 29 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 30 | 31 | assert img.size == mask.size 32 | w, h = img.size 33 | th, tw = self.size 34 | if w == tw and h == th: 35 | return img, mask 36 | if w < tw or h < th: 37 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 38 | 39 | x1 = random.randint(0, w - tw) 40 | y1 = random.randint(0, h - th) 41 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 42 | 43 | 44 | class RandomHorizontallyFlip(object): 45 | def __call__(self, img, mask): 46 | if random.random() < 0.5: 47 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 48 | return img, mask 49 | 50 | 51 | class RandomRotate(object): 52 | def __init__(self, degree): 53 | self.degree = degree 54 | 55 | def __call__(self, img, mask): 56 | rotate_degree = random.random() * 2 * self.degree - self.degree 57 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 58 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import pydensecrf.densecrf as dcrf 5 | 6 | 7 | class AvgMeter(object): 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def check_mkdir(dir_name): 25 | if not os.path.exists(dir_name): 26 | os.mkdir(dir_name) 27 | 28 | 29 | def cal_precision_recall_mae(prediction, gt): 30 | # input should be np array with data type uint8 31 | assert prediction.dtype == np.uint8 32 | assert gt.dtype == np.uint8 33 | assert prediction.shape == gt.shape 34 | 35 | eps = 1e-4 36 | 37 | prediction = prediction / 255. 38 | gt = gt / 255. 39 | 40 | mae = np.mean(np.abs(prediction - gt)) 41 | 42 | hard_gt = np.zeros(prediction.shape) 43 | hard_gt[gt > 0.5] = 1 44 | t = np.sum(hard_gt) 45 | 46 | precision, recall = [], [] 47 | # calculating precision and recall at 255 different binarizing thresholds 48 | for threshold in range(256): 49 | threshold = threshold / 255. 50 | 51 | hard_prediction = np.zeros(prediction.shape) 52 | hard_prediction[prediction > threshold] = 1 53 | 54 | tp = np.sum(hard_prediction * hard_gt) 55 | p = np.sum(hard_prediction) 56 | 57 | precision.append((tp + eps) / (p + eps)) 58 | recall.append((tp + eps) / (t + eps)) 59 | 60 | return precision, recall, mae 61 | 62 | 63 | def cal_fmeasure(precision, recall): 64 | assert len(precision) == 256 65 | assert len(recall) == 256 66 | beta_square = 0.3 67 | max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 68 | # max_fmeasure = np.mean([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 69 | 70 | return max_fmeasure 71 | 72 | 73 | def cal_fmeasure_all(precision, recall): 74 | assert len(precision) == 256 75 | assert len(recall) == 256 76 | beta_square = 0.3 77 | fmeasure = [(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)] 78 | # max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 79 | # max_fmeasure = np.mean([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 80 | 81 | return fmeasure, precision, recall 82 | 83 | 84 | def cal_fmeasure_both(precision, recall): 85 | assert len(precision) == 256 86 | assert len(recall) == 256 87 | beta_square = 0.3 88 | fmeasure = [(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)] 89 | max_fmeasure = max(fmeasure) 90 | mean_fmeasure = np.mean(fmeasure) 91 | 92 | return max_fmeasure, mean_fmeasure 93 | 94 | # codes of this function are borrowed from https://github.com/Andrew-Qibin/dss_crf 95 | def crf_refine(img, annos): 96 | def _sigmoid(x): 97 | return 1 / (1 + np.exp(-x)) 98 | 99 | assert img.dtype == np.uint8 100 | assert annos.dtype == np.uint8 101 | assert img.shape[:2] == annos.shape 102 | 103 | # img and annos should be np array with data type uint8 104 | 105 | EPSILON = 1e-8 106 | 107 | M = 2 # salient or not 108 | tau = 1.05 109 | # Setup the CRF model 110 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 111 | 112 | anno_norm = annos / 255. 113 | 114 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 115 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 116 | 117 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') 118 | U[0, :] = n_energy.flatten() 119 | U[1, :] = p_energy.flatten() 120 | 121 | d.setUnaryEnergy(U) 122 | 123 | d.addPairwiseGaussian(sxy=3, compat=3) 124 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 125 | 126 | # Do the inference 127 | infer = np.array(d.inference(1)).astype('float32') 128 | res = infer[1, :] 129 | 130 | res = res * 255 131 | res = res.reshape(img.shape[:2]) 132 | return res.astype('uint8') 133 | -------------------------------------------------------------------------------- /train_densenet_attention.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | 11 | import joint_transforms 12 | from config import duts_train_path 13 | from datasets import ImageFolder 14 | from misc import AvgMeter, check_mkdir 15 | 16 | from Densenet_attention import AADFNet 17 | from torch.backends import cudnn 18 | 19 | cudnn.benchmark = True 20 | 21 | torch.manual_seed(2018) 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "1,0" 23 | # torch.cuda.set_device(0) 24 | 25 | ckpt_path = './ckpt' 26 | 27 | 28 | args = { 29 | 'iter_num': 30000, 30 | 'train_batch_size': 10, 31 | 'last_iter': 0, 32 | 'lr': 1e-3, 33 | 'lr_decay': 0.9, 34 | 'weight_decay': 5e-4, 35 | 'momentum': 0.9, 36 | 'snapshot': '' 37 | } 38 | joint_transform = joint_transforms.Compose([ 39 | joint_transforms.RandomCrop(400), 40 | joint_transforms.RandomHorizontallyFlip(), 41 | joint_transforms.RandomRotate(10) 42 | ]) 43 | img_transform = transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 46 | ]) 47 | target_transform = transforms.ToTensor() 48 | 49 | train_set = ImageFolder(duts_train_path, joint_transform, img_transform, target_transform) 50 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=12, shuffle=True, drop_last=True) 51 | 52 | criterion = nn.BCEWithLogitsLoss().cuda() 53 | 54 | save_points = range(8000, 30002, 1000) 55 | 56 | def main(): 57 | 58 | exp_name = 'AADFNet' 59 | train(exp_name) 60 | 61 | 62 | def train(exp_name): 63 | 64 | net = AADFNet().cuda().train() 65 | net = nn.DataParallel(net, device_ids=[0, 1]) 66 | 67 | optimizer = optim.SGD([ 68 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 69 | 'lr': 2 * args['lr']}, 70 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 71 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 72 | ], momentum=args['momentum']) 73 | 74 | 75 | if len(args['snapshot']) > 0: 76 | print('training resumes from ' + args['snapshot']) 77 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 78 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth'))) 79 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 80 | optimizer.param_groups[1]['lr'] = args['lr'] 81 | 82 | check_mkdir(ckpt_path) 83 | check_mkdir(os.path.join(ckpt_path, exp_name)) 84 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt') 85 | open(log_path, 'w').write(str(args) + '\n\n') 86 | print 'start to train' 87 | 88 | 89 | 90 | curr_iter = args['last_iter'] 91 | while True: 92 | total_loss_record, loss1_record, loss2_record = AvgMeter(), AvgMeter(), AvgMeter() 93 | loss3_record, loss4_record = AvgMeter(), AvgMeter() 94 | loss2_2_record, loss3_2_record, loss4_2_record = AvgMeter(), AvgMeter(), AvgMeter() 95 | 96 | loss44_record, loss43_record, loss42_record, loss41_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 97 | loss34_record, loss33_record, loss32_record, loss31_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 98 | loss24_record, loss23_record, loss22_record, loss21_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 99 | loss14_record, loss13_record, loss12_record, loss11_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 100 | 101 | for i, data in enumerate(train_loader): 102 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num'] 103 | ) ** args['lr_decay'] 104 | optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num'] 105 | ) ** args['lr_decay'] 106 | 107 | inputs, labels = data 108 | batch_size = inputs.size(0) 109 | inputs = Variable(inputs).cuda() 110 | labels = Variable(labels).cuda() 111 | 112 | optimizer.zero_grad() 113 | 114 | outputs4_2, outputs3_2, outputs2_2, outputs1, outputs2, outputs3, outputs4, \ 115 | predict41, predict42, predict43, predict44, \ 116 | predict31, predict32, predict33, predict34, \ 117 | predict21, predict22, predict23, predict24, \ 118 | predict11, predict12, predict13, predict14 = net(inputs) 119 | 120 | loss1 = criterion(outputs1, labels) 121 | loss2 = criterion(outputs2, labels) 122 | loss3 = criterion(outputs3, labels) 123 | loss4 = criterion(outputs4, labels) 124 | 125 | loss2_2 = criterion(outputs2_2, labels) 126 | loss3_2 = criterion(outputs3_2, labels) 127 | loss4_2 = criterion(outputs4_2, labels) 128 | 129 | loss44 = criterion(predict44, labels) 130 | loss43 = criterion(predict43, labels) 131 | loss42 = criterion(predict42, labels) 132 | loss41 = criterion(predict41, labels) 133 | 134 | loss34 = criterion(predict34, labels) 135 | loss33 = criterion(predict33, labels) 136 | loss32 = criterion(predict32, labels) 137 | loss31 = criterion(predict31, labels) 138 | 139 | loss24 = criterion(predict24, labels) 140 | loss23 = criterion(predict23, labels) 141 | loss22 = criterion(predict22, labels) 142 | loss21 = criterion(predict21, labels) 143 | 144 | loss14 = criterion(predict14, labels) 145 | loss13 = criterion(predict13, labels) 146 | loss12 = criterion(predict12, labels) 147 | loss11 = criterion(predict11, labels) 148 | 149 | total_loss = loss1 + loss2 + loss3 + loss4 + loss2_2 + loss3_2 + loss4_2 \ 150 | + (loss44 + loss43 + loss42 + loss41)/10 \ 151 | + (loss34 + loss33 + loss32 + loss31)/10 \ 152 | + (loss24 + loss23 + loss22 + loss21)/10 \ 153 | + (loss14 + loss13 + loss12 + loss11)/10 154 | 155 | total_loss = loss1 + loss2 + loss3 + loss4 156 | 157 | total_loss.backward() 158 | optimizer.step() 159 | 160 | total_loss_record.update(total_loss.item(), batch_size) 161 | loss1_record.update(loss1.item(), batch_size) 162 | loss2_record.update(loss2.item(), batch_size) 163 | loss3_record.update(loss3.item(), batch_size) 164 | loss4_record.update(loss4.item(), batch_size) 165 | 166 | loss2_2_record.update(loss2_2.item(), batch_size) 167 | loss3_2_record.update(loss3_2.item(), batch_size) 168 | loss4_2_record.update(loss4_2.item(), batch_size) 169 | 170 | loss44_record.update(loss44.item(), batch_size) 171 | loss43_record.update(loss43.item(), batch_size) 172 | loss42_record.update(loss42.item(), batch_size) 173 | loss41_record.update(loss41.item(), batch_size) 174 | 175 | loss34_record.update(loss34.item(), batch_size) 176 | loss33_record.update(loss33.item(), batch_size) 177 | loss32_record.update(loss32.item(), batch_size) 178 | loss31_record.update(loss31.item(), batch_size) 179 | 180 | loss24_record.update(loss24.item(), batch_size) 181 | loss23_record.update(loss23.item(), batch_size) 182 | loss22_record.update(loss22.item(), batch_size) 183 | loss21_record.update(loss21.item(), batch_size) 184 | 185 | loss14_record.update(loss14.item(), batch_size) 186 | loss13_record.update(loss13.item(), batch_size) 187 | loss12_record.update(loss12.item(), batch_size) 188 | loss11_record.update(loss11.item(), batch_size) 189 | 190 | 191 | curr_iter += 1 192 | 193 | log = '[iter %d], [total loss %.5f], ' \ 194 | '[loss4_2 %.5f], [loss3_2 %.5f], [loss2_2 %.5f], [loss1 %.5f], ' \ 195 | '[loss2 %.5f], [loss3 %.5f], [loss4 %.5f], ' \ 196 | '[loss44 %.5f], [loss43 %.5f], [loss42 %.5f], [loss41 %.5f], ' \ 197 | '[loss34 %.5f], [loss33 %.5f], [loss32 %.5f], [loss31 %.5f], ' \ 198 | '[loss24 %.5f], [loss23 %.5f], [loss22 %.5f], [loss21 %.5f], ' \ 199 | '[loss14 %.5f], [loss13 %.5f], [loss12 %.5f], [loss11 %.5f], ' \ 200 | '[lr %.13f]' % \ 201 | (curr_iter, total_loss_record.avg, 202 | loss4_2_record.avg, loss3_2_record.avg, 203 | loss2_2_record.avg, loss1_record.avg, loss2_record.avg, 204 | loss3_record.avg, loss4_record.avg, 205 | loss44_record.avg, loss43_record.avg, loss42_record.avg, loss41_record.avg, 206 | loss34_record.avg, loss33_record.avg, loss32_record.avg, loss31_record.avg, 207 | loss24_record.avg, loss23_record.avg, loss22_record.avg, loss21_record.avg, 208 | loss14_record.avg, loss13_record.avg, loss12_record.avg, loss11_record.avg, 209 | optimizer.param_groups[1]['lr']) 210 | 211 | print log 212 | open(log_path, 'a').write(log + '\n') 213 | 214 | 215 | if curr_iter == args['iter_num']: 216 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) 217 | torch.save(optimizer.state_dict(), 218 | os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) 219 | return 220 | 221 | 222 | if __name__ == '__main__': 223 | main() 224 | --------------------------------------------------------------------------------