├── image
└── logo.png
├── utils_downstream
├── config.py
├── misc.py
├── test_data.py
├── utils.py
├── ssim_loss.py
├── dataset_rgbd_strategy2.py
└── saliency_metric.py
├── get_contour.py
├── utils_ssl
├── joint_transforms.py
├── datasets_stage2.py
├── datasets_stage1.py
└── misc.py
├── test_score.py
├── README.md
├── prediction_rgbd.py
├── model
├── model_stage1.py
├── model_stage2.py
└── model_stage3.py
├── train_stage3_downstream.py
├── train_stage2_pretext2.py
└── train_stage1_pretext1.py
/image/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xiaoqi-Zhao-DLUT/SSLSOD/HEAD/image/logo.png
--------------------------------------------------------------------------------
/utils_downstream/config.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | import os
3 |
4 | dutrgbd_root_test = ''
5 | njud_root_test = ''
6 | nlpr_root_test = ''
7 | stere_root_test =''
8 | sip_root_test = ''
9 | rgbd135_root_test = ''
10 | ssd_root_test = ''
11 | lfsd_root_test = ''
12 |
13 |
14 | dutrgbd = os.path.join(dutrgbd_root_test)
15 | njud = os.path.join(njud_root_test)
16 | nlpr = os.path.join(nlpr_root_test)
17 | stere = os.path.join(stere_root_test)
18 | sip = os.path.join(sip_root_test)
19 | rgbd135 = os.path.join(rgbd135_root_test)
20 | ssd = os.path.join(ssd_root_test)
21 | lfsd = os.path.join(lfsd_root_test)
22 |
23 |
24 | SSLSOD = '' # model
25 |
26 | RGBD_SOD_Models = {'SSLSOD':SSLSOD}
27 |
--------------------------------------------------------------------------------
/get_contour.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | from torchvision import transforms
5 | import cv2
6 | test = ''
7 | to_test = {'contour':test}
8 | img_transform = transforms.Compose([
9 | transforms.ToTensor()])
10 | save_path = ''
11 | to_pil = transforms.ToPILImage()
12 |
13 | for name, root in to_test.items():
14 | root1 = os.path.join(root)
15 | img_list = [os.path.splitext(f)[0] for f in os.listdir(root1) if f.endswith('.png')]
16 | for idx, img_name in enumerate(img_list):
17 | print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
18 | img1 = Image.open(os.path.join(root, img_name + '.png')).convert('L')
19 | img1 = np.array(img1)
20 | kernel = np.ones((5,5),np.uint8)
21 | img2 = cv2.erode(img1,kernel)
22 | img3 = cv2.dilate(img1,kernel)
23 | img = np.array(img3-img2)
24 | img[img >= 6] = 255
25 | img[img<6] = 0
26 | cv2.imwrite(os.path.join(save_path, img_name + '.jpg'), img,[int(cv2.IMWRITE_JPEG_LUMA_QUALITY),50])
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/utils_downstream/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pydensecrf.densecrf as dcrf
4 |
5 |
6 | class AvgMeter(object):
7 | def __init__(self):
8 | self.reset()
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
22 |
23 | def check_mkdir(dir_name):
24 | if not os.path.isdir(dir_name):
25 | os.makedirs(dir_name)
26 |
27 | def crf_refine(img, annos):
28 | def _sigmoid(x):
29 | return 1 / (1 + np.exp(-x))
30 |
31 | assert img.dtype == np.uint8
32 | assert annos.dtype == np.uint8
33 | print(img.shape[:2],annos.shape)
34 | assert img.shape[:2] == annos.shape
35 |
36 | # img and annos should be np array with data type uint8
37 |
38 | EPSILON = 1e-8
39 |
40 | M = 2 # salient or not
41 | tau = 1.05
42 | # Setup the CRF model
43 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)
44 |
45 | anno_norm = annos / 255.
46 |
47 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
48 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))
49 |
50 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') # 创建和输入图片同样大小的U
51 | U[0, :] = n_energy.flatten()
52 | U[1, :] = p_energy.flatten()
53 |
54 | d.setUnaryEnergy(U)
55 |
56 | d.addPairwiseGaussian(sxy=3, compat=3)
57 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)
58 |
59 | # Do the inference
60 | infer = np.array(d.inference(1)).astype('float32')
61 | res = infer[1, :]
62 |
63 | res = res * 255
64 | res = res.reshape(img.shape[:2]) # 和输入图片同样大小
65 | return res.astype('uint8')
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/utils_ssl/joint_transforms.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 | from PIL import Image, ImageOps
4 |
5 | Image.MAX_IMAGE_PIXELS = 1000000000
6 |
7 |
8 | class Compose(object):
9 | def __init__(self, transforms):
10 | self.transforms = transforms
11 | def __call__(self, img, depth, mask):
12 | assert img.size == mask.size
13 | for t in self.transforms:
14 | img, depth, mask = t(img, depth, mask)
15 | return img, depth, mask
16 |
17 |
18 |
19 | class RandomCrop(object):
20 | def __init__(self, size,size1, padding=0):
21 | if isinstance(size, numbers.Number):
22 | self.size = (int(size), int(size1))
23 | else:
24 | self.size = size
25 | self.padding = padding
26 |
27 | def __call__(self, img, depth, mask):
28 | if self.padding > 0:
29 | img = ImageOps.expand(img, border=self.padding, fill=0)
30 | mask = ImageOps.expand(mask, border=self.padding, fill=0)
31 | depth = ImageOps.expand(depth, border=self.padding, fill=0)
32 |
33 | assert img.size == mask.size
34 | w, h = img.size
35 | th, tw = self.size
36 | if w == tw and h == th:
37 | return img, mask
38 | if w < tw or h < th:
39 | return img.resize((tw, th), Image.BILINEAR), depth.resize((tw, th), Image.NEAREST), mask.resize((tw, th), Image.NEAREST)
40 | return img.resize((tw, th), Image.BILINEAR), depth.resize((tw, th), Image.NEAREST), mask.resize((tw, th), Image.NEAREST)
41 |
42 | class RandomHorizontallyFlip(object):
43 | def __call__(self, img, depth, mask):
44 | if random.random() < 0.5:
45 | return img.transpose(Image.FLIP_LEFT_RIGHT), depth.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)
46 | return img, depth, mask
47 |
48 |
49 | class RandomRotate(object):
50 | def __init__(self, degree):
51 | self.degree = degree
52 |
53 | def __call__(self, img, depth, mask):
54 | rotate_degree = random.random() * 2 * self.degree - self.degree
55 | return img.rotate(rotate_degree, Image.BILINEAR), depth.rotate(rotate_degree, Image.NEAREST), mask.rotate(rotate_degree, Image.NEAREST)
56 |
--------------------------------------------------------------------------------
/utils_downstream/test_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 |
6 | class test_dataset:
7 | def __init__(self, image_root, gt_root):
8 | self.img_list_1 = [os.path.splitext(f)[0] for f in os.listdir(image_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')]
9 | self.img_list_2 = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.bmp')]
10 | self.img_list = list(set(self.img_list_1).intersection(set(self.img_list_2)))
11 |
12 | self.image_root = image_root
13 | self.gt_root = gt_root
14 | self.transform = transforms.Compose([
15 | transforms.ToTensor(),
16 | ])
17 | self.gt_transform = transforms.ToTensor()
18 | self.size = len(self.img_list)
19 | self.index = 0
20 |
21 | def load_data(self):
22 | #image = self.rgb_loader(self.images[self.index])
23 | rgb_png_path = os.path.join(self.image_root,self.img_list[self.index]+ '.png')
24 | rgb_jpg_path = os.path.join(self.image_root,self.img_list[self.index]+ '.jpg')
25 | rgb_bmp_path = os.path.join(self.image_root,self.img_list[self.index]+ '.bmp')
26 | if os.path.exists(rgb_png_path):
27 | image = self.binary_loader(rgb_png_path)
28 | elif os.path.exists(rgb_jpg_path):
29 | image = self.binary_loader(rgb_jpg_path)
30 | else:
31 | image = self.binary_loader(rgb_bmp_path)
32 | if os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.png')):
33 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png'))
34 | elif os.path.exists(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg')):
35 | gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.jpg'))
36 | else:
37 | gt = self.binary_loader(os.path.join(self.gt_root, self.img_list[self.index] + '.bmp'))
38 |
39 | self.index += 1
40 | return image, gt
41 |
42 | def rgb_loader(self, path):
43 | with open(path, 'rb') as f:
44 | img = Image.open(f)
45 | return img.convert('RGB')
46 |
47 | def binary_loader(self, path):
48 | with open(path, 'rb') as f:
49 | img = Image.open(f)
50 | return img.convert('L')
51 |
52 |
--------------------------------------------------------------------------------
/utils_downstream/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 |
5 | def clip_gradient(optimizer, grad_clip):
6 | for group in optimizer.param_groups:
7 | for param in group['params']:
8 | if param.grad is not None:
9 | param.grad.data.clamp_(-grad_clip, grad_clip)
10 |
11 |
12 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=5):
13 | decay = decay_rate ** (epoch // decay_epoch)
14 | for param_group in optimizer.param_groups:
15 | param_group['lr'] *= decay
16 |
17 |
18 | def truncated_normal_(tensor, mean=0, std=1):
19 | size = tensor.shape
20 | tmp = tensor.new_empty(size + (4,)).normal_()
21 | valid = (tmp < 2) & (tmp > -2)
22 | ind = valid.max(-1, keepdim=True)[1]
23 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
24 | tensor.data.mul_(std).add_(mean)
25 |
26 | def init_weights(m):
27 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
28 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
29 | #nn.init.normal_(m.weight, std=0.001)
30 | #nn.init.normal_(m.bias, std=0.001)
31 | truncated_normal_(m.bias, mean=0, std=0.001)
32 |
33 | def init_weights_orthogonal_normal(m):
34 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
35 | nn.init.orthogonal_(m.weight)
36 | truncated_normal_(m.bias, mean=0, std=0.001)
37 | #nn.init.normal_(m.bias, std=0.001)
38 |
39 | def l2_regularisation(m):
40 | l2_reg = None
41 |
42 | for W in m.parameters():
43 | if l2_reg is None:
44 | l2_reg = W.norm(2)
45 | else:
46 | l2_reg = l2_reg + W.norm(2)
47 | return l2_reg
48 |
49 | class AvgMeter(object):
50 | def __init__(self, num=40):
51 | self.num = num
52 | self.reset()
53 |
54 | def reset(self):
55 | self.val = 0
56 | self.avg = 0
57 | self.sum = 0
58 | self.count = 0
59 | self.losses = []
60 |
61 | def update(self, val, n=1):
62 | self.val = val
63 | self.sum += val * n
64 | self.count += n
65 | self.avg = self.sum / self.count
66 | self.losses.append(val)
67 |
68 | def show(self):
69 | a = len(self.losses)
70 | b = np.maximum(a-self.num, 0)
71 | c = self.losses[b:]
72 | #print(c)
73 | #d = torch.mean(torch.stack(c))
74 | #print(d)
75 | return torch.mean(torch.stack(c))
76 |
77 | # def save_mask_prediction_example(mask, pred, iter):
78 | # plt.imshow(pred[0,:,:],cmap='Greys')
79 | # plt.savefig('images/'+str(iter)+"_prediction.png")
80 | # plt.imshow(mask[0,:,:],cmap='Greys')
81 | # plt.savefig('images/'+str(iter)+"_mask.png")
--------------------------------------------------------------------------------
/test_score.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from utils_downstream.test_data import test_dataset
4 | from utils_downstream.saliency_metric import cal_mae,cal_fm,cal_sm,cal_em,cal_wfm, cal_dice, cal_iou,cal_ber,cal_acc
5 | from utils_downstream.config import dutrgbd,njud,nlpr,stere,sip,rgbd135,ssd,lfsd
6 | from utils_downstream.config import RGBD_SOD_Models
7 | from tqdm import tqdm
8 |
9 | test_datasets = {'DUT-RGBD':dutrgbd,'NJU2K':njud,'NJUD':njud,'NLPR':nlpr,'STERE':stere,'SIP':sip,'DES':rgbd135,'RGBD135':rgbd135,'SSD':ssd,'LFSD':lfsd}
10 | for method_name,method_map_root in RGBD_SOD_Models.items():
11 | print(method_name)
12 | for name, root in test_datasets.items():
13 | print(name)
14 | sal_root = method_map_root +name
15 | print(sal_root)
16 | gt_root = root+'/GT'
17 | print(gt_root)
18 | if os.path.exists(sal_root):
19 | test_loader = test_dataset(sal_root, gt_root)
20 | mae,fm,sm,em,wfm, m_dice, m_iou,ber,acc= cal_mae(),cal_fm(test_loader.size),cal_sm(),cal_em(),cal_wfm(), cal_dice(), cal_iou(),cal_ber(),cal_acc()
21 | for i in tqdm(range(test_loader.size)):
22 | # print ('predicting for %d / %d' % ( i + 1, test_loader.size))
23 | sal, gt = test_loader.load_data()
24 | if sal.size != gt.size:
25 | x, y = gt.size
26 | sal = sal.resize((x, y))
27 | gt = np.asarray(gt, np.float32)
28 | gt /= (gt.max() + 1e-8)
29 | gt[gt > 0.5] = 1
30 | gt[gt != 1] = 0
31 | res = sal
32 | res = np.array(res)
33 | if res.max() == res.min():
34 | res = res/255
35 | else:
36 | res = (res - res.min()) / (res.max() - res.min())
37 | #二值化会提升mae和meanf,em
38 | # res[res > 0.5] = 1
39 | # res[res != 1] = 0
40 |
41 | mae.update(res, gt)
42 | sm.update(res,gt)
43 | fm.update(res, gt)
44 | em.update(res,gt)
45 | wfm.update(res,gt)
46 | m_dice.update(res,gt)
47 | m_iou.update(res,gt)
48 | ber.update(res,gt)
49 | acc.update(res,gt)
50 |
51 | MAE = mae.show()
52 | maxf,meanf,_,_ = fm.show()
53 | sm = sm.show()
54 | em = em.show()
55 | wfm = wfm.show()
56 | m_dice = m_dice.show()
57 | m_iou = m_iou.show()
58 | ber = ber.show()
59 | acc = acc.show()
60 | print('method_name: {} dataset: {} MAE: {:.4f} Ber: {:.4f} maxF: {:.4f} avgF: {:.4f} wfm: {:.4f} Sm: {:.4f} Em: {:.4f} M_dice: {:.4f} M_iou: {:.4f} Acc: {:.4f}'.format(method_name,name, MAE,ber, maxf,meanf,wfm,sm,em, m_dice, m_iou,acc))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SSLSOD
2 |
3 |
4 |
5 |
6 |
Self-Supervised Pretraining for RGB-D Salient Object Detection
7 |
8 |
9 | Xiaoqi Zhao, Youwei Pang, Lihe Zhang, Huchuan Lu, Xiang Ruan
10 |
11 | ⭐ arXiv »
12 | :fire:[Slide&极市平台推送]
13 |
14 |
15 |
16 |
17 | The official repo of the AAAI 2022 paper, Self-Supervised Pretraining for RGB-D Salient Object Detection.
18 | ## Saliency map
19 | [Google Drive](https://drive.google.com/file/d/1i5OElgml76p76N2l9eYlFc4Bm8jOlxUk/view?usp=sharing) / [BaiduYunPan(d63j)](https://pan.baidu.com/s/1qifMM7wgR5gPhb6ZlRU9Zw)
20 | ## Trained Model
21 | You can download all trained models at [Google Drive](https://drive.google.com/file/d/1mxX4yk6yOCTapJ_dn_5nZhnb8IpvzEt0/view?usp=sharing) / [BaiduYunPan(0401)](https://pan.baidu.com/s/1zruPGxeR-7j4bfNSrzU5Lg).
22 | ## Datasets
23 | * [Google Drive](https://drive.google.com/file/d/1Nxm8wr2jSW-Ntqu8cdm4GfZPVOClJbZE/view?usp=sharing) / [BaiduYunPan(1a4t)](https://pan.baidu.com/s/1DUHzxs4JP4hzWJIoz4Lqyg)
24 | * We believe that using a large amount of RGB-D data for pre-training, we will get a super-strong SSL-based model even surpassing the ImageNet-based model. This [survey](https://arxiv.org/pdf/2201.05761.pdf) of the RGB-D dataset may be helpful to you.
25 | ## Training
26 | * SSL-based model
27 | 1.Run train_stage1_pretext1.py
28 | 2.Run get_contour.py (can generate the depth-contour maps for the stage2 training)
29 | 2.Load the pretext1 weights for Crossmodal_Autoendoer (model_stage1.py) and run train_stage2_pretext2.py
30 | 3.Load the pretext1 and pretext2 weights for RGBD_sal (model_stage3.py) as initialization and run train_stage3_downstream.py
31 | * ImageNet-based model
32 | Set 'pretrained= Ture' for models.vgg16_bn(pretrained='True') in RGBD_sal (model_stage3.py) and run train_stage3_downstream.py
33 | ## Testing
34 | Run prediction_rgbd.py (can generate the predicted saliency maps)
35 | Run test_score.py (can evaluate the predicted saliency maps in terms of fmax,fmean,wfm,sm,em,mae,mdice,miou,ber,acc).
36 | ## BibTex
37 | ```
38 | @inproceedings{SSLSOD,
39 | title={Self-Supervised Pretraining for RGB-D Salient Object Detection},
40 | author={Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and and Lu, Huchuan and Ruan, Xiang},
41 | booktitle={AAAI},
42 | year={2022}
43 | }
44 | ```
45 |
--------------------------------------------------------------------------------
/utils_ssl/datasets_stage2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import torch.utils.data as data
4 | from PIL import Image
5 | import random
6 | import torch
7 | from torch.nn.functional import interpolate
8 | Image.MAX_IMAGE_PIXELS = 1000000000
9 |
10 |
11 |
12 | class ImageFolder(data.Dataset):
13 | def __init__(self, image_root, gt_root,depth_root, joint_transform=None, transform=None, target_transform=None):
14 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')or f.endswith('.bmp')]
15 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
16 | or f.endswith('.png')or f.endswith('.bmp')]
17 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.jpg')
18 | or f.endswith('.png')or f.endswith('.bmp')]
19 | self.images = sorted(self.images)
20 | self.gts = sorted(self.gts)
21 | self.depths = sorted(self.depths)
22 |
23 | self.joint_transform = joint_transform
24 | self.transform = transform
25 | self.target_transform = target_transform
26 |
27 | def __getitem__(self, index):
28 | img_path = self.images[index]
29 | gt_path = self.gts[index]
30 | depth_path = self.depths[index]
31 |
32 | img = Image.open(img_path).convert('RGB')
33 | gt = Image.open(gt_path).convert('L')
34 | depth = Image.open(depth_path).convert('L')
35 |
36 |
37 | if self.joint_transform is not None:
38 | img_raw, depth, gt = self.joint_transform(img, depth, gt)
39 | if self.transform is not None:
40 | img = self.transform(img_raw)
41 | if self.target_transform is not None:
42 | depth = self.target_transform(depth)
43 | gt = self.target_transform(gt)
44 |
45 | return img, depth, gt
46 |
47 | def __len__(self):
48 | return len(self.images)
49 |
50 |
51 | class ImageFolder_multi_scale(data.Dataset):
52 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None):
53 | self.root = root
54 | self.imgs = make_dataset(root)
55 | self.joint_transform = joint_transform
56 | self.transform = transform
57 | self.target_transform = target_transform
58 | def __getitem__(self, index):
59 | img_path, gt_path = self.imgs[index]
60 | img = Image.open(img_path).convert('RGB')
61 | target = Image.open(gt_path).convert('L')
62 |
63 | if self.joint_transform is not None:
64 | img, target = self.joint_transform(img, target)
65 | if self.transform is not None:
66 | img = self.transform(img)
67 | if self.target_transform is not None:
68 | target = self.target_transform(target)
69 |
70 | return img, target
71 |
72 | def __len__(self):
73 | return len(self.imgs)
74 |
75 |
76 | ####可用可不用. GateNet论文中没有使用multi-scale to train
77 | def collate(self,batch):
78 | # size = [224, 256, 288, 320, 352][np.random.randint(0, 5)]
79 | # size_list = [224, 256, 288, 320, 352]
80 | # size_list = [128, 160, 192, 224, 256]
81 | size_list = [128, 192, 256, 320, 384]
82 | size = random.choice(size_list)
83 |
84 | img, target = [list(item) for item in zip(*batch)]
85 | img = torch.stack(img, dim=0)
86 | img = interpolate(img, size=(size, size), mode="bilinear", align_corners=False)
87 | target = torch.stack(target, dim=0)
88 | target = interpolate(target, size=(size, size), mode="bilinear")
89 | # print(img.shape)
90 | return img, target
--------------------------------------------------------------------------------
/utils_ssl/datasets_stage1.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import torch.utils.data as data
4 | from PIL import Image
5 | import random
6 | import torch
7 | from torch.nn.functional import interpolate
8 | Image.MAX_IMAGE_PIXELS = 1000000000
9 |
10 |
11 |
12 | class ImageFolder(data.Dataset):
13 | def __init__(self, image_root, gt_root,depth_root, joint_transform=None, transform=None, target_transform=None):
14 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')or f.endswith('.bmp')]
15 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
16 | or f.endswith('.png')or f.endswith('.bmp')]
17 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.jpg')
18 | or f.endswith('.png')or f.endswith('.bmp')]
19 | self.images = sorted(self.images)
20 | self.gts = sorted(self.gts)
21 | self.depths = sorted(self.depths)
22 |
23 | self.joint_transform = joint_transform
24 | self.transform = transform
25 | self.target_transform = target_transform
26 |
27 | def __getitem__(self, index):
28 | img_path = self.images[index]
29 | gt_path = self.gts[index]
30 | depth_path = self.depths[index]
31 |
32 | img = Image.open(img_path).convert('RGB')
33 | gt = Image.open(gt_path).convert('L')
34 | depth = Image.open(depth_path).convert('L')
35 |
36 |
37 | if self.joint_transform is not None:
38 | img_raw, depth, gt = self.joint_transform(img, depth, gt)
39 | if self.transform is not None:
40 | img = self.transform(img_raw)
41 | if self.target_transform is not None:
42 | img_raw = self.target_transform(img_raw)
43 | depth = self.target_transform(depth)
44 | gt = self.target_transform(gt)
45 |
46 | return img, img_raw, depth, gt
47 |
48 | def __len__(self):
49 | return len(self.images)
50 |
51 |
52 | class ImageFolder_multi_scale(data.Dataset):
53 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None):
54 | self.root = root
55 | self.imgs = make_dataset(root)
56 | self.joint_transform = joint_transform
57 | self.transform = transform
58 | self.target_transform = target_transform
59 | def __getitem__(self, index):
60 | img_path, gt_path = self.imgs[index]
61 | img = Image.open(img_path).convert('RGB')
62 | target = Image.open(gt_path).convert('L')
63 |
64 | if self.joint_transform is not None:
65 | img, target = self.joint_transform(img, target)
66 | if self.transform is not None:
67 | img = self.transform(img)
68 | if self.target_transform is not None:
69 | target = self.target_transform(target)
70 |
71 | return img, target
72 |
73 | def __len__(self):
74 | return len(self.imgs)
75 |
76 |
77 | ####可用可不用. GateNet论文中没有使用multi-scale to train
78 | def collate(self,batch):
79 | # size = [224, 256, 288, 320, 352][np.random.randint(0, 5)]
80 | # size_list = [224, 256, 288, 320, 352]
81 | # size_list = [128, 160, 192, 224, 256]
82 | size_list = [128, 192, 256, 320, 384]
83 | size = random.choice(size_list)
84 |
85 | img, target = [list(item) for item in zip(*batch)]
86 | img = torch.stack(img, dim=0)
87 | img = interpolate(img, size=(size, size), mode="bilinear", align_corners=False)
88 | target = torch.stack(target, dim=0)
89 | target = interpolate(target, size=(size, size), mode="bilinear")
90 | # print(img.shape)
91 | return img, target
--------------------------------------------------------------------------------
/utils_ssl/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pydensecrf.densecrf as dcrf
4 |
5 |
6 | class AvgMeter(object):
7 | def __init__(self):
8 | self.reset()
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
22 |
23 | def check_mkdir(dir_name):
24 | if not os.path.exists(dir_name):
25 | os.mkdir(dir_name)
26 |
27 | def cal_precision_recall_mae(prediction, gt):
28 |
29 | assert prediction.dtype == np.uint8
30 | assert gt.dtype == np.uint8
31 | print(prediction.shape,gt.shape)
32 | assert prediction.shape == gt.shape
33 | eps = 1e-4
34 | gt = gt / 255
35 |
36 | prediction = (prediction-prediction.min())/(prediction.max()-prediction.min()+ eps)
37 | gt[gt>0.5] = 1
38 | gt[gt!=1] = 0
39 | mae = np.mean(np.abs(prediction - gt))
40 |
41 | hard_gt = np.zeros(prediction.shape)
42 | hard_gt[gt > 0.5] = 1
43 | t = np.sum(hard_gt)
44 | precision, recall,iou= [], [],[]
45 |
46 | binary = np.zeros(gt.shape)
47 | th = 2 * prediction.mean()
48 | if th > 1:
49 | th = 1
50 | binary[prediction >= th] = 1
51 | sb = (binary * gt).sum()
52 | pre_th = (sb+eps) / (binary.sum() + eps)
53 | rec_th = (sb+eps) / (gt.sum() + eps)
54 | thfm = 1.3 * pre_th * rec_th / (0.3*pre_th + rec_th + eps)
55 |
56 |
57 | for threshold in range(256):
58 | threshold = threshold / 255.
59 |
60 | hard_prediction = np.zeros(prediction.shape)
61 | hard_prediction[prediction > threshold] = 1
62 |
63 | tp = np.sum(hard_prediction * hard_gt)
64 | p = np.sum(hard_prediction)
65 | iou.append((tp + eps) / (p+t-tp + eps))
66 | precision.append((tp + eps) / (p + eps))
67 | recall.append((tp + eps) / (t + eps))
68 |
69 |
70 | return precision, recall, iou,mae,thfm
71 |
72 |
73 |
74 | def cal_fmeasure(precision, recall,iou): #iou
75 | beta_square = 0.3
76 |
77 | max_fmeasure = max([(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)])
78 | loc = [(1 + beta_square) * p * r / (beta_square * p + r) for p, r in zip(precision, recall)]
79 | a = loc.index(max(loc))
80 | max_iou = max(iou)
81 |
82 | return max_fmeasure,max_iou
83 |
84 |
85 |
86 |
87 | def crf_refine(img, annos):
88 | def _sigmoid(x):
89 | return 1 / (1 + np.exp(-x))
90 |
91 | assert img.dtype == np.uint8
92 | assert annos.dtype == np.uint8
93 | print(img.shape[:2],annos.shape)
94 | assert img.shape[:2] == annos.shape
95 |
96 | # img and annos should be np array with data type uint8
97 |
98 | EPSILON = 1e-8
99 |
100 | M = 2 # salient or not
101 | tau = 1.05
102 | # Setup the CRF model
103 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)
104 |
105 | anno_norm = annos / 255.
106 |
107 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
108 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))
109 |
110 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') # 创建和输入图片同样大小的U
111 | U[0, :] = n_energy.flatten()
112 | U[1, :] = p_energy.flatten()
113 |
114 | d.setUnaryEnergy(U)
115 |
116 | d.addPairwiseGaussian(sxy=3, compat=3)
117 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)
118 |
119 | # Do the inference
120 | infer = np.array(d.inference(1)).astype('float32')
121 | res = infer[1, :]
122 |
123 | res = res * 255
124 | res = res.reshape(img.shape[:2]) # 和输入图片同样大小
125 | return res.astype('uint8')
126 |
--------------------------------------------------------------------------------
/prediction_rgbd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 | import torch
5 | from PIL import Image
6 | from torch.autograd import Variable
7 | from torchvision import transforms
8 | from utils_downstream.config import dutrgbd,njud,nlpr,stere,sip,rgbd135,ssd,lfsd
9 | from utils_downstream.misc import check_mkdir
10 | from model.model_stage3 import RGBD_sal
11 | import ttach as tta
12 |
13 | torch.manual_seed(2018)
14 | torch.cuda.set_device(0)
15 | ckpt_path = './saved_model'
16 | args = {
17 | 'snapshot': 'imagenet_based_model-50',
18 | 'crf_refine': False,
19 | 'save_results': True
20 | }
21 |
22 |
23 |
24 | img_transform = transforms.Compose([
25 | transforms.ToTensor(),
26 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
27 |
28 | ])
29 |
30 | depth_transform = transforms.ToTensor()
31 | target_transform = transforms.ToTensor()
32 | to_pil = transforms.ToPILImage()
33 |
34 | to_test = {'DUT-RGBD':dutrgbd,'NJUD':njud,'NLPR':nlpr,'STERE':stere,'SIP':sip,'RGBD135':rgbd135,'SSD':ssd,'LFSD':lfsd}
35 |
36 | transforms = tta.Compose(
37 | [
38 | # tta.HorizontalFlip(),
39 | # tta.Scale(scales=[0.75, 1, 1.25], interpolation='bilinear', align_corners=False),
40 | tta.Scale(scales=[1], interpolation='bilinear', align_corners=False),
41 | ]
42 | )
43 |
44 | def main():
45 | t0 = time.time()
46 | net = RGBD_sal().cuda()
47 | print ('load snapshot \'%s\' for testing' % args['snapshot'])
48 | net.load_state_dict(torch.load(os.path.join(ckpt_path, args['snapshot']+'.pth'),map_location={'cuda:1': 'cuda:1'}))
49 | net.eval()
50 | with torch.no_grad():
51 | for name, root in to_test.items():
52 | root1 = os.path.join(root,'depth')
53 | img_list = [os.path.splitext(f) for f in os.listdir(root1)]
54 | for idx, img_name in enumerate(img_list):
55 |
56 | print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
57 | rgb_png_path = os.path.join(root, 'RGB', img_name[0] + '.png')
58 | rgb_jpg_path = os.path.join(root, 'RGB', img_name[0] + '.jpg')
59 | depth_jpg_path = os.path.join(root, 'depth', img_name[0] + '.jpg')
60 | depth_png_path = os.path.join(root, 'depth', img_name[0] + '.png')
61 | if os.path.exists(rgb_png_path):
62 | img = Image.open(rgb_png_path).convert('RGB')
63 | else:
64 | img = Image.open(rgb_jpg_path).convert('RGB')
65 | if os.path.exists(depth_jpg_path):
66 | depth = Image.open(depth_jpg_path).convert('L')
67 | else:
68 | depth = Image.open(depth_png_path).convert('L')
69 |
70 |
71 | w_,h_ = img.size
72 | img_resize = img.resize([256,256],Image.BILINEAR) # Foldconv cat是320
73 | depth_resize = depth.resize([256,256],Image.BILINEAR) # Foldconv cat是320
74 | img_var = Variable(img_transform(img_resize).unsqueeze(0), volatile=True).cuda()
75 | depth_var = Variable(depth_transform(depth_resize).unsqueeze(0), volatile=True).cuda()
76 | n, c, h, w = img_var.size()
77 | depth_3 = torch.cat((depth_var, depth_var, depth_var), 1)
78 | mask = []
79 | for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform()
80 |
81 | rgb_trans = transformer.augment_image(img_var)
82 | d_trans = transformer.augment_image(depth_3)
83 | model_output = net(rgb_trans,d_trans)
84 | deaug_mask = transformer.deaugment_mask(model_output)
85 | mask.append(deaug_mask)
86 |
87 | prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
88 | prediction = prediction.sigmoid()
89 | prediction = to_pil(prediction.data.squeeze(0).cpu())
90 | prediction = prediction.resize((w_, h_), Image.BILINEAR)
91 | if args['crf_refine']:
92 | prediction = crf_refine(np.array(img), np.array(prediction))
93 | if args['save_results']:
94 | check_mkdir(os.path.join(ckpt_path, args['snapshot'],name))
95 | prediction.save(os.path.join(ckpt_path, args['snapshot'],name, img_name[0] + '.png'))
96 |
97 |
98 |
99 | if __name__ == '__main__':
100 | main()
101 |
--------------------------------------------------------------------------------
/utils_downstream/ssim_loss.py:
--------------------------------------------------------------------------------
1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import numpy as np
6 | from math import exp
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10 | return gauss/gauss.sum()
11 |
12 | def create_window(window_size, channel):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
16 | return window
17 |
18 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
21 |
22 | mu1_sq = mu1.pow(2)
23 | mu2_sq = mu2.pow(2)
24 | mu1_mu2 = mu1*mu2
25 |
26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
29 |
30 | C1 = 0.01**2
31 | C2 = 0.03**2
32 |
33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
34 |
35 | if size_average:
36 | return ssim_map.mean()
37 | else:
38 | return ssim_map.mean(1).mean(1).mean(1)
39 |
40 | class SSIM(torch.nn.Module):
41 | def __init__(self, window_size = 11, size_average = True):
42 | super(SSIM, self).__init__()
43 | self.window_size = window_size
44 | self.size_average = size_average
45 | self.channel = 1
46 | self.window = create_window(window_size, self.channel)
47 |
48 | def forward(self, img1, img2):
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def _logssim(img1, img2, window, window_size, channel, size_average = True):
67 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
68 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
69 |
70 | mu1_sq = mu1.pow(2)
71 | mu2_sq = mu2.pow(2)
72 | mu1_mu2 = mu1*mu2
73 |
74 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
75 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
76 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
77 |
78 | C1 = 0.01**2
79 | C2 = 0.03**2
80 |
81 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
82 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map))
83 | ssim_map = -torch.log(ssim_map + 1e-8)
84 |
85 | if size_average:
86 | return ssim_map.mean()
87 | else:
88 | return ssim_map.mean(1).mean(1).mean(1)
89 |
90 | class LOGSSIM(torch.nn.Module):
91 | def __init__(self, window_size = 11, size_average = True):
92 | super(LOGSSIM, self).__init__()
93 | self.window_size = window_size
94 | self.size_average = size_average
95 | self.channel = 1
96 | self.window = create_window(window_size, self.channel)
97 |
98 | def forward(self, img1, img2):
99 | (_, channel, _, _) = img1.size()
100 |
101 | if channel == self.channel and self.window.data.type() == img1.data.type():
102 | window = self.window
103 | else:
104 | window = create_window(self.window_size, channel)
105 |
106 | if img1.is_cuda:
107 | window = window.cuda(img1.get_device())
108 | window = window.type_as(img1)
109 |
110 | self.window = window
111 | self.channel = channel
112 |
113 |
114 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average)
115 |
116 |
117 | def ssim(img1, img2, window_size = 11, size_average = True):
118 | (_, channel, _, _) = img1.size()
119 | window = create_window(window_size, channel)
120 |
121 | if img1.is_cuda:
122 | window = window.cuda(img1.get_device())
123 | window = window.type_as(img1)
124 |
125 | return _ssim(img1, img2, window, window_size, channel, size_average)
126 |
--------------------------------------------------------------------------------
/model/model_stage1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torchvision import models
5 |
6 | class Crossmodal_Autoendoer(nn.Module):
7 |
8 | def __init__(self):
9 | super(Crossmodal_Autoendoer, self).__init__()
10 | ################################vgg16#######################################
11 | ##set 'pretrained=False' for SSL model or 'pretrained=True' for ImageNet pretrained model.
12 | feats = list(models.vgg16_bn(pretrained=False).features.children())
13 | feats1 = list(models.vgg16_bn(pretrained=False).features.children())
14 | #self.conv0 = nn.Conv2d(4, 64, kernel_size=3, padding=1)
15 | self.conv1_RGB = nn.Sequential(*feats[0:6])
16 | self.conv2_RGB = nn.Sequential(*feats[6:13])
17 | self.conv3_RGB = nn.Sequential(*feats[13:23])
18 | self.conv4_RGB = nn.Sequential(*feats[23:33])
19 | self.conv5_RGB = nn.Sequential(*feats[33:43])
20 |
21 | self.conv1_depth = nn.Sequential(*feats1[0:6])
22 | self.conv2_depth = nn.Sequential(*feats1[6:13])
23 | self.conv3_depth = nn.Sequential(*feats1[13:23])
24 | self.conv4_depth = nn.Sequential(*feats1[23:33])
25 | self.conv5_depth = nn.Sequential(*feats1[33:43])
26 |
27 | self.output4_rgb = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
28 | self.output3_rgb = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
29 | self.output2_rgb = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64), nn.PReLU())
30 | self.output1_rgbtod = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1))
31 |
32 | self.output4_depth = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
33 | self.output3_depth = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
34 | self.output2_depth = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64), nn.PReLU())
35 | self.output1_depthtorgb = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
36 |
37 | self.sideout5_rgbtod = nn.Sequential(nn.Conv2d(512, 1, kernel_size=3, padding=1))
38 | self.sideout4_rgbtod = nn.Sequential(nn.Conv2d(256, 1, kernel_size=3, padding=1))
39 | self.sideout3_rgbtod = nn.Sequential(nn.Conv2d(128, 1, kernel_size=3, padding=1))
40 | self.sideout2_rgbtod = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1))
41 |
42 | self.sideout5_depthtorgb = nn.Sequential(nn.Conv2d(512, 3, kernel_size=3, padding=1))
43 | self.sideout4_depthtorgb = nn.Sequential(nn.Conv2d(256, 3, kernel_size=3, padding=1))
44 | self.sideout3_depthtorgb = nn.Sequential(nn.Conv2d(128, 3, kernel_size=3, padding=1))
45 | self.sideout2_depthtorgb = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
46 |
47 |
48 |
49 | for m in self.modules():
50 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
51 | m.inplace = True
52 |
53 | def forward(self, x,depth):
54 |
55 | input = x
56 | B,_,_,_ = input.size()
57 | e1_rgb = self.conv1_RGB(x)
58 | e1_depth = self.conv1_depth(depth)
59 | e2_rgb = self.conv2_RGB(e1_rgb)
60 | e2_depth= self.conv2_depth(e1_depth)
61 | e3_rgb = self.conv3_RGB(e2_rgb)
62 | e3_depth = self.conv3_depth(e2_depth)
63 | e4_rgb = self.conv4_RGB(e3_rgb)
64 | e4_depth = self.conv4_depth(e3_depth)
65 | e5_rgb = self.conv5_RGB(e4_rgb)
66 | e5_depth = self.conv5_depth(e4_depth)
67 |
68 | sideout5_rgbtod = self.sideout5_rgbtod(e5_rgb)
69 | output4_rgb = self.output4_rgb(F.upsample(e5_rgb, size=e4_rgb.size()[2:], mode='bilinear')+e4_rgb)
70 | sideout4_rgbtod = self.sideout4_rgbtod(output4_rgb)
71 | output3_rgb = self.output3_rgb(F.upsample(output4_rgb, size=e3_rgb.size()[2:], mode='bilinear') + e3_rgb)
72 | sideout3_rgbtod = self.sideout3_rgbtod(output3_rgb)
73 | output2_rgb = self.output2_rgb(F.upsample(output3_rgb, size=e2_rgb.size()[2:], mode='bilinear') + e2_rgb)
74 | sideout2_rgbtod = self.sideout2_rgbtod(output2_rgb)
75 | sideout1_rgbtod = self.output1_rgbtod(F.upsample(output2_rgb, size=e1_rgb.size()[2:], mode='bilinear') + e1_rgb)
76 |
77 | sideout5_dtorgb = self.sideout5_depthtorgb(e5_depth)
78 | output4_d = self.output4_depth(F.upsample(e5_depth, size=e4_rgb.size()[2:], mode='bilinear')+e4_depth)
79 | sideout4_dtorgb = self.sideout4_depthtorgb(output4_d)
80 | output3_d = self.output3_depth(F.upsample(output4_d, size=e3_rgb.size()[2:], mode='bilinear') + e3_depth)
81 | sideout3_dtorgb = self.sideout3_depthtorgb(output3_d)
82 | output2_d = self.output2_depth(F.upsample(output3_d, size=e2_rgb.size()[2:], mode='bilinear') + e2_depth)
83 | sideout2_dtorgb = self.sideout2_depthtorgb(output2_d)
84 | sideout1_dtorgb = self.output1_depthtorgb(F.upsample(output2_d, size=e1_rgb.size()[2:], mode='bilinear') + e1_depth)
85 |
86 |
87 | if self.training:
88 | return sideout5_rgbtod,sideout4_rgbtod,sideout3_rgbtod,sideout2_rgbtod,sideout1_rgbtod,sideout5_dtorgb,sideout4_dtorgb,sideout3_dtorgb,sideout2_dtorgb,sideout1_dtorgb
89 | # return F.sigmoid(sideout1_rgbtod), F.sigmoid(sideout1_dtorgb)
90 | return e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb, e5_depth,e4_depth,e3_depth,e2_depth,e1_depth
91 | # return e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb
92 |
93 |
--------------------------------------------------------------------------------
/train_stage3_downstream.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import os, argparse
5 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
6 | from datetime import datetime
7 | from model.model_stage3 import RGBD_sal
8 | from utils_downstream.dataset_rgbd_strategy2 import get_loader
9 | from utils_downstream.utils import adjust_lr, AvgMeter
10 | import torch.nn as nn
11 | from torch.cuda import amp
12 |
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--epoch', type=int, default=100, help='epoch number')
16 | parser.add_argument('--lr_gen', type=float, default=1e-4, help='learning rate')
17 | parser.add_argument('--batchsize', type=int, default=4, help='training batch size')
18 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size')
19 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
20 | parser.add_argument('--decay_rate', type=float, default=0.9, help='decay rate of learning rate')
21 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate')
22 | parser.add_argument('-beta1_gen', type=float, default=0.5,help='beta of Adam for generator')
23 | parser.add_argument('--weight_decay', type=float, default=0.001, help='weight_decay')
24 | parser.add_argument('--feat_channel', type=int, default=64, help='reduced channel of saliency feat')
25 |
26 | opt = parser.parse_args()
27 | print('Generator Learning Rate: {}'.format(opt.lr_gen))
28 | # build models
29 | generator = RGBD_sal()
30 | generator.cuda()
31 |
32 | pretrained_dict = torch.load(os.path.join('./saved_model/pretext_task1.pth'))
33 | model_dict = generator.state_dict()
34 | # print(pretrained_dict.items())
35 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
36 | model_dict.update(pretrained_dict)
37 | generator.load_state_dict(model_dict)
38 |
39 | pretrained_dict = torch.load(os.path.join('./saved_model/pretext_task2.pth'))
40 | model_dict = generator.state_dict()
41 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
42 | print(pretrained_dict.items())
43 | model_dict.update(pretrained_dict)
44 | generator.load_state_dict(model_dict)
45 |
46 | generator_params = generator.parameters()
47 | generator_optimizer = torch.optim.Adam(generator_params, opt.lr_gen)
48 |
49 | ## load data
50 | image_root = ''
51 | depth_root = ''
52 | gt_root = ''
53 |
54 | train_loader = get_loader(image_root, gt_root,depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
55 | total_step = len(train_loader)
56 |
57 | ## define loss
58 |
59 | CE = torch.nn.BCELoss()
60 | mse_loss = torch.nn.MSELoss(size_average=True, reduce=True)
61 | # size_rates = [0.75,1,1.25] # multi-scale training
62 | size_rates = [1] # multi-scale training
63 | criterion = nn.BCEWithLogitsLoss().cuda()
64 | criterion_mae = nn.L1Loss().cuda()
65 | criterion_mse = nn.MSELoss().cuda()
66 | use_fp16 = True
67 | scaler = amp.GradScaler(enabled=use_fp16)
68 |
69 | def structure_loss(pred, mask):
70 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask)
71 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
72 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
73 |
74 |
75 | pred = torch.sigmoid(pred)
76 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
77 | union = ((pred + mask) * weit).sum(dim=(2, 3))
78 | wiou = 1-(inter+1)/(union-inter+1)
79 |
80 | return (wbce+wiou).mean()
81 |
82 |
83 |
84 | for epoch in range(1, opt.epoch+1):
85 | generator.train()
86 | loss_record = AvgMeter()
87 | print('Generator Learning Rate: {}'.format(generator_optimizer.param_groups[0]['lr']))
88 | for i, pack in enumerate(train_loader, start=1):
89 | for rate in size_rates:
90 | generator_optimizer.zero_grad()
91 | images, gts, depths = pack
92 | images = Variable(images)
93 | gts = Variable(gts)
94 | depths = Variable(depths)
95 | images = images.cuda()
96 | gts = gts.cuda()
97 | depths = depths.cuda()
98 | # multi-scale training samples
99 | trainsize = int(round(opt.trainsize * rate / 32) * 32)
100 | if rate != 1:
101 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear',
102 | align_corners=True)
103 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
104 | # contours = F.upsample(contours, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
105 | depths = F.upsample(depths, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
106 |
107 | b, c, h, w = gts.size()
108 | target_1 = F.upsample(gts, size=h // 2, mode='nearest')
109 | target_2 = F.upsample(gts, size=h // 4, mode='nearest')
110 | target_3 = F.upsample(gts, size=h // 8, mode='nearest')
111 | target_4 = F.upsample(gts, size=h // 16, mode='nearest')
112 |
113 | with amp.autocast(enabled=use_fp16):
114 | depth_3 = torch.cat((depths, depths, depths), 1)
115 | sideout5, sideout4, sideout3, sideout2, output1 = generator.forward(images, depth_3) # hed
116 | loss1 = structure_loss(sideout5, target_4)
117 | loss2 = structure_loss(sideout4, target_3)
118 | loss3 = structure_loss(sideout3, target_2)
119 | loss4 = structure_loss(sideout2, target_1)
120 | loss5 = structure_loss(output1, gts)
121 |
122 | loss = loss1 + loss2 + loss3 + loss4 + loss5
123 |
124 | generator_optimizer.zero_grad()
125 | scaler.scale(loss).backward()
126 | scaler.step(generator_optimizer)
127 | scaler.update()
128 |
129 | if rate == 1:
130 | loss_record.update(loss.data, opt.batchsize)
131 |
132 |
133 | if i % 10 == 0 or i == total_step:
134 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], gen Loss: {:.4f}'.
135 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record.show()))
136 | # print(anneal_reg)
137 |
138 |
139 | adjust_lr(generator_optimizer, opt.lr_gen, epoch, opt.decay_rate, opt.decay_epoch)
140 |
141 | save_path = './saved_model/SSLSOD_v2'
142 |
143 |
144 | if not os.path.exists(save_path):
145 | os.makedirs(save_path)
146 | if epoch % opt.epoch == 0:
147 | torch.save(generator.state_dict(), save_path + 'Model' + '_%d' % epoch + '_gen.pth')
148 |
--------------------------------------------------------------------------------
/train_stage2_pretext2.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import torch
4 | from torch import optim
5 | from torch.autograd import Variable
6 | from torchvision import transforms
7 | from torch.utils.data import DataLoader
8 | import utils_ssl.joint_transforms
9 | from utils_ssl.datasets_stage2 import ImageFolder
10 | from utils_ssl.misc import AvgMeter, check_mkdir
11 | from model.model_stage1 import Crossmodal_Autoendoer
12 | from model.model_stage2 import Contour_Estimation
13 | from torch.backends import cudnn
14 | from utils_downstream.ssim_loss import SSIM
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | cudnn.benchmark = True
18 | torch.manual_seed(2018)
19 | torch.cuda.set_device(0)
20 |
21 | ##########################hyperparameters###############################
22 | ckpt_path = './saved_model'
23 | exp_name = 'pretext_task2_stage2'
24 | args = {
25 | 'iter_num': 79900, #50epoch
26 | 'train_batch_size': 4,
27 | 'last_iter': 0,
28 | 'lr': 1e-3,
29 | 'lr_decay': 0.9,
30 | 'weight_decay': 0.0005,
31 | 'momentum': 0.9,
32 | 'snapshot': ''
33 | }
34 | ##########################data augmentation###############################
35 | joint_transform = utils_ssl.joint_transforms.Compose([
36 | utils_ssl.joint_transforms.RandomCrop(256, 256), # change to resize
37 | utils_ssl.joint_transforms.RandomHorizontallyFlip(),
38 | utils_ssl.joint_transforms.RandomRotate(10)
39 | ])
40 | img_transform = transforms.Compose([
41 | transforms.ColorJitter(0.1, 0.1, 0.1),
42 | transforms.ToTensor(),
43 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
44 | ])
45 | target_transform = transforms.ToTensor()
46 | ##########################################################################
47 | image_root = ''
48 | depth_root = ''
49 | gt_root = ''
50 |
51 | train_set = ImageFolder(image_root, gt_root,depth_root, joint_transform, img_transform, target_transform)
52 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=0, shuffle=True)
53 |
54 |
55 | criterion = nn.BCEWithLogitsLoss().cuda()
56 | criterion_BCE = nn.BCELoss().cuda()
57 | criterion_mse = nn.MSELoss().cuda()
58 | criterion_mae = nn.L1Loss().cuda()
59 | criterion_ssim = SSIM(window_size=11,size_average=True)
60 | def ssimmae(pre,gt):
61 | maeloss = criterion_mae(pre,gt)
62 | ssimloss = 1-criterion_ssim(pre,gt)
63 | loss = ssimloss+maeloss
64 | return loss
65 |
66 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')
67 |
68 |
69 | def main():
70 | #############################ResNet pretrained###########################
71 | #res18[2,2,2,2],res34[3,4,6,3],res50[3,4,6,3],res101[3,4,23,3],res152[3,8,36,3]
72 | model_pretext1 = Crossmodal_Autoendoer()
73 | net_pretext1 = model_pretext1.cuda()
74 | net_pretext1.load_state_dict(torch.load(os.path.join('./saved_model/pretext_task1.pth')))
75 | net_pretext1.eval()
76 |
77 | model_pretext2 = Contour_Estimation()
78 | net = model_pretext2.cuda().train()
79 |
80 |
81 | optimizer = optim.SGD([
82 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
83 | 'lr': 2 * args['lr']},
84 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
85 | 'lr': args['lr'], 'weight_decay': args['weight_decay']}
86 | ], momentum=args['momentum'])
87 | if len(args['snapshot']) > 0:
88 | print('training resumes from ' + args['snapshot'])
89 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
90 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
91 | optimizer.param_groups[0]['lr'] = 2 * args['lr']
92 | optimizer.param_groups[1]['lr'] = args['lr']
93 | check_mkdir(ckpt_path)
94 | check_mkdir(os.path.join(ckpt_path, exp_name))
95 | open(log_path, 'w').write(str(args) + '\n\n')
96 | train(net_pretext1,net, optimizer)
97 |
98 |
99 | #########################################################################
100 |
101 | def train(net_pretext1,net, optimizer):
102 | curr_iter = args['last_iter']
103 | while True:
104 | total_loss_record, loss1_record, loss2_record, loss3_record, loss4_record, loss5_record, loss6_record, loss7_record, loss8_record,loss9_record,loss10_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
105 | for i, data in enumerate(train_loader):
106 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
107 | ) ** args['lr_decay']
108 | optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
109 | ) ** args['lr_decay']
110 | # data\binarizing\Variable
111 | images, depths, gts = data
112 | gts[gts > 0.5] = 1
113 | gts[gts != 1] = 0
114 | batch_size = images.size(0)
115 | inputs = Variable(images).cuda()
116 | labels = Variable(gts).cuda()
117 | depths = Variable(depths).cuda()
118 | b, c, h, w = labels.size()
119 | optimizer.zero_grad()
120 | target_1 = F.upsample(labels, size=h // 2, mode='nearest')
121 | target_2 = F.upsample(labels, size=h // 4, mode='nearest')
122 | target_3 = F.upsample(labels, size=h // 8, mode='nearest')
123 | target_4 = F.upsample(labels, size=h // 16, mode='nearest')
124 |
125 | ##########loss#############
126 | depth_3 = torch.cat((depths, depths, depths), 1)
127 | e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb, e5_depth,e4_depth,e3_depth,e2_depth,e1_depth = net_pretext1(
128 | inputs, depth_3) # hed
129 | sideout5, sideout4, sideout3, sideout2, output1 = net(e5_rgb, e4_rgb, e3_rgb, e2_rgb, e1_rgb, e5_depth, e4_depth, e3_depth, e2_depth, e1_depth)
130 | loss1 = criterion_mae(F.sigmoid(sideout5), target_4)
131 | loss2 = criterion_mae(F.sigmoid(sideout4), target_3)
132 | loss3 = criterion_mae(F.sigmoid(sideout3), target_2)
133 | loss4 = criterion_mae(F.sigmoid(sideout2), target_1)
134 | loss5 = criterion_mae(F.sigmoid(output1), labels)
135 |
136 | total_loss = loss1 + loss2 + loss3 + loss4 + loss5
137 | total_loss.backward()
138 | optimizer.step()
139 | total_loss_record.update(total_loss.item(), batch_size)
140 | loss1_record.update(loss1.item(), batch_size)
141 | loss2_record.update(loss2.item(), batch_size)
142 | loss3_record.update(loss3.item(), batch_size)
143 | loss4_record.update(loss4.item(), batch_size)
144 | loss5_record.update(loss5.item(), batch_size)
145 |
146 | #############log###############
147 | curr_iter += 1
148 | log = '[iter %d], [total loss %.5f],[loss4 %.5f],[loss5 %.5f],[lr %.13f] ' % \
149 | (curr_iter, total_loss_record.avg, loss4_record.avg, loss5_record.avg, optimizer.param_groups[1]['lr'])
150 | print(log)
151 | open(log_path, 'a').write(log + '\n')
152 | if curr_iter == args['iter_num']:
153 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
154 | torch.save(optimizer.state_dict(),
155 | os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
156 | return
157 | ###############end###############
158 |
159 |
160 | if __name__ == '__main__':
161 | main()
162 |
--------------------------------------------------------------------------------
/train_stage1_pretext1.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import torch
4 | from torch import optim
5 | from torch.autograd import Variable
6 | from torchvision import transforms
7 | from torch.utils.data import DataLoader
8 | import utils_ssl.joint_transforms
9 | from utils_ssl.datasets_stage1 import ImageFolder
10 | from utils_ssl.misc import AvgMeter, check_mkdir
11 | from model.model_stage1 import Crossmodal_Autoendoer
12 | from torch.backends import cudnn
13 | from utils_downstream.ssim_loss import SSIM
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | cudnn.benchmark = True
17 | torch.manual_seed(2018)
18 | torch.cuda.set_device(0)
19 |
20 | ##########################hyperparameters###############################
21 | ckpt_path = './model'
22 | exp_name = 'pretext_task1_stage1'
23 | args = {
24 | 'iter_num': 79900,
25 | 'train_batch_size': 4,
26 | 'last_iter': 0,
27 | 'lr': 1e-3,
28 | 'lr_decay': 0.9,
29 | 'weight_decay': 0.0005,
30 | 'momentum': 0.9,
31 | 'snapshot': ''
32 | }
33 | ##########################data augmentation###############################
34 | joint_transform = utils_ssl.joint_transforms.Compose([
35 | utils_ssl.joint_transforms.RandomCrop(256, 256), # change to resize
36 | utils_ssl.joint_transforms.RandomHorizontallyFlip(),
37 | utils_ssl.joint_transforms.RandomRotate(10)
38 | ])
39 | img_transform = transforms.Compose([
40 | transforms.ColorJitter(0.1, 0.1, 0.1),
41 | transforms.ToTensor(),
42 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
43 | ])
44 | target_transform = transforms.ToTensor()
45 | ##########################################################################
46 | image_root = ''
47 | depth_root = ''
48 | gt_root = ''
49 |
50 | train_set = ImageFolder(image_root, gt_root,depth_root, joint_transform, img_transform, target_transform)
51 | train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=0, shuffle=True)
52 |
53 |
54 | ###multi-scale-train
55 | # train_set = ImageFolder_multi_scale(train_data, joint_transform, img_transform, target_transform)
56 | # train_loader = DataLoader(train_set, collate_fn=train_set.collate, batch_size=args['train_batch_size'], num_workers=12, shuffle=True, drop_last=True)
57 |
58 | criterion = nn.BCEWithLogitsLoss().cuda()
59 | criterion_BCE = nn.BCELoss().cuda()
60 | criterion_mse = nn.MSELoss().cuda()
61 | criterion_mae = nn.L1Loss().cuda()
62 | criterion_ssim = SSIM(window_size=11,size_average=True)
63 | def ssimmae(pre,gt):
64 | maeloss = criterion_mae(pre,gt)
65 | ssimloss = 1-criterion_ssim(pre,gt)
66 | loss = ssimloss+maeloss
67 | return loss
68 |
69 | log_path = os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt')
70 |
71 |
72 | def main():
73 | #############################ResNet pretrained###########################
74 | #res18[2,2,2,2],res34[3,4,6,3],res50[3,4,6,3],res101[3,4,23,3],res152[3,8,36,3]
75 | model = Crossmodal_Autoendoer()
76 | net = model.cuda().train()
77 | optimizer = optim.SGD([
78 | {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
79 | 'lr': 2 * args['lr']},
80 | {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
81 | 'lr': args['lr'], 'weight_decay': args['weight_decay']}
82 | ], momentum=args['momentum'])
83 | if len(args['snapshot']) > 0:
84 | print('training resumes from ' + args['snapshot'])
85 | net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
86 | optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
87 | optimizer.param_groups[0]['lr'] = 2 * args['lr']
88 | optimizer.param_groups[1]['lr'] = args['lr']
89 | check_mkdir(ckpt_path)
90 | check_mkdir(os.path.join(ckpt_path, exp_name))
91 | open(log_path, 'w').write(str(args) + '\n\n')
92 | train(net, optimizer)
93 |
94 |
95 | #########################################################################
96 |
97 | def train(net, optimizer):
98 | curr_iter = args['last_iter']
99 | while True:
100 | total_loss_record, loss1_record, loss2_record, loss3_record, loss4_record, loss5_record, loss6_record, loss7_record, loss8_record,loss9_record,loss10_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
101 | for i, data in enumerate(train_loader):
102 | optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']
103 | ) ** args['lr_decay']
104 | optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']
105 | ) ** args['lr_decay']
106 | # data\binarizing\Variable
107 | images, images_raw, gts, depths = data
108 | gts[gts > 0.5] = 1
109 | gts[gts != 1] = 0
110 | batch_size = images.size(0)
111 | inputs = Variable(images).cuda()
112 | labels = Variable(gts).cuda()
113 | images_raw = Variable(images_raw).cuda()
114 | depths = Variable(depths).cuda()
115 | b, c, h, w = labels.size()
116 | optimizer.zero_grad()
117 | target_1 = F.upsample(labels, size=h // 2, mode='nearest')
118 | target_2 = F.upsample(labels, size=h // 4, mode='nearest')
119 | target_3 = F.upsample(labels, size=h // 8, mode='nearest')
120 | target_4 = F.upsample(labels, size=h // 16, mode='nearest')
121 |
122 | images_raw_1 = F.upsample(images_raw, size=h // 2, mode='nearest')
123 | images_raw_2 = F.upsample(images_raw, size=h // 4, mode='nearest')
124 | images_raw_3 = F.upsample(images_raw, size=h // 8, mode='nearest')
125 | images_raw_4 = F.upsample(images_raw, size=h // 16, mode='nearest')
126 |
127 | depths_1 = F.upsample(depths, size=h // 2, mode='nearest')
128 | depths_2 = F.upsample(depths, size=h // 4, mode='nearest')
129 | depths_3 = F.upsample(depths, size=h // 8, mode='nearest')
130 | depths_4 = F.upsample(depths, size=h // 16, mode='nearest')
131 |
132 | ##########loss#############
133 | depth_3 = torch.cat((depths, depths, depths), 1)
134 | sideout5_rgbtod, sideout4_rgbtod, sideout3_rgbtod, sideout2_rgbtod, sideout1_rgbtod, sideout5_dtorgb, sideout4_dtorgb, sideout3_dtorgb, sideout2_dtorgb, sideout1_dtorgb = net(
135 | inputs, depth_3) # hed
136 | loss1 = ssimmae(F.sigmoid(sideout5_rgbtod), depths_4)
137 | loss2 = ssimmae(F.sigmoid(sideout4_rgbtod), depths_3)
138 | loss3 = ssimmae(F.sigmoid(sideout3_rgbtod), depths_2)
139 | loss4 = ssimmae(F.sigmoid(sideout2_rgbtod), depths_1)
140 | loss5 = ssimmae(F.sigmoid(sideout1_rgbtod), depths)
141 | loss6 = ssimmae(F.sigmoid(sideout5_dtorgb), images_raw_4)
142 | loss7 = ssimmae(F.sigmoid(sideout4_dtorgb), images_raw_3)
143 | loss8 = ssimmae(F.sigmoid(sideout3_dtorgb), images_raw_2)
144 | loss9 = ssimmae(F.sigmoid(sideout2_dtorgb), images_raw_1)
145 | loss10 = ssimmae(F.sigmoid(sideout1_dtorgb), images_raw)
146 |
147 | loss_depth = loss1 + loss2 + loss3 + loss4 + loss5
148 | loss_rgb = loss6 + loss7 + loss8 + loss9 + loss10
149 | total_loss = loss_depth + loss_rgb
150 | total_loss.backward()
151 | optimizer.step()
152 | total_loss_record.update(total_loss.item(), batch_size)
153 | loss1_record.update(loss1.item(), batch_size)
154 | loss2_record.update(loss2.item(), batch_size)
155 | loss3_record.update(loss3.item(), batch_size)
156 | loss4_record.update(loss4.item(), batch_size)
157 | loss5_record.update(loss5.item(), batch_size)
158 | loss6_record.update(loss6.item(), batch_size)
159 | loss7_record.update(loss7.item(), batch_size)
160 | loss8_record.update(loss8.item(), batch_size)
161 | loss9_record.update(loss9.item(), batch_size)
162 | loss10_record.update(loss10.item(), batch_size)
163 |
164 | #############log###############
165 | curr_iter += 1
166 |
167 | log = '[iter %d], [total loss %.5f],[loss5 %.5f],[loss10 %.5f],[lr %.13f] ' % \
168 | (curr_iter, total_loss_record.avg, loss5_record.avg, loss10_record.avg, optimizer.param_groups[1]['lr'])
169 | print(log)
170 | open(log_path, 'a').write(log + '\n')
171 |
172 | if curr_iter == args['iter_num']:
173 | torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % curr_iter))
174 | torch.save(optimizer.state_dict(),
175 | os.path.join(ckpt_path, exp_name, '%d_optim.pth' % curr_iter))
176 | return
177 | ## #############end###############
178 |
179 |
180 | if __name__ == '__main__':
181 | main()
182 |
--------------------------------------------------------------------------------
/utils_downstream/dataset_rgbd_strategy2.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label, depth):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | # left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
19 | # top bottom flip
20 | # if flip_flag2==1:
21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
24 | return img, label, depth
25 |
26 |
27 | def randomCrop(image, label, depth):
28 | border = 30
29 | image_width = image.size[0]
30 | image_height = image.size[1]
31 | crop_win_width = np.random.randint(image_width - border, image_width)
32 | crop_win_height = np.random.randint(image_height - border, image_height)
33 | random_region = (
34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
35 | (image_height + crop_win_height) >> 1)
36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region)
37 |
38 |
39 | def randomRotation(image, label, depth):
40 | mode = Image.BICUBIC
41 | if random.random() > 0.8:
42 | random_angle = np.random.randint(-15, 15)
43 | image = image.rotate(random_angle, mode)
44 | label = label.rotate(random_angle, mode)
45 | depth = depth.rotate(random_angle, mode)
46 | return image, label, depth
47 |
48 |
49 | def colorEnhance(image):
50 | bright_intensity = random.randint(5, 15) / 10.0
51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
52 | contrast_intensity = random.randint(5, 15) / 10.0
53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
54 | color_intensity = random.randint(0, 20) / 10.0
55 | image = ImageEnhance.Color(image).enhance(color_intensity)
56 | sharp_intensity = random.randint(0, 30) / 10.0
57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
58 | return image
59 |
60 |
61 | def randomGaussian(image, mean=0.1, sigma=0.35):
62 | def gaussianNoisy(im, mean=mean, sigma=sigma):
63 | for _i in range(len(im)):
64 | im[_i] += random.gauss(mean, sigma)
65 | return im
66 |
67 | img = np.asarray(image)
68 | width, height = img.shape
69 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
70 | img = img.reshape([width, height])
71 | return Image.fromarray(np.uint8(img))
72 |
73 |
74 | def randomPeper(img):
75 | img = np.array(img)
76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
77 | for i in range(noiseNum):
78 |
79 | randX = random.randint(0, img.shape[0] - 1)
80 |
81 | randY = random.randint(0, img.shape[1] - 1)
82 |
83 | if random.randint(0, 1) == 0:
84 |
85 | img[randX, randY] = 0
86 |
87 | else:
88 |
89 | img[randX, randY] = 255
90 | return Image.fromarray(img)
91 |
92 |
93 | # dataset for training
94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
96 | class SalObjDataset(data.Dataset):
97 | def __init__(self, image_root, gt_root, depth_root, trainsize):
98 | self.trainsize = trainsize
99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
100 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
101 | or f.endswith('.png')]
102 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
103 | or f.endswith('.png')]
104 | self.images = sorted(self.images)
105 | # print(self.images)
106 | self.gts = sorted(self.gts)
107 | # print(self.gts)
108 | # print(self.contours)
109 | self.depths = sorted(self.depths)
110 | # print(self.depths)
111 | self.filter_files()
112 | self.size = len(self.images)
113 | self.img_transform = transforms.Compose([
114 | transforms.Resize((self.trainsize, self.trainsize)),
115 | transforms.ToTensor(),
116 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
117 | self.gt_transform = transforms.Compose([
118 | transforms.Resize((self.trainsize, self.trainsize)),
119 | transforms.ToTensor()])
120 | self.depths_transform = transforms.Compose(
121 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()])
122 |
123 | def __getitem__(self, index):
124 | image = self.rgb_loader(self.images[index])
125 | gt = self.binary_loader(self.gts[index])
126 | depth = self.binary_loader(self.depths[index])
127 | image, gt, depth = cv_random_flip(image, gt, depth)
128 | image, gt, depth = randomCrop(image, gt, depth)
129 | image, gt, depth = randomRotation(image, gt, depth)
130 | image = colorEnhance(image)
131 | # gt=randomGaussian(gt)
132 | # gt = randomPeper(gt)
133 | image = self.img_transform(image)
134 | gt = self.gt_transform(gt)
135 | depth = self.depths_transform(depth)
136 |
137 |
138 | return image, gt, depth
139 |
140 | def filter_files(self):
141 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
142 | images = []
143 | gts = []
144 | depths = []
145 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths):
146 | img = Image.open(img_path)
147 | gt = Image.open(gt_path)
148 | depth = Image.open(depth_path)
149 | if img.size == gt.size and gt.size == depth.size:
150 | images.append(img_path)
151 | gts.append(gt_path)
152 | depths.append(depth_path)
153 | self.images = images
154 | self.gts = gts
155 | self.depths = depths
156 |
157 | def rgb_loader(self, path):
158 | with open(path, 'rb') as f:
159 | img = Image.open(f)
160 | return img.convert('RGB')
161 |
162 | def binary_loader(self, path):
163 | with open(path, 'rb') as f:
164 | img = Image.open(f)
165 | return img.convert('L')
166 |
167 | def resize(self, img, gt, depth):
168 | assert img.size == gt.size and gt.size == depth.size
169 | w, h = img.size
170 | if h < self.trainsize or w < self.trainsize:
171 | h = max(h, self.trainsize)
172 | w = max(w, self.trainsize)
173 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),
174 | Image.NEAREST)
175 | else:
176 | return img, gt, depth
177 |
178 | def __len__(self):
179 | return self.size
180 |
181 |
182 | # dataloader for training
183 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False):
184 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize)
185 | data_loader = data.DataLoader(dataset=dataset,
186 | batch_size=batchsize,
187 | shuffle=shuffle,
188 | num_workers=num_workers,
189 | pin_memory=pin_memory)
190 | return data_loader
191 |
192 |
193 | # test dataset and loader
194 | class test_dataset:
195 | def __init__(self, image_root, depth_root, testsize):
196 | self.testsize = testsize
197 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
198 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
199 | or f.endswith('.png')]
200 | self.images = sorted(self.images)
201 | self.depths = sorted(self.depths)
202 | self.transform = transforms.Compose([
203 | transforms.Resize((self.testsize, self.testsize)),
204 | transforms.ToTensor(),
205 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
206 | # self.gt_transform = transforms.Compose([
207 | # transforms.Resize((self.trainsize, self.trainsize)),
208 | # transforms.ToTensor()])
209 | self.depths_transform = transforms.Compose(
210 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()])
211 | self.size = len(self.images)
212 | self.index = 0
213 |
214 | def load_data(self):
215 | image = self.rgb_loader(self.images[self.index])
216 | HH = image.size[0]
217 | WW = image.size[1]
218 | image = self.transform(image).unsqueeze(0)
219 | depth = self.rgb_loader(self.depths[self.index])
220 | depth = self.depths_transform(depth).unsqueeze(0)
221 |
222 | name = self.images[self.index].split('/')[-1]
223 | # image_for_post=self.rgb_loader(self.images[self.index])
224 | # image_for_post=image_for_post.resize(gt.size)
225 | if name.endswith('.jpg'):
226 | name = name.split('.jpg')[0] + '.png'
227 | self.index += 1
228 | self.index = self.index % self.size
229 | return image, depth, HH, WW, name
230 |
231 | def rgb_loader(self, path):
232 | with open(path, 'rb') as f:
233 | img = Image.open(f)
234 | return img.convert('RGB')
235 |
236 | def binary_loader(self, path):
237 | with open(path, 'rb') as f:
238 | img = Image.open(f)
239 | return img.convert('L')
240 |
241 | def __len__(self):
242 | return self.size
243 |
244 |
--------------------------------------------------------------------------------
/model/model_stage2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torchvision import models
5 |
6 | class Contour_Estimation(nn.Module):
7 |
8 | def __init__(self):
9 | super(Contour_Estimation, self).__init__()
10 | self.fuse1_conv1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
11 | self.fuse1_conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
12 | self.fuse1_conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
13 | self.fuse1_conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
14 | self.fuse1_conv5 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
15 |
16 | self.fuse1_conv1_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
17 | self.fuse1_conv2_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
18 | self.fuse1_conv3_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
19 | self.fuse1_conv4_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
20 | self.fuse1_conv5_fpn = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.PReLU())
21 |
22 | self.fuse2_conv1 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
23 | self.fuse2_conv2 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
24 | self.fuse2_conv3 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
25 | self.fuse2_conv4 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
26 | self.fuse2_conv5 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
27 |
28 | self.fuse2_conv1_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
29 | self.fuse2_conv2_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
30 | self.fuse2_conv3_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
31 | self.fuse2_conv4_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
32 | self.fuse2_conv5_fpn = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
33 |
34 | self.fuse3_conv1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
35 | self.fuse3_conv2 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
36 | self.fuse3_conv3 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
37 | self.fuse3_conv4 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
38 | self.fuse3_conv5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
39 |
40 | self.fuse3_conv1_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
41 | self.fuse3_conv2_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
42 | self.fuse3_conv3_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
43 | self.fuse3_conv4_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
44 | self.fuse3_conv5_fpn = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
45 |
46 | self.fuse4_conv1 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
47 | self.fuse4_conv2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
48 | self.fuse4_conv3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
49 | self.fuse4_conv4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
50 | self.fuse4_conv5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
51 |
52 | self.fuse4_conv1_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
53 | self.fuse4_conv2_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
54 | self.fuse4_conv3_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
55 | self.fuse4_conv4_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
56 | self.fuse4_conv5_fpn = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
57 |
58 | self.fuse5_conv1 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
59 | self.fuse5_conv2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
60 | self.fuse5_conv3 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
61 | self.fuse5_conv4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
62 | self.fuse5_conv5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.PReLU())
63 |
64 | self.sideout5 = nn.Sequential(nn.Conv2d(512, 1, kernel_size=3, padding=1))
65 | self.sideout4 = nn.Sequential(nn.Conv2d(256, 1, kernel_size=3, padding=1))
66 | self.sideout3 = nn.Sequential(nn.Conv2d(128, 1, kernel_size=3, padding=1))
67 | self.sideout2 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1))
68 |
69 |
70 | self.output4 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.PReLU())
71 | self.output3 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.PReLU())
72 | self.output2 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64), nn.PReLU())
73 | self.output1 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1))
74 |
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.ReLU) or isinstance(m, nn.Dropout):
78 | m.inplace = True
79 |
80 | def forward(self, e5_rgb,e4_rgb,e3_rgb,e2_rgb,e1_rgb, e5_depth,e4_depth,e3_depth,e2_depth,e1_depth):
81 |
82 | certain_feature5 = e5_rgb * e5_depth
83 | fuse5_conv1 = self.fuse5_conv1(e5_rgb+certain_feature5)
84 | fuse5_conv2 = self.fuse5_conv2(e5_depth+certain_feature5)
85 | fuse5_certain = self.fuse5_conv3(fuse5_conv1+fuse5_conv2)
86 |
87 | uncertain_feature5 = self.fuse5_conv4(torch.abs(e5_rgb - e5_depth))
88 | fuse5 = self.fuse5_conv5(fuse5_certain + uncertain_feature5)
89 | sideout5 = self.sideout5(fuse5)
90 |
91 | certain_feature4 = F.upsample(F.sigmoid(sideout5),size=e4_rgb.size()[2:], mode='bilinear')*e4_rgb * e4_depth
92 | fuse4_conv1 = self.fuse4_conv1(e4_rgb+certain_feature4)
93 | fuse4_conv2 = self.fuse4_conv2(e4_depth+certain_feature4)
94 | fuse4_certain = self.fuse4_conv3(fuse4_conv1+fuse4_conv2)
95 | uncertain_feature4 = self.fuse4_conv4(F.upsample(F.sigmoid(sideout5),size=e4_rgb.size()[2:], mode='bilinear')*torch.abs(e4_rgb - e4_depth))
96 | fuse4 = self.fuse4_conv5(fuse4_certain + uncertain_feature4)
97 | ###
98 | fuse5_fpn = F.upsample(fuse5, size=fuse4.size()[2:], mode='bilinear')
99 | fpn_certain_feature4 = F.upsample(F.sigmoid(sideout5), size=fuse4.size()[2:], mode='bilinear') * fuse4 * fuse5_fpn
100 | fuse4_fpn_conv1 = self.fuse4_conv1_fpn(fuse5_fpn + fpn_certain_feature4)
101 | fuse4_fpn_conv2 = self.fuse4_conv2_fpn(fuse4 + fpn_certain_feature4)
102 | fuse4_certain_fpn = self.fuse4_conv3_fpn(fuse4_fpn_conv1 + fuse4_fpn_conv2)
103 | fpn_uncertain_feature4 = self.fuse4_conv4_fpn(
104 | F.upsample(F.sigmoid(sideout5), size=fuse4.size()[2:], mode='bilinear') * torch.abs(fuse4 - fuse5_fpn))
105 | fuse4_fpn = self.fuse4_conv5_fpn(fuse4_certain_fpn + fpn_uncertain_feature4)
106 | output4 = self.output4(fuse4_fpn)
107 | sideout4 = self.sideout4(output4)
108 |
109 | certain_feature3 = F.upsample(F.sigmoid(sideout4),size=e3_rgb.size()[2:], mode='bilinear')*e3_rgb * e3_depth
110 | fuse3_conv1 = self.fuse3_conv1(e3_rgb + certain_feature3)
111 | fuse3_conv2 = self.fuse3_conv2(e3_depth + certain_feature3)
112 | fuse3_certain = self.fuse3_conv3(fuse3_conv1 + fuse3_conv2)
113 | uncertain_feature3 = self.fuse3_conv4( F.upsample(F.sigmoid(sideout4),size=e3_rgb.size()[2:], mode='bilinear')*torch.abs(e3_rgb - e3_depth))
114 | fuse3 = self.fuse3_conv5(fuse3_certain + uncertain_feature3)
115 | ##
116 | output4_fpn = F.upsample(output4, size=fuse3.size()[2:], mode='bilinear')
117 | fpn_certain_feature3 = F.upsample(F.sigmoid(sideout4), size=fuse3.size()[2:], mode='bilinear') * output4_fpn * fuse3
118 | fuse3_fpn_conv1 = self.fuse3_conv1_fpn(output4_fpn + fpn_certain_feature3)
119 | fuse3_fpn_conv2 = self.fuse3_conv2_fpn(fuse3 + fpn_certain_feature3)
120 | fuse3_certain_fpn = self.fuse3_conv3_fpn(fuse3_fpn_conv1 + fuse3_fpn_conv2)
121 |
122 | fpn_uncertain_feature3 = self.fuse3_conv4_fpn(
123 | F.upsample(F.sigmoid(sideout4), size=fuse3.size()[2:], mode='bilinear') * torch.abs(fuse3 - output4_fpn))
124 | fuse3_fpn = self.fuse3_conv5_fpn(fuse3_certain_fpn + fpn_uncertain_feature3)
125 | output3 = self.output3(fuse3_fpn)
126 | sideout3 = self.sideout3(output3)
127 |
128 | certain_feature2 = F.upsample(F.sigmoid(sideout3),size=e2_rgb.size()[2:], mode='bilinear')*e2_rgb* e2_depth
129 | fuse2_conv1 = self.fuse2_conv1(e2_rgb + certain_feature2)
130 | fuse2_conv2 = self.fuse2_conv2(e2_depth + certain_feature2)
131 | fuse2_certain = self.fuse2_conv3(fuse2_conv1 + fuse2_conv2)
132 | uncertain_feature2 = self.fuse2_conv4(F.upsample(F.sigmoid(sideout3),size=e2_rgb.size()[2:], mode='bilinear')*torch.abs(e2_rgb - e2_depth))
133 | fuse2 = self.fuse2_conv5(fuse2_certain + uncertain_feature2)
134 |
135 | output3_fpn = F.upsample(output3, size=fuse2.size()[2:], mode='bilinear')
136 | fpn_certain_feature2 = F.upsample(F.sigmoid(sideout3), size=fuse2.size()[2:], mode='bilinear') * output3_fpn * fuse2
137 | fuse2_fpn_conv1 = self.fuse2_conv1_fpn(output3_fpn + fpn_certain_feature2)
138 | fuse2_fpn_conv2 = self.fuse2_conv2_fpn(fuse2 + fpn_certain_feature2)
139 | fuse2_certain_fpn = self.fuse2_conv3_fpn(fuse2_fpn_conv1 + fuse2_fpn_conv2)
140 |
141 | fpn_uncertain_feature2 = self.fuse2_conv4_fpn(
142 | F.upsample(F.sigmoid(sideout3), size=fuse2.size()[2:], mode='bilinear') * torch.abs(fuse2 - output3_fpn))
143 | fuse2_fpn = self.fuse2_conv5_fpn(fuse2_certain_fpn + fpn_uncertain_feature2)
144 | output2 = self.output2(fuse2_fpn)
145 | sideout2 = self.sideout2(output2)
146 |
147 | certain_feature1 = F.upsample(F.sigmoid(sideout2),size=e1_rgb.size()[2:],mode='bilinear')*e1_rgb * e1_depth
148 | fuse1_conv1 = self.fuse1_conv1(e1_rgb+certain_feature1)
149 | fuse1_conv2 = self.fuse1_conv2(e1_depth+certain_feature1)
150 | fuse1_certain = self.fuse1_conv3(fuse1_conv1+fuse1_conv2)
151 | uncertain_feature1 = self.fuse1_conv4(F.upsample(F.sigmoid(sideout2),size=e1_rgb.size()[2:],mode='bilinear')*torch.abs(e1_rgb - e1_depth))
152 | fuse1 = self.fuse1_conv5(fuse1_certain+uncertain_feature1)
153 |
154 | output2_fpn = F.upsample(output2, size=fuse1.size()[2:], mode='bilinear')
155 | fpn_certain_feature1 = F.upsample(F.sigmoid(sideout2), size=fuse1.size()[2:],
156 | mode='bilinear') * output2_fpn * fuse1
157 | fuse1_fpn_conv1 = self.fuse1_conv1_fpn(output2_fpn + fpn_certain_feature1)
158 | fuse1_fpn_conv2 = self.fuse1_conv2_fpn(fuse1 + fpn_certain_feature1)
159 | fuse1_certain_fpn = self.fuse1_conv3_fpn(fuse1_fpn_conv1 + fuse1_fpn_conv2)
160 |
161 | fpn_uncertain_feature1 = self.fuse1_conv4_fpn(
162 | F.upsample(F.sigmoid(sideout2), size=fuse1.size()[2:], mode='bilinear') * torch.abs(fuse1 - output2_fpn))
163 | fuse1_fpn = self.fuse1_conv5_fpn(fuse1_certain_fpn + fpn_uncertain_feature1)
164 | output1 = self.output1(fuse1_fpn)
165 |
166 | if self.training:
167 | return sideout5,sideout4,sideout3,sideout2,output1
168 | return output1
169 | if __name__ == "__main__":
170 | model = RGBD_sal()
171 | depth = torch.randn(1, 3, 256, 256)
172 | input = torch.randn(1, 3, 256, 256)
173 | flops, params = profile(model,inputs=(input,depth))
174 | flops, params = clever_format([flops, params], "%.3f")
175 | print(flops,params)
--------------------------------------------------------------------------------
/utils_downstream/saliency_metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import ndimage
3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist
4 |
5 |
6 | class cal_fm(object):
7 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009)
8 | def __init__(self, num, thds=255):
9 | self.num = num
10 | self.thds = thds
11 | self.precision = np.zeros((self.num, self.thds))
12 | self.recall = np.zeros((self.num, self.thds))
13 | self.meanF = np.zeros((self.num,1))
14 | self.idx = 0
15 |
16 | def update(self, pred, gt):
17 | if gt.max() != 0:
18 | prediction, recall, Fmeasure_temp = self.cal(pred, gt)
19 | self.precision[self.idx, :] = prediction
20 | self.recall[self.idx, :] = recall
21 | self.meanF[self.idx, :] = Fmeasure_temp
22 | self.idx += 1
23 |
24 | def cal(self, pred, gt):
25 | ########################meanF##############################
26 | th = 2 * pred.mean()
27 | if th > 1:
28 | th = 1
29 | binary = np.zeros_like(pred)
30 | binary[pred >= th] = 1
31 | hard_gt = np.zeros_like(gt)
32 | hard_gt[gt > 0.5] = 1
33 | tp = (binary * hard_gt).sum()
34 | if tp == 0:
35 | meanF = 0
36 | else:
37 | pre = tp / binary.sum()
38 | rec = tp / hard_gt.sum()
39 | meanF = 1.3 * pre * rec / (0.3 * pre + rec)
40 | ########################maxF##############################
41 | pred = np.uint8(pred * 255)
42 | target = pred[gt > 0.5]
43 | nontarget = pred[gt <= 0.5]
44 | targetHist, _ = np.histogram(target, bins=range(256))
45 | nontargetHist, _ = np.histogram(nontarget, bins=range(256))
46 | targetHist = np.cumsum(np.flip(targetHist), axis=0)
47 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0)
48 | precision = targetHist / (targetHist + nontargetHist + 1e-8)
49 | recall = targetHist / np.sum(gt)
50 | return precision, recall, meanF
51 |
52 | def show(self):
53 | assert self.num == self.idx
54 | precision = self.precision.mean(axis=0)
55 | recall = self.recall.mean(axis=0)
56 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8)
57 | fmeasure_avg = self.meanF.mean(axis=0)
58 | return fmeasure.max(),fmeasure_avg[0],precision,recall
59 |
60 |
61 | class cal_mae(object):
62 | # mean absolute error
63 | def __init__(self):
64 | self.prediction = []
65 |
66 | def update(self, pred, gt):
67 | score = self.cal(pred, gt)
68 | self.prediction.append(score)
69 |
70 | def cal(self, pred, gt):
71 | return np.mean(np.abs(pred - gt))
72 |
73 | def show(self):
74 | return np.mean(self.prediction)
75 |
76 | class cal_dice(object):
77 | # mean absolute error
78 | def __init__(self):
79 | self.prediction = []
80 |
81 | def update(self, pred, gt):
82 | score = self.cal(pred, gt)
83 | self.prediction.append(score)
84 |
85 | def cal(self, y_pred, y_true):
86 | # smooth = 1
87 | smooth = 1e-5
88 | y_true_f = y_true.flatten()
89 | y_pred_f = y_pred.flatten()
90 | intersection = np.sum(y_true_f * y_pred_f)
91 | return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
92 |
93 | def show(self):
94 | return np.mean(self.prediction)
95 |
96 | class cal_ber(object):
97 | # mean absolute error
98 | def __init__(self):
99 | self.prediction = []
100 |
101 | def update(self, pred, gt):
102 | score = self.cal(pred, gt)
103 | self.prediction.append(score)
104 |
105 | def cal(self, y_pred, y_true):
106 | binary = np.zeros_like(y_pred)
107 | binary[y_pred >= 0.5] = 1
108 | hard_gt = np.zeros_like(y_true)
109 | hard_gt[y_true > 0.5] = 1
110 | tp = (binary * hard_gt).sum()
111 | tn = ((1-binary) * (1-hard_gt)).sum()
112 | Np = hard_gt.sum()
113 | Nn = (1-hard_gt).sum()
114 | ber = (1-(tp/(Np+1e-8)+tn/(Nn+1e-8))/2)
115 | return ber
116 |
117 | def show(self):
118 | return np.mean(self.prediction)
119 |
120 | class cal_acc(object):
121 | # mean absolute error
122 | def __init__(self):
123 | self.prediction = []
124 |
125 | def update(self, pred, gt):
126 | score = self.cal(pred, gt)
127 | self.prediction.append(score)
128 |
129 | def cal(self, y_pred, y_true):
130 | binary = np.zeros_like(y_pred)
131 | binary[y_pred >= 0.5] = 1
132 | hard_gt = np.zeros_like(y_true)
133 | hard_gt[y_true > 0.5] = 1
134 | tp = (binary * hard_gt).sum()
135 | tn = ((1-binary) * (1-hard_gt)).sum()
136 | Np = hard_gt.sum()
137 | Nn = (1-hard_gt).sum()
138 | acc = ((tp+tn)/(Np+Nn))
139 | return acc
140 |
141 | def show(self):
142 | return np.mean(self.prediction)
143 |
144 | class cal_iou(object):
145 | # mean absolute error
146 | def __init__(self):
147 | self.prediction = []
148 |
149 | def update(self, pred, gt):
150 | score = self.cal(pred, gt)
151 | self.prediction.append(score)
152 |
153 | # def cal(self, input, target):
154 | # classes = 1
155 | # intersection = np.logical_and(target == classes, input == classes)
156 | # # print(intersection.any())
157 | # union = np.logical_or(target == classes, input == classes)
158 | # return np.sum(intersection) / np.sum(union)
159 |
160 | def cal(self, input, target):
161 | smooth = 1e-5
162 | input = input > 0.5
163 | target_ = target > 0.5
164 | intersection = (input & target_).sum()
165 | union = (input | target_).sum()
166 |
167 | return (intersection + smooth) / (union + smooth)
168 | def show(self):
169 | return np.mean(self.prediction)
170 |
171 | # smooth = 1e-5
172 | #
173 | # if torch.is_tensor(output):
174 | # output = torch.sigmoid(output).data.cpu().numpy()
175 | # if torch.is_tensor(target):
176 | # target = target.data.cpu().numpy()
177 | # output_ = output > 0.5
178 | # target_ = target > 0.5
179 | # intersection = (output_ & target_).sum()
180 | # union = (output_ | target_).sum()
181 |
182 | # return (intersection + smooth) / (union + smooth)
183 |
184 | class cal_sm(object):
185 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017)
186 | def __init__(self, alpha=0.5):
187 | self.prediction = []
188 | self.alpha = alpha
189 |
190 | def update(self, pred, gt):
191 | gt = gt > 0.5
192 | score = self.cal(pred, gt)
193 | self.prediction.append(score)
194 |
195 | def show(self):
196 | return np.mean(self.prediction)
197 |
198 | def cal(self, pred, gt):
199 | y = np.mean(gt)
200 | if y == 0:
201 | score = 1 - np.mean(pred)
202 | elif y == 1:
203 | score = np.mean(pred)
204 | else:
205 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
206 | return score
207 |
208 | def object(self, pred, gt):
209 | fg = pred * gt
210 | bg = (1 - pred) * (1 - gt)
211 |
212 | u = np.mean(gt)
213 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt))
214 |
215 | def s_object(self, in1, in2):
216 | x = np.mean(in1[in2])
217 | sigma_x = np.std(in1[in2])
218 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8)
219 |
220 | def region(self, pred, gt):
221 | [y, x] = ndimage.center_of_mass(gt)
222 | y = int(round(y)) + 1
223 | x = int(round(x)) + 1
224 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y)
225 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y)
226 |
227 | score1 = self.ssim(pred1, gt1)
228 | score2 = self.ssim(pred2, gt2)
229 | score3 = self.ssim(pred3, gt3)
230 | score4 = self.ssim(pred4, gt4)
231 |
232 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
233 |
234 | def divideGT(self, gt, x, y):
235 | h, w = gt.shape
236 | area = h * w
237 | LT = gt[0:y, 0:x]
238 | RT = gt[0:y, x:w]
239 | LB = gt[y:h, 0:x]
240 | RB = gt[y:h, x:w]
241 |
242 | w1 = x * y / area
243 | w2 = y * (w - x) / area
244 | w3 = (h - y) * x / area
245 | w4 = (h - y) * (w - x) / area
246 |
247 | return LT, RT, LB, RB, w1, w2, w3, w4
248 |
249 | def dividePred(self, pred, x, y):
250 | h, w = pred.shape
251 | LT = pred[0:y, 0:x]
252 | RT = pred[0:y, x:w]
253 | LB = pred[y:h, 0:x]
254 | RB = pred[y:h, x:w]
255 |
256 | return LT, RT, LB, RB
257 |
258 | def ssim(self, in1, in2):
259 | in2 = np.float32(in2)
260 | h, w = in1.shape
261 | N = h * w
262 |
263 | x = np.mean(in1)
264 | y = np.mean(in2)
265 | sigma_x = np.var(in1)
266 | sigma_y = np.var(in2)
267 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1)
268 |
269 | alpha = 4 * x * y * sigma_xy
270 | beta = (x * x + y * y) * (sigma_x + sigma_y)
271 |
272 | if alpha != 0:
273 | score = alpha / (beta + 1e-8)
274 | elif alpha == 0 and beta == 0:
275 | score = 1
276 | else:
277 | score = 0
278 |
279 | return score
280 |
281 | class cal_em(object):
282 | #Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018)
283 | def __init__(self):
284 | self.prediction = []
285 |
286 | def update(self, pred, gt):
287 | score = self.cal(pred, gt)
288 | self.prediction.append(score)
289 |
290 | def cal(self, pred, gt):
291 | th = 2 * pred.mean()
292 | if th > 1:
293 | th = 1
294 | FM = np.zeros(gt.shape)
295 | FM[pred >= th] = 1
296 | FM = np.array(FM,dtype=bool)
297 | GT = np.array(gt,dtype=bool)
298 | dFM = np.double(FM)
299 | if (sum(sum(np.double(GT)))==0):
300 | enhanced_matrix = 1.0-dFM
301 | elif (sum(sum(np.double(~GT)))==0):
302 | enhanced_matrix = dFM
303 | else:
304 | dGT = np.double(GT)
305 | align_matrix = self.AlignmentTerm(dFM, dGT)
306 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix)
307 | [w, h] = np.shape(GT)
308 | score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8)
309 | return score
310 | def AlignmentTerm(self,dFM,dGT):
311 | mu_FM = np.mean(dFM)
312 | mu_GT = np.mean(dGT)
313 | align_FM = dFM - mu_FM
314 | align_GT = dGT - mu_GT
315 | align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8)
316 | return align_Matrix
317 | def EnhancedAlignmentTerm(self,align_Matrix):
318 | enhanced = np.power(align_Matrix + 1,2) / 4
319 | return enhanced
320 | def show(self):
321 | return np.mean(self.prediction)
322 | class cal_wfm(object):
323 | def __init__(self, beta=1):
324 | self.beta = beta
325 | self.eps = 1e-6
326 | self.scores_list = []
327 |
328 | def update(self, pred, gt):
329 | assert pred.ndim == gt.ndim and pred.shape == gt.shape
330 | assert pred.max() <= 1 and pred.min() >= 0
331 | assert gt.max() <= 1 and gt.min() >= 0
332 |
333 | gt = gt > 0.5
334 | if gt.max() == 0:
335 | score = 0
336 | else:
337 | score = self.cal(pred, gt)
338 | self.scores_list.append(score)
339 |
340 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5):
341 | """
342 | 2D gaussian mask - should give the same result as MATLAB's
343 | fspecial('gaussian',[shape],[sigma])
344 | """
345 | m, n = [(ss - 1.) / 2. for ss in shape]
346 | y, x = np.ogrid[-m:m + 1, -n:n + 1]
347 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
348 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
349 | sumh = h.sum()
350 | if sumh != 0:
351 | h /= sumh
352 | return h
353 |
354 | def cal(self, pred, gt):
355 | # [Dst,IDXT] = bwdist(dGT);
356 | Dst, Idxt = bwdist(gt == 0, return_indices=True)
357 |
358 | # %Pixel dependency
359 | # E = abs(FG-dGT);
360 | E = np.abs(pred - gt)
361 | # Et = E;
362 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region
363 | Et = np.copy(E)
364 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]]
365 |
366 | # K = fspecial('gaussian',7,5);
367 | # EA = imfilter(Et,K);
368 | # MIN_E_EA(GT & EA