├── README.md ├── lib ├── data_prefetcher.py ├── dataset.py └── transform.py ├── net.py ├── paper ├── PR-FM.png ├── Weakly_Alignment-free_RGBT_Salient_Object_Detection_with_Deep_Correlation_Network.pdf └── framework.png ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Correlation-Network 2 | This is the code of "Weakly Alignment-free RGBT Salient Object Detection with Deep Correlation Network" 3 | 4 | ![](./paper/framework.png) 5 | 6 | ![](./paper/PR-FM.png) 7 | 8 | ### Saved Models 9 | https://pan.baidu.com/s/1Goq5K8qr-uVPvOTcaV1fIQ [egiy] 10 | 11 | ### Pretrained Model(VGG) 12 | https://pan.baidu.com/s/1EKUMEUrUz9XKu15X4SzHsg [uwm5] 13 | 14 | ### Unaligned Dataset 15 | https://pan.baidu.com/s/1W8rZFfN5K4-0RK-bXHJRYQ [nrkq] 16 | 17 | ### Saliency Maps 18 | https://pan.baidu.com/s/1xtkzcvRfkNtsZ9GUtIIaVQ [9tf7] 19 | -------------------------------------------------------------------------------- /lib/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DataPrefetcher(object): 4 | def __init__(self, loader): 5 | self.loader = iter(loader) 6 | self.stream = torch.cuda.Stream() 7 | self.preload() 8 | 9 | def preload(self): 10 | try: 11 | self.next_rgb, self.next_t, self.next_gt,_,_ = next(self.loader) 12 | except StopIteration: 13 | self.next_rgb = None 14 | self.next_t = None 15 | self.next_gt = None 16 | return 17 | 18 | with torch.cuda.stream(self.stream): 19 | self.next_rgb = self.next_rgb.cuda(non_blocking=True).float() 20 | self.next_t = self.next_t.cuda(non_blocking=True).float() 21 | self.next_gt = self.next_gt.cuda(non_blocking=True).float() 22 | #self.next_rgb = self.next_rgb #if need 23 | #self.next_t = self.next_t #if need 24 | #self.next_gt = self.next_gt # if need 25 | 26 | def next(self): 27 | torch.cuda.current_stream().wait_stream(self.stream) 28 | rgb = self.next_rgb 29 | t= self.next_t 30 | gt = self.next_gt 31 | self.preload() 32 | return rgb, t, gt -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import torch 7 | try: 8 | from . import transform 9 | except: 10 | import transform 11 | from torch.utils.data import Dataset 12 | 13 | def getRandomSample(rgb,t): 14 | n = np.random.randint(10) 15 | zero = np.random.randint(2) 16 | if n==1: 17 | if zero: 18 | rgb = torch.from_numpy(np.zeros_like(rgb)) 19 | else: 20 | rgb = torch.from_numpy(np.random.randn(*rgb.shape)) 21 | elif n==2: 22 | if zero: 23 | t = torch.from_numpy(np.zeros_like(t)) 24 | else: 25 | t = torch.from_numpy(np.random.randn(*t.shape)) 26 | return rgb,t 27 | 28 | class Data(Dataset): 29 | def __init__(self, root,mode='train'): 30 | self.samples = [] 31 | lines = os.listdir(os.path.join(root, 'GT')) 32 | self.mode = mode 33 | for line in lines: 34 | rgbpath = os.path.join(root, 'RGB', line[:-4]+'.jpg')# 35 | tpath = os.path.join(root, 'T', line[:-4]+'.jpg')# 36 | maskpath = os.path.join(root, 'GT', line) 37 | self.samples.append([rgbpath,tpath,maskpath]) 38 | 39 | if mode == 'train': 40 | self.transform = transform.Compose( transform.Normalize(), 41 | transform.Resize(352,352), 42 | transform.RandomHorizontalFlip(), 43 | transform.ToTensor(), 44 | #transform.RandomSpitalTransformation() 45 | ) 46 | 47 | elif mode == 'test': 48 | self.transform = transform.Compose( transform.Normalize(), 49 | transform.Resize(352,352), 50 | transform.ToTensor() 51 | ) 52 | else: 53 | raise ValueError 54 | 55 | def __getitem__(self, idx): 56 | rgbpath,tpath,maskpath = self.samples[idx] 57 | rgb = cv2.imread(rgbpath).astype(np.float32) 58 | t = cv2.imread(tpath).astype(np.float32) 59 | mask = cv2.imread(maskpath).astype(np.float32) 60 | H, W, C = mask.shape 61 | rgb,t,mask = self.transform(rgb,t,mask) 62 | if self.mode == 'train': 63 | rgb,t =getRandomSample(rgb,t) 64 | return rgb.float(), t.float(), mask.float(), (H, W), os.path.split(maskpath)[-1] 65 | 66 | def __len__(self): 67 | return len(self.samples) 68 | -------------------------------------------------------------------------------- /lib/transform.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import os 7 | class Compose(object): 8 | def __init__(self, *ops): 9 | self.ops = ops 10 | 11 | def __call__(self, rgb,t, mask): 12 | for op in self.ops: 13 | rgb,t, mask = op(rgb,t, mask) 14 | return rgb,t, mask 15 | 16 | class Normalize(object): 17 | def __call__(self, rgb,t, mask): 18 | rgb = rgb/255 19 | t = t/ 255 20 | mask /= 255 21 | return rgb,t, mask 22 | 23 | class Minusmean(object): 24 | def __init__(self, mean1,mean2): 25 | self.mean1 = mean1 26 | self.mean2 = mean2 27 | 28 | def __call__(self, rgb,t, mask): 29 | rgb = rgb - self.mean1 30 | t = t - self.mean2 31 | mask /= 255 32 | return rgb,t, mask 33 | 34 | 35 | class Resize(object): 36 | def __init__(self, H, W): 37 | self.H = H 38 | self.W = W 39 | 40 | def __call__(self, rgb,t, mask): 41 | rgb = cv2.resize(rgb, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 42 | t = cv2.resize(t, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 43 | mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 44 | return rgb,t, mask 45 | 46 | class RandomCrop(object): 47 | def __init__(self, H, W): 48 | self.H = H 49 | self.W = W 50 | 51 | def __call__(self, rgb,t, mask): 52 | H,W,_ = rgb.shape 53 | xmin = np.random.randint(W-self.W+1) 54 | ymin = np.random.randint(H-self.H+1) 55 | rgb = rgb[ymin:ymin+self.H, xmin:xmin+self.W, :] 56 | t = t[ymin:ymin + self.H, xmin:xmin + self.W, :] 57 | mask = mask[ymin:ymin+self.H, xmin:xmin+self.W, :] 58 | return rgb,t, mask 59 | 60 | class RandomHorizontalFlip(object): 61 | def __call__(self, rgb,t, mask): 62 | if np.random.randint(2)==1: 63 | rgb = rgb[:,::-1,:].copy() 64 | t = t[:, ::-1, :].copy() 65 | mask = mask[:,::-1,:].copy() 66 | return rgb,t, mask 67 | 68 | class ToTensor(object): 69 | def __call__(self, rgb,t, mask): 70 | rgb = torch.from_numpy(rgb) 71 | rgb = rgb.permute(2, 0, 1) 72 | t = torch.from_numpy(t) 73 | t = t.permute(2, 0, 1) 74 | mask = torch.from_numpy(mask) 75 | mask = mask.permute(2, 0, 1) 76 | return rgb.float(),t.float(),mask.mean(dim=0,keepdim=True).float() 77 | 78 | class RandomSpitalTransformation(object): 79 | def __call__(self, rgb,t, mask): 80 | RTI= 10*torch.rand(1)+2 # Random Transformation Intensity 81 | identity_theta = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float) 82 | dtheta = 2*(torch.rand(6)-0.5)/RTI 83 | theta = dtheta + identity_theta 84 | resampling_grid = F.affine_grid(theta.view(-1,2, 3), rgb.unsqueeze(0).size(), align_corners=True) 85 | option = 10*torch.rand(1) 86 | if option < 5: 87 | rgb_wrap = F.grid_sample(rgb.unsqueeze(0).float(), resampling_grid, mode='bilinear', padding_mode='zeros', align_corners=True) 88 | mask_wrap = F.grid_sample(mask.unsqueeze(0).float(), resampling_grid, mode='bilinear', padding_mode='zeros', align_corners=True) 89 | return rgb_wrap.squeeze(0) , t, mask_wrap.squeeze(0) 90 | else: 91 | t_wrap = F.grid_sample(t.unsqueeze(0).float(), resampling_grid, mode='bilinear', padding_mode='zeros', align_corners=True) 92 | return rgb, t_wrap.squeeze(0), mask 93 | 94 | class RST_Test(object): 95 | def __call__(self, rgb,t, mask): 96 | RTI= 10*torch.rand(1)+2 # Random Transformation Intensity 97 | identity_theta = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float) 98 | dtheta = 2*(torch.rand(6)-0.5)/RTI 99 | theta = dtheta + identity_theta 100 | resampling_grid = F.affine_grid(theta.view(-1,2, 3), rgb.unsqueeze(0).size(), align_corners=True) 101 | t_wrap = F.grid_sample(t.unsqueeze(0).float(), resampling_grid, mode='bilinear', padding_mode='zeros', align_corners=True) 102 | return rgb, t_wrap.squeeze(0), mask 103 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | # vgg16 5 | def vgg(cfg, i, batch_norm=False): 6 | layers = [] 7 | in_channels = i 8 | stage = 1 9 | for v in cfg: 10 | if v == 'M': 11 | stage += 1 12 | if stage == 6: 13 | layers += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1)] 14 | else: 15 | layers += [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)] 16 | else: 17 | if stage == 6: 18 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 19 | else: 20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 21 | if batch_norm: 22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 23 | else: 24 | layers += [conv2d, nn.ReLU(inplace=True)] 25 | in_channels = v 26 | return layers 27 | 28 | class vgg16(nn.Module): 29 | def __init__(self): 30 | super(vgg16, self).__init__() 31 | self.cfg =[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 32 | self.extract = [3, 8, 15, 22, 29] # [ 8, 15, 22, 29] 33 | # 64:1 -->128:1/2 -->256:1/4 -->512 :1/8 --> 512:1/16 -->M-> 512,1/16 34 | self.base = nn.ModuleList(vgg(self.cfg, 3)) 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | m.weight.data.normal_(0, 0.01) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | m.weight.data.fill_(1) 40 | m.bias.data.zero_() 41 | 42 | def forward(self, x): 43 | tmp_x = [] 44 | for k in range(len(self.base)): 45 | x = self.base[k](x) 46 | if k in self.extract: 47 | tmp_x.append(x) #collect feature maps 1(64) 1/2(128) 1/4(256) 1/8(512) 1/16(512) 48 | return tmp_x 49 | 50 | def convblock(in_, out_, ks, st, pad): 51 | return nn.Sequential( 52 | nn.Conv2d(in_, out_, ks, st, pad), 53 | nn.BatchNorm2d(out_), 54 | nn.ReLU(inplace=True) 55 | ) 56 | 57 | def kernel2d_conv(feat_in, kernel, ksize): 58 | """ 59 | """ 60 | channels = feat_in.size(1) 61 | N, kernels, H, W = kernel.size() 62 | pad = (ksize - 1) // 2 63 | 64 | feat_in = F.pad(feat_in, (pad, pad, pad, pad), mode="replicate") 65 | feat_in = feat_in.unfold(2, ksize, 1).unfold(3, ksize, 1) 66 | feat_in = feat_in.permute(0, 2, 3, 1, 5, 4).contiguous() 67 | feat_in = feat_in.reshape(N, H, W, channels, -1) 68 | 69 | kernel = kernel.permute(0, 2, 3, 1).reshape(N, H, W, channels, ksize, ksize) 70 | kernel = kernel.permute(0, 1, 2, 3, 5, 4).reshape(N, H, W, channels, -1) 71 | feat_out = torch.sum(feat_in * kernel, -1) 72 | feat_out = feat_out.permute(0, 3, 1, 2).contiguous() 73 | return feat_out 74 | 75 | class CA(nn.Module): 76 | def __init__(self,in_ch): 77 | super(CA, self).__init__() 78 | self.avg_weight = nn.AdaptiveAvgPool2d(1) 79 | self.max_weight = nn.AdaptiveMaxPool2d(1) 80 | self.fus = nn.Sequential( 81 | nn.Conv2d(in_ch, in_ch // 2, 1, 1, 0), 82 | nn.ReLU(), 83 | nn.Conv2d(in_ch // 2, in_ch, 1, 1, 0), 84 | ) 85 | self.c_mask = nn.Sigmoid() 86 | def forward(self, x): 87 | avg_map_c = self.avg_weight(x) 88 | max_map_c = self.max_weight(x) 89 | c_mask = self.c_mask(torch.add(self.fus(avg_map_c), self.fus(max_map_c))) 90 | return torch.mul(x, c_mask) 91 | 92 | class MySTN(nn.Module): 93 | def __init__(self, in_ch, mode='Curve'): 94 | super(MySTN, self).__init__() 95 | self.mode = mode 96 | self.down_block_1 = nn.Sequential( 97 | convblock(in_ch, 128, 3, 2, 1), 98 | convblock(128, 128, 1, 1, 0) 99 | ) 100 | self.down_block_2 = nn.Sequential( 101 | convblock(128, 128, 3, 2, 1), 102 | convblock(128, 128, 1, 1, 0) 103 | ) 104 | if mode =='Curve': 105 | self.up_blcok_1 = convblock(128, 128, 1, 1, 0) 106 | self.up_blcok_2 = convblock(128, 64, 1, 1, 0) 107 | self.wrap_filed = nn.Conv2d(64,2,3,1,1) 108 | self.wrap_filed.weight.data.normal_(mean=0.0, std=5e-4) 109 | self.wrap_filed.bias.data.zero_() 110 | self.wrap_grid = None 111 | elif mode =='Affine': 112 | self.down_block_3 = nn.Sequential( 113 | convblock(128, 128, 3, 2, 1), 114 | convblock(128, 128, 1, 1, 0), 115 | ) 116 | self.deta = nn.Sequential( 117 | nn.AdaptiveAvgPool2d(1), 118 | nn.Conv2d(128,6,1,1,0) 119 | ) 120 | # Start with identity transformation 121 | self.deta[-1].weight.data.normal_(mean=0.0, std=5e-4) 122 | self.deta[-1].bias.data.zero_() 123 | self.affine_matrix = None 124 | self.wrap_grid = None 125 | 126 | def forward(self, in_): 127 | size = in_.shape[2:] 128 | n1 = self.down_block_1(in_) 129 | n2 = self.down_block_2(n1) 130 | 131 | if self.mode=="Curve": 132 | n2 = self.up_blcok_1(F.interpolate(n2,size=n1.shape[2:],mode='bilinear',align_corners=True)) 133 | n2 = self.up_blcok_2(F.interpolate(n2,size=in_.shape[2:],mode='bilinear',align_corners=True)) 134 | 135 | xx = torch.linspace(-1, 1, size[1]).view(1, -1).repeat(size[0], 1) 136 | yy = torch.linspace(-1, 1, size[0]).view(-1, 1).repeat(1, size[1]) 137 | xx = xx.view(1, size[0], size[1]) 138 | yy = yy.view(1, size[0], size[1]) 139 | grid = torch.cat((xx, yy), 0).float().unsqueeze(0).repeat(in_.shape[0], 1, 1, 1) 140 | grid = grid.clone().detach().requires_grad_(False) 141 | if in_.is_cuda: 142 | grid = grid.cuda() 143 | 144 | filed_residal = self.wrap_filed(n2) 145 | self.wrap_grid = grid + filed_residal 146 | 147 | elif self.mode=="Affine": 148 | identity_theta = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float).requires_grad_(False) 149 | if in_.is_cuda: 150 | identity_theta = identity_theta.cuda() 151 | 152 | n3 = self.down_block_3(n2) 153 | deta = self.deta(n3) 154 | bsize = deta.shape[0] 155 | self.affine_matrix = deta.view(bsize,-1) + identity_theta.unsqueeze(0).repeat(bsize, 1) 156 | self.wrap_grid = F.affine_grid(self.affine_matrix.view(-1, 2, 3), in_.size(),align_corners=True).permute(0, 3, 1, 2) 157 | 158 | def wrap(self, x): 159 | if not x.shape[-1] == self.wrap_grid.shape[-1]: 160 | sampled_grid = F.interpolate(self.wrap_grid, size=x.shape[2:], mode='bilinear', align_corners=True) 161 | wrap_x = F.grid_sample(x, sampled_grid.permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) 162 | else: 163 | wrap_x = F.grid_sample(x, self.wrap_grid.permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) 164 | return wrap_x 165 | 166 | def wrap_inverse(self, x): 167 | t1 ,t2 = self.affine_matrix.view(-1, 2, 3)[:,:,:2],self.affine_matrix.view(-1, 2, 3)[:,:,2].unsqueeze(2) 168 | matrix_inverse = torch.cat((t1.inverse(),-t2),dim=2) 169 | sampled_grid = F.affine_grid(matrix_inverse, x.size(),align_corners=True) 170 | wrap_x = F.grid_sample(x, sampled_grid, mode='bilinear', padding_mode='zeros', align_corners=True) 171 | return wrap_x 172 | 173 | class PPM(nn.Module): 174 | def __init__(self,in_ch): 175 | super(PPM, self).__init__() 176 | self.conv = convblock(in_ch, 128, 3, 1, 1) 177 | self.b0 = nn.Sequential( 178 | nn.AdaptiveMaxPool2d(9), 179 | nn.Conv2d(128, 128, 1, 1, 0, bias=False), 180 | nn.ReLU(inplace=True) 181 | ) 182 | 183 | self.b1 = nn.Sequential( 184 | nn.AdaptiveMaxPool2d(5), 185 | nn.Conv2d(128, 128, 1, 1, 0, bias=False), 186 | nn.ReLU(inplace=True) 187 | ) 188 | self.b2 = nn.Sequential( 189 | nn.AdaptiveMaxPool2d(3), 190 | nn.Conv2d(128, 128, 1, 1, 0, bias=False), 191 | nn.ReLU(inplace=True) 192 | ) 193 | self.b3 = nn.Sequential( 194 | nn.AdaptiveMaxPool2d(1), 195 | nn.Conv2d(128, 128, 1, 1, 0, bias=False), 196 | nn.ReLU(inplace=True) 197 | ) 198 | self.fus = convblock(640, 128, 1, 1, 0) 199 | self.score=nn.Conv2d(128,1,1,1,0) 200 | 201 | def forward(self, x): 202 | x_size = x.size()[2:] 203 | x = self.conv(x) 204 | b0 = F.interpolate(self.b0(x), x_size, mode='bilinear', align_corners=True) 205 | b1 = F.interpolate(self.b1(x), x_size, mode='bilinear', align_corners=True) 206 | b2 = F.interpolate(self.b2(x), x_size, mode='bilinear', align_corners=True) 207 | b3 = F.interpolate(self.b3(x), x_size, mode='bilinear', align_corners=True) 208 | out = self.fus(torch.cat((b0, b1, b2, b3, x), 1)) 209 | return out 210 | 211 | class MAM(nn.Module): 212 | def __init__(self): 213 | super(MAM, self).__init__() 214 | self.stn = MySTN(256, "Affine") 215 | 216 | self.fus1 = convblock(256, 64, 1, 1, 0) 217 | self.alpha = nn.Conv2d(128, 1, 1, 1, 0) 218 | self.bata = nn.Conv2d(128, 1, 1, 1, 0) 219 | self.fus2 = convblock(128, 64, 1, 1, 0) 220 | 221 | self.dynamic_filter = nn.Conv2d(128,3*3*128,3,1,1) 222 | self.fus3 = convblock(128, 64, 1, 1, 0) 223 | self.combine = convblock(192,128,3,1,1) 224 | 225 | def forward(self, gr, gt): 226 | 227 | self.stn(torch.cat([gr,gt],dim=1)) 228 | in1 = self.fus1(torch.cat([gr, self.stn.wrap(gt)],dim=1)) 229 | 230 | affine_gt = self.alpha(gr)*gt + self.bata(gr) 231 | in2 = self.fus2(gr+affine_gt) 232 | 233 | filter = self.dynamic_filter(gr) 234 | in3 = self.fus3(kernel2d_conv(gt,filter,3)+gr) 235 | return self.combine(torch.cat([in1,in2,in3],dim=1)) 236 | 237 | 238 | 239 | class LSTMCell(nn.Module): 240 | 241 | def __init__(self): 242 | super(LSTMCell, self).__init__() 243 | self.fus = convblock(256,128,1,1,0) 244 | self.conv = nn.Conv2d(256,512,3,1,1,bias=True) 245 | 246 | def forward(self, rgb, t,cur_state): 247 | h_cur, c_cur = cur_state 248 | in_ = self.fus(torch.cat([rgb, t],dim=1)) 249 | combined = torch.cat([in_, h_cur], dim=1) # concatenate along channel axis 250 | combined_conv = self.conv(combined) 251 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, 128, dim=1) 252 | i = torch.sigmoid(cc_i) 253 | f = torch.sigmoid(cc_f) 254 | o = torch.sigmoid(cc_o) 255 | g = torch.tanh(cc_g) 256 | 257 | c_next = f * c_cur + i * g 258 | h_next = o * torch.tanh(c_next) 259 | return h_next, c_next 260 | 261 | class MCLSTMCell(nn.Module): 262 | 263 | def __init__(self): 264 | """ 265 | Modality Correction ConvLSTMCell 266 | """ 267 | super(MCLSTMCell, self).__init__() 268 | #global-spatial context enhancement 269 | self.gc_enhance = nn.Conv2d(128,128,3,1,1) 270 | 271 | #modalities alignment at spatial location and pixel-wise correlation 272 | self.stn = MySTN(256,"Affine") 273 | self.fus1 = convblock(256,64,1,1,0) 274 | self.alpha = nn.Conv2d(128,1,1,1,0) 275 | self.bata = nn.Conv2d(128,1,1,1,0) 276 | self.fus2 = convblock(128, 64, 1, 1, 0) 277 | 278 | self.conv = nn.Conv2d(256,512,3,1,1,bias=True) 279 | 280 | def forward(self, rgb, t, global_context, cur_state): 281 | h_cur, c_cur = cur_state 282 | 283 | self.stn(torch.cat([rgb,t],dim=1)) 284 | in1 = self.fus1(torch.cat([rgb, self.stn.wrap(t)],dim=1)) 285 | 286 | affine_t = self.alpha(rgb)*t + self.bata(rgb) 287 | in2 = self.fus2(rgb+affine_t) 288 | 289 | combined = torch.cat([in1,in2,h_cur], dim=1) # concatenate along channel axis 290 | combined_conv = self.conv(combined) 291 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, 128, dim=1) 292 | i = torch.sigmoid(cc_i) 293 | f = torch.sigmoid(cc_f) 294 | o = torch.sigmoid(cc_o) 295 | g = torch.tanh(cc_g) 296 | 297 | gc_enhance = torch.tanh(self.gc_enhance(global_context)) 298 | 299 | c_next = f * c_cur + i * g + gc_enhance 300 | h_next = o * torch.tanh(c_next) 301 | 302 | return h_next, c_next 303 | 304 | class Mynet(nn.Module): 305 | def __init__(self): 306 | super(Mynet, self).__init__() 307 | self.backbone = vgg16() 308 | #transition 309 | trans_layers_mapping = [[128, 128], [256, 128], [512, 128], [512, 128]] 310 | self.trans = nn.ModuleList() 311 | for mapp in trans_layers_mapping: 312 | self.trans.append(convblock(mapp[0], mapp[1], 3, 2, 1)) 313 | self.globalcontex = PPM(128) 314 | self.mam = MAM() 315 | self.topdown = MCLSTMCell() 316 | self.buttonup = MCLSTMCell() 317 | self.lstm_refine = LSTMCell() 318 | 319 | self.score = nn.Conv2d(128, 1, 1, 1, 0) 320 | 321 | def forward(self, rgb, t): 322 | size = rgb.shape[2:] 323 | Rh = self.backbone(rgb)[1:] 324 | Th = self.backbone(t)[1:] 325 | for i in range(len(Rh)): 326 | Rh[i] = self.trans[i](Rh[i]) 327 | Th[i] =self.trans[i](Th[i]) 328 | 329 | gr = self.globalcontex(Rh[-1]) 330 | gt = self.globalcontex(Th[-1]) 331 | 332 | global_context = F.interpolate(self.mam(gr,gt),size=Rh[0].shape[2:],mode='bilinear',align_corners=True) 333 | scores = [F.interpolate(self.score(global_context),size=size,mode='bilinear',align_corners=True)] 334 | 335 | featnums = len(Rh) 336 | #print(featnums) 337 | 338 | refine_hide_feats = [] 339 | 340 | Rh =[F.interpolate(feat,size=Rh[0].shape[2:],mode='bilinear',align_corners=True) for feat in Rh] 341 | Th = [F.interpolate(feat, size=Th[0].shape[2:], mode='bilinear', align_corners=True) for feat in Th] 342 | cur_state_topdown = [torch.zeros_like(Rh[0]).detach(), torch.zeros_like(Rh[0]).detach()] 343 | cur_state_buttonup = [torch.zeros_like(Rh[0]).detach(), torch.zeros_like(Rh[0]).detach()] 344 | cur_state_refine = [torch.zeros_like(Rh[0]).detach(), torch.zeros_like(Rh[0]).detach()] 345 | for i in range(featnums): 346 | cur_state_topdown = self.topdown(Rh[featnums-i-1], Th[featnums-i-1],global_context, cur_state_topdown) 347 | cur_state_buttonup = self.buttonup(Rh[i],Th[i],global_context, cur_state_buttonup) 348 | cur_state_refine = self.lstm_refine(cur_state_topdown[0],cur_state_buttonup[0], cur_state_refine) 349 | refine_hide_feats.append(cur_state_refine[0]) 350 | 351 | for feat in refine_hide_feats: 352 | scores.append(F.interpolate(self.score(feat),size=size,mode='bilinear',align_corners=True)) 353 | return scores 354 | 355 | def load_pretrained_model(self): 356 | st = torch.load("vgg16.pth") 357 | st2 = {} 358 | for key in st.keys(): 359 | st2['base.' + key] = st[key] 360 | self.backbone.load_state_dict(st2) 361 | print('loading pretrained model success!') 362 | 363 | 364 | if __name__ == "__main__": 365 | rgb = torch.rand(2, 3, 352, 352) 366 | t = rgb 367 | net = Mynet() 368 | map_list= net(rgb, t) 369 | torch.save(net.state_dict(),'test.pth') 370 | print(len(map_list),map_list[0].shape) 371 | -------------------------------------------------------------------------------- /paper/PR-FM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lz118/Deep-Correlation-Network/f60c7c896c3eb9cdaa3416270bb8655dd0ccf57c/paper/PR-FM.png -------------------------------------------------------------------------------- /paper/Weakly_Alignment-free_RGBT_Salient_Object_Detection_with_Deep_Correlation_Network.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lz118/Deep-Correlation-Network/f60c7c896c3eb9cdaa3416270bb8655dd0ccf57c/paper/Weakly_Alignment-free_RGBT_Salient_Object_Detection_with_Deep_Correlation_Network.pdf -------------------------------------------------------------------------------- /paper/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lz118/Deep-Correlation-Network/f60c7c896c3eb9cdaa3416270bb8655dd0ccf57c/paper/framework.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | daicoding='utf-8' 2 | import os 3 | from torch.utils.data import DataLoader 4 | from lib.dataset import Data 5 | import torch.nn.functional as F 6 | import torch 7 | import cv2 8 | import time 9 | from net import Mynet 10 | import numpy as np 11 | #os.environ["CUDA_VISIBLE_DEVICES"] = "2" 12 | if __name__ == '__main__': 13 | model_path= 'model/unalign.pth' 14 | out_path = './VT821' 15 | data = Data(root='./data/VT821_unalign/',mode='test') 16 | loader = DataLoader(data, batch_size=1,shuffle=False) 17 | net = Mynet().cuda() 18 | print('loading model from %s...' % model_path) 19 | net.load_state_dict(torch.load(model_path)) 20 | if not os.path.exists(out_path): os.mkdir(out_path) 21 | time_s = time.time() 22 | img_num = len(loader) 23 | net.eval() 24 | with torch.no_grad(): 25 | for rgb, t, _, (H, W), name in loader: 26 | print(name[0]) 27 | scores = net(rgb.cuda().float(), t.cuda().float()) 28 | score = F.interpolate(scores[-1], size=(H, W), mode='bilinear', align_corners=True) 29 | pred = np.squeeze(score.cpu().data.numpy()) 30 | cv2.imwrite(os.path.join(out_path, name[0][:-4] + '.png'), 255 * pred) 31 | time_e = time.time() 32 | print('speed: %f FPS' % (img_num / (time_e - time_s))) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | coding='utf-8' 2 | import os 3 | from net import Mynet 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from lib.dataset import Data 10 | from lib.data_prefetcher import DataPrefetcher 11 | from torch.nn import functional as F 12 | import cv2 13 | from eval import PR_Curve,MAE_Value 14 | from smooth_loss import get_saliency_smoothness,get_grid_smoothness 15 | #os.environ["CUDA_VISIBLE_DEVICES"] = "4" 16 | 17 | def myloss(scores,label): 18 | deepsal_loss = F.binary_cross_entropy(torch.sigmoid(scores[0]),label,reduction='mean')+\ 19 | F.binary_cross_entropy(torch.sigmoid(scores[1]),label,reduction='mean')+\ 20 | F.binary_cross_entropy(torch.sigmoid(scores[2]),label,reduction='mean')+ \ 21 | F.binary_cross_entropy(torch.sigmoid(scores[3]), label, reduction='mean') + \ 22 | F.binary_cross_entropy(torch.sigmoid(scores[4]),label,reduction='mean') 23 | return deepsal_loss 24 | 25 | if __name__ == '__main__': 26 | random.seed(118) 27 | np.random.seed(118) 28 | torch.manual_seed(118) 29 | torch.cuda.manual_seed(118) 30 | torch.cuda.manual_seed_all(118) 31 | 32 | # dataset 33 | img_root = './data/VT5000-Train_unalign/' 34 | save_path = './model' 35 | if not os.path.exists(save_path): os.mkdir(save_path) 36 | if not os.path.exists(temp_save_path): os.mkdir(temp_save_path) 37 | lr = 0.001 #2 38 | batch_size = 4 39 | epoch = 100 40 | lr_dec=[51] 41 | data = Data(img_root) 42 | loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=1) 43 | 44 | net = Mynet().cuda() 45 | net.load_pretrained_model() 46 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 47 | 48 | iter_num = len(loader) 49 | net.train() 50 | for epochi in range(1, epoch + 1): 51 | 52 | if epochi in lr_dec : 53 | lr=lr/10 54 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 55 | print(lr) 56 | prefetcher = DataPrefetcher(loader) 57 | rgb, t, label = prefetcher.next() 58 | r_sal_loss = 0 59 | net.zero_grad() 60 | i = 0 61 | while rgb is not None: 62 | i+=1 63 | scores = net(rgb, t) 64 | loss = myloss(scores, label) 65 | r_sal_loss += loss.data 66 | loss.backward() 67 | optimizer.step() 68 | optimizer.zero_grad() 69 | if i % 100 == 0: 70 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || loss : %5.4f' % ( 71 | epochi, epoch, i, iter_num, r_sal_loss / 100)) 72 | r_sal_loss = 0 73 | rgb, t, label = prefetcher.next() 74 | torch.save(net.state_dict(), '%s/final.pth' % (save_path)) --------------------------------------------------------------------------------