├── DepthBranchDecoder ├── DepthBranchDecoder.py ├── __init__.py └── __pycache__ │ ├── DepthBranchDecoder.cpython-36.pyc │ ├── FlowBranchDecoder.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── DepthBranchEncoder ├── DepthBranchEncoder.py ├── __init__.py └── __pycache__ │ ├── DepthBranchEncoder.cpython-36.pyc │ ├── FlowBranchEncoder.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── ImageBranchDecoder ├── ImageBranchDecoder.py ├── __init__.py └── __pycache__ │ ├── ImageBranchDecoder.cpython-36.pyc │ ├── ImageBranchDecoder_parts.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── ImageBranchEncoder ├── ImageBranchEncoder.py ├── __init__.py └── __pycache__ │ ├── ImageBranchEncoder.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── ImageDepthNet ├── ImageDepthNet.py ├── __init__.py └── __pycache__ │ ├── ImageDepthNet.cpython-36.pyc │ ├── ImageFlowNet.cpython-36.pyc │ ├── ImageFlowNet_parts.cpython-36.pyc │ └── __init__.cpython-36.pyc ├── README.md ├── RGBdDataset_processed └── NLPR │ └── testset │ └── depth │ ├── 10_01-00-59.bmp │ ├── 10_01-02-01.bmp │ ├── 10_01-05-11.bmp │ ├── 10_01-06-39.bmp │ ├── 10_01-10-53.bmp │ ├── 10_01-11-07.bmp │ ├── 10_01-13-37.bmp │ ├── 10_01-16-00.bmp │ ├── 10_01-16-04.bmp │ ├── 10_01-16-21.bmp │ ├── 10_01-17-25.bmp │ ├── 10_01-17-42.bmp │ ├── 10_01-18-10.bmp │ ├── 10_01-18-15.bmp │ ├── 10_01-20-50.bmp │ ├── 10_01-21-12.bmp │ ├── 10_01-21-21.bmp │ ├── 10_01-21-41.bmp │ ├── 10_01-22-41.bmp │ ├── 10_01-30-24.bmp │ ├── 10_02-58-26.bmp │ └── 10_02-58-46.bmp ├── dataset.py ├── finetune_DUT_RGBD.py ├── generate_list.py ├── list ├── test │ └── test_list.txt └── train │ ├── DUT_train_list.txt │ └── train_list.txt ├── model_epoch_loss └── loss.txt ├── output └── NLPR │ └── S2MA.pth │ ├── 1_02-03-44.png │ ├── 1_02-08-35.png │ ├── 1_02-54-18.png │ └── 1_02-59-20.png ├── parameter.py ├── parameter_finetune_DUT_RGBD.py ├── test.py ├── train.py └── transforms.py /DepthBranchDecoder/DepthBranchDecoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import parameter 5 | 6 | 7 | class decoder_module_part1(nn.Module): 8 | def __init__(self, in_channels, out_channels, fusing=True): 9 | super(decoder_module_part1, self).__init__() 10 | if fusing: 11 | self.enc_fea_proc = nn.Sequential( 12 | nn.BatchNorm2d(in_channels, momentum=parameter.bn_momentum), 13 | nn.ReLU(inplace=True), 14 | ) 15 | in_channels = in_channels*2 16 | self.decoding1 = nn.Sequential( 17 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 18 | ) 19 | 20 | def forward(self, enc_fea, dec_fea=None): 21 | if dec_fea is not None: 22 | enc_fea = self.enc_fea_proc(enc_fea) 23 | if dec_fea.size(2) != enc_fea.size(2): 24 | dec_fea = F.upsample(dec_fea, size=[enc_fea.size(2), enc_fea.size(3)], mode='bilinear', align_corners=True) 25 | enc_fea = torch.cat([enc_fea, dec_fea], dim=1) 26 | output = self.decoding1(enc_fea) 27 | 28 | return output 29 | 30 | 31 | class decoder_module_part2(nn.Module): 32 | def __init__(self, out_channels): 33 | super(decoder_module_part2, self).__init__() 34 | 35 | self.decoding1_resPart = nn.Sequential( 36 | nn.BatchNorm2d(out_channels, momentum=parameter.bn_momentum), 37 | nn.ReLU(inplace=True), 38 | ) 39 | self.decoding2 = nn.Sequential( 40 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 41 | nn.BatchNorm2d(out_channels, momentum=parameter.bn_momentum), 42 | nn.ReLU(inplace=True), 43 | ) 44 | 45 | def forward(self, enc_fea): 46 | 47 | output = self.decoding1_resPart(enc_fea) 48 | output = self.decoding2(output) 49 | 50 | return output 51 | 52 | 53 | class DepthBranchDecoder(nn.Module): 54 | def __init__(self): 55 | 56 | super(DepthBranchDecoder, self).__init__() 57 | channels = [64, 128, 256, 512, 512, 512] 58 | 59 | self.decoder6_part1 = decoder_module_part1(channels[5], channels[4], False) 60 | self.decoder6_part2 = decoder_module_part2(channels[4]) 61 | 62 | self.decoder5_part1 = decoder_module_part1(channels[4], channels[3]) 63 | self.decoder5_part2 = decoder_module_part2(channels[3]) 64 | 65 | self.decoder4_part1 = decoder_module_part1(channels[3], channels[2]) 66 | self.decoder4_part2 = decoder_module_part2(channels[2]) 67 | 68 | self.decoder3_part1 = decoder_module_part1(channels[2], channels[1]) 69 | self.decoder3_part2 = decoder_module_part2(channels[1]) 70 | 71 | self.decoder2_part1 = decoder_module_part1(channels[1], channels[0]) 72 | self.decoder2_part2 = decoder_module_part2(channels[0]) 73 | 74 | self.decoder1_part1 = decoder_module_part1(channels[0], channels[0]) 75 | self.decoder1_part2 = decoder_module_part2(channels[0]) 76 | 77 | self.conv_loss6 = nn.Conv2d(in_channels=channels[4], out_channels=1, kernel_size=3, padding=1) 78 | self.conv_loss5 = nn.Conv2d(in_channels=channels[3], out_channels=1, kernel_size=3, padding=1) 79 | self.conv_loss4 = nn.Conv2d(in_channels=channels[2], out_channels=1, kernel_size=3, padding=1) 80 | self.conv_loss3 = nn.Conv2d(in_channels=channels[1], out_channels=1, kernel_size=3, padding=1) 81 | self.conv_loss2 = nn.Conv2d(in_channels=channels[0], out_channels=1, kernel_size=3, padding=1) 82 | self.conv_loss1 = nn.Conv2d(in_channels=channels[0], out_channels=1, kernel_size=3, padding=1) 83 | 84 | def forward(self, enc_fea, AfterDASPP): 85 | 86 | encoder_conv1, encoder_conv2, encoder_conv3, encoder_conv4, encoder_conv5, x7 = enc_fea 87 | 88 | dec_fea_6_part1 = self.decoder6_part1(AfterDASPP) 89 | dec_fea_6_part2 = self.decoder6_part2(dec_fea_6_part1) 90 | 91 | mask6 = self.conv_loss6(dec_fea_6_part2) 92 | 93 | dec_fea_5_part1 = self.decoder5_part1(encoder_conv5, dec_fea_6_part2) 94 | dec_fea_5_part2 = self.decoder5_part2(dec_fea_5_part1) 95 | 96 | mask5 = self.conv_loss5(dec_fea_5_part2) 97 | 98 | dec_fea_4_part1 = self.decoder4_part1(encoder_conv4, dec_fea_5_part2) 99 | dec_fea_4_part2 = self.decoder4_part2(dec_fea_4_part1) 100 | 101 | mask4 = self.conv_loss4(dec_fea_4_part2) 102 | 103 | dec_fea_3_part1 = self.decoder3_part1(encoder_conv3, dec_fea_4_part2) 104 | dec_fea_3_part2 = self.decoder3_part2(dec_fea_3_part1) 105 | 106 | mask3 = self.conv_loss3(dec_fea_3_part2) 107 | 108 | dec_fea_2_part1 = self.decoder2_part1(encoder_conv2, dec_fea_3_part2) 109 | dec_fea_2_part2 = self.decoder2_part2(dec_fea_2_part1) 110 | 111 | mask2 = self.conv_loss2(dec_fea_2_part2) 112 | 113 | dec_fea_1_part1 = self.decoder1_part1(encoder_conv1, dec_fea_2_part2) 114 | dec_fea_1_part2 = self.decoder1_part2(dec_fea_1_part1) 115 | 116 | mask1 = self.conv_loss1(dec_fea_1_part2) 117 | 118 | return mask6, mask5, mask4, mask3, mask2, mask1 119 | -------------------------------------------------------------------------------- /DepthBranchDecoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .DepthBranchDecoder import DepthBranchDecoder -------------------------------------------------------------------------------- /DepthBranchDecoder/__pycache__/DepthBranchDecoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/DepthBranchDecoder/__pycache__/DepthBranchDecoder.cpython-36.pyc -------------------------------------------------------------------------------- /DepthBranchDecoder/__pycache__/FlowBranchDecoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/DepthBranchDecoder/__pycache__/FlowBranchDecoder.cpython-36.pyc -------------------------------------------------------------------------------- /DepthBranchDecoder/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/DepthBranchDecoder/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DepthBranchEncoder/DepthBranchEncoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class DepthBranchEncoder(nn.Module): 5 | def __init__(self, n_channels): 6 | 7 | super(DepthBranchEncoder, self).__init__() 8 | 9 | self.conv1 = nn.Sequential( 10 | nn.Conv2d(n_channels, out_channels=64, kernel_size=3, stride=1, padding=1), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 13 | ) 14 | self.conv2 = nn.Sequential( 15 | nn.ReLU(), 16 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True), 17 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 20 | ) 21 | self.conv3 = nn.Sequential( 22 | nn.ReLU(), 23 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True), 24 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 29 | ) 30 | self.conv4 = nn.Sequential( 31 | nn.ReLU(), 32 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True), 33 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 38 | ) 39 | self.conv5 = nn.Sequential( 40 | nn.ReLU(), 41 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 42 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, stride=1, padding=2), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, stride=1, padding=2), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, stride=1, padding=2), 47 | ) 48 | self.fc6 = nn.Sequential( 49 | nn.ReLU(), 50 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 51 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, dilation=12, padding=12), 52 | nn.ReLU(inplace=True), 53 | ) 54 | 55 | self.dropout = nn.Dropout(0.5) 56 | 57 | self.fc7 = nn.Sequential( 58 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), 59 | nn.ReLU(inplace=True), 60 | ) 61 | 62 | def forward(self, x): 63 | out_conv1 = self.conv1(x) 64 | out_conv2 = self.conv2(out_conv1) 65 | out_conv3 = self.conv3(out_conv2) 66 | out_conv4 = self.conv4(out_conv3) 67 | out_conv5 = self.conv5(out_conv4) 68 | x6 = self.fc6(out_conv5) 69 | x7 = self.fc7(self.dropout(x6)) 70 | 71 | return out_conv1, out_conv2, out_conv3, out_conv4, out_conv5, x7 72 | -------------------------------------------------------------------------------- /DepthBranchEncoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .DepthBranchEncoder import DepthBranchEncoder -------------------------------------------------------------------------------- /DepthBranchEncoder/__pycache__/DepthBranchEncoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/DepthBranchEncoder/__pycache__/DepthBranchEncoder.cpython-36.pyc -------------------------------------------------------------------------------- /DepthBranchEncoder/__pycache__/FlowBranchEncoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/DepthBranchEncoder/__pycache__/FlowBranchEncoder.cpython-36.pyc -------------------------------------------------------------------------------- /DepthBranchEncoder/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/DepthBranchEncoder/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ImageBranchDecoder/ImageBranchDecoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import parameter 5 | from DepthBranchDecoder import DepthBranchDecoder 6 | 7 | 8 | class Res_part(nn.Module): 9 | 10 | def __init__(self, in_channels): 11 | super(Res_part, self).__init__() 12 | self.res_bn_relu = nn.Sequential( 13 | nn.BatchNorm2d(in_channels*2), 14 | nn.ReLU(inplace=True), 15 | ) 16 | self.res_conv = nn.Conv2d(in_channels*2, in_channels, kernel_size=3, stride=1, padding=1) 17 | 18 | def forward(self, ImageFea, DepthFea): 19 | 20 | ImageFlow_Fea = torch.cat([ImageFea, DepthFea], dim=1) 21 | ImageFlow_resFea = self.res_bn_relu(ImageFlow_Fea) 22 | ImageFlow_resFea = self.res_conv(ImageFlow_resFea) 23 | 24 | return ImageFea + ImageFlow_resFea 25 | 26 | 27 | class decoder_module(nn.Module): 28 | def __init__(self, in_channels, out_channels, fusing=True): 29 | super(decoder_module, self).__init__() 30 | if fusing: 31 | self.enc_fea_proc = nn.Sequential( 32 | nn.BatchNorm2d(in_channels, momentum=parameter.bn_momentum), 33 | nn.ReLU(inplace=True), 34 | ) 35 | in_channels = in_channels*2 36 | 37 | self.ResPart = Res_part(out_channels) 38 | 39 | self.decoding1 = nn.Sequential( 40 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 41 | nn.BatchNorm2d(out_channels, momentum=parameter.bn_momentum), 42 | nn.ReLU(inplace=True), 43 | ) 44 | 45 | self.decoding2 = nn.Sequential( 46 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 47 | nn.BatchNorm2d(out_channels, momentum=parameter.bn_momentum), 48 | nn.ReLU(inplace=True), 49 | ) 50 | 51 | def forward(self, enc_fea, depth_fea=None, dec_fea=None): 52 | if (dec_fea is not None) and (depth_fea is not None): 53 | # process encoder feature 54 | enc_fea = self.enc_fea_proc(enc_fea) 55 | if dec_fea.size(2) != enc_fea.size(2): 56 | dec_fea = F.upsample(dec_fea, size=[enc_fea.size(2), enc_fea.size(3)], mode='bilinear', align_corners=True) 57 | enc_fea = torch.cat([enc_fea, dec_fea], dim=1) 58 | 59 | # split conv1/bn/relu to conv1, ResPart, bn, relu 60 | # conv1 61 | output = self.decoding1[0](enc_fea) 62 | 63 | output = self.ResPart(output, depth_fea) 64 | 65 | # bn/relu 66 | output = self.decoding1[1](output) 67 | output = self.decoding1[2](output) 68 | 69 | # conv2 70 | output = self.decoding2(output) 71 | else: 72 | output = self.decoding1(enc_fea) 73 | output = self.decoding2(output) 74 | 75 | return output 76 | 77 | 78 | class ImageBranchDecoder(nn.Module): 79 | def __init__(self): 80 | 81 | super(ImageBranchDecoder, self).__init__() 82 | channels = [64, 128, 256, 512, 512, 512] 83 | 84 | self.decoder6 = decoder_module(channels[5], channels[4], False) 85 | self.decoder5 = decoder_module(channels[4], channels[3]) 86 | self.decoder4 = decoder_module(channels[3], channels[2]) 87 | self.decoder3 = decoder_module(channels[2], channels[1]) 88 | self.decoder2 = decoder_module(channels[1], channels[0]) 89 | self.decoder1 = decoder_module(channels[0], channels[0]) 90 | 91 | self.conv_loss6 = nn.Conv2d(in_channels=channels[4], out_channels=1, kernel_size=3, padding=1) 92 | self.conv_loss5 = nn.Conv2d(in_channels=channels[3], out_channels=1, kernel_size=3, padding=1) 93 | self.conv_loss4 = nn.Conv2d(in_channels=channels[2], out_channels=1, kernel_size=3, padding=1) 94 | self.conv_loss3 = nn.Conv2d(in_channels=channels[1], out_channels=1, kernel_size=3, padding=1) 95 | self.conv_loss2 = nn.Conv2d(in_channels=channels[0], out_channels=1, kernel_size=3, padding=1) 96 | self.conv_loss1 = nn.Conv2d(in_channels=channels[0], out_channels=1, kernel_size=3, padding=1) 97 | 98 | self.DepthBranchDecoder = DepthBranchDecoder() 99 | 100 | def forward(self, image_feas, ImageAfterAtt, depth_feas, DepthAfterAtt): 101 | 102 | encoder_conv1, encoder_conv2, encoder_conv3, encoder_conv4, encoder_conv5, x7 = image_feas 103 | depth_encoder_conv1, depth_encoder_conv2, depth_encoder_conv3, depth_encoder_conv4, depth_encoder_conv5, depth_x7 = depth_feas 104 | 105 | # depth (decoder6) 106 | depth_dec_fea_6_part1 = self.DepthBranchDecoder.decoder6_part1(DepthAfterAtt) 107 | depth_dec_fea_6_part2 = self.DepthBranchDecoder.decoder6_part2(depth_dec_fea_6_part1) 108 | depth_mask6 = self.DepthBranchDecoder.conv_loss6(depth_dec_fea_6_part2) 109 | # image (decoder6) 110 | dec_fea_6 = self.decoder6(ImageAfterAtt) 111 | mask6 = self.conv_loss6(dec_fea_6) 112 | 113 | # depth (decoder5) 114 | depth_dec_fea_5_part1 = self.DepthBranchDecoder.decoder5_part1(depth_encoder_conv5, depth_dec_fea_6_part2) 115 | depth_dec_fea_5_part2 = self.DepthBranchDecoder.decoder5_part2(depth_dec_fea_5_part1) 116 | depth_mask5 = self.DepthBranchDecoder.conv_loss5(depth_dec_fea_5_part2) 117 | # image (decoder5) 118 | dec_fea_5 = self.decoder5(encoder_conv5, depth_dec_fea_5_part1, dec_fea_6) 119 | mask5 = self.conv_loss5(dec_fea_5) 120 | 121 | # depth (decoder4) 122 | depth_dec_fea_4_part1 = self.DepthBranchDecoder.decoder4_part1(depth_encoder_conv4, depth_dec_fea_5_part2) 123 | depth_dec_fea_4_part2 = self.DepthBranchDecoder.decoder4_part2(depth_dec_fea_4_part1) 124 | depth_mask4 = self.DepthBranchDecoder.conv_loss4(depth_dec_fea_4_part2) 125 | # image (decoder4) 126 | dec_fea_4 = self.decoder4(encoder_conv4, depth_dec_fea_4_part1, dec_fea_5) 127 | mask4 = self.conv_loss4(dec_fea_4) 128 | 129 | # depth (decoder3) 130 | depth_dec_fea_3_part1 = self.DepthBranchDecoder.decoder3_part1(depth_encoder_conv3, depth_dec_fea_4_part2) 131 | depth_dec_fea_3_part2 = self.DepthBranchDecoder.decoder3_part2(depth_dec_fea_3_part1) 132 | depth_mask3 = self.DepthBranchDecoder.conv_loss3(depth_dec_fea_3_part2) 133 | # image (decoder3) 134 | dec_fea_3 = self.decoder3(encoder_conv3, depth_dec_fea_3_part1, dec_fea_4) 135 | mask3 = self.conv_loss3(dec_fea_3) 136 | 137 | # depth (decoder2) 138 | depth_dec_fea_2_part1 = self.DepthBranchDecoder.decoder2_part1(depth_encoder_conv2, depth_dec_fea_3_part2) 139 | depth_dec_fea_2_part2 = self.DepthBranchDecoder.decoder2_part2(depth_dec_fea_2_part1) 140 | depth_mask2 = self.DepthBranchDecoder.conv_loss2(depth_dec_fea_2_part2) 141 | # image (decoder2) 142 | dec_fea_2 = self.decoder2(encoder_conv2, depth_dec_fea_2_part1, dec_fea_3) 143 | mask2 = self.conv_loss2(dec_fea_2) 144 | 145 | # depth (decoder1) 146 | depth_dec_fea_1_part1 = self.DepthBranchDecoder.decoder1_part1(depth_encoder_conv1, depth_dec_fea_2_part2) 147 | depth_dec_fea_1_part2 = self.DepthBranchDecoder.decoder1_part2(depth_dec_fea_1_part1) 148 | depth_mask1 = self.DepthBranchDecoder.conv_loss1(depth_dec_fea_1_part2) 149 | # image (decoder1) 150 | dec_fea_1 = self.decoder1(encoder_conv1, depth_dec_fea_1_part1, dec_fea_2) 151 | mask1 = self.conv_loss1(dec_fea_1) 152 | 153 | return [mask6, mask5, mask4, mask3, mask2, mask1], [depth_mask6, depth_mask5, depth_mask4, depth_mask3, depth_mask2, depth_mask1] 154 | -------------------------------------------------------------------------------- /ImageBranchDecoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .ImageBranchDecoder import ImageBranchDecoder -------------------------------------------------------------------------------- /ImageBranchDecoder/__pycache__/ImageBranchDecoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageBranchDecoder/__pycache__/ImageBranchDecoder.cpython-36.pyc -------------------------------------------------------------------------------- /ImageBranchDecoder/__pycache__/ImageBranchDecoder_parts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageBranchDecoder/__pycache__/ImageBranchDecoder_parts.cpython-36.pyc -------------------------------------------------------------------------------- /ImageBranchDecoder/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageBranchDecoder/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ImageBranchEncoder/ImageBranchEncoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ImageBranchEncoder(nn.Module): 5 | def __init__(self, n_channels): 6 | 7 | super(ImageBranchEncoder, self).__init__() 8 | 9 | self.conv1 = nn.Sequential( 10 | nn.Conv2d(n_channels, out_channels=64, kernel_size=3, stride=1, padding=1), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 13 | ) 14 | self.conv2 = nn.Sequential( 15 | nn.ReLU(), 16 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True), 17 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 20 | ) 21 | self.conv3 = nn.Sequential( 22 | nn.ReLU(), 23 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True), 24 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 29 | ) 30 | self.conv4 = nn.Sequential( 31 | nn.ReLU(), 32 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True), 33 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 38 | ) 39 | self.conv5 = nn.Sequential( 40 | nn.ReLU(), 41 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 42 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, stride=1, padding=2), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, stride=1, padding=2), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, stride=1, padding=2), 47 | ) 48 | self.fc6 = nn.Sequential( 49 | nn.ReLU(), 50 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 51 | nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, dilation=12, padding=12), 52 | nn.ReLU(inplace=True), 53 | ) 54 | 55 | self.dropout = nn.Dropout(0.5) 56 | 57 | self.fc7 = nn.Sequential( 58 | nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), 59 | nn.ReLU(inplace=True), 60 | ) 61 | 62 | def forward(self, x): 63 | out_conv1 = self.conv1(x) 64 | out_conv2 = self.conv2(out_conv1) 65 | out_conv3 = self.conv3(out_conv2) 66 | out_conv4 = self.conv4(out_conv3) 67 | out_conv5 = self.conv5(out_conv4) 68 | x6 = self.fc6(out_conv5) 69 | x7 = self.fc7(self.dropout(x6)) 70 | 71 | return out_conv1, out_conv2, out_conv3, out_conv4, out_conv5, x7 72 | -------------------------------------------------------------------------------- /ImageBranchEncoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .ImageBranchEncoder import ImageBranchEncoder -------------------------------------------------------------------------------- /ImageBranchEncoder/__pycache__/ImageBranchEncoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageBranchEncoder/__pycache__/ImageBranchEncoder.cpython-36.pyc -------------------------------------------------------------------------------- /ImageBranchEncoder/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageBranchEncoder/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ImageDepthNet/ImageDepthNet.py: -------------------------------------------------------------------------------- 1 | from ImageBranchEncoder import ImageBranchEncoder 2 | from ImageBranchDecoder import ImageBranchDecoder 3 | from DepthBranchEncoder import DepthBranchEncoder 4 | 5 | import torch.nn as nn 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.nn import BatchNorm2d as bn 9 | 10 | 11 | class NonLocalBlock(nn.Module): 12 | """ NonLocalBlock Module""" 13 | def __init__(self, in_channels): 14 | super(NonLocalBlock, self).__init__() 15 | 16 | conv_nd = nn.Conv2d 17 | 18 | self.in_channels = in_channels 19 | self.inter_channels = self.in_channels // 2 20 | 21 | self.ImageAfterASPP_bnRelu = nn.Sequential( 22 | nn.BatchNorm2d(self.in_channels), 23 | nn.ReLU(inplace=True), 24 | ) 25 | 26 | self.DepthAfterASPP_bnRelu = nn.Sequential( 27 | nn.BatchNorm2d(self.in_channels), 28 | nn.ReLU(inplace=True), 29 | ) 30 | 31 | self.R_g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 32 | kernel_size=1, stride=1, padding=0) 33 | self.R_theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 34 | kernel_size=1, stride=1, padding=0) 35 | self.R_phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 36 | kernel_size=1, stride=1, padding=0) 37 | self.R_W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 38 | kernel_size=1, stride=1, padding=0) 39 | 40 | self.F_g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 41 | kernel_size=1, stride=1, padding=0) 42 | self.F_theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 43 | kernel_size=1, stride=1, padding=0) 44 | self.F_phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | self.F_W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 47 | kernel_size=1, stride=1, padding=0) 48 | 49 | def forward(self, self_fea, mutual_fea, alpha, selfImage): 50 | 51 | if selfImage: 52 | selfNonLocal_fea = self.ImageAfterASPP_bnRelu(self_fea) 53 | mutualNonLocal_fea = self.DepthAfterASPP_bnRelu(mutual_fea) 54 | 55 | batch_size = selfNonLocal_fea.size(0) 56 | 57 | g_x = self.R_g(selfNonLocal_fea).view(batch_size, self.inter_channels, -1) 58 | g_x = g_x.permute(0, 2, 1) 59 | 60 | # using mutual feature to generate attention 61 | theta_x = self.F_theta(mutualNonLocal_fea).view(batch_size, self.inter_channels, -1) 62 | theta_x = theta_x.permute(0, 2, 1) 63 | phi_x = self.F_phi(mutualNonLocal_fea).view(batch_size, self.inter_channels, -1) 64 | f = torch.matmul(theta_x, phi_x) 65 | 66 | # using self feature to generate attention 67 | self_theta_x = self.R_theta(selfNonLocal_fea).view(batch_size, self.inter_channels, -1) 68 | self_theta_x = self_theta_x.permute(0, 2, 1) 69 | self_phi_x = self.R_phi(selfNonLocal_fea).view(batch_size, self.inter_channels, -1) 70 | self_f = torch.matmul(self_theta_x, self_phi_x) 71 | 72 | # add self_f and mutual f 73 | f_div_C = F.softmax(alpha*f + self_f, dim=-1) 74 | 75 | y = torch.matmul(f_div_C, g_x) 76 | y = y.permute(0, 2, 1).contiguous() 77 | y = y.view(batch_size, self.inter_channels, *selfNonLocal_fea.size()[2:]) 78 | W_y = self.R_W(y) 79 | z = W_y + self_fea 80 | return z 81 | 82 | else: 83 | selfNonLocal_fea = self.DepthAfterASPP_bnRelu(self_fea) 84 | mutualNonLocal_fea = self.ImageAfterASPP_bnRelu(mutual_fea) 85 | 86 | batch_size = selfNonLocal_fea.size(0) 87 | 88 | g_x = self.F_g(selfNonLocal_fea).view(batch_size, self.inter_channels, -1) 89 | g_x = g_x.permute(0, 2, 1) 90 | 91 | # using mutual feature to generate attention 92 | theta_x = self.R_theta(mutualNonLocal_fea).view(batch_size, self.inter_channels, -1) 93 | theta_x = theta_x.permute(0, 2, 1) 94 | phi_x = self.R_phi(mutualNonLocal_fea).view(batch_size, self.inter_channels, -1) 95 | f = torch.matmul(theta_x, phi_x) 96 | 97 | # using self feature to generate attention 98 | self_theta_x = self.F_theta(selfNonLocal_fea).view(batch_size, self.inter_channels, -1) 99 | self_theta_x = self_theta_x.permute(0, 2, 1) 100 | self_phi_x = self.F_phi(selfNonLocal_fea).view(batch_size, self.inter_channels, -1) 101 | self_f = torch.matmul(self_theta_x, self_phi_x) 102 | 103 | # add self_f and mutual f 104 | f_div_C = F.softmax(alpha*f+self_f, dim=-1) 105 | 106 | y = torch.matmul(f_div_C, g_x) 107 | y = y.permute(0, 2, 1).contiguous() 108 | y = y.view(batch_size, self.inter_channels, *selfNonLocal_fea.size()[2:]) 109 | W_y = self.F_W(y) 110 | z = W_y + self_fea 111 | return z 112 | 113 | 114 | class _DenseAsppBlock(nn.Sequential): 115 | """ ConvNet block for building DenseASPP. """ 116 | 117 | def __init__(self, input_num, num1, num2, dilation_rate): 118 | super(_DenseAsppBlock, self).__init__() 119 | 120 | self.conv1 = nn.Conv2d(in_channels=input_num, out_channels=num1, kernel_size=1) 121 | self.bn1 = bn(num1, momentum=0.0003) 122 | self.relu1 = nn.ReLU(inplace=True) 123 | 124 | self.conv2 = nn.Conv2d(in_channels=num1, out_channels=num2, kernel_size=3, 125 | dilation=dilation_rate, padding=dilation_rate) 126 | self.bn2 = bn(num2, momentum=0.0003) 127 | self.relu2 = nn.ReLU(inplace=True) 128 | 129 | def forward(self, input): 130 | 131 | feature = self.relu1(self.bn1(self.conv1(input))) 132 | feature = self.relu2(self.bn2(self.conv2(feature))) 133 | 134 | return feature 135 | 136 | 137 | class DASPPmodule(nn.Module): 138 | def __init__(self): 139 | super(DASPPmodule, self).__init__() 140 | num_features = 512 141 | d_feature1 = 176 142 | d_feature0 = num_features//2 143 | 144 | self.AvgPool = nn.Sequential( 145 | nn.AvgPool2d([32, 32], [32, 32]), 146 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1), 147 | nn.BatchNorm2d(512), 148 | nn.ReLU(inplace=True), 149 | nn.Upsample(size=32, mode='nearest'), 150 | ) 151 | self.ASPP_2 = _DenseAsppBlock(input_num=num_features, num1=d_feature0, num2=d_feature1, 152 | dilation_rate=2) 153 | 154 | self.ASPP_4 = _DenseAsppBlock(input_num=num_features + d_feature1 * 1, num1=d_feature0, num2=d_feature1, 155 | dilation_rate=4) 156 | 157 | self.ASPP_8 = _DenseAsppBlock(input_num=num_features + d_feature1 * 2, num1=d_feature0, num2=d_feature1, 158 | dilation_rate=8) 159 | 160 | self.afterASPP = nn.Sequential( 161 | nn.Conv2d(in_channels=512*2 + 176*3, out_channels=512, kernel_size=1)) 162 | 163 | def forward(self, encoder_fea): 164 | 165 | imgAvgPool = self.AvgPool(encoder_fea) 166 | 167 | aspp2 = self.ASPP_2(encoder_fea) 168 | feature = torch.cat([aspp2, encoder_fea], dim=1) 169 | 170 | aspp4 = self.ASPP_4(feature) 171 | feature = torch.cat([aspp4, feature], dim=1) 172 | 173 | aspp8 = self.ASPP_8(feature) 174 | feature = torch.cat([aspp8, feature], dim=1) 175 | 176 | asppFea = torch.cat([feature, imgAvgPool], dim=1) 177 | AfterASPP = self.afterASPP(asppFea) 178 | 179 | return AfterASPP 180 | 181 | 182 | class ImageDepthNet(nn.Module): 183 | def __init__(self, n_channels): 184 | super(ImageDepthNet, self).__init__() 185 | 186 | # encoder part 187 | self.ImageBranchEncoder = ImageBranchEncoder(n_channels) 188 | self.DepthBranchEncoder = DepthBranchEncoder(n_channels) 189 | 190 | self.ImageBranch_fc7_1 = nn.Sequential( 191 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1), 192 | nn.BatchNorm2d(512), 193 | nn.ReLU(inplace=True), 194 | ) 195 | 196 | self.DepthBranch_fc7_1 = nn.Sequential( 197 | nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1), 198 | nn.BatchNorm2d(512), 199 | nn.ReLU(inplace=True), 200 | ) 201 | 202 | self.affinityAttConv = nn.Sequential( 203 | nn.Conv2d(in_channels=1024, out_channels=2, kernel_size=1), 204 | nn.BatchNorm2d(2), 205 | nn.ReLU(inplace=True), 206 | ) 207 | 208 | # DASPP 209 | self.ImageBranch_DASPP = DASPPmodule() 210 | self.DepthBranch_DASPP = DASPPmodule() 211 | 212 | # S2MA module 213 | self.NonLocal = NonLocalBlock(in_channels=512) 214 | 215 | self.image_bn_relu = nn.Sequential( 216 | nn.BatchNorm2d(512), 217 | nn.ReLU(inplace=True)) 218 | 219 | self.depth_bn_relu = nn.Sequential( 220 | nn.BatchNorm2d(512), 221 | nn.ReLU(inplace=True)) 222 | 223 | # decoder part 224 | self.ImageBranchDecoder = ImageBranchDecoder() 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.xavier_uniform_(m.weight), 229 | nn.init.constant_(m.bias, 0), 230 | 231 | def forward(self, image_Input, depth_Input): 232 | 233 | image_feas = self.ImageBranchEncoder(image_Input) 234 | ImageAfterDASPP = self.ImageBranch_DASPP(self.ImageBranch_fc7_1(image_feas[-1])) 235 | 236 | depth_feas = self.DepthBranchEncoder(depth_Input) 237 | DepthAfterDASPP = self.DepthBranch_DASPP(self.DepthBranch_fc7_1(depth_feas[-1])) 238 | 239 | bs, ch, hei, wei = ImageAfterDASPP.size() 240 | 241 | affinityAtt = F.softmax(self.affinityAttConv(torch.cat([ImageAfterDASPP, DepthAfterDASPP], dim=1))) 242 | alphaD = affinityAtt[:, 0, :, :].reshape([bs, hei * wei, 1]) 243 | alphaR = affinityAtt[:, 1, :, :].reshape([bs, hei * wei, 1]) 244 | 245 | alphaD = alphaD.expand([bs, hei * wei, hei * wei]) 246 | alphaR = alphaR.expand([bs, hei * wei, hei * wei]) 247 | 248 | ImageAfterAtt1 = self.NonLocal(ImageAfterDASPP, DepthAfterDASPP, alphaD, selfImage=True) 249 | DepthAfterAtt1 = self.NonLocal(DepthAfterDASPP, ImageAfterDASPP, alphaR, selfImage=False) 250 | 251 | ImageAfterAtt = self.image_bn_relu(ImageAfterAtt1) 252 | DepthAfterAtt = self.depth_bn_relu(DepthAfterAtt1) 253 | 254 | outputs_image, outputs_depth = self.ImageBranchDecoder(image_feas, ImageAfterAtt, depth_feas, DepthAfterAtt) 255 | return outputs_image, outputs_depth 256 | 257 | def init_parameters(self, pretrain_vgg16_1024): 258 | 259 | rgb_conv_blocks = [self.ImageBranchEncoder.conv1, 260 | self.ImageBranchEncoder.conv2, 261 | self.ImageBranchEncoder.conv3, 262 | self.ImageBranchEncoder.conv4, 263 | self.ImageBranchEncoder.conv5, 264 | self.ImageBranchEncoder.fc6, 265 | self.ImageBranchEncoder.fc7] 266 | 267 | depth_conv_blocks = [self.DepthBranchEncoder.conv1, 268 | self.DepthBranchEncoder.conv2, 269 | self.DepthBranchEncoder.conv3, 270 | self.DepthBranchEncoder.conv4, 271 | self.DepthBranchEncoder.conv5, 272 | self.DepthBranchEncoder.fc6, 273 | self.DepthBranchEncoder.fc7] 274 | 275 | listkey = [['conv1_1', 'conv1_2'], ['conv2_1', 'conv2_2'], ['conv3_1', 'conv3_2', 'conv3_3'], 276 | ['conv4_1', 'conv4_2', 'conv4_3'], ['conv5_1', 'conv5_2', 'conv5_3'], ['fc6'], ['fc7']] 277 | 278 | for idx, conv_block in enumerate(rgb_conv_blocks): 279 | num_conv = 0 280 | for l2 in conv_block: 281 | if isinstance(l2, nn.Conv2d): 282 | num_conv += 1 283 | l2.weight.data = pretrain_vgg16_1024[str(listkey[idx][num_conv - 1]) + '.weight'] 284 | l2.bias.data = pretrain_vgg16_1024[str(listkey[idx][num_conv - 1]) + '.bias'].squeeze(0).squeeze(0).squeeze(0).squeeze(0) 285 | 286 | for idx, conv_block in enumerate(depth_conv_blocks): 287 | num_conv = 0 288 | for l2 in conv_block: 289 | if isinstance(l2, nn.Conv2d): 290 | num_conv += 1 291 | l2.weight.data = pretrain_vgg16_1024[str(listkey[idx][num_conv - 1]) + '.weight'] 292 | l2.bias.data = pretrain_vgg16_1024[str(listkey[idx][num_conv - 1]) + '.bias'].squeeze(0).squeeze( 293 | 0).squeeze(0).squeeze(0) 294 | return self 295 | -------------------------------------------------------------------------------- /ImageDepthNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .ImageDepthNet import ImageDepthNet 2 | -------------------------------------------------------------------------------- /ImageDepthNet/__pycache__/ImageDepthNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageDepthNet/__pycache__/ImageDepthNet.cpython-36.pyc -------------------------------------------------------------------------------- /ImageDepthNet/__pycache__/ImageFlowNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageDepthNet/__pycache__/ImageFlowNet.cpython-36.pyc -------------------------------------------------------------------------------- /ImageDepthNet/__pycache__/ImageFlowNet_parts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageDepthNet/__pycache__/ImageFlowNet_parts.cpython-36.pyc -------------------------------------------------------------------------------- /ImageDepthNet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/ImageDepthNet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # S2MA 2 | source code for our CVPR 2020 paper “Learning Selective Self-Mutual Attention for RGB-D Saliency Detection” by Nian Liu, Ni Zhang and Junwei Han. 3 | 4 | created by Ni Zhang, email: nnizhang.1995@gmail.com 5 | 6 | ## Usage 7 | 8 | ### Requirement 9 | 1. pytorch 0.4.1 10 | 2. torchvision 0.1.8 11 | 12 | ### Training 13 | 1. download the RBD-D datasets [[baidu pan](https://pan.baidu.com/s/1q4g9n_n4X_b4WbrhiFuxOw) fetch code: chdz | [Google drive](https://drive.google.com/drive/folders/1ZKK7Le5veXJVD3DZ8OdrO9CdqL2QOFAl?usp=sharing)] and pretrained VGG model [[baidu pan](https://pan.baidu.com/s/19cik8v7Ix5YOo7sdEosp9A) fetch code: dyt4 | [Google drive](https://drive.google.com/drive/folders/1ZKK7Le5veXJVD3DZ8OdrO9CdqL2QOFAl?usp=sharing)], then put them in the ./RGBdDataset_processed directory and ./pretrained_model directory, respectively. 14 | 2. run `python generate_list.py` to generate the image lists. 15 | 3. modify codes in the parameter.py 16 | 4. start to train with `python train.py` 17 | 18 | 19 | ### Testing 20 | 1. download our models [[baidu pan](https://pan.baidu.com/s/16hfdk-yE5-sy9B9v6oT1oQ) fetch code: ly9k | [Google drive](https://drive.google.com/drive/folders/1ZKK7Le5veXJVD3DZ8OdrO9CdqL2QOFAl?usp=sharing)] and put them in the ./models directory. After downloading, you can find two models (S2MA.pth and S2MA_DUT.pth). S2MA_DUT.pth is used for testing on the DUT-RGBD dataset and S2MA.pth is used for testing on the rest datasets. 21 | 2. modify codes in the parameter.py 22 | 3. start to test with `python test.py` and the saliency maps will be generated in the ./output directory. 23 | 24 | Our saliency maps can be download from [[baidu pan](https://pan.baidu.com/s/1G-M18V7taJZb44awqxg4tw) fetch code: frzb | [Google drive](https://drive.google.com/drive/folders/1ZKK7Le5veXJVD3DZ8OdrO9CdqL2QOFAl?usp=sharing)]. 25 | 26 | ## Acknowledgement 27 | We use some opensource codes from [Non-local_pytorch](https://github.com/AlexHex7/Non-local_pytorch), [denseASPP](https://github.com/DeepMotionAIResearch/DenseASPP). Thanks for the authors. 28 | 29 | ## Citing our work 30 | If you think our work is helpful, please cite 31 | ``` 32 | @inproceedings{liu2020S2MA, 33 | title={Learning Selective Self-Mutual Attention for RGB-D Saliency Detection}, 34 | author={Liu, Nian and Zhang, Ni and Han, Junwei}, 35 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 36 | pages={13756--13765}, 37 | year={2020} 38 | } 39 | 40 | -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-00-59.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-00-59.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-02-01.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-02-01.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-05-11.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-05-11.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-06-39.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-06-39.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-10-53.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-10-53.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-11-07.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-11-07.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-13-37.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-13-37.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-16-00.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-16-00.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-16-04.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-16-04.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-16-21.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-16-21.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-17-25.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-17-25.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-17-42.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-17-42.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-18-10.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-18-10.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-18-15.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-18-15.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-20-50.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-20-50.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-21-12.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-21-12.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-21-21.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-21-21.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-21-41.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-21-41.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-22-41.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-22-41.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_01-30-24.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_01-30-24.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_02-58-26.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_02-58-26.bmp -------------------------------------------------------------------------------- /RGBdDataset_processed/NLPR/testset/depth/10_02-58-46.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/RGBdDataset_processed/NLPR/testset/depth/10_02-58-46.bmp -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from torch.utils import data 4 | import transforms as trans 5 | from torchvision import transforms 6 | import random 7 | from parameter import * 8 | 9 | 10 | def load_list(file): 11 | 12 | with open(file) as f: 13 | lines = f.read().splitlines() 14 | 15 | files = [] 16 | depths = [] 17 | labels = [] 18 | 19 | for line in lines: 20 | files.append(line.split(' ')[0]) 21 | depths.append(line.split(' ')[1]) 22 | labels.append(line.split(' ')[2]) 23 | 24 | return files, depths, labels 25 | 26 | 27 | def load_test_list(file): 28 | 29 | with open(file) as f: 30 | lines = f.read().splitlines() 31 | 32 | files = [] 33 | depths = [] 34 | for line in lines: 35 | files.append(line.split(' ')[0]) 36 | depths.append(line.split(' ')[1]) 37 | 38 | return files, depths 39 | 40 | 41 | class ImageData(data.Dataset): 42 | def __init__(self, img_root, transform, depth_transform, t_transform, label_32_transform, label_64_transform, label_128_transform, mode): 43 | 44 | if mode == 'train': 45 | self.image_path, self.depth_path, self.label_path = load_list(img_root) 46 | else: 47 | self.image_path, self.depth_path = load_test_list(img_root) 48 | 49 | self.transform = transform 50 | self.depth_transform = depth_transform 51 | self.t_transform = t_transform 52 | self.label_32_transform = label_32_transform 53 | self.label_64_transform = label_64_transform 54 | self.label_128_transform = label_128_transform 55 | self.mode = mode 56 | 57 | def __getitem__(self, item): 58 | fn = self.image_path[item].split('/') 59 | 60 | filename = fn[-1] 61 | image = Image.open(self.image_path[item]).convert('RGB') 62 | image_w, image_h = int(image.size[0]), int(image.size[1]) 63 | depth = Image.open(self.depth_path[item]).convert('L') 64 | 65 | # data augmentation 66 | if self.mode == 'train': 67 | 68 | label = Image.open(self.label_path[item]).convert('L') 69 | random_size = scale_size 70 | 71 | new_img = trans.Scale((random_size, random_size))(image) 72 | new_depth = trans.Scale((random_size, random_size))(depth) 73 | new_label = trans.Scale((random_size, random_size), interpolation=Image.NEAREST)(label) 74 | 75 | # random crop 76 | w, h = new_img.size 77 | if w != img_size and h != img_size: 78 | x1 = random.randint(0, w - img_size) 79 | y1 = random.randint(0, h - img_size) 80 | new_img = new_img.crop((x1, y1, x1 + img_size, y1 + img_size)) 81 | new_depth = new_depth.crop((x1, y1, x1 + img_size, y1 + img_size)) 82 | new_label = new_label.crop((x1, y1, x1 + img_size, y1 + img_size)) 83 | 84 | # random flip 85 | if random.random() < 0.5: 86 | new_img = new_img.transpose(Image.FLIP_LEFT_RIGHT) 87 | new_depth = new_depth.transpose(Image.FLIP_LEFT_RIGHT) 88 | new_label = new_label.transpose(Image.FLIP_LEFT_RIGHT) 89 | 90 | new_img = self.transform(new_img) 91 | new_depth = self.depth_transform(new_depth) 92 | 93 | new_depth = new_depth.expand(3, img_size, img_size) 94 | label_256 = self.t_transform(new_label) 95 | if self.label_32_transform is not None and self.label_64_transform is not None and self.label_128_transform is\ 96 | not None: 97 | label_32 = self.label_32_transform(new_label) 98 | label_64 = self.label_64_transform(new_label) 99 | label_128 = self.label_128_transform(new_label) 100 | return new_img, new_depth, label_256, label_32, label_64, label_128, filename 101 | else: 102 | 103 | image = self.transform(image) 104 | depth = self.depth_transform(depth) 105 | depth = depth.expand(3, img_size, img_size) 106 | 107 | return image, depth, image_w, image_h, self.image_path[item] 108 | 109 | def __len__(self): 110 | return len(self.image_path) 111 | 112 | 113 | def get_loader(img_root, img_size, batch_size, mode='train', num_thread=1): 114 | shuffle = False 115 | 116 | mean_bgr = torch.Tensor(3, 256, 256) 117 | mean_bgr[0, :, :] = 104.008 # B 118 | mean_bgr[1, :, :] = 116.669 # G 119 | mean_bgr[2, :, :] = 122.675 # R 120 | 121 | depth_mean_bgr = torch.Tensor(1, 256, 256) 122 | depth_mean_bgr[0, :, :] = 115.8695 123 | 124 | if mode == 'train': 125 | transform = trans.Compose([ 126 | # trans.ToTensor image -> [0,255] 127 | trans.ToTensor_BGR(), 128 | trans.Lambda(lambda x: x - mean_bgr) 129 | ]) 130 | 131 | depth_transform = trans.Compose([ 132 | # trans.ToTensor image -> [0,255] 133 | trans.ToTensor(), 134 | trans.Lambda(lambda x: x - depth_mean_bgr) 135 | ]) 136 | 137 | t_transform = trans.Compose([ 138 | # transform.ToTensor label -> [0,1] 139 | transforms.ToTensor(), 140 | ]) 141 | label_32_transform = trans.Compose([ 142 | trans.Scale((32, 32), interpolation=Image.NEAREST), 143 | transforms.ToTensor(), 144 | ]) 145 | label_64_transform = trans.Compose([ 146 | trans.Scale((64, 64), interpolation=Image.NEAREST), 147 | transforms.ToTensor(), 148 | ]) 149 | label_128_transform = trans.Compose([ 150 | trans.Scale((128, 128), interpolation=Image.NEAREST), 151 | transforms.ToTensor(), 152 | ]) 153 | shuffle = True 154 | else: 155 | transform = trans.Compose([ 156 | trans.Scale((img_size, img_size)), 157 | trans.ToTensor_BGR(), 158 | trans.Lambda(lambda x: x - mean_bgr) 159 | ]) 160 | 161 | depth_transform = trans.Compose([ 162 | trans.Scale((img_size, img_size)), 163 | trans.ToTensor(), 164 | trans.Lambda(lambda x: x - depth_mean_bgr) 165 | ]) 166 | 167 | t_transform = trans.Compose([ 168 | trans.Scale((img_size, img_size), interpolation=Image.NEAREST), 169 | transforms.ToTensor(), 170 | ]) 171 | if mode == 'train': 172 | dataset = ImageData(img_root, transform, depth_transform, t_transform, label_32_transform, label_64_transform, label_128_transform, mode) 173 | else: 174 | dataset = ImageData(img_root, transform, depth_transform, t_transform, label_32_transform=None, label_64_transform=None, label_128_transform=None, mode=mode) 175 | 176 | data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_thread) 177 | return data_loader 178 | 179 | -------------------------------------------------------------------------------- /finetune_DUT_RGBD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torch.nn as nn 4 | from torch import optim 5 | from torch.autograd import Variable 6 | 7 | from dataset import get_loader 8 | import math 9 | from parameter_finetune_DUT_RGBD import * 10 | 11 | from ImageDepthNet import ImageDepthNet 12 | import os 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 15 | cudnn.benchmark = True 16 | 17 | 18 | def save_loss(save_dir, whole_iter_num, epoch_total_loss, epoch_loss, epoch): 19 | fh = open(save_dir, 'a') 20 | epoch_total_loss = str(epoch_total_loss) 21 | epoch_loss = str(epoch_loss) 22 | fh.write('until_' + str(epoch) + '_run_iter_num' + str(whole_iter_num) + '\n') 23 | fh.write(str(epoch) + '_epoch_total_loss' + epoch_total_loss + '\n') 24 | fh.write(str(epoch) + '_epoch_loss' + epoch_loss + '\n') 25 | fh.write('\n') 26 | fh.close() 27 | 28 | 29 | def adjust_learning_rate(optimizer, decay_rate=.1): 30 | update_lr_group = optimizer.param_groups 31 | for param_group in update_lr_group: 32 | print('before lr: ', param_group['lr']) 33 | param_group['lr'] = param_group['lr'] * decay_rate 34 | print('after lr: ', param_group['lr']) 35 | return optimizer 36 | 37 | 38 | def save_lr(save_dir, optimizer): 39 | update_lr_group = optimizer.param_groups[0] 40 | fh = open(save_dir, 'a') 41 | fh.write('encode:update:lr' + str(update_lr_group['lr']) + '\n') 42 | fh.write('decode:update:lr' + str(update_lr_group['lr']) + '\n') 43 | fh.write('\n') 44 | fh.close() 45 | 46 | 47 | def train_net(net): 48 | 49 | train_loader = get_loader(train_dir_img, img_size, batch_size, mode='train', 50 | num_thread=4) 51 | 52 | print(''' 53 | Starting training: 54 | Train steps: {} 55 | Batch size: {} 56 | Learning rate: {} 57 | Training size: {} 58 | '''.format(train_steps, batch_size, lr, len(train_loader.dataset))) 59 | 60 | N_train = len(train_loader) * batch_size 61 | 62 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 63 | 64 | criterion = nn.BCEWithLogitsLoss() 65 | whole_iter_num = 0 66 | iter_num = math.ceil(len(train_loader.dataset) / batch_size) 67 | for epoch in range(epochs): 68 | 69 | print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) 70 | print('epoch:{0}-------lr:{1}'.format(epoch + 1, lr)) 71 | 72 | epoch_total_loss = 0 73 | epoch_loss = 0 74 | 75 | for i, data_batch in enumerate(train_loader): 76 | if (i + 1) > iter_num: break 77 | images, depths, label_256, label_32, label_64, label_128, filename = data_batch 78 | images, depths, label_256 = Variable(images.cuda()), Variable(depths.cuda()), Variable(label_256.cuda()) 79 | label_32, label_64, label_128 = Variable(label_32.cuda()), Variable(label_64.cuda()), \ 80 | Variable(label_128.cuda()) 81 | 82 | outputs_image, outputs_depth = net(images, depths) 83 | for_loss6, for_loss5, for_loss4, for_loss3, for_loss2, for_loss1 = outputs_image 84 | depth_for_loss6, depth_for_loss5, depth_for_loss4, depth_for_loss3, depth_for_loss2, depth_for_loss1 = outputs_depth 85 | 86 | # loss 87 | loss6 = criterion(for_loss6, label_32) 88 | loss5 = criterion(for_loss5, label_32) 89 | loss4 = criterion(for_loss4, label_32) 90 | loss3 = criterion(for_loss3, label_64) 91 | loss2 = criterion(for_loss2, label_128) 92 | loss1 = criterion(for_loss1, label_256) 93 | 94 | img_total_loss = loss_weights[0] * loss1 + loss_weights[1] * loss2 + loss_weights[2] * loss3\ 95 | + loss_weights[3] * loss4 + loss_weights[4] * loss5 + loss_weights[5] * loss6 96 | 97 | # depth loss 98 | 99 | depth_loss6 = criterion(depth_for_loss6, label_32) 100 | depth_loss5 = criterion(depth_for_loss5, label_32) 101 | depth_loss4 = criterion(depth_for_loss4, label_32) 102 | depth_loss3 = criterion(depth_for_loss3, label_64) 103 | depth_loss2 = criterion(depth_for_loss2, label_128) 104 | depth_loss1 = criterion(depth_for_loss1, label_256) 105 | 106 | depth_total_loss = loss_weights[0] * depth_loss1 + loss_weights[1] * depth_loss2 + loss_weights[2] * depth_loss3\ 107 | + loss_weights[3] * depth_loss4 + loss_weights[4] * depth_loss5 + loss_weights[5] * depth_loss6 108 | 109 | total_loss = img_total_loss + depth_total_loss 110 | 111 | epoch_total_loss += total_loss.cpu().data.item() 112 | epoch_loss += loss1.cpu().data.item() 113 | 114 | print('whole_iter_num: {0} --- {1:.4f} --- total_loss: {2:.6f} --- loss: {3:.6f}'.format((whole_iter_num + 1), 115 | (i + 1) * batch_size / N_train, total_loss.item(), loss1.item())) 116 | 117 | optimizer.zero_grad() 118 | 119 | total_loss.backward() 120 | 121 | optimizer.step() 122 | whole_iter_num += 1 123 | 124 | if whole_iter_num == train_steps: 125 | torch.save(net.state_dict(), 126 | save_model_dir + 'iterations{}.pth'.format(train_steps)) 127 | return 128 | 129 | if whole_iter_num == stepvalue1 or whole_iter_num == stepvalue2: 130 | optimizer = adjust_learning_rate(optimizer, decay_rate=lr_decay_gamma) 131 | save_lr(save_lossdir, optimizer) 132 | print('have updated lr!!') 133 | 134 | print('Epoch finished ! Loss: {}'.format(epoch_total_loss / iter_num)) 135 | 136 | save_loss(save_lossdir, whole_iter_num, epoch_total_loss / iter_num, epoch_loss/iter_num, epoch+1) 137 | torch.save(net.state_dict(), 138 | save_model_dir + 'MODEL_EPOCH{}.pth'.format(epoch + 1)) 139 | print('Saved') 140 | 141 | 142 | if __name__ == '__main__': 143 | 144 | net = ImageDepthNet(3) 145 | 146 | net.load_state_dict(torch.load(load_model)) 147 | print('load model:', load_model) 148 | 149 | net.train() 150 | net.cuda() 151 | 152 | train_net(net) 153 | 154 | -------------------------------------------------------------------------------- /generate_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | # dataset_dir = 'your rgb-d dataset path' 5 | dataset_dir = '/data/zhangni/Data/RGB-D_Saliency/RGBdDataset_processed' 6 | 7 | 8 | # train 9 | NJU2K_train = True 10 | NLPR_train = True 11 | DUT_RGBD_train = True 12 | 13 | # test 14 | NJU2K_test = True 15 | NLPR_test = True 16 | DUT_RGBD_test = True 17 | RGBD135 = True 18 | LFSD = True 19 | SSD = True 20 | STERE = True 21 | 22 | 23 | if NJU2K_train: 24 | root = dataset_dir + '/NJU2K/trainset' 25 | imgs = os.listdir(os.path.join(root, 'RGB')) 26 | 27 | for img in imgs: 28 | f = open('list/train/NJU2K_NLRP_train_list.txt', 'a') 29 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.bmp') + ' ' + root 30 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 31 | 32 | if NLPR_train: 33 | root = dataset_dir + '/NLPR/trainset' 34 | imgs = os.listdir(os.path.join(root, 'RGB')) 35 | 36 | for img in imgs: 37 | f = open('list/train/NJU2K_NLRP_train_list.txt', 'a') 38 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.bmp') + ' ' + root 39 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 40 | 41 | 42 | if DUT_RGBD_train: 43 | root = dataset_dir + '/DUT-RGBD/trainset' 44 | imgs = os.listdir(os.path.join(root, 'RGB')) 45 | 46 | for img in imgs: 47 | f = open('list/train/DUT_train_list.txt', 'a') 48 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.png') + ' ' + root 49 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 50 | 51 | 52 | if NJU2K_test: 53 | root = dataset_dir + '/NJU2K/testset' 54 | imgs = os.listdir(os.path.join(root, 'RGB')) 55 | 56 | for img in imgs: 57 | f = open('list/test/NJU2K_test_list.txt', 'a') 58 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.bmp') + ' ' + root 59 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 60 | 61 | if NLPR_test: 62 | root = dataset_dir + '/NLPR/testset' 63 | imgs = os.listdir(os.path.join(root, 'RGB')) 64 | 65 | for img in imgs: 66 | f = open('list/test/NLRP_test_list.txt', 'a') 67 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.bmp') + ' ' + root 68 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 69 | 70 | 71 | if DUT_RGBD_test: 72 | root = dataset_dir + '/DUT-RGBD/testset' 73 | imgs = os.listdir(os.path.join(root, 'RGB')) 74 | 75 | for img in imgs: 76 | f = open('list/test/DUT_RGBD_test_list.txt', 'a') 77 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.png') + ' ' + root 78 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 79 | 80 | 81 | if RGBD135: 82 | root = dataset_dir + '/RGBD135' 83 | imgs = os.listdir(os.path.join(root, 'RGB')) 84 | 85 | for img in imgs: 86 | f = open('list/test/RGBD135_test_list.txt', 'a') 87 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.bmp') + ' ' + root 88 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 89 | 90 | if LFSD: 91 | root = dataset_dir + '/LFSD' 92 | imgs = os.listdir(os.path.join(root, 'RGB')) 93 | 94 | for img in imgs: 95 | f = open('list/test/LFSD_test_list.txt', 'a') 96 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.bmp') + ' ' + root 97 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 98 | 99 | if SSD: 100 | root = dataset_dir + '/SSD100' 101 | imgs = os.listdir(os.path.join(root, 'RGB')) 102 | 103 | for img in imgs: 104 | f = open('list/test/SSD_test_list.txt', 'a') 105 | f.write(root + '/RGB/' + img + ' ' + root + '/depth/' + img.replace('.jpg', '.bmp') + ' ' + root 106 | + '/GT/' + img.replace('.jpg', '.png') + '\n') 107 | 108 | if STERE: 109 | root = dataset_dir + '/STERE' 110 | imgs = os.listdir(os.path.join(root, 'GT')) 111 | 112 | for img in imgs: 113 | f = open('list/test/STERE_test_list.txt', 'a') 114 | f.write(root + '/RGB/' + img.replace('.png', '.jpg') + ' ' + root + '/depth/' + img + ' ' + root 115 | + '/GT/' + img + '\n') -------------------------------------------------------------------------------- /model_epoch_loss/loss.txt: -------------------------------------------------------------------------------- 1 | until_1_run_iter_num257 2 | 1_epoch_total_loss2.9147255685078957 3 | 1_epoch_loss0.315449187147942 4 | 5 | until_2_run_iter_num514 6 | 2_epoch_total_loss2.3672789905785585 7 | 2_epoch_loss0.2537469699514979 8 | 9 | until_3_run_iter_num771 10 | 3_epoch_total_loss2.2081247197050993 11 | 3_epoch_loss0.23648921258254738 12 | 13 | until_4_run_iter_num1028 14 | 4_epoch_total_loss2.035446706217087 15 | 4_epoch_loss0.2171872167445806 16 | 17 | until_5_run_iter_num1285 18 | 5_epoch_total_loss1.9525965127499652 19 | 5_epoch_loss0.20640648840582324 20 | 21 | until_6_run_iter_num1542 22 | 6_epoch_total_loss1.8969010675927545 23 | 6_epoch_loss0.20080818424312985 24 | 25 | until_7_run_iter_num1799 26 | 7_epoch_total_loss1.903412134962787 27 | 7_epoch_loss0.20123525088515262 28 | 29 | until_8_run_iter_num2056 30 | 8_epoch_total_loss1.789656961010588 31 | 8_epoch_loss0.18779330158511953 32 | 33 | until_9_run_iter_num2313 34 | 9_epoch_total_loss1.7295767590693463 35 | 9_epoch_loss0.18207231346619268 36 | 37 | until_10_run_iter_num2570 38 | 10_epoch_total_loss1.709370740656723 39 | 10_epoch_loss0.17973679005519888 40 | 41 | until_11_run_iter_num2827 42 | 11_epoch_total_loss1.655074855232981 43 | 11_epoch_loss0.17493339662769888 44 | 45 | until_12_run_iter_num3084 46 | 12_epoch_total_loss1.6083429008142494 47 | 12_epoch_loss0.16876627226515967 48 | 49 | until_13_run_iter_num3341 50 | 13_epoch_total_loss1.6429965196880385 51 | 13_epoch_loss0.17326313165897061 52 | 53 | until_14_run_iter_num3598 54 | 14_epoch_total_loss1.5823060007410754 55 | 14_epoch_loss0.16589014451037584 56 | 57 | until_15_run_iter_num3855 58 | 15_epoch_total_loss1.6396020799295448 59 | 15_epoch_loss0.17362643897185528 60 | 61 | until_16_run_iter_num4112 62 | 16_epoch_total_loss1.5857775981324191 63 | 16_epoch_loss0.1661415981991282 64 | 65 | until_17_run_iter_num4369 66 | 17_epoch_total_loss1.5108513811219064 67 | 17_epoch_loss0.15883862588134257 68 | 69 | until_18_run_iter_num4626 70 | 18_epoch_total_loss1.5025985475180215 71 | 18_epoch_loss0.15660575542468505 72 | 73 | until_19_run_iter_num4883 74 | 19_epoch_total_loss1.456015890906293 75 | 19_epoch_loss0.1513756904243727 76 | 77 | until_20_run_iter_num5140 78 | 20_epoch_total_loss1.4264661442444946 79 | 20_epoch_loss0.14774624288719915 80 | 81 | until_21_run_iter_num5397 82 | 21_epoch_total_loss1.4134679174608757 83 | 21_epoch_loss0.14611534713423205 84 | 85 | until_22_run_iter_num5654 86 | 22_epoch_total_loss1.361613309568932 87 | 22_epoch_loss0.13995363202366384 88 | 89 | until_23_run_iter_num5911 90 | 23_epoch_total_loss1.3814432993241321 91 | 23_epoch_loss0.142436565925523 92 | 93 | until_24_run_iter_num6168 94 | 24_epoch_total_loss1.3931302337794915 95 | 24_epoch_loss0.14455866717990734 96 | 97 | until_25_run_iter_num6425 98 | 25_epoch_total_loss1.40492737965825 99 | 25_epoch_loss0.1440146196966969 100 | 101 | until_26_run_iter_num6682 102 | 26_epoch_total_loss1.3753630602406157 103 | 26_epoch_loss0.14161113551038712 104 | 105 | until_27_run_iter_num6939 106 | 27_epoch_total_loss1.2652437132620162 107 | 27_epoch_loss0.12959771442448118 108 | 109 | until_28_run_iter_num7196 110 | 28_epoch_total_loss1.3094334402900725 111 | 28_epoch_loss0.13394566996436175 112 | 113 | until_29_run_iter_num7453 114 | 29_epoch_total_loss1.282642172236387 115 | 29_epoch_loss0.13152918081861062 116 | 117 | until_30_run_iter_num7710 118 | 30_epoch_total_loss1.504320443604243 119 | 30_epoch_loss0.15680220121597502 120 | 121 | until_31_run_iter_num7967 122 | 31_epoch_total_loss1.2424432192331158 123 | 31_epoch_loss0.12624149701416723 124 | 125 | until_32_run_iter_num8224 126 | 32_epoch_total_loss1.2452753865765227 127 | 32_epoch_loss0.12546248239707855 128 | 129 | until_33_run_iter_num8481 130 | 33_epoch_total_loss1.270309661844825 131 | 33_epoch_loss0.12728911232391685 132 | 133 | until_34_run_iter_num8738 134 | 34_epoch_total_loss1.2159730858144129 135 | 34_epoch_loss0.12211641406908573 136 | 137 | until_35_run_iter_num8995 138 | 35_epoch_total_loss1.1736004027875018 139 | 35_epoch_loss0.11635396527380099 140 | 141 | until_36_run_iter_num9252 142 | 36_epoch_total_loss1.21599005562786 143 | 36_epoch_loss0.12250238355495587 144 | 145 | until_37_run_iter_num9509 146 | 37_epoch_total_loss1.1510989443337407 147 | 37_epoch_loss0.11473687917921793 148 | 149 | until_38_run_iter_num9766 150 | 38_epoch_total_loss1.117409409484047 151 | 38_epoch_loss0.1108188926013991 152 | 153 | until_39_run_iter_num10023 154 | 39_epoch_total_loss1.1681004572471292 155 | 39_epoch_loss0.11763532507570337 156 | 157 | until_40_run_iter_num10280 158 | 40_epoch_total_loss1.1694527045762029 159 | 40_epoch_loss0.11756650835333632 160 | 161 | until_41_run_iter_num10537 162 | 41_epoch_total_loss1.2503190131039008 163 | 41_epoch_loss0.12518644124153522 164 | 165 | until_42_run_iter_num10794 166 | 42_epoch_total_loss1.1561251325598023 167 | 42_epoch_loss0.11672705224921731 168 | 169 | until_43_run_iter_num11051 170 | 43_epoch_total_loss1.1944208133545366 171 | 43_epoch_loss0.12068063156237166 172 | 173 | until_44_run_iter_num11308 174 | 44_epoch_total_loss1.088732536084921 175 | 44_epoch_loss0.10707175698319761 176 | 177 | until_45_run_iter_num11565 178 | 45_epoch_total_loss1.0972721162705106 179 | 45_epoch_loss0.1087632247533084 180 | 181 | until_46_run_iter_num11822 182 | 46_epoch_total_loss1.1076280652085166 183 | 46_epoch_loss0.10931139816222024 184 | 185 | until_47_run_iter_num12079 186 | 47_epoch_total_loss1.1115526846874548 187 | 47_epoch_loss0.11111226561932248 188 | 189 | until_48_run_iter_num12336 190 | 48_epoch_total_loss1.0578289484235563 191 | 48_epoch_loss0.10394000006904862 192 | 193 | until_49_run_iter_num12593 194 | 49_epoch_total_loss1.034591721537512 195 | 49_epoch_loss0.10160712302021016 196 | 197 | until_50_run_iter_num12850 198 | 50_epoch_total_loss1.0637416328205673 199 | 50_epoch_loss0.10471972894169941 200 | 201 | until_51_run_iter_num13107 202 | 51_epoch_total_loss0.9864993151059874 203 | 51_epoch_loss0.09567699742282411 204 | 205 | until_52_run_iter_num13364 206 | 52_epoch_total_loss1.128371318258664 207 | 52_epoch_loss0.11329957715993029 208 | 209 | until_53_run_iter_num13621 210 | 53_epoch_total_loss1.0044042179324748 211 | 53_epoch_loss0.09740827621797875 212 | 213 | until_54_run_iter_num13878 214 | 54_epoch_total_loss0.9527588265183371 215 | 54_epoch_loss0.09227437552079153 216 | 217 | until_55_run_iter_num14135 218 | 55_epoch_total_loss1.067799113140032 219 | 55_epoch_loss0.10549948018058729 220 | 221 | until_56_run_iter_num14392 222 | 56_epoch_total_loss1.0037250494446737 223 | 56_epoch_loss0.09738325593623438 224 | 225 | until_57_run_iter_num14649 226 | 57_epoch_total_loss1.0048139448991547 227 | 57_epoch_loss0.09810962346799179 228 | 229 | until_58_run_iter_num14906 230 | 58_epoch_total_loss1.086166602512278 231 | 58_epoch_loss0.10866496922151124 232 | 233 | until_59_run_iter_num15163 234 | 59_epoch_total_loss1.1080880610395498 235 | 59_epoch_loss0.10998567686586529 236 | 237 | until_60_run_iter_num15420 238 | 60_epoch_total_loss0.9847406962966176 239 | 60_epoch_loss0.09515144283147638 240 | 241 | until_61_run_iter_num15677 242 | 61_epoch_total_loss0.8655691803197453 243 | 61_epoch_loss0.08108170419670496 244 | 245 | until_62_run_iter_num15934 246 | 62_epoch_total_loss0.9167658494140387 247 | 62_epoch_loss0.08843703966602277 248 | 249 | until_63_run_iter_num16191 250 | 63_epoch_total_loss0.9811931890272445 251 | 63_epoch_loss0.09570551518684231 252 | 253 | until_64_run_iter_num16448 254 | 64_epoch_total_loss0.9523202685298623 255 | 64_epoch_loss0.0914396683573143 256 | 257 | until_65_run_iter_num16705 258 | 65_epoch_total_loss0.9092262264355612 259 | 65_epoch_loss0.08711089811320434 260 | 261 | until_66_run_iter_num16962 262 | 66_epoch_total_loss0.9030005573530605 263 | 66_epoch_loss0.08530807166074036 264 | 265 | until_67_run_iter_num17219 266 | 67_epoch_total_loss0.9949567871103027 267 | 67_epoch_loss0.09793605567735225 268 | 269 | until_68_run_iter_num17476 270 | 68_epoch_total_loss0.8997945252095679 271 | 68_epoch_loss0.085144241902151 272 | 273 | until_69_run_iter_num17733 274 | 69_epoch_total_loss0.8809142460618965 275 | 69_epoch_loss0.08259011528922194 276 | 277 | until_70_run_iter_num17990 278 | 70_epoch_total_loss0.8758771908886238 279 | 70_epoch_loss0.08297502832190072 280 | 281 | until_71_run_iter_num18247 282 | 71_epoch_total_loss0.8161530900558145 283 | 71_epoch_loss0.0760208034329841 284 | 285 | until_72_run_iter_num18504 286 | 72_epoch_total_loss0.866171500214343 287 | 72_epoch_loss0.081990071887636 288 | 289 | until_73_run_iter_num18761 290 | 73_epoch_total_loss0.9889248210632383 291 | 73_epoch_loss0.0960127595841073 292 | 293 | until_74_run_iter_num19018 294 | 74_epoch_total_loss0.8465735300041822 295 | 74_epoch_loss0.07819042175805291 296 | 297 | until_75_run_iter_num19275 298 | 75_epoch_total_loss0.8944964356691457 299 | 75_epoch_loss0.08484508948770246 300 | 301 | until_76_run_iter_num19532 302 | 76_epoch_total_loss1.0538172172201283 303 | 76_epoch_loss0.10389163353795672 304 | 305 | until_77_run_iter_num19789 306 | 77_epoch_total_loss0.8445532666338094 307 | 77_epoch_loss0.07845945701959764 308 | 309 | encode:update:lr0.001 310 | decode:update:lr0.001 311 | 312 | until_78_run_iter_num20046 313 | 78_epoch_total_loss0.7819390322447751 314 | 78_epoch_loss0.07185748056418005 315 | 316 | until_79_run_iter_num20303 317 | 79_epoch_total_loss0.6370890648448514 318 | 79_epoch_loss0.05380115555128235 319 | 320 | until_80_run_iter_num20560 321 | 80_epoch_total_loss0.608880815454958 322 | 80_epoch_loss0.05057510671274671 323 | 324 | until_81_run_iter_num20817 325 | 81_epoch_total_loss0.5912573366545517 326 | 81_epoch_loss0.048611669710580015 327 | 328 | until_82_run_iter_num21074 329 | 82_epoch_total_loss0.5855925998103294 330 | 82_epoch_loss0.047767939647579936 331 | 332 | until_83_run_iter_num21331 333 | 83_epoch_total_loss0.5748265224440089 334 | 83_epoch_loss0.046724781843358903 335 | 336 | until_84_run_iter_num21588 337 | 84_epoch_total_loss0.5709986604373278 338 | 84_epoch_loss0.046138237574742925 339 | 340 | until_85_run_iter_num21845 341 | 85_epoch_total_loss0.5661133538192348 342 | 85_epoch_loss0.04595781586600417 343 | 344 | until_86_run_iter_num22102 345 | 86_epoch_total_loss0.5539664819778636 346 | 86_epoch_loss0.044184205654994056 347 | 348 | until_87_run_iter_num22359 349 | 87_epoch_total_loss0.5499959524851364 350 | 87_epoch_loss0.043900862082491124 351 | 352 | until_88_run_iter_num22616 353 | 88_epoch_total_loss0.5494662551100616 354 | 88_epoch_loss0.043509423138450555 355 | 356 | until_89_run_iter_num22873 357 | 89_epoch_total_loss0.5407229823129186 358 | 89_epoch_loss0.042557877825449876 359 | 360 | until_90_run_iter_num23130 361 | 90_epoch_total_loss0.5394098395735373 362 | 90_epoch_loss0.042500946272150085 363 | 364 | until_91_run_iter_num23387 365 | 91_epoch_total_loss0.5281071924977729 366 | 91_epoch_loss0.04132863341428426 367 | 368 | until_92_run_iter_num23644 369 | 92_epoch_total_loss0.5289088325277841 370 | 92_epoch_loss0.041061711504852494 371 | 372 | until_93_run_iter_num23901 373 | 93_epoch_total_loss0.5272036815199871 374 | 93_epoch_loss0.04113331990106328 375 | 376 | until_94_run_iter_num24158 377 | 94_epoch_total_loss0.5178754832378157 378 | 94_epoch_loss0.04018079099908529 379 | 380 | until_95_run_iter_num24415 381 | 95_epoch_total_loss0.5171937797551953 382 | 95_epoch_loss0.039843805619888734 383 | 384 | until_96_run_iter_num24672 385 | 96_epoch_total_loss0.5158559859726679 386 | 96_epoch_loss0.03961050346609683 387 | 388 | until_97_run_iter_num24929 389 | 97_epoch_total_loss0.5133076916061022 390 | 97_epoch_loss0.03971901201146586 391 | 392 | until_98_run_iter_num25186 393 | 98_epoch_total_loss0.5105405887625096 394 | 98_epoch_loss0.0391363617040014 395 | 396 | until_99_run_iter_num25443 397 | 99_epoch_total_loss0.5038399869713802 398 | 99_epoch_loss0.03819878200786587 399 | 400 | until_100_run_iter_num25700 401 | 100_epoch_total_loss0.5010551078542197 402 | 100_epoch_loss0.03794905020772019 403 | 404 | until_101_run_iter_num25957 405 | 101_epoch_total_loss0.4979936431817971 406 | 101_epoch_loss0.03744724006953754 407 | 408 | until_102_run_iter_num26214 409 | 102_epoch_total_loss0.49274639408412146 410 | 102_epoch_loss0.03692865161522353 411 | 412 | until_103_run_iter_num26471 413 | 103_epoch_total_loss0.4890569598062493 414 | 103_epoch_loss0.03686086322880325 415 | 416 | until_104_run_iter_num26728 417 | 104_epoch_total_loss0.4893108544298647 418 | 104_epoch_loss0.03668111516343944 419 | 420 | until_105_run_iter_num26985 421 | 105_epoch_total_loss0.4833360931868683 422 | 105_epoch_loss0.0358691472915302 423 | 424 | until_106_run_iter_num27242 425 | 106_epoch_total_loss0.485116560055588 426 | 106_epoch_loss0.036231909559655516 427 | 428 | until_107_run_iter_num27499 429 | 107_epoch_total_loss0.48045009548562045 430 | 107_epoch_loss0.03573642703722655 431 | 432 | until_108_run_iter_num27756 433 | 108_epoch_total_loss0.4791310576265424 434 | 108_epoch_loss0.0355711895465155 435 | 436 | until_109_run_iter_num28013 437 | 109_epoch_total_loss0.476651120974396 438 | 109_epoch_loss0.035509517213423894 439 | 440 | until_110_run_iter_num28270 441 | 110_epoch_total_loss0.47401201168386853 442 | 110_epoch_loss0.035145227163160354 443 | 444 | until_111_run_iter_num28527 445 | 111_epoch_total_loss0.4767846197817576 446 | 111_epoch_loss0.035144146336133844 447 | 448 | until_112_run_iter_num28784 449 | 112_epoch_total_loss0.4687355341034641 450 | 112_epoch_loss0.03458756940794478 451 | 452 | until_113_run_iter_num29041 453 | 113_epoch_total_loss0.4679742344158633 454 | 113_epoch_loss0.03422354329522366 455 | 456 | until_114_run_iter_num29298 457 | 114_epoch_total_loss0.46285581536562065 458 | 114_epoch_loss0.0337373405071664 459 | 460 | until_115_run_iter_num29555 461 | 115_epoch_total_loss0.4615465828764763 462 | 115_epoch_loss0.03343529821688795 463 | 464 | until_116_run_iter_num29812 465 | 116_epoch_total_loss0.4572947304768321 466 | 116_epoch_loss0.03315276170606047 467 | 468 | encode:update:lr0.0001 469 | decode:update:lr0.0001 470 | 471 | until_117_run_iter_num30069 472 | 117_epoch_total_loss0.4627621197630923 473 | 117_epoch_loss0.03360621864882077 474 | 475 | until_118_run_iter_num30326 476 | 118_epoch_total_loss0.4552246324051215 477 | 118_epoch_loss0.03256899915664577 478 | 479 | until_119_run_iter_num30583 480 | 119_epoch_total_loss0.45596304037227703 481 | 119_epoch_loss0.032713232989557056 482 | 483 | until_120_run_iter_num30840 484 | 120_epoch_total_loss0.45265980731652405 485 | 120_epoch_loss0.03240104200833038 486 | 487 | until_121_run_iter_num31097 488 | 121_epoch_total_loss0.4504545003987472 489 | 121_epoch_loss0.03222049908632782 490 | 491 | until_122_run_iter_num31354 492 | 122_epoch_total_loss0.4503245900461182 493 | 122_epoch_loss0.03210566632797514 494 | 495 | until_123_run_iter_num31611 496 | 123_epoch_total_loss0.4521118187370931 497 | 123_epoch_loss0.03242068726511433 498 | 499 | until_124_run_iter_num31868 500 | 124_epoch_total_loss0.45128802375338883 501 | 124_epoch_loss0.03238379156963487 502 | 503 | until_125_run_iter_num32125 504 | 125_epoch_total_loss0.45088195626837735 505 | 125_epoch_loss0.03226691228609373 506 | 507 | until_126_run_iter_num32382 508 | 126_epoch_total_loss0.4488588214964254 509 | 126_epoch_loss0.03216906836988397 510 | 511 | until_127_run_iter_num32639 512 | 127_epoch_total_loss0.45389116245253075 513 | 127_epoch_loss0.032103816156025526 514 | 515 | until_128_run_iter_num32896 516 | 128_epoch_total_loss0.4515056067403652 517 | 128_epoch_loss0.032224321281283747 518 | 519 | until_129_run_iter_num33153 520 | 129_epoch_total_loss0.4475089895470133 521 | 129_epoch_loss0.032081622989116244 522 | 523 | until_130_run_iter_num33410 524 | 130_epoch_total_loss0.44921680797862634 525 | 130_epoch_loss0.03177298748220683 526 | 527 | until_131_run_iter_num33667 528 | 131_epoch_total_loss0.4505925745003882 529 | 131_epoch_loss0.03193292310207734 530 | 531 | until_132_run_iter_num33924 532 | 132_epoch_total_loss0.4499709574860821 533 | 132_epoch_loss0.032228382739169587 534 | 535 | until_133_run_iter_num34181 536 | 133_epoch_total_loss0.4471386437402161 537 | 133_epoch_loss0.03175366164732304 538 | 539 | until_134_run_iter_num34438 540 | 134_epoch_total_loss0.45070924891108205 541 | 134_epoch_loss0.032128087828363426 542 | 543 | until_135_run_iter_num34695 544 | 135_epoch_total_loss0.4468509164069877 545 | 135_epoch_loss0.03165623247043632 546 | 547 | until_136_run_iter_num34952 548 | 136_epoch_total_loss0.4570802058807143 549 | 136_epoch_loss0.032946177026264165 550 | 551 | until_137_run_iter_num35209 552 | 137_epoch_total_loss0.4471096976257948 553 | 137_epoch_loss0.03172565458871511 554 | 555 | until_138_run_iter_num35466 556 | 138_epoch_total_loss0.443956793804113 557 | 138_epoch_loss0.03151714502098148 558 | 559 | until_139_run_iter_num35723 560 | 139_epoch_total_loss0.44795521084900497 561 | 139_epoch_loss0.03178354841204005 562 | 563 | until_140_run_iter_num35980 564 | 140_epoch_total_loss0.44874818941962397 565 | 140_epoch_loss0.03178112199783673 566 | 567 | until_141_run_iter_num36237 568 | 141_epoch_total_loss0.44878244429015 569 | 141_epoch_loss0.0320592653586127 570 | 571 | until_142_run_iter_num36494 572 | 142_epoch_total_loss0.4476344460759181 573 | 142_epoch_loss0.03180497993827446 574 | 575 | until_143_run_iter_num36751 576 | 143_epoch_total_loss0.44871998244918276 577 | 143_epoch_loss0.03213348399470868 578 | 579 | until_144_run_iter_num37008 580 | 144_epoch_total_loss0.4455173012115612 581 | 144_epoch_loss0.031711736944022115 582 | 583 | until_145_run_iter_num37265 584 | 145_epoch_total_loss0.4442329010147065 585 | 145_epoch_loss0.031471870310517604 586 | 587 | until_146_run_iter_num37522 588 | 146_epoch_total_loss0.44660616506630346 589 | 146_epoch_loss0.03187825586574319 590 | 591 | until_147_run_iter_num37779 592 | 147_epoch_total_loss0.4463065349522268 593 | 147_epoch_loss0.03171071094181751 594 | 595 | until_148_run_iter_num38036 596 | 148_epoch_total_loss0.4443723172645161 597 | 148_epoch_loss0.0314907329338186 598 | 599 | until_149_run_iter_num38293 600 | 149_epoch_total_loss0.44571469289319526 601 | 149_epoch_loss0.031432333800956204 602 | 603 | until_150_run_iter_num38550 604 | 150_epoch_total_loss0.44427788930180473 605 | 150_epoch_loss0.03132501293466597 606 | 607 | until_151_run_iter_num38807 608 | 151_epoch_total_loss0.44324175932296056 609 | 151_epoch_loss0.03141437526719811 610 | 611 | until_152_run_iter_num39064 612 | 152_epoch_total_loss0.44463205731796385 613 | 152_epoch_loss0.031407727863538126 614 | 615 | until_153_run_iter_num39321 616 | 153_epoch_total_loss0.4475186238724898 617 | 153_epoch_loss0.03187719220114357 618 | 619 | until_154_run_iter_num39578 620 | 154_epoch_total_loss0.4444856185625499 621 | 154_epoch_loss0.031427699728263724 622 | 623 | until_155_run_iter_num39835 624 | 155_epoch_total_loss0.4427632737600386 625 | 155_epoch_loss0.03137549389052252 626 | 627 | -------------------------------------------------------------------------------- /output/NLPR/S2MA.pth/1_02-03-44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/output/NLPR/S2MA.pth/1_02-03-44.png -------------------------------------------------------------------------------- /output/NLPR/S2MA.pth/1_02-08-35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/output/NLPR/S2MA.pth/1_02-08-35.png -------------------------------------------------------------------------------- /output/NLPR/S2MA.pth/1_02-54-18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/output/NLPR/S2MA.pth/1_02-54-18.png -------------------------------------------------------------------------------- /output/NLPR/S2MA.pth/1_02-59-20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnizhang/S2MA/f94fceede09d644f285c271b1d8d41e384e0f8ed/output/NLPR/S2MA.pth/1_02-59-20.png -------------------------------------------------------------------------------- /parameter.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # train.py 4 | gpu_id = "5" 5 | img_size = 256 6 | scale_size = 288 7 | batch_size = 8 8 | lr = 0.01 9 | epochs = 200 10 | train_steps = 40000 11 | lr_decay_gamma = 0.1 12 | stepvalue1 = 20000 13 | stepvalue2 = 30000 14 | loss_weights = [1, 0.8, 0.8, 0.5, 0.5, 0.5] 15 | bn_momentum = 0.001 16 | 17 | load_vgg_model = './pretrained_model/vgg16_20M.caffemodel.pth' 18 | 19 | train_dir_img = 'list/train/NJU2K_NLRP_train_list.txt' 20 | save_lossdir = './model_epoch_loss/loss.txt' 21 | save_model_dir = './models/' 22 | if not os.path.exists(save_model_dir): 23 | os.makedirs(save_model_dir) 24 | 25 | 26 | # test.py 27 | 28 | test_lists = ['list/test/NLRP_test_list.txt', 'list/test/NJU2K_test_list.txt', 29 | 'list/test/STERE_test_list.txt', 'list/test/SSD_test_list.txt', 30 | 'list/test/RGBD135_test_list.txt', 'list/test/LFSD_test_list.txt'] 31 | test_model = 'S2MA.pth' 32 | 33 | 34 | # test on DUT-RGBD dataset 35 | # test_lists = ['list/test/DUT_RGBD_test_list.txt'] 36 | # test_model = 'S2MA_DUT.pth' 37 | 38 | test_model_dir = save_model_dir + test_model 39 | save_test_path_root = './output/' 40 | 41 | 42 | -------------------------------------------------------------------------------- /parameter_finetune_DUT_RGBD.py: -------------------------------------------------------------------------------- 1 | 2 | # train.py 3 | gpu_id = "8" 4 | img_size = 256 5 | scale_size = 288 6 | batch_size = 8 7 | lr = 0.001 8 | epochs = 200 9 | train_steps = 40000 10 | lr_decay_gamma = 0.1 11 | stepvalue1 = 20000 12 | stepvalue2 = 30000 13 | loss_weights = [1, 0.8, 0.8, 0.5, 0.5, 0.5] 14 | bn_momentum = 0.001 15 | 16 | load_model = 'models/S2MA.pth' 17 | 18 | train_dir_img = 'list/train/DUT_train_list.txt' 19 | save_lossdir = './model_epoch_loss/loss.txt' 20 | save_model_dir = './models/' 21 | 22 | 23 | # test.py 24 | 25 | # test_lists = ['list/test/NLRP_test_list.txt', 'list/test/NJU2K_test_list.txt', 26 | # 'list/test/STERE_test_list.txt', 'list/test/SSD_test_list.txt', 27 | # 'list/test/RGBD135_test_list.txt', 'list/test/LFSD_test_list.txt'] 28 | # test_model = 'S2MA.pth' 29 | 30 | 31 | # test on DUT-RGBD dataset 32 | test_lists = ['list/test/DUT_RGBD_test_list.txt'] 33 | test_model = 'S2MA_DUT.pth' 34 | 35 | test_model_dir = save_model_dir + test_model 36 | save_test_path_root = './output/' 37 | 38 | 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import os 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from dataset import get_loader 8 | import transforms as trans 9 | from torchvision import transforms 10 | import math 11 | import time 12 | from parameter import * 13 | from ImageDepthNet import ImageDepthNet 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 16 | cudnn.benchmark = True 17 | 18 | 19 | def test_net(net): 20 | 21 | for test_dir_img in test_lists: 22 | 23 | test_loader = get_loader(test_dir_img, img_size, 1, mode='test', num_thread=1) 24 | 25 | print(''' 26 | Starting testing: 27 | dataset: {} 28 | Testing size: {} 29 | '''.format(test_dir_img.split('/')[-1], len(test_loader.dataset))) 30 | 31 | for i, data_batch in enumerate(test_loader): 32 | print('{}/{}'.format(i, len(test_loader.dataset))) 33 | images, depths, image_w, image_h, image_path = data_batch 34 | images, depths = Variable(images.cuda()), Variable(depths.cuda()) 35 | 36 | outputs_image, outputs_depth = net(images, depths) 37 | _, _, _, _, _, imageBran_output = outputs_image 38 | _, _, _, _, _, depthBran_output = outputs_depth 39 | 40 | image_w, image_h = int(image_w[0]), int(image_h[0]) 41 | 42 | output_imageBran = F.sigmoid(imageBran_output) 43 | output_depthBran = F.sigmoid(depthBran_output) 44 | 45 | output_imageBran = output_imageBran.data.cpu().squeeze(0) 46 | output_depthBran = output_depthBran.data.cpu().squeeze(0) 47 | 48 | transform = trans.Compose([ 49 | transforms.ToPILImage(), 50 | trans.Scale((image_w, image_h)) 51 | ]) 52 | outputImageBranch = transform(output_imageBran) 53 | outputDepthBranch = transform(output_depthBran) 54 | 55 | dataset = image_path[0].split('RGBdDataset_processed')[1].split('/')[1] 56 | 57 | filename = image_path[0].split('/')[-1].split('.')[0] 58 | 59 | # save image branch output 60 | save_test_path = save_test_path_root + dataset + '/' + test_model + '/' 61 | if not os.path.exists(save_test_path): 62 | os.makedirs(save_test_path) 63 | outputImageBranch.save(os.path.join(save_test_path, filename + '.png')) 64 | 65 | 66 | if __name__ == '__main__': 67 | 68 | start = time.time() 69 | 70 | net = ImageDepthNet(3) 71 | net.cuda() 72 | net.eval() 73 | # load model 74 | net.load_state_dict(torch.load(test_model_dir)) 75 | print('Model loaded from {}'.format(test_model_dir)) 76 | 77 | test_net(net) 78 | print('total time {}'.format(time.time()-start)) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torch.nn as nn 4 | from torch import optim 5 | from torch.autograd import Variable 6 | 7 | from dataset import get_loader 8 | import math 9 | from parameter import * 10 | from ImageDepthNet import ImageDepthNet 11 | import os 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 14 | cudnn.benchmark = True 15 | 16 | 17 | def save_loss(save_dir, whole_iter_num, epoch_total_loss, epoch_loss, epoch): 18 | fh = open(save_dir, 'a') 19 | epoch_total_loss = str(epoch_total_loss) 20 | epoch_loss = str(epoch_loss) 21 | fh.write('until_' + str(epoch) + '_run_iter_num' + str(whole_iter_num) + '\n') 22 | fh.write(str(epoch) + '_epoch_total_loss' + epoch_total_loss + '\n') 23 | fh.write(str(epoch) + '_epoch_loss' + epoch_loss + '\n') 24 | fh.write('\n') 25 | fh.close() 26 | 27 | 28 | def adjust_learning_rate(optimizer, decay_rate=.1): 29 | update_lr_group = optimizer.param_groups 30 | for param_group in update_lr_group: 31 | print('before lr: ', param_group['lr']) 32 | param_group['lr'] = param_group['lr'] * decay_rate 33 | print('after lr: ', param_group['lr']) 34 | return optimizer 35 | 36 | 37 | def save_lr(save_dir, optimizer): 38 | update_lr_group = optimizer.param_groups[0] 39 | fh = open(save_dir, 'a') 40 | fh.write('encode:update:lr' + str(update_lr_group['lr']) + '\n') 41 | fh.write('decode:update:lr' + str(update_lr_group['lr']) + '\n') 42 | fh.write('\n') 43 | fh.close() 44 | 45 | 46 | def train_net(net): 47 | 48 | train_loader = get_loader(train_dir_img, img_size, batch_size, mode='train', 49 | num_thread=4) 50 | 51 | print(''' 52 | Starting training: 53 | Train steps: {} 54 | Batch size: {} 55 | Learning rate: {} 56 | Training size: {} 57 | '''.format(train_steps, batch_size, lr, len(train_loader.dataset))) 58 | 59 | N_train = len(train_loader) * batch_size 60 | 61 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) 62 | 63 | criterion = nn.BCEWithLogitsLoss() 64 | whole_iter_num = 0 65 | iter_num = math.ceil(len(train_loader.dataset) / batch_size) 66 | for epoch in range(epochs): 67 | 68 | print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) 69 | print('epoch:{0}-------lr:{1}'.format(epoch + 1, lr)) 70 | 71 | epoch_total_loss = 0 72 | epoch_loss = 0 73 | 74 | for i, data_batch in enumerate(train_loader): 75 | if (i + 1) > iter_num: break 76 | images, depths, label_256, label_32, label_64, label_128, filename = data_batch 77 | images, depths, label_256 = Variable(images.cuda()), Variable(depths.cuda()), Variable(label_256.cuda()) 78 | label_32, label_64, label_128 = Variable(label_32.cuda()), Variable(label_64.cuda()), \ 79 | Variable(label_128.cuda()) 80 | 81 | outputs_image, outputs_depth = net(images, depths) 82 | for_loss6, for_loss5, for_loss4, for_loss3, for_loss2, for_loss1 = outputs_image 83 | depth_for_loss6, depth_for_loss5, depth_for_loss4, depth_for_loss3, depth_for_loss2, depth_for_loss1 = outputs_depth 84 | 85 | # loss 86 | loss6 = criterion(for_loss6, label_32) 87 | loss5 = criterion(for_loss5, label_32) 88 | loss4 = criterion(for_loss4, label_32) 89 | loss3 = criterion(for_loss3, label_64) 90 | loss2 = criterion(for_loss2, label_128) 91 | loss1 = criterion(for_loss1, label_256) 92 | 93 | img_total_loss = loss_weights[0] * loss1 + loss_weights[1] * loss2 + loss_weights[2] * loss3\ 94 | + loss_weights[3] * loss4 + loss_weights[4] * loss5 + loss_weights[5] * loss6 95 | 96 | # depth loss 97 | 98 | depth_loss6 = criterion(depth_for_loss6, label_32) 99 | depth_loss5 = criterion(depth_for_loss5, label_32) 100 | depth_loss4 = criterion(depth_for_loss4, label_32) 101 | depth_loss3 = criterion(depth_for_loss3, label_64) 102 | depth_loss2 = criterion(depth_for_loss2, label_128) 103 | depth_loss1 = criterion(depth_for_loss1, label_256) 104 | 105 | depth_total_loss = loss_weights[0] * depth_loss1 + loss_weights[1] * depth_loss2 + loss_weights[2] * depth_loss3\ 106 | + loss_weights[3] * depth_loss4 + loss_weights[4] * depth_loss5 + loss_weights[5] * depth_loss6 107 | 108 | total_loss = img_total_loss + depth_total_loss 109 | 110 | epoch_total_loss += total_loss.cpu().data.item() 111 | epoch_loss += loss1.cpu().data.item() 112 | 113 | print('whole_iter_num: {0} --- {1:.4f} --- total_loss: {2:.6f} --- loss: {3:.6f}'.format((whole_iter_num + 1), 114 | (i + 1) * batch_size / N_train, total_loss.item(), loss1.item())) 115 | 116 | optimizer.zero_grad() 117 | 118 | total_loss.backward() 119 | 120 | optimizer.step() 121 | whole_iter_num += 1 122 | 123 | if whole_iter_num == train_steps: 124 | torch.save(net.state_dict(), 125 | save_model_dir + 'iterations{}.pth'.format(train_steps)) 126 | return 127 | 128 | if whole_iter_num == stepvalue1 or whole_iter_num == stepvalue2: 129 | optimizer = adjust_learning_rate(optimizer, decay_rate=lr_decay_gamma) 130 | save_lr(save_lossdir, optimizer) 131 | print('have updated lr!!') 132 | 133 | print('Epoch finished ! Loss: {}'.format(epoch_total_loss / iter_num)) 134 | 135 | save_loss(save_lossdir, whole_iter_num, epoch_total_loss / iter_num, epoch_loss/iter_num, epoch+1) 136 | torch.save(net.state_dict(), 137 | save_model_dir + 'MODEL_EPOCH{}.pth'.format(epoch + 1)) 138 | print('Saved') 139 | 140 | 141 | if __name__ == '__main__': 142 | 143 | net = ImageDepthNet(3) 144 | 145 | # load pretrain model for image and depth encoder 146 | vgg_model = torch.load(load_vgg_model) 147 | net = net.init_parameters(vgg_model) 148 | 149 | net.train() 150 | net.cuda() 151 | 152 | train_net(net) 153 | 154 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | import collections 10 | 11 | class Compose(object): 12 | """Composes several transforms together. 13 | 14 | Args: 15 | transforms (List[Transform]): list of transforms to compose. 16 | 17 | Example: 18 | >>> transforms.Compose([ 19 | >>> transforms.CenterCrop(10), 20 | >>> transforms.ToTensor(), 21 | >>> ]) 22 | """ 23 | 24 | def __init__(self, transforms): 25 | self.transforms = transforms 26 | 27 | def __call__(self, img): 28 | for t in self.transforms: 29 | img = t(img) 30 | return img 31 | 32 | 33 | class ToTensor(object): 34 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 35 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 36 | """ 37 | 38 | def __call__(self, pic): 39 | if isinstance(pic, np.ndarray): 40 | # handle numpy array 41 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 42 | # backard compability 43 | # return img.float().div(255) 44 | return img.float() 45 | # handle PIL Image 46 | if pic.mode == 'I': 47 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 48 | elif pic.mode == 'I;16': 49 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 50 | else: 51 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 52 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 53 | if pic.mode == 'YCbCr': 54 | nchannel = 3 55 | elif pic.mode == 'I;16': 56 | nchannel = 1 57 | else: 58 | nchannel = len(pic.mode) 59 | img = img.view(pic.size[1], pic.size[0], nchannel) 60 | # put it from HWC to CHW format 61 | # yikes, this transpose takes 80% of the loading time/CPU 62 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 63 | if isinstance(img, torch.ByteTensor): 64 | # return img.float().div(255) 65 | return img.float() 66 | 67 | else: 68 | return img 69 | 70 | 71 | class ToTensor_BGR(object): 72 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 73 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 74 | """ 75 | 76 | def __call__(self, pic): 77 | if isinstance(pic, np.ndarray): 78 | # handle numpy array 79 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 80 | # backard compability 81 | # return img.float().div(255) 82 | return img.float() 83 | # handle PIL Image 84 | if pic.mode == 'I': 85 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 86 | elif pic.mode == 'I;16': 87 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 88 | else: 89 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 90 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 91 | if pic.mode == 'YCbCr': 92 | nchannel = 3 93 | elif pic.mode == 'I;16': 94 | nchannel = 1 95 | else: 96 | nchannel = len(pic.mode) 97 | img = img.view(pic.size[1], pic.size[0], nchannel) 98 | # put it from HWC to CHW format 99 | # yikes, this transpose takes 80% of the loading time/CPU 100 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 101 | if isinstance(img, torch.ByteTensor): 102 | # return img.float().div(255) 103 | 104 | img_bgr = img[[2, 1, 0], :, :] 105 | 106 | return img_bgr.float() 107 | 108 | else: 109 | return img 110 | 111 | 112 | class ToPILImage(object): 113 | """Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape 114 | H x W x C to a PIL.Image while preserving value range. 115 | """ 116 | 117 | def __call__(self, pic): 118 | npimg = pic 119 | mode = None 120 | if isinstance(pic, torch.FloatTensor): 121 | # pic = pic.mul(255).byte() 122 | pic = pic.byte() 123 | if torch.is_tensor(pic): 124 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 125 | assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray' 126 | if npimg.shape[2] == 1: 127 | npimg = npimg[:, :, 0] 128 | 129 | if npimg.dtype == np.uint8: 130 | mode = 'L' 131 | if npimg.dtype == np.int16: 132 | mode = 'I;16' 133 | if npimg.dtype == np.int32: 134 | mode = 'I' 135 | elif npimg.dtype == np.float32: 136 | mode = 'F' 137 | else: 138 | if npimg.dtype == np.uint8: 139 | mode = 'RGB' 140 | assert mode is not None, '{} is not supported'.format(npimg.dtype) 141 | return Image.fromarray(npimg, mode=mode) 142 | 143 | 144 | class Normalize(object): 145 | """Given mean: (R, G, B) and std: (R, G, B), 146 | will normalize each channel of the torch.*Tensor, i.e. 147 | channel = (channel - mean) / std 148 | """ 149 | 150 | def __init__(self, mean, std): 151 | self.mean = mean 152 | self.std = std 153 | 154 | def __call__(self, tensor): 155 | # TODO: make efficient 156 | for t, m, s in zip(tensor, self.mean, self.std): 157 | t.sub_(m).div_(s) 158 | return tensor 159 | 160 | 161 | class Scale(object): 162 | """Rescale the input PIL.Image to the given size. 163 | 164 | Args: 165 | size (sequence or int): Desired output size. If size is a sequence like 166 | (w, h), output size will be matched to this. If size is an int, 167 | smaller edge of the image will be matched to this number. 168 | i.e, if height > width, then image will be rescaled to 169 | (size * height / width, size) 170 | interpolation (int, optional): Desired interpolation. Default is 171 | ``PIL.Image.BILINEAR`` 172 | """ 173 | 174 | def __init__(self, size, interpolation=Image.BILINEAR): 175 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 176 | self.size = size 177 | self.interpolation = interpolation 178 | 179 | def __call__(self, img): 180 | """ 181 | Args: 182 | img (PIL.Image): Image to be scaled. 183 | 184 | Returns: 185 | PIL.Image: Rescaled image. 186 | """ 187 | if isinstance(self.size, int): 188 | w, h = img.size 189 | if (w <= h and w == self.size) or (h <= w and h == self.size): 190 | return img 191 | if w < h: 192 | ow = self.size 193 | oh = int(self.size * h / w) 194 | return img.resize((ow, oh), self.interpolation) 195 | else: 196 | oh = self.size 197 | ow = int(self.size * w / h) 198 | return img.resize((ow, oh), self.interpolation) 199 | else: 200 | return img.resize(self.size, self.interpolation) 201 | 202 | 203 | class CenterCrop(object): 204 | """Crops the given PIL.Image at the center to have a region of 205 | the given size. size can be a tuple (target_height, target_width) 206 | or an integer, in which case the target will be of a square shape (size, size) 207 | """ 208 | 209 | def __init__(self, size): 210 | if isinstance(size, numbers.Number): 211 | self.size = (int(size), int(size)) 212 | else: 213 | self.size = size 214 | 215 | def __call__(self, img): 216 | w, h = img.size 217 | th, tw = self.size 218 | x1 = int(round((w - tw) / 2.)) 219 | y1 = int(round((h - th) / 2.)) 220 | return img.crop((x1, y1, x1 + tw, y1 + th)) 221 | 222 | 223 | class Pad(object): 224 | """Pads the given PIL.Image on all sides with the given "pad" value""" 225 | 226 | def __init__(self, padding, fill=0): 227 | assert isinstance(padding, numbers.Number) 228 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 229 | self.padding = padding 230 | self.fill = fill 231 | 232 | def __call__(self, img): 233 | return ImageOps.expand(img, border=self.padding, fill=self.fill) 234 | 235 | 236 | class Lambda(object): 237 | """Applies a lambda as a transform.""" 238 | 239 | def __init__(self, lambd): 240 | assert isinstance(lambd, types.LambdaType) 241 | self.lambd = lambd 242 | 243 | def __call__(self, img): 244 | return self.lambd(img) 245 | 246 | 247 | class RandomCrop(object): 248 | """Crops the given PIL.Image at a random location to have a region of 249 | the given size. size can be a tuple (target_height, target_width) 250 | or an integer, in which case the target will be of a square shape (size, size) 251 | """ 252 | 253 | def __init__(self, size, padding=0): 254 | if isinstance(size, numbers.Number): 255 | self.size = (int(size), int(size)) 256 | else: 257 | self.size = size 258 | self.padding = padding 259 | 260 | def __call__(self, img): 261 | if self.padding > 0: 262 | img = ImageOps.expand(img, border=self.padding, fill=0) 263 | 264 | w, h = img.size 265 | th, tw = self.size 266 | if w == tw and h == th: 267 | return img 268 | 269 | x1 = random.randint(0, w - tw) 270 | y1 = random.randint(0, h - th) 271 | return img.crop((x1, y1, x1 + tw, y1 + th)) 272 | 273 | 274 | class RandomHorizontalFlip(object): 275 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 276 | """ 277 | 278 | def __call__(self, img): 279 | if random.random() < 0.5: 280 | return img.transpose(Image.FLIP_LEFT_RIGHT) 281 | return img 282 | 283 | 284 | class RandomSizedCrop(object): 285 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 286 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 287 | This is popularly used to train the Inception networks 288 | size: size of the smaller edge 289 | interpolation: Default: PIL.Image.BILINEAR 290 | """ 291 | 292 | def __init__(self, size, interpolation=Image.BILINEAR): 293 | self.size = size 294 | self.interpolation = interpolation 295 | 296 | def __call__(self, img): 297 | for attempt in range(10): 298 | area = img.size[0] * img.size[1] 299 | target_area = random.uniform(0.08, 1.0) * area 300 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 301 | 302 | w = int(round(math.sqrt(target_area * aspect_ratio))) 303 | h = int(round(math.sqrt(target_area / aspect_ratio))) 304 | 305 | if random.random() < 0.5: 306 | w, h = h, w 307 | 308 | if w <= img.size[0] and h <= img.size[1]: 309 | x1 = random.randint(0, img.size[0] - w) 310 | y1 = random.randint(0, img.size[1] - h) 311 | 312 | img = img.crop((x1, y1, x1 + w, y1 + h)) 313 | assert(img.size == (w, h)) 314 | 315 | return img.resize((self.size, self.size), self.interpolation) 316 | 317 | # Fallback 318 | scale = Scale(self.size, interpolation=self.interpolation) 319 | crop = CenterCrop(self.size) 320 | return crop(scale(img)) 321 | --------------------------------------------------------------------------------