├── smoothness ├── smoothness_loss ├── __pycache__ │ └── __init__.cpython-38.pyc └── __init__.py ├── main.jpg ├── RGBTScribble.pdf ├── __pycache__ ├── data.cpython-38.pyc ├── pamr.cpython-38.pyc ├── tools.cpython-38.pyc ├── utils.cpython-38.pyc └── lscloss.cpython-38.pyc ├── PVT_Model ├── __pycache__ │ ├── pvtv2.cpython-38.pyc │ └── pvtmodel.cpython-38.pyc ├── pvtmodel.py └── pvtv2.py ├── demo.py ├── README.md ├── tools.py ├── utils.py ├── test.py ├── pamr.py ├── lscloss.py ├── data.py └── train.py /smoothness/smoothness_loss: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/main.jpg -------------------------------------------------------------------------------- /RGBTScribble.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/RGBTScribble.pdf -------------------------------------------------------------------------------- /__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/pamr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/__pycache__/pamr.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/lscloss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/__pycache__/lscloss.cpython-38.pyc -------------------------------------------------------------------------------- /PVT_Model/__pycache__/pvtv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/PVT_Model/__pycache__/pvtv2.cpython-38.pyc -------------------------------------------------------------------------------- /PVT_Model/__pycache__/pvtmodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/PVT_Model/__pycache__/pvtmodel.cpython-38.pyc -------------------------------------------------------------------------------- /smoothness/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzywen/RGBTScribble-ICME2023/HEAD/smoothness/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from BBS_Model.pvtv2 import pvt_v2_b2 3 | 4 | path = "D:\HXS\pretraining parameters\pvt_v2_b2.pth" 5 | state = torch.load(path) 6 | print(state.keys()) 7 | model = pvt_v2_b2() 8 | print(model) 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scribble-Supervised RGB-T Salient Object Detection 2 | The [paper](RGBTScribble.pdf) has been accepted by ICME2023. 3 | 4 | 5 | ![Main](main.jpg) 6 | 7 | # RGBT-S dataset (RGB-T Scribble dataset) 8 | 链接:https://pan.baidu.com/s/1odSfb7XGnYQ6tUJ341UIdA 9 | 提取码:t67e 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | # Pretraining Parameters 18 | 链接:https://pan.baidu.com/s/10hcP9NjB8Z9VC9kzSt1eCQ 19 | 提取码:9eaa 20 | 21 | # Model Parameters 22 | 链接:https://pan.baidu.com/s/1DV_cKbjx5ZzRcomJCPcHSA 23 | 提取码:it8k 24 | 25 | # Evaluation Code 26 | 链接:https://pan.baidu.com/s/1W9L-twMVlM8CXfWJC40Y_g 27 | 提取码:t2vw 28 | 29 | # RGB-T Saliency Maps 30 | 链接:https://pan.baidu.com/s/1VGEHFnMm2fAiK_etfDRmSw 31 | 提取码:kd42 32 | 33 | ### Citation 34 | 35 | If you find the information useful, please consider citing: 36 | 37 | ``` 38 | @inproceedings{liu2023, 39 | author={Liu, Zhengyi and Huang, Xiaoshen and Zhang, Guanghui and Fang, Xianyong and Wang, Linbo and Tang, Bin}, 40 | journal={2023 IEEE International Conference on Multimedia and Expo (ICME)}, 41 | booktitle={Scribble-Supervised RGB-T Salient Object Detection}, 42 | pages={2369--2374}, 43 | year={2023}} 44 | ``` 45 | If you have any question, please email liuzywen@ahu.edu.cn 46 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | def ToLabel(E): 7 | fgs = np.argmax(E, axis=1).astype(np.float32) 8 | return fgs.astype(np.uint8) 9 | 10 | 11 | def SSIM(x, y): 12 | C1 = 0.01 ** 2 13 | C2 = 0.03 ** 2 14 | 15 | mu_x = nn.AvgPool2d(3, 1, 1)(x) 16 | mu_y = nn.AvgPool2d(3, 1, 1)(y) 17 | mu_x_mu_y = mu_x * mu_y 18 | mu_x_sq = mu_x.pow(2) 19 | mu_y_sq = mu_y.pow(2) 20 | 21 | sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq 22 | sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq 23 | sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y 24 | 25 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) 26 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) 27 | SSIM = SSIM_n / SSIM_d 28 | 29 | return torch.clamp((1 - SSIM) / 2, 0, 1) 30 | 31 | 32 | def SaliencyStructureConsistency(x, y, alpha): 33 | ssim = torch.mean(SSIM(x,y)) 34 | l1_loss = torch.mean(torch.abs(x-y)) 35 | loss_ssc = alpha*ssim + (1-alpha)*l1_loss 36 | return loss_ssc 37 | 38 | 39 | def SaliencyStructureConsistencynossim(x, y): 40 | l1_loss = torch.mean(torch.abs(x-y)) 41 | return l1_loss 42 | 43 | 44 | def set_seed(seed): 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed_all(seed) 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | torch.backends.cudnn.deterministic = True 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | fx = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).astype(np.float32) 8 | fy = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).astype(np.float32) 9 | fx = np.reshape(fx, (1, 1, 3, 3)) 10 | fy = np.reshape(fy, (1, 1, 3, 3)) 11 | fx = Variable(torch.from_numpy(fx)).cuda() 12 | fy = Variable(torch.from_numpy(fy)).cuda() 13 | contour_th = 1.5 14 | 15 | 16 | def label_edge_prediction(label): 17 | # convert label to edge 18 | label = label.gt(0.5).float() 19 | label = F.pad(label, (1, 1, 1, 1), mode='replicate') 20 | label_fx = F.conv2d(label, fx) 21 | label_fy = F.conv2d(label, fy) 22 | label_grad = torch.sqrt(torch.mul(label_fx, label_fx) + torch.mul(label_fy, label_fy)) 23 | label_grad = torch.gt(label_grad, contour_th).float() 24 | 25 | return label_grad 26 | 27 | 28 | def pred_edge_prediction(pred): 29 | # infer edge from prediction 30 | pred = F.pad(pred, (1, 1, 1, 1), mode='replicate') 31 | pred_fx = F.conv2d(pred, fx) 32 | pred_fy = F.conv2d(pred, fy) 33 | pred_grad = (pred_fx*pred_fx + pred_fy*pred_fy).sqrt().tanh() 34 | 35 | return pred_grad 36 | 37 | def clip_gradient(optimizer, grad_clip): 38 | for group in optimizer.param_groups: 39 | for param in group['params']: 40 | if param.grad is not None: 41 | param.grad.data.clamp_(-grad_clip, grad_clip) 42 | 43 | 44 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 45 | decay = decay_rate ** (epoch // decay_epoch) 46 | for param_group in optimizer.param_groups: 47 | param_group['lr'] *= decay 48 | -------------------------------------------------------------------------------- /smoothness/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | # from torch.autograd import Variable 4 | # import numpy as np 5 | def laplacian_edge(img): 6 | laplacian_filter = torch.Tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]) 7 | filter = torch.reshape(laplacian_filter, [1, 1, 3, 3]) 8 | filter = filter.cuda() 9 | lap_edge = F.conv2d(img, filter, stride=1, padding=1) 10 | return lap_edge 11 | 12 | def gradient_x(img): 13 | sobel = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) 14 | filter = torch.reshape(sobel,[1,1,3,3]) 15 | filter = filter.cuda() 16 | gx = F.conv2d(img, filter, stride=1, padding=1) 17 | return gx 18 | 19 | 20 | def gradient_y(img): 21 | sobel = torch.Tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) 22 | filter = torch.reshape(sobel, [1, 1,3,3]) 23 | filter = filter.cuda() 24 | gy = F.conv2d(img, filter, stride=1, padding=1) 25 | return gy 26 | 27 | def charbonnier_penalty(s): 28 | cp_s = torch.pow(torch.pow(s, 2) + 0.001**2, 0.5) 29 | return cp_s 30 | 31 | def get_saliency_smoothness(pred, gt, size_average=True): 32 | alpha = 10 33 | s1 = 10 34 | s2 = 1 35 | ## first oder derivative: sobel 36 | sal_x = torch.abs(gradient_x(pred)) 37 | sal_y = torch.abs(gradient_y(pred)) 38 | gt_x = gradient_x(gt) 39 | gt_y = gradient_y(gt) 40 | w_x = torch.exp(torch.abs(gt_x) * (-alpha)) 41 | w_y = torch.exp(torch.abs(gt_y) * (-alpha)) 42 | cps_x = charbonnier_penalty(sal_x * w_x) 43 | cps_y = charbonnier_penalty(sal_y * w_y) 44 | cps_xy = cps_x + cps_y 45 | 46 | ## second order derivative: laplacian 47 | lap_sal = torch.abs(laplacian_edge(pred)) 48 | lap_gt = torch.abs(laplacian_edge(gt)) 49 | weight_lap = torch.exp(lap_gt * (-alpha)) 50 | weighted_lap = charbonnier_penalty(lap_sal*weight_lap) 51 | 52 | smooth_loss = s1*torch.mean(cps_xy) + s2*torch.mean(weighted_lap) 53 | 54 | return smooth_loss 55 | 56 | class smoothness_loss(torch.nn.Module): 57 | def __init__(self, size_average = True): 58 | super(smoothness_loss, self).__init__() 59 | self.size_average = size_average 60 | 61 | def forward(self, pred, target): 62 | 63 | return get_saliency_smoothness(pred, target, self.size_average) 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | sys.path.append('./models') 5 | import numpy as np 6 | import os , argparse 7 | import cv2 8 | from PVT_Model.pvtmodel import PvtNet 9 | from data import test_dataset 10 | 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | # parser.add_argument('--trainsize', type=int, default=256, help='testing size') 15 | parser.add_argument('--testsize', type=int, default=256, help='testing size') 16 | # parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 17 | parser.add_argument('--test_path', type=str, default='', help='test dataset path') 18 | # parser.add_argument('--load', type=str, default=None, help='train from checkpoints') 19 | opt = parser.parse_args() 20 | 21 | dataset_path = opt.test_path 22 | 23 | #set device for test 24 | # if opt.gpu_id=='0': 25 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 26 | # print('USE GPU 0') 27 | # elif opt.gpu_id=='1': 28 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 29 | # print('USE GPU 1') 30 | 31 | #load the model 32 | model = PvtNet(None) 33 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training. 34 | model.load_state_dict(torch.load('')) 35 | model.cuda() 36 | model.eval() 37 | 38 | #test 39 | test_datasets = ['VT821', 'VT1000'] 40 | # test_datasets = ['VT5000'] 41 | # test_datasets = ['DES'] 42 | for dataset in test_datasets: 43 | save_path = './test_maps/' + dataset + '/' 44 | if not os.path.exists(save_path): 45 | os.makedirs(save_path) 46 | image_root = dataset_path +dataset+'/'+'RGB'+'/' 47 | gt_root = dataset_path +dataset+'/'+'GT'+'/' 48 | depth_root = dataset_path +dataset+'/'+'T'+'/' 49 | test_loader = test_dataset(image_root, gt_root, depth_root, opt.testsize) 50 | for i in range(test_loader.size): 51 | image, gt, depth, name, image_for_post = test_loader.load_data()#image_for_post有什么用? 52 | gt = np.asarray(gt, np.float32) 53 | gt /= (gt.max() + 1e-8) 54 | image = image.cuda() 55 | depth = depth.cuda() 56 | res, _, _, _ = model(image, depth) 57 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 58 | res = res.sigmoid().data.cpu().numpy().squeeze() 59 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 60 | print('save img to: ', save_path+name) 61 | cv2.imwrite(save_path+name, res*255) 62 | print('Test Done!') 63 | -------------------------------------------------------------------------------- /pamr.py: -------------------------------------------------------------------------------- 1 | ### 2 | # local pixel refinement 3 | ### 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | 11 | def get_kernel(): 12 | weight = torch.zeros(8, 1, 3, 3) 13 | weight[0, 0, 0, 0] = 1 14 | weight[1, 0, 0, 1] = 1 15 | weight[2, 0, 0, 2] = 1 16 | 17 | weight[3, 0, 1, 0] = 1 18 | weight[4, 0, 1, 2] = 1 19 | 20 | weight[5, 0, 2, 0] = 1 21 | weight[6, 0, 2, 1] = 1 22 | weight[7, 0, 2, 2] = 1 23 | 24 | return weight 25 | 26 | 27 | class PAR(nn.Module): 28 | 29 | def __init__(self, dilations, num_iter, ): 30 | super().__init__() 31 | self.dilations = dilations 32 | self.num_iter = num_iter 33 | kernel = get_kernel() 34 | self.register_buffer('kernel', kernel) 35 | self.pos = self.get_pos() 36 | self.dim = 2 37 | self.w1 = 0.3 38 | self.w2 = 0.01 39 | 40 | def get_dilated_neighbors(self, x): 41 | 42 | b, c, h, w = x.shape 43 | x_aff = [] 44 | for d in self.dilations: 45 | _x_pad = F.pad(x, [d] * 4, mode='replicate', value=0) 46 | _x_pad = _x_pad.reshape(b * c, -1, _x_pad.shape[-2], _x_pad.shape[-1]) 47 | _x = F.conv2d(_x_pad, self.kernel, dilation=d).view(b, c, -1, h, w) 48 | x_aff.append(_x) 49 | 50 | return torch.cat(x_aff, dim=2) 51 | 52 | def get_pos(self): 53 | pos_xy = [] 54 | 55 | ker = torch.ones(1, 1, 8, 1, 1) 56 | ker[0, 0, 0, 0, 0] = np.sqrt(2) 57 | ker[0, 0, 2, 0, 0] = np.sqrt(2) 58 | ker[0, 0, 5, 0, 0] = np.sqrt(2) 59 | ker[0, 0, 7, 0, 0] = np.sqrt(2) 60 | 61 | for d in self.dilations: 62 | pos_xy.append(ker * d) 63 | return torch.cat(pos_xy, dim=2) 64 | 65 | def forward(self, imgs, masks): 66 | 67 | masks = F.interpolate(masks, size=imgs.size()[-2:], mode="bilinear", align_corners=True) 68 | 69 | b, c, h, w = imgs.shape 70 | _imgs = self.get_dilated_neighbors(imgs) 71 | _pos = self.pos.to(_imgs.device) 72 | 73 | _imgs_rep = imgs.unsqueeze(self.dim).repeat(1, 1, _imgs.shape[self.dim], 1, 1) 74 | _pos_rep = _pos.repeat(b, 1, 1, h, w) 75 | 76 | _imgs_abs = torch.abs(_imgs - _imgs_rep) 77 | _imgs_std = torch.std(_imgs, dim=self.dim, keepdim=True) 78 | _pos_std = torch.std(_pos_rep, dim=self.dim, keepdim=True) 79 | 80 | aff = -(_imgs_abs / (_imgs_std + 1e-8) / self.w1) ** 2 81 | aff = aff.mean(dim=1, keepdim=True) 82 | 83 | pos_aff = -(_pos_rep / (_pos_std + 1e-8) / self.w1) ** 2 84 | # pos_aff = pos_aff.mean(dim=1, keepdim=True) 85 | 86 | aff = F.softmax(aff, dim=2) + self.w2 * F.softmax(pos_aff, dim=2) 87 | 88 | for _ in range(self.num_iter): 89 | _masks = self.get_dilated_neighbors(masks) 90 | masks = (_masks * aff).sum(2) 91 | 92 | return masks 93 | 94 | 95 | def run_pamr(im, mask): 96 | aff = PAR(num_iter=10, dilations=[1,2,4,8,12,24]) 97 | masks_dec = aff(im, mask) 98 | return masks_dec 99 | 100 | 101 | def BinaryPamr(img, sal, binary=0.4): 102 | pamr = PAR(num_iter=10, dilations=[1,2,4,8,12,24]).cuda() 103 | sal_pamr = pamr(img, sal) 104 | sal_pamr /= F.adaptive_max_pool2d(sal_pamr, (1, 1)) + 1e-5 105 | if binary is not None: 106 | sal_pamr[sal_pamr < binary] = 0 107 | sal_pamr[sal_pamr > binary] = 1 108 | return sal_pamr 109 | -------------------------------------------------------------------------------- /lscloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class LocalSaliencyCoherence(torch.nn.Module): 6 | """ 7 | This loss function based on the following paper. 8 | Please consider using the following bibtex for citation: 9 | @article{obukhov2019gated, 10 | author={Anton Obukhov and Stamatios Georgoulis and Dengxin Dai and Luc {Van Gool}}, 11 | title={Gated {CRF} Loss for Weakly Supervised Semantic Image Segmentation}, 12 | journal={CoRR}, 13 | volume={abs/1906.04651}, 14 | year={2019}, 15 | url={http://arxiv.org/abs/1906.04651}, 16 | } 17 | """ 18 | def forward( 19 | self, y_hat_softmax, kernels_desc, kernels_radius, sample, height_input, width_input, 20 | mask_src=None, mask_dst=None, compatibility=None, custom_modality_downsamplers=None, out_kernels_vis=False 21 | ): 22 | """ 23 | Performs the forward pass of the loss. 24 | :param y_hat_softmax: A tensor of predicted per-pixel class probabilities of size NxCxHxW 25 | :param kernels_desc: A list of dictionaries, each describing one Gaussian kernel composition from modalities. 26 | The final kernel is a weighted sum of individual kernels. Following example is a composition of 27 | RGBXY and XY kernels: 28 | kernels_desc: [{ 29 | 'weight': 0.9, # Weight of RGBXY kernel 30 | 'xy': 6, # Sigma for XY 31 | 'rgb': 0.1, # Sigma for RGB 32 | },{ 33 | 'weight': 0.1, # Weight of XY kernel 34 | 'xy': 6, # Sigma for XY 35 | }] 36 | :param kernels_radius: Defines size of bounding box region around each pixel in which the kernel is constructed. 37 | :param sample: A dictionary with modalities (except 'xy') used in kernels_desc parameter. Each of the provided 38 | modalities is allowed to be larger than the shape of y_hat_softmax, in such case downsampling will be 39 | invoked. Default downsampling method is area resize; this can be overriden by setting. 40 | custom_modality_downsamplers parameter. 41 | :param width_input, height_input: Dimensions of the full scale resolution of modalities 42 | :param mask_src: (optional) Source mask. 43 | :param mask_dst: (optional) Destination mask. 44 | :param compatibility: (optional) Classes compatibility matrix, defaults to Potts model. 45 | :param custom_modality_downsamplers: A dictionary of modality downsampling functions. 46 | :param out_kernels_vis: Whether to return a tensor with kernels visualized with some step. 47 | :return: Loss function value. 48 | """ 49 | assert y_hat_softmax.dim() == 4, 'Prediction must be a NCHW batch' 50 | N, C, height_pred, width_pred = y_hat_softmax.shape 51 | 52 | device = y_hat_softmax.device 53 | 54 | assert width_input % width_pred == 0 and height_input % height_pred == 0 and \ 55 | width_input * height_pred == height_input * width_pred, \ 56 | f'[{width_input}x{height_input}] !~= [{width_pred}x{height_pred}]' 57 | 58 | kernels = self._create_kernels( 59 | kernels_desc, kernels_radius, sample, N, height_pred, width_pred, device, custom_modality_downsamplers 60 | ) 61 | 62 | y_hat_unfolded = self._unfold(y_hat_softmax, kernels_radius) 63 | y_hat_unfolded = torch.abs(y_hat_unfolded[:, :, kernels_radius, kernels_radius, :, :].view(N, C, 1, 1, height_pred, width_pred) - y_hat_unfolded) 64 | 65 | loss = torch.mean((kernels * y_hat_unfolded).view(N, C, (kernels_radius * 2 + 1) ** 2, height_pred, width_pred).sum(dim=2, keepdim=True)) 66 | 67 | 68 | out = { 69 | 'loss': loss.mean(), 70 | } 71 | 72 | if out_kernels_vis: 73 | out['kernels_vis'] = self._visualize_kernels( 74 | kernels, kernels_radius, height_input, width_input, height_pred, width_pred 75 | ) 76 | 77 | return out 78 | 79 | @staticmethod 80 | def _downsample(img, modality, height_dst, width_dst, custom_modality_downsamplers): 81 | if custom_modality_downsamplers is not None and modality in custom_modality_downsamplers: 82 | f_down = custom_modality_downsamplers[modality] 83 | else: 84 | f_down = F.adaptive_avg_pool2d 85 | return f_down(img, (height_dst, width_dst)) 86 | 87 | @staticmethod 88 | def _create_kernels( 89 | kernels_desc, kernels_radius, sample, N, height_pred, width_pred, device, custom_modality_downsamplers 90 | ): 91 | kernels = None 92 | for i, desc in enumerate(kernels_desc): 93 | weight = desc['weight'] 94 | features = [] 95 | for modality, sigma in desc.items(): 96 | if modality == 'weight': 97 | continue 98 | if modality == 'xy': 99 | feature = LocalSaliencyCoherence._get_mesh(N, height_pred, width_pred, device) 100 | else: 101 | assert modality in sample, \ 102 | f'Modality {modality} is listed in {i}-th kernel descriptor, but not present in the sample' 103 | feature = sample[modality] 104 | # feature = LocalSaliencyCoherence._downsample( 105 | # feature, modality, height_pred, width_pred, custom_modality_downsamplers 106 | # ) 107 | feature /= sigma 108 | features.append(feature) 109 | features = torch.cat(features, dim=1) 110 | kernel = weight * LocalSaliencyCoherence._create_kernels_from_features(features, kernels_radius) 111 | kernels = kernel if kernels is None else kernel + kernels 112 | return kernels 113 | 114 | @staticmethod 115 | def _create_kernels_from_features(features, radius): 116 | assert features.dim() == 4, 'Features must be a NCHW batch' 117 | N, C, H, W = features.shape 118 | kernels = LocalSaliencyCoherence._unfold(features, radius) 119 | kernels = kernels - kernels[:, :, radius, radius, :, :].view(N, C, 1, 1, H, W) 120 | kernels = (-0.5 * kernels ** 2).sum(dim=1, keepdim=True).exp() 121 | # kernels[:, :, radius, radius, :, :] = 0 122 | return kernels 123 | 124 | @staticmethod 125 | def _get_mesh(N, H, W, device): 126 | return torch.cat(( 127 | torch.arange(0, W, 1, dtype=torch.float32, device=device).view(1, 1, 1, W).repeat(N, 1, H, 1), 128 | torch.arange(0, H, 1, dtype=torch.float32, device=device).view(1, 1, H, 1).repeat(N, 1, 1, W) 129 | ), 1) 130 | 131 | @staticmethod 132 | def _unfold(img, radius): 133 | assert img.dim() == 4, 'Unfolding requires NCHW batch' 134 | N, C, H, W = img.shape 135 | diameter = 2 * radius + 1 136 | return F.unfold(img, diameter, 1, radius).view(N, C, diameter, diameter, H, W) 137 | 138 | @staticmethod 139 | def _visualize_kernels(kernels, radius, height_input, width_input, height_pred, width_pred): 140 | diameter = 2 * radius + 1 141 | vis = kernels[:, :, :, :, radius::diameter, radius::diameter] 142 | vis_nh, vis_nw = vis.shape[-2:] 143 | vis = vis.permute(0, 1, 4, 2, 5, 3).contiguous().view(kernels.shape[0], 1, diameter * vis_nh, diameter * vis_nw) 144 | if vis.shape[2] > height_pred: 145 | vis = vis[:, :, :height_pred, :] 146 | if vis.shape[3] > width_pred: 147 | vis = vis[:, :, :, :width_pred] 148 | if vis.shape[2:] != (height_pred, width_pred): 149 | vis = F.pad(vis, [0, width_pred-vis.shape[3], 0, height_pred-vis.shape[2]]) 150 | vis = F.interpolate(vis, (height_input, width_input), mode='nearest') 151 | return vis 152 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | import cv2 8 | from fast_slic import Slic 9 | class SalObjDataset(data.Dataset): 10 | def __init__(self, image_root,depth_root, gt_root, mask_root, gray_root,edge_root,trainsize): 11 | self.trainsize = trainsize 12 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 13 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.png') or f.endswith('.jpg')] 14 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 15 | or f.endswith('.png')] 16 | self.masks = [mask_root + f for f in os.listdir(mask_root) if f.endswith('.png')] 17 | self.grays = [gray_root + f for f in os.listdir(gray_root) if f.endswith('.png')] 18 | self.edges = [edge_root + f for f in os.listdir(edge_root) if f.endswith('.png')] 19 | self.images = sorted(self.images) 20 | self.gts = sorted(self.gts) 21 | self.masks = sorted(self.masks) 22 | self.grays = sorted(self.grays) 23 | self.edges = sorted(self.edges) 24 | self.filter_files() 25 | self.size = len(self.images) 26 | self.img_transform = transforms.Compose([ 27 | transforms.Resize((self.trainsize, self.trainsize)), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 30 | self.depth_transform = transforms.Compose([ 31 | transforms.Resize((self.trainsize, self.trainsize)), 32 | transforms.ToTensor()]) 33 | self.gt_transform = transforms.Compose([ 34 | transforms.Resize((self.trainsize, self.trainsize)), 35 | transforms.ToTensor()]) 36 | self.mask_transform = transforms.Compose([ 37 | transforms.Resize((self.trainsize, self.trainsize)), 38 | transforms.ToTensor()]) 39 | self.gray_transform = transforms.Compose([ 40 | transforms.Resize((self.trainsize, self.trainsize)), 41 | transforms.ToTensor()]) 42 | self.edge_transform = transforms.Compose([ 43 | transforms.Resize((self.trainsize, self.trainsize)), 44 | transforms.ToTensor()]) 45 | 46 | def __getitem__(self, index): 47 | image = self.rgb_loader(self.images[index]) 48 | depth = self.rgb_loader(self.depths[index]) 49 | gt = self.binary_loader(self.gts[index]) 50 | mask = self.binary_loader(self.masks[index]) 51 | gray = self.binary_loader(self.grays[index]) 52 | edge = self.binary_loader(self.edges[index]) 53 | name = self.images[index].split('/')[-1] 54 | np_img = np.array(image) 55 | np_img = cv2.resize(np_img, dsize=(self.trainsize, self.trainsize), interpolation=cv2.INTER_LINEAR) 56 | 57 | np_depth = np.array(depth) 58 | np_depth = cv2.resize(np_depth, dsize=(self.trainsize, self.trainsize), interpolation=cv2.INTER_LINEAR) 59 | 60 | np_gt = np.array(gt) 61 | np_gt = cv2.resize(np_gt, dsize=(self.trainsize, self.trainsize), interpolation=cv2.INTER_LINEAR) / 255 62 | 63 | np_mask = np.array(mask) 64 | np_mask = cv2.resize(np_mask, dsize=(self.trainsize, self.trainsize), interpolation=cv2.INTER_LINEAR) / 255 65 | 66 | slic = Slic(num_components=40, compactness=10) 67 | SS_map = slic.iterate(np_img) 68 | SS_map_depth = slic.iterate(np_depth) 69 | 70 | SS_map = SS_map + 1 71 | SS_map_depth = SS_map_depth + 1 72 | 73 | SS_maps_label = [] 74 | SS_maps = [] 75 | 76 | # SS_maps_label_mask = [] 77 | # SS_maps_mask = [] 78 | 79 | SS_maps_label_depth = [] 80 | SS_maps_depth = [] 81 | 82 | # SS_maps_label_mask_depth = [] 83 | # SS_maps_mask_depth = [] 84 | 85 | label_gt = np.zeros((1, self.trainsize, self.trainsize)) 86 | # label_mask = np.zeros((1, self.trainsize, self.trainsize)) 87 | 88 | label_gt_depth = np.zeros((1, self.trainsize, self.trainsize)) 89 | # label_mask_depth = np.zeros((1, self.trainsize, self.trainsize)) 90 | 91 | for i in range(1, 40+ 1): 92 | buffer = np.copy(SS_map) 93 | buffer[buffer != i] = 0 94 | buffer[buffer == i] = 1 95 | 96 | if np.sum(buffer) != 0: 97 | if np.sum(buffer * np_gt) > 1: 98 | label_gt = label_gt+buffer 99 | SS_maps_label.append(1) 100 | 101 | else: 102 | SS_maps_label.append(0) 103 | else: 104 | SS_maps_label.append(0) 105 | SS_maps.append(buffer) 106 | label_gt = torch.tensor(label_gt) 107 | label_gt = label_gt.to(torch.float32) 108 | 109 | 110 | 111 | for i in range(1, 40 + 1): 112 | buffer = np.copy(SS_map_depth) 113 | buffer[buffer != i] = 0 114 | buffer[buffer == i] = 1 115 | 116 | if np.sum(buffer) != 0: 117 | if np.sum(buffer * np_gt) > 1: 118 | label_gt_depth = label_gt_depth+buffer 119 | SS_maps_label_depth.append(1) 120 | 121 | else: 122 | SS_maps_label_depth.append(0) 123 | else: 124 | SS_maps_label_depth.append(0) 125 | SS_maps_depth.append(buffer) 126 | label_gt_depth = torch.tensor(label_gt_depth) 127 | label_gt_depth = label_gt_depth.to(torch.float32) 128 | 129 | 130 | 131 | image = self.img_transform(image) 132 | depth = self.depth_transform(depth) 133 | gt = self.gt_transform(gt) 134 | mask = self.mask_transform(mask) 135 | gray = self.gray_transform(gray) 136 | edge = self.edge_transform(edge) 137 | return image, depth, gt, mask, gray, edge, label_gt, label_gt_depth,name 138 | 139 | def filter_files(self): 140 | assert len(self.images) == len(self.gts) 141 | images = [] 142 | depths = [] 143 | gts = [] 144 | masks = [] 145 | grays = [] 146 | edges = [] 147 | for img_path,depth_path, gt_path, mask_path, gray_path,edge_path in zip(self.images, self.depths, self.gts, self.masks, self.grays,self.edges): 148 | img = Image.open(img_path) 149 | depth = Image.open(depth_path) 150 | gt = Image.open(gt_path) 151 | mask = Image.open(mask_path) 152 | gray = Image.open(gray_path) 153 | edge = Image.open(edge_path) 154 | if img.size == gt.size: 155 | images.append(img_path) 156 | depths.append(depth_path) 157 | gts.append(gt_path) 158 | masks.append(mask_path) 159 | grays.append(gray_path) 160 | edges.append(edge_path) 161 | self.images = images 162 | self.depths = depths 163 | self.gts = gts 164 | self.masks = masks 165 | self.grays = grays 166 | self.edges = edges 167 | 168 | def rgb_loader(self, path): 169 | with open(path, 'rb') as f: 170 | img = Image.open(f) 171 | return img.convert('RGB') 172 | 173 | def binary_loader(self, path): 174 | with open(path, 'rb') as f: 175 | img = Image.open(f) 176 | # return img.convert('1') 177 | return img.convert('L') 178 | 179 | def resize(self, img, gt): 180 | assert img.size == gt.size 181 | w, h = img.size 182 | if h < self.trainsize or w < self.trainsize: 183 | h = max(h, self.trainsize) 184 | w = max(w, self.trainsize) 185 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 186 | else: 187 | return img, gt 188 | 189 | def __len__(self): 190 | return self.size 191 | 192 | 193 | def get_loader(image_root,depth_root, gt_root, mask_root, gray_root,edge_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True): 194 | 195 | dataset = SalObjDataset(image_root, depth_root, gt_root, mask_root, gray_root, edge_root, trainsize) 196 | data_loader = data.DataLoader(dataset=dataset, 197 | batch_size=batchsize, 198 | shuffle=shuffle, 199 | num_workers=num_workers, 200 | pin_memory=pin_memory) 201 | return data_loader 202 | 203 | class test_dataset: 204 | def __init__(self, image_root, gt_root,depth_root, testsize): 205 | self.testsize = testsize 206 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 207 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 208 | or f.endswith('.png')] 209 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 210 | or f.endswith('.jpg')] 211 | self.images = sorted(self.images) 212 | self.gts = sorted(self.gts) 213 | self.depths=sorted(self.depths) 214 | self.transform = transforms.Compose([ 215 | transforms.Resize((self.testsize, self.testsize)), 216 | transforms.ToTensor(), 217 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 218 | self.gt_transform = transforms.ToTensor() 219 | # self.gt_transform = transforms.Compose([ 220 | # transforms.Resize((self.trainsize, self.trainsize)), 221 | # transforms.ToTensor()]) 222 | self.depths_transform = transforms.Compose([transforms.Resize((self.testsize, self.testsize)),transforms.ToTensor()]) 223 | self.size = len(self.images) 224 | self.index = 0 225 | 226 | def load_data(self): 227 | image = self.rgb_loader(self.images[self.index]) 228 | image = self.transform(image).unsqueeze(0) 229 | gt = self.binary_loader(self.gts[self.index]) 230 | depth=self.rgb_loader(self.depths[self.index]) 231 | depth=self.depths_transform(depth).unsqueeze(0) 232 | name = self.images[self.index].split('/')[-1] 233 | image_for_post=self.rgb_loader(self.images[self.index]) 234 | image_for_post=image_for_post.resize(gt.size) 235 | if name.endswith('.jpg'): 236 | name = name.split('.jpg')[0] + '.jpg' 237 | self.index += 1 238 | self.index = self.index % self.size 239 | return image, gt,depth, name,np.array(image_for_post) 240 | 241 | 242 | def rgb_loader(self, path): 243 | with open(path, 'rb') as f: 244 | img = Image.open(f) 245 | return img.convert('RGB') 246 | 247 | def binary_loader(self, path): 248 | with open(path, 'rb') as f: 249 | img = Image.open(f) 250 | return img.convert('L') 251 | 252 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | from PVT_Model.pvtmodel import PvtNet 5 | import numpy as np 6 | import pdb, os, argparse 7 | from datetime import datetime 8 | from data import get_loader, test_dataset 9 | from utils import clip_gradient, adjust_lr 10 | from pamr import BinaryPamr 11 | import os 12 | import logging 13 | from scipy import misc 14 | from fast_slic import Slic 15 | import smoothness 16 | from tools import * 17 | import imageio 18 | from lscloss import * 19 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--epoch', type=int, default=300, help='epoch number') 23 | parser.add_argument('--lr', type=float, default=5e-5, help='learning rate') 24 | parser.add_argument('--batchsize', type=int, default=16, help='training batch size') 25 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size') 26 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 27 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 28 | parser.add_argument('--decay_epoch', type=int, default=200, help='every n epochs decay learning rate') 29 | parser.add_argument('--sm_loss_weight', type=float, default=0.3, help='weight for smoothness loss') 30 | parser.add_argument('--edge_loss_weight', type=float, default=1.0, help='weight for edge loss') 31 | parser.add_argument('--save_path', type=str, default='', help='the path to save models and logs') 32 | parser.add_argument('--load', type=str, default='', help='train from checkpoints') 33 | opt = parser.parse_args() 34 | 35 | print('Learning Rate: {}'.format(opt.lr)) 36 | # build models 37 | model = PvtNet(opt) 38 | model.encoder_rgb.load_state_dict(torch.load(""), strict=True) 39 | model.encoder_depth.load_state_dict(torch.load(""), strict=True) 40 | # if(opt.load is not None): 41 | # model.pvtb2.init_weights(opt.load) 42 | model.cuda() 43 | params = model.parameters() 44 | optimizer = torch.optim.Adam(params, opt.lr) 45 | 46 | 47 | image_root = '' 48 | depth_root = '' 49 | gt_root = '' 50 | mask_root = '' 51 | grayimg_root = '' 52 | edge_root = '' 53 | test_image_root = '' 54 | test_gt_root ='' 55 | test_depth_root ='' 56 | train_loader = get_loader(image_root, depth_root, gt_root, mask_root, grayimg_root, edge_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 57 | test_loader = test_dataset(test_image_root, test_gt_root, test_depth_root, opt.trainsize) 58 | total_step = len(train_loader) 59 | 60 | CE = torch.nn.BCELoss() 61 | smooth_loss = smoothness.smoothness_loss(size_average=True) 62 | 63 | best_mae = 1 64 | best_epoch = 0 65 | 66 | loss_lsc = LocalSaliencyCoherence().cuda() 67 | loss_lsc_kernels_desc_defaults = [{"weight": 1, "xy": 6, "rgb": 0.1}] 68 | 69 | loss_lsc_radius = 5 70 | save_path = opt.save_path 71 | 72 | logging.basicConfig(filename=save_path + 'log.log', format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', 73 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p') 74 | logging.info("scribbleNet-Train") 75 | logging.info("Config") 76 | logging.info( 77 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format( 78 | opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, opt.decay_rate, opt.load, save_path, 79 | opt.decay_epoch)) 80 | 81 | def structure_loss(pred, mask): 82 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 83 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 84 | wbce = (weit*wbce).sum(dim=(2, 3))/weit.sum(dim=(2, 3)) 85 | 86 | pred = torch.sigmoid(pred) 87 | inter = ((pred*mask)*weit).sum(dim=(2,3)) 88 | union = ((pred+mask)*weit).sum(dim=(2,3)) 89 | wiou = 1-(inter+1)/(union-inter+1) 90 | return (wbce+wiou).mean() 91 | 92 | # def visualize_prediction1(pred,name): 93 | # for kk in range(pred.shape[0]): 94 | # pred_edge_kk = pred[kk, :, :, :] 95 | # pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 96 | # pred_edge_kk = (pred_edge_kk - pred_edge_kk.min()) / (pred_edge_kk.max() - pred_edge_kk.min() + 1e-8) 97 | # pred_edge_kk *= 255.0 98 | # pred_edge_kk = pred_edge_kk.astype(np.uint8) 99 | # save_path = './label_gt/' 100 | # if not os.path.exists(save_path): 101 | # os.makedirs(save_path) 102 | # # name = '{:02d}_sal1.png'.format(kk) 103 | # imageio.imsave(save_path + name[kk], pred_edge_kk) 104 | 105 | 106 | def run_pamr(img, sal): 107 | lbl_self = BinaryPamr(img, sal.clone().detach(), binary=0.4) 108 | return lbl_self 109 | 110 | 111 | def train(train_loader, model, optimizer, epoch): 112 | # global step 113 | model.train() 114 | loss_all = 0 115 | epoch_step = 0 116 | for i, pack in enumerate(train_loader, start=1): 117 | optimizer.zero_grad() 118 | images, depths, gts, masks, grays, edges, label_gt, label_gt_depth, name = pack 119 | images = Variable(images) 120 | depths = Variable(depths) 121 | gts = Variable(gts) 122 | masks = Variable(masks) 123 | grays = Variable(grays) 124 | # edges = Variable(edges) 125 | label_gt = Variable(label_gt) 126 | label_gt_depth = Variable(label_gt_depth) 127 | images = images.cuda() 128 | # depths = depths.repeat(1, 3, 1, 1, ).cuda() 129 | depths = depths.cuda() 130 | gts = gts.cuda() 131 | masks = masks.cuda() 132 | grays = grays.cuda() 133 | # edges = edges.cuda() 134 | label_gt = label_gt.cuda() 135 | label_gt_depth = label_gt_depth.cuda() 136 | 137 | 138 | img_size = images.size(2) * images.size(3) * images.size(0) 139 | ratio = img_size / torch.sum(masks) 140 | 141 | result_final, mask4, mask3, mask2, sal1, sal2 = model(images, depths) 142 | 143 | # BCEloss for the 1st DF 144 | sal1_loss = CE(sal1, label_gt) 145 | 146 | # BCEloss for the 2nd DF 147 | sal2_loss = CE(sal2, label_gt_depth) 148 | 149 | # visualize_prediction1(sal1, name) 150 | # The self-supervision term between 1st DF and 2nd DF 151 | 152 | 153 | # Guidance loss for the final saliency decoder 154 | lbl_tea = run_pamr(images, (sal1 + sal2) / 2) 155 | # visualize_prediction1(lbl_tea, name) 156 | loss_reult_final = structure_loss(torch.sigmoid(result_final), lbl_tea) 157 | loss_reult_mask4 = structure_loss(torch.sigmoid(mask4), lbl_tea) 158 | loss_reult_mask3 = structure_loss(torch.sigmoid(mask3), lbl_tea) 159 | loss_reult_mask2 = structure_loss(torch.sigmoid(mask2), lbl_tea) 160 | 161 | image_scale = F.interpolate(images, scale_factor=0.25, mode='bilinear', align_corners=True) 162 | depth_scale = F.interpolate(depths, scale_factor=0.25, mode='bilinear', align_corners=True) 163 | # 164 | result_final_scale, mask4_s, mask3_s, mask2_s, sal1_s, sal2_s = model(image_scale, depth_scale) 165 | result_out_scale = F.interpolate(torch.sigmoid(result_final), scale_factor=0.25, mode='bilinear', align_corners=True) 166 | loss_ssc = SaliencyStructureConsistency(torch.sigmoid(result_final_scale), result_out_scale, 0.85) 167 | 168 | images_ = F.interpolate(images, scale_factor=0.25, mode="bilinear", align_corners=True) 169 | sample_rgb = {'rgb': images_} 170 | 171 | # 172 | final_prob = torch.sigmoid(result_final) 173 | final_prob = final_prob * masks 174 | smoothLoss_cur_final = opt.sm_loss_weight * smooth_loss(torch.sigmoid(result_final), grays) 175 | sal_loss_final = ratio * CE(final_prob, gts * masks) + smoothLoss_cur_final 176 | 177 | result_final_ = F.interpolate(torch.sigmoid(result_final), scale_factor=0.25, mode="bilinear", align_corners=True) 178 | lossfinal_lsc = loss_lsc(result_final_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample_rgb, images_.shape[2],images_.shape[3])['loss'] 179 | lossfinal = sal_loss_final + lossfinal_lsc + loss_ssc + loss_reult_final 180 | 181 | mask4_prob = torch.sigmoid(mask4) 182 | mask4_prob = mask4_prob * masks 183 | smoothLoss_cur_mask4 = opt.sm_loss_weight * smooth_loss(torch.sigmoid(mask4), grays) 184 | sal_loss_mask4 = ratio * CE(mask4_prob, gts * masks) + smoothLoss_cur_mask4 185 | 186 | 187 | mask4_ = F.interpolate(torch.sigmoid(mask4), scale_factor=0.25, mode="bilinear", align_corners=True) 188 | lossmask4_lsc = loss_lsc(mask4_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample_rgb, images_.shape[2],images_.shape[3])['loss'] 189 | lossmask4 = sal_loss_mask4 + lossmask4_lsc +loss_reult_mask4 190 | 191 | 192 | mask3_prob = torch.sigmoid(mask3) 193 | mask3_prob = mask3_prob * masks 194 | smoothLoss_cur_mask3 = opt.sm_loss_weight * smooth_loss(torch.sigmoid(mask3), grays) 195 | sal_loss_mask3 = ratio * CE(mask3_prob, gts * masks) + smoothLoss_cur_mask3 196 | 197 | 198 | mask3_ = F.interpolate(torch.sigmoid(mask3), scale_factor=0.25, mode="bilinear", align_corners=True) 199 | lossmask3_lsc = loss_lsc(mask3_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample_rgb, images_.shape[2],images_.shape[3])['loss'] 200 | lossmask3 = sal_loss_mask3 + lossmask3_lsc +loss_reult_mask3 201 | 202 | mask2_prob = torch.sigmoid(mask2) 203 | mask2_prob = mask2_prob * masks 204 | smoothLoss_cur_mask2 = opt.sm_loss_weight * smooth_loss(torch.sigmoid(mask2), grays) 205 | sal_loss_mask2 = ratio * CE(mask2_prob, gts * masks) + smoothLoss_cur_mask2 206 | 207 | 208 | mask2_ = F.interpolate(torch.sigmoid(mask2), scale_factor=0.25, mode="bilinear", align_corners=True) 209 | lossmask2_lsc = loss_lsc(mask2_, loss_lsc_kernels_desc_defaults, loss_lsc_radius, sample_rgb, images_.shape[2], images_.shape[3])['loss'] 210 | lossmask2 = sal_loss_mask2 + lossmask2_lsc +loss_reult_mask2 211 | loss = lossfinal * 1 + lossmask2 * 0.8 + lossmask3 * 0.6 + lossmask4 * 0.4 + sal1_loss + sal2_loss 212 | 213 | loss.backward() 214 | 215 | clip_gradient(optimizer, opt.clip) 216 | optimizer.step() 217 | # if i % 10 == 0 or i == total_step: 218 | # print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], sal1_loss: {:0.4f}, loss: {:0.4f}, sal2_loss: {:0.4f}'. 219 | # format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss.data, loss2.data)) 220 | if i % 100 == 0 or i == total_step or i == 1: 221 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Lossfinal: {:.4f}'. 222 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss.data)) 223 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Lossfinal: {:.4f}'. 224 | format(epoch, opt.epoch, i, total_step, loss.data)) 225 | 226 | if not os.path.exists(save_path): 227 | os.makedirs(save_path) 228 | if epoch % 30 == 0: 229 | torch.save(model.state_dict(), save_path + 'scribble' + '_%d' % epoch + '.pth') 230 | 231 | def test(test_loader, model, epoch, save_path): 232 | global best_mae, best_epoch 233 | # 神经网络沿用batch normalization的值,并不使用drop out 234 | model.eval() 235 | with torch.no_grad(): 236 | mae_sum = 0 237 | for i in range(test_loader.size): 238 | image, gt, depth, name, img_for_post = test_loader.load_data() 239 | gt = np.asarray(gt, np.float32) 240 | gt /= (gt.max() + 1e-8) 241 | image = image.cuda() 242 | # depth = depth.repeat(1, 3, 1, 1, ).cuda() 243 | depth = depth.cuda() 244 | res, _, _, _ = model(image, depth) 245 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 246 | res = res.sigmoid().data.cpu().numpy().squeeze() 247 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 248 | mae_sum += np.sum(np.abs(res - gt)) * 1.0 / (gt.shape[0] * gt.shape[1]) 249 | mae = mae_sum / test_loader.size 250 | # writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch) 251 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch)) 252 | if epoch == 1: 253 | best_mae = mae 254 | else: 255 | if mae < best_mae: 256 | best_mae = mae 257 | best_epoch = epoch 258 | torch.save(model.state_dict(), save_path + 'scribble_epoch_best.pth') 259 | print('best epoch:{}'.format(epoch)) 260 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae)) 261 | 262 | print("Starting!") 263 | for epoch in range(1, opt.epoch+1): 264 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 265 | train(train_loader, model, optimizer, epoch) 266 | test(test_loader, model, epoch, save_path) 267 | -------------------------------------------------------------------------------- /PVT_Model/pvtmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | # from torchsummary import summary 5 | from .pvtv2 import pvt_v2_b2 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | class feature_fuse(nn.Module): 11 | def __init__(self, in_channel=128, out_channel=128): 12 | super(feature_fuse, self).__init__() 13 | self.dim = in_channel 14 | self.out_dim = out_channel 15 | self.fuseconv = nn.Sequential(nn.Conv2d(2 * self.dim, self.out_dim, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(self.out_dim), 17 | nn.ReLU(True)) 18 | self.conv = nn.Sequential(nn.Conv2d(self.out_dim, self.out_dim, 3, 1, 1, bias=False), 19 | nn.BatchNorm2d(self.out_dim), 20 | nn.ReLU(True)) 21 | 22 | def forward(self, Ri, Di): 23 | assert Ri.ndim == 4 24 | RDi = torch.cat((Ri, Di), dim=1) 25 | RDi = self.fuseconv(RDi) 26 | RDi = self.conv(RDi) 27 | return RDi 28 | class CALayer(nn.Module): 29 | def __init__(self, channel, reduction=16): 30 | super(CALayer, self).__init__() 31 | # global average pooling: feature --> point 32 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 33 | # feature channel downscale and upscale --> channel weight 34 | self.conv_du = nn.Sequential( 35 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 38 | nn.Sigmoid() 39 | ) 40 | 41 | def forward(self, x): 42 | y = self.avg_pool(x) 43 | y = self.conv_du(y) 44 | return x * y 45 | 46 | ## Residual Channel Attention Block (RCAB) 47 | class RCAB(nn.Module): 48 | def __init__( 49 | self, n_feat, kernel_size=3, reduction=16, 50 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 51 | 52 | super(RCAB, self).__init__() 53 | modules_body = [] 54 | for i in range(2): 55 | modules_body.append(self.default_conv(n_feat, n_feat, kernel_size, bias=bias)) 56 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 57 | if i == 0: modules_body.append(act) 58 | modules_body.append(CALayer(n_feat, reduction)) 59 | self.body = nn.Sequential(*modules_body) 60 | self.res_scale = res_scale 61 | 62 | def default_conv(self, in_channels, out_channels, kernel_size, bias=True): 63 | return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size // 2), bias=bias) 64 | 65 | def forward(self, x): 66 | res = self.body(x) 67 | #res = self.body(x).mul(self.res_scale) 68 | res += x 69 | return res 70 | 71 | class _AtrousSpatialPyramidPoolingModule(nn.Module): 72 | ''' 73 | operations performed: 74 | 1x1 x depth 75 | 3x3 x depth dilation 6 76 | 3x3 x depth dilation 12 77 | 3x3 x depth dilation 18 78 | image pooling 79 | concatenate all together 80 | Final 1x1 conv 81 | ''' 82 | 83 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=[6, 12, 18]): 84 | super(_AtrousSpatialPyramidPoolingModule, self).__init__() 85 | 86 | # Check if we are using distributed BN and use the nn from encoding.nn 87 | # library rather than using standard pytorch.nn 88 | 89 | if output_stride == 8: 90 | rates = [2 * r for r in rates] 91 | elif output_stride == 16: 92 | pass 93 | else: 94 | raise 'output stride of {} not supported'.format(output_stride) 95 | 96 | self.features = [] 97 | # 1x1 98 | self.features.append( 99 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 100 | nn.ReLU(inplace=True))) 101 | # other rates 102 | for r in rates: 103 | self.features.append(nn.Sequential( 104 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, 105 | dilation=r, padding=r, bias=False), 106 | nn.ReLU(inplace=True) 107 | )) 108 | self.features = torch.nn.ModuleList(self.features) 109 | 110 | # img level features 111 | self.img_pooling = nn.AdaptiveAvgPool2d(1) 112 | self.img_conv = nn.Sequential( 113 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 114 | nn.ReLU(inplace=True)) 115 | self.edge_conv = nn.Sequential( 116 | nn.Conv2d(1, reduction_dim, kernel_size=1, bias=False), 117 | nn.ReLU(inplace=True)) 118 | 119 | def forward(self, x): 120 | x_size = x.size() 121 | 122 | img_features = self.img_pooling(x) 123 | img_features = self.img_conv(img_features) 124 | img_features = F.interpolate(img_features, x_size[2:], 125 | mode='bilinear', align_corners=True) 126 | out = img_features 127 | 128 | # edge_features = F.interpolate(edge, x_size[2:], 129 | # mode='bilinear', align_corners=True) 130 | # edge_features = self.edge_conv(edge_features) 131 | # out = torch.cat((out, edge_features), 1) 132 | 133 | for f in self.features: 134 | y = f(x) 135 | out = torch.cat((out, y), 1) 136 | return out 137 | 138 | 139 | class Edge_Module(nn.Module): 140 | 141 | def __init__(self, in_fea=[128, 320, 512], mid_fea=32): 142 | super(Edge_Module, self).__init__() 143 | self.relu = nn.ReLU(inplace=True) 144 | self.conv2 = nn.Conv2d(in_fea[0], mid_fea, 1) 145 | self.conv4 = nn.Conv2d(in_fea[1], mid_fea, 1) 146 | self.conv5 = nn.Conv2d(in_fea[2], mid_fea, 1) 147 | self.conv5_2 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1) 148 | self.conv5_4 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1) 149 | self.conv5_5 = nn.Conv2d(mid_fea, mid_fea, 3, padding=1) 150 | 151 | self.classifer = nn.Conv2d(mid_fea * 3, 1, kernel_size=3, padding=1) 152 | self.rcab = RCAB(mid_fea * 3) 153 | 154 | def forward(self,input, x2, x4, x5): 155 | _, _, h, w = input.size() 156 | edge2_fea = self.relu(self.conv2(x2)) 157 | edge2 = self.relu(self.conv5_2(edge2_fea)) 158 | edge4_fea = self.relu(self.conv4(x4)) 159 | edge4 = self.relu(self.conv5_4(edge4_fea)) 160 | edge5_fea = self.relu(self.conv5(x5)) 161 | edge5 = self.relu(self.conv5_5(edge5_fea)) 162 | 163 | edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear', align_corners=True) 164 | edge4 = F.interpolate(edge4, size=(h, w), mode='bilinear', align_corners=True) 165 | edge5 = F.interpolate(edge5, size=(h, w), mode='bilinear', align_corners=True) 166 | 167 | edge = torch.cat([edge2, edge4, edge5], dim=1) 168 | edge = self.rcab(edge) 169 | edge = self.classifer(edge) 170 | return edge 171 | 172 | class Decoder(nn.Module): 173 | def __init__(self, dim=128): 174 | super(Decoder, self).__init__() 175 | self.dim = dim 176 | self.out_dim = dim 177 | self.fuse1 = feature_fuse(in_channel=64, out_channel=128) 178 | self.fuse2 = feature_fuse(in_channel=128, out_channel=128) 179 | self.fuse3 = feature_fuse(in_channel=320, out_channel=128) 180 | self.fuse4 = feature_fuse(in_channel=320, out_channel=128) 181 | 182 | self.up2 = nn.Upsample(scale_factor=2, mode="bilinear") 183 | self.up4 = nn.Upsample(scale_factor=4, mode="bilinear") 184 | 185 | self.Conv43 = nn.Sequential(nn.Conv2d(2 * self.out_dim, self.out_dim, 1, 1, 0, bias=False), 186 | nn.BatchNorm2d(self.out_dim), 187 | nn.ReLU(True), nn.Conv2d(self.out_dim, self.out_dim, 3, 1, 1, bias=False), 188 | nn.BatchNorm2d(self.out_dim), 189 | nn.ReLU(True)) 190 | 191 | self.Conv432 = nn.Sequential(nn.Conv2d(2 * self.out_dim, self.out_dim, 1, 1, 0, bias=False), 192 | nn.BatchNorm2d(self.out_dim), 193 | nn.ReLU(True), nn.Conv2d(self.out_dim, self.out_dim, 3, 1, 1, bias=False), 194 | nn.BatchNorm2d(self.out_dim), 195 | nn.ReLU(True)) 196 | self.Conv4321 = nn.Sequential(nn.Conv2d(2 * self.out_dim, self.out_dim, 1, 1, 0, bias=False), 197 | nn.BatchNorm2d(self.out_dim), 198 | nn.ReLU(True), nn.Conv2d(self.out_dim, self.out_dim, 3, 1, 1, bias=False), 199 | nn.BatchNorm2d(self.out_dim), 200 | nn.ReLU(True)) 201 | 202 | self.sal_pred = nn.Sequential(nn.Conv2d(self.out_dim, 64, 3, 1, 1, bias=False), nn.BatchNorm2d(64), 203 | nn.ReLU(True), 204 | nn.Conv2d(64, 1, 3, 1, 1, bias=False)) 205 | 206 | self.linear4 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 207 | self.linear3 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 208 | self.linear2 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 209 | 210 | self.aspp_rgb = _AtrousSpatialPyramidPoolingModule(512, 320, 211 | output_stride=16) 212 | self.aspp_depth = _AtrousSpatialPyramidPoolingModule(512, 320, 213 | output_stride=16) 214 | self.after_aspp_conv_rgb = nn.Conv2d(320 * 5, 320, kernel_size=1, bias=False) 215 | self.after_aspp_conv_depth = nn.Conv2d(320 * 5, 320, kernel_size=1, bias=False) 216 | 217 | self.edge_conv = nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False) 218 | self.rcab_sal_edge = RCAB(32 * 2) 219 | self.fused_edge_sal = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=False) 220 | self.sal_conv = nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False) 221 | self.relu = nn.ReLU(True) 222 | 223 | 224 | 225 | def forward(self, x, feature_list, feature_list_depth): 226 | R1, R2, R3, R4 = feature_list[0], feature_list[1], feature_list[2], feature_list[3] 227 | D1, D2, D3, D4 = feature_list_depth[0], feature_list_depth[1], feature_list_depth[2], feature_list_depth[3] 228 | 229 | R4 = self.aspp_rgb(R4) 230 | D4 = self.aspp_depth(D4) 231 | R4 = self.after_aspp_conv_rgb(R4) 232 | D4 = self.after_aspp_conv_depth(D4) 233 | 234 | RD1 = self.fuse1(R1, D1) 235 | RD2 = self.fuse2(R2, D2) 236 | RD3 = self.fuse3(R3, D3) 237 | RD4 = self.fuse4(R4, D4) 238 | 239 | RD43 = self.up2(RD4) 240 | RD43 = torch.cat((RD43, RD3), dim=1) 241 | RD43 = self.Conv43(RD43) 242 | 243 | RD432 = self.up2(RD43) 244 | RD432 = torch.cat((RD432, RD2), dim=1) 245 | RD432 = self.Conv432(RD432) 246 | 247 | RD4321 = self.up2(RD432) 248 | RD4321 = torch.cat((RD4321, RD1), dim=1) 249 | RD4321 = self.Conv4321(RD4321) # [B, 128, 56, 56] 250 | 251 | sal_map = self.sal_pred(RD4321) 252 | sal_out = self.up4(sal_map) 253 | 254 | mask4 = F.interpolate(self.linear4(RD4), size=x.size()[2:], mode='bilinear', align_corners=False) 255 | mask3 = F.interpolate(self.linear3(RD43), size=x.size()[2:], mode='bilinear', align_corners=False) 256 | mask2 = F.interpolate(self.linear4(RD432), size=x.size()[2:], mode='bilinear', align_corners=False) 257 | 258 | 259 | return sal_out, mask4, mask3, mask2 260 | 261 | 262 | class PvtNet(nn.Module): 263 | def __init__(self, args): 264 | super().__init__() 265 | key = [] 266 | self.encoder_rgb = pvt_v2_b2() 267 | self.encoder_depth = pvt_v2_b2() 268 | self.decoder = Decoder(dim=128) 269 | self.edge_layer = Edge_Module() 270 | self.fuse_canny_edge = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) 271 | 272 | # ------------------------ rgb prediction module ---------------------------- # 273 | self.conv1_1 = nn.Conv2d(512, 320, kernel_size=(3, 3), padding=(1, 1)) 274 | self.conv1_3 = nn.Conv2d(320, 128, kernel_size=(3, 3), padding=(1, 1)) 275 | self.conv1_5 = nn.Conv2d(128, 64, kernel_size=(3, 3), padding=(1, 1)) 276 | self.conv1_7 = nn.Conv2d(64, 32, kernel_size=(3, 3), padding=(1, 1)) 277 | self.conv1_9 = nn.Conv2d(32, 32, kernel_size=(3, 3), padding=(1, 1)) 278 | self.conv1_11 = nn.Conv2d(32, 1, kernel_size=(3, 3), padding=(1, 1)) 279 | 280 | # ------------------------ t prediction module ---------------------------- # 281 | self.conv2_1 = nn.Conv2d(512, 320, kernel_size=(3, 3), padding=(1, 1)) 282 | self.conv2_3 = nn.Conv2d(320, 128, kernel_size=(3, 3), padding=(1, 1)) 283 | self.conv2_5 = nn.Conv2d(128, 64, kernel_size=(3, 3), padding=(1, 1)) 284 | self.conv2_7 = nn.Conv2d(64, 32, kernel_size=(3, 3), padding=(1, 1)) 285 | self.conv2_9 = nn.Conv2d(32, 32, kernel_size=(3, 3), padding=(1, 1)) 286 | self.conv2_11 = nn.Conv2d(32, 1, kernel_size=(3, 3), padding=(1, 1)) 287 | 288 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 289 | 290 | def forward(self, input_rgb, input_depth): 291 | # output of backbone 292 | # x_size = input_rgb.size() 293 | rgb_feats = self.encoder_rgb(input_rgb) 294 | depth_feats = self.encoder_depth(input_depth) 295 | if self.training is True: 296 | # ------------------------ rgb prediction module ---------------------------- # 297 | sal1 = self.conv1_11(self.upsample(self.conv1_9(self.upsample(self.conv1_7( 298 | self.upsample(self.conv1_5( 299 | self.upsample(self.conv1_3( 300 | self.upsample(self.conv1_1(rgb_feats[3]))))))))))) 301 | 302 | # ------------------------ t prediction module ---------------------------- # 303 | sal2 = self.conv2_11(self.upsample(self.conv2_9(self.upsample(self.conv2_7( 304 | self.upsample(self.conv2_5( 305 | self.upsample(self.conv2_3( 306 | self.upsample(self.conv2_1(depth_feats[3]))))))))))) 307 | 308 | 309 | result_final, mask4, mask3, mask2 = self.decoder(input_rgb, rgb_feats, depth_feats) 310 | 311 | return result_final, mask4, mask3, mask2, torch.sigmoid(sal1), torch.sigmoid(sal2) 312 | else: 313 | result_final, mask4, mask3, mask2 = self.decoder(input_rgb, rgb_feats, depth_feats) 314 | return result_final, mask4, mask3, mask2 315 | 316 | -------------------------------------------------------------------------------- /PVT_Model/pvtv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | import math 10 | 11 | class Mlp(nn.Module): 12 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 13 | super().__init__() 14 | out_features = out_features or in_features 15 | hidden_features = hidden_features or in_features 16 | self.fc1 = nn.Linear(in_features, hidden_features) 17 | self.dwconv = DWConv(hidden_features) 18 | self.act = act_layer() 19 | self.fc2 = nn.Linear(hidden_features, out_features) 20 | self.drop = nn.Dropout(drop) 21 | self.linear = linear 22 | if self.linear: 23 | self.relu = nn.ReLU(inplace=True) 24 | self.apply(self._init_weights) 25 | 26 | def _init_weights(self, m): 27 | if isinstance(m, nn.Linear): 28 | trunc_normal_(m.weight, std=.02) 29 | if isinstance(m, nn.Linear) and m.bias is not None: 30 | nn.init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.LayerNorm): 32 | nn.init.constant_(m.bias, 0) 33 | nn.init.constant_(m.weight, 1.0) 34 | elif isinstance(m, nn.Conv2d): 35 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 36 | fan_out //= m.groups 37 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | 41 | def forward(self, x, H, W): 42 | x = self.fc1(x) 43 | if self.linear: 44 | x = self.relu(x) 45 | x = self.dwconv(x, H, W) 46 | x = self.act(x) 47 | x = self.drop(x) 48 | x = self.fc2(x) 49 | x = self.drop(x) 50 | return x 51 | 52 | 53 | class Attention(nn.Module): 54 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 55 | super().__init__() 56 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 57 | 58 | self.dim = dim 59 | self.num_heads = num_heads 60 | head_dim = dim // num_heads 61 | self.scale = qk_scale or head_dim ** -0.5 62 | 63 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 64 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 65 | self.attn_drop = nn.Dropout(attn_drop) 66 | self.proj = nn.Linear(dim, dim) 67 | self.proj_drop = nn.Dropout(proj_drop) 68 | 69 | self.linear = linear 70 | self.sr_ratio = sr_ratio 71 | if not linear: 72 | if sr_ratio > 1: 73 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 74 | self.norm = nn.LayerNorm(dim) 75 | else: 76 | self.pool = nn.AdaptiveAvgPool2d(7) 77 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 78 | self.norm = nn.LayerNorm(dim) 79 | self.act = nn.GELU() 80 | self.apply(self._init_weights) 81 | 82 | def _init_weights(self, m): 83 | if isinstance(m, nn.Linear): 84 | trunc_normal_(m.weight, std=.02) 85 | if isinstance(m, nn.Linear) and m.bias is not None: 86 | nn.init.constant_(m.bias, 0) 87 | elif isinstance(m, nn.LayerNorm): 88 | nn.init.constant_(m.bias, 0) 89 | nn.init.constant_(m.weight, 1.0) 90 | elif isinstance(m, nn.Conv2d): 91 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 92 | fan_out //= m.groups 93 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 94 | if m.bias is not None: 95 | m.bias.data.zero_() 96 | 97 | def forward(self, x, H, W): 98 | B, N, C = x.shape 99 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 100 | 101 | if not self.linear: 102 | if self.sr_ratio > 1: 103 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 104 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 105 | x_ = self.norm(x_) 106 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 107 | else: 108 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 109 | else: 110 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 111 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 112 | x_ = self.norm(x_) 113 | x_ = self.act(x_) 114 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | k, v = kv[0], kv[1] 116 | 117 | attn = (q @ k.transpose(-2, -1)) * self.scale 118 | attn = attn.softmax(dim=-1) 119 | attn = self.attn_drop(attn) 120 | 121 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 122 | x = self.proj(x) 123 | x = self.proj_drop(x) 124 | 125 | return x 126 | 127 | 128 | class Block(nn.Module): 129 | 130 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 131 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): 132 | super().__init__() 133 | self.norm1 = norm_layer(dim) 134 | self.attn = Attention( 135 | dim, 136 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 137 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) 138 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 139 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 140 | self.norm2 = norm_layer(dim) 141 | mlp_hidden_dim = int(dim * mlp_ratio) 142 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) 143 | 144 | self.apply(self._init_weights) 145 | 146 | def _init_weights(self, m): 147 | if isinstance(m, nn.Linear): 148 | trunc_normal_(m.weight, std=.02) 149 | if isinstance(m, nn.Linear) and m.bias is not None: 150 | nn.init.constant_(m.bias, 0) 151 | elif isinstance(m, nn.LayerNorm): 152 | nn.init.constant_(m.bias, 0) 153 | nn.init.constant_(m.weight, 1.0) 154 | elif isinstance(m, nn.Conv2d): 155 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 156 | fan_out //= m.groups 157 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 158 | if m.bias is not None: 159 | m.bias.data.zero_() 160 | 161 | def forward(self, x, H, W): 162 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 163 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 164 | 165 | return x 166 | 167 | 168 | class OverlapPatchEmbed(nn.Module): 169 | """ Image to Patch Embedding 170 | """ 171 | 172 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 173 | super().__init__() 174 | img_size = to_2tuple(img_size) 175 | patch_size = to_2tuple(patch_size) 176 | 177 | self.img_size = img_size 178 | self.patch_size = patch_size 179 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 180 | self.num_patches = self.H * self.W 181 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 182 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 183 | self.norm = nn.LayerNorm(embed_dim) 184 | 185 | self.apply(self._init_weights) 186 | 187 | def _init_weights(self, m): 188 | if isinstance(m, nn.Linear): 189 | trunc_normal_(m.weight, std=.02) 190 | if isinstance(m, nn.Linear) and m.bias is not None: 191 | nn.init.constant_(m.bias, 0) 192 | elif isinstance(m, nn.LayerNorm): 193 | nn.init.constant_(m.bias, 0) 194 | nn.init.constant_(m.weight, 1.0) 195 | elif isinstance(m, nn.Conv2d): 196 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 197 | fan_out //= m.groups 198 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 199 | if m.bias is not None: 200 | m.bias.data.zero_() 201 | 202 | def forward(self, x): 203 | x = self.proj(x) 204 | _, _, H, W = x.shape 205 | x = x.flatten(2).transpose(1, 2) 206 | x = self.norm(x) 207 | 208 | return x, H, W 209 | 210 | class OverlapPatchEmbed1(nn.Module): 211 | """ Image to Patch Embedding 212 | """ 213 | 214 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 215 | super().__init__() 216 | img_size = to_2tuple(img_size) 217 | patch_size = to_2tuple(patch_size) 218 | 219 | self.img_size = img_size 220 | self.patch_size = patch_size 221 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 222 | self.num_patches = self.H * self.W 223 | self.proj1 = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 224 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 225 | self.norm = nn.LayerNorm(embed_dim) 226 | 227 | self.apply(self._init_weights) 228 | 229 | def _init_weights(self, m): 230 | if isinstance(m, nn.Linear): 231 | trunc_normal_(m.weight, std=.02) 232 | if isinstance(m, nn.Linear) and m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | elif isinstance(m, nn.LayerNorm): 235 | nn.init.constant_(m.bias, 0) 236 | nn.init.constant_(m.weight, 1.0) 237 | elif isinstance(m, nn.Conv2d): 238 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 239 | fan_out //= m.groups 240 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 241 | if m.bias is not None: 242 | m.bias.data.zero_() 243 | 244 | def forward(self, x): 245 | x = self.proj1(x) 246 | _, _, H, W = x.shape 247 | x = x.flatten(2).transpose(1, 2) 248 | x = self.norm(x) 249 | 250 | return x, H, W 251 | 252 | class PyramidVisionTransformerV2(nn.Module): 253 | def __init__(self, img_size=224, patch_size=16, in_chans=39, num_classes=1000, embed_dims=[64, 128, 256, 512], 254 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 255 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], 256 | sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False, pretrained=None): 257 | super().__init__() 258 | # self.num_classes = num_classes 259 | self.depths = depths 260 | self.num_stages = num_stages 261 | self.linear = linear 262 | 263 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 264 | cur = 0 265 | 266 | for i in range(num_stages): 267 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 268 | patch_size=7 if i == 0 else 3, 269 | stride=4 if i == 0 else 2, 270 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 271 | embed_dim=embed_dims[i]) 272 | 273 | block = nn.ModuleList([Block( 274 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 275 | qk_scale=qk_scale, 276 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, 277 | sr_ratio=sr_ratios[i], linear=linear) 278 | for j in range(depths[i])]) 279 | norm = norm_layer(embed_dims[i]) 280 | cur += depths[i] 281 | 282 | setattr(self, f"patch_embed{i + 1}", patch_embed) 283 | setattr(self, f"block{i + 1}", block) 284 | setattr(self, f"norm{i + 1}", norm) 285 | 286 | # classification head 287 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 288 | 289 | self.apply(self._init_weights) 290 | self.init_weights(pretrained) 291 | 292 | def _init_weights(self, m): 293 | if isinstance(m, nn.Linear): 294 | trunc_normal_(m.weight, std=.02) 295 | if isinstance(m, nn.Linear) and m.bias is not None: 296 | nn.init.constant_(m.bias, 0) 297 | elif isinstance(m, nn.LayerNorm): 298 | nn.init.constant_(m.bias, 0) 299 | nn.init.constant_(m.weight, 1.0) 300 | elif isinstance(m, nn.Conv2d): 301 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 302 | fan_out //= m.groups 303 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 304 | if m.bias is not None: 305 | m.bias.data.zero_() 306 | 307 | def init_weights(self, pretrained=None): 308 | if isinstance(pretrained, str): 309 | print('from {} load pretrained...'.format(pretrained)) 310 | self.load_state_dict(torch.load(pretrained), strict=False) 311 | 312 | 313 | def freeze_patch_emb(self): 314 | self.patch_embed1.requires_grad = False 315 | 316 | @torch.jit.ignore 317 | def no_weight_decay(self): 318 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 319 | 320 | def get_classifier(self): 321 | return self.head 322 | 323 | def reset_classifier(self, num_classes, global_pool=''): 324 | self.num_classes = num_classes 325 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 326 | 327 | def forward_features(self, x): 328 | B = x.shape[0] 329 | outs = [] 330 | 331 | for i in range(self.num_stages): 332 | patch_embed = getattr(self, f"patch_embed{i + 1}") 333 | block = getattr(self, f"block{i + 1}") 334 | norm = getattr(self, f"norm{i + 1}") 335 | x, H, W = patch_embed(x) 336 | for blk in block: 337 | x = blk(x, H, W) 338 | x = norm(x) 339 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 340 | outs.append(x) 341 | 342 | return outs 343 | 344 | 345 | def forward(self, x): 346 | x = self.forward_features(x) 347 | # x = self.head(x) 348 | 349 | return x 350 | 351 | 352 | class DWConv(nn.Module): 353 | def __init__(self, dim=768): 354 | super(DWConv, self).__init__() 355 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 356 | 357 | def forward(self, x, H, W): 358 | B, N, C = x.shape 359 | x = x.transpose(1, 2).view(B, C, H, W) 360 | x = self.dwconv(x) 361 | x = x.flatten(2).transpose(1, 2) 362 | 363 | return x 364 | 365 | 366 | def _conv_filter(state_dict, patch_size=16): 367 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 368 | out_dict = {} 369 | for k, v in state_dict.items(): 370 | if 'patch_embed.proj.weight' in k: 371 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 372 | out_dict[k] = v 373 | 374 | return out_dict 375 | 376 | 377 | 378 | class pvt_v2_b0(PyramidVisionTransformerV2): 379 | def __init__(self, **kwargs): 380 | super(pvt_v2_b0, self).__init__( 381 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 382 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 383 | drop_rate=0.0, drop_path_rate=0.1, pretrained=kwargs['pretrained']) 384 | 385 | 386 | 387 | class pvt_v2_b1(PyramidVisionTransformerV2): 388 | def __init__(self, **kwargs): 389 | super(pvt_v2_b1, self).__init__( 390 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 391 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 392 | drop_rate=0.0, drop_path_rate=0.1, pretrained=kwargs['pretrained']) 393 | 394 | 395 | 396 | class pvt_v2_b2(PyramidVisionTransformerV2): 397 | def __init__(self, **kwargs): 398 | super(pvt_v2_b2, self).__init__( 399 | in_chans=3, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 400 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 401 | drop_rate=0.0, drop_path_rate=0.1) 402 | 403 | 404 | 405 | class pvt_v2_b2_li(PyramidVisionTransformerV2): 406 | def __init__(self, **kwargs): 407 | super(pvt_v2_b2_li, self).__init__( 408 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 409 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 410 | drop_rate=0.0, drop_path_rate=0.1, linear=True) 411 | 412 | 413 | 414 | class pvt_v2_b3(PyramidVisionTransformerV2): 415 | def __init__(self, **kwargs): 416 | super(pvt_v2_b3, self).__init__( 417 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 418 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 419 | drop_rate=0.0, drop_path_rate=0.1, pretrained=kwargs['pretrained']) 420 | 421 | 422 | 423 | class pvt_v2_b4(PyramidVisionTransformerV2): 424 | def __init__(self, **kwargs): 425 | super(pvt_v2_b4, self).__init__( 426 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 427 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 428 | drop_rate=0.0, drop_path_rate=0.1, pretrained=kwargs['pretrained']) 429 | 430 | 431 | 432 | class pvt_v2_b5(PyramidVisionTransformerV2): 433 | def __init__(self, **kwargs): 434 | super(pvt_v2_b5, self).__init__( 435 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 436 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 437 | drop_rate=0.0, drop_path_rate=0.1, pretrained=kwargs['pretrained']) --------------------------------------------------------------------------------