├── image └── logo.png ├── utils_downstream ├── config.py ├── misc.py ├── test_data.py ├── utils.py ├── ssim_loss.py ├── dataset_rgbd_strategy2.py └── saliency_metric.py ├── get_contour.py ├── utils_ssl ├── joint_transforms.py ├── datasets_stage2.py ├── datasets_stage1.py └── misc.py ├── test_score.py ├── README.md ├── prediction_rgbd.py ├── model ├── model_stage1.py ├── model_stage2.py └── model_stage3.py ├── train_stage3_downstream.py ├── train_stage2_pretext2.py └── train_stage1_pretext1.py /image/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xiaoqi-Zhao-DLUT/SSLSOD/HEAD/image/logo.png -------------------------------------------------------------------------------- /utils_downstream/config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | dutrgbd_root_test = '' 5 | njud_root_test = '' 6 | nlpr_root_test = '' 7 | stere_root_test ='' 8 | sip_root_test = '' 9 | rgbd135_root_test = '' 10 | ssd_root_test = '' 11 | lfsd_root_test = '' 12 | 13 | 14 | dutrgbd = os.path.join(dutrgbd_root_test) 15 | njud = os.path.join(njud_root_test) 16 | nlpr = os.path.join(nlpr_root_test) 17 | stere = os.path.join(stere_root_test) 18 | sip = os.path.join(sip_root_test) 19 | rgbd135 = os.path.join(rgbd135_root_test) 20 | ssd = os.path.join(ssd_root_test) 21 | lfsd = os.path.join(lfsd_root_test) 22 | 23 | 24 | SSLSOD = '' # model 25 | 26 | RGBD_SOD_Models = {'SSLSOD':SSLSOD} 27 | -------------------------------------------------------------------------------- /get_contour.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torchvision import transforms 5 | import cv2 6 | test = '' 7 | to_test = {'contour':test} 8 | img_transform = transforms.Compose([ 9 | transforms.ToTensor()]) 10 | save_path = '' 11 | to_pil = transforms.ToPILImage() 12 | 13 | for name, root in to_test.items(): 14 | root1 = os.path.join(root) 15 | img_list = [os.path.splitext(f)[0] for f in os.listdir(root1) if f.endswith('.png')] 16 | for idx, img_name in enumerate(img_list): 17 | print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) 18 | img1 = Image.open(os.path.join(root, img_name + '.png')).convert('L') 19 | img1 = np.array(img1) 20 | kernel = np.ones((5,5),np.uint8) 21 | img2 = cv2.erode(img1,kernel) 22 | img3 = cv2.dilate(img1,kernel) 23 | img = np.array(img3-img2) 24 | img[img >= 6] = 255 25 | img[img<6] = 0 26 | cv2.imwrite(os.path.join(save_path, img_name + '.jpg'), img,[int(cv2.IMWRITE_JPEG_LUMA_QUALITY),50]) 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /utils_downstream/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pydensecrf.densecrf as dcrf 4 | 5 | 6 | class AvgMeter(object): 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | def check_mkdir(dir_name): 24 | if not os.path.isdir(dir_name): 25 | os.makedirs(dir_name) 26 | 27 | def crf_refine(img, annos): 28 | def _sigmoid(x): 29 | return 1 / (1 + np.exp(-x)) 30 | 31 | assert img.dtype == np.uint8 32 | assert annos.dtype == np.uint8 33 | print(img.shape[:2],annos.shape) 34 | assert img.shape[:2] == annos.shape 35 | 36 | # img and annos should be np array with data type uint8 37 | 38 | EPSILON = 1e-8 39 | 40 | M = 2 # salient or not 41 | tau = 1.05 42 | # Setup the CRF model 43 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 44 | 45 | anno_norm = annos / 255. 46 | 47 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 48 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 49 | 50 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') # 创建和输入图片同样大小的U 51 | U[0, :] = n_energy.flatten() 52 | U[1, :] = p_energy.flatten() 53 | 54 | d.setUnaryEnergy(U) 55 | 56 | d.addPairwiseGaussian(sxy=3, compat=3) 57 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 58 | 59 | # Do the inference 60 | infer = np.array(d.inference(1)).astype('float32') 61 | res = infer[1, :] 62 | 63 | res = res * 255 64 | res = res.reshape(img.shape[:2]) # 和输入图片同样大小 65 | return res.astype('uint8') 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /utils_ssl/joint_transforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | from PIL import Image, ImageOps 4 | 5 | Image.MAX_IMAGE_PIXELS = 1000000000 6 | 7 | 8 | class Compose(object): 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | def __call__(self, img, depth, mask): 12 | assert img.size == mask.size 13 | for t in self.transforms: 14 | img, depth, mask = t(img, depth, mask) 15 | return img, depth, mask 16 | 17 | 18 | 19 | class RandomCrop(object): 20 | def __init__(self, size,size1, padding=0): 21 | if isinstance(size, numbers.Number): 22 | self.size = (int(size), int(size1)) 23 | else: 24 | self.size = size 25 | self.padding = padding 26 | 27 | def __call__(self, img, depth, mask): 28 | if self.padding > 0: 29 | img = ImageOps.expand(img, border=self.padding, fill=0) 30 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 31 | depth = ImageOps.expand(depth, border=self.padding, fill=0) 32 | 33 | assert img.size == mask.size 34 | w, h = img.size 35 | th, tw = self.size 36 | if w == tw and h == th: 37 | return img, mask 38 | if w < tw or h < th: 39 | return img.resize((tw, th), Image.BILINEAR), depth.resize((tw, th), Image.NEAREST), mask.resize((tw, th), Image.NEAREST) 40 | return img.resize((tw, th), Image.BILINEAR), depth.resize((tw, th), Image.NEAREST), mask.resize((tw, th), Image.NEAREST) 41 | 42 | class RandomHorizontallyFlip(object): 43 | def __call__(self, img, depth, mask): 44 | if random.random() < 0.5: 45 | return img.transpose(Image.FLIP_LEFT_RIGHT), depth.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 46 | return img, depth, mask 47 | 48 | 49 | class RandomRotate(object): 50 | def __init__(self, degree): 51 | self.degree = degree 52 | 53 | def __call__(self, img, depth, mask): 54 | rotate_degree = random.random() * 2 * self.degree - self.degree 55 | return img.rotate(rotate_degree, Image.BILINEAR), depth.rotate(rotate_degree, Image.NEAREST), mask.rotate(rotate_degree, Image.NEAREST) 56 | -------------------------------------------------------------------------------- /utils_downstream/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | 6 | class test_dataset: 7 | def __init__(self, image_root, gt_root): 8 | self.img_list_1 = [os.path.splitext(f)[0] for f in os.listdir(image_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')] 9 | self.img_list_2 = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')] 10 | self.img_list = list(set(self.img_list_1).intersection(set(self.img_list_2))) 11 | 12 | self.image_root = image_root 13 | self.gt_root = gt_root 14 | self.transform = transforms.Compose([ 15 | transforms.ToTensor(), 16 | ]) 17 | self.gt_transform = transforms.ToTensor() 18 | self.size = len(self.img_list) 19 | self.index = 0 20 | 21 | def load_data(self): 22 | #image = self.rgb_loader(self.images[self.index]) 23 | rgb_png_path = os.path.join(self.image_root,self.img_list[self.index]+ '.png') 24 | rgb_jpg_path = os.path.join(self.image_root,self.img_list[self.index]+ '.jpg') 25 | rgb_bmp_path = os.path.join(self.image_root,self.img_list[self.index]+ '.bmp') 26 | if os.path.exists(rgb_png_path): 27 | image = self.binary_loader(rgb_png_path) 28 | elif os.path.exists(rgb_jpg_path): 29 | image = self.binary_loader(rgb_jpg_path) 30 | else: 31 | image = self.binary_loader(rgb_bmp_path) 32 | if os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.png')): 33 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png')) 34 | elif os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')): 35 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')) 36 | else: 37 | gt = self.binary_loader(os.path.join(self.gt_root, self.img_list[self.index] + '.bmp')) 38 | 39 | self.index += 1 40 | return image, gt 41 | 42 | def rgb_loader(self, path): 43 | with open(path, 'rb') as f: 44 | img = Image.open(f) 45 | return img.convert('RGB') 46 | 47 | def binary_loader(self, path): 48 | with open(path, 'rb') as f: 49 | img = Image.open(f) 50 | return img.convert('L') 51 | 52 | -------------------------------------------------------------------------------- /utils_downstream/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | def clip_gradient(optimizer, grad_clip): 6 | for group in optimizer.param_groups: 7 | for param in group['params']: 8 | if param.grad is not None: 9 | param.grad.data.clamp_(-grad_clip, grad_clip) 10 | 11 | 12 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=5): 13 | decay = decay_rate ** (epoch // decay_epoch) 14 | for param_group in optimizer.param_groups: 15 | param_group['lr'] *= decay 16 | 17 | 18 | def truncated_normal_(tensor, mean=0, std=1): 19 | size = tensor.shape 20 | tmp = tensor.new_empty(size + (4,)).normal_() 21 | valid = (tmp < 2) & (tmp > -2) 22 | ind = valid.max(-1, keepdim=True)[1] 23 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 24 | tensor.data.mul_(std).add_(mean) 25 | 26 | def init_weights(m): 27 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 28 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 29 | #nn.init.normal_(m.weight, std=0.001) 30 | #nn.init.normal_(m.bias, std=0.001) 31 | truncated_normal_(m.bias, mean=0, std=0.001) 32 | 33 | def init_weights_orthogonal_normal(m): 34 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 35 | nn.init.orthogonal_(m.weight) 36 | truncated_normal_(m.bias, mean=0, std=0.001) 37 | #nn.init.normal_(m.bias, std=0.001) 38 | 39 | def l2_regularisation(m): 40 | l2_reg = None 41 | 42 | for W in m.parameters(): 43 | if l2_reg is None: 44 | l2_reg = W.norm(2) 45 | else: 46 | l2_reg = l2_reg + W.norm(2) 47 | return l2_reg 48 | 49 | class AvgMeter(object): 50 | def __init__(self, num=40): 51 | self.num = num 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | self.losses = [] 60 | 61 | def update(self, val, n=1): 62 | self.val = val 63 | self.sum += val * n 64 | self.count += n 65 | self.avg = self.sum / self.count 66 | self.losses.append(val) 67 | 68 | def show(self): 69 | a = len(self.losses) 70 | b = np.maximum(a-self.num, 0) 71 | c = self.losses[b:] 72 | #print(c) 73 | #d = torch.mean(torch.stack(c)) 74 | #print(d) 75 | return torch.mean(torch.stack(c)) 76 | 77 | # def save_mask_prediction_example(mask, pred, iter): 78 | # plt.imshow(pred[0,:,:],cmap='Greys') 79 | # plt.savefig('images/'+str(iter)+"_prediction.png") 80 | # plt.imshow(mask[0,:,:],cmap='Greys') 81 | # plt.savefig('images/'+str(iter)+"_mask.png") -------------------------------------------------------------------------------- /test_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from utils_downstream.test_data import test_dataset 4 | from utils_downstream.saliency_metric import cal_mae,cal_fm,cal_sm,cal_em,cal_wfm, cal_dice, cal_iou,cal_ber,cal_acc 5 | from utils_downstream.config import dutrgbd,njud,nlpr,stere,sip,rgbd135,ssd,lfsd 6 | from utils_downstream.config import RGBD_SOD_Models 7 | from tqdm import tqdm 8 | 9 | test_datasets = {'DUT-RGBD':dutrgbd,'NJU2K':njud,'NJUD':njud,'NLPR':nlpr,'STERE':stere,'SIP':sip,'DES':rgbd135,'RGBD135':rgbd135,'SSD':ssd,'LFSD':lfsd} 10 | for method_name,method_map_root in RGBD_SOD_Models.items(): 11 | print(method_name) 12 | for name, root in test_datasets.items(): 13 | print(name) 14 | sal_root = method_map_root +name 15 | print(sal_root) 16 | gt_root = root+'/GT' 17 | print(gt_root) 18 | if os.path.exists(sal_root): 19 | test_loader = test_dataset(sal_root, gt_root) 20 | mae,fm,sm,em,wfm, m_dice, m_iou,ber,acc= cal_mae(),cal_fm(test_loader.size),cal_sm(),cal_em(),cal_wfm(), cal_dice(), cal_iou(),cal_ber(),cal_acc() 21 | for i in tqdm(range(test_loader.size)): 22 | # print ('predicting for %d / %d' % ( i + 1, test_loader.size)) 23 | sal, gt = test_loader.load_data() 24 | if sal.size != gt.size: 25 | x, y = gt.size 26 | sal = sal.resize((x, y)) 27 | gt = np.asarray(gt, np.float32) 28 | gt /= (gt.max() + 1e-8) 29 | gt[gt > 0.5] = 1 30 | gt[gt != 1] = 0 31 | res = sal 32 | res = np.array(res) 33 | if res.max() == res.min(): 34 | res = res/255 35 | else: 36 | res = (res - res.min()) / (res.max() - res.min()) 37 | #二值化会提升mae和meanf,em 38 | # res[res > 0.5] = 1 39 | # res[res != 1] = 0 40 | 41 | mae.update(res, gt) 42 | sm.update(res,gt) 43 | fm.update(res, gt) 44 | em.update(res,gt) 45 | wfm.update(res,gt) 46 | m_dice.update(res,gt) 47 | m_iou.update(res,gt) 48 | ber.update(res,gt) 49 | acc.update(res,gt) 50 | 51 | MAE = mae.show() 52 | maxf,meanf,_,_ = fm.show() 53 | sm = sm.show() 54 | em = em.show() 55 | wfm = wfm.show() 56 | m_dice = m_dice.show() 57 | m_iou = m_iou.show() 58 | ber = ber.show() 59 | acc = acc.show() 60 | print('method_name: {} dataset: {} MAE: {:.4f} Ber: {:.4f} maxF: {:.4f} avgF: {:.4f} wfm: {:.4f} Sm: {:.4f} Em: {:.4f} M_dice: {:.4f} M_iou: {:.4f} Acc: {:.4f}'.format(method_name,name, MAE,ber, maxf,meanf,wfm,sm,em, m_dice, m_iou,acc)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSLSOD 2 |

3 | Logo 4 | 5 | 6 |

Self-Supervised Pretraining for RGB-D Salient Object Detection

7 | 8 |

9 | Xiaoqi Zhao, Youwei Pang, Lihe Zhang, Huchuan Lu, Xiang Ruan 10 |
11 | ⭐ arXiv » 12 | :fire:[Slide&极市平台推送] 13 |
14 |

15 |

16 | 17 | The official repo of the AAAI 2022 paper, Self-Supervised Pretraining for RGB-D Salient Object Detection. 18 | ## Saliency map 19 | [Google Drive](https://drive.google.com/file/d/1i5OElgml76p76N2l9eYlFc4Bm8jOlxUk/view?usp=sharing) / [BaiduYunPan(d63j)](https://pan.baidu.com/s/1qifMM7wgR5gPhb6ZlRU9Zw) 20 | ## Trained Model 21 | You can download all trained models at [Google Drive](https://drive.google.com/file/d/1mxX4yk6yOCTapJ_dn_5nZhnb8IpvzEt0/view?usp=sharing) / [BaiduYunPan(0401)](https://pan.baidu.com/s/1zruPGxeR-7j4bfNSrzU5Lg). 22 | ## Datasets 23 | * [Google Drive](https://drive.google.com/file/d/1Nxm8wr2jSW-Ntqu8cdm4GfZPVOClJbZE/view?usp=sharing) / [BaiduYunPan(1a4t)](https://pan.baidu.com/s/1DUHzxs4JP4hzWJIoz4Lqyg) 24 | * We believe that using a large amount of RGB-D data for pre-training, we will get a super-strong SSL-based model even surpassing the ImageNet-based model. This [survey](https://arxiv.org/pdf/2201.05761.pdf) of the RGB-D dataset may be helpful to you. 25 | ## Training 26 | * SSL-based model 27 | 1.Run train_stage1_pretext1.py 28 | 2.Run get_contour.py (can generate the depth-contour maps for the stage2 training) 29 | 2.Load the pretext1 weights for Crossmodal_Autoendoer (model_stage1.py) and run train_stage2_pretext2.py 30 | 3.Load the pretext1 and pretext2 weights for RGBD_sal (model_stage3.py) as initialization and run train_stage3_downstream.py 31 | * ImageNet-based model 32 | Set 'pretrained= Ture' for models.vgg16_bn(pretrained='True') in RGBD_sal (model_stage3.py) and run train_stage3_downstream.py 33 | ## Testing 34 | Run prediction_rgbd.py (can generate the predicted saliency maps) 35 | Run test_score.py (can evaluate the predicted saliency maps in terms of fmax,fmean,wfm,sm,em,mae,mdice,miou,ber,acc). 36 | ## BibTex 37 | ``` 38 | @inproceedings{SSLSOD, 39 | title={Self-Supervised Pretraining for RGB-D Salient Object Detection}, 40 | author={Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and and Lu, Huchuan and Ruan, Xiang}, 41 | booktitle={AAAI}, 42 | year={2022} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /utils_ssl/datasets_stage2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import random 6 | import torch 7 | from torch.nn.functional import interpolate 8 | Image.MAX_IMAGE_PIXELS = 1000000000 9 | 10 | 11 | 12 | class ImageFolder(data.Dataset): 13 | def __init__(self, image_root, gt_root,depth_root, joint_transform=None, transform=None, target_transform=None): 14 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')or f.endswith('.bmp')] 15 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 16 | or f.endswith('.png')or f.endswith('.bmp')] 17 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.jpg') 18 | or f.endswith('.png')or f.endswith('.bmp')] 19 | self.images = sorted(self.images) 20 | self.gts = sorted(self.gts) 21 | self.depths = sorted(self.depths) 22 | 23 | self.joint_transform = joint_transform 24 | self.transform = transform 25 | self.target_transform = target_transform 26 | 27 | def __getitem__(self, index): 28 | img_path = self.images[index] 29 | gt_path = self.gts[index] 30 | depth_path = self.depths[index] 31 | 32 | img = Image.open(img_path).convert('RGB') 33 | gt = Image.open(gt_path).convert('L') 34 | depth = Image.open(depth_path).convert('L') 35 | 36 | 37 | if self.joint_transform is not None: 38 | img_raw, depth, gt = self.joint_transform(img, depth, gt) 39 | if self.transform is not None: 40 | img = self.transform(img_raw) 41 | if self.target_transform is not None: 42 | depth = self.target_transform(depth) 43 | gt = self.target_transform(gt) 44 | 45 | return img, depth, gt 46 | 47 | def __len__(self): 48 | return len(self.images) 49 | 50 | 51 | class ImageFolder_multi_scale(data.Dataset): 52 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None): 53 | self.root = root 54 | self.imgs = make_dataset(root) 55 | self.joint_transform = joint_transform 56 | self.transform = transform 57 | self.target_transform = target_transform 58 | def __getitem__(self, index): 59 | img_path, gt_path = self.imgs[index] 60 | img = Image.open(img_path).convert('RGB') 61 | target = Image.open(gt_path).convert('L') 62 | 63 | if self.joint_transform is not None: 64 | img, target = self.joint_transform(img, target) 65 | if self.transform is not None: 66 | img = self.transform(img) 67 | if self.target_transform is not None: 68 | target = self.target_transform(target) 69 | 70 | return img, target 71 | 72 | def __len__(self): 73 | return len(self.imgs) 74 | 75 | 76 | ####可用可不用. GateNet论文中没有使用multi-scale to train 77 | def collate(self,batch): 78 | # size = [224, 256, 288, 320, 352][np.random.randint(0, 5)] 79 | # size_list = [224, 256, 288, 320, 352] 80 | # size_list = [128, 160, 192, 224, 256] 81 | size_list = [128, 192, 256, 320, 384] 82 | size = random.choice(size_list) 83 | 84 | img, target = [list(item) for item in zip(*batch)] 85 | img = torch.stack(img, dim=0) 86 | img = interpolate(img, size=(size, size), mode="bilinear", align_corners=False) 87 | target = torch.stack(target, dim=0) 88 | target = interpolate(target, size=(size, size), mode="bilinear") 89 | # print(img.shape) 90 | return img, target -------------------------------------------------------------------------------- /utils_ssl/datasets_stage1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import random 6 | import torch 7 | from torch.nn.functional import interpolate 8 | Image.MAX_IMAGE_PIXELS = 1000000000 9 | 10 | 11 | 12 | class ImageFolder(data.Dataset): 13 | def __init__(self, image_root, gt_root,depth_root, joint_transform=None, transform=None, target_transform=None): 14 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')or f.endswith('.bmp')] 15 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 16 | or f.endswith('.png')or f.endswith('.bmp')] 17 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.jpg') 18 | or f.endswith('.png')or f.endswith('.bmp')] 19 | self.images = sorted(self.images) 20 | self.gts = sorted(self.gts) 21 | self.depths = sorted(self.depths) 22 | 23 | self.joint_transform = joint_transform 24 | self.transform = transform 25 | self.target_transform = target_transform 26 | 27 | def __getitem__(self, index): 28 | img_path = self.images[index] 29 | gt_path = self.gts[index] 30 | depth_path = self.depths[index] 31 | 32 | img = Image.open(img_path).convert('RGB') 33 | gt = Image.open(gt_path).convert('L') 34 | depth = Image.open(depth_path).convert('L') 35 | 36 | 37 | if self.joint_transform is not None: 38 | img_raw, depth, gt = self.joint_transform(img, depth, gt) 39 | if self.transform is not None: 40 | img = self.transform(img_raw) 41 | if self.target_transform is not None: 42 | img_raw = self.target_transform(img_raw) 43 | depth = self.target_transform(depth) 44 | gt = self.target_transform(gt) 45 | 46 | return img, img_raw, depth, gt 47 | 48 | def __len__(self): 49 | return len(self.images) 50 | 51 | 52 | class ImageFolder_multi_scale(data.Dataset): 53 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None): 54 | self.root = root 55 | self.imgs = make_dataset(root) 56 | self.joint_transform = joint_transform 57 | self.transform = transform 58 | self.target_transform = target_transform 59 | def __getitem__(self, index): 60 | img_path, gt_path = self.imgs[index] 61 | img = Image.open(img_path).convert('RGB') 62 | target = Image.open(gt_path).convert('L') 63 | 64 | if self.joint_transform is not None: 65 | img, target = self.joint_transform(img, target) 66 | if self.transform is not None: 67 | img = self.transform(img) 68 | if self.target_transform is not None: 69 | target = self.target_transform(target) 70 | 71 | return img, target 72 | 73 | def __len__(self): 74 | return len(self.imgs) 75 | 76 | 77 | ####可用可不用. GateNet论文中没有使用multi-scale to train 78 | def collate(self,batch): 79 | # size = [224, 256, 288, 320, 352][np.random.randint(0, 5)] 80 | # size_list = [224, 256, 288, 320, 352] 81 | # size_list = [128, 160, 192, 224, 256] 82 | size_list = [128, 192, 256, 320, 384] 83 | size = random.choice(size_list) 84 | 85 | img, target = [list(item) for item in zip(*batch)] 86 | img = torch.stack(img, dim=0) 87 | img = interpolate(img, size=(size, size), mode="bilinear", align_corners=False) 88 | target = torch.stack(target, dim=0) 89 | target = interpolate(target, size=(size, size), mode="bilinear") 90 | # print(img.shape) 91 | return img, target -------------------------------------------------------------------------------- /utils_ssl/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pydensecrf.densecrf as dcrf 4 | 5 | 6 | class AvgMeter(object): 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | def check_mkdir(dir_name): 24 | if not os.path.exists(dir_name): 25 | os.mkdir(dir_name) 26 | 27 | def cal_precision_recall_mae(prediction, gt): 28 | 29 | assert prediction.dtype == np.uint8 30 | assert gt.dtype == np.uint8 31 | print(prediction.shape,gt.shape) 32 | assert prediction.shape == gt.shape 33 | eps = 1e-4 34 | gt = gt / 255 35 | 36 | prediction = (prediction-prediction.min())/(prediction.max()-prediction.min()+ eps) 37 | gt[gt>0.5] = 1 38 | gt[gt!=1] = 0 39 | mae = np.mean(np.abs(prediction - gt)) 40 | 41 | hard_gt = np.zeros(prediction.shape) 42 | hard_gt[gt > 0.5] = 1 43 | t = np.sum(hard_gt) 44 | precision, recall,iou= [], [],[] 45 | 46 | binary = np.zeros(gt.shape) 47 | th = 2 * prediction.mean() 48 | if th > 1: 49 | th = 1 50 | binary[prediction >= th] = 1 51 | sb = (binary * gt).sum() 52 | pre_th = (sb+eps) / (binary.sum() + eps) 53 | rec_th = (sb+eps) / (gt.sum() + eps) 54 | thfm = 1.3 * pre_th * rec_th / (0.3*pre_th + rec_th + eps) 55 | 56 | 57 | for threshold in range(256): 58 | threshold = threshold / 255. 59 | 60 | hard_prediction = np.zeros(prediction.shape) 61 | hard_prediction[prediction > threshold] = 1 62 | 63 | tp = np.sum(hard_prediction * hard_gt) 64 | p = np.sum(hard_prediction) 65 | iou.append((tp + eps) / (p+t-tp + eps)) 66 | precision.append((tp + eps) / (p + eps)) 67 | recall.append((tp + eps) / (t + eps)) 68 | 69 | 70 | return precision, recall, iou,mae,thfm 71 | 72 | 73 | 74 | def cal_fmeasure(precision, recall,iou): #iou 75 | beta_square = 0.3 76 | 77 | max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]) 78 | loc = [(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)] 79 | a = loc.index(max(loc)) 80 | max_iou = max(iou) 81 | 82 | return max_fmeasure,max_iou 83 | 84 | 85 | 86 | 87 | def crf_refine(img, annos): 88 | def _sigmoid(x): 89 | return 1 / (1 + np.exp(-x)) 90 | 91 | assert img.dtype == np.uint8 92 | assert annos.dtype == np.uint8 93 | print(img.shape[:2],annos.shape) 94 | assert img.shape[:2] == annos.shape 95 | 96 | # img and annos should be np array with data type uint8 97 | 98 | EPSILON = 1e-8 99 | 100 | M = 2 # salient or not 101 | tau = 1.05 102 | # Setup the CRF model 103 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 104 | 105 | anno_norm = annos / 255. 106 | 107 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 108 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 109 | 110 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') # 创建和输入图片同样大小的U 111 | U[0, :] = n_energy.flatten() 112 | U[1, :] = p_energy.flatten() 113 | 114 | d.setUnaryEnergy(U) 115 | 116 | d.addPairwiseGaussian(sxy=3, compat=3) 117 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 118 | 119 | # Do the inference 120 | infer = np.array(d.inference(1)).astype('float32') 121 | res = infer[1, :] 122 | 123 | res = res * 255 124 | res = res.reshape(img.shape[:2]) # 和输入图片同样大小 125 | return res.astype('uint8') 126 | -------------------------------------------------------------------------------- /prediction_rgbd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | import torch 5 | from PIL import Image 6 | from torch.autograd import Variable 7 | from torchvision import transforms 8 | from utils_downstream.config import dutrgbd,njud,nlpr,stere,sip,rgbd135,ssd,lfsd 9 | from utils_downstream.misc import check_mkdir 10 | from model.model_stage3 import RGBD_sal 11 | import ttach as tta 12 | 13 | torch.manual_seed(2018) 14 | torch.cuda.set_device(0) 15 | ckpt_path = './saved_model' 16 | args = { 17 | 'snapshot': 'imagenet_based_model-50', 18 | 'crf_refine': False, 19 | 'save_results': True 20 | } 21 | 22 | 23 | 24 | img_transform = transforms.Compose([ 25 | transforms.ToTensor(), 26 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 27 | 28 | ]) 29 | 30 | depth_transform = transforms.ToTensor() 31 | target_transform = transforms.ToTensor() 32 | to_pil = transforms.ToPILImage() 33 | 34 | to_test = {'DUT-RGBD':dutrgbd,'NJUD':njud,'NLPR':nlpr,'STERE':stere,'SIP':sip,'RGBD135':rgbd135,'SSD':ssd,'LFSD':lfsd} 35 | 36 | transforms = tta.Compose( 37 | [ 38 | # tta.HorizontalFlip(), 39 | # tta.Scale(scales=[0.75, 1, 1.25], interpolation='bilinear', align_corners=False), 40 | tta.Scale(scales=[1], interpolation='bilinear', align_corners=False), 41 | ] 42 | ) 43 | 44 | def main(): 45 | t0 = time.time() 46 | net = RGBD_sal().cuda() 47 | print ('load snapshot \'%s\' for testing' % args['snapshot']) 48 | net.load_state_dict(torch.load(os.path.join(ckpt_path, args['snapshot']+'.pth'),map_location={'cuda:1': 'cuda:1'})) 49 | net.eval() 50 | with torch.no_grad(): 51 | for name, root in to_test.items(): 52 | root1 = os.path.join(root,'depth') 53 | img_list = [os.path.splitext(f) for f in os.listdir(root1)] 54 | for idx, img_name in enumerate(img_list): 55 | 56 | print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) 57 | rgb_png_path = os.path.join(root, 'RGB', img_name[0] + '.png') 58 | rgb_jpg_path = os.path.join(root, 'RGB', img_name[0] + '.jpg') 59 | depth_jpg_path = os.path.join(root, 'depth', img_name[0] + '.jpg') 60 | depth_png_path = os.path.join(root, 'depth', img_name[0] + '.png') 61 | if os.path.exists(rgb_png_path): 62 | img = Image.open(rgb_png_path).convert('RGB') 63 | else: 64 | img = Image.open(rgb_jpg_path).convert('RGB') 65 | if os.path.exists(depth_jpg_path): 66 | depth = Image.open(depth_jpg_path).convert('L') 67 | else: 68 | depth = Image.open(depth_png_path).convert('L') 69 | 70 | 71 | w_,h_ = img.size 72 | img_resize = img.resize([256,256],Image.BILINEAR) # Foldconv cat是320 73 | depth_resize = depth.resize([256,256],Image.BILINEAR) # Foldconv cat是320 74 | img_var = Variable(img_transform(img_resize).unsqueeze(0), volatile=True).cuda() 75 | depth_var = Variable(depth_transform(depth_resize).unsqueeze(0), volatile=True).cuda() 76 | n, c, h, w = img_var.size() 77 | depth_3 = torch.cat((depth_var, depth_var, depth_var), 1) 78 | mask = [] 79 | for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 80 | 81 | rgb_trans = transformer.augment_image(img_var) 82 | d_trans = transformer.augment_image(depth_3) 83 | model_output = net(rgb_trans,d_trans) 84 | deaug_mask = transformer.deaugment_mask(model_output) 85 | mask.append(deaug_mask) 86 | 87 | prediction = torch.mean(torch.stack(mask, dim=0), dim=0) 88 | prediction = prediction.sigmoid() 89 | prediction = to_pil(prediction.data.squeeze(0).cpu()) 90 | prediction = prediction.resize((w_, h_), Image.BILINEAR) 91 | if args['crf_refine']: 92 | prediction = crf_refine(np.array(img), np.array(prediction)) 93 | if args['save_results']: 94 | check_mkdir(os.path.join(ckpt_path, args['snapshot'],name)) 95 | prediction.save(os.path.join(ckpt_path, args['snapshot'],name, img_name[0] + '.png')) 96 | 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /utils_downstream/ssim_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 67 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 68 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 69 | 70 | mu1_sq = mu1.pow(2) 71 | mu2_sq = mu2.pow(2) 72 | mu1_mu2 = mu1*mu2 73 | 74 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 75 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 76 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 77 | 78 | C1 = 0.01**2 79 | C2 = 0.03**2 80 | 81 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 82 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 83 | ssim_map = -torch.log(ssim_map + 1e-8) 84 | 85 | if size_average: 86 | return ssim_map.mean() 87 | else: 88 | return ssim_map.mean(1).mean(1).mean(1) 89 | 90 | class LOGSSIM(torch.nn.Module): 91 | def __init__(self, window_size = 11, size_average = True): 92 | super(LOGSSIM, self).__init__() 93 | self.window_size = window_size 94 | self.size_average = size_average 95 | self.channel = 1 96 | self.window = create_window(window_size, self.channel) 97 | 98 | def forward(self, img1, img2): 99 | (_, channel, _, _) = img1.size() 100 | 101 | if channel == self.channel and self.window.data.type() == img1.data.type(): 102 | window = self.window 103 | else: 104 | window = create_window(self.window_size, channel) 105 | 106 | if img1.is_cuda: 107 | window = window.cuda(img1.get_device()) 108 | window = window.type_as(img1) 109 | 110 | self.window = window 111 | self.channel = channel 112 | 113 | 114 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 115 | 116 | 117 | def ssim(img1, img2, window_size = 11, size_average = True): 118 | (_, channel, _, _) = img1.size() 119 | window = create_window(window_size, channel) 120 | 121 | if img1.is_cuda: 122 | window = window.cuda(img1.get_device()) 123 | window = window.type_as(img1) 124 | 125 | return _ssim(img1, img2, window, window_size, channel, size_average) 126 | -------------------------------------------------------------------------------- /model/model_stage1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | 6 | class Crossmodal_Autoendoer(nn.Module): 7 | 8 | def __init__(self): 9 | super(Crossmodal_Autoendoer, self).__init__() 10 | ################################vgg16####################################### 11 | ##set 'pretrained=False' for SSL model or 'pretrained=True' for ImageNet pretrained model. 12 | feats = list(models.vgg16_bn(pretrained=False).features.children()) 13 | feats1 = list(models.vgg16_bn(pretrained=False).features.children()) 14 | #self.conv0 = nn.Conv2d(4, 64, kernel_size=3, padding=1) 15 | self.conv1_RGB = nn.Sequential(*feats[0:6]) 16 | self.conv2_RGB = nn.Sequential(*feats[6:13]) 17 | self.conv3_RGB = nn.Sequential(*feats[13:23]) 18 | self.conv4_RGB = nn.Sequential(*feats[23:33]) 19 | self.conv5_RGB = nn.Sequential(*feats[33:43]) 20 | 21 | self.conv1_depth = nn.Sequential(*feats1[0:6]) 22 | self.conv2_depth = nn.Sequential(*feats1[6:13]) 23 | self.conv3_depth = nn.Sequential(*feats1[13:23]) 24 | self.conv4_depth = nn.Sequential(*feats1[23:33]) 25 | self.conv5_depth = nn.Sequential(*feats1[33:43]) 26 | 27 | self.output4_rgb = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 28 | self.output3_rgb = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 29 | self.output2_rgb = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64), nn.PReLU()) 30 | self.output1_rgbtod = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1)) 31 | 32 | self.output4_depth = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 33 | self.output3_depth = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 34 | self.output2_depth = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64), nn.PReLU()) 35 | self.output1_depthtorgb = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1)) 36 | 37 | self.sideout5_rgbtod = nn.Sequential(nn.Conv2d(512, 1, kernel_size=3, padding=1)) 38 | self.sideout4_rgbtod = nn.Sequential(nn.Conv2d(256, 1, kernel_size=3, padding=1)) 39 | self.sideout3_rgbtod = nn.Sequential(nn.Conv2d(128, 1, kernel_size=3, padding=1)) 40 | self.sideout2_rgbtod = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1)) 41 | 42 | self.sideout5_depthtorgb = nn.Sequential(nn.Conv2d(512, 3, kernel_size=3, padding=1)) 43 | self.sideout4_depthtorgb = nn.Sequential(nn.Conv2d(256, 3, kernel_size=3, padding=1)) 44 | self.sideout3_depthtorgb = nn.Sequential(nn.Conv2d(128, 3, kernel_size=3, padding=1)) 45 | self.sideout2_depthtorgb = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1)) 46 | 47 | 48 | 49 | for m in self.modules(): 50 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout): 51 | m.inplace = True 52 | 53 | def forward(self, x,depth): 54 | 55 | input = x 56 | B,_,_,_ = input.size() 57 | e1_rgb = self.conv1_RGB(x) 58 | e1_depth = self.conv1_depth(depth) 59 | e2_rgb = self.conv2_RGB(e1_rgb) 60 | e2_depth= self.conv2_depth(e1_depth) 61 | e3_rgb = self.conv3_RGB(e2_rgb) 62 | e3_depth = self.conv3_depth(e2_depth) 63 | e4_rgb = self.conv4_RGB(e3_rgb) 64 | e4_depth = self.conv4_depth(e3_depth) 65 | e5_rgb = self.conv5_RGB(e4_rgb) 66 | e5_depth = self.conv5_depth(e4_depth) 67 | 68 | sideout5_rgbtod = self.sideout5_rgbtod(e5_rgb) 69 | output4_rgb = self.output4_rgb(F.upsample(e5_rgb, size=e4_rgb.size()[2:], mode='bilinear')+e4_rgb) 70 | sideout4_rgbtod = self.sideout4_rgbtod(output4_rgb) 71 | output3_rgb = self.output3_rgb(F.upsample(output4_rgb, size=e3_rgb.size()[2:], mode='bilinear') + e3_rgb) 72 | sideout3_rgbtod = self.sideout3_rgbtod(output3_rgb) 73 | output2_rgb = self.output2_rgb(F.upsample(output3_rgb, size=e2_rgb.size()[2:], mode='bilinear') + e2_rgb) 74 | sideout2_rgbtod = self.sideout2_rgbtod(output2_rgb) 75 | sideout1_rgbtod = self.output1_rgbtod(F.upsample(output2_rgb, size=e1_rgb.size()[2:], mode='bilinear') + e1_rgb) 76 | 77 | sideout5_dtorgb = self.sideout5_depthtorgb(e5_depth) 78 | output4_d = self.output4_depth(F.upsample(e5_depth, size=e4_rgb.size()[2:], mode='bilinear')+e4_depth) 79 | sideout4_dtorgb = self.sideout4_depthtorgb(output4_d) 80 | output3_d = self.output3_depth(F.upsample(output4_d, size=e3_rgb.size()[2:], mode='bilinear') + e3_depth) 81 | sideout3_dtorgb = self.sideout3_depthtorgb(output3_d) 82 | output2_d = self.output2_depth(F.upsample(output3_d, size=e2_rgb.size()[2:], mode='bilinear') + e2_depth) 83 | sideout2_dtorgb = self.sideout2_depthtorgb(output2_d) 84 | sideout1_dtorgb = self.output1_depthtorgb(F.upsample(output2_d, size=e1_rgb.size()[2:], mode='bilinear') + e1_depth) 85 | 86 | 87 | if self.training: 88 | return sideout5_rgbtod,sideout4_rgbtod,sideout3_rgbtod,sideout2_rgbtod,sideout1_rgbtod,sideout5_dtorgb,sideout4_dtorgb,sideout3_dtorgb,sideout2_dtorgb,sideout1_dtorgb 89 | # return F.sigmoid(sideout1_rgbtod), F.sigmoid(sideout1_dtorgb) 90 | return e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb, e5_depth,e4_depth,e3_depth,e2_depth,e1_depth 91 | # return e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb 92 | 93 | -------------------------------------------------------------------------------- /train_stage3_downstream.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import os, argparse 5 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 6 | from datetime import datetime 7 | from model.model_stage3 import RGBD_sal 8 | from utils_downstream.dataset_rgbd_strategy2 import get_loader 9 | from utils_downstream.utils import adjust_lr, AvgMeter 10 | import torch.nn as nn 11 | from torch.cuda import amp 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--epoch', type=int, default=100, help='epoch number') 16 | parser.add_argument('--lr_gen', type=float, default=1e-4, help='learning rate') 17 | parser.add_argument('--batchsize', type=int, default=4, help='training batch size') 18 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size') 19 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 20 | parser.add_argument('--decay_rate', type=float, default=0.9, help='decay rate of learning rate') 21 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate') 22 | parser.add_argument('-beta1_gen', type=float, default=0.5,help='beta of Adam for generator') 23 | parser.add_argument('--weight_decay', type=float, default=0.001, help='weight_decay') 24 | parser.add_argument('--feat_channel', type=int, default=64, help='reduced channel of saliency feat') 25 | 26 | opt = parser.parse_args() 27 | print('Generator Learning Rate: {}'.format(opt.lr_gen)) 28 | # build models 29 | generator = RGBD_sal() 30 | generator.cuda() 31 | 32 | pretrained_dict = torch.load(os.path.join('./saved_model/pretext_task1.pth')) 33 | model_dict = generator.state_dict() 34 | # print(pretrained_dict.items()) 35 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 36 | model_dict.update(pretrained_dict) 37 | generator.load_state_dict(model_dict) 38 | 39 | pretrained_dict = torch.load(os.path.join('./saved_model/pretext_task2.pth')) 40 | model_dict = generator.state_dict() 41 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 42 | print(pretrained_dict.items()) 43 | model_dict.update(pretrained_dict) 44 | generator.load_state_dict(model_dict) 45 | 46 | generator_params = generator.parameters() 47 | generator_optimizer = torch.optim.Adam(generator_params, opt.lr_gen) 48 | 49 | ## load data 50 | image_root = '' 51 | depth_root = '' 52 | gt_root = '' 53 | 54 | train_loader = get_loader(image_root, gt_root,depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 55 | total_step = len(train_loader) 56 | 57 | ## define loss 58 | 59 | CE = torch.nn.BCELoss() 60 | mse_loss = torch.nn.MSELoss(size_average=True, reduce=True) 61 | # size_rates = [0.75,1,1.25] # multi-scale training 62 | size_rates = [1] # multi-scale training 63 | criterion = nn.BCEWithLogitsLoss().cuda() 64 | criterion_mae = nn.L1Loss().cuda() 65 | criterion_mse = nn.MSELoss().cuda() 66 | use_fp16 = True 67 | scaler = amp.GradScaler(enabled=use_fp16) 68 | 69 | def structure_loss(pred, mask): 70 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 71 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 72 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 73 | 74 | 75 | pred = torch.sigmoid(pred) 76 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 77 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 78 | wiou = 1-(inter+1)/(union-inter+1) 79 | 80 | return (wbce+wiou).mean() 81 | 82 | 83 | 84 | for epoch in range(1, opt.epoch+1): 85 | generator.train() 86 | loss_record = AvgMeter() 87 | print('Generator Learning Rate: {}'.format(generator_optimizer.param_groups[0]['lr'])) 88 | for i, pack in enumerate(train_loader, start=1): 89 | for rate in size_rates: 90 | generator_optimizer.zero_grad() 91 | images, gts, depths = pack 92 | images = Variable(images) 93 | gts = Variable(gts) 94 | depths = Variable(depths) 95 | images = images.cuda() 96 | gts = gts.cuda() 97 | depths = depths.cuda() 98 | # multi-scale training samples 99 | trainsize = int(round(opt.trainsize * rate / 32) * 32) 100 | if rate != 1: 101 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', 102 | align_corners=True) 103 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 104 | # contours = F.upsample(contours, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 105 | depths = F.upsample(depths, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 106 | 107 | b, c, h, w = gts.size() 108 | target_1 = F.upsample(gts, size=h // 2, mode='nearest') 109 | target_2 = F.upsample(gts, size=h // 4, mode='nearest') 110 | target_3 = F.upsample(gts, size=h // 8, mode='nearest') 111 | target_4 = F.upsample(gts, size=h // 16, mode='nearest') 112 | 113 | with amp.autocast(enabled=use_fp16): 114 | depth_3 = torch.cat((depths, depths, depths), 1) 115 | sideout5, sideout4, sideout3, sideout2, output1 = generator.forward(images, depth_3) # hed 116 | loss1 = structure_loss(sideout5, target_4) 117 | loss2 = structure_loss(sideout4, target_3) 118 | loss3 = structure_loss(sideout3, target_2) 119 | loss4 = structure_loss(sideout2, target_1) 120 | loss5 = structure_loss(output1, gts) 121 | 122 | loss = loss1 + loss2 + loss3 + loss4 + loss5 123 | 124 | generator_optimizer.zero_grad() 125 | scaler.scale(loss).backward() 126 | scaler.step(generator_optimizer) 127 | scaler.update() 128 | 129 | if rate == 1: 130 | loss_record.update(loss.data, opt.batchsize) 131 | 132 | 133 | if i % 10 == 0 or i == total_step: 134 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], gen Loss: {:.4f}'. 135 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record.show())) 136 | # print(anneal_reg) 137 | 138 | 139 | adjust_lr(generator_optimizer, opt.lr_gen, epoch, opt.decay_rate, opt.decay_epoch) 140 | 141 | save_path = './saved_model/SSLSOD_v2' 142 | 143 | 144 | if not os.path.exists(save_path): 145 | os.makedirs(save_path) 146 | if epoch % opt.epoch == 0: 147 | torch.save(generator.state_dict(), save_path + 'Model' + '_%d' % epoch + '_gen.pth') 148 | -------------------------------------------------------------------------------- /train_stage2_pretext2.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import torch 4 | from torch import optim 5 | from torch.autograd import Variable 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader 8 | import utils_ssl.joint_transforms 9 | from utils_ssl.datasets_stage2 import ImageFolder 10 | from utils_ssl.misc import AvgMeter, check_mkdir 11 | from model.model_stage1 import Crossmodal_Autoendoer 12 | from model.model_stage2 import Contour_Estimation 13 | from torch.backends import cudnn 14 | from utils_downstream.ssim_loss import SSIM 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | cudnn.benchmark = True 18 | torch.manual_seed(2018) 19 | torch.cuda.set_device(0) 20 | 21 | ##########################hyperparameters############################### 22 | ckpt_path = './saved_model' 23 | exp_name = 'pretext_task2_stage2' 24 | args = { 25 | 'iter_num': 79900, #50epoch 26 | 'train_batch_size': 4, 27 | 'last_iter': 0, 28 | 'lr': 1e-3, 29 | 'lr_decay': 0.9, 30 | 'weight_decay': 0.0005, 31 | 'momentum': 0.9, 32 | 'snapshot': '' 33 | } 34 | ##########################data augmentation############################### 35 | joint_transform = utils_ssl.joint_transforms.Compose([ 36 | utils_ssl.joint_transforms.RandomCrop(256, 256), # change to resize 37 | utils_ssl.joint_transforms.RandomHorizontallyFlip(), 38 | utils_ssl.joint_transforms.RandomRotate(10) 39 | ]) 40 | img_transform = transforms.Compose([ 41 | transforms.ColorJitter(0.1, 0.1, 0.1), 42 | transforms.ToTensor(), 43 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 44 | ]) 45 | target_transform = transforms.ToTensor() 46 | ########################################################################## 47 | image_root = '' 48 | depth_root = '' 49 | gt_root = '' 50 | 51 | train_set = ImageFolder(image_root, gt_root,depth_root, joint_transform, img_transform, target_transform) 52 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=0, shuffle=True) 53 | 54 | 55 | criterion = nn.BCEWithLogitsLoss().cuda() 56 | criterion_BCE = nn.BCELoss().cuda() 57 | criterion_mse = nn.MSELoss().cuda() 58 | criterion_mae = nn.L1Loss().cuda() 59 | criterion_ssim = SSIM(window_size=11,size_average=True) 60 | def ssimmae(pre,gt): 61 | maeloss = criterion_mae(pre,gt) 62 | ssimloss = 1-criterion_ssim(pre,gt) 63 | loss = ssimloss+maeloss 64 | return loss 65 | 66 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt') 67 | 68 | 69 | def main(): 70 | #############################ResNet pretrained########################### 71 | #res18[2,2,2,2],res34[3,4,6,3],res50[3,4,6,3],res101[3,4,23,3],res152[3,8,36,3] 72 | model_pretext1 = Crossmodal_Autoendoer() 73 | net_pretext1 = model_pretext1.cuda() 74 | net_pretext1.load_state_dict(torch.load(os.path.join('./saved_model/pretext_task1.pth'))) 75 | net_pretext1.eval() 76 | 77 | model_pretext2 = Contour_Estimation() 78 | net = model_pretext2.cuda().train() 79 | 80 | 81 | optimizer = optim.SGD([ 82 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 83 | 'lr': 2 * args['lr']}, 84 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 85 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 86 | ], momentum=args['momentum']) 87 | if len(args['snapshot']) > 0: 88 | print('training resumes from ' + args['snapshot']) 89 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 90 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth'))) 91 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 92 | optimizer.param_groups[1]['lr'] = args['lr'] 93 | check_mkdir(ckpt_path) 94 | check_mkdir(os.path.join(ckpt_path, exp_name)) 95 | open(log_path, 'w').write(str(args) + '\n\n') 96 | train(net_pretext1,net, optimizer) 97 | 98 | 99 | ######################################################################### 100 | 101 | def train(net_pretext1,net, optimizer): 102 | curr_iter = args['last_iter'] 103 | while True: 104 | total_loss_record, loss1_record, loss2_record, loss3_record, loss4_record, loss5_record, loss6_record, loss7_record, loss8_record,loss9_record,loss10_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 105 | for i, data in enumerate(train_loader): 106 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num'] 107 | ) ** args['lr_decay'] 108 | optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num'] 109 | ) ** args['lr_decay'] 110 | # data\binarizing\Variable 111 | images, depths, gts = data 112 | gts[gts > 0.5] = 1 113 | gts[gts != 1] = 0 114 | batch_size = images.size(0) 115 | inputs = Variable(images).cuda() 116 | labels = Variable(gts).cuda() 117 | depths = Variable(depths).cuda() 118 | b, c, h, w = labels.size() 119 | optimizer.zero_grad() 120 | target_1 = F.upsample(labels, size=h // 2, mode='nearest') 121 | target_2 = F.upsample(labels, size=h // 4, mode='nearest') 122 | target_3 = F.upsample(labels, size=h // 8, mode='nearest') 123 | target_4 = F.upsample(labels, size=h // 16, mode='nearest') 124 | 125 | ##########loss############# 126 | depth_3 = torch.cat((depths, depths, depths), 1) 127 | e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb, e5_depth,e4_depth,e3_depth,e2_depth,e1_depth = net_pretext1( 128 | inputs, depth_3) # hed 129 | sideout5, sideout4, sideout3, sideout2, output1 = net(e5_rgb, e4_rgb, e3_rgb, e2_rgb, e1_rgb, e5_depth, e4_depth, e3_depth, e2_depth, e1_depth) 130 | loss1 = criterion_mae(F.sigmoid(sideout5), target_4) 131 | loss2 = criterion_mae(F.sigmoid(sideout4), target_3) 132 | loss3 = criterion_mae(F.sigmoid(sideout3), target_2) 133 | loss4 = criterion_mae(F.sigmoid(sideout2), target_1) 134 | loss5 = criterion_mae(F.sigmoid(output1), labels) 135 | 136 | total_loss = loss1 + loss2 + loss3 + loss4 + loss5 137 | total_loss.backward() 138 | optimizer.step() 139 | total_loss_record.update(total_loss.item(), batch_size) 140 | loss1_record.update(loss1.item(), batch_size) 141 | loss2_record.update(loss2.item(), batch_size) 142 | loss3_record.update(loss3.item(), batch_size) 143 | loss4_record.update(loss4.item(), batch_size) 144 | loss5_record.update(loss5.item(), batch_size) 145 | 146 | #############log############### 147 | curr_iter += 1 148 | log = '[iter %d], [total loss %.5f],[loss4 %.5f],[loss5 %.5f],[lr %.13f] ' % \ 149 | (curr_iter, total_loss_record.avg, loss4_record.avg, loss5_record.avg, optimizer.param_groups[1]['lr']) 150 | print(log) 151 | open(log_path, 'a').write(log + '\n') 152 | if curr_iter == args['iter_num']: 153 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) 154 | torch.save(optimizer.state_dict(), 155 | os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) 156 | return 157 | ###############end############### 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /train_stage1_pretext1.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import torch 4 | from torch import optim 5 | from torch.autograd import Variable 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader 8 | import utils_ssl.joint_transforms 9 | from utils_ssl.datasets_stage1 import ImageFolder 10 | from utils_ssl.misc import AvgMeter, check_mkdir 11 | from model.model_stage1 import Crossmodal_Autoendoer 12 | from torch.backends import cudnn 13 | from utils_downstream.ssim_loss import SSIM 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | cudnn.benchmark = True 17 | torch.manual_seed(2018) 18 | torch.cuda.set_device(0) 19 | 20 | ##########################hyperparameters############################### 21 | ckpt_path = './model' 22 | exp_name = 'pretext_task1_stage1' 23 | args = { 24 | 'iter_num': 79900, 25 | 'train_batch_size': 4, 26 | 'last_iter': 0, 27 | 'lr': 1e-3, 28 | 'lr_decay': 0.9, 29 | 'weight_decay': 0.0005, 30 | 'momentum': 0.9, 31 | 'snapshot': '' 32 | } 33 | ##########################data augmentation############################### 34 | joint_transform = utils_ssl.joint_transforms.Compose([ 35 | utils_ssl.joint_transforms.RandomCrop(256, 256), # change to resize 36 | utils_ssl.joint_transforms.RandomHorizontallyFlip(), 37 | utils_ssl.joint_transforms.RandomRotate(10) 38 | ]) 39 | img_transform = transforms.Compose([ 40 | transforms.ColorJitter(0.1, 0.1, 0.1), 41 | transforms.ToTensor(), 42 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 43 | ]) 44 | target_transform = transforms.ToTensor() 45 | ########################################################################## 46 | image_root = '' 47 | depth_root = '' 48 | gt_root = '' 49 | 50 | train_set = ImageFolder(image_root, gt_root,depth_root, joint_transform, img_transform, target_transform) 51 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=0, shuffle=True) 52 | 53 | 54 | ###multi-scale-train 55 | # train_set = ImageFolder_multi_scale(train_data, joint_transform, img_transform, target_transform) 56 | # train_loader = DataLoader(train_set, collate_fn=train_set.collate, batch_size=args['train_batch_size'], num_workers=12, shuffle=True, drop_last=True) 57 | 58 | criterion = nn.BCEWithLogitsLoss().cuda() 59 | criterion_BCE = nn.BCELoss().cuda() 60 | criterion_mse = nn.MSELoss().cuda() 61 | criterion_mae = nn.L1Loss().cuda() 62 | criterion_ssim = SSIM(window_size=11,size_average=True) 63 | def ssimmae(pre,gt): 64 | maeloss = criterion_mae(pre,gt) 65 | ssimloss = 1-criterion_ssim(pre,gt) 66 | loss = ssimloss+maeloss 67 | return loss 68 | 69 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt') 70 | 71 | 72 | def main(): 73 | #############################ResNet pretrained########################### 74 | #res18[2,2,2,2],res34[3,4,6,3],res50[3,4,6,3],res101[3,4,23,3],res152[3,8,36,3] 75 | model = Crossmodal_Autoendoer() 76 | net = model.cuda().train() 77 | optimizer = optim.SGD([ 78 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 79 | 'lr': 2 * args['lr']}, 80 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 81 | 'lr': args['lr'], 'weight_decay': args['weight_decay']} 82 | ], momentum=args['momentum']) 83 | if len(args['snapshot']) > 0: 84 | print('training resumes from ' + args['snapshot']) 85 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) 86 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth'))) 87 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] 88 | optimizer.param_groups[1]['lr'] = args['lr'] 89 | check_mkdir(ckpt_path) 90 | check_mkdir(os.path.join(ckpt_path, exp_name)) 91 | open(log_path, 'w').write(str(args) + '\n\n') 92 | train(net, optimizer) 93 | 94 | 95 | ######################################################################### 96 | 97 | def train(net, optimizer): 98 | curr_iter = args['last_iter'] 99 | while True: 100 | total_loss_record, loss1_record, loss2_record, loss3_record, loss4_record, loss5_record, loss6_record, loss7_record, loss8_record,loss9_record,loss10_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 101 | for i, data in enumerate(train_loader): 102 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num'] 103 | ) ** args['lr_decay'] 104 | optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num'] 105 | ) ** args['lr_decay'] 106 | # data\binarizing\Variable 107 | images, images_raw, gts, depths = data 108 | gts[gts > 0.5] = 1 109 | gts[gts != 1] = 0 110 | batch_size = images.size(0) 111 | inputs = Variable(images).cuda() 112 | labels = Variable(gts).cuda() 113 | images_raw = Variable(images_raw).cuda() 114 | depths = Variable(depths).cuda() 115 | b, c, h, w = labels.size() 116 | optimizer.zero_grad() 117 | target_1 = F.upsample(labels, size=h // 2, mode='nearest') 118 | target_2 = F.upsample(labels, size=h // 4, mode='nearest') 119 | target_3 = F.upsample(labels, size=h // 8, mode='nearest') 120 | target_4 = F.upsample(labels, size=h // 16, mode='nearest') 121 | 122 | images_raw_1 = F.upsample(images_raw, size=h // 2, mode='nearest') 123 | images_raw_2 = F.upsample(images_raw, size=h // 4, mode='nearest') 124 | images_raw_3 = F.upsample(images_raw, size=h // 8, mode='nearest') 125 | images_raw_4 = F.upsample(images_raw, size=h // 16, mode='nearest') 126 | 127 | depths_1 = F.upsample(depths, size=h // 2, mode='nearest') 128 | depths_2 = F.upsample(depths, size=h // 4, mode='nearest') 129 | depths_3 = F.upsample(depths, size=h // 8, mode='nearest') 130 | depths_4 = F.upsample(depths, size=h // 16, mode='nearest') 131 | 132 | ##########loss############# 133 | depth_3 = torch.cat((depths, depths, depths), 1) 134 | sideout5_rgbtod, sideout4_rgbtod, sideout3_rgbtod, sideout2_rgbtod, sideout1_rgbtod, sideout5_dtorgb, sideout4_dtorgb, sideout3_dtorgb, sideout2_dtorgb, sideout1_dtorgb = net( 135 | inputs, depth_3) # hed 136 | loss1 = ssimmae(F.sigmoid(sideout5_rgbtod), depths_4) 137 | loss2 = ssimmae(F.sigmoid(sideout4_rgbtod), depths_3) 138 | loss3 = ssimmae(F.sigmoid(sideout3_rgbtod), depths_2) 139 | loss4 = ssimmae(F.sigmoid(sideout2_rgbtod), depths_1) 140 | loss5 = ssimmae(F.sigmoid(sideout1_rgbtod), depths) 141 | loss6 = ssimmae(F.sigmoid(sideout5_dtorgb), images_raw_4) 142 | loss7 = ssimmae(F.sigmoid(sideout4_dtorgb), images_raw_3) 143 | loss8 = ssimmae(F.sigmoid(sideout3_dtorgb), images_raw_2) 144 | loss9 = ssimmae(F.sigmoid(sideout2_dtorgb), images_raw_1) 145 | loss10 = ssimmae(F.sigmoid(sideout1_dtorgb), images_raw) 146 | 147 | loss_depth = loss1 + loss2 + loss3 + loss4 + loss5 148 | loss_rgb = loss6 + loss7 + loss8 + loss9 + loss10 149 | total_loss = loss_depth + loss_rgb 150 | total_loss.backward() 151 | optimizer.step() 152 | total_loss_record.update(total_loss.item(), batch_size) 153 | loss1_record.update(loss1.item(), batch_size) 154 | loss2_record.update(loss2.item(), batch_size) 155 | loss3_record.update(loss3.item(), batch_size) 156 | loss4_record.update(loss4.item(), batch_size) 157 | loss5_record.update(loss5.item(), batch_size) 158 | loss6_record.update(loss6.item(), batch_size) 159 | loss7_record.update(loss7.item(), batch_size) 160 | loss8_record.update(loss8.item(), batch_size) 161 | loss9_record.update(loss9.item(), batch_size) 162 | loss10_record.update(loss10.item(), batch_size) 163 | 164 | #############log############### 165 | curr_iter += 1 166 | 167 | log = '[iter %d], [total loss %.5f],[loss5 %.5f],[loss10 %.5f],[lr %.13f] ' % \ 168 | (curr_iter, total_loss_record.avg, loss5_record.avg, loss10_record.avg, optimizer.param_groups[1]['lr']) 169 | print(log) 170 | open(log_path, 'a').write(log + '\n') 171 | 172 | if curr_iter == args['iter_num']: 173 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter)) 174 | torch.save(optimizer.state_dict(), 175 | os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter)) 176 | return 177 | ## #############end############### 178 | 179 | 180 | if __name__ == '__main__': 181 | main() 182 | -------------------------------------------------------------------------------- /utils_downstream/dataset_rgbd_strategy2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label, depth): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | # left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 19 | # top bottom flip 20 | # if flip_flag2==1: 21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 24 | return img, label, depth 25 | 26 | 27 | def randomCrop(image, label, depth): 28 | border = 30 29 | image_width = image.size[0] 30 | image_height = image.size[1] 31 | crop_win_width = np.random.randint(image_width - border, image_width) 32 | crop_win_height = np.random.randint(image_height - border, image_height) 33 | random_region = ( 34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 35 | (image_height + crop_win_height) >> 1) 36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region) 37 | 38 | 39 | def randomRotation(image, label, depth): 40 | mode = Image.BICUBIC 41 | if random.random() > 0.8: 42 | random_angle = np.random.randint(-15, 15) 43 | image = image.rotate(random_angle, mode) 44 | label = label.rotate(random_angle, mode) 45 | depth = depth.rotate(random_angle, mode) 46 | return image, label, depth 47 | 48 | 49 | def colorEnhance(image): 50 | bright_intensity = random.randint(5, 15) / 10.0 51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 52 | contrast_intensity = random.randint(5, 15) / 10.0 53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 54 | color_intensity = random.randint(0, 20) / 10.0 55 | image = ImageEnhance.Color(image).enhance(color_intensity) 56 | sharp_intensity = random.randint(0, 30) / 10.0 57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 58 | return image 59 | 60 | 61 | def randomGaussian(image, mean=0.1, sigma=0.35): 62 | def gaussianNoisy(im, mean=mean, sigma=sigma): 63 | for _i in range(len(im)): 64 | im[_i] += random.gauss(mean, sigma) 65 | return im 66 | 67 | img = np.asarray(image) 68 | width, height = img.shape 69 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 70 | img = img.reshape([width, height]) 71 | return Image.fromarray(np.uint8(img)) 72 | 73 | 74 | def randomPeper(img): 75 | img = np.array(img) 76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 77 | for i in range(noiseNum): 78 | 79 | randX = random.randint(0, img.shape[0] - 1) 80 | 81 | randY = random.randint(0, img.shape[1] - 1) 82 | 83 | if random.randint(0, 1) == 0: 84 | 85 | img[randX, randY] = 0 86 | 87 | else: 88 | 89 | img[randX, randY] = 255 90 | return Image.fromarray(img) 91 | 92 | 93 | # dataset for training 94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 96 | class SalObjDataset(data.Dataset): 97 | def __init__(self, image_root, gt_root, depth_root, trainsize): 98 | self.trainsize = trainsize 99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 100 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 101 | or f.endswith('.png')] 102 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 103 | or f.endswith('.png')] 104 | self.images = sorted(self.images) 105 | # print(self.images) 106 | self.gts = sorted(self.gts) 107 | # print(self.gts) 108 | # print(self.contours) 109 | self.depths = sorted(self.depths) 110 | # print(self.depths) 111 | self.filter_files() 112 | self.size = len(self.images) 113 | self.img_transform = transforms.Compose([ 114 | transforms.Resize((self.trainsize, self.trainsize)), 115 | transforms.ToTensor(), 116 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 117 | self.gt_transform = transforms.Compose([ 118 | transforms.Resize((self.trainsize, self.trainsize)), 119 | transforms.ToTensor()]) 120 | self.depths_transform = transforms.Compose( 121 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) 122 | 123 | def __getitem__(self, index): 124 | image = self.rgb_loader(self.images[index]) 125 | gt = self.binary_loader(self.gts[index]) 126 | depth = self.binary_loader(self.depths[index]) 127 | image, gt, depth = cv_random_flip(image, gt, depth) 128 | image, gt, depth = randomCrop(image, gt, depth) 129 | image, gt, depth = randomRotation(image, gt, depth) 130 | image = colorEnhance(image) 131 | # gt=randomGaussian(gt) 132 | # gt = randomPeper(gt) 133 | image = self.img_transform(image) 134 | gt = self.gt_transform(gt) 135 | depth = self.depths_transform(depth) 136 | 137 | 138 | return image, gt, depth 139 | 140 | def filter_files(self): 141 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 142 | images = [] 143 | gts = [] 144 | depths = [] 145 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths): 146 | img = Image.open(img_path) 147 | gt = Image.open(gt_path) 148 | depth = Image.open(depth_path) 149 | if img.size == gt.size and gt.size == depth.size: 150 | images.append(img_path) 151 | gts.append(gt_path) 152 | depths.append(depth_path) 153 | self.images = images 154 | self.gts = gts 155 | self.depths = depths 156 | 157 | def rgb_loader(self, path): 158 | with open(path, 'rb') as f: 159 | img = Image.open(f) 160 | return img.convert('RGB') 161 | 162 | def binary_loader(self, path): 163 | with open(path, 'rb') as f: 164 | img = Image.open(f) 165 | return img.convert('L') 166 | 167 | def resize(self, img, gt, depth): 168 | assert img.size == gt.size and gt.size == depth.size 169 | w, h = img.size 170 | if h < self.trainsize or w < self.trainsize: 171 | h = max(h, self.trainsize) 172 | w = max(w, self.trainsize) 173 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h), 174 | Image.NEAREST) 175 | else: 176 | return img, gt, depth 177 | 178 | def __len__(self): 179 | return self.size 180 | 181 | 182 | # dataloader for training 183 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False): 184 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize) 185 | data_loader = data.DataLoader(dataset=dataset, 186 | batch_size=batchsize, 187 | shuffle=shuffle, 188 | num_workers=num_workers, 189 | pin_memory=pin_memory) 190 | return data_loader 191 | 192 | 193 | # test dataset and loader 194 | class test_dataset: 195 | def __init__(self, image_root, depth_root, testsize): 196 | self.testsize = testsize 197 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 198 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 199 | or f.endswith('.png')] 200 | self.images = sorted(self.images) 201 | self.depths = sorted(self.depths) 202 | self.transform = transforms.Compose([ 203 | transforms.Resize((self.testsize, self.testsize)), 204 | transforms.ToTensor(), 205 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 206 | # self.gt_transform = transforms.Compose([ 207 | # transforms.Resize((self.trainsize, self.trainsize)), 208 | # transforms.ToTensor()]) 209 | self.depths_transform = transforms.Compose( 210 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()]) 211 | self.size = len(self.images) 212 | self.index = 0 213 | 214 | def load_data(self): 215 | image = self.rgb_loader(self.images[self.index]) 216 | HH = image.size[0] 217 | WW = image.size[1] 218 | image = self.transform(image).unsqueeze(0) 219 | depth = self.rgb_loader(self.depths[self.index]) 220 | depth = self.depths_transform(depth).unsqueeze(0) 221 | 222 | name = self.images[self.index].split('/')[-1] 223 | # image_for_post=self.rgb_loader(self.images[self.index]) 224 | # image_for_post=image_for_post.resize(gt.size) 225 | if name.endswith('.jpg'): 226 | name = name.split('.jpg')[0] + '.png' 227 | self.index += 1 228 | self.index = self.index % self.size 229 | return image, depth, HH, WW, name 230 | 231 | def rgb_loader(self, path): 232 | with open(path, 'rb') as f: 233 | img = Image.open(f) 234 | return img.convert('RGB') 235 | 236 | def binary_loader(self, path): 237 | with open(path, 'rb') as f: 238 | img = Image.open(f) 239 | return img.convert('L') 240 | 241 | def __len__(self): 242 | return self.size 243 | 244 | -------------------------------------------------------------------------------- /model/model_stage2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision import models 5 | 6 | class Contour_Estimation(nn.Module): 7 | 8 | def __init__(self): 9 | super(Contour_Estimation, self).__init__() 10 | self.fuse1_conv1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 11 | self.fuse1_conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 12 | self.fuse1_conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 13 | self.fuse1_conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 14 | self.fuse1_conv5 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 15 | 16 | self.fuse1_conv1_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 17 | self.fuse1_conv2_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 18 | self.fuse1_conv3_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 19 | self.fuse1_conv4_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 20 | self.fuse1_conv5_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU()) 21 | 22 | self.fuse2_conv1 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 23 | self.fuse2_conv2 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 24 | self.fuse2_conv3 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 25 | self.fuse2_conv4 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 26 | self.fuse2_conv5 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 27 | 28 | self.fuse2_conv1_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 29 | self.fuse2_conv2_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 30 | self.fuse2_conv3_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 31 | self.fuse2_conv4_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 32 | self.fuse2_conv5_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 33 | 34 | self.fuse3_conv1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 35 | self.fuse3_conv2 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 36 | self.fuse3_conv3 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 37 | self.fuse3_conv4 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 38 | self.fuse3_conv5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 39 | 40 | self.fuse3_conv1_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 41 | self.fuse3_conv2_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 42 | self.fuse3_conv3_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 43 | self.fuse3_conv4_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 44 | self.fuse3_conv5_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 45 | 46 | self.fuse4_conv1 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 47 | self.fuse4_conv2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 48 | self.fuse4_conv3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 49 | self.fuse4_conv4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 50 | self.fuse4_conv5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 51 | 52 | self.fuse4_conv1_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 53 | self.fuse4_conv2_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 54 | self.fuse4_conv3_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 55 | self.fuse4_conv4_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 56 | self.fuse4_conv5_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 57 | 58 | self.fuse5_conv1 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 59 | self.fuse5_conv2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 60 | self.fuse5_conv3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 61 | self.fuse5_conv4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 62 | self.fuse5_conv5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU()) 63 | 64 | self.sideout5 = nn.Sequential(nn.Conv2d(512, 1, kernel_size=3, padding=1)) 65 | self.sideout4 = nn.Sequential(nn.Conv2d(256, 1, kernel_size=3, padding=1)) 66 | self.sideout3 = nn.Sequential(nn.Conv2d(128, 1, kernel_size=3, padding=1)) 67 | self.sideout2 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1)) 68 | 69 | 70 | self.output4 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU()) 71 | self.output3 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU()) 72 | self.output2 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64), nn.PReLU()) 73 | self.output1 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1)) 74 | 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout): 78 | m.inplace = True 79 | 80 | def forward(self, e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb, e5_depth,e4_depth,e3_depth,e2_depth,e1_depth): 81 | 82 | certain_feature5 = e5_rgb * e5_depth 83 | fuse5_conv1 = self.fuse5_conv1(e5_rgb+certain_feature5) 84 | fuse5_conv2 = self.fuse5_conv2(e5_depth+certain_feature5) 85 | fuse5_certain = self.fuse5_conv3(fuse5_conv1+fuse5_conv2) 86 | 87 | uncertain_feature5 = self.fuse5_conv4(torch.abs(e5_rgb - e5_depth)) 88 | fuse5 = self.fuse5_conv5(fuse5_certain + uncertain_feature5) 89 | sideout5 = self.sideout5(fuse5) 90 | 91 | certain_feature4 = F.upsample(F.sigmoid(sideout5),size=e4_rgb.size()[2:], mode='bilinear')*e4_rgb * e4_depth 92 | fuse4_conv1 = self.fuse4_conv1(e4_rgb+certain_feature4) 93 | fuse4_conv2 = self.fuse4_conv2(e4_depth+certain_feature4) 94 | fuse4_certain = self.fuse4_conv3(fuse4_conv1+fuse4_conv2) 95 | uncertain_feature4 = self.fuse4_conv4(F.upsample(F.sigmoid(sideout5),size=e4_rgb.size()[2:], mode='bilinear')*torch.abs(e4_rgb - e4_depth)) 96 | fuse4 = self.fuse4_conv5(fuse4_certain + uncertain_feature4) 97 | ### 98 | fuse5_fpn = F.upsample(fuse5, size=fuse4.size()[2:], mode='bilinear') 99 | fpn_certain_feature4 = F.upsample(F.sigmoid(sideout5), size=fuse4.size()[2:], mode='bilinear') * fuse4 * fuse5_fpn 100 | fuse4_fpn_conv1 = self.fuse4_conv1_fpn(fuse5_fpn + fpn_certain_feature4) 101 | fuse4_fpn_conv2 = self.fuse4_conv2_fpn(fuse4 + fpn_certain_feature4) 102 | fuse4_certain_fpn = self.fuse4_conv3_fpn(fuse4_fpn_conv1 + fuse4_fpn_conv2) 103 | fpn_uncertain_feature4 = self.fuse4_conv4_fpn( 104 | F.upsample(F.sigmoid(sideout5), size=fuse4.size()[2:], mode='bilinear') * torch.abs(fuse4 - fuse5_fpn)) 105 | fuse4_fpn = self.fuse4_conv5_fpn(fuse4_certain_fpn + fpn_uncertain_feature4) 106 | output4 = self.output4(fuse4_fpn) 107 | sideout4 = self.sideout4(output4) 108 | 109 | certain_feature3 = F.upsample(F.sigmoid(sideout4),size=e3_rgb.size()[2:], mode='bilinear')*e3_rgb * e3_depth 110 | fuse3_conv1 = self.fuse3_conv1(e3_rgb + certain_feature3) 111 | fuse3_conv2 = self.fuse3_conv2(e3_depth + certain_feature3) 112 | fuse3_certain = self.fuse3_conv3(fuse3_conv1 + fuse3_conv2) 113 | uncertain_feature3 = self.fuse3_conv4( F.upsample(F.sigmoid(sideout4),size=e3_rgb.size()[2:], mode='bilinear')*torch.abs(e3_rgb - e3_depth)) 114 | fuse3 = self.fuse3_conv5(fuse3_certain + uncertain_feature3) 115 | ## 116 | output4_fpn = F.upsample(output4, size=fuse3.size()[2:], mode='bilinear') 117 | fpn_certain_feature3 = F.upsample(F.sigmoid(sideout4), size=fuse3.size()[2:], mode='bilinear') * output4_fpn * fuse3 118 | fuse3_fpn_conv1 = self.fuse3_conv1_fpn(output4_fpn + fpn_certain_feature3) 119 | fuse3_fpn_conv2 = self.fuse3_conv2_fpn(fuse3 + fpn_certain_feature3) 120 | fuse3_certain_fpn = self.fuse3_conv3_fpn(fuse3_fpn_conv1 + fuse3_fpn_conv2) 121 | 122 | fpn_uncertain_feature3 = self.fuse3_conv4_fpn( 123 | F.upsample(F.sigmoid(sideout4), size=fuse3.size()[2:], mode='bilinear') * torch.abs(fuse3 - output4_fpn)) 124 | fuse3_fpn = self.fuse3_conv5_fpn(fuse3_certain_fpn + fpn_uncertain_feature3) 125 | output3 = self.output3(fuse3_fpn) 126 | sideout3 = self.sideout3(output3) 127 | 128 | certain_feature2 = F.upsample(F.sigmoid(sideout3),size=e2_rgb.size()[2:], mode='bilinear')*e2_rgb* e2_depth 129 | fuse2_conv1 = self.fuse2_conv1(e2_rgb + certain_feature2) 130 | fuse2_conv2 = self.fuse2_conv2(e2_depth + certain_feature2) 131 | fuse2_certain = self.fuse2_conv3(fuse2_conv1 + fuse2_conv2) 132 | uncertain_feature2 = self.fuse2_conv4(F.upsample(F.sigmoid(sideout3),size=e2_rgb.size()[2:], mode='bilinear')*torch.abs(e2_rgb - e2_depth)) 133 | fuse2 = self.fuse2_conv5(fuse2_certain + uncertain_feature2) 134 | 135 | output3_fpn = F.upsample(output3, size=fuse2.size()[2:], mode='bilinear') 136 | fpn_certain_feature2 = F.upsample(F.sigmoid(sideout3), size=fuse2.size()[2:], mode='bilinear') * output3_fpn * fuse2 137 | fuse2_fpn_conv1 = self.fuse2_conv1_fpn(output3_fpn + fpn_certain_feature2) 138 | fuse2_fpn_conv2 = self.fuse2_conv2_fpn(fuse2 + fpn_certain_feature2) 139 | fuse2_certain_fpn = self.fuse2_conv3_fpn(fuse2_fpn_conv1 + fuse2_fpn_conv2) 140 | 141 | fpn_uncertain_feature2 = self.fuse2_conv4_fpn( 142 | F.upsample(F.sigmoid(sideout3), size=fuse2.size()[2:], mode='bilinear') * torch.abs(fuse2 - output3_fpn)) 143 | fuse2_fpn = self.fuse2_conv5_fpn(fuse2_certain_fpn + fpn_uncertain_feature2) 144 | output2 = self.output2(fuse2_fpn) 145 | sideout2 = self.sideout2(output2) 146 | 147 | certain_feature1 = F.upsample(F.sigmoid(sideout2),size=e1_rgb.size()[2:],mode='bilinear')*e1_rgb * e1_depth 148 | fuse1_conv1 = self.fuse1_conv1(e1_rgb+certain_feature1) 149 | fuse1_conv2 = self.fuse1_conv2(e1_depth+certain_feature1) 150 | fuse1_certain = self.fuse1_conv3(fuse1_conv1+fuse1_conv2) 151 | uncertain_feature1 = self.fuse1_conv4(F.upsample(F.sigmoid(sideout2),size=e1_rgb.size()[2:],mode='bilinear')*torch.abs(e1_rgb - e1_depth)) 152 | fuse1 = self.fuse1_conv5(fuse1_certain+uncertain_feature1) 153 | 154 | output2_fpn = F.upsample(output2, size=fuse1.size()[2:], mode='bilinear') 155 | fpn_certain_feature1 = F.upsample(F.sigmoid(sideout2), size=fuse1.size()[2:], 156 | mode='bilinear') * output2_fpn * fuse1 157 | fuse1_fpn_conv1 = self.fuse1_conv1_fpn(output2_fpn + fpn_certain_feature1) 158 | fuse1_fpn_conv2 = self.fuse1_conv2_fpn(fuse1 + fpn_certain_feature1) 159 | fuse1_certain_fpn = self.fuse1_conv3_fpn(fuse1_fpn_conv1 + fuse1_fpn_conv2) 160 | 161 | fpn_uncertain_feature1 = self.fuse1_conv4_fpn( 162 | F.upsample(F.sigmoid(sideout2), size=fuse1.size()[2:], mode='bilinear') * torch.abs(fuse1 - output2_fpn)) 163 | fuse1_fpn = self.fuse1_conv5_fpn(fuse1_certain_fpn + fpn_uncertain_feature1) 164 | output1 = self.output1(fuse1_fpn) 165 | 166 | if self.training: 167 | return sideout5,sideout4,sideout3,sideout2,output1 168 | return output1 169 | if __name__ == "__main__": 170 | model = RGBD_sal() 171 | depth = torch.randn(1, 3, 256, 256) 172 | input = torch.randn(1, 3, 256, 256) 173 | flops, params = profile(model,inputs=(input,depth)) 174 | flops, params = clever_format([flops, params], "%.3f") 175 | print(flops,params) -------------------------------------------------------------------------------- /utils_downstream/saliency_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 4 | 5 | 6 | class cal_fm(object): 7 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009) 8 | def __init__(self, num, thds=255): 9 | self.num = num 10 | self.thds = thds 11 | self.precision = np.zeros((self.num, self.thds)) 12 | self.recall = np.zeros((self.num, self.thds)) 13 | self.meanF = np.zeros((self.num,1)) 14 | self.idx = 0 15 | 16 | def update(self, pred, gt): 17 | if gt.max() != 0: 18 | prediction, recall, Fmeasure_temp = self.cal(pred, gt) 19 | self.precision[self.idx, :] = prediction 20 | self.recall[self.idx, :] = recall 21 | self.meanF[self.idx, :] = Fmeasure_temp 22 | self.idx += 1 23 | 24 | def cal(self, pred, gt): 25 | ########################meanF############################## 26 | th = 2 * pred.mean() 27 | if th > 1: 28 | th = 1 29 | binary = np.zeros_like(pred) 30 | binary[pred >= th] = 1 31 | hard_gt = np.zeros_like(gt) 32 | hard_gt[gt > 0.5] = 1 33 | tp = (binary * hard_gt).sum() 34 | if tp == 0: 35 | meanF = 0 36 | else: 37 | pre = tp / binary.sum() 38 | rec = tp / hard_gt.sum() 39 | meanF = 1.3 * pre * rec / (0.3 * pre + rec) 40 | ########################maxF############################## 41 | pred = np.uint8(pred * 255) 42 | target = pred[gt > 0.5] 43 | nontarget = pred[gt <= 0.5] 44 | targetHist, _ = np.histogram(target, bins=range(256)) 45 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 46 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 47 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 48 | precision = targetHist / (targetHist + nontargetHist + 1e-8) 49 | recall = targetHist / np.sum(gt) 50 | return precision, recall, meanF 51 | 52 | def show(self): 53 | assert self.num == self.idx 54 | precision = self.precision.mean(axis=0) 55 | recall = self.recall.mean(axis=0) 56 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8) 57 | fmeasure_avg = self.meanF.mean(axis=0) 58 | return fmeasure.max(),fmeasure_avg[0],precision,recall 59 | 60 | 61 | class cal_mae(object): 62 | # mean absolute error 63 | def __init__(self): 64 | self.prediction = [] 65 | 66 | def update(self, pred, gt): 67 | score = self.cal(pred, gt) 68 | self.prediction.append(score) 69 | 70 | def cal(self, pred, gt): 71 | return np.mean(np.abs(pred - gt)) 72 | 73 | def show(self): 74 | return np.mean(self.prediction) 75 | 76 | class cal_dice(object): 77 | # mean absolute error 78 | def __init__(self): 79 | self.prediction = [] 80 | 81 | def update(self, pred, gt): 82 | score = self.cal(pred, gt) 83 | self.prediction.append(score) 84 | 85 | def cal(self, y_pred, y_true): 86 | # smooth = 1 87 | smooth = 1e-5 88 | y_true_f = y_true.flatten() 89 | y_pred_f = y_pred.flatten() 90 | intersection = np.sum(y_true_f * y_pred_f) 91 | return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth) 92 | 93 | def show(self): 94 | return np.mean(self.prediction) 95 | 96 | class cal_ber(object): 97 | # mean absolute error 98 | def __init__(self): 99 | self.prediction = [] 100 | 101 | def update(self, pred, gt): 102 | score = self.cal(pred, gt) 103 | self.prediction.append(score) 104 | 105 | def cal(self, y_pred, y_true): 106 | binary = np.zeros_like(y_pred) 107 | binary[y_pred >= 0.5] = 1 108 | hard_gt = np.zeros_like(y_true) 109 | hard_gt[y_true > 0.5] = 1 110 | tp = (binary * hard_gt).sum() 111 | tn = ((1-binary) * (1-hard_gt)).sum() 112 | Np = hard_gt.sum() 113 | Nn = (1-hard_gt).sum() 114 | ber = (1-(tp/(Np+1e-8)+tn/(Nn+1e-8))/2) 115 | return ber 116 | 117 | def show(self): 118 | return np.mean(self.prediction) 119 | 120 | class cal_acc(object): 121 | # mean absolute error 122 | def __init__(self): 123 | self.prediction = [] 124 | 125 | def update(self, pred, gt): 126 | score = self.cal(pred, gt) 127 | self.prediction.append(score) 128 | 129 | def cal(self, y_pred, y_true): 130 | binary = np.zeros_like(y_pred) 131 | binary[y_pred >= 0.5] = 1 132 | hard_gt = np.zeros_like(y_true) 133 | hard_gt[y_true > 0.5] = 1 134 | tp = (binary * hard_gt).sum() 135 | tn = ((1-binary) * (1-hard_gt)).sum() 136 | Np = hard_gt.sum() 137 | Nn = (1-hard_gt).sum() 138 | acc = ((tp+tn)/(Np+Nn)) 139 | return acc 140 | 141 | def show(self): 142 | return np.mean(self.prediction) 143 | 144 | class cal_iou(object): 145 | # mean absolute error 146 | def __init__(self): 147 | self.prediction = [] 148 | 149 | def update(self, pred, gt): 150 | score = self.cal(pred, gt) 151 | self.prediction.append(score) 152 | 153 | # def cal(self, input, target): 154 | # classes = 1 155 | # intersection = np.logical_and(target == classes, input == classes) 156 | # # print(intersection.any()) 157 | # union = np.logical_or(target == classes, input == classes) 158 | # return np.sum(intersection) / np.sum(union) 159 | 160 | def cal(self, input, target): 161 | smooth = 1e-5 162 | input = input > 0.5 163 | target_ = target > 0.5 164 | intersection = (input & target_).sum() 165 | union = (input | target_).sum() 166 | 167 | return (intersection + smooth) / (union + smooth) 168 | def show(self): 169 | return np.mean(self.prediction) 170 | 171 | # smooth = 1e-5 172 | # 173 | # if torch.is_tensor(output): 174 | # output = torch.sigmoid(output).data.cpu().numpy() 175 | # if torch.is_tensor(target): 176 | # target = target.data.cpu().numpy() 177 | # output_ = output > 0.5 178 | # target_ = target > 0.5 179 | # intersection = (output_ & target_).sum() 180 | # union = (output_ | target_).sum() 181 | 182 | # return (intersection + smooth) / (union + smooth) 183 | 184 | class cal_sm(object): 185 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 186 | def __init__(self, alpha=0.5): 187 | self.prediction = [] 188 | self.alpha = alpha 189 | 190 | def update(self, pred, gt): 191 | gt = gt > 0.5 192 | score = self.cal(pred, gt) 193 | self.prediction.append(score) 194 | 195 | def show(self): 196 | return np.mean(self.prediction) 197 | 198 | def cal(self, pred, gt): 199 | y = np.mean(gt) 200 | if y == 0: 201 | score = 1 - np.mean(pred) 202 | elif y == 1: 203 | score = np.mean(pred) 204 | else: 205 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 206 | return score 207 | 208 | def object(self, pred, gt): 209 | fg = pred * gt 210 | bg = (1 - pred) * (1 - gt) 211 | 212 | u = np.mean(gt) 213 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 214 | 215 | def s_object(self, in1, in2): 216 | x = np.mean(in1[in2]) 217 | sigma_x = np.std(in1[in2]) 218 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 219 | 220 | def region(self, pred, gt): 221 | [y, x] = ndimage.center_of_mass(gt) 222 | y = int(round(y)) + 1 223 | x = int(round(x)) + 1 224 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 225 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 226 | 227 | score1 = self.ssim(pred1, gt1) 228 | score2 = self.ssim(pred2, gt2) 229 | score3 = self.ssim(pred3, gt3) 230 | score4 = self.ssim(pred4, gt4) 231 | 232 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 233 | 234 | def divideGT(self, gt, x, y): 235 | h, w = gt.shape 236 | area = h * w 237 | LT = gt[0:y, 0:x] 238 | RT = gt[0:y, x:w] 239 | LB = gt[y:h, 0:x] 240 | RB = gt[y:h, x:w] 241 | 242 | w1 = x * y / area 243 | w2 = y * (w - x) / area 244 | w3 = (h - y) * x / area 245 | w4 = (h - y) * (w - x) / area 246 | 247 | return LT, RT, LB, RB, w1, w2, w3, w4 248 | 249 | def dividePred(self, pred, x, y): 250 | h, w = pred.shape 251 | LT = pred[0:y, 0:x] 252 | RT = pred[0:y, x:w] 253 | LB = pred[y:h, 0:x] 254 | RB = pred[y:h, x:w] 255 | 256 | return LT, RT, LB, RB 257 | 258 | def ssim(self, in1, in2): 259 | in2 = np.float32(in2) 260 | h, w = in1.shape 261 | N = h * w 262 | 263 | x = np.mean(in1) 264 | y = np.mean(in2) 265 | sigma_x = np.var(in1) 266 | sigma_y = np.var(in2) 267 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 268 | 269 | alpha = 4 * x * y * sigma_xy 270 | beta = (x * x + y * y) * (sigma_x + sigma_y) 271 | 272 | if alpha != 0: 273 | score = alpha / (beta + 1e-8) 274 | elif alpha == 0 and beta == 0: 275 | score = 1 276 | else: 277 | score = 0 278 | 279 | return score 280 | 281 | class cal_em(object): 282 | #Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 283 | def __init__(self): 284 | self.prediction = [] 285 | 286 | def update(self, pred, gt): 287 | score = self.cal(pred, gt) 288 | self.prediction.append(score) 289 | 290 | def cal(self, pred, gt): 291 | th = 2 * pred.mean() 292 | if th > 1: 293 | th = 1 294 | FM = np.zeros(gt.shape) 295 | FM[pred >= th] = 1 296 | FM = np.array(FM,dtype=bool) 297 | GT = np.array(gt,dtype=bool) 298 | dFM = np.double(FM) 299 | if (sum(sum(np.double(GT)))==0): 300 | enhanced_matrix = 1.0-dFM 301 | elif (sum(sum(np.double(~GT)))==0): 302 | enhanced_matrix = dFM 303 | else: 304 | dGT = np.double(GT) 305 | align_matrix = self.AlignmentTerm(dFM, dGT) 306 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 307 | [w, h] = np.shape(GT) 308 | score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8) 309 | return score 310 | def AlignmentTerm(self,dFM,dGT): 311 | mu_FM = np.mean(dFM) 312 | mu_GT = np.mean(dGT) 313 | align_FM = dFM - mu_FM 314 | align_GT = dGT - mu_GT 315 | align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8) 316 | return align_Matrix 317 | def EnhancedAlignmentTerm(self,align_Matrix): 318 | enhanced = np.power(align_Matrix + 1,2) / 4 319 | return enhanced 320 | def show(self): 321 | return np.mean(self.prediction) 322 | class cal_wfm(object): 323 | def __init__(self, beta=1): 324 | self.beta = beta 325 | self.eps = 1e-6 326 | self.scores_list = [] 327 | 328 | def update(self, pred, gt): 329 | assert pred.ndim == gt.ndim and pred.shape == gt.shape 330 | assert pred.max() <= 1 and pred.min() >= 0 331 | assert gt.max() <= 1 and gt.min() >= 0 332 | 333 | gt = gt > 0.5 334 | if gt.max() == 0: 335 | score = 0 336 | else: 337 | score = self.cal(pred, gt) 338 | self.scores_list.append(score) 339 | 340 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 341 | """ 342 | 2D gaussian mask - should give the same result as MATLAB's 343 | fspecial('gaussian',[shape],[sigma]) 344 | """ 345 | m, n = [(ss - 1.) / 2. for ss in shape] 346 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 347 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 348 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 349 | sumh = h.sum() 350 | if sumh != 0: 351 | h /= sumh 352 | return h 353 | 354 | def cal(self, pred, gt): 355 | # [Dst,IDXT] = bwdist(dGT); 356 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 357 | 358 | # %Pixel dependency 359 | # E = abs(FG-dGT); 360 | E = np.abs(pred - gt) 361 | # Et = E; 362 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 363 | Et = np.copy(E) 364 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 365 | 366 | # K = fspecial('gaussian',7,5); 367 | # EA = imfilter(Et,K); 368 | # MIN_E_EA(GT & EA