├── Multi-Interactive_Dual-Decoder_for_RGB-Thermal_Salient_Object_Detection.pdf ├── README.md ├── fig ├── framework.png └── pr_f.png ├── lib ├── data_prefetcher.py ├── dataset.py └── transform.py ├── lib_rgbd ├── data_prefetcher.py ├── dataset.py └── transform.py ├── net.py ├── net_resnet50.py ├── net_resnet50_4.py ├── smooth_loss.py ├── test.py ├── train.py ├── train_resnet50.py ├── train_resnet50_4.py └── vgg.py /Multi-Interactive_Dual-Decoder_for_RGB-Thermal_Salient_Object_Detection.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lz118/Multi-interactive-Dual-decoder/4d49a4d5d3e44e794b6fa105d18d6a0d157f2cf2/Multi-Interactive_Dual-Decoder_for_RGB-Thermal_Salient_Object_Detection.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-interactive-Dual-decoder-for-RGBT-Salient-Object-Detection 2 | 3 | The pytorch implementation of Multi-interactive Dual-decoder for RGBT Salient Object Detection 4 | 5 | ![framework](./fig/framework.png) 6 | ![framework](./fig/pr_f.png) 7 | 8 | ## Train 9 | 10 | - We use VT5000-Train to train our network. All the datasets are available in https://github.com/lz118/RGBT-Salient-Object-Detection 11 | - The pretrained model (VGG16) can be downloaded at https://pan.baidu.com/s/11lq3mUGRFP7TFvH9Eui14A [3513] 12 | 13 | 14 | ## Test 15 | 16 | - The trained models on RGB-T Dataset 17 | 18 | https://pan.baidu.com/s/1Wj6bfi7lhp1KF5iCSVj0gQ [4zkx] 19 | 20 | https://drive.google.com/file/d/11lU5TaRZMTXQ6QCbBLinG9iDvIUDrRP5/view?usp=sharing 21 | 22 | - The trained models on RGB-D Dataset 23 | 24 | https://pan.baidu.com/s/1KlAKrVszQisG0bK1kiedzA [2ulc] 25 | 26 | https://drive.google.com/file/d/1LKVn3iPDBI07DUBiirm4bk2-7yA3pTSM/view?usp=sharing 27 | 28 | ## Evalution 29 | 30 | - For RGB-T SOD, we provide the our saliency maps on VT821, VT1000 and VT5000-Test. 31 | 32 | https://pan.baidu.com/s/1hEZJyEJ2j1n1JKUgUaZZLQ [0div] 33 | 34 | - The saliency maps of all compared methods on VT821, VT1000 and VT5000-Test. 35 | https://pan.baidu.com/s/1s_pJ5qNJcQ8Q7ucusZHnRg [ax96] 36 | 37 | - For RGB-D SOD, we provide the our saliency maps on SIP, SSD,STERE,LFSD and DES. 38 | 39 | https://pan.baidu.com/s/1ZHxvMh818RxlZGW1hQA70w [2oqx] 40 | 41 | - The evalution toolbox is provided by https://github.com/jiwei0921/Saliency-Evaluation-Toolbox 42 | -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lz118/Multi-interactive-Dual-decoder/4d49a4d5d3e44e794b6fa105d18d6a0d157f2cf2/fig/framework.png -------------------------------------------------------------------------------- /fig/pr_f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lz118/Multi-interactive-Dual-decoder/4d49a4d5d3e44e794b6fa105d18d6a0d157f2cf2/fig/pr_f.png -------------------------------------------------------------------------------- /lib/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DataPrefetcher(object): 4 | def __init__(self, loader): 5 | self.loader = iter(loader) 6 | self.stream = torch.cuda.Stream() 7 | self.preload() 8 | 9 | def preload(self): 10 | try: 11 | self.next_rgb, self.next_t, self.next_gt,_,_ = next(self.loader) 12 | except StopIteration: 13 | self.next_rgb = None 14 | self.next_t = None 15 | self.next_gt = None 16 | return 17 | 18 | with torch.cuda.stream(self.stream): 19 | self.next_rgb = self.next_rgb.cuda(non_blocking=True).float() 20 | self.next_t = self.next_t.cuda(non_blocking=True).float() 21 | self.next_gt = self.next_gt.cuda(non_blocking=True).float() 22 | #self.next_rgb = self.next_rgb #if need 23 | #self.next_t = self.next_t #if need 24 | #self.next_gt = self.next_gt # if need 25 | 26 | def next(self): 27 | torch.cuda.current_stream().wait_stream(self.stream) 28 | rgb = self.next_rgb 29 | t= self.next_t 30 | gt = self.next_gt 31 | self.preload() 32 | return rgb, t, gt -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import torch 7 | try: 8 | from . import transform 9 | except: 10 | import transform 11 | from torch.utils.data import Dataset 12 | #BGR 13 | mean_rgb = np.array([[[0.551, 0.619, 0.532]]])*255 14 | mean_t =np.array([[[0.341, 0.360, 0.753]]])*255 15 | std_rgb = np.array([[[0.241, 0.236, 0.244]]])*255 16 | std_t = np.array([[[0.208, 0.269, 0.241]]])*255 17 | 18 | def getRandomSample(rgb,t): 19 | n = np.random.randint(10) 20 | zero = np.random.randint(2) 21 | if n==1: 22 | if zero: 23 | rgb = torch.from_numpy(np.zeros_like(rgb)) 24 | else: 25 | rgb = torch.from_numpy(np.random.randn(*rgb.shape)) 26 | elif n==2: 27 | if zero: 28 | t = torch.from_numpy(np.zeros_like(t)) 29 | else: 30 | t = torch.from_numpy(np.random.randn(*t.shape)) 31 | return rgb,t 32 | 33 | class Data(Dataset): 34 | def __init__(self, root,mode='train'): 35 | self.samples = [] 36 | lines = os.listdir(os.path.join(root, 'GT')) 37 | self.mode = mode 38 | for line in lines: 39 | rgbpath = os.path.join(root, 'RGB', line[:-4]+'.jpg') 40 | tpath = os.path.join(root, 'T', line[:-4]+'.jpg') 41 | maskpath = os.path.join(root, 'GT', line) 42 | self.samples.append([rgbpath,tpath,maskpath]) 43 | 44 | if mode == 'train': 45 | self.transform = transform.Compose( transform.Normalize(mean1=mean_rgb,mean2=mean_t,std1=std_rgb,std2=std_t), 46 | transform.Resize(352,352), 47 | transform.RandomHorizontalFlip(),transform.ToTensor()) 48 | 49 | elif mode == 'test': 50 | self.transform = transform.Compose( transform.Normalize(mean1=mean_rgb,mean2=mean_t,std1=std_rgb,std2=std_t), 51 | transform.Resize(352,352), 52 | transform.ToTensor()) 53 | else: 54 | raise ValueError 55 | 56 | def __getitem__(self, idx): 57 | rgbpath,tpath,maskpath = self.samples[idx] 58 | rgb = cv2.imread(rgbpath).astype(np.float32) 59 | t = cv2.imread(tpath).astype(np.float32) 60 | mask = cv2.imread(maskpath).astype(np.float32) 61 | H, W, C = mask.shape 62 | rgb,t,mask = self.transform(rgb,t,mask) 63 | if self.mode == 'train': 64 | rgb,t =getRandomSample(rgb,t) 65 | return rgb,t,mask, (H, W), maskpath.split('/')[-1] 66 | 67 | def __len__(self): 68 | return len(self.samples) 69 | -------------------------------------------------------------------------------- /lib/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | 5 | class Compose(object): 6 | def __init__(self, *ops): 7 | self.ops = ops 8 | 9 | def __call__(self, rgb,t, mask): 10 | for op in self.ops: 11 | rgb,t, mask = op(rgb,t, mask) 12 | return rgb,t, mask 13 | 14 | 15 | 16 | class Normalize(object): 17 | def __init__(self, mean1,mean2, std1,std2): 18 | self.mean1 = mean1 19 | self.mean2 = mean2 20 | self.std1 = std1 21 | self.std2 = std2 22 | 23 | def __call__(self, rgb,t, mask): 24 | rgb = (rgb - self.mean1)/self.std1 25 | t = (t - self.mean2) / self.std2 26 | mask /= 255 27 | return rgb,t, mask 28 | 29 | class Minusmean(object): 30 | def __init__(self, mean1,mean2): 31 | self.mean1 = mean1 32 | self.mean2 = mean2 33 | 34 | def __call__(self, rgb,t, mask): 35 | rgb = rgb - self.mean1 36 | t = t - self.mean2 37 | mask /= 255 38 | return rgb,t, mask 39 | 40 | 41 | class Resize(object): 42 | def __init__(self, H, W): 43 | self.H = H 44 | self.W = W 45 | 46 | def __call__(self, rgb,t, mask): 47 | rgb = cv2.resize(rgb, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 48 | t = cv2.resize(t, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 49 | mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 50 | return rgb,t, mask 51 | 52 | class RandomCrop(object): 53 | def __init__(self, H, W): 54 | self.H = H 55 | self.W = W 56 | 57 | def __call__(self, rgb,t, mask): 58 | H,W,_ = rgb.shape 59 | xmin = np.random.randint(W-self.W+1) 60 | ymin = np.random.randint(H-self.H+1) 61 | rgb = rgb[ymin:ymin+self.H, xmin:xmin+self.W, :] 62 | t = t[ymin:ymin + self.H, xmin:xmin + self.W, :] 63 | mask = mask[ymin:ymin+self.H, xmin:xmin+self.W, :] 64 | return rgb,t, mask 65 | 66 | class RandomHorizontalFlip(object): 67 | def __call__(self, rgb,t, mask): 68 | if np.random.randint(2)==1: 69 | rgb = rgb[:,::-1,:].copy() 70 | t = t[:, ::-1, :].copy() 71 | mask = mask[:,::-1,:].copy() 72 | return rgb,t, mask 73 | 74 | class ToTensor(object): 75 | def __call__(self, rgb,t, mask): 76 | rgb = torch.from_numpy(rgb) 77 | rgb = rgb.permute(2, 0, 1) 78 | t = torch.from_numpy(t) 79 | t = t.permute(2, 0, 1) 80 | mask = torch.from_numpy(mask) 81 | mask = mask.permute(2, 0, 1) 82 | return rgb,t,mask.mean(dim=0, keepdim=True) -------------------------------------------------------------------------------- /lib_rgbd/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class DataPrefetcher(object): 4 | def __init__(self, loader): 5 | self.loader = iter(loader) 6 | self.stream = torch.cuda.Stream() 7 | self.preload() 8 | 9 | def preload(self): 10 | try: 11 | self.next_rgb, self.next_t, self.next_gt,_,_ = next(self.loader) 12 | except StopIteration: 13 | self.next_rgb = None 14 | self.next_t = None 15 | self.next_gt = None 16 | return 17 | 18 | with torch.cuda.stream(self.stream): 19 | self.next_rgb = self.next_rgb.cuda(non_blocking=True).float() 20 | self.next_t = self.next_t.cuda(non_blocking=True).float() 21 | self.next_gt = self.next_gt.cuda(non_blocking=True).float() 22 | 23 | def next(self): 24 | torch.cuda.current_stream().wait_stream(self.stream) 25 | rgb = self.next_rgb 26 | t= self.next_t 27 | gt = self.next_gt 28 | self.preload() 29 | return rgb, t, gt -------------------------------------------------------------------------------- /lib_rgbd/dataset.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import torch 7 | try: 8 | from . import transform 9 | except: 10 | import transform 11 | from torch.utils.data import Dataset 12 | 13 | #commmon trainset 14 | mean_rgb = np.array([[[0.43127787, 0.4015223, 0.44389117]]])*255 15 | std_rgb = np.array([[[0.25044188, 0.25923958, 0.25612995]]])*255 16 | 17 | mean_d = np.array([[[0.45592305, 0.45592305, 0.45592305]]])*255 18 | std_d = np.array([[[0.2845027, 0.2845027, 0.2845027]]])*255 19 | 20 | #DUTD 21 | #mean_rgb = np.array([[[ 0.4061459, 0.38510114,0.4457303]]])*255 22 | #std_rgb = np.array([[[ 0.25237563, 0.2545061,0.24679454]]])*255 23 | # 24 | #mean_d = np.array([[[0.6786454, 0.6786454, 0.6786454]]])*255 25 | #std_d = np.array([[[0.13604848, 0.13604848, 0.13604848]]])*255 26 | def getRandomSample(rgb,t): 27 | n = np.random.randint(10) 28 | zero = np.random.randint(2) 29 | if n==1: 30 | if zero: 31 | rgb = torch.from_numpy(np.zeros_like(rgb)) 32 | else: 33 | rgb = torch.from_numpy(np.random.randn(*rgb.shape)) 34 | elif n==2: 35 | if zero: 36 | t = torch.from_numpy(np.zeros_like(t)) 37 | else: 38 | t = torch.from_numpy(np.random.randn(*t.shape)) 39 | return rgb.float(),t.float() 40 | 41 | class Data(Dataset): 42 | def __init__(self, root,mode='train'): 43 | self.samples = [] 44 | lines = os.listdir(os.path.join(root,mode+'_images')) 45 | self.mode = mode 46 | for line in lines: 47 | rgbpath = os.path.join(root,mode+'_images', line) 48 | tpath = os.path.join(root,mode+'_depth', line[:-4]+'.png') 49 | maskpath = os.path.join(root,mode+'_masks', line[:-4]+'.png') 50 | self.samples.append([rgbpath,tpath,maskpath]) 51 | 52 | if mode == 'train': 53 | self.transform = transform.Compose( transform.Normalize(mean1=mean_rgb,std1=std_rgb), 54 | transform.Resize(256 ,256), 55 | transform.RandomHorizontalFlip(),transform.ToTensor()) 56 | 57 | elif mode == 'test': 58 | self.transform = transform.Compose( transform.Normalize(mean1=mean_rgb,std1=std_rgb), 59 | transform.Resize(256,256), 60 | transform.ToTensor()) 61 | else: 62 | raise ValueError 63 | 64 | def __getitem__(self, idx): 65 | rgbpath,tpath,maskpath = self.samples[idx] 66 | rgb = cv2.imread(rgbpath).astype(np.float32) 67 | t = cv2.imread(tpath).astype(np.float32) 68 | mask = cv2.imread(maskpath).astype(np.float32) 69 | H, W, C = mask.shape 70 | rgb,t,mask = self.transform(rgb,t,mask) 71 | # if self.mode == 'train': 72 | # rgb,t =getRandomSample(rgb,t) 73 | return rgb,t,mask, (H, W), maskpath.split('/')[-1] 74 | 75 | def __len__(self): 76 | return len(self.samples) 77 | 78 | 79 | 80 | if __name__=='__main__': 81 | data = Data('E:\VT5000\VT5000_clearall') 82 | for i,ba in enumerate(data): 83 | print(ba) -------------------------------------------------------------------------------- /lib_rgbd/transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | #coding=utf-8 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | 8 | class Compose(object): 9 | def __init__(self, *ops): 10 | self.ops = ops 11 | 12 | def __call__(self, rgb,t, mask): 13 | for op in self.ops: 14 | rgb,t, mask = op(rgb,t, mask) 15 | return rgb,t, mask 16 | # 17 | # 18 | # 19 | #class Normalize(object): 20 | # def __init__(self, mean1, std1,mean2, std2): 21 | # self.mean1 = mean1 22 | # self.std1 = std1 23 | # self.mean2 = mean2 24 | # self.std2 = std2 25 | # 26 | # def __call__(self, rgb,t, mask): 27 | # rgb = (rgb - self.mean1)/self.std1 28 | # t = (t - self.mean2)/self.std2 29 | # mask /= 255 30 | # return rgb,t, mask 31 | 32 | class Normalize(object): 33 | def __init__(self, mean1, std1): 34 | self.mean1 = mean1 35 | self.std1 = std1 36 | 37 | def __call__(self, rgb,t, mask): 38 | rgb = (rgb - self.mean1)/self.std1 39 | t = t/255 40 | mask /= 255 41 | return rgb,t, mask 42 | class Resize(object): 43 | def __init__(self, H, W): 44 | self.H = H 45 | self.W = W 46 | 47 | def __call__(self, rgb,t, mask): 48 | rgb = cv2.resize(rgb, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 49 | t = cv2.resize(t, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 50 | mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR) 51 | return rgb,t, mask 52 | 53 | class RandomCrop(object): 54 | def __init__(self, H, W): 55 | self.H = H 56 | self.W = W 57 | 58 | def __call__(self, rgb,t, mask): 59 | H,W,_ = rgb.shape 60 | xmin = np.random.randint(W-self.W+1) 61 | ymin = np.random.randint(H-self.H+1) 62 | rgb = rgb[ymin:ymin+self.H, xmin:xmin+self.W, :] 63 | t = t[ymin:ymin + self.H, xmin:xmin + self.W, :] 64 | mask = mask[ymin:ymin+self.H, xmin:xmin+self.W, :] 65 | return rgb,t, mask 66 | 67 | class RandomHorizontalFlip(object): 68 | def __call__(self, rgb,t, mask): 69 | if np.random.randint(2)==1: 70 | rgb = rgb[:,::-1,:].copy() 71 | t = t[:, ::-1, :].copy() 72 | mask = mask[:,::-1,:].copy() 73 | return rgb,t, mask 74 | 75 | class ToTensor(object): 76 | def __call__(self, rgb,t, mask): 77 | rgb = torch.from_numpy(rgb) 78 | rgb = rgb.permute(2, 0, 1) 79 | t = torch.from_numpy(t) 80 | t = t.permute(2, 0, 1) 81 | mask = torch.from_numpy(mask) 82 | mask = mask.permute(2, 0, 1) 83 | return rgb,t,mask.mean(dim=0, keepdim=True) -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import vgg 5 | 6 | def convblock(in_,out_,ks,st,pad): 7 | return nn.Sequential( 8 | nn.Conv2d(in_,out_,ks,st,pad), 9 | nn.BatchNorm2d(out_), 10 | nn.ReLU(inplace=True) 11 | ) 12 | 13 | class GFB(nn.Module): 14 | def __init__(self,in_1,in_2): 15 | super(GFB, self).__init__() 16 | self.ca1 = CA(2*in_1) 17 | self.conv1 = convblock(2*in_1,128, 3, 1, 1) 18 | self.conv_globalinfo = convblock(512,128,3, 1, 1) 19 | self.ca2 = CA(in_2) 20 | self.conv_curfeat =convblock(in_2,128,3,1,1) 21 | self.conv_out= convblock(128,in_2,3,1,1) 22 | 23 | def forward(self, pre1,pre2,cur,global_info): 24 | cur_size = cur.size()[2:] 25 | pre = self.ca1(torch.cat((pre1,pre2),1)) 26 | pre =self.conv1(F.interpolate(pre,cur_size,mode='bilinear',align_corners=True)) 27 | 28 | global_info = self.conv_globalinfo(F.interpolate(global_info,cur_size,mode='bilinear',align_corners=True)) 29 | cur_feat =self.conv_curfeat(self.ca2(cur)) 30 | fus = pre + cur_feat + global_info 31 | return self.conv_out(fus) 32 | 33 | 34 | class GlobalInfo(nn.Module): 35 | def __init__(self): 36 | super(GlobalInfo, self).__init__() 37 | self.ca = CA(1024) 38 | self.de_chan = convblock(1024,256,3,1,1) 39 | 40 | self.b0 = nn.Sequential( 41 | nn.AdaptiveMaxPool2d(13), 42 | nn.Conv2d(256,128,1,1,0,bias=False), 43 | nn.ReLU(inplace=True) 44 | ) 45 | 46 | self.b1 = nn.Sequential( 47 | nn.AdaptiveMaxPool2d(9), 48 | nn.Conv2d(256,128,1,1,0,bias=False), 49 | nn.ReLU(inplace=True) 50 | ) 51 | self.b2 = nn.Sequential( 52 | nn.AdaptiveMaxPool2d(5), 53 | nn.Conv2d(256, 128, 1, 1, 0,bias=False), 54 | nn.ReLU(inplace=True) 55 | ) 56 | self.b3 = nn.Sequential( 57 | nn.AdaptiveMaxPool2d(1), 58 | nn.Conv2d(256, 128, 1, 1, 0,bias=False), 59 | nn.ReLU(inplace=True) 60 | ) 61 | self.fus = convblock(768,512,1,1,0) 62 | 63 | def forward(self, rgb,t): 64 | x_size=rgb.size()[2:] 65 | x=self.ca(torch.cat((rgb,t),1)) 66 | x=self.de_chan(x) 67 | b0 = F.interpolate(self.b0(x),x_size,mode='bilinear',align_corners=True) 68 | b1 = F.interpolate(self.b1(x),x_size,mode='bilinear',align_corners=True) 69 | b2 = F.interpolate(self.b2(x),x_size,mode='bilinear',align_corners=True) 70 | b3 = F.interpolate(self.b3(x),x_size,mode='bilinear',align_corners=True) 71 | out = self.fus(torch.cat((b0,b1,b2,b3,x),1)) 72 | return out 73 | 74 | class CA(nn.Module): 75 | def __init__(self,in_ch): 76 | super(CA, self).__init__() 77 | self.avg_weight = nn.AdaptiveAvgPool2d(1) 78 | self.max_weight = nn.AdaptiveMaxPool2d(1) 79 | self.fus = nn.Sequential( 80 | nn.Conv2d(in_ch, in_ch // 2, 1, 1, 0), 81 | nn.ReLU(), 82 | nn.Conv2d(in_ch // 2, in_ch, 1, 1, 0), 83 | ) 84 | self.c_mask = nn.Sigmoid() 85 | def forward(self, x): 86 | avg_map_c = self.avg_weight(x) 87 | max_map_c = self.max_weight(x) 88 | c_mask = self.c_mask(torch.add(self.fus(avg_map_c), self.fus(max_map_c))) 89 | return torch.mul(x, c_mask) 90 | 91 | class FinalScore(nn.Module): 92 | def __init__(self): 93 | super(FinalScore, self).__init__() 94 | self.ca =CA(256) 95 | self.score = nn.Conv2d(256, 1, 1, 1, 0) 96 | def forward(self,f1,f2,xsize): 97 | f1 = torch.cat((f1,f2),1) 98 | f1 = self.ca(f1) 99 | score = F.interpolate(self.score(f1), xsize, mode='bilinear', align_corners=True) 100 | return score 101 | 102 | class Decoder(nn.Module): 103 | def __init__(self): 104 | super(Decoder, self).__init__() 105 | self.global_info =GlobalInfo() 106 | self.score_global = nn.Conv2d(512, 1, 1, 1, 0) 107 | 108 | self.gfb4_1 = GFB(512,512) 109 | self.gfb3_1= GFB(512,256) 110 | self.gfb2_1= GFB(256,128) 111 | 112 | self.gfb4_2 = GFB(512, 512) #1/8 113 | self.gfb3_2 = GFB(512, 256)#1/4 114 | self.gfb2_2 = GFB(256, 128)#1/2 115 | 116 | self.score_1=nn.Conv2d(128, 1, 1, 1, 0) 117 | self.score_2 = nn.Conv2d(128, 1, 1, 1, 0) 118 | 119 | self.refine =FinalScore() 120 | 121 | 122 | def forward(self,rgb,t): 123 | xsize=rgb[0].size()[2:] 124 | global_info =self.global_info(rgb[4],t[4]) # 512 1/16 125 | d1=self.gfb4_1(global_info,t[4],rgb[3],global_info) 126 | d2=self.gfb4_2(global_info, rgb[4], t[3], global_info) 127 | #print(d1.shape,d2.shape) 128 | d3= self.gfb3_1(d1, d2,rgb[2],global_info) 129 | d4 = self.gfb3_2(d2, d1, t[2], global_info) 130 | d5 = self.gfb2_1(d3, d4, rgb[1], global_info) 131 | d6 = self.gfb2_2(d4, d3, t[1], global_info) #1/2 128 132 | 133 | score_global = self.score_global(global_info) 134 | 135 | score1=self.score_1(F.interpolate(d5,xsize,mode='bilinear',align_corners=True)) 136 | score2 = self.score_2(F.interpolate(d6, xsize, mode='bilinear', align_corners=True)) 137 | score =self.refine(d5,d6,xsize) 138 | return score,score1,score2,score_global 139 | 140 | class Mnet(nn.Module): 141 | def __init__(self): 142 | super(Mnet,self).__init__() 143 | self.rgb_net= vgg.a_vgg16() 144 | self.t_net= vgg.a_vgg16() 145 | self.decoder=Decoder() 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | m.weight.data.normal_(0, 0.01) 150 | elif isinstance(m, nn.BatchNorm2d): 151 | m.weight.data.fill_(1) 152 | m.bias.data.zero_() 153 | 154 | def forward(self,rgb,t): 155 | rgb_f= self.rgb_net(rgb) 156 | t_f= self.t_net(t) 157 | score,score1,score2,score_g =self.decoder(rgb_f,t_f) 158 | return score,score1,score2,score_g 159 | 160 | def load_pretrained_model(self): 161 | st=torch.load("vgg16.pth") 162 | st2={} 163 | for key in st.keys(): 164 | st2['base.'+key]=st[key] 165 | self.rgb_net.load_state_dict(st2) 166 | self.t_net.load_state_dict(st2) 167 | print('loading pretrained model success!') -------------------------------------------------------------------------------- /net_resnet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from itertools import chain 5 | from torchvision.models import resnet50 6 | 7 | def convblock(in_,out_,ks,st,pad): 8 | return nn.Sequential( 9 | nn.Conv2d(in_,out_,ks,st,pad), 10 | nn.BatchNorm2d(out_), 11 | nn.ReLU(inplace=True) 12 | ) 13 | 14 | class GFB(nn.Module): 15 | def __init__(self,in_1,in_2): 16 | super(GFB, self).__init__() 17 | self.ca1 = CA(2*in_1) 18 | self.conv1 = convblock(2*in_1,128, 3, 1, 1) 19 | self.conv_globalinfo = convblock(512,128,3, 1, 1) 20 | self.ca2 = CA(in_2) 21 | self.conv_curfeat =convblock(in_2,128,3,1,1) 22 | self.conv_out= convblock(128,in_2,3,1,1) 23 | 24 | def forward(self, pre1,pre2,cur,global_info): 25 | cur_size = cur.size()[2:] 26 | pre = self.ca1(torch.cat((pre1,pre2),1)) 27 | pre =self.conv1(F.interpolate(pre,cur_size,mode='bilinear',align_corners=True)) 28 | 29 | global_info = self.conv_globalinfo(F.interpolate(global_info,cur_size,mode='bilinear',align_corners=True)) 30 | cur_feat =self.conv_curfeat(self.ca2(cur)) 31 | fus = pre + cur_feat + global_info 32 | return self.conv_out(fus) 33 | 34 | 35 | class GlobalInfo(nn.Module): 36 | def __init__(self): 37 | super(GlobalInfo, self).__init__() 38 | self.ca = CA(1024) 39 | self.de_chan = convblock(1024,256,3,1,1) 40 | 41 | self.b0 = nn.Sequential( 42 | nn.AdaptiveMaxPool2d(13), 43 | nn.Conv2d(256,128,1,1,0,bias=False), 44 | nn.ReLU(inplace=True) 45 | ) 46 | 47 | self.b1 = nn.Sequential( 48 | nn.AdaptiveMaxPool2d(9), 49 | nn.Conv2d(256,128,1,1,0,bias=False), 50 | nn.ReLU(inplace=True) 51 | ) 52 | self.b2 = nn.Sequential( 53 | nn.AdaptiveMaxPool2d(5), 54 | nn.Conv2d(256, 128, 1, 1, 0,bias=False), 55 | nn.ReLU(inplace=True) 56 | ) 57 | self.b3 = nn.Sequential( 58 | nn.AdaptiveMaxPool2d(1), 59 | nn.Conv2d(256, 128, 1, 1, 0,bias=False), 60 | nn.ReLU(inplace=True) 61 | ) 62 | self.fus = convblock(768,512,1,1,0) 63 | 64 | def forward(self, rgb,t): 65 | x_size=rgb.size()[2:] 66 | x=self.ca(torch.cat((rgb,t),1)) 67 | x=self.de_chan(x) 68 | b0 = F.interpolate(self.b0(x),x_size,mode='bilinear',align_corners=True) 69 | b1 = F.interpolate(self.b1(x),x_size,mode='bilinear',align_corners=True) 70 | b2 = F.interpolate(self.b2(x),x_size,mode='bilinear',align_corners=True) 71 | b3 = F.interpolate(self.b3(x),x_size,mode='bilinear',align_corners=True) 72 | out = self.fus(torch.cat((b0,b1,b2,b3,x),1)) 73 | return out 74 | 75 | class CA(nn.Module): 76 | def __init__(self,in_ch): 77 | super(CA, self).__init__() 78 | self.avg_weight = nn.AdaptiveAvgPool2d(1) 79 | self.max_weight = nn.AdaptiveMaxPool2d(1) 80 | self.fus = nn.Sequential( 81 | nn.Conv2d(in_ch, in_ch // 2, 1, 1, 0), 82 | nn.ReLU(), 83 | nn.Conv2d(in_ch // 2, in_ch, 1, 1, 0), 84 | ) 85 | self.c_mask = nn.Sigmoid() 86 | def forward(self, x): 87 | avg_map_c = self.avg_weight(x) 88 | max_map_c = self.max_weight(x) 89 | c_mask = self.c_mask(torch.add(self.fus(avg_map_c), self.fus(max_map_c))) 90 | return torch.mul(x, c_mask) 91 | 92 | class FinalScore(nn.Module): 93 | def __init__(self): 94 | super(FinalScore, self).__init__() 95 | self.ca =CA(256) 96 | self.score = nn.Conv2d(256, 1, 1, 1, 0) 97 | def forward(self,f1,f2,xsize): 98 | f1 = torch.cat((f1,f2),1) 99 | f1 = self.ca(f1) 100 | score = F.interpolate(self.score(f1), xsize, mode='bilinear', align_corners=True) 101 | return score 102 | 103 | class Decoder(nn.Module): 104 | def __init__(self): 105 | super(Decoder, self).__init__() 106 | self.global_info =GlobalInfo() 107 | self.score_global = nn.Conv2d(512, 1, 1, 1, 0) 108 | 109 | self.gfb4_1 = GFB(512,512) 110 | self.gfb3_1= GFB(512,256) 111 | self.gfb2_1= GFB(256,128) 112 | 113 | self.gfb4_2 = GFB(512, 512) #1/8 114 | self.gfb3_2 = GFB(512, 256)#1/4 115 | self.gfb2_2 = GFB(256, 128)#1/2 116 | 117 | self.score_1=nn.Conv2d(128, 1, 1, 1, 0) 118 | self.score_2 = nn.Conv2d(128, 1, 1, 1, 0) 119 | 120 | self.refine =FinalScore() 121 | 122 | 123 | def forward(self,rgb,t): 124 | xsize=rgb[0].size()[2:] 125 | global_info =self.global_info(rgb[3],t[3]) # 512 1/16 126 | d1=self.gfb4_1(global_info,t[3],rgb[2],global_info) 127 | d2=self.gfb4_2(global_info, rgb[3], t[2], global_info) 128 | #print(d1.shape,d2.shape) 129 | d3= self.gfb3_1(d1, d2,rgb[1],global_info) 130 | d4 = self.gfb3_2(d2, d1, t[1], global_info) 131 | d5 = self.gfb2_1(d3, d4, rgb[0], global_info) 132 | d6 = self.gfb2_2(d4, d3, t[0], global_info) #1/2 128 133 | 134 | score_global = self.score_global(global_info) 135 | 136 | score1=self.score_1(F.interpolate(d5,xsize,mode='bilinear',align_corners=True)) 137 | score2 = self.score_2(F.interpolate(d6, xsize, mode='bilinear', align_corners=True)) 138 | score =self.refine(d5,d6,xsize) 139 | return score,score1,score2,score_global 140 | 141 | class Mnet(nn.Module): 142 | def __init__(self,train=False): 143 | super(Mnet,self).__init__() 144 | self.rgb_net= resnet50(pretrained=train) 145 | self.t_net= resnet50(pretrained=train) 146 | trans_layers_mapping = [[256,128],[512,256],[1024,512],[2048,512]] 147 | self.trans_rgb = nn.ModuleList() 148 | self.trans_t = nn.ModuleList() 149 | for mapp in trans_layers_mapping: 150 | self.trans_rgb.append(convblock(mapp[0],mapp[1],1,1,0)) 151 | self.trans_t.append(convblock(mapp[0], mapp[1], 1, 1, 0)) 152 | self.decoder=Decoder() 153 | 154 | for m in chain(self.decoder.modules(),chain(self.trans_rgb.modules(),self.trans_t.modules())): 155 | if isinstance(m, nn.Conv2d): 156 | m.weight.data.normal_(0, 0.01) 157 | elif isinstance(m, nn.BatchNorm2d): 158 | m.weight.data.fill_(1) 159 | m.bias.data.zero_() 160 | 161 | def forward(self,rgb,t): 162 | rgb_f=[] 163 | t_f = [] 164 | x = self.rgb_net.layer1(self.rgb_net.maxpool(self.rgb_net.relu(self.rgb_net.bn1(self.rgb_net.conv1(rgb))))) 165 | rgb_f.append(self.trans_rgb[0](x)) #256->128 166 | x= self.rgb_net.layer2(x) #256->512 167 | rgb_f.append(self.trans_rgb[1](x)) # 512->256 168 | x = self.rgb_net.layer3(x) # 512->1024 169 | rgb_f.append(self.trans_rgb[2](x)) # 1024->512 170 | x = self.rgb_net.layer4(x) # 1024->2048 171 | rgb_f.append(self.trans_rgb[3](x)) # 2048->512 172 | 173 | x = self.t_net.layer1(self.t_net.maxpool(self.t_net.relu(self.t_net.bn1(self.t_net.conv1(t))))) 174 | t_f.append(self.trans_t[0](x)) # 256->128 175 | x = self.t_net.layer2(x) # 256->512 176 | t_f.append(self.trans_t[1](x)) # 512->256 177 | x = self.t_net.layer3(x) # 512->1024 178 | t_f.append(self.trans_t[2](x)) # 1024->512 179 | x = self.t_net.layer4(x) # 1024->2048 180 | t_f.append(self.trans_t[3](x)) # 2048->512 181 | 182 | score,score1,score2,score_g =self.decoder(rgb_f,t_f) 183 | return score,score1,score2,score_g -------------------------------------------------------------------------------- /net_resnet50_4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from itertools import chain 5 | from torchvision.models import resnet50 6 | 7 | def convblock(in_,out_,ks,st,pad): 8 | return nn.Sequential( 9 | nn.Conv2d(in_,out_,ks,st,pad), 10 | nn.BatchNorm2d(out_), 11 | nn.ReLU(inplace=True) 12 | ) 13 | 14 | class GFB(nn.Module): 15 | def __init__(self,in_1,in_2): 16 | super(GFB, self).__init__() 17 | self.ca1 = CA(2*in_1) 18 | self.conv1 = convblock(2*in_1,128, 3, 1, 1) 19 | self.conv_globalinfo = convblock(512,128,3, 1, 1) 20 | self.ca2 = CA(in_2) 21 | self.conv_curfeat =convblock(in_2,128,3,1,1) 22 | self.conv_out= convblock(128,in_2,3,1,1) 23 | 24 | def forward(self, pre1,pre2,cur,global_info): 25 | cur_size = cur.size()[2:] 26 | pre = self.ca1(torch.cat((pre1,pre2),1)) 27 | pre =self.conv1(F.interpolate(pre,cur_size,mode='bilinear',align_corners=True)) 28 | 29 | global_info = self.conv_globalinfo(F.interpolate(global_info,cur_size,mode='bilinear',align_corners=True)) 30 | cur_feat =self.conv_curfeat(self.ca2(cur)) 31 | fus = pre + cur_feat + global_info 32 | return self.conv_out(fus) 33 | 34 | 35 | class GlobalInfo(nn.Module): 36 | def __init__(self): 37 | super(GlobalInfo, self).__init__() 38 | self.ca = CA(1024) 39 | self.de_chan = convblock(1024,256,3,1,1) 40 | 41 | self.b0 = nn.Sequential( 42 | nn.AdaptiveMaxPool2d(13), 43 | nn.Conv2d(256,128,1,1,0,bias=False), 44 | nn.ReLU(inplace=True) 45 | ) 46 | 47 | self.b1 = nn.Sequential( 48 | nn.AdaptiveMaxPool2d(9), 49 | nn.Conv2d(256,128,1,1,0,bias=False), 50 | nn.ReLU(inplace=True) 51 | ) 52 | self.b2 = nn.Sequential( 53 | nn.AdaptiveMaxPool2d(5), 54 | nn.Conv2d(256, 128, 1, 1, 0,bias=False), 55 | nn.ReLU(inplace=True) 56 | ) 57 | self.b3 = nn.Sequential( 58 | nn.AdaptiveMaxPool2d(1), 59 | nn.Conv2d(256, 128, 1, 1, 0,bias=False), 60 | nn.ReLU(inplace=True) 61 | ) 62 | self.fus = convblock(768,512,1,1,0) 63 | 64 | def forward(self, rgb,t): 65 | x_size=rgb.size()[2:] 66 | x=self.ca(torch.cat((rgb,t),1)) 67 | x=self.de_chan(x) 68 | b0 = F.interpolate(self.b0(x),x_size,mode='bilinear',align_corners=True) 69 | b1 = F.interpolate(self.b1(x),x_size,mode='bilinear',align_corners=True) 70 | b2 = F.interpolate(self.b2(x),x_size,mode='bilinear',align_corners=True) 71 | b3 = F.interpolate(self.b3(x),x_size,mode='bilinear',align_corners=True) 72 | out = self.fus(torch.cat((b0,b1,b2,b3,x),1)) 73 | return out 74 | 75 | class CA(nn.Module): 76 | def __init__(self,in_ch): 77 | super(CA, self).__init__() 78 | self.avg_weight = nn.AdaptiveAvgPool2d(1) 79 | self.max_weight = nn.AdaptiveMaxPool2d(1) 80 | self.fus = nn.Sequential( 81 | nn.Conv2d(in_ch, in_ch // 2, 1, 1, 0), 82 | nn.ReLU(), 83 | nn.Conv2d(in_ch // 2, in_ch, 1, 1, 0), 84 | ) 85 | self.c_mask = nn.Sigmoid() 86 | def forward(self, x): 87 | avg_map_c = self.avg_weight(x) 88 | max_map_c = self.max_weight(x) 89 | c_mask = self.c_mask(torch.add(self.fus(avg_map_c), self.fus(max_map_c))) 90 | return torch.mul(x, c_mask) 91 | 92 | class FinalScore(nn.Module): 93 | def __init__(self): 94 | super(FinalScore, self).__init__() 95 | self.ca =CA(128) 96 | self.score = nn.Conv2d(128, 1, 1, 1, 0) 97 | def forward(self,f1,f2): 98 | f1 = torch.cat((f1,f2),1) 99 | f1 = self.ca(f1) 100 | score = self.score(f1) 101 | return score 102 | 103 | class Decoder(nn.Module): 104 | def __init__(self): 105 | super(Decoder, self).__init__() 106 | self.global_info =GlobalInfo() 107 | self.score_global = nn.Conv2d(512, 1, 1, 1, 0) 108 | 109 | self.gfb4_1 = GFB(512,512) 110 | self.gfb3_1= GFB(512,256) 111 | self.gfb2_1= GFB(256,128) 112 | self.gfb1_1 = GFB(128, 64) # 1/2 113 | 114 | self.gfb4_2 = GFB(512, 512) #1/16 115 | self.gfb3_2 = GFB(512, 256)#1/8 116 | self.gfb2_2 = GFB(256, 128)#1/4 117 | self.gfb1_2 = GFB(128, 64) # 1/2 118 | 119 | self.score_1=nn.Conv2d(64, 1, 1, 1, 0) 120 | self.score_2 = nn.Conv2d(64, 1, 1, 1, 0) 121 | 122 | self.refine =FinalScore() 123 | 124 | 125 | def forward(self,rgb,t): 126 | xsize=rgb[0].size()[2:] 127 | global_info =self.global_info(rgb[4],t[4]) # 512 1/16 128 | d1=self.gfb4_1(global_info,t[4],rgb[3],global_info) 129 | d2=self.gfb4_2(global_info, rgb[4], t[3], global_info) 130 | #print(d1.shape,d2.shape) 131 | d3= self.gfb3_1(d1, d2,rgb[2],global_info) 132 | d4 = self.gfb3_2(d2, d1, t[2], global_info) 133 | d5 = self.gfb2_1(d3, d4, rgb[1], global_info) 134 | d6 = self.gfb2_2(d4, d3, t[1], global_info) #1/2 128 135 | d7 = self.gfb1_1(d5, d6, rgb[0], global_info) 136 | d8 = self.gfb1_2(d6, d5, t[0], global_info) # 1/2 128 137 | 138 | score_global = self.score_global(global_info) 139 | 140 | score1=self.score_1(d7) 141 | score2 = self.score_2(d8) 142 | score =self.refine(d7,d8) 143 | return score,score1,score2,score_global 144 | 145 | class Mnet(nn.Module): 146 | def __init__(self,train=False): 147 | super(Mnet,self).__init__() 148 | self.rgb_net= resnet50(pretrained=train) 149 | self.t_net= resnet50(pretrained=train) 150 | trans_layers_mapping = [[256,128],[512,256],[1024,512],[2048,512]] 151 | self.trans_rgb = nn.ModuleList() 152 | self.trans_t = nn.ModuleList() 153 | for mapp in trans_layers_mapping: 154 | self.trans_rgb.append(convblock(mapp[0],mapp[1],1,1,0)) 155 | self.trans_t.append(convblock(mapp[0], mapp[1], 1, 1, 0)) 156 | self.decoder=Decoder() 157 | 158 | for m in chain(self.decoder.modules(),chain(self.trans_rgb.modules(),self.trans_t.modules())): 159 | if isinstance(m, nn.Conv2d): 160 | m.weight.data.normal_(0, 0.01) 161 | elif isinstance(m, nn.BatchNorm2d): 162 | m.weight.data.fill_(1) 163 | m.bias.data.zero_() 164 | 165 | def forward(self,rgb,t): 166 | rgb_f=[] 167 | t_f = [] 168 | x = self.rgb_net.relu(self.rgb_net.bn1(self.rgb_net.conv1(rgb))) 169 | rgb_f.append(x) # 64 170 | x = self.rgb_net.layer1(self.rgb_net.maxpool(x)) 171 | rgb_f.append(self.trans_rgb[0](x)) #256->128 172 | x= self.rgb_net.layer2(x) #256->512 173 | rgb_f.append(self.trans_rgb[1](x)) # 512->256 174 | x = self.rgb_net.layer3(x) # 512->1024 175 | rgb_f.append(self.trans_rgb[2](x)) # 1024->512 176 | x = self.rgb_net.layer4(x) # 1024->2048 177 | rgb_f.append(self.trans_rgb[3](x)) # 2048->512 178 | 179 | x = self.t_net.relu(self.t_net.bn1(self.t_net.conv1(t))) 180 | t_f.append(x) # 64 181 | x = self.t_net.layer1(self.t_net.maxpool(x)) 182 | t_f.append(self.trans_t[0](x)) # 256->128 183 | x = self.t_net.layer2(x) # 256->512 184 | t_f.append(self.trans_t[1](x)) # 512->256 185 | x = self.t_net.layer3(x) # 512->1024 186 | t_f.append(self.trans_t[2](x)) # 1024->512 187 | x = self.t_net.layer4(x) # 1024->2048 188 | t_f.append(self.trans_t[3](x)) # 2048->512 189 | 190 | score,score1,score2,score_g =self.decoder(rgb_f,t_f) 191 | return score,score1,score2,score_g -------------------------------------------------------------------------------- /smooth_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | def laplacian_edge(img): 4 | laplacian_filter = torch.Tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) 5 | filter = torch.reshape(laplacian_filter, [1, 1, 3, 3]) 6 | filter = filter.cuda() 7 | lap_edge = F.conv2d(img, filter, stride=1, padding=1) 8 | return lap_edge 9 | 10 | def gradient_x(img): 11 | sobel = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) 12 | filter = torch.reshape(sobel,[1,1,3,3]) 13 | filter = filter.cuda() 14 | gx = F.conv2d(img, filter, stride=1, padding=1) 15 | return gx 16 | 17 | 18 | def gradient_y(img): 19 | sobel = torch.Tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) 20 | filter = torch.reshape(sobel, [1, 1,3,3]) 21 | filter = filter.cuda() 22 | gy = F.conv2d(img, filter, stride=1, padding=1) 23 | return gy 24 | 25 | def charbonnier_penalty(s): 26 | cp_s = torch.pow(torch.pow(s, 2) + 0.001**2, 0.5) 27 | return cp_s 28 | 29 | def get_saliency_smoothness(pred, gt, size_average=True): 30 | alpha = 10 31 | s1 = 10 32 | s2 = 0 33 | ## first oder derivative: sobel 34 | sal_x = torch.abs(gradient_x(pred)) 35 | sal_y = torch.abs(gradient_y(pred)) 36 | gt_x = gradient_x(gt) 37 | gt_y = gradient_y(gt) 38 | w_x = torch.exp(torch.abs(gt_x) * (-alpha)) 39 | w_y = torch.exp(torch.abs(gt_y) * (-alpha)) 40 | cps_x = charbonnier_penalty(sal_x * w_x) 41 | cps_y = charbonnier_penalty(sal_y * w_y) 42 | cps_xy = cps_x + cps_y 43 | 44 | ## second order derivative: laplacian 45 | lap_sal = torch.abs(laplacian_edge(pred)) 46 | lap_gt = torch.abs(laplacian_edge(gt)) 47 | weight_lap = torch.exp(lap_gt * (-alpha)) 48 | weighted_lap = charbonnier_penalty(lap_sal*weight_lap) 49 | 50 | smooth_loss = s1*torch.mean(cps_xy) + s2*torch.mean(weighted_lap) 51 | 52 | return smooth_loss 53 | 54 | class smoothness_loss(torch.nn.Module): 55 | def __init__(self, size_average = True): 56 | super(smoothness_loss, self).__init__() 57 | self.size_average = size_average 58 | 59 | def forward(self, pred, target): 60 | 61 | return get_saliency_smoothness(pred, target, self.size_average) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import DataLoader 3 | from lib.dataset import Data 4 | import torch.nn.functional as F 5 | import torch 6 | import cv2 7 | import time 8 | from net import Mnet 9 | import numpy as np 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 11 | if __name__ == '__main__': 12 | model_path='./model/final.pth' 13 | out_path = './output' 14 | data = Data(root='/data/to/test',mode='test') 15 | loader = DataLoader(data, batch_size=1,shuffle=False) 16 | net = Mnet().cuda() 17 | print('loading model from %s...' % model_path) 18 | net.load_state_dict(torch.load(model_path)) 19 | if not os.path.exists(out_path): os.mkdir(out_path) 20 | time_s = time.time() 21 | img_num = len(loader) 22 | net.eval() 23 | with torch.no_grad(): 24 | for rgb, t, _ , (H, W), name in loader: 25 | name = name[0].split('\\')[-1] 26 | print(name) 27 | score, score1, score2,score_g = net(rgb.cuda().float(), t.cuda().float()) 28 | score = F.interpolate(score, size=(H, W), mode='bilinear',align_corners=True) 29 | pred = np.squeeze(torch.sigmoid(score).cpu().data.numpy()) 30 | cv2.imwrite(os.path.join(out_path, name[:-4] + '.png'), 255 * pred) 31 | time_e = time.time() 32 | print('speed: %f FPS' % (img_num / (time_e - time_s))) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | coding='utf-8' 2 | import os 3 | from net import Mnet 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from lib.dataset import Data 10 | from lib.data_prefetcher import DataPrefetcher 11 | from torch.nn import functional as F 12 | from smooth_loss import get_saliency_smoothness 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 14 | def my_loss1(score,score1,score2,score_g,label): 15 | sal_loss2 = F.binary_cross_entropy_with_logits(score1, label, reduction='mean') 16 | sal_loss3 = F.binary_cross_entropy_with_logits(score2, label, reduction='mean') 17 | sal_loss1 = F.binary_cross_entropy_with_logits(score, label, reduction='mean') 18 | sml = get_saliency_smoothness(torch.sigmoid(score),label) 19 | label_g = F.interpolate(label, score_g.shape[2:], mode='bilinear', align_corners=True) 20 | sal_loss_g = F.binary_cross_entropy_with_logits(score_g, label_g, reduction='mean') 21 | return sal_loss1 + sal_loss2 + sal_loss3 + 0.5*sml + sal_loss_g 22 | 23 | if __name__ == '__main__': 24 | random.seed(118) 25 | np.random.seed(118) 26 | torch.manual_seed(118) 27 | torch.cuda.manual_seed(118) 28 | torch.cuda.manual_seed_all(118) 29 | 30 | # dataset 31 | img_root = '/data/to/train' 32 | save_path = './model' 33 | if not os.path.exists(save_path): os.mkdir(save_path) 34 | lr = 0.001 35 | batch_size = 4 36 | epoch = 100 37 | lr_dec=[21,51] 38 | data = Data(img_root) 39 | loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=1) 40 | net = Mnet().cuda() 41 | net.load_pretrained_model() 42 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 43 | iter_num = len(loader) 44 | net.train() 45 | 46 | for epochi in range(1, epoch + 1): 47 | if epochi in lr_dec : 48 | lr=lr/10 49 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 50 | print(lr) 51 | prefetcher = DataPrefetcher(loader) 52 | rgb, t, label = prefetcher.next() 53 | r_sal_loss = 0 54 | net.zero_grad() 55 | i = 0 56 | while rgb is not None: 57 | i+=1 58 | score, score1, score2,g= net(rgb, t) 59 | sal_loss= my_loss1( score,score1,score2,g,label) 60 | r_sal_loss += sal_loss.data 61 | sal_loss.backward() 62 | optimizer.step() 63 | optimizer.zero_grad() 64 | if i % 100 == 0: 65 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || loss : %5.4f' % ( 66 | epochi, epoch, i, iter_num, r_sal_loss / 100)) 67 | r_sal_loss = 0 68 | rgb, t, label = prefetcher.next() 69 | if epochi %5 ==0: 70 | torch.save(net.state_dict(), '%s/epoch_%d.pth' % (save_path, epochi)) 71 | torch.save(net.state_dict(), '%s/final.pth' % (save_path)) -------------------------------------------------------------------------------- /train_resnet50.py: -------------------------------------------------------------------------------- 1 | coding='utf-8' 2 | import os 3 | from net_resnet50 import Mnet 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from lib.dataset import Data 10 | from lib.data_prefetcher import DataPrefetcher 11 | from torch.nn import functional as F 12 | from smooth_loss import get_saliency_smoothness 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 14 | 15 | def my_loss1(score,score1,score2,score_g,label): 16 | score = F.interpolate(score, label.shape[2:], mode='bilinear', align_corners=True) 17 | score1= F.interpolate(score1, label.shape[2:], mode='bilinear', align_corners=True) 18 | score2 = F.interpolate(score2, label.shape[2:], mode='bilinear', align_corners=True) 19 | sal_loss2 = F.binary_cross_entropy_with_logits(score1, label, reduction='mean') 20 | sal_loss3 = F.binary_cross_entropy_with_logits(score2, label, reduction='mean') 21 | sal_loss1 = F.binary_cross_entropy_with_logits(score, label, reduction='mean') 22 | sml = get_saliency_smoothness(torch.sigmoid(score),label) 23 | label_g = F.interpolate(label, score_g.shape[2:], mode='bilinear', align_corners=True) 24 | sal_loss_g = F.binary_cross_entropy_with_logits(score_g, label_g, reduction='mean') 25 | return sal_loss1 + sal_loss2 + sal_loss3 + 0.5*sml + sal_loss_g 26 | 27 | 28 | if __name__ == '__main__': 29 | random.seed(118) 30 | np.random.seed(118) 31 | torch.manual_seed(118) 32 | torch.cuda.manual_seed(118) 33 | torch.cuda.manual_seed_all(118) 34 | # dataset 35 | img_root = '/data/to/train' 36 | save_path = './model' 37 | if not os.path.exists(save_path): os.mkdir(save_path) 38 | lr = 0.001 39 | batch_size = 4 40 | epoch = 100 41 | lr_dec=[21,51] 42 | data = Data(img_root) 43 | loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=1) 44 | net = Mnet(train=True).cuda() 45 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 46 | iter_num = len(loader) 47 | net.train() 48 | for epochi in range(1, epoch + 1): 49 | if epochi in lr_dec : 50 | lr=lr/10 51 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 52 | print(lr) 53 | prefetcher = DataPrefetcher(loader) 54 | rgb, t, label = prefetcher.next() 55 | r_sal_loss = 0 56 | net.zero_grad() 57 | i = 0 58 | while rgb is not None: 59 | i+=1 60 | score, score1, score2,g= net(rgb, t) 61 | sal_loss= my_loss1( score,score1,score2,g,label) 62 | r_sal_loss += sal_loss.data 63 | sal_loss.backward() 64 | optimizer.step() 65 | optimizer.zero_grad() 66 | if i % 100 == 0: 67 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || loss : %5.4f' % ( 68 | epochi, epoch, i, iter_num, r_sal_loss / 100)) 69 | r_sal_loss = 0 70 | rgb, t, label = prefetcher.next() 71 | if epochi %5 ==0: 72 | torch.save(net.state_dict(), '%s/epoch_%d.pth' % (save_path, epochi)) 73 | torch.save(net.state_dict(), '%s/final.pth' % (save_path)) -------------------------------------------------------------------------------- /train_resnet50_4.py: -------------------------------------------------------------------------------- 1 | coding='utf-8' 2 | import os 3 | from net_resnet50_4 import Mnet 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from lib.dataset import Data 10 | from lib.data_prefetcher import DataPrefetcher 11 | from torch.nn import functional as F 12 | from smooth_loss import get_saliency_smoothness 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 14 | def my_loss1(score,score1,score2,score_g,label): 15 | score = F.interpolate(score, label.shape[2:], mode='bilinear', align_corners=True) 16 | score1= F.interpolate(score1, label.shape[2:], mode='bilinear', align_corners=True) 17 | score2 = F.interpolate(score2, label.shape[2:], mode='bilinear', align_corners=True) 18 | sal_loss2 = F.binary_cross_entropy_with_logits(score1, label, reduction='mean') 19 | sal_loss3 = F.binary_cross_entropy_with_logits(score2, label, reduction='mean') 20 | sal_loss1 = F.binary_cross_entropy_with_logits(score, label, reduction='mean') 21 | sml = get_saliency_smoothness(torch.sigmoid(score),label) 22 | label_g = F.interpolate(label, score_g.shape[2:], mode='bilinear', align_corners=True) 23 | sal_loss_g = F.binary_cross_entropy_with_logits(score_g, label_g, reduction='mean') 24 | return sal_loss1 + sal_loss2 + sal_loss3 + 0.5*sml + sal_loss_g 25 | if __name__ == '__main__': 26 | random.seed(118) 27 | np.random.seed(118) 28 | torch.manual_seed(118) 29 | torch.cuda.manual_seed(118) 30 | torch.cuda.manual_seed_all(118) 31 | # dataset 32 | img_root = '/data/to/train' 33 | save_path = './model' 34 | if not os.path.exists(save_path): os.mkdir(save_path) 35 | lr = 0.001 36 | batch_size = 4 37 | epoch = 100 38 | lr_dec=[21,51] 39 | data = Data(img_root) 40 | loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=1) 41 | net = Mnet(train=True).cuda() 42 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 43 | iter_num = len(loader) 44 | net.train() 45 | for epochi in range(1, epoch + 1): 46 | if epochi in lr_dec : 47 | lr=lr/10 48 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=0.0005,momentum=0.9) 49 | print(lr) 50 | prefetcher = DataPrefetcher(loader) 51 | rgb, t, label = prefetcher.next() 52 | r_sal_loss = 0 53 | net.zero_grad() 54 | i = 0 55 | while rgb is not None: 56 | i+=1 57 | score, score1, score2,g= net(rgb, t) 58 | sal_loss= my_loss1( score,score1,score2,g,label) 59 | r_sal_loss += sal_loss.data 60 | sal_loss.backward() 61 | optimizer.step() 62 | optimizer.zero_grad() 63 | if i % 100 == 0: 64 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || loss : %5.4f' % ( 65 | epochi, epoch, i, iter_num, r_sal_loss / 100)) 66 | r_sal_loss = 0 67 | rgb, t, label = prefetcher.next() 68 | if epochi %5 ==0: 69 | torch.save(net.state_dict(), '%s/epoch_%d.pth' % (save_path, epochi)) 70 | torch.save(net.state_dict(), '%s/final.pth' % (save_path)) -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # vgg16 3 | def vgg(cfg, i, batch_norm=False): 4 | layers = [] 5 | in_channels = i 6 | stage = 1 7 | for v in cfg: 8 | if v == 'M': 9 | stage += 1 10 | if stage == 6: 11 | layers += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1)] 12 | else: 13 | layers += [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)] 14 | else: 15 | if stage == 6: 16 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 17 | else: 18 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 19 | if batch_norm: 20 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 21 | else: 22 | layers += [conv2d, nn.ReLU(inplace=True)] 23 | in_channels = v 24 | return layers 25 | 26 | class a_vgg16(nn.Module): 27 | def __init__(self): 28 | super(a_vgg16, self).__init__() 29 | self.cfg =[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 30 | self.extract = [3, 8, 15, 22, 29] # [ 8, 15, 22, 29] 31 | # 64:1 -->128:1/2 -->256:1/4 -->512 :1/8 --> 512:1/16 -->M-> 512,1/16 32 | self.base = nn.ModuleList(vgg(self.cfg, 3)) 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | m.weight.data.normal_(0, 0.01) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | m.weight.data.fill_(1) 38 | m.bias.data.zero_() 39 | 40 | def forward(self, x): 41 | tmp_x = [] 42 | for k in range(len(self.base)): 43 | x = self.base[k](x) 44 | if k in self.extract: 45 | tmp_x.append(x) #collect feature maps 1(64) 1/2(128) 1/4(256) 1/8(512) 1/16(512) 46 | return tmp_x 47 | 48 | --------------------------------------------------------------------------------